├── LICENSE
├── PyTorch
├── README.md
├── attention-unet.py
├── multiresunet.py
├── resunet.py
├── unet.py
└── unetr_2d.py
├── README.md
└── TensorFlow
├── attention-unet.py
├── colonsegnet.py
├── deeplabv3plus.py
├── densenet121.py
├── doubleunet.py
├── efficientnetb0_unet.py
├── inception_resnetv2_unet.py
├── mobilenetv2_unet.py
├── multiresunet.py
├── notebook
├── ColonSegNet.ipynb
├── README.md
└── images
│ ├── ColonSegNet.png
│ ├── ResidualBlock.png
│ ├── Strided_Conv_Block.png
│ └── squeeze_and_excitation_detailed_block_diagram.png
├── resnet50_unet.py
├── resunet++.py
├── resunet.py
├── u2-net.py
├── unet.py
├── unetr_2d.py
├── vgg16_unet.py
└── vgg19_unet.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/PyTorch/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch
2 | This directory contains the implementation of the different segmentation models in PyTorch framework.
3 |
--------------------------------------------------------------------------------
/PyTorch/attention-unet.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 |
5 | class conv_block(nn.Module):
6 | def __init__(self, in_c, out_c):
7 | super().__init__()
8 |
9 | self.conv = nn.Sequential(
10 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
11 | nn.BatchNorm2d(out_c),
12 | nn.ReLU(inplace=True),
13 | nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
14 | nn.BatchNorm2d(out_c),
15 | nn.ReLU(inplace=True)
16 | )
17 |
18 | def forward(self, x):
19 | return self.conv(x)
20 |
21 | class encoder_block(nn.Module):
22 | def __init__(self, in_c, out_c):
23 | super().__init__()
24 |
25 | self.conv = conv_block(in_c, out_c)
26 | self.pool = nn.MaxPool2d((2, 2))
27 |
28 | def forward(self, x):
29 | s = self.conv(x)
30 | p = self.pool(s)
31 | return s, p
32 |
33 | class attention_gate(nn.Module):
34 | def __init__(self, in_c, out_c):
35 | super().__init__()
36 |
37 | self.Wg = nn.Sequential(
38 | nn.Conv2d(in_c[0], out_c, kernel_size=1, padding=0),
39 | nn.BatchNorm2d(out_c)
40 | )
41 | self.Ws = nn.Sequential(
42 | nn.Conv2d(in_c[1], out_c, kernel_size=1, padding=0),
43 | nn.BatchNorm2d(out_c)
44 | )
45 | self.relu = nn.ReLU(inplace=True)
46 | self.output = nn.Sequential(
47 | nn.Conv2d(out_c, out_c, kernel_size=1, padding=0),
48 | nn.Sigmoid()
49 | )
50 |
51 | def forward(self, g, s):
52 | Wg = self.Wg(g)
53 | Ws = self.Ws(s)
54 | out = self.relu(Wg + Ws)
55 | out = self.output(out)
56 | return out * s
57 |
58 | class decoder_block(nn.Module):
59 | def __init__(self, in_c, out_c):
60 | super().__init__()
61 |
62 | self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
63 | self.ag = attention_gate(in_c, out_c)
64 | self.c1 = conv_block(in_c[0]+out_c, out_c)
65 |
66 | def forward(self, x, s):
67 | x = self.up(x)
68 | s = self.ag(x, s)
69 | x = torch.cat([x, s], axis=1)
70 | x = self.c1(x)
71 | return x
72 |
73 | class attention_unet(nn.Module):
74 | def __init__(self):
75 | super().__init__()
76 |
77 | self.e1 = encoder_block(3, 64)
78 | self.e2 = encoder_block(64, 128)
79 | self.e3 = encoder_block(128, 256)
80 |
81 | self.b1 = conv_block(256, 512)
82 |
83 | self.d1 = decoder_block([512, 256], 256)
84 | self.d2 = decoder_block([256, 128], 128)
85 | self.d3 = decoder_block([128, 64], 64)
86 |
87 | self.output = nn.Conv2d(64, 1, kernel_size=1, padding=0)
88 |
89 | def forward(self, x):
90 | s1, p1 = self.e1(x)
91 | s2, p2 = self.e2(p1)
92 | s3, p3 = self.e3(p2)
93 |
94 | b1 = self.b1(p3)
95 |
96 | d1 = self.d1(b1, s3)
97 | d2 = self.d2(d1, s2)
98 | d3 = self.d3(d2, s1)
99 |
100 | output = self.output(d3)
101 | return output
102 |
103 |
104 | if __name__ == "__main__":
105 | x = torch.randn((8, 3, 256, 256))
106 | model = attention_unet()
107 | output = model(x)
108 | print(output.shape)
109 |
--------------------------------------------------------------------------------
/PyTorch/multiresunet.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | class conv_block(nn.Module):
7 | def __init__(self, in_c, out_c, kernel_size=3, padding=1, act=True):
8 | super().__init__()
9 |
10 | layers = [
11 | nn.Conv2d(in_c, out_c, kernel_size=kernel_size, padding=padding, bias=False),
12 | nn.BatchNorm2d(out_c)
13 | ]
14 | if act == True:
15 | layers.append(nn.ReLU(inplace=True))
16 |
17 | self.conv = nn.Sequential(*layers)
18 |
19 | def forward(self, x):
20 | return self.conv(x)
21 |
22 | class multires_block(nn.Module):
23 | def __init__(self, in_c, out_c, alpha=1.67):
24 | super().__init__()
25 |
26 | W = out_c * alpha
27 | self.c1 = conv_block(in_c, int(W*0.167))
28 | self.c2 = conv_block(int(W*0.167), int(W*0.333))
29 | self.c3 = conv_block(int(W*0.333), int(W*0.5))
30 |
31 | nf = int(W*0.167) + int(W*0.333) + int(W*0.5)
32 | self.b1 = nn.BatchNorm2d(nf)
33 | self.c4 = conv_block(in_c, nf)
34 | self.relu = nn.ReLU(inplace=True)
35 | self.b2 = nn.BatchNorm2d(nf)
36 |
37 | def forward(self, x):
38 | x0 = x
39 | x1 = self.c1(x0)
40 | x2 = self.c2(x1)
41 | x3 = self.c3(x2)
42 | xc = torch.cat([x1, x2, x3], dim=1)
43 | xc = self.b1(xc)
44 |
45 | sc = self.c4(x0)
46 | x = self.relu(xc + sc)
47 | x = self.b2(x)
48 | return x
49 |
50 | class res_path_block(nn.Module):
51 | def __init__(self, in_c, out_c):
52 | super().__init__()
53 |
54 | self.c1 = conv_block(in_c, out_c, act=False)
55 | self.s1 = conv_block(in_c, out_c, kernel_size=1, padding=0, act=False)
56 | self.relu = nn.ReLU(inplace=True)
57 | self.bn = nn.BatchNorm2d(out_c)
58 |
59 | def forward(self, x):
60 | x1 = self.c1(x)
61 | s1 = self.s1(x)
62 | x = self.relu(x1 + s1)
63 | x = self.bn(x)
64 | return x
65 |
66 | class res_path(nn.Module):
67 | def __init__(self, in_c, out_c, length):
68 | super().__init__()
69 |
70 | layers = []
71 | for i in range(length):
72 | layers.append(res_path_block(in_c, out_c))
73 | in_c = out_c
74 |
75 | self.conv = nn.Sequential(*layers)
76 |
77 | def forward(self, x):
78 | return self.conv(x)
79 |
80 | def cal_nf(ch, alpha=1.67):
81 | W = ch * alpha
82 | return int(W*0.167) + int(W*0.333) + int(W*0.5)
83 |
84 | class encoder_block(nn.Module):
85 | def __init__(self, in_c, out_c, length):
86 | super().__init__()
87 |
88 | self.c1 = multires_block(in_c, out_c)
89 | nf = cal_nf(out_c)
90 | self.s1 = res_path(nf, out_c, length)
91 | self.pool = nn.MaxPool2d((2, 2))
92 |
93 | def forward(self, x):
94 | x = self.c1(x)
95 | s = self.s1(x)
96 | p = self.pool(x)
97 | return s, p
98 |
99 | class decoder_block(nn.Module):
100 | def __init__(self, in_c, out_c):
101 | super().__init__()
102 |
103 | self.c1 = nn.ConvTranspose2d(in_c[0], out_c, kernel_size=2, stride=2, padding=0)
104 | self.c2 = multires_block(out_c+in_c[1], out_c)
105 |
106 | def forward(self, x, s):
107 | x = self.c1(x)
108 | x = torch.cat([x, s], dim=1)
109 | x = self.c2(x)
110 | return x
111 |
112 | class build_multiresunet(nn.Module):
113 | def __init__(self):
114 | super().__init__()
115 |
116 | """ Encoder """
117 | self.e1 = encoder_block(3, 32, 4)
118 | self.e2 = encoder_block(cal_nf(32), 64, 3)
119 | self.e3 = encoder_block(cal_nf(64), 128, 2)
120 | self.e4 = encoder_block(cal_nf(128), 256, 1)
121 |
122 | """ Bridge """
123 | self.b1 = multires_block(cal_nf(256), 512)
124 |
125 | """ Decoder """
126 | self.d1 = decoder_block([cal_nf(512), 256], 256)
127 | self.d2 = decoder_block([cal_nf(256), 128], 128)
128 | self.d3 = decoder_block([cal_nf(128), 64], 64)
129 | self.d4 = decoder_block([cal_nf(64), 32], 32)
130 |
131 | """ Output """
132 | self.output = nn.Conv2d(cal_nf(32), 1, kernel_size=1, padding=0)
133 |
134 | def forward(self, x):
135 | s1, p1 = self.e1(x)
136 | s2, p2 = self.e2(p1)
137 | s3, p3 = self.e3(p2)
138 | s4, p4 = self.e4(p3)
139 |
140 | b1 = self.b1(p4)
141 |
142 | d1 = self.d1(b1, s4)
143 | d2 = self.d2(d1, s3)
144 | d3 = self.d3(d2, s2)
145 | d4 = self.d4(d3, s1)
146 |
147 | output = self.output(d4)
148 | return output
149 |
150 | if __name__ == "__main__":
151 | x = torch.randn((8, 3, 256, 256))
152 | model = build_multiresunet()
153 | output = model(x)
154 | print(output.shape)
155 |
--------------------------------------------------------------------------------
/PyTorch/resunet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class batchnorm_relu(nn.Module):
5 | def __init__(self, in_c):
6 | super().__init__()
7 |
8 | self.bn = nn.BatchNorm2d(in_c)
9 | self.relu = nn.ReLU()
10 |
11 | def forward(self, inputs):
12 | x = self.bn(inputs)
13 | x = self.relu(x)
14 | return x
15 |
16 | class residual_block(nn.Module):
17 | def __init__(self, in_c, out_c, stride=1):
18 | super().__init__()
19 |
20 | """ Convolutional layer """
21 | self.b1 = batchnorm_relu(in_c)
22 | self.c1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, stride=stride)
23 | self.b2 = batchnorm_relu(out_c)
24 | self.c2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1, stride=1)
25 |
26 | """ Shortcut Connection (Identity Mapping) """
27 | self.s = nn.Conv2d(in_c, out_c, kernel_size=1, padding=0, stride=stride)
28 |
29 | def forward(self, inputs):
30 | x = self.b1(inputs)
31 | x = self.c1(x)
32 | x = self.b2(x)
33 | x = self.c2(x)
34 | s = self.s(inputs)
35 |
36 | skip = x + s
37 | return skip
38 |
39 | class decoder_block(nn.Module):
40 | def __init__(self, in_c, out_c):
41 | super().__init__()
42 |
43 | self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
44 | self.r = residual_block(in_c+out_c, out_c)
45 |
46 | def forward(self, inputs, skip):
47 | x = self.upsample(inputs)
48 | x = torch.cat([x, skip], axis=1)
49 | x = self.r(x)
50 | return x
51 |
52 | class build_resunet(nn.Module):
53 | def __init__(self):
54 | super().__init__()
55 |
56 | """ Encoder 1 """
57 | self.c11 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
58 | self.br1 = batchnorm_relu(64)
59 | self.c12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
60 | self.c13 = nn.Conv2d(3, 64, kernel_size=1, padding=0)
61 |
62 | """ Encoder 2 and 3 """
63 | self.r2 = residual_block(64, 128, stride=2)
64 | self.r3 = residual_block(128, 256, stride=2)
65 |
66 | """ Bridge """
67 | self.r4 = residual_block(256, 512, stride=2)
68 |
69 | """ Decoder """
70 | self.d1 = decoder_block(512, 256)
71 | self.d2 = decoder_block(256, 128)
72 | self.d3 = decoder_block(128, 64)
73 |
74 | """ Output """
75 | self.output = nn.Conv2d(64, 1, kernel_size=1, padding=0)
76 | self.sigmoid = nn.Sigmoid()
77 |
78 | def forward(self, inputs):
79 | """ Encoder 1 """
80 | x = self.c11(inputs)
81 | x = self.br1(x)
82 | x = self.c12(x)
83 | s = self.c13(inputs)
84 | skip1 = x + s
85 |
86 | """ Encoder 2 and 3 """
87 | skip2 = self.r2(skip1)
88 | skip3 = self.r3(skip2)
89 |
90 | """ Bridge """
91 | b = self.r4(skip3)
92 |
93 | """ Decoder """
94 | d1 = self.d1(b, skip3)
95 | d2 = self.d2(d1, skip2)
96 | d3 = self.d3(d2, skip1)
97 |
98 | """ output """
99 | output = self.output(d3)
100 | output = self.sigmoid(output)
101 |
102 | return output
103 |
104 |
105 | if __name__ == "__main__":
106 | inputs = torch.randn((4, 3, 256, 256))
107 | model = build_resunet()
108 | y = model(inputs)
109 | print(y.shape)
110 |
--------------------------------------------------------------------------------
/PyTorch/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | """ Convolutional block:
5 | It follows a two 3x3 convolutional layer, each followed by a batch normalization and a relu activation.
6 | """
7 | class conv_block(nn.Module):
8 | def __init__(self, in_c, out_c):
9 | super().__init__()
10 |
11 | self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
12 | self.bn1 = nn.BatchNorm2d(out_c)
13 |
14 | self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
15 | self.bn2 = nn.BatchNorm2d(out_c)
16 |
17 | self.relu = nn.ReLU()
18 |
19 | def forward(self, inputs):
20 | x = self.conv1(inputs)
21 | x = self.bn1(x)
22 | x = self.relu(x)
23 |
24 | x = self.conv2(x)
25 | x = self.bn2(x)
26 | x = self.relu(x)
27 |
28 | return x
29 |
30 | """ Encoder block:
31 | It consists of an conv_block followed by a max pooling.
32 | Here the number of filters doubles and the height and width half after every block.
33 | """
34 | class encoder_block(nn.Module):
35 | def __init__(self, in_c, out_c):
36 | super().__init__()
37 |
38 | self.conv = conv_block(in_c, out_c)
39 | self.pool = nn.MaxPool2d((2, 2))
40 |
41 | def forward(self, inputs):
42 | x = self.conv(inputs)
43 | p = self.pool(x)
44 |
45 | return x, p
46 |
47 | """ Decoder block:
48 | The decoder block begins with a transpose convolution, followed by a concatenation with the skip
49 | connection from the encoder block. Next comes the conv_block.
50 | Here the number filters decreases by half and the height and width doubles.
51 | """
52 | class decoder_block(nn.Module):
53 | def __init__(self, in_c, out_c):
54 | super().__init__()
55 |
56 | self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
57 | self.conv = conv_block(out_c+out_c, out_c)
58 |
59 | def forward(self, inputs, skip):
60 | x = self.up(inputs)
61 | x = torch.cat([x, skip], axis=1)
62 | x = self.conv(x)
63 |
64 | return x
65 |
66 |
67 | class build_unet(nn.Module):
68 | def __init__(self):
69 | super().__init__()
70 |
71 | """ Encoder """
72 | self.e1 = encoder_block(3, 64)
73 | self.e2 = encoder_block(64, 128)
74 | self.e3 = encoder_block(128, 256)
75 | self.e4 = encoder_block(256, 512)
76 |
77 | """ Bottleneck """
78 | self.b = conv_block(512, 1024)
79 |
80 | """ Decoder """
81 | self.d1 = decoder_block(1024, 512)
82 | self.d2 = decoder_block(512, 256)
83 | self.d3 = decoder_block(256, 128)
84 | self.d4 = decoder_block(128, 64)
85 |
86 | """ Classifier """
87 | self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)
88 |
89 | def forward(self, inputs):
90 | """ Encoder """
91 | s1, p1 = self.e1(inputs)
92 | s2, p2 = self.e2(p1)
93 | s3, p3 = self.e3(p2)
94 | s4, p4 = self.e4(p3)
95 |
96 | """ Bottleneck """
97 | b = self.b(p4)
98 |
99 | """ Decoder """
100 | d1 = self.d1(b, s4)
101 | d2 = self.d2(d1, s3)
102 | d3 = self.d3(d2, s2)
103 | d4 = self.d4(d3, s1)
104 |
105 | """ Classifier """
106 | outputs = self.outputs(d4)
107 |
108 | return outputs
109 |
110 | if __name__ == "__main__":
111 | # inputs = torch.randn((2, 32, 256, 256))
112 | # e = encoder_block(32, 64)
113 | # x, p = e(inputs)
114 | # print(x.shape, p.shape)
115 | #
116 | # d = decoder_block(64, 32)
117 | # y = d(p, x)
118 | # print(y.shape)
119 |
120 | inputs = torch.randn((2, 3, 512, 512))
121 | model = build_unet()
122 | y = model(inputs)
123 | print(y.shape)
124 |
--------------------------------------------------------------------------------
/PyTorch/unetr_2d.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class ConvBlock(nn.Module):
7 | def __init__(self, in_c, out_c, kernel_size=3, padding=1):
8 | super().__init__()
9 |
10 | self.layers = nn.Sequential(
11 | nn.Conv2d(in_c, out_c, kernel_size=kernel_size, padding=padding),
12 | nn.BatchNorm2d(out_c),
13 | nn.ReLU(inplace=True)
14 | )
15 |
16 |
17 | def forward(self, x):
18 | return self.layers(x)
19 |
20 |
21 | class DeconvBlock(nn.Module):
22 | def __init__(self, in_c, out_c):
23 | super().__init__()
24 |
25 | self.deconv = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
26 |
27 | def forward(self, x):
28 | return self.deconv(x)
29 |
30 |
31 | class UNETR_2D(nn.Module):
32 | def __init__(self, cf):
33 | super().__init__()
34 | self.cf = cf
35 |
36 | """ Patch + Position Embeddings """
37 | self.patch_embed = nn.Linear(
38 | cf["patch_size"]*cf["patch_size"]*cf["num_channels"],
39 | cf["hidden_dim"]
40 | )
41 |
42 | self.positions = torch.arange(start=0, end=cf["num_patches"], step=1, dtype=torch.int32)
43 | self.pos_embed = nn.Embedding(cf["num_patches"], cf["hidden_dim"])
44 |
45 | """ Transformer Encoder """
46 | self.trans_encoder_layers = []
47 |
48 | for i in range(cf["num_layers"]):
49 | layer = nn.TransformerEncoderLayer(
50 | d_model=cf["hidden_dim"],
51 | nhead=cf["num_heads"],
52 | dim_feedforward=cf["mlp_dim"],
53 | dropout=cf["dropout_rate"],
54 | activation=nn.GELU(),
55 | batch_first=True
56 | )
57 | self.trans_encoder_layers.append(layer)
58 |
59 | """ CNN Decoder """
60 | ## Decoder 1
61 | self.d1 = DeconvBlock(cf["hidden_dim"], 512)
62 | self.s1 = nn.Sequential(
63 | DeconvBlock(cf["hidden_dim"], 512),
64 | ConvBlock(512, 512)
65 | )
66 | self.c1 = nn.Sequential(
67 | ConvBlock(512+512, 512),
68 | ConvBlock(512, 512)
69 | )
70 |
71 | ## Decoder 2
72 | self.d2 = DeconvBlock(512, 256)
73 | self.s2 = nn.Sequential(
74 | DeconvBlock(cf["hidden_dim"], 256),
75 | ConvBlock(256, 256),
76 | DeconvBlock(256, 256),
77 | ConvBlock(256, 256)
78 | )
79 | self.c2 = nn.Sequential(
80 | ConvBlock(256+256, 256),
81 | ConvBlock(256, 256)
82 | )
83 |
84 | ## Decoder 3
85 | self.d3 = DeconvBlock(256, 128)
86 | self.s3 = nn.Sequential(
87 | DeconvBlock(cf["hidden_dim"], 128),
88 | ConvBlock(128, 128),
89 | DeconvBlock(128, 128),
90 | ConvBlock(128, 128),
91 | DeconvBlock(128, 128),
92 | ConvBlock(128, 128)
93 | )
94 | self.c3 = nn.Sequential(
95 | ConvBlock(128+128, 128),
96 | ConvBlock(128, 128)
97 | )
98 |
99 | ## Decoder 4
100 | self.d4 = DeconvBlock(128, 64)
101 | self.s4 = nn.Sequential(
102 | ConvBlock(3, 64),
103 | ConvBlock(64, 64)
104 | )
105 | self.c4 = nn.Sequential(
106 | ConvBlock(64+64, 64),
107 | ConvBlock(64, 64)
108 | )
109 |
110 | """ Output """
111 | self.output = nn.Conv2d(64, 1, kernel_size=1, padding=0)
112 |
113 | def forward(self, inputs):
114 | """ Patch + Position Embeddings """
115 | patch_embed = self.patch_embed(inputs) ## [8, 256, 768]
116 |
117 | positions = self.positions
118 | pos_embed = self.pos_embed(positions) ## [256, 768]
119 |
120 | x = patch_embed + pos_embed ## [8, 256, 768]
121 |
122 | """ Transformer Encoder """
123 | skip_connection_index = [3, 6, 9, 12]
124 | skip_connections = []
125 |
126 | for i in range(self.cf["num_layers"]):
127 | layer = self.trans_encoder_layers[i]
128 | x = layer(x)
129 |
130 | if (i+1) in skip_connection_index:
131 | skip_connections.append(x)
132 |
133 | """ CNN Decoder """
134 | z3, z6, z9, z12 = skip_connections
135 |
136 | ## Reshaping
137 | batch = inputs.shape[0]
138 | z0 = inputs.view((batch, self.cf["num_channels"], self.cf["image_size"], self.cf["image_size"]))
139 |
140 | shape = (batch, self.cf["hidden_dim"], self.cf["patch_size"], self.cf["patch_size"])
141 | z3 = z3.view(shape)
142 | z6 = z6.view(shape)
143 | z9 = z9.view(shape)
144 | z12 = z12.view(shape)
145 |
146 |
147 | ## Decoder 1
148 | x = self.d1(z12)
149 | s = self.s1(z9)
150 | x = torch.cat([x, s], dim=1)
151 | x = self.c1(x)
152 |
153 | ## Decoder 2
154 | x = self.d2(x)
155 | s = self.s2(z6)
156 | x = torch.cat([x, s], dim=1)
157 | x = self.c2(x)
158 |
159 | ## Decoder 3
160 | x = self.d3(x)
161 | s = self.s3(z3)
162 | x = torch.cat([x, s], dim=1)
163 | x = self.c3(x)
164 |
165 | ## Decoder 4
166 | x = self.d4(x)
167 | s = self.s4(z0)
168 | x = torch.cat([x, s], dim=1)
169 | x = self.c4(x)
170 |
171 | """ Output """
172 | output = self.output(x)
173 |
174 | return output
175 |
176 |
177 | if __name__ == "__main__":
178 | config = {}
179 | config["image_size"] = 256
180 | config["num_layers"] = 12
181 | config["hidden_dim"] = 768
182 | config["mlp_dim"] = 3072
183 | config["num_heads"] = 12
184 | config["dropout_rate"] = 0.1
185 | config["num_patches"] = 256
186 | config["patch_size"] = 16
187 | config["num_channels"] = 3
188 |
189 | x = torch.randn((
190 | 8,
191 | config["num_patches"],
192 | config["patch_size"]*config["patch_size"]*config["num_channels"]
193 | ))
194 | model = UNETR_2D(config)
195 | output = model(x)
196 | print(output.shape)
197 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Semantic-Segmentation-Architecture
2 | A repository contains the code for various semantic segmentation in TensorFlow and PyTorch framework.
3 |
4 | ## Research papers
5 | - [U-Net](https://arxiv.org/pdf/1505.04597.pdf)
6 | - [ResU-Net](https://arxiv.org/pdf/1711.10684.pdf)
7 | - [MultiResU-Net](https://arxiv.org/pdf/1902.04049)
8 |
9 | ## Star History
10 | [](https://github.com/nikhilroxtomar/Semantic-Segmentation-Architecture/stargazers)
11 |
--------------------------------------------------------------------------------
/TensorFlow/attention-unet.py:
--------------------------------------------------------------------------------
1 |
2 | import tensorflow as tf
3 | import tensorflow.keras.layers as L
4 | from tensorflow.keras.models import Model
5 |
6 | def conv_block(x, num_filters):
7 | x = L.Conv2D(num_filters, 3, padding="same")(x)
8 | x = L.BatchNormalization()(x)
9 | x = L.Activation("relu")(x)
10 |
11 | x = L.Conv2D(num_filters, 3, padding="same")(x)
12 | x = L.BatchNormalization()(x)
13 | x = L.Activation("relu")(x)
14 |
15 | return x
16 |
17 | def encoder_block(x, num_filters):
18 | x = conv_block(x, num_filters)
19 | p = L.MaxPool2D((2, 2))(x)
20 | return x, p
21 |
22 | def attention_gate(g, s, num_filters):
23 | Wg = L.Conv2D(num_filters, 1, padding="same")(g)
24 | Wg = L.BatchNormalization()(Wg)
25 |
26 | Ws = L.Conv2D(num_filters, 1, padding="same")(s)
27 | Ws = L.BatchNormalization()(Ws)
28 |
29 | out = L.Activation("relu")(Wg + Ws)
30 | out = L.Conv2D(num_filters, 1, padding="same")(out)
31 | out = L.Activation("sigmoid")(out)
32 |
33 | return out * s
34 |
35 | def decoder_block(x, s, num_filters):
36 | x = L.UpSampling2D(interpolation="bilinear")(x)
37 | s = attention_gate(x, s, num_filters)
38 | x = L.Concatenate()([x, s])
39 | x = conv_block(x, num_filters)
40 | return x
41 |
42 | def attention_unet(input_shape):
43 | """ Inputs """
44 | inputs = L.Input(input_shape)
45 |
46 | """ Encoder """
47 | s1, p1 = encoder_block(inputs, 64)
48 | s2, p2 = encoder_block(p1, 128)
49 | s3, p3 = encoder_block(p2, 256)
50 |
51 | b1 = conv_block(p3, 512)
52 |
53 | """ Decoder """
54 | d1 = decoder_block(b1, s3, 256)
55 | d2 = decoder_block(d1, s2, 128)
56 | d3 = decoder_block(d2, s1, 64)
57 |
58 | """ Outputs """
59 | outputs = L.Conv2D(1, 1, padding="same", activation="sigmoid")(d3)
60 |
61 | """ Model """
62 | model = Model(inputs, outputs, name="Attention-UNET")
63 | return model
64 |
65 | if __name__ == "__main__":
66 | input_shape = (256, 256, 3)
67 | model = attention_unet(input_shape)
68 | model.summary()
69 |
--------------------------------------------------------------------------------
/TensorFlow/colonsegnet.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
4 |
5 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, UpSampling2D, Dense
6 | from tensorflow.keras.layers import GlobalAveragePooling2D, Conv2DTranspose, Concatenate, Input
7 | from tensorflow.keras.layers import MaxPool2D
8 | from tensorflow.keras.models import Model
9 |
10 | def se_layer(x, num_filters, reduction=16):
11 | x_init = x
12 |
13 | x = GlobalAveragePooling2D()(x)
14 | x = Dense(num_filters//reduction, use_bias=False, activation="relu")(x)
15 | x = Dense(num_filters, use_bias=False, activation="sigmoid")(x)
16 | x = x * x_init
17 | return x
18 |
19 | def residual_block(x, num_filters):
20 | x_init = x
21 |
22 | x = Conv2D(num_filters, 3, padding="same")(x)
23 | x = BatchNormalization()(x)
24 | x = Activation("relu")(x)
25 |
26 | x = Conv2D(num_filters, 3, padding="same")(x)
27 | x = BatchNormalization()(x)
28 |
29 | s = Conv2D(num_filters, 1, padding="same")(x_init)
30 | s = BatchNormalization()(s)
31 | s = se_layer(s, num_filters)
32 |
33 | x = Activation("relu")(x + s)
34 | return x
35 |
36 | def strided_conv_block(x, num_filters):
37 | x = Conv2D(num_filters, 3, strides=2, padding="same")(x)
38 | x = BatchNormalization()(x)
39 | x = Activation("relu")(x)
40 | return x
41 |
42 | def encoder_block(x, num_filters):
43 | x1 = residual_block(x, num_filters)
44 | x2 = strided_conv_block(x1, num_filters)
45 | x3 = residual_block(x2, num_filters)
46 | p = MaxPool2D((2, 2))(x3)
47 |
48 | return x1, x3, p
49 |
50 | def build_colonsegnet(input_shape):
51 | """ Input """
52 | inputs = Input(input_shape)
53 |
54 | """ Encoder """
55 | s11, s12, p1 = encoder_block(inputs, 64)
56 | s21, s22, p2 = encoder_block(p1, 256)
57 |
58 | """ Decoder 1 """
59 | x = Conv2DTranspose(128, 4, strides=4, padding="same")(s22)
60 | x = Concatenate()([x, s12])
61 | x = residual_block(x, 128)
62 | r1 = x
63 |
64 | x = Conv2DTranspose(128, 4, strides=2, padding="same")(s21)
65 | x = Concatenate()([x, r1])
66 | x = residual_block(x, 128)
67 |
68 | """ Decoder 2 """
69 | x = Conv2DTranspose(64, 4, strides=2, padding="same")(x)
70 | x = Concatenate()([x, s11])
71 | x = residual_block(x, 64)
72 | r2 = x
73 |
74 | x = Conv2DTranspose(32, 4, strides=2, padding="same")(s12)
75 | x = Concatenate()([x, r2])
76 | x = residual_block(x, 32)
77 |
78 | """ Output """
79 | output = Conv2D(1, 1, padding="same")(x)
80 |
81 | """ Model """
82 | model = Model(inputs, output)
83 |
84 | return model
85 |
86 | if __name__ == "__main__":
87 | input_shape = (512, 512, 3)
88 | model = build_colonsegnet(input_shape)
89 | model.summary()
90 |
--------------------------------------------------------------------------------
/TensorFlow/deeplabv3plus.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
4 |
5 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, UpSampling2D
6 | from tensorflow.keras.layers import AveragePooling2D, Conv2DTranspose, Concatenate, Input
7 | from tensorflow.keras.models import Model
8 | from tensorflow.keras.applications import ResNet50
9 |
10 | """ Atrous Spatial Pyramid Pooling """
11 | def ASPP(inputs):
12 | shape = inputs.shape
13 |
14 | y_pool = AveragePooling2D(pool_size=(shape[1], shape[2]), name='average_pooling')(inputs)
15 | y_pool = Conv2D(filters=256, kernel_size=1, padding='same', use_bias=False)(y_pool)
16 | y_pool = BatchNormalization(name=f'bn_1')(y_pool)
17 | y_pool = Activation('relu', name=f'relu_1')(y_pool)
18 | y_pool = UpSampling2D((shape[1], shape[2]), interpolation="bilinear")(y_pool)
19 |
20 | y_1 = Conv2D(filters=256, kernel_size=1, dilation_rate=1, padding='same', use_bias=False)(inputs)
21 | y_1 = BatchNormalization()(y_1)
22 | y_1 = Activation('relu')(y_1)
23 |
24 | y_6 = Conv2D(filters=256, kernel_size=3, dilation_rate=6, padding='same', use_bias=False)(inputs)
25 | y_6 = BatchNormalization()(y_6)
26 | y_6 = Activation('relu')(y_6)
27 |
28 | y_12 = Conv2D(filters=256, kernel_size=3, dilation_rate=12, padding='same', use_bias=False)(inputs)
29 | y_12 = BatchNormalization()(y_12)
30 | y_12 = Activation('relu')(y_12)
31 |
32 | y_18 = Conv2D(filters=256, kernel_size=3, dilation_rate=18, padding='same', use_bias=False)(inputs)
33 | y_18 = BatchNormalization()(y_18)
34 | y_18 = Activation('relu')(y_18)
35 |
36 | y = Concatenate()([y_pool, y_1, y_6, y_12, y_18])
37 |
38 | y = Conv2D(filters=256, kernel_size=1, dilation_rate=1, padding='same', use_bias=False)(y)
39 | y = BatchNormalization()(y)
40 | y = Activation('relu')(y)
41 | return y
42 |
43 | def DeepLabV3Plus(shape):
44 | """ Inputs """
45 | inputs = Input(shape)
46 |
47 | """ Pre-trained ResNet50 """
48 | base_model = ResNet50(weights='imagenet', include_top=False, input_tensor=inputs)
49 |
50 | """ Pre-trained ResNet50 Output """
51 | image_features = base_model.get_layer('conv4_block6_out').output
52 | x_a = ASPP(image_features)
53 | x_a = UpSampling2D((4, 4), interpolation="bilinear")(x_a)
54 |
55 | """ Get low-level features """
56 | x_b = base_model.get_layer('conv2_block2_out').output
57 | x_b = Conv2D(filters=48, kernel_size=1, padding='same', use_bias=False)(x_b)
58 | x_b = BatchNormalization()(x_b)
59 | x_b = Activation('relu')(x_b)
60 |
61 | x = Concatenate()([x_a, x_b])
62 |
63 | x = Conv2D(filters=256, kernel_size=3, padding='same', activation='relu',use_bias=False)(x)
64 | x = BatchNormalization()(x)
65 | x = Activation('relu')(x)
66 |
67 | x = Conv2D(filters=256, kernel_size=3, padding='same', activation='relu', use_bias=False)(x)
68 | x = BatchNormalization()(x)
69 | x = Activation('relu')(x)
70 | x = UpSampling2D((4, 4), interpolation="bilinear")(x)
71 |
72 | """ Outputs """
73 | x = Conv2D(1, (1, 1), name='output_layer')(x)
74 | x = Activation('sigmoid')(x)
75 |
76 | """ Model """
77 | model = Model(inputs=inputs, outputs=x)
78 | return model
79 |
80 | if __name__ == "__main__":
81 | input_shape = (512, 512, 3)
82 | model = DeepLabV3Plus(input_shape)
83 | model.summary()
84 |
85 |
--------------------------------------------------------------------------------
/TensorFlow/densenet121.py:
--------------------------------------------------------------------------------
1 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
2 | from tensorflow.keras.models import Model
3 | from tensorflow.keras.applications import DenseNet121
4 |
5 | def conv_block(inputs, num_filters):
6 | x = Conv2D(num_filters, 3, padding="same")(inputs)
7 | x = BatchNormalization()(x)
8 | x = Activation("relu")(x)
9 |
10 | x = Conv2D(num_filters, 3, padding="same")(x)
11 | x = BatchNormalization()(x)
12 | x = Activation("relu")(x)
13 |
14 | return x
15 |
16 | def decoder_block(inputs, skip_features, num_filters):
17 | x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(inputs)
18 | x = Concatenate()([x, skip_features])
19 | x = conv_block(x, num_filters)
20 | return x
21 |
22 | def build_densenet121_unet(input_shape):
23 | """ Input """
24 | inputs = Input(input_shape)
25 |
26 | """ Pre-trained DenseNet121 Model """
27 | densenet = DenseNet121(include_top=False, weights="imagenet", input_tensor=inputs)
28 |
29 | """ Encoder """
30 | s1 = densenet.get_layer("input_1").output ## 512
31 | s2 = densenet.get_layer("conv1/relu").output ## 256
32 | s3 = densenet.get_layer("pool2_relu").output ## 128
33 | s4 = densenet.get_layer("pool3_relu").output ## 64
34 |
35 | """ Bridge """
36 | b1 = densenet.get_layer("pool4_relu").output ## 32
37 |
38 | """ Decoder """
39 | d1 = decoder_block(b1, s4, 512) ## 64
40 | d2 = decoder_block(d1, s3, 256) ## 128
41 | d3 = decoder_block(d2, s2, 128) ## 256
42 | d4 = decoder_block(d3, s1, 64) ## 512
43 |
44 | """ Outputs """
45 | outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)
46 |
47 | model = Model(inputs, outputs)
48 | return model
49 |
50 |
51 | if __name__ == "__main__":
52 | input_shape = (512, 512, 3)
53 | model = build_densenet121_unet(input_shape)
54 | model.summary()
55 |
--------------------------------------------------------------------------------
/TensorFlow/doubleunet.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
4 |
5 | import tensorflow as tf
6 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
7 | from tensorflow.keras.layers import GlobalAveragePooling2D, Reshape, Dense, Multiply, AveragePooling2D, UpSampling2D
8 | from tensorflow.keras.models import Model
9 | from tensorflow.keras.applications import VGG19
10 |
11 | def squeeze_excite_block(inputs, ratio=8):
12 | init = inputs ## (b, 128, 128, 32)
13 | channel_axis = -1
14 | filters = init.shape[channel_axis]
15 | se_shape = (1, 1, filters)
16 |
17 | se = GlobalAveragePooling2D()(init) ## (b, 32) -> (b, 1, 1, 32)
18 | se = Reshape(se_shape)(se)
19 | se = Dense(filters//ratio, activation="relu", use_bias=False)(se)
20 | se = Dense(filters, activation="sigmoid", use_bias=False)(se)
21 |
22 | x = Multiply()([inputs, se])
23 | return x
24 |
25 | def ASPP(x, filter):
26 | shape = x.shape
27 |
28 | y1 = AveragePooling2D(pool_size=(shape[1], shape[2]))(x)
29 | y1 = Conv2D(filter, 1, padding="same")(y1)
30 | y1 = BatchNormalization()(y1)
31 | y1 = Activation("relu")(y1)
32 | y1 = UpSampling2D((shape[1], shape[2]), interpolation="bilinear")(y1)
33 |
34 | y2 = Conv2D(filter, 1, dilation_rate=1, padding="same", use_bias=False)(x)
35 | y2 = BatchNormalization()(y2)
36 | y2 = Activation("relu")(y2)
37 |
38 | y3 = Conv2D(filter, 3, dilation_rate=6, padding="same", use_bias=False)(x)
39 | y3 = BatchNormalization()(y3)
40 | y3 = Activation("relu")(y3)
41 |
42 | y4 = Conv2D(filter, 3, dilation_rate=12, padding="same", use_bias=False)(x)
43 | y4 = BatchNormalization()(y4)
44 | y4 = Activation("relu")(y4)
45 |
46 | y5 = Conv2D(filter, 3, dilation_rate=18, padding="same", use_bias=False)(x)
47 | y5 = BatchNormalization()(y5)
48 | y5 = Activation("relu")(y5)
49 |
50 | y = Concatenate()([y1, y2, y3, y4, y5])
51 |
52 | y = Conv2D(filter, 1, dilation_rate=1, padding="same", use_bias=False)(y)
53 | y = BatchNormalization()(y)
54 | y = Activation("relu")(y)
55 |
56 | return y
57 |
58 | def conv_block(x, filters):
59 | x = Conv2D(filters, 3, padding="same")(x)
60 | x = BatchNormalization()(x)
61 | x = Activation("relu")(x)
62 |
63 | x = Conv2D(filters, 3, padding="same")(x)
64 | x = BatchNormalization()(x)
65 | x = Activation("relu")(x)
66 |
67 | x = squeeze_excite_block(x)
68 |
69 | return x
70 |
71 | def encoder1(inputs):
72 | skip_connections = []
73 |
74 | model = VGG19(include_top=False, weights="imagenet", input_tensor=inputs)
75 | names = ["block1_conv2", "block2_conv2", "block3_conv4", "block4_conv4"]
76 | for name in names:
77 | skip_connections.append(model.get_layer(name).output)
78 |
79 | output = model.get_layer("block5_conv4").output
80 | return output, skip_connections
81 |
82 | def decoder1(inputs, skip_connections):
83 | num_filters = [256, 128, 64, 32]
84 | skip_connections.reverse()
85 |
86 | x = inputs
87 | for i, f in enumerate(num_filters):
88 | x = UpSampling2D((2, 2), interpolation="bilinear")(x)
89 | x = Concatenate()([x, skip_connections[i]])
90 | x = conv_block(x, f)
91 |
92 | return x
93 |
94 | def output_block(inputs):
95 | x = Conv2D(1, 1, padding="same")(inputs)
96 | x = Activation("sigmoid")(x)
97 | return x
98 |
99 | def encoder2(inputs):
100 | num_filters = [32, 64, 128, 256]
101 | skip_connections = []
102 |
103 | x = inputs
104 | for i, f in enumerate(num_filters):
105 | x = conv_block(x, f)
106 | skip_connections.append(x)
107 | x = MaxPool2D((2, 2))(x)
108 |
109 | return x, skip_connections
110 |
111 | def decoder2(inputs, skip_1, skip_2):
112 | num_filters = [256, 128, 64, 32]
113 | skip_2.reverse()
114 |
115 | x = inputs
116 | for i, f in enumerate(num_filters):
117 | x = UpSampling2D((2, 2), interpolation="bilinear")(x)
118 | x = Concatenate()([x, skip_1[i], skip_2[i]])
119 | x = conv_block(x, f)
120 |
121 | return x
122 |
123 | def build_model(input_shape):
124 | inputs = Input(input_shape)
125 | x, skip_1 = encoder1(inputs)
126 | x = ASPP(x, 64)
127 | x = decoder1(x, skip_1)
128 | output1 = output_block(x)
129 |
130 | x = inputs * output1
131 |
132 | x, skip_2 = encoder2(x)
133 | x = ASPP(x, 64)
134 | x = decoder2(x, skip_1, skip_2)
135 | output2 = output_block(x)
136 |
137 | outputs = Concatenate()([output1, output2])
138 | model = Model(inputs, outputs)
139 | return model
140 |
141 |
142 | if __name__ == "__main__":
143 | input_shape = (256, 256, 3)
144 | model = build_model(input_shape)
145 | model.summary()
146 |
--------------------------------------------------------------------------------
/TensorFlow/efficientnetb0_unet.py:
--------------------------------------------------------------------------------
1 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
2 | from tensorflow.keras.models import Model
3 | from tensorflow.keras.applications import EfficientNetB0
4 | import tensorflow as tf
5 |
6 | print("TF Version: ", tf.__version__)
7 |
8 | def conv_block(inputs, num_filters):
9 | x = Conv2D(num_filters, 3, padding="same")(inputs)
10 | x = BatchNormalization()(x)
11 | x = Activation("relu")(x)
12 |
13 | x = Conv2D(num_filters, 3, padding="same")(x)
14 | x = BatchNormalization()(x)
15 | x = Activation("relu")(x)
16 |
17 | return x
18 |
19 | def decoder_block(inputs, skip, num_filters):
20 | x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(inputs)
21 | x = Concatenate()([x, skip])
22 | x = conv_block(x, num_filters)
23 | return x
24 |
25 | def build_effienet_unet(input_shape):
26 | """ Input """
27 | inputs = Input(input_shape)
28 |
29 | """ Pre-trained Encoder """
30 | encoder = EfficientNetB0(include_top=False, weights="imagenet", input_tensor=inputs)
31 |
32 | s1 = encoder.get_layer("input_1").output ## 256
33 | s2 = encoder.get_layer("block2a_expand_activation").output ## 128
34 | s3 = encoder.get_layer("block3a_expand_activation").output ## 64
35 | s4 = encoder.get_layer("block4a_expand_activation").output ## 32
36 |
37 | """ Bottleneck """
38 | b1 = encoder.get_layer("block6a_expand_activation").output ## 16
39 |
40 | """ Decoder """
41 | d1 = decoder_block(b1, s4, 512) ## 32
42 | d2 = decoder_block(d1, s3, 256) ## 64
43 | d3 = decoder_block(d2, s2, 128) ## 128
44 | d4 = decoder_block(d3, s1, 64) ## 256
45 |
46 | """ Output """
47 | outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)
48 |
49 | model = Model(inputs, outputs, name="EfficientNetB0_UNET")
50 | return model
51 |
52 | if __name__ == "__main__":
53 | input_shape = (256, 256, 3)
54 | model = build_effienet_unet(input_shape)
55 | model.summary()
56 |
--------------------------------------------------------------------------------
/TensorFlow/inception_resnetv2_unet.py:
--------------------------------------------------------------------------------
1 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input, ZeroPadding2D
2 | from tensorflow.keras.models import Model
3 | from tensorflow.keras.applications import InceptionResNetV2
4 |
5 | def conv_block(input, num_filters):
6 | x = Conv2D(num_filters, 3, padding="same")(input)
7 | x = BatchNormalization()(x)
8 | x = Activation("relu")(x)
9 |
10 | x = Conv2D(num_filters, 3, padding="same")(x)
11 | x = BatchNormalization()(x)
12 | x = Activation("relu")(x)
13 |
14 | return x
15 |
16 | def decoder_block(input, skip_features, num_filters):
17 | x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input)
18 | x = Concatenate()([x, skip_features])
19 | x = conv_block(x, num_filters)
20 | return x
21 |
22 | def build_inception_resnetv2_unet(input_shape):
23 | """ Input """
24 | inputs = Input(input_shape)
25 |
26 | """ Pre-trained InceptionResNetV2 Model """
27 | encoder = InceptionResNetV2(include_top=False, weights="imagenet", input_tensor=inputs)
28 |
29 | """ Encoder """
30 | s1 = encoder.get_layer("input_1").output ## (512 x 512)
31 |
32 | s2 = encoder.get_layer("activation").output ## (255 x 255)
33 | s2 = ZeroPadding2D(( (1, 0), (1, 0) ))(s2) ## (256 x 256)
34 |
35 | s3 = encoder.get_layer("activation_3").output ## (126 x 126)
36 | s3 = ZeroPadding2D((1, 1))(s3) ## (128 x 128)
37 |
38 | s4 = encoder.get_layer("activation_74").output ## (61 x 61)
39 | s4 = ZeroPadding2D(( (2, 1),(2, 1) ))(s4) ## (64 x 64)
40 |
41 | """ Bridge """
42 | b1 = encoder.get_layer("activation_161").output ## (30 x 30)
43 | b1 = ZeroPadding2D((1, 1))(b1) ## (32 x 32)
44 |
45 | """ Decoder """
46 | d1 = decoder_block(b1, s4, 512) ## (64 x 64)
47 | d2 = decoder_block(d1, s3, 256) ## (128 x 128)
48 | d3 = decoder_block(d2, s2, 128) ## (256 x 256)
49 | d4 = decoder_block(d3, s1, 64) ## (512 x 512)
50 |
51 | """ Output """
52 | outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)
53 |
54 | model = Model(inputs, outputs, name="InceptionResNetV2_U-Net")
55 | return model
56 |
57 | if __name__ == "__main__":
58 | input_shape = (512, 512, 3)
59 | model = build_inception_resnetv2_unet(input_shape)
60 | model.summary()
61 |
--------------------------------------------------------------------------------
/TensorFlow/mobilenetv2_unet.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
3 | from tensorflow.keras.models import Model
4 | from tensorflow.keras.applications import MobileNetV2
5 |
6 | print("TF Version: ", tf.__version__)
7 |
8 | def conv_block(inputs, num_filters):
9 | x = Conv2D(num_filters, 3, padding="same")(inputs)
10 | x = BatchNormalization()(x)
11 | x = Activation("relu")(x)
12 |
13 | x = Conv2D(num_filters, 3, padding="same")(x)
14 | x = BatchNormalization()(x)
15 | x = Activation("relu")(x)
16 |
17 | return x
18 |
19 | def decoder_block(inputs, skip, num_filters):
20 | x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(inputs)
21 | x = Concatenate()([x, skip])
22 | x = conv_block(x, num_filters)
23 |
24 | return x
25 |
26 | def build_mobilenetv2_unet(input_shape): ## (512, 512, 3)
27 | """ Input """
28 | inputs = Input(shape=input_shape)
29 |
30 | """ Pre-trained MobileNetV2 """
31 | encoder = MobileNetV2(include_top=False, weights="imagenet",
32 | input_tensor=inputs, alpha=1.4)
33 |
34 | """ Encoder """
35 | s1 = encoder.get_layer("input_1").output ## (512 x 512)
36 | s2 = encoder.get_layer("block_1_expand_relu").output ## (256 x 256)
37 | s3 = encoder.get_layer("block_3_expand_relu").output ## (128 x 128)
38 | s4 = encoder.get_layer("block_6_expand_relu").output ## (64 x 64)
39 |
40 | """ Bridge """
41 | b1 = encoder.get_layer("block_13_expand_relu").output ## (32 x 32)
42 |
43 | """ Decoder """
44 | d1 = decoder_block(b1, s4, 512) ## (64 x 64)
45 | d2 = decoder_block(d1, s3, 256) ## (128 x 128)
46 | d3 = decoder_block(d2, s2, 128) ## (256 x 256)
47 | d4 = decoder_block(d3, s1, 64) ## (512 x 512)
48 |
49 | """ Output """
50 | outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)
51 |
52 | model = Model(inputs, outputs, name="MobileNetV2_U-Net")
53 | return model
54 |
55 | if __name__ == "__main__":
56 | model = build_mobilenetv2_unet((512, 512, 3))
57 | model.summary()
58 |
--------------------------------------------------------------------------------
/TensorFlow/multiresunet.py:
--------------------------------------------------------------------------------
1 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPooling2D, Conv2DTranspose
2 | from tensorflow.keras.layers import Concatenate, Input
3 | from tensorflow.keras.models import Model
4 |
5 | def conv_block(x, num_filters, kernel_size, padding="same", act=True):
6 | x = Conv2D(num_filters, kernel_size, padding=padding, use_bias=False)(x)
7 | x = BatchNormalization()(x)
8 | if act:
9 | x = Activation("relu")(x)
10 | return x
11 |
12 | def multires_block(x, num_filters, alpha=1.67):
13 | W = num_filters * alpha
14 |
15 | x0 = x
16 | x1 = conv_block(x0, int(W*0.167), 3)
17 | x2 = conv_block(x1, int(W*0.333), 3)
18 | x3 = conv_block(x2, int(W*0.5), 3)
19 | xc = Concatenate()([x1, x2, x3])
20 | xc = BatchNormalization()(xc)
21 |
22 | nf = int(W*0.167) + int(W*0.333) + int(W*0.5)
23 | sc = conv_block(x0, nf, 1, act=False)
24 |
25 | x = Activation("relu")(xc + sc)
26 | x = BatchNormalization()(x)
27 | return x
28 |
29 | def res_path(x, num_filters, length):
30 | for i in range(length):
31 | x0 = x
32 | x1 = conv_block(x0, num_filters, 3, act=False)
33 | sc = conv_block(x0, num_filters, 1, act=False)
34 | x = Activation("relu")(x1 + sc)
35 | x = BatchNormalization()(x)
36 | return x
37 |
38 | def encoder_block(x, num_filters, length):
39 | x = multires_block(x, num_filters)
40 | s = res_path(x, num_filters, length)
41 | p = MaxPooling2D((2, 2))(x)
42 | return s, p
43 |
44 | def decoder_block(x, skip, num_filters):
45 | x = Conv2DTranspose(num_filters, 2, strides=2, padding="same")(x)
46 | x = Concatenate()([x, skip])
47 | x = multires_block(x, num_filters)
48 | return x
49 |
50 | def build_multiresunet(shape):
51 | """ Input """
52 | inputs = Input(shape)
53 |
54 | """ Encoder """
55 | p0 = inputs
56 | s1, p1 = encoder_block(p0, 32, 4)
57 | s2, p2 = encoder_block(p1, 64, 3)
58 | s3, p3 = encoder_block(p2, 128, 2)
59 | s4, p4 = encoder_block(p3, 256, 1)
60 |
61 | """ Bridge """
62 | b1 = multires_block(p4, 512)
63 |
64 | """ Decoder """
65 | d1 = decoder_block(b1, s4, 256)
66 | d2 = decoder_block(d1, s3, 128)
67 | d3 = decoder_block(d2, s2, 64)
68 | d4 = decoder_block(d3, s1, 32)
69 |
70 | """ Output """
71 | outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)
72 |
73 | """ Model """
74 | model = Model(inputs, outputs, name="MultiResUNET")
75 |
76 | return model
77 |
78 | if __name__ == "__main__":
79 | shape = (256, 256, 3)
80 | model = build_multiresunet(shape)
81 | model.summary()
82 |
--------------------------------------------------------------------------------
/TensorFlow/notebook/ColonSegNet.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# ColonSegNet\n",
8 | "\n",
9 | "
\n",
10 | "\n",
11 | "\n",
12 | "Research Paper: Real-Time Polyp Detection, Localization and Segmentation in Colonoscopy Using Deep Learning \n",
13 | "\n",
14 | "
\n",
15 | "\n",
16 | " - ColonSegNet is an encoder-decoder architecture developed for the purpose of polyp segmentation.
\n",
17 | " - It uses Residual block with Squeeze and Excitation as the main component.
\n",
18 | "
\n",
19 | " \n",
20 | "\n",
21 | "\n",
22 | "
"
23 | ]
24 | },
25 | {
26 | "cell_type": "markdown",
27 | "metadata": {},
28 | "source": [
29 | "## Import"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": 1,
35 | "metadata": {},
36 | "outputs": [],
37 | "source": [
38 | "from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, UpSampling2D, Dense\n",
39 | "from tensorflow.keras.layers import GlobalAveragePooling2D, Conv2DTranspose, Concatenate, Input\n",
40 | "from tensorflow.keras.layers import MaxPool2D\n",
41 | "from tensorflow.keras.models import Model"
42 | ]
43 | },
44 | {
45 | "cell_type": "markdown",
46 | "metadata": {},
47 | "source": [
48 | "## Squeeze and Excitation\n",
49 | "
"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": 2,
55 | "metadata": {},
56 | "outputs": [],
57 | "source": [
58 | "def se_layer(x, num_filters, reduction=16):\n",
59 | " x_init = x\n",
60 | " \n",
61 | " x = GlobalAveragePooling2D()(x)\n",
62 | " x = Dense(num_filters//reduction, use_bias=False, activation=\"relu\")(x)\n",
63 | " x = Dense(num_filters, use_bias=False, activation=\"sigmoid\")(x)\n",
64 | " \n",
65 | " return x_init * x"
66 | ]
67 | },
68 | {
69 | "cell_type": "markdown",
70 | "metadata": {},
71 | "source": [
72 | "## Residual Block\n",
73 | "
"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 3,
79 | "metadata": {},
80 | "outputs": [],
81 | "source": [
82 | "def residual_block(x, num_filters):\n",
83 | " x_init = x\n",
84 | " \n",
85 | " x = Conv2D(num_filters, 3, padding=\"same\")(x)\n",
86 | " x = BatchNormalization()(x)\n",
87 | " x = Activation(\"relu\")(x)\n",
88 | " \n",
89 | " x = Conv2D(num_filters, 3, padding=\"same\")(x)\n",
90 | " x = BatchNormalization()(x)\n",
91 | " \n",
92 | " s = Conv2D(num_filters, 1, padding=\"same\")(x_init)\n",
93 | " s = BatchNormalization()(x)\n",
94 | " s = se_layer(s, num_filters)\n",
95 | " \n",
96 | " x = Activation(\"relu\")(x + s)\n",
97 | " \n",
98 | " return x"
99 | ]
100 | },
101 | {
102 | "cell_type": "markdown",
103 | "metadata": {},
104 | "source": [
105 | "## Strided Convolution\n",
106 | "
"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": 4,
112 | "metadata": {},
113 | "outputs": [],
114 | "source": [
115 | "def strided_conv_block(x, num_filters):\n",
116 | " x = Conv2D(num_filters, 3, strides=2, padding=\"same\")(x)\n",
117 | " x = BatchNormalization()(x)\n",
118 | " x = Activation(\"relu\")(x)\n",
119 | " return x"
120 | ]
121 | },
122 | {
123 | "cell_type": "markdown",
124 | "metadata": {},
125 | "source": [
126 | "## Encoder Block"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": 5,
132 | "metadata": {},
133 | "outputs": [],
134 | "source": [
135 | "def encoder_block(x, num_filters):\n",
136 | " x1 = residual_block(x, num_filters)\n",
137 | " x2 = strided_conv_block(x1, num_filters)\n",
138 | " x3 = residual_block(x2, num_filters)\n",
139 | " p = MaxPool2D((2, 2))(x3)\n",
140 | " \n",
141 | " return x1, x3, p"
142 | ]
143 | },
144 | {
145 | "cell_type": "markdown",
146 | "metadata": {},
147 | "source": [
148 | "## ColonSegNet"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": 12,
154 | "metadata": {},
155 | "outputs": [],
156 | "source": [
157 | "def build_colonsegnet(input_shape):\n",
158 | " \"\"\" Input \"\"\"\n",
159 | " inputs = Input(input_shape)\n",
160 | " \n",
161 | " \"\"\" Encoder \"\"\"\n",
162 | " s11, s12, p1 = encoder_block(inputs, 64)\n",
163 | " s21, s22, p2 = encoder_block(p1, 256)\n",
164 | " \n",
165 | " \"\"\" Decoder 1 \"\"\"\n",
166 | " x = Conv2DTranspose(128, 4, strides=4, padding=\"same\")(s22)\n",
167 | " x = Concatenate()([x, s12])\n",
168 | " x = residual_block(x, 128)\n",
169 | " r1 = x\n",
170 | " \n",
171 | " x = Conv2DTranspose(128, 4, strides=2, padding=\"same\")(s21)\n",
172 | " x = Concatenate()([x, r1])\n",
173 | " x = residual_block(x, 128)\n",
174 | " \n",
175 | " \"\"\" Decoder 2 \"\"\"\n",
176 | " x = Conv2DTranspose(64, 4, strides=2, padding=\"same\")(x)\n",
177 | " x = Concatenate()([x, s11])\n",
178 | " x = residual_block(x, 64)\n",
179 | " r2 = x\n",
180 | " \n",
181 | " x = Conv2DTranspose(64, 4, strides=2, padding=\"same\")(s12)\n",
182 | " x = Concatenate()([x, r2])\n",
183 | " x = residual_block(x, 32)\n",
184 | " \n",
185 | " \"\"\" Output \"\"\"\n",
186 | " output = Conv2D(5, 1, padding=\"same\", activation=\"softmax\")(x)\n",
187 | " \n",
188 | " \"\"\" Model \"\"\"\n",
189 | " model = Model(inputs, output, name=\"ColonSegNet\")\n",
190 | " \n",
191 | " return model"
192 | ]
193 | },
194 | {
195 | "cell_type": "markdown",
196 | "metadata": {},
197 | "source": [
198 | "## Model"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": 13,
204 | "metadata": {},
205 | "outputs": [],
206 | "source": [
207 | "input_shape = (512, 512, 3)\n",
208 | "model = build_colonsegnet(input_shape)"
209 | ]
210 | },
211 | {
212 | "cell_type": "code",
213 | "execution_count": 14,
214 | "metadata": {
215 | "scrolled": false
216 | },
217 | "outputs": [
218 | {
219 | "name": "stdout",
220 | "output_type": "stream",
221 | "text": [
222 | "Model: \"ColonSegNet\"\n",
223 | "__________________________________________________________________________________________________\n",
224 | "Layer (type) Output Shape Param # Connected to \n",
225 | "==================================================================================================\n",
226 | "input_3 (InputLayer) [(None, 512, 512, 3) 0 \n",
227 | "__________________________________________________________________________________________________\n",
228 | "conv2d_54 (Conv2D) (None, 512, 512, 64) 1792 input_3[0][0] \n",
229 | "__________________________________________________________________________________________________\n",
230 | "batch_normalization_52 (BatchNo (None, 512, 512, 64) 256 conv2d_54[0][0] \n",
231 | "__________________________________________________________________________________________________\n",
232 | "activation_36 (Activation) (None, 512, 512, 64) 0 batch_normalization_52[0][0] \n",
233 | "__________________________________________________________________________________________________\n",
234 | "conv2d_55 (Conv2D) (None, 512, 512, 64) 36928 activation_36[0][0] \n",
235 | "__________________________________________________________________________________________________\n",
236 | "batch_normalization_53 (BatchNo (None, 512, 512, 64) 256 conv2d_55[0][0] \n",
237 | "__________________________________________________________________________________________________\n",
238 | "batch_normalization_54 (BatchNo (None, 512, 512, 64) 256 batch_normalization_53[0][0] \n",
239 | "__________________________________________________________________________________________________\n",
240 | "global_average_pooling2d_16 (Gl (None, 64) 0 batch_normalization_54[0][0] \n",
241 | "__________________________________________________________________________________________________\n",
242 | "dense_32 (Dense) (None, 4) 256 global_average_pooling2d_16[0][0]\n",
243 | "__________________________________________________________________________________________________\n",
244 | "dense_33 (Dense) (None, 64) 256 dense_32[0][0] \n",
245 | "__________________________________________________________________________________________________\n",
246 | "tf.math.multiply_16 (TFOpLambda (None, 512, 512, 64) 0 batch_normalization_54[0][0] \n",
247 | " dense_33[0][0] \n",
248 | "__________________________________________________________________________________________________\n",
249 | "tf.__operators__.add_16 (TFOpLa (None, 512, 512, 64) 0 batch_normalization_53[0][0] \n",
250 | " tf.math.multiply_16[0][0] \n",
251 | "__________________________________________________________________________________________________\n",
252 | "activation_37 (Activation) (None, 512, 512, 64) 0 tf.__operators__.add_16[0][0] \n",
253 | "__________________________________________________________________________________________________\n",
254 | "conv2d_57 (Conv2D) (None, 256, 256, 64) 36928 activation_37[0][0] \n",
255 | "__________________________________________________________________________________________________\n",
256 | "batch_normalization_55 (BatchNo (None, 256, 256, 64) 256 conv2d_57[0][0] \n",
257 | "__________________________________________________________________________________________________\n",
258 | "activation_38 (Activation) (None, 256, 256, 64) 0 batch_normalization_55[0][0] \n",
259 | "__________________________________________________________________________________________________\n",
260 | "conv2d_58 (Conv2D) (None, 256, 256, 64) 36928 activation_38[0][0] \n",
261 | "__________________________________________________________________________________________________\n",
262 | "batch_normalization_56 (BatchNo (None, 256, 256, 64) 256 conv2d_58[0][0] \n",
263 | "__________________________________________________________________________________________________\n",
264 | "activation_39 (Activation) (None, 256, 256, 64) 0 batch_normalization_56[0][0] \n",
265 | "__________________________________________________________________________________________________\n",
266 | "conv2d_59 (Conv2D) (None, 256, 256, 64) 36928 activation_39[0][0] \n",
267 | "__________________________________________________________________________________________________\n",
268 | "batch_normalization_57 (BatchNo (None, 256, 256, 64) 256 conv2d_59[0][0] \n",
269 | "__________________________________________________________________________________________________\n",
270 | "batch_normalization_58 (BatchNo (None, 256, 256, 64) 256 batch_normalization_57[0][0] \n",
271 | "__________________________________________________________________________________________________\n",
272 | "global_average_pooling2d_17 (Gl (None, 64) 0 batch_normalization_58[0][0] \n",
273 | "__________________________________________________________________________________________________\n",
274 | "dense_34 (Dense) (None, 4) 256 global_average_pooling2d_17[0][0]\n",
275 | "__________________________________________________________________________________________________\n",
276 | "dense_35 (Dense) (None, 64) 256 dense_34[0][0] \n",
277 | "__________________________________________________________________________________________________\n",
278 | "tf.math.multiply_17 (TFOpLambda (None, 256, 256, 64) 0 batch_normalization_58[0][0] \n",
279 | " dense_35[0][0] \n",
280 | "__________________________________________________________________________________________________\n",
281 | "tf.__operators__.add_17 (TFOpLa (None, 256, 256, 64) 0 batch_normalization_57[0][0] \n",
282 | " tf.math.multiply_17[0][0] \n",
283 | "__________________________________________________________________________________________________\n",
284 | "activation_40 (Activation) (None, 256, 256, 64) 0 tf.__operators__.add_17[0][0] \n",
285 | "__________________________________________________________________________________________________\n",
286 | "max_pooling2d_4 (MaxPooling2D) (None, 128, 128, 64) 0 activation_40[0][0] \n",
287 | "__________________________________________________________________________________________________\n",
288 | "conv2d_61 (Conv2D) (None, 128, 128, 256 147712 max_pooling2d_4[0][0] \n",
289 | "__________________________________________________________________________________________________\n",
290 | "batch_normalization_59 (BatchNo (None, 128, 128, 256 1024 conv2d_61[0][0] \n",
291 | "__________________________________________________________________________________________________\n",
292 | "activation_41 (Activation) (None, 128, 128, 256 0 batch_normalization_59[0][0] \n",
293 | "__________________________________________________________________________________________________\n",
294 | "conv2d_62 (Conv2D) (None, 128, 128, 256 590080 activation_41[0][0] \n",
295 | "__________________________________________________________________________________________________\n",
296 | "batch_normalization_60 (BatchNo (None, 128, 128, 256 1024 conv2d_62[0][0] \n",
297 | "__________________________________________________________________________________________________\n",
298 | "batch_normalization_61 (BatchNo (None, 128, 128, 256 1024 batch_normalization_60[0][0] \n",
299 | "__________________________________________________________________________________________________\n",
300 | "global_average_pooling2d_18 (Gl (None, 256) 0 batch_normalization_61[0][0] \n",
301 | "__________________________________________________________________________________________________\n",
302 | "dense_36 (Dense) (None, 16) 4096 global_average_pooling2d_18[0][0]\n",
303 | "__________________________________________________________________________________________________\n",
304 | "dense_37 (Dense) (None, 256) 4096 dense_36[0][0] \n",
305 | "__________________________________________________________________________________________________\n",
306 | "tf.math.multiply_18 (TFOpLambda (None, 128, 128, 256 0 batch_normalization_61[0][0] \n",
307 | " dense_37[0][0] \n",
308 | "__________________________________________________________________________________________________\n",
309 | "tf.__operators__.add_18 (TFOpLa (None, 128, 128, 256 0 batch_normalization_60[0][0] \n",
310 | " tf.math.multiply_18[0][0] \n",
311 | "__________________________________________________________________________________________________\n",
312 | "activation_42 (Activation) (None, 128, 128, 256 0 tf.__operators__.add_18[0][0] \n",
313 | "__________________________________________________________________________________________________\n",
314 | "conv2d_64 (Conv2D) (None, 64, 64, 256) 590080 activation_42[0][0] \n",
315 | "__________________________________________________________________________________________________\n",
316 | "batch_normalization_62 (BatchNo (None, 64, 64, 256) 1024 conv2d_64[0][0] \n",
317 | "__________________________________________________________________________________________________\n",
318 | "activation_43 (Activation) (None, 64, 64, 256) 0 batch_normalization_62[0][0] \n",
319 | "__________________________________________________________________________________________________\n",
320 | "conv2d_65 (Conv2D) (None, 64, 64, 256) 590080 activation_43[0][0] \n",
321 | "__________________________________________________________________________________________________\n",
322 | "batch_normalization_63 (BatchNo (None, 64, 64, 256) 1024 conv2d_65[0][0] \n",
323 | "__________________________________________________________________________________________________\n",
324 | "activation_44 (Activation) (None, 64, 64, 256) 0 batch_normalization_63[0][0] \n",
325 | "__________________________________________________________________________________________________\n",
326 | "conv2d_66 (Conv2D) (None, 64, 64, 256) 590080 activation_44[0][0] \n",
327 | "__________________________________________________________________________________________________\n",
328 | "batch_normalization_64 (BatchNo (None, 64, 64, 256) 1024 conv2d_66[0][0] \n",
329 | "__________________________________________________________________________________________________\n",
330 | "batch_normalization_65 (BatchNo (None, 64, 64, 256) 1024 batch_normalization_64[0][0] \n",
331 | "__________________________________________________________________________________________________\n",
332 | "global_average_pooling2d_19 (Gl (None, 256) 0 batch_normalization_65[0][0] \n",
333 | "__________________________________________________________________________________________________\n",
334 | "dense_38 (Dense) (None, 16) 4096 global_average_pooling2d_19[0][0]\n",
335 | "__________________________________________________________________________________________________\n",
336 | "dense_39 (Dense) (None, 256) 4096 dense_38[0][0] \n",
337 | "__________________________________________________________________________________________________\n",
338 | "tf.math.multiply_19 (TFOpLambda (None, 64, 64, 256) 0 batch_normalization_65[0][0] \n",
339 | " dense_39[0][0] \n",
340 | "__________________________________________________________________________________________________\n",
341 | "tf.__operators__.add_19 (TFOpLa (None, 64, 64, 256) 0 batch_normalization_64[0][0] \n",
342 | " tf.math.multiply_19[0][0] \n",
343 | "__________________________________________________________________________________________________\n",
344 | "activation_45 (Activation) (None, 64, 64, 256) 0 tf.__operators__.add_19[0][0] \n",
345 | "__________________________________________________________________________________________________\n",
346 | "conv2d_transpose_8 (Conv2DTrans (None, 256, 256, 128 524416 activation_45[0][0] \n",
347 | "__________________________________________________________________________________________________\n",
348 | "concatenate_8 (Concatenate) (None, 256, 256, 192 0 conv2d_transpose_8[0][0] \n",
349 | " activation_40[0][0] \n",
350 | "__________________________________________________________________________________________________\n",
351 | "conv2d_68 (Conv2D) (None, 256, 256, 128 221312 concatenate_8[0][0] \n",
352 | "__________________________________________________________________________________________________\n",
353 | "batch_normalization_66 (BatchNo (None, 256, 256, 128 512 conv2d_68[0][0] \n",
354 | "__________________________________________________________________________________________________\n",
355 | "activation_46 (Activation) (None, 256, 256, 128 0 batch_normalization_66[0][0] \n",
356 | "__________________________________________________________________________________________________\n",
357 | "conv2d_69 (Conv2D) (None, 256, 256, 128 147584 activation_46[0][0] \n",
358 | "__________________________________________________________________________________________________\n",
359 | "batch_normalization_67 (BatchNo (None, 256, 256, 128 512 conv2d_69[0][0] \n",
360 | "__________________________________________________________________________________________________\n",
361 | "batch_normalization_68 (BatchNo (None, 256, 256, 128 512 batch_normalization_67[0][0] \n",
362 | "__________________________________________________________________________________________________\n",
363 | "global_average_pooling2d_20 (Gl (None, 128) 0 batch_normalization_68[0][0] \n",
364 | "__________________________________________________________________________________________________\n",
365 | "dense_40 (Dense) (None, 8) 1024 global_average_pooling2d_20[0][0]\n",
366 | "__________________________________________________________________________________________________\n",
367 | "dense_41 (Dense) (None, 128) 1024 dense_40[0][0] \n",
368 | "__________________________________________________________________________________________________\n",
369 | "tf.math.multiply_20 (TFOpLambda (None, 256, 256, 128 0 batch_normalization_68[0][0] \n",
370 | " dense_41[0][0] \n",
371 | "__________________________________________________________________________________________________\n",
372 | "tf.__operators__.add_20 (TFOpLa (None, 256, 256, 128 0 batch_normalization_67[0][0] \n",
373 | " tf.math.multiply_20[0][0] \n",
374 | "__________________________________________________________________________________________________\n",
375 | "conv2d_transpose_9 (Conv2DTrans (None, 256, 256, 128 524416 activation_42[0][0] \n",
376 | "__________________________________________________________________________________________________\n",
377 | "activation_47 (Activation) (None, 256, 256, 128 0 tf.__operators__.add_20[0][0] \n",
378 | "__________________________________________________________________________________________________\n",
379 | "concatenate_9 (Concatenate) (None, 256, 256, 256 0 conv2d_transpose_9[0][0] \n",
380 | " activation_47[0][0] \n",
381 | "__________________________________________________________________________________________________\n",
382 | "conv2d_71 (Conv2D) (None, 256, 256, 128 295040 concatenate_9[0][0] \n",
383 | "__________________________________________________________________________________________________\n",
384 | "batch_normalization_69 (BatchNo (None, 256, 256, 128 512 conv2d_71[0][0] \n",
385 | "__________________________________________________________________________________________________\n",
386 | "activation_48 (Activation) (None, 256, 256, 128 0 batch_normalization_69[0][0] \n",
387 | "__________________________________________________________________________________________________\n",
388 | "conv2d_72 (Conv2D) (None, 256, 256, 128 147584 activation_48[0][0] \n",
389 | "__________________________________________________________________________________________________\n",
390 | "batch_normalization_70 (BatchNo (None, 256, 256, 128 512 conv2d_72[0][0] \n",
391 | "__________________________________________________________________________________________________\n",
392 | "batch_normalization_71 (BatchNo (None, 256, 256, 128 512 batch_normalization_70[0][0] \n",
393 | "__________________________________________________________________________________________________\n",
394 | "global_average_pooling2d_21 (Gl (None, 128) 0 batch_normalization_71[0][0] \n",
395 | "__________________________________________________________________________________________________\n",
396 | "dense_42 (Dense) (None, 8) 1024 global_average_pooling2d_21[0][0]\n",
397 | "__________________________________________________________________________________________________\n",
398 | "dense_43 (Dense) (None, 128) 1024 dense_42[0][0] \n",
399 | "__________________________________________________________________________________________________\n",
400 | "tf.math.multiply_21 (TFOpLambda (None, 256, 256, 128 0 batch_normalization_71[0][0] \n",
401 | " dense_43[0][0] \n",
402 | "__________________________________________________________________________________________________\n",
403 | "tf.__operators__.add_21 (TFOpLa (None, 256, 256, 128 0 batch_normalization_70[0][0] \n",
404 | " tf.math.multiply_21[0][0] \n",
405 | "__________________________________________________________________________________________________\n",
406 | "activation_49 (Activation) (None, 256, 256, 128 0 tf.__operators__.add_21[0][0] \n",
407 | "__________________________________________________________________________________________________\n",
408 | "conv2d_transpose_10 (Conv2DTran (None, 512, 512, 64) 131136 activation_49[0][0] \n",
409 | "__________________________________________________________________________________________________\n",
410 | "concatenate_10 (Concatenate) (None, 512, 512, 128 0 conv2d_transpose_10[0][0] \n",
411 | " activation_37[0][0] \n",
412 | "__________________________________________________________________________________________________\n",
413 | "conv2d_74 (Conv2D) (None, 512, 512, 64) 73792 concatenate_10[0][0] \n",
414 | "__________________________________________________________________________________________________\n",
415 | "batch_normalization_72 (BatchNo (None, 512, 512, 64) 256 conv2d_74[0][0] \n",
416 | "__________________________________________________________________________________________________\n",
417 | "activation_50 (Activation) (None, 512, 512, 64) 0 batch_normalization_72[0][0] \n",
418 | "__________________________________________________________________________________________________\n",
419 | "conv2d_75 (Conv2D) (None, 512, 512, 64) 36928 activation_50[0][0] \n",
420 | "__________________________________________________________________________________________________\n",
421 | "batch_normalization_73 (BatchNo (None, 512, 512, 64) 256 conv2d_75[0][0] \n",
422 | "__________________________________________________________________________________________________\n",
423 | "batch_normalization_74 (BatchNo (None, 512, 512, 64) 256 batch_normalization_73[0][0] \n",
424 | "__________________________________________________________________________________________________\n",
425 | "global_average_pooling2d_22 (Gl (None, 64) 0 batch_normalization_74[0][0] \n",
426 | "__________________________________________________________________________________________________\n",
427 | "dense_44 (Dense) (None, 4) 256 global_average_pooling2d_22[0][0]\n",
428 | "__________________________________________________________________________________________________\n",
429 | "dense_45 (Dense) (None, 64) 256 dense_44[0][0] \n",
430 | "__________________________________________________________________________________________________\n",
431 | "tf.math.multiply_22 (TFOpLambda (None, 512, 512, 64) 0 batch_normalization_74[0][0] \n",
432 | " dense_45[0][0] \n",
433 | "__________________________________________________________________________________________________\n",
434 | "tf.__operators__.add_22 (TFOpLa (None, 512, 512, 64) 0 batch_normalization_73[0][0] \n",
435 | " tf.math.multiply_22[0][0] \n",
436 | "__________________________________________________________________________________________________\n",
437 | "conv2d_transpose_11 (Conv2DTran (None, 512, 512, 64) 65600 activation_40[0][0] \n",
438 | "__________________________________________________________________________________________________\n",
439 | "activation_51 (Activation) (None, 512, 512, 64) 0 tf.__operators__.add_22[0][0] \n",
440 | "__________________________________________________________________________________________________\n",
441 | "concatenate_11 (Concatenate) (None, 512, 512, 128 0 conv2d_transpose_11[0][0] \n",
442 | " activation_51[0][0] \n",
443 | "__________________________________________________________________________________________________\n",
444 | "conv2d_77 (Conv2D) (None, 512, 512, 32) 36896 concatenate_11[0][0] \n",
445 | "__________________________________________________________________________________________________\n",
446 | "batch_normalization_75 (BatchNo (None, 512, 512, 32) 128 conv2d_77[0][0] \n",
447 | "__________________________________________________________________________________________________\n",
448 | "activation_52 (Activation) (None, 512, 512, 32) 0 batch_normalization_75[0][0] \n",
449 | "__________________________________________________________________________________________________\n",
450 | "conv2d_78 (Conv2D) (None, 512, 512, 32) 9248 activation_52[0][0] \n",
451 | "__________________________________________________________________________________________________\n",
452 | "batch_normalization_76 (BatchNo (None, 512, 512, 32) 128 conv2d_78[0][0] \n",
453 | "__________________________________________________________________________________________________\n",
454 | "batch_normalization_77 (BatchNo (None, 512, 512, 32) 128 batch_normalization_76[0][0] \n",
455 | "__________________________________________________________________________________________________\n",
456 | "global_average_pooling2d_23 (Gl (None, 32) 0 batch_normalization_77[0][0] \n",
457 | "__________________________________________________________________________________________________\n",
458 | "dense_46 (Dense) (None, 2) 64 global_average_pooling2d_23[0][0]\n",
459 | "__________________________________________________________________________________________________\n",
460 | "dense_47 (Dense) (None, 32) 64 dense_46[0][0] \n",
461 | "__________________________________________________________________________________________________\n",
462 | "tf.math.multiply_23 (TFOpLambda (None, 512, 512, 32) 0 batch_normalization_77[0][0] \n",
463 | " dense_47[0][0] \n",
464 | "__________________________________________________________________________________________________\n",
465 | "tf.__operators__.add_23 (TFOpLa (None, 512, 512, 32) 0 batch_normalization_76[0][0] \n",
466 | " tf.math.multiply_23[0][0] \n",
467 | "__________________________________________________________________________________________________\n",
468 | "activation_53 (Activation) (None, 512, 512, 32) 0 tf.__operators__.add_23[0][0] \n",
469 | "__________________________________________________________________________________________________\n",
470 | "conv2d_80 (Conv2D) (None, 512, 512, 5) 165 activation_53[0][0] \n",
471 | "==================================================================================================\n",
472 | "Total params: 4,906,981\n",
473 | "Trainable params: 4,900,389\n",
474 | "Non-trainable params: 6,592\n",
475 | "__________________________________________________________________________________________________\n"
476 | ]
477 | }
478 | ],
479 | "source": [
480 | "model.summary()"
481 | ]
482 | },
483 | {
484 | "cell_type": "code",
485 | "execution_count": null,
486 | "metadata": {},
487 | "outputs": [],
488 | "source": []
489 | }
490 | ],
491 | "metadata": {
492 | "kernelspec": {
493 | "display_name": "tf",
494 | "language": "python",
495 | "name": "tf"
496 | },
497 | "language_info": {
498 | "codemirror_mode": {
499 | "name": "ipython",
500 | "version": 3
501 | },
502 | "file_extension": ".py",
503 | "mimetype": "text/x-python",
504 | "name": "python",
505 | "nbconvert_exporter": "python",
506 | "pygments_lexer": "ipython3",
507 | "version": "3.8.10"
508 | }
509 | },
510 | "nbformat": 4,
511 | "nbformat_minor": 4
512 | }
513 |
--------------------------------------------------------------------------------
/TensorFlow/notebook/README.md:
--------------------------------------------------------------------------------
1 | # Jupyter Notebook
2 |
--------------------------------------------------------------------------------
/TensorFlow/notebook/images/ColonSegNet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nikhilroxtomar/Semantic-Segmentation-Architecture/558a6b88108a58aac21ce4109397022e21cc6f1c/TensorFlow/notebook/images/ColonSegNet.png
--------------------------------------------------------------------------------
/TensorFlow/notebook/images/ResidualBlock.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nikhilroxtomar/Semantic-Segmentation-Architecture/558a6b88108a58aac21ce4109397022e21cc6f1c/TensorFlow/notebook/images/ResidualBlock.png
--------------------------------------------------------------------------------
/TensorFlow/notebook/images/Strided_Conv_Block.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nikhilroxtomar/Semantic-Segmentation-Architecture/558a6b88108a58aac21ce4109397022e21cc6f1c/TensorFlow/notebook/images/Strided_Conv_Block.png
--------------------------------------------------------------------------------
/TensorFlow/notebook/images/squeeze_and_excitation_detailed_block_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nikhilroxtomar/Semantic-Segmentation-Architecture/558a6b88108a58aac21ce4109397022e21cc6f1c/TensorFlow/notebook/images/squeeze_and_excitation_detailed_block_diagram.png
--------------------------------------------------------------------------------
/TensorFlow/resnet50_unet.py:
--------------------------------------------------------------------------------
1 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
2 | from tensorflow.keras.models import Model
3 | from tensorflow.keras.applications import ResNet50
4 |
5 | def conv_block(input, num_filters):
6 | x = Conv2D(num_filters, 3, padding="same")(input)
7 | x = BatchNormalization()(x)
8 | x = Activation("relu")(x)
9 |
10 | x = Conv2D(num_filters, 3, padding="same")(x)
11 | x = BatchNormalization()(x)
12 | x = Activation("relu")(x)
13 |
14 | return x
15 |
16 | def decoder_block(input, skip_features, num_filters):
17 | x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input)
18 | x = Concatenate()([x, skip_features])
19 | x = conv_block(x, num_filters)
20 | return x
21 |
22 | def build_resnet50_unet(input_shape):
23 | """ Input """
24 | inputs = Input(input_shape)
25 |
26 | """ Pre-trained ResNet50 Model """
27 | resnet50 = ResNet50(include_top=False, weights="imagenet", input_tensor=inputs)
28 |
29 | """ Encoder """
30 | s1 = resnet50.get_layer("input_1").output ## (512 x 512)
31 | s2 = resnet50.get_layer("conv1_relu").output ## (256 x 256)
32 | s3 = resnet50.get_layer("conv2_block3_out").output ## (128 x 128)
33 | s4 = resnet50.get_layer("conv3_block4_out").output ## (64 x 64)
34 |
35 | """ Bridge """
36 | b1 = resnet50.get_layer("conv4_block6_out").output ## (32 x 32)
37 |
38 | """ Decoder """
39 | d1 = decoder_block(b1, s4, 512) ## (64 x 64)
40 | d2 = decoder_block(d1, s3, 256) ## (128 x 128)
41 | d3 = decoder_block(d2, s2, 128) ## (256 x 256)
42 | d4 = decoder_block(d3, s1, 64) ## (512 x 512)
43 |
44 | """ Output """
45 | outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)
46 |
47 | model = Model(inputs, outputs, name="ResNet50_U-Net")
48 | return model
49 |
50 | if __name__ == "__main__":
51 | input_shape = (512, 512, 3)
52 | model = build_resnet50_unet(input_shape)
53 | model.summary()
54 |
--------------------------------------------------------------------------------
/TensorFlow/resunet++.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow.keras.layers as L
3 | from tensorflow.keras.models import Model
4 |
5 | def SE(inputs, ratio=8):
6 | ## [8, H, W, 32]
7 | channel_axis = -1
8 | num_filters = inputs.shape[channel_axis]
9 | se_shape = (1, 1, num_filters)
10 |
11 | x = L.GlobalAveragePooling2D()(inputs) ## [8, 32]
12 | x = L.Reshape(se_shape)(x)
13 | x = L.Dense(num_filters // ratio, activation='relu', use_bias=False)(x)
14 | x = L.Dense(num_filters, activation='sigmoid', use_bias=False)(x)
15 |
16 | x = L.Multiply()([inputs, x])
17 | return x
18 |
19 |
20 | def stem_block(inputs, num_filters):
21 | ## Conv 1
22 | x = L.Conv2D(num_filters, 3, padding="same")(inputs)
23 | x = L.BatchNormalization()(x)
24 | x = L.Activation("relu")(x)
25 | x = L.Conv2D(num_filters, 3, padding="same")(x)
26 |
27 | ## Shortcut
28 | s = L.Conv2D(num_filters, 1, padding="same")(inputs)
29 |
30 | ## Add
31 | x = L.Add()([x, s])
32 | return x
33 |
34 | def resnet_block(inputs, num_filters, strides=1):
35 | ## SE
36 | inputs = SE(inputs)
37 |
38 | ## Conv 1
39 | x = L.BatchNormalization()(inputs)
40 | x = L.Activation("relu")(x)
41 | x = L.Conv2D(num_filters, 3, padding="same", strides=strides)(x)
42 |
43 | ## Conv 2
44 | x = L.BatchNormalization()(x)
45 | x = L.Activation("relu")(x)
46 | x = L.Conv2D(num_filters, 3, padding="same", strides=1)(x)
47 |
48 | ## Shortcut
49 | s = L.Conv2D(num_filters, 1, padding="same", strides=strides)(inputs)
50 |
51 | ## Add
52 | x = L.Add()([x, s])
53 |
54 | return x
55 |
56 | def aspp_block(inputs, num_filters):
57 | x1 = L.Conv2D(num_filters, 3, dilation_rate=6, padding="same")(inputs)
58 | x1 = L.BatchNormalization()(x1)
59 |
60 | x2 = L.Conv2D(num_filters, 3, dilation_rate=12, padding="same")(inputs)
61 | x2 = L.BatchNormalization()(x2)
62 |
63 | x3 = L.Conv2D(num_filters, 3, dilation_rate=18, padding="same")(inputs)
64 | x3 = L.BatchNormalization()(x3)
65 |
66 | x4 = L.Conv2D(num_filters, (3, 3), padding="same")(inputs)
67 | x4 = L.BatchNormalization()(x4)
68 |
69 | y = L.Add()([x1, x2, x3, x4])
70 | y = L.Conv2D(num_filters, 1, padding="same")(y)
71 |
72 | return y
73 |
74 | def attetion_block(x1, x2):
75 | num_filters = x2.shape[-1]
76 |
77 | x1_conv = L.BatchNormalization()(x1)
78 | x1_conv = L.Activation("relu")(x1_conv)
79 | x1_conv = L.Conv2D(num_filters, 3, padding="same")(x1_conv)
80 | x1_pool = L.MaxPooling2D((2, 2))(x1_conv)
81 |
82 | x2_conv = L.BatchNormalization()(x2)
83 | x2_conv = L.Activation("relu")(x2_conv)
84 | x2_conv = L.Conv2D(num_filters, 3, padding="same")(x2_conv)
85 |
86 | x = L.Add()([x1_pool, x2_conv])
87 |
88 | x = L.BatchNormalization()(x)
89 | x = L.Activation("relu")(x)
90 | x = L.Conv2D(num_filters, 3, padding="same")(x)
91 |
92 | x = L.Multiply()([x, x2])
93 | return x
94 |
95 | def resunet_pp(input_shape):
96 | """ Inputs """
97 | inputs = L.Input(input_shape)
98 |
99 | """ Encoder """
100 | c1 = stem_block(inputs, 16)
101 | c2 = resnet_block(c1, 32, strides=2)
102 | c3 = resnet_block(c2, 64, strides=2)
103 | c4 = resnet_block(c3, 128, strides=2)
104 |
105 | """ Bridge """
106 | b1 = aspp_block(c4, 256)
107 |
108 | """ Decoder """
109 | d1 = attetion_block(c3, b1)
110 | d1 = L.UpSampling2D((2, 2))(d1)
111 | d1 = L.Concatenate()([d1, c3])
112 | d1 = resnet_block(d1, 128)
113 |
114 | d2 = attetion_block(c2, d1)
115 | d2 = L.UpSampling2D((2, 2))(d2)
116 | d2 = L.Concatenate()([d2, c2])
117 | d2 = resnet_block(d2, 64)
118 |
119 | d3 = attetion_block(c1, d2)
120 | d3 = L.UpSampling2D((2, 2))(d3)
121 | d3 = L.Concatenate()([d3, c1])
122 | d3 = resnet_block(d3, 32)
123 |
124 | """ Output"""
125 | outputs = aspp_block(d3, 16)
126 | outputs = L.Conv2D(1, 1, padding="same")(outputs)
127 | outputs = L.Activation("sigmoid")(outputs)
128 |
129 | """ Model """
130 | model = Model(inputs, outputs)
131 | return model
132 |
133 | if __name__ == "__main__":
134 | input_shape = (256, 256, 3)
135 | model = resunet_pp(input_shape)
136 | model.summary()
137 |
--------------------------------------------------------------------------------
/TensorFlow/resunet.py:
--------------------------------------------------------------------------------
1 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, UpSampling2D, Concatenate, Input
2 | from tensorflow.keras.models import Model
3 |
4 | def batchnorm_relu(inputs):
5 | """ Batch Normalization & ReLU """
6 | x = BatchNormalization()(inputs)
7 | x = Activation("relu")(x)
8 | return x
9 |
10 | def residual_block(inputs, num_filters, strides=1):
11 | """ Convolutional Layers """
12 | x = batchnorm_relu(inputs)
13 | x = Conv2D(num_filters, 3, padding="same", strides=strides)(x)
14 | x = batchnorm_relu(x)
15 | x = Conv2D(num_filters, 3, padding="same", strides=1)(x)
16 |
17 | """ Shortcut Connection (Identity Mapping) """
18 | s = Conv2D(num_filters, 1, padding="same", strides=strides)(inputs)
19 |
20 | """ Addition """
21 | x = x + s
22 | return x
23 |
24 | def decoder_block(inputs, skip_features, num_filters):
25 | """ Decoder Block """
26 |
27 | x = UpSampling2D((2, 2))(inputs)
28 | x = Concatenate()([x, skip_features])
29 | x = residual_block(x, num_filters, strides=1)
30 | return x
31 |
32 | def build_resunet(input_shape):
33 | """ RESUNET Architecture """
34 |
35 | inputs = Input(input_shape)
36 |
37 | """ Endoder 1 """
38 | x = Conv2D(64, 3, padding="same", strides=1)(inputs)
39 | x = batchnorm_relu(x)
40 | x = Conv2D(64, 3, padding="same", strides=1)(x)
41 | s = Conv2D(64, 1, padding="same")(inputs)
42 | s1 = x + s
43 |
44 | """ Encoder 2, 3 """
45 | s2 = residual_block(s1, 128, strides=2)
46 | s3 = residual_block(s2, 256, strides=2)
47 |
48 | """ Bridge """
49 | b = residual_block(s3, 512, strides=2)
50 |
51 | """ Decoder 1, 2, 3 """
52 | x = decoder_block(b, s3, 256)
53 | x = decoder_block(x, s2, 128)
54 | x = decoder_block(x, s1, 64)
55 |
56 | """ Classifier """
57 | outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(x)
58 |
59 | """ Model """
60 | model = Model(inputs, outputs, name="RESUNET")
61 |
62 | return model
63 |
64 | if __name__ == "__main__":
65 | shape = (224, 224, 3)
66 | model = build_resunet(shape)
67 | model.summary()
68 |
--------------------------------------------------------------------------------
/TensorFlow/u2-net.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
4 |
5 | import tensorflow as tf
6 | from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, MaxPool2D, UpSampling2D, Concatenate, Add
7 |
8 | def conv_block(inputs, out_ch, rate=1):
9 | x = Conv2D(out_ch, 3, padding="same", dilation_rate=rate)(inputs)
10 | x = BatchNormalization()(x)
11 | x = Activation("relu")(x)
12 | return x
13 |
14 | def RSU_L(inputs, out_ch, int_ch, num_layers, rate=2):
15 | """ Initial Conv """
16 | x = conv_block(inputs, out_ch)
17 | init_feats = x
18 |
19 | """ Encoder """
20 | skip = []
21 | x = conv_block(x, int_ch)
22 | skip.append(x)
23 |
24 | for i in range(num_layers-2):
25 | x = MaxPool2D((2, 2))(x)
26 | x = conv_block(x, int_ch)
27 | skip.append(x)
28 |
29 | """ Bridge """
30 | x = conv_block(x, int_ch, rate=rate)
31 |
32 | """ Decoder """
33 | skip.reverse()
34 |
35 | x = Concatenate()([x, skip[0]])
36 | x = conv_block(x, int_ch)
37 |
38 | for i in range(num_layers-3):
39 | x = UpSampling2D(size=(2, 2), interpolation="bilinear")(x)
40 | x = Concatenate()([x, skip[i+1]])
41 | x = conv_block(x, int_ch)
42 |
43 | x = UpSampling2D(size=(2, 2), interpolation="bilinear")(x)
44 | x = Concatenate()([x, skip[-1]])
45 | x = conv_block(x, out_ch)
46 |
47 | """ Add """
48 | x = Add()([x, init_feats])
49 | return x
50 |
51 | def RSU_4F(inputs, out_ch, int_ch):
52 | """ Initial Conv """
53 | x0 = conv_block(inputs, out_ch, rate=1)
54 |
55 | """ Encoder """
56 | x1 = conv_block(x0, int_ch, rate=1)
57 | x2 = conv_block(x1, int_ch, rate=2)
58 | x3 = conv_block(x2, int_ch, rate=4)
59 |
60 | """ Bridge """
61 | x4 = conv_block(x3, int_ch, rate=8)
62 |
63 | """ Decoder """
64 | x = Concatenate()([x4, x3])
65 | x = conv_block(x, int_ch, rate=4)
66 |
67 | x = Concatenate()([x, x2])
68 | x = conv_block(x, int_ch, rate=2)
69 |
70 | x = Concatenate()([x, x1])
71 | x = conv_block(x, out_ch, rate=1)
72 |
73 | """ Addition """
74 | x = Add()([x, x0])
75 | return x
76 |
77 | def u2net(input_shape, out_ch, int_ch, num_classes=1):
78 | """ Input Layer """
79 | inputs = Input(input_shape)
80 | s0 = inputs
81 |
82 | """ Encoder """
83 | s1 = RSU_L(s0, out_ch[0], int_ch[0], 7)
84 | p1 = MaxPool2D((2, 2))(s1)
85 |
86 | s2 = RSU_L(p1, out_ch[1], int_ch[1], 6)
87 | p2 = MaxPool2D((2, 2))(s2)
88 |
89 | s3 = RSU_L(p2, out_ch[2], int_ch[2], 5)
90 | p3 = MaxPool2D((2, 2))(s3)
91 |
92 | s4 = RSU_L(p3, out_ch[3], int_ch[3], 4)
93 | p4 = MaxPool2D((2, 2))(s4)
94 |
95 | s5 = RSU_4F(p4, out_ch[4], int_ch[4])
96 | p5 = MaxPool2D((2, 2))(s5)
97 |
98 | """ Bridge """
99 | b1 = RSU_4F(p5, out_ch[5], int_ch[5])
100 | b2 = UpSampling2D(size=(2, 2), interpolation="bilinear")(b1)
101 |
102 | """ Decoder """
103 | d1 = Concatenate()([b2, s5])
104 | d1 = RSU_4F(d1, out_ch[6], int_ch[6])
105 | u1 = UpSampling2D(size=(2, 2), interpolation="bilinear")(d1)
106 |
107 | d2 = Concatenate()([u1, s4])
108 | d2 = RSU_L(d2, out_ch[7], int_ch[7], 4)
109 | u2 = UpSampling2D(size=(2, 2), interpolation="bilinear")(d2)
110 |
111 | d3 = Concatenate()([u2, s3])
112 | d3 = RSU_L(d3, out_ch[8], int_ch[8], 5)
113 | u3 = UpSampling2D(size=(2, 2), interpolation="bilinear")(d3)
114 |
115 | d4 = Concatenate()([u3, s2])
116 | d4 = RSU_L(d4, out_ch[9], int_ch[9], 6)
117 | u4 = UpSampling2D(size=(2, 2), interpolation="bilinear")(d4)
118 |
119 | d5 = Concatenate()([u4, s1])
120 | d5 = RSU_L(d5, out_ch[10], int_ch[10], 7)
121 |
122 | """ Side Outputs """
123 | y1 = Conv2D(num_classes, 3, padding="same")(d5)
124 |
125 | y2 = Conv2D(num_classes, 3, padding="same")(d4)
126 | y2 = UpSampling2D(size=(2, 2), interpolation="bilinear")(y2)
127 |
128 | y3 = Conv2D(num_classes, 3, padding="same")(d3)
129 | y3 = UpSampling2D(size=(4, 4), interpolation="bilinear")(y3)
130 |
131 | y4 = Conv2D(num_classes, 3, padding="same")(d2)
132 | y4 = UpSampling2D(size=(8, 8), interpolation="bilinear")(y4)
133 |
134 | y5 = Conv2D(num_classes, 3, padding="same")(d1)
135 | y5 = UpSampling2D(size=(16, 16), interpolation="bilinear")(y5)
136 |
137 | y6 = Conv2D(num_classes, 3, padding="same")(b1)
138 | y6 = UpSampling2D(size=(32, 32), interpolation="bilinear")(y6)
139 |
140 | y0 = Concatenate()([y1, y2, y3, y4, y5, y6])
141 | y0 = Conv2D(num_classes, 3, padding="same")(y0)
142 |
143 | y0 = Activation("sigmoid")(y0)
144 | y1 = Activation("sigmoid")(y1)
145 | y2 = Activation("sigmoid")(y2)
146 | y3 = Activation("sigmoid")(y3)
147 | y4 = Activation("sigmoid")(y4)
148 | y5 = Activation("sigmoid")(y5)
149 | y6 = Activation("sigmoid")(y6)
150 |
151 | model = tf.keras.models.Model(inputs, outputs=[y0, y1, y2, y3, y4, y5, y6])
152 | return model
153 |
154 | def build_u2net(input_shape, num_classes=1):
155 | out_ch = [64, 128, 256, 512, 512, 512, 512, 256, 128, 64, 64]
156 | int_ch = [32, 32, 64, 128, 256, 256, 256, 128, 64, 32, 16]
157 | model = u2net(input_shape, out_ch, int_ch, num_classes=num_classes)
158 | return model
159 |
160 | def build_u2net_lite(input_shape, num_classes=1):
161 | out_ch = [64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64]
162 | int_ch = [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16]
163 | model = u2net(input_shape, out_ch, int_ch, num_classes=num_classes)
164 | return model
165 |
166 | if __name__ == "__main__":
167 | model = build_u2net_lite((512, 512, 3))
168 | model.summary()
169 |
--------------------------------------------------------------------------------
/TensorFlow/unet.py:
--------------------------------------------------------------------------------
1 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
2 | from tensorflow.keras.models import Model
3 |
4 | def conv_block(input, num_filters):
5 | x = Conv2D(num_filters, 3, padding="same")(input)
6 | x = BatchNormalization()(x)
7 | x = Activation("relu")(x)
8 |
9 | x = Conv2D(num_filters, 3, padding="same")(x)
10 | x = BatchNormalization()(x)
11 | x = Activation("relu")(x)
12 |
13 | return x
14 |
15 | def encoder_block(input, num_filters):
16 | x = conv_block(input, num_filters)
17 | p = MaxPool2D((2, 2))(x)
18 | return x, p
19 |
20 | def decoder_block(input, skip_features, num_filters):
21 | x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input)
22 | x = Concatenate()([x, skip_features])
23 | x = conv_block(x, num_filters)
24 | return x
25 |
26 | def build_unet(input_shape):
27 | inputs = Input(input_shape)
28 |
29 | s1, p1 = encoder_block(inputs, 64)
30 | s2, p2 = encoder_block(p1, 128)
31 | s3, p3 = encoder_block(p2, 256)
32 | s4, p4 = encoder_block(p3, 512)
33 |
34 | b1 = conv_block(p4, 1024)
35 |
36 | d1 = decoder_block(b1, s4, 512)
37 | d2 = decoder_block(d1, s3, 256)
38 | d3 = decoder_block(d2, s2, 128)
39 | d4 = decoder_block(d3, s1, 64)
40 |
41 | outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)
42 |
43 | model = Model(inputs, outputs, name="U-Net")
44 | return model
45 |
46 | if __name__ == "__main__":
47 | input_shape = (512, 512, 3)
48 | model = build_unet(input_shape)
49 | model.summary()
50 |
--------------------------------------------------------------------------------
/TensorFlow/unetr_2d.py:
--------------------------------------------------------------------------------
1 |
2 | import tensorflow as tf
3 | import tensorflow.keras.layers as L
4 | from tensorflow.keras.models import Model
5 |
6 | def mlp(x, cf):
7 | x = L.Dense(cf["mlp_dim"], activation="gelu")(x)
8 | x = L.Dropout(cf["dropout_rate"])(x)
9 | x = L.Dense(cf["hidden_dim"])(x)
10 | x = L.Dropout(cf["dropout_rate"])(x)
11 | return x
12 |
13 | def transformer_encoder(x, cf):
14 | skip_1 = x
15 | x = L.LayerNormalization()(x)
16 | x = L.MultiHeadAttention(
17 | num_heads=cf["num_heads"], key_dim=cf["hidden_dim"]
18 | )(x, x)
19 | x = L.Add()([x, skip_1])
20 |
21 | skip_2 = x
22 | x = L.LayerNormalization()(x)
23 | x = mlp(x, cf)
24 | x = L.Add()([x, skip_2])
25 |
26 | return x
27 |
28 | def conv_block(x, num_filters, kernel_size=3):
29 | x = L.Conv2D(num_filters, kernel_size=kernel_size, padding="same")(x)
30 | x = L.BatchNormalization()(x)
31 | x = L.ReLU()(x)
32 | return x
33 |
34 | def deconv_block(x, num_filters):
35 | x = L.Conv2DTranspose(num_filters, kernel_size=2, padding="same", strides=2)(x)
36 | return x
37 |
38 | def build_unetr_2d(cf):
39 | """ Inputs """
40 | input_shape = (cf["num_patches"], cf["patch_size"]*cf["patch_size"]*cf["num_channels"])
41 | inputs = L.Input(input_shape) ## (None, 256, 768)
42 |
43 | """ Patch + Position Embeddings """
44 | patch_embed = L.Dense(cf["hidden_dim"])(inputs) ## (None, 256, 768)
45 |
46 | positions = tf.range(start=0, limit=cf["num_patches"], delta=1) ## (256,)
47 | pos_embed = L.Embedding(input_dim=cf["num_patches"], output_dim=cf["hidden_dim"])(positions) ## (256, 768)
48 | x = patch_embed + pos_embed ## (None, 256, 768)
49 |
50 | """ Transformer Encoder """
51 | skip_connection_index = [3, 6, 9, 12]
52 | skip_connections = []
53 |
54 | for i in range(1, cf["num_layers"]+1, 1):
55 | x = transformer_encoder(x, cf)
56 |
57 | if i in skip_connection_index:
58 | skip_connections.append(x)
59 |
60 | """ CNN Decoder """
61 | z3, z6, z9, z12 = skip_connections
62 |
63 | ## Reshaping
64 | z0 = L.Reshape((cf["image_size"], cf["image_size"], cf["num_channels"]))(inputs)
65 | z3 = L.Reshape((cf["patch_size"], cf["patch_size"], cf["hidden_dim"]))(z3)
66 | z6 = L.Reshape((cf["patch_size"], cf["patch_size"], cf["hidden_dim"]))(z6)
67 | z9 = L.Reshape((cf["patch_size"], cf["patch_size"], cf["hidden_dim"]))(z9)
68 | z12 = L.Reshape((cf["patch_size"], cf["patch_size"], cf["hidden_dim"]))(z12)
69 |
70 | ## Decoder 1
71 | x = deconv_block(z12, 512)
72 |
73 | s = deconv_block(z9, 512)
74 | s = conv_block(s, 512)
75 | x = L.Concatenate()([x, s])
76 |
77 | x = conv_block(x, 512)
78 | x = conv_block(x, 512)
79 |
80 | ## Decoder 2
81 | x = deconv_block(x, 256)
82 |
83 | s = deconv_block(z6, 256)
84 | s = conv_block(s, 256)
85 | s = deconv_block(s, 256)
86 | s = conv_block(s, 256)
87 |
88 | x = L.Concatenate()([x, s])
89 | x = conv_block(x, 256)
90 | x = conv_block(x, 256)
91 |
92 | ## Decoder 3
93 | x = deconv_block(x, 128)
94 |
95 | s = deconv_block(z3, 128)
96 | s = conv_block(s, 128)
97 | s = deconv_block(s, 128)
98 | s = conv_block(s, 128)
99 | s = deconv_block(s, 128)
100 | s = conv_block(s, 128)
101 |
102 | x = L.Concatenate()([x, s])
103 | x = conv_block(x, 128)
104 | x = conv_block(x, 128)
105 |
106 | ## Decoder 4
107 | x = deconv_block(x, 64)
108 |
109 | s = conv_block(z0, 64)
110 | s = conv_block(s, 64)
111 |
112 | x = L.Concatenate()([x, s])
113 | x = conv_block(x, 64)
114 | x = conv_block(x, 64)
115 |
116 | """ Output """
117 | outputs = L.Conv2D(1, kernel_size=1, padding="same", activation="sigmoid")(x)
118 |
119 | return Model(inputs, outputs, name="UNETR_2D")
120 |
121 | if __name__ == "__main__":
122 | config = {}
123 | config["image_size"] = 256
124 | config["num_layers"] = 12
125 | config["hidden_dim"] = 768
126 | config["mlp_dim"] = 3072
127 | config["num_heads"] = 12
128 | config["dropout_rate"] = 0.1
129 | config["num_patches"] = 256
130 | config["patch_size"] = 16
131 | config["num_channels"] = 3
132 |
133 | model = build_unetr_2d(config)
134 | model.summary()
135 |
--------------------------------------------------------------------------------
/TensorFlow/vgg16_unet.py:
--------------------------------------------------------------------------------
1 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
2 | from tensorflow.keras.models import Model
3 | from tensorflow.keras.applications import VGG16
4 |
5 | def conv_block(input, num_filters):
6 | x = Conv2D(num_filters, 3, padding="same")(input)
7 | x = BatchNormalization()(x)
8 | x = Activation("relu")(x)
9 |
10 | x = Conv2D(num_filters, 3, padding="same")(x)
11 | x = BatchNormalization()(x)
12 | x = Activation("relu")(x)
13 |
14 | return x
15 |
16 | def decoder_block(input, skip_features, num_filters):
17 | x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input)
18 | x = Concatenate()([x, skip_features])
19 | x = conv_block(x, num_filters)
20 | return x
21 |
22 | def build_vgg16_unet(input_shape):
23 | """ Input """
24 | inputs = Input(input_shape)
25 |
26 | """ Pre-trained VGG16 Model """
27 | vgg16 = VGG16(include_top=False, weights="imagenet", input_tensor=inputs)
28 |
29 | """ Encoder """
30 | s1 = vgg16.get_layer("block1_conv2").output ## (512 x 512)
31 | s2 = vgg16.get_layer("block2_conv2").output ## (256 x 256)
32 | s3 = vgg16.get_layer("block3_conv3").output ## (128 x 128)
33 | s4 = vgg16.get_layer("block4_conv3").output ## (64 x 64)
34 |
35 | """ Bridge """
36 | b1 = vgg16.get_layer("block5_conv3").output ## (32 x 32)
37 |
38 | """ Decoder """
39 | d1 = decoder_block(b1, s4, 512) ## (64 x 64)
40 | d2 = decoder_block(d1, s3, 256) ## (128 x 128)
41 | d3 = decoder_block(d2, s2, 128) ## (256 x 256)
42 | d4 = decoder_block(d3, s1, 64) ## (512 x 512)
43 |
44 | """ Output """
45 | outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)
46 |
47 | model = Model(inputs, outputs, name="VGG16_U-Net")
48 | return model
49 |
50 | if __name__ == "__main__":
51 | input_shape = (512, 512, 3)
52 | model = build_vgg16_unet(input_shape)
53 | model.summary()
54 |
--------------------------------------------------------------------------------
/TensorFlow/vgg19_unet.py:
--------------------------------------------------------------------------------
1 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
2 | from tensorflow.keras.models import Model
3 | from tensorflow.keras.applications import VGG19
4 |
5 | def conv_block(input, num_filters):
6 | x = Conv2D(num_filters, 3, padding="same")(input)
7 | x = BatchNormalization()(x)
8 | x = Activation("relu")(x)
9 |
10 | x = Conv2D(num_filters, 3, padding="same")(x)
11 | x = BatchNormalization()(x)
12 | x = Activation("relu")(x)
13 |
14 | return x
15 |
16 | def decoder_block(input, skip_features, num_filters):
17 | x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input)
18 | x = Concatenate()([x, skip_features])
19 | x = conv_block(x, num_filters)
20 | return x
21 |
22 | def build_vgg19_unet(input_shape):
23 | """ Input """
24 | inputs = Input(input_shape)
25 |
26 | """ Pre-trained VGG19 Model """
27 | vgg19 = VGG19(include_top=False, weights="imagenet", input_tensor=inputs)
28 |
29 | """ Encoder """
30 | s1 = vgg19.get_layer("block1_conv2").output ## (512 x 512)
31 | s2 = vgg19.get_layer("block2_conv2").output ## (256 x 256)
32 | s3 = vgg19.get_layer("block3_conv4").output ## (128 x 128)
33 | s4 = vgg19.get_layer("block4_conv4").output ## (64 x 64)
34 |
35 | """ Bridge """
36 | b1 = vgg19.get_layer("block5_conv4").output ## (32 x 32)
37 |
38 | """ Decoder """
39 | d1 = decoder_block(b1, s4, 512) ## (64 x 64)
40 | d2 = decoder_block(d1, s3, 256) ## (128 x 128)
41 | d3 = decoder_block(d2, s2, 128) ## (256 x 256)
42 | d4 = decoder_block(d3, s1, 64) ## (512 x 512)
43 |
44 | """ Output """
45 | outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)
46 |
47 | model = Model(inputs, outputs, name="VGG19_U-Net")
48 | return model
49 |
50 | if __name__ == "__main__":
51 | input_shape = (512, 512, 3)
52 | model = build_vgg19_unet(input_shape)
53 | model.summary()
54 |
--------------------------------------------------------------------------------