├── ACGAN.py ├── BEGAN.py ├── CGAN.py ├── DRAGAN.py ├── EBGAN.py ├── GAN.py ├── LSGAN.py ├── README.md ├── WGAN.py ├── WGAN_GP.py ├── assets ├── celebA_results │ ├── BEGAN_epoch001.png │ ├── BEGAN_epoch010.png │ ├── BEGAN_epoch025.png │ ├── BEGAN_generate_animation.gif │ ├── DRAGAN_epoch001.png │ ├── DRAGAN_epoch010.png │ ├── DRAGAN_epoch025.png │ ├── DRAGAN_generate_animation.gif │ ├── EBGAN_epoch001.png │ ├── EBGAN_epoch010.png │ ├── EBGAN_epoch025.png │ ├── EBGAN_generate_animation.gif │ ├── GAN_epoch001.png │ ├── GAN_epoch010.png │ ├── GAN_epoch025.png │ ├── GAN_generate_animation.gif │ ├── LSGAN_epoch001.png │ ├── LSGAN_epoch010.png │ ├── LSGAN_epoch025.png │ ├── LSGAN_generate_animation.gif │ ├── WGAN_GP_epoch001.png │ ├── WGAN_GP_epoch010.png │ ├── WGAN_GP_epoch025.png │ ├── WGAN_GP_generate_animation.gif │ ├── WGAN_epoch001.png │ ├── WGAN_epoch010.png │ ├── WGAN_epoch025.png │ └── WGAN_generate_animation.gif ├── equations │ ├── ACGAN.png │ ├── BEGAN.png │ ├── CGAN.png │ ├── DRAGAN.png │ ├── EBGAN.png │ ├── GAN.png │ ├── LSGAN.png │ ├── WGAN.png │ ├── WGAN_GP.png │ └── infoGAN.png ├── etc │ └── GAN_structure.png ├── fashion_mnist_results │ ├── ACGAN_epoch001.png │ ├── ACGAN_epoch025.png │ ├── ACGAN_epoch050.png │ ├── ACGAN_generate_animation.gif │ ├── ACGAN_loss.png │ ├── BEGAN_epoch001.png │ ├── BEGAN_epoch025.png │ ├── BEGAN_epoch050.png │ ├── BEGAN_generate_animation.gif │ ├── BEGAN_loss.png │ ├── CGAN_epoch001.png │ ├── CGAN_epoch025.png │ ├── CGAN_epoch050.png │ ├── CGAN_generate_animation.gif │ ├── CGAN_loss.png │ ├── DRAGAN_epoch001.png │ ├── DRAGAN_epoch025.png │ ├── DRAGAN_epoch050.png │ ├── DRAGAN_generate_animation.gif │ ├── DRAGAN_loss.png │ ├── EBGAN_epoch001.png │ ├── EBGAN_epoch025.png │ ├── EBGAN_epoch050.png │ ├── EBGAN_generate_animation.gif │ ├── EBGAN_loss.png │ ├── GAN_epoch001.png │ ├── GAN_epoch025.png │ ├── GAN_epoch050.png │ ├── GAN_generate_animation.gif │ ├── GAN_loss.png │ ├── LSGAN_epoch001.png │ ├── LSGAN_epoch025.png │ ├── LSGAN_epoch050.png │ ├── LSGAN_generate_animation.gif │ ├── LSGAN_loss.png │ ├── WGAN_GP_epoch001.png │ ├── WGAN_GP_epoch025.png │ ├── WGAN_GP_epoch050.png │ ├── WGAN_GP_generate_animation.gif │ ├── WGAN_GP_loss.png │ ├── WGAN_epoch001.png │ ├── WGAN_epoch025.png │ ├── WGAN_epoch050.png │ ├── WGAN_generate_animation.gif │ ├── WGAN_loss.png │ ├── infoGAN_cont_epoch001.png │ ├── infoGAN_cont_epoch025.png │ ├── infoGAN_cont_epoch050.png │ ├── infoGAN_cont_generate_animation.gif │ ├── infoGAN_epoch001.png │ ├── infoGAN_epoch025.png │ ├── infoGAN_epoch050.png │ ├── infoGAN_generate_animation.gif │ └── infoGAN_loss.png └── mnist_results │ ├── ACGAN_epoch001.png │ ├── ACGAN_epoch025.png │ ├── ACGAN_epoch050.png │ ├── ACGAN_generate_animation.gif │ ├── ACGAN_loss.png │ ├── BEGAN_epoch001.png │ ├── BEGAN_epoch025.png │ ├── BEGAN_epoch050.png │ ├── BEGAN_generate_animation.gif │ ├── BEGAN_loss.png │ ├── CGAN_epoch001.png │ ├── CGAN_epoch025.png │ ├── CGAN_epoch050.png │ ├── CGAN_generate_animation.gif │ ├── CGAN_loss.png │ ├── DRAGAN_epoch001.png │ ├── DRAGAN_epoch025.png │ ├── DRAGAN_epoch050.png │ ├── DRAGAN_generate_animation.gif │ ├── DRAGAN_loss.png │ ├── EBGAN_epoch001.png │ ├── EBGAN_epoch025.png │ ├── EBGAN_epoch050.png │ ├── EBGAN_generate_animation.gif │ ├── EBGAN_loss.png │ ├── GAN_epoch001.png │ ├── GAN_epoch025.png │ ├── GAN_epoch050.png │ ├── GAN_generate_animation.gif │ ├── GAN_loss.png │ ├── LSGAN_epoch001.png │ ├── LSGAN_epoch025.png │ ├── LSGAN_epoch050.png │ ├── LSGAN_generate_animation.gif │ ├── LSGAN_loss.png │ ├── WGAN_GP_epoch001.png │ ├── WGAN_GP_epoch025.png │ ├── WGAN_GP_epoch050.png │ ├── WGAN_GP_generate_animation.gif │ ├── WGAN_GP_loss.png │ ├── WGAN_epoch001.png │ ├── WGAN_epoch025.png │ ├── WGAN_epoch050.png │ ├── WGAN_generate_animation.gif │ ├── WGAN_loss.png │ ├── infoGAN_cont_epoch001.png │ ├── infoGAN_cont_epoch025.png │ ├── infoGAN_cont_epoch050.png │ ├── infoGAN_cont_generate_animation.gif │ ├── infoGAN_epoch001.png │ ├── infoGAN_epoch025.png │ ├── infoGAN_epoch050.png │ ├── infoGAN_generate_animation.gif │ └── infoGAN_loss.png ├── dataloader.py ├── infoGAN.py ├── main.py └── utils.py /ACGAN.py: -------------------------------------------------------------------------------- 1 | import utils, torch, time, os, pickle 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from dataloader import dataloader 6 | 7 | class generator(nn.Module): 8 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 9 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 10 | def __init__(self, input_dim=100, output_dim=1, input_size=32, class_num=10): 11 | super(generator, self).__init__() 12 | self.input_dim = input_dim 13 | self.output_dim = output_dim 14 | self.input_size = input_size 15 | self.class_num = class_num 16 | 17 | self.fc = nn.Sequential( 18 | nn.Linear(self.input_dim + self.class_num, 1024), 19 | nn.BatchNorm1d(1024), 20 | nn.ReLU(), 21 | nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)), 22 | nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)), 23 | nn.ReLU(), 24 | ) 25 | self.deconv = nn.Sequential( 26 | nn.ConvTranspose2d(128, 64, 4, 2, 1), 27 | nn.BatchNorm2d(64), 28 | nn.ReLU(), 29 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), 30 | nn.Tanh(), 31 | ) 32 | utils.initialize_weights(self) 33 | 34 | def forward(self, input, label): 35 | x = torch.cat([input, label], 1) 36 | x = self.fc(x) 37 | x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4)) 38 | x = self.deconv(x) 39 | 40 | return x 41 | 42 | class discriminator(nn.Module): 43 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 44 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S 45 | def __init__(self, input_dim=1, output_dim=1, input_size=32, class_num=10): 46 | super(discriminator, self).__init__() 47 | self.input_dim = input_dim 48 | self.output_dim = output_dim 49 | self.input_size = input_size 50 | self.class_num = class_num 51 | 52 | self.conv = nn.Sequential( 53 | nn.Conv2d(self.input_dim, 64, 4, 2, 1), 54 | nn.LeakyReLU(0.2), 55 | nn.Conv2d(64, 128, 4, 2, 1), 56 | nn.BatchNorm2d(128), 57 | nn.LeakyReLU(0.2), 58 | ) 59 | self.fc1 = nn.Sequential( 60 | nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024), 61 | nn.BatchNorm1d(1024), 62 | nn.LeakyReLU(0.2), 63 | ) 64 | self.dc = nn.Sequential( 65 | nn.Linear(1024, self.output_dim), 66 | nn.Sigmoid(), 67 | ) 68 | self.cl = nn.Sequential( 69 | nn.Linear(1024, self.class_num), 70 | ) 71 | utils.initialize_weights(self) 72 | 73 | def forward(self, input): 74 | x = self.conv(input) 75 | x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4)) 76 | x = self.fc1(x) 77 | d = self.dc(x) 78 | c = self.cl(x) 79 | 80 | return d, c 81 | 82 | class ACGAN(object): 83 | def __init__(self, args): 84 | # parameters 85 | self.epoch = args.epoch 86 | self.sample_num = 100 87 | self.batch_size = args.batch_size 88 | self.save_dir = args.save_dir 89 | self.result_dir = args.result_dir 90 | self.dataset = args.dataset 91 | self.log_dir = args.log_dir 92 | self.gpu_mode = args.gpu_mode 93 | self.model_name = args.gan_type 94 | self.input_size = args.input_size 95 | self.z_dim = 62 96 | self.class_num = 10 97 | self.sample_num = self.class_num ** 2 98 | 99 | # load dataset 100 | self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size) 101 | data = self.data_loader.__iter__().__next__()[0] 102 | 103 | # networks init 104 | self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size) 105 | self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size) 106 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) 107 | self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) 108 | 109 | if self.gpu_mode: 110 | self.G.cuda() 111 | self.D.cuda() 112 | self.BCE_loss = nn.BCELoss().cuda() 113 | self.CE_loss = nn.CrossEntropyLoss().cuda() 114 | else: 115 | self.BCE_loss = nn.BCELoss() 116 | self.CE_loss = nn.CrossEntropyLoss() 117 | 118 | print('---------- Networks architecture -------------') 119 | utils.print_network(self.G) 120 | utils.print_network(self.D) 121 | print('-----------------------------------------------') 122 | 123 | # fixed noise & condition 124 | self.sample_z_ = torch.zeros((self.sample_num, self.z_dim)) 125 | for i in range(self.class_num): 126 | self.sample_z_[i*self.class_num] = torch.rand(1, self.z_dim) 127 | for j in range(1, self.class_num): 128 | self.sample_z_[i*self.class_num + j] = self.sample_z_[i*self.class_num] 129 | 130 | temp = torch.zeros((self.class_num, 1)) 131 | for i in range(self.class_num): 132 | temp[i, 0] = i 133 | 134 | temp_y = torch.zeros((self.sample_num, 1)) 135 | for i in range(self.class_num): 136 | temp_y[i*self.class_num: (i+1)*self.class_num] = temp 137 | 138 | self.sample_y_ = torch.zeros((self.sample_num, self.class_num)).scatter_(1, temp_y.type(torch.LongTensor), 1) 139 | if self.gpu_mode: 140 | self.sample_z_, self.sample_y_ = self.sample_z_.cuda(), self.sample_y_.cuda() 141 | 142 | def train(self): 143 | self.train_hist = {} 144 | self.train_hist['D_loss'] = [] 145 | self.train_hist['G_loss'] = [] 146 | self.train_hist['per_epoch_time'] = [] 147 | self.train_hist['total_time'] = [] 148 | 149 | self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1) 150 | if self.gpu_mode: 151 | self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda() 152 | 153 | self.D.train() 154 | print('training start!!') 155 | start_time = time.time() 156 | for epoch in range(self.epoch): 157 | self.G.train() 158 | epoch_start_time = time.time() 159 | for iter, (x_, y_) in enumerate(self.data_loader): 160 | if iter == self.data_loader.dataset.__len__() // self.batch_size: 161 | break 162 | z_ = torch.rand((self.batch_size, self.z_dim)) 163 | y_vec_ = torch.zeros((self.batch_size, self.class_num)).scatter_(1, y_.type(torch.LongTensor).unsqueeze(1), 1) 164 | if self.gpu_mode: 165 | x_, z_, y_vec_ = x_.cuda(), z_.cuda(), y_vec_.cuda() 166 | 167 | # update D network 168 | self.D_optimizer.zero_grad() 169 | 170 | D_real, C_real = self.D(x_) 171 | D_real_loss = self.BCE_loss(D_real, self.y_real_) 172 | C_real_loss = self.CE_loss(C_real, torch.max(y_vec_, 1)[1]) 173 | 174 | G_ = self.G(z_, y_vec_) 175 | D_fake, C_fake = self.D(G_) 176 | D_fake_loss = self.BCE_loss(D_fake, self.y_fake_) 177 | C_fake_loss = self.CE_loss(C_fake, torch.max(y_vec_, 1)[1]) 178 | 179 | D_loss = D_real_loss + C_real_loss + D_fake_loss + C_fake_loss 180 | self.train_hist['D_loss'].append(D_loss.item()) 181 | 182 | D_loss.backward() 183 | self.D_optimizer.step() 184 | 185 | # update G network 186 | self.G_optimizer.zero_grad() 187 | 188 | G_ = self.G(z_, y_vec_) 189 | D_fake, C_fake = self.D(G_) 190 | 191 | G_loss = self.BCE_loss(D_fake, self.y_real_) 192 | C_fake_loss = self.CE_loss(C_fake, torch.max(y_vec_, 1)[1]) 193 | 194 | G_loss += C_fake_loss 195 | self.train_hist['G_loss'].append(G_loss.item()) 196 | 197 | G_loss.backward() 198 | self.G_optimizer.step() 199 | 200 | if ((iter + 1) % 100) == 0: 201 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % 202 | ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item())) 203 | 204 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) 205 | with torch.no_grad(): 206 | self.visualize_results((epoch+1)) 207 | 208 | self.train_hist['total_time'].append(time.time() - start_time) 209 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), 210 | self.epoch, self.train_hist['total_time'][0])) 211 | print("Training finish!... save training results") 212 | 213 | self.save() 214 | utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, 215 | self.epoch) 216 | utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) 217 | 218 | def visualize_results(self, epoch, fix=True): 219 | self.G.eval() 220 | 221 | if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name): 222 | os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name) 223 | 224 | image_frame_dim = int(np.floor(np.sqrt(self.sample_num))) 225 | 226 | if fix: 227 | """ fixed noise """ 228 | samples = self.G(self.sample_z_, self.sample_y_) 229 | else: 230 | """ random noise """ 231 | sample_y_ = torch.zeros(self.batch_size, self.class_num).scatter_(1, torch.randint(0, self.class_num - 1, (self.batch_size, 1)).type(torch.LongTensor), 1) 232 | sample_z_ = torch.rand((self.batch_size, self.z_dim)) 233 | if self.gpu_mode: 234 | sample_z_, sample_y_ = sample_z_.cuda(), sample_y_.cuda() 235 | 236 | samples = self.G(sample_z_, sample_y_) 237 | 238 | if self.gpu_mode: 239 | samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) 240 | else: 241 | samples = samples.data.numpy().transpose(0, 2, 3, 1) 242 | 243 | samples = (samples + 1) / 2 244 | utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 245 | self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png') 246 | 247 | def save(self): 248 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 249 | 250 | if not os.path.exists(save_dir): 251 | os.makedirs(save_dir) 252 | 253 | torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl')) 254 | torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl')) 255 | 256 | with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: 257 | pickle.dump(self.train_hist, f) 258 | 259 | def load(self): 260 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 261 | 262 | self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) 263 | self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) 264 | -------------------------------------------------------------------------------- /BEGAN.py: -------------------------------------------------------------------------------- 1 | import utils, torch, time, os, pickle 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from dataloader import dataloader 6 | 7 | class generator(nn.Module): 8 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 9 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 10 | def __init__(self, input_dim=100, output_dim=1, input_size=32): 11 | super(generator, self).__init__() 12 | self.input_dim = input_dim 13 | self.output_dim = output_dim 14 | self.input_size = input_size 15 | 16 | self.fc = nn.Sequential( 17 | nn.Linear(self.input_dim, 1024), 18 | nn.BatchNorm1d(1024), 19 | nn.ReLU(), 20 | nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)), 21 | nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)), 22 | nn.ReLU(), 23 | ) 24 | self.deconv = nn.Sequential( 25 | nn.ConvTranspose2d(128, 64, 4, 2, 1), 26 | nn.BatchNorm2d(64), 27 | nn.ReLU(), 28 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), 29 | nn.Tanh(), 30 | ) 31 | utils.initialize_weights(self) 32 | 33 | def forward(self, input): 34 | x = self.fc(input) 35 | x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4)) 36 | x = self.deconv(x) 37 | 38 | return x 39 | 40 | class discriminator(nn.Module): 41 | # It must be Auto-Encoder style architecture 42 | # Architecture : (64)4c2s-FC32-FC64*14*14_BR-(1)4dc2s_S 43 | def __init__(self, input_dim=1, output_dim=1, input_size=32): 44 | super(discriminator, self).__init__() 45 | self.input_dim = input_dim 46 | self.output_dim = output_dim 47 | self.input_size = input_size 48 | 49 | self.conv = nn.Sequential( 50 | nn.Conv2d(self.input_dim, 64, 4, 2, 1), 51 | nn.ReLU(), 52 | ) 53 | self.fc = nn.Sequential( 54 | nn.Linear(64 * (self.input_size // 2) * (self.input_size // 2), 32), 55 | nn.Linear(32, 64 * (self.input_size // 2) * (self.input_size // 2)), 56 | ) 57 | self.deconv = nn.Sequential( 58 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), 59 | #nn.Sigmoid(), 60 | ) 61 | 62 | utils.initialize_weights(self) 63 | 64 | def forward(self, input): 65 | x = self.conv(input) 66 | x = x.view(x.size()[0], -1) 67 | x = self.fc(x) 68 | x = x.view(-1, 64, (self.input_size // 2), (self.input_size // 2)) 69 | x = self.deconv(x) 70 | 71 | return x 72 | 73 | class BEGAN(object): 74 | def __init__(self, args): 75 | # parameters 76 | self.epoch = args.epoch 77 | self.sample_num = 100 78 | self.batch_size = args.batch_size 79 | self.save_dir = args.save_dir 80 | self.result_dir = args.result_dir 81 | self.dataset = args.dataset 82 | self.log_dir = args.log_dir 83 | self.gpu_mode = args.gpu_mode 84 | self.model_name = args.gan_type 85 | self.input_size = args.input_size 86 | self.z_dim = 62 87 | self.gamma = 1 88 | self.lambda_ = 0.001 89 | self.k = 0.0 90 | self.lr_lower_boundary = 0.00002 91 | 92 | # load dataset 93 | self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size) 94 | data = self.data_loader.__iter__().__next__()[0] 95 | 96 | # networks init 97 | self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size) 98 | self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size) 99 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=0.0002, betas=(args.beta1, args.beta2)) 100 | self.D_optimizer = optim.Adam(self.D.parameters(), lr=0.0002, betas=(args.beta1, args.beta2)) 101 | 102 | if self.gpu_mode: 103 | self.G.cuda() 104 | self.D.cuda() 105 | 106 | print('---------- Networks architecture -------------') 107 | utils.print_network(self.G) 108 | utils.print_network(self.D) 109 | print('-----------------------------------------------') 110 | 111 | # fixed noise 112 | self.sample_z_ = torch.rand((self.batch_size, self.z_dim)) 113 | if self.gpu_mode: 114 | self.sample_z_ = self.sample_z_.cuda() 115 | 116 | def train(self): 117 | self.train_hist = {} 118 | self.train_hist['D_loss'] = [] 119 | self.train_hist['G_loss'] = [] 120 | self.train_hist['per_epoch_time'] = [] 121 | self.train_hist['total_time'] = [] 122 | self.M = {} 123 | self.M['pre'] = [] 124 | self.M['pre'].append(1) 125 | self.M['cur'] = [] 126 | 127 | self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1) 128 | if self.gpu_mode: 129 | self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda() 130 | 131 | self.D.train() 132 | print('training start!!') 133 | start_time = time.time() 134 | for epoch in range(self.epoch): 135 | self.G.train() 136 | epoch_start_time = time.time() 137 | for iter, (x_, _) in enumerate(self.data_loader): 138 | if iter == self.data_loader.dataset.__len__() // self.batch_size: 139 | break 140 | 141 | z_ = torch.rand((self.batch_size, self.z_dim)) 142 | 143 | if self.gpu_mode: 144 | x_, z_ = x_.cuda(), z_.cuda() 145 | 146 | # update D network 147 | self.D_optimizer.zero_grad() 148 | 149 | D_real = self.D(x_) 150 | D_real_loss = torch.mean(torch.abs(D_real - x_)) 151 | 152 | G_ = self.G(z_) 153 | D_fake = self.D(G_) 154 | D_fake_loss = torch.mean(torch.abs(D_fake - G_)) 155 | 156 | D_loss = D_real_loss - self.k * D_fake_loss 157 | self.train_hist['D_loss'].append(D_loss.item()) 158 | 159 | D_loss.backward() 160 | self.D_optimizer.step() 161 | 162 | # update G network 163 | self.G_optimizer.zero_grad() 164 | 165 | G_ = self.G(z_) 166 | D_fake = self.D(G_) 167 | D_fake_loss = torch.mean(torch.abs(D_fake - G_)) 168 | 169 | G_loss = D_fake_loss 170 | self.train_hist['G_loss'].append(G_loss.item()) 171 | 172 | G_loss.backward() 173 | self.G_optimizer.step() 174 | 175 | # convergence metric 176 | temp_M = D_real_loss + torch.abs(self.gamma * D_real_loss - G_loss) 177 | 178 | # operation for updating k 179 | temp_k = self.k + self.lambda_ * (self.gamma * D_real_loss - G_loss) 180 | temp_k = temp_k.item() 181 | 182 | self.k = min(max(temp_k, 0), 1) 183 | self.M['cur'] = temp_M.item() 184 | 185 | if ((iter + 1) % 100) == 0: 186 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f, M: %.8f, k: %.8f" % 187 | ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item(), self.M['cur'], self.k)) 188 | 189 | 190 | # if epoch == 0: 191 | # self.M['pre'] = self.M['cur'] 192 | # self.M['cur'] = [] 193 | # else: 194 | if np.mean(self.M['pre']) < np.mean(self.M['cur']): 195 | pre_lr = self.G_optimizer.param_groups[0]['lr'] 196 | self.G_optimizer.param_groups[0]['lr'] = max(self.G_optimizer.param_groups[0]['lr'] / 2.0, 197 | self.lr_lower_boundary) 198 | self.D_optimizer.param_groups[0]['lr'] = max(self.D_optimizer.param_groups[0]['lr'] / 2.0, 199 | self.lr_lower_boundary) 200 | print('M_pre: ' + str(np.mean(self.M['pre'])) + ', M_cur: ' + str( 201 | np.mean(self.M['cur'])) + ', lr: ' + str(pre_lr) + ' --> ' + str( 202 | self.G_optimizer.param_groups[0]['lr'])) 203 | else: 204 | print('M_pre: ' + str(np.mean(self.M['pre'])) + ', M_cur: ' + str(np.mean(self.M['cur']))) 205 | self.M['pre'] = self.M['cur'] 206 | 207 | self.M['cur'] = [] 208 | 209 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) 210 | with torch.no_grad(): 211 | self.visualize_results((epoch+1)) 212 | 213 | self.train_hist['total_time'].append(time.time() - start_time) 214 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), 215 | self.epoch, self.train_hist['total_time'][0])) 216 | print("Training finish!... save training results") 217 | 218 | self.save() 219 | utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, 220 | self.epoch) 221 | utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) 222 | 223 | def visualize_results(self, epoch, fix=True): 224 | self.G.eval() 225 | 226 | if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name): 227 | os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name) 228 | 229 | tot_num_samples = min(self.sample_num, self.batch_size) 230 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 231 | 232 | if fix: 233 | """ fixed noise """ 234 | samples = self.G(self.sample_z_) 235 | else: 236 | """ random noise """ 237 | sample_z_ = torch.rand((self.batch_size, self.z_dim)) 238 | if self.gpu_mode: 239 | sample_z_ = sample_z_.cuda() 240 | 241 | samples = self.G(sample_z_) 242 | 243 | if self.gpu_mode: 244 | samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) 245 | else: 246 | samples = samples.data.numpy().transpose(0, 2, 3, 1) 247 | 248 | samples = (samples + 1) / 2 249 | utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 250 | self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png') 251 | 252 | def save(self): 253 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 254 | 255 | if not os.path.exists(save_dir): 256 | os.makedirs(save_dir) 257 | 258 | torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl')) 259 | torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl')) 260 | 261 | with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: 262 | pickle.dump(self.train_hist, f) 263 | 264 | def load(self): 265 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 266 | 267 | self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) 268 | self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) -------------------------------------------------------------------------------- /CGAN.py: -------------------------------------------------------------------------------- 1 | import utils, torch, time, os, pickle 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from dataloader import dataloader 6 | 7 | class generator(nn.Module): 8 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 9 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 10 | def __init__(self, input_dim=100, output_dim=1, input_size=32, class_num=10): 11 | super(generator, self).__init__() 12 | self.input_dim = input_dim 13 | self.output_dim = output_dim 14 | self.input_size = input_size 15 | self.class_num = class_num 16 | 17 | self.fc = nn.Sequential( 18 | nn.Linear(self.input_dim + self.class_num, 1024), 19 | nn.BatchNorm1d(1024), 20 | nn.ReLU(), 21 | nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)), 22 | nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)), 23 | nn.ReLU(), 24 | ) 25 | self.deconv = nn.Sequential( 26 | nn.ConvTranspose2d(128, 64, 4, 2, 1), 27 | nn.BatchNorm2d(64), 28 | nn.ReLU(), 29 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), 30 | nn.Tanh(), 31 | ) 32 | utils.initialize_weights(self) 33 | 34 | def forward(self, input, label): 35 | x = torch.cat([input, label], 1) 36 | x = self.fc(x) 37 | x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4)) 38 | x = self.deconv(x) 39 | 40 | return x 41 | 42 | class discriminator(nn.Module): 43 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 44 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S 45 | def __init__(self, input_dim=1, output_dim=1, input_size=32, class_num=10): 46 | super(discriminator, self).__init__() 47 | self.input_dim = input_dim 48 | self.output_dim = output_dim 49 | self.input_size = input_size 50 | self.class_num = class_num 51 | 52 | self.conv = nn.Sequential( 53 | nn.Conv2d(self.input_dim + self.class_num, 64, 4, 2, 1), 54 | nn.LeakyReLU(0.2), 55 | nn.Conv2d(64, 128, 4, 2, 1), 56 | nn.BatchNorm2d(128), 57 | nn.LeakyReLU(0.2), 58 | ) 59 | self.fc = nn.Sequential( 60 | nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024), 61 | nn.BatchNorm1d(1024), 62 | nn.LeakyReLU(0.2), 63 | nn.Linear(1024, self.output_dim), 64 | nn.Sigmoid(), 65 | ) 66 | utils.initialize_weights(self) 67 | 68 | def forward(self, input, label): 69 | x = torch.cat([input, label], 1) 70 | x = self.conv(x) 71 | x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4)) 72 | x = self.fc(x) 73 | 74 | return x 75 | 76 | class CGAN(object): 77 | def __init__(self, args): 78 | # parameters 79 | self.epoch = args.epoch 80 | self.batch_size = args.batch_size 81 | self.save_dir = args.save_dir 82 | self.result_dir = args.result_dir 83 | self.dataset = args.dataset 84 | self.log_dir = args.log_dir 85 | self.gpu_mode = args.gpu_mode 86 | self.model_name = args.gan_type 87 | self.input_size = args.input_size 88 | self.z_dim = 62 89 | self.class_num = 10 90 | self.sample_num = self.class_num ** 2 91 | 92 | # load dataset 93 | self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size) 94 | data = self.data_loader.__iter__().__next__()[0] 95 | 96 | # networks init 97 | self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size, class_num=self.class_num) 98 | self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size, class_num=self.class_num) 99 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) 100 | self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) 101 | 102 | if self.gpu_mode: 103 | self.G.cuda() 104 | self.D.cuda() 105 | self.BCE_loss = nn.BCELoss().cuda() 106 | else: 107 | self.BCE_loss = nn.BCELoss() 108 | 109 | print('---------- Networks architecture -------------') 110 | utils.print_network(self.G) 111 | utils.print_network(self.D) 112 | print('-----------------------------------------------') 113 | 114 | # fixed noise & condition 115 | self.sample_z_ = torch.zeros((self.sample_num, self.z_dim)) 116 | for i in range(self.class_num): 117 | self.sample_z_[i*self.class_num] = torch.rand(1, self.z_dim) 118 | for j in range(1, self.class_num): 119 | self.sample_z_[i*self.class_num + j] = self.sample_z_[i*self.class_num] 120 | 121 | temp = torch.zeros((self.class_num, 1)) 122 | for i in range(self.class_num): 123 | temp[i, 0] = i 124 | 125 | temp_y = torch.zeros((self.sample_num, 1)) 126 | for i in range(self.class_num): 127 | temp_y[i*self.class_num: (i+1)*self.class_num] = temp 128 | 129 | self.sample_y_ = torch.zeros((self.sample_num, self.class_num)).scatter_(1, temp_y.type(torch.LongTensor), 1) 130 | if self.gpu_mode: 131 | self.sample_z_, self.sample_y_ = self.sample_z_.cuda(), self.sample_y_.cuda() 132 | 133 | def train(self): 134 | self.train_hist = {} 135 | self.train_hist['D_loss'] = [] 136 | self.train_hist['G_loss'] = [] 137 | self.train_hist['per_epoch_time'] = [] 138 | self.train_hist['total_time'] = [] 139 | 140 | self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1) 141 | if self.gpu_mode: 142 | self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda() 143 | 144 | self.D.train() 145 | print('training start!!') 146 | start_time = time.time() 147 | for epoch in range(self.epoch): 148 | self.G.train() 149 | epoch_start_time = time.time() 150 | for iter, (x_, y_) in enumerate(self.data_loader): 151 | if iter == self.data_loader.dataset.__len__() // self.batch_size: 152 | break 153 | 154 | z_ = torch.rand((self.batch_size, self.z_dim)) 155 | y_vec_ = torch.zeros((self.batch_size, self.class_num)).scatter_(1, y_.type(torch.LongTensor).unsqueeze(1), 1) 156 | y_fill_ = y_vec_.unsqueeze(2).unsqueeze(3).expand(self.batch_size, self.class_num, self.input_size, self.input_size) 157 | if self.gpu_mode: 158 | x_, z_, y_vec_, y_fill_ = x_.cuda(), z_.cuda(), y_vec_.cuda(), y_fill_.cuda() 159 | 160 | # update D network 161 | self.D_optimizer.zero_grad() 162 | 163 | D_real = self.D(x_, y_fill_) 164 | D_real_loss = self.BCE_loss(D_real, self.y_real_) 165 | 166 | G_ = self.G(z_, y_vec_) 167 | D_fake = self.D(G_, y_fill_) 168 | D_fake_loss = self.BCE_loss(D_fake, self.y_fake_) 169 | 170 | D_loss = D_real_loss + D_fake_loss 171 | self.train_hist['D_loss'].append(D_loss.item()) 172 | 173 | D_loss.backward() 174 | self.D_optimizer.step() 175 | 176 | # update G network 177 | self.G_optimizer.zero_grad() 178 | 179 | G_ = self.G(z_, y_vec_) 180 | D_fake = self.D(G_, y_fill_) 181 | G_loss = self.BCE_loss(D_fake, self.y_real_) 182 | self.train_hist['G_loss'].append(G_loss.item()) 183 | 184 | G_loss.backward() 185 | self.G_optimizer.step() 186 | 187 | if ((iter + 1) % 100) == 0: 188 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % 189 | ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item())) 190 | 191 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) 192 | with torch.no_grad(): 193 | self.visualize_results((epoch+1)) 194 | 195 | self.train_hist['total_time'].append(time.time() - start_time) 196 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), 197 | self.epoch, self.train_hist['total_time'][0])) 198 | print("Training finish!... save training results") 199 | 200 | self.save() 201 | utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, 202 | self.epoch) 203 | utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) 204 | 205 | def visualize_results(self, epoch, fix=True): 206 | self.G.eval() 207 | 208 | if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name): 209 | os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name) 210 | 211 | image_frame_dim = int(np.floor(np.sqrt(self.sample_num))) 212 | 213 | if fix: 214 | """ fixed noise """ 215 | samples = self.G(self.sample_z_, self.sample_y_) 216 | else: 217 | """ random noise """ 218 | sample_y_ = torch.zeros(self.batch_size, self.class_num).scatter_(1, torch.randint(0, self.class_num - 1, (self.batch_size, 1)).type(torch.LongTensor), 1) 219 | sample_z_ = torch.rand((self.batch_size, self.z_dim)) 220 | if self.gpu_mode: 221 | sample_z_, sample_y_ = sample_z_.cuda(), sample_y_.cuda() 222 | 223 | samples = self.G(sample_z_, sample_y_) 224 | 225 | if self.gpu_mode: 226 | samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) 227 | else: 228 | samples = samples.data.numpy().transpose(0, 2, 3, 1) 229 | 230 | samples = (samples + 1) / 2 231 | utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 232 | self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png') 233 | 234 | def save(self): 235 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 236 | 237 | if not os.path.exists(save_dir): 238 | os.makedirs(save_dir) 239 | 240 | torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl')) 241 | torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl')) 242 | 243 | with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: 244 | pickle.dump(self.train_hist, f) 245 | 246 | def load(self): 247 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 248 | 249 | self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) 250 | self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) -------------------------------------------------------------------------------- /DRAGAN.py: -------------------------------------------------------------------------------- 1 | import utils, torch, time, os, pickle 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.autograd import grad 6 | from dataloader import dataloader 7 | 8 | class generator(nn.Module): 9 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 10 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 11 | def __init__(self, input_dim=100, output_dim=1, input_size=32): 12 | super(generator, self).__init__() 13 | self.input_dim = input_dim 14 | self.output_dim = output_dim 15 | self.input_size = input_size 16 | 17 | self.fc = nn.Sequential( 18 | nn.Linear(self.input_dim, 1024), 19 | nn.BatchNorm1d(1024), 20 | nn.ReLU(), 21 | nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)), 22 | nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)), 23 | nn.ReLU(), 24 | ) 25 | self.deconv = nn.Sequential( 26 | nn.ConvTranspose2d(128, 64, 4, 2, 1), 27 | nn.BatchNorm2d(64), 28 | nn.ReLU(), 29 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), 30 | nn.Tanh(), 31 | ) 32 | utils.initialize_weights(self) 33 | 34 | def forward(self, input): 35 | x = self.fc(input) 36 | x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4)) 37 | x = self.deconv(x) 38 | 39 | return x 40 | 41 | class discriminator(nn.Module): 42 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 43 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S 44 | def __init__(self, input_dim=1, output_dim=1, input_size=32): 45 | super(discriminator, self).__init__() 46 | self.input_dim = input_dim 47 | self.output_dim = output_dim 48 | self.input_size = input_size 49 | 50 | self.conv = nn.Sequential( 51 | nn.Conv2d(self.input_dim, 64, 4, 2, 1), 52 | nn.LeakyReLU(0.2), 53 | nn.Conv2d(64, 128, 4, 2, 1), 54 | nn.BatchNorm2d(128), 55 | nn.LeakyReLU(0.2), 56 | ) 57 | self.fc = nn.Sequential( 58 | nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024), 59 | nn.BatchNorm1d(1024), 60 | nn.LeakyReLU(0.2), 61 | nn.Linear(1024, self.output_dim), 62 | nn.Sigmoid(), 63 | ) 64 | utils.initialize_weights(self) 65 | 66 | def forward(self, input): 67 | x = self.conv(input) 68 | x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4)) 69 | x = self.fc(x) 70 | 71 | return x 72 | 73 | class DRAGAN(object): 74 | def __init__(self, args): 75 | # parameters 76 | self.epoch = args.epoch 77 | self.sample_num = 100 78 | self.batch_size = args.batch_size 79 | self.save_dir = args.save_dir 80 | self.result_dir = args.result_dir 81 | self.dataset = args.dataset 82 | self.log_dir = args.log_dir 83 | self.gpu_mode = args.gpu_mode 84 | self.model_name = args.gan_type 85 | self.input_size = args.input_size 86 | self.z_dim = 62 87 | self.lambda_ = 0.25 88 | 89 | # load dataset 90 | self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size) 91 | data = self.data_loader.__iter__().__next__()[0] 92 | 93 | # networks init 94 | self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size) 95 | self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size) 96 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) 97 | self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) 98 | 99 | if self.gpu_mode: 100 | self.G.cuda() 101 | self.D.cuda() 102 | self.BCE_loss = nn.BCELoss().cuda() 103 | else: 104 | self.BCE_loss = nn.BCELoss() 105 | 106 | print('---------- Networks architecture -------------') 107 | utils.print_network(self.G) 108 | utils.print_network(self.D) 109 | print('-----------------------------------------------') 110 | 111 | # fixed noise 112 | self.sample_z_ = torch.rand((self.batch_size, self.z_dim)) 113 | if self.gpu_mode: 114 | self.sample_z_ = self.sample_z_.cuda() 115 | 116 | def train(self): 117 | self.train_hist = {} 118 | self.train_hist['D_loss'] = [] 119 | self.train_hist['G_loss'] = [] 120 | self.train_hist['per_epoch_time'] = [] 121 | self.train_hist['total_time'] = [] 122 | 123 | self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1) 124 | if self.gpu_mode: 125 | self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda() 126 | 127 | self.D.train() 128 | print('training start!!') 129 | start_time = time.time() 130 | for epoch in range(self.epoch): 131 | epoch_start_time = time.time() 132 | self.G.train() 133 | for iter, (x_, _) in enumerate(self.data_loader): 134 | if iter == self.data_loader.dataset.__len__() // self.batch_size: 135 | break 136 | 137 | z_ = torch.rand((self.batch_size, self.z_dim)) 138 | if self.gpu_mode: 139 | x_, z_ = x_.cuda(), z_.cuda() 140 | 141 | # update D network 142 | self.D_optimizer.zero_grad() 143 | 144 | D_real = self.D(x_) 145 | D_real_loss = self.BCE_loss(D_real, self.y_real_) 146 | 147 | G_ = self.G(z_) 148 | D_fake = self.D(G_) 149 | D_fake_loss = self.BCE_loss(D_fake, self.y_fake_) 150 | 151 | """ DRAGAN Loss (Gradient penalty) """ 152 | # This is borrowed from https://github.com/kodalinaveen3/DRAGAN/blob/master/DRAGAN.ipynb 153 | alpha = torch.rand(self.batch_size, 1, 1, 1).cuda() 154 | if self.gpu_mode: 155 | alpha = alpha.cuda() 156 | x_p = x_ + 0.5 * x_.std() * torch.rand(x_.size()).cuda() 157 | else: 158 | x_p = x_ + 0.5 * x_.std() * torch.rand(x_.size()) 159 | differences = x_p - x_ 160 | interpolates = x_ + (alpha * differences) 161 | interpolates.requires_grad = True 162 | pred_hat = self.D(interpolates) 163 | if self.gpu_mode: 164 | gradients = grad(outputs=pred_hat, inputs=interpolates, grad_outputs=torch.ones(pred_hat.size()).cuda(), 165 | create_graph=True, retain_graph=True, only_inputs=True)[0] 166 | else: 167 | gradients = grad(outputs=pred_hat, inputs=interpolates, grad_outputs=torch.ones(pred_hat.size()), 168 | create_graph=True, retain_graph=True, only_inputs=True)[0] 169 | 170 | gradient_penalty = self.lambda_ * ((gradients.view(gradients.size()[0], -1).norm(2, 1) - 1) ** 2).mean() 171 | 172 | D_loss = D_real_loss + D_fake_loss + gradient_penalty 173 | self.train_hist['D_loss'].append(D_loss.item()) 174 | D_loss.backward() 175 | self.D_optimizer.step() 176 | 177 | # update G network 178 | self.G_optimizer.zero_grad() 179 | 180 | G_ = self.G(z_) 181 | D_fake = self.D(G_) 182 | 183 | G_loss = self.BCE_loss(D_fake, self.y_real_) 184 | self.train_hist['G_loss'].append(G_loss.item()) 185 | 186 | G_loss.backward() 187 | self.G_optimizer.step() 188 | 189 | if ((iter + 1) % 100) == 0: 190 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % 191 | ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item())) 192 | 193 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) 194 | with torch.no_grad(): 195 | self.visualize_results((epoch+1)) 196 | 197 | self.train_hist['total_time'].append(time.time() - start_time) 198 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), 199 | self.epoch, self.train_hist['total_time'][0])) 200 | print("Training finish!... save training results") 201 | 202 | self.save() 203 | utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, self.epoch) 204 | utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) 205 | 206 | def visualize_results(self, epoch, fix=True): 207 | self.G.eval() 208 | 209 | if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name): 210 | os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name) 211 | 212 | tot_num_samples = min(self.sample_num, self.batch_size) 213 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 214 | 215 | if fix: 216 | """ fixed noise """ 217 | samples = self.G(self.sample_z_) 218 | else: 219 | """ random noise """ 220 | sample_z_ = torch.rand((self.batch_size, self.z_dim)) 221 | if self.gpu_mode: 222 | sample_z_ = sample_z_.cuda() 223 | 224 | samples = self.G(sample_z_) 225 | 226 | if self.gpu_mode: 227 | samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) 228 | else: 229 | samples = samples.data.numpy().transpose(0, 2, 3, 1) 230 | 231 | samples = (samples + 1) / 2 232 | utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 233 | self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png') 234 | 235 | def save(self): 236 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 237 | 238 | if not os.path.exists(save_dir): 239 | os.makedirs(save_dir) 240 | 241 | torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl')) 242 | torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl')) 243 | 244 | with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: 245 | pickle.dump(self.train_hist, f) 246 | 247 | def load(self): 248 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 249 | 250 | self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) 251 | self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) -------------------------------------------------------------------------------- /EBGAN.py: -------------------------------------------------------------------------------- 1 | import utils, torch, time, os, pickle 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.autograd import Variable 6 | from torch.utils.data import DataLoader 7 | from torchvision import datasets, transforms 8 | from dataloader import dataloader 9 | 10 | class generator(nn.Module): 11 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 12 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 13 | def __init__(self, input_dim=100, output_dim=1, input_size=32): 14 | super(generator, self).__init__() 15 | self.input_dim = input_dim 16 | self.output_dim = output_dim 17 | self.input_size = input_size 18 | 19 | self.fc = nn.Sequential( 20 | nn.Linear(self.input_dim, 1024), 21 | nn.BatchNorm1d(1024), 22 | nn.ReLU(), 23 | nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)), 24 | nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)), 25 | nn.ReLU(), 26 | ) 27 | self.deconv = nn.Sequential( 28 | nn.ConvTranspose2d(128, 64, 4, 2, 1), 29 | nn.BatchNorm2d(64), 30 | nn.ReLU(), 31 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), 32 | nn.Tanh(), 33 | ) 34 | utils.initialize_weights(self) 35 | 36 | def forward(self, input): 37 | x = self.fc(input) 38 | x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4)) 39 | x = self.deconv(x) 40 | 41 | return x 42 | 43 | class discriminator(nn.Module): 44 | # It must be Auto-Encoder style architecture 45 | # Architecture : (64)4c2s-FC32-FC64*14*14_BR-(1)4dc2s_S 46 | def __init__(self, input_dim=1, output_dim=1, input_size=32): 47 | super(discriminator, self).__init__() 48 | self.input_dim = input_dim 49 | self.output_dim = output_dim 50 | self.input_size = input_size 51 | 52 | self.conv = nn.Sequential( 53 | nn.Conv2d(self.input_dim, 64, 4, 2, 1), 54 | nn.ReLU(), 55 | ) 56 | self.code = nn.Sequential( 57 | nn.Linear(64 * (self.input_size // 2) * (self.input_size // 2), 32), # bn and relu are excluded since code is used in pullaway_loss 58 | ) 59 | self.fc = nn.Sequential( 60 | nn.Linear(32, 64 * (self.input_size // 2) * (self.input_size // 2)), 61 | nn.BatchNorm1d(64 * (self.input_size // 2) * (self.input_size // 2)), 62 | nn.ReLU(), 63 | ) 64 | self.deconv = nn.Sequential( 65 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), 66 | # nn.Sigmoid(), 67 | ) 68 | utils.initialize_weights(self) 69 | 70 | def forward(self, input): 71 | x = self.conv(input) 72 | x = x.view(x.size()[0], -1) 73 | code = self.code(x) 74 | x = self.fc(code) 75 | x = x.view(-1, 64, (self.input_size // 2), (self.input_size // 2)) 76 | x = self.deconv(x) 77 | 78 | return x, code 79 | 80 | class EBGAN(object): 81 | def __init__(self, args): 82 | # parameters 83 | self.epoch = args.epoch 84 | self.sample_num = 100 85 | self.batch_size = args.batch_size 86 | self.save_dir = args.save_dir 87 | self.result_dir = args.result_dir 88 | self.dataset = args.dataset 89 | self.log_dir = args.log_dir 90 | self.gpu_mode = args.gpu_mode 91 | self.model_name = args.gan_type 92 | self.input_size = args.input_size 93 | self.z_dim = 62 94 | self.pt_loss_weight = 0.1 95 | self.margin = 1 96 | 97 | # load dataset 98 | self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size) 99 | data = self.data_loader.__iter__().__next__()[0] 100 | 101 | # networks init 102 | self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size) 103 | self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size) 104 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) 105 | self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) 106 | 107 | if self.gpu_mode: 108 | self.G.cuda() 109 | self.D.cuda() 110 | self.MSE_loss = nn.MSELoss().cuda() 111 | else: 112 | self.MSE_loss = nn.MSELoss() 113 | 114 | print('---------- Networks architecture -------------') 115 | utils.print_network(self.G) 116 | utils.print_network(self.D) 117 | print('-----------------------------------------------') 118 | 119 | # fixed noise 120 | self.sample_z_ = torch.rand((self.batch_size, self.z_dim)) 121 | if self.gpu_mode: 122 | self.sample_z_ = self.sample_z_.cuda() 123 | 124 | def train(self): 125 | self.train_hist = {} 126 | self.train_hist['D_loss'] = [] 127 | self.train_hist['G_loss'] = [] 128 | self.train_hist['per_epoch_time'] = [] 129 | self.train_hist['total_time'] = [] 130 | 131 | self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1) 132 | if self.gpu_mode: 133 | self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda() 134 | 135 | self.D.train() 136 | print('training start!!') 137 | start_time = time.time() 138 | for epoch in range(self.epoch): 139 | self.G.train() 140 | epoch_start_time = time.time() 141 | for iter, (x_, _) in enumerate(self.data_loader): 142 | if iter == self.data_loader.dataset.__len__() // self.batch_size: 143 | break 144 | 145 | z_ = torch.rand((self.batch_size, self.z_dim)) 146 | if self.gpu_mode: 147 | x_, z_ = x_.cuda(), z_.cuda() 148 | 149 | # update D network 150 | self.D_optimizer.zero_grad() 151 | 152 | D_real, _ = self.D(x_) 153 | D_real_loss = self.MSE_loss(D_real, x_) 154 | 155 | G_ = self.G(z_) 156 | D_fake, _ = self.D(G_) 157 | D_fake_loss = self.MSE_loss(D_fake, G_.detach()) 158 | 159 | D_loss = D_real_loss + torch.clamp(self.margin - D_fake_loss, min=0) 160 | self.train_hist['D_loss'].append(D_loss.item()) 161 | 162 | D_loss.backward() 163 | self.D_optimizer.step() 164 | 165 | # update G network 166 | self.G_optimizer.zero_grad() 167 | 168 | G_ = self.G(z_) 169 | D_fake, D_fake_code = self.D(G_) 170 | D_fake_loss = self.MSE_loss(D_fake, G_.detach()) 171 | G_loss = D_fake_loss + self.pt_loss_weight * self.pullaway_loss(D_fake_code.view(self.batch_size, -1)) 172 | self.train_hist['G_loss'].append(G_loss.item()) 173 | 174 | G_loss.backward() 175 | self.G_optimizer.step() 176 | 177 | if ((iter + 1) % 100) == 0: 178 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % 179 | ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item())) 180 | 181 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) 182 | with torch.no_grad(): 183 | self.visualize_results((epoch+1)) 184 | 185 | self.train_hist['total_time'].append(time.time() - start_time) 186 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), 187 | self.epoch, self.train_hist['total_time'][0])) 188 | print("Training finish!... save training results") 189 | 190 | self.save() 191 | utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, 192 | self.epoch) 193 | utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) 194 | 195 | def pullaway_loss(self, embeddings): 196 | """ pullaway_loss tensorflow version code 197 | 198 | norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True)) 199 | normalized_embeddings = embeddings / norm 200 | similarity = tf.matmul( 201 | normalized_embeddings, normalized_embeddings, transpose_b=True) 202 | batch_size = tf.cast(tf.shape(embeddings)[0], tf.float32) 203 | pt_loss = (tf.reduce_sum(similarity) - batch_size) / (batch_size * (batch_size - 1)) 204 | return pt_loss 205 | 206 | """ 207 | # norm = torch.sqrt(torch.sum(embeddings ** 2, 1, keepdim=True)) 208 | # normalized_embeddings = embeddings / norm 209 | # similarity = torch.matmul(normalized_embeddings, normalized_embeddings.transpose(1, 0)) 210 | # batch_size = embeddings.size()[0] 211 | # pt_loss = (torch.sum(similarity) - batch_size) / (batch_size * (batch_size - 1)) 212 | 213 | norm = torch.norm(embeddings, 1) 214 | normalized_embeddings = embeddings / norm 215 | similarity = torch.matmul(normalized_embeddings, normalized_embeddings.transpose(1, 0)) ** 2 216 | batch_size = embeddings.size()[0] 217 | pt_loss = (torch.sum(similarity) - batch_size) / (batch_size * (batch_size - 1)) 218 | 219 | return pt_loss 220 | 221 | 222 | def visualize_results(self, epoch, fix=True): 223 | self.G.eval() 224 | 225 | if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name): 226 | os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name) 227 | 228 | tot_num_samples = min(self.sample_num, self.batch_size) 229 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 230 | 231 | if fix: 232 | """ fixed noise """ 233 | samples = self.G(self.sample_z_) 234 | else: 235 | """ random noise """ 236 | sample_z_ = torch.rand((self.batch_size, self.z_dim)) 237 | if self.gpu_mode: 238 | sample_z_ = sample_z_.cuda() 239 | 240 | samples = self.G(sample_z_) 241 | 242 | if self.gpu_mode: 243 | samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) 244 | else: 245 | samples = samples.data.numpy().transpose(0, 2, 3, 1) 246 | 247 | samples = (samples + 1) / 2 248 | utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 249 | self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png') 250 | 251 | def save(self): 252 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 253 | 254 | if not os.path.exists(save_dir): 255 | os.makedirs(save_dir) 256 | 257 | torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl')) 258 | torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl')) 259 | 260 | with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: 261 | pickle.dump(self.train_hist, f) 262 | 263 | def load(self): 264 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 265 | 266 | self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) 267 | self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) -------------------------------------------------------------------------------- /GAN.py: -------------------------------------------------------------------------------- 1 | import utils, torch, time, os, pickle 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from dataloader import dataloader 6 | 7 | class generator(nn.Module): 8 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 9 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 10 | def __init__(self, input_dim=100, output_dim=1, input_size=32): 11 | super(generator, self).__init__() 12 | self.input_dim = input_dim 13 | self.output_dim = output_dim 14 | self.input_size = input_size 15 | 16 | self.fc = nn.Sequential( 17 | nn.Linear(self.input_dim, 1024), 18 | nn.BatchNorm1d(1024), 19 | nn.ReLU(), 20 | nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)), 21 | nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)), 22 | nn.ReLU(), 23 | ) 24 | self.deconv = nn.Sequential( 25 | nn.ConvTranspose2d(128, 64, 4, 2, 1), 26 | nn.BatchNorm2d(64), 27 | nn.ReLU(), 28 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), 29 | nn.Tanh(), 30 | ) 31 | utils.initialize_weights(self) 32 | 33 | def forward(self, input): 34 | x = self.fc(input) 35 | x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4)) 36 | x = self.deconv(x) 37 | 38 | return x 39 | 40 | class discriminator(nn.Module): 41 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 42 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S 43 | def __init__(self, input_dim=1, output_dim=1, input_size=32): 44 | super(discriminator, self).__init__() 45 | self.input_dim = input_dim 46 | self.output_dim = output_dim 47 | self.input_size = input_size 48 | 49 | self.conv = nn.Sequential( 50 | nn.Conv2d(self.input_dim, 64, 4, 2, 1), 51 | nn.LeakyReLU(0.2), 52 | nn.Conv2d(64, 128, 4, 2, 1), 53 | nn.BatchNorm2d(128), 54 | nn.LeakyReLU(0.2), 55 | ) 56 | self.fc = nn.Sequential( 57 | nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024), 58 | nn.BatchNorm1d(1024), 59 | nn.LeakyReLU(0.2), 60 | nn.Linear(1024, self.output_dim), 61 | nn.Sigmoid(), 62 | ) 63 | utils.initialize_weights(self) 64 | 65 | def forward(self, input): 66 | x = self.conv(input) 67 | x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4)) 68 | x = self.fc(x) 69 | 70 | return x 71 | 72 | class GAN(object): 73 | def __init__(self, args): 74 | # parameters 75 | self.epoch = args.epoch 76 | self.sample_num = 100 77 | self.batch_size = args.batch_size 78 | self.save_dir = args.save_dir 79 | self.result_dir = args.result_dir 80 | self.dataset = args.dataset 81 | self.log_dir = args.log_dir 82 | self.gpu_mode = args.gpu_mode 83 | self.model_name = args.gan_type 84 | self.input_size = args.input_size 85 | self.z_dim = 62 86 | 87 | # load dataset 88 | self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size) 89 | data = self.data_loader.__iter__().__next__()[0] 90 | 91 | # networks init 92 | self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size) 93 | self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size) 94 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) 95 | self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) 96 | 97 | if self.gpu_mode: 98 | self.G.cuda() 99 | self.D.cuda() 100 | self.BCE_loss = nn.BCELoss().cuda() 101 | else: 102 | self.BCE_loss = nn.BCELoss() 103 | 104 | print('---------- Networks architecture -------------') 105 | utils.print_network(self.G) 106 | utils.print_network(self.D) 107 | print('-----------------------------------------------') 108 | 109 | 110 | # fixed noise 111 | self.sample_z_ = torch.rand((self.batch_size, self.z_dim)) 112 | if self.gpu_mode: 113 | self.sample_z_ = self.sample_z_.cuda() 114 | 115 | 116 | def train(self): 117 | self.train_hist = {} 118 | self.train_hist['D_loss'] = [] 119 | self.train_hist['G_loss'] = [] 120 | self.train_hist['per_epoch_time'] = [] 121 | self.train_hist['total_time'] = [] 122 | 123 | self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1) 124 | if self.gpu_mode: 125 | self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda() 126 | 127 | self.D.train() 128 | print('training start!!') 129 | start_time = time.time() 130 | for epoch in range(self.epoch): 131 | self.G.train() 132 | epoch_start_time = time.time() 133 | for iter, (x_, _) in enumerate(self.data_loader): 134 | if iter == self.data_loader.dataset.__len__() // self.batch_size: 135 | break 136 | 137 | z_ = torch.rand((self.batch_size, self.z_dim)) 138 | if self.gpu_mode: 139 | x_, z_ = x_.cuda(), z_.cuda() 140 | 141 | # update D network 142 | self.D_optimizer.zero_grad() 143 | 144 | D_real = self.D(x_) 145 | D_real_loss = self.BCE_loss(D_real, self.y_real_) 146 | 147 | G_ = self.G(z_) 148 | D_fake = self.D(G_) 149 | D_fake_loss = self.BCE_loss(D_fake, self.y_fake_) 150 | 151 | D_loss = D_real_loss + D_fake_loss 152 | self.train_hist['D_loss'].append(D_loss.item()) 153 | 154 | D_loss.backward() 155 | self.D_optimizer.step() 156 | 157 | # update G network 158 | self.G_optimizer.zero_grad() 159 | 160 | G_ = self.G(z_) 161 | D_fake = self.D(G_) 162 | G_loss = self.BCE_loss(D_fake, self.y_real_) 163 | self.train_hist['G_loss'].append(G_loss.item()) 164 | 165 | G_loss.backward() 166 | self.G_optimizer.step() 167 | 168 | if ((iter + 1) % 100) == 0: 169 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % 170 | ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item())) 171 | 172 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) 173 | with torch.no_grad(): 174 | self.visualize_results((epoch+1)) 175 | 176 | self.train_hist['total_time'].append(time.time() - start_time) 177 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), 178 | self.epoch, self.train_hist['total_time'][0])) 179 | print("Training finish!... save training results") 180 | 181 | self.save() 182 | utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, 183 | self.epoch) 184 | utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) 185 | 186 | def visualize_results(self, epoch, fix=True): 187 | self.G.eval() 188 | 189 | if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name): 190 | os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name) 191 | 192 | tot_num_samples = min(self.sample_num, self.batch_size) 193 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 194 | 195 | if fix: 196 | """ fixed noise """ 197 | samples = self.G(self.sample_z_) 198 | else: 199 | """ random noise """ 200 | sample_z_ = torch.rand((self.batch_size, self.z_dim)) 201 | if self.gpu_mode: 202 | sample_z_ = sample_z_.cuda() 203 | 204 | samples = self.G(sample_z_) 205 | 206 | if self.gpu_mode: 207 | samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) 208 | else: 209 | samples = samples.data.numpy().transpose(0, 2, 3, 1) 210 | 211 | samples = (samples + 1) / 2 212 | utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 213 | self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png') 214 | 215 | def save(self): 216 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 217 | 218 | if not os.path.exists(save_dir): 219 | os.makedirs(save_dir) 220 | 221 | torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl')) 222 | torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl')) 223 | 224 | with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: 225 | pickle.dump(self.train_hist, f) 226 | 227 | def load(self): 228 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 229 | 230 | self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) 231 | self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) -------------------------------------------------------------------------------- /LSGAN.py: -------------------------------------------------------------------------------- 1 | import utils, torch, time, os, pickle 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from dataloader import dataloader 6 | 7 | class generator(nn.Module): 8 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 9 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 10 | def __init__(self, input_dim=100, output_dim=1, input_size=32): 11 | super(generator, self).__init__() 12 | self.input_dim = input_dim 13 | self.output_dim = output_dim 14 | self.input_size = input_size 15 | 16 | self.fc = nn.Sequential( 17 | nn.Linear(self.input_dim, 1024), 18 | nn.BatchNorm1d(1024), 19 | nn.ReLU(), 20 | nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)), 21 | nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)), 22 | nn.ReLU(), 23 | ) 24 | self.deconv = nn.Sequential( 25 | nn.ConvTranspose2d(128, 64, 4, 2, 1), 26 | nn.BatchNorm2d(64), 27 | nn.ReLU(), 28 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), 29 | nn.Tanh(), 30 | ) 31 | utils.initialize_weights(self) 32 | 33 | def forward(self, input): 34 | x = self.fc(input) 35 | x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4)) 36 | x = self.deconv(x) 37 | 38 | return x 39 | 40 | class discriminator(nn.Module): 41 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 42 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S 43 | def __init__(self, input_dim=1, output_dim=1, input_size=32): 44 | super(discriminator, self).__init__() 45 | self.input_dim = input_dim 46 | self.output_dim = output_dim 47 | self.input_size = input_size 48 | 49 | self.conv = nn.Sequential( 50 | nn.Conv2d(self.input_dim, 64, 4, 2, 1), 51 | nn.LeakyReLU(0.2), 52 | nn.Conv2d(64, 128, 4, 2, 1), 53 | nn.BatchNorm2d(128), 54 | nn.LeakyReLU(0.2), 55 | ) 56 | self.fc = nn.Sequential( 57 | nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024), 58 | nn.BatchNorm1d(1024), 59 | nn.LeakyReLU(0.2), 60 | nn.Linear(1024, self.output_dim), 61 | # nn.Sigmoid(), 62 | ) 63 | utils.initialize_weights(self) 64 | 65 | def forward(self, input): 66 | x = self.conv(input) 67 | x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4)) 68 | x = self.fc(x) 69 | 70 | return x 71 | 72 | class LSGAN(object): 73 | def __init__(self, args): 74 | # parameters 75 | self.epoch = args.epoch 76 | self.sample_num = 100 77 | self.batch_size = args.batch_size 78 | self.save_dir = args.save_dir 79 | self.result_dir = args.result_dir 80 | self.dataset = args.dataset 81 | self.log_dir = args.log_dir 82 | self.gpu_mode = args.gpu_mode 83 | self.model_name = args.gan_type 84 | self.input_size = args.input_size 85 | self.z_dim = 62 86 | 87 | # load dataset 88 | self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size) 89 | data = self.data_loader.__iter__().__next__()[0] 90 | 91 | # networks init 92 | self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size) 93 | self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size) 94 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) 95 | self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) 96 | 97 | if self.gpu_mode: 98 | self.G.cuda() 99 | self.D.cuda() 100 | self.MSE_loss = nn.MSELoss().cuda() 101 | else: 102 | self.MSE_loss = nn.MSELoss() 103 | 104 | print('---------- Networks architecture -------------') 105 | utils.print_network(self.G) 106 | utils.print_network(self.D) 107 | print('-----------------------------------------------') 108 | 109 | # fixed noise 110 | self.sample_z_ = torch.rand((self.batch_size, self.z_dim)) 111 | if self.gpu_mode: 112 | self.sample_z_ = self.sample_z_.cuda() 113 | 114 | def train(self): 115 | self.train_hist = {} 116 | self.train_hist['D_loss'] = [] 117 | self.train_hist['G_loss'] = [] 118 | self.train_hist['per_epoch_time'] = [] 119 | self.train_hist['total_time'] = [] 120 | 121 | self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1) 122 | if self.gpu_mode: 123 | self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda() 124 | 125 | self.D.train() 126 | print('training start!!') 127 | start_time = time.time() 128 | for epoch in range(self.epoch): 129 | self.G.train() 130 | epoch_start_time = time.time() 131 | for iter, (x_, _) in enumerate(self.data_loader): 132 | if iter == self.data_loader.dataset.__len__() // self.batch_size: 133 | break 134 | 135 | z_ = torch.rand((self.batch_size, self.z_dim)) 136 | if self.gpu_mode: 137 | x_, z_ = x_.cuda(), z_.cuda() 138 | 139 | # update D network 140 | self.D_optimizer.zero_grad() 141 | 142 | D_real = self.D(x_) 143 | D_real_loss = self.MSE_loss(D_real, self.y_real_) 144 | 145 | G_ = self.G(z_) 146 | D_fake = self.D(G_) 147 | D_fake_loss = self.MSE_loss(D_fake, self.y_fake_) 148 | 149 | D_loss = D_real_loss + D_fake_loss 150 | self.train_hist['D_loss'].append(D_loss.item()) 151 | 152 | D_loss.backward() 153 | self.D_optimizer.step() 154 | 155 | # update G network 156 | self.G_optimizer.zero_grad() 157 | 158 | G_ = self.G(z_) 159 | D_fake = self.D(G_) 160 | G_loss = self.MSE_loss(D_fake, self.y_real_) 161 | self.train_hist['G_loss'].append(G_loss.item()) 162 | 163 | G_loss.backward() 164 | self.G_optimizer.step() 165 | 166 | if ((iter + 1) % 100) == 0: 167 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % 168 | ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item())) 169 | 170 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) 171 | with torch.no_grad(): 172 | self.visualize_results((epoch+1)) 173 | 174 | self.train_hist['total_time'].append(time.time() - start_time) 175 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), 176 | self.epoch, self.train_hist['total_time'][0])) 177 | print("Training finish!... save training results") 178 | 179 | self.save() 180 | utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, 181 | self.epoch) 182 | utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) 183 | 184 | def visualize_results(self, epoch, fix=True): 185 | self.G.eval() 186 | 187 | if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name): 188 | os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name) 189 | 190 | tot_num_samples = min(self.sample_num, self.batch_size) 191 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 192 | 193 | if fix: 194 | """ fixed noise """ 195 | samples = self.G(self.sample_z_) 196 | else: 197 | """ random noise """ 198 | sample_z_ = torch.rand((self.batch_size, self.z_dim)) 199 | if self.gpu_mode: 200 | sample_z_ = sample_z_.cuda() 201 | 202 | samples = self.G(sample_z_) 203 | 204 | if self.gpu_mode: 205 | samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) 206 | else: 207 | samples = samples.data.numpy().transpose(0, 2, 3, 1) 208 | 209 | samples = (samples + 1) / 2 210 | utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 211 | self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png') 212 | 213 | def save(self): 214 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 215 | 216 | if not os.path.exists(save_dir): 217 | os.makedirs(save_dir) 218 | 219 | torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl')) 220 | torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl')) 221 | 222 | with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: 223 | pickle.dump(self.train_hist, f) 224 | 225 | def load(self): 226 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 227 | 228 | self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) 229 | self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-generative-model-collections 2 | Original : [[Tensorflow version]](https://github.com/hwalsuklee/tensorflow-generative-model-collections) 3 | 4 | Pytorch implementation of various GANs. 5 | 6 | This repository was re-implemented with reference to [tensorflow-generative-model-collections](https://github.com/hwalsuklee/tensorflow-generative-model-collections) by [Hwalsuk Lee](https://github.com/hwalsuklee) 7 | 8 | I tried to implement this repository as much as possible with [tensorflow-generative-model-collections](https://github.com/hwalsuklee/tensorflow-generative-model-collections), But some models are a little different. 9 | 10 | This repository is included code for CPU mode Pytorch, but i did not test. I tested only in GPU mode Pytorch. 11 | 12 | ## Dataset 13 | 14 | - MNIST 15 | - Fashion-MNIST 16 | - CIFAR10 17 | - SVHN 18 | - STL10 19 | - LSUN-bed 20 | #### I only tested the code on MNIST and Fashion-MNIST. 21 | 22 | ## Generative Adversarial Networks (GANs) 23 | ### Lists (Table is borrowed from [tensorflow-generative-model-collections](https://github.com/hwalsuklee/tensorflow-generative-model-collections)) 24 | 25 | *Name* | *Paper Link* | *Value Function* 26 | :---: | :---: | :--- | 27 | **GAN** | [Arxiv](https://arxiv.org/abs/1406.2661) | 28 | **LSGAN**| [Arxiv](https://arxiv.org/abs/1611.04076) | 29 | **WGAN**| [Arxiv](https://arxiv.org/abs/1701.07875) | 30 | **WGAN_GP**| [Arxiv](https://arxiv.org/abs/1704.00028) | 31 | **DRAGAN**| [Arxiv](https://arxiv.org/abs/1705.07215) | 32 | **CGAN**| [Arxiv](https://arxiv.org/abs/1411.1784) | 33 | **infoGAN**| [Arxiv](https://arxiv.org/abs/1606.03657) | 34 | **ACGAN**| [Arxiv](https://arxiv.org/abs/1610.09585) | 35 | **EBGAN**| [Arxiv](https://arxiv.org/abs/1609.03126) | 36 | **BEGAN**| [Arxiv](https://arxiv.org/abs/1703.10717) | 37 | 38 | #### Variants of GAN structure (Figures are borrowed from [tensorflow-generative-model-collections](https://github.com/hwalsuklee/tensorflow-generative-model-collections)) 39 | 40 | 41 | ### Results for mnist 42 | Network architecture of generator and discriminator is the exaclty sames as in [infoGAN paper](https://arxiv.org/abs/1606.03657). 43 | For fair comparison of core ideas in all gan variants, all implementations for network architecture are kept same except EBGAN and BEGAN. Small modification is made for EBGAN/BEGAN, since those adopt auto-encoder strucutre for discriminator. But I tried to keep the capacity of discirminator. 44 | 45 | The following results can be reproduced with command: 46 | ``` 47 | python main.py --dataset mnist --gan_type --epoch 50 --batch_size 64 48 | ``` 49 | 50 | #### Fixed generation 51 | All results are generated from the fixed noise vector. 52 | 53 | *Name* | *Epoch 1* | *Epoch 25* | *Epoch 50* | *GIF* 54 | :---: | :---: | :---: | :---: | :---: | 55 | GAN | | | | 56 | LSGAN | | | | 57 | WGAN | | | | 58 | WGAN_GP | | | | 59 | DRAGAN | | | | 60 | EBGAN | | | | 61 | BEGAN | | | | 62 | 63 | #### Conditional generation 64 | Each row has the same noise vector and each column has the same label condition. 65 | 66 | *Name* | *Epoch 1* | *Epoch 25* | *Epoch 50* | *GIF* 67 | :---: | :---: | :---: | :---: | :---: | 68 | CGAN | | | | 69 | ACGAN | | | | 70 | infoGAN | | | | 71 | 72 | #### InfoGAN : Manipulating two continous codes 73 | All results have the same noise vector and label condition, but have different continous vector. 74 | 75 | *Name* | *Epoch 1* | *Epoch 25* | *Epoch 50* | *GIF* 76 | :---: | :---: | :---: | :---: | :---: | 77 | infoGAN | | | | 78 | 79 | #### Loss plot 80 | 81 | *Name* | *Loss* 82 | :---: | :---: | 83 | GAN | 84 | LSGAN | 85 | WGAN | 86 | WGAN_GP | 87 | DRAGAN | 88 | EBGAN | 89 | BEGAN | 90 | CGAN | 91 | ACGAN | 92 | infoGAN | 93 | 94 | ### Results for fashion-mnist 95 | Comments on network architecture in mnist are also applied to here. 96 | [Fashion-mnist](https://github.com/zalandoresearch/fashion-mnist) is a recently proposed dataset consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes. (T-shirt/top, Trouser, Pullover, Dress, Coat, Sandal, Shirt, Sneaker, Bag, Ankle boot) 97 | 98 | The following results can be reproduced with command: 99 | ``` 100 | python main.py --dataset fashion-mnist --gan_type --epoch 50 --batch_size 64 101 | ``` 102 | 103 | #### Fixed generation 104 | All results are generated from the fixed noise vector. 105 | 106 | *Name* | *Epoch 1* | *Epoch 25* | *Epoch 50* | *GIF* 107 | :---: | :---: | :---: | :---: | :---: | 108 | GAN | | | | 109 | LSGAN | | | | 110 | WGAN | | | | 111 | WGAN_GP | | | | 112 | DRAGAN | | | | 113 | EBGAN | | | | 114 | BEGAN | | | | 115 | 116 | #### Conditional generation 117 | Each row has the same noise vector and each column has the same label condition. 118 | 119 | *Name* | *Epoch 1* | *Epoch 25* | *Epoch 50* | *GIF* 120 | :---: | :---: | :---: | :---: | :---: | 121 | CGAN | | | | 122 | ACGAN | | | | 123 | infoGAN | | | | 124 | 125 | - ACGAN tends to fall into mode-collapse in [tensorflow-generative-model-collections](https://github.com/hwalsuklee/tensorflow-generative-model-collections), but Pytorch ACGAN does not fall into mode-collapse. 126 | 127 | #### InfoGAN : Manipulating two continous codes 128 | All results have the same noise vector and label condition, but have different continous vector. 129 | 130 | *Name* | *Epoch 1* | *Epoch 25* | *Epoch 50* | *GIF* 131 | :---: | :---: | :---: | :---: | :---: | 132 | infoGAN | | | | 133 | 134 | #### Loss plot 135 | 136 | *Name* | *Loss* 137 | :---: | :---: | 138 | GAN | 139 | LSGAN | 140 | WGAN | 141 | WGAN_GP | 142 | DRAGAN | 143 | EBGAN | 144 | BEGAN | 145 | CGAN | 146 | ACGAN | 147 | infoGAN | 148 | 149 | ## Folder structure 150 | The following shows basic folder structure. 151 | ``` 152 | ├── main.py # gateway 153 | ├── data 154 | │ ├── mnist # mnist data (not included in this repo) 155 | │ ├── ... 156 | │ ├── ... 157 | │ └── fashion-mnist # fashion-mnist data (not included in this repo) 158 | │ 159 | ├── GAN.py # vainilla GAN 160 | ├── utils.py # utils 161 | ├── dataloader.py # dataloader 162 | ├── models # model files to be saved here 163 | └── results # generation results to be saved here 164 | ``` 165 | 166 | ## Development Environment 167 | * Ubuntu 16.04 LTS 168 | * NVIDIA GTX 1080 ti 169 | * cuda 9.0 170 | * Python 3.5.2 171 | * pytorch 0.4.0 172 | * torchvision 0.2.1 173 | * numpy 1.14.3 174 | * matplotlib 2.2.2 175 | * imageio 2.3.0 176 | * scipy 1.1.0 177 | 178 | ## Acknowledgements 179 | This implementation has been based on [tensorflow-generative-model-collections](https://github.com/hwalsuklee/tensorflow-generative-model-collections) and tested with Pytorch 0.4.0 on Ubuntu 16.04 using GPU. 180 | 181 | -------------------------------------------------------------------------------- /WGAN.py: -------------------------------------------------------------------------------- 1 | import utils, torch, time, os, pickle 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from dataloader import dataloader 6 | 7 | class generator(nn.Module): 8 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 9 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 10 | def __init__(self, input_dim=100, output_dim=1, input_size=32): 11 | super(generator, self).__init__() 12 | self.input_dim = input_dim 13 | self.output_dim = output_dim 14 | self.input_size = input_size 15 | 16 | self.fc = nn.Sequential( 17 | nn.Linear(self.input_dim, 1024), 18 | nn.BatchNorm1d(1024), 19 | nn.ReLU(), 20 | nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)), 21 | nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)), 22 | nn.ReLU(), 23 | ) 24 | self.deconv = nn.Sequential( 25 | nn.ConvTranspose2d(128, 64, 4, 2, 1), 26 | nn.BatchNorm2d(64), 27 | nn.ReLU(), 28 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), 29 | nn.Tanh(), 30 | ) 31 | utils.initialize_weights(self) 32 | 33 | def forward(self, input): 34 | x = self.fc(input) 35 | x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4)) 36 | x = self.deconv(x) 37 | 38 | return x 39 | 40 | class discriminator(nn.Module): 41 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 42 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S 43 | def __init__(self, input_dim=1, output_dim=1, input_size=32): 44 | super(discriminator, self).__init__() 45 | self.input_dim = input_dim 46 | self.output_dim = output_dim 47 | self.input_size = input_size 48 | 49 | self.conv = nn.Sequential( 50 | nn.Conv2d(self.input_dim, 64, 4, 2, 1), 51 | nn.LeakyReLU(0.2), 52 | nn.Conv2d(64, 128, 4, 2, 1), 53 | nn.BatchNorm2d(128), 54 | nn.LeakyReLU(0.2), 55 | ) 56 | self.fc = nn.Sequential( 57 | nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024), 58 | nn.BatchNorm1d(1024), 59 | nn.LeakyReLU(0.2), 60 | nn.Linear(1024, self.output_dim), 61 | # nn.Sigmoid(), 62 | ) 63 | utils.initialize_weights(self) 64 | 65 | def forward(self, input): 66 | x = self.conv(input) 67 | x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4)) 68 | x = self.fc(x) 69 | 70 | return x 71 | 72 | class WGAN(object): 73 | def __init__(self, args): 74 | # parameters 75 | self.epoch = args.epoch 76 | self.sample_num = 100 77 | self.batch_size = args.batch_size 78 | self.save_dir = args.save_dir 79 | self.result_dir = args.result_dir 80 | self.dataset = args.dataset 81 | self.log_dir = args.log_dir 82 | self.gpu_mode = args.gpu_mode 83 | self.model_name = args.gan_type 84 | self.input_size = args.input_size 85 | self.z_dim = 62 86 | self.c = 0.01 # clipping value 87 | self.n_critic = 5 # the number of iterations of the critic per generator iteration 88 | 89 | # load dataset 90 | self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size) 91 | data = self.data_loader.__iter__().__next__()[0] 92 | 93 | # networks init 94 | self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size) 95 | self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size) 96 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) 97 | self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) 98 | 99 | if self.gpu_mode: 100 | self.G.cuda() 101 | self.D.cuda() 102 | 103 | print('---------- Networks architecture -------------') 104 | utils.print_network(self.G) 105 | utils.print_network(self.D) 106 | print('-----------------------------------------------') 107 | 108 | # fixed noise 109 | self.sample_z_ = torch.rand((self.batch_size, self.z_dim)) 110 | if self.gpu_mode: 111 | self.sample_z_ = self.sample_z_.cuda() 112 | 113 | def train(self): 114 | self.train_hist = {} 115 | self.train_hist['D_loss'] = [] 116 | self.train_hist['G_loss'] = [] 117 | self.train_hist['per_epoch_time'] = [] 118 | self.train_hist['total_time'] = [] 119 | 120 | self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1) 121 | if self.gpu_mode: 122 | self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda() 123 | 124 | self.D.train() 125 | print('training start!!') 126 | start_time = time.time() 127 | for epoch in range(self.epoch): 128 | self.G.train() 129 | epoch_start_time = time.time() 130 | for iter, (x_, _) in enumerate(self.data_loader): 131 | if iter == self.data_loader.dataset.__len__() // self.batch_size: 132 | break 133 | 134 | z_ = torch.rand((self.batch_size, self.z_dim)) 135 | if self.gpu_mode: 136 | x_, z_ = x_.cuda(), z_.cuda() 137 | 138 | # update D network 139 | self.D_optimizer.zero_grad() 140 | 141 | D_real = self.D(x_) 142 | D_real_loss = -torch.mean(D_real) 143 | 144 | G_ = self.G(z_) 145 | D_fake = self.D(G_) 146 | D_fake_loss = torch.mean(D_fake) 147 | 148 | D_loss = D_real_loss + D_fake_loss 149 | 150 | D_loss.backward() 151 | self.D_optimizer.step() 152 | 153 | # clipping D 154 | for p in self.D.parameters(): 155 | p.data.clamp_(-self.c, self.c) 156 | 157 | if ((iter+1) % self.n_critic) == 0: 158 | # update G network 159 | self.G_optimizer.zero_grad() 160 | 161 | G_ = self.G(z_) 162 | D_fake = self.D(G_) 163 | G_loss = -torch.mean(D_fake) 164 | self.train_hist['G_loss'].append(G_loss.item()) 165 | 166 | G_loss.backward() 167 | self.G_optimizer.step() 168 | 169 | self.train_hist['D_loss'].append(D_loss.item()) 170 | 171 | if ((iter + 1) % 100) == 0: 172 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % 173 | ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item())) 174 | 175 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) 176 | with torch.no_grad(): 177 | self.visualize_results((epoch+1)) 178 | 179 | self.train_hist['total_time'].append(time.time() - start_time) 180 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), 181 | self.epoch, self.train_hist['total_time'][0])) 182 | print("Training finish!... save training results") 183 | 184 | self.save() 185 | utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, 186 | self.epoch) 187 | utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) 188 | 189 | def visualize_results(self, epoch, fix=True): 190 | self.G.eval() 191 | 192 | if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name): 193 | os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name) 194 | 195 | tot_num_samples = min(self.sample_num, self.batch_size) 196 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 197 | 198 | if fix: 199 | """ fixed noise """ 200 | samples = self.G(self.sample_z_) 201 | else: 202 | """ random noise """ 203 | sample_z_ = torch.rand((self.batch_size, self.z_dim)) 204 | if self.gpu_mode: 205 | sample_z_ = sample_z_.cuda() 206 | 207 | samples = self.G(sample_z_) 208 | 209 | if self.gpu_mode: 210 | samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) 211 | else: 212 | samples = samples.data.numpy().transpose(0, 2, 3, 1) 213 | 214 | samples = (samples + 1) / 2 215 | utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 216 | self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png') 217 | 218 | def save(self): 219 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 220 | 221 | if not os.path.exists(save_dir): 222 | os.makedirs(save_dir) 223 | 224 | torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl')) 225 | torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl')) 226 | 227 | with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: 228 | pickle.dump(self.train_hist, f) 229 | 230 | def load(self): 231 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 232 | 233 | self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) 234 | self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) -------------------------------------------------------------------------------- /WGAN_GP.py: -------------------------------------------------------------------------------- 1 | import utils, torch, time, os, pickle 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.autograd import grad 6 | from dataloader import dataloader 7 | 8 | class generator(nn.Module): 9 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 10 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 11 | def __init__(self, input_dim=100, output_dim=1, input_size=32): 12 | super(generator, self).__init__() 13 | self.input_dim = input_dim 14 | self.output_dim = output_dim 15 | self.input_size = input_size 16 | 17 | self.fc = nn.Sequential( 18 | nn.Linear(self.input_dim, 1024), 19 | nn.BatchNorm1d(1024), 20 | nn.ReLU(), 21 | nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)), 22 | nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)), 23 | nn.ReLU(), 24 | ) 25 | self.deconv = nn.Sequential( 26 | nn.ConvTranspose2d(128, 64, 4, 2, 1), 27 | nn.BatchNorm2d(64), 28 | nn.ReLU(), 29 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), 30 | nn.Tanh(), 31 | ) 32 | utils.initialize_weights(self) 33 | 34 | def forward(self, input): 35 | x = self.fc(input) 36 | x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4)) 37 | x = self.deconv(x) 38 | 39 | return x 40 | 41 | class discriminator(nn.Module): 42 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 43 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S 44 | def __init__(self, input_dim=1, output_dim=1, input_size=32): 45 | super(discriminator, self).__init__() 46 | self.input_dim = input_dim 47 | self.output_dim = output_dim 48 | self.input_size = input_size 49 | 50 | self.conv = nn.Sequential( 51 | nn.Conv2d(self.input_dim, 64, 4, 2, 1), 52 | nn.LeakyReLU(0.2), 53 | nn.Conv2d(64, 128, 4, 2, 1), 54 | nn.BatchNorm2d(128), 55 | nn.LeakyReLU(0.2), 56 | ) 57 | self.fc = nn.Sequential( 58 | nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024), 59 | nn.BatchNorm1d(1024), 60 | nn.LeakyReLU(0.2), 61 | nn.Linear(1024, self.output_dim), 62 | # nn.Sigmoid(), 63 | ) 64 | utils.initialize_weights(self) 65 | 66 | def forward(self, input): 67 | x = self.conv(input) 68 | x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4)) 69 | x = self.fc(x) 70 | 71 | return x 72 | 73 | class WGAN_GP(object): 74 | def __init__(self, args): 75 | # parameters 76 | self.epoch = args.epoch 77 | self.sample_num = 100 78 | self.batch_size = args.batch_size 79 | self.save_dir = args.save_dir 80 | self.result_dir = args.result_dir 81 | self.dataset = args.dataset 82 | self.log_dir = args.log_dir 83 | self.gpu_mode = args.gpu_mode 84 | self.model_name = args.gan_type 85 | self.input_size = args.input_size 86 | self.z_dim = 62 87 | self.lambda_ = 10 88 | self.n_critic = 5 # the number of iterations of the critic per generator iteration 89 | 90 | # load dataset 91 | self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size) 92 | data = self.data_loader.__iter__().__next__()[0] 93 | 94 | # networks init 95 | self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size) 96 | self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size) 97 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) 98 | self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) 99 | 100 | if self.gpu_mode: 101 | self.G.cuda() 102 | self.D.cuda() 103 | 104 | print('---------- Networks architecture -------------') 105 | utils.print_network(self.G) 106 | utils.print_network(self.D) 107 | print('-----------------------------------------------') 108 | 109 | # fixed noise 110 | self.sample_z_ = torch.rand((self.batch_size, self.z_dim)) 111 | if self.gpu_mode: 112 | self.sample_z_ = self.sample_z_.cuda() 113 | 114 | def train(self): 115 | self.train_hist = {} 116 | self.train_hist['D_loss'] = [] 117 | self.train_hist['G_loss'] = [] 118 | self.train_hist['per_epoch_time'] = [] 119 | self.train_hist['total_time'] = [] 120 | 121 | self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1) 122 | if self.gpu_mode: 123 | self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda() 124 | 125 | self.D.train() 126 | print('training start!!') 127 | start_time = time.time() 128 | for epoch in range(self.epoch): 129 | self.G.train() 130 | epoch_start_time = time.time() 131 | for iter, (x_, _) in enumerate(self.data_loader): 132 | if iter == self.data_loader.dataset.__len__() // self.batch_size: 133 | break 134 | 135 | z_ = torch.rand((self.batch_size, self.z_dim)) 136 | if self.gpu_mode: 137 | x_, z_ = x_.cuda(), z_.cuda() 138 | 139 | # update D network 140 | self.D_optimizer.zero_grad() 141 | 142 | D_real = self.D(x_) 143 | D_real_loss = -torch.mean(D_real) 144 | 145 | G_ = self.G(z_) 146 | D_fake = self.D(G_) 147 | D_fake_loss = torch.mean(D_fake) 148 | 149 | # gradient penalty 150 | alpha = torch.rand((self.batch_size, 1, 1, 1)) 151 | if self.gpu_mode: 152 | alpha = alpha.cuda() 153 | 154 | x_hat = alpha * x_.data + (1 - alpha) * G_.data 155 | x_hat.requires_grad = True 156 | 157 | pred_hat = self.D(x_hat) 158 | if self.gpu_mode: 159 | gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()).cuda(), 160 | create_graph=True, retain_graph=True, only_inputs=True)[0] 161 | else: 162 | gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()), 163 | create_graph=True, retain_graph=True, only_inputs=True)[0] 164 | 165 | gradient_penalty = self.lambda_ * ((gradients.view(gradients.size()[0], -1).norm(2, 1) - 1) ** 2).mean() 166 | 167 | D_loss = D_real_loss + D_fake_loss + gradient_penalty 168 | 169 | D_loss.backward() 170 | self.D_optimizer.step() 171 | 172 | if ((iter+1) % self.n_critic) == 0: 173 | # update G network 174 | self.G_optimizer.zero_grad() 175 | 176 | G_ = self.G(z_) 177 | D_fake = self.D(G_) 178 | G_loss = -torch.mean(D_fake) 179 | self.train_hist['G_loss'].append(G_loss.item()) 180 | 181 | G_loss.backward() 182 | self.G_optimizer.step() 183 | 184 | self.train_hist['D_loss'].append(D_loss.item()) 185 | 186 | if ((iter + 1) % 100) == 0: 187 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % 188 | ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item())) 189 | 190 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) 191 | with torch.no_grad(): 192 | self.visualize_results((epoch+1)) 193 | 194 | self.train_hist['total_time'].append(time.time() - start_time) 195 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), 196 | self.epoch, self.train_hist['total_time'][0])) 197 | print("Training finish!... save training results") 198 | 199 | self.save() 200 | utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, 201 | self.epoch) 202 | utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) 203 | 204 | def visualize_results(self, epoch, fix=True): 205 | self.G.eval() 206 | 207 | if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name): 208 | os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name) 209 | 210 | tot_num_samples = min(self.sample_num, self.batch_size) 211 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) 212 | 213 | if fix: 214 | """ fixed noise """ 215 | samples = self.G(self.sample_z_) 216 | else: 217 | """ random noise """ 218 | sample_z_ = torch.rand((self.batch_size, self.z_dim)) 219 | if self.gpu_mode: 220 | sample_z_ = sample_z_.cuda() 221 | 222 | samples = self.G(sample_z_) 223 | 224 | if self.gpu_mode: 225 | samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) 226 | else: 227 | samples = samples.data.numpy().transpose(0, 2, 3, 1) 228 | 229 | samples = (samples + 1) / 2 230 | utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 231 | self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png') 232 | 233 | def save(self): 234 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 235 | 236 | if not os.path.exists(save_dir): 237 | os.makedirs(save_dir) 238 | 239 | torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl')) 240 | torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl')) 241 | 242 | with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: 243 | pickle.dump(self.train_hist, f) 244 | 245 | def load(self): 246 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 247 | 248 | self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) 249 | self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) -------------------------------------------------------------------------------- /assets/celebA_results/BEGAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/BEGAN_epoch001.png -------------------------------------------------------------------------------- /assets/celebA_results/BEGAN_epoch010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/BEGAN_epoch010.png -------------------------------------------------------------------------------- /assets/celebA_results/BEGAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/BEGAN_epoch025.png -------------------------------------------------------------------------------- /assets/celebA_results/BEGAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/BEGAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/celebA_results/DRAGAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/DRAGAN_epoch001.png -------------------------------------------------------------------------------- /assets/celebA_results/DRAGAN_epoch010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/DRAGAN_epoch010.png -------------------------------------------------------------------------------- /assets/celebA_results/DRAGAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/DRAGAN_epoch025.png -------------------------------------------------------------------------------- /assets/celebA_results/DRAGAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/DRAGAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/celebA_results/EBGAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/EBGAN_epoch001.png -------------------------------------------------------------------------------- /assets/celebA_results/EBGAN_epoch010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/EBGAN_epoch010.png -------------------------------------------------------------------------------- /assets/celebA_results/EBGAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/EBGAN_epoch025.png -------------------------------------------------------------------------------- /assets/celebA_results/EBGAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/EBGAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/celebA_results/GAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/GAN_epoch001.png -------------------------------------------------------------------------------- /assets/celebA_results/GAN_epoch010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/GAN_epoch010.png -------------------------------------------------------------------------------- /assets/celebA_results/GAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/GAN_epoch025.png -------------------------------------------------------------------------------- /assets/celebA_results/GAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/GAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/celebA_results/LSGAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/LSGAN_epoch001.png -------------------------------------------------------------------------------- /assets/celebA_results/LSGAN_epoch010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/LSGAN_epoch010.png -------------------------------------------------------------------------------- /assets/celebA_results/LSGAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/LSGAN_epoch025.png -------------------------------------------------------------------------------- /assets/celebA_results/LSGAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/LSGAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/celebA_results/WGAN_GP_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/WGAN_GP_epoch001.png -------------------------------------------------------------------------------- /assets/celebA_results/WGAN_GP_epoch010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/WGAN_GP_epoch010.png -------------------------------------------------------------------------------- /assets/celebA_results/WGAN_GP_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/WGAN_GP_epoch025.png -------------------------------------------------------------------------------- /assets/celebA_results/WGAN_GP_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/WGAN_GP_generate_animation.gif -------------------------------------------------------------------------------- /assets/celebA_results/WGAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/WGAN_epoch001.png -------------------------------------------------------------------------------- /assets/celebA_results/WGAN_epoch010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/WGAN_epoch010.png -------------------------------------------------------------------------------- /assets/celebA_results/WGAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/WGAN_epoch025.png -------------------------------------------------------------------------------- /assets/celebA_results/WGAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/WGAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/equations/ACGAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/equations/ACGAN.png -------------------------------------------------------------------------------- /assets/equations/BEGAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/equations/BEGAN.png -------------------------------------------------------------------------------- /assets/equations/CGAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/equations/CGAN.png -------------------------------------------------------------------------------- /assets/equations/DRAGAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/equations/DRAGAN.png -------------------------------------------------------------------------------- /assets/equations/EBGAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/equations/EBGAN.png -------------------------------------------------------------------------------- /assets/equations/GAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/equations/GAN.png -------------------------------------------------------------------------------- /assets/equations/LSGAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/equations/LSGAN.png -------------------------------------------------------------------------------- /assets/equations/WGAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/equations/WGAN.png -------------------------------------------------------------------------------- /assets/equations/WGAN_GP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/equations/WGAN_GP.png -------------------------------------------------------------------------------- /assets/equations/infoGAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/equations/infoGAN.png -------------------------------------------------------------------------------- /assets/etc/GAN_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/etc/GAN_structure.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/ACGAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/ACGAN_epoch001.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/ACGAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/ACGAN_epoch025.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/ACGAN_epoch050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/ACGAN_epoch050.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/ACGAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/ACGAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/fashion_mnist_results/ACGAN_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/ACGAN_loss.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/BEGAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/BEGAN_epoch001.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/BEGAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/BEGAN_epoch025.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/BEGAN_epoch050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/BEGAN_epoch050.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/BEGAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/BEGAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/fashion_mnist_results/BEGAN_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/BEGAN_loss.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/CGAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/CGAN_epoch001.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/CGAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/CGAN_epoch025.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/CGAN_epoch050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/CGAN_epoch050.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/CGAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/CGAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/fashion_mnist_results/CGAN_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/CGAN_loss.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/DRAGAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/DRAGAN_epoch001.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/DRAGAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/DRAGAN_epoch025.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/DRAGAN_epoch050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/DRAGAN_epoch050.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/DRAGAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/DRAGAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/fashion_mnist_results/DRAGAN_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/DRAGAN_loss.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/EBGAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/EBGAN_epoch001.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/EBGAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/EBGAN_epoch025.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/EBGAN_epoch050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/EBGAN_epoch050.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/EBGAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/EBGAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/fashion_mnist_results/EBGAN_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/EBGAN_loss.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/GAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/GAN_epoch001.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/GAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/GAN_epoch025.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/GAN_epoch050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/GAN_epoch050.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/GAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/GAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/fashion_mnist_results/GAN_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/GAN_loss.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/LSGAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/LSGAN_epoch001.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/LSGAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/LSGAN_epoch025.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/LSGAN_epoch050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/LSGAN_epoch050.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/LSGAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/LSGAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/fashion_mnist_results/LSGAN_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/LSGAN_loss.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/WGAN_GP_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/WGAN_GP_epoch001.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/WGAN_GP_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/WGAN_GP_epoch025.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/WGAN_GP_epoch050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/WGAN_GP_epoch050.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/WGAN_GP_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/WGAN_GP_generate_animation.gif -------------------------------------------------------------------------------- /assets/fashion_mnist_results/WGAN_GP_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/WGAN_GP_loss.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/WGAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/WGAN_epoch001.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/WGAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/WGAN_epoch025.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/WGAN_epoch050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/WGAN_epoch050.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/WGAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/WGAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/fashion_mnist_results/WGAN_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/WGAN_loss.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/infoGAN_cont_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/infoGAN_cont_epoch001.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/infoGAN_cont_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/infoGAN_cont_epoch025.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/infoGAN_cont_epoch050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/infoGAN_cont_epoch050.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/infoGAN_cont_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/infoGAN_cont_generate_animation.gif -------------------------------------------------------------------------------- /assets/fashion_mnist_results/infoGAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/infoGAN_epoch001.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/infoGAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/infoGAN_epoch025.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/infoGAN_epoch050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/infoGAN_epoch050.png -------------------------------------------------------------------------------- /assets/fashion_mnist_results/infoGAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/infoGAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/fashion_mnist_results/infoGAN_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/infoGAN_loss.png -------------------------------------------------------------------------------- /assets/mnist_results/ACGAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/ACGAN_epoch001.png -------------------------------------------------------------------------------- /assets/mnist_results/ACGAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/ACGAN_epoch025.png -------------------------------------------------------------------------------- /assets/mnist_results/ACGAN_epoch050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/ACGAN_epoch050.png -------------------------------------------------------------------------------- /assets/mnist_results/ACGAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/ACGAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/mnist_results/ACGAN_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/ACGAN_loss.png -------------------------------------------------------------------------------- /assets/mnist_results/BEGAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/BEGAN_epoch001.png -------------------------------------------------------------------------------- /assets/mnist_results/BEGAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/BEGAN_epoch025.png -------------------------------------------------------------------------------- /assets/mnist_results/BEGAN_epoch050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/BEGAN_epoch050.png -------------------------------------------------------------------------------- /assets/mnist_results/BEGAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/BEGAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/mnist_results/BEGAN_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/BEGAN_loss.png -------------------------------------------------------------------------------- /assets/mnist_results/CGAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/CGAN_epoch001.png -------------------------------------------------------------------------------- /assets/mnist_results/CGAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/CGAN_epoch025.png -------------------------------------------------------------------------------- /assets/mnist_results/CGAN_epoch050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/CGAN_epoch050.png -------------------------------------------------------------------------------- /assets/mnist_results/CGAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/CGAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/mnist_results/CGAN_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/CGAN_loss.png -------------------------------------------------------------------------------- /assets/mnist_results/DRAGAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/DRAGAN_epoch001.png -------------------------------------------------------------------------------- /assets/mnist_results/DRAGAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/DRAGAN_epoch025.png -------------------------------------------------------------------------------- /assets/mnist_results/DRAGAN_epoch050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/DRAGAN_epoch050.png -------------------------------------------------------------------------------- /assets/mnist_results/DRAGAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/DRAGAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/mnist_results/DRAGAN_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/DRAGAN_loss.png -------------------------------------------------------------------------------- /assets/mnist_results/EBGAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/EBGAN_epoch001.png -------------------------------------------------------------------------------- /assets/mnist_results/EBGAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/EBGAN_epoch025.png -------------------------------------------------------------------------------- /assets/mnist_results/EBGAN_epoch050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/EBGAN_epoch050.png -------------------------------------------------------------------------------- /assets/mnist_results/EBGAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/EBGAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/mnist_results/EBGAN_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/EBGAN_loss.png -------------------------------------------------------------------------------- /assets/mnist_results/GAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/GAN_epoch001.png -------------------------------------------------------------------------------- /assets/mnist_results/GAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/GAN_epoch025.png -------------------------------------------------------------------------------- /assets/mnist_results/GAN_epoch050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/GAN_epoch050.png -------------------------------------------------------------------------------- /assets/mnist_results/GAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/GAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/mnist_results/GAN_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/GAN_loss.png -------------------------------------------------------------------------------- /assets/mnist_results/LSGAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/LSGAN_epoch001.png -------------------------------------------------------------------------------- /assets/mnist_results/LSGAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/LSGAN_epoch025.png -------------------------------------------------------------------------------- /assets/mnist_results/LSGAN_epoch050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/LSGAN_epoch050.png -------------------------------------------------------------------------------- /assets/mnist_results/LSGAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/LSGAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/mnist_results/LSGAN_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/LSGAN_loss.png -------------------------------------------------------------------------------- /assets/mnist_results/WGAN_GP_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/WGAN_GP_epoch001.png -------------------------------------------------------------------------------- /assets/mnist_results/WGAN_GP_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/WGAN_GP_epoch025.png -------------------------------------------------------------------------------- /assets/mnist_results/WGAN_GP_epoch050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/WGAN_GP_epoch050.png -------------------------------------------------------------------------------- /assets/mnist_results/WGAN_GP_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/WGAN_GP_generate_animation.gif -------------------------------------------------------------------------------- /assets/mnist_results/WGAN_GP_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/WGAN_GP_loss.png -------------------------------------------------------------------------------- /assets/mnist_results/WGAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/WGAN_epoch001.png -------------------------------------------------------------------------------- /assets/mnist_results/WGAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/WGAN_epoch025.png -------------------------------------------------------------------------------- /assets/mnist_results/WGAN_epoch050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/WGAN_epoch050.png -------------------------------------------------------------------------------- /assets/mnist_results/WGAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/WGAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/mnist_results/WGAN_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/WGAN_loss.png -------------------------------------------------------------------------------- /assets/mnist_results/infoGAN_cont_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/infoGAN_cont_epoch001.png -------------------------------------------------------------------------------- /assets/mnist_results/infoGAN_cont_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/infoGAN_cont_epoch025.png -------------------------------------------------------------------------------- /assets/mnist_results/infoGAN_cont_epoch050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/infoGAN_cont_epoch050.png -------------------------------------------------------------------------------- /assets/mnist_results/infoGAN_cont_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/infoGAN_cont_generate_animation.gif -------------------------------------------------------------------------------- /assets/mnist_results/infoGAN_epoch001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/infoGAN_epoch001.png -------------------------------------------------------------------------------- /assets/mnist_results/infoGAN_epoch025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/infoGAN_epoch025.png -------------------------------------------------------------------------------- /assets/mnist_results/infoGAN_epoch050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/infoGAN_epoch050.png -------------------------------------------------------------------------------- /assets/mnist_results/infoGAN_generate_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/infoGAN_generate_animation.gif -------------------------------------------------------------------------------- /assets/mnist_results/infoGAN_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/infoGAN_loss.png -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from torchvision import datasets, transforms 3 | 4 | def dataloader(dataset, input_size, batch_size, split='train'): 5 | transform = transforms.Compose([transforms.Resize((input_size, input_size)), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) 6 | if dataset == 'mnist': 7 | data_loader = DataLoader( 8 | datasets.MNIST('data/mnist', train=True, download=True, transform=transform), 9 | batch_size=batch_size, shuffle=True) 10 | elif dataset == 'fashion-mnist': 11 | data_loader = DataLoader( 12 | datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transform), 13 | batch_size=batch_size, shuffle=True) 14 | elif dataset == 'cifar10': 15 | data_loader = DataLoader( 16 | datasets.CIFAR10('data/cifar10', train=True, download=True, transform=transform), 17 | batch_size=batch_size, shuffle=True) 18 | elif dataset == 'svhn': 19 | data_loader = DataLoader( 20 | datasets.SVHN('data/svhn', split=split, download=True, transform=transform), 21 | batch_size=batch_size, shuffle=True) 22 | elif dataset == 'stl10': 23 | data_loader = DataLoader( 24 | datasets.STL10('data/stl10', split=split, download=True, transform=transform), 25 | batch_size=batch_size, shuffle=True) 26 | elif dataset == 'lsun-bed': 27 | data_loader = DataLoader( 28 | datasets.LSUN('data/lsun', classes=['bedroom_train'], transform=transform), 29 | batch_size=batch_size, shuffle=True) 30 | 31 | return data_loader -------------------------------------------------------------------------------- /infoGAN.py: -------------------------------------------------------------------------------- 1 | import utils, torch, time, os, pickle, itertools 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import matplotlib.pyplot as plt 7 | from dataloader import dataloader 8 | 9 | class generator(nn.Module): 10 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 11 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 12 | def __init__(self, input_dim=100, output_dim=1, input_size=32, len_discrete_code=10, len_continuous_code=2): 13 | super(generator, self).__init__() 14 | self.input_dim = input_dim 15 | self.output_dim = output_dim 16 | self.input_size = input_size 17 | self.len_discrete_code = len_discrete_code # categorical distribution (i.e. label) 18 | self.len_continuous_code = len_continuous_code # gaussian distribution (e.g. rotation, thickness) 19 | 20 | self.fc = nn.Sequential( 21 | nn.Linear(self.input_dim + self.len_discrete_code + self.len_continuous_code, 1024), 22 | nn.BatchNorm1d(1024), 23 | nn.ReLU(), 24 | nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)), 25 | nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)), 26 | nn.ReLU(), 27 | ) 28 | self.deconv = nn.Sequential( 29 | nn.ConvTranspose2d(128, 64, 4, 2, 1), 30 | nn.BatchNorm2d(64), 31 | nn.ReLU(), 32 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), 33 | nn.Tanh(), 34 | ) 35 | utils.initialize_weights(self) 36 | 37 | def forward(self, input, cont_code, dist_code): 38 | x = torch.cat([input, cont_code, dist_code], 1) 39 | x = self.fc(x) 40 | x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4)) 41 | x = self.deconv(x) 42 | 43 | return x 44 | 45 | class discriminator(nn.Module): 46 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 47 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S 48 | def __init__(self, input_dim=1, output_dim=1, input_size=32, len_discrete_code=10, len_continuous_code=2): 49 | super(discriminator, self).__init__() 50 | self.input_dim = input_dim 51 | self.output_dim = output_dim 52 | self.input_size = input_size 53 | self.len_discrete_code = len_discrete_code # categorical distribution (i.e. label) 54 | self.len_continuous_code = len_continuous_code # gaussian distribution (e.g. rotation, thickness) 55 | 56 | self.conv = nn.Sequential( 57 | nn.Conv2d(self.input_dim, 64, 4, 2, 1), 58 | nn.LeakyReLU(0.2), 59 | nn.Conv2d(64, 128, 4, 2, 1), 60 | nn.BatchNorm2d(128), 61 | nn.LeakyReLU(0.2), 62 | ) 63 | self.fc = nn.Sequential( 64 | nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024), 65 | nn.BatchNorm1d(1024), 66 | nn.LeakyReLU(0.2), 67 | nn.Linear(1024, self.output_dim + self.len_continuous_code + self.len_discrete_code), 68 | # nn.Sigmoid(), 69 | ) 70 | utils.initialize_weights(self) 71 | 72 | def forward(self, input): 73 | x = self.conv(input) 74 | x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4)) 75 | x = self.fc(x) 76 | a = F.sigmoid(x[:, self.output_dim]) 77 | b = x[:, self.output_dim:self.output_dim + self.len_continuous_code] 78 | c = x[:, self.output_dim + self.len_continuous_code:] 79 | 80 | return a, b, c 81 | 82 | class infoGAN(object): 83 | def __init__(self, args, SUPERVISED=True): 84 | # parameters 85 | self.epoch = args.epoch 86 | self.batch_size = args.batch_size 87 | self.save_dir = args.save_dir 88 | self.result_dir = args.result_dir 89 | self.dataset = args.dataset 90 | self.log_dir = args.log_dir 91 | self.gpu_mode = args.gpu_mode 92 | self.model_name = args.gan_type 93 | self.input_size = args.input_size 94 | self.z_dim = 62 95 | self.SUPERVISED = SUPERVISED # if it is true, label info is directly used for code 96 | self.len_discrete_code = 10 # categorical distribution (i.e. label) 97 | self.len_continuous_code = 2 # gaussian distribution (e.g. rotation, thickness) 98 | self.sample_num = self.len_discrete_code ** 2 99 | 100 | # load dataset 101 | self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size) 102 | data = self.data_loader.__iter__().__next__()[0] 103 | 104 | # networks init 105 | self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size, len_discrete_code=self.len_discrete_code, len_continuous_code=self.len_continuous_code) 106 | self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size, len_discrete_code=self.len_discrete_code, len_continuous_code=self.len_continuous_code) 107 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) 108 | self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) 109 | self.info_optimizer = optim.Adam(itertools.chain(self.G.parameters(), self.D.parameters()), lr=args.lrD, betas=(args.beta1, args.beta2)) 110 | 111 | if self.gpu_mode: 112 | self.G.cuda() 113 | self.D.cuda() 114 | self.BCE_loss = nn.BCELoss().cuda() 115 | self.CE_loss = nn.CrossEntropyLoss().cuda() 116 | self.MSE_loss = nn.MSELoss().cuda() 117 | else: 118 | self.BCE_loss = nn.BCELoss() 119 | self.CE_loss = nn.CrossEntropyLoss() 120 | self.MSE_loss = nn.MSELoss() 121 | 122 | print('---------- Networks architecture -------------') 123 | utils.print_network(self.G) 124 | utils.print_network(self.D) 125 | print('-----------------------------------------------') 126 | 127 | # fixed noise & condition 128 | self.sample_z_ = torch.zeros((self.sample_num, self.z_dim)) 129 | for i in range(self.len_discrete_code): 130 | self.sample_z_[i * self.len_discrete_code] = torch.rand(1, self.z_dim) 131 | for j in range(1, self.len_discrete_code): 132 | self.sample_z_[i * self.len_discrete_code + j] = self.sample_z_[i * self.len_discrete_code] 133 | 134 | temp = torch.zeros((self.len_discrete_code, 1)) 135 | for i in range(self.len_discrete_code): 136 | temp[i, 0] = i 137 | 138 | temp_y = torch.zeros((self.sample_num, 1)) 139 | for i in range(self.len_discrete_code): 140 | temp_y[i * self.len_discrete_code: (i + 1) * self.len_discrete_code] = temp 141 | 142 | self.sample_y_ = torch.zeros((self.sample_num, self.len_discrete_code)).scatter_(1, temp_y.type(torch.LongTensor), 1) 143 | self.sample_c_ = torch.zeros((self.sample_num, self.len_continuous_code)) 144 | 145 | # manipulating two continuous code 146 | self.sample_z2_ = torch.rand((1, self.z_dim)).expand(self.sample_num, self.z_dim) 147 | self.sample_y2_ = torch.zeros(self.sample_num, self.len_discrete_code) 148 | self.sample_y2_[:, 0] = 1 149 | 150 | temp_c = torch.linspace(-1, 1, 10) 151 | self.sample_c2_ = torch.zeros((self.sample_num, 2)) 152 | for i in range(self.len_discrete_code): 153 | for j in range(self.len_discrete_code): 154 | self.sample_c2_[i*self.len_discrete_code+j, 0] = temp_c[i] 155 | self.sample_c2_[i*self.len_discrete_code+j, 1] = temp_c[j] 156 | 157 | if self.gpu_mode: 158 | self.sample_z_, self.sample_y_, self.sample_c_, self.sample_z2_, self.sample_y2_, self.sample_c2_ = \ 159 | self.sample_z_.cuda(), self.sample_y_.cuda(), self.sample_c_.cuda(), self.sample_z2_.cuda(), \ 160 | self.sample_y2_.cuda(), self.sample_c2_.cuda() 161 | 162 | def train(self): 163 | self.train_hist = {} 164 | self.train_hist['D_loss'] = [] 165 | self.train_hist['G_loss'] = [] 166 | self.train_hist['info_loss'] = [] 167 | self.train_hist['per_epoch_time'] = [] 168 | self.train_hist['total_time'] = [] 169 | 170 | self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1) 171 | if self.gpu_mode: 172 | self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda() 173 | 174 | self.D.train() 175 | print('training start!!') 176 | start_time = time.time() 177 | for epoch in range(self.epoch): 178 | self.G.train() 179 | epoch_start_time = time.time() 180 | for iter, (x_, y_) in enumerate(self.data_loader): 181 | if iter == self.data_loader.dataset.__len__() // self.batch_size: 182 | break 183 | z_ = torch.rand((self.batch_size, self.z_dim)) 184 | if self.SUPERVISED == True: 185 | y_disc_ = torch.zeros((self.batch_size, self.len_discrete_code)).scatter_(1, y_.type(torch.LongTensor).unsqueeze(1), 1) 186 | else: 187 | y_disc_ = torch.from_numpy( 188 | np.random.multinomial(1, self.len_discrete_code * [float(1.0 / self.len_discrete_code)], 189 | size=[self.batch_size])).type(torch.FloatTensor) 190 | 191 | y_cont_ = torch.from_numpy(np.random.uniform(-1, 1, size=(self.batch_size, 2))).type(torch.FloatTensor) 192 | 193 | if self.gpu_mode: 194 | x_, z_, y_disc_, y_cont_ = x_.cuda(), z_.cuda(), y_disc_.cuda(), y_cont_.cuda() 195 | 196 | # update D network 197 | self.D_optimizer.zero_grad() 198 | 199 | D_real, _, _ = self.D(x_) 200 | D_real_loss = self.BCE_loss(D_real, self.y_real_) 201 | 202 | G_ = self.G(z_, y_cont_, y_disc_) 203 | D_fake, _, _ = self.D(G_) 204 | D_fake_loss = self.BCE_loss(D_fake, self.y_fake_) 205 | 206 | D_loss = D_real_loss + D_fake_loss 207 | self.train_hist['D_loss'].append(D_loss.item()) 208 | 209 | D_loss.backward(retain_graph=True) 210 | self.D_optimizer.step() 211 | 212 | # update G network 213 | self.G_optimizer.zero_grad() 214 | 215 | G_ = self.G(z_, y_cont_, y_disc_) 216 | D_fake, D_cont, D_disc = self.D(G_) 217 | 218 | G_loss = self.BCE_loss(D_fake, self.y_real_) 219 | self.train_hist['G_loss'].append(G_loss.item()) 220 | 221 | G_loss.backward(retain_graph=True) 222 | self.G_optimizer.step() 223 | 224 | # information loss 225 | disc_loss = self.CE_loss(D_disc, torch.max(y_disc_, 1)[1]) 226 | cont_loss = self.MSE_loss(D_cont, y_cont_) 227 | info_loss = disc_loss + cont_loss 228 | self.train_hist['info_loss'].append(info_loss.item()) 229 | 230 | info_loss.backward() 231 | self.info_optimizer.step() 232 | 233 | 234 | if ((iter + 1) % 100) == 0: 235 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f, info_loss: %.8f" % 236 | ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item(), info_loss.item())) 237 | 238 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) 239 | with torch.no_grad(): 240 | self.visualize_results((epoch+1)) 241 | 242 | self.train_hist['total_time'].append(time.time() - start_time) 243 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), 244 | self.epoch, self.train_hist['total_time'][0])) 245 | print("Training finish!... save training results") 246 | 247 | self.save() 248 | utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, 249 | self.epoch) 250 | utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_cont', 251 | self.epoch) 252 | self.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) 253 | 254 | def visualize_results(self, epoch): 255 | self.G.eval() 256 | 257 | if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name): 258 | os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name) 259 | 260 | image_frame_dim = int(np.floor(np.sqrt(self.sample_num))) 261 | 262 | """ style by class """ 263 | samples = self.G(self.sample_z_, self.sample_c_, self.sample_y_) 264 | if self.gpu_mode: 265 | samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) 266 | else: 267 | samples = samples.data.numpy().transpose(0, 2, 3, 1) 268 | 269 | samples = (samples + 1) / 2 270 | utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 271 | self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png') 272 | 273 | """ manipulating two continous codes """ 274 | samples = self.G(self.sample_z2_, self.sample_c2_, self.sample_y2_) 275 | if self.gpu_mode: 276 | samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) 277 | else: 278 | samples = samples.data.numpy().transpose(0, 2, 3, 1) 279 | 280 | samples = (samples + 1) / 2 281 | utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 282 | self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_cont_epoch%03d' % epoch + '.png') 283 | 284 | def save(self): 285 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 286 | 287 | if not os.path.exists(save_dir): 288 | os.makedirs(save_dir) 289 | 290 | torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl')) 291 | torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl')) 292 | 293 | with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: 294 | pickle.dump(self.train_hist, f) 295 | 296 | def load(self): 297 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 298 | 299 | self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) 300 | self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) 301 | 302 | def loss_plot(self, hist, path='Train_hist.png', model_name=''): 303 | x = range(len(hist['D_loss'])) 304 | 305 | y1 = hist['D_loss'] 306 | y2 = hist['G_loss'] 307 | y3 = hist['info_loss'] 308 | 309 | plt.plot(x, y1, label='D_loss') 310 | plt.plot(x, y2, label='G_loss') 311 | plt.plot(x, y3, label='info_loss') 312 | 313 | plt.xlabel('Iter') 314 | plt.ylabel('Loss') 315 | 316 | plt.legend(loc=4) 317 | plt.grid(True) 318 | plt.tight_layout() 319 | 320 | path = os.path.join(path, model_name + '_loss.png') 321 | 322 | plt.savefig(path) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse, os, torch 2 | from GAN import GAN 3 | from CGAN import CGAN 4 | from LSGAN import LSGAN 5 | from DRAGAN import DRAGAN 6 | from ACGAN import ACGAN 7 | from WGAN import WGAN 8 | from WGAN_GP import WGAN_GP 9 | from infoGAN import infoGAN 10 | from EBGAN import EBGAN 11 | from BEGAN import BEGAN 12 | 13 | """parsing and configuration""" 14 | def parse_args(): 15 | desc = "Pytorch implementation of GAN collections" 16 | parser = argparse.ArgumentParser(description=desc) 17 | 18 | parser.add_argument('--gan_type', type=str, default='GAN', 19 | choices=['GAN', 'CGAN', 'infoGAN', 'ACGAN', 'EBGAN', 'BEGAN', 'WGAN', 'WGAN_GP', 'DRAGAN', 'LSGAN'], 20 | help='The type of GAN') 21 | parser.add_argument('--dataset', type=str, default='mnist', choices=['mnist', 'fashion-mnist', 'cifar10', 'cifar100', 'svhn', 'stl10', 'lsun-bed'], 22 | help='The name of dataset') 23 | parser.add_argument('--split', type=str, default='', help='The split flag for svhn and stl10') 24 | parser.add_argument('--epoch', type=int, default=50, help='The number of epochs to run') 25 | parser.add_argument('--batch_size', type=int, default=64, help='The size of batch') 26 | parser.add_argument('--input_size', type=int, default=28, help='The size of input image') 27 | parser.add_argument('--save_dir', type=str, default='models', 28 | help='Directory name to save the model') 29 | parser.add_argument('--result_dir', type=str, default='results', help='Directory name to save the generated images') 30 | parser.add_argument('--log_dir', type=str, default='logs', help='Directory name to save training logs') 31 | parser.add_argument('--lrG', type=float, default=0.0002) 32 | parser.add_argument('--lrD', type=float, default=0.0002) 33 | parser.add_argument('--beta1', type=float, default=0.5) 34 | parser.add_argument('--beta2', type=float, default=0.999) 35 | parser.add_argument('--gpu_mode', type=bool, default=True) 36 | parser.add_argument('--benchmark_mode', type=bool, default=True) 37 | 38 | return check_args(parser.parse_args()) 39 | 40 | """checking arguments""" 41 | def check_args(args): 42 | # --save_dir 43 | if not os.path.exists(args.save_dir): 44 | os.makedirs(args.save_dir) 45 | 46 | # --result_dir 47 | if not os.path.exists(args.result_dir): 48 | os.makedirs(args.result_dir) 49 | 50 | # --result_dir 51 | if not os.path.exists(args.log_dir): 52 | os.makedirs(args.log_dir) 53 | 54 | # --epoch 55 | try: 56 | assert args.epoch >= 1 57 | except: 58 | print('number of epochs must be larger than or equal to one') 59 | 60 | # --batch_size 61 | try: 62 | assert args.batch_size >= 1 63 | except: 64 | print('batch size must be larger than or equal to one') 65 | 66 | return args 67 | 68 | """main""" 69 | def main(): 70 | # parse arguments 71 | args = parse_args() 72 | if args is None: 73 | exit() 74 | 75 | if args.benchmark_mode: 76 | torch.backends.cudnn.benchmark = True 77 | 78 | # declare instance for GAN 79 | if args.gan_type == 'GAN': 80 | gan = GAN(args) 81 | elif args.gan_type == 'CGAN': 82 | gan = CGAN(args) 83 | elif args.gan_type == 'ACGAN': 84 | gan = ACGAN(args) 85 | elif args.gan_type == 'infoGAN': 86 | gan = infoGAN(args, SUPERVISED=False) 87 | elif args.gan_type == 'EBGAN': 88 | gan = EBGAN(args) 89 | elif args.gan_type == 'WGAN': 90 | gan = WGAN(args) 91 | elif args.gan_type == 'WGAN_GP': 92 | gan = WGAN_GP(args) 93 | elif args.gan_type == 'DRAGAN': 94 | gan = DRAGAN(args) 95 | elif args.gan_type == 'LSGAN': 96 | gan = LSGAN(args) 97 | elif args.gan_type == 'BEGAN': 98 | gan = BEGAN(args) 99 | else: 100 | raise Exception("[!] There is no option for " + args.gan_type) 101 | 102 | # launch the graph in a session 103 | gan.train() 104 | print(" [*] Training finished!") 105 | 106 | # visualize learned generator 107 | gan.visualize_results(args.epoch) 108 | print(" [*] Testing finished!") 109 | 110 | if __name__ == '__main__': 111 | main() 112 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os, gzip, torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import scipy.misc 5 | import imageio 6 | import matplotlib.pyplot as plt 7 | from torchvision import datasets, transforms 8 | 9 | def load_mnist(dataset): 10 | data_dir = os.path.join("./data", dataset) 11 | 12 | def extract_data(filename, num_data, head_size, data_size): 13 | with gzip.open(filename) as bytestream: 14 | bytestream.read(head_size) 15 | buf = bytestream.read(data_size * num_data) 16 | data = np.frombuffer(buf, dtype=np.uint8).astype(np.float) 17 | return data 18 | 19 | data = extract_data(data_dir + '/train-images-idx3-ubyte.gz', 60000, 16, 28 * 28) 20 | trX = data.reshape((60000, 28, 28, 1)) 21 | 22 | data = extract_data(data_dir + '/train-labels-idx1-ubyte.gz', 60000, 8, 1) 23 | trY = data.reshape((60000)) 24 | 25 | data = extract_data(data_dir + '/t10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28) 26 | teX = data.reshape((10000, 28, 28, 1)) 27 | 28 | data = extract_data(data_dir + '/t10k-labels-idx1-ubyte.gz', 10000, 8, 1) 29 | teY = data.reshape((10000)) 30 | 31 | trY = np.asarray(trY).astype(np.int) 32 | teY = np.asarray(teY) 33 | 34 | X = np.concatenate((trX, teX), axis=0) 35 | y = np.concatenate((trY, teY), axis=0).astype(np.int) 36 | 37 | seed = 547 38 | np.random.seed(seed) 39 | np.random.shuffle(X) 40 | np.random.seed(seed) 41 | np.random.shuffle(y) 42 | 43 | y_vec = np.zeros((len(y), 10), dtype=np.float) 44 | for i, label in enumerate(y): 45 | y_vec[i, y[i]] = 1 46 | 47 | X = X.transpose(0, 3, 1, 2) / 255. 48 | # y_vec = y_vec.transpose(0, 3, 1, 2) 49 | 50 | X = torch.from_numpy(X).type(torch.FloatTensor) 51 | y_vec = torch.from_numpy(y_vec).type(torch.FloatTensor) 52 | return X, y_vec 53 | 54 | def load_celebA(dir, transform, batch_size, shuffle): 55 | # transform = transforms.Compose([ 56 | # transforms.CenterCrop(160), 57 | # transform.Scale(64) 58 | # transforms.ToTensor(), 59 | # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 60 | # ]) 61 | 62 | # data_dir = 'data/celebA' # this path depends on your computer 63 | dset = datasets.ImageFolder(dir, transform) 64 | data_loader = torch.utils.data.DataLoader(dset, batch_size, shuffle) 65 | 66 | return data_loader 67 | 68 | 69 | def print_network(net): 70 | num_params = 0 71 | for param in net.parameters(): 72 | num_params += param.numel() 73 | print(net) 74 | print('Total number of parameters: %d' % num_params) 75 | 76 | def save_images(images, size, image_path): 77 | return imsave(images, size, image_path) 78 | 79 | def imsave(images, size, path): 80 | image = np.squeeze(merge(images, size)) 81 | return scipy.misc.imsave(path, image) 82 | 83 | def merge(images, size): 84 | h, w = images.shape[1], images.shape[2] 85 | if (images.shape[3] in (3,4)): 86 | c = images.shape[3] 87 | img = np.zeros((h * size[0], w * size[1], c)) 88 | for idx, image in enumerate(images): 89 | i = idx % size[1] 90 | j = idx // size[1] 91 | img[j * h:j * h + h, i * w:i * w + w, :] = image 92 | return img 93 | elif images.shape[3]==1: 94 | img = np.zeros((h * size[0], w * size[1])) 95 | for idx, image in enumerate(images): 96 | i = idx % size[1] 97 | j = idx // size[1] 98 | img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0] 99 | return img 100 | else: 101 | raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4') 102 | 103 | def generate_animation(path, num): 104 | images = [] 105 | for e in range(num): 106 | img_name = path + '_epoch%03d' % (e+1) + '.png' 107 | images.append(imageio.imread(img_name)) 108 | imageio.mimsave(path + '_generate_animation.gif', images, fps=5) 109 | 110 | def loss_plot(hist, path = 'Train_hist.png', model_name = ''): 111 | x = range(len(hist['D_loss'])) 112 | 113 | y1 = hist['D_loss'] 114 | y2 = hist['G_loss'] 115 | 116 | plt.plot(x, y1, label='D_loss') 117 | plt.plot(x, y2, label='G_loss') 118 | 119 | plt.xlabel('Iter') 120 | plt.ylabel('Loss') 121 | 122 | plt.legend(loc=4) 123 | plt.grid(True) 124 | plt.tight_layout() 125 | 126 | path = os.path.join(path, model_name + '_loss.png') 127 | 128 | plt.savefig(path) 129 | 130 | plt.close() 131 | 132 | def initialize_weights(net): 133 | for m in net.modules(): 134 | if isinstance(m, nn.Conv2d): 135 | m.weight.data.normal_(0, 0.02) 136 | m.bias.data.zero_() 137 | elif isinstance(m, nn.ConvTranspose2d): 138 | m.weight.data.normal_(0, 0.02) 139 | m.bias.data.zero_() 140 | elif isinstance(m, nn.Linear): 141 | m.weight.data.normal_(0, 0.02) 142 | m.bias.data.zero_() --------------------------------------------------------------------------------