├── LICENSE ├── PyTorch ├── README.md ├── attention-unet.py ├── multiresunet.py ├── resunet.py ├── unet.py └── unetr_2d.py ├── README.md └── TensorFlow ├── attention-unet.py ├── colonsegnet.py ├── deeplabv3plus.py ├── densenet121.py ├── doubleunet.py ├── efficientnetb0_unet.py ├── inception_resnetv2_unet.py ├── mobilenetv2_unet.py ├── multiresunet.py ├── notebook ├── ColonSegNet.ipynb ├── README.md └── images │ ├── ColonSegNet.png │ ├── ResidualBlock.png │ ├── Strided_Conv_Block.png │ └── squeeze_and_excitation_detailed_block_diagram.png ├── resnet50_unet.py ├── resunet++.py ├── resunet.py ├── u2-net.py ├── unet.py ├── unetr_2d.py ├── vgg16_unet.py └── vgg19_unet.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /PyTorch/README.md: -------------------------------------------------------------------------------- 1 | # PyTorch 2 | This directory contains the implementation of the different segmentation models in PyTorch framework. 3 | -------------------------------------------------------------------------------- /PyTorch/attention-unet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | class conv_block(nn.Module): 6 | def __init__(self, in_c, out_c): 7 | super().__init__() 8 | 9 | self.conv = nn.Sequential( 10 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), 11 | nn.BatchNorm2d(out_c), 12 | nn.ReLU(inplace=True), 13 | nn.Conv2d(out_c, out_c, kernel_size=3, padding=1), 14 | nn.BatchNorm2d(out_c), 15 | nn.ReLU(inplace=True) 16 | ) 17 | 18 | def forward(self, x): 19 | return self.conv(x) 20 | 21 | class encoder_block(nn.Module): 22 | def __init__(self, in_c, out_c): 23 | super().__init__() 24 | 25 | self.conv = conv_block(in_c, out_c) 26 | self.pool = nn.MaxPool2d((2, 2)) 27 | 28 | def forward(self, x): 29 | s = self.conv(x) 30 | p = self.pool(s) 31 | return s, p 32 | 33 | class attention_gate(nn.Module): 34 | def __init__(self, in_c, out_c): 35 | super().__init__() 36 | 37 | self.Wg = nn.Sequential( 38 | nn.Conv2d(in_c[0], out_c, kernel_size=1, padding=0), 39 | nn.BatchNorm2d(out_c) 40 | ) 41 | self.Ws = nn.Sequential( 42 | nn.Conv2d(in_c[1], out_c, kernel_size=1, padding=0), 43 | nn.BatchNorm2d(out_c) 44 | ) 45 | self.relu = nn.ReLU(inplace=True) 46 | self.output = nn.Sequential( 47 | nn.Conv2d(out_c, out_c, kernel_size=1, padding=0), 48 | nn.Sigmoid() 49 | ) 50 | 51 | def forward(self, g, s): 52 | Wg = self.Wg(g) 53 | Ws = self.Ws(s) 54 | out = self.relu(Wg + Ws) 55 | out = self.output(out) 56 | return out * s 57 | 58 | class decoder_block(nn.Module): 59 | def __init__(self, in_c, out_c): 60 | super().__init__() 61 | 62 | self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 63 | self.ag = attention_gate(in_c, out_c) 64 | self.c1 = conv_block(in_c[0]+out_c, out_c) 65 | 66 | def forward(self, x, s): 67 | x = self.up(x) 68 | s = self.ag(x, s) 69 | x = torch.cat([x, s], axis=1) 70 | x = self.c1(x) 71 | return x 72 | 73 | class attention_unet(nn.Module): 74 | def __init__(self): 75 | super().__init__() 76 | 77 | self.e1 = encoder_block(3, 64) 78 | self.e2 = encoder_block(64, 128) 79 | self.e3 = encoder_block(128, 256) 80 | 81 | self.b1 = conv_block(256, 512) 82 | 83 | self.d1 = decoder_block([512, 256], 256) 84 | self.d2 = decoder_block([256, 128], 128) 85 | self.d3 = decoder_block([128, 64], 64) 86 | 87 | self.output = nn.Conv2d(64, 1, kernel_size=1, padding=0) 88 | 89 | def forward(self, x): 90 | s1, p1 = self.e1(x) 91 | s2, p2 = self.e2(p1) 92 | s3, p3 = self.e3(p2) 93 | 94 | b1 = self.b1(p3) 95 | 96 | d1 = self.d1(b1, s3) 97 | d2 = self.d2(d1, s2) 98 | d3 = self.d3(d2, s1) 99 | 100 | output = self.output(d3) 101 | return output 102 | 103 | 104 | if __name__ == "__main__": 105 | x = torch.randn((8, 3, 256, 256)) 106 | model = attention_unet() 107 | output = model(x) 108 | print(output.shape) 109 | -------------------------------------------------------------------------------- /PyTorch/multiresunet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class conv_block(nn.Module): 7 | def __init__(self, in_c, out_c, kernel_size=3, padding=1, act=True): 8 | super().__init__() 9 | 10 | layers = [ 11 | nn.Conv2d(in_c, out_c, kernel_size=kernel_size, padding=padding, bias=False), 12 | nn.BatchNorm2d(out_c) 13 | ] 14 | if act == True: 15 | layers.append(nn.ReLU(inplace=True)) 16 | 17 | self.conv = nn.Sequential(*layers) 18 | 19 | def forward(self, x): 20 | return self.conv(x) 21 | 22 | class multires_block(nn.Module): 23 | def __init__(self, in_c, out_c, alpha=1.67): 24 | super().__init__() 25 | 26 | W = out_c * alpha 27 | self.c1 = conv_block(in_c, int(W*0.167)) 28 | self.c2 = conv_block(int(W*0.167), int(W*0.333)) 29 | self.c3 = conv_block(int(W*0.333), int(W*0.5)) 30 | 31 | nf = int(W*0.167) + int(W*0.333) + int(W*0.5) 32 | self.b1 = nn.BatchNorm2d(nf) 33 | self.c4 = conv_block(in_c, nf) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.b2 = nn.BatchNorm2d(nf) 36 | 37 | def forward(self, x): 38 | x0 = x 39 | x1 = self.c1(x0) 40 | x2 = self.c2(x1) 41 | x3 = self.c3(x2) 42 | xc = torch.cat([x1, x2, x3], dim=1) 43 | xc = self.b1(xc) 44 | 45 | sc = self.c4(x0) 46 | x = self.relu(xc + sc) 47 | x = self.b2(x) 48 | return x 49 | 50 | class res_path_block(nn.Module): 51 | def __init__(self, in_c, out_c): 52 | super().__init__() 53 | 54 | self.c1 = conv_block(in_c, out_c, act=False) 55 | self.s1 = conv_block(in_c, out_c, kernel_size=1, padding=0, act=False) 56 | self.relu = nn.ReLU(inplace=True) 57 | self.bn = nn.BatchNorm2d(out_c) 58 | 59 | def forward(self, x): 60 | x1 = self.c1(x) 61 | s1 = self.s1(x) 62 | x = self.relu(x1 + s1) 63 | x = self.bn(x) 64 | return x 65 | 66 | class res_path(nn.Module): 67 | def __init__(self, in_c, out_c, length): 68 | super().__init__() 69 | 70 | layers = [] 71 | for i in range(length): 72 | layers.append(res_path_block(in_c, out_c)) 73 | in_c = out_c 74 | 75 | self.conv = nn.Sequential(*layers) 76 | 77 | def forward(self, x): 78 | return self.conv(x) 79 | 80 | def cal_nf(ch, alpha=1.67): 81 | W = ch * alpha 82 | return int(W*0.167) + int(W*0.333) + int(W*0.5) 83 | 84 | class encoder_block(nn.Module): 85 | def __init__(self, in_c, out_c, length): 86 | super().__init__() 87 | 88 | self.c1 = multires_block(in_c, out_c) 89 | nf = cal_nf(out_c) 90 | self.s1 = res_path(nf, out_c, length) 91 | self.pool = nn.MaxPool2d((2, 2)) 92 | 93 | def forward(self, x): 94 | x = self.c1(x) 95 | s = self.s1(x) 96 | p = self.pool(x) 97 | return s, p 98 | 99 | class decoder_block(nn.Module): 100 | def __init__(self, in_c, out_c): 101 | super().__init__() 102 | 103 | self.c1 = nn.ConvTranspose2d(in_c[0], out_c, kernel_size=2, stride=2, padding=0) 104 | self.c2 = multires_block(out_c+in_c[1], out_c) 105 | 106 | def forward(self, x, s): 107 | x = self.c1(x) 108 | x = torch.cat([x, s], dim=1) 109 | x = self.c2(x) 110 | return x 111 | 112 | class build_multiresunet(nn.Module): 113 | def __init__(self): 114 | super().__init__() 115 | 116 | """ Encoder """ 117 | self.e1 = encoder_block(3, 32, 4) 118 | self.e2 = encoder_block(cal_nf(32), 64, 3) 119 | self.e3 = encoder_block(cal_nf(64), 128, 2) 120 | self.e4 = encoder_block(cal_nf(128), 256, 1) 121 | 122 | """ Bridge """ 123 | self.b1 = multires_block(cal_nf(256), 512) 124 | 125 | """ Decoder """ 126 | self.d1 = decoder_block([cal_nf(512), 256], 256) 127 | self.d2 = decoder_block([cal_nf(256), 128], 128) 128 | self.d3 = decoder_block([cal_nf(128), 64], 64) 129 | self.d4 = decoder_block([cal_nf(64), 32], 32) 130 | 131 | """ Output """ 132 | self.output = nn.Conv2d(cal_nf(32), 1, kernel_size=1, padding=0) 133 | 134 | def forward(self, x): 135 | s1, p1 = self.e1(x) 136 | s2, p2 = self.e2(p1) 137 | s3, p3 = self.e3(p2) 138 | s4, p4 = self.e4(p3) 139 | 140 | b1 = self.b1(p4) 141 | 142 | d1 = self.d1(b1, s4) 143 | d2 = self.d2(d1, s3) 144 | d3 = self.d3(d2, s2) 145 | d4 = self.d4(d3, s1) 146 | 147 | output = self.output(d4) 148 | return output 149 | 150 | if __name__ == "__main__": 151 | x = torch.randn((8, 3, 256, 256)) 152 | model = build_multiresunet() 153 | output = model(x) 154 | print(output.shape) 155 | -------------------------------------------------------------------------------- /PyTorch/resunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class batchnorm_relu(nn.Module): 5 | def __init__(self, in_c): 6 | super().__init__() 7 | 8 | self.bn = nn.BatchNorm2d(in_c) 9 | self.relu = nn.ReLU() 10 | 11 | def forward(self, inputs): 12 | x = self.bn(inputs) 13 | x = self.relu(x) 14 | return x 15 | 16 | class residual_block(nn.Module): 17 | def __init__(self, in_c, out_c, stride=1): 18 | super().__init__() 19 | 20 | """ Convolutional layer """ 21 | self.b1 = batchnorm_relu(in_c) 22 | self.c1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, stride=stride) 23 | self.b2 = batchnorm_relu(out_c) 24 | self.c2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1, stride=1) 25 | 26 | """ Shortcut Connection (Identity Mapping) """ 27 | self.s = nn.Conv2d(in_c, out_c, kernel_size=1, padding=0, stride=stride) 28 | 29 | def forward(self, inputs): 30 | x = self.b1(inputs) 31 | x = self.c1(x) 32 | x = self.b2(x) 33 | x = self.c2(x) 34 | s = self.s(inputs) 35 | 36 | skip = x + s 37 | return skip 38 | 39 | class decoder_block(nn.Module): 40 | def __init__(self, in_c, out_c): 41 | super().__init__() 42 | 43 | self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 44 | self.r = residual_block(in_c+out_c, out_c) 45 | 46 | def forward(self, inputs, skip): 47 | x = self.upsample(inputs) 48 | x = torch.cat([x, skip], axis=1) 49 | x = self.r(x) 50 | return x 51 | 52 | class build_resunet(nn.Module): 53 | def __init__(self): 54 | super().__init__() 55 | 56 | """ Encoder 1 """ 57 | self.c11 = nn.Conv2d(3, 64, kernel_size=3, padding=1) 58 | self.br1 = batchnorm_relu(64) 59 | self.c12 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 60 | self.c13 = nn.Conv2d(3, 64, kernel_size=1, padding=0) 61 | 62 | """ Encoder 2 and 3 """ 63 | self.r2 = residual_block(64, 128, stride=2) 64 | self.r3 = residual_block(128, 256, stride=2) 65 | 66 | """ Bridge """ 67 | self.r4 = residual_block(256, 512, stride=2) 68 | 69 | """ Decoder """ 70 | self.d1 = decoder_block(512, 256) 71 | self.d2 = decoder_block(256, 128) 72 | self.d3 = decoder_block(128, 64) 73 | 74 | """ Output """ 75 | self.output = nn.Conv2d(64, 1, kernel_size=1, padding=0) 76 | self.sigmoid = nn.Sigmoid() 77 | 78 | def forward(self, inputs): 79 | """ Encoder 1 """ 80 | x = self.c11(inputs) 81 | x = self.br1(x) 82 | x = self.c12(x) 83 | s = self.c13(inputs) 84 | skip1 = x + s 85 | 86 | """ Encoder 2 and 3 """ 87 | skip2 = self.r2(skip1) 88 | skip3 = self.r3(skip2) 89 | 90 | """ Bridge """ 91 | b = self.r4(skip3) 92 | 93 | """ Decoder """ 94 | d1 = self.d1(b, skip3) 95 | d2 = self.d2(d1, skip2) 96 | d3 = self.d3(d2, skip1) 97 | 98 | """ output """ 99 | output = self.output(d3) 100 | output = self.sigmoid(output) 101 | 102 | return output 103 | 104 | 105 | if __name__ == "__main__": 106 | inputs = torch.randn((4, 3, 256, 256)) 107 | model = build_resunet() 108 | y = model(inputs) 109 | print(y.shape) 110 | -------------------------------------------------------------------------------- /PyTorch/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | """ Convolutional block: 5 | It follows a two 3x3 convolutional layer, each followed by a batch normalization and a relu activation. 6 | """ 7 | class conv_block(nn.Module): 8 | def __init__(self, in_c, out_c): 9 | super().__init__() 10 | 11 | self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1) 12 | self.bn1 = nn.BatchNorm2d(out_c) 13 | 14 | self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1) 15 | self.bn2 = nn.BatchNorm2d(out_c) 16 | 17 | self.relu = nn.ReLU() 18 | 19 | def forward(self, inputs): 20 | x = self.conv1(inputs) 21 | x = self.bn1(x) 22 | x = self.relu(x) 23 | 24 | x = self.conv2(x) 25 | x = self.bn2(x) 26 | x = self.relu(x) 27 | 28 | return x 29 | 30 | """ Encoder block: 31 | It consists of an conv_block followed by a max pooling. 32 | Here the number of filters doubles and the height and width half after every block. 33 | """ 34 | class encoder_block(nn.Module): 35 | def __init__(self, in_c, out_c): 36 | super().__init__() 37 | 38 | self.conv = conv_block(in_c, out_c) 39 | self.pool = nn.MaxPool2d((2, 2)) 40 | 41 | def forward(self, inputs): 42 | x = self.conv(inputs) 43 | p = self.pool(x) 44 | 45 | return x, p 46 | 47 | """ Decoder block: 48 | The decoder block begins with a transpose convolution, followed by a concatenation with the skip 49 | connection from the encoder block. Next comes the conv_block. 50 | Here the number filters decreases by half and the height and width doubles. 51 | """ 52 | class decoder_block(nn.Module): 53 | def __init__(self, in_c, out_c): 54 | super().__init__() 55 | 56 | self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0) 57 | self.conv = conv_block(out_c+out_c, out_c) 58 | 59 | def forward(self, inputs, skip): 60 | x = self.up(inputs) 61 | x = torch.cat([x, skip], axis=1) 62 | x = self.conv(x) 63 | 64 | return x 65 | 66 | 67 | class build_unet(nn.Module): 68 | def __init__(self): 69 | super().__init__() 70 | 71 | """ Encoder """ 72 | self.e1 = encoder_block(3, 64) 73 | self.e2 = encoder_block(64, 128) 74 | self.e3 = encoder_block(128, 256) 75 | self.e4 = encoder_block(256, 512) 76 | 77 | """ Bottleneck """ 78 | self.b = conv_block(512, 1024) 79 | 80 | """ Decoder """ 81 | self.d1 = decoder_block(1024, 512) 82 | self.d2 = decoder_block(512, 256) 83 | self.d3 = decoder_block(256, 128) 84 | self.d4 = decoder_block(128, 64) 85 | 86 | """ Classifier """ 87 | self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0) 88 | 89 | def forward(self, inputs): 90 | """ Encoder """ 91 | s1, p1 = self.e1(inputs) 92 | s2, p2 = self.e2(p1) 93 | s3, p3 = self.e3(p2) 94 | s4, p4 = self.e4(p3) 95 | 96 | """ Bottleneck """ 97 | b = self.b(p4) 98 | 99 | """ Decoder """ 100 | d1 = self.d1(b, s4) 101 | d2 = self.d2(d1, s3) 102 | d3 = self.d3(d2, s2) 103 | d4 = self.d4(d3, s1) 104 | 105 | """ Classifier """ 106 | outputs = self.outputs(d4) 107 | 108 | return outputs 109 | 110 | if __name__ == "__main__": 111 | # inputs = torch.randn((2, 32, 256, 256)) 112 | # e = encoder_block(32, 64) 113 | # x, p = e(inputs) 114 | # print(x.shape, p.shape) 115 | # 116 | # d = decoder_block(64, 32) 117 | # y = d(p, x) 118 | # print(y.shape) 119 | 120 | inputs = torch.randn((2, 3, 512, 512)) 121 | model = build_unet() 122 | y = model(inputs) 123 | print(y.shape) 124 | -------------------------------------------------------------------------------- /PyTorch/unetr_2d.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class ConvBlock(nn.Module): 7 | def __init__(self, in_c, out_c, kernel_size=3, padding=1): 8 | super().__init__() 9 | 10 | self.layers = nn.Sequential( 11 | nn.Conv2d(in_c, out_c, kernel_size=kernel_size, padding=padding), 12 | nn.BatchNorm2d(out_c), 13 | nn.ReLU(inplace=True) 14 | ) 15 | 16 | 17 | def forward(self, x): 18 | return self.layers(x) 19 | 20 | 21 | class DeconvBlock(nn.Module): 22 | def __init__(self, in_c, out_c): 23 | super().__init__() 24 | 25 | self.deconv = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0) 26 | 27 | def forward(self, x): 28 | return self.deconv(x) 29 | 30 | 31 | class UNETR_2D(nn.Module): 32 | def __init__(self, cf): 33 | super().__init__() 34 | self.cf = cf 35 | 36 | """ Patch + Position Embeddings """ 37 | self.patch_embed = nn.Linear( 38 | cf["patch_size"]*cf["patch_size"]*cf["num_channels"], 39 | cf["hidden_dim"] 40 | ) 41 | 42 | self.positions = torch.arange(start=0, end=cf["num_patches"], step=1, dtype=torch.int32) 43 | self.pos_embed = nn.Embedding(cf["num_patches"], cf["hidden_dim"]) 44 | 45 | """ Transformer Encoder """ 46 | self.trans_encoder_layers = [] 47 | 48 | for i in range(cf["num_layers"]): 49 | layer = nn.TransformerEncoderLayer( 50 | d_model=cf["hidden_dim"], 51 | nhead=cf["num_heads"], 52 | dim_feedforward=cf["mlp_dim"], 53 | dropout=cf["dropout_rate"], 54 | activation=nn.GELU(), 55 | batch_first=True 56 | ) 57 | self.trans_encoder_layers.append(layer) 58 | 59 | """ CNN Decoder """ 60 | ## Decoder 1 61 | self.d1 = DeconvBlock(cf["hidden_dim"], 512) 62 | self.s1 = nn.Sequential( 63 | DeconvBlock(cf["hidden_dim"], 512), 64 | ConvBlock(512, 512) 65 | ) 66 | self.c1 = nn.Sequential( 67 | ConvBlock(512+512, 512), 68 | ConvBlock(512, 512) 69 | ) 70 | 71 | ## Decoder 2 72 | self.d2 = DeconvBlock(512, 256) 73 | self.s2 = nn.Sequential( 74 | DeconvBlock(cf["hidden_dim"], 256), 75 | ConvBlock(256, 256), 76 | DeconvBlock(256, 256), 77 | ConvBlock(256, 256) 78 | ) 79 | self.c2 = nn.Sequential( 80 | ConvBlock(256+256, 256), 81 | ConvBlock(256, 256) 82 | ) 83 | 84 | ## Decoder 3 85 | self.d3 = DeconvBlock(256, 128) 86 | self.s3 = nn.Sequential( 87 | DeconvBlock(cf["hidden_dim"], 128), 88 | ConvBlock(128, 128), 89 | DeconvBlock(128, 128), 90 | ConvBlock(128, 128), 91 | DeconvBlock(128, 128), 92 | ConvBlock(128, 128) 93 | ) 94 | self.c3 = nn.Sequential( 95 | ConvBlock(128+128, 128), 96 | ConvBlock(128, 128) 97 | ) 98 | 99 | ## Decoder 4 100 | self.d4 = DeconvBlock(128, 64) 101 | self.s4 = nn.Sequential( 102 | ConvBlock(3, 64), 103 | ConvBlock(64, 64) 104 | ) 105 | self.c4 = nn.Sequential( 106 | ConvBlock(64+64, 64), 107 | ConvBlock(64, 64) 108 | ) 109 | 110 | """ Output """ 111 | self.output = nn.Conv2d(64, 1, kernel_size=1, padding=0) 112 | 113 | def forward(self, inputs): 114 | """ Patch + Position Embeddings """ 115 | patch_embed = self.patch_embed(inputs) ## [8, 256, 768] 116 | 117 | positions = self.positions 118 | pos_embed = self.pos_embed(positions) ## [256, 768] 119 | 120 | x = patch_embed + pos_embed ## [8, 256, 768] 121 | 122 | """ Transformer Encoder """ 123 | skip_connection_index = [3, 6, 9, 12] 124 | skip_connections = [] 125 | 126 | for i in range(self.cf["num_layers"]): 127 | layer = self.trans_encoder_layers[i] 128 | x = layer(x) 129 | 130 | if (i+1) in skip_connection_index: 131 | skip_connections.append(x) 132 | 133 | """ CNN Decoder """ 134 | z3, z6, z9, z12 = skip_connections 135 | 136 | ## Reshaping 137 | batch = inputs.shape[0] 138 | z0 = inputs.view((batch, self.cf["num_channels"], self.cf["image_size"], self.cf["image_size"])) 139 | 140 | shape = (batch, self.cf["hidden_dim"], self.cf["patch_size"], self.cf["patch_size"]) 141 | z3 = z3.view(shape) 142 | z6 = z6.view(shape) 143 | z9 = z9.view(shape) 144 | z12 = z12.view(shape) 145 | 146 | 147 | ## Decoder 1 148 | x = self.d1(z12) 149 | s = self.s1(z9) 150 | x = torch.cat([x, s], dim=1) 151 | x = self.c1(x) 152 | 153 | ## Decoder 2 154 | x = self.d2(x) 155 | s = self.s2(z6) 156 | x = torch.cat([x, s], dim=1) 157 | x = self.c2(x) 158 | 159 | ## Decoder 3 160 | x = self.d3(x) 161 | s = self.s3(z3) 162 | x = torch.cat([x, s], dim=1) 163 | x = self.c3(x) 164 | 165 | ## Decoder 4 166 | x = self.d4(x) 167 | s = self.s4(z0) 168 | x = torch.cat([x, s], dim=1) 169 | x = self.c4(x) 170 | 171 | """ Output """ 172 | output = self.output(x) 173 | 174 | return output 175 | 176 | 177 | if __name__ == "__main__": 178 | config = {} 179 | config["image_size"] = 256 180 | config["num_layers"] = 12 181 | config["hidden_dim"] = 768 182 | config["mlp_dim"] = 3072 183 | config["num_heads"] = 12 184 | config["dropout_rate"] = 0.1 185 | config["num_patches"] = 256 186 | config["patch_size"] = 16 187 | config["num_channels"] = 3 188 | 189 | x = torch.randn(( 190 | 8, 191 | config["num_patches"], 192 | config["patch_size"]*config["patch_size"]*config["num_channels"] 193 | )) 194 | model = UNETR_2D(config) 195 | output = model(x) 196 | print(output.shape) 197 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantic-Segmentation-Architecture 2 | A repository contains the code for various semantic segmentation in TensorFlow and PyTorch framework. 3 | 4 | ## Research papers 5 | - [U-Net](https://arxiv.org/pdf/1505.04597.pdf) 6 | - [ResU-Net](https://arxiv.org/pdf/1711.10684.pdf) 7 | - [MultiResU-Net](https://arxiv.org/pdf/1902.04049) 8 | 9 | ## Star History 10 | [![Stargazers repo roster for ](https://bytecrank.com/nastyox/reporoster/php/stargazersSVG.php?user=nikhilroxtomar&repo=Semantic-Segmentation-Architecture)](https://github.com/nikhilroxtomar/Semantic-Segmentation-Architecture/stargazers) 11 | -------------------------------------------------------------------------------- /TensorFlow/attention-unet.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | import tensorflow.keras.layers as L 4 | from tensorflow.keras.models import Model 5 | 6 | def conv_block(x, num_filters): 7 | x = L.Conv2D(num_filters, 3, padding="same")(x) 8 | x = L.BatchNormalization()(x) 9 | x = L.Activation("relu")(x) 10 | 11 | x = L.Conv2D(num_filters, 3, padding="same")(x) 12 | x = L.BatchNormalization()(x) 13 | x = L.Activation("relu")(x) 14 | 15 | return x 16 | 17 | def encoder_block(x, num_filters): 18 | x = conv_block(x, num_filters) 19 | p = L.MaxPool2D((2, 2))(x) 20 | return x, p 21 | 22 | def attention_gate(g, s, num_filters): 23 | Wg = L.Conv2D(num_filters, 1, padding="same")(g) 24 | Wg = L.BatchNormalization()(Wg) 25 | 26 | Ws = L.Conv2D(num_filters, 1, padding="same")(s) 27 | Ws = L.BatchNormalization()(Ws) 28 | 29 | out = L.Activation("relu")(Wg + Ws) 30 | out = L.Conv2D(num_filters, 1, padding="same")(out) 31 | out = L.Activation("sigmoid")(out) 32 | 33 | return out * s 34 | 35 | def decoder_block(x, s, num_filters): 36 | x = L.UpSampling2D(interpolation="bilinear")(x) 37 | s = attention_gate(x, s, num_filters) 38 | x = L.Concatenate()([x, s]) 39 | x = conv_block(x, num_filters) 40 | return x 41 | 42 | def attention_unet(input_shape): 43 | """ Inputs """ 44 | inputs = L.Input(input_shape) 45 | 46 | """ Encoder """ 47 | s1, p1 = encoder_block(inputs, 64) 48 | s2, p2 = encoder_block(p1, 128) 49 | s3, p3 = encoder_block(p2, 256) 50 | 51 | b1 = conv_block(p3, 512) 52 | 53 | """ Decoder """ 54 | d1 = decoder_block(b1, s3, 256) 55 | d2 = decoder_block(d1, s2, 128) 56 | d3 = decoder_block(d2, s1, 64) 57 | 58 | """ Outputs """ 59 | outputs = L.Conv2D(1, 1, padding="same", activation="sigmoid")(d3) 60 | 61 | """ Model """ 62 | model = Model(inputs, outputs, name="Attention-UNET") 63 | return model 64 | 65 | if __name__ == "__main__": 66 | input_shape = (256, 256, 3) 67 | model = attention_unet(input_shape) 68 | model.summary() 69 | -------------------------------------------------------------------------------- /TensorFlow/colonsegnet.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" 4 | 5 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, UpSampling2D, Dense 6 | from tensorflow.keras.layers import GlobalAveragePooling2D, Conv2DTranspose, Concatenate, Input 7 | from tensorflow.keras.layers import MaxPool2D 8 | from tensorflow.keras.models import Model 9 | 10 | def se_layer(x, num_filters, reduction=16): 11 | x_init = x 12 | 13 | x = GlobalAveragePooling2D()(x) 14 | x = Dense(num_filters//reduction, use_bias=False, activation="relu")(x) 15 | x = Dense(num_filters, use_bias=False, activation="sigmoid")(x) 16 | x = x * x_init 17 | return x 18 | 19 | def residual_block(x, num_filters): 20 | x_init = x 21 | 22 | x = Conv2D(num_filters, 3, padding="same")(x) 23 | x = BatchNormalization()(x) 24 | x = Activation("relu")(x) 25 | 26 | x = Conv2D(num_filters, 3, padding="same")(x) 27 | x = BatchNormalization()(x) 28 | 29 | s = Conv2D(num_filters, 1, padding="same")(x_init) 30 | s = BatchNormalization()(s) 31 | s = se_layer(s, num_filters) 32 | 33 | x = Activation("relu")(x + s) 34 | return x 35 | 36 | def strided_conv_block(x, num_filters): 37 | x = Conv2D(num_filters, 3, strides=2, padding="same")(x) 38 | x = BatchNormalization()(x) 39 | x = Activation("relu")(x) 40 | return x 41 | 42 | def encoder_block(x, num_filters): 43 | x1 = residual_block(x, num_filters) 44 | x2 = strided_conv_block(x1, num_filters) 45 | x3 = residual_block(x2, num_filters) 46 | p = MaxPool2D((2, 2))(x3) 47 | 48 | return x1, x3, p 49 | 50 | def build_colonsegnet(input_shape): 51 | """ Input """ 52 | inputs = Input(input_shape) 53 | 54 | """ Encoder """ 55 | s11, s12, p1 = encoder_block(inputs, 64) 56 | s21, s22, p2 = encoder_block(p1, 256) 57 | 58 | """ Decoder 1 """ 59 | x = Conv2DTranspose(128, 4, strides=4, padding="same")(s22) 60 | x = Concatenate()([x, s12]) 61 | x = residual_block(x, 128) 62 | r1 = x 63 | 64 | x = Conv2DTranspose(128, 4, strides=2, padding="same")(s21) 65 | x = Concatenate()([x, r1]) 66 | x = residual_block(x, 128) 67 | 68 | """ Decoder 2 """ 69 | x = Conv2DTranspose(64, 4, strides=2, padding="same")(x) 70 | x = Concatenate()([x, s11]) 71 | x = residual_block(x, 64) 72 | r2 = x 73 | 74 | x = Conv2DTranspose(32, 4, strides=2, padding="same")(s12) 75 | x = Concatenate()([x, r2]) 76 | x = residual_block(x, 32) 77 | 78 | """ Output """ 79 | output = Conv2D(1, 1, padding="same")(x) 80 | 81 | """ Model """ 82 | model = Model(inputs, output) 83 | 84 | return model 85 | 86 | if __name__ == "__main__": 87 | input_shape = (512, 512, 3) 88 | model = build_colonsegnet(input_shape) 89 | model.summary() 90 | -------------------------------------------------------------------------------- /TensorFlow/deeplabv3plus.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" 4 | 5 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, UpSampling2D 6 | from tensorflow.keras.layers import AveragePooling2D, Conv2DTranspose, Concatenate, Input 7 | from tensorflow.keras.models import Model 8 | from tensorflow.keras.applications import ResNet50 9 | 10 | """ Atrous Spatial Pyramid Pooling """ 11 | def ASPP(inputs): 12 | shape = inputs.shape 13 | 14 | y_pool = AveragePooling2D(pool_size=(shape[1], shape[2]), name='average_pooling')(inputs) 15 | y_pool = Conv2D(filters=256, kernel_size=1, padding='same', use_bias=False)(y_pool) 16 | y_pool = BatchNormalization(name=f'bn_1')(y_pool) 17 | y_pool = Activation('relu', name=f'relu_1')(y_pool) 18 | y_pool = UpSampling2D((shape[1], shape[2]), interpolation="bilinear")(y_pool) 19 | 20 | y_1 = Conv2D(filters=256, kernel_size=1, dilation_rate=1, padding='same', use_bias=False)(inputs) 21 | y_1 = BatchNormalization()(y_1) 22 | y_1 = Activation('relu')(y_1) 23 | 24 | y_6 = Conv2D(filters=256, kernel_size=3, dilation_rate=6, padding='same', use_bias=False)(inputs) 25 | y_6 = BatchNormalization()(y_6) 26 | y_6 = Activation('relu')(y_6) 27 | 28 | y_12 = Conv2D(filters=256, kernel_size=3, dilation_rate=12, padding='same', use_bias=False)(inputs) 29 | y_12 = BatchNormalization()(y_12) 30 | y_12 = Activation('relu')(y_12) 31 | 32 | y_18 = Conv2D(filters=256, kernel_size=3, dilation_rate=18, padding='same', use_bias=False)(inputs) 33 | y_18 = BatchNormalization()(y_18) 34 | y_18 = Activation('relu')(y_18) 35 | 36 | y = Concatenate()([y_pool, y_1, y_6, y_12, y_18]) 37 | 38 | y = Conv2D(filters=256, kernel_size=1, dilation_rate=1, padding='same', use_bias=False)(y) 39 | y = BatchNormalization()(y) 40 | y = Activation('relu')(y) 41 | return y 42 | 43 | def DeepLabV3Plus(shape): 44 | """ Inputs """ 45 | inputs = Input(shape) 46 | 47 | """ Pre-trained ResNet50 """ 48 | base_model = ResNet50(weights='imagenet', include_top=False, input_tensor=inputs) 49 | 50 | """ Pre-trained ResNet50 Output """ 51 | image_features = base_model.get_layer('conv4_block6_out').output 52 | x_a = ASPP(image_features) 53 | x_a = UpSampling2D((4, 4), interpolation="bilinear")(x_a) 54 | 55 | """ Get low-level features """ 56 | x_b = base_model.get_layer('conv2_block2_out').output 57 | x_b = Conv2D(filters=48, kernel_size=1, padding='same', use_bias=False)(x_b) 58 | x_b = BatchNormalization()(x_b) 59 | x_b = Activation('relu')(x_b) 60 | 61 | x = Concatenate()([x_a, x_b]) 62 | 63 | x = Conv2D(filters=256, kernel_size=3, padding='same', activation='relu',use_bias=False)(x) 64 | x = BatchNormalization()(x) 65 | x = Activation('relu')(x) 66 | 67 | x = Conv2D(filters=256, kernel_size=3, padding='same', activation='relu', use_bias=False)(x) 68 | x = BatchNormalization()(x) 69 | x = Activation('relu')(x) 70 | x = UpSampling2D((4, 4), interpolation="bilinear")(x) 71 | 72 | """ Outputs """ 73 | x = Conv2D(1, (1, 1), name='output_layer')(x) 74 | x = Activation('sigmoid')(x) 75 | 76 | """ Model """ 77 | model = Model(inputs=inputs, outputs=x) 78 | return model 79 | 80 | if __name__ == "__main__": 81 | input_shape = (512, 512, 3) 82 | model = DeepLabV3Plus(input_shape) 83 | model.summary() 84 | 85 | -------------------------------------------------------------------------------- /TensorFlow/densenet121.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input 2 | from tensorflow.keras.models import Model 3 | from tensorflow.keras.applications import DenseNet121 4 | 5 | def conv_block(inputs, num_filters): 6 | x = Conv2D(num_filters, 3, padding="same")(inputs) 7 | x = BatchNormalization()(x) 8 | x = Activation("relu")(x) 9 | 10 | x = Conv2D(num_filters, 3, padding="same")(x) 11 | x = BatchNormalization()(x) 12 | x = Activation("relu")(x) 13 | 14 | return x 15 | 16 | def decoder_block(inputs, skip_features, num_filters): 17 | x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(inputs) 18 | x = Concatenate()([x, skip_features]) 19 | x = conv_block(x, num_filters) 20 | return x 21 | 22 | def build_densenet121_unet(input_shape): 23 | """ Input """ 24 | inputs = Input(input_shape) 25 | 26 | """ Pre-trained DenseNet121 Model """ 27 | densenet = DenseNet121(include_top=False, weights="imagenet", input_tensor=inputs) 28 | 29 | """ Encoder """ 30 | s1 = densenet.get_layer("input_1").output ## 512 31 | s2 = densenet.get_layer("conv1/relu").output ## 256 32 | s3 = densenet.get_layer("pool2_relu").output ## 128 33 | s4 = densenet.get_layer("pool3_relu").output ## 64 34 | 35 | """ Bridge """ 36 | b1 = densenet.get_layer("pool4_relu").output ## 32 37 | 38 | """ Decoder """ 39 | d1 = decoder_block(b1, s4, 512) ## 64 40 | d2 = decoder_block(d1, s3, 256) ## 128 41 | d3 = decoder_block(d2, s2, 128) ## 256 42 | d4 = decoder_block(d3, s1, 64) ## 512 43 | 44 | """ Outputs """ 45 | outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4) 46 | 47 | model = Model(inputs, outputs) 48 | return model 49 | 50 | 51 | if __name__ == "__main__": 52 | input_shape = (512, 512, 3) 53 | model = build_densenet121_unet(input_shape) 54 | model.summary() 55 | -------------------------------------------------------------------------------- /TensorFlow/doubleunet.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" 4 | 5 | import tensorflow as tf 6 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input 7 | from tensorflow.keras.layers import GlobalAveragePooling2D, Reshape, Dense, Multiply, AveragePooling2D, UpSampling2D 8 | from tensorflow.keras.models import Model 9 | from tensorflow.keras.applications import VGG19 10 | 11 | def squeeze_excite_block(inputs, ratio=8): 12 | init = inputs ## (b, 128, 128, 32) 13 | channel_axis = -1 14 | filters = init.shape[channel_axis] 15 | se_shape = (1, 1, filters) 16 | 17 | se = GlobalAveragePooling2D()(init) ## (b, 32) -> (b, 1, 1, 32) 18 | se = Reshape(se_shape)(se) 19 | se = Dense(filters//ratio, activation="relu", use_bias=False)(se) 20 | se = Dense(filters, activation="sigmoid", use_bias=False)(se) 21 | 22 | x = Multiply()([inputs, se]) 23 | return x 24 | 25 | def ASPP(x, filter): 26 | shape = x.shape 27 | 28 | y1 = AveragePooling2D(pool_size=(shape[1], shape[2]))(x) 29 | y1 = Conv2D(filter, 1, padding="same")(y1) 30 | y1 = BatchNormalization()(y1) 31 | y1 = Activation("relu")(y1) 32 | y1 = UpSampling2D((shape[1], shape[2]), interpolation="bilinear")(y1) 33 | 34 | y2 = Conv2D(filter, 1, dilation_rate=1, padding="same", use_bias=False)(x) 35 | y2 = BatchNormalization()(y2) 36 | y2 = Activation("relu")(y2) 37 | 38 | y3 = Conv2D(filter, 3, dilation_rate=6, padding="same", use_bias=False)(x) 39 | y3 = BatchNormalization()(y3) 40 | y3 = Activation("relu")(y3) 41 | 42 | y4 = Conv2D(filter, 3, dilation_rate=12, padding="same", use_bias=False)(x) 43 | y4 = BatchNormalization()(y4) 44 | y4 = Activation("relu")(y4) 45 | 46 | y5 = Conv2D(filter, 3, dilation_rate=18, padding="same", use_bias=False)(x) 47 | y5 = BatchNormalization()(y5) 48 | y5 = Activation("relu")(y5) 49 | 50 | y = Concatenate()([y1, y2, y3, y4, y5]) 51 | 52 | y = Conv2D(filter, 1, dilation_rate=1, padding="same", use_bias=False)(y) 53 | y = BatchNormalization()(y) 54 | y = Activation("relu")(y) 55 | 56 | return y 57 | 58 | def conv_block(x, filters): 59 | x = Conv2D(filters, 3, padding="same")(x) 60 | x = BatchNormalization()(x) 61 | x = Activation("relu")(x) 62 | 63 | x = Conv2D(filters, 3, padding="same")(x) 64 | x = BatchNormalization()(x) 65 | x = Activation("relu")(x) 66 | 67 | x = squeeze_excite_block(x) 68 | 69 | return x 70 | 71 | def encoder1(inputs): 72 | skip_connections = [] 73 | 74 | model = VGG19(include_top=False, weights="imagenet", input_tensor=inputs) 75 | names = ["block1_conv2", "block2_conv2", "block3_conv4", "block4_conv4"] 76 | for name in names: 77 | skip_connections.append(model.get_layer(name).output) 78 | 79 | output = model.get_layer("block5_conv4").output 80 | return output, skip_connections 81 | 82 | def decoder1(inputs, skip_connections): 83 | num_filters = [256, 128, 64, 32] 84 | skip_connections.reverse() 85 | 86 | x = inputs 87 | for i, f in enumerate(num_filters): 88 | x = UpSampling2D((2, 2), interpolation="bilinear")(x) 89 | x = Concatenate()([x, skip_connections[i]]) 90 | x = conv_block(x, f) 91 | 92 | return x 93 | 94 | def output_block(inputs): 95 | x = Conv2D(1, 1, padding="same")(inputs) 96 | x = Activation("sigmoid")(x) 97 | return x 98 | 99 | def encoder2(inputs): 100 | num_filters = [32, 64, 128, 256] 101 | skip_connections = [] 102 | 103 | x = inputs 104 | for i, f in enumerate(num_filters): 105 | x = conv_block(x, f) 106 | skip_connections.append(x) 107 | x = MaxPool2D((2, 2))(x) 108 | 109 | return x, skip_connections 110 | 111 | def decoder2(inputs, skip_1, skip_2): 112 | num_filters = [256, 128, 64, 32] 113 | skip_2.reverse() 114 | 115 | x = inputs 116 | for i, f in enumerate(num_filters): 117 | x = UpSampling2D((2, 2), interpolation="bilinear")(x) 118 | x = Concatenate()([x, skip_1[i], skip_2[i]]) 119 | x = conv_block(x, f) 120 | 121 | return x 122 | 123 | def build_model(input_shape): 124 | inputs = Input(input_shape) 125 | x, skip_1 = encoder1(inputs) 126 | x = ASPP(x, 64) 127 | x = decoder1(x, skip_1) 128 | output1 = output_block(x) 129 | 130 | x = inputs * output1 131 | 132 | x, skip_2 = encoder2(x) 133 | x = ASPP(x, 64) 134 | x = decoder2(x, skip_1, skip_2) 135 | output2 = output_block(x) 136 | 137 | outputs = Concatenate()([output1, output2]) 138 | model = Model(inputs, outputs) 139 | return model 140 | 141 | 142 | if __name__ == "__main__": 143 | input_shape = (256, 256, 3) 144 | model = build_model(input_shape) 145 | model.summary() 146 | -------------------------------------------------------------------------------- /TensorFlow/efficientnetb0_unet.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input 2 | from tensorflow.keras.models import Model 3 | from tensorflow.keras.applications import EfficientNetB0 4 | import tensorflow as tf 5 | 6 | print("TF Version: ", tf.__version__) 7 | 8 | def conv_block(inputs, num_filters): 9 | x = Conv2D(num_filters, 3, padding="same")(inputs) 10 | x = BatchNormalization()(x) 11 | x = Activation("relu")(x) 12 | 13 | x = Conv2D(num_filters, 3, padding="same")(x) 14 | x = BatchNormalization()(x) 15 | x = Activation("relu")(x) 16 | 17 | return x 18 | 19 | def decoder_block(inputs, skip, num_filters): 20 | x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(inputs) 21 | x = Concatenate()([x, skip]) 22 | x = conv_block(x, num_filters) 23 | return x 24 | 25 | def build_effienet_unet(input_shape): 26 | """ Input """ 27 | inputs = Input(input_shape) 28 | 29 | """ Pre-trained Encoder """ 30 | encoder = EfficientNetB0(include_top=False, weights="imagenet", input_tensor=inputs) 31 | 32 | s1 = encoder.get_layer("input_1").output ## 256 33 | s2 = encoder.get_layer("block2a_expand_activation").output ## 128 34 | s3 = encoder.get_layer("block3a_expand_activation").output ## 64 35 | s4 = encoder.get_layer("block4a_expand_activation").output ## 32 36 | 37 | """ Bottleneck """ 38 | b1 = encoder.get_layer("block6a_expand_activation").output ## 16 39 | 40 | """ Decoder """ 41 | d1 = decoder_block(b1, s4, 512) ## 32 42 | d2 = decoder_block(d1, s3, 256) ## 64 43 | d3 = decoder_block(d2, s2, 128) ## 128 44 | d4 = decoder_block(d3, s1, 64) ## 256 45 | 46 | """ Output """ 47 | outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4) 48 | 49 | model = Model(inputs, outputs, name="EfficientNetB0_UNET") 50 | return model 51 | 52 | if __name__ == "__main__": 53 | input_shape = (256, 256, 3) 54 | model = build_effienet_unet(input_shape) 55 | model.summary() 56 | -------------------------------------------------------------------------------- /TensorFlow/inception_resnetv2_unet.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input, ZeroPadding2D 2 | from tensorflow.keras.models import Model 3 | from tensorflow.keras.applications import InceptionResNetV2 4 | 5 | def conv_block(input, num_filters): 6 | x = Conv2D(num_filters, 3, padding="same")(input) 7 | x = BatchNormalization()(x) 8 | x = Activation("relu")(x) 9 | 10 | x = Conv2D(num_filters, 3, padding="same")(x) 11 | x = BatchNormalization()(x) 12 | x = Activation("relu")(x) 13 | 14 | return x 15 | 16 | def decoder_block(input, skip_features, num_filters): 17 | x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input) 18 | x = Concatenate()([x, skip_features]) 19 | x = conv_block(x, num_filters) 20 | return x 21 | 22 | def build_inception_resnetv2_unet(input_shape): 23 | """ Input """ 24 | inputs = Input(input_shape) 25 | 26 | """ Pre-trained InceptionResNetV2 Model """ 27 | encoder = InceptionResNetV2(include_top=False, weights="imagenet", input_tensor=inputs) 28 | 29 | """ Encoder """ 30 | s1 = encoder.get_layer("input_1").output ## (512 x 512) 31 | 32 | s2 = encoder.get_layer("activation").output ## (255 x 255) 33 | s2 = ZeroPadding2D(( (1, 0), (1, 0) ))(s2) ## (256 x 256) 34 | 35 | s3 = encoder.get_layer("activation_3").output ## (126 x 126) 36 | s3 = ZeroPadding2D((1, 1))(s3) ## (128 x 128) 37 | 38 | s4 = encoder.get_layer("activation_74").output ## (61 x 61) 39 | s4 = ZeroPadding2D(( (2, 1),(2, 1) ))(s4) ## (64 x 64) 40 | 41 | """ Bridge """ 42 | b1 = encoder.get_layer("activation_161").output ## (30 x 30) 43 | b1 = ZeroPadding2D((1, 1))(b1) ## (32 x 32) 44 | 45 | """ Decoder """ 46 | d1 = decoder_block(b1, s4, 512) ## (64 x 64) 47 | d2 = decoder_block(d1, s3, 256) ## (128 x 128) 48 | d3 = decoder_block(d2, s2, 128) ## (256 x 256) 49 | d4 = decoder_block(d3, s1, 64) ## (512 x 512) 50 | 51 | """ Output """ 52 | outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4) 53 | 54 | model = Model(inputs, outputs, name="InceptionResNetV2_U-Net") 55 | return model 56 | 57 | if __name__ == "__main__": 58 | input_shape = (512, 512, 3) 59 | model = build_inception_resnetv2_unet(input_shape) 60 | model.summary() 61 | -------------------------------------------------------------------------------- /TensorFlow/mobilenetv2_unet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input 3 | from tensorflow.keras.models import Model 4 | from tensorflow.keras.applications import MobileNetV2 5 | 6 | print("TF Version: ", tf.__version__) 7 | 8 | def conv_block(inputs, num_filters): 9 | x = Conv2D(num_filters, 3, padding="same")(inputs) 10 | x = BatchNormalization()(x) 11 | x = Activation("relu")(x) 12 | 13 | x = Conv2D(num_filters, 3, padding="same")(x) 14 | x = BatchNormalization()(x) 15 | x = Activation("relu")(x) 16 | 17 | return x 18 | 19 | def decoder_block(inputs, skip, num_filters): 20 | x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(inputs) 21 | x = Concatenate()([x, skip]) 22 | x = conv_block(x, num_filters) 23 | 24 | return x 25 | 26 | def build_mobilenetv2_unet(input_shape): ## (512, 512, 3) 27 | """ Input """ 28 | inputs = Input(shape=input_shape) 29 | 30 | """ Pre-trained MobileNetV2 """ 31 | encoder = MobileNetV2(include_top=False, weights="imagenet", 32 | input_tensor=inputs, alpha=1.4) 33 | 34 | """ Encoder """ 35 | s1 = encoder.get_layer("input_1").output ## (512 x 512) 36 | s2 = encoder.get_layer("block_1_expand_relu").output ## (256 x 256) 37 | s3 = encoder.get_layer("block_3_expand_relu").output ## (128 x 128) 38 | s4 = encoder.get_layer("block_6_expand_relu").output ## (64 x 64) 39 | 40 | """ Bridge """ 41 | b1 = encoder.get_layer("block_13_expand_relu").output ## (32 x 32) 42 | 43 | """ Decoder """ 44 | d1 = decoder_block(b1, s4, 512) ## (64 x 64) 45 | d2 = decoder_block(d1, s3, 256) ## (128 x 128) 46 | d3 = decoder_block(d2, s2, 128) ## (256 x 256) 47 | d4 = decoder_block(d3, s1, 64) ## (512 x 512) 48 | 49 | """ Output """ 50 | outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4) 51 | 52 | model = Model(inputs, outputs, name="MobileNetV2_U-Net") 53 | return model 54 | 55 | if __name__ == "__main__": 56 | model = build_mobilenetv2_unet((512, 512, 3)) 57 | model.summary() 58 | -------------------------------------------------------------------------------- /TensorFlow/multiresunet.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPooling2D, Conv2DTranspose 2 | from tensorflow.keras.layers import Concatenate, Input 3 | from tensorflow.keras.models import Model 4 | 5 | def conv_block(x, num_filters, kernel_size, padding="same", act=True): 6 | x = Conv2D(num_filters, kernel_size, padding=padding, use_bias=False)(x) 7 | x = BatchNormalization()(x) 8 | if act: 9 | x = Activation("relu")(x) 10 | return x 11 | 12 | def multires_block(x, num_filters, alpha=1.67): 13 | W = num_filters * alpha 14 | 15 | x0 = x 16 | x1 = conv_block(x0, int(W*0.167), 3) 17 | x2 = conv_block(x1, int(W*0.333), 3) 18 | x3 = conv_block(x2, int(W*0.5), 3) 19 | xc = Concatenate()([x1, x2, x3]) 20 | xc = BatchNormalization()(xc) 21 | 22 | nf = int(W*0.167) + int(W*0.333) + int(W*0.5) 23 | sc = conv_block(x0, nf, 1, act=False) 24 | 25 | x = Activation("relu")(xc + sc) 26 | x = BatchNormalization()(x) 27 | return x 28 | 29 | def res_path(x, num_filters, length): 30 | for i in range(length): 31 | x0 = x 32 | x1 = conv_block(x0, num_filters, 3, act=False) 33 | sc = conv_block(x0, num_filters, 1, act=False) 34 | x = Activation("relu")(x1 + sc) 35 | x = BatchNormalization()(x) 36 | return x 37 | 38 | def encoder_block(x, num_filters, length): 39 | x = multires_block(x, num_filters) 40 | s = res_path(x, num_filters, length) 41 | p = MaxPooling2D((2, 2))(x) 42 | return s, p 43 | 44 | def decoder_block(x, skip, num_filters): 45 | x = Conv2DTranspose(num_filters, 2, strides=2, padding="same")(x) 46 | x = Concatenate()([x, skip]) 47 | x = multires_block(x, num_filters) 48 | return x 49 | 50 | def build_multiresunet(shape): 51 | """ Input """ 52 | inputs = Input(shape) 53 | 54 | """ Encoder """ 55 | p0 = inputs 56 | s1, p1 = encoder_block(p0, 32, 4) 57 | s2, p2 = encoder_block(p1, 64, 3) 58 | s3, p3 = encoder_block(p2, 128, 2) 59 | s4, p4 = encoder_block(p3, 256, 1) 60 | 61 | """ Bridge """ 62 | b1 = multires_block(p4, 512) 63 | 64 | """ Decoder """ 65 | d1 = decoder_block(b1, s4, 256) 66 | d2 = decoder_block(d1, s3, 128) 67 | d3 = decoder_block(d2, s2, 64) 68 | d4 = decoder_block(d3, s1, 32) 69 | 70 | """ Output """ 71 | outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4) 72 | 73 | """ Model """ 74 | model = Model(inputs, outputs, name="MultiResUNET") 75 | 76 | return model 77 | 78 | if __name__ == "__main__": 79 | shape = (256, 256, 3) 80 | model = build_multiresunet(shape) 81 | model.summary() 82 | -------------------------------------------------------------------------------- /TensorFlow/notebook/ColonSegNet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# ColonSegNet\n", 8 | "\n", 9 | "
\n", 10 | "\n", 11 | "\n", 12 | "Research Paper: Real-Time Polyp Detection, Localization and Segmentation in Colonoscopy Using Deep Learning \n", 13 | "\n", 14 | "
\n", 15 | "\n", 19 | " \n", 20 | "
\n", 21 | "\n", 22 | "" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "## Import" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 1, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, UpSampling2D, Dense\n", 39 | "from tensorflow.keras.layers import GlobalAveragePooling2D, Conv2DTranspose, Concatenate, Input\n", 40 | "from tensorflow.keras.layers import MaxPool2D\n", 41 | "from tensorflow.keras.models import Model" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "## Squeeze and Excitation\n", 49 | "" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 2, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "def se_layer(x, num_filters, reduction=16):\n", 59 | " x_init = x\n", 60 | " \n", 61 | " x = GlobalAveragePooling2D()(x)\n", 62 | " x = Dense(num_filters//reduction, use_bias=False, activation=\"relu\")(x)\n", 63 | " x = Dense(num_filters, use_bias=False, activation=\"sigmoid\")(x)\n", 64 | " \n", 65 | " return x_init * x" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "## Residual Block\n", 73 | "" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 3, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "def residual_block(x, num_filters):\n", 83 | " x_init = x\n", 84 | " \n", 85 | " x = Conv2D(num_filters, 3, padding=\"same\")(x)\n", 86 | " x = BatchNormalization()(x)\n", 87 | " x = Activation(\"relu\")(x)\n", 88 | " \n", 89 | " x = Conv2D(num_filters, 3, padding=\"same\")(x)\n", 90 | " x = BatchNormalization()(x)\n", 91 | " \n", 92 | " s = Conv2D(num_filters, 1, padding=\"same\")(x_init)\n", 93 | " s = BatchNormalization()(x)\n", 94 | " s = se_layer(s, num_filters)\n", 95 | " \n", 96 | " x = Activation(\"relu\")(x + s)\n", 97 | " \n", 98 | " return x" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": {}, 104 | "source": [ 105 | "## Strided Convolution\n", 106 | "" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 4, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "def strided_conv_block(x, num_filters):\n", 116 | " x = Conv2D(num_filters, 3, strides=2, padding=\"same\")(x)\n", 117 | " x = BatchNormalization()(x)\n", 118 | " x = Activation(\"relu\")(x)\n", 119 | " return x" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "## Encoder Block" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 5, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "def encoder_block(x, num_filters):\n", 136 | " x1 = residual_block(x, num_filters)\n", 137 | " x2 = strided_conv_block(x1, num_filters)\n", 138 | " x3 = residual_block(x2, num_filters)\n", 139 | " p = MaxPool2D((2, 2))(x3)\n", 140 | " \n", 141 | " return x1, x3, p" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "## ColonSegNet" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 12, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "def build_colonsegnet(input_shape):\n", 158 | " \"\"\" Input \"\"\"\n", 159 | " inputs = Input(input_shape)\n", 160 | " \n", 161 | " \"\"\" Encoder \"\"\"\n", 162 | " s11, s12, p1 = encoder_block(inputs, 64)\n", 163 | " s21, s22, p2 = encoder_block(p1, 256)\n", 164 | " \n", 165 | " \"\"\" Decoder 1 \"\"\"\n", 166 | " x = Conv2DTranspose(128, 4, strides=4, padding=\"same\")(s22)\n", 167 | " x = Concatenate()([x, s12])\n", 168 | " x = residual_block(x, 128)\n", 169 | " r1 = x\n", 170 | " \n", 171 | " x = Conv2DTranspose(128, 4, strides=2, padding=\"same\")(s21)\n", 172 | " x = Concatenate()([x, r1])\n", 173 | " x = residual_block(x, 128)\n", 174 | " \n", 175 | " \"\"\" Decoder 2 \"\"\"\n", 176 | " x = Conv2DTranspose(64, 4, strides=2, padding=\"same\")(x)\n", 177 | " x = Concatenate()([x, s11])\n", 178 | " x = residual_block(x, 64)\n", 179 | " r2 = x\n", 180 | " \n", 181 | " x = Conv2DTranspose(64, 4, strides=2, padding=\"same\")(s12)\n", 182 | " x = Concatenate()([x, r2])\n", 183 | " x = residual_block(x, 32)\n", 184 | " \n", 185 | " \"\"\" Output \"\"\"\n", 186 | " output = Conv2D(5, 1, padding=\"same\", activation=\"softmax\")(x)\n", 187 | " \n", 188 | " \"\"\" Model \"\"\"\n", 189 | " model = Model(inputs, output, name=\"ColonSegNet\")\n", 190 | " \n", 191 | " return model" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "metadata": {}, 197 | "source": [ 198 | "## Model" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 13, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "input_shape = (512, 512, 3)\n", 208 | "model = build_colonsegnet(input_shape)" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 14, 214 | "metadata": { 215 | "scrolled": false 216 | }, 217 | "outputs": [ 218 | { 219 | "name": "stdout", 220 | "output_type": "stream", 221 | "text": [ 222 | "Model: \"ColonSegNet\"\n", 223 | "__________________________________________________________________________________________________\n", 224 | "Layer (type) Output Shape Param # Connected to \n", 225 | "==================================================================================================\n", 226 | "input_3 (InputLayer) [(None, 512, 512, 3) 0 \n", 227 | "__________________________________________________________________________________________________\n", 228 | "conv2d_54 (Conv2D) (None, 512, 512, 64) 1792 input_3[0][0] \n", 229 | "__________________________________________________________________________________________________\n", 230 | "batch_normalization_52 (BatchNo (None, 512, 512, 64) 256 conv2d_54[0][0] \n", 231 | "__________________________________________________________________________________________________\n", 232 | "activation_36 (Activation) (None, 512, 512, 64) 0 batch_normalization_52[0][0] \n", 233 | "__________________________________________________________________________________________________\n", 234 | "conv2d_55 (Conv2D) (None, 512, 512, 64) 36928 activation_36[0][0] \n", 235 | "__________________________________________________________________________________________________\n", 236 | "batch_normalization_53 (BatchNo (None, 512, 512, 64) 256 conv2d_55[0][0] \n", 237 | "__________________________________________________________________________________________________\n", 238 | "batch_normalization_54 (BatchNo (None, 512, 512, 64) 256 batch_normalization_53[0][0] \n", 239 | "__________________________________________________________________________________________________\n", 240 | "global_average_pooling2d_16 (Gl (None, 64) 0 batch_normalization_54[0][0] \n", 241 | "__________________________________________________________________________________________________\n", 242 | "dense_32 (Dense) (None, 4) 256 global_average_pooling2d_16[0][0]\n", 243 | "__________________________________________________________________________________________________\n", 244 | "dense_33 (Dense) (None, 64) 256 dense_32[0][0] \n", 245 | "__________________________________________________________________________________________________\n", 246 | "tf.math.multiply_16 (TFOpLambda (None, 512, 512, 64) 0 batch_normalization_54[0][0] \n", 247 | " dense_33[0][0] \n", 248 | "__________________________________________________________________________________________________\n", 249 | "tf.__operators__.add_16 (TFOpLa (None, 512, 512, 64) 0 batch_normalization_53[0][0] \n", 250 | " tf.math.multiply_16[0][0] \n", 251 | "__________________________________________________________________________________________________\n", 252 | "activation_37 (Activation) (None, 512, 512, 64) 0 tf.__operators__.add_16[0][0] \n", 253 | "__________________________________________________________________________________________________\n", 254 | "conv2d_57 (Conv2D) (None, 256, 256, 64) 36928 activation_37[0][0] \n", 255 | "__________________________________________________________________________________________________\n", 256 | "batch_normalization_55 (BatchNo (None, 256, 256, 64) 256 conv2d_57[0][0] \n", 257 | "__________________________________________________________________________________________________\n", 258 | "activation_38 (Activation) (None, 256, 256, 64) 0 batch_normalization_55[0][0] \n", 259 | "__________________________________________________________________________________________________\n", 260 | "conv2d_58 (Conv2D) (None, 256, 256, 64) 36928 activation_38[0][0] \n", 261 | "__________________________________________________________________________________________________\n", 262 | "batch_normalization_56 (BatchNo (None, 256, 256, 64) 256 conv2d_58[0][0] \n", 263 | "__________________________________________________________________________________________________\n", 264 | "activation_39 (Activation) (None, 256, 256, 64) 0 batch_normalization_56[0][0] \n", 265 | "__________________________________________________________________________________________________\n", 266 | "conv2d_59 (Conv2D) (None, 256, 256, 64) 36928 activation_39[0][0] \n", 267 | "__________________________________________________________________________________________________\n", 268 | "batch_normalization_57 (BatchNo (None, 256, 256, 64) 256 conv2d_59[0][0] \n", 269 | "__________________________________________________________________________________________________\n", 270 | "batch_normalization_58 (BatchNo (None, 256, 256, 64) 256 batch_normalization_57[0][0] \n", 271 | "__________________________________________________________________________________________________\n", 272 | "global_average_pooling2d_17 (Gl (None, 64) 0 batch_normalization_58[0][0] \n", 273 | "__________________________________________________________________________________________________\n", 274 | "dense_34 (Dense) (None, 4) 256 global_average_pooling2d_17[0][0]\n", 275 | "__________________________________________________________________________________________________\n", 276 | "dense_35 (Dense) (None, 64) 256 dense_34[0][0] \n", 277 | "__________________________________________________________________________________________________\n", 278 | "tf.math.multiply_17 (TFOpLambda (None, 256, 256, 64) 0 batch_normalization_58[0][0] \n", 279 | " dense_35[0][0] \n", 280 | "__________________________________________________________________________________________________\n", 281 | "tf.__operators__.add_17 (TFOpLa (None, 256, 256, 64) 0 batch_normalization_57[0][0] \n", 282 | " tf.math.multiply_17[0][0] \n", 283 | "__________________________________________________________________________________________________\n", 284 | "activation_40 (Activation) (None, 256, 256, 64) 0 tf.__operators__.add_17[0][0] \n", 285 | "__________________________________________________________________________________________________\n", 286 | "max_pooling2d_4 (MaxPooling2D) (None, 128, 128, 64) 0 activation_40[0][0] \n", 287 | "__________________________________________________________________________________________________\n", 288 | "conv2d_61 (Conv2D) (None, 128, 128, 256 147712 max_pooling2d_4[0][0] \n", 289 | "__________________________________________________________________________________________________\n", 290 | "batch_normalization_59 (BatchNo (None, 128, 128, 256 1024 conv2d_61[0][0] \n", 291 | "__________________________________________________________________________________________________\n", 292 | "activation_41 (Activation) (None, 128, 128, 256 0 batch_normalization_59[0][0] \n", 293 | "__________________________________________________________________________________________________\n", 294 | "conv2d_62 (Conv2D) (None, 128, 128, 256 590080 activation_41[0][0] \n", 295 | "__________________________________________________________________________________________________\n", 296 | "batch_normalization_60 (BatchNo (None, 128, 128, 256 1024 conv2d_62[0][0] \n", 297 | "__________________________________________________________________________________________________\n", 298 | "batch_normalization_61 (BatchNo (None, 128, 128, 256 1024 batch_normalization_60[0][0] \n", 299 | "__________________________________________________________________________________________________\n", 300 | "global_average_pooling2d_18 (Gl (None, 256) 0 batch_normalization_61[0][0] \n", 301 | "__________________________________________________________________________________________________\n", 302 | "dense_36 (Dense) (None, 16) 4096 global_average_pooling2d_18[0][0]\n", 303 | "__________________________________________________________________________________________________\n", 304 | "dense_37 (Dense) (None, 256) 4096 dense_36[0][0] \n", 305 | "__________________________________________________________________________________________________\n", 306 | "tf.math.multiply_18 (TFOpLambda (None, 128, 128, 256 0 batch_normalization_61[0][0] \n", 307 | " dense_37[0][0] \n", 308 | "__________________________________________________________________________________________________\n", 309 | "tf.__operators__.add_18 (TFOpLa (None, 128, 128, 256 0 batch_normalization_60[0][0] \n", 310 | " tf.math.multiply_18[0][0] \n", 311 | "__________________________________________________________________________________________________\n", 312 | "activation_42 (Activation) (None, 128, 128, 256 0 tf.__operators__.add_18[0][0] \n", 313 | "__________________________________________________________________________________________________\n", 314 | "conv2d_64 (Conv2D) (None, 64, 64, 256) 590080 activation_42[0][0] \n", 315 | "__________________________________________________________________________________________________\n", 316 | "batch_normalization_62 (BatchNo (None, 64, 64, 256) 1024 conv2d_64[0][0] \n", 317 | "__________________________________________________________________________________________________\n", 318 | "activation_43 (Activation) (None, 64, 64, 256) 0 batch_normalization_62[0][0] \n", 319 | "__________________________________________________________________________________________________\n", 320 | "conv2d_65 (Conv2D) (None, 64, 64, 256) 590080 activation_43[0][0] \n", 321 | "__________________________________________________________________________________________________\n", 322 | "batch_normalization_63 (BatchNo (None, 64, 64, 256) 1024 conv2d_65[0][0] \n", 323 | "__________________________________________________________________________________________________\n", 324 | "activation_44 (Activation) (None, 64, 64, 256) 0 batch_normalization_63[0][0] \n", 325 | "__________________________________________________________________________________________________\n", 326 | "conv2d_66 (Conv2D) (None, 64, 64, 256) 590080 activation_44[0][0] \n", 327 | "__________________________________________________________________________________________________\n", 328 | "batch_normalization_64 (BatchNo (None, 64, 64, 256) 1024 conv2d_66[0][0] \n", 329 | "__________________________________________________________________________________________________\n", 330 | "batch_normalization_65 (BatchNo (None, 64, 64, 256) 1024 batch_normalization_64[0][0] \n", 331 | "__________________________________________________________________________________________________\n", 332 | "global_average_pooling2d_19 (Gl (None, 256) 0 batch_normalization_65[0][0] \n", 333 | "__________________________________________________________________________________________________\n", 334 | "dense_38 (Dense) (None, 16) 4096 global_average_pooling2d_19[0][0]\n", 335 | "__________________________________________________________________________________________________\n", 336 | "dense_39 (Dense) (None, 256) 4096 dense_38[0][0] \n", 337 | "__________________________________________________________________________________________________\n", 338 | "tf.math.multiply_19 (TFOpLambda (None, 64, 64, 256) 0 batch_normalization_65[0][0] \n", 339 | " dense_39[0][0] \n", 340 | "__________________________________________________________________________________________________\n", 341 | "tf.__operators__.add_19 (TFOpLa (None, 64, 64, 256) 0 batch_normalization_64[0][0] \n", 342 | " tf.math.multiply_19[0][0] \n", 343 | "__________________________________________________________________________________________________\n", 344 | "activation_45 (Activation) (None, 64, 64, 256) 0 tf.__operators__.add_19[0][0] \n", 345 | "__________________________________________________________________________________________________\n", 346 | "conv2d_transpose_8 (Conv2DTrans (None, 256, 256, 128 524416 activation_45[0][0] \n", 347 | "__________________________________________________________________________________________________\n", 348 | "concatenate_8 (Concatenate) (None, 256, 256, 192 0 conv2d_transpose_8[0][0] \n", 349 | " activation_40[0][0] \n", 350 | "__________________________________________________________________________________________________\n", 351 | "conv2d_68 (Conv2D) (None, 256, 256, 128 221312 concatenate_8[0][0] \n", 352 | "__________________________________________________________________________________________________\n", 353 | "batch_normalization_66 (BatchNo (None, 256, 256, 128 512 conv2d_68[0][0] \n", 354 | "__________________________________________________________________________________________________\n", 355 | "activation_46 (Activation) (None, 256, 256, 128 0 batch_normalization_66[0][0] \n", 356 | "__________________________________________________________________________________________________\n", 357 | "conv2d_69 (Conv2D) (None, 256, 256, 128 147584 activation_46[0][0] \n", 358 | "__________________________________________________________________________________________________\n", 359 | "batch_normalization_67 (BatchNo (None, 256, 256, 128 512 conv2d_69[0][0] \n", 360 | "__________________________________________________________________________________________________\n", 361 | "batch_normalization_68 (BatchNo (None, 256, 256, 128 512 batch_normalization_67[0][0] \n", 362 | "__________________________________________________________________________________________________\n", 363 | "global_average_pooling2d_20 (Gl (None, 128) 0 batch_normalization_68[0][0] \n", 364 | "__________________________________________________________________________________________________\n", 365 | "dense_40 (Dense) (None, 8) 1024 global_average_pooling2d_20[0][0]\n", 366 | "__________________________________________________________________________________________________\n", 367 | "dense_41 (Dense) (None, 128) 1024 dense_40[0][0] \n", 368 | "__________________________________________________________________________________________________\n", 369 | "tf.math.multiply_20 (TFOpLambda (None, 256, 256, 128 0 batch_normalization_68[0][0] \n", 370 | " dense_41[0][0] \n", 371 | "__________________________________________________________________________________________________\n", 372 | "tf.__operators__.add_20 (TFOpLa (None, 256, 256, 128 0 batch_normalization_67[0][0] \n", 373 | " tf.math.multiply_20[0][0] \n", 374 | "__________________________________________________________________________________________________\n", 375 | "conv2d_transpose_9 (Conv2DTrans (None, 256, 256, 128 524416 activation_42[0][0] \n", 376 | "__________________________________________________________________________________________________\n", 377 | "activation_47 (Activation) (None, 256, 256, 128 0 tf.__operators__.add_20[0][0] \n", 378 | "__________________________________________________________________________________________________\n", 379 | "concatenate_9 (Concatenate) (None, 256, 256, 256 0 conv2d_transpose_9[0][0] \n", 380 | " activation_47[0][0] \n", 381 | "__________________________________________________________________________________________________\n", 382 | "conv2d_71 (Conv2D) (None, 256, 256, 128 295040 concatenate_9[0][0] \n", 383 | "__________________________________________________________________________________________________\n", 384 | "batch_normalization_69 (BatchNo (None, 256, 256, 128 512 conv2d_71[0][0] \n", 385 | "__________________________________________________________________________________________________\n", 386 | "activation_48 (Activation) (None, 256, 256, 128 0 batch_normalization_69[0][0] \n", 387 | "__________________________________________________________________________________________________\n", 388 | "conv2d_72 (Conv2D) (None, 256, 256, 128 147584 activation_48[0][0] \n", 389 | "__________________________________________________________________________________________________\n", 390 | "batch_normalization_70 (BatchNo (None, 256, 256, 128 512 conv2d_72[0][0] \n", 391 | "__________________________________________________________________________________________________\n", 392 | "batch_normalization_71 (BatchNo (None, 256, 256, 128 512 batch_normalization_70[0][0] \n", 393 | "__________________________________________________________________________________________________\n", 394 | "global_average_pooling2d_21 (Gl (None, 128) 0 batch_normalization_71[0][0] \n", 395 | "__________________________________________________________________________________________________\n", 396 | "dense_42 (Dense) (None, 8) 1024 global_average_pooling2d_21[0][0]\n", 397 | "__________________________________________________________________________________________________\n", 398 | "dense_43 (Dense) (None, 128) 1024 dense_42[0][0] \n", 399 | "__________________________________________________________________________________________________\n", 400 | "tf.math.multiply_21 (TFOpLambda (None, 256, 256, 128 0 batch_normalization_71[0][0] \n", 401 | " dense_43[0][0] \n", 402 | "__________________________________________________________________________________________________\n", 403 | "tf.__operators__.add_21 (TFOpLa (None, 256, 256, 128 0 batch_normalization_70[0][0] \n", 404 | " tf.math.multiply_21[0][0] \n", 405 | "__________________________________________________________________________________________________\n", 406 | "activation_49 (Activation) (None, 256, 256, 128 0 tf.__operators__.add_21[0][0] \n", 407 | "__________________________________________________________________________________________________\n", 408 | "conv2d_transpose_10 (Conv2DTran (None, 512, 512, 64) 131136 activation_49[0][0] \n", 409 | "__________________________________________________________________________________________________\n", 410 | "concatenate_10 (Concatenate) (None, 512, 512, 128 0 conv2d_transpose_10[0][0] \n", 411 | " activation_37[0][0] \n", 412 | "__________________________________________________________________________________________________\n", 413 | "conv2d_74 (Conv2D) (None, 512, 512, 64) 73792 concatenate_10[0][0] \n", 414 | "__________________________________________________________________________________________________\n", 415 | "batch_normalization_72 (BatchNo (None, 512, 512, 64) 256 conv2d_74[0][0] \n", 416 | "__________________________________________________________________________________________________\n", 417 | "activation_50 (Activation) (None, 512, 512, 64) 0 batch_normalization_72[0][0] \n", 418 | "__________________________________________________________________________________________________\n", 419 | "conv2d_75 (Conv2D) (None, 512, 512, 64) 36928 activation_50[0][0] \n", 420 | "__________________________________________________________________________________________________\n", 421 | "batch_normalization_73 (BatchNo (None, 512, 512, 64) 256 conv2d_75[0][0] \n", 422 | "__________________________________________________________________________________________________\n", 423 | "batch_normalization_74 (BatchNo (None, 512, 512, 64) 256 batch_normalization_73[0][0] \n", 424 | "__________________________________________________________________________________________________\n", 425 | "global_average_pooling2d_22 (Gl (None, 64) 0 batch_normalization_74[0][0] \n", 426 | "__________________________________________________________________________________________________\n", 427 | "dense_44 (Dense) (None, 4) 256 global_average_pooling2d_22[0][0]\n", 428 | "__________________________________________________________________________________________________\n", 429 | "dense_45 (Dense) (None, 64) 256 dense_44[0][0] \n", 430 | "__________________________________________________________________________________________________\n", 431 | "tf.math.multiply_22 (TFOpLambda (None, 512, 512, 64) 0 batch_normalization_74[0][0] \n", 432 | " dense_45[0][0] \n", 433 | "__________________________________________________________________________________________________\n", 434 | "tf.__operators__.add_22 (TFOpLa (None, 512, 512, 64) 0 batch_normalization_73[0][0] \n", 435 | " tf.math.multiply_22[0][0] \n", 436 | "__________________________________________________________________________________________________\n", 437 | "conv2d_transpose_11 (Conv2DTran (None, 512, 512, 64) 65600 activation_40[0][0] \n", 438 | "__________________________________________________________________________________________________\n", 439 | "activation_51 (Activation) (None, 512, 512, 64) 0 tf.__operators__.add_22[0][0] \n", 440 | "__________________________________________________________________________________________________\n", 441 | "concatenate_11 (Concatenate) (None, 512, 512, 128 0 conv2d_transpose_11[0][0] \n", 442 | " activation_51[0][0] \n", 443 | "__________________________________________________________________________________________________\n", 444 | "conv2d_77 (Conv2D) (None, 512, 512, 32) 36896 concatenate_11[0][0] \n", 445 | "__________________________________________________________________________________________________\n", 446 | "batch_normalization_75 (BatchNo (None, 512, 512, 32) 128 conv2d_77[0][0] \n", 447 | "__________________________________________________________________________________________________\n", 448 | "activation_52 (Activation) (None, 512, 512, 32) 0 batch_normalization_75[0][0] \n", 449 | "__________________________________________________________________________________________________\n", 450 | "conv2d_78 (Conv2D) (None, 512, 512, 32) 9248 activation_52[0][0] \n", 451 | "__________________________________________________________________________________________________\n", 452 | "batch_normalization_76 (BatchNo (None, 512, 512, 32) 128 conv2d_78[0][0] \n", 453 | "__________________________________________________________________________________________________\n", 454 | "batch_normalization_77 (BatchNo (None, 512, 512, 32) 128 batch_normalization_76[0][0] \n", 455 | "__________________________________________________________________________________________________\n", 456 | "global_average_pooling2d_23 (Gl (None, 32) 0 batch_normalization_77[0][0] \n", 457 | "__________________________________________________________________________________________________\n", 458 | "dense_46 (Dense) (None, 2) 64 global_average_pooling2d_23[0][0]\n", 459 | "__________________________________________________________________________________________________\n", 460 | "dense_47 (Dense) (None, 32) 64 dense_46[0][0] \n", 461 | "__________________________________________________________________________________________________\n", 462 | "tf.math.multiply_23 (TFOpLambda (None, 512, 512, 32) 0 batch_normalization_77[0][0] \n", 463 | " dense_47[0][0] \n", 464 | "__________________________________________________________________________________________________\n", 465 | "tf.__operators__.add_23 (TFOpLa (None, 512, 512, 32) 0 batch_normalization_76[0][0] \n", 466 | " tf.math.multiply_23[0][0] \n", 467 | "__________________________________________________________________________________________________\n", 468 | "activation_53 (Activation) (None, 512, 512, 32) 0 tf.__operators__.add_23[0][0] \n", 469 | "__________________________________________________________________________________________________\n", 470 | "conv2d_80 (Conv2D) (None, 512, 512, 5) 165 activation_53[0][0] \n", 471 | "==================================================================================================\n", 472 | "Total params: 4,906,981\n", 473 | "Trainable params: 4,900,389\n", 474 | "Non-trainable params: 6,592\n", 475 | "__________________________________________________________________________________________________\n" 476 | ] 477 | } 478 | ], 479 | "source": [ 480 | "model.summary()" 481 | ] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": null, 486 | "metadata": {}, 487 | "outputs": [], 488 | "source": [] 489 | } 490 | ], 491 | "metadata": { 492 | "kernelspec": { 493 | "display_name": "tf", 494 | "language": "python", 495 | "name": "tf" 496 | }, 497 | "language_info": { 498 | "codemirror_mode": { 499 | "name": "ipython", 500 | "version": 3 501 | }, 502 | "file_extension": ".py", 503 | "mimetype": "text/x-python", 504 | "name": "python", 505 | "nbconvert_exporter": "python", 506 | "pygments_lexer": "ipython3", 507 | "version": "3.8.10" 508 | } 509 | }, 510 | "nbformat": 4, 511 | "nbformat_minor": 4 512 | } 513 | -------------------------------------------------------------------------------- /TensorFlow/notebook/README.md: -------------------------------------------------------------------------------- 1 | # Jupyter Notebook 2 | -------------------------------------------------------------------------------- /TensorFlow/notebook/images/ColonSegNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilroxtomar/Semantic-Segmentation-Architecture/558a6b88108a58aac21ce4109397022e21cc6f1c/TensorFlow/notebook/images/ColonSegNet.png -------------------------------------------------------------------------------- /TensorFlow/notebook/images/ResidualBlock.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilroxtomar/Semantic-Segmentation-Architecture/558a6b88108a58aac21ce4109397022e21cc6f1c/TensorFlow/notebook/images/ResidualBlock.png -------------------------------------------------------------------------------- /TensorFlow/notebook/images/Strided_Conv_Block.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilroxtomar/Semantic-Segmentation-Architecture/558a6b88108a58aac21ce4109397022e21cc6f1c/TensorFlow/notebook/images/Strided_Conv_Block.png -------------------------------------------------------------------------------- /TensorFlow/notebook/images/squeeze_and_excitation_detailed_block_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilroxtomar/Semantic-Segmentation-Architecture/558a6b88108a58aac21ce4109397022e21cc6f1c/TensorFlow/notebook/images/squeeze_and_excitation_detailed_block_diagram.png -------------------------------------------------------------------------------- /TensorFlow/resnet50_unet.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input 2 | from tensorflow.keras.models import Model 3 | from tensorflow.keras.applications import ResNet50 4 | 5 | def conv_block(input, num_filters): 6 | x = Conv2D(num_filters, 3, padding="same")(input) 7 | x = BatchNormalization()(x) 8 | x = Activation("relu")(x) 9 | 10 | x = Conv2D(num_filters, 3, padding="same")(x) 11 | x = BatchNormalization()(x) 12 | x = Activation("relu")(x) 13 | 14 | return x 15 | 16 | def decoder_block(input, skip_features, num_filters): 17 | x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input) 18 | x = Concatenate()([x, skip_features]) 19 | x = conv_block(x, num_filters) 20 | return x 21 | 22 | def build_resnet50_unet(input_shape): 23 | """ Input """ 24 | inputs = Input(input_shape) 25 | 26 | """ Pre-trained ResNet50 Model """ 27 | resnet50 = ResNet50(include_top=False, weights="imagenet", input_tensor=inputs) 28 | 29 | """ Encoder """ 30 | s1 = resnet50.get_layer("input_1").output ## (512 x 512) 31 | s2 = resnet50.get_layer("conv1_relu").output ## (256 x 256) 32 | s3 = resnet50.get_layer("conv2_block3_out").output ## (128 x 128) 33 | s4 = resnet50.get_layer("conv3_block4_out").output ## (64 x 64) 34 | 35 | """ Bridge """ 36 | b1 = resnet50.get_layer("conv4_block6_out").output ## (32 x 32) 37 | 38 | """ Decoder """ 39 | d1 = decoder_block(b1, s4, 512) ## (64 x 64) 40 | d2 = decoder_block(d1, s3, 256) ## (128 x 128) 41 | d3 = decoder_block(d2, s2, 128) ## (256 x 256) 42 | d4 = decoder_block(d3, s1, 64) ## (512 x 512) 43 | 44 | """ Output """ 45 | outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4) 46 | 47 | model = Model(inputs, outputs, name="ResNet50_U-Net") 48 | return model 49 | 50 | if __name__ == "__main__": 51 | input_shape = (512, 512, 3) 52 | model = build_resnet50_unet(input_shape) 53 | model.summary() 54 | -------------------------------------------------------------------------------- /TensorFlow/resunet++.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.keras.layers as L 3 | from tensorflow.keras.models import Model 4 | 5 | def SE(inputs, ratio=8): 6 | ## [8, H, W, 32] 7 | channel_axis = -1 8 | num_filters = inputs.shape[channel_axis] 9 | se_shape = (1, 1, num_filters) 10 | 11 | x = L.GlobalAveragePooling2D()(inputs) ## [8, 32] 12 | x = L.Reshape(se_shape)(x) 13 | x = L.Dense(num_filters // ratio, activation='relu', use_bias=False)(x) 14 | x = L.Dense(num_filters, activation='sigmoid', use_bias=False)(x) 15 | 16 | x = L.Multiply()([inputs, x]) 17 | return x 18 | 19 | 20 | def stem_block(inputs, num_filters): 21 | ## Conv 1 22 | x = L.Conv2D(num_filters, 3, padding="same")(inputs) 23 | x = L.BatchNormalization()(x) 24 | x = L.Activation("relu")(x) 25 | x = L.Conv2D(num_filters, 3, padding="same")(x) 26 | 27 | ## Shortcut 28 | s = L.Conv2D(num_filters, 1, padding="same")(inputs) 29 | 30 | ## Add 31 | x = L.Add()([x, s]) 32 | return x 33 | 34 | def resnet_block(inputs, num_filters, strides=1): 35 | ## SE 36 | inputs = SE(inputs) 37 | 38 | ## Conv 1 39 | x = L.BatchNormalization()(inputs) 40 | x = L.Activation("relu")(x) 41 | x = L.Conv2D(num_filters, 3, padding="same", strides=strides)(x) 42 | 43 | ## Conv 2 44 | x = L.BatchNormalization()(x) 45 | x = L.Activation("relu")(x) 46 | x = L.Conv2D(num_filters, 3, padding="same", strides=1)(x) 47 | 48 | ## Shortcut 49 | s = L.Conv2D(num_filters, 1, padding="same", strides=strides)(inputs) 50 | 51 | ## Add 52 | x = L.Add()([x, s]) 53 | 54 | return x 55 | 56 | def aspp_block(inputs, num_filters): 57 | x1 = L.Conv2D(num_filters, 3, dilation_rate=6, padding="same")(inputs) 58 | x1 = L.BatchNormalization()(x1) 59 | 60 | x2 = L.Conv2D(num_filters, 3, dilation_rate=12, padding="same")(inputs) 61 | x2 = L.BatchNormalization()(x2) 62 | 63 | x3 = L.Conv2D(num_filters, 3, dilation_rate=18, padding="same")(inputs) 64 | x3 = L.BatchNormalization()(x3) 65 | 66 | x4 = L.Conv2D(num_filters, (3, 3), padding="same")(inputs) 67 | x4 = L.BatchNormalization()(x4) 68 | 69 | y = L.Add()([x1, x2, x3, x4]) 70 | y = L.Conv2D(num_filters, 1, padding="same")(y) 71 | 72 | return y 73 | 74 | def attetion_block(x1, x2): 75 | num_filters = x2.shape[-1] 76 | 77 | x1_conv = L.BatchNormalization()(x1) 78 | x1_conv = L.Activation("relu")(x1_conv) 79 | x1_conv = L.Conv2D(num_filters, 3, padding="same")(x1_conv) 80 | x1_pool = L.MaxPooling2D((2, 2))(x1_conv) 81 | 82 | x2_conv = L.BatchNormalization()(x2) 83 | x2_conv = L.Activation("relu")(x2_conv) 84 | x2_conv = L.Conv2D(num_filters, 3, padding="same")(x2_conv) 85 | 86 | x = L.Add()([x1_pool, x2_conv]) 87 | 88 | x = L.BatchNormalization()(x) 89 | x = L.Activation("relu")(x) 90 | x = L.Conv2D(num_filters, 3, padding="same")(x) 91 | 92 | x = L.Multiply()([x, x2]) 93 | return x 94 | 95 | def resunet_pp(input_shape): 96 | """ Inputs """ 97 | inputs = L.Input(input_shape) 98 | 99 | """ Encoder """ 100 | c1 = stem_block(inputs, 16) 101 | c2 = resnet_block(c1, 32, strides=2) 102 | c3 = resnet_block(c2, 64, strides=2) 103 | c4 = resnet_block(c3, 128, strides=2) 104 | 105 | """ Bridge """ 106 | b1 = aspp_block(c4, 256) 107 | 108 | """ Decoder """ 109 | d1 = attetion_block(c3, b1) 110 | d1 = L.UpSampling2D((2, 2))(d1) 111 | d1 = L.Concatenate()([d1, c3]) 112 | d1 = resnet_block(d1, 128) 113 | 114 | d2 = attetion_block(c2, d1) 115 | d2 = L.UpSampling2D((2, 2))(d2) 116 | d2 = L.Concatenate()([d2, c2]) 117 | d2 = resnet_block(d2, 64) 118 | 119 | d3 = attetion_block(c1, d2) 120 | d3 = L.UpSampling2D((2, 2))(d3) 121 | d3 = L.Concatenate()([d3, c1]) 122 | d3 = resnet_block(d3, 32) 123 | 124 | """ Output""" 125 | outputs = aspp_block(d3, 16) 126 | outputs = L.Conv2D(1, 1, padding="same")(outputs) 127 | outputs = L.Activation("sigmoid")(outputs) 128 | 129 | """ Model """ 130 | model = Model(inputs, outputs) 131 | return model 132 | 133 | if __name__ == "__main__": 134 | input_shape = (256, 256, 3) 135 | model = resunet_pp(input_shape) 136 | model.summary() 137 | -------------------------------------------------------------------------------- /TensorFlow/resunet.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, UpSampling2D, Concatenate, Input 2 | from tensorflow.keras.models import Model 3 | 4 | def batchnorm_relu(inputs): 5 | """ Batch Normalization & ReLU """ 6 | x = BatchNormalization()(inputs) 7 | x = Activation("relu")(x) 8 | return x 9 | 10 | def residual_block(inputs, num_filters, strides=1): 11 | """ Convolutional Layers """ 12 | x = batchnorm_relu(inputs) 13 | x = Conv2D(num_filters, 3, padding="same", strides=strides)(x) 14 | x = batchnorm_relu(x) 15 | x = Conv2D(num_filters, 3, padding="same", strides=1)(x) 16 | 17 | """ Shortcut Connection (Identity Mapping) """ 18 | s = Conv2D(num_filters, 1, padding="same", strides=strides)(inputs) 19 | 20 | """ Addition """ 21 | x = x + s 22 | return x 23 | 24 | def decoder_block(inputs, skip_features, num_filters): 25 | """ Decoder Block """ 26 | 27 | x = UpSampling2D((2, 2))(inputs) 28 | x = Concatenate()([x, skip_features]) 29 | x = residual_block(x, num_filters, strides=1) 30 | return x 31 | 32 | def build_resunet(input_shape): 33 | """ RESUNET Architecture """ 34 | 35 | inputs = Input(input_shape) 36 | 37 | """ Endoder 1 """ 38 | x = Conv2D(64, 3, padding="same", strides=1)(inputs) 39 | x = batchnorm_relu(x) 40 | x = Conv2D(64, 3, padding="same", strides=1)(x) 41 | s = Conv2D(64, 1, padding="same")(inputs) 42 | s1 = x + s 43 | 44 | """ Encoder 2, 3 """ 45 | s2 = residual_block(s1, 128, strides=2) 46 | s3 = residual_block(s2, 256, strides=2) 47 | 48 | """ Bridge """ 49 | b = residual_block(s3, 512, strides=2) 50 | 51 | """ Decoder 1, 2, 3 """ 52 | x = decoder_block(b, s3, 256) 53 | x = decoder_block(x, s2, 128) 54 | x = decoder_block(x, s1, 64) 55 | 56 | """ Classifier """ 57 | outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(x) 58 | 59 | """ Model """ 60 | model = Model(inputs, outputs, name="RESUNET") 61 | 62 | return model 63 | 64 | if __name__ == "__main__": 65 | shape = (224, 224, 3) 66 | model = build_resunet(shape) 67 | model.summary() 68 | -------------------------------------------------------------------------------- /TensorFlow/u2-net.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" 4 | 5 | import tensorflow as tf 6 | from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, MaxPool2D, UpSampling2D, Concatenate, Add 7 | 8 | def conv_block(inputs, out_ch, rate=1): 9 | x = Conv2D(out_ch, 3, padding="same", dilation_rate=rate)(inputs) 10 | x = BatchNormalization()(x) 11 | x = Activation("relu")(x) 12 | return x 13 | 14 | def RSU_L(inputs, out_ch, int_ch, num_layers, rate=2): 15 | """ Initial Conv """ 16 | x = conv_block(inputs, out_ch) 17 | init_feats = x 18 | 19 | """ Encoder """ 20 | skip = [] 21 | x = conv_block(x, int_ch) 22 | skip.append(x) 23 | 24 | for i in range(num_layers-2): 25 | x = MaxPool2D((2, 2))(x) 26 | x = conv_block(x, int_ch) 27 | skip.append(x) 28 | 29 | """ Bridge """ 30 | x = conv_block(x, int_ch, rate=rate) 31 | 32 | """ Decoder """ 33 | skip.reverse() 34 | 35 | x = Concatenate()([x, skip[0]]) 36 | x = conv_block(x, int_ch) 37 | 38 | for i in range(num_layers-3): 39 | x = UpSampling2D(size=(2, 2), interpolation="bilinear")(x) 40 | x = Concatenate()([x, skip[i+1]]) 41 | x = conv_block(x, int_ch) 42 | 43 | x = UpSampling2D(size=(2, 2), interpolation="bilinear")(x) 44 | x = Concatenate()([x, skip[-1]]) 45 | x = conv_block(x, out_ch) 46 | 47 | """ Add """ 48 | x = Add()([x, init_feats]) 49 | return x 50 | 51 | def RSU_4F(inputs, out_ch, int_ch): 52 | """ Initial Conv """ 53 | x0 = conv_block(inputs, out_ch, rate=1) 54 | 55 | """ Encoder """ 56 | x1 = conv_block(x0, int_ch, rate=1) 57 | x2 = conv_block(x1, int_ch, rate=2) 58 | x3 = conv_block(x2, int_ch, rate=4) 59 | 60 | """ Bridge """ 61 | x4 = conv_block(x3, int_ch, rate=8) 62 | 63 | """ Decoder """ 64 | x = Concatenate()([x4, x3]) 65 | x = conv_block(x, int_ch, rate=4) 66 | 67 | x = Concatenate()([x, x2]) 68 | x = conv_block(x, int_ch, rate=2) 69 | 70 | x = Concatenate()([x, x1]) 71 | x = conv_block(x, out_ch, rate=1) 72 | 73 | """ Addition """ 74 | x = Add()([x, x0]) 75 | return x 76 | 77 | def u2net(input_shape, out_ch, int_ch, num_classes=1): 78 | """ Input Layer """ 79 | inputs = Input(input_shape) 80 | s0 = inputs 81 | 82 | """ Encoder """ 83 | s1 = RSU_L(s0, out_ch[0], int_ch[0], 7) 84 | p1 = MaxPool2D((2, 2))(s1) 85 | 86 | s2 = RSU_L(p1, out_ch[1], int_ch[1], 6) 87 | p2 = MaxPool2D((2, 2))(s2) 88 | 89 | s3 = RSU_L(p2, out_ch[2], int_ch[2], 5) 90 | p3 = MaxPool2D((2, 2))(s3) 91 | 92 | s4 = RSU_L(p3, out_ch[3], int_ch[3], 4) 93 | p4 = MaxPool2D((2, 2))(s4) 94 | 95 | s5 = RSU_4F(p4, out_ch[4], int_ch[4]) 96 | p5 = MaxPool2D((2, 2))(s5) 97 | 98 | """ Bridge """ 99 | b1 = RSU_4F(p5, out_ch[5], int_ch[5]) 100 | b2 = UpSampling2D(size=(2, 2), interpolation="bilinear")(b1) 101 | 102 | """ Decoder """ 103 | d1 = Concatenate()([b2, s5]) 104 | d1 = RSU_4F(d1, out_ch[6], int_ch[6]) 105 | u1 = UpSampling2D(size=(2, 2), interpolation="bilinear")(d1) 106 | 107 | d2 = Concatenate()([u1, s4]) 108 | d2 = RSU_L(d2, out_ch[7], int_ch[7], 4) 109 | u2 = UpSampling2D(size=(2, 2), interpolation="bilinear")(d2) 110 | 111 | d3 = Concatenate()([u2, s3]) 112 | d3 = RSU_L(d3, out_ch[8], int_ch[8], 5) 113 | u3 = UpSampling2D(size=(2, 2), interpolation="bilinear")(d3) 114 | 115 | d4 = Concatenate()([u3, s2]) 116 | d4 = RSU_L(d4, out_ch[9], int_ch[9], 6) 117 | u4 = UpSampling2D(size=(2, 2), interpolation="bilinear")(d4) 118 | 119 | d5 = Concatenate()([u4, s1]) 120 | d5 = RSU_L(d5, out_ch[10], int_ch[10], 7) 121 | 122 | """ Side Outputs """ 123 | y1 = Conv2D(num_classes, 3, padding="same")(d5) 124 | 125 | y2 = Conv2D(num_classes, 3, padding="same")(d4) 126 | y2 = UpSampling2D(size=(2, 2), interpolation="bilinear")(y2) 127 | 128 | y3 = Conv2D(num_classes, 3, padding="same")(d3) 129 | y3 = UpSampling2D(size=(4, 4), interpolation="bilinear")(y3) 130 | 131 | y4 = Conv2D(num_classes, 3, padding="same")(d2) 132 | y4 = UpSampling2D(size=(8, 8), interpolation="bilinear")(y4) 133 | 134 | y5 = Conv2D(num_classes, 3, padding="same")(d1) 135 | y5 = UpSampling2D(size=(16, 16), interpolation="bilinear")(y5) 136 | 137 | y6 = Conv2D(num_classes, 3, padding="same")(b1) 138 | y6 = UpSampling2D(size=(32, 32), interpolation="bilinear")(y6) 139 | 140 | y0 = Concatenate()([y1, y2, y3, y4, y5, y6]) 141 | y0 = Conv2D(num_classes, 3, padding="same")(y0) 142 | 143 | y0 = Activation("sigmoid")(y0) 144 | y1 = Activation("sigmoid")(y1) 145 | y2 = Activation("sigmoid")(y2) 146 | y3 = Activation("sigmoid")(y3) 147 | y4 = Activation("sigmoid")(y4) 148 | y5 = Activation("sigmoid")(y5) 149 | y6 = Activation("sigmoid")(y6) 150 | 151 | model = tf.keras.models.Model(inputs, outputs=[y0, y1, y2, y3, y4, y5, y6]) 152 | return model 153 | 154 | def build_u2net(input_shape, num_classes=1): 155 | out_ch = [64, 128, 256, 512, 512, 512, 512, 256, 128, 64, 64] 156 | int_ch = [32, 32, 64, 128, 256, 256, 256, 128, 64, 32, 16] 157 | model = u2net(input_shape, out_ch, int_ch, num_classes=num_classes) 158 | return model 159 | 160 | def build_u2net_lite(input_shape, num_classes=1): 161 | out_ch = [64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64] 162 | int_ch = [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16] 163 | model = u2net(input_shape, out_ch, int_ch, num_classes=num_classes) 164 | return model 165 | 166 | if __name__ == "__main__": 167 | model = build_u2net_lite((512, 512, 3)) 168 | model.summary() 169 | -------------------------------------------------------------------------------- /TensorFlow/unet.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input 2 | from tensorflow.keras.models import Model 3 | 4 | def conv_block(input, num_filters): 5 | x = Conv2D(num_filters, 3, padding="same")(input) 6 | x = BatchNormalization()(x) 7 | x = Activation("relu")(x) 8 | 9 | x = Conv2D(num_filters, 3, padding="same")(x) 10 | x = BatchNormalization()(x) 11 | x = Activation("relu")(x) 12 | 13 | return x 14 | 15 | def encoder_block(input, num_filters): 16 | x = conv_block(input, num_filters) 17 | p = MaxPool2D((2, 2))(x) 18 | return x, p 19 | 20 | def decoder_block(input, skip_features, num_filters): 21 | x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input) 22 | x = Concatenate()([x, skip_features]) 23 | x = conv_block(x, num_filters) 24 | return x 25 | 26 | def build_unet(input_shape): 27 | inputs = Input(input_shape) 28 | 29 | s1, p1 = encoder_block(inputs, 64) 30 | s2, p2 = encoder_block(p1, 128) 31 | s3, p3 = encoder_block(p2, 256) 32 | s4, p4 = encoder_block(p3, 512) 33 | 34 | b1 = conv_block(p4, 1024) 35 | 36 | d1 = decoder_block(b1, s4, 512) 37 | d2 = decoder_block(d1, s3, 256) 38 | d3 = decoder_block(d2, s2, 128) 39 | d4 = decoder_block(d3, s1, 64) 40 | 41 | outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4) 42 | 43 | model = Model(inputs, outputs, name="U-Net") 44 | return model 45 | 46 | if __name__ == "__main__": 47 | input_shape = (512, 512, 3) 48 | model = build_unet(input_shape) 49 | model.summary() 50 | -------------------------------------------------------------------------------- /TensorFlow/unetr_2d.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | import tensorflow.keras.layers as L 4 | from tensorflow.keras.models import Model 5 | 6 | def mlp(x, cf): 7 | x = L.Dense(cf["mlp_dim"], activation="gelu")(x) 8 | x = L.Dropout(cf["dropout_rate"])(x) 9 | x = L.Dense(cf["hidden_dim"])(x) 10 | x = L.Dropout(cf["dropout_rate"])(x) 11 | return x 12 | 13 | def transformer_encoder(x, cf): 14 | skip_1 = x 15 | x = L.LayerNormalization()(x) 16 | x = L.MultiHeadAttention( 17 | num_heads=cf["num_heads"], key_dim=cf["hidden_dim"] 18 | )(x, x) 19 | x = L.Add()([x, skip_1]) 20 | 21 | skip_2 = x 22 | x = L.LayerNormalization()(x) 23 | x = mlp(x, cf) 24 | x = L.Add()([x, skip_2]) 25 | 26 | return x 27 | 28 | def conv_block(x, num_filters, kernel_size=3): 29 | x = L.Conv2D(num_filters, kernel_size=kernel_size, padding="same")(x) 30 | x = L.BatchNormalization()(x) 31 | x = L.ReLU()(x) 32 | return x 33 | 34 | def deconv_block(x, num_filters): 35 | x = L.Conv2DTranspose(num_filters, kernel_size=2, padding="same", strides=2)(x) 36 | return x 37 | 38 | def build_unetr_2d(cf): 39 | """ Inputs """ 40 | input_shape = (cf["num_patches"], cf["patch_size"]*cf["patch_size"]*cf["num_channels"]) 41 | inputs = L.Input(input_shape) ## (None, 256, 768) 42 | 43 | """ Patch + Position Embeddings """ 44 | patch_embed = L.Dense(cf["hidden_dim"])(inputs) ## (None, 256, 768) 45 | 46 | positions = tf.range(start=0, limit=cf["num_patches"], delta=1) ## (256,) 47 | pos_embed = L.Embedding(input_dim=cf["num_patches"], output_dim=cf["hidden_dim"])(positions) ## (256, 768) 48 | x = patch_embed + pos_embed ## (None, 256, 768) 49 | 50 | """ Transformer Encoder """ 51 | skip_connection_index = [3, 6, 9, 12] 52 | skip_connections = [] 53 | 54 | for i in range(1, cf["num_layers"]+1, 1): 55 | x = transformer_encoder(x, cf) 56 | 57 | if i in skip_connection_index: 58 | skip_connections.append(x) 59 | 60 | """ CNN Decoder """ 61 | z3, z6, z9, z12 = skip_connections 62 | 63 | ## Reshaping 64 | z0 = L.Reshape((cf["image_size"], cf["image_size"], cf["num_channels"]))(inputs) 65 | z3 = L.Reshape((cf["patch_size"], cf["patch_size"], cf["hidden_dim"]))(z3) 66 | z6 = L.Reshape((cf["patch_size"], cf["patch_size"], cf["hidden_dim"]))(z6) 67 | z9 = L.Reshape((cf["patch_size"], cf["patch_size"], cf["hidden_dim"]))(z9) 68 | z12 = L.Reshape((cf["patch_size"], cf["patch_size"], cf["hidden_dim"]))(z12) 69 | 70 | ## Decoder 1 71 | x = deconv_block(z12, 512) 72 | 73 | s = deconv_block(z9, 512) 74 | s = conv_block(s, 512) 75 | x = L.Concatenate()([x, s]) 76 | 77 | x = conv_block(x, 512) 78 | x = conv_block(x, 512) 79 | 80 | ## Decoder 2 81 | x = deconv_block(x, 256) 82 | 83 | s = deconv_block(z6, 256) 84 | s = conv_block(s, 256) 85 | s = deconv_block(s, 256) 86 | s = conv_block(s, 256) 87 | 88 | x = L.Concatenate()([x, s]) 89 | x = conv_block(x, 256) 90 | x = conv_block(x, 256) 91 | 92 | ## Decoder 3 93 | x = deconv_block(x, 128) 94 | 95 | s = deconv_block(z3, 128) 96 | s = conv_block(s, 128) 97 | s = deconv_block(s, 128) 98 | s = conv_block(s, 128) 99 | s = deconv_block(s, 128) 100 | s = conv_block(s, 128) 101 | 102 | x = L.Concatenate()([x, s]) 103 | x = conv_block(x, 128) 104 | x = conv_block(x, 128) 105 | 106 | ## Decoder 4 107 | x = deconv_block(x, 64) 108 | 109 | s = conv_block(z0, 64) 110 | s = conv_block(s, 64) 111 | 112 | x = L.Concatenate()([x, s]) 113 | x = conv_block(x, 64) 114 | x = conv_block(x, 64) 115 | 116 | """ Output """ 117 | outputs = L.Conv2D(1, kernel_size=1, padding="same", activation="sigmoid")(x) 118 | 119 | return Model(inputs, outputs, name="UNETR_2D") 120 | 121 | if __name__ == "__main__": 122 | config = {} 123 | config["image_size"] = 256 124 | config["num_layers"] = 12 125 | config["hidden_dim"] = 768 126 | config["mlp_dim"] = 3072 127 | config["num_heads"] = 12 128 | config["dropout_rate"] = 0.1 129 | config["num_patches"] = 256 130 | config["patch_size"] = 16 131 | config["num_channels"] = 3 132 | 133 | model = build_unetr_2d(config) 134 | model.summary() 135 | -------------------------------------------------------------------------------- /TensorFlow/vgg16_unet.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input 2 | from tensorflow.keras.models import Model 3 | from tensorflow.keras.applications import VGG16 4 | 5 | def conv_block(input, num_filters): 6 | x = Conv2D(num_filters, 3, padding="same")(input) 7 | x = BatchNormalization()(x) 8 | x = Activation("relu")(x) 9 | 10 | x = Conv2D(num_filters, 3, padding="same")(x) 11 | x = BatchNormalization()(x) 12 | x = Activation("relu")(x) 13 | 14 | return x 15 | 16 | def decoder_block(input, skip_features, num_filters): 17 | x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input) 18 | x = Concatenate()([x, skip_features]) 19 | x = conv_block(x, num_filters) 20 | return x 21 | 22 | def build_vgg16_unet(input_shape): 23 | """ Input """ 24 | inputs = Input(input_shape) 25 | 26 | """ Pre-trained VGG16 Model """ 27 | vgg16 = VGG16(include_top=False, weights="imagenet", input_tensor=inputs) 28 | 29 | """ Encoder """ 30 | s1 = vgg16.get_layer("block1_conv2").output ## (512 x 512) 31 | s2 = vgg16.get_layer("block2_conv2").output ## (256 x 256) 32 | s3 = vgg16.get_layer("block3_conv3").output ## (128 x 128) 33 | s4 = vgg16.get_layer("block4_conv3").output ## (64 x 64) 34 | 35 | """ Bridge """ 36 | b1 = vgg16.get_layer("block5_conv3").output ## (32 x 32) 37 | 38 | """ Decoder """ 39 | d1 = decoder_block(b1, s4, 512) ## (64 x 64) 40 | d2 = decoder_block(d1, s3, 256) ## (128 x 128) 41 | d3 = decoder_block(d2, s2, 128) ## (256 x 256) 42 | d4 = decoder_block(d3, s1, 64) ## (512 x 512) 43 | 44 | """ Output """ 45 | outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4) 46 | 47 | model = Model(inputs, outputs, name="VGG16_U-Net") 48 | return model 49 | 50 | if __name__ == "__main__": 51 | input_shape = (512, 512, 3) 52 | model = build_vgg16_unet(input_shape) 53 | model.summary() 54 | -------------------------------------------------------------------------------- /TensorFlow/vgg19_unet.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input 2 | from tensorflow.keras.models import Model 3 | from tensorflow.keras.applications import VGG19 4 | 5 | def conv_block(input, num_filters): 6 | x = Conv2D(num_filters, 3, padding="same")(input) 7 | x = BatchNormalization()(x) 8 | x = Activation("relu")(x) 9 | 10 | x = Conv2D(num_filters, 3, padding="same")(x) 11 | x = BatchNormalization()(x) 12 | x = Activation("relu")(x) 13 | 14 | return x 15 | 16 | def decoder_block(input, skip_features, num_filters): 17 | x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input) 18 | x = Concatenate()([x, skip_features]) 19 | x = conv_block(x, num_filters) 20 | return x 21 | 22 | def build_vgg19_unet(input_shape): 23 | """ Input """ 24 | inputs = Input(input_shape) 25 | 26 | """ Pre-trained VGG19 Model """ 27 | vgg19 = VGG19(include_top=False, weights="imagenet", input_tensor=inputs) 28 | 29 | """ Encoder """ 30 | s1 = vgg19.get_layer("block1_conv2").output ## (512 x 512) 31 | s2 = vgg19.get_layer("block2_conv2").output ## (256 x 256) 32 | s3 = vgg19.get_layer("block3_conv4").output ## (128 x 128) 33 | s4 = vgg19.get_layer("block4_conv4").output ## (64 x 64) 34 | 35 | """ Bridge """ 36 | b1 = vgg19.get_layer("block5_conv4").output ## (32 x 32) 37 | 38 | """ Decoder """ 39 | d1 = decoder_block(b1, s4, 512) ## (64 x 64) 40 | d2 = decoder_block(d1, s3, 256) ## (128 x 128) 41 | d3 = decoder_block(d2, s2, 128) ## (256 x 256) 42 | d4 = decoder_block(d3, s1, 64) ## (512 x 512) 43 | 44 | """ Output """ 45 | outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4) 46 | 47 | model = Model(inputs, outputs, name="VGG19_U-Net") 48 | return model 49 | 50 | if __name__ == "__main__": 51 | input_shape = (512, 512, 3) 52 | model = build_vgg19_unet(input_shape) 53 | model.summary() 54 | --------------------------------------------------------------------------------