├── .idea
├── misc.xml
├── mnist_gan.iml
├── modules.xml
└── workspace.xml
├── README.md
├── mnist_data.py
├── mnist_loss.py
├── mnist_net.py
├── mnist_train.py
└── mnist_visual.py
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/mnist_gan.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
112 |
113 |
114 |
115 | cuda
116 | resume
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 | 1542509554264
209 |
210 |
211 | 1542509554264
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #### 项目结构
2 | * mnist_data.py:数据输入模块
3 | * mnist_net.py:网络模型模块
4 | * mnist_loss.py:Loss计算模块
5 | * mnist_train.py:迭代训练模块
6 | * mnist_visual.py:可视化模块
7 | #### 生成器生成的图片
8 | 
9 | #### MNIST真实的图片
10 | 
11 | #### 网络结构
12 | 
13 |
--------------------------------------------------------------------------------
/mnist_data.py:
--------------------------------------------------------------------------------
1 | # coding:utf-8
2 | from torchvision import transforms
3 | from torchvision import datasets
4 |
5 |
6 | class Mnist:
7 | def __init__(self, data_path):
8 | # 数据路径
9 | self.data_path = data_path
10 | # 数据预处理,当然预处理还有其他方式:翻转、平移、裁剪...
11 | self.img_transform = transforms.Compose([transforms.ToTensor(),
12 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
13 |
14 | # 获取训练数据
15 | def train_data(self):
16 | return datasets.MNIST(self.data_path, train=True, transform=self.img_transform)
17 |
18 | # 获取测试数据
19 | def test_data(self):
20 | return datasets.MNIST(self.data_path, train=False, transform=self.img_transform)
21 |
--------------------------------------------------------------------------------
/mnist_loss.py:
--------------------------------------------------------------------------------
1 | # coding:utf-8
2 | import torch.nn as nn
3 |
4 |
5 | # 损失函数
6 | class Loss(nn.Module):
7 |
8 | def __init__(self):
9 | super(Loss, self).__init__()
10 |
11 | self.loss = nn.BCELoss()
12 |
13 | def forward(self, outputs, targets):
14 | return self.loss(outputs, targets)
15 |
--------------------------------------------------------------------------------
/mnist_net.py:
--------------------------------------------------------------------------------
1 | # coding:utf-8
2 | import torch.nn as nn
3 |
4 |
5 | # 定义判别器
6 | class Discriminator(nn.Module):
7 | def __init__(self):
8 | super(Discriminator, self).__init__()
9 | # 定义第一层网络
10 | self.conv1 = nn.Sequential(
11 | # 输入1通道,输出32通道,卷积核大小5,边缘填充2
12 | nn.Conv2d(1, 32, 5, padding=2), # batch, 32, 28, 28
13 | # 负区域斜率为0.2
14 | nn.LeakyReLU(0.2, True),
15 | # 卷积核大小2,步长2
16 | nn.AvgPool2d(2, stride=2), # batch, 32, 14, 14
17 | )
18 | # 定义第二层网络
19 | self.conv2 = nn.Sequential(
20 | # 输入32通道,输出64通道,卷积核大小5,边缘填充2
21 | nn.Conv2d(32, 64, 5, padding=2), # batch, 64, 14, 14
22 | nn.LeakyReLU(0.2, True),
23 | nn.AvgPool2d(2, stride=2) # batch, 64, 7, 7
24 | )
25 | # 共2个全连接层
26 | self.fc = nn.Sequential(
27 | # 输入向量长度64*7*×7,输出长度1024
28 | nn.Linear(64 * 7 * 7, 1024),
29 | nn.LeakyReLU(0.2, True),
30 | # 输入向量长度1024,输出长度1
31 | nn.Linear(1024, 1),
32 | nn.Sigmoid()
33 | )
34 |
35 | def forward(self, x):
36 | '''
37 | x: batch, width, height, channel=1
38 | '''
39 | x = self.conv1(x)
40 | x = self.conv2(x)
41 | # reshape
42 | x = x.view(x.size(0), -1) # batch,width*height,*channel
43 | x = self.fc(x)
44 | x = x.squeeze() # 压缩尺寸保证与真值同一个shape
45 | return x
46 |
47 |
48 | # 定义生成器
49 | class Generator(nn.Module):
50 | def __init__(self, input_size, num_feature):
51 | super(Generator, self).__init__()
52 | # 全连接
53 | self.fc = nn.Linear(input_size, num_feature) # batch, 3136=1x56x56
54 | self.br = nn.Sequential(
55 | nn.BatchNorm2d(1),
56 | nn.ReLU(True)
57 | )
58 | # 第一个卷积层
59 | self.conv1 = nn.Sequential(
60 | # 输入1通道,输出50通道,卷积核大小3,步长1,边缘填充1
61 | nn.Conv2d(1, 50, 3, stride=1, padding=1), # batch, 50, 56, 56
62 | nn.BatchNorm2d(50),
63 | nn.ReLU(True)
64 | )
65 | # 第二个卷积层
66 | self.conv2 = nn.Sequential(
67 | # 输入通道50,输出通道25,卷积核大小3,步长1,边缘填充1
68 | nn.Conv2d(50, 25, 3, stride=1, padding=1), # batch, 25, 56, 56
69 | nn.BatchNorm2d(25),
70 | nn.ReLU(True)
71 | )
72 | # 第三个卷积层
73 | self.conv3 = nn.Sequential(
74 | # 输入通道25,输出通道1,卷积核大小2,步长2,边缘填充1
75 | nn.Conv2d(25, 1, 2, stride=2), # batch, 1, 28, 28
76 | nn.Tanh()
77 | )
78 |
79 | def forward(self, x):
80 | x = self.fc(x)
81 | x = x.view(x.size(0), 1, 56, 56)
82 | x = self.br(x)
83 | x = self.conv1(x)
84 | x = self.conv2(x)
85 | x = self.conv3(x)
86 | return x
87 |
--------------------------------------------------------------------------------
/mnist_train.py:
--------------------------------------------------------------------------------
1 | # coding:utf-8
2 | from argparse import ArgumentParser
3 | import os
4 |
5 | import torch
6 | from torch.utils.data import DataLoader
7 | from torch.autograd import Variable
8 |
9 | from mnist_data import Mnist
10 | from mnist_loss import Loss
11 | from mnist_net import Discriminator, Generator
12 | from mnist_visual import Visual
13 |
14 |
15 | def main(args):
16 | # 1.相关路径
17 | # 模型存储路径
18 | if not os.path.exists(args.savedir):
19 | os.makedirs(args.savedir)
20 | # 数据集路径
21 | if not os.path.exists(args.datadir):
22 | os.makedirs(args.datadir)
23 | # 可视化路径
24 | if not os.path.exists(args.visualdir):
25 | os.makedirs(args.visualdir)
26 |
27 | # 2.数据加载
28 | dataset_train = Mnist(args.datadir).train_data()
29 | loader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
30 |
31 | # 3.初始化模型
32 | D = Discriminator()
33 | G = Generator(args.z_dimension, 3136)
34 | if args.cuda:
35 | D = torch.nn.DataParallel(D).cuda()
36 | G = torch.nn.DataParallel(G).cuda()
37 |
38 | # 4.优化器
39 | d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
40 | g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)
41 |
42 | # 5.损失函数
43 | criterion = Loss()
44 |
45 | # 6.可视化
46 | visual = Visual(args.visualdir)
47 |
48 | # 7.恢复模型
49 | start_epoch = 0
50 | if args.resume:
51 | d_path = args.savedir + '/discriminator.pth'
52 | assert os.path.exists(
53 | d_path), "Error: resume option was used but discriminator.pth was not found in folder"
54 | d_checkpoint = torch.load(d_path)
55 | start_epoch = d_checkpoint['epoch']
56 | D.load_state_dict(d_checkpoint['state_dict'])
57 |
58 | g_path = args.savedir + '/generator.pth'
59 | assert os.path.exists(
60 | g_path), "Error: resume option was used but generator.pth was not found in folder"
61 | g_checkpoint = torch.load(g_path)
62 | G.load_state_dict(g_checkpoint)
63 |
64 | print("=> Loaded checkpoint at epoch {})".format(start_epoch))
65 |
66 | # 8.开始训练
67 | print("========== TRAINING START===========")
68 | for epoch in range(start_epoch + 1, args.num_epochs):
69 | for i, (img, _) in enumerate(loader_train):
70 | num_img = img.size(0)
71 | # =================数据处理
72 | # 真实图片
73 | real_img = Variable(img)
74 | # 真样本1
75 | real_label = Variable(torch.ones(num_img))
76 | # 假样本0
77 | fake_label = Variable(torch.zeros(num_img))
78 | # 用于判别器的噪声
79 | d_z = Variable(torch.randn(num_img, args.z_dimension))
80 | # 用于生成器的噪声
81 | g_z = Variable(torch.randn(num_img, args.z_dimension))
82 | if args.cuda:
83 | real_img = real_img.cuda()
84 | real_label = real_label.cuda()
85 | fake_label = fake_label.cuda()
86 | d_z = d_z.cuda()
87 | g_z = g_z.cuda()
88 |
89 | # =================训练判别器
90 | # 真实图片loss
91 | real_out = D(real_img)
92 | d_loss_real = criterion(real_out, real_label)
93 | real_scores = real_out # closer to 1 means better
94 |
95 | # 假图片loss
96 | fake_img = G(d_z)
97 | fake_out = D(fake_img)
98 | d_loss_fake = criterion(fake_out, fake_label)
99 | fake_scores = fake_out # closer to 0 means better
100 |
101 | # 判别器梯度反传和参数优化
102 | d_loss = d_loss_real + d_loss_fake
103 | d_optimizer.zero_grad()
104 | d_loss.backward()
105 | d_optimizer.step()
106 |
107 | # ===============训练生成器
108 | # 假图片loss
109 | fake_img = G(g_z)
110 | output = D(fake_img)
111 | g_loss = criterion(output, real_label)
112 |
113 | # 生成器梯度反传和参数优化
114 | g_optimizer.zero_grad()
115 | g_loss.backward()
116 | g_optimizer.step()
117 |
118 | # =================打印
119 | if (i + 1) % args.steps_loss == 0:
120 | print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
121 | 'D real: {:.6f}, D fake: {:.6f}'
122 | .format(epoch, args.num_epochs, d_loss.data[0], g_loss.data[0],
123 | real_scores.data.mean(), fake_scores.data.mean()))
124 |
125 | # =================可视化
126 | if epoch == 1:
127 | visual.save_img(real_img.cpu().data, 'real_images.png')
128 | visual.show_img(real_img[0].cpu().data, 'real_image')
129 | if epoch % args.epochs_visual == 0:
130 | visual.save_img(fake_img.cpu().data, 'fake_images-{}.png'.format(epoch))
131 | visual.show_img(fake_img[0].cpu().data, 'fake_image (epoch: %d)' % epoch)
132 |
133 | # =================保存模型
134 | if epoch % args.epochs_save == 0:
135 | torch.save({
136 | 'epoch': epoch,
137 | 'state_dict': D.state_dict()
138 | }, args.savedir + '/discriminator.pth')
139 | torch.save(G.state_dict(), args.savedir + '/generator.pth')
140 |
141 | print("========== TRAINING FINISHED ===========")
142 |
143 |
144 | if __name__ == '__main__':
145 | parser = ArgumentParser()
146 | # 指定数据集路径
147 | parser.add_argument('--datadir', default='./data')
148 | # 存储日志和模型的路径
149 | parser.add_argument('--savedir', default='./model')
150 | # 可视化保存图片路径
151 | parser.add_argument('--visualdir', default='./visual')
152 |
153 | # 打印loss间隔,单位step
154 | parser.add_argument('--steps-loss', type=int, default=100)
155 | # 可视化间隔,单位epoch
156 | parser.add_argument('--epochs-visual', type=int, default=1)
157 | # 存储模型间隔,单位epoch
158 | parser.add_argument('--epochs-save', type=int, default=1)
159 |
160 | # 训练的epoch数
161 | parser.add_argument('--num-epochs', type=int, default=100)
162 | # 线程数
163 | parser.add_argument('--num-workers', type=int, default=4)
164 | # 训练批大小
165 | parser.add_argument('--batch-size', type=int, default=128)
166 | # 输入噪声的维度
167 | parser.add_argument('--z-dimension', type=int, default=100)
168 |
169 | # 是否使用cuda
170 | parser.add_argument('--cuda', action='store_true', default=True)
171 | # 是否重新使用权重
172 | parser.add_argument('--resume', action='store_true')
173 |
174 | main(parser.parse_args())
175 |
--------------------------------------------------------------------------------
/mnist_visual.py:
--------------------------------------------------------------------------------
1 | # coding:utf-8
2 | import os
3 | from torchvision.utils import save_image
4 | from visdom import Visdom
5 | import numpy as np
6 |
7 |
8 | class Visual:
9 | def __init__(self, path):
10 | self.path = path
11 | self.vis = Visdom()
12 |
13 | def save_img(self, img, sub_path):
14 | img = 0.5 * (img + 1)
15 | img = img.clamp(0, 1)
16 | img = img.view(-1, 1, 28, 28)
17 | save_image(img, os.path.join(self.path, sub_path))
18 |
19 | '''
20 | 1.在pytorch环境下开启服务:
21 | python -m visdom.server
22 | 2.浏览器输入http://localhost:8097
23 | '''
24 | def show_img(self, img, name):
25 | img = img.numpy()
26 | self.vis.image(img, env='images', opts=dict(title=name))
27 |
--------------------------------------------------------------------------------