├── dataloader.py ├── misc.py ├── makeLabel.py ├── README.md ├── models.py ├── main.py └── CAAE_128_jupyter.ipynb /dataloader.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | import torchvision.datasets as dset 3 | import torchvision.utils as vutils 4 | from PIL import ImageFile 5 | import torch 6 | 7 | ImageFile.LOAD_TRUNCATED_IMAGES = True 8 | 9 | 10 | 11 | def loadImgs(des_dir = "./data/",img_size=128,batchSize = 20): 12 | 13 | dataset = dset.ImageFolder(root=des_dir, 14 | transform=transforms.Compose([ 15 | transforms.Resize(img_size), 16 | transforms.ToTensor(), 17 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 18 | ])) 19 | 20 | dataloader = torch.utils.data.DataLoader(dataset, 21 | batch_size= batchSize, 22 | shuffle=True) 23 | 24 | return dataloader 25 | -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch import optim 3 | from torch.autograd import Variable 4 | import torch 5 | 6 | def weights_init(m): 7 | classname = m.__class__.__name__ 8 | if classname.find('Conv') != -1 or classname.find("Linear") !=-1: 9 | m.weight.data.normal_(0.0, 0.02) 10 | elif classname.find('BatchNorm') != -1: 11 | m.weight.data.normal_(1.0, 0.02) 12 | m.bias.data.fill_(0) 13 | 14 | def one_hot(labelTensor,batchSize,n_l,use_cuda=False): 15 | oneHot = - torch.ones(batchSize*n_l).view(batchSize,n_l) 16 | for i,j in enumerate(labelTensor): 17 | oneHot[i,j] = 1 18 | if use_cuda: 19 | return Variable(oneHot).cuda() 20 | else: 21 | return Variable(oneHot) 22 | 23 | def TV_LOSS(imgTensor,img_size=128): 24 | x = (imgTensor[:,:,1:,:]-imgTensor[:,:,:img_size-1,:])**2 25 | y = (imgTensor[:,:,:,1:]-imgTensor[:,:,:,:img_size-1])**2 26 | 27 | out = (x.mean(dim=2)+y.mean(dim=3)).mean() 28 | return out 29 | -------------------------------------------------------------------------------- /makeLabel.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | origin_dir = "./UTKFace" 4 | des_dir = "./data" 5 | imgFiles = [file for file in os.listdir(origin_dir)] 6 | 7 | def encodeAge(n): 8 | if n<=5: 9 | return 0 10 | elif n<=10: 11 | return 1 12 | elif n<=15: 13 | return 2 14 | elif n<=20: 15 | return 3 16 | elif n<=30: 17 | return 4 18 | elif n<=40: 19 | return 5 20 | elif n<=50: 21 | return 6 22 | elif n<=60: 23 | return 7 24 | elif n<=70: 25 | return 8 26 | else: 27 | return 9 28 | 29 | 30 | def makeDir(): 31 | if not os.path.exists(des_dir): 32 | os.mkdir(des_dir) 33 | 34 | for i in range(20): 35 | new_folder = os.path.join(des_dir,format(i,"<02")) 36 | if not os.path.exists(new_folder): 37 | os.mkdir(new_folder) 38 | 39 | def moveFiles(): 40 | for file in imgFiles: 41 | lst = file.split("_") 42 | 43 | age = int(lst[0]) 44 | gender = int(lst[1]) 45 | 46 | folder = format(encodeAge(age)*2 + gender,"<02") 47 | origin_file = os.path.join(origin_dir,file) 48 | des_file = os.path.join(des_dir,folder,file) 49 | os.rename(origin_file,des_file) 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Face-Aging-CAAE-Pytorch 2 | 3 | * Pytorch implementation of [Age Progression/Regression by Conditional Adversarial Autoencoder](http://web.eecs.utk.edu/~zzhang61/docs/papers/2017_CVPR_Age.pdf) 4 | * reference: [TensorFlow implementation of CAAE](https://github.com/ZZUTK/Face-Aging-CAAE) 5 | * gave a presentation in [2017 YONSEI BIGDATA CONFERENCE](https://onoffmix.com/event/121883) by team FACEBIGTA. 6 | 7 | 8 | ## Requirements 9 | * pytorch 0.2.0 10 | * [UTKFace Aligned&Cropped](https://drive.google.com/drive/folders/0BxYys69jI14kU0I1YUQyY1ZDRUE) dataset 11 | 12 | ## Usage 13 | * git clone or download zip file of this repository 14 | * download Aligned & Cropped version of UTKFace from [here](https://drive.google.com/drive/folders/0BxYys69jI14kU0I1YUQyY1ZDRUE) 15 | * execute main.py in bash 16 | > python main.py 17 | 18 | ## Results 19 | 20 | **UTKFace** 21 | > rows: years of 0 ~ 5, 5 ~ 10, 10 ~ 15, 16 ~ 20, 21 ~ 30, 31 ~ 40, 41 ~ 50, 51 ~ 60, 61 ~ 70, over 70 22 | 23 | 24 | 25 |

26 | 27 | **Irene, Korean Celebrity** 28 | 29 | 30 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | ## import models 2 | 3 | import torch 4 | from torch import nn 5 | from torch import optim 6 | from torch.autograd import Variable 7 | 8 | 9 | n_channel = 3 10 | n_disc = 16 11 | n_gen = 64 12 | n_encode = 64 13 | n_l = 10 14 | n_z = 50 15 | img_size = 128 16 | batchSize = 20 17 | use_cuda = torch.cuda.is_available() 18 | n_age = int(n_z/n_l) 19 | n_gender = int(n_z/2) 20 | 21 | class Encoder(nn.Module): 22 | def __init__(self): 23 | super(Encoder,self).__init__() 24 | self.conv = nn.Sequential( 25 | #input: 3*128*128 26 | nn.Conv2d(n_channel,n_encode,5,2,2), 27 | nn.ReLU(), 28 | 29 | nn.Conv2d(n_encode,2*n_encode,5,2,2), 30 | nn.ReLU(), 31 | 32 | nn.Conv2d(2*n_encode,4*n_encode,5,2,2), 33 | nn.ReLU(), 34 | 35 | nn.Conv2d(4*n_encode,8*n_encode,5,2,2), 36 | nn.ReLU(), 37 | 38 | ) 39 | self.fc = nn.Linear(8*n_encode*8*8,50) 40 | 41 | def forward(self,x): 42 | conv = self.conv(x).view(-1,8*n_encode*8*8) 43 | out = self.fc(conv) 44 | return out 45 | 46 | class Generator(nn.Module): 47 | def __init__(self): 48 | super(Generator,self).__init__() 49 | self.fc = nn.Sequential(nn.Linear(n_z+n_l*n_age+n_gender, 50 | 8*8*n_gen*16), 51 | nn.ReLU()) 52 | self.upconv= nn.Sequential( 53 | nn.ConvTranspose2d(16*n_gen,8*n_gen,4,2,1), 54 | nn.ReLU(), 55 | 56 | nn.ConvTranspose2d(8*n_gen,4*n_gen,4,2,1), 57 | nn.ReLU(), 58 | 59 | nn.ConvTranspose2d(4*n_gen,2*n_gen,4,2,1), 60 | nn.ReLU(), 61 | 62 | nn.ConvTranspose2d(2*n_gen,n_gen,4,2,1), 63 | nn.ReLU(), 64 | 65 | nn.ConvTranspose2d(n_gen,n_channel,3,1,1), 66 | nn.Tanh(), 67 | 68 | ) 69 | 70 | def forward(self,z,age,gender): 71 | ## duplicate age & gender conditions as descripted in https://github.com/ZZUTK/Face-Aging-CAAE 72 | l = age.repeat(1,n_age).float() 73 | k = gender.view(-1,1).repeat(1,n_gender).float() 74 | 75 | x = torch.cat([z,l,k],dim=1) 76 | fc = self.fc(x).view(-1,16*n_gen,8,8) 77 | out = self.upconv(fc) 78 | return out 79 | 80 | 81 | class Dimg(nn.Module): 82 | def __init__(self): 83 | super(Dimg,self).__init__() 84 | self.conv_img = nn.Sequential( 85 | nn.Conv2d(n_channel,n_disc,4,2,1), 86 | ) 87 | self.conv_l = nn.Sequential( 88 | nn.ConvTranspose2d(n_l*n_age+n_gender, n_l*n_age+n_gender, 64, 1, 0), 89 | nn.ReLU() 90 | ) 91 | self.total_conv = nn.Sequential( 92 | nn.Conv2d(n_disc+n_l*n_age+n_gender,n_disc*2,4,2,1), 93 | nn.ReLU(), 94 | 95 | nn.Conv2d(n_disc*2,n_disc*4,4,2,1), 96 | nn.ReLU(), 97 | 98 | nn.Conv2d(n_disc*4,n_disc*8,4,2,1), 99 | nn.ReLU() 100 | ) 101 | 102 | self.fc_common = nn.Sequential( 103 | nn.Linear(8*8*img_size,1024), 104 | nn.ReLU() 105 | ) 106 | self.fc_head1 = nn.Sequential( 107 | nn.Linear(1024,1), 108 | nn.Sigmoid() 109 | ) 110 | self.fc_head2 = nn.Sequential( 111 | nn.Linear(1024,n_l), 112 | nn.Softmax() 113 | ) 114 | 115 | def forward(self,img,age,gender): 116 | ## duplicate age & gender conditions as descripted in https://github.com/ZZUTK/Face-Aging-CAAE 117 | l = age.repeat(1,n_age,1,1,) 118 | k = gender.repeat(1,n_gender,1,1,) 119 | conv_img = self.conv_img(img) 120 | conv_l = self.conv_l(torch.cat([l,k],dim=1)) 121 | catted = torch.cat((conv_img,conv_l),dim=1) 122 | total_conv = self.total_conv(catted).view(-1,8*8*img_size) 123 | body = self.fc_common(total_conv) 124 | 125 | head1 = self.fc_head1(body) 126 | head2 = self.fc_head2(body) 127 | 128 | return head1,head2 129 | 130 | 131 | class Dz(nn.Module): 132 | def __init__(self): 133 | super(Dz,self).__init__() 134 | self.model = nn.Sequential( 135 | nn.Linear(n_z,n_disc*4), 136 | nn.ReLU(), 137 | 138 | nn.Linear(n_disc*4,n_disc*2), 139 | nn.ReLU(), 140 | 141 | nn.Linear(n_disc*2,n_disc), 142 | nn.ReLU(), 143 | 144 | nn.Linear(n_disc,1), 145 | nn.Sigmoid() 146 | ) 147 | def forward(self,z): 148 | return self.model(z) 149 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import optim 4 | from torch.autograd import Variable 5 | from dataloader import * 6 | from misc import * 7 | from models import * 8 | import pickle 9 | from makeLabel import * 10 | import os 11 | 12 | ## boolean variable indicating whether cuda is available 13 | use_cuda = torch.cuda.is_available() 14 | 15 | makeDir() 16 | moveFiles() 17 | 18 | 19 | dataloader = loadImgs() 20 | 21 | ## build model and use cuda if available 22 | if use_cuda: 23 | netE = Encoder().cuda() 24 | netD_img = Dimg().cuda() 25 | netD_z = Dz().cuda() 26 | netG = Generator().cuda() 27 | else: 28 | netE = Encoder() 29 | netD_img = Dimg() 30 | netD_z = Dz() 31 | netG = Generator() 32 | 33 | ## apply weight initialization 34 | netE.apply(weights_init) 35 | netD_img.apply(weights_init) 36 | netD_z.apply(weights_init) 37 | netG.apply(weights_init) 38 | 39 | ## build optimizer for each networks 40 | optimizerE = optim.Adam(netE.parameters(),lr=0.0002,betas=(0.5,0.999)) 41 | optimizerD_z = optim.Adam(netD_z.parameters(),lr=0.0002,betas=(0.5,0.999)) 42 | optimizerD_img = optim.Adam(netD_img.parameters(),lr=0.0002,betas=(0.5,0.999)) 43 | optimizerG = optim.Adam(netG.parameters(),lr=0.0002,betas=(0.5,0.999)) 44 | 45 | ## build criterions to calculate loss, and use cuda if available 46 | if use_cuda: 47 | BCE = nn.BCELoss().cuda() 48 | L1 = nn.L1Loss().cuda() 49 | CE = nn.CrossEntropyLoss().cuda() 50 | MSE = nn.MSELoss().cuda() 51 | else: 52 | BCE = nn.BCELoss() 53 | L1 = nn.L1Loss() 54 | CE = nn.CrossEntropyLoss() 55 | MSE = nn.MSELoss() 56 | 57 | ## fixed variables to regress / progress age 58 | fixed_l = -torch.ones(80*10).view(80,10) 59 | for i,l in enumerate(fixed_l): 60 | l[i//8] = 1 61 | 62 | fixed_l_v = Variable(fixed_l) 63 | 64 | if use_cuda: 65 | fixed_l_v = fixed_l_v.cuda() 66 | 67 | 68 | outf='./result_tv_gender' 69 | 70 | if os.path.exists(outf): 71 | os.mkdir(outf) 72 | 73 | niter=50 74 | 75 | for epoch in range(niter): 76 | for i,(img_data,img_label) in enumerate(dataloader): 77 | 78 | # make image variable and class variable 79 | 80 | img_data_v = Variable(img_data) 81 | img_age = img_label/2 82 | img_gender = img_label%2*2-1 83 | 84 | img_age_v = Variable(img_age).view(-1,1) 85 | img_gender_v = Variable(img_gender.float()) 86 | 87 | if epoch == 0 and i == 0: 88 | fixed_noise = img_data[:8].repeat(10,1,1,1) 89 | fixed_g = img_gender[:8].view(-1,1).repeat(10,1) 90 | 91 | 92 | fixed_img_v = Variable(fixed_noise) 93 | fixed_g_v = Variable(fixed_g) 94 | 95 | pickle.dump(fixed_noise,open("fixed_noise.p","wb")) 96 | 97 | if use_cuda: 98 | fixed_img_v = fixed_img_v.cuda() 99 | fixed_g_v = fixed_g_v.cuda() 100 | if use_cuda: 101 | img_data_v = img_data_v.cuda() 102 | img_age_v = img_age_v.cuda() 103 | img_gender_v = img_gender_v.cuda() 104 | 105 | # make one hot encoding version of label 106 | batchSize = img_data_v.size(0) 107 | age_ohe = one_hot(img_age,batchSize,n_l,use_cuda) 108 | 109 | # prior distribution z_star, real_label, fake_label 110 | z_star = Variable(torch.FloatTensor(batchSize*n_z).uniform_(-1,1)).view(batchSize,n_z) 111 | real_label = Variable(torch.ones(batchSize).fill_(1)).view(-1,1) 112 | fake_label = Variable(torch.ones(batchSize).fill_(0)).view(-1,1) 113 | 114 | if use_cuda: 115 | z_star, real_label, fake_label = z_star.cuda(),real_label.cuda(),fake_label.cuda() 116 | 117 | 118 | ## train Encoder and Generator with reconstruction loss 119 | netE.zero_grad() 120 | netG.zero_grad() 121 | 122 | # EG_loss 1. L1 reconstruction loss 123 | z = netE(img_data_v) 124 | reconst = netG(z,age_ohe,img_gender_v) 125 | EG_L1_loss = L1(reconst,img_data_v) 126 | 127 | 128 | # EG_loss 2. GAN loss - image 129 | z = netE(img_data_v) 130 | reconst = netG(z,age_ohe,img_gender_v) 131 | D_reconst,_ = netD_img(reconst,age_ohe.view(batchSize,n_l,1,1),img_gender_v.view(batchSize,1,1,1)) 132 | G_img_loss = BCE(D_reconst,real_label) 133 | 134 | 135 | 136 | ## EG_loss 3. GAN loss - z 137 | Dz_prior = netD_z(z_star) 138 | Dz = netD_z(z) 139 | Ez_loss = BCE(Dz,real_label) 140 | 141 | ## EG_loss 4. TV loss - G 142 | reconst = netG(z.detach(),age_ohe,img_gender_v) 143 | G_tv_loss = TV_LOSS(reconst) 144 | 145 | EG_loss = EG_L1_loss + 0.0001*G_img_loss + 0.01*Ez_loss + G_tv_loss 146 | EG_loss.backward() 147 | 148 | optimizerE.step() 149 | optimizerG.step() 150 | 151 | 152 | 153 | ## train netD_z with prior distribution U(-1,1) 154 | netD_z.zero_grad() 155 | Dz_prior = netD_z(z_star) 156 | Dz = netD_z(z.detach()) 157 | 158 | Dz_loss = BCE(Dz_prior,real_label)+BCE(Dz,fake_label) 159 | Dz_loss.backward() 160 | optimizerD_z.step() 161 | 162 | 163 | 164 | ## train D_img with real images 165 | netD_img.zero_grad() 166 | D_img,D_clf = netD_img(img_data_v,age_ohe.view(batchSize,n_l,1,1),img_gender_v.view(batchSize,1,1,1)) 167 | D_reconst,_ = netD_img(reconst.detach(),age_ohe.view(batchSize,n_l,1,1),img_gender_v.view(batchSize,1,1,1)) 168 | 169 | D_loss = BCE(D_img,real_label)+BCE(D_reconst,fake_label) 170 | D_loss.backward() 171 | optimizerD_img.step() 172 | 173 | 174 | 175 | ## save fixed img for every 20 step 176 | fixed_z = netE(fixed_img_v) 177 | fixed_fake = netG(fixed_z,fixed_l_v,fixed_g_v) 178 | vutils.save_image(fixed_fake.data, 179 | '%s/reconst_epoch%03d.png' % (outf,epoch+1), 180 | normalize=True) 181 | 182 | ## checkpoint 183 | if epoch%10==0: 184 | torch.save(netE.state_dict(),"%s/netE_%03d.pth"%(outf,epoch+1)) 185 | torch.save(netG.state_dict(),"%s/netG_%03d.pth"%(outf,epoch+1)) 186 | torch.save(netD_img.state_dict(),"%s/netD_img_%03d.pth"%(outf,epoch+1)) 187 | torch.save(netD_z.state_dict(),"%s/netD_z_%03d.pth"%(outf,epoch+1)) 188 | 189 | 190 | msg1 = "epoch:{}, step:{}".format(epoch+1,i+1) 191 | msg2 = format("EG_L1_loss:%f"%(EG_L1_loss.data[0]),"<30")+"|"+format("G_img_loss:%f"%(G_img_loss.data[0]),"<30") 192 | msg5 = format("G_tv_loss:%f"%(G_tv_loss.data[0]),"<30")+"|"+"Ez_loss:%f"%(Ez_loss.data[0]) 193 | msg3 = format("D_img:%f"%(D_img.mean().data[0]),"<30")+"|"+format("D_reconst:%f"%(D_reconst.mean().data[0]),"<30")\ 194 | +"|"+format("D_loss:%f"%(D_loss.data[0]),"<30") 195 | msg4 = format("D_z:%f"%(Dz.mean().data[0]),"<30")+"|"+format("D_z_prior:%f"%(Dz_prior.mean().data[0]),"<30")\ 196 | +"|"+format("Dz_loss:%f"%(Dz_loss.data[0]),"<30") 197 | 198 | print() 199 | print(msg1) 200 | print(msg2) 201 | print(msg5) 202 | print(msg3) 203 | print(msg4) 204 | print() 205 | print("-"*80) 206 | -------------------------------------------------------------------------------- /CAAE_128_jupyter.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import numpy as np\n", 11 | "import torch\n", 12 | "from torch import nn\n", 13 | "from torch import optim\n", 14 | "from torch.autograd import Variable\n", 15 | "import torchvision.transforms as transforms\n", 16 | "import torchvision.datasets as dset\n", 17 | "import torch.nn.functional as F\n", 18 | "import torchvision.utils as vutils\n", 19 | "import pickle\n", 20 | "from PIL import ImageFile\n", 21 | "ImageFile.LOAD_TRUNCATED_IMAGES = True" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "n_channel = 3\n", 31 | "n_disc = 16\n", 32 | "n_gen = 64\n", 33 | "n_encode = 64\n", 34 | "n_l = 10\n", 35 | "n_z = 50\n", 36 | "img_size = 128\n", 37 | "batchSize = 20\n", 38 | "use_cuda = torch.cuda.is_available()\n", 39 | "n_age = int(n_z/n_l)\n", 40 | "n_gender = int(n_z/2)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "des_dir = \"./data/\"\n", 50 | "\n", 51 | "dataset = dset.ImageFolder(root=des_dir,\n", 52 | " transform=transforms.Compose([\n", 53 | " transforms.Scale(img_size),\n", 54 | " transforms.ToTensor(),\n", 55 | " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", 56 | " ]))\n", 57 | "\n", 58 | "dataloader = torch.utils.data.DataLoader(dataset,\n", 59 | " batch_size= batchSize,\n", 60 | " shuffle=True)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "class Encoder(nn.Module):\n", 70 | " def __init__(self):\n", 71 | " super(Encoder,self).__init__()\n", 72 | " self.conv = nn.Sequential(\n", 73 | " #input: 3*128*128\n", 74 | " nn.Conv2d(n_channel,n_encode,5,2,2),\n", 75 | " nn.ReLU(),\n", 76 | " \n", 77 | " nn.Conv2d(n_encode,2*n_encode,5,2,2),\n", 78 | " nn.ReLU(),\n", 79 | " \n", 80 | " nn.Conv2d(2*n_encode,4*n_encode,5,2,2),\n", 81 | " nn.ReLU(),\n", 82 | " \n", 83 | " nn.Conv2d(4*n_encode,8*n_encode,5,2,2),\n", 84 | " nn.ReLU(),\n", 85 | " \n", 86 | " )\n", 87 | " self.fc = nn.Linear(8*n_encode*8*8,50)\n", 88 | " \n", 89 | " def forward(self,x):\n", 90 | " conv = self.conv(x).view(-1,8*n_encode*8*8)\n", 91 | " out = self.fc(conv)\n", 92 | " return out" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "class Generator(nn.Module):\n", 102 | " def __init__(self):\n", 103 | " super(Generator,self).__init__()\n", 104 | " self.fc = nn.Sequential(nn.Linear(n_z+n_l*n_age+n_gender,\n", 105 | " 8*8*n_gen*16),\n", 106 | " nn.ReLU())\n", 107 | " self.upconv= nn.Sequential(\n", 108 | " nn.ConvTranspose2d(16*n_gen,8*n_gen,4,2,1),\n", 109 | " nn.ReLU(),\n", 110 | " \n", 111 | " nn.ConvTranspose2d(8*n_gen,4*n_gen,4,2,1),\n", 112 | " nn.ReLU(),\n", 113 | " \n", 114 | " nn.ConvTranspose2d(4*n_gen,2*n_gen,4,2,1),\n", 115 | " nn.ReLU(),\n", 116 | " \n", 117 | " nn.ConvTranspose2d(2*n_gen,n_gen,4,2,1),\n", 118 | " nn.ReLU(),\n", 119 | " \n", 120 | " nn.ConvTranspose2d(n_gen,n_channel,3,1,1),\n", 121 | " nn.Tanh(),\n", 122 | " \n", 123 | " )\n", 124 | " \n", 125 | " def forward(self,z,age,gender):\n", 126 | " l = age.repeat(1,n_age)\n", 127 | " k = gender.view(-1,1).repeat(1,n_gender)\n", 128 | " \n", 129 | " x = torch.cat([z,l,k],dim=1)\n", 130 | " fc = self.fc(x).view(-1,16*n_gen,8,8)\n", 131 | " out = self.upconv(fc)\n", 132 | " return out" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "class Dimg(nn.Module):\n", 142 | " def __init__(self):\n", 143 | " super(Dimg,self).__init__()\n", 144 | " self.conv_img = nn.Sequential(\n", 145 | " nn.Conv2d(n_channel,n_disc,4,2,1),\n", 146 | " )\n", 147 | " self.conv_l = nn.Sequential(\n", 148 | " nn.ConvTranspose2d(n_l*n_age+n_gender, n_l*n_age+n_gender, 64, 1, 0),\n", 149 | " nn.ReLU()\n", 150 | " )\n", 151 | " self.total_conv = nn.Sequential(\n", 152 | " nn.Conv2d(n_disc+n_l*n_age+n_gender,n_disc*2,4,2,1),\n", 153 | " nn.ReLU(),\n", 154 | " \n", 155 | " nn.Conv2d(n_disc*2,n_disc*4,4,2,1),\n", 156 | " nn.ReLU(),\n", 157 | " \n", 158 | " nn.Conv2d(n_disc*4,n_disc*8,4,2,1),\n", 159 | " nn.ReLU()\n", 160 | " )\n", 161 | " \n", 162 | " self.fc_common = nn.Sequential(\n", 163 | " nn.Linear(8*8*img_size,1024),\n", 164 | " nn.ReLU()\n", 165 | " )\n", 166 | " self.fc_head1 = nn.Sequential(\n", 167 | " nn.Linear(1024,1),\n", 168 | " nn.Sigmoid()\n", 169 | " )\n", 170 | " self.fc_head2 = nn.Sequential(\n", 171 | " nn.Linear(1024,n_l),\n", 172 | " nn.Softmax()\n", 173 | " )\n", 174 | " \n", 175 | " def forward(self,img,age,gender):\n", 176 | " l = age.repeat(1,n_age,1,1,)\n", 177 | " k = gender.repeat(1,n_gender,1,1,)\n", 178 | " conv_img = self.conv_img(img)\n", 179 | " conv_l = self.conv_l(torch.cat([l,k],dim=1))\n", 180 | " catted = torch.cat((conv_img,conv_l),dim=1)\n", 181 | " total_conv = self.total_conv(catted).view(-1,8*8*img_size)\n", 182 | " body = self.fc_common(total_conv)\n", 183 | " \n", 184 | " head1 = self.fc_head1(body)\n", 185 | " head2 = self.fc_head2(body)\n", 186 | " \n", 187 | " return head1,head2" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "class Dz(nn.Module):\n", 197 | " def __init__(self):\n", 198 | " super(Dz,self).__init__()\n", 199 | " self.model = nn.Sequential(\n", 200 | " nn.Linear(n_z,n_disc*4),\n", 201 | " nn.ReLU(),\n", 202 | " \n", 203 | " nn.Linear(n_disc*4,n_disc*2),\n", 204 | " nn.ReLU(),\n", 205 | " \n", 206 | " nn.Linear(n_disc*2,n_disc),\n", 207 | " nn.ReLU(),\n", 208 | " \n", 209 | " nn.Linear(n_disc,1),\n", 210 | " nn.Sigmoid()\n", 211 | " )\n", 212 | " def forward(self,z):\n", 213 | " return self.model(z)" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "if use_cuda:\n", 223 | " netE = Encoder().cuda()\n", 224 | " netD_img = Dimg().cuda()\n", 225 | " netD_z = Dz().cuda()\n", 226 | " netG = Generator().cuda()\n", 227 | "else:\n", 228 | " netE = Encoder()\n", 229 | " netD_img = Dimg()\n", 230 | " netD_z = Dz()\n", 231 | " netG = Generator()" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "def weights_init(m):\n", 241 | " classname = m.__class__.__name__\n", 242 | " if classname.find('Conv') != -1 or classname.find(\"Linear\") !=-1:\n", 243 | " m.weight.data.normal_(0.0, 0.02)\n", 244 | " elif classname.find('BatchNorm') != -1:\n", 245 | " m.weight.data.normal_(1.0, 0.02)\n", 246 | " m.bias.data.fill_(0)" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "netE.apply(weights_init)\n", 256 | "netD_img.apply(weights_init)\n", 257 | "netD_z.apply(weights_init)\n", 258 | "netG.apply(weights_init)" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "optimizerE = optim.Adam(netE.parameters(),lr=0.0002,betas=(0.5,0.999))\n", 268 | "optimizerD_z = optim.Adam(netD_z.parameters(),lr=0.0002,betas=(0.5,0.999))\n", 269 | "optimizerD_img = optim.Adam(netD_img.parameters(),lr=0.0002,betas=(0.5,0.999))\n", 270 | "optimizerG = optim.Adam(netG.parameters(),lr=0.0002,betas=(0.5,0.999))" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [ 279 | "def one_hot(labelTensor):\n", 280 | " oneHot = - torch.ones(batchSize*n_l).view(batchSize,n_l)\n", 281 | " for i,j in enumerate(labelTensor):\n", 282 | " oneHot[i,j] = 1\n", 283 | " if use_cuda:\n", 284 | " return Variable(oneHot).cuda()\n", 285 | " else:\n", 286 | " return Variable(oneHot)" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": null, 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [ 295 | "if use_cuda:\n", 296 | " BCE = nn.BCELoss().cuda()\n", 297 | " L1 = nn.L1Loss().cuda()\n", 298 | " CE = nn.CrossEntropyLoss().cuda()\n", 299 | " MSE = nn.MSELoss().cuda()\n", 300 | "else:\n", 301 | " BCE = nn.BCELoss()\n", 302 | " L1 = nn.L1Loss()\n", 303 | " CE = nn.CrossEntropyLoss()\n", 304 | " MSE = nn.MSELoss()" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": null, 310 | "metadata": {}, 311 | "outputs": [], 312 | "source": [ 313 | "def TV_LOSS(imgTensor):\n", 314 | " x = (imgTensor[:,:,1:,:]-imgTensor[:,:,:img_size-1,:])**2\n", 315 | " y = (imgTensor[:,:,:,1:]-imgTensor[:,:,:,:img_size-1])**2 \n", 316 | " out = (x.mean(dim=1)+y.mean(dim=1)).mean()\n", 317 | " return out" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": null, 323 | "metadata": {}, 324 | "outputs": [], 325 | "source": [ 326 | "niter=150" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": null, 332 | "metadata": {}, 333 | "outputs": [], 334 | "source": [ 335 | "fixed_noise = pickle.load(open(\"fixed_noise.p\",\"rb\"))" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": null, 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "fixed_l = -torch.ones(80*10).view(80,10)" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": null, 350 | "metadata": {}, 351 | "outputs": [], 352 | "source": [ 353 | "for i,l in enumerate(fixed_l):\n", 354 | " l[i//8] = 1" 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": null, 360 | "metadata": {}, 361 | "outputs": [], 362 | "source": [ 363 | "fixed_g = -1*torch.FloatTensor([1,-1,-1,-1,-1,1,1,1]).view(-1,1).repeat(10,1)" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": null, 369 | "metadata": {}, 370 | "outputs": [], 371 | "source": [ 372 | "fixed_l_v = Variable(fixed_l)\n", 373 | "fixed_img_v = Variable(fixed_noise)\n", 374 | "fixed_g_v = Variable(fixed_g)\n", 375 | "if use_cuda:\n", 376 | " fixed_l_v = fixed_l_v.cuda()\n", 377 | " fixed_img_v = fixed_img_v.cuda()\n", 378 | " fixed_g_v = fixed_g_v.cuda()" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": null, 384 | "metadata": {}, 385 | "outputs": [], 386 | "source": [ 387 | "outf='./result_tv_gender'" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": null, 393 | "metadata": { 394 | "scrolled": true 395 | }, 396 | "outputs": [], 397 | "source": [ 398 | "for epoch in range(30,niter):\n", 399 | " for i,(img_data,img_label) in enumerate(dataloader):\n", 400 | " \n", 401 | " # make image variable and class variable\n", 402 | " \n", 403 | " img_data_v = Variable(img_data)\n", 404 | " img_age = img_label/2\n", 405 | " img_gender = img_label%2*2-1\n", 406 | " \n", 407 | " img_age_v = Variable(img_age).view(-1,1)\n", 408 | " img_gender_v = Variable(img_gender.float())\n", 409 | "\n", 410 | "\n", 411 | " if use_cuda:\n", 412 | " img_data_v = img_data_v.cuda()\n", 413 | " img_age_v = img_age_v.cuda()\n", 414 | " img_gender_v = img_gender_v.cuda() \n", 415 | " \n", 416 | " # make one hot encoding version of label\n", 417 | " batchSize = img_data_v.size(0)\n", 418 | " age_ohe = one_hot(img_age)\n", 419 | " \n", 420 | " # prior distribution z_star, real_label, fake_label\n", 421 | " z_star = Variable(torch.FloatTensor(batchSize*n_z).uniform_(-1,1)).view(batchSize,n_z)\n", 422 | " real_label = Variable(torch.ones(batchSize).fill_(1)).view(-1,1)\n", 423 | " fake_label = Variable(torch.ones(batchSize).fill_(0)).view(-1,1)\n", 424 | " \n", 425 | " if use_cuda:\n", 426 | " z_star, real_label, fake_label = z_star.cuda(),real_label.cuda(),fake_label.cuda()\n", 427 | " \n", 428 | " \n", 429 | " ## train Encoder and Generator with reconstruction loss\n", 430 | " netE.zero_grad()\n", 431 | " netG.zero_grad()\n", 432 | " \n", 433 | " # EG_loss 1. L1 reconstruction loss\n", 434 | " z = netE(img_data_v)\n", 435 | " reconst = netG(z,age_ohe,img_gender_v)\n", 436 | " EG_L1_loss = L1(reconst,img_data_v)\n", 437 | " \n", 438 | " \n", 439 | " # EG_loss 2. GAN loss - image\n", 440 | " z = netE(img_data_v)\n", 441 | " reconst = netG(z,age_ohe,img_gender_v)\n", 442 | " D_reconst,_ = netD_img(reconst,age_ohe.view(batchSize,n_l,1,1),img_gender_v.view(batchSize,1,1,1))\n", 443 | " G_img_loss = BCE(D_reconst,real_label)\n", 444 | "\n", 445 | " \n", 446 | " \n", 447 | " ## EG_loss 3. GAN loss - z \n", 448 | " Dz_prior = netD_z(z_star)\n", 449 | " Dz = netD_z(z)\n", 450 | " Ez_loss = BCE(Dz,real_label)\n", 451 | " \n", 452 | " ## EG_loss 4. TV loss - G\n", 453 | " reconst = netG(z.detach(),age_ohe,img_gender_v)\n", 454 | " G_tv_loss = TV_LOSS(reconst)\n", 455 | " \n", 456 | " EG_loss = EG_L1_loss + 0.0001*G_img_loss + 0.01*Ez_loss + G_tv_loss\n", 457 | " EG_loss.backward()\n", 458 | " \n", 459 | " optimizerE.step()\n", 460 | " optimizerG.step()\n", 461 | " \n", 462 | "\n", 463 | "\n", 464 | " ## train netD_z with prior distribution U(-1,1)\n", 465 | " netD_z.zero_grad() \n", 466 | " Dz_prior = netD_z(z_star)\n", 467 | " Dz = netD_z(z.detach())\n", 468 | " \n", 469 | " Dz_loss = BCE(Dz_prior,real_label)+BCE(Dz,fake_label)\n", 470 | " Dz_loss.backward()\n", 471 | " optimizerD_z.step()\n", 472 | " \n", 473 | "\n", 474 | "\n", 475 | " ## train D_img with real images\n", 476 | " netD_img.zero_grad()\n", 477 | " D_img,D_clf = netD_img(img_data_v,age_ohe.view(batchSize,n_l,1,1),img_gender_v.view(batchSize,1,1,1))\n", 478 | " D_reconst,_ = netD_img(reconst.detach(),age_ohe.view(batchSize,n_l,1,1),img_gender_v.view(batchSize,1,1,1))\n", 479 | "\n", 480 | " D_loss = BCE(D_img,real_label)+BCE(D_reconst,fake_label)\n", 481 | " D_loss.backward()\n", 482 | " optimizerD_img.step()\n", 483 | " \n", 484 | "\n", 485 | " \n", 486 | " ## save fixed img for every 20 step \n", 487 | " fixed_z = netE(fixed_img_v)\n", 488 | " fixed_fake = netG(fixed_z,fixed_l_v,fixed_g_v)\n", 489 | " vutils.save_image(fixed_fake.data,\n", 490 | " '%s/reconst_epoch%03d.png' % (outf,epoch+1),\n", 491 | " normalize=True)\n", 492 | " \n", 493 | " ## checkpoint\n", 494 | " if epoch%10==0:\n", 495 | " torch.save(netE.state_dict(),\"%s/netE_%03d.pth\"%(outf,epoch+1))\n", 496 | " torch.save(netG.state_dict(),\"%s/netG_%03d.pth\"%(outf,epoch+1))\n", 497 | " torch.save(netD_img.state_dict(),\"%s/netD_img_%03d.pth\"%(outf,epoch+1))\n", 498 | " torch.save(netD_z.state_dict(),\"%s/netD_z_%03d.pth\"%(outf,epoch+1))\n", 499 | "\n", 500 | "\n", 501 | " msg1 = \"epoch:{}, step:{}\".format(epoch+1,i+1)\n", 502 | " msg2 = format(\"EG_L1_loss:%f\"%(EG_L1_loss.data[0]),\"<30\")+\"|\"+format(\"G_img_loss:%f\"%(G_img_loss.data[0]),\"<30\")\n", 503 | " msg5 = format(\"G_tv_loss:%f\"%(G_tv_loss.data[0]),\"<30\")+\"|\"+\"Ez_loss:%f\"%(Ez_loss.data[0])\n", 504 | " msg3 = format(\"D_img:%f\"%(D_img.mean().data[0]),\"<30\")+\"|\"+format(\"D_reconst:%f\"%(D_reconst.mean().data[0]),\"<30\")\\\n", 505 | " +\"|\"+format(\"D_loss:%f\"%(D_loss.data[0]),\"<30\")\n", 506 | " msg4 = format(\"D_z:%f\"%(Dz.mean().data[0]),\"<30\")+\"|\"+format(\"D_z_prior:%f\"%(Dz_prior.mean().data[0]),\"<30\")\\\n", 507 | " +\"|\"+format(\"Dz_loss:%f\"%(Dz_loss.data[0]),\"<30\")\n", 508 | "\n", 509 | " print()\n", 510 | " print(msg1)\n", 511 | " print(msg2)\n", 512 | " print(msg5)\n", 513 | " print(msg3)\n", 514 | " print(msg4) \n", 515 | " print()\n", 516 | " print(\"-\"*80)\n", 517 | " \n", 518 | " " 519 | ] 520 | } 521 | ], 522 | "metadata": { 523 | "kernelspec": { 524 | "display_name": "Python 3", 525 | "language": "python", 526 | "name": "python3" 527 | }, 528 | "language_info": { 529 | "codemirror_mode": { 530 | "name": "ipython", 531 | "version": 3 532 | }, 533 | "file_extension": ".py", 534 | "mimetype": "text/x-python", 535 | "name": "python", 536 | "nbconvert_exporter": "python", 537 | "pygments_lexer": "ipython3", 538 | "version": "3.6.3" 539 | } 540 | }, 541 | "nbformat": 4, 542 | "nbformat_minor": 2 543 | } 544 | --------------------------------------------------------------------------------