├── images ├── Fig.2.png ├── Fig.7.png ├── Fig.8.png └── Fig.13.png ├── README.md ├── DEN_Network.py ├── PDF-UNet.py └── PMG_Network.py /images/Fig.2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahmedeqbal/PDF-UNet/HEAD/images/Fig.2.png -------------------------------------------------------------------------------- /images/Fig.7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahmedeqbal/PDF-UNet/HEAD/images/Fig.7.png -------------------------------------------------------------------------------- /images/Fig.8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahmedeqbal/PDF-UNet/HEAD/images/Fig.8.png -------------------------------------------------------------------------------- /images/Fig.13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahmedeqbal/PDF-UNet/HEAD/images/Fig.13.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository contains the original implementation of "**[PDF-UNet: A semi-supervised method for segmentation of breast tumor images using a U-shaped pyramid-dilated network 2 | ](https://doi.org/10.1016/j.eswa.2023.119718)**" in PyTorch. This paper has been published in "*Expert Systems with Applications - Elsevier, IF:8.665*" 3 | 4 | ## Proposed Architecture 5 | 6 | 7 | ## PDF-UNet 8 | 9 | 10 | # Pyramid-dilated fusion block 11 | 12 | 13 | ## Visualization of segmentation results 14 | 15 | 16 | ## Requirements 17 | 18 | - Python 3.9.7 19 | - PyTorch: 1.10.1 20 | - OpenCV: 4.6.0 21 | - Numpy: 1.22.3 22 | - Matplotlib: 3.5.1 23 | 24 | ## Cite: 25 | 26 | If you use PDF-UNet architecture in your project, please cite the following paper: 27 | ``` 28 | Iqbal, A., & Sharif, M. (2023). PDF-UNet: A semi-supervised method for segmentation of breast tumor images using a U-shaped pyramid-dilated network. Expert Systems with Applications, 119718. DOI: https://doi.org/10.1016/j.eswa.2023.119718 29 | -------------------------------------------------------------------------------- /DEN_Network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from glob import glob 3 | from matplotlib import pyplot 4 | from sklearn.utils import shuffle 5 | 6 | import tensorflow as tf 7 | from tensorflow.keras.layers import * 8 | from tensorflow.keras.models import Model 9 | from tensorflow.keras.optimizers import Adam 10 | 11 | IMG_H = 128 12 | IMG_W = 128 13 | IMG_C = 1 14 | 15 | w_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02) 16 | 17 | def load_image(image_path): 18 | img = tf.io.read_file(image_path) 19 | img = tf.io.decode_png(img) 20 | img = tf.image.rgb_to_grayscale(img) 21 | img = tf.image.resize_with_crop_or_pad(img, IMG_H, IMG_W) 22 | img = tf.cast(img, tf.float32) 23 | img = (img - 127.5) / 127.5 24 | return img 25 | 26 | def tf_dataset(images_path, batch_size): 27 | dataset = tf.data.Dataset.from_tensor_slices(images_path) 28 | dataset = dataset.shuffle(buffer_size=10240) 29 | dataset = dataset.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) 30 | dataset = dataset.batch(batch_size) 31 | dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) 32 | return dataset 33 | 34 | def conv_block(inputs, num_filters, kernel_size, padding="same", strides=2, activation=True): 35 | x = Conv2D( 36 | filters=num_filters, 37 | kernel_size=kernel_size, 38 | kernel_initializer=w_init, 39 | padding=padding, 40 | strides=strides, 41 | )(inputs) 42 | 43 | if activation: 44 | x = LeakyReLU(alpha=0.2)(x) 45 | x = Dropout(0.3)(x) 46 | return x 47 | 48 | def deconv_block(inputs, num_filters, kernel_size, strides, bn=True): 49 | x = Conv2DTranspose( 50 | filters=num_filters, 51 | kernel_size=kernel_size, 52 | kernel_initializer=w_init, 53 | padding="same", 54 | strides=strides, 55 | use_bias=True 56 | )(inputs) 57 | 58 | if bn: 59 | x = BatchNormalization()(x) 60 | x = LeakyReLU(alpha=0.2)(x) 61 | return x 62 | 63 | def build_generator(latent_dim): 64 | f = [2**i for i in range(5)][::-1] 65 | filters = 32 66 | output_strides = 16 67 | h_output = IMG_H // output_strides 68 | w_output = IMG_W // output_strides 69 | 70 | noise = Input(shape=(latent_dim,), name="generator_noise_input") 71 | 72 | x = Dense(f[0] * filters * h_output * w_output, use_bias=False)(noise) 73 | x = BatchNormalization()(x) 74 | x = LeakyReLU(alpha=0.2)(x) 75 | x = Reshape((h_output, w_output, 16 * filters))(x) 76 | 77 | for i in range(1, 5): 78 | x = deconv_block(x, 79 | num_filters=f[i] * filters, 80 | kernel_size=5, 81 | strides=2, 82 | bn=True 83 | ) 84 | 85 | x = conv_block(x, 86 | num_filters=1, 87 | kernel_size=5, 88 | strides=1, 89 | activation=False 90 | ) 91 | fake_output = Activation("tanh")(x) 92 | 93 | return Model(noise, fake_output, name="generator") 94 | 95 | def build_discriminator(): 96 | f = [2**i for i in range(4)] 97 | image_input = Input(shape=(IMG_H, IMG_W, IMG_C)) 98 | x = image_input 99 | filters = 64 100 | output_strides = 16 101 | h_output = IMG_H // output_strides 102 | w_output = IMG_W // output_strides 103 | 104 | for i in range(0, 4): 105 | x = conv_block(x, num_filters=f[i] * filters, kernel_size=3, strides=2) 106 | 107 | x = Flatten()(x) 108 | x = Dense(1)(x) 109 | 110 | return Model(image_input, x, name="discriminator") 111 | 112 | class DEN(Model): 113 | def __init__(self, discriminator, generator, latent_dim): 114 | super(DEN, self).__init__() 115 | self.discriminator = discriminator 116 | self.generator = generator 117 | self.latent_dim = latent_dim 118 | 119 | def compile(self, d_optimizer, g_optimizer, loss_fn): 120 | super(DEN, self).compile() 121 | self.d_optimizer = d_optimizer 122 | self.g_optimizer = g_optimizer 123 | self.loss_fn = loss_fn 124 | 125 | def train_step(self, real_images): 126 | batch_size = tf.shape(real_images)[0] 127 | 128 | for _ in range(2): 129 | ## Train the discriminator 130 | random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim)) 131 | generated_images = self.generator(random_latent_vectors) 132 | generated_labels = tf.zeros((batch_size, 1)) 133 | 134 | with tf.GradientTape() as ftape: 135 | predictions = self.discriminator(generated_images) 136 | d1_loss = self.loss_fn(generated_labels, predictions) 137 | grads = ftape.gradient(d1_loss, self.discriminator.trainable_weights) 138 | self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights)) 139 | 140 | ## Train the discriminator 141 | labels = tf.ones((batch_size, 1)) 142 | 143 | with tf.GradientTape() as rtape: 144 | predictions = self.discriminator(real_images) 145 | d2_loss = self.loss_fn(labels, predictions) 146 | grads = rtape.gradient(d2_loss, self.discriminator.trainable_weights) 147 | self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights)) 148 | 149 | ## Train the generator 150 | random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim)) 151 | misleading_labels = tf.ones((batch_size, 1)) 152 | 153 | with tf.GradientTape() as gtape: 154 | predictions = self.discriminator(self.generator(random_latent_vectors)) 155 | g_loss = self.loss_fn(misleading_labels, predictions) 156 | grads = gtape.gradient(g_loss, self.generator.trainable_weights) 157 | self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights)) 158 | 159 | return {"d1_loss": d1_loss, "d2_loss": d2_loss, "g_loss": g_loss} 160 | 161 | def save_plot(examples, epoch, n): 162 | examples = (examples + 1) / 2.0 163 | for i in range(n * n): 164 | pyplot.subplot(n, n, i+1) 165 | pyplot.axis("off") 166 | pyplot.imshow(np.squeeze(examples[i], axis=-1), cmap='gray') 167 | filename = f"samples/Experiment_01/Epochs/generated_plot_epoch-{epoch+1}.png" 168 | pyplot.savefig(filename) 169 | pyplot.close() 170 | 171 | if __name__ == "__main__": 172 | ## Hyperparameters 173 | batch_size = 16 174 | latent_dim = 128 175 | num_epochs = 300 176 | images_path = glob("Datasets/DEN_BUS_Dataset/Aug_All/*") 177 | 178 | d_model = build_discriminator() 179 | g_model = build_generator(latent_dim) 180 | 181 | ## d_model.load_weights("saved_model/d_model.h5") 182 | ## g_model.load_weights("saved_model/g_model.h5") 183 | 184 | d_model.summary() 185 | g_model.summary() 186 | 187 | den = DEN(d_model, g_model, latent_dim) 188 | 189 | bce_loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True, label_smoothing=0.1) 190 | d_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5) 191 | g_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5) 192 | den.compile(d_optimizer, g_optimizer, bce_loss_fn) 193 | 194 | images_dataset = tf_dataset(images_path, batch_size) 195 | 196 | for epoch in range(num_epochs): 197 | den.fit(images_dataset, epochs=1) 198 | g_model.save(f"saved_model/Experiments/g_model_{epoch}.h5") 199 | d_model.save(f"saved_model/Experiments/d_model_{epoch}.h5") 200 | 201 | n_samples = 25 202 | noise = np.random.normal(size=(n_samples, latent_dim)) 203 | examples = g_model.predict(noise) 204 | save_plot(examples, epoch, int(np.sqrt(n_samples))) 205 | -------------------------------------------------------------------------------- /PDF-UNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ConvLayer(nn.Module): 5 | def __init__(self, in_channels, out_channels, kernel_size = 1): 6 | super(ConvLayer, self).__init__() 7 | padding = int((kernel_size - 1) / 2) 8 | self.conv = nn.Sequential( 9 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), 10 | nn.BatchNorm2d(out_channels), 11 | nn.LeakyReLU() 12 | ) 13 | 14 | def forward(self, x): 15 | return self.conv(x) 16 | 17 | class SEBlock(nn.Module): 18 | def __init__(self, in_channels, r): 19 | super(SEBlock, self).__init__() 20 | 21 | redu_chns = int(in_channels / r) 22 | self.se_layers = nn.Sequential( 23 | nn.AdaptiveAvgPool2d(1), 24 | nn.Conv2d(in_channels, redu_chns, kernel_size=1, padding=0), 25 | nn.LeakyReLU(), 26 | nn.Conv2d(redu_chns, in_channels, kernel_size=1, padding=0), 27 | nn.ReLU()) 28 | 29 | def forward(self, x): 30 | f = self.se_layers(x) 31 | return f*x + x 32 | 33 | class PDFBlock(nn.Module): 34 | def __init__(self,in_channels, out_channels_list, kernel_size_list, dilation_list): 35 | super(PDFBlock, self).__init__() 36 | self.conv_num = len(out_channels_list) 37 | assert(self.conv_num == 4) 38 | assert(self.conv_num == len(kernel_size_list) and self.conv_num == len(dilation_list)) 39 | pad0 = int((kernel_size_list[0] - 1) / 2 * dilation_list[0]) 40 | pad1 = int((kernel_size_list[1] - 1) / 2 * dilation_list[1]) 41 | pad2 = int((kernel_size_list[2] - 1) / 2 * dilation_list[2]) 42 | pad3 = int((kernel_size_list[3] - 1) / 2 * dilation_list[3]) 43 | self.conv_1 = nn.Conv2d(in_channels, out_channels_list[0], kernel_size = kernel_size_list[0], dilation = dilation_list[0], padding = pad0 ) 44 | self.conv_2 = nn.Conv2d(in_channels, out_channels_list[1], kernel_size = kernel_size_list[1], dilation = dilation_list[1], padding = pad1 ) 45 | self.conv_3 = nn.Conv2d(in_channels, out_channels_list[2], kernel_size = kernel_size_list[2], dilation = dilation_list[2], padding = pad2 ) 46 | self.conv_4 = nn.Conv2d(in_channels, out_channels_list[3], kernel_size = kernel_size_list[3], dilation = dilation_list[3], padding = pad3 ) 47 | 48 | out_channels = out_channels_list[0] + out_channels_list[1] + out_channels_list[2] + out_channels_list[3] 49 | self.conv_1x1 = nn.Sequential( 50 | nn.Conv2d(out_channels, out_channels, kernel_size=1, padding=0), 51 | nn.BatchNorm2d(out_channels), 52 | nn.LeakyReLU()) 53 | 54 | def forward(self, x): 55 | x1 = self.conv_1(x) 56 | x2 = self.conv_2(x) 57 | x3 = self.conv_3(x) 58 | x4 = self.conv_4(x) 59 | 60 | y = torch.cat([x1, x2, x3, x4], dim=1) 61 | y = self.conv_1x1(y) 62 | return y 63 | 64 | class ConBNActBlock(nn.Module): 65 | def __init__(self,in_channels, out_channels, dropout_p): 66 | super(ConBNActBlock, self).__init__() 67 | self.conv_conv = nn.Sequential( 68 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 69 | nn.BatchNorm2d(out_channels), 70 | nn.LeakyReLU(), 71 | nn.Dropout(dropout_p), 72 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 73 | nn.BatchNorm2d(out_channels), 74 | nn.LeakyReLU(), 75 | SEBlock(out_channels, 2) 76 | ) 77 | 78 | def forward(self, x): 79 | return self.conv_conv(x) 80 | 81 | class UpBlock(nn.Module): 82 | def __init__(self, in_channels1, in_channels2, out_channels, 83 | bilinear=True, dropout_p = 0.5): 84 | super(UpBlock, self).__init__() 85 | self.bilinear = bilinear 86 | if bilinear: 87 | self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) 88 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 89 | else: 90 | self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2) 91 | self.conv = ConBNActBlock(in_channels2 * 2, out_channels, dropout_p) 92 | 93 | def forward(self, x1, x2): 94 | if self.bilinear: 95 | x1 = self.conv1x1(x1) 96 | x1 = self.up(x1) 97 | x_cat = torch.cat([x2, x1], dim=1) 98 | y = self.conv(x_cat) 99 | return y + x_cat 100 | 101 | class DownBlock(nn.Module): 102 | def __init__(self, in_channels, out_channels, dropout_p): 103 | super(DownBlock, self).__init__() 104 | self.maxpool = nn.MaxPool2d(2) 105 | self.avgpool = nn.AvgPool2d(2) 106 | self.conv = ConBNActBlock(2 * in_channels, out_channels, dropout_p) 107 | 108 | def forward(self, x): 109 | x_max = self.maxpool(x) 110 | x_avg = self.avgpool(x) 111 | x_cat = torch.cat([x_max, x_avg], dim=1) 112 | y = self.conv(x_cat) 113 | return y + x_cat 114 | 115 | class PDFUNet (nn.Module): 116 | def __init__(self): 117 | super(PDFUNet , self).__init__() 118 | self.in_chns = 1 119 | self.f_chan = [32, 64, 128, 256, 512] 120 | self.n_class = 1 121 | self.bilinear = True 122 | self.dropout = [0.0, 0.0, 0.3, 0.4, 0.5] 123 | assert(len(self.f_chan) == 5) 124 | 125 | f0_half = int(self.f_chan[0] / 2) 126 | f1_half = int(self.f_chan[1] / 2) 127 | f2_half = int(self.f_chan[2] / 2) 128 | f3_half = int(self.f_chan[3] / 2) 129 | self.in_conv= ConBNActBlock(self.in_chns, self.f_chan[0], self.dropout[0]) 130 | self.down1 = DownBlock(self.f_chan[0], self.f_chan[1], self.dropout[1]) 131 | self.down2 = DownBlock(self.f_chan[1], self.f_chan[2], self.dropout[2]) 132 | self.down3 = DownBlock(self.f_chan[2], self.f_chan[3], self.dropout[3]) 133 | self.down4 = DownBlock(self.f_chan[3], self.f_chan[4], self.dropout[4]) 134 | 135 | self.bridge0= ConvLayer(self.f_chan[0], f0_half) 136 | self.bridge1= ConvLayer(self.f_chan[1], f1_half) 137 | self.bridge2= ConvLayer(self.f_chan[2], f2_half) 138 | self.bridge3= ConvLayer(self.f_chan[3], f3_half) 139 | 140 | self.up1 = UpBlock(self.f_chan[4], f3_half, self.f_chan[3], dropout_p = self.dropout[3]) 141 | self.up2 = UpBlock(self.f_chan[3], f2_half, self.f_chan[2], dropout_p = self.dropout[2]) 142 | self.up3 = UpBlock(self.f_chan[2], f1_half, self.f_chan[1], dropout_p = self.dropout[1]) 143 | self.up4 = UpBlock(self.f_chan[1], f0_half, self.f_chan[0], dropout_p = self.dropout[0]) 144 | 145 | f4 = self.f_chan[4] 146 | aspp_chns = [int(f4 / 4), int(f4 / 4), int(f4 / 4), int(f4 / 4)] 147 | aspp_knls = [1, 3, 3, 3] 148 | aspp_dila = [1, 2, 4, 6] 149 | self.aspp = PDFBlock(f4, aspp_chns, aspp_knls, aspp_dila) 150 | 151 | 152 | self.out_conv = nn.Conv2d(self.f_chan[0], self.n_class, 153 | kernel_size = 3, padding = 1) 154 | 155 | def forward(self, x): 156 | x_shape = list(x.shape) 157 | if(len(x_shape) == 5): 158 | [N, C, D, H, W] = x_shape 159 | new_shape = [N*D, C, H, W] 160 | x = torch.transpose(x, 1, 2) 161 | x = torch.reshape(x, new_shape) 162 | x0 = self.in_conv(x) 163 | x0b = self.bridge0(x0) 164 | x1 = self.down1(x0) 165 | x1b = self.bridge1(x1) 166 | x2 = self.down2(x1) 167 | x2b = self.bridge2(x2) 168 | x3 = self.down3(x2) 169 | x3b = self.bridge3(x3) 170 | x4 = self.down4(x3) 171 | x4 = self.aspp(x4) 172 | 173 | x = self.up1(x4, x3b) 174 | x = self.up2(x, x2b) 175 | x = self.up3(x, x1b) 176 | x = self.up4(x, x0b) 177 | output = self.out_conv(x) 178 | 179 | if(len(x_shape) == 5): 180 | new_shape = [N, D] + list(output.shape)[1:] 181 | output = torch.reshape(output, new_shape) 182 | output = torch.transpose(output, 1, 2) 183 | return output 184 | -------------------------------------------------------------------------------- /PMG_Network.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | class Bottleneck(nn.Module): 8 | expansion = 4 9 | 10 | def __init__(self, inplanes, planes, stride=1, rate=1, downsample=None): 11 | super(Bottleneck, self).__init__() 12 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 13 | self.bn1 = nn.BatchNorm2d(planes) 14 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 15 | dilation=rate, padding=rate, bias=False) 16 | self.bn2 = nn.BatchNorm2d(planes) 17 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 18 | self.bn3 = nn.BatchNorm2d(planes * 4) 19 | self.relu = nn.ReLU(inplace=True) 20 | self.downsample = downsample 21 | self.stride = stride 22 | self.rate = rate 23 | 24 | def forward(self, x): 25 | residual = x 26 | 27 | out = self.conv1(x) 28 | out = self.bn1(out) 29 | out = self.relu(out) 30 | 31 | out = self.conv2(out) 32 | out = self.bn2(out) 33 | out = self.relu(out) 34 | 35 | out = self.conv3(out) 36 | out = self.bn3(out) 37 | 38 | if self.downsample is not None: 39 | residual = self.downsample(x) 40 | 41 | out += residual 42 | out = self.relu(out) 43 | 44 | return out 45 | 46 | class ResNet(nn.Module): 47 | 48 | def __init__(self, nInputChannels, block, layers, os=16, pretrained=False): 49 | self.inplanes = 64 50 | super(ResNet, self).__init__() 51 | if os == 16: 52 | strides = [1, 2, 2, 1] 53 | rates = [1, 1, 1, 2] 54 | blocks = [1, 2, 4] 55 | elif os == 8: 56 | strides = [1, 2, 1, 1] 57 | rates = [1, 1, 2, 2] 58 | blocks = [1, 2, 1] 59 | else: 60 | raise NotImplementedError 61 | 62 | # Modules 63 | self.conv1 = nn.Conv2d(nInputChannels, 64, kernel_size=7, stride=2, padding=3, 64 | bias=False) 65 | self.bn1 = nn.BatchNorm2d(64) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 68 | 69 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], rate=rates[0]) 70 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], rate=rates[1]) 71 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], rate=rates[2]) 72 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], rate=rates[3]) 73 | 74 | self._init_weight() 75 | 76 | if pretrained: 77 | self._load_pretrained_model() 78 | 79 | def _make_layer(self, block, planes, blocks, stride=1, rate=1): 80 | downsample = None 81 | if stride != 1 or self.inplanes != planes * block.expansion: 82 | downsample = nn.Sequential( 83 | nn.Conv2d(self.inplanes, planes * block.expansion, 84 | kernel_size=1, stride=stride, bias=False), 85 | nn.BatchNorm2d(planes * block.expansion), 86 | ) 87 | 88 | layers = [] 89 | layers.append(block(self.inplanes, planes, stride, rate, downsample)) 90 | self.inplanes = planes * block.expansion 91 | for i in range(1, blocks): 92 | layers.append(block(self.inplanes, planes)) 93 | 94 | return nn.Sequential(*layers) 95 | 96 | def _make_MG_unit(self, block, planes, blocks=[1,2,4], stride=1, rate=1): 97 | downsample = None 98 | if stride != 1 or self.inplanes != planes * block.expansion: 99 | downsample = nn.Sequential( 100 | nn.Conv2d(self.inplanes, planes * block.expansion, 101 | kernel_size=1, stride=stride, bias=False), 102 | nn.BatchNorm2d(planes * block.expansion), 103 | ) 104 | 105 | layers = [] 106 | layers.append(block(self.inplanes, planes, stride, rate=blocks[0]*rate, downsample=downsample)) 107 | self.inplanes = planes * block.expansion 108 | for i in range(1, len(blocks)): 109 | layers.append(block(self.inplanes, planes, stride=1, rate=blocks[i]*rate)) 110 | 111 | return nn.Sequential(*layers) 112 | 113 | def forward(self, input): 114 | x = self.conv1(input) 115 | x = self.bn1(x) 116 | x = self.relu(x) 117 | x = self.maxpool(x) 118 | 119 | x = self.layer1(x) 120 | low_level_feat = x 121 | x = self.layer2(x) 122 | x = self.layer3(x) 123 | x = self.layer4(x) 124 | return x, low_level_feat 125 | 126 | def _init_weight(self): 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 130 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 131 | torch.nn.init.kaiming_normal_(m.weight) 132 | elif isinstance(m, nn.BatchNorm2d): 133 | m.weight.data.fill_(1) 134 | m.bias.data.zero_() 135 | 136 | def _load_pretrained_model(self): 137 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth') 138 | model_dict = {} 139 | state_dict = self.state_dict() 140 | for k, v in pretrain_dict.items(): 141 | if k in state_dict: 142 | model_dict[k] = v 143 | state_dict.update(model_dict) 144 | self.load_state_dict(state_dict) 145 | 146 | def ResNet101(nInputChannels=3, os=16, pretrained=False): 147 | model = ResNet(nInputChannels, Bottleneck, [3, 4, 23, 3], os, pretrained=pretrained) 148 | return model 149 | 150 | 151 | class ASPP_module(nn.Module): 152 | def __init__(self, inplanes, planes, rate): 153 | super(ASPP_module, self).__init__() 154 | if rate == 1: 155 | kernel_size = 1 156 | padding = 0 157 | else: 158 | kernel_size = 3 159 | padding = rate 160 | self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 161 | stride=1, padding=padding, dilation=rate, bias=False) 162 | self.bn = nn.BatchNorm2d(planes) 163 | self.relu = nn.ReLU() 164 | 165 | self._init_weight() 166 | 167 | def forward(self, x): 168 | x = self.atrous_convolution(x) 169 | x = self.bn(x) 170 | 171 | return self.relu(x) 172 | 173 | def _init_weight(self): 174 | for m in self.modules(): 175 | if isinstance(m, nn.Conv2d): 176 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 177 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 178 | torch.nn.init.kaiming_normal_(m.weight) 179 | elif isinstance(m, nn.BatchNorm2d): 180 | m.weight.data.fill_(1) 181 | m.bias.data.zero_() 182 | 183 | 184 | class PMG_Network(nn.Module): 185 | def __init__(self, nInputChannels=3, n_classes=1, os=16, pretrained=F, _print=True): 186 | if _print: 187 | print("Constructing PMG model...") 188 | print("Number of classes: {}".format(n_classes)) 189 | print("Output stride: {}".format(os)) 190 | print("Number of Input Channels: {}".format(nInputChannels)) 191 | super(PMG_Network, self).__init__() 192 | 193 | # Atrous Conv 194 | self.resnet_features = ResNet101(nInputChannels, os, pretrained=pretrained) 195 | 196 | # ASPP 197 | if os == 16: 198 | rates = [1, 6, 12, 18] 199 | elif os == 8: 200 | rates = [1, 12, 24, 36] 201 | else: 202 | raise NotImplementedError 203 | 204 | self.aspp1 = ASPP_module(2048, 256, rate=rates[0]) 205 | self.aspp2 = ASPP_module(2048, 256, rate=rates[1]) 206 | self.aspp3 = ASPP_module(2048, 256, rate=rates[2]) 207 | self.aspp4 = ASPP_module(2048, 256, rate=rates[3]) 208 | 209 | self.relu = nn.ReLU() 210 | 211 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 212 | nn.Conv2d(2048, 256, 1, stride=1, bias=False), 213 | nn.BatchNorm2d(256), 214 | nn.ReLU()) 215 | 216 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 217 | self.bn1 = nn.BatchNorm2d(256) 218 | 219 | # adopt [1x1, 48] for channel reduction. 220 | self.conv2 = nn.Conv2d(256, 48, 1, bias=False) 221 | self.bn2 = nn.BatchNorm2d(48) 222 | 223 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 224 | nn.BatchNorm2d(256), 225 | nn.ReLU(), 226 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 227 | nn.BatchNorm2d(256), 228 | nn.ReLU(), 229 | nn.Conv2d(256, n_classes, kernel_size=1, stride=1)) 230 | 231 | def forward(self, input): 232 | x, low_level_features = self.resnet_features(input) 233 | x1 = self.aspp1(x) 234 | x2 = self.aspp2(x) 235 | x3 = self.aspp3(x) 236 | x4 = self.aspp4(x) 237 | x5 = self.global_avg_pool(x) 238 | x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 239 | 240 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 241 | 242 | x = self.conv1(x) 243 | x = self.bn1(x) 244 | x = self.relu(x) 245 | x = F.upsample(x, size=(int(math.ceil(input.size()[-2]/4)), 246 | int(math.ceil(input.size()[-1]/4))), mode='bilinear', align_corners=True) 247 | 248 | low_level_features = self.conv2(low_level_features) 249 | low_level_features = self.bn2(low_level_features) 250 | low_level_features = self.relu(low_level_features) 251 | 252 | 253 | x = torch.cat((x, low_level_features), dim=1) 254 | x = self.last_conv(x) 255 | x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True) 256 | 257 | return x 258 | 259 | def freeze_bn(self): 260 | for m in self.modules(): 261 | if isinstance(m, nn.BatchNorm2d): 262 | m.eval() 263 | 264 | def __init_weight(self): 265 | for m in self.modules(): 266 | if isinstance(m, nn.Conv2d): 267 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 268 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 269 | torch.nn.init.kaiming_normal_(m.weight) 270 | elif isinstance(m, nn.BatchNorm2d): 271 | m.weight.data.fill_(1) 272 | m.bias.data.zero_() 273 | 274 | def get_1x_lr_params(model): 275 | """ 276 | This generator returns all the parameters of the net except for 277 | the last classification layer. Note that for each batchnorm layer, 278 | requires_grad is set to False in PMG_resnet.py, therefore this function does not return 279 | any batchnorm parameter 280 | """ 281 | b = [model.resnet_features] 282 | for i in range(len(b)): 283 | for k in b[i].parameters(): 284 | if k.requires_grad: 285 | yield k 286 | 287 | def get_10x_lr_params(model): 288 | """ 289 | This generator returns all the parameters for the last layer of the net, 290 | which does the classification of pixel into classes 291 | """ 292 | b = [model.aspp1, model.aspp2, model.aspp3, model.aspp4, model.conv1, model.conv2, model.last_conv] 293 | for j in range(len(b)): 294 | for k in b[j].parameters(): 295 | if k.requires_grad: 296 | yield k 297 | --------------------------------------------------------------------------------