├── .gitignore
├── LICENSE
├── README.md
├── download.sh
├── images
├── MNIST_RESULTS.PNG
└── cleba_results.PNG
├── main.py
├── model
├── .gitignore
├── __init__.py
├── base_trainer.py
├── loss.py
├── model.py
├── sub_layer.py
├── trainer.py
└── vae.py
├── test.py
└── utils
├── __init__.py
└── get_data.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | #
6 | results/
7 |
8 | data/
9 | runs/
10 | save_model/
11 | MNIST_results/
12 | celeba_results/
13 |
14 | # C extensions
15 | *.so
16 |
17 | # Distribution / packaging
18 | .Python
19 | env/
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | results/
33 | *.egg-info/
34 | .installed.cfg
35 | *.egg
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | .hypothesis/
57 | testing.py
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # pyenv
83 | .python-version
84 |
85 | # celery beat schedule file
86 | celerybeat-schedule
87 |
88 | # SageMath parsed files
89 | *.sage.py
90 |
91 | # dotenv
92 | .env
93 |
94 | # virtualenv
95 | .venv
96 | venv/
97 | ENV/
98 |
99 | # Spyder project settings
100 | .spyderproject
101 | .spyproject
102 |
103 | # Rope project settings
104 | .ropeproject
105 |
106 | # mkdocs documentation
107 | /site
108 |
109 | # mypy
110 | .mypy_cache/
111 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Cheonbok Park
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Pytorch-Latent-Constraints-Learning-to-Generate-Conditionally-from-Unconditional-Generative-Models
2 | **[Jesse Engel, Matthew Hoffman, Adam Roberts, "Learning to Generate Conditionally from Unconditional Generative Models" arXiv preprint arXiv:1711.05772 (2018)](https://arxiv.org/abs/1711.05772).**
3 | ## Meta overview
4 | This repository provides a PyTorch implementation of Learning to Generate Conditionally from Unconditional Generative Models implemented .
5 |
6 |
7 |
8 |
9 | ## Current update status
10 | * [ ] Reproduce result as the paper
11 | * [ ] Atrribute Classifier
12 | * [ ] Gradinet Penalty
13 | * [x] Tensorboad loggings
14 | * [x] Trainer
15 | * [x] Implemented a actor-crtic pairs , Distance Penalty
16 | * [x] Implemented VAE
17 |
18 |
19 | ## Results
20 |
21 |
22 | ### MNIST Results
23 |

24 |
25 | ### CelebA Results
26 | 
27 |
28 |
29 | ## Usage
30 |
31 | #### 1.Clone the repository
32 | ```bash
33 | $ git clone https://github.com/cheonbok94/Pytorch-Latent-Constraints-Learning-to-Generate-Conditionally-from-Unconditional-Generative-Models.git
34 | $ cd Pytorch-Latent-Constraints-Learning-to-Generate-Conditionally-from-Unconditional-Generative-Models
35 | ```
36 |
37 | #### 2.Download datasets (Celeba) & install requirement packages
38 | ```bash
39 | $ bash download.sh
40 | ```
41 |
42 | #### 3. Train
43 |
44 | ##### (1) Training VAE
45 |
46 | ##### (2) Training a Actor-Critic Pairs
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/download.sh:
--------------------------------------------------------------------------------
1 | URL=https://www.dropbox.com/s/3e5cmqgplchz85o/CelebA_nocrop.zip?dl=0
2 | ZIP_FILE=./data/CelebA_nocrop.zip
3 | mkdir -p ./data/
4 | wget -N $URL -O $ZIP_FILE
5 | unzip $ZIP_FILE -d ./data/
6 | rm $ZIP_FILE
7 |
8 | # CelebA attribute labels
9 | URL=https://www.dropbox.com/s/auexdy98c6g7y25/list_attr_celeba.zip?dl=0
10 | ZIP_FILE=./data/list_attr_celeba.zip
11 | wget -N $URL -O $ZIP_FILE
12 | unzip $ZIP_FILE -d ./data/
13 | rm $ZIP_FILE
14 |
--------------------------------------------------------------------------------
/images/MNIST_RESULTS.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cbokpark/Pytorch-Latent-Constraints-Learning-to-Generate-Conditionally-from-Unconditional-Generative-Models/0dbd182b294e0c6d3ad0deda3be1dd855fd57617/images/MNIST_RESULTS.PNG
--------------------------------------------------------------------------------
/images/cleba_results.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cbokpark/Pytorch-Latent-Constraints-Learning-to-Generate-Conditionally-from-Unconditional-Generative-Models/0dbd182b294e0c6d3ad0deda3be1dd855fd57617/images/cleba_results.PNG
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy
3 | from model.trainer import Trainer,AC_Trainer
4 | from model.loss import loss_function,celeba_loss
5 |
6 |
7 | from model.vae import Mnist_VAE,Celeba_VAE
8 | from utils.get_data import MNIST_DATA ,Celeba_DATA
9 | from model.model import Actor,Critic
10 | import torch
11 |
12 | def main():
13 | parser = argparse.ArgumentParser()
14 |
15 | parser.add_argument('--gpu_num', type = int, default = None)
16 | parser.add_argument('--data',type =str,required=True)
17 | parser.add_argument('--num_epoch',type=int,default =50)
18 | parser.add_argument('--batch_size',type=int,default =128)
19 | parser.add_argument('--tensorboard_dirs',type=str,default ='./run')
20 | parser.add_argument('--train_id',type=str , default = 'my_model')
21 | parser.add_argument('--gpu_accelerate',action='store_true')
22 | parser.add_argument('--image_dir',type = str , default = './data/CelebA_nocrop/images')
23 | parser.add_argument('--attr_path',type = str ,default = './data/list_attr_celeba.txt')
24 | parser.add_argument('--distance_penalty',type=float, default = 0.1)
25 | parser.add_argument('--vae_mode',action = 'store_true')
26 | parser.add_argument('--save_model',type=str,default = './')
27 | parser_config = parser.parse_args()
28 | print (parser_config)
29 | if parser_config.gpu_num is not None :
30 | torch.cuda.set_device(parser_config.gpu_num)
31 |
32 | if parser_config.gpu_accelerate:
33 | torch.backends.cudnn.benchmark = True
34 | if parser_config.gpu_num == -1:
35 | device = 'cpu'
36 | else:
37 | device = parser_config.gpu_num
38 | if parser_config.vae_mode:
39 |
40 | if parser_config.data == 'MNIST':
41 | trainDset,testDset,trainDataLoader,testDataLoader= MNIST_DATA(batch_size = parser_config.batch_size )
42 | model = Mnist_VAE(input_dim= 28*28 ,layer_num= 4, d_model=1024)
43 | trainer = Trainer(model=model,
44 | loss = loss_function,
45 | data = parser_config.data,
46 | epoch = parser_config.num_epoch,
47 | trainDataLoader=trainDataLoader,
48 | testDataLoader=testDataLoader)
49 |
50 | elif parser_config.data == 'celeba':
51 | trainDset,testDset,trainDataLoader,testDataLoader = Celeba_DATA(celeba_img_dir=parser_config.image_dir ,attr_path=parser_config.attr_path,batch_size = parser_config.batch_size,image_size=64,celeba_crop_size=128)
52 | model = Celeba_VAE(128,d_model=1024,layer_num=3)
53 | trainer = Trainer(model=model,
54 | loss = celeba_loss,
55 | data = parser_config.data,
56 | epoch = parser_config.num_epoch,
57 | trainDataLoader=trainDataLoader,
58 | testDataLoader=testDataLoader)
59 |
60 | else:
61 | raise NotImplementedError
62 | else:
63 | if parser_config.data == 'MNIST':
64 | trainDset,testDset,trainDataLoader,testDataLoader = MNIST_DATA(batch_size = parser_config.batch_size )
65 | model = Mnist_VAE(input_dim= 28*28 ,layer_num= 4, d_model=1024)
66 | actor = Actor(1024,2048)
67 | real_critic = Critic(1024,2048,num_labels=10,condition_mode =True)
68 | attr_critic = Critic(1024,2048,num_labels=10,num_output=10,condition_mode =False)
69 | actrainer = AC_Trainer(vae_model=model,
70 | actor = actor,
71 | real_critic = real_critic,
72 | attr_critic = attr_critic,
73 | epoch = parser_config.num_epoch,
74 | data = parser_config.data,
75 | trainDataLoader=trainDataLoader,
76 | testDataLoader=testDataLoader)
77 | actrainer.load_vae('./save_model/vae_model50_MNIST.path.tar')
78 | actrainer._set_label_type()
79 | if parser_config.data == 'celeba':
80 | selected_attrs = ['Bald','Black_Hair','Blond_Hair','Brown_Hair','Eyeglasses','Male','No_Beard','Smiling','Wearing_Hat','Young']
81 | trainDset,testDset,trainDataLoader,testDataLoader = Celeba_DATA(celeba_img_dir=parser_config.image_dir ,attr_path=parser_config.attr_path,selected_attrs=selected_attrs,batch_size = parser_config.batch_size,image_size=64,celeba_crop_size=128)
82 | model = Celeba_VAE(128,d_model=1024,layer_num=3)
83 | actor = Actor(1024,2048)
84 | real_critic = Critic(1024,2048,num_labels=10,condition_mode =True)
85 | attr_critic = Critic(1024,2048,num_labels=10,num_output=10,condition_mode =False)
86 | actrainer = AC_Trainer(vae_model=model,
87 | actor = actor,
88 | real_critic = real_critic,
89 | attr_critic = attr_critic,
90 | epoch = parser_config.num_epoch,
91 | data = parser_config.data,
92 | trainDataLoader=trainDataLoader,
93 | testDataLoader=testDataLoader)
94 | actrainer.load_vae('./save_model/vae_model150_celeba.path.tar')
95 |
96 |
97 | print ("[+] Train model start")
98 | if parser_config.vae_mode:
99 | trainer.train()
100 | else:
101 | actrainer.train()
102 |
103 | if __name__ == '__main__':
104 | main()
--------------------------------------------------------------------------------
/model/.gitignore:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cbokpark/Pytorch-Latent-Constraints-Learning-to-Generate-Conditionally-from-Unconditional-Generative-Models/0dbd182b294e0c6d3ad0deda3be1dd855fd57617/model/.gitignore
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cbokpark/Pytorch-Latent-Constraints-Learning-to-Generate-Conditionally-from-Unconditional-Generative-Models/0dbd182b294e0c6d3ad0deda3be1dd855fd57617/model/__init__.py
--------------------------------------------------------------------------------
/model/base_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import json
4 | import logging
5 | import torch
6 | import torch.optim as optim
7 | from utils.util import ensure_dir
8 |
9 |
10 | class BaseTrainer:
11 | """
12 | Base class for all trainers
13 | """
14 | def __init__(self, model, loss, metrics, resume, config, train_logger=None):
15 | self.config = config
16 | self.logger = logging.getLogger(self.__class__.__name__)
17 | self.model = model
18 | self.loss = loss
19 | self.metrics = metrics
20 | self.name = config['name']
21 | self.epochs = config['trainer']['epochs']
22 | self.save_freq = config['trainer']['save_freq']
23 | self.verbosity = config['trainer']['verbosity']
24 | self.with_cuda = config['cuda'] and torch.cuda.is_available()
25 | if config['cuda'] and not torch.cuda.is_available():
26 | self.logger.warning('Warning: There\'s no CUDA support on this machine, '
27 | 'training is performed on CPU.')
28 | self.train_logger = train_logger
29 | self.optimizer = getattr(optim, config['optimizer_type'])(model.parameters(),
30 | **config['optimizer'])
31 | self.monitor = config['trainer']['monitor']
32 | self.monitor_mode = config['trainer']['monitor_mode']
33 | assert self.monitor_mode == 'min' or self.monitor_mode == 'max'
34 | self.monitor_best = math.inf if self.monitor_mode == 'min' else -math.inf
35 | self.start_epoch = 1
36 | self.checkpoint_dir = os.path.join(config['trainer']['save_dir'], self.name)
37 | ensure_dir(self.checkpoint_dir)
38 | json.dump(config, open(os.path.join(self.checkpoint_dir, 'config.json'), 'w'),
39 | indent=4, sort_keys=False)
40 | if resume:
41 | self._resume_checkpoint(resume)
42 |
43 | def train(self):
44 | """
45 | Full training logic
46 | """
47 | for epoch in range(self.start_epoch, self.epochs+1):
48 | result = self._train_epoch(epoch)
49 | log = {'epoch': epoch}
50 | for key, value in result.items():
51 | if key == 'metrics':
52 | for i, metric in enumerate(self.metrics):
53 | log[metric.__name__] = result['metrics'][i]
54 | elif key == 'val_metrics':
55 | for i, metric in enumerate(self.metrics):
56 | log['val_' + metric.__name__] = result['val_metrics'][i]
57 | else:
58 | log[key] = value
59 | if self.train_logger is not None:
60 | self.train_logger.add_entry(log)
61 | if self.verbosity >= 1:
62 | for key, value in log.items():
63 | self.logger.info(' {:15s}: {}'.format(str(key), value))
64 | if (self.monitor_mode == 'min' and log[self.monitor] < self.monitor_best)\
65 | or (self.monitor_mode == 'max' and log[self.monitor] > self.monitor_best):
66 | self.monitor_best = log[self.monitor]
67 | self._save_checkpoint(epoch, log, save_best=True)
68 | if epoch % self.save_freq == 0:
69 | self._save_checkpoint(epoch, log)
70 |
71 | def _train_epoch(self, epoch):
72 | """
73 | Training logic for an epoch
74 | :param epoch: Current epoch number
75 | """
76 | raise NotImplementedError
77 |
78 | def _save_checkpoint(self, epoch, log, save_best=False):
79 | """
80 | Saving checkpoints
81 | :param epoch: current epoch number
82 | :param log: logging information of the epoch
83 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth.tar'
84 | """
85 | arch = type(self.model).__name__
86 | state = {
87 | 'arch': arch,
88 | 'epoch': epoch,
89 | 'logger': self.train_logger,
90 | 'state_dict': self.model.state_dict(),
91 | 'optimizer': self.optimizer.state_dict(),
92 | 'monitor_best': self.monitor_best,
93 | 'config': self.config
94 | }
95 | filename = os.path.join(self.checkpoint_dir, 'checkpoint-epoch{:03d}-loss-{:.4f}.pth.tar'
96 | .format(epoch, log['loss']))
97 | torch.save(state, filename)
98 | if save_best:
99 | os.rename(filename, os.path.join(self.checkpoint_dir, 'model_best.pth.tar'))
100 | self.logger.info("Saving current best: {} ...".format('model_best.pth.tar'))
101 | else:
102 | self.logger.info("Saving checkpoint: {} ...".format(filename))
103 |
104 | def _resume_checkpoint(self, resume_path):
105 | """
106 | Resume from saved checkpoints
107 | :param resume_path: Checkpoint path to be resumed
108 | """
109 | self.logger.info("Loading checkpoint: {} ...".format(resume_path))
110 | checkpoint = torch.load(resume_path)
111 | self.start_epoch = checkpoint['epoch'] + 1
112 | self.monitor_best = checkpoint['monitor_best']
113 | self.model.load_state_dict(checkpoint['state_dict'])
114 | self.optimizer.load_state_dict(checkpoint['optimizer'])
115 | self.train_logger = checkpoint['logger']
116 | self.config = checkpoint['config']
117 | self.logger.info("Checkpoint '{}' (epoch {}) loaded".format(resume_path, self.start_epoch))
118 |
--------------------------------------------------------------------------------
/model/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | def loss_function(recon_x,x,mu,sig_var):
5 | reconstruction_loss = F.mse_loss(recon_x,x.view(-1,784),size_average = False)/0.01
6 |
7 | KLD_element = mu.pow(2).add_(sig_var.pow(2)).mul_(-1).add_(1).add_(sig_var.pow(2).log())
8 | KLD = torch.sum(KLD_element).mul_(-0.5)
9 | return KLD+ reconstruction_loss
10 |
11 | def celeba_loss(recon_x,x,mu,sig_var):
12 | reconstruction_loss = F.mse_loss(recon_x,x,size_average = False)/recon_x.size(0)/0.1
13 |
14 | KLD_element = mu.pow(2).add_(sig_var.pow(2)).mul_(-1).add_(1).add_(sig_var.pow(2).log())
15 | KLD = torch.sum(KLD_element).mul_(-0.5)
16 | return (KLD+ reconstruction_loss)
17 | class AC_loss:
18 | def __init__(self,lambda_dist,lambda_attr):
19 | self.lambda_attr = lambda_attr
20 | self.lambda_dist = lambda_dist
21 | def __real_loss(self,z,z_prime,sigvar,predict_d,grth_d):
22 | """
23 | inputs:
24 | z : Batch Size * Z_dim
25 | z_prime : Batch Size * Z_dim
26 | sig_var : std
27 | """
28 |
29 | sum_variance = torch.sum(sigvar.pow(2),dim=-1)
30 | distance_penalty = torch.sum(F.mse_loss(z_prime,z,size_average=False,reduce =False)*sum_variance)
31 | real_loss = F.binary_cross_entropy(predict_d,grth_d)
32 |
33 | return real_loss
34 |
35 | #def __attr_loss(self,z,z_prime,sigvar,predict_d,grth_d)
36 |
37 |
--------------------------------------------------------------------------------
/model/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .sub_layer import Linear
4 | import torch.nn.functional as F
5 | import pdb
6 | class Actor(nn.Module):
7 | def __init__(self,d_z,d_model=2048,layer_num=4,num_label = 10,condition_mode = True):
8 | super(Actor,self).__init__()
9 | self.d_z = d_z
10 | self.d_model = d_model
11 | self.layer_num = layer_num
12 | self.condition_mode = condition_mode
13 | layer_list = []
14 | for i in range(layer_num):
15 | if i == 0:
16 | if condition_mode:
17 | input_dim = d_z + d_model
18 | else:
19 | input_dim = d_z
20 | layer_list.append(Linear(input_dim,d_model))
21 | else:
22 | layer_list.append(Linear(d_model,d_model))
23 | layer_list.append(nn.ReLU())
24 | layer_list.append(Linear(d_model,self.d_z*2))
25 |
26 | self.fw_layer = nn.Sequential(*layer_list)
27 | self.gate = nn.Sigmoid()
28 | if condition_mode:
29 | self.num_label = num_label
30 | self.condition_layer = Linear(num_label,d_model)
31 |
32 | def forward(self,x,label):
33 | original_x = x
34 |
35 | x = torch.cat((x,self.condition_layer(label)),dim = -1)
36 | out = self.fw_layer(x)
37 | input_gate , dz = out.chunk(2,dim = -1)
38 | gate_value = self.gate(input_gate)
39 | new_z = (1-gate_value)*original_x + gate_value*dz
40 | return new_z
41 |
42 | class Critic(nn.Module):
43 | def __init__(self,d_z,d_model,layer_num=4,num_labels = None,num_output = 1,condition_mode = True):
44 | super(Critic,self).__init__()
45 | self.d_z = d_z
46 | self.d_model = d_model
47 | self.condition_mode = condition_mode
48 | layer_list = []
49 | for i in range(layer_num):
50 | if i == 0:
51 | if condition_mode:
52 | input_dim = d_z + d_model
53 | else:
54 | input_dim = d_z
55 | layer_list.append(Linear(input_dim,d_model))
56 | else:
57 | layer_list.append(Linear(d_model,d_model))
58 | layer_list.append(nn.ReLU())
59 | layer_list.append(Linear(d_model,num_output))
60 | self.fw_layer = nn.Sequential(*layer_list)
61 |
62 | if condition_mode:
63 | self.num_labels = num_labels
64 | self.condition_layer = Linear(num_labels,d_model)
65 | def forward(self, x, label = None):
66 |
67 | if self.condition_mode:
68 | x = torch.cat((x,self.condition_layer(label)),dim = -1)
69 | out = self.fw_layer(x)
70 | return F.sigmoid(out)
71 |
72 |
--------------------------------------------------------------------------------
/model/sub_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class Linear(nn.Module):
5 | def __init__(self,input_dim,output_dim,bias = True):
6 | super(Linear,self).__init__()
7 | self.linear = nn.Linear(input_dim,output_dim,bias=bias)
8 | self.batch_norm = nn.BatchNorm1d(output_dim)
9 | def forward(self,x):
10 | out = self.linear(x)
11 | return self.batch_norm(out)
12 | class View(nn.Module):
13 | def __init__(self,shape = None):
14 | super(View,self).__init__()
15 | self.shape = shape
16 | def forward(self,x):
17 | if self.shape is None:
18 | return x.view(x.size(0),-1)
19 | else:
20 |
21 | return x.view(x.size(0),*self.shape)
22 |
--------------------------------------------------------------------------------
/model/trainer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import pdb
4 | from torchvision.utils import save_image
5 | import torch.optim as optim
6 | import numpy as np
7 | from tensorboardX import SummaryWriter
8 | import torch.nn.functional as F
9 | from tqdm import tqdm
10 | class Trainer:
11 | # need to modify to inheritance version
12 | def __init__(self,model,trainDataLoader,loss,epoch,data = 'celeba' ,metrics=None,resume=None,config=None,validDataLoader = None,device = 0, testDataLoader = None,train_logger =None,optimizer_type='Adam',lr=1e-3):
13 | #super(Trainer,self).__init__(model,loss,metrics,resume,config,train_logger)
14 | self.model = model
15 | self.trainDataLoader = trainDataLoader
16 | self.testDataLoader = testDataLoader
17 | self.validDataLoader = validDataLoader
18 | self.valid = True if self.validDataLoader is not None else False
19 | self.test = True if self.testDataLoader is not None else False
20 | self.device = device
21 | self.model.to(self.device)
22 | self.d_model = self.model.d_model
23 | self.train_loss = 0
24 |
25 | self.data = data
26 | self.tensorboad_writer = SummaryWriter()
27 | self.epoch = epoch
28 | self.loss = loss
29 | self.start_epoch = 1
30 | self.with_cuda = torch.cuda.is_available()
31 | self.save_freq = 500
32 | self.total_iteration = 0
33 | self.optimizer = getattr(optim, optimizer_type)(self.model.parameters(),lr=lr)
34 | self.valid_term = 10
35 | def train(self):
36 | for epoch in range(self.start_epoch,self.epoch+1):
37 | result = self._train_epoch(epoch)
38 | self.get_sample(epoch)
39 | if epoch%self.valid_term == 0:
40 | self._test(epoch)
41 | self.save_model(epoch)
42 | print ("[+] Finished Training Model")
43 |
44 |
45 | def _train_epoch(self,epoch):
46 |
47 | self.model.train()
48 | train_loss = 0
49 | for batch_idx,(data,labels) in enumerate(self.trainDataLoader):
50 | data = data.to(self.device)
51 | self.optimizer.zero_grad()
52 | recon_batch,z,mu,log_sigma = self.model(data)
53 | loss = self.loss(recon_batch,data,mu,log_sigma)
54 | loss.backward()
55 | self.optimizer.step()
56 | train_loss += loss.item()
57 | if batch_idx == 1:
58 | if self.data == 'MNIST':
59 | recon_batch = recon_batch.view(-1,1,28,28)
60 | save_image(recon_batch.cpu(),self.data+'_results/sample_train_' + str(epoch) +'.png')
61 | save_image(data.cpu(),self.data+'_results/grtruth_train_' + str(epoch) +'.png')
62 | self._summary_wrtie(train_loss,epoch)
63 | print ("[+] Epoch:[{}/{}] train average loss :{}".format(epoch,self.epoch,train_loss))
64 | # print interval state
65 |
66 | def _test(self,epoch):
67 | self.model.eval()
68 | test_loss = 0
69 | with torch.no_grad():
70 | for i, (data,lebels) in enumerate(self.testDataLoader):
71 | data = data.cuda()
72 | recon_batch,z,mu,log_sigma = self.model(data)
73 |
74 | loss = self.loss(recon_batch,data,mu,log_sigma)
75 | test_loss += loss.item()
76 | if i == 1:
77 | if self.data == 'MNIST':
78 | recon_batch = recon_batch.view(-1,1,28,28)
79 | save_image(recon_batch.cpu(),self.data+'_results/sample_valid_' + str(epoch) +'.png')
80 | save_image(data.cpu(),self.data+'_results/grtruth_valid_' + str(epoch) +'.png')
81 |
82 | print ("[+] Validation result {}".format(test_loss))
83 | def get_sample(self,epoch):
84 | self.model.eval()
85 | with torch.no_grad():
86 |
87 | sample = torch.randn(64,self.d_model).to(self.device)
88 | out = self.model.decoder(sample)
89 | out = self.model.sigmoid(out)
90 | if self.data == 'MNIST':
91 | save_image(out.view(-1,1,28,28),self.data+'_results/sample_' + str(epoch) +'.png')
92 | else:
93 | save_image(out.cpu(),self.data+'_results/sample_' + str(epoch) +'.png')
94 | def save_model(self,epoch):
95 |
96 | torch.save(self.model.state_dict(), './save_model/vae_model'+str(epoch)+'_'+self.data+'.path.tar')
97 | def _summary_wrtie(self,loss,epoch):
98 | self.tensorboad_writer.add_scalar('data/loss',loss,epoch)
99 | for name,param in self.model.named_parameters():
100 | self.tensorboad_writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch,bins='sturges')
101 | self.tensorboad_writer.add_histogram(name+'/grad', param.grad.clone().cpu().data.numpy(), epoch,bins='sturges')
102 | def _eval_metric(self,output,target):
103 | raise NotImplementedError
104 |
105 | class AC_Trainer:
106 | def __init__(self,vae_model,actor,real_critic,attr_critic,epoch,trainDataLoader,data,metrics=None,resume=None,config=None,validDataLoader = None,device = 0, testDataLoader = None,train_logger =None,optimizer_type='Adam',lr=1e-4):
107 | self.model = vae_model
108 | self.actor = actor
109 | self.real_critic = real_critic
110 | self.attr_critic = attr_critic
111 | #self.loss = loss # lossfunction class
112 | self.epoch = epoch
113 |
114 | self.trainDataLoader = trainDataLoader
115 | self.testDataLoader = testDataLoader
116 | self.validDataLoader = validDataLoader
117 | self.valid = True if self.validDataLoader is not None else False
118 | self.test = True if self.testDataLoader is not None else False
119 | self.device = device
120 |
121 | self.model.to(self.device)
122 | self.model.eval()
123 | self.actor.to(self.device)
124 | self.real_critic.to(self.device)
125 | self.attr_critic.to(self.device)
126 |
127 | self.d_model = self.model.d_model
128 | self.train_loss = 0
129 |
130 | self.iteration = 0
131 | self.data = data
132 | self.tensorboad_writer = SummaryWriter()
133 | self.epoch = epoch
134 | #self.loss = loss
135 |
136 | self.start_epoch = 1
137 | self.with_cuda = torch.cuda.is_available()
138 | self.save_freq = 500
139 | self.total_iteration = 0
140 | #self.gen_optimizer = getattr(optim, optimizer_type)(list(self.actor.parameters()) + list(self.critic.parameters()),lr=lr)
141 | self.actor_optim = getattr(optim, 'Adam')(self.actor.parameters(),lr=lr*3)
142 | self.real_optim = getattr(optim, 'Adam')(self.real_critic.parameters(),lr=lr*3)
143 | self.attr_optim = getattr(optim, optimizer_type)(self.attr_critic.parameters(),lr=lr*3)
144 | self.valid_term = 10
145 | self._set_label_type()
146 |
147 | def train(self):
148 | for epoch in tqdm(range(self.start_epoch,self.epoch+1)):
149 | result = self._train_epoch(epoch)
150 | self.get_sample(epoch)
151 | if epoch%self.valid_term == 0:
152 | #self._test(epoch)
153 | self.get_sample(epoch)
154 | self.save_model(epoch)
155 | if (epoch+1) == 30:
156 | self.actor_optim.param_groups[0]['lr'] /= 10
157 | self.real_optim.param_groups[0]['lr'] /= 10
158 | print("learning rate change!")
159 |
160 | if (epoch+1) == 50:
161 | self.actor_optim.param_groups[0]['lr'] /= 10
162 | self.real_optim.param_groups[0]['lr'] /= 10
163 | print("learning rate change!")
164 | print ("[+] Finished Training Model")
165 |
166 |
167 | def _train_epoch(self,epoch):
168 |
169 | # Run Data Loader
170 | # realism Constraints pz~ is 0 E(q|x) ~ x
171 | # attr Constraints
172 | self.model.eval()
173 | self.actor.train()
174 | self.real_critic.train()
175 | self.attr_critic.train()
176 | train_loss = 0
177 | iteration = 0
178 | total_actor_loss = 0
179 | total_real_loss = 0
180 | total_distance_penalty = 0
181 | actor_iteration = 0
182 | for batch_idx,(data,labels) in enumerate(self.trainDataLoader):
183 | self.iteration += 1
184 | iteration += 1
185 | m_batchsize = data.size(0)
186 |
187 | real_data = torch.ones(m_batchsize,1)
188 | fake_data = torch.zeros(m_batchsize,1)
189 | fake_z = torch.randn(m_batchsize,self.d_model)
190 | fake_attr = self.fake_attr_generate(m_batchsize)
191 |
192 | real_data = real_data.to(self.device)
193 | fake_data = fake_data.to(self.device)
194 | labels = labels.to(self.device)
195 | fake_z = fake_z.to(self.device)
196 | fake_attr = fake_attr.to(self.device)
197 | data = data.to(self.device)
198 | with torch.no_grad():
199 | sig_var,mu,z = self.model.encode(data)
200 |
201 | z = self.reparameterize(mu,z)
202 |
203 |
204 | fake_z.requires_grad =True
205 | z.requires_grad = True
206 | labels.requires_grad =True
207 | fake_attr.requires_grad =True
208 |
209 | self.real_critic.zero_grad()
210 | if np.random.rand(1) < 0.1:
211 |
212 | #input_data = torch.cat([z,fake_z],dim=0)
213 | input_data = torch.cat([z,fake_z,z],dim=0)
214 |
215 |
216 | input_attr = torch.cat([labels,labels,fake_attr],dim=0)
217 | #input_attr = torch.cat([labels,labels],dim=0)
218 | real_labels = torch.cat([real_data,fake_data,fake_data])
219 |
220 | #real_labels = torch.cat([real_data,fake_data])
221 |
222 | logit_out = self.real_critic(input_data,input_attr)
223 | critic_loss = F.binary_cross_entropy(logit_out,real_labels)
224 | else:
225 | #pdb.set_trace()
226 | fake_z.requires_grad =True
227 | z_g = self.actor(fake_z,labels)
228 |
229 | input_data = torch.cat([z,z_g,z],dim=0)
230 | #input_data = torch.cat([z,z_g],dim=0)
231 |
232 | input_attr = torch.cat([labels,labels,fake_attr],dim=0)
233 | #input_attr = torch.cat([labels,labels],dim=0)
234 |
235 |
236 | real_labels = torch.cat([real_data,fake_data,fake_data])
237 | #real_labels = torch.cat([real_data,fake_data])
238 |
239 | logit_out = self.real_critic(input_data,input_attr)
240 |
241 | critic_loss = F.binary_cross_entropy(logit_out,real_labels)
242 |
243 | #print ("critic_loss : {}".format(critic_loss))
244 | critic_loss.backward()
245 | self.real_optim.step()
246 | if (batch_idx+1)%1000 == 0:
247 | self.d_critic_histogram(self.iteration)
248 | total_real_loss += critic_loss.item()
249 |
250 | if (batch_idx+1)%10 ==0:
251 | self.actor.zero_grad()
252 | #actor_labels = self.re_allocate(labels)
253 | fake_z = torch.randn(m_batchsize,self.d_model)
254 | fake_z = fake_z.to(self.device)
255 | fake_z = self.re_allocate(fake_z)
256 | #actor_truth = self.re_allocate(real_data)
257 |
258 | actor_labels = labels
259 | fake_z = fake_z
260 | actor_truth = real_data
261 |
262 | actor_labels.requires_grad =True
263 | fake_z.requires_grad = True
264 |
265 | actor_g = self.actor(fake_z,actor_labels)
266 | real_g = self.actor(z,actor_labels)
267 | zg_critic_out = self.real_critic(actor_g,actor_labels)
268 | zg_critic_real = self.real_critic(real_g,actor_labels)
269 | #fake_output = self.attr_critic(z_g) # prior 는 안써도 되는가? 애매하군.
270 | weight_var = torch.mean(sig_var,0,True)
271 |
272 | distnace_penalty = torch.mean(torch.sum((1 + (actor_g-fake_z).pow(2)).log()*weight_var.pow(-2),1),0)
273 | distnace_penalty = distnace_penalty + torch.mean(torch.sum((1 + (real_g-z).pow(2)).log()*weight_var.pow(-2),1),0)
274 | #distnace_penalty = 0
275 |
276 | actor_loss = F.binary_cross_entropy(zg_critic_out,actor_truth,size_average=False)+ F.binary_cross_entropy(zg_critic_real,actor_truth,size_average=False)+distnace_penalty
277 |
278 | #actor_loss = actor_loss + distnace_penalty
279 | actor_loss.backward()
280 | total_actor_loss += actor_loss.item()
281 | self.actor_optim.step()
282 | total_distance_penalty += distnace_penalty.item()
283 | actor_iteration +=1
284 | if (actor_iteration%100) == 0 :
285 | self.d_actor_histogram(self.iteration)
286 |
287 |
288 |
289 |
290 |
291 | if batch_idx == 1:
292 | fake_z = torch.randn(m_batchsize,self.d_model)
293 | fake_z = fake_z.to(self.device)
294 | z_g = self.actor(fake_z,labels)
295 | z_g_recon = self.model.decode(z_g)
296 | prior_recon = self.model.decode(fake_z)
297 | data_recon = self.model.decode(z)
298 | if self.data == 'MNIST':
299 | data = data.view(-1,1,28,28)
300 | z_g_recon = z_g_recon.view(-1,1,28,28)
301 | data_recon = data_recon.view(-1,1,28,28)
302 | prior_recon = prior_recon.view(-1,1,28,28)
303 | save_image(z_g_recon.cpu(),self.data+'_results_ac/sample_z_g_train_' + str(epoch) +'.png')
304 | save_image(prior_recon.cpu(),self.data+'_results_ac/sample_prior_train_' + str(epoch) +'.png')
305 | save_image(data_recon.cpu(),self.data+'_results_ac/sample_recon_train_' + str(epoch) +'.png')
306 | save_image(data.cpu(),self.data+'_results_ac/grtruth_train_' + str(epoch) +'.png')
307 | print ("distance penalty : {} , {} | critic_loss :{}".format(total_distance_penalty/actor_iteration,total_actor_loss/actor_iteration,total_real_loss/iteration))
308 |
309 |
310 | self._summary_wrtie(total_distance_penalty/actor_iteration,total_actor_loss/actor_iteration,total_real_loss/iteration,epoch)
311 | print ("[+] Epoch:[{}/{}] train actor average loss :{}".format(epoch,self.epoch,train_loss))
312 | def re_allocate(self,data):
313 | new_data = data.detach()
314 | new_data.requiers_grad = True
315 | return new_data
316 | def get_sample(self,epoch,data =None,labels =None):
317 | self.model.eval()
318 | self.actor.eval()
319 | self.real_critic.eval()
320 | self.attr_critic.eval()
321 | with torch.no_grad():
322 | for i in range(self.num_labels):
323 | test_labels = self.labels[i]
324 | test_labels = test_labels.expand(64,-1)
325 |
326 | sample = torch.randn(64,self.d_model).to(self.device)
327 | sample = self.actor(sample,test_labels)
328 |
329 | out = self.model.decoder(sample)
330 | out = self.model.sigmoid(out) #?? Amiguity Labels input?
331 | if self.data == 'MNIST':
332 | save_image(out.view(-1,1,28,28),self.data+'_results_ac/sample_' + str(epoch)+'_class:'+str(i) +'.png')
333 | else:
334 | save_image(out.cpu(),self.data+'_results_ac/sample_' + str(epoch)+'_class:'+str(i) +'.png')
335 | def _test(self,epoch):
336 | self.model.eval()
337 | self.actor.eval()
338 | self.real_critic.eval()
339 | self.attr_critic.eval()
340 | test_loss = 0
341 | with torch.no_grad():
342 | for i, (data,lebels) in enumerate(self.testDataLoader):
343 | data = data.cuda()
344 | recon_batch,z,mu,log_sigma = self.model(data)
345 |
346 | loss = self.loss(recon_batch,data,mu,log_sigma)
347 | test_loss += loss.item()
348 | if i == 1:
349 | if self.data == 'MNIST':
350 | recon_batch = recon_batch.view(-1,1,28,28)
351 | save_image(recon_batch.cpu(),self.data+'_results/sample_valid_' + str(epoch) +'.png')
352 | save_image(data.cpu(),self.data+'_results/grtruth_valid_' + str(epoch) +'.png')
353 |
354 | def _set_label_type(self):
355 |
356 | self.labels = torch.eye(10)
357 | self.labels = self.labels.to(self.device)
358 | self.num_labels = self.labels.size(0)
359 |
360 | def fake_attr_generate(self,batch_size,selection_index = None ):
361 | if self.data == 'MNIST':
362 | if selection_index is None:
363 | m = batch_size
364 | selection = np.random.randint(self.num_labels,size=m)
365 | selection = torch.from_numpy(selection).to(self.device)
366 | fake_attr = torch.index_select(self.labels,0,selection)
367 | else:
368 | selection_index = selection_index.cuda()
369 | fake_attr = torch.index_select(self.labels,0,selection_index)
370 | else:
371 | selection_start = np.random.randint(self.trainDataLoader.dataset.num_images-batch_size-1)
372 | selection_data = self.trainDataLoader.dataset.train_dataset[selection_start:selection_start+batch_size]
373 | fake_attr = []
374 | for (name,labels) in selection_data:
375 | fake_attr.append(labels)
376 | fake_attr = torch.FloatTensor(fake_attr)
377 |
378 | return fake_attr
379 |
380 |
381 | def reparameterize(self,mu,sig_var):
382 |
383 | std = sig_var # need to check sig_var is log (sigma^2)
384 | eps = std.data.new(std.size()).normal_(std=1)
385 | return eps.mul(std).add_(mu)
386 | def _summary_wrtie(self,distance_penalty,loss,real_loss,epoch):
387 | self.tensorboad_writer.add_scalar('data/loss',loss,epoch) # need to modify . We use four loss value .
388 | self.tensorboad_writer.add_scalar('data/distance_penalty',distance_penalty,epoch) # need to modify . We use four loss value .
389 | self.tensorboad_writer.add_scalar('data/discriminator_loss',real_loss,epoch)
390 | #for name,param in self.actor.named_parameters(): #actor
391 | # self.tensorboad_writer.add_histogram('actor/'+name, param.clone().cpu().data.numpy(), epoch,bins='sturges')
392 | # self.tensorboad_writer.add_histogram('actor/'+name+'/grad', param.grad.clone().cpu().data.numpy(), epoch,bins='sturges')
393 | #for name,param in self.real_critic.named_parameters(): #actor
394 | # self.tensorboad_writer.add_histogram('real_critic/'+name, param.clone().cpu().data.numpy(), epoch,bins='sturges')
395 | # self.tensorboad_writer.add_histogram('real_critic/'+name+'/grad', param.grad.clone().cpu().data.numpy(), epoch,bins='sturges')
396 | #for name,param in self.attr_critic.named_parameters(): #actor
397 | # self.tensorboad_writer.add_histogram('attr_critic/'+name, param.clone().cpu().data.numpy(), epoch,bins='sturges')
398 | # self.tensorboad_writer.add_histogram('attr_critic/'+name+'/grad', param.grad.clone().cpu().data.numpy(), epoch,bins='sturges')
399 | def d_critic_histogram(self,iteration):
400 | for name,param in self.real_critic.named_parameters(): #actor
401 | self.tensorboad_writer.add_histogram('real_critic/'+name, param.clone().cpu().data.numpy(), iteration,bins='sturges')
402 | self.tensorboad_writer.add_histogram('real_critic/'+name+'/grad', param.grad.clone().cpu().data.numpy(), iteration,bins='sturges')
403 | def d_actor_histogram(self,iteration):
404 | for name,param in self.actor.named_parameters(): #actor
405 | self.tensorboad_writer.add_histogram('actor/'+name, param.clone().cpu().data.numpy(), iteration,bins='sturges')
406 | self.tensorboad_writer.add_histogram('actor/'+name+'/grad', param.grad.clone().cpu().data.numpy(), iteration,bins='sturges')
407 | def save_model(self,epoch):
408 | torch.save(self.actor.state_dict(), './save_model/actor_model'+str(epoch)+'_'+self.data+'.path.tar')
409 | torch.save(self.real_critic.state_dict(), './save_model/real_d_model'+str(epoch)+'_'+self.data+'.path.tar')
410 | torch.save(self.attr_critic.state_dict(), './save_model/attr_d_model'+str(epoch)+'_'+self.data+'.path.tar')
411 | def load_vae(self,path):
412 |
413 | print ("[+] Load pre-trained VAE model")
414 | checkpoint = torch.load(path)
415 | self.model.load_state_dict(checkpoint)
416 |
--------------------------------------------------------------------------------
/model/vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | from .sub_layer import Linear,View
5 | import pdb
6 |
7 | class Celeba_VAE(nn.Module):
8 | def __init__(self,input_size=128,d_model=1024,layer_num=3):
9 | super(Celeba_VAE,self).__init__()
10 | self.d_model = d_model
11 | self.layer_num = layer_num
12 |
13 | self.encoder = self.build_encoder(self.d_model,self.layer_num)
14 | self.sig_layer = nn.Softplus()
15 |
16 | self.decoder = self.build_decoder(self.d_model,self.layer_num)
17 | self.sigmoid = nn.Sigmoid()
18 |
19 |
20 | def build_encoder(self,d_model,layer_num):
21 | encoder_layerList = []
22 |
23 |
24 | encoder_layerList.append(nn.Conv2d(in_channels = 3,out_channels = 256,kernel_size = 5,stride =2,padding=1))
25 | #encoder_layerList.append(nn.BatchNorm2d(256))
26 | encoder_layerList.append(nn.ReLU())
27 | encoder_layerList.append(nn.Conv2d(in_channels = 256 ,out_channels = 256*2 , kernel_size = 5,stride =2,padding=1))
28 | #encoder_layerList.append(nn.BatchNorm2d(256*2))
29 | encoder_layerList.append(nn.ReLU())
30 | encoder_layerList.append(nn.Conv2d(in_channels=512 , out_channels = 1024, kernel_size = 3,stride =2,padding=1))
31 | #encoder_layerList.append(nn.BatchNorm2d(512*2))
32 | encoder_layerList.append(nn.ReLU())
33 | encoder_layerList.append(nn.Conv2d(in_channels = 1024,out_channels = 2048,kernel_size = 3,stride =2,padding=1))
34 | #encoder_layerList.append(nn.BatchNorm2d(1024*2))
35 | encoder_layerList.append(nn.ReLU())
36 | encoder_layerList.append(View())
37 | encoder_layerList.append(nn.Linear(4*4*2048,2048))
38 | return nn.Sequential(*encoder_layerList)
39 |
40 | def build_decoder(self,d_model,layer_num):
41 | decoder_layerList = []
42 | decoder_layerList
43 |
44 | decoder_layerList.append(nn.Linear(d_model,2048*4*4))
45 | decoder_layerList.append(View([2048,4,4]))
46 | decoder_layerList.append(nn.ConvTranspose2d(2048,1024,3,stride=2 ,padding =1 ,output_padding=1))
47 | #decoder_layerList.append(nn.BatchNorm2d(1024))
48 | decoder_layerList.append(nn.ReLU())
49 | decoder_layerList.append(nn.ConvTranspose2d(1024,512,3,stride=2 ,padding =1 ,output_padding=0))
50 | #decoder_layerList.append(nn.BatchNorm2d(512))
51 | decoder_layerList.append(nn.ReLU())
52 | decoder_layerList.append(nn.ConvTranspose2d(512,256,5,stride=2,padding=1 ,output_padding=0))
53 | #decoder_layerList.append(nn.BatchNorm2d(256))
54 | decoder_layerList.append(nn.ReLU())
55 | decoder_layerList.append(nn.ConvTranspose2d(256,3,5,stride=2 ,padding = 1 ,output_padding =1))
56 | return nn.Sequential(*decoder_layerList)
57 |
58 | def reparameterize(self,mu,sig_var):
59 | ## need to understand
60 | if self.training:
61 | std = sig_var # need to check sig_var is log (sigma^2)
62 | eps = std.data.new(std.size()).normal_(std=1)
63 | return eps.mul(std).add_(mu)
64 | else:
65 | return mu
66 | def encode(self,x):
67 | encoder_out = self.encoder(x)
68 |
69 | sig_var , mu_var = encoder_out.chunk(2,dim=-1)
70 |
71 | sig_var = self.sig_layer(sig_var)
72 | z = self.reparameterize(mu_var,sig_var)
73 | return sig_var,mu_var,z
74 | def decode(self,z):
75 | output = self.decoder(z)
76 | output = self.sigmoid(output)
77 | return output
78 |
79 | def forward(self,x):
80 | encoder_out = self.encoder(x)
81 | sig_var , mu_var = encoder_out.chunk(2,dim=-1)
82 |
83 | sig_var = self.sig_layer(sig_var)
84 | z = self.reparameterize(mu_var,sig_var)
85 |
86 | output = self.decoder(z)
87 | output = self.sigmoid(output)
88 | return output,z,mu_var,sig_var
89 |
90 |
91 |
92 | class Mnist_VAE(nn.Module):
93 | def __init__(self,input_dim=28*28,d_model=1024,layer_num=3):
94 | super(Mnist_VAE,self).__init__()
95 | self.d_model = d_model
96 | self.layer_num = layer_num
97 | self.input_dim = input_dim
98 |
99 | self.encoder = self.build_encoder(self.input_dim,self.d_model,self.layer_num)
100 | self.sig_layer = nn.Softplus()
101 |
102 | self.decoder = self.build_decoder(self.input_dim,self.d_model,self.layer_num)
103 |
104 | self.sigmoid = nn.Sigmoid()
105 | def build_encoder(self,input_dim,d_model,layer_num):
106 | encoder_layerList = []
107 | for i in range(layer_num):
108 | if i == 0 :
109 | encoder_layerList.append(nn.Linear(input_dim,d_model))
110 | else:
111 | encoder_layerList.append(nn.Linear(d_model,d_model))
112 | encoder_layerList.append(nn.ReLU())
113 | encoder_layerList.append(nn.Linear(d_model,2*d_model))
114 | return nn.Sequential(*encoder_layerList)
115 |
116 | def build_decoder(self,input_dim,d_model,layer_num):
117 | decoder_layerList = []
118 | for i in range(layer_num):
119 | decoder_layerList.append(nn.Linear(d_model,d_model))
120 | decoder_layerList.append(nn.ReLU())
121 |
122 | decoder_layerList.append(nn.Linear(d_model,input_dim))
123 |
124 | return nn.Sequential(*decoder_layerList)
125 |
126 | def reparameterize(self,mu,sig_var):
127 | ## need to understand
128 | if self.training:
129 | std = sig_var # need to check sig_var is log (sigma^2)
130 | eps = std.data.new(std.size()).normal_(std=1)
131 | return eps.mul(std).add_(mu)
132 | else:
133 | return mu
134 | def encode(self,x):
135 | x = x.view(-1,28*28)
136 | encoder_out = self.encoder(x)
137 | sig_var , mu_var = encoder_out.chunk(2,dim=-1)
138 |
139 | sig_var = self.sig_layer(sig_var)
140 | z = self.reparameterize(mu_var,sig_var)
141 | return sig_var,mu_var,z
142 |
143 | def decode(self,z):
144 | output = self.decoder(z)
145 | output = self.sigmoid(output)
146 | return output
147 |
148 | def forward(self,x):
149 | x = x.view(-1,28*28)
150 | encoder_out = self.encoder(x)
151 | sig_var , mu_var = encoder_out.chunk(2,dim=-1)
152 |
153 | sig_var = self.sig_layer(sig_var)
154 | z = self.reparameterize(mu_var,sig_var)
155 | output = self.decoder(z)
156 | output = self.sigmoid(output)
157 |
158 | return output,z,mu_var,sig_var
159 |
160 |
161 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from utils.get_data import celebA_data_preprocess
2 |
3 |
4 |
5 | if __name__ == '__main__':
6 | print ("[+] Testing celebA_data_preprocess function")
7 | celebA_data_preprocess()
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cbokpark/Pytorch-Latent-Constraints-Learning-to-Generate-Conditionally-from-Unconditional-Generative-Models/0dbd182b294e0c6d3ad0deda3be1dd855fd57617/utils/__init__.py
--------------------------------------------------------------------------------
/utils/get_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import matplotlib.pyplot as plt
4 | from scipy.misc import imresize
5 |
6 | from PIL import Image
7 | import random
8 |
9 | import torchvision.utils as utils
10 | from torch.utils import data
11 | import torchvision.datasets as vision_dsets
12 | import torchvision.transforms as T
13 |
14 |
15 | from tqdm import tqdm
16 | import pdb
17 | def celebA_data_preprocess(root ='/hdd1/cheonbok_experiment/celevA/data/CelebA_nocrop/images/' ,save_root='/hdd1/cheonbok_experiment/celevA/data/CelebA_resize/',resize=64):
18 | """
19 | Preprocessing the celevA data set (resizing)
20 | """
21 |
22 |
23 |
24 | if not os.path.isdir(save_root):
25 | os.mkdir(save_root)
26 | if not os.path.isdir(save_root + 'celebA'):
27 | os.mkdir(save_root+ 'celebA')
28 | img_list = os.listdir(root)
29 |
30 | for i in tqdm(range(len(img_list)),desc='CelebA Preprocessing'):
31 | img = plt.imread(root+ img_list[i])
32 | img = imresize(img,(resize,resize))
33 | plt.imsave(fname = save_root + 'celebA/'+img_list[i],arr=img)
34 | print ("[+] Finished the CelebA Data set Preprocessing")
35 |
36 | def MNIST_DATA(root='./data',train =True,transforms=None ,download =True,batch_size = 32,num_worker = 2):
37 | if transforms is None:
38 | transforms = T.ToTensor()
39 | print ("[+] Get the MNIST DATA")
40 | mnist_train = vision_dsets.MNIST(root = root,
41 | train = True,
42 | transform = transforms,
43 | download = True)
44 | mnist_test = vision_dsets.MNIST(root = root,
45 | train = False,
46 | transform = T.ToTensor(),
47 | download = True)
48 | trainDataLoader = data.DataLoader(dataset = mnist_train,
49 | batch_size = batch_size,
50 | shuffle =True,
51 | num_workers = 2)
52 |
53 | testDataLoader = data.DataLoader(dataset = mnist_test,
54 | batch_size = batch_size,
55 | shuffle = False,
56 | num_workers = 2)
57 | print ("[+] Finished loading data & Preprocessing")
58 | return mnist_train,mnist_test,trainDataLoader,testDataLoader
59 |
60 |
61 | class CelebA(data.Dataset):
62 | """Dataset class for the CelebA dataset."""
63 |
64 | def __init__(self, image_dir, attr_path, selected_attrs, transform, mode,un_condition_mode=True):
65 | """Initialize and preprocess the CelebA dataset."""
66 | self.image_dir = image_dir
67 | self.attr_path = attr_path
68 | self.selected_attrs = selected_attrs
69 | self.transform = transform
70 | self.mode = mode
71 | self.train_dataset = []
72 | self.test_dataset = []
73 | self.attr2idx = {}
74 | self.idx2attr = {}
75 | self.un_condition_mode = un_condition_mode
76 | self.preprocess()
77 |
78 | if mode == 'train':
79 | self.num_images = len(self.train_dataset)
80 | else:
81 | self.num_images = len(self.test_dataset)
82 |
83 | def preprocess(self):
84 | """Preprocess the CelebA attribute file."""
85 | lines = [line.rstrip() for line in open(self.attr_path, 'r')]
86 | all_attr_names = lines[1].split()
87 | for i, attr_name in enumerate(all_attr_names):
88 | self.attr2idx[attr_name] = i
89 | self.idx2attr[i] = attr_name
90 |
91 | lines = lines[2:]
92 | random.seed(1234)
93 | random.shuffle(lines)
94 | for i, line in enumerate(lines):
95 | split = line.split()
96 | filename = split[0]
97 | values = split[1:]
98 |
99 | label = []
100 | for attr_name in self.selected_attrs:
101 | idx = self.attr2idx[attr_name]
102 | label.append(values[idx] == '1')
103 |
104 | if sum(label)>=1 or self.un_condition_mode:
105 | if (i+1) < 2000:
106 | self.test_dataset.append([filename, label])
107 | else:
108 | self.train_dataset.append([filename, label])
109 | print('[+]Finished preprocessing the CelebA dataset...')
110 |
111 | def __getitem__(self, index):
112 | """Return one image and its corresponding attribute label."""
113 | dataset = self.train_dataset if self.mode == 'train' else self.test_dataset
114 | filename, label = dataset[index]
115 | image = Image.open(os.path.join(self.image_dir, filename))
116 |
117 | return self.transform(image), torch.FloatTensor(label)
118 |
119 | def __len__(self):
120 | """Return the number of images."""
121 | return self.num_images
122 | #def __figure_label(self,label):
123 | # figure_label = {0:[2,]}
124 |
125 |
126 |
127 | def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128,
128 | batch_size=16, dataset='CelebA', mode='train', num_workers=1,un_condition_mode = True):
129 | """Build and return a data loader."""
130 | transform = []
131 | if mode == 'train':
132 | transform.append(T.RandomHorizontalFlip())
133 | transform.append(T.CenterCrop(crop_size))
134 | transform.append(T.Resize(image_size))
135 | transform.append(T.ToTensor())
136 | #transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
137 | transform = T.Compose(transform)
138 |
139 | if dataset == 'CelebA':
140 | dataset = CelebA(image_dir, attr_path, selected_attrs, transform, mode,un_condition_mode)
141 | elif dataset == 'RaFD':
142 | dataset = ImageFolder(image_dir, transform,un_condition_mode)
143 |
144 | data_loader = data.DataLoader(dataset=dataset,
145 | batch_size=batch_size,
146 | shuffle=(mode=='train'),
147 | num_workers=num_workers)
148 | return data_loader,dataset
149 |
150 | def Celeba_DATA(celeba_img_dir ,attr_path,image_size=128,celeba_crop_size=178,selected_attrs=None,batch_size = 32,num_worker = 1,un_condition_mode =True):
151 |
152 |
153 |
154 | if selected_attrs is None:
155 | selected_attrs = ['5_o_Clock_Shadow','Arched_Eyebrows','Attractive','Bags_Under_Eyes', 'Bangs',
156 | 'Big_Lips','Big_Nose', 'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby',
157 | 'Double_Chin', 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', 'Male',
158 | 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose',
159 | 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair',
160 | 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young']
161 | trainDataLoader,trainData = get_loader(celeba_img_dir,attr_path,selected_attrs,
162 | celeba_crop_size,image_size,batch_size,
163 | 'CelebA','train',num_worker,un_condition_mode)
164 | testDataLoader,testData = get_loader(celeba_img_dir,attr_path,selected_attrs,
165 | celeba_crop_size,image_size,batch_size,
166 | 'CelebA','test',num_worker,un_condition_mode)
167 |
168 | return trainData,testData,trainDataLoader,testDataLoader
--------------------------------------------------------------------------------