├── ResUNetFormer.jpg ├── Neighborhood_Attention_Makes_the_Encoder_of_ResUNet_Stronger_for_Accurate_Road_Extraction.pdf ├── README.md ├── LICENSE └── ResU_NetFormer_Het.ipynb /ResUNetFormer.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aj1365/ResUNetFormer/HEAD/ResUNetFormer.jpg -------------------------------------------------------------------------------- /Neighborhood_Attention_Makes_the_Encoder_of_ResUNet_Stronger_for_Accurate_Road_Extraction.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aj1365/ResUNetFormer/HEAD/Neighborhood_Attention_Makes_the_Encoder_of_ResUNet_Stronger_for_Accurate_Road_Extraction.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neighborhood Attention Makes the Encoder of ResUNet Stronger for Accurate Road Extraction 2 | 3 | [Ali Jamali](https://www.researchgate.net/profile/Ali-Jamali), [Swalpa Kumar Roy](https://swalpa.github.io), [Jonathan Li](https://uwaterloo.ca/geography-environmental-management/people-profiles/jonathan-li), and [Pedram Ghamisi](https://www.iarai.ac.at/people/pedramghamisi/) 4 | 5 | 6 | 7 | ___________ 8 | 9 | This Keras code is for the paper A. Jamali, S. K. Roy, J. Li and P. Ghamisi, "[Neighborhood Attention Makes the Encoder of ResUNet Stronger for Accurate Road Extraction]," in IEEE Geoscience and Remote Sensing Letters, doi: 10.1109/LGRS.2024.3354560 [https://ieeexplore.ieee.org/document/10400502]. 10 | 11 | 12 | Citation 13 | --------------------- 14 | 15 | **Please kindly cite the paper if this code is useful and helpful for your research.** 16 | 17 | @article{10400502, 18 | title={Neighborhood Attention Makes the Encoder of ResUNet Stronger for Accurate Road Extraction}, 19 | author={Jamali, Ali and Roy, Swalpa Kumar and Li, Jonathan and Ghamisi, Pedram}, 20 | journal={IEEE Geoscience and Remote Sensing Letters}, 21 | year={2024}, 22 | volume={}, 23 | number={}, 24 | pages={1-5}, 25 | doi={10.1109/LGRS.2024.3354560} 26 | } 27 | 28 | 29 | 30 | Acknowledgement 31 | --------------------- 32 | 33 | Part of the local window attention (LWA) block is implementated from [Neighborhood Attention Transformer](https://github.com/SHI-Labs/Neighborhood-Attention-Transformer). 34 | 35 | ## License 36 | 37 | Copyright (c) 2023 Ali Jamali. Released under the MIT License. See [LICENSE](LICENSE) for details. 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /ResU_NetFormer_Het.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "authorship_tag": "ABX9TyNipqv0BIUSaVZOOBCb/s17", 8 | "include_colab_link": true 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": { 33 | "id": "ChByVI4kU4QT" 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "import cv2 # For CV operations\n", 38 | "from PIL import Image #To create and store images\n", 39 | "import numpy as np\n", 40 | "\n", 41 | "#To binarize the input\n", 42 | "import h5py\n", 43 | "import os\n", 44 | "from patchify import patchify" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "source": [ 50 | "##### Creating input & mask arrays\n", 51 | "\n", 52 | "\n", 53 | "images = []\n", 54 | "originalImages = os.listdir(\"E:/MRD/tiff/train/\")\n", 55 | "\n", 56 | "for index,image in enumerate(originalImages):\n", 57 | " print(\"Image number : \" +str(index) )\n", 58 | " img = Image.open(\"E:/MRD/tiff/train/\" + str(image))\n", 59 | " img = img.resize((384, 384))\n", 60 | " arr = np.array(img)\n", 61 | " #arr = np.expand_dims(arr, -1)\n", 62 | " images.append(arr)\n", 63 | "\n", 64 | "TrainX=images\n", 65 | "TrainX = np.array(TrainX)\n", 66 | "\n", 67 | "\n", 68 | "\n", 69 | "images = []\n", 70 | "originalImages = os.listdir(\"E:/MRD/tiff/train_labels/\")\n", 71 | "\n", 72 | "for index,image in enumerate(originalImages):\n", 73 | " print(\"Image number : \" +str(index) )\n", 74 | " img = Image.open(\"E:/MRD/tiff/train_labels/\" + str(image))\n", 75 | " img = img.resize((384, 384))\n", 76 | " arr = np.array(img)\n", 77 | " #arr = np.expand_dims(arr, -1)\n", 78 | " images.append(arr)\n", 79 | "\n", 80 | "TrainY=images\n", 81 | "TrainY = np.array(TrainY)\n" 82 | ], 83 | "metadata": { 84 | "id": "1Lx8g5r_Vfrt" 85 | }, 86 | "execution_count": null, 87 | "outputs": [] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "source": [ 92 | "images = []\n", 93 | "originalImages = os.listdir(\"E:/MRD/tiff/test/\")\n", 94 | "\n", 95 | "for index,image in enumerate(originalImages):\n", 96 | " print(\"Image number : \" +str(index) )\n", 97 | " img = Image.open(\"E:/MRD/tiff/test/\" + str(image))\n", 98 | " img = img.resize((384, 384))\n", 99 | " arr = np.array(img)\n", 100 | " #arr = np.expand_dims(arr, -1)\n", 101 | " images.append(arr)\n", 102 | "\n", 103 | "TestX=images\n", 104 | "TestX = np.array(TestX)\n", 105 | "TestX.shape\n", 106 | "\n", 107 | "images = []\n", 108 | "originalImages = os.listdir(\"E:/MRD/tiff/test_labels/\")\n", 109 | "\n", 110 | "for index,image in enumerate(originalImages):\n", 111 | " print(\"Image number : \" +str(index) )\n", 112 | " img = Image.open(\"E:/MRD/tiff/test_labels/\" + str(image))\n", 113 | " img = img.resize((384, 384))\n", 114 | " arr = np.array(img)\n", 115 | " #arr = np.expand_dims(arr, -1)\n", 116 | " images.append(arr)\n", 117 | "\n", 118 | "TestY=images\n", 119 | "TestY = np.array(TestY)\n" 120 | ], 121 | "metadata": { 122 | "id": "jYVppjM8VmjO" 123 | }, 124 | "execution_count": null, 125 | "outputs": [] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "source": [ 130 | "TrainX=TrainX[0:800,:,:,:]\n", 131 | "TrainY=TrainY[0:800,:,:]" 132 | ], 133 | "metadata": { 134 | "id": "SmOBHAwwV1nu" 135 | }, 136 | "execution_count": null, 137 | "outputs": [] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "source": [ 142 | "TrainY=TrainY.reshape(TrainY.shape[0],TrainY.shape[1],TrainY.shape[1],1)\n", 143 | "TestY=TestY.reshape(TestY.shape[0],TestY.shape[1],TestY.shape[1],1)\n", 144 | "\n", 145 | "TrainY.shape, TestY.shape" 146 | ], 147 | "metadata": { 148 | "id": "yadFVdEoV1q-" 149 | }, 150 | "execution_count": null, 151 | "outputs": [] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "source": [ 156 | "##############Finalizing Dataset for Training#######\n", 157 | "\n", 158 | "with h5py.File(\"E:/Dataset_train.h5\", 'w') as hdf:\n", 159 | " hdf.create_dataset('images', data=TrainX, compression='gzip', compression_opts=9)\n", 160 | " hdf.create_dataset('masks', data=TrainY, compression='gzip', compression_opts=9)" 161 | ], 162 | "metadata": { 163 | "id": "HblewMj8V1uG" 164 | }, 165 | "execution_count": null, 166 | "outputs": [] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "source": [ 171 | "from keras.models import *\n", 172 | "from keras.layers import *\n", 173 | "from keras.optimizers import *\n", 174 | "import keras\n", 175 | "import keras.callbacks\n", 176 | "from keras.callbacks import TensorBoard\n", 177 | "from keras.callbacks import ModelCheckpoint\n", 178 | "from keras import backend as keras\n", 179 | "import matplotlib.pyplot as plt\n", 180 | "from tensorflow.keras.optimizers import Adam\n", 181 | "import tensorflow as tf\n", 182 | "import tensorflow.keras.backend as K\n", 183 | "from typing import Callable\n", 184 | "from keras_cv_attention_models.attention_layers import (\n", 185 | " activation_by_name,\n", 186 | " ChannelAffine,\n", 187 | " conv2d_no_bias,\n", 188 | " depthwise_conv2d_no_bias,\n", 189 | " drop_block,\n", 190 | " #MixupToken,\n", 191 | " mlp_block,\n", 192 | " output_block,\n", 193 | " add_pre_post_process,\n", 194 | ")\n", 195 | "from keras_cv_attention_models.download_and_load import reload_model_weights\n", 196 | "from keras_cv_attention_models.attention_layers import (\n", 197 | " ChannelAffine,\n", 198 | " CompatibleExtractPatches,\n", 199 | " conv2d_no_bias,\n", 200 | " drop_block,\n", 201 | " layer_norm,\n", 202 | " mlp_block,\n", 203 | " output_block,\n", 204 | " add_pre_post_process,\n", 205 | ")\n", 206 | "from keras_cv_attention_models.download_and_load import reload_model_weights\n" 207 | ], 208 | "metadata": { 209 | "id": "fvUtATE7V1x3" 210 | }, 211 | "execution_count": null, 212 | "outputs": [] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "source": [ 217 | "# Metrics to be used when evaluating the network\n", 218 | "from tensorflow_addons.metrics import F1Score\n", 219 | "\n", 220 | "precision = tf.keras.metrics.Precision()\n", 221 | "recall = tf.keras.metrics.Recall()\n", 222 | "f1 = F1Score(num_classes=1, name='f1', average='micro', threshold=0.4)\n", 223 | "sgd_optimizer = Adam()" 224 | ], 225 | "metadata": { 226 | "id": "bhcM9nqOWCQA" 227 | }, 228 | "execution_count": null, 229 | "outputs": [] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "source": [ 234 | "import math\n", 235 | "import tensorflow_addons as tfa\n", 236 | "\n", 237 | "tfk = tf.keras\n", 238 | "tfkl = tfk.layers\n", 239 | "tfm = tf.math\n", 240 | "L2_WEIGHT_DECAY = 1e-4" 241 | ], 242 | "metadata": { 243 | "id": "M1DAhO0jWCS4" 244 | }, 245 | "execution_count": null, 246 | "outputs": [] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "source": [ 251 | "class MultiHeadRelativePositionalKernelBias(tf.keras.layers.Layer):\n", 252 | " def __init__(self, input_height=-1, is_heads_first=False, **kwargs):\n", 253 | " super().__init__(**kwargs)\n", 254 | " self.input_height, self.is_heads_first = input_height, is_heads_first\n", 255 | "\n", 256 | " def build(self, input_shape):\n", 257 | " # input (is_heads_first=False): `[batch, height * width, num_heads, ..., size * size]`\n", 258 | " # input (is_heads_first=True): `[batch, num_heads, height * width, ..., size * size]`\n", 259 | " blocks, num_heads = (input_shape[2], input_shape[1]) if self.is_heads_first else (input_shape[1], input_shape[2])\n", 260 | " size = int(tf.math.sqrt(float(input_shape[-1])))\n", 261 | " height = self.input_height if self.input_height > 0 else int(tf.math.sqrt(float(blocks)))\n", 262 | " width = blocks // height\n", 263 | " pos_size = 2 * size - 1\n", 264 | " initializer = tf.initializers.truncated_normal(stddev=0.02)\n", 265 | " self.pos_bias = self.add_weight(name=\"positional_embedding\", shape=(num_heads, pos_size * pos_size), initializer=initializer, trainable=True)\n", 266 | "\n", 267 | " idx_hh, idx_ww = tf.range(0, size), tf.range(0, size)\n", 268 | " coords = tf.reshape(tf.expand_dims(idx_hh, -1) * pos_size + idx_ww, [-1])\n", 269 | " bias_hh = tf.concat([idx_hh[: size // 2], tf.repeat(idx_hh[size // 2], height - size + 1), idx_hh[size // 2 + 1 :]], axis=-1)\n", 270 | " bias_ww = tf.concat([idx_ww[: size // 2], tf.repeat(idx_ww[size // 2], width - size + 1), idx_ww[size // 2 + 1 :]], axis=-1)\n", 271 | " bias_hw = tf.expand_dims(bias_hh, -1) * pos_size + bias_ww\n", 272 | " bias_coords = tf.expand_dims(bias_hw, -1) + coords\n", 273 | " bias_coords = tf.reshape(bias_coords, [-1, size**2])[::-1] # torch.flip(bias_coords, [0])\n", 274 | "\n", 275 | " bias_coords_shape = [bias_coords.shape[0]] + [1] * (len(input_shape) - 4) + [bias_coords.shape[1]]\n", 276 | " self.bias_coords = tf.reshape(bias_coords, bias_coords_shape) # [height * width, 1 * n, size * size]\n", 277 | " if not self.is_heads_first:\n", 278 | " self.transpose_perm = [1, 0] + list(range(2, len(input_shape) - 1)) # transpose [num_heads, height * width] -> [height * width, num_heads]\n", 279 | "\n", 280 | " def call(self, inputs):\n", 281 | " if self.is_heads_first:\n", 282 | " return inputs + tf.gather(self.pos_bias, self.bias_coords, axis=-1)\n", 283 | " else:\n", 284 | " return inputs + tf.transpose(tf.gather(self.pos_bias, self.bias_coords, axis=-1), self.transpose_perm)\n", 285 | "\n", 286 | " def get_config(self):\n", 287 | " base_config = super().get_config()\n", 288 | " base_config.update({\"input_height\": self.input_height, \"is_heads_first\": self.is_heads_first})\n", 289 | " return base_config\n", 290 | "\n", 291 | "\n", 292 | "def LWA(\n", 293 | " inputs, kernel_size=7, num_heads=4, key_dim=0, out_weight=True, qkv_bias=True, out_bias=True, attn_dropout=0, output_dropout=0, name=None\n", 294 | "):\n", 295 | " _, hh, ww, cc = inputs.shape\n", 296 | " key_dim = key_dim if key_dim > 0 else cc // num_heads\n", 297 | " qk_scale = 1.0 / (float(key_dim) ** 0.5)\n", 298 | " out_shape = cc\n", 299 | " qkv_out = num_heads * key_dim\n", 300 | "\n", 301 | " should_pad_hh, should_pad_ww = max(0, kernel_size - hh), max(0, kernel_size - ww)\n", 302 | " if should_pad_hh or should_pad_ww:\n", 303 | " inputs = tf.pad(inputs, [[0, 0], [0, should_pad_hh], [0, should_pad_ww], [0, 0]])\n", 304 | " _, hh, ww, cc = inputs.shape\n", 305 | "\n", 306 | " qkv = keras.layers.Dense(qkv_out * 3, use_bias=qkv_bias, name=name and name + \"qkv\")(inputs)\n", 307 | " query, key_value = tf.split(qkv, [qkv_out, qkv_out * 2], axis=-1) # Matching weights from PyTorch\n", 308 | " query = tf.expand_dims(tf.reshape(query, [-1, hh * ww, num_heads, key_dim]), -2) # [batch, hh * ww, num_heads, 1, key_dim]\n", 309 | "\n", 310 | " # key_value: [batch, height // kernel_size, width // kernel_size, kernel_size, kernel_size, key + value]\n", 311 | " key_value = CompatibleExtractPatches(sizes=kernel_size, strides=1, padding=\"VALID\", compressed=False)(key_value)\n", 312 | " padded = (kernel_size - 1) // 2\n", 313 | " # torch.pad 'replicate'\n", 314 | " key_value = tf.concat([tf.repeat(key_value[:, :1], padded, axis=1), key_value, tf.repeat(key_value[:, -1:], padded, axis=1)], axis=1)\n", 315 | " key_value = tf.concat([tf.repeat(key_value[:, :, :1], padded, axis=2), key_value, tf.repeat(key_value[:, :, -1:], padded, axis=2)], axis=2)\n", 316 | "\n", 317 | " key_value = tf.reshape(key_value, [-1, kernel_size * kernel_size, key_value.shape[-1]])\n", 318 | " key, value = tf.split(key_value, 2, axis=-1) # [batch * block_height * block_width, kernel_size * kernel_size, key_dim]\n", 319 | " key = tf.transpose(tf.reshape(key, [-1, key.shape[1], num_heads, key_dim]), [0, 2, 3, 1]) # [batch * hh*ww, num_heads, key_dim, kernel_size * kernel_size]\n", 320 | " key = tf.reshape(key, [-1, hh * ww, num_heads, key_dim, kernel_size * kernel_size]) # [batch, hh*ww, num_heads, key_dim, kernel_size * kernel_size]\n", 321 | " value = tf.transpose(tf.reshape(value, [-1, value.shape[1], num_heads, key_dim]), [0, 2, 1, 3])\n", 322 | " value = tf.reshape(value, [-1, hh * ww, num_heads, kernel_size * kernel_size, key_dim]) # [batch, hh*ww, num_heads, kernel_size * kernel_size, key_dim]\n", 323 | " # print(f\">>>> {query.shape = }, {key.shape = }, {value.shape = }\")\n", 324 | "\n", 325 | " # [batch, hh * ww, num_heads, 1, kernel_size * kernel_size]\n", 326 | " attention_scores = keras.layers.Lambda(lambda xx: tf.matmul(xx[0], xx[1]))([query, key]) * qk_scale\n", 327 | " attention_scores = MultiHeadRelativePositionalKernelBias(input_height=hh, name=name and name + \"pos\")(attention_scores)\n", 328 | " attention_scores = keras.layers.Softmax(axis=-1, name=name and name + \"attention_scores\")(attention_scores)\n", 329 | " attention_scores = keras.layers.Dropout(attn_dropout, name=name and name + \"attn_drop\")(attention_scores) if attn_dropout > 0 else attention_scores\n", 330 | "\n", 331 | " # attention_output = [batch, block_height * block_width, num_heads, 1, key_dim]\n", 332 | " attention_output = keras.layers.Lambda(lambda xx: tf.matmul(xx[0], xx[1]))([attention_scores, value])\n", 333 | " attention_output = tf.reshape(attention_output, [-1, hh, ww, num_heads * key_dim])\n", 334 | " # print(f\">>>> {attention_output.shape = }, {attention_scores.shape = }\")\n", 335 | "\n", 336 | " if should_pad_hh or should_pad_ww:\n", 337 | " attention_output = attention_output[:, : hh - should_pad_hh, : ww - should_pad_ww, :]\n", 338 | "\n", 339 | " if out_weight:\n", 340 | " # [batch, hh, ww, num_heads * key_dim] * [num_heads * key_dim, out] --> [batch, hh, ww, out]\n", 341 | " attention_output = keras.layers.Dense(out_shape, use_bias=out_bias, name=name and name + \"output\")(attention_output)\n", 342 | " attention_output = keras.layers.Dropout(output_dropout, name=name and name + \"out_drop\")(attention_output) if output_dropout > 0 else attention_output\n", 343 | " return attention_output" 344 | ], 345 | "metadata": { 346 | "id": "3gj4ajVzWCV7" 347 | }, 348 | "execution_count": null, 349 | "outputs": [] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "source": [ 354 | "\n", 355 | "################################## LIBRARIES ##################################\n", 356 | "\n", 357 | "from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, Add, Conv2DTranspose, concatenate, Lambda, UpSampling2D\n", 358 | "from tensorflow.keras import Model, Input\n", 359 | "from contextlib import redirect_stdout\n", 360 | "\n", 361 | "\n", 362 | "############################# CONVOLUTIONAL BLOCK #############################\n", 363 | "\n", 364 | "def HetConv(feature_map, conv_filter, kernel_size , strides):\n", 365 | "\n", 366 | " # Groupwise Convolution\n", 367 | " x1=Conv2D(filters=conv_filter, kernel_size=(3,3), groups=3, strides=strides, padding='same')(feature_map)\n", 368 | "\n", 369 | " # Pointwise Convolution\n", 370 | " x2= Conv2D(filters=conv_filter, kernel_size=(1,1), strides=strides, padding='same')(feature_map)\n", 371 | "\n", 372 | "\n", 373 | " addition = Add()([x1, x2])\n", 374 | "\n", 375 | " return addition\n", 376 | "\n", 377 | "\n", 378 | "def conv_block(feature_map):\n", 379 | "\n", 380 | " # Main Path\n", 381 | " conv_1 = HetConv(feature_map, conv_filter=66, kernel_size=(3,3), strides=(1,1))\n", 382 | " bn = BatchNormalization()(conv_1)\n", 383 | " relu = Activation(activation='relu')(bn)\n", 384 | " conv_2 = HetConv(relu,conv_filter=66, kernel_size=(3,3), strides=(1,1))\n", 385 | "\n", 386 | " res_conn = HetConv(feature_map,conv_filter=66, kernel_size=(1,1), strides=(1,1))\n", 387 | " res_conn = BatchNormalization()(res_conn)\n", 388 | " addition = Add()([res_conn, conv_2])\n", 389 | "\n", 390 | " return addition\n", 391 | "\n", 392 | "\n", 393 | "############################### RESIDUAL BLOCK ################################\n", 394 | "\n", 395 | "def res_block(feature_map, conv_filter, stride):\n", 396 | "\n", 397 | " bn_1 = BatchNormalization()(feature_map)\n", 398 | " relu_1 = Activation(activation='relu')(bn_1)\n", 399 | " conv_1 = HetConv(relu_1, conv_filter, kernel_size=(3,3), strides=stride[0])\n", 400 | "\n", 401 | "\n", 402 | "\n", 403 | " bn_2 = BatchNormalization()(conv_1)\n", 404 | " relu_2 = Activation(activation='relu')(bn_2)\n", 405 | " conv_2 = HetConv(relu_2, conv_filter, kernel_size=(3,3), strides=stride[1])\n", 406 | "\n", 407 | "\n", 408 | " res_conn = HetConv(feature_map, conv_filter, kernel_size=(1,1), strides=stride[0])\n", 409 | " res_conn = BatchNormalization()(res_conn)\n", 410 | " addition = Add()([res_conn, conv_2])\n", 411 | "\n", 412 | " return addition\n", 413 | "\n", 414 | "################################### ENCODER ###################################\n", 415 | "\n", 416 | "def encoder(feature_map):\n", 417 | "\n", 418 | " # Initialize the to_decoder connection\n", 419 | " to_decoder = []\n", 420 | "\n", 421 | " # Block 1 - Convolution Block\n", 422 | " path = conv_block(feature_map)\n", 423 | " to_decoder.append(path)\n", 424 | "\n", 425 | " # Block 2 - Residual Block 1\n", 426 | " path = res_block(path, 126, [(2, 2), (1, 1)])\n", 427 | " to_decoder.append(path)\n", 428 | "\n", 429 | " # Block 3 - Residual Block 2\n", 430 | " path = res_block(path, 252, [(2, 2), (1, 1)])\n", 431 | " to_decoder.append(path)\n", 432 | "\n", 433 | " return to_decoder\n", 434 | "\n", 435 | "################################### DECODER ###################################\n", 436 | "\n", 437 | "def decoder(feature_map, from_encoder):\n", 438 | "\n", 439 | " # Block 1: Up-sample, Concatenation + Residual Block 1\n", 440 | " main_path = UpSampling2D(size=(2,2), interpolation='bilinear')(feature_map)\n", 441 | " # main_path = Conv2DTranspose(filters=256, kernel_size=(2,2), strides=(2,2), padding='same')(feature_map)\n", 442 | " main_path = concatenate([main_path, from_encoder[2]], axis=3)\n", 443 | " main_path = res_block(main_path, 252, [(1, 1), (1, 1)])\n", 444 | "\n", 445 | " # Block 2: Up-sample, Concatenation + Residual Block 2\n", 446 | " main_path = UpSampling2D(size=(2,2), interpolation='bilinear')(main_path)\n", 447 | " # main_path = Conv2DTranspose(filters=128, kernel_size=(2,2), strides=(2,2), padding='same')(main_path)\n", 448 | " main_path = concatenate([main_path, from_encoder[1]], axis=3)\n", 449 | " main_path = res_block(main_path, 126, [(1, 1), (1, 1)])\n", 450 | "\n", 451 | " # Block 3: Up-sample, Concatenation + Residual Block 3\n", 452 | " main_path = UpSampling2D(size=(2,2), interpolation='bilinear')(main_path)\n", 453 | " # main_path = Conv2DTranspose(filters=64, kernel_size=(2,2), strides=(2,2), padding='same')(main_path)\n", 454 | " main_path = concatenate([main_path, from_encoder[0]], axis=3)\n", 455 | " main_path = res_block(main_path, 66, [(1, 1), (1, 1)])\n", 456 | "\n", 457 | " return main_path\n", 458 | "\n", 459 | "################################ RESIDUAL UNET ################################\n", 460 | "n=1\n", 461 | "attn_kernel_size=3\n", 462 | "num_heads=4\n", 463 | "attn_drop_rate=0.1\n", 464 | "hidden_size=256\n", 465 | "\n", 466 | "def ResLWAUNet():\n", 467 | "\n", 468 | " # Input\n", 469 | " x = Input(shape=(384, 384, 3))\n", 470 | " # model_input_float = Lambda(lambda x: x / 255)(model_input)\n", 471 | "\n", 472 | "\n", 473 | " # Encoder Path\n", 474 | " model_encoder = encoder(x)\n", 475 | " model_bottleneck = res_block(model_encoder[2], 510, [(2, 2), (1, 1)])\n", 476 | " # Transformer/Encoder\n", 477 | "\n", 478 | " y= LWA(model_bottleneck,\n", 479 | " attn_kernel_size,\n", 480 | " num_heads,\n", 481 | " attn_dropout=attn_drop_rate,\n", 482 | " name=f\"Transformer/encoderblock_{n}\")\n", 483 | "\n", 484 | "\n", 485 | " # Bottleneck\n", 486 | "\n", 487 | " # Decoder Path\n", 488 | " model_decoder = decoder(y, model_encoder)\n", 489 | "\n", 490 | " # Output\n", 491 | "\n", 492 | " output_layer = Conv2D(filters=1, kernel_size=(1, 1), strides=(1, 1), activation='sigmoid', padding='same')(model_decoder)\n", 493 | "\n", 494 | "\n", 495 | " model=Model(inputs=x, outputs=output_layer)\n", 496 | " model.compile(optimizer=sgd_optimizer, loss='binary_crossentropy', metrics=['accuracy', precision, recall, f1])\n", 497 | "\n", 498 | "\n", 499 | " return model" 500 | ], 501 | "metadata": { 502 | "id": "PS1mN4lFWCY0" 503 | }, 504 | "execution_count": null, 505 | "outputs": [] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "source": [ 510 | "model=ResLWAUNet()\n", 511 | "model.summary()" 512 | ], 513 | "metadata": { 514 | "id": "5xgFiHKmWCbf" 515 | }, 516 | "execution_count": null, 517 | "outputs": [] 518 | }, 519 | { 520 | "cell_type": "code", 521 | "source": [ 522 | "print('*'*30)\n", 523 | "print('Loading and preprocessing train data...')\n", 524 | "print('*'*30)\n", 525 | "file = h5py.File('E:/Dataset_train.h5', 'r')\n", 526 | "imgs_train = file.get('images')\n", 527 | "imgs_mask_train = file.get('masks')\n", 528 | "imgs_train = np.array(imgs_train)\n", 529 | "imgs_mask_train = np.array(imgs_mask_train)\n", 530 | "\n", 531 | "print(imgs_train.shape)\n", 532 | "print(imgs_mask_train.shape)\n", 533 | "\n", 534 | "\n", 535 | "imgs_train = imgs_train.astype('float32')\n", 536 | "\n", 537 | "mean = np.mean(imgs_train) # mean for data centering\n", 538 | "std = np.std(imgs_train) # std for data normalization\n", 539 | "\n", 540 | "imgs_train -= mean\n", 541 | "imgs_train /= std\n", 542 | "\n", 543 | "imgs_mask_train = imgs_mask_train.astype('float32')\n", 544 | "imgs_mask_train /= 255 # scale masks to [0, 1]\n", 545 | "\n", 546 | "print('*'*30)\n", 547 | "print('Creating and compiling model...')\n", 548 | "print('*'*30)\n", 549 | "model = ResLWAUNet()" 550 | ], 551 | "metadata": { 552 | "id": "NrpR3IwSWCen" 553 | }, 554 | "execution_count": null, 555 | "outputs": [] 556 | }, 557 | { 558 | "cell_type": "code", 559 | "source": [ 560 | "weight_decay = 0.0001\n", 561 | "learning_rate=1e-4\n", 562 | "\n", 563 | "optimizer = tfa.optimizers.AdamW(\n", 564 | " learning_rate=learning_rate, weight_decay=weight_decay\n", 565 | " )\n", 566 | "\n", 567 | "\n", 568 | "checkpoint_filepath = \"E:/MRD100/ResUNetFormer.h5\"\n", 569 | "\n", 570 | "\n", 571 | "\n", 572 | "#with tf.device('/CPU:0'):\n", 573 | "history = model.fit(\n", 574 | " x=imgs_train,\n", 575 | " y=imgs_mask_train,\n", 576 | " batch_size=1,\n", 577 | " epochs=20,\n", 578 | " validation_split=0.1\n", 579 | " )\n" 580 | ], 581 | "metadata": { 582 | "id": "_GWfplfAWwyW" 583 | }, 584 | "execution_count": null, 585 | "outputs": [] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "source": [ 590 | "model.save('E:/MRD100/ResUNetFormer.h5')" 591 | ], 592 | "metadata": { 593 | "id": "aPRac4xBdCVt" 594 | }, 595 | "execution_count": null, 596 | "outputs": [] 597 | }, 598 | { 599 | "cell_type": "code", 600 | "source": [ 601 | "###### Creating Test Dataset\n", 602 | "testImages=TestX\n", 603 | "\n", 604 | "testImages.shape\n", 605 | "\n", 606 | "with h5py.File(\"E:/Dataset_test.h5\", 'w') as hdf:\n", 607 | " hdf.create_dataset('images', data=testImages, compression='gzip', compression_opts=9)\n", 608 | "\n", 609 | "\n", 610 | "\n", 611 | "file = h5py.File('E:/Dataset_test.h5', 'r')\n", 612 | "imgs_test = file.get('images')\n", 613 | "#imgs_mask_test = file.get('masks')\n", 614 | "imgs_test = np.array(imgs_test)\n", 615 | "#imgs_mask_test = np.array(imgs_mask_test)\n", 616 | "imgs_test = imgs_test.astype('float32')\n", 617 | "imgs_test -= mean\n", 618 | "imgs_test /= std\n", 619 | "\n", 620 | "print('*'*30)\n", 621 | "print('Loading saved weights...')\n", 622 | "print('*'*30)\n", 623 | "model.load_weights('E:/MRD100/ResUNetFormer.h5')\n", 624 | "\n", 625 | "print('*'*30)\n", 626 | "print('Predicting masks on test data...')\n", 627 | "print('*'*30)\n", 628 | "imgs_mask_test = model.predict(imgs_test, verbose=1,batch_size=1)\n", 629 | "imgs_mask_test=(imgs_mask_test - np.min(imgs_mask_test))/(np.max(imgs_mask_test) - np.min(imgs_mask_test))\n", 630 | "imgs_mask_test = (imgs_mask_test * 255).astype(np.uint8)\n", 631 | "\n", 632 | "#imgs_mask_test = (imgs_mask_test * 255).astype(np.uint8)" 633 | ], 634 | "metadata": { 635 | "id": "P_eNBawtWw1R" 636 | }, 637 | "execution_count": null, 638 | "outputs": [] 639 | }, 640 | { 641 | "cell_type": "code", 642 | "source": [ 643 | "acc = model.evaluate(imgs_test, TestY, batch_size=1)" 644 | ], 645 | "metadata": { 646 | "id": "63zO1i3wWw4W" 647 | }, 648 | "execution_count": null, 649 | "outputs": [] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "source": [ 654 | "######################### Write the predicted images\n", 655 | "\n", 656 | "print('*' * 30)\n", 657 | "print('Saving predicted masks to files...')\n", 658 | "print('*' * 30)\n", 659 | "pred_dir = 'E:/PredictionsResUNetFormer'\n", 660 | "\n", 661 | "if not os.path.exists(pred_dir):\n", 662 | " os.mkdir(pred_dir)\n", 663 | "for i, image in enumerate(imgs_mask_test):\n", 664 | " #image = (image * 255).astype(np.uint8)\n", 665 | "\n", 666 | " image=(image - np.min(image))/(np.max(image) - np.min(image))\n", 667 | " image = (image * 255).astype(np.uint8)\n", 668 | "\n", 669 | " cv2.imwrite(os.path.join(pred_dir, str(i + 1) + '_pred.png'), image)" 670 | ], 671 | "metadata": { 672 | "id": "JHVD8N9NXcY1" 673 | }, 674 | "execution_count": null, 675 | "outputs": [] 676 | } 677 | ] 678 | } --------------------------------------------------------------------------------