├── README.md ├── cDCGAN.py.ipynb ├── doc2vec.ipynb ├── emergent_characters.ipynb ├── main.ipynb ├── symbols_dataset.py ├── train_d2v.py ├── utils.py └── word2vec.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # dimensions_of_dialogue- 2 | -------------------------------------------------------------------------------- /cDCGAN.py.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline\n", 10 | "%load_ext autoreload\n", 11 | "%autoreload 2\n", 12 | "import os, time\n", 13 | "import numpy as np\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "import itertools\n", 16 | "import pickle\n", 17 | "import imageio\n", 18 | "import torch\n", 19 | "import torch.nn as nn\n", 20 | "import torch.nn.functional as F\n", 21 | "import torch.optim as optim\n", 22 | "from torchvision import datasets, transforms\n", 23 | "from torch.autograd import Variable\n", 24 | "from tqdm import tqdm_notebook\n", 25 | "import kornia\n", 26 | "from torchvision.utils import save_image\n", 27 | "from IPython.core.display import Image, display" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "# refactored from https://github.com/znxlwm/pytorch-MNIST-CelebA-cGAN-cDCGAN/blob/master/pytorch_MNIST_cDCGAN.py\n", 37 | "# G(z)\n", 38 | "class generator(nn.Module):\n", 39 | " # initializers\n", 40 | " def __init__(self, d=128, n_classes=10):\n", 41 | " super(generator, self).__init__()\n", 42 | " self.deconv1_1 = nn.ConvTranspose2d(100, d*2, 4, 1, 0)\n", 43 | " self.deconv1_1_bn = nn.BatchNorm2d(d*2)\n", 44 | " self.deconv1_2 = nn.ConvTranspose2d(n_classes, d*2, 4, 1, 0)\n", 45 | " self.deconv1_2_bn = nn.BatchNorm2d(d*2)\n", 46 | " self.deconv2 = nn.ConvTranspose2d(d*4, d*2, 4, 2, 1)\n", 47 | " self.deconv2_bn = nn.BatchNorm2d(d*2)\n", 48 | " self.deconv3 = nn.ConvTranspose2d(d*2, d, 4, 2, 1)\n", 49 | " self.deconv3_bn = nn.BatchNorm2d(d)\n", 50 | " self.deconv4 = nn.ConvTranspose2d(d, 1, 4, 2, 1)\n", 51 | "\n", 52 | " # weight_init\n", 53 | " def weight_init(self, mean, std):\n", 54 | " for m in self._modules:\n", 55 | " normal_init(self._modules[m], mean, std)\n", 56 | "\n", 57 | " # forward method\n", 58 | " def forward(self, input, label):\n", 59 | " x = F.relu(self.deconv1_1_bn(self.deconv1_1(input)))\n", 60 | " y = F.relu(self.deconv1_2_bn(self.deconv1_2(label)))\n", 61 | " x = torch.cat([x, y], 1)\n", 62 | " x = F.relu(self.deconv2_bn(self.deconv2(x)))\n", 63 | " x = F.relu(self.deconv3_bn(self.deconv3(x)))\n", 64 | " x = torch.tanh(self.deconv4(x))\n", 65 | " return x\n", 66 | "\n", 67 | "class discriminator(nn.Module):\n", 68 | " # initializers\n", 69 | " def __init__(self, d=128, n_classes=10):\n", 70 | " super(discriminator, self).__init__()\n", 71 | " self.conv1_1 = nn.Conv2d(1, d//2, 4, 2, 1)\n", 72 | " self.conv1_2 = nn.Conv2d(n_classes, d//2, 4, 2, 1)\n", 73 | " self.conv2 = nn.Conv2d(d, d*2, 4, 2, 1)\n", 74 | " self.conv2_bn = nn.BatchNorm2d(d*2)\n", 75 | " self.conv3 = nn.Conv2d(d*2, d*4, 4, 2, 1)\n", 76 | " self.conv3_bn = nn.BatchNorm2d(d*4)\n", 77 | "# self.conv4 = nn.Conv2d(d * 4, 101, 4, 1, 0)\n", 78 | " self.conv4 = nn.Conv2d(d * 4, 1, 4, 1, 0)\n", 79 | "\n", 80 | " # weight_init\n", 81 | " def weight_init(self, mean, std):\n", 82 | " for m in self._modules:\n", 83 | " normal_init(self._modules[m], mean, std)\n", 84 | "\n", 85 | " # forward method\n", 86 | " def forward(self, images, labels):\n", 87 | " x = F.leaky_relu(self.conv1_1(images), 0.2)\n", 88 | " y = F.leaky_relu(self.conv1_2(labels), 0.2)\n", 89 | " x = torch.cat([x, y], 1)\n", 90 | " x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)\n", 91 | " x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)\n", 92 | " x = self.conv4(x)\n", 93 | " return torch.sigmoid(x[:, :1])\n", 94 | "\n", 95 | "def normal_init(m, mean, std):\n", 96 | " if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):\n", 97 | " m.weight.data.normal_(mean, std)\n", 98 | " m.bias.data.zero_()" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 3, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "# help(nn.Conv2d)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 4, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'):\n", 117 | " x = range(len(hist['D_losses']))\n", 118 | " y1 = hist['D_losses']\n", 119 | " y2 = hist['G_losses']\n", 120 | " plt.plot(x, y1, label='D_loss')\n", 121 | " plt.plot(x, y2, label='G_loss')\n", 122 | " plt.xlabel('Epoch')\n", 123 | " plt.ylabel('Loss')\n", 124 | " plt.legend(loc=4)\n", 125 | " plt.grid(True)\n", 126 | " plt.tight_layout()\n", 127 | " if save:\n", 128 | " plt.savefig(path)\n", 129 | " if show:\n", 130 | " plt.show()\n", 131 | " else:\n", 132 | " plt.close()" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 5, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "# training parameters\n", 142 | "batch_size = 128\n", 143 | "lr = 0.0002\n", 144 | "train_epoch = 50\n", 145 | "\n", 146 | "# data_loader\n", 147 | "img_size = 32\n", 148 | "\n", 149 | "\n", 150 | "# MNIST loader.\n", 151 | "n_classes = 10\n", 152 | "transform = transforms.Compose([\n", 153 | " transforms.Resize(img_size),\n", 154 | " transforms.Grayscale(num_output_channels=1),\n", 155 | " transforms.ToTensor(),\n", 156 | " transforms.Normalize([0.5], [0.5])\n", 157 | "])\n", 158 | "dataloader = torch.utils.data.DataLoader(\n", 159 | " datasets.MNIST('data', train=True, download=True, transform=transform),\n", 160 | " batch_size=batch_size, shuffle=True)\n", 161 | "\n", 162 | "# from symbols_dataset import make_loader\n", 163 | "# n_classes = 22\n", 164 | "# dataloader, dataset = make_loader(batch_size, img_size)\n", 165 | "\n", 166 | "\n", 167 | "# results save folder\n", 168 | "root = 'MNIST_cDCGAN_results/'\n", 169 | "model = 'MNIST_cDCGAN_'\n", 170 | "if not os.path.isdir(root):\n", 171 | " os.mkdir(root)\n", 172 | "if not os.path.isdir(root + 'Fixed_results'):\n", 173 | " os.mkdir(root + 'Fixed_results')\n", 174 | "\n", 175 | "train_hist = {}\n", 176 | "train_hist['D_losses'] = []\n", 177 | "train_hist['G_losses'] = []\n", 178 | "train_hist['per_epoch_ptimes'] = []\n", 179 | "train_hist['total_ptime'] = []\n", 180 | "\n", 181 | "onehot = torch.zeros(n_classes, n_classes)\n", 182 | "onehot = onehot.scatter_(1, torch.LongTensor(np.arange(n_classes)).view(n_classes, 1), 1).view(n_classes, n_classes, 1, 1)\n", 183 | "fill = torch.zeros([n_classes, n_classes, img_size, img_size])\n", 184 | "for i in range(n_classes):\n", 185 | " fill[i, i, :, :] = 1\n" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 6, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "# fixed noise & label\n", 195 | "num_fixed = 100\n", 196 | "\n", 197 | "fixed_z = torch.randn((num_fixed, 100, 1, 1)).cuda()\n", 198 | "\n", 199 | "fixed_y_label = torch.zeros(num_fixed, n_classes, 1, 1).cuda()\n", 200 | "for i in range(num_fixed):\n", 201 | " label = (i // 4) % n_classes\n", 202 | " fixed_y_label[i, label, :, :] = 1\n", 203 | " \n", 204 | "def show_result(num_epoch, show = True, save = False, path = 'result.png'):\n", 205 | " G.eval()\n", 206 | " test_images = G(fixed_z, fixed_y_label)\n", 207 | " test_images = (test_images.cpu() + 1) * 0.5\n", 208 | " G.train()\n", 209 | " test_images = test_images[:100].data.view(100, 1, 32, 32)\n", 210 | " save_image(test_images, path, nrow=10, padding=1, pad_value=1, scale_each=False, normalize=False)\n", 211 | " display(Image(path))" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 7, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "import matplotlib.pyplot as plt\n", 221 | "import matplotlib as mpl\n", 222 | "\n", 223 | "def show_dataset(dataset, n=6):\n", 224 | " img = np.hstack([ np.asarray(dataset[i][0][0]) for i in range(n) ])\n", 225 | " plt.figure(figsize = (10,2))\n", 226 | " plt.imshow(img)\n", 227 | " plt.axis('off')" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 9, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "# show_dataset(dataset, 12)" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": null, 242 | "metadata": { 243 | "scrolled": false 244 | }, 245 | "outputs": [ 246 | { 247 | "name": "stdout", 248 | "output_type": "stream", 249 | "text": [ 250 | "training start!\n" 251 | ] 252 | }, 253 | { 254 | "data": { 255 | "application/vnd.jupyter.widget-view+json": { 256 | "model_id": "54a826eaa0a64f83be407798619e744f", 257 | "version_major": 2, 258 | "version_minor": 0 259 | }, 260 | "text/plain": [ 261 | "HBox(children=(IntProgress(value=0, max=469), HTML(value='')))" 262 | ] 263 | }, 264 | "metadata": {}, 265 | "output_type": "display_data" 266 | } 267 | ], 268 | "source": [ 269 | "# network\n", 270 | "G = generator(128, n_classes=n_classes)\n", 271 | "D = discriminator(128, n_classes=n_classes)\n", 272 | "G.weight_init(mean=0.0, std=0.02)\n", 273 | "D.weight_init(mean=0.0, std=0.02)\n", 274 | "G.cuda()\n", 275 | "D.cuda()\n", 276 | "\n", 277 | "# Binary Cross Entropy loss\n", 278 | "BCE_loss = nn.BCELoss()\n", 279 | "\n", 280 | "# Adam optimizer\n", 281 | "G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))\n", 282 | "D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))\n", 283 | "\n", 284 | "print('training start!')\n", 285 | "start_time = time.time()\n", 286 | "for epoch in range(train_epoch):\n", 287 | " D_losses = []\n", 288 | " G_losses = []\n", 289 | " \n", 290 | " if (epoch+1) in [10, 16]:\n", 291 | " G_optimizer.param_groups[0]['lr'] /= 10\n", 292 | " D_optimizer.param_groups[0]['lr'] /= 10\n", 293 | " print(\"learning rate change!\")\n", 294 | "\n", 295 | " epoch_start_time = time.time()\n", 296 | " y_real = torch.ones(batch_size).cuda()\n", 297 | " y_fake = torch.zeros(batch_size).cuda()\n", 298 | " \n", 299 | " for i, (images, y_) in enumerate(tqdm_notebook(dataloader)):\n", 300 | " images = images.cuda()\n", 301 | " \n", 302 | " #--------------------------------------------------\n", 303 | " # train discriminator\n", 304 | " #--------------------------------------------------\n", 305 | " D.zero_grad()\n", 306 | " mini_batch = images.size()[0]\n", 307 | "\n", 308 | " if mini_batch != batch_size:\n", 309 | " y_real = torch.ones(mini_batch).cuda()\n", 310 | " y_fake = torch.zeros(mini_batch).cuda()\n", 311 | " \n", 312 | " y_fill = fill[y_].cuda()\n", 313 | "\n", 314 | " # Train desciminator on real images.\n", 315 | " D_result = D(images, y_fill).squeeze()\n", 316 | " D_real_loss = BCE_loss(D_result, y_real)\n", 317 | "\n", 318 | " # Train desciminator on fake, generated images.\n", 319 | " z_ = torch.randn((mini_batch, 100)).view(-1, 100, 1, 1).cuda()\n", 320 | " y_ = (torch.rand(mini_batch, 1) * n_classes).long().squeeze()\n", 321 | " y_label_ = onehot[y_].cuda()\n", 322 | " y_fill_ = fill[y_].cuda()\n", 323 | "\n", 324 | " G_result = G(z_, y_label_)\n", 325 | " D_result = D(G_result, y_fill_).squeeze()\n", 326 | " D_fake_loss = BCE_loss(D_result, y_fake)\n", 327 | "\n", 328 | " # Total loss is classifying fake + classifying real.\n", 329 | " D_train_loss = D_real_loss + D_fake_loss\n", 330 | "\n", 331 | " D_train_loss.backward()\n", 332 | " D_optimizer.step()\n", 333 | " D_losses.append(D_train_loss)\n", 334 | "\n", 335 | " #--------------------------------------------------\n", 336 | " # train generator\n", 337 | " #--------------------------------------------------\n", 338 | " G.zero_grad()\n", 339 | "\n", 340 | " z = torch.randn((mini_batch, 100)).view(-1, 100, 1, 1).cuda()\n", 341 | " y = (torch.rand(mini_batch, 1) * n_classes).long().squeeze()\n", 342 | "\n", 343 | " G_result = G(z, onehot[y].cuda())\n", 344 | " D_result = D(G_result, fill[y].cuda()).squeeze()\n", 345 | "\n", 346 | " G_train_loss = BCE_loss(D_result, y_real)\n", 347 | "\n", 348 | " G_train_loss.backward()\n", 349 | " G_optimizer.step()\n", 350 | " G_losses.append(G_train_loss)\n", 351 | " \n", 352 | "\n", 353 | " fixed_p = root + 'Fixed_results/' + model + str(epoch + 1) + '.png'\n", 354 | " show_result((epoch+1), show=True, path=fixed_p)\n", 355 | "\n", 356 | " epoch_end_time = time.time()\n", 357 | " per_epoch_ptime = epoch_end_time - epoch_start_time\n", 358 | "\n", 359 | " print('[%d/%d] - ptime: %.2f, loss_d: %.3f, loss_g: %.3f' % \\\n", 360 | " ((epoch + 1), train_epoch, per_epoch_ptime, torch.mean(torch.FloatTensor(D_losses)),\n", 361 | " torch.mean(torch.FloatTensor(G_losses))))\n", 362 | " \n", 363 | " train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses)))\n", 364 | " train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses)))\n", 365 | " train_hist['per_epoch_ptimes'].append(per_epoch_ptime)" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": null, 371 | "metadata": { 372 | "scrolled": false 373 | }, 374 | "outputs": [], 375 | "source": [ 376 | "show_train_hist(train_hist, show=True, save=True, path=root + 'train_hist.png')" 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": null, 382 | "metadata": {}, 383 | "outputs": [], 384 | "source": [] 385 | } 386 | ], 387 | "metadata": { 388 | "kernelspec": { 389 | "display_name": "Python 3", 390 | "language": "python", 391 | "name": "python3" 392 | }, 393 | "language_info": { 394 | "codemirror_mode": { 395 | "name": "ipython", 396 | "version": 3 397 | }, 398 | "file_extension": ".py", 399 | "mimetype": "text/x-python", 400 | "name": "python", 401 | "nbconvert_exporter": "python", 402 | "pygments_lexer": "ipython3", 403 | "version": "3.6.6" 404 | } 405 | }, 406 | "nbformat": 4, 407 | "nbformat_minor": 2 408 | } 409 | -------------------------------------------------------------------------------- /symbols_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import torchvision.transforms as transforms 4 | from torch.utils.data import DataLoader 5 | from torchvision import datasets 6 | import PIL 7 | from imgaug import augmenters as iaa 8 | 9 | # transforms = iaa.SomeOf(3, [ 10 | # iaa.AdditiveGaussianNoise(scale=0.2*255), 11 | # iaa.GaussianBlur((0.0, 2.0)), 12 | # iaa.Affine( 13 | # scale={"x": (0.9, 1.1), "y": (0.9, 1.1)}, 14 | # translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)}, 15 | # rotate=(-5, 5), 16 | # shear=(-2, 2), 17 | # cval=(255, 255), 18 | # ), 19 | # iaa.SaltAndPepper(.05) 20 | # ]) 21 | 22 | class ImgAugTransform: 23 | def __init__(self, img_size=32): 24 | t = .05 25 | self.transforms = iaa.Sequential([ 26 | iaa.Resize(img_size), 27 | iaa.SomeOf(2, [ 28 | # iaa.Affine( 29 | # scale={"x": (0.9, 1.1), "y": (0.9, 1.1)}, 30 | # translate_percent={"x": (-t, t), "y": (-t, t)}, 31 | # rotate=(-8, 8), 32 | # shear=(-4, 4), 33 | # cval=(255, 255), 34 | # ), 35 | # iaa.AdditiveGaussianNoise(scale=0.1*255), 36 | # iaa.GaussianBlur((0.0, 2.0)), 37 | iaa.SaltAndPepper(.1) 38 | ]) 39 | ]) 40 | def __call__(self, img): 41 | img = np.array(img) 42 | return self.transforms.augment_image(img) 43 | 44 | def make_loader(batch_size, img_size=32): 45 | 46 | dataset = datasets.ImageFolder('./dataset', transform=transforms.Compose([ 47 | ImgAugTransform(img_size), 48 | lambda x: PIL.Image.fromarray(x), 49 | transforms.Grayscale(num_output_channels=1), 50 | transforms.ToTensor(), 51 | transforms.Normalize([0.5], [0.5]), 52 | ])) 53 | 54 | return DataLoader( 55 | dataset, 56 | batch_size=batch_size, 57 | shuffle=True, 58 | ), dataset 59 | -------------------------------------------------------------------------------- /train_d2v.py: -------------------------------------------------------------------------------- 1 | # https://medium.com/@mishra.thedeepak/doc2vec-simple-implementation-example-df2afbbfbad5 2 | from gensim.models.doc2vec import Doc2Vec, TaggedDocument 3 | from nltk.tokenize import word_tokenize 4 | from nltk.tokenize import RegexpTokenizer 5 | import nltk 6 | import numpy as np 7 | 8 | nltk.download('punkt') 9 | tokenizer = RegexpTokenizer(r'\w+') 10 | 11 | # https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html 12 | movie_lines_path = "./cornell movie-dialogs corpus/movie_lines.txt" 13 | corpus = [] 14 | 15 | with open(movie_lines_path, 'r', encoding="utf8", errors='ignore') as movie_lines_file: 16 | for line in movie_lines_file.readlines(): 17 | line = line.strip() 18 | split = line.split('+++$+++') 19 | dialogue = split[-1] 20 | corpus.append(dialogue) 21 | 22 | tagged_data = [ TaggedDocument(words=tokenizer.tokenize(_d.lower()), tags=[ str(i) ]) \ 23 | for i, _d in enumerate(corpus) ] 24 | 25 | max_epochs = 100 26 | vec_size = 25 27 | alpha = 0.025 28 | 29 | model = Doc2Vec(size=vec_size, 30 | alpha=alpha, 31 | min_alpha=0.00025, 32 | min_count=1, 33 | dm =1) 34 | 35 | model.build_vocab(tagged_data) 36 | 37 | for epoch in range(max_epochs): 38 | print('iteration {0}'.format(epoch)) 39 | model.train(tagged_data, 40 | total_examples=model.corpus_count, 41 | epochs=model.iter) 42 | # decrease the learning rate 43 | model.alpha -= 0.0002 44 | # fix the learning rate, no decay 45 | model.min_alpha = model.alpha 46 | 47 | model.save("d2v_%i.model"%vec_size) 48 | print("Model Saved") 49 | 50 | dataset = np.zeros((len(tagged_data), vec_size), dtype='float32') 51 | 52 | for i, tagged_words in enumerate(tagged_data): 53 | vec = model.infer_vector(tagged_words.words) 54 | dataset[i] = vec 55 | 56 | np.save('doc_to_vec_%i.npy'%vec_size, dataset) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from torchvision.utils import save_image 3 | from IPython.core.display import Image, display 4 | 5 | def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'): 6 | x = range(len(hist['D_losses'])) 7 | y1 = hist['D_losses'] 8 | y2 = hist['G_losses'] 9 | plt.plot(x, y1, label='D_loss') 10 | plt.plot(x, y2, label='G_loss') 11 | plt.xlabel('Epoch') 12 | plt.ylabel('Loss') 13 | plt.legend() #loc=1 14 | plt.grid(True) 15 | plt.tight_layout() 16 | if save: 17 | plt.savefig(path) 18 | if show: 19 | plt.show() 20 | else: 21 | plt.close() 22 | 23 | def make_images(G, z, x, c): 24 | test_images = G(z, x) 25 | test_images = (test_images.cpu() + 1) * 0.5 26 | test_images = test_images.data.view(test_images.shape[0], c, 32, 32) 27 | return test_images 28 | 29 | def show_images(images, path='result.png', rows=8): 30 | save_image(images, path, nrow=rows, padding=1, pad_value=1.0, scale_each=False, normalize=False) 31 | display(Image(path)) 32 | 33 | 34 | --------------------------------------------------------------------------------