├── ResUNetFormer.jpg
├── Neighborhood_Attention_Makes_the_Encoder_of_ResUNet_Stronger_for_Accurate_Road_Extraction.pdf
├── README.md
├── LICENSE
└── ResU_NetFormer_Het.ipynb
/ResUNetFormer.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aj1365/ResUNetFormer/HEAD/ResUNetFormer.jpg
--------------------------------------------------------------------------------
/Neighborhood_Attention_Makes_the_Encoder_of_ResUNet_Stronger_for_Accurate_Road_Extraction.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aj1365/ResUNetFormer/HEAD/Neighborhood_Attention_Makes_the_Encoder_of_ResUNet_Stronger_for_Accurate_Road_Extraction.pdf
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Neighborhood Attention Makes the Encoder of ResUNet Stronger for Accurate Road Extraction
2 |
3 | [Ali Jamali](https://www.researchgate.net/profile/Ali-Jamali), [Swalpa Kumar Roy](https://swalpa.github.io), [Jonathan Li](https://uwaterloo.ca/geography-environmental-management/people-profiles/jonathan-li), and [Pedram Ghamisi](https://www.iarai.ac.at/people/pedramghamisi/)
4 |
5 |
6 |
7 | ___________
8 |
9 | This Keras code is for the paper A. Jamali, S. K. Roy, J. Li and P. Ghamisi, "[Neighborhood Attention Makes the Encoder of ResUNet Stronger for Accurate Road Extraction]," in IEEE Geoscience and Remote Sensing Letters, doi: 10.1109/LGRS.2024.3354560 [https://ieeexplore.ieee.org/document/10400502].
10 |
11 |
12 | Citation
13 | ---------------------
14 |
15 | **Please kindly cite the paper if this code is useful and helpful for your research.**
16 |
17 | @article{10400502,
18 | title={Neighborhood Attention Makes the Encoder of ResUNet Stronger for Accurate Road Extraction},
19 | author={Jamali, Ali and Roy, Swalpa Kumar and Li, Jonathan and Ghamisi, Pedram},
20 | journal={IEEE Geoscience and Remote Sensing Letters},
21 | year={2024},
22 | volume={},
23 | number={},
24 | pages={1-5},
25 | doi={10.1109/LGRS.2024.3354560}
26 | }
27 |
28 |
29 |
30 | Acknowledgement
31 | ---------------------
32 |
33 | Part of the local window attention (LWA) block is implementated from [Neighborhood Attention Transformer](https://github.com/SHI-Labs/Neighborhood-Attention-Transformer).
34 |
35 | ## License
36 |
37 | Copyright (c) 2023 Ali Jamali. Released under the MIT License. See [LICENSE](LICENSE) for details.
38 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/ResU_NetFormer_Het.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "authorship_tag": "ABX9TyNipqv0BIUSaVZOOBCb/s17",
8 | "include_colab_link": true
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "language_info": {
15 | "name": "python"
16 | }
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | "
"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": null,
32 | "metadata": {
33 | "id": "ChByVI4kU4QT"
34 | },
35 | "outputs": [],
36 | "source": [
37 | "import cv2 # For CV operations\n",
38 | "from PIL import Image #To create and store images\n",
39 | "import numpy as np\n",
40 | "\n",
41 | "#To binarize the input\n",
42 | "import h5py\n",
43 | "import os\n",
44 | "from patchify import patchify"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "source": [
50 | "##### Creating input & mask arrays\n",
51 | "\n",
52 | "\n",
53 | "images = []\n",
54 | "originalImages = os.listdir(\"E:/MRD/tiff/train/\")\n",
55 | "\n",
56 | "for index,image in enumerate(originalImages):\n",
57 | " print(\"Image number : \" +str(index) )\n",
58 | " img = Image.open(\"E:/MRD/tiff/train/\" + str(image))\n",
59 | " img = img.resize((384, 384))\n",
60 | " arr = np.array(img)\n",
61 | " #arr = np.expand_dims(arr, -1)\n",
62 | " images.append(arr)\n",
63 | "\n",
64 | "TrainX=images\n",
65 | "TrainX = np.array(TrainX)\n",
66 | "\n",
67 | "\n",
68 | "\n",
69 | "images = []\n",
70 | "originalImages = os.listdir(\"E:/MRD/tiff/train_labels/\")\n",
71 | "\n",
72 | "for index,image in enumerate(originalImages):\n",
73 | " print(\"Image number : \" +str(index) )\n",
74 | " img = Image.open(\"E:/MRD/tiff/train_labels/\" + str(image))\n",
75 | " img = img.resize((384, 384))\n",
76 | " arr = np.array(img)\n",
77 | " #arr = np.expand_dims(arr, -1)\n",
78 | " images.append(arr)\n",
79 | "\n",
80 | "TrainY=images\n",
81 | "TrainY = np.array(TrainY)\n"
82 | ],
83 | "metadata": {
84 | "id": "1Lx8g5r_Vfrt"
85 | },
86 | "execution_count": null,
87 | "outputs": []
88 | },
89 | {
90 | "cell_type": "code",
91 | "source": [
92 | "images = []\n",
93 | "originalImages = os.listdir(\"E:/MRD/tiff/test/\")\n",
94 | "\n",
95 | "for index,image in enumerate(originalImages):\n",
96 | " print(\"Image number : \" +str(index) )\n",
97 | " img = Image.open(\"E:/MRD/tiff/test/\" + str(image))\n",
98 | " img = img.resize((384, 384))\n",
99 | " arr = np.array(img)\n",
100 | " #arr = np.expand_dims(arr, -1)\n",
101 | " images.append(arr)\n",
102 | "\n",
103 | "TestX=images\n",
104 | "TestX = np.array(TestX)\n",
105 | "TestX.shape\n",
106 | "\n",
107 | "images = []\n",
108 | "originalImages = os.listdir(\"E:/MRD/tiff/test_labels/\")\n",
109 | "\n",
110 | "for index,image in enumerate(originalImages):\n",
111 | " print(\"Image number : \" +str(index) )\n",
112 | " img = Image.open(\"E:/MRD/tiff/test_labels/\" + str(image))\n",
113 | " img = img.resize((384, 384))\n",
114 | " arr = np.array(img)\n",
115 | " #arr = np.expand_dims(arr, -1)\n",
116 | " images.append(arr)\n",
117 | "\n",
118 | "TestY=images\n",
119 | "TestY = np.array(TestY)\n"
120 | ],
121 | "metadata": {
122 | "id": "jYVppjM8VmjO"
123 | },
124 | "execution_count": null,
125 | "outputs": []
126 | },
127 | {
128 | "cell_type": "code",
129 | "source": [
130 | "TrainX=TrainX[0:800,:,:,:]\n",
131 | "TrainY=TrainY[0:800,:,:]"
132 | ],
133 | "metadata": {
134 | "id": "SmOBHAwwV1nu"
135 | },
136 | "execution_count": null,
137 | "outputs": []
138 | },
139 | {
140 | "cell_type": "code",
141 | "source": [
142 | "TrainY=TrainY.reshape(TrainY.shape[0],TrainY.shape[1],TrainY.shape[1],1)\n",
143 | "TestY=TestY.reshape(TestY.shape[0],TestY.shape[1],TestY.shape[1],1)\n",
144 | "\n",
145 | "TrainY.shape, TestY.shape"
146 | ],
147 | "metadata": {
148 | "id": "yadFVdEoV1q-"
149 | },
150 | "execution_count": null,
151 | "outputs": []
152 | },
153 | {
154 | "cell_type": "code",
155 | "source": [
156 | "##############Finalizing Dataset for Training#######\n",
157 | "\n",
158 | "with h5py.File(\"E:/Dataset_train.h5\", 'w') as hdf:\n",
159 | " hdf.create_dataset('images', data=TrainX, compression='gzip', compression_opts=9)\n",
160 | " hdf.create_dataset('masks', data=TrainY, compression='gzip', compression_opts=9)"
161 | ],
162 | "metadata": {
163 | "id": "HblewMj8V1uG"
164 | },
165 | "execution_count": null,
166 | "outputs": []
167 | },
168 | {
169 | "cell_type": "code",
170 | "source": [
171 | "from keras.models import *\n",
172 | "from keras.layers import *\n",
173 | "from keras.optimizers import *\n",
174 | "import keras\n",
175 | "import keras.callbacks\n",
176 | "from keras.callbacks import TensorBoard\n",
177 | "from keras.callbacks import ModelCheckpoint\n",
178 | "from keras import backend as keras\n",
179 | "import matplotlib.pyplot as plt\n",
180 | "from tensorflow.keras.optimizers import Adam\n",
181 | "import tensorflow as tf\n",
182 | "import tensorflow.keras.backend as K\n",
183 | "from typing import Callable\n",
184 | "from keras_cv_attention_models.attention_layers import (\n",
185 | " activation_by_name,\n",
186 | " ChannelAffine,\n",
187 | " conv2d_no_bias,\n",
188 | " depthwise_conv2d_no_bias,\n",
189 | " drop_block,\n",
190 | " #MixupToken,\n",
191 | " mlp_block,\n",
192 | " output_block,\n",
193 | " add_pre_post_process,\n",
194 | ")\n",
195 | "from keras_cv_attention_models.download_and_load import reload_model_weights\n",
196 | "from keras_cv_attention_models.attention_layers import (\n",
197 | " ChannelAffine,\n",
198 | " CompatibleExtractPatches,\n",
199 | " conv2d_no_bias,\n",
200 | " drop_block,\n",
201 | " layer_norm,\n",
202 | " mlp_block,\n",
203 | " output_block,\n",
204 | " add_pre_post_process,\n",
205 | ")\n",
206 | "from keras_cv_attention_models.download_and_load import reload_model_weights\n"
207 | ],
208 | "metadata": {
209 | "id": "fvUtATE7V1x3"
210 | },
211 | "execution_count": null,
212 | "outputs": []
213 | },
214 | {
215 | "cell_type": "code",
216 | "source": [
217 | "# Metrics to be used when evaluating the network\n",
218 | "from tensorflow_addons.metrics import F1Score\n",
219 | "\n",
220 | "precision = tf.keras.metrics.Precision()\n",
221 | "recall = tf.keras.metrics.Recall()\n",
222 | "f1 = F1Score(num_classes=1, name='f1', average='micro', threshold=0.4)\n",
223 | "sgd_optimizer = Adam()"
224 | ],
225 | "metadata": {
226 | "id": "bhcM9nqOWCQA"
227 | },
228 | "execution_count": null,
229 | "outputs": []
230 | },
231 | {
232 | "cell_type": "code",
233 | "source": [
234 | "import math\n",
235 | "import tensorflow_addons as tfa\n",
236 | "\n",
237 | "tfk = tf.keras\n",
238 | "tfkl = tfk.layers\n",
239 | "tfm = tf.math\n",
240 | "L2_WEIGHT_DECAY = 1e-4"
241 | ],
242 | "metadata": {
243 | "id": "M1DAhO0jWCS4"
244 | },
245 | "execution_count": null,
246 | "outputs": []
247 | },
248 | {
249 | "cell_type": "code",
250 | "source": [
251 | "class MultiHeadRelativePositionalKernelBias(tf.keras.layers.Layer):\n",
252 | " def __init__(self, input_height=-1, is_heads_first=False, **kwargs):\n",
253 | " super().__init__(**kwargs)\n",
254 | " self.input_height, self.is_heads_first = input_height, is_heads_first\n",
255 | "\n",
256 | " def build(self, input_shape):\n",
257 | " # input (is_heads_first=False): `[batch, height * width, num_heads, ..., size * size]`\n",
258 | " # input (is_heads_first=True): `[batch, num_heads, height * width, ..., size * size]`\n",
259 | " blocks, num_heads = (input_shape[2], input_shape[1]) if self.is_heads_first else (input_shape[1], input_shape[2])\n",
260 | " size = int(tf.math.sqrt(float(input_shape[-1])))\n",
261 | " height = self.input_height if self.input_height > 0 else int(tf.math.sqrt(float(blocks)))\n",
262 | " width = blocks // height\n",
263 | " pos_size = 2 * size - 1\n",
264 | " initializer = tf.initializers.truncated_normal(stddev=0.02)\n",
265 | " self.pos_bias = self.add_weight(name=\"positional_embedding\", shape=(num_heads, pos_size * pos_size), initializer=initializer, trainable=True)\n",
266 | "\n",
267 | " idx_hh, idx_ww = tf.range(0, size), tf.range(0, size)\n",
268 | " coords = tf.reshape(tf.expand_dims(idx_hh, -1) * pos_size + idx_ww, [-1])\n",
269 | " bias_hh = tf.concat([idx_hh[: size // 2], tf.repeat(idx_hh[size // 2], height - size + 1), idx_hh[size // 2 + 1 :]], axis=-1)\n",
270 | " bias_ww = tf.concat([idx_ww[: size // 2], tf.repeat(idx_ww[size // 2], width - size + 1), idx_ww[size // 2 + 1 :]], axis=-1)\n",
271 | " bias_hw = tf.expand_dims(bias_hh, -1) * pos_size + bias_ww\n",
272 | " bias_coords = tf.expand_dims(bias_hw, -1) + coords\n",
273 | " bias_coords = tf.reshape(bias_coords, [-1, size**2])[::-1] # torch.flip(bias_coords, [0])\n",
274 | "\n",
275 | " bias_coords_shape = [bias_coords.shape[0]] + [1] * (len(input_shape) - 4) + [bias_coords.shape[1]]\n",
276 | " self.bias_coords = tf.reshape(bias_coords, bias_coords_shape) # [height * width, 1 * n, size * size]\n",
277 | " if not self.is_heads_first:\n",
278 | " self.transpose_perm = [1, 0] + list(range(2, len(input_shape) - 1)) # transpose [num_heads, height * width] -> [height * width, num_heads]\n",
279 | "\n",
280 | " def call(self, inputs):\n",
281 | " if self.is_heads_first:\n",
282 | " return inputs + tf.gather(self.pos_bias, self.bias_coords, axis=-1)\n",
283 | " else:\n",
284 | " return inputs + tf.transpose(tf.gather(self.pos_bias, self.bias_coords, axis=-1), self.transpose_perm)\n",
285 | "\n",
286 | " def get_config(self):\n",
287 | " base_config = super().get_config()\n",
288 | " base_config.update({\"input_height\": self.input_height, \"is_heads_first\": self.is_heads_first})\n",
289 | " return base_config\n",
290 | "\n",
291 | "\n",
292 | "def LWA(\n",
293 | " inputs, kernel_size=7, num_heads=4, key_dim=0, out_weight=True, qkv_bias=True, out_bias=True, attn_dropout=0, output_dropout=0, name=None\n",
294 | "):\n",
295 | " _, hh, ww, cc = inputs.shape\n",
296 | " key_dim = key_dim if key_dim > 0 else cc // num_heads\n",
297 | " qk_scale = 1.0 / (float(key_dim) ** 0.5)\n",
298 | " out_shape = cc\n",
299 | " qkv_out = num_heads * key_dim\n",
300 | "\n",
301 | " should_pad_hh, should_pad_ww = max(0, kernel_size - hh), max(0, kernel_size - ww)\n",
302 | " if should_pad_hh or should_pad_ww:\n",
303 | " inputs = tf.pad(inputs, [[0, 0], [0, should_pad_hh], [0, should_pad_ww], [0, 0]])\n",
304 | " _, hh, ww, cc = inputs.shape\n",
305 | "\n",
306 | " qkv = keras.layers.Dense(qkv_out * 3, use_bias=qkv_bias, name=name and name + \"qkv\")(inputs)\n",
307 | " query, key_value = tf.split(qkv, [qkv_out, qkv_out * 2], axis=-1) # Matching weights from PyTorch\n",
308 | " query = tf.expand_dims(tf.reshape(query, [-1, hh * ww, num_heads, key_dim]), -2) # [batch, hh * ww, num_heads, 1, key_dim]\n",
309 | "\n",
310 | " # key_value: [batch, height // kernel_size, width // kernel_size, kernel_size, kernel_size, key + value]\n",
311 | " key_value = CompatibleExtractPatches(sizes=kernel_size, strides=1, padding=\"VALID\", compressed=False)(key_value)\n",
312 | " padded = (kernel_size - 1) // 2\n",
313 | " # torch.pad 'replicate'\n",
314 | " key_value = tf.concat([tf.repeat(key_value[:, :1], padded, axis=1), key_value, tf.repeat(key_value[:, -1:], padded, axis=1)], axis=1)\n",
315 | " key_value = tf.concat([tf.repeat(key_value[:, :, :1], padded, axis=2), key_value, tf.repeat(key_value[:, :, -1:], padded, axis=2)], axis=2)\n",
316 | "\n",
317 | " key_value = tf.reshape(key_value, [-1, kernel_size * kernel_size, key_value.shape[-1]])\n",
318 | " key, value = tf.split(key_value, 2, axis=-1) # [batch * block_height * block_width, kernel_size * kernel_size, key_dim]\n",
319 | " key = tf.transpose(tf.reshape(key, [-1, key.shape[1], num_heads, key_dim]), [0, 2, 3, 1]) # [batch * hh*ww, num_heads, key_dim, kernel_size * kernel_size]\n",
320 | " key = tf.reshape(key, [-1, hh * ww, num_heads, key_dim, kernel_size * kernel_size]) # [batch, hh*ww, num_heads, key_dim, kernel_size * kernel_size]\n",
321 | " value = tf.transpose(tf.reshape(value, [-1, value.shape[1], num_heads, key_dim]), [0, 2, 1, 3])\n",
322 | " value = tf.reshape(value, [-1, hh * ww, num_heads, kernel_size * kernel_size, key_dim]) # [batch, hh*ww, num_heads, kernel_size * kernel_size, key_dim]\n",
323 | " # print(f\">>>> {query.shape = }, {key.shape = }, {value.shape = }\")\n",
324 | "\n",
325 | " # [batch, hh * ww, num_heads, 1, kernel_size * kernel_size]\n",
326 | " attention_scores = keras.layers.Lambda(lambda xx: tf.matmul(xx[0], xx[1]))([query, key]) * qk_scale\n",
327 | " attention_scores = MultiHeadRelativePositionalKernelBias(input_height=hh, name=name and name + \"pos\")(attention_scores)\n",
328 | " attention_scores = keras.layers.Softmax(axis=-1, name=name and name + \"attention_scores\")(attention_scores)\n",
329 | " attention_scores = keras.layers.Dropout(attn_dropout, name=name and name + \"attn_drop\")(attention_scores) if attn_dropout > 0 else attention_scores\n",
330 | "\n",
331 | " # attention_output = [batch, block_height * block_width, num_heads, 1, key_dim]\n",
332 | " attention_output = keras.layers.Lambda(lambda xx: tf.matmul(xx[0], xx[1]))([attention_scores, value])\n",
333 | " attention_output = tf.reshape(attention_output, [-1, hh, ww, num_heads * key_dim])\n",
334 | " # print(f\">>>> {attention_output.shape = }, {attention_scores.shape = }\")\n",
335 | "\n",
336 | " if should_pad_hh or should_pad_ww:\n",
337 | " attention_output = attention_output[:, : hh - should_pad_hh, : ww - should_pad_ww, :]\n",
338 | "\n",
339 | " if out_weight:\n",
340 | " # [batch, hh, ww, num_heads * key_dim] * [num_heads * key_dim, out] --> [batch, hh, ww, out]\n",
341 | " attention_output = keras.layers.Dense(out_shape, use_bias=out_bias, name=name and name + \"output\")(attention_output)\n",
342 | " attention_output = keras.layers.Dropout(output_dropout, name=name and name + \"out_drop\")(attention_output) if output_dropout > 0 else attention_output\n",
343 | " return attention_output"
344 | ],
345 | "metadata": {
346 | "id": "3gj4ajVzWCV7"
347 | },
348 | "execution_count": null,
349 | "outputs": []
350 | },
351 | {
352 | "cell_type": "code",
353 | "source": [
354 | "\n",
355 | "################################## LIBRARIES ##################################\n",
356 | "\n",
357 | "from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, Add, Conv2DTranspose, concatenate, Lambda, UpSampling2D\n",
358 | "from tensorflow.keras import Model, Input\n",
359 | "from contextlib import redirect_stdout\n",
360 | "\n",
361 | "\n",
362 | "############################# CONVOLUTIONAL BLOCK #############################\n",
363 | "\n",
364 | "def HetConv(feature_map, conv_filter, kernel_size , strides):\n",
365 | "\n",
366 | " # Groupwise Convolution\n",
367 | " x1=Conv2D(filters=conv_filter, kernel_size=(3,3), groups=3, strides=strides, padding='same')(feature_map)\n",
368 | "\n",
369 | " # Pointwise Convolution\n",
370 | " x2= Conv2D(filters=conv_filter, kernel_size=(1,1), strides=strides, padding='same')(feature_map)\n",
371 | "\n",
372 | "\n",
373 | " addition = Add()([x1, x2])\n",
374 | "\n",
375 | " return addition\n",
376 | "\n",
377 | "\n",
378 | "def conv_block(feature_map):\n",
379 | "\n",
380 | " # Main Path\n",
381 | " conv_1 = HetConv(feature_map, conv_filter=66, kernel_size=(3,3), strides=(1,1))\n",
382 | " bn = BatchNormalization()(conv_1)\n",
383 | " relu = Activation(activation='relu')(bn)\n",
384 | " conv_2 = HetConv(relu,conv_filter=66, kernel_size=(3,3), strides=(1,1))\n",
385 | "\n",
386 | " res_conn = HetConv(feature_map,conv_filter=66, kernel_size=(1,1), strides=(1,1))\n",
387 | " res_conn = BatchNormalization()(res_conn)\n",
388 | " addition = Add()([res_conn, conv_2])\n",
389 | "\n",
390 | " return addition\n",
391 | "\n",
392 | "\n",
393 | "############################### RESIDUAL BLOCK ################################\n",
394 | "\n",
395 | "def res_block(feature_map, conv_filter, stride):\n",
396 | "\n",
397 | " bn_1 = BatchNormalization()(feature_map)\n",
398 | " relu_1 = Activation(activation='relu')(bn_1)\n",
399 | " conv_1 = HetConv(relu_1, conv_filter, kernel_size=(3,3), strides=stride[0])\n",
400 | "\n",
401 | "\n",
402 | "\n",
403 | " bn_2 = BatchNormalization()(conv_1)\n",
404 | " relu_2 = Activation(activation='relu')(bn_2)\n",
405 | " conv_2 = HetConv(relu_2, conv_filter, kernel_size=(3,3), strides=stride[1])\n",
406 | "\n",
407 | "\n",
408 | " res_conn = HetConv(feature_map, conv_filter, kernel_size=(1,1), strides=stride[0])\n",
409 | " res_conn = BatchNormalization()(res_conn)\n",
410 | " addition = Add()([res_conn, conv_2])\n",
411 | "\n",
412 | " return addition\n",
413 | "\n",
414 | "################################### ENCODER ###################################\n",
415 | "\n",
416 | "def encoder(feature_map):\n",
417 | "\n",
418 | " # Initialize the to_decoder connection\n",
419 | " to_decoder = []\n",
420 | "\n",
421 | " # Block 1 - Convolution Block\n",
422 | " path = conv_block(feature_map)\n",
423 | " to_decoder.append(path)\n",
424 | "\n",
425 | " # Block 2 - Residual Block 1\n",
426 | " path = res_block(path, 126, [(2, 2), (1, 1)])\n",
427 | " to_decoder.append(path)\n",
428 | "\n",
429 | " # Block 3 - Residual Block 2\n",
430 | " path = res_block(path, 252, [(2, 2), (1, 1)])\n",
431 | " to_decoder.append(path)\n",
432 | "\n",
433 | " return to_decoder\n",
434 | "\n",
435 | "################################### DECODER ###################################\n",
436 | "\n",
437 | "def decoder(feature_map, from_encoder):\n",
438 | "\n",
439 | " # Block 1: Up-sample, Concatenation + Residual Block 1\n",
440 | " main_path = UpSampling2D(size=(2,2), interpolation='bilinear')(feature_map)\n",
441 | " # main_path = Conv2DTranspose(filters=256, kernel_size=(2,2), strides=(2,2), padding='same')(feature_map)\n",
442 | " main_path = concatenate([main_path, from_encoder[2]], axis=3)\n",
443 | " main_path = res_block(main_path, 252, [(1, 1), (1, 1)])\n",
444 | "\n",
445 | " # Block 2: Up-sample, Concatenation + Residual Block 2\n",
446 | " main_path = UpSampling2D(size=(2,2), interpolation='bilinear')(main_path)\n",
447 | " # main_path = Conv2DTranspose(filters=128, kernel_size=(2,2), strides=(2,2), padding='same')(main_path)\n",
448 | " main_path = concatenate([main_path, from_encoder[1]], axis=3)\n",
449 | " main_path = res_block(main_path, 126, [(1, 1), (1, 1)])\n",
450 | "\n",
451 | " # Block 3: Up-sample, Concatenation + Residual Block 3\n",
452 | " main_path = UpSampling2D(size=(2,2), interpolation='bilinear')(main_path)\n",
453 | " # main_path = Conv2DTranspose(filters=64, kernel_size=(2,2), strides=(2,2), padding='same')(main_path)\n",
454 | " main_path = concatenate([main_path, from_encoder[0]], axis=3)\n",
455 | " main_path = res_block(main_path, 66, [(1, 1), (1, 1)])\n",
456 | "\n",
457 | " return main_path\n",
458 | "\n",
459 | "################################ RESIDUAL UNET ################################\n",
460 | "n=1\n",
461 | "attn_kernel_size=3\n",
462 | "num_heads=4\n",
463 | "attn_drop_rate=0.1\n",
464 | "hidden_size=256\n",
465 | "\n",
466 | "def ResLWAUNet():\n",
467 | "\n",
468 | " # Input\n",
469 | " x = Input(shape=(384, 384, 3))\n",
470 | " # model_input_float = Lambda(lambda x: x / 255)(model_input)\n",
471 | "\n",
472 | "\n",
473 | " # Encoder Path\n",
474 | " model_encoder = encoder(x)\n",
475 | " model_bottleneck = res_block(model_encoder[2], 510, [(2, 2), (1, 1)])\n",
476 | " # Transformer/Encoder\n",
477 | "\n",
478 | " y= LWA(model_bottleneck,\n",
479 | " attn_kernel_size,\n",
480 | " num_heads,\n",
481 | " attn_dropout=attn_drop_rate,\n",
482 | " name=f\"Transformer/encoderblock_{n}\")\n",
483 | "\n",
484 | "\n",
485 | " # Bottleneck\n",
486 | "\n",
487 | " # Decoder Path\n",
488 | " model_decoder = decoder(y, model_encoder)\n",
489 | "\n",
490 | " # Output\n",
491 | "\n",
492 | " output_layer = Conv2D(filters=1, kernel_size=(1, 1), strides=(1, 1), activation='sigmoid', padding='same')(model_decoder)\n",
493 | "\n",
494 | "\n",
495 | " model=Model(inputs=x, outputs=output_layer)\n",
496 | " model.compile(optimizer=sgd_optimizer, loss='binary_crossentropy', metrics=['accuracy', precision, recall, f1])\n",
497 | "\n",
498 | "\n",
499 | " return model"
500 | ],
501 | "metadata": {
502 | "id": "PS1mN4lFWCY0"
503 | },
504 | "execution_count": null,
505 | "outputs": []
506 | },
507 | {
508 | "cell_type": "code",
509 | "source": [
510 | "model=ResLWAUNet()\n",
511 | "model.summary()"
512 | ],
513 | "metadata": {
514 | "id": "5xgFiHKmWCbf"
515 | },
516 | "execution_count": null,
517 | "outputs": []
518 | },
519 | {
520 | "cell_type": "code",
521 | "source": [
522 | "print('*'*30)\n",
523 | "print('Loading and preprocessing train data...')\n",
524 | "print('*'*30)\n",
525 | "file = h5py.File('E:/Dataset_train.h5', 'r')\n",
526 | "imgs_train = file.get('images')\n",
527 | "imgs_mask_train = file.get('masks')\n",
528 | "imgs_train = np.array(imgs_train)\n",
529 | "imgs_mask_train = np.array(imgs_mask_train)\n",
530 | "\n",
531 | "print(imgs_train.shape)\n",
532 | "print(imgs_mask_train.shape)\n",
533 | "\n",
534 | "\n",
535 | "imgs_train = imgs_train.astype('float32')\n",
536 | "\n",
537 | "mean = np.mean(imgs_train) # mean for data centering\n",
538 | "std = np.std(imgs_train) # std for data normalization\n",
539 | "\n",
540 | "imgs_train -= mean\n",
541 | "imgs_train /= std\n",
542 | "\n",
543 | "imgs_mask_train = imgs_mask_train.astype('float32')\n",
544 | "imgs_mask_train /= 255 # scale masks to [0, 1]\n",
545 | "\n",
546 | "print('*'*30)\n",
547 | "print('Creating and compiling model...')\n",
548 | "print('*'*30)\n",
549 | "model = ResLWAUNet()"
550 | ],
551 | "metadata": {
552 | "id": "NrpR3IwSWCen"
553 | },
554 | "execution_count": null,
555 | "outputs": []
556 | },
557 | {
558 | "cell_type": "code",
559 | "source": [
560 | "weight_decay = 0.0001\n",
561 | "learning_rate=1e-4\n",
562 | "\n",
563 | "optimizer = tfa.optimizers.AdamW(\n",
564 | " learning_rate=learning_rate, weight_decay=weight_decay\n",
565 | " )\n",
566 | "\n",
567 | "\n",
568 | "checkpoint_filepath = \"E:/MRD100/ResUNetFormer.h5\"\n",
569 | "\n",
570 | "\n",
571 | "\n",
572 | "#with tf.device('/CPU:0'):\n",
573 | "history = model.fit(\n",
574 | " x=imgs_train,\n",
575 | " y=imgs_mask_train,\n",
576 | " batch_size=1,\n",
577 | " epochs=20,\n",
578 | " validation_split=0.1\n",
579 | " )\n"
580 | ],
581 | "metadata": {
582 | "id": "_GWfplfAWwyW"
583 | },
584 | "execution_count": null,
585 | "outputs": []
586 | },
587 | {
588 | "cell_type": "code",
589 | "source": [
590 | "model.save('E:/MRD100/ResUNetFormer.h5')"
591 | ],
592 | "metadata": {
593 | "id": "aPRac4xBdCVt"
594 | },
595 | "execution_count": null,
596 | "outputs": []
597 | },
598 | {
599 | "cell_type": "code",
600 | "source": [
601 | "###### Creating Test Dataset\n",
602 | "testImages=TestX\n",
603 | "\n",
604 | "testImages.shape\n",
605 | "\n",
606 | "with h5py.File(\"E:/Dataset_test.h5\", 'w') as hdf:\n",
607 | " hdf.create_dataset('images', data=testImages, compression='gzip', compression_opts=9)\n",
608 | "\n",
609 | "\n",
610 | "\n",
611 | "file = h5py.File('E:/Dataset_test.h5', 'r')\n",
612 | "imgs_test = file.get('images')\n",
613 | "#imgs_mask_test = file.get('masks')\n",
614 | "imgs_test = np.array(imgs_test)\n",
615 | "#imgs_mask_test = np.array(imgs_mask_test)\n",
616 | "imgs_test = imgs_test.astype('float32')\n",
617 | "imgs_test -= mean\n",
618 | "imgs_test /= std\n",
619 | "\n",
620 | "print('*'*30)\n",
621 | "print('Loading saved weights...')\n",
622 | "print('*'*30)\n",
623 | "model.load_weights('E:/MRD100/ResUNetFormer.h5')\n",
624 | "\n",
625 | "print('*'*30)\n",
626 | "print('Predicting masks on test data...')\n",
627 | "print('*'*30)\n",
628 | "imgs_mask_test = model.predict(imgs_test, verbose=1,batch_size=1)\n",
629 | "imgs_mask_test=(imgs_mask_test - np.min(imgs_mask_test))/(np.max(imgs_mask_test) - np.min(imgs_mask_test))\n",
630 | "imgs_mask_test = (imgs_mask_test * 255).astype(np.uint8)\n",
631 | "\n",
632 | "#imgs_mask_test = (imgs_mask_test * 255).astype(np.uint8)"
633 | ],
634 | "metadata": {
635 | "id": "P_eNBawtWw1R"
636 | },
637 | "execution_count": null,
638 | "outputs": []
639 | },
640 | {
641 | "cell_type": "code",
642 | "source": [
643 | "acc = model.evaluate(imgs_test, TestY, batch_size=1)"
644 | ],
645 | "metadata": {
646 | "id": "63zO1i3wWw4W"
647 | },
648 | "execution_count": null,
649 | "outputs": []
650 | },
651 | {
652 | "cell_type": "code",
653 | "source": [
654 | "######################### Write the predicted images\n",
655 | "\n",
656 | "print('*' * 30)\n",
657 | "print('Saving predicted masks to files...')\n",
658 | "print('*' * 30)\n",
659 | "pred_dir = 'E:/PredictionsResUNetFormer'\n",
660 | "\n",
661 | "if not os.path.exists(pred_dir):\n",
662 | " os.mkdir(pred_dir)\n",
663 | "for i, image in enumerate(imgs_mask_test):\n",
664 | " #image = (image * 255).astype(np.uint8)\n",
665 | "\n",
666 | " image=(image - np.min(image))/(np.max(image) - np.min(image))\n",
667 | " image = (image * 255).astype(np.uint8)\n",
668 | "\n",
669 | " cv2.imwrite(os.path.join(pred_dir, str(i + 1) + '_pred.png'), image)"
670 | ],
671 | "metadata": {
672 | "id": "JHVD8N9NXcY1"
673 | },
674 | "execution_count": null,
675 | "outputs": []
676 | }
677 | ]
678 | }
--------------------------------------------------------------------------------