├── LICENSE ├── README.md ├── accel_conf.json ├── assets ├── celeb_manip_2.gif ├── clevrer_manip_1.gif └── tarffic_manip.gif ├── checkpoints ├── PLACE PRE-TRAINED CHECKPOINTS HERE.txt └── sample_images │ ├── celeb │ ├── 1.jpg │ ├── 2.jpg │ ├── 3.jpg │ ├── 4.jpg │ ├── 5.jpg │ ├── 6.jpg │ ├── 7.jpg │ └── 8.jpg │ ├── clevrer │ ├── 1.png │ └── 2.png │ └── traffic │ └── 1.png ├── datasets ├── celeba_dataset.py ├── clevrer_ds.py ├── shapes_ds.py └── traffic_ds.py ├── dlp_tutorial.ipynb ├── environment17.yml ├── environment19.yml ├── eval └── eval_model.py ├── eval_celeb.py ├── interactive_demo_dlp.py ├── models.py ├── modules └── modules.py ├── requirements17.txt ├── requirements19.txt ├── train_dlp.py ├── train_dlp_accelerate.py └── utils ├── loss_functions.py ├── tps.py └── util_func.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # deep-latent-particles-pytorch 2 | 3 | [ICML 2022] Official PyTorch implementation of the paper "Unsupervised Image Representation Learning with Deep Latent Particles" 4 | 5 |

DLPv2 and DDLP (DLP for video generation) have been released: DDLP: Unsupervised Object-Centric Video Prediction with Deep Dynamic Latent Particles

6 | 7 |

8 |
9 | [ICML 2022] Unsupervised Image Representation Learning with Deep Latent Particles 10 |
11 |

12 |

13 | Tal Daniel • 14 | Aviv Tamar 15 | 16 |

17 |

Official repository of the paper

18 | 19 |

ICML 2022

20 | 21 |

Project WebsiteVideo

22 | 23 |

24 | Open In Colab 25 |

26 | 27 | 28 |

29 | 30 | 31 |

32 |

33 | 34 | 35 |

36 | 37 | # Deep Latent Particles 38 | 39 | > **Unsupervised Image Representation Learning with Deep Latent Particles**
40 | > Tal Daniel, Aviv Tamar
41 | > 42 | > **Abstract:** *We propose a new representation of visual data that disentangles object position from appearance. 43 | > Our method, termed Deep Latent Particles (DLP), decomposes the visual input into low-dimensional latent ``particles'', 44 | > where each particle is described by its spatial location and features of its surrounding region. 45 | > To drive learning of such representations, we follow a VAE-based approach and introduce a prior for particle positions 46 | > based on a spatial-softmax architecture, and a modification of the evidence lower bound loss 47 | > inspired by the Chamfer distance between particles. We demonstrate that our DLP representations are useful for 48 | > downstream tasks such as unsupervised keypoint (KP) detection, image manipulation, and video prediction for scenes 49 | > composed of multiple dynamic objects. In addition, we show that our probabilistic interpretation of the problem 50 | > naturally provides uncertainty estimates for particle locations, which can be used for model selection, 51 | > among other tasks.* 52 | 53 | ## Citation 54 | 55 | Daniel, Tal, and Aviv Tamar. "Unsupervised Image Representation Learning with Deep Latent Particles." Proceedings of the 39th International Conference on Machine Learning (ICML) 2022. 56 | > 57 | @InProceedings{pmlr-v162-daniel22a, 58 | title = {Unsupervised Image Representation Learning with Deep Latent Particles}, 59 | author = {Daniel, Tal and Tamar, Aviv}, 60 | booktitle = {Proceedings of the 39th International Conference on Machine Learning}, 61 | pages = {4644--4665}, 62 | year = {2022}, 63 | volume = {162}, 64 | series = {Proceedings of Machine Learning Research}, 65 | month = {17--23 Jul}, 66 | publisher = {PMLR} 67 | 68 | 69 | 70 |

Paper on ArXiv: 2205.15821

71 | 72 | - [deep-latent-particles-pytorch](#deep-latent-particles-pytorch) 73 | - [Deep Latent Particles](#deep-latent-particles) 74 | * [Citation](#citation) 75 | * [Prerequisites](#prerequisites) 76 | * [Pretrained Models](#pretrained-models) 77 | * [Interactive Demo](#interactive-demo) 78 | * [Datasets](#datasets) 79 | * [Training](#training) 80 | * [Evaluation of Unsupervised Keypoint Regression on CelebA](#evaluation-of-unsupervised-keypoint-regression-on-celeba) 81 | * [Recommended Hyper-parameters](#recommended-hyper-parameters) 82 | * [Repository Organization](#repository-organization) 83 | * [Credits](#credits) 84 | 85 | ## Prerequisites 86 | 87 | * For your convenience, we provide an `environemnt.yml` file which installs the required packages in a `conda` 88 | environment named `torch`. Alternatively, you can use `pip` to install `requirements.txt`. 89 | * Use the terminal or an Anaconda Prompt and run the following command `conda env create -f environment.yml`. 90 | * For PyTorch 1.7 + CUDA 10.2: `environment17.yml`, `requirements17.txt` 91 | * For PyTorch 1.9 + CUDA 11.1: `environment19.yml`, `requirements19.txt` 92 | 93 | | Library | Version | 94 | |-------------------|------------------| 95 | | `Python` | `3.7 (Anaconda)` | 96 | | `torch` | > = `1.7.1` | 97 | | `torch_geometric` | > = `1.7.1` | 98 | | `torchvision` | > = `0.4` | 99 | | `matplotlib` | > = `2.2.2` | 100 | | `numpy` | > = `1.17` | 101 | | `py-opencv` | > = `3.4.2` | 102 | | `tqdm` | > = `4.36.1` | 103 | | `scipy` | > = `1.3.1` | 104 | | `scikit-image` | > = `0.18.1` | 105 | | `accelerate` | > = `0.3.0` | 106 | 107 | ## Pretrained Models 108 | 109 | * We provide pre-trained checkpoints for the 3 datasets we used in the paper. 110 | * All model checkpoints should be placed inside the `/checkpoints` directory. 111 | * The interactive demo will use these checkpoints. 112 | 113 | | Dataset | Filename | Link | 114 | |-------------------|------------------------------------|--------------------------------------------------------------------------------------| 115 | | CelebA (128x128) | `dlp_celeba_gauss_pointnetpp_feat.pth` | [MEGA.co.nz](https://mega.nz/file/ZAkiDSIQ#ndtlzAPwG42TEGZmuADAzR1Wo0AZx2k__qyWUfcOkQc)| 116 | | Traffic (128x128) | `dlp_traffic_gauss_pointnetpp.pth` | [MEGA.co.nz](https://mega.nz/file/9cN0HAYQ#K9AvKsWemA5hvk9WleleautIdQu2Euezf8UOI7aKUtE)| 117 | | CLEVRER (128x128) | `dlp_clevrer_gauss_pointnetpp.pth` | [MEGA.co.nz](https://mega.nz/file/VINjRZCL#rJ25UPXlYJUxWPaP7gDEbxjVZaayey5JB6x9P5Z__CU)| 118 | 119 | ## Interactive Demo 120 | 121 | * We designed a simple `matplotlib` interactive GUI to plot and control the particles. 122 | * The demo is a **standalone and does not require to download the original datasets**. 123 | * We provide sample images inside `/checkpoints/sample_images/` which will be used. 124 | 125 | To run the demo (after downloading the checkpoints): `python interactive_demo_dlp.py --help` 126 | 127 | * `-d`: dataset to use: [`celeba`, `traffic`, `clevrer`] 128 | * `-i`: index of the image to use inside `/checkpoints/sample_images/` 129 | 130 | Examples: 131 | 132 | * `python interactive_demo_dlp.py -d celeba -i 2` 133 | * `python interactive_demo_dlp.py -d traffic -i 0` 134 | * `python interactive_demo_dlp.py -d clevrer -i 0` 135 | 136 | You can modify `interactive_demo_dlp.py` to add additional datasets. 137 | 138 | ## Datasets 139 | 140 | * **CelebA**: we follow [DVE](https://github.com/jamt9000/DVE): 141 | * [Download](https://github.com/jamt9000/DVE/blob/master/misc/datasets/celeba/README.md) the dataset from 142 | this [link](http:/www.robots.ox.ac.uk/~vgg/research/DVE/data/datasets/celeba.tar.gz). 143 | * The pre-processing is described in `datasets/celeba_dataset.py`. 144 | * **CLEVRER**: download the training and validation videos from [here](http://clevrer.csail.mit.edu/): 145 | * [Training Videos](http://data.csail.mit.edu/clevrer/videos/train/video_train.zip) 146 | , [Validation Videos](http://data.csail.mit.edu/clevrer/videos/validation/video_validation.zip) 147 | * Follow the pre-processing 148 | in `datasets/clevrer_ds.py` (`prepare_numpy_file(path_to_img, image_size=128, frameskip=3, start_frame=26`) 149 | * **Traffic**: this is a self-collected dataset, please contact us if you wish to use it. 150 | * **Shapes**: this dataset is generated automatically in each run for simplicity, see `generate_shape_dataset_torch()` 151 | in `datasets/shapes_ds.py`. 152 | 153 | ## Training 154 | 155 | You can train the model on single-GPU machines and multi-GPU machines. For multi-GPU training We use 156 | [HuggingFace Accelerate](https://huggingface.co/docs/accelerate/index): `pip install accelerate`. 157 | 158 | 1. Set visible GPUs under: `os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"` (`NUM_GPUS=4`) 159 | 2. Set "num_processes": NUM_GPUS in `accel_conf.json` (e.g. `"num_processes":4` 160 | if `os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"`). 161 | 162 | 163 | * Single-GPU machines: `python train_dlp.py --help` 164 | * Multi-GPU machines: `accelerate --config_file ./accel_conf.json train_dlp_accelerate.py --help` 165 | 166 | You should run the `train_dlp.py` or `train_dlp_accelerate.py` files with the following arguments: 167 | 168 | | Argument | Description | Legal Values | 169 | |-------------------------|----------------------------------------------------------------------------------------------------------|----------------------------------------------| 170 | | -h, --help | shows arguments description | | 171 | | -d, --dataset | dataset to train on | str: 'celeba', traffic', 'clevrer', 'shapes' | 172 | | -o, --override | if specified, the code will override the default hyper-parameters with the ones specified with `argparse` (command line) | bool: default=False | 173 | | -l, --lr | learning rate | float: default=2e-4 | 174 | | -b, --batch_size | batch size | int: default=32 | 175 | | -n, --num_epochs | total number of epochs to run | int: default=100 | 176 | | -e, --eval_freq | evaluation epoch frequency | int: defalut=2 | 177 | | -s, --sigma | the prior std of the keypoints | float: default=0.1 | 178 | | -p, --prefix | string prefix for logging | str: default="" | 179 | | -r, --beta_rec | beta coefficient for the reconstruction loss | float: default=1.0 | 180 | | -k, --beta_kl | beta coefficient for the kl divergence | float: default=1.0 | 181 | | -c, --kl_balance | coefficient for the balance between the ChamferKL (for the KP) and the standard KL | float: default=0.001 | 182 | | -v, --rec_loss_function | type of reconstruction loss: 'mse', 'vgg' | str: default="mse" | 183 | | --n_kp_enc | number of posterior kp to be learned | int: default=30 | 184 | | --n_kp_enc_prior | number of kp to filter from the set of prior kp | int: default=50 | 185 | | --dec_bone | decoder backbone:'gauss_pointnetpp_feat': Masked Model, 'gauss_pointnetpp': Object Model" | str: default="gauss_pointnetpp" | 186 | | --patch_size | patch size for the prior KP proposals network (not to be confused with the glimpse size) | int: default=8 | 187 | | --learned_feature_dim | the latent visual features dimensions extracted from glimpses | int: default=10 | 188 | | --use_object_enc | set True to use a separate encoder to encode visual features of glimpses | bool: default=False | 189 | | --use_object_dec | set True to use a separate decoder to decode glimpses (Object Model) | bool: default=False | 190 | | --warmup_epoch | number of epochs where only the object decoder is trained | int: default=2 | 191 | | --anchor_s | defines the glimpse size as a ratio of image_size | float: default=0.25 | 192 | | --exclusive_patches | set True to enable non-overlapping object patches | bool: default=False | 193 | 194 | Examples: 195 | 196 | * Single-GPU: 197 | 198 | `python train_dlp.py --dataset shapes` 199 | 200 | `python train_dlp.py --dataset celeba` 201 | 202 | `python train_dlp.py --dataset clevrer -o --use_object_enc --use_object_dec --warmup_epoch 1 --beta_kl 40.0 --rec_loss_function vgg --learned_feature_dim 6` 203 | 204 | * Multi-GPU: 205 | 206 | `accelerate --config_file ./accel_conf.json train_dlp_accelerate.py --dataset celeba` 207 | 208 | `accelerate --config_file ./accel_conf.json train_dlp_accelerate.py --dataset clevrer -o --use_object_enc --use_object_dec --warmup_epoch 1 --beta_kl 40.0 --rec_loss_function vgg --learned_feature_dim 6` 209 | 210 | * Note: if you want multiple multi-GPU runs, each run should have a different accelerate config file ( 211 | e.g., `accel_conf.json`, `accel_conf_2.json`, etc..). The only difference between the files should be 212 | the `main_process_port` field (e.g., for the second config file, set `main_process_port: 81231`). 213 | 214 | ## Evaluation of Unsupervised Keypoint Regression on CelebA 215 | 216 | Linear regression of supervised keypoints on the MAFL dataset it performed during training on the CelebA dataset. 217 | 218 | To evaluate a saved checkpoint of the model: modify the hyper-parameters and paths in `eval_celeb.py`, 219 | and then use `python eval_celeb.py` to calculate and print the normalized error with respect to intra-occular distance. 220 | 221 | ## Recommended Hyper-parameters 222 | 223 | | Dataset | `dec_bone` (model type) | `n_kp_enc` | `n_kp_prior`|`rec_loss_func`|`beta_kl`| `kl_balance` | `patch_size` | `anchor_s` | `learned_feature_dim` | 224 | |---------------------|-------------------------|--------------|---|----|---|-----|-----|-----------|-----| 225 | | CelebA (`celeba`) | `gauss_pointnetpp_feat` | 30 |50|`vgg`|40| 0.001 | 8 |0.125| 10 | 226 | | Traffic (`traffic`) | `gauss_pointnetpp` | 15 |20|`vgg`|30| 0.001 | 16 | 0.25 | 20 | 227 | | CLEVRER (`clevrer`) | `gauss_pointnetpp` | 10 |20|`vgg`|40| 0.001 | 16 |0.25 | 5 | 228 | | Shapes (`shapes`) | `gauss_pointnetpp` | 10 |15|`mse`|0.1| 0.001 | 8 | 0.25 | 6 | 229 | 230 | 231 | ## Repository Organization 232 | 233 | | File name | Content | 234 | |----------------------------|---------------------------------------------------------------------------------------------------------------------------------------------| 235 | | `/checkpoints` | directory for pre-trained checkpoints and sample images for the interactive demo | 236 | | `/datasets` | directory containing data loading classes for the various datasets | 237 | | `/eval/eval_model.py` | evaluation functions such as evaluating the ELBO | 238 | | `/modules/modules.py` | basic neural network blocks used to implement the DLP model | 239 | | `/utils/tps.py` | implementation of the TPS augmentation used for training on CelebA | 240 | | `/utils/loss_functions.py` | loss functions used to optimize the model such as Chamfer-KL and perceptual (VGG) loss | 241 | | `/utils/util_func.py` | utility functions such as logging and plotting functions | 242 | | `eval_celeb.py` | functions to evaluate the normalized error of keypoint linear regression with respect to intra-occular distance for the MAFL/CelebA dataset | 243 | | `models.py` | implementation of the DLP model | 244 | | `train_dlp.py` | training function of DLP for single-GPU machines | 245 | | `train_dlp_accelerate.py` | training function of DLP for multi-GPU machines | 246 | | `dlp_tutorial.ipynb` | Jupyter Notebook tutorial for explaining and training DLP on the random shapes dataset | 247 | | `interactive_demo_dlp.py` | `matplotlib`-based interactive demo to plot and interact with learned particles | 248 | | `environment17/19.yml` | Anaconda environment file to install the required dependencies | 249 | | `requirements17/19.txt` | requirements file for `pip` | 250 | | `accel_conf.json` | configuration file for `accelerate` to run training on multiple GPUs | 251 | 252 | ## Credits 253 | 254 | * CelebA pre-processing is performed as [DVE](https://github.com/jamt9000/DVE). 255 | * Normalized intra-occular distance: [KeyNet (Jakab et al.)](https://github.com/tomasjakab/imm). 256 | -------------------------------------------------------------------------------- /accel_conf.json: -------------------------------------------------------------------------------- 1 | { 2 | "compute_environment": "LOCAL_MACHINE", 3 | "distributed_type": "MULTI_GPU", 4 | "fp16": false, 5 | "machine_rank": 0, 6 | "main_process_ip": null, 7 | "main_process_port": null, 8 | "main_training_function": "main", 9 | "num_machines": 1, 10 | "num_processes": 4 11 | } 12 | -------------------------------------------------------------------------------- /assets/celeb_manip_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/assets/celeb_manip_2.gif -------------------------------------------------------------------------------- /assets/clevrer_manip_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/assets/clevrer_manip_1.gif -------------------------------------------------------------------------------- /assets/tarffic_manip.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/assets/tarffic_manip.gif -------------------------------------------------------------------------------- /checkpoints/PLACE PRE-TRAINED CHECKPOINTS HERE.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/PLACE PRE-TRAINED CHECKPOINTS HERE.txt -------------------------------------------------------------------------------- /checkpoints/sample_images/celeb/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/sample_images/celeb/1.jpg -------------------------------------------------------------------------------- /checkpoints/sample_images/celeb/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/sample_images/celeb/2.jpg -------------------------------------------------------------------------------- /checkpoints/sample_images/celeb/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/sample_images/celeb/3.jpg -------------------------------------------------------------------------------- /checkpoints/sample_images/celeb/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/sample_images/celeb/4.jpg -------------------------------------------------------------------------------- /checkpoints/sample_images/celeb/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/sample_images/celeb/5.jpg -------------------------------------------------------------------------------- /checkpoints/sample_images/celeb/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/sample_images/celeb/6.jpg -------------------------------------------------------------------------------- /checkpoints/sample_images/celeb/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/sample_images/celeb/7.jpg -------------------------------------------------------------------------------- /checkpoints/sample_images/celeb/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/sample_images/celeb/8.jpg -------------------------------------------------------------------------------- /checkpoints/sample_images/clevrer/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/sample_images/clevrer/1.png -------------------------------------------------------------------------------- /checkpoints/sample_images/clevrer/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/sample_images/clevrer/2.png -------------------------------------------------------------------------------- /checkpoints/sample_images/traffic/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taldatech/deep-latent-particles-pytorch/a3d9cffe6426d5f8dd68b8bafdefb322f662f18e/checkpoints/sample_images/traffic/1.png -------------------------------------------------------------------------------- /datasets/clevrer_ds.py: -------------------------------------------------------------------------------- 1 | """ 2 | functions and classes to process the CLEVRER dataset 3 | """ 4 | 5 | import os 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import cv2 9 | import utils.tps as tps 10 | 11 | import torch 12 | from PIL import Image 13 | from tqdm import tqdm 14 | from torch.utils.data import Dataset, DataLoader 15 | import torchvision.transforms as transforms 16 | 17 | 18 | def list_images_in_dir(path): 19 | valid_images = [".jpg", ".gif", ".png"] 20 | img_list = [] 21 | for f in os.listdir(path): 22 | ext = os.path.splitext(f)[1] 23 | if ext.lower() not in valid_images: 24 | continue 25 | img_list.append(os.path.join(path, f)) 26 | return img_list 27 | 28 | 29 | def prepare_numpy_file(path_to_image_dir, image_size=128, frameskip=1, start_frame=1): 30 | # path_to_image_dir = '/media/newhd/data/traffic_data/rimon_frames/' 31 | img_list = list_images_in_dir(path_to_image_dir) 32 | img_list = sorted(img_list, key=lambda x: int(x.split('/')[-1].split('_')[-1].split('.')[0])) 33 | img_list = [img_list[i] for i in range(len(img_list)) if 34 | abs(int(img_list[i].split('/')[-1].split('_')[-1].split('.')[0])) % 1000 > start_frame] 35 | print(f'img_list: {len(img_list)}, 0: {img_list[0]}, -1: {img_list[-1]}') 36 | img_np_list = [] 37 | for i in tqdm(range(len(img_list))): 38 | if i % frameskip != 0: 39 | continue 40 | img = Image.open(img_list[i]) 41 | img = img.convert('RGB') 42 | # img = img.crop((60, 0, 480, 420)) 43 | img = img.resize((image_size, image_size), Image.BICUBIC) 44 | img_np = np.asarray(img) 45 | img_np_list.append(img_np) 46 | img_np_array = np.stack(img_np_list, axis=0) 47 | print(f'img_np_array: {img_np_array.shape}') 48 | save_path = os.path.join(path_to_image_dir, f'clevrer_img{image_size}np_fs{frameskip}.npy') 49 | np.save(save_path, img_np_array) 50 | print(f'file save at @ {save_path}') 51 | 52 | 53 | class CLEVRERDataset(Dataset): 54 | def __init__(self, path_to_npy, image_size=128, transform=None, mode='single', train=True, horizon=3, 55 | frames_per_video=34, video_as_index=False): 56 | super(CLEVRERDataset, self).__init__() 57 | assert mode in ['single', 'frames', 'tps', 'horizon'] 58 | self.mode = mode 59 | self.frames_per_video = frames_per_video 60 | self.horizon = horizon if (horizon > 0 and self.mode == 'horizon') else self.frames_per_video 61 | self.train_mode = train 62 | if train: 63 | print(f'clevrer dataset mode: {self.mode}') 64 | if self.mode == 'horizon': 65 | print(f'time steps horizon: {self.horizon}') 66 | if self.mode == 'tps': 67 | self.warper = tps.Warper(H=image_size, W=image_size, warpsd_all=0.00001, 68 | warpsd_subset=0.001, transsd=0.1, scalesd=0.1, 69 | rotsd=2, im1_multiplier=0.1, im1_multiplier_aff=0.1) 70 | else: 71 | self.warper = None 72 | data = np.load(path_to_npy) 73 | # train_size = int(0.9 * data.shape[0]) 74 | # valid_size = data.shape[0] - train_size 75 | if train: 76 | self.data = data 77 | # self.data = data[:self.frames_per_video * 200] 78 | print(f'loaded data with shape: {self.data.shape}, size: {self.data.shape[0]}') 79 | else: 80 | self.data = data[:5000] 81 | self.image_size = image_size 82 | self.num_videos = self.data.shape[0] // self.frames_per_video 83 | self.video_as_index = video_as_index 84 | if transform is None: 85 | self.input_transform = transforms.Compose([ 86 | transforms.ToPILImage(), 87 | transforms.Resize(image_size), 88 | transforms.ToTensor() 89 | ]) 90 | else: 91 | self.input_transform = transform 92 | 93 | def __getitem__(self, index): 94 | if not self.video_as_index: 95 | video_num = int(index / self.frames_per_video) 96 | video_start_idx = video_num * self.frames_per_video 97 | curr_idx = index % self.frames_per_video 98 | max_idx = min(video_start_idx + self.frames_per_video - 1, self.data.shape[0] - 1) 99 | global_idx = video_start_idx + curr_idx 100 | if self.mode == 'single': 101 | return self.input_transform(self.data[index]) 102 | elif self.mode == 'frames': 103 | min_idx = min(video_start_idx, index - 1) 104 | if min_idx == video_start_idx: 105 | im1 = self.input_transform(self.data[min_idx + 1]) 106 | im2 = self.input_transform(self.data[min_idx]) 107 | else: 108 | im1 = self.input_transform(self.data[min_idx]) 109 | im2 = self.input_transform(self.data[min_idx - 1]) 110 | return im1, im2 111 | elif self.mode == 'horizon': 112 | images = [] 113 | length = max_idx 114 | if (index + self.horizon) >= length: 115 | slack = index + self.horizon - length 116 | index = index - slack 117 | for i in range(self.horizon): 118 | t = index + i 119 | images.append(self.input_transform(self.data[t])) 120 | images = torch.stack(images, dim=0) 121 | return images 122 | elif self.mode == 'tps': 123 | im = self.input_transform(self.data[index]) 124 | im = im * 255 125 | im2, im1, _, _, _, _ = self.warper(im) 126 | return im1 / 255, im2 / 255 127 | else: 128 | raise NotImplementedError 129 | else: 130 | video_num = index 131 | video_start_idx = video_num * self.frames_per_video 132 | max_idx = video_start_idx + self.frames_per_video - 1 133 | images = [] 134 | length = max_idx 135 | frame_idx = video_start_idx 136 | actual_horizon = self.frames_per_video if ((frame_idx + self.horizon) >= length) else self.horizon 137 | for i in range(actual_horizon): 138 | t = frame_idx + i 139 | images.append(self.input_transform(self.data[t])) 140 | images = torch.stack(images, dim=0) 141 | return images 142 | 143 | def __len__(self): 144 | if not self.video_as_index: 145 | return self.data.shape[0] 146 | else: 147 | return self.num_videos 148 | 149 | 150 | if __name__ == '__main__': 151 | path_to_img = '/media/newhd/data/clevrer/train/frames/' 152 | # prepare_numpy_file(path_to_img, image_size=128, frameskip=3, start_frame=26) 153 | test_epochs = False 154 | # load data 155 | path_to_npy = '/media/newhd/data/clevrer/valid/clevrer_img128np_fs3_valid.npy' 156 | mode = 'frames' 157 | horizon = 4 158 | train = True 159 | clevrer_ds = CLEVRERDataset(path_to_npy, mode=mode, train=train, horizon=horizon) 160 | clevrer_dl = DataLoader(clevrer_ds, shuffle=True, pin_memory=True, batch_size=5) 161 | batch = next(iter(clevrer_dl)) 162 | if mode == 'single': 163 | im1 = batch[0] 164 | elif mode == 'frames' or mode == 'tps': 165 | im1 = batch[0][0] 166 | im2 = batch[1][0] 167 | 168 | if mode == 'single': 169 | print(im1.shape) 170 | img_np = im1.permute(1, 2, 0).data.cpu().numpy() 171 | fig = plt.figure(figsize=(5, 5)) 172 | ax = fig.add_subplot(111) 173 | ax.imshow(img_np) 174 | elif mode == 'horizon': 175 | print(f'batch shape: {batch.shape}') 176 | images = batch[0] 177 | print(f'images shape: {images.shape}') 178 | fig = plt.figure(figsize=(8, 8)) 179 | for i in range(images.shape[0]): 180 | ax = fig.add_subplot(1, horizon, i + 1) 181 | im = images[i] 182 | im_np = im.permute(1, 2, 0).data.cpu().numpy() 183 | ax.imshow(im_np) 184 | ax.set_title(f'im {i + 1}') 185 | else: 186 | print(f'im1: {im1.shape}, im2: {im2.shape}') 187 | im1_np = im1.permute(1, 2, 0).data.cpu().numpy() 188 | im2_np = im2.permute(1, 2, 0).data.cpu().numpy() 189 | fig = plt.figure(figsize=(8, 8)) 190 | ax = fig.add_subplot(1, 2, 1) 191 | ax.imshow(im1_np) 192 | ax.set_title('im1') 193 | 194 | ax = fig.add_subplot(1, 2, 2) 195 | ax.imshow(im2_np) 196 | ax.set_title('im2 [t-1] or [tps]') 197 | plt.show() 198 | if test_epochs: 199 | from tqdm import tqdm 200 | 201 | pbar = tqdm(iterable=clevrer_dl) 202 | for batch in pbar: 203 | pass 204 | pbar.close() 205 | -------------------------------------------------------------------------------- /datasets/shapes_ds.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple Random Colored Shapes Dataset 3 | """ 4 | # imports 5 | import numpy as np 6 | from skimage.draw import random_shapes 7 | from tqdm.auto import tqdm 8 | import torch 9 | 10 | 11 | def generate_shape_dataset(img_size=64, min_shapes=2, max_shapes=5, min_size=10, max_size=12, allow_overlap=False, 12 | num_images=10_000): 13 | images = [] 14 | for i in tqdm(range(num_images)): 15 | img, _ = random_shapes((img_size, img_size), min_shapes=min_shapes, max_shapes=max_shapes, 16 | intensity_range=((0, 200),), min_size=min_size, max_size=max_size, 17 | allow_overlap=allow_overlap, num_trials=100) 18 | img[:, :, 0][img[:, :, 0] == 255] = 0 19 | img[:, :, 1][img[:, :, 1] == 255] = 255 20 | img[:, :, 2][img[:, :, 2] == 255] = 255 21 | img = img / 255.0 22 | images.append(img) 23 | images = np.stack(images, axis=0) # [num_mages, H, W, 3] 24 | return images 25 | 26 | 27 | def generate_shape_dataset_torch(img_size=64, min_shapes=2, max_shapes=5, min_size=11, max_size=13, allow_overlap=False, 28 | num_images=10_000): 29 | images = generate_shape_dataset(img_size=img_size, min_shapes=min_shapes, max_shapes=max_shapes, min_size=min_size, 30 | max_size=max_size, 31 | allow_overlap=allow_overlap, num_images=num_images) 32 | # create torch dataset 33 | img_data_torch = images.transpose(0, 3, 1, 2) # [num_images, 3, H, W] 34 | img_ds = torch.utils.data.TensorDataset(torch.tensor(img_data_torch, dtype=torch.float)) 35 | return img_ds 36 | -------------------------------------------------------------------------------- /datasets/traffic_ds.py: -------------------------------------------------------------------------------- 1 | """ 2 | functions and classes to process the Traffic dataset 3 | """ 4 | 5 | import os 6 | import numpy as np 7 | import cv2 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader 10 | import torchvision.transforms as transforms 11 | import torchvision.utils as vutils 12 | import matplotlib.pyplot as plt 13 | from PIL import Image 14 | # from tqdm.auto import tqdm 15 | import utils.tps as tps 16 | 17 | 18 | def list_images_in_dir(path): 19 | valid_images = [".jpg", ".gif", ".png"] 20 | img_list = [] 21 | for f in os.listdir(path): 22 | ext = os.path.splitext(f)[1] 23 | if ext.lower() not in valid_images: 24 | continue 25 | img_list.append(os.path.join(path, f)) 26 | return img_list 27 | 28 | 29 | def prepare_numpy_file(path_to_image_dir, image_size=128, frameskip=1): 30 | # path_to_image_dir = '/media/newhd/data/traffic_data/rimon_frames/' 31 | img_list = list_images_in_dir(path_to_image_dir) 32 | img_list = sorted(img_list, key=lambda x: int(x.split('/')[-1].split('.')[0])) 33 | print(f'img_list: {len(img_list)}, 0: {img_list[0]}, -1: {img_list[-1]}') 34 | img_np_list = [] 35 | for i in tqdm(range(len(img_list))): 36 | if i % frameskip != 0: 37 | continue 38 | img = Image.open(img_list[i]) 39 | img = img.convert('RGB') 40 | img = img.crop((60, 0, 480, 420)) 41 | img = img.resize((image_size, image_size), Image.BICUBIC) 42 | img_np = np.asarray(img) 43 | img_np_list.append(img_np) 44 | img_np_array = np.stack(img_np_list, axis=0) 45 | print(f'img_np_array: {img_np_array.shape}') 46 | save_path = os.path.join(path_to_image_dir, f'img{image_size}np_fs{frameskip}.npy') 47 | np.save(save_path, img_np_array) 48 | print(f'file save at @ {save_path}') 49 | 50 | 51 | class TrafficDataset(Dataset): 52 | def __init__(self, path_to_npy, image_size=128, transform=None, mode='single', train=True, horizon=3): 53 | super(TrafficDataset, self).__init__() 54 | assert mode in ['single', 'frames', 'tps', 'horizon'] 55 | self.mode = mode 56 | self.horizon = horizon 57 | if train: 58 | print(f'traffic dataset mode: {self.mode}') 59 | if self.mode == 'horizon': 60 | print(f'time steps horizon: {self.horizon}') 61 | if self.mode == 'tps': 62 | self.warper = tps.Warper(H=image_size, W=image_size, warpsd_all=0.00001, 63 | warpsd_subset=0.001, transsd=0.1, scalesd=0.1, 64 | rotsd=2, im1_multiplier=0.1, im1_multiplier_aff=0.1) 65 | else: 66 | self.warper = None 67 | data = np.load(path_to_npy) 68 | train_size = int(0.9 * data.shape[0]) 69 | valid_size = data.shape[0] - train_size 70 | if train: 71 | print(f'loaded data with shape: {data.shape}, train_size: {train_size}, valid_size: {valid_size}') 72 | self.data = data[:train_size] 73 | else: 74 | self.data = data[train_size:] 75 | self.image_size = image_size 76 | if transform is None: 77 | self.input_transform = transforms.Compose([ 78 | transforms.ToPILImage(), 79 | transforms.Resize(image_size), 80 | transforms.ToTensor() 81 | ]) 82 | else: 83 | self.input_transform = transform 84 | 85 | def __getitem__(self, index): 86 | if self.mode == 'single': 87 | return self.input_transform(self.data[index]) 88 | elif self.mode == 'frames': 89 | if index == 0: 90 | im1 = self.input_transform(self.data[index + 1]) 91 | im2 = self.input_transform(self.data[index]) 92 | else: 93 | im1 = self.input_transform(self.data[index]) 94 | im2 = self.input_transform(self.data[index - 1]) 95 | return im1, im2 96 | elif self.mode == 'horizon': 97 | images = [] 98 | length = self.data.shape[0] 99 | if (index + self.horizon) >= length: 100 | slack = index + self.horizon - length 101 | index = index - slack 102 | for i in range(self.horizon): 103 | t = index + i 104 | images.append(self.input_transform(self.data[t])) 105 | images = torch.stack(images, dim=0) 106 | return images 107 | elif self.mode == 'tps': 108 | im = self.input_transform(self.data[index]) 109 | im = im * 255 110 | im2, im1, _, _, _, _ = self.warper(im) 111 | return im1 / 255, im2 / 255 112 | else: 113 | raise NotImplementedError 114 | 115 | def __len__(self): 116 | return self.data.shape[0] 117 | 118 | 119 | if __name__ == '__main__': 120 | # prepare data 121 | path_to_image_dir = '/media/newhd/data/traffic_data/rimon_frames/' 122 | frameskip = 3 123 | image_size = 128 124 | # prepare_numpy_file(path_to_image_dir, image_size=128, frameskip=1) 125 | test_epochs = True 126 | # load data 127 | path_to_npy = '/media/newhd/data/traffic_data/img128np_fs3.npy' 128 | mode = 'horizon' 129 | horizon = 4 130 | train = True 131 | traffic_ds = TrafficDataset(path_to_npy, mode=mode, train=train, horizon=horizon) 132 | traffic_dl = DataLoader(traffic_ds, shuffle=True, pin_memory=True, batch_size=5) 133 | batch = next(iter(traffic_dl)) 134 | if mode == 'single': 135 | im1 = batch[0] 136 | elif mode == 'frames' or mode == 'tps': 137 | im1 = batch[0][0] 138 | im2 = batch[1][0] 139 | 140 | if mode == 'single': 141 | print(im1.shape) 142 | img_np = im1.permute(1, 2, 0).data.cpu().numpy() 143 | fig = plt.figure(figsize=(5, 5)) 144 | ax = fig.add_subplot(111) 145 | ax.imshow(img_np) 146 | elif mode == 'horizon': 147 | print(f'batch shape: {batch.shape}') 148 | images = batch[0] 149 | print(f'images shape: {images.shape}') 150 | fig = plt.figure(figsize=(8, 8)) 151 | for i in range(images.shape[0]): 152 | ax = fig.add_subplot(1, horizon, i + 1) 153 | im = images[i] 154 | im_np = im.permute(1, 2, 0).data.cpu().numpy() 155 | ax.imshow(im_np) 156 | ax.set_title(f'im {i + 1}') 157 | else: 158 | print(f'im1: {im1.shape}, im2: {im2.shape}') 159 | im1_np = im1.permute(1, 2, 0).data.cpu().numpy() 160 | im2_np = im2.permute(1, 2, 0).data.cpu().numpy() 161 | fig = plt.figure(figsize=(8, 8)) 162 | ax = fig.add_subplot(1, 2, 1) 163 | ax.imshow(im1_np) 164 | ax.set_title('im1') 165 | 166 | ax = fig.add_subplot(1, 2, 2) 167 | ax.imshow(im2_np) 168 | ax.set_title('im2 [t-1] or [tps]') 169 | plt.show() 170 | if test_epochs: 171 | from tqdm import tqdm 172 | pbar = tqdm(iterable=traffic_dl) 173 | for batch in pbar: 174 | pass 175 | pbar.close() 176 | -------------------------------------------------------------------------------- /environment17.yml: -------------------------------------------------------------------------------- 1 | name: torch17 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - cudatoolkit=10.2.89 8 | - ffmpeg=4.0 9 | - hdf5=1.10.2 10 | - imageio=2.6.1 11 | - matplotlib=3.4.2 12 | - matplotlib-base=3.4.2 13 | - opencv=3.4.2 14 | - pillow=8.3.1 15 | - pip=21.1.3 16 | - py-opencv=3.4.2 17 | - python=3.7.10 18 | - python-dateutil=2.8.2 19 | - pytorch=1.7.1 20 | - scikit-image=0.18.1 21 | - torchaudio=0.7.2 22 | - torchvision=0.8.2 23 | - pip: 24 | - absl-py==0.15.0 25 | - accelerate==0.3.0 26 | - astunparse==1.6.3 27 | - cachetools==4.2.4 28 | - charset-normalizer==2.0.3 29 | - cloudpickle==2.0.0 30 | - dm-tree==0.1.6 31 | - flatbuffers==1.12 32 | - gast==0.3.3 33 | - google-auth==2.3.3 34 | - google-auth-oauthlib==0.4.6 35 | - google-pasta==0.2.0 36 | - googledrivedownloader==0.4 37 | - grpcio==1.32.0 38 | - h5py==2.10.0 39 | - idna==3.2 40 | - importlib-metadata==4.10.0 41 | - isodate==0.6.0 42 | - jinja2==3.0.1 43 | - joblib==1.0.1 44 | - markdown==3.3.6 45 | - markupsafe==2.0.1 46 | - networkx==2.6.2 47 | - numpy==1.19.5 48 | - oauthlib==3.1.1 49 | - opt-einsum==3.3.0 50 | - packaging==21.0 51 | - pandas==1.3.1 52 | - protobuf==3.19.3 53 | - pyaml==20.4.0 54 | - pyasn1==0.4.8 55 | - pyasn1-modules==0.2.8 56 | - python-louvain==0.15 57 | - pytz==2021.1 58 | - pyyaml==5.4.1 59 | - rdflib==6.0.0 60 | - requests-oauthlib==1.3.0 61 | - rsa==4.8 62 | - scikit-learn==0.24.2 63 | - scipy==1.7.0 64 | - six==1.15.0 65 | - termcolor==1.1.0 66 | - threadpoolctl==2.2.0 67 | - torch-cluster==1.5.9 68 | - torch-geometric==1.7.2 69 | - torch-scatter==2.0.7 70 | - torch-sparse==0.6.9 71 | - torch-spline-conv==1.2.1 72 | - tqdm==4.61.2 73 | - typing-extensions==3.7.4.3 74 | - urllib3==1.26.6 75 | - werkzeug==2.0.2 76 | - wrapt==1.12.1 77 | - zipp==3.7.0 78 | -------------------------------------------------------------------------------- /environment19.yml: -------------------------------------------------------------------------------- 1 | name: torch19 2 | channels: 3 | - pyg 4 | - pytorch 5 | - nvidia 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - cudatoolkit=11.1.74 10 | - ffmpeg=4.3 11 | - h5py=2.10.0 12 | - hdf5=1.10.5 13 | - imageio=2.6.1 14 | - ipykernel=6.4.1 15 | - ipython=7.27.0 16 | - ipython_genutils=0.2.0 17 | - ipywidgets=7.6.4 18 | - jupyter=1.0.0 19 | - jupyter_client=7.0.2 20 | - jupyter_console=6.4.0 21 | - jupyter_core=4.7.1 22 | - jupyterlab_pygments=0.1.2 23 | - jupyterlab_widgets=1.0.1 24 | - matplotlib=3.4.2 25 | - matplotlib-base=3.4.2 26 | - matplotlib-inline=0.1.3 27 | - notebook=6.4.3 28 | - numpy=1.20.3 29 | - numpy-base=1.20.3 30 | - pandas=1.3.0 31 | - pillow=8.3.1 32 | - pip=21.0.1 33 | - python=3.7.11 34 | - pytorch=1.9.0 35 | - pytorch-cluster=1.5.9 36 | - pytorch-scatter=2.0.8 37 | - pytorch-sparse=0.6.12 38 | - pytorch-spline-conv=1.2.1 39 | - pyyaml=5.4.1 40 | - scikit-image=0.18.1 41 | - scikit-learn=0.24.2 42 | - scipy=1.6.3 43 | - torchaudio=0.9.0 44 | - torchfile=0.1.0 45 | - torchvision=0.10.0 46 | - tqdm=4.62.2 47 | - yacs=0.1.6 48 | - yaml=0.2.5 49 | - pip: 50 | - accelerate==0.5.1 51 | - opencv-python==3.4.2.17 52 | -------------------------------------------------------------------------------- /eval/eval_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluation of the ELBO on the validation set 3 | """ 4 | # imports 5 | import numpy as np 6 | # torch 7 | import torch 8 | import torch.nn.functional as F 9 | from utils.loss_functions import ChamferLossKL, calc_kl, calc_reconstruction_loss, VGGDistance 10 | from torch.utils.data import DataLoader 11 | import torchvision.utils as vutils 12 | # datasets 13 | from datasets.traffic_ds import TrafficDataset 14 | from datasets.clevrer_ds import CLEVRERDataset 15 | # util functions 16 | from utils.util_func import plot_keypoints_on_image_batch 17 | 18 | 19 | def evaluate_validation_elbo(model, ds, epoch, batch_size=100, recon_loss_type="vgg", device=torch.device('cpu'), 20 | save_image=False, fig_dir='./', topk=5, recon_loss_func=None, beta_rec=1.0, beta_kl=1.0, 21 | kl_balance=1.0, accelerator=None): 22 | model.eval() 23 | kp_range = model.kp_range 24 | # load data 25 | if ds == "traffic": 26 | image_size = 128 27 | root = '/mnt/data/tal/traffic_dataset/img128np_fs3.npy' 28 | mode = 'single' 29 | dataset = TrafficDataset(path_to_npy=root, image_size=image_size, mode=mode, train=False) 30 | elif ds == 'clevrer': 31 | image_size = 128 32 | root = '/mnt/data/tal/clevrer/clevrer_img128np_fs3_valid.npy' 33 | # root = '/media/newhd/data/clevrer/valid/clevrer_img128np_fs3_valid.npy' 34 | mode = 'single' 35 | dataset = CLEVRERDataset(path_to_npy=root, image_size=image_size, mode=mode, train=False) 36 | else: 37 | raise NotImplementedError 38 | 39 | dataloader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=2, drop_last=False) 40 | kl_loss_func = ChamferLossKL(use_reverse_kl=False) 41 | if recon_loss_func is None: 42 | if recon_loss_type == "vgg": 43 | recon_loss_func = VGGDistance(device=device) 44 | else: 45 | recon_loss_func = calc_reconstruction_loss 46 | 47 | elbos = [] 48 | for batch in dataloader: 49 | if ds == 'traffic': 50 | if mode == 'single': 51 | x = batch.to(device) 52 | x_prior = x 53 | else: 54 | x = batch[0].to(device) 55 | x_prior = batch[1].to(device) 56 | elif ds == 'clevrer': 57 | if mode == 'single': 58 | x = batch.to(device) 59 | x_prior = x 60 | else: 61 | x = batch[0].to(device) 62 | x_prior = batch[1].to(device) 63 | else: 64 | x = batch 65 | x_prior = x 66 | batch_size = x.shape[0] 67 | # forward pass 68 | with torch.no_grad(): 69 | model_output = model(x, x_prior=x_prior) 70 | mu_p = model_output['kp_p'] 71 | gmap = model_output['gmap'] 72 | mu = model_output['mu'] 73 | logvar = model_output['logvar'] 74 | rec_x = model_output['rec'] 75 | mu_features = model_output['mu_features'] 76 | logvar_features = model_output['logvar_features'] 77 | # object stuff 78 | dec_objects_original = model_output['dec_objects_original'] 79 | cropped_objects_original = model_output['cropped_objects_original'] 80 | 81 | # reconstruction error 82 | if recon_loss_type == "vgg": 83 | loss_rec = recon_loss_func(x, rec_x, reduction="mean") 84 | else: 85 | loss_rec = calc_reconstruction_loss(x, rec_x, loss_type='mse', reduction='mean') 86 | 87 | # kl-divergence 88 | logvar_p = torch.log(torch.tensor(model.sigma ** 2)).to(mu.device) # logvar of the constant std -> for the kl 89 | logvar_kp = logvar_p.expand_as(mu_p) 90 | 91 | mu_post = mu 92 | logvar_post = logvar 93 | mu_prior = mu_p 94 | logvar_prior = logvar_kp 95 | 96 | loss_kl_kp = kl_loss_func(mu_preds=mu_post, logvar_preds=logvar_post, mu_gts=mu_prior, 97 | logvar_gts=logvar_prior).mean() 98 | if model.learned_feature_dim > 0: 99 | loss_kl_feat = calc_kl(logvar_features.view(-1, logvar_features.shape[-1]), 100 | mu_features.view(-1, mu_features.shape[-1]), reduce='none') 101 | loss_kl_feat = loss_kl_feat.view(batch_size, model.n_kp_enc + 1).sum(1).mean() 102 | else: 103 | loss_kl_feat = torch.tensor(0.0, device=mu.device) 104 | loss_kl = loss_kl_kp + kl_balance * loss_kl_feat 105 | elbo = beta_rec * loss_rec + beta_kl * loss_kl 106 | elbos.append(elbo.data.cpu().numpy()) 107 | if save_image: 108 | max_imgs = 8 109 | img_with_kp = plot_keypoints_on_image_batch(mu.clamp(min=model.kp_range[0], max=model.kp_range[1]), x, radius=3, 110 | thickness=1, max_imgs=max_imgs, kp_range=model.kp_range) 111 | img_with_kp_p = plot_keypoints_on_image_batch(mu_p, x_prior, radius=3, thickness=1, max_imgs=max_imgs, 112 | kp_range=model.kp_range) 113 | # top-k 114 | with torch.no_grad(): 115 | logvar_sum = logvar.sum(-1) 116 | logvar_topk = torch.topk(logvar_sum, k=topk, dim=-1, largest=False) 117 | indices = logvar_topk[1] # [batch_size, topk] 118 | batch_indices = torch.arange(mu.shape[0]).view(-1, 1).to(mu.device) 119 | topk_kp = mu[batch_indices, indices] 120 | img_with_kp_topk = plot_keypoints_on_image_batch(topk_kp.clamp(min=kp_range[0], max=kp_range[1]), x, 121 | radius=3, thickness=1, max_imgs=max_imgs, 122 | kp_range=kp_range) 123 | if model.use_object_dec and dec_objects_original is not None: 124 | dec_objects = model_output['dec_objects'] 125 | if accelerator is not None: 126 | if accelerator.is_main_process: 127 | vutils.save_image(torch.cat([x[:max_imgs, -3:], img_with_kp[:max_imgs, -3:].to(mu.device), 128 | rec_x[:max_imgs, -3:], img_with_kp_p[:max_imgs, -3:].to(mu.device), 129 | img_with_kp_topk[:max_imgs, -3:].to(mu.device), 130 | dec_objects[:max_imgs, -3:]], 131 | dim=0).data.cpu(), '{}/image_valid_{}.jpg'.format(fig_dir, epoch), 132 | nrow=8, pad_value=1) 133 | else: 134 | vutils.save_image(torch.cat([x[:max_imgs, -3:], img_with_kp[:max_imgs, -3:].to(mu.device), 135 | rec_x[:max_imgs, -3:], img_with_kp_p[:max_imgs, -3:].to(mu.device), 136 | img_with_kp_topk[:max_imgs, -3:].to(mu.device), 137 | dec_objects[:max_imgs, -3:]], 138 | dim=0).data.cpu(), '{}/image_valid_{}.jpg'.format(fig_dir, epoch), 139 | nrow=8, pad_value=1) 140 | with torch.no_grad(): 141 | _, dec_objects_rgb = torch.split(dec_objects_original, [1, 3], dim=2) 142 | dec_objects_rgb = dec_objects_rgb.reshape(-1, *dec_objects_rgb.shape[2:]) 143 | cropped_objects_original = cropped_objects_original.clone().reshape(-1, 3, 144 | cropped_objects_original.shape[ 145 | -1], 146 | cropped_objects_original.shape[ 147 | -1]) 148 | if cropped_objects_original.shape[-1] != dec_objects_rgb.shape[-1]: 149 | cropped_objects_original = F.interpolate(cropped_objects_original, 150 | size=dec_objects_rgb.shape[-1], 151 | align_corners=False, mode='bilinear') 152 | if accelerator is not None: 153 | if accelerator.is_main_process: 154 | vutils.save_image( 155 | torch.cat([cropped_objects_original[:max_imgs * 2, -3:], dec_objects_rgb[:max_imgs * 2, -3:]], 156 | dim=0).data.cpu(), '{}/image_obj_valid_{}.jpg'.format(fig_dir, epoch), 157 | nrow=8, pad_value=1) 158 | else: 159 | vutils.save_image( 160 | torch.cat([cropped_objects_original[:max_imgs * 2, -3:], dec_objects_rgb[:max_imgs * 2, -3:]], 161 | dim=0).data.cpu(), '{}/image_obj_valid_{}.jpg'.format(fig_dir, epoch), 162 | nrow=8, pad_value=1) 163 | else: 164 | if accelerator is not None: 165 | if accelerator.is_main_process: 166 | vutils.save_image(torch.cat([x[:max_imgs, -3:], img_with_kp[:max_imgs, -3:].to(mu.device), 167 | rec_x[:max_imgs, -3:], img_with_kp_p[:max_imgs, -3:].to(mu.device), 168 | img_with_kp_topk[:max_imgs, -3:].to(mu.device)], 169 | dim=0).data.cpu(), '{}/image_valid_{}.jpg'.format(fig_dir, epoch), 170 | nrow=8, pad_value=1) 171 | else: 172 | vutils.save_image(torch.cat([x[:max_imgs, -3:], img_with_kp[:max_imgs, -3:].to(mu.device), 173 | rec_x[:max_imgs, -3:], img_with_kp_p[:max_imgs, -3:].to(mu.device), 174 | img_with_kp_topk[:max_imgs, -3:].to(mu.device)], 175 | dim=0).data.cpu(), '{}/image_valid_{}.jpg'.format(fig_dir, epoch), 176 | nrow=8, pad_value=1) 177 | return np.mean(elbos) 178 | -------------------------------------------------------------------------------- /eval_celeb.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate supervised regression on CelebA-HQ 128x128 3 | """ 4 | # imports 5 | import torch 6 | from datasets.celeba_dataset import evaluate_lin_reg_on_mafl_topk 7 | from models import KeyPointVAE 8 | 9 | if __name__ == '__main__': 10 | image_size = 128 11 | imwidth = 160 12 | crop = 16 13 | ch = 3 14 | enc_channels = [32, 64, 128, 256] 15 | prior_channels = (16, 32, 64) 16 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 17 | use_logsoftmax = False 18 | pad_mode = 'replicate' 19 | sigma = 0.1 # default sigma for the gaussian maps 20 | n_kp = 1 # num kp per patch 21 | n_kp_enc = 30 # total kp to output from the encoder / filter from prior 22 | n_kp_prior = 50 23 | mask_threshold = 0.2 # mask threshold for the features from the encoder 24 | patch_size = 8 # 8 for playground, 16 for celeb 25 | learned_feature_dim = 10 # additional features than x,y for each kp 26 | kp_range = (-1, 1) 27 | dec_bone = "gauss_pointnetpp_feat" 28 | topk = 10 29 | kp_activation = "tanh" 30 | use_object_enc = True # separate object encoder 31 | use_object_dec = False # separate object decoder 32 | anchor_s = 0.125 33 | learn_order = False 34 | kl_balance = 0.001 35 | dropout = 0.0 36 | root = '/mnt/data/tal/celeba' 37 | path_to_model_ckpt = './checkpoints/dlp_celeba_gauss_pointnetpp_feat.pth' 38 | 39 | model = KeyPointVAE(cdim=ch, enc_channels=enc_channels, prior_channels=prior_channels, 40 | image_size=image_size, n_kp=n_kp, learned_feature_dim=learned_feature_dim, 41 | use_logsoftmax=use_logsoftmax, pad_mode=pad_mode, sigma=sigma, 42 | dropout=dropout, dec_bone=dec_bone, patch_size=patch_size, n_kp_enc=n_kp_enc, 43 | n_kp_prior=n_kp_prior, kp_range=kp_range, kp_activation=kp_activation, 44 | mask_threshold=mask_threshold, use_object_enc=use_object_enc, 45 | use_object_dec=use_object_dec, anchor_s=anchor_s, learn_order=learn_order).to(device) 46 | model.load_state_dict( 47 | torch.load(path_to_model_ckpt, map_location=device)) 48 | print("loaded model from checkpoint") 49 | print('evaluating linear regression') 50 | result = evaluate_lin_reg_on_mafl_topk(model, root=root, batch_size=100, device=device, topk=topk) 51 | print(result) 52 | print(f"all kp (mu) err: {result['mu_err'] * 100}%") 53 | print(f"top-10 confident kp err: {result['mu_confident_err'] * 100}%") 54 | print(f"top-10 uncertain kp err: {result['mu_uncertainty_err'] * 100}%") 55 | print(f"all kp with logvar err: {result['logvar_err'] * 100}%") 56 | print(f"all kp with logvar and features err: {result['feat_err'] * 100}%") 57 | -------------------------------------------------------------------------------- /interactive_demo_dlp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Interactive demo to observe and explore the learned particles. 3 | """ 4 | import torch 5 | from PIL import Image 6 | import torchvision.transforms as transforms 7 | import torchvision.transforms.functional as TF 8 | 9 | from models import KeyPointVAE 10 | from utils.util_func import reparameterize 11 | 12 | import matplotlib 13 | from matplotlib.widgets import Slider, Button 14 | import matplotlib as mpl 15 | from matplotlib import pyplot as plt 16 | import numpy as np 17 | import argparse 18 | 19 | matplotlib.use('Qt5Agg') 20 | 21 | 22 | def update_from_slider(val): 23 | for i in np.arange(N - 1): 24 | yvals[i] = sliders_y[i].val 25 | # xvals[i] = sliders_x[i].val 26 | if learned_feature_dim > 0: 27 | feature_1_vals[i] = sliders_features[i].val 28 | update(val) 29 | 30 | 31 | def update(val): 32 | global yvals 33 | global xvals 34 | global dec_bone 35 | global learned_feature_dim 36 | if learned_feature_dim > 0: 37 | global feature_1_vals 38 | # update curve 39 | for i in np.arange(N - 1): 40 | if learned_feature_dim > 0: 41 | # print(f'{i}: {feature_1_vals[i]}') 42 | feature_1_vals[i] = sliders_features[i].val 43 | # print(f'{i}: {feature_1_vals[i]}') 44 | l.set_offsets(np.c_[xvals, yvals]) 45 | # convert to tensors 46 | new_mu = torch.from_numpy(np.stack([yvals, xvals], axis=-1)).unsqueeze(0).to(device) / (image_size - 1) # [0, 1] 47 | new_mu = new_mu * (kp_range[1] - kp_range[0]) + kp_range[0] # [kp_range[0], kp_range[1]] 48 | new_mu = torch.cat([new_mu, original_mu[:, -1].unsqueeze(1)], dim=1) 49 | delta_mu = new_mu - original_mu 50 | # print(f'delta_mu: {delta_mu}') 51 | if learned_feature_dim > 0: 52 | new_features = torch.from_numpy(feature_1_vals[None, :, None]).to(device) 53 | new_features = torch.cat([mu_features[:, :, :-1], new_features], dim=-1) 54 | else: 55 | new_features = None 56 | with torch.no_grad(): 57 | rec_new, _, _ = model.decode_all(new_mu, new_features, kp_heatmap, obj_on, deterministic=deterministic, 58 | order_weights=order_weights) 59 | rec_new = rec_new.clamp(0, 1) 60 | 61 | image_rec_new = rec_new[0].permute(1, 2, 0).data.cpu().numpy() 62 | m.set_data(image_rec_new) 63 | # redraw canvas while idle 64 | fig.canvas.draw_idle() 65 | 66 | 67 | def reset(event): 68 | global yvals 69 | global xvals 70 | global learned_feature_dim 71 | if learned_feature_dim > 0: 72 | global feature_1_vals 73 | # reset the values 74 | xvals = mu[0, :-1, 1].data.cpu().numpy() * (image_size - 1) 75 | yvals = mu[0, :-1, 0].data.cpu().numpy() * (image_size - 1) 76 | if learned_feature_dim > 0: 77 | # a slider for the last feature dimension 78 | # feature_1_vals = mu_features[0, :, 0].data.cpu().numpy() 79 | feature_1_vals = mu_features[0, :, -1].data.cpu().numpy() 80 | for i in np.arange(N - 1): 81 | sliders_y[i].reset() 82 | if learned_feature_dim > 0: 83 | sliders_features[i].reset() 84 | l.set_offsets(np.c_[xvals, yvals]) 85 | m.set_data(image_rec) 86 | # redraw canvas while idle 87 | fig.canvas.draw_idle() 88 | 89 | 90 | def button_press_callback(event): 91 | 'whenever a mouse button is pressed' 92 | global pind 93 | if event.inaxes is None: 94 | return 95 | if event.button != 1: 96 | return 97 | pind = get_ind_under_point(event) 98 | 99 | 100 | def button_release_callback(event): 101 | 'whenever a mouse button is released' 102 | global pind 103 | if event.button != 1: 104 | return 105 | pind = None 106 | 107 | 108 | def get_ind_under_point(event): 109 | 'get the index of the vertex under point if within epsilon tolerance' 110 | t = ax1.transData.inverted() 111 | tinv = ax1.transData 112 | xy = t.transform([event.x, event.y]) 113 | xr = np.reshape(xvals, (np.shape(xvals)[0], 1)) 114 | yr = np.reshape(yvals, (np.shape(yvals)[0], 1)) 115 | xy_vals = np.append(xr, yr, 1) 116 | xyt = tinv.transform(xy_vals) 117 | xt, yt = xyt[:, 0], xyt[:, 1] 118 | d = np.hypot(xt - event.x, yt - event.y) 119 | indseq, = np.nonzero(d == d.min()) 120 | ind = indseq[0] 121 | 122 | if d[ind] >= epsilon: 123 | ind = None 124 | return ind 125 | 126 | 127 | def motion_notify_callback(event): 128 | 'on mouse movement' 129 | global xvals 130 | global yvals 131 | if pind is None: 132 | return 133 | if event.inaxes is None: 134 | return 135 | if event.button != 1: 136 | return 137 | 138 | # update yvals 139 | # print('motion x: {0}; y: {1}'.format(event.xdata,event.ydata)) 140 | # print(f'delta: x: {event.xdata - xvals[pind]}. y: {event.ydata - yvals[pind]}') 141 | delta_x = event.xdata - xvals[pind] 142 | delta_y = event.ydata - yvals[pind] 143 | yvals[pind] = event.ydata 144 | xvals[pind] = event.xdata 145 | 146 | # yvals[pind + 1] = yvals[pind + 1] + delta_y 147 | # xvals[pind + 1] = xvals[pind + 1] + delta_x 148 | 149 | update(None) 150 | 151 | # update curve via sliders and draw 152 | sliders_y[pind].set_val(yvals[pind]) 153 | # sliders_y[pind + 1].set_val(yvals[pind + 1]) 154 | 155 | # sliders_x[pind].set_val(xvals[pind]) 156 | fig.canvas.draw_idle() 157 | 158 | 159 | def bn_eval(model): 160 | """ 161 | https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/67 162 | for batch_size = 1 don't use the running stats 163 | """ 164 | for m in model.modules(): 165 | for child in m.children(): 166 | if type(child) == torch.nn.BatchNorm2d or type(child) == torch.nn.BatchNorm1d: 167 | child.track_running_stats = False 168 | child.running_mean = None 169 | child.running_var = None 170 | 171 | 172 | if __name__ == '__main__': 173 | parser = argparse.ArgumentParser(description="DLP Interactive Demo") 174 | parser.add_argument("-d", "--dataset", type=str, default='celeba', 175 | help="dataset of pretrained model: ['celeba', 'traffic', 'clevrer']") 176 | parser.add_argument("-i", "--index", type=int, 177 | help="index of image in ./checkpoints/sample_images/dataset/", default=0) 178 | args = parser.parse_args() 179 | # hyper-parameters for model 180 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 181 | use_logsoftmax = False 182 | pad_mode = 'replicate' 183 | sigma = 0.1 # default sigma for the gaussian maps 184 | dropout = 0.0 185 | n_kp = 1 # num kp per patch 186 | kp_range = (-1, 1) 187 | kp_activation = "tanh" 188 | mask_threshold = 0.2 189 | learn_order = False 190 | 191 | ds = args.dataset 192 | 193 | image_idx = args.index 194 | image_idx = max(0, image_idx) 195 | 196 | if ds == 'celeba': 197 | path_to_model_ckpt = './checkpoints/dlp_celeba_gauss_pointnetpp_feat.pth' 198 | image_size = 128 199 | ch = 3 200 | enc_channels = [32, 64, 128, 256] 201 | prior_channels = (16, 32, 64) 202 | imwidth = 160 203 | crop = 16 204 | n_kp_enc = 30 # total kp to output from the encoder / filter from prior 205 | n_kp_prior = 50 # total kp to filter from prior 206 | use_object_enc = True 207 | use_object_dec = False 208 | learned_feature_dim = 10 209 | patch_size = 8 210 | anchor_s = 0.125 211 | dec_bone = "gauss_pointnetpp_feat" 212 | exclusive_patches = False 213 | elif ds == 'traffic': 214 | path_to_model_ckpt = './checkpoints/dlp_traffic_gauss_pointnetpp.pth' 215 | image_size = 128 216 | ch = 3 217 | enc_channels = [32, 64, 128, 256] 218 | prior_channels = (16, 32, 64) 219 | imwidth = 160 220 | crop = 16 221 | n_kp_enc = 15 # total kp to output from the encoder / filter from prior 222 | n_kp_prior = 20 # total kp to filter from prior 223 | use_object_enc = True 224 | use_object_dec = True 225 | learned_feature_dim = 20 226 | patch_size = 16 227 | anchor_s = 0.25 228 | dec_bone = "gauss_pointnetpp" 229 | exclusive_patches = False 230 | elif ds == 'clevrer': 231 | path_to_model_ckpt = './checkpoints/dlp_clevrer_gauss_pointnetpp_orig.pth' 232 | image_size = 128 233 | ch = 3 234 | enc_channels = [32, 64, 128, 256] 235 | prior_channels = (16, 32, 64) 236 | imwidth = 160 237 | crop = 16 238 | n_kp_enc = 10 # total kp to output from the encoder / filter from prior 239 | n_kp_prior = 20 # total kp to filter from prior 240 | use_object_enc = True 241 | use_object_dec = True 242 | learned_feature_dim = 5 243 | # learned_feature_dim = 8 244 | patch_size = 16 245 | anchor_s = 0.25 246 | dec_bone = "gauss_pointnetpp" 247 | exclusive_patches = False 248 | else: 249 | raise NotImplementedError 250 | 251 | model = KeyPointVAE(cdim=ch, enc_channels=enc_channels, prior_channels=prior_channels, 252 | image_size=image_size, n_kp=n_kp, learned_feature_dim=learned_feature_dim, 253 | use_logsoftmax=use_logsoftmax, pad_mode=pad_mode, sigma=sigma, 254 | dropout=dropout, dec_bone=dec_bone, patch_size=patch_size, n_kp_enc=n_kp_enc, 255 | n_kp_prior=n_kp_prior, kp_range=kp_range, kp_activation=kp_activation, 256 | mask_threshold=mask_threshold, use_object_enc=use_object_enc, 257 | exclusive_patches=exclusive_patches, use_object_dec=use_object_dec, anchor_s=anchor_s, 258 | learn_order=learn_order).to(device) 259 | model.load_state_dict(torch.load(path_to_model_ckpt, map_location=device), strict=False) 260 | model.eval() 261 | print("loaded model from checkpoint") 262 | logvar_threshold = 0.0 # threshold to filter particles 263 | if ds == 'celeba': 264 | # load image 265 | path_to_images = ['./checkpoints/sample_images/celeb/1.jpg', './checkpoints/sample_images/celeb/2.jpg', 266 | './checkpoints/sample_images/celeb/3.jpg', './checkpoints/sample_images/celeb/4.jpg', 267 | './checkpoints/sample_images/celeb/5.jpg', './checkpoints/sample_images/celeb/6.jpg', 268 | './checkpoints/sample_images/celeb/7.jpg', './checkpoints/sample_images/celeb/8.jpg'] 269 | image_idx = min(image_idx, len(path_to_images) - 1) 270 | path_to_image = path_to_images[image_idx] 271 | im = Image.open(path_to_image) 272 | # move head up a bit 273 | vertical_shift = 30 274 | initial_crop = lambda im: transforms.functional.crop(im, 30, 0, 178, 178) 275 | initial_transforms = transforms.Compose([initial_crop, transforms.Resize(imwidth)]) 276 | trans = transforms.ToTensor() 277 | data = trans(initial_transforms(im.convert("RGB"))) 278 | if crop != 0: 279 | data = data[:, crop:-crop, crop:-crop] 280 | data = data.unsqueeze(0).to(device) 281 | elif ds == 'traffic': 282 | path_to_images = ['./checkpoints/sample_images/traffic/1.png', ] 283 | image_idx = min(image_idx, len(path_to_images) - 1) 284 | path_to_image = path_to_images[image_idx] 285 | im = Image.open(path_to_image) 286 | im = im.convert('RGB') 287 | im = im.crop((60, 0, 480, 420)) 288 | im = im.resize((image_size, image_size), Image.BICUBIC) 289 | trans = transforms.ToTensor() 290 | data = trans(im) 291 | data = data.unsqueeze(0).to(device) 292 | x = data 293 | logvar_threshold = 14.0 # threshold to filter particles 294 | elif ds == 'clevrer': 295 | path_to_images = ['./checkpoints/sample_images/clevrer/1.png', 296 | './checkpoints/sample_images/clevrer/2.png'] 297 | image_idx = min(image_idx, len(path_to_images) - 1) 298 | path_to_image = path_to_images[image_idx] 299 | im = Image.open(path_to_image) 300 | im = im.convert('RGB') 301 | im = im.resize((image_size, image_size), Image.BICUBIC) 302 | trans = transforms.ToTensor() 303 | data = trans(im) 304 | data = data.unsqueeze(0).to(device) 305 | x = data 306 | logvar_threshold = 13.0 # threshold to filter particles 307 | else: 308 | raise NotImplementedError 309 | 310 | with torch.no_grad(): 311 | deterministic = True 312 | enc_out = model.encode_all(data, return_heatmap=True, deterministic=deterministic) 313 | mu, logvar, kp_heatmap, mu_features, logvar_features, obj_on, order_weights = enc_out 314 | if deterministic: 315 | z = mu 316 | z_features = mu_features 317 | else: 318 | z = reparameterize(mu, logvar) 319 | z_features = reparameterize(mu_features, logvar_features) 320 | 321 | # top-k 322 | logvar_sum = logvar.sum(-1) 323 | logvar_topk = torch.topk(logvar_sum, k=5, dim=-1, largest=False) 324 | indices = logvar_topk[1] # [batch_size, topk] 325 | batch_indices = torch.arange(mu.shape[0]).view(-1, 1).to(mu.device) 326 | topk_kp = mu[batch_indices, indices] 327 | print(f'logvar: {logvar_sum[0].data.cpu()}') 328 | 329 | if learn_order: 330 | order_of_kp = [torch.argmax(order_weights[0][i]).item() for i in range(order_weights.shape[-1])] 331 | print(f'order of kp: {order_of_kp}') 332 | if obj_on is not None: 333 | obj_on = torch.where((torch.abs(logvar_sum[:, :-1]) > logvar_threshold), obj_on, 334 | torch.tensor(0.0, dtype=torch.float, device=obj_on.device)) 335 | print(f'obj_on: {obj_on[0].data.cpu()}') 336 | 337 | rec, _, _ = model.decode_all(z, z_features, kp_heatmap, obj_on, deterministic=deterministic, 338 | order_weights=order_weights) 339 | rec = rec.clamp(0, 1) 340 | 341 | N = mu.shape[1] 342 | xmin = 0 343 | xmax = image_size 344 | 345 | x = np.linspace(xmin, xmax, N) 346 | 347 | mu = mu.clamp(kp_range[0], kp_range[1]) 348 | original_mu = mu.clone() 349 | mu = (mu - kp_range[0]) / (kp_range[1] - kp_range[0]) 350 | xvals = mu[0, :-1, 1].data.cpu().numpy() * (image_size - 1) 351 | yvals = mu[0, :-1, 0].data.cpu().numpy() * (image_size - 1) 352 | if learned_feature_dim > 0: 353 | # feature_1_vals = mu_features[0, :, 0].data.cpu().numpy() 354 | feature_1_vals = mu_features[0, :, -1].data.cpu().numpy() 355 | 356 | # set up a plot for topk 357 | topk_kp = topk_kp.clamp(kp_range[0], kp_range[1]) 358 | topk_kp = (topk_kp - kp_range[0]) / (kp_range[1] - kp_range[0]) 359 | xvals_topk = topk_kp[0, :, 1].data.cpu().numpy() * (image_size - 1) 360 | yvals_topk = topk_kp[0, :, 0].data.cpu().numpy() * (image_size - 1) 361 | 362 | fig = plt.figure(figsize=(10, 10)) 363 | ax1 = fig.add_subplot(111) 364 | image = data[0].permute(1, 2, 0).data.cpu().numpy() 365 | ax1.imshow(image) 366 | ax1.scatter(xvals, yvals, label='original', s=70) 367 | ax1.set_axis_off() 368 | ax1.set_title('all particles') 369 | fig = plt.figure(figsize=(10, 10)) 370 | ax2 = fig.add_subplot(111) 371 | ax2.imshow(image) 372 | ax2.scatter(xvals_topk, yvals_topk, label='topk', s=70, color='red') 373 | ax2.set_axis_off() 374 | ax2.set_title('top-5 lowest variance particles') 375 | 376 | # figure.subplot.right 377 | mpl.rcParams['figure.subplot.right'] = 0.8 378 | 379 | # set up a plot 380 | fig, axes = plt.subplots(1, 2, figsize=(15.0, 15.0), sharex=True) 381 | ax1, ax2 = axes 382 | 383 | image = data[0].permute(1, 2, 0).data.cpu().numpy() 384 | ax1.imshow(image) 385 | ax1.set_axis_off() 386 | ax2.set_axis_off() 387 | image_rec = rec[0].permute(1, 2, 0).data.cpu().numpy() 388 | m = ax2.imshow(image_rec) 389 | 390 | pind = None # active point 391 | epsilon = 10 # max pixel distance 392 | 393 | ax1.scatter(xvals, yvals, label='original', s=70) 394 | l = ax1.scatter(xvals, yvals, color='red', marker='*', s=70) 395 | 396 | ax1.set_xlabel('x') 397 | ax1.set_ylabel('y') 398 | 399 | sliders_y = [] 400 | sliders_x = [] 401 | if learned_feature_dim > 0: 402 | sliders_features = [] 403 | 404 | for i in np.arange(N - 1): 405 | slider_width = 0.04 if learned_feature_dim > 0 else 0.12 406 | axamp = plt.axes([0.84, 0.85 - (i * 0.025), slider_width, 0.01]) 407 | # Slider y 408 | s_y = Slider(axamp, 'p_y{0}'.format(i), 0, image_size, valinit=yvals[i]) 409 | sliders_y.append(s_y) 410 | if learned_feature_dim > 0: 411 | axamp_f = plt.axes([0.93, 0.85 - (i * 0.025), slider_width, 0.01]) 412 | s_feat = Slider(axamp_f, f'f_1', -5, 5, valinit=feature_1_vals[i]) 413 | sliders_features.append(s_feat) 414 | 415 | for i in np.arange(N - 1): 416 | sliders_y[i].on_changed(update_from_slider) 417 | if learned_feature_dim > 0: 418 | sliders_features[i].on_changed(update_from_slider) 419 | 420 | axres = plt.axes([0.84, 0.85 - ((N) * 0.025), 0.12, 0.01]) 421 | bres = Button(axres, 'Reset') 422 | bres.on_clicked(reset) 423 | 424 | fig.canvas.mpl_connect('button_press_event', button_press_callback) 425 | fig.canvas.mpl_connect('button_release_event', button_release_callback) 426 | fig.canvas.mpl_connect('motion_notify_event', motion_notify_callback) 427 | 428 | plt.show() 429 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main DLP model neural network. 3 | """ 4 | # imports 5 | import numpy as np 6 | # torch 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | # modules 11 | from modules.modules import KeyPointCNNOriginal, VariationalKeyPointPatchEncoder, SpatialSoftmaxKP, SpatialLogSoftmaxKP, \ 12 | ToGaussianMapHW, CNNDecoder, ObjectEncoder, ObjectDecoderCNN, PointNetPPToCNN 13 | # util functions 14 | from utils.util_func import reparameterize, get_kp_mask_from_gmap, create_masks_fast 15 | 16 | 17 | class KeyPointVAE(nn.Module): 18 | def __init__(self, cdim=3, enc_channels=(16, 16, 32), prior_channels=(16, 16, 32), image_size=64, n_kp=1, 19 | use_logsoftmax=False, pad_mode='replicate', sigma=0.1, dropout=0.0, dec_bone="gauss_pointnetpp", 20 | patch_size=16, n_kp_enc=20, n_kp_prior=20, learned_feature_dim=16, 21 | kp_range=(-1, 1), kp_activation="tanh", mask_threshold=0.2, anchor_s=0.25, 22 | use_object_enc=False, use_object_dec=False, learn_order=False, exclusive_patches=False): 23 | super(KeyPointVAE, self).__init__() 24 | """ 25 | cdim: channels of the input image (3...) 26 | enc_channels: channels for the posterior CNN (takes in the whole image) 27 | prior_channels: channels for prior CNN (takes in patches) 28 | n_kp: number of kp to extract from each (!) patch 29 | n_kp_prior: number of kp to filter from the set of prior kp (of size n_kp x num_patches) 30 | n_kp_enc: number of posterior kp to be learned (this is the actual number of kp that will be learnt) 31 | use_logsoftmax: for spatial-softmax, set True to use log-softmax for numerical stability 32 | pad_mode: padding for the CNNs, 'zeros' or 'replicate' (default) 33 | sigma: the prior std of the KP 34 | dropout: dropout for the CNNs. We don't use it though... 35 | dec_bone: decoder backbone -- "gauss_pointnetpp_feat": Masked Model, "gauss_pointnetpp": "Object Model 36 | patch_size: patch size for the prior KP proposals network (not to be confused with the glimpse size) 37 | kp_range: the range of keypoints, can be [-1, 1] (default) or [0,1] 38 | learned_feature_dim: the latent visual features dimensions extracted from glimpses. 39 | kp_activation: the type of activation to apply on the keypoints: "tanh" for kp_range [-1, 1], "sigmoid" for [0, 1] 40 | mask_threshold: activation threshold (>thresh -> 1, else 0) for the binary mask created from the Gaussian-maps. 41 | anchor_s: defines the glimpse size as a ratio of image_size (e.g., 0.25 for image_size=128 -> glimpse_size=32) 42 | learn_order: experimental feature to learn the order of keypoints - but it doesn't work yet. 43 | use_object_enc: set True to use a separate encoder to encode visual features of glimpses. 44 | use_object_dec: set True to use a separate decoder to decode glimpses (Object Model). 45 | exclusive_patches: (mostly) enforce one particle pre object by masking up regions that were already encoded. 46 | """ 47 | if dec_bone not in ["gauss_pointnetpp", "gauss_pointnetpp_feat"]: 48 | raise SystemError(f'unrecognized decoder backbone: {dec_bone}') 49 | print(f'decoder backbone: {dec_bone}') 50 | self.dec_bone = dec_bone 51 | self.image_size = image_size 52 | self.use_logsoftmax = use_logsoftmax 53 | self.sigma = sigma 54 | print(f'prior std: {self.sigma}') 55 | self.dropout = dropout 56 | self.kp_range = kp_range 57 | print(f'keypoints range: {self.kp_range}') 58 | self.num_patches = int((image_size // patch_size) ** 2) 59 | self.n_kp = n_kp 60 | self.n_kp_total = self.n_kp * self.num_patches 61 | self.n_kp_prior = min(self.n_kp_total, n_kp_prior) 62 | print(f'total number of kp: {self.n_kp_total} -> prior kp: {self.n_kp_prior}') 63 | self.n_kp_enc = n_kp_enc 64 | print(f'number of kp from encoder: {self.n_kp_enc}') 65 | self.kp_activation = kp_activation 66 | print(f'kp_activation: {self.kp_activation}') 67 | self.patch_size = patch_size 68 | self.features_dim = int(image_size // (2 ** (len(enc_channels) - 1))) 69 | self.learned_feature_dim = learned_feature_dim 70 | print(f'learnable feature dim: {learned_feature_dim}') 71 | self.mask_threshold = mask_threshold 72 | print(f'mask threshold: {self.mask_threshold}') 73 | self.anchor_s = anchor_s 74 | self.obj_patch_size = np.round(anchor_s * (image_size - 1)).astype(int) 75 | print(f'object patch size: {self.obj_patch_size}') 76 | self.use_object_enc = True if use_object_dec else use_object_enc 77 | self.use_object_dec = use_object_dec 78 | print(f'object encoder: {self.use_object_enc}, object decoder: {self.use_object_dec}') 79 | self.learn_order = learn_order 80 | print(f'learn particles order: {self.learn_order}') 81 | self.exclusive_patches = exclusive_patches 82 | 83 | # encoder 84 | self.enc = KeyPointCNNOriginal(cdim=cdim, channels=enc_channels, image_size=image_size, n_kp=self.n_kp_enc, 85 | pad_mode=pad_mode, use_resblock=False) 86 | enc_output_dim = 2 * 2 87 | # flatten feature maps and extract statistics 88 | self.to_normal_stats = nn.Sequential(nn.Linear(self.n_kp_enc * self.features_dim ** 2, 256), 89 | nn.ReLU(True), 90 | nn.Linear(256, 128), 91 | nn.ReLU(True), 92 | nn.Linear(128, self.n_kp_enc * enc_output_dim)) 93 | if self.use_object_dec: 94 | if self.learn_order: 95 | enc_aux_output_dim = 1 + self.n_kp_enc # obj_on, ordering weights 96 | else: 97 | enc_aux_output_dim = 1 # obj_on 98 | self.aux_enc = nn.Sequential(nn.Linear(self.n_kp_enc * self.features_dim ** 2, 256), 99 | nn.ReLU(True), 100 | nn.Linear(256, 128), 101 | nn.ReLU(True), 102 | nn.Linear(128, self.n_kp_enc * enc_aux_output_dim)) 103 | else: 104 | self.aux_enc = None 105 | # object encoder 106 | object_enc_output_dim = self.learned_feature_dim * 2 # [mu_features, sigma_features] 107 | self.object_enc = nn.Sequential(nn.Linear(self.n_kp_enc * self.features_dim ** 2, 256), 108 | nn.ReLU(True), 109 | nn.Linear(256, 128), 110 | nn.ReLU(True), 111 | nn.Linear(128, object_enc_output_dim)) 112 | if self.use_object_enc: 113 | if self.use_object_dec: 114 | self.object_enc_sep = ObjectEncoder(z_dim=learned_feature_dim, anchor_size=anchor_s, 115 | image_size=image_size, ch=cdim, margin=0, cnn=True) 116 | else: 117 | self.object_enc_sep = ObjectEncoder(z_dim=learned_feature_dim, anchor_size=anchor_s, 118 | image_size=self.features_dim, ch=self.n_kp_enc, 119 | margin=0, cnn=False, encode_location=True) 120 | else: 121 | self.object_enc_sep = None 122 | self.prior = VariationalKeyPointPatchEncoder(cdim=cdim, channels=prior_channels, image_size=image_size, 123 | n_kp=n_kp, kp_range=self.kp_range, 124 | patch_size=patch_size, use_logsoftmax=use_logsoftmax, 125 | pad_mode=pad_mode, sigma=sigma, dropout=dropout, 126 | learnable_logvar=False, learned_feature_dim=0) 127 | self.ssm = SpatialLogSoftmaxKP(kp_range=kp_range) if use_logsoftmax else SpatialSoftmaxKP(kp_range=kp_range) 128 | 129 | # decoder 130 | decoder_n_kp = 3 * self.n_kp_enc if self.dec_bone == "gauss_pointnetpp_feat" else 2 * self.n_kp_enc 131 | self.to_gauss_map = ToGaussianMapHW(sigma_w=sigma, sigma_h=sigma, kp_range=kp_range) 132 | self.pointnet = PointNetPPToCNN(axis_dim=2, target_hw=self.features_dim, 133 | n_kp=self.n_kp_enc, features_dim=self.learned_feature_dim, 134 | pad_mode=pad_mode) 135 | self.dec = CNNDecoder(cdim=cdim, channels=enc_channels, image_size=image_size, in_ch=decoder_n_kp, 136 | n_kp=self.n_kp_enc + 1, pad_mode=pad_mode) 137 | # object decoder 138 | if self.use_object_dec: 139 | self.object_dec = ObjectDecoderCNN(patch_size=(self.obj_patch_size, self.obj_patch_size), num_chans=4, 140 | bottleneck_size=learned_feature_dim) 141 | else: 142 | self.object_dec = None 143 | self.init_weights() 144 | 145 | def get_parameters(self, prior=True, encoder=True, decoder=True): 146 | parameters = [] 147 | if prior: 148 | parameters.extend(list(self.prior.parameters())) 149 | if encoder: 150 | parameters.extend(list(self.enc.parameters())) 151 | parameters.extend(list(self.to_normal_stats.parameters())) 152 | parameters.extend(list(self.object_enc.parameters())) 153 | if self.use_object_enc: 154 | parameters.extend(list(self.object_enc_sep.parameters())) 155 | if self.use_object_dec: 156 | parameters.extend(list(self.aux_enc.parameters())) 157 | if decoder: 158 | parameters.extend(list(self.dec.parameters())) 159 | parameters.extend(list(self.pointnet.parameters())) 160 | if self.use_object_dec: 161 | parameters.extend(list(self.object_dec.parameters())) 162 | return parameters 163 | 164 | def set_require_grad(self, prior_value=True, enc_value=True, dec_value=True): 165 | for param in self.prior.parameters(): 166 | param.requires_grad = prior_value 167 | for param in self.enc.parameters(): 168 | param.requires_grad = enc_value 169 | for param in self.to_normal_stats.parameters(): 170 | param.requires_grad = enc_value 171 | for param in self.object_enc.parameters(): 172 | param.requires_grad = enc_value 173 | if self.use_object_enc: 174 | for param in self.object_enc_sep.parameters(): 175 | param.requires_grad = enc_value 176 | if self.use_object_dec: 177 | for param in self.aux_enc.parameters(): 178 | param.requires_grad = enc_value 179 | for param in self.dec.parameters(): 180 | param.requires_grad = dec_value 181 | for param in self.pointnet.parameters(): 182 | param.requires_grad = dec_value 183 | if self.use_object_dec: 184 | for param in self.object_dec.parameters(): 185 | param.requires_grad = dec_value 186 | 187 | def init_weights(self): 188 | for m in self.modules(): 189 | if isinstance(m, nn.Conv2d): 190 | nn.init.normal_(m.weight, 0, 0.01) 191 | if m.bias is not None: 192 | nn.init.constant_(m.bias, 0) 193 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 194 | nn.init.constant_(m.weight, 1) 195 | nn.init.constant_(m.bias, 0) 196 | elif isinstance(m, nn.Linear): 197 | # use pytorch's default 198 | pass 199 | 200 | def encode(self, x, return_heatmap=False, mask=None): 201 | _, z_kp = self.enc(x) # [batch_size, n_kp, features_dim, features_dim] 202 | if mask is None: 203 | masked_hm = z_kp 204 | else: 205 | masked_hm = mask * z_kp 206 | z_kp_v = masked_hm.view(masked_hm.shape[0], -1) # [batch_size, n_kp * features_dim * features_dim] 207 | stats = self.to_normal_stats(z_kp_v) # [batch_size, n_kp * 4] 208 | stats = stats.view(stats.shape[0], self.n_kp_enc, 2 * 2) 209 | # [batch_size, n_kp, 4 + learned_feature_dim * 2] 210 | mu_enc, logvar_enc = torch.chunk(stats, chunks=2, dim=-1) # [batch_size, n_kp, 2 + learned_feature_dim] 211 | mu, logvar = mu_enc[:, :, :2], logvar_enc[:, :, :2] # [x, y] 212 | 213 | logvar_p = torch.log(torch.tensor(self.sigma ** 2, device=logvar.device)) 214 | if self.use_object_dec: 215 | stats_aux = self.aux_enc(z_kp_v.detach()) 216 | if self.learn_order: 217 | stats_aux = stats_aux.view(stats_aux.shape[0], self.n_kp_enc, 1 + self.n_kp_enc) 218 | order_weights = stats_aux[:, :, 1:] 219 | else: 220 | stats_aux = stats_aux.view(stats_aux.shape[0], self.n_kp_enc, 1) 221 | order_weights = None 222 | mu_obj_weight = stats_aux[:, :, 0] 223 | mu_obj_weight = torch.sigmoid(mu_obj_weight) 224 | else: 225 | mu_obj_weight = None 226 | order_weights = None 227 | mu_features, logvar_features = None, None 228 | if self.kp_activation == "tanh": 229 | mu = torch.tanh(mu) 230 | elif self.kp_activation == "sigmoid": 231 | mu = torch.sigmoid(mu) 232 | 233 | mu = torch.cat([mu, torch.zeros_like(mu[:, 0]).unsqueeze(1)], dim=1) 234 | logvar = torch.cat([logvar, logvar_p * torch.ones_like(logvar[:, 0]).unsqueeze(1)], dim=1) 235 | 236 | if return_heatmap: 237 | return mu, logvar, z_kp, mu_features, logvar_features, mu_obj_weight, order_weights 238 | else: 239 | return mu, logvar, mu_features, logvar_features, mu_obj_weight, order_weights 240 | 241 | def encode_object_features(self, features_map, masks): 242 | # features_map [bs, n_kp, feature_dim, feature_dim] 243 | # masks: [bs, n_kp + 1, feature_dim, feature_dim] 244 | y = masks.unsqueeze(2) * features_map.unsqueeze(1) # [bs, n_kp + 1, n_kp, feature_dim, feature_dim] 245 | y = y.view(y.shape[0], y.shape[1], -1) # [bs, n_kp + 1, n_kp * feature_dim ** 2] 246 | enc_out = self.object_enc(y) # [bs, n_kp + 1, learned_feature_dim * 2] 247 | mu_features, logvar_features = torch.chunk(enc_out, chunks=2, dim=-1) # [bs, n_kp + 1, learned_feature_dim] 248 | return mu_features, logvar_features 249 | 250 | def encode_object_features_sep(self, x, kp, features_map, masks, exclusive_patches=False, obj_on=None): 251 | # x: [bs, ch, image_size, image_size] 252 | # kp :[bs, n_kp, 2] 253 | # features_map: [bs, n_kp, features_dim, features_dim] 254 | # masks: [bs, n_kp, features_dim, features_dim] 255 | 256 | batch_size, n_kp, features_dim, _ = masks.shape 257 | 258 | # object features 259 | obj_enc_out = self.object_enc_sep(x, kp.detach(), exclusive_patches=exclusive_patches, obj_on=obj_on) 260 | mu_obj, logvar_obj, cropped_objects = obj_enc_out[0], obj_enc_out[1], obj_enc_out[2] 261 | if len(obj_enc_out) > 3: 262 | cropped_objects_masks = obj_enc_out[3] 263 | else: 264 | cropped_objects_masks = None 265 | 266 | # bg beatures 267 | if self.use_object_dec: 268 | obj_fmap_masks = create_masks_fast(kp.detach(), anchor_s=self.anchor_s, feature_dim=self.features_dim) 269 | bg_mask = 1 - obj_fmap_masks.squeeze(2).sum(1, keepdim=True).clamp(0, 1) 270 | # [bs, 1, features_dim, features_dim] 271 | else: 272 | bg_mask = masks[:, -1].unsqueeze(1) # [bs, 1, features_dim, features_dim] 273 | masked_features = bg_mask.unsqueeze(2) * features_map.unsqueeze(1) # [bs, 1, n_kp, f_dim, f_dim] 274 | masked_features = masked_features.view(batch_size, masked_features.shape[1], -1) # flatten 275 | object_enc_out = self.object_enc(masked_features) # [bs, 1, 2 * learned_features_dim] 276 | mu_bg, logvar_bg = object_enc_out.chunk(2, dim=-1) 277 | 278 | mu_features = torch.cat([mu_obj, mu_bg], dim=1) 279 | logvar_features = torch.cat([logvar_obj, logvar_bg], dim=1) 280 | 281 | return mu_features, logvar_features, cropped_objects, cropped_objects_masks 282 | 283 | def encode_all(self, x, return_heatmap=False, mask=None, deterministic=False): 284 | # posterior 285 | enc_out = self.encode(x, return_heatmap=True, mask=mask) 286 | mu, logvar, kp_heatmap, mu_features, logvar_features, obj_on, order_weights = enc_out 287 | if deterministic: 288 | z = mu 289 | else: 290 | z = reparameterize(mu, logvar) 291 | gmap_1_fg = self.to_gauss_map(z[:, :-1], self.features_dim, self.features_dim) 292 | fg_masks_sep = get_kp_mask_from_gmap(gmap_1_fg, threshold=self.mask_threshold, binary=True, 293 | elementwise=True).detach() 294 | fg_masks = fg_masks_sep.sum(1, keepdim=True).clamp(0, 1) 295 | bg_masks = 1 - fg_masks 296 | masks_sep = torch.cat([fg_masks_sep, bg_masks], dim=1) 297 | 298 | if self.learned_feature_dim > 0: 299 | if self.use_object_enc: 300 | feat_source = x if self.use_object_dec else kp_heatmap.detach() 301 | obj_enc_out = self.encode_object_features_sep(feat_source, mu[:, :-1], kp_heatmap.detach(), 302 | masks_sep.detach()) 303 | mu_features, logvar_features, cropped_objects = obj_enc_out[0], obj_enc_out[1], obj_enc_out[2] 304 | else: 305 | mu_features, logvar_features = self.encode_object_features(kp_heatmap.detach(), masks_sep) 306 | if return_heatmap: 307 | return mu, logvar, kp_heatmap, mu_features, logvar_features, obj_on, order_weights 308 | else: 309 | return mu, logvar, mu_features, logvar_features, obj_on, order_weights 310 | 311 | def encode_prior(self, x): 312 | return self.prior(x) 313 | 314 | def decode(self, z): 315 | return self.dec(z) 316 | 317 | def get_prior_kp(self, x, probs=False): 318 | _, z = self.encode_prior(x) 319 | return self.ssm(z, probs) 320 | 321 | def translate_patches(self, kp_batch, patches_batch, scale=None, translation=None): 322 | """ 323 | translate patches to be centered around given keypoints 324 | kp_batch: [bs, n_kp, 2] in [-1, 1] 325 | patches: [bs, n_kp, ch_patches, patch_size, patch_size] 326 | scale: None or [bs, n_kp, 2] or [bs, n_kp, 1] 327 | translation: None or [bs, n_kp, 2] or [bs, n_kp, 1] (delta from kp) 328 | :return: translated_padded_pathces [bs, n_kp, ch, img_size, img_size] 329 | """ 330 | batch_size, n_kp, ch_patch, patch_size, _ = patches_batch.shape 331 | img_size = self.image_size 332 | pad_size = (img_size - patch_size) // 2 333 | padded_patches_batch = F.pad(patches_batch, pad=[pad_size] * 4) 334 | delta_t_batch = 0.0 - kp_batch 335 | delta_t_batch = delta_t_batch.reshape(-1, delta_t_batch.shape[-1]) # [bs * n_kp, 2] 336 | padded_patches_batch = padded_patches_batch.reshape(-1, *padded_patches_batch.shape[2:]) 337 | # [bs * n_kp, 3, patch_size, patch_size] 338 | zeros = torch.zeros([delta_t_batch.shape[0], 1], device=delta_t_batch.device).float() 339 | ones = torch.ones([delta_t_batch.shape[0], 1], device=delta_t_batch.device).float() 340 | 341 | if scale is None: 342 | scale_w = ones 343 | scale_h = ones 344 | elif scale.shape[-1] == 1: 345 | scale_w = scale[:, :-1].reshape(-1, scale.shape[-1]) # no need for bg kp 346 | scale_h = scale[:, :-1].reshape(-1, scale.shape[-1]) # no need for bg kp 347 | else: 348 | scale_h, scale_w = torch.split(scale[:, :-1], [1, 1], dim=-1) 349 | scale_w = scale_w.reshape(-1, scale_w.shape[-1]) 350 | scale_h = scale_h.reshape(-1, scale_h.shape[-1]) 351 | if translation is None: 352 | trans_w = zeros 353 | trans_h = zeros 354 | elif translation.shape[-1] == 1: 355 | trans_w = translation[:, :-1].reshape(-1, translation.shape[-1]) # no need for bg kp 356 | trans_h = translation[:, :-1].reshape(-1, translation.shape[-1]) # no need for bg kp 357 | else: 358 | trans_h, trans_w = torch.split(translation[:, :-1], [1, 1], dim=-1) 359 | trans_w = trans_w.reshape(-1, trans_w.shape[-1]) 360 | trans_h = trans_h.reshape(-1, trans_h.shape[-1]) 361 | 362 | theta = torch.cat([scale_h, zeros, delta_t_batch[:, 1].unsqueeze(-1) + trans_h, 363 | zeros, scale_w, delta_t_batch[:, 0].unsqueeze(-1) + trans_w], dim=-1) 364 | 365 | theta = theta.view(-1, 2, 3) # [batch_size * n_kp, 2, 3] 366 | align_corners = False 367 | padding_mode = 'zeros' 368 | # mode = "nearest" 369 | mode = 'bilinear' 370 | 371 | grid = F.affine_grid(theta, padded_patches_batch.size(), align_corners=align_corners) 372 | trans_padded_patches_batch = F.grid_sample(padded_patches_batch, grid, align_corners=align_corners, 373 | mode=mode, padding_mode=padding_mode) 374 | 375 | trans_padded_patches_batch = trans_padded_patches_batch.view(batch_size, n_kp, *padded_patches_batch.shape[1:]) 376 | # [bs, n_kp, ch, img_size, img_size] 377 | return trans_padded_patches_batch 378 | 379 | def get_objects_alpha_rgb(self, z_kp, z_features, scale=None, translation=None, deterministic=False, 380 | order_weights=None): 381 | dec_objects = self.object_dec(z_features[:, :-1]) # [bs * n_kp, 4, patch_size, patch_size] 382 | dec_objects = dec_objects.view(-1, self.n_kp_enc, 383 | *dec_objects.shape[1:]) # [bs, n_kp, 4, patch_size, patch_size] 384 | # translate patches 385 | dec_objects_trans = self.translate_patches(z_kp[:, :-1], dec_objects, scale, translation) 386 | dec_objects_trans = dec_objects_trans.clamp(0, 1) # STN can change values to be < 0 387 | # dec_objects_trans: [bs, n_kp, 3, im_size, im_size] 388 | if order_weights is not None: 389 | # for each particle, we get a one-hot vector of its place in the order 390 | # we then move all of its maps [4, h, w] to its new place via 1x1 grouped-convolution (group_size=batch_size) 391 | bs, n_kp, n_ch, h, w = dec_objects_trans.shape 392 | order_weights = order_weights.view(order_weights.shape[0], self.n_kp_enc, self.n_kp_enc, 1, 1) 393 | order_weights = F.gumbel_softmax(order_weights, hard=True, dim=1) # straight-through gradients (hard=True) 394 | # order weights: [bs, n_kp, n_kp, 1, 1] - for each kp, its location in the order, in one-hot form 395 | # i.e., if kp 1 is 6 in the order of 8 kp, then its vector: [0 0 0 0 0 0 1 0] 396 | order_weights = order_weights.view(order_weights.shape[0] * self.n_kp_enc, self.n_kp_enc, 1, 1) 397 | reordered_objects = dec_objects_trans.reshape(1, -1, h * n_ch, w) # [1, bs * n_kp, h * n_ch, w] 398 | ordered_objects = F.conv2d(reordered_objects, order_weights, bias=None, stride=1, groups=bs) 399 | ordered_objects = ordered_objects.view(bs, n_kp, n_ch, h, w) 400 | dec_objects_trans = ordered_objects 401 | 402 | # multiply by alpha channel 403 | a_obj, rgb_obj = torch.split(dec_objects_trans, [1, 3], dim=2) 404 | 405 | if not deterministic: 406 | attn_mask = torch.where(a_obj > 0.1, 1.0, 0.0) 407 | # attn_mask = self.to_gauss_map(z_kp[:, :-1], a_obj.shape[-1], a_obj.shape[-1]).unsqueeze( 408 | # 2).detach() 409 | a_obj = a_obj + self.sigma * torch.randn_like(a_obj) * attn_mask 410 | return dec_objects, a_obj, rgb_obj 411 | 412 | def stitch_objects(self, a_obj, rgb_obj, obj_on, bg, stitch_method='c'): 413 | # turn off inactive kp 414 | # obj_on: [bs, n_kp, 1] 415 | a_obj = obj_on[:, :, None, None, None] * a_obj # [bs, n_kp, 4, im_size, im_size] 416 | if stitch_method == 'a': 417 | # layer-wise stitching, each particle is a layer 418 | # x_0 = bg 419 | # x_i = (1-a_i) * x_(i-1) + a_i * rgb_i 420 | rec = bg 421 | curr_mask = a_obj[:, 0] 422 | comp_masks = [curr_mask] # to calculate the effective mask, only for plotting 423 | for i in range(a_obj.shape[1]): 424 | rec = (1 - a_obj[:, i]) * rec + a_obj[:, i] * rgb_obj[:, i] 425 | # rec = (1 - a_obj[:, i].detach()) * rec + a_obj[:, i] * rgb_obj[:, i] 426 | # what is the effect of this? bad, masks are not learned properly 427 | if i > 0: 428 | available_space = 1.0 - curr_mask.detach() 429 | curr_mask_tmp = torch.min(available_space, a_obj[:, i]) 430 | comp_masks.append(curr_mask_tmp) 431 | curr_mask = curr_mask + curr_mask_tmp 432 | comp_masks = torch.stack(comp_masks, dim=1) 433 | dec_objects_trans = comp_masks * rgb_obj 434 | dec_objects_trans = dec_objects_trans.sum(1) 435 | elif stitch_method == 'b': 436 | # same formula as method 'a', but with detach and opening the recursive formula 437 | # x_n = bg * \prod_{i=1}^n (1-a_i) + a_n * rgb_n + a_(n-1) * rgb_(n-1) * (1-a_n) + ... 438 | # + a_1 * rgb_1 * \prod_{i=1}^{n-1} (1-a_i) 439 | bg_comp = torch.prod(1 - a_obj, dim=1) * bg 440 | obj = a_obj * rgb_obj 441 | # stitch 442 | rec = obj[:, -1] 443 | for i in reversed(range(a_obj.shape[1] - 1)): 444 | rec = rec + obj[:, i] * torch.prod((1 - a_obj[:, i + 1:].detach()), dim=1) 445 | dec_objects_trans = rec.detach() 446 | rec = rec + bg_comp 447 | else: 448 | # alpha-based stitching: we first calculate the effective masks, assuming the previous 449 | # masks already occupy some space that cannot be taken and finally we multiply the effective masks 450 | # by the rgb channel, the bg mask is the space left from the sum of all effective masks. 451 | curr_mask = a_obj[:, 0] 452 | comp_masks = [curr_mask] 453 | for i in range(1, a_obj.shape[1]): 454 | available_space = 1.0 - curr_mask.detach() 455 | curr_mask_tmp = torch.min(available_space, a_obj[:, i]) 456 | comp_masks.append(curr_mask_tmp) 457 | curr_mask = curr_mask + curr_mask_tmp 458 | comp_masks = torch.stack(comp_masks, dim=1) 459 | comp_masks_sum = comp_masks.sum(1).clamp(0, 1) 460 | alpha_mask = 1.0 - comp_masks_sum 461 | dec_objects_trans = comp_masks * rgb_obj 462 | dec_objects_trans = dec_objects_trans.sum(1) # [bs, 3, im_size, im_size] 463 | rec = alpha_mask * bg + dec_objects_trans 464 | return rec, dec_objects_trans 465 | 466 | def decode_objects(self, z_kp, z_features, obj_on, scale=None, translation=None, deterministic=False, 467 | order_weights=None, bg=None): 468 | dec_objects, a_obj, rgb_obj = self.get_objects_alpha_rgb(z_kp, z_features, scale=scale, translation=translation, 469 | deterministic=deterministic, 470 | order_weights=order_weights) 471 | if bg is None: 472 | bg = torch.zeros_like(rgb_obj[:, 0]) 473 | # stitching 474 | rec, dec_objects_trans = self.stitch_objects(a_obj, rgb_obj, obj_on=obj_on, bg=bg) 475 | return dec_objects, dec_objects_trans, rec 476 | 477 | def decode_all(self, z, z_features, kp_heatmap, obj_on, deterministic=False, order_weights=None): 478 | gmap_1_fg = self.to_gauss_map(z[:, :-1], self.features_dim, self.features_dim) 479 | gmap_1_bg = 1 - gmap_1_fg.sum(1, keepdim=True).clamp(0, 1).detach() 480 | gmap_1 = torch.cat([gmap_1_fg, gmap_1_bg], dim=1) 481 | fg_masks_sep = get_kp_mask_from_gmap(gmap_1_fg, threshold=self.mask_threshold, binary=True, 482 | elementwise=True).detach() 483 | fg_masks = fg_masks_sep.sum(1, keepdim=True).clamp(0, 1) 484 | bg_masks = 1 - fg_masks 485 | masks = torch.cat([fg_masks.expand_as(gmap_1_fg), bg_masks], dim=1) 486 | # decode object and translate them to the positions of the keypoints 487 | # decode 488 | z_features_in = z_features 489 | if self.dec_bone == "gauss_pointnetpp": 490 | if self.learned_feature_dim > 0: 491 | gmap_2 = self.pointnet(position=z.detach(), 492 | features=torch.cat([z.detach(), z_features_in], dim=-1)) 493 | else: 494 | gmap_2 = self.pointnet(position=z.detach(), features=z.detach()) 495 | gmap = torch.cat([gmap_1[:, :-1], gmap_2], dim=1) 496 | elif self.dec_bone == "gauss_pointnetpp_feat": 497 | if self.learned_feature_dim > 0: 498 | gmap_2 = self.pointnet(position=z.detach(), 499 | features=torch.cat([z.detach(), z_features_in], dim=-1)) 500 | else: 501 | gmap_2 = self.pointnet(position=z.detach(), features=z.detach()) 502 | 503 | fg_masks = masks[:, :-1] 504 | bg_masks = masks[:, -1].unsqueeze(1) 505 | gmap_2 = fg_masks * gmap_2 506 | gmap_3 = bg_masks * kp_heatmap.detach() 507 | gmap = torch.cat([gmap_1[:, :-1], gmap_2, gmap_3], dim=1) 508 | else: 509 | raise NotImplementedError('grow a dec bone') 510 | rec = self.dec(gmap) 511 | 512 | if z_features is not None and self.use_object_dec: 513 | object_dec_out = self.decode_objects(z, z_features, obj_on, deterministic=deterministic, 514 | order_weights=order_weights, bg=rec) 515 | dec_objects, dec_objects_trans, rec = object_dec_out 516 | else: 517 | dec_objects_trans = None 518 | dec_objects = None 519 | return rec, dec_objects, dec_objects_trans 520 | 521 | def forward(self, x, deterministic=False, detach_decoder=False, x_prior=None, warmup=False, stg=False, 522 | noisy_masks=False): 523 | # stg: straight-through-gradients. not used. 524 | # first, extract prior KP proposals 525 | # prior 526 | if x_prior is None: 527 | x_prior = x 528 | kp_p = self.prior(x_prior, global_kp=True) 529 | kp_p = kp_p.view(x_prior.shape[0], -1, 2) # [batch_size, n_kp_total, 2] 530 | # filter proposals by distance to the patches' center 531 | dist_from_center = self.prior.get_distance_from_patch_centers(kp_p, global_kp=True) 532 | _, indices = torch.topk(dist_from_center, k=self.n_kp_prior, dim=-1, largest=True) 533 | batch_indices = torch.arange(kp_p.shape[0]).view(-1, 1).to(kp_p.device) 534 | kp_p = kp_p[batch_indices, indices] 535 | # alternatively, just sample random kp 536 | # kp_p = kp_p[:, torch.randperm(kp_p.shape[1])[:self.n_kp_prior]] 537 | 538 | # encode posterior KP 539 | mu, logvar, kp_heatmap, mu_features, logvar_features, obj_on, order_weights = self.encode(x, 540 | return_heatmap=True) 541 | if deterministic: 542 | z = mu 543 | else: 544 | z = reparameterize(mu, logvar) 545 | 546 | # create gaussian maps (and masks) from the posterior keypoints 547 | gmap_1_fg = self.to_gauss_map(z[:, :-1], self.features_dim, self.features_dim) 548 | gmap_1_bg = 1 - gmap_1_fg.sum(1, keepdim=True).clamp(0, 1).detach() 549 | gmap_1 = torch.cat([gmap_1_fg, gmap_1_bg], dim=1) 550 | fg_masks_sep = get_kp_mask_from_gmap(gmap_1_fg, threshold=self.mask_threshold, binary=True, 551 | elementwise=True).detach() 552 | fg_masks = fg_masks_sep.sum(1, keepdim=True).clamp(0, 1) 553 | bg_masks = 1 - fg_masks 554 | masks = torch.cat([fg_masks.expand_as(gmap_1_fg), bg_masks], dim=1) 555 | masks_sep = torch.cat([fg_masks_sep, bg_masks], dim=1) 556 | 557 | # encode visual features 558 | if self.learned_feature_dim > 0: 559 | if self.use_object_enc: 560 | feat_source = x if self.use_object_dec else kp_heatmap.detach() 561 | obj_on_in = obj_on if not noisy_masks else 0.0 * obj_on + torch.rand_like(obj_on) 562 | obj_enc_out = self.encode_object_features_sep(feat_source, mu[:, :-1], kp_heatmap.detach(), 563 | masks_sep.detach(), 564 | exclusive_patches=self.exclusive_patches, 565 | obj_on=obj_on_in) 566 | mu_features, logvar_features, cropped_objects = obj_enc_out[0], obj_enc_out[1], obj_enc_out[2] 567 | if len(obj_enc_out) > 3: 568 | cropped_objects_masks = obj_enc_out[3] 569 | else: 570 | cropped_objects_masks = None 571 | else: 572 | mu_features, logvar_features = self.encode_object_features(kp_heatmap.detach(), masks_sep) 573 | cropped_objects = None 574 | cropped_objects_masks = None 575 | 576 | if deterministic: 577 | z_features = mu_features 578 | else: 579 | z_features = reparameterize(mu_features, logvar_features) 580 | else: 581 | z_features = None 582 | cropped_objects = None 583 | cropped_objects_masks = None 584 | 585 | # decode 586 | if not warmup or not self.use_object_dec: 587 | z_features_fg, z_features_bg = torch.split(z_features, [self.n_kp_enc, 1], dim=1) 588 | z_features_in = torch.cat([z_features_fg.detach(), z_features_bg], 589 | dim=1) if self.use_object_dec else z_features 590 | if self.dec_bone == "gauss_pointnetpp": 591 | if self.learned_feature_dim > 0: 592 | gmap_2 = self.pointnet(position=z.detach(), 593 | features=torch.cat([z.detach(), z_features_in], dim=-1)) 594 | else: 595 | gmap_2 = self.pointnet(position=z.detach(), features=z.detach()) 596 | gmap = torch.cat([gmap_1[:, :-1], gmap_2], dim=1) 597 | # gmap = torch.cat([gmap_2.detach(), gmap_2], dim=1) 598 | elif self.dec_bone == "gauss_pointnetpp_feat": 599 | if self.learned_feature_dim > 0: 600 | gmap_2 = self.pointnet(position=z.detach(), 601 | features=torch.cat([z.detach(), z_features_in], dim=-1)) 602 | else: 603 | gmap_2 = self.pointnet(position=z.detach(), features=z.detach()) 604 | 605 | fg_masks = masks[:, :-1] 606 | bg_masks = masks[:, -1].unsqueeze(1) 607 | gmap_2 = fg_masks * gmap_2 608 | gmap_3 = bg_masks * kp_heatmap.detach() 609 | gmap = torch.cat([gmap_1[:, :-1], gmap_2, gmap_3], dim=1) 610 | else: 611 | raise NotImplementedError('grow a dec bone') 612 | if detach_decoder: 613 | rec = self.dec(gmap.detach()) 614 | else: 615 | rec = self.dec(gmap) 616 | else: 617 | rec = torch.zeros_like(x) 618 | gmap = None 619 | 620 | # decode object and translate them to the positions of the keypoints 621 | if z_features is not None and self.use_object_dec: 622 | obj_on_in = obj_on if not noisy_masks else 0.0 * obj_on + torch.rand_like(obj_on) 623 | object_dec_out = self.decode_objects(z, z_features, obj_on_in, deterministic=not noisy_masks, 624 | order_weights=order_weights, bg=rec) 625 | dec_objects, dec_objects_trans, rec = object_dec_out 626 | else: 627 | dec_objects_trans = None 628 | dec_objects = None 629 | gmap = None 630 | 631 | output_dict = {} 632 | output_dict['kp_p'] = kp_p 633 | output_dict['gmap'] = gmap 634 | output_dict['rec'] = rec 635 | output_dict['mu'] = mu 636 | output_dict['logvar'] = logvar 637 | output_dict['mu_features'] = mu_features 638 | output_dict['logvar_features'] = logvar_features 639 | # object stuff 640 | output_dict['cropped_objects_original'] = cropped_objects 641 | output_dict['cropped_objects_masks'] = cropped_objects_masks 642 | output_dict['obj_on'] = obj_on 643 | output_dict['dec_objects_original'] = dec_objects 644 | output_dict['dec_objects'] = dec_objects_trans 645 | output_dict['order_weights'] = order_weights 646 | 647 | return output_dict 648 | 649 | def lerp(self, other, betta): 650 | # weight interpolation for ema - not used in the paper 651 | if hasattr(other, 'module'): 652 | other = other.module 653 | with torch.no_grad(): 654 | params = self.parameters() 655 | other_param = other.parameters() 656 | for p, p_other in zip(params, other_param): 657 | p.data.lerp_(p_other.data, 1.0 - betta) 658 | -------------------------------------------------------------------------------- /modules/modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Basic modules and layers. 3 | """ 4 | 5 | # imports 6 | import numpy as np 7 | # torch 8 | import torch 9 | import torch.nn.functional as F 10 | import torch.nn as nn 11 | from utils.util_func import create_masks_fast 12 | # torch geometric 13 | from torch_geometric.nn import MessagePassing, global_max_pool 14 | from torch_cluster import knn_graph 15 | from torch_cluster import fps 16 | 17 | 18 | class ResidualBlock(nn.Module): 19 | def __init__(self, inc=64, outc=64, groups=1, scale=1.0, padding="zeros"): 20 | super(ResidualBlock, self).__init__() 21 | 22 | midc = int(outc * scale) 23 | 24 | if inc is not outc: 25 | self.conv_expand = nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=1, stride=1, padding=0, 26 | groups=1, bias=False) 27 | else: 28 | self.conv_expand = None 29 | if padding == "zeros": 30 | self.conv1 = nn.Conv2d(in_channels=inc, out_channels=midc, kernel_size=3, stride=1, padding=1, 31 | groups=groups, 32 | bias=False) 33 | else: 34 | self.conv1 = nn.Sequential(nn.ReplicationPad2d(1), 35 | nn.Conv2d(in_channels=inc, out_channels=midc, kernel_size=3, stride=1, 36 | padding=0, groups=groups, bias=False)) 37 | self.bn1 = nn.BatchNorm2d(midc) 38 | self.relu1 = nn.LeakyReLU(0.2, inplace=True) 39 | if padding == "zeros": 40 | self.conv2 = nn.Conv2d(in_channels=midc, out_channels=outc, kernel_size=3, stride=1, padding=1, 41 | groups=groups, bias=False) 42 | else: 43 | self.conv2 = nn.Sequential(nn.ReplicationPad2d(1), 44 | nn.Conv2d(in_channels=midc, out_channels=outc, kernel_size=3, stride=1, 45 | padding=0, groups=groups, bias=False)) 46 | self.bn2 = nn.BatchNorm2d(outc) 47 | self.relu2 = nn.LeakyReLU(0.01, inplace=True) 48 | 49 | def forward(self, x): 50 | if self.conv_expand is not None: 51 | identity_data = self.conv_expand(x) 52 | else: 53 | identity_data = x 54 | 55 | output = self.relu1(self.bn1(self.conv1(x))) 56 | output = self.conv2(output) 57 | output = self.bn2(output) 58 | output = self.relu2(torch.add(output, identity_data)) 59 | return output 60 | 61 | 62 | class ConvBlock(nn.Module): 63 | """ 64 | Basic convolutional nn block. 65 | """ 66 | 67 | def __init__(self, c_in, c_out, kernel_size, stride=1, pad=0, pool=False, upsample=False, bias=False, 68 | activation=True, batchnorm=True, relu_type='leaky', pad_mode='zeros', use_resblock=False): 69 | super(ConvBlock, self).__init__() 70 | self.main = nn.Sequential() 71 | if use_resblock: 72 | self.main.add_module(f'conv_{c_in}_to_{c_out}', ResidualBlock(c_in, c_out, padding=pad_mode)) 73 | else: 74 | if pad_mode != 'zeros': 75 | self.main.add_module('replicate_pad', nn.ReplicationPad2d(pad)) 76 | pad = 0 77 | self.main.add_module(f'conv_{c_in}_to_{c_out}', nn.Conv2d(c_in, c_out, kernel_size, 78 | stride=stride, padding=pad, bias=bias)) 79 | if batchnorm: 80 | # note: for better performance on small datasets/batches, it is better to replace with nn.GroupNorm 81 | self.main.add_module(f'bathcnorm_{c_out}', nn.BatchNorm2d(c_out)) 82 | if activation: 83 | if relu_type == 'leaky': 84 | self.main.add_module(f'relu', nn.LeakyReLU(0.01)) 85 | else: 86 | self.main.add_module(f'relu', nn.ReLU()) 87 | if pool: 88 | self.main.add_module(f'max_pool2', nn.MaxPool2d(kernel_size=2, stride=2)) 89 | if upsample: 90 | # note: literature recommends using 'nearest' interpolation instead of`bilinear`. 91 | self.main.add_module(f'upsample_bilinear_2', 92 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)) 93 | 94 | def forward(self, x): 95 | y = self.main(x) 96 | return y 97 | 98 | 99 | class KeyPointCNNOriginal(nn.Module): 100 | """ 101 | CNN to extract heatmaps, inspired by KeyNet (Jakab et.al) 102 | """ 103 | 104 | def __init__(self, cdim=3, channels=(32, 64, 128, 256), image_size=64, n_kp=8, pad_mode='replicate', 105 | use_resblock=False, first_conv_kernel_size=7): 106 | super(KeyPointCNNOriginal, self).__init__() 107 | self.cdim = cdim 108 | self.image_size = image_size 109 | self.n_kp = n_kp 110 | cc = channels[0] 111 | ch = cc 112 | first_conv_pad = first_conv_kernel_size // 2 113 | self.main = nn.Sequential() 114 | self.main.add_module(f'in_block_1', 115 | ConvBlock(cdim, cc, kernel_size=first_conv_kernel_size, stride=1, 116 | pad=first_conv_pad, pool=False, pad_mode=pad_mode, 117 | use_resblock=use_resblock, relu_type='relu')) 118 | self.main.add_module(f'in_block_2', 119 | ConvBlock(cc, cc, kernel_size=3, stride=1, pad=1, pool=False, pad_mode=pad_mode, 120 | use_resblock=use_resblock, relu_type='relu')) 121 | 122 | sz = image_size 123 | for ch in channels[1:]: 124 | self.main.add_module('conv_in_{}_0'.format(sz), ConvBlock(cc, ch, kernel_size=3, stride=2, pad=1, 125 | pool=False, pad_mode=pad_mode, 126 | use_resblock=use_resblock, relu_type='relu')) 127 | self.main.add_module('conv_in_{}_1'.format(ch), ConvBlock(ch, ch, kernel_size=3, stride=1, pad=1, 128 | pool=False, pad_mode=pad_mode, 129 | use_resblock=use_resblock, relu_type='relu')) 130 | cc, sz = ch, sz // 2 131 | 132 | self.keymap = nn.Conv2d(channels[-1], n_kp, kernel_size=1) 133 | self.conv_output_size = self.calc_conv_output_size() 134 | num_fc_features = torch.zeros(self.conv_output_size).view(-1).shape[0] 135 | print("conv shape: ", self.conv_output_size) 136 | # print("num fc features: ", num_fc_features) 137 | # self.fc = nn.Linear(num_fc_features, self.fc_output) 138 | 139 | def calc_conv_output_size(self): 140 | dummy_input = torch.zeros(1, self.cdim, self.image_size, self.image_size) 141 | dummy_input = self.main(dummy_input) 142 | return dummy_input[0].shape 143 | 144 | def forward(self, x): 145 | y = self.main(x) 146 | # heatmap 147 | hm = self.keymap(y) 148 | return y, hm 149 | 150 | 151 | # coverting to probabilities 152 | class SpatialSoftmaxKP(torch.nn.Module): 153 | """ 154 | This module performs spatial-softmax (ssm) by performing marginalization over heatmaps. 155 | """ 156 | 157 | def __init__(self, kp_range=(0, 1)): 158 | super().__init__() 159 | self.kp_range = kp_range 160 | 161 | def forward(self, heatmap, probs=False): 162 | batch_size, n_kp, height, width = heatmap.shape 163 | # marginalize over height (y) 164 | s_h = torch.mean(heatmap, dim=3) # [batch_size, n_kp, features_dim_height] 165 | sm_h = torch.softmax(s_h, dim=-1) # [batch_size, n_kp, features_dim_height] 166 | # marginalize over width (x) 167 | s_w = torch.mean(heatmap, dim=2) # [batch_size, n_kp, features_dim_width] 168 | # probability per spatial coordinate 169 | sm_w = torch.softmax(s_w, dim=-1) # [batch_size, n_kp, features_dim_width] 170 | # each coordinate [0, 1] is assigned a probability 171 | y_axis = torch.linspace(self.kp_range[0], self.kp_range[1], height).type_as(sm_h).expand(1, 1, -1).to( 172 | sm_h.device) # [1, 1, features_dim_height] 173 | # expected value: proability per coordinate * coordinate 174 | kp_h = torch.sum(sm_h * y_axis, dim=-1, keepdim=True) # [batch_size, n_kp, 1] 175 | kp_h = kp_h.squeeze(-1) # [batch_size, n_kp], y coordinate of each kp 176 | 177 | x_axis = torch.linspace(self.kp_range[0], self.kp_range[1], width).type_as(sm_w).expand(1, 1, -1).to( 178 | sm_w.device) # [1, 1, features_dim_width] 179 | kp_w = torch.sum(sm_w * x_axis, dim=-1, keepdim=True).squeeze(-1) # [batch_size, n_kp], x coordinate of each kp 180 | 181 | # stack keypoints 182 | kp = torch.stack([kp_h, kp_w], dim=-1) # [batch_size, n_kp, 2], x, y coordinates of each kp 183 | 184 | if probs: 185 | return kp, sm_h, sm_w 186 | else: 187 | return kp 188 | 189 | 190 | class SpatialLogSoftmaxKP(torch.nn.Module): 191 | """ 192 | This module performs spatial-softmax (ssm) by performing marginalization over heatmaps. Uses log-softmax 193 | for numerical stability. In practice, we do not use it, but feel free to explore it. 194 | """ 195 | 196 | def __init__(self, kp_range=(0, 1)): 197 | super().__init__() 198 | self.kp_range = kp_range 199 | 200 | def forward(self, heatmap, probs=False): 201 | batch_size, n_kp, height, width = heatmap.shape 202 | # marginalize over height (y) 203 | s_h = torch.mean(heatmap, dim=3) # [batch_size, n_kp, features_dim_height] 204 | sm_h = torch.log_softmax(s_h, dim=-1) # [batch_size, n_kp, features_dim_height] 205 | # marginalize over width (x) 206 | s_w = torch.mean(heatmap, dim=2) # [batch_size, n_kp, features_dim_width] 207 | sm_w = torch.log_softmax(s_w, dim=-1) # [batch_size, n_kp, features_dim_width] 208 | # each coordinate [0, 1] is assigned a probability 209 | y_axis = torch.log(torch.linspace(self.kp_range[0], self.kp_range[1], height)).type_as(sm_h).expand(1, 1, 210 | -1).to( 211 | sm_h.device) 212 | # [1, 1, features_dim_height] 213 | # expected value: proability per coordinate * coordinate 214 | kp_h = torch.sum(torch.exp(sm_h + y_axis), dim=-1, keepdim=True) # [batch_size, n_kp, 1] 215 | kp_h = kp_h.squeeze(-1) # [batch_size, n_kp], y coordinate of each kp 216 | 217 | x_axis = torch.log(torch.linspace(self.kp_range[0], self.kp_range[1], width)).type_as(sm_w).expand(1, 1, -1).to( 218 | sm_w.device) 219 | # [1, 1, features_dim_width] 220 | kp_w = torch.sum(torch.exp(sm_w + x_axis), dim=-1, keepdim=True).squeeze(-1) 221 | # [batch_size, n_kp], x coordinate of each kp 222 | 223 | # stack keypoints 224 | kp = torch.stack([kp_h, kp_w], dim=-1) # [batch_size, n_kp, 2], x, y coordinates of each kp 225 | 226 | if probs: 227 | return kp, sm_h, sm_w 228 | else: 229 | return kp 230 | 231 | 232 | class ToGaussianMapHW(nn.Module): 233 | """ 234 | This module converts KP to a gaussian map centered at the coordinates of the keypoint 235 | """ 236 | 237 | def __init__(self, sigma_w=0.1, sigma_h=0.1, kp_range=(0, 1)): 238 | super().__init__() 239 | self.sigma_w = sigma_w 240 | self.sigma_h = sigma_h 241 | self.kp_range = kp_range 242 | 243 | def forward(self, kp, height, width, logvar_h=None, logvar_w=None): 244 | batch_size, n_kp, _ = kp.shape 245 | # get means 246 | h_mean, w_mean = kp[:, :, 0], kp[:, :, 1] 247 | # create a coordinate map for each axis 248 | h_map = torch.linspace(self.kp_range[0], self.kp_range[1], height, device=h_mean.device).type_as( 249 | h_mean) # [height] 250 | w_map = torch.linspace(self.kp_range[0], self.kp_range[1], width, device=w_mean.device).type_as( 251 | w_mean) # [width] 252 | # duplicate for all keypoints in the batch 253 | h_map = h_map.expand(batch_size, n_kp, height) # [batch_size, n_kp, height] 254 | w_map = w_map.expand(batch_size, n_kp, width) # [batch_size, n_kp, width] 255 | # repeat the mean to match dimensions 256 | h_mean_m = h_mean.expand(height, -1, -1) # [height, batch_size, n_kp] 257 | h_mean_m = h_mean_m.permute(1, 2, 0) # [batch_size, n_kp, height] 258 | w_mean_m = w_mean.expand(width, -1, -1) # [width, batch_size, n_kp] 259 | w_mean_m = w_mean_m.permute(1, 2, 0) # [batch_size, n_kp, width] 260 | # for each pixel in the map, calculate the squared distance from the mean 261 | h_sdiff = (h_map - h_mean_m) ** 2 # [batch_size, n_kp, height] 262 | w_sdiff = (w_map - w_mean_m) ** 2 # [batch_size, n_kp, width] 263 | # compute gaussian 264 | # duplicate for the other dimension 265 | hm = h_sdiff.expand(width, -1, -1, -1).permute(1, 2, 3, 0) # [batch_size, n_kp, height, width] 266 | wm = w_sdiff.expand(height, -1, -1, -1).permute(1, 2, 0, 3) # [batch_size, n_kp, height, width] 267 | if logvar_h is not None: 268 | sigma_h = torch.exp(0.5 * logvar_h) 269 | sigma_h = sigma_h.expand(height, -1, -1) # [height, batch_size, n_kp] 270 | sigma_h = sigma_h.permute(1, 2, 0) # [batch_size, n_kp, height] 271 | sigma_h = sigma_h.expand(width, -1, -1, -1).permute(1, 2, 3, 0) # [batch_size, n_kp, height, width] 272 | else: 273 | sigma_h = self.sigma_h 274 | if logvar_w is not None: 275 | sigma_w = torch.exp(0.5 * logvar_w) 276 | sigma_w = sigma_w.expand(width, -1, -1) # [width, batch_size, n_kp] 277 | sigma_w = sigma_w.permute(1, 2, 0) # [batch_size, n_kp, width] 278 | sigma_w = sigma_w.expand(height, -1, -1, -1).permute(1, 2, 0, 3) # [batch_size, n_kp, height, width] 279 | else: 280 | sigma_w = self.sigma_w 281 | # print(hm.shape, sigma_h.shape, wm.shape, sigma_w.shape) 282 | gm = -0.5 * (hm / (sigma_h ** 2) + wm / (sigma_w ** 2)) 283 | gm = torch.exp(gm) # [batch_size, n_kp, height, width] 284 | return gm 285 | 286 | 287 | class CNNDecoder(nn.Module): 288 | """ 289 | This module upsamples activation maps to to the original image size. 290 | """ 291 | 292 | def __init__(self, cdim=3, channels=(64, 128, 256, 512, 512, 512), image_size=64, in_ch=16, n_kp=8, 293 | pad_mode='zeros', use_resblock=False): 294 | super(CNNDecoder, self).__init__() 295 | self.cdim = cdim 296 | self.image_size = image_size 297 | cc = channels[-1] 298 | self.in_ch = in_ch 299 | self.n_kp = n_kp 300 | 301 | sz = 4 302 | 303 | self.main = nn.Sequential() 304 | self.main.add_module('depth_up', 305 | ConvBlock(self.in_ch, cc, kernel_size=3, pad=1, upsample=True, pad_mode=pad_mode, 306 | use_resblock=use_resblock)) 307 | for ch in reversed(channels[1:-1]): 308 | self.main.add_module('conv_to_{}'.format(sz * 2), ConvBlock(cc, ch, kernel_size=3, pad=1, upsample=True, 309 | pad_mode=pad_mode, use_resblock=use_resblock)) 310 | cc, sz = ch, sz * 2 311 | 312 | self.main.add_module('conv_to_{}'.format(sz * 2), 313 | ConvBlock(cc, self.n_kp * (channels[0] // self.n_kp + 1), kernel_size=3, pad=1, 314 | upsample=False, 315 | pad_mode=pad_mode, use_resblock=use_resblock)) 316 | self.final_conv = ConvBlock(self.n_kp * (channels[0] // self.n_kp + 1), cdim, kernel_size=1, bias=True, 317 | activation=False, batchnorm=False) 318 | 319 | def forward(self, z, masks=None): 320 | y = self.main(z) 321 | if masks is not None: 322 | # this is not used in the paper 323 | # masks: [bs, n_kp, feat_dim, feat_dim] 324 | bs, n_kp, fs, _ = masks.shape 325 | # y: [bs, n_kp * ch[0], feat_dim, feat_dim] 326 | y = y.view(bs, n_kp, -1, fs, fs) 327 | y = masks.unsqueeze(2) * y 328 | y = y.view(bs, -1, fs, fs) 329 | y = self.final_conv(y) 330 | return y 331 | 332 | 333 | class ImagePatcher(nn.Module): 334 | """ 335 | This module take an image of size B x cdim x H x W and return a patchified tesnor 336 | B x cdim x num_patches x patch_size x patch_size. It also gives you the global location of the patch 337 | w.r.t the original image. We use this module to extract prior KP from patches, and we need to know their 338 | global coordinates for the Chamfer-KL. 339 | """ 340 | 341 | def __init__(self, cdim=3, image_size=64, patch_size=16): 342 | super(ImagePatcher, self).__init__() 343 | self.cdim = cdim 344 | self.image_size = image_size 345 | self.patch_size = patch_size 346 | self.kh, self.kw = self.patch_size, self.patch_size # kernel size 347 | self.dh, self.dw = self.patch_size, patch_size # stride 348 | self.unfold_shape = self.get_unfold_shape() 349 | self.patch_location_idx = self.get_patch_location_idx() 350 | # print(f'unfold shape: {self.unfold_shape}') 351 | # print(f'patch locations: {self.patch_location_idx}') 352 | 353 | def get_patch_location_idx(self): 354 | h = np.arange(0, self.image_size)[::self.patch_size] 355 | w = np.arange(0, self.image_size)[::self.patch_size] 356 | ww, hh = np.meshgrid(h, w) 357 | hw = np.stack((hh, ww), axis=-1) 358 | hw = hw.reshape(-1, 2) 359 | return torch.from_numpy(hw).int() 360 | 361 | def get_patch_centers(self): 362 | mid = self.patch_size // 2 363 | patch_locations_idx = self.get_patch_location_idx() 364 | patch_locations_idx += mid 365 | return patch_locations_idx 366 | 367 | def get_unfold_shape(self): 368 | dummy_input = torch.zeros(1, self.cdim, self.image_size, self.image_size) 369 | patches = dummy_input.unfold(2, self.kh, self.dh).unfold(3, self.kw, self.dw) 370 | unfold_shape = patches.shape[1:] 371 | return unfold_shape 372 | 373 | def img_to_patches(self, x): 374 | patches = x.unfold(2, self.kh, self.dh).unfold(3, self.kw, self.dw) 375 | patches = patches.contiguous().view(patches.shape[0], patches.shape[1], -1, self.kh, self.kw) 376 | return patches 377 | 378 | def patches_to_img(self, x): 379 | patches_orig = x.view(x.shape[0], *self.unfold_shape) 380 | output_h = self.unfold_shape[1] * self.unfold_shape[2] 381 | output_w = self.unfold_shape[2] * self.unfold_shape[4] 382 | patches_orig = patches_orig.permute(0, 1, 2, 4, 3, 5).contiguous() 383 | patches_orig = patches_orig.view(-1, self.cdim, output_h, output_w) 384 | return patches_orig 385 | 386 | def forward(self, x, patches=True): 387 | # x [batch_size, 3, image_size, image_size] or [batch_size, 3, num_patches, image_size, image_size] 388 | if patches: 389 | return self.img_to_patches(x) 390 | else: 391 | return self.patches_to_img(x) 392 | 393 | 394 | class VariationalKeyPointPatchEncoder(nn.Module): 395 | """ 396 | This module encodes patches to KP via SSM. Additionally, we implement a variational version that encodes 397 | log-variance in addition to the mean, but we don't use it in practice, as constant prior std works better. 398 | We also experimented with extracting features directly from the patches for a prior for the visual features, 399 | but we didn't find it better than a constant prior (~N(0,1)). However, you can explore it by setting 400 | `learned_feature_dim`>0. 401 | """ 402 | 403 | def __init__(self, cdim=3, channels=(16, 16, 32), image_size=64, n_kp=4, patch_size=16, kp_range=(0, 1), 404 | use_logsoftmax=False, pad_mode='replicate', sigma=0.1, dropout=0.0, learnable_logvar=False, 405 | learned_feature_dim=0): 406 | super(VariationalKeyPointPatchEncoder, self).__init__() 407 | self.use_logsoftmax = use_logsoftmax 408 | self.image_size = image_size 409 | self.dropout = dropout 410 | self.kp_range = kp_range 411 | self.n_kp = n_kp # kp per patch 412 | self.patcher = ImagePatcher(cdim=cdim, image_size=image_size, patch_size=patch_size) 413 | self.features_dim = int(patch_size // (2 ** (len(channels) - 1))) 414 | self.enc = KeyPointCNNOriginal(cdim=cdim, channels=channels, image_size=patch_size, n_kp=n_kp, 415 | pad_mode=pad_mode, use_resblock=False, first_conv_kernel_size=3) 416 | self.ssm = SpatialLogSoftmaxKP(kp_range=kp_range) if use_logsoftmax else SpatialSoftmaxKP(kp_range=kp_range) 417 | self.sigma = sigma 418 | self.learnable_logvar = learnable_logvar 419 | self.learned_feature_dim = learned_feature_dim 420 | if self.learnable_logvar: 421 | self.to_logvar = nn.Sequential(nn.Linear(self.n_kp * (self.features_dim ** 2), 512), 422 | nn.ReLU(True), 423 | nn.Linear(512, 256), 424 | nn.ReLU(True), 425 | nn.Linear(256, self.n_kp * 2)) # logvar_x, logvar_y 426 | if self.learned_feature_dim > 0: 427 | self.to_features = nn.Sequential(nn.Linear(self.n_kp * (self.features_dim ** 2), 512), 428 | nn.ReLU(True), 429 | nn.Linear(512, 256), 430 | nn.ReLU(True), 431 | nn.Linear(256, self.n_kp * self.learned_feature_dim)) # logvar_x, logvar_y 432 | 433 | def img_to_patches(self, x): 434 | return self.patcher.img_to_patches(x) 435 | 436 | def patches_to_img(self, x): 437 | return self.patcher.patches_to_img(x) 438 | 439 | def get_global_kp(self, local_kp): 440 | # local_kp: [batch_size, num_patches, n_kp, 2] 441 | # returns the global coordinates of a KP within the original image. 442 | batch_size, num_patches, n_kp, _ = local_kp.shape 443 | global_coor = self.patcher.get_patch_location_idx().to(local_kp.device) # [num_patches, 2] 444 | global_coor = global_coor[:, None, :].repeat(1, n_kp, 1) 445 | global_coor = (((local_kp - self.kp_range[0]) / (self.kp_range[1] - self.kp_range[0])) * ( 446 | self.patcher.patch_size - 1) + global_coor) / (self.image_size - 1) 447 | global_coor = global_coor * (self.kp_range[1] - self.kp_range[0]) + self.kp_range[0] 448 | return global_coor 449 | 450 | def get_distance_from_patch_centers(self, kp, global_kp=False): 451 | # calculates the distance of a KP from the center of its parent patch. This is useful to understand (and filter) 452 | # if SSM detected something, otherwise, the KP will probably land in the center of the patch 453 | # (e.g., a solid-color patch will have the same activation in all pixels). 454 | if not global_kp: 455 | global_coor = self.get_global_kp(kp).view(kp.shape[0], -1, 2) 456 | else: 457 | global_coor = kp 458 | centers = 0.5 * (self.kp_range[1] + self.kp_range[0]) * torch.ones_like(kp).to(kp.device) 459 | global_centers = self.get_global_kp(centers.view(kp.shape[0], -1, self.n_kp, 2)).view(kp.shape[0], -1, 2) 460 | return ((global_coor - global_centers) ** 2).sum(-1) 461 | 462 | def encode(self, x, global_kp=False): 463 | # x: [batch_size, cdim, image_size, image_size] 464 | # global_kp: set True to get the global coordinates within the image (instead of local KP inside the patch) 465 | batch_size, cdim, image_size, image_size = x.shape 466 | x_patches = self.img_to_patches(x) # [batch_size, cdim, num_patches, patch_size, patch_size] 467 | x_patches = x_patches.permute(0, 2, 1, 3, 4) # [batch_size, num_patches, cdim, patch_size, patch_size] 468 | x_patches = x_patches.contiguous().view(-1, cdim, self.patcher.patch_size, self.patcher.patch_size) 469 | _, z = self.enc(x_patches) # [batch_size*num_patches, n_kp, features_dim, features_dim] 470 | mu_kp = self.ssm(z, probs=False) # [batch_size * num_patches, n_kp, 2] 471 | mu_kp = mu_kp.view(batch_size, -1, self.n_kp, 2) # [batch_size, num_patches, n_kp, 2] 472 | if global_kp: 473 | mu_kp = self.get_global_kp(mu_kp) 474 | if self.learned_feature_dim > 0: 475 | mu_features = self.to_features(z.view(z.shape[0], -1)) 476 | mu_features = mu_features.view(batch_size, -1, self.n_kp, self.learned_feature_dim) 477 | # [batch_size, num_patches, n_kp, learned_feature_dim] 478 | if self.learnable_logvar: 479 | logvar_kp = self.to_logvar(z.view(z.shape[0], -1)) 480 | logvar_kp = logvar_kp.view(batch_size, -1, self.n_kp, 2) # [batch_size, num_patches, n_kp, 2] 481 | if self.learned_feature_dim > 0: 482 | return mu_kp, logvar_kp, mu_features 483 | else: 484 | return mu_kp, logvar_kp 485 | elif self.learned_feature_dim > 0: 486 | return mu_kp, mu_features 487 | else: 488 | return mu_kp 489 | 490 | def forward(self, x, global_kp=False): 491 | return self.encode(x, global_kp) 492 | 493 | 494 | class ObjectEncoder(nn.Module): 495 | """ 496 | Glimpse-encoder: encodes patches visual features in a variational fashion (mu, log-variance). 497 | Useful for object-based scenes. 498 | """ 499 | 500 | def __init__(self, z_dim, anchor_size, image_size, cnn_channels=(16, 16, 32), margin=0, ch=3, cnn=False, 501 | encode_location=False): 502 | super().__init__() 503 | 504 | self.anchor_size = anchor_size 505 | self.channels = cnn_channels 506 | self.z_dim = z_dim 507 | self.image_size = image_size 508 | self.patch_size = np.round(anchor_size * (image_size - 1)).astype(int) 509 | self.margin = margin 510 | self.crop_size = self.patch_size + 2 * margin 511 | self.ch = ch 512 | self.encode_location = encode_location 513 | 514 | if cnn: 515 | self.cnn = KeyPointCNNOriginal(cdim=ch, channels=cnn_channels, image_size=self.crop_size, n_kp=32, 516 | pad_mode='replicate', use_resblock=False, first_conv_kernel_size=3) 517 | else: 518 | self.cnn = None 519 | fc_in_dim = 32 * ((self.crop_size // 4) ** 2) if cnn else self.ch * (self.crop_size ** 2) 520 | fc_in_dim = fc_in_dim + 2 if self.encode_location else fc_in_dim 521 | self.fc = nn.Sequential(nn.Linear(fc_in_dim, 256), 522 | nn.ReLU(True), 523 | nn.Linear(256, 128), 524 | nn.ReLU(True), 525 | nn.Linear(128, self.z_dim * 2)) 526 | 527 | def center_objects(self, kp, padded_objects): 528 | # after masking everything outside of the glimpse, move the unmasked area to the center -- makes it easier 529 | # to crop the patches. 530 | batch_size, n_kp, _ = kp.shape 531 | delta_tr_batch = kp 532 | delta_tr_batch = delta_tr_batch.reshape(-1, delta_tr_batch.shape[-1]) # [bs * n_kp, 2] 533 | source_batch = padded_objects.reshape(-1, *padded_objects.shape[2:]) 534 | 535 | zeros = torch.zeros([delta_tr_batch.shape[0], 1], device=delta_tr_batch.device).float() 536 | ones = torch.ones([delta_tr_batch.shape[0], 1], device=delta_tr_batch.device).float() 537 | 538 | theta = torch.cat([ones, zeros, delta_tr_batch[:, 1].unsqueeze(-1), 539 | zeros, ones, delta_tr_batch[:, 0].unsqueeze(-1)], dim=-1) 540 | theta = theta.view(-1, 2, 3) # [batch_size * n_kp, 2, 3] 541 | 542 | align_corners = False 543 | padding_mode = 'zeros' 544 | mode = 'bilinear' 545 | 546 | grid = F.affine_grid(theta, source_batch.size(), align_corners=align_corners) 547 | trans_source_batch = F.grid_sample(source_batch, grid, align_corners=align_corners, 548 | mode=mode, padding_mode=padding_mode) 549 | trans_source_batch = trans_source_batch.view(batch_size, n_kp, *source_batch.shape[1:]) 550 | return trans_source_batch 551 | 552 | def get_cropped_objects(self, centered_objects): 553 | # extracts the unmasked area of the image after the glimpses have been centered in the original image. 554 | center_idx = self.image_size // 2 555 | margin = self.margin 556 | w_start = center_idx - self.patch_size // 2 - margin 557 | w_end = center_idx + self.patch_size // 2 + margin 558 | h_start = center_idx - self.patch_size // 2 - margin 559 | h_end = center_idx + self.patch_size // 2 + margin 560 | cropped_objects = centered_objects[:, :, :, w_start:w_end, h_start:h_end] 561 | return cropped_objects 562 | 563 | def forward(self, x, kp, exclusive_patches=False, obj_on=None): 564 | # x: [bs, ch, image_size, image_size] 565 | # kp: [bs, n_kp, 2] in [-1, 1] 566 | # exclusive_objects: create cumulative masks to avoid overlapping objects, THIS WAS NOT USED IN THE PAPER 567 | # this will (mostly) enforce one particle pre object and 568 | # won't allow several layers per object 569 | batch_size, _, _, img_size = x.shape 570 | _, n_kp, _ = kp.shape 571 | # create masks from kp 572 | masks = create_masks_fast(kp.detach(), self.anchor_size, feature_dim=self.image_size) 573 | # [batch_size, n_kp, 1, feature_dim, feature_dim] 574 | if exclusive_patches: 575 | if obj_on is not None: 576 | masks = obj_on[:, :, None, None, None] * masks 577 | masks = masks.clamp(0, 1) # STN can cause values to be outside [0, 1] 578 | # create cumulative masks to avoid overlapping objects 579 | curr_mask = masks[:, 0] 580 | comp_masks = [curr_mask] 581 | for i in range(1, masks.shape[1]): 582 | available_space = 1.0 - curr_mask.detach() 583 | curr_mask_tmp = torch.min(available_space, masks[:, i]) 584 | comp_masks.append(curr_mask_tmp) 585 | curr_mask = curr_mask + curr_mask_tmp 586 | masks = torch.stack(comp_masks, dim=1) 587 | # extract objects 588 | padded_objects = masks * x.unsqueeze(1) # [batch_size, n_kp, ch, image_size, image_size] 589 | # center objects 590 | centered_objects = self.center_objects(kp, padded_objects) # [batch_size, n_kp, ch, image_size, image_size] 591 | # get crop 592 | cropped_objects = self.get_cropped_objects(centered_objects) 593 | # [batch_size, n_kp, ch, patch_size + margin * 2, image_size + margin * 2] 594 | 595 | # encode objects - fc 596 | if self.cnn is not None: 597 | _, cropped_objects_cnn = self.cnn(cropped_objects.view(-1, *cropped_objects.shape[2:])) 598 | else: 599 | cropped_objects_cnn = cropped_objects 600 | cropped_objects_flat = cropped_objects_cnn.reshape(batch_size, n_kp, -1) # flatten 601 | if self.encode_location: 602 | cropped_objects_flat = torch.cat([cropped_objects_flat, kp], dim=-1) # [batch_size, n_kp, .. + 2] 603 | enc_out = self.fc(cropped_objects_flat) 604 | mu, logvar = enc_out.chunk(2, dim=-1) # [batch_size, n_kp, z_dim] 605 | return mu, logvar, cropped_objects 606 | 607 | 608 | class ObjectDecoderCNN(nn.Module): 609 | """ 610 | Glimpse Decoder: this module takes in a latent object tensor (vector) and upsamples it to an RGBA patch 611 | (RGBA = RGB + alpha channel). 612 | """ 613 | 614 | def __init__(self, patch_size, num_chans=4, bottleneck_size=128, pad_mode='replicate'): 615 | super().__init__() 616 | 617 | if isinstance(patch_size, int): 618 | patch_size = (patch_size, patch_size) 619 | self.patch_size = patch_size 620 | self.num_chans = num_chans 621 | 622 | self.in_ch = 32 623 | 624 | fc_out_dim = self.in_ch * 8 * 8 625 | self.fc = nn.Sequential(nn.Linear(bottleneck_size, 256, bias=True), 626 | nn.ReLU(True), 627 | nn.Linear(256, fc_out_dim), 628 | nn.ReLU(True)) 629 | 630 | # num_upsample = int(np.log(patch_size[0]) // np.log(2)) - 3 631 | num_upsample = int(np.log2(patch_size[0])) - 3 # thanks @MoritzLange 632 | # print(f'ObjDecCNN: fc to cnn num upsample: {num_upsample}') 633 | self.channels = [32] 634 | for i in range(num_upsample): 635 | self.channels.append(64) 636 | cc = self.channels[-1] 637 | 638 | sz = 8 639 | 640 | self.main = nn.Sequential() 641 | self.main.add_module('depth_up', 642 | ConvBlock(self.in_ch, cc, kernel_size=3, pad=1, upsample=True, pad_mode=pad_mode, 643 | use_resblock=False)) 644 | for ch in reversed(self.channels[1:-1]): 645 | self.main.add_module('conv_to_{}'.format(sz * 2), ConvBlock(cc, ch, kernel_size=3, pad=1, upsample=True, 646 | pad_mode=pad_mode, use_resblock=False)) 647 | cc, sz = ch, sz * 2 648 | 649 | self.main.add_module('conv_to_{}'.format(sz * 2), 650 | ConvBlock(cc, self.channels[0], kernel_size=3, pad=1, 651 | upsample=False, pad_mode=pad_mode, use_resblock=False)) 652 | self.main.add_module('final_conv', ConvBlock(self.channels[0], num_chans, kernel_size=1, bias=True, 653 | activation=False, batchnorm=False)) 654 | self.decode = self.main 655 | 656 | def forward(self, x): 657 | if x.dim() == 3: 658 | x = x.reshape(-1, x.shape[-1]) 659 | conv_in = self.fc(x) 660 | conv_in = conv_in.view(-1, 32, 8, 8) 661 | out = self.decode(conv_in).view(-1, self.num_chans, *self.patch_size) 662 | out = torch.sigmoid(out) 663 | return out 664 | 665 | 666 | # torch geometric modules to process graphs/KP 667 | class PointNetPPLayer(MessagePassing): 668 | def __init__(self, in_channels, out_channels, axis_dim=2): 669 | # Message passing with "max" aggregation. 670 | super(PointNetPPLayer, self).__init__('max') 671 | self.axis_dim = axis_dim 672 | # Initialization of the MLP: 673 | # Here, the number of input features correspond to the hidden node 674 | # dimensionality plus point dimensionality (=3). 675 | self.mlp = nn.Sequential(nn.Linear(in_channels + axis_dim, out_channels), 676 | nn.BatchNorm1d(out_channels), 677 | nn.ReLU(), 678 | nn.Linear(out_channels, out_channels), 679 | nn.BatchNorm1d(out_channels)) 680 | 681 | def forward(self, h, pos, edge_index): 682 | # Start propagating messages. 683 | return self.propagate(edge_index, h=h, pos=pos) 684 | 685 | def message(self, h_j, pos_j, pos_i): 686 | # h_j defines the features of neighboring nodes as shape [num_edges, in_channels] 687 | # pos_j defines the position of neighboring nodes as shape [num_edges, 3] 688 | # pos_i defines the position of central nodes as shape [num_edges, 3] 689 | 690 | input_pos = pos_j - pos_i # Compute spatial relation. 691 | 692 | if h_j is not None: 693 | # In the first layer, we may not have any hidden node features, 694 | # so we only combine them in case they are present. 695 | input_pos = torch.cat([h_j, input_pos], dim=-1) 696 | 697 | return self.mlp(input_pos) # Apply our final MLP. 698 | 699 | 700 | class PointNetPPToCNN(nn.Module): 701 | def __init__(self, axis_dim=2, target_hw=16, n_kp=8, pad_mode='replicate', with_fps=False, features_dim=2): 702 | super(PointNetPPToCNN, self).__init__() 703 | # features_dim : 2 [logvar] + additional features 704 | self.with_fps = with_fps 705 | self.axis_dim = axis_dim # mu 706 | self.features_dim = features_dim # logvar, features 707 | self.conv1 = PointNetPPLayer(self.axis_dim + self.features_dim, 64, axis_dim=axis_dim) 708 | self.conv2 = PointNetPPLayer(64, 128, axis_dim=axis_dim) 709 | self.conv3 = PointNetPPLayer(128, 256, axis_dim=axis_dim) 710 | self.conv4 = PointNetPPLayer(256, 512, axis_dim=axis_dim) 711 | 712 | self.n_kp = n_kp 713 | 714 | fc_out_dim = self.n_kp * 8 * 8 715 | 716 | self.fc = nn.Sequential(nn.Linear(512, fc_out_dim, bias=True), 717 | nn.ReLU(True)) 718 | 719 | # num_upsample = int(np.log(target_hw) // np.log(2)) - 3 720 | num_upsample = int(np.log2(patch_size[0])) - 3 # thanks @MoritzLange 721 | # print(f'pointnet to cnn num upsample: {num_upsample}') 722 | self.cnn = nn.Sequential() 723 | for i in range(num_upsample): 724 | self.cnn.add_module(f'depth_up_{i}', ConvBlock(n_kp, n_kp, kernel_size=3, pad=1, 725 | upsample=True, pad_mode=pad_mode)) 726 | 727 | def forward(self, position, features): 728 | # position [batch_size, n_kp, 2] 729 | # features [batch_size, n_kp, features_dim] 730 | pos = position 731 | batch = torch.arange(pos.shape[0]).view(-1, 1).repeat(1, pos.shape[1]).view(-1).to(pos.device) 732 | pos = pos.view(-1, pos.shape[-1]) # [batch_size * n_kp, 2] 733 | features = features.view(-1, features.shape[-1]) # [batch_size * n_kp, features] 734 | # x [batch_size, n_kp, 2 or features_dim] 735 | # Compute the kNN graph: 736 | # Here, we need to pass the batch vector to the function call in order 737 | # to prevent creating edges between points of different examples. 738 | # We also add `loop=True` which will add self-loops to the graph in 739 | # order to preserve central point information. 740 | edge_index = knn_graph(pos, k=10, batch=batch, loop=True) 741 | 742 | # 3. Start bipartite message passing. 743 | h = self.conv1(h=features, pos=pos, edge_index=edge_index) 744 | h = h.relu() 745 | # print(f'conv1 h: {h.shape}') 746 | if self.with_fps: 747 | index = fps(pos, batch=batch, ratio=0.5) 748 | pos = pos[index] 749 | h = h[index] 750 | batch = batch[index] 751 | edge_index = knn_graph(pos, k=5, batch=batch, loop=True) 752 | h = self.conv2(h=h, pos=pos, edge_index=edge_index) 753 | h = h.relu() 754 | # print(f'conv2 h: {h.shape}') 755 | if self.with_fps: 756 | index = fps(pos, batch=batch, ratio=0.5) 757 | pos = pos[index] 758 | h = h[index] 759 | batch = batch[index] 760 | edge_index = knn_graph(pos, k=3, batch=batch, loop=True) 761 | h = self.conv3(h=h, pos=pos, edge_index=edge_index) 762 | h = h.relu() 763 | h = self.conv4(h=h, pos=pos, edge_index=edge_index) 764 | h = h.relu() 765 | # 4. Global Pooling. 766 | h = global_max_pool(h, batch) # [num_examples, hidden_channels] 767 | # 5. FC 768 | h = self.fc(h) 769 | h = h.view(-1, self.n_kp, 8, 8) # [batch_size, n_kp, 4, 4] 770 | cnn_out = self.cnn(h) # [batch_size, n_kp, target_hw, target_hw] 771 | return cnn_out 772 | -------------------------------------------------------------------------------- /requirements17.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.3.0 2 | h5py==2.10.0 3 | imageio==2.6.1 4 | matplotlib==3.4.2 5 | numpy==1.19.5 6 | pandas==1.3.1 7 | Pillow==8.3.1 8 | pip==21.1.3 9 | scikit-image==0.18.1 10 | scikit-learn==0.24.2 11 | scipy==1.7.0 12 | torch==1.7.1 13 | torch-cluster==1.5.9 14 | torch-geometric==1.7.2 15 | torch-scatter==2.0.7 16 | torch-sparse==0.6.9 17 | torch-spline-conv==1.2.1 18 | torchvision==0.8.2 19 | tqdm==4.62.3 20 | -------------------------------------------------------------------------------- /requirements19.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.5.1 2 | h5py==2.10.0 3 | imageio==2.6.1 4 | matplotlib==3.4.2 5 | matplotlib-inline==0.1.3 6 | numpy==1.20.3 7 | opencv-python==3.4.2.17 8 | pandas==1.3.0 9 | Pillow==8.3.1 10 | pip==21.0.1 11 | scikit-image==0.18.1 12 | scikit-learn==0.24.2 13 | scipy==1.6.3 14 | torch==1.9.0 15 | torch-cluster==1.5.9 16 | torch-geometric==2.0.0 17 | torch-scatter==2.0.8 18 | torch-sparse==0.6.12 19 | torch-spline-conv==1.2.1 20 | torchfile==0.1.0 21 | torchvision==0.10.0 22 | tqdm==4.62.2 23 | yacs==0.1.6 24 | -------------------------------------------------------------------------------- /utils/loss_functions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loss functions implementations used in the optimization of DLP. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torchvision import models 8 | 9 | 10 | # functions 11 | def batch_pairwise_kl(mu_x, logvar_x, mu_y, logvar_y, reverse_kl=False): 12 | """ 13 | Calculate batch-wise KL-divergence 14 | mu_x, logvar_x: [batch_size, n_x, points_dim] 15 | mu_y, logvar_y: [batch_size, n_y, points_dim] 16 | kl = -0.5 * Σ_points_dim (1 + logvar_x - logvar_y - exp(logvar_x)/exp(logvar_y) 17 | - ((mu_x - mu_y) ** 2)/exp(logvar_y)) 18 | """ 19 | if reverse_kl: 20 | mu_a, logvar_a = mu_y, logvar_y 21 | mu_b, logvar_b = mu_x, logvar_x 22 | else: 23 | mu_a, logvar_a = mu_x, logvar_x 24 | mu_b, logvar_b = mu_y, logvar_y 25 | bs, n_a, points_dim = mu_a.size() 26 | _, n_b, _ = mu_b.size() 27 | logvar_aa = logvar_a.unsqueeze(2).expand(-1, -1, n_b, -1) # [batch_size, n_a, n_b, points_dim] 28 | logvar_bb = logvar_b.unsqueeze(1).expand(-1, n_a, -1, -1) # [batch_size, n_a, n_b, points_dim] 29 | mu_aa = mu_a.unsqueeze(2).expand(-1, -1, n_b, -1) # [batch_size, n_a, n_b, points_dim] 30 | mu_bb = mu_b.unsqueeze(1).expand(-1, n_a, -1, -1) # [batch_size, n_a, n_b, points_dim] 31 | p_kl = -0.5 * (1 + logvar_aa - logvar_bb - logvar_aa.exp() / logvar_bb.exp() 32 | - ((mu_aa - mu_bb) ** 2) / logvar_bb.exp()).sum(-1) # [batch_size, n_x, n_y] 33 | return p_kl 34 | 35 | 36 | def calc_reconstruction_loss(x, recon_x, loss_type='mse', reduction='sum'): 37 | """ 38 | 39 | :param x: original inputs 40 | :param recon_x: reconstruction of the VAE's input 41 | :param loss_type: "mse", "l1", "bce" 42 | :param reduction: "sum", "mean", "none" 43 | :return: recon_loss 44 | """ 45 | if reduction not in ['sum', 'mean', 'none']: 46 | raise NotImplementedError 47 | recon_x = recon_x.view(recon_x.size(0), -1) 48 | x = x.view(x.size(0), -1) 49 | if loss_type == 'mse': 50 | recon_error = F.mse_loss(recon_x, x, reduction='none') 51 | recon_error = recon_error.sum(1) 52 | if reduction == 'sum': 53 | recon_error = recon_error.sum() 54 | elif reduction == 'mean': 55 | recon_error = recon_error.mean() 56 | elif loss_type == 'l1': 57 | recon_error = F.l1_loss(recon_x, x, reduction=reduction) 58 | elif loss_type == 'bce': 59 | recon_error = F.binary_cross_entropy(recon_x, x, reduction=reduction) 60 | else: 61 | raise NotImplementedError 62 | return recon_error 63 | 64 | 65 | def calc_kl(logvar, mu, mu_o=0.0, logvar_o=0.0, reduce='sum'): 66 | """ 67 | Calculate kl-divergence 68 | :param logvar: log-variance from the encoder 69 | :param mu: mean from the encoder 70 | :param mu_o: negative mean for outliers (hyper-parameter) 71 | :param logvar_o: negative log-variance for outliers (hyper-parameter) 72 | :param reduce: type of reduce: 'sum', 'none' 73 | :return: kld 74 | """ 75 | if not isinstance(mu_o, torch.Tensor): 76 | mu_o = torch.tensor(mu_o).to(mu.device) 77 | if not isinstance(logvar_o, torch.Tensor): 78 | logvar_o = torch.tensor(logvar_o).to(mu.device) 79 | kl = -0.5 * (1 + logvar - logvar_o - logvar.exp() / torch.exp(logvar_o) - (mu - mu_o).pow(2) / torch.exp( 80 | logvar_o)).sum(1) 81 | if reduce == 'sum': 82 | kl = torch.sum(kl) 83 | elif reduce == 'mean': 84 | kl = torch.mean(kl) 85 | return kl 86 | 87 | 88 | # classes 89 | class ChamferLossKL(nn.Module): 90 | """ 91 | Calculates the KL-divergence between two sets of (R.V.) particle coordinates. 92 | """ 93 | def __init__(self, use_reverse_kl=False): 94 | super(ChamferLossKL, self).__init__() 95 | self.use_reverse_kl = use_reverse_kl 96 | 97 | def forward(self, mu_preds, logvar_preds, mu_gts, logvar_gts): 98 | p_kl = batch_pairwise_kl(mu_preds, logvar_preds, mu_gts, logvar_gts, reverse_kl=False) 99 | if self.use_reverse_kl: 100 | p_rkl = batch_pairwise_kl(mu_preds, logvar_preds, mu_gts, logvar_gts, reverse_kl=True) 101 | p_kl = 0.5 * (p_kl + p_rkl.transpose(2, 1)) 102 | mins, _ = torch.min(p_kl, 1) 103 | loss_1 = torch.sum(mins, 1) 104 | mins, _ = torch.min(p_kl, 2) 105 | loss_2 = torch.sum(mins, 1) 106 | return loss_1 + loss_2 107 | 108 | 109 | class NetVGGFeatures(nn.Module): 110 | 111 | def __init__(self, layer_ids): 112 | super().__init__() 113 | 114 | self.vggnet = models.vgg16(pretrained=True) 115 | self.layer_ids = layer_ids 116 | 117 | def forward(self, x): 118 | output = [] 119 | for i in range(self.layer_ids[-1] + 1): 120 | x = self.vggnet.features[i](x) 121 | 122 | if i in self.layer_ids: 123 | output.append(x) 124 | 125 | return output 126 | 127 | 128 | class VGGDistance(nn.Module): 129 | 130 | def __init__(self, layer_ids=(2, 7, 12, 21, 30), accumulate_mode='sum', device=torch.device("cpu")): 131 | super().__init__() 132 | 133 | self.vgg = NetVGGFeatures(layer_ids).to(device) 134 | self.layer_ids = layer_ids 135 | self.accumulate_mode = accumulate_mode 136 | self.device = device 137 | 138 | def forward(self, I1, I2, reduction='sum', only_image=False): 139 | b_sz = I1.size(0) 140 | num_ch = I1.size(1) 141 | 142 | if self.accumulate_mode == 'sum': 143 | loss = ((I1 - I2) ** 2).view(b_sz, -1).sum(1) 144 | else: 145 | loss = ((I1 - I2) ** 2).view(b_sz, -1).mean(1) 146 | 147 | if num_ch == 1: 148 | I1 = I1.repeat(1, 3, 1, 1) 149 | I2 = I2.repeat(1, 3, 1, 1) 150 | f1 = self.vgg(I1) 151 | f2 = self.vgg(I2) 152 | 153 | if not only_image: 154 | for i in range(len(self.layer_ids)): 155 | if self.accumulate_mode == 'sum': 156 | layer_loss = ((f1[i] - f2[i]) ** 2).view(b_sz, -1).sum(1) 157 | else: 158 | layer_loss = ((f1[i] - f2[i]) ** 2).view(b_sz, -1).mean(1) 159 | loss = loss + layer_loss 160 | 161 | if reduction == 'mean': 162 | return loss.mean() 163 | elif reduction == 'sum': 164 | return loss.sum() 165 | else: 166 | return loss 167 | 168 | def get_dimensions(self, device=torch.device("cpu")): 169 | dims = [] 170 | dummy_input = torch.zeros(1, 3, 128, 128).to(device) 171 | dims.append(dummy_input.view(1, -1).size(1)) 172 | f = self.vgg(dummy_input) 173 | for i in range(len(self.layer_ids)): 174 | dims.append(f[i].view(1, -1).size(1)) 175 | return dims 176 | 177 | 178 | class ChamferLoss(nn.Module): 179 | 180 | def __init__(self): 181 | super(ChamferLoss, self).__init__() 182 | self.use_cuda = torch.cuda.is_available() 183 | 184 | def forward(self, preds, gts): 185 | P = self.batch_pairwise_dist(gts, preds) 186 | mins, _ = torch.min(P, 1) 187 | loss_1 = torch.sum(mins, 1) 188 | mins, _ = torch.min(P, 2) 189 | loss_2 = torch.sum(mins, 1) 190 | return loss_1 + loss_2 191 | 192 | def batch_pairwise_dist(self, x, y): 193 | bs, num_points_x, points_dim = x.size() 194 | _, num_points_y, _ = y.size() 195 | xx = torch.bmm(x, x.transpose(2, 1)) 196 | yy = torch.bmm(y, y.transpose(2, 1)) 197 | zz = torch.bmm(x, y.transpose(2, 1)) 198 | if self.use_cuda: 199 | dtype = torch.cuda.LongTensor 200 | else: 201 | dtype = torch.LongTensor 202 | diag_ind_x = torch.arange(0, num_points_x).type(dtype) 203 | diag_ind_y = torch.arange(0, num_points_y).type(dtype) 204 | rx = xx[:, diag_ind_x, diag_ind_x].unsqueeze(1).expand_as( 205 | zz.transpose(2, 1)) 206 | ry = yy[:, diag_ind_y, diag_ind_y].unsqueeze(1).expand_as(zz) 207 | P = rx.transpose(2, 1) + ry - 2 * zz 208 | return P 209 | -------------------------------------------------------------------------------- /utils/tps.py: -------------------------------------------------------------------------------- 1 | """ 2 | Credits: 3 | https://github.com/jamt9000/DVE 4 | """ 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | import math 9 | 10 | 11 | def tps_grid(H, W): 12 | xi = torch.linspace(-1, 1, W) 13 | yi = torch.linspace(-1, 1, H) 14 | 15 | yy, xx = torch.meshgrid(yi, xi) 16 | grid = torch.stack((xx.reshape(-1), yy.reshape(-1)), 1) 17 | return grid 18 | 19 | 20 | def spatial_grid_unnormalized(H, W): 21 | xi = torch.linspace(0, W - 1, W) 22 | yi = torch.linspace(0, H - 1, H) 23 | 24 | yy, xx = torch.meshgrid(yi, xi) 25 | grid = torch.stack((xx.reshape(-1), yy.reshape(-1)), 1) 26 | return grid.reshape(H, W, 2) 27 | 28 | 29 | def tps_U(grid1, grid2): 30 | D = grid1.reshape(-1, 1, 2) - grid2.reshape(1, -1, 2) 31 | D = torch.sum(D ** 2., 2) 32 | U = D * torch.log(D + 1e-5) 33 | return U 34 | 35 | 36 | def grid_unnormalize(grid, H, W): 37 | x = grid.reshape(-1, H, W, 2) 38 | constants = torch.tensor([W - 1., H - 1.], dtype=x.dtype) 39 | constants = constants.reshape(1, 1, 1, 2).to(x.device) 40 | x = (x + 1.) / 2. * constants 41 | return x.reshape(grid.shape) 42 | 43 | 44 | def grid_normalize(grid, H, W): 45 | x = grid.reshape(-1, H, W, 2) 46 | x = 2. * x / torch.Tensor([W - 1., H - 1.]).reshape(1, 1, 1, 2).to(x.device) - 1 47 | return x.reshape(grid.shape) 48 | 49 | 50 | def random_tps_weights(nctrlpts, warpsd_all, warpsd_subset, transsd, scalesd, rotsd): 51 | W = torch.randn(nctrlpts, 2) * warpsd_all 52 | subset = torch.rand(W.shape) > 0.5 53 | W[subset] = torch.randn(subset.sum()) * warpsd_subset 54 | rot = torch.randn([]) * rotsd * math.pi / 180 55 | sc = 1. + torch.randn([]) * scalesd 56 | tx = torch.randn([]) * transsd 57 | ty = torch.randn([]) * transsd 58 | 59 | aff = torch.Tensor([[tx, ty], 60 | [sc * torch.cos(rot), sc * -torch.sin(rot)], 61 | [sc * torch.sin(rot), sc * torch.cos(rot)]]) 62 | 63 | Wa = torch.cat((W, aff), 0) 64 | return Wa 65 | 66 | 67 | class Warper(object): 68 | returns_pairs = True 69 | 70 | def __init__(self, H, W, warpsd_all=0.001, warpsd_subset=0.01, transsd=0.1, 71 | scalesd=0.1, rotsd=5, im1_multiplier=0.5, im1_multiplier_aff=1.): 72 | self.H = H 73 | self.W = W 74 | self.warpsd_all = warpsd_all 75 | self.warpsd_subset = warpsd_subset 76 | self.transsd = transsd 77 | self.scalesd = scalesd 78 | self.rotsd = rotsd 79 | self.im1_multiplier = im1_multiplier 80 | self.im1_multiplier_aff = im1_multiplier_aff 81 | 82 | self.npixels = H * W 83 | self.nc = 10 84 | self.nctrlpts = self.nc * self.nc 85 | 86 | self.grid_pixels = tps_grid(H, W) 87 | self.grid_pixels_unnormalized = grid_unnormalize(self.grid_pixels.reshape(1, H, W, 2), self.H, self.W) 88 | self.grid_ctrlpts = tps_grid(self.nc, self.nc) 89 | self.U_ctrlpts = tps_U(self.grid_ctrlpts, self.grid_ctrlpts) 90 | self.U_pixels_ctrlpts = tps_U(self.grid_pixels, self.grid_ctrlpts) 91 | self.F = torch.cat((self.U_pixels_ctrlpts, torch.ones(self.npixels, 1), self.grid_pixels), 1) 92 | 93 | def __call__(self, im1, im2=None, keypts=None, crop=0): 94 | Hc = self.H - crop - crop 95 | Wc = self.W - crop - crop 96 | 97 | # im2 should be a copy of im1 with different colour jitter 98 | if im2 is None: 99 | im2 = im1 100 | 101 | kp1 = kp2 = 0 102 | 103 | unsqueezed = False 104 | if len(im1.shape) == 3: 105 | im1 = im1.unsqueeze(0) 106 | im2 = im2.unsqueeze(0) 107 | unsqueezed = True 108 | 109 | assert im1.shape[0] == 1 and im2.shape[0] == 1 110 | 111 | a = self.im1_multiplier 112 | b = self.im1_multiplier_aff 113 | weights1 = random_tps_weights(self.nctrlpts, a * self.warpsd_all, a * self.warpsd_subset, b * self.transsd, 114 | b * self.scalesd, b * self.rotsd) 115 | 116 | grid1 = torch.matmul(self.F, weights1).reshape(1, self.H, self.W, 2) 117 | grid1_unnormalized = grid_unnormalize(grid1, self.H, self.W) 118 | if keypts is not None: 119 | kp1 = self.warp_keypoints(keypts, grid1_unnormalized) 120 | 121 | im1 = F.grid_sample(im1, grid1, align_corners=True) 122 | im2 = F.grid_sample(im2, grid1, align_corners=True) 123 | 124 | weights2 = random_tps_weights(self.nctrlpts, self.warpsd_all, self.warpsd_subset, self.transsd, 125 | self.scalesd, self.rotsd) 126 | grid2 = torch.matmul(self.F, weights2).reshape(1, self.H, self.W, 2) 127 | im2 = F.grid_sample(im2, grid2, align_corners=True) 128 | 129 | if crop != 0: 130 | im1 = im1[:, :, crop:-crop, crop:-crop] 131 | im2 = im2[:, :, crop:-crop, crop:-crop] 132 | 133 | if unsqueezed: 134 | im1 = im1.squeeze(0) 135 | im2 = im2.squeeze(0) 136 | 137 | grid = grid2 138 | grid_unnormalized = grid_unnormalize(grid, self.H, self.W) 139 | 140 | if keypts is not None: 141 | kp2 = self.warp_keypoints(kp1, grid_unnormalized) 142 | 143 | flow = grid_unnormalized - self.grid_pixels_unnormalized 144 | 145 | if crop != 0: 146 | flow = flow[:, crop:-crop, crop:-crop, :] 147 | 148 | grid_cropped = grid_unnormalized[:, crop:-crop, crop:-crop, :] - crop 149 | grid = grid_normalize(grid_cropped, Hc, Wc) 150 | 151 | # hc = flow.shape[1] 152 | # wc = flow.shape[2] 153 | # gridc = flow + grid_unnormalize(tps_grid(hc, wc).reshape(1, hc, wc, 2), hc, wc) 154 | # grid = grid_normalize(gridc, hc, wc) 155 | 156 | if keypts is not None: 157 | kp1 -= crop 158 | kp2 -= crop 159 | 160 | # Reverse the order because due to inverse warping the "flow" is in direction im2->im1 161 | # and we want to be consistent with optical flow from videos 162 | return im2, im1, flow, grid, kp2, kp1 163 | 164 | def warp_keypoints(self, keypoints, grid_unnormalized): 165 | from scipy.spatial.kdtree import KDTree 166 | warp_grid = grid_unnormalized.reshape(-1, 2) 167 | regular_grid = self.grid_pixels_unnormalized.reshape(-1, 2) 168 | kd = KDTree(warp_grid) 169 | dists, idxs = kd.query(keypoints) 170 | new_keypoints = regular_grid[idxs] 171 | return new_keypoints 172 | 173 | 174 | class WarperSingle(object): 175 | returns_pairs = False 176 | 177 | def __init__(self, H, W, warpsd_all=0.0005, warpsd_subset=0.0, transsd=0.02, 178 | scalesd=0.02, rotsd=2): 179 | self.H = H 180 | self.W = W 181 | self.warpsd_all = warpsd_all 182 | self.warpsd_subset = warpsd_subset 183 | self.transsd = transsd 184 | self.scalesd = scalesd 185 | self.rotsd = rotsd 186 | 187 | self.npixels = H * W 188 | self.nc = 10 189 | self.nctrlpts = self.nc * self.nc 190 | 191 | self.grid_pixels = tps_grid(H, W) 192 | self.grid_pixels_unnormalized = grid_unnormalize(self.grid_pixels.reshape(1, H, W, 2), self.H, self.W) 193 | self.grid_ctrlpts = tps_grid(self.nc, self.nc) 194 | self.U_ctrlpts = tps_U(self.grid_ctrlpts, self.grid_ctrlpts) 195 | self.U_pixels_ctrlpts = tps_U(self.grid_pixels, self.grid_ctrlpts) 196 | self.F = torch.cat((self.U_pixels_ctrlpts, torch.ones(self.npixels, 1), self.grid_pixels), 1) 197 | 198 | def __call__(self, im1, keypts=None, crop=0): 199 | kp1 = 0 200 | 201 | unsqueezed = False 202 | if len(im1.shape) == 3: 203 | im1 = im1.unsqueeze(0) 204 | unsqueezed = True 205 | 206 | assert im1.shape[0] == 1 207 | 208 | a = 1 209 | weights1 = random_tps_weights(self.nctrlpts, a * self.warpsd_all, a * self.warpsd_subset, a * self.transsd, 210 | a * self.scalesd, a * self.rotsd) 211 | 212 | grid1 = torch.matmul(self.F, weights1).reshape(1, self.H, self.W, 2) 213 | grid1_unnormalized = grid_unnormalize(grid1, self.H, self.W) 214 | if keypts is not None: 215 | kp1 = self.warp_keypoints(keypts, grid1_unnormalized) 216 | 217 | im1 = F.grid_sample(im1, grid1, align_corners=True) 218 | 219 | if crop != 0: 220 | im1 = im1[:, :, crop:-crop, crop:-crop] 221 | 222 | if unsqueezed: 223 | im1 = im1.squeeze(0) 224 | 225 | if crop != 0 and keypts is not None: 226 | kp1 -= crop 227 | 228 | # Reverse the order because due to inverse warping the "flow" is in direction im2->im1 229 | # and we want to be consistent with optical flow from videos 230 | return im1, kp1 231 | 232 | def warp_keypoints(self, keypoints, grid_unnormalized): 233 | from scipy.spatial.kdtree import KDTree 234 | warp_grid = grid_unnormalized.reshape(-1, 2) 235 | regular_grid = self.grid_pixels_unnormalized.reshape(-1, 2) 236 | kd = KDTree(warp_grid) 237 | dists, idxs = kd.query(keypoints) 238 | new_keypoints = regular_grid[idxs] 239 | return new_keypoints -------------------------------------------------------------------------------- /utils/util_func.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility logging, plotting and animation functions. 3 | """ 4 | 5 | # imports 6 | import numpy as np 7 | # import matplotlib.pyplot as plt 8 | import cv2 9 | # from PIL import Image 10 | import datetime 11 | import os 12 | import json 13 | # import imageio 14 | # torch 15 | import torch 16 | import torch.nn.functional as F 17 | import torchvision.transforms as transforms 18 | import torchvision.ops as ops 19 | 20 | 21 | def color_map(num=100): 22 | """ 23 | Color maps for the keypoints 24 | """ 25 | colormap = ["FF355E", 26 | "8ffe09", 27 | "1d5dec", 28 | "FF9933", 29 | "FFFF66", 30 | "CCFF00", 31 | "AAF0D1", 32 | "FF6EFF", 33 | "FF00CC", 34 | "299617", 35 | "AF6E4D"] * num 36 | s = '' 37 | for color in colormap: 38 | s += color 39 | b = bytes.fromhex(s) 40 | cm = np.frombuffer(b, np.uint8) 41 | cm = cm.reshape(len(colormap), 3) 42 | return cm 43 | 44 | 45 | def plot_keypoints_on_image(k, image_tensor, radius=1, thickness=1, kp_range=(0, 1)): 46 | # https://github.com/DuaneNielsen/keypoints 47 | height, width = image_tensor.size(1), image_tensor.size(2) 48 | num_keypoints = k.size(0) 49 | 50 | if len(k.shape) != 2: 51 | raise Exception('Individual images and keypoints, not batches') 52 | 53 | k = k.clone() 54 | k[:, 0] = ((k[:, 0] - kp_range[0]) / (kp_range[1] - kp_range[0])) * (height - 1) 55 | k[:, 1] = ((k[:, 1] - kp_range[0]) / (kp_range[1] - kp_range[0])) * (width - 1) 56 | # k.floor_() 57 | k.round_() 58 | k = k.detach().cpu().numpy() 59 | # print(k) 60 | 61 | img = transforms.ToPILImage()(image_tensor.cpu()) 62 | 63 | img = np.array(img) 64 | cmap = color_map() 65 | cm = cmap[:num_keypoints].astype(int) 66 | count = 0 67 | for co_ord, color in zip(k, cm): 68 | c = color.item(0), color.item(1), color.item(2) 69 | co_ord = co_ord.squeeze() 70 | cv2.circle(img, (int(co_ord[1]), int(co_ord[0])), radius, c, thickness) 71 | count += 1 72 | 73 | return img 74 | 75 | 76 | def plot_keypoints_on_image_batch(kp_batch_tensor, img_batch_tensor, radius=1, thickness=1, max_imgs=8, 77 | kp_range=(-1, 1)): 78 | num_plot = min(max_imgs, img_batch_tensor.shape[0]) 79 | img_with_kp = [] 80 | for i in range(num_plot): 81 | img_np = plot_keypoints_on_image(kp_batch_tensor[i], img_batch_tensor[i], radius=radius, thickness=thickness, 82 | kp_range=kp_range) 83 | img_tensor = torch.tensor(img_np).float() / 255.0 84 | img_with_kp.append(img_tensor.permute(2, 0, 1)) 85 | img_with_kp = torch.stack(img_with_kp, dim=0) 86 | return img_with_kp 87 | 88 | 89 | def get_kp_mask_from_gmap(gmaps, threshold=0.2, binary=True, elementwise=False): 90 | """ 91 | Transforms the Gaussian-maps created from the KP to (binary) masks. 92 | gmaps: [B, K, H, W] 93 | threshold: above it pixels are one and below zero 94 | """ 95 | if elementwise: 96 | mask = gmaps 97 | else: 98 | mask = gmaps.sum(1, keepdim=True) 99 | if binary: 100 | mask = torch.where(mask > threshold, 1.0, 0.0) 101 | else: 102 | mask = mask.clamp(0, 1) 103 | return mask 104 | 105 | 106 | def reparameterize(mu, logvar): 107 | """ 108 | This function applies the reparameterization trick: 109 | z = mu(X) + sigma(X)^0.5 * epsilon, where epsilon ~ N(0,I) 110 | :param mu: mean of x 111 | :param logvar: log variaance of x 112 | :return z: the sampled latent variable 113 | """ 114 | device = mu.device 115 | std = torch.exp(0.5 * logvar) 116 | eps = torch.randn_like(mu).to(device) 117 | return mu + eps * std 118 | 119 | 120 | def create_masks_fast(center, anchor_s, feature_dim=16, patch_size=None): 121 | """ 122 | Creates binary masks where only a box of size round(anchor_s * (feature_dim - 1)) centered arond `center` 123 | is 1, rest is 0. 124 | center: [batch_size, n_kp, 2] in kp_range 125 | anchor_h, anchor_w: size of anchor in [0, 1] 126 | """ 127 | batch_size, n_kp = center.shape[0], center.shape[1] 128 | if patch_size is None: 129 | patch_size = np.round(anchor_s * (feature_dim - 1)).astype(int) 130 | # create white rectangles 131 | masks = torch.ones(batch_size * n_kp, 1, patch_size, patch_size, device=center.device).float() 132 | # pad the masks to image size 133 | pad_size = (feature_dim - patch_size) // 2 134 | padded_patches_batch = F.pad(masks, pad=[pad_size] * 4) 135 | # move the masks to be centered around the kp 136 | delta_t_batch = 0.0 - center 137 | delta_t_batch = delta_t_batch.reshape(-1, delta_t_batch.shape[-1]) # [bs * n_kp, 2] 138 | zeros = torch.zeros([delta_t_batch.shape[0], 1], device=delta_t_batch.device).float() 139 | ones = torch.ones([delta_t_batch.shape[0], 1], device=delta_t_batch.device).float() 140 | theta = torch.cat([ones, zeros, delta_t_batch[:, 1].unsqueeze(-1), 141 | zeros, ones, delta_t_batch[:, 0].unsqueeze(-1)], dim=-1) 142 | theta = theta.view(-1, 2, 3) # [batch_size * n_kp, 2, 3] 143 | align_corners = False 144 | padding_mode = 'zeros' 145 | mode = "nearest" 146 | # mode = 'bilinear' # makes it differentiable, but we don't care about it here 147 | grid = F.affine_grid(theta, padded_patches_batch.size(), align_corners=align_corners) 148 | trans_padded_patches_batch = F.grid_sample(padded_patches_batch, grid, align_corners=align_corners, 149 | mode=mode, padding_mode=padding_mode) 150 | trans_padded_patches_batch = trans_padded_patches_batch.view(batch_size, n_kp, *padded_patches_batch.shape[1:]) 151 | # [bs, n_kp, 1, feature_dim, feature_dim] 152 | return trans_padded_patches_batch 153 | 154 | 155 | def get_bb_from_masks(masks, width, height): 156 | # extracts bounding boxes (bb) from masks. 157 | # masks: [n_masks, 1, feature_dim, feature_dim] 158 | n_masks = masks.shape[0] 159 | mask_h, mask_w = masks.shape[2], masks.shape[3] 160 | coor = torch.zeros(size=(n_masks, 4), dtype=torch.int, device=masks.device) 161 | for i in range(n_masks): 162 | mask = masks[i].int().squeeze() # [feature_dim, feature_dim] 163 | indices = (mask == 1).nonzero(as_tuple=False) 164 | if indices.shape[0] > 0: 165 | ws = (indices[0][1] * (width / mask_w)).clamp(0, width).int() 166 | wt = (indices[-1][1] * (width / mask_w)).clamp(0, width).int() 167 | hs = (indices[0][0] * (height / mask_h)).clamp(0, height).int() 168 | ht = (indices[-1][0] * (height / mask_h)).clamp(0, height).int() 169 | coor[i, 0] = ws 170 | coor[i, 1] = hs 171 | coor[i, 2] = wt 172 | coor[i, 3] = ht 173 | return coor 174 | 175 | 176 | def get_bb_from_masks_batch(masks, width, height): 177 | # extracts bounding boxes (bb) from a batch of masks. 178 | # masks: [batch_size, n_masks, 1, feature_dim, feature_dim] 179 | coor = torch.zeros(size=(masks.shape[0], masks.shape[1], 4), dtype=torch.int, device=masks.device) 180 | for i in range(masks.shape[0]): 181 | coor[i, :, :] = get_bb_from_masks(masks[i], width, height) 182 | return coor 183 | 184 | 185 | def nms_single(boxes, scores, iou_thresh=0.5, return_scores=False, remove_ind=None): 186 | # non-maximal suppression on bb and scores from one image. 187 | # boxes: [n_bb, 4], scores: [n_boxes] 188 | nms_indices = ops.nms(boxes.float(), scores, iou_thresh) 189 | # remove low scoring indices from nms output 190 | if remove_ind is not None: 191 | # final_indices = [ind for ind in nms_indices if ind not in remove_ind] 192 | final_indices = list(set(nms_indices.data.cpu().numpy()) - set(remove_ind)) 193 | # print(f'removed indices: {remove_ind}') 194 | else: 195 | final_indices = nms_indices 196 | nms_boxes = boxes[final_indices] # [n_bb_nms, 4] 197 | if return_scores: 198 | return nms_boxes, final_indices, scores[final_indices] 199 | else: 200 | return nms_boxes, final_indices 201 | 202 | 203 | def remove_low_score_bb_single(boxes, scores, return_scores=False, mode='mean', thresh=0.4, hard_thresh=None): 204 | # filters out low-scoring bounding boxes. The score is usually the variance of the particle. 205 | # boxes: [n_bb, 4], scores: [n_boxes] 206 | if hard_thresh is None: 207 | if mode == 'mean': 208 | mean_score = scores.mean() 209 | # indices = (scores > mean_score) 210 | indices = torch.nonzero(scores > thresh, as_tuple=True)[0].data.cpu().numpy() 211 | else: 212 | normalzied_scores = (scores - scores.min()) / (scores.max() - scores.min()) 213 | # indices = (normalzied_scores > thresh) 214 | indices = torch.nonzero(normalzied_scores > thresh, as_tuple=True)[0].data.cpu().numpy() 215 | else: 216 | # indices = (scores > hard_thresh) 217 | indices = torch.nonzero(scores > hard_thresh, as_tuple=True)[0].data.cpu().numpy() 218 | boxes_t = boxes[indices] 219 | scores_t = scores[indices] 220 | if return_scores: 221 | return indices, boxes_t, scores_t 222 | else: 223 | return indices, boxes_t 224 | 225 | 226 | def get_low_score_bb_single(scores, mode='mean', thresh=0.4, hard_thresh=None): 227 | # get indices of low-scoring bounding boxes. 228 | # boxes: [n_bb, 4], scores: [n_boxes] 229 | if hard_thresh is None: 230 | if mode == 'mean': 231 | mean_score = scores.mean() 232 | # indices = (scores > mean_score) 233 | indices = torch.nonzero(scores < thresh, as_tuple=True)[0].data.cpu().numpy() 234 | else: 235 | normalzied_scores = (scores - scores.min()) / (scores.max() - scores.min()) 236 | # indices = (normalzied_scores > thresh) 237 | indices = torch.nonzero(normalzied_scores < thresh, as_tuple=True)[0].data.cpu().numpy() 238 | else: 239 | # indices = (scores > hard_thresh) 240 | indices = torch.nonzero(scores < hard_thresh, as_tuple=True)[0].data.cpu().numpy() 241 | return indices 242 | 243 | 244 | def plot_bb_on_image_from_masks_nms(masks, image_tensor, scores, iou_thresh=0.5, thickness=1, hard_thresh=None): 245 | # plot bounding boxes on a single image, use non-maximal suppression to filter low-scoring bbs. 246 | # masks: [n_masks, 1, feature_dim, feature_dim] 247 | n_masks = masks.shape[0] 248 | mask_h, mask_w = masks.shape[2], masks.shape[3] 249 | height, width = image_tensor.size(1), image_tensor.size(2) 250 | img = transforms.ToPILImage()(image_tensor.cpu()) 251 | img = np.array(img) 252 | cmap = color_map() 253 | cm = cmap[:n_masks].astype(int) 254 | count = 0 255 | # get bb coor 256 | coors = get_bb_from_masks(masks, width, height) # [n_masks, 4] 257 | # remove low-score bb 258 | low_score_ind = get_low_score_bb_single(scores, mode='mean', hard_thresh=hard_thresh) 259 | # nms 260 | coors_nms, nms_indices, scores_nms = nms_single(coors, scores, iou_thresh, return_scores=True, 261 | remove_ind=low_score_ind) 262 | # [n_masks_nms, 4] 263 | for coor, color in zip(coors_nms, cm): 264 | c = color.item(0), color.item(1), color.item(2) 265 | ws = (coor[0] - thickness).clamp(0, width) 266 | hs = (coor[1] - thickness).clamp(0, height) 267 | wt = (coor[2] + thickness).clamp(0, width) 268 | ht = (coor[3] + thickness).clamp(0, height) 269 | bb_s = (int(ws), int(hs)) 270 | bb_t = (int(wt), int(ht)) 271 | cv2.rectangle(img, bb_s, bb_t, c, thickness, 1) 272 | score_text = f'{scores_nms[count]:.2f}' 273 | font = cv2.FONT_HERSHEY_SIMPLEX 274 | fontScale = 0.3 275 | thickness = 1 276 | box_w = bb_t[0] - bb_s[0] 277 | box_h = bb_t[1] - bb_s[1] 278 | org = (int(bb_s[0] + box_w / 2), int(bb_s[1] + box_h / 2)) 279 | cv2.putText(img, score_text, org, font, fontScale, thickness=thickness, color=c, lineType=cv2.LINE_AA) 280 | count += 1 281 | 282 | return img, nms_indices 283 | 284 | 285 | def plot_bb_on_image_batch_from_masks_nms(mask_batch_tensor, img_batch_tensor, scores, iou_thresh=0.5, thickness=1, 286 | max_imgs=8, hard_thresh=None): 287 | # plot bounding boxes on a batch of images, use non-maximal suppression to filter low-scoring bbs. 288 | # mask_batch_tensor: [batch_size, n_kp, 1, feature_dim, feature_dim] 289 | num_plot = min(max_imgs, img_batch_tensor.shape[0]) 290 | img_with_bb = [] 291 | indices = [] 292 | for i in range(num_plot): 293 | img_np, nms_indices = plot_bb_on_image_from_masks_nms(mask_batch_tensor[i], img_batch_tensor[i], scores[i], 294 | iou_thresh, thickness=thickness, hard_thresh=hard_thresh) 295 | img_tensor = torch.tensor(img_np).float() / 255.0 296 | img_with_bb.append(img_tensor.permute(2, 0, 1)) 297 | indices.append(nms_indices) 298 | img_with_bb = torch.stack(img_with_bb, dim=0) 299 | return img_with_bb, indices 300 | 301 | 302 | def plot_bb_on_image_from_masks(masks, image_tensor, thickness=1): 303 | # vanilla plotting of bbs from masks. 304 | # masks: [n_masks, 1, feature_dim, feature_dim] 305 | n_masks = masks.shape[0] 306 | mask_h, mask_w = masks.shape[2], masks.shape[3] 307 | height, width = image_tensor.size(1), image_tensor.size(2) 308 | 309 | img = transforms.ToPILImage()(image_tensor.cpu()) 310 | 311 | img = np.array(img) 312 | cmap = color_map() 313 | cm = cmap[:n_masks].astype(int) 314 | count = 0 315 | for mask, color in zip(masks, cm): 316 | c = color.item(0), color.item(1), color.item(2) 317 | mask = mask.int().squeeze() # [feature_dim, feature_dim] 318 | # print(mask.shape) 319 | indices = (mask == 1).nonzero(as_tuple=False) 320 | # print(indices.shape) 321 | if indices.shape[0] > 0: 322 | ws = (indices[0][1] * (width / mask_w) - thickness).clamp(0, width).int() 323 | wt = (indices[-1][1] * (width / mask_w) + thickness).clamp(0, width).int() 324 | hs = (indices[0][0] * (height / mask_h) - thickness).clamp(0, height).int() 325 | ht = (indices[-1][0] * (height / mask_h) + thickness).clamp(0, height).int() 326 | bb_s = (int(ws), int(hs)) 327 | bb_t = (int(wt), int(ht)) 328 | cv2.rectangle(img, bb_s, bb_t, c, thickness, 1) 329 | count += 1 330 | return img 331 | 332 | 333 | def plot_bb_on_image_batch_from_masks(mask_batch_tensor, img_batch_tensor, thickness=1, max_imgs=8): 334 | # vanilla plotting of bbs from a batch of masks. 335 | # mask_batch_tensor: [batch_size, n_kp, 1, feature_dim, feature_dim] 336 | num_plot = min(max_imgs, img_batch_tensor.shape[0]) 337 | img_with_bb = [] 338 | for i in range(num_plot): 339 | img_np = plot_bb_on_image_from_masks(mask_batch_tensor[i], img_batch_tensor[i], thickness=thickness) 340 | img_tensor = torch.tensor(img_np).float() / 255.0 341 | img_with_bb.append(img_tensor.permute(2, 0, 1)) 342 | img_with_bb = torch.stack(img_with_bb, dim=0) 343 | return img_with_bb 344 | 345 | 346 | def prepare_logdir(runname, src_dir='./'): 347 | """ 348 | Prepare the log directory in which checkpoints, plots and stats will be saved. 349 | """ 350 | td_prefix = datetime.datetime.now().strftime("%d%m%y_%H%M%S") 351 | dir_name = f'{td_prefix}_{runname}' 352 | path_to_dir = os.path.join(src_dir, dir_name) 353 | os.makedirs(path_to_dir, exist_ok=True) 354 | path_to_fig_dir = os.path.join(path_to_dir, 'figures') 355 | os.makedirs(path_to_fig_dir, exist_ok=True) 356 | path_to_save_dir = os.path.join(path_to_dir, 'saves') 357 | os.makedirs(path_to_save_dir, exist_ok=True) 358 | return path_to_dir 359 | 360 | 361 | def save_config(src_dir, hparams): 362 | # saves the hyperparameters of a single run. 363 | path_to_conf = os.path.join(src_dir, 'hparams.json') 364 | with open(path_to_conf, "w") as outfile: 365 | json.dump(hparams, outfile, indent=2) 366 | 367 | 368 | def log_line(src_dir, line): 369 | log_file = os.path.join(src_dir, 'log.txt') 370 | with open(log_file, 'a') as fp: 371 | fp.writelines(line) 372 | --------------------------------------------------------------------------------