├── __init__.py ├── utils └── constants.py ├── inference.py ├── README.md ├── decoder.py ├── encoder.py ├── CODE_OF_CONDUCT.md ├── discriminator.py ├── CONTRIBUTING.md ├── model.py └── train.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/constants.py: -------------------------------------------------------------------------------- 1 | KERNEL_SIZE = 5 2 | PADDING = 2 3 | STRIDE = 2 4 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from network import VAEGAN 7 | from tensorboardX import SummaryWriter 8 | from torchvision.utils import make_grid 9 | 10 | 11 | np.random.seed(8) 12 | torch.manual_seed(8) 13 | torch.cuda.manual_seed(8) 14 | 15 | if __name__ == "__main__": 16 | 17 | parser = argparse.ArgumentParser(description="VAEGAN") 18 | parser.add_argument("--z_size", default=128, action="store", type=int, dest="z_size") 19 | parser.add_argument("--recon_level", default=3, action="store", type=int, dest="recon_level") 20 | parser.add_argument("--batchsize", default=64, action="store", type=int, dest="batchsize") 21 | parser.add_argument("--num_classes", default=10, action="store", type=int, dest="num_classes") 22 | parser.add_argument("--model_path", default='model.pth', action="store", type=str, dest="model_path") 23 | 24 | args = parser.parse_args() 25 | 26 | z_size = args.z_size 27 | recon_level = args.recon_level 28 | batchsize = args.batchsize 29 | num_classes = args.num_classes 30 | model_path = args.model_path 31 | step_index = 0 32 | 33 | # TODO: add to argument parser 34 | dataset_name = 'cifar10' 35 | 36 | writer = SummaryWriter(comment="_CIFAR10_GAN") 37 | net = VAEGAN(z_size=z_size, recon_level=recon_level).cuda() 38 | 39 | # Load existing model 40 | model = torch.load(model_path) 41 | net.load_state_dict(model.state_dict()) 42 | 43 | # switch to inference model 44 | net.eval() 45 | 46 | label = np.random.randint(0, num_classes, batchsize) 47 | one_hot_label = F.one_hot(torch.from_numpy(label), num_classes).float().cuda() 48 | 49 | out = net(None, one_hot_label, 50) 50 | out = out.data.cpu() 51 | out = (out + 1) / 2 52 | out = make_grid(out, nrow=8) 53 | writer.add_image("generated", out, step_index) 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ACVAEGAN 2 | An approach of imposing a condition on VAEGAN through the use of an auxiliary classifier. 3 | 4 | ## Description 5 | - Goal: To build a Conditional VAEGAN by employing an Auxiliary Classifier. 6 | - Architecture: 7 | ![ACVAEGAN Architecture](https://github.com/pranavbudhwant/acvaegan/blob/master/architecture.jpg) 8 | - Dataset: WikiArt Emotions[1] 9 | - Approaches: 10 | 1. Generate paintings conditioned on emotion (anger, fear, sadness, ..) 11 | 2. Generate paintings conditioned on category (cubism, surrealism, minimalism, ..) 12 | 3. Generate paintings conditioned on style (contemporary, modern, renaissance, ..) 13 | 14 | ## Plan 15 | - ~Prepare dataset for the three approaches~ 16 | - ~CSV files containing (image-id, emotion); (image-id, category); (image-id, style)~ 17 | - Auxiliary Classifier Architecture 18 | - Multilabel Classifier/Multiclass Classifier? 19 | - Keras, Pytorch implementation 20 | - Try on MNIST 21 | 22 | ## References 23 | [1] WikiArt Emotions: An Annotated Dataset of Emotions Evoked by Art. Saif M. Mohammad and Svetlana Kiritchenko. In Proceedings of the 11th Edition of the Language Resources and Evaluation Conference (LREC-2018), May 2018, Miyazaki, Japan. 24 | 25 | [2] Autoencoding beyond pixels using a learned similarity metric https://arxiv.org/abs/1512.09300 26 | 27 | [3] Conditional Image Synthesis With Auxiliary Classifier GANs https://arxiv.org/abs/1610.09585 28 | 29 | [4] Twin Auxiliary Classifiers GAN https://arxiv.org/abs/1907.02690 30 | 31 | [5] The Emotional GAN: Priming AdversarialGeneration of Art with Emotion https://nips2017creativity.github.io/doc/The_Emotional_GAN.pdf 32 | 33 | [6] CVAE-GAN: Fine-Grained Image Generation through Asymmetric Training https://arxiv.org/pdf/1703.10155.pdf 34 | 35 | [7] Learning Structured Output Representationusing Deep Conditional Generative Models https://pdfs.semanticscholar.org/3f25/e17eb717e5894e0404ea634451332f85d287.pdf 36 | -------------------------------------------------------------------------------- /decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class DecoderBlock(nn.Module): 6 | def __init__(self, channel_in, channel_out): 7 | super(DecoderBlock, self).__init__() 8 | 9 | # transpose convolution to double the dimensions 10 | self.conv = nn.ConvTranspose2d(in_channels=channel_in, out_channels=channel_out, kernel_size=5, padding=2, 11 | stride=2, output_padding=1, bias=False) 12 | self.bn = nn.BatchNorm2d(channel_out, momentum=0.9) 13 | 14 | def forward(self, ten): 15 | ten = self.conv(ten) 16 | ten = self.bn(ten) 17 | ten = F.relu(ten, True) 18 | return ten 19 | 20 | 21 | class Decoder(nn.Module): 22 | def __init__(self, z_size, size, num_classes=10): 23 | super(Decoder, self).__init__() 24 | 25 | # start from B * z_size 26 | # concatenate one hot encoded class vector 27 | self.fc = nn.Sequential(nn.Linear(in_features=(z_size + num_classes), out_features=(8 * 8 * size), bias=False), 28 | nn.BatchNorm1d(num_features=8 * 8 * size, momentum=0.9), 29 | nn.ReLU(True)) 30 | self.size = size 31 | 32 | layers = [ 33 | DecoderBlock(channel_in=self.size, channel_out=self.size), 34 | DecoderBlock(channel_in=self.size, channel_out=self.size // 2) 35 | ] 36 | 37 | self.size = self.size // 2 38 | layers.append(DecoderBlock(channel_in=self.size, channel_out=self.size // 4)) 39 | self.size = self.size // 4 40 | 41 | # final conv to get 3 channels and tanh layer 42 | layers.append(nn.Sequential( 43 | nn.Conv2d(in_channels=self.size, out_channels=3, kernel_size=5, stride=1, padding=2), 44 | nn.Tanh() 45 | )) 46 | self.conv = nn.Sequential(*layers) 47 | 48 | def forward(self, ten, one_hot_classes): 49 | ten_cat = torch.cat((one_hot_classes, ten), 1) 50 | ten = self.fc(ten_cat) 51 | ten = ten.view(len(ten), -1, 8, 8) 52 | ten = self.conv(ten) 53 | return ten 54 | 55 | def __call__(self, *args, **kwargs): 56 | return super(Decoder, self).__call__(*args, **kwargs) 57 | -------------------------------------------------------------------------------- /encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class EncoderBlock(nn.Module): 6 | def __init__(self, channel_in, channel_out): 7 | super(EncoderBlock, self).__init__() 8 | 9 | # convolution layer to halve the dimensions 10 | 11 | self.conv = nn.Conv2d(in_channels=channel_in, out_channels=channel_out, kernel_size=5, padding=2, stride=2, 12 | bias=False) 13 | self.bn = nn.BatchNorm2d(num_features=channel_out, momentum=0.9) 14 | 15 | def forward(self, ten, out=False, t=False): 16 | ten = self.conv(ten) 17 | ten_out = ten 18 | ten = self.bn(ten) 19 | ten = F.relu(ten, False) 20 | 21 | # if out=True return intermediate output for reconstruction error 22 | if out: 23 | return ten, ten_out 24 | return ten 25 | 26 | 27 | class Encoder(nn.Module): 28 | def __init__(self, channel_in=3, z_size=128): 29 | super(Encoder, self).__init__() 30 | self.size = channel_in 31 | layers = [] 32 | 33 | # the first time 3->64, for every other double the channel size 34 | for i in range(3): 35 | if i == 0: 36 | layers.append(EncoderBlock(channel_in=self.size, channel_out=64)) 37 | self.size = 64 38 | else: 39 | layers.append(EncoderBlock(channel_in=self.size, channel_out=self.size * 2)) 40 | self.size *= 2 41 | 42 | # final shape B x 256 x 8 x 8 43 | self.conv = nn.Sequential(*layers) 44 | self.fc = nn.Sequential(nn.Linear(in_features=8 * 8 * self.size, out_features=1024, bias=False), 45 | nn.BatchNorm1d(num_features=1024, momentum=0.9), 46 | nn.ReLU(True)) 47 | self.l_mu = nn.Linear(in_features=1024, out_features=z_size) 48 | self.l_var = nn.Linear(in_features=1024, out_features=z_size) 49 | 50 | def forward(self, ten): 51 | ten = self.conv(ten) 52 | ten = ten.view(len(ten), -1) 53 | ten = self.fc(ten) 54 | mu = self.l_mu(ten) 55 | logvar = self.l_var(ten) 56 | return mu, logvar 57 | 58 | def __call__(self, *args, **kwargs): 59 | return super(Encoder, self).__call__(*args, **kwargs) 60 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Code of Conduct 2 | 3 | As contributors and maintainers of this project, and in the interest of fostering an open 4 | and welcoming community, we pledge to respect all people who contribute through reporting 5 | issues, posting feature requests, updating documentation, submitting pull requests or 6 | patches, and other activities. 7 | 8 | We are committed to making participation in this project a harassment-free experience for 9 | everyone, regardless of level of experience, gender, gender identity and expression, 10 | sexual orientation, disability, personal appearance, body size, race, ethnicity, age, 11 | religion, or nationality. 12 | 13 | Examples of unacceptable behavior by participants include: 14 | 15 | * The use of sexualized language or imagery 16 | * Personal attacks 17 | * Trolling or insulting/derogatory comments 18 | * Public or private harassment 19 | * Publishing other's private information, such as physical or electronic addresses, 20 | without explicit permission 21 | * Other unethical or unprofessional conduct 22 | 23 | Project maintainers have the right and responsibility to remove, edit, or reject comments, 24 | commits, code, wiki edits, issues, and other contributions that are not aligned to this 25 | Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors 26 | that they deem inappropriate, threatening, offensive, or harmful. 27 | 28 | By adopting this Code of Conduct, project maintainers commit themselves to fairly and 29 | consistently applying these principles to every aspect of managing this project. Project 30 | maintainers who do not follow or enforce the Code of Conduct may be permanently removed 31 | from the project team. 32 | 33 | This Code of Conduct applies both within project spaces and in public spaces when an 34 | individual is representing the project or its community. 35 | 36 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by 37 | contacting a project maintainer at [ramramrakhya81@gmail.com](mailto:ramramrakhya81@gmail.com). All complaints will 38 | be reviewed and investigated and will result in a response that is deemed necessary and 39 | appropriate to the circumstances. Maintainers are obligated to maintain confidentiality 40 | with regard to the reporter of an incident. 41 | 42 | This Code of Conduct is adapted from the 43 | [Contributor Covenant](http://contributor-covenant.org), version 1.3.0, available at 44 | [contributor-covenant.org/version/1/3/0/](http://contributor-covenant.org/version/1/3/0/) -------------------------------------------------------------------------------- /discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from encoder import EncoderBlock 6 | 7 | 8 | class Discriminator(nn.Module): 9 | def __init__(self, channels_in=3, recon_level=3, num_classes=10): 10 | super(Discriminator, self).__init__() 11 | self.size = channels_in 12 | self.recon_level = recon_level 13 | 14 | # module list because we need need to extract an intermediate output 15 | self.conv = nn.ModuleList() 16 | self.conv.append(nn.Sequential( 17 | nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2), 18 | nn.ReLU(inplace=True) 19 | )) 20 | 21 | self.size = 32 22 | self.conv.append(EncoderBlock(channel_in=self.size, channel_out=128)) 23 | self.size = 128 24 | self.conv.append(EncoderBlock(channel_in=self.size, channel_out=256)) 25 | self.size = 256 26 | self.conv.append(EncoderBlock(channel_in=self.size, channel_out=256)) 27 | 28 | # final fc layer to get the score (real or fake) 29 | 30 | self.fc = nn.Sequential( 31 | nn.Linear(in_features=8 * 8 * self.size, out_features=512, bias=False), 32 | nn.BatchNorm1d(num_features=512, momentum=0.9), 33 | nn.ReLU(inplace=True) 34 | ) 35 | self.fc_disc = nn.Linear(in_features=512, out_features=1) 36 | self.fc_aux = nn.Linear(in_features=512, out_features=num_classes) 37 | 38 | def forward(self, ten, other_ten, mode='REC'): 39 | ten = torch.cat((ten, other_ten), 0) 40 | if mode == 'REC': 41 | for i, layer in enumerate(self.conv): 42 | # take 9th layer as one of the outputs 43 | if i == self.recon_level: 44 | ten, layer_ten = layer(ten, True) 45 | # fetch the layer representations just for the original & reconstructed, 46 | # flatten, because it is multidimensional 47 | layer_ten = layer_ten.view(len(layer_ten), -1) 48 | return layer_ten 49 | else: 50 | ten = layer(ten) 51 | else: 52 | for i, layer in enumerate(self.conv): 53 | ten = layer(ten) 54 | 55 | ten = ten.view(len(ten), -1) 56 | ten = self.fc(ten) 57 | ten_disc = self.fc_disc(ten) 58 | ten_aux = self.fc_aux(ten) 59 | return F.sigmoid(ten_disc), F.log_softmax(ten_aux) 60 | 61 | def __call__(self, *args, **kwargs): 62 | return super(Discriminator, self).__call__(*args, **kwargs) 63 | 64 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Contributing guidelines 2 | 3 | Thank you for your interest in contributing to AC-VAEGAN! Here are a few pointers about how you can help. 4 | 5 | ### Setting things up 6 | 7 | To set up the development environment, follow the instructions in README. 8 | 9 | ### Finding something to work on 10 | 11 | The issue tracker of AC-VAEGAN is a good place to start. If you find something that interests you, comment on the thread and we’ll help get you started. 12 | 13 | Alternatively, if you come across a new bug in the module, please file a new issue and comment if you would like to be assigned. The existing issues are tagged with one or more labels, based on the part of the website it touches, its importance etc., that can help you in selecting one. 14 | 15 | If neither of these seem appealing, please post on our channel and we will help find you something else to work on. 16 | 17 | ### Instructions to submit code 18 | 19 | Before you submit code, please talk to us via the issue tracker so we know you are working on it. 20 | 21 | Our central development branch is development. Coding is done on feature branches based off of development and merged into it once stable and reviewed. To submit code, follow these steps: 22 | 23 | 1. Create a new branch off of development. Select a descriptive branch name. 24 | 25 | git fetch upstream 26 | git checkout master 27 | git merge upstream/master 28 | git checkout -b your-branch-name 29 | 30 | 2. Commit and push code to your branch: 31 | 32 | - Commits should be self-contained and contain a descriptive commit message. 33 | ##### Rules for a great git commit message style 34 | - Separate subject from body with a blank line 35 | - Do not end the subject line with a period 36 | - Capitalize the subject line and each paragraph 37 | - Use the imperative mood in the subject line 38 | - Wrap lines at 72 characters 39 | - Use the body to explain what and why you have done something. In most cases, you can leave out details about how a change has been made. 40 | 41 | ##### Example for a commit message 42 | Subject of the commit message 43 | 44 | Body of the commit message... 45 | .... 46 | 47 | - Please make sure your code is well-formatted and adheres to PEP8 conventions (for Python) and the airbnb style guide (for JavaScript). For others (Lua, prototxt etc.) please ensure that the code is well-formatted and the style consistent. 48 | - Please ensure that your code is well tested. 49 | - We highly encourage to use `autopep8` to follow the PEP8 styling. Run the following command before creating the pull request: 50 | 51 | autopep8 --in-place --exclude env,docs --recursive . 52 | git commit -a -m “{{commit_message}}” 53 | git push origin {{branch_name}} 54 | - Also, For Pretifying the Frontend Code Use ```HTML/JS/CSS Pretifier```. 55 | - For installing the Sublime Package Control Manager in Sublime-Text Editor use [this](https://packagecontrol.io/installation#st2) link. Also, If Sublime Package Control Manager is installed then install ```HTML/JS/CSS Pretifier```. 56 | 57 | 3. Once the code is pushed, create a pull request: 58 | 59 | - On your Github fork, select your branch and click “New pull request”. Select “master” as the base branch and your branch in the “compare” dropdown. 60 | If the code is mergeable (you get a message saying “Able to merge”), go ahead and create the pull request. 61 | - Check back after some time to see if the Travis checks have passed, if not you should click on “Details” link on your PR thread at the right of “The Travis CI build failed”, which will take you to the dashboard for your PR. You will see what failed / stalled, and will need to resolve them. 62 | - If your checks have passed, your PR will be assigned a reviewer who will review your code and provide comments. Please address each review comment by pushing new commits to the same branch (the PR will automatically update, so you don’t need to submit a new one). Once you are done, comment below each review comment marking it as “Done”. Feel free to use the thread to have a discussion about comments that you don’t understand completely or don’t agree with. 63 | 64 | - Once all comments are addressed, the reviewer will give an LGTM (‘looks good to me’) and merge the PR. 65 | 66 | Congratulations, you have successfully contributed to Project AC-VAEGAN! -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from encoder import Encoder 6 | from decoder import Decoder 7 | from discriminator import Discriminator 8 | from torch.autograd import Variable 9 | 10 | 11 | class VAEGAN(nn.Module): 12 | def __init__(self, z_size=128, recon_level=3, num_classes=10): 13 | super(VAEGAN, self).__init__() 14 | 15 | # latent space size 16 | self.z_size = z_size 17 | self.encoder = Encoder(z_size=self.z_size) 18 | self.decoder = Decoder(z_size=self.z_size, size=self.encoder.size, 19 | num_classes=num_classes) 20 | 21 | self.discriminator = Discriminator(channels_in=3, recon_level=recon_level) 22 | 23 | # initialize self defined params 24 | self.init_parameters() 25 | 26 | def init_parameters(self): 27 | # just explore the network, find every weight and bias matrix and fill it 28 | for m in self.modules(): 29 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)): 30 | if hasattr(m, 'weight') and m.weight is not None and m.weight.requires_grad: 31 | # init as original implementation 32 | scale = 1.0 / np.sqrt(np.prod(m.weight.shape[1:])) 33 | scale /= np.sqrt(3) 34 | # nn.init.xavier_normal(m.weight,1) 35 | # nn.init.constant(m.weight,0.005) 36 | nn.init.uniform(m.weight, -scale, scale) 37 | if hasattr(m, 'bias') and m.bias is not None and m.bias.requires_grad: 38 | nn.init.constant(m.bias, 0.0) 39 | 40 | def forward(self, ten, one_hot_class, gen_size=10): 41 | if self.training: 42 | # save original images 43 | ten_original = ten 44 | # encode 45 | mu, log_variances = self.encoder(ten) 46 | 47 | # we need true variance not log 48 | variances = torch.exp(log_variances * 0.5) 49 | 50 | # sample from gaussian 51 | ten_from_normal = Variable(torch.randn(len(ten), self.z_size).cuda(), requires_grad=True) 52 | 53 | # shift and scale using mean and variances 54 | ten = ten_from_normal * variances + mu 55 | 56 | # decode tensor 57 | ten = self.decoder(ten, one_hot_class) 58 | 59 | # discriminator for reconstruction 60 | ten_layer = self.discriminator(ten, ten_original, mode='REC') 61 | 62 | # decode from samples 63 | ten_from_normal = Variable(torch.randn(len(ten), self.z_size).cuda(), requires_grad=True) 64 | 65 | ten = self.decoder(ten_from_normal, one_hot_class) 66 | ten_real_fake, ten_aux = self.discriminator(ten_original, ten, mode='GAN') 67 | 68 | return ten, ten_real_fake, ten_layer, mu, log_variances, ten_aux 69 | else: 70 | if ten is None: 71 | # just sample and decode 72 | ten = Variable(torch.randn(gen_size, self.z_size).cuda(), requires_grad=False) 73 | else: 74 | mu, log_variances = self.encoder(ten) 75 | # we need true variance not log 76 | variances = torch.exp(log_variances * 0.5) 77 | 78 | # sample from gaussian 79 | ten_from_normal = Variable(torch.randn(len(ten), self.z_size).cuda(), requires_grad=True) 80 | 81 | # shift and scale using mean and variances 82 | ten = ten_from_normal * variances + mu 83 | 84 | # decode tensor 85 | ten = self.decoder(ten, one_hot_class) 86 | return ten 87 | 88 | def __call__(self, *args, **kwargs): 89 | return super(VAEGAN, self).__call__(*args, **kwargs) 90 | 91 | @staticmethod 92 | def loss(ten_original, ten_predict, layer_original, layer_predicted, labels_original, labels_sampled, 93 | mu, variances, aux_labels_predicted, aux_labels_sampled, aux_labels_original): 94 | """ 95 | :param ten_original: original images 96 | :param ten_predict: predicted images (decode ouput) 97 | :param layer_original: intermediate layer for original (intermediate output of discriminator) 98 | :param layer_predicted: intermediate layer for reconstructed (intermediate output of the discriminator) 99 | :param labels_original: labels for original (output of the discriminator) 100 | :param labels_sampled: labels for sampled from gaussian (0,1) (output of the discriminator) 101 | :param mu: means 102 | :param variances: tensor of diagonals of log_variances 103 | :param aux_labels_original: tensor of diagonals of log_variances 104 | :param aux_labels_predicted: tensor of diagonals of log_variances 105 | :param aux_labels_sampled: tensor of diagonals of log_variances 106 | :return: 107 | """ 108 | 109 | # reconstruction errors, not used as part of loss just to monitor 110 | nle = 0.5 * (ten_original.view(len(ten_original), -1)) - ten_predict.view((len(ten_predict), -1)) ** 2 111 | 112 | # kl-divergence 113 | kl = -0.5 * torch.sum(-variances.exp() - torch.pow(mu, 2) + variances + 1, 1) 114 | 115 | # mse between intermediate layers 116 | mse = torch.sum((layer_original - layer_predicted) ** 2, 1) 117 | 118 | # BCE for decoder & discriminator for original, sampled & reconstructed 119 | # the only excluded is the bce_gen original 120 | 121 | bce_dis_original = -torch.log(labels_original) 122 | bce_dis_sampled = -torch.log(1 - labels_sampled) 123 | 124 | bce_gen_original = -torch.log(1 - labels_original) 125 | bce_gen_sampled = -torch.log(labels_sampled) 126 | 127 | aux_criteron = nn.NLLLoss() 128 | nllloss_aux_original = aux_criteron(aux_labels_predicted, aux_labels_original) 129 | nllloss_aux_sampled = aux_criteron(aux_labels_sampled, aux_labels_original) 130 | 131 | ''' 132 | bce_gen_predicted = nn.BCEWithLogitsLoss(size_average=False)(labels_predicted, 133 | Variable(torch.ones_like(labels_predicted.data).cuda(), requires_grad=False)) 134 | bce_gen_sampled = nn.BCEWithLogitsLoss(size_average=False)(labels_sampled, 135 | Variable(torch.ones_like(labels_sampled.data).cuda(), requires_grad=False)) 136 | bce_dis_original = nn.BCEWithLogitsLoss(size_average=False)(labels_original, 137 | Variable(torch.ones_like(labels_original.data).cuda(), requires_grad=False)) 138 | bce_dis_predicted = nn.BCEWithLogitsLoss(size_average=False)(labels_predicted, 139 | Variable(torch.zeros_like(labels_predicted.data).cuda(), requires_grad=False)) 140 | bce_dis_sampled = nn.BCEWithLogitsLoss(size_average=False)(labels_sampled, 141 | Variable(torch.zeros_like(labels_sampled.data).cuda(), requires_grad=False)) 142 | ''' 143 | 144 | return nle, kl, mse, bce_dis_original, bce_dis_sampled, bce_gen_original, bce_gen_sampled, nllloss_aux_original, nllloss_aux_sampled 145 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import torchvision.datasets as dset 6 | import torchvision.transforms as transforms 7 | 8 | from network import VAEGAN 9 | from tensorboardX import SummaryWriter 10 | from torch.autograd import Variable 11 | 12 | from torch.optim import RMSprop, Adam, SGD 13 | from torch.optim.lr_scheduler import ExponentialLR, MultiStepLR 14 | 15 | import progressbar 16 | from utils import RollingMeasure 17 | 18 | 19 | np.random.seed(8) 20 | torch.manual_seed(8) 21 | torch.cuda.manual_seed(8) 22 | 23 | if __name__ == "__main__": 24 | 25 | parser = argparse.ArgumentParser(description="VAEGAN") 26 | parser.add_argument("--train_folder", action="store", dest="train_folder") 27 | parser.add_argument("--test_folder", action="store", dest="test_folder") 28 | parser.add_argument("--model_path", action="store", dest="model_path") 29 | parser.add_argument("--n_epochs", default=12, action="store", type=int, dest="n_epochs") 30 | parser.add_argument("--z_size", default=128, action="store", type=int, dest="z_size") 31 | parser.add_argument("--recon_level", default=3, action="store", type=int, dest="recon_level") 32 | parser.add_argument("--lambda_mse", default=1e-3, action="store", type=float, dest="lambda_mse") 33 | parser.add_argument("--lr", default=3e-4, action="store", type=float, dest="lr") 34 | parser.add_argument("--decay_lr", default=0.75, action="store", type=float, dest="decay_lr") 35 | parser.add_argument("--decay_mse", default=1, action="store", type=float, dest="decay_mse") 36 | parser.add_argument("--decay_margin",default=1,action="store",type=float, dest="decay_margin") 37 | parser.add_argument("--decay_equilibrium", default=1, action="store", type=float, dest="decay_equilibrium") 38 | parser.add_argument("--slurm", default=False, action="store", type=bool, dest="slurm") 39 | parser.add_argument("--batchsize", default=64, action="store", type=int, dest="batchsize") 40 | 41 | args = parser.parse_args() 42 | 43 | train_folder = args.train_folder 44 | test_folder = args.test_folder 45 | z_size = args.z_size 46 | recon_level = args.recon_level 47 | decay_mse = args.decay_mse 48 | decay_margin = args.decay_margin 49 | n_epochs = args.n_epochs 50 | lambda_mse = args.lambda_mse 51 | lr = args.lr 52 | decay_lr = args.decay_lr 53 | decay_equilibrium = args.decay_equilibrium 54 | slurm = args.slurm 55 | batchsize = args.batchsize 56 | model_path = args.model_path 57 | 58 | # TODO: add to argument parser 59 | dataset_name = 'cifar10' 60 | 61 | writer = SummaryWriter(comment="_CIFAR10_GAN") 62 | net = VAEGAN(z_size=z_size, recon_level=recon_level).cuda() 63 | 64 | # DATASET 65 | if dataset_name == 'cifar10': 66 | dataset = dset.CIFAR10( 67 | root=train_folder, download=True, 68 | transform=transforms.Compose([ 69 | transforms.Scale(z_size), 70 | transforms.ToTensor(), 71 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 72 | ])) 73 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, 74 | shuffle=True, num_workers=4) 75 | 76 | # margin and equilibirum 77 | margin = 0.35 78 | equilibrium = 0.68 79 | 80 | # mse_lambda = 1.0 81 | # OPTIM-LOSS 82 | # an optimizer for each of the sub-networks, so we can selectively backprop 83 | # optimizer_encoder = Adam(params=net.encoder.parameters(),lr = lr,betas=(0.9,0.999)) 84 | 85 | optimizer_encoder = RMSprop(params=net.encoder.parameters(), lr=lr, alpha=0.9, eps=1e-8, weight_decay=0, momentum=0, 86 | centered=False) 87 | # lr_encoder = MultiStepLR(optimizer_encoder,milestones=[2],gamma=1) 88 | lr_encoder = ExponentialLR(optimizer_encoder, gamma=decay_lr) 89 | # optimizer_decoder = Adam(params=net.decoder.parameters(),lr = lr,betas=(0.9,0.999)) 90 | optimizer_decoder = RMSprop(params=net.decoder.parameters(), lr=lr, alpha=0.9, eps=1e-8, weight_decay=0, momentum=0, 91 | centered=False) 92 | lr_decoder = ExponentialLR(optimizer_decoder, gamma=decay_lr) 93 | # lr_decoder = MultiStepLR(optimizer_decoder,milestones=[2],gamma=1) 94 | # optimizer_discriminator = Adam(params=net.discriminator.parameters(),lr = lr,betas=(0.9,0.999)) 95 | optimizer_discriminator = RMSprop(params=net.discriminator.parameters(), lr=lr, alpha=0.9, eps=1e-8, weight_decay=0, 96 | momentum=0, centered=False) 97 | lr_discriminator = ExponentialLR(optimizer_discriminator, gamma=decay_lr) 98 | # lr_discriminator = MultiStepLR(optimizer_discriminator,milestones=[2],gamma=1) 99 | 100 | batch_number = len(dataloader) 101 | step_index = 0 102 | widgets = [ 103 | 104 | 'Batch: ', progressbar.Counter(), 105 | '/', progressbar.FormatCustomText('%(total)s', {"total": batch_number}), 106 | ' ', progressbar.Bar(marker="-", left='[', right=']'), 107 | ' ', progressbar.ETA(), 108 | ' ', 109 | progressbar.DynamicMessage('loss_nle'), 110 | ' ', 111 | progressbar.DynamicMessage('loss_encoder'), 112 | ' ', 113 | progressbar.DynamicMessage('loss_decoder'), 114 | ' ', 115 | progressbar.DynamicMessage('loss_discriminator'), 116 | ' ', 117 | progressbar.DynamicMessage('loss_mse_layer'), 118 | ' ', 119 | progressbar.DynamicMessage('loss_kld'), 120 | ' ', 121 | progressbar.DynamicMessage('loss_aux_classifier'), 122 | ' ', 123 | progressbar.DynamicMessage("epoch") 124 | ] 125 | 126 | # for each epoch 127 | if slurm: 128 | print(args) 129 | 130 | for i in range(n_epochs): 131 | 132 | progress = progressbar.ProgressBar(min_value=0, max_value=batch_number + 1, initial_value=0, 133 | widgets=widgets).start() 134 | # reset rolling average 135 | loss_nle_mean = RollingMeasure() 136 | loss_encoder_mean = RollingMeasure() 137 | loss_decoder_mean = RollingMeasure() 138 | loss_discriminator_mean = RollingMeasure() 139 | loss_reconstruction_layer_mean = RollingMeasure() 140 | loss_kld_mean = RollingMeasure() 141 | loss_aux_classifier_mean = RollingMeasure() 142 | gan_gen_eq_mean = RollingMeasure() 143 | gan_dis_eq_mean = RollingMeasure() 144 | # print("LR:{}".format(lr_encoder.get_lr())) 145 | 146 | # for each batch 147 | for j, (data_batch, target_batch) in enumerate(dataloader): 148 | 149 | # set to train mode 150 | net.train() 151 | # target and input are the same images 152 | 153 | data_target = Variable(target_batch, requires_grad=False).float().cuda() 154 | data_in = Variable(data_batch, requires_grad=False).float().cuda() 155 | aux_label_batch = Variable(target_batch, requires_grad=False).long().cuda() 156 | one_hot_class = F.one_hot(aux_label_batch).float() 157 | 158 | # get output 159 | out, out_labels, out_layer, mus, variances, aux_labels = net(data_in, one_hot_class) 160 | # split so we can get the different parts 161 | out_layer_predicted = out_layer[:len(out_layer) // 2] 162 | out_layer_original = out_layer[len(out_layer) // 2:] 163 | # TODO set a batch_size variable to get a clean code here 164 | out_labels_original = out_labels[:len(out_labels) // 2] 165 | out_labels_sampled = out_labels[-len(out_labels) // 2:] 166 | 167 | # aux labels for original and actual images 168 | aux_labels_sampled = aux_labels[:len(aux_labels) // 2] 169 | aux_labels_original = aux_labels[-len(aux_labels) // 2:] 170 | 171 | # loss, nothing special here 172 | nle_value, kl_value, mse_value, bce_dis_original_value, bce_dis_sampled_value,\ 173 | bce_gen_original_value, bce_gen_sampled_value, \ 174 | nllloss_aux_original, nllloss_aux_sampled = VAEGAN.loss(data_target, out, 175 | out_layer_original, 176 | out_layer_predicted, 177 | out_labels_original, 178 | out_labels_sampled, 179 | mus, variances, 180 | aux_labels_original, aux_labels_sampled, 181 | aux_label_batch) 182 | # THIS IS THE MOST IMPORTANT PART OF THE CODE 183 | loss_encoder = torch.sum(kl_value) + torch.sum(mse_value) 184 | loss_discriminator = torch.sum(bce_dis_original_value) + torch.sum(bce_dis_sampled_value) \ 185 | + torch.sum(nllloss_aux_original) + torch.sum(nllloss_aux_sampled) 186 | 187 | loss_decoder = torch.sum(lambda_mse * mse_value) - loss_discriminator 188 | # loss_decoder = torch.sum(mse_lambda * mse_value) + (1.0-mse_lambda)*(torch.sum(bce_gen_sampled_value) 189 | # +torch.sum(bce_gen_original_value)) 190 | 191 | # register mean values of the losses for logging 192 | loss_nle_mean(torch.mean(nle_value).data.cpu().numpy()) 193 | 194 | loss_discriminator_mean((torch.mean(bce_dis_original_value) + torch.mean(bce_dis_sampled_value) 195 | + torch.mean(nllloss_aux_original) + torch.mean( 196 | nllloss_aux_sampled)).data.cpu().numpy()) 197 | 198 | loss_decoder_mean((torch.mean(lambda_mse * mse_value) - ( 199 | torch.mean(bce_dis_original_value) + torch.mean(bce_dis_sampled_value))).data.cpu().numpy()) 200 | # loss_decoder_mean((torch.mean(mse_lambda * mse_value) + (1-mse_lambda)*( 201 | # torch.mean(bce_gen_original_value) + torch.mean(bce_gen_sampled_value))).data.cpu().numpy()[0]) 202 | 203 | loss_encoder_mean((torch.mean(kl_value) + torch.mean(mse_value)).data.cpu().numpy()) 204 | loss_reconstruction_layer_mean(torch.mean(mse_value).data.cpu().numpy()) 205 | loss_kld_mean(torch.mean(kl_value).data.cpu().numpy()) 206 | loss_aux_classifier_mean( 207 | (torch.mean(nllloss_aux_original) + torch.mean(nllloss_aux_sampled)).data.cpu().numpy()) 208 | 209 | # selectively disable the decoder of the discriminator if they are unbalanced 210 | train_dis = True 211 | train_dec = True 212 | if torch.mean(bce_dis_original_value).data < equilibrium - margin or torch.mean( 213 | bce_dis_sampled_value).data < equilibrium - margin: 214 | train_dis = False 215 | if torch.mean(bce_dis_original_value).data > equilibrium + margin or torch.mean( 216 | bce_dis_sampled_value).data > equilibrium + margin: 217 | train_dec = False 218 | if train_dec is False and train_dis is False: 219 | train_dis = True 220 | train_dec = True 221 | 222 | # aggiungo log 223 | if train_dis: 224 | gan_dis_eq_mean(1.0) 225 | else: 226 | gan_dis_eq_mean(0.0) 227 | 228 | if train_dec: 229 | gan_gen_eq_mean(1.0) 230 | else: 231 | gan_gen_eq_mean(0.0) 232 | 233 | # BACKPROP 234 | # clean grads 235 | net.zero_grad() 236 | # encoder 237 | loss_encoder.backward(retain_graph=True) 238 | # someone likes to clamp the grad here 239 | # [p.grad.data.clamp_(-1,1) for p in net.encoder.parameters()] 240 | # update parameters 241 | optimizer_encoder.step() 242 | # clean others, so they are not afflicted by encoder loss 243 | net.zero_grad() 244 | # decoder 245 | if train_dec: 246 | loss_decoder.backward(retain_graph=True) 247 | # [p.grad.data.clamp_(-1,1) for p in net.decoder.parameters()] 248 | optimizer_decoder.step() 249 | # clean the discriminator 250 | net.discriminator.zero_grad() 251 | # discriminator 252 | if train_dis: 253 | loss_discriminator.backward() 254 | # [p.grad.data.clamp_(-1,1) for p in net.discriminator.parameters()] 255 | optimizer_discriminator.step() 256 | 257 | # LOGGING 258 | if slurm: 259 | progress.update(progress.value + 1, loss_nle=loss_nle_mean.measure, 260 | loss_encoder=loss_encoder_mean.measure, 261 | loss_decoder=loss_decoder_mean.measure, 262 | loss_discriminator=loss_discriminator_mean.measure, 263 | loss_mse_layer=loss_reconstruction_layer_mean.measure, 264 | loss_kld=loss_kld_mean.measure, 265 | loss_aux_classifier=loss_aux_classifier_mean.measure, 266 | epoch=i + 1) 267 | 268 | # EPOCH END 269 | if slurm: 270 | progress.update(progress.value + 1, loss_nle=loss_nle_mean.measure, 271 | loss_encoder=loss_encoder_mean.measure, 272 | loss_decoder=loss_decoder_mean.measure, 273 | loss_discriminator=loss_discriminator_mean.measure, 274 | loss_mse_layer=loss_reconstruction_layer_mean.measure, 275 | loss_kld=loss_kld_mean.measure, 276 | loss_aux_classifier=loss_aux_classifier_mean.measure, 277 | epoch=i + 1) 278 | lr_encoder.step() 279 | lr_decoder.step() 280 | lr_discriminator.step() 281 | margin *= decay_margin 282 | equilibrium *= decay_equilibrium 283 | # margin non puo essere piu alto di equilibrium 284 | if margin > equilibrium: 285 | equilibrium = margin 286 | lambda_mse *= decay_mse 287 | if lambda_mse > 1: 288 | lambda_mse = 1 289 | progress.finish() 290 | 291 | writer.add_scalar('loss_encoder', loss_encoder_mean.measure, step_index) 292 | writer.add_scalar('loss_decoder', loss_decoder_mean.measure, step_index) 293 | writer.add_scalar('loss_discriminator', loss_discriminator_mean.measure, step_index) 294 | writer.add_scalar('loss_reconstruction', loss_nle_mean.measure, step_index) 295 | writer.add_scalar('loss_kld', loss_kld_mean.measure, step_index) 296 | writer.add_scalar('loss_aux_classifier', loss_aux_classifier_mean.measure, step_index) 297 | writer.add_scalar('gan_gen', gan_gen_eq_mean.measure, step_index) 298 | writer.add_scalar('gan_dis', gan_dis_eq_mean.measure, step_index) 299 | step_index += 1 300 | 301 | torch.save(net, model_path) 302 | exit(0) 303 | --------------------------------------------------------------------------------