├── LICENSE ├── README.md ├── bin ├── efficient_capsnet_MNIST.h5 ├── efficient_capsnet_MULTIMNIST.h5 ├── efficient_capsnet_SMALLNORB.h5 └── original_capsnet_MNIST.h5 ├── config.json ├── dynamic_visualization_capsules_dimensions_perturbation.ipynb ├── efficient_capsnet_test.ipynb ├── efficient_capsnet_train.ipynb ├── media ├── dimension_perturbation.gif ├── efficient_capsnet_architecture.png └── routing_capsules.png ├── models ├── __init__.py ├── efficient_capsnet_graph_mnist.py ├── efficient_capsnet_graph_multimnist.py ├── efficient_capsnet_graph_smallnorb.py ├── model.py └── original_capsnet_graph_mnist.py ├── original_capsnet_test.ipynb ├── original_capsnet_train.ipynb ├── requirements.txt └── utils ├── __init__.py ├── dataset.py ├── layers.py ├── layers_hinton.py ├── pre_process_mnist.py ├── pre_process_multimnist.py ├── pre_process_smallnorb.py ├── tools.py └── visualization.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![arXiv](http://img.shields.io/badge/arXiv-2001.09136-B31B1B.svg)](https://arxiv.org/abs/2101.12491) 2 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/efficient-capsnet-capsule-network-with-self/image-classification-on-smallnorb)](https://paperswithcode.com/sota/image-classification-on-smallnorb?p=efficient-capsnet-capsule-network-with-self) 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/efficient-capsnet-capsule-network-with-self/image-classification-on-mnist)](https://paperswithcode.com/sota/image-classification-on-mnist?p=efficient-capsnet-capsule-network-with-self) 4 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 5 | 6 | 7 | 8 |

~ Efficient-CapsNet ~

9 | Are you tired of over inflated and overused convolutional neural networks? You're right! It's time for CAPSULES :)
10 | 11 | This repository has been made for two primarly reasons: 12 | 13 | - open source the code (most of) developed during our "first-stage" research on capsules, summarized by the forthcoming article "Efficient-CapsNet: Capsule Network with Self-Attention Routing". The repository let you play with Efficient-CapsNet and let you set the base for your own experiments. 14 | - be an hub and a headlight in the cyberspace to spread to the machine learning comunity the intrinsic potential and value of capsule. However, albeit remarkable results achieved by capsule networks, we're fully aware that they're only limited to toy datasets. Nevertheless, there's a lot to make us think that with the right effort and collaboration of the scientific community, capsule based networks could really make a difference in the long run. For now, feel free to dive in our work :)) 15 | 16 |

17 | 18 |

19 | 20 | # 1.0 Getting Started 21 | 22 | ## 1.1 Installation 23 | 24 | Python3 and Tensorflow 2.x are required and should be installed on the host machine following the [official guide](https://www.tensorflow.org/install). Good luck with it! 25 | 26 | 1. Clone this repository 27 | ```bash 28 | git clone https://github.com/EscVM/Efficient-CapsNet.git 29 | ``` 30 | 2. Install the required packages 31 | ```bash 32 | pip3 install -r requirements.txt 33 | ``` 34 | Peek inside the requirements file if you have everything already installed. Most of the dependencies are common libraries. 35 | 36 | # 2.0 Efficient-CapsNet Notebooks 37 | The repository provides two starting notebooks to make you confortable with our architecture. They all have the information and explanations to let you dive further in new research and experiments. 38 | The [first](https://github.com/EscVM/Efficient-CapsNet/blob/main/efficient_capsnet_test.ipynb) one let you test Efficient-CapsNet over three different datasets. The repository is provided with some of the weights derived by our own experiments. 39 | On the other hand, the [second](https://github.com/EscVM/Efficient-CapsNet/blob/main/efficient_capsnet_train.ipynb) one let you train the network from scratch. It's a very lightweight network so you don't need "Deep Mind" TPUs arsenal to train it. However, even if a GP-GPU is not compulsory, it's strongly suggested (No GPU, no deep learning, no party). 40 | 41 | # 3.0 Original CapsNet Notebooks 42 | It goes without saying that our work has been inspiered by Geoffrey Hinton and his article "[Dynamic Routing Between Capsules](https://arxiv.org/abs/1710.09829)". It's really an honor to build on his idea. Nevertheless, when we did our first steps in the capsule world, we were pretty disappointed in finding that all repositories/implementations were ultimately wrong in some aspects. So, we implemented everything from scratch, carefully following the original Sara's [repository](https://github.com/Sarasra/models/tree/master/research/capsules). However, our implementation, besides beeing written for the new TensorFlow 2 version, is much more easier and practical to use. Sara's one is really overcomplicated and too mazy that you can lost pretty easily. 43 | 44 | As for the previous section we provide two notebooks, [one](https://github.com/EscVM/Efficient-CapsNet/blob/main/original_capsnet_test.ipynb) for testing (weights have been derived from Sara's repository) and [one](https://github.com/EscVM/Efficient-CapsNet/blob/main/original_capsnet_train.ipynb) for training. 45 | 46 | Nevertheless, there's a really negative note (at least for us:)); as all other repositories that you can find on the web, also our one is not capable to achieve the scores reported in their [paper](https://arxiv.org/abs/1710.09829). We really did our best, but there is no way to make the network achieve a score greater than 99.64% on MNIST. Exactly for this reason, weights provided in this repository are derived from their repository. Anyway, it's Geoffrey so we can excuse him. 47 | 48 | 49 | # 4.0 Capsules Dimensions Perturbation Notebook 50 | The network is trained with a reconstruction regularizer that is simply a fully connected network trained in conjuction with the main one. So, we can use it to visualize the inner capsules reppresentations. In particular, we should expect that a dimension of a digit capsule should learn to span the space of variations in the way digits of that class are instantiated. We can see what the individual dimensions represent by making use of the decoder network and injecting some noise to one of the dimensions of the main digit capsule layer that is predicting the class of the input. 51 | 52 | So, we coded a practical notebook in which you can dynamically tweak whichever dimension you want of the capsule that is making the prediction (longest one). 53 | 54 | Finally, if you don't have the necessary resources (GP-GPU holy grail) you can still try this interesting notebook out on 55 | Open In Colab. 56 |

57 | 58 |

59 | 60 | # Citation 61 | Use this bibtex if you enjoyed this repository and you want to cite it: 62 | 63 | ``` 64 | @article{mazzia2021efficient, 65 | title={Efficient-CapsNet: capsule network with self-attention routing}, 66 | author={Mazzia, Vittorio and Salvetti, Francesco and Chiaberge, Marcello}, 67 | year={2021}, 68 | journal={Scientific reports}, 69 | publisher={Nature Publishing Group}, 70 | volume={11} 71 | } 72 | ``` 73 | -------------------------------------------------------------------------------- /bin/efficient_capsnet_MNIST.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EscVM/Efficient-CapsNet/4b337f5bf79d70a56e4dec8b113fe54b44cfe963/bin/efficient_capsnet_MNIST.h5 -------------------------------------------------------------------------------- /bin/efficient_capsnet_MULTIMNIST.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EscVM/Efficient-CapsNet/4b337f5bf79d70a56e4dec8b113fe54b44cfe963/bin/efficient_capsnet_MULTIMNIST.h5 -------------------------------------------------------------------------------- /bin/efficient_capsnet_SMALLNORB.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EscVM/Efficient-CapsNet/4b337f5bf79d70a56e4dec8b113fe54b44cfe963/bin/efficient_capsnet_SMALLNORB.h5 -------------------------------------------------------------------------------- /bin/original_capsnet_MNIST.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EscVM/Efficient-CapsNet/4b337f5bf79d70a56e4dec8b113fe54b44cfe963/bin/original_capsnet_MNIST.h5 -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "eps": 10e-21, 3 | "MNIST_INPUT_SHAPE": [28,28,1], 4 | "SMALLNORB_INPUT_SHAPE": [48,48,2], 5 | "MULTIMNIST_INPUT_SHAPE": [36,36,1], 6 | "lr": 5e-4, 7 | "lmd_gen": 0.392, 8 | "lr_dec": 0.97, 9 | "batch_size": 16, 10 | "epochs":150, 11 | "saved_model_dir": "bin", 12 | "tb_log_save_dir": "logs", 13 | "mnist_path": "mnist.npz", 14 | "scale_smallnorb": 64, 15 | "patch_smallnorb": 48, 16 | "n_overlay_multimnist": 1000, 17 | "shift_multimnist": 6, 18 | "pad_multimnist": 4 19 | } -------------------------------------------------------------------------------- /dynamic_visualization_capsules_dimensions_perturbation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Capsule Dynamic Dimensions Perturbation\n", 8 | "\n", 9 | "This notebook provide a simple way to visualize dimensions perturbation of the capsule with the greatest module. We make use of the decoder network to see how the injected noise affects inner reppresentations. Indeed, after computing the activity vector for the correct digit capsule, we can feed a perturbed version of this activity vector to the decoder network and see how the perturbation affects the reconstruction. For more information read section 5.1 @ https://arxiv.org/pdf/1710.09829.pdf" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "ExecuteTime": { 17 | "end_time": "2021-02-01T10:42:35.128687Z", 18 | "start_time": "2021-02-01T10:42:35.112414Z" 19 | } 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "%load_ext autoreload\n", 24 | "%autoreload 2" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": { 31 | "ExecuteTime": { 32 | "end_time": "2021-02-01T10:43:42.320641Z", 33 | "start_time": "2021-02-01T10:43:42.294940Z" 34 | } 35 | }, 36 | "outputs": [], 37 | "source": [ 38 | "import matplotlib\n", 39 | "matplotlib.__version__" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": { 46 | "ExecuteTime": { 47 | "end_time": "2021-02-01T10:42:36.699811Z", 48 | "start_time": "2021-02-01T10:42:35.407043Z" 49 | } 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "import tensorflow as tf\n", 54 | "from utils import AffineVisualizer, Dataset\n", 55 | "from models import EfficientCapsNet" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": { 62 | "ExecuteTime": { 63 | "end_time": "2021-02-01T10:42:36.785789Z", 64 | "start_time": "2021-02-01T10:42:36.735232Z" 65 | } 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "gpus = tf.config.experimental.list_physical_devices('GPU')\n", 70 | "tf.config.experimental.set_visible_devices(gpus[0], 'GPU')\n", 71 | "tf.config.experimental.set_memory_growth(gpus[0], True)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "# 1.0 Prepare the Environment" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "## 1.1 Import the dataset" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": { 92 | "ExecuteTime": { 93 | "end_time": "2021-02-01T10:42:38.632750Z", 94 | "start_time": "2021-02-01T10:42:38.334839Z" 95 | } 96 | }, 97 | "outputs": [], 98 | "source": [ 99 | "mnist_dataset = Dataset('MNIST', config_path='config.json') # only MNIST" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "## 1.2 Import Efficient-CapsNet" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": { 113 | "ExecuteTime": { 114 | "end_time": "2021-02-01T10:42:40.653927Z", 115 | "start_time": "2021-02-01T10:42:39.463244Z" 116 | }, 117 | "scrolled": true 118 | }, 119 | "outputs": [], 120 | "source": [ 121 | "model_test = EfficientCapsNet('MNIST', mode='test', verbose=False)\n", 122 | "model_test.load_graph_weights()\n", 123 | "model_play = EfficientCapsNet('MNIST', mode='play', verbose=False)\n", 124 | "model_play.load_graph_weights()" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "## 1.2.1 Evaluate the model" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": { 138 | "ExecuteTime": { 139 | "end_time": "2021-02-01T10:42:45.276700Z", 140 | "start_time": "2021-02-01T10:42:42.158439Z" 141 | } 142 | }, 143 | "outputs": [], 144 | "source": [ 145 | "model_test.evaluate(mnist_dataset.X_test, mnist_dataset.y_test)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "# 2.0 Visualize Affine Transformation" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": { 159 | "ExecuteTime": { 160 | "end_time": "2021-02-01T10:42:46.396439Z", 161 | "start_time": "2021-02-01T10:42:45.934451Z" 162 | } 163 | }, 164 | "outputs": [], 165 | "source": [ 166 | "AffineVisualizer(model_play, mnist_dataset.X_test, mnist_dataset.y_test, hist=True).start()" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [] 175 | } 176 | ], 177 | "metadata": { 178 | "kernelspec": { 179 | "display_name": "Python 3", 180 | "language": "python", 181 | "name": "python3" 182 | }, 183 | "language_info": { 184 | "codemirror_mode": { 185 | "name": "ipython", 186 | "version": 3 187 | }, 188 | "file_extension": ".py", 189 | "mimetype": "text/x-python", 190 | "name": "python", 191 | "nbconvert_exporter": "python", 192 | "pygments_lexer": "ipython3", 193 | "version": "3.6.9" 194 | }, 195 | "toc": { 196 | "base_numbering": 1, 197 | "nav_menu": {}, 198 | "number_sections": false, 199 | "sideBar": true, 200 | "skip_h1_title": false, 201 | "title_cell": "Table of Contents", 202 | "title_sidebar": "Contents", 203 | "toc_cell": false, 204 | "toc_position": {}, 205 | "toc_section_display": true, 206 | "toc_window_display": false 207 | }, 208 | "varInspector": { 209 | "cols": { 210 | "lenName": 16, 211 | "lenType": 16, 212 | "lenVar": 40 213 | }, 214 | "kernels_config": { 215 | "python": { 216 | "delete_cmd_postfix": "", 217 | "delete_cmd_prefix": "del ", 218 | "library": "var_list.py", 219 | "varRefreshCmd": "print(var_dic_list())" 220 | }, 221 | "r": { 222 | "delete_cmd_postfix": ") ", 223 | "delete_cmd_prefix": "rm(", 224 | "library": "var_list.r", 225 | "varRefreshCmd": "cat(var_dic_list()) " 226 | } 227 | }, 228 | "types_to_exclude": [ 229 | "module", 230 | "function", 231 | "builtin_function_or_method", 232 | "instance", 233 | "_Feature" 234 | ], 235 | "window_display": false 236 | } 237 | }, 238 | "nbformat": 4, 239 | "nbformat_minor": 4 240 | } 241 | -------------------------------------------------------------------------------- /efficient_capsnet_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Efficient-CapsNet Model Test\n", 8 | "\n", 9 | "In this notebook we provide a simple interface to test the different trained Efficient-CapsNet models on the three datasets:\n", 10 | "\n", 11 | "- MNIST (MNIST)\n", 12 | "- smallNORB (SMALLNORB)\n", 13 | "- Multi-MNIST (MULTIMNIST)\n", 14 | "\n", 15 | "**NB**: remember to modify the \"config.json\" file with the appropriate parameters." 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": { 22 | "ExecuteTime": { 23 | "end_time": "2021-02-16T08:57:31.782498Z", 24 | "start_time": "2021-02-16T08:57:31.762987Z" 25 | } 26 | }, 27 | "outputs": [], 28 | "source": [ 29 | "%load_ext autoreload\n", 30 | "%autoreload 2" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": { 37 | "ExecuteTime": { 38 | "end_time": "2021-02-16T08:57:33.233885Z", 39 | "start_time": "2021-02-16T08:57:31.936799Z" 40 | } 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "import tensorflow as tf\n", 45 | "from utils import Dataset, plotImages, plotWrongImages\n", 46 | "from models import EfficientCapsNet" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": { 53 | "ExecuteTime": { 54 | "end_time": "2021-02-16T08:57:33.321947Z", 55 | "start_time": "2021-02-16T08:57:33.270060Z" 56 | } 57 | }, 58 | "outputs": [], 59 | "source": [ 60 | "gpus = tf.config.experimental.list_physical_devices('GPU')\n", 61 | "tf.config.experimental.set_visible_devices(gpus[0], 'GPU')\n", 62 | "tf.config.experimental.set_memory_growth(gpus[0], True)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": { 69 | "ExecuteTime": { 70 | "end_time": "2021-02-16T08:57:33.374047Z", 71 | "start_time": "2021-02-16T08:57:33.357993Z" 72 | } 73 | }, 74 | "outputs": [], 75 | "source": [ 76 | "# some parameters\n", 77 | "model_name = 'MNIST' \n", 78 | "custom_path = None # if you've trained a new model, insert here the full graph weights path" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "# 1.0 Import the Dataset" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": { 92 | "ExecuteTime": { 93 | "end_time": "2021-02-16T08:58:45.850797Z", 94 | "start_time": "2021-02-16T08:58:45.458549Z" 95 | } 96 | }, 97 | "outputs": [], 98 | "source": [ 99 | "dataset = Dataset(model_name, config_path='config.json')" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "## 1.1 Visualize imported dataset" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": { 113 | "ExecuteTime": { 114 | "end_time": "2021-02-16T08:58:48.500994Z", 115 | "start_time": "2021-02-16T08:58:47.175795Z" 116 | } 117 | }, 118 | "outputs": [], 119 | "source": [ 120 | "n_images = 20 # number of images to be plotted\n", 121 | "plotImages(dataset.X_test[:n_images,...,0], dataset.y_test[:n_images], n_images, dataset.class_names)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "# 2.0 Load the Model" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": { 135 | "ExecuteTime": { 136 | "end_time": "2021-02-16T08:58:48.778928Z", 137 | "start_time": "2021-02-16T08:58:48.538547Z" 138 | }, 139 | "scrolled": false 140 | }, 141 | "outputs": [], 142 | "source": [ 143 | "model_test = EfficientCapsNet(model_name, mode='test', verbose=True, custom_path=custom_path)\n", 144 | "\n", 145 | "model_test.load_graph_weights() # load graph weights (bin folder)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "# 3.0 Test the Model" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": { 159 | "ExecuteTime": { 160 | "end_time": "2021-02-16T09:01:44.955356Z", 161 | "start_time": "2021-02-16T08:58:52.847947Z" 162 | }, 163 | "scrolled": false 164 | }, 165 | "outputs": [], 166 | "source": [ 167 | "model_test.evaluate(dataset.X_test, dataset.y_test) # if \"smallnorb\" use X_test_patch" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": {}, 173 | "source": [ 174 | "## 3.1 Plot misclassified images" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "metadata": { 181 | "ExecuteTime": { 182 | "end_time": "2021-02-16T09:01:58.076066Z", 183 | "start_time": "2021-02-16T09:01:55.841686Z" 184 | } 185 | }, 186 | "outputs": [], 187 | "source": [ 188 | "#not working with MultiMNIST\n", 189 | "y_pred = model_test.predict(dataset.X_test)[0] # if \"smallnorb\" use X_test_patch\n", 190 | "\n", 191 | "n_images = 20\n", 192 | "plotWrongImages(dataset.X_test, dataset.y_test, y_pred, # if \"smallnorb\" use X_test_patch\n", 193 | " n_images, dataset.class_names)" 194 | ] 195 | } 196 | ], 197 | "metadata": { 198 | "kernelspec": { 199 | "display_name": "Python 3", 200 | "language": "python", 201 | "name": "python3" 202 | }, 203 | "language_info": { 204 | "codemirror_mode": { 205 | "name": "ipython", 206 | "version": 3 207 | }, 208 | "file_extension": ".py", 209 | "mimetype": "text/x-python", 210 | "name": "python", 211 | "nbconvert_exporter": "python", 212 | "pygments_lexer": "ipython3", 213 | "version": "3.6.9" 214 | }, 215 | "toc": { 216 | "base_numbering": 1, 217 | "nav_menu": {}, 218 | "number_sections": false, 219 | "sideBar": true, 220 | "skip_h1_title": false, 221 | "title_cell": "Table of Contents", 222 | "title_sidebar": "Contents", 223 | "toc_cell": false, 224 | "toc_position": {}, 225 | "toc_section_display": true, 226 | "toc_window_display": false 227 | }, 228 | "varInspector": { 229 | "cols": { 230 | "lenName": 16, 231 | "lenType": 16, 232 | "lenVar": 40 233 | }, 234 | "kernels_config": { 235 | "python": { 236 | "delete_cmd_postfix": "", 237 | "delete_cmd_prefix": "del ", 238 | "library": "var_list.py", 239 | "varRefreshCmd": "print(var_dic_list())" 240 | }, 241 | "r": { 242 | "delete_cmd_postfix": ") ", 243 | "delete_cmd_prefix": "rm(", 244 | "library": "var_list.r", 245 | "varRefreshCmd": "cat(var_dic_list()) " 246 | } 247 | }, 248 | "types_to_exclude": [ 249 | "module", 250 | "function", 251 | "builtin_function_or_method", 252 | "instance", 253 | "_Feature" 254 | ], 255 | "window_display": false 256 | } 257 | }, 258 | "nbformat": 4, 259 | "nbformat_minor": 4 260 | } 261 | -------------------------------------------------------------------------------- /efficient_capsnet_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Efficient-CapsNet Model Train\n", 8 | "\n", 9 | "In this notebook we provide a simple interface to train Efficient-CapsNet on the three dataset discussed in \"Efficient-CapsNet: Capsule Network with Self-Attention Routing\":\n", 10 | "\n", 11 | "- MNIST (MNIST)\n", 12 | "- smallNORB (SMALLNORB)\n", 13 | "- Multi-MNIST (MULTIMNIST)\n", 14 | "\n", 15 | "The hyperparameters have been only slightly investigated. So, there's a lot of room for improvements. Good luck!\n", 16 | "\n", 17 | "**NB**: remember to modify the \"config.json\" file with the appropriate parameters." 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": { 24 | "ExecuteTime": { 25 | "end_time": "2021-01-28T14:17:38.152068Z", 26 | "start_time": "2021-01-28T14:17:38.145241Z" 27 | } 28 | }, 29 | "outputs": [], 30 | "source": [ 31 | "%load_ext autoreload\n", 32 | "%autoreload 2" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": { 39 | "ExecuteTime": { 40 | "end_time": "2021-01-28T14:17:39.436665Z", 41 | "start_time": "2021-01-28T14:17:38.152986Z" 42 | } 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "import tensorflow as tf\n", 47 | "from utils import Dataset, plotImages, plotWrongImages, plotHistory\n", 48 | "from models import EfficientCapsNet" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": { 55 | "ExecuteTime": { 56 | "end_time": "2021-01-28T14:17:39.485046Z", 57 | "start_time": "2021-01-28T14:17:39.438120Z" 58 | } 59 | }, 60 | "outputs": [], 61 | "source": [ 62 | "gpus = tf.config.experimental.list_physical_devices('GPU')\n", 63 | "tf.config.experimental.set_visible_devices(gpus[0], 'GPU')\n", 64 | "tf.config.experimental.set_memory_growth(gpus[0], True)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": { 71 | "ExecuteTime": { 72 | "end_time": "2021-01-28T14:17:39.502857Z", 73 | "start_time": "2021-01-28T14:17:39.486169Z" 74 | } 75 | }, 76 | "outputs": [], 77 | "source": [ 78 | "# some parameters\n", 79 | "model_name = 'MNIST'" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "# 1.0 Import the Dataset" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": { 93 | "ExecuteTime": { 94 | "end_time": "2021-01-28T14:17:39.898397Z", 95 | "start_time": "2021-01-28T14:17:39.503821Z" 96 | } 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "dataset = Dataset(model_name, config_path='config.json')" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": {}, 106 | "source": [ 107 | "## 1.1 Visualize imported dataset" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": { 114 | "ExecuteTime": { 115 | "end_time": "2021-01-28T14:17:41.229443Z", 116 | "start_time": "2021-01-28T14:17:39.899261Z" 117 | } 118 | }, 119 | "outputs": [], 120 | "source": [ 121 | "n_images = 20 # number of images to be plotted\n", 122 | "plotImages(dataset.X_test[:n_images,...,0], dataset.y_test[:n_images], n_images, dataset.class_names)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "# 2.0 Load the Model" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": { 136 | "ExecuteTime": { 137 | "end_time": "2021-01-28T14:21:46.634945Z", 138 | "start_time": "2021-01-28T14:21:46.311296Z" 139 | }, 140 | "scrolled": false 141 | }, 142 | "outputs": [], 143 | "source": [ 144 | "model_train = EfficientCapsNet(model_name, mode='train', verbose=True)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": { 150 | "ExecuteTime": { 151 | "end_time": "2021-01-25T17:38:06.189031Z", 152 | "start_time": "2021-01-25T17:38:05.460415Z" 153 | } 154 | }, 155 | "source": [ 156 | "# 3.0 Train the Model" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": { 163 | "ExecuteTime": { 164 | "end_time": "2021-01-28T14:22:02.087316Z", 165 | "start_time": "2021-01-28T14:22:02.031863Z" 166 | } 167 | }, 168 | "outputs": [], 169 | "source": [ 170 | "dataset_train, dataset_val = dataset.get_tf_data() " 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "metadata": { 177 | "ExecuteTime": { 178 | "end_time": "2021-01-28T14:25:09.510250Z", 179 | "start_time": "2021-01-28T14:24:56.018640Z" 180 | }, 181 | "scrolled": true 182 | }, 183 | "outputs": [], 184 | "source": [ 185 | "history = model_train.train(dataset, initial_epoch=0)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "plotHistory(history)" 195 | ] 196 | } 197 | ], 198 | "metadata": { 199 | "kernelspec": { 200 | "display_name": "Python 3", 201 | "language": "python", 202 | "name": "python3" 203 | }, 204 | "language_info": { 205 | "codemirror_mode": { 206 | "name": "ipython", 207 | "version": 3 208 | }, 209 | "file_extension": ".py", 210 | "mimetype": "text/x-python", 211 | "name": "python", 212 | "nbconvert_exporter": "python", 213 | "pygments_lexer": "ipython3", 214 | "version": "3.6.9" 215 | }, 216 | "toc": { 217 | "base_numbering": 1, 218 | "nav_menu": {}, 219 | "number_sections": false, 220 | "sideBar": true, 221 | "skip_h1_title": false, 222 | "title_cell": "Table of Contents", 223 | "title_sidebar": "Contents", 224 | "toc_cell": false, 225 | "toc_position": {}, 226 | "toc_section_display": true, 227 | "toc_window_display": false 228 | }, 229 | "varInspector": { 230 | "cols": { 231 | "lenName": 16, 232 | "lenType": 16, 233 | "lenVar": 40 234 | }, 235 | "kernels_config": { 236 | "python": { 237 | "delete_cmd_postfix": "", 238 | "delete_cmd_prefix": "del ", 239 | "library": "var_list.py", 240 | "varRefreshCmd": "print(var_dic_list())" 241 | }, 242 | "r": { 243 | "delete_cmd_postfix": ") ", 244 | "delete_cmd_prefix": "rm(", 245 | "library": "var_list.r", 246 | "varRefreshCmd": "cat(var_dic_list()) " 247 | } 248 | }, 249 | "types_to_exclude": [ 250 | "module", 251 | "function", 252 | "builtin_function_or_method", 253 | "instance", 254 | "_Feature" 255 | ], 256 | "window_display": false 257 | } 258 | }, 259 | "nbformat": 4, 260 | "nbformat_minor": 4 261 | } 262 | -------------------------------------------------------------------------------- /media/dimension_perturbation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EscVM/Efficient-CapsNet/4b337f5bf79d70a56e4dec8b113fe54b44cfe963/media/dimension_perturbation.gif -------------------------------------------------------------------------------- /media/efficient_capsnet_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EscVM/Efficient-CapsNet/4b337f5bf79d70a56e4dec8b113fe54b44cfe963/media/efficient_capsnet_architecture.png -------------------------------------------------------------------------------- /media/routing_capsules.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EscVM/Efficient-CapsNet/4b337f5bf79d70a56e4dec8b113fe54b44cfe963/media/routing_capsules.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.model import EfficientCapsNet, CapsNet 2 | 3 | -------------------------------------------------------------------------------- /models/efficient_capsnet_graph_mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Vittorio Mazzia & Francesco Salvetti. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | from utils.layers import PrimaryCaps, FCCaps, Length, Mask 19 | 20 | 21 | def efficient_capsnet_graph(input_shape): 22 | """ 23 | Efficient-CapsNet graph architecture. 24 | 25 | Parameters 26 | ---------- 27 | input_shape: list 28 | network input shape 29 | """ 30 | inputs = tf.keras.Input(input_shape) 31 | 32 | x = tf.keras.layers.Conv2D(32,5,activation="relu", padding='valid', kernel_initializer='he_normal')(inputs) 33 | x = tf.keras.layers.BatchNormalization()(x) 34 | x = tf.keras.layers.Conv2D(64,3, activation='relu', padding='valid', kernel_initializer='he_normal')(x) 35 | x = tf.keras.layers.BatchNormalization()(x) 36 | x = tf.keras.layers.Conv2D(64,3, activation='relu', padding='valid', kernel_initializer='he_normal')(x) 37 | x = tf.keras.layers.BatchNormalization()(x) 38 | x = tf.keras.layers.Conv2D(128,3,2, activation='relu', padding='valid', kernel_initializer='he_normal')(x) 39 | x = tf.keras.layers.BatchNormalization()(x) 40 | x = PrimaryCaps(128, 9, 16, 8)(x) 41 | 42 | digit_caps = FCCaps(10,16)(x) 43 | 44 | digit_caps_len = Length(name='length_capsnet_output')(digit_caps) 45 | 46 | return tf.keras.Model(inputs=inputs,outputs=[digit_caps, digit_caps_len], name='Efficient_CapsNet') 47 | 48 | 49 | def generator_graph(input_shape): 50 | """ 51 | Generator graph architecture. 52 | 53 | Parameters 54 | ---------- 55 | input_shape: list 56 | network input shape 57 | """ 58 | inputs = tf.keras.Input(16*10) 59 | 60 | x = tf.keras.layers.Dense(512, activation='relu', kernel_initializer='he_normal')(inputs) 61 | x = tf.keras.layers.Dense(1024, activation='relu', kernel_initializer='he_normal')(x) 62 | x = tf.keras.layers.Dense(np.prod(input_shape), activation='sigmoid', kernel_initializer='glorot_normal')(x) 63 | x = tf.keras.layers.Reshape(target_shape=input_shape, name='out_generator')(x) 64 | 65 | return tf.keras.Model(inputs=inputs, outputs=x, name='Generator') 66 | 67 | 68 | def build_graph(input_shape, mode, verbose): 69 | """ 70 | Efficient-CapsNet graph architecture with reconstruction regularizer. The network can be initialize with different modalities. 71 | 72 | Parameters 73 | ---------- 74 | input_shape: list 75 | network input shape 76 | mode: str 77 | working mode ('train', 'test' & 'play') 78 | verbose: bool 79 | """ 80 | inputs = tf.keras.Input(input_shape) 81 | y_true = tf.keras.layers.Input(shape=(10,)) 82 | noise = tf.keras.layers.Input(shape=(10, 16)) 83 | 84 | efficient_capsnet = efficient_capsnet_graph(input_shape) 85 | 86 | if verbose: 87 | efficient_capsnet.summary() 88 | print("\n\n") 89 | 90 | digit_caps, digit_caps_len = efficient_capsnet(inputs) 91 | noised_digitcaps = tf.keras.layers.Add()([digit_caps, noise]) # only if mode is play 92 | 93 | masked_by_y = Mask()([digit_caps, y_true]) 94 | masked = Mask()(digit_caps) 95 | masked_noised_y = Mask()([noised_digitcaps, y_true]) 96 | 97 | generator = generator_graph(input_shape) 98 | 99 | if verbose: 100 | generator.summary() 101 | print("\n\n") 102 | 103 | x_gen_train = generator(masked_by_y) 104 | x_gen_eval = generator(masked) 105 | x_gen_play = generator(masked_noised_y) 106 | 107 | if mode == 'train': 108 | return tf.keras.models.Model([inputs, y_true], [digit_caps_len, x_gen_train], name='Efficinet_CapsNet_Generator') 109 | elif mode == 'test': 110 | return tf.keras.models.Model(inputs, [digit_caps_len, x_gen_eval], name='Efficinet_CapsNet_Generator') 111 | elif mode == 'play': 112 | return tf.keras.models.Model([inputs, y_true, noise], [digit_caps_len, x_gen_play], name='Efficinet_CapsNet_Generator') 113 | else: 114 | raise RuntimeError('mode not recognized') 115 | -------------------------------------------------------------------------------- /models/efficient_capsnet_graph_multimnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Vittorio Mazzia & Francesco Salvetti. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | from utils.layers import PrimaryCaps, FCCaps, Length, Mask 19 | 20 | 21 | def efficient_capsnet_graph(input_shape): 22 | """ 23 | Efficient-CapsNet graph architecture. 24 | 25 | Parameters 26 | ---------- 27 | input_shape: list 28 | network input shape 29 | """ 30 | inputs = tf.keras.Input(input_shape) 31 | 32 | x = tf.keras.layers.Conv2D(32,5,activation="relu", padding='valid', kernel_initializer='he_normal')(inputs) 33 | x = tf.keras.layers.BatchNormalization()(x) 34 | x = tf.keras.layers.Conv2D(64,3, activation='relu', padding='valid', kernel_initializer='he_normal')(x) 35 | x = tf.keras.layers.BatchNormalization()(x) 36 | x = tf.keras.layers.Conv2D(64,3,2, activation='relu', padding='valid', kernel_initializer='he_normal')(x) 37 | x = tf.keras.layers.BatchNormalization()(x) 38 | x = tf.keras.layers.Conv2D(128,3,2, activation='relu', padding='valid', kernel_initializer='he_normal')(x) 39 | x = tf.keras.layers.BatchNormalization()(x) 40 | x = PrimaryCaps(128, 5, 16, 8, 2)(x) 41 | 42 | digit_caps = FCCaps(10,16)(x) 43 | 44 | digit_caps_len = Length(name='length_capsnet_output')(digit_caps) 45 | 46 | return tf.keras.Model(inputs=inputs,outputs=[digit_caps, digit_caps_len], name='Efficient_CapsNet') 47 | 48 | 49 | def generator_graph(input_shape): 50 | """ 51 | Generator graph architecture. 52 | 53 | Parameters 54 | ---------- 55 | input_shape: list 56 | network input shape 57 | """ 58 | inputs = tf.keras.Input(16*10) 59 | 60 | x = tf.keras.layers.Dense(512, activation='relu', kernel_initializer='he_normal')(inputs) 61 | x = tf.keras.layers.Dense(1024, activation='relu', kernel_initializer='he_normal')(x) 62 | x = tf.keras.layers.Dense(np.prod(input_shape), activation='sigmoid', kernel_initializer='glorot_normal')(x) 63 | x = tf.keras.layers.Reshape(target_shape=input_shape, name='out_generator')(x) 64 | 65 | return tf.keras.Model(inputs=inputs, outputs=x, name='Generator') 66 | 67 | 68 | def build_graph(input_shape, mode, verbose): 69 | """ 70 | Efficient-CapsNet graph architecture with reconstruction regularizer. The network can be initialize with different modalities. 71 | Parameters 72 | ---------- 73 | input_shape: list 74 | network input shape 75 | mode: str 76 | working mode ('train', 'test' & 'play') 77 | verbose: bool 78 | """ 79 | inputs = tf.keras.Input(input_shape) 80 | y_true1 = tf.keras.layers.Input(shape=(10,)) 81 | y_true2 = tf.keras.layers.Input(shape=(10,)) 82 | 83 | efficient_capsnet = efficient_capsnet_graph(input_shape) 84 | 85 | if verbose: 86 | efficient_capsnet.summary() 87 | print("\n\n") 88 | 89 | digit_caps, digit_caps_len = efficient_capsnet(inputs) 90 | 91 | masked_by_y1,masked_by_y2 = Mask()([digit_caps, y_true1, y_true2],double_mask=True) 92 | masked1,masked2 = Mask()(digit_caps,double_mask=True) 93 | 94 | generator = generator_graph(input_shape) 95 | 96 | if verbose: 97 | generator.summary() 98 | print("\n\n") 99 | 100 | x_gen_train1,x_gen_train2 = generator(masked_by_y1),generator(masked_by_y2) 101 | x_gen_eval1,x_gen_eval2 = generator(masked1),generator(masked2) 102 | 103 | if mode == 'train': 104 | return tf.keras.models.Model([inputs, y_true1,y_true2], [digit_caps_len, x_gen_train1,x_gen_train2], name='Efficinet_CapsNet_Generator') 105 | elif mode == 'test': 106 | return tf.keras.models.Model(inputs, [digit_caps_len, x_gen_eval1,x_gen_eval2], name='Efficinet_CapsNet_Generator') 107 | else: 108 | raise RuntimeError('mode not recognized') 109 | -------------------------------------------------------------------------------- /models/efficient_capsnet_graph_smallnorb.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Vittorio Mazzia & Francesco Salvetti. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | from utils.layers import PrimaryCaps, FCCaps, Length, Mask 19 | import tensorflow_addons as tfa 20 | 21 | 22 | def efficient_capsnet_graph(input_shape): 23 | """ 24 | Efficient-CapsNet graph architecture. 25 | 26 | Parameters 27 | ---------- 28 | input_shape: list 29 | network input shape 30 | """ 31 | inputs = tf.keras.Input(input_shape) 32 | 33 | x = tf.keras.layers.Conv2D(32,7,2,activation=None, padding='valid', kernel_initializer='he_normal')(inputs) 34 | x = tf.keras.layers.LeakyReLU()(x) 35 | x = tfa.layers.InstanceNormalization(axis=3, 36 | center=True, 37 | scale=True, 38 | beta_initializer="random_uniform", 39 | gamma_initializer="random_uniform")(x) 40 | x = tf.keras.layers.Conv2D(64,3, activation=None, padding='valid', kernel_initializer='he_normal')(x) 41 | x = tf.keras.layers.LeakyReLU()(x) 42 | x = tfa.layers.InstanceNormalization(axis=3, 43 | center=True, 44 | scale=True, 45 | beta_initializer="random_uniform", 46 | gamma_initializer="random_uniform")(x) 47 | x = tf.keras.layers.Conv2D(64,3, activation=None, padding='valid', kernel_initializer='he_normal')(x) 48 | x = tf.keras.layers.LeakyReLU()(x) 49 | x = tfa.layers.InstanceNormalization(axis=3, 50 | center=True, 51 | scale=True, 52 | beta_initializer="random_uniform", 53 | gamma_initializer="random_uniform")(x) 54 | x = tf.keras.layers.Conv2D(128,3,2, activation=None, padding='valid', kernel_initializer='he_normal')(x) 55 | x = tf.keras.layers.LeakyReLU()(x) 56 | x = tfa.layers.InstanceNormalization(axis=3, 57 | center=True, 58 | scale=True, 59 | beta_initializer="random_uniform", 60 | gamma_initializer="random_uniform")(x) 61 | 62 | x = PrimaryCaps(128, 8, 16, 8)(x) # there could be an error 63 | 64 | digit_caps = FCCaps(5,16)(x) 65 | 66 | 67 | digit_caps_len = Length(name='length_capsnet_output')(digit_caps) 68 | 69 | return tf.keras.Model(inputs=inputs,outputs=[digit_caps,digit_caps_len], name='Efficient_CapsNet') 70 | 71 | 72 | def generator_graph(input_shape): 73 | """ 74 | Generator graph architecture. 75 | 76 | Parameters 77 | ---------- 78 | input_shape: list 79 | network input shape 80 | """ 81 | inputs = tf.keras.Input(16*5) 82 | 83 | x = tf.keras.layers.Dense(64)(inputs) 84 | x = tf.keras.layers.Reshape(target_shape=(8,8,1))(x) 85 | x = tf.keras.layers.UpSampling2D(size=(2,2), interpolation='bilinear')(x) 86 | x = tf.keras.layers.Conv2D(filters=64, kernel_size=(3,3), padding="valid", activation=tf.nn.leaky_relu)(x) 87 | x = tf.keras.layers.UpSampling2D(size=(2,2), interpolation='bilinear')(x) 88 | x = tf.keras.layers.Conv2D(filters=128, kernel_size=(3,3), padding="valid", activation=tf.nn.leaky_relu)(x) 89 | x = tf.keras.layers.UpSampling2D(size=(2,2), interpolation='bilinear')(x) 90 | x = tf.keras.layers.Conv2D(filters=128, kernel_size=(3,3), padding="valid", activation=tf.nn.leaky_relu)(x) 91 | x = tf.keras.layers.Conv2D(filters=2, kernel_size=(3,3), padding="valid", activation=tf.nn.sigmoid)(x) 92 | 93 | return tf.keras.Model(inputs=inputs, outputs=x, name='Generator') 94 | 95 | 96 | def build_graph(input_shape, mode, verbose): 97 | """ 98 | Efficient-CapsNet graph architecture with reconstruction regularizer. The network can be initialize with different modalities. 99 | 100 | Parameters 101 | ---------- 102 | input_shape: list 103 | network input shape 104 | mode: str 105 | working mode ('train' & 'test') 106 | verbose: bool 107 | """ 108 | inputs = tf.keras.Input(input_shape) 109 | y_true = tf.keras.layers.Input(shape=(5,)) 110 | 111 | 112 | efficient_capsnet = efficient_capsnet_graph(input_shape) 113 | 114 | if verbose: 115 | efficient_capsnet.summary() 116 | print("\n\n") 117 | 118 | digit_caps, digit_caps_len = efficient_capsnet(inputs) 119 | 120 | 121 | masked_by_y = Mask()([digit_caps, y_true]) 122 | masked = Mask()(digit_caps) 123 | 124 | generator = generator_graph(input_shape) 125 | 126 | if verbose: 127 | generator.summary() 128 | print("\n\n") 129 | 130 | x_gen_train = generator(masked_by_y) 131 | x_gen_eval = generator(masked) 132 | 133 | if mode == 'train': 134 | return tf.keras.models.Model([inputs, y_true], [digit_caps_len, x_gen_train]) 135 | elif mode == 'test': 136 | return tf.keras.models.Model(inputs, [digit_caps_len, x_gen_eval]) 137 | else: 138 | raise RuntimeError('mode not recognized') 139 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Vittorio Mazzia & Francesco Salvetti. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | from utils.layers import PrimaryCaps, FCCaps, Length 19 | from utils.tools import get_callbacks, marginLoss, multiAccuracy 20 | from utils.dataset import Dataset 21 | from utils import pre_process_multimnist 22 | from models import efficient_capsnet_graph_mnist, efficient_capsnet_graph_smallnorb, efficient_capsnet_graph_multimnist, original_capsnet_graph_mnist 23 | import os 24 | import json 25 | from tqdm.notebook import tqdm 26 | 27 | 28 | class Model(object): 29 | """ 30 | A class used to share common model functions and attributes. 31 | 32 | ... 33 | 34 | Attributes 35 | ---------- 36 | model_name: str 37 | name of the model (Ex. 'MNIST') 38 | mode: str 39 | model modality (Ex. 'test') 40 | config_path: str 41 | path configuration file 42 | verbose: bool 43 | 44 | Methods 45 | ------- 46 | load_config(): 47 | load configuration file 48 | load_graph_weights(): 49 | load network weights 50 | predict(dataset_test): 51 | use the model to predict dataset_test 52 | evaluate(X_test, y_test): 53 | comute accuracy and test error with the given dataset (X_test, y_test) 54 | save_graph_weights(): 55 | save model weights 56 | """ 57 | def __init__(self, model_name, mode='test', config_path='config.json', verbose=True): 58 | self.model_name = model_name 59 | self.model = None 60 | self.mode = mode 61 | self.config_path = config_path 62 | self.config = None 63 | self.verbose = verbose 64 | self.load_config() 65 | 66 | 67 | def load_config(self): 68 | """ 69 | Load config file 70 | """ 71 | with open(self.config_path) as json_data_file: 72 | self.config = json.load(json_data_file) 73 | 74 | 75 | def load_graph_weights(self): 76 | try: 77 | self.model.load_weights(self.model_path) 78 | except Exception as e: 79 | print("[ERRROR] Graph Weights not found") 80 | 81 | 82 | def predict(self, dataset_test): 83 | return self.model.predict(dataset_test) 84 | 85 | 86 | def evaluate(self, X_test, y_test): 87 | print('-'*30 + f'{self.model_name} Evaluation' + '-'*30) 88 | if self.model_name == "MULTIMNIST": 89 | dataset_test = pre_process_multimnist.generate_tf_data_test(X_test, y_test, self.config["shift_multimnist"], n_multi=self.config['n_overlay_multimnist']) 90 | acc = [] 91 | for X,y in tqdm(dataset_test,total=len(X_test)): 92 | y_pred,X_gen1,X_gen2 = self.model.predict(X) 93 | acc.append(multiAccuracy(y, y_pred)) 94 | acc = np.mean(acc) 95 | else: 96 | y_pred, X_gen = self.model.predict(X_test) 97 | acc = np.sum(np.argmax(y_pred, 1) == np.argmax(y_test, 1))/y_test.shape[0] 98 | test_error = 1 - acc 99 | print('Test acc:', acc) 100 | print(f"Test error [%]: {(test_error):.4%}") 101 | if self.model_name == "MULTIMNIST": 102 | print(f"N° misclassified images: {int(test_error*len(y_test)*self.config['n_overlay_multimnist'])} out of {len(y_test)*self.config['n_overlay_multimnist']}") 103 | else: 104 | print(f"N° misclassified images: {int(test_error*len(y_test))} out of {len(y_test)}") 105 | 106 | 107 | def save_graph_weights(self): 108 | self.model.save_weights(self.model_path) 109 | 110 | 111 | 112 | class EfficientCapsNet(Model): 113 | """ 114 | A class used to manage an Efficiet-CapsNet model. 'model_name' and 'mode' define the particular architecure and modality of the 115 | generated network. 116 | 117 | ... 118 | 119 | Attributes 120 | ---------- 121 | model_name: str 122 | name of the model (Ex. 'MNIST') 123 | mode: str 124 | model modality (Ex. 'test') 125 | config_path: str 126 | path configuration file 127 | custom_path: str 128 | custom weights path 129 | verbose: bool 130 | 131 | Methods 132 | ------- 133 | load_graph(): 134 | load the network graph given the model_name 135 | train(dataset, initial_epoch) 136 | train the constructed network with a given dataset. All train hyperparameters are defined in the configuration file 137 | 138 | """ 139 | def __init__(self, model_name, mode='test', config_path='config.json', custom_path=None, verbose=True): 140 | Model.__init__(self, model_name, mode, config_path, verbose) 141 | if custom_path != None: 142 | self.model_path = custom_path 143 | else: 144 | self.model_path = os.path.join(self.config['saved_model_dir'], f"efficient_capsnet_{self.model_name}.h5") 145 | self.model_path_new_train = os.path.join(self.config['saved_model_dir'], f"efficient_capsnet{self.model_name}_new_train.h5") 146 | self.tb_path = os.path.join(self.config['tb_log_save_dir'], f"efficient_capsnet_{self.model_name}") 147 | self.load_graph() 148 | 149 | 150 | def load_graph(self): 151 | if self.model_name == 'MNIST': 152 | self.model = efficient_capsnet_graph_mnist.build_graph(self.config['MNIST_INPUT_SHAPE'], self.mode, self.verbose) 153 | elif self.model_name == 'SMALLNORB': 154 | self.model = efficient_capsnet_graph_smallnorb.build_graph(self.config['SMALLNORB_INPUT_SHAPE'], self.mode, self.verbose) 155 | elif self.model_name == 'MULTIMNIST': 156 | self.model = efficient_capsnet_graph_multimnist.build_graph(self.config['MULTIMNIST_INPUT_SHAPE'], self.mode, self.verbose) 157 | 158 | def train(self, dataset=None, initial_epoch=0): 159 | callbacks = get_callbacks(self.tb_path, self.model_path_new_train, self.config['lr_dec'], self.config['lr']) 160 | 161 | if dataset == None: 162 | dataset = Dataset(self.model_name, self.config_path) 163 | dataset_train, dataset_val = dataset.get_tf_data() 164 | 165 | if self.model_name == 'MULTIMNIST': 166 | self.model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=self.config['lr']), 167 | loss=[marginLoss, 'mse', 'mse'], 168 | loss_weights=[1., self.config['lmd_gen']/2,self.config['lmd_gen']/2], 169 | metrics={'Efficient_CapsNet': multiAccuracy}) 170 | steps = 10*int(dataset.y_train.shape[0] / self.config['batch_size']) 171 | else: 172 | self.model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=self.config['lr']), 173 | loss=[marginLoss, 'mse'], 174 | loss_weights=[1., self.config['lmd_gen']], 175 | metrics={'Efficient_CapsNet': 'accuracy'}) 176 | steps=None 177 | 178 | print('-'*30 + f'{self.model_name} train' + '-'*30) 179 | 180 | history = self.model.fit(dataset_train, 181 | epochs=self.config[f'epochs'], steps_per_epoch=steps, 182 | validation_data=(dataset_val), batch_size=self.config['batch_size'], initial_epoch=initial_epoch, 183 | callbacks=callbacks) 184 | 185 | return history 186 | 187 | 188 | 189 | 190 | class CapsNet(Model): 191 | """ 192 | A class used to manage the original CapsNet architecture. 193 | 194 | ... 195 | 196 | Attributes 197 | ---------- 198 | model_name: str 199 | name of the model (only MNIST provided) 200 | mode: str 201 | model modality (Ex. 'test') 202 | config_path: str 203 | path configuration file 204 | verbose: bool 205 | n_routing: int 206 | number of routing interations 207 | 208 | Methods 209 | ------- 210 | load_graph(): 211 | load the network graph given the model_name 212 | train(): 213 | train the constructed network with a given dataset. All train hyperparameters are defined in the configuration file 214 | """ 215 | def __init__(self, model_name, mode='test', config_path='config.json', custom_path=None, verbose=True, n_routing=3): 216 | Model.__init__(self, model_name, mode, config_path, verbose) 217 | self.n_routing = n_routing 218 | self.load_config() 219 | if custom_path != None: 220 | self.model_path = custom_path 221 | else: 222 | self.model_path = os.path.join(self.config['saved_model_dir'], f"efficient_capsnet_{self.model_name}.h5") 223 | self.model_path_new_train = os.path.join(self.config['saved_model_dir'], f"original_capsnet_{self.model_name}_new_train.h5") 224 | self.tb_path = os.path.join(self.config['tb_log_save_dir'], f"original_capsnet_{self.model_name}") 225 | self.load_graph() 226 | 227 | 228 | def load_graph(self): 229 | self.model = original_capsnet_graph_mnist.build_graph(self.config['MNIST_INPUT_SHAPE'], self.mode, self.n_routing, self.verbose) 230 | 231 | def train(self, dataset=None, initial_epoch=0): 232 | callbacks = get_callbacks(self.tb_path, self.model_path_new_train, self.config['lr_dec'], self.config['lr']) 233 | 234 | if dataset == None: 235 | dataset = Dataset(self.model_name, self.config_path) 236 | dataset_train, dataset_val = dataset.get_tf_data() 237 | 238 | 239 | self.model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=self.config['lr']), 240 | loss=[marginLoss, 'mse'], 241 | loss_weights=[1., self.config['lmd_gen']], 242 | metrics={'Original_CapsNet': 'accuracy'}) 243 | 244 | print('-'*30 + f'{self.model_name} train' + '-'*30) 245 | 246 | history = self.model.fit(dataset_train, 247 | epochs=self.config['epochs'], 248 | validation_data=(dataset_val), batch_size=self.config['batch_size'], initial_epoch=initial_epoch, 249 | callbacks=callbacks) 250 | 251 | return history 252 | -------------------------------------------------------------------------------- /models/original_capsnet_graph_mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Vittorio Mazzia & Francesco Salvetti. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | from utils.layers_hinton import PrimaryCaps, DigitCaps, Length, Mask 19 | import tensorflow_addons as tfa 20 | 21 | 22 | def capsnet_graph(input_shape, routing): 23 | """ 24 | Original CapsNet graph architecture described in "dynamic routinig between capsules". 25 | 26 | Parameters 27 | ---------- 28 | input_shape: list 29 | network input shape 30 | routing: int 31 | number of routing iterations 32 | """ 33 | inputs = tf.keras.Input(input_shape) 34 | 35 | x = tf.keras.layers.Conv2D(256, 9, activation="relu")(inputs) 36 | primary = PrimaryCaps(C=32, L=8, k=9, s=2)(x) 37 | digit_caps = DigitCaps(10, 16, routing=routing)(primary) 38 | digit_caps_len = Length(name='capsnet_output_len')(digit_caps) 39 | pr_shape = primary.shape 40 | primary = tf.reshape(primary,(-1,pr_shape[1]*pr_shape[2]*pr_shape[3],pr_shape[-1])) 41 | 42 | return tf.keras.Model(inputs=inputs,outputs=[primary, digit_caps, digit_caps_len] , name='Original_CapsNet') 43 | 44 | 45 | def generator_graph(input_shape): 46 | """ 47 | Generator graph architecture. 48 | 49 | Parameters 50 | ---------- 51 | input_shape: list 52 | network input shape 53 | """ 54 | inputs = tf.keras.Input(16*10) 55 | 56 | x = tf.keras.layers.Dense(512, activation='relu')(inputs) 57 | x = tf.keras.layers.Dense(1024, activation='relu')(x) 58 | x = tf.keras.layers.Dense(np.prod(input_shape), activation='sigmoid')(x) 59 | x = tf.keras.layers.Reshape(target_shape=input_shape, name='out_generator')(x) 60 | 61 | return tf.keras.Model(inputs=inputs, outputs=x, name='Generator') 62 | 63 | 64 | def build_graph(input_shape, mode, n_routing, verbose): 65 | """ 66 | Original CapsNet graph architecture with reconstruction regularizer. The network can be initialize with different modalities. 67 | 68 | Parameters 69 | ---------- 70 | input_shape: list 71 | network input shape 72 | mode: str 73 | working mode ('train' & 'test') 74 | n_routing: int 75 | number of routing iterations 76 | verbose: bool 77 | """ 78 | inputs = tf.keras.Input(input_shape) 79 | y_true = tf.keras.Input(shape=(10)) 80 | noise = tf.keras.layers.Input(shape=(10, 16)) 81 | 82 | capsnet = capsnet_graph(input_shape, routing=n_routing) 83 | primary, digit_caps, digit_caps_len = capsnet(inputs) 84 | noised_digitcaps = tf.keras.layers.Add()([digit_caps, noise]) # only if mode is play 85 | 86 | if verbose: 87 | capsnet.summary() 88 | print("\n\n") 89 | 90 | 91 | masked_by_y = Mask()([digit_caps, y_true]) # The true label is used to mask the output of capsule layer. For training 92 | masked = Mask()(digit_caps) # Mask using the capsule with maximal length. For prediction 93 | masked_noised_y = Mask()([noised_digitcaps, y_true]) 94 | 95 | 96 | generator = generator_graph(input_shape) 97 | 98 | if verbose: 99 | generator.summary() 100 | print("\n\n") 101 | 102 | x_gen_train = generator(masked_by_y) 103 | x_gen_eval = generator(masked) 104 | x_gen_play = generator(masked_noised_y) 105 | 106 | 107 | if mode == 'train': 108 | return tf.keras.models.Model([inputs, y_true], [digit_caps_len, x_gen_train], name='CapsNet_Generator') 109 | elif mode == 'test': 110 | return tf.keras.models.Model(inputs, [digit_caps_len, x_gen_eval], name='CapsNet_Generator') 111 | elif mode == 'play': 112 | return tf.keras.models.Model([inputs, y_true, noise], [digit_caps_len, x_gen_play], name='CapsNet_Generator') 113 | else: 114 | raise RuntimeError('mode not recognized') 115 | -------------------------------------------------------------------------------- /original_capsnet_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Original CapsNet Model Test\n", 8 | "\n", 9 | "In this notebook we provide a simple interface to test the original CapsNet model described in \"Dynamic routinig between capsules\". The model is copycat of the original Sara's repository (https://github.com/Sarasra/models/tree/master/research/capsules) and also the weights, provided with our repository, are derived from the original ones. " 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "ExecuteTime": { 17 | "end_time": "2021-01-25T22:03:01.755485Z", 18 | "start_time": "2021-01-25T22:03:01.745674Z" 19 | } 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "%load_ext autoreload\n", 24 | "%autoreload 2" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": { 31 | "ExecuteTime": { 32 | "end_time": "2021-01-25T22:03:03.310938Z", 33 | "start_time": "2021-01-25T22:03:02.045618Z" 34 | } 35 | }, 36 | "outputs": [], 37 | "source": [ 38 | "import tensorflow as tf\n", 39 | "from utils import Dataset, plotImages, plotWrongImages\n", 40 | "from models import CapsNet" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": { 47 | "ExecuteTime": { 48 | "end_time": "2021-01-25T22:03:03.373518Z", 49 | "start_time": "2021-01-25T22:03:03.328301Z" 50 | } 51 | }, 52 | "outputs": [], 53 | "source": [ 54 | "gpus = tf.config.experimental.list_physical_devices('GPU')\n", 55 | "tf.config.experimental.set_visible_devices(gpus[0], 'GPU')\n", 56 | "tf.config.experimental.set_memory_growth(gpus[0], True)" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": { 63 | "ExecuteTime": { 64 | "end_time": "2021-01-25T22:03:04.829181Z", 65 | "start_time": "2021-01-25T22:03:04.805968Z" 66 | } 67 | }, 68 | "outputs": [], 69 | "source": [ 70 | "# some parameters\n", 71 | "model_name = 'MNIST' # only MNIST is available\n", 72 | "n_routing = 3" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "# 1.0 Import the Dataset" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": { 86 | "ExecuteTime": { 87 | "end_time": "2021-01-25T22:03:08.790666Z", 88 | "start_time": "2021-01-25T22:03:08.483172Z" 89 | } 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "dataset = Dataset(model_name, config_path='config.json') # only MNIST" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "## 1.1 Visualize imported dataset" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": { 107 | "ExecuteTime": { 108 | "end_time": "2021-01-25T22:03:11.531218Z", 109 | "start_time": "2021-01-25T22:03:10.193090Z" 110 | } 111 | }, 112 | "outputs": [], 113 | "source": [ 114 | "n_images = 20 # number of images to be plotted\n", 115 | "plotImages(dataset.X_test[:n_images,...,0], dataset.y_test[:n_images], n_images, dataset.class_names)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "# 2.0 Load the Model" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": { 129 | "ExecuteTime": { 130 | "end_time": "2021-01-25T22:03:15.707231Z", 131 | "start_time": "2021-01-25T22:03:14.733048Z" 132 | } 133 | }, 134 | "outputs": [], 135 | "source": [ 136 | "model_test = CapsNet(model_name, mode='test', verbose=True, n_routing=n_routing)\n", 137 | "\n", 138 | "model_test.load_graph_weights() # load graph weights (bin folder)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": {}, 144 | "source": [ 145 | "# 3.0 Test the Model" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": { 152 | "ExecuteTime": { 153 | "end_time": "2021-01-25T22:03:25.415576Z", 154 | "start_time": "2021-01-25T22:03:18.826201Z" 155 | } 156 | }, 157 | "outputs": [], 158 | "source": [ 159 | "model_test.evaluate(dataset.X_test, dataset.y_test) # if \"smallnorb\" use X_test_patch\n", 160 | "y_pred = model_test.predict(dataset.X_test)[0] # if \"smallnorb\" use X_test_patch" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "## 3.1 Plot misclassified images" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": { 174 | "ExecuteTime": { 175 | "end_time": "2021-01-25T22:03:26.222073Z", 176 | "start_time": "2021-01-25T22:03:25.437213Z" 177 | } 178 | }, 179 | "outputs": [], 180 | "source": [ 181 | "n_images = 20\n", 182 | "plotWrongImages(dataset.X_test, dataset.y_test, y_pred, # if \"smallnorb\" use X_test_patch\n", 183 | " n_images, dataset.class_names)" 184 | ] 185 | } 186 | ], 187 | "metadata": { 188 | "kernelspec": { 189 | "display_name": "Python 3", 190 | "language": "python", 191 | "name": "python3" 192 | }, 193 | "language_info": { 194 | "codemirror_mode": { 195 | "name": "ipython", 196 | "version": 3 197 | }, 198 | "file_extension": ".py", 199 | "mimetype": "text/x-python", 200 | "name": "python", 201 | "nbconvert_exporter": "python", 202 | "pygments_lexer": "ipython3", 203 | "version": "3.6.9" 204 | }, 205 | "toc": { 206 | "base_numbering": 1, 207 | "nav_menu": {}, 208 | "number_sections": false, 209 | "sideBar": true, 210 | "skip_h1_title": false, 211 | "title_cell": "Table of Contents", 212 | "title_sidebar": "Contents", 213 | "toc_cell": false, 214 | "toc_position": {}, 215 | "toc_section_display": true, 216 | "toc_window_display": false 217 | }, 218 | "varInspector": { 219 | "cols": { 220 | "lenName": 16, 221 | "lenType": 16, 222 | "lenVar": 40 223 | }, 224 | "kernels_config": { 225 | "python": { 226 | "delete_cmd_postfix": "", 227 | "delete_cmd_prefix": "del ", 228 | "library": "var_list.py", 229 | "varRefreshCmd": "print(var_dic_list())" 230 | }, 231 | "r": { 232 | "delete_cmd_postfix": ") ", 233 | "delete_cmd_prefix": "rm(", 234 | "library": "var_list.r", 235 | "varRefreshCmd": "cat(var_dic_list()) " 236 | } 237 | }, 238 | "types_to_exclude": [ 239 | "module", 240 | "function", 241 | "builtin_function_or_method", 242 | "instance", 243 | "_Feature" 244 | ], 245 | "window_display": false 246 | } 247 | }, 248 | "nbformat": 4, 249 | "nbformat_minor": 4 250 | } 251 | -------------------------------------------------------------------------------- /original_capsnet_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Original CapsNet Model Train\n", 8 | "\n", 9 | "In this notebook we provide a simple interface to train the original CapsNet model described in \"Dynamic routinig between capsules\". The model is copycat of the original Sara's repository (https://github.com/Sarasra/models/tree/master/research/capsules).
\n", 10 | "However, if you really reach 99.75, you've got to buy me a drink :)" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": { 17 | "ExecuteTime": { 18 | "end_time": "2021-02-16T09:09:04.587350Z", 19 | "start_time": "2021-02-16T09:09:04.570402Z" 20 | } 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "%load_ext autoreload\n", 25 | "%autoreload 2" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": { 32 | "ExecuteTime": { 33 | "end_time": "2021-02-16T09:09:05.887355Z", 34 | "start_time": "2021-02-16T09:09:04.588441Z" 35 | } 36 | }, 37 | "outputs": [], 38 | "source": [ 39 | "import tensorflow as tf\n", 40 | "from utils import Dataset, plotImages, plotWrongImages\n", 41 | "from models import CapsNet" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": { 48 | "ExecuteTime": { 49 | "end_time": "2021-02-16T09:09:05.948466Z", 50 | "start_time": "2021-02-16T09:09:05.888700Z" 51 | } 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "gpus = tf.config.experimental.list_physical_devices('GPU')\n", 56 | "tf.config.experimental.set_visible_devices(gpus[0], 'GPU')\n", 57 | "tf.config.experimental.set_memory_growth(gpus[0], True)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": { 64 | "ExecuteTime": { 65 | "end_time": "2021-02-16T09:09:05.969183Z", 66 | "start_time": "2021-02-16T09:09:05.949736Z" 67 | } 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "# some parameters\n", 72 | "model_name = 'MNIST' # only MNIST is available\n", 73 | "n_routing = 3" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "# 1.0 Import the Dataset" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": { 87 | "ExecuteTime": { 88 | "end_time": "2021-02-16T09:09:06.264512Z", 89 | "start_time": "2021-02-16T09:09:05.970183Z" 90 | } 91 | }, 92 | "outputs": [], 93 | "source": [ 94 | "dataset = Dataset(model_name, config_path='config.json') # only MNIST" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "## 1.1 Visualize imported dataset" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": { 108 | "ExecuteTime": { 109 | "end_time": "2021-02-16T09:09:07.594324Z", 110 | "start_time": "2021-02-16T09:09:06.265453Z" 111 | } 112 | }, 113 | "outputs": [], 114 | "source": [ 115 | "n_images = 20 # number of images to be plotted\n", 116 | "plotImages(dataset.X_test[:n_images,...,0], dataset.y_test[:n_images], n_images, dataset.class_names)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "# 2.0 Load the Model" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": { 130 | "ExecuteTime": { 131 | "end_time": "2021-02-16T09:09:13.228879Z", 132 | "start_time": "2021-02-16T09:09:12.391672Z" 133 | } 134 | }, 135 | "outputs": [], 136 | "source": [ 137 | "model_train = CapsNet(model_name, mode='train', verbose=True, n_routing=n_routing)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "# 3.0 Train the Model" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": { 151 | "ExecuteTime": { 152 | "end_time": "2021-02-16T09:09:29.014172Z", 153 | "start_time": "2021-02-16T09:09:14.064376Z" 154 | } 155 | }, 156 | "outputs": [], 157 | "source": [ 158 | "history = model_train.train(dataset, initial_epoch=0)" 159 | ] 160 | } 161 | ], 162 | "metadata": { 163 | "kernelspec": { 164 | "display_name": "Python 3", 165 | "language": "python", 166 | "name": "python3" 167 | }, 168 | "language_info": { 169 | "codemirror_mode": { 170 | "name": "ipython", 171 | "version": 3 172 | }, 173 | "file_extension": ".py", 174 | "mimetype": "text/x-python", 175 | "name": "python", 176 | "nbconvert_exporter": "python", 177 | "pygments_lexer": "ipython3", 178 | "version": "3.6.9" 179 | }, 180 | "toc": { 181 | "base_numbering": 1, 182 | "nav_menu": {}, 183 | "number_sections": false, 184 | "sideBar": true, 185 | "skip_h1_title": false, 186 | "title_cell": "Table of Contents", 187 | "title_sidebar": "Contents", 188 | "toc_cell": false, 189 | "toc_position": {}, 190 | "toc_section_display": true, 191 | "toc_window_display": false 192 | }, 193 | "varInspector": { 194 | "cols": { 195 | "lenName": 16, 196 | "lenType": 16, 197 | "lenVar": 40 198 | }, 199 | "kernels_config": { 200 | "python": { 201 | "delete_cmd_postfix": "", 202 | "delete_cmd_prefix": "del ", 203 | "library": "var_list.py", 204 | "varRefreshCmd": "print(var_dic_list())" 205 | }, 206 | "r": { 207 | "delete_cmd_postfix": ") ", 208 | "delete_cmd_prefix": "rm(", 209 | "library": "var_list.r", 210 | "varRefreshCmd": "cat(var_dic_list()) " 211 | } 212 | }, 213 | "types_to_exclude": [ 214 | "module", 215 | "function", 216 | "builtin_function_or_method", 217 | "instance", 218 | "_Feature" 219 | ], 220 | "window_display": false 221 | } 222 | }, 223 | "nbformat": 4, 224 | "nbformat_minor": 4 225 | } 226 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | tensorflow-addons 4 | opencv-python 5 | tqdm 6 | tensorflow 7 | matplotlib 8 | pytest 9 | jupyter 10 | tensorflow-datasets 11 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.layers import * 2 | from utils.visualization import AffineVisualizer, plotImages, plotWrongImages, plotHistory 3 | from utils.dataset import Dataset 4 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Vittorio Mazzia & Francesco Salvetti. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | import tensorflow_datasets as tfds 19 | import matplotlib.pyplot as plt 20 | import os 21 | from utils import pre_process_mnist, pre_process_multimnist, pre_process_smallnorb 22 | import json 23 | 24 | 25 | class Dataset(object): 26 | """ 27 | A class used to share common dataset functions and attributes. 28 | 29 | ... 30 | 31 | Attributes 32 | ---------- 33 | model_name: str 34 | name of the model (Ex. 'MNIST') 35 | config_path: str 36 | path configuration file 37 | 38 | Methods 39 | ------- 40 | load_config(): 41 | load configuration file 42 | get_dataset(): 43 | load the dataset defined by model_name and pre_process it 44 | get_tf_data(): 45 | get a tf.data.Dataset object of the loaded dataset. 46 | """ 47 | def __init__(self, model_name, config_path='config.json'): 48 | self.model_name = model_name 49 | self.config_path = config_path 50 | self.config = None 51 | self.X_train = None 52 | self.y_train = None 53 | self.X_test = None 54 | self.y_test = None 55 | self.class_names = None 56 | self.X_test_patch = None 57 | self.load_config() 58 | self.get_dataset() 59 | 60 | 61 | def load_config(self): 62 | """ 63 | Load config file 64 | """ 65 | with open(self.config_path) as json_data_file: 66 | self.config = json.load(json_data_file) 67 | 68 | 69 | def get_dataset(self): 70 | if self.model_name == 'MNIST': 71 | (self.X_train, self.y_train), (self.X_test, self.y_test) = tf.keras.datasets.mnist.load_data(path=self.config['mnist_path']) 72 | # prepare the data 73 | self.X_train, self.y_train = pre_process_mnist.pre_process(self.X_train, self.y_train) 74 | self.X_test, self.y_test = pre_process_mnist.pre_process(self.X_test, self.y_test) 75 | self.class_names = list(range(10)) 76 | print("[INFO] Dataset loaded!") 77 | elif self.model_name == 'SMALLNORB': 78 | # import the datatset 79 | (ds_train, ds_test), ds_info = tfds.load( 80 | 'smallnorb', 81 | split=['train', 'test'], 82 | shuffle_files=True, 83 | as_supervised=False, 84 | with_info=True) 85 | self.X_train, self.y_train = pre_process_smallnorb.pre_process(ds_train) 86 | self.X_test, self.y_test = pre_process_smallnorb.pre_process(ds_test) 87 | 88 | self.X_train, self.y_train = pre_process_smallnorb.standardize(self.X_train, self.y_train) 89 | self.X_train, self.y_train = pre_process_smallnorb.rescale(self.X_train, self.y_train, self.config) 90 | self.X_test, self.y_test = pre_process_smallnorb.standardize(self.X_test, self.y_test) 91 | self.X_test, self.y_test = pre_process_smallnorb.rescale(self.X_test, self.y_test, self.config) 92 | self.X_test_patch, self.y_test = pre_process_smallnorb.test_patches(self.X_test, self.y_test, self.config) 93 | self.class_names = ds_info.features['label_category'].names 94 | print("[INFO] Dataset loaded!") 95 | elif self.model_name == 'MULTIMNIST': 96 | (self.X_train, self.y_train), (self.X_test, self.y_test) = tf.keras.datasets.mnist.load_data(path=self.config['mnist_path']) 97 | # prepare the data 98 | self.X_train = pre_process_multimnist.pad_dataset(self.X_train, self.config["pad_multimnist"]) 99 | self.X_test = pre_process_multimnist.pad_dataset(self.X_test, self.config["pad_multimnist"]) 100 | self.X_train, self.y_train = pre_process_multimnist.pre_process(self.X_train, self.y_train) 101 | self.X_test, self.y_test = pre_process_multimnist.pre_process(self.X_test, self.y_test) 102 | self.class_names = list(range(10)) 103 | print("[INFO] Dataset loaded!") 104 | 105 | 106 | def get_tf_data(self): 107 | if self.model_name == 'MNIST': 108 | dataset_train, dataset_test = pre_process_mnist.generate_tf_data(self.X_train, self.y_train, self.X_test, self.y_test, self.config['batch_size']) 109 | elif self.model_name == 'SMALLNORB': 110 | dataset_train, dataset_test = pre_process_smallnorb.generate_tf_data(self.X_train, self.y_train, self.X_test_patch, self.y_test, self.config['batch_size']) 111 | elif self.model_name == 'MULTIMNIST': 112 | dataset_train, dataset_test = pre_process_multimnist.generate_tf_data(self.X_train, self.y_train, self.X_test, self.y_test, self.config['batch_size'], self.config["shift_multimnist"]) 113 | 114 | return dataset_train, dataset_test 115 | -------------------------------------------------------------------------------- /utils/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Vittorio Mazzia & Francesco Salvetti. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | 19 | 20 | class SquashHinton(tf.keras.layers.Layer): 21 | """ 22 | Squash activation function presented in 'Dynamic routinig between capsules'. 23 | 24 | ... 25 | 26 | Attributes 27 | ---------- 28 | eps: int 29 | fuzz factor used in numeric expression 30 | 31 | Methods 32 | ------- 33 | call(s) 34 | compute the activation from input capsules 35 | 36 | """ 37 | 38 | def __init__(self, eps=10e-21, **kwargs): 39 | super().__init__(**kwargs) 40 | self.eps = eps 41 | 42 | def call(self, s): 43 | n = tf.norm(s,axis=-1,keepdims=True) 44 | return tf.multiply(n**2/(1+n**2)/(n+self.eps), s) 45 | 46 | def get_config(self): 47 | base_config = super().get_config() 48 | return {**base_config} 49 | 50 | def compute_output_shape(self, input_shape): 51 | return input_shape 52 | 53 | 54 | 55 | class Squash(tf.keras.layers.Layer): 56 | """ 57 | Squash activation used in 'Efficient-CapsNet: Capsule Network with Self-Attention Routing'. 58 | 59 | ... 60 | 61 | Attributes 62 | ---------- 63 | eps: int 64 | fuzz factor used in numeric expression 65 | 66 | Methods 67 | ------- 68 | call(s) 69 | compute the activation from input capsules 70 | """ 71 | 72 | def __init__(self, eps=10e-21, **kwargs): 73 | super().__init__(**kwargs) 74 | self.eps = eps 75 | 76 | def call(self, s): 77 | n = tf.norm(s,axis=-1,keepdims=True) 78 | return (1 - 1/(tf.math.exp(n)+self.eps))*(s/(n+self.eps)) 79 | 80 | def get_config(self): 81 | base_config = super().get_config() 82 | return {**base_config} 83 | 84 | def compute_output_shape(self, input_shape): 85 | return input_shape 86 | 87 | 88 | 89 | 90 | class PrimaryCaps(tf.keras.layers.Layer): 91 | """ 92 | Create a primary capsule layer with the methodology described in 'Efficient-CapsNet: Capsule Network with Self-Attention Routing'. 93 | Properties of each capsule s_n are exatracted using a 2D depthwise convolution. 94 | 95 | ... 96 | 97 | Attributes 98 | ---------- 99 | F: int 100 | depthwise conv number of features 101 | K: int 102 | depthwise conv kernel dimension 103 | N: int 104 | number of primary capsules 105 | D: int 106 | primary capsules dimension (number of properties) 107 | s: int 108 | depthwise conv strides 109 | Methods 110 | ------- 111 | call(inputs) 112 | compute the primary capsule layer 113 | """ 114 | def __init__(self, F, K, N, D, s=1, **kwargs): 115 | super(PrimaryCaps, self).__init__(**kwargs) 116 | self.F = F 117 | self.K = K 118 | self.N = N 119 | self.D = D 120 | self.s = s 121 | 122 | def build(self, input_shape): 123 | self.DW_Conv2D = tf.keras.layers.Conv2D(self.F, self.K, self.s, 124 | activation='linear', groups=self.F, padding='valid') 125 | 126 | self.built = True 127 | 128 | def call(self, inputs): 129 | x = self.DW_Conv2D(inputs) 130 | x = tf.keras.layers.Reshape((self.N, self.D))(x) 131 | x = Squash()(x) 132 | 133 | return x 134 | 135 | 136 | 137 | class FCCaps(tf.keras.layers.Layer): 138 | """ 139 | Fully-connected caps layer. It exploites the routing mechanism, explained in 'Efficient-CapsNet: Capsule Network with Self-Attention Routing', 140 | to create a parent layer of capsules. 141 | 142 | ... 143 | 144 | Attributes 145 | ---------- 146 | N: int 147 | number of primary capsules 148 | D: int 149 | primary capsules dimension (number of properties) 150 | kernel_initilizer: str 151 | matrix W initialization strategy 152 | 153 | Methods 154 | ------- 155 | call(inputs) 156 | compute the primary capsule layer 157 | """ 158 | def __init__(self, N, D, kernel_initializer='he_normal', **kwargs): 159 | super(FCCaps, self).__init__(**kwargs) 160 | self.N = N 161 | self.D = D 162 | self.kernel_initializer = tf.keras.initializers.get(kernel_initializer) 163 | 164 | def build(self, input_shape): 165 | input_N = input_shape[-2] 166 | input_D = input_shape[-1] 167 | 168 | self.W = self.add_weight(shape=[self.N, input_N, input_D, self.D],initializer=self.kernel_initializer,name='W') 169 | self.b = self.add_weight(shape=[self.N, input_N,1], initializer=tf.zeros_initializer(), name='b') 170 | self.built = True 171 | 172 | def call(self, inputs, training=None): 173 | 174 | u = tf.einsum('...ji,kjiz->...kjz',inputs,self.W) # u shape=(None,N,H*W*input_N,D) 175 | 176 | c = tf.einsum('...ij,...kj->...i', u, u)[...,None] # b shape=(None,N,H*W*input_N,1) -> (None,j,i,1) 177 | c = c/tf.sqrt(tf.cast(self.D, tf.float32)) 178 | c = tf.nn.softmax(c, axis=1) # c shape=(None,N,H*W*input_N,1) -> (None,j,i,1) 179 | c = c + self.b 180 | s = tf.reduce_sum(tf.multiply(u, c),axis=-2) # s shape=(None,N,D) 181 | v = Squash()(s) # v shape=(None,N,D) 182 | 183 | return v 184 | 185 | def compute_output_shape(self, input_shape): 186 | return (None, self.C, self.L) 187 | 188 | def get_config(self): 189 | config = { 190 | 'N': self.N, 191 | 'D': self.D 192 | } 193 | base_config = super(FCCaps, self).get_config() 194 | return dict(list(base_config.items()) + list(config.items())) 195 | 196 | 197 | 198 | class Length(tf.keras.layers.Layer): 199 | """ 200 | Compute the length of each capsule n of a layer l. 201 | ... 202 | 203 | Methods 204 | ------- 205 | call(inputs) 206 | compute the length of each capsule 207 | """ 208 | 209 | def call(self, inputs, **kwargs): 210 | """ 211 | Compute the length of each capsule 212 | 213 | Parameters 214 | ---------- 215 | inputs: tensor 216 | tensor with shape [None, num_capsules (N), dim_capsules (D)] 217 | """ 218 | return tf.sqrt(tf.reduce_sum(tf.square(inputs), - 1) + tf.keras.backend.epsilon()) 219 | 220 | def compute_output_shape(self, input_shape): 221 | return input_shape[:-1] 222 | 223 | def get_config(self): 224 | config = super(Length, self).get_config() 225 | return config 226 | 227 | 228 | 229 | class Mask(tf.keras.layers.Layer): 230 | """ 231 | Mask operation described in 'Dynamic routinig between capsules'. 232 | 233 | ... 234 | 235 | Methods 236 | ------- 237 | call(inputs, double_mask) 238 | mask a capsule layer 239 | set double_mask for multimnist dataset 240 | """ 241 | def call(self, inputs, double_mask=None, **kwargs): 242 | if type(inputs) is list: 243 | if double_mask: 244 | inputs, mask1, mask2 = inputs 245 | else: 246 | inputs, mask = inputs 247 | else: 248 | x = tf.sqrt(tf.reduce_sum(tf.square(inputs), -1)) 249 | if double_mask: 250 | mask1 = tf.keras.backend.one_hot(tf.argsort(x,direction='DESCENDING',axis=-1)[...,0],num_classes=x.get_shape().as_list()[1]) 251 | mask2 = tf.keras.backend.one_hot(tf.argsort(x,direction='DESCENDING',axis=-1)[...,1],num_classes=x.get_shape().as_list()[1]) 252 | else: 253 | mask = tf.keras.backend.one_hot(indices=tf.argmax(x, 1), num_classes=x.get_shape().as_list()[1]) 254 | 255 | if double_mask: 256 | masked1 = tf.keras.backend.batch_flatten(inputs * tf.expand_dims(mask1, -1)) 257 | masked2 = tf.keras.backend.batch_flatten(inputs * tf.expand_dims(mask2, -1)) 258 | return masked1, masked2 259 | else: 260 | masked = tf.keras.backend.batch_flatten(inputs * tf.expand_dims(mask, -1)) 261 | return masked 262 | 263 | def compute_output_shape(self, input_shape): 264 | if type(input_shape[0]) is tuple: 265 | return tuple([None, input_shape[0][1] * input_shape[0][2]]) 266 | else: # generation step 267 | return tuple([None, input_shape[1] * input_shape[2]]) 268 | 269 | def get_config(self): 270 | config = super(Mask, self).get_config() 271 | return config 272 | 273 | -------------------------------------------------------------------------------- /utils/layers_hinton.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Vittorio Mazzia & Francesco Salvetti. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | 19 | 20 | 21 | def squash(s): 22 | """ 23 | Squash activation function presented in 'Dynamic routinig between capsules'. 24 | ... 25 | 26 | Parameters 27 | ---------- 28 | s: tensor 29 | input tensor 30 | """ 31 | n = tf.norm(s, axis=-1,keepdims=True) 32 | return tf.multiply(n**2/(1+n**2)/(n + tf.keras.backend.epsilon()), s) 33 | 34 | 35 | class PrimaryCaps(tf.keras.layers.Layer): 36 | """ 37 | Create a primary capsule layer with the methodology described in 'Dynamic routing between capsules'. 38 | ... 39 | 40 | Attributes 41 | ---------- 42 | C: int 43 | number of primary capsules 44 | L: int 45 | primary capsules dimension (number of properties) 46 | k: int 47 | kernel dimension 48 | s: int 49 | conv stride 50 | 51 | Methods 52 | ------- 53 | call(inputs) 54 | compute the primary capsule layer 55 | """ 56 | def __init__(self, C, L, k, s, **kwargs): 57 | super(PrimaryCaps, self).__init__(**kwargs) 58 | self.C = C 59 | self.L = L 60 | self.k = k 61 | self.s = s 62 | 63 | def build(self, input_shape): 64 | self.kernel = self.add_weight(shape=(self.k, self.k, input_shape[-1], self.C*self.L), initializer='glorot_uniform', name='kernel') 65 | self.biases = self.add_weight(shape=(self.C,self.L), initializer='zeros', name='biases') 66 | self.built = True 67 | 68 | def call(self, inputs): 69 | x = tf.nn.conv2d(inputs, self.kernel, self.s, 'VALID') 70 | H,W = x.shape[1:3] 71 | x = tf.keras.layers.Reshape((H, W, self.C, self.L))(x) 72 | x /= self.C 73 | x += self.biases 74 | x = squash(x) 75 | return x 76 | 77 | def compute_output_shape(self, input_shape): 78 | H,W = input_shape.shape[1:3] 79 | return (None, (H - self.k)/self.s + 1, (W - self.k)/self.s + 1, self.C, self.L) 80 | 81 | def get_config(self): 82 | config = { 83 | 'C': self.C, 84 | 'L': self.L, 85 | 'k': self.k, 86 | 's': self.s 87 | } 88 | base_config = super(PrimaryCaps, self).get_config() 89 | return dict(list(base_config.items()) + list(config.items())) 90 | 91 | class DigitCaps(tf.keras.layers.Layer): 92 | """ 93 | Create a digitcaps layer as described in 'Dynamic routing between capsules'. 94 | 95 | ... 96 | 97 | Attributes 98 | ---------- 99 | C: int 100 | number of primary capsules 101 | L: int 102 | primary capsules dimension (number of properties) 103 | routing: int 104 | number of routing iterations 105 | kernel_initializer: 106 | matrix W kernel initializer 107 | 108 | Methods 109 | ------- 110 | call(inputs) 111 | compute the primary capsule layer 112 | """ 113 | def __init__(self, C, L, routing=None, kernel_initializer='glorot_uniform', **kwargs): 114 | super(DigitCaps, self).__init__(**kwargs) 115 | self.C = C 116 | self.L = L 117 | self.routing = routing 118 | self.kernel_initializer = tf.keras.initializers.get(kernel_initializer) 119 | 120 | def build(self, input_shape): 121 | assert len(input_shape) >= 5, "The input Tensor should have shape=[None,H,W,input_C,input_L]" 122 | H = input_shape[-4] 123 | W = input_shape[-3] 124 | input_C = input_shape[-2] 125 | input_L = input_shape[-1] 126 | 127 | self.W = self.add_weight(shape=[H*W*input_C, input_L, self.L*self.C], initializer=self.kernel_initializer, name='W') 128 | self.biases = self.add_weight(shape=[self.C,self.L], initializer='zeros', name='biases') 129 | self.built = True 130 | 131 | def call(self, inputs): 132 | H,W,input_C,input_L = inputs.shape[1:] # input shape=(None,H,W,input_C,input_L) 133 | x = tf.reshape(inputs,(-1, H*W*input_C, input_L)) # x shape=(None,H*W*input_C,input_L) 134 | 135 | u = tf.einsum('...ji,jik->...jk', x, self.W) # u shape=(None,H*W*input_C,C*L) 136 | u = tf.reshape(u,(-1, H*W*input_C, self.C, self.L))# u shape=(None,H*W*input_C,C,L) 137 | 138 | if self.routing: 139 | #Hinton's routing 140 | b = tf.zeros(tf.shape(u)[:-1])[...,None] # b shape=(None,H*W*input_C,C,1) -> (None,i,j,1) 141 | for r in range(self.routing): 142 | c = tf.nn.softmax(b,axis=2) # c shape=(None,H*W*input_C,C,1) -> (None,i,j,1) 143 | s = tf.reduce_sum(tf.multiply(u,c),axis=1,keepdims=True) # s shape=(None,1,C,L) 144 | s += self.biases 145 | v = squash(s) # v shape=(None,1,C,L) 146 | if r < self.routing-1: 147 | b += tf.reduce_sum(tf.multiply(u, v), axis=-1, keepdims=True) 148 | v = v[:,0,...] # v shape=(None,C,L) 149 | else: 150 | s = tf.reduce_sum(u, axis=1, keepdims=True) 151 | s += self.biases 152 | v = squash(s) 153 | v = v[:,0,...] 154 | return v 155 | 156 | def compute_output_shape(self, input_shape): 157 | return (None, self.C, self.L) 158 | 159 | def get_config(self): 160 | config = { 161 | 'C': self.C, 162 | 'L': self.L, 163 | 'routing': self.routing 164 | } 165 | base_config = super(DigitCaps, self).get_config() 166 | return dict(list(base_config.items()) + list(config.items())) 167 | 168 | class Length(tf.keras.layers.Layer): 169 | """ 170 | Compute the length of each capsule n of a layer l. 171 | ... 172 | 173 | Methods 174 | ------- 175 | call(inputs) 176 | compute the length of each capsule 177 | """ 178 | 179 | def call(self, inputs, **kwargs): 180 | """ 181 | Compute the length of each capsule 182 | 183 | Parameters 184 | ---------- 185 | inputs: tensor 186 | tensor with shape [None, num_capsules (N), dim_capsules (D)] 187 | """ 188 | return tf.sqrt(tf.reduce_sum(tf.square(inputs), - 1) + tf.keras.backend.epsilon()) 189 | 190 | def compute_output_shape(self, input_shape): 191 | return input_shape[:-1] 192 | 193 | def get_config(self): 194 | config = super(Length, self).get_config() 195 | return config 196 | 197 | class Mask(tf.keras.layers.Layer): 198 | """ 199 | Mask operation described in 'Dynamic routinig between capsules'. 200 | 201 | ... 202 | 203 | Methods 204 | ------- 205 | call(inputs) 206 | mask a capsule layer 207 | 208 | """ 209 | def call(self, inputs, **kwargs): 210 | if type(inputs) is list: 211 | inputs, mask = inputs 212 | else: 213 | x = tf.sqrt(tf.reduce_sum(tf.square(inputs), -1)) 214 | mask = tf.keras.backend.one_hot(indices=tf.argmax(x, 1), num_classes=x.get_shape().as_list()[1]) 215 | 216 | masked = tf.keras.backend.batch_flatten(inputs * tf.expand_dims(mask, -1)) 217 | return masked 218 | 219 | def compute_output_shape(self, input_shape): 220 | if type(input_shape[0]) is tuple: 221 | return tuple([None, input_shape[0][1] * input_shape[0][2]]) 222 | else: # generation step 223 | return tuple([None, input_shape[1] * input_shape[2]]) 224 | 225 | def get_config(self): 226 | config = super(Mask, self).get_config() 227 | return config 228 | 229 | -------------------------------------------------------------------------------- /utils/pre_process_mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Adam Byerly & Vittorio Mazzia & Francesco Salvetti. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | import os 19 | import cv2 20 | tf2 = tf.compat.v2 21 | 22 | # constants 23 | MNIST_IMG_SIZE = 28 24 | MNIST_TRAIN_IMAGE_COUNT = 60000 25 | PARALLEL_INPUT_CALLS = 16 26 | 27 | # normalize dataset 28 | def pre_process(image, label): 29 | return (image / 256)[...,None].astype('float32'), tf.keras.utils.to_categorical(label, num_classes=10) 30 | 31 | def image_shift_rand(image, label): 32 | image = tf.reshape(image, [MNIST_IMG_SIZE, MNIST_IMG_SIZE]) 33 | nonzero_x_cols = tf.cast(tf.where(tf.greater( 34 | tf.reduce_sum(image, axis=0), 0)), tf.int32) 35 | nonzero_y_rows = tf.cast(tf.where(tf.greater( 36 | tf.reduce_sum(image, axis=1), 0)), tf.int32) 37 | left_margin = tf.reduce_min(nonzero_x_cols) 38 | right_margin = MNIST_IMG_SIZE - tf.reduce_max(nonzero_x_cols) - 1 39 | top_margin = tf.reduce_min(nonzero_y_rows) 40 | bot_margin = MNIST_IMG_SIZE - tf.reduce_max(nonzero_y_rows) - 1 41 | rand_dirs = tf.random.uniform([2]) 42 | dir_idxs = tf.cast(tf.floor(rand_dirs * 2), tf.int32) 43 | rand_amts = tf.minimum(tf.abs(tf.random.normal([2], 0, .33)), .9999) 44 | x_amts = [tf.floor(-1.0 * rand_amts[0] * 45 | tf.cast(left_margin, tf.float32)), tf.floor(rand_amts[0] * 46 | tf.cast(1 + right_margin, tf.float32))] 47 | y_amts = [tf.floor(-1.0 * rand_amts[1] * 48 | tf.cast(top_margin, tf.float32)), tf.floor(rand_amts[1] * 49 | tf.cast(1 + bot_margin, tf.float32))] 50 | x_amt = tf.cast(tf.gather(x_amts, dir_idxs[1], axis=0), tf.int32) 51 | y_amt = tf.cast(tf.gather(y_amts, dir_idxs[0], axis=0), tf.int32) 52 | image = tf.reshape(image, [MNIST_IMG_SIZE * MNIST_IMG_SIZE]) 53 | image = tf.roll(image, y_amt * MNIST_IMG_SIZE, axis=0) 54 | image = tf.reshape(image, [MNIST_IMG_SIZE, MNIST_IMG_SIZE]) 55 | image = tf.transpose(image) 56 | image = tf.reshape(image, [MNIST_IMG_SIZE * MNIST_IMG_SIZE]) 57 | image = tf.roll(image, x_amt * MNIST_IMG_SIZE, axis=0) 58 | image = tf.reshape(image, [MNIST_IMG_SIZE, MNIST_IMG_SIZE]) 59 | image = tf.transpose(image) 60 | image = tf.reshape(image, [MNIST_IMG_SIZE, MNIST_IMG_SIZE, 1]) 61 | return image, label 62 | 63 | def image_rotate_random_py_func(image, angle): 64 | rot_mat = cv2.getRotationMatrix2D( 65 | (MNIST_IMG_SIZE/2, MNIST_IMG_SIZE/2), int(angle), 1.0) 66 | rotated = cv2.warpAffine(image.numpy(), rot_mat, 67 | (MNIST_IMG_SIZE, MNIST_IMG_SIZE)) 68 | return rotated 69 | 70 | def image_rotate_random(image, label): 71 | rand_amts = tf.maximum(tf.minimum( 72 | tf.random.normal([2], 0, .33), .9999), -.9999) 73 | angle = rand_amts[0] * 30 # degrees 74 | new_image = tf.py_function(image_rotate_random_py_func, 75 | (image, angle), tf.float32) 76 | new_image = tf.cond(rand_amts[1] > 0, lambda: image, lambda: new_image) 77 | return new_image, label 78 | 79 | def image_erase_random(image, label): 80 | sess = tf.compat.v1.Session() 81 | with sess.as_default(): 82 | rand_amts = tf.random.uniform([2]) 83 | x = tf.cast(tf.floor(rand_amts[0]*19)+4, tf.int32) 84 | y = tf.cast(tf.floor(rand_amts[1]*19)+4, tf.int32) 85 | patch = tf.zeros([4, 4]) 86 | mask = tf.pad(patch, [[x, MNIST_IMG_SIZE-x-4], 87 | [y, MNIST_IMG_SIZE-y-4]], 88 | mode='CONSTANT', constant_values=1) 89 | image = tf.multiply(image, tf.expand_dims(mask, -1)) 90 | return image, label 91 | 92 | 93 | def image_squish_random(image, label): 94 | rand_amts = tf.minimum(tf.abs(tf.random.normal([2], 0, .33)), .9999) 95 | width_mod = tf.cast(tf.floor( 96 | (rand_amts[0] * (MNIST_IMG_SIZE / 4)) + 1), tf.int32) 97 | offset_mod = tf.cast(tf.floor(rand_amts[1] * 2.0), tf.int32) 98 | offset = (width_mod // 2) + offset_mod 99 | image = tf.image.resize(image, 100 | [MNIST_IMG_SIZE, MNIST_IMG_SIZE - width_mod], 101 | method=tf2.image.ResizeMethod.LANCZOS3, 102 | preserve_aspect_ratio=False, 103 | antialias=True) 104 | image = tf.image.pad_to_bounding_box( 105 | image, 0, offset, MNIST_IMG_SIZE, MNIST_IMG_SIZE + offset_mod) 106 | image = tf.image.crop_to_bounding_box( 107 | image, 0, 0, MNIST_IMG_SIZE, MNIST_IMG_SIZE) 108 | return image, label 109 | 110 | def generator(image, label): 111 | return (image, label), (label, image) 112 | 113 | def generate_tf_data(X_train, y_train, X_test, y_test, batch_size): 114 | dataset_train = tf.data.Dataset.from_tensor_slices((X_train,y_train)) 115 | dataset_train = dataset_train.shuffle(buffer_size=MNIST_TRAIN_IMAGE_COUNT) 116 | dataset_train = dataset_train.map(image_rotate_random, 117 | num_parallel_calls=PARALLEL_INPUT_CALLS) 118 | dataset_train = dataset_train.map(image_shift_rand, 119 | num_parallel_calls=PARALLEL_INPUT_CALLS) 120 | dataset_train = dataset_train.map(image_squish_random, 121 | num_parallel_calls=PARALLEL_INPUT_CALLS) 122 | dataset_train = dataset_train.map(image_erase_random, 123 | num_parallel_calls=PARALLEL_INPUT_CALLS) 124 | dataset_train = dataset_train.map(generator, 125 | num_parallel_calls=PARALLEL_INPUT_CALLS) 126 | dataset_train = dataset_train.batch(batch_size) 127 | dataset_train = dataset_train.prefetch(-1) 128 | 129 | dataset_test = tf.data.Dataset.from_tensor_slices((X_test, y_test)) 130 | dataset_test = dataset_test.cache() 131 | dataset_test = dataset_test.map(generator, 132 | num_parallel_calls=PARALLEL_INPUT_CALLS) 133 | dataset_test = dataset_test.batch(batch_size) 134 | dataset_test = dataset_test.prefetch(-1) 135 | 136 | return dataset_train, dataset_test 137 | -------------------------------------------------------------------------------- /utils/pre_process_multimnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Vittorio Mazzia & Francesco Salvetti. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | import os 19 | import cv2 20 | 21 | # constants 22 | MULTIMNIST_IMG_SIZE = 36 23 | 24 | def pad_dataset(images,pad): 25 | return np.pad(images,[(0,0),(pad,pad),(pad,pad)]) 26 | 27 | def pre_process(image, label): 28 | return (image / 255)[...,None].astype('float32'), tf.keras.utils.to_categorical(label, num_classes=10) 29 | 30 | def shift_images(images, shifts, max_shift): 31 | l = images.shape[1] 32 | images_sh = np.pad(images,((0,0),(max_shift,max_shift),(max_shift,max_shift),(0,0))) 33 | shifts = max_shift - shifts 34 | batches = np.arange(len(images))[:,None,None] 35 | images_sh = images_sh[batches,np.arange(l+max_shift*2)[None,:,None],(shifts[:,0,None]+np.arange(0,l))[:,None,:]] 36 | images_sh = images_sh[batches,(shifts[:,1,None]+np.arange(0,l))[...,None],np.arange(l)[None,None]] 37 | return images_sh 38 | 39 | def merge_with_image(images,labels,i,shift,n_multi=1000): #for an image i, generate n_multi merged images 40 | base_image = images[i] 41 | base_label = labels[i] 42 | indexes = np.arange(len(images))[np.bitwise_not((labels==base_label).all(axis=-1))] 43 | indexes = np.random.choice(indexes,n_multi,replace=False) 44 | top_images = images[indexes] 45 | top_labels = labels[indexes] 46 | shifts = np.random.randint(-shift,shift+1,(n_multi+1,2)) 47 | images_sh = shift_images(np.concatenate((base_image[None],top_images),axis=0),shifts,shift) 48 | base_sh = images_sh[0] 49 | top_sh = images_sh[1:] 50 | merged = np.clip(base_sh+top_sh,0,1) 51 | merged_labels = base_label+top_labels 52 | return merged,merged_labels 53 | 54 | def multi_mnist_generator(images,labels,shift): 55 | def multi_mnist(): 56 | while True: 57 | i = np.random.randint(len(images)) 58 | j = np.random.randint(len(images)) 59 | while np.all(images[i]==images[j]): 60 | j = np.random.randint(len(images)) 61 | base = shift_images(images[i:i+1],np.random.randint(-shift,shift+1,(1,2)),shift)[0] 62 | top = shift_images(images[j:j+1],np.random.randint(-shift,shift+1,(1,2)),shift)[0] 63 | merged = tf.clip_by_value(tf.add(base, top),0,1) 64 | yield (merged,labels[i],labels[j]),(labels[i]+labels[j],base,top) 65 | return multi_mnist 66 | 67 | def multi_mnist_generator_validation(images,labels,shift): 68 | def multi_mnist_val(): 69 | for i in range(len(images)): 70 | j = np.random.randint(len(images)) 71 | while np.all(labels[i]==labels[j]): 72 | j = np.random.randint(len(images)) 73 | base = shift_images(images[i:i+1],np.random.randint(-shift,shift+1,(1,2)),shift)[0] 74 | top = shift_images(images[j:j+1],np.random.randint(-shift,shift+1,(1,2)),shift)[0] 75 | merged = tf.clip_by_value(tf.add(base, top),0,1) 76 | yield (merged,labels[i],labels[j]),(labels[i]+labels[j],base,top) 77 | return multi_mnist_val 78 | 79 | def multi_mnist_generator_test(images,labels,shift,n_multi=1000): 80 | def multi_mnist_test(): 81 | for i in range(len(images)): 82 | X_merged,y_merged = merge_with_image(images,labels,i,shift,n_multi) 83 | yield X_merged,y_merged 84 | return multi_mnist_test 85 | 86 | def generate_tf_data(X_train, y_train, X_test, y_test, batch_size, shift): 87 | input_shape = (MULTIMNIST_IMG_SIZE,MULTIMNIST_IMG_SIZE,1) 88 | dataset_train = tf.data.Dataset.from_generator(multi_mnist_generator(X_train,y_train,shift), 89 | output_shapes=((input_shape,(10,),(10,)),((10,),input_shape,input_shape)), 90 | output_types=((tf.float32,tf.float32,tf.float32), 91 | (tf.float32,tf.float32,tf.float32))) 92 | dataset_train = dataset_train.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE) 93 | dataset_test = tf.data.Dataset.from_generator(multi_mnist_generator_validation(X_test,y_test,shift), 94 | output_shapes=((input_shape,(10,),(10,)),((10,),input_shape,input_shape)), 95 | output_types=((tf.float32,tf.float32,tf.float32), 96 | (tf.float32,tf.float32,tf.float32))) 97 | dataset_test = dataset_test.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE) 98 | return dataset_train, dataset_test 99 | 100 | def generate_tf_data_test(X_test, y_test, shift, n_multi=1000, random_seed=42): 101 | input_shape = (MULTIMNIST_IMG_SIZE,MULTIMNIST_IMG_SIZE,1) 102 | np.random.seed(random_seed) 103 | dataset_test = tf.data.Dataset.from_generator(multi_mnist_generator_test(X_test,y_test,shift,n_multi), 104 | output_shapes=((n_multi,)+input_shape,(n_multi,10,)), 105 | output_types=(tf.float32,tf.float32)) 106 | dataset_test = dataset_test.prefetch(tf.data.experimental.AUTOTUNE) 107 | return dataset_test 108 | -------------------------------------------------------------------------------- /utils/pre_process_smallnorb.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Vittorio Mazzia & Francesco Salvetti. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | import os 19 | from tqdm.notebook import tqdm 20 | 21 | 22 | # constants 23 | SAMPLES = 24300 24 | INPUT_SHAPE = 96 25 | PATCH_SMALLNORB = 48 26 | N_CLASSES = 5 27 | MAX_DELTA = 2.0 28 | LOWER_CONTRAST = 0.5 29 | UPPER_CONTRAST = 1.5 30 | PARALLEL_INPUT_CALLS = 16 31 | 32 | 33 | def pre_process(ds): 34 | X = np.empty((SAMPLES, INPUT_SHAPE, INPUT_SHAPE, 2)) 35 | y = np.empty((SAMPLES,)) 36 | 37 | for index, d in tqdm(enumerate(ds.batch(1))): 38 | X[index, :, :, 0:1] = d['image'] 39 | X[index, :, :, 1:2] = d['image2'] 40 | y[index] = d['label_category'] 41 | return X, y 42 | 43 | 44 | def standardize(x, y): 45 | x[...,0] = (x[...,0] - x[...,0].mean()) / x[...,0].std() 46 | x[...,1] = (x[...,1] - x[...,1].mean()) / x[...,1].std() 47 | return x, tf.one_hot(y, N_CLASSES) 48 | 49 | def rescale(x, y, config): 50 | with tf.device("/cpu:0"): 51 | x = tf.image.resize(x , [config['scale_smallnorb'], config['scale_smallnorb']]) 52 | return x, y 53 | 54 | def test_patches(x, y, config): 55 | res = (config['scale_smallnorb'] - config['patch_smallnorb']) // 2 56 | return x[:,res:-res,res:-res,:], y 57 | 58 | 59 | def generator(image, label): 60 | return (image, label), (label, image) 61 | 62 | def random_patches(x, y): 63 | return tf.image.random_crop(x, [PATCH_SMALLNORB, PATCH_SMALLNORB, 2]), y 64 | 65 | def random_brightness(x, y): 66 | return tf.image.random_brightness(x, max_delta=MAX_DELTA), y 67 | 68 | def random_contrast(x, y): 69 | return tf.image.random_contrast(x, lower=LOWER_CONTRAST, upper=UPPER_CONTRAST), y 70 | 71 | 72 | def generate_tf_data(X_train, y_train, X_test_patch, y_test, batch_size): 73 | dataset_train = tf.data.Dataset.from_tensor_slices((X_train, y_train)) 74 | # dataset_train = dataset_train.shuffle(buffer_size=SAMPLES) not needed if imported with tfds 75 | dataset_train = dataset_train.map(random_patches, 76 | num_parallel_calls=PARALLEL_INPUT_CALLS) 77 | dataset_train = dataset_train.map(random_brightness, 78 | num_parallel_calls=PARALLEL_INPUT_CALLS) 79 | dataset_train = dataset_train.map(random_contrast, 80 | num_parallel_calls=PARALLEL_INPUT_CALLS) 81 | dataset_train = dataset_train.map(generator, 82 | num_parallel_calls=PARALLEL_INPUT_CALLS) 83 | dataset_train = dataset_train.batch(batch_size) 84 | dataset_train = dataset_train.prefetch(-1) 85 | 86 | dataset_test = tf.data.Dataset.from_tensor_slices((X_test_patch, y_test)) 87 | dataset_test = dataset_test.cache() 88 | dataset_test = dataset_test.map(generator, 89 | num_parallel_calls=PARALLEL_INPUT_CALLS) 90 | dataset_test = dataset_test.batch(1) 91 | dataset_test = dataset_test.prefetch(-1) 92 | 93 | return dataset_train, dataset_test 94 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Vittorio Mazzia & Francesco Salvetti. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | 19 | def learn_scheduler(lr_dec, lr): 20 | def learning_scheduler_fn(epoch): 21 | lr_new = lr * (lr_dec ** epoch) 22 | return lr_new if lr_new >= 5e-5 else 5e-5 23 | return learning_scheduler_fn 24 | 25 | 26 | def get_callbacks(tb_log_save_path, saved_model_path, lr_dec, lr): 27 | tb = tf.keras.callbacks.TensorBoard(log_dir=tb_log_save_path, histogram_freq=0) 28 | 29 | model_checkpoint = tf.keras.callbacks.ModelCheckpoint(saved_model_path, monitor='val_Efficient_CapsNet_accuracy', 30 | save_best_only=True, save_weights_only=True, verbose=1) 31 | 32 | lr_decay = tf.keras.callbacks.LearningRateScheduler(learn_scheduler(lr_dec, lr)) 33 | 34 | reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_CapsNet_accuracy', factor=0.9, 35 | patience=4, min_lr=0.00001, min_delta=0.0001, mode='max') 36 | 37 | return [tb, model_checkpoint, lr_decay] 38 | 39 | 40 | def marginLoss(y_true, y_pred): 41 | lbd = 0.5 42 | m_plus = 0.9 43 | m_minus = 0.1 44 | 45 | L = y_true * tf.square(tf.maximum(0., m_plus - y_pred)) + \ 46 | lbd * (1 - y_true) * tf.square(tf.maximum(0., y_pred - m_minus)) 47 | 48 | return tf.reduce_mean(tf.reduce_sum(L, axis=1)) 49 | 50 | 51 | def multiAccuracy(y_true, y_pred): 52 | label_pred = tf.argsort(y_pred,axis=-1)[:,-2:] 53 | label_true = tf.argsort(y_true,axis=-1)[:,-2:] 54 | 55 | acc = tf.reduce_sum(tf.cast(label_pred[:,:1]==label_true,tf.int8),axis=-1) + \ 56 | tf.reduce_sum(tf.cast(label_pred[:,1:]==label_true,tf.int8),axis=-1) 57 | acc /= 2 58 | return tf.reduce_mean(acc,axis=-1) 59 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Vittorio Mazzia & Francesco Salvetti. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | import matplotlib.pyplot as plt 19 | from ipywidgets import interact, widgets, interactive 20 | import os 21 | import pandas as pd 22 | 23 | class AffineVisualizer(object): 24 | # only MNIST 25 | def __init__(self, model, X, y, hist=True): 26 | self.min_value = - 0.30 27 | self.max_value = + 0.30 28 | self.step = 0.05 29 | self.sliders = {str(i):widgets.FloatSlider(min=self.min_value, max=self.max_value, step=self.step) for i in range(16)} 30 | self.text = widgets.IntText() 31 | self.sliders['index'] = self.text 32 | self.model = model 33 | self.X = X 34 | self.y = y 35 | self.hist = hist 36 | 37 | def affineTransform(self, **info): 38 | 39 | index = abs(int(info['index'])) 40 | tmp = np.zeros([1, 10, 16]) 41 | 42 | for d in range(16): 43 | tmp[:,:,d] = info[str(d)] 44 | 45 | y_pred, X_gen = self.model.predict([self.X[index:index+1], self.y[index:index+1], tmp]) 46 | 47 | if self.hist: 48 | fig, ax = plt.subplots(1, 3, figsize=(15,3)) 49 | else: 50 | fig, ax = plt.subplots(1, 2, figsize=(12,12)) 51 | ax[0].imshow(self.X[index,...,0], cmap='gray') 52 | ax[0].set_title('Input Digit') 53 | ax[1].imshow(X_gen[0,...,0], cmap='gray') 54 | ax[1].set_title('Output Generator') 55 | if self.hist: 56 | ax[2].set_title('Output Caps Length') 57 | ax[2].bar(range(10), y_pred[0]) 58 | plt.show() 59 | 60 | def on_button_clicked(self, k): 61 | for i in range(16): 62 | self.sliders[str(i)].value = 0 63 | 64 | def start(self): 65 | button = widgets.Button(description="Reset") 66 | button.on_click(self.on_button_clicked) 67 | 68 | main = widgets.HBox([self.text, button]) 69 | u1 = widgets.HBox([self.sliders[str(i)] for i in range(0,4)]) 70 | u2 = widgets.HBox([self.sliders[str(i)] for i in range(4,8)]) 71 | u3 = widgets.HBox([self.sliders[str(i)] for i in range(8,12)]) 72 | u4 = widgets.HBox([self.sliders[str(i)] for i in range(12,16)]) 73 | 74 | out = widgets.interactive_output(self.affineTransform, self.sliders) 75 | 76 | display(main, u1, u2, u3, u4, out) 77 | 78 | 79 | def plotHistory(history): 80 | """ 81 | Plot the loss and accuracy curves for training and validation 82 | """ 83 | pd.DataFrame(history.history).plot(figsize=(8, 5), y=list(history.history.keys())[0:-1:2]) 84 | plt.grid(True) 85 | plt.show() 86 | 87 | 88 | def plotImages(X_batch, y_batch, n_img, class_names): 89 | 90 | max_c = 5 # max images per row 91 | 92 | if n_img <= max_c: 93 | r = 1 94 | c = n_img 95 | else: 96 | r = int(np.ceil(n_img/max_c)) 97 | c = max_c 98 | 99 | fig, axes = plt.subplots(r, c, figsize=(15,15)) 100 | axes = axes.flatten() 101 | for img_batch, label_batch, ax in zip(X_batch, y_batch, axes): 102 | ax.imshow(img_batch, cmap='gray') 103 | ax.grid() 104 | ax.set_title('Class: {}'.format(class_names[np.argmax(label_batch)])) 105 | plt.tight_layout() 106 | plt.show() 107 | 108 | def plotWrongImages(X_test, y_test, y_pred, n_img, class_names): 109 | max_c = 5 # max images per row 110 | 111 | indices = np.where(np.argmax(y_pred, -1) != np.argmax(y_test, -1))[0] # indices wrrong images 112 | 113 | if n_img <= max_c: 114 | r = 1 115 | c = n_img 116 | else: 117 | r = int(np.ceil(n_img/max_c)) 118 | c = max_c 119 | 120 | fig, axes = plt.subplots(r, c, figsize=(20,20)) 121 | axes = axes.flatten() 122 | for index, ax in zip(indices, axes): 123 | ax.imshow(X_test[index,:,:,0], cmap='gray') 124 | ax.set_axis_off() 125 | ax.set_title('Class: {} ({:.3f}) \nPred [1]: {} ({:.3f}) \nPred [2]: {} ({:.3f})'.format(class_names[np.argmax(y_test[index])], y_pred[index][np.argmax(y_test[index])], 126 | class_names[np.argmax(y_pred[index])], np.max(y_pred[index]), 127 | class_names[np.argsort(y_pred[index], axis=0)[-2]], y_pred[index][np.argsort(y_pred[index], axis=0)[-2]]), 128 | color='black', fontsize=14) 129 | plt.tight_layout() 130 | plt.show() 131 | --------------------------------------------------------------------------------