├── dataloader.py
├── misc.py
├── makeLabel.py
├── README.md
├── models.py
├── main.py
└── CAAE_128_jupyter.ipynb
/dataloader.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms as transforms
2 | import torchvision.datasets as dset
3 | import torchvision.utils as vutils
4 | from PIL import ImageFile
5 | import torch
6 |
7 | ImageFile.LOAD_TRUNCATED_IMAGES = True
8 |
9 |
10 |
11 | def loadImgs(des_dir = "./data/",img_size=128,batchSize = 20):
12 |
13 | dataset = dset.ImageFolder(root=des_dir,
14 | transform=transforms.Compose([
15 | transforms.Resize(img_size),
16 | transforms.ToTensor(),
17 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
18 | ]))
19 |
20 | dataloader = torch.utils.data.DataLoader(dataset,
21 | batch_size= batchSize,
22 | shuffle=True)
23 |
24 | return dataloader
25 |
--------------------------------------------------------------------------------
/misc.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch import optim
3 | from torch.autograd import Variable
4 | import torch
5 |
6 | def weights_init(m):
7 | classname = m.__class__.__name__
8 | if classname.find('Conv') != -1 or classname.find("Linear") !=-1:
9 | m.weight.data.normal_(0.0, 0.02)
10 | elif classname.find('BatchNorm') != -1:
11 | m.weight.data.normal_(1.0, 0.02)
12 | m.bias.data.fill_(0)
13 |
14 | def one_hot(labelTensor,batchSize,n_l,use_cuda=False):
15 | oneHot = - torch.ones(batchSize*n_l).view(batchSize,n_l)
16 | for i,j in enumerate(labelTensor):
17 | oneHot[i,j] = 1
18 | if use_cuda:
19 | return Variable(oneHot).cuda()
20 | else:
21 | return Variable(oneHot)
22 |
23 | def TV_LOSS(imgTensor,img_size=128):
24 | x = (imgTensor[:,:,1:,:]-imgTensor[:,:,:img_size-1,:])**2
25 | y = (imgTensor[:,:,:,1:]-imgTensor[:,:,:,:img_size-1])**2
26 |
27 | out = (x.mean(dim=2)+y.mean(dim=3)).mean()
28 | return out
29 |
--------------------------------------------------------------------------------
/makeLabel.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | origin_dir = "./UTKFace"
4 | des_dir = "./data"
5 | imgFiles = [file for file in os.listdir(origin_dir)]
6 |
7 | def encodeAge(n):
8 | if n<=5:
9 | return 0
10 | elif n<=10:
11 | return 1
12 | elif n<=15:
13 | return 2
14 | elif n<=20:
15 | return 3
16 | elif n<=30:
17 | return 4
18 | elif n<=40:
19 | return 5
20 | elif n<=50:
21 | return 6
22 | elif n<=60:
23 | return 7
24 | elif n<=70:
25 | return 8
26 | else:
27 | return 9
28 |
29 |
30 | def makeDir():
31 | if not os.path.exists(des_dir):
32 | os.mkdir(des_dir)
33 |
34 | for i in range(20):
35 | new_folder = os.path.join(des_dir,format(i,"<02"))
36 | if not os.path.exists(new_folder):
37 | os.mkdir(new_folder)
38 |
39 | def moveFiles():
40 | for file in imgFiles:
41 | lst = file.split("_")
42 |
43 | age = int(lst[0])
44 | gender = int(lst[1])
45 |
46 | folder = format(encodeAge(age)*2 + gender,"<02")
47 | origin_file = os.path.join(origin_dir,file)
48 | des_file = os.path.join(des_dir,folder,file)
49 | os.rename(origin_file,des_file)
50 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Face-Aging-CAAE-Pytorch
2 |
3 | * Pytorch implementation of [Age Progression/Regression by Conditional Adversarial Autoencoder](http://web.eecs.utk.edu/~zzhang61/docs/papers/2017_CVPR_Age.pdf)
4 | * reference: [TensorFlow implementation of CAAE](https://github.com/ZZUTK/Face-Aging-CAAE)
5 | * gave a presentation in [2017 YONSEI BIGDATA CONFERENCE](https://onoffmix.com/event/121883) by team FACEBIGTA.
6 |
7 |
8 | ## Requirements
9 | * pytorch 0.2.0
10 | * [UTKFace Aligned&Cropped](https://drive.google.com/drive/folders/0BxYys69jI14kU0I1YUQyY1ZDRUE) dataset
11 |
12 | ## Usage
13 | * git clone or download zip file of this repository
14 | * download Aligned & Cropped version of UTKFace from [here](https://drive.google.com/drive/folders/0BxYys69jI14kU0I1YUQyY1ZDRUE)
15 | * execute main.py in bash
16 | > python main.py
17 |
18 | ## Results
19 |
20 | **UTKFace**
21 | > rows: years of 0 ~ 5, 5 ~ 10, 10 ~ 15, 16 ~ 20, 21 ~ 30, 31 ~ 40, 41 ~ 50, 51 ~ 60, 61 ~ 70, over 70
22 |
23 |
24 |
25 |
26 |
27 | **Irene, Korean Celebrity**
28 |
29 |
30 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | ## import models
2 |
3 | import torch
4 | from torch import nn
5 | from torch import optim
6 | from torch.autograd import Variable
7 |
8 |
9 | n_channel = 3
10 | n_disc = 16
11 | n_gen = 64
12 | n_encode = 64
13 | n_l = 10
14 | n_z = 50
15 | img_size = 128
16 | batchSize = 20
17 | use_cuda = torch.cuda.is_available()
18 | n_age = int(n_z/n_l)
19 | n_gender = int(n_z/2)
20 |
21 | class Encoder(nn.Module):
22 | def __init__(self):
23 | super(Encoder,self).__init__()
24 | self.conv = nn.Sequential(
25 | #input: 3*128*128
26 | nn.Conv2d(n_channel,n_encode,5,2,2),
27 | nn.ReLU(),
28 |
29 | nn.Conv2d(n_encode,2*n_encode,5,2,2),
30 | nn.ReLU(),
31 |
32 | nn.Conv2d(2*n_encode,4*n_encode,5,2,2),
33 | nn.ReLU(),
34 |
35 | nn.Conv2d(4*n_encode,8*n_encode,5,2,2),
36 | nn.ReLU(),
37 |
38 | )
39 | self.fc = nn.Linear(8*n_encode*8*8,50)
40 |
41 | def forward(self,x):
42 | conv = self.conv(x).view(-1,8*n_encode*8*8)
43 | out = self.fc(conv)
44 | return out
45 |
46 | class Generator(nn.Module):
47 | def __init__(self):
48 | super(Generator,self).__init__()
49 | self.fc = nn.Sequential(nn.Linear(n_z+n_l*n_age+n_gender,
50 | 8*8*n_gen*16),
51 | nn.ReLU())
52 | self.upconv= nn.Sequential(
53 | nn.ConvTranspose2d(16*n_gen,8*n_gen,4,2,1),
54 | nn.ReLU(),
55 |
56 | nn.ConvTranspose2d(8*n_gen,4*n_gen,4,2,1),
57 | nn.ReLU(),
58 |
59 | nn.ConvTranspose2d(4*n_gen,2*n_gen,4,2,1),
60 | nn.ReLU(),
61 |
62 | nn.ConvTranspose2d(2*n_gen,n_gen,4,2,1),
63 | nn.ReLU(),
64 |
65 | nn.ConvTranspose2d(n_gen,n_channel,3,1,1),
66 | nn.Tanh(),
67 |
68 | )
69 |
70 | def forward(self,z,age,gender):
71 | ## duplicate age & gender conditions as descripted in https://github.com/ZZUTK/Face-Aging-CAAE
72 | l = age.repeat(1,n_age).float()
73 | k = gender.view(-1,1).repeat(1,n_gender).float()
74 |
75 | x = torch.cat([z,l,k],dim=1)
76 | fc = self.fc(x).view(-1,16*n_gen,8,8)
77 | out = self.upconv(fc)
78 | return out
79 |
80 |
81 | class Dimg(nn.Module):
82 | def __init__(self):
83 | super(Dimg,self).__init__()
84 | self.conv_img = nn.Sequential(
85 | nn.Conv2d(n_channel,n_disc,4,2,1),
86 | )
87 | self.conv_l = nn.Sequential(
88 | nn.ConvTranspose2d(n_l*n_age+n_gender, n_l*n_age+n_gender, 64, 1, 0),
89 | nn.ReLU()
90 | )
91 | self.total_conv = nn.Sequential(
92 | nn.Conv2d(n_disc+n_l*n_age+n_gender,n_disc*2,4,2,1),
93 | nn.ReLU(),
94 |
95 | nn.Conv2d(n_disc*2,n_disc*4,4,2,1),
96 | nn.ReLU(),
97 |
98 | nn.Conv2d(n_disc*4,n_disc*8,4,2,1),
99 | nn.ReLU()
100 | )
101 |
102 | self.fc_common = nn.Sequential(
103 | nn.Linear(8*8*img_size,1024),
104 | nn.ReLU()
105 | )
106 | self.fc_head1 = nn.Sequential(
107 | nn.Linear(1024,1),
108 | nn.Sigmoid()
109 | )
110 | self.fc_head2 = nn.Sequential(
111 | nn.Linear(1024,n_l),
112 | nn.Softmax()
113 | )
114 |
115 | def forward(self,img,age,gender):
116 | ## duplicate age & gender conditions as descripted in https://github.com/ZZUTK/Face-Aging-CAAE
117 | l = age.repeat(1,n_age,1,1,)
118 | k = gender.repeat(1,n_gender,1,1,)
119 | conv_img = self.conv_img(img)
120 | conv_l = self.conv_l(torch.cat([l,k],dim=1))
121 | catted = torch.cat((conv_img,conv_l),dim=1)
122 | total_conv = self.total_conv(catted).view(-1,8*8*img_size)
123 | body = self.fc_common(total_conv)
124 |
125 | head1 = self.fc_head1(body)
126 | head2 = self.fc_head2(body)
127 |
128 | return head1,head2
129 |
130 |
131 | class Dz(nn.Module):
132 | def __init__(self):
133 | super(Dz,self).__init__()
134 | self.model = nn.Sequential(
135 | nn.Linear(n_z,n_disc*4),
136 | nn.ReLU(),
137 |
138 | nn.Linear(n_disc*4,n_disc*2),
139 | nn.ReLU(),
140 |
141 | nn.Linear(n_disc*2,n_disc),
142 | nn.ReLU(),
143 |
144 | nn.Linear(n_disc,1),
145 | nn.Sigmoid()
146 | )
147 | def forward(self,z):
148 | return self.model(z)
149 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch import optim
4 | from torch.autograd import Variable
5 | from dataloader import *
6 | from misc import *
7 | from models import *
8 | import pickle
9 | from makeLabel import *
10 | import os
11 |
12 | ## boolean variable indicating whether cuda is available
13 | use_cuda = torch.cuda.is_available()
14 |
15 | makeDir()
16 | moveFiles()
17 |
18 |
19 | dataloader = loadImgs()
20 |
21 | ## build model and use cuda if available
22 | if use_cuda:
23 | netE = Encoder().cuda()
24 | netD_img = Dimg().cuda()
25 | netD_z = Dz().cuda()
26 | netG = Generator().cuda()
27 | else:
28 | netE = Encoder()
29 | netD_img = Dimg()
30 | netD_z = Dz()
31 | netG = Generator()
32 |
33 | ## apply weight initialization
34 | netE.apply(weights_init)
35 | netD_img.apply(weights_init)
36 | netD_z.apply(weights_init)
37 | netG.apply(weights_init)
38 |
39 | ## build optimizer for each networks
40 | optimizerE = optim.Adam(netE.parameters(),lr=0.0002,betas=(0.5,0.999))
41 | optimizerD_z = optim.Adam(netD_z.parameters(),lr=0.0002,betas=(0.5,0.999))
42 | optimizerD_img = optim.Adam(netD_img.parameters(),lr=0.0002,betas=(0.5,0.999))
43 | optimizerG = optim.Adam(netG.parameters(),lr=0.0002,betas=(0.5,0.999))
44 |
45 | ## build criterions to calculate loss, and use cuda if available
46 | if use_cuda:
47 | BCE = nn.BCELoss().cuda()
48 | L1 = nn.L1Loss().cuda()
49 | CE = nn.CrossEntropyLoss().cuda()
50 | MSE = nn.MSELoss().cuda()
51 | else:
52 | BCE = nn.BCELoss()
53 | L1 = nn.L1Loss()
54 | CE = nn.CrossEntropyLoss()
55 | MSE = nn.MSELoss()
56 |
57 | ## fixed variables to regress / progress age
58 | fixed_l = -torch.ones(80*10).view(80,10)
59 | for i,l in enumerate(fixed_l):
60 | l[i//8] = 1
61 |
62 | fixed_l_v = Variable(fixed_l)
63 |
64 | if use_cuda:
65 | fixed_l_v = fixed_l_v.cuda()
66 |
67 |
68 | outf='./result_tv_gender'
69 |
70 | if os.path.exists(outf):
71 | os.mkdir(outf)
72 |
73 | niter=50
74 |
75 | for epoch in range(niter):
76 | for i,(img_data,img_label) in enumerate(dataloader):
77 |
78 | # make image variable and class variable
79 |
80 | img_data_v = Variable(img_data)
81 | img_age = img_label/2
82 | img_gender = img_label%2*2-1
83 |
84 | img_age_v = Variable(img_age).view(-1,1)
85 | img_gender_v = Variable(img_gender.float())
86 |
87 | if epoch == 0 and i == 0:
88 | fixed_noise = img_data[:8].repeat(10,1,1,1)
89 | fixed_g = img_gender[:8].view(-1,1).repeat(10,1)
90 |
91 |
92 | fixed_img_v = Variable(fixed_noise)
93 | fixed_g_v = Variable(fixed_g)
94 |
95 | pickle.dump(fixed_noise,open("fixed_noise.p","wb"))
96 |
97 | if use_cuda:
98 | fixed_img_v = fixed_img_v.cuda()
99 | fixed_g_v = fixed_g_v.cuda()
100 | if use_cuda:
101 | img_data_v = img_data_v.cuda()
102 | img_age_v = img_age_v.cuda()
103 | img_gender_v = img_gender_v.cuda()
104 |
105 | # make one hot encoding version of label
106 | batchSize = img_data_v.size(0)
107 | age_ohe = one_hot(img_age,batchSize,n_l,use_cuda)
108 |
109 | # prior distribution z_star, real_label, fake_label
110 | z_star = Variable(torch.FloatTensor(batchSize*n_z).uniform_(-1,1)).view(batchSize,n_z)
111 | real_label = Variable(torch.ones(batchSize).fill_(1)).view(-1,1)
112 | fake_label = Variable(torch.ones(batchSize).fill_(0)).view(-1,1)
113 |
114 | if use_cuda:
115 | z_star, real_label, fake_label = z_star.cuda(),real_label.cuda(),fake_label.cuda()
116 |
117 |
118 | ## train Encoder and Generator with reconstruction loss
119 | netE.zero_grad()
120 | netG.zero_grad()
121 |
122 | # EG_loss 1. L1 reconstruction loss
123 | z = netE(img_data_v)
124 | reconst = netG(z,age_ohe,img_gender_v)
125 | EG_L1_loss = L1(reconst,img_data_v)
126 |
127 |
128 | # EG_loss 2. GAN loss - image
129 | z = netE(img_data_v)
130 | reconst = netG(z,age_ohe,img_gender_v)
131 | D_reconst,_ = netD_img(reconst,age_ohe.view(batchSize,n_l,1,1),img_gender_v.view(batchSize,1,1,1))
132 | G_img_loss = BCE(D_reconst,real_label)
133 |
134 |
135 |
136 | ## EG_loss 3. GAN loss - z
137 | Dz_prior = netD_z(z_star)
138 | Dz = netD_z(z)
139 | Ez_loss = BCE(Dz,real_label)
140 |
141 | ## EG_loss 4. TV loss - G
142 | reconst = netG(z.detach(),age_ohe,img_gender_v)
143 | G_tv_loss = TV_LOSS(reconst)
144 |
145 | EG_loss = EG_L1_loss + 0.0001*G_img_loss + 0.01*Ez_loss + G_tv_loss
146 | EG_loss.backward()
147 |
148 | optimizerE.step()
149 | optimizerG.step()
150 |
151 |
152 |
153 | ## train netD_z with prior distribution U(-1,1)
154 | netD_z.zero_grad()
155 | Dz_prior = netD_z(z_star)
156 | Dz = netD_z(z.detach())
157 |
158 | Dz_loss = BCE(Dz_prior,real_label)+BCE(Dz,fake_label)
159 | Dz_loss.backward()
160 | optimizerD_z.step()
161 |
162 |
163 |
164 | ## train D_img with real images
165 | netD_img.zero_grad()
166 | D_img,D_clf = netD_img(img_data_v,age_ohe.view(batchSize,n_l,1,1),img_gender_v.view(batchSize,1,1,1))
167 | D_reconst,_ = netD_img(reconst.detach(),age_ohe.view(batchSize,n_l,1,1),img_gender_v.view(batchSize,1,1,1))
168 |
169 | D_loss = BCE(D_img,real_label)+BCE(D_reconst,fake_label)
170 | D_loss.backward()
171 | optimizerD_img.step()
172 |
173 |
174 |
175 | ## save fixed img for every 20 step
176 | fixed_z = netE(fixed_img_v)
177 | fixed_fake = netG(fixed_z,fixed_l_v,fixed_g_v)
178 | vutils.save_image(fixed_fake.data,
179 | '%s/reconst_epoch%03d.png' % (outf,epoch+1),
180 | normalize=True)
181 |
182 | ## checkpoint
183 | if epoch%10==0:
184 | torch.save(netE.state_dict(),"%s/netE_%03d.pth"%(outf,epoch+1))
185 | torch.save(netG.state_dict(),"%s/netG_%03d.pth"%(outf,epoch+1))
186 | torch.save(netD_img.state_dict(),"%s/netD_img_%03d.pth"%(outf,epoch+1))
187 | torch.save(netD_z.state_dict(),"%s/netD_z_%03d.pth"%(outf,epoch+1))
188 |
189 |
190 | msg1 = "epoch:{}, step:{}".format(epoch+1,i+1)
191 | msg2 = format("EG_L1_loss:%f"%(EG_L1_loss.data[0]),"<30")+"|"+format("G_img_loss:%f"%(G_img_loss.data[0]),"<30")
192 | msg5 = format("G_tv_loss:%f"%(G_tv_loss.data[0]),"<30")+"|"+"Ez_loss:%f"%(Ez_loss.data[0])
193 | msg3 = format("D_img:%f"%(D_img.mean().data[0]),"<30")+"|"+format("D_reconst:%f"%(D_reconst.mean().data[0]),"<30")\
194 | +"|"+format("D_loss:%f"%(D_loss.data[0]),"<30")
195 | msg4 = format("D_z:%f"%(Dz.mean().data[0]),"<30")+"|"+format("D_z_prior:%f"%(Dz_prior.mean().data[0]),"<30")\
196 | +"|"+format("Dz_loss:%f"%(Dz_loss.data[0]),"<30")
197 |
198 | print()
199 | print(msg1)
200 | print(msg2)
201 | print(msg5)
202 | print(msg3)
203 | print(msg4)
204 | print()
205 | print("-"*80)
206 |
--------------------------------------------------------------------------------
/CAAE_128_jupyter.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "import numpy as np\n",
11 | "import torch\n",
12 | "from torch import nn\n",
13 | "from torch import optim\n",
14 | "from torch.autograd import Variable\n",
15 | "import torchvision.transforms as transforms\n",
16 | "import torchvision.datasets as dset\n",
17 | "import torch.nn.functional as F\n",
18 | "import torchvision.utils as vutils\n",
19 | "import pickle\n",
20 | "from PIL import ImageFile\n",
21 | "ImageFile.LOAD_TRUNCATED_IMAGES = True"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": null,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "n_channel = 3\n",
31 | "n_disc = 16\n",
32 | "n_gen = 64\n",
33 | "n_encode = 64\n",
34 | "n_l = 10\n",
35 | "n_z = 50\n",
36 | "img_size = 128\n",
37 | "batchSize = 20\n",
38 | "use_cuda = torch.cuda.is_available()\n",
39 | "n_age = int(n_z/n_l)\n",
40 | "n_gender = int(n_z/2)"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "des_dir = \"./data/\"\n",
50 | "\n",
51 | "dataset = dset.ImageFolder(root=des_dir,\n",
52 | " transform=transforms.Compose([\n",
53 | " transforms.Scale(img_size),\n",
54 | " transforms.ToTensor(),\n",
55 | " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
56 | " ]))\n",
57 | "\n",
58 | "dataloader = torch.utils.data.DataLoader(dataset,\n",
59 | " batch_size= batchSize,\n",
60 | " shuffle=True)"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": null,
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "class Encoder(nn.Module):\n",
70 | " def __init__(self):\n",
71 | " super(Encoder,self).__init__()\n",
72 | " self.conv = nn.Sequential(\n",
73 | " #input: 3*128*128\n",
74 | " nn.Conv2d(n_channel,n_encode,5,2,2),\n",
75 | " nn.ReLU(),\n",
76 | " \n",
77 | " nn.Conv2d(n_encode,2*n_encode,5,2,2),\n",
78 | " nn.ReLU(),\n",
79 | " \n",
80 | " nn.Conv2d(2*n_encode,4*n_encode,5,2,2),\n",
81 | " nn.ReLU(),\n",
82 | " \n",
83 | " nn.Conv2d(4*n_encode,8*n_encode,5,2,2),\n",
84 | " nn.ReLU(),\n",
85 | " \n",
86 | " )\n",
87 | " self.fc = nn.Linear(8*n_encode*8*8,50)\n",
88 | " \n",
89 | " def forward(self,x):\n",
90 | " conv = self.conv(x).view(-1,8*n_encode*8*8)\n",
91 | " out = self.fc(conv)\n",
92 | " return out"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": null,
98 | "metadata": {},
99 | "outputs": [],
100 | "source": [
101 | "class Generator(nn.Module):\n",
102 | " def __init__(self):\n",
103 | " super(Generator,self).__init__()\n",
104 | " self.fc = nn.Sequential(nn.Linear(n_z+n_l*n_age+n_gender,\n",
105 | " 8*8*n_gen*16),\n",
106 | " nn.ReLU())\n",
107 | " self.upconv= nn.Sequential(\n",
108 | " nn.ConvTranspose2d(16*n_gen,8*n_gen,4,2,1),\n",
109 | " nn.ReLU(),\n",
110 | " \n",
111 | " nn.ConvTranspose2d(8*n_gen,4*n_gen,4,2,1),\n",
112 | " nn.ReLU(),\n",
113 | " \n",
114 | " nn.ConvTranspose2d(4*n_gen,2*n_gen,4,2,1),\n",
115 | " nn.ReLU(),\n",
116 | " \n",
117 | " nn.ConvTranspose2d(2*n_gen,n_gen,4,2,1),\n",
118 | " nn.ReLU(),\n",
119 | " \n",
120 | " nn.ConvTranspose2d(n_gen,n_channel,3,1,1),\n",
121 | " nn.Tanh(),\n",
122 | " \n",
123 | " )\n",
124 | " \n",
125 | " def forward(self,z,age,gender):\n",
126 | " l = age.repeat(1,n_age)\n",
127 | " k = gender.view(-1,1).repeat(1,n_gender)\n",
128 | " \n",
129 | " x = torch.cat([z,l,k],dim=1)\n",
130 | " fc = self.fc(x).view(-1,16*n_gen,8,8)\n",
131 | " out = self.upconv(fc)\n",
132 | " return out"
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": null,
138 | "metadata": {},
139 | "outputs": [],
140 | "source": [
141 | "class Dimg(nn.Module):\n",
142 | " def __init__(self):\n",
143 | " super(Dimg,self).__init__()\n",
144 | " self.conv_img = nn.Sequential(\n",
145 | " nn.Conv2d(n_channel,n_disc,4,2,1),\n",
146 | " )\n",
147 | " self.conv_l = nn.Sequential(\n",
148 | " nn.ConvTranspose2d(n_l*n_age+n_gender, n_l*n_age+n_gender, 64, 1, 0),\n",
149 | " nn.ReLU()\n",
150 | " )\n",
151 | " self.total_conv = nn.Sequential(\n",
152 | " nn.Conv2d(n_disc+n_l*n_age+n_gender,n_disc*2,4,2,1),\n",
153 | " nn.ReLU(),\n",
154 | " \n",
155 | " nn.Conv2d(n_disc*2,n_disc*4,4,2,1),\n",
156 | " nn.ReLU(),\n",
157 | " \n",
158 | " nn.Conv2d(n_disc*4,n_disc*8,4,2,1),\n",
159 | " nn.ReLU()\n",
160 | " )\n",
161 | " \n",
162 | " self.fc_common = nn.Sequential(\n",
163 | " nn.Linear(8*8*img_size,1024),\n",
164 | " nn.ReLU()\n",
165 | " )\n",
166 | " self.fc_head1 = nn.Sequential(\n",
167 | " nn.Linear(1024,1),\n",
168 | " nn.Sigmoid()\n",
169 | " )\n",
170 | " self.fc_head2 = nn.Sequential(\n",
171 | " nn.Linear(1024,n_l),\n",
172 | " nn.Softmax()\n",
173 | " )\n",
174 | " \n",
175 | " def forward(self,img,age,gender):\n",
176 | " l = age.repeat(1,n_age,1,1,)\n",
177 | " k = gender.repeat(1,n_gender,1,1,)\n",
178 | " conv_img = self.conv_img(img)\n",
179 | " conv_l = self.conv_l(torch.cat([l,k],dim=1))\n",
180 | " catted = torch.cat((conv_img,conv_l),dim=1)\n",
181 | " total_conv = self.total_conv(catted).view(-1,8*8*img_size)\n",
182 | " body = self.fc_common(total_conv)\n",
183 | " \n",
184 | " head1 = self.fc_head1(body)\n",
185 | " head2 = self.fc_head2(body)\n",
186 | " \n",
187 | " return head1,head2"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": null,
193 | "metadata": {},
194 | "outputs": [],
195 | "source": [
196 | "class Dz(nn.Module):\n",
197 | " def __init__(self):\n",
198 | " super(Dz,self).__init__()\n",
199 | " self.model = nn.Sequential(\n",
200 | " nn.Linear(n_z,n_disc*4),\n",
201 | " nn.ReLU(),\n",
202 | " \n",
203 | " nn.Linear(n_disc*4,n_disc*2),\n",
204 | " nn.ReLU(),\n",
205 | " \n",
206 | " nn.Linear(n_disc*2,n_disc),\n",
207 | " nn.ReLU(),\n",
208 | " \n",
209 | " nn.Linear(n_disc,1),\n",
210 | " nn.Sigmoid()\n",
211 | " )\n",
212 | " def forward(self,z):\n",
213 | " return self.model(z)"
214 | ]
215 | },
216 | {
217 | "cell_type": "code",
218 | "execution_count": null,
219 | "metadata": {},
220 | "outputs": [],
221 | "source": [
222 | "if use_cuda:\n",
223 | " netE = Encoder().cuda()\n",
224 | " netD_img = Dimg().cuda()\n",
225 | " netD_z = Dz().cuda()\n",
226 | " netG = Generator().cuda()\n",
227 | "else:\n",
228 | " netE = Encoder()\n",
229 | " netD_img = Dimg()\n",
230 | " netD_z = Dz()\n",
231 | " netG = Generator()"
232 | ]
233 | },
234 | {
235 | "cell_type": "code",
236 | "execution_count": null,
237 | "metadata": {},
238 | "outputs": [],
239 | "source": [
240 | "def weights_init(m):\n",
241 | " classname = m.__class__.__name__\n",
242 | " if classname.find('Conv') != -1 or classname.find(\"Linear\") !=-1:\n",
243 | " m.weight.data.normal_(0.0, 0.02)\n",
244 | " elif classname.find('BatchNorm') != -1:\n",
245 | " m.weight.data.normal_(1.0, 0.02)\n",
246 | " m.bias.data.fill_(0)"
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": null,
252 | "metadata": {},
253 | "outputs": [],
254 | "source": [
255 | "netE.apply(weights_init)\n",
256 | "netD_img.apply(weights_init)\n",
257 | "netD_z.apply(weights_init)\n",
258 | "netG.apply(weights_init)"
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": null,
264 | "metadata": {},
265 | "outputs": [],
266 | "source": [
267 | "optimizerE = optim.Adam(netE.parameters(),lr=0.0002,betas=(0.5,0.999))\n",
268 | "optimizerD_z = optim.Adam(netD_z.parameters(),lr=0.0002,betas=(0.5,0.999))\n",
269 | "optimizerD_img = optim.Adam(netD_img.parameters(),lr=0.0002,betas=(0.5,0.999))\n",
270 | "optimizerG = optim.Adam(netG.parameters(),lr=0.0002,betas=(0.5,0.999))"
271 | ]
272 | },
273 | {
274 | "cell_type": "code",
275 | "execution_count": null,
276 | "metadata": {},
277 | "outputs": [],
278 | "source": [
279 | "def one_hot(labelTensor):\n",
280 | " oneHot = - torch.ones(batchSize*n_l).view(batchSize,n_l)\n",
281 | " for i,j in enumerate(labelTensor):\n",
282 | " oneHot[i,j] = 1\n",
283 | " if use_cuda:\n",
284 | " return Variable(oneHot).cuda()\n",
285 | " else:\n",
286 | " return Variable(oneHot)"
287 | ]
288 | },
289 | {
290 | "cell_type": "code",
291 | "execution_count": null,
292 | "metadata": {},
293 | "outputs": [],
294 | "source": [
295 | "if use_cuda:\n",
296 | " BCE = nn.BCELoss().cuda()\n",
297 | " L1 = nn.L1Loss().cuda()\n",
298 | " CE = nn.CrossEntropyLoss().cuda()\n",
299 | " MSE = nn.MSELoss().cuda()\n",
300 | "else:\n",
301 | " BCE = nn.BCELoss()\n",
302 | " L1 = nn.L1Loss()\n",
303 | " CE = nn.CrossEntropyLoss()\n",
304 | " MSE = nn.MSELoss()"
305 | ]
306 | },
307 | {
308 | "cell_type": "code",
309 | "execution_count": null,
310 | "metadata": {},
311 | "outputs": [],
312 | "source": [
313 | "def TV_LOSS(imgTensor):\n",
314 | " x = (imgTensor[:,:,1:,:]-imgTensor[:,:,:img_size-1,:])**2\n",
315 | " y = (imgTensor[:,:,:,1:]-imgTensor[:,:,:,:img_size-1])**2 \n",
316 | " out = (x.mean(dim=1)+y.mean(dim=1)).mean()\n",
317 | " return out"
318 | ]
319 | },
320 | {
321 | "cell_type": "code",
322 | "execution_count": null,
323 | "metadata": {},
324 | "outputs": [],
325 | "source": [
326 | "niter=150"
327 | ]
328 | },
329 | {
330 | "cell_type": "code",
331 | "execution_count": null,
332 | "metadata": {},
333 | "outputs": [],
334 | "source": [
335 | "fixed_noise = pickle.load(open(\"fixed_noise.p\",\"rb\"))"
336 | ]
337 | },
338 | {
339 | "cell_type": "code",
340 | "execution_count": null,
341 | "metadata": {},
342 | "outputs": [],
343 | "source": [
344 | "fixed_l = -torch.ones(80*10).view(80,10)"
345 | ]
346 | },
347 | {
348 | "cell_type": "code",
349 | "execution_count": null,
350 | "metadata": {},
351 | "outputs": [],
352 | "source": [
353 | "for i,l in enumerate(fixed_l):\n",
354 | " l[i//8] = 1"
355 | ]
356 | },
357 | {
358 | "cell_type": "code",
359 | "execution_count": null,
360 | "metadata": {},
361 | "outputs": [],
362 | "source": [
363 | "fixed_g = -1*torch.FloatTensor([1,-1,-1,-1,-1,1,1,1]).view(-1,1).repeat(10,1)"
364 | ]
365 | },
366 | {
367 | "cell_type": "code",
368 | "execution_count": null,
369 | "metadata": {},
370 | "outputs": [],
371 | "source": [
372 | "fixed_l_v = Variable(fixed_l)\n",
373 | "fixed_img_v = Variable(fixed_noise)\n",
374 | "fixed_g_v = Variable(fixed_g)\n",
375 | "if use_cuda:\n",
376 | " fixed_l_v = fixed_l_v.cuda()\n",
377 | " fixed_img_v = fixed_img_v.cuda()\n",
378 | " fixed_g_v = fixed_g_v.cuda()"
379 | ]
380 | },
381 | {
382 | "cell_type": "code",
383 | "execution_count": null,
384 | "metadata": {},
385 | "outputs": [],
386 | "source": [
387 | "outf='./result_tv_gender'"
388 | ]
389 | },
390 | {
391 | "cell_type": "code",
392 | "execution_count": null,
393 | "metadata": {
394 | "scrolled": true
395 | },
396 | "outputs": [],
397 | "source": [
398 | "for epoch in range(30,niter):\n",
399 | " for i,(img_data,img_label) in enumerate(dataloader):\n",
400 | " \n",
401 | " # make image variable and class variable\n",
402 | " \n",
403 | " img_data_v = Variable(img_data)\n",
404 | " img_age = img_label/2\n",
405 | " img_gender = img_label%2*2-1\n",
406 | " \n",
407 | " img_age_v = Variable(img_age).view(-1,1)\n",
408 | " img_gender_v = Variable(img_gender.float())\n",
409 | "\n",
410 | "\n",
411 | " if use_cuda:\n",
412 | " img_data_v = img_data_v.cuda()\n",
413 | " img_age_v = img_age_v.cuda()\n",
414 | " img_gender_v = img_gender_v.cuda() \n",
415 | " \n",
416 | " # make one hot encoding version of label\n",
417 | " batchSize = img_data_v.size(0)\n",
418 | " age_ohe = one_hot(img_age)\n",
419 | " \n",
420 | " # prior distribution z_star, real_label, fake_label\n",
421 | " z_star = Variable(torch.FloatTensor(batchSize*n_z).uniform_(-1,1)).view(batchSize,n_z)\n",
422 | " real_label = Variable(torch.ones(batchSize).fill_(1)).view(-1,1)\n",
423 | " fake_label = Variable(torch.ones(batchSize).fill_(0)).view(-1,1)\n",
424 | " \n",
425 | " if use_cuda:\n",
426 | " z_star, real_label, fake_label = z_star.cuda(),real_label.cuda(),fake_label.cuda()\n",
427 | " \n",
428 | " \n",
429 | " ## train Encoder and Generator with reconstruction loss\n",
430 | " netE.zero_grad()\n",
431 | " netG.zero_grad()\n",
432 | " \n",
433 | " # EG_loss 1. L1 reconstruction loss\n",
434 | " z = netE(img_data_v)\n",
435 | " reconst = netG(z,age_ohe,img_gender_v)\n",
436 | " EG_L1_loss = L1(reconst,img_data_v)\n",
437 | " \n",
438 | " \n",
439 | " # EG_loss 2. GAN loss - image\n",
440 | " z = netE(img_data_v)\n",
441 | " reconst = netG(z,age_ohe,img_gender_v)\n",
442 | " D_reconst,_ = netD_img(reconst,age_ohe.view(batchSize,n_l,1,1),img_gender_v.view(batchSize,1,1,1))\n",
443 | " G_img_loss = BCE(D_reconst,real_label)\n",
444 | "\n",
445 | " \n",
446 | " \n",
447 | " ## EG_loss 3. GAN loss - z \n",
448 | " Dz_prior = netD_z(z_star)\n",
449 | " Dz = netD_z(z)\n",
450 | " Ez_loss = BCE(Dz,real_label)\n",
451 | " \n",
452 | " ## EG_loss 4. TV loss - G\n",
453 | " reconst = netG(z.detach(),age_ohe,img_gender_v)\n",
454 | " G_tv_loss = TV_LOSS(reconst)\n",
455 | " \n",
456 | " EG_loss = EG_L1_loss + 0.0001*G_img_loss + 0.01*Ez_loss + G_tv_loss\n",
457 | " EG_loss.backward()\n",
458 | " \n",
459 | " optimizerE.step()\n",
460 | " optimizerG.step()\n",
461 | " \n",
462 | "\n",
463 | "\n",
464 | " ## train netD_z with prior distribution U(-1,1)\n",
465 | " netD_z.zero_grad() \n",
466 | " Dz_prior = netD_z(z_star)\n",
467 | " Dz = netD_z(z.detach())\n",
468 | " \n",
469 | " Dz_loss = BCE(Dz_prior,real_label)+BCE(Dz,fake_label)\n",
470 | " Dz_loss.backward()\n",
471 | " optimizerD_z.step()\n",
472 | " \n",
473 | "\n",
474 | "\n",
475 | " ## train D_img with real images\n",
476 | " netD_img.zero_grad()\n",
477 | " D_img,D_clf = netD_img(img_data_v,age_ohe.view(batchSize,n_l,1,1),img_gender_v.view(batchSize,1,1,1))\n",
478 | " D_reconst,_ = netD_img(reconst.detach(),age_ohe.view(batchSize,n_l,1,1),img_gender_v.view(batchSize,1,1,1))\n",
479 | "\n",
480 | " D_loss = BCE(D_img,real_label)+BCE(D_reconst,fake_label)\n",
481 | " D_loss.backward()\n",
482 | " optimizerD_img.step()\n",
483 | " \n",
484 | "\n",
485 | " \n",
486 | " ## save fixed img for every 20 step \n",
487 | " fixed_z = netE(fixed_img_v)\n",
488 | " fixed_fake = netG(fixed_z,fixed_l_v,fixed_g_v)\n",
489 | " vutils.save_image(fixed_fake.data,\n",
490 | " '%s/reconst_epoch%03d.png' % (outf,epoch+1),\n",
491 | " normalize=True)\n",
492 | " \n",
493 | " ## checkpoint\n",
494 | " if epoch%10==0:\n",
495 | " torch.save(netE.state_dict(),\"%s/netE_%03d.pth\"%(outf,epoch+1))\n",
496 | " torch.save(netG.state_dict(),\"%s/netG_%03d.pth\"%(outf,epoch+1))\n",
497 | " torch.save(netD_img.state_dict(),\"%s/netD_img_%03d.pth\"%(outf,epoch+1))\n",
498 | " torch.save(netD_z.state_dict(),\"%s/netD_z_%03d.pth\"%(outf,epoch+1))\n",
499 | "\n",
500 | "\n",
501 | " msg1 = \"epoch:{}, step:{}\".format(epoch+1,i+1)\n",
502 | " msg2 = format(\"EG_L1_loss:%f\"%(EG_L1_loss.data[0]),\"<30\")+\"|\"+format(\"G_img_loss:%f\"%(G_img_loss.data[0]),\"<30\")\n",
503 | " msg5 = format(\"G_tv_loss:%f\"%(G_tv_loss.data[0]),\"<30\")+\"|\"+\"Ez_loss:%f\"%(Ez_loss.data[0])\n",
504 | " msg3 = format(\"D_img:%f\"%(D_img.mean().data[0]),\"<30\")+\"|\"+format(\"D_reconst:%f\"%(D_reconst.mean().data[0]),\"<30\")\\\n",
505 | " +\"|\"+format(\"D_loss:%f\"%(D_loss.data[0]),\"<30\")\n",
506 | " msg4 = format(\"D_z:%f\"%(Dz.mean().data[0]),\"<30\")+\"|\"+format(\"D_z_prior:%f\"%(Dz_prior.mean().data[0]),\"<30\")\\\n",
507 | " +\"|\"+format(\"Dz_loss:%f\"%(Dz_loss.data[0]),\"<30\")\n",
508 | "\n",
509 | " print()\n",
510 | " print(msg1)\n",
511 | " print(msg2)\n",
512 | " print(msg5)\n",
513 | " print(msg3)\n",
514 | " print(msg4) \n",
515 | " print()\n",
516 | " print(\"-\"*80)\n",
517 | " \n",
518 | " "
519 | ]
520 | }
521 | ],
522 | "metadata": {
523 | "kernelspec": {
524 | "display_name": "Python 3",
525 | "language": "python",
526 | "name": "python3"
527 | },
528 | "language_info": {
529 | "codemirror_mode": {
530 | "name": "ipython",
531 | "version": 3
532 | },
533 | "file_extension": ".py",
534 | "mimetype": "text/x-python",
535 | "name": "python",
536 | "nbconvert_exporter": "python",
537 | "pygments_lexer": "ipython3",
538 | "version": "3.6.3"
539 | }
540 | },
541 | "nbformat": 4,
542 | "nbformat_minor": 2
543 | }
544 |
--------------------------------------------------------------------------------