├── images
├── Fig.2.png
├── Fig.7.png
├── Fig.8.png
└── Fig.13.png
├── README.md
├── DEN_Network.py
├── PDF-UNet.py
└── PMG_Network.py
/images/Fig.2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ahmedeqbal/PDF-UNet/HEAD/images/Fig.2.png
--------------------------------------------------------------------------------
/images/Fig.7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ahmedeqbal/PDF-UNet/HEAD/images/Fig.7.png
--------------------------------------------------------------------------------
/images/Fig.8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ahmedeqbal/PDF-UNet/HEAD/images/Fig.8.png
--------------------------------------------------------------------------------
/images/Fig.13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ahmedeqbal/PDF-UNet/HEAD/images/Fig.13.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This repository contains the original implementation of "**[PDF-UNet: A semi-supervised method for segmentation of breast tumor images using a U-shaped pyramid-dilated network
2 | ](https://doi.org/10.1016/j.eswa.2023.119718)**" in PyTorch. This paper has been published in "*Expert Systems with Applications - Elsevier, IF:8.665*"
3 |
4 | ## Proposed Architecture
5 |
6 |
7 | ## PDF-UNet
8 |
9 |
10 | # Pyramid-dilated fusion block
11 |
12 |
13 | ## Visualization of segmentation results
14 |
15 |
16 | ## Requirements
17 |
18 | - Python 3.9.7
19 | - PyTorch: 1.10.1
20 | - OpenCV: 4.6.0
21 | - Numpy: 1.22.3
22 | - Matplotlib: 3.5.1
23 |
24 | ## Cite:
25 |
26 | If you use PDF-UNet architecture in your project, please cite the following paper:
27 | ```
28 | Iqbal, A., & Sharif, M. (2023). PDF-UNet: A semi-supervised method for segmentation of breast tumor images using a U-shaped pyramid-dilated network. Expert Systems with Applications, 119718. DOI: https://doi.org/10.1016/j.eswa.2023.119718
29 |
--------------------------------------------------------------------------------
/DEN_Network.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from glob import glob
3 | from matplotlib import pyplot
4 | from sklearn.utils import shuffle
5 |
6 | import tensorflow as tf
7 | from tensorflow.keras.layers import *
8 | from tensorflow.keras.models import Model
9 | from tensorflow.keras.optimizers import Adam
10 |
11 | IMG_H = 128
12 | IMG_W = 128
13 | IMG_C = 1
14 |
15 | w_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
16 |
17 | def load_image(image_path):
18 | img = tf.io.read_file(image_path)
19 | img = tf.io.decode_png(img)
20 | img = tf.image.rgb_to_grayscale(img)
21 | img = tf.image.resize_with_crop_or_pad(img, IMG_H, IMG_W)
22 | img = tf.cast(img, tf.float32)
23 | img = (img - 127.5) / 127.5
24 | return img
25 |
26 | def tf_dataset(images_path, batch_size):
27 | dataset = tf.data.Dataset.from_tensor_slices(images_path)
28 | dataset = dataset.shuffle(buffer_size=10240)
29 | dataset = dataset.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
30 | dataset = dataset.batch(batch_size)
31 | dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
32 | return dataset
33 |
34 | def conv_block(inputs, num_filters, kernel_size, padding="same", strides=2, activation=True):
35 | x = Conv2D(
36 | filters=num_filters,
37 | kernel_size=kernel_size,
38 | kernel_initializer=w_init,
39 | padding=padding,
40 | strides=strides,
41 | )(inputs)
42 |
43 | if activation:
44 | x = LeakyReLU(alpha=0.2)(x)
45 | x = Dropout(0.3)(x)
46 | return x
47 |
48 | def deconv_block(inputs, num_filters, kernel_size, strides, bn=True):
49 | x = Conv2DTranspose(
50 | filters=num_filters,
51 | kernel_size=kernel_size,
52 | kernel_initializer=w_init,
53 | padding="same",
54 | strides=strides,
55 | use_bias=True
56 | )(inputs)
57 |
58 | if bn:
59 | x = BatchNormalization()(x)
60 | x = LeakyReLU(alpha=0.2)(x)
61 | return x
62 |
63 | def build_generator(latent_dim):
64 | f = [2**i for i in range(5)][::-1]
65 | filters = 32
66 | output_strides = 16
67 | h_output = IMG_H // output_strides
68 | w_output = IMG_W // output_strides
69 |
70 | noise = Input(shape=(latent_dim,), name="generator_noise_input")
71 |
72 | x = Dense(f[0] * filters * h_output * w_output, use_bias=False)(noise)
73 | x = BatchNormalization()(x)
74 | x = LeakyReLU(alpha=0.2)(x)
75 | x = Reshape((h_output, w_output, 16 * filters))(x)
76 |
77 | for i in range(1, 5):
78 | x = deconv_block(x,
79 | num_filters=f[i] * filters,
80 | kernel_size=5,
81 | strides=2,
82 | bn=True
83 | )
84 |
85 | x = conv_block(x,
86 | num_filters=1,
87 | kernel_size=5,
88 | strides=1,
89 | activation=False
90 | )
91 | fake_output = Activation("tanh")(x)
92 |
93 | return Model(noise, fake_output, name="generator")
94 |
95 | def build_discriminator():
96 | f = [2**i for i in range(4)]
97 | image_input = Input(shape=(IMG_H, IMG_W, IMG_C))
98 | x = image_input
99 | filters = 64
100 | output_strides = 16
101 | h_output = IMG_H // output_strides
102 | w_output = IMG_W // output_strides
103 |
104 | for i in range(0, 4):
105 | x = conv_block(x, num_filters=f[i] * filters, kernel_size=3, strides=2)
106 |
107 | x = Flatten()(x)
108 | x = Dense(1)(x)
109 |
110 | return Model(image_input, x, name="discriminator")
111 |
112 | class DEN(Model):
113 | def __init__(self, discriminator, generator, latent_dim):
114 | super(DEN, self).__init__()
115 | self.discriminator = discriminator
116 | self.generator = generator
117 | self.latent_dim = latent_dim
118 |
119 | def compile(self, d_optimizer, g_optimizer, loss_fn):
120 | super(DEN, self).compile()
121 | self.d_optimizer = d_optimizer
122 | self.g_optimizer = g_optimizer
123 | self.loss_fn = loss_fn
124 |
125 | def train_step(self, real_images):
126 | batch_size = tf.shape(real_images)[0]
127 |
128 | for _ in range(2):
129 | ## Train the discriminator
130 | random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
131 | generated_images = self.generator(random_latent_vectors)
132 | generated_labels = tf.zeros((batch_size, 1))
133 |
134 | with tf.GradientTape() as ftape:
135 | predictions = self.discriminator(generated_images)
136 | d1_loss = self.loss_fn(generated_labels, predictions)
137 | grads = ftape.gradient(d1_loss, self.discriminator.trainable_weights)
138 | self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))
139 |
140 | ## Train the discriminator
141 | labels = tf.ones((batch_size, 1))
142 |
143 | with tf.GradientTape() as rtape:
144 | predictions = self.discriminator(real_images)
145 | d2_loss = self.loss_fn(labels, predictions)
146 | grads = rtape.gradient(d2_loss, self.discriminator.trainable_weights)
147 | self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))
148 |
149 | ## Train the generator
150 | random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
151 | misleading_labels = tf.ones((batch_size, 1))
152 |
153 | with tf.GradientTape() as gtape:
154 | predictions = self.discriminator(self.generator(random_latent_vectors))
155 | g_loss = self.loss_fn(misleading_labels, predictions)
156 | grads = gtape.gradient(g_loss, self.generator.trainable_weights)
157 | self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
158 |
159 | return {"d1_loss": d1_loss, "d2_loss": d2_loss, "g_loss": g_loss}
160 |
161 | def save_plot(examples, epoch, n):
162 | examples = (examples + 1) / 2.0
163 | for i in range(n * n):
164 | pyplot.subplot(n, n, i+1)
165 | pyplot.axis("off")
166 | pyplot.imshow(np.squeeze(examples[i], axis=-1), cmap='gray')
167 | filename = f"samples/Experiment_01/Epochs/generated_plot_epoch-{epoch+1}.png"
168 | pyplot.savefig(filename)
169 | pyplot.close()
170 |
171 | if __name__ == "__main__":
172 | ## Hyperparameters
173 | batch_size = 16
174 | latent_dim = 128
175 | num_epochs = 300
176 | images_path = glob("Datasets/DEN_BUS_Dataset/Aug_All/*")
177 |
178 | d_model = build_discriminator()
179 | g_model = build_generator(latent_dim)
180 |
181 | ## d_model.load_weights("saved_model/d_model.h5")
182 | ## g_model.load_weights("saved_model/g_model.h5")
183 |
184 | d_model.summary()
185 | g_model.summary()
186 |
187 | den = DEN(d_model, g_model, latent_dim)
188 |
189 | bce_loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True, label_smoothing=0.1)
190 | d_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
191 | g_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
192 | den.compile(d_optimizer, g_optimizer, bce_loss_fn)
193 |
194 | images_dataset = tf_dataset(images_path, batch_size)
195 |
196 | for epoch in range(num_epochs):
197 | den.fit(images_dataset, epochs=1)
198 | g_model.save(f"saved_model/Experiments/g_model_{epoch}.h5")
199 | d_model.save(f"saved_model/Experiments/d_model_{epoch}.h5")
200 |
201 | n_samples = 25
202 | noise = np.random.normal(size=(n_samples, latent_dim))
203 | examples = g_model.predict(noise)
204 | save_plot(examples, epoch, int(np.sqrt(n_samples)))
205 |
--------------------------------------------------------------------------------
/PDF-UNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class ConvLayer(nn.Module):
5 | def __init__(self, in_channels, out_channels, kernel_size = 1):
6 | super(ConvLayer, self).__init__()
7 | padding = int((kernel_size - 1) / 2)
8 | self.conv = nn.Sequential(
9 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding),
10 | nn.BatchNorm2d(out_channels),
11 | nn.LeakyReLU()
12 | )
13 |
14 | def forward(self, x):
15 | return self.conv(x)
16 |
17 | class SEBlock(nn.Module):
18 | def __init__(self, in_channels, r):
19 | super(SEBlock, self).__init__()
20 |
21 | redu_chns = int(in_channels / r)
22 | self.se_layers = nn.Sequential(
23 | nn.AdaptiveAvgPool2d(1),
24 | nn.Conv2d(in_channels, redu_chns, kernel_size=1, padding=0),
25 | nn.LeakyReLU(),
26 | nn.Conv2d(redu_chns, in_channels, kernel_size=1, padding=0),
27 | nn.ReLU())
28 |
29 | def forward(self, x):
30 | f = self.se_layers(x)
31 | return f*x + x
32 |
33 | class PDFBlock(nn.Module):
34 | def __init__(self,in_channels, out_channels_list, kernel_size_list, dilation_list):
35 | super(PDFBlock, self).__init__()
36 | self.conv_num = len(out_channels_list)
37 | assert(self.conv_num == 4)
38 | assert(self.conv_num == len(kernel_size_list) and self.conv_num == len(dilation_list))
39 | pad0 = int((kernel_size_list[0] - 1) / 2 * dilation_list[0])
40 | pad1 = int((kernel_size_list[1] - 1) / 2 * dilation_list[1])
41 | pad2 = int((kernel_size_list[2] - 1) / 2 * dilation_list[2])
42 | pad3 = int((kernel_size_list[3] - 1) / 2 * dilation_list[3])
43 | self.conv_1 = nn.Conv2d(in_channels, out_channels_list[0], kernel_size = kernel_size_list[0], dilation = dilation_list[0], padding = pad0 )
44 | self.conv_2 = nn.Conv2d(in_channels, out_channels_list[1], kernel_size = kernel_size_list[1], dilation = dilation_list[1], padding = pad1 )
45 | self.conv_3 = nn.Conv2d(in_channels, out_channels_list[2], kernel_size = kernel_size_list[2], dilation = dilation_list[2], padding = pad2 )
46 | self.conv_4 = nn.Conv2d(in_channels, out_channels_list[3], kernel_size = kernel_size_list[3], dilation = dilation_list[3], padding = pad3 )
47 |
48 | out_channels = out_channels_list[0] + out_channels_list[1] + out_channels_list[2] + out_channels_list[3]
49 | self.conv_1x1 = nn.Sequential(
50 | nn.Conv2d(out_channels, out_channels, kernel_size=1, padding=0),
51 | nn.BatchNorm2d(out_channels),
52 | nn.LeakyReLU())
53 |
54 | def forward(self, x):
55 | x1 = self.conv_1(x)
56 | x2 = self.conv_2(x)
57 | x3 = self.conv_3(x)
58 | x4 = self.conv_4(x)
59 |
60 | y = torch.cat([x1, x2, x3, x4], dim=1)
61 | y = self.conv_1x1(y)
62 | return y
63 |
64 | class ConBNActBlock(nn.Module):
65 | def __init__(self,in_channels, out_channels, dropout_p):
66 | super(ConBNActBlock, self).__init__()
67 | self.conv_conv = nn.Sequential(
68 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
69 | nn.BatchNorm2d(out_channels),
70 | nn.LeakyReLU(),
71 | nn.Dropout(dropout_p),
72 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
73 | nn.BatchNorm2d(out_channels),
74 | nn.LeakyReLU(),
75 | SEBlock(out_channels, 2)
76 | )
77 |
78 | def forward(self, x):
79 | return self.conv_conv(x)
80 |
81 | class UpBlock(nn.Module):
82 | def __init__(self, in_channels1, in_channels2, out_channels,
83 | bilinear=True, dropout_p = 0.5):
84 | super(UpBlock, self).__init__()
85 | self.bilinear = bilinear
86 | if bilinear:
87 | self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1)
88 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
89 | else:
90 | self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2)
91 | self.conv = ConBNActBlock(in_channels2 * 2, out_channels, dropout_p)
92 |
93 | def forward(self, x1, x2):
94 | if self.bilinear:
95 | x1 = self.conv1x1(x1)
96 | x1 = self.up(x1)
97 | x_cat = torch.cat([x2, x1], dim=1)
98 | y = self.conv(x_cat)
99 | return y + x_cat
100 |
101 | class DownBlock(nn.Module):
102 | def __init__(self, in_channels, out_channels, dropout_p):
103 | super(DownBlock, self).__init__()
104 | self.maxpool = nn.MaxPool2d(2)
105 | self.avgpool = nn.AvgPool2d(2)
106 | self.conv = ConBNActBlock(2 * in_channels, out_channels, dropout_p)
107 |
108 | def forward(self, x):
109 | x_max = self.maxpool(x)
110 | x_avg = self.avgpool(x)
111 | x_cat = torch.cat([x_max, x_avg], dim=1)
112 | y = self.conv(x_cat)
113 | return y + x_cat
114 |
115 | class PDFUNet (nn.Module):
116 | def __init__(self):
117 | super(PDFUNet , self).__init__()
118 | self.in_chns = 1
119 | self.f_chan = [32, 64, 128, 256, 512]
120 | self.n_class = 1
121 | self.bilinear = True
122 | self.dropout = [0.0, 0.0, 0.3, 0.4, 0.5]
123 | assert(len(self.f_chan) == 5)
124 |
125 | f0_half = int(self.f_chan[0] / 2)
126 | f1_half = int(self.f_chan[1] / 2)
127 | f2_half = int(self.f_chan[2] / 2)
128 | f3_half = int(self.f_chan[3] / 2)
129 | self.in_conv= ConBNActBlock(self.in_chns, self.f_chan[0], self.dropout[0])
130 | self.down1 = DownBlock(self.f_chan[0], self.f_chan[1], self.dropout[1])
131 | self.down2 = DownBlock(self.f_chan[1], self.f_chan[2], self.dropout[2])
132 | self.down3 = DownBlock(self.f_chan[2], self.f_chan[3], self.dropout[3])
133 | self.down4 = DownBlock(self.f_chan[3], self.f_chan[4], self.dropout[4])
134 |
135 | self.bridge0= ConvLayer(self.f_chan[0], f0_half)
136 | self.bridge1= ConvLayer(self.f_chan[1], f1_half)
137 | self.bridge2= ConvLayer(self.f_chan[2], f2_half)
138 | self.bridge3= ConvLayer(self.f_chan[3], f3_half)
139 |
140 | self.up1 = UpBlock(self.f_chan[4], f3_half, self.f_chan[3], dropout_p = self.dropout[3])
141 | self.up2 = UpBlock(self.f_chan[3], f2_half, self.f_chan[2], dropout_p = self.dropout[2])
142 | self.up3 = UpBlock(self.f_chan[2], f1_half, self.f_chan[1], dropout_p = self.dropout[1])
143 | self.up4 = UpBlock(self.f_chan[1], f0_half, self.f_chan[0], dropout_p = self.dropout[0])
144 |
145 | f4 = self.f_chan[4]
146 | aspp_chns = [int(f4 / 4), int(f4 / 4), int(f4 / 4), int(f4 / 4)]
147 | aspp_knls = [1, 3, 3, 3]
148 | aspp_dila = [1, 2, 4, 6]
149 | self.aspp = PDFBlock(f4, aspp_chns, aspp_knls, aspp_dila)
150 |
151 |
152 | self.out_conv = nn.Conv2d(self.f_chan[0], self.n_class,
153 | kernel_size = 3, padding = 1)
154 |
155 | def forward(self, x):
156 | x_shape = list(x.shape)
157 | if(len(x_shape) == 5):
158 | [N, C, D, H, W] = x_shape
159 | new_shape = [N*D, C, H, W]
160 | x = torch.transpose(x, 1, 2)
161 | x = torch.reshape(x, new_shape)
162 | x0 = self.in_conv(x)
163 | x0b = self.bridge0(x0)
164 | x1 = self.down1(x0)
165 | x1b = self.bridge1(x1)
166 | x2 = self.down2(x1)
167 | x2b = self.bridge2(x2)
168 | x3 = self.down3(x2)
169 | x3b = self.bridge3(x3)
170 | x4 = self.down4(x3)
171 | x4 = self.aspp(x4)
172 |
173 | x = self.up1(x4, x3b)
174 | x = self.up2(x, x2b)
175 | x = self.up3(x, x1b)
176 | x = self.up4(x, x0b)
177 | output = self.out_conv(x)
178 |
179 | if(len(x_shape) == 5):
180 | new_shape = [N, D] + list(output.shape)[1:]
181 | output = torch.reshape(output, new_shape)
182 | output = torch.transpose(output, 1, 2)
183 | return output
184 |
--------------------------------------------------------------------------------
/PMG_Network.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.utils.model_zoo as model_zoo
6 |
7 | class Bottleneck(nn.Module):
8 | expansion = 4
9 |
10 | def __init__(self, inplanes, planes, stride=1, rate=1, downsample=None):
11 | super(Bottleneck, self).__init__()
12 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
13 | self.bn1 = nn.BatchNorm2d(planes)
14 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
15 | dilation=rate, padding=rate, bias=False)
16 | self.bn2 = nn.BatchNorm2d(planes)
17 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
18 | self.bn3 = nn.BatchNorm2d(planes * 4)
19 | self.relu = nn.ReLU(inplace=True)
20 | self.downsample = downsample
21 | self.stride = stride
22 | self.rate = rate
23 |
24 | def forward(self, x):
25 | residual = x
26 |
27 | out = self.conv1(x)
28 | out = self.bn1(out)
29 | out = self.relu(out)
30 |
31 | out = self.conv2(out)
32 | out = self.bn2(out)
33 | out = self.relu(out)
34 |
35 | out = self.conv3(out)
36 | out = self.bn3(out)
37 |
38 | if self.downsample is not None:
39 | residual = self.downsample(x)
40 |
41 | out += residual
42 | out = self.relu(out)
43 |
44 | return out
45 |
46 | class ResNet(nn.Module):
47 |
48 | def __init__(self, nInputChannels, block, layers, os=16, pretrained=False):
49 | self.inplanes = 64
50 | super(ResNet, self).__init__()
51 | if os == 16:
52 | strides = [1, 2, 2, 1]
53 | rates = [1, 1, 1, 2]
54 | blocks = [1, 2, 4]
55 | elif os == 8:
56 | strides = [1, 2, 1, 1]
57 | rates = [1, 1, 2, 2]
58 | blocks = [1, 2, 1]
59 | else:
60 | raise NotImplementedError
61 |
62 | # Modules
63 | self.conv1 = nn.Conv2d(nInputChannels, 64, kernel_size=7, stride=2, padding=3,
64 | bias=False)
65 | self.bn1 = nn.BatchNorm2d(64)
66 | self.relu = nn.ReLU(inplace=True)
67 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
68 |
69 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], rate=rates[0])
70 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], rate=rates[1])
71 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], rate=rates[2])
72 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], rate=rates[3])
73 |
74 | self._init_weight()
75 |
76 | if pretrained:
77 | self._load_pretrained_model()
78 |
79 | def _make_layer(self, block, planes, blocks, stride=1, rate=1):
80 | downsample = None
81 | if stride != 1 or self.inplanes != planes * block.expansion:
82 | downsample = nn.Sequential(
83 | nn.Conv2d(self.inplanes, planes * block.expansion,
84 | kernel_size=1, stride=stride, bias=False),
85 | nn.BatchNorm2d(planes * block.expansion),
86 | )
87 |
88 | layers = []
89 | layers.append(block(self.inplanes, planes, stride, rate, downsample))
90 | self.inplanes = planes * block.expansion
91 | for i in range(1, blocks):
92 | layers.append(block(self.inplanes, planes))
93 |
94 | return nn.Sequential(*layers)
95 |
96 | def _make_MG_unit(self, block, planes, blocks=[1,2,4], stride=1, rate=1):
97 | downsample = None
98 | if stride != 1 or self.inplanes != planes * block.expansion:
99 | downsample = nn.Sequential(
100 | nn.Conv2d(self.inplanes, planes * block.expansion,
101 | kernel_size=1, stride=stride, bias=False),
102 | nn.BatchNorm2d(planes * block.expansion),
103 | )
104 |
105 | layers = []
106 | layers.append(block(self.inplanes, planes, stride, rate=blocks[0]*rate, downsample=downsample))
107 | self.inplanes = planes * block.expansion
108 | for i in range(1, len(blocks)):
109 | layers.append(block(self.inplanes, planes, stride=1, rate=blocks[i]*rate))
110 |
111 | return nn.Sequential(*layers)
112 |
113 | def forward(self, input):
114 | x = self.conv1(input)
115 | x = self.bn1(x)
116 | x = self.relu(x)
117 | x = self.maxpool(x)
118 |
119 | x = self.layer1(x)
120 | low_level_feat = x
121 | x = self.layer2(x)
122 | x = self.layer3(x)
123 | x = self.layer4(x)
124 | return x, low_level_feat
125 |
126 | def _init_weight(self):
127 | for m in self.modules():
128 | if isinstance(m, nn.Conv2d):
129 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
130 | # m.weight.data.normal_(0, math.sqrt(2. / n))
131 | torch.nn.init.kaiming_normal_(m.weight)
132 | elif isinstance(m, nn.BatchNorm2d):
133 | m.weight.data.fill_(1)
134 | m.bias.data.zero_()
135 |
136 | def _load_pretrained_model(self):
137 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
138 | model_dict = {}
139 | state_dict = self.state_dict()
140 | for k, v in pretrain_dict.items():
141 | if k in state_dict:
142 | model_dict[k] = v
143 | state_dict.update(model_dict)
144 | self.load_state_dict(state_dict)
145 |
146 | def ResNet101(nInputChannels=3, os=16, pretrained=False):
147 | model = ResNet(nInputChannels, Bottleneck, [3, 4, 23, 3], os, pretrained=pretrained)
148 | return model
149 |
150 |
151 | class ASPP_module(nn.Module):
152 | def __init__(self, inplanes, planes, rate):
153 | super(ASPP_module, self).__init__()
154 | if rate == 1:
155 | kernel_size = 1
156 | padding = 0
157 | else:
158 | kernel_size = 3
159 | padding = rate
160 | self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
161 | stride=1, padding=padding, dilation=rate, bias=False)
162 | self.bn = nn.BatchNorm2d(planes)
163 | self.relu = nn.ReLU()
164 |
165 | self._init_weight()
166 |
167 | def forward(self, x):
168 | x = self.atrous_convolution(x)
169 | x = self.bn(x)
170 |
171 | return self.relu(x)
172 |
173 | def _init_weight(self):
174 | for m in self.modules():
175 | if isinstance(m, nn.Conv2d):
176 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
177 | # m.weight.data.normal_(0, math.sqrt(2. / n))
178 | torch.nn.init.kaiming_normal_(m.weight)
179 | elif isinstance(m, nn.BatchNorm2d):
180 | m.weight.data.fill_(1)
181 | m.bias.data.zero_()
182 |
183 |
184 | class PMG_Network(nn.Module):
185 | def __init__(self, nInputChannels=3, n_classes=1, os=16, pretrained=F, _print=True):
186 | if _print:
187 | print("Constructing PMG model...")
188 | print("Number of classes: {}".format(n_classes))
189 | print("Output stride: {}".format(os))
190 | print("Number of Input Channels: {}".format(nInputChannels))
191 | super(PMG_Network, self).__init__()
192 |
193 | # Atrous Conv
194 | self.resnet_features = ResNet101(nInputChannels, os, pretrained=pretrained)
195 |
196 | # ASPP
197 | if os == 16:
198 | rates = [1, 6, 12, 18]
199 | elif os == 8:
200 | rates = [1, 12, 24, 36]
201 | else:
202 | raise NotImplementedError
203 |
204 | self.aspp1 = ASPP_module(2048, 256, rate=rates[0])
205 | self.aspp2 = ASPP_module(2048, 256, rate=rates[1])
206 | self.aspp3 = ASPP_module(2048, 256, rate=rates[2])
207 | self.aspp4 = ASPP_module(2048, 256, rate=rates[3])
208 |
209 | self.relu = nn.ReLU()
210 |
211 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
212 | nn.Conv2d(2048, 256, 1, stride=1, bias=False),
213 | nn.BatchNorm2d(256),
214 | nn.ReLU())
215 |
216 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
217 | self.bn1 = nn.BatchNorm2d(256)
218 |
219 | # adopt [1x1, 48] for channel reduction.
220 | self.conv2 = nn.Conv2d(256, 48, 1, bias=False)
221 | self.bn2 = nn.BatchNorm2d(48)
222 |
223 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
224 | nn.BatchNorm2d(256),
225 | nn.ReLU(),
226 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
227 | nn.BatchNorm2d(256),
228 | nn.ReLU(),
229 | nn.Conv2d(256, n_classes, kernel_size=1, stride=1))
230 |
231 | def forward(self, input):
232 | x, low_level_features = self.resnet_features(input)
233 | x1 = self.aspp1(x)
234 | x2 = self.aspp2(x)
235 | x3 = self.aspp3(x)
236 | x4 = self.aspp4(x)
237 | x5 = self.global_avg_pool(x)
238 | x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
239 |
240 | x = torch.cat((x1, x2, x3, x4, x5), dim=1)
241 |
242 | x = self.conv1(x)
243 | x = self.bn1(x)
244 | x = self.relu(x)
245 | x = F.upsample(x, size=(int(math.ceil(input.size()[-2]/4)),
246 | int(math.ceil(input.size()[-1]/4))), mode='bilinear', align_corners=True)
247 |
248 | low_level_features = self.conv2(low_level_features)
249 | low_level_features = self.bn2(low_level_features)
250 | low_level_features = self.relu(low_level_features)
251 |
252 |
253 | x = torch.cat((x, low_level_features), dim=1)
254 | x = self.last_conv(x)
255 | x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
256 |
257 | return x
258 |
259 | def freeze_bn(self):
260 | for m in self.modules():
261 | if isinstance(m, nn.BatchNorm2d):
262 | m.eval()
263 |
264 | def __init_weight(self):
265 | for m in self.modules():
266 | if isinstance(m, nn.Conv2d):
267 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
268 | # m.weight.data.normal_(0, math.sqrt(2. / n))
269 | torch.nn.init.kaiming_normal_(m.weight)
270 | elif isinstance(m, nn.BatchNorm2d):
271 | m.weight.data.fill_(1)
272 | m.bias.data.zero_()
273 |
274 | def get_1x_lr_params(model):
275 | """
276 | This generator returns all the parameters of the net except for
277 | the last classification layer. Note that for each batchnorm layer,
278 | requires_grad is set to False in PMG_resnet.py, therefore this function does not return
279 | any batchnorm parameter
280 | """
281 | b = [model.resnet_features]
282 | for i in range(len(b)):
283 | for k in b[i].parameters():
284 | if k.requires_grad:
285 | yield k
286 |
287 | def get_10x_lr_params(model):
288 | """
289 | This generator returns all the parameters for the last layer of the net,
290 | which does the classification of pixel into classes
291 | """
292 | b = [model.aspp1, model.aspp2, model.aspp3, model.aspp4, model.conv1, model.conv2, model.last_conv]
293 | for j in range(len(b)):
294 | for k in b[j].parameters():
295 | if k.requires_grad:
296 | yield k
297 |
--------------------------------------------------------------------------------