├── data └── describe_data ├── README.md ├── sub.py ├── img_dataloader.py ├── utills.py ├── create_datasets.py ├── aae_basic.py ├── aae_supervised.py └── aae_semi_supervised.py /data/describe_data: -------------------------------------------------------------------------------- 1 | You have to download MNIST dataset in this directory. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AAE-PyTorch 2 | Adversarial autoencoder (basic/semi-supervised/supervised) 3 | 4 | First, 5 | 6 | ```bash 7 | $ python create_datasets.py 8 | ``` 9 | It takes some times.... 10 | 11 | Then, you get data/MNIST, data/subMNIST (automatically downloded in data/ directory), which are MNIST image datasets. 12 | you also get train_labeled.p, train_unlabeled.p, validation.p, which are list of tr_l, tr_u, tt image. 13 | 14 | Second, 15 | 16 | ```bash 17 | $ python aae_baisc.py 18 | ``` 19 | or 20 | 21 | 22 | ```bash 23 | $ python aae_supervised.py 24 | ``` 25 | You can get category conditional images like below, 26 | 27 | 28 | 29 | 30 | or 31 | 32 | 33 | ```bash 34 | $ python aae_semi_supervised.py 35 | ``` 36 | -------------------------------------------------------------------------------- /sub.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | from torchvision.datasets import MNIST 6 | 7 | class subMNIST(MNIST): 8 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False, k=3000): 9 | super(subMNIST, self).__init__(root, train, transform, target_transform, download) 10 | self.k = k 11 | 12 | def __len__(self): 13 | if self.train: 14 | return self.k 15 | else: 16 | return 10000 17 | 18 | transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) 19 | trainset = subMNIST(root='data', train=True, 20 | download=True, transform=transform) 21 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, 22 | shuffle=True, num_workers=2) 23 | 24 | 25 | #print(len(trainset)) 26 | #print(len(trainloader)) -------------------------------------------------------------------------------- /img_dataloader.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | import torch.utils.data as data 4 | import numpy as np 5 | 6 | 7 | def load_data(data_path='data/'): 8 | print('loading data!') 9 | trainset_labeled = pickle.load(open(data_path + "train_labeled.p", "rb")) 10 | trainset_unlabeled = pickle.load(open(data_path + "train_unlabeled.p", "rb")) 11 | # Set -1 as labels for unlabeled data 12 | trainset_unlabeled.targets = torch.from_numpy(np.array([-1] * 47000)) 13 | validset = pickle.load(open(data_path + "validation.p", "rb")) 14 | train_batch_size = 64 15 | valid_batch_size = 64 16 | train_labeled_loader = data.DataLoader(trainset_labeled, batch_size=train_batch_size, shuffle=True) 17 | train_unlabeled_loader = data.DataLoader(trainset_unlabeled, batch_size=train_batch_size, shuffle=True) 18 | valid_loader = data.DataLoader(validset, batch_size=valid_batch_size, shuffle=True) 19 | 20 | return train_labeled_loader, train_unlabeled_loader, valid_loader 21 | 22 | 23 | if __name__ == "__main__": 24 | trainset_labeled = pickle.load(open('data/'+ "train_labeled.p", "rb")) 25 | print(trainset_labeled) 26 | -------------------------------------------------------------------------------- /utills.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import torch 4 | import pickle 5 | import numpy as np 6 | import itertools 7 | # from viz import * 8 | from torch.autograd import Variable 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | 13 | cuda = False 14 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 15 | def sample_image(n_row): 16 | z = Variable(Tensor(np.random.normal(0,1,(n_row**2, 10)))) 17 | print(z.shape) 18 | z_cat = Tensor(np.array([1,0,0,0,0,0,0,0,0,0])) 19 | 20 | 21 | def sample_categorical(batch_size, n_classes=10): 22 | ''' 23 | Sample from a categorical distribution 24 | of size batch_size and # of classes n_classes 25 | return: torch.autograd.Variable with the sample 26 | ''' 27 | cat = np.random.randint(0, 10, batch_size) 28 | cat = np.eye(n_classes)[cat].astype('float32') 29 | cat = torch.from_numpy(cat) 30 | return cat 31 | 32 | def get_categorical(labels, n_classes=10): 33 | cat = np.array(labels.data.tolist()) 34 | cat = np.eye(n_classes)[cat].astype('float32') 35 | cat = torch.from_numpy(cat) 36 | return cat 37 | 38 | def classification_accuracy(Q, data_loader): 39 | Q.eval() 40 | labels = [] 41 | test_loss = 0 42 | correct = 0 43 | cuda = True 44 | for batch_idx, (X, target) in enumerate(data_loader): 45 | # X = X * 0.3081 + 0.1307 46 | # X.resize_(data_loader.batch_size, X_dim) 47 | X, target = Variable(X), Variable(target) 48 | if cuda: 49 | X, target = X.cuda(), target.cuda() 50 | 51 | labels.extend(target.data.tolist()) 52 | # Reconstruction phase 53 | output = Q(X)[1] 54 | 55 | test_loss += F.nll_loss(output, target).item() 56 | 57 | pred = output.data.max(1)[1] 58 | correct += pred.eq(target.data).cpu().sum() 59 | 60 | test_loss /= len(data_loader) 61 | return 100. * correct / len(data_loader.dataset) 62 | 63 | 64 | 65 | if __name__ == "__main__": 66 | a = sample_categorical(5, n_classes=10) 67 | print(a.shape) 68 | sample_image(10) 69 | 70 | 71 | z_cat = Tensor(np.array([1,0,0,0,0,0,0,0,0,0])) 72 | hat = z_cat.repeat(2,1) 73 | print(hat) -------------------------------------------------------------------------------- /create_datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import pickle 3 | import numpy as np 4 | import torch 5 | from torchvision import datasets, transforms 6 | from sub import subMNIST 7 | 8 | transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), 9 | transforms.Normalize((0.5,), (0.5,))]) 10 | 11 | trainset_original = datasets.MNIST('data', train=True, download=True, 12 | transform=transform) 13 | 14 | train_label_index = [] 15 | valid_label_index = [] 16 | for i in range(10): 17 | train_label_list = trainset_original.train_labels.numpy() 18 | label_index = np.where(train_label_list == i)[0] 19 | label_subindex = list(label_index[:300]) 20 | valid_subindex = list(label_index[300: 1000 + 300]) 21 | train_label_index += label_subindex 22 | valid_label_index += valid_subindex 23 | 24 | trainset_np = trainset_original.train_data.numpy() 25 | trainset_label_np = trainset_original.train_labels.numpy() 26 | train_data_sub = torch.from_numpy(trainset_np[train_label_index]) 27 | train_labels_sub = torch.from_numpy(trainset_label_np[train_label_index]) 28 | 29 | trainset_new = subMNIST(root='data', train=True, download=True, transform=transform, k=3000) 30 | trainset_new.data = train_data_sub.clone() 31 | trainset_new.targets = train_labels_sub.clone() 32 | 33 | pickle.dump(trainset_new, open("data/train_labeled.p", "wb")) 34 | 35 | 36 | validset_np = trainset_original.train_data.numpy() 37 | validset_label_np = trainset_original.train_labels.numpy() 38 | valid_data_sub = torch.from_numpy(validset_np[valid_label_index]) 39 | valid_labels_sub = torch.from_numpy(validset_label_np[valid_label_index]) 40 | 41 | 42 | validset = subMNIST(root='data', train=False, download=True, transform=transform, k=10000) 43 | validset.data = valid_data_sub.clone() 44 | validset.targets = valid_labels_sub.clone() 45 | 46 | pickle.dump(validset, open("data/validation.p", "wb")) 47 | 48 | 49 | train_unlabel_index = [] 50 | for i in range(60000): 51 | if i in train_label_index or i in valid_label_index: 52 | pass 53 | else: 54 | train_unlabel_index.append(i) 55 | 56 | trainset_np = trainset_original.train_data.numpy() 57 | trainset_label_np = trainset_original.train_labels.numpy() 58 | train_data_sub_unl = torch.from_numpy(trainset_np[train_unlabel_index]) 59 | train_labels_sub_unl = torch.from_numpy(trainset_label_np[train_unlabel_index]) 60 | 61 | trainset_new_unl = subMNIST(root='data', train=True, download=True, transform=transform, k=47000) 62 | trainset_new_unl.data = train_data_sub_unl.clone() 63 | trainset_new_unl.targets = None # Unlabeled 64 | 65 | trainset_new_unl.targets 66 | 67 | pickle.dump(trainset_new_unl, open("data/train_unlabeled.p", "wb")) 68 | -------------------------------------------------------------------------------- /aae_basic.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import math 5 | import itertools 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | from torchvision.utils import save_image 12 | import torchvision.transforms as transforms 13 | from torchvision import datasets 14 | 15 | from img_dataloader import load_data 16 | 17 | ####################################################### 18 | # hyperparameter setting 19 | ####################################################### 20 | 21 | parser = argparse.ArgumentParser("Basic aae model by hyu") 22 | parser.add_argument("--n_epochs", type=int, default=500, help="number of epochs of training") 23 | parser.add_argument("--batch_size", type=int, default=64, help="size of the batches") 24 | 25 | parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") 26 | parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") 27 | parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") 28 | 29 | parser.add_argument("--latent_dim", type=int, default=10, help="dimensionality of the latent code") 30 | parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension") 31 | parser.add_argument("--channels", type=int, default=1, help="number of image channels") 32 | 33 | parser.add_argument("--img_dir", type=str, default='image_basic', help="number of classes of image datasets") 34 | 35 | args = parser.parse_args() 36 | print(args) 37 | 38 | # config cuda 39 | cuda = torch.cuda.is_available() 40 | img_shape = (args.channels, args.img_size, args.img_size) 41 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 42 | 43 | 44 | ####################################################### 45 | # Define Networks 46 | ####################################################### 47 | 48 | class Encoder(nn.Module): 49 | def __init__(self): 50 | super(Encoder, self).__init__() 51 | 52 | self.model = nn.Sequential( 53 | nn.Linear(int(np.prod(img_shape)), 512), 54 | nn.LeakyReLU(0.2, inplace=True), 55 | nn.Linear(512, 512), 56 | nn.BatchNorm1d(512), 57 | nn.LeakyReLU(0.2, inplace=True), 58 | nn.Linear(512, args.latent_dim) 59 | ) 60 | 61 | self.mu = nn.Linear(512, args.latent_dim) 62 | self.logvar = nn.Linear(512, args.latent_dim) 63 | 64 | def forward(self, img): 65 | img_flat = img.view(img.shape[0], -1) 66 | x = self.model(img_flat) 67 | # mu = self.mu(x) 68 | # logvar = self.logvar(x) 69 | # z = reparameterization(mu, logvar) 70 | return x 71 | 72 | 73 | class Decoder(nn.Module): 74 | def __init__(self): 75 | super(Decoder, self).__init__() 76 | 77 | self.model = nn.Sequential( 78 | nn.Linear(args.latent_dim, 512), 79 | nn.LeakyReLU(0.2, inplace=True), 80 | nn.Linear(512, 512), 81 | nn.BatchNorm1d(512), 82 | nn.LeakyReLU(0.2, inplace=True), 83 | nn.Linear(512, int(np.prod(img_shape))), 84 | nn.Tanh(), 85 | ) 86 | 87 | def forward(self, z): 88 | img_flat = self.model(z) 89 | img = img_flat.view(img_flat.shape[0], *img_shape) 90 | return img 91 | 92 | 93 | class Discriminator(nn.Module): 94 | def __init__(self): 95 | super(Discriminator, self).__init__() 96 | 97 | self.model = nn.Sequential( 98 | nn.Linear(args.latent_dim, 512), 99 | nn.LeakyReLU(0.2, inplace=True), 100 | nn.Linear(512, 256), 101 | nn.LeakyReLU(0.2, inplace=True), 102 | nn.Linear(256, 1), 103 | nn.Sigmoid(), 104 | ) 105 | 106 | def forward(self, z): 107 | validity = self.model(z) 108 | return validity 109 | 110 | 111 | ####################################################### 112 | # Preparation part 113 | ####################################################### 114 | 115 | # data 116 | train_labeled_loader = load_data('data/')[1] 117 | 118 | # define model 119 | # 1) generator 120 | encoder = Encoder() 121 | decoder = Decoder() 122 | # 2) discriminator 123 | discriminator = Discriminator() 124 | 125 | # loss 126 | adversarial_loss = nn.BCELoss() 127 | reconstruction_loss = nn.MSELoss() 128 | 129 | # optimizer 130 | optimizer_G = torch.optim.Adam(itertools.chain(encoder.parameters(), decoder.parameters()), lr=args.lr, betas=(args.b1, args.b2)) 131 | optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.b1, args.b2)) 132 | 133 | if cuda: 134 | encoder.cuda() 135 | decoder.cuda() 136 | discriminator.cuda() 137 | adversarial_loss.cuda() 138 | reconstruction_loss.cuda() 139 | 140 | 141 | ####################################################### 142 | # Training part 143 | ####################################################### 144 | 145 | def sample_image(n_row, epoch, img_dir): 146 | if not os.path.exists(img_dir): 147 | os.makedirs(img_dir) 148 | z = Variable(Tensor(np.random.normal(0,1,(n_row**2, args.latent_dim)))) 149 | generated_imgs = decoder(z) 150 | save_image(generated_imgs.data, os.path.join(img_dir, "%depoch.png" % epoch), nrow = n_row, normalize = True) 151 | 152 | 153 | # training phase 154 | for epoch in range(args.n_epochs): 155 | for i, (x, idx) in enumerate(train_labeled_loader): 156 | 157 | valid = Variable(Tensor(x.shape[0], 1).fill_(1.0), requires_grad=False) 158 | fake = Variable(Tensor(x.shape[0], 1).fill_(0.0), requires_grad=False) 159 | 160 | if cuda: 161 | x = x.cuda() 162 | 163 | # 1) reconstruction + generator loss 164 | optimizer_G.zero_grad() 165 | fake_z = encoder(x) 166 | decoded_x = decoder(fake_z) 167 | validity_fake_z = discriminator(fake_z) 168 | G_loss = 0.001*adversarial_loss(validity_fake_z, valid) + 0.999*reconstruction_loss(decoded_x, x) 169 | G_loss.backward() 170 | optimizer_G.step() 171 | 172 | # 2) discriminator loss 173 | optimizer_D.zero_grad() 174 | real_z = Variable(Tensor(np.random.normal(0,1,(x.shape[0], args.latent_dim)))) 175 | real_loss = adversarial_loss(discriminator(real_z), valid) 176 | fake_loss = adversarial_loss(discriminator(fake_z.detach()), fake) 177 | D_loss = 0.5*(real_loss + fake_loss) 178 | D_loss.backward() 179 | optimizer_D.step() 180 | 181 | # print loss 182 | print( 183 | "[Epoch %d/%d] [G loss: %f] [D loss: %f]" 184 | % (epoch, args.n_epochs, G_loss.item(), D_loss.item()) 185 | ) 186 | 187 | sample_image(n_row = 10, epoch=epoch, img_dir=args.img_dir) 188 | 189 | 190 | -------------------------------------------------------------------------------- /aae_supervised.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import math 5 | import itertools 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | from torchvision.utils import save_image 12 | import torchvision.transforms as transforms 13 | from torchvision import datasets 14 | 15 | from img_dataloader import load_data 16 | from utills import * 17 | 18 | ####################################################### 19 | # hyperparameter setting 20 | ####################################################### 21 | 22 | parser = argparse.ArgumentParser("supervised aae model by hyu") 23 | parser.add_argument("--n_epochs", type=int, default=500, 24 | help="number of epochs of training") 25 | parser.add_argument("--batch_size", type=int, default=64, 26 | help="size of the batches") 27 | 28 | parser.add_argument("--lr", type=float, default=0.0002, 29 | help="adam: learning rate") 30 | parser.add_argument("--b1", type=float, default=0.5, 31 | help="adam: decay of first order momentum of gradient") 32 | parser.add_argument("--b2", type=float, default=0.999, 33 | help="adam: decay of first order momentum of gradient") 34 | 35 | parser.add_argument("--latent_dim", type=int, default=20, 36 | help="dimensionality of the latent code") 37 | parser.add_argument("--img_size", type=int, default=32, 38 | help="size of each image dimension") 39 | parser.add_argument("--channels", type=int, default=1, 40 | help="number of image channels") 41 | parser.add_argument("--n_classes", type=int, default=10, 42 | help="number of classes of image datasets") 43 | 44 | parser.add_argument("--img_dir", type=str, default='image_supervised', 45 | help="number of classes of image datasets") 46 | 47 | args = parser.parse_args() 48 | print(args) 49 | 50 | # config cuda 51 | cuda = torch.cuda.is_available() 52 | img_shape = (args.channels, args.img_size, args.img_size) 53 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 54 | 55 | 56 | ####################################################### 57 | # Define Networks 58 | ####################################################### 59 | 60 | class Encoder(nn.Module): 61 | def __init__(self): 62 | super(Encoder, self).__init__() 63 | 64 | self.model = nn.Sequential( 65 | nn.Linear(int(np.prod(img_shape)), 512), 66 | nn.LeakyReLU(0.2, inplace=True), 67 | nn.Linear(512, 512), 68 | nn.BatchNorm1d(512), 69 | nn.LeakyReLU(0.2, inplace=True), 70 | nn.Linear(512, args.latent_dim) 71 | ) 72 | 73 | self.mu = nn.Linear(512, args.latent_dim) 74 | self.logvar = nn.Linear(512, args.latent_dim) 75 | 76 | def forward(self, img): 77 | img_flat = img.view(img.shape[0], -1) 78 | x = self.model(img_flat) 79 | # mu = self.mu(x) 80 | # logvar = self.logvar(x) 81 | # z = reparameterization(mu, logvar) 82 | return x 83 | 84 | 85 | class Decoder(nn.Module): 86 | def __init__(self): 87 | super(Decoder, self).__init__() 88 | 89 | self.model = nn.Sequential( 90 | nn.Linear(args.latent_dim + args.n_classes, 512), 91 | nn.LeakyReLU(0.2, inplace=True), 92 | nn.Linear(512, 512), 93 | nn.BatchNorm1d(512), 94 | nn.LeakyReLU(0.2, inplace=True), 95 | nn.Linear(512, int(np.prod(img_shape))), 96 | nn.Tanh(), 97 | ) 98 | 99 | def forward(self, z): 100 | img_flat = self.model(z) 101 | img = img_flat.view(img_flat.shape[0], *img_shape) 102 | return img 103 | 104 | 105 | class Discriminator(nn.Module): 106 | def __init__(self): 107 | super(Discriminator, self).__init__() 108 | 109 | self.model = nn.Sequential( 110 | nn.Linear(args.latent_dim, 512), 111 | nn.LeakyReLU(0.2, inplace=True), 112 | nn.Linear(512, 256), 113 | nn.LeakyReLU(0.2, inplace=True), 114 | nn.Linear(256, 1), 115 | nn.Sigmoid(), 116 | ) 117 | 118 | def forward(self, z): 119 | validity = self.model(z) 120 | return validity 121 | 122 | 123 | ####################################################### 124 | # Preparation part 125 | ####################################################### 126 | 127 | # data 128 | train_labeled_loader = load_data('data/')[0] 129 | 130 | # define model 131 | # 1) generator 132 | encoder = Encoder() 133 | decoder = Decoder() 134 | # 2) discriminator 135 | discriminator = Discriminator() 136 | 137 | # loss 138 | adversarial_loss = nn.BCELoss() 139 | reconstruction_loss = nn.MSELoss() 140 | 141 | # optimizer 142 | optimizer_G = torch.optim.Adam(itertools.chain(encoder.parameters(), decoder.parameters()), 143 | lr=args.lr, betas=(args.b1, args.b2)) 144 | optimizer_D = torch.optim.Adam(discriminator.parameters(), 145 | lr=args.lr, betas=(args.b1, args.b2)) 146 | 147 | if cuda: 148 | encoder.cuda() 149 | decoder.cuda() 150 | discriminator.cuda() 151 | adversarial_loss.cuda() 152 | reconstruction_loss.cuda() 153 | 154 | 155 | ####################################################### 156 | # Training part 157 | ####################################################### 158 | 159 | def sample_image(n_row, epoch, img_dir): 160 | if not os.path.exists(img_dir): 161 | os.makedirs(img_dir) 162 | z = Variable(Tensor(np.random.normal(0, 1, (n_row**2, args.latent_dim*2)))) 163 | decoder.eval() 164 | with torch.no_grad(): 165 | generated_imgs = decoder(z) 166 | save_image(generated_imgs.data, os.path.join(img_dir, "%depoch.png" % epoch), 167 | nrow = n_row, normalize = True) 168 | 169 | def sample_image_cat(n_row, epoch, img_dir): 170 | # form latent vector 171 | z = Variable(Tensor(np.random.normal(0, 1, (n_row**2, args.latent_dim)))) 172 | # form categorical value 173 | z_cat = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).repeat(n_row) 174 | z_cat = np.eye(args.n_classes)[z_cat].astype('float32') 175 | z_cat = Tensor(z_cat) 176 | 177 | z = torch.cat([z_cat, z], dim=1) 178 | 179 | decoder.eval() 180 | with torch.no_grad(): 181 | generated_imgs = decoder(z) 182 | save_image(generated_imgs.data, os.path.join(img_dir, "%depoch.png" % epoch), 183 | nrow = n_row, normalize = True) 184 | 185 | # training phase 186 | for epoch in range(1, args.n_epochs+1): 187 | encoder.train() 188 | decoder.train() 189 | discriminator.train() 190 | for i, (x, idx) in enumerate(train_labeled_loader): 191 | valid = Variable(Tensor(x.shape[0], 1).fill_(1.0), requires_grad=False) 192 | fake = Variable(Tensor(x.shape[0], 1).fill_(0.0), requires_grad=False) 193 | 194 | if cuda: 195 | x = x.cuda() 196 | idx = idx.cuda() 197 | 198 | # 1) reconstruction + generator loss 199 | optimizer_G.zero_grad() 200 | fake_z = encoder(x) 201 | fake_z_cat = get_categorical(idx, n_classes=args.n_classes) 202 | if cuda: 203 | fake_z_cat = fake_z_cat.cuda() 204 | fake_z_concatenate = torch.cat((fake_z_cat, fake_z), 1) 205 | decoded_x = decoder(fake_z_concatenate) 206 | validity_fake_z = discriminator(fake_z) 207 | 208 | G_loss = 0.005*adversarial_loss(validity_fake_z, valid) \ 209 | + 0.995*reconstruction_loss(decoded_x, x) 210 | G_loss.backward() 211 | optimizer_G.step() 212 | 213 | # 2) discriminator loss 214 | optimizer_D.zero_grad() 215 | real_z = Variable(Tensor(np.random.normal(0, 1, (x.shape[0], args.latent_dim)))) 216 | real_loss = adversarial_loss(discriminator(real_z), valid) 217 | fake_loss = adversarial_loss(discriminator(fake_z.detach()), fake) 218 | D_loss = 0.5*(real_loss + fake_loss) 219 | D_loss.backward() 220 | optimizer_D.step() 221 | 222 | # print loss 223 | print("[Epoch %d/%d] [G loss: %.3f] [D loss: %.3f] --> [real loss: %.3f] [fake loss: %.3f]" \ 224 | % (epoch, args.n_epochs, G_loss.item(), D_loss.item(), real_loss.item(), fake_loss.item())) 225 | # Save image 226 | sample_image_cat(n_row = 10, epoch=epoch, img_dir=args.img_dir) 227 | -------------------------------------------------------------------------------- /aae_semi_supervised.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import math 5 | import itertools 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torchvision.utils import save_image 11 | 12 | from img_dataloader import load_data 13 | from utills import * 14 | 15 | ####################################################### 16 | # hyperparameter setting 17 | ####################################################### 18 | 19 | parser = argparse.ArgumentParser("semi-supervised aae model by hyu") 20 | parser.add_argument("--n_epochs", type=int, default=500, help="number of epochs of training") 21 | parser.add_argument("--batch_size", type=int, default=64, help="size of the batches") 22 | 23 | parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") 24 | parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") 25 | parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") 26 | 27 | parser.add_argument("--latent_dim", type=int, default=10, help="dimensionality of the latent code") 28 | parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension") 29 | parser.add_argument("--channels", type=int, default=1, help="number of image channels") 30 | parser.add_argument("--n_classes", type=int, default=10, help="number of classes of image datasets") 31 | 32 | parser.add_argument("--img_dir", type=str, default='image_semi_supervised', help="number of classes of image datasets") 33 | 34 | args = parser.parse_args() 35 | print(args) 36 | 37 | # config cuda 38 | cuda = torch.cuda.is_available() 39 | img_shape = (args.channels, args.img_size, args.img_size) 40 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 41 | 42 | ####################################################### 43 | # Define Networks 44 | ####################################################### 45 | 46 | class Encoder(nn.Module): 47 | def __init__(self): 48 | super(Encoder, self).__init__() 49 | 50 | self.model = nn.Sequential( 51 | nn.Linear(int(np.prod(img_shape)), 512), 52 | nn.LeakyReLU(0.2, inplace=True), 53 | nn.Linear(512, 512), 54 | nn.BatchNorm1d(512), 55 | nn.LeakyReLU(0.2, inplace=True), 56 | ) 57 | 58 | self.lin_D = nn.Linear(512, args.latent_dim) 59 | self.lin_D_cat = nn.Sequential(nn.Linear(512, args.n_classes),nn.Softmax()) 60 | 61 | 62 | def forward(self, img): 63 | img_flat = img.view(img.shape[0], -1) 64 | x = self.model(img_flat) 65 | 66 | z_gauss = self.lin_D(x) 67 | z_cat = self.lin_D_cat(x) 68 | 69 | return z_gauss, z_cat 70 | 71 | class Decoder(nn.Module): 72 | def __init__(self): 73 | super(Decoder, self).__init__() 74 | 75 | self.model = nn.Sequential( 76 | nn.Linear(args.latent_dim, 512), 77 | nn.LeakyReLU(0.2, inplace=True), 78 | nn.Linear(512, 512), 79 | nn.BatchNorm1d(512), 80 | nn.LeakyReLU(0.2, inplace=True), 81 | nn.Linear(512, int(np.prod(img_shape))), 82 | nn.Tanh(), 83 | ) 84 | 85 | def forward(self, z): 86 | img_flat = self.model(z) 87 | img = img_flat.view(img_flat.shape[0], *img_shape) 88 | return img 89 | 90 | class Discriminator(nn.Module): 91 | def __init__(self): 92 | super(Discriminator, self).__init__() 93 | 94 | self.model = nn.Sequential( 95 | nn.Linear(args.latent_dim, 512), 96 | nn.LeakyReLU(0.2, inplace=True), 97 | nn.Linear(512, 256), 98 | nn.LeakyReLU(0.2, inplace=True), 99 | nn.Linear(256, 1), 100 | nn.Sigmoid(), 101 | ) 102 | 103 | def forward(self, z): 104 | validity = self.model(z) 105 | return validity 106 | 107 | class Discriminator_category(nn.Module): 108 | def __init__(self): 109 | super(Discriminator_category, self).__init__() 110 | 111 | self.model = nn.Sequential( 112 | nn.Linear(args.n_classes, 512), 113 | nn.LeakyReLU(0.2, inplace=True), 114 | nn.Linear(512, 256), 115 | nn.LeakyReLU(0.2, inplace=True), 116 | nn.Linear(256, 1), 117 | nn.Sigmoid(), 118 | ) 119 | 120 | def forward(self, z): 121 | validity = self.model(z) 122 | return validity 123 | 124 | 125 | ####################################################### 126 | # Preparation 127 | ####################################################### 128 | 129 | # data 130 | train_labeled_loader = load_data('data/')[0] 131 | train_unlabeled_loader = load_data('data/')[1] 132 | valid_loader = load_data('data/')[2] 133 | 134 | # define model 135 | # 1) generator 136 | encoder = Encoder() 137 | decoder = Decoder() 138 | # 2) discriminator for z and y(category) 139 | discriminator = Discriminator() 140 | discriminator_cat = Discriminator_category() 141 | 142 | # loss 143 | adversarial_loss = nn.BCELoss() 144 | reconstruction_loss = nn.MSELoss() 145 | 146 | # optimizer 147 | # 1) basic optimizer for G(encoder-decoder) and D(discriminator) 148 | optimizer_G = torch.optim.Adam(itertools.chain(encoder.parameters(), decoder.parameters()), lr=args.lr, betas=(args.b1, args.b2)) 149 | optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.b1, args.b2)) 150 | 151 | # 2) additional optimizer for E-semi(encoder for semi-supervised learning) and D_cat(Discriminator for category) 152 | optimizer_E_semi = torch.optim.Adam(encoder.parameters(), lr=args.lr, betas=(args.b1, args.b2)) 153 | optimizer_D_cat = torch.optim.Adam(discriminator_cat.parameters(), lr=args.lr, betas=(args.b1, args.b2)) 154 | 155 | if cuda: 156 | encoder.cuda() 157 | decoder.cuda() 158 | discriminator.cuda() 159 | discriminator_cat.cuda() 160 | adversarial_loss.cuda() 161 | reconstruction_loss.cuda() 162 | 163 | 164 | def sample_image(n_row, epoch, img_dir): 165 | if not os.path.exists(img_dir): 166 | os.makedirs(img_dir) 167 | z = nn.Parameter(Tensor(np.random.normal(0,1,(n_row**2, args.latent_dim)))) 168 | generated_imgs = decoder(z) 169 | save_image(generated_imgs.data, os.path.join(img_dir, "%depoch.png" % epoch), nrow = n_row, normalize = True) 170 | 171 | ####################################################### 172 | # Training part 173 | ####################################################### 174 | 175 | # training phase 176 | for epoch in range(args.n_epochs): 177 | encoder.train() 178 | decoder.train() 179 | discriminator.train() 180 | discriminator_cat.train() 181 | for (x_l, idx_l), (x_u, idx_u) in zip(train_labeled_loader, train_unlabeled_loader): 182 | for x, idx in [(x_l, idx_l), (x_u, idx_u)]: 183 | if idx[0] == -1: 184 | labeled = False 185 | else: 186 | labeled = True 187 | 188 | if cuda: 189 | x = x.cuda() 190 | idx = idx.cuda() 191 | 192 | if labeled: # supervised phase 193 | optimizer_E_semi.zero_grad() 194 | _, fake_cat = encoder(x) 195 | class_loss = F.cross_entropy(fake_cat, idx) 196 | class_loss.backward() 197 | optimizer_E_semi.step() 198 | 199 | else: # unsupervised phase 200 | 201 | valid = nn.Parameter(Tensor(x.shape[0], 1).fill_(1.0), requires_grad=False) 202 | fake = nn.Parameter(Tensor(x.shape[0], 1).fill_(0.0), requires_grad=False) 203 | 204 | # 1) reconstruction + generator loss for gaussian and category 205 | optimizer_G.zero_grad() 206 | fake_z, fake_cat = encoder(x) 207 | decoded_x = decoder(fake_z) 208 | validity_fake_z = discriminator(fake_z) 209 | G_loss = 0.01*adversarial_loss(validity_fake_z, valid) + 0.01*adversarial_loss(discriminator_cat(fake_cat), valid) + 0.98*reconstruction_loss(decoded_x, x) 210 | G_loss.backward() 211 | optimizer_G.step() 212 | 213 | # 2) discriminator loss for gaussian 214 | optimizer_D.zero_grad() 215 | real_z = nn.Parameter(Tensor(np.random.normal(0,1,(x.shape[0], args.latent_dim))), requires_grad=False) 216 | real_loss = adversarial_loss(discriminator(real_z), valid) 217 | fake_loss = adversarial_loss(discriminator(fake_z.detach()), fake) 218 | D_loss = 0.5*(real_loss + fake_loss) 219 | D_loss.backward() 220 | optimizer_D.step() 221 | 222 | # 3) discriminator loss for category 223 | optimizer_D_cat.zero_grad() 224 | real_cat = nn.Parameter(sample_categorical(x.shape[0], n_classes=args.n_classes), requires_grad=False).cuda() 225 | real_cat_loss = adversarial_loss(discriminator_cat(real_cat), valid) 226 | fake_cat_loss = adversarial_loss(discriminator_cat(fake_cat.detach()), fake) 227 | D_cat_loss = 0.5*(real_cat_loss + fake_cat_loss) 228 | D_cat_loss.backward() 229 | optimizer_D_cat.step() 230 | 231 | train_acc = classification_accuracy(encoder, train_labeled_loader) 232 | val_acc = classification_accuracy(encoder, valid_loader) 233 | 234 | print( 235 | "[Epoch %d/%d] [G loss: %f] [D loss: %f] [D_cat loss: %f] [class_loss: %f] [train_acc: %f] [val_acc: %f]" 236 | % (epoch, args.n_epochs, G_loss.item(), D_loss.item(), D_cat_loss.item(), class_loss.item(), train_acc, val_acc) 237 | ) 238 | 239 | sample_image(n_row = 10, epoch=epoch, img_dir=args.img_dir) 240 | --------------------------------------------------------------------------------