├── .gitignore ├── LICENSE ├── README.md ├── download.sh ├── images ├── MNIST_RESULTS.PNG └── cleba_results.PNG ├── main.py ├── model ├── .gitignore ├── __init__.py ├── base_trainer.py ├── loss.py ├── model.py ├── sub_layer.py ├── trainer.py └── vae.py ├── test.py └── utils ├── __init__.py └── get_data.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | # 6 | results/ 7 | 8 | data/ 9 | runs/ 10 | save_model/ 11 | MNIST_results/ 12 | celeba_results/ 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | env/ 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | results/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | testing.py 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # dotenv 92 | .env 93 | 94 | # virtualenv 95 | .venv 96 | venv/ 97 | ENV/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Cheonbok Park 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-Latent-Constraints-Learning-to-Generate-Conditionally-from-Unconditional-Generative-Models 2 | **[Jesse Engel, Matthew Hoffman, Adam Roberts, "Learning to Generate Conditionally from Unconditional Generative Models" arXiv preprint arXiv:1711.05772 (2018)](https://arxiv.org/abs/1711.05772).** 3 | ## Meta overview 4 | This repository provides a PyTorch implementation of Learning to Generate Conditionally from Unconditional Generative Models implemented . 5 | 6 | 7 | 8 | 9 | ## Current update status 10 | * [ ] Reproduce result as the paper 11 | * [ ] Atrribute Classifier 12 | * [ ] Gradinet Penalty 13 | * [x] Tensorboad loggings 14 | * [x] Trainer 15 | * [x] Implemented a actor-crtic pairs , Distance Penalty 16 | * [x] Implemented VAE 17 | 18 | 19 | ## Results 20 | 21 | 22 | ### MNIST Results 23 |

24 | 25 | ### CelebA Results 26 |

27 | 28 | 29 | ## Usage 30 | 31 | #### 1.Clone the repository 32 | ```bash 33 | $ git clone https://github.com/cheonbok94/Pytorch-Latent-Constraints-Learning-to-Generate-Conditionally-from-Unconditional-Generative-Models.git 34 | $ cd Pytorch-Latent-Constraints-Learning-to-Generate-Conditionally-from-Unconditional-Generative-Models 35 | ``` 36 | 37 | #### 2.Download datasets (Celeba) & install requirement packages 38 | ```bash 39 | $ bash download.sh 40 | ``` 41 | 42 | #### 3. Train 43 | 44 | ##### (1) Training VAE 45 | 46 | ##### (2) Training a Actor-Critic Pairs 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | URL=https://www.dropbox.com/s/3e5cmqgplchz85o/CelebA_nocrop.zip?dl=0 2 | ZIP_FILE=./data/CelebA_nocrop.zip 3 | mkdir -p ./data/ 4 | wget -N $URL -O $ZIP_FILE 5 | unzip $ZIP_FILE -d ./data/ 6 | rm $ZIP_FILE 7 | 8 | # CelebA attribute labels 9 | URL=https://www.dropbox.com/s/auexdy98c6g7y25/list_attr_celeba.zip?dl=0 10 | ZIP_FILE=./data/list_attr_celeba.zip 11 | wget -N $URL -O $ZIP_FILE 12 | unzip $ZIP_FILE -d ./data/ 13 | rm $ZIP_FILE 14 | -------------------------------------------------------------------------------- /images/MNIST_RESULTS.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbokpark/Pytorch-Latent-Constraints-Learning-to-Generate-Conditionally-from-Unconditional-Generative-Models/0dbd182b294e0c6d3ad0deda3be1dd855fd57617/images/MNIST_RESULTS.PNG -------------------------------------------------------------------------------- /images/cleba_results.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbokpark/Pytorch-Latent-Constraints-Learning-to-Generate-Conditionally-from-Unconditional-Generative-Models/0dbd182b294e0c6d3ad0deda3be1dd855fd57617/images/cleba_results.PNG -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy 3 | from model.trainer import Trainer,AC_Trainer 4 | from model.loss import loss_function,celeba_loss 5 | 6 | 7 | from model.vae import Mnist_VAE,Celeba_VAE 8 | from utils.get_data import MNIST_DATA ,Celeba_DATA 9 | from model.model import Actor,Critic 10 | import torch 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument('--gpu_num', type = int, default = None) 16 | parser.add_argument('--data',type =str,required=True) 17 | parser.add_argument('--num_epoch',type=int,default =50) 18 | parser.add_argument('--batch_size',type=int,default =128) 19 | parser.add_argument('--tensorboard_dirs',type=str,default ='./run') 20 | parser.add_argument('--train_id',type=str , default = 'my_model') 21 | parser.add_argument('--gpu_accelerate',action='store_true') 22 | parser.add_argument('--image_dir',type = str , default = './data/CelebA_nocrop/images') 23 | parser.add_argument('--attr_path',type = str ,default = './data/list_attr_celeba.txt') 24 | parser.add_argument('--distance_penalty',type=float, default = 0.1) 25 | parser.add_argument('--vae_mode',action = 'store_true') 26 | parser.add_argument('--save_model',type=str,default = './') 27 | parser_config = parser.parse_args() 28 | print (parser_config) 29 | if parser_config.gpu_num is not None : 30 | torch.cuda.set_device(parser_config.gpu_num) 31 | 32 | if parser_config.gpu_accelerate: 33 | torch.backends.cudnn.benchmark = True 34 | if parser_config.gpu_num == -1: 35 | device = 'cpu' 36 | else: 37 | device = parser_config.gpu_num 38 | if parser_config.vae_mode: 39 | 40 | if parser_config.data == 'MNIST': 41 | trainDset,testDset,trainDataLoader,testDataLoader= MNIST_DATA(batch_size = parser_config.batch_size ) 42 | model = Mnist_VAE(input_dim= 28*28 ,layer_num= 4, d_model=1024) 43 | trainer = Trainer(model=model, 44 | loss = loss_function, 45 | data = parser_config.data, 46 | epoch = parser_config.num_epoch, 47 | trainDataLoader=trainDataLoader, 48 | testDataLoader=testDataLoader) 49 | 50 | elif parser_config.data == 'celeba': 51 | trainDset,testDset,trainDataLoader,testDataLoader = Celeba_DATA(celeba_img_dir=parser_config.image_dir ,attr_path=parser_config.attr_path,batch_size = parser_config.batch_size,image_size=64,celeba_crop_size=128) 52 | model = Celeba_VAE(128,d_model=1024,layer_num=3) 53 | trainer = Trainer(model=model, 54 | loss = celeba_loss, 55 | data = parser_config.data, 56 | epoch = parser_config.num_epoch, 57 | trainDataLoader=trainDataLoader, 58 | testDataLoader=testDataLoader) 59 | 60 | else: 61 | raise NotImplementedError 62 | else: 63 | if parser_config.data == 'MNIST': 64 | trainDset,testDset,trainDataLoader,testDataLoader = MNIST_DATA(batch_size = parser_config.batch_size ) 65 | model = Mnist_VAE(input_dim= 28*28 ,layer_num= 4, d_model=1024) 66 | actor = Actor(1024,2048) 67 | real_critic = Critic(1024,2048,num_labels=10,condition_mode =True) 68 | attr_critic = Critic(1024,2048,num_labels=10,num_output=10,condition_mode =False) 69 | actrainer = AC_Trainer(vae_model=model, 70 | actor = actor, 71 | real_critic = real_critic, 72 | attr_critic = attr_critic, 73 | epoch = parser_config.num_epoch, 74 | data = parser_config.data, 75 | trainDataLoader=trainDataLoader, 76 | testDataLoader=testDataLoader) 77 | actrainer.load_vae('./save_model/vae_model50_MNIST.path.tar') 78 | actrainer._set_label_type() 79 | if parser_config.data == 'celeba': 80 | selected_attrs = ['Bald','Black_Hair','Blond_Hair','Brown_Hair','Eyeglasses','Male','No_Beard','Smiling','Wearing_Hat','Young'] 81 | trainDset,testDset,trainDataLoader,testDataLoader = Celeba_DATA(celeba_img_dir=parser_config.image_dir ,attr_path=parser_config.attr_path,selected_attrs=selected_attrs,batch_size = parser_config.batch_size,image_size=64,celeba_crop_size=128) 82 | model = Celeba_VAE(128,d_model=1024,layer_num=3) 83 | actor = Actor(1024,2048) 84 | real_critic = Critic(1024,2048,num_labels=10,condition_mode =True) 85 | attr_critic = Critic(1024,2048,num_labels=10,num_output=10,condition_mode =False) 86 | actrainer = AC_Trainer(vae_model=model, 87 | actor = actor, 88 | real_critic = real_critic, 89 | attr_critic = attr_critic, 90 | epoch = parser_config.num_epoch, 91 | data = parser_config.data, 92 | trainDataLoader=trainDataLoader, 93 | testDataLoader=testDataLoader) 94 | actrainer.load_vae('./save_model/vae_model150_celeba.path.tar') 95 | 96 | 97 | print ("[+] Train model start") 98 | if parser_config.vae_mode: 99 | trainer.train() 100 | else: 101 | actrainer.train() 102 | 103 | if __name__ == '__main__': 104 | main() -------------------------------------------------------------------------------- /model/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbokpark/Pytorch-Latent-Constraints-Learning-to-Generate-Conditionally-from-Unconditional-Generative-Models/0dbd182b294e0c6d3ad0deda3be1dd855fd57617/model/.gitignore -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbokpark/Pytorch-Latent-Constraints-Learning-to-Generate-Conditionally-from-Unconditional-Generative-Models/0dbd182b294e0c6d3ad0deda3be1dd855fd57617/model/__init__.py -------------------------------------------------------------------------------- /model/base_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import json 4 | import logging 5 | import torch 6 | import torch.optim as optim 7 | from utils.util import ensure_dir 8 | 9 | 10 | class BaseTrainer: 11 | """ 12 | Base class for all trainers 13 | """ 14 | def __init__(self, model, loss, metrics, resume, config, train_logger=None): 15 | self.config = config 16 | self.logger = logging.getLogger(self.__class__.__name__) 17 | self.model = model 18 | self.loss = loss 19 | self.metrics = metrics 20 | self.name = config['name'] 21 | self.epochs = config['trainer']['epochs'] 22 | self.save_freq = config['trainer']['save_freq'] 23 | self.verbosity = config['trainer']['verbosity'] 24 | self.with_cuda = config['cuda'] and torch.cuda.is_available() 25 | if config['cuda'] and not torch.cuda.is_available(): 26 | self.logger.warning('Warning: There\'s no CUDA support on this machine, ' 27 | 'training is performed on CPU.') 28 | self.train_logger = train_logger 29 | self.optimizer = getattr(optim, config['optimizer_type'])(model.parameters(), 30 | **config['optimizer']) 31 | self.monitor = config['trainer']['monitor'] 32 | self.monitor_mode = config['trainer']['monitor_mode'] 33 | assert self.monitor_mode == 'min' or self.monitor_mode == 'max' 34 | self.monitor_best = math.inf if self.monitor_mode == 'min' else -math.inf 35 | self.start_epoch = 1 36 | self.checkpoint_dir = os.path.join(config['trainer']['save_dir'], self.name) 37 | ensure_dir(self.checkpoint_dir) 38 | json.dump(config, open(os.path.join(self.checkpoint_dir, 'config.json'), 'w'), 39 | indent=4, sort_keys=False) 40 | if resume: 41 | self._resume_checkpoint(resume) 42 | 43 | def train(self): 44 | """ 45 | Full training logic 46 | """ 47 | for epoch in range(self.start_epoch, self.epochs+1): 48 | result = self._train_epoch(epoch) 49 | log = {'epoch': epoch} 50 | for key, value in result.items(): 51 | if key == 'metrics': 52 | for i, metric in enumerate(self.metrics): 53 | log[metric.__name__] = result['metrics'][i] 54 | elif key == 'val_metrics': 55 | for i, metric in enumerate(self.metrics): 56 | log['val_' + metric.__name__] = result['val_metrics'][i] 57 | else: 58 | log[key] = value 59 | if self.train_logger is not None: 60 | self.train_logger.add_entry(log) 61 | if self.verbosity >= 1: 62 | for key, value in log.items(): 63 | self.logger.info(' {:15s}: {}'.format(str(key), value)) 64 | if (self.monitor_mode == 'min' and log[self.monitor] < self.monitor_best)\ 65 | or (self.monitor_mode == 'max' and log[self.monitor] > self.monitor_best): 66 | self.monitor_best = log[self.monitor] 67 | self._save_checkpoint(epoch, log, save_best=True) 68 | if epoch % self.save_freq == 0: 69 | self._save_checkpoint(epoch, log) 70 | 71 | def _train_epoch(self, epoch): 72 | """ 73 | Training logic for an epoch 74 | :param epoch: Current epoch number 75 | """ 76 | raise NotImplementedError 77 | 78 | def _save_checkpoint(self, epoch, log, save_best=False): 79 | """ 80 | Saving checkpoints 81 | :param epoch: current epoch number 82 | :param log: logging information of the epoch 83 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth.tar' 84 | """ 85 | arch = type(self.model).__name__ 86 | state = { 87 | 'arch': arch, 88 | 'epoch': epoch, 89 | 'logger': self.train_logger, 90 | 'state_dict': self.model.state_dict(), 91 | 'optimizer': self.optimizer.state_dict(), 92 | 'monitor_best': self.monitor_best, 93 | 'config': self.config 94 | } 95 | filename = os.path.join(self.checkpoint_dir, 'checkpoint-epoch{:03d}-loss-{:.4f}.pth.tar' 96 | .format(epoch, log['loss'])) 97 | torch.save(state, filename) 98 | if save_best: 99 | os.rename(filename, os.path.join(self.checkpoint_dir, 'model_best.pth.tar')) 100 | self.logger.info("Saving current best: {} ...".format('model_best.pth.tar')) 101 | else: 102 | self.logger.info("Saving checkpoint: {} ...".format(filename)) 103 | 104 | def _resume_checkpoint(self, resume_path): 105 | """ 106 | Resume from saved checkpoints 107 | :param resume_path: Checkpoint path to be resumed 108 | """ 109 | self.logger.info("Loading checkpoint: {} ...".format(resume_path)) 110 | checkpoint = torch.load(resume_path) 111 | self.start_epoch = checkpoint['epoch'] + 1 112 | self.monitor_best = checkpoint['monitor_best'] 113 | self.model.load_state_dict(checkpoint['state_dict']) 114 | self.optimizer.load_state_dict(checkpoint['optimizer']) 115 | self.train_logger = checkpoint['logger'] 116 | self.config = checkpoint['config'] 117 | self.logger.info("Checkpoint '{}' (epoch {}) loaded".format(resume_path, self.start_epoch)) 118 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def loss_function(recon_x,x,mu,sig_var): 5 | reconstruction_loss = F.mse_loss(recon_x,x.view(-1,784),size_average = False)/0.01 6 | 7 | KLD_element = mu.pow(2).add_(sig_var.pow(2)).mul_(-1).add_(1).add_(sig_var.pow(2).log()) 8 | KLD = torch.sum(KLD_element).mul_(-0.5) 9 | return KLD+ reconstruction_loss 10 | 11 | def celeba_loss(recon_x,x,mu,sig_var): 12 | reconstruction_loss = F.mse_loss(recon_x,x,size_average = False)/recon_x.size(0)/0.1 13 | 14 | KLD_element = mu.pow(2).add_(sig_var.pow(2)).mul_(-1).add_(1).add_(sig_var.pow(2).log()) 15 | KLD = torch.sum(KLD_element).mul_(-0.5) 16 | return (KLD+ reconstruction_loss) 17 | class AC_loss: 18 | def __init__(self,lambda_dist,lambda_attr): 19 | self.lambda_attr = lambda_attr 20 | self.lambda_dist = lambda_dist 21 | def __real_loss(self,z,z_prime,sigvar,predict_d,grth_d): 22 | """ 23 | inputs: 24 | z : Batch Size * Z_dim 25 | z_prime : Batch Size * Z_dim 26 | sig_var : std 27 | """ 28 | 29 | sum_variance = torch.sum(sigvar.pow(2),dim=-1) 30 | distance_penalty = torch.sum(F.mse_loss(z_prime,z,size_average=False,reduce =False)*sum_variance) 31 | real_loss = F.binary_cross_entropy(predict_d,grth_d) 32 | 33 | return real_loss 34 | 35 | #def __attr_loss(self,z,z_prime,sigvar,predict_d,grth_d) 36 | 37 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .sub_layer import Linear 4 | import torch.nn.functional as F 5 | import pdb 6 | class Actor(nn.Module): 7 | def __init__(self,d_z,d_model=2048,layer_num=4,num_label = 10,condition_mode = True): 8 | super(Actor,self).__init__() 9 | self.d_z = d_z 10 | self.d_model = d_model 11 | self.layer_num = layer_num 12 | self.condition_mode = condition_mode 13 | layer_list = [] 14 | for i in range(layer_num): 15 | if i == 0: 16 | if condition_mode: 17 | input_dim = d_z + d_model 18 | else: 19 | input_dim = d_z 20 | layer_list.append(Linear(input_dim,d_model)) 21 | else: 22 | layer_list.append(Linear(d_model,d_model)) 23 | layer_list.append(nn.ReLU()) 24 | layer_list.append(Linear(d_model,self.d_z*2)) 25 | 26 | self.fw_layer = nn.Sequential(*layer_list) 27 | self.gate = nn.Sigmoid() 28 | if condition_mode: 29 | self.num_label = num_label 30 | self.condition_layer = Linear(num_label,d_model) 31 | 32 | def forward(self,x,label): 33 | original_x = x 34 | 35 | x = torch.cat((x,self.condition_layer(label)),dim = -1) 36 | out = self.fw_layer(x) 37 | input_gate , dz = out.chunk(2,dim = -1) 38 | gate_value = self.gate(input_gate) 39 | new_z = (1-gate_value)*original_x + gate_value*dz 40 | return new_z 41 | 42 | class Critic(nn.Module): 43 | def __init__(self,d_z,d_model,layer_num=4,num_labels = None,num_output = 1,condition_mode = True): 44 | super(Critic,self).__init__() 45 | self.d_z = d_z 46 | self.d_model = d_model 47 | self.condition_mode = condition_mode 48 | layer_list = [] 49 | for i in range(layer_num): 50 | if i == 0: 51 | if condition_mode: 52 | input_dim = d_z + d_model 53 | else: 54 | input_dim = d_z 55 | layer_list.append(Linear(input_dim,d_model)) 56 | else: 57 | layer_list.append(Linear(d_model,d_model)) 58 | layer_list.append(nn.ReLU()) 59 | layer_list.append(Linear(d_model,num_output)) 60 | self.fw_layer = nn.Sequential(*layer_list) 61 | 62 | if condition_mode: 63 | self.num_labels = num_labels 64 | self.condition_layer = Linear(num_labels,d_model) 65 | def forward(self, x, label = None): 66 | 67 | if self.condition_mode: 68 | x = torch.cat((x,self.condition_layer(label)),dim = -1) 69 | out = self.fw_layer(x) 70 | return F.sigmoid(out) 71 | 72 | -------------------------------------------------------------------------------- /model/sub_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Linear(nn.Module): 5 | def __init__(self,input_dim,output_dim,bias = True): 6 | super(Linear,self).__init__() 7 | self.linear = nn.Linear(input_dim,output_dim,bias=bias) 8 | self.batch_norm = nn.BatchNorm1d(output_dim) 9 | def forward(self,x): 10 | out = self.linear(x) 11 | return self.batch_norm(out) 12 | class View(nn.Module): 13 | def __init__(self,shape = None): 14 | super(View,self).__init__() 15 | self.shape = shape 16 | def forward(self,x): 17 | if self.shape is None: 18 | return x.view(x.size(0),-1) 19 | else: 20 | 21 | return x.view(x.size(0),*self.shape) 22 | -------------------------------------------------------------------------------- /model/trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pdb 4 | from torchvision.utils import save_image 5 | import torch.optim as optim 6 | import numpy as np 7 | from tensorboardX import SummaryWriter 8 | import torch.nn.functional as F 9 | from tqdm import tqdm 10 | class Trainer: 11 | # need to modify to inheritance version 12 | def __init__(self,model,trainDataLoader,loss,epoch,data = 'celeba' ,metrics=None,resume=None,config=None,validDataLoader = None,device = 0, testDataLoader = None,train_logger =None,optimizer_type='Adam',lr=1e-3): 13 | #super(Trainer,self).__init__(model,loss,metrics,resume,config,train_logger) 14 | self.model = model 15 | self.trainDataLoader = trainDataLoader 16 | self.testDataLoader = testDataLoader 17 | self.validDataLoader = validDataLoader 18 | self.valid = True if self.validDataLoader is not None else False 19 | self.test = True if self.testDataLoader is not None else False 20 | self.device = device 21 | self.model.to(self.device) 22 | self.d_model = self.model.d_model 23 | self.train_loss = 0 24 | 25 | self.data = data 26 | self.tensorboad_writer = SummaryWriter() 27 | self.epoch = epoch 28 | self.loss = loss 29 | self.start_epoch = 1 30 | self.with_cuda = torch.cuda.is_available() 31 | self.save_freq = 500 32 | self.total_iteration = 0 33 | self.optimizer = getattr(optim, optimizer_type)(self.model.parameters(),lr=lr) 34 | self.valid_term = 10 35 | def train(self): 36 | for epoch in range(self.start_epoch,self.epoch+1): 37 | result = self._train_epoch(epoch) 38 | self.get_sample(epoch) 39 | if epoch%self.valid_term == 0: 40 | self._test(epoch) 41 | self.save_model(epoch) 42 | print ("[+] Finished Training Model") 43 | 44 | 45 | def _train_epoch(self,epoch): 46 | 47 | self.model.train() 48 | train_loss = 0 49 | for batch_idx,(data,labels) in enumerate(self.trainDataLoader): 50 | data = data.to(self.device) 51 | self.optimizer.zero_grad() 52 | recon_batch,z,mu,log_sigma = self.model(data) 53 | loss = self.loss(recon_batch,data,mu,log_sigma) 54 | loss.backward() 55 | self.optimizer.step() 56 | train_loss += loss.item() 57 | if batch_idx == 1: 58 | if self.data == 'MNIST': 59 | recon_batch = recon_batch.view(-1,1,28,28) 60 | save_image(recon_batch.cpu(),self.data+'_results/sample_train_' + str(epoch) +'.png') 61 | save_image(data.cpu(),self.data+'_results/grtruth_train_' + str(epoch) +'.png') 62 | self._summary_wrtie(train_loss,epoch) 63 | print ("[+] Epoch:[{}/{}] train average loss :{}".format(epoch,self.epoch,train_loss)) 64 | # print interval state 65 | 66 | def _test(self,epoch): 67 | self.model.eval() 68 | test_loss = 0 69 | with torch.no_grad(): 70 | for i, (data,lebels) in enumerate(self.testDataLoader): 71 | data = data.cuda() 72 | recon_batch,z,mu,log_sigma = self.model(data) 73 | 74 | loss = self.loss(recon_batch,data,mu,log_sigma) 75 | test_loss += loss.item() 76 | if i == 1: 77 | if self.data == 'MNIST': 78 | recon_batch = recon_batch.view(-1,1,28,28) 79 | save_image(recon_batch.cpu(),self.data+'_results/sample_valid_' + str(epoch) +'.png') 80 | save_image(data.cpu(),self.data+'_results/grtruth_valid_' + str(epoch) +'.png') 81 | 82 | print ("[+] Validation result {}".format(test_loss)) 83 | def get_sample(self,epoch): 84 | self.model.eval() 85 | with torch.no_grad(): 86 | 87 | sample = torch.randn(64,self.d_model).to(self.device) 88 | out = self.model.decoder(sample) 89 | out = self.model.sigmoid(out) 90 | if self.data == 'MNIST': 91 | save_image(out.view(-1,1,28,28),self.data+'_results/sample_' + str(epoch) +'.png') 92 | else: 93 | save_image(out.cpu(),self.data+'_results/sample_' + str(epoch) +'.png') 94 | def save_model(self,epoch): 95 | 96 | torch.save(self.model.state_dict(), './save_model/vae_model'+str(epoch)+'_'+self.data+'.path.tar') 97 | def _summary_wrtie(self,loss,epoch): 98 | self.tensorboad_writer.add_scalar('data/loss',loss,epoch) 99 | for name,param in self.model.named_parameters(): 100 | self.tensorboad_writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch,bins='sturges') 101 | self.tensorboad_writer.add_histogram(name+'/grad', param.grad.clone().cpu().data.numpy(), epoch,bins='sturges') 102 | def _eval_metric(self,output,target): 103 | raise NotImplementedError 104 | 105 | class AC_Trainer: 106 | def __init__(self,vae_model,actor,real_critic,attr_critic,epoch,trainDataLoader,data,metrics=None,resume=None,config=None,validDataLoader = None,device = 0, testDataLoader = None,train_logger =None,optimizer_type='Adam',lr=1e-4): 107 | self.model = vae_model 108 | self.actor = actor 109 | self.real_critic = real_critic 110 | self.attr_critic = attr_critic 111 | #self.loss = loss # lossfunction class 112 | self.epoch = epoch 113 | 114 | self.trainDataLoader = trainDataLoader 115 | self.testDataLoader = testDataLoader 116 | self.validDataLoader = validDataLoader 117 | self.valid = True if self.validDataLoader is not None else False 118 | self.test = True if self.testDataLoader is not None else False 119 | self.device = device 120 | 121 | self.model.to(self.device) 122 | self.model.eval() 123 | self.actor.to(self.device) 124 | self.real_critic.to(self.device) 125 | self.attr_critic.to(self.device) 126 | 127 | self.d_model = self.model.d_model 128 | self.train_loss = 0 129 | 130 | self.iteration = 0 131 | self.data = data 132 | self.tensorboad_writer = SummaryWriter() 133 | self.epoch = epoch 134 | #self.loss = loss 135 | 136 | self.start_epoch = 1 137 | self.with_cuda = torch.cuda.is_available() 138 | self.save_freq = 500 139 | self.total_iteration = 0 140 | #self.gen_optimizer = getattr(optim, optimizer_type)(list(self.actor.parameters()) + list(self.critic.parameters()),lr=lr) 141 | self.actor_optim = getattr(optim, 'Adam')(self.actor.parameters(),lr=lr*3) 142 | self.real_optim = getattr(optim, 'Adam')(self.real_critic.parameters(),lr=lr*3) 143 | self.attr_optim = getattr(optim, optimizer_type)(self.attr_critic.parameters(),lr=lr*3) 144 | self.valid_term = 10 145 | self._set_label_type() 146 | 147 | def train(self): 148 | for epoch in tqdm(range(self.start_epoch,self.epoch+1)): 149 | result = self._train_epoch(epoch) 150 | self.get_sample(epoch) 151 | if epoch%self.valid_term == 0: 152 | #self._test(epoch) 153 | self.get_sample(epoch) 154 | self.save_model(epoch) 155 | if (epoch+1) == 30: 156 | self.actor_optim.param_groups[0]['lr'] /= 10 157 | self.real_optim.param_groups[0]['lr'] /= 10 158 | print("learning rate change!") 159 | 160 | if (epoch+1) == 50: 161 | self.actor_optim.param_groups[0]['lr'] /= 10 162 | self.real_optim.param_groups[0]['lr'] /= 10 163 | print("learning rate change!") 164 | print ("[+] Finished Training Model") 165 | 166 | 167 | def _train_epoch(self,epoch): 168 | 169 | # Run Data Loader 170 | # realism Constraints pz~ is 0 E(q|x) ~ x 171 | # attr Constraints 172 | self.model.eval() 173 | self.actor.train() 174 | self.real_critic.train() 175 | self.attr_critic.train() 176 | train_loss = 0 177 | iteration = 0 178 | total_actor_loss = 0 179 | total_real_loss = 0 180 | total_distance_penalty = 0 181 | actor_iteration = 0 182 | for batch_idx,(data,labels) in enumerate(self.trainDataLoader): 183 | self.iteration += 1 184 | iteration += 1 185 | m_batchsize = data.size(0) 186 | 187 | real_data = torch.ones(m_batchsize,1) 188 | fake_data = torch.zeros(m_batchsize,1) 189 | fake_z = torch.randn(m_batchsize,self.d_model) 190 | fake_attr = self.fake_attr_generate(m_batchsize) 191 | 192 | real_data = real_data.to(self.device) 193 | fake_data = fake_data.to(self.device) 194 | labels = labels.to(self.device) 195 | fake_z = fake_z.to(self.device) 196 | fake_attr = fake_attr.to(self.device) 197 | data = data.to(self.device) 198 | with torch.no_grad(): 199 | sig_var,mu,z = self.model.encode(data) 200 | 201 | z = self.reparameterize(mu,z) 202 | 203 | 204 | fake_z.requires_grad =True 205 | z.requires_grad = True 206 | labels.requires_grad =True 207 | fake_attr.requires_grad =True 208 | 209 | self.real_critic.zero_grad() 210 | if np.random.rand(1) < 0.1: 211 | 212 | #input_data = torch.cat([z,fake_z],dim=0) 213 | input_data = torch.cat([z,fake_z,z],dim=0) 214 | 215 | 216 | input_attr = torch.cat([labels,labels,fake_attr],dim=0) 217 | #input_attr = torch.cat([labels,labels],dim=0) 218 | real_labels = torch.cat([real_data,fake_data,fake_data]) 219 | 220 | #real_labels = torch.cat([real_data,fake_data]) 221 | 222 | logit_out = self.real_critic(input_data,input_attr) 223 | critic_loss = F.binary_cross_entropy(logit_out,real_labels) 224 | else: 225 | #pdb.set_trace() 226 | fake_z.requires_grad =True 227 | z_g = self.actor(fake_z,labels) 228 | 229 | input_data = torch.cat([z,z_g,z],dim=0) 230 | #input_data = torch.cat([z,z_g],dim=0) 231 | 232 | input_attr = torch.cat([labels,labels,fake_attr],dim=0) 233 | #input_attr = torch.cat([labels,labels],dim=0) 234 | 235 | 236 | real_labels = torch.cat([real_data,fake_data,fake_data]) 237 | #real_labels = torch.cat([real_data,fake_data]) 238 | 239 | logit_out = self.real_critic(input_data,input_attr) 240 | 241 | critic_loss = F.binary_cross_entropy(logit_out,real_labels) 242 | 243 | #print ("critic_loss : {}".format(critic_loss)) 244 | critic_loss.backward() 245 | self.real_optim.step() 246 | if (batch_idx+1)%1000 == 0: 247 | self.d_critic_histogram(self.iteration) 248 | total_real_loss += critic_loss.item() 249 | 250 | if (batch_idx+1)%10 ==0: 251 | self.actor.zero_grad() 252 | #actor_labels = self.re_allocate(labels) 253 | fake_z = torch.randn(m_batchsize,self.d_model) 254 | fake_z = fake_z.to(self.device) 255 | fake_z = self.re_allocate(fake_z) 256 | #actor_truth = self.re_allocate(real_data) 257 | 258 | actor_labels = labels 259 | fake_z = fake_z 260 | actor_truth = real_data 261 | 262 | actor_labels.requires_grad =True 263 | fake_z.requires_grad = True 264 | 265 | actor_g = self.actor(fake_z,actor_labels) 266 | real_g = self.actor(z,actor_labels) 267 | zg_critic_out = self.real_critic(actor_g,actor_labels) 268 | zg_critic_real = self.real_critic(real_g,actor_labels) 269 | #fake_output = self.attr_critic(z_g) # prior 는 안써도 되는가? 애매하군. 270 | weight_var = torch.mean(sig_var,0,True) 271 | 272 | distnace_penalty = torch.mean(torch.sum((1 + (actor_g-fake_z).pow(2)).log()*weight_var.pow(-2),1),0) 273 | distnace_penalty = distnace_penalty + torch.mean(torch.sum((1 + (real_g-z).pow(2)).log()*weight_var.pow(-2),1),0) 274 | #distnace_penalty = 0 275 | 276 | actor_loss = F.binary_cross_entropy(zg_critic_out,actor_truth,size_average=False)+ F.binary_cross_entropy(zg_critic_real,actor_truth,size_average=False)+distnace_penalty 277 | 278 | #actor_loss = actor_loss + distnace_penalty 279 | actor_loss.backward() 280 | total_actor_loss += actor_loss.item() 281 | self.actor_optim.step() 282 | total_distance_penalty += distnace_penalty.item() 283 | actor_iteration +=1 284 | if (actor_iteration%100) == 0 : 285 | self.d_actor_histogram(self.iteration) 286 | 287 | 288 | 289 | 290 | 291 | if batch_idx == 1: 292 | fake_z = torch.randn(m_batchsize,self.d_model) 293 | fake_z = fake_z.to(self.device) 294 | z_g = self.actor(fake_z,labels) 295 | z_g_recon = self.model.decode(z_g) 296 | prior_recon = self.model.decode(fake_z) 297 | data_recon = self.model.decode(z) 298 | if self.data == 'MNIST': 299 | data = data.view(-1,1,28,28) 300 | z_g_recon = z_g_recon.view(-1,1,28,28) 301 | data_recon = data_recon.view(-1,1,28,28) 302 | prior_recon = prior_recon.view(-1,1,28,28) 303 | save_image(z_g_recon.cpu(),self.data+'_results_ac/sample_z_g_train_' + str(epoch) +'.png') 304 | save_image(prior_recon.cpu(),self.data+'_results_ac/sample_prior_train_' + str(epoch) +'.png') 305 | save_image(data_recon.cpu(),self.data+'_results_ac/sample_recon_train_' + str(epoch) +'.png') 306 | save_image(data.cpu(),self.data+'_results_ac/grtruth_train_' + str(epoch) +'.png') 307 | print ("distance penalty : {} , {} | critic_loss :{}".format(total_distance_penalty/actor_iteration,total_actor_loss/actor_iteration,total_real_loss/iteration)) 308 | 309 | 310 | self._summary_wrtie(total_distance_penalty/actor_iteration,total_actor_loss/actor_iteration,total_real_loss/iteration,epoch) 311 | print ("[+] Epoch:[{}/{}] train actor average loss :{}".format(epoch,self.epoch,train_loss)) 312 | def re_allocate(self,data): 313 | new_data = data.detach() 314 | new_data.requiers_grad = True 315 | return new_data 316 | def get_sample(self,epoch,data =None,labels =None): 317 | self.model.eval() 318 | self.actor.eval() 319 | self.real_critic.eval() 320 | self.attr_critic.eval() 321 | with torch.no_grad(): 322 | for i in range(self.num_labels): 323 | test_labels = self.labels[i] 324 | test_labels = test_labels.expand(64,-1) 325 | 326 | sample = torch.randn(64,self.d_model).to(self.device) 327 | sample = self.actor(sample,test_labels) 328 | 329 | out = self.model.decoder(sample) 330 | out = self.model.sigmoid(out) #?? Amiguity Labels input? 331 | if self.data == 'MNIST': 332 | save_image(out.view(-1,1,28,28),self.data+'_results_ac/sample_' + str(epoch)+'_class:'+str(i) +'.png') 333 | else: 334 | save_image(out.cpu(),self.data+'_results_ac/sample_' + str(epoch)+'_class:'+str(i) +'.png') 335 | def _test(self,epoch): 336 | self.model.eval() 337 | self.actor.eval() 338 | self.real_critic.eval() 339 | self.attr_critic.eval() 340 | test_loss = 0 341 | with torch.no_grad(): 342 | for i, (data,lebels) in enumerate(self.testDataLoader): 343 | data = data.cuda() 344 | recon_batch,z,mu,log_sigma = self.model(data) 345 | 346 | loss = self.loss(recon_batch,data,mu,log_sigma) 347 | test_loss += loss.item() 348 | if i == 1: 349 | if self.data == 'MNIST': 350 | recon_batch = recon_batch.view(-1,1,28,28) 351 | save_image(recon_batch.cpu(),self.data+'_results/sample_valid_' + str(epoch) +'.png') 352 | save_image(data.cpu(),self.data+'_results/grtruth_valid_' + str(epoch) +'.png') 353 | 354 | def _set_label_type(self): 355 | 356 | self.labels = torch.eye(10) 357 | self.labels = self.labels.to(self.device) 358 | self.num_labels = self.labels.size(0) 359 | 360 | def fake_attr_generate(self,batch_size,selection_index = None ): 361 | if self.data == 'MNIST': 362 | if selection_index is None: 363 | m = batch_size 364 | selection = np.random.randint(self.num_labels,size=m) 365 | selection = torch.from_numpy(selection).to(self.device) 366 | fake_attr = torch.index_select(self.labels,0,selection) 367 | else: 368 | selection_index = selection_index.cuda() 369 | fake_attr = torch.index_select(self.labels,0,selection_index) 370 | else: 371 | selection_start = np.random.randint(self.trainDataLoader.dataset.num_images-batch_size-1) 372 | selection_data = self.trainDataLoader.dataset.train_dataset[selection_start:selection_start+batch_size] 373 | fake_attr = [] 374 | for (name,labels) in selection_data: 375 | fake_attr.append(labels) 376 | fake_attr = torch.FloatTensor(fake_attr) 377 | 378 | return fake_attr 379 | 380 | 381 | def reparameterize(self,mu,sig_var): 382 | 383 | std = sig_var # need to check sig_var is log (sigma^2) 384 | eps = std.data.new(std.size()).normal_(std=1) 385 | return eps.mul(std).add_(mu) 386 | def _summary_wrtie(self,distance_penalty,loss,real_loss,epoch): 387 | self.tensorboad_writer.add_scalar('data/loss',loss,epoch) # need to modify . We use four loss value . 388 | self.tensorboad_writer.add_scalar('data/distance_penalty',distance_penalty,epoch) # need to modify . We use four loss value . 389 | self.tensorboad_writer.add_scalar('data/discriminator_loss',real_loss,epoch) 390 | #for name,param in self.actor.named_parameters(): #actor 391 | # self.tensorboad_writer.add_histogram('actor/'+name, param.clone().cpu().data.numpy(), epoch,bins='sturges') 392 | # self.tensorboad_writer.add_histogram('actor/'+name+'/grad', param.grad.clone().cpu().data.numpy(), epoch,bins='sturges') 393 | #for name,param in self.real_critic.named_parameters(): #actor 394 | # self.tensorboad_writer.add_histogram('real_critic/'+name, param.clone().cpu().data.numpy(), epoch,bins='sturges') 395 | # self.tensorboad_writer.add_histogram('real_critic/'+name+'/grad', param.grad.clone().cpu().data.numpy(), epoch,bins='sturges') 396 | #for name,param in self.attr_critic.named_parameters(): #actor 397 | # self.tensorboad_writer.add_histogram('attr_critic/'+name, param.clone().cpu().data.numpy(), epoch,bins='sturges') 398 | # self.tensorboad_writer.add_histogram('attr_critic/'+name+'/grad', param.grad.clone().cpu().data.numpy(), epoch,bins='sturges') 399 | def d_critic_histogram(self,iteration): 400 | for name,param in self.real_critic.named_parameters(): #actor 401 | self.tensorboad_writer.add_histogram('real_critic/'+name, param.clone().cpu().data.numpy(), iteration,bins='sturges') 402 | self.tensorboad_writer.add_histogram('real_critic/'+name+'/grad', param.grad.clone().cpu().data.numpy(), iteration,bins='sturges') 403 | def d_actor_histogram(self,iteration): 404 | for name,param in self.actor.named_parameters(): #actor 405 | self.tensorboad_writer.add_histogram('actor/'+name, param.clone().cpu().data.numpy(), iteration,bins='sturges') 406 | self.tensorboad_writer.add_histogram('actor/'+name+'/grad', param.grad.clone().cpu().data.numpy(), iteration,bins='sturges') 407 | def save_model(self,epoch): 408 | torch.save(self.actor.state_dict(), './save_model/actor_model'+str(epoch)+'_'+self.data+'.path.tar') 409 | torch.save(self.real_critic.state_dict(), './save_model/real_d_model'+str(epoch)+'_'+self.data+'.path.tar') 410 | torch.save(self.attr_critic.state_dict(), './save_model/attr_d_model'+str(epoch)+'_'+self.data+'.path.tar') 411 | def load_vae(self,path): 412 | 413 | print ("[+] Load pre-trained VAE model") 414 | checkpoint = torch.load(path) 415 | self.model.load_state_dict(checkpoint) 416 | -------------------------------------------------------------------------------- /model/vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from .sub_layer import Linear,View 5 | import pdb 6 | 7 | class Celeba_VAE(nn.Module): 8 | def __init__(self,input_size=128,d_model=1024,layer_num=3): 9 | super(Celeba_VAE,self).__init__() 10 | self.d_model = d_model 11 | self.layer_num = layer_num 12 | 13 | self.encoder = self.build_encoder(self.d_model,self.layer_num) 14 | self.sig_layer = nn.Softplus() 15 | 16 | self.decoder = self.build_decoder(self.d_model,self.layer_num) 17 | self.sigmoid = nn.Sigmoid() 18 | 19 | 20 | def build_encoder(self,d_model,layer_num): 21 | encoder_layerList = [] 22 | 23 | 24 | encoder_layerList.append(nn.Conv2d(in_channels = 3,out_channels = 256,kernel_size = 5,stride =2,padding=1)) 25 | #encoder_layerList.append(nn.BatchNorm2d(256)) 26 | encoder_layerList.append(nn.ReLU()) 27 | encoder_layerList.append(nn.Conv2d(in_channels = 256 ,out_channels = 256*2 , kernel_size = 5,stride =2,padding=1)) 28 | #encoder_layerList.append(nn.BatchNorm2d(256*2)) 29 | encoder_layerList.append(nn.ReLU()) 30 | encoder_layerList.append(nn.Conv2d(in_channels=512 , out_channels = 1024, kernel_size = 3,stride =2,padding=1)) 31 | #encoder_layerList.append(nn.BatchNorm2d(512*2)) 32 | encoder_layerList.append(nn.ReLU()) 33 | encoder_layerList.append(nn.Conv2d(in_channels = 1024,out_channels = 2048,kernel_size = 3,stride =2,padding=1)) 34 | #encoder_layerList.append(nn.BatchNorm2d(1024*2)) 35 | encoder_layerList.append(nn.ReLU()) 36 | encoder_layerList.append(View()) 37 | encoder_layerList.append(nn.Linear(4*4*2048,2048)) 38 | return nn.Sequential(*encoder_layerList) 39 | 40 | def build_decoder(self,d_model,layer_num): 41 | decoder_layerList = [] 42 | decoder_layerList 43 | 44 | decoder_layerList.append(nn.Linear(d_model,2048*4*4)) 45 | decoder_layerList.append(View([2048,4,4])) 46 | decoder_layerList.append(nn.ConvTranspose2d(2048,1024,3,stride=2 ,padding =1 ,output_padding=1)) 47 | #decoder_layerList.append(nn.BatchNorm2d(1024)) 48 | decoder_layerList.append(nn.ReLU()) 49 | decoder_layerList.append(nn.ConvTranspose2d(1024,512,3,stride=2 ,padding =1 ,output_padding=0)) 50 | #decoder_layerList.append(nn.BatchNorm2d(512)) 51 | decoder_layerList.append(nn.ReLU()) 52 | decoder_layerList.append(nn.ConvTranspose2d(512,256,5,stride=2,padding=1 ,output_padding=0)) 53 | #decoder_layerList.append(nn.BatchNorm2d(256)) 54 | decoder_layerList.append(nn.ReLU()) 55 | decoder_layerList.append(nn.ConvTranspose2d(256,3,5,stride=2 ,padding = 1 ,output_padding =1)) 56 | return nn.Sequential(*decoder_layerList) 57 | 58 | def reparameterize(self,mu,sig_var): 59 | ## need to understand 60 | if self.training: 61 | std = sig_var # need to check sig_var is log (sigma^2) 62 | eps = std.data.new(std.size()).normal_(std=1) 63 | return eps.mul(std).add_(mu) 64 | else: 65 | return mu 66 | def encode(self,x): 67 | encoder_out = self.encoder(x) 68 | 69 | sig_var , mu_var = encoder_out.chunk(2,dim=-1) 70 | 71 | sig_var = self.sig_layer(sig_var) 72 | z = self.reparameterize(mu_var,sig_var) 73 | return sig_var,mu_var,z 74 | def decode(self,z): 75 | output = self.decoder(z) 76 | output = self.sigmoid(output) 77 | return output 78 | 79 | def forward(self,x): 80 | encoder_out = self.encoder(x) 81 | sig_var , mu_var = encoder_out.chunk(2,dim=-1) 82 | 83 | sig_var = self.sig_layer(sig_var) 84 | z = self.reparameterize(mu_var,sig_var) 85 | 86 | output = self.decoder(z) 87 | output = self.sigmoid(output) 88 | return output,z,mu_var,sig_var 89 | 90 | 91 | 92 | class Mnist_VAE(nn.Module): 93 | def __init__(self,input_dim=28*28,d_model=1024,layer_num=3): 94 | super(Mnist_VAE,self).__init__() 95 | self.d_model = d_model 96 | self.layer_num = layer_num 97 | self.input_dim = input_dim 98 | 99 | self.encoder = self.build_encoder(self.input_dim,self.d_model,self.layer_num) 100 | self.sig_layer = nn.Softplus() 101 | 102 | self.decoder = self.build_decoder(self.input_dim,self.d_model,self.layer_num) 103 | 104 | self.sigmoid = nn.Sigmoid() 105 | def build_encoder(self,input_dim,d_model,layer_num): 106 | encoder_layerList = [] 107 | for i in range(layer_num): 108 | if i == 0 : 109 | encoder_layerList.append(nn.Linear(input_dim,d_model)) 110 | else: 111 | encoder_layerList.append(nn.Linear(d_model,d_model)) 112 | encoder_layerList.append(nn.ReLU()) 113 | encoder_layerList.append(nn.Linear(d_model,2*d_model)) 114 | return nn.Sequential(*encoder_layerList) 115 | 116 | def build_decoder(self,input_dim,d_model,layer_num): 117 | decoder_layerList = [] 118 | for i in range(layer_num): 119 | decoder_layerList.append(nn.Linear(d_model,d_model)) 120 | decoder_layerList.append(nn.ReLU()) 121 | 122 | decoder_layerList.append(nn.Linear(d_model,input_dim)) 123 | 124 | return nn.Sequential(*decoder_layerList) 125 | 126 | def reparameterize(self,mu,sig_var): 127 | ## need to understand 128 | if self.training: 129 | std = sig_var # need to check sig_var is log (sigma^2) 130 | eps = std.data.new(std.size()).normal_(std=1) 131 | return eps.mul(std).add_(mu) 132 | else: 133 | return mu 134 | def encode(self,x): 135 | x = x.view(-1,28*28) 136 | encoder_out = self.encoder(x) 137 | sig_var , mu_var = encoder_out.chunk(2,dim=-1) 138 | 139 | sig_var = self.sig_layer(sig_var) 140 | z = self.reparameterize(mu_var,sig_var) 141 | return sig_var,mu_var,z 142 | 143 | def decode(self,z): 144 | output = self.decoder(z) 145 | output = self.sigmoid(output) 146 | return output 147 | 148 | def forward(self,x): 149 | x = x.view(-1,28*28) 150 | encoder_out = self.encoder(x) 151 | sig_var , mu_var = encoder_out.chunk(2,dim=-1) 152 | 153 | sig_var = self.sig_layer(sig_var) 154 | z = self.reparameterize(mu_var,sig_var) 155 | output = self.decoder(z) 156 | output = self.sigmoid(output) 157 | 158 | return output,z,mu_var,sig_var 159 | 160 | 161 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from utils.get_data import celebA_data_preprocess 2 | 3 | 4 | 5 | if __name__ == '__main__': 6 | print ("[+] Testing celebA_data_preprocess function") 7 | celebA_data_preprocess() -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cbokpark/Pytorch-Latent-Constraints-Learning-to-Generate-Conditionally-from-Unconditional-Generative-Models/0dbd182b294e0c6d3ad0deda3be1dd855fd57617/utils/__init__.py -------------------------------------------------------------------------------- /utils/get_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import matplotlib.pyplot as plt 4 | from scipy.misc import imresize 5 | 6 | from PIL import Image 7 | import random 8 | 9 | import torchvision.utils as utils 10 | from torch.utils import data 11 | import torchvision.datasets as vision_dsets 12 | import torchvision.transforms as T 13 | 14 | 15 | from tqdm import tqdm 16 | import pdb 17 | def celebA_data_preprocess(root ='/hdd1/cheonbok_experiment/celevA/data/CelebA_nocrop/images/' ,save_root='/hdd1/cheonbok_experiment/celevA/data/CelebA_resize/',resize=64): 18 | """ 19 | Preprocessing the celevA data set (resizing) 20 | """ 21 | 22 | 23 | 24 | if not os.path.isdir(save_root): 25 | os.mkdir(save_root) 26 | if not os.path.isdir(save_root + 'celebA'): 27 | os.mkdir(save_root+ 'celebA') 28 | img_list = os.listdir(root) 29 | 30 | for i in tqdm(range(len(img_list)),desc='CelebA Preprocessing'): 31 | img = plt.imread(root+ img_list[i]) 32 | img = imresize(img,(resize,resize)) 33 | plt.imsave(fname = save_root + 'celebA/'+img_list[i],arr=img) 34 | print ("[+] Finished the CelebA Data set Preprocessing") 35 | 36 | def MNIST_DATA(root='./data',train =True,transforms=None ,download =True,batch_size = 32,num_worker = 2): 37 | if transforms is None: 38 | transforms = T.ToTensor() 39 | print ("[+] Get the MNIST DATA") 40 | mnist_train = vision_dsets.MNIST(root = root, 41 | train = True, 42 | transform = transforms, 43 | download = True) 44 | mnist_test = vision_dsets.MNIST(root = root, 45 | train = False, 46 | transform = T.ToTensor(), 47 | download = True) 48 | trainDataLoader = data.DataLoader(dataset = mnist_train, 49 | batch_size = batch_size, 50 | shuffle =True, 51 | num_workers = 2) 52 | 53 | testDataLoader = data.DataLoader(dataset = mnist_test, 54 | batch_size = batch_size, 55 | shuffle = False, 56 | num_workers = 2) 57 | print ("[+] Finished loading data & Preprocessing") 58 | return mnist_train,mnist_test,trainDataLoader,testDataLoader 59 | 60 | 61 | class CelebA(data.Dataset): 62 | """Dataset class for the CelebA dataset.""" 63 | 64 | def __init__(self, image_dir, attr_path, selected_attrs, transform, mode,un_condition_mode=True): 65 | """Initialize and preprocess the CelebA dataset.""" 66 | self.image_dir = image_dir 67 | self.attr_path = attr_path 68 | self.selected_attrs = selected_attrs 69 | self.transform = transform 70 | self.mode = mode 71 | self.train_dataset = [] 72 | self.test_dataset = [] 73 | self.attr2idx = {} 74 | self.idx2attr = {} 75 | self.un_condition_mode = un_condition_mode 76 | self.preprocess() 77 | 78 | if mode == 'train': 79 | self.num_images = len(self.train_dataset) 80 | else: 81 | self.num_images = len(self.test_dataset) 82 | 83 | def preprocess(self): 84 | """Preprocess the CelebA attribute file.""" 85 | lines = [line.rstrip() for line in open(self.attr_path, 'r')] 86 | all_attr_names = lines[1].split() 87 | for i, attr_name in enumerate(all_attr_names): 88 | self.attr2idx[attr_name] = i 89 | self.idx2attr[i] = attr_name 90 | 91 | lines = lines[2:] 92 | random.seed(1234) 93 | random.shuffle(lines) 94 | for i, line in enumerate(lines): 95 | split = line.split() 96 | filename = split[0] 97 | values = split[1:] 98 | 99 | label = [] 100 | for attr_name in self.selected_attrs: 101 | idx = self.attr2idx[attr_name] 102 | label.append(values[idx] == '1') 103 | 104 | if sum(label)>=1 or self.un_condition_mode: 105 | if (i+1) < 2000: 106 | self.test_dataset.append([filename, label]) 107 | else: 108 | self.train_dataset.append([filename, label]) 109 | print('[+]Finished preprocessing the CelebA dataset...') 110 | 111 | def __getitem__(self, index): 112 | """Return one image and its corresponding attribute label.""" 113 | dataset = self.train_dataset if self.mode == 'train' else self.test_dataset 114 | filename, label = dataset[index] 115 | image = Image.open(os.path.join(self.image_dir, filename)) 116 | 117 | return self.transform(image), torch.FloatTensor(label) 118 | 119 | def __len__(self): 120 | """Return the number of images.""" 121 | return self.num_images 122 | #def __figure_label(self,label): 123 | # figure_label = {0:[2,]} 124 | 125 | 126 | 127 | def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128, 128 | batch_size=16, dataset='CelebA', mode='train', num_workers=1,un_condition_mode = True): 129 | """Build and return a data loader.""" 130 | transform = [] 131 | if mode == 'train': 132 | transform.append(T.RandomHorizontalFlip()) 133 | transform.append(T.CenterCrop(crop_size)) 134 | transform.append(T.Resize(image_size)) 135 | transform.append(T.ToTensor()) 136 | #transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) 137 | transform = T.Compose(transform) 138 | 139 | if dataset == 'CelebA': 140 | dataset = CelebA(image_dir, attr_path, selected_attrs, transform, mode,un_condition_mode) 141 | elif dataset == 'RaFD': 142 | dataset = ImageFolder(image_dir, transform,un_condition_mode) 143 | 144 | data_loader = data.DataLoader(dataset=dataset, 145 | batch_size=batch_size, 146 | shuffle=(mode=='train'), 147 | num_workers=num_workers) 148 | return data_loader,dataset 149 | 150 | def Celeba_DATA(celeba_img_dir ,attr_path,image_size=128,celeba_crop_size=178,selected_attrs=None,batch_size = 32,num_worker = 1,un_condition_mode =True): 151 | 152 | 153 | 154 | if selected_attrs is None: 155 | selected_attrs = ['5_o_Clock_Shadow','Arched_Eyebrows','Attractive','Bags_Under_Eyes', 'Bangs', 156 | 'Big_Lips','Big_Nose', 'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 157 | 'Double_Chin', 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', 'Male', 158 | 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 159 | 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 160 | 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young'] 161 | trainDataLoader,trainData = get_loader(celeba_img_dir,attr_path,selected_attrs, 162 | celeba_crop_size,image_size,batch_size, 163 | 'CelebA','train',num_worker,un_condition_mode) 164 | testDataLoader,testData = get_loader(celeba_img_dir,attr_path,selected_attrs, 165 | celeba_crop_size,image_size,batch_size, 166 | 'CelebA','test',num_worker,un_condition_mode) 167 | 168 | return trainData,testData,trainDataLoader,testDataLoader --------------------------------------------------------------------------------