├── .gitignore
├── LICENSE
├── README.md
├── asset
├── c1.png
└── c2.png
├── dataset
└── _empty_
├── main.py
├── model.py
├── tmp
└── _empty_
└── trainer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask stuff:
57 | instance/
58 | .webassets-cache
59 |
60 | # Scrapy stuff:
61 | .scrapy
62 |
63 | # Sphinx documentation
64 | docs/_build/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # IPython Notebook
70 | .ipynb_checkpoints
71 |
72 | # pyenv
73 | .python-version
74 |
75 | # celery beat schedule file
76 | celerybeat-schedule
77 |
78 | # dotenv
79 | .env
80 |
81 | # virtualenv
82 | venv/
83 | ENV/
84 |
85 | # Spyder project settings
86 | .spyderproject
87 |
88 | # Rope project settings
89 | .ropeproject
90 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Ping-Yao, Chang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # infoGAN-pytorch
2 | Implementation of infoGAN using PyTorch
3 |
4 | ## Results
5 | **Catagorical variable(K) and continue variable (c1)**
6 | Row represents categorical variable from K = 0 to K = 9 (top to buttom) to characterize digits.
7 | Column represents continue variable (c1) (Width) varying from -1 to 1(left to right).
8 | 
9 |
10 | **Catagorical variable(K) and continue variable (c2)**
11 | Row represents categorical variable from K = 0 to K = 9 (top to buttom) to characterize digits.
12 | Column represents continue variable (c2) (Rotation) varying from -1 to 1(left to right).
13 | 
14 |
--------------------------------------------------------------------------------
/asset/c1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pianomania/infoGAN-pytorch/db6fec7de791d534912be9999430cccdee85709a/asset/c1.png
--------------------------------------------------------------------------------
/asset/c2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pianomania/infoGAN-pytorch/db6fec7de791d534912be9999430cccdee85709a/asset/c2.png
--------------------------------------------------------------------------------
/dataset/_empty_:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pianomania/infoGAN-pytorch/db6fec7de791d534912be9999430cccdee85709a/dataset/_empty_
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from model import *
2 | from trainer import Trainer
3 |
4 |
5 | fe = FrontEnd()
6 | d = D()
7 | q = Q()
8 | g = G()
9 |
10 | for i in [fe, d, q, g]:
11 | i.cuda()
12 | i.apply(weights_init)
13 |
14 | trainer = Trainer(g, fe, d, q)
15 | trainer.train()
16 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class FrontEnd(nn.Module):
5 | ''' front end part of discriminator and Q'''
6 |
7 | def __init__(self):
8 | super(FrontEnd, self).__init__()
9 |
10 | self.main = nn.Sequential(
11 | nn.Conv2d(1, 64, 4, 2, 1),
12 | nn.LeakyReLU(0.1, inplace=True),
13 | nn.Conv2d(64, 128, 4, 2, 1, bias=False),
14 | nn.BatchNorm2d(128),
15 | nn.LeakyReLU(0.1, inplace=True),
16 | nn.Conv2d(128, 1024, 7, bias=False),
17 | nn.BatchNorm2d(1024),
18 | nn.LeakyReLU(0.1, inplace=True),
19 | )
20 |
21 | def forward(self, x):
22 | output = self.main(x)
23 | return output
24 |
25 |
26 | class D(nn.Module):
27 |
28 | def __init__(self):
29 | super(D, self).__init__()
30 |
31 | self.main = nn.Sequential(
32 | nn.Conv2d(1024, 1, 1),
33 | nn.Sigmoid()
34 | )
35 |
36 |
37 | def forward(self, x):
38 | output = self.main(x).view(-1, 1)
39 | return output
40 |
41 |
42 | class Q(nn.Module):
43 |
44 | def __init__(self):
45 | super(Q, self).__init__()
46 |
47 | self.conv = nn.Conv2d(1024, 128, 1, bias=False)
48 | self.bn = nn.BatchNorm2d(128)
49 | self.lReLU = nn.LeakyReLU(0.1, inplace=True)
50 | self.conv_disc = nn.Conv2d(128, 10, 1)
51 | self.conv_mu = nn.Conv2d(128, 2, 1)
52 | self.conv_var = nn.Conv2d(128, 2, 1)
53 |
54 | def forward(self, x):
55 |
56 | y = self.conv(x)
57 |
58 | disc_logits = self.conv_disc(y).squeeze()
59 |
60 | mu = self.conv_mu(y).squeeze()
61 | var = self.conv_var(y).squeeze().exp()
62 |
63 | return disc_logits, mu, var
64 |
65 |
66 | class G(nn.Module):
67 |
68 | def __init__(self):
69 | super(G, self).__init__()
70 |
71 | self.main = nn.Sequential(
72 | nn.ConvTranspose2d(74, 1024, 1, 1, bias=False),
73 | nn.BatchNorm2d(1024),
74 | nn.ReLU(True),
75 | nn.ConvTranspose2d(1024, 128, 7, 1, bias=False),
76 | nn.BatchNorm2d(128),
77 | nn.ReLU(True),
78 | nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
79 | nn.BatchNorm2d(64),
80 | nn.ReLU(True),
81 | nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
82 | nn.Sigmoid()
83 | )
84 |
85 | def forward(self, x):
86 | output = self.main(x)
87 | return output
88 |
89 | def weights_init(m):
90 | classname = m.__class__.__name__
91 | if classname.find('Conv') != -1:
92 | m.weight.data.normal_(0.0, 0.02)
93 | elif classname.find('BatchNorm') != -1:
94 | m.weight.data.normal_(1.0, 0.02)
95 | m.bias.data.fill_(0)
--------------------------------------------------------------------------------
/tmp/_empty_:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pianomania/infoGAN-pytorch/db6fec7de791d534912be9999430cccdee85709a/tmp/_empty_
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | from torch.autograd import Variable
5 | import torch.autograd as autograd
6 |
7 | import torchvision.datasets as dset
8 | import torchvision.transforms as transforms
9 | from torch.utils.data import DataLoader
10 | from torchvision.utils import save_image
11 |
12 | import numpy as np
13 |
14 | class log_gaussian:
15 |
16 | def __call__(self, x, mu, var):
17 |
18 | logli = -0.5*(var.mul(2*np.pi)+1e-6).log() - \
19 | (x-mu).pow(2).div(var.mul(2.0)+1e-6)
20 |
21 | return logli.sum(1).mean().mul(-1)
22 |
23 | class Trainer:
24 |
25 | def __init__(self, G, FE, D, Q):
26 |
27 | self.G = G
28 | self.FE = FE
29 | self.D = D
30 | self.Q = Q
31 |
32 | self.batch_size = 100
33 |
34 | def _noise_sample(self, dis_c, con_c, noise, bs):
35 |
36 | idx = np.random.randint(10, size=bs)
37 | c = np.zeros((bs, 10))
38 | c[range(bs),idx] = 1.0
39 |
40 | dis_c.data.copy_(torch.Tensor(c))
41 | con_c.data.uniform_(-1.0, 1.0)
42 | noise.data.uniform_(-1.0, 1.0)
43 | z = torch.cat([noise, dis_c, con_c], 1).view(-1, 74, 1, 1)
44 |
45 | return z, idx
46 |
47 | def train(self):
48 |
49 | real_x = torch.FloatTensor(self.batch_size, 1, 28, 28).cuda()
50 | label = torch.FloatTensor(self.batch_size, 1).cuda()
51 | dis_c = torch.FloatTensor(self.batch_size, 10).cuda()
52 | con_c = torch.FloatTensor(self.batch_size, 2).cuda()
53 | noise = torch.FloatTensor(self.batch_size, 62).cuda()
54 |
55 | real_x = Variable(real_x)
56 | label = Variable(label, requires_grad=False)
57 | dis_c = Variable(dis_c)
58 | con_c = Variable(con_c)
59 | noise = Variable(noise)
60 |
61 | criterionD = nn.BCELoss().cuda()
62 | criterionQ_dis = nn.CrossEntropyLoss().cuda()
63 | criterionQ_con = log_gaussian()
64 |
65 | optimD = optim.Adam([{'params':self.FE.parameters()}, {'params':self.D.parameters()}], lr=0.0002, betas=(0.5, 0.99))
66 | optimG = optim.Adam([{'params':self.G.parameters()}, {'params':self.Q.parameters()}], lr=0.001, betas=(0.5, 0.99))
67 |
68 | dataset = dset.MNIST('./dataset', transform=transforms.ToTensor(), download=True)
69 | dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=1)
70 |
71 | # fixed random variables
72 | c = np.linspace(-1, 1, 10).reshape(1, -1)
73 | c = np.repeat(c, 10, 0).reshape(-1, 1)
74 |
75 | c1 = np.hstack([c, np.zeros_like(c)])
76 | c2 = np.hstack([np.zeros_like(c), c])
77 |
78 | idx = np.arange(10).repeat(10)
79 | one_hot = np.zeros((100, 10))
80 | one_hot[range(100), idx] = 1
81 | fix_noise = torch.Tensor(100, 62).uniform_(-1, 1)
82 |
83 |
84 | for epoch in range(100):
85 | for num_iters, batch_data in enumerate(dataloader, 0):
86 |
87 | # real part
88 | optimD.zero_grad()
89 |
90 | x, _ = batch_data
91 |
92 | bs = x.size(0)
93 | real_x.data.resize_(x.size())
94 | label.data.resize_(bs, 1)
95 | dis_c.data.resize_(bs, 10)
96 | con_c.data.resize_(bs, 2)
97 | noise.data.resize_(bs, 62)
98 |
99 | real_x.data.copy_(x)
100 | fe_out1 = self.FE(real_x)
101 | probs_real = self.D(fe_out1)
102 | label.data.fill_(1)
103 | loss_real = criterionD(probs_real, label)
104 | loss_real.backward()
105 |
106 | # fake part
107 | z, idx = self._noise_sample(dis_c, con_c, noise, bs)
108 | fake_x = self.G(z)
109 | fe_out2 = self.FE(fake_x.detach())
110 | probs_fake = self.D(fe_out2)
111 | label.data.fill_(0)
112 | loss_fake = criterionD(probs_fake, label)
113 | loss_fake.backward()
114 |
115 | D_loss = loss_real + loss_fake
116 |
117 | optimD.step()
118 |
119 | # G and Q part
120 | optimG.zero_grad()
121 |
122 | fe_out = self.FE(fake_x)
123 | probs_fake = self.D(fe_out)
124 | label.data.fill_(1.0)
125 |
126 | reconstruct_loss = criterionD(probs_fake, label)
127 |
128 | q_logits, q_mu, q_var = self.Q(fe_out)
129 | class_ = torch.LongTensor(idx).cuda()
130 | target = Variable(class_)
131 | dis_loss = criterionQ_dis(q_logits, target)
132 | con_loss = criterionQ_con(con_c, q_mu, q_var)*0.1
133 |
134 | G_loss = reconstruct_loss + dis_loss + con_loss
135 | G_loss.backward()
136 | optimG.step()
137 |
138 | if num_iters % 100 == 0:
139 |
140 | print('Epoch/Iter:{0}/{1}, Dloss: {2}, Gloss: {3}'.format(
141 | epoch, num_iters, D_loss.data.cpu().numpy(),
142 | G_loss.data.cpu().numpy())
143 | )
144 |
145 | noise.data.copy_(fix_noise)
146 | dis_c.data.copy_(torch.Tensor(one_hot))
147 |
148 | con_c.data.copy_(torch.from_numpy(c1))
149 | z = torch.cat([noise, dis_c, con_c], 1).view(-1, 74, 1, 1)
150 | x_save = self.G(z)
151 | save_image(x_save.data, './tmp/c1.png', nrow=10)
152 |
153 | con_c.data.copy_(torch.from_numpy(c2))
154 | z = torch.cat([noise, dis_c, con_c], 1).view(-1, 74, 1, 1)
155 | x_save = self.G(z)
156 | save_image(x_save.data, './tmp/c2.png', nrow=10)
157 |
--------------------------------------------------------------------------------