├── images ├── table_svhn.jpg ├── table_cifar.jpg ├── generated_svhn.jpeg └── generated_cifar.jpeg ├── train_svhn.py ├── train_cifar.py ├── train_mnist.py ├── README.md ├── .gitignore ├── mnist_models.py ├── cifar_models.py ├── svhn_models.py ├── data.py └── sgan.py /images/table_svhn.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCI-ML-course-team/gan-manifold-regularization-PyTorch/HEAD/images/table_svhn.jpg -------------------------------------------------------------------------------- /images/table_cifar.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCI-ML-course-team/gan-manifold-regularization-PyTorch/HEAD/images/table_cifar.jpg -------------------------------------------------------------------------------- /images/generated_svhn.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCI-ML-course-team/gan-manifold-regularization-PyTorch/HEAD/images/generated_svhn.jpeg -------------------------------------------------------------------------------- /images/generated_cifar.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCI-ML-course-team/gan-manifold-regularization-PyTorch/HEAD/images/generated_cifar.jpeg -------------------------------------------------------------------------------- /train_svhn.py: -------------------------------------------------------------------------------- 1 | from data import StandardImgData 2 | from cifar_models import Discriminator, Generator, weights_init 3 | from sgan import SGAN_Manifold_Reg 4 | 5 | dataset = 'svhn' 6 | num_classes = 10 7 | latent_dim = 100 8 | batch_size = 128 9 | samples_per_class = 100 10 | 11 | img_data = StandardImgData(samples_per_class, batch_size, dataset) 12 | 13 | G = Generator(latent_dim).apply(weights_init) 14 | D = Discriminator(num_classes).apply(weights_init) 15 | 16 | data_loaders = img_data.get_dataloaders(dataset) 17 | sgan = SGAN_Manifold_Reg(batch_size, latent_dim, num_classes, G, D, data_loaders) 18 | sgan.train(num_epochs=500) 19 | # print(sgan.eval(test=True, epoch_idx=100)) -------------------------------------------------------------------------------- /train_cifar.py: -------------------------------------------------------------------------------- 1 | from data import StandardImgData 2 | from cifar_models import Discriminator, Generator, weights_init 3 | from sgan import SGAN_Manifold_Reg 4 | 5 | dataset = 'cifar' 6 | num_classes = 10 7 | latent_dim = 100 8 | batch_size = 128 9 | samples_per_class = 400 10 | 11 | img_data = StandardImgData(samples_per_class, batch_size, dataset) 12 | 13 | G = Generator(latent_dim).apply(weights_init) 14 | D = Discriminator(num_classes).apply(weights_init) 15 | 16 | data_loaders = img_data.get_dataloaders(dataset) 17 | sgan = SGAN_Manifold_Reg(batch_size, latent_dim, num_classes, G, D, data_loaders) 18 | sgan.train(num_epochs=500) 19 | # print(sgan.eval(test=True, epoch_idx=100)) 20 | -------------------------------------------------------------------------------- /train_mnist.py: -------------------------------------------------------------------------------- 1 | from data import StandardImgData 2 | from mnist_models import Discriminator, Generator, weights_init 3 | from sgan import SGAN_Manifold_Reg 4 | 5 | dataset = 'mnist' 6 | num_classes = 10 7 | latent_dim = 100 8 | batch_size = 128 9 | samples_per_class = 100 10 | 11 | img_data = StandardImgData(samples_per_class, batch_size, dataset) 12 | 13 | G = Generator(latent_dim).apply(weights_init) 14 | D = Discriminator(num_classes).apply(weights_init) 15 | 16 | data_loaders = img_data.get_dataloaders(dataset) 17 | sgan = SGAN_Manifold_Reg(batch_size, latent_dim, num_classes, G, D, data_loaders) 18 | sgan.train(num_epochs=200) 19 | # print(sgan.eval(test=True, epoch_idx=100)) 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semi-Supervised Learning With GANs: Revisiting Manifold Regularization 2 | 3 | This repository contains a PyTorch implementation of the paper [Semi-Supervised Learning With GANs: Revisiting Manifold Regularization] (https://arxiv.org/abs/1805.08957) by Bruno Lecouat, Chuan Sheng Foo, Houssam Zenati, Vijay Ramaseshan Chandrasekhar. 4 | 5 | The tables below show that manifold regularization increases the classification accuracy on CIFAR-10 and SVHN datasets. We used 400 and 100 labeled examples per class respectively. Due to constraints, training was stopped at 254 and 105 epoch respectively. This was enough to demonstrate the effect of manifold regularization but higher accuracies are possible with more training (according to the original paper). 6 | 7 | ![table_cifar](images/table_cifar.jpg) 8 | ![table_svhn](images/table_svhn.jpg) 9 | 10 | Examples of the generated images are shown below. As expected, they do not look visually appealing like in conventional GANs. Good semi-supervised classification performance requires a bad generator. 11 | 12 | ![generated_cifar](images/generated_cifar.jpeg) 13 | ![generated_svhn](images/generated_svhn.jpeg) 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | models/* 107 | metrics/* 108 | tmp/* 109 | data/* 110 | -------------------------------------------------------------------------------- /mnist_models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Discriminator(nn.Module): 5 | 6 | def __init__(self, num_classes): 7 | super(Discriminator, self).__init__() 8 | self.net = nn.Sequential( 9 | nn.Conv2d(1, 64, 3, stride=1, padding=1), 10 | nn.LeakyReLU(), 11 | nn.Conv2d(64, 96, 3, stride=2, padding=1), 12 | nn.LeakyReLU(), 13 | 14 | nn.Dropout(.2), 15 | nn.Conv2d(96, 96, 3, stride=1, padding=1), 16 | nn.LeakyReLU(), 17 | nn.Conv2d(96, 192, 3, stride=2, padding=1), 18 | nn.LeakyReLU(), 19 | 20 | nn.Dropout(.2), 21 | nn.Conv2d(192, 192, 3, stride=2, padding=1), 22 | nn.LeakyReLU(), 23 | nn.Conv2d(192, 192, 1, stride=1, padding=0), 24 | nn.LeakyReLU(), 25 | nn.Conv2d(192, 192, 1, stride=1, padding=0), 26 | nn.LeakyReLU(), 27 | 28 | nn.MaxPool2d(4, stride=1), 29 | Flatten() 30 | ) 31 | 32 | self.fc = nn.Linear(192, num_classes) 33 | 34 | def forward(self, x): 35 | features = self.net(x) 36 | logits = self.fc(features) 37 | return features, logits 38 | 39 | 40 | class Generator(nn.Module): 41 | 42 | def __init__(self, latent_dim): 43 | super(Generator, self).__init__() 44 | self.net = nn.Sequential( 45 | nn.Linear(latent_dim, 512 * 4 * 4), 46 | nn.BatchNorm1d(512 * 4 * 4), 47 | nn.ReLU(), 48 | Reshape((512, 4, 4)), 49 | nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), 50 | nn.BatchNorm2d(256), 51 | nn.ReLU(), 52 | nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), 53 | nn.BatchNorm2d(128), 54 | nn.ReLU(), 55 | nn.ConvTranspose2d(128, 1, 4, stride=2, padding=1), 56 | nn.Tanh(), 57 | ) 58 | 59 | def forward(self, x): 60 | return self.net(x) 61 | 62 | 63 | class Flatten(nn.Module): 64 | def forward(self, x): 65 | return x.view(x.shape[0], -1) 66 | 67 | 68 | class Reshape(nn.Module): 69 | def __init__(self, target_shape): 70 | super(Reshape, self).__init__() 71 | self.target_shape = (-1,) + target_shape 72 | 73 | def forward(self, x): 74 | return x.view(self.target_shape) 75 | 76 | 77 | def weights_init(m): 78 | if type(m) == nn.Linear: 79 | nn.init.normal_(m.weight, mean=.0, std=.1) 80 | nn.init.constant_(m.bias, .0) 81 | 82 | if type(m) == nn.ConvTranspose2d: 83 | nn.init.normal_(m.weight, mean=0, std=.05) 84 | -------------------------------------------------------------------------------- /cifar_models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.utils import weight_norm 3 | 4 | 5 | class Discriminator(nn.Module): 6 | 7 | def __init__(self, num_classes): 8 | super(Discriminator, self).__init__() 9 | self.net = nn.Sequential( 10 | nn.Dropout(.2), 11 | weight_norm(nn.Conv2d(3, 96, 3, stride=1, padding=1)), 12 | nn.LeakyReLU(), 13 | weight_norm(nn.Conv2d(96, 96, 3, stride=1, padding=1)), 14 | nn.LeakyReLU(), 15 | weight_norm(nn.Conv2d(96, 96, 3, stride=2, padding=1)), 16 | nn.LeakyReLU(), 17 | 18 | nn.Dropout(.5), 19 | weight_norm(nn.Conv2d(96, 192, 3, stride=1, padding=1)), 20 | nn.LeakyReLU(), 21 | weight_norm(nn.Conv2d(192, 192, 3, stride=1, padding=1)), 22 | nn.LeakyReLU(), 23 | weight_norm(nn.Conv2d(192, 192, 3, stride=2, padding=1)), 24 | nn.LeakyReLU(), 25 | 26 | nn.Dropout(.5), 27 | weight_norm(nn.Conv2d(192, 192, 3, stride=1, padding=0)), 28 | nn.LeakyReLU(), 29 | weight_norm(nn.Conv2d(192, 192, 1, stride=1, padding=0)), 30 | nn.LeakyReLU(), 31 | weight_norm(nn.Conv2d(192, 192, 1, stride=1, padding=0)), 32 | nn.LeakyReLU(), 33 | 34 | nn.AdaptiveAvgPool2d(1), 35 | Flatten() 36 | ) 37 | 38 | self.fc = weight_norm(nn.Linear(192, num_classes)) 39 | 40 | def forward(self, x): 41 | inter_layer = self.net(x) 42 | logits = self.fc(inter_layer) 43 | return inter_layer, logits 44 | 45 | 46 | class Generator(nn.Module): 47 | 48 | def __init__(self, latent_dim): 49 | super(Generator, self).__init__() 50 | self.net = nn.Sequential( 51 | nn.Linear(latent_dim, 512 * 4 * 4), 52 | nn.BatchNorm1d(512 * 4 * 4), 53 | nn.ReLU(), 54 | Reshape((512, 4, 4)), 55 | nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), 56 | nn.BatchNorm2d(256), 57 | nn.ReLU(), 58 | nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), 59 | nn.BatchNorm2d(128), 60 | nn.ReLU(), 61 | weight_norm(nn.ConvTranspose2d(128, 3, 4, stride=2, padding=1)), 62 | nn.Tanh(), 63 | ) 64 | 65 | def forward(self, x): 66 | return self.net(x) 67 | 68 | 69 | class Flatten(nn.Module): 70 | def forward(self, x): 71 | return x.view(x.shape[0], -1) 72 | 73 | 74 | class Reshape(nn.Module): 75 | def __init__(self, target_shape): 76 | super(Reshape, self).__init__() 77 | self.target_shape = (-1,) + target_shape 78 | 79 | def forward(self, x): 80 | return x.view(self.target_shape) 81 | 82 | 83 | def weights_init(m): 84 | if type(m) == nn.Linear: 85 | nn.init.normal_(m.weight, mean=.0, std=.05) 86 | nn.init.constant_(m.bias, .0) 87 | 88 | if type(m) == nn.ConvTranspose2d: 89 | nn.init.normal_(m.weight, mean=0, std=.05) 90 | -------------------------------------------------------------------------------- /svhn_models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.utils import weight_norm 3 | 4 | 5 | class Discriminator(nn.Module): 6 | 7 | def __init__(self, num_classes): 8 | super(Discriminator, self).__init__() 9 | self.net = nn.Sequential( 10 | nn.Dropout(.2), 11 | weight_norm(nn.Conv2d(3, 64, 3, stride=1, padding=1)), 12 | nn.LeakyReLU(), 13 | weight_norm(nn.Conv2d(64, 64, 3, stride=1, padding=1)), 14 | nn.LeakyReLU(), 15 | weight_norm(nn.Conv2d(64, 64, 3, stride=2, padding=1)), 16 | nn.LeakyReLU(), 17 | 18 | nn.Dropout(.5), 19 | weight_norm(nn.Conv2d(64, 128, 3, stride=1, padding=1)), 20 | nn.LeakyReLU(), 21 | weight_norm(nn.Conv2d(128, 128, 3, stride=1, padding=1)), 22 | nn.LeakyReLU(), 23 | weight_norm(nn.Conv2d(128, 128, 3, stride=2, padding=1)), 24 | nn.LeakyReLU(), 25 | 26 | nn.Dropout(.5), 27 | weight_norm(nn.Conv2d(128, 128, 3, stride=1, padding=0)), 28 | nn.LeakyReLU(), 29 | weight_norm(nn.Conv2d(128, 128, 1, stride=1, padding=0)), 30 | nn.LeakyReLU(), 31 | weight_norm(nn.Conv2d(128, 128, 1, stride=1, padding=0)), 32 | nn.LeakyReLU(), 33 | 34 | nn.AdaptiveAvgPool2d(1), 35 | Flatten() 36 | ) 37 | 38 | self.fc = weight_norm(nn.Linear(128, num_classes)) 39 | 40 | def forward(self, x): 41 | inter_layer = self.net(x) 42 | logits = self.fc(inter_layer) 43 | return inter_layer, logits 44 | 45 | 46 | class Generator(nn.Module): 47 | 48 | def __init__(self, latent_dim): 49 | super(Generator, self).__init__() 50 | self.net = nn.Sequential( 51 | nn.Linear(latent_dim, 512 * 4 * 4), 52 | nn.BatchNorm1d(512 * 4 * 4), 53 | nn.ReLU(), 54 | Reshape((512, 4, 4)), 55 | nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), 56 | nn.BatchNorm2d(256), 57 | nn.ReLU(), 58 | nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), 59 | nn.BatchNorm2d(128), 60 | nn.ReLU(), 61 | weight_norm(nn.ConvTranspose2d(128, 3, 4, stride=2, padding=1)), 62 | nn.Tanh(), 63 | ) 64 | 65 | def forward(self, x): 66 | return self.net(x) 67 | 68 | 69 | class Flatten(nn.Module): 70 | def forward(self, x): 71 | return x.view(x.shape[0], -1) 72 | 73 | 74 | class Reshape(nn.Module): 75 | def __init__(self, target_shape): 76 | super(Reshape, self).__init__() 77 | self.target_shape = (-1,) + target_shape 78 | 79 | def forward(self, x): 80 | return x.view(self.target_shape) 81 | 82 | 83 | def weights_init(m): 84 | if type(m) == nn.Linear: 85 | nn.init.normal_(m.weight, mean=.0, std=.05) 86 | nn.init.constant_(m.bias, .0) 87 | 88 | if type(m) == nn.ConvTranspose2d: 89 | nn.init.normal_(m.weight, mean=0, std=.05) 90 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | from torch.utils import data 4 | 5 | 6 | class StandardImgData(): 7 | def __init__(self, samples_per_class, batch_size, dataset): 8 | 9 | self.root = 'data/%s/' % dataset 10 | self.dataset = dataset 11 | self.img_sz = 32 12 | self.samples_per_class = samples_per_class 13 | self.batch_size = batch_size 14 | 15 | self.transform = transforms.Compose([ 16 | transforms.Resize(self.img_sz), 17 | transforms.CenterCrop(self.img_sz), 18 | transforms.ToTensor(), 19 | transforms.Normalize(mean=[.5, ], std=[.5]) if dataset == 'mnist' \ 20 | else transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]), 21 | ]) 22 | 23 | def get_dataloaders(self, dataset): 24 | if dataset == 'mnist': 25 | full_dataset = datasets.MNIST(root=self.root, train=True, transform=self.transform, target_transform=None, 26 | download=True) 27 | test_dataset = datasets.MNIST(root=self.root, train=False, transform=self.transform, target_transform=None, 28 | download=True) 29 | elif dataset == 'cifar': 30 | full_dataset = datasets.CIFAR10(root=self.root, train=True, transform=self.transform, target_transform=None, 31 | download=True) 32 | test_dataset = datasets.CIFAR10(root=self.root, train=False, transform=self.transform, 33 | target_transform=None, 34 | download=True) 35 | elif dataset == 'svhn': 36 | full_dataset = datasets.SVHN(root=self.root, split='train', transform=self.transform, target_transform=None, 37 | download=True) 38 | test_dataset = datasets.SVHN(root=self.root, split='test', transform=self.transform, 39 | target_transform=None, 40 | download=True) 41 | 42 | train_size = int(0.8 * len(full_dataset)) 43 | valid_size = len(full_dataset) - train_size 44 | train_dataset, valid_dataset = torch.utils.data.random_split(full_dataset, [train_size, valid_size]) 45 | 46 | train_unl_dataset = train_dataset 47 | train_lb_dataset = self.__get_samples_per_class(train_dataset, self.samples_per_class) 48 | 49 | train_unl_dataloader = data.DataLoader(train_unl_dataset, batch_size=self.batch_size, shuffle=True) 50 | train_lb_dataloader = data.DataLoader(train_lb_dataset, batch_size=self.batch_size, shuffle=True) 51 | train_lb_dataloader = self.__create_infinite_dataloader(train_lb_dataloader) 52 | valid_dataloader = data.DataLoader(valid_dataset, batch_size=self.batch_size, shuffle=False) 53 | test_dataloader = data.DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False) 54 | 55 | return train_unl_dataloader, train_lb_dataloader, valid_dataloader, test_dataloader 56 | 57 | @staticmethod 58 | def __get_samples_per_class(dataset, num_samples): 59 | labels = torch.tensor([y for x, y in dataset]) 60 | indices = torch.arange(len(labels)) 61 | indices = torch.cat([indices[labels == y][:num_samples] for y in torch.unique(labels)]) 62 | dataset = data.Subset(dataset, indices) 63 | return dataset 64 | 65 | @staticmethod 66 | def __create_infinite_dataloader(dataloader): 67 | data_iter = iter(dataloader) 68 | while True: 69 | try: 70 | yield next(data_iter) 71 | except StopIteration: 72 | data_iter = iter(dataloader) 73 | -------------------------------------------------------------------------------- /sgan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch 5 | 6 | from tensorboardX import SummaryWriter 7 | 8 | 9 | class SGAN_Manifold_Reg(): 10 | def __init__(self, batch_size, latent_dim, num_classes, generator, discriminator, data_loaders): 11 | 12 | self.batch_size = batch_size 13 | self.batch_size_cuda = torch.tensor(self.batch_size).cuda() 14 | self.latent_dim = latent_dim 15 | self.num_classes = num_classes 16 | self.lr = 0.0003 17 | 18 | self.save_path = 'models/' 19 | self.log_dir = 'logs/' 20 | self.writer = SummaryWriter(self.log_dir) 21 | 22 | self.G = generator.cuda() 23 | self.D = discriminator.cuda() 24 | self.train_unl_loader, self.train_lb_loader, self.valid_loader, self.test_loader = data_loaders 25 | 26 | self.ce_criterion = nn.CrossEntropyLoss().cuda() 27 | self.mse = nn.MSELoss().cuda() 28 | 29 | self.train_step = 0 30 | self.test_step = 0 31 | 32 | def train(self, num_epochs): 33 | opt_G = torch.optim.Adam(self.G.parameters(), lr=self.lr) 34 | opt_D = torch.optim.Adam(self.D.parameters(), lr=self.lr) 35 | 36 | for epoch_idx in range(num_epochs): 37 | 38 | avg_G_loss = avg_D_loss = 0 39 | self.G.train() 40 | self.D.train() 41 | for unl_train_x, __ in self.train_unl_loader: 42 | lb_train_x, lb_train_y = next(self.train_lb_loader) 43 | unl_train_x = unl_train_x.cuda() 44 | lb_train_x = lb_train_x.cuda() 45 | lb_train_y = lb_train_y.cuda() 46 | 47 | z, z_perturbed = self.define_noise() 48 | 49 | # Train Discriminator 50 | opt_D.zero_grad() 51 | imgs_fake = self.G(z) 52 | imgs_fake_perturbed = self.G(z_perturbed) 53 | 54 | __, logits_lb = self.D(lb_train_x) 55 | features_fake, logits_fake = self.D(imgs_fake) 56 | features_fake_pertubed, __ = self.D(imgs_fake_perturbed) 57 | features_real, logits_unl = self.D(unl_train_x) 58 | 59 | logits_sum_unl = torch.logsumexp(logits_unl, dim=1) 60 | logits_sum_fake = torch.logsumexp(logits_fake, dim=1) 61 | loss_unsupervised = torch.mean(F.softplus(logits_sum_unl)) - torch.mean(logits_sum_unl) + torch.mean( 62 | F.softplus(logits_sum_fake)) 63 | 64 | loss_supervised = torch.mean(self.ce_criterion(logits_lb, lb_train_y)) 65 | loss_manifold_reg = self.mse(features_fake, features_fake_pertubed) \ 66 | / self.batch_size_cuda 67 | 68 | loss_D = loss_supervised + .5 * loss_unsupervised + 1e-3 * loss_manifold_reg 69 | loss_D.backward() 70 | opt_D.step() 71 | avg_D_loss += loss_D 72 | 73 | # Train Generator 74 | opt_G.zero_grad() 75 | opt_D.zero_grad() 76 | imgs_fake = self.G(z) 77 | features_fake, __ = self.D(imgs_fake) 78 | features_real, __ = self.D(unl_train_x) 79 | m1 = torch.mean(features_real, dim=0) 80 | m2 = torch.mean(features_fake, dim=0) 81 | loss_G = torch.mean((m1 - m2) ** 2) # Feature matching 82 | loss_G.backward() 83 | opt_G.step() 84 | avg_G_loss += loss_G 85 | 86 | self.writer.add_scalar('G_loss', loss_G, self.train_step) 87 | self.writer.add_scalar('D_loss', loss_D, self.train_step) 88 | self.train_step += 1 89 | 90 | # Evaluate 91 | avg_G_loss /= len(self.train_unl_loader) 92 | avg_D_loss /= len(self.train_unl_loader) 93 | 94 | acc, val_loss = self.eval() 95 | 96 | print('Epoch %d disc_loss %.3f gen_loss %.3f val_loss %.3f acc %.3f' % ( 97 | epoch_idx, avg_D_loss, avg_G_loss, val_loss, acc)) 98 | 99 | torch.save(self.D.state_dict(), self.save_path + 'disc_{}.pth'.format(epoch_idx)) 100 | torch.save(self.G.state_dict(), self.save_path + 'gen_{}.pth'.format(epoch_idx)) 101 | 102 | self.writer.close() 103 | 104 | def eval(self, test=False, epoch_idx=None): 105 | if test: 106 | self.D.load_state_dict(torch.load(self.save_path + 'disc_{}.pth'.format(epoch_idx))) 107 | eval_loader = self.test_loader 108 | else: 109 | eval_loader = self.valid_loader 110 | 111 | val_loss = corrects = total_samples = 0.0 112 | with torch.no_grad(): 113 | self.D.eval() 114 | for x, y in eval_loader: 115 | x, y = x.cuda(), y.cuda() 116 | __, logits = self.D(x) 117 | loss = self.ce_criterion(logits, y) 118 | if not test: 119 | self.writer.add_scalar('val_loss', loss, self.test_step) 120 | self.test_step += 1 121 | val_loss += loss.item() 122 | preds = torch.argmax(logits, dim=1) 123 | corrects += torch.sum(preds == y) 124 | total_samples += len(y) 125 | 126 | val_loss /= len(self.valid_loader) 127 | acc = corrects.item() / total_samples 128 | 129 | return acc, val_loss 130 | 131 | def define_noise(self): 132 | z = torch.randn(self.batch_size, self.latent_dim).cuda() 133 | z_perturbed = z + torch.randn(self.batch_size, self.latent_dim).cuda() * 1e-5 134 | return z, z_perturbed 135 | --------------------------------------------------------------------------------