├── .DS_Store ├── MNIST ├── .ipynb_checkpoints │ └── Untitled-checkpoint.ipynb ├── DCGAN.ipynb └── mnist_Readme.md ├── README.md ├── documentation ├── Progress-presentations │ ├── .DS_Store │ └── progress-presentation-1 │ │ ├── .DS_Store │ │ ├── Progress-Presentation.tex │ │ ├── architecture.png │ │ ├── coco.jpg │ │ ├── gan.png │ │ ├── mnist.png │ │ ├── mnist_ideal.png │ │ └── progress_presentation_i.pdf ├── REPORT.pdf ├── text_to_image_synthesis_using_gan.pdf └── text_to_image_synthesis_using_gan.png ├── texttoimage ├── .DS_Store ├── code │ ├── .DS_Store │ ├── cfg │ │ ├── .DS_Store │ │ ├── birds_s1.yml │ │ ├── birds_s2.yml │ │ ├── coco_eval.yml │ │ ├── coco_s1.yml │ │ ├── coco_s2.yml │ │ ├── flowers_s1.yml │ │ └── flowers_s2.yml │ ├── demo-Copy1.ipynb │ ├── demo_birds.ipynb │ ├── demo_coco.ipynb │ ├── get_embedding.lua │ ├── main.py │ ├── miscc │ │ ├── config.py │ │ ├── datasets.py │ │ └── utils .py │ ├── model.py │ └── trainer.py └── readme.md └── videogeneration ├── .DS_Store ├── data └── .DS_Store ├── preprocess.py ├── readme.md └── src ├── .DS_Store ├── data.py ├── demo_video.ipynb ├── generate_videos.py ├── logger.py ├── models.py ├── run.sh ├── train.py ├── trainers.py └── util.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eYSIP-2018/Text_to_Image-Video_synthesis_using_GANs/7331a5aa6d898e7481206a406c554dda825bc971/.DS_Store -------------------------------------------------------------------------------- /MNIST/.ipynb_checkpoints/Untitled-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /MNIST/DCGAN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "DCGAN.ipynb", 7 | "version": "0.3.2", 8 | "views": {}, 9 | "default_view": {}, 10 | "provenance": [], 11 | "collapsed_sections": [] 12 | }, 13 | "kernelspec": { 14 | "name": "python3", 15 | "display_name": "Python 3" 16 | }, 17 | "accelerator": "GPU" 18 | }, 19 | "cells": [ 20 | { 21 | "metadata": { 22 | "id": "5IlgEnKCSx8F", 23 | "colab_type": "text" 24 | }, 25 | "cell_type": "markdown", 26 | "source": [ 27 | "# **DCGAN implementation of MNIST dataset**\n", 28 | "\n", 29 | "Software Requirements:\n", 30 | "python 3\n", 31 | "\n", 32 | "library requirements \n", 33 | "1. Torch\n", 34 | "2. Torchvision\n", 35 | "3. Matplotlib\n", 36 | "4. tensorboardX\n", 37 | "5. numpy\n", 38 | "6. Pillow\n", 39 | "\n", 40 | "\n", 41 | "Running Instructions:\n", 42 | "1. Install all libraries\n", 43 | "2. Run the Code cell below \n", 44 | "\n", 45 | "DCGAN architecture Details\n", 46 | "\n", 47 | "![DCGAN architecture details](https://raw.githubusercontent.com/znxlwm/pytorch-MNIST-CelebA-GAN-DCGAN/master/pytorch_DCGAN.png)\n", 48 | "\n", 49 | "\n", 50 | "Generated images after a few epochs\n", 51 | "\n", 52 | "![alt text](https://i.imgur.com/Vp1w3KS.png)\n", 53 | "\n", 54 | "\n", 55 | "**References**:\n", 56 | "\n", 57 | "1. [DCGAN Paper](https://arxiv.org/pdf/1511.06434.)\n", 58 | "2. [MNIST- celebA implementation](https://github.com/znxlwm/pytorch-MNIST-CelebA-GAN-DCGAN)\n", 59 | "3. [utility File, GAN tutorial](https://medium.com/ai-society/gans-from-scratch-1-a-deep-introduction-with-code-in-pytorch-and-tensorflow-cb03cdcdba0f)\n", 60 | "\n", 61 | "\n", 62 | "\n", 63 | "\n", 64 | "\n" 65 | ] 66 | }, 67 | { 68 | "metadata": { 69 | "id": "LGt4QJHUWkoT", 70 | "colab_type": "text" 71 | }, 72 | "cell_type": "markdown", 73 | "source": [ 74 | "logger utility class to save images , models and log tensorboard data\n" 75 | ] 76 | }, 77 | { 78 | "metadata": { 79 | "id": "KnbfbabsthKO", 80 | "colab_type": "code", 81 | "colab": { 82 | "autoexec": { 83 | "startup": false, 84 | "wait_interval": 0 85 | } 86 | }, 87 | "cellView": "code" 88 | }, 89 | "cell_type": "code", 90 | "source": [ 91 | "import os\n", 92 | "import numpy as np\n", 93 | "import errno\n", 94 | "import torchvision.utils as vutils\n", 95 | "from tensorboardX import SummaryWriter\n", 96 | "from matplotlib import pyplot as plt\n", 97 | "from IPython import display\n", 98 | "\n", 99 | "\n", 100 | "# utility class to show image plots , save model, and log tensorboard Data\n", 101 | "class Logger:\n", 102 | " def __init__(self, model_name, data_name):\n", 103 | " self.model_name = model_name\n", 104 | " self.data_name = data_name\n", 105 | "\n", 106 | " self.comment = '{}_{}'.format(model_name, data_name)\n", 107 | " self.data_subdir = '{}/{}'.format(model_name, data_name)\n", 108 | "\n", 109 | " # TensorBoard\n", 110 | " self.writer = SummaryWriter(comment=self.comment)\n", 111 | " \n", 112 | " def log(self, d_error, g_error, epoch, n_batch, num_batches):\n", 113 | "\n", 114 | " var_class = torch.autograd.Variable\n", 115 | " if type(d_error) == var_class:\n", 116 | " d_error = d_error.data.cpu().numpy()\n", 117 | " if type(g_error) == var_class:\n", 118 | " g_error = g_error.data.cpu().numpy()\n", 119 | "\n", 120 | " step = Logger._step(epoch, n_batch, num_batches)\n", 121 | " self.writer.add_scalar(\n", 122 | " '{}/D_error'.format(self.comment), d_error, step)\n", 123 | " self.writer.add_scalar(\n", 124 | " '{}/G_error'.format(self.comment), g_error, step)\n", 125 | "\n", 126 | " def log_images(self, images, num_images, epoch, n_batch, num_batches, format='NCHW', normalize=True):\n", 127 | " '''\n", 128 | " input images are expected in format (NCHW)\n", 129 | " '''\n", 130 | " if type(images) == np.ndarray:\n", 131 | " images = torch.from_numpy(images)\n", 132 | "\n", 133 | " if format == 'NHWC':\n", 134 | " images = images.transpose(1, 3)\n", 135 | "\n", 136 | " step = Logger._step(epoch, n_batch, num_batches)\n", 137 | " img_name = '{}/images{}'.format(self.comment, '')\n", 138 | "\n", 139 | " # Make horizontal grid from image tensor\n", 140 | " horizontal_grid = vutils.make_grid(\n", 141 | " images, normalize=normalize, scale_each=True)\n", 142 | " # Make vertical grid from image tensor\n", 143 | " nrows = int(np.sqrt(num_images))\n", 144 | " grid = vutils.make_grid(\n", 145 | " images, nrow=nrows, normalize=True, scale_each=True)\n", 146 | "\n", 147 | " # Add horizontal images to tensorboard\n", 148 | " self.writer.add_image(img_name, horizontal_grid, step)\n", 149 | "\n", 150 | " # Save plots\n", 151 | " self.save_torch_images(horizontal_grid, grid, epoch, n_batch)\n", 152 | "\n", 153 | " def save_torch_images(self, horizontal_grid, grid, epoch, n_batch, plot_horizontal=True):\n", 154 | " out_dir = './data/images/{}'.format(self.data_subdir)\n", 155 | " Logger._make_dir(out_dir)\n", 156 | "\n", 157 | " # Plot and save horizontal\n", 158 | " fig = plt.figure(figsize=(16, 16))\n", 159 | " plt.imshow(np.moveaxis(horizontal_grid.numpy(), 0, -1))\n", 160 | " plt.axis('off')\n", 161 | " if plot_horizontal:\n", 162 | " display.display(plt.gcf())\n", 163 | " self._save_images(fig, epoch, n_batch, 'hori')\n", 164 | " plt.close()\n", 165 | "\n", 166 | " # Save squared\n", 167 | " fig = plt.figure()\n", 168 | " plt.imshow(np.moveaxis(grid.numpy(), 0, -1))\n", 169 | " plt.axis('off')\n", 170 | " self._save_images(fig, epoch, n_batch)\n", 171 | " plt.close()\n", 172 | "\n", 173 | " def _save_images(self, fig, epoch, n_batch, comment=''):\n", 174 | " out_dir = './data/images/{}'.format(self.data_subdir)\n", 175 | " Logger._make_dir(out_dir)\n", 176 | " fig.savefig('{}/{}_epoch_{}_batch_{}.png'.format(out_dir,\n", 177 | " comment, epoch, n_batch))\n", 178 | "\n", 179 | " def display_status(self, epoch, num_epochs, n_batch, num_batches, d_error, g_error, d_pred_real, d_pred_fake):\n", 180 | "\n", 181 | " var_class = torch.autograd.Variable\n", 182 | " if type(d_error) == var_class:\n", 183 | " d_error = d_error.data.cpu().numpy()[0]\n", 184 | " if type(g_error) == var_class:\n", 185 | " g_error = g_error.data.cpu().numpy()[0]\n", 186 | " if type(d_pred_real) == var_class:\n", 187 | " d_pred_real = d_pred_real.data\n", 188 | " if type(d_pred_fake) == var_class:\n", 189 | " d_pred_fake = d_pred_fake.data\n", 190 | "\n", 191 | " print('Epoch: [{}/{}], Batch Num: [{}/{}]'.format(\n", 192 | " epoch, num_epochs, n_batch, num_batches)\n", 193 | " )\n", 194 | " print('Discriminator Loss: {:.4f}, Generator Loss: {:.4f}'.format(d_error, g_error))\n", 195 | " print('D(x): {:.4f}, D(G(z)): {:.4f}'.format(d_pred_real.mean(), d_pred_fake.mean()))\n", 196 | "\n", 197 | " def save_models(self, generator, discriminator, epoch):\n", 198 | " out_dir = './data/models/{}'.format(self.data_subdir)\n", 199 | " Logger._make_dir(out_dir)\n", 200 | " torch.save(generator.state_dict(),\n", 201 | " '{}/G_epoch_{}'.format(out_dir, epoch))\n", 202 | " torch.save(discriminator.state_dict(),\n", 203 | " '{}/D_epoch_{}'.format(out_dir, epoch))\n", 204 | "\n", 205 | " def close(self):\n", 206 | " self.writer.close()\n", 207 | "\n", 208 | " # Private Functionality\n", 209 | "\n", 210 | " @staticmethod\n", 211 | " def _step(epoch, n_batch, num_batches):\n", 212 | " return epoch * num_batches + n_batch\n", 213 | "\n", 214 | " @staticmethod\n", 215 | " def _make_dir(directory):\n", 216 | " try:\n", 217 | " os.makedirs(directory)\n", 218 | " except OSError as e:\n", 219 | " if e.errno != errno.EEXIST:\n", 220 | " raise\n" 221 | ], 222 | "execution_count": 0, 223 | "outputs": [] 224 | }, 225 | { 226 | "metadata": { 227 | "id": "LnNdivABtvAw", 228 | "colab_type": "text" 229 | }, 230 | "cell_type": "markdown", 231 | "source": [ 232 | "CODE FROM HERE" 233 | ] 234 | }, 235 | { 236 | "metadata": { 237 | "id": "XbUQf8U3O8z7", 238 | "colab_type": "code", 239 | "colab": { 240 | "autoexec": { 241 | "startup": false, 242 | "wait_interval": 0 243 | } 244 | } 245 | }, 246 | "cell_type": "code", 247 | "source": [ 248 | "#DCGAN implementation for MNIST\n", 249 | "\n", 250 | "\n", 251 | "import torch\n", 252 | "from torch import nn\n", 253 | "from torch.optim import Adam\n", 254 | "from torch.autograd Pimport Variable\n", 255 | "from IPython import display\n", 256 | "from torchvision import transforms, datasets\n", 257 | "import os\n", 258 | "import numpy as np\n", 259 | "import errno\n", 260 | "import torchvision.utils as vutils\n", 261 | "from tensorboardX import SummaryWriter\n", 262 | "from matplotlib import pyplot as plt\n", 263 | "\n", 264 | "\n", 265 | " \n", 266 | "#data generator function to create MNIST data\n", 267 | "#uses inbuilt dataloader from pytorch datasets , downloads the dataset if necessary \n", 268 | "def data_generator():\n", 269 | " # applies transfroms to each image to convert it to a valid format by resizing and bringing it into a (-1,1) range\n", 270 | " compose = transforms.Compose([\n", 271 | " transforms.Resize(64),\n", 272 | " transforms.ToTensor(),\n", 273 | " transforms.Normalize((.5,.5,.5),(.5,.5,.5))\n", 274 | " ])\n", 275 | " out_dir='{}/dataset'.format('MNIST')\n", 276 | " return datasets.MNIST(out_dir,True,compose,download=True)\n", 277 | "\n", 278 | "data = data_generator()\n", 279 | "batch_size = 100\n", 280 | "data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)\n", 281 | "num_batches = len(data_loader)\n", 282 | "\n", 283 | "\n", 284 | "#discriminator net \n", 285 | "class DiscriminatorNet(torch.nn.Module):\n", 286 | "# Network for discriminator is created in this part of code\n", 287 | "# Six layers have 1, 128, 256, 512, 1024 and 1 channels respectively.\n", 288 | "# Batch Normalisation is applied to all hidden layers.\n", 289 | "# LeakyReLU with negetive slope 0.2 is applied to all the layers except the output layer. \n", 290 | " def __init__(self):\n", 291 | " super(DiscriminatorNet, self).__init__()\n", 292 | " # First convolutional layer takes 1 channel and returns 128 channels, applies a 4*4 kernel with a padding of 1 and stride of 2.\n", 293 | " self.conv1 = nn.Sequential(\n", 294 | " nn.Conv2d(1,128,4,2,1, bias=False),\n", 295 | " nn.LeakyReLU(0.2, inplace=True)\n", 296 | " )\n", 297 | " # Second convolutional layer takes 128 channel input and returns 256 channels, applies a 4*4 kernel with a padding of 1 and stride of 2.\n", 298 | " self.conv2 = nn.Sequential(\n", 299 | " nn.Conv2d(128, 256,4, 2, 1, bias=False),\n", 300 | " nn.BatchNorm2d(256),\n", 301 | " nn.LeakyReLU(0.2, inplace=True)\n", 302 | " )\n", 303 | " # Third convolutional layer takes 256 channels and returns 512 channels, applies a 4*4 kernel with a padding of 1 and stride of 2.\n", 304 | " self.conv3 = nn.Sequential(\n", 305 | " nn.Conv2d(256,512, 4, 2, 1,bias=False),\n", 306 | " nn.BatchNorm2d(512),\n", 307 | " nn.LeakyReLU(0.2, inplace=True)\n", 308 | " )\n", 309 | " # Fourth convolutional layer takes 512 channels and returns 1024 channels,applies a 4*4 kernel with a padding of 1 and stride of 2. \n", 310 | " self.conv4 = nn.Sequential(\n", 311 | " nn.Conv2d(512, 1024, 4, 2, 1, bias=False),\n", 312 | " nn.BatchNorm2d(1024),\n", 313 | " nn.LeakyReLU(0.2, inplace=True)\n", 314 | " )\n", 315 | " # Linear network is used for the last layer.\n", 316 | " # Sigmoid activation function is used. \n", 317 | " self.out = nn.Sequential(\n", 318 | " nn.Linear(1024*4*4, 1),\n", 319 | " nn.Sigmoid(),\n", 320 | " )\n", 321 | "\n", 322 | " def forward(self, x):\n", 323 | " x = self.conv1(x)\n", 324 | " x = self.conv2(x)\n", 325 | " x = self.conv3(x)\n", 326 | " x = self.conv4(x)\n", 327 | " # Flatten and apply sigmoid\n", 328 | " x = x.view(-1, 1024*4*4)\n", 329 | " x = self.out(x)\n", 330 | " return x\n", 331 | "\n", 332 | "#generatorNet Class contains our generator network which inherits the nn.Module \n", 333 | "class generatorNet(nn.Module):\n", 334 | " def __init__(self):\n", 335 | " super(generatorNet, self).__init__()\n", 336 | " #a liner layer to convert a noise vector of size 100 to a size of 1024*4*4\n", 337 | " self.linear = torch.nn.Linear(100, 1024 * 4 * 4)\n", 338 | " \n", 339 | " #you can view the architecture in a more feasible manner in the image provided\n", 340 | " #ReLu activation has been used in every layer and A batchnorm has been applied after the output of Deconvolutional layer\n", 341 | " \n", 342 | " #first deconvolutional layer which takes 1024 input channels and returns 512 output channels , applies a 4*4 kernel with a padding of 1 and stride of 2. \n", 343 | " self.deconv1 = nn.Sequential(\n", 344 | " nn.ConvTranspose2d(1024,512, 4, 2, 1, bias=False ),\n", 345 | " nn.BatchNorm2d(512),\n", 346 | " nn.ReLU(inplace=True)\n", 347 | " )\n", 348 | " \n", 349 | " #second deconvolutional layer which takes 512 input channels and returns 256 output channels , applies a 4*4 kernel with a padding of 1 and stride of 2. \n", 350 | " self.deconv2 = nn.Sequential(\n", 351 | " nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),\n", 352 | " nn.BatchNorm2d(256),\n", 353 | " nn.ReLU(inplace=True)\n", 354 | " )\n", 355 | " \n", 356 | " #fourth deconvolutional layer which takes 256 input channels and returns 128 output channels , applies a 4*4 kernel with a padding of 1 and stride of 2. \n", 357 | " self.deconv3 = nn.Sequential(\n", 358 | " nn.ConvTranspose2d(256, 128,4, 2, 1, bias=False),\n", 359 | " nn.BatchNorm2d(128),\n", 360 | " nn.ReLU(inplace=True)\n", 361 | " )\n", 362 | " \n", 363 | " #fifth deconvolutional layer which takes 128 input channels and returns 1 output channel , applies a 4*4 kernel with a padding of 1 and stride of 2.\n", 364 | " #we dont apply batchnorm to this deconvolutional layer.\n", 365 | " self.deconv4 = nn.Sequential(\n", 366 | " nn.ConvTranspose2d(128, 1, 4, 2, 1, bias=False)\n", 367 | " )\n", 368 | " \n", 369 | " #tanh activator for the output of fifth deconvolutional layer. \n", 370 | " self.out = torch.nn.Tanh()\n", 371 | "\n", 372 | " def forward(self, x):\n", 373 | " # Project and reshape\n", 374 | " x = self.linear(x)\n", 375 | " x = x.view(x.shape[0], 1024, 4, 4)\n", 376 | " x = self.deconv1(x)\n", 377 | " x = self.deconv2(x)\n", 378 | " x = self.deconv3(x)\n", 379 | " x = self.deconv4(x)\n", 380 | " # Apply Tanh\n", 381 | " return self.out(x)\n", 382 | "\n", 383 | "\n", 384 | "#noise generator takes size as a parameter and generates a noise vector with size 100 \n", 385 | "def noise(size):\n", 386 | " n = Variable(torch.randn(size, 100))\n", 387 | " if torch.cuda.is_available(): return n.cuda()\n", 388 | " return n\n", 389 | "\n", 390 | "\n", 391 | "#weight initialization function to initialize the weights of both the generator and discriminator \n", 392 | "def init_weights(m):\n", 393 | " classname = m.__class__.__name__\n", 394 | " if classname.find('Conv') != -1 or classname.find('BatchNorm') != -1:\n", 395 | " m.weight.data.normal_(0.00, 0.02)\n", 396 | "# Get instance of generator and discriminator network \n", 397 | "generator = generatorNet()\n", 398 | "discriminator = DiscriminatorNet()\n", 399 | "# Initialise the weights\n", 400 | "generator.apply(init_weights)\n", 401 | "discriminator.apply(init_weights)\n", 402 | "\n", 403 | "#set it to cuda if cuda enabled GPU is available\n", 404 | "if torch.cuda.is_available():\n", 405 | " generator.cuda()\n", 406 | " discriminator.cuda()\n", 407 | "\n", 408 | "# both the networks use Adam optimizer with learning rate of 0.0002 and beta of 0.5 with NO weight decay.\n", 409 | "d_optimizer = Adam(discriminator.parameters(),lr=0.0002, betas=(0.5,0.999))\n", 410 | "g_optimizer = Adam(generator.parameters(),lr=0.0002,betas=(0.5,0.999))\n", 411 | "\n", 412 | "\n", 413 | "#A binay cross entropy loss has been used as its very similar to the loss described in the DCGAN paper.\n", 414 | "loss = nn.BCELoss()\n", 415 | "num_epochs = 50\n", 416 | "\n", 417 | "\n", 418 | "#this function takes the size as parameter and generates a vector of 1s with the given size , this vector will be our real data target.\n", 419 | "def real_data_target(size):\n", 420 | " if torch.cuda.is_available():\n", 421 | " return Variable(torch.ones(size,1)).cuda()\n", 422 | " return Variable(torch.ones(size,1))\n", 423 | "\n", 424 | "#this function takes the size as parameter and generates a vector of 0s with the given size , this vector will be our fake data target.\n", 425 | "def fake_data_target(size):\n", 426 | " if torch.cuda.is_available():\n", 427 | " return Variable(torch.zeros(size,1)).cuda()\n", 428 | " return Variable(torch.zeros(size,1))\n", 429 | "\n", 430 | "#training function for the discriminator\n", 431 | "#takes input as the optimizer , real image data , and the generated (fake) images data.\n", 432 | "def train_discriminator(optimizer, real_data, fake_data):\n", 433 | " #set the gradients to zero before optimization\n", 434 | " optimizer.zero_grad()\n", 435 | " \n", 436 | " #get the prediction from the discriminator for the real data and calculate the error by comparing it to the real data target \n", 437 | " prediction_real = discriminator(real_data)\n", 438 | " error_real = loss(prediction_real,real_data_target(real_data.size(0)))\n", 439 | " #calculate all the gradients going backward for the disriminator\n", 440 | " error_real.backward()\n", 441 | "\n", 442 | " #get the prediction from the discriminator for the fake data and calculate the error by comparing it to the fake data target \n", 443 | " prediction_fake = discriminator(fake_data)\n", 444 | " error_fake = loss(prediction_fake,fake_data_target(real_data.size(0)))\n", 445 | " #calculate all the gradients going backward for the generator \n", 446 | " error_fake.backward()\n", 447 | "\n", 448 | " #backpropogate the weight and optimize both the nets\n", 449 | " optimizer.step()\n", 450 | "\n", 451 | " return error_real+error_fake,prediction_real,prediction_fake\n", 452 | " \n", 453 | " \n", 454 | "#training function for the generator\n", 455 | "def train_generator(optimizer,fake_data):\n", 456 | " optimizer.step()\n", 457 | " # get predictions from discriminator by feeding fake data\n", 458 | " prediction = discriminator(fake_data)\n", 459 | " # calcuate the error of prediction by comparing it to real data target\n", 460 | " error = loss(prediction,real_data_target(prediction.size(0)))\n", 461 | " # backpropagate the error and update the weights\n", 462 | " error.backward()\n", 463 | "\n", 464 | " optimizer.step()\n", 465 | " return error\n", 466 | "\n", 467 | "#logger configuratiom.\n", 468 | "logger = Logger(model_name='DCGAN', data_name='MNIST')\n", 469 | "\n", 470 | "#generate noise samples for testing.\n", 471 | "num_test_samples = 16\n", 472 | "test_noise = noise(num_test_samples)\n", 473 | "\n", 474 | "#trainig loop for the number of epochs\n", 475 | "for epoch in range(num_epochs):\n", 476 | " #loop through one batch of data in ths loop.\n", 477 | " for n_batch, (real_batch, _) in enumerate(data_loader):\n", 478 | "\n", 479 | " real_data = Variable(real_batch)\n", 480 | " if torch.cuda.is_available(): real_data = real_data.cuda()\n", 481 | " # Generate fake data\n", 482 | " fake_data = generator(noise(real_data.size(0))).detach()\n", 483 | " # Train the discriminatorNet\n", 484 | " d_error, d_pred_real, d_pred_fake = train_discriminator(d_optimizer,\n", 485 | " real_data, fake_data)\n", 486 | "\n", 487 | " # Generate fake data\n", 488 | " fake_data = generator(noise(real_batch.size(0)))\n", 489 | " # Train Generator \n", 490 | " g_error = train_generator(g_optimizer, fake_data)\n", 491 | " # Log error\n", 492 | " logger.log(d_error, g_error, epoch, n_batch, num_batches)\n", 493 | "\n", 494 | " # Display Progress\n", 495 | " if (n_batch) % 10 == 0:\n", 496 | " display.clear_output(True)\n", 497 | " # Display Images\n", 498 | " test_images = generator(test_noise).data.cpu()\n", 499 | " logger.log_images(test_images, num_test_samples, epoch, n_batch, num_batches);\n", 500 | " # Display status Logs\n", 501 | " logger.display_status(\n", 502 | " epoch, num_epochs, n_batch, num_batches,\n", 503 | " d_error, g_error, d_pred_real, d_pred_fake\n", 504 | " )\n", 505 | " # Save the model that has been trained \n", 506 | " logger.save_models(generator, discriminator, epoch)" 507 | ], 508 | "execution_count": 0, 509 | "outputs": [] 510 | } 511 | ] 512 | } 513 | -------------------------------------------------------------------------------- /MNIST/mnist_Readme.md: -------------------------------------------------------------------------------- 1 | ## Generating MNIST dataset Images 2 | 3 | Please visit the wiki page with the same title to learn more about this code. 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Text to Image/Video synthesis using GANs. 2 | This repository contains all the code that was developed during eyantra summer internship 2018 at IITB. 3 | 4 | ## Objective 5 | In this project we have developed models that can generate dataset of new images, images based on text description and video dataset based on men-tioned category. 6 | 7 | ## Authors 8 | * Deval Srivastava 9 | * Aishwarya J Kalloli 10 | 11 | ## Mentors 12 | * Aditya Panwar 13 | * Kalind Karia 14 | 15 | ## Folders 16 | * mnist: Contains all the code for mnist dataset generation 17 | * texttoimage: Contains all the code for generation of images from text 18 | * videogeneration: Contains all the code for the generation of videos of KTH dataset 19 | * documentation: Contains all the progress presentations and posters 20 | 21 | ## software and libraries to be installed 22 | * Python3 23 | * CudaNN 24 | * Cuda 25 | * lua 26 | * torch 27 | 28 | ## python requirements 29 | * Pytorch 30 | * torchvision 31 | * Numpy 32 | * Matplotlib 33 | * Jupyter Notebook 34 | * Pandas 35 | * torchfile 36 | * requests 37 | * tensorboardx 38 | * imageio 39 | * cv2 40 | * zipfile 41 | * PIL 42 | * h5py 43 | 44 | ## WARNING: Please ensure that you have a cuda enabled GPU as all of the code in the repository has been designed to be run on it 45 | -------------------------------------------------------------------------------- /documentation/Progress-presentations/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eYSIP-2018/Text_to_Image-Video_synthesis_using_GANs/7331a5aa6d898e7481206a406c554dda825bc971/documentation/Progress-presentations/.DS_Store -------------------------------------------------------------------------------- /documentation/Progress-presentations/progress-presentation-1/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eYSIP-2018/Text_to_Image-Video_synthesis_using_GANs/7331a5aa6d898e7481206a406c554dda825bc971/documentation/Progress-presentations/progress-presentation-1/.DS_Store -------------------------------------------------------------------------------- /documentation/Progress-presentations/progress-presentation-1/Progress-Presentation.tex: -------------------------------------------------------------------------------- 1 | \documentclass[10pt, a4paper]{beamer} 2 | \setbeamertemplate{caption}[numbered] 3 | 4 | 5 | \usetheme{Berkeley} 6 | \usecolortheme{sidebartab} 7 | \usepackage{tabu} 8 | \usepackage{array} 9 | \usepackage{graphicx} 10 | \usepackage{subcaption} 11 | \newcolumntype{P}[1]{>{\centering\arraybackslash}p{#1}} 12 | \begin{document} 13 | \setbeamertemplate{sidebar left}{} 14 | \title{Progress Presentation-I} 15 | \subtitle{e-Yantra Summer Internship-2018 \\ $ $\textbf{Text-to-Image/Video Synthesis using GANs}$ $} 16 | \author{Aishwarya Kalloli\\Deval Srivastava\\ \vspace{1em} 17 | Mentors:\\ Aditya Panwar, Kalind Karia} 18 | \institute{IIT Bombay} 19 | \date{\today} 20 | %\addtobeamertemplate{sidebar left}{}{\includegraphics[scale = 0.3]{logowithtext.png}} 21 | \frame{\titlepage} 22 | 23 | \setbeamertemplate{sidebar left}[sidebar theme] 24 | \section{Overview of Project} 25 | \begin{frame}{Overview of Project} 26 | \begin{itemize} 27 | \item \textcolor{blue}{Project Name:} Text-to-Image / Video Synthesis using GANs 28 | \item \textcolor{blue}{Objective:} To generate image or video from given caption 29 | \item \textcolor{blue}{Deliverables:} 30 | \begin{enumerate} 31 | \item To create a model that can generate new images by getting trained on a given dataset 32 | \item Creation of video from these new set of images 33 | \item Prepare proper documentation and tutorial of the solution 34 | \end{enumerate} 35 | \end{itemize} 36 | \end{frame} 37 | 38 | \section{Overview of Task} 39 | \begin{frame}{Overview of Task} 40 | \footnotesize{ 41 | \begin{tabular}{|P{1cm}|p{6.2cm}|P{1.7cm}| } 42 | \hline 43 | \multicolumn{3}{|c|}{Project Task List} \\ 44 | \hline 45 | Task & \centering Task & Deadline (days)\\ 46 | \hline 47 | 1 & Understanding the idea and create report on how it can be tackled using Machine Learning: a basic report of 2-5 pages highlighting various algorithms suitable for the task & 2\\ 48 | \hline 49 | 2 & Installing the required software & 1 \\ 50 | \hline 51 | 3 & 52 | Perform a basic experiment to understand GANs (MNIST) 53 | & 2 \\ 54 | \hline 55 | 4 & Gather the required data-set to train the model & 1-2 \\ 56 | \hline 57 | 5 & Design the model, test its feasibility 58 | & 2 \\ 59 | \hline 60 | 6 & 61 | Train the model and calculate the accuracy of the model 62 | & 6 \\ 63 | \hline 64 | 7 & 65 | Generate new data-set of images/scenes from the text 66 | & 4 \\ 67 | \hline 68 | 8 & 69 | Create a video/scene from the set of generated images 70 | & 6\\ 71 | \hline 72 | 9 & 73 | Develop proper tutorial and documentation (with video demo) on the implementation & 4\\ 74 | \hline 75 | 10 & 76 | Text to audio/music generation (optional) 77 | & 3 \\ 78 | \hline 79 | \end{tabular} 80 | 81 | %------------ 82 | } 83 | \end{frame} 84 | 85 | \section{Task Accomplished} 86 | \begin{frame}{Task Accomplished: Report} 87 | \begin{itemize} 88 | \item 89 | {\textcolor{blue}{Understanding the idea and create report on how it can be tackled using Machine Learning.}\linebreak \linebreak 90 | We went through several papers and researched about different GANs, based on that we decided to use DCGAN that will be conditioned on text embeddings generated by a Character RNN to Generate Images from text 91 | } 92 | 93 | \begin{figure} 94 | \includegraphics[width=\linewidth]{architecture.png} 95 | \caption{Architecture\textsuperscript{[1]}} 96 | \label{fig:architecture} 97 | \end{figure} 98 | % \item 99 | % Perform a basic experiment to understand GANs (MNIST) 100 | % \item 101 | % Gather the required data-set to train the model 102 | % \item Design the model 103 | {\scriptsize{[1] Generative Adversarial Text to Image Synthesis by Scott Reed, Zeynep Akata, Xinchen Yan, Lajanugen Logeswaran, Bernt Schiele, Honglak Lee}} 104 | \end{itemize} 105 | \end{frame} 106 | 107 | \begin{frame}{Task Accomplished: Software Installation} 108 | \begin{itemize} 109 | \item { 110 | \textcolor{blue}{Installing the required software}\linebreak 111 | \begin{enumerate} 112 | \item Python, PyTorch and Torchvision were successfully installed 113 | \item A Nvidia GTX 1080Ti was employed for training 114 | \item CudaNN and Nvidia drivers were installed to allow training models on the GPU 115 | \end{enumerate} 116 | 117 | } 118 | 119 | 120 | % \item 121 | % Installing the required software. 122 | % \item 123 | % Perform a basic experiment to understand GANs (MNIST) 124 | % \item 125 | % Gather the required data-set to train the model 126 | % \item Design the model 127 | 128 | \end{itemize} 129 | \end{frame} 130 | 131 | 132 | \begin{frame}{Task Accomplished: DCGAN} 133 | \begin{figure} 134 | \includegraphics[width=\linewidth]{GAN.png} 135 | \caption{DCGAN block diagram\textsuperscript{[2]}} 136 | % \caption{DCGAN architecture} 137 | \label{fig:DCGAN} 138 | \end{figure} 139 | {\scriptsize{[2] GANs from Scratch - Part 1 on Medium}} 140 | 141 | \end{frame} 142 | 143 | \begin{frame}{Task Accomplished: MNIST example} 144 | \begin{itemize} 145 | \item 146 | {\textcolor{blue}{Perform a basic experiment to understand GANs (MNIST)} 147 | \linebreak \linebreak 148 | We used DCGAN to implement the task of generating new images from original MNIST dataset} 149 | 150 | \begin{figure}[h!] 151 | \centering 152 | \begin{subfigure}[b]{0.4\linewidth} 153 | \includegraphics[width=\linewidth]{mnist_ideal.png} 154 | \caption{original MNIST} 155 | \end{subfigure} 156 | \begin{subfigure}[b]{0.407\linewidth} 157 | \includegraphics[width=\linewidth]{mnist.png} 158 | \caption{generated MNIST} 159 | \end{subfigure} 160 | \caption{Comparison of generated MNIST images} 161 | \label{fig:MNIST} 162 | \end{figure} 163 | 164 | % \item 165 | % Perform a basic experiment to understand GANs (MNIST) 166 | % \item 167 | % Gather the required data-set to train the model 168 | % \item Design the model 169 | 170 | \end{itemize} 171 | \end{frame} 172 | 173 | 174 | 175 | 176 | 177 | \begin{frame}{Task Accomplished: Gather Dataset} 178 | \begin{itemize} 179 | \item 180 | {\textcolor{blue}{Gather the required data-set to train the model for the final solution} 181 | \linebreak \linebreak 182 | We are planning to use COCO image dataset to train out text to image model. COCO is a large-scale object detection, segmentation, and captioning dataset. The dataset has been downloaded and prepared successfully.} 183 | 184 | \begin{figure} 185 | \includegraphics[width=\linewidth]{coco.jpg} 186 | \caption{COCO Examples.\textsuperscript{[3]}} 187 | \label{fig:architecture} 188 | \end{figure} 189 | % \item 190 | % Perform a basic experiment to understand GANs (MNIST) 191 | % \item 192 | % Gather the required data-set to train the model 193 | % \item Design the model 194 | \linebreak 195 | {\scriptsize{[3] mscoco.org/dataset}} 196 | \end{itemize} 197 | \end{frame} 198 | 199 | 200 | \begin{frame}{Task Accomplished: Model Design} 201 | \begin{itemize} 202 | \item 203 | {\textcolor{blue}{Design the model and test its feasibilty}} 204 | \linebreak\linebreak 205 | Designing of model for generating images from text has has been completed but we are yet to test its feasibility and effectiveness. This current model incorporates the standard DCGAN architecture with conditioning on character data. 206 | \end{itemize} 207 | \end{frame} 208 | 209 | 210 | 211 | 212 | \section{Challenges Faced} 213 | \begin{frame}{Challenges Faced} 214 | \begin{itemize} 215 | 216 | \item Lack of knowledge about PyTorch before starting the internship 217 | \item Choosing hyper-parameters for GAN leading to efficient convergence 218 | \item Since GAN is fairly new and is still being actively researched, it took time to find the right algorithm for the task 219 | \item Finding a right Dataset for the task with captions and creating a dataloader for the same 220 | \end{itemize} 221 | \end{frame} 222 | 223 | \section{Future Plans} 224 | \begin{frame}{Future Plans} 225 | \begin{itemize} 226 | \item Test feasibility and effectiveness of our current model 227 | \item Train the model and calculate the accuracy of the model 228 | \item Generate new data-set of images/scenes from the text 229 | \item Next we plan to use these recent images to generate videos by either using a LSTM network to predict the next frame or a Temporal GAN 230 | 231 | 232 | \end{itemize} 233 | \end{frame} 234 | 235 | 236 | \section{Thank You} 237 | \begin{frame}{Thank You} 238 | \centering THANK YOU !!! 239 | \end{frame} 240 | \end{document} 241 | -------------------------------------------------------------------------------- /documentation/Progress-presentations/progress-presentation-1/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eYSIP-2018/Text_to_Image-Video_synthesis_using_GANs/7331a5aa6d898e7481206a406c554dda825bc971/documentation/Progress-presentations/progress-presentation-1/architecture.png -------------------------------------------------------------------------------- /documentation/Progress-presentations/progress-presentation-1/coco.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eYSIP-2018/Text_to_Image-Video_synthesis_using_GANs/7331a5aa6d898e7481206a406c554dda825bc971/documentation/Progress-presentations/progress-presentation-1/coco.jpg -------------------------------------------------------------------------------- /documentation/Progress-presentations/progress-presentation-1/gan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eYSIP-2018/Text_to_Image-Video_synthesis_using_GANs/7331a5aa6d898e7481206a406c554dda825bc971/documentation/Progress-presentations/progress-presentation-1/gan.png -------------------------------------------------------------------------------- /documentation/Progress-presentations/progress-presentation-1/mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eYSIP-2018/Text_to_Image-Video_synthesis_using_GANs/7331a5aa6d898e7481206a406c554dda825bc971/documentation/Progress-presentations/progress-presentation-1/mnist.png -------------------------------------------------------------------------------- /documentation/Progress-presentations/progress-presentation-1/mnist_ideal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eYSIP-2018/Text_to_Image-Video_synthesis_using_GANs/7331a5aa6d898e7481206a406c554dda825bc971/documentation/Progress-presentations/progress-presentation-1/mnist_ideal.png -------------------------------------------------------------------------------- /documentation/Progress-presentations/progress-presentation-1/progress_presentation_i.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eYSIP-2018/Text_to_Image-Video_synthesis_using_GANs/7331a5aa6d898e7481206a406c554dda825bc971/documentation/Progress-presentations/progress-presentation-1/progress_presentation_i.pdf -------------------------------------------------------------------------------- /documentation/REPORT.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eYSIP-2018/Text_to_Image-Video_synthesis_using_GANs/7331a5aa6d898e7481206a406c554dda825bc971/documentation/REPORT.pdf -------------------------------------------------------------------------------- /documentation/text_to_image_synthesis_using_gan.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eYSIP-2018/Text_to_Image-Video_synthesis_using_GANs/7331a5aa6d898e7481206a406c554dda825bc971/documentation/text_to_image_synthesis_using_gan.pdf -------------------------------------------------------------------------------- /documentation/text_to_image_synthesis_using_gan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eYSIP-2018/Text_to_Image-Video_synthesis_using_GANs/7331a5aa6d898e7481206a406c554dda825bc971/documentation/text_to_image_synthesis_using_gan.png -------------------------------------------------------------------------------- /texttoimage/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eYSIP-2018/Text_to_Image-Video_synthesis_using_GANs/7331a5aa6d898e7481206a406c554dda825bc971/texttoimage/.DS_Store -------------------------------------------------------------------------------- /texttoimage/code/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eYSIP-2018/Text_to_Image-Video_synthesis_using_GANs/7331a5aa6d898e7481206a406c554dda825bc971/texttoimage/code/.DS_Store -------------------------------------------------------------------------------- /texttoimage/code/cfg/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eYSIP-2018/Text_to_Image-Video_synthesis_using_GANs/7331a5aa6d898e7481206a406c554dda825bc971/texttoimage/code/cfg/.DS_Store -------------------------------------------------------------------------------- /texttoimage/code/cfg/birds_s1.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'stageI' 2 | 3 | DATASET_NAME: 'birds' 4 | EMBEDDING_TYPE: 'cnn-rnn' 5 | GPU_ID: '0,1' 6 | Z_DIM: 100 7 | DATA_DIR: '../data/birds' 8 | IMSIZE: 64 9 | 10 | WORKERS: 4 11 | STAGE: 1 12 | TRAIN: 13 | FLAG: True 14 | BATCH_SIZE: 128 15 | MAX_EPOCH: 500 16 | LR_DECAY_EPOCH: 20 17 | SNAPSHOT_INTERVAL: 10 18 | DISCRIMINATOR_LR: 0.0002 19 | GENERATOR_LR: 0.0002 20 | COEFF: 21 | KL: 2.0 22 | 23 | GAN: 24 | CONDITION_DIM: 128 25 | DF_DIM: 96 26 | GF_DIM: 192 27 | 28 | TEXT: 29 | DIMENSION: 1024 30 | -------------------------------------------------------------------------------- /texttoimage/code/cfg/birds_s2.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'stageII' 2 | 3 | DATASET_NAME: 'birds' 4 | EMBEDDING_TYPE: 'cnn-rnn' 5 | GPU_ID: '0,1' 6 | Z_DIM: 100 7 | STAGE1_G: '../output/birds_stageI_2018_06_19_12_41_44/Model/netG_epoch_500.pth' 8 | DATA_DIR: '../data/birds' 9 | WORKERS: 4 10 | IMSIZE: 256 11 | STAGE: 2 12 | TRAIN: 13 | FLAG: True 14 | BATCH_SIZE: 18 15 | MAX_EPOCH: 600 16 | LR_DECAY_EPOCH: 20 17 | SNAPSHOT_INTERVAL: 5 18 | DISCRIMINATOR_LR: 0.0002 19 | GENERATOR_LR: 0.0002 20 | COEFF: 21 | KL: 2.0 22 | 23 | GAN: 24 | CONDITION_DIM: 128 25 | DF_DIM: 96 26 | GF_DIM: 192 27 | R_NUM: 2 28 | 29 | TEXT: 30 | DIMENSION: 1024 31 | -------------------------------------------------------------------------------- /texttoimage/code/cfg/coco_eval.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'stageII' 2 | 3 | DATASET_NAME: 'coco' 4 | EMBEDDING_TYPE: 'cnn-rnn' 5 | GPU_ID: '0,1' 6 | Z_DIM: 100 7 | NET_G: '../models/coco/netG_epoch_120.pth' 8 | DATA_DIR: '../data/coco' 9 | WORKERS: 4 10 | IMSIZE: 256 11 | STAGE: 2 12 | TRAIN: 13 | FLAG: False 14 | BATCH_SIZE: 7 15 | 16 | GAN: 17 | CONDITION_DIM: 128 18 | DF_DIM: 96 19 | GF_DIM: 192 20 | R_NUM: 2 21 | 22 | TEXT: 23 | DIMENSION: 1024 24 | -------------------------------------------------------------------------------- /texttoimage/code/cfg/coco_s1.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'stageI' 2 | 3 | DATASET_NAME: 'coco' 4 | EMBEDDING_TYPE: 'cnn-rnn' 5 | GPU_ID: '0,1' 6 | Z_DIM: 100 7 | DATA_DIR: '../data/coco' 8 | IMSIZE: 64 9 | 10 | WORKERS: 4 11 | STAGE: 1 12 | TRAIN: 13 | FLAG: True 14 | BATCH_SIZE: 128 15 | MAX_EPOCH: 500 16 | LR_DECAY_EPOCH: 20 17 | SNAPSHOT_INTERVAL: 10 18 | DISCRIMINATOR_LR: 0.0002 19 | GENERATOR_LR: 0.0002 20 | COEFF: 21 | KL: 2.0 22 | 23 | GAN: 24 | CONDITION_DIM: 128 25 | DF_DIM: 96 26 | GF_DIM: 192 27 | 28 | TEXT: 29 | DIMENSION: 1024 30 | -------------------------------------------------------------------------------- /texttoimage/code/cfg/coco_s2.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'stageII' 2 | 3 | DATASET_NAME: 'coco' 4 | EMBEDDING_TYPE: 'cnn-rnn' 5 | GPU_ID: '0,1' 6 | Z_DIM: 100 7 | STAGE1_G: '../output/coco_stageI_2018_06_19_12_41_44/Model/netG_epoch_500.pth' 8 | DATA_DIR: '../data/coco' 9 | WORKERS: 4 10 | IMSIZE: 256 11 | STAGE: 2 12 | TRAIN: 13 | FLAG: True 14 | BATCH_SIZE: 18 15 | MAX_EPOCH: 600 16 | LR_DECAY_EPOCH: 20 17 | SNAPSHOT_INTERVAL: 5 18 | DISCRIMINATOR_LR: 0.0002 19 | GENERATOR_LR: 0.0002 20 | COEFF: 21 | KL: 2.0 22 | 23 | GAN: 24 | CONDITION_DIM: 128 25 | DF_DIM: 96 26 | GF_DIM: 192 27 | R_NUM: 2 28 | 29 | TEXT: 30 | DIMENSION: 1024 31 | -------------------------------------------------------------------------------- /texttoimage/code/cfg/flowers_s1.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'stageI' 2 | 3 | DATASET_NAME: 'flowers' 4 | EMBEDDING_TYPE: 'cnn-rnn' 5 | GPU_ID: '0,1' 6 | Z_DIM: 100 7 | DATA_DIR: '../data/flowers' 8 | IMSIZE: 64 9 | 10 | WORKERS: 4 11 | STAGE: 1 12 | TRAIN: 13 | FLAG: True 14 | BATCH_SIZE: 128 15 | MAX_EPOCH: 500 16 | LR_DECAY_EPOCH: 20 17 | SNAPSHOT_INTERVAL: 10 18 | DISCRIMINATOR_LR: 0.0002 19 | GENERATOR_LR: 0.0002 20 | COEFF: 21 | KL: 2.0 22 | 23 | GAN: 24 | CONDITION_DIM: 128 25 | DF_DIM: 96 26 | GF_DIM: 192 27 | 28 | TEXT: 29 | DIMENSION: 1024 30 | -------------------------------------------------------------------------------- /texttoimage/code/cfg/flowers_s2.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'stageII' 2 | 3 | DATASET_NAME: 'flowers' 4 | EMBEDDING_TYPE: 'cnn-rnn' 5 | GPU_ID: '0,1' 6 | Z_DIM: 100 7 | STAGE1_G: '../output/flowers_stageI_2018_06_19_12_41_44/Model/netG_epoch_600.pth' 8 | DATA_DIR: '../data/flowers' 9 | WORKERS: 4 10 | IMSIZE: 256 11 | STAGE: 2 12 | TRAIN: 13 | FLAG: True 14 | BATCH_SIZE: 18 15 | MAX_EPOCH: 600 16 | LR_DECAY_EPOCH: 20 17 | SNAPSHOT_INTERVAL: 5 18 | DISCRIMINATOR_LR: 0.0002 19 | GENERATOR_LR: 0.0002 20 | COEFF: 21 | KL: 2.0 22 | 23 | GAN: 24 | CONDITION_DIM: 128 25 | DF_DIM: 96 26 | GF_DIM: 192 27 | R_NUM: 2 28 | 29 | TEXT: 30 | DIMENSION: 1024 31 | -------------------------------------------------------------------------------- /texttoimage/code/demo_birds.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "{'CONFIG_NAME': 'stageII',\n", 13 | " 'CUDA': True,\n", 14 | " 'DATASET_NAME': 'birds',\n", 15 | " 'DATA_DIR': '../data/birds',\n", 16 | " 'EMBEDDING_TYPE': 'cnn-rnn',\n", 17 | " 'GAN': {'CONDITION_DIM': 128, 'DF_DIM': 96, 'GF_DIM': 192, 'R_NUM': 2},\n", 18 | " 'GPU_ID': '0',\n", 19 | " 'IMSIZE': 256,\n", 20 | " 'NET_D': '',\n", 21 | " 'NET_G': '../models/birds/netG_epoch_350.pth',\n", 22 | " 'STAGE': 2,\n", 23 | " 'STAGE1_G': '',\n", 24 | " 'TEXT': {'DIMENSION': 1024},\n", 25 | " 'TRAIN': {'BATCH_SIZE': 16,\n", 26 | " 'COEFF': {'KL': 2.0},\n", 27 | " 'DISCRIMINATOR_LR': 0.0002,\n", 28 | " 'FLAG': False,\n", 29 | " 'GENERATOR_LR': 0.0002,\n", 30 | " 'LR_DECAY_EPOCH': 600,\n", 31 | " 'MAX_EPOCH': 600,\n", 32 | " 'PRETRAINED_EPOCH': 600,\n", 33 | " 'PRETRAINED_MODEL': '',\n", 34 | " 'SNAPSHOT_INTERVAL': 50},\n", 35 | " 'VIS_COUNT': 64,\n", 36 | " 'WORKERS': 4,\n", 37 | " 'Z_DIM': 100}\n" 38 | ] 39 | } 40 | ], 41 | "source": [ 42 | "import torch \n", 43 | "import os \n", 44 | "from subprocess import call\n", 45 | "from matplotlib.pyplot import imshow\n", 46 | "from miscc.config import cfg,cfg_from_file\n", 47 | "import pprint\n", 48 | "from miscc.utils import weights_init,mkdir_p\n", 49 | "import torchfile\n", 50 | "import numpy as np \n", 51 | "from torch.autograd import Variable\n", 52 | "import torch.nn as nn\n", 53 | "from PIL import Image\n", 54 | "import matplotlib.pyplot as plt \n", 55 | "\n", 56 | "cfg_from_file(\"cfg/demo.yml\")\n", 57 | "cfg.GPU_ID = '0'\n", 58 | "cfg.NET_G = '../models/birds/netG_epoch_350.pth'\n", 59 | "model=\"birds\"\n", 60 | "s_gpus = cfg.GPU_ID.split(',')\n", 61 | "gpus = [int(ix) for ix in s_gpus]\n", 62 | "torch.cuda.set_device(gpus[0])\n", 63 | "pprint.pprint(cfg)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 2, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "def load_network_stageII():\n", 73 | " from model import STAGE1_G, STAGE2_G, STAGE2_D\n", 74 | "\n", 75 | " Stage1_G = STAGE1_G()\n", 76 | " netG = STAGE2_G(Stage1_G)\n", 77 | " netG.apply(weights_init)\n", 78 | " \n", 79 | " if cfg.NET_G != '':\n", 80 | " state_dict = \\\n", 81 | " torch.load(cfg.NET_G,\n", 82 | " map_location=lambda storage, loc: storage)\n", 83 | " netG.load_state_dict(state_dict)\n", 84 | " print('Load from: ', cfg.NET_G)\n", 85 | " elif cfg.STAGE1_G != '':\n", 86 | " state_dict = \\\n", 87 | " torch.load(cfg.STAGE1_G,\n", 88 | " map_location=lambda storage, loc: storage)\n", 89 | " netG.STAGE1_G.load_state_dict(state_dict)\n", 90 | " print('Load from: ', cfg.STAGE1_G)\n", 91 | " else:\n", 92 | " print(\"Please give the Stage1_G path\")\n", 93 | " return\n", 94 | "\n", 95 | " netD = STAGE2_D()\n", 96 | " netD.apply(weights_init)\n", 97 | " if cfg.NET_D != '':\n", 98 | " state_dict = \\\n", 99 | " torch.load(cfg.NET_D,\n", 100 | " map_location=lambda storage, loc: storage)\n", 101 | " netD.load_state_dict(state_dict)\n", 102 | " print('Load from: ', cfg.NET_D)\n", 103 | " \n", 104 | "\n", 105 | " if cfg.CUDA:\n", 106 | " netG.cuda()\n", 107 | " netD.cuda()\n", 108 | " return netG, netD\n", 109 | " " 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 3, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "def sample(datapath):\n", 119 | "\n", 120 | " netG, _ = load_network_stageII()\n", 121 | " netG.eval()\n", 122 | "\n", 123 | " # Load text embeddings generated from the encoder\n", 124 | " t_file = torchfile.load(datapath)\n", 125 | " captions_list = t_file.raw_txt\n", 126 | " embeddings = np.concatenate(t_file.fea_txt, axis=0)\n", 127 | " num_embeddings = len(captions_list)\n", 128 | " print('Successfully load sentences from: ', datapath)\n", 129 | " print('Total number of sentences:', num_embeddings)\n", 130 | " print('num_embeddings:', num_embeddings, embeddings.shape)\n", 131 | " # path to save generated samples\n", 132 | " save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')]\n", 133 | " mkdir_p(save_dir)\n", 134 | "\n", 135 | " batch_size = cfg.TRAIN.BATCH_SIZE\n", 136 | " nz = cfg.Z_DIM\n", 137 | " noise = Variable(torch.FloatTensor(batch_size, nz))\n", 138 | " if cfg.CUDA:\n", 139 | " noise = noise.cuda()\n", 140 | " count = 0\n", 141 | " while count < num_embeddings:\n", 142 | " if count > 3000:\n", 143 | " break\n", 144 | " iend = count + batch_size\n", 145 | " if iend > num_embeddings:\n", 146 | " iend = num_embeddings\n", 147 | " count = num_embeddings - batch_size\n", 148 | " embeddings_batch = embeddings[count:iend]\n", 149 | " # captions_batch = captions_list[count:iend]\n", 150 | " txt_embedding = Variable(torch.FloatTensor(embeddings_batch))\n", 151 | " if cfg.CUDA:\n", 152 | " txt_embedding = txt_embedding.cuda()\n", 153 | "\n", 154 | " noise.data.normal_(0, 1)\n", 155 | " inputs = (txt_embedding, noise)\n", 156 | " _, fake_imgs, mu, logvar = \\\n", 157 | " nn.parallel.data_parallel(netG, inputs,gpus)\n", 158 | " file = open(\"captions_birds.txt\",\"r\")\n", 159 | " lines = file.readlines()\n", 160 | " for i in range(batch_size): \n", 161 | " save_name = '%s/%d.png' % (save_dir, count + i)\n", 162 | " im = fake_imgs[i].data.cpu().numpy()\n", 163 | " im = (im + 1.0) * 127.5\n", 164 | " im = im.astype(np.uint8)\n", 165 | " # print('im', im.shape)\n", 166 | " im = np.transpose(im, (1, 2, 0))\n", 167 | " # print('im', im.shape)\n", 168 | " im = Image.fromarray(im)\n", 169 | " plt.figure()\n", 170 | " plt.title(lines[i])\n", 171 | " plt.imshow(im)\n", 172 | " count += batch_size" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 6, 178 | "metadata": {}, 179 | "outputs": [ 180 | { 181 | "name": "stdout", 182 | "output_type": "stream", 183 | "text": [ 184 | "this bird has blue body and yellow beak.\n" 185 | ] 186 | } 187 | ], 188 | "source": [ 189 | "if \"captions_birds_.txt\" in os.listdir(\".\"):\n", 190 | " os.remove(\"captions_birds_.txt\")\n", 191 | "line = input(\"\")\n", 192 | "with open('captions_birds_.txt', 'a') as the_file:\n", 193 | " the_file.write(line)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 7, 199 | "metadata": { 200 | "scrolled": false 201 | }, 202 | "outputs": [ 203 | { 204 | "name": "stdout", 205 | "output_type": "stream", 206 | "text": [ 207 | "Load from: ../models/birds/netG_epoch_350.pth\n", 208 | "Successfully load sentences from: results.t7\n", 209 | "Total number of sentences: 1\n", 210 | "num_embeddings: 1 (1, 1024)\n" 211 | ] 212 | }, 213 | { 214 | "ename": "RuntimeError", 215 | "evalue": "invalid argument 0: Sizes of tensors must match except in dimension 1. Got 1 and 16 in dimension 0 at /pytorch/aten/src/THC/generic/THCTensorMath.cu:111", 216 | "output_type": "error", 217 | "traceback": [ 218 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 219 | "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", 220 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mlines\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfile\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreadlines\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mcall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"th\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\"get_embedding.lua\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"results.t7\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 221 | "\u001b[0;32m\u001b[0m in \u001b[0;36msample\u001b[0;34m(datapath)\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0mnoise\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnormal_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0minputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtxt_embedding\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnoise\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 39\u001b[0;31m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfake_imgs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlogvar\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparallel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_parallel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnetG\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mgpus\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 40\u001b[0m \u001b[0mfile\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"captions_birds.txt\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\"r\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0mlines\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfile\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreadlines\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 222 | "\u001b[0;32m~/.local/lib/python3.5/site-packages/torch/nn/parallel/data_parallel.py\u001b[0m in \u001b[0;36mdata_parallel\u001b[0;34m(module, inputs, device_ids, output_device, dim, module_kwargs)\u001b[0m\n\u001b[1;32m 154\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodule_kwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mscatter_kwargs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodule_kwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice_ids\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 155\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice_ids\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 156\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mmodule_kwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 157\u001b[0m \u001b[0mused_device_ids\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdevice_ids\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 158\u001b[0m \u001b[0mreplicas\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mreplicate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mused_device_ids\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 223 | "\u001b[0;32m~/.local/lib/python3.5/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 489\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 490\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 491\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 492\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 493\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 224 | "\u001b[0;32m~/texttoimage/code/model.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, text_embedding, noise)\u001b[0m\n\u001b[1;32m 241\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 242\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtext_embedding\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnoise\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 243\u001b[0;31m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstage1_img\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSTAGE1_G\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext_embedding\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnoise\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 244\u001b[0m \u001b[0mstage1_img\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstage1_img\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 245\u001b[0m \u001b[0mencoded_img\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstage1_img\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 225 | "\u001b[0;32m~/.local/lib/python3.5/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 489\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 490\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 491\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 492\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 493\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 226 | "\u001b[0;32m~/texttoimage/code/model.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, text_embedding, noise)\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtext_embedding\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnoise\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0mc_code\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlogvar\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mca_net\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext_embedding\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 142\u001b[0;31m \u001b[0mz_c_code\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnoise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc_code\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 143\u001b[0m \u001b[0mh_code\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz_c_code\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 227 | "\u001b[0;31mRuntimeError\u001b[0m: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 1 and 16 in dimension 0 at /pytorch/aten/src/THC/generic/THCTensorMath.cu:111" 228 | ] 229 | } 230 | ], 231 | "source": [ 232 | "if \"results.t7\" in os.listdir(\".\"):\n", 233 | " os.remove(\"results.t7\")\n", 234 | "file = open(\"captions_birds_.txt\",\"r\")\n", 235 | "lines = file.readlines()\n", 236 | "call([\"th\",\"get_embedding.lua\"])\n", 237 | "sample(\"results.t7\")" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [] 253 | } 254 | ], 255 | "metadata": { 256 | "kernelspec": { 257 | "display_name": "Python 3", 258 | "language": "python", 259 | "name": "python3" 260 | }, 261 | "language_info": { 262 | "codemirror_mode": { 263 | "name": "ipython", 264 | "version": 3 265 | }, 266 | "file_extension": ".py", 267 | "mimetype": "text/x-python", 268 | "name": "python", 269 | "nbconvert_exporter": "python", 270 | "pygments_lexer": "ipython3", 271 | "version": "3.5.2" 272 | } 273 | }, 274 | "nbformat": 4, 275 | "nbformat_minor": 2 276 | } 277 | -------------------------------------------------------------------------------- /texttoimage/code/get_embedding.lua: -------------------------------------------------------------------------------- 1 | -- Modification from the codebase of scott's icml16 2 | -- please check https://github.com/reedscot/icml2016 for details 3 | 4 | require 'torch' 5 | require 'image' 6 | require 'nn' 7 | require 'nngraph' 8 | require 'cunn' 9 | require 'cutorch' 10 | require 'cudnn' 11 | require 'lfs' 12 | 13 | 14 | 15 | 16 | torch.setdefaulttensortype('torch.FloatTensor') 17 | 18 | local alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{} " 19 | local dict = {} 20 | for i = 1,#alphabet do 21 | dict[alphabet:sub(i,i)] = i 22 | end 23 | ivocab = {} 24 | for k,v in pairs(dict) do 25 | ivocab[v] = k 26 | end 27 | 28 | opt = { 29 | filenames = 'results.t7', 30 | doc_length = 201, 31 | queries = 'captions_birds_.txt', 32 | net_txt = 'lm_sje_cub_c10_hybrid_0.00070_1_10_trainvalids.txt.t7', 33 | } 34 | 35 | for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end 36 | print(opt) 37 | 38 | net_txt = torch.load(opt.net_txt) 39 | 40 | if net_txt.protos ~=nil then net_txt = net_txt.protos.enc_doc end 41 | 42 | 43 | net_txt:evaluate() 44 | 45 | -- Extract all text features. 46 | local fea_txt = {} 47 | -- Decode text for sanity check. 48 | local raw_txt = {} 49 | local raw_img = {} 50 | for query_str in io.lines(opt.queries) do 51 | local txt = torch.zeros(1,opt.doc_length,#alphabet) 52 | for t = 1,opt.doc_length do 53 | local ch = query_str:sub(t,t) 54 | local ix = dict[ch] 55 | if ix ~= 0 and ix ~= nil then 56 | txt[{1,t,ix}] = 1 57 | end 58 | end 59 | raw_txt[#raw_txt+1] = query_str 60 | txt = txt:cuda() 61 | 62 | fea_txt[#fea_txt+1] = net_txt:forward(txt):clone() 63 | end 64 | 65 | torch.save(opt.filenames, {raw_txt=raw_txt, fea_txt=fea_txt}) 66 | -------------------------------------------------------------------------------- /texttoimage/code/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.backends.cudnn as cudnn 3 | import torch 4 | import torchvision.transforms as transforms 5 | 6 | import argparse 7 | import os 8 | import random 9 | import sys 10 | import pprint 11 | import datetime 12 | import dateutil 13 | import dateutil.tz 14 | from PIL import ImageFile 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | 17 | 18 | dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.'))) 19 | sys.path.append(dir_path) 20 | 21 | from miscc.datasets import TextDataset 22 | from miscc.config import cfg, cfg_from_file 23 | from miscc.utils import mkdir_p 24 | from trainer import GANTrainer 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser(description='Train a GAN network') 29 | parser.add_argument('--cfg', dest='cfg_file', 30 | help='optional config file', 31 | default='birds_stage1.yml', type=str) 32 | parser.add_argument('--gpu', dest='gpu_id', type=str, default='0') 33 | parser.add_argument('--data_dir', dest='data_dir', type=str, default='') 34 | parser.add_argument('--manualSeed', type=int, help='manual seed') 35 | args = parser.parse_args() 36 | return args 37 | 38 | if __name__ == "__main__": 39 | args = parse_args() 40 | if args.cfg_file is not None: 41 | cfg_from_file(args.cfg_file) 42 | if args.gpu_id != -1: 43 | cfg.GPU_ID = args.gpu_id 44 | if args.data_dir != '': 45 | cfg.DATA_DIR = args.data_dir 46 | print('Using config:') 47 | pprint.pprint(cfg) 48 | if args.manualSeed is None: 49 | args.manualSeed = random.randint(1, 10000) 50 | random.seed(args.manualSeed) 51 | torch.manual_seed(args.manualSeed) 52 | if cfg.CUDA: 53 | torch.cuda.manual_seed_all(args.manualSeed) 54 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 55 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 56 | output_dir = '../output/%s_%s_%s' % \ 57 | (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp) 58 | 59 | num_gpu = len(cfg.GPU_ID.split(',')) 60 | if cfg.TRAIN.FLAG: 61 | image_transform = transforms.Compose([ 62 | transforms.RandomCrop(cfg.IMSIZE), 63 | transforms.RandomHorizontalFlip(), 64 | transforms.ToTensor(), 65 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 66 | dataset = TextDataset(cfg.DATA_DIR, 'train', 67 | imsize=cfg.IMSIZE, 68 | transform=image_transform) 69 | assert dataset 70 | dataloader = torch.utils.data.DataLoader( 71 | dataset, batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu, 72 | drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS)) 73 | 74 | algo = GANTrainer(output_dir) 75 | algo.train(dataloader, cfg.STAGE) 76 | else: 77 | datapath= '%s/test/result.t7' % (cfg.DATA_DIR) 78 | algo = GANTrainer(output_dir) 79 | algo.sample(datapath, cfg.STAGE) 80 | -------------------------------------------------------------------------------- /texttoimage/code/miscc/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import os.path as osp 5 | import numpy as np 6 | from easydict import EasyDict as edict 7 | 8 | 9 | __C = edict() 10 | cfg = __C 11 | 12 | # Dataset name: flowers, birds 13 | __C.DATASET_NAME = 'birds' 14 | __C.EMBEDDING_TYPE = 'cnn-rnn' 15 | __C.CONFIG_NAME = '' 16 | __C.GPU_ID = '0' 17 | __C.CUDA = True 18 | __C.WORKERS = 6 19 | 20 | __C.NET_G = '' 21 | __C.NET_D = '' 22 | __C.STAGE1_G = '' 23 | __C.DATA_DIR = '' 24 | __C.VIS_COUNT = 64 25 | 26 | __C.Z_DIM = 100 27 | __C.IMSIZE = 64 28 | __C.STAGE = 1 29 | 30 | 31 | # Training options 32 | __C.TRAIN = edict() 33 | __C.TRAIN.FLAG = True 34 | __C.TRAIN.BATCH_SIZE = 64 35 | __C.TRAIN.MAX_EPOCH = 600 36 | __C.TRAIN.SNAPSHOT_INTERVAL = 50 37 | __C.TRAIN.PRETRAINED_MODEL = '' 38 | __C.TRAIN.PRETRAINED_EPOCH = 600 39 | __C.TRAIN.LR_DECAY_EPOCH = 600 40 | __C.TRAIN.DISCRIMINATOR_LR = 2e-4 41 | __C.TRAIN.GENERATOR_LR = 2e-4 42 | 43 | __C.TRAIN.COEFF = edict() 44 | __C.TRAIN.COEFF.KL = 2.0 45 | 46 | # Modal options 47 | __C.GAN = edict() 48 | __C.GAN.CONDITION_DIM = 128 49 | __C.GAN.DF_DIM = 64 50 | __C.GAN.GF_DIM = 128 51 | __C.GAN.R_NUM = 4 52 | 53 | __C.TEXT = edict() 54 | __C.TEXT.DIMENSION = 1024 55 | 56 | 57 | def _merge_a_into_b(a, b): 58 | """Merge config dictionary a into config dictionary b, clobbering the 59 | options in b whenever they are also specified in a. 60 | """ 61 | if type(a) is not edict: 62 | return 63 | 64 | for k, v in a.items(): 65 | # a must specify keys that are in b 66 | if k not in b: 67 | raise KeyError('{} is not a valid config key'.format(k)) 68 | 69 | # the types must match, too 70 | old_type = type(b[k]) 71 | if old_type is not type(v): 72 | if isinstance(b[k], np.ndarray): 73 | v = np.array(v, dtype=b[k].dtype) 74 | else: 75 | raise ValueError(('Type mismatch ({} vs. {}) ' 76 | 'for config key: {}').format(type(b[k]), 77 | type(v), k)) 78 | 79 | # recursively merge dicts 80 | if type(v) is edict: 81 | try: 82 | _merge_a_into_b(a[k], b[k]) 83 | except: 84 | print('Error under config key: {}'.format(k)) 85 | raise 86 | else: 87 | b[k] = v 88 | 89 | 90 | def cfg_from_file(filename): 91 | """Load a config file and merge it into the default options.""" 92 | import yaml 93 | with open(filename, 'r') as f: 94 | yaml_cfg = edict(yaml.load(f)) 95 | 96 | _merge_a_into_b(yaml_cfg, __C) 97 | -------------------------------------------------------------------------------- /texttoimage/code/miscc/datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import PIL 10 | import os 11 | import os.path 12 | import pickle 13 | import random 14 | import numpy as np 15 | import pandas as pd 16 | 17 | from miscc.config import cfg 18 | 19 | 20 | class TextDataset(data.Dataset): 21 | def __init__(self, data_dir, split='train', embedding_type='cnn-rnn', 22 | imsize=64, transform=None, target_transform=None): 23 | 24 | self.transform = transform 25 | self.target_transform = target_transform 26 | self.imsize = imsize 27 | self.data = [] 28 | self.data_dir = data_dir 29 | self.bbox = self.load_bbox() 30 | split_dir = os.path.join(data_dir, split) 31 | 32 | self.filenames = self.load_filenames(split_dir) 33 | self.embeddings = self.load_embedding(split_dir, embedding_type) 34 | self.class_id = self.load_class_id(split_dir, len(self.filenames)) 35 | # self.captions = self.load_all_captions() 36 | 37 | def get_img(self, img_path, bbox): 38 | img = Image.open(img_path).convert('RGB') 39 | width, height = img.size 40 | if bbox is not None: 41 | R = int(np.maximum(bbox[2], bbox[3]) * 0.75) 42 | center_x = int((2 * bbox[0] + bbox[2]) / 2) 43 | center_y = int((2 * bbox[1] + bbox[3]) / 2) 44 | y1 = np.maximum(0, center_y - R) 45 | y2 = np.minimum(height, center_y + R) 46 | x1 = np.maximum(0, center_x - R) 47 | x2 = np.minimum(width, center_x + R) 48 | img = img.crop([x1, y1, x2, y2]) 49 | load_size = int(self.imsize * 76 / 64) 50 | img = img.resize((load_size, load_size), PIL.Image.BILINEAR) 51 | if self.transform is not None: 52 | img = self.transform(img) 53 | return img 54 | 55 | def load_bbox(self): 56 | data_dir = self.data_dir 57 | bbox_path = os.path.join(data_dir, 'CUB_200_2011/bounding_boxes.txt') 58 | df_bounding_boxes = pd.read_csv(bbox_path, 59 | delim_whitespace=True, 60 | header=None).astype(int) 61 | # 62 | filepath = os.path.join(data_dir, 'CUB_200_2011/images.txt') 63 | df_filenames = \ 64 | pd.read_csv(filepath, delim_whitespace=True, header=None) 65 | filenames = df_filenames[1].tolist() 66 | print('Total filenames: ', len(filenames), filenames[0]) 67 | # 68 | filename_bbox = {img_file[:-4]: [] for img_file in filenames} 69 | numImgs = len(filenames) 70 | for i in range(0, numImgs): 71 | # bbox = [x-left, y-top, width, height] 72 | bbox = df_bounding_boxes.iloc[i][1:].tolist() 73 | 74 | key = filenames[i][:-4] 75 | filename_bbox[key] = bbox 76 | # 77 | return filename_bbox 78 | 79 | def load_all_captions(self): 80 | caption_dict = {} 81 | for key in self.filenames: 82 | caption_name = '%s/text/%s.txt' % (self.data_dir, key) 83 | captions = self.load_captions(caption_name) 84 | caption_dict[key] = captions 85 | return caption_dict 86 | 87 | def load_captions(self, caption_name): 88 | cap_path = caption_name 89 | with open(cap_path, "r") as f: 90 | captions = f.read().decode('utf8').split('\n') 91 | captions = [cap.replace("\ufffd\ufffd", " ") 92 | for cap in captions if len(cap) > 0] 93 | return captions 94 | 95 | def load_embedding(self, data_dir, embedding_type): 96 | if embedding_type == 'cnn-rnn': 97 | embedding_filename = '/char-CNN-RNN-embeddings.pickle' 98 | elif embedding_type == 'cnn-gru': 99 | embedding_filename = '/char-CNN-GRU-embeddings.pickle' 100 | elif embedding_type == 'skip-thought': 101 | embedding_filename = '/skip-thought-embeddings.pickle' 102 | 103 | with open(data_dir + embedding_filename, 'rb') as f: 104 | embeddings = pickle.load(f,encoding='latin1') 105 | embeddings = np.array(embeddings) 106 | # embedding_shape = [embeddings.shape[-1]] 107 | print('embeddings: ', embeddings.shape) 108 | return embeddings 109 | 110 | def load_class_id(self, data_dir, total_num): 111 | if os.path.isfile(data_dir + '/class_info.pickle'): 112 | with open(data_dir + '/class_info.pickle', 'rb') as f: 113 | class_id = pickle.load(f,encoding='latin1') 114 | else: 115 | class_id = np.arange(total_num) 116 | return class_id 117 | def load_filenames(self, data_dir): 118 | filepath = os.path.join(data_dir, 'filenames.pickle') 119 | with open(filepath, 'rb') as f: 120 | filenames = pickle.load(f) 121 | print('Load filenames from: %s (%d)' % (filepath, len(filenames))) 122 | return filenames 123 | 124 | def __getitem__(self, index): 125 | key = self.filenames[index] 126 | # cls_id = self.class_id[index] 127 | # 128 | if self.bbox is not None: 129 | bbox = self.bbox[key] 130 | data_dir = '%s/CUB_200_2011' % self.data_dir 131 | else: 132 | bbox = None 133 | data_dir = self.data_dir 134 | 135 | # captions = self.captions[key] 136 | embeddings = self.embeddings[index, :, :] 137 | img_name = '%s/images/%s.jpg' % (data_dir, key) 138 | img = self.get_img(img_name, bbox) 139 | 140 | embedding_ix = random.randint(0, embeddings.shape[0]-1) 141 | embedding = embeddings[embedding_ix, :] 142 | if self.target_transform is not None: 143 | embedding = self.target_transform(embedding) 144 | return img, embedding 145 | 146 | def __len__(self): 147 | return len(self.filenames) 148 | -------------------------------------------------------------------------------- /texttoimage/code/miscc/utils .py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import numpy as np 4 | 5 | from copy import deepcopy 6 | from miscc.config import cfg 7 | 8 | from torch.nn import init 9 | import torch 10 | import torch.nn as nn 11 | import torchvision.utils as vutils 12 | 13 | 14 | ############################# 15 | def KL_loss(mu, logvar): 16 | # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 17 | KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) 18 | KLD = torch.mean(KLD_element).mul_(-0.5) 19 | return KLD 20 | 21 | 22 | def compute_discriminator_loss(netD, real_imgs, fake_imgs, 23 | real_labels, fake_labels, 24 | conditions, gpus): 25 | criterion = nn.BCELoss() 26 | batch_size = real_imgs.size(0) 27 | cond = conditions.detach() 28 | fake = fake_imgs.detach() 29 | real_features = nn.parallel.data_parallel(netD, (real_imgs), gpus) 30 | fake_features = nn.parallel.data_parallel(netD, (fake), gpus) 31 | # real pairs 32 | inputs = (real_features, cond) 33 | real_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) 34 | errD_real = criterion(real_logits, real_labels) 35 | # wrong pairs 36 | inputs = (real_features[:(batch_size-1)], cond[1:]) 37 | wrong_logits = \ 38 | nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) 39 | errD_wrong = criterion(wrong_logits, fake_labels[1:]) 40 | # fake pairs 41 | inputs = (fake_features, cond) 42 | fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) 43 | errD_fake = criterion(fake_logits, fake_labels) 44 | 45 | if netD.get_uncond_logits is not None: 46 | real_logits = \ 47 | nn.parallel.data_parallel(netD.get_uncond_logits, 48 | (real_features), gpus) 49 | fake_logits = \ 50 | nn.parallel.data_parallel(netD.get_uncond_logits, 51 | (fake_features), gpus) 52 | uncond_errD_real = criterion(real_logits, real_labels) 53 | uncond_errD_fake = criterion(fake_logits, fake_labels) 54 | # 55 | errD = ((errD_real + uncond_errD_real) / 2. + 56 | (errD_fake + errD_wrong + uncond_errD_fake) / 3.) 57 | errD_real = (errD_real + uncond_errD_real) / 2. 58 | errD_fake = (errD_fake + uncond_errD_fake) / 2. 59 | else: 60 | errD = errD_real + (errD_fake + errD_wrong) * 0.5 61 | return errD, errD_real.data[0], errD_wrong.data[0], errD_fake.data[0] 62 | 63 | 64 | def compute_generator_loss(netD, fake_imgs, real_labels, conditions, gpus): 65 | criterion = nn.BCELoss() 66 | cond = conditions.detach() 67 | fake_features = nn.parallel.data_parallel(netD, (fake_imgs), gpus) 68 | # fake pairs 69 | inputs = (fake_features, cond) 70 | fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) 71 | errD_fake = criterion(fake_logits, real_labels) 72 | if netD.get_uncond_logits is not None: 73 | fake_logits = \ 74 | nn.parallel.data_parallel(netD.get_uncond_logits, 75 | (fake_features), gpus) 76 | uncond_errD_fake = criterion(fake_logits, real_labels) 77 | errD_fake += uncond_errD_fake 78 | return errD_fake 79 | 80 | 81 | ############################# 82 | def weights_init(m): 83 | classname = m.__class__.__name__ 84 | if classname.find('Conv') != -1: 85 | m.weight.data.normal_(0.0, 0.02) 86 | elif classname.find('BatchNorm') != -1: 87 | m.weight.data.normal_(1.0, 0.02) 88 | m.bias.data.fill_(0) 89 | elif classname.find('Linear') != -1: 90 | m.weight.data.normal_(0.0, 0.02) 91 | if m.bias is not None: 92 | m.bias.data.fill_(0.0) 93 | 94 | 95 | ############################# 96 | def save_img_results(data_img, fake, epoch, image_dir): 97 | num = cfg.VIS_COUNT 98 | fake = fake[0:num] 99 | # data_img is changed to [0,1] 100 | if data_img is not None: 101 | data_img = data_img[0:num] 102 | vutils.save_image( 103 | data_img, '%s/real_samples.png' % image_dir, 104 | normalize=True) 105 | # fake.data is still [-1, 1] 106 | vutils.save_image( 107 | fake.data, '%s/fake_samples_epoch_%03d.png' % 108 | (image_dir, epoch), normalize=True) 109 | else: 110 | vutils.save_image( 111 | fake.data, '%s/lr_fake_samples_epoch_%03d.png' % 112 | (image_dir, epoch), normalize=True) 113 | 114 | 115 | def save_model(netG, netD, epoch, model_dir): 116 | torch.save( 117 | netG.state_dict(), 118 | '%s/netG_epoch_%d.pth' % (model_dir, epoch)) 119 | torch.save( 120 | netD.state_dict(), 121 | '%s/netD_epoch_last.pth' % (model_dir)) 122 | print('Save G/D models') 123 | 124 | 125 | def mkdir_p(path): 126 | try: 127 | os.makedirs(path) 128 | except OSError as exc: # Python >2.5 129 | if exc.errno == errno.EEXIST and os.path.isdir(path): 130 | pass 131 | else: 132 | raise 133 | -------------------------------------------------------------------------------- /texttoimage/code/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | from miscc.config import cfg 5 | from torch.autograd import Variable 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | "3x3 convolution with padding" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | 14 | # Upsale the spatial size by a factor of 2 15 | def upBlock(in_planes, out_planes): 16 | block = nn.Sequential( 17 | nn.Upsample(scale_factor=2, mode='nearest'), 18 | conv3x3(in_planes, out_planes), 19 | nn.BatchNorm2d(out_planes), 20 | nn.ReLU(True)) 21 | return block 22 | 23 | 24 | class ResBlock(nn.Module): 25 | def __init__(self, channel_num): 26 | super(ResBlock, self).__init__() 27 | self.block = nn.Sequential( 28 | conv3x3(channel_num, channel_num), 29 | nn.BatchNorm2d(channel_num), 30 | nn.ReLU(True), 31 | conv3x3(channel_num, channel_num), 32 | nn.BatchNorm2d(channel_num)) 33 | self.relu = nn.ReLU(inplace=True) 34 | 35 | def forward(self, x): 36 | residual = x 37 | out = self.block(x) 38 | out += residual 39 | out = self.relu(out) 40 | return out 41 | 42 | 43 | class CA_NET(nn.Module): 44 | # some code is modified from vae examples 45 | # (https://github.com/pytorch/examples/blob/master/vae/main.py) 46 | def __init__(self): 47 | super(CA_NET, self).__init__() 48 | self.t_dim = cfg.TEXT.DIMENSION 49 | self.c_dim = cfg.GAN.CONDITION_DIM 50 | self.fc = nn.Linear(self.t_dim, self.c_dim * 2, bias=True) 51 | self.relu = nn.ReLU() 52 | 53 | def encode(self, text_embedding): 54 | x = self.relu(self.fc(text_embedding)) 55 | mu = x[:, :self.c_dim] 56 | logvar = x[:, self.c_dim:] 57 | return mu, logvar 58 | 59 | def reparametrize(self, mu, logvar): 60 | std = logvar.mul(0.5).exp_() 61 | if cfg.CUDA: 62 | eps = torch.cuda.FloatTensor(std.size()).normal_() 63 | else: 64 | eps = torch.FloatTensor(std.size()).normal_() 65 | eps = Variable(eps) 66 | return eps.mul(std).add_(mu) 67 | 68 | def forward(self, text_embedding): 69 | mu, logvar = self.encode(text_embedding) 70 | c_code = self.reparametrize(mu, logvar) 71 | return c_code, mu, logvar 72 | 73 | 74 | class D_GET_LOGITS(nn.Module): 75 | def __init__(self, ndf, nef, bcondition=True): 76 | super(D_GET_LOGITS, self).__init__() 77 | self.df_dim = ndf 78 | self.ef_dim = nef 79 | self.bcondition = bcondition 80 | if bcondition: 81 | self.outlogits = nn.Sequential( 82 | conv3x3(ndf * 8 + nef, ndf * 8), 83 | nn.BatchNorm2d(ndf * 8), 84 | nn.LeakyReLU(0.2, inplace=True), 85 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), 86 | nn.Sigmoid()) 87 | else: 88 | self.outlogits = nn.Sequential( 89 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), 90 | nn.Sigmoid()) 91 | 92 | def forward(self, h_code, c_code=None): 93 | # conditioning output 94 | if self.bcondition and c_code is not None: 95 | c_code = c_code.view(-1, self.ef_dim, 1, 1) 96 | c_code = c_code.repeat(1, 1, 4, 4) 97 | # state size (ngf+egf) x 4 x 4 98 | h_c_code = torch.cat((h_code, c_code), 1) 99 | else: 100 | h_c_code = h_code 101 | 102 | output = self.outlogits(h_c_code) 103 | return output.view(-1) 104 | 105 | 106 | # ############# Networks for stageI GAN ############# 107 | class STAGE1_G(nn.Module): 108 | def __init__(self): 109 | super(STAGE1_G, self).__init__() 110 | self.gf_dim = cfg.GAN.GF_DIM * 8 111 | self.ef_dim = cfg.GAN.CONDITION_DIM 112 | self.z_dim = cfg.Z_DIM 113 | self.define_module() 114 | 115 | def define_module(self): 116 | ninput = self.z_dim + self.ef_dim 117 | ngf = self.gf_dim 118 | # TEXT.DIMENSION -> GAN.CONDITION_DIM 119 | self.ca_net = CA_NET() 120 | 121 | # -> ngf x 4 x 4 122 | self.fc = nn.Sequential( 123 | nn.Linear(ninput, ngf * 4 * 4, bias=False), 124 | nn.BatchNorm1d(ngf * 4 * 4), 125 | nn.ReLU(True)) 126 | 127 | # ngf x 4 x 4 -> ngf/2 x 8 x 8 128 | self.upsample1 = upBlock(ngf, ngf // 2) 129 | # -> ngf/4 x 16 x 16 130 | self.upsample2 = upBlock(ngf // 2, ngf // 4) 131 | # -> ngf/8 x 32 x 32 132 | self.upsample3 = upBlock(ngf // 4, ngf // 8) 133 | # -> ngf/16 x 64 x 64 134 | self.upsample4 = upBlock(ngf // 8, ngf // 16) 135 | # -> 3 x 64 x 64 136 | self.img = nn.Sequential( 137 | conv3x3(ngf // 16, 3), 138 | nn.Tanh()) 139 | 140 | def forward(self, text_embedding, noise): 141 | c_code, mu, logvar = self.ca_net(text_embedding) 142 | z_c_code = torch.cat((noise, c_code), 1) 143 | h_code = self.fc(z_c_code) 144 | 145 | h_code = h_code.view(-1, self.gf_dim, 4, 4) 146 | h_code = self.upsample1(h_code) 147 | h_code = self.upsample2(h_code) 148 | h_code = self.upsample3(h_code) 149 | h_code = self.upsample4(h_code) 150 | # state size 3 x 64 x 64 151 | fake_img = self.img(h_code) 152 | return None, fake_img, mu, logvar 153 | 154 | 155 | class STAGE1_D(nn.Module): 156 | def __init__(self): 157 | super(STAGE1_D, self).__init__() 158 | self.df_dim = cfg.GAN.DF_DIM 159 | self.ef_dim = cfg.GAN.CONDITION_DIM 160 | self.define_module() 161 | 162 | def define_module(self): 163 | ndf, nef = self.df_dim, self.ef_dim 164 | self.encode_img = nn.Sequential( 165 | nn.Conv2d(3, ndf, 4, 2, 1, bias=False), 166 | nn.LeakyReLU(0.2, inplace=True), 167 | # state size. (ndf) x 32 x 32 168 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 169 | nn.BatchNorm2d(ndf * 2), 170 | nn.LeakyReLU(0.2, inplace=True), 171 | # state size (ndf*2) x 16 x 16 172 | nn.Conv2d(ndf*2, ndf * 4, 4, 2, 1, bias=False), 173 | nn.BatchNorm2d(ndf * 4), 174 | nn.LeakyReLU(0.2, inplace=True), 175 | # state size (ndf*4) x 8 x 8 176 | nn.Conv2d(ndf*4, ndf * 8, 4, 2, 1, bias=False), 177 | nn.BatchNorm2d(ndf * 8), 178 | # state size (ndf * 8) x 4 x 4) 179 | nn.LeakyReLU(0.2, inplace=True) 180 | ) 181 | 182 | self.get_cond_logits = D_GET_LOGITS(ndf, nef) 183 | self.get_uncond_logits = None 184 | 185 | def forward(self, image): 186 | img_embedding = self.encode_img(image) 187 | 188 | return img_embedding 189 | 190 | 191 | # ############# Networks for stageII GAN ############# 192 | class STAGE2_G(nn.Module): 193 | def __init__(self, STAGE1_G): 194 | super(STAGE2_G, self).__init__() 195 | self.gf_dim = cfg.GAN.GF_DIM 196 | self.ef_dim = cfg.GAN.CONDITION_DIM 197 | self.z_dim = cfg.Z_DIM 198 | self.STAGE1_G = STAGE1_G 199 | # fix parameters of stageI GAN 200 | for param in self.STAGE1_G.parameters(): 201 | param.requires_grad = False 202 | self.define_module() 203 | 204 | def _make_layer(self, block, channel_num): 205 | layers = [] 206 | for i in range(cfg.GAN.R_NUM): 207 | layers.append(block(channel_num)) 208 | return nn.Sequential(*layers) 209 | 210 | def define_module(self): 211 | ngf = self.gf_dim 212 | # TEXT.DIMENSION -> GAN.CONDITION_DIM 213 | self.ca_net = CA_NET() 214 | # --> 4ngf x 16 x 16 215 | self.encoder = nn.Sequential( 216 | conv3x3(3, ngf), 217 | nn.ReLU(True), 218 | nn.Conv2d(ngf, ngf * 2, 4, 2, 1, bias=False), 219 | nn.BatchNorm2d(ngf * 2), 220 | nn.ReLU(True), 221 | nn.Conv2d(ngf * 2, ngf * 4, 4, 2, 1, bias=False), 222 | nn.BatchNorm2d(ngf * 4), 223 | nn.ReLU(True)) 224 | self.hr_joint = nn.Sequential( 225 | conv3x3(self.ef_dim + ngf * 4, ngf * 4), 226 | nn.BatchNorm2d(ngf * 4), 227 | nn.ReLU(True)) 228 | self.residual = self._make_layer(ResBlock, ngf * 4) 229 | # --> 2ngf x 32 x 32 230 | self.upsample1 = upBlock(ngf * 4, ngf * 2) 231 | # --> ngf x 64 x 64 232 | self.upsample2 = upBlock(ngf * 2, ngf) 233 | # --> ngf // 2 x 128 x 128 234 | self.upsample3 = upBlock(ngf, ngf // 2) 235 | # --> ngf // 4 x 256 x 256 236 | self.upsample4 = upBlock(ngf // 2, ngf // 4) 237 | # --> 3 x 256 x 256 238 | self.img = nn.Sequential( 239 | conv3x3(ngf // 4, 3), 240 | nn.Tanh()) 241 | 242 | def forward(self, text_embedding, noise): 243 | _, stage1_img, _, _ = self.STAGE1_G(text_embedding, noise) 244 | stage1_img = stage1_img.detach() 245 | encoded_img = self.encoder(stage1_img) 246 | 247 | c_code, mu, logvar = self.ca_net(text_embedding) 248 | c_code = c_code.view(-1, self.ef_dim, 1, 1) 249 | c_code = c_code.repeat(1, 1, 16, 16) 250 | i_c_code = torch.cat([encoded_img, c_code], 1) 251 | h_code = self.hr_joint(i_c_code) 252 | h_code = self.residual(h_code) 253 | 254 | h_code = self.upsample1(h_code) 255 | h_code = self.upsample2(h_code) 256 | h_code = self.upsample3(h_code) 257 | h_code = self.upsample4(h_code) 258 | 259 | fake_img = self.img(h_code) 260 | return stage1_img, fake_img, mu, logvar 261 | 262 | 263 | class STAGE2_D(nn.Module): 264 | def __init__(self): 265 | super(STAGE2_D, self).__init__() 266 | self.df_dim = cfg.GAN.DF_DIM 267 | self.ef_dim = cfg.GAN.CONDITION_DIM 268 | self.define_module() 269 | 270 | def define_module(self): 271 | ndf, nef = self.df_dim, self.ef_dim 272 | self.encode_img = nn.Sequential( 273 | nn.Conv2d(3, ndf, 4, 2, 1, bias=False), # 128 * 128 * ndf 274 | nn.LeakyReLU(0.2, inplace=True), 275 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 276 | nn.BatchNorm2d(ndf * 2), 277 | nn.LeakyReLU(0.2, inplace=True), # 64 * 64 * ndf * 2 278 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 279 | nn.BatchNorm2d(ndf * 4), 280 | nn.LeakyReLU(0.2, inplace=True), # 32 * 32 * ndf * 4 281 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 282 | nn.BatchNorm2d(ndf * 8), 283 | nn.LeakyReLU(0.2, inplace=True), # 16 * 16 * ndf * 8 284 | nn.Conv2d(ndf * 8, ndf * 16, 4, 2, 1, bias=False), 285 | nn.BatchNorm2d(ndf * 16), 286 | nn.LeakyReLU(0.2, inplace=True), # 8 * 8 * ndf * 16 287 | nn.Conv2d(ndf * 16, ndf * 32, 4, 2, 1, bias=False), 288 | nn.BatchNorm2d(ndf * 32), 289 | nn.LeakyReLU(0.2, inplace=True), # 4 * 4 * ndf * 32 290 | conv3x3(ndf * 32, ndf * 16), 291 | nn.BatchNorm2d(ndf * 16), 292 | nn.LeakyReLU(0.2, inplace=True), # 4 * 4 * ndf * 16 293 | conv3x3(ndf * 16, ndf * 8), 294 | nn.BatchNorm2d(ndf * 8), 295 | nn.LeakyReLU(0.2, inplace=True) # 4 * 4 * ndf * 8 296 | ) 297 | 298 | self.get_cond_logits = D_GET_LOGITS(ndf, nef, bcondition=True) 299 | self.get_uncond_logits = D_GET_LOGITS(ndf, nef, bcondition=False) 300 | 301 | def forward(self, image): 302 | img_embedding = self.encode_img(image) 303 | 304 | return img_embedding 305 | -------------------------------------------------------------------------------- /texttoimage/code/trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from six.moves import range 3 | from PIL import Image 4 | 5 | import torch.backends.cudnn as cudnn 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | import torch.optim as optim 10 | import os 11 | import time 12 | 13 | import numpy as np 14 | import torchfile 15 | 16 | from miscc.config import cfg 17 | from miscc.utils import mkdir_p 18 | from miscc.utils import weights_init 19 | from miscc.utils import save_img_results, save_model 20 | from miscc.utils import KL_loss 21 | from miscc.utils import compute_discriminator_loss, compute_generator_loss 22 | 23 | from tensorboard import summary 24 | from tensorboard import FileWriter 25 | 26 | 27 | class GANTrainer(object): 28 | def __init__(self, output_dir): 29 | if cfg.TRAIN.FLAG: 30 | self.model_dir = os.path.join(output_dir, 'Model') 31 | self.image_dir = os.path.join(output_dir, 'Image') 32 | self.log_dir = os.path.join(output_dir, 'Log') 33 | mkdir_p(self.model_dir) 34 | mkdir_p(self.image_dir) 35 | mkdir_p(self.log_dir) 36 | self.summary_writer = FileWriter(self.log_dir) 37 | 38 | self.max_epoch = cfg.TRAIN.MAX_EPOCH 39 | self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL 40 | 41 | s_gpus = cfg.GPU_ID.split(',') 42 | self.gpus = [int(ix) for ix in s_gpus] 43 | self.num_gpus = len(self.gpus) 44 | self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus 45 | torch.cuda.set_device(self.gpus[0]) 46 | cudnn.benchmark = True 47 | 48 | # ############# For training stageI GAN ############# 49 | def load_network_stageI(self): 50 | from model import STAGE1_G, STAGE1_D 51 | netG = STAGE1_G() 52 | netG.apply(weights_init) 53 | print(netG) 54 | netD = STAGE1_D() 55 | netD.apply(weights_init) 56 | print(netD) 57 | 58 | if cfg.NET_G != '': 59 | state_dict = \ 60 | torch.load(cfg.NET_G, 61 | map_location=lambda storage, loc: storage) 62 | netG.load_state_dict(state_dict) 63 | print('Load from: ', cfg.NET_G) 64 | if cfg.NET_D != '': 65 | state_dict = \ 66 | torch.load(cfg.NET_D, 67 | map_location=lambda storage, loc: storage) 68 | netD.load_state_dict(state_dict) 69 | print('Load from: ', cfg.NET_D) 70 | if cfg.CUDA: 71 | netG.cuda() 72 | netD.cuda() 73 | return netG, netD 74 | 75 | # ############# For training stageII GAN ############# 76 | def load_network_stageII(self): 77 | from model import STAGE1_G, STAGE2_G, STAGE2_D 78 | 79 | Stage1_G = STAGE1_G() 80 | netG = STAGE2_G(Stage1_G) 81 | netG.apply(weights_init) 82 | print(netG) 83 | if cfg.NET_G != '': 84 | state_dict = \ 85 | torch.load(cfg.NET_G, 86 | map_location=lambda storage, loc: storage) 87 | netG.load_state_dict(state_dict) 88 | print('Load from: ', cfg.NET_G) 89 | elif cfg.STAGE1_G != '': 90 | state_dict = \ 91 | torch.load(cfg.STAGE1_G, 92 | map_location=lambda storage, loc: storage) 93 | netG.STAGE1_G.load_state_dict(state_dict) 94 | print('Load from: ', cfg.STAGE1_G) 95 | else: 96 | print("Please give the Stage1_G path") 97 | return 98 | 99 | netD = STAGE2_D() 100 | netD.apply(weights_init) 101 | if cfg.NET_D != '': 102 | state_dict = \ 103 | torch.load(cfg.NET_D, 104 | map_location=lambda storage, loc: storage) 105 | netD.load_state_dict(state_dict) 106 | print('Load from: ', cfg.NET_D) 107 | print(netD) 108 | 109 | if cfg.CUDA: 110 | netG.cuda() 111 | netD.cuda() 112 | return netG, netD 113 | 114 | def train(self, data_loader, stage=1): 115 | if stage == 1: 116 | netG, netD = self.load_network_stageI() 117 | else: 118 | netG, netD = self.load_network_stageII() 119 | 120 | nz = cfg.Z_DIM 121 | batch_size = self.batch_size 122 | noise = Variable(torch.FloatTensor(batch_size, nz)) 123 | fixed_noise = \ 124 | Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), 125 | volatile=True) 126 | real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) 127 | fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) 128 | if cfg.CUDA: 129 | noise, fixed_noise = noise.cuda(), fixed_noise.cuda() 130 | real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() 131 | 132 | generator_lr = cfg.TRAIN.GENERATOR_LR 133 | discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR 134 | lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH 135 | optimizerD = \ 136 | optim.Adam(netD.parameters(), 137 | lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) 138 | netG_para = [] 139 | for p in netG.parameters(): 140 | if p.requires_grad: 141 | netG_para.append(p) 142 | optimizerG = optim.Adam(netG_para, 143 | lr=cfg.TRAIN.GENERATOR_LR, 144 | betas=(0.5, 0.999)) 145 | count = 0 146 | for epoch in range(self.max_epoch): 147 | start_t = time.time() 148 | if epoch % lr_decay_step == 0 and epoch > 0: 149 | generator_lr *= 0.5 150 | for param_group in optimizerG.param_groups: 151 | param_group['lr'] = generator_lr 152 | discriminator_lr *= 0.5 153 | for param_group in optimizerD.param_groups: 154 | param_group['lr'] = discriminator_lr 155 | 156 | for i, data in enumerate(data_loader, 0): 157 | ###################################################### 158 | # (1) Prepare training data 159 | ###################################################### 160 | real_img_cpu, txt_embedding = data 161 | real_imgs = Variable(real_img_cpu) 162 | txt_embedding = Variable(txt_embedding) 163 | if cfg.CUDA: 164 | real_imgs = real_imgs.cuda() 165 | txt_embedding = txt_embedding.cuda() 166 | 167 | ####################################################### 168 | # (2) Generate fake images 169 | ###################################################### 170 | noise.data.normal_(0, 1) 171 | inputs = (txt_embedding, noise) 172 | _, fake_imgs, mu, logvar = \ 173 | nn.parallel.data_parallel(netG, inputs, self.gpus) 174 | 175 | ############################ 176 | # (3) Update D network 177 | ########################### 178 | netD.zero_grad() 179 | errD, errD_real, errD_wrong, errD_fake = \ 180 | compute_discriminator_loss(netD, real_imgs, fake_imgs, 181 | real_labels, fake_labels, 182 | mu, self.gpus) 183 | errD.backward() 184 | optimizerD.step() 185 | ############################ 186 | # (2) Update G network 187 | ########################### 188 | netG.zero_grad() 189 | errG = compute_generator_loss(netD, fake_imgs, 190 | real_labels, mu, self.gpus) 191 | kl_loss = KL_loss(mu, logvar) 192 | errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL 193 | errG_total.backward() 194 | optimizerG.step() 195 | 196 | count = count + 1 197 | if i % 10 == 0: 198 | summary_D = summary.scalar('D_loss', errD.data[0]) 199 | summary_D_r = summary.scalar('D_loss_real', errD_real) 200 | summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) 201 | summary_D_f = summary.scalar('D_loss_fake', errD_fake) 202 | summary_G = summary.scalar('G_loss', errG.data[0]) 203 | summary_KL = summary.scalar('KL_loss', kl_loss.data[0]) 204 | 205 | self.summary_writer.add_summary(summary_D, count) 206 | self.summary_writer.add_summary(summary_D_r, count) 207 | self.summary_writer.add_summary(summary_D_w, count) 208 | self.summary_writer.add_summary(summary_D_f, count) 209 | self.summary_writer.add_summary(summary_G, count) 210 | self.summary_writer.add_summary(summary_KL, count) 211 | 212 | # save the image result for each epoch 213 | inputs = (txt_embedding, fixed_noise) 214 | lr_fake, fake, _, _ = \ 215 | nn.parallel.data_parallel(netG, inputs, self.gpus) 216 | save_img_results(real_img_cpu, fake, epoch, self.image_dir) 217 | if lr_fake is not None: 218 | save_img_results(None, lr_fake, epoch, self.image_dir) 219 | end_t = time.time() 220 | print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f 221 | Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f 222 | Total Time: %.2fsec 223 | ''' 224 | % (epoch, self.max_epoch, i, len(data_loader), 225 | errD.data[0], errG.data[0], kl_loss.data[0], 226 | errD_real, errD_wrong, errD_fake, (end_t - start_t))) 227 | if epoch % self.snapshot_interval == 0: 228 | save_model(netG, netD, epoch, self.model_dir) 229 | # 230 | save_model(netG, netD, self.max_epoch, self.model_dir) 231 | # 232 | self.summary_writer.close() 233 | 234 | def sample(self, datapath, stage=1): 235 | if stage == 1: 236 | netG, _ = self.load_network_stageI() 237 | else: 238 | netG, _ = self.load_network_stageII() 239 | netG.eval() 240 | 241 | # Load text embeddings generated from the encoder 242 | t_file = torchfile.load(datapath) 243 | captions_list = t_file.raw_txt 244 | embeddings = np.concatenate(t_file.fea_txt, axis=0) 245 | num_embeddings = len(captions_list) 246 | print('Successfully load sentences from: ', datapath) 247 | print('Total number of sentences:', num_embeddings) 248 | print('num_embeddings:', num_embeddings, embeddings.shape) 249 | # path to save generated samples 250 | save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')] 251 | mkdir_p(save_dir) 252 | 253 | batch_size = self.batch_size 254 | nz = cfg.Z_DIM 255 | noise = Variable(torch.FloatTensor(batch_size, nz)) 256 | if cfg.CUDA: 257 | noise = noise.cuda() 258 | count = 0 259 | while count < num_embeddings: 260 | if count > 3000: 261 | break 262 | iend = count + batch_size 263 | if iend > num_embeddings: 264 | iend = num_embeddings 265 | count = num_embeddings - batch_size 266 | embeddings_batch = embeddings[count:iend] 267 | # captions_batch = captions_list[count:iend] 268 | txt_embedding = Variable(torch.FloatTensor(embeddings_batch)) 269 | if cfg.CUDA: 270 | txt_embedding = txt_embedding.cuda() 271 | 272 | ####################################################### 273 | # (2) Generate fake images 274 | ###################################################### 275 | noise.data.normal_(0, 1) 276 | inputs = (txt_embedding, noise) 277 | _, fake_imgs, mu, logvar = \ 278 | nn.parallel.data_parallel(netG, inputs, self.gpus) 279 | for i in range(batch_size): 280 | save_name = '%s/%d.png' % (save_dir, count + i) 281 | im = fake_imgs[i].data.cpu().numpy() 282 | im = (im + 1.0) * 127.5 283 | im = im.astype(np.uint8) 284 | # print('im', im.shape) 285 | im = np.transpose(im, (1, 2, 0)) 286 | # print('im', im.shape) 287 | im = Image.fromarray(im) 288 | im.save(save_name) 289 | count += batch_size 290 | 291 | -------------------------------------------------------------------------------- /texttoimage/readme.md: -------------------------------------------------------------------------------- 1 | ## Generation of images based on text captions 2 | 3 | visit the wiki page with the same title to learn more about this code 4 | -------------------------------------------------------------------------------- /videogeneration/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eYSIP-2018/Text_to_Image-Video_synthesis_using_GANs/7331a5aa6d898e7481206a406c554dda825bc971/videogeneration/.DS_Store -------------------------------------------------------------------------------- /videogeneration/data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eYSIP-2018/Text_to_Image-Video_synthesis_using_GANs/7331a5aa6d898e7481206a406c554dda825bc971/videogeneration/data/.DS_Store -------------------------------------------------------------------------------- /videogeneration/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import h5py 5 | from PIL import Image,ImageSequence 6 | from torchvision.datasets import ImageFolder 7 | import tqdm 8 | import zipfile 9 | from requests import get 10 | 11 | """ 12 | Description: 13 | This file will download the dataset and unzip it and then preprocess it as required. 14 | once you run this file dataset would be prepared in the kth folder , and it will need 15 | to be copied to data/actions 16 | """ 17 | 18 | 19 | 20 | 21 | os.makedirs('kth',exist_ok=True) 22 | 23 | def download(url, file_name): 24 | with open(file_name, "wb") as file: 25 | print("downloading dataset file",file_name) 26 | response = get(url) 27 | file.write(response.content) 28 | print("{} downloading complete".format(file_name)) 29 | with zipfile.ZipFile(file_name, "r") as zip_ref: 30 | zip_ref.extractall('kth') 31 | 32 | def convertFile(inputpath, targetFormat='.gif'): 33 | outputpath = os.path.splitext(inputpath)[0] + targetFormat 34 | reader = imageio.get_reader(inputpath) 35 | fps = reader.get_meta_data()['fps'] 36 | writer = imageio.get_writer(outputpath, fps=fps) 37 | for i,im in enumerate(reader): 38 | sys.stdout.write("\rframe {0}".format(i)) 39 | sys.stdout.flush() 40 | writer.append_data(im) 41 | writer.close() 42 | 43 | dict = {'boxing':0,'handclapping':1,'handwaving':2,'jogging':3,'running':4,'walking':5} 44 | for i in dict.keys(): 45 | download('http://www.nada.kth.se/cvap/actions/{}.zip'.format(i),'kth_{}.zip'.format(i)) 46 | 47 | for i in os.listdir('kth'): 48 | convertFile(i) 49 | 50 | for idx, entry in enumerate(os.listdir('kth')): 51 | if 'd2' in entry: 52 | continue 53 | parts = entry.split("_") 54 | os.makedirs("kth_out/{0}".format(lower(dict[parts[1]])), exist_ok=True) 55 | try: 56 | imageObject = Image.open("kth/{0}".format(entry)) 57 | vid_iter = ImageSequence.Iterator(imageObject) 58 | vid_frames = [cv2.cvtColor(np.array(img.convert('RGB')), cv2.COLOR_BGR2RGB) for img in vid_iter] 59 | vid_frames = [cv2.copyMakeBorder(img, 20, 20, 0, 0, cv2.BORDER_CONSTANT, value=[0, 0, 0]) for img in vid_frames] 60 | idy = str(idx + 100000001) 61 | idy = idy[1:] 62 | image = np.hstack(vid_frames) 63 | cv2.imwrite("kth_out/{0}/{1}.png".format(parts[1],idy),image) 64 | except KeyboardInterrupt: 65 | exit() 66 | except Exception as e: 67 | print("error is", e) 68 | pass 69 | 70 | 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /videogeneration/readme.md: -------------------------------------------------------------------------------- 1 | ## generation of videos 2 | Visit the wiki page with the same title to learn more about this code. 3 | -------------------------------------------------------------------------------- /videogeneration/src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eYSIP-2018/Text_to_Image-Video_synthesis_using_GANs/7331a5aa6d898e7481206a406c554dda825bc971/videogeneration/src/.DS_Store -------------------------------------------------------------------------------- /videogeneration/src/data.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import tqdm 5 | import _pickle as pickle 6 | import numpy as np 7 | import torch.utils.data 8 | from torchvision.datasets import ImageFolder 9 | import PIL 10 | 11 | """ 12 | VideoFolderDataset 13 | Description: Loads the dataset folder. 14 | Constructor inputs: 15 | folder: path to folder containing images. 16 | cache: path to cache. 17 | 18 | __getitem__ 19 | input: 20 | item: integer index from 0 to length of dataset. 21 | outputs: 22 | im: image 23 | label: label of the image 24 | 25 | __len__ 26 | output: 27 | size of dataset. 28 | """ 29 | class VideoFolderDataset(torch.utils.data.Dataset): 30 | def __init__(self, folder, cache, min_len=32): 31 | dataset = ImageFolder(folder) 32 | self.total_frames = 0 33 | self.lengths = [] 34 | self.images = [] 35 | print(os.path.exists(cache)) 36 | 37 | if cache is not None and os.path.exists(cache): 38 | with open(cache, 'rb') as f: 39 | self.images, self.lengths = pickle.load(f, encoding='latin1') 40 | else: 41 | for idx, (im, categ) in enumerate( 42 | tqdm.tqdm(dataset, desc="Counting total number of frames")): 43 | img_path, _ = dataset.imgs[idx] 44 | shorter, longer = min(im.width, im.height), max(im.width, im.height) 45 | length = longer // shorter 46 | if length >= min_len: 47 | self.images.append((img_path, categ)) 48 | self.lengths.append(length) 49 | 50 | if cache is not None: 51 | with open(cache, 'wb') as f: 52 | # pickle.dump((self.images, self.lengths), open(cache,"wb")) 53 | pickle.dump((self.images, self.lengths), f) 54 | 55 | self.cumsum = np.cumsum([0] + self.lengths) 56 | print("Total number of frames {}".format(np.sum(self.lengths))) 57 | 58 | def __getitem__(self, item): 59 | path, label = self.images[item] 60 | im = PIL.Image.open(path) 61 | return im, label 62 | 63 | def __len__(self): 64 | return len(self.images) 65 | 66 | """ 67 | ImageDataset 68 | Description: Transforms the image and returns the frames. 69 | 70 | Constructor inputs: 71 | dataset: dataset as returned from VideoFolderDataset. 72 | transform: function to resize image, convert it to tensor and normalise the image. 73 | 74 | __getitem__ 75 | input: 76 | item: integer index from 0 to length of dataset. 77 | output: 78 | dictionary of images(frames of video) and their categories 79 | 80 | __len__ 81 | output: 82 | size of dataset. 83 | 84 | """ 85 | class ImageDataset(torch.utils.data.Dataset): 86 | def __init__(self, dataset, transform=None): 87 | self.dataset = dataset 88 | 89 | self.transforms = transform if transform is not None else lambda x: x 90 | 91 | def __getitem__(self, item): 92 | if item != 0: 93 | video_id = np.searchsorted(self.dataset.cumsum, item) - 1 94 | frame_num = item - self.dataset.cumsum[video_id] - 1 95 | else: 96 | video_id = 0 97 | frame_num = 0 98 | 99 | video, target = self.dataset[video_id] 100 | video = np.array(video) 101 | 102 | horizontal = video.shape[1] > video.shape[0] 103 | 104 | if horizontal: 105 | i_from, i_to = video.shape[0] * frame_num, video.shape[0] * (frame_num + 1) 106 | frame = video[:, i_from: i_to, ::] 107 | else: 108 | i_from, i_to = video.shape[1] * frame_num, video.shape[1] * (frame_num + 1) 109 | frame = video[i_from: i_to, :, ::] 110 | 111 | if frame.shape[0] == 0: 112 | print("video {}. From {} to {}. num {}".format(video.shape, i_from, i_to, item)) 113 | 114 | return {"images": self.transforms(frame), "categories": target} 115 | 116 | def __len__(self): 117 | return self.dataset.cumsum[-1] 118 | 119 | """ 120 | VideoDataset 121 | Description: Reduces the video to required size. 122 | 123 | Constructor inputs: 124 | dataset: dataset as returned from VideoFolderDataset 125 | video_length: number of frames required 126 | every_nth: training videos are sampled using every nth frame 127 | transform: video transform function 128 | 129 | __getitem__ 130 | input: 131 | item: integer index from 0 to length of dataset. 132 | output: 133 | dictionary of images after selecting the frames and their categories 134 | 135 | __len__ 136 | output: 137 | size of dataset. 138 | 139 | """ 140 | class VideoDataset(torch.utils.data.Dataset): 141 | def __init__(self, dataset, video_length, every_nth=1, transform=None): 142 | self.dataset = dataset 143 | self.video_length = video_length 144 | self.every_nth = every_nth 145 | self.transforms = transform if transform is not None else lambda x: x 146 | 147 | def __getitem__(self, item): 148 | video, target = self.dataset[item] 149 | video = np.array(video) 150 | 151 | horizontal = video.shape[1] > video.shape[0] 152 | shorter, longer = min(video.shape[0], video.shape[1]), max(video.shape[0], video.shape[1]) 153 | video_len = longer // shorter 154 | 155 | # videos can be of various length, we randomly sample sub-sequences 156 | if video_len > self.video_length * self.every_nth: 157 | needed = self.every_nth * (self.video_length - 1) 158 | gap = video_len - needed 159 | start = 0 if gap == 0 else np.random.randint(0, gap, 1)[0] 160 | subsequence_idx = np.linspace(start, start + needed, self.video_length, endpoint=True, dtype=np.int32) 161 | elif video_len >= self.video_length: 162 | subsequence_idx = np.arange(0, self.video_length) 163 | else: 164 | raise Exception("Length is too short id - {}, len - {}").format(self.dataset[item], video_len) 165 | 166 | frames = np.split(video, video_len, axis=1 if horizontal else 0) 167 | selected = np.array([frames[s_id] for s_id in subsequence_idx]) 168 | 169 | return {"images": self.transforms(selected), "categories": target} 170 | 171 | def __len__(self): 172 | return len(self.dataset) 173 | 174 | 175 | """ 176 | ImageSampler 177 | Description: Transforms the image and returns the frames. 178 | 179 | Constructor inputs: 180 | dataset: dataset as returned from VideoFolderDataset. 181 | transform: function to resize image, convert it to tensor and normalise the image. 182 | 183 | __getitem__ 184 | input: 185 | item: integer index from 0 to length of dataset. 186 | output: 187 | dictionary of images(frames of video) and their categories 188 | 189 | __len__ 190 | output: 191 | size of dataset. 192 | """ 193 | class ImageSampler(torch.utils.data.Dataset): 194 | def __init__(self, dataset, transform=None): 195 | self.dataset = dataset 196 | self.transforms = transform 197 | 198 | def __getitem__(self, index): 199 | result = {} 200 | for k in self.dataset.keys: 201 | result[k] = np.take(self.dataset.get_data()[k], index, axis=0) 202 | 203 | if self.transforms is not None: 204 | for k, transform in self.transforms.iteritems(): 205 | result[k] = transform(result[k]) 206 | 207 | return result 208 | 209 | def __len__(self): 210 | return self.dataset.get_data()[self.dataset.keys[0]].shape[0] 211 | 212 | """ 213 | VideoSampler 214 | Description: Reduces the video to required size. 215 | 216 | Constructor inputs: 217 | dataset: dataset 218 | video_length: desired length of video 219 | every_nth: training videos are sampled using every nth frame 220 | transform: video transform function 221 | 222 | __getitem__ 223 | input: 224 | item: integer index from 0 to length of dataset. 225 | output: 226 | 227 | __len__ 228 | output: 229 | size of dataset. 230 | """ 231 | 232 | class VideoSampler(torch.utils.data.Dataset): 233 | def __init__(self, dataset, video_length, every_nth=1, transform=None): 234 | self.dataset = dataset 235 | self.video_length = video_length 236 | self.unique_ids = np.unique(self.dataset.get_data()['video_ids']) 237 | self.every_nth = every_nth 238 | self.transforms = transform 239 | 240 | def __getitem__(self, item): 241 | result = {} 242 | ids = self.dataset.get_data()['video_ids'] == self.unique_ids[item] 243 | ids = np.squeeze(np.squeeze(np.argwhere(ids))) 244 | for k in self.dataset.keys: 245 | result[k] = np.take(self.dataset.get_data()[k], ids, axis=0) 246 | 247 | subsequence_idx = None 248 | print(result[k].shape[0]) 249 | 250 | # videos can be of various length, we randomly sample sub-sequences 251 | if result[k].shape[0] > self.video_length: 252 | needed = self.every_nth * (self.video_length - 1) 253 | gap = result[k].shape[0] - needed 254 | start = 0 if gap == 0 else np.random.randint(0, gap, 1)[0] 255 | subsequence_idx = np.linspace(start, start + needed, self.video_length, endpoint=True, dtype=np.int32) 256 | elif result[k].shape[0] == self.video_length: 257 | subsequence_idx = np.arange(0, self.video_length) 258 | else: 259 | print("Length is too short id - {}, len - {}".format(self.unique_ids[item], result[k].shape[0])) 260 | 261 | if subsequence_idx: 262 | for k in self.dataset.keys: 263 | result[k] = np.take(result[k], subsequence_idx, axis=0) 264 | else: 265 | print(result[self.dataset.keys[0]].shape) 266 | 267 | if self.transforms is not None: 268 | for k, transform in self.transforms.iteritems(): 269 | result[k] = transform(result[k]) 270 | 271 | return result 272 | 273 | def __len__(self): 274 | return len(self.unique_ids) -------------------------------------------------------------------------------- /videogeneration/src/demo_video.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stderr", 17 | "output_type": "stream", 18 | "text": [ 19 | "/home/text-to-image/.local/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", 20 | " from ._conv import register_converters as _register_converters\n" 21 | ] 22 | } 23 | ], 24 | "source": [ 25 | "import os\n", 26 | "import docopt\n", 27 | "import torch\n", 28 | "import numpy as np\n", 29 | "from trainers import videos_to_numpy\n", 30 | "import cv2\n", 31 | "import subprocess as sp\n", 32 | "import imageio\n", 33 | "from IPython.core.display import Image,HTML,display\n" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 10, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | "running\n" 46 | ] 47 | }, 48 | { 49 | "data": { 50 | "text/html": [ 51 | "" 52 | ], 53 | "text/plain": [ 54 | "" 55 | ] 56 | }, 57 | "metadata": {}, 58 | "output_type": "display_data" 59 | }, 60 | { 61 | "data": { 62 | "text/html": [ 63 | "" 64 | ], 65 | "text/plain": [ 66 | "" 67 | ] 68 | }, 69 | "metadata": {}, 70 | "output_type": "display_data" 71 | }, 72 | { 73 | "data": { 74 | "text/html": [ 75 | "" 76 | ], 77 | "text/plain": [ 78 | "" 79 | ] 80 | }, 81 | "metadata": {}, 82 | "output_type": "display_data" 83 | }, 84 | { 85 | "data": { 86 | "text/html": [ 87 | "" 88 | ], 89 | "text/plain": [ 90 | "" 91 | ] 92 | }, 93 | "metadata": {}, 94 | "output_type": "display_data" 95 | }, 96 | { 97 | "data": { 98 | "text/html": [ 99 | "" 100 | ], 101 | "text/plain": [ 102 | "" 103 | ] 104 | }, 105 | "metadata": {}, 106 | "output_type": "display_data" 107 | }, 108 | { 109 | "data": { 110 | "text/html": [ 111 | "" 112 | ], 113 | "text/plain": [ 114 | "" 115 | ] 116 | }, 117 | "metadata": {}, 118 | "output_type": "display_data" 119 | }, 120 | { 121 | "data": { 122 | "text/html": [ 123 | "" 124 | ], 125 | "text/plain": [ 126 | "" 127 | ] 128 | }, 129 | "metadata": {}, 130 | "output_type": "display_data" 131 | }, 132 | { 133 | "data": { 134 | "text/html": [ 135 | "" 136 | ], 137 | "text/plain": [ 138 | "" 139 | ] 140 | }, 141 | "metadata": {}, 142 | "output_type": "display_data" 143 | }, 144 | { 145 | "data": { 146 | "text/html": [ 147 | "" 148 | ], 149 | "text/plain": [ 150 | "" 151 | ] 152 | }, 153 | "metadata": {}, 154 | "output_type": "display_data" 155 | }, 156 | { 157 | "data": { 158 | "text/html": [ 159 | "" 160 | ], 161 | "text/plain": [ 162 | "" 163 | ] 164 | }, 165 | "metadata": {}, 166 | "output_type": "display_data" 167 | } 168 | ], 169 | "source": [ 170 | " model = \"../logs/actions/generator_46000.pytorch\"\n", 171 | " number_videos = 10\n", 172 | " number_frames = 32\n", 173 | " generator = torch.load(model)\n", 174 | " generator.eval()\n", 175 | " num_videos = int(number_videos)\n", 176 | " category = input()\n", 177 | " dict = {'boxing':0,'handclapping':1,'handwaving':2,'jogging':3,'running':4,'walking':5}\n", 178 | " c = dict[category]\n", 179 | " \n", 180 | " for i in range(num_videos):\n", 181 | " v, cats = generator.sample_videos(1,number_frames ,category=c)\n", 182 | " cat = cats.data.cpu().numpy()\n", 183 | " video = videos_to_numpy(v).squeeze().transpose((1, 2, 3, 0))\n", 184 | " imageio.mimsave(\"video{0}_{1}.gif\".format(i,category),video)\n", 185 | " display(HTML(''.format(i,category)))\n" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [] 194 | } 195 | ], 196 | "metadata": { 197 | "kernelspec": { 198 | "display_name": "Python 3", 199 | "language": "python", 200 | "name": "python3" 201 | }, 202 | "language_info": { 203 | "codemirror_mode": { 204 | "name": "ipython", 205 | "version": 3 206 | }, 207 | "file_extension": ".py", 208 | "mimetype": "text/x-python", 209 | "name": "python", 210 | "nbconvert_exporter": "python", 211 | "pygments_lexer": "ipython3", 212 | "version": "3.5.2" 213 | } 214 | }, 215 | "nbformat": 4, 216 | "nbformat_minor": 2 217 | } 218 | -------------------------------------------------------------------------------- /videogeneration/src/generate_videos.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Usage: 4 | generate_videos.py [options] 5 | 6 | Options: 7 | -n, --num_videos= number of videos to generate [default: 10] 8 | -o, --output_format= save videos as [default: gif] 9 | -f, --number_of_frames= generate videos with that many frames [default: 16] 10 | -c, --category= generate videos for a specific category [default: running] 11 | 12 | """ 13 | 14 | import os 15 | import docopt 16 | import torch 17 | import numpy as np 18 | from trainers import videos_to_numpy 19 | import cv2 20 | import subprocess as sp 21 | import imageio 22 | 23 | 24 | if __name__ == "__main__": 25 | args = docopt.docopt(__doc__) 26 | generator = torch.load(args[""]) 27 | generator.eval() 28 | num_videos = int(args['--num_videos']) 29 | output_folder = args[''] 30 | category = args['--category'].lower() 31 | 32 | dict = {'boxing':0,'handclapping':1,'handwaving':2,'jogging':3,'running':4,'walking':5} 33 | c = dict[category] 34 | 35 | if not os.path.exists(output_folder): 36 | os.makedirs(output_folder) 37 | 38 | for i in range(num_videos): 39 | v, cats = generator.sample_videos(1, int(args['--number_of_frames']),category=c) 40 | cat = cats.data.cpu().numpy() 41 | video = videos_to_numpy(v).squeeze().transpose((1, 2, 3, 0)) 42 | imageio.mimsave("video{0}_{1}.gif".format(i,category),video) 43 | # save_video(args["--ffmpeg"], video, os.path.join(output_folder, "{}.{}".format(i, args['--output_format']))) 44 | -------------------------------------------------------------------------------- /videogeneration/src/logger.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | import numpy as np 4 | import scipy.misc 5 | 6 | try: 7 | from StringIO import StringIO # Python 2.7 8 | except ImportError: 9 | from io import BytesIO # Python 3.x 10 | 11 | 12 | class Logger(object): 13 | def __init__(self, log_dir, suffix=None): 14 | self.writer = tf.summary.FileWriter(log_dir, filename_suffix=suffix) 15 | 16 | def scalar_summary(self, tag, value, step): 17 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 18 | self.writer.add_summary(summary, step) 19 | 20 | def image_summary(self, tag, images, step): 21 | 22 | img_summaries = [] 23 | for i, img in enumerate(images): 24 | # Write the image to a string 25 | try: 26 | s = StringIO() 27 | except: 28 | s = BytesIO() 29 | scipy.misc.toimage(img).save(s, format="png") 30 | 31 | # Create an Image object 32 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 33 | height=img.shape[0], 34 | width=img.shape[1]) 35 | # Create a Summary value 36 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 37 | 38 | # Create and write Summary 39 | summary = tf.Summary(value=img_summaries) 40 | self.writer.add_summary(summary, step) 41 | self.writer.flush() 42 | 43 | def video_summary(self, tag, videos, step): 44 | 45 | sh = list(videos.shape) 46 | sh[-1] = 1 47 | 48 | separator = np.zeros(sh, dtype=videos.dtype) 49 | videos = np.concatenate([videos, separator], axis=-1) 50 | 51 | img_summaries = [] 52 | for i, vid in enumerate(videos): 53 | # Concat a video 54 | try: 55 | s = StringIO() 56 | except: 57 | s = BytesIO() 58 | 59 | v = vid.transpose(1, 2, 3, 0) 60 | v = [np.squeeze(f) for f in np.split(v, v.shape[0], axis=0)] 61 | img = np.concatenate(v, axis=1)[:, :-1, :] 62 | 63 | scipy.misc.toimage(img).save(s, format="png") 64 | 65 | # Create an Image object 66 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 67 | height=img.shape[0], 68 | width=img.shape[1]) 69 | # Create a Summary value 70 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 71 | 72 | # Create and write Summary 73 | summary = tf.Summary(value=img_summaries) 74 | self.writer.add_summary(summary, step) 75 | self.writer.flush() 76 | -------------------------------------------------------------------------------- /videogeneration/src/models.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.parallel 5 | import torch.utils.data 6 | from torch.autograd import Variable 7 | 8 | import numpy as np 9 | 10 | if torch.cuda.is_available(): 11 | T = torch.cuda 12 | else: 13 | T = torch 14 | 15 | """ 16 | Description: 17 | Noise class is used to add noise to input before it is forwarded to the next layer 18 | constructor inputs: 19 | use_noise : This variable tells the class weather noise is required or not 20 | sigma : This variable stores noise factor 21 | 22 | input: image tensor 23 | 24 | outputs: 25 | a floatTensor with noise augumented with it. 26 | """ 27 | class Noise(nn.Module): 28 | def __init__(self, use_noise, sigma=0.2): 29 | super(Noise, self).__init__() 30 | self.use_noise = use_noise 31 | self.sigma = sigma 32 | 33 | def forward(self, x): 34 | if self.use_noise: 35 | return x + self.sigma * Variable(T.FloatTensor(x.size()).normal_(), requires_grad=False) 36 | return x 37 | 38 | """ 39 | Description: 40 | ImageDiscriminator disriminates each and every frame of the video as either fake or real it uses several convolutional layers and auguments the input tensor with noise. 41 | 42 | constructor inputs: 43 | n_channels: Number of channels in input image. 44 | ndf: the output channel factor , decides the number of output channels in every layer. 45 | use_noise: Tells the discrimiator to augument noise or not. 46 | noise_sigma: The noise factor if noise is being used. 47 | 48 | input: 49 | Image tensor 50 | 51 | outputs: 52 | a 1 dimensional vector describing if image is real. 53 | 54 | """ 55 | 56 | 57 | class ImageDiscriminator(nn.Module): 58 | def __init__(self, n_channels, ndf=64, use_noise=False, noise_sigma=None): 59 | super(ImageDiscriminator, self).__init__() 60 | 61 | self.use_noise = use_noise 62 | 63 | self.main = nn.Sequential( 64 | Noise(use_noise, sigma=noise_sigma), 65 | nn.Conv2d(n_channels, ndf, 4, 2, 1, bias=False), 66 | nn.LeakyReLU(0.2, inplace=True), 67 | 68 | Noise(use_noise, sigma=noise_sigma), 69 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 70 | nn.BatchNorm2d(ndf * 2), 71 | nn.LeakyReLU(0.2, inplace=True), 72 | 73 | Noise(use_noise, sigma=noise_sigma), 74 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 75 | nn.BatchNorm2d(ndf * 4), 76 | nn.LeakyReLU(0.2, inplace=True), 77 | 78 | Noise(use_noise, sigma=noise_sigma), 79 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 80 | nn.BatchNorm2d(ndf * 8), 81 | nn.LeakyReLU(0.2, inplace=True), 82 | 83 | nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), 84 | ) 85 | 86 | def forward(self, input): 87 | h = self.main(input).squeeze() 88 | return h, None 89 | 90 | 91 | class PatchImageDiscriminator(nn.Module): 92 | def __init__(self, n_channels, ndf=64, use_noise=False, noise_sigma=None): 93 | super(PatchImageDiscriminator, self).__init__() 94 | 95 | self.use_noise = use_noise 96 | 97 | self.main = nn.Sequential( 98 | Noise(use_noise, sigma=noise_sigma), 99 | nn.Conv2d(n_channels, ndf, 4, 2, 1, bias=False), 100 | nn.LeakyReLU(0.2, inplace=True), 101 | 102 | Noise(use_noise, sigma=noise_sigma), 103 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 104 | nn.BatchNorm2d(ndf * 2), 105 | nn.LeakyReLU(0.2, inplace=True), 106 | 107 | Noise(use_noise, sigma=noise_sigma), 108 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 109 | nn.BatchNorm2d(ndf * 4), 110 | nn.LeakyReLU(0.2, inplace=True), 111 | 112 | Noise(use_noise, sigma=noise_sigma), 113 | nn.Conv2d(ndf * 4, 1, 4, 2, 1, bias=False), 114 | ) 115 | 116 | def forward(self, input): 117 | h = self.main(input).squeeze() 118 | return h, None 119 | 120 | 121 | class PatchVideoDiscriminator(nn.Module): 122 | def __init__(self, n_channels, n_output_neurons=1, bn_use_gamma=True, use_noise=False, noise_sigma=None, ndf=64): 123 | super(PatchVideoDiscriminator, self).__init__() 124 | 125 | self.n_channels = n_channels 126 | self.n_output_neurons = n_output_neurons 127 | self.use_noise = use_noise 128 | self.bn_use_gamma = bn_use_gamma 129 | 130 | self.main = nn.Sequential( 131 | Noise(use_noise, sigma=noise_sigma), 132 | nn.Conv3d(n_channels, ndf, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 133 | nn.LeakyReLU(0.2, inplace=True), 134 | 135 | Noise(use_noise, sigma=noise_sigma), 136 | nn.Conv3d(ndf, ndf * 2, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 137 | nn.BatchNorm3d(ndf * 2), 138 | nn.LeakyReLU(0.2, inplace=True), 139 | 140 | Noise(use_noise, sigma=noise_sigma), 141 | nn.Conv3d(ndf * 2, ndf * 4, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 142 | nn.BatchNorm3d(ndf * 4), 143 | nn.LeakyReLU(0.2, inplace=True), 144 | 145 | nn.Conv3d(ndf * 4, 1, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 146 | ) 147 | 148 | def forward(self, input): 149 | h = self.main(input).squeeze() 150 | 151 | return h, None 152 | 153 | """ 154 | Description: 155 | the VideoDiscriminator is used to discriminate the overall video as real or fake by using several 3d spatiotemporal convolutions, Depending on the output neurons the video discriminator can be used to categorize the video as well, 156 | 157 | constructor inputs: 158 | n_channels: Number of channels in input image. 159 | n_output_neurons: Number of output neurons in the output this can be modified as per the number of categories. 160 | ndf: the output channel factor , decides the number of output channels in every layer. 161 | use_noise: Tells the discrimiator to augument noise or not. 162 | noise_sigma: The noise factor if noise is being used. 163 | 164 | inputs: 165 | Image tensor 166 | 167 | outputs: 168 | a 1 dimensional vector describing if the video is fake or not and belongs to which category. 169 | 170 | 171 | """ 172 | class VideoDiscriminator(nn.Module): 173 | def __init__(self, n_channels, n_output_neurons=1, bn_use_gamma=True, use_noise=False, noise_sigma=None, ndf=64): 174 | super(VideoDiscriminator, self).__init__() 175 | 176 | self.n_channels = n_channels 177 | self.n_output_neurons = n_output_neurons 178 | self.use_noise = use_noise 179 | self.bn_use_gamma = bn_use_gamma 180 | 181 | self.main = nn.Sequential( 182 | Noise(use_noise, sigma=noise_sigma), 183 | nn.Conv3d(n_channels, ndf, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 184 | nn.LeakyReLU(0.2, inplace=True), 185 | 186 | Noise(use_noise, sigma=noise_sigma), 187 | nn.Conv3d(ndf, ndf * 2, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 188 | nn.BatchNorm3d(ndf * 2), 189 | nn.LeakyReLU(0.2, inplace=True), 190 | 191 | Noise(use_noise, sigma=noise_sigma), 192 | nn.Conv3d(ndf * 2, ndf * 4, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 193 | nn.BatchNorm3d(ndf * 4), 194 | nn.LeakyReLU(0.2, inplace=True), 195 | 196 | Noise(use_noise, sigma=noise_sigma), 197 | nn.Conv3d(ndf * 4, ndf * 8, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False), 198 | nn.BatchNorm3d(ndf * 8), 199 | nn.LeakyReLU(0.2, inplace=True), 200 | 201 | nn.Conv3d(ndf * 8, n_output_neurons, 4, 1, 0, bias=False), 202 | ) 203 | 204 | def forward(self, input): 205 | h = self.main(input).squeeze() 206 | 207 | return h, None 208 | """ 209 | Description: 210 | The categoricalVideoDiscriminator builds upon the videoDiscriminator so that video categories can be found out, it splits the output from the video Discriminator into the label of the video and the categories. 211 | 212 | constructor inputs: 213 | dim_categorical: Number of dimensions of the categories ie the number of the dimensions the video can belong to. 214 | n_channels: Number of channels in input image. 215 | n_output_neurons: Number of output neurons in the output this can be modified as per the number of categories. 216 | use_noise: Tells the discrimiator to augument noise or not. 217 | noise_sigma: The noise factor if noise is being used. 218 | 219 | input: 220 | 4d video tensor 221 | 222 | output: 223 | two tensors describing if the video is fake or not and the category the video belongs to. 224 | 225 | 226 | """ 227 | 228 | class CategoricalVideoDiscriminator(VideoDiscriminator): 229 | def __init__(self, n_channels, dim_categorical, n_output_neurons=1, use_noise=False, noise_sigma=None): 230 | super(CategoricalVideoDiscriminator, self).__init__(n_channels=n_channels, 231 | n_output_neurons=n_output_neurons + dim_categorical, 232 | use_noise=use_noise, 233 | noise_sigma=noise_sigma) 234 | 235 | self.dim_categorical = dim_categorical 236 | 237 | def split(self, input): 238 | return input[:, :input.size(1) - self.dim_categorical], input[:, input.size(1) - self.dim_categorical:] 239 | 240 | def forward(self, input): 241 | h, _ = super(CategoricalVideoDiscriminator, self).forward(input) 242 | labels, categ = self.split(h) 243 | return labels, categ 244 | 245 | """ 246 | Description: 247 | videoGenerator is used to generate videos.Firstly a recurrent neural network is used to generate latent variables coresponding to every frame these are the motion variables. This is concatenated with random variables coresponding to the content in the frames 248 | ,if categories have been provided then the category dimension vector will also be concatenated.This concatenaed vector will be forwarded to the generator which will generate a video using upsampling layers. 249 | 250 | constructor inputs: 251 | n_channels: Number of channels in input image. 252 | dim_z_content:Dimensionality of the content vector. 253 | dim_z_category:Dimensionality of the category vector. 254 | dim_z_motion:dimensionality of the motion vector. 255 | video_length:Number of frames in the generated video. 256 | ngf: upscaling factor for the video. 257 | 258 | 259 | input: 260 | the first function is sample videos and it takes the number of samples as input. 261 | 262 | output: 263 | the generated video and the category label of the video. 264 | 265 | 266 | """ 267 | 268 | class VideoGenerator(nn.Module): 269 | def __init__(self, n_channels, dim_z_content, dim_z_category, dim_z_motion, 270 | video_length, ngf=64): 271 | super(VideoGenerator, self).__init__() 272 | 273 | self.n_channels = n_channels 274 | self.dim_z_content = dim_z_content 275 | self.dim_z_category = dim_z_category 276 | self.dim_z_motion = dim_z_motion 277 | self.video_length = video_length 278 | 279 | dim_z = dim_z_motion + dim_z_category + dim_z_content 280 | 281 | self.recurrent = nn.GRUCell(dim_z_motion, dim_z_motion) 282 | 283 | self.main = nn.Sequential( 284 | nn.ConvTranspose2d(dim_z, ngf * 8, 4, 1, 0, bias=False), 285 | nn.BatchNorm2d(ngf * 8), 286 | nn.ReLU(True), 287 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), 288 | nn.BatchNorm2d(ngf * 4), 289 | nn.ReLU(True), 290 | nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), 291 | nn.BatchNorm2d(ngf * 2), 292 | nn.ReLU(True), 293 | nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), 294 | nn.BatchNorm2d(ngf), 295 | nn.ReLU(True), 296 | nn.ConvTranspose2d(ngf, self.n_channels, 4, 2, 1, bias=False), 297 | nn.Tanh() 298 | ) 299 | 300 | def sample_z_m(self, num_samples, video_len=None): 301 | video_len = video_len if video_len is not None else self.video_length 302 | 303 | h_t = [self.get_gru_initial_state(num_samples)] 304 | 305 | for frame_num in range(video_len): 306 | e_t = self.get_iteration_noise(num_samples) 307 | h_t.append(self.recurrent(e_t, h_t[-1])) 308 | 309 | z_m_t = [h_k.view(-1, 1, self.dim_z_motion) for h_k in h_t] 310 | z_m = torch.cat(z_m_t[1:], dim=1).view(-1, self.dim_z_motion) 311 | 312 | return z_m 313 | 314 | def sample_z_categ(self, num_samples, video_len, category=None): 315 | video_len = video_len if video_len is not None else self.video_length 316 | if category: 317 | classes_to_generate = np.full(num_samples,category) 318 | else: 319 | classes_to_generate = np.random.randint(self.dim_z_category, size=num_samples) 320 | one_hot = np.zeros((num_samples, self.dim_z_category), dtype=np.float32) 321 | one_hot[np.arange(num_samples), classes_to_generate] = 1 322 | one_hot_video = np.repeat(one_hot, video_len, axis=0) 323 | 324 | one_hot_video = torch.from_numpy(one_hot_video) 325 | 326 | if torch.cuda.is_available(): 327 | one_hot_video = one_hot_video.cuda() 328 | 329 | return Variable(one_hot_video), classes_to_generate 330 | 331 | def sample_z_content(self, num_samples, video_len=None): 332 | video_len = video_len if video_len is not None else self.video_length 333 | 334 | content = np.random.normal(0, 1, (num_samples, self.dim_z_content)).astype(np.float32) 335 | content = np.repeat(content, video_len, axis=0) 336 | content = torch.from_numpy(content) 337 | if torch.cuda.is_available(): 338 | content = content.cuda() 339 | return Variable(content) 340 | 341 | def sample_z_video(self, num_samples, video_len=None,category=None): 342 | z_content = self.sample_z_content(num_samples, video_len) 343 | z_category, z_category_labels = self.sample_z_categ(num_samples, video_len, category) 344 | z_motion = self.sample_z_m(num_samples, video_len) 345 | 346 | z = torch.cat([z_content, z_category, z_motion], dim=1) 347 | 348 | return z, z_category_labels 349 | 350 | def sample_videos(self, num_samples, video_len=None,category=None): 351 | video_len = video_len if video_len is not None else self.video_length 352 | 353 | z, z_category_labels = self.sample_z_video(num_samples, video_len,category) 354 | 355 | h = self.main(z.view(z.size(0), z.size(1), 1, 1)) 356 | h = h.view(h.size(0) // video_len, video_len, self.n_channels, h.size(3), h.size(3)) 357 | 358 | z_category_labels = torch.from_numpy(z_category_labels) 359 | 360 | if torch.cuda.is_available(): 361 | z_category_labels = z_category_labels.cuda() 362 | 363 | h = h.permute(0, 2, 1, 3, 4) 364 | return h, Variable(z_category_labels, requires_grad=False) 365 | 366 | def sample_images(self, num_samples): 367 | z, z_category_labels = self.sample_z_video(num_samples * self.video_length * 2) 368 | 369 | j = np.sort(np.random.choice(z.size(0), num_samples, replace=False)).astype(np.int64) 370 | z = z[j, ::] 371 | z = z.view(z.size(0), z.size(1), 1, 1) 372 | h = self.main(z) 373 | 374 | return h, None 375 | 376 | def get_gru_initial_state(self, num_samples): 377 | return Variable(T.FloatTensor(num_samples, self.dim_z_motion).normal_()) 378 | 379 | def get_iteration_noise(self, num_samples): 380 | return Variable(T.FloatTensor(num_samples, self.dim_z_motion).normal_()) 381 | -------------------------------------------------------------------------------- /videogeneration/src/run.sh: -------------------------------------------------------------------------------- 1 | 2 | #put in the paramter values as you want them , recommended values are these 3 | 4 | image_batch=32 \ 5 | video_batch=32 \ 6 | noise_sigma=0.1 \ 7 | print_every=100 \ 8 | every_nth=2 \ 9 | dim_z_content=50\ 10 | dim_z_motion=10\ 11 | dim_z_category=6 \ 12 | 13 | 14 | #this shell script runs the train function with all the hyper paramaters 15 | # make sure the logs/actions or logs/shapes contains the dataset 16 | python train.py \ 17 | --image_batch ${image_batch} \ 18 | --video_batch ${video_batch} \ 19 | --use_infogan \ 20 | --use_noise \ 21 | --noise_sigma ${noise_sigma} \ 22 | --image_discriminator PatchImageDiscriminator \ 23 | --video_discriminator CategoricalVideoDiscriminator \ 24 | --print_every ${print_every} \ 25 | --every_nth ${every_nth} \ 26 | --dim_z_content ${dim_z_content} \ 27 | --dim_z_motion ${dim_z_motion} \ 28 | --dim_z_category ${dim_z_category}\ 29 | ../data/actions ../logs/actions -------------------------------------------------------------------------------- /videogeneration/src/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | The main training function from which training begins it includes the intilialization of the dataloaders and then starts the training. 3 | 4 | 5 | Usage: 6 | train.py [options] 7 | 8 | Options: 9 | --image_dataset= specifies a separate dataset to train for images [default: ] 10 | --image_batch= number of images in image batch [default: 10] 11 | --video_batch= number of videos in video batch [default: 3] 12 | 13 | --image_size= resize all frames to this size [default: 64] 14 | 15 | --use_infogan when specified infogan loss is used 16 | 17 | --use_categories when specified ground truth categories are used to 18 | train CategoricalVideoDiscriminator 19 | 20 | --use_noise when specified instance noise is used 21 | --noise_sigma= when use_noise is specified, noise_sigma controls 22 | the magnitude of the noise [default: 0] 23 | 24 | --image_discriminator= specifies image disciminator type (see models.py for a 25 | list of available models) [default: PatchImageDiscriminator] 26 | 27 | --video_discriminator= specifies video discriminator type (see models.py for a 28 | list of available models) [default: CategoricalVideoDiscriminator] 29 | 30 | --video_length= length of the video [default: 16] 31 | --print_every= print every iterations [default: 1] 32 | --n_channels= number of channels in the input data [default: 3] 33 | --every_nth= sample training videos using every nth frame [default: 4] 34 | --batches= specify number of batches to train [default: 100000] 35 | 36 | --dim_z_content= dimensionality of the content input, ie hidden space [default: 50] 37 | --dim_z_motion= dimensionality of the motion input [default: 10] 38 | --dim_z_category= dimensionality of categorical input [default: 6] 39 | """ 40 | 41 | import os 42 | import docopt 43 | import PIL 44 | 45 | import functools 46 | 47 | import torch 48 | from torch.utils.data import DataLoader 49 | from torchvision import transforms 50 | 51 | import models 52 | 53 | from trainers import Trainer 54 | 55 | import data 56 | 57 | 58 | def build_discriminator(type, **kwargs): 59 | discriminator_type = getattr(models, type) 60 | 61 | if 'Categorical' not in type and 'dim_categorical' in kwargs: 62 | kwargs.pop('dim_categorical') 63 | 64 | return discriminator_type(**kwargs) 65 | 66 | 67 | def video_transform(video, image_transform): 68 | vid = [] 69 | for im in video: 70 | vid.append(image_transform(im)) 71 | 72 | vid = torch.stack(vid).permute(1, 0, 2, 3) 73 | 74 | return vid 75 | 76 | 77 | if __name__ == "__main__": 78 | args = docopt.docopt(__doc__) 79 | print(args) 80 | 81 | n_channels = int(args['--n_channels']) 82 | 83 | image_transforms = transforms.Compose([ 84 | PIL.Image.fromarray, 85 | transforms.Resize(int(args["--image_size"])), 86 | transforms.ToTensor(), 87 | lambda x: x[:n_channels, ::], 88 | transforms.Normalize((0.5, 0.5, .5), (0.5, 0.5, 0.5)), 89 | ]) 90 | 91 | video_transforms = functools.partial(video_transform, image_transform=image_transforms) 92 | 93 | video_length = int(args['--video_length']) 94 | image_batch = int(args['--image_batch']) 95 | video_batch = int(args['--video_batch']) 96 | 97 | dim_z_content = int(args['--dim_z_content']) 98 | dim_z_motion = int(args['--dim_z_motion']) 99 | dim_z_category = int(args['--dim_z_category']) 100 | 101 | print(os.path.join(args[''])) 102 | dataset = data.VideoFolderDataset(args[''], cache=os.path.join(args[''], 'local.db')) 103 | 104 | image_dataset = data.ImageDataset(dataset, image_transforms) 105 | image_loader = DataLoader(image_dataset, batch_size=image_batch, drop_last=True, num_workers=2, shuffle=True) 106 | 107 | video_dataset = data.VideoDataset(dataset, 16, 2, video_transforms) 108 | video_loader = DataLoader(video_dataset, batch_size=video_batch, drop_last=True, num_workers=2, shuffle=True) 109 | 110 | generator = models.VideoGenerator(n_channels, dim_z_content, dim_z_category, dim_z_motion, video_length) 111 | 112 | image_discriminator = build_discriminator(args['--image_discriminator'], n_channels=n_channels, 113 | use_noise=args['--use_noise'], noise_sigma=float(args['--noise_sigma'])) 114 | 115 | video_discriminator = build_discriminator(args['--video_discriminator'], dim_categorical=dim_z_category, 116 | n_channels=n_channels, use_noise=args['--use_noise'], 117 | noise_sigma=float(args['--noise_sigma'])) 118 | 119 | if torch.cuda.is_available(): 120 | generator.cuda() 121 | image_discriminator.cuda() 122 | video_discriminator.cuda() 123 | 124 | trainer = Trainer(image_loader, video_loader, 125 | int(args['--print_every']), 126 | int(args['--batches']), 127 | args[''], 128 | use_cuda=torch.cuda.is_available(), 129 | use_infogan=args['--use_infogan'], 130 | use_categories=args['--use_categories']) 131 | 132 | trainer.train(generator, image_discriminator, video_discriminator) 133 | -------------------------------------------------------------------------------- /videogeneration/src/trainers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-ND 4.0 license (https://creativecommons.org/licenses/by-nc-nd/4.0/legalcode). 4 | """ 5 | 6 | import os 7 | import time 8 | 9 | import numpy as np 10 | 11 | from logger import Logger 12 | 13 | import torch 14 | from torch import nn 15 | 16 | from torch.autograd import Variable 17 | import torch.optim as optim 18 | 19 | if torch.cuda.is_available(): 20 | T = torch.cuda 21 | else: 22 | T = torch 23 | 24 | 25 | 26 | """ 27 | utility functions to convert images and videos to numpy arrays 28 | """ 29 | 30 | def images_to_numpy(tensor): 31 | generated = tensor.data.cpu().numpy().transpose(0, 2, 3, 1) 32 | generated[generated < -1] = -1 33 | generated[generated > 1] = 1 34 | generated = (generated + 1) / 2 * 255 35 | return generated.astype('uint8') 36 | 37 | 38 | def videos_to_numpy(tensor): 39 | generated = tensor.data.cpu().numpy().transpose(0, 1, 2, 3, 4) 40 | generated[generated < -1] = -1 41 | generated[generated > 1] = 1 42 | generated = (generated + 1) / 2 * 255 43 | return generated.astype('uint8') 44 | 45 | 46 | def one_hot_to_class(tensor): 47 | a, b = np.nonzero(tensor) 48 | return np.unique(b).astype(np.int32) 49 | 50 | """ 51 | Description: 52 | Trainer Class which contains all the neccessary functions for training both the generator and discriminator 53 | 54 | constructor inputs: 55 | image_sampler: real Image Dataloader for the discriminator network. 56 | video_sampler: real Video Dataloader for the discriminator network. 57 | log_interval: decides after how many iterations should the model be saved. 58 | train_batches: number of total batches that have to be trained. 59 | log_folder: the folder where log files have been saved. 60 | use_cuda: whether gpu is available or not. 61 | use_infogan: using infogan loss when categories are being used. 62 | use_categories: whether the dataset has specific categories such as shapes or actions. 63 | 64 | 65 | """ 66 | class Trainer(object): 67 | def __init__(self, image_sampler, video_sampler, log_interval, train_batches, log_folder, use_cuda=False, 68 | use_infogan=True, use_categories=True): 69 | 70 | self.use_categories = use_categories 71 | 72 | #binary cross entropy loss has been used as its similar to the gan loss in the paper. 73 | self.gan_criterion = nn.BCEWithLogitsLoss() 74 | self.category_criterion = nn.CrossEntropyLoss() 75 | 76 | self.image_sampler = image_sampler 77 | self.video_sampler = video_sampler 78 | 79 | self.video_batch_size = self.video_sampler.batch_size 80 | self.image_batch_size = self.image_sampler.batch_size 81 | 82 | self.log_interval = log_interval 83 | self.train_batches = train_batches 84 | 85 | self.log_folder = log_folder 86 | 87 | self.use_cuda = use_cuda 88 | self.use_infogan = use_infogan 89 | 90 | self.image_enumerator = None 91 | self.video_enumerator = None 92 | 93 | @staticmethod 94 | def ones_like(tensor, val=1.): 95 | return Variable(T.FloatTensor(tensor.size()).fill_(val), requires_grad=False) 96 | 97 | @staticmethod 98 | def zeros_like(tensor, val=0.): 99 | return Variable(T.FloatTensor(tensor.size()).fill_(val), requires_grad=False) 100 | 101 | 102 | """ 103 | Description: 104 | 105 | """ 106 | def compute_gan_loss(self, discriminator, sample_true, sample_fake, is_video): 107 | real_batch = sample_true() 108 | 109 | batch_size = real_batch['images'].size(0) 110 | fake_batch, generated_categories = sample_fake(batch_size) 111 | 112 | real_labels, real_categorical = discriminator(Variable(real_batch['images'])) 113 | fake_labels, fake_categorical = discriminator(fake_batch) 114 | 115 | fake_gt, real_gt = self.get_gt_for_discriminator(batch_size, real=0.) 116 | 117 | l_discriminator = self.gan_criterion(real_labels, real_gt) + \ 118 | self.gan_criterion(fake_labels, fake_gt) 119 | 120 | # update image discriminator here 121 | 122 | # sample again for videos 123 | 124 | # update video discriminator 125 | 126 | # sample again 127 | # - videos 128 | # - images 129 | 130 | # l_vidoes + l_images -> l 131 | # l.backward() 132 | # opt.step() 133 | 134 | 135 | # sample again and compute for generator 136 | 137 | fake_gt = self.get_gt_for_generator(batch_size) 138 | # to real_gt 139 | l_generator = self.gan_criterion(fake_labels, fake_gt) 140 | 141 | if is_video: 142 | 143 | # Ask the video discriminator to learn categories from training videos 144 | categories_gt = Variable(torch.squeeze(real_batch['categories'].long())) 145 | l_discriminator += self.category_criterion(real_categorical, categories_gt) 146 | 147 | if self.use_infogan: 148 | # Ask the generator to generate categories recognizable by the discriminator 149 | l_generator += self.category_criterion(fake_categorical, generated_categories) 150 | 151 | return l_generator, l_discriminator 152 | 153 | """ 154 | Description: 155 | generates the batch of real images from the dataloader 156 | 157 | output: 158 | tensor array of images 159 | """ 160 | def sample_real_image_batch(self): 161 | if self.image_enumerator is None: 162 | self.image_enumerator = enumerate(self.image_sampler) 163 | 164 | batch_idx, batch = next(self.image_enumerator) 165 | b = batch 166 | if self.use_cuda: 167 | for k, v in batch.items(): 168 | b[k] = v.cuda() 169 | 170 | if batch_idx == len(self.image_sampler) - 1: 171 | self.image_enumerator = enumerate(self.image_sampler) 172 | 173 | return b 174 | """ 175 | Description: 176 | generates the batch of real videos from the dataloader 177 | 178 | outptut: 179 | tensor array of videos 180 | """ 181 | def sample_real_video_batch(self): 182 | if self.video_enumerator is None: 183 | self.video_enumerator = enumerate(self.video_sampler) 184 | 185 | batch_idx, batch = next(self.video_enumerator) 186 | b = batch 187 | if self.use_cuda: 188 | for k, v in batch.items(): 189 | b[k] = v.cuda() 190 | 191 | if batch_idx == len(self.video_sampler) - 1: 192 | self.video_enumerator = enumerate(self.video_sampler) 193 | 194 | return b 195 | 196 | def train_discriminator(self, discriminator, sample_true, sample_fake, opt, batch_size, use_categories): 197 | opt.zero_grad() 198 | 199 | real_batch = sample_true() 200 | batch = Variable(real_batch['images'], requires_grad=False) 201 | 202 | # util.show_batch(batch.data) 203 | 204 | fake_batch, generated_categories = sample_fake(batch_size) 205 | 206 | real_labels, real_categorical = discriminator(batch) 207 | fake_labels, fake_categorical = discriminator(fake_batch.detach()) 208 | 209 | ones = self.ones_like(real_labels) 210 | zeros = self.zeros_like(fake_labels) 211 | 212 | l_discriminator = self.gan_criterion(real_labels, ones) + \ 213 | self.gan_criterion(fake_labels, zeros) 214 | 215 | if use_categories: 216 | # Ask the video discriminator to learn categories from training videos 217 | categories_gt = Variable(torch.squeeze(real_batch['categories'].long()), requires_grad=False) 218 | l_discriminator += self.category_criterion(real_categorical.squeeze(), categories_gt) 219 | 220 | l_discriminator.backward() 221 | opt.step() 222 | 223 | return l_discriminator 224 | 225 | def train_generator(self, 226 | image_discriminator, video_discriminator, 227 | sample_fake_images, sample_fake_videos, 228 | opt): 229 | 230 | opt.zero_grad() 231 | 232 | # train on images 233 | 234 | fake_batch, generated_categories = sample_fake_images(self.image_batch_size) 235 | fake_labels, fake_categorical = image_discriminator(fake_batch) 236 | all_ones = self.ones_like(fake_labels) 237 | 238 | l_generator = self.gan_criterion(fake_labels, all_ones) 239 | 240 | # train on videos 241 | 242 | fake_batch, generated_categories = sample_fake_videos(self.video_batch_size) 243 | fake_labels, fake_categorical = video_discriminator(fake_batch) 244 | all_ones = self.ones_like(fake_labels) 245 | 246 | l_generator += self.gan_criterion(fake_labels, all_ones) 247 | 248 | if self.use_infogan: 249 | # Ask the generator to generate categories recognizable by the discriminator 250 | l_generator += self.category_criterion(fake_categorical.squeeze(), generated_categories) 251 | 252 | l_generator.backward() 253 | opt.step() 254 | 255 | return l_generator 256 | """ 257 | Description: 258 | Core training function which calls all the sub train functions, intialises the optimizers and logs all the loss values to tensorboard. Also after an interval it saves the model for both inference and restarting training. 259 | 260 | Input: 261 | generator: Video generator object. 262 | image_discriminator: Image discriminator object. 263 | Video_discriminator: Video Discriminator. 264 | 265 | """ 266 | def train(self, generator, image_discriminator, video_discriminator): 267 | if self.use_cuda: 268 | generator.cuda() 269 | image_discriminator.cuda() 270 | video_discriminator.cuda() 271 | 272 | logger = Logger(self.log_folder) 273 | 274 | # create optimizers 275 | opt_generator = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=0.00001) 276 | opt_image_discriminator = optim.Adam(image_discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999), 277 | weight_decay=0.00001) 278 | opt_video_discriminator = optim.Adam(video_discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999), 279 | weight_decay=0.00001) 280 | 281 | # training loop 282 | 283 | def sample_fake_image_batch(batch_size): 284 | return generator.sample_images(batch_size) 285 | 286 | def sample_fake_video_batch(batch_size): 287 | return generator.sample_videos(batch_size) 288 | 289 | def init_logs(): 290 | return {'l_gen': 0, 'l_image_dis': 0, 'l_video_dis': 0} 291 | 292 | batch_num = 0 293 | 294 | logs = init_logs() 295 | 296 | start_time = time.time() 297 | 298 | while True: 299 | generator.train() 300 | image_discriminator.train() 301 | video_discriminator.train() 302 | 303 | opt_generator.zero_grad() 304 | 305 | opt_video_discriminator.zero_grad() 306 | 307 | # train image discriminator 308 | l_image_dis = self.train_discriminator(image_discriminator, self.sample_real_image_batch, 309 | sample_fake_image_batch, opt_image_discriminator, 310 | self.image_batch_size, use_categories=False) 311 | 312 | # train video discriminator 313 | l_video_dis = self.train_discriminator(video_discriminator, self.sample_real_video_batch, 314 | sample_fake_video_batch, opt_video_discriminator, 315 | self.video_batch_size, use_categories=self.use_categories) 316 | 317 | # train generator 318 | l_gen = self.train_generator(image_discriminator, video_discriminator, 319 | sample_fake_image_batch, sample_fake_video_batch, 320 | opt_generator) 321 | 322 | logs['l_gen'] += l_gen.data[0] 323 | 324 | logs['l_image_dis'] += l_image_dis.data[0] 325 | logs['l_video_dis'] += l_video_dis.data[0] 326 | 327 | batch_num += 1 328 | 329 | if batch_num % self.log_interval == 0: 330 | 331 | log_string = "Batch %d" % batch_num 332 | for k, v in logs.items(): 333 | log_string += " [%s] %5.3f" % (k, v / self.log_interval) 334 | 335 | log_string += ". Took %5.2f" % (time.time() - start_time) 336 | 337 | print(log_string) 338 | 339 | for tag, value in logs.items(): 340 | logger.scalar_summary(tag, value / self.log_interval, batch_num) 341 | 342 | logs = init_logs() 343 | start_time = time.time() 344 | 345 | generator.eval() 346 | 347 | images, _ = sample_fake_image_batch(self.image_batch_size) 348 | logger.image_summary("Images", images_to_numpy(images), batch_num) 349 | 350 | videos, _ = sample_fake_video_batch(self.video_batch_size) 351 | logger.video_summary("Videos", videos_to_numpy(videos), batch_num) 352 | 353 | torch.save(generator, os.path.join(self.log_folder, 'generator_%05d.pytorch' % batch_num)) 354 | torch.save(video_discriminator.state_dict(),os.path.join('restore','video_disc.pth')) 355 | torch.save(image_discriminator.state_dict(),os.path.join('restore','image_disc.pth')) 356 | torch.save(generator.state_dict(),os.path.join('restore','generator.pth')) 357 | if batch_num >= self.train_batches: 358 | torch.save(generator, os.path.join(self.log_folder, 'generator_%05d.pytorch' % batch_num)) 359 | break 360 | 361 | -------------------------------------------------------------------------------- /videogeneration/src/util.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | from torchvision import utils as vu 6 | 7 | 8 | def show_batch(batch): 9 | normed = batch * 0.5 + 0.5 10 | is_video_batch = len(normed.size()) > 4 11 | 12 | if is_video_batch: 13 | rows = [vu.make_grid(b.permute(1, 0, 2, 3), nrow=b.size(1)).numpy() for b in normed] 14 | im = np.concatenate(rows, axis=1) 15 | else: 16 | im = vu.make_grid(normed).numpy() 17 | 18 | im = im.transpose((1, 2, 0)) 19 | 20 | plt.imshow(im) 21 | plt.show(block=True) 22 | --------------------------------------------------------------------------------