├── 1804.07455.pdf ├── FusionGAN.ipynb ├── README.md ├── dataset_lists ├── train_datapoint_triplets.pkl └── train_shapeLoss_pairs.pkl ├── model ├── __pycache__ │ ├── discriminator.cpython-35.pyc │ └── generator.cpython-35.pyc ├── discriminator.py └── generator.py ├── train.ipynb ├── train.py └── utils ├── __pycache__ ├── dataloader.cpython-35.pyc ├── loss_functions.cpython-35.pyc └── train_function.cpython-35.pyc ├── dataloader.py └── loss_functions.py /1804.07455.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aarushgupta/FusionGAN/6efeda8aa501a6ebfad90adb50eb812a09198c11/1804.07455.pdf -------------------------------------------------------------------------------- /FusionGAN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import torch\n", 12 | "import torch.nn as nn\n", 13 | "import torch.optim as optim\n", 14 | "import torch.nn.functional as F\n", 15 | "\n", 16 | "from torch.utils.data import Dataset, DataLoader\n", 17 | "from torchvision import transforms, utils\n", 18 | "\n", 19 | "import time\n", 20 | "import numpy as np\n", 21 | "import os\n", 22 | "import pickle\n", 23 | "from PIL import Image\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "plt.ion()" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": { 32 | "collapsed": true 33 | }, 34 | "outputs": [], 35 | "source": [ 36 | "dataset_dir = './Dataset/'\n", 37 | "batch_size = 4\n", 38 | "epochs = 10\n", 39 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 40 | "# device = 'cpu'\n", 41 | "learning_rate = 0.01" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": { 48 | "collapsed": true 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "class YouTubePose(Dataset):\n", 53 | " \n", 54 | " def __init__(self, datapoint_pairs, shapeLoss_datapoint_pairs, dataset_dir, transform=None, mode='train'):\n", 55 | " self.datapoint_pairs = datapoint_pairs\n", 56 | " self.shapeLoss_datapoint_pairs = shapeLoss_datapoint_pairs\n", 57 | " self.dataset_dir = dataset_dir\n", 58 | " self.transform = transform\n", 59 | " self.mode = mode\n", 60 | " \n", 61 | " def __len__(self):\n", 62 | " return len(self.datapoint_pairs)\n", 63 | " \n", 64 | " def __getitem__(self, idx):\n", 65 | " image_pair = self.datapoint_pairs[idx]\n", 66 | " x_gen_path = image_pair[0]\n", 67 | " x_dis_path = image_pair[1]\n", 68 | " y_path = image_pair[2]\n", 69 | " \n", 70 | " identity_pair = self.shapeLoss_datapoint_pairs[idx]\n", 71 | " iden_1_path = identity_pair[0]\n", 72 | " iden_2_path = identity_pair[1]\n", 73 | " \n", 74 | " x_gen = Image.open(self.dataset_dir + self.mode + '/' + x_gen_path)\n", 75 | " x_dis = Image.open(self.dataset_dir + self.mode + '/' + x_dis_path)\n", 76 | " y = Image.open(self.dataset_dir + self.mode + '/' + y_path)\n", 77 | " iden_1 = Image.open(self.dataset_dir + self.mode + '/' + iden_1_path)\n", 78 | " iden_2 = Image.open(self.dataset_dir + self.mode + '/' + iden_2_path)\n", 79 | " \n", 80 | " if self.transform:\n", 81 | " x_gen = self.transform(x_gen)\n", 82 | " x_dis = self.transform(x_dis)\n", 83 | " y = self.transform(y)\n", 84 | " iden_1 = self.transform(iden_1)\n", 85 | " iden_2 = self.transform(iden_2)\n", 86 | " \n", 87 | " sample = {'x_gen' : x_gen, 'x_dis': x_dis, 'y': y, 'iden_1': iden_1, 'iden_2':iden_2}\n", 88 | " return sample" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "with open('./dataset_lists/train_datapoint_triplets.pkl', 'rb') as f:\n", 98 | " datapoint_pairs = pickle.load(f)" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "with open('./dataset_lists/train_shapeLoss_pairs.pkl', 'rb') as f:\n", 108 | " shapeLoss_datapoint_pairs = pickle.load(f)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "len(datapoint_pairs)" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": { 124 | "collapsed": true 125 | }, 126 | "outputs": [], 127 | "source": [ 128 | "transform = transforms.Compose([\n", 129 | " transforms.Resize((256, 256)),\n", 130 | " transforms.ToTensor()\n", 131 | "])" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": { 138 | "collapsed": true 139 | }, 140 | "outputs": [], 141 | "source": [ 142 | "train_dataset = YouTubePose(datapoint_pairs, shapeLoss_datapoint_pairs, dataset_dir, transform)\n", 143 | "train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,\n", 144 | " num_workers=0)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": { 151 | "collapsed": true 152 | }, 153 | "outputs": [], 154 | "source": [ 155 | "dataset_sizes = [len(train_dataset)]" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "def show_landmarks_batch(sample_batched):\n", 165 | " \n", 166 | " images_batch, landmarks_batch = sample_batched['iden_1'], sample_batched['iden_2']\n", 167 | "# batch_size = len(images_batch)\n", 168 | "# im_size = images_batch.size(2)\n", 169 | " \n", 170 | " grid = utils.make_grid(images_batch)\n", 171 | " grid1 = utils.make_grid(landmarks_batch)\n", 172 | " \n", 173 | "# plt.imshow(grid.numpy().transpose((1, 2, 0)))\n", 174 | " plt.imshow(grid1.numpy().transpose((1, 2, 0)))\n", 175 | " plt.title('Batch from dataloader')\n", 176 | " \n", 177 | "for i_batch, sample_batched in enumerate(train_dataloader):\n", 178 | " print(i_batch, sample_batched['x_gen'].size(), sample_batched['y'].size())\n", 179 | " \n", 180 | " if i_batch == 3:\n", 181 | " plt.figure()\n", 182 | " show_landmarks_batch(sample_batched)\n", 183 | " plt.axis('off')\n", 184 | " plt.ioff()\n", 185 | " plt.show()\n", 186 | " break" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "metadata": { 193 | "collapsed": true 194 | }, 195 | "outputs": [], 196 | "source": [ 197 | "# 3x3 convolution\n", 198 | "def conv3x3(in_channels, out_channels, stride=1):\n", 199 | " return nn.Conv2d(in_channels, out_channels, kernel_size=3, \n", 200 | " stride=stride, padding=1, bias=False)\n", 201 | "\n", 202 | "# Residual block\n", 203 | "class ResidualBlock(nn.Module):\n", 204 | " def __init__(self, in_channels, out_channels, stride=1, downsample=None):\n", 205 | " super(ResidualBlock, self).__init__()\n", 206 | " self.conv1 = conv3x3(in_channels, out_channels, stride)\n", 207 | " self.bn1 = nn.BatchNorm2d(out_channels)\n", 208 | " self.relu = nn.ReLU(inplace=True)\n", 209 | " self.conv2 = conv3x3(out_channels, out_channels)\n", 210 | " self.bn2 = nn.BatchNorm2d(out_channels)\n", 211 | " self.downsample = downsample\n", 212 | " \n", 213 | " def forward(self, x):\n", 214 | "# if torch.cuda.is_available():\n", 215 | "# x = torch.cuda.FloatTensor(x)\n", 216 | " residual = x\n", 217 | " out = self.conv1(x)\n", 218 | " out = self.bn1(out)\n", 219 | " out = self.relu(out)\n", 220 | " out = self.conv2(out)\n", 221 | " out = self.bn2(out)\n", 222 | " if self.downsample:\n", 223 | " residual = self.downsample(x)\n", 224 | " out += residual\n", 225 | " out = self.relu(out)\n", 226 | " return out" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "metadata": { 233 | "collapsed": true 234 | }, 235 | "outputs": [], 236 | "source": [ 237 | "class Generator(nn.Module):\n", 238 | " def __init__(self, block):\n", 239 | " super(Generator, self).__init__()\n", 240 | " \n", 241 | " self.conv1_x = nn.Sequential(\n", 242 | " nn.Conv2d(3, 16, 3, padding = 1),\n", 243 | " nn.BatchNorm2d(16),\n", 244 | " nn.ReLU(inplace=True),\n", 245 | " nn.AvgPool2d(2))\n", 246 | " \n", 247 | " self.conv2_x = nn.Sequential(\n", 248 | " nn.Conv2d(16, 32, 3, padding = 1),\n", 249 | " nn.BatchNorm2d(32),\n", 250 | " nn.ReLU(inplace=True),\n", 251 | " nn.AvgPool2d(2))\n", 252 | "\n", 253 | " self.conv3_x = nn.Sequential(\n", 254 | " nn.Conv2d(32, 16, 3, padding = 1),\n", 255 | " nn.BatchNorm2d(16),\n", 256 | " nn.ReLU(inplace=True))\n", 257 | "\n", 258 | " self.conv1_y = nn.Sequential(\n", 259 | " nn.Conv2d(3, 16, 3, padding = 1),\n", 260 | " nn.BatchNorm2d(16),\n", 261 | " nn.ReLU(inplace=True),\n", 262 | " nn.AvgPool2d(2))\n", 263 | " \n", 264 | " self.conv2_y = nn.Sequential(\n", 265 | " nn.Conv2d(16, 32, 3, padding = 1),\n", 266 | " nn.BatchNorm2d(32),\n", 267 | " nn.ReLU(inplace=True),\n", 268 | " nn.AvgPool2d(2))\n", 269 | "\n", 270 | " self.conv3_y = nn.Sequential(\n", 271 | " nn.Conv2d(32, 16, 3, padding = 1),\n", 272 | " nn.BatchNorm2d(16),\n", 273 | " nn.ReLU(inplace=True)) \n", 274 | " \n", 275 | "# self.conv1_x = nn.Conv2d(3, 16, 3, padding = 1)\n", 276 | "# self.conv2_x = nn.Conv2d(16, 32, 3)\n", 277 | "# self.conv3_x = nn.Conv2d(32, 16, 3, padding=1)\n", 278 | "\n", 279 | "# self.conv1_y = nn.Conv2d(3, 16, 3)\n", 280 | "# self.conv2_y = nn.Conv2d(16, 32, 3)\n", 281 | "# self.conv3_y = nn.Conv2d(32, 16, 3, padding = 1)\n", 282 | "\n", 283 | "# self.relu = nn.ReLU()\n", 284 | "# self.avgpool = nn.AvgPool2d(2)\n", 285 | " \n", 286 | " \n", 287 | " # 2 Residual Blocks for Identity Image\n", 288 | " self.block1_x = block(16, 16)\n", 289 | "# downsample_x = nn.Sequential(conv3x3(16, 1, 1), nn.BatchNorm2d(1))\n", 290 | "# self.block2_x = block(16, 1, 1, downsample_x)\n", 291 | " self.block2_x = block(16, 16)\n", 292 | "\n", 293 | " # 2 Residual Blocks for Shape Image\n", 294 | " self.block1_y = block(16, 16)\n", 295 | "# downsample_y = nn.Sequential(conv3x3(16, 1, 1), nn.BatchNorm2d(1))\n", 296 | "# self.block2_y = block(16, 1, 1, downsample_y)\n", 297 | " self.block2_y = block(16, 16)\n", 298 | " # 2 Residual Blocks for Combined(concat) image\n", 299 | " downsample1_concat = nn.Sequential(conv3x3(32, 16, 1), nn.BatchNorm2d(16))\n", 300 | " self.block1_concat = block(32, 16, 1, downsample1_concat)\n", 301 | "\n", 302 | " self.block2_concat = block(16, 16)\n", 303 | " \n", 304 | "# self.deconv1 = nn.ConvTranspose2d(16, 16, 3)\n", 305 | "# self.deconv2 = nn.ConvTranspose2d(16, 3, 3)\n", 306 | " \n", 307 | " \n", 308 | "# self.conv2_y = nn.Sequential(\n", 309 | "# nn.Conv2d(16, 32, 3, padding = 1),\n", 310 | "# nn.BatchNorm2d(32),\n", 311 | "# nn.ReLU(inplace=True),\n", 312 | "# nn.AvgPool2d(2))\n", 313 | " \n", 314 | " self.upsample1 = nn.Sequential(\n", 315 | " nn.Upsample(scale_factor=2, mode='bilinear', align_corners = True),\n", 316 | " nn.ConvTranspose2d(16, 32, 3, padding=1),\n", 317 | " nn.BatchNorm2d(32),\n", 318 | " nn.ReLU(inplace=True))\n", 319 | " \n", 320 | " self.upsample2 = nn.Sequential(\n", 321 | " nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),\n", 322 | " nn.ConvTranspose2d(32, 3, 3, padding=1),\n", 323 | " nn.BatchNorm2d(3),\n", 324 | " nn.ReLU(inplace=True))\n", 325 | " \n", 326 | " \n", 327 | " def forward(self, x, y):\n", 328 | " \n", 329 | " x = self.conv1_x(x)\n", 330 | " x = self.conv2_x(x)\n", 331 | " x = self.conv3_x(x)\n", 332 | " x = self.block1_x(x)\n", 333 | " x = self.block2_x(x)\n", 334 | " \n", 335 | " y = self.conv1_y(y)\n", 336 | " y = self.conv2_y(y)\n", 337 | " y = self.conv3_y(y)\n", 338 | " y = self.block1_y(y)\n", 339 | " y = self.block2_y(y)\n", 340 | " \n", 341 | "# if torch.cuda.is_available():\n", 342 | "# concat_result = torch.cuda.FloatTensor([x.shape[0], x.shape[1] * 2, x.shape[2], x.shape[3]]).fill_(0)\n", 343 | "# else:\n", 344 | " concat_result = torch.zeros([x.shape[0], x.shape[1] * 2, x.shape[2], x.shape[3]], dtype=x.dtype)\n", 345 | "# print(x.shape, y.shape, concat_result.shape)\n", 346 | " for i in range(batch_size):\n", 347 | " for j in range(x.shape[1]):\n", 348 | " concat_result[i][j] = x[i][j]\n", 349 | " concat_result[i][j + x.shape[1]] = y[i][j]\n", 350 | " if torch.cuda.is_available():\n", 351 | " concat_result = concat_result.cuda()\n", 352 | " concat_result = self.block1_concat(concat_result)\n", 353 | " concat_result = self.block2_concat(concat_result)\n", 354 | " \n", 355 | " upsampled_1 = self.upsample1(concat_result)\n", 356 | " upsampled_2 = self.upsample2(upsampled_1)\n", 357 | "# print(upsample2.shape)\n", 358 | " return upsampled_2\n", 359 | " " 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": null, 365 | "metadata": {}, 366 | "outputs": [], 367 | "source": [ 368 | "class NLayerDiscriminator(nn.Module):\n", 369 | " def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):\n", 370 | " super(NLayerDiscriminator, self).__init__()\n", 371 | " \n", 372 | " use_bias = norm_layer\n", 373 | "\n", 374 | " kw = 3\n", 375 | " padw = 1\n", 376 | " sequence = []\n", 377 | "# sequence = [\n", 378 | "# nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),\n", 379 | "# nn.LeakyReLU(0.2, True)\n", 380 | "# ]\n", 381 | " self.initial_conv = nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)\n", 382 | " self.initial_relu = nn.LeakyReLU(0.2, True)\n", 383 | " \n", 384 | " \n", 385 | " ndf = 128\n", 386 | " \n", 387 | " nf_mult = 1\n", 388 | " nf_mult_prev = 1\n", 389 | " for n in range(1, n_layers):\n", 390 | " nf_mult_prev = nf_mult\n", 391 | " nf_mult = min(2**n, 8)\n", 392 | " sequence += [\n", 393 | " nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,\n", 394 | " kernel_size=kw, stride=2, padding=padw, bias=use_bias),\n", 395 | " norm_layer(ndf * nf_mult),\n", 396 | " nn.LeakyReLU(0.2, True)\n", 397 | " ]\n", 398 | "\n", 399 | " nf_mult_prev = nf_mult\n", 400 | " nf_mult = min(2**n_layers, 8)\n", 401 | " sequence += [\n", 402 | " nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,\n", 403 | " kernel_size=kw, stride=1, padding=padw, bias=use_bias),\n", 404 | " norm_layer(ndf * nf_mult),\n", 405 | " nn.LeakyReLU(0.2, True)\n", 406 | " ]\n", 407 | "\n", 408 | " sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]\n", 409 | "\n", 410 | " if use_sigmoid:\n", 411 | " sequence += [nn.Sigmoid()]\n", 412 | "\n", 413 | " self.model = nn.Sequential(*sequence)\n", 414 | " \n", 415 | " self.maxpool = nn.MaxPool2d(8)\n", 416 | " def forward(self, x, y):\n", 417 | " x = self.initial_relu(self.initial_conv(x))\n", 418 | " y = self.initial_relu(self.initial_conv(y)) \n", 419 | "\n", 420 | "# if torch.cuda.is_available():\n", 421 | "# concat_fmap = torch.cuda.FloatTensor([x.shape[0], 2 * x.shape[1], x.shape[2], x.shape[3]]).fill_(0)\n", 422 | "# else:\n", 423 | " concat_fmap = torch.zeros([x.shape[0], 2 * x.shape[1], x.shape[2], x.shape[3]], dtype=x.dtype)\n", 424 | "\n", 425 | " # print(x.shape, y.shape, concat_result.shape)\n", 426 | " for i in range(x.shape[0]):\n", 427 | " for j in range(x.shape[1]):\n", 428 | " concat_fmap[i][j] = x[i][j]\n", 429 | " concat_fmap[i][j + 64] = y[i][j]\n", 430 | " if torch.cuda.is_available():\n", 431 | " concat_fmap = concat_fmap.cuda()\n", 432 | " op = self.model(concat_fmap)\n", 433 | " op_neg = - op\n", 434 | " op_neg = self.maxpool(op_neg)\n", 435 | " op_pooled = -op_neg\n", 436 | " \n", 437 | " return op, op_pooled" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": null, 443 | "metadata": { 444 | "collapsed": true 445 | }, 446 | "outputs": [], 447 | "source": [ 448 | "generator = Generator(ResidualBlock)" 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": null, 454 | "metadata": {}, 455 | "outputs": [], 456 | "source": [ 457 | "discriminator = NLayerDiscriminator(3)" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": null, 463 | "metadata": { 464 | "collapsed": true 465 | }, 466 | "outputs": [], 467 | "source": [ 468 | "# sample = []\n", 469 | "# for i_batch, sample_batched in enumerate(train_dataloader):\n", 470 | "# sample = sample_batched\n", 471 | "# if i_batch == 0:\n", 472 | "# break" 473 | ] 474 | }, 475 | { 476 | "cell_type": "code", 477 | "execution_count": null, 478 | "metadata": {}, 479 | "outputs": [], 480 | "source": [ 481 | "# op = generator(sample['x_gen'], sample['y'])\n", 482 | "# op.shape" 483 | ] 484 | }, 485 | { 486 | "cell_type": "code", 487 | "execution_count": null, 488 | "metadata": { 489 | "collapsed": true 490 | }, 491 | "outputs": [], 492 | "source": [ 493 | "generator = generator.to(device)\n", 494 | "discriminator = discriminator.to(device)" 495 | ] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "execution_count": null, 500 | "metadata": { 501 | "collapsed": true 502 | }, 503 | "outputs": [], 504 | "source": [ 505 | "# def lossIdentity(real_pooled_op, fake_pooled_op):\n", 506 | "# batch_size = real_pooled_op.size()[0]\n", 507 | "# real_pooled_op = 1 - real_pooled_op\n", 508 | "# real_pooled_op = real_pooled_op ** 2\n", 509 | "# fake_pooled_op = (fake_pooled_op ** 2)\n", 510 | "# real_pooled_op = real_pooled_op.view(batch_size, -1)\n", 511 | "# fake_pooled_op = fake_pooled_op.view(batch_size, -1)\n", 512 | "# real_pooled_op = torch.sum(real_pooled_op, dim = 1)\n", 513 | "# fake_pooled_op = torch.sum(fake_pooled_op, dim = 1)\n", 514 | "# return (real_pooled_op + fake_pooled_op) / batch_size\n", 515 | "\n", 516 | "def lossIdentity(real_pair, fake_pair):\n", 517 | " batch_size = real_pair.size()[0]\n", 518 | " real_pair = 1 - real_pair\n", 519 | " real_pair = real_pair ** 2\n", 520 | " fake_pair = fake_pair ** 2\n", 521 | " real_pair = torch.sum(real_pair)\n", 522 | " fake_pair = torch.sum(fake_pair)\n", 523 | " return (real_pair + fake_pair) / batch_size" 524 | ] 525 | }, 526 | { 527 | "cell_type": "code", 528 | "execution_count": null, 529 | "metadata": { 530 | "collapsed": true 531 | }, 532 | "outputs": [], 533 | "source": [ 534 | "def lossShape(x, y):\n", 535 | " batch_size = x.size()[0]\n", 536 | " diff = x - y\n", 537 | " diff = diff ** 2\n", 538 | " diff = torch.sum(diff) / batch_size\n", 539 | " return diff" 540 | ] 541 | }, 542 | { 543 | "cell_type": "code", 544 | "execution_count": null, 545 | "metadata": {}, 546 | "outputs": [], 547 | "source": [ 548 | "def train_model(gen, disc, loss_i, loss_s, optimizer_gen, optimizer_disc, alpha = 1, beta = 1, num_epochs = 10):\n", 549 | " for epoch in range(num_epochs):\n", 550 | " print(\"Epoch {}/{}\".format(epoch, num_epochs - 1))\n", 551 | " print('-'*10)\n", 552 | " dataloader = train_dataloader\n", 553 | " gen.train()\n", 554 | " disc.train()\n", 555 | " since = time.time()\n", 556 | " running_loss_iden = 0.0\n", 557 | " running_loss_s1 = 0.0\n", 558 | " running_loss_s2a = 0.0\n", 559 | " running_loss_s2b = 0.0\n", 560 | " running_loss = 0.0\n", 561 | " \n", 562 | " for i_batch, sample_batched in enumerate(dataloader):\n", 563 | " x_gen, y, x_dis = sample_batched['x_gen'], sample_batched['y'], sample_batched['x_dis']\n", 564 | " iden_1, iden_2 = sample_batched['iden_1'], sample_batched['iden_2']\n", 565 | " x_gen = x_gen.to(device)\n", 566 | " y = y.to(device)\n", 567 | " x_dis = x_dis.to(device)\n", 568 | " iden_1 = iden_1.to(device)\n", 569 | " iden_2 = iden_2.to(device)\n", 570 | " \n", 571 | " optimizer_gen.zero_grad()\n", 572 | " optimizer_disc.zero_grad()\n", 573 | " \n", 574 | " with torch.set_grad_enabled(True):\n", 575 | " x_generated = gen(x_gen, y)\n", 576 | " print('forward 1 done')\n", 577 | " fake_op, fake_pooled_op = disc(x_gen, x_generated)\n", 578 | " real_op, real_pooled_op = disc(x_gen, x_dis)\n", 579 | " loss_identity_gen = -loss_i(real_pooled_op, fake_pooled_op)\n", 580 | " print('Loss calculated')\n", 581 | " loss_identity_gen.backward(retain_graph=True)\n", 582 | " optimizer_gen.step()\n", 583 | " print('backward 1.1 done')\n", 584 | " \n", 585 | " optimizer_disc.zero_grad()\n", 586 | " loss_identity_disc = loss_i(real_op, fake_op)\n", 587 | " print('Loss calculated')\n", 588 | " loss_identity_disc.backward(retain_graph=True)\n", 589 | " optimizer_disc.step()\n", 590 | " print('backward 1.2 done')\n", 591 | "\n", 592 | " optimizer_gen.zero_grad()\n", 593 | " optimizer_disc.zero_grad()\n", 594 | " x_ls2a = gen(y, x_generated)\n", 595 | " x_ls2b = gen(x_generated, y)\n", 596 | " print('forward 2 done')\n", 597 | "\n", 598 | " loss_s2a = loss_s(y, x_ls2a)\n", 599 | " loss_s2b = loss_s(x_generated, x_ls2b)\n", 600 | " loss_s2 = loss_s2a + loss_s2b\n", 601 | " print('Loss calculated')\n", 602 | "\n", 603 | " loss_s2.backward()\n", 604 | " optimizer_gen.step()\n", 605 | " print('backward 2 done')\n", 606 | "\n", 607 | " optimizer_gen.zero_grad()\n", 608 | " optimizer_disc.zero_grad()\n", 609 | " \n", 610 | " x_ls1 = generator(iden_1, iden_2)\n", 611 | " print('forward 3 done')\n", 612 | "\n", 613 | " loss_s1 = loss_s(iden_2, x_ls1)\n", 614 | " print('Loss calculated')\n", 615 | " loss_s1.backward()\n", 616 | " optimizer_gen.step()\n", 617 | " print('backward 5 done')\n", 618 | " print()\n", 619 | " running_loss_iden += loss_identity_disc.item() * x_gen.size(0)\n", 620 | " running_loss_s1 += loss_s1.item() * x_gen.size(0)\n", 621 | " running_loss_s2a += loss_s2a.item() * x_gen.size(0) \n", 622 | " running_loss_s2b += loss_s2b.item() * x_gen.size(0)\n", 623 | " running_loss = running_loss_iden + beta * (running_loss_s1 + alpha * (running_loss_s2a + running_loss_s2b))\n", 624 | " print(str(time.time() - since))\n", 625 | " since = time.time()\n", 626 | " epoch_loss_iden = running_loss_iden / dataset_sizes[0]\n", 627 | " epoch_loss_s1 = running_loss_s1 / dataset_sizes[0]\n", 628 | " epoch_loss_s2a = running_loss_s2a / dataset_sizes[0]\n", 629 | " epoch_loss_s2b = running_loss_s2a / dataset_sizes[0]\n", 630 | " epoch_loss = running_loss / dataset_sizes[0]\n", 631 | " print('Identity Loss: {:.4f} Loss Shape1: {:.4f} Loss Shape2a: {:.4f} \\\n", 632 | " Loss Shape2b: {:.4f}'.format(epoch_loss_iden, epoch_loss_s1,\n", 633 | " epoch_loss_s2a, epoch_loss_s2b))\n", 634 | " print('Epoch Loss: {:.4f}'.format(epoch_loss))\n", 635 | " print('Time Taken: ' + str(time.time() - since))\n", 636 | " return gen, disc" 637 | ] 638 | }, 639 | { 640 | "cell_type": "code", 641 | "execution_count": null, 642 | "metadata": {}, 643 | "outputs": [], 644 | "source": [ 645 | "optimizer_gen = optim.SGD(generator.parameters(), lr = learning_rate, momentum=0.9)\n", 646 | "optimizer_disc = optim.SGD(discriminator.parameters(), lr = learning_rate, momentum=0.9)" 647 | ] 648 | }, 649 | { 650 | "cell_type": "code", 651 | "execution_count": null, 652 | "metadata": {}, 653 | "outputs": [], 654 | "source": [ 655 | "generator, discriminator = train_model(generator, discriminator, lossIdentity, lossShape, optimizer_gen, optimizer_disc, num_epochs=epochs)" 656 | ] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "execution_count": null, 661 | "metadata": { 662 | "collapsed": true 663 | }, 664 | "outputs": [], 665 | "source": [] 666 | } 667 | ], 668 | "metadata": { 669 | "kernelspec": { 670 | "display_name": "Python 3", 671 | "language": "python", 672 | "name": "python3" 673 | }, 674 | "language_info": { 675 | "codemirror_mode": { 676 | "name": "ipython", 677 | "version": 3 678 | }, 679 | "file_extension": ".py", 680 | "mimetype": "text/x-python", 681 | "name": "python", 682 | "nbconvert_exporter": "python", 683 | "pygments_lexer": "ipython3", 684 | "version": "3.5.4" 685 | } 686 | }, 687 | "nbformat": 4, 688 | "nbformat_minor": 2 689 | } 690 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Implementation of FusionGAN 2 | 3 | This repository contains the code for implementation of the FusionGAN model described in the paper [**Generating a Fusion Image: One's Identity and Another's Shape**](https://arxiv.org/abs/1804.07455) 4 | 5 | ## Dependencies 6 | 7 | The following are the dependencies required by this repository: 8 | 9 | + PyTorch v0.4 10 | + NumPy 11 | + SciPy 12 | + Pickle 13 | + PIL 14 | + Matplotlib 15 | 16 | ## Setup Instructions 17 | 18 | First, download the repository on your local machine by either downloading it or running the following script in the terminal 19 | 20 | ``` Batchfile 21 | git clone https://github.com/aarushgupta/FusionGAN.git 22 | ``` 23 | 24 | Next, go through the instructions mentioned in the `Dataset Preparation` section. 25 | ## Dataset Preparation 26 | 27 | As the data is not publically available in the desired form, the frames of the required YouTube videos have been saved at [**this**](https://drive.google.com/drive/folders/1waOPQYOmQF1k0pT50uqp6STzYDdSv_5N?usp=sharing) Google Drive link. 28 | 29 | The link contains a compressed train folder which has the following 3 folders: 30 | 1. class1_cropped 31 | 2. class2_cropped 32 | 3. class3_cropped 33 | 34 | Download the data from the link and put the folders according to the following directory structure: 35 | ``` 36 | /FusionGAN_root_directory/Dataset/train/class1_cropped 37 | /FusionGAN_root_directory/Dataset/train/class2_cropped 38 | /FusionGAN_root_directory/Dataset/train/class3_cropped 39 | ``` 40 | 41 | ## Training Instructions 42 | The hyperparameters of the model have been preset. To start training of the model, simply run the `train.py` file using the following command 43 | 44 | ``` cmd 45 | python train.py 46 | ``` 47 | The code can also be run interactively using the `train.ipynb` Jupyter Notebook provided in the repository. 48 | 49 | ## To-Do 50 | 51 | 1. [ ] Train the model and add checkpoints. 52 | 2. [ ] Polish and add code for preparing the dataset. 53 | 3. [ ] Add test script for the model. 54 | 4. [ ] Add keypoint estimation for quantitative evaluation. 55 | 5. [ ] Remove the unused images in the dataset. -------------------------------------------------------------------------------- /dataset_lists/train_datapoint_triplets.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aarushgupta/FusionGAN/6efeda8aa501a6ebfad90adb50eb812a09198c11/dataset_lists/train_datapoint_triplets.pkl -------------------------------------------------------------------------------- /dataset_lists/train_shapeLoss_pairs.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aarushgupta/FusionGAN/6efeda8aa501a6ebfad90adb50eb812a09198c11/dataset_lists/train_shapeLoss_pairs.pkl -------------------------------------------------------------------------------- /model/__pycache__/discriminator.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aarushgupta/FusionGAN/6efeda8aa501a6ebfad90adb50eb812a09198c11/model/__pycache__/discriminator.cpython-35.pyc -------------------------------------------------------------------------------- /model/__pycache__/generator.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aarushgupta/FusionGAN/6efeda8aa501a6ebfad90adb50eb812a09198c11/model/__pycache__/generator.cpython-35.pyc -------------------------------------------------------------------------------- /model/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class NLayerDiscriminator(nn.Module): 5 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False): 6 | super(NLayerDiscriminator, self).__init__() 7 | 8 | use_bias = norm_layer 9 | 10 | kw = 3 11 | padw = 1 12 | sequence = [] 13 | self.initial_conv = nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw) 14 | self.initial_relu = nn.LeakyReLU(0.2, True) 15 | 16 | 17 | ndf = 128 18 | 19 | nf_mult = 1 20 | nf_mult_prev = 1 21 | for n in range(1, n_layers): 22 | nf_mult_prev = nf_mult 23 | nf_mult = min(2**n, 8) 24 | sequence += [ 25 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 26 | kernel_size=kw, stride=2, padding=padw, bias=use_bias), 27 | norm_layer(ndf * nf_mult), 28 | nn.LeakyReLU(0.2, True) 29 | ] 30 | 31 | nf_mult_prev = nf_mult 32 | nf_mult = min(2**n_layers, 8) 33 | sequence += [ 34 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 35 | kernel_size=kw, stride=1, padding=padw, bias=use_bias), 36 | norm_layer(ndf * nf_mult), 37 | nn.LeakyReLU(0.2, True) 38 | ] 39 | 40 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] 41 | 42 | if use_sigmoid: 43 | sequence += [nn.Sigmoid()] 44 | 45 | self.model = nn.Sequential(*sequence) 46 | 47 | self.maxpool = nn.MaxPool2d(8) 48 | def forward(self, x, y): 49 | x = self.initial_relu(self.initial_conv(x)) 50 | y = self.initial_relu(self.initial_conv(y)) 51 | 52 | concat_fmap = torch.zeros([x.shape[0], 2 * x.shape[1], x.shape[2], x.shape[3]], dtype=x.dtype) 53 | # print(x.shape, y.shape, concat_result.shape) 54 | for i in range(x.shape[0]): 55 | for j in range(x.shape[1]): 56 | concat_fmap[i][j] = x[i][j] 57 | concat_fmap[i][j + 64] = y[i][j] 58 | if torch.cuda.is_available(): 59 | concat_fmap = concat_fmap.cuda() 60 | op = self.model(concat_fmap) 61 | op_neg = - op 62 | op_neg = self.maxpool(op_neg) 63 | op_pooled = -op_neg 64 | 65 | return op, op_pooled -------------------------------------------------------------------------------- /model/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # 3x3 convolution 6 | def conv3x3(in_channels, out_channels, stride=1): 7 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, 8 | stride=stride, padding=1, bias=False) 9 | 10 | # Residual block 11 | class ResidualBlock(nn.Module): 12 | def __init__(self, in_channels, out_channels, stride=1, downsample=None): 13 | super(ResidualBlock, self).__init__() 14 | self.conv1 = conv3x3(in_channels, out_channels, stride) 15 | self.bn1 = nn.BatchNorm2d(out_channels) 16 | self.relu = nn.ReLU(inplace=True) 17 | self.conv2 = conv3x3(out_channels, out_channels) 18 | self.bn2 = nn.BatchNorm2d(out_channels) 19 | self.downsample = downsample 20 | 21 | def forward(self, x): 22 | residual = x 23 | # if torch.cuda.is_available(): 24 | # residual = torch.cuda.FloatTensor(residual) 25 | out = self.conv1(x) 26 | out = self.bn1(out) 27 | out = self.relu(out) 28 | out = self.conv2(out) 29 | out = self.bn2(out) 30 | if self.downsample: 31 | residual = self.downsample(x) 32 | out += residual 33 | out = self.relu(out) 34 | return out 35 | 36 | # Generator Definition 37 | 38 | class Generator(nn.Module): 39 | def __init__(self, block): 40 | super(Generator, self).__init__() 41 | 42 | self.conv1_x = nn.Sequential( 43 | nn.Conv2d(3, 16, 3, padding = 1), 44 | nn.BatchNorm2d(16), 45 | nn.ReLU(inplace=True), 46 | nn.AvgPool2d(2)) 47 | 48 | self.conv2_x = nn.Sequential( 49 | nn.Conv2d(16, 32, 3, padding = 1), 50 | nn.BatchNorm2d(32), 51 | nn.ReLU(inplace=True), 52 | nn.AvgPool2d(2)) 53 | 54 | self.conv3_x = nn.Sequential( 55 | nn.Conv2d(32, 16, 3, padding = 1), 56 | nn.BatchNorm2d(16), 57 | nn.ReLU(inplace=True)) 58 | 59 | self.conv1_y = nn.Sequential( 60 | nn.Conv2d(3, 16, 3, padding = 1), 61 | nn.BatchNorm2d(16), 62 | nn.ReLU(inplace=True), 63 | nn.AvgPool2d(2)) 64 | 65 | self.conv2_y = nn.Sequential( 66 | nn.Conv2d(16, 32, 3, padding = 1), 67 | nn.BatchNorm2d(32), 68 | nn.ReLU(inplace=True), 69 | nn.AvgPool2d(2)) 70 | 71 | self.conv3_y = nn.Sequential( 72 | nn.Conv2d(32, 16, 3, padding = 1), 73 | nn.BatchNorm2d(16), 74 | nn.ReLU(inplace=True)) 75 | 76 | # 2 Residual Blocks for Identity Image 77 | self.block1_x = block(16, 16) 78 | # downsample_x = nn.Sequential(conv3x3(16, 1, 1), nn.BatchNorm2d(1)) 79 | # self.block2_x = block(16, 1, 1, downsample_x) 80 | self.block2_x = block(16, 16) 81 | 82 | # 2 Residual Blocks for Shape Image 83 | self.block1_y = block(16, 16) 84 | # downsample_y = nn.Sequential(conv3x3(16, 1, 1), nn.BatchNorm2d(1)) 85 | # self.block2_y = block(16, 1, 1, downsample_y) 86 | self.block2_y = block(16, 16) 87 | 88 | # 2 Residual Blocks for Combined(concat) image 89 | downsample1_concat = nn.Sequential(conv3x3(32, 16, 1), nn.BatchNorm2d(16)) 90 | self.block1_concat = block(32, 16, 1, downsample1_concat) 91 | 92 | self.block2_concat = block(16, 16) 93 | 94 | # Upsampling layers 95 | 96 | self.upsample1 = nn.Sequential( 97 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners = True), 98 | nn.ConvTranspose2d(16, 32, 3, padding=1), 99 | nn.BatchNorm2d(32), 100 | nn.ReLU(inplace=True)) 101 | 102 | self.upsample2 = nn.Sequential( 103 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 104 | nn.ConvTranspose2d(32, 3, 3, padding=1), 105 | nn.BatchNorm2d(3), 106 | nn.ReLU(inplace=True)) 107 | 108 | 109 | def forward(self, x, y): 110 | 111 | x = self.conv1_x(x) 112 | x = self.conv2_x(x) 113 | x = self.conv3_x(x) 114 | x = self.block1_x(x) 115 | x = self.block2_x(x) 116 | 117 | y = self.conv1_y(y) 118 | y = self.conv2_y(y) 119 | y = self.conv3_y(y) 120 | y = self.block1_y(y) 121 | y = self.block2_y(y) 122 | 123 | concat_result = torch.zeros([x.shape[0], x.shape[1] * 2, x.shape[2], x.shape[3]], dtype=x.dtype) 124 | # print(x.shape, y.shape, concat_result.shape) 125 | for i in range(x.shape[0]): 126 | for j in range(x.shape[1]): 127 | concat_result[i][j] = x[i][j] 128 | concat_result[i][j + x.shape[1]] = y[i][j] 129 | if torch.cuda.is_available(): 130 | concat_result = concat_result.cuda() 131 | concat_result = self.block1_concat(concat_result) 132 | concat_result = self.block2_concat(concat_result) 133 | 134 | upsampled_1 = self.upsample1(concat_result) 135 | upsampled_2 = self.upsample2(upsampled_1) 136 | # print(upsample2.shape) 137 | return upsampled_2 138 | -------------------------------------------------------------------------------- /train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import torch\n", 12 | "import torch.optim as optim\n", 13 | "import torch.nn.functional as F\n", 14 | "from torch.utils.data import DataLoader\n", 15 | "from torchvision import transforms, utils\n", 16 | "\n", 17 | "import time\n", 18 | "import numpy as np\n", 19 | "import os\n", 20 | "import pickle\n", 21 | "from PIL import Image\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "plt.ion()\n", 24 | "\n", 25 | "from model.generator import *\n", 26 | "from model.discriminator import NLayerDiscriminator as Discriminator\n", 27 | "from utils.dataloader import YouTubePose\n", 28 | "from utils.loss_functions import lossIdentity, lossShape" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": { 35 | "collapsed": true 36 | }, 37 | "outputs": [], 38 | "source": [ 39 | "dataset_dir = './Dataset/'\n", 40 | "checkpoint_path = \"./model_checkpoints/\"\n", 41 | "batch_size = 4\n", 42 | "epochs = 10\n", 43 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 44 | "learning_rate_generator = 3e-4\n", 45 | "learning_rate_discriminator = 0.1\n", 46 | "alpha = 8\n", 47 | "beta = 0.001" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 3, 53 | "metadata": { 54 | "collapsed": true 55 | }, 56 | "outputs": [], 57 | "source": [ 58 | "with open('./dataset_lists/train_datapoint_triplets.pkl', 'rb') as f:\n", 59 | " datapoint_pairs = pickle.load(f)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 4, 65 | "metadata": { 66 | "collapsed": true 67 | }, 68 | "outputs": [], 69 | "source": [ 70 | "with open('./dataset_lists/train_shapeLoss_pairs.pkl', 'rb') as f:\n", 71 | " shapeLoss_datapoint_pairs = pickle.load(f)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 5, 77 | "metadata": { 78 | "collapsed": true 79 | }, 80 | "outputs": [], 81 | "source": [ 82 | "transform = transforms.Compose([\n", 83 | " transforms.Resize((256, 256)),\n", 84 | " transforms.ToTensor(),\n", 85 | " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n", 86 | "])" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 6, 92 | "metadata": { 93 | "collapsed": true 94 | }, 95 | "outputs": [], 96 | "source": [ 97 | "train_dataset = YouTubePose(datapoint_pairs, shapeLoss_datapoint_pairs, dataset_dir, transform)\n", 98 | "train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,\n", 99 | " num_workers=0)\n", 100 | "dataset_sizes = [len(train_dataset)]" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 7, 106 | "metadata": { 107 | "collapsed": true 108 | }, 109 | "outputs": [], 110 | "source": [ 111 | "generator = Generator(ResidualBlock)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 8, 117 | "metadata": { 118 | "collapsed": true 119 | }, 120 | "outputs": [], 121 | "source": [ 122 | "discriminator = Discriminator(3)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 9, 128 | "metadata": { 129 | "collapsed": true 130 | }, 131 | "outputs": [], 132 | "source": [ 133 | "generator = generator.to(device)\n", 134 | "discriminator = discriminator.to(device)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 10, 140 | "metadata": { 141 | "collapsed": true 142 | }, 143 | "outputs": [], 144 | "source": [ 145 | "optimizer_gen = optim.Adam(generator.parameters(), lr = learning_rate_generator)\n", 146 | "optimizer_disc = optim.SGD(discriminator.parameters(), lr = learning_rate_discriminator, momentum=0.9)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 11, 152 | "metadata": { 153 | "collapsed": true 154 | }, 155 | "outputs": [], 156 | "source": [ 157 | "def save_checkpoint(state, dirpath, epoch):\n", 158 | " filename = 'checkpoint-{}.ckpt'.format(epoch)\n", 159 | " checkpoint_path = os.path.join(dirpath, filename)\n", 160 | " torch.save(state, checkpoint_path)\n", 161 | " print('--- checkpoint saved to ' + str(checkpoint_path) + ' ---')" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 12, 167 | "metadata": { 168 | "collapsed": true 169 | }, 170 | "outputs": [], 171 | "source": [ 172 | "def train_model(gen, disc, loss_i, loss_s, optimizer_gen, optimizer_disc, alpha = 1, beta = 1, num_epochs = 10):\n", 173 | " for epoch in range(num_epochs):\n", 174 | " print(\"Epoch {}/{}\".format(epoch, num_epochs - 1))\n", 175 | " print('-'*10)\n", 176 | " dataloader = train_dataloader\n", 177 | " gen.train()\n", 178 | " disc.train()\n", 179 | " since = time.time()\n", 180 | " running_loss_iden = 0.0\n", 181 | " running_loss_s1 = 0.0\n", 182 | " running_loss_s2a = 0.0\n", 183 | " running_loss_s2b = 0.0\n", 184 | " running_loss = 0.0\n", 185 | " \n", 186 | " for i_batch, sample_batched in enumerate(dataloader):\n", 187 | " x_gen, y, x_dis = sample_batched['x_gen'], sample_batched['y'], sample_batched['x_dis']\n", 188 | " iden_1, iden_2 = sample_batched['iden_1'], sample_batched['iden_2']\n", 189 | " x_gen = x_gen.to(device)\n", 190 | " y = y.to(device)\n", 191 | " x_dis = x_dis.to(device)\n", 192 | " iden_1 = iden_1.to(device)\n", 193 | " iden_2 = iden_2.to(device)\n", 194 | " \n", 195 | " optimizer_gen.zero_grad()\n", 196 | " optimizer_disc.zero_grad()\n", 197 | " \n", 198 | " with torch.set_grad_enabled(True):\n", 199 | " x_generated = gen(x_gen, y)\n", 200 | " fake_op, fake_pooled_op = disc(x_gen, x_generated)\n", 201 | " real_op, real_pooled_op = disc(x_gen, x_dis)\n", 202 | " loss_identity_gen = -loss_i(real_pooled_op, fake_pooled_op)\n", 203 | " loss_identity_gen.backward(retain_graph=True)\n", 204 | " optimizer_gen.step()\n", 205 | " \n", 206 | " optimizer_disc.zero_grad()\n", 207 | " loss_identity_disc = loss_i(real_op, fake_op)\n", 208 | " loss_identity_disc.backward(retain_graph=True)\n", 209 | " optimizer_disc.step()\n", 210 | "\n", 211 | " optimizer_gen.zero_grad()\n", 212 | " optimizer_disc.zero_grad()\n", 213 | " x_ls2a = gen(y, x_generated)\n", 214 | " x_ls2b = gen(x_generated, y)\n", 215 | "\n", 216 | " loss_s2a = loss_s(y, x_ls2a)\n", 217 | " loss_s2b = loss_s(x_generated, x_ls2b)\n", 218 | " loss_s2 = loss_s2a + loss_s2b\n", 219 | "\n", 220 | " loss_s2.backward()\n", 221 | " optimizer_gen.step()\n", 222 | "\n", 223 | " optimizer_gen.zero_grad()\n", 224 | " optimizer_disc.zero_grad()\n", 225 | " \n", 226 | " x_ls1 = generator(iden_1, iden_2)\n", 227 | "\n", 228 | " loss_s1 = loss_s(iden_2, x_ls1)\n", 229 | " loss_s1.backward()\n", 230 | " optimizer_gen.step()\n", 231 | " running_loss_iden += loss_identity_disc.item() * x_gen.size(0)\n", 232 | " running_loss_s1 += loss_s1.item() * x_gen.size(0)\n", 233 | " running_loss_s2a += loss_s2a.item() * x_gen.size(0) \n", 234 | " running_loss_s2b += loss_s2b.item() * x_gen.size(0)\n", 235 | " running_loss = running_loss_iden + beta * (running_loss_s1 + alpha * (running_loss_s2a + running_loss_s2b))\n", 236 | " epoch_loss_iden = running_loss_iden / dataset_sizes[0]\n", 237 | " epoch_loss_s1 = running_loss_s1 / dataset_sizes[0]\n", 238 | " epoch_loss_s2a = running_loss_s2a / dataset_sizes[0]\n", 239 | " epoch_loss_s2b = running_loss_s2a / dataset_sizes[0]\n", 240 | " epoch_loss = running_loss / dataset_sizes[0]\n", 241 | " print('Identity Loss: {:.4f} Loss Shape1: {:.4f} Loss Shape2a: {:.4f} Loss Shape2b: {:.4f}'.format(epoch_loss_iden, epoch_loss_s1,\n", 242 | " epoch_loss_s2a, epoch_loss_s2b))\n", 243 | " print('Epoch Loss: {:.4f}'.format(epoch_loss))\n", 244 | " \n", 245 | " save_checkpoint({\n", 246 | " 'epoch': epoch + 1,\n", 247 | " 'gen_state_dict': gen.state_dict(),\n", 248 | " 'disc_state_dict': disc.state_dict(),\n", 249 | " 'gen_opt': optimizer_gen.state_dict(),\n", 250 | " 'disc_opt': optimizer_disc.state_dict()\n", 251 | " }, checkpoint_path, epoch + 1)\n", 252 | " print('Time taken by epoch: {: .0f}m {:0f}s'.format((time.time() - since) // 60, (time.time() - since) % 60))\n", 253 | " print()\n", 254 | " since = time.time()\n", 255 | "\n", 256 | " return gen, disc" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "metadata": {}, 263 | "outputs": [ 264 | { 265 | "name": "stdout", 266 | "output_type": "stream", 267 | "text": [ 268 | "Epoch 0/9\n", 269 | "----------\n" 270 | ] 271 | } 272 | ], 273 | "source": [ 274 | "generator, discriminator = train_model(generator, discriminator, lossIdentity, lossShape, optimizer_gen, optimizer_disc, alpha=alpha, beta=beta, num_epochs=epochs)" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": null, 280 | "metadata": { 281 | "collapsed": true 282 | }, 283 | "outputs": [], 284 | "source": [] 285 | } 286 | ], 287 | "metadata": { 288 | "kernelspec": { 289 | "display_name": "Python 3", 290 | "language": "python", 291 | "name": "python3" 292 | }, 293 | "language_info": { 294 | "codemirror_mode": { 295 | "name": "ipython", 296 | "version": 3 297 | }, 298 | "file_extension": ".py", 299 | "mimetype": "text/x-python", 300 | "name": "python", 301 | "nbconvert_exporter": "python", 302 | "pygments_lexer": "ipython3", 303 | "version": "3.5.4" 304 | } 305 | }, 306 | "nbformat": 4, 307 | "nbformat_minor": 2 308 | } 309 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | ##################################################### 2 | ## 3 | ## IMPORTING MODULES AND LIBRARIES 4 | ## 5 | ##################################################### 6 | 7 | import torch 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | from torch.utils.data import DataLoader 11 | from torchvision import transforms, utils 12 | 13 | import time 14 | import numpy as np 15 | import os 16 | import pickle 17 | from PIL import Image 18 | import matplotlib.pyplot as plt 19 | plt.ion() 20 | 21 | from model.generator import * 22 | from model.discriminator import NLayerDiscriminator as Discriminator 23 | from utils.dataloader import YouTubePose 24 | from utils.loss_functions import lossIdentity, lossShape 25 | 26 | 27 | ##################################################### 28 | ## 29 | ## IMPORTANT PARAMETERS 30 | ## 31 | ##################################################### 32 | 33 | dataset_dir = './Dataset/' 34 | checkpoint_path = "./model_checkpoints/" 35 | batch_size = 4 36 | epochs = 10 37 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 38 | learning_rate_generator = 3e-4 39 | learning_rate_discriminator = 0.1 40 | alpha = 8 41 | beta = 0.001 42 | 43 | 44 | ##################################################### 45 | ## 46 | ## DATALOADER 47 | ## 48 | ##################################################### 49 | 50 | with open('./dataset_lists/train_datapoint_triplets.pkl', 'rb') as f: 51 | datapoint_pairs = pickle.load(f) 52 | 53 | with open('./dataset_lists/train_shapeLoss_pairs.pkl', 'rb') as f: 54 | shapeLoss_datapoint_pairs = pickle.load(f) 55 | 56 | transform = transforms.Compose([ 57 | transforms.Resize((256, 256)), 58 | transforms.ToTensor(), 59 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 60 | ]) 61 | 62 | train_dataset = YouTubePose(datapoint_pairs, shapeLoss_datapoint_pairs, dataset_dir, transform) 63 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 64 | num_workers=0) 65 | dataset_sizes = [len(train_dataset)] 66 | 67 | 68 | ##################################################### 69 | ## 70 | ## MODEL PREPARATION 71 | ## 72 | ##################################################### 73 | 74 | generator = Generator(ResidualBlock) 75 | 76 | discriminator = Discriminator(3) 77 | 78 | generator = generator.to(device) 79 | 80 | discriminator = discriminator.to(device) 81 | 82 | optimizer_gen = optim.Adam(generator.parameters(), lr = learning_rate_generator) 83 | optimizer_disc = optim.SGD(discriminator.parameters(), lr = learning_rate_discriminator, momentum=0.9) 84 | 85 | ##################################################### 86 | ## 87 | ## TRAINING SCRIPT 88 | ## 89 | ##################################################### 90 | 91 | 92 | def save_checkpoint(state, dirpath, epoch): 93 | filename = 'checkpoint-{}.ckpt'.format(epoch) 94 | checkpoint_path = os.path.join(dirpath, filename) 95 | torch.save(state, checkpoint_path) 96 | print('--- checkpoint saved to ' + str(checkpoint_path) + ' ---') 97 | 98 | def train_model(gen, disc, loss_i, loss_s, optimizer_gen, optimizer_disc, alpha = 1, beta = 1, num_epochs = 10): 99 | for epoch in range(num_epochs): 100 | print("Epoch {}/{}".format(epoch, num_epochs - 1)) 101 | print('-'*10) 102 | dataloader = train_dataloader 103 | gen.train() 104 | disc.train() 105 | since = time.time() 106 | running_loss_iden = 0.0 107 | running_loss_s1 = 0.0 108 | running_loss_s2a = 0.0 109 | running_loss_s2b = 0.0 110 | running_loss = 0.0 111 | 112 | for i_batch, sample_batched in enumerate(dataloader): 113 | x_gen, y, x_dis = sample_batched['x_gen'], sample_batched['y'], sample_batched['x_dis'] 114 | iden_1, iden_2 = sample_batched['iden_1'], sample_batched['iden_2'] 115 | x_gen = x_gen.to(device) 116 | y = y.to(device) 117 | x_dis = x_dis.to(device) 118 | iden_1 = iden_1.to(device) 119 | iden_2 = iden_2.to(device) 120 | 121 | optimizer_gen.zero_grad() 122 | optimizer_disc.zero_grad() 123 | 124 | with torch.set_grad_enabled(True): 125 | x_generated = gen(x_gen, y) 126 | fake_op, fake_pooled_op = disc(x_gen, x_generated) 127 | real_op, real_pooled_op = disc(x_gen, x_dis) 128 | loss_identity_gen = -loss_i(real_pooled_op, fake_pooled_op) 129 | loss_identity_gen.backward(retain_graph=True) 130 | optimizer_gen.step() 131 | 132 | optimizer_disc.zero_grad() 133 | loss_identity_disc = loss_i(real_op, fake_op) 134 | loss_identity_disc.backward(retain_graph=True) 135 | optimizer_disc.step() 136 | 137 | optimizer_gen.zero_grad() 138 | optimizer_disc.zero_grad() 139 | x_ls2a = gen(y, x_generated) 140 | x_ls2b = gen(x_generated, y) 141 | 142 | loss_s2a = loss_s(y, x_ls2a) 143 | loss_s2b = loss_s(x_generated, x_ls2b) 144 | loss_s2 = loss_s2a + loss_s2b 145 | 146 | loss_s2.backward() 147 | optimizer_gen.step() 148 | 149 | optimizer_gen.zero_grad() 150 | optimizer_disc.zero_grad() 151 | 152 | x_ls1 = generator(iden_1, iden_2) 153 | 154 | loss_s1 = loss_s(iden_2, x_ls1) 155 | loss_s1.backward() 156 | optimizer_gen.step() 157 | running_loss_iden += loss_identity_disc.item() * x_gen.size(0) 158 | running_loss_s1 += loss_s1.item() * x_gen.size(0) 159 | running_loss_s2a += loss_s2a.item() * x_gen.size(0) 160 | running_loss_s2b += loss_s2b.item() * x_gen.size(0) 161 | running_loss = running_loss_iden + beta * (running_loss_s1 + alpha * (running_loss_s2a + running_loss_s2b)) 162 | epoch_loss_iden = running_loss_iden / dataset_sizes[0] 163 | epoch_loss_s1 = running_loss_s1 / dataset_sizes[0] 164 | epoch_loss_s2a = running_loss_s2a / dataset_sizes[0] 165 | epoch_loss_s2b = running_loss_s2a / dataset_sizes[0] 166 | epoch_loss = running_loss / dataset_sizes[0] 167 | print('Identity Loss: {:.4f} Loss Shape1: {:.4f} Loss Shape2a: {:.4f} Loss Shape2b: {:.4f}'.format(epoch_loss_iden, epoch_loss_s1, 168 | epoch_loss_s2a, epoch_loss_s2b)) 169 | print('Epoch Loss: {:.4f}'.format(epoch_loss)) 170 | 171 | save_checkpoint({ 172 | 'epoch': epoch + 1, 173 | 'gen_state_dict': gen.state_dict(), 174 | 'disc_state_dict': disc.state_dict(), 175 | 'gen_opt': optimizer_gen.state_dict(), 176 | 'disc_opt': optimizer_disc.state_dict() 177 | }, checkpoint_path, epoch + 1) 178 | print('Time taken by epoch: {: .0f}m {:0f}s'.format((time.time() - since) // 60, (time.time() - since) % 60)) 179 | print() 180 | since = time.time() 181 | 182 | return gen, disc 183 | 184 | 185 | ##################################################### 186 | ## 187 | ## MODEL TRAINING 188 | ## 189 | ##################################################### 190 | 191 | 192 | generator, discriminator = train_model(generator, discriminator, lossIdentity, lossShape, optimizer_gen, optimizer_disc, alpha=alpha, beta=beta, num_epochs=epochs) 193 | -------------------------------------------------------------------------------- /utils/__pycache__/dataloader.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aarushgupta/FusionGAN/6efeda8aa501a6ebfad90adb50eb812a09198c11/utils/__pycache__/dataloader.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/loss_functions.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aarushgupta/FusionGAN/6efeda8aa501a6ebfad90adb50eb812a09198c11/utils/__pycache__/loss_functions.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/train_function.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aarushgupta/FusionGAN/6efeda8aa501a6ebfad90adb50eb812a09198c11/utils/__pycache__/train_function.cpython-35.pyc -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | from torchvision import utils 3 | from PIL import Image 4 | 5 | class YouTubePose(Dataset): 6 | 7 | def __init__(self, datapoint_pairs, shapeLoss_datapoint_pairs, dataset_dir, transform=None, mode='train'): 8 | self.datapoint_pairs = datapoint_pairs 9 | self.shapeLoss_datapoint_pairs = shapeLoss_datapoint_pairs 10 | self.dataset_dir = dataset_dir 11 | self.transform = transform 12 | self.mode = mode 13 | 14 | def __len__(self): 15 | return len(self.datapoint_pairs) 16 | 17 | def __getitem__(self, idx): 18 | image_pair = self.datapoint_pairs[idx] 19 | x_gen_path = image_pair[0] 20 | x_dis_path = image_pair[1] 21 | y_path = image_pair[2] 22 | 23 | identity_pair = self.shapeLoss_datapoint_pairs[idx] 24 | iden_1_path = identity_pair[0] 25 | iden_2_path = identity_pair[1] 26 | 27 | x_gen = Image.open(self.dataset_dir + self.mode + '/' + x_gen_path) 28 | x_dis = Image.open(self.dataset_dir + self.mode + '/' + x_dis_path) 29 | y = Image.open(self.dataset_dir + self.mode + '/' + y_path) 30 | iden_1 = Image.open(self.dataset_dir + self.mode + '/' + iden_1_path) 31 | iden_2 = Image.open(self.dataset_dir + self.mode + '/' + iden_2_path) 32 | 33 | if self.transform: 34 | x_gen = self.transform(x_gen) 35 | x_dis = self.transform(x_dis) 36 | y = self.transform(y) 37 | iden_1 = self.transform(iden_1) 38 | iden_2 = self.transform(iden_2) 39 | 40 | sample = {'x_gen' : x_gen, 'x_dis': x_dis, 'y': y, 'iden_1': iden_1, 'iden_2':iden_2} 41 | return sample 42 | -------------------------------------------------------------------------------- /utils/loss_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def lossIdentity(real_pair, fake_pair): 4 | batch_size = real_pair.size()[0] 5 | real_pair = 1 - real_pair 6 | real_pair = real_pair ** 2 7 | fake_pair = fake_pair ** 2 8 | real_pair = torch.sum(real_pair) 9 | fake_pair = torch.sum(fake_pair) 10 | return (real_pair + fake_pair) / batch_size 11 | 12 | 13 | def lossShape(x, y): 14 | batch_size = x.size()[0] 15 | diff = torch.abs(x - y) 16 | diff = torch.sum(diff) / batch_size 17 | return diff --------------------------------------------------------------------------------