├── .gitignore ├── README.md ├── generator.py ├── main.py ├── main_1.py ├── network.py ├── network_1.py ├── results ├── original.png ├── recon.png └── sampled.png └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | /data 104 | /logs 105 | /.idea 106 | /runs -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VAEGAN-PYTORCH 2 | VAEGAN from "Autoencoding beyond pixels using a learned similarity metric" implemented in Pytorch. 3 | Clean, clear and with comments. 4 | 5 | ## Requirements 6 | * pytorch 7 | * torchvision 8 | * tensorboard-pytorch 9 | * progressbar2 10 | * matplotlib 11 | 12 | For packages version please refer to my ```pip freeze``` command output below: 13 | 14 | ``` 15 | bleach==1.5.0 16 | cycler==0.10.0 17 | decorator==4.1.2 18 | enum34==1.1.6 19 | html5lib==0.9999999 20 | ipython==6.2.1 21 | ipython-genutils==0.2.0 22 | jedi==0.11.0 23 | Markdown==2.6.9 24 | matplotlib==2.0.2 25 | numpy==1.13.3 26 | olefile==0.44 27 | parso==0.1.0 28 | pexpect==4.2.1 29 | pickleshare==0.7.4 30 | Pillow==4.3.0 31 | pkg-resources==0.0.0 32 | progressbar2==3.34.3 33 | prompt-toolkit==1.0.15 34 | protobuf==3.4.0 35 | ptyprocess==0.5.2 36 | Pygments==2.2.0 37 | pyparsing==2.2.0 38 | PyQt5==5.9.1 39 | python-dateutil==2.6.1 40 | python-utils==2.2.0 41 | pytz==2017.3 42 | PyYAML==3.12 43 | scipy==1.0.0 44 | simplegeneric==0.8.1 45 | sip==4.19.5 46 | six==1.11.0 47 | tensorboardX==0.8 48 | tensorflow==1.4.0 49 | tensorflow-tensorboard==0.4.0rc2 50 | torch==0.2.0+28f3d50 51 | torchvision==0.1.9 52 | traitlets==4.3.2 53 | wcwidth==0.1.7 54 | Werkzeug==0.12.2 55 | ``` 56 | ## Visual Results 57 | Results after 13 epochs using lr=0.0001 58 | 59 | ![Alt text](/results/original.png?raw=true "Original") 60 | ![Alt text](/results/recon.png?raw=true "Reconstructed") 61 | ![Alt text](/results/sampled.png?raw=true "Sampled") 62 | 63 | Reconstructed are not bad (images never seen before), still generated could be better. 64 | 65 | ## Implementation details 66 | So, using GAN makes training REALLY unstable. In every moment the generator or the descriminator could collapse, rendering awful results. As such, some tricks have been employed in the original implementation (and also here) to try to solve this instability: 67 | 68 | ### Equilibrium Theory 69 | As one of the two player in the minmax game of the adversarial train tends to overcome the other and to break the equilibrium, the former gets punished by stopping its update. This is achieved thanks to separate optimizers for each of the 3 sub-network of the implementation. The equilibrium value is set from the orginal implementation. 70 | 71 | ### Gradient Clip 72 | Even if it's not used in this implementation (nor in the original as far as I know), some projects out there clip the gradient from each of the 3 losses between ```[-1,1]```. This could prevent degenerative patterns. 73 | 74 | ### Low Learning Rate 75 | 0.0001 is a really low lr, but even a slightly higher could lead to strange patterns to appear. 76 | 77 | ### Exponential Decay 78 | Don't know if this helps really :) 79 | 80 | ## Theory Explanation 81 | Here be Dragons. 82 | 83 | ### VAE 84 | Plain vae makes just a statistical assumptions: everything is a Gaussian (and that works btw). 85 | The base structure is formed by an encoder and a decoder...so it's an autoencoder? NO my fella, it's not. In fact, What you decode it's not the code you have generated, but a sample from a gaussian space (whose parameters are what you have generated). So you're not encoding a discrete latent space, but the parameters of it. 86 | 87 | The final loss is just the reconstruction error between original and reconstructed images, plus a KL-divergence. 88 | 89 | #### What the heck is KL-divergence? 90 | To put it simple, is just a way to bring closer two distributions. In our case we want our latent distribution to be a gaussian N(0,I), so we can sample from it using only samples from a standard gaussian N(0,I). If the KLD is not included in the loss function the latent space could be spreaded out in the N-dimensional space, and our samples for the test phase would be just random noise. 91 | 92 | ### GAN 93 | I really don't have enough time to explain to you why I hate so much GAN. They are unstable, hard to train and to understand, but they do work (sometimes). Moving from a plain VAE to the VAEGAN has been a pain in the ass and it tooks me 2 weeks, so I think i'm not well-suited to talk you about them. What we need to know here is that to obtain the VAEGAN we just stick a discriminator at the end of the plain VAE. 94 | 95 | ### VAEGAN 96 | Instead of forcing a low reconstruction error, VAEGAN imposes a low errore between intermediate features in the descriminator. If you think about it, if the reconstructed image is very similar to the original one their middle-representation in the descriminator should be similar too. This is why the paper drops the use of what they call "element-wise-error", and prefers the "feature-wise-error"(FWE). They also made some strong assumption on the GAN loss, as they use only the original batch and the sampled from a gaussian N(0,I) to compute it, leaving out the reconstructed batch. 97 | 98 | ##### Encoder Loss 99 | ```KLD+FWE``` 100 | So latent space close to a gaussian, but with samples resembling the originals for the descriminator. 101 | 102 | ##### Decoder Loss 103 | ```alpha*FWE-GAN``` 104 | Yeah I know...this is not how GAN are usually trained for the generator part, as one should swap the labels (so fake became real and real became fake). I'm still wondering if they lead to the same results (my graphs seems to suggest otherwise). 105 | Alpha here is really low (1e-06 if i remember correctly), probably beacuse the error is computed using the sum fro the single images (so it's the mean of the error between layers, and not the mean of the mean error of layers..what did I just write?) 106 | ##### Gan Loss 107 | ```GAN``` 108 | Nothing special to say here, including the reconstructed loss seems to lower the results (and I REALLY REALLY don't understand why). There are too many hyperparams to investigate them all. 109 | 110 | 111 | 112 | ## TODO 113 | - [x] requirements 114 | - [x] visual results 115 | - [ ] TB log 116 | - [x] theory explanation 117 | - [x] implementation details 118 | -------------------------------------------------------------------------------- /generator.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import os 3 | from torch.utils.data import Dataset, DataLoader 4 | import cv2 5 | from skimage import filters,transform 6 | numpy.random.seed(5) 7 | 8 | def _resize(img): 9 | rescale_size = 64 10 | bbox = (40, 218 - 30, 15, 178 - 15) 11 | img = img[bbox[0]:bbox[1], bbox[2]:bbox[3]] 12 | # Smooth image before resize to avoid moire patterns 13 | scale = img.shape[0] / float(rescale_size) 14 | sigma = numpy.sqrt(scale) / 2.0 15 | img = filters.gaussian(img, sigma=sigma, multichannel=True) 16 | img = transform.resize(img, (rescale_size, rescale_size, 3), order=3,mode="constant") 17 | img = (img*255).astype(numpy.uint8) 18 | return img 19 | 20 | class CELEBA(Dataset): 21 | """ 22 | loader for the CELEB-A dataset 23 | """ 24 | 25 | def __init__(self, data_folder): 26 | #len is the number of files 27 | self.len = len(os.listdir(data_folder)) 28 | #list of file names 29 | self.data_names = [os.path.join(data_folder, name) for name in sorted(os.listdir(data_folder))] 30 | #data_all 31 | #if "train" in data_folder: 32 | # self.data = numpy.load("/home/lapis/Desktop/full_train.npy") 33 | #else: 34 | # self.data = numpy.load("/home/lapis/Desktop/full_test.npy") 35 | 36 | self.len = len(self.data_names) 37 | def __len__(self): 38 | return self.len 39 | 40 | def __iter__(self): 41 | return self 42 | 43 | def __getitem__(self, item): 44 | """ 45 | 46 | :param item: image index between 0-(len-1) 47 | :return: image 48 | """ 49 | #load image,crop 128x128,resize,transpose(to channel first),scale (so we can use tanh) 50 | data = cv2.cvtColor(cv2.imread(self.data_names[item]), cv2.COLOR_BGR2RGB) 51 | 52 | data = _resize(data) 53 | 54 | # CHANNEL FIRST 55 | data = data.transpose(2, 0, 1) 56 | # TANH 57 | data = data.astype("float32") / 127.5 - 1.0 58 | 59 | return (data.copy(),data.copy()) 60 | 61 | 62 | class CELEBA_SLURM(Dataset): 63 | """ 64 | loader for the CELEB-A dataset 65 | """ 66 | 67 | def __init__(self, data_folder): 68 | #open the file 69 | self.file = open(os.path.join(data_folder,"imgs"),"rb") 70 | #get len 71 | self.len = int(os.path.getsize(os.path.join(data_folder,"imgs"))/(64*64*3)) 72 | def __len__(self): 73 | return self.len 74 | 75 | def __iter__(self): 76 | return self 77 | 78 | def __getitem__(self, item): 79 | """ 80 | 81 | :param item: image index between 0-(len-1) 82 | :return: image 83 | """ 84 | offset = item*3*64*64 85 | self.file.seek(offset) 86 | data = numpy.fromfile(self.file, dtype=numpy.uint8, count=(3 * 64 * 64)) 87 | data = numpy.reshape(data, newshape=(3, 64, 64)) 88 | data = data.astype("float32") / 127.5 - 1.0 89 | return (data.copy(),data.copy()) 90 | 91 | 92 | if __name__ == "__main__": 93 | dataset = CELEBA_SLURM(".") 94 | gen = DataLoader(dataset, batch_size=128, shuffle=False,num_workers=1) 95 | #file = open("test",mode="wb+") 96 | from matplotlib import pyplot 97 | imgs = [] 98 | for i,(b,l) in enumerate(gen): 99 | print("{}:{}".format(i,len(gen))) 100 | #b.numpy().astype("uint8").tofile(file) 101 | #file.close() 102 | 103 | 104 | #for i in range(1000): 105 | 106 | #a = gen.__iter__().next() 107 | #scale between (0,1) 108 | #a = (a + 1) / 2 109 | #for el in a: 110 | # pyplot.imshow(numpy.transpose(el.numpy(), (1, 2, 0))) 111 | # pyplot.show() 112 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy 3 | import argparse 4 | numpy.random.seed(8) 5 | torch.manual_seed(8) 6 | torch.cuda.manual_seed(8) 7 | from network import VaeGan 8 | from torch.autograd import Variable 9 | from torch.utils.data import Dataset 10 | from tensorboardX import SummaryWriter 11 | from torch.optim import RMSprop,Adam,SGD 12 | from torch.optim.lr_scheduler import ExponentialLR,MultiStepLR 13 | import progressbar 14 | from torchvision.utils import make_grid 15 | from generator import CELEBA,CELEBA_SLURM 16 | from utils import RollingMeasure 17 | 18 | if __name__ == "__main__": 19 | 20 | parser = argparse.ArgumentParser(description="VAEGAN") 21 | parser.add_argument("--train_folder",action="store",dest="train_folder") 22 | parser.add_argument("--test_folder",action="store",dest="test_folder") 23 | parser.add_argument("--n_epochs",default=12,action="store",type=int,dest="n_epochs") 24 | parser.add_argument("--z_size",default=128,action="store",type=int,dest="z_size") 25 | parser.add_argument("--recon_level",default=3,action="store",type=int,dest="recon_level") 26 | parser.add_argument("--lambda_mse",default=1e-6,action="store",type=float,dest="lambda_mse") 27 | parser.add_argument("--lr",default=3e-4,action="store",type=float,dest="lr") 28 | parser.add_argument("--decay_lr",default=0.75,action="store",type=float,dest="decay_lr") 29 | parser.add_argument("--decay_mse",default=1,action="store",type=float,dest="decay_mse") 30 | parser.add_argument("--decay_margin",default=1,action="store",type=float,dest="decay_margin") 31 | parser.add_argument("--decay_equilibrium",default=1,action="store",type=float,dest="decay_equilibrium") 32 | parser.add_argument("--slurm",default=False,action="store",type=bool,dest="slurm") 33 | 34 | args = parser.parse_args() 35 | 36 | train_folder = args.train_folder 37 | test_folder = args.test_folder 38 | z_size = args.z_size 39 | recon_level = args.recon_level 40 | decay_mse = args.decay_mse 41 | decay_margin = args.decay_margin 42 | n_epochs = args.n_epochs 43 | lambda_mse = args.lambda_mse 44 | lr = args.lr 45 | decay_lr = args.decay_lr 46 | decay_equilibrium = args.decay_equilibrium 47 | slurm = args.slurm 48 | 49 | writer = SummaryWriter(comment="_CELEBA_NEW_DATA_STOCK_GAN") 50 | net = VaeGan(z_size=z_size,recon_level=recon_level).cuda() 51 | 52 | # DATASET 53 | if not slurm: 54 | dataloader = torch.utils.data.DataLoader(CELEBA(train_folder), batch_size=64, 55 | shuffle=True, num_workers=4) 56 | # DATASET for test 57 | # if you want to split train from test just move some files in another dir 58 | dataloader_test = torch.utils.data.DataLoader(CELEBA(test_folder), batch_size=100, 59 | shuffle=False, num_workers=1) 60 | else: 61 | dataloader = torch.utils.data.DataLoader(CELEBA_SLURM(train_folder), batch_size=64, 62 | shuffle=True, num_workers=1) 63 | # DATASET for test 64 | # if you want to split train from test just move some files in another dir 65 | dataloader_test = torch.utils.data.DataLoader(CELEBA_SLURM(test_folder), batch_size=100, 66 | shuffle=False, num_workers=1) 67 | #margin and equilibirum 68 | margin = 0.35 69 | equilibrium = 0.68 70 | #mse_lambda = 1.0 71 | # OPTIM-LOSS 72 | # an optimizer for each of the sub-networks, so we can selectively backprop 73 | #optimizer_encoder = Adam(params=net.encoder.parameters(),lr = lr,betas=(0.9,0.999)) 74 | optimizer_encoder = RMSprop(params=net.encoder.parameters(),lr=lr,alpha=0.9,eps=1e-8,weight_decay=0,momentum=0,centered=False) 75 | #lr_encoder = MultiStepLR(optimizer_encoder,milestones=[2],gamma=1) 76 | lr_encoder = ExponentialLR(optimizer_encoder, gamma=decay_lr) 77 | #optimizer_decoder = Adam(params=net.decoder.parameters(),lr = lr,betas=(0.9,0.999)) 78 | optimizer_decoder = RMSprop(params=net.decoder.parameters(),lr=lr,alpha=0.9,eps=1e-8,weight_decay=0,momentum=0,centered=False) 79 | lr_decoder = ExponentialLR(optimizer_decoder, gamma=decay_lr) 80 | #lr_decoder = MultiStepLR(optimizer_decoder,milestones=[2],gamma=1) 81 | #optimizer_discriminator = Adam(params=net.discriminator.parameters(),lr = lr,betas=(0.9,0.999)) 82 | optimizer_discriminator = RMSprop(params=net.discriminator.parameters(),lr=lr,alpha=0.9,eps=1e-8,weight_decay=0,momentum=0,centered=False) 83 | lr_discriminator = ExponentialLR(optimizer_discriminator, gamma=decay_lr) 84 | #lr_discriminator = MultiStepLR(optimizer_discriminator,milestones=[2],gamma=1) 85 | 86 | batch_number = len(dataloader) 87 | step_index = 0 88 | widgets = [ 89 | 90 | 'Batch: ', progressbar.Counter(), 91 | '/', progressbar.FormatCustomText('%(total)s', {"total": batch_number}), 92 | ' ', progressbar.Bar(marker="-", left='[', right=']'), 93 | ' ', progressbar.ETA(), 94 | ' ', 95 | progressbar.DynamicMessage('loss_nle'), 96 | ' ', 97 | progressbar.DynamicMessage('loss_encoder'), 98 | ' ', 99 | progressbar.DynamicMessage('loss_decoder'), 100 | ' ', 101 | progressbar.DynamicMessage('loss_discriminator'), 102 | ' ', 103 | progressbar.DynamicMessage('loss_mse_layer'), 104 | ' ', 105 | progressbar.DynamicMessage('loss_kld'), 106 | ' ', 107 | progressbar.DynamicMessage("epoch") 108 | ] 109 | # for each epoch 110 | if slurm: 111 | print(args) 112 | for i in range(n_epochs): 113 | 114 | progress = progressbar.ProgressBar(min_value=0, max_value=batch_number, initial_value=0, 115 | widgets=widgets).start() 116 | # reset rolling average 117 | loss_nle_mean = RollingMeasure() 118 | loss_encoder_mean = RollingMeasure() 119 | loss_decoder_mean = RollingMeasure() 120 | loss_discriminator_mean = RollingMeasure() 121 | loss_reconstruction_layer_mean = RollingMeasure() 122 | loss_kld_mean = RollingMeasure() 123 | gan_gen_eq_mean = RollingMeasure() 124 | gan_dis_eq_mean = RollingMeasure() 125 | #print("LR:{}".format(lr_encoder.get_lr())) 126 | 127 | # for each batch 128 | for j, (data_batch,target_batch) in enumerate(dataloader): 129 | 130 | # set to train mode 131 | net.train() 132 | # target and input are the same images 133 | 134 | data_target = Variable(target_batch, requires_grad=False).float().cuda() 135 | data_in = Variable(data_batch, requires_grad=False).float().cuda() 136 | 137 | 138 | # get output 139 | out, out_labels, out_layer, mus, variances = net(data_in) 140 | # split so we can get the different parts 141 | out_layer_predicted = out_layer[:len(out_layer) // 2] 142 | out_layer_original = out_layer[len(out_layer) // 2:] 143 | # TODO set a batch_len variable to get a clean code here 144 | out_labels_original = out_labels[:len(out_labels) // 2] 145 | out_labels_sampled = out_labels[-len(out_labels) // 2:] 146 | # loss, nothing special here 147 | nle_value, kl_value, mse_value, bce_dis_original_value, bce_dis_sampled_value,\ 148 | bce_gen_original_value,bce_gen_sampled_value= VaeGan.loss(data_target, out, out_layer_original, 149 | out_layer_predicted, out_labels_original, 150 | out_labels_sampled, mus, 151 | variances) 152 | # THIS IS THE MOST IMPORTANT PART OF THE CODE 153 | loss_encoder = torch.sum(kl_value)+torch.sum(mse_value) 154 | loss_discriminator = torch.sum(bce_dis_original_value) + torch.sum(bce_dis_sampled_value) 155 | loss_decoder = torch.sum(lambda_mse * mse_value) - (1.0 - lambda_mse) * loss_discriminator 156 | #loss_decoder = torch.sum(mse_lambda * mse_value) + (1.0-mse_lambda)*(torch.sum(bce_gen_sampled_value)+torch.sum(bce_gen_original_value)) 157 | 158 | # register mean values of the losses for logging 159 | loss_nle_mean(torch.mean(nle_value).data.cpu().numpy()[0]) 160 | loss_discriminator_mean((torch.mean(bce_dis_original_value) + torch.mean(bce_dis_sampled_value)).data.cpu().numpy()[0]) 161 | loss_decoder_mean((torch.mean(lambda_mse * mse_value) - (1 - lambda_mse) * (torch.mean(bce_dis_original_value) + torch.mean(bce_dis_sampled_value))).data.cpu().numpy()[0]) 162 | #loss_decoder_mean((torch.mean(mse_lambda * mse_value) + (1-mse_lambda)*(torch.mean(bce_gen_original_value) + torch.mean(bce_gen_sampled_value))).data.cpu().numpy()[0]) 163 | 164 | loss_encoder_mean((torch.mean(kl_value) + torch.mean(mse_value)).data.cpu().numpy()[0]) 165 | loss_reconstruction_layer_mean(torch.mean(mse_value).data.cpu().numpy()[0]) 166 | loss_kld_mean(torch.mean(kl_value).data.cpu().numpy()[0]) 167 | # selectively disable the decoder of the discriminator if they are unbalanced 168 | train_dis = True 169 | train_dec = True 170 | if torch.mean(bce_dis_original_value).data[0] < equilibrium-margin or torch.mean(bce_dis_sampled_value).data[0] < equilibrium-margin: 171 | train_dis = False 172 | if torch.mean(bce_dis_original_value).data[0] > equilibrium+margin or torch.mean(bce_dis_sampled_value).data[0] > equilibrium+margin: 173 | train_dec = False 174 | if train_dec is False and train_dis is False: 175 | train_dis = True 176 | train_dec = True 177 | 178 | #aggiungo log 179 | if train_dis: 180 | gan_dis_eq_mean(1.0) 181 | else: 182 | gan_dis_eq_mean(0.0) 183 | 184 | if train_dec: 185 | gan_gen_eq_mean(1.0) 186 | else: 187 | gan_gen_eq_mean(0.0) 188 | 189 | # BACKPROP 190 | # clean grads 191 | net.zero_grad() 192 | # encoder 193 | loss_encoder.backward(retain_graph=True) 194 | # someone likes to clamp the grad here 195 | #[p.grad.data.clamp_(-1,1) for p in net.encoder.parameters()] 196 | # update parameters 197 | optimizer_encoder.step() 198 | # clean others, so they are not afflicted by encoder loss 199 | net.zero_grad() 200 | #decoder 201 | if train_dec: 202 | loss_decoder.backward(retain_graph=True) 203 | #[p.grad.data.clamp_(-1,1) for p in net.decoder.parameters()] 204 | optimizer_decoder.step() 205 | #clean the discriminator 206 | net.discriminator.zero_grad() 207 | #discriminator 208 | if train_dis: 209 | loss_discriminator.backward() 210 | #[p.grad.data.clamp_(-1,1) for p in net.discriminator.parameters()] 211 | optimizer_discriminator.step() 212 | 213 | # LOGGING 214 | if slurm: 215 | progress.update(progress.value + 1, loss_nle=loss_nle_mean.measure, 216 | loss_encoder=loss_encoder_mean.measure, 217 | loss_decoder=loss_decoder_mean.measure, 218 | loss_discriminator=loss_discriminator_mean.measure, 219 | loss_mse_layer=loss_reconstruction_layer_mean.measure, 220 | loss_kld=loss_kld_mean.measure, 221 | epoch=i + 1) 222 | 223 | # EPOCH END 224 | if slurm: 225 | progress.update(progress.value + 1, loss_nle=loss_nle_mean.measure, 226 | loss_encoder=loss_encoder_mean.measure, 227 | loss_decoder=loss_decoder_mean.measure, 228 | loss_discriminator=loss_discriminator_mean.measure, 229 | loss_mse_layer=loss_reconstruction_layer_mean.measure, 230 | loss_kld=loss_kld_mean.measure, 231 | epoch=i + 1) 232 | lr_encoder.step() 233 | lr_decoder.step() 234 | lr_discriminator.step() 235 | margin *=decay_margin 236 | equilibrium *=decay_equilibrium 237 | #margin non puo essere piu alto di equilibrium 238 | if margin > equilibrium: 239 | equilibrium = margin 240 | lambda_mse *=decay_mse 241 | if lambda_mse > 1: 242 | lambda_mse=1 243 | progress.finish() 244 | 245 | writer.add_scalar('loss_encoder', loss_encoder_mean.measure, step_index) 246 | writer.add_scalar('loss_decoder', loss_decoder_mean.measure, step_index) 247 | writer.add_scalar('loss_discriminator', loss_discriminator_mean.measure, step_index) 248 | writer.add_scalar('loss_reconstruction', loss_nle_mean.measure, step_index) 249 | writer.add_scalar('loss_kld',loss_kld_mean.measure,step_index) 250 | writer.add_scalar('gan_gen',gan_gen_eq_mean.measure,step_index) 251 | writer.add_scalar('gan_dis',gan_dis_eq_mean.measure,step_index) 252 | 253 | for j, (data_batch,target_batch) in enumerate(dataloader_test): 254 | net.eval() 255 | 256 | data_in = Variable(data_batch, requires_grad=False).float().cuda() 257 | data_target = Variable(target_batch, requires_grad=False).float().cuda() 258 | out = net(data_in) 259 | out = out.data.cpu() 260 | out = (out + 1) / 2 261 | out = make_grid(out, nrow=8) 262 | writer.add_image("reconstructed", out, step_index) 263 | 264 | out = net(None, 100) 265 | out = out.data.cpu() 266 | out = (out + 1) / 2 267 | out = make_grid(out, nrow=8) 268 | writer.add_image("generated", out, step_index) 269 | 270 | out = data_target.data.cpu() 271 | out = (out + 1) / 2 272 | out = make_grid(out, nrow=8) 273 | writer.add_image("original", out, step_index) 274 | break 275 | 276 | step_index += 1 277 | exit(0) 278 | -------------------------------------------------------------------------------- /main_1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy 3 | import argparse 4 | numpy.random.seed(8) 5 | torch.manual_seed(8) 6 | torch.cuda.manual_seed(8) 7 | from network_1 import VaeGan 8 | from torch.autograd import Variable 9 | from torch.utils.data import Dataset 10 | from tensorboardX import SummaryWriter 11 | from torch.optim import RMSprop,Adam,SGD 12 | from torch.optim.lr_scheduler import ExponentialLR,MultiStepLR 13 | import progressbar 14 | from torchvision.utils import make_grid 15 | from generator import CELEBA,CELEBA_SLURM 16 | from utils import RollingMeasure 17 | 18 | if __name__ == "__main__": 19 | 20 | parser = argparse.ArgumentParser(description="VAEGAN") 21 | parser.add_argument("--train_folder",action="store",dest="train_folder") 22 | parser.add_argument("--test_folder",action="store",dest="test_folder") 23 | parser.add_argument("--n_epochs",default=12,action="store",type=int,dest="n_epochs") 24 | parser.add_argument("--z_size",default=128,action="store",type=int,dest="z_size") 25 | parser.add_argument("--recon_level",default=3,action="store",type=int,dest="recon_level") 26 | parser.add_argument("--lambda_mse",default=1e-6,action="store",type=float,dest="lambda_mse") 27 | parser.add_argument("--lr",default=3e-4,action="store",type=float,dest="lr") 28 | parser.add_argument("--decay_lr",default=0.75,action="store",type=float,dest="decay_lr") 29 | parser.add_argument("--decay_mse",default=1,action="store",type=float,dest="decay_mse") 30 | parser.add_argument("--decay_margin",default=1,action="store",type=float,dest="decay_margin") 31 | parser.add_argument("--decay_equilibrium",default=1,action="store",type=float,dest="decay_equilibrium") 32 | parser.add_argument("--slurm",default=False,action="store",type=bool,dest="slurm") 33 | 34 | args = parser.parse_args() 35 | 36 | train_folder = args.train_folder 37 | test_folder = args.test_folder 38 | z_size = args.z_size 39 | recon_level = args.recon_level 40 | decay_mse = args.decay_mse 41 | decay_margin = args.decay_margin 42 | n_epochs = args.n_epochs 43 | lambda_mse = args.lambda_mse 44 | lr = args.lr 45 | decay_lr = args.decay_lr 46 | decay_equilibrium = args.decay_equilibrium 47 | slurm = args.slurm 48 | 49 | writer = SummaryWriter(comment="_CELEBA_ALL") 50 | net = VaeGan(z_size=z_size,recon_level=recon_level).cuda() 51 | # DATASET 52 | if not slurm: 53 | dataloader = torch.utils.data.DataLoader(CELEBA(train_folder), batch_size=64, 54 | shuffle=True, num_workers=4) 55 | # DATASET for test 56 | # if you want to split train from test just move some files in another dir 57 | dataloader_test = torch.utils.data.DataLoader(CELEBA(test_folder), batch_size=100, 58 | shuffle=False, num_workers=1) 59 | else: 60 | dataloader = torch.utils.data.DataLoader(CELEBA_SLURM(train_folder), batch_size=64, 61 | shuffle=True, num_workers=4) 62 | # DATASET for test 63 | # if you want to split train from test just move some files in another dir 64 | dataloader_test = torch.utils.data.DataLoader(CELEBA_SLURM(test_folder), batch_size=100, 65 | shuffle=False, num_workers=1) 66 | #margin and equilibirum 67 | margin = 0.35 68 | equilibrium = 0.68 69 | #mse_lambda = 1.0 70 | # OPTIM-LOSS 71 | # an optimizer for each of the sub-networks, so we can selectively backprop 72 | optimizer_encoder = RMSprop(params=net.encoder.parameters(),lr=lr,alpha=0.9,eps=1e-8,weight_decay=0,momentum=0,centered=False) 73 | #lr_encoder = MultiStepLR(optimizer_encoder,milestones=[2],gamma=1) 74 | lr_encoder = ExponentialLR(optimizer_encoder, gamma=decay_lr) 75 | optimizer_decoder = RMSprop(params=net.decoder.parameters(),lr=lr,alpha=0.9,eps=1e-8,weight_decay=0,momentum=0,centered=False) 76 | lr_decoder = ExponentialLR(optimizer_decoder, gamma=decay_lr) 77 | #lr_decoder = MultiStepLR(optimizer_decoder,milestones=[2],gamma=1) 78 | optimizer_discriminator = RMSprop(params=net.discriminator.parameters(),lr=lr,alpha=0.9,eps=1e-8,weight_decay=0,momentum=0,centered=False) 79 | lr_discriminator = ExponentialLR(optimizer_discriminator, gamma=decay_lr) 80 | #lr_discriminator = MultiStepLR(optimizer_discriminator,milestones=[2],gamma=1) 81 | 82 | batch_number = len(dataloader) 83 | step_index = 0 84 | widgets = [ 85 | 86 | 'Batch: ', progressbar.Counter(), 87 | '/', progressbar.FormatCustomText('%(total)s', {"total": batch_number}), 88 | ' ', progressbar.Bar(marker="-", left='[', right=']'), 89 | ' ', progressbar.ETA(), 90 | ' ', 91 | progressbar.DynamicMessage('loss_nle'), 92 | ' ', 93 | progressbar.DynamicMessage('loss_encoder'), 94 | ' ', 95 | progressbar.DynamicMessage('loss_decoder'), 96 | ' ', 97 | progressbar.DynamicMessage('loss_discriminator'), 98 | ' ', 99 | progressbar.DynamicMessage('loss_mse_layer'), 100 | ' ', 101 | progressbar.DynamicMessage('loss_kld'), 102 | ' ', 103 | progressbar.DynamicMessage("epoch") 104 | ] 105 | if slurm: 106 | print(args) 107 | # for each epoch 108 | for i in range(n_epochs): 109 | progress = progressbar.ProgressBar(min_value=0, max_value=batch_number, initial_value=0, 110 | widgets=widgets).start() 111 | # reset rolling average 112 | loss_nle_mean = RollingMeasure() 113 | loss_encoder_mean = RollingMeasure() 114 | loss_decoder_mean = RollingMeasure() 115 | loss_discriminator_mean = RollingMeasure() 116 | loss_reconstruction_layer_mean = RollingMeasure() 117 | loss_kld_mean = RollingMeasure() 118 | gan_gen_eq_mean = RollingMeasure() 119 | gan_dis_eq_mean = RollingMeasure() 120 | #print("LR:{}".format(lr_encoder.get_lr())) 121 | 122 | # for each batch 123 | for j, (data_batch,target_batch) in enumerate(dataloader): 124 | # set to train mode 125 | train_batch = len(data_batch) 126 | net.train() 127 | # target and input are the same images 128 | data_target = Variable(target_batch, requires_grad=False).float().cuda() 129 | data_in = Variable(data_batch, requires_grad=False).float().cuda() 130 | 131 | 132 | # get output 133 | out, out_labels, out_layer, mus, variances = net(data_in) 134 | # split so we can get the different parts 135 | out_layer_predicted = out_layer[:train_batch] 136 | out_layer_original = out_layer[train_batch:-train_batch] 137 | out_layer_sampled = out_layer[-train_batch:] 138 | #labels 139 | out_labels_predicted = out_labels[:train_batch] 140 | out_labels_original = out_labels[train_batch:-train_batch] 141 | out_labels_sampled = out_labels[-train_batch:] 142 | # loss, nothing special here 143 | nle_value, kl_value, mse_value_1,mse_value_2, bce_dis_original_value, bce_dis_sampled_value, \ 144 | bce_dis_predicted_value,bce_gen_sampled_value,bce_gen_predicted_value= VaeGan.loss(data_target, out, out_layer_original, 145 | out_layer_predicted,out_layer_sampled, out_labels_original, 146 | out_labels_predicted,out_labels_sampled, mus, 147 | variances) 148 | # THIS IS THE MOST IMPORTANT PART OF THE CODE 149 | loss_encoder = torch.sum(kl_value)+torch.sum(mse_value_1)+torch.sum(mse_value_2) 150 | loss_discriminator = torch.sum(bce_dis_original_value) + torch.sum(bce_dis_sampled_value)+ torch.sum(bce_dis_predicted_value) 151 | loss_decoder = torch.sum(bce_gen_sampled_value) + torch.sum(bce_gen_predicted_value) 152 | loss_decoder = torch.sum(lambda_mse/2 * mse_value_1)+ torch.sum(lambda_mse/2 * mse_value_2) + (1.0 - lambda_mse) * loss_decoder 153 | 154 | # register mean values of the losses for logging 155 | loss_nle_mean(torch.mean(nle_value).data.cpu().numpy()[0]) 156 | loss_discriminator_mean((torch.mean(bce_dis_original_value) + torch.mean(bce_dis_sampled_value)).data.cpu().numpy()[0]) 157 | loss_decoder_mean((torch.mean(lambda_mse * mse_value_1/2)+torch.mean(lambda_mse * mse_value_2/2) + (1 - lambda_mse) * (torch.mean(bce_gen_predicted_value) + torch.mean(bce_gen_sampled_value))).data.cpu().numpy()[0]) 158 | 159 | loss_encoder_mean((torch.mean(kl_value) + torch.mean(mse_value_1)+ torch.mean(mse_value_2)).data.cpu().numpy()[0]) 160 | loss_reconstruction_layer_mean((torch.mean(mse_value_1)+torch.mean(mse_value_2)).data.cpu().numpy()[0]) 161 | loss_kld_mean(torch.mean(kl_value).data.cpu().numpy()[0]) 162 | # selectively disable the decoder of the discriminator if they are unbalanced 163 | train_dis = True 164 | train_dec = True 165 | if torch.mean(bce_dis_original_value).data[0] < equilibrium-margin or torch.mean(bce_dis_sampled_value).data[0] < equilibrium-margin: 166 | train_dis = False 167 | if torch.mean(bce_dis_original_value).data[0] > equilibrium+margin or torch.mean(bce_dis_sampled_value).data[0] > equilibrium+margin: 168 | train_dec = False 169 | if train_dec is False and train_dis is False: 170 | train_dis = True 171 | train_dec = True 172 | 173 | #aggiungo log 174 | if train_dis: 175 | gan_dis_eq_mean(1.0) 176 | else: 177 | gan_dis_eq_mean(0.0) 178 | 179 | if train_dec: 180 | gan_gen_eq_mean(1.0) 181 | else: 182 | gan_gen_eq_mean(0.0) 183 | 184 | # BACKPROP 185 | # clean grads 186 | net.zero_grad() 187 | # encoder 188 | loss_encoder.backward(retain_graph=True) 189 | # someone likes to clamp the grad here 190 | #[p.grad.data.clamp_(-1,1) for p in net.encoder.parameters()] 191 | # update parameters 192 | optimizer_encoder.step() 193 | # clean others, so they are not afflicted by encoder loss 194 | net.zero_grad() 195 | #decoder 196 | if train_dec: 197 | loss_decoder.backward(retain_graph=True) 198 | #[p.grad.data.clamp_(-1,1) for p in net.decoder.parameters()] 199 | optimizer_decoder.step() 200 | #clean the discriminator 201 | net.discriminator.zero_grad() 202 | #discriminator 203 | if train_dis: 204 | loss_discriminator.backward() 205 | #[p.grad.data.clamp_(-1,1) for p in net.discriminator.parameters()] 206 | optimizer_discriminator.step() 207 | 208 | # LOGGING 209 | if not slurm: 210 | progress.update(progress.value + 1, loss_nle=loss_nle_mean.measure, 211 | loss_encoder=loss_encoder_mean.measure, 212 | loss_decoder=loss_decoder_mean.measure, 213 | loss_discriminator=loss_discriminator_mean.measure, 214 | loss_mse_layer=loss_reconstruction_layer_mean.measure, 215 | loss_kld=loss_kld_mean.measure, 216 | epoch=i + 1) 217 | 218 | if slurm: 219 | progress.update(progress.value, loss_nle=loss_nle_mean.measure, 220 | loss_encoder=loss_encoder_mean.measure, 221 | loss_decoder=loss_decoder_mean.measure, 222 | loss_discriminator=loss_discriminator_mean.measure, 223 | loss_mse_layer=loss_reconstruction_layer_mean.measure, 224 | loss_kld=loss_kld_mean.measure, 225 | epoch=i + 1) 226 | 227 | # EPOCH END 228 | lr_encoder.step() 229 | lr_decoder.step() 230 | lr_discriminator.step() 231 | margin *=decay_margin 232 | equilibrium *=decay_equilibrium 233 | #margin non puo essere piu alto di equilibrium 234 | if margin > equilibrium: 235 | equilibrium = margin 236 | lambda_mse *=decay_mse 237 | if lambda_mse > 1: 238 | lambda_mse=1 239 | progress.finish() 240 | 241 | writer.add_scalar('loss_encoder', loss_encoder_mean.measure, step_index) 242 | writer.add_scalar('loss_decoder', loss_decoder_mean.measure, step_index) 243 | writer.add_scalar('loss_discriminator', loss_discriminator_mean.measure, step_index) 244 | writer.add_scalar('loss_reconstruction', loss_nle_mean.measure, step_index) 245 | writer.add_scalar('loss_kld',loss_kld_mean.measure,step_index) 246 | writer.add_scalar('gan_gen',gan_gen_eq_mean.measure,step_index) 247 | writer.add_scalar('gan_dis',gan_dis_eq_mean.measure,step_index) 248 | 249 | for j, (data_batch,target_batch) in enumerate(dataloader_test): 250 | net.eval() 251 | 252 | data_in = Variable(data_batch, requires_grad=False).float().cuda() 253 | data_target = Variable(target_batch, requires_grad=False).float().cuda() 254 | out = net(data_in) 255 | out = out.data.cpu() 256 | out = (out + 1) / 2 257 | out = make_grid(out, nrow=8) 258 | writer.add_image("reconstructed", out, step_index) 259 | 260 | out = net(None, 100) 261 | out = out.data.cpu() 262 | out = (out + 1) / 2 263 | out = make_grid(out, nrow=8) 264 | writer.add_image("generated", out, step_index) 265 | 266 | out = data_target.data.cpu() 267 | out = (out + 1) / 2 268 | out = make_grid(out, nrow=8) 269 | writer.add_image("original", out, step_index) 270 | break 271 | 272 | step_index += 1 273 | exit(0) 274 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import numpy 6 | 7 | # encoder block (used in encoder and discriminator) 8 | class EncoderBlock(nn.Module): 9 | def __init__(self, channel_in, channel_out): 10 | super(EncoderBlock, self).__init__() 11 | # convolution to halve the dimensions 12 | self.conv = nn.Conv2d(in_channels=channel_in, out_channels=channel_out, kernel_size=5, padding=2, stride=2, 13 | bias=False) 14 | self.bn = nn.BatchNorm2d(num_features=channel_out, momentum=0.9) 15 | 16 | def forward(self, ten, out=False,t = False): 17 | # here we want to be able to take an intermediate output for reconstruction error 18 | if out: 19 | ten = self.conv(ten) 20 | ten_out = ten 21 | ten = self.bn(ten) 22 | ten = F.relu(ten, False) 23 | return ten, ten_out 24 | else: 25 | ten = self.conv(ten) 26 | ten = self.bn(ten) 27 | ten = F.relu(ten, True) 28 | return ten 29 | 30 | 31 | # decoder block (used in the decoder) 32 | class DecoderBlock(nn.Module): 33 | def __init__(self, channel_in, channel_out): 34 | super(DecoderBlock, self).__init__() 35 | # transpose convolution to double the dimensions 36 | self.conv = nn.ConvTranspose2d(channel_in, channel_out, kernel_size=5, padding=2, stride=2, output_padding=1, 37 | bias=False) 38 | self.bn = nn.BatchNorm2d(channel_out, momentum=0.9) 39 | 40 | def forward(self, ten): 41 | ten = self.conv(ten) 42 | ten = self.bn(ten) 43 | ten = F.relu(ten, True) 44 | return ten 45 | 46 | 47 | class Encoder(nn.Module): 48 | def __init__(self, channel_in=3, z_size=128): 49 | super(Encoder, self).__init__() 50 | self.size = channel_in 51 | layers_list = [] 52 | # the first time 3->64, for every other double the channel size 53 | for i in range(3): 54 | if i == 0: 55 | layers_list.append(EncoderBlock(channel_in=self.size, channel_out=64)) 56 | self.size = 64 57 | else: 58 | layers_list.append(EncoderBlock(channel_in=self.size, channel_out=self.size * 2)) 59 | self.size *= 2 60 | 61 | # final shape Bx256x8x8 62 | self.conv = nn.Sequential(*layers_list) 63 | self.fc = nn.Sequential(nn.Linear(in_features=8 * 8 * self.size, out_features=1024, bias=False), 64 | nn.BatchNorm1d(num_features=1024,momentum=0.9), 65 | nn.ReLU(True)) 66 | # two linear to get the mu vector and the diagonal of the log_variance 67 | self.l_mu = nn.Linear(in_features=1024, out_features=z_size) 68 | self.l_var = nn.Linear(in_features=1024, out_features=z_size) 69 | 70 | def forward(self, ten): 71 | ten = self.conv(ten) 72 | ten = ten.view(len(ten), -1) 73 | ten = self.fc(ten) 74 | mu = self.l_mu(ten) 75 | logvar = self.l_var(ten) 76 | return mu, logvar 77 | 78 | def __call__(self, *args, **kwargs): 79 | return super(Encoder, self).__call__(*args, **kwargs) 80 | 81 | 82 | class Decoder(nn.Module): 83 | def __init__(self, z_size, size): 84 | super(Decoder, self).__init__() 85 | # start from B*z_size 86 | self.fc = nn.Sequential(nn.Linear(in_features=z_size, out_features=8 * 8 * size, bias=False), 87 | nn.BatchNorm1d(num_features=8 * 8 * size,momentum=0.9), 88 | nn.ReLU(True)) 89 | self.size = size 90 | layers_list = [] 91 | layers_list.append(DecoderBlock(channel_in=self.size, channel_out=self.size)) 92 | layers_list.append(DecoderBlock(channel_in=self.size, channel_out=self.size//2)) 93 | self.size = self.size//2 94 | layers_list.append(DecoderBlock(channel_in=self.size, channel_out=self.size//4)) 95 | self.size = self.size//4 96 | # final conv to get 3 channels and tanh layer 97 | layers_list.append(nn.Sequential( 98 | nn.Conv2d(in_channels=self.size, out_channels=3, kernel_size=5, stride=1, padding=2), 99 | nn.Tanh() 100 | )) 101 | 102 | self.conv = nn.Sequential(*layers_list) 103 | 104 | def forward(self, ten): 105 | 106 | ten = self.fc(ten) 107 | ten = ten.view(len(ten), -1, 8, 8) 108 | ten = self.conv(ten) 109 | return ten 110 | 111 | def __call__(self, *args, **kwargs): 112 | return super(Decoder, self).__call__(*args, **kwargs) 113 | 114 | 115 | class Discriminator(nn.Module): 116 | def __init__(self, channel_in=3,recon_level=3): 117 | super(Discriminator, self).__init__() 118 | self.size = channel_in 119 | self.recon_levl = recon_level 120 | # module list because we need need to extract an intermediate output 121 | self.conv = nn.ModuleList() 122 | self.conv.append(nn.Sequential( 123 | nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2), 124 | nn.ReLU(inplace=True))) 125 | self.size = 32 126 | self.conv.append(EncoderBlock(channel_in=self.size, channel_out=128)) 127 | self.size = 128 128 | self.conv.append(EncoderBlock(channel_in=self.size, channel_out=256)) 129 | self.size = 256 130 | self.conv.append(EncoderBlock(channel_in=self.size, channel_out=256)) 131 | # final fc to get the score (real or fake) 132 | self.fc = nn.Sequential( 133 | nn.Linear(in_features=8 * 8 * self.size, out_features=512, bias=False), 134 | nn.BatchNorm1d(num_features=512,momentum=0.9), 135 | nn.ReLU(inplace=True), 136 | nn.Linear(in_features=512, out_features=1), 137 | 138 | ) 139 | 140 | def forward(self, ten,other_ten,mode='REC'): 141 | if mode == "REC": 142 | ten = torch.cat((ten, other_ten), 0) 143 | for i, lay in enumerate(self.conv): 144 | # we take the 9th layer as one of the outputs 145 | if i == self.recon_levl: 146 | ten, layer_ten = lay(ten, True) 147 | # we need the layer representations just for the original and reconstructed, 148 | # flatten, because it's a convolutional shape 149 | layer_ten = layer_ten.view(len(layer_ten), -1) 150 | return layer_ten 151 | else: 152 | ten = lay(ten) 153 | else: 154 | ten = torch.cat((ten, other_ten), 0) 155 | for i, lay in enumerate(self.conv): 156 | ten = lay(ten) 157 | 158 | ten = ten.view(len(ten), -1) 159 | ten = self.fc(ten) 160 | return F.sigmoid(ten) 161 | 162 | 163 | def __call__(self, *args, **kwargs): 164 | return super(Discriminator, self).__call__(*args, **kwargs) 165 | 166 | 167 | class VaeGan(nn.Module): 168 | def __init__(self,z_size=128,recon_level=3): 169 | super(VaeGan, self).__init__() 170 | # latent space size 171 | self.z_size = z_size 172 | self.encoder = Encoder(z_size=self.z_size) 173 | self.decoder = Decoder(z_size=self.z_size, size=self.encoder.size) 174 | self.discriminator = Discriminator(channel_in=3,recon_level=recon_level) 175 | # self-defined function to init the parameters 176 | self.init_parameters() 177 | 178 | def init_parameters(self): 179 | # just explore the network, find every weight and bias matrix and fill it 180 | for m in self.modules(): 181 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)): 182 | if hasattr(m, "weight") and m.weight is not None and m.weight.requires_grad: 183 | #init as original implementation 184 | scale = 1.0/numpy.sqrt(numpy.prod(m.weight.shape[1:])) 185 | scale /=numpy.sqrt(3) 186 | #nn.init.xavier_normal(m.weight,1) 187 | #nn.init.constant(m.weight,0.005) 188 | nn.init.uniform(m.weight,-scale,scale) 189 | if hasattr(m, "bias") and m.bias is not None and m.bias.requires_grad: 190 | nn.init.constant(m.bias, 0.0) 191 | 192 | def forward(self, ten, gen_size=10): 193 | if self.training: 194 | # save the original images 195 | ten_original = ten 196 | # encode 197 | mus, log_variances = self.encoder(ten) 198 | # we need the true variances, not the log one 199 | variances = torch.exp(log_variances * 0.5) 200 | # sample from a gaussian 201 | 202 | ten_from_normal = Variable(torch.randn(len(ten), self.z_size).cuda(), requires_grad=True) 203 | # shift and scale using the means and variances 204 | 205 | ten = ten_from_normal * variances + mus 206 | # decode the tensor 207 | ten = self.decoder(ten) 208 | # discriminator for reconstruction 209 | ten_layer = self.discriminator(ten, ten_original, "REC") 210 | # decoder for samples 211 | 212 | ten_from_normal = Variable(torch.randn(len(ten), self.z_size).cuda(), requires_grad=True) 213 | 214 | ten = self.decoder(ten_from_normal) 215 | ten_class = self.discriminator(ten_original, ten, "GAN") 216 | return ten, ten_class, ten_layer, mus, log_variances 217 | else: 218 | if ten is None: 219 | # just sample and decode 220 | 221 | ten = Variable(torch.randn(gen_size, self.z_size).cuda(), requires_grad=False) 222 | ten = self.decoder(ten) 223 | else: 224 | mus, log_variances = self.encoder(ten) 225 | # we need the true variances, not the log one 226 | variances = torch.exp(log_variances * 0.5) 227 | # sample from a gaussian 228 | 229 | ten_from_normal = Variable(torch.randn(len(ten), self.z_size).cuda(), requires_grad=False) 230 | # shift and scale using the means and variances 231 | ten = ten_from_normal * variances + mus 232 | # decode the tensor 233 | ten = self.decoder(ten) 234 | return ten 235 | 236 | 237 | 238 | def __call__(self, *args, **kwargs): 239 | return super(VaeGan, self).__call__(*args, **kwargs) 240 | 241 | @staticmethod 242 | def loss(ten_original, ten_predict, layer_original, layer_predicted, labels_original, 243 | labels_sampled, mus, variances): 244 | """ 245 | 246 | :param ten_original: original images 247 | :param ten_predict: predicted images (output of the decoder) 248 | :param layer_original: intermediate layer for original (intermediate output of the discriminator) 249 | :param layer_predicted: intermediate layer for reconstructed (intermediate output of the discriminator) 250 | :param labels_original: labels for original (output of the discriminator) 251 | :param labels_predicted: labels for reconstructed (output of the discriminator) 252 | :param labels_sampled: labels for sampled from gaussian (0,1) (output of the discriminator) 253 | :param mus: tensor of means 254 | :param variances: tensor of diagonals of log_variances 255 | :return: 256 | """ 257 | 258 | # reconstruction error, not used for the loss but useful to evaluate quality 259 | nle = 0.5*(ten_original.view(len(ten_original), -1) - ten_predict.view(len(ten_predict), -1)) ** 2 260 | # kl-divergence 261 | kl = -0.5 * torch.sum(-variances.exp() - torch.pow(mus,2) + variances + 1, 1) 262 | # mse between intermediate layers 263 | mse = torch.sum(0.5*(layer_original - layer_predicted) ** 2, 1) 264 | # bce for decoder and discriminator for original,sampled and reconstructed 265 | # the only excluded is the bce_gen_original 266 | 267 | bce_dis_original = -torch.log(labels_original + 1e-3) 268 | bce_dis_sampled = -torch.log(1 - labels_sampled + 1e-3) 269 | 270 | bce_gen_original = -torch.log(1-labels_original + 1e-3) 271 | bce_gen_sampled = -torch.log(labels_sampled + 1e-3) 272 | ''' 273 | 274 | 275 | bce_gen_predicted = nn.BCEWithLogitsLoss(size_average=False)(labels_predicted, 276 | Variable(torch.ones_like(labels_predicted.data).cuda(), requires_grad=False)) 277 | bce_gen_sampled = nn.BCEWithLogitsLoss(size_average=False)(labels_sampled, 278 | Variable(torch.ones_like(labels_sampled.data).cuda(), requires_grad=False)) 279 | bce_dis_original = nn.BCEWithLogitsLoss(size_average=False)(labels_original, 280 | Variable(torch.ones_like(labels_original.data).cuda(), requires_grad=False)) 281 | bce_dis_predicted = nn.BCEWithLogitsLoss(size_average=False)(labels_predicted, 282 | Variable(torch.zeros_like(labels_predicted.data).cuda(), requires_grad=False)) 283 | bce_dis_sampled = nn.BCEWithLogitsLoss(size_average=False)(labels_sampled, 284 | Variable(torch.zeros_like(labels_sampled.data).cuda(), requires_grad=False)) 285 | ''' 286 | return nle, kl, mse, bce_dis_original, bce_dis_sampled,bce_gen_original,bce_gen_sampled 287 | -------------------------------------------------------------------------------- /network_1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import numpy 6 | 7 | # encoder block (used in encoder and discriminator) 8 | class EncoderBlock(nn.Module): 9 | def __init__(self, channel_in, channel_out): 10 | super(EncoderBlock, self).__init__() 11 | # convolution to halve the dimensions 12 | self.conv = nn.Conv2d(in_channels=channel_in, out_channels=channel_out, kernel_size=5, padding=2, stride=2, 13 | bias=False) 14 | self.bn = nn.BatchNorm2d(num_features=channel_out, momentum=0.9) 15 | 16 | def forward(self, ten, out=False,t = False): 17 | # here we want to be able to take an intermediate output for reconstruction error 18 | if out: 19 | ten = self.conv(ten) 20 | ten_out = ten 21 | ten = self.bn(ten) 22 | ten = F.relu(ten, False) 23 | return ten, ten_out 24 | else: 25 | ten = self.conv(ten) 26 | ten = self.bn(ten) 27 | ten = F.relu(ten, True) 28 | return ten 29 | 30 | 31 | # decoder block (used in the decoder) 32 | class DecoderBlock(nn.Module): 33 | def __init__(self, channel_in, channel_out): 34 | super(DecoderBlock, self).__init__() 35 | # transpose convolution to double the dimensions 36 | self.conv = nn.ConvTranspose2d(channel_in, channel_out, kernel_size=5, padding=2, stride=2, output_padding=1, 37 | bias=False) 38 | self.bn = nn.BatchNorm2d(channel_out, momentum=0.9) 39 | 40 | def forward(self, ten): 41 | ten = self.conv(ten) 42 | ten = self.bn(ten) 43 | ten = F.relu(ten, True) 44 | return ten 45 | 46 | 47 | class Encoder(nn.Module): 48 | def __init__(self, channel_in=3, z_size=128): 49 | super(Encoder, self).__init__() 50 | self.size = channel_in 51 | layers_list = [] 52 | # the first time 3->64, for every other double the channel size 53 | for i in range(3): 54 | if i == 0: 55 | layers_list.append(EncoderBlock(channel_in=self.size, channel_out=64)) 56 | self.size = 64 57 | else: 58 | layers_list.append(EncoderBlock(channel_in=self.size, channel_out=self.size * 2)) 59 | self.size *= 2 60 | 61 | # final shape Bx256x8x8 62 | self.conv = nn.Sequential(*layers_list) 63 | self.fc = nn.Sequential(nn.Linear(in_features=8 * 8 * self.size, out_features=1024, bias=False), 64 | nn.BatchNorm1d(num_features=1024,momentum=0.9), 65 | nn.ReLU(True)) 66 | # two linear to get the mu vector and the diagonal of the log_variance 67 | self.l_mu = nn.Linear(in_features=1024, out_features=z_size) 68 | self.l_var = nn.Linear(in_features=1024, out_features=z_size) 69 | 70 | def forward(self, ten): 71 | ten = self.conv(ten) 72 | ten = ten.view(len(ten), -1) 73 | ten = self.fc(ten) 74 | mu = self.l_mu(ten) 75 | logvar = self.l_var(ten) 76 | return mu, logvar 77 | 78 | def __call__(self, *args, **kwargs): 79 | return super(Encoder, self).__call__(*args, **kwargs) 80 | 81 | 82 | class Decoder(nn.Module): 83 | def __init__(self, z_size, size): 84 | super(Decoder, self).__init__() 85 | # start from B*z_size 86 | self.fc = nn.Sequential(nn.Linear(in_features=z_size, out_features=8 * 8 * size, bias=False), 87 | nn.BatchNorm1d(num_features=8 * 8 * size,momentum=0.9), 88 | nn.ReLU(True)) 89 | self.size = size 90 | layers_list = [] 91 | layers_list.append(DecoderBlock(channel_in=self.size, channel_out=self.size)) 92 | layers_list.append(DecoderBlock(channel_in=self.size, channel_out=self.size//2)) 93 | self.size = self.size//2 94 | layers_list.append(DecoderBlock(channel_in=self.size, channel_out=self.size//4)) 95 | self.size = self.size//4 96 | # final conv to get 3 channels and tanh layer 97 | layers_list.append(nn.Sequential( 98 | nn.Conv2d(in_channels=self.size, out_channels=3, kernel_size=5, stride=1, padding=2), 99 | nn.Tanh() 100 | )) 101 | 102 | self.conv = nn.Sequential(*layers_list) 103 | 104 | def forward(self, ten): 105 | 106 | ten = self.fc(ten) 107 | ten = ten.view(len(ten), -1, 8, 8) 108 | ten = self.conv(ten) 109 | return ten 110 | 111 | def __call__(self, *args, **kwargs): 112 | return super(Decoder, self).__call__(*args, **kwargs) 113 | 114 | 115 | class Discriminator(nn.Module): 116 | def __init__(self, channel_in=3,recon_level=3): 117 | super(Discriminator, self).__init__() 118 | self.size = channel_in 119 | self.recon_levl = recon_level 120 | # module list because we need need to extract an intermediate output 121 | self.conv = nn.ModuleList() 122 | self.conv.append(nn.Sequential( 123 | nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2), 124 | nn.ReLU(inplace=True))) 125 | self.size = 32 126 | self.conv.append(EncoderBlock(channel_in=self.size, channel_out=128)) 127 | self.size = 128 128 | self.conv.append(EncoderBlock(channel_in=self.size, channel_out=256)) 129 | self.size = 256 130 | self.conv.append(EncoderBlock(channel_in=self.size, channel_out=256)) 131 | # final fc to get the score (real or fake) 132 | self.fc = nn.Sequential( 133 | nn.Linear(in_features=8 * 8 * self.size, out_features=512, bias=False), 134 | nn.BatchNorm1d(num_features=512,momentum=0.9), 135 | nn.ReLU(inplace=True), 136 | nn.Linear(in_features=512, out_features=1), 137 | 138 | ) 139 | 140 | def forward(self, ten,ten_original,ten_sampled): 141 | 142 | ten = torch.cat((ten, ten_original,ten_sampled), 0) 143 | 144 | for i, lay in enumerate(self.conv): 145 | # we take the 9th layer as one of the outputs 146 | if i == self.recon_levl: 147 | ten, layer_ten = lay(ten, True) 148 | # we need the layer representations just for the original and reconstructed, 149 | # flatten, because it's a convolutional shape 150 | layer_ten = layer_ten.view(len(layer_ten), -1) 151 | else: 152 | ten = lay(ten) 153 | 154 | ten = ten.view(len(ten), -1) 155 | ten = self.fc(ten) 156 | return layer_ten,F.sigmoid(ten) 157 | 158 | 159 | def __call__(self, *args, **kwargs): 160 | return super(Discriminator, self).__call__(*args, **kwargs) 161 | 162 | 163 | class VaeGan(nn.Module): 164 | def __init__(self,z_size=128,recon_level=3): 165 | super(VaeGan, self).__init__() 166 | # latent space size 167 | self.z_size = z_size 168 | self.encoder = Encoder(z_size=self.z_size) 169 | self.decoder = Decoder(z_size=self.z_size, size=self.encoder.size) 170 | self.discriminator = Discriminator(channel_in=3,recon_level=recon_level) 171 | # self-defined function to init the parameters 172 | self.init_parameters() 173 | 174 | def init_parameters(self): 175 | # just explore the network, find every weight and bias matrix and fill it 176 | for m in self.modules(): 177 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)): 178 | if hasattr(m, "weight") and m.weight is not None and m.weight.requires_grad: 179 | #init as original implementation 180 | scale = 1.0/numpy.sqrt(numpy.prod(m.weight.shape[1:])) 181 | scale /=numpy.sqrt(3) 182 | #nn.init.xavier_normal(m.weight,1) 183 | #nn.init.constant(m.weight,0.005) 184 | nn.init.uniform(m.weight,-scale,scale) 185 | if hasattr(m, "bias") and m.bias is not None and m.bias.requires_grad: 186 | nn.init.constant(m.bias, 0.0) 187 | 188 | def forward(self, ten, gen_size=10): 189 | if self.training: 190 | # save the original images 191 | ten_original = ten 192 | # encode 193 | mus, log_variances = self.encoder(ten) 194 | # we need the true variances, not the log one 195 | variances = torch.exp(log_variances * 0.5) 196 | # sample from a gaussian 197 | 198 | ten_from_normal = Variable(torch.randn(len(ten), self.z_size).cuda(), requires_grad=True) 199 | # shift and scale using the means and variances 200 | 201 | ten = ten_from_normal * variances + mus 202 | # decode the tensor 203 | ten = self.decoder(ten) 204 | ten_from_normal = Variable(torch.randn(len(ten), self.z_size).cuda(), requires_grad=True) 205 | ten_from_normal = self.decoder(ten_from_normal) 206 | #discriminator 207 | ten_layer,ten_class = self.discriminator(ten,ten_original,ten_from_normal) 208 | 209 | return ten, ten_class, ten_layer, mus, log_variances 210 | 211 | else: 212 | if ten is None: 213 | # just sample and decode 214 | 215 | ten = Variable(torch.randn(gen_size, self.z_size).cuda(), requires_grad=False) 216 | ten = self.decoder(ten) 217 | else: 218 | mus, log_variances = self.encoder(ten) 219 | # we need the true variances, not the log one 220 | variances = torch.exp(log_variances * 0.5) 221 | # sample from a gaussian 222 | 223 | ten_from_normal = Variable(torch.randn(len(ten), self.z_size).cuda(), requires_grad=False) 224 | # shift and scale using the means and variances 225 | ten = ten_from_normal * variances + mus 226 | # decode the tensor 227 | ten = self.decoder(ten) 228 | return ten 229 | 230 | 231 | 232 | def __call__(self, *args, **kwargs): 233 | return super(VaeGan, self).__call__(*args, **kwargs) 234 | 235 | @staticmethod 236 | def loss(ten_original, ten_predicted, layer_original, layer_predicted,layer_sampled, labels_original, 237 | labels_predicted,labels_sampled, mus, variances): 238 | """ 239 | 240 | :param ten_original: original images 241 | :param ten_predicted: predicted images (output of the decoder) 242 | :param layer_original: intermediate layer for original (intermediate output of the discriminator) 243 | :param layer_predicted: intermediate layer for reconstructed (intermediate output of the discriminator) 244 | :param labels_original: labels for original (output of the discriminator) 245 | :param labels_predicted: labels for reconstructed (output of the discriminator) 246 | :param labels_sampled: labels for sampled from gaussian (0,1) (output of the discriminator) 247 | :param mus: tensor of means 248 | :param variances: tensor of diagonals of log_variances 249 | :return: 250 | """ 251 | 252 | # reconstruction error, not used for the loss but useful to evaluate quality 253 | nle = 0.5*(ten_original.view(len(ten_original), -1) - ten_predicted.view(len(ten_predicted), -1)) ** 2 254 | # kl-divergence 255 | kl = -0.5 * torch.sum(-variances.exp() - torch.pow(mus,2) + variances + 1, 1) 256 | # mse between intermediate layers for both 257 | mse_1 = torch.sum(0.5*(layer_original - layer_predicted) ** 2, 1) 258 | mse_2 = torch.sum(0.5*(layer_original - layer_sampled) ** 2, 1) 259 | # bce for decoder and discriminator for original,sampled and reconstructed 260 | # the only excluded is the bce_gen_original 261 | 262 | bce_dis_original = -torch.log(labels_original + 1e-3) 263 | bce_dis_sampled = -torch.log(1 - labels_sampled + 1e-3) 264 | bce_dis_recon = -torch.log(1 - labels_predicted+ 1e-3) 265 | 266 | #bce_gen_original = -torch.log(1-labels_original + 1e-3) 267 | bce_gen_sampled = -torch.log(labels_sampled + 1e-3) 268 | bce_gen_recon = -torch.log(labels_predicted+ 1e-3) 269 | ''' 270 | 271 | 272 | bce_gen_predicted = nn.BCEWithLogitsLoss(size_average=False)(labels_predicted, 273 | Variable(torch.ones_like(labels_predicted.data).cuda(), requires_grad=False)) 274 | bce_gen_sampled = nn.BCEWithLogitsLoss(size_average=False)(labels_sampled, 275 | Variable(torch.ones_like(labels_sampled.data).cuda(), requires_grad=False)) 276 | bce_dis_original = nn.BCEWithLogitsLoss(size_average=False)(labels_original, 277 | Variable(torch.ones_like(labels_original.data).cuda(), requires_grad=False)) 278 | bce_dis_predicted = nn.BCEWithLogitsLoss(size_average=False)(labels_predicted, 279 | Variable(torch.zeros_like(labels_predicted.data).cuda(), requires_grad=False)) 280 | bce_dis_sampled = nn.BCEWithLogitsLoss(size_average=False)(labels_sampled, 281 | Variable(torch.zeros_like(labels_sampled.data).cuda(), requires_grad=False)) 282 | ''' 283 | return nle, kl, mse_1,mse_2,\ 284 | bce_dis_original, bce_dis_sampled,bce_dis_recon,bce_gen_sampled,bce_gen_recon 285 | -------------------------------------------------------------------------------- /results/original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucabergamini/VAEGAN-PYTORCH/fa75275d743ef27075aaf198cf4abff5ffbba540/results/original.png -------------------------------------------------------------------------------- /results/recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucabergamini/VAEGAN-PYTORCH/fa75275d743ef27075aaf198cf4abff5ffbba540/results/recon.png -------------------------------------------------------------------------------- /results/sampled.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucabergamini/VAEGAN-PYTORCH/fa75275d743ef27075aaf198cf4abff5ffbba540/results/sampled.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # just a class to store a rolling average 2 | # useful to log to TB 3 | class RollingMeasure(object): 4 | def __init__(self): 5 | self.measure = 0.0 6 | self.iter = 0 7 | 8 | def __call__(self, measure): 9 | # passo nuovo valore e ottengo average 10 | # se first call inizializzo 11 | if self.iter == 0: 12 | self.measure = measure 13 | else: 14 | self.measure = (1.0 / self.iter * measure) + (1 - 1.0 / self.iter) * self.measure 15 | self.iter += 1 16 | return self.measure 17 | --------------------------------------------------------------------------------