├── .gitignore ├── LICENSE ├── README.md ├── checkpoints └── cifar10 │ ├── netD_epoch_249.pth │ └── netG_epoch_249.pth ├── data └── fake_samples_epoch_060.png └── odegan.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # Additional 85 | /cifar10/ 86 | /images/ 87 | /mnist/ 88 | 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Somshubra Majumdar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ODE GAN (Prototype) in PyTorch 2 | Partial implementation of ODE-GAN technique from the paper [Training Generative Adversarial Networks by Solving Ordinary Differential Equations](https://arxiv.org/abs/2010.15040). 3 | 4 | # Caveat 5 | This is **not a faithful reproduction of the paper**! 6 | 7 | - One of the many major difference is the use of gradient normalization to stabilize training (and avoid exploding gradients which lead to nans in generator + discriminator). 8 | - Another difference might be implementation of the regularization component. 9 | - Finally, this is a prototype to demonstrate the training regiment, without any focus for optimization of any kind - there's a lot of duplication of weights, caches etc throughout the code. 10 | 11 | # Training Regiment 12 | By default, the model is trained on the CIFAR 10 dataset, with most of the parameters set in argparse. 13 | 14 | Here is a tensorboard of a model being trained using RK2 (Heuns ODE step) for 250 epochs ~ 187500 update steps - [Tensorboard Dev Log](https://tensorboard.dev/experiment/E9VIqTYgT9umwIbiMVj33Q/#scalars&runSelectionState=eyIyMDIwLTExLTEwLTE3LTU1LTAxIjp0cnVlLCIyMDIwLTExLTEwLTE3LTU1LTAxXFwxNjA1MDU5NzA1LjkyNjM2NTEiOmZhbHNlfQ%3D%3D) 15 | 16 | # Generated images 17 | Training has not completed yet, here are images at the 60th epoch of training. Assuming nothing crashes in the next 200 epochs, there might be better results in later epochs. 18 | 19 |
20 | 21 |
22 | -------------------------------------------------------------------------------- /checkpoints/cifar10/netD_epoch_249.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/pytorch_odegan/2bbbd124f1065dd679bc0bcbf11cebe9939cbe18/checkpoints/cifar10/netD_epoch_249.pth -------------------------------------------------------------------------------- /checkpoints/cifar10/netG_epoch_249.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/pytorch_odegan/2bbbd124f1065dd679bc0bcbf11cebe9939cbe18/checkpoints/cifar10/netG_epoch_249.pth -------------------------------------------------------------------------------- /data/fake_samples_epoch_060.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/pytorch_odegan/2bbbd124f1065dd679bc0bcbf11cebe9939cbe18/data/fake_samples_epoch_060.png -------------------------------------------------------------------------------- /odegan.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/pytorch/examples/blob/master/dcgan/main.py 3 | """ 4 | from __future__ import print_function 5 | import argparse 6 | import os 7 | import random 8 | import copy 9 | import datetime 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.backends.cudnn as cudnn 14 | import torch.utils.data 15 | import torchvision.datasets as dset 16 | import torchvision.transforms as transforms 17 | import torchvision.utils as vutils 18 | from torch.utils.tensorboard import SummaryWriter 19 | 20 | if __name__ == '__main__': 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--dataset', required=True, default='cifar10', 23 | help='cifar10 | lsun | mnist |imagenet | folder | lfw | fake') 24 | parser.add_argument('--dataroot', required=False, help='path to dataset') 25 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=8) 26 | parser.add_argument('--batchSize', type=int, default=64, help='input batch size') 27 | parser.add_argument('--imageSize', type=int, default=32, help='the height / width of the input image to network') 28 | parser.add_argument('--nz', type=int, default=128, help='size of the latent z vector') 29 | parser.add_argument('--ngf', type=int, default=64) 30 | parser.add_argument('--ndf', type=int, default=64) 31 | parser.add_argument('--niter', type=int, default=250, help='number of epochs to train for') 32 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 33 | parser.add_argument('--dry-run', action='store_true', help='check a single training cycle works') 34 | parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') 35 | parser.add_argument('--netG', default='', help="path to netG (to continue training)") 36 | parser.add_argument('--netD', default='', help="path to netD (to continue training)") 37 | parser.add_argument('--outf', default='./images/', help='folder to output images and model checkpoints') 38 | parser.add_argument('--manualSeed', type=int, help='manual seed') 39 | parser.add_argument('--classes', default='bedroom', help='comma separated list of classes for the lsun data set') 40 | 41 | # ODE Params 42 | parser.add_argument('--ode', default='heun', choices=['heun', 'rk4'], help='Type of ode step to take') 43 | parser.add_argument('--step_size', type=float, default=0.01, help='Fixed step optimizer step size') 44 | parser.add_argument('--disc_reg', default=0.01, type=float, 45 | help='Fixed weight decay of theta (discriminator)') 46 | 47 | opt = parser.parse_args() 48 | print(opt) 49 | 50 | opt.outf = os.path.join(opt.outf, opt.dataset + "_" + opt.ode) 51 | timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 52 | logdir = os.path.join(opt.outf, 'logs', timestamp) 53 | 54 | writer = SummaryWriter(log_dir=logdir) 55 | 56 | try: 57 | os.makedirs(opt.outf, exist_ok=True) 58 | os.makedirs(logdir, exist_ok=True) 59 | except OSError: 60 | pass 61 | 62 | if opt.manualSeed is None: 63 | opt.manualSeed = random.randint(1, 10000) 64 | print("Random Seed: ", opt.manualSeed) 65 | random.seed(opt.manualSeed) 66 | torch.manual_seed(opt.manualSeed) 67 | 68 | cudnn.benchmark = True 69 | 70 | if torch.cuda.is_available() and not opt.cuda: 71 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 72 | 73 | if opt.dataroot is None and str(opt.dataset).lower() != 'fake': 74 | raise ValueError("`dataroot` parameter is required for dataset \"%s\"" % opt.dataset) 75 | 76 | if opt.dataset in ['imagenet', 'folder', 'lfw']: 77 | # folder dataset 78 | dataset = dset.ImageFolder(root=opt.dataroot, 79 | transform=transforms.Compose([ 80 | transforms.Resize(opt.imageSize), 81 | transforms.CenterCrop(opt.imageSize), 82 | transforms.ToTensor(), 83 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 84 | ])) 85 | nc = 3 86 | elif opt.dataset == 'lsun': 87 | classes = [c + '_train' for c in opt.classes.split(',')] 88 | dataset = dset.LSUN(root=opt.dataroot, classes=classes, 89 | transform=transforms.Compose([ 90 | transforms.Resize(opt.imageSize), 91 | transforms.CenterCrop(opt.imageSize), 92 | transforms.ToTensor(), 93 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 94 | ])) 95 | nc = 3 96 | elif opt.dataset == 'cifar10': 97 | dataset = dset.CIFAR10(root=opt.dataroot, download=True, 98 | transform=transforms.Compose([ 99 | transforms.Resize(opt.imageSize), 100 | transforms.ToTensor(), 101 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 102 | ])) 103 | nc = 3 104 | 105 | elif opt.dataset == 'mnist': 106 | dataset = dset.MNIST(root=opt.dataroot, download=True, 107 | transform=transforms.Compose([ 108 | transforms.Resize(opt.imageSize), 109 | transforms.ToTensor(), 110 | transforms.Normalize((0.5,), (0.5,)), 111 | ])) 112 | nc = 1 113 | 114 | elif opt.dataset == 'fake': 115 | dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize), 116 | transform=transforms.ToTensor()) 117 | nc = 3 118 | 119 | assert dataset 120 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, 121 | shuffle=True, num_workers=int(opt.workers)) 122 | 123 | device = torch.device("cuda:0" if opt.cuda else "cpu") 124 | ngpu = int(opt.ngpu) 125 | nz = int(opt.nz) 126 | ngf = int(opt.ngf) 127 | ndf = int(opt.ndf) 128 | 129 | # Conv Initialization from SNGAN codebase 130 | def weights_init(m): 131 | classname = m.__class__.__name__ 132 | if classname.find('Conv') != -1: 133 | torch.nn.init.xavier_uniform_(m.weight, gain=1.0) 134 | elif classname.find('BatchNorm') != -1: 135 | torch.nn.init.normal_(m.weight, 1.0, 0.02) 136 | torch.nn.init.zeros_(m.bias) 137 | 138 | 139 | class Generator(nn.Module): 140 | def __init__(self, ngpu): 141 | super(Generator, self).__init__() 142 | self.ngpu = ngpu 143 | self.project = nn.Conv2d(nz, ngf * 8 * 4 * 4, 1, 1, 0, bias=False) 144 | self.main = nn.Sequential( 145 | # state size. (ngf*8) x 4 x 4 146 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), 147 | nn.BatchNorm2d(ngf * 4), 148 | nn.ReLU(True), 149 | # state size. (ngf*4) x 8 x 8 150 | nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), 151 | nn.BatchNorm2d(ngf * 2), 152 | nn.ReLU(True), 153 | # state size. (ngf*2) x 16 x 16 154 | nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), 155 | nn.BatchNorm2d(ngf), 156 | nn.ReLU(True), 157 | # # state size. (ngf) x 32 x 32 158 | nn.Conv2d(ngf, nc, 3, 1, 1, bias=False), 159 | nn.Tanh() 160 | # state size. (nc) x 32 x 32 161 | ) 162 | 163 | def forward(self, input): 164 | if input.is_cuda and self.ngpu > 1: 165 | raise NotImplemented() 166 | # output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 167 | else: 168 | x = self.project(input) 169 | x = x.view(-1, ngf * 8, 4, 4) 170 | output = self.main(x) 171 | 172 | return output 173 | 174 | 175 | # ODE GAN 176 | netG = Generator(ngpu) 177 | netG.apply(weights_init) 178 | netG = netG.to(device) 179 | 180 | if opt.netG != '': 181 | netG.load_state_dict(torch.load(opt.netG)) 182 | print(netG) 183 | 184 | 185 | class Discriminator(nn.Module): 186 | def __init__(self, ngpu): 187 | super(Discriminator, self).__init__() 188 | self.ngpu = ngpu 189 | self.main = nn.Sequential( 190 | # input is (nc) x 32 x 32 191 | nn.Conv2d(nc, ndf, 3, 1, 1, bias=False), 192 | # nn.BatchNorm2d(ndf), 193 | nn.LeakyReLU(0.1, inplace=True), 194 | nn.Conv2d(ndf, ndf, 4, 2, 1, bias=False), 195 | # nn.BatchNorm2d(ndf), 196 | nn.LeakyReLU(0.1, inplace=True), 197 | # state size. (ndf) x 16 x 16 198 | nn.Conv2d(ndf, ndf * 2, 3, 1, 1, bias=False), 199 | # nn.BatchNorm2d(ndf * 2), 200 | nn.LeakyReLU(0.1, inplace=True), 201 | nn.Conv2d(ndf * 2, ndf * 2, 4, 2, 1, bias=False), 202 | # nn.BatchNorm2d(ndf * 2), 203 | nn.LeakyReLU(0.1, inplace=True), 204 | # state size. (ndf*2) x 8 x 8 205 | nn.Conv2d(ndf * 2, ndf * 4, 3, 1, 1, bias=False), 206 | # nn.BatchNorm2d(ndf * 4), 207 | nn.LeakyReLU(0.1, inplace=True), 208 | nn.Conv2d(ndf * 4, ndf * 4, 4, 2, 1, bias=False), 209 | # nn.BatchNorm2d(ndf * 4), 210 | nn.LeakyReLU(0.1, inplace=True), 211 | # state size. (ndf*4) x 4 x 4 212 | nn.Conv2d(ndf * 4, ndf * 8, 3, 1, 1, bias=False), 213 | # nn.BatchNorm2d(ndf * 8), 214 | nn.LeakyReLU(0.1, inplace=True), 215 | # state size. (ndf*8) x 2 x 2 216 | nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), 217 | # nn.Sigmoid() 218 | ) 219 | 220 | def forward(self, input): 221 | if input.is_cuda and self.ngpu > 1: 222 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 223 | else: 224 | output = self.main(input) 225 | 226 | return output.view(-1, 1).squeeze(1) 227 | 228 | 229 | netD = Discriminator(ngpu).to(device) 230 | netD.apply(weights_init) 231 | 232 | netD = netD.to(device) 233 | 234 | if opt.netD != '': 235 | netD.load_state_dict(torch.load(opt.netD)) 236 | print(netD) 237 | 238 | criterion = nn.BCEWithLogitsLoss() 239 | 240 | # fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device) 241 | real_label = 1 242 | fake_label = 0 243 | 244 | if opt.dry_run: 245 | opt.niter = 1 246 | 247 | # deep copies model + grads of model 248 | def grad_clone(source: torch.nn.Module) -> torch.nn.Module: 249 | dest = copy.deepcopy(source) 250 | dest.requires_grad_(True) 251 | 252 | for s_p, d_p in zip(source.parameters(), dest.parameters()): 253 | if s_p.grad is not None: 254 | d_p.grad = s_p.grad.clone() 255 | 256 | return dest 257 | 258 | # Inplace normalizes gradient; if grad_norm > 1 259 | def normalize_grad(grad: torch.Tensor) -> torch.Tensor: 260 | # normalize gradient 261 | grad_norm = grad.norm() 262 | if grad_norm > 1.: 263 | grad.div_(grad_norm) 264 | return grad 265 | 266 | # Heun's ODE Step 267 | def heun_ode_step(G: Generator, D: Discriminator, data: torch.Tensor, step_size: float, disc_reg: float): 268 | # Compute first step of Heun 269 | theta_1, phi_1, errD, errG, D_x, D_G_z1, D_G_z2 = gan_step(G, D, data, detach_err=False, retain_graph=True) 270 | 271 | # Compute the L2 norm using the prior computation graph 272 | grad_norm = None 273 | for phi_0_param in G.parameters(): 274 | if phi_0_param.grad is not None: 275 | if grad_norm is None: 276 | grad_norm = phi_0_param.grad.square().sum() 277 | else: 278 | grad_norm = grad_norm + phi_0_param.grad.square().sum() 279 | 280 | grad_norm = grad_norm.sqrt() 281 | 282 | # Preserve gradients for regularization in cache 283 | D_norm_grads = torch.autograd.grad(grad_norm, list(D.parameters())) 284 | grad_norm = grad_norm.detach() 285 | 286 | # Compute norm of the gradients of the discriminator for logging 287 | disc_grad_norm = torch.tensor(0.0, device=device) 288 | for d_grad, in zip(D_norm_grads): 289 | # compute discriminator norm 290 | disc_grad_norm = disc_grad_norm + d_grad.detach().square().sum().sqrt() 291 | 292 | # Detach graph 293 | errD = errD.detach() 294 | errG = errG.detach() 295 | 296 | # preserve theta, phi for next computation 297 | theta_0 = grad_clone(theta_1) 298 | phi_0 = grad_clone(phi_1) 299 | 300 | # Update theta and phi for first heun step] 301 | for d_param, theta_1_param in zip(D.parameters(), theta_1.parameters()): 302 | if theta_1_param.grad is not None: 303 | theta_1_param.data = d_param.data + (step_size * -theta_1_param.grad) 304 | 305 | for g_param, phi_1_param in zip(G.parameters(), phi_1.parameters()): 306 | if phi_1_param.grad is not None: 307 | phi_1_param.data = g_param.data + (step_size * -phi_1_param.grad) 308 | 309 | # Compute second step of Heun 310 | theta_2, phi_2, errD, errG, D_x, D_G_z1, D_G_z2 = gan_step(phi_1, theta_1, data) 311 | 312 | # Compute grad norm and update discriminator 313 | for d_param, theta_0_param, theta_1_param in zip(D.parameters(), theta_0.parameters(), theta_2.parameters()): 314 | if theta_1_param.grad is not None: 315 | grad = theta_0_param.grad + theta_1_param.grad 316 | 317 | # simulate regularization with weight decay 318 | # if disc_reg > 0: 319 | # grad += disc_reg * d_param.data 320 | 321 | # normalize gradient 322 | grad = normalize_grad(grad) 323 | 324 | d_param.data = d_param.data + (step_size * 0.5 * -(grad)) 325 | 326 | for g_param, phi_0_param, phi_1_param in zip(G.parameters(), phi_0.parameters(), phi_2.parameters()): 327 | if phi_1_param.grad is not None: 328 | grad = phi_0_param.grad + phi_1_param.grad 329 | 330 | # normalize gradient 331 | grad = normalize_grad(grad) 332 | 333 | g_param.data = g_param.data + (step_size * 0.5 * -(grad)) 334 | 335 | # Regularization step 336 | for d_param, d_grad in zip(D.parameters(), D_norm_grads): 337 | d_param.data = d_param.data - step_size * disc_reg * d_grad 338 | 339 | del theta_0, theta_1, theta_2 340 | del phi_0, phi_1, phi_2 341 | del D_norm_grads 342 | 343 | return G, D, errD, errG, D_x, D_G_z1, D_G_z2, grad_norm.detach(), disc_grad_norm.detach() 344 | 345 | 346 | def rk4_ode_step(G: Generator, D: Discriminator, data: torch.Tensor, step_size: float, disc_reg: float): 347 | # Compute first step of RK4 348 | theta_1_cache, phi_1_cache, errD, errG, D_x, D_G_z1, D_G_z2 = gan_step(G, D, data, 349 | detach_err=False, 350 | retain_graph=True) 351 | 352 | # Compute the L2 norm using the prior computation graph 353 | grad_norm = None # errG 354 | for phi_0_param in G.parameters(): 355 | if phi_0_param.grad is not None: 356 | if grad_norm is None: 357 | grad_norm = phi_0_param.grad.square().sum() 358 | else: 359 | grad_norm = grad_norm + phi_0_param.grad.square().sum() 360 | 361 | grad_norm = grad_norm.sqrt() 362 | 363 | # Preserve gradients for regularization in cache 364 | D_norm_grads = torch.autograd.grad(grad_norm, list(D.parameters())) 365 | grad_norm = grad_norm.detach() 366 | 367 | # Compute norm of the gradients of the discriminator for logging 368 | disc_grad_norm = torch.tensor(0.0, device=device) 369 | for d_grad, in zip(D_norm_grads): 370 | # compute discriminator norm 371 | disc_grad_norm = disc_grad_norm + d_grad.detach().square().sum().sqrt() 372 | 373 | # Detach graph 374 | errD = errD.detach() 375 | errG = errG.detach() 376 | 377 | # preserve theta1, phi1 for next computation 378 | theta_1 = grad_clone(theta_1_cache) 379 | phi_1 = grad_clone(phi_1_cache) 380 | 381 | # Update theta and phi for second RK step] 382 | for d_param, theta_1_param in zip(D.parameters(), theta_1.parameters()): 383 | if theta_1_param.grad is not None: 384 | theta_1_param.data = d_param.data + (step_size * 0.5 * -theta_1_param.grad) 385 | 386 | for g_param, phi_1_param in zip(G.parameters(), phi_1.parameters()): 387 | if phi_1_param.grad is not None: 388 | phi_1_param.data = g_param.data + (step_size * 0.5 * -phi_1_param.grad) 389 | 390 | # Compute second step of RK 4 391 | theta_2_cache, phi_2_cache, errD, errG, D_x, D_G_z1, D_G_z2 = gan_step(phi_1, theta_1, data) 392 | 393 | # preserve theta2, phi2 394 | theta_2 = grad_clone(theta_2_cache) 395 | phi_2 = grad_clone(phi_2_cache) 396 | 397 | # Update theta and phi for third RK step] 398 | for d_param, theta_2_param in zip(D.parameters(), theta_2.parameters()): 399 | if theta_2_param.grad is not None: 400 | theta_2_param.data = d_param.data + (step_size * 0.5 * -theta_2_param.grad) 401 | 402 | for g_param, phi_2_param in zip(G.parameters(), phi_2.parameters()): 403 | if phi_2_param.grad is not None: 404 | phi_2_param.data = g_param.data + (step_size * 0.5 * -phi_2_param.grad) 405 | 406 | # Compute third step of RK 4 407 | theta_3_cache, phi_3_cache, errD, errG, D_x, D_G_z1, D_G_z2 = gan_step(phi_2, theta_2, data) 408 | 409 | # preserve theta3, phi3 410 | theta_3 = grad_clone(theta_3_cache) 411 | phi_3 = grad_clone(phi_3_cache) 412 | 413 | # Update theta and phi for fourth RK step] 414 | for d_param, theta_3_param in zip(D.parameters(), theta_3.parameters()): 415 | if theta_3_param.grad is not None: 416 | theta_3_param.data = d_param.data + (step_size * -theta_3_param.grad) 417 | 418 | for g_param, phi_3_param in zip(G.parameters(), phi_3.parameters()): 419 | if phi_3_param.grad is not None: 420 | phi_3_param.data = g_param.data + (step_size * -phi_3_param.grad) 421 | 422 | # Compute fourth step of RK 4 423 | theta_4, phi_4, errD, errG, D_x, D_G_z1, D_G_z2 = gan_step(phi_3, theta_3, data) 424 | 425 | # Compute grad norm and update discriminator 426 | for d_param, theta_1_param, theta_2_param, theta_3_param, theta_4_param in zip(D.parameters(), 427 | theta_1_cache.parameters(), 428 | theta_2_cache.parameters(), 429 | theta_3_cache.parameters(), 430 | theta_4.parameters()): 431 | if theta_1_param.grad is not None: 432 | grad = (theta_1_param.grad + 2 * theta_2_param.grad + 2 * theta_3_param.grad + theta_4_param.grad) 433 | 434 | # simulate regularization with weight decay 435 | # if disc_reg > 0: 436 | # grad += disc_reg * d_param.data 437 | 438 | # normalize gradient 439 | grad = normalize_grad(grad) 440 | 441 | d_param.data = d_param.data + (step_size / 6. * -(grad)) 442 | 443 | for g_param, phi_1_param, phi_2_param, phi_3_param, phi_4_param in zip(G.parameters(), 444 | phi_1_cache.parameters(), 445 | phi_2_cache.parameters(), 446 | phi_3_cache.parameters(), 447 | phi_4.parameters()): 448 | if phi_1_param.grad is not None: 449 | grad = (phi_1_param.grad + 2 * phi_2_param.grad + 2 * phi_3_param.grad + phi_4_param.grad) 450 | 451 | # normalize gradient 452 | grad = normalize_grad(grad) 453 | 454 | g_param.data = g_param.data + (step_size / 6.0 * -(grad)) 455 | 456 | # Regularization step 457 | for d_param, d_grad in zip(D.parameters(), D_norm_grads): 458 | if d_param.grad is not None: 459 | d_param.data = d_param.data - step_size * disc_reg * d_grad 460 | 461 | del theta_1, theta_1_cache, theta_2, theta_2_cache, theta_3, theta_3_cache, theta_4 462 | del phi_1, phi_1_cache, phi_2, phi_2_cache, phi_3, phi_3_cache, phi_4 463 | del D_norm_grads 464 | 465 | return G, D, errD, errG, D_x, D_G_z1, D_G_z2, grad_norm.detach(), disc_grad_norm.detach() 466 | 467 | 468 | def gan_step(G: Generator, D: Discriminator, data, detach_err: bool = True, retain_graph: bool = False) -> ( 469 | Discriminator, Generator, torch.Tensor, torch.Tensor, torch.Tensor): 470 | ############################ 471 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) 472 | ########################### 473 | # train with real 474 | D.zero_grad() 475 | 476 | real_cpu = data[0].to(device) 477 | batch_size = real_cpu.size(0) 478 | label = torch.full((batch_size,), real_label, 479 | dtype=real_cpu.dtype, device=device) 480 | 481 | output = D(real_cpu) 482 | errD_real = criterion(output, label) 483 | errD_real.backward() 484 | D_x = output.mean().detach() 485 | 486 | # train with fake 487 | noise = torch.randn(batch_size, nz, 1, 1, device=device) 488 | fake = G(noise) 489 | label.fill_(fake_label) 490 | output = D(fake.detach()) 491 | errD_fake = criterion(output, label) 492 | errD_fake.backward() 493 | D_G_z1 = output.mean().detach() 494 | errD = errD_real + errD_fake 495 | 496 | if detach_err: 497 | errD = errD.detach() 498 | 499 | DISC_GRAD_CACHE = grad_clone(D) 500 | 501 | ############################ 502 | # (2) Update G network: maximize log(D(G(z))) 503 | ########################### 504 | G.zero_grad() 505 | 506 | label.fill_(real_label) # fake labels are real for generator cost 507 | output = D(fake) 508 | errG = criterion(output, label) 509 | errG.backward(create_graph=retain_graph) 510 | D_G_z2 = output.mean().detach() 511 | 512 | if detach_err: 513 | errG = errG.detach() 514 | 515 | GEN_GRAD_CACHE = grad_clone(G) 516 | 517 | return DISC_GRAD_CACHE, GEN_GRAD_CACHE, errD, errG, D_x, D_G_z1, D_G_z2 518 | 519 | # Save hyper parameters 520 | writer.add_hparams(vars(opt), metric_dict={}) 521 | 522 | step_size = opt.step_size 523 | global_step = 0 524 | 525 | for epoch in range(opt.niter): 526 | for i, data in enumerate(dataloader, 0): 527 | # Schedule 528 | if global_step < 500: 529 | step_size = opt.step_size 530 | elif global_step >= 500 and global_step <= 400000: 531 | step_size = opt.step_size * 4 532 | elif global_step > 400000: 533 | step_size = opt.step_size * 2 534 | 535 | if opt.ode == 'heun': 536 | 537 | netG, netD, errD, errG, D_x, D_G_z1, D_G_z2, gen_grad_norm, disc_grad_norm = heun_ode_step(netG, netD, 538 | data, 539 | step_size=step_size, 540 | disc_reg=opt.disc_reg) 541 | 542 | elif opt.ode == 'rk4': 543 | netG, netD, errD, errG, D_x, D_G_z1, D_G_z2, gen_grad_norm, disc_grad_norm = rk4_ode_step(netG, netD, 544 | data, 545 | step_size=step_size, 546 | disc_reg=opt.disc_reg) 547 | 548 | else: 549 | raise ValueError("Only support ode steps are - heun and rk4") 550 | 551 | # Cast logits to sigmoid probabilities 552 | D_x = D_x.sigmoid().item() 553 | D_G_z1 = D_G_z1.sigmoid().item() 554 | D_G_z2 = D_G_z2.sigmoid().item() 555 | 556 | print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f ' 557 | 'Gen Grad Norm: %0.4f Disc Grad Norm: %0.4f' 558 | % (epoch, opt.niter, i, len(dataloader), 559 | errD.item(), errG.item(), D_x, D_G_z1, D_G_z2, gen_grad_norm, disc_grad_norm)) 560 | 561 | writer.add_scalar('loss/discriminator', errD.item(), global_step=global_step) 562 | writer.add_scalar('loss/generator', errG.item(), global_step=global_step) 563 | writer.add_scalar('acc/D(x)', D_x, global_step=global_step) 564 | writer.add_scalar('acc/D(G(z))-fake', D_G_z1, global_step=global_step) 565 | writer.add_scalar('acc/D(G(z))-real', D_G_z2, global_step=global_step) 566 | writer.add_scalar('norm/gen_grad_norm', gen_grad_norm, global_step=global_step) 567 | writer.add_scalar('norm/disc_grad_norm', disc_grad_norm, global_step=global_step) 568 | writer.add_scalar('step_size', step_size, global_step=global_step) 569 | 570 | global_step += 1 571 | 572 | if i % 100 == 0: 573 | real_cpu = data[0].to(device) 574 | vutils.save_image(real_cpu, 575 | '%s/real_samples.png' % opt.outf, 576 | normalize=True) 577 | 578 | # fake = netG(fixed_noise) 579 | random_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device) 580 | 581 | fake = netG(random_noise) 582 | 583 | vutils.save_image(fake.detach(), 584 | '%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch), 585 | normalize=True) 586 | 587 | if opt.dry_run: 588 | break 589 | # do checkpointing 590 | torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch)) 591 | torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch)) 592 | 593 | writer.flush() 594 | --------------------------------------------------------------------------------