├── Arche.JPG ├── License ├── README.md ├── requirements.txt └── unetr.py /Arche.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamasino52/UNETR/eeb5277a95b0c28d35bfeb24fa0eb6d2d43b16ec/Arche.JPG -------------------------------------------------------------------------------- /License: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Minseok_Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UNETR: Transformers for 3D Medical Image Segmentation (WACV 2022) 2 | [![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2Ftamasino52%2FUNETR&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false)](https://hits.seeyoufarm.com) 3 | 4 | 5 | Unofficial codebase for : 6 | > [**UNETR: Transformers for 3D Medical Image Segmentation**], 7 | > Ali Hatamizadeh, Dong Yang, Holger Roth, Daguang Xu. 2021. 8 | > *(https://arxiv.org/abs/2103.10504?context=cs.CV)* 9 | 10 |
11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.4.0 -------------------------------------------------------------------------------- /unetr.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | 7 | 8 | class SingleDeconv3DBlock(nn.Module): 9 | def __init__(self, in_planes, out_planes): 10 | super().__init__() 11 | self.block = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=2, stride=2, padding=0, output_padding=0) 12 | 13 | def forward(self, x): 14 | return self.block(x) 15 | 16 | 17 | class SingleConv3DBlock(nn.Module): 18 | def __init__(self, in_planes, out_planes, kernel_size): 19 | super().__init__() 20 | self.block = nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1, 21 | padding=((kernel_size - 1) // 2)) 22 | 23 | def forward(self, x): 24 | return self.block(x) 25 | 26 | 27 | class Conv3DBlock(nn.Module): 28 | def __init__(self, in_planes, out_planes, kernel_size=3): 29 | super().__init__() 30 | self.block = nn.Sequential( 31 | SingleConv3DBlock(in_planes, out_planes, kernel_size), 32 | nn.BatchNorm3d(out_planes), 33 | nn.ReLU(True) 34 | ) 35 | 36 | def forward(self, x): 37 | return self.block(x) 38 | 39 | 40 | class Deconv3DBlock(nn.Module): 41 | def __init__(self, in_planes, out_planes, kernel_size=3): 42 | super().__init__() 43 | self.block = nn.Sequential( 44 | SingleDeconv3DBlock(in_planes, out_planes), 45 | SingleConv3DBlock(out_planes, out_planes, kernel_size), 46 | nn.BatchNorm3d(out_planes), 47 | nn.ReLU(True) 48 | ) 49 | 50 | def forward(self, x): 51 | return self.block(x) 52 | 53 | 54 | class SelfAttention(nn.Module): 55 | def __init__(self, num_heads, embed_dim, dropout): 56 | super().__init__() 57 | self.num_attention_heads = num_heads 58 | self.attention_head_size = int(embed_dim / num_heads) 59 | self.all_head_size = self.num_attention_heads * self.attention_head_size 60 | 61 | self.query = nn.Linear(embed_dim, self.all_head_size) 62 | self.key = nn.Linear(embed_dim, self.all_head_size) 63 | self.value = nn.Linear(embed_dim, self.all_head_size) 64 | 65 | self.out = nn.Linear(embed_dim, embed_dim) 66 | self.attn_dropout = nn.Dropout(dropout) 67 | self.proj_dropout = nn.Dropout(dropout) 68 | 69 | self.softmax = nn.Softmax(dim=-1) 70 | 71 | self.vis = False 72 | 73 | def transpose_for_scores(self, x): 74 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 75 | x = x.view(*new_x_shape) 76 | return x.permute(0, 2, 1, 3) 77 | 78 | def forward(self, hidden_states): 79 | mixed_query_layer = self.query(hidden_states) 80 | mixed_key_layer = self.key(hidden_states) 81 | mixed_value_layer = self.value(hidden_states) 82 | 83 | query_layer = self.transpose_for_scores(mixed_query_layer) 84 | key_layer = self.transpose_for_scores(mixed_key_layer) 85 | value_layer = self.transpose_for_scores(mixed_value_layer) 86 | 87 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 88 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 89 | attention_probs = self.softmax(attention_scores) 90 | weights = attention_probs if self.vis else None 91 | attention_probs = self.attn_dropout(attention_probs) 92 | 93 | context_layer = torch.matmul(attention_probs, value_layer) 94 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 95 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 96 | context_layer = context_layer.view(*new_context_layer_shape) 97 | attention_output = self.out(context_layer) 98 | attention_output = self.proj_dropout(attention_output) 99 | return attention_output, weights 100 | 101 | 102 | class Mlp(nn.Module): 103 | def __init__(self, in_features, act_layer=nn.GELU, drop=0.): 104 | super().__init__() 105 | self.fc1 = nn.Linear(in_features, in_features) 106 | self.act = act_layer() 107 | self.drop = nn.Dropout(drop) 108 | 109 | def forward(self, x): 110 | x = self.fc1() 111 | x = self.act(x) 112 | x = self.drop(x) 113 | return x 114 | 115 | 116 | class PositionwiseFeedForward(nn.Module): 117 | def __init__(self, d_model=786, d_ff=2048, dropout=0.1): 118 | super().__init__() 119 | # Torch linears have a `b` by default. 120 | self.w_1 = nn.Linear(d_model, d_ff) 121 | self.w_2 = nn.Linear(d_ff, d_model) 122 | self.dropout = nn.Dropout(dropout) 123 | 124 | def forward(self, x): 125 | return self.w_2(self.dropout(F.relu(self.w_1(x)))) 126 | 127 | 128 | class Embeddings(nn.Module): 129 | def __init__(self, input_dim, embed_dim, cube_size, patch_size, dropout): 130 | super().__init__() 131 | self.n_patches = int((cube_size[0] * cube_size[1] * cube_size[2]) / (patch_size * patch_size * patch_size)) 132 | self.patch_size = patch_size 133 | self.embed_dim = embed_dim 134 | self.patch_embeddings = nn.Conv3d(in_channels=input_dim, out_channels=embed_dim, 135 | kernel_size=patch_size, stride=patch_size) 136 | self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, embed_dim)) 137 | self.dropout = nn.Dropout(dropout) 138 | 139 | def forward(self, x): 140 | x = self.patch_embeddings(x) 141 | x = x.flatten(2) 142 | x = x.transpose(-1, -2) 143 | embeddings = x + self.position_embeddings 144 | embeddings = self.dropout(embeddings) 145 | return embeddings 146 | 147 | 148 | class TransformerBlock(nn.Module): 149 | def __init__(self, embed_dim, num_heads, dropout, cube_size, patch_size): 150 | super().__init__() 151 | self.attention_norm = nn.LayerNorm(embed_dim, eps=1e-6) 152 | self.mlp_norm = nn.LayerNorm(embed_dim, eps=1e-6) 153 | self.mlp_dim = int((cube_size[0] * cube_size[1] * cube_size[2]) / (patch_size * patch_size * patch_size)) 154 | self.mlp = PositionwiseFeedForward(embed_dim, 2048) 155 | self.attn = SelfAttention(num_heads, embed_dim, dropout) 156 | 157 | def forward(self, x): 158 | h = x 159 | x = self.attention_norm(x) 160 | x, weights = self.attn(x) 161 | x = x + h 162 | h = x 163 | 164 | x = self.mlp_norm(x) 165 | x = self.mlp(x) 166 | 167 | x = x + h 168 | return x, weights 169 | 170 | 171 | class Transformer(nn.Module): 172 | def __init__(self, input_dim, embed_dim, cube_size, patch_size, num_heads, num_layers, dropout, extract_layers): 173 | super().__init__() 174 | self.embeddings = Embeddings(input_dim, embed_dim, cube_size, patch_size, dropout) 175 | self.layer = nn.ModuleList() 176 | self.encoder_norm = nn.LayerNorm(embed_dim, eps=1e-6) 177 | self.extract_layers = extract_layers 178 | for _ in range(num_layers): 179 | layer = TransformerBlock(embed_dim, num_heads, dropout, cube_size, patch_size) 180 | self.layer.append(copy.deepcopy(layer)) 181 | 182 | def forward(self, x): 183 | extract_layers = [] 184 | hidden_states = self.embeddings(x) 185 | 186 | for depth, layer_block in enumerate(self.layer): 187 | hidden_states, _ = layer_block(hidden_states) 188 | if depth + 1 in self.extract_layers: 189 | extract_layers.append(hidden_states) 190 | 191 | return extract_layers 192 | 193 | 194 | class UNETR(nn.Module): 195 | def __init__(self, img_shape=(128, 128, 128), input_dim=4, output_dim=3, embed_dim=768, patch_size=16, num_heads=12, dropout=0.1): 196 | super().__init__() 197 | self.input_dim = input_dim 198 | self.output_dim = output_dim 199 | self.embed_dim = embed_dim 200 | self.img_shape = img_shape 201 | self.patch_size = patch_size 202 | self.num_heads = num_heads 203 | self.dropout = dropout 204 | self.num_layers = 12 205 | self.ext_layers = [3, 6, 9, 12] 206 | 207 | self.patch_dim = [int(x / patch_size) for x in img_shape] 208 | 209 | # Transformer Encoder 210 | self.transformer = \ 211 | Transformer( 212 | input_dim, 213 | embed_dim, 214 | img_shape, 215 | patch_size, 216 | num_heads, 217 | self.num_layers, 218 | dropout, 219 | self.ext_layers 220 | ) 221 | 222 | # U-Net Decoder 223 | self.decoder0 = \ 224 | nn.Sequential( 225 | Conv3DBlock(input_dim, 32, 3), 226 | Conv3DBlock(32, 64, 3) 227 | ) 228 | 229 | self.decoder3 = \ 230 | nn.Sequential( 231 | Deconv3DBlock(embed_dim, 512), 232 | Deconv3DBlock(512, 256), 233 | Deconv3DBlock(256, 128) 234 | ) 235 | 236 | self.decoder6 = \ 237 | nn.Sequential( 238 | Deconv3DBlock(embed_dim, 512), 239 | Deconv3DBlock(512, 256), 240 | ) 241 | 242 | self.decoder9 = \ 243 | Deconv3DBlock(embed_dim, 512) 244 | 245 | self.decoder12_upsampler = \ 246 | SingleDeconv3DBlock(embed_dim, 512) 247 | 248 | self.decoder9_upsampler = \ 249 | nn.Sequential( 250 | Conv3DBlock(1024, 512), 251 | Conv3DBlock(512, 512), 252 | Conv3DBlock(512, 512), 253 | SingleDeconv3DBlock(512, 256) 254 | ) 255 | 256 | self.decoder6_upsampler = \ 257 | nn.Sequential( 258 | Conv3DBlock(512, 256), 259 | Conv3DBlock(256, 256), 260 | SingleDeconv3DBlock(256, 128) 261 | ) 262 | 263 | self.decoder3_upsampler = \ 264 | nn.Sequential( 265 | Conv3DBlock(256, 128), 266 | Conv3DBlock(128, 128), 267 | SingleDeconv3DBlock(128, 64) 268 | ) 269 | 270 | self.decoder0_header = \ 271 | nn.Sequential( 272 | Conv3DBlock(128, 64), 273 | Conv3DBlock(64, 64), 274 | SingleConv3DBlock(64, output_dim, 1) 275 | ) 276 | 277 | def forward(self, x): 278 | z = self.transformer(x) 279 | z0, z3, z6, z9, z12 = x, *z 280 | z3 = z3.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim) 281 | z6 = z6.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim) 282 | z9 = z9.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim) 283 | z12 = z12.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim) 284 | 285 | z12 = self.decoder12_upsampler(z12) 286 | z9 = self.decoder9(z9) 287 | z9 = self.decoder9_upsampler(torch.cat([z9, z12], dim=1)) 288 | z6 = self.decoder6(z6) 289 | z6 = self.decoder6_upsampler(torch.cat([z6, z9], dim=1)) 290 | z3 = self.decoder3(z3) 291 | z3 = self.decoder3_upsampler(torch.cat([z3, z6], dim=1)) 292 | z0 = self.decoder0(z0) 293 | output = self.decoder0_header(torch.cat([z0, z3], dim=1)) 294 | return output 295 | --------------------------------------------------------------------------------