├── README.md ├── cgan.py ├── dcgan.py ├── images ├── cGAN │ ├── 0.png │ ├── 10.png │ ├── 100.png │ ├── 150.png │ ├── 199.png │ ├── 20.png │ ├── 30.png │ ├── 40.png │ └── 50.png ├── dcgan │ ├── 0.png │ ├── 10.png │ ├── 20.png │ ├── 30.png │ ├── 40.png │ ├── 50.png │ ├── 60.png │ ├── 70.png │ ├── 80.png │ ├── 90.png │ └── 99.png ├── improved_cgan │ ├── 0.png │ ├── 10.png │ ├── 100.png │ ├── 150.png │ ├── 199.png │ ├── 20.png │ ├── 30.png │ ├── 40.png │ └── 50.png ├── vanilla_gan │ ├── 0.png │ ├── 10.png │ ├── 100.png │ ├── 150.png │ ├── 18.png │ ├── 199.png │ ├── 20.png │ ├── 30.png │ ├── 40.png │ └── 50.png ├── wgan-gp │ ├── 0.png │ ├── 10.png │ ├── 100.png │ ├── 150.png │ ├── 199.png │ ├── 20.png │ ├── 30.png │ ├── 40.png │ └── 50.png └── wgan │ ├── 0.png │ ├── 10.png │ ├── 100.png │ ├── 150.png │ ├── 199.png │ ├── 20.png │ ├── 30.png │ ├── 40.png │ └── 50.png ├── improved_cgan.py ├── vanilla_gan.py ├── wgan-gp.py └── wgan.py /README.md: -------------------------------------------------------------------------------- 1 | ## GANs 2 | Simple Pytorch implementations of most used Generative Adversarial Network (GAN) varieties. 3 | 4 | ## GPU or CPU 5 | Support both GPU and CPU. 6 | 7 | ## Dependencies 8 | * [Anaconda (Python 2.7)](https://www.anaconda.com/download/) 9 | * [PyTorch 0.4.0](http://pytorch.org/) 10 | 11 | ## Table of Contents 12 | * [Vanilla GAN (GAN)](https://arxiv.org/pdf/1406.2661.pdf) 13 | * [Conditonal GAN (cGAN)](https://arxiv.org/pdf/1411.1784.pdf) 14 | * Improved Conditonal GAN (Improved cGAN) 15 | * [Deep Convolutional GAN (DCGAN)](https://arxiv.org/pdf/1511.06434.pdf) 16 | * [Wasserstein GAN (WGAN)](https://arxiv.org/pdf/1701.07875.pdf) 17 | * [Improved Training of Wasserstein GAN (WGAN-GP)](https://arxiv.org/pdf/1704.00028.pdf) 18 | 19 | ## Experiment Results 20 | 21 | ### Vanilla GAN (GAN) 22 | 23 | | epoch 0 | epoch 10 | epoch 20 | epoch 30 | epoch 40 | 24 | | :---: | :---: | :---: | :---: | :---: | 25 | | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/vanilla_gan/0.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/vanilla_gan/10.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/vanilla_gan/20.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/vanilla_gan/30.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/vanilla_gan/40.png?raw=true) | 26 | | epoch 50 | epoch 100 | epoch 150 | epoch 199 | - 27 | | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/vanilla_gan/50.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/vanilla_gan/100.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/vanilla_gan/150.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/vanilla_gan/199.png?raw=true) | - | 28 | 29 | --- 30 | 31 | ### Conditional GAN (cGAN) 32 | 33 | | epoch 0 | epoch 10 | epoch 20 | epoch 30 | epoch 40 | 34 | | :---: | :---: | :---: | :---: | :---: | 35 | | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/cGAN/0.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/cGAN/10.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/cGAN/20.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/cGAN/30.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/cGAN/40.png?raw=true) | 36 | | epoch 50 | epoch 100 | epoch 150 | epoch 199 | - 37 | | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/cGAN/50.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/cGAN/100.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/cGAN/150.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/cGAN/199.png?raw=true) | - | 38 | 39 | --- 40 | 41 | ### Improved Conditional GAN (Improved cGAN) 42 | 43 | | epoch 0 | epoch 10 | epoch 20 | epoch 30 | epoch 40 | 44 | | :---: | :---: | :---: | :---: | :---: | 45 | | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/improved_cgan/0.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/improved_cgan/10.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/improved_cgan/20.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/improved_cgan/30.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/improved_cgan/40.png?raw=true) | 46 | | epoch 50 | epoch 100 | epoch 150 | epoch 199 | - 47 | | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/improved_cgan/50.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/improved_cgan/100.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/improved_cgan/150.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/improved_cgan/199.png?raw=true) | - | 48 | 49 | 50 | ### Deep Convolutional GAN (DCGAN) 51 | 52 | | epoch 0 | epoch 10 | epoch 20 | epoch 30 | epoch 40 | 53 | | :---: | :---: | :---: | :---: | :---: | 54 | | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/dcgan/0.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/dcgan/10.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/dcgan/20.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/dcgan/30.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/dcgan/40.png?raw=true) | 55 | | epoch 50 | epoch 60 | epoch 70 | epoch 80 | epoch 90 | 56 | | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/dcgan/50.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/dcgan/60.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/dcgan/70.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/dcgan/80.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/dcgan/90.png?raw=true) | 57 | 58 | 59 | ### Wasserstein GAN (WGAN) 60 | 61 | | epoch 0 | epoch 10 | epoch 20 | epoch 30 | epoch 40 | 62 | | :---: | :---: | :---: | :---: | :---: | 63 | | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/wgan/0.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/wgan/10.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/wgan/20.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/wgan/30.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/wgan/40.png?raw=true) | 64 | | epoch 50 | epoch 100 | epoch 150 | epoch 199 | - | 65 | | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/wgan/50.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/wgan/100.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/wgan/150.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/wgan/199.png?raw=true) | - | 66 | 67 | 68 | ### Wasserstein GAN with Gradient Plenty (WGAN-GP) 69 | 70 | | epoch 0 | epoch 10 | epoch 20 | epoch 30 | epoch 40 | 71 | | :---: | :---: | :---: | :---: | :---: | 72 | | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/wgan-gp/0.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/wgan-gp/10.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/wgan-gp/20.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/wgan-gp/30.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/wgan-gp/40.png?raw=true) | 73 | | epoch 50 | epoch 100 | epoch 150 | epoch 199 | - | 74 | | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/wgan-gp/50.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/wgan-gp/100.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/wgan-gp/150.png?raw=true) | ![xxx](https://github.com/wangguanan/Pytorch-Basic-GANs/blob/master/images/wgan-gp/199.png?raw=true) | - | 75 | 76 | 77 | ## Acknowledgement 78 | This project is going with the [GAN Theory and Practice](https://study.163.com/course/courseLearn.htm?courseId=1006498024&share=2&shareId=400000000681046#/learn/live?lessonId=1054160393&courseId=1006498024) part of the [Deep Learning Course: from Algorithm to Practice](https://study.163.com/course/courseMain.htm?share=2&shareId=400000000681046&courseId=1006498024&_trace_c_p_k2_=d197343763ee421eae96c4cdb1b129cb). 79 | 80 | ## Contacts 81 | If you have any question about the project, please feel free to contact with me. 82 | 83 | E-mail: guan.wang0706@gmail.com 84 | 85 | 86 | # 87 | -------------------------------------------------------------------------------- /cgan.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | from torchvision.utils import save_image 8 | 9 | import numpy as np 10 | 11 | import os 12 | 13 | 14 | # 超参数 15 | gpu_id = None 16 | if gpu_id is not None: 17 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id 18 | device = torch.device('cuda') 19 | else: 20 | device = torch.device('cpu') 21 | if os.path.exists('cgan_images') is False: 22 | os.makedirs('cgan_images') 23 | z_dim = 100 24 | batch_size = 64 25 | learning_rate = 0.0002 26 | total_epochs = 200 27 | 28 | 29 | class Discriminator(nn.Module): 30 | '''全连接判别器,用于1x28x28的MNIST数据,输出是数据和类别''' 31 | def __init__(self): 32 | super(Discriminator, self).__init__() 33 | 34 | layers = [] 35 | # 第一层 36 | layers.append(nn.Linear(in_features=28*28+10, out_features=512, bias=True)) 37 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 38 | # 第二层 39 | layers.append(nn.Linear(in_features=512, out_features=256, bias=True)) 40 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 41 | # 输出层 42 | layers.append(nn.Linear(in_features=256, out_features=1, bias=True)) 43 | layers.append(nn.Sigmoid()) 44 | 45 | self.model = nn.Sequential(*layers) 46 | 47 | def forward(self, x, c): 48 | x = x.view(x.size(0), -1) 49 | validity = self.model(torch.cat([x, c], -1)) 50 | return validity 51 | 52 | class Generator(nn.Module): 53 | '''全连接生成器,用于1x28x28的MNIST数据,输入是噪声和类别''' 54 | def __init__(self, z_dim): 55 | super(Generator, self).__init__() 56 | 57 | layers = [] 58 | # 第一层 59 | layers.append(nn.Linear(in_features=z_dim+10, out_features=128)) 60 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 61 | # 第二层 62 | layers.append(nn.Linear(in_features=128, out_features=256)) 63 | layers.append(nn.BatchNorm1d(256, 0.8)) 64 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 65 | # 第三层 66 | layers.append(nn.Linear(in_features=256, out_features=512)) 67 | layers.append(nn.BatchNorm1d(512, 0.8)) 68 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 69 | # 输出层 70 | layers.append(nn.Linear(in_features=512, out_features=28*28)) 71 | layers.append(nn.Tanh()) 72 | 73 | self.model = nn.Sequential(*layers) 74 | 75 | def forward(self, z, c): 76 | x = self.model(torch.cat([z, c], dim=1)) 77 | x = x.view(-1, 1, 28, 28) 78 | return x 79 | 80 | 81 | def one_hot(labels, class_num): 82 | '''把标签转换成one-hot类型''' 83 | tmp = torch.FloatTensor(labels.size(0), class_num).zero_() 84 | one_hot = tmp.scatter_(dim=1, index=torch.LongTensor(labels.view(-1, 1)), value=1) 85 | return one_hot 86 | 87 | 88 | # 初始化构建判别器和生成器 89 | discriminator = Discriminator().to(device) 90 | generator = Generator(z_dim=z_dim).to(device) 91 | 92 | # 初始化二值交叉熵损失 93 | bce = torch.nn.BCELoss().to(device) 94 | ones = torch.ones(batch_size).to(device) 95 | zeros = torch.zeros(batch_size).to(device) 96 | 97 | # 初始化优化器,使用Adam优化器 98 | g_optimizer = optim.Adam(generator.parameters(), lr=learning_rate, betas=[0.5, 0.999]) 99 | d_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=[0.5, 0.999]) 100 | 101 | # 加载MNIST数据集 102 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 103 | dataset = torchvision.datasets.MNIST(root='data/', train=True, transform=transform, download=True) 104 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) 105 | 106 | #用于生成效果图 107 | # 生成100个one_hot向量,每类10个 108 | fixed_c = torch.FloatTensor(100, 10).zero_() 109 | fixed_c = fixed_c.scatter_(dim=1, index=torch.LongTensor(np.array(np.arange(0, 10).tolist()*10).reshape([100, 1])), value=1) 110 | fixed_c = fixed_c.to(device) 111 | # 生成100个随机噪声向量 112 | fixed_z = torch.randn([100, z_dim]).to(device) 113 | 114 | # 开始训练,一共训练total_epochs 115 | for epoch in range(total_epochs): 116 | 117 | # 在训练阶段,把生成器设置为训练模式;对应于后面的,在测试阶段,把生成器设置为测试模式 118 | generator = generator.train() 119 | 120 | # 训练一个epoch 121 | for i, data in enumerate(dataloader): 122 | 123 | # 加载真实数据 124 | real_images, real_labels = data 125 | real_images = real_images.to(device) 126 | # 把对应的标签转化成 one-hot 类型 127 | tmp = torch.FloatTensor(real_labels.size(0), 10).zero_() 128 | real_labels = tmp.scatter_(dim=1, index=torch.LongTensor(real_labels.view(-1, 1)), value=1) 129 | real_labels = real_labels.to(device) 130 | 131 | # 生成数据 132 | # 用正态分布中采样batch_size个随机噪声 133 | z = torch.randn([batch_size, z_dim]).to(device) 134 | # 生成 batch_size 个 ont-hot 标签 135 | c = torch.FloatTensor(batch_size, 10).zero_() 136 | c = c.scatter_(dim=1, index=torch.LongTensor(np.random.choice(10, batch_size).reshape([batch_size, 1])), value=1) 137 | c = c.to(device) 138 | # 生成数据 139 | fake_images = generator(z,c) 140 | 141 | # 计算判别器损失,并优化判别器 142 | real_loss = bce(discriminator(real_images, real_labels), ones) 143 | fake_loss = bce(discriminator(fake_images.detach(), c), zeros) 144 | d_loss = real_loss + fake_loss 145 | 146 | d_optimizer.zero_grad() 147 | d_loss.backward() 148 | d_optimizer.step() 149 | 150 | # 计算生成器损失,并优化生成器 151 | g_loss = bce(discriminator(fake_images, c), ones) 152 | 153 | g_optimizer.zero_grad() 154 | g_loss.backward() 155 | g_optimizer.step() 156 | 157 | # 输出损失 158 | print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, total_epochs, i, len(dataloader), d_loss.item(), g_loss.item())) 159 | 160 | # 把生成器设置为测试模型,生成效果图并保存 161 | generator = generator.eval() 162 | fixed_fake_images = generator(fixed_z, fixed_c) 163 | save_image(fixed_fake_images, 'cgan_images/{}.png'.format(epoch), nrow=10, normalize=True) 164 | 165 | -------------------------------------------------------------------------------- /dcgan.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | from torchvision.utils import save_image 8 | 9 | import os 10 | 11 | 12 | # 超参数 13 | gpu_id = None 14 | if gpu_id is not None: 15 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id 16 | device = torch.device('cuda') 17 | else: 18 | device = torch.device('cpu') 19 | if os.path.exists('dcgan_images') is False: 20 | os.makedirs('dcgan_images') 21 | z_dim = 100 22 | batch_size = 64 23 | learning_rate = 0.0002 24 | total_epochs = 100 25 | 26 | 27 | def weights_init_normal(m): 28 | classname = m.__class__.__name__ 29 | if classname.find('Conv') != -1: 30 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 31 | elif classname.find('BatchNorm2d') != -1: 32 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 33 | torch.nn.init.constant_(m.bias.data, 0.0) 34 | 35 | class Discriminator(nn.Module): 36 | '''滑动卷积判别器''' 37 | def __init__(self): 38 | super(Discriminator, self).__init__() 39 | 40 | # 定义卷积层 41 | conv = [] 42 | # 第一个滑动卷积层,不使用BN,LRelu激活函数 43 | conv.append(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=2, padding=1)) 44 | conv.append(nn.LeakyReLU(0.2, inplace=True)) 45 | # 第二个滑动卷积层,包含BN,LRelu激活函数 46 | conv.append(nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1)) 47 | conv.append(nn.BatchNorm2d(32)) 48 | conv.append(nn.LeakyReLU(0.2, inplace=True)) 49 | # 第三个滑动卷积层,包含BN,LRelu激活函数 50 | conv.append(nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1)) 51 | conv.append(nn.BatchNorm2d(64)) 52 | conv.append(nn.LeakyReLU(0.2, inplace=True)) 53 | # 第四个滑动卷积层,包含BN,LRelu激活函数 54 | conv.append(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=1)) 55 | conv.append(nn.BatchNorm2d(128)) 56 | conv.append(nn.LeakyReLU(0.2, inplace=True)) 57 | # 卷积层 58 | self.conv = nn.Sequential(*conv) 59 | 60 | # 全连接层+Sigmoid激活函数 61 | self.linear = nn.Sequential(nn.Linear(in_features=128, out_features=1), nn.Sigmoid()) 62 | 63 | def forward(self, x): 64 | x = self.conv(x) 65 | x = x.view(x.size(0), -1) 66 | validity = self.linear(x) 67 | return validity 68 | 69 | class Generator(nn.Module): 70 | '''反滑动卷积生成器''' 71 | 72 | def __init__(self, z_dim): 73 | super(Generator, self).__init__() 74 | 75 | self.z_dim = z_dim 76 | layers = [] 77 | 78 | # 第一层:把输入线性变换成256x4x4的矩阵,并在这个基础上做反卷机操作 79 | self.linear = nn.Linear(self.z_dim, 4*4*256) 80 | # 第二层:bn+relu 81 | layers.append(nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=0)) 82 | layers.append(nn.BatchNorm2d(128)) 83 | layers.append(nn.ReLU(inplace=True)) 84 | # 第三层:bn+relu 85 | layers.append(nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1)) 86 | layers.append(nn.BatchNorm2d(64)) 87 | layers.append(nn.ReLU(inplace=True)) 88 | # 第四层:不使用BN,使用tanh激活函数 89 | layers.append(nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=4, stride=2, padding=2)) 90 | layers.append(nn.Tanh()) 91 | 92 | self.model = nn.Sequential(*layers) 93 | 94 | def forward(self, z): 95 | # 把随机噪声经过线性变换,resize成256x4x4的大小 96 | x = self.linear(z) 97 | x = x.view([x.size(0), 256, 4, 4]) 98 | # 生成图片 99 | x = self.model(x) 100 | return x 101 | 102 | # 构建判别器和生成器 103 | discriminator = Discriminator().to(device) 104 | generator = Generator(z_dim=z_dim).to(device) 105 | # 使用均值为0,方差为0.02的正态分布初始化神经网络 106 | generator.apply(weights_init_normal) 107 | discriminator.apply(weights_init_normal) 108 | 109 | # 初始化二值交叉熵损失 110 | bce = torch.nn.BCELoss().to(device) 111 | ones = torch.ones(batch_size).to(device) 112 | zeros = torch.zeros(batch_size).to(device) 113 | 114 | # 初始化优化器,使用Adam优化器 115 | g_optimizer = optim.Adam(generator.parameters(), lr=learning_rate, betas=[0.5, 0.999]) 116 | d_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=[0.5, 0.999]) 117 | 118 | # 加载MNIST数据集 119 | transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 120 | dataset = torchvision.datasets.MNIST(root='data/', train=True, transform=transform, download=True) 121 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) 122 | 123 | # 随机产生100个向量,用于生成效果图 124 | fixed_z = torch.randn([100, z_dim]).to(device) 125 | 126 | # 开始训练,一共训练total_epochs 127 | for epoch in range(total_epochs): 128 | 129 | # 在训练阶段,把生成器设置为训练模型;对应于后面的,在测试阶段,把生成器设置为测试模型 130 | generator = generator.train() 131 | 132 | # 训练一个epoch 133 | for i, data in enumerate(dataloader): 134 | 135 | # 加载真实数据,不加载标签 136 | real_images, _ = data 137 | real_images = real_images.to(device) 138 | 139 | # 用正态分布中采样batch_size个噪声,然后生成对应的图片 140 | z = torch.randn([batch_size, z_dim]).to(device) 141 | fake_images = generator(z) 142 | 143 | # 计算判别器损失,并优化判别器 144 | real_loss = bce(discriminator(real_images), ones) 145 | fake_loss = bce(discriminator(fake_images.detach()), zeros) 146 | d_loss = real_loss + fake_loss 147 | 148 | d_optimizer.zero_grad() 149 | d_loss.backward() 150 | d_optimizer.step() 151 | 152 | # 计算生成器损失,并优化生成器 153 | g_loss = bce(discriminator(fake_images), ones) 154 | 155 | g_optimizer.zero_grad() 156 | g_loss.backward() 157 | g_optimizer.step() 158 | 159 | # 输出损失 160 | print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, total_epochs, i, len(dataloader), d_loss.item(), g_loss.item())) 161 | 162 | # 把生成器设置为测试模型,生成效果图并保存 163 | generator = generator.eval() 164 | fixed_fake_images = generator(fixed_z) 165 | save_image(fixed_fake_images, 'dcgan_images/{}.png'.format(epoch), nrow=10, normalize=True) 166 | 167 | -------------------------------------------------------------------------------- /images/cGAN/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/cGAN/0.png -------------------------------------------------------------------------------- /images/cGAN/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/cGAN/10.png -------------------------------------------------------------------------------- /images/cGAN/100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/cGAN/100.png -------------------------------------------------------------------------------- /images/cGAN/150.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/cGAN/150.png -------------------------------------------------------------------------------- /images/cGAN/199.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/cGAN/199.png -------------------------------------------------------------------------------- /images/cGAN/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/cGAN/20.png -------------------------------------------------------------------------------- /images/cGAN/30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/cGAN/30.png -------------------------------------------------------------------------------- /images/cGAN/40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/cGAN/40.png -------------------------------------------------------------------------------- /images/cGAN/50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/cGAN/50.png -------------------------------------------------------------------------------- /images/dcgan/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/dcgan/0.png -------------------------------------------------------------------------------- /images/dcgan/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/dcgan/10.png -------------------------------------------------------------------------------- /images/dcgan/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/dcgan/20.png -------------------------------------------------------------------------------- /images/dcgan/30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/dcgan/30.png -------------------------------------------------------------------------------- /images/dcgan/40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/dcgan/40.png -------------------------------------------------------------------------------- /images/dcgan/50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/dcgan/50.png -------------------------------------------------------------------------------- /images/dcgan/60.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/dcgan/60.png -------------------------------------------------------------------------------- /images/dcgan/70.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/dcgan/70.png -------------------------------------------------------------------------------- /images/dcgan/80.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/dcgan/80.png -------------------------------------------------------------------------------- /images/dcgan/90.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/dcgan/90.png -------------------------------------------------------------------------------- /images/dcgan/99.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/dcgan/99.png -------------------------------------------------------------------------------- /images/improved_cgan/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/improved_cgan/0.png -------------------------------------------------------------------------------- /images/improved_cgan/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/improved_cgan/10.png -------------------------------------------------------------------------------- /images/improved_cgan/100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/improved_cgan/100.png -------------------------------------------------------------------------------- /images/improved_cgan/150.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/improved_cgan/150.png -------------------------------------------------------------------------------- /images/improved_cgan/199.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/improved_cgan/199.png -------------------------------------------------------------------------------- /images/improved_cgan/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/improved_cgan/20.png -------------------------------------------------------------------------------- /images/improved_cgan/30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/improved_cgan/30.png -------------------------------------------------------------------------------- /images/improved_cgan/40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/improved_cgan/40.png -------------------------------------------------------------------------------- /images/improved_cgan/50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/improved_cgan/50.png -------------------------------------------------------------------------------- /images/vanilla_gan/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/vanilla_gan/0.png -------------------------------------------------------------------------------- /images/vanilla_gan/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/vanilla_gan/10.png -------------------------------------------------------------------------------- /images/vanilla_gan/100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/vanilla_gan/100.png -------------------------------------------------------------------------------- /images/vanilla_gan/150.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/vanilla_gan/150.png -------------------------------------------------------------------------------- /images/vanilla_gan/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/vanilla_gan/18.png -------------------------------------------------------------------------------- /images/vanilla_gan/199.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/vanilla_gan/199.png -------------------------------------------------------------------------------- /images/vanilla_gan/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/vanilla_gan/20.png -------------------------------------------------------------------------------- /images/vanilla_gan/30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/vanilla_gan/30.png -------------------------------------------------------------------------------- /images/vanilla_gan/40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/vanilla_gan/40.png -------------------------------------------------------------------------------- /images/vanilla_gan/50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/vanilla_gan/50.png -------------------------------------------------------------------------------- /images/wgan-gp/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/wgan-gp/0.png -------------------------------------------------------------------------------- /images/wgan-gp/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/wgan-gp/10.png -------------------------------------------------------------------------------- /images/wgan-gp/100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/wgan-gp/100.png -------------------------------------------------------------------------------- /images/wgan-gp/150.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/wgan-gp/150.png -------------------------------------------------------------------------------- /images/wgan-gp/199.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/wgan-gp/199.png -------------------------------------------------------------------------------- /images/wgan-gp/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/wgan-gp/20.png -------------------------------------------------------------------------------- /images/wgan-gp/30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/wgan-gp/30.png -------------------------------------------------------------------------------- /images/wgan-gp/40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/wgan-gp/40.png -------------------------------------------------------------------------------- /images/wgan-gp/50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/wgan-gp/50.png -------------------------------------------------------------------------------- /images/wgan/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/wgan/0.png -------------------------------------------------------------------------------- /images/wgan/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/wgan/10.png -------------------------------------------------------------------------------- /images/wgan/100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/wgan/100.png -------------------------------------------------------------------------------- /images/wgan/150.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/wgan/150.png -------------------------------------------------------------------------------- /images/wgan/199.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/wgan/199.png -------------------------------------------------------------------------------- /images/wgan/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/wgan/20.png -------------------------------------------------------------------------------- /images/wgan/30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/wgan/30.png -------------------------------------------------------------------------------- /images/wgan/40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/wgan/40.png -------------------------------------------------------------------------------- /images/wgan/50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangguanan/Pytorch-Basic-GANs/74c2f8b52be94e943c447fef82313a4152028f4a/images/wgan/50.png -------------------------------------------------------------------------------- /improved_cgan.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | from torchvision.utils import save_image 8 | 9 | import numpy as np 10 | 11 | import os 12 | 13 | 14 | # 超参数 15 | gpu_id = None 16 | if gpu_id is not None: 17 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id 18 | device = torch.device('cuda') 19 | else: 20 | device = torch.device('cpu') 21 | if os.path.exists('improved_cgan_images') is False: 22 | os.makedirs('improved_cgan_images') 23 | z_dim = 100 24 | batch_size = 64 25 | learning_rate = 0.0002 26 | total_epochs = 200 27 | 28 | 29 | class Discriminator(nn.Module): 30 | '''全连接判别器,用于1x28x28的MNIST数据,输出是数据和类别''' 31 | def __init__(self): 32 | super(Discriminator, self).__init__() 33 | 34 | layers = [] 35 | # 第一层 36 | layers.append(nn.Linear(in_features=28*28+10, out_features=512, bias=True)) 37 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 38 | # 第二层 39 | layers.append(nn.Linear(in_features=512, out_features=256, bias=True)) 40 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 41 | # 输出层 42 | layers.append(nn.Linear(in_features=256, out_features=1, bias=True)) 43 | layers.append(nn.Sigmoid()) 44 | 45 | self.model = nn.Sequential(*layers) 46 | 47 | def forward(self, x, c): 48 | x = x.view(x.size(0), -1) 49 | validity = self.model(torch.cat([x, c], -1)) 50 | return validity 51 | 52 | class Generator(nn.Module): 53 | '''全连接生成器,用于1x28x28的MNIST数据,输入是噪声和类别''' 54 | def __init__(self, z_dim): 55 | super(Generator, self).__init__() 56 | 57 | layers = [] 58 | # 第一层 59 | layers.append(nn.Linear(in_features=z_dim+10, out_features=128)) 60 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 61 | # 第二层 62 | layers.append(nn.Linear(in_features=128, out_features=256)) 63 | layers.append(nn.BatchNorm1d(256, 0.8)) 64 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 65 | # 第三层 66 | layers.append(nn.Linear(in_features=256, out_features=512)) 67 | layers.append(nn.BatchNorm1d(512, 0.8)) 68 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 69 | # 输出层 70 | layers.append(nn.Linear(in_features=512, out_features=28*28)) 71 | layers.append(nn.Tanh()) 72 | 73 | self.model = nn.Sequential(*layers) 74 | 75 | def forward(self, z, c): 76 | x = self.model(torch.cat([z, c], dim=1)) 77 | x = x.view(-1, 1, 28, 28) 78 | return x 79 | 80 | 81 | def one_hot(labels, class_num): 82 | '''把标签转换成one-hot类型''' 83 | tmp = torch.FloatTensor(labels.size(0), class_num).zero_() 84 | one_hot = tmp.scatter_(dim=1, index=torch.LongTensor(labels.view(-1, 1)), value=1) 85 | return one_hot 86 | 87 | 88 | # 初始化构建判别器和生成器 89 | discriminator = Discriminator().to(device) 90 | generator = Generator(z_dim=z_dim).to(device) 91 | 92 | # 初始化二值交叉熵损失 93 | bce = torch.nn.BCELoss().to(device) 94 | ones = torch.ones(batch_size).to(device) 95 | zeros = torch.zeros(batch_size).to(device) 96 | 97 | # 初始化优化器,使用Adam优化器 98 | g_optimizer = optim.Adam(generator.parameters(), lr=learning_rate, betas=[0.5, 0.999]) 99 | d_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=[0.5, 0.999]) 100 | 101 | # 加载MNIST数据集 102 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 103 | dataset = torchvision.datasets.MNIST(root='data/', train=True, transform=transform, download=True) 104 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) 105 | 106 | #用于生成效果图 107 | # 生成100个one_hot向量,每类10个 108 | fixed_c = torch.FloatTensor(100, 10).zero_() 109 | fixed_c = fixed_c.scatter_(dim=1, index=torch.LongTensor(np.array(np.arange(0,10).tolist()*10).reshape([100, 1])), value=1) 110 | fixed_c = fixed_c.to(device) 111 | # 生成100个随机噪声向量 112 | fixed_z = torch.randn([100, z_dim]).to(device) 113 | 114 | # 开始训练,一共训练total_epochs 115 | for epoch in range(total_epochs): 116 | 117 | # 在训练阶段,把生成器设置为训练模式;对应于后面的,在测试阶段,把生成器设置为测试模式 118 | generator = generator.train() 119 | 120 | # 训练一个epoch 121 | for i, data in enumerate(dataloader): 122 | 123 | # 加载真实数据 124 | real_images, real_labels = data 125 | wrong_labels = [] 126 | for real_label in real_labels: 127 | tmp = np.arange(10).tolist() 128 | tmp.remove(float(real_label.data)) 129 | wrong_labels.append(np.random.choice(tmp, 1)[0]) 130 | wrong_labels = torch.LongTensor(wrong_labels) 131 | # 132 | real_images = real_images.to(device) 133 | # 把对应的标签转化成 one-hot 类型 134 | tmp = torch.FloatTensor(real_labels.size(0), 10).zero_() 135 | real_labels = tmp.scatter_(dim=1, index=torch.LongTensor(real_labels.view(-1, 1)), value=1) 136 | tmp = torch.FloatTensor(real_labels.size(0), 10).zero_() 137 | wrong_labels = tmp.scatter_(dim=1, index=torch.LongTensor(wrong_labels.view(-1, 1)), value=1) 138 | real_labels = real_labels.to(device) 139 | wrong_labels = wrong_labels.to(device) 140 | 141 | # 生成数据 142 | # 用正态分布中采样batch_size个随机噪声 143 | z = torch.randn([batch_size, z_dim]).to(device) 144 | # 生成 batch_size 个 ont-hot 标签 145 | c = torch.FloatTensor(batch_size, 10).zero_() 146 | c = c.scatter_(dim=1, index=torch.LongTensor(np.random.choice(10, batch_size).reshape([batch_size, 1])), value=1) 147 | c = c.to(device) 148 | # 生成数据 149 | fake_images = generator(z,c) 150 | 151 | # 计算判别器损失,并优化判别器 152 | real_loss = bce(discriminator(real_images, real_labels), ones) 153 | fake_loss_1 = bce(discriminator(real_images, wrong_labels), zeros) 154 | fake_loss_2 = bce(discriminator(fake_images.detach(), c), zeros) 155 | d_loss = real_loss + fake_loss_1 + fake_loss_2 156 | 157 | d_optimizer.zero_grad() 158 | d_loss.backward() 159 | d_optimizer.step() 160 | 161 | # 计算生成器损失,并优化生成器 162 | g_loss = bce(discriminator(fake_images, c), ones) 163 | 164 | g_optimizer.zero_grad() 165 | g_loss.backward() 166 | g_optimizer.step() 167 | 168 | # 输出损失 169 | print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, total_epochs, i, len(dataloader), d_loss.item(), g_loss.item())) 170 | 171 | # 把生成器设置为测试模型,生成效果图并保存 172 | generator = generator.eval() 173 | fixed_fake_images = generator(fixed_z, fixed_c) 174 | save_image(fixed_fake_images, 'improved_cgan_images/{}.png'.format(epoch), nrow=10, normalize=True) 175 | 176 | -------------------------------------------------------------------------------- /vanilla_gan.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | from torchvision.utils import save_image 8 | 9 | import os 10 | 11 | 12 | ########################################################## 13 | # 超参数 14 | ########################################################## 15 | gpu_id = None #‘0’ 16 | if gpu_id is not None: 17 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id 18 | device = torch.device('cuda') 19 | else: 20 | device = torch.device('cpu') 21 | if os.path.exists('gan_images') is False: 22 | os.makedirs('gan_images') 23 | z_dim = 100 24 | batch_size = 64 25 | learning_rate = 0.0002 26 | total_epochs = 200 27 | 28 | 29 | ########################################################## 30 | # 定义模型 31 | ########################################################## 32 | class Discriminator(nn.Module): 33 | '''全连接判别器,用于1x28x28的MNIST数据''' 34 | def __init__(self): 35 | super(Discriminator, self).__init__() 36 | 37 | layers = [] 38 | # 第一层 39 | layers.append(nn.Linear(in_features=28*28, out_features=512, bias=True)) 40 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 41 | # 第二层 42 | layers.append(nn.Linear(in_features=512, out_features=256, bias=True)) 43 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 44 | # 输出层 45 | layers.append(nn.Linear(in_features=256, out_features=1, bias=True)) 46 | layers.append(nn.Sigmoid()) 47 | 48 | self.model = nn.Sequential(*layers) 49 | 50 | def forward(self, x): 51 | x = x.view(x.size(0), -1) 52 | validity = self.model(x) 53 | return validity 54 | 55 | class Generator(nn.Module): 56 | '''全连接生成器,用于1x28x28的MNIST数据''' 57 | def __init__(self, z_dim): 58 | super(Generator, self).__init__() 59 | 60 | layers= [] 61 | # 第一层 62 | layers.append(nn.Linear(in_features=z_dim, out_features=128)) 63 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 64 | # 第二层 65 | layers.append(nn.Linear(in_features=128, out_features=256)) 66 | layers.append(nn.BatchNorm1d(256, 0.8)) 67 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 68 | # 第三层 69 | layers.append(nn.Linear(in_features=256, out_features=512)) 70 | layers.append(nn.BatchNorm1d(512, 0.8)) 71 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 72 | # 输出层 73 | layers.append(nn.Linear(in_features=512, out_features=28*28)) 74 | layers.append(nn.Tanh()) #[-1,1] 75 | 76 | self.model = nn.Sequential(*layers) 77 | 78 | def forward(self, z): 79 | x = self.model(z) 80 | x = x.view(-1, 1, 28, 28) 81 | return x 82 | 83 | # 初始化构建判别器和生成器 84 | discriminator = Discriminator().to(device) 85 | generator = Generator(z_dim=z_dim).to(device) 86 | 87 | 88 | ########################################################## 89 | # 准备工作 90 | ########################################################## 91 | # 初始化二值交叉熵损失 92 | bce = nn.BCELoss().to(device) 93 | ones = torch.ones(batch_size).to(device) 94 | zeros = torch.zeros(batch_size).to(device) 95 | 96 | # 初始化优化器,使用Adam优化器 97 | g_optimizer = optim.Adam(generator.parameters(), lr=learning_rate, betas=[0.5, 0.999]) 98 | d_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=[0.5, 0.999]) 99 | 100 | # 加载MNIST数据集 101 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 102 | dataset = torchvision.datasets.MNIST(root='data/', train=True, transform=transform, download=True) 103 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) 104 | 105 | # 随机产生100个向量,用于生成效果图 106 | fixed_z = torch.randn([100, z_dim]).to(device) 107 | 108 | 109 | ########################################################## 110 | # 开始训练,一共训练total_epochs 111 | ########################################################## 112 | for epoch in range(total_epochs): 113 | 114 | # 在训练阶段,把生成器设置为训练模式;对应于后面的,在测试阶段,把生成器设置为测试模式 115 | generator = generator.train() 116 | 117 | # 训练一个epoch 118 | for i, data in enumerate(dataloader): 119 | 120 | # 加载真实数据,不加载标签 121 | real_images, _ = data 122 | real_images = real_images.to(device) 123 | 124 | # 从正态分布中采样batch_size个噪声,然后生成对应的图片 125 | z = torch.randn([batch_size, z_dim]).to(device) 126 | fake_images = generator(z) 127 | 128 | # 计算判别器损失,并优化判别器 129 | real_loss = bce(discriminator(real_images), ones) 130 | fake_loss = bce(discriminator(fake_images.detach()), zeros) 131 | d_loss = real_loss + fake_loss 132 | 133 | d_optimizer.zero_grad() 134 | d_loss.backward() 135 | d_optimizer.step() 136 | 137 | # 计算生成器损失,并优化生成器 138 | g_loss = bce(discriminator(fake_images), ones) 139 | 140 | g_optimizer.zero_grad() 141 | g_loss.backward() 142 | g_optimizer.step() 143 | 144 | # 输出损失 145 | print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, total_epochs, i, len(dataloader), d_loss.item(), g_loss.item())) 146 | 147 | # 把生成器设置为测试模型,生成效果图并保存 148 | generator = generator.eval() 149 | fixed_fake_images = generator(fixed_z) 150 | save_image(fixed_fake_images, 'gan_images/{}.png'.format(epoch), nrow=10, normalize=True) 151 | 152 | -------------------------------------------------------------------------------- /wgan-gp.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.autograd as autograd 5 | import torch.optim as optim 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | from torchvision.utils import save_image 9 | 10 | import numpy as np 11 | 12 | import os 13 | 14 | 15 | # 超参数 16 | gpu_id = None 17 | if gpu_id is not None: 18 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id 19 | device = torch.device('cuda') 20 | else: 21 | device = torch.device('cpu') 22 | if os.path.exists('wgangp_images') is False: 23 | os.makedirs('wgangp_images') 24 | z_dim = 100 25 | batch_size = 64 26 | total_epochs = 200 27 | learning_rate = 0.0001 28 | weight_gp = 10 29 | n_critic = 5 30 | 31 | 32 | class Discriminator(nn.Module): 33 | '''全连接判别器,用于1x28x28的MNIST数据''' 34 | 35 | def __init__(self): 36 | super(Discriminator, self).__init__() 37 | 38 | layers = [] 39 | 40 | # 第一层 41 | layers.append(nn.Linear(in_features=28*28, out_features=512, bias=True)) 42 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 43 | 44 | # 第二层 45 | layers.append(nn.Linear(in_features=512, out_features=256, bias=True)) 46 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 47 | 48 | # 输出层:相比gan,不需要sigmoid 49 | layers.append(nn.Linear(in_features=256, out_features=1, bias=True)) 50 | 51 | self.model = nn.Sequential(*layers) 52 | 53 | def forward(self, x): 54 | x = x.view(x.size(0), -1) 55 | validity = self.model(x) 56 | return validity 57 | 58 | 59 | class Generator(nn.Module): 60 | '''全连接生成器,用于1x28x28的MNIST数据''' 61 | 62 | def __init__(self, z_dim): 63 | super(Generator, self).__init__() 64 | 65 | layers= [] 66 | 67 | # 第一层 68 | layers.append(nn.Linear(in_features=z_dim, out_features=128)) 69 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 70 | 71 | # 第二层 72 | layers.append(nn.Linear(in_features=128, out_features=256)) 73 | layers.append(nn.BatchNorm1d(256, 0.8)) 74 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 75 | 76 | # 第三层 77 | layers.append(nn.Linear(in_features=256, out_features=512)) 78 | layers.append(nn.BatchNorm1d(512, 0.8)) 79 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 80 | 81 | # 输出层 82 | layers.append(nn.Linear(in_features=512, out_features=28*28)) 83 | layers.append(nn.Tanh()) 84 | 85 | self.model = nn.Sequential(*layers) 86 | 87 | def forward(self, z): 88 | x = self.model(z) 89 | x = x.view(-1, 1, 28, 28) 90 | return x 91 | 92 | # 初始化构建判别器和生成器 93 | discriminator = Discriminator().to(device) 94 | generator = Generator(z_dim=z_dim).to(device) 95 | 96 | # 计算梯度惩罚正则项 97 | def compute_gradient_penalty(D, real_samples, fake_samples): 98 | # 在真实样本所在空间 和 生成样本空间 之间采样样本 (通过插值进行采样) 99 | alpha = torch.Tensor(np.random.random((real_samples.size(0), 1, 1, 1))).to(device) 100 | interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) 101 | # 计算判别器对于这些样本的梯度 102 | d_interpolates = D(interpolates) 103 | fake = autograd.Variable(torch.Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False).to(device) 104 | gradients = autograd.grad( 105 | outputs=d_interpolates, 106 | inputs=interpolates, 107 | grad_outputs=fake, 108 | create_graph=True, 109 | retain_graph=True, 110 | only_inputs=True, 111 | )[0] 112 | gradients = gradients.view(gradients.size(0), -1) 113 | # 计算梯度损失 114 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() 115 | return gradient_penalty 116 | 117 | # 初始化优化器,使用Adam优化器 118 | g_optimizer = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=[0.5, 0.9]) 119 | d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=[0.5, 0.9]) 120 | 121 | # 加载MNIST数据集 122 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 123 | dataset = torchvision.datasets.MNIST(root='data/', train=True, transform=transform, download=True) 124 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) 125 | 126 | # 随机产生100个向量,用于生成效果图 127 | fixed_z = torch.randn([100, z_dim]).to(device) 128 | 129 | # 开始训练,一共训练total_epochs 130 | for epoch in range(total_epochs): 131 | 132 | # 在训练阶段,把生成器设置为训练模式;对应于后面的,在测试阶段,把生成器设置为测试模式 133 | generator = generator.train() 134 | 135 | # 训练一个epoch 136 | for i, data in enumerate(dataloader): 137 | 138 | # 加载真实数据,不加载标签 139 | real_images, _ = data 140 | real_images = real_images.to(device) 141 | 142 | # 从正态分布中采样batch_size个噪声,然后生成对应的图片 143 | z = torch.randn([batch_size, z_dim]).to(device) 144 | fake_images = generator(z) 145 | 146 | # 计算判别器损失,并优化判别器 147 | d_gan = - torch.mean(discriminator(real_images)) + torch.mean(discriminator(fake_images)) 148 | d_gp = compute_gradient_penalty(discriminator, real_images.data, fake_images.data) 149 | d_loss = d_gan + weight_gp * d_gp 150 | d_optimizer.zero_grad() 151 | d_loss.backward() 152 | d_optimizer.step() 153 | 154 | # 每优化n_critic次判别器, 优化一次生成器 155 | if i % n_critic == 0: 156 | fake_images = generator(z) 157 | g_loss = - torch.mean(discriminator(fake_images)) 158 | g_optimizer.zero_grad() 159 | g_loss.backward() 160 | g_optimizer.step() 161 | 162 | # 输出损失 163 | print ("[Epoch %d/%d] [Batch %d /%d] [D loss: %f %f] [G loss: %f]" % (epoch, total_epochs, i, len(dataloader), d_gan.item(), d_gp.item(), g_loss.item())) 164 | 165 | # 每训练一个epoch,把生成器设置为测试模型,生成效果图并保存 166 | generator = generator.eval() 167 | fixed_fake_images = generator(fixed_z) 168 | save_image(fixed_fake_images, 'wgangp_images/{}.png'.format(epoch), nrow=10, normalize=True) 169 | 170 | -------------------------------------------------------------------------------- /wgan.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | from torchvision.utils import save_image 8 | 9 | import os 10 | 11 | 12 | # 超参数 13 | gpu_id = None 14 | if gpu_id is not None: 15 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id 16 | device = torch.device('cuda') 17 | else: 18 | device = torch.device('cpu') 19 | if os.path.exists('wgan_images') is False: 20 | os.makedirs('wgan_images') 21 | z_dim = 100 22 | batch_size = 64 23 | learning_rate = 0.00005 24 | total_epochs = 200 25 | clip_value = 0.01 26 | n_critic = 5 27 | 28 | 29 | class Discriminator(nn.Module): 30 | '''全连接判别器,用于1x28x28的MNIST数据''' 31 | 32 | def __init__(self): 33 | super(Discriminator, self).__init__() 34 | 35 | layers = [] 36 | # 第一层 37 | layers.append(nn.Linear(in_features=28*28, out_features=512, bias=True)) 38 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 39 | # 第二层 40 | layers.append(nn.Linear(in_features=512, out_features=256, bias=True)) 41 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 42 | # 输出层:相比gan,不需要sigmoid 43 | layers.append(nn.Linear(in_features=256, out_features=1, bias=True)) 44 | 45 | self.model = nn.Sequential(*layers) 46 | 47 | def forward(self, x): 48 | x = x.view(x.size(0), -1) 49 | validity = self.model(x) 50 | return validity 51 | 52 | 53 | class Generator(nn.Module): 54 | '''全连接生成器,用于1x28x28的MNIST数据''' 55 | 56 | def __init__(self, z_dim): 57 | super(Generator, self).__init__() 58 | 59 | layers= [] 60 | 61 | # 第一层 62 | layers.append(nn.Linear(in_features=z_dim, out_features=128)) 63 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 64 | # 第二层 65 | layers.append(nn.Linear(in_features=128, out_features=256)) 66 | layers.append(nn.BatchNorm1d(256, 0.8)) 67 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 68 | # 第三层 69 | layers.append(nn.Linear(in_features=256, out_features=512)) 70 | layers.append(nn.BatchNorm1d(512, 0.8)) 71 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 72 | # 输出层 73 | layers.append(nn.Linear(in_features=512, out_features=28*28)) 74 | layers.append(nn.Tanh()) 75 | 76 | self.model = nn.Sequential(*layers) 77 | 78 | def forward(self, z): 79 | x = self.model(z) 80 | x = x.view(-1, 1, 28, 28) 81 | return x 82 | 83 | # 初始化构建判别器和生成器 84 | discriminator = Discriminator().to(device) 85 | generator = Generator(z_dim=z_dim).to(device) 86 | 87 | # 初始化优化器,使用Adam优化器 88 | g_optimizer = torch.optim.RMSprop(generator.parameters(), lr=learning_rate) 89 | d_optimizer = torch.optim.RMSprop(discriminator.parameters(), lr=learning_rate) 90 | 91 | # 加载MNIST数据集 92 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 93 | dataset = torchvision.datasets.MNIST(root='data/', train=True, transform=transform, download=True) 94 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) 95 | 96 | # 随机产生100个向量,用于生成效果图 97 | fixed_z = torch.randn([100, z_dim]).to(device) 98 | 99 | # 开始训练,一共训练total_epochs 100 | for epoch in range(total_epochs): 101 | 102 | # 在训练阶段,把生成器设置为训练模式;对应于后面的,在测试阶段,把生成器设置为测试模式 103 | generator = generator.train() 104 | 105 | # 训练一个epoch 106 | for i, data in enumerate(dataloader): 107 | 108 | # 加载真实数据,不加载标签 109 | real_images, _ = data 110 | real_images = real_images.to(device) 111 | 112 | # 从正态分布中采样batch_size个噪声,然后生成对应的图片 113 | z = torch.randn([batch_size, z_dim]).to(device) 114 | fake_images = generator(z) 115 | 116 | # 计算判别器损失,并优化判别器 117 | d_loss = - torch.mean(discriminator(real_images)) + torch.mean(discriminator(fake_images.detach())) 118 | d_optimizer.zero_grad() 119 | d_loss.backward() 120 | d_optimizer.step() 121 | 122 | # 为了保证利普斯次系数小于一个常数,进行权重截断 123 | for p in discriminator.parameters(): 124 | p.data.clamp_(-clip_value, clip_value) 125 | 126 | # 每优化n_critic次判别器, 优化一次生成器 127 | if i % n_critic == 0: 128 | g_loss = - torch.mean(discriminator(fake_images)) 129 | g_optimizer.zero_grad() 130 | g_loss.backward() 131 | g_optimizer.step() 132 | 133 | # 输出损失 134 | print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, total_epochs, i, len(dataloader), d_loss.item(), g_loss.item())) 135 | 136 | # 每训练一个epoch,把生成器设置为测试模型,生成效果图并保存 137 | generator = generator.eval() 138 | fixed_fake_images = generator(fixed_z) 139 | save_image(fixed_fake_images, 'wgan_images/{}.png'.format(epoch), nrow=10, normalize=True) 140 | 141 | --------------------------------------------------------------------------------