├── LICENSE ├── MultiModelCNN.ipynb └── README.md /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MultiModelCNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "MultiModelCNN.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyN+9ITdnGE5c3RaXbagD/gz", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | } 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": { 35 | "id": "ENi7VEVSnHhK" 36 | }, 37 | "outputs": [], 38 | "source": [ 39 | "######################################### Connecting to Google Drive\n", 40 | "\n", 41 | "from google.colab import drive\n", 42 | "drive.mount('/content/drive')" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "source": [ 48 | "######################## Your file directory in Google Drive\n", 49 | "\n", 50 | "%cd /content/drive/MyDrive/MMCNN" 51 | ], 52 | "metadata": { 53 | "id": "lumXfxIjnIh1" 54 | }, 55 | "execution_count": null, 56 | "outputs": [] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "source": [ 61 | "from tensorflow import keras \n", 62 | "from tensorflow.keras.layers import Conv2D, Conv3D, Flatten, Dense, Reshape, BatchNormalization, MaxPool2D, MaxPooling1D,Add, ConvLSTM2D, LSTM, Conv1D\n", 63 | "from tensorflow.keras.layers import Dropout, Input\n", 64 | "from tensorflow.keras.models import Model\n", 65 | "#from keras.optimizers import Adam, SGD\n", 66 | "from tensorflow.keras.callbacks import ModelCheckpoint\n", 67 | "from keras.utils import np_utils\n", 68 | "from tensorflow.keras import backend as Kb\n", 69 | "from tensorflow.keras.layers import Lambda\n", 70 | "from tensorflow.keras.layers import Activation\n", 71 | "from tensorflow.keras.layers import add, concatenate\n", 72 | "from tensorflow.keras.layers import AveragePooling2D\n", 73 | "#from keras.utils.vis_utils import plot_model\n", 74 | "from tensorflow.keras.utils import plot_model\n", 75 | "#from keras.utils import plot_model\n", 76 | "import tensorflow\n", 77 | " \n", 78 | "from sklearn.model_selection import train_test_split\n", 79 | "from sklearn.metrics import confusion_matrix, accuracy_score, classification_report, cohen_kappa_score\n", 80 | " \n", 81 | "from sklearn.decomposition import FactorAnalysis\n", 82 | "from sklearn.decomposition import PCA\n", 83 | "from operator import truediv\n", 84 | " \n", 85 | "from plotly.offline import init_notebook_mode\n", 86 | " \n", 87 | "import numpy as np\n", 88 | "import matplotlib.pyplot as plt\n", 89 | "import scipy.io as sio\n", 90 | "import os\n", 91 | "#!pip install spectral\n", 92 | "import spectral" 93 | ], 94 | "metadata": { 95 | "id": "65eE0GP6nL5k" 96 | }, 97 | "execution_count": null, 98 | "outputs": [] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "source": [ 103 | "## GLOBAL VARIABLES\n", 104 | "test_ratio = 0.5\n", 105 | "#windowSize = 8" 106 | ], 107 | "metadata": { 108 | "id": "SM21YOlzni8I" 109 | }, 110 | "execution_count": null, 111 | "outputs": [] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "source": [ 116 | "def loadData(name):\n", 117 | " data_path = os.path.join(os.getcwd(),'Data/')\n", 118 | " if name == 'BrunswickS1':\n", 119 | " \n", 120 | " data = sio.loadmat(os.path.join(data_path, 'BrunswickS1.mat'))['BrunswickS1']\n", 121 | " labels = sio.loadmat(os.path.join(data_path, 'Brunswick_gt.mat'))['Brunswick_gt']\n", 122 | "\n", 123 | " else if name == 'BrunswickS2':\n", 124 | " \n", 125 | " data = sio.loadmat(os.path.join(data_path, 'BrunswickS2.mat'))['BrunswickS2]\n", 126 | " labels = sio.loadmat(os.path.join(data_path, 'Brunswick_gt.mat'))['Brunswick_gt']\n", 127 | "\n", 128 | " else if name == 'BrunswickDEM':\n", 129 | " \n", 130 | " data = sio.loadmat(os.path.join(data_path, 'BrunswickDEM.mat'))['BrunswickDEM]\n", 131 | " labels = sio.loadmat(os.path.join(data_path, 'Brunswick_gt.mat'))['Brunswick_gt']\n", 132 | "\n", 133 | "\n", 134 | " return data, labels" 135 | ], 136 | "metadata": { 137 | "id": "VeYHMgFNnl3r" 138 | }, 139 | "execution_count": null, 140 | "outputs": [] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "source": [ 145 | "def splitTrainTestSet(X, y, testRatio, randomState=345):\n", 146 | " X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=testRatio, random_state=randomState,\n", 147 | " stratify=y)\n", 148 | " return X_train, X_test, y_train, y_test" 149 | ], 150 | "metadata": { 151 | "id": "VaYeGGDKnvZC" 152 | }, 153 | "execution_count": null, 154 | "outputs": [] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "source": [ 159 | "def padWithZeros(X, margin=2):\n", 160 | " newX = np.zeros((X.shape[0] + 2 * margin, X.shape[1] + 2* margin, X.shape[2]))\n", 161 | " x_offset = margin\n", 162 | " y_offset = margin\n", 163 | " newX[x_offset:X.shape[0] + x_offset, y_offset:X.shape[1] + y_offset, :] = X\n", 164 | " return newX" 165 | ], 166 | "metadata": { 167 | "id": "FWzvHIR_nzHa" 168 | }, 169 | "execution_count": null, 170 | "outputs": [] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "source": [ 175 | "def createImageCubes(X, y, windowSize=8, removeZeroLabels = True):\n", 176 | " margin = int((windowSize) / 2)\n", 177 | " zeroPaddedX = padWithZeros(X, margin=margin)\n", 178 | " # split patches\n", 179 | " patchesData = np.zeros((X.shape[0] * X.shape[1], windowSize, windowSize, X.shape[2]))\n", 180 | " patchesLabels = np.zeros((X.shape[0] * X.shape[1]))\n", 181 | " patchIndex = 0\n", 182 | " for r in range(margin, zeroPaddedX.shape[0] - margin):\n", 183 | " for c in range(margin, zeroPaddedX.shape[1] - margin):\n", 184 | " patch = zeroPaddedX[r - margin:r + margin , c - margin:c + margin ] \n", 185 | " patchesData[patchIndex, :, :, :] = patch\n", 186 | " patchesLabels[patchIndex] = y[r-margin, c-margin]\n", 187 | " patchIndex = patchIndex + 1\n", 188 | " if removeZeroLabels:\n", 189 | " patchesData = patchesData[patchesLabels>0,:,:,:]\n", 190 | " patchesLabels = patchesLabels[patchesLabels>0]\n", 191 | " patchesLabels -= 1\n", 192 | " return patchesData, patchesLabels" 193 | ], 194 | "metadata": { 195 | "id": "MSHKmHhgn0_p" 196 | }, 197 | "execution_count": null, 198 | "outputs": [] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "source": [ 203 | "dataset = 'BrunswickS2'\n", 204 | "X, y = loadData(dataset)\n", 205 | "\n", 206 | "X=(X-np.min(X))/(np.max(X)-np.min(X))\n", 207 | "X.shape, y.shape" 208 | ], 209 | "metadata": { 210 | "id": "0WSk6mS2n2yQ" 211 | }, 212 | "execution_count": null, 213 | "outputs": [] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "source": [ 218 | "dataset = 'BrunswickS1'\n", 219 | "X1, y = loadData(dataset)\n", 220 | "\n", 221 | "X1=(X1-np.min(X1))/(np.max(X1)-np.min(X1))\n", 222 | "X1.shape, y.shape" 223 | ], 224 | "metadata": { 225 | "id": "6P-6_AHCn5Dt" 226 | }, 227 | "execution_count": null, 228 | "outputs": [] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "source": [ 233 | "dataset = 'BrunswickDEM'\n", 234 | "X2, y = loadData(dataset)\n", 235 | "\n", 236 | "X2=(X2-np.min(X2))/(np.max(X2)-np.min(X2))\n", 237 | "\n", 238 | "X2=X2.reshape(667, 2323,1)\n", 239 | "X2.shape, y.shape\n" 240 | ], 241 | "metadata": { 242 | "id": "jpolBnVKn9mJ" 243 | }, 244 | "execution_count": null, 245 | "outputs": [] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "source": [ 250 | "X11, y1 = createImageCubes(X, y, windowSize=4)\n", 251 | "\n", 252 | "X11.shape, y1.shape" 253 | ], 254 | "metadata": { 255 | "id": "l04lJztSorNF" 256 | }, 257 | "execution_count": null, 258 | "outputs": [] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "source": [ 263 | "X12, y2 = createImageCubes(X1, y, windowSize=8)\n", 264 | "\n", 265 | "X12.shape, y2.shape" 266 | ], 267 | "metadata": { 268 | "id": "qRrVDEt2oznM" 269 | }, 270 | "execution_count": null, 271 | "outputs": [] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "source": [ 276 | "X12 = X12.reshape((X12.shape[0],8,8,4,1))\n", 277 | "X12.shape" 278 | ], 279 | "metadata": { 280 | "id": "aE5BF5UHo1W_" 281 | }, 282 | "execution_count": null, 283 | "outputs": [] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "source": [ 288 | "X13, y3 = createImageCubes(X2, y, windowSize=8)\n", 289 | "\n", 290 | "X13.shape, y3.shape" 291 | ], 292 | "metadata": { 293 | "id": "iWvELGi0o1aP" 294 | }, 295 | "execution_count": null, 296 | "outputs": [] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "source": [ 301 | "X1train, X1test, y1train, y1test = splitTrainTestSet(X11, y1, test_ratio)\n", 302 | "\n", 303 | "X1train.shape, X1test.shape, y1train.shape, y1test.shape\n", 304 | "\n" 305 | ], 306 | "metadata": { 307 | "id": "qKwiPQy6o56z" 308 | }, 309 | "execution_count": null, 310 | "outputs": [] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "source": [ 315 | "X2train, X2test, y2train, y2test = splitTrainTestSet(X12, y2, test_ratio)\n", 316 | "\n", 317 | "X2train.shape, X2test.shape, y2train.shape, y2test.shape\n" 318 | ], 319 | "metadata": { 320 | "id": "YHyFe53Ko59-" 321 | }, 322 | "execution_count": null, 323 | "outputs": [] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "source": [ 328 | "X3train, X3test, y3train, y3test = splitTrainTestSet(X13, y3, test_ratio)\n", 329 | "\n", 330 | "X3train.shape, X3test.shape, y3train.shape, y3test.shape\n" 331 | ], 332 | "metadata": { 333 | "id": "HKC7Kzzxo6BL" 334 | }, 335 | "execution_count": null, 336 | "outputs": [] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "source": [ 341 | "ytrain = tensorflow.keras.utils.to_categorical(y1train)\n", 342 | "ytrain.shape\n" 343 | ], 344 | "metadata": { 345 | "id": "lZ_dCS5Po6EM" 346 | }, 347 | "execution_count": null, 348 | "outputs": [] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "source": [ 353 | "output_units = 11 " 354 | ], 355 | "metadata": { 356 | "id": "77-zFmR6o6HR" 357 | }, 358 | "execution_count": null, 359 | "outputs": [] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "source": [ 364 | "import tensorflow as tf\n", 365 | "import tensorflow_addons as tfa\n", 366 | "from tensorflow import keras\n" 367 | ], 368 | "metadata": { 369 | "id": "xcUnXGwao1dN" 370 | }, 371 | "execution_count": null, 372 | "outputs": [] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "source": [ 377 | "from tensorflow.keras import layers" 378 | ], 379 | "metadata": { 380 | "id": "BnucHzkHo1gb" 381 | }, 382 | "execution_count": null, 383 | "outputs": [] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "source": [ 388 | "######################## Swin Transformer settings (Try to change it)\n", 389 | "\n", 390 | "input_shape = (8, 8, 1)\n", 391 | "patch_size = (2, 2) # 2-by-2 sized patches\n", 392 | "dropout_rate = 0.03 # Dropout rate\n", 393 | "num_heads = 8 # Attention heads\n", 394 | "embed_dim = 64 # Embedding dimension\n", 395 | "num_mlp = 256 # MLP layer size\n", 396 | "qkv_bias = True # Convert embedded patches to query, key, and values with a learnable additive value\n", 397 | "window_size = 2 # Size of attention window\n", 398 | "shift_size = 1 # Size of shifting window\n", 399 | "image_dimension = 8 # Initial image size\n", 400 | "\n", 401 | "num_patch_x = input_shape[0] // patch_size[0]\n", 402 | "num_patch_y = input_shape[1] // patch_size[1]\n" 403 | ], 404 | "metadata": { 405 | "id": "mpVRHCLvorUq" 406 | }, 407 | "execution_count": null, 408 | "outputs": [] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "source": [ 413 | "def window_partition(x, window_size):\n", 414 | " _, height, width, channels = x.shape\n", 415 | " patch_num_y = height // window_size\n", 416 | " patch_num_x = width // window_size\n", 417 | " x = tf.reshape(\n", 418 | " x, shape=(-1, patch_num_y, window_size, patch_num_x, window_size, channels)\n", 419 | " )\n", 420 | " x = tf.transpose(x, (0, 1, 3, 2, 4, 5))\n", 421 | " windows = tf.reshape(x, shape=(-1, window_size, window_size, channels))\n", 422 | " return windows\n", 423 | "\n", 424 | "\n", 425 | "def window_reverse(windows, window_size, height, width, channels):\n", 426 | " patch_num_y = height // window_size\n", 427 | " patch_num_x = width // window_size\n", 428 | " x = tf.reshape(\n", 429 | " windows,\n", 430 | " shape=(-1, patch_num_y, patch_num_x, window_size, window_size, channels),\n", 431 | " )\n", 432 | " x = tf.transpose(x, perm=(0, 1, 3, 2, 4, 5))\n", 433 | " x = tf.reshape(x, shape=(-1, height, width, channels))\n", 434 | " return x\n", 435 | "\n", 436 | "\n", 437 | "class DropPath(layers.Layer):\n", 438 | " def __init__(self, drop_prob=None, **kwargs):\n", 439 | " super(DropPath, self).__init__(**kwargs)\n", 440 | " self.drop_prob = drop_prob\n", 441 | "\n", 442 | " def call(self, x):\n", 443 | " input_shape = tf.shape(x)\n", 444 | " batch_size = input_shape[0]\n", 445 | " rank = x.shape.rank\n", 446 | " shape = (batch_size,) + (1,) * (rank - 1)\n", 447 | " random_tensor = (1 - self.drop_prob) + tf.random.uniform(shape, dtype=x.dtype)\n", 448 | " path_mask = tf.floor(random_tensor)\n", 449 | " output = tf.math.divide(x, 1 - self.drop_prob) * path_mask\n", 450 | " return output" 451 | ], 452 | "metadata": { 453 | "id": "IzYwvRhJorXq" 454 | }, 455 | "execution_count": null, 456 | "outputs": [] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "source": [ 461 | "class WindowAttention(layers.Layer):\n", 462 | " def __init__(\n", 463 | " self, dim, window_size, num_heads, qkv_bias=True, dropout_rate=0.0, **kwargs\n", 464 | " ):\n", 465 | " super(WindowAttention, self).__init__(**kwargs)\n", 466 | " self.dim = dim\n", 467 | " self.window_size = window_size\n", 468 | " self.num_heads = num_heads\n", 469 | " self.scale = (dim // num_heads) ** -0.5\n", 470 | " self.qkv = layers.Dense(dim * 3, use_bias=qkv_bias)\n", 471 | " self.dropout = layers.Dropout(dropout_rate)\n", 472 | " self.proj = layers.Dense(dim)\n", 473 | "\n", 474 | " def build(self, input_shape):\n", 475 | " num_window_elements = (2 * self.window_size[0] - 1) * (\n", 476 | " 2 * self.window_size[1] - 1\n", 477 | " )\n", 478 | " self.relative_position_bias_table = self.add_weight(\n", 479 | " shape=(num_window_elements, self.num_heads),\n", 480 | " initializer=tf.initializers.Zeros(),\n", 481 | " trainable=True,\n", 482 | " )\n", 483 | " coords_h = np.arange(self.window_size[0])\n", 484 | " coords_w = np.arange(self.window_size[1])\n", 485 | " coords_matrix = np.meshgrid(coords_h, coords_w, indexing=\"ij\")\n", 486 | " coords = np.stack(coords_matrix)\n", 487 | " coords_flatten = coords.reshape(2, -1)\n", 488 | " relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]\n", 489 | " relative_coords = relative_coords.transpose([1, 2, 0])\n", 490 | " relative_coords[:, :, 0] += self.window_size[0] - 1\n", 491 | " relative_coords[:, :, 1] += self.window_size[1] - 1\n", 492 | " relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n", 493 | " relative_position_index = relative_coords.sum(-1)\n", 494 | "\n", 495 | " self.relative_position_index = tf.Variable(\n", 496 | " initial_value=tf.convert_to_tensor(relative_position_index), trainable=False\n", 497 | " )\n", 498 | "\n", 499 | " def call(self, x, mask=None):\n", 500 | " _, size, channels = x.shape\n", 501 | " head_dim = channels // self.num_heads\n", 502 | " x_qkv = self.qkv(x)\n", 503 | " x_qkv = tf.reshape(x_qkv, shape=(-1, size, 3, self.num_heads, head_dim))\n", 504 | " x_qkv = tf.transpose(x_qkv, perm=(2, 0, 3, 1, 4))\n", 505 | " q, k, v = x_qkv[0], x_qkv[1], x_qkv[2]\n", 506 | " q = q * self.scale\n", 507 | " k = tf.transpose(k, perm=(0, 1, 3, 2))\n", 508 | " attn = q @ k\n", 509 | "\n", 510 | " num_window_elements = self.window_size[0] * self.window_size[1]\n", 511 | " relative_position_index_flat = tf.reshape(\n", 512 | " self.relative_position_index, shape=(-1,)\n", 513 | " )\n", 514 | " relative_position_bias = tf.gather(\n", 515 | " self.relative_position_bias_table, relative_position_index_flat\n", 516 | " )\n", 517 | " relative_position_bias = tf.reshape(\n", 518 | " relative_position_bias, shape=(num_window_elements, num_window_elements, -1)\n", 519 | " )\n", 520 | " relative_position_bias = tf.transpose(relative_position_bias, perm=(2, 0, 1))\n", 521 | " attn = attn + tf.expand_dims(relative_position_bias, axis=0)\n", 522 | "\n", 523 | " if mask is not None:\n", 524 | " nW = mask.get_shape()[0]\n", 525 | " mask_float = tf.cast(\n", 526 | " tf.expand_dims(tf.expand_dims(mask, axis=1), axis=0), tf.float32\n", 527 | " )\n", 528 | " attn = (\n", 529 | " tf.reshape(attn, shape=(-1, nW, self.num_heads, size, size))\n", 530 | " + mask_float\n", 531 | " )\n", 532 | " attn = tf.reshape(attn, shape=(-1, self.num_heads, size, size))\n", 533 | " attn = keras.activations.softmax(attn, axis=-1)\n", 534 | " else:\n", 535 | " attn = keras.activations.softmax(attn, axis=-1)\n", 536 | " attn = self.dropout(attn)\n", 537 | "\n", 538 | " x_qkv = attn @ v\n", 539 | " x_qkv = tf.transpose(x_qkv, perm=(0, 2, 1, 3))\n", 540 | " x_qkv = tf.reshape(x_qkv, shape=(-1, size, channels))\n", 541 | " x_qkv = self.proj(x_qkv)\n", 542 | " x_qkv = self.dropout(x_qkv)\n", 543 | " return x_qkv" 544 | ], 545 | "metadata": { 546 | "id": "Qr2KOkIApLQU" 547 | }, 548 | "execution_count": null, 549 | "outputs": [] 550 | }, 551 | { 552 | "cell_type": "code", 553 | "source": [ 554 | "class SwinTransformer(layers.Layer):\n", 555 | " def __init__(\n", 556 | " self,\n", 557 | " dim,\n", 558 | " num_patch,\n", 559 | " num_heads,\n", 560 | " window_size=7,\n", 561 | " shift_size=0,\n", 562 | " num_mlp=1024,\n", 563 | " qkv_bias=True,\n", 564 | " dropout_rate=0.0,\n", 565 | " **kwargs,\n", 566 | " ):\n", 567 | " super(SwinTransformer, self).__init__(**kwargs)\n", 568 | "\n", 569 | " self.dim = dim # number of input dimensions\n", 570 | " self.num_patch = num_patch # number of embedded patches\n", 571 | " self.num_heads = num_heads # number of attention heads\n", 572 | " self.window_size = window_size # size of window\n", 573 | " self.shift_size = shift_size # size of window shift\n", 574 | " self.num_mlp = num_mlp # number of MLP nodes\n", 575 | "\n", 576 | " self.norm1 = layers.LayerNormalization(epsilon=1e-5)\n", 577 | " self.attn = WindowAttention(\n", 578 | " dim,\n", 579 | " window_size=(self.window_size, self.window_size),\n", 580 | " num_heads=num_heads,\n", 581 | " qkv_bias=qkv_bias,\n", 582 | " dropout_rate=dropout_rate,\n", 583 | " )\n", 584 | " self.drop_path = DropPath(dropout_rate)\n", 585 | " self.norm2 = layers.LayerNormalization(epsilon=1e-5)\n", 586 | "\n", 587 | " self.mlp = keras.Sequential(\n", 588 | " [\n", 589 | " layers.Dense(num_mlp),\n", 590 | " layers.Activation(keras.activations.gelu),\n", 591 | " layers.Dropout(dropout_rate),\n", 592 | " layers.Dense(dim),\n", 593 | " layers.Dropout(dropout_rate),\n", 594 | " ]\n", 595 | " )\n", 596 | "\n", 597 | " if min(self.num_patch) < self.window_size:\n", 598 | " self.shift_size = 0\n", 599 | " self.window_size = min(self.num_patch)\n", 600 | "\n", 601 | " def build(self, input_shape):\n", 602 | " if self.shift_size == 0:\n", 603 | " self.attn_mask = None\n", 604 | " else:\n", 605 | " height, width = self.num_patch\n", 606 | " h_slices = (\n", 607 | " slice(0, -self.window_size),\n", 608 | " slice(-self.window_size, -self.shift_size),\n", 609 | " slice(-self.shift_size, None),\n", 610 | " )\n", 611 | " w_slices = (\n", 612 | " slice(0, -self.window_size),\n", 613 | " slice(-self.window_size, -self.shift_size),\n", 614 | " slice(-self.shift_size, None),\n", 615 | " )\n", 616 | " mask_array = np.zeros((1, height, width, 1))\n", 617 | " count = 0\n", 618 | " for h in h_slices:\n", 619 | " for w in w_slices:\n", 620 | " mask_array[:, h, w, :] = count\n", 621 | " count += 1\n", 622 | " mask_array = tf.convert_to_tensor(mask_array)\n", 623 | "\n", 624 | " # mask array to windows\n", 625 | " mask_windows = window_partition(mask_array, self.window_size)\n", 626 | " mask_windows = tf.reshape(\n", 627 | " mask_windows, shape=[-1, self.window_size * self.window_size]\n", 628 | " )\n", 629 | " attn_mask = tf.expand_dims(mask_windows, axis=1) - tf.expand_dims(\n", 630 | " mask_windows, axis=2\n", 631 | " )\n", 632 | " attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask)\n", 633 | " attn_mask = tf.where(attn_mask == 0, 0.0, attn_mask)\n", 634 | " self.attn_mask = tf.Variable(initial_value=attn_mask, trainable=False)\n", 635 | "\n", 636 | " def call(self, x):\n", 637 | " height, width = self.num_patch\n", 638 | " _, num_patches_before, channels = x.shape\n", 639 | " x_skip = x\n", 640 | " x = self.norm1(x)\n", 641 | " x = tf.reshape(x, shape=(-1, height, width, channels))\n", 642 | " if self.shift_size > 0:\n", 643 | " shifted_x = tf.roll(\n", 644 | " x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2]\n", 645 | " )\n", 646 | " else:\n", 647 | " shifted_x = x\n", 648 | "\n", 649 | " x_windows = window_partition(shifted_x, self.window_size)\n", 650 | " x_windows = tf.reshape(\n", 651 | " x_windows, shape=(-1, self.window_size * self.window_size, channels)\n", 652 | " )\n", 653 | " attn_windows = self.attn(x_windows, mask=self.attn_mask)\n", 654 | "\n", 655 | " attn_windows = tf.reshape(\n", 656 | " attn_windows, shape=(-1, self.window_size, self.window_size, channels)\n", 657 | " )\n", 658 | " shifted_x = window_reverse(\n", 659 | " attn_windows, self.window_size, height, width, channels\n", 660 | " )\n", 661 | " if self.shift_size > 0:\n", 662 | " x = tf.roll(\n", 663 | " shifted_x, shift=[self.shift_size, self.shift_size], axis=[1, 2]\n", 664 | " )\n", 665 | " else:\n", 666 | " x = shifted_x\n", 667 | "\n", 668 | " x = tf.reshape(x, shape=(-1, height * width, channels))\n", 669 | " x = self.drop_path(x)\n", 670 | " x = x_skip + x\n", 671 | " x_skip = x\n", 672 | " x = self.norm2(x)\n", 673 | " x = self.mlp(x)\n", 674 | " x = self.drop_path(x)\n", 675 | " x = x_skip + x\n", 676 | " return x" 677 | ], 678 | "metadata": { 679 | "id": "Xq2ybc1spLTs" 680 | }, 681 | "execution_count": null, 682 | "outputs": [] 683 | }, 684 | { 685 | "cell_type": "code", 686 | "source": [ 687 | "class PatchExtract(layers.Layer):\n", 688 | " def __init__(self, patch_size, **kwargs):\n", 689 | " super(PatchExtract, self).__init__(**kwargs)\n", 690 | " self.patch_size_x = patch_size[0]\n", 691 | " self.patch_size_y = patch_size[0]\n", 692 | "\n", 693 | " def call(self, images):\n", 694 | " batch_size = tf.shape(images)[0]\n", 695 | " patches = tf.image.extract_patches(\n", 696 | " images=images,\n", 697 | " sizes=(1, self.patch_size_x, self.patch_size_y, 1),\n", 698 | " strides=(1, self.patch_size_x, self.patch_size_y, 1),\n", 699 | " rates=(1, 1, 1, 1),\n", 700 | " padding=\"VALID\",\n", 701 | " )\n", 702 | " patch_dim = patches.shape[-1]\n", 703 | " patch_num = patches.shape[1]\n", 704 | " return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))\n", 705 | "\n", 706 | "\n", 707 | "class PatchEmbedding(layers.Layer):\n", 708 | " def __init__(self, num_patch, embed_dim, **kwargs):\n", 709 | " super(PatchEmbedding, self).__init__(**kwargs)\n", 710 | " self.num_patch = num_patch\n", 711 | " self.proj = layers.Dense(embed_dim)\n", 712 | " self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)\n", 713 | "\n", 714 | " def call(self, patch):\n", 715 | " pos = tf.range(start=0, limit=self.num_patch, delta=1)\n", 716 | " return self.proj(patch) + self.pos_embed(pos)\n", 717 | "\n", 718 | "\n", 719 | "class PatchMerging(tf.keras.layers.Layer):\n", 720 | " def __init__(self, num_patch, embed_dim):\n", 721 | " super(PatchMerging, self).__init__()\n", 722 | " self.num_patch = num_patch\n", 723 | " self.embed_dim = embed_dim\n", 724 | " self.linear_trans = layers.Dense(2 * embed_dim, use_bias=False)\n", 725 | "\n", 726 | " def call(self, x):\n", 727 | " height, width = self.num_patch\n", 728 | " _, _, C = x.get_shape().as_list()\n", 729 | " x = tf.reshape(x, shape=(-1, height, width, C))\n", 730 | " x0 = x[:, 0::2, 0::2, :]\n", 731 | " x1 = x[:, 1::2, 0::2, :]\n", 732 | " x2 = x[:, 0::2, 1::2, :]\n", 733 | " x3 = x[:, 1::2, 1::2, :]\n", 734 | " x = tf.concat((x0, x1, x2, x3), axis=-1)\n", 735 | " x = tf.reshape(x, shape=(-1, (height // 2) * (width // 2), 4 * C))\n", 736 | " return self.linear_trans(x)" 737 | ], 738 | "metadata": { 739 | "id": "xCc1323NpLWb" 740 | }, 741 | "execution_count": null, 742 | "outputs": [] 743 | }, 744 | { 745 | "cell_type": "code", 746 | "source": [ 747 | "def get_cnn_model():\n", 748 | " \n", 749 | " input_shape1 = 4, 4, 12\n", 750 | " input_shape2 = 8, 8, 4,1\n", 751 | " input_shape3 = 8, 8, 1\n", 752 | " \n", 753 | " \n", 754 | " input1_ = Input(shape=input_shape1)\n", 755 | " input2_ = Input(shape=input_shape2)\n", 756 | " input3_ = Input(shape=input_shape3)\n", 757 | "\n", 758 | " \n", 759 | "\n", 760 | "############################ Modified VGG-16 CNN Module (Filter numbers and kernel sizes are changed)\n", 761 | "########### (It shoud be noted that the kernel sizes for Conv2D in the original VGG-16 are 3 by 3, kernel_size=(3,3))\n", 762 | " \n", 763 | " conv_layer1 = Conv2D(filters=32, kernel_size=(1,1), activation='relu',padding='same')(input1_)\n", 764 | " conv_layer2 = Conv2D(filters=32, kernel_size=(3,3), activation='relu',padding='same')(conv_layer1)\n", 765 | " conv_layer3 = MaxPool2D(pool_size=(1,1))(conv_layer2)\n", 766 | " \n", 767 | " conv_layer4 = Conv2D(filters=64, kernel_size=(1,1), activation='relu',padding='same')(conv_layer3)\n", 768 | " conv_layer5 = Conv2D(filters=64, kernel_size=(3,3), activation='relu',padding='same')(conv_layer4)\n", 769 | " conv_layer6 = MaxPool2D(pool_size=(1,1))(conv_layer5)\n", 770 | " \n", 771 | " conv_layer7 = Conv2D(filters=64, kernel_size=(1,1), activation='relu',padding='same')(conv_layer6)\n", 772 | " conv_layer8 = Conv2D(filters=64, kernel_size=(1,1), activation='relu',padding='same')(conv_layer7)\n", 773 | " conv_layer9 = Conv2D(filters=64, kernel_size=(3,3), activation='relu',padding='same')(conv_layer8)\n", 774 | " conv_layer10 = MaxPool2D(pool_size=(1,1))(conv_layer9)\n", 775 | " \n", 776 | " conv_layer11 = Conv2D(filters=64, kernel_size=(1,1), activation='relu',padding='same')(conv_layer10)\n", 777 | " conv_layer12 = Conv2D(filters=64, kernel_size=(1,1), activation='relu',padding='same')(conv_layer11)\n", 778 | " conv_layer13 = Conv2D(filters=64, kernel_size=(3,3), activation='relu',padding='same')(conv_layer12)\n", 779 | " conv_layer14 = MaxPool2D(pool_size=(1,1))(conv_layer13)\n", 780 | " \n", 781 | " conv_layer15 = Conv2D(filters=128, kernel_size=(1,1), activation='relu',padding='same')(conv_layer14)\n", 782 | " conv_layer16 = Conv2D(filters=128, kernel_size=(1,1), activation='relu',padding='same')(conv_layer15)\n", 783 | " conv_layer17 = Conv2D(filters=128, kernel_size=(3,3), activation='relu',padding='same')(conv_layer16)\n", 784 | " conv_layer18 = MaxPool2D(pool_size=(1,1))(conv_layer17)\n", 785 | " \n", 786 | " \n", 787 | " ###########################################################\n", 788 | "###################### 3D CNN\n", 789 | " \n", 790 | " \n", 791 | "\n", 792 | " conv_layerb1 = Conv3D(filters=64, kernel_size=(2, 2, 2), activation='relu', padding='same', name='conv1')(input2_)\n", 793 | " norm_1 = BatchNormalization(name='norm_a1')(conv_layerb1)\n", 794 | " conv_layerb2 = Conv3D(filters=64, kernel_size=(3, 3, 3), activation='relu',padding='same', name='conv2')(norm_1)\n", 795 | " \n", 796 | " \n", 797 | " conv3d_shape = conv_layerb2.shape\n", 798 | " conv_layerb2 = Reshape((conv3d_shape[1], conv3d_shape[2], conv3d_shape[3]*conv3d_shape[4]))(conv_layerb2)\n", 799 | "\n", 800 | " conv_layerb3 = Conv2D(filters=128, kernel_size=(3, 3), activation='relu', padding='same', name='conv3')(conv_layerb2)\n", 801 | "\n", 802 | " max_b_2 = MaxPool2D((2,2), strides=(2,2), padding='same')(conv_layerb3)\n", 803 | " \n", 804 | " #################################### \n", 805 | " ############################### concate first CNN to Second CNN\n", 806 | " \n", 807 | " concate_level_1 = concatenate([conv_layer18, max_b_2])\n", 808 | " conv_c2 = Conv2D(32, kernel_size=(1, 1), padding='same', name='conv_2')(concate_level_1)\n", 809 | " norm_c2 = BatchNormalization(name='norm_2')(conv_c2)\n", 810 | " relu_c2 = Activation('relu', name='relu_2')(norm_c2)\n", 811 | "\n", 812 | " \n", 813 | " pool_e_1 = AveragePooling2D(pool_size=(2, 2), strides=1, padding='same', name='avg_pool_5_1')(relu_c2)\n", 814 | " \n", 815 | "\n", 816 | " flatten_1=Flatten()(pool_e_1)\n", 817 | " \n", 818 | " dense_layer1 = Dense(units=100, activation='relu')(flatten_1)\n", 819 | " dense_layer1 = Dropout(0.4)(dense_layer1)\n", 820 | " dense_layer2 = Dense(units=50, activation='relu')(dense_layer1)\n", 821 | " dense_layer2 = Dropout(0.4)(dense_layer2)\n", 822 | " \n", 823 | " ######################################## Swin Transformer\n", 824 | " \n", 825 | " x = layers.RandomCrop(image_dimension, image_dimension)(input3_)\n", 826 | " x = layers.RandomFlip(\"horizontal\")(x)\n", 827 | " x = PatchExtract(patch_size)(x)\n", 828 | " x = PatchEmbedding(num_patch_x * num_patch_y, embed_dim)(x)\n", 829 | " x = SwinTransformer(\n", 830 | " dim=embed_dim,\n", 831 | " num_patch=(num_patch_x, num_patch_y),\n", 832 | " num_heads=num_heads,\n", 833 | " window_size=window_size,\n", 834 | " shift_size=0,\n", 835 | " num_mlp=num_mlp,\n", 836 | " qkv_bias=qkv_bias,\n", 837 | " dropout_rate=dropout_rate,\n", 838 | " )(x)\n", 839 | " x = SwinTransformer(\n", 840 | " dim=embed_dim,\n", 841 | " num_patch=(num_patch_x, num_patch_y),\n", 842 | " num_heads=num_heads,\n", 843 | " window_size=window_size,\n", 844 | " shift_size=shift_size,\n", 845 | " num_mlp=num_mlp,\n", 846 | " qkv_bias=qkv_bias,\n", 847 | " dropout_rate=dropout_rate,\n", 848 | " )(x)\n", 849 | " x = PatchMerging((num_patch_x, num_patch_y), embed_dim=embed_dim)(x)\n", 850 | " x = layers.GlobalAveragePooling1D()(x)\n", 851 | " x = layers.Dense(50, activation=\"softmax\")(x)\n", 852 | " \n", 853 | " #############################################\n", 854 | " concate_level_3 = concatenate([x, dense_layer2])\n", 855 | "\n", 856 | " output_layer = Dense(units=output_units, activation='softmax')(concate_level_3)\n", 857 | " \n", 858 | " model = Model(inputs=[input1_,input2_,input3_], outputs=output_layer)\n", 859 | " model.summary()\n", 860 | " \n", 861 | " plot_model(model, to_file='CNN_TransformersBrunswick-VGG16.png', show_shapes=True, show_layer_names=True)\n", 862 | " \n", 863 | "\n", 864 | " return model" 865 | ], 866 | "metadata": { 867 | "id": "plFfnSH-pLZJ" 868 | }, 869 | "execution_count": null, 870 | "outputs": [] 871 | }, 872 | { 873 | "cell_type": "code", 874 | "source": [ 875 | "model = get_cnn_model()" 876 | ], 877 | "metadata": { 878 | "id": "Xhr3NS00pLb8" 879 | }, 880 | "execution_count": null, 881 | "outputs": [] 882 | }, 883 | { 884 | "cell_type": "code", 885 | "source": [ 886 | "\n", 887 | "model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])\n", 888 | "model_checkpoint_callback = keras.callbacks.ModelCheckpoint(\"VGG16.h5\",save_best_only=True)\n", 889 | "history = model.fit(x=[X1train,X2train,X3train], y=ytrain, batch_size = 32, epochs=100,callbacks=model_checkpoint_callback)\n" 890 | ], 891 | "metadata": { 892 | "id": "HuJEpoiyorbH" 893 | }, 894 | "execution_count": null, 895 | "outputs": [] 896 | }, 897 | { 898 | "cell_type": "code", 899 | "source": [ 900 | "plt.figure(figsize=(7,7)) \n", 901 | "plt.grid() \n", 902 | "plt.plot(history.history['loss'])\n", 903 | "\n", 904 | "plt.savefig('BrunswickLoss-VGG16.tiff',facecolor='w', dpi=500)" 905 | ], 906 | "metadata": { 907 | "id": "Kw188TyZordq" 908 | }, 909 | "execution_count": null, 910 | "outputs": [] 911 | }, 912 | { 913 | "cell_type": "code", 914 | "source": [ 915 | "X1test = X1test.reshape(-1, 4, 4, 12)\n", 916 | "X2test = X2test.reshape(-1, 8, 8, 4,1)\n", 917 | "X3test = X3test.reshape(-1, 8, 8, 1)\n", 918 | "\n", 919 | "X1test.shape,X2test.shape,X3test.shape" 920 | ], 921 | "metadata": { 922 | "id": "CszPX3l6pcmJ" 923 | }, 924 | "execution_count": null, 925 | "outputs": [] 926 | }, 927 | { 928 | "cell_type": "code", 929 | "source": [ 930 | "ytest = np_utils.to_categorical(y1test)\n", 931 | "\n", 932 | "ytest.shape" 933 | ], 934 | "metadata": { 935 | "id": "V8RbreGkpcpO" 936 | }, 937 | "execution_count": null, 938 | "outputs": [] 939 | }, 940 | { 941 | "cell_type": "code", 942 | "source": [ 943 | "Y_pred_test = model.predict([X1test,X2test,X3test])\n", 944 | "y_pred_test = np.argmax(Y_pred_test, axis=1)\n" 945 | ], 946 | "metadata": { 947 | "id": "vRzNtIMHpcsP" 948 | }, 949 | "execution_count": null, 950 | "outputs": [] 951 | }, 952 | { 953 | "cell_type": "code", 954 | "source": [ 955 | "ca = np.sum(y_pred_test == np.argmax(ytest, axis=1)) / ytest.shape[0]\n", 956 | "\n", 957 | "print(\"Classification accuracy: %.5f\" % ca)" 958 | ], 959 | "metadata": { 960 | "id": "AYywssLypcvH" 961 | }, 962 | "execution_count": null, 963 | "outputs": [] 964 | }, 965 | { 966 | "cell_type": "code", 967 | "source": [ 968 | "classification = classification_report(np.argmax(ytest, axis=1), y_pred_test)\n", 969 | "print(classification)" 970 | ], 971 | "metadata": { 972 | "id": "tClCFVDQorgp" 973 | }, 974 | "execution_count": null, 975 | "outputs": [] 976 | }, 977 | { 978 | "cell_type": "code", 979 | "source": [ 980 | "" 981 | ], 982 | "metadata": { 983 | "id": "V17zCcRXplKt" 984 | }, 985 | "execution_count": null, 986 | "outputs": [] 987 | } 988 | ] 989 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MultiModelCNN 2 | 3 | **Here are the codes for the "Swin Transformer and Deep Convolutional Neural Networks for Coastal Wetland Classification using Sentinel-1, Sentinel-2, and LiDAR Data" paper. 4 | 5 | The paper is published in Remote Sensing journal. 6 | 7 | **Jamali, Ali, and Masoud Mahdianpari. 2022. "Swin Transformer and Deep Convolutional Neural Networks for Coastal Wetland Classification Using Sentinel-1, Sentinel-2, and LiDAR Data" Remote Sensing 14, no. 2: 359. https://doi.org/10.3390/rs14020359** 8 | 9 | ![Model](https://user-images.githubusercontent.com/22929034/174426194-ce4acb69-635d-41d3-8c37-ad85bed46645.png) 10 | ![Model2](https://user-images.githubusercontent.com/22929034/174426196-dfeb1ef0-f6a6-410b-be1d-4b4c00306821.png) 11 | 12 | 13 | There are three branches in the proposed multi-model deep CNN network: 14 | 1. A modified version of VGG-16 for Sentinel-2 data 15 | 2. A 3D CNN for Sentinel-1 data 16 | 3. The Swin Transformer for the DEM generated from LiDAR data 17 | 18 | 19 | **In our VGG-16 network, we modified the number of filters and kernel sizes to reduce the complexity of the original VGG-16 Deep CNN Network** 20 | 21 | 22 | 23 | 24 | --------------------------------------------------------------------------------