├── README.md ├── log ├── StGen_ResDis_fake_samples_epoch_130.png ├── fake_samples_epoch_113.png └── fake_samples_epoch_199.png ├── models ├── __init__.py ├── __init__.pyc ├── models.py ├── models.pyc ├── snres_discriminator.py ├── snres_discriminator.pyc ├── snres_generator.py └── snres_generator.pyc ├── src ├── __init__.py ├── __init__.pyc ├── functions │ ├── __init__.py │ ├── __init__.pyc │ ├── max_sv.py │ └── max_sv.pyc └── snlayers │ ├── __init__.py │ ├── __init__.pyc │ ├── snconv2d.py │ ├── snconv2d.pyc │ ├── snlinear.py │ └── snlinear.pyc ├── test.py ├── train-conditional.py ├── train-res.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # SN-GAN 2 | PyTorch implementation of [Spectral Normalization for Generative Adversarial Networks](https://openreview.net/pdf?id=B1QRgziT-) 3 | 4 | ## Prerequistes 5 | * Python 2.7 or Python 3.4+ 6 | * [PyTorch](http://pytorch.org/) 7 | * [Numpy](http://www.numpy.org/) 8 | 9 | ## Usage 10 | 11 | - Clone this repo: 12 | ```bash 13 | git clone https://github.com/godisboy/SN-GAN.git 14 | cd SN-GAN 15 | ``` 16 | - Train the standard model for 32*32 image size 17 | ```bash 18 | python train.py --cuda(optional) 19 | ``` 20 | - Train the Res-Generator and Res-Discriminator for 64*64 image size 21 | ```bash 22 | python train-res.py --cuda --dataPath /Path/to/yourdataset/ 23 | ``` 24 | 25 | -------------------- 26 | [SNGAN]:(https://openreview.net/pdf?id=B1QRgziT-) 27 | 28 | 1. Result of SN-GAN on CIFAR10 dataset 29 | * Generated 30 | 31 | ![Generated samples](log/fake_samples_epoch_199.png) 32 | 33 | * Generated 34 | 35 | ![Generated samples](log/fake_samples_epoch_113.png) 36 | 37 | * Generated samples with Standard Generator and ResDiscriminator 38 | 39 | ![Generated samples](log/StGen_ResDis_fake_samples_epoch_130.png) 40 | 41 | **Note**: 42 | The ResBlock of Res-Generator is different from what implemented in original paper. 43 | This repo use `UpsamplingBilinear` instead of `Uppooling` for Upsampling operation. 44 | 45 | ## To Do 46 | - Conditional version of SNGAN with [Conditional BatchNormalization](https://arxiv.org/pdf/1707.03017.pdf) 47 | - ImageNet 1000 classes 48 | - Latent Inference 49 | - ... 50 | ## Acknowledgments 51 | Based on the implementation [DCGAN](https://github.com/pytorch/examples/tree/master/dcgan) and official implementation with [Chainer](https://chainer.org/) [sngan_projection](https://github.com/pfnet-research/sngan_projection) 52 | -------------------------------------------------------------------------------- /log/StGen_ResDis_fake_samples_epoch_130.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/godisboy/SN-GAN/2a5c448235be967df1bc6270c7cc24c07c78f388/log/StGen_ResDis_fake_samples_epoch_130.png -------------------------------------------------------------------------------- /log/fake_samples_epoch_113.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/godisboy/SN-GAN/2a5c448235be967df1bc6270c7cc24c07c78f388/log/fake_samples_epoch_113.png -------------------------------------------------------------------------------- /log/fake_samples_epoch_199.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/godisboy/SN-GAN/2a5c448235be967df1bc6270c7cc24c07c78f388/log/fake_samples_epoch_199.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/godisboy/SN-GAN/2a5c448235be967df1bc6270c7cc24c07c78f388/models/__init__.py -------------------------------------------------------------------------------- /models/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/godisboy/SN-GAN/2a5c448235be967df1bc6270c7cc24c07c78f388/models/__init__.pyc -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.nn.modules import conv, Linear 4 | import torch.nn.functional as F 5 | from src.snlayers.snconv2d import SNConv2d 6 | 7 | class _netG(nn.Module): 8 | def __init__(self, nz, nc, ngf): 9 | super(_netG, self).__init__() 10 | self.main = nn.Sequential( 11 | # input is Z, going into a convolution 12 | nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=True), 13 | nn.BatchNorm2d(ngf * 8), 14 | nn.ReLU(True), 15 | # state size. (ngf*8) x 4 x 4 16 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=True), 17 | nn.BatchNorm2d(ngf * 4), 18 | nn.ReLU(True), 19 | # state size. (ngf*4) x 8 x 8 20 | nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=True), 21 | nn.BatchNorm2d(ngf * 2), 22 | nn.ReLU(True), 23 | # state size. (ngf*2) x 16 x 16 24 | nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=True), 25 | nn.BatchNorm2d(ngf), 26 | nn.ReLU(True), 27 | # state size. (ngf) x 32 x 32 28 | nn.ConvTranspose2d(ngf, nc, 3, 1, 1, bias=True), 29 | nn.Tanh() 30 | # state size. (nc) x 32 x 32 31 | ) 32 | 33 | def forward(self, input): 34 | output = self.main(input) 35 | return output 36 | 37 | class _netD(nn.Module): 38 | def __init__(self, nc, ndf): 39 | super(_netD, self).__init__() 40 | 41 | self.main = nn.Sequential( 42 | # input is (nc) x 32 x 32 43 | #SNConv2d() 44 | SNConv2d(nc, ndf, 3, 1, 1, bias=True), 45 | nn.LeakyReLU(0.1, inplace=True), 46 | SNConv2d(ndf, ndf, 4, 2, 1, bias=True), 47 | nn.LeakyReLU(0.1, inplace=True), 48 | # state size. (ndf) x 1 x 32 49 | SNConv2d(ndf, ndf * 2, 3, 1, 1, bias=True), 50 | nn.LeakyReLU(0.1, inplace=True), 51 | SNConv2d(ndf*2, ndf * 2, 4, 2, 1, bias=True), 52 | #nn.BatchNorm2d(ndf * 2), 53 | nn.LeakyReLU(0.1, inplace=True), 54 | # state size. (ndf*2) x 16 x 16 55 | SNConv2d(ndf * 2, ndf * 4, 3, 1, 1, bias=True), 56 | nn.LeakyReLU(0.1, inplace=True), 57 | SNConv2d(ndf * 4, ndf * 4, 4, 2, 1, bias=True), 58 | nn.LeakyReLU(0.1, inplace=True), 59 | # state size. (ndf*8) x 4 x 4 60 | SNConv2d(ndf * 4, ndf * 8, 3, 1, 1, bias=True), 61 | nn.LeakyReLU(0.1, inplace=True), 62 | SNConv2d(ndf * 8, 1, 4, 1, 0, bias=False), 63 | nn.Sigmoid() 64 | ) 65 | #self.snlinear = nn.Sequential(SNLinear(ndf * 4 * 4 * 4, 1), 66 | # nn.Sigmoid()) 67 | 68 | def forward(self, input): 69 | output = self.main(input) 70 | #output = output.view(output.size(0), -1) 71 | #output = self.snlinear(output) 72 | return output.view(-1, 1).squeeze(1) -------------------------------------------------------------------------------- /models/models.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/godisboy/SN-GAN/2a5c448235be967df1bc6270c7cc24c07c78f388/models/models.pyc -------------------------------------------------------------------------------- /models/snres_discriminator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from src.snlayers.snconv2d import SNConv2d 3 | from src.snlayers.snlinear import SNLinear 4 | import torch.nn.functional as F 5 | 6 | class ResBlock(nn.Module): 7 | def __init__(self, in_channels, out_channels, hidden_channels=None, use_BN = False, downsample=False): 8 | super(ResBlock, self).__init__() 9 | #self.conv1 = SNConv2d(n_dim, n_out, kernel_size=3, stride=2) 10 | hidden_channels = in_channels 11 | self.downsample = downsample 12 | 13 | self.resblock = self.make_res_block(in_channels, out_channels, hidden_channels, use_BN, downsample) 14 | self.residual_connect = self.make_residual_connect(in_channels, out_channels) 15 | def make_res_block(self, in_channels, out_channels, hidden_channels, use_BN, downsample): 16 | model = [] 17 | if use_BN: 18 | model += [nn.BatchNorm2d(in_channels)] 19 | 20 | model += [nn.ReLU()] 21 | model += [SNConv2d(in_channels, hidden_channels, kernel_size=3, padding=1)] 22 | model += [nn.ReLU()] 23 | model += [SNConv2d(hidden_channels, out_channels, kernel_size=3, padding=1)] 24 | if downsample: 25 | model += [nn.AvgPool2d(2)] 26 | return nn.Sequential(*model) 27 | def make_residual_connect(self, in_channels, out_channels): 28 | model = [] 29 | model += [SNConv2d(in_channels, out_channels, kernel_size=1, padding=0)] 30 | if self.downsample: 31 | model += [nn.AvgPool2d(2)] 32 | return nn.Sequential(*model) 33 | else: 34 | return nn.Sequential(*model) 35 | 36 | def forward(self, input): 37 | return self.resblock(input) + self.residual_connect(input) 38 | 39 | class OptimizedBlock(nn.Module): 40 | def __init__(self, in_channels, out_channels): 41 | super(OptimizedBlock, self).__init__() 42 | self.res_block = self.make_res_block(in_channels, out_channels) 43 | self.residual_connect = self.make_residual_connect(in_channels, out_channels) 44 | def make_res_block(self, in_channels, out_channels): 45 | model = [] 46 | model += [SNConv2d(in_channels, out_channels, kernel_size=3, padding=1)] 47 | model += [nn.ReLU()] 48 | model += [SNConv2d(out_channels, out_channels, kernel_size=3, padding=1)] 49 | model += [nn.AvgPool2d(2)] 50 | return nn.Sequential(*model) 51 | def make_residual_connect(self, in_channels, out_channels): 52 | model = [] 53 | model += [SNConv2d(in_channels, out_channels, kernel_size=1, padding=0)] 54 | model += [nn.AvgPool2d(2)] 55 | return nn.Sequential(*model) 56 | def forward(self, input): 57 | return self.res_block(input) + self.residual_connect(input) 58 | 59 | class SNResDiscriminator(nn.Module): 60 | def __init__(self, ndf=64, ndlayers=4): 61 | super(SNResDiscriminator, self).__init__() 62 | self.res_d = self.make_model(ndf, ndlayers) 63 | self.fc = nn.Sequential(SNLinear(ndf*16, 1), nn.Sigmoid()) 64 | def make_model(self, ndf, ndlayers): 65 | model = [] 66 | model += [OptimizedBlock(3, ndf)] 67 | tndf = ndf 68 | for i in range(ndlayers): 69 | model += [ResBlock(tndf, tndf*2, downsample=True)] 70 | tndf *= 2 71 | model += [nn.ReLU()] 72 | return nn.Sequential(*model) 73 | def forward(self, input): 74 | out = self.res_d(input) 75 | out = F.avg_pool2d(out, out.size(3), stride=1) 76 | out = out.view(-1, 1024) 77 | return self.fc(out) 78 | -------------------------------------------------------------------------------- /models/snres_discriminator.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/godisboy/SN-GAN/2a5c448235be967df1bc6270c7cc24c07c78f388/models/snres_discriminator.pyc -------------------------------------------------------------------------------- /models/snres_generator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class ResBlock(nn.Module): 6 | def __init__(self, in_channels, out_channels, hidden_channels=None, upsample=False): 7 | super(ResBlock, self).__init__() 8 | #self.conv1 = SNConv2d(n_dim, n_out, kernel_size=3, stride=2) 9 | hidden_channels = in_channels 10 | self.upsample = upsample 11 | self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1) 12 | self.conv2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1) 13 | self.conv_sc = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) 14 | self.upsampling = nn.UpsamplingBilinear2d(scale_factor=2) 15 | self.bn1 = nn.BatchNorm2d(in_channels) 16 | self.bn2 = nn.BatchNorm2d(hidden_channels) 17 | self.relu = nn.ReLU() 18 | def forward_residual_connect(self, input): 19 | out = self.conv_sc(input) 20 | if self.upsample: 21 | out = self.upsampling(out) 22 | #out = self.upconv2(out) 23 | return out 24 | def forward(self, input): 25 | out = self.relu(self.bn1(input)) 26 | out = self.conv1(out) 27 | if self.upsample: 28 | out = self.upsampling(out) 29 | #out = self.upconv1(out) 30 | out = self.relu(self.bn2(out)) 31 | out = self.conv2(out) 32 | out_res = self.forward_residual_connect(input) 33 | return out + out_res 34 | 35 | class SNResGenerator(nn.Module): 36 | def __init__(self, ngf, z=128, nlayers=4): 37 | super(SNResGenerator, self).__init__() 38 | self.input_layer = nn.Linear(z, (4 ** 2) * ngf * 16) 39 | self.generator = self.make_model(ngf, nlayers) 40 | 41 | def make_model(self, ngf, nlayers): 42 | model = [] 43 | tngf = ngf*16 44 | for i in range(nlayers): 45 | model += [ResBlock(tngf, tngf/2, upsample=True)] 46 | tngf /= 2 47 | model += [nn.BatchNorm2d(ngf)] 48 | model += [nn.ReLU()] 49 | model += [nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)] 50 | model += [nn.Tanh()] 51 | return nn.Sequential(*model) 52 | 53 | def forward(self, z): 54 | out = self.input_layer(z) 55 | out = out.view(z.size(0), -1, 4, 4) 56 | out = self.generator(out) 57 | 58 | return out -------------------------------------------------------------------------------- /models/snres_generator.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/godisboy/SN-GAN/2a5c448235be967df1bc6270c7cc24c07c78f388/models/snres_generator.pyc -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/godisboy/SN-GAN/2a5c448235be967df1bc6270c7cc24c07c78f388/src/__init__.py -------------------------------------------------------------------------------- /src/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/godisboy/SN-GAN/2a5c448235be967df1bc6270c7cc24c07c78f388/src/__init__.pyc -------------------------------------------------------------------------------- /src/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/godisboy/SN-GAN/2a5c448235be967df1bc6270c7cc24c07c78f388/src/functions/__init__.py -------------------------------------------------------------------------------- /src/functions/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/godisboy/SN-GAN/2a5c448235be967df1bc6270c7cc24c07c78f388/src/functions/__init__.pyc -------------------------------------------------------------------------------- /src/functions/max_sv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | #define _l2normalization 5 | def _l2normalize(v, eps=1e-12): 6 | return v / (torch.norm(v) + eps) 7 | 8 | def max_singular_value(W, u=None, Ip=1): 9 | """ 10 | power iteration for weight parameter 11 | """ 12 | #xp = W.data 13 | if not Ip >= 1: 14 | raise ValueError("Power iteration should be a positive integer") 15 | if u is None: 16 | u = torch.FloatTensor(1, W.size(0)).normal_(0, 1).cuda() 17 | _u = u 18 | for _ in range(Ip): 19 | _v = _l2normalize(torch.matmul(_u, W.data), eps=1e-12) 20 | _u = _l2normalize(torch.matmul(_v, torch.transpose(W.data, 0, 1)), eps=1e-12) 21 | sigma = torch.sum(F.linear(_u, torch.transpose(W.data, 0, 1)) * _v) 22 | return sigma, _u 23 | -------------------------------------------------------------------------------- /src/functions/max_sv.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/godisboy/SN-GAN/2a5c448235be967df1bc6270c7cc24c07c78f388/src/functions/max_sv.pyc -------------------------------------------------------------------------------- /src/snlayers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/godisboy/SN-GAN/2a5c448235be967df1bc6270c7cc24c07c78f388/src/snlayers/__init__.py -------------------------------------------------------------------------------- /src/snlayers/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/godisboy/SN-GAN/2a5c448235be967df1bc6270c7cc24c07c78f388/src/snlayers/__init__.pyc -------------------------------------------------------------------------------- /src/snlayers/snconv2d.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.modules import conv 6 | from torch.nn.modules.utils import _pair 7 | from ..functions.max_sv import max_singular_value 8 | 9 | class SNConv2d(conv._ConvNd): 10 | 11 | r"""Applies a 2D convolution over an input signal composed of several input 12 | planes. 13 | 14 | In the simplest case, the output value of the layer with input size 15 | :math:`(N, C_{in}, H, W)` and output :math:`(N, C_{out}, H_{out}, W_{out})` 16 | can be precisely described as: 17 | 18 | .. math:: 19 | 20 | \begin{array}{ll} 21 | out(N_i, C_{out_j}) = bias(C_{out_j}) 22 | + \sum_{{k}=0}^{C_{in}-1} weight(C_{out_j}, k) \star input(N_i, k) 23 | \end{array} 24 | 25 | where :math:`\star` is the valid 2D `cross-correlation`_ operator, 26 | :math:`N` is a batch size, :math:`C` denotes a number of channels, 27 | :math:`H` is a height of input planes in pixels, and :math:`W` is 28 | width in pixels. 29 | 30 | | :attr:`stride` controls the stride for the cross-correlation, a single 31 | number or a tuple. 32 | | :attr:`padding` controls the amount of implicit zero-paddings on both 33 | | sides for :attr:`padding` number of points for each dimension. 34 | | :attr:`dilation` controls the spacing between the kernel points; also 35 | known as the à trous algorithm. It is harder to describe, but this `link`_ 36 | has a nice visualization of what :attr:`dilation` does. 37 | | :attr:`groups` controls the connections between inputs and outputs. 38 | `in_channels` and `out_channels` must both be divisible by `groups`. 39 | | At groups=1, all inputs are convolved to all outputs. 40 | | At groups=2, the operation becomes equivalent to having two conv 41 | layers side by side, each seeing half the input channels, 42 | and producing half the output channels, and both subsequently 43 | concatenated. 44 | At groups=`in_channels`, each input channel is convolved with its 45 | own set of filters (of size `out_channels // in_channels`). 46 | 47 | The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: 48 | 49 | - a single ``int`` -- in which case the same value is used for the height and width dimension 50 | - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, 51 | and the second `int` for the width dimension 52 | 53 | .. note:: 54 | 55 | Depending of the size of your kernel, several (of the last) 56 | columns of the input might be lost, because it is a valid `cross-correlation`_, 57 | and not a full `cross-correlation`_. 58 | It is up to the user to add proper padding. 59 | 60 | .. note:: 61 | 62 | The configuration when `groups == in_channels` and `out_channels = K * in_channels` 63 | where `K` is a positive integer is termed in literature as depthwise convolution. 64 | 65 | In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`, if you want a 66 | depthwise convolution with a depthwise multiplier `K`, 67 | then you use the constructor arguments 68 | :math:`(in\_channels=C_{in}, out\_channels=C_{in} * K, ..., groups=C_{in})` 69 | 70 | Args: 71 | in_channels (int): Number of channels in the input image 72 | out_channels (int): Number of channels produced by the convolution 73 | kernel_size (int or tuple): Size of the convolving kernel 74 | stride (int or tuple, optional): Stride of the convolution. Default: 1 75 | padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 76 | dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 77 | groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 78 | bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` 79 | 80 | Shape: 81 | - Input: :math:`(N, C_{in}, H_{in}, W_{in})` 82 | - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where 83 | :math:`H_{out} = floor((H_{in} + 2 * padding[0] - dilation[0] * (kernel\_size[0] - 1) - 1) / stride[0] + 1)` 84 | :math:`W_{out} = floor((W_{in} + 2 * padding[1] - dilation[1] * (kernel\_size[1] - 1) - 1) / stride[1] + 1)` 85 | 86 | Attributes: 87 | weight (Tensor): the learnable weights of the module of shape 88 | (out_channels, in_channels, kernel_size[0], kernel_size[1]) 89 | bias (Tensor): the learnable bias of the module of shape (out_channels) 90 | 91 | W(Tensor): Spectrally normalized weight 92 | 93 | u (Tensor): the right largest singular value of W. 94 | 95 | .. _cross-correlation: 96 | https://en.wikipedia.org/wiki/Cross-correlation 97 | 98 | .. _link: 99 | https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md 100 | """ 101 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): 102 | kernel_size = _pair(kernel_size) 103 | stride = _pair(stride) 104 | padding = _pair(padding) 105 | dilation = _pair(dilation) 106 | super(SNConv2d, self).__init__( 107 | in_channels, out_channels, kernel_size, stride, padding, dilation, 108 | False, _pair(0), groups, bias, padding_mode = 'zeros') 109 | self.register_buffer('u', torch.Tensor(1, out_channels).normal_()) 110 | 111 | @property 112 | def W_(self): 113 | w_mat = self.weight.view(self.weight.size(0), -1) 114 | sigma, _u = max_singular_value(w_mat, self.u) 115 | self.u.copy_(_u) 116 | return self.weight / sigma 117 | 118 | def forward(self, input): 119 | return F.conv2d(input, self.W_, self.bias, self.stride, 120 | self.padding, self.dilation, self.groups) 121 | -------------------------------------------------------------------------------- /src/snlayers/snconv2d.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/godisboy/SN-GAN/2a5c448235be967df1bc6270c7cc24c07c78f388/src/snlayers/snconv2d.pyc -------------------------------------------------------------------------------- /src/snlayers/snlinear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.modules import Linear 5 | from ..functions.max_sv import max_singular_value 6 | 7 | class SNLinear(Linear): 8 | r"""Applies a linear transformation to the incoming data: :math:`y = Ax + b` 9 | Args: 10 | in_features: size of each input sample 11 | out_features: size of each output sample 12 | bias: If set to False, the layer will not learn an additive bias. 13 | Default: ``True`` 14 | Shape: 15 | - Input: :math:`(N, *, in\_features)` where :math:`*` means any number of 16 | additional dimensions 17 | - Output: :math:`(N, *, out\_features)` where all but the last dimension 18 | are the same shape as the input. 19 | Attributes: 20 | weight: the learnable weights of the module of shape 21 | `(out_features x in_features)` 22 | bias: the learnable bias of the module of shape `(out_features)` 23 | 24 | W(Tensor): Spectrally normalized weight 25 | 26 | u (Tensor): the right largest singular value of W. 27 | """ 28 | def __init__(self, in_features, out_features, bias=True): 29 | super(SNLinear, self).__init__(in_features, out_features, bias) 30 | self.register_buffer('u', torch.Tensor(1, out_features).normal_()) 31 | 32 | @property 33 | def W_(self): 34 | w_mat = self.weight.view(self.weight.size(0), -1) 35 | sigma, _u = max_singular_value(w_mat, self.u) 36 | self.u.copy_(_u) 37 | return self.weight / sigma 38 | 39 | def forward(self, input): 40 | return F.linear(input, self.W_, self.bias) 41 | -------------------------------------------------------------------------------- /src/snlayers/snlinear.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/godisboy/SN-GAN/2a5c448235be967df1bc6270c7cc24c07c78f388/src/snlayers/snlinear.pyc -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from torchvision import datasets, transforms 6 | import torchvision.utils as vutils 7 | from torch.autograd import Variable 8 | import torch.utils.data 9 | from torch.nn.modules import conv 10 | from torch.nn.modules.utils import _pair, _triple 11 | import torch.backends.cudnn as cudnn 12 | 13 | import random 14 | import argparse 15 | import os 16 | from PIL import Image 17 | import numpy as np 18 | import matplotlib.pyplot as plt 19 | parser = argparse.ArgumentParser(description='train SNDCGAN model') 20 | parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector') 21 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 22 | parser.add_argument('--gpu_ids', default=[0,1,2,3], help='gpu ids: e.g. 0,1,2, 0,2.') 23 | parser.add_argument('--manualSeed', type=int, help='manual seed') 24 | parser.add_argument('--batchSize', type=int, default=100, help='with batchSize=1 equivalent to instance normalization.') 25 | parser.add_argument('--label_num', type=int, default=200, help='number of labels.') 26 | opt = parser.parse_args() 27 | print(opt) 28 | 29 | dataset = datasets.ImageFolder(root='/media/scw4750/25a01ed5-a903-4298-87f2-a5836dcb6888/AIwalker/dataset/CUB200_object', 30 | transform=transforms.Compose([ 31 | transforms.Scale(64), 32 | transforms.CenterCrop(64), 33 | transforms.ToTensor(), 34 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 35 | ]) 36 | ) 37 | 38 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, 39 | shuffle=True, num_workers=int(2)) 40 | if opt.manualSeed is None: 41 | opt.manualSeed = random.randint(1, 10000) 42 | print("Random Seed: ", opt.manualSeed) 43 | random.seed(opt.manualSeed) 44 | torch.manual_seed(opt.manualSeed) 45 | 46 | if opt.cuda: 47 | torch.cuda.manual_seed_all(opt.manualSeed) 48 | torch.cuda.set_device(opt.gpu_ids[2]) 49 | 50 | cudnn.benchmark = True 51 | 52 | def _l2normalize(v, eps=1e-12): 53 | return v / (((v**2).sum())**0.5 + eps) 54 | 55 | def max_singular_value(W, u=None, Ip=1): 56 | """ 57 | power iteration for weight parameter 58 | """ 59 | #xp = W.data 60 | if u is None: 61 | u = torch.FloatTensor(1, W.size(0)).normal_(0, 1).cuda() 62 | _u = u 63 | for _ in range(Ip): 64 | #print(_u.size(), W.size()) 65 | _v = _l2normalize(torch.matmul(_u, W.data), eps=1e-12) 66 | _u = _l2normalize(torch.matmul(_v, torch.transpose(W.data, 0, 1)), eps=1e-12) 67 | sigma = torch.matmul(torch.matmul(_v, torch.transpose(W.data, 0, 1)), torch.transpose(_u, 0, 1)) 68 | return sigma, _v 69 | 70 | class SNConv2d(conv._ConvNd): 71 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): 72 | kernel_size = _pair(kernel_size) 73 | stride = _pair(stride) 74 | padding = _pair(padding) 75 | dilation = _pair(dilation) 76 | super(SNConv2d, self).__init__( 77 | in_channels, out_channels, kernel_size, stride, padding, dilation, 78 | False, _pair(0), groups, bias) 79 | 80 | def forward(self, input): 81 | w_mat = self.weight.view(self.weight.size(0), -1) 82 | sigma, _ = max_singular_value(w_mat) 83 | #print(sigma.size()) 84 | self.weight.data = self.weight.data / sigma 85 | #print(self.weight.data) 86 | return F.conv2d(input, self.weight, self.bias, self.stride, 87 | self.padding, self.dilation, self.groups) 88 | 89 | class _netG(nn.Module): 90 | def __init__(self, nz, nc, ngf): 91 | super(_netG, self).__init__() 92 | self.main = nn.Sequential( 93 | # input is Z, going into a convolution 94 | nn.ConvTranspose2d(nz+200, ngf * 8, 4, 1, 0, bias=False), 95 | nn.BatchNorm2d(ngf * 8), 96 | nn.ReLU(True), 97 | # state size. (ngf*8) x 4 x 4 98 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), 99 | nn.BatchNorm2d(ngf * 4), 100 | nn.ReLU(True), 101 | # state size. (ngf*4) x 8 x 8 102 | nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), 103 | nn.BatchNorm2d(ngf * 2), 104 | nn.ReLU(True), 105 | # state size. (ngf*2) x 16 x 16 106 | nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), 107 | nn.BatchNorm2d(ngf), 108 | nn.ReLU(True), 109 | # state size. (ngf) x 32 x 32 110 | nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False), 111 | nn.Tanh() 112 | # state size. (nc) x 64 x 64 113 | ) 114 | 115 | def forward(self, input): 116 | output = self.main(input) 117 | return output 118 | 119 | nz = opt.nz 120 | 121 | G = _netG(nz, 3, 64) 122 | print(G) 123 | save_path = 'log/netG_epoch_199.pth' 124 | G.load_state_dict(torch.load(save_path)) 125 | 126 | input = torch.FloatTensor(opt.batchSize, 3, 64, 64) 127 | noise = torch.FloatTensor(opt.batchSize, nz, 1, 1) 128 | label = torch.FloatTensor(opt.batchSize) 129 | real_label = 1 130 | fake_label = 0 131 | 132 | #fixed label 133 | fix_label = torch.FloatTensor(opt.batchSize) 134 | 135 | for i in range(0,100): 136 | fix_label[i] = i; 137 | #fix_label[i] = np.random.randint(1,200); 138 | 139 | fix = torch.LongTensor(opt.batchSize,1).copy_(fix_label) 140 | fix_onehot = torch.FloatTensor(opt.batchSize, 200) 141 | fix_onehot.zero_() 142 | fix_onehot.scatter_(1, fix, 1) 143 | fix_onehot.view(-1, 200, 1, 1) 144 | 145 | fixed_noise = torch.FloatTensor(opt.batchSize, nz, 1, 1).normal_(0, 1) 146 | fixed_input = torch.cat([fixed_noise, fix_onehot],1) 147 | fixed_input = Variable(fixed_input) 148 | 149 | print(fixed_input.size()) 150 | criterion = nn.BCELoss() 151 | 152 | fill = torch.zeros([200, 200, 64, 64]) 153 | for i in range(200): 154 | fill[i, i, :, :] = 1 155 | 156 | if opt.cuda: 157 | G.cuda() 158 | input, label = input.cuda(), label.cuda() 159 | noise, fixed_input = noise.cuda(), fixed_input.cuda() 160 | 161 | #for i, data in enumerate(dataloader, 0): 162 | ############################ 163 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) 164 | ########################### 165 | # train with real 166 | #real_cpu, labels = data 167 | batch_size = opt.batchSize 168 | #if opt.cuda: 169 | # real_cpu = real_cpu.cuda() 170 | ''' 171 | y = torch.LongTensor(batch_size, 1).copy_(labels) 172 | y_onehot = torch.zeros(batch_size, 200) 173 | y_onehot.scatter_(1, y, 1) 174 | y_onehot.view(-1, batch_size, 1, 1) 175 | y_onehot = Variable(y_onehot.cuda()) 176 | ''' 177 | # train with fake 178 | noise.resize_(batch_size, 100, 1, 1).normal_(0, 1) 179 | noisev = Variable(noise) 180 | #y_nz = torch.cat([noisev, y_onehot], 1) 181 | fake = G(fixed_input) 182 | 183 | vutils.save_image(fake.data, 184 | '%s/conditional_fake.png' % ('log'), 185 | normalize=True, nrow=10) 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | -------------------------------------------------------------------------------- /train-conditional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from torchvision import datasets, transforms 6 | import torchvision.utils as vutils 7 | from torch.autograd import Variable 8 | import torch.utils.data 9 | import torch.backends.cudnn as cudnn 10 | 11 | import random 12 | import argparse 13 | from models.models import SNConv2d 14 | import os 15 | from PIL import Image 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | parser = argparse.ArgumentParser(description='train SNDCGAN model') 19 | parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector') 20 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 21 | parser.add_argument('--gpu_ids', default=[0,1,2,3], help='gpu ids: e.g. 0,1,2, 0,2.') 22 | parser.add_argument('--manualSeed', type=int, help='manual seed') 23 | parser.add_argument('--batchSize', type=int, default=32, help='with batchSize=1 equivalent to instance normalization.') 24 | parser.add_argument('--label_num', type=int, default=200, help='number of labels.') 25 | opt = parser.parse_args() 26 | print(opt) 27 | 28 | dataset = datasets.ImageFolder(root='/home/chao/Downloads/AwA2-data/train-tiny', 29 | transform=transforms.Compose([ 30 | transforms.Scale(64), 31 | transforms.CenterCrop(64), 32 | transforms.ToTensor(), 33 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 34 | ]) 35 | ) 36 | 37 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, 38 | shuffle=True, num_workers=int(2)) 39 | if opt.manualSeed is None: 40 | opt.manualSeed = random.randint(1, 10000) 41 | print("Random Seed: ", opt.manualSeed) 42 | random.seed(opt.manualSeed) 43 | torch.manual_seed(opt.manualSeed) 44 | 45 | if opt.cuda: 46 | torch.cuda.manual_seed_all(opt.manualSeed) 47 | torch.cuda.set_device(opt.gpu_ids[0]) 48 | 49 | cudnn.benchmark = True 50 | 51 | def weight_filler(m): 52 | classname = m.__class__.__name__ 53 | if classname.find('Conv') != -1: 54 | m.weight.data.normal_(0.0, 0.02) 55 | elif classname.find('BatchNorm') != -1: 56 | m.weight.data.normal_(1.0, 0.02) 57 | m.bias.data.fill_(0) 58 | 59 | class _netG(nn.Module): 60 | def __init__(self, nz, nc, ngf): 61 | super(_netG, self).__init__() 62 | 63 | self.convT1 = nn.Sequential( 64 | nn.ConvTranspose2d(nz, ngf * 4, 4, 1, 0, bias=False), 65 | nn.BatchNorm2d(ngf * 4), 66 | nn.ReLU(True) 67 | ) 68 | self.convT2 = nn.Sequential( 69 | nn.ConvTranspose2d(10, ngf * 4, 4, 1, 0, bias=False), 70 | nn.BatchNorm2d(ngf * 4), 71 | nn.ReLU(True) 72 | ) 73 | self.main = nn.Sequential( 74 | # input is Z, going into a convolution 75 | # state size. (ngf*8) x 4 x 4 76 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), 77 | nn.BatchNorm2d(ngf * 4), 78 | nn.ReLU(True), 79 | # state size. (ngf*4) x 8 x 8 80 | nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), 81 | nn.BatchNorm2d(ngf * 2), 82 | nn.ReLU(True), 83 | # state size. (ngf*2) x 16 x 16 84 | nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), 85 | nn.BatchNorm2d(ngf), 86 | nn.ReLU(True), 87 | # state size. (ngf) x 32 x 32 88 | nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False), 89 | nn.Tanh() 90 | # state size. (nc) x 64 x 64 91 | ) 92 | 93 | def forward(self, input, input_c): 94 | out1 = self.convT1(input) 95 | out2 = self.convT2(input_c) 96 | output = torch.cat([out1, out2], 1) 97 | output = self.main(output) 98 | 99 | return output 100 | 101 | class _netD(nn.Module): 102 | def __init__(self, nc, ndf): 103 | super(_netD, self).__init__() 104 | 105 | self.conv1_1 = SNConv2d(nc, ndf/2, 3, 1, 1, bias=False) 106 | self.conv1_2 = SNConv2d(10, ndf/2, 3, 1, 1, bias=False) 107 | self.lrelu = nn.LeakyReLU(0.2, inplace=True) 108 | self.main = nn.Sequential( 109 | # input is (nc) x 64 x 64 110 | SNConv2d(ndf, ndf, 4, 2, 1, bias=False), 111 | nn.LeakyReLU(0.2, inplace=True), 112 | # state size. (ndf) x 32 x 32 113 | SNConv2d(ndf, ndf * 2, 3, 1, 1, bias=False), 114 | nn.LeakyReLU(0.2, inplace=True), 115 | SNConv2d(ndf *2 , ndf * 2, 4, 2, 1, bias=False), 116 | nn.LeakyReLU(0.2, inplace=True), 117 | # state size. (ndf*2) x 16 x 16 118 | SNConv2d(ndf * 2, ndf * 4, 3, 1, 1, bias=False), 119 | nn.LeakyReLU(0.2, inplace=True), 120 | SNConv2d(ndf * 4, ndf * 4, 4, 2, 1, bias=False), 121 | nn.LeakyReLU(0.2, inplace=True), 122 | # state size. (ndf*4) x 8 x 8 123 | SNConv2d(ndf * 4, ndf * 8, 3, 1, 1, bias=False), 124 | nn.LeakyReLU(0.2, inplace=True), 125 | SNConv2d(ndf * 8, ndf * 8, 4, 2, 1, bias=False), 126 | nn.LeakyReLU(0.2, inplace=True), 127 | # state size. (ndf*8) x 4 x 4 128 | SNConv2d(ndf * 8, 1, 4, 1, 0, bias=False), 129 | #nn.LeakyReLU(0.2, inplace=True) 130 | #nn.Softplus() 131 | ) 132 | 133 | def forward(self, input, input_c): 134 | out1 = self.lrelu(self.conv1_1(input)) 135 | out2 = self.lrelu(self.conv1_2(input_c)) 136 | output = torch.cat([out1, out2], 1) 137 | output = self.main(output) 138 | return output.view(-1, 1).squeeze(1) 139 | 140 | nz = opt.nz 141 | 142 | G = _netG(nz, 3, 64) 143 | SND = _netD(3, 64) 144 | print(G) 145 | print(SND) 146 | G.apply(weight_filler) 147 | SND.apply(weight_filler) 148 | 149 | input = torch.FloatTensor(32, 3, 64, 64) 150 | noise = torch.FloatTensor(32, nz, 1, 1) 151 | label = torch.FloatTensor(32) 152 | real_label = 1 153 | fake_label = 0 154 | 155 | #fixed label 156 | fix_label = torch.FloatTensor(opt.batchSize, 1) 157 | 158 | for i in range(0, 4): 159 | #label_y = np.random.randint(1,200) 160 | for j in range(0, 8): 161 | fix_label[i*8+j] = j 162 | #fix_label[i] = np.random.randint(1,200); 163 | 164 | fix = torch.LongTensor(32,1).copy_(fix_label) 165 | fix_onehot = torch.FloatTensor(opt.batchSize, 10) 166 | fix_onehot.zero_() 167 | fix_onehot.scatter_(1, fix, 1) 168 | fix_onehot = fix_onehot.view(-1, 10, 1, 1) 169 | 170 | fixed_noise = torch.FloatTensor(32, nz, 1, 1).normal_(0, 1) 171 | #fixed_input = torch.cat([fixed_noise, fix_onehot],1) 172 | fixed_noise, fix_onehot = Variable(fixed_noise), Variable(fix_onehot) 173 | 174 | criterion = nn.BCELoss() 175 | 176 | fill = torch.zeros([10, 10, 64, 64]) 177 | for i in range(10): 178 | fill[i, i, :, :] = 1 179 | 180 | if opt.cuda: 181 | G.cuda() 182 | SND.cuda() 183 | criterion.cuda() 184 | input, label = input.cuda(), label.cuda() 185 | noise, fixed_noise, fix_onehot = noise.cuda(), fixed_noise.cuda(), fix_onehot.cuda() 186 | 187 | optimizerG = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999)) 188 | optimizerSND = optim.Adam(SND.parameters(), lr=0.0002, betas=(0.5, 0.999)) 189 | 190 | for epoch in range(300): 191 | for i, data in enumerate(dataloader, 0): 192 | ############################ 193 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) 194 | ########################### 195 | # train with real 196 | SND.zero_grad() 197 | real_cpu, labels = data 198 | batch_size = real_cpu.size(0) 199 | #if opt.cuda: 200 | # real_cpu = real_cpu.cuda() 201 | y = torch.LongTensor(batch_size, 1).copy_(labels.view(-1, 1)) 202 | y_onehot = torch.zeros(batch_size, 10) 203 | y_onehot.scatter_(1, y, 1) 204 | y_onehot_v = y_onehot.view(batch_size, -1, 1, 1) 205 | #print(y_onehot_v.size()) 206 | y_onehot_v = Variable(y_onehot_v.cuda()) 207 | 208 | y_fill = fill[labels] 209 | y_fill = Variable(y_fill.cuda()) 210 | 211 | input.resize_(real_cpu.size()).copy_(real_cpu) 212 | label.resize_(batch_size).fill_(real_label) 213 | inputv = Variable(input) 214 | labelv = Variable(label) 215 | output = SND(inputv, y_fill) 216 | #print(output) 217 | errD_real = torch.mean(F.softplus(-output).mean()) 218 | #errD_real = criterion(output, labelv) 219 | errD_real.backward() 220 | D_x = output.data.mean() 221 | 222 | # train with fake 223 | noise.resize_(batch_size, 100, 1, 1).normal_(0, 1) 224 | noisev = Variable(noise) 225 | #y_nz = torch.cat([noisev, y_onehot], 1) 226 | fake = G(noisev, y_onehot_v) 227 | labelv = Variable(label.fill_(fake_label)) 228 | output = SND(fake.detach(), y_fill) 229 | errD_fake = torch.mean(F.softplus(output)) 230 | #errD_fake = criterion(output, labelv) 231 | errD_fake.backward() 232 | D_G_z1 = output.data.mean() 233 | errD = errD_real + errD_fake 234 | optimizerSND.step() 235 | 236 | ############################ 237 | # (2) Update G network: maximize log(D(G(z))) 238 | ########################### 239 | G.zero_grad() 240 | labelv = Variable(label.fill_(real_label)) # fake labels are real for generator cost 241 | output = SND(fake, y_fill) 242 | errG = torch.mean(F.softplus(-output)) 243 | #errG = criterion(output, labelv) 244 | errG.backward() 245 | D_G_z2 = output.data.mean() 246 | optimizerG.step() 247 | if i % 20 == 0: 248 | print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' 249 | % (epoch, 200, i, len(dataloader), 250 | errD.data[0], errG.data[0], D_x, D_G_z1, D_G_z2)) 251 | if i % 100 == 0: 252 | vutils.save_image(real_cpu, 253 | '%s/real_samples.png' % 'log', 254 | normalize=True) 255 | fake = G(fixed_noise, fix_onehot) 256 | vutils.save_image(fake.data, 257 | '%s/fake_samples_epoch_%03d.png' % ('log', epoch), 258 | normalize=True) 259 | 260 | # do checkpointing 261 | torch.save(G.state_dict(), '%s/netG_epoch_%d.pth' % ('log', epoch)) 262 | torch.save(SND.state_dict(), '%s/netD_epoch_%d.pth' % ('log', epoch)) 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | -------------------------------------------------------------------------------- /train-res.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from torchvision import datasets, transforms 6 | import torchvision.utils as vutils 7 | from torch.autograd import Variable 8 | import torch.utils.data 9 | import torch.backends.cudnn as cudnn 10 | 11 | import random 12 | import argparse 13 | from models.snres_generator import SNResGenerator 14 | from models.snres_discriminator import SNResDiscriminator 15 | 16 | parser = argparse.ArgumentParser(description='train SNDCGAN model') 17 | parser.add_argument('--dataPath', required=True, help='path to dataset') 18 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 19 | parser.add_argument('--gpu_ids', default=[0,1,2,3], help='gpu ids: e.g. 0,1,2, 0,2.') 20 | parser.add_argument('--manualSeed', type=int, help='manual seed') 21 | parser.add_argument('--n_dis', type=int, default=1, help='discriminator critic iters') 22 | parser.add_argument('--nz', type=int, default=128, help='dimention of lantent noise') 23 | parser.add_argument('--batchsize', type=int, default=32, help='training batch size') 24 | 25 | 26 | opt = parser.parse_args() 27 | print(opt) 28 | 29 | dataset = datasets.ImageFolder(root=opt.dataPath, 30 | transform=transforms.Compose([ 31 | transforms.Scale(64), 32 | transforms.CenterCrop(64), 33 | transforms.ToTensor(), 34 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 35 | ]) 36 | ) 37 | ''' 38 | dataset = datasets.CIFAR10(root='dataset', download=True, 39 | transform=transforms.Compose([ 40 | transforms.Scale(32), 41 | transforms.ToTensor(), 42 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 43 | ])) 44 | ''' 45 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchsize, 46 | shuffle=True, num_workers=int(2)) 47 | 48 | if opt.manualSeed is None: 49 | opt.manualSeed = random.randint(1, 10000) 50 | print("Random Seed: ", opt.manualSeed) 51 | random.seed(opt.manualSeed) 52 | torch.manual_seed(opt.manualSeed) 53 | 54 | if opt.cuda: 55 | torch.cuda.manual_seed_all(opt.manualSeed) 56 | torch.cuda.set_device(opt.gpu_ids[0]) 57 | 58 | cudnn.benchmark = True 59 | 60 | def weight_filler(m): 61 | classname = m.__class__.__name__ 62 | if classname.find('Conv' or 'SNConv') != -1: 63 | m.weight.data.normal_(0.0, 0.02) 64 | elif classname.find('BatchNorm') != -1: 65 | m.weight.data.normal_(1.0, 0.02) 66 | m.bias.data.fill_(0) 67 | 68 | n_dis = opt.n_dis 69 | nz = opt.nz 70 | 71 | G = SNResGenerator(64, nz, 4) 72 | SND = SNResDiscriminator(64, 4) 73 | print(G) 74 | print(SND) 75 | G.apply(weight_filler) 76 | SND.apply(weight_filler) 77 | 78 | input = torch.FloatTensor(opt.batchsize, 3, 64, 64) 79 | noise = torch.FloatTensor(opt.batchsize, nz) 80 | fixed_noise = torch.FloatTensor(opt.batchsize, nz).normal_(0, 1) 81 | label = torch.FloatTensor(opt.batchsize) 82 | real_label = 1 83 | fake_label = 0 84 | 85 | fixed_noise = Variable(fixed_noise) 86 | criterion = nn.BCELoss() 87 | 88 | if opt.cuda: 89 | G.cuda() 90 | SND.cuda() 91 | criterion.cuda() 92 | input, label = input.cuda(), label.cuda() 93 | noise, fixed_noise = noise.cuda(), fixed_noise.cuda() 94 | 95 | optimizerG = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999)) 96 | optimizerSND = optim.Adam(SND.parameters(), lr=0.0002, betas=(0.5, 0.999)) 97 | 98 | for epoch in range(200): 99 | for i, data in enumerate(dataloader, 0): 100 | step = epoch * len(dataloader) + i 101 | ############################ 102 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) 103 | ########################### 104 | # train with real 105 | SND.zero_grad() 106 | real_cpu, _ = data 107 | batch_size = real_cpu.size(0) 108 | #if opt.cuda: 109 | # real_cpu = real_cpu.cuda() 110 | input.resize_(real_cpu.size()).copy_(real_cpu) 111 | label.resize_(batch_size).fill_(real_label) 112 | inputv = Variable(input) 113 | labelv = Variable(label) 114 | output = SND(inputv) 115 | 116 | #errD_real = torch.mean(F.softplus(-output)) 117 | errD_real = criterion(output, labelv) 118 | errD_real.backward() 119 | 120 | D_x = output.data.mean() 121 | # train with fake 122 | noise.resize_(batch_size, nz).normal_(0, 1) 123 | noisev = Variable(noise) 124 | fake = G(noisev) 125 | labelv = Variable(label.fill_(fake_label)) 126 | output = SND(fake.detach()) 127 | #errD_fake = torch.mean(F.softplus(output)) 128 | errD_fake = criterion(output, labelv) 129 | errD_fake.backward() 130 | D_G_z1 = output.data.mean() 131 | errD = errD_real + errD_fake 132 | 133 | optimizerSND.step() 134 | ############################ 135 | # (2) Update G network: maximize log(D(G(z))) 136 | ########################### 137 | if step % n_dis == 0: 138 | G.zero_grad() 139 | labelv = Variable(label.fill_(real_label)) # fake labels are real for generator cost 140 | output = SND(fake) 141 | #errG = torch.mean(F.softplus(-output)) 142 | errG = criterion(output, labelv) 143 | errG.backward() 144 | D_G_z2 = output.data.mean() 145 | optimizerG.step() 146 | if i % 20 == 0: 147 | print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' 148 | % (epoch, 200, i, len(dataloader), 149 | errD.data[0], errG.data[0], D_x, D_G_z1, D_G_z2)) 150 | if i % 100 == 0: 151 | vutils.save_image(real_cpu, 152 | '%s/real_samples.png' % 'log', 153 | normalize=True) 154 | fake = G(fixed_noise) 155 | vutils.save_image(fake.data, 156 | '%s/fake_samples_epoch_%03d.png' % ('log', epoch), 157 | normalize=True) 158 | 159 | # do checkpointing 160 | torch.save(G.state_dict(), '%s/netG_epoch_%d.pth' % ('log', epoch)) 161 | torch.save(SND.state_dict(), '%s/netD_epoch_%d.pth' % ('log', epoch)) 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from torchvision import datasets, transforms 6 | import torchvision.utils as vutils 7 | from torch.autograd import Variable 8 | import torch.utils.data 9 | import torch.backends.cudnn as cudnn 10 | 11 | import random 12 | import argparse 13 | from models.models import _netG, _netD 14 | 15 | parser = argparse.ArgumentParser(description='train SNDCGAN model') 16 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 17 | parser.add_argument('--gpu_ids', default=[0,1,2,3], help='gpu ids: e.g. 0,1,2, 0,2.') 18 | parser.add_argument('--manualSeed', type=int, help='manual seed') 19 | parser.add_argument('--n_dis', type=int, default=1, help='discriminator critic iters') 20 | parser.add_argument('--nz', type=int, default=128, help='dimention of lantent noise') 21 | parser.add_argument('--batchsize', type=int, default=64, help='training batch size') 22 | 23 | opt = parser.parse_args() 24 | print(opt) 25 | 26 | # dataset = datasets.ImageFolder(root='/home/chao/zero/datasets/cfp-dataset/Data/Images', 27 | # transform=transforms.Compose([ 28 | # transforms.Scale(32), 29 | # transforms.CenterCrop(32), 30 | # transforms.ToTensor(), 31 | # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 32 | # ]) 33 | # ) 34 | 35 | dataset = datasets.CIFAR10(root='dataset', download=True, 36 | transform=transforms.Compose([ 37 | transforms.Scale(32), 38 | transforms.ToTensor(), 39 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 40 | ])) 41 | 42 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchsize, 43 | shuffle=True, num_workers=int(2)) 44 | 45 | if opt.manualSeed is None: 46 | opt.manualSeed = random.randint(1, 10000) 47 | print("Random Seed: ", opt.manualSeed) 48 | random.seed(opt.manualSeed) 49 | torch.manual_seed(opt.manualSeed) 50 | 51 | if opt.cuda: 52 | torch.cuda.manual_seed_all(opt.manualSeed) 53 | torch.cuda.set_device(opt.gpu_ids[0]) 54 | 55 | cudnn.benchmark = True 56 | 57 | def weight_filler(m): 58 | classname = m.__class__.__name__ 59 | if classname.find('Conv' or 'SNConv') != -1: 60 | m.weight.data.normal_(0.0, 0.02) 61 | elif classname.find('BatchNorm') != -1: 62 | m.weight.data.normal_(1.0, 0.02) 63 | m.bias.data.fill_(0) 64 | 65 | n_dis = opt.n_dis 66 | nz = opt.nz 67 | 68 | G = _netG(nz, 3, 64) 69 | SND = _netD(3, 64) 70 | print(G) 71 | print(SND) 72 | G.apply(weight_filler) 73 | SND.apply(weight_filler) 74 | 75 | input = torch.FloatTensor(opt.batchsize, 3, 32, 32) 76 | noise = torch.FloatTensor(opt.batchsize, nz, 1, 1) 77 | fixed_noise = torch.FloatTensor(opt.batchsize, nz, 1, 1).normal_(0, 1) 78 | label = torch.FloatTensor(opt.batchsize) 79 | real_label = 1 80 | fake_label = 0 81 | 82 | fixed_noise = Variable(fixed_noise) 83 | criterion = nn.BCELoss() 84 | 85 | if opt.cuda: 86 | G.cuda() 87 | SND.cuda() 88 | criterion.cuda() 89 | input, label = input.cuda(), label.cuda() 90 | noise, fixed_noise = noise.cuda(), fixed_noise.cuda() 91 | 92 | optimizerG = optim.Adam(G.parameters(), lr=0.0002, betas=(0, 0.9)) 93 | optimizerSND = optim.Adam(SND.parameters(), lr=0.0002, betas=(0, 0.9)) 94 | 95 | for epoch in range(200): 96 | for i, data in enumerate(dataloader, 0): 97 | step = epoch * len(dataloader) + i 98 | ############################ 99 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) 100 | ########################### 101 | # train with real 102 | SND.zero_grad() 103 | real_cpu, _ = data 104 | batch_size = real_cpu.size(0) 105 | #if opt.cuda: 106 | # real_cpu = real_cpu.cuda() 107 | input.resize_(real_cpu.size()).copy_(real_cpu) 108 | label.resize_(batch_size).fill_(real_label) 109 | inputv = Variable(input) 110 | labelv = Variable(label) 111 | output = SND(inputv) 112 | 113 | #errD_real = torch.mean(F.softplus(-output)) 114 | errD_real = criterion(output, labelv) 115 | errD_real.backward() 116 | 117 | D_x = output.data.mean() 118 | # train with fake 119 | noise.resize_(batch_size, noise.size(1), noise.size(2), noise.size(3)).normal_(0, 1) 120 | noisev = Variable(noise) 121 | fake = G(noisev) 122 | labelv = Variable(label.fill_(fake_label)) 123 | output = SND(fake.detach()) 124 | #errD_fake = torch.mean(F.softplus(output)) 125 | errD_fake = criterion(output, labelv) 126 | errD_fake.backward() 127 | D_G_z1 = output.data.mean() 128 | errD = errD_real + errD_fake 129 | 130 | optimizerSND.step() 131 | ############################ 132 | # (2) Update G network: maximize log(D(G(z))) 133 | ########################### 134 | if step % n_dis == 0: 135 | G.zero_grad() 136 | labelv = Variable(label.fill_(real_label)) # fake labels are real for generator cost 137 | output = SND(fake) 138 | #errG = torch.mean(F.softplus(-output)) 139 | errG = criterion(output, labelv) 140 | errG.backward() 141 | D_G_z2 = output.data.mean() 142 | optimizerG.step() 143 | if i % 20 == 0: 144 | print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' 145 | % (epoch, 200, i, len(dataloader), 146 | errD.data[0], errG.data[0], D_x, D_G_z1, D_G_z2)) 147 | if i % 100 == 0: 148 | vutils.save_image(real_cpu, 149 | '%s/real_samples.png' % 'log', 150 | normalize=True) 151 | fake = G(fixed_noise) 152 | vutils.save_image(fake.data, 153 | '%s/fake_samples_epoch_%03d.png' % ('log', epoch), 154 | normalize=True) 155 | 156 | # do checkpointing 157 | torch.save(G.state_dict(), '%s/netG_epoch_%d.pth' % ('log', epoch)) 158 | torch.save(SND.state_dict(), '%s/netD_epoch_%d.pth' % ('log', epoch)) 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | --------------------------------------------------------------------------------