├── .gitignore ├── LICENSE ├── README.md └── src ├── README.md └── model.py /.gitignore: -------------------------------------------------------------------------------- 1 | deploy.sh 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 HoritaDaichi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 3D-UNet-PyTorch 2 | This repository is an implementation of [3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation](https://arxiv.org/abs/1606.06650) 3 | 4 | ## Environment 5 | + Python 3.6 6 | + PyTorch 1.0 7 | 8 | 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation 9 | Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger 10 | Conditionally accepted for MICCAI 2016 11 | https://arxiv.org/abs/1606.06650 12 | -------------------------------------------------------------------------------- /src/README.md: -------------------------------------------------------------------------------- 1 | ## src 2 | + model.py -> 3D-UNet model. -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | # 3D-UNet model. 2 | # x: 128x128 resolution for 32 frames. 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def conv_block_3d(in_dim, out_dim, activation): 8 | return nn.Sequential( 9 | nn.Conv3d(in_dim, out_dim, kernel_size=3, stride=1, padding=1), 10 | nn.BatchNorm3d(out_dim), 11 | activation,) 12 | 13 | 14 | def conv_trans_block_3d(in_dim, out_dim, activation): 15 | return nn.Sequential( 16 | nn.ConvTranspose3d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1), 17 | nn.BatchNorm3d(out_dim), 18 | activation,) 19 | 20 | 21 | def max_pooling_3d(): 22 | return nn.MaxPool3d(kernel_size=2, stride=2, padding=0) 23 | 24 | 25 | def conv_block_2_3d(in_dim, out_dim, activation): 26 | return nn.Sequential( 27 | conv_block_3d(in_dim, out_dim, activation), 28 | nn.Conv3d(out_dim, out_dim, kernel_size=3, stride=1, padding=1), 29 | nn.BatchNorm3d(out_dim),) 30 | 31 | class UNet(nn.Module): 32 | def __init__(self, in_dim, out_dim, num_filters): 33 | super(UNet, self).__init__() 34 | 35 | self.in_dim = in_dim 36 | self.out_dim = out_dim 37 | self.num_filters = num_filters 38 | activation = nn.LeakyReLU(0.2, inplace=True) 39 | 40 | # Down sampling 41 | self.down_1 = conv_block_2_3d(self.in_dim, self.num_filters, activation) 42 | self.pool_1 = max_pooling_3d() 43 | self.down_2 = conv_block_2_3d(self.num_filters, self.num_filters * 2, activation) 44 | self.pool_2 = max_pooling_3d() 45 | self.down_3 = conv_block_2_3d(self.num_filters * 2, self.num_filters * 4, activation) 46 | self.pool_3 = max_pooling_3d() 47 | self.down_4 = conv_block_2_3d(self.num_filters * 4, self.num_filters * 8, activation) 48 | self.pool_4 = max_pooling_3d() 49 | self.down_5 = conv_block_2_3d(self.num_filters * 8, self.num_filters * 16, activation) 50 | self.pool_5 = max_pooling_3d() 51 | 52 | # Bridge 53 | self.bridge = conv_block_2_3d(self.num_filters * 16, self.num_filters * 32, activation) 54 | 55 | # Up sampling 56 | self.trans_1 = conv_trans_block_3d(self.num_filters * 32, self.num_filters * 32, activation) 57 | self.up_1 = conv_block_2_3d(self.num_filters * 48, self.num_filters * 16, activation) 58 | self.trans_2 = conv_trans_block_3d(self.num_filters * 16, self.num_filters * 16, activation) 59 | self.up_2 = conv_block_2_3d(self.num_filters * 24, self.num_filters * 8, activation) 60 | self.trans_3 = conv_trans_block_3d(self.num_filters * 8, self.num_filters * 8, activation) 61 | self.up_3 = conv_block_2_3d(self.num_filters * 12, self.num_filters * 4, activation) 62 | self.trans_4 = conv_trans_block_3d(self.num_filters * 4, self.num_filters * 4, activation) 63 | self.up_4 = conv_block_2_3d(self.num_filters * 6, self.num_filters * 2, activation) 64 | self.trans_5 = conv_trans_block_3d(self.num_filters * 2, self.num_filters * 2, activation) 65 | self.up_5 = conv_block_2_3d(self.num_filters * 3, self.num_filters * 1, activation) 66 | 67 | # Output 68 | self.out = conv_block_3d(self.num_filters, out_dim, activation) 69 | 70 | def forward(self, x): 71 | # Down sampling 72 | down_1 = self.down_1(x) # -> [1, 4, 128, 128, 128] 73 | pool_1 = self.pool_1(down_1) # -> [1, 4, 64, 64, 64] 74 | 75 | down_2 = self.down_2(pool_1) # -> [1, 8, 64, 64, 64] 76 | pool_2 = self.pool_2(down_2) # -> [1, 8, 32, 32, 32] 77 | 78 | down_3 = self.down_3(pool_2) # -> [1, 16, 32, 32, 32] 79 | pool_3 = self.pool_3(down_3) # -> [1, 16, 16, 16, 16] 80 | 81 | down_4 = self.down_4(pool_3) # -> [1, 32, 16, 16, 16] 82 | pool_4 = self.pool_4(down_4) # -> [1, 32, 8, 8, 8] 83 | 84 | down_5 = self.down_5(pool_4) # -> [1, 64, 8, 8, 8] 85 | pool_5 = self.pool_5(down_5) # -> [1, 64, 4, 4, 4] 86 | 87 | # Bridge 88 | bridge = self.bridge(pool_5) # -> [1, 128, 4, 4, 4] 89 | 90 | # Up sampling 91 | trans_1 = self.trans_1(bridge) # -> [1, 128, 8, 8, 8] 92 | concat_1 = torch.cat([trans_1, down_5], dim=1) # -> [1, 192, 8, 8, 8] 93 | up_1 = self.up_1(concat_1) # -> [1, 64, 8, 8, 8] 94 | 95 | trans_2 = self.trans_2(up_1) # -> [1, 64, 16, 16, 16] 96 | concat_2 = torch.cat([trans_2, down_4], dim=1) # -> [1, 96, 16, 16, 16] 97 | up_2 = self.up_2(concat_2) # -> [1, 32, 16, 16, 16] 98 | 99 | trans_3 = self.trans_3(up_2) # -> [1, 32, 32, 32, 32] 100 | concat_3 = torch.cat([trans_3, down_3], dim=1) # -> [1, 48, 32, 32, 32] 101 | up_3 = self.up_3(concat_3) # -> [1, 16, 32, 32, 32] 102 | 103 | trans_4 = self.trans_4(up_3) # -> [1, 16, 64, 64, 64] 104 | concat_4 = torch.cat([trans_4, down_2], dim=1) # -> [1, 24, 64, 64, 64] 105 | up_4 = self.up_4(concat_4) # -> [1, 8, 64, 64, 64] 106 | 107 | trans_5 = self.trans_5(up_4) # -> [1, 8, 128, 128, 128] 108 | concat_5 = torch.cat([trans_5, down_1], dim=1) # -> [1, 12, 128, 128, 128] 109 | up_5 = self.up_5(concat_5) # -> [1, 4, 128, 128, 128] 110 | 111 | # Output 112 | out = self.out(up_5) # -> [1, 3, 128, 128, 128] 113 | return out 114 | 115 | if __name__ == "__main__": 116 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 117 | image_size = 128 118 | x = torch.Tensor(1, 3, image_size, image_size, image_size) 119 | x.to(device) 120 | print("x size: {}".format(x.size())) 121 | 122 | model = UNet(in_dim=3, out_dim=3, num_filters=4) 123 | 124 | out = model(x) 125 | print("out size: {}".format(out.size())) --------------------------------------------------------------------------------