├── .gitignore ├── LICENSE ├── README.md ├── asset ├── c1.png └── c2.png ├── dataset └── _empty_ ├── main.py ├── model.py ├── tmp └── _empty_ └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Ping-Yao, Chang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # infoGAN-pytorch 2 | Implementation of infoGAN using PyTorch 3 | 4 | ## Results 5 | **Catagorical variable(K) and continue variable (c1)**
6 | Row represents categorical variable from K = 0 to K = 9 (top to buttom) to characterize digits.
7 | Column represents continue variable (c1) (Width) varying from -1 to 1(left to right).
8 | ![c1](./asset/c1.png) 9 | 10 | **Catagorical variable(K) and continue variable (c2)**
11 | Row represents categorical variable from K = 0 to K = 9 (top to buttom) to characterize digits.
12 | Column represents continue variable (c2) (Rotation) varying from -1 to 1(left to right).
13 | ![c2](./asset/c2.png) 14 | -------------------------------------------------------------------------------- /asset/c1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pianomania/infoGAN-pytorch/db6fec7de791d534912be9999430cccdee85709a/asset/c1.png -------------------------------------------------------------------------------- /asset/c2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pianomania/infoGAN-pytorch/db6fec7de791d534912be9999430cccdee85709a/asset/c2.png -------------------------------------------------------------------------------- /dataset/_empty_: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pianomania/infoGAN-pytorch/db6fec7de791d534912be9999430cccdee85709a/dataset/_empty_ -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from model import * 2 | from trainer import Trainer 3 | 4 | 5 | fe = FrontEnd() 6 | d = D() 7 | q = Q() 8 | g = G() 9 | 10 | for i in [fe, d, q, g]: 11 | i.cuda() 12 | i.apply(weights_init) 13 | 14 | trainer = Trainer(g, fe, d, q) 15 | trainer.train() 16 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class FrontEnd(nn.Module): 5 | ''' front end part of discriminator and Q''' 6 | 7 | def __init__(self): 8 | super(FrontEnd, self).__init__() 9 | 10 | self.main = nn.Sequential( 11 | nn.Conv2d(1, 64, 4, 2, 1), 12 | nn.LeakyReLU(0.1, inplace=True), 13 | nn.Conv2d(64, 128, 4, 2, 1, bias=False), 14 | nn.BatchNorm2d(128), 15 | nn.LeakyReLU(0.1, inplace=True), 16 | nn.Conv2d(128, 1024, 7, bias=False), 17 | nn.BatchNorm2d(1024), 18 | nn.LeakyReLU(0.1, inplace=True), 19 | ) 20 | 21 | def forward(self, x): 22 | output = self.main(x) 23 | return output 24 | 25 | 26 | class D(nn.Module): 27 | 28 | def __init__(self): 29 | super(D, self).__init__() 30 | 31 | self.main = nn.Sequential( 32 | nn.Conv2d(1024, 1, 1), 33 | nn.Sigmoid() 34 | ) 35 | 36 | 37 | def forward(self, x): 38 | output = self.main(x).view(-1, 1) 39 | return output 40 | 41 | 42 | class Q(nn.Module): 43 | 44 | def __init__(self): 45 | super(Q, self).__init__() 46 | 47 | self.conv = nn.Conv2d(1024, 128, 1, bias=False) 48 | self.bn = nn.BatchNorm2d(128) 49 | self.lReLU = nn.LeakyReLU(0.1, inplace=True) 50 | self.conv_disc = nn.Conv2d(128, 10, 1) 51 | self.conv_mu = nn.Conv2d(128, 2, 1) 52 | self.conv_var = nn.Conv2d(128, 2, 1) 53 | 54 | def forward(self, x): 55 | 56 | y = self.conv(x) 57 | 58 | disc_logits = self.conv_disc(y).squeeze() 59 | 60 | mu = self.conv_mu(y).squeeze() 61 | var = self.conv_var(y).squeeze().exp() 62 | 63 | return disc_logits, mu, var 64 | 65 | 66 | class G(nn.Module): 67 | 68 | def __init__(self): 69 | super(G, self).__init__() 70 | 71 | self.main = nn.Sequential( 72 | nn.ConvTranspose2d(74, 1024, 1, 1, bias=False), 73 | nn.BatchNorm2d(1024), 74 | nn.ReLU(True), 75 | nn.ConvTranspose2d(1024, 128, 7, 1, bias=False), 76 | nn.BatchNorm2d(128), 77 | nn.ReLU(True), 78 | nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False), 79 | nn.BatchNorm2d(64), 80 | nn.ReLU(True), 81 | nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False), 82 | nn.Sigmoid() 83 | ) 84 | 85 | def forward(self, x): 86 | output = self.main(x) 87 | return output 88 | 89 | def weights_init(m): 90 | classname = m.__class__.__name__ 91 | if classname.find('Conv') != -1: 92 | m.weight.data.normal_(0.0, 0.02) 93 | elif classname.find('BatchNorm') != -1: 94 | m.weight.data.normal_(1.0, 0.02) 95 | m.bias.data.fill_(0) -------------------------------------------------------------------------------- /tmp/_empty_: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pianomania/infoGAN-pytorch/db6fec7de791d534912be9999430cccdee85709a/tmp/_empty_ -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.autograd import Variable 5 | import torch.autograd as autograd 6 | 7 | import torchvision.datasets as dset 8 | import torchvision.transforms as transforms 9 | from torch.utils.data import DataLoader 10 | from torchvision.utils import save_image 11 | 12 | import numpy as np 13 | 14 | class log_gaussian: 15 | 16 | def __call__(self, x, mu, var): 17 | 18 | logli = -0.5*(var.mul(2*np.pi)+1e-6).log() - \ 19 | (x-mu).pow(2).div(var.mul(2.0)+1e-6) 20 | 21 | return logli.sum(1).mean().mul(-1) 22 | 23 | class Trainer: 24 | 25 | def __init__(self, G, FE, D, Q): 26 | 27 | self.G = G 28 | self.FE = FE 29 | self.D = D 30 | self.Q = Q 31 | 32 | self.batch_size = 100 33 | 34 | def _noise_sample(self, dis_c, con_c, noise, bs): 35 | 36 | idx = np.random.randint(10, size=bs) 37 | c = np.zeros((bs, 10)) 38 | c[range(bs),idx] = 1.0 39 | 40 | dis_c.data.copy_(torch.Tensor(c)) 41 | con_c.data.uniform_(-1.0, 1.0) 42 | noise.data.uniform_(-1.0, 1.0) 43 | z = torch.cat([noise, dis_c, con_c], 1).view(-1, 74, 1, 1) 44 | 45 | return z, idx 46 | 47 | def train(self): 48 | 49 | real_x = torch.FloatTensor(self.batch_size, 1, 28, 28).cuda() 50 | label = torch.FloatTensor(self.batch_size, 1).cuda() 51 | dis_c = torch.FloatTensor(self.batch_size, 10).cuda() 52 | con_c = torch.FloatTensor(self.batch_size, 2).cuda() 53 | noise = torch.FloatTensor(self.batch_size, 62).cuda() 54 | 55 | real_x = Variable(real_x) 56 | label = Variable(label, requires_grad=False) 57 | dis_c = Variable(dis_c) 58 | con_c = Variable(con_c) 59 | noise = Variable(noise) 60 | 61 | criterionD = nn.BCELoss().cuda() 62 | criterionQ_dis = nn.CrossEntropyLoss().cuda() 63 | criterionQ_con = log_gaussian() 64 | 65 | optimD = optim.Adam([{'params':self.FE.parameters()}, {'params':self.D.parameters()}], lr=0.0002, betas=(0.5, 0.99)) 66 | optimG = optim.Adam([{'params':self.G.parameters()}, {'params':self.Q.parameters()}], lr=0.001, betas=(0.5, 0.99)) 67 | 68 | dataset = dset.MNIST('./dataset', transform=transforms.ToTensor(), download=True) 69 | dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=1) 70 | 71 | # fixed random variables 72 | c = np.linspace(-1, 1, 10).reshape(1, -1) 73 | c = np.repeat(c, 10, 0).reshape(-1, 1) 74 | 75 | c1 = np.hstack([c, np.zeros_like(c)]) 76 | c2 = np.hstack([np.zeros_like(c), c]) 77 | 78 | idx = np.arange(10).repeat(10) 79 | one_hot = np.zeros((100, 10)) 80 | one_hot[range(100), idx] = 1 81 | fix_noise = torch.Tensor(100, 62).uniform_(-1, 1) 82 | 83 | 84 | for epoch in range(100): 85 | for num_iters, batch_data in enumerate(dataloader, 0): 86 | 87 | # real part 88 | optimD.zero_grad() 89 | 90 | x, _ = batch_data 91 | 92 | bs = x.size(0) 93 | real_x.data.resize_(x.size()) 94 | label.data.resize_(bs, 1) 95 | dis_c.data.resize_(bs, 10) 96 | con_c.data.resize_(bs, 2) 97 | noise.data.resize_(bs, 62) 98 | 99 | real_x.data.copy_(x) 100 | fe_out1 = self.FE(real_x) 101 | probs_real = self.D(fe_out1) 102 | label.data.fill_(1) 103 | loss_real = criterionD(probs_real, label) 104 | loss_real.backward() 105 | 106 | # fake part 107 | z, idx = self._noise_sample(dis_c, con_c, noise, bs) 108 | fake_x = self.G(z) 109 | fe_out2 = self.FE(fake_x.detach()) 110 | probs_fake = self.D(fe_out2) 111 | label.data.fill_(0) 112 | loss_fake = criterionD(probs_fake, label) 113 | loss_fake.backward() 114 | 115 | D_loss = loss_real + loss_fake 116 | 117 | optimD.step() 118 | 119 | # G and Q part 120 | optimG.zero_grad() 121 | 122 | fe_out = self.FE(fake_x) 123 | probs_fake = self.D(fe_out) 124 | label.data.fill_(1.0) 125 | 126 | reconstruct_loss = criterionD(probs_fake, label) 127 | 128 | q_logits, q_mu, q_var = self.Q(fe_out) 129 | class_ = torch.LongTensor(idx).cuda() 130 | target = Variable(class_) 131 | dis_loss = criterionQ_dis(q_logits, target) 132 | con_loss = criterionQ_con(con_c, q_mu, q_var)*0.1 133 | 134 | G_loss = reconstruct_loss + dis_loss + con_loss 135 | G_loss.backward() 136 | optimG.step() 137 | 138 | if num_iters % 100 == 0: 139 | 140 | print('Epoch/Iter:{0}/{1}, Dloss: {2}, Gloss: {3}'.format( 141 | epoch, num_iters, D_loss.data.cpu().numpy(), 142 | G_loss.data.cpu().numpy()) 143 | ) 144 | 145 | noise.data.copy_(fix_noise) 146 | dis_c.data.copy_(torch.Tensor(one_hot)) 147 | 148 | con_c.data.copy_(torch.from_numpy(c1)) 149 | z = torch.cat([noise, dis_c, con_c], 1).view(-1, 74, 1, 1) 150 | x_save = self.G(z) 151 | save_image(x_save.data, './tmp/c1.png', nrow=10) 152 | 153 | con_c.data.copy_(torch.from_numpy(c2)) 154 | z = torch.cat([noise, dis_c, con_c], 1).view(-1, 74, 1, 1) 155 | x_save = self.G(z) 156 | save_image(x_save.data, './tmp/c2.png', nrow=10) 157 | --------------------------------------------------------------------------------