├── LICENSE ├── README.md └── model └── VNet.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Sam 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reproduction of *V-Net* in PyTorch 2 | 3 | This repo is created because I found that most repos on Github trying to reproduce V-Net by PyTorch have too many differences from the orginal paper *V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image*, which may cause "potential misunderstandings". 4 | 5 | Here are some additional explanation: 6 | - My work is mainly based on @zengyu714's code. The full network is constructed in a single file so you can easily adapt it to any of your projects. 7 | - The network is now much closer to the one described in the paper(in my view), but still have two things different -- normalization and last activation layer. This repo uses InstanceNorm3d between Conv3d and PReLU, and uses Sigmoid as last activation layer. 8 | 9 | Please star this repo if you find it helpful, and if you find any other places of the codes seem different from the one in that paper, please inform me without hesitation. Thank you very much. 10 | -------------------------------------------------------------------------------- /model/VNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | class conv3d(nn.Module): 6 | def __init__(self, in_channels, out_channels): 7 | """ 8 | + Instantiate modules: conv-relu-norm 9 | + Assign them as member variables 10 | """ 11 | super(conv3d, self).__init__() 12 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=5, padding=2) 13 | self.relu = nn.PReLU() 14 | # with learnable parameters 15 | self.norm = nn.InstanceNorm3d(out_channels, affine=True) 16 | 17 | def forward(self, x): 18 | return self.relu(self.norm(self.conv(x))) 19 | 20 | 21 | class conv3d_x3(nn.Module): 22 | """Three serial convs with a residual connection. 23 | Structure: 24 | inputs --> ① --> ② --> ③ --> outputs 25 | ↓ --> add--> ↑ 26 | """ 27 | 28 | def __init__(self, in_channels, out_channels): 29 | super(conv3d_x3, self).__init__() 30 | self.conv_1 = conv3d(in_channels, out_channels) 31 | self.conv_2 = conv3d(out_channels, out_channels) 32 | self.conv_3 = conv3d(out_channels, out_channels) 33 | self.skip_connection=nn.Conv3d(in_channels,out_channels,1) 34 | 35 | def forward(self, x): 36 | z_1 = self.conv_1(x) 37 | z_3 = self.conv_3(self.conv_2(z_1)) 38 | return z_3 + self.skip_connection(x) 39 | 40 | class conv3d_x2(nn.Module): 41 | """Three serial convs with a residual connection. 42 | Structure: 43 | inputs --> ① --> ② --> ③ --> outputs 44 | ↓ --> add--> ↑ 45 | """ 46 | 47 | def __init__(self, in_channels, out_channels): 48 | super(conv3d_x2, self).__init__() 49 | self.conv_1 = conv3d(in_channels, out_channels) 50 | self.conv_2 = conv3d(out_channels, out_channels) 51 | self.skip_connection=nn.Conv3d(in_channels,out_channels,1) 52 | 53 | def forward(self, x): 54 | z_1 = self.conv_1(x) 55 | z_2 = self.conv_2(z_1) 56 | return z_2 + self.skip_connection(x) 57 | 58 | 59 | class conv3d_x1(nn.Module): 60 | """Three serial convs with a residual connection. 61 | Structure: 62 | inputs --> ① --> ② --> ③ --> outputs 63 | ↓ --> add--> ↑ 64 | """ 65 | 66 | def __init__(self, in_channels, out_channels): 67 | super(conv3d_x1, self).__init__() 68 | self.conv_1 = conv3d(in_channels, out_channels) 69 | self.skip_connection=nn.Conv3d(in_channels,out_channels,1) 70 | 71 | def forward(self, x): 72 | z_1 = self.conv_1(x) 73 | return z_1 + self.skip_connection(x) 74 | 75 | class deconv3d_x3(nn.Module): 76 | def __init__(self, in_channels, out_channels): 77 | super(deconv3d_x3, self).__init__() 78 | self.up = deconv3d_as_up(in_channels, out_channels, 2, 2) 79 | self.lhs_conv = conv3d(out_channels // 2, out_channels) 80 | self.conv_x3 = nn.Sequential( 81 | nn.Conv3d(2*out_channels, out_channels,5,1,2), 82 | nn.PReLU(), 83 | nn.Conv3d(out_channels, out_channels,5,1,2), 84 | nn.PReLU(), 85 | nn.Conv3d(out_channels, out_channels,5,1,2), 86 | nn.PReLU(), 87 | ) 88 | 89 | def forward(self, lhs, rhs): 90 | rhs_up = self.up(rhs) 91 | lhs_conv = self.lhs_conv(lhs) 92 | rhs_add = torch.cat((rhs_up, lhs_conv),dim=1) 93 | return self.conv_x3(rhs_add)+ rhs_up 94 | 95 | class deconv3d_x2(nn.Module): 96 | def __init__(self, in_channels, out_channels): 97 | super(deconv3d_x2, self).__init__() 98 | self.up = deconv3d_as_up(in_channels, out_channels, 2, 2) 99 | self.lhs_conv = conv3d(out_channels // 2, out_channels) 100 | self.conv_x2= nn.Sequential( 101 | nn.Conv3d(2*out_channels, out_channels,5,1,2), 102 | nn.PReLU(), 103 | nn.Conv3d(out_channels, out_channels,5,1,2), 104 | nn.PReLU(), 105 | ) 106 | 107 | def forward(self, lhs, rhs): 108 | rhs_up = self.up(rhs) 109 | lhs_conv = self.lhs_conv(lhs) 110 | rhs_add = torch.cat((rhs_up, lhs_conv),dim=1) 111 | return self.conv_x2(rhs_add)+ rhs_up 112 | 113 | class deconv3d_x1(nn.Module): 114 | def __init__(self, in_channels, out_channels): 115 | super(deconv3d_x1, self).__init__() 116 | self.up = deconv3d_as_up(in_channels, out_channels, 2, 2) 117 | self.lhs_conv = conv3d(out_channels // 2, out_channels) 118 | self.conv_x1 = nn.Sequential( 119 | nn.Conv3d(2*out_channels, out_channels,5,1,2), 120 | nn.PReLU(), 121 | ) 122 | 123 | def forward(self, lhs, rhs): 124 | rhs_up = self.up(rhs) 125 | lhs_conv = self.lhs_conv(lhs) 126 | rhs_add = torch.cat((rhs_up, lhs_conv),dim=1) 127 | return self.conv_x1(rhs_add)+ rhs_up 128 | 129 | 130 | def conv3d_as_pool(in_channels, out_channels, kernel_size=2, stride=2): 131 | return nn.Sequential( 132 | nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding=0), 133 | nn.PReLU()) 134 | 135 | 136 | def deconv3d_as_up(in_channels, out_channels, kernel_size=2, stride=2): 137 | return nn.Sequential( 138 | nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride), 139 | nn.PReLU() 140 | ) 141 | 142 | 143 | class softmax_out(nn.Module): 144 | def __init__(self, in_channels, out_channels): 145 | super(softmax_out, self).__init__() 146 | self.conv_1 = nn.Conv3d(in_channels, out_channels, kernel_size=5, padding=2) 147 | self.conv_2 = nn.Conv3d(out_channels, out_channels, kernel_size=1, padding=0) 148 | 149 | def forward(self, x): 150 | """Output with shape [batch_size, 1, depth, height, width].""" 151 | # Do NOT add normalize layer, or its values vanish. 152 | y_conv = self.conv_2(self.conv_1(x)) 153 | return nn.Sigmoid()(y_conv) 154 | 155 | 156 | class VNet(nn.Module): 157 | def __init__(self): 158 | super(VNet, self).__init__() 159 | self.conv_1 = conv3d_x1(1, 16) 160 | self.pool_1 = conv3d_as_pool(16, 32) 161 | self.conv_2 = conv3d_x2(32, 32) 162 | self.pool_2 = conv3d_as_pool(32, 64) 163 | self.conv_3 = conv3d_x3(64, 64) 164 | self.pool_3 = conv3d_as_pool(64, 128) 165 | self.conv_4 = conv3d_x3(128, 128) 166 | self.pool_4 = conv3d_as_pool(128, 256) 167 | 168 | self.bottom = conv3d_x3(256, 256) 169 | 170 | self.deconv_4 = deconv3d_x3(256, 256) 171 | self.deconv_3 = deconv3d_x3(256, 128) 172 | self.deconv_2 = deconv3d_x2(128, 64) 173 | self.deconv_1 = deconv3d_x1(64, 32) 174 | 175 | self.out = softmax_out(32, 1) 176 | 177 | def forward(self, x): 178 | conv_1 = self.conv_1(x) 179 | pool = self.pool_1(conv_1) 180 | conv_2 = self.conv_2(pool) 181 | pool = self.pool_2(conv_2) 182 | conv_3 = self.conv_3(pool) 183 | pool = self.pool_3(conv_3) 184 | conv_4 = self.conv_4(pool) 185 | pool = self.pool_4(conv_4) 186 | bottom = self.bottom(pool) 187 | deconv = self.deconv_4(conv_4, bottom) 188 | deconv = self.deconv_3(conv_3, deconv) 189 | deconv = self.deconv_2(conv_2, deconv) 190 | deconv = self.deconv_1(conv_1, deconv) 191 | return self.out(deconv) --------------------------------------------------------------------------------