├── Arche.JPG
├── License
├── README.md
├── requirements.txt
└── unetr.py
/Arche.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tamasino52/UNETR/eeb5277a95b0c28d35bfeb24fa0eb6d2d43b16ec/Arche.JPG
--------------------------------------------------------------------------------
/License:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Minseok_Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # UNETR: Transformers for 3D Medical Image Segmentation (WACV 2022)
2 | [](https://hits.seeyoufarm.com)
3 |
4 |
5 | Unofficial codebase for :
6 | > [**UNETR: Transformers for 3D Medical Image Segmentation**],
7 | > Ali Hatamizadeh, Dong Yang, Holger Roth, Daguang Xu. 2021.
8 | > *(https://arxiv.org/abs/2103.10504?context=cs.CV)*
9 |
10 | 
11 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.4.0
--------------------------------------------------------------------------------
/unetr.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import math
6 |
7 |
8 | class SingleDeconv3DBlock(nn.Module):
9 | def __init__(self, in_planes, out_planes):
10 | super().__init__()
11 | self.block = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=2, stride=2, padding=0, output_padding=0)
12 |
13 | def forward(self, x):
14 | return self.block(x)
15 |
16 |
17 | class SingleConv3DBlock(nn.Module):
18 | def __init__(self, in_planes, out_planes, kernel_size):
19 | super().__init__()
20 | self.block = nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1,
21 | padding=((kernel_size - 1) // 2))
22 |
23 | def forward(self, x):
24 | return self.block(x)
25 |
26 |
27 | class Conv3DBlock(nn.Module):
28 | def __init__(self, in_planes, out_planes, kernel_size=3):
29 | super().__init__()
30 | self.block = nn.Sequential(
31 | SingleConv3DBlock(in_planes, out_planes, kernel_size),
32 | nn.BatchNorm3d(out_planes),
33 | nn.ReLU(True)
34 | )
35 |
36 | def forward(self, x):
37 | return self.block(x)
38 |
39 |
40 | class Deconv3DBlock(nn.Module):
41 | def __init__(self, in_planes, out_planes, kernel_size=3):
42 | super().__init__()
43 | self.block = nn.Sequential(
44 | SingleDeconv3DBlock(in_planes, out_planes),
45 | SingleConv3DBlock(out_planes, out_planes, kernel_size),
46 | nn.BatchNorm3d(out_planes),
47 | nn.ReLU(True)
48 | )
49 |
50 | def forward(self, x):
51 | return self.block(x)
52 |
53 |
54 | class SelfAttention(nn.Module):
55 | def __init__(self, num_heads, embed_dim, dropout):
56 | super().__init__()
57 | self.num_attention_heads = num_heads
58 | self.attention_head_size = int(embed_dim / num_heads)
59 | self.all_head_size = self.num_attention_heads * self.attention_head_size
60 |
61 | self.query = nn.Linear(embed_dim, self.all_head_size)
62 | self.key = nn.Linear(embed_dim, self.all_head_size)
63 | self.value = nn.Linear(embed_dim, self.all_head_size)
64 |
65 | self.out = nn.Linear(embed_dim, embed_dim)
66 | self.attn_dropout = nn.Dropout(dropout)
67 | self.proj_dropout = nn.Dropout(dropout)
68 |
69 | self.softmax = nn.Softmax(dim=-1)
70 |
71 | self.vis = False
72 |
73 | def transpose_for_scores(self, x):
74 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
75 | x = x.view(*new_x_shape)
76 | return x.permute(0, 2, 1, 3)
77 |
78 | def forward(self, hidden_states):
79 | mixed_query_layer = self.query(hidden_states)
80 | mixed_key_layer = self.key(hidden_states)
81 | mixed_value_layer = self.value(hidden_states)
82 |
83 | query_layer = self.transpose_for_scores(mixed_query_layer)
84 | key_layer = self.transpose_for_scores(mixed_key_layer)
85 | value_layer = self.transpose_for_scores(mixed_value_layer)
86 |
87 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
88 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
89 | attention_probs = self.softmax(attention_scores)
90 | weights = attention_probs if self.vis else None
91 | attention_probs = self.attn_dropout(attention_probs)
92 |
93 | context_layer = torch.matmul(attention_probs, value_layer)
94 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
95 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
96 | context_layer = context_layer.view(*new_context_layer_shape)
97 | attention_output = self.out(context_layer)
98 | attention_output = self.proj_dropout(attention_output)
99 | return attention_output, weights
100 |
101 |
102 | class Mlp(nn.Module):
103 | def __init__(self, in_features, act_layer=nn.GELU, drop=0.):
104 | super().__init__()
105 | self.fc1 = nn.Linear(in_features, in_features)
106 | self.act = act_layer()
107 | self.drop = nn.Dropout(drop)
108 |
109 | def forward(self, x):
110 | x = self.fc1()
111 | x = self.act(x)
112 | x = self.drop(x)
113 | return x
114 |
115 |
116 | class PositionwiseFeedForward(nn.Module):
117 | def __init__(self, d_model=786, d_ff=2048, dropout=0.1):
118 | super().__init__()
119 | # Torch linears have a `b` by default.
120 | self.w_1 = nn.Linear(d_model, d_ff)
121 | self.w_2 = nn.Linear(d_ff, d_model)
122 | self.dropout = nn.Dropout(dropout)
123 |
124 | def forward(self, x):
125 | return self.w_2(self.dropout(F.relu(self.w_1(x))))
126 |
127 |
128 | class Embeddings(nn.Module):
129 | def __init__(self, input_dim, embed_dim, cube_size, patch_size, dropout):
130 | super().__init__()
131 | self.n_patches = int((cube_size[0] * cube_size[1] * cube_size[2]) / (patch_size * patch_size * patch_size))
132 | self.patch_size = patch_size
133 | self.embed_dim = embed_dim
134 | self.patch_embeddings = nn.Conv3d(in_channels=input_dim, out_channels=embed_dim,
135 | kernel_size=patch_size, stride=patch_size)
136 | self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, embed_dim))
137 | self.dropout = nn.Dropout(dropout)
138 |
139 | def forward(self, x):
140 | x = self.patch_embeddings(x)
141 | x = x.flatten(2)
142 | x = x.transpose(-1, -2)
143 | embeddings = x + self.position_embeddings
144 | embeddings = self.dropout(embeddings)
145 | return embeddings
146 |
147 |
148 | class TransformerBlock(nn.Module):
149 | def __init__(self, embed_dim, num_heads, dropout, cube_size, patch_size):
150 | super().__init__()
151 | self.attention_norm = nn.LayerNorm(embed_dim, eps=1e-6)
152 | self.mlp_norm = nn.LayerNorm(embed_dim, eps=1e-6)
153 | self.mlp_dim = int((cube_size[0] * cube_size[1] * cube_size[2]) / (patch_size * patch_size * patch_size))
154 | self.mlp = PositionwiseFeedForward(embed_dim, 2048)
155 | self.attn = SelfAttention(num_heads, embed_dim, dropout)
156 |
157 | def forward(self, x):
158 | h = x
159 | x = self.attention_norm(x)
160 | x, weights = self.attn(x)
161 | x = x + h
162 | h = x
163 |
164 | x = self.mlp_norm(x)
165 | x = self.mlp(x)
166 |
167 | x = x + h
168 | return x, weights
169 |
170 |
171 | class Transformer(nn.Module):
172 | def __init__(self, input_dim, embed_dim, cube_size, patch_size, num_heads, num_layers, dropout, extract_layers):
173 | super().__init__()
174 | self.embeddings = Embeddings(input_dim, embed_dim, cube_size, patch_size, dropout)
175 | self.layer = nn.ModuleList()
176 | self.encoder_norm = nn.LayerNorm(embed_dim, eps=1e-6)
177 | self.extract_layers = extract_layers
178 | for _ in range(num_layers):
179 | layer = TransformerBlock(embed_dim, num_heads, dropout, cube_size, patch_size)
180 | self.layer.append(copy.deepcopy(layer))
181 |
182 | def forward(self, x):
183 | extract_layers = []
184 | hidden_states = self.embeddings(x)
185 |
186 | for depth, layer_block in enumerate(self.layer):
187 | hidden_states, _ = layer_block(hidden_states)
188 | if depth + 1 in self.extract_layers:
189 | extract_layers.append(hidden_states)
190 |
191 | return extract_layers
192 |
193 |
194 | class UNETR(nn.Module):
195 | def __init__(self, img_shape=(128, 128, 128), input_dim=4, output_dim=3, embed_dim=768, patch_size=16, num_heads=12, dropout=0.1):
196 | super().__init__()
197 | self.input_dim = input_dim
198 | self.output_dim = output_dim
199 | self.embed_dim = embed_dim
200 | self.img_shape = img_shape
201 | self.patch_size = patch_size
202 | self.num_heads = num_heads
203 | self.dropout = dropout
204 | self.num_layers = 12
205 | self.ext_layers = [3, 6, 9, 12]
206 |
207 | self.patch_dim = [int(x / patch_size) for x in img_shape]
208 |
209 | # Transformer Encoder
210 | self.transformer = \
211 | Transformer(
212 | input_dim,
213 | embed_dim,
214 | img_shape,
215 | patch_size,
216 | num_heads,
217 | self.num_layers,
218 | dropout,
219 | self.ext_layers
220 | )
221 |
222 | # U-Net Decoder
223 | self.decoder0 = \
224 | nn.Sequential(
225 | Conv3DBlock(input_dim, 32, 3),
226 | Conv3DBlock(32, 64, 3)
227 | )
228 |
229 | self.decoder3 = \
230 | nn.Sequential(
231 | Deconv3DBlock(embed_dim, 512),
232 | Deconv3DBlock(512, 256),
233 | Deconv3DBlock(256, 128)
234 | )
235 |
236 | self.decoder6 = \
237 | nn.Sequential(
238 | Deconv3DBlock(embed_dim, 512),
239 | Deconv3DBlock(512, 256),
240 | )
241 |
242 | self.decoder9 = \
243 | Deconv3DBlock(embed_dim, 512)
244 |
245 | self.decoder12_upsampler = \
246 | SingleDeconv3DBlock(embed_dim, 512)
247 |
248 | self.decoder9_upsampler = \
249 | nn.Sequential(
250 | Conv3DBlock(1024, 512),
251 | Conv3DBlock(512, 512),
252 | Conv3DBlock(512, 512),
253 | SingleDeconv3DBlock(512, 256)
254 | )
255 |
256 | self.decoder6_upsampler = \
257 | nn.Sequential(
258 | Conv3DBlock(512, 256),
259 | Conv3DBlock(256, 256),
260 | SingleDeconv3DBlock(256, 128)
261 | )
262 |
263 | self.decoder3_upsampler = \
264 | nn.Sequential(
265 | Conv3DBlock(256, 128),
266 | Conv3DBlock(128, 128),
267 | SingleDeconv3DBlock(128, 64)
268 | )
269 |
270 | self.decoder0_header = \
271 | nn.Sequential(
272 | Conv3DBlock(128, 64),
273 | Conv3DBlock(64, 64),
274 | SingleConv3DBlock(64, output_dim, 1)
275 | )
276 |
277 | def forward(self, x):
278 | z = self.transformer(x)
279 | z0, z3, z6, z9, z12 = x, *z
280 | z3 = z3.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)
281 | z6 = z6.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)
282 | z9 = z9.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)
283 | z12 = z12.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)
284 |
285 | z12 = self.decoder12_upsampler(z12)
286 | z9 = self.decoder9(z9)
287 | z9 = self.decoder9_upsampler(torch.cat([z9, z12], dim=1))
288 | z6 = self.decoder6(z6)
289 | z6 = self.decoder6_upsampler(torch.cat([z6, z9], dim=1))
290 | z3 = self.decoder3(z3)
291 | z3 = self.decoder3_upsampler(torch.cat([z3, z6], dim=1))
292 | z0 = self.decoder0(z0)
293 | output = self.decoder0_header(torch.cat([z0, z3], dim=1))
294 | return output
295 |
--------------------------------------------------------------------------------