├── query ├── __init__.py └── query.py ├── data ├── __init__.py └── data.py ├── evals ├── __init__.py └── evals.py ├── train ├── __init__.py ├── train_acgan_full.py └── train_acgan_semi.py ├── utils ├── __init__.py ├── utils.py ├── logger.py └── drs.py ├── network ├── __init__.py ├── gan_loss.py ├── lenet.py └── acgan.py ├── README.md ├── rejection.ipynb └── main.py /query/__init__.py: -------------------------------------------------------------------------------- 1 | from .query import gold_acquistiion -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import get_transform, load_base_dataset, split_dataset 2 | 3 | -------------------------------------------------------------------------------- /evals/__init__.py: -------------------------------------------------------------------------------- 1 | from .evals import get_base_message, train_classifier, eval_classifier 2 | 3 | -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- 1 | from .train_acgan_full import train_acgan_full 2 | from .train_acgan_semi import train_acgan_semi 3 | 4 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import Logger 2 | from .utils import make_z, make_y, make_fixed_z, make_fixed_y 3 | from .utils import count_classes, save_to_logger, normalize_info, to_numpy_image 4 | from .utils import gold_score, entropy, accuracy 5 | 6 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | from .gan_loss import GANLoss 2 | from .acgan import ACGAN_Toy_Generator, ACGAN_Toy_Discriminator 3 | from .acgan import ACGAN_MNIST_Generator, ACGAN_MNIST_Discriminator 4 | from .acgan import ACGAN_CIFAR10_Generator, ACGAN_CIFAR10_Discriminator 5 | from .lenet import LeNet 6 | 7 | -------------------------------------------------------------------------------- /network/gan_loss.py: -------------------------------------------------------------------------------- 1 | # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 2 | # https://github.com/caogang/wgan-gp/blob/master/gan_mnist.py 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class GANLoss(nn.Module): 8 | def __init__(self, target_real_label=1.0, target_fake_label=0.0, reduction='mean'): 9 | super(GANLoss, self).__init__() 10 | self.register_buffer('real_label', torch.tensor(target_real_label)) 11 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 12 | self.loss = nn.BCEWithLogitsLoss(reduction=reduction) 13 | 14 | def get_target_tensor(self, prediction, target_is_real): 15 | if target_is_real: 16 | target_tensor = self.real_label 17 | else: 18 | target_tensor = self.fake_label 19 | return target_tensor.expand_as(prediction) 20 | 21 | def __call__(self, prediction, target_is_real): 22 | target_tensor = self.get_target_tensor(prediction, target_is_real) 23 | loss = self.loss(prediction, target_tensor) 24 | return loss 25 | 26 | 27 | -------------------------------------------------------------------------------- /network/lenet.py: -------------------------------------------------------------------------------- 1 | # https://github.com/kuangliu/pytorch-cifar/blob/master/models/lenet.py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class LeNet_32(nn.Module): 8 | def __init__(self, n_channels=3, n_classes=10): 9 | super().__init__() 10 | self.conv1 = nn.Conv2d(n_channels, 6, 5) 11 | self.conv2 = nn.Conv2d(6, 16, 5) 12 | self.fc1 = nn.Linear(16*5*5, 120) 13 | self.fc2 = nn.Linear(120, 84) 14 | self.fc3 = nn.Linear(84, n_classes) 15 | 16 | def forward(self, x): 17 | out = F.relu(self.conv1(x)) 18 | out = F.max_pool2d(out, 2) 19 | out = F.relu(self.conv2(out)) 20 | out = F.max_pool2d(out, 2) 21 | out = out.view(out.size(0), -1) 22 | out = F.relu(self.fc1(out)) 23 | out = F.relu(self.fc2(out)) 24 | out = self.fc3(out) 25 | return out 26 | 27 | 28 | # Create LeNet classifier 29 | def LeNet(image_size, n_channels=3, n_classes=10): 30 | if image_size == 32: 31 | return LeNet_32(n_channels, n_classes) 32 | else: 33 | raise NotImplementedError 34 | 35 | -------------------------------------------------------------------------------- /query/query.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data_utils 3 | from utils import entropy 4 | 5 | 6 | 7 | def gold_acquistiion(pool, netD, args, device): 8 | def gold_score_unlabel(x): 9 | out_D, out_C = netD(x) # B x 1, B x nc 10 | score_C = entropy(out_C) # B 11 | return out_D.view(-1) + score_C # B 12 | 13 | query_idx = score_based_acquisition(pool, gold_score_unlabel, args, device) 14 | return query_idx 15 | 16 | 17 | def score_based_acquisition(pool, score_func, args, device): 18 | loader = data_utils.DataLoader(pool, batch_size=args.pool_batch_size, shuffle=False) 19 | 20 | scores = [] 21 | for batch_idx, (real_x, _) in enumerate(loader): 22 | with torch.no_grad(): 23 | real_x = real_x.to(device) # B x nc x H x W 24 | score = score_func(real_x).cpu().numpy() # B 25 | for i in range(len(score)): 26 | idx = batch_idx * args.pool_batch_size + i # index in dataset 27 | scores.append((score[i], idx)) 28 | 29 | query_idx = [x[1] for x in scores[-args.per_size:]] # maximum values 30 | query_idx = [pool.indices[i] for i in query_idx] # pool idx -> base_dataset idx 31 | 32 | return query_idx 33 | 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mining GOLD Samples for Conditional GANs 2 | 3 | PyTorch implementation of ["Mining GOLD Samples for Conditional GANs"](https://arxiv.org/abs/1910.09170) (NeurIPS 2019). 4 | 5 | ## Run experiments 6 | 7 | Run example re-weighting experiments 8 | ``` 9 | python main.py --name reweight_base --dataset mnist --epochs 20 --mode acgan_semi 10 | python main.py --name reweight_gold --dataset mnist --epochs 20 --mode acgan_semi_gold 11 | ``` 12 | 13 | Run rejection sampling experiments 14 | ``` 15 | See rejection.ipynb 16 | ``` 17 | 18 | Run active learning experiments 19 | ``` 20 | python main.py --name active_base --dataset mnist --init_size 10 --per_size 2 --max_size 18 --mode acgan_semi --lambda_C_fake 0.01 --query_type random 21 | python main.py --name active_gold --dataset mnist --init_size 10 --per_size 2 --max_size 18 --mode acgan_semi --lambda_C_fake 0.01 --query_type gold 22 | ``` 23 | 24 | 25 | ## Citation 26 | If you use this code for your research, please cite our papers. 27 | ``` 28 | @inproceedings{ 29 | mo2019mining, 30 | title={Mining GOLD Samples for Conditional GANs}, 31 | author={Mo, Sangwoo and Kim, Chiheon and Kim, Sungwoong and Cho, Minsu and Shin, Jinwoo}, 32 | booktitle={Advances in Neural Information Processing Systems}, 33 | year={2019}, 34 | } 35 | ``` 36 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.utils.data as data_utils 8 | 9 | 10 | """"""""""""""" 11 | Generate noise 12 | """"""""""""""" 13 | 14 | def make_z(size, nz): 15 | """Return B x nz noise vector""" 16 | return torch.randn(size, nz) # B x nz 17 | 18 | def make_y(size, ny, value=None): 19 | """Return B condition vector""" 20 | if value is None: 21 | return torch.randint(ny, [size]).long() # B (random value) 22 | else: 23 | return torch.LongTensor(size).fill_(value) # B (given value) 24 | 25 | def make_fixed_z(size, nz, ny): 26 | """Return (B * ny) x nz noise vector (for visualization)""" 27 | z = make_z(size, nz) # B x nz 28 | return torch.cat([z] * ny, dim=0) # (B x ny) x nz 29 | 30 | def make_fixed_y(size, ny): 31 | """Return (B * ny) condition vector (for visualization)""" 32 | y = [torch.LongTensor(size).fill_(i) for i in range(ny)] # list of B tensors 33 | return torch.cat(y, dim=0) # (B * ny) 34 | 35 | 36 | """"""""""""""" 37 | Helper functions (I/O) 38 | """"""""""""""" 39 | 40 | def count_classes(dataset, class_num): 41 | count = [0] * class_num 42 | for _, y in dataset: 43 | count[y] += 1 44 | return count 45 | 46 | def save_to_logger(logger, info, step): 47 | for key, val in info.items(): 48 | if isinstance(val, np.ndarray): 49 | logger.image_summary(key, val, step) 50 | else: 51 | logger.scalar_summary(key, val, step) 52 | 53 | def normalize_info(info): 54 | num = info.pop('num') 55 | for key, val in info.items(): 56 | info[key] /= num 57 | return info 58 | 59 | def gold_score(netD, x, y, eps=1e-6): 60 | out_D, out_C = netD(x) # B x 1, B x nc 61 | out_C = torch.softmax(out_C, dim=1) # B x nc 62 | score_C = torch.log(out_C[torch.arange(len(out_C)), y] + eps) # B 63 | return out_D.view(-1) + score_C # B 64 | 65 | def entropy(outs, eps=0): 66 | probs = F.softmax(outs, dim=1) # B x nc 67 | entropy = -(probs * torch.log(probs + eps)).sum(-1) # B 68 | return entropy # B 69 | 70 | def accuracy(out, tgt): 71 | _, pred = out.max(1) 72 | acc = pred.eq(tgt).sum().item() / len(out) 73 | return acc 74 | 75 | def to_numpy_image(x): 76 | # convert torch tensor [-1,1] to numpy image [0,255] 77 | x = x.cpu().numpy().transpose(0, 2, 3, 1) # C x H x W -> H x W x C 78 | x = ((x + 1) / 2).clip(0, 1) # [-1,1] -> [0,1] 79 | x = (x * 255).astype(np.uint8) # uint8 numpy image 80 | return x 81 | 82 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | import tensorflow as tf 3 | import numpy as np 4 | import scipy.misc 5 | 6 | try: 7 | from StringIO import StringIO # Python 2.7 8 | except ImportError: 9 | from io import BytesIO # Python 3.x 10 | 11 | 12 | class Logger(object): 13 | 14 | def __init__(self, log_dir): 15 | """Create a summary writer logging to log_dir.""" 16 | self.writer = tf.summary.FileWriter(log_dir) 17 | 18 | def scalar_summary(self, tag, value, step): 19 | """Log a scalar variable.""" 20 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 21 | self.writer.add_summary(summary, step) 22 | 23 | def image_summary(self, tag, images, step): 24 | """Log a list of images.""" 25 | 26 | img_summaries = [] 27 | for i, img in enumerate(images): 28 | # Write the image to a string 29 | try: 30 | s = StringIO() 31 | except: 32 | s = BytesIO() 33 | 34 | # convert B/W image to RGB image 35 | if img.shape[-1] == 1: 36 | img = np.concatenate((img,) * 3, axis=-1) 37 | 38 | scipy.misc.toimage(img).save(s, format="png") 39 | 40 | # Create an Image object 41 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 42 | height=img.shape[0], 43 | width=img.shape[1]) 44 | # Create a Summary value 45 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 46 | 47 | # Create and write Summary 48 | summary = tf.Summary(value=img_summaries) 49 | self.writer.add_summary(summary, step) 50 | 51 | def histo_summary(self, tag, values, step, bins=1000): 52 | """Log a histogram of the tensor of values.""" 53 | 54 | # Create a histogram using numpy 55 | counts, bin_edges = np.histogram(values, bins=bins) 56 | 57 | # Fill the fields of the histogram proto 58 | hist = tf.HistogramProto() 59 | hist.min = float(np.min(values)) 60 | hist.max = float(np.max(values)) 61 | hist.num = int(np.prod(values.shape)) 62 | hist.sum = float(np.sum(values)) 63 | hist.sum_squares = float(np.sum(values ** 2)) 64 | 65 | # Drop the start of the first bin 66 | bin_edges = bin_edges[1:] 67 | 68 | # Add bin edges and counts 69 | for edge in bin_edges: 70 | hist.bucket_limit.append(edge) 71 | for c in counts: 72 | hist.bucket.append(c) 73 | 74 | # Create and write Summary 75 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 76 | self.writer.add_summary(summary, step) 77 | self.writer.flush() -------------------------------------------------------------------------------- /train/train_acgan_full.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.data as data_utils 4 | 5 | import network 6 | from utils import make_z, make_y 7 | from utils import gold_score, normalize_info 8 | 9 | 10 | 11 | def train_acgan_full(trainset, model, args, device, use_gold=False): 12 | # preprocess dataset 13 | if len(trainset) < args.per_epoch: 14 | n_iter = args.per_epoch // len(trainset) 15 | trainset = data_utils.ConcatDataset([trainset] * n_iter) 16 | loader = data_utils.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=8) 17 | 18 | # preprocess model 19 | netG = model['net_G'] 20 | netD = model['net_D'] 21 | optimizerG = model['optim_G'] 22 | optimizerD = model['optim_D'] 23 | 24 | # initialize criterion 25 | criterionGAN = network.GANLoss(reduction='none').to(device) 26 | criterionCE = nn.CrossEntropyLoss(reduction='none').to(device) 27 | 28 | # initialize loss info 29 | info = {'num': 0, 'loss_G': 0, 'loss_G_cls': 0, 'loss_D_real': 0, 'loss_D_fake': 0, 'loss_C_real': 0, 'loss_C_fake': 0} 30 | 31 | # train one epoch 32 | for i, (real_x, real_y) in enumerate(loader): 33 | # forward 34 | real_x = real_x.to(device) # B x nc x H x W 35 | real_y = real_y.to(device) # B 36 | fake_z = make_z(len(real_x), args.nz).to(device) # B x nz 37 | fake_y = make_y(len(real_x), args.ny).to(device) # B 38 | 39 | ######################### 40 | # (1) Update D network 41 | ######################### 42 | 43 | optimizerD.zero_grad() 44 | 45 | # real loss 46 | out_D, out_C = netD(real_x) # B x 1, B x nc 47 | loss_D_real = torch.mean(criterionGAN(out_D, True)) 48 | loss_C_real = torch.mean(criterionCE(out_C, real_y)) 49 | 50 | # fake loss 51 | fake_x = netG(fake_z, fake_y) # B x nc x H x W 52 | out_D, out_C = netD(fake_x.detach()) # B x 1, B x nc 53 | with torch.no_grad(): 54 | gold = gold_score(netD, fake_x, fake_y) 55 | 56 | if use_gold: 57 | weight = gold 58 | else: 59 | weight = torch.ones(len(gold)).to(device) 60 | 61 | loss_D_fake = torch.mean(criterionGAN(out_D, False) * weight) 62 | loss_C_fake = torch.mean(criterionCE(out_C, fake_y) * weight) * args.lambda_C_fake 63 | 64 | loss_D = loss_D_real + loss_D_fake + loss_C_real + loss_C_fake 65 | loss_D.backward() 66 | optimizerD.step() 67 | 68 | ######################### 69 | # (2) Update G network 70 | ######################### 71 | 72 | optimizerG.zero_grad() 73 | 74 | # GAN & classification loss 75 | fake_x = netG(fake_z, fake_y) # B x nc x H x W 76 | out_D, out_C = netD(fake_x) # B x 1, B x nc 77 | loss_G = torch.mean(criterionGAN(out_D, True)) 78 | loss_G_cls = torch.mean(criterionCE(out_C, fake_y)) 79 | 80 | # backward loss 81 | loss_G_total = loss_G + loss_G_cls 82 | loss_G_total.backward() 83 | optimizerG.step() 84 | 85 | # update loss info 86 | info['num'] += 1 87 | 88 | info['loss_G'] += loss_G.item() 89 | info['loss_G_cls'] += loss_G_cls.item() 90 | 91 | info['loss_D_real'] += loss_D_real.item() 92 | info['loss_D_fake'] += loss_D_fake.item() 93 | 94 | info['loss_C_real'] += loss_C_real.item() 95 | info['loss_C_fake'] += loss_C_fake.item() 96 | 97 | info = normalize_info(info) 98 | return info 99 | 100 | -------------------------------------------------------------------------------- /evals/evals.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.utils.data as data_utils 5 | 6 | import network 7 | from utils import make_z, make_y 8 | from utils import accuracy, normalize_info 9 | 10 | 11 | """"""""""""""" 12 | Evaluation metrics 13 | """"""""""""""" 14 | 15 | def get_base_message(epoch, info): 16 | message = "Epoch: {}".format(epoch) 17 | message += " G: {:.4f}".format(info['loss_G']) 18 | message += " G (cls): {:.4f}".format(info['loss_G_cls']) 19 | message += " D (real): {:.4f}".format(info['loss_D_real']) 20 | message += " D (fake): {:.4f}".format(info['loss_D_fake']) 21 | message += " C (real): {:.4f}".format(info['loss_C_real']) 22 | message += " C (fake): {:.4f}".format(info['loss_C_fake']) 23 | return message 24 | 25 | 26 | def adjust_learning_rate(optimizer, epoch, base_lr, lr_decay_period=20, lr_decay_rate=0.1): 27 | lr = base_lr * (lr_decay_rate ** (epoch // lr_decay_period)) 28 | for param_group in optimizer.param_groups: 29 | param_group['lr'] = lr 30 | 31 | 32 | def train_classifier(netG, args, device, testset=None): 33 | print('\nTraining a classifier') 34 | loader = data_utils.DataLoader(range(args.netC_per_epoch), batch_size=args.netC_batch_size, shuffle=True, num_workers=8) 35 | netC = network.LeNet(args.image_size, args.nc, args.ny).to(device) 36 | optimizerC = optim.Adam(netC.parameters(), lr=args.netC_lr, betas=(args.netC_beta1, args.netC_beta2)) 37 | criterionCE = nn.CrossEntropyLoss().to(device) 38 | 39 | for epoch in range(1, args.netC_epochs + 1): 40 | adjust_learning_rate(optimizerC, epoch, args.netC_lr, args.netC_lr_period) 41 | info = {'num': 0, 'loss_C': 0, 'acc': 0} 42 | 43 | # train network 44 | netC.train() 45 | for i, x in enumerate(loader): 46 | # forward 47 | fake_z = make_z(len(x), args.nz).to(device) # B x nz 48 | fake_y = make_y(len(x), args.ny).to(device) # B 49 | with torch.no_grad(): 50 | fake_x = netG(fake_z, fake_y) # B x nc x H x W 51 | out_fake = netC(fake_x) # B x nc 52 | loss_C = criterionCE(out_fake, fake_y) 53 | acc = accuracy(out_fake, fake_y) 54 | 55 | # backward 56 | optimizerC.zero_grad() 57 | loss_C.backward() 58 | optimizerC.step() 59 | 60 | # update loss info 61 | info['num'] += 1 62 | info['loss_C'] += loss_C.item() 63 | info['acc'] += acc 64 | 65 | # evaluate performance 66 | info = normalize_info(info) 67 | message = "Epoch: {} C: {:.4f} acc (train): {:.4f}".format(epoch, info['loss_C'], info['acc']) 68 | 69 | if testset and epoch % args.netC_eval_period == 0: 70 | test_acc = eval_classifier(netC, args, device, testset) 71 | message += " acc (test): {:.4f}".format(test_acc) 72 | 73 | print(message) 74 | print('') 75 | 76 | return netC 77 | 78 | 79 | def eval_classifier(netC, args, device, testset): 80 | loader = data_utils.DataLoader(testset, batch_size=args.netC_eval_batch_size, shuffle=False, num_workers=8) 81 | netC.eval() 82 | 83 | info = {'num': 0, 'acc': 0} # loss info 84 | for i, (real_x, real_y) in enumerate(loader): 85 | real_x = real_x.to(device) # B x nc x H x W 86 | real_y = real_y.to(device) # B 87 | with torch.no_grad(): 88 | out = netC(real_x) # B x nc 89 | _, pred = out.max(1) 90 | correct = pred.eq(real_y).sum().item() 91 | 92 | info['num'] += 1 93 | info['acc'] += correct / len(real_x) 94 | 95 | acc = info['acc'] / info['num'] 96 | return acc 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /rejection.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import os\n", 12 | "import math\n", 13 | "import numpy as np\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "import torch\n", 16 | "import torch.nn as nn\n", 17 | "import torch.optim as optim\n", 18 | "import torch.utils.data as data_utils\n", 19 | "from torchvision import datasets\n", 20 | "from torchvision import transforms as T\n", 21 | "\n", 22 | "%matplotlib inline\n", 23 | "import network\n", 24 | "from utils import make_z, make_y, make_fixed_z, make_fixed_y\n", 25 | "from utils.drs import drs, fitting_capacity" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": { 32 | "collapsed": true, 33 | "scrolled": false 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "nz, ny, nc = 100, 10, 1\n", 38 | "dataset = datasets.MNIST('./dataset/mnist', train=True, download=True)\n", 39 | "testset = datasets.MNIST('./dataset/mnist', train=False, download=True)\n", 40 | "testset.transform = T.Compose([\n", 41 | " T.Resize(32),\n", 42 | " T.ToTensor(),\n", 43 | " T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n", 44 | "])\n", 45 | "\n", 46 | "netG = network.ACGAN_MNIST_Generator(nz, nc, ny).cuda()\n", 47 | "netD = network.ACGAN_MNIST_Discriminator(nc, ny, use_sn=True).cuda()\n", 48 | "netG.eval()\n", 49 | "netD.eval()\n", 50 | "\n", 51 | "def load_network(name, it=1):\n", 52 | " netG_path = './{}/netG_{}.pth'.format(name, it)\n", 53 | " netD_path = './{}/netD_{}.pth'.format(name, it)\n", 54 | " netG.load_state_dict(torch.load(netG_path))\n", 55 | " netD.load_state_dict(torch.load(netD_path))\n", 56 | "\n", 57 | "load_network('logs/reweight_base')" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "num_samples = 5000 # sample per class\n", 67 | "z = make_z(num_samples * ny, nz).cuda()\n", 68 | "y = make_fixed_y(num_samples, ny).cuda()\n", 69 | "with torch.no_grad():\n", 70 | " x = netG(z, y)\n", 71 | "\n", 72 | "samples = data_utils.TensorDataset(x.cpu(), y.cpu())\n", 73 | "acc = fitting_capacity(samples, testset)\n", 74 | "print(acc)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "num_samples = 5000 # sample per class\n", 84 | "x = drs(netG, netD, num_samples, perc=10)\n", 85 | "samples = data_utils.TensorDataset(x.cpu(), y.cpu())\n", 86 | "acc = fitting_capacity(samples, testset)\n", 87 | "print(acc)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": { 94 | "collapsed": true 95 | }, 96 | "outputs": [], 97 | "source": [] 98 | } 99 | ], 100 | "metadata": { 101 | "kernelspec": { 102 | "display_name": "Python 3", 103 | "language": "python", 104 | "name": "python3" 105 | }, 106 | "language_info": { 107 | "codemirror_mode": { 108 | "name": "ipython", 109 | "version": 3 110 | }, 111 | "file_extension": ".py", 112 | "mimetype": "text/x-python", 113 | "name": "python", 114 | "nbconvert_exporter": "python", 115 | "pygments_lexer": "ipython3", 116 | "version": "3.6.3" 117 | } 118 | }, 119 | "nbformat": 4, 120 | "nbformat_minor": 2 121 | } 122 | -------------------------------------------------------------------------------- /train/train_acgan_semi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.data as data_utils 4 | from copy import deepcopy 5 | 6 | import network 7 | from utils import make_z, make_y 8 | from utils import gold_score, normalize_info 9 | 10 | 11 | 12 | def train_acgan_semi(trainset, pool, model, args, device=None, use_gold=False): 13 | # preprocess dataset (labels of pool = -1) 14 | if args.dataset != 'lsun': 15 | pool = deepcopy(pool) 16 | 17 | if args.dataset == 'synthetic': 18 | ones = torch.ones(len(pool.dataset)).long() 19 | pool.dataset.tensors = (pool.dataset.tensors[0], -ones) 20 | else: 21 | ones = torch.ones(len(pool.dataset)).long() 22 | pool.dataset.train_labels = -ones 23 | 24 | if args.dataset == 'cifar10': 25 | trainset = deepcopy(trainset) 26 | trainset.dataset.train_labels = torch.tensor(trainset.dataset.train_labels).long() 27 | 28 | dataset = data_utils.ConcatDataset([trainset, pool]) 29 | if len(dataset) < args.per_epoch: 30 | n_iter = args.per_epoch // len(dataset) 31 | dataset = data_utils.ConcatDataset([dataset] * n_iter) 32 | loader = data_utils.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=8) 33 | 34 | # preprocess model 35 | netG = model['net_G'] 36 | netD = model['net_D'] 37 | optimizerG = model['optim_G'] 38 | optimizerD = model['optim_D'] 39 | 40 | # initialize criterion 41 | criterionGAN = network.GANLoss(reduction='none').to(device) 42 | criterionCE = nn.CrossEntropyLoss(reduction='none').to(device) 43 | 44 | # initialize loss info 45 | info = {'num': 0, 'loss_G': 0, 'loss_G_cls': 0, 'loss_D_real': 0, 'loss_D_fake': 0, 'loss_C_real': 0, 'loss_C_fake': 0} 46 | 47 | # train one epoch 48 | for i, (real_x, real_y) in enumerate(loader): 49 | idx_l = [i for i in range(len(real_x)) if real_y[i] != -1] 50 | 51 | # forward 52 | real_x = real_x.to(device) # B x nc x H x W 53 | real_y = real_y.to(device) # B 54 | fake_z = make_z(len(real_x), args.nz).to(device) # B x nz 55 | fake_y = make_y(len(real_x), args.ny).to(device) # B 56 | 57 | ######################### 58 | # (1) Update D network 59 | ######################### 60 | 61 | optimizerD.zero_grad() 62 | 63 | # real loss 64 | out_D, out_C = netD(real_x) # B x 1, B x nc 65 | loss_D_real = torch.mean(criterionGAN(out_D, True)) 66 | if len(idx_l) > 0: 67 | loss_C_real = torch.mean(criterionCE(out_C[idx_l], real_y[idx_l])) 68 | else: 69 | loss_C_real = 0 70 | 71 | # fake loss 72 | fake_x = netG(fake_z, fake_y) # B x nc x H x W 73 | out_D, out_C = netD(fake_x.detach()) # B x 1, B x nc 74 | with torch.no_grad(): 75 | gold = gold_score(netD, fake_x, fake_y) 76 | 77 | if use_gold: 78 | weight = gold 79 | else: 80 | weight = torch.ones(len(gold)).to(device) 81 | 82 | loss_D_fake = torch.mean(criterionGAN(out_D, False) * weight) 83 | if len(idx_l) > 0: 84 | loss_C_fake = torch.mean(criterionCE(out_C[idx_l], fake_y[idx_l]) * weight[idx_l]) * args.lambda_C_fake 85 | else: 86 | loss_C_fake = 0 87 | 88 | loss_D = loss_D_real + loss_D_fake + loss_C_real + loss_C_fake 89 | loss_D.backward() 90 | optimizerD.step() 91 | 92 | ######################### 93 | # (2) Update G network 94 | ######################### 95 | 96 | optimizerG.zero_grad() 97 | 98 | # GAN & classification loss 99 | fake_x = netG(fake_z, fake_y) # B x nc x H x W 100 | out_D, out_C = netD(fake_x) # B x 1, B x nc 101 | loss_G = torch.mean(criterionGAN(out_D, True)) 102 | loss_G_cls = torch.mean(criterionCE(out_C, fake_y)) 103 | 104 | # backward loss 105 | loss_G_total = loss_G + loss_G_cls 106 | loss_G_total.backward() 107 | optimizerG.step() 108 | 109 | # update loss info 110 | info['num'] += 1 111 | 112 | info['loss_G'] += loss_G.item() 113 | info['loss_G_cls'] += loss_G_cls.item() 114 | 115 | info['loss_D_real'] += loss_D_real.item() 116 | info['loss_D_fake'] += loss_D_fake.item() 117 | 118 | if len(idx_l) > 0: 119 | info['loss_C_real'] += loss_C_real.item() 120 | info['loss_C_fake'] += loss_C_fake.item() 121 | 122 | info = normalize_info(info) 123 | return info 124 | 125 | -------------------------------------------------------------------------------- /data/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sklearn.datasets 3 | import torch 4 | import torch.utils.data as data_utils 5 | from torchvision import datasets 6 | from torchvision import transforms as T 7 | 8 | 9 | def get_transform(image_size, transform_type): 10 | if transform_type == 'none': 11 | return lambda x: x 12 | elif transform_type == 'base': 13 | return T.Compose([ 14 | T.Resize(image_size), 15 | T.CenterCrop(image_size), 16 | T.ToTensor(), 17 | T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 18 | ]) 19 | elif transform_type == 'random_crop': 20 | base_size = int(image_size * 1.1) 21 | return T.Compose([ 22 | T.Resize(base_size), 23 | T.RandomCrop(image_size), 24 | T.Resize(image_size), 25 | T.ToTensor(), 26 | T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 27 | ]) 28 | elif transform_type == 'random_crop_and_flip': 29 | base_size = int(image_size * 1.1) 30 | return T.Compose([ 31 | T.Resize(base_size), 32 | T.RandomCrop(image_size), 33 | T.RandomHorizontalFlip(), 34 | T.Resize(image_size), 35 | T.ToTensor(), 36 | T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 37 | ]) 38 | else: 39 | raise NotImplementedError 40 | 41 | 42 | def load_base_dataset(args): 43 | if args.dataset == 'synthetic': 44 | base_dataset, test_dataset = generate_synthetic_dataset(args) 45 | elif args.dataset == 'mnist': 46 | base_dataset = datasets.MNIST('./dataset/mnist', train=True, download=True) 47 | test_dataset = datasets.MNIST('./dataset/mnist', train=False, download=True) 48 | elif args.dataset == 'fmnist': 49 | base_dataset = datasets.FashionMNIST('./dataset/fmnist', train=True, download=True) 50 | test_dataset = datasets.FashionMNIST('./dataset/fmnist', train=False, download=True) 51 | elif args.dataset == 'svhn': 52 | base_dataset = datasets.SVHN('./dataset/svhn', split='train', download=True) 53 | test_dataset = datasets.SVHN('./dataset/svhn', split='test', download=True) 54 | elif args.dataset == 'cifar10': 55 | base_dataset = datasets.CIFAR10('./dataset/cifar10', train=True, download=True) 56 | test_dataset = datasets.CIFAR10('./dataset/cifar10', train=False, download=True) 57 | elif args.dataset == 'stl10': 58 | base_dataset = datasets.STL10('./dataset/stl10', split='train', download=True) 59 | test_dataset = datasets.STL10('./dataset/stl10', split='test', download=True) 60 | elif args.dataset == 'lsun': 61 | train_transform = get_transform(args.image_size, args.train_transform) 62 | test_transform = get_transform(args.image_size, args.test_transform) 63 | base_dataset = datasets.LSUN('./dataset/lsun', classes='val', transform=train_transform) 64 | test_dataset = datasets.LSUN('./dataset/lsun', classes='val', transform=test_transform) 65 | else: 66 | raise NotImplementedError 67 | 68 | return base_dataset, test_dataset 69 | 70 | 71 | def generate_synthetic_dataset(args): 72 | n_base = args.n_samples_base 73 | n_test = args.n_samples_test 74 | 75 | full_dataset = sklearn.datasets.make_blobs(n_base + n_test, cluster_std=0.5, centers=6) 76 | full_dataset = (full_dataset[0], full_dataset[1] % 2) 77 | 78 | xs = torch.FloatTensor(full_dataset[0][:n_base]) 79 | ys = torch.LongTensor(full_dataset[1][:n_base]) 80 | base_dataset = data_utils.TensorDataset(xs, ys) 81 | 82 | xs = torch.FloatTensor(full_dataset[0][n_base:]) 83 | ys = torch.LongTensor(full_dataset[1][n_base:]) 84 | test_dataset = data_utils.TensorDataset(xs, ys) 85 | 86 | return base_dataset, test_dataset 87 | 88 | 89 | def split_dataset(base_dataset, num_classes, init_size, val_size): 90 | shuffled_idx = np.random.permutation(len(base_dataset)) 91 | train_idx = pick_samples(base_dataset, num_classes, shuffled_idx, init_size) 92 | 93 | shuffled_idx = list(set(shuffled_idx) - set(train_idx)) 94 | val_idx = pick_samples(base_dataset, num_classes, shuffled_idx, val_size) 95 | 96 | pool_idx = list(set(shuffled_idx) - set(val_idx)) 97 | 98 | return train_idx, val_idx, pool_idx 99 | 100 | 101 | def pick_samples(base_dataset, num_classes, base_idx, size): 102 | sub_idx = [] 103 | for cls in range(num_classes): 104 | for idx in base_idx: 105 | if len(sub_idx) == (size // num_classes) * (cls + 1): 106 | break 107 | if base_dataset[idx][1] == cls: 108 | sub_idx.append(idx) 109 | return sub_idx 110 | 111 | 112 | -------------------------------------------------------------------------------- /utils/drs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.utils.data as data_utils 9 | 10 | import network 11 | from utils import make_z, make_y 12 | from utils import normalize_info, accuracy 13 | 14 | 15 | 16 | def get_score_stats(netG, netD, sample=50000): 17 | score_D = sample_scores(netG, netD, wd=1, wc=0, sample_size=sample) 18 | score_C = sample_scores(netG, netD, wd=0, wc=1, sample_size=sample) 19 | 20 | M = np.exp(np.max(score_D)) 21 | w = np.std(score_D) / np.sqrt(np.mean(np.square(score_C))) 22 | 23 | return M, w 24 | 25 | 26 | def sample_scores(netG, netD, nz=100, ny=10, wd=1, wc=1, sample_size=50000, batch_size=100): 27 | scores = [] 28 | for i in range(sample_size // batch_size): 29 | z = make_z(batch_size, nz).cuda() 30 | y = make_y(batch_size, ny).cuda() 31 | with torch.no_grad(): 32 | x = netG(z, y) 33 | s = gold(netD, x, y, wd, wc) 34 | scores.append(s) 35 | scores = np.concatenate(scores, axis=0) 36 | return scores 37 | 38 | 39 | def gold(netD, x, y, wd=1, wc=1, verbose=False): 40 | with torch.no_grad(): 41 | out_D, out_C = netD(x) # B x 1, B x nc 42 | 43 | score_D = out_D.view(-1) * wd 44 | out_C = torch.softmax(out_C, dim=1) 45 | out_C = out_C[torch.arange(len(out_C)), y] 46 | score_C = torch.log(out_C) * wc 47 | 48 | if verbose: 49 | plt.hist(score_D.cpu().numpy()) 50 | plt.hist(score_C.cpu().numpy()) 51 | 52 | return (score_D + score_C).cpu().numpy() 53 | 54 | 55 | def drs(netG, netD, num_samples=10, perc=10, nz=100, ny=10, batch_size=100, eps=1e-6): 56 | M, w = get_score_stats(netG, netD) 57 | ones = np.ones(batch_size).astype('int64') 58 | 59 | images = [[] for _ in range(ny)] 60 | for cls in range(10): 61 | while len(images[cls]) < num_samples: 62 | z = make_z(batch_size, nz).cuda() 63 | y = make_y(batch_size, ny, cls).cuda() 64 | with torch.no_grad(): 65 | x = netG(z, y) 66 | r = np.exp(gold(netD, x, y, 1, w)) 67 | 68 | p = np.minimum(ones, r/M) 69 | f = np.log(p + eps) - np.log(1 - p + eps) # inverse sigmoid 70 | f = (f - np.percentile(f, perc)) 71 | p = [1 / (1 + math.exp(-x)) for x in f] # sigmoid 72 | accept = np.random.binomial(ones, p) 73 | 74 | for i in range(batch_size): 75 | if accept[i] and len(images[cls]) < num_samples: 76 | images[cls].append(x[i].detach().cpu()) 77 | 78 | images = torch.stack([x for l in images for x in l]) 79 | return images 80 | 81 | 82 | def adjust_learning_rate(optimizer, epoch, base_lr, lr_decay_period=20, lr_decay_rate=0.1): 83 | lr = base_lr * (lr_decay_rate ** (epoch // lr_decay_period)) 84 | for param_group in optimizer.param_groups: 85 | param_group['lr'] = lr 86 | 87 | 88 | def fitting_capacity(samples, testset, nc=1, ny=10, epochs=40, eval_period=10, verbose=False): 89 | netC = network.LeNet(32, nc, ny) 90 | netC = nn.DataParallel(netC, [0]).cuda() 91 | optimizerC = optim.Adam(netC.parameters(), lr=0.001, betas=(0.5, 0.999)) 92 | criterionCE = nn.CrossEntropyLoss() 93 | 94 | loader = data_utils.DataLoader(samples, batch_size=128, shuffle=True, num_workers=8) 95 | test_acc = 0 96 | for epoch in range(1, epochs + 1): 97 | adjust_learning_rate(optimizerC, epoch, 0.001, epochs//2) 98 | info = {'num': 0, 'loss_C': 0, 'acc': 0} 99 | 100 | # train network 101 | netC.train() 102 | for i, (x, y) in enumerate(loader): 103 | # forward 104 | x = x.cuda() 105 | y = y.cuda() 106 | out = netC(x) # B x nc 107 | loss_C = criterionCE(out, y) 108 | 109 | # backward 110 | optimizerC.zero_grad() 111 | loss_C.backward() 112 | optimizerC.step() 113 | 114 | # update loss info 115 | info['num'] += 1 116 | info['loss_C'] += loss_C.item() 117 | info['acc'] += accuracy(out, y) 118 | 119 | # evaluate performance 120 | info = normalize_info(info) 121 | message = "Epoch: {} C: {:.4f} acc (train): {:.4f}".format(epoch, info['loss_C'], info['acc']) 122 | if epoch % eval_period == 0: 123 | test_acc = eval_classifier(netC, testset) 124 | message += " acc (test): {:.4f}".format(test_acc) 125 | 126 | if verbose: 127 | print(message) 128 | 129 | return test_acc 130 | 131 | 132 | def eval_classifier(netC, testset): 133 | loader = data_utils.DataLoader(testset, batch_size=128, shuffle=False, num_workers=8) 134 | netC.eval() 135 | 136 | info = {'num': 0, 'acc': 0} # loss info 137 | for i, (x, y) in enumerate(loader): 138 | x = x.cuda() # B x nc x H x W 139 | y = y.cuda() # B 140 | with torch.no_grad(): 141 | pred = netC(x).max(1)[1] 142 | correct = pred.eq(y).sum().item() 143 | 144 | info['num'] += 1 145 | info['acc'] += correct / len(x) 146 | 147 | acc = info['acc'] / info['num'] 148 | return acc 149 | 150 | -------------------------------------------------------------------------------- /network/acgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def get_norm(use_sn): 6 | if use_sn: # spectral normalization 7 | return nn.utils.spectral_norm 8 | else: # identity mapping 9 | return lambda x: x 10 | 11 | # https://github.com/znxlwm/pytorch-generative-model-collections/blob/master/utils.py 12 | def weights_init(net): 13 | for m in net.modules(): 14 | if isinstance(m, nn.Conv2d): 15 | m.weight.data.normal_(0, 0.02) 16 | m.bias.data.zero_() 17 | elif isinstance(m, nn.ConvTranspose2d): 18 | m.weight.data.normal_(0, 0.02) 19 | m.bias.data.zero_() 20 | elif isinstance(m, nn.Linear): 21 | m.weight.data.normal_(0, 0.02) 22 | m.bias.data.zero_() 23 | 24 | # https://github.com/gitlimlab/ACGAN-PyTorch/blob/master/utils.py 25 | def weights_init_3channel(m): 26 | classname = m.__class__.__name__ 27 | if classname.find('Conv') != -1: 28 | m.weight.data.normal_(0.0, 0.02) 29 | elif classname.find('BatchNorm') != -1: 30 | m.weight.data.normal_(1.0, 0.02) 31 | m.bias.data.fill_(0) 32 | 33 | def onehot(y, class_num): 34 | eye = torch.eye(class_num).type_as(y) # ny x ny 35 | onehot = eye[y.view(-1)].float() # B -> B x ny 36 | return onehot 37 | 38 | 39 | # https://github.com/caogang/wgan-gp/blob/master/gan_toy.py 40 | class ACGAN_Toy_Generator(nn.Module): 41 | def __init__(self, nz=2, nc=2, ny=2, dim=512): 42 | super().__init__() 43 | self.class_num = ny 44 | self.net = nn.Sequential( 45 | nn.Linear(nz + ny, dim), 46 | nn.ReLU(True), 47 | nn.Linear(dim, dim), 48 | nn.ReLU(True), 49 | nn.Linear(dim, dim), 50 | nn.ReLU(True), 51 | nn.Linear(dim, nc), 52 | ) 53 | weights_init(self) 54 | 55 | def forward(self, x, y): 56 | y = onehot(y, self.class_num) # B -> B x ny 57 | x = torch.cat([x, y], dim=1) # B x (nz + ny) 58 | return self.net(x) 59 | 60 | # https://github.com/caogang/wgan-gp/blob/master/gan_toy.py 61 | class ACGAN_Toy_Discriminator(nn.Module): 62 | def __init__(self, nc=2, ny=2, dim=512, use_sn=False): 63 | super().__init__() 64 | norm = get_norm(use_sn) 65 | self.net = nn.Sequential( 66 | nn.Linear(nc, dim), 67 | nn.ReLU(True), 68 | nn.Linear(dim, dim), 69 | nn.ReLU(True), 70 | nn.Linear(dim, dim), 71 | nn.ReLU(True), 72 | ) 73 | self.out_d = nn.Linear(dim, 1) 74 | self.out_c = nn.Linear(dim, ny) 75 | weights_init(self) 76 | 77 | def forward(self, x, y=None, get_feature=False): 78 | x = self.net(x) 79 | if get_feature: 80 | return x 81 | else: 82 | return self.out_d(x), self.out_c(x) 83 | 84 | 85 | # https://github.com/znxlwm/pytorch-generative-model-collections/blob/master/ACGAN.py 86 | class ACGAN_MNIST_Generator(nn.Module): 87 | def __init__(self, nz=100, nc=1, ny=10, image_size=32): 88 | super().__init__() 89 | self.class_num = ny 90 | self.image_size = image_size 91 | self.fc = nn.Sequential( 92 | nn.Linear(nz + ny, 1024), 93 | nn.BatchNorm1d(1024), 94 | nn.ReLU(), 95 | nn.Linear(1024, 128 * (image_size // 4) * (image_size // 4)), 96 | nn.BatchNorm1d(128 * (image_size // 4) * (image_size // 4)), 97 | nn.ReLU(), 98 | ) 99 | self.deconv = nn.Sequential( 100 | nn.ConvTranspose2d(128, 64, 4, 2, 1), 101 | nn.BatchNorm2d(64), 102 | nn.ReLU(), 103 | nn.ConvTranspose2d(64, nc, 4, 2, 1), 104 | nn.Tanh(), 105 | ) 106 | weights_init(self) 107 | 108 | def forward(self, x, y): 109 | y = onehot(y, self.class_num) # B -> B x ny 110 | x = torch.cat([x, y], 1) 111 | x = self.fc(x) 112 | x = x.view(-1, 128, (self.image_size // 4), (self.image_size // 4)) 113 | x = self.deconv(x) 114 | return x 115 | 116 | # https://github.com/znxlwm/pytorch-generative-model-collections/blob/master/ACGAN.py 117 | class ACGAN_MNIST_Discriminator(nn.Module): 118 | def __init__(self, nc=1, ny=10, image_size=32, use_sn=False): 119 | super().__init__() 120 | self.class_num = ny 121 | self.image_size = image_size 122 | norm = get_norm(use_sn) 123 | self.conv = nn.Sequential( 124 | norm(nn.Conv2d(nc, 64, 4, 2, 1)), # # use spectral norm 125 | nn.LeakyReLU(0.2), 126 | norm(nn.Conv2d(64, 128, 4, 2, 1)), # use spectral norm 127 | nn.BatchNorm2d(128), 128 | nn.LeakyReLU(0.2), 129 | ) 130 | self.fc = nn.Sequential( 131 | nn.Linear(128 * (image_size // 4) * (image_size // 4), 1024), 132 | nn.BatchNorm1d(1024), 133 | nn.LeakyReLU(0.2), 134 | ) 135 | self.out_d = nn.Linear(1024, 1) 136 | self.out_c = nn.Linear(1024, self.class_num) 137 | weights_init(self) 138 | 139 | def forward(self, x, y=None, get_feature=False): 140 | x = self.conv(x) 141 | x = x.view(-1, 128 * (self.image_size // 4) * (self.image_size // 4)) 142 | x = self.fc(x) 143 | if get_feature: 144 | return x 145 | else: 146 | return self.out_d(x), self.out_c(x) 147 | 148 | 149 | # https://github.com/gitlimlab/ACGAN-PyTorch/blob/master/network.py 150 | class ACGAN_CIFAR10_Generator(nn.Module): 151 | def __init__(self, nz=100, nc=3, ny=10): 152 | super().__init__() 153 | self.class_num = ny 154 | self.fc = nn.Linear(nz + ny, 384) 155 | self.tconv = nn.Sequential( 156 | # tconv1 157 | nn.ConvTranspose2d(384, 192, 4, 1, 0, bias=False), 158 | nn.BatchNorm2d(192), 159 | nn.ReLU(True), 160 | # tconv2 161 | nn.ConvTranspose2d(192, 96, 4, 2, 1, bias=False), 162 | nn.BatchNorm2d(96), 163 | nn.ReLU(True), 164 | # tconv3 165 | nn.ConvTranspose2d(96, 48, 4, 2, 1, bias=False), 166 | nn.BatchNorm2d(48), 167 | nn.ReLU(True), 168 | # tconv4 169 | nn.ConvTranspose2d(48, nc, 4, 2, 1, bias=False), 170 | nn.Tanh(), 171 | ) 172 | weights_init_3channel(self) 173 | 174 | def forward(self, x, y): 175 | y = onehot(y, self.class_num) # B -> B x ny 176 | x = torch.cat([x, y], dim=1) # B x (nz + ny) 177 | x = self.fc(x) 178 | x = x.view(-1, 384, 1, 1) 179 | x = self.tconv(x) 180 | return x 181 | 182 | # https://github.com/gitlimlab/ACGAN-PyTorch/blob/master/network.py 183 | class ACGAN_CIFAR10_Discriminator(nn.Module): 184 | def __init__(self, nc=3, ny=10, use_sn=False): 185 | super().__init__() 186 | norm = get_norm(use_sn) 187 | self.conv = nn.Sequential( 188 | # conv1 189 | norm(nn.Conv2d(nc, 16, 3, 2, 1, bias=False)), # use spectral norm 190 | nn.LeakyReLU(0.2, inplace=True), 191 | nn.Dropout(0.5, inplace=False), 192 | # conv2 193 | norm(nn.Conv2d(16, 32, 3, 1, 1, bias=False)), # use spectral norm 194 | nn.BatchNorm2d(32), 195 | nn.LeakyReLU(0.2, inplace=True), 196 | nn.Dropout(0.5, inplace=False), 197 | # conv3 198 | norm(nn.Conv2d(32, 64, 3, 2, 1, bias=False)), # use spectral norm 199 | nn.BatchNorm2d(64), 200 | nn.LeakyReLU(0.2, inplace=True), 201 | nn.Dropout(0.5, inplace=False), 202 | # conv4 203 | norm(nn.Conv2d(64, 128, 3, 1, 1, bias=False)), # use spectral norm 204 | nn.BatchNorm2d(128), 205 | nn.LeakyReLU(0.2, inplace=True), 206 | nn.Dropout(0.5, inplace=False), 207 | # conv5 208 | norm(nn.Conv2d(128, 256, 3, 2, 1, bias=False)), # use spectral norm 209 | nn.BatchNorm2d(256), 210 | nn.LeakyReLU(0.2, inplace=True), 211 | nn.Dropout(0.5, inplace=False), 212 | # conv6 213 | norm(nn.Conv2d(256, 512, 3, 1, 1, bias=False)), # use spectral norm 214 | nn.BatchNorm2d(512), 215 | nn.LeakyReLU(0.2, inplace=True), 216 | nn.Dropout(0.5, inplace=False), 217 | ) 218 | self.out_d = nn.Linear(4 * 4 * 512, 1) 219 | self.out_c = nn.Linear(4 * 4 * 512, ny) 220 | weights_init_3channel(self) 221 | 222 | def forward(self, x, y=None, get_feature=False): 223 | x = self.conv(x) 224 | x = x.view(-1, 4*4*512) 225 | if get_feature: 226 | return x 227 | else: 228 | return self.out_d(x), self.out_c(x) 229 | 230 | 231 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import random 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.utils.data as data_utils 9 | 10 | import data 11 | import network 12 | import train 13 | import evals 14 | import query 15 | from utils import * 16 | 17 | 18 | 19 | def add_parser(parser): 20 | # base arguments 21 | parser.add_argument('--name', type=str, default='temp', help='experiment name') 22 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 23 | parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') 24 | parser.add_argument('--dataset', type=str, default='mnist', help='dataset') 25 | parser.add_argument('--image_size', type=int, default=32, help='image size (default: 32)') 26 | parser.add_argument('--train_transform', type=str, default='random_crop', help='data augmentaion (training)') 27 | parser.add_argument('--test_transform', type=str, default='base', help='data augmentation (test)') 28 | parser.add_argument('--n_samples_base', type=int, default=1000, help='number of base dataset (default: 1000)') 29 | parser.add_argument('--n_samples_test', type=int, default=1000, help='number of test dataset (default: 1000)') 30 | parser.add_argument('--mode', type=str, default='acgan', help='train method (acgan|acgan_gold|etc.)') 31 | parser.add_argument('--network', type=str, default='acgan_sn', help='network architecture for GAN') 32 | parser.add_argument('--nz', type=int, default=100, help='dimension of noise vector (default: 100)') 33 | parser.add_argument('--ny', type=int, default=10, help='number of classes (default: 10)') 34 | parser.add_argument('--nc', type=int, default=1, help='number of channels of image (default: 1)') 35 | 36 | # training arguments 37 | parser.add_argument('--epochs', type=int, default=100, help='epochs (default: 100)') 38 | parser.add_argument('--per_epoch', type=int, default=10000, help='# of total training samples (default: 10000)') 39 | parser.add_argument('--batch_size', type=int, default=128, help='batch size (default: 128)') 40 | parser.add_argument('--lr', type=float, default=0.0002, help='learning rate (default: 0.0002)') 41 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam (default: 0.5)') 42 | parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for adam (default: 0.999)') 43 | parser.add_argument('--lambda_C_fake', type=float, default=0.1, help='weight for loss_C_fake (default: 0.1)') 44 | parser.add_argument('--use_final_model', action='store_true', help='use final instead of best (default: False)') 45 | parser.add_argument('--compare_metric', type=str, default='cap_val', help='compare models (default: cap_val)') 46 | 47 | # evaluation arguments 48 | parser.add_argument('--eval_period', type=int, default=10, help='evaluation period for heavy stuffs (default: 10)') 49 | parser.add_argument('--netC_network', type=str, default='lenet', help='network architecture for classifier') 50 | parser.add_argument('--netC_epochs', type=int, default=40, help='epochs for classifier (default: 40)') 51 | parser.add_argument('--netC_per_epoch', type=int, default=10000, help='per epoch for classifier (default: 10000)') 52 | parser.add_argument('--netC_batch_size', type=int, default=128, help='batch size for classifier (default: 128)') 53 | parser.add_argument('--netC_lr', type=float, default=0.001, help='learning rate for classifier (default: 0.001)') 54 | parser.add_argument('--netC_lr_period', type=int, default=20, help='lr decay period for classifier (default: 20)') 55 | parser.add_argument('--netC_beta1', type=float, default=0.5, help='beta1 for classifier (default: 0.5)') 56 | parser.add_argument('--netC_beta2', type=float, default=0.999, help='beta2 for classifier (default: 0.999)') 57 | parser.add_argument('--netC_eval_period', type=int, default=10, help='eval period ofr classifier (default: 10)') 58 | parser.add_argument('--netC_eval_batch_size', type=int, default=1000, help='eval batch size for classifier (default: 1000)') 59 | 60 | # query arguments 61 | parser.add_argument('--init_size', type=int, default=None, help='size of initial training set (default: None)') 62 | parser.add_argument('--per_size', type=int, default=None, help='size of query for each acquisition (default: None)') 63 | parser.add_argument('--max_size', type=int, default=None, help='size of maximum training set (default: None)') 64 | parser.add_argument('--val_size', type=int, default=100, help='size of validation set (default: 100)') 65 | parser.add_argument('--query_type', type=str, default='random', help='acquisition algorithm (random|maxent|etc.)') 66 | parser.add_argument('--pool_batch_size', type=int, default=1000, help='batch size for query selection (default: 1000)') 67 | parser.add_argument('--reinit_type', type=str, default='cont_G', help='re-initialization for each query iteration') 68 | 69 | return parser 70 | 71 | 72 | class BaseModel(object): 73 | def __init__(self, args): 74 | self.args = args 75 | self.set_device() 76 | self.logger = Logger('./logs/{}'.format(self.args.name)) 77 | 78 | def set_device(self): 79 | str_ids = self.args.gpu_ids.split(',') 80 | self.gpu_ids = [] 81 | for str_id in str_ids: 82 | if int(str_id) >= 0: 83 | self.gpu_ids.append(int(str_id)) 84 | if len(self.gpu_ids) > 0: 85 | torch.cuda.set_device(self.gpu_ids[0]) 86 | self.device = torch.device('cuda:{}'.format(self.args.gpu_ids[0])) 87 | else: 88 | self.device = torch.device('cpu') 89 | 90 | """"""""""""""" 91 | Run model 92 | """"""""""""""" 93 | 94 | def run(self): 95 | # initialize setting 96 | self.init_data() 97 | self.init_model() 98 | 99 | # run experiment 100 | self.it = 0 # iteration (start from 0) 101 | while len(self.trainset) <= self.args.max_size: 102 | self.it += 1 # update iteration 103 | self.show_status() 104 | 105 | # train network 106 | if len(self.trainset) > 0: 107 | self.reinit_model() 108 | self.train() 109 | 110 | # add new query 111 | if len(self.trainset) < self.args.max_size: 112 | self.add_query() 113 | else: 114 | break 115 | 116 | """"""""""""""" 117 | Initialize model 118 | """"""""""""""" 119 | 120 | def init_data(self): 121 | print('Initialize dataset...') 122 | self.train_transform = data.get_transform(self.args.image_size, self.args.train_transform) 123 | self.test_transform = data.get_transform(self.args.image_size, self.args.test_transform) 124 | 125 | # load base dataset 126 | self.base_dataset, self.test_dataset = data.load_base_dataset(self.args) 127 | self.base_dataset.transform = self.train_transform 128 | self.test_dataset.transform = self.test_transform 129 | 130 | # split to train/val/pool set 131 | if self.args.init_size is None: 132 | self.train_idx = list(range(len(self.base_dataset))) 133 | self.val_idx = [] 134 | self.pool_idx = [] 135 | self.args.init_size = len(self.base_dataset) 136 | self.args.per_size = 0 137 | self.args.max_size = len(self.base_dataset) 138 | else: 139 | self.train_idx, self.val_idx, self.pool_idx = data.split_dataset( 140 | self.base_dataset, self.args.ny, self.args.init_size, self.args.val_size) 141 | 142 | if self.args.max_size is None: 143 | self.args.per_size = 0 144 | self.args.max_size = self.args.init_size 145 | 146 | # define trainset and pool 147 | self.trainset = data_utils.Subset(self.base_dataset, self.train_idx) 148 | self.valset = data_utils.Subset(self.base_dataset, self.val_idx) 149 | self.pool = data_utils.Subset(self.base_dataset, self.pool_idx) 150 | 151 | def show_status(self): 152 | # update count 153 | if self.it == 1: 154 | self.count = count_classes(self.trainset, self.args.ny) 155 | self.prev_count = self.count 156 | self.count = count_classes(self.trainset, self.args.ny) 157 | 158 | message = '\n# of classes:' 159 | for i in range(self.args.ny): 160 | message += ' {}: {} (+{})'.format(i, self.count[i], self.count[i] - self.prev_count[i]) 161 | message += ' sum: {}'.format(sum(self.count)) 162 | print(message) 163 | 164 | def init_model(self): 165 | print('Initialize networks...') 166 | self._init_model_G() 167 | self._init_model_D() 168 | self.optimizerG = optim.Adam(self.netG.parameters(), lr=self.args.lr, betas=(self.args.beta1, self.args.beta2)) 169 | self.optimizerD = optim.Adam(self.netD.parameters(), lr=self.args.lr, betas=(self.args.beta1, self.args.beta2)) 170 | 171 | def reinit_model(self): 172 | if self.args.reinit_type == 'random': 173 | self._init_model_G() 174 | self._init_model_D() 175 | self.optimizerG = optim.Adam(self.netG.parameters(), lr=self.args.lr, betas=(self.args.beta1, self.args.beta2)) 176 | self.optimizerD = optim.Adam(self.netD.parameters(), lr=self.args.lr, betas=(self.args.beta1, self.args.beta2)) 177 | elif self.args.reinit_type == 'cont_G': 178 | self._init_model_D() 179 | self.optimizerG = optim.Adam(self.netG.parameters(), lr=self.args.lr, betas=(self.args.beta1, self.args.beta2)) 180 | self.optimizerD = optim.Adam(self.netD.parameters(), lr=self.args.lr, betas=(self.args.beta1, self.args.beta2)) 181 | else: 182 | raise NotImplementedError 183 | 184 | def _init_model_G(self): 185 | if self.args.dataset in ['synthetic']: 186 | self.netG = network.ACGAN_Toy_Generator(self.args.nz, self.args.nc, self.args.ny).to(self.device) 187 | elif self.args.dataset in ['mnist', 'fmnist']: 188 | self.netG = network.ACGAN_MNIST_Generator(self.args.nz, self.args.nc, self.args.ny).to(self.device) 189 | else: 190 | self.netG = network.ACGAN_CIFAR10_Generator(self.args.nz, self.args.nc, self.args.ny).to(self.device) 191 | 192 | def _init_model_D(self): 193 | if self.args.dataset in ['synthetic']: 194 | self.netD = network.ACGAN_Toy_Discriminator(self.args.nc, self.args.ny, use_sn=True).to(self.device) 195 | elif self.args.dataset in ['mnist', 'fmnist']: 196 | self.netD = network.ACGAN_MNIST_Discriminator(self.args.nc, self.args.ny, use_sn=True).to(self.device) 197 | else: 198 | self.netD = network.ACGAN_CIFAR10_Discriminator(self.args.nc, self.args.ny, use_sn=True).to(self.device) 199 | 200 | """"""""""""""" 201 | Main functions 202 | """"""""""""""" 203 | 204 | def train(self): 205 | print('Train networks...') 206 | 207 | # train GAN networks 208 | best_epoch = 0 209 | best_score = 0 210 | metric = self.args.compare_metric 211 | for epoch in range(1, self.args.epochs + 1): 212 | info = self._train_sub(epoch) 213 | info = self._eval_sub(epoch, info) 214 | 215 | # save the best model 216 | if metric in info and info[metric] >= best_score: 217 | netG_best = self.netG.state_dict() 218 | netD_best = self.netD.state_dict() 219 | best_epoch = epoch 220 | best_score = info[metric] 221 | 222 | # load the best model 223 | if not self.args.use_final_model and best_epoch > 0: 224 | print('\nUse the best networks (epoch: {}, score: {:.3f})'.format(best_epoch, best_score)) 225 | self.netG.load_state_dict(netG_best) 226 | self.netD.load_state_dict(netD_best) 227 | 228 | # save network 229 | netG_path = './logs/{}/netG_{}.pth'.format(self.args.name, self.it) 230 | netD_path = './logs/{}/netD_{}.pth'.format(self.args.name, self.it) 231 | torch.save(self.netG.cpu().state_dict(), netG_path) 232 | torch.save(self.netD.cpu().state_dict(), netD_path) 233 | self.netG = self.netG.to(self.device) 234 | self.netD = self.netD.to(self.device) 235 | 236 | def _train_sub(self, epoch): 237 | self.start_time_train = time.time() 238 | self.base_dataset.transform = self.train_transform 239 | self.netG.train() 240 | self.netD.train() 241 | 242 | model = { 243 | 'net_G': self.netG, 244 | 'net_D': self.netD, 245 | 'optim_G': self.optimizerG, 246 | 'optim_D': self.optimizerD, 247 | } 248 | 249 | # train networks by 1 epoch 250 | if self.args.mode == 'acgan_semi': 251 | return train.train_acgan_semi(self.trainset, self.pool, model, self.args, self.device) 252 | elif self.args.mode == 'acgan_semi_gold': 253 | if epoch <= self.args.epochs // 2: 254 | return train.train_acgan_semi(self.trainset, self.pool, model, self.args, self.device) 255 | else: 256 | return train.train_acgan_semi(self.trainset, self.pool, model, self.args, self.device, use_gold=True) 257 | else: 258 | raise NotImplementedError 259 | 260 | def _eval_sub(self, epoch, info): 261 | self.start_time_eval = time.time() 262 | self.base_dataset.transform = self.test_transform 263 | self.netG.eval() 264 | self.netD.eval() 265 | 266 | # compute evaluation metrics 267 | message = evals.get_base_message(epoch, info) 268 | if epoch % self.args.eval_period == 0: 269 | self.netC = evals.train_classifier(self.netG, self.args, self.device, testset=self.test_dataset) 270 | 271 | cap_test = evals.eval_classifier(self.netC, self.args, self.device, testset=self.test_dataset) 272 | message += " cap (test): {:.4f}".format(cap_test) 273 | info['cap_test'] = cap_test 274 | 275 | if len(self.valset) > 0: 276 | cap_test = evals.eval_classifier(self.netC, self.args, self.device, testset=self.valset) 277 | message += " cap (val): {:.4f}".format(cap_test) 278 | info['cap_val'] = cap_test 279 | 280 | eval_time = int(time.time() - self.start_time_eval) 281 | train_time = int(time.time() - self.start_time_train) - eval_time 282 | message += ' {:d}s/{:d}s elapsed'.format(train_time, eval_time) 283 | 284 | print(message) 285 | step = self.args.epochs * (self.it - 1) + epoch 286 | save_to_logger(self.logger, info, step) 287 | 288 | return info 289 | 290 | def add_query(self): 291 | print('\nSelect queries... (query type: {})'.format(self.args.query_type)) 292 | start_time = time.time() 293 | self.netG.eval() 294 | self.netD.eval() 295 | 296 | # get query & update train/pool 297 | if self.args.query_type == 'random': 298 | query_idx = np.random.permutation(len(self.pool))[:self.args.per_size] 299 | elif self.args.query_type == 'gold': 300 | query_idx = query.gold_acquistiion(self.pool, self.netD, self.args, self.device) 301 | else: 302 | raise NotImplementedError 303 | 304 | self.train_idx = list(set(self.train_idx) | set(query_idx)) 305 | self.pool_idx = list(set(self.pool_idx) - set(query_idx)) 306 | 307 | self.trainset = data_utils.Subset(self.base_dataset, self.train_idx) 308 | self.pool = data_utils.Subset(self.base_dataset, self.pool_idx) 309 | 310 | # print computation time 311 | query_time = int(time.time() - start_time) 312 | print('{:d}s elapsed'.format(query_time)) 313 | 314 | 315 | def main(): 316 | # get arguments 317 | parser = argparse.ArgumentParser() 318 | parser = add_parser(parser) 319 | args = parser.parse_args() 320 | 321 | # set random seed 322 | random.seed(args.seed) 323 | np.random.seed(args.seed) 324 | torch.manual_seed(args.seed) 325 | torch.cuda.manual_seed_all(args.seed) 326 | torch.backends.cudnn.deterministic = True 327 | torch.backends.cudnn.benchmark = False 328 | 329 | # define model and run 330 | model = BaseModel(args) 331 | model.run() 332 | 333 | 334 | if __name__ == '__main__': 335 | main() 336 | 337 | 338 | --------------------------------------------------------------------------------