├── README.md ├── data_loader.py ├── train.py └── model.py /README.md: -------------------------------------------------------------------------------- 1 | The model uses the AE-GAN (Autoencoder Generative Adversarial Network) architecture for generating upsampled images. The model is trained on Celeb-A image (1024 x 1024) dataset where input image is of 128x128 and generated image is of shape 480x480. 2 | 3 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | #importing required libraries 2 | import os 3 | import torch 4 | import numpy as np 5 | import torchvision.transforms as transforms 6 | import torchvision.datasets as datasets 7 | from PIL import Image 8 | from torch.utils.data import DataLoader, Dataset, random_split 9 | 10 | print(torch.cuda.get_device_name(0)) 11 | 12 | class ImageDataset(Dataset): 13 | 14 | def __init__(self, data_path, transform = None): 15 | self.train_data_path = data_path['data'] 16 | self.train_label_path = data_path['labels'] 17 | self.train_lables = os.listdir(self.train_label_path) 18 | self.train_data = os.listdir(self.train_data_path) 19 | self.transform = transform 20 | 21 | def __len__(self): 22 | return len(self.train_data) 23 | 24 | def __getitem__(self, indx): 25 | 26 | if indx >= len(self.train_data): 27 | raise Exception("Index should be less than {}".format(len(self.train_data))) 28 | image = Image.open(self.train_data_path + self.train_data[indx]).convert('RGB') 29 | final_label = cv2.resize(cv2.imread(self.train_label_path + self.train_lables[indx]), (480, 480)) 30 | final_label = Image.fromarray(cv2.cvtColor(final_label, cv2.COLOR_BGR2RGB)) 31 | 32 | if len(self.transform) > 0: 33 | image = self.transform(image) 34 | final_label = self.transform(final_label) 35 | 36 | return image, final_label 37 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | import torchvision.transforms as transforms 9 | import torchvision.datasets as datasets 10 | import data_loader as dl 11 | from PIL import Image 12 | from model import BasicBlock, preEncoder, Encoder, Decoder 13 | from torch.utils.data import DataLoader, Dataset, random_split 14 | 15 | epochs = 30 16 | D_Losses = [] 17 | G_Losses = [] 18 | count = 0 19 | batch_size = 4 20 | learning_rate = 0.0002 21 | lmbda = 0 22 | 23 | #data_path will different because this model is trained on different system 24 | data_path = {'train' : "data/train/", 'test' : "data/test/"} 25 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 26 | 27 | 28 | transform_img = transforms.Compose([transforms.ToTensor(), 29 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),] 30 | ) 31 | 32 | inv_normalize = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255], std=[1/0.229, 1/0.224, 1/0.255]) 33 | 34 | 35 | data = dl.ImageDataset(data_path, transform = transform_img) 36 | 37 | train_size = int(0.9 * len(data)) 38 | test_size = len(data) - train_size 39 | 40 | train_data, validation_data = random_split(data, [train_size, test_size]) 41 | 42 | train_loader = DataLoader(train_data, batch_size = batch_size, shuffle=True) 43 | validation_loader = DataLoader(validation_data, batch_size = batch_size, shuffle=True) 44 | 45 | pre_encoder = preEncoder(BasicBlock).to(device) 46 | encoder = Encoder(BasicBlock).to(device) 47 | decoder = Decoder(BasicBlock).to(device) 48 | discriminator = Discriminator(BasicBlock).to(device) 49 | 50 | ce_loss = nn.CrossEntropyLoss() 51 | l1_loss = nn.L1Loss() 52 | mse_loss = nn.MSELoss() 53 | 54 | pre_encoder_optim = optim.Adam(pre_encoder.parameters(), lr = 0.0002, weight_decay=0) 55 | encoder_optim = optim.Adam(encoder.parameters(), lr = 0.0002, weight_decay=0) 56 | decoder_optim = optim.Adam(decoder.parameters(), lr = 0.0002, weight_decay=0) 57 | discriminator_optim = optim.Adam(discriminator.parameters(), lr = learning_rate, weight_decay=lmbda) 58 | 59 | 60 | pre_encoder = pre_encoder.train() 61 | encoder = encoder.train() 62 | decoder = decoder.train() 63 | discriminator = discriminator.train() 64 | 65 | 66 | for epoch in range(epochs): 67 | for i, batch in enumerate(train_loader): 68 | 69 | image, real_image = batch 70 | image = image.to(device) 71 | real_image = real_image.to(device) 72 | real_label = torch.ones([image.shape[0]]).long().to(device) 73 | fake_label = torch.zeros([image.shape[0]]).long().to(device) 74 | 75 | pre_encoder_optim.zero_grad() 76 | encoder_optim.zero_grad() 77 | decoder_optim.zero_grad() 78 | 79 | x = pre_encoder(image) 80 | x = encoder(x) 81 | G_x = decoder(x) 82 | for i in range(2): 83 | discriminator_optim.zero_grad() 84 | D_x = discriminator(real_image) 85 | G_D_x = discriminator(G_x.detach()) 86 | 87 | real_loss = ce_loss(D_x, real_label) 88 | fake_loss = ce_loss(G_D_x, fake_label) 89 | disc_loss = real_loss + fake_loss 90 | disc_loss.backward() 91 | discriminator_optim.step() 92 | 93 | G_D_x = discriminator(G_x) 94 | gen_real_loss = ce_loss(G_D_x, real_label) 95 | abs_loss = l1_loss(G_x, real_image) 96 | se_loss = mse_loss(G_x, real_image) 97 | gen_loss = gen_real_loss + abs_loss + se_loss 98 | gen_loss.backward() 99 | 100 | pre_encoder_optim.step() 101 | encoder_optim.step() 102 | decoder_optim.step() 103 | 104 | if count == 0: 105 | print("Epoch [{}/{}] after iteration {}: \n".format((epoch+1), epochs, count)) 106 | imgs = [inv_normalize(image.detach()[0, :, :, :]).cpu().numpy().squeeze().transpose(1, 2, 0), 107 | inv_normalize(G_x.detach()[0, :, :, :]).cpu().numpy().squeeze().transpose(1, 2, 0)] 108 | lab = ['Ground Truth', 'Generated'] 109 | fig=plt.figure(figsize=(10, 10)) 110 | for i in range(1, 3): 111 | fig.add_subplot(1, 2, i) 112 | plt.imshow(imgs[i-1]) 113 | plt.title(lab[i-1]) 114 | plt.show() 115 | count += 1 116 | 117 | D_Losses.append(disc_loss.item()) 118 | G_Losses.append(abs_loss.item()) 119 | print("Epoch [{}/{}] after iteration {}: \n".format((epoch+1), epochs, count)) 120 | imgs = [inv_normalize(image.detach()[0, :, :, :]).cpu().numpy().squeeze().transpose(1, 2, 0), 121 | inv_normalize(G_x.detach()[0, :, :, :]).cpu().numpy().squeeze().transpose(1, 2, 0)] 122 | lab = ['Ground Truth', 'Generated'] 123 | fig=plt.figure(figsize=(10, 10)) 124 | for i in range(1, 3): 125 | fig.add_subplot(1, 2, i) 126 | plt.imshow(imgs[i-1]) 127 | plt.title(lab[i-1]) 128 | plt.show() 129 | 130 | plt.title("Generator and Discriminator Loss During Training") 131 | plt.plot(G_Losses,label="G") 132 | plt.plot(D_Losses,label="D") 133 | plt.xlabel("Iterations") 134 | plt.ylabel("Loss") 135 | plt.legend() 136 | plt.show() 137 | 138 | #for testing Images 139 | 140 | count = 0 141 | for i, batch in enumerate(validation_loader): 142 | image, real_image = batch 143 | x = pre_encoder(image.to(device)) 144 | x = encoder(x) 145 | x = decoder(x) 146 | imgs = [inv_normalize(image.detach()[0, :, :, :]).cpu().numpy().squeeze().transpose(1, 2, 0), 147 | inv_normalize(x.detach()[0, :, :, :]).cpu().numpy().squeeze().transpose(1, 2, 0), 148 | inv_normalize(real_image[0, :, :, :]).cpu().numpy().squeeze().transpose(1, 2, 0)] 149 | lab = ['Ground Truth', 'Generated', 'Original Image'] 150 | fig=plt.figure(figsize=(10, 10)) 151 | for i in range(1, 4): 152 | fig.add_subplot(1, 3, i) 153 | plt.imshow(imgs[i-1], cmap = 'gray') 154 | plt.title(lab[i-1]) 155 | plt.show() 156 | if count == 20: 157 | break 158 | count += 1 159 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | 8 | class BasicBlock(nn.Module): 9 | def __init__(self, channels = 128, stride = 1, padding = 1): 10 | super(BasicBlock, self).__init__() 11 | self.channels = channels 12 | self.stride = stride 13 | self.padding = padding 14 | 15 | self.conv_1 = nn.Conv2d(in_channels = self.channels, out_channels = self.channels, 16 | kernel_size = 3, stride = self.stride, padding = self.padding) 17 | self.bn_1 = nn.BatchNorm2d(self.channels) 18 | self.prelu_1 = nn.PReLU() 19 | 20 | self.conv_2 = nn.Conv2d(in_channels = self.channels, out_channels = self.channels, 21 | kernel_size = 3, stride = self.stride, padding = self.padding) 22 | self.bn_2 = nn.BatchNorm2d(self.channels) 23 | self.prelu_2 = nn.PReLU() 24 | 25 | def forward(self, x): 26 | identity = x 27 | x = self.prelu_1(self.bn_1(self.conv_1(x))) 28 | x = self.bn_2(self.conv_2(x)) + identity 29 | return self.prelu_2(x) 30 | 31 | class preEncoder(nn.Module): 32 | def __init__(self, channels = 128, stride = 1, padding = 1): 33 | super(preEncoder, self).__init__() 34 | self.channels = channels 35 | self.stride = stride 36 | self.padding = padding 37 | 38 | self.input_pass = nn.Sequential( 39 | nn.Conv2d(in_channels = 3, out_channels = 32, kernel_size = 3, stride = 1, padding = 1), 40 | nn.BatchNorm2d(32), 41 | nn.PReLU(), 42 | 43 | nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = 3, stride = 1, padding = 1), 44 | nn.BatchNorm2d(32), 45 | nn.PReLU(), 46 | 47 | nn.Conv2d(in_channels = 32, out_channels = 64,kernel_size = 3, stride = 1, padding = 1), 48 | nn.BatchNorm2d(64), 49 | nn.PReLU() 50 | ) 51 | 52 | def forward(self, x): 53 | return self.input_pass(x) 54 | 55 | class Encoder(nn.Module): 56 | def __init__(self, block): 57 | super(Encoder, self).__init__() 58 | self.block = block 59 | 60 | self.input_conv = nn.Sequential( 61 | nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 3, padding = 1), 62 | nn.BatchNorm2d(128), 63 | nn.PReLU(), 64 | 65 | nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, padding = 1), 66 | nn.BatchNorm2d(128), 67 | nn.PReLU(), 68 | ) 69 | 70 | self.layer_1 = self.make_layers(1) 71 | self.layer_2 = self.make_layers(1) 72 | self.layer_3 = self.make_layers(1) 73 | 74 | self.downsample_conv_1 = nn.Sequential( 75 | nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 5, stride = 2), 76 | nn.BatchNorm2d(128), 77 | nn.PReLU() 78 | ) 79 | 80 | self.downsample_conv_2 = nn.Sequential( 81 | nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 5, stride = 2), 82 | nn.BatchNorm2d(128), 83 | nn.PReLU() 84 | ) 85 | 86 | self.downsample_conv_3 = nn.Sequential( 87 | nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 5, stride = 2), 88 | nn.BatchNorm2d(128), 89 | nn.PReLU() 90 | ) 91 | 92 | def make_layers(self, layers): 93 | res_layers = [] 94 | for i in range(layers): 95 | res_layers.append(self.block()) 96 | return nn.Sequential(*res_layers) 97 | 98 | def forward(self, x): 99 | x = self.input_conv(x) 100 | x = self.layer_1(x) 101 | x = self.downsample_conv_1(x) 102 | x = self.layer_2(x) 103 | x = self.downsample_conv_2(x) 104 | x = self.layer_3(x) 105 | x = self.downsample_conv_3(x) 106 | return x 107 | 108 | class Decoder(nn.Module): 109 | def __init__(self, block): 110 | super(Decoder, self).__init__() 111 | self.block = block 112 | 113 | self.layer_1 = self.make_layers(1) 114 | self.layer_2 = self.make_layers(1) 115 | self.layer_3 = self.make_layers(1) 116 | 117 | self.trans_conv_1 = nn.Sequential( 118 | nn.ConvTranspose2d(in_channels = 128, out_channels = 128, kernel_size = 4, stride = 2), 119 | nn.BatchNorm2d(128), 120 | nn.PReLU() 121 | ) 122 | 123 | self.trans_conv_2 = nn.Sequential( 124 | nn.ConvTranspose2d(in_channels = 128, out_channels = 128, kernel_size = 4, stride = 2), 125 | nn.BatchNorm2d(128), 126 | nn.PReLU() 127 | ) 128 | 129 | self.trans_conv_3 = nn.Sequential( 130 | nn.ConvTranspose2d(in_channels = 128, out_channels = 128, kernel_size = 4, stride = 2), 131 | nn.BatchNorm2d(128), 132 | nn.PReLU() 133 | ) 134 | 135 | self.trans_conv_4 = nn.Sequential( 136 | nn.ConvTranspose2d(in_channels = 128, out_channels = 128, kernel_size = 5, stride = 2), 137 | nn.BatchNorm2d(128), 138 | nn.PReLU() 139 | ) 140 | 141 | self.trans_conv_5 = nn.Sequential( 142 | nn.ConvTranspose2d(in_channels = 128, out_channels = 128, kernel_size = 5, stride = 2), 143 | nn.BatchNorm2d(128), 144 | nn.PReLU() 145 | ) 146 | 147 | self.output_conv = nn.Sequential( 148 | nn.Conv2d(in_channels = 128, out_channels = 64, kernel_size = 2, padding = 0), 149 | nn.BatchNorm2d(64), 150 | nn.PReLU(), 151 | 152 | nn.Conv2d(in_channels = 64, out_channels = 32, kernel_size = 3, padding = 1), 153 | nn.BatchNorm2d(32), 154 | nn.PReLU(), 155 | 156 | nn.Conv2d(in_channels = 32, out_channels = 3, kernel_size = 3, padding = 1), 157 | ) 158 | 159 | def make_layers(self, layers): 160 | res_layers = [] 161 | for i in range(layers): 162 | res_layers.append(self.block()) 163 | return nn.Sequential(*res_layers) 164 | 165 | def forward(self, x): 166 | x = self.trans_conv_1(x) 167 | x = self.layer_1(x) 168 | x = self.trans_conv_2(x) 169 | x = self.layer_2(x) 170 | x = self.trans_conv_3(x) 171 | x = self.layer_3(x) 172 | x = self.trans_conv_4(x) 173 | x = self.trans_conv_5(x) 174 | x = self.output_conv(x) 175 | return x 176 | 177 | 178 | class Discriminator(nn.Module): 179 | def __init__(self, block): 180 | super(Discriminator, self).__init__() 181 | 182 | self.block = block 183 | 184 | self.layer_1 = self.make_layers(1) 185 | self.layer_2 = self.make_layers(1) 186 | self.layer_3 = self.make_layers(1) 187 | 188 | self.main_conv_1 = nn.Sequential( 189 | nn.Conv2d(in_channels = 3, out_channels = 32, kernel_size = 3, stride = 2), 190 | nn.BatchNorm2d(32), 191 | nn.PReLU() 192 | ) 193 | 194 | self.main_conv_2 = nn.Sequential( 195 | nn.Conv2d(in_channels = 32, out_channels = 128, kernel_size = 3, stride = 2), 196 | nn.BatchNorm2d(128), 197 | nn.PReLU() 198 | ) 199 | 200 | self.main_conv_3 = nn.Sequential( 201 | nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 5, stride = 2, padding = 2), 202 | nn.BatchNorm2d(128), 203 | nn.PReLU() 204 | ) 205 | 206 | self.main_conv_4 = nn.Sequential( 207 | nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, stride = 2), 208 | nn.BatchNorm2d(128), 209 | nn.PReLU() 210 | ) 211 | 212 | self.main_conv_5 = nn.Sequential( 213 | nn.Conv2d(in_channels = 128, out_channels = 64, kernel_size = 4, stride = 2), 214 | nn.BatchNorm2d(64), 215 | nn.PReLU() 216 | ) 217 | 218 | self.main_conv_6 = nn.Sequential( 219 | nn.Conv2d(in_channels = 64, out_channels = 32, kernel_size = 3, stride = 2), 220 | nn.BatchNorm2d(32), 221 | nn.PReLU() 222 | ) 223 | 224 | self.main_conv_7 = nn.Sequential( 225 | nn.Conv2d(in_channels = 32, out_channels = 16, kernel_size = 3, stride = 1), 226 | nn.BatchNorm2d(16), 227 | nn.PReLU() 228 | ) 229 | 230 | self.main_conv_8 = nn.Sequential( 231 | nn.Linear(256, 128), 232 | nn.BatchNorm1d(128), 233 | nn.PReLU(), 234 | 235 | nn.Linear(128, 10), 236 | nn.BatchNorm1d(10), 237 | nn.PReLU(), 238 | 239 | nn.Linear(10, 2), 240 | ) 241 | 242 | def make_layers(self, layers): 243 | res_layers = [] 244 | for i in range(layers): 245 | res_layers.append(self.block()) 246 | return nn.Sequential(*res_layers) 247 | 248 | def forward(self, x): 249 | x = self.main_conv_1(x) 250 | x = self.main_conv_2(x) 251 | x = self.layer_1(x) 252 | x = self.main_conv_3(x) 253 | x = self.layer_2(x) 254 | x = self.main_conv_4(x) 255 | x = self.layer_3(x) 256 | x = self.main_conv_5(x) 257 | x = self.main_conv_6(x) 258 | x = self.main_conv_7(x) 259 | x = x.view(x.shape[0], -1) 260 | x = self.main_conv_8(x) 261 | return x 262 | --------------------------------------------------------------------------------