├── README.md ├── exec.py ├── model.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # introvae.pytorch 2 | 3 | **WORK IN PROGRESS!**. 4 | 5 | PyTorch implementation of *IntroVAE: Introspective Variational Autoencoders for Photographic Image Synthesis*. 6 | 7 | ## pre-requirements 8 | 9 | * PyTorch>=0.4.1 10 | * [homura](https://github.com/moskomule/homura) -------------------------------------------------------------------------------- /exec.py: -------------------------------------------------------------------------------- 1 | def main(): 2 | pass 3 | 4 | 5 | if __name__ == '__main__': 6 | pass 7 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import functional as F 3 | 4 | 5 | class MiniEncoder(nn.Module): 6 | def __init__(self): 7 | super(MiniEncoder, self).__init__() 8 | 9 | self.fc1 = nn.Linear(784, 400) 10 | self.fc21 = nn.Linear(400, 20) 11 | self.fc22 = nn.Linear(400, 20) 12 | 13 | def forward(self, x): 14 | h1 = F.relu(self.fc1(x)) 15 | return self.fc21(h1), self.fc22(h1) 16 | 17 | 18 | class MiniDecoder(nn.Module): 19 | def __init__(self): 20 | super(MiniDecoder, self).__init__() 21 | 22 | self.fc3 = nn.Linear(20, 400) 23 | self.fc4 = nn.Linear(400, 784) 24 | 25 | def forward(self, z): 26 | h3 = F.relu(self.fc3(z)) 27 | return F.sigmoid(self.fc4(h3)) 28 | 29 | 30 | MODELS = {"mini": (MiniEncoder, MiniDecoder)} 31 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from homura.utils import Trainer as TrainerBase 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | 7 | def reparameterize(mu, logvar): 8 | std = torch.exp(0.5 * logvar) 9 | eps = torch.randn_like(std) 10 | return eps.mul(std).add_(mu) 11 | 12 | 13 | def l_reg(mu, log_var): 14 | return - 0.5 * torch.sum(1 + log_var - mu ** 2 - torch.exp(log_var), dim=-1) 15 | 16 | 17 | class Trainer(TrainerBase): 18 | def __init__(self, encoder, decoder, optimizer, callbacks, scheduler, hp, verb=False): 19 | """ 20 | :param encoder: Encoder model 21 | :param decoder: Decoder Model 22 | :param optimizer: tuple of optimizers (Opt_enc, Opt_dec) 23 | :param callbacks: 24 | :param scheduler: 25 | :param hp: dict of Hyper parameters 26 | :param verb: 27 | """ 28 | assert isinstance(optimizer, torch.optim.Optimizer), "Need two optimizers for encoder and generator" 29 | assert set(hp.keys()) <= set(["alpha", "beta", "margin"]), "Need `alpha`, `beta`, `margin`" 30 | super(Trainer, self).__init__(model=nn.ModuleList([encoder, decoder]), optimizer=optimizer, loss_f=None, 31 | callbacks=callbacks, scheduler=scheduler, verb=verb, hp=hp) 32 | self.encoder = self.model[0] 33 | self.decoder = self.model[1] 34 | 35 | def iteration(self, data, is_train): 36 | alpha = self.hp["alpha"] 37 | beta = self.hp["beta"] 38 | margin = self.hp["margin"] 39 | 40 | input, _ = self.to_device(data) 41 | if is_train: 42 | z_mu, z_log_var = self.encoder(input) 43 | z = reparameterize(z_mu, z_log_var) 44 | z_p = torch.randn_like(z) 45 | x_r = self.decoder(z) 46 | x_p = self.decoder(z_p) 47 | 48 | self.optimizer[0].zero_grad() 49 | z_r = self.encoder(x_r.detach()) 50 | z_pp = self.encoder(x_p.detach()) 51 | l_enc = l_reg(z_mu, z_log_var) + alpha * (F.relu(margin - l_reg(*z_r)) + F.relu(margin - l_reg(*z_pp))) 52 | (l_enc + beta * F.mse_loss(x_r, input)).backward() 53 | self.optimizer[0].step() 54 | 55 | self.optimizer[1].zero_grad() 56 | z_r = self.encoder(x_r) 57 | z_pp = self.encoder(x_p) 58 | l_dec = alpha * (l_reg(*z_r) + l_reg(*z_pp)) 59 | (l_dec + beta * F.mse_loss(x_r, input)).backward() 60 | self.optimizer[1].step() 61 | 62 | def test(self, data_loader, name=None): 63 | pass 64 | 65 | def generate(self, num_sample): 66 | self.model.eval() 67 | return self.decoder(torch.randn(num_sample)) 68 | 69 | --------------------------------------------------------------------------------