├── .gitattributes ├── README.md ├── image_demo ├── README.md ├── main.py ├── model.py ├── model_resnet.py ├── spectral_normalization.py └── spectral_normalization_nondiff.py ├── results ├── hq_results.png ├── image_results.png └── toy.png └── toy_demo ├── 1d_GMM_exp.ipynb ├── README.md ├── data_loader.py ├── loss.py ├── main.py └── training.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | *.ipynb linguist-vendored -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Conditional Transport (CT) 2 | 3 | This is the pytorch implementation of experiments from the paper [Exploiting Chain Rule and Bayes' Theorem to Compare Probability Distributions](https://arxiv.org/abs/2012.14100), accepted in NeurIPS 2021. 4 | 5 | ![Toy results](results/toy.png) 6 | ![Image results](results/image_results.png) 7 | ![HQ results](results/hq_results.png) 8 | 9 | ## Quick starts 10 | 11 | To add CT loss in your existing code base, below we provide the key component of code for simplest usage: 12 | ```python 13 | ######################## Inputs ###################### 14 | # x: data B x C x W x H; 15 | # y: generated samples B x C x W x H; 16 | # netN: navigator network d -> 1 17 | # netD: critic network C x W x H -> d 18 | # rho: balance coefficient of forward-backward, default = 0.5 19 | 20 | def ct_loss(x, y, netN, netD, rho): 21 | ######################## compute cost ###################### 22 | f_x = netD(x) # feature of x: B x d 23 | f_y = netD(y) # feature of y: B x d 24 | cost = torch.norm(f_x[:,None] - f_y, dim=-1).pow(2) # pairwise cost: B x B 25 | 26 | ######################## compute transport map ###################### 27 | mse_n = (f_x[:,None] - f_y).pow(2) # pairwise mse for navigator network: B x B x d 28 | d = netN(mse_n).squeeze().mul(-1) # navigator distance: B x B 29 | forward_map = torch.softmax(d, dim=1) # forward map is in y wise 30 | backward_map = torch.softmax(d, dim=0) # backward map is in x wise 31 | 32 | ######################## compute CT loss ###################### 33 | # element-wise product of cost and transport map 34 | ct = rho * (cost * forward_map).sum(1).mean() + (1-rho) * (cost * backward_map).sum(0).mean() 35 | return ct 36 | ``` 37 | For most existing code for GANs, the modification of the output dimension of discriminator network and adding a MLP as navigator are sufficient to make it ready to run. 38 | 39 | We also provide our implementation for experiment on toy data and image data for demo. 40 | Please refer to corresponding subfolder for more information. 41 | 42 | ### Requirements 43 | - pytorch >= 1.2.0 44 | - seaborn == 0.9.0 45 | - pandas 46 | - sklearn 47 | - Tensorboard (for visualization of toy data experiments, but optional) 48 | 49 | Specifically, we run our experiments with pytorch 1.6.0, CUDA 10.2, cuDNN 7.0, and we also test our code with pytorch 1.2.0 and 1.7.0 to ensure reproducibility. 50 | 51 | 52 | ## Citation 53 | 54 | If you find this code useful, please cite our paper in your work. 55 | 56 | ``` 57 | @article{zheng2020act, 58 | title={Exploiting Chain Rule and Bayes' Theorem to Compare Probability Distributions}, 59 | author={Zheng, Huangjie and Zhou, Mingyuan}, 60 | journal={arXiv preprint arXiv:2012.14100}, 61 | year={2020} 62 | } 63 | ``` -------------------------------------------------------------------------------- /image_demo/README.md: -------------------------------------------------------------------------------- 1 | # Demo code for conditional transport (CT) on image experiments 2 | 3 | ## Requirements 4 | - pytorch >= 1.2.0 5 | - numpy 6 | 7 | ## Implementation Details 8 | This code is built on the implementation from the [repo](https://github.com/christiancosgrove/pytorch-spectral-normalization-gan) that implements both DCGAN-like and ResNet GAN architectures. 9 | In addition, training with standard, Wasserstein, and hinge losses is possible. 10 | 11 | 12 | ## Example usage 13 | To run the DCGAN backbone: 14 | 15 | `$ python main.py --model dcgan --loss ct` 16 | 17 | or run the SNGAN backbone: 18 | 19 | `$ python main.py --model resnet --loss ct` 20 | 21 | Use --help for more options 22 | 23 | `$ python main.py --help` -------------------------------------------------------------------------------- /image_demo/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.optim.lr_scheduler import ExponentialLR 7 | from torchvision import datasets, transforms 8 | import model_resnet 9 | import model 10 | 11 | import numpy as np 12 | import matplotlib 13 | matplotlib.use('Agg') 14 | import matplotlib.pyplot as plt 15 | import matplotlib.gridspec as gridspec 16 | import os 17 | 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--batch_size', type=int, default=256) 21 | parser.add_argument('--disc_iters', type=int, default=1) 22 | parser.add_argument('--gen_iters', type=int, default=5) 23 | parser.add_argument('--epochs', type=int, default=5000) 24 | parser.add_argument('--dim', type=int, default=256) 25 | parser.add_argument('--lr', type=float, default=2e-4) 26 | parser.add_argument('--rho', type=float, default=0.5) 27 | parser.add_argument('--loss', type=str, default='hinge') 28 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoints') 29 | parser.add_argument('--out_dir', type=str, default='out') 30 | 31 | 32 | parser.add_argument('--model' , type=str, default='resnet') 33 | 34 | args = parser.parse_args() 35 | 36 | loader = torch.utils.data.DataLoader( 37 | datasets.CIFAR10('../data/', train=True, download=True, 38 | transform=transforms.Compose([ 39 | transforms.ToTensor(), 40 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])), 41 | batch_size=args.batch_size, shuffle=True, num_workers=8) 42 | 43 | Z_dim = 128 44 | 45 | # discriminator = torch.nn.DataParallel(Discriminator()).cuda() # TODO: try out multi-gpu training 46 | if args.model == 'resnet': 47 | discriminator = model_resnet.Discriminator(args.dim, loss_type=args.loss).cuda() 48 | discriminator = torch.nn.DataParallel(discriminator) 49 | generator = model_resnet.Generator(Z_dim).cuda() 50 | generator = torch.nn.DataParallel(generator) 51 | else: 52 | discriminator = model.Discriminator(args.dim, loss_type=args.loss).cuda() 53 | discriminator = torch.nn.DataParallel(discriminator) 54 | generator = model.Generator(Z_dim).cuda() 55 | generator = torch.nn.DataParallel(generator) 56 | navigator = model.Navigator(dim=args.dim).cuda() 57 | navigator = torch.nn.DataParallel(navigator) 58 | # because the spectral normalization module creates parameters that don't require gradients (u and v), we don't want to 59 | # optimize these using sgd. We only let the optimizer operate on parameters that _do_ require gradients 60 | optim_disc = optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()), lr=args.lr, betas=(0.0,0.9)) 61 | optim_gen = optim.Adam(generator.parameters(), lr=args.lr, betas=(0.0,0.9)) 62 | optim_nav = optim.Adam(navigator.parameters(), lr=args.lr, betas=(0.0,0.9)) 63 | 64 | # use an exponentially decaying learning rate 65 | scheduler_d = optim.lr_scheduler.ExponentialLR(optim_disc, gamma=0.99) 66 | scheduler_g = optim.lr_scheduler.ExponentialLR(optim_gen, gamma=0.99) 67 | scheduler_n = optim.lr_scheduler.ExponentialLR(optim_nav, gamma=0.99) 68 | 69 | def train(epoch): 70 | for batch_idx, (data, target) in enumerate(loader): 71 | if data.size()[0] != args.batch_size: 72 | continue 73 | data, target = (data.cuda()), (target.cuda()) 74 | 75 | # update discriminator 76 | for _ in range(args.disc_iters): 77 | z = (torch.randn(args.batch_size, Z_dim).cuda()) 78 | optim_disc.zero_grad() 79 | optim_gen.zero_grad() 80 | if args.loss == 'hinge': 81 | disc_loss = nn.ReLU()(1.0 - discriminator(data)).mean() + nn.ReLU()(1.0 + discriminator(generator(z))).mean() 82 | elif args.loss == 'wasserstein': 83 | disc_loss = -discriminator(data).mean() + discriminator(generator(z)).mean() 84 | elif args.loss == 'ct': 85 | feat = discriminator(generator(z)) 86 | feat_x = discriminator(data) 87 | mse_n = (feat_x[:,None] - feat).pow(2) 88 | cost = mse_n.sum(-1) 89 | d = navigator(mse_n).squeeze().mul(-1) 90 | m_forward = torch.softmax(d, dim=1) 91 | m_backward = torch.softmax(d, dim=0) 92 | disc_loss = - args.rho * (cost * m_forward).sum(1).mean() - (1-args.rho) * (cost * m_backward).sum(0).mean() 93 | else: 94 | disc_loss = nn.BCEWithLogitsLoss()(discriminator(data), (torch.ones(args.batch_size, 1).cuda())) + \ 95 | nn.BCEWithLogitsLoss()(discriminator(generator(z)), (torch.zeros(args.batch_size, 1).cuda())) 96 | disc_loss.backward() 97 | optim_disc.step() 98 | 99 | # update generator 100 | for _ in range(args.gen_iters): 101 | z = (torch.randn(args.batch_size, Z_dim).cuda()) 102 | optim_disc.zero_grad() 103 | optim_gen.zero_grad() 104 | if args.loss == 'hinge' or args.loss == 'wasserstein': 105 | gen_loss = -discriminator(generator(z)).mean() 106 | elif args.loss == 'ct': 107 | optim_nav.zero_grad() 108 | feat = discriminator(generator(z)) 109 | feat_x = discriminator(data) 110 | mse_n = (feat_x[:,None] - feat).pow(2) 111 | cost = mse_n.sum(-1) 112 | d = navigator(mse_n).squeeze().mul(-1) 113 | m_forward = torch.softmax(d, dim=1) 114 | m_backward = torch.softmax(d, dim=0) 115 | gen_loss = args.rho * (cost * m_forward).sum(1).mean() + (1-args.rho) * (cost * m_backward).sum(0).mean() 116 | else: 117 | gen_loss = nn.BCEWithLogitsLoss()(discriminator(generator(z)), (torch.ones(args.batch_size, 1).cuda())) 118 | gen_loss.backward() 119 | if args.loss == 'ct': 120 | optim_nav.step() 121 | optim_gen.step() 122 | 123 | if batch_idx % 100 == 0: 124 | print(f"Epoch {epoch}/{args.epochs}\tIt {batch_idx}/{len(loader)}\t" + 'disc loss', disc_loss.item(), 'gen loss', gen_loss.item()) 125 | 126 | scheduler_d.step() 127 | scheduler_g.step() 128 | scheduler_n.step() 129 | 130 | fixed_z = (torch.randn(args.batch_size, Z_dim).cuda()) 131 | def evaluate(epoch): 132 | 133 | samples = generator(fixed_z).cpu().data.numpy()[:64] 134 | 135 | 136 | fig = plt.figure(figsize=(8, 8)) 137 | gs = gridspec.GridSpec(8, 8) 138 | gs.update(wspace=0.05, hspace=0.05) 139 | 140 | for i, sample in enumerate(samples): 141 | ax = plt.subplot(gs[i]) 142 | plt.axis('off') 143 | ax.set_xticklabels([]) 144 | ax.set_yticklabels([]) 145 | ax.set_aspect('equal') 146 | plt.imshow(sample.transpose((1,2,0)) * 0.5 + 0.5) 147 | 148 | plt.savefig(os.path.join(args.out_dir,'{}.png'.format(str(epoch).zfill(3))), bbox_inches='tight') 149 | plt.close(fig) 150 | 151 | os.makedirs(args.checkpoint_dir, exist_ok=True) 152 | os.makedirs(args.out_dir, exist_ok=True) 153 | 154 | for epoch in range(args.epochs): 155 | train(epoch) 156 | evaluate(epoch) 157 | torch.save(discriminator.state_dict(), os.path.join(args.checkpoint_dir, 'disc_{}'.format(epoch))) 158 | torch.save(generator.state_dict(), os.path.join(args.checkpoint_dir, 'gen_{}'.format(epoch))) 159 | 160 | -------------------------------------------------------------------------------- /image_demo/model.py: -------------------------------------------------------------------------------- 1 | # DCGAN-like generator and discriminator 2 | import torch 3 | import numpy as np 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | from spectral_normalization import SpectralNorm 8 | 9 | channels = 3 10 | leak = 0.1 11 | w_g = 4 12 | 13 | class Generator(nn.Module): 14 | def __init__(self, z_dim): 15 | super(Generator, self).__init__() 16 | self.z_dim = z_dim 17 | 18 | self.model = nn.Sequential( 19 | nn.ConvTranspose2d(z_dim, 512, 4, stride=1), 20 | nn.BatchNorm2d(512), 21 | nn.ReLU(), 22 | nn.ConvTranspose2d(512, 256, 4, stride=2, padding=(1,1)), 23 | nn.BatchNorm2d(256), 24 | nn.ReLU(), 25 | nn.ConvTranspose2d(256, 128, 4, stride=2, padding=(1,1)), 26 | nn.BatchNorm2d(128), 27 | nn.ReLU(), 28 | nn.ConvTranspose2d(128, 64, 4, stride=2, padding=(1,1)), 29 | nn.BatchNorm2d(64), 30 | nn.ReLU(), 31 | nn.ConvTranspose2d(64, channels, 3, stride=1, padding=(1,1)), 32 | nn.Tanh()) 33 | 34 | def forward(self, z): 35 | return self.model(z.view(-1, self.z_dim, 1, 1)) 36 | 37 | class Discriminator(nn.Module): 38 | def __init__(self,dim=256, loss_type='ct'): 39 | super(Discriminator, self).__init__() 40 | self.loss_type = loss_type 41 | 42 | self.conv1 = SpectralNorm(nn.Conv2d(channels, 64, 3, stride=1, padding=(1,1))) 43 | 44 | self.conv2 = SpectralNorm(nn.Conv2d(64, 64, 4, stride=2, padding=(1,1))) 45 | self.conv3 = SpectralNorm(nn.Conv2d(64, 128, 3, stride=1, padding=(1,1))) 46 | self.conv4 = SpectralNorm(nn.Conv2d(128, 128, 4, stride=2, padding=(1,1))) 47 | self.conv5 = SpectralNorm(nn.Conv2d(128, 256, 3, stride=1, padding=(1,1))) 48 | self.conv6 = SpectralNorm(nn.Conv2d(256, 256, 4, stride=2, padding=(1,1))) 49 | self.conv7 = SpectralNorm(nn.Conv2d(256, 512, 3, stride=1, padding=(1,1))) 50 | 51 | 52 | self.fc = SpectralNorm(nn.Linear(w_g * w_g * 512, dim)) 53 | 54 | def forward(self, x): 55 | bs = x.shape[0] 56 | m = x 57 | m = nn.LeakyReLU(leak)(self.conv1(m)) 58 | m = nn.LeakyReLU(leak)(self.conv2(m)) 59 | m = nn.LeakyReLU(leak)(self.conv3(m)) 60 | m = nn.LeakyReLU(leak)(self.conv4(m)) 61 | m = nn.LeakyReLU(leak)(self.conv5(m)) 62 | m = nn.LeakyReLU(leak)(self.conv6(m)) 63 | m = nn.LeakyReLU(leak)(self.conv7(m)) 64 | m = m.view(bs, -1) 65 | m = self.fc(m) 66 | if self.loss_type == 'ct': 67 | m = m/torch.sqrt(torch.sum(m.square(), dim=-1, keepdim=True)) 68 | return m 69 | 70 | class Navigator(nn.Module): 71 | def __init__(self, hidden=512,dim=256): 72 | super(Navigator, self).__init__() 73 | 74 | self.fc1 = nn.Linear(dim, hidden) 75 | self.fc2 = nn.Linear(hidden, hidden // 2) 76 | self.fc3 = nn.Linear(hidden // 2, 1) 77 | self.act = nn.LeakyReLU() 78 | self.norm1 = nn.BatchNorm1d(hidden) 79 | self.norm2 = nn.BatchNorm1d(hidden//2) 80 | 81 | def forward(self, x): 82 | x = self.fc1(x) 83 | x = self.act(x) 84 | x = self.fc2(x) 85 | x = self.act(x) 86 | 87 | return self.fc3(x) 88 | -------------------------------------------------------------------------------- /image_demo/model_resnet.py: -------------------------------------------------------------------------------- 1 | # ResNet generator and discriminator 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | from spectral_normalization import SpectralNorm 7 | import numpy as np 8 | 9 | 10 | channels = 3 11 | 12 | class ResBlockGenerator(nn.Module): 13 | 14 | def __init__(self, in_channels, out_channels, stride=1): 15 | super(ResBlockGenerator, self).__init__() 16 | 17 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=1) 18 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding=1) 19 | nn.init.xavier_uniform(self.conv1.weight.data, 1.) 20 | nn.init.xavier_uniform(self.conv2.weight.data, 1.) 21 | 22 | self.model = nn.Sequential( 23 | nn.BatchNorm2d(in_channels), 24 | nn.ReLU(), 25 | nn.Upsample(scale_factor=2), 26 | self.conv1, 27 | nn.BatchNorm2d(out_channels), 28 | nn.ReLU(), 29 | self.conv2 30 | ) 31 | self.bypass = nn.Sequential() 32 | if stride != 1: 33 | self.bypass = nn.Upsample(scale_factor=2) 34 | 35 | def forward(self, x): 36 | return self.model(x) + self.bypass(x) 37 | 38 | 39 | class ResBlockDiscriminator(nn.Module): 40 | 41 | def __init__(self, in_channels, out_channels, stride=1): 42 | super(ResBlockDiscriminator, self).__init__() 43 | 44 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=1) 45 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding=1) 46 | nn.init.xavier_uniform(self.conv1.weight.data, 1.) 47 | nn.init.xavier_uniform(self.conv2.weight.data, 1.) 48 | 49 | if stride == 1: 50 | self.model = nn.Sequential( 51 | nn.ReLU(), 52 | SpectralNorm(self.conv1), 53 | nn.ReLU(), 54 | SpectralNorm(self.conv2) 55 | ) 56 | else: 57 | self.model = nn.Sequential( 58 | nn.ReLU(), 59 | SpectralNorm(self.conv1), 60 | nn.ReLU(), 61 | SpectralNorm(self.conv2), 62 | nn.AvgPool2d(2, stride=stride, padding=0) 63 | ) 64 | self.bypass = nn.Sequential() 65 | if stride != 1: 66 | 67 | self.bypass_conv = nn.Conv2d(in_channels,out_channels, 1, 1, padding=0) 68 | nn.init.xavier_uniform(self.bypass_conv.weight.data, np.sqrt(2)) 69 | 70 | self.bypass = nn.Sequential( 71 | SpectralNorm(self.bypass_conv), 72 | nn.AvgPool2d(2, stride=stride, padding=0) 73 | ) 74 | # if in_channels == out_channels: 75 | # self.bypass = nn.AvgPool2d(2, stride=stride, padding=0) 76 | # else: 77 | # self.bypass = nn.Sequential( 78 | # SpectralNorm(nn.Conv2d(in_channels,out_channels, 1, 1, padding=0)), 79 | # nn.AvgPool2d(2, stride=stride, padding=0) 80 | # ) 81 | 82 | 83 | def forward(self, x): 84 | return self.model(x) + self.bypass(x) 85 | 86 | # special ResBlock just for the first layer of the discriminator 87 | class FirstResBlockDiscriminator(nn.Module): 88 | 89 | def __init__(self, in_channels, out_channels, stride=1): 90 | super(FirstResBlockDiscriminator, self).__init__() 91 | 92 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=1) 93 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding=1) 94 | self.bypass_conv = nn.Conv2d(in_channels, out_channels, 1, 1, padding=0) 95 | nn.init.xavier_uniform(self.conv1.weight.data, 1.) 96 | nn.init.xavier_uniform(self.conv2.weight.data, 1.) 97 | nn.init.xavier_uniform(self.bypass_conv.weight.data, np.sqrt(2)) 98 | 99 | # we don't want to apply ReLU activation to raw image before convolution transformation. 100 | self.model = nn.Sequential( 101 | SpectralNorm(self.conv1), 102 | nn.ReLU(), 103 | SpectralNorm(self.conv2), 104 | nn.AvgPool2d(2) 105 | ) 106 | self.bypass = nn.Sequential( 107 | nn.AvgPool2d(2), 108 | SpectralNorm(self.bypass_conv), 109 | ) 110 | 111 | def forward(self, x): 112 | return self.model(x) + self.bypass(x) 113 | 114 | GEN_SIZE=128 115 | DISC_SIZE=128 116 | 117 | class Generator(nn.Module): 118 | def __init__(self, z_dim): 119 | super(Generator, self).__init__() 120 | self.z_dim = z_dim 121 | 122 | self.dense = nn.Linear(self.z_dim, 4 * 4 * GEN_SIZE) 123 | self.final = nn.Conv2d(GEN_SIZE, channels, 3, stride=1, padding=1) 124 | nn.init.xavier_uniform(self.dense.weight.data, 1.) 125 | nn.init.xavier_uniform(self.final.weight.data, 1.) 126 | 127 | self.model = nn.Sequential( 128 | ResBlockGenerator(GEN_SIZE, GEN_SIZE, stride=2), 129 | ResBlockGenerator(GEN_SIZE, GEN_SIZE, stride=2), 130 | ResBlockGenerator(GEN_SIZE, GEN_SIZE, stride=2), 131 | nn.BatchNorm2d(GEN_SIZE), 132 | nn.ReLU(), 133 | self.final, 134 | nn.Tanh()) 135 | 136 | def forward(self, z): 137 | return self.model(self.dense(z).view(-1, GEN_SIZE, 4, 4)) 138 | 139 | class Discriminator(nn.Module): 140 | def __init__(self, dim, loss_type='ct'): 141 | super(Discriminator, self).__init__() 142 | self.loss_type = loss_type 143 | 144 | self.model = nn.Sequential( 145 | FirstResBlockDiscriminator(channels, DISC_SIZE, stride=2), 146 | ResBlockDiscriminator(DISC_SIZE, DISC_SIZE, stride=2), 147 | ResBlockDiscriminator(DISC_SIZE, DISC_SIZE), 148 | ResBlockDiscriminator(DISC_SIZE, DISC_SIZE), 149 | nn.ReLU(), 150 | nn.AvgPool2d(8), 151 | ) 152 | self.fc = nn.Linear(DISC_SIZE, dim) 153 | nn.init.xavier_uniform(self.fc.weight.data, 1.) 154 | self.fc = SpectralNorm(self.fc) 155 | 156 | def forward(self, x): 157 | x = self.model(x).view(-1,DISC_SIZE) 158 | x = self.fc(x) 159 | 160 | if self.loss_type == 'ct': 161 | x = x/torch.sqrt(torch.sum(x.square(), dim=-1, keepdim=True)) 162 | return x -------------------------------------------------------------------------------- /image_demo/spectral_normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer, required 3 | 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from torch import Tensor 8 | from torch.nn import Parameter 9 | 10 | def l2normalize(v, eps=1e-12): 11 | return v / (v.norm() + eps) 12 | 13 | 14 | class SpectralNorm(nn.Module): 15 | def __init__(self, module, name='weight', power_iterations=1): 16 | super(SpectralNorm, self).__init__() 17 | self.module = module 18 | self.name = name 19 | self.power_iterations = power_iterations 20 | if not self._made_params(): 21 | self._make_params() 22 | 23 | def _update_u_v(self): 24 | u = getattr(self.module, self.name + "_u") 25 | v = getattr(self.module, self.name + "_v") 26 | w = getattr(self.module, self.name + "_bar") 27 | 28 | height = w.data.shape[0] 29 | for _ in range(self.power_iterations): 30 | v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data)) 31 | u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data)) 32 | 33 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) 34 | sigma = u.dot(w.view(height, -1).mv(v)) 35 | setattr(self.module, self.name, w / sigma.expand_as(w)) 36 | 37 | def _made_params(self): 38 | try: 39 | u = getattr(self.module, self.name + "_u") 40 | v = getattr(self.module, self.name + "_v") 41 | w = getattr(self.module, self.name + "_bar") 42 | return True 43 | except AttributeError: 44 | return False 45 | 46 | 47 | def _make_params(self): 48 | w = getattr(self.module, self.name) 49 | 50 | height = w.data.shape[0] 51 | width = w.view(height, -1).data.shape[1] 52 | 53 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 54 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 55 | u.data = l2normalize(u.data) 56 | v.data = l2normalize(v.data) 57 | w_bar = Parameter(w.data) 58 | 59 | del self.module._parameters[self.name] 60 | 61 | self.module.register_parameter(self.name + "_u", u) 62 | self.module.register_parameter(self.name + "_v", v) 63 | self.module.register_parameter(self.name + "_bar", w_bar) 64 | 65 | 66 | def forward(self, *args): 67 | self._update_u_v() 68 | return self.module.forward(*args) 69 | -------------------------------------------------------------------------------- /image_demo/spectral_normalization_nondiff.py: -------------------------------------------------------------------------------- 1 | # non-differentiable spectral normalization module 2 | # weight tensors are normalized directly 3 | import torch 4 | from torch.optim.optimizer import Optimizer, required 5 | 6 | from torch.autograd import Variable 7 | import torch.nn.functional as F 8 | from torch import nn 9 | from torch import Tensor 10 | from torch.nn import Parameter 11 | 12 | def l2normalize(v, eps=1e-12): 13 | return v / (v.norm() + eps) 14 | 15 | class SpectralNorm(nn.Module): 16 | def __init__(self, module, name='weight', power_iterations=1): 17 | super(SpectralNorm, self).__init__() 18 | self.module = module 19 | self.name = name 20 | self.power_iterations = power_iterations 21 | 22 | def _update_u_v(self): 23 | if not self._made_params(): 24 | self._make_params() 25 | w = getattr(self.module, self.name) 26 | u = getattr(self.module, self.name + "_u") 27 | 28 | height = w.data.shape[0] 29 | for _ in range(self.power_iterations): 30 | v = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u)) 31 | u = l2normalize(torch.mv(w.view(height,-1).data, v)) 32 | 33 | setattr(self.module, self.name + "_u", u) 34 | w.data = w.data / torch.dot(u, torch.mv(w.view(height,-1).data, v)) 35 | 36 | def _made_params(self): 37 | try: 38 | u = getattr(self.module, self.name + "_u") 39 | return True 40 | except AttributeError: 41 | return False 42 | 43 | 44 | def _make_params(self): 45 | w = getattr(self.module, self.name) 46 | 47 | height = w.data.shape[0] 48 | width = w.view(height, -1).data.shape[1] 49 | 50 | u = l2normalize(w.data.new(height).normal_(0, 1)) 51 | 52 | self.module.register_buffer(self.name + "_u", u) 53 | 54 | 55 | def forward(self, *args): 56 | self._update_u_v() 57 | return self.module.forward(*args) -------------------------------------------------------------------------------- /results/hq_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JegZheng/CT-pytorch/6821edc93354ae4c533904548b79335b534a8c33/results/hq_results.png -------------------------------------------------------------------------------- /results/image_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JegZheng/CT-pytorch/6821edc93354ae4c533904548b79335b534a8c33/results/image_results.png -------------------------------------------------------------------------------- /results/toy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JegZheng/CT-pytorch/6821edc93354ae4c533904548b79335b534a8c33/results/toy.png -------------------------------------------------------------------------------- /toy_demo/README.md: -------------------------------------------------------------------------------- 1 | ## Demo code for conditional transport (CT) on toy experiments 2 | 3 | "1d_GMM_exp.ipynb": The 1d-GMM experiments 4 | Other files: Experiments on 2d toy datasets 5 | 6 | ### Requirements 7 | - pytorch >= 1.2.0 8 | - seaborn == 0.9.0 9 | - pandas 10 | - sklearn 11 | - Tensorboard (for visualization of toy data experiments, but optional) 12 | 13 | ### Example usage 14 | 15 | `$ python main.py --dataset 8gaussians --method CT` 16 | 17 | or run all methods on all toy datasets with 18 | 19 | `$ python main.py --run_all` 20 | 21 | Use --help for more options 22 | 23 | `$ python main.py --help` -------------------------------------------------------------------------------- /toy_demo/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from sklearn.datasets import make_swiss_roll, make_moons, make_circles, make_s_curve 5 | 6 | def load_data(name='swiss_roll', n_samples=500): 7 | N=n_samples 8 | if name == 'swiss_roll': 9 | temp=make_swiss_roll(n_samples=N, noise=0.05)[0][:,(0,2)] 10 | temp/=abs(temp).max() 11 | elif name == 'half_moons': 12 | temp=make_moons(n_samples=N, noise=0.02)[0] 13 | temp/=abs(temp).max() 14 | elif name == '2gaussians': 15 | scale = 2. 16 | centers = [ 17 | (1. / np.sqrt(2), 1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2)) 18 | ] 19 | centers = [(scale * x, scale * y) for x, y in centers] 20 | temp = [] 21 | for i in range(N): 22 | point = np.random.randn(2) * .02 23 | center = centers[np.random.choice(np.arange(len(centers)))] 24 | point[0] += center[0] 25 | point[1] += center[1] 26 | temp.append(point) 27 | temp = np.array(temp, dtype='float32') 28 | temp /= 1.414 # stdev 29 | elif name == '8gaussians': 30 | scale = 2. 31 | centers = [ 32 | (1, 0), (-1, 0), (0, 1), (0, -1), 33 | (1. / np.sqrt(2), 1. / np.sqrt(2)), (1. / np.sqrt(2), -1. / np.sqrt(2)), 34 | (-1. / np.sqrt(2), 1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2)) 35 | ] 36 | centers = [(scale * x, scale * y) for x, y in centers] 37 | temp = [] 38 | for i in range(N): 39 | point = np.random.randn(2) * .02 40 | center = centers[np.random.choice(np.arange(len(centers)))] 41 | point[0] += center[0] 42 | point[1] += center[1] 43 | temp.append(point) 44 | temp = np.array(temp, dtype='float32') 45 | temp /= 1.414 # stdev 46 | elif name == '25gaussians': 47 | temp = [] 48 | for i in range(int(N / 25)): 49 | for x in range(-2, 3): 50 | for y in range(-2, 3): 51 | point = np.random.randn(2) * 0.05 52 | point[0] += 2 * x 53 | point[1] += 2 * y 54 | temp.append(point) 55 | temp = np.array(temp, dtype='float32') 56 | np.random.shuffle(temp) 57 | temp /= 2.828 # stdev 58 | elif name == 'circle': 59 | temp,y=make_circles(n_samples=N, noise=0.05) 60 | temp=temp[np.argwhere(y==0).squeeze(),:] 61 | elif name == 's_curve': 62 | temp = make_s_curve(n_samples=500, noise=0.02)[0] 63 | temp = np.stack([temp[:,0], temp[:,2] ],axis=1) 64 | else: 65 | raise Exception("Dataset not found: name must be 'swiss_roll', 'half_moons', 'circle', 's_curve', '8gaussians' or '25gaussians'.") 66 | X=torch.from_numpy(temp).float() 67 | return X 68 | -------------------------------------------------------------------------------- /toy_demo/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.autograd as autograd 5 | 6 | # loss for Sliced Wasserstein Generator 7 | def wasserstein1d(x, y): 8 | n = x.size(0) 9 | x1, _ = torch.sort(x, dim=0) 10 | y1, _ = torch.sort(y, dim=0) 11 | z = (x1-y1) 12 | return (torch.norm(z,dim=0)**2).mean() 13 | 14 | 15 | # gradient penalty for WGAN-GP 16 | def calc_gradient_penalty(netD, real_data, fake_data): 17 | alpha = torch.rand(real_data.size(0), 1) 18 | alpha = alpha.expand(real_data.size()) 19 | alpha = alpha.to(real_data.device) 20 | 21 | interpolates = alpha * real_data + ((1 - alpha) * fake_data) 22 | 23 | 24 | interpolates = interpolates.to(real_data.device) 25 | interpolates = autograd.Variable(interpolates, requires_grad=True) 26 | 27 | disc_interpolates = netD(interpolates) 28 | 29 | gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, 30 | grad_outputs=torch.ones(disc_interpolates.size()).to(real_data.device), 31 | create_graph=True, retain_graph=True, only_inputs=True)[0] 32 | 33 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * .1 34 | return gradient_penalty -------------------------------------------------------------------------------- /toy_demo/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import matplotlib 4 | import matplotlib.pyplot as plt 5 | 6 | matplotlib.style.use('ggplot') 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | import torch.utils.data as data_utils 12 | import os 13 | from data_loader import load_data 14 | import datetime 15 | from torch.utils.tensorboard import SummaryWriter 16 | from training import * 17 | import seaborn as sns 18 | 19 | parser = argparse.ArgumentParser(description='Conditional transport experiment on toydata') 20 | # dataset options 21 | parser.add_argument('--batchsize', type=int, default=100, metavar='N', 22 | help='input batch size for training (default: 100)') 23 | parser.add_argument('--dataset', type=str, default="swiss_roll", metavar='D', 24 | help='Dataset: swiss_roll|half_moons|circle|s_curve|2gaussians|8gaussians|25gaussians') 25 | parser.add_argument('--toysize', type=int, default=2000, metavar='N', 26 | help='toy dataset size for training (default: 2000)') 27 | # training options 28 | parser.add_argument('--method', type=str, default="ACT", metavar='D', 29 | help='CT|CT_withD|GAN|SWD|MSWD') 30 | parser.add_argument('--epochs', type=int, default=10000, metavar='N', 31 | help='number of epochs to train (default: 10000)') 32 | parser.add_argument('--z_dim', type=int, default=50, metavar='N', 33 | help='dimensionality of z (default: 50)') 34 | parser.add_argument('--x_dim', type=int, default=2, metavar='N', 35 | help='dimensionality of x (default: 2)') 36 | parser.add_argument('--p_dim', type=int, default=1, metavar='N', 37 | help='dimensionality of projected x (default: 1)') 38 | parser.add_argument('--d_dim', type=int, default=20, metavar='N', 39 | help='dimensionality of feature x_feat (default: 20)') 40 | parser.add_argument('--learning-rate', type=float, default=2e-4, 41 | help='learning rate for Adam') 42 | parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") 43 | parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") 44 | parser.add_argument('--device', type=str, default="0", metavar='D', 45 | help='which device for training: 0, 1, 2, 3 (GPU) or cpu') 46 | parser.add_argument('--run_all', action='store_true', default=False, help='activate to run all methods on all datasets') 47 | 48 | # CT options 49 | parser.add_argument('--rho', type=float, default=0.5, 50 | help='balance coefficient for forward-backward (default: 0.5)') 51 | parser.add_argument('--use_cos_distance', action='store_true', default=False, help='activate to use cosine distance as feature cost') 52 | 53 | # SWD options 54 | parser.add_argument('--n_projections', type=int, default=1000, metavar='N', 55 | help='number of projections for input x (default: 1000)') 56 | 57 | # saving options 58 | parser.add_argument('--remark', type=str, default="experiment1", metavar='R', 59 | help='leave some remark for this experiment') 60 | parser.add_argument('--save_fig', action='store_true', default=False, help='activate to save sampled and reconstructed figures') 61 | 62 | args = parser.parse_args() 63 | device = 'cuda:' + args.device if torch.cuda.is_available() else 'cpu' 64 | 65 | # Generator architecture 66 | class Generator(torch.nn.Module): 67 | def __init__(self, D_in, H, D_out): 68 | super(Generator, self).__init__() 69 | self.model = nn.Sequential(nn.Linear(D_in, H), 70 | nn.BatchNorm1d(H), 71 | nn.LeakyReLU(), 72 | nn.Linear(H, H//2), 73 | nn.BatchNorm1d(H//2), 74 | nn.LeakyReLU(), 75 | torch.nn.Linear(H//2, D_out) 76 | ) 77 | 78 | def forward(self, x): 79 | mu = self.model(x) 80 | return mu 81 | 82 | 83 | # Navigator/Discriminator/Feature encoder for CT, GAN and WGAN 84 | class Projector(torch.nn.Module): 85 | def __init__(self, D_in, H, D_out=1, use_cos_distance=False): 86 | super(Projector, self).__init__() 87 | self.model = nn.Sequential(nn.Linear(D_in, H), 88 | nn.LeakyReLU(), 89 | nn.Linear(H, H//2), 90 | nn.LeakyReLU(), 91 | nn.Linear(H//2, D_out) 92 | ) 93 | self.use_cos_distance=use_cos_distance 94 | 95 | def forward(self, x): 96 | logit = self.model(x) 97 | if self.use_cos_distance: 98 | logit = F.normalize(logit, p=2, dim=-1) 99 | return logit 100 | 101 | 102 | def main(): 103 | print(args) 104 | # saving path 105 | name = args.method + '_' + args.remark + '_' + str(args.dataset) + '_'+ str(datetime.datetime.now()).replace(' ', '_') 106 | model_path = os.path.join('models', name) 107 | if not os.path.exists(model_path): 108 | os.makedirs(model_path) 109 | img_path = os.path.join('imgs', args.method + '_' + args.remark) 110 | if not os.path.exists(img_path): 111 | os.makedirs(img_path) 112 | 113 | X_data = load_data(name=args.dataset, n_samples=args.toysize) 114 | dataloader = data_utils.DataLoader(X_data, shuffle=True, 115 | batch_size=args.batchsize) 116 | # Tensorboard: optional for visualization 117 | writer = SummaryWriter(log_dir= os.path.join('runs', 'toy', name)) 118 | 119 | G = Generator(D_in=args.z_dim, H=100, D_out=args.x_dim).to(device) 120 | g_opt = optim.Adam(G.parameters(), lr=args.learning_rate) 121 | 122 | if args.method =='CT_withD': 123 | D = Projector(D_in=args.x_dim, H=100, D_out=args.d_dim, use_cos_distance=args.use_cos_distance).to(device) 124 | d_opt = optim.Adam(D.parameters(), lr=args.learning_rate) 125 | P = Projector(D_in=args.d_dim, H=100, D_out=args.p_dim).to(device) 126 | else: 127 | P = Projector(D_in=args.x_dim, H=100, D_out=args.p_dim).to(device) 128 | p_opt = optim.Adam(P.parameters(), lr=args.learning_rate/5) 129 | z_fix = torch.randn(X_data.size(0), args.z_dim, requires_grad=False).to(device) 130 | 131 | p_stats = [] 132 | g_stats = [] 133 | swd_stats = [] 134 | # training 135 | for epoch in range(args.epochs+1): 136 | if args.method == 'CT': 137 | ploss, gloss = train_ct(args, epoch, G, P, g_opt, p_opt, dataloader, device, writer) 138 | p_stats.append(ploss) 139 | g_stats.append(gloss) 140 | elif args.method == 'CT_withD': 141 | ploss, gloss = train_ct_withD(args, epoch, G, P, D, g_opt, p_opt, d_opt, dataloader, device, writer) 142 | p_stats.append(ploss) 143 | g_stats.append(gloss) 144 | elif args.method == 'GAN': 145 | d_criterion = nn.BCEWithLogitsLoss() 146 | ploss, gloss = train_GAN(args, epoch, G, P, g_opt, p_opt, d_criterion, dataloader, device, writer) 147 | p_stats.append(ploss) 148 | g_stats.append(gloss) 149 | elif args.method == 'WGANGP': 150 | ploss, gloss = train_WGANGP(args, epoch, G, P, g_opt, p_opt, dataloader, device, writer) 151 | p_stats.append(ploss) 152 | g_stats.append(gloss) 153 | elif args.method == 'SWD': 154 | gloss = train_SWD(args, epoch, G, g_opt, dataloader, device, writer) 155 | g_stats.append(gloss) 156 | else: 157 | raise Exception("Method not found: name must be 'GAN', 'SWD', 'WGANGP', 'CT', 'CT_withD'.") 158 | 159 | # test 160 | if epoch % 100 == 0: 161 | with torch.no_grad(): 162 | X = X_data.to(device) 163 | xpred = G(z_fix) 164 | # plot true data distribution 165 | if epoch == 0: 166 | fig, (ax) = plt.subplots(1,1,figsize=(6,6)) 167 | sns.kdeplot(X.detach().cpu().numpy()[:,0], X.detach().cpu().numpy()[:,1], ax=ax, cmap="Greens", shade=True, bw=0.1) 168 | writer.add_figure('data_distribution', fig, epoch) 169 | fig.savefig(os.path.join(img_path ,'{}_true.pdf'.format(args.dataset))) 170 | plt.close() 171 | # plot generated data distribution every 100 epochs 172 | if epoch % 100 == 0: 173 | fig, (ax) = plt.subplots(1,1,figsize=(6,6)) 174 | sns.kdeplot(xpred.detach().cpu().numpy()[:,0], xpred.detach().cpu().numpy()[:,1], ax=ax, cmap="Greens", shade=True, bw=0.1) 175 | writer.add_figure('test_distribution', fig, epoch) 176 | fig.savefig(os.path.join(img_path , '{}_fake_{}.pdf'.format(args.dataset,epoch))) 177 | plt.close() 178 | 179 | 180 | # save training checkpoints and status 181 | torch.save(G.state_dict(), model_path + name + '.G') 182 | torch.save(P.state_dict(), model_path + name + '.P') 183 | torch.save(g_stats, model_path + name + '.gstat') 184 | if not args.method =='SWD': 185 | torch.save(p_stats, model_path + name + '.pstat') 186 | 187 | 188 | 189 | if __name__ == '__main__': 190 | if args.run_all: 191 | methods = ['CT', 'CT_withD', 'GAN', 'WGANGP', 'SWD'] 192 | datasets = ['swiss_roll', 'half_moons', '8gaussians', '25gaussians'] 193 | for method in methods: 194 | for dataset in datasets: 195 | args.method = method 196 | args.dataset = dataset 197 | main() 198 | else: 199 | main() 200 | 201 | 202 | 203 | 204 | 205 | -------------------------------------------------------------------------------- /toy_demo/training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from loss import * 4 | import numpy as np 5 | 6 | 7 | def train_ct(args, epoch, G, P, g_opt, p_opt, dataloader, device, writer): 8 | # initialization of navgator projection loss and generator loss 9 | n_loss = 0 10 | g_loss = 0 11 | for i, X in enumerate(dataloader): 12 | N = X.size(0) 13 | X = X.view(-1, args.x_dim).to(device) 14 | rho = args.rho 15 | # generate samples B x d 16 | z = torch.randn(N, args.z_dim).to(device) 17 | xpred = G(z) 18 | 19 | diff = (X[:,None]-xpred).pow(2) #pairwise mse for navigator network: B x B x h 20 | cost = diff.sum(-1) #pairwise cost: B x B 21 | tmp = P(diff).squeeze() # navigator distance: B x B 22 | m_backward = torch.nn.functional.softmax(tmp, dim=0) # backward map 23 | m_forward = torch.nn.functional.softmax(tmp, dim=1) # forward map 24 | 25 | p_opt.zero_grad() 26 | g_opt.zero_grad() 27 | gloss = (cost * m_forward).sum(1).mean() # forward transport 28 | nloss = (cost * m_backward).sum(0).mean() # backward transport 29 | loss = rho * gloss + (1-rho) * nloss 30 | loss.backward() 31 | g_opt.step() 32 | p_opt.step() 33 | g_loss += gloss.item() 34 | n_loss += nloss.item() 35 | 36 | n_iter = epoch * len(dataloader) + i 37 | writer.add_scalar('Loss/Generator', gloss.item(), n_iter) 38 | writer.add_scalar('Loss/Sorter', nloss.item(), n_iter) 39 | 40 | 41 | g_loss /= i+1 42 | n_loss /= i+1 43 | 44 | if epoch % 100 == 0: 45 | print("{} {} Epoch {}: \t nloss {} \t gloss {} \t ".format(args.method, args.dataset, epoch, nloss.item(), gloss.item())) 46 | return nloss.item(), gloss.item() 47 | 48 | def train_ct_withD(args, epoch, G, P, D, g_opt, p_opt, d_opt, dataloader, device, writer): 49 | n_loss = 0 50 | g_loss = 0 51 | for i, X in enumerate(dataloader): 52 | N = X.size(0) 53 | X = X.view(-1, args.x_dim).to(device) 54 | rho = args.rho 55 | 56 | z = torch.randn(N, args.z_dim).to(device) 57 | 58 | # ---------------------------------------- 59 | # Update Generator and Navigator network: minimize ct loss 60 | xpred = G(z) 61 | xpred_feat = D(xpred) # feature of generations: B x d 62 | x_feat = D(X) # feature of data: B x d 63 | 64 | cost = torch.norm(x_feat[:,None]-xpred_feat, dim=-1).pow(2) #pairwise cost: B x B 65 | diff = (x_feat[:,None]-xpred_feat).pow(2) #pairwise mse for navigator network: B x B x d 66 | tmp = P(diff).squeeze() # navigator distance: B x B 67 | weight_x = torch.nn.functional.softmax(tmp, dim=0) # backward map 68 | weight_xpred = torch.nn.functional.softmax(tmp, dim=1) # forward map 69 | 70 | g_opt.zero_grad() 71 | p_opt.zero_grad() 72 | gloss = (cost * weight_xpred).sum(1).mean() # forward transport 73 | nloss = (cost * weight_x).sum(0).mean() # backward transport 74 | loss = rho * gloss + (1-rho) * nloss 75 | loss.backward() 76 | g_opt.step() 77 | p_opt.step() 78 | 79 | n_loss += nloss.item() 80 | g_loss += gloss.item() 81 | 82 | 83 | # ---------------------------------------- 84 | # Update Critic network D: maximize ct loss 85 | xpred = G(z) 86 | xpred_feat = D(xpred) 87 | x_feat = D(X) 88 | 89 | cost = torch.norm(x_feat[:,None]-xpred_feat, dim=-1).pow(2) 90 | 91 | diff = (x_feat[:,None]-xpred_feat).pow(2) 92 | tmp = P(diff).squeeze() 93 | weight_x = torch.nn.functional.softmax(tmp, dim=0) # backward map 94 | weight_xpred = torch.nn.functional.softmax(tmp, dim=1) # forward map 95 | 96 | d_opt.zero_grad() 97 | dloss = -((1-rho)*(cost * weight_x).sum(0).mean() + rho*(cost * weight_xpred).sum(1).mean()) 98 | dloss.backward() 99 | d_opt.step() 100 | n_iter = epoch * len(dataloader) + i 101 | writer.add_scalar('Loss/Generator', gloss.item(), n_iter) 102 | writer.add_scalar('Loss/Sorter', nloss.item(), n_iter) 103 | 104 | 105 | g_loss /= i+1 106 | n_loss /= i+1 107 | 108 | if epoch % 100 == 0: 109 | print("{} {} Epoch {}: \t nloss {} \t gloss {} \t ".format(args.method, args.dataset, epoch, nloss.item(), gloss.item())) 110 | return nloss.item(), gloss.item() 111 | 112 | 113 | def train_GAN(args, epoch, G, D, g_opt, d_opt, d_criterion, dataloader, device, writer): 114 | d_loss = 0 115 | g_loss = 0 116 | for i, X in enumerate(dataloader): 117 | N = X.size(0) 118 | X = X.view(-1, args.x_dim).to(device) 119 | 120 | z = torch.randn(N, args.z_dim).to(device) 121 | xpred = G(z).view(-1, args.x_dim) 122 | xpred_1d = D(xpred) 123 | x_1d = D(X) 124 | 125 | g_opt.zero_grad() 126 | gloss = d_criterion(xpred_1d, torch.ones_like(xpred_1d)) 127 | gloss.backward() 128 | g_opt.step() 129 | g_loss += gloss.item() 130 | 131 | 132 | z = torch.randn(N, args.z_dim).to(device) 133 | xpred = G(z).view(-1, args.x_dim) 134 | xpred_1d = D(xpred) 135 | x_1d = D(X) 136 | dloss_fake = d_criterion(xpred_1d, torch.zeros_like(xpred_1d)) 137 | dloss_true = d_criterion(x_1d, torch.ones_like(x_1d)) 138 | dloss = dloss_fake + dloss_true 139 | d_opt.zero_grad() 140 | dloss.backward() 141 | d_opt.step() 142 | 143 | d_loss += dloss.item() 144 | n_iter = epoch * len(dataloader) + i 145 | writer.add_scalar('Loss/Generator', gloss.item(), n_iter) 146 | writer.add_scalar('Loss/Sorter', dloss.item(), n_iter) 147 | 148 | 149 | g_loss /= i+1 150 | d_loss /= i+1 151 | 152 | if epoch % 100 == 0: 153 | print("{} {} Epoch {}: \t dloss {} \t gloss {}".format(args.method, args.dataset, epoch, dloss.item(), gloss.item())) 154 | return dloss.item(), gloss.item() 155 | 156 | 157 | def train_SWD(args, epoch, G, g_opt, dataloader, device, writer): 158 | g_loss = 0 159 | for i, X in enumerate(dataloader): 160 | g_opt.zero_grad() 161 | N = X.size(0) 162 | X = X.view(-1, args.x_dim).to(device) 163 | z = torch.randn(N, args.z_dim).to(device) 164 | theta = torch.randn((args.x_dim, args.n_projections), 165 | requires_grad=False, 166 | device=device) 167 | theta = theta/torch.norm(theta, dim=0)[None, :] 168 | xpred = G(z).view(-1, args.x_dim) 169 | xpred_1d = xpred@theta 170 | x_1d = X@theta 171 | 172 | gloss = wasserstein1d(xpred_1d, x_1d) 173 | gloss.backward() 174 | g_opt.step() 175 | g_loss += gloss.item() 176 | 177 | n_iter = epoch * len(dataloader) + i 178 | writer.add_scalar('Loss/Generator', gloss.item(), n_iter) 179 | 180 | 181 | g_loss /= i+1 182 | 183 | if epoch % 100 == 0: 184 | print("{} {} Epoch {}: \t gloss {}".format(args.method, args.dataset, epoch, gloss.item())) 185 | return gloss.item() 186 | 187 | 188 | def train_WGANGP(args, epoch, G, D, g_opt, d_opt, dataloader, device, writer): 189 | d_loss = 0 190 | g_loss = 0 191 | one = torch.FloatTensor([1]).to(device) 192 | mone = one * -1 193 | 194 | 195 | for i, X in enumerate(dataloader): 196 | N = X.size(0) 197 | X = X.view(-1, args.x_dim).to(device) 198 | 199 | z = torch.randn(N, args.z_dim).to(device) 200 | xpred = G(z).view(-1, args.x_dim) 201 | 202 | 203 | xpred_1d = D(xpred) 204 | x_1d = D(X) 205 | 206 | 207 | d_opt.zero_grad() 208 | D_real = x_1d.mean() 209 | 210 | D_fake = xpred_1d.mean() 211 | gradient_penalty = calc_gradient_penalty(D, X, xpred) 212 | 213 | dloss = D_fake - D_real + gradient_penalty 214 | dloss.backward() 215 | d_opt.step() 216 | 217 | 218 | z = torch.randn(N, args.z_dim).to(device) 219 | xpred = G(z).view(-1, args.x_dim) 220 | xpred_1d = D(xpred) 221 | x_1d = D(X) 222 | 223 | g_opt.zero_grad() 224 | g_loss = -xpred_1d.mean() 225 | g_loss.backward() 226 | gloss = -g_loss 227 | if i % 5 == 0: 228 | g_opt.step() 229 | 230 | n_iter = epoch * len(dataloader) + i 231 | writer.add_scalar('Loss/Generator', gloss.item(), n_iter) 232 | writer.add_scalar('Loss/Sorter', dloss.item(), n_iter) 233 | 234 | 235 | g_loss /= i+1 236 | d_loss /= i+1 237 | 238 | if epoch % 100 == 0: 239 | print("{} {} Epoch {}: \t dloss {} \t gloss {}".format(args.method, args.dataset, epoch, dloss.item(), gloss.item())) 240 | return dloss.item(), gloss.item() 241 | 242 | --------------------------------------------------------------------------------