├── .gitignore
├── LICENSE
├── README.md
├── checkpoints
└── cifar10
│ ├── netD_epoch_249.pth
│ └── netG_epoch_249.pth
├── data
└── fake_samples_epoch_060.png
└── odegan.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # Additional
85 | /cifar10/
86 | /images/
87 | /mnist/
88 |
89 |
90 | # pyenv
91 | .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
101 | __pypackages__/
102 |
103 | # Celery stuff
104 | celerybeat-schedule
105 | celerybeat.pid
106 |
107 | # SageMath parsed files
108 | *.sage.py
109 |
110 | # Environments
111 | .env
112 | .venv
113 | env/
114 | venv/
115 | ENV/
116 | env.bak/
117 | venv.bak/
118 |
119 | # Spyder project settings
120 | .spyderproject
121 | .spyproject
122 |
123 | # Rope project settings
124 | .ropeproject
125 |
126 | # mkdocs documentation
127 | /site
128 |
129 | # mypy
130 | .mypy_cache/
131 | .dmypy.json
132 | dmypy.json
133 |
134 | # Pyre type checker
135 | .pyre/
136 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Somshubra Majumdar
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ODE GAN (Prototype) in PyTorch
2 | Partial implementation of ODE-GAN technique from the paper [Training Generative Adversarial Networks by Solving Ordinary Differential Equations](https://arxiv.org/abs/2010.15040).
3 |
4 | # Caveat
5 | This is **not a faithful reproduction of the paper**!
6 |
7 | - One of the many major difference is the use of gradient normalization to stabilize training (and avoid exploding gradients which lead to nans in generator + discriminator).
8 | - Another difference might be implementation of the regularization component.
9 | - Finally, this is a prototype to demonstrate the training regiment, without any focus for optimization of any kind - there's a lot of duplication of weights, caches etc throughout the code.
10 |
11 | # Training Regiment
12 | By default, the model is trained on the CIFAR 10 dataset, with most of the parameters set in argparse.
13 |
14 | Here is a tensorboard of a model being trained using RK2 (Heuns ODE step) for 250 epochs ~ 187500 update steps - [Tensorboard Dev Log](https://tensorboard.dev/experiment/E9VIqTYgT9umwIbiMVj33Q/#scalars&runSelectionState=eyIyMDIwLTExLTEwLTE3LTU1LTAxIjp0cnVlLCIyMDIwLTExLTEwLTE3LTU1LTAxXFwxNjA1MDU5NzA1LjkyNjM2NTEiOmZhbHNlfQ%3D%3D)
15 |
16 | # Generated images
17 | Training has not completed yet, here are images at the 60th epoch of training. Assuming nothing crashes in the next 200 epochs, there might be better results in later epochs.
18 |
19 |
20 |

21 |
22 |
--------------------------------------------------------------------------------
/checkpoints/cifar10/netD_epoch_249.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/pytorch_odegan/2bbbd124f1065dd679bc0bcbf11cebe9939cbe18/checkpoints/cifar10/netD_epoch_249.pth
--------------------------------------------------------------------------------
/checkpoints/cifar10/netG_epoch_249.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/pytorch_odegan/2bbbd124f1065dd679bc0bcbf11cebe9939cbe18/checkpoints/cifar10/netG_epoch_249.pth
--------------------------------------------------------------------------------
/data/fake_samples_epoch_060.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/pytorch_odegan/2bbbd124f1065dd679bc0bcbf11cebe9939cbe18/data/fake_samples_epoch_060.png
--------------------------------------------------------------------------------
/odegan.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from https://github.com/pytorch/examples/blob/master/dcgan/main.py
3 | """
4 | from __future__ import print_function
5 | import argparse
6 | import os
7 | import random
8 | import copy
9 | import datetime
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.parallel
13 | import torch.backends.cudnn as cudnn
14 | import torch.utils.data
15 | import torchvision.datasets as dset
16 | import torchvision.transforms as transforms
17 | import torchvision.utils as vutils
18 | from torch.utils.tensorboard import SummaryWriter
19 |
20 | if __name__ == '__main__':
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument('--dataset', required=True, default='cifar10',
23 | help='cifar10 | lsun | mnist |imagenet | folder | lfw | fake')
24 | parser.add_argument('--dataroot', required=False, help='path to dataset')
25 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=8)
26 | parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
27 | parser.add_argument('--imageSize', type=int, default=32, help='the height / width of the input image to network')
28 | parser.add_argument('--nz', type=int, default=128, help='size of the latent z vector')
29 | parser.add_argument('--ngf', type=int, default=64)
30 | parser.add_argument('--ndf', type=int, default=64)
31 | parser.add_argument('--niter', type=int, default=250, help='number of epochs to train for')
32 | parser.add_argument('--cuda', action='store_true', help='enables cuda')
33 | parser.add_argument('--dry-run', action='store_true', help='check a single training cycle works')
34 | parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
35 | parser.add_argument('--netG', default='', help="path to netG (to continue training)")
36 | parser.add_argument('--netD', default='', help="path to netD (to continue training)")
37 | parser.add_argument('--outf', default='./images/', help='folder to output images and model checkpoints')
38 | parser.add_argument('--manualSeed', type=int, help='manual seed')
39 | parser.add_argument('--classes', default='bedroom', help='comma separated list of classes for the lsun data set')
40 |
41 | # ODE Params
42 | parser.add_argument('--ode', default='heun', choices=['heun', 'rk4'], help='Type of ode step to take')
43 | parser.add_argument('--step_size', type=float, default=0.01, help='Fixed step optimizer step size')
44 | parser.add_argument('--disc_reg', default=0.01, type=float,
45 | help='Fixed weight decay of theta (discriminator)')
46 |
47 | opt = parser.parse_args()
48 | print(opt)
49 |
50 | opt.outf = os.path.join(opt.outf, opt.dataset + "_" + opt.ode)
51 | timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
52 | logdir = os.path.join(opt.outf, 'logs', timestamp)
53 |
54 | writer = SummaryWriter(log_dir=logdir)
55 |
56 | try:
57 | os.makedirs(opt.outf, exist_ok=True)
58 | os.makedirs(logdir, exist_ok=True)
59 | except OSError:
60 | pass
61 |
62 | if opt.manualSeed is None:
63 | opt.manualSeed = random.randint(1, 10000)
64 | print("Random Seed: ", opt.manualSeed)
65 | random.seed(opt.manualSeed)
66 | torch.manual_seed(opt.manualSeed)
67 |
68 | cudnn.benchmark = True
69 |
70 | if torch.cuda.is_available() and not opt.cuda:
71 | print("WARNING: You have a CUDA device, so you should probably run with --cuda")
72 |
73 | if opt.dataroot is None and str(opt.dataset).lower() != 'fake':
74 | raise ValueError("`dataroot` parameter is required for dataset \"%s\"" % opt.dataset)
75 |
76 | if opt.dataset in ['imagenet', 'folder', 'lfw']:
77 | # folder dataset
78 | dataset = dset.ImageFolder(root=opt.dataroot,
79 | transform=transforms.Compose([
80 | transforms.Resize(opt.imageSize),
81 | transforms.CenterCrop(opt.imageSize),
82 | transforms.ToTensor(),
83 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
84 | ]))
85 | nc = 3
86 | elif opt.dataset == 'lsun':
87 | classes = [c + '_train' for c in opt.classes.split(',')]
88 | dataset = dset.LSUN(root=opt.dataroot, classes=classes,
89 | transform=transforms.Compose([
90 | transforms.Resize(opt.imageSize),
91 | transforms.CenterCrop(opt.imageSize),
92 | transforms.ToTensor(),
93 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
94 | ]))
95 | nc = 3
96 | elif opt.dataset == 'cifar10':
97 | dataset = dset.CIFAR10(root=opt.dataroot, download=True,
98 | transform=transforms.Compose([
99 | transforms.Resize(opt.imageSize),
100 | transforms.ToTensor(),
101 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
102 | ]))
103 | nc = 3
104 |
105 | elif opt.dataset == 'mnist':
106 | dataset = dset.MNIST(root=opt.dataroot, download=True,
107 | transform=transforms.Compose([
108 | transforms.Resize(opt.imageSize),
109 | transforms.ToTensor(),
110 | transforms.Normalize((0.5,), (0.5,)),
111 | ]))
112 | nc = 1
113 |
114 | elif opt.dataset == 'fake':
115 | dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize),
116 | transform=transforms.ToTensor())
117 | nc = 3
118 |
119 | assert dataset
120 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
121 | shuffle=True, num_workers=int(opt.workers))
122 |
123 | device = torch.device("cuda:0" if opt.cuda else "cpu")
124 | ngpu = int(opt.ngpu)
125 | nz = int(opt.nz)
126 | ngf = int(opt.ngf)
127 | ndf = int(opt.ndf)
128 |
129 | # Conv Initialization from SNGAN codebase
130 | def weights_init(m):
131 | classname = m.__class__.__name__
132 | if classname.find('Conv') != -1:
133 | torch.nn.init.xavier_uniform_(m.weight, gain=1.0)
134 | elif classname.find('BatchNorm') != -1:
135 | torch.nn.init.normal_(m.weight, 1.0, 0.02)
136 | torch.nn.init.zeros_(m.bias)
137 |
138 |
139 | class Generator(nn.Module):
140 | def __init__(self, ngpu):
141 | super(Generator, self).__init__()
142 | self.ngpu = ngpu
143 | self.project = nn.Conv2d(nz, ngf * 8 * 4 * 4, 1, 1, 0, bias=False)
144 | self.main = nn.Sequential(
145 | # state size. (ngf*8) x 4 x 4
146 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
147 | nn.BatchNorm2d(ngf * 4),
148 | nn.ReLU(True),
149 | # state size. (ngf*4) x 8 x 8
150 | nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
151 | nn.BatchNorm2d(ngf * 2),
152 | nn.ReLU(True),
153 | # state size. (ngf*2) x 16 x 16
154 | nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
155 | nn.BatchNorm2d(ngf),
156 | nn.ReLU(True),
157 | # # state size. (ngf) x 32 x 32
158 | nn.Conv2d(ngf, nc, 3, 1, 1, bias=False),
159 | nn.Tanh()
160 | # state size. (nc) x 32 x 32
161 | )
162 |
163 | def forward(self, input):
164 | if input.is_cuda and self.ngpu > 1:
165 | raise NotImplemented()
166 | # output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
167 | else:
168 | x = self.project(input)
169 | x = x.view(-1, ngf * 8, 4, 4)
170 | output = self.main(x)
171 |
172 | return output
173 |
174 |
175 | # ODE GAN
176 | netG = Generator(ngpu)
177 | netG.apply(weights_init)
178 | netG = netG.to(device)
179 |
180 | if opt.netG != '':
181 | netG.load_state_dict(torch.load(opt.netG))
182 | print(netG)
183 |
184 |
185 | class Discriminator(nn.Module):
186 | def __init__(self, ngpu):
187 | super(Discriminator, self).__init__()
188 | self.ngpu = ngpu
189 | self.main = nn.Sequential(
190 | # input is (nc) x 32 x 32
191 | nn.Conv2d(nc, ndf, 3, 1, 1, bias=False),
192 | # nn.BatchNorm2d(ndf),
193 | nn.LeakyReLU(0.1, inplace=True),
194 | nn.Conv2d(ndf, ndf, 4, 2, 1, bias=False),
195 | # nn.BatchNorm2d(ndf),
196 | nn.LeakyReLU(0.1, inplace=True),
197 | # state size. (ndf) x 16 x 16
198 | nn.Conv2d(ndf, ndf * 2, 3, 1, 1, bias=False),
199 | # nn.BatchNorm2d(ndf * 2),
200 | nn.LeakyReLU(0.1, inplace=True),
201 | nn.Conv2d(ndf * 2, ndf * 2, 4, 2, 1, bias=False),
202 | # nn.BatchNorm2d(ndf * 2),
203 | nn.LeakyReLU(0.1, inplace=True),
204 | # state size. (ndf*2) x 8 x 8
205 | nn.Conv2d(ndf * 2, ndf * 4, 3, 1, 1, bias=False),
206 | # nn.BatchNorm2d(ndf * 4),
207 | nn.LeakyReLU(0.1, inplace=True),
208 | nn.Conv2d(ndf * 4, ndf * 4, 4, 2, 1, bias=False),
209 | # nn.BatchNorm2d(ndf * 4),
210 | nn.LeakyReLU(0.1, inplace=True),
211 | # state size. (ndf*4) x 4 x 4
212 | nn.Conv2d(ndf * 4, ndf * 8, 3, 1, 1, bias=False),
213 | # nn.BatchNorm2d(ndf * 8),
214 | nn.LeakyReLU(0.1, inplace=True),
215 | # state size. (ndf*8) x 2 x 2
216 | nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
217 | # nn.Sigmoid()
218 | )
219 |
220 | def forward(self, input):
221 | if input.is_cuda and self.ngpu > 1:
222 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
223 | else:
224 | output = self.main(input)
225 |
226 | return output.view(-1, 1).squeeze(1)
227 |
228 |
229 | netD = Discriminator(ngpu).to(device)
230 | netD.apply(weights_init)
231 |
232 | netD = netD.to(device)
233 |
234 | if opt.netD != '':
235 | netD.load_state_dict(torch.load(opt.netD))
236 | print(netD)
237 |
238 | criterion = nn.BCEWithLogitsLoss()
239 |
240 | # fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device)
241 | real_label = 1
242 | fake_label = 0
243 |
244 | if opt.dry_run:
245 | opt.niter = 1
246 |
247 | # deep copies model + grads of model
248 | def grad_clone(source: torch.nn.Module) -> torch.nn.Module:
249 | dest = copy.deepcopy(source)
250 | dest.requires_grad_(True)
251 |
252 | for s_p, d_p in zip(source.parameters(), dest.parameters()):
253 | if s_p.grad is not None:
254 | d_p.grad = s_p.grad.clone()
255 |
256 | return dest
257 |
258 | # Inplace normalizes gradient; if grad_norm > 1
259 | def normalize_grad(grad: torch.Tensor) -> torch.Tensor:
260 | # normalize gradient
261 | grad_norm = grad.norm()
262 | if grad_norm > 1.:
263 | grad.div_(grad_norm)
264 | return grad
265 |
266 | # Heun's ODE Step
267 | def heun_ode_step(G: Generator, D: Discriminator, data: torch.Tensor, step_size: float, disc_reg: float):
268 | # Compute first step of Heun
269 | theta_1, phi_1, errD, errG, D_x, D_G_z1, D_G_z2 = gan_step(G, D, data, detach_err=False, retain_graph=True)
270 |
271 | # Compute the L2 norm using the prior computation graph
272 | grad_norm = None
273 | for phi_0_param in G.parameters():
274 | if phi_0_param.grad is not None:
275 | if grad_norm is None:
276 | grad_norm = phi_0_param.grad.square().sum()
277 | else:
278 | grad_norm = grad_norm + phi_0_param.grad.square().sum()
279 |
280 | grad_norm = grad_norm.sqrt()
281 |
282 | # Preserve gradients for regularization in cache
283 | D_norm_grads = torch.autograd.grad(grad_norm, list(D.parameters()))
284 | grad_norm = grad_norm.detach()
285 |
286 | # Compute norm of the gradients of the discriminator for logging
287 | disc_grad_norm = torch.tensor(0.0, device=device)
288 | for d_grad, in zip(D_norm_grads):
289 | # compute discriminator norm
290 | disc_grad_norm = disc_grad_norm + d_grad.detach().square().sum().sqrt()
291 |
292 | # Detach graph
293 | errD = errD.detach()
294 | errG = errG.detach()
295 |
296 | # preserve theta, phi for next computation
297 | theta_0 = grad_clone(theta_1)
298 | phi_0 = grad_clone(phi_1)
299 |
300 | # Update theta and phi for first heun step]
301 | for d_param, theta_1_param in zip(D.parameters(), theta_1.parameters()):
302 | if theta_1_param.grad is not None:
303 | theta_1_param.data = d_param.data + (step_size * -theta_1_param.grad)
304 |
305 | for g_param, phi_1_param in zip(G.parameters(), phi_1.parameters()):
306 | if phi_1_param.grad is not None:
307 | phi_1_param.data = g_param.data + (step_size * -phi_1_param.grad)
308 |
309 | # Compute second step of Heun
310 | theta_2, phi_2, errD, errG, D_x, D_G_z1, D_G_z2 = gan_step(phi_1, theta_1, data)
311 |
312 | # Compute grad norm and update discriminator
313 | for d_param, theta_0_param, theta_1_param in zip(D.parameters(), theta_0.parameters(), theta_2.parameters()):
314 | if theta_1_param.grad is not None:
315 | grad = theta_0_param.grad + theta_1_param.grad
316 |
317 | # simulate regularization with weight decay
318 | # if disc_reg > 0:
319 | # grad += disc_reg * d_param.data
320 |
321 | # normalize gradient
322 | grad = normalize_grad(grad)
323 |
324 | d_param.data = d_param.data + (step_size * 0.5 * -(grad))
325 |
326 | for g_param, phi_0_param, phi_1_param in zip(G.parameters(), phi_0.parameters(), phi_2.parameters()):
327 | if phi_1_param.grad is not None:
328 | grad = phi_0_param.grad + phi_1_param.grad
329 |
330 | # normalize gradient
331 | grad = normalize_grad(grad)
332 |
333 | g_param.data = g_param.data + (step_size * 0.5 * -(grad))
334 |
335 | # Regularization step
336 | for d_param, d_grad in zip(D.parameters(), D_norm_grads):
337 | d_param.data = d_param.data - step_size * disc_reg * d_grad
338 |
339 | del theta_0, theta_1, theta_2
340 | del phi_0, phi_1, phi_2
341 | del D_norm_grads
342 |
343 | return G, D, errD, errG, D_x, D_G_z1, D_G_z2, grad_norm.detach(), disc_grad_norm.detach()
344 |
345 |
346 | def rk4_ode_step(G: Generator, D: Discriminator, data: torch.Tensor, step_size: float, disc_reg: float):
347 | # Compute first step of RK4
348 | theta_1_cache, phi_1_cache, errD, errG, D_x, D_G_z1, D_G_z2 = gan_step(G, D, data,
349 | detach_err=False,
350 | retain_graph=True)
351 |
352 | # Compute the L2 norm using the prior computation graph
353 | grad_norm = None # errG
354 | for phi_0_param in G.parameters():
355 | if phi_0_param.grad is not None:
356 | if grad_norm is None:
357 | grad_norm = phi_0_param.grad.square().sum()
358 | else:
359 | grad_norm = grad_norm + phi_0_param.grad.square().sum()
360 |
361 | grad_norm = grad_norm.sqrt()
362 |
363 | # Preserve gradients for regularization in cache
364 | D_norm_grads = torch.autograd.grad(grad_norm, list(D.parameters()))
365 | grad_norm = grad_norm.detach()
366 |
367 | # Compute norm of the gradients of the discriminator for logging
368 | disc_grad_norm = torch.tensor(0.0, device=device)
369 | for d_grad, in zip(D_norm_grads):
370 | # compute discriminator norm
371 | disc_grad_norm = disc_grad_norm + d_grad.detach().square().sum().sqrt()
372 |
373 | # Detach graph
374 | errD = errD.detach()
375 | errG = errG.detach()
376 |
377 | # preserve theta1, phi1 for next computation
378 | theta_1 = grad_clone(theta_1_cache)
379 | phi_1 = grad_clone(phi_1_cache)
380 |
381 | # Update theta and phi for second RK step]
382 | for d_param, theta_1_param in zip(D.parameters(), theta_1.parameters()):
383 | if theta_1_param.grad is not None:
384 | theta_1_param.data = d_param.data + (step_size * 0.5 * -theta_1_param.grad)
385 |
386 | for g_param, phi_1_param in zip(G.parameters(), phi_1.parameters()):
387 | if phi_1_param.grad is not None:
388 | phi_1_param.data = g_param.data + (step_size * 0.5 * -phi_1_param.grad)
389 |
390 | # Compute second step of RK 4
391 | theta_2_cache, phi_2_cache, errD, errG, D_x, D_G_z1, D_G_z2 = gan_step(phi_1, theta_1, data)
392 |
393 | # preserve theta2, phi2
394 | theta_2 = grad_clone(theta_2_cache)
395 | phi_2 = grad_clone(phi_2_cache)
396 |
397 | # Update theta and phi for third RK step]
398 | for d_param, theta_2_param in zip(D.parameters(), theta_2.parameters()):
399 | if theta_2_param.grad is not None:
400 | theta_2_param.data = d_param.data + (step_size * 0.5 * -theta_2_param.grad)
401 |
402 | for g_param, phi_2_param in zip(G.parameters(), phi_2.parameters()):
403 | if phi_2_param.grad is not None:
404 | phi_2_param.data = g_param.data + (step_size * 0.5 * -phi_2_param.grad)
405 |
406 | # Compute third step of RK 4
407 | theta_3_cache, phi_3_cache, errD, errG, D_x, D_G_z1, D_G_z2 = gan_step(phi_2, theta_2, data)
408 |
409 | # preserve theta3, phi3
410 | theta_3 = grad_clone(theta_3_cache)
411 | phi_3 = grad_clone(phi_3_cache)
412 |
413 | # Update theta and phi for fourth RK step]
414 | for d_param, theta_3_param in zip(D.parameters(), theta_3.parameters()):
415 | if theta_3_param.grad is not None:
416 | theta_3_param.data = d_param.data + (step_size * -theta_3_param.grad)
417 |
418 | for g_param, phi_3_param in zip(G.parameters(), phi_3.parameters()):
419 | if phi_3_param.grad is not None:
420 | phi_3_param.data = g_param.data + (step_size * -phi_3_param.grad)
421 |
422 | # Compute fourth step of RK 4
423 | theta_4, phi_4, errD, errG, D_x, D_G_z1, D_G_z2 = gan_step(phi_3, theta_3, data)
424 |
425 | # Compute grad norm and update discriminator
426 | for d_param, theta_1_param, theta_2_param, theta_3_param, theta_4_param in zip(D.parameters(),
427 | theta_1_cache.parameters(),
428 | theta_2_cache.parameters(),
429 | theta_3_cache.parameters(),
430 | theta_4.parameters()):
431 | if theta_1_param.grad is not None:
432 | grad = (theta_1_param.grad + 2 * theta_2_param.grad + 2 * theta_3_param.grad + theta_4_param.grad)
433 |
434 | # simulate regularization with weight decay
435 | # if disc_reg > 0:
436 | # grad += disc_reg * d_param.data
437 |
438 | # normalize gradient
439 | grad = normalize_grad(grad)
440 |
441 | d_param.data = d_param.data + (step_size / 6. * -(grad))
442 |
443 | for g_param, phi_1_param, phi_2_param, phi_3_param, phi_4_param in zip(G.parameters(),
444 | phi_1_cache.parameters(),
445 | phi_2_cache.parameters(),
446 | phi_3_cache.parameters(),
447 | phi_4.parameters()):
448 | if phi_1_param.grad is not None:
449 | grad = (phi_1_param.grad + 2 * phi_2_param.grad + 2 * phi_3_param.grad + phi_4_param.grad)
450 |
451 | # normalize gradient
452 | grad = normalize_grad(grad)
453 |
454 | g_param.data = g_param.data + (step_size / 6.0 * -(grad))
455 |
456 | # Regularization step
457 | for d_param, d_grad in zip(D.parameters(), D_norm_grads):
458 | if d_param.grad is not None:
459 | d_param.data = d_param.data - step_size * disc_reg * d_grad
460 |
461 | del theta_1, theta_1_cache, theta_2, theta_2_cache, theta_3, theta_3_cache, theta_4
462 | del phi_1, phi_1_cache, phi_2, phi_2_cache, phi_3, phi_3_cache, phi_4
463 | del D_norm_grads
464 |
465 | return G, D, errD, errG, D_x, D_G_z1, D_G_z2, grad_norm.detach(), disc_grad_norm.detach()
466 |
467 |
468 | def gan_step(G: Generator, D: Discriminator, data, detach_err: bool = True, retain_graph: bool = False) -> (
469 | Discriminator, Generator, torch.Tensor, torch.Tensor, torch.Tensor):
470 | ############################
471 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
472 | ###########################
473 | # train with real
474 | D.zero_grad()
475 |
476 | real_cpu = data[0].to(device)
477 | batch_size = real_cpu.size(0)
478 | label = torch.full((batch_size,), real_label,
479 | dtype=real_cpu.dtype, device=device)
480 |
481 | output = D(real_cpu)
482 | errD_real = criterion(output, label)
483 | errD_real.backward()
484 | D_x = output.mean().detach()
485 |
486 | # train with fake
487 | noise = torch.randn(batch_size, nz, 1, 1, device=device)
488 | fake = G(noise)
489 | label.fill_(fake_label)
490 | output = D(fake.detach())
491 | errD_fake = criterion(output, label)
492 | errD_fake.backward()
493 | D_G_z1 = output.mean().detach()
494 | errD = errD_real + errD_fake
495 |
496 | if detach_err:
497 | errD = errD.detach()
498 |
499 | DISC_GRAD_CACHE = grad_clone(D)
500 |
501 | ############################
502 | # (2) Update G network: maximize log(D(G(z)))
503 | ###########################
504 | G.zero_grad()
505 |
506 | label.fill_(real_label) # fake labels are real for generator cost
507 | output = D(fake)
508 | errG = criterion(output, label)
509 | errG.backward(create_graph=retain_graph)
510 | D_G_z2 = output.mean().detach()
511 |
512 | if detach_err:
513 | errG = errG.detach()
514 |
515 | GEN_GRAD_CACHE = grad_clone(G)
516 |
517 | return DISC_GRAD_CACHE, GEN_GRAD_CACHE, errD, errG, D_x, D_G_z1, D_G_z2
518 |
519 | # Save hyper parameters
520 | writer.add_hparams(vars(opt), metric_dict={})
521 |
522 | step_size = opt.step_size
523 | global_step = 0
524 |
525 | for epoch in range(opt.niter):
526 | for i, data in enumerate(dataloader, 0):
527 | # Schedule
528 | if global_step < 500:
529 | step_size = opt.step_size
530 | elif global_step >= 500 and global_step <= 400000:
531 | step_size = opt.step_size * 4
532 | elif global_step > 400000:
533 | step_size = opt.step_size * 2
534 |
535 | if opt.ode == 'heun':
536 |
537 | netG, netD, errD, errG, D_x, D_G_z1, D_G_z2, gen_grad_norm, disc_grad_norm = heun_ode_step(netG, netD,
538 | data,
539 | step_size=step_size,
540 | disc_reg=opt.disc_reg)
541 |
542 | elif opt.ode == 'rk4':
543 | netG, netD, errD, errG, D_x, D_G_z1, D_G_z2, gen_grad_norm, disc_grad_norm = rk4_ode_step(netG, netD,
544 | data,
545 | step_size=step_size,
546 | disc_reg=opt.disc_reg)
547 |
548 | else:
549 | raise ValueError("Only support ode steps are - heun and rk4")
550 |
551 | # Cast logits to sigmoid probabilities
552 | D_x = D_x.sigmoid().item()
553 | D_G_z1 = D_G_z1.sigmoid().item()
554 | D_G_z2 = D_G_z2.sigmoid().item()
555 |
556 | print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f '
557 | 'Gen Grad Norm: %0.4f Disc Grad Norm: %0.4f'
558 | % (epoch, opt.niter, i, len(dataloader),
559 | errD.item(), errG.item(), D_x, D_G_z1, D_G_z2, gen_grad_norm, disc_grad_norm))
560 |
561 | writer.add_scalar('loss/discriminator', errD.item(), global_step=global_step)
562 | writer.add_scalar('loss/generator', errG.item(), global_step=global_step)
563 | writer.add_scalar('acc/D(x)', D_x, global_step=global_step)
564 | writer.add_scalar('acc/D(G(z))-fake', D_G_z1, global_step=global_step)
565 | writer.add_scalar('acc/D(G(z))-real', D_G_z2, global_step=global_step)
566 | writer.add_scalar('norm/gen_grad_norm', gen_grad_norm, global_step=global_step)
567 | writer.add_scalar('norm/disc_grad_norm', disc_grad_norm, global_step=global_step)
568 | writer.add_scalar('step_size', step_size, global_step=global_step)
569 |
570 | global_step += 1
571 |
572 | if i % 100 == 0:
573 | real_cpu = data[0].to(device)
574 | vutils.save_image(real_cpu,
575 | '%s/real_samples.png' % opt.outf,
576 | normalize=True)
577 |
578 | # fake = netG(fixed_noise)
579 | random_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device)
580 |
581 | fake = netG(random_noise)
582 |
583 | vutils.save_image(fake.detach(),
584 | '%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch),
585 | normalize=True)
586 |
587 | if opt.dry_run:
588 | break
589 | # do checkpointing
590 | torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch))
591 | torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch))
592 |
593 | writer.flush()
594 |
--------------------------------------------------------------------------------