├── README.md ├── Semi-supervised ├── semi-supervised_learning_2.ipynb └── semi-supervised_learning_2_so.ipynb ├── autoencoder ├── .gitignore ├── .ipynb_checkpoints │ ├── Convolutional_Autoencoder-checkpoint.ipynb │ ├── Convolutional_Autoencoder_Solution-checkpoint.ipynb │ ├── Simple_Autoencoder-checkpoint.ipynb │ └── Simple_Autoencoder_Solution-checkpoint.ipynb ├── Convolutional_Autoencoder.ipynb ├── Convolutional_Autoencoder_So.ipynb ├── Simple_Autoencoder.ipynb ├── Simple_Autoencoder_So.ipynb └── assets │ ├── autoencoder_1.png │ ├── compressed.png │ ├── convolutional_autoencoder.png │ ├── denoising.png │ ├── mnist_examples.png │ └── simple_autoencoder.png ├── dcgan-svhn ├── .ipynb_checkpoints │ ├── DCGAN-checkpoint.ipynb │ └── DCGAN_Exercises-checkpoint.ipynb ├── 1511.06434.pdf ├── DCGAN.ipynb ├── DCGAN_Exercises.ipynb ├── assets │ ├── 32x32eg.png │ ├── SVHN_examples.png │ ├── dcgan.png │ └── svhn_gan.png ├── checkpoints │ └── .gitignore └── data │ └── .gitignore ├── gan_mnist ├── .gitignore ├── .ipynb_checkpoints │ ├── Intro_to_GANs_Exercises-checkpoint.ipynb │ └── Intro_to_GANs_Solution-checkpoint.ipynb ├── Intro_to_GANs_Exercises.ipynb ├── Intro_to_GANs_So.ipynb ├── assets │ ├── gan_diagram.png │ └── gan_network.png └── checkpoints │ └── .gitignore ├── seq2seq-twitter-chatbot ├── data │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-36.pyc │ ├── cornell_corpus │ │ └── data.py │ └── twitter │ │ ├── __pycache__ │ │ └── data.cpython-36.pyc │ │ ├── data.py │ │ ├── idx_a.npy │ │ ├── idx_q.npy │ │ ├── metadata.pkl │ │ ├── pull │ │ └── pull_raw_data └── main.py └── seq2seq ├── .ipynb_checkpoints └── sequence_to_sequence_implementation-checkpoint.ipynb ├── __pycache__ └── helper.cpython-36.pyc ├── data ├── english1 ├── english_new ├── french1 ├── french_new ├── letters_source.txt ├── letters_target.txt └── sequence_to_sequence_implementation_modify.ipynb ├── helper.py ├── images ├── sequence-to-sequence-inference-decoder.png ├── sequence-to-sequence-training-decoder.png └── sequence-to-sequence.jpg ├── sequence_to_sequence_implementation.ipynb └── sequence_to_sequence_implementation_modify.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Seq2Seq-Gan 2 | **Jianguo Zhang, June 20, 2018** 3 | 4 | Related implementations for **sequence to sequence**, **generative adversarial networks(GAN)** and **Autoencoder** 5 | 6 | ## Sequence to Sequence 7 | 8 | ![image1](https://github.com/JianguoZhang1994/Seq2Seq-Gan-Autoencoder/blob/master/seq2seq/images/sequence-to-sequence-inference-decoder.png) 9 | 10 | ## Generative Adversarial Networks 11 | 12 | ### gan_diagram 13 | 14 |
15 | 16 | ![image2](https://github.com/JianguoZhang1994/Seq2Seq-Gan-Autoencoder/blob/master/gan_mnist/assets/gan_diagram.png) 17 |
18 | 19 | ### dcgan 20 | 21 |
22 | 23 | ![image6](https://github.com/JianguoZhang1994/Seq2Seq-Gan-Autoencoder/blob/master/dcgan-svhn/assets/dcgan.png) 24 |
25 | 26 | 27 | ### svhn_gan 28 | 29 |
30 | 31 | ![image7](https://github.com/JianguoZhang1994/Seq2Seq-Gan-Autoencoder/blob/master/dcgan-svhn/assets/svhn_gan.png) 32 |
33 | 34 | ### gan_network 35 | 36 |
37 | 38 | ![image3](https://github.com/JianguoZhang1994/Seq2Seq-Gan-Autoencoder/blob/master/gan_mnist/assets/gan_network.png) 39 |
40 | 41 | ## Autoencoder 42 | 43 | ### simple_autoencoder 44 | 45 |
46 | 47 | ![image4](https://github.com/JianguoZhang1994/Seq2Seq-Gan-Autoencoder/blob/master/autoencoder/assets/simple_autoencoder.png) 48 |
49 | 50 | ### convolutional_autoencoder 51 | 52 |
53 | 54 | ![image5](https://github.com/JianguoZhang1994/Seq2Seq-Gan-Autoencoder/blob/master/autoencoder/assets/convolutional_autoencoder.png) 55 |
56 | 57 | 58 | ### Semi-supervised learning 59 | 60 | [This](https://github.com/JianguoZhang1994/Seq2Seq-Gan-Autoencoder/tree/master/Semi-supervised) is an implementation for [Improve techniques for training Gans](https://arxiv.org/pdf/1606.03498.pdf). Which can be modified for semi-supervised learning 61 | 62 | ### Seq2Seq Twitter-Chatbot 63 | 64 | [This](https://github.com/jianguoz/Seq2Seq-Gan-Autoencoder/tree/master/seq2seq-twitter-chatbot) is 200-line codes of Seq2Seq model for twitter chatbot, the dataset is already uploaded, and the code can be ran directly. 65 | 66 | Requirements: Tensorflow 1.8.0, [Tensorlayer>=1.8.4](http://tensorlayer.readthedocs.io/en/latest/modules/layers.html#tensorlayer.layers.Seq2Seq), cuda>=9.0. 67 | 68 | After tweaking the dataset to bring down the overall % of unknown tokens to 3%, increasing vocabulary to 8000, the results are promising. 69 | 70 | | Query | Reply | 71 | | ------------- | ------------- | 72 | | donald trump won last nights presidential debate according to snap online polls | thought he was a joke | 73 | | trump campaign please never delete this | he was a joke | 74 | | s new scoop is unk the castro connection how trumps company violated the us embargo against cuba | trump is a disgrace to the debate and the media is not a liar | 75 | | who won the first presidential debate | trump will be a better time for a man | 76 | | just wanna live in unk everything is 10x better there | i was just on the same side and i was like it was a good time | 77 | | the lil girl i forgot her name scares the shit out of me n probably always will lmaooo | she was so cute and she was so cute and she was a bitch | 78 | | question of the day who do you think won last nights debate weigh in using | hillary is a trump supporter and a woman who has a better job than a woman | 79 | | trump is a total embarrassment hes like a unk child who happens to be a professional liar and con man | trump is a lot of money to get a great country | 80 | | chill the fuck out | thats what i said | 81 | | so fun to see what the special effects team created for | wow that is the best thing ive ever seen | 82 | | she was fine too unk voice | she is a good guy | 83 | | the important ones are still here | i dont think so | 84 | | been drinking pumpkin spice protein shake every morning for a week and now i literally cannot even | me too i just got it | 85 | | lmao i just lost my job and i cant even leave early to be in the unk video today | literally the same thing | 86 | | hey happy birthday have a nice day | thank you | 87 | 88 | _______________________________________________________________ 89 | -------------------------------------------------------------------------------- /autoencoder/.gitignore: -------------------------------------------------------------------------------- 1 | MNIST_data 2 | -------------------------------------------------------------------------------- /autoencoder/.ipynb_checkpoints/Simple_Autoencoder_Solution-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# A Simple Autoencoder\n", 8 | "\n", 9 | "We'll start off by building a simple autoencoder to compress the MNIST dataset. With autoencoders, we pass input data through an encoder that makes a compressed representation of the input. Then, this representation is passed through a decoder to reconstruct the input data. Generally the encoder and decoder will be built with neural networks, then trained on example data.\n", 10 | "\n", 11 | "![Autoencoder](assets/autoencoder_1.png)\n", 12 | "\n", 13 | "In this notebook, we'll be build a simple network architecture for the encoder and decoder. Let's get started by importing our libraries and getting the dataset." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": { 20 | "collapsed": true 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "%matplotlib inline\n", 25 | "\n", 26 | "import numpy as np\n", 27 | "import tensorflow as tf\n", 28 | "import matplotlib.pyplot as plt" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "name": "stdout", 38 | "output_type": "stream", 39 | "text": [ 40 | "Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.\n", 41 | "Extracting MNIST_data/train-images-idx3-ubyte.gz\n", 42 | "Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.\n", 43 | "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n", 44 | "Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.\n", 45 | "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n", 46 | "Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.\n", 47 | "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n" 48 | ] 49 | } 50 | ], 51 | "source": [ 52 | "from tensorflow.examples.tutorials.mnist import input_data\n", 53 | "mnist = input_data.read_data_sets('MNIST_data', validation_size=0)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "Below I'm plotting an example image from the MNIST dataset. These are 28x28 grayscale images of handwritten digits." 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 3, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "data": { 70 | "text/plain": [ 71 | "" 72 | ] 73 | }, 74 | "execution_count": 3, 75 | "metadata": {}, 76 | "output_type": "execute_result" 77 | }, 78 | { 79 | "data": { 80 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADP9JREFUeJzt3V+IXPUZxvHnSfwHieCf4BJtMBGkKkFTWMR/lGibajUS\nvYiYi5JSdXvRSgsVKulFhVqQYlq8ErYkGkuNKRjJEsSgoZgWqyQRTaI2idUUs8akMWLthdQkby/m\nRLZx58xm5syc2X2/H1h25rxz5rwc9tnfOXNm5ueIEIB8ptXdAIB6EH4gKcIPJEX4gaQIP5AU4QeS\nIvxAUoQfSIrwA0md1suN2ebthECXRYQn8riORn7bt9jebftd2w928lwAesvtvrff9nRJeyQtkrRf\n0lZJyyLi7ZJ1GPmBLuvFyH+1pHcj4r2I+K+kZyQt6eD5APRQJ+G/SNIHY+7vL5b9H9tDtrfZ3tbB\ntgBUrOsv+EXEsKRhicN+oJ90MvKPSpoz5v7XimUAJoFOwr9V0qW259k+Q9LdkkaqaQtAt7V92B8R\nR23/WNImSdMlrY6ItyrrDEBXtX2pr62Ncc4PdF1P3uQDYPIi/EBShB9IivADSRF+ICnCDyRF+IGk\nCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiB\npAg/kBThB5Ii/EBShB9IivADSRF+IKm2p+iWJNv7JH0m6ZikoxExWEVTQBWWLl3atPbEE0+Urnv9\n9deX1t988822euonHYW/cGNEHK7geQD0EIf9QFKdhj8kvWR7u+2hKhoC0BudHvbfEBGjti+Q9KLt\nv0fElrEPKP4p8I8B6DMdjfwRMVr8PiTpOUlXj/OY4YgY5MVAoL+0HX7bM2yffeK2pO9I2lVVYwC6\nq5PD/gFJz9k+8TxPR8QLlXQFoOvaDn9EvCfpqgp76aolS5aU1mfNmlVaX7VqVZXtoAeuueaaprW9\ne/f2sJP+xKU+ICnCDyRF+IGkCD+QFOEHkiL8QFJVfKpvUli0aFFpff78+aV1LvX1n2nTyseuyy67\nrGltYGCgdN3i/StTGiM/kBThB5Ii/EBShB9IivADSRF+ICnCDyTliOjdxuzebewkH3/8cWl9586d\npfWFCxdW2A2qcPHFF5fW33///aa1l19+uXTdG2+8sa2e+kFETOhNCoz8QFKEH0iK8ANJEX4gKcIP\nJEX4gaQIP5BUms/zt/rsNyafkZGRttfdtYv5ZUgEkBThB5Ii/EBShB9IivADSRF+ICnCDyTV8jq/\n7dWSFks6FBHzi2XnSVonaa6kfZLuiohPutdma2XTMUvSjBkzetQJemXmzJltr7tx48YKO5mcJjLy\nPynplpOWPShpc0RcKmlzcR/AJNIy/BGxRdKRkxYvkbSmuL1G0h0V9wWgy9o95x+IiAPF7Y8klc99\nBKDvdPze/oiIsu/msz0kaajT7QCoVrsj/0HbsyWp+H2o2QMjYjgiBiNisM1tAeiCdsM/Iml5cXu5\npA3VtAOgV1qG3/ZaSX+T9HXb+23fI+kRSYts75X07eI+gEmk5Tl/RCxrUvpWxb10ZOnSpaX1005L\n89UFU8aFF15YWr/gggvafu49e/a0ve5UwTv8gKQIP5AU4QeSIvxAUoQfSIrwA0lNmetfV111VUfr\nb9++vaJOUJWnn366tN7qY9qHDx9uWvv000/b6mkqYeQHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaSm\nzHX+Tr366qt1tzApnXPOOaX1ZcuafSJcuvfee0vXvfLKK9vq6YSHH364ae3IkZO/kzYfRn4gKcIP\nJEX4gaQIP5AU4QeSIvxAUoQfSIrr/IXzzz+/tm1fd911pfXp06eX1hcvXty0Nm/evNJ1zzzzzNL6\nzTffXFq3XVo/evRo09ru3btL1z127Fhpfdq08rFry5YtpfXsGPmBpAg/kBThB5Ii/EBShB9IivAD\nSRF+IClHRPkD7NWSFks6FBHzi2UPSbpP0r+Kh62IiOdbbswu31gHNmzYUFq//fbbS+uff/55ab2b\nn/9uNRV1K8ePH29a++KLL0rX/fDDD0vrW7duLa2/8sorpfWRkZGmtdHR0dJ1P/nkk9L6WWedVVrP\nOi17RJS/+aIwkZH/SUm3jLP8dxGxoPhpGXwA/aVl+CNiiyS+9gSYYjo557/f9g7bq22fW1lHAHqi\n3fA/LukSSQskHZC0stkDbQ/Z3mZ7W5vbAtAFbYU/Ig5GxLGIOC7p95KuLnnscEQMRsRgu00CqF5b\n4bc9e8zdOyXtqqYdAL3S8lqI7bWSFkqaZXu/pF9KWmh7gaSQtE/SD7vYI4AuaHmdv9KNdfE6fyuP\nPvpoaX3hwoW9aaQN69atK63v2LGjaW3Tpk1Vt1OZFStWlNbLvndfav0+gDq/o6FOVV7nBzAFEX4g\nKcIPJEX4gaQIP5AU4QeSSvOZxwceeKDuFnCS2267raP1N27cWFEnOTHyA0kRfiApwg8kRfiBpAg/\nkBThB5Ii/EBSaa7zY+pZu3Zt3S1Maoz8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQf\nSIrwA0kRfiApwg8kRfiBpAg/kFTLz/PbniPpKUkDkkLScEQ8Zvs8SeskzZW0T9JdEVE+ZzJwCuzy\nmaYvv/zy0voLL7xQZTtTzkRG/qOSfhYRV0i6RtKPbF8h6UFJmyPiUkmbi/sAJomW4Y+IAxHxenH7\nM0nvSLpI0hJJa4qHrZF0R7eaBFC9Uzrntz1X0jckvSZpICIOFKWP1DgtADBJTPg7/GzPlPSspJ9G\nxL/Hno9FRNiOJusNSRrqtFEA1ZrQyG/7dDWC/8eIWF8sPmh7dlGfLenQeOtGxHBEDEbEYBUNA6hG\ny/C7McSvkvRORPx2TGlE0vLi9nJJG6pvD0C3TOSw/3pJ35O00/YbxbIVkh6R9Cfb90j6p6S7utMi\nsooY90zyS9Om8TaVTrQMf0T8VVKzC67fqrYdAL3Cv04gKcIPJEX4gaQIP5AU4QeSIvxAUkzRjUnr\npptuKq2vXLmyR51MToz8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU1/nRt1p9dTc6w8gPJEX4gaQI\nP5AU4QeSIvxAUoQfSIrwA0lxnR+1Wb9+fWn92muv7VEnOTHyA0kRfiApwg8kRfiBpAg/kBThB5Ii\n/EBSbjUHuu05kp6SNCApJA1HxGO2H5J0n6R/FQ9dERHPt3iu8o0B6FhETOiLECYS/tmSZkfE67bP\nlrRd0h2S7pL0n4h4dKJNEX6g+yYa/pbv8IuIA5IOFLc/s/2OpIs6aw9A3U7pnN/2XEnfkPRaseh+\n2ztsr7Z9bpN1hmxvs72to04BVKrlYf+XD7RnSnpZ0q8jYr3tAUmH1Xgd4FdqnBr8oMVzcNgPdFll\n5/ySZPt0SRslbYqI345TnytpY0TMb/E8hB/osomGv+VhvxtfobpK0jtjg1+8EHjCnZJ2nWqTAOoz\nkVf7b5D0F0k7JR0vFq+QtEzSAjUO+/dJ+mHx4mDZczHyA11W6WF/VQg/0H2VHfYDmJoIP5AU4QeS\nIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSfV6iu7Dkv455v6sYlk/6tfe\n+rUvid7aVWVvF0/0gT39PP9XNm5vi4jB2hoo0a+99WtfEr21q67eOOwHkiL8QFJ1h3+45u2X6dfe\n+rUvid7aVUtvtZ7zA6hP3SM/gJrUEn7bt9jebftd2w/W0UMztvfZ3mn7jbqnGCumQTtke9eYZefZ\nftH23uL3uNOk1dTbQ7ZHi333hu1ba+ptju0/237b9lu2f1Isr3XflfRVy37r+WG/7emS9khaJGm/\npK2SlkXE2z1tpAnb+yQNRkTt14Rtf1PSfyQ9dWI2JNu/kXQkIh4p/nGeGxE/75PeHtIpztzcpd6a\nzSz9fdW476qc8boKdYz8V0t6NyLei4j/SnpG0pIa+uh7EbFF0pGTFi+RtKa4vUaNP56ea9JbX4iI\nAxHxenH7M0knZpaudd+V9FWLOsJ/kaQPxtzfr/6a8jskvWR7u+2hupsZx8CYmZE+kjRQZzPjaDlz\ncy+dNLN03+y7dma8rhov+H3VDRGxQNJ3Jf2oOLztS9E4Z+unyzWPS7pEjWncDkhaWWczxczSz0r6\naUT8e2ytzn03Tl+17Lc6wj8qac6Y+18rlvWFiBgtfh+S9Jwapyn95OCJSVKL34dq7udLEXEwIo5F\nxHFJv1eN+66YWfpZSX+MiPXF4tr33Xh91bXf6gj/VkmX2p5n+wxJd0saqaGPr7A9o3ghRrZnSPqO\n+m/24RFJy4vbyyVtqLGX/9MvMzc3m1laNe+7vpvxOiJ6/iPpVjVe8f+HpF/U0UOTvi6R9Gbx81bd\nvUlaq8Zh4BdqvDZyj6TzJW2WtFfSS5LO66Pe/qDGbM471Aja7Jp6u0GNQ/odkt4ofm6te9+V9FXL\nfuMdfkBSvOAHJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiCp/wE+Awqah6Q+0AAAAABJRU5ErkJg\ngg==\n", 81 | "text/plain": [ 82 | "" 83 | ] 84 | }, 85 | "metadata": {}, 86 | "output_type": "display_data" 87 | } 88 | ], 89 | "source": [ 90 | "img = mnist.train.images[2]\n", 91 | "plt.imshow(img.reshape((28, 28)), cmap='Greys_r')" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "We'll train an autoencoder with these images by flattening them into 784 length vectors. The images from this dataset are already normalized such that the values are between 0 and 1. Let's start by building basically the simplest autoencoder with a **single ReLU hidden layer**. This layer will be used as the compressed representation. Then, the encoder is the input layer and the hidden layer. The decoder is the hidden layer and the output layer. Since the images are normalized between 0 and 1, we need to use a **sigmoid activation on the output layer** to get values matching the input.\n", 99 | "\n", 100 | "![Autoencoder architecture](assets/simple_autoencoder.png)\n", 101 | "\n", 102 | "\n", 103 | "> **Exercise:** Build the graph for the autoencoder in the cell below. The input images will be flattened into 784 length vectors. The targets are the same as the inputs. And there should be one hidden layer with a ReLU activation and an output layer with a sigmoid activation. The loss should be calculated with the cross-entropy loss, there is a convenient TensorFlow function for this `tf.nn.sigmoid_cross_entropy_with_logits` ([documentation](https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits)). You should note that `tf.nn.sigmoid_cross_entropy_with_logits` takes the logits, but to get the reconstructed images you'll need to pass the logits through the sigmoid function." 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 4, 109 | "metadata": { 110 | "collapsed": true 111 | }, 112 | "outputs": [], 113 | "source": [ 114 | "# Size of the encoding layer (the hidden layer)\n", 115 | "encoding_dim = 32\n", 116 | "\n", 117 | "image_size = mnist.train.images.shape[1]\n", 118 | "\n", 119 | "inputs_ = tf.placeholder(tf.float32, (None, image_size), name='inputs')\n", 120 | "targets_ = tf.placeholder(tf.float32, (None, image_size), name='targets')\n", 121 | "\n", 122 | "# Output of hidden layer\n", 123 | "encoded = tf.layers.dense(inputs_, encoding_dim, activation=tf.nn.relu)\n", 124 | "\n", 125 | "# Output layer logits\n", 126 | "logits = tf.layers.dense(encoded, image_size, activation=None)\n", 127 | "# Sigmoid output from\n", 128 | "decoded = tf.nn.sigmoid(logits, name='output')\n", 129 | "\n", 130 | "loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=targets_, logits=logits)\n", 131 | "cost = tf.reduce_mean(loss)\n", 132 | "opt = tf.train.AdamOptimizer(0.001).minimize(cost)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": {}, 138 | "source": [ 139 | "## Training" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 5, 145 | "metadata": { 146 | "collapsed": true 147 | }, 148 | "outputs": [], 149 | "source": [ 150 | "# Create the session\n", 151 | "sess = tf.Session()" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "Here I'll write a bit of code to train the network. I'm not too interested in validation here, so I'll just monitor the training loss and the test loss afterwards. \n", 159 | "\n", 160 | "Calling `mnist.train.next_batch(batch_size)` will return a tuple of `(images, labels)`. We're not concerned with the labels here, we just need the images. Otherwise this is pretty straightfoward training with TensorFlow. We initialize the variables with `sess.run(tf.global_variables_initializer())`. Then, run the optimizer and get the loss with `batch_cost, _ = sess.run([cost, opt], feed_dict=feed)`." 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "epochs = 20\n", 170 | "batch_size = 200\n", 171 | "sess.run(tf.global_variables_initializer())\n", 172 | "for e in range(epochs):\n", 173 | " for ii in range(mnist.train.num_examples//batch_size):\n", 174 | " batch = mnist.train.next_batch(batch_size)\n", 175 | " feed = {inputs_: batch[0], targets_: batch[0]}\n", 176 | " batch_cost, _ = sess.run([cost, opt], feed_dict=feed)\n", 177 | "\n", 178 | " print(\"Epoch: {}/{}...\".format(e+1, epochs),\n", 179 | " \"Training loss: {:.4f}\".format(batch_cost))" 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": {}, 185 | "source": [ 186 | "## Checking out the results\n", 187 | "\n", 188 | "Below I've plotted some of the test images along with their reconstructions. For the most part these look pretty good except for some blurriness in some parts." 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 7, 194 | "metadata": {}, 195 | "outputs": [ 196 | { 197 | "data": { 198 | "image/png": "iVBORw0KGgoAAAANSUhEUgAABawAAAEsCAYAAAAvofT2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3WeYVFW28PHdIDlDg2RawAAKioAgEgRRQYFRGZSrhBFR\nRxQDimEGBAQj6mC6omMChZlRdBAT4yCggAkRBQEltETJsYEmiP1+uHee9+61FtThVOjT1f/ft7Vc\np2rTZ9c5p7b17JWRl5fnAAAAAAAAAADIb0XyewAAAAAAAAAAADjHgjUAAAAAAAAAICJYsAYAAAAA\nAAAARAIL1gAAAAAAAACASGDBGgAAAAAAAAAQCSxYAwAAAAAAAAAigQVrAAAAAAAAAEAksGANAAAA\nAAAAAIgEFqwBAAAAAAAAAJFwwvEUZ2Zm5mVlZSVpKCjoFixYsC0vL6/q0f478wdHw9xBPJg/iAfz\nB/Fg/iAezB/Eg/mDeDB/EA/mD+IRa/78x3EtWGdlZblvvvkm/KiQ1jIyMtYc678zf3A0zB3Eg/mD\neDB/EA/mD+LB/EE8mD+IB/MH8WD+IB6x5s9/sCUIAAAAAAAAACASjusX1v9XRkZGIseBAiovLy/U\nccwfOMf8QXyYP4hHmPnD3IFzXHsQH+YP4sH8QTyYP4gH8wfxCDN/+IU1AAAAAAAAACASWLAGAAAA\nAAAAAEQCC9YAAAAAAAAAgEhgwRoAAAAAAAAAEAksWAMAAAAAAAAAIoEFawAAAAAAAABAJLBgDQAA\nAAAAAACIBBasAQAAAAAAAACRwII1AAAAAAAAACASWLAGAAAAAAAAAEQCC9YAAAAAAAAAgEhgwRoA\nAAAAAAAAEAksWAMAAAAAAAAAIuGE/B4AUJA8+uijKle6dGmVa9GihRe3bt060OtPmzbNi2fNmqVq\nxo0bF+i1AAAAAAAAgIKGX1gDAAAAAAAAACKBBWsAAAAAAAAAQCSwYA0AAAAAAAAAiAQWrAEAAAAA\nAAAAkUDTReAY5s2b58XnnntuqNfJy8sLVNe9e3cvPu+881SNbMzonHPZ2dmhxoX01qRJE5X7/vvv\nVW706NFePGLEiKSNCclXtmxZL540aZKqkdca55xbu3atF19wwQWqZtWqVXGODgAAACgcqlSponKn\nnnrqcb/Ojz/+qHIPPvigysnveosWLVI1n3/++XG/P5Af+IU1AAAAAAAAACASWLAGAAAAAAAAAEQC\nC9YAAAAAAAAAgEhgD2vgf8n9qp0Lv2f1li1bvHjWrFmqpmHDhirXvHlzL65cubKqGTx4sMrdcccd\nxztEFALt2rVTOWs/9XXr1qViOEiRrKwsL+7WrZuqseZB3bp1vbhPnz6qZtSoUfENDvmiffv2Kmf1\nQ6hYsWIqhnNUvXv3VrmvvvrKi3/++edUDQf5pH///ir32muvqdzIkSO9eMyYMarmyJEjiRoWAqpR\no4YXz549W9XMnTtX5R555BEvXrFiRULHlQiVKlVSuR49eqjc5MmTvfjw4cNJGxOA/NO3b18vtp5j\nzjnnHJWz9rWOZdu2bSpnPbedcELsJb4iRfjdKgoGZioAAAAAAAAAIBJYsAYAAAAAAAAARAIL1gAA\nAAAAAACASGDBGgAAAAAAAAAQCTRdRKHUsWNHlWvVqlXM4zZt2qRyHTp0iFmXk5OjaooXL65yq1at\n8uJatWqpmmrVqsUcJ+Cccy1btlQ5q/HPSy+9lIrhIAmqV6+ucu+++24+jARRdtlll6lc0aJF82Ek\nx3bVVVep3C233OLFbdu2TdVwkCLyueaZZ54JdJxsujh27FhVs3///tDjQmxW47CVK1d6cYkSJVSN\n1TysIDRZlP8255wrU6aMyi1YsMCLf/jhh8QOrJCzGs3JxqyNGzdWNaeffrrK0RATzjnXqFEjL77/\n/vtVzRVXXKFyssFhRkZGYgf2f2RmZibttYGo4hfWAAAAAAAAAIBIYMEaAAAAAAAAABAJLFgDAAAA\nAAAAACKhwOxhfcMNN3jx4MGDVc3mzZtVTu5d9+KLL6qa7OxslVu6dOnxDhEFSN26dVXO2nNK7kVt\n7XO9bt26UGN49NFHVc7aj1Z65513Qr0f0p+cn1dffbWqmT59eqqGgwR74IEHVK5Xr14ql5WVlZD3\nu+iii1SuSBH9/7m//fZbL2YP7fwn91Ts3r17Po3k+MydO1fl7rzzTi8uW7asqtm7d2/SxoTkk/Oz\nXLlygY6bM2eOF+fm5iZsTNBOPPFElZs9e7bKlSpVyov/+c9/qpqePXsmbFzJJPdTl3taO+fcfffd\np3LsWZ04t956q8pZz0Ply5eP+VrW+duyZUu4gSGtnHrqqV5s9dRINTk3rTUrRJO1h36dOnVUTn5X\nt3qj/fbbbyr37LPPevHHH3+satLlPsQvrAEAAAAAAAAAkcCCNQAAAAAAAAAgEliwBgAAAAAAAABE\nAgvWAAAAAAAAAIBIKDBNF2WDugoVKqia008/PebrdOvWTeUOHTqkchs2bDiO0aWGbCr55z//WdXM\nmjUrVcMp0CZMmKByVrOn3bt3e/G2bdsSNoYrr7xS5YoWLZqw10fhc+aZZ3pxsWLFVM2rr76aquEg\nwYYNG6ZyeXl5SXu/1q1bB8rt2rXLi61mWlZjLiSPPAf169dXNa+99lqKRhNcZmamyslGbzRdLNhK\nliypciNGjAj1Wi+88IIXJ/N6COc6duyocrJRmeXmm29OxnASrkWLFionG2J9/fXXqmb8+PFJG1Nh\nJBtHP/zww6pGNvYMasqUKSp3xRVXeHEiv+shuaxGsGPGjPFia21k8uTJKnfgwAEvPnjwoKqx1oyK\nFy/uxQsWLFA1sjm5c87NmzfPi63n5H379nkxzzrR0KpVK5WT39E6deqkasJetyyPP/64F1uNGbdu\n3erF8+fPVzW///3vVc6a5/mJX1gDAAAAAAAAACKBBWsAAAAAAAAAQCSwYA0AAAAAAAAAiAQWrAEA\nAAAAAAAAkVBgmi7ecMMNXnz22WermsWLF6tckyZNvPjcc89VNc2aNVO5k046yYv37NmjasqXL28P\nNgZrU/T9+/d7sdVUSI5p4MCBqoami+GtWrUqaa/92GOPqVy1atViHvfzzz+r3PTp0xMyJqSfP/3p\nT14sm4Y659yMGTNSNRzE6bvvvvPijIyMpL5fbm6uF1tNN6yGx5UqVfLimTNnqpoiRfj/48liNX+R\nzVV37Niham6//fakjSks2fwK6adNmzYqV6dOnZjHWc/OkyZNSsiYYKtRo4YX9+3bN9BxQ4cO9eJN\nmzYlbEyJJJssBvkO9be//U3lrGcthCe/MyWyUVnbtm1Vbt26dV781FNPqZr7779f5aLWmCzdWWsj\n33zzjcrVqlXLi2Vzw6OR36+bNm2qalasWKFysqn16tWrVY11/0I0yebyw4cPVzVWQ8USJUrEfO2c\nnByV+/777714+fLlqubaa69VubVr13pxvXr1VE2ZMmW8uH379qrm7rvvVjnZuDS/8Q0SAAAAAAAA\nABAJLFgDAAAAAAAAACKBBWsAAAAAAAAAQCQUmD2s33rrrWPG8ahSpYrKdezY0YutfV8vvPDCUO8n\n96t2zrkFCxZ4cXZ2tqopWbKkF//000+h3h/J169fPy++4447VE3RokVVbt++fV585513xqxB4XTy\nySerXN26db1427Ztqmbv3r1JGxPCu+yyy1ROns+8vDxVY+WCmDp1qspNmzbNi3ft2qVqLr74YpW7\n8cYbY76f3ANu9OjRMY9BME888YTKFStWzIuvuuoqVWPtpZdqmZmZXnzKKaeomrBzHNEUdB9kadGi\nRQkeCWKR+zV36NBB1cj9f51z7oUXXkjamBKpS5cuXiz3+3TOuU8++cSLrf2NEV6DBg1UrkePHjGP\n27hxo8rJXg2nn356oDHIvWdvvvlmVfPMM8+o3IYNGwK9PsIpXry4F8+ePVvVyP2qnXPu5Zdf9uKw\na0bWftUWa80GBcMHH3ygcueff74XB91Df9myZV5sPbMMGDBA5WT/IIu1937v3r29+O2331Y1sj+I\ntYb0wAMPqNxLL73kxfndh4JfWAMAAAAAAAAAIoEFawAAAAAAAABAJLBgDQAAAAAAAACIBBasAQAA\nAAAAAACRUGCaLibT9u3bVW7KlCkxj0tk48frr7/ei2WDRed0g4n//u//Ttj7I7Fat27txVaDRctH\nH33kxVZjNMA557p37x6zZvfu3SkYCY6X1TDz9ddfV7nSpUuHen3ZLPH9999XNYMGDVK5IA1df/jh\nB5WTTdSscQ8bNsyLrSYmI0aMULnDhw/HHFNhcsMNN6hcixYtVE42XJ05c2bSxhSPp59+2outBouy\nwbT1zIaCo3379jFrjhw5onK33HJLMoaDY5CfR+vzuXXrVpU7ePBg0sYUhHUPGjdunMr16dMn5mtd\neOGFCRkTbNb1QDbbW7lypaqxGvTK5wrrmnHvvfeqXKVKlby4bNmyqmbevHkqJ++9VqNzBFOuXDmV\n+8tf/uLFZ599tqrZv3+/yt19991eHOTZFulHXg/Gjh2rarp27Rrzdaw5NnHiRJWT827v3r0xXzuo\n8uXLq9wJJ/jLuH/+859VzeTJk724QoUKCRtTKvELawAAAAAAAABAJLBgDQAAAAAAAACIBBasAQAA\nAAAAAACRwII1AAAAAAAAACASaLqYD2rUqKFysrFARkaGqhk5cqQX09whGubPn69yZ555ZszjrCZY\n1113XULGhPTXvHnzmDVjxoxJwUhwvEqUKKFyYRssyoZ0zjnXsWNHL968eXOo17asWrVK5Z588kkv\nlg0WnXOuWLFiXnzPPfeoGqvx5LJly453iGmtf//+Kif/ts459/zzz6diOMfFajbao0cPL/7tt99U\nzfDhw72YRpwFh9XQqH79+jGPs86x1fQM+a9Zs2Yqt3jxYi/es2ePqpH3jXh07tzZi+U90DnnTjrp\npJiv88UXXyRsTAimZMmSMWseeeSRQK+Vm5vrxVaTtWuuuUblZNNFq7nogQMHVC6/m4umkwEDBsTM\nWY3krevPzp07EzcwFFiXX365F19//fWBjpPNEq+44gpVM2PGjPADE4oWLerF1jOS9f1IjiHItdRa\nX5w9e7bKRa25Ob+wBgAAAAAAAABEAgvWAAAAAAAAAIBIYMEaAAAAAAAAABAJ7GGdD+6//36Vk/uX\nWntlff/990kbE4KpU6eOyjVu3FjlTjjB/2jt379f1QwePFjlcnJy4hgd0lWXLl1UTu7N5Zxz69ev\n9+I333wzaWNC6q1du1blunXrpnKJ3LM6iIkTJ3pxv379VE29evVSNZy0IvfWPP300wMd98ADDyRj\nOHG59957Va5UqVJevGXLFlUzZcqUpI0JydWmTZtQx02aNCnBI0EYo0aN8uJp06apmrJly6rcKaec\nEvO1J0+eHH5gCSL3uh04cGA+jaTwuvbaa2PW9OrVS+VeeeWVUO9n9VIIwtrfnO9sidOpU6eYNcuX\nL1e51atXJ2E0SAdyb2irR4rlyJEjXtyuXTtVY33PCfJ8bq3vyf4KJ554oqqx1pHKlCkT8/2kffv2\nqdytt96qclHrFcMvrAEAAAAAAAAAkcCCNQAAAAAAAAAgEliwBgAAAAAAAABEAgvWAAAAAAAAAIBI\noOlikl166aUqd/3118c8rnfv3ir39ddfJ2RMCG/27NkqJ5tGWaxGNcuWLUvEkFAIXHLJJSpnzbuf\nf/7Zi3Nzc5M2JiRWRkZGzJqsrKzkDySEIkX8//dt/VuC/PvGjx+vch06dAg/sDRQsmRJLy5Xrpyq\nmTt3bqqGE5fTTjstZs3KlStTMBKkSvv27QPVyUZEY8aMScZwcJzkM69sDuWcc+eff77K9ejRw4v7\n9u2raqwmUm+//fbxDfB/Pffcc1785ZdfBjpONrPnuTz1Xn31VZVr0aKFFzdt2lTVnHXWWSrXunVr\nL7766qtVjbynOqevP1bNVVddpXLPPvusFy9YsEDVIJjOnTvHrGnWrJnKyc++c879/e9/9+I5c+aE\nHxgKLHk/GTx4sKo588wzVa5ChQpefP/996uavLy8mO9v1QT5LmQJ0mDRej+5dnjllVeqmnXr1oUa\nUyrxC2sAAAAAAAAAQCSwYA0AAAAAAAAAiAQWrAEAAAAAAAAAkcCCNQAAAAAAAAAgEmi6mGSXX365\nyskGVc7pRh8ffvhh0saE4P7whz94cd26dQMd99NPP3nxjTfemKghoRBq2bKlylnNFSZOnJiK4SBO\n9913n8oFaeARVX369PHiOnXqqBr577P+vX/84x8TO7A0sHv3bi/esGGDqmnYsKHKZWZmevG2bdsS\nO7AYatSooXLnnntuzONmzJiRjOEgRbp16+bF7dq1C3TcwYMHvXj16tWJGhISaPv27SpnNUqUuf79\n+ydtTM4Fa+hqXTutpnxIrbfeekvlnnzySS+27ifffvttqPdbsmSJysmGirLZqHP6nuqccyNHjvTi\n7t27hxoTnCtdurTKyefEE07Qy1Y33XSTyslnyalTp6qaTz/9VOVkY/Ply5ermvnz56ucZH1nmz59\nuspxn0su2dj3nHPOUTWVK1dWOXn9Oe+881TNrl27VG7NmjVeXKpUKVXTuHFjlatXr57KhfH++++r\n3LXXXuvFO3bsSMh7pRq/sAYAAAAAAAAARAIL1gAAAAAAAACASGDBGgAAAAAAAAAQCexhnWByD6aL\nLrpI1Rw5ckTl7rrrLi8+fPhwYgeGmKpVq6ZyI0aM8OKiRYsGeq2FCxd6cU5OTviBodCpVauWFzdp\n0kTVWHvSvvzyy0kbExLHui9EUfXq1VWudevWKjdkyJDjfm25t5xzeh9b6L/TunXrVI11Tr7++msv\nfuyxxxI2pjPPPFPl5L58NWvWVDVB9mkvyHu5w7mqVat6cUZGRqDjvvjii2QMB4XEc889F7NGfs9y\nzrlNmzYlYzg4DtazrNzzfMKECaqmZMmSKifvH9b+6v369VO53NxcL37vvfdUjdwL1jnn2rZt68WN\nGjVSNbJHFWyTJk1SubB7zMv7jtVPzMolk/XM+91333mxnE9IPmtPZ9m/LJFmzZqlckH2sD506JDK\n3X///V78xBNPqBprzbEg4hfWAAAAAAAAAIBIYMEaAAAAAAAAABAJLFgDAAAAAAAAACKBBWsAAAAA\nAAAAQCTQdDHBZGOj2rVrq5pFixap3EcffZS0MSGYhx9+WOWCbIQvm1s559yNN96YkDGhcJJN7GQz\nV+ec+/LLL1M1HBRSTz/9tMr17Nkz1Gvt2rXLi62mJtnZ2aFeuzC55ZZbVM5qONaiRYuYNWHJBlXO\n6WZX1jUriMcffzzUcYiGIM2KDhw4oHJjx45NwmiQjv74xz+qXMeOHb3YalC1cePGpI0JifXmm2/G\nrLn++utVTjZwvOGGG1SNdf+SBg8erHJW8/Mg99lOnTrFfD/oRpvOOffKK694sTUvihYtqnLly5f3\n4qDNf5PJeiY699xzvdh65r711luTNiYkl/Vc065du1CvNXToUJV75plnQr1WQcQvrAEAAAAAAAAA\nkcCCNQAAAAAAAAAgEliwBgAAAAAAAABEAgvWAAAAAAAAAIBIoOliHPr27atyN910kxcfPHhQ1dx7\n771JGxPC69evX6jjevXqpXI5OTnxDgeF2MknnxyzZuvWrSkYCQqT7777zovr1q2bsNdes2aNF0+b\nNi1hr12YLFy4UOXatGmjcrKxS6NGjRI2hhdffDFmzcyZM1WuQ4cOMY/bv39/qDEh9bKyslQuSEMh\n2YDVOXu+AJYgjX+/+uorlfvss8+SMRykgNVsL0hjxrCs+9CECRNUTjZdbN68uarJzMz0YtkYEv/j\nyJEjKifvC/JveTTye3mxYsVUzYMPPqhy9erVC/T6iSKbQbZu3Tql74/Euueee7zYat5apEjs3wpv\n3rxZ5f7617+GH1ga4BfWAAAAAAAAAIBIYMEaAAAAAAAAABAJLFgDAAAAAAAAACKBPawDqlatmso9\n9dRTKif3I5o/f76qmT59euIGhnx34oknqtyhQ4cS8to7duxQucOHD6uc3J+rcuXKMV+7atWqKmft\n6RXEr7/+qnJyT/B9+/aFeu3C6Pzzz49Z8/bbbyd/IEgKeZ84Wk665pprAr3+888/78Vly5YNNa68\nvLxAxwXRrFmzhL0WYpszZ84x42RbtmyZygXZw7pVq1YqZ+1Hi/zXtWtXlQtyHXv//feTMRwUEtY+\nr/K5ePjw4akaDgoJ+VzlnHNXXXWVF7dt21bVjBw50otvueWWhI4L2ltvvRWzxtpv/I477vDi3377\nTdV89NFHKvfEE0948ahRo1RNkP4OKDg6d+6scvK8Fy9ePNBryTWjgQMHqpoDBw4cx+jSD7+wBgAA\nAAAAAABEAgvWAAAAAAAAAIBIYMEaAAAAAAAAABAJLFgDAAAAAAAAACKBpotHUbRoUS+2midWrFhR\n5Xbu3OnFN954Y2IHhsj5+uuvk/ban3/+ucqtX79e5WrWrOnFVuOPVHvooYe8+LbbbsunkURbjx49\nVK5MmTL5MBKkyosvvqhy99xzT8zjXn/9dZUL0hgxbPPEsMdNnTo11HFIH2Ebi9JgseDIzMyMWbN/\n/36VGzZsWDKGgzRkzRXr+UjOs88++yxpY0LhZDXgu++++7x41qxZqmbQoEFe/MILL6iaxYsXxzk6\nHK93331X5WTTxSJF9O86L730UpVr0KCBF5966qmhxrRhw4ZQxyH1rrzySpUL0mRRNgh2zrmrr77a\niz/44IPwA0tT/MIaAAAAAAAAABAJLFgDAAAAAAAAACKBBWsAAAAAAAAAQCSwh/VRNG7c2Ivr1KkT\n6LghQ4Z48bJlyxI2JiTXt99+q3ItW7bMh5H8f23atEnYa8n914LuTyv36J43b16g42bOnBlsYIVc\n7969VU7u9WrtW/7Pf/4zaWNCcr388ssqN3jwYJUrXbp0KoZzVNb+s9ZcvOKKK7x47dq1SRsTCgbr\n/hJ2T3REk9V/Qdq+fbvK7dixIxnDQRq66aabAtVZ/V6kChUqqFyVKlW8ODs7O9jAAKe/Dz355JOq\n5u677/biv/71r6qmU6dOKmc9fyFxvvnmG5WT5/O8884L9FqnnXZazBprD3S57tC3b99A74fUsu4d\nAwYMCPVaH3/8scq98847oV6rMOEX1gAAAAAAAACASGDBGgAAAAAAAAAQCSxYAwAAAAAAAAAigQVr\nAAAAAAAAAEAk0HTROdegQQOVmzNnTszjHnvsMZWbOHFiQsaE1GvVqpXKjR071ouLFy8e6rWbNWum\ncm3btg31Wv/6179Ubvny5TGPe+2117x44cKFod4f4ZUpU0blOnfuHPO4KVOmqNyRI0cSMiak3qpV\nq1SuT58+Kicbcl511VVJG5Pl8ccfV7lRo0aldAwomII2DP3111+TPBIkQrFixVSudu3aMY87fPhw\noBwQD3kdufXWW1XNXXfdpXIrV670Yqv5HRDUuHHjVG7gwIFefM4556iapk2bqtyXX36ZuIFBsZpa\nymfsDz74QNU0bNhQ5eR3u127dqmav//97yo3aNCgmONE6pUrV86L161bp2qKFIn9m9+NGzeq3JVX\nXhl+YIUYv7AGAAAAAAAAAEQCC9YAAAAAAAAAgEhgwRoAAAAAAAAAEAksWAMAAAAAAAAAIoGmi865\n++67T+XKly8f8zir+V1eXl5CxoRoGDp0aH4PAWnk0KFDKpeTk6Nya9as8eLhw4cnbUyIhnfffTdm\n7r333lM1t912m8q1aNHCi+fPn69qnnrqKZXLyMjwYpr+IKxevXqp3MGDB1XuiSeeSMVwEKfffvtN\n5ZYsWaJy1atX92J5LwOSoUuXLseMnXNu+vTpKnfzzTcnbUwofDZt2qRyssmibPTpnHOPPvqoynXo\n0CFxA0Mgv/zyixc3a9ZM1dx+++0qd/7553vxTTfdpGqsBnyIpp49e3qxbMLoXLD1Puv7WW5ubviB\nFWL8whoAAAAAAAAAEAksWAMAAAAAAAAAIoEFawAAAAAAAABAJBS6Pax79Oihcn369MmHkQAobA4f\nPqxyDRo0yIeRoCCaPHlyoByQ35YvX65yDz30kMpNmTIlFcNBnI4cOaJyAwYMULmXX37Zi+fOnZu0\nMSH9WXvBWvv9zpo1y4vHjBmjarZt26ZyVl8RIJGys7O9eOnSpaqmdevWKte8eXMvXrBgQWIHhlDG\njRsXKIeC68EHH/TioP3pXn/9dS/m+TZx+IU1AAAAAAAAACASWLAGAAAAAAAAAEQCC9YAAAAAAAAA\ngEhgwRoAAAAAAAAAEAmFruni+eefr3LFixePedzOnTsD5QAAAAqzs88+O7+HgCRbu3atyl144YX5\nMBKkq2nTpgXKAQVF27ZtVe7nn39WuSZNmngxTReB1ChbtqwXZ2RkqJp9+/ap3LBhw5I2psKOX1gD\nAAAAAAAAACKBBWsAAAAAAAAAQCSwYA0AAAAAAAAAiAQWrAEAAAAAAAAAkVDomi4G9csvv3jxWWed\npWq2bduWquEAAAAAAIACaNeuXSpXqVKlfBgJAMtzzz3nxffdd5+qefzxx1Vu3bp1SRtTYccvrAEA\nAAAAAAAAkcCCNQAAAAAAAAAgEliwBgAAAAAAAABEQqHbw3rIkCGBcgAAAAAAAADS25/+9Kdjxkg9\nfmENAAAAAAAAAIgEFqwBAAAAAAAAAJHAgjUAAAAAAAAAIBJYsAYAAAAAAAAARELopot5eXmJHAcK\nGeYP4sH8QTyYPwiLuYN4MH8QD+YP4sH8QTyYP4gH8wdh8QtrAAAAAAAAAEAksGANAAAAAAAAAIiE\njOP5eX5GRsZW59ya5A0HBVy9vLy8qkf7j8wfHANzB/Fg/iAezB/Eg/mDeDB/EA/mD+LB/EE8mD+I\nxzHnz38c14I1AAAAAAAAAADJwpYgAAAAAAAAAIBIYMEaAAAAAAAAABAJLFgDAAAAAAAAACKBBWsA\nAAAAAAAAQCSwYA0AAAAAAAAAiAQWrAEAAAAAAAAAkXDC8RRnZmbmZWVlJWkoKOgWLFiwLS8vr+rR\n/jvzB0efyxe1AAAgAElEQVTD3EE8mD+IB/MH8WD+IB7MH8SD+YN4MH8QD+YP4hFr/vzHcS1YZ2Vl\nuW+++Sb8qJDWMjIy1hzrvzN/cDTMHcSD+YN4MH8QD+YP4sH8QTyYP4gH8wfxYP4gHrHmz38c14K1\neIOwhyKN5OXlhTqO+QPnmD+ID/MH8Qgzf5g7cI5rD+LD/EE8mD+IB/MH8WD+IB5h5g97WAMAAAAA\nAAAAIoEFawAAAAAAAABAJLBgDQAAAAAAAACIBBasAQAAAAAAAACRwII1AAAAAAAAACASWLAGAAAA\nAAAAAEQCC9YAAAAAAAAAgEhgwRoAAAAAAAAAEAksWAMAAAAAAAAAIoEFawAAAAAAAABAJLBgDQAA\nAAAAAACIBBasAQAAAAAAAACRwII1AAAAAAAAACASTsjvAQD5oWTJkipXvnx5lbv00ku9+Nxzz1U1\n1atXV7kmTZrEfL/Dhw+r3J49e7z41VdfVTUvvfRSzOPy8vJUDQqOjIyMmDVhz3GRIvr/U8r3s15b\n5phjBUexYsUC1f3222/HjJ3jvAMAgPQR5BkYyA/J/D4IFBT8whoAAAAAAAAAEAksWAMAAAAAAAAA\nIoEFawAAAAAAAABAJLBgDQAAAAAAAACIBJouolA44QR/qteqVUvV9O7dW+Wuv/56L65SpYqqKVGi\nhMoVLVrUi61GdxbZOOGBBx5QNfv371c52ZzxwIEDgd4PBYPVdMOadzVq1PDili1bqpoTTzxR5Vas\nWOHFX3/9taqRjT1//fVXe7BIKXltc865s88+24ufeuopVdOwYUOV27lzpxcPGTJE1UyfPl3lmAv5\ny7q/yJx1DbEa9chGm0EasMZDjjNIU9gjR46oGqtBKKLJmovyvFuNqkuVKqVy8lnHej5ibqRekGfe\nZF9bgCAN65L9fvL7oCXsvTjdyb9n0O/SUfzbVaxYUeW6d+/uxVlZWaomJyfHi998801Vs2nTJpXj\nvod0wi+sAQAAAAAAAACRwII1AAAAAAAAACASWLAGAAAAAAAAAEQCe1ijUAiyf5W1P+LBgwe92Nob\n2nptue9WsWLFVI21F5d8Lbl3lXPOZWdnq9yhQ4dUDgVXkPlq7Rss513z5s1VzRlnnBHztRYuXBhq\nTEgua3/E8uXLq9zNN9/sxU2bNlU11t7XMif3wnbOuRkzZqgce1gnT5A9HK17l9yrvnr16qrG2ltz\n7dq1Xrx582ZVE+Q+aM1VKyf3Ki5Xrpyqkf9muZ++c87l5uaqHHs4JlfQfdGDkNexm266SdV07dpV\n5aZOnerFL7/8sqrZvXu3F3MvC8+6b1SoUEHlGjRo4MXWs+yOHTtUTn62Dx8+rGqsPewTdU6ta6Kc\nm1YfEOuaKK+d1jUK4QW5x1jn07ovWHMqjLDfB61nKHnfs8Zt5dLp+ib/BlbvHov8ewbtexHkb2fN\nqSZNmnjxxIkTVc0pp5yictb1VJLjHDlypKr55JNPVO6OO+7w4nXr1qmadJorSG/8whoAAAAAAAAA\nEAksWAMAAAAAAAAAIoEFawAAAAAAAABAJLBgDQAAAAAAAACIhALTdFFuTF+8eHFVU7p0aZWTDXys\nphdWAx/ZQIPmPenFauSyZcsWlfvss8+8+Ntvv1U1c+fOVbmtW7d6ceXKlVXNpZdeqnJXX321F+/f\nv1/VbNq0SeVonJDerPNrXZNko5FKlSqpGqvhi2zGsXPnzkDvh9Symgyde+65Knf55Zd7sWxsdzSy\nzmq6aM0peU3iepQ48m8ZtFFQs2bNvPiCCy5QNVajp3feeceLrftiEEGaTzmnmypZDdxkU0mrgRIN\nzcIL0rzMOp9hP+fW+zVu3NiL7777blVjPePLBniTJk1SNbLpIoKTzwvyPDnn3AMPPKBysrnzkiVL\nVM0LL7ygcvJ5OmgzvLDXSXn9sZoYd+vWzYs7d+6salasWKFy7777rhcvXbpU1SSq2V+6s64Z1vWg\nUaNGXlyzZk1V8+WXX6qc/M6WyGcYeY6tRntWM2XJOs5q9plO5Of44MGDqsZqgij/Vtbf17pGyOtd\n+/btVc348eNVrnbt2jHfL4iw1yhrPUHO/euuu07VWNck5L+w8yedv6fzC2sAAAAAAAAAQCSwYA0A\nAAAAAAAAiAQWrAEAAAAAAAAAkRCJPazl/kNly5ZVNaeeeqoXt2zZUtWccsopKif38LH2/7X2Hlu1\napUX//LLL6rG2vta7o1s7ZVs7cVVsWJFL5b78gUdk7UnJfS+VNYeUN99953KTZ061Ys3bNigaoLs\nQbd58+aYNc45d/PNN3tx/fr1VY019xcvXuzFUdhDNsgemOnO+qxLify7yH3UTjvtNFWzbNkylfvi\niy+82LpuFcbzFzXW/tEvvviiysl7qDUPrfMpr5PWXp133HGHyo0bN86Lresde3UmhnXerJ4enTp1\n8uKOHTuqmnnz5qmcfM6w9sgMey2wjpPPLNa9We6ba+3TZ+27L9+Pa1h4Qa8hQVj7Mw4dOtSLrXlg\nkfcq67mc826T59Q6L/I71COPPKJq2rRpo3Lytax9gxcsWKBye/fu9eJEXmuC1FnfPy+++GIvPv30\n01XNjh07VE7e85iHwcm5Wb16dVXz0ksvqVy7du282LpXfPjhhyp35513enEiewXJ44KuC8jnscI4\nf4LsTW9dt+QzkbX/t+yN4Zz+fv3EE0+oGvk9yxqDda6sZ+CcnBwvtuaG3Kvd6kNk/V1q1KhxzNdB\ncEH6fDin1zOt55guXbqo3MiRI73YmmOHDh1SuR9//NGLp0yZomqmT5/uxevXr1c1ch46p+dUfu+P\nzS+sAQAAAAAAAACRwII1AAAAAAAAACASWLAGAAAAAAAAAEQCC9YAAAAAAAAAgEhIedNFuSG5c3pz\nfKuxlGweJhsKOWc3qCtTpowXn3HGGaqmffv2Kic3469QoULM13ZOb7QvG4g459zBgwdVTm6Gb/2d\nlixZ4sU33nijqvnpp59ULshG6YlsqJPfgmyEv3v3blWzZs0alUtUAxiL1bxMbrRvvZ9sKuKcc6+9\n9poX5/fm+M4V3PkTVpCmDIn8m5QsWVLlevTo4cVW89Z//vOfKic/D4Xt3EWVbK5iNdSwmhEFafZp\nkcdZzfys+45szjh+/HhV849//MOLaY6WONbn/IILLvBiq5nYnDlzVE42m0r1vaRWrVoqJ5//tmzZ\nompk8xnnmE9BpfreJZuMO+fc+eefH3NM1lycMGGCF+fm5sY3uEJENgqzvnfI+0tWVpaqsY7bunWr\nF48ZM0bVWM15U/2ZlfPM+o7YvHlzL7Yao82dO1flVq5c6cVReC4vKOS8s5p2Wo3J5Pm0/uYXXnih\nyt17771ePHr0aFWzfft2lQsyX2VN0AbUQZp2co+zn1NlszurYd2+fftUbv/+/V5srddYOXkerIbW\nt912m8rJJtdBrsHyXumcc6eccorKffXVV168aNEiVQP7WUM26axWrZqqadWqlcrJa0vXrl1VjfWd\nzWqkGWtMzukGwCeffLKqGTRokBd/8803qmb48OEqF7X7F7+wBgAAAAAAAABEAgvWAAAAAAAAAIBI\nYMEaAAAAAAAAABAJLFgDAAAAAAAAACIh5U0XrU27ZWMBq1GhbAqWnZ0d6P127drlxdbG+1WqVFE5\n2XjDanBmbZJ+4MABL167dm3MGuecO+uss7xYNmF0zrkmTZp4caNGjVSN1XQxiHRv3CCbJMjGUs45\n9+uvv6pcov4usimZc871799f5WQTHKu5w+uvv65yQZt4IHmCNK6yBJljcl44Zzc/+t3vfufFVpOG\n7777TuWSOX8S9TdId9Y5luezbdu2gY6TrL9vkHuxVWPNKdkUb+zYsaqmWbNmXjxs2DBVs2PHDpVj\nbvispjy9e/dWuapVq3qx1ajw888/Vzk5B4Je14I0u7LIZ53/+q//UjX16tXzYmvczJPwktl00bo+\ndenSReXkPLDeTzbEck4/DyVyHqT7vUv++6xri2yCajU4sz7rH3/8sRevWLFC1YT924VtEm8dJxvW\nPvDAA6pGfkf86KOPVI3VdNFqzgjNagj8wQcfeHGdOnVUTZB5YH2Hsq5J8h5qNbEbMmSIysnGZEG+\nRwad9wX52pIs1rkrVaqUyuXk5Bwzds6+bsnmnjfccIOqsRpDL1y40IvlvHDOnhtB/Pzzz15srStZ\na0ZyrYnrkX2Pk8/Kzjl3zTXXePHAgQNVjdWIUc7FoM/P8lytW7dO1VhN0uX3KtmE0Tl9fa1cubKq\nsZqWP/zww14cdv4mCr+wBgAAAAAAAABEAgvWAAAAAAAAAIBIYMEaAAAAAAAAABAJKd/D2tqTSe6L\nIvedds65Tz/91It/+OGHQO+3b98+L7b2sC5RooTKyf2hzzvvvEDHTZ061YutfWHknmnOOffvf//b\ni8uUKaNq5N471p5M7HkVbI4lW8WKFb34jTfeUDXWXkpy/9Dhw4ermlmzZqkc5z2aErUPqDVXzjnn\nHJWTe6utX79e1cj90BIp3ff8TCZrP7RXXnnFi639oy3yb7x161ZVs3TpUpWT161ffvlF1dStW1fl\n6tev78XWnoK9evXy4sWLF6ua1157TeWsfWsLs5o1a6pcz549VU7OFfkM5ZxzGzZsUDk5d6zPtHU9\nkqz9Ia39J3v06OHFl1xyiaqR+/tZfUCC7pkNzbomB7mWW+Rx1p7HHTp0UDk5p4LsL+qcc9u2bTve\nIYaWbveuIJ91+T3H+m5ikd9PEvm3s64jQV6/Ro0aKiefp629kuV98LnnnlM127dvDzWmwibI/tHO\n6T5SFuv7/IwZM7xYfrd2zrl+/fqpnNz7tV27dqpm0KBBKjd69Ggvtp61kDjly5dXOeuaJHt2BF0D\n2LNnjxd/8cUXqsb6XMtcMvsCWa9t9X6Dvt7Url1b1bz00ksqJ/sFBVmvcU734ZFrgs7ZazhfffWV\nF1t91qy+eU888YQXy354zum/gfX9rHHjxoHeLz/xC2sAAAAAAAAAQCSwYA0AAAAAAAAAiAQWrAEA\nAAAAAAAAkcCCNQAAAAAAAAAgElLedNESZLN62cBDNlN0zm7SInNBm8vIRi5yQ3TnnDt8+LDKWU0g\ngozTaiIkyaYBq1atUjU0+Ug9q4nI9ddf78Vly5ZVNdb8efvtt734qaeeUjVhmzkkqgEgbNbfU37W\nwzYGs5rtde7cWeVkk4Tdu3ermp07d6occyG1rPP5+uuvq5x13ZCsZjLPP/+8F48YMULVWPe9ChUq\neLE1X1u1aqVyd911lxc3bdpU1chmOXfffbeqee+991SusDddlM1errjiClVjNXKWTYcefvhhVWP9\nbYM0Ygv7rGU1qh4yZIgXWw1h5D3PakhD08XEkucv7D3COp+dOnVSOTnPrOecCRMmqFyqG2qnE/k3\ntz7r1atX92KrwZn1uZZNqypVqqRqrGcRyWraWbp0aZWT86VBgwaq5s0331Q52UT44MGDqubRRx/1\n4oULF6oa5mEw1vWgb9++MY+TzfCcc65NmzYqt2LFCi+2nqEGDhyocvI+a30WMjMzVc6aL0iccuXK\nefFVV12lapYvX65yVkPpIOR9LujnWs6XoI1hE/Xdi+9wtpIlS3rx0KFDVU379u1VTn6Xtp5HrMbx\nAwYM8OIff/xR1QR5frZY3xs7duzoxUEaoltzUzaLdC65jUPD4BfWAAAAAAAAAIBIYMEaAAAAAAAA\nABAJLFgDAAAAAAAAACKBBWsAAAAAAAAAQCREsumiJUjzsrCbzlvHyY32rc3Hwzb5kU1MnHOuTp06\nMV9bNn785ZdfQr0/ggvSSKFZs2Yq179/fy+2GjdYTSGGDRsW87gggjZ8QHKFvUbIeScb1jlnN4qQ\nPvroI5UL28TOakITBPNOq1evnsrJJlUWaz49/fTTKicbi1jHWdcI2aTTOm7r1q0q17JlSy9u0qSJ\nqpHNQKpUqaJqatWqpXLr16/34sI2nypWrOjFPXr0UDXW88mzzz7rxWvXrlU1Qf6WVk3YZizWvG/Y\nsKEXW9eZlStXevG6detCvT9sQZpBBf3cyfN36qmnqhrrcy5Z96lPPvkk0BjCsOZdYbvWWJ9r+Tew\n7htWrnHjxl78+OOPqxqraZXVWEqymtRXrVrVi2+44QZVY91zpLlz56qcbIhsNUxHMFYTROv5VjYB\n69Onj6pZtmxZzPerUaOGyp188skqJ+ew9d1LPotYxyE8q2HcoEGDvPi6665TNc8884zKzZ49O9QY\nglzzrXuFnAfWvOC6kXryvtCzZ09VIxssWnJzc1Vu3LhxKpednR3ztay5IeeU9VmwGkbK52eLnNNW\no9j58+ernPz+l9/PSFxpAQAAAAAAAACRwII1AAAAAAAAACASWLAGAAAAAAAAAERCJPawDiKZ+6QE\n2Zs17F601n5s9913n8qVKFHCi639+0aOHOnF1j40SCy531q3bt1UzcMPP6xyct+knTt3qhpr361N\nmzbFHFOQ/bMsYecwtLD7OQclz6e1J3CFChVUTu4v/M4776iasPuiJ/M6Wdh07dpV5YLso7Z3716V\nk/cF54Kdh7DnytrLTe6PHPbzUbJkyVDHpQvrOn7aaad5cWZmpqrZvn27yr333ntenOrPpjUHLr30\nUpWTz0jW/Bo9erQX5+TkxDm6wivoZzPsM7d8lr322mtVjXWtk/PT2q9627ZtKhd2z9Ewr5Nu5N/c\n2sNa7m1p7YV/0kknqZycB7169VI1V155pcrJ55PVq1erGqt/T4sWLby4cuXKqsaaB3I/bNlHxqpB\ncEH6AFl7Q//0009evGTJElVjvZbcD/vtt99WNUGetQ4dOqRylSpVUjm5H/+ePXtUTdieD+lOzo26\ndeuqmltvvdWLrc+1tS/x+PHjvTiRzz/WvSLIHtaF8R6TStb1Xc6pMmXKBHotea6s/cebN2+ucnLP\n/NKlS6sa6xm3VKlSXvz73/9e1TRq1EjlrL2uJTn2efPmqZo5c+aoXNTmK7+wBgAAAAAAAABEAgvW\nAAAAAAAAAIBIYMEaAAAAAAAAABAJLFgDAAAAAAAAACKhwDRdTJSgDevkZuNBNx+Xr3/yySermu7d\nu6ucbMowbdo0VbNw4cJQY0Iw1ub1/fv39+JHHnlE1ViNwoI0jpFNsZwL1kjTGqdsXGU1DEE0Wdck\nOacuueQSVWPNg3//+99enJ2drWpoUpV68h7TrFmzQMfJ+8KECRNUzb59+8IPLEFkU4+wDToL+3XL\n+kzLRitWwyirSbP8fFrnJJGfYfn6FStWVDU33XSTysl5YDV/mTFjxjGPwdEl81puvbZsCmo97wZp\nfjd27FhVE7ZhMGzyvFufqx9++MGLhw4dqmpuueUWlZPffaznZKsZnWzqOHXqVFVjXVvOO+88Lw56\nvfvuu++8eOnSpYGOQzhWA8tvv/1W5WRzPatpp9WA75prrvHirKwsVWOdTzmuLVu2qJpWrVqp3N13\n3+3F1udj8+bNKgf9XNy5c2dVI+8n8ruuc/Z5ad26tRdbjebCPkdYz2lWQ07Jun9xT0scay2vatWq\nXrx7925VE/TeJPXt21flypUr58XWXLGuP/J+Zc1zi3wt6zvUzJkzvXjMmDGqZtOmTSoXtbnJL6wB\nAAAAAAAAAJHAgjUAAAAAAAAAIBJYsAYAAAAAAAAARAIL1gAAAAAAAACASEj7potyI3NrA3SrOYfc\ncD1I4xrnnCtVqpQXX3nllaqmRIkSKrdhwwYvHjFihKqRja2QWHKzfOecu+eee7xYnt+jyc3N9eJx\n48apml9++UXlZBOIoI1j5Eb7NKVKrKCf/0S9dpUqVbzYakZiNZj48ssvvdhqcJMoNCIKTp6rU089\nVdVYTT7k+Xv11VdVTarPg9UMRDZWs+amZDX0WLduncoVxHkW9Lot66yGivJabj3DlClTRuXkOXnz\nzTdVzc6dO1VOzsMgTX6dc658+fJefO+996qa2rVrq5z89/39739XNcm8jiE863N+5plnenGQZlTO\nObd161YvXrRokapJ1LWgIF5TkiFI08WcnBwv/te//qVqPv30U5UrW7bsMd/LOf2c7JxuImwdV6tW\nLZXr06ePF1vP6lZDqgcffNCLrcbn8jrN/AlO/q2sv+/PP/+schdddJEX/+EPf1A11jmW58p6zrAa\njL3zzjtebDX27NKli8pddtllXmw1WPzTn/7kxXyX/x/y/tGkSZOYNdazlbWm8u6773qx/Jw759x7\n772ncrIpX82aNVXNoEGDVO6UU07x4m3btqma999/P+Y4rWcyrjfBWH8n2Wxz2LBhqsaad/J51ppj\nTZs2Vbn69et7sdXQ0XqmDvKdybo/y/vlQw89pGomT57sxdbctK5J8v3yex7yC2sAAAAAAAAAQCSw\nYA0AAAAAAAAAiAQWrAEAAAAAAAAAkZBWe1hbe8AE2f/I2js0yF4t1n6Tbdq08eJevXqpml27dqnc\nI4884sXpspdnQWLtTyb3Erbmj7Un2/Dhw73Y2rsq7D5m1j5GzI38F/YcWHOqYcOGXlynTh1VY+3N\nt3DhQi9O5F7mzLHw5P5n1apVUzXWuZL73Fv73gfZG9naM806Lsj+xRdffLHKdejQ4bjf7+uvv1Y1\n1t5qBVHYz4o1B5YuXerFy5YtUzXWnug9e/b04pYtW6qatWvXqpy8L1nPVXKfR+ecq1Gjhhdfe+21\nqsZ6ZpL3z/Xr16saejIEY91Lkrn/rvV+8jmqePHiqsYaw6xZs7w47L7lQfePh2b9nYLsc209y8q9\nrxNJ7nfunL4myeuRc3qOOefczJkzvTjI83WQz9nRXquws+bKxo0bVU72SbCuI9bfXL7+22+/rWpG\njhypcnv27PHiHj16qJpu3bqpnHy2s77z/+Uvf/Fi6zmuMJLPA1u2bFE18nuO9WxpkT2pxowZo2pG\njRoV83WsZxbrmUheI6xr6YUXXqhyHTt29OLbb79d1ezYsSPmOGFfb+W94o033gj12tY8sPbQP+us\ns7z4uuuuUzVdu3ZVOblntnVts66T11xzjRdb36vkM3ZBfR7iF9YAAAAAAAAAgEhgwRoAAAAAAAAA\nEAksWAMAAAAAAAAAIoEFawAAAAAAAABAJBTYpovWpveySYNzenNxq1FZkA3IrferX7++yslmDrJp\nn3POzZs3T+U++OADL7YaQSJxrLny5z//WeVkQw1rU/8FCxao3CuvvOLFVmNGa95Z80xKZiMXq6FF\nkKZNQcZUUDf6TzbrnLdr186LS5curWqsBgwrV65MyBisMXFNCq9MmTJeXKFCBVVj/c3lcdb9K0gD\nKKthUZCmr02bNlW5v/3tbzHHaZHXCHmNdC58o7V0YZ3fJUuWePH48eNVTadOnVTutNNO8+KsrCxV\nYzVilHNnw4YNqmby5MkqJ+8dQe4lVl2tWrViHse9JLhk/q2sz/0VV1zhxdZ17dChQyr3wgsveHHQ\n5xw5N7h3pb969eqpnGyyuHfvXlVz8803q1xubq4Xc21JriCNhZ1zbsqUKV5sfa6te8zDDz/sxR9+\n+KGqCfKcIZtxOufc6tWrVU7eZ61n9dNPP92LN23apGoKY4NO+Qz6+eefqxrZRE42tXMuWCNGq6Zk\nyZIqF/Y7uLxuWGsMVuO+3r17e7H1HU42jCyMcyWsRP2trGcI6zoi1/cqV66salq3bq1y8rphrRmN\nHj1a5eTnI+haU0HEL6wBAAAAAAAAAJHAgjUAAAAAAAAAIBJYsAYAAAAAAAAARAIL1gAAAAAAAACA\nSCiwTRet5j1BmruE3Xzcalr1yCOPqFyzZs28eMuWLarm1VdfVbmtW7d6cbpskh5VNWvWVLmGDRuq\nnJxTVlOsF198UeWCNPWwmjKUK1fOi2VDGOfsTfWlIM2trGYADRo0ULk9e/Z4sdV4xGqiJOdwkCZv\nUZXMz6N1benSpYsXWw1D1qxZo3JB5l2QuWHNTdm8gmtUcPLvG7SBUKVKlbxY3l+cc+7TTz+N+f6y\neaxzzpUvX17lLrvsMi8eO3asqilbtmzM97Ps3r3bi61mSIVtTgVpCi2bh33yySeqZv78+SpXu3Zt\nL5bNoZxzrlGjRion7znWecrOzla56tWre/Fdd92lakqVKqVyct5bTZUmTZrkxYVtngRl/V2S2bDy\nxBNPVDl5zbLs379f5YI0DA7y3G9dW7l3JZd1XoIIch6sZ5Hnn39e5WQj4/fff1/VrFu3LtQYEnEM\n/of1t9u8ebPKPffcc1787LPPqhrre4dsjmZdD6z5KsdlNRu+5557VE42ebSe5+W90Wr2Z33XS/d5\nJp935s6dq2pkE98qVaqoGuv5tn79+l58+eWXq5qePXuqnPWMIlnPaUGe8S3y+ta9e3dV89BDD3kx\nTRejwbqOyPk5fPhwVVOnTh2Vk/Nn2bJlqmbatGkqJ9eD0vmawS+sAQAAAAAAAACRwII1AAAAAAAA\nACASWLAGAAAAAAAAAERCgd3D2trDx9rbN1H7uVx88cWBcnIfGmu/SStn7YmE5LH2ySxWrFio16pX\nr57KyX0crT3LrD215Lg+/vhjVfPFF1/EHJO11+3gwYOP+V7OObdv3z6Ve+qpp7z4o48+UjU5OTkq\nJ/ep3Llzpz3YQsTa86patWoq17Rp05jHffXVVyoXZJ9w65oor6dyH8CjjQHByH2I5b7wztl7v8r9\n9ORn0TnnRo0apXJyH7ULLrhA1bRp00blqlat6sXWvtoWOaes/Rh///vfe7G1jy20IL0AduzYoXJy\nz3Brj2DrWUTug2/dE6xriNwf1tov1tp/Ur5W6dKlVY2ch+zhGFyinoGt6/+ZZ56pcvI5ynr/RYsW\nqdyuXbtijsF6LZljbiRXkB4Yzuk9XK1niiDPGdZ+n+ecc07M9/vss88CvV+ipPPeofEIsoe+tRe1\nlQvzftYe6Nb+wnL9wFpPsObUnXfe6cWtWrVSNfI5LjMzU9VY/a7kvd66thXkeSfHbp3z7du3HzM+\nmgitTUcAAAsFSURBVCVLlnixtT+21Y/l0ksv9eIg+507F3zP6lis3kRyDrNeFA3W8+zixYu92OoV\nZs0p+VkfN26cqrH2+i/In//jxS+sAQAAAAAAAACRwII1AAAAAAAAACASWLAGAAAAAAAAAEQCC9YA\nAAAAAAAAgEgosE0Xk73RePHixb346aefVjUlSpRQOdmgaPTo0apGNjVC6lnNK6yGLLKZjNVYYejQ\noSo3YMAAL7Y255eNOJzT81o2KnPObs4hm0dYDdyCNJWU8945504++WQvtpr9WU3krMZrhZ01f2Qz\nTOd04zFrbv773/9WubANp4Icl8zmXeneOELeF9566y1VM2TIEJWTzVZOO+00VTNp0qSY72/NO+s8\nBGmsac2VDRs2ePHvfvc7VfP99997cbqf81Sy/payMY913qxnkbCNyeRxGzduVDVNmjSJeZz1/vJz\nEKS5LBLLuob069dP5eQ1xDqfjz32mMqFbSQl53Wiml8huLDXnyCs5vbWdy/JagSeqGePoA2oC9s9\nLsgzRbL/JvL9rDFZ1xo5P4PMaeecW758uRfXrVtX1cjGaw0bNlQ11vvJZt3WdyqrOWRhm3cWeT6t\nxtRffPGFyl100UVebH0ntprMBrkmWOdFPsusWrUq1Gsjuaxm4HPmzFE5+VkPes/54IMPvPiNN94I\ndFxhwpMdAAAAAAAAACASWLAGAAAAAAAAAEQCC9YAAAAAAAAAgEgosHtYJ5K1x4zcm+/EE09UNdbe\nfM8995wXr127Ns7RIRkWLFigckuXLlW5Ro0aebG1d5W1F3WdOnW8OOy+imXKlFG5evXqqZycw0He\nz9qPTe5F65xz8+fP9+LVq1ermm3btgV6/cKuQoUKKterV6+Yx+3atUvlFi5cqHJh97iSx1mvU9j3\nz4qH3E9v3LhxqsaaB1lZWV5sfa4TuWerPMfWPW7KlCkqd8stt3ixtV8g8yd/WX//sPvKWuQ8tObl\noUOHVE7uo/3DDz+omrD7aiNxZJ8M55w7++yzVU7Os507d6oaa+/QROHelVzW3zKRn0+5X33v3r1V\njXVtkeM66aSTAh0nx259H5TP/da+tta1LZ2egYPso2t9P5K9c6y/UzKv79b7Wfe9sNcI+Xex+gdl\nZmZ6cdOmTVWNdX1dtmyZF1v9i6x+DtwvNev8Tps2TeUGDhzoxfK7vHP25z/IdyjrXGVnZ3uxdW9M\np+tIQSGvZaNGjVI1sr+Xc/p6YF1rZs6cqXI9e/aMeVxhxy+sAQAAAAAAAACRwII1AAAAAAAAACAS\nWLAGAAAAAAAAAEQCC9YAAAAAAAAAgEgodE0XrcYRZ5xxhso9+OCDXmw1k9i4caPKjR492otp9hJN\nmzZtUrn27durnGyO0blzZ1XTvHlzlTvrrLO8uEqVKqrGmouyQcjBgwdVzZ49e1QuNzfXi2UjK+f0\nv9lq2vfOO++o3MqVK4/5Xs7R6Oho5DmuWbOmqrEapMjzN2vWLFWzb9++OEd3dMls+MC8sK8/1nXk\n/fff9+JWrVqpGuveFKRBknWOc3JyvPjee+9VNa+88orKWY2NEC3J/tyVL1/ei615aTWPXbx4sRfP\nnTtX1dBEKvXkNUQ2gHVON1RzTp8rq5Ez9xccjWxM3aRJE1Vj3d/keW/Tpo2qKVeunMrt3bvXi63G\njLLJWpCma87peZ5ujbTkeShZsqSqkdcI6+9rfV+Rf6ugn+tkNhC37mnVqlXz4jJlyqiaUqVKeXGt\nWrVUzebNm1VO/g32798fswbByYaHzjk3ePBgL5ZrOs7Z90LZLNb6nv6Pf/xD5SZOnOjFa9euVTU0\nXUwu634iP6P9+/dXNda1TH4eV6xYoWq6du0a8zho/MIaAAAAAAAAABAJLFgDAAAAAAAAACKBBWsA\nAAAAAAAAQCSwYA0AAAAAAAAAiIRC13SxUqVKKvf888+rnGz8YW16P378eJWzmiIgeqymG7LhmHPO\nzZs375hxPKyN/mVONnI4miCNRiRrk3+aJybX7t27Ve5vf/ubysnrzxtvvKFqDh8+nLiBId9ZDena\ntWvnxfXq1VM1Y8aMUblGjRp58fbt21XN008/rXKffvqpF8uGVM5xPYDdbEY2Xfzpp59UjdVEeMaM\nGV5sXSODNBFFYgVpGGw978rGZFu3blU1pUuXVjl53rnOFE5yblgNfYM0OLSej6ymgLLhX5BrjdUk\n0GoMK8cZpFlkVAUZp3V9l/9m63uHdT9J1N8l7OtY58pqMiu/o1nvJ+eL1fB+2bJlKmc9f0kFZf5E\nkfWZlQ3uu3fvrmpq166tcvK6tX79elWzceNGlQvSUJFznFzWOsuFF17oxdYzi3Utk88x11xzjaqh\niWY4/MIaAAAAAAAAABAJLFgDAAAAAAAAACKBBWsAAAAAAAAAQCSk/R7Wcs+y2267TdU0b95c5eSe\nNtb+xp9//nmco0NhFmS/aGv/PhQc8nxa+5qNGDFC5eSeftY8sPbPQnqR82f16tWqpk+fPikaDfA/\nrHvXtm3bvPjdd99VNdYesnLvTmu/dWuvSSSXvL/I/e2dc+72229XuWrVqnnx/PnzVc3OnTvjHB3S\nldwDdOLEiapmwIABKif3+50+fbqq2bNnj8oFecaW+xkH3Ys63feelf8+a99wee0O+7eLwt/S2qN7\n1apVXhxkX395r3TO7t0g/3b0rUk++Te3zpWVk/M6CvMVmvwsOudc9erVVa5Lly4xj7O+g//www9e\n/OOPPx7vEHEU/MIaAAAAAAAAABAJLFgDAAAAAAAAACKBBWsAAAAAAAAAQCSwYA0AAAAAAAAAiIS0\narpoNXMoXbq0F/fr10/VWI2AJKvp4qJFi1SOjfYBHI11faCxJoCCxLqOyWek2bNnqxrrGU02OSqM\nzcsKAquZmNVYM0jzKc4njkY2Rhw+fLiqeeGFF1SuRIkSXiyb4TmX3Gct69qGgtsc3LpGWc1/5Xw9\ncOCAqilWrJgXn3CCXnqxXvvXX3+NOU5EA/e0gqtOnToqd95553mx/Aw7Z39mJ0yY4MX79u2Lc3T4\nD35hDQAAAAAAAACIBBasAQAAAAAAAACRwII1AAAAAAAAACASWLAGAAAAAAAAAERC2jddzMzM9OKy\nZcuqmiCb5T/55JMqt2vXruMYHQAAQPqjYVT6o6EiEk3OH6tRYnZ2dqqGExjzvnCS5/3w4cOqRjZs\ny83NVTXWccwpILGsJrDWWt7atWu9WDb1dc65Z555RuVef/31OEaHY+EX1gAAAAAAAACASGDBGgAA\nAAAAAAAQCSxYAwAAAAAAAAAiIa32sLb2plm1apUXN2zYUNUUKaLX7ffu3evF1v5SAAAAAACg8JD7\nTFvrEDk5OakaDoBjsPaFX7p0qcq1bt06FcPBceAX1gAAAAAAAACASGDBGgAAAAAAAAAQCSxYAwAA\nAAAAAAAigQVrAAAAAAAAAEAkhG66aG1cDgTF/EE8mD+IB/MHYTF3EA/mD+LB/EE8mD+IB/MH8WD+\nICx+YQ0AAAAAAAAAiAQWrAEAAAAAAAAAkZBxPD/Pz8jI2OqcW5O84aCAq5eXl1f1aP+R+YNjYO4g\nHswfxIP5g3gwfxAP5g/iwfxBPJg/iAfzB/E45vz5j+NasAYAAAAAAAAAIFnYEgQAAAAAAAAAEAks\nWAMAAAAAAAAAIoEFawAAAAAAAABAJLBgDQAAAAAAAACIBBasAQAAAAAAAACRwII1AAAAAAAAACAS\nWLAGAAAAAAAAAEQCC9YAAAAAAAAAgEhgwRoAAAAAAAAAEAn/D4KC5ehFLTFQAAAAAElFTkSuQmCC\n", 199 | "text/plain": [ 200 | "" 201 | ] 202 | }, 203 | "metadata": {}, 204 | "output_type": "display_data" 205 | } 206 | ], 207 | "source": [ 208 | "fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(20,4))\n", 209 | "in_imgs = mnist.test.images[:10]\n", 210 | "reconstructed, compressed = sess.run([decoded, encoded], feed_dict={inputs_: in_imgs})\n", 211 | "\n", 212 | "for images, row in zip([in_imgs, reconstructed], axes):\n", 213 | " for img, ax in zip(images, row):\n", 214 | " ax.imshow(img.reshape((28, 28)), cmap='Greys_r')\n", 215 | " ax.get_xaxis().set_visible(False)\n", 216 | " ax.get_yaxis().set_visible(False)\n", 217 | "\n", 218 | "fig.tight_layout(pad=0.1)" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 9, 224 | "metadata": { 225 | "collapsed": true 226 | }, 227 | "outputs": [], 228 | "source": [ 229 | "sess.close()" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "metadata": { 235 | "collapsed": true 236 | }, 237 | "source": [ 238 | "## Up Next\n", 239 | "\n", 240 | "We're dealing with images here, so we can (usually) get better performance using convolution layers. So, next we'll build a better autoencoder with convolutional layers.\n", 241 | "\n", 242 | "In practice, autoencoders aren't actually better at compression compared to typical methods like JPEGs and MP3s. But, they are being used for noise reduction, which you'll also build." 243 | ] 244 | } 245 | ], 246 | "metadata": { 247 | "kernelspec": { 248 | "display_name": "Python 3", 249 | "language": "python", 250 | "name": "python3" 251 | }, 252 | "language_info": { 253 | "codemirror_mode": { 254 | "name": "ipython", 255 | "version": 3 256 | }, 257 | "file_extension": ".py", 258 | "mimetype": "text/x-python", 259 | "name": "python", 260 | "nbconvert_exporter": "python", 261 | "pygments_lexer": "ipython3", 262 | "version": "3.6.1" 263 | } 264 | }, 265 | "nbformat": 4, 266 | "nbformat_minor": 2 267 | } 268 | -------------------------------------------------------------------------------- /autoencoder/Simple_Autoencoder_So.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# A Simple Autoencoder\n", 8 | "\n", 9 | "We'll start off by building a simple autoencoder to compress the MNIST dataset. With autoencoders, we pass input data through an encoder that makes a compressed representation of the input. Then, this representation is passed through a decoder to reconstruct the input data. Generally the encoder and decoder will be built with neural networks, then trained on example data.\n", 10 | "\n", 11 | "![Autoencoder](assets/autoencoder_1.png)\n", 12 | "\n", 13 | "In this notebook, we'll be build a simple network architecture for the encoder and decoder. Let's get started by importing our libraries and getting the dataset." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": { 20 | "collapsed": true 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "%matplotlib inline\n", 25 | "\n", 26 | "import numpy as np\n", 27 | "import tensorflow as tf\n", 28 | "import matplotlib.pyplot as plt" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "name": "stdout", 38 | "output_type": "stream", 39 | "text": [ 40 | "Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.\n", 41 | "Extracting MNIST_data/train-images-idx3-ubyte.gz\n", 42 | "Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.\n", 43 | "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n", 44 | "Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.\n", 45 | "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n", 46 | "Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.\n", 47 | "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n" 48 | ] 49 | } 50 | ], 51 | "source": [ 52 | "from tensorflow.examples.tutorials.mnist import input_data\n", 53 | "mnist = input_data.read_data_sets('MNIST_data', validation_size=0)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "Below I'm plotting an example image from the MNIST dataset. These are 28x28 grayscale images of handwritten digits." 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 3, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "data": { 70 | "text/plain": [ 71 | "" 72 | ] 73 | }, 74 | "execution_count": 3, 75 | "metadata": {}, 76 | "output_type": "execute_result" 77 | }, 78 | { 79 | "data": { 80 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADP9JREFUeJzt3V+IXPUZxvHnSfwHieCf4BJtMBGkKkFTWMR/lGibajUS\nvYiYi5JSdXvRSgsVKulFhVqQYlq8ErYkGkuNKRjJEsSgoZgWqyQRTaI2idUUs8akMWLthdQkby/m\nRLZx58xm5syc2X2/H1h25rxz5rwc9tnfOXNm5ueIEIB8ptXdAIB6EH4gKcIPJEX4gaQIP5AU4QeS\nIvxAUoQfSIrwA0md1suN2ebthECXRYQn8riORn7bt9jebftd2w928lwAesvtvrff9nRJeyQtkrRf\n0lZJyyLi7ZJ1GPmBLuvFyH+1pHcj4r2I+K+kZyQt6eD5APRQJ+G/SNIHY+7vL5b9H9tDtrfZ3tbB\ntgBUrOsv+EXEsKRhicN+oJ90MvKPSpoz5v7XimUAJoFOwr9V0qW259k+Q9LdkkaqaQtAt7V92B8R\nR23/WNImSdMlrY6ItyrrDEBXtX2pr62Ncc4PdF1P3uQDYPIi/EBShB9IivADSRF+ICnCDyRF+IGk\nCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiB\npAg/kBThB5Ii/EBShB9IivADSRF+IKm2p+iWJNv7JH0m6ZikoxExWEVTQBWWLl3atPbEE0+Urnv9\n9deX1t988822euonHYW/cGNEHK7geQD0EIf9QFKdhj8kvWR7u+2hKhoC0BudHvbfEBGjti+Q9KLt\nv0fElrEPKP4p8I8B6DMdjfwRMVr8PiTpOUlXj/OY4YgY5MVAoL+0HX7bM2yffeK2pO9I2lVVYwC6\nq5PD/gFJz9k+8TxPR8QLlXQFoOvaDn9EvCfpqgp76aolS5aU1mfNmlVaX7VqVZXtoAeuueaaprW9\ne/f2sJP+xKU+ICnCDyRF+IGkCD+QFOEHkiL8QFJVfKpvUli0aFFpff78+aV1LvX1n2nTyseuyy67\nrGltYGCgdN3i/StTGiM/kBThB5Ii/EBShB9IivADSRF+ICnCDyTliOjdxuzebewkH3/8cWl9586d\npfWFCxdW2A2qcPHFF5fW33///aa1l19+uXTdG2+8sa2e+kFETOhNCoz8QFKEH0iK8ANJEX4gKcIP\nJEX4gaQIP5BUms/zt/rsNyafkZGRttfdtYv5ZUgEkBThB5Ii/EBShB9IivADSRF+ICnCDyTV8jq/\n7dWSFks6FBHzi2XnSVonaa6kfZLuiohPutdma2XTMUvSjBkzetQJemXmzJltr7tx48YKO5mcJjLy\nPynplpOWPShpc0RcKmlzcR/AJNIy/BGxRdKRkxYvkbSmuL1G0h0V9wWgy9o95x+IiAPF7Y8klc99\nBKDvdPze/oiIsu/msz0kaajT7QCoVrsj/0HbsyWp+H2o2QMjYjgiBiNisM1tAeiCdsM/Iml5cXu5\npA3VtAOgV1qG3/ZaSX+T9HXb+23fI+kRSYts75X07eI+gEmk5Tl/RCxrUvpWxb10ZOnSpaX1005L\n89UFU8aFF15YWr/gggvafu49e/a0ve5UwTv8gKQIP5AU4QeSIvxAUoQfSIrwA0lNmetfV111VUfr\nb9++vaJOUJWnn366tN7qY9qHDx9uWvv000/b6mkqYeQHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaSm\nzHX+Tr366qt1tzApnXPOOaX1ZcuafSJcuvfee0vXvfLKK9vq6YSHH364ae3IkZO/kzYfRn4gKcIP\nJEX4gaQIP5AU4QeSIvxAUoQfSIrr/IXzzz+/tm1fd911pfXp06eX1hcvXty0Nm/evNJ1zzzzzNL6\nzTffXFq3XVo/evRo09ru3btL1z127Fhpfdq08rFry5YtpfXsGPmBpAg/kBThB5Ii/EBShB9IivAD\nSRF+IClHRPkD7NWSFks6FBHzi2UPSbpP0r+Kh62IiOdbbswu31gHNmzYUFq//fbbS+uff/55ab2b\nn/9uNRV1K8ePH29a++KLL0rX/fDDD0vrW7duLa2/8sorpfWRkZGmtdHR0dJ1P/nkk9L6WWedVVrP\nOi17RJS/+aIwkZH/SUm3jLP8dxGxoPhpGXwA/aVl+CNiiyS+9gSYYjo557/f9g7bq22fW1lHAHqi\n3fA/LukSSQskHZC0stkDbQ/Z3mZ7W5vbAtAFbYU/Ig5GxLGIOC7p95KuLnnscEQMRsRgu00CqF5b\n4bc9e8zdOyXtqqYdAL3S8lqI7bWSFkqaZXu/pF9KWmh7gaSQtE/SD7vYI4AuaHmdv9KNdfE6fyuP\nPvpoaX3hwoW9aaQN69atK63v2LGjaW3Tpk1Vt1OZFStWlNbLvndfav0+gDq/o6FOVV7nBzAFEX4g\nKcIPJEX4gaQIP5AU4QeSSvOZxwceeKDuFnCS2267raP1N27cWFEnOTHyA0kRfiApwg8kRfiBpAg/\nkBThB5Ii/EBSaa7zY+pZu3Zt3S1Maoz8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQf\nSIrwA0kRfiApwg8kRfiBpAg/kFTLz/PbniPpKUkDkkLScEQ8Zvs8SeskzZW0T9JdEVE+ZzJwCuzy\nmaYvv/zy0voLL7xQZTtTzkRG/qOSfhYRV0i6RtKPbF8h6UFJmyPiUkmbi/sAJomW4Y+IAxHxenH7\nM0nvSLpI0hJJa4qHrZF0R7eaBFC9Uzrntz1X0jckvSZpICIOFKWP1DgtADBJTPg7/GzPlPSspJ9G\nxL/Hno9FRNiOJusNSRrqtFEA1ZrQyG/7dDWC/8eIWF8sPmh7dlGfLenQeOtGxHBEDEbEYBUNA6hG\ny/C7McSvkvRORPx2TGlE0vLi9nJJG6pvD0C3TOSw/3pJ35O00/YbxbIVkh6R9Cfb90j6p6S7utMi\nsooY90zyS9Om8TaVTrQMf0T8VVKzC67fqrYdAL3Cv04gKcIPJEX4gaQIP5AU4QeSIvxAUkzRjUnr\npptuKq2vXLmyR51MToz8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU1/nRt1p9dTc6w8gPJEX4gaQI\nP5AU4QeSIvxAUoQfSIrwA0lxnR+1Wb9+fWn92muv7VEnOTHyA0kRfiApwg8kRfiBpAg/kBThB5Ii\n/EBSbjUHuu05kp6SNCApJA1HxGO2H5J0n6R/FQ9dERHPt3iu8o0B6FhETOiLECYS/tmSZkfE67bP\nlrRd0h2S7pL0n4h4dKJNEX6g+yYa/pbv8IuIA5IOFLc/s/2OpIs6aw9A3U7pnN/2XEnfkPRaseh+\n2ztsr7Z9bpN1hmxvs72to04BVKrlYf+XD7RnSnpZ0q8jYr3tAUmH1Xgd4FdqnBr8oMVzcNgPdFll\n5/ySZPt0SRslbYqI345TnytpY0TMb/E8hB/osomGv+VhvxtfobpK0jtjg1+8EHjCnZJ2nWqTAOoz\nkVf7b5D0F0k7JR0vFq+QtEzSAjUO+/dJ+mHx4mDZczHyA11W6WF/VQg/0H2VHfYDmJoIP5AU4QeS\nIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSfV6iu7Dkv455v6sYlk/6tfe\n+rUvid7aVWVvF0/0gT39PP9XNm5vi4jB2hoo0a+99WtfEr21q67eOOwHkiL8QFJ1h3+45u2X6dfe\n+rUvid7aVUtvtZ7zA6hP3SM/gJrUEn7bt9jebftd2w/W0UMztvfZ3mn7jbqnGCumQTtke9eYZefZ\nftH23uL3uNOk1dTbQ7ZHi333hu1ba+ptju0/237b9lu2f1Isr3XflfRVy37r+WG/7emS9khaJGm/\npK2SlkXE2z1tpAnb+yQNRkTt14Rtf1PSfyQ9dWI2JNu/kXQkIh4p/nGeGxE/75PeHtIpztzcpd6a\nzSz9fdW476qc8boKdYz8V0t6NyLei4j/SnpG0pIa+uh7EbFF0pGTFi+RtKa4vUaNP56ea9JbX4iI\nAxHxenH7M0knZpaudd+V9FWLOsJ/kaQPxtzfr/6a8jskvWR7u+2hupsZx8CYmZE+kjRQZzPjaDlz\ncy+dNLN03+y7dma8rhov+H3VDRGxQNJ3Jf2oOLztS9E4Z+unyzWPS7pEjWncDkhaWWczxczSz0r6\naUT8e2ytzn03Tl+17Lc6wj8qac6Y+18rlvWFiBgtfh+S9Jwapyn95OCJSVKL34dq7udLEXEwIo5F\nxHFJv1eN+66YWfpZSX+MiPXF4tr33Xh91bXf6gj/VkmX2p5n+wxJd0saqaGPr7A9o3ghRrZnSPqO\n+m/24RFJy4vbyyVtqLGX/9MvMzc3m1laNe+7vpvxOiJ6/iPpVjVe8f+HpF/U0UOTvi6R9Gbx81bd\nvUlaq8Zh4BdqvDZyj6TzJW2WtFfSS5LO66Pe/qDGbM471Aja7Jp6u0GNQ/odkt4ofm6te9+V9FXL\nfuMdfkBSvOAHJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiCp/wE+Awqah6Q+0AAAAABJRU5ErkJg\ngg==\n", 81 | "text/plain": [ 82 | "" 83 | ] 84 | }, 85 | "metadata": {}, 86 | "output_type": "display_data" 87 | } 88 | ], 89 | "source": [ 90 | "img = mnist.train.images[2]\n", 91 | "plt.imshow(img.reshape((28, 28)), cmap='Greys_r')" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "We'll train an autoencoder with these images by flattening them into 784 length vectors. The images from this dataset are already normalized such that the values are between 0 and 1. Let's start by building basically the simplest autoencoder with a **single ReLU hidden layer**. This layer will be used as the compressed representation. Then, the encoder is the input layer and the hidden layer. The decoder is the hidden layer and the output layer. Since the images are normalized between 0 and 1, we need to use a **sigmoid activation on the output layer** to get values matching the input.\n", 99 | "\n", 100 | "![Autoencoder architecture](assets/simple_autoencoder.png)\n", 101 | "\n", 102 | "\n", 103 | "> **Exercise:** Build the graph for the autoencoder in the cell below. The input images will be flattened into 784 length vectors. The targets are the same as the inputs. And there should be one hidden layer with a ReLU activation and an output layer with a sigmoid activation. The loss should be calculated with the cross-entropy loss, there is a convenient TensorFlow function for this `tf.nn.sigmoid_cross_entropy_with_logits` ([documentation](https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits)). You should note that `tf.nn.sigmoid_cross_entropy_with_logits` takes the logits, but to get the reconstructed images you'll need to pass the logits through the sigmoid function." 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 4, 109 | "metadata": { 110 | "collapsed": true 111 | }, 112 | "outputs": [], 113 | "source": [ 114 | "# Size of the encoding layer (the hidden layer)\n", 115 | "encoding_dim = 32\n", 116 | "\n", 117 | "image_size = mnist.train.images.shape[1]\n", 118 | "\n", 119 | "inputs_ = tf.placeholder(tf.float32, (None, image_size), name='inputs')\n", 120 | "targets_ = tf.placeholder(tf.float32, (None, image_size), name='targets')\n", 121 | "\n", 122 | "# Output of hidden layer\n", 123 | "encoded = tf.layers.dense(inputs_, encoding_dim, activation=tf.nn.relu)\n", 124 | "\n", 125 | "# Output layer logits\n", 126 | "logits = tf.layers.dense(encoded, image_size, activation=None)\n", 127 | "# Sigmoid output from\n", 128 | "decoded = tf.nn.sigmoid(logits, name='output')\n", 129 | "\n", 130 | "loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=targets_, logits=logits)\n", 131 | "cost = tf.reduce_mean(loss)\n", 132 | "opt = tf.train.AdamOptimizer(0.001).minimize(cost)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": {}, 138 | "source": [ 139 | "## Training" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 5, 145 | "metadata": { 146 | "collapsed": true 147 | }, 148 | "outputs": [], 149 | "source": [ 150 | "# Create the session\n", 151 | "sess = tf.Session()" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "Here I'll write a bit of code to train the network. I'm not too interested in validation here, so I'll just monitor the training loss and the test loss afterwards. \n", 159 | "\n", 160 | "Calling `mnist.train.next_batch(batch_size)` will return a tuple of `(images, labels)`. We're not concerned with the labels here, we just need the images. Otherwise this is pretty straightfoward training with TensorFlow. We initialize the variables with `sess.run(tf.global_variables_initializer())`. Then, run the optimizer and get the loss with `batch_cost, _ = sess.run([cost, opt], feed_dict=feed)`." 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "epochs = 20\n", 170 | "batch_size = 200\n", 171 | "sess.run(tf.global_variables_initializer())\n", 172 | "for e in range(epochs):\n", 173 | " for ii in range(mnist.train.num_examples//batch_size):\n", 174 | " batch = mnist.train.next_batch(batch_size)\n", 175 | " feed = {inputs_: batch[0], targets_: batch[0]}\n", 176 | " batch_cost, _ = sess.run([cost, opt], feed_dict=feed)\n", 177 | "\n", 178 | " print(\"Epoch: {}/{}...\".format(e+1, epochs),\n", 179 | " \"Training loss: {:.4f}\".format(batch_cost))" 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": {}, 185 | "source": [ 186 | "## Checking out the results\n", 187 | "\n", 188 | "Below I've plotted some of the test images along with their reconstructions. For the most part these look pretty good except for some blurriness in some parts." 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 7, 194 | "metadata": {}, 195 | "outputs": [ 196 | { 197 | "data": { 198 | "image/png": "iVBORw0KGgoAAAANSUhEUgAABawAAAEsCAYAAAAvofT2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3WeYVFW28PHdIDlDg2RawAAKioAgEgRRQYFRGZSrhBFR\nRxQDimEGBAQj6mC6omMChZlRdBAT4yCggAkRBQEltETJsYEmiP1+uHee9+61FtThVOjT1f/ft7Vc\np2rTZ9c5p7b17JWRl5fnAAAAAAAAAADIb0XyewAAAAAAAAAAADjHgjUAAAAAAAAAICJYsAYAAAAA\nAAAARAIL1gAAAAAAAACASGDBGgAAAAAAAAAQCSxYAwAAAAAAAAAigQVrAAAAAAAAAEAksGANAAAA\nAAAAAIgEFqwBAAAAAAAAAJFwwvEUZ2Zm5mVlZSVpKCjoFixYsC0vL6/q0f478wdHw9xBPJg/iAfz\nB/Fg/iAezB/Eg/mDeDB/EA/mD+IRa/78x3EtWGdlZblvvvkm/KiQ1jIyMtYc678zf3A0zB3Eg/mD\neDB/EA/mD+LB/EE8mD+IB/MH8WD+IB6x5s9/sCUIAAAAAAAAACASjusX1v9XRkZGIseBAiovLy/U\nccwfOMf8QXyYP4hHmPnD3IFzXHsQH+YP4sH8QTyYP4gH8wfxCDN/+IU1AAAAAAAAACASWLAGAAAA\nAAAAAEQCC9YAAAAAAAAAgEhgwRoAAAAAAAAAEAksWAMAAAAAAAAAIoEFawAAAAAAAABAJLBgDQAA\nAAAAAACIBBasAQAAAAAAAACRwII1AAAAAAAAACASWLAGAAAAAAAAAEQCC9YAAAAAAAAAgEhgwRoA\nAAAAAAAAEAksWAMAAAAAAAAAIuGE/B4AUJA8+uijKle6dGmVa9GihRe3bt060OtPmzbNi2fNmqVq\nxo0bF+i1AAAAAAAAgIKGX1gDAAAAAAAAACKBBWsAAAAAAAAAQCSwYA0AAAAAAAAAiAQWrAEAAAAA\nAAAAkUDTReAY5s2b58XnnntuqNfJy8sLVNe9e3cvPu+881SNbMzonHPZ2dmhxoX01qRJE5X7/vvv\nVW706NFePGLEiKSNCclXtmxZL540aZKqkdca55xbu3atF19wwQWqZtWqVXGODgAAACgcqlSponKn\nnnrqcb/Ojz/+qHIPPvigysnveosWLVI1n3/++XG/P5Af+IU1AAAAAAAAACASWLAGAAAAAAAAAEQC\nC9YAAAAAAAAAgEhgD2vgf8n9qp0Lv2f1li1bvHjWrFmqpmHDhirXvHlzL65cubKqGTx4sMrdcccd\nxztEFALt2rVTOWs/9XXr1qViOEiRrKwsL+7WrZuqseZB3bp1vbhPnz6qZtSoUfENDvmiffv2Kmf1\nQ6hYsWIqhnNUvXv3VrmvvvrKi3/++edUDQf5pH///ir32muvqdzIkSO9eMyYMarmyJEjiRoWAqpR\no4YXz549W9XMnTtX5R555BEvXrFiRULHlQiVKlVSuR49eqjc5MmTvfjw4cNJGxOA/NO3b18vtp5j\nzjnnHJWz9rWOZdu2bSpnPbedcELsJb4iRfjdKgoGZioAAAAAAAAAIBJYsAYAAAAAAAAARAIL1gAA\nAAAAAACASGDBGgAAAAAAAAAQCTRdRKHUsWNHlWvVqlXM4zZt2qRyHTp0iFmXk5OjaooXL65yq1at\n8uJatWqpmmrVqsUcJ+Cccy1btlQ5q/HPSy+9lIrhIAmqV6+ucu+++24+jARRdtlll6lc0aJF82Ek\nx3bVVVep3C233OLFbdu2TdVwkCLyueaZZ54JdJxsujh27FhVs3///tDjQmxW47CVK1d6cYkSJVSN\n1TysIDRZlP8255wrU6aMyi1YsMCLf/jhh8QOrJCzGs3JxqyNGzdWNaeffrrK0RATzjnXqFEjL77/\n/vtVzRVXXKFyssFhRkZGYgf2f2RmZibttYGo4hfWAAAAAAAAAIBIYMEaAAAAAAAAABAJLFgDAAAA\nAAAAACKhwOxhfcMNN3jx4MGDVc3mzZtVTu5d9+KLL6qa7OxslVu6dOnxDhEFSN26dVXO2nNK7kVt\n7XO9bt26UGN49NFHVc7aj1Z65513Qr0f0p+cn1dffbWqmT59eqqGgwR74IEHVK5Xr14ql5WVlZD3\nu+iii1SuSBH9/7m//fZbL2YP7fwn91Ts3r17Po3k+MydO1fl7rzzTi8uW7asqtm7d2/SxoTkk/Oz\nXLlygY6bM2eOF+fm5iZsTNBOPPFElZs9e7bKlSpVyov/+c9/qpqePXsmbFzJJPdTl3taO+fcfffd\np3LsWZ04t956q8pZz0Ply5eP+VrW+duyZUu4gSGtnHrqqV5s9dRINTk3rTUrRJO1h36dOnVUTn5X\nt3qj/fbbbyr37LPPevHHH3+satLlPsQvrAEAAAAAAAAAkcCCNQAAAAAAAAAgEliwBgAAAAAAAABE\nAgvWAAAAAAAAAIBIKDBNF2WDugoVKqia008/PebrdOvWTeUOHTqkchs2bDiO0aWGbCr55z//WdXM\nmjUrVcMp0CZMmKByVrOn3bt3e/G2bdsSNoYrr7xS5YoWLZqw10fhc+aZZ3pxsWLFVM2rr76aquEg\nwYYNG6ZyeXl5SXu/1q1bB8rt2rXLi61mWlZjLiSPPAf169dXNa+99lqKRhNcZmamyslGbzRdLNhK\nliypciNGjAj1Wi+88IIXJ/N6COc6duyocrJRmeXmm29OxnASrkWLFionG2J9/fXXqmb8+PFJG1Nh\nJBtHP/zww6pGNvYMasqUKSp3xRVXeHEiv+shuaxGsGPGjPFia21k8uTJKnfgwAEvPnjwoKqx1oyK\nFy/uxQsWLFA1sjm5c87NmzfPi63n5H379nkxzzrR0KpVK5WT39E6deqkasJetyyPP/64F1uNGbdu\n3erF8+fPVzW///3vVc6a5/mJX1gDAAAAAAAAACKBBWsAAAAAAAAAQCSwYA0AAAAAAAAAiAQWrAEA\nAAAAAAAAkVBgmi7ecMMNXnz22WermsWLF6tckyZNvPjcc89VNc2aNVO5k046yYv37NmjasqXL28P\nNgZrU/T9+/d7sdVUSI5p4MCBqoami+GtWrUqaa/92GOPqVy1atViHvfzzz+r3PTp0xMyJqSfP/3p\nT14sm4Y659yMGTNSNRzE6bvvvvPijIyMpL5fbm6uF1tNN6yGx5UqVfLimTNnqpoiRfj/48liNX+R\nzVV37Niham6//fakjSks2fwK6adNmzYqV6dOnZjHWc/OkyZNSsiYYKtRo4YX9+3bN9BxQ4cO9eJN\nmzYlbEyJJJssBvkO9be//U3lrGcthCe/MyWyUVnbtm1Vbt26dV781FNPqZr7779f5aLWmCzdWWsj\n33zzjcrVqlXLi2Vzw6OR36+bNm2qalasWKFysqn16tWrVY11/0I0yebyw4cPVzVWQ8USJUrEfO2c\nnByV+/777714+fLlqubaa69VubVr13pxvXr1VE2ZMmW8uH379qrm7rvvVjnZuDS/8Q0SAAAAAAAA\nABAJLFgDAAAAAAAAACKBBWsAAAAAAAAAQCQUmD2s33rrrWPG8ahSpYrKdezY0YutfV8vvPDCUO8n\n96t2zrkFCxZ4cXZ2tqopWbKkF//000+h3h/J169fPy++4447VE3RokVVbt++fV585513xqxB4XTy\nySerXN26db1427Ztqmbv3r1JGxPCu+yyy1ROns+8vDxVY+WCmDp1qspNmzbNi3ft2qVqLr74YpW7\n8cYbY76f3ANu9OjRMY9BME888YTKFStWzIuvuuoqVWPtpZdqmZmZXnzKKaeomrBzHNEUdB9kadGi\nRQkeCWKR+zV36NBB1cj9f51z7oUXXkjamBKpS5cuXiz3+3TOuU8++cSLrf2NEV6DBg1UrkePHjGP\n27hxo8rJXg2nn356oDHIvWdvvvlmVfPMM8+o3IYNGwK9PsIpXry4F8+ePVvVyP2qnXPu5Zdf9uKw\na0bWftUWa80GBcMHH3ygcueff74XB91Df9myZV5sPbMMGDBA5WT/IIu1937v3r29+O2331Y1sj+I\ntYb0wAMPqNxLL73kxfndh4JfWAMAAAAAAAAAIoEFawAAAAAAAABAJLBgDQAAAAAAAACIBBasAQAA\nAAAAAACRUGCaLibT9u3bVW7KlCkxj0tk48frr7/ei2WDRed0g4n//u//Ttj7I7Fat27txVaDRctH\nH33kxVZjNMA557p37x6zZvfu3SkYCY6X1TDz9ddfV7nSpUuHen3ZLPH9999XNYMGDVK5IA1df/jh\nB5WTTdSscQ8bNsyLrSYmI0aMULnDhw/HHFNhcsMNN6hcixYtVE42XJ05c2bSxhSPp59+2outBouy\nwbT1zIaCo3379jFrjhw5onK33HJLMoaDY5CfR+vzuXXrVpU7ePBg0sYUhHUPGjdunMr16dMn5mtd\neOGFCRkTbNb1QDbbW7lypaqxGvTK5wrrmnHvvfeqXKVKlby4bNmyqmbevHkqJ++9VqNzBFOuXDmV\n+8tf/uLFZ599tqrZv3+/yt19991eHOTZFulHXg/Gjh2rarp27Rrzdaw5NnHiRJWT827v3r0xXzuo\n8uXLq9wJJ/jLuH/+859VzeTJk724QoUKCRtTKvELawAAAAAAAABAJLBgDQAAAAAAAACIBBasAQAA\nAAAAAACRwII1AAAAAAAAACASaLqYD2rUqKFysrFARkaGqhk5cqQX09whGubPn69yZ555ZszjrCZY\n1113XULGhPTXvHnzmDVjxoxJwUhwvEqUKKFyYRssyoZ0zjnXsWNHL968eXOo17asWrVK5Z588kkv\nlg0WnXOuWLFiXnzPPfeoGqvx5LJly453iGmtf//+Kif/ts459/zzz6diOMfFajbao0cPL/7tt99U\nzfDhw72YRpwFh9XQqH79+jGPs86x1fQM+a9Zs2Yqt3jxYi/es2ePqpH3jXh07tzZi+U90DnnTjrp\npJiv88UXXyRsTAimZMmSMWseeeSRQK+Vm5vrxVaTtWuuuUblZNNFq7nogQMHVC6/m4umkwEDBsTM\nWY3krevPzp07EzcwFFiXX365F19//fWBjpPNEq+44gpVM2PGjPADE4oWLerF1jOS9f1IjiHItdRa\nX5w9e7bKRa25Ob+wBgAAAAAAAABEAgvWAAAAAAAAAIBIYMEaAAAAAAAAABAJ7GGdD+6//36Vk/uX\nWntlff/990kbE4KpU6eOyjVu3FjlTjjB/2jt379f1QwePFjlcnJy4hgd0lWXLl1UTu7N5Zxz69ev\n9+I333wzaWNC6q1du1blunXrpnKJ3LM6iIkTJ3pxv379VE29evVSNZy0IvfWPP300wMd98ADDyRj\nOHG59957Va5UqVJevGXLFlUzZcqUpI0JydWmTZtQx02aNCnBI0EYo0aN8uJp06apmrJly6rcKaec\nEvO1J0+eHH5gCSL3uh04cGA+jaTwuvbaa2PW9OrVS+VeeeWVUO9n9VIIwtrfnO9sidOpU6eYNcuX\nL1e51atXJ2E0SAdyb2irR4rlyJEjXtyuXTtVY33PCfJ8bq3vyf4KJ554oqqx1pHKlCkT8/2kffv2\nqdytt96qclHrFcMvrAEAAAAAAAAAkcCCNQAAAAAAAAAgEliwBgAAAAAAAABEAgvWAAAAAAAAAIBI\noOlikl166aUqd/3118c8rnfv3ir39ddfJ2RMCG/27NkqJ5tGWaxGNcuWLUvEkFAIXHLJJSpnzbuf\nf/7Zi3Nzc5M2JiRWRkZGzJqsrKzkDySEIkX8//dt/VuC/PvGjx+vch06dAg/sDRQsmRJLy5Xrpyq\nmTt3bqqGE5fTTjstZs3KlStTMBKkSvv27QPVyUZEY8aMScZwcJzkM69sDuWcc+eff77K9ejRw4v7\n9u2raqwmUm+//fbxDfB/Pffcc1785ZdfBjpONrPnuTz1Xn31VZVr0aKFFzdt2lTVnHXWWSrXunVr\nL7766qtVjbynOqevP1bNVVddpXLPPvusFy9YsEDVIJjOnTvHrGnWrJnKyc++c879/e9/9+I5c+aE\nHxgKLHk/GTx4sKo588wzVa5ChQpefP/996uavLy8mO9v1QT5LmQJ0mDRej+5dnjllVeqmnXr1oUa\nUyrxC2sAAAAAAAAAQCSwYA0AAAAAAAAAiAQWrAEAAAAAAAAAkcCCNQAAAAAAAAAgEmi6mGSXX365\nyskGVc7pRh8ffvhh0saE4P7whz94cd26dQMd99NPP3nxjTfemKghoRBq2bKlylnNFSZOnJiK4SBO\n9913n8oFaeARVX369PHiOnXqqBr577P+vX/84x8TO7A0sHv3bi/esGGDqmnYsKHKZWZmevG2bdsS\nO7AYatSooXLnnntuzONmzJiRjOEgRbp16+bF7dq1C3TcwYMHvXj16tWJGhISaPv27SpnNUqUuf79\n+ydtTM4Fa+hqXTutpnxIrbfeekvlnnzySS+27ifffvttqPdbsmSJysmGirLZqHP6nuqccyNHjvTi\n7t27hxoTnCtdurTKyefEE07Qy1Y33XSTyslnyalTp6qaTz/9VOVkY/Ply5ermvnz56ucZH1nmz59\nuspxn0su2dj3nHPOUTWVK1dWOXn9Oe+881TNrl27VG7NmjVeXKpUKVXTuHFjlatXr57KhfH++++r\n3LXXXuvFO3bsSMh7pRq/sAYAAAAAAAAARAIL1gAAAAAAAACASGDBGgAAAAAAAAAQCexhnWByD6aL\nLrpI1Rw5ckTl7rrrLi8+fPhwYgeGmKpVq6ZyI0aM8OKiRYsGeq2FCxd6cU5OTviBodCpVauWFzdp\n0kTVWHvSvvzyy0kbExLHui9EUfXq1VWudevWKjdkyJDjfm25t5xzeh9b6L/TunXrVI11Tr7++msv\nfuyxxxI2pjPPPFPl5L58NWvWVDVB9mkvyHu5w7mqVat6cUZGRqDjvvjii2QMB4XEc889F7NGfs9y\nzrlNmzYlYzg4DtazrNzzfMKECaqmZMmSKifvH9b+6v369VO53NxcL37vvfdUjdwL1jnn2rZt68WN\nGjVSNbJHFWyTJk1SubB7zMv7jtVPzMolk/XM+91333mxnE9IPmtPZ9m/LJFmzZqlckH2sD506JDK\n3X///V78xBNPqBprzbEg4hfWAAAAAAAAAIBIYMEaAAAAAAAAABAJLFgDAAAAAAAAACKBBWsAAAAA\nAAAAQCTQdDHBZGOj2rVrq5pFixap3EcffZS0MSGYhx9+WOWCbIQvm1s559yNN96YkDGhcJJN7GQz\nV+ec+/LLL1M1HBRSTz/9tMr17Nkz1Gvt2rXLi62mJtnZ2aFeuzC55ZZbVM5qONaiRYuYNWHJBlXO\n6WZX1jUriMcffzzUcYiGIM2KDhw4oHJjx45NwmiQjv74xz+qXMeOHb3YalC1cePGpI0JifXmm2/G\nrLn++utVTjZwvOGGG1SNdf+SBg8erHJW8/Mg99lOnTrFfD/oRpvOOffKK694sTUvihYtqnLly5f3\n4qDNf5PJeiY699xzvdh65r711luTNiYkl/Vc065du1CvNXToUJV75plnQr1WQcQvrAEAAAAAAAAA\nkcCCNQAAAAAAAAAgEliwBgAAAAAAAABEAgvWAAAAAAAAAIBIoOliHPr27atyN910kxcfPHhQ1dx7\n771JGxPC69evX6jjevXqpXI5OTnxDgeF2MknnxyzZuvWrSkYCQqT7777zovr1q2bsNdes2aNF0+b\nNi1hr12YLFy4UOXatGmjcrKxS6NGjRI2hhdffDFmzcyZM1WuQ4cOMY/bv39/qDEh9bKyslQuSEMh\n2YDVOXu+AJYgjX+/+uorlfvss8+SMRykgNVsL0hjxrCs+9CECRNUTjZdbN68uarJzMz0YtkYEv/j\nyJEjKifvC/JveTTye3mxYsVUzYMPPqhy9erVC/T6iSKbQbZu3Tql74/Euueee7zYat5apEjs3wpv\n3rxZ5f7617+GH1ga4BfWAAAAAAAAAIBIYMEaAAAAAAAAABAJLFgDAAAAAAAAACKBPawDqlatmso9\n9dRTKif3I5o/f76qmT59euIGhnx34oknqtyhQ4cS8to7duxQucOHD6uc3J+rcuXKMV+7atWqKmft\n6RXEr7/+qnJyT/B9+/aFeu3C6Pzzz49Z8/bbbyd/IEgKeZ84Wk665pprAr3+888/78Vly5YNNa68\nvLxAxwXRrFmzhL0WYpszZ84x42RbtmyZygXZw7pVq1YqZ+1Hi/zXtWtXlQtyHXv//feTMRwUEtY+\nr/K5ePjw4akaDgoJ+VzlnHNXXXWVF7dt21bVjBw50otvueWWhI4L2ltvvRWzxtpv/I477vDi3377\nTdV89NFHKvfEE0948ahRo1RNkP4OKDg6d+6scvK8Fy9ePNBryTWjgQMHqpoDBw4cx+jSD7+wBgAA\nAAAAAABEAgvWAAAAAAAAAIBIYMEaAAAAAAAAABAJLFgDAAAAAAAAACKBpotHUbRoUS+2midWrFhR\n5Xbu3OnFN954Y2IHhsj5+uuvk/ban3/+ucqtX79e5WrWrOnFVuOPVHvooYe8+LbbbsunkURbjx49\nVK5MmTL5MBKkyosvvqhy99xzT8zjXn/9dZUL0hgxbPPEsMdNnTo11HFIH2Ebi9JgseDIzMyMWbN/\n/36VGzZsWDKGgzRkzRXr+UjOs88++yxpY0LhZDXgu++++7x41qxZqmbQoEFe/MILL6iaxYsXxzk6\nHK93331X5WTTxSJF9O86L730UpVr0KCBF5966qmhxrRhw4ZQxyH1rrzySpUL0mRRNgh2zrmrr77a\niz/44IPwA0tT/MIaAAAAAAAAABAJLFgDAAAAAAAAACKBBWsAAAAAAAAAQCSwh/VRNG7c2Ivr1KkT\n6LghQ4Z48bJlyxI2JiTXt99+q3ItW7bMh5H8f23atEnYa8n914LuTyv36J43b16g42bOnBlsYIVc\n7969VU7u9WrtW/7Pf/4zaWNCcr388ssqN3jwYJUrXbp0KoZzVNb+s9ZcvOKKK7x47dq1SRsTCgbr\n/hJ2T3REk9V/Qdq+fbvK7dixIxnDQRq66aabAtVZ/V6kChUqqFyVKlW8ODs7O9jAAKe/Dz355JOq\n5u677/biv/71r6qmU6dOKmc9fyFxvvnmG5WT5/O8884L9FqnnXZazBprD3S57tC3b99A74fUsu4d\nAwYMCPVaH3/8scq98847oV6rMOEX1gAAAAAAAACASGDBGgAAAAAAAAAQCSxYAwAAAAAAAAAigQVr\nAAAAAAAAAEAk0HTROdegQQOVmzNnTszjHnvsMZWbOHFiQsaE1GvVqpXKjR071ouLFy8e6rWbNWum\ncm3btg31Wv/6179Ubvny5TGPe+2117x44cKFod4f4ZUpU0blOnfuHPO4KVOmqNyRI0cSMiak3qpV\nq1SuT58+Kicbcl511VVJG5Pl8ccfV7lRo0aldAwomII2DP3111+TPBIkQrFixVSudu3aMY87fPhw\noBwQD3kdufXWW1XNXXfdpXIrV670Yqv5HRDUuHHjVG7gwIFefM4556iapk2bqtyXX36ZuIFBsZpa\nymfsDz74QNU0bNhQ5eR3u127dqmav//97yo3aNCgmONE6pUrV86L161bp2qKFIn9m9+NGzeq3JVX\nXhl+YIUYv7AGAAAAAAAAAEQCC9YAAAAAAAAAgEhgwRoAAAAAAAAAEAksWAMAAAAAAAAAIoGmi865\n++67T+XKly8f8zir+V1eXl5CxoRoGDp0aH4PAWnk0KFDKpeTk6Nya9as8eLhw4cnbUyIhnfffTdm\n7r333lM1t912m8q1aNHCi+fPn69qnnrqKZXLyMjwYpr+IKxevXqp3MGDB1XuiSeeSMVwEKfffvtN\n5ZYsWaJy1atX92J5LwOSoUuXLseMnXNu+vTpKnfzzTcnbUwofDZt2qRyssmibPTpnHOPPvqoynXo\n0CFxA0Mgv/zyixc3a9ZM1dx+++0qd/7553vxTTfdpGqsBnyIpp49e3qxbMLoXLD1Puv7WW5ubviB\nFWL8whoAAAAAAAAAEAksWAMAAAAAAAAAIoEFawAAAAAAAABAJBS6Pax79Oihcn369MmHkQAobA4f\nPqxyDRo0yIeRoCCaPHlyoByQ35YvX65yDz30kMpNmTIlFcNBnI4cOaJyAwYMULmXX37Zi+fOnZu0\nMSH9WXvBWvv9zpo1y4vHjBmjarZt26ZyVl8RIJGys7O9eOnSpaqmdevWKte8eXMvXrBgQWIHhlDG\njRsXKIeC68EHH/TioP3pXn/9dS/m+TZx+IU1AAAAAAAAACASWLAGAAAAAAAAAEQCC9YAAAAAAAAA\ngEhgwRoAAAAAAAAAEAmFruni+eefr3LFixePedzOnTsD5QAAAAqzs88+O7+HgCRbu3atyl144YX5\nMBKkq2nTpgXKAQVF27ZtVe7nn39WuSZNmngxTReB1ChbtqwXZ2RkqJp9+/ap3LBhw5I2psKOX1gD\nAAAAAAAAACKBBWsAAAAAAAAAQCSwYA0AAAAAAAAAiAQWrAEAAAAAAAAAkVDomi4G9csvv3jxWWed\npWq2bduWquEAAAAAAIACaNeuXSpXqVKlfBgJAMtzzz3nxffdd5+qefzxx1Vu3bp1SRtTYccvrAEA\nAAAAAAAAkcCCNQAAAAAAAAAgEliwBgAAAAAAAABEQqHbw3rIkCGBcgAAAAAAAADS25/+9Kdjxkg9\nfmENAAAAAAAAAIgEFqwBAAAAAAAAAJHAgjUAAAAAAAAAIBJYsAYAAAAAAAAARELopot5eXmJHAcK\nGeYP4sH8QTyYPwiLuYN4MH8QD+YP4sH8QTyYP4gH8wdh8QtrAAAAAAAAAEAksGANAAAAAAAAAIiE\njOP5eX5GRsZW59ya5A0HBVy9vLy8qkf7j8wfHANzB/Fg/iAezB/Eg/mDeDB/EA/mD+LB/EE8mD+I\nxzHnz38c14I1AAAAAAAAAADJwpYgAAAAAAAAAIBIYMEaAAAAAAAAABAJLFgDAAAAAAAAACKBBWsA\nAAAAAAAAQCSwYA0AAAAAAAAAiAQWrAEAAAAAAAAAkXDC8RRnZmbmZWVlJWkoKOgWLFiwLS8vr+rR\n/jvzB0efyxe1AAAgAElEQVTD3EE8mD+IB/MH8WD+IB7MH8SD+YN4MH8QD+YP4hFr/vzHcS1YZ2Vl\nuW+++Sb8qJDWMjIy1hzrvzN/cDTMHcSD+YN4MH8QD+YP4sH8QTyYP4gH8wfxYP4gHrHmz38c14K1\neIOwhyKN5OXlhTqO+QPnmD+ID/MH8Qgzf5g7cI5rD+LD/EE8mD+IB/MH8WD+IB5h5g97WAMAAAAA\nAAAAIoEFawAAAAAAAABAJLBgDQAAAAAAAACIBBasAQAAAAAAAACRwII1AAAAAAAAACASWLAGAAAA\nAAAAAEQCC9YAAAAAAAAAgEhgwRoAAAAAAAAAEAksWAMAAAAAAAAAIoEFawAAAAAAAABAJLBgDQAA\nAAAAAACIBBasAQAAAAAAAACRwII1AAAAAAAAACASTsjvAQD5oWTJkipXvnx5lbv00ku9+Nxzz1U1\n1atXV7kmTZrEfL/Dhw+r3J49e7z41VdfVTUvvfRSzOPy8vJUDQqOjIyMmDVhz3GRIvr/U8r3s15b\n5phjBUexYsUC1f3222/HjJ3jvAMAgPQR5BkYyA/J/D4IFBT8whoAAAAAAAAAEAksWAMAAAAAAAAA\nIoEFawAAAAAAAABAJLBgDQAAAAAAAACIBJouolA44QR/qteqVUvV9O7dW+Wuv/56L65SpYqqKVGi\nhMoVLVrUi61GdxbZOOGBBx5QNfv371c52ZzxwIEDgd4PBYPVdMOadzVq1PDili1bqpoTTzxR5Vas\nWOHFX3/9taqRjT1//fVXe7BIKXltc865s88+24ufeuopVdOwYUOV27lzpxcPGTJE1UyfPl3lmAv5\ny7q/yJx1DbEa9chGm0EasMZDjjNIU9gjR46oGqtBKKLJmovyvFuNqkuVKqVy8lnHej5ibqRekGfe\nZF9bgCAN65L9fvL7oCXsvTjdyb9n0O/SUfzbVaxYUeW6d+/uxVlZWaomJyfHi998801Vs2nTJpXj\nvod0wi+sAQAAAAAAAACRwII1AAAAAAAAACASWLAGAAAAAAAAAEQCe1ijUAiyf5W1P+LBgwe92Nob\n2nptue9WsWLFVI21F5d8Lbl3lXPOZWdnq9yhQ4dUDgVXkPlq7Rss513z5s1VzRlnnBHztRYuXBhq\nTEgua3/E8uXLq9zNN9/sxU2bNlU11t7XMif3wnbOuRkzZqgce1gnT5A9HK17l9yrvnr16qrG2ltz\n7dq1Xrx582ZVE+Q+aM1VKyf3Ki5Xrpyqkf9muZ++c87l5uaqHHs4JlfQfdGDkNexm266SdV07dpV\n5aZOnerFL7/8sqrZvXu3F3MvC8+6b1SoUEHlGjRo4MXWs+yOHTtUTn62Dx8+rGqsPewTdU6ta6Kc\nm1YfEOuaKK+d1jUK4QW5x1jn07ovWHMqjLDfB61nKHnfs8Zt5dLp+ib/BlbvHov8ewbtexHkb2fN\nqSZNmnjxxIkTVc0pp5yictb1VJLjHDlypKr55JNPVO6OO+7w4nXr1qmadJorSG/8whoAAAAAAAAA\nEAksWAMAAAAAAAAAIoEFawAAAAAAAABAJLBgDQAAAAAAAACIhALTdFFuTF+8eHFVU7p0aZWTDXys\nphdWAx/ZQIPmPenFauSyZcsWlfvss8+8+Ntvv1U1c+fOVbmtW7d6ceXKlVXNpZdeqnJXX321F+/f\nv1/VbNq0SeVonJDerPNrXZNko5FKlSqpGqvhi2zGsXPnzkDvh9Symgyde+65Knf55Zd7sWxsdzSy\nzmq6aM0peU3iepQ48m8ZtFFQs2bNvPiCCy5QNVajp3feeceLrftiEEGaTzmnmypZDdxkU0mrgRIN\nzcIL0rzMOp9hP+fW+zVu3NiL7777blVjPePLBniTJk1SNbLpIoKTzwvyPDnn3AMPPKBysrnzkiVL\nVM0LL7ygcvJ5OmgzvLDXSXn9sZoYd+vWzYs7d+6salasWKFy7777rhcvXbpU1SSq2V+6s64Z1vWg\nUaNGXlyzZk1V8+WXX6qc/M6WyGcYeY6tRntWM2XJOs5q9plO5Of44MGDqsZqgij/Vtbf17pGyOtd\n+/btVc348eNVrnbt2jHfL4iw1yhrPUHO/euuu07VWNck5L+w8yedv6fzC2sAAAAAAAAAQCSwYA0A\nAAAAAAAAiAQWrAEAAAAAAAAAkRCJPazl/kNly5ZVNaeeeqoXt2zZUtWccsopKif38LH2/7X2Hlu1\napUX//LLL6rG2vta7o1s7ZVs7cVVsWJFL5b78gUdk7UnJfS+VNYeUN99953KTZ061Ys3bNigaoLs\nQbd58+aYNc45d/PNN3tx/fr1VY019xcvXuzFUdhDNsgemOnO+qxLify7yH3UTjvtNFWzbNkylfvi\niy+82LpuFcbzFzXW/tEvvviiysl7qDUPrfMpr5PWXp133HGHyo0bN86Lresde3UmhnXerJ4enTp1\n8uKOHTuqmnnz5qmcfM6w9sgMey2wjpPPLNa9We6ba+3TZ+27L9+Pa1h4Qa8hQVj7Mw4dOtSLrXlg\nkfcq67mc826T59Q6L/I71COPPKJq2rRpo3Lytax9gxcsWKBye/fu9eJEXmuC1FnfPy+++GIvPv30\n01XNjh07VE7e85iHwcm5Wb16dVXz0ksvqVy7du282LpXfPjhhyp35513enEiewXJ44KuC8jnscI4\nf4LsTW9dt+QzkbX/t+yN4Zz+fv3EE0+oGvk9yxqDda6sZ+CcnBwvtuaG3Kvd6kNk/V1q1KhxzNdB\ncEH6fDin1zOt55guXbqo3MiRI73YmmOHDh1SuR9//NGLp0yZomqmT5/uxevXr1c1ch46p+dUfu+P\nzS+sAQAAAAAAAACRwII1AAAAAAAAACASWLAGAAAAAAAAAEQCC9YAAAAAAAAAgEhIedNFuSG5c3pz\nfKuxlGweJhsKOWc3qCtTpowXn3HGGaqmffv2Kic3469QoULM13ZOb7QvG4g459zBgwdVTm6Gb/2d\nlixZ4sU33nijqvnpp59ULshG6YlsqJPfgmyEv3v3blWzZs0alUtUAxiL1bxMbrRvvZ9sKuKcc6+9\n9poX5/fm+M4V3PkTVpCmDIn8m5QsWVLlevTo4cVW89Z//vOfKic/D4Xt3EWVbK5iNdSwmhEFafZp\nkcdZzfys+45szjh+/HhV849//MOLaY6WONbn/IILLvBiq5nYnDlzVE42m0r1vaRWrVoqJ5//tmzZ\nompk8xnnmE9BpfreJZuMO+fc+eefH3NM1lycMGGCF+fm5sY3uEJENgqzvnfI+0tWVpaqsY7bunWr\nF48ZM0bVWM15U/2ZlfPM+o7YvHlzL7Yao82dO1flVq5c6cVReC4vKOS8s5p2Wo3J5Pm0/uYXXnih\nyt17771ePHr0aFWzfft2lQsyX2VN0AbUQZp2co+zn1NlszurYd2+fftUbv/+/V5srddYOXkerIbW\nt912m8rJJtdBrsHyXumcc6eccorKffXVV168aNEiVQP7WUM26axWrZqqadWqlcrJa0vXrl1VjfWd\nzWqkGWtMzukGwCeffLKqGTRokBd/8803qmb48OEqF7X7F7+wBgAAAAAAAABEAgvWAAAAAAAAAIBI\nYMEaAAAAAAAAABAJLFgDAAAAAAAAACIh5U0XrU27ZWMBq1GhbAqWnZ0d6P127drlxdbG+1WqVFE5\n2XjDanBmbZJ+4MABL167dm3MGuecO+uss7xYNmF0zrkmTZp4caNGjVSN1XQxiHRv3CCbJMjGUs45\n9+uvv6pcov4usimZc871799f5WQTHKu5w+uvv65yQZt4IHmCNK6yBJljcl44Zzc/+t3vfufFVpOG\n7777TuWSOX8S9TdId9Y5luezbdu2gY6TrL9vkHuxVWPNKdkUb+zYsaqmWbNmXjxs2DBVs2PHDpVj\nbvispjy9e/dWuapVq3qx1ajw888/Vzk5B4Je14I0u7LIZ53/+q//UjX16tXzYmvczJPwktl00bo+\ndenSReXkPLDeTzbEck4/DyVyHqT7vUv++6xri2yCajU4sz7rH3/8sRevWLFC1YT924VtEm8dJxvW\nPvDAA6pGfkf86KOPVI3VdNFqzgjNagj8wQcfeHGdOnVUTZB5YH2Hsq5J8h5qNbEbMmSIysnGZEG+\nRwad9wX52pIs1rkrVaqUyuXk5Bwzds6+bsnmnjfccIOqsRpDL1y40IvlvHDOnhtB/Pzzz15srStZ\na0ZyrYnrkX2Pk8/Kzjl3zTXXePHAgQNVjdWIUc7FoM/P8lytW7dO1VhN0uX3KtmE0Tl9fa1cubKq\nsZqWP/zww14cdv4mCr+wBgAAAAAAAABEAgvWAAAAAAAAAIBIYMEaAAAAAAAAABAJKd/D2tqTSe6L\nIvedds65Tz/91It/+OGHQO+3b98+L7b2sC5RooTKyf2hzzvvvEDHTZ061YutfWHknmnOOffvf//b\ni8uUKaNq5N471p5M7HkVbI4lW8WKFb34jTfeUDXWXkpy/9Dhw4ermlmzZqkc5z2aErUPqDVXzjnn\nHJWTe6utX79e1cj90BIp3ff8TCZrP7RXXnnFi639oy3yb7x161ZVs3TpUpWT161ffvlF1dStW1fl\n6tev78XWnoK9evXy4sWLF6ua1157TeWsfWsLs5o1a6pcz549VU7OFfkM5ZxzGzZsUDk5d6zPtHU9\nkqz9Ia39J3v06OHFl1xyiaqR+/tZfUCC7pkNzbomB7mWW+Rx1p7HHTp0UDk5p4LsL+qcc9u2bTve\nIYaWbveuIJ91+T3H+m5ikd9PEvm3s64jQV6/Ro0aKiefp629kuV98LnnnlM127dvDzWmwibI/tHO\n6T5SFuv7/IwZM7xYfrd2zrl+/fqpnNz7tV27dqpm0KBBKjd69Ggvtp61kDjly5dXOeuaJHt2BF0D\n2LNnjxd/8cUXqsb6XMtcMvsCWa9t9X6Dvt7Url1b1bz00ksqJ/sFBVmvcU734ZFrgs7ZazhfffWV\nF1t91qy+eU888YQXy354zum/gfX9rHHjxoHeLz/xC2sAAAAAAAAAQCSwYA0AAAAAAAAAiAQWrAEA\nAAAAAAAAkcCCNQAAAAAAAAAgElLedNESZLN62cBDNlN0zm7SInNBm8vIRi5yQ3TnnDt8+LDKWU0g\ngozTaiIkyaYBq1atUjU0+Ug9q4nI9ddf78Vly5ZVNdb8efvtt734qaeeUjVhmzkkqgEgbNbfU37W\nwzYGs5rtde7cWeVkk4Tdu3ermp07d6occyG1rPP5+uuvq5x13ZCsZjLPP/+8F48YMULVWPe9ChUq\neLE1X1u1aqVyd911lxc3bdpU1chmOXfffbeqee+991SusDddlM1errjiClVjNXKWTYcefvhhVWP9\nbYM0Ygv7rGU1qh4yZIgXWw1h5D3PakhD08XEkucv7D3COp+dOnVSOTnPrOecCRMmqFyqG2qnE/k3\ntz7r1atX92KrwZn1uZZNqypVqqRqrGcRyWraWbp0aZWT86VBgwaq5s0331Q52UT44MGDqubRRx/1\n4oULF6oa5mEw1vWgb9++MY+TzfCcc65NmzYqt2LFCi+2nqEGDhyocvI+a30WMjMzVc6aL0iccuXK\nefFVV12lapYvX65yVkPpIOR9LujnWs6XoI1hE/Xdi+9wtpIlS3rx0KFDVU379u1VTn6Xtp5HrMbx\nAwYM8OIff/xR1QR5frZY3xs7duzoxUEaoltzUzaLdC65jUPD4BfWAAAAAAAAAIBIYMEaAAAAAAAA\nABAJLFgDAAAAAAAAACKBBWsAAAAAAAAAQCREsumiJUjzsrCbzlvHyY32rc3Hwzb5kU1MnHOuTp06\nMV9bNn785ZdfQr0/ggvSSKFZs2Yq179/fy+2GjdYTSGGDRsW87gggjZ8QHKFvUbIeScb1jlnN4qQ\nPvroI5UL28TOakITBPNOq1evnsrJJlUWaz49/fTTKicbi1jHWdcI2aTTOm7r1q0q17JlSy9u0qSJ\nqpHNQKpUqaJqatWqpXLr16/34sI2nypWrOjFPXr0UDXW88mzzz7rxWvXrlU1Qf6WVk3YZizWvG/Y\nsKEXW9eZlStXevG6detCvT9sQZpBBf3cyfN36qmnqhrrcy5Z96lPPvkk0BjCsOZdYbvWWJ9r+Tew\n7htWrnHjxl78+OOPqxqraZXVWEqymtRXrVrVi2+44QZVY91zpLlz56qcbIhsNUxHMFYTROv5VjYB\n69Onj6pZtmxZzPerUaOGyp188skqJ+ew9d1LPotYxyE8q2HcoEGDvPi6665TNc8884zKzZ49O9QY\nglzzrXuFnAfWvOC6kXryvtCzZ09VIxssWnJzc1Vu3LhxKpednR3ztay5IeeU9VmwGkbK52eLnNNW\no9j58+ernPz+l9/PSFxpAQAAAAAAAACRwII1AAAAAAAAACASWLAGAAAAAAAAAERCJPawDiKZ+6QE\n2Zs17F601n5s9913n8qVKFHCi639+0aOHOnF1j40SCy531q3bt1UzcMPP6xyct+knTt3qhpr361N\nmzbFHFOQ/bMsYecwtLD7OQclz6e1J3CFChVUTu4v/M4776iasPuiJ/M6Wdh07dpV5YLso7Z3716V\nk/cF54Kdh7DnytrLTe6PHPbzUbJkyVDHpQvrOn7aaad5cWZmpqrZvn27yr333ntenOrPpjUHLr30\nUpWTz0jW/Bo9erQX5+TkxDm6wivoZzPsM7d8lr322mtVjXWtk/PT2q9627ZtKhd2z9Ewr5Nu5N/c\n2sNa7m1p7YV/0kknqZycB7169VI1V155pcrJ55PVq1erGqt/T4sWLby4cuXKqsaaB3I/bNlHxqpB\ncEH6AFl7Q//0009evGTJElVjvZbcD/vtt99WNUGetQ4dOqRylSpVUjm5H/+ePXtUTdieD+lOzo26\ndeuqmltvvdWLrc+1tS/x+PHjvTiRzz/WvSLIHtaF8R6TStb1Xc6pMmXKBHotea6s/cebN2+ucnLP\n/NKlS6sa6xm3VKlSXvz73/9e1TRq1EjlrL2uJTn2efPmqZo5c+aoXNTmK7+wBgAAAAAAAABEAgvW\nAAAAAAAAAIBIYMEaAAAAAAAAABAJLFgDAAAAAAAAACKhwDRdTJSgDevkZuNBNx+Xr3/yySermu7d\nu6ucbMowbdo0VbNw4cJQY0Iw1ub1/fv39+JHHnlE1ViNwoI0jpFNsZwL1kjTGqdsXGU1DEE0Wdck\nOacuueQSVWPNg3//+99enJ2drWpoUpV68h7TrFmzQMfJ+8KECRNUzb59+8IPLEFkU4+wDToL+3XL\n+kzLRitWwyirSbP8fFrnJJGfYfn6FStWVDU33XSTysl5YDV/mTFjxjGPwdEl81puvbZsCmo97wZp\nfjd27FhVE7ZhMGzyvFufqx9++MGLhw4dqmpuueUWlZPffaznZKsZnWzqOHXqVFVjXVvOO+88Lw56\nvfvuu++8eOnSpYGOQzhWA8tvv/1W5WRzPatpp9WA75prrvHirKwsVWOdTzmuLVu2qJpWrVqp3N13\n3+3F1udj8+bNKgf9XNy5c2dVI+8n8ruuc/Z5ad26tRdbjebCPkdYz2lWQ07Jun9xT0scay2vatWq\nXrx7925VE/TeJPXt21flypUr58XWXLGuP/J+Zc1zi3wt6zvUzJkzvXjMmDGqZtOmTSoXtbnJL6wB\nAAAAAAAAAJHAgjUAAAAAAAAAIBJYsAYAAAAAAAAARAIL1gAAAAAAAACASEj7potyI3NrA3SrOYfc\ncD1I4xrnnCtVqpQXX3nllaqmRIkSKrdhwwYvHjFihKqRja2QWHKzfOecu+eee7xYnt+jyc3N9eJx\n48apml9++UXlZBOIoI1j5Eb7NKVKrKCf/0S9dpUqVbzYakZiNZj48ssvvdhqcJMoNCIKTp6rU089\nVdVYTT7k+Xv11VdVTarPg9UMRDZWs+amZDX0WLduncoVxHkW9Lot66yGivJabj3DlClTRuXkOXnz\nzTdVzc6dO1VOzsMgTX6dc658+fJefO+996qa2rVrq5z89/39739XNcm8jiE863N+5plnenGQZlTO\nObd161YvXrRokapJ1LWgIF5TkiFI08WcnBwv/te//qVqPv30U5UrW7bsMd/LOf2c7JxuImwdV6tW\nLZXr06ePF1vP6lZDqgcffNCLrcbn8jrN/AlO/q2sv+/PP/+schdddJEX/+EPf1A11jmW58p6zrAa\njL3zzjtebDX27NKli8pddtllXmw1WPzTn/7kxXyX/x/y/tGkSZOYNdazlbWm8u6773qx/Jw759x7\n772ncrIpX82aNVXNoEGDVO6UU07x4m3btqma999/P+Y4rWcyrjfBWH8n2Wxz2LBhqsaad/J51ppj\nTZs2Vbn69et7sdXQ0XqmDvKdybo/y/vlQw89pGomT57sxdbctK5J8v3yex7yC2sAAAAAAAAAQCSw\nYA0AAAAAAAAAiAQWrAEAAAAAAAAAkZBWe1hbe8AE2f/I2js0yF4t1n6Tbdq08eJevXqpml27dqnc\nI4884sXpspdnQWLtTyb3Erbmj7Un2/Dhw73Y2rsq7D5m1j5GzI38F/YcWHOqYcOGXlynTh1VY+3N\nt3DhQi9O5F7mzLHw5P5n1apVUzXWuZL73Fv73gfZG9naM806Lsj+xRdffLHKdejQ4bjf7+uvv1Y1\n1t5qBVHYz4o1B5YuXerFy5YtUzXWnug9e/b04pYtW6qatWvXqpy8L1nPVXKfR+ecq1Gjhhdfe+21\nqsZ6ZpL3z/Xr16saejIEY91Lkrn/rvV+8jmqePHiqsYaw6xZs7w47L7lQfePh2b9nYLsc209y8q9\nrxNJ7nfunL4myeuRc3qOOefczJkzvTjI83WQz9nRXquws+bKxo0bVU72SbCuI9bfXL7+22+/rWpG\njhypcnv27PHiHj16qJpu3bqpnHy2s77z/+Uvf/Fi6zmuMJLPA1u2bFE18nuO9WxpkT2pxowZo2pG\njRoV83WsZxbrmUheI6xr6YUXXqhyHTt29OLbb79d1ezYsSPmOGFfb+W94o033gj12tY8sPbQP+us\ns7z4uuuuUzVdu3ZVOblntnVts66T11xzjRdb36vkM3ZBfR7iF9YAAAAAAAAAgEhgwRoAAAAAAAAA\nEAksWAMAAAAAAAAAIoEFawAAAAAAAABAJBTYpovWpveySYNzenNxq1FZkA3IrferX7++yslmDrJp\nn3POzZs3T+U++OADL7YaQSJxrLny5z//WeVkQw1rU/8FCxao3CuvvOLFVmNGa95Z80xKZiMXq6FF\nkKZNQcZUUDf6TzbrnLdr186LS5curWqsBgwrV65MyBisMXFNCq9MmTJeXKFCBVVj/c3lcdb9K0gD\nKKthUZCmr02bNlW5v/3tbzHHaZHXCHmNdC58o7V0YZ3fJUuWePH48eNVTadOnVTutNNO8+KsrCxV\nYzVilHNnw4YNqmby5MkqJ+8dQe4lVl2tWrViHse9JLhk/q2sz/0VV1zhxdZ17dChQyr3wgsveHHQ\n5xw5N7h3pb969eqpnGyyuHfvXlVz8803q1xubq4Xc21JriCNhZ1zbsqUKV5sfa6te8zDDz/sxR9+\n+KGqCfKcIZtxOufc6tWrVU7eZ61n9dNPP92LN23apGoKY4NO+Qz6+eefqxrZRE42tXMuWCNGq6Zk\nyZIqF/Y7uLxuWGsMVuO+3r17e7H1HU42jCyMcyWsRP2trGcI6zoi1/cqV66salq3bq1y8rphrRmN\nHj1a5eTnI+haU0HEL6wBAAAAAAAAAJHAgjUAAAAAAAAAIBJYsAYAAAAAAAAARAIL1gAAAAAAAACA\nSCiwTRet5j1BmruE3Xzcalr1yCOPqFyzZs28eMuWLarm1VdfVbmtW7d6cbpskh5VNWvWVLmGDRuq\nnJxTVlOsF198UeWCNPWwmjKUK1fOi2VDGOfsTfWlIM2trGYADRo0ULk9e/Z4sdV4xGqiJOdwkCZv\nUZXMz6N1benSpYsXWw1D1qxZo3JB5l2QuWHNTdm8gmtUcPLvG7SBUKVKlbxY3l+cc+7TTz+N+f6y\neaxzzpUvX17lLrvsMi8eO3asqilbtmzM97Ps3r3bi61mSIVtTgVpCi2bh33yySeqZv78+SpXu3Zt\nL5bNoZxzrlGjRion7znWecrOzla56tWre/Fdd92lakqVKqVyct5bTZUmTZrkxYVtngRl/V2S2bDy\nxBNPVDl5zbLs379f5YI0DA7y3G9dW7l3JZd1XoIIch6sZ5Hnn39e5WQj4/fff1/VrFu3LtQYEnEM\n/of1t9u8ebPKPffcc1787LPPqhrre4dsjmZdD6z5KsdlNRu+5557VE42ebSe5+W90Wr2Z33XS/d5\nJp935s6dq2pkE98qVaqoGuv5tn79+l58+eWXq5qePXuqnPWMIlnPaUGe8S3y+ta9e3dV89BDD3kx\nTRejwbqOyPk5fPhwVVOnTh2Vk/Nn2bJlqmbatGkqJ9eD0vmawS+sAQAAAAAAAACRwII1AAAAAAAA\nACASWLAGAAAAAAAAAERCgd3D2trDx9rbN1H7uVx88cWBcnIfGmu/SStn7YmE5LH2ySxWrFio16pX\nr57KyX0crT3LrD215Lg+/vhjVfPFF1/EHJO11+3gwYOP+V7OObdv3z6Ve+qpp7z4o48+UjU5OTkq\nJ/ep3Llzpz3YQsTa86patWoq17Rp05jHffXVVyoXZJ9w65oor6dyH8CjjQHByH2I5b7wztl7v8r9\n9ORn0TnnRo0apXJyH7ULLrhA1bRp00blqlat6sXWvtoWOaes/Rh///vfe7G1jy20IL0AduzYoXJy\nz3Brj2DrWUTug2/dE6xriNwf1tov1tp/Ur5W6dKlVY2ch+zhGFyinoGt6/+ZZ56pcvI5ynr/RYsW\nqdyuXbtijsF6LZljbiRXkB4Yzuk9XK1niiDPGdZ+n+ecc07M9/vss88CvV+ipPPeofEIsoe+tRe1\nlQvzftYe6Nb+wnL9wFpPsObUnXfe6cWtWrVSNfI5LjMzU9VY/a7kvd66thXkeSfHbp3z7du3HzM+\nmgitTUcAAAsFSURBVCVLlnixtT+21Y/l0ksv9eIg+507F3zP6lis3kRyDrNeFA3W8+zixYu92OoV\nZs0p+VkfN26cqrH2+i/In//jxS+sAQAAAAAAAACRwII1AAAAAAAAACASWLAGAAAAAAAAAEQCC9YA\nAAAAAAAAgEgosE0Xk73RePHixb346aefVjUlSpRQOdmgaPTo0apGNjVC6lnNK6yGLLKZjNVYYejQ\noSo3YMAAL7Y255eNOJzT81o2KnPObs4hm0dYDdyCNJWU8945504++WQvtpr9WU3krMZrhZ01f2Qz\nTOd04zFrbv773/9WubANp4Icl8zmXeneOELeF9566y1VM2TIEJWTzVZOO+00VTNp0qSY72/NO+s8\nBGmsac2VDRs2ePHvfvc7VfP99997cbqf81Sy/payMY913qxnkbCNyeRxGzduVDVNmjSJeZz1/vJz\nEKS5LBLLuob069dP5eQ1xDqfjz32mMqFbSQl53Wiml8huLDXnyCs5vbWdy/JagSeqGePoA2oC9s9\nLsgzRbL/JvL9rDFZ1xo5P4PMaeecW758uRfXrVtX1cjGaw0bNlQ11vvJZt3WdyqrOWRhm3cWeT6t\nxtRffPGFyl100UVebH0ntprMBrkmWOdFPsusWrUq1Gsjuaxm4HPmzFE5+VkPes/54IMPvPiNN94I\ndFxhwpMdAAAAAAAAACASWLAGAAAAAAAAAEQCC9YAAAAAAAAAgEgosHtYJ5K1x4zcm+/EE09UNdbe\nfM8995wXr127Ns7RIRkWLFigckuXLlW5Ro0aebG1d5W1F3WdOnW8OOy+imXKlFG5evXqqZycw0He\nz9qPTe5F65xz8+fP9+LVq1ermm3btgV6/cKuQoUKKterV6+Yx+3atUvlFi5cqHJh97iSx1mvU9j3\nz4qH3E9v3LhxqsaaB1lZWV5sfa4TuWerPMfWPW7KlCkqd8stt3ixtV8g8yd/WX//sPvKWuQ8tObl\noUOHVE7uo/3DDz+omrD7aiNxZJ8M55w7++yzVU7Os507d6oaa+/QROHelVzW3zKRn0+5X33v3r1V\njXVtkeM66aSTAh0nx259H5TP/da+tta1LZ2egYPso2t9P5K9c6y/UzKv79b7Wfe9sNcI+Xex+gdl\nZmZ6cdOmTVWNdX1dtmyZF1v9i6x+DtwvNev8Tps2TeUGDhzoxfK7vHP25z/IdyjrXGVnZ3uxdW9M\np+tIQSGvZaNGjVI1sr+Xc/p6YF1rZs6cqXI9e/aMeVxhxy+sAQAAAAAAAACRwII1AAAAAAAAACAS\nWLAGAAAAAAAAAEQCC9YAAAAAAAAAgEgodE0XrcYRZ5xxhso9+OCDXmw1k9i4caPKjR492otp9hJN\nmzZtUrn27durnGyO0blzZ1XTvHlzlTvrrLO8uEqVKqrGmouyQcjBgwdVzZ49e1QuNzfXi2UjK+f0\nv9lq2vfOO++o3MqVK4/5Xs7R6Oho5DmuWbOmqrEapMjzN2vWLFWzb9++OEd3dMls+MC8sK8/1nXk\n/fff9+JWrVqpGuveFKRBknWOc3JyvPjee+9VNa+88orKWY2NEC3J/tyVL1/ei615aTWPXbx4sRfP\nnTtX1dBEKvXkNUQ2gHVON1RzTp8rq5Ez9xccjWxM3aRJE1Vj3d/keW/Tpo2qKVeunMrt3bvXi63G\njLLJWpCma87peZ5ujbTkeShZsqSqkdcI6+9rfV+Rf6ugn+tkNhC37mnVqlXz4jJlyqiaUqVKeXGt\nWrVUzebNm1VO/g32798fswbByYaHzjk3ePBgL5ZrOs7Z90LZLNb6nv6Pf/xD5SZOnOjFa9euVTU0\nXUwu634iP6P9+/dXNda1TH4eV6xYoWq6du0a8zho/MIaAAAAAAAAABAJLFgDAAAAAAAAACKBBWsA\nAAAAAAAAQCSwYA0AAAAAAAAAiIRC13SxUqVKKvf888+rnGz8YW16P378eJWzmiIgeqymG7LhmHPO\nzZs375hxPKyN/mVONnI4miCNRiRrk3+aJybX7t27Ve5vf/ubysnrzxtvvKFqDh8+nLiBId9ZDena\ntWvnxfXq1VM1Y8aMUblGjRp58fbt21XN008/rXKffvqpF8uGVM5xPYDdbEY2Xfzpp59UjdVEeMaM\nGV5sXSODNBFFYgVpGGw978rGZFu3blU1pUuXVjl53rnOFE5yblgNfYM0OLSej6ymgLLhX5BrjdUk\n0GoMK8cZpFlkVAUZp3V9l/9m63uHdT9J1N8l7OtY58pqMiu/o1nvJ+eL1fB+2bJlKmc9f0kFZf5E\nkfWZlQ3uu3fvrmpq166tcvK6tX79elWzceNGlQvSUJFznFzWOsuFF17oxdYzi3Utk88x11xzjaqh\niWY4/MIaAAAAAAAAABAJLFgDAAAAAAAAACKBBWsAAAAAAAAAQCSk/R7Wcs+y2267TdU0b95c5eSe\nNtb+xp9//nmco0NhFmS/aGv/PhQc8nxa+5qNGDFC5eSeftY8sPbPQnqR82f16tWqpk+fPikaDfA/\nrHvXtm3bvPjdd99VNdYesnLvTmu/dWuvSSSXvL/I/e2dc+72229XuWrVqnnx/PnzVc3OnTvjHB3S\nldwDdOLEiapmwIABKif3+50+fbqq2bNnj8oFecaW+xkH3Ys63feelf8+a99wee0O+7eLwt/S2qN7\n1apVXhxkX395r3TO7t0g/3b0rUk++Te3zpWVk/M6CvMVmvwsOudc9erVVa5Lly4xj7O+g//www9e\n/OOPPx7vEHEU/MIaAAAAAAAAABAJLFgDAAAAAAAAACKBBWsAAAAAAAAAQCSwYA0AAAAAAAAAiIS0\narpoNXMoXbq0F/fr10/VWI2AJKvp4qJFi1SOjfYBHI11faCxJoCCxLqOyWek2bNnqxrrGU02OSqM\nzcsKAquZmNVYM0jzKc4njkY2Rhw+fLiqeeGFF1SuRIkSXiyb4TmX3Gct69qGgtsc3LpGWc1/5Xw9\ncOCAqilWrJgXn3CCXnqxXvvXX3+NOU5EA/e0gqtOnToqd95553mx/Aw7Z39mJ0yY4MX79u2Lc3T4\nD35hDQAAAAAAAACIBBasAQAAAAAAAACRwII1AAAAAAAAACASWLAGAAAAAAAAAERC2jddzMzM9OKy\nZcuqmiCb5T/55JMqt2vXruMYHQAAQPqjYVT6o6EiEk3OH6tRYnZ2dqqGExjzvnCS5/3w4cOqRjZs\ny83NVTXWccwpILGsJrDWWt7atWu9WDb1dc65Z555RuVef/31OEaHY+EX1gAAAAAAAACASGDBGgAA\nAAAAAAAQCSxYAwAAAAAAAAAiIa32sLb2plm1apUXN2zYUNUUKaLX7ffu3evF1v5SAAAAAACg8JD7\nTFvrEDk5OakaDoBjsPaFX7p0qcq1bt06FcPBceAX1gAAAAAAAACASGDBGgAAAAAAAAAQCSxYAwAA\nAAAAAAAigQVrAAAAAAAAAEAkhG66aG1cDgTF/EE8mD+IB/MHYTF3EA/mD+LB/EE8mD+IB/MH8WD+\nICx+YQ0AAAAAAAAAiAQWrAEAAAAAAAAAkZBxPD/Pz8jI2OqcW5O84aCAq5eXl1f1aP+R+YNjYO4g\nHswfxIP5g3gwfxAP5g/iwfxBPJg/iAfzB/E45vz5j+NasAYAAAAAAAAAIFnYEgQAAAAAAAAAEAks\nWAMAAAAAAAAAIoEFawAAAAAAAABAJLBgDQAAAAAAAACIBBasAQAAAAAAAACRwII1AAAAAAAAACAS\nWLAGAAAAAAAAAEQCC9YAAAAAAAAAgEhgwRoAAAAAAAAAEAn/D4KC5ehFLTFQAAAAAElFTkSuQmCC\n", 199 | "text/plain": [ 200 | "" 201 | ] 202 | }, 203 | "metadata": {}, 204 | "output_type": "display_data" 205 | } 206 | ], 207 | "source": [ 208 | "fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(20,4))\n", 209 | "in_imgs = mnist.test.images[:10]\n", 210 | "reconstructed, compressed = sess.run([decoded, encoded], feed_dict={inputs_: in_imgs})\n", 211 | "\n", 212 | "for images, row in zip([in_imgs, reconstructed], axes):\n", 213 | " for img, ax in zip(images, row):\n", 214 | " ax.imshow(img.reshape((28, 28)), cmap='Greys_r')\n", 215 | " ax.get_xaxis().set_visible(False)\n", 216 | " ax.get_yaxis().set_visible(False)\n", 217 | "\n", 218 | "fig.tight_layout(pad=0.1)" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 9, 224 | "metadata": { 225 | "collapsed": true 226 | }, 227 | "outputs": [], 228 | "source": [ 229 | "sess.close()" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "metadata": { 235 | "collapsed": true 236 | }, 237 | "source": [ 238 | "## Up Next\n", 239 | "\n", 240 | "We're dealing with images here, so we can (usually) get better performance using convolution layers. So, next we'll build a better autoencoder with convolutional layers.\n", 241 | "\n", 242 | "In practice, autoencoders aren't actually better at compression compared to typical methods like JPEGs and MP3s. But, they are being used for noise reduction, which you'll also build." 243 | ] 244 | } 245 | ], 246 | "metadata": { 247 | "kernelspec": { 248 | "display_name": "Python 3", 249 | "language": "python", 250 | "name": "python3" 251 | }, 252 | "language_info": { 253 | "codemirror_mode": { 254 | "name": "ipython", 255 | "version": 3 256 | }, 257 | "file_extension": ".py", 258 | "mimetype": "text/x-python", 259 | "name": "python", 260 | "nbconvert_exporter": "python", 261 | "pygments_lexer": "ipython3", 262 | "version": "3.6.1" 263 | } 264 | }, 265 | "nbformat": 4, 266 | "nbformat_minor": 2 267 | } 268 | -------------------------------------------------------------------------------- /autoencoder/assets/autoencoder_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianguoz/Seq2Seq-Gan-Autoencoder/73881bb50578d714b89bbf584ce4a4e5a44987d3/autoencoder/assets/autoencoder_1.png -------------------------------------------------------------------------------- /autoencoder/assets/compressed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianguoz/Seq2Seq-Gan-Autoencoder/73881bb50578d714b89bbf584ce4a4e5a44987d3/autoencoder/assets/compressed.png -------------------------------------------------------------------------------- /autoencoder/assets/convolutional_autoencoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianguoz/Seq2Seq-Gan-Autoencoder/73881bb50578d714b89bbf584ce4a4e5a44987d3/autoencoder/assets/convolutional_autoencoder.png -------------------------------------------------------------------------------- /autoencoder/assets/denoising.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianguoz/Seq2Seq-Gan-Autoencoder/73881bb50578d714b89bbf584ce4a4e5a44987d3/autoencoder/assets/denoising.png -------------------------------------------------------------------------------- /autoencoder/assets/mnist_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianguoz/Seq2Seq-Gan-Autoencoder/73881bb50578d714b89bbf584ce4a4e5a44987d3/autoencoder/assets/mnist_examples.png -------------------------------------------------------------------------------- /autoencoder/assets/simple_autoencoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianguoz/Seq2Seq-Gan-Autoencoder/73881bb50578d714b89bbf584ce4a4e5a44987d3/autoencoder/assets/simple_autoencoder.png -------------------------------------------------------------------------------- /dcgan-svhn/1511.06434.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianguoz/Seq2Seq-Gan-Autoencoder/73881bb50578d714b89bbf584ce4a4e5a44987d3/dcgan-svhn/1511.06434.pdf -------------------------------------------------------------------------------- /dcgan-svhn/assets/32x32eg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianguoz/Seq2Seq-Gan-Autoencoder/73881bb50578d714b89bbf584ce4a4e5a44987d3/dcgan-svhn/assets/32x32eg.png -------------------------------------------------------------------------------- /dcgan-svhn/assets/SVHN_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianguoz/Seq2Seq-Gan-Autoencoder/73881bb50578d714b89bbf584ce4a4e5a44987d3/dcgan-svhn/assets/SVHN_examples.png -------------------------------------------------------------------------------- /dcgan-svhn/assets/dcgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianguoz/Seq2Seq-Gan-Autoencoder/73881bb50578d714b89bbf584ce4a4e5a44987d3/dcgan-svhn/assets/dcgan.png -------------------------------------------------------------------------------- /dcgan-svhn/assets/svhn_gan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianguoz/Seq2Seq-Gan-Autoencoder/73881bb50578d714b89bbf584ce4a4e5a44987d3/dcgan-svhn/assets/svhn_gan.png -------------------------------------------------------------------------------- /dcgan-svhn/checkpoints/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | 4 | -------------------------------------------------------------------------------- /dcgan-svhn/data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /gan_mnist/.gitignore: -------------------------------------------------------------------------------- 1 | MNIST_data 2 | train_samples.pkl 3 | -------------------------------------------------------------------------------- /gan_mnist/assets/gan_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianguoz/Seq2Seq-Gan-Autoencoder/73881bb50578d714b89bbf584ce4a4e5a44987d3/gan_mnist/assets/gan_diagram.png -------------------------------------------------------------------------------- /gan_mnist/assets/gan_network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianguoz/Seq2Seq-Gan-Autoencoder/73881bb50578d714b89bbf584ce4a4e5a44987d3/gan_mnist/assets/gan_network.png -------------------------------------------------------------------------------- /gan_mnist/checkpoints/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /seq2seq-twitter-chatbot/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | # from . import twitter 5 | # from . import imagenet_classes 6 | # from . import 7 | -------------------------------------------------------------------------------- /seq2seq-twitter-chatbot/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianguoz/Seq2Seq-Gan-Autoencoder/73881bb50578d714b89bbf584ce4a4e5a44987d3/seq2seq-twitter-chatbot/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /seq2seq-twitter-chatbot/data/cornell_corpus/data.py: -------------------------------------------------------------------------------- 1 | EN_WHITELIST = '0123456789abcdefghijklmnopqrstuvwxyz ' # space is included in whitelist 2 | EN_BLACKLIST = '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~\'' 3 | 4 | limit = { 5 | 'maxq' : 25, 6 | 'minq' : 2, 7 | 'maxa' : 25, 8 | 'mina' : 2 9 | } 10 | 11 | UNK = 'unk' 12 | VOCAB_SIZE = 8000 13 | 14 | 15 | import random 16 | 17 | import nltk 18 | import itertools 19 | from collections import defaultdict 20 | 21 | import numpy as np 22 | 23 | import pickle 24 | 25 | 26 | 27 | ''' 28 | 1. Read from 'movie-lines.txt' 29 | 2. Create a dictionary with ( key = line_id, value = text ) 30 | ''' 31 | def get_id2line(): 32 | lines=open('raw_data/movie_lines.txt', encoding='utf-8', errors='ignore').read().split('\n') 33 | id2line = {} 34 | for line in lines: 35 | _line = line.split(' +++$+++ ') 36 | if len(_line) == 5: 37 | id2line[_line[0]] = _line[4] 38 | return id2line 39 | 40 | ''' 41 | 1. Read from 'movie_conversations.txt' 42 | 2. Create a list of [list of line_id's] 43 | ''' 44 | def get_conversations(): 45 | conv_lines = open('raw_data/movie_conversations.txt', encoding='utf-8', errors='ignore').read().split('\n') 46 | convs = [ ] 47 | for line in conv_lines[:-1]: 48 | _line = line.split(' +++$+++ ')[-1][1:-1].replace("'","").replace(" ","") 49 | convs.append(_line.split(',')) 50 | return convs 51 | 52 | ''' 53 | 1. Get each conversation 54 | 2. Get each line from conversation 55 | 3. Save each conversation to file 56 | ''' 57 | def extract_conversations(convs,id2line,path=''): 58 | idx = 0 59 | for conv in convs: 60 | f_conv = open(path + str(idx)+'.txt', 'w') 61 | for line_id in conv: 62 | f_conv.write(id2line[line_id]) 63 | f_conv.write('\n') 64 | f_conv.close() 65 | idx += 1 66 | 67 | ''' 68 | Get lists of all conversations as Questions and Answers 69 | 1. [questions] 70 | 2. [answers] 71 | ''' 72 | def gather_dataset(convs, id2line): 73 | questions = []; answers = [] 74 | 75 | for conv in convs: 76 | if len(conv) %2 != 0: 77 | conv = conv[:-1] 78 | for i in range(len(conv)): 79 | if i%2 == 0: 80 | questions.append(id2line[conv[i]]) 81 | else: 82 | answers.append(id2line[conv[i]]) 83 | 84 | return questions, answers 85 | 86 | 87 | ''' 88 | We need 4 files 89 | 1. train.enc : Encoder input for training 90 | 2. train.dec : Decoder input for training 91 | 3. test.enc : Encoder input for testing 92 | 4. test.dec : Decoder input for testing 93 | ''' 94 | def prepare_seq2seq_files(questions, answers, path='',TESTSET_SIZE = 30000): 95 | 96 | # open files 97 | train_enc = open(path + 'train.enc','w') 98 | train_dec = open(path + 'train.dec','w') 99 | test_enc = open(path + 'test.enc', 'w') 100 | test_dec = open(path + 'test.dec', 'w') 101 | 102 | # choose 30,000 (TESTSET_SIZE) items to put into testset 103 | test_ids = random.sample([i for i in range(len(questions))],TESTSET_SIZE) 104 | 105 | for i in range(len(questions)): 106 | if i in test_ids: 107 | test_enc.write(questions[i]+'\n') 108 | test_dec.write(answers[i]+ '\n' ) 109 | else: 110 | train_enc.write(questions[i]+'\n') 111 | train_dec.write(answers[i]+ '\n' ) 112 | if i%10000 == 0: 113 | print('\n>> written {} lines'.format(i)) 114 | 115 | # close files 116 | train_enc.close() 117 | train_dec.close() 118 | test_enc.close() 119 | test_dec.close() 120 | 121 | 122 | 123 | ''' 124 | remove anything that isn't in the vocabulary 125 | return str(pure en) 126 | 127 | ''' 128 | def filter_line(line, whitelist): 129 | return ''.join([ ch for ch in line if ch in whitelist ]) 130 | 131 | 132 | 133 | ''' 134 | filter too long and too short sequences 135 | return tuple( filtered_ta, filtered_en ) 136 | 137 | ''' 138 | def filter_data(qseq, aseq): 139 | filtered_q, filtered_a = [], [] 140 | raw_data_len = len(qseq) 141 | 142 | assert len(qseq) == len(aseq) 143 | 144 | for i in range(raw_data_len): 145 | qlen, alen = len(qseq[i].split(' ')), len(aseq[i].split(' ')) 146 | if qlen >= limit['minq'] and qlen <= limit['maxq']: 147 | if alen >= limit['mina'] and alen <= limit['maxa']: 148 | filtered_q.append(qseq[i]) 149 | filtered_a.append(aseq[i]) 150 | 151 | # print the fraction of the original data, filtered 152 | filt_data_len = len(filtered_q) 153 | filtered = int((raw_data_len - filt_data_len)*100/raw_data_len) 154 | print(str(filtered) + '% filtered from original data') 155 | 156 | return filtered_q, filtered_a 157 | 158 | 159 | ''' 160 | read list of words, create index to word, 161 | word to index dictionaries 162 | return tuple( vocab->(word, count), idx2w, w2idx ) 163 | 164 | ''' 165 | def index_(tokenized_sentences, vocab_size): 166 | # get frequency distribution 167 | freq_dist = nltk.FreqDist(itertools.chain(*tokenized_sentences)) 168 | # get vocabulary of 'vocab_size' most used words 169 | vocab = freq_dist.most_common(vocab_size) 170 | # index2word 171 | index2word = ['_'] + [UNK] + [ x[0] for x in vocab ] 172 | # word2index 173 | word2index = dict([(w,i) for i,w in enumerate(index2word)] ) 174 | return index2word, word2index, freq_dist 175 | 176 | ''' 177 | filter based on number of unknowns (words not in vocabulary) 178 | filter out the worst sentences 179 | 180 | ''' 181 | def filter_unk(qtokenized, atokenized, w2idx): 182 | data_len = len(qtokenized) 183 | 184 | filtered_q, filtered_a = [], [] 185 | 186 | for qline, aline in zip(qtokenized, atokenized): 187 | unk_count_q = len([ w for w in qline if w not in w2idx ]) 188 | unk_count_a = len([ w for w in aline if w not in w2idx ]) 189 | if unk_count_a <= 2: 190 | if unk_count_q > 0: 191 | if unk_count_q/len(qline) > 0.2: 192 | pass 193 | filtered_q.append(qline) 194 | filtered_a.append(aline) 195 | 196 | # print the fraction of the original data, filtered 197 | filt_data_len = len(filtered_q) 198 | filtered = int((data_len - filt_data_len)*100/data_len) 199 | print(str(filtered) + '% filtered from original data') 200 | 201 | return filtered_q, filtered_a 202 | 203 | 204 | 205 | 206 | ''' 207 | create the final dataset : 208 | - convert list of items to arrays of indices 209 | - add zero padding 210 | return ( [array_en([indices]), array_ta([indices]) ) 211 | 212 | ''' 213 | def zero_pad(qtokenized, atokenized, w2idx): 214 | # num of rows 215 | data_len = len(qtokenized) 216 | 217 | # numpy arrays to store indices 218 | idx_q = np.zeros([data_len, limit['maxq']], dtype=np.int32) 219 | idx_a = np.zeros([data_len, limit['maxa']], dtype=np.int32) 220 | 221 | for i in range(data_len): 222 | q_indices = pad_seq(qtokenized[i], w2idx, limit['maxq']) 223 | a_indices = pad_seq(atokenized[i], w2idx, limit['maxa']) 224 | 225 | #print(len(idx_q[i]), len(q_indices)) 226 | #print(len(idx_a[i]), len(a_indices)) 227 | idx_q[i] = np.array(q_indices) 228 | idx_a[i] = np.array(a_indices) 229 | 230 | return idx_q, idx_a 231 | 232 | 233 | ''' 234 | replace words with indices in a sequence 235 | replace with unknown if word not in lookup 236 | return [list of indices] 237 | 238 | ''' 239 | def pad_seq(seq, lookup, maxlen): 240 | indices = [] 241 | for word in seq: 242 | if word in lookup: 243 | indices.append(lookup[word]) 244 | else: 245 | indices.append(lookup[UNK]) 246 | return indices + [0]*(maxlen - len(seq)) 247 | 248 | 249 | 250 | 251 | 252 | def process_data(): 253 | 254 | id2line = get_id2line() 255 | print('>> gathered id2line dictionary.\n') 256 | convs = get_conversations() 257 | print(convs[121:125]) 258 | print('>> gathered conversations.\n') 259 | questions, answers = gather_dataset(convs,id2line) 260 | 261 | # change to lower case (just for en) 262 | questions = [ line.lower() for line in questions ] 263 | answers = [ line.lower() for line in answers ] 264 | 265 | # filter out unnecessary characters 266 | print('\n>> Filter lines') 267 | questions = [ filter_line(line, EN_WHITELIST) for line in questions ] 268 | answers = [ filter_line(line, EN_WHITELIST) for line in answers ] 269 | 270 | # filter out too long or too short sequences 271 | print('\n>> 2nd layer of filtering') 272 | qlines, alines = filter_data(questions, answers) 273 | 274 | for q,a in zip(qlines[141:145], alines[141:145]): 275 | print('q : [{0}]; a : [{1}]'.format(q,a)) 276 | 277 | # convert list of [lines of text] into list of [list of words ] 278 | print('\n>> Segment lines into words') 279 | qtokenized = [ [w.strip() for w in wordlist.split(' ') if w] for wordlist in qlines ] 280 | atokenized = [ [w.strip() for w in wordlist.split(' ') if w] for wordlist in alines ] 281 | print('\n:: Sample from segmented list of words') 282 | 283 | for q,a in zip(qtokenized[141:145], atokenized[141:145]): 284 | print('q : [{0}]; a : [{1}]'.format(q,a)) 285 | 286 | # indexing -> idx2w, w2idx 287 | print('\n >> Index words') 288 | idx2w, w2idx, freq_dist = index_( qtokenized + atokenized, vocab_size=VOCAB_SIZE) 289 | 290 | # filter out sentences with too many unknowns 291 | print('\n >> Filter Unknowns') 292 | qtokenized, atokenized = filter_unk(qtokenized, atokenized, w2idx) 293 | print('\n Final dataset len : ' + str(len(qtokenized))) 294 | 295 | 296 | print('\n >> Zero Padding') 297 | idx_q, idx_a = zero_pad(qtokenized, atokenized, w2idx) 298 | 299 | print('\n >> Save numpy arrays to disk') 300 | # save them 301 | np.save('idx_q.npy', idx_q) 302 | np.save('idx_a.npy', idx_a) 303 | 304 | # let us now save the necessary dictionaries 305 | metadata = { 306 | 'w2idx' : w2idx, 307 | 'idx2w' : idx2w, 308 | 'limit' : limit, 309 | 'freq_dist' : freq_dist 310 | } 311 | 312 | # write to disk : data control dictionaries 313 | with open('metadata.pkl', 'wb') as f: 314 | pickle.dump(metadata, f) 315 | 316 | # count of unknowns 317 | unk_count = (idx_q == 1).sum() + (idx_a == 1).sum() 318 | # count of words 319 | word_count = (idx_q > 1).sum() + (idx_a > 1).sum() 320 | 321 | print('% unknown : {0}'.format(100 * (unk_count/word_count))) 322 | print('Dataset count : ' + str(idx_q.shape[0])) 323 | 324 | 325 | #print '>> gathered questions and answers.\n' 326 | #prepare_seq2seq_files(questions,answers) 327 | 328 | 329 | import numpy as np 330 | from random import sample 331 | 332 | ''' 333 | split data into train (70%), test (15%) and valid(15%) 334 | return tuple( (trainX, trainY), (testX,testY), (validX,validY) ) 335 | 336 | ''' 337 | def split_dataset(x, y, ratio = [0.7, 0.15, 0.15] ): 338 | # number of examples 339 | data_len = len(x) 340 | lens = [ int(data_len*item) for item in ratio ] 341 | 342 | trainX, trainY = x[:lens[0]], y[:lens[0]] 343 | testX, testY = x[lens[0]:lens[0]+lens[1]], y[lens[0]:lens[0]+lens[1]] 344 | validX, validY = x[-lens[-1]:], y[-lens[-1]:] 345 | 346 | return (trainX,trainY), (testX,testY), (validX,validY) 347 | 348 | 349 | ''' 350 | generate batches from dataset 351 | yield (x_gen, y_gen) 352 | 353 | TODO : fix needed 354 | 355 | ''' 356 | def batch_gen(x, y, batch_size): 357 | # infinite while 358 | while True: 359 | for i in range(0, len(x), batch_size): 360 | if (i+1)*batch_size < len(x): 361 | yield x[i : (i+1)*batch_size ].T, y[i : (i+1)*batch_size ].T 362 | 363 | ''' 364 | generate batches, by random sampling a bunch of items 365 | yield (x_gen, y_gen) 366 | 367 | ''' 368 | def rand_batch_gen(x, y, batch_size): 369 | while True: 370 | sample_idx = sample(list(np.arange(len(x))), batch_size) 371 | yield x[sample_idx].T, y[sample_idx].T 372 | 373 | #''' 374 | # convert indices of alphabets into a string (word) 375 | # return str(word) 376 | # 377 | #''' 378 | #def decode_word(alpha_seq, idx2alpha): 379 | # return ''.join([ idx2alpha[alpha] for alpha in alpha_seq if alpha ]) 380 | # 381 | # 382 | #''' 383 | # convert indices of phonemes into list of phonemes (as string) 384 | # return str(phoneme_list) 385 | # 386 | #''' 387 | #def decode_phonemes(pho_seq, idx2pho): 388 | # return ' '.join( [ idx2pho[pho] for pho in pho_seq if pho ]) 389 | 390 | 391 | ''' 392 | a generic decode function 393 | inputs : sequence, lookup 394 | 395 | ''' 396 | def decode(sequence, lookup, separator=''): # 0 used for padding, is ignored 397 | return separator.join([ lookup[element] for element in sequence if element ]) 398 | 399 | 400 | 401 | if __name__ == '__main__': 402 | process_data() 403 | 404 | 405 | def load_data(PATH=''): 406 | # read data control dictionaries 407 | with open(PATH + 'metadata.pkl', 'rb') as f: 408 | metadata = pickle.load(f) 409 | # read numpy arrays 410 | idx_q = np.load(PATH + 'idx_q.npy') 411 | idx_a = np.load(PATH + 'idx_a.npy') 412 | return metadata, idx_q, idx_a 413 | -------------------------------------------------------------------------------- /seq2seq-twitter-chatbot/data/twitter/__pycache__/data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianguoz/Seq2Seq-Gan-Autoencoder/73881bb50578d714b89bbf584ce4a4e5a44987d3/seq2seq-twitter-chatbot/data/twitter/__pycache__/data.cpython-36.pyc -------------------------------------------------------------------------------- /seq2seq-twitter-chatbot/data/twitter/data.py: -------------------------------------------------------------------------------- 1 | EN_WHITELIST = '0123456789abcdefghijklmnopqrstuvwxyz ' # space is included in whitelist 2 | EN_BLACKLIST = '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~\'' 3 | 4 | FILENAME = 'data/chat.txt' 5 | 6 | limit = { 7 | 'maxq' : 20, 8 | 'minq' : 0, 9 | 'maxa' : 20, 10 | 'mina' : 3 11 | } 12 | 13 | UNK = 'unk' 14 | VOCAB_SIZE = 6000 15 | 16 | import random 17 | import sys 18 | 19 | import nltk 20 | import itertools 21 | from collections import defaultdict 22 | 23 | import numpy as np 24 | 25 | import pickle 26 | 27 | 28 | def ddefault(): 29 | return 1 30 | 31 | ''' 32 | read lines from file 33 | return [list of lines] 34 | 35 | ''' 36 | def read_lines(filename): 37 | return open(filename).read().split('\n')[:-1] 38 | 39 | 40 | ''' 41 | split sentences in one line 42 | into multiple lines 43 | return [list of lines] 44 | 45 | ''' 46 | def split_line(line): 47 | return line.split('.') 48 | 49 | 50 | ''' 51 | remove anything that isn't in the vocabulary 52 | return str(pure ta/en) 53 | 54 | ''' 55 | def filter_line(line, whitelist): 56 | return ''.join([ ch for ch in line if ch in whitelist ]) 57 | 58 | 59 | ''' 60 | read list of words, create index to word, 61 | word to index dictionaries 62 | return tuple( vocab->(word, count), idx2w, w2idx ) 63 | 64 | ''' 65 | def index_(tokenized_sentences, vocab_size): 66 | # get frequency distribution 67 | freq_dist = nltk.FreqDist(itertools.chain(*tokenized_sentences)) 68 | # get vocabulary of 'vocab_size' most used words 69 | vocab = freq_dist.most_common(vocab_size) 70 | # index2word 71 | index2word = ['_'] + [UNK] + [ x[0] for x in vocab ] 72 | # word2index 73 | word2index = dict([(w,i) for i,w in enumerate(index2word)] ) 74 | return index2word, word2index, freq_dist 75 | 76 | 77 | ''' 78 | filter too long and too short sequences 79 | return tuple( filtered_ta, filtered_en ) 80 | 81 | ''' 82 | def filter_data(sequences): 83 | filtered_q, filtered_a = [], [] 84 | raw_data_len = len(sequences)//2 85 | 86 | for i in range(0, len(sequences), 2): 87 | qlen, alen = len(sequences[i].split(' ')), len(sequences[i+1].split(' ')) 88 | if qlen >= limit['minq'] and qlen <= limit['maxq']: 89 | if alen >= limit['mina'] and alen <= limit['maxa']: 90 | filtered_q.append(sequences[i]) 91 | filtered_a.append(sequences[i+1]) 92 | 93 | # print the fraction of the original data, filtered 94 | filt_data_len = len(filtered_q) 95 | filtered = int((raw_data_len - filt_data_len)*100/raw_data_len) 96 | print(str(filtered) + '% filtered from original data') 97 | 98 | return filtered_q, filtered_a 99 | 100 | 101 | 102 | 103 | 104 | ''' 105 | create the final dataset : 106 | - convert list of items to arrays of indices 107 | - add zero padding 108 | return ( [array_en([indices]), array_ta([indices]) ) 109 | 110 | ''' 111 | def zero_pad(qtokenized, atokenized, w2idx): 112 | # num of rows 113 | data_len = len(qtokenized) 114 | 115 | # numpy arrays to store indices 116 | idx_q = np.zeros([data_len, limit['maxq']], dtype=np.int32) 117 | idx_a = np.zeros([data_len, limit['maxa']], dtype=np.int32) 118 | 119 | for i in range(data_len): 120 | q_indices = pad_seq(qtokenized[i], w2idx, limit['maxq']) 121 | a_indices = pad_seq(atokenized[i], w2idx, limit['maxa']) 122 | 123 | #print(len(idx_q[i]), len(q_indices)) 124 | #print(len(idx_a[i]), len(a_indices)) 125 | idx_q[i] = np.array(q_indices) 126 | idx_a[i] = np.array(a_indices) 127 | 128 | return idx_q, idx_a 129 | 130 | 131 | ''' 132 | replace words with indices in a sequence 133 | replace with unknown if word not in lookup 134 | return [list of indices] 135 | 136 | ''' 137 | def pad_seq(seq, lookup, maxlen): 138 | indices = [] 139 | for word in seq: 140 | if word in lookup: 141 | indices.append(lookup[word]) 142 | else: 143 | indices.append(lookup[UNK]) 144 | return indices + [0]*(maxlen - len(seq)) 145 | 146 | 147 | def process_data(): 148 | 149 | print('\n>> Read lines from file') 150 | lines = read_lines(filename=FILENAME) 151 | 152 | # change to lower case (just for en) 153 | lines = [ line.lower() for line in lines ] 154 | 155 | print('\n:: Sample from read(p) lines') 156 | print(lines[121:125]) 157 | 158 | # filter out unnecessary characters 159 | print('\n>> Filter lines') 160 | lines = [ filter_line(line, EN_WHITELIST) for line in lines ] 161 | print(lines[121:125]) 162 | 163 | # filter out too long or too short sequences 164 | print('\n>> 2nd layer of filtering') 165 | qlines, alines = filter_data(lines) 166 | print('\nq : {0} ; a : {1}'.format(qlines[60], alines[60])) 167 | print('\nq : {0} ; a : {1}'.format(qlines[61], alines[61])) 168 | 169 | 170 | # convert list of [lines of text] into list of [list of words ] 171 | print('\n>> Segment lines into words') 172 | qtokenized = [ wordlist.split(' ') for wordlist in qlines ] 173 | atokenized = [ wordlist.split(' ') for wordlist in alines ] 174 | print('\n:: Sample from segmented list of words') 175 | print('\nq : {0} ; a : {1}'.format(qtokenized[60], atokenized[60])) 176 | print('\nq : {0} ; a : {1}'.format(qtokenized[61], atokenized[61])) 177 | 178 | 179 | # indexing -> idx2w, w2idx : en/ta 180 | print('\n >> Index words') 181 | idx2w, w2idx, freq_dist = index_( qtokenized + atokenized, vocab_size=VOCAB_SIZE) 182 | 183 | print('\n >> Zero Padding') 184 | idx_q, idx_a = zero_pad(qtokenized, atokenized, w2idx) 185 | 186 | print('\n >> Save numpy arrays to disk') 187 | # save them 188 | np.save('idx_q.npy', idx_q) 189 | np.save('idx_a.npy', idx_a) 190 | 191 | # let us now save the necessary dictionaries 192 | metadata = { 193 | 'w2idx' : w2idx, 194 | 'idx2w' : idx2w, 195 | 'limit' : limit, 196 | 'freq_dist' : freq_dist 197 | } 198 | 199 | # write to disk : data control dictionaries 200 | with open('metadata.pkl', 'wb') as f: 201 | pickle.dump(metadata, f) 202 | 203 | def load_data(PATH=''): 204 | # read data control dictionaries 205 | try: 206 | with open(PATH + 'metadata.pkl', 'rb') as f: 207 | metadata = pickle.load(f) 208 | except: 209 | metadata = None 210 | # read numpy arrays 211 | idx_q = np.load(PATH + 'idx_q.npy') 212 | idx_a = np.load(PATH + 'idx_a.npy') 213 | return metadata, idx_q, idx_a 214 | 215 | import numpy as np 216 | from random import sample 217 | 218 | ''' 219 | split data into train (70%), test (15%) and valid(15%) 220 | return tuple( (trainX, trainY), (testX,testY), (validX,validY) ) 221 | 222 | ''' 223 | def split_dataset(x, y, ratio = [0.7, 0.15, 0.15] ): 224 | # number of examples 225 | data_len = len(x) 226 | lens = [ int(data_len*item) for item in ratio ] 227 | 228 | trainX, trainY = x[:lens[0]], y[:lens[0]] 229 | testX, testY = x[lens[0]:lens[0]+lens[1]], y[lens[0]:lens[0]+lens[1]] 230 | validX, validY = x[-lens[-1]:], y[-lens[-1]:] 231 | 232 | return (trainX,trainY), (testX,testY), (validX,validY) 233 | 234 | 235 | ''' 236 | generate batches from dataset 237 | yield (x_gen, y_gen) 238 | 239 | TODO : fix needed 240 | 241 | ''' 242 | def batch_gen(x, y, batch_size): 243 | # infinite while 244 | while True: 245 | for i in range(0, len(x), batch_size): 246 | if (i+1)*batch_size < len(x): 247 | yield x[i : (i+1)*batch_size ].T, y[i : (i+1)*batch_size ].T 248 | 249 | ''' 250 | generate batches, by random sampling a bunch of items 251 | yield (x_gen, y_gen) 252 | 253 | ''' 254 | def rand_batch_gen(x, y, batch_size): 255 | while True: 256 | sample_idx = sample(list(np.arange(len(x))), batch_size) 257 | yield x[sample_idx].T, y[sample_idx].T 258 | 259 | #''' 260 | # convert indices of alphabets into a string (word) 261 | # return str(word) 262 | # 263 | #''' 264 | #def decode_word(alpha_seq, idx2alpha): 265 | # return ''.join([ idx2alpha[alpha] for alpha in alpha_seq if alpha ]) 266 | # 267 | # 268 | #''' 269 | # convert indices of phonemes into list of phonemes (as string) 270 | # return str(phoneme_list) 271 | # 272 | #''' 273 | #def decode_phonemes(pho_seq, idx2pho): 274 | # return ' '.join( [ idx2pho[pho] for pho in pho_seq if pho ]) 275 | 276 | 277 | ''' 278 | a generic decode function 279 | inputs : sequence, lookup 280 | 281 | ''' 282 | def decode(sequence, lookup, separator=''): # 0 used for padding, is ignored 283 | return separator.join([ lookup[element] for element in sequence if element ]) 284 | 285 | 286 | 287 | if __name__ == '__main__': 288 | process_data() 289 | -------------------------------------------------------------------------------- /seq2seq-twitter-chatbot/data/twitter/idx_a.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianguoz/Seq2Seq-Gan-Autoencoder/73881bb50578d714b89bbf584ce4a4e5a44987d3/seq2seq-twitter-chatbot/data/twitter/idx_a.npy -------------------------------------------------------------------------------- /seq2seq-twitter-chatbot/data/twitter/idx_q.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianguoz/Seq2Seq-Gan-Autoencoder/73881bb50578d714b89bbf584ce4a4e5a44987d3/seq2seq-twitter-chatbot/data/twitter/idx_q.npy -------------------------------------------------------------------------------- /seq2seq-twitter-chatbot/data/twitter/metadata.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianguoz/Seq2Seq-Gan-Autoencoder/73881bb50578d714b89bbf584ce4a4e5a44987d3/seq2seq-twitter-chatbot/data/twitter/metadata.pkl -------------------------------------------------------------------------------- /seq2seq-twitter-chatbot/data/twitter/pull: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget -c 'https://www.dropbox.com/s/tmfwptbs3q180p0/seq2seq.twitter.tar.gz?dl=0' -O seq2seq.twitter.tar.gz 4 | -------------------------------------------------------------------------------- /seq2seq-twitter-chatbot/data/twitter/pull_raw_data: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -c https://raw.githubusercontent.com/Marsan-Ma/chat_corpus/master/twitter_en.txt.gz 3 | -------------------------------------------------------------------------------- /seq2seq-twitter-chatbot/main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorlayer as tl 3 | from tensorlayer.layers import * 4 | import time 5 | 6 | 7 | ###========================== prepare data===========================================### 8 | 9 | # preprocessed data 10 | # from datasets.cornell_corpus import data 11 | # import data_utils 12 | # 13 | # # load data from pickle and npy files 14 | # metadata, idx_q, idx_a = data.load_data(PATH='datasets/cornell_corpus/') # PATH='datasets/cornell_corpus/' 15 | # (trainX, trainY), (testX, testY), (validX, validY) = data_utils.split_dataset(idx_q, idx_a) 16 | from data.twitter import data 17 | metadata, idx_q, idx_a = data.load_data(PATH='data/twitter/') # Twitter 18 | # from data.cornell_corpus import data 19 | # metadata, idx_q, idx_a = data.load_data(PATH='data/cornell_corpus/') # Cornell Moive 20 | (trainX, trainY), (testX, testY), (validX, validY) = data.split_dataset(idx_q, idx_a) 21 | 22 | trainX = trainX.tolist() 23 | trainY = trainY.tolist() 24 | testX = testX.tolist() 25 | testY = testY.tolist() 26 | validX = validX.tolist() 27 | validY = validY.tolist() 28 | 29 | #print(trainX[1]) 30 | 31 | trainX = tl.prepro.remove_pad_sequences(trainX) 32 | trainY = tl.prepro.remove_pad_sequences(trainY) 33 | testX = tl.prepro.remove_pad_sequences(testX) 34 | testY = tl.prepro.remove_pad_sequences(testY) 35 | validX = tl.prepro.remove_pad_sequences(validX) 36 | validY = tl.prepro.remove_pad_sequences(validY) 37 | 38 | # parameters 39 | xseq_len = len(trainX) 40 | yseq_len = len(trainY) 41 | assert xseq_len == yseq_len 42 | 43 | batch_size = 32 44 | n_step = int(xseq_len/batch_size) 45 | xvocab_size = len(metadata['idx2w']) 46 | yvocab_size = xvocab_size 47 | emb_dim = 1024 48 | 49 | w2idx = metadata['w2idx'] # dict word 2 index 50 | idx2w = metadata['idx2w'] # list index 2 word 51 | 52 | unk_id = w2idx['unk'] # 1 53 | pad_id = w2idx['_'] # 0 54 | 55 | start_id = xvocab_size # 8002 56 | end_id = xvocab_size+1 # 8003 57 | 58 | w2idx.update({'start_id': start_id}) 59 | w2idx.update({'end_id': end_id}) 60 | idx2w = idx2w + ['start_id', 'end_id'] 61 | 62 | xvocab_size = yvocab_size = xvocab_size + 2 63 | 64 | 65 | """ A data for Seq2Seq should look like this: 66 | input_seqs : ['how', 'are', 'you', '] 67 | decode_seqs : ['', 'I', 'am', 'fine', '] 68 | target_seqs : ['I', 'am', 'fine', '', '] 69 | target_mask : [1, 1, 1, 1, 0] 70 | """ 71 | 72 | # show trainX[10] as an example 73 | print("encode_seqs", [idx2w[id] for id in trainX[10]]) 74 | target_seqs = tl.prepro.sequences_add_end_id([trainY[10]], end_id=end_id)[0] 75 | # target_seqs = tl.prepro.remove_pad_sequences([target_seqs], pad_id=pad_id)[0] 76 | print("target_seqs", [idx2w[id] for id in target_seqs]) 77 | decode_seqs = tl.prepro.sequences_add_start_id([trainY[10]], start_id=start_id, remove_last=False)[0] 78 | # decode_seqs = tl.prepro.remove_pad_sequences([decode_seqs], pad_id=pad_id)[0] 79 | print("decode_seqs", [idx2w[id] for id in decode_seqs]) 80 | target_mask = tl.prepro.sequences_get_mask([target_seqs])[0] 81 | print("target_mask", target_mask) 82 | print(len(target_seqs), len(decode_seqs), len(target_mask)) 83 | 84 | ###========================== prepare data===========================================### 85 | 86 | 87 | ###========================== main model===========================================### 88 | def model(encode_seqs, decode_seqs, is_train=True, reuse=False): 89 | with tf.variable_scope("model", reuse=reuse): 90 | with tf.variable_scope("embedding") as vs: 91 | net_encode = EmbeddingInputlayer( 92 | inputs = encode_seqs, 93 | vocabulary_size=xvocab_size, 94 | embedding_size=emb_dim, 95 | name = 'seq_embedding' 96 | ) 97 | vs.reuse_variables() 98 | tl.layers.set_name_reuse() #remove if TL version == 1.8.0+ 99 | 100 | net_decode = EmbeddingInputlayer( 101 | inputs = decode_seqs, 102 | vocabulary_size=xvocab_size, 103 | embedding_size=emb_dim, 104 | name = 'seq_embedding' 105 | ) 106 | net_rnn = Seq2Seq(net_encode, net_decode, 107 | cell_fn=tf.contrib.rnn.BasicLSTMCell, 108 | n_hidden=emb_dim, 109 | initializer=tf.random_uniform_initializer(-0.1, 0.1), 110 | encode_sequence_length=retrieve_seq_length_op2(encode_seqs), 111 | decode_sequence_length=retrieve_seq_length_op2(decode_seqs), 112 | initial_state_encode=None, 113 | initial_state_decode=None, 114 | dropout=(0.5 if is_train else None), 115 | n_layer=3, 116 | return_seq_2d=True, 117 | name='Seq2Seq') 118 | net_out = DenseLayer(net_rnn, n_units=xvocab_size, act=tf.identity, name='out') 119 | return net_out, net_rnn 120 | 121 | # model for training 122 | encode_seqs = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name='encode_seqs') 123 | decode_seqs = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name='decode_seqs') 124 | target_seqs = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name='target_seqs') 125 | target_mask = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name='target_mask') 126 | net_out, _ = model(encode_seqs, decode_seqs, is_train=True, reuse=False) 127 | 128 | # model for inference 129 | encode_seqs2 = tf.placeholder(dtype=tf.int64, shape=[1, None], name='encode_seqs_2') 130 | decode_seqs2 = tf.placeholder(dtype=tf.int64, shape=[1, None], name='decode_seqs_2') 131 | net_infer, net_rnn = model(encode_seqs2, decode_seqs2, is_train=False, reuse=True) 132 | y = tf.nn.softmax(net_infer.outputs) 133 | 134 | # loss for training 135 | loss = tl.cost.cross_entropy_seq_with_mask(logits=net_out.outputs, target_seqs=target_seqs, input_mask=target_mask, 136 | return_details=False, name='cost') 137 | net_out.print_params(False) 138 | 139 | lr=0.0001 140 | train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss) 141 | 142 | # session 143 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) 144 | tl.layers.initialize_global_variables(sess) 145 | tl.files.load_and_assign_npz(sess=sess, name='n.npa', network=net_infer) 146 | 147 | # =====train========= 148 | n_epochs = 50 149 | for epoch in range(n_epochs): 150 | epoch_time = time.time() 151 | # shuffle training data 152 | from sklearn.utils import shuffle 153 | trainX, trainY = shuffle(trainX, trainY) 154 | # train an epoch 155 | total_err, n_iter = 0, 0 # total_error, #iterations 156 | for X, Y in tl.iterate.minibatches(inputs=trainX, targets=trainY, batch_size=batch_size, shuffle=False): 157 | step_time = time.time() 158 | 159 | X = tl.prepro.pad_sequences(X) 160 | _target_seqs = tl.prepro.sequences_add_end_id(Y, end_id=end_id) 161 | _target_seqs = tl.prepro.pad_sequences(_target_seqs) 162 | 163 | _decode_seqs = tl.prepro.sequences_add_start_id(Y, start_id=start_id, remove_last=False) 164 | _decode_seqs = tl.prepro.pad_sequences(_decode_seqs) 165 | _target_mask = tl.prepro.sequences_get_mask(_target_seqs) 166 | 167 | _, err = sess.run([train_op, loss], 168 | feed_dict={encode_seqs: X, 169 | decode_seqs: _decode_seqs, 170 | target_seqs: _target_seqs, 171 | target_mask: _target_mask} 172 | ) 173 | 174 | if n_iter % 200 == 0: 175 | print("Epoch[%d/%d] step:[%d/%d] loss:%f took:%.5fs" % (epoch, n_epochs, n_iter, n_step, err, time.time() - step_time)) 176 | total_err += err 177 | n_iter += 1 178 | 179 | # ====inference=== 180 | if n_iter % 1000 == 0: 181 | seeds = ["happy birthday have a nice day", 182 | "so fun to see what the special effects team created for", 183 | "donald trump won last nights presidential debate according to snap online polls"] 184 | for seed in seeds: 185 | print('Query >', seed) 186 | seed_id = [w2idx[w] for w in seed.split(" ")] 187 | for _ in range(5): # 1 Query --> 5 reply 188 | # 1. encode, get state 189 | state = sess.run(net_rnn.final_state_encode, 190 | feed_dict={encode_seqs2: [seed_id]} 191 | ) 192 | # 2. decode, feed start_id, get first word 193 | o, state = sess.run([y, net_rnn.final_state_decode], 194 | feed_dict={net_rnn.initial_state_decode: state, 195 | decode_seqs2: [[start_id]]} 196 | ) 197 | w_id = tl.nlp.sample_top(o[0], top_k=3) 198 | w = idx2w[w_id] 199 | # 3. decode, feed state iteratively 200 | sentence = [w] 201 | for _ in range(30): # max_sequence_length = 30 202 | o, state = sess.run([y, net_rnn.final_state_decode], 203 | feed_dict={net_rnn.initial_state_decode:state, 204 | decode_seqs2: [[w_id]]} 205 | ) 206 | w_id = tl.nlp.sample_top(o[0], top_k=2) 207 | w = idx2w[w_id] 208 | if w_id == end_id: 209 | break 210 | sentence = sentence + [w] 211 | print('>', ' '.join(sentence)) 212 | 213 | print("Epoch[%d/%d] averaged loss:%f took:%.5fs" % (epoch, n_epochs, total_err / n_iter, time.time() - epoch_time)) 214 | tl.files.save_npz(net_infer.all_params, name='n.npz', sess=sess) 215 | -------------------------------------------------------------------------------- /seq2seq/.ipynb_checkpoints/sequence_to_sequence_implementation-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true 7 | }, 8 | "source": [ 9 | "# Character Sequence to Sequence \n", 10 | "In this notebook, we'll build a model that takes in a sequence of letters, and outputs a sorted version of that sequence. We'll do that using what we've learned so far about Sequence to Sequence models.\n", 11 | "\n", 12 | "\n", 13 | "\n", 14 | "\n", 15 | "## Dataset \n", 16 | "\n", 17 | "The dataset lives in the /data/ folder. At the moment, it is made up of the following files:\n", 18 | " * **letters_source.txt**: The list of input letter sequences. Each sequence is its own line. \n", 19 | " * **letters_target.txt**: The list of target sequences we'll use in the training process. Each sequence here is a response to the input sequence in letters_source.txt with the same line number." 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 1, 25 | "metadata": { 26 | "collapsed": true, 27 | "scrolled": false 28 | }, 29 | "outputs": [], 30 | "source": [ 31 | "import helper\n", 32 | "\n", 33 | "source_path = 'data/english1'\n", 34 | "target_path = 'data/french1'\n", 35 | "\n", 36 | "source_sentences = helper.load_data(source_path)\n", 37 | "target_sentences = helper.load_data(target_path)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "Let's start by examining the current state of the dataset. `source_sentences` contains the entire input sequence file as text delimited by newline symbols." 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "data": { 54 | "text/plain": [ 55 | "['6300181537 B000W8KY0G 1574924494 6305958041 B00077']" 56 | ] 57 | }, 58 | "execution_count": 2, 59 | "metadata": {}, 60 | "output_type": "execute_result" 61 | } 62 | ], 63 | "source": [ 64 | "source_sentences[:50].split('\\n')" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "`target_sentences` contains the entire output sequence file as text delimited by newline symbols. Each line corresponds to the line from `source_sentences`. `target_sentences` contains a sorted characters of the line." 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 3, 77 | "metadata": {}, 78 | "outputs": [ 79 | { 80 | "data": { 81 | "text/plain": [ 82 | "['B002E01LQ6 6304744404 B00008LUNW B000NQJP98 078401']" 83 | ] 84 | }, 85 | "execution_count": 3, 86 | "metadata": {}, 87 | "output_type": "execute_result" 88 | } 89 | ], 90 | "source": [ 91 | "target_sentences[:50].split('\\n')" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "## Preprocess\n", 99 | "To do anything useful with it, we'll need to turn the characters into a list of integers: " 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 21, 105 | "metadata": { 106 | "collapsed": true 107 | }, 108 | "outputs": [], 109 | "source": [ 110 | "def extract_character_vocab(data):\n", 111 | " special_words = ['', '', '', '<\\s>']\n", 112 | "\n", 113 | " set_words = set([character for line in data.split('\\n') for character in line.split(' ')])\n", 114 | " int_to_vocab = {word_i: word for word_i, word in enumerate(special_words + list(set_words))}\n", 115 | " vocab_to_int = {word: word_i for word_i, word in int_to_vocab.items()}\n", 116 | "\n", 117 | " return int_to_vocab, vocab_to_int\n", 118 | "\n", 119 | "# Build int2letter and letter2int dicts\n", 120 | "source_int_to_letter, source_letter_to_int = extract_character_vocab(source_sentences)\n", 121 | "target_int_to_letter, target_letter_to_int = extract_character_vocab(target_sentences)\n", 122 | "\n", 123 | "# Convert characters to ids\n", 124 | "source_letter_ids = [[source_letter_to_int.get(letter, source_letter_to_int['']) for letter in line.split(' ')] for line in source_sentences.split('\\n')]\n", 125 | "target_letter_ids = [[target_letter_to_int.get(letter, target_letter_to_int['']) for letter in line.split(' ')] for line in target_sentences.split('\\n')]\n" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 22, 131 | "metadata": {}, 132 | "outputs": [ 133 | { 134 | "name": "stdout", 135 | "output_type": "stream", 136 | "text": [ 137 | "Example source sequence\n", 138 | "[[13912, 45997, 4221, 15839, 46389, 887, 10294, 9024, 3635, 11081, 35851, 34866, 3193, 17286, 47244, 19131, 24487, 36266, 23880, 23845, 28602], [762, 20186, 49438, 3716, 3524, 42601, 27187], [12265, 50943, 32507, 1195, 28834, 32207, 46192]]\n", 139 | "\n", 140 | "\n", 141 | "Example target sequence\n", 142 | "[[50562, 31372, 28959, 19392, 35717, 25840, 25736, 5889, 40752, 42712, 29362, 42665, 48571, 17284, 33690], [10104, 48224, 52024, 32276], [7402, 50696, 24376, 20466, 17934, 45656, 25840, 11080, 28176, 39114, 29680, 28149, 32642, 48459, 40753, 20865, 47631, 3933, 8510, 39345, 5830, 40603]]\n" 143 | ] 144 | } 145 | ], 146 | "source": [ 147 | "print(\"Example source sequence\")\n", 148 | "#print(len(source_letter_to_int))\n", 149 | "print(source_letter_ids[:3])\n", 150 | "print(\"\\n\")\n", 151 | "print(\"Example target sequence\")\n", 152 | "print(target_letter_ids[:3])" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "The last step in the preprocessing stage is to determine the the longest sequence size in the dataset we'll be using, then pad all the sequences to that length." 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 23, 165 | "metadata": {}, 166 | "outputs": [ 167 | { 168 | "name": "stdout", 169 | "output_type": "stream", 170 | "text": [ 171 | "Sequence Length\n", 172 | "35\n", 173 | "\n", 174 | "\n", 175 | "Input sequence example\n", 176 | "[[13912, 45997, 4221, 15839, 46389, 887, 10294, 9024, 3635, 11081, 35851, 34866, 3193, 17286, 47244, 19131, 24487, 36266, 23880, 23845, 28602, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [762, 20186, 49438, 3716, 3524, 42601, 27187, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [12265, 50943, 32507, 1195, 28834, 32207, 46192, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]\n", 177 | "\n", 178 | "\n", 179 | "Target sequence example\n", 180 | "[[50562, 31372, 28959, 19392, 35717, 25840, 25736, 5889, 40752, 42712, 29362, 42665, 48571, 17284, 33690, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [10104, 48224, 52024, 32276, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [7402, 50696, 24376, 20466, 17934, 45656, 25840, 11080, 28176, 39114, 29680, 28149, 32642, 48459, 40753, 20865, 47631, 3933, 8510, 39345, 5830, 40603, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]\n" 181 | ] 182 | } 183 | ], 184 | "source": [ 185 | "def pad_id_sequences(source_ids, source_letter_to_int, target_ids, target_letter_to_int, sequence_length):\n", 186 | " new_source_ids = [sentence + [source_letter_to_int['']] * (sequence_length - len(sentence)) \\\n", 187 | " for sentence in source_ids]\n", 188 | " new_target_ids = [sentence + [target_letter_to_int['']] * (sequence_length - len(sentence)) \\\n", 189 | " for sentence in target_ids]\n", 190 | "\n", 191 | " return new_source_ids, new_target_ids\n", 192 | "\n", 193 | "\n", 194 | "# Use the longest sequence as sequence length\n", 195 | "sequence_length = max(\n", 196 | " [len(sentence) for sentence in source_letter_ids] + [len(sentence) for sentence in target_letter_ids])\n", 197 | "\n", 198 | "# Pad all sequences up to sequence length\n", 199 | "source_ids, target_ids = pad_id_sequences(source_letter_ids, source_letter_to_int, \n", 200 | " target_letter_ids, target_letter_to_int, sequence_length)\n", 201 | "\n", 202 | "print(\"Sequence Length\")\n", 203 | "print(sequence_length)\n", 204 | "print(\"\\n\")\n", 205 | "print(\"Input sequence example\")\n", 206 | "print(source_ids[:3])\n", 207 | "print(\"\\n\")\n", 208 | "print(\"Target sequence example\")\n", 209 | "print(target_ids[:3])" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "metadata": {}, 215 | "source": [ 216 | "This is the final shape we need them to be in. We can now proceed to building the model." 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "metadata": {}, 222 | "source": [ 223 | "## Model\n", 224 | "#### Check the Version of TensorFlow\n", 225 | "This will check to make sure you have the correct version of TensorFlow" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 24, 231 | "metadata": {}, 232 | "outputs": [ 233 | { 234 | "name": "stdout", 235 | "output_type": "stream", 236 | "text": [ 237 | "TensorFlow Version: 1.0.0\n" 238 | ] 239 | } 240 | ], 241 | "source": [ 242 | "from distutils.version import LooseVersion\n", 243 | "import tensorflow as tf\n", 244 | "\n", 245 | "# Check TensorFlow Version\n", 246 | "assert LooseVersion(tf.__version__) >= LooseVersion('1.0'), 'Please use TensorFlow version 1.0 or newer'\n", 247 | "print('TensorFlow Version: {}'.format(tf.__version__))" 248 | ] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "metadata": {}, 253 | "source": [ 254 | "### Hyperparameters" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 25, 260 | "metadata": { 261 | "collapsed": true 262 | }, 263 | "outputs": [], 264 | "source": [ 265 | "# Number of Epochs\n", 266 | "epochs = 60\n", 267 | "# Batch Size\n", 268 | "batch_size = 128\n", 269 | "# RNN Size\n", 270 | "rnn_size = 256\n", 271 | "# Number of Layers\n", 272 | "num_layers = 2\n", 273 | "# Embedding Size\n", 274 | "encoding_embedding_size = 256\n", 275 | "decoding_embedding_size = 256\n", 276 | "# Learning Rate\n", 277 | "learning_rate = 0.001" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "metadata": {}, 283 | "source": [ 284 | "### Input" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 26, 290 | "metadata": { 291 | "collapsed": true 292 | }, 293 | "outputs": [], 294 | "source": [ 295 | "input_data = tf.placeholder(tf.int32, [batch_size, sequence_length])\n", 296 | "targets = tf.placeholder(tf.int32, [batch_size, sequence_length])\n", 297 | "lr = tf.placeholder(tf.float32)" 298 | ] 299 | }, 300 | { 301 | "cell_type": "markdown", 302 | "metadata": {}, 303 | "source": [ 304 | "### Sequence to Sequence\n", 305 | "The decoder is probably the most complex part of this model. We need to declare a decoder for the training phase, and a decoder for the inference/prediction phase. These two decoders will share their parameters (so that all the weights and biases that are set during the training phase can be used when we deploy the model).\n", 306 | "\n", 307 | "\n", 308 | "First, we'll need to define the type of cell we'll be using for our decoder RNNs. We opted for LSTM.\n", 309 | "\n", 310 | "Then, we'll need to hookup a fully connected layer to the output of decoder. The output of this layer tells us which word the RNN is choosing to output at each time step.\n", 311 | "\n", 312 | "Let's first look at the inference/prediction decoder. It is the one we'll use when we deploy our chatbot to the wild (even though it comes second in the actual code).\n", 313 | "\n", 314 | "\n", 315 | "\n", 316 | "We'll hand our encoder hidden state to the inference decoder and have it process its output. TensorFlow handles most of the logic for us. We just have to use [`tf.contrib.seq2seq.simple_decoder_fn_inference`](https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/simple_decoder_fn_inference) and [`tf.contrib.seq2seq.dynamic_rnn_decoder`](https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/dynamic_rnn_decoder) and supply them with the appropriate inputs.\n", 317 | "\n", 318 | "Notice that the inference decoder feeds the output of each time step as an input to the next.\n", 319 | "\n", 320 | "As for the training decoder, we can think of it as looking like this:\n", 321 | "\n", 322 | "\n", 323 | "The training decoder **does not** feed the output of each time step to the next. Rather, the inputs to the decoder time steps are the target sequence from the training dataset (the orange letters)." 324 | ] 325 | }, 326 | { 327 | "cell_type": "markdown", 328 | "metadata": {}, 329 | "source": [ 330 | "### Encoding\n", 331 | "- Embed the input data using [`tf.contrib.layers.embed_sequence`](https://www.tensorflow.org/api_docs/python/tf/contrib/layers/embed_sequence)\n", 332 | "- Pass the embedded input into a stack of RNNs. Save the RNN state and ignore the output." 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": 27, 338 | "metadata": { 339 | "collapsed": true 340 | }, 341 | "outputs": [], 342 | "source": [ 343 | "source_vocab_size = len(source_letter_to_int)\n", 344 | "\n", 345 | "# Encoder embedding\n", 346 | "enc_embed_input = tf.contrib.layers.embed_sequence(input_data, source_vocab_size, encoding_embedding_size)\n", 347 | "\n", 348 | "# Encoder\n", 349 | "enc_cell = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.BasicLSTMCell(rnn_size)] * num_layers)\n", 350 | "_, enc_state = tf.nn.dynamic_rnn(enc_cell, enc_embed_input, dtype=tf.float32)" 351 | ] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "metadata": {}, 356 | "source": [ 357 | "### Process Decoding Input" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": 28, 363 | "metadata": {}, 364 | "outputs": [ 365 | { 366 | "name": "stdout", 367 | "output_type": "stream", 368 | "text": [ 369 | "Targets\n", 370 | "[[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23\n", 371 | " 24 25 26 27 28 29 30 31 32 33 34]\n", 372 | " [35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58\n", 373 | " 59 60 61 62 63 64 65 66 67 68 69]]\n", 374 | "\n", 375 | "\n", 376 | "Processed Decoding Input\n", 377 | "[[ 2 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22\n", 378 | " 23 24 25 26 27 28 29 30 31 32 33]\n", 379 | " [ 2 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57\n", 380 | " 58 59 60 61 62 63 64 65 66 67 68]]\n" 381 | ] 382 | } 383 | ], 384 | "source": [ 385 | "import numpy as np\n", 386 | "\n", 387 | "# Process the input we'll feed to the decoder\n", 388 | "ending = tf.strided_slice(targets, [0, 0], [batch_size, -1], [1, 1])\n", 389 | "dec_input = tf.concat([tf.fill([batch_size, 1], target_letter_to_int['']), ending], 1)\n", 390 | "\n", 391 | "demonstration_outputs = np.reshape(range(batch_size * sequence_length), (batch_size, sequence_length))\n", 392 | "\n", 393 | "sess = tf.InteractiveSession()\n", 394 | "print(\"Targets\")\n", 395 | "print(demonstration_outputs[:2])\n", 396 | "print(\"\\n\")\n", 397 | "print(\"Processed Decoding Input\")\n", 398 | "print(sess.run(dec_input, {targets: demonstration_outputs})[:2])" 399 | ] 400 | }, 401 | { 402 | "cell_type": "markdown", 403 | "metadata": {}, 404 | "source": [ 405 | "### Decoding\n", 406 | "- Embed the decoding input\n", 407 | "- Build the decoding RNNs\n", 408 | "- Build the output layer in the decoding scope, so the weight and bias can be shared between the training and inference decoders." 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": 29, 414 | "metadata": { 415 | "collapsed": true 416 | }, 417 | "outputs": [], 418 | "source": [ 419 | "target_vocab_size = len(target_letter_to_int)\n", 420 | "\n", 421 | "# Decoder Embedding\n", 422 | "dec_embeddings = tf.Variable(tf.random_uniform([target_vocab_size, decoding_embedding_size]))\n", 423 | "dec_embed_input = tf.nn.embedding_lookup(dec_embeddings, dec_input)\n", 424 | "\n", 425 | "# Decoder RNNs\n", 426 | "dec_cell = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.BasicLSTMCell(rnn_size)] * num_layers)\n", 427 | "\n", 428 | "with tf.variable_scope(\"decoding\") as decoding_scope:\n", 429 | " # Output Layer\n", 430 | " output_fn = lambda x: tf.contrib.layers.fully_connected(x, target_vocab_size, None, scope=decoding_scope)" 431 | ] 432 | }, 433 | { 434 | "cell_type": "markdown", 435 | "metadata": {}, 436 | "source": [ 437 | "#### Decoder During Training\n", 438 | "- Build the training decoder using [`tf.contrib.seq2seq.simple_decoder_fn_train`](https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/simple_decoder_fn_train) and [`tf.contrib.seq2seq.dynamic_rnn_decoder`](https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/dynamic_rnn_decoder).\n", 439 | "- Apply the output layer to the output of the training decoder" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": 30, 445 | "metadata": { 446 | "collapsed": true 447 | }, 448 | "outputs": [], 449 | "source": [ 450 | "with tf.variable_scope(\"decoding\") as decoding_scope:\n", 451 | " # Training Decoder\n", 452 | " train_decoder_fn = tf.contrib.seq2seq.simple_decoder_fn_train(enc_state)\n", 453 | " train_pred, _, _ = tf.contrib.seq2seq.dynamic_rnn_decoder(\n", 454 | " dec_cell, train_decoder_fn, dec_embed_input, sequence_length, scope=decoding_scope)\n", 455 | " \n", 456 | " # Apply output function\n", 457 | " train_logits = output_fn(train_pred)" 458 | ] 459 | }, 460 | { 461 | "cell_type": "markdown", 462 | "metadata": {}, 463 | "source": [ 464 | "#### Decoder During Inference\n", 465 | "- Reuse the weights the biases from the training decoder using [`tf.variable_scope(\"decoding\", reuse=True)`](https://www.tensorflow.org/api_docs/python/tf/variable_scope)\n", 466 | "- Build the inference decoder using [`tf.contrib.seq2seq.simple_decoder_fn_inference`](https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/simple_decoder_fn_inference) and [`tf.contrib.seq2seq.dynamic_rnn_decoder`](https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/dynamic_rnn_decoder).\n", 467 | " - The output function is applied to the output in this step " 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": 31, 473 | "metadata": { 474 | "collapsed": true 475 | }, 476 | "outputs": [], 477 | "source": [ 478 | "with tf.variable_scope(\"decoding\", reuse=True) as decoding_scope:\n", 479 | " # Inference Decoder\n", 480 | " infer_decoder_fn = tf.contrib.seq2seq.simple_decoder_fn_inference(\n", 481 | " output_fn, enc_state, dec_embeddings, target_letter_to_int[''], target_letter_to_int['<\\s>'], \n", 482 | " sequence_length - 1, target_vocab_size)\n", 483 | " inference_logits, _, _ = tf.contrib.seq2seq.dynamic_rnn_decoder(dec_cell, infer_decoder_fn, scope=decoding_scope)" 484 | ] 485 | }, 486 | { 487 | "cell_type": "markdown", 488 | "metadata": {}, 489 | "source": [ 490 | "### Optimization\n", 491 | "Our loss function is [`tf.contrib.seq2seq.sequence_loss`](https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/sequence_loss) provided by the tensor flow seq2seq module. It calculates a weighted cross-entropy loss for the output logits." 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": 32, 497 | "metadata": { 498 | "collapsed": true 499 | }, 500 | "outputs": [], 501 | "source": [ 502 | "# Loss function\n", 503 | "cost = tf.contrib.seq2seq.sequence_loss(\n", 504 | " train_logits,\n", 505 | " targets,\n", 506 | " tf.ones([batch_size, sequence_length]))\n", 507 | "\n", 508 | "# Optimizer\n", 509 | "optimizer = tf.train.AdamOptimizer(lr)\n", 510 | "\n", 511 | "# Gradient Clipping\n", 512 | "gradients = optimizer.compute_gradients(cost)\n", 513 | "capped_gradients = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gradients if grad is not None]\n", 514 | "train_op = optimizer.apply_gradients(capped_gradients)" 515 | ] 516 | }, 517 | { 518 | "cell_type": "markdown", 519 | "metadata": {}, 520 | "source": [ 521 | "## Train\n", 522 | "We're now ready to train our model. If you run into OOM (out of memory) issues during training, try to decrease the batch_size." 523 | ] 524 | }, 525 | { 526 | "cell_type": "code", 527 | "execution_count": 33, 528 | "metadata": { 529 | "scrolled": true 530 | }, 531 | "outputs": [ 532 | { 533 | "name": "stdout", 534 | "output_type": "stream", 535 | "text": [ 536 | "Epoch 0 Batch 0/146 - Train Accuracy: 0.685, Validation Accuracy: 0.689, Loss: 10.889\n", 537 | "Epoch 0 Batch 1/146 - Train Accuracy: 0.677, Validation Accuracy: 0.689, Loss: 10.737\n", 538 | "Epoch 0 Batch 2/146 - Train Accuracy: 0.731, Validation Accuracy: 0.689, Loss: 10.500\n", 539 | "Epoch 0 Batch 3/146 - Train Accuracy: 0.700, Validation Accuracy: 0.689, Loss: 10.241\n", 540 | "Epoch 0 Batch 4/146 - Train Accuracy: 0.668, Validation Accuracy: 0.689, Loss: 9.987\n", 541 | "Epoch 0 Batch 5/146 - Train Accuracy: 0.717, Validation Accuracy: 0.689, Loss: 9.596\n", 542 | "Epoch 0 Batch 6/146 - Train Accuracy: 0.704, Validation Accuracy: 0.689, Loss: 9.263\n", 543 | "Epoch 0 Batch 7/146 - Train Accuracy: 0.706, Validation Accuracy: 0.689, Loss: 8.884\n", 544 | "Epoch 0 Batch 8/146 - Train Accuracy: 0.696, Validation Accuracy: 0.689, Loss: 8.488\n", 545 | "Epoch 0 Batch 9/146 - Train Accuracy: 0.689, Validation Accuracy: 0.689, Loss: 8.164\n", 546 | "Epoch 0 Batch 10/146 - Train Accuracy: 0.708, Validation Accuracy: 0.689, Loss: 7.727\n", 547 | "Epoch 0 Batch 11/146 - Train Accuracy: 0.744, Validation Accuracy: 0.689, Loss: 7.183\n", 548 | "Epoch 0 Batch 12/146 - Train Accuracy: 0.725, Validation Accuracy: 0.689, Loss: 6.921\n", 549 | "Epoch 0 Batch 13/146 - Train Accuracy: 0.680, Validation Accuracy: 0.689, Loss: 6.875\n", 550 | "Epoch 0 Batch 14/146 - Train Accuracy: 0.691, Validation Accuracy: 0.689, Loss: 6.464\n", 551 | "Epoch 0 Batch 15/146 - Train Accuracy: 0.688, Validation Accuracy: 0.689, Loss: 6.175\n", 552 | "Epoch 0 Batch 16/146 - Train Accuracy: 0.685, Validation Accuracy: 0.689, Loss: 5.901\n", 553 | "Epoch 0 Batch 17/146 - Train Accuracy: 0.702, Validation Accuracy: 0.689, Loss: 5.521\n", 554 | "Epoch 0 Batch 18/146 - Train Accuracy: 0.695, Validation Accuracy: 0.689, Loss: 5.323\n", 555 | "Epoch 0 Batch 19/146 - Train Accuracy: 0.665, Validation Accuracy: 0.689, Loss: 5.299\n", 556 | "Epoch 0 Batch 20/146 - Train Accuracy: 0.697, Validation Accuracy: 0.689, Loss: 4.720\n", 557 | "Epoch 0 Batch 21/146 - Train Accuracy: 0.715, Validation Accuracy: 0.689, Loss: 4.385\n", 558 | "Epoch 0 Batch 22/146 - Train Accuracy: 0.687, Validation Accuracy: 0.689, Loss: 4.425\n", 559 | "Epoch 0 Batch 23/146 - Train Accuracy: 0.703, Validation Accuracy: 0.689, Loss: 4.102\n", 560 | "Epoch 0 Batch 24/146 - Train Accuracy: 0.683, Validation Accuracy: 0.689, Loss: 4.174\n", 561 | "Epoch 0 Batch 25/146 - Train Accuracy: 0.700, Validation Accuracy: 0.689, Loss: 3.875\n", 562 | "Epoch 0 Batch 26/146 - Train Accuracy: 0.707, Validation Accuracy: 0.689, Loss: 3.702\n", 563 | "Epoch 0 Batch 27/146 - Train Accuracy: 0.711, Validation Accuracy: 0.689, Loss: 3.644\n", 564 | "Epoch 0 Batch 28/146 - Train Accuracy: 0.695, Validation Accuracy: 0.689, Loss: 3.827\n", 565 | "Epoch 0 Batch 29/146 - Train Accuracy: 0.705, Validation Accuracy: 0.689, Loss: 3.723\n", 566 | "Epoch 0 Batch 30/146 - Train Accuracy: 0.691, Validation Accuracy: 0.689, Loss: 3.854\n", 567 | "Epoch 0 Batch 31/146 - Train Accuracy: 0.692, Validation Accuracy: 0.689, Loss: 3.845\n", 568 | "Epoch 0 Batch 32/146 - Train Accuracy: 0.697, Validation Accuracy: 0.689, Loss: 3.812\n", 569 | "Epoch 0 Batch 33/146 - Train Accuracy: 0.720, Validation Accuracy: 0.689, Loss: 3.461\n", 570 | "Epoch 0 Batch 34/146 - Train Accuracy: 0.707, Validation Accuracy: 0.689, Loss: 3.682\n", 571 | "Epoch 0 Batch 35/146 - Train Accuracy: 0.695, Validation Accuracy: 0.689, Loss: 3.801\n", 572 | "Epoch 0 Batch 36/146 - Train Accuracy: 0.667, Validation Accuracy: 0.689, Loss: 4.197\n", 573 | "Epoch 0 Batch 37/146 - Train Accuracy: 0.669, Validation Accuracy: 0.689, Loss: 4.155\n", 574 | "Epoch 0 Batch 38/146 - Train Accuracy: 0.679, Validation Accuracy: 0.689, Loss: 3.969\n", 575 | "Epoch 0 Batch 39/146 - Train Accuracy: 0.656, Validation Accuracy: 0.689, Loss: 4.246\n", 576 | "Epoch 0 Batch 40/146 - Train Accuracy: 0.700, Validation Accuracy: 0.689, Loss: 3.701\n", 577 | "Epoch 0 Batch 41/146 - Train Accuracy: 0.686, Validation Accuracy: 0.689, Loss: 3.821\n", 578 | "Epoch 0 Batch 42/146 - Train Accuracy: 0.702, Validation Accuracy: 0.689, Loss: 3.621\n", 579 | "Epoch 0 Batch 43/146 - Train Accuracy: 0.718, Validation Accuracy: 0.689, Loss: 3.424\n", 580 | "Epoch 0 Batch 44/146 - Train Accuracy: 0.696, Validation Accuracy: 0.689, Loss: 3.722\n", 581 | "Epoch 0 Batch 45/146 - Train Accuracy: 0.715, Validation Accuracy: 0.689, Loss: 3.418\n", 582 | "Epoch 0 Batch 46/146 - Train Accuracy: 0.709, Validation Accuracy: 0.689, Loss: 3.487\n", 583 | "Epoch 0 Batch 47/146 - Train Accuracy: 0.726, Validation Accuracy: 0.689, Loss: 3.272\n", 584 | "Epoch 0 Batch 48/146 - Train Accuracy: 0.676, Validation Accuracy: 0.689, Loss: 3.897\n", 585 | "Epoch 0 Batch 49/146 - Train Accuracy: 0.688, Validation Accuracy: 0.689, Loss: 3.787\n", 586 | "Epoch 0 Batch 50/146 - Train Accuracy: 0.695, Validation Accuracy: 0.689, Loss: 3.620\n", 587 | "Epoch 0 Batch 51/146 - Train Accuracy: 0.691, Validation Accuracy: 0.689, Loss: 3.671\n", 588 | "Epoch 0 Batch 52/146 - Train Accuracy: 0.702, Validation Accuracy: 0.689, Loss: 3.547\n", 589 | "Epoch 0 Batch 53/146 - Train Accuracy: 0.715, Validation Accuracy: 0.689, Loss: 3.406\n", 590 | "Epoch 0 Batch 54/146 - Train Accuracy: 0.694, Validation Accuracy: 0.689, Loss: 3.654\n", 591 | "Epoch 0 Batch 55/146 - Train Accuracy: 0.716, Validation Accuracy: 0.689, Loss: 3.418\n", 592 | "Epoch 0 Batch 56/146 - Train Accuracy: 0.688, Validation Accuracy: 0.689, Loss: 3.725\n", 593 | "Epoch 0 Batch 57/146 - Train Accuracy: 0.696, Validation Accuracy: 0.689, Loss: 3.615\n", 594 | "Epoch 0 Batch 58/146 - Train Accuracy: 0.683, Validation Accuracy: 0.689, Loss: 3.707\n", 595 | "Epoch 0 Batch 59/146 - Train Accuracy: 0.721, Validation Accuracy: 0.689, Loss: 3.320\n", 596 | "Epoch 0 Batch 60/146 - Train Accuracy: 0.691, Validation Accuracy: 0.689, Loss: 3.665\n", 597 | "Epoch 0 Batch 61/146 - Train Accuracy: 0.664, Validation Accuracy: 0.689, Loss: 4.035\n", 598 | "Epoch 0 Batch 62/146 - Train Accuracy: 0.686, Validation Accuracy: 0.689, Loss: 3.724\n", 599 | "Epoch 0 Batch 63/146 - Train Accuracy: 0.709, Validation Accuracy: 0.689, Loss: 3.452\n", 600 | "Epoch 0 Batch 64/146 - Train Accuracy: 0.713, Validation Accuracy: 0.689, Loss: 3.377\n", 601 | "Epoch 0 Batch 65/146 - Train Accuracy: 0.685, Validation Accuracy: 0.689, Loss: 3.643\n", 602 | "Epoch 0 Batch 66/146 - Train Accuracy: 0.713, Validation Accuracy: 0.689, Loss: 3.377\n", 603 | "Epoch 0 Batch 67/146 - Train Accuracy: 0.693, Validation Accuracy: 0.689, Loss: 3.626\n", 604 | "Epoch 0 Batch 68/146 - Train Accuracy: 0.702, Validation Accuracy: 0.689, Loss: 3.483\n", 605 | "Epoch 0 Batch 69/146 - Train Accuracy: 0.738, Validation Accuracy: 0.689, Loss: 2.996\n", 606 | "Epoch 0 Batch 70/146 - Train Accuracy: 0.700, Validation Accuracy: 0.689, Loss: 3.427\n", 607 | "Epoch 0 Batch 71/146 - Train Accuracy: 0.709, Validation Accuracy: 0.689, Loss: 3.417\n", 608 | "Epoch 0 Batch 72/146 - Train Accuracy: 0.706, Validation Accuracy: 0.689, Loss: 3.432\n", 609 | "Epoch 0 Batch 73/146 - Train Accuracy: 0.676, Validation Accuracy: 0.689, Loss: 3.828\n", 610 | "Epoch 0 Batch 74/146 - Train Accuracy: 0.690, Validation Accuracy: 0.689, Loss: 3.596\n", 611 | "Epoch 0 Batch 75/146 - Train Accuracy: 0.687, Validation Accuracy: 0.689, Loss: 3.698\n", 612 | "Epoch 0 Batch 76/146 - Train Accuracy: 0.712, Validation Accuracy: 0.689, Loss: 3.368\n", 613 | "Epoch 0 Batch 77/146 - Train Accuracy: 0.712, Validation Accuracy: 0.689, Loss: 3.302\n", 614 | "Epoch 0 Batch 78/146 - Train Accuracy: 0.708, Validation Accuracy: 0.689, Loss: 3.314\n", 615 | "Epoch 0 Batch 79/146 - Train Accuracy: 0.681, Validation Accuracy: 0.689, Loss: 3.650\n", 616 | "Epoch 0 Batch 80/146 - Train Accuracy: 0.680, Validation Accuracy: 0.689, Loss: 3.674\n", 617 | "Epoch 0 Batch 81/146 - Train Accuracy: 0.710, Validation Accuracy: 0.689, Loss: 3.345\n", 618 | "Epoch 0 Batch 82/146 - Train Accuracy: 0.696, Validation Accuracy: 0.689, Loss: 3.509\n", 619 | "Epoch 0 Batch 83/146 - Train Accuracy: 0.718, Validation Accuracy: 0.689, Loss: 3.270\n", 620 | "Epoch 0 Batch 84/146 - Train Accuracy: 0.706, Validation Accuracy: 0.689, Loss: 3.327\n", 621 | "Epoch 0 Batch 85/146 - Train Accuracy: 0.711, Validation Accuracy: 0.689, Loss: 3.291\n", 622 | "Epoch 0 Batch 86/146 - Train Accuracy: 0.702, Validation Accuracy: 0.689, Loss: 3.327\n", 623 | "Epoch 0 Batch 87/146 - Train Accuracy: 0.712, Validation Accuracy: 0.689, Loss: 3.259\n", 624 | "Epoch 0 Batch 88/146 - Train Accuracy: 0.702, Validation Accuracy: 0.689, Loss: 3.358\n" 625 | ] 626 | }, 627 | { 628 | "name": "stdout", 629 | "output_type": "stream", 630 | "text": [ 631 | "Epoch 0 Batch 89/146 - Train Accuracy: 0.719, Validation Accuracy: 0.689, Loss: 3.115\n", 632 | "Epoch 0 Batch 90/146 - Train Accuracy: 0.688, Validation Accuracy: 0.689, Loss: 3.530\n", 633 | "Epoch 0 Batch 91/146 - Train Accuracy: 0.704, Validation Accuracy: 0.689, Loss: 3.313\n", 634 | "Epoch 0 Batch 92/146 - Train Accuracy: 0.689, Validation Accuracy: 0.689, Loss: 3.475\n", 635 | "Epoch 0 Batch 93/146 - Train Accuracy: 0.654, Validation Accuracy: 0.689, Loss: 3.888\n", 636 | "Epoch 0 Batch 94/146 - Train Accuracy: 0.737, Validation Accuracy: 0.689, Loss: 2.939\n", 637 | "Epoch 0 Batch 95/146 - Train Accuracy: 0.686, Validation Accuracy: 0.689, Loss: 3.440\n", 638 | "Epoch 0 Batch 96/146 - Train Accuracy: 0.710, Validation Accuracy: 0.689, Loss: 3.248\n", 639 | "Epoch 0 Batch 97/146 - Train Accuracy: 0.685, Validation Accuracy: 0.689, Loss: 3.490\n", 640 | "Epoch 0 Batch 98/146 - Train Accuracy: 0.677, Validation Accuracy: 0.689, Loss: 3.579\n", 641 | "Epoch 0 Batch 99/146 - Train Accuracy: 0.704, Validation Accuracy: 0.689, Loss: 3.296\n", 642 | "Epoch 0 Batch 100/146 - Train Accuracy: 0.696, Validation Accuracy: 0.689, Loss: 3.311\n", 643 | "Epoch 0 Batch 101/146 - Train Accuracy: 0.693, Validation Accuracy: 0.689, Loss: 3.428\n", 644 | "Epoch 0 Batch 102/146 - Train Accuracy: 0.710, Validation Accuracy: 0.689, Loss: 3.209\n", 645 | "Epoch 0 Batch 103/146 - Train Accuracy: 0.691, Validation Accuracy: 0.689, Loss: 3.423\n", 646 | "Epoch 0 Batch 104/146 - Train Accuracy: 0.668, Validation Accuracy: 0.689, Loss: 3.618\n", 647 | "Epoch 0 Batch 105/146 - Train Accuracy: 0.708, Validation Accuracy: 0.689, Loss: 3.153\n" 648 | ] 649 | }, 650 | { 651 | "ename": "KeyboardInterrupt", 652 | "evalue": "", 653 | "output_type": "error", 654 | "traceback": [ 655 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 656 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 657 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 17\u001b[0m batch_train_logits = sess.run(\n\u001b[1;32m 18\u001b[0m \u001b[0minference_logits\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 19\u001b[0;31m {input_data: source_batch})\n\u001b[0m\u001b[1;32m 20\u001b[0m batch_valid_logits = sess.run(\n\u001b[1;32m 21\u001b[0m \u001b[0minference_logits\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 658 | "\u001b[0;32m//anaconda/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 765\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 766\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 767\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 768\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 769\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 659 | "\u001b[0;32m//anaconda/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 963\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 964\u001b[0m results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m--> 965\u001b[0;31m feed_dict_string, options, run_metadata)\n\u001b[0m\u001b[1;32m 966\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 967\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 660 | "\u001b[0;32m//anaconda/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1013\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1014\u001b[0m return self._do_call(_run_fn, self._session, feed_dict, fetch_list,\n\u001b[0;32m-> 1015\u001b[0;31m target_list, options, run_metadata)\n\u001b[0m\u001b[1;32m 1016\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1017\u001b[0m return self._do_call(_prun_fn, self._session, handle, feed_dict,\n", 661 | "\u001b[0;32m//anaconda/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m 1020\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1021\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1022\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1023\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1024\u001b[0m \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 662 | "\u001b[0;32m//anaconda/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(session, feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m 1002\u001b[0m return tf_session.TF_Run(session, options,\n\u001b[1;32m 1003\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1004\u001b[0;31m status, run_metadata)\n\u001b[0m\u001b[1;32m 1005\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1006\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msession\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 663 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 664 | ] 665 | } 666 | ], 667 | "source": [ 668 | "import numpy as np\n", 669 | "\n", 670 | "train_source = source_ids[batch_size:]\n", 671 | "train_target = target_ids[batch_size:]\n", 672 | "\n", 673 | "valid_source = source_ids[:batch_size]\n", 674 | "valid_target = target_ids[:batch_size]\n", 675 | "\n", 676 | "sess.run(tf.global_variables_initializer())\n", 677 | "\n", 678 | "for epoch_i in range(epochs):\n", 679 | " for batch_i, (source_batch, target_batch) in enumerate(\n", 680 | " helper.batch_data(train_source, train_target, batch_size)):\n", 681 | " _, loss = sess.run(\n", 682 | " [train_op, cost],\n", 683 | " {input_data: source_batch, targets: target_batch, lr: learning_rate})\n", 684 | " batch_train_logits = sess.run(\n", 685 | " inference_logits,\n", 686 | " {input_data: source_batch})\n", 687 | " batch_valid_logits = sess.run(\n", 688 | " inference_logits,\n", 689 | " {input_data: valid_source})\n", 690 | "\n", 691 | " train_acc = np.mean(np.equal(target_batch, np.argmax(batch_train_logits, 2)))\n", 692 | " valid_acc = np.mean(np.equal(valid_target, np.argmax(batch_valid_logits, 2)))\n", 693 | " print('Epoch {:>3} Batch {:>4}/{} - Train Accuracy: {:>6.3f}, Validation Accuracy: {:>6.3f}, Loss: {:>6.3f}'\n", 694 | " .format(epoch_i, batch_i, len(source_ids) // batch_size, train_acc, valid_acc, loss))" 695 | ] 696 | }, 697 | { 698 | "cell_type": "markdown", 699 | "metadata": {}, 700 | "source": [ 701 | "## Prediction" 702 | ] 703 | }, 704 | { 705 | "cell_type": "code", 706 | "execution_count": 16, 707 | "metadata": {}, 708 | "outputs": [ 709 | { 710 | "name": "stdout", 711 | "output_type": "stream", 712 | "text": [ 713 | "Input\n", 714 | " Word Ids: [20, 18, 28, 28, 10, 0, 0]\n", 715 | " Input Words: ['h', 'e', 'l', 'l', 'o', '', '']\n", 716 | "\n", 717 | "Prediction\n", 718 | " Word Ids: [18, 20, 28, 28, 10, 0, 0]\n", 719 | " Chatbot Answer Words: ['e', 'h', 'l', 'l', 'o', '', '']\n" 720 | ] 721 | } 722 | ], 723 | "source": [ 724 | "input_sentence = 'hello'\n", 725 | "\n", 726 | "\n", 727 | "input_sentence = [source_letter_to_int.get(word, source_letter_to_int['']) for word in input_sentence.lower()]\n", 728 | "input_sentence = input_sentence + [0] * (sequence_length - len(input_sentence))\n", 729 | "batch_shell = np.zeros((batch_size, sequence_length))\n", 730 | "batch_shell[0] = input_sentence\n", 731 | "chatbot_logits = sess.run(inference_logits, {input_data: batch_shell})[0]\n", 732 | "\n", 733 | "print('Input')\n", 734 | "print(' Word Ids: {}'.format([i for i in input_sentence]))\n", 735 | "print(' Input Words: {}'.format([source_int_to_letter[i] for i in input_sentence]))\n", 736 | "\n", 737 | "print('\\nPrediction')\n", 738 | "print(' Word Ids: {}'.format([i for i in np.argmax(chatbot_logits, 1)]))\n", 739 | "print(' Chatbot Answer Words: {}'.format([target_int_to_letter[i] for i in np.argmax(chatbot_logits, 1)]))" 740 | ] 741 | } 742 | ], 743 | "metadata": { 744 | "anaconda-cloud": {}, 745 | "kernelspec": { 746 | "display_name": "Python 3", 747 | "language": "python", 748 | "name": "python3" 749 | }, 750 | "language_info": { 751 | "codemirror_mode": { 752 | "name": "ipython", 753 | "version": 3 754 | }, 755 | "file_extension": ".py", 756 | "mimetype": "text/x-python", 757 | "name": "python", 758 | "nbconvert_exporter": "python", 759 | "pygments_lexer": "ipython3", 760 | "version": "3.6.1" 761 | } 762 | }, 763 | "nbformat": 4, 764 | "nbformat_minor": 1 765 | } 766 | -------------------------------------------------------------------------------- /seq2seq/__pycache__/helper.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianguoz/Seq2Seq-Gan-Autoencoder/73881bb50578d714b89bbf584ce4a4e5a44987d3/seq2seq/__pycache__/helper.cpython-36.pyc -------------------------------------------------------------------------------- /seq2seq/helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def load_data(path): 5 | input_file = os.path.join(path) 6 | with open(input_file, "r", encoding='utf-8', errors='ignore') as f: 7 | data = f.read() 8 | 9 | return data 10 | 11 | 12 | def extract_vocab(data): 13 | special_words = ['', '', '', '<\s>'] 14 | 15 | set_words = set([word for line in data.split('\n') for word in line.split()]) 16 | int_to_vocab = {word_i: word for word_i, word in enumerate(special_words + list(set_words))} 17 | vocab_to_int = {word: word_i for word_i, word in int_to_vocab.items()} 18 | 19 | return int_to_vocab, vocab_to_int 20 | 21 | 22 | def pad_id_sequences(source_ids, source_vocab_to_int, target_ids, target_vocab_to_int, sequence_length): 23 | new_source_ids = [list(reversed(sentence + [source_vocab_to_int['']] * (sequence_length - len(sentence)))) \ 24 | for sentence in source_ids] 25 | new_target_ids = [sentence + [target_vocab_to_int['']] * (sequence_length - len(sentence)) \ 26 | for sentence in target_ids] 27 | 28 | return new_source_ids, new_target_ids 29 | 30 | 31 | def batch_data(source, target, batch_size): 32 | """ 33 | Batch source and target together 34 | """ 35 | for batch_i in range(0, len(source)//batch_size): 36 | start_i = batch_i * batch_size 37 | source_batch = source[start_i:start_i + batch_size] 38 | target_batch = target[start_i:start_i + batch_size] 39 | yield source_batch, target_batch 40 | -------------------------------------------------------------------------------- /seq2seq/images/sequence-to-sequence-inference-decoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianguoz/Seq2Seq-Gan-Autoencoder/73881bb50578d714b89bbf584ce4a4e5a44987d3/seq2seq/images/sequence-to-sequence-inference-decoder.png -------------------------------------------------------------------------------- /seq2seq/images/sequence-to-sequence-training-decoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianguoz/Seq2Seq-Gan-Autoencoder/73881bb50578d714b89bbf584ce4a4e5a44987d3/seq2seq/images/sequence-to-sequence-training-decoder.png -------------------------------------------------------------------------------- /seq2seq/images/sequence-to-sequence.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianguoz/Seq2Seq-Gan-Autoencoder/73881bb50578d714b89bbf584ce4a4e5a44987d3/seq2seq/images/sequence-to-sequence.jpg -------------------------------------------------------------------------------- /seq2seq/sequence_to_sequence_implementation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true 7 | }, 8 | "source": [ 9 | "# Character Sequence to Sequence \n", 10 | "In this notebook, we'll build a model that takes in a sequence of letters, and outputs a sorted version of that sequence. We'll do that using what we've learned so far about Sequence to Sequence models.\n", 11 | "\n", 12 | "\n", 13 | "\n", 14 | "\n", 15 | "## Dataset \n", 16 | "\n", 17 | "The dataset lives in the /data/ folder. At the moment, it is made up of the following files:\n", 18 | " * **letters_source.txt**: The list of input letter sequences. Each sequence is its own line. \n", 19 | " * **letters_target.txt**: The list of target sequences we'll use in the training process. Each sequence here is a response to the input sequence in letters_source.txt with the same line number." 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 1, 25 | "metadata": { 26 | "collapsed": true, 27 | "scrolled": false 28 | }, 29 | "outputs": [], 30 | "source": [ 31 | "import helper\n", 32 | "\n", 33 | "source_path = 'data/english1'\n", 34 | "target_path = 'data/french1'\n", 35 | "\n", 36 | "source_sentences = helper.load_data(source_path)\n", 37 | "target_sentences = helper.load_data(target_path)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "Let's start by examining the current state of the dataset. `source_sentences` contains the entire input sequence file as text delimited by newline symbols." 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "data": { 54 | "text/plain": [ 55 | "['6300181537 B000W8KY0G 1574924494 6305958041 B00077']" 56 | ] 57 | }, 58 | "execution_count": 2, 59 | "metadata": {}, 60 | "output_type": "execute_result" 61 | } 62 | ], 63 | "source": [ 64 | "source_sentences[:50].split('\\n')" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "`target_sentences` contains the entire output sequence file as text delimited by newline symbols. Each line corresponds to the line from `source_sentences`. `target_sentences` contains a sorted characters of the line." 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 3, 77 | "metadata": {}, 78 | "outputs": [ 79 | { 80 | "data": { 81 | "text/plain": [ 82 | "['B002E01LQ6 6304744404 B00008LUNW B000NQJP98 078401']" 83 | ] 84 | }, 85 | "execution_count": 3, 86 | "metadata": {}, 87 | "output_type": "execute_result" 88 | } 89 | ], 90 | "source": [ 91 | "target_sentences[:50].split('\\n')" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "## Preprocess\n", 99 | "To do anything useful with it, we'll need to turn the characters into a list of integers: " 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 21, 105 | "metadata": { 106 | "collapsed": true 107 | }, 108 | "outputs": [], 109 | "source": [ 110 | "def extract_character_vocab(data):\n", 111 | " special_words = ['', '', '', '<\\s>']\n", 112 | "\n", 113 | " set_words = set([character for line in data.split('\\n') for character in line.split(' ')])\n", 114 | " int_to_vocab = {word_i: word for word_i, word in enumerate(special_words + list(set_words))}\n", 115 | " vocab_to_int = {word: word_i for word_i, word in int_to_vocab.items()}\n", 116 | "\n", 117 | " return int_to_vocab, vocab_to_int\n", 118 | "\n", 119 | "# Build int2letter and letter2int dicts\n", 120 | "source_int_to_letter, source_letter_to_int = extract_character_vocab(source_sentences)\n", 121 | "target_int_to_letter, target_letter_to_int = extract_character_vocab(target_sentences)\n", 122 | "\n", 123 | "# Convert characters to ids\n", 124 | "source_letter_ids = [[source_letter_to_int.get(letter, source_letter_to_int['']) for letter in line.split(' ')] for line in source_sentences.split('\\n')]\n", 125 | "target_letter_ids = [[target_letter_to_int.get(letter, target_letter_to_int['']) for letter in line.split(' ')] for line in target_sentences.split('\\n')]\n" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 22, 131 | "metadata": {}, 132 | "outputs": [ 133 | { 134 | "name": "stdout", 135 | "output_type": "stream", 136 | "text": [ 137 | "Example source sequence\n", 138 | "[[13912, 45997, 4221, 15839, 46389, 887, 10294, 9024, 3635, 11081, 35851, 34866, 3193, 17286, 47244, 19131, 24487, 36266, 23880, 23845, 28602], [762, 20186, 49438, 3716, 3524, 42601, 27187], [12265, 50943, 32507, 1195, 28834, 32207, 46192]]\n", 139 | "\n", 140 | "\n", 141 | "Example target sequence\n", 142 | "[[50562, 31372, 28959, 19392, 35717, 25840, 25736, 5889, 40752, 42712, 29362, 42665, 48571, 17284, 33690], [10104, 48224, 52024, 32276], [7402, 50696, 24376, 20466, 17934, 45656, 25840, 11080, 28176, 39114, 29680, 28149, 32642, 48459, 40753, 20865, 47631, 3933, 8510, 39345, 5830, 40603]]\n" 143 | ] 144 | } 145 | ], 146 | "source": [ 147 | "print(\"Example source sequence\")\n", 148 | "#print(len(source_letter_to_int))\n", 149 | "print(source_letter_ids[:3])\n", 150 | "print(\"\\n\")\n", 151 | "print(\"Example target sequence\")\n", 152 | "print(target_letter_ids[:3])" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "The last step in the preprocessing stage is to determine the the longest sequence size in the dataset we'll be using, then pad all the sequences to that length." 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 23, 165 | "metadata": {}, 166 | "outputs": [ 167 | { 168 | "name": "stdout", 169 | "output_type": "stream", 170 | "text": [ 171 | "Sequence Length\n", 172 | "35\n", 173 | "\n", 174 | "\n", 175 | "Input sequence example\n", 176 | "[[13912, 45997, 4221, 15839, 46389, 887, 10294, 9024, 3635, 11081, 35851, 34866, 3193, 17286, 47244, 19131, 24487, 36266, 23880, 23845, 28602, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [762, 20186, 49438, 3716, 3524, 42601, 27187, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [12265, 50943, 32507, 1195, 28834, 32207, 46192, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]\n", 177 | "\n", 178 | "\n", 179 | "Target sequence example\n", 180 | "[[50562, 31372, 28959, 19392, 35717, 25840, 25736, 5889, 40752, 42712, 29362, 42665, 48571, 17284, 33690, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [10104, 48224, 52024, 32276, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [7402, 50696, 24376, 20466, 17934, 45656, 25840, 11080, 28176, 39114, 29680, 28149, 32642, 48459, 40753, 20865, 47631, 3933, 8510, 39345, 5830, 40603, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]\n" 181 | ] 182 | } 183 | ], 184 | "source": [ 185 | "def pad_id_sequences(source_ids, source_letter_to_int, target_ids, target_letter_to_int, sequence_length):\n", 186 | " new_source_ids = [sentence + [source_letter_to_int['']] * (sequence_length - len(sentence)) \\\n", 187 | " for sentence in source_ids]\n", 188 | " new_target_ids = [sentence + [target_letter_to_int['']] * (sequence_length - len(sentence)) \\\n", 189 | " for sentence in target_ids]\n", 190 | "\n", 191 | " return new_source_ids, new_target_ids\n", 192 | "\n", 193 | "\n", 194 | "# Use the longest sequence as sequence length\n", 195 | "sequence_length = max(\n", 196 | " [len(sentence) for sentence in source_letter_ids] + [len(sentence) for sentence in target_letter_ids])\n", 197 | "\n", 198 | "# Pad all sequences up to sequence length\n", 199 | "source_ids, target_ids = pad_id_sequences(source_letter_ids, source_letter_to_int, \n", 200 | " target_letter_ids, target_letter_to_int, sequence_length)\n", 201 | "\n", 202 | "print(\"Sequence Length\")\n", 203 | "print(sequence_length)\n", 204 | "print(\"\\n\")\n", 205 | "print(\"Input sequence example\")\n", 206 | "print(source_ids[:3])\n", 207 | "print(\"\\n\")\n", 208 | "print(\"Target sequence example\")\n", 209 | "print(target_ids[:3])" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "metadata": {}, 215 | "source": [ 216 | "This is the final shape we need them to be in. We can now proceed to building the model." 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "metadata": {}, 222 | "source": [ 223 | "## Model\n", 224 | "#### Check the Version of TensorFlow\n", 225 | "This will check to make sure you have the correct version of TensorFlow" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 24, 231 | "metadata": {}, 232 | "outputs": [ 233 | { 234 | "name": "stdout", 235 | "output_type": "stream", 236 | "text": [ 237 | "TensorFlow Version: 1.0.0\n" 238 | ] 239 | } 240 | ], 241 | "source": [ 242 | "from distutils.version import LooseVersion\n", 243 | "import tensorflow as tf\n", 244 | "\n", 245 | "# Check TensorFlow Version\n", 246 | "assert LooseVersion(tf.__version__) >= LooseVersion('1.0'), 'Please use TensorFlow version 1.0 or newer'\n", 247 | "print('TensorFlow Version: {}'.format(tf.__version__))" 248 | ] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "metadata": {}, 253 | "source": [ 254 | "### Hyperparameters" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 25, 260 | "metadata": { 261 | "collapsed": true 262 | }, 263 | "outputs": [], 264 | "source": [ 265 | "# Number of Epochs\n", 266 | "epochs = 60\n", 267 | "# Batch Size\n", 268 | "batch_size = 128\n", 269 | "# RNN Size\n", 270 | "rnn_size = 256\n", 271 | "# Number of Layers\n", 272 | "num_layers = 2\n", 273 | "# Embedding Size\n", 274 | "encoding_embedding_size = 256\n", 275 | "decoding_embedding_size = 256\n", 276 | "# Learning Rate\n", 277 | "learning_rate = 0.001" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "metadata": {}, 283 | "source": [ 284 | "### Input" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 26, 290 | "metadata": { 291 | "collapsed": true 292 | }, 293 | "outputs": [], 294 | "source": [ 295 | "input_data = tf.placeholder(tf.int32, [batch_size, sequence_length])\n", 296 | "targets = tf.placeholder(tf.int32, [batch_size, sequence_length])\n", 297 | "lr = tf.placeholder(tf.float32)" 298 | ] 299 | }, 300 | { 301 | "cell_type": "markdown", 302 | "metadata": {}, 303 | "source": [ 304 | "### Sequence to Sequence\n", 305 | "The decoder is probably the most complex part of this model. We need to declare a decoder for the training phase, and a decoder for the inference/prediction phase. These two decoders will share their parameters (so that all the weights and biases that are set during the training phase can be used when we deploy the model).\n", 306 | "\n", 307 | "\n", 308 | "First, we'll need to define the type of cell we'll be using for our decoder RNNs. We opted for LSTM.\n", 309 | "\n", 310 | "Then, we'll need to hookup a fully connected layer to the output of decoder. The output of this layer tells us which word the RNN is choosing to output at each time step.\n", 311 | "\n", 312 | "Let's first look at the inference/prediction decoder. It is the one we'll use when we deploy our chatbot to the wild (even though it comes second in the actual code).\n", 313 | "\n", 314 | "\n", 315 | "\n", 316 | "We'll hand our encoder hidden state to the inference decoder and have it process its output. TensorFlow handles most of the logic for us. We just have to use [`tf.contrib.seq2seq.simple_decoder_fn_inference`](https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/simple_decoder_fn_inference) and [`tf.contrib.seq2seq.dynamic_rnn_decoder`](https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/dynamic_rnn_decoder) and supply them with the appropriate inputs.\n", 317 | "\n", 318 | "Notice that the inference decoder feeds the output of each time step as an input to the next.\n", 319 | "\n", 320 | "As for the training decoder, we can think of it as looking like this:\n", 321 | "\n", 322 | "\n", 323 | "The training decoder **does not** feed the output of each time step to the next. Rather, the inputs to the decoder time steps are the target sequence from the training dataset (the orange letters)." 324 | ] 325 | }, 326 | { 327 | "cell_type": "markdown", 328 | "metadata": {}, 329 | "source": [ 330 | "### Encoding\n", 331 | "- Embed the input data using [`tf.contrib.layers.embed_sequence`](https://www.tensorflow.org/api_docs/python/tf/contrib/layers/embed_sequence)\n", 332 | "- Pass the embedded input into a stack of RNNs. Save the RNN state and ignore the output." 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": 27, 338 | "metadata": { 339 | "collapsed": true 340 | }, 341 | "outputs": [], 342 | "source": [ 343 | "source_vocab_size = len(source_letter_to_int)\n", 344 | "\n", 345 | "# Encoder embedding\n", 346 | "enc_embed_input = tf.contrib.layers.embed_sequence(input_data, source_vocab_size, encoding_embedding_size)\n", 347 | "\n", 348 | "# Encoder\n", 349 | "enc_cell = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.BasicLSTMCell(rnn_size)] * num_layers)\n", 350 | "_, enc_state = tf.nn.dynamic_rnn(enc_cell, enc_embed_input, dtype=tf.float32)" 351 | ] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "metadata": {}, 356 | "source": [ 357 | "### Process Decoding Input" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": 28, 363 | "metadata": {}, 364 | "outputs": [ 365 | { 366 | "name": "stdout", 367 | "output_type": "stream", 368 | "text": [ 369 | "Targets\n", 370 | "[[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23\n", 371 | " 24 25 26 27 28 29 30 31 32 33 34]\n", 372 | " [35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58\n", 373 | " 59 60 61 62 63 64 65 66 67 68 69]]\n", 374 | "\n", 375 | "\n", 376 | "Processed Decoding Input\n", 377 | "[[ 2 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22\n", 378 | " 23 24 25 26 27 28 29 30 31 32 33]\n", 379 | " [ 2 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57\n", 380 | " 58 59 60 61 62 63 64 65 66 67 68]]\n" 381 | ] 382 | } 383 | ], 384 | "source": [ 385 | "import numpy as np\n", 386 | "\n", 387 | "# Process the input we'll feed to the decoder\n", 388 | "ending = tf.strided_slice(targets, [0, 0], [batch_size, -1], [1, 1])\n", 389 | "dec_input = tf.concat([tf.fill([batch_size, 1], target_letter_to_int['']), ending], 1)\n", 390 | "\n", 391 | "demonstration_outputs = np.reshape(range(batch_size * sequence_length), (batch_size, sequence_length))\n", 392 | "\n", 393 | "sess = tf.InteractiveSession()\n", 394 | "print(\"Targets\")\n", 395 | "print(demonstration_outputs[:2])\n", 396 | "print(\"\\n\")\n", 397 | "print(\"Processed Decoding Input\")\n", 398 | "print(sess.run(dec_input, {targets: demonstration_outputs})[:2])" 399 | ] 400 | }, 401 | { 402 | "cell_type": "markdown", 403 | "metadata": {}, 404 | "source": [ 405 | "### Decoding\n", 406 | "- Embed the decoding input\n", 407 | "- Build the decoding RNNs\n", 408 | "- Build the output layer in the decoding scope, so the weight and bias can be shared between the training and inference decoders." 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": 29, 414 | "metadata": { 415 | "collapsed": true 416 | }, 417 | "outputs": [], 418 | "source": [ 419 | "target_vocab_size = len(target_letter_to_int)\n", 420 | "\n", 421 | "# Decoder Embedding\n", 422 | "dec_embeddings = tf.Variable(tf.random_uniform([target_vocab_size, decoding_embedding_size]))\n", 423 | "dec_embed_input = tf.nn.embedding_lookup(dec_embeddings, dec_input)\n", 424 | "\n", 425 | "# Decoder RNNs\n", 426 | "dec_cell = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.BasicLSTMCell(rnn_size)] * num_layers)\n", 427 | "\n", 428 | "with tf.variable_scope(\"decoding\") as decoding_scope:\n", 429 | " # Output Layer\n", 430 | " output_fn = lambda x: tf.contrib.layers.fully_connected(x, target_vocab_size, None, scope=decoding_scope)" 431 | ] 432 | }, 433 | { 434 | "cell_type": "markdown", 435 | "metadata": {}, 436 | "source": [ 437 | "#### Decoder During Training\n", 438 | "- Build the training decoder using [`tf.contrib.seq2seq.simple_decoder_fn_train`](https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/simple_decoder_fn_train) and [`tf.contrib.seq2seq.dynamic_rnn_decoder`](https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/dynamic_rnn_decoder).\n", 439 | "- Apply the output layer to the output of the training decoder" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": 30, 445 | "metadata": { 446 | "collapsed": true 447 | }, 448 | "outputs": [], 449 | "source": [ 450 | "with tf.variable_scope(\"decoding\") as decoding_scope:\n", 451 | " # Training Decoder\n", 452 | " train_decoder_fn = tf.contrib.seq2seq.simple_decoder_fn_train(enc_state)\n", 453 | " train_pred, _, _ = tf.contrib.seq2seq.dynamic_rnn_decoder(\n", 454 | " dec_cell, train_decoder_fn, dec_embed_input, sequence_length, scope=decoding_scope)\n", 455 | " \n", 456 | " # Apply output function\n", 457 | " train_logits = output_fn(train_pred)" 458 | ] 459 | }, 460 | { 461 | "cell_type": "markdown", 462 | "metadata": {}, 463 | "source": [ 464 | "#### Decoder During Inference\n", 465 | "- Reuse the weights the biases from the training decoder using [`tf.variable_scope(\"decoding\", reuse=True)`](https://www.tensorflow.org/api_docs/python/tf/variable_scope)\n", 466 | "- Build the inference decoder using [`tf.contrib.seq2seq.simple_decoder_fn_inference`](https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/simple_decoder_fn_inference) and [`tf.contrib.seq2seq.dynamic_rnn_decoder`](https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/dynamic_rnn_decoder).\n", 467 | " - The output function is applied to the output in this step " 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": 31, 473 | "metadata": { 474 | "collapsed": true 475 | }, 476 | "outputs": [], 477 | "source": [ 478 | "with tf.variable_scope(\"decoding\", reuse=True) as decoding_scope:\n", 479 | " # Inference Decoder\n", 480 | " infer_decoder_fn = tf.contrib.seq2seq.simple_decoder_fn_inference(\n", 481 | " output_fn, enc_state, dec_embeddings, target_letter_to_int[''], target_letter_to_int['<\\s>'], \n", 482 | " sequence_length - 1, target_vocab_size)\n", 483 | " inference_logits, _, _ = tf.contrib.seq2seq.dynamic_rnn_decoder(dec_cell, infer_decoder_fn, scope=decoding_scope)" 484 | ] 485 | }, 486 | { 487 | "cell_type": "markdown", 488 | "metadata": {}, 489 | "source": [ 490 | "### Optimization\n", 491 | "Our loss function is [`tf.contrib.seq2seq.sequence_loss`](https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/sequence_loss) provided by the tensor flow seq2seq module. It calculates a weighted cross-entropy loss for the output logits." 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": 32, 497 | "metadata": { 498 | "collapsed": true 499 | }, 500 | "outputs": [], 501 | "source": [ 502 | "# Loss function\n", 503 | "cost = tf.contrib.seq2seq.sequence_loss(\n", 504 | " train_logits,\n", 505 | " targets,\n", 506 | " tf.ones([batch_size, sequence_length]))\n", 507 | "\n", 508 | "# Optimizer\n", 509 | "optimizer = tf.train.AdamOptimizer(lr)\n", 510 | "\n", 511 | "# Gradient Clipping\n", 512 | "gradients = optimizer.compute_gradients(cost)\n", 513 | "capped_gradients = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gradients if grad is not None]\n", 514 | "train_op = optimizer.apply_gradients(capped_gradients)" 515 | ] 516 | }, 517 | { 518 | "cell_type": "markdown", 519 | "metadata": {}, 520 | "source": [ 521 | "## Train\n", 522 | "We're now ready to train our model. If you run into OOM (out of memory) issues during training, try to decrease the batch_size." 523 | ] 524 | }, 525 | { 526 | "cell_type": "code", 527 | "execution_count": 33, 528 | "metadata": { 529 | "scrolled": true 530 | }, 531 | "outputs": [ 532 | { 533 | "name": "stdout", 534 | "output_type": "stream", 535 | "text": [ 536 | "Epoch 0 Batch 0/146 - Train Accuracy: 0.685, Validation Accuracy: 0.689, Loss: 10.889\n", 537 | "Epoch 0 Batch 1/146 - Train Accuracy: 0.677, Validation Accuracy: 0.689, Loss: 10.737\n", 538 | "Epoch 0 Batch 2/146 - Train Accuracy: 0.731, Validation Accuracy: 0.689, Loss: 10.500\n", 539 | "Epoch 0 Batch 3/146 - Train Accuracy: 0.700, Validation Accuracy: 0.689, Loss: 10.241\n", 540 | "Epoch 0 Batch 4/146 - Train Accuracy: 0.668, Validation Accuracy: 0.689, Loss: 9.987\n", 541 | "Epoch 0 Batch 5/146 - Train Accuracy: 0.717, Validation Accuracy: 0.689, Loss: 9.596\n", 542 | "Epoch 0 Batch 6/146 - Train Accuracy: 0.704, Validation Accuracy: 0.689, Loss: 9.263\n", 543 | "Epoch 0 Batch 7/146 - Train Accuracy: 0.706, Validation Accuracy: 0.689, Loss: 8.884\n", 544 | "Epoch 0 Batch 8/146 - Train Accuracy: 0.696, Validation Accuracy: 0.689, Loss: 8.488\n", 545 | "Epoch 0 Batch 9/146 - Train Accuracy: 0.689, Validation Accuracy: 0.689, Loss: 8.164\n", 546 | "Epoch 0 Batch 10/146 - Train Accuracy: 0.708, Validation Accuracy: 0.689, Loss: 7.727\n", 547 | "Epoch 0 Batch 11/146 - Train Accuracy: 0.744, Validation Accuracy: 0.689, Loss: 7.183\n", 548 | "Epoch 0 Batch 12/146 - Train Accuracy: 0.725, Validation Accuracy: 0.689, Loss: 6.921\n", 549 | "Epoch 0 Batch 13/146 - Train Accuracy: 0.680, Validation Accuracy: 0.689, Loss: 6.875\n", 550 | "Epoch 0 Batch 14/146 - Train Accuracy: 0.691, Validation Accuracy: 0.689, Loss: 6.464\n", 551 | "Epoch 0 Batch 15/146 - Train Accuracy: 0.688, Validation Accuracy: 0.689, Loss: 6.175\n", 552 | "Epoch 0 Batch 16/146 - Train Accuracy: 0.685, Validation Accuracy: 0.689, Loss: 5.901\n", 553 | "Epoch 0 Batch 17/146 - Train Accuracy: 0.702, Validation Accuracy: 0.689, Loss: 5.521\n", 554 | "Epoch 0 Batch 18/146 - Train Accuracy: 0.695, Validation Accuracy: 0.689, Loss: 5.323\n", 555 | "Epoch 0 Batch 19/146 - Train Accuracy: 0.665, Validation Accuracy: 0.689, Loss: 5.299\n", 556 | "Epoch 0 Batch 20/146 - Train Accuracy: 0.697, Validation Accuracy: 0.689, Loss: 4.720\n", 557 | "Epoch 0 Batch 21/146 - Train Accuracy: 0.715, Validation Accuracy: 0.689, Loss: 4.385\n", 558 | "Epoch 0 Batch 22/146 - Train Accuracy: 0.687, Validation Accuracy: 0.689, Loss: 4.425\n", 559 | "Epoch 0 Batch 23/146 - Train Accuracy: 0.703, Validation Accuracy: 0.689, Loss: 4.102\n", 560 | "Epoch 0 Batch 24/146 - Train Accuracy: 0.683, Validation Accuracy: 0.689, Loss: 4.174\n", 561 | "Epoch 0 Batch 25/146 - Train Accuracy: 0.700, Validation Accuracy: 0.689, Loss: 3.875\n", 562 | "Epoch 0 Batch 26/146 - Train Accuracy: 0.707, Validation Accuracy: 0.689, Loss: 3.702\n", 563 | "Epoch 0 Batch 27/146 - Train Accuracy: 0.711, Validation Accuracy: 0.689, Loss: 3.644\n", 564 | "Epoch 0 Batch 28/146 - Train Accuracy: 0.695, Validation Accuracy: 0.689, Loss: 3.827\n", 565 | "Epoch 0 Batch 29/146 - Train Accuracy: 0.705, Validation Accuracy: 0.689, Loss: 3.723\n", 566 | "Epoch 0 Batch 30/146 - Train Accuracy: 0.691, Validation Accuracy: 0.689, Loss: 3.854\n", 567 | "Epoch 0 Batch 31/146 - Train Accuracy: 0.692, Validation Accuracy: 0.689, Loss: 3.845\n", 568 | "Epoch 0 Batch 32/146 - Train Accuracy: 0.697, Validation Accuracy: 0.689, Loss: 3.812\n", 569 | "Epoch 0 Batch 33/146 - Train Accuracy: 0.720, Validation Accuracy: 0.689, Loss: 3.461\n", 570 | "Epoch 0 Batch 34/146 - Train Accuracy: 0.707, Validation Accuracy: 0.689, Loss: 3.682\n", 571 | "Epoch 0 Batch 35/146 - Train Accuracy: 0.695, Validation Accuracy: 0.689, Loss: 3.801\n", 572 | "Epoch 0 Batch 36/146 - Train Accuracy: 0.667, Validation Accuracy: 0.689, Loss: 4.197\n", 573 | "Epoch 0 Batch 37/146 - Train Accuracy: 0.669, Validation Accuracy: 0.689, Loss: 4.155\n", 574 | "Epoch 0 Batch 38/146 - Train Accuracy: 0.679, Validation Accuracy: 0.689, Loss: 3.969\n", 575 | "Epoch 0 Batch 39/146 - Train Accuracy: 0.656, Validation Accuracy: 0.689, Loss: 4.246\n", 576 | "Epoch 0 Batch 40/146 - Train Accuracy: 0.700, Validation Accuracy: 0.689, Loss: 3.701\n", 577 | "Epoch 0 Batch 41/146 - Train Accuracy: 0.686, Validation Accuracy: 0.689, Loss: 3.821\n", 578 | "Epoch 0 Batch 42/146 - Train Accuracy: 0.702, Validation Accuracy: 0.689, Loss: 3.621\n", 579 | "Epoch 0 Batch 43/146 - Train Accuracy: 0.718, Validation Accuracy: 0.689, Loss: 3.424\n", 580 | "Epoch 0 Batch 44/146 - Train Accuracy: 0.696, Validation Accuracy: 0.689, Loss: 3.722\n", 581 | "Epoch 0 Batch 45/146 - Train Accuracy: 0.715, Validation Accuracy: 0.689, Loss: 3.418\n", 582 | "Epoch 0 Batch 46/146 - Train Accuracy: 0.709, Validation Accuracy: 0.689, Loss: 3.487\n", 583 | "Epoch 0 Batch 47/146 - Train Accuracy: 0.726, Validation Accuracy: 0.689, Loss: 3.272\n", 584 | "Epoch 0 Batch 48/146 - Train Accuracy: 0.676, Validation Accuracy: 0.689, Loss: 3.897\n", 585 | "Epoch 0 Batch 49/146 - Train Accuracy: 0.688, Validation Accuracy: 0.689, Loss: 3.787\n", 586 | "Epoch 0 Batch 50/146 - Train Accuracy: 0.695, Validation Accuracy: 0.689, Loss: 3.620\n", 587 | "Epoch 0 Batch 51/146 - Train Accuracy: 0.691, Validation Accuracy: 0.689, Loss: 3.671\n", 588 | "Epoch 0 Batch 52/146 - Train Accuracy: 0.702, Validation Accuracy: 0.689, Loss: 3.547\n", 589 | "Epoch 0 Batch 53/146 - Train Accuracy: 0.715, Validation Accuracy: 0.689, Loss: 3.406\n", 590 | "Epoch 0 Batch 54/146 - Train Accuracy: 0.694, Validation Accuracy: 0.689, Loss: 3.654\n", 591 | "Epoch 0 Batch 55/146 - Train Accuracy: 0.716, Validation Accuracy: 0.689, Loss: 3.418\n", 592 | "Epoch 0 Batch 56/146 - Train Accuracy: 0.688, Validation Accuracy: 0.689, Loss: 3.725\n", 593 | "Epoch 0 Batch 57/146 - Train Accuracy: 0.696, Validation Accuracy: 0.689, Loss: 3.615\n", 594 | "Epoch 0 Batch 58/146 - Train Accuracy: 0.683, Validation Accuracy: 0.689, Loss: 3.707\n", 595 | "Epoch 0 Batch 59/146 - Train Accuracy: 0.721, Validation Accuracy: 0.689, Loss: 3.320\n", 596 | "Epoch 0 Batch 60/146 - Train Accuracy: 0.691, Validation Accuracy: 0.689, Loss: 3.665\n", 597 | "Epoch 0 Batch 61/146 - Train Accuracy: 0.664, Validation Accuracy: 0.689, Loss: 4.035\n", 598 | "Epoch 0 Batch 62/146 - Train Accuracy: 0.686, Validation Accuracy: 0.689, Loss: 3.724\n", 599 | "Epoch 0 Batch 63/146 - Train Accuracy: 0.709, Validation Accuracy: 0.689, Loss: 3.452\n", 600 | "Epoch 0 Batch 64/146 - Train Accuracy: 0.713, Validation Accuracy: 0.689, Loss: 3.377\n", 601 | "Epoch 0 Batch 65/146 - Train Accuracy: 0.685, Validation Accuracy: 0.689, Loss: 3.643\n", 602 | "Epoch 0 Batch 66/146 - Train Accuracy: 0.713, Validation Accuracy: 0.689, Loss: 3.377\n", 603 | "Epoch 0 Batch 67/146 - Train Accuracy: 0.693, Validation Accuracy: 0.689, Loss: 3.626\n", 604 | "Epoch 0 Batch 68/146 - Train Accuracy: 0.702, Validation Accuracy: 0.689, Loss: 3.483\n", 605 | "Epoch 0 Batch 69/146 - Train Accuracy: 0.738, Validation Accuracy: 0.689, Loss: 2.996\n", 606 | "Epoch 0 Batch 70/146 - Train Accuracy: 0.700, Validation Accuracy: 0.689, Loss: 3.427\n", 607 | "Epoch 0 Batch 71/146 - Train Accuracy: 0.709, Validation Accuracy: 0.689, Loss: 3.417\n", 608 | "Epoch 0 Batch 72/146 - Train Accuracy: 0.706, Validation Accuracy: 0.689, Loss: 3.432\n", 609 | "Epoch 0 Batch 73/146 - Train Accuracy: 0.676, Validation Accuracy: 0.689, Loss: 3.828\n", 610 | "Epoch 0 Batch 74/146 - Train Accuracy: 0.690, Validation Accuracy: 0.689, Loss: 3.596\n", 611 | "Epoch 0 Batch 75/146 - Train Accuracy: 0.687, Validation Accuracy: 0.689, Loss: 3.698\n", 612 | "Epoch 0 Batch 76/146 - Train Accuracy: 0.712, Validation Accuracy: 0.689, Loss: 3.368\n", 613 | "Epoch 0 Batch 77/146 - Train Accuracy: 0.712, Validation Accuracy: 0.689, Loss: 3.302\n", 614 | "Epoch 0 Batch 78/146 - Train Accuracy: 0.708, Validation Accuracy: 0.689, Loss: 3.314\n", 615 | "Epoch 0 Batch 79/146 - Train Accuracy: 0.681, Validation Accuracy: 0.689, Loss: 3.650\n", 616 | "Epoch 0 Batch 80/146 - Train Accuracy: 0.680, Validation Accuracy: 0.689, Loss: 3.674\n", 617 | "Epoch 0 Batch 81/146 - Train Accuracy: 0.710, Validation Accuracy: 0.689, Loss: 3.345\n", 618 | "Epoch 0 Batch 82/146 - Train Accuracy: 0.696, Validation Accuracy: 0.689, Loss: 3.509\n", 619 | "Epoch 0 Batch 83/146 - Train Accuracy: 0.718, Validation Accuracy: 0.689, Loss: 3.270\n", 620 | "Epoch 0 Batch 84/146 - Train Accuracy: 0.706, Validation Accuracy: 0.689, Loss: 3.327\n", 621 | "Epoch 0 Batch 85/146 - Train Accuracy: 0.711, Validation Accuracy: 0.689, Loss: 3.291\n", 622 | "Epoch 0 Batch 86/146 - Train Accuracy: 0.702, Validation Accuracy: 0.689, Loss: 3.327\n", 623 | "Epoch 0 Batch 87/146 - Train Accuracy: 0.712, Validation Accuracy: 0.689, Loss: 3.259\n", 624 | "Epoch 0 Batch 88/146 - Train Accuracy: 0.702, Validation Accuracy: 0.689, Loss: 3.358\n" 625 | ] 626 | }, 627 | { 628 | "name": "stdout", 629 | "output_type": "stream", 630 | "text": [ 631 | "Epoch 0 Batch 89/146 - Train Accuracy: 0.719, Validation Accuracy: 0.689, Loss: 3.115\n", 632 | "Epoch 0 Batch 90/146 - Train Accuracy: 0.688, Validation Accuracy: 0.689, Loss: 3.530\n", 633 | "Epoch 0 Batch 91/146 - Train Accuracy: 0.704, Validation Accuracy: 0.689, Loss: 3.313\n", 634 | "Epoch 0 Batch 92/146 - Train Accuracy: 0.689, Validation Accuracy: 0.689, Loss: 3.475\n", 635 | "Epoch 0 Batch 93/146 - Train Accuracy: 0.654, Validation Accuracy: 0.689, Loss: 3.888\n", 636 | "Epoch 0 Batch 94/146 - Train Accuracy: 0.737, Validation Accuracy: 0.689, Loss: 2.939\n", 637 | "Epoch 0 Batch 95/146 - Train Accuracy: 0.686, Validation Accuracy: 0.689, Loss: 3.440\n", 638 | "Epoch 0 Batch 96/146 - Train Accuracy: 0.710, Validation Accuracy: 0.689, Loss: 3.248\n", 639 | "Epoch 0 Batch 97/146 - Train Accuracy: 0.685, Validation Accuracy: 0.689, Loss: 3.490\n", 640 | "Epoch 0 Batch 98/146 - Train Accuracy: 0.677, Validation Accuracy: 0.689, Loss: 3.579\n", 641 | "Epoch 0 Batch 99/146 - Train Accuracy: 0.704, Validation Accuracy: 0.689, Loss: 3.296\n", 642 | "Epoch 0 Batch 100/146 - Train Accuracy: 0.696, Validation Accuracy: 0.689, Loss: 3.311\n", 643 | "Epoch 0 Batch 101/146 - Train Accuracy: 0.693, Validation Accuracy: 0.689, Loss: 3.428\n", 644 | "Epoch 0 Batch 102/146 - Train Accuracy: 0.710, Validation Accuracy: 0.689, Loss: 3.209\n", 645 | "Epoch 0 Batch 103/146 - Train Accuracy: 0.691, Validation Accuracy: 0.689, Loss: 3.423\n", 646 | "Epoch 0 Batch 104/146 - Train Accuracy: 0.668, Validation Accuracy: 0.689, Loss: 3.618\n", 647 | "Epoch 0 Batch 105/146 - Train Accuracy: 0.708, Validation Accuracy: 0.689, Loss: 3.153\n" 648 | ] 649 | }, 650 | { 651 | "ename": "KeyboardInterrupt", 652 | "evalue": "", 653 | "output_type": "error", 654 | "traceback": [ 655 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 656 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 657 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 17\u001b[0m batch_train_logits = sess.run(\n\u001b[1;32m 18\u001b[0m \u001b[0minference_logits\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 19\u001b[0;31m {input_data: source_batch})\n\u001b[0m\u001b[1;32m 20\u001b[0m batch_valid_logits = sess.run(\n\u001b[1;32m 21\u001b[0m \u001b[0minference_logits\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 658 | "\u001b[0;32m//anaconda/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 765\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 766\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 767\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 768\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 769\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 659 | "\u001b[0;32m//anaconda/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 963\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 964\u001b[0m results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m--> 965\u001b[0;31m feed_dict_string, options, run_metadata)\n\u001b[0m\u001b[1;32m 966\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 967\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 660 | "\u001b[0;32m//anaconda/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1013\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1014\u001b[0m return self._do_call(_run_fn, self._session, feed_dict, fetch_list,\n\u001b[0;32m-> 1015\u001b[0;31m target_list, options, run_metadata)\n\u001b[0m\u001b[1;32m 1016\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1017\u001b[0m return self._do_call(_prun_fn, self._session, handle, feed_dict,\n", 661 | "\u001b[0;32m//anaconda/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m 1020\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1021\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1022\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1023\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1024\u001b[0m \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 662 | "\u001b[0;32m//anaconda/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(session, feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m 1002\u001b[0m return tf_session.TF_Run(session, options,\n\u001b[1;32m 1003\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1004\u001b[0;31m status, run_metadata)\n\u001b[0m\u001b[1;32m 1005\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1006\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msession\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 663 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 664 | ] 665 | } 666 | ], 667 | "source": [ 668 | "import numpy as np\n", 669 | "\n", 670 | "train_source = source_ids[batch_size:]\n", 671 | "train_target = target_ids[batch_size:]\n", 672 | "\n", 673 | "valid_source = source_ids[:batch_size]\n", 674 | "valid_target = target_ids[:batch_size]\n", 675 | "\n", 676 | "sess.run(tf.global_variables_initializer())\n", 677 | "\n", 678 | "for epoch_i in range(epochs):\n", 679 | " for batch_i, (source_batch, target_batch) in enumerate(\n", 680 | " helper.batch_data(train_source, train_target, batch_size)):\n", 681 | " _, loss = sess.run(\n", 682 | " [train_op, cost],\n", 683 | " {input_data: source_batch, targets: target_batch, lr: learning_rate})\n", 684 | " batch_train_logits = sess.run(\n", 685 | " inference_logits,\n", 686 | " {input_data: source_batch})\n", 687 | " batch_valid_logits = sess.run(\n", 688 | " inference_logits,\n", 689 | " {input_data: valid_source})\n", 690 | "\n", 691 | " train_acc = np.mean(np.equal(target_batch, np.argmax(batch_train_logits, 2)))\n", 692 | " valid_acc = np.mean(np.equal(valid_target, np.argmax(batch_valid_logits, 2)))\n", 693 | " print('Epoch {:>3} Batch {:>4}/{} - Train Accuracy: {:>6.3f}, Validation Accuracy: {:>6.3f}, Loss: {:>6.3f}'\n", 694 | " .format(epoch_i, batch_i, len(source_ids) // batch_size, train_acc, valid_acc, loss))" 695 | ] 696 | }, 697 | { 698 | "cell_type": "markdown", 699 | "metadata": {}, 700 | "source": [ 701 | "## Prediction" 702 | ] 703 | }, 704 | { 705 | "cell_type": "code", 706 | "execution_count": 16, 707 | "metadata": {}, 708 | "outputs": [ 709 | { 710 | "name": "stdout", 711 | "output_type": "stream", 712 | "text": [ 713 | "Input\n", 714 | " Word Ids: [20, 18, 28, 28, 10, 0, 0]\n", 715 | " Input Words: ['h', 'e', 'l', 'l', 'o', '', '']\n", 716 | "\n", 717 | "Prediction\n", 718 | " Word Ids: [18, 20, 28, 28, 10, 0, 0]\n", 719 | " Chatbot Answer Words: ['e', 'h', 'l', 'l', 'o', '', '']\n" 720 | ] 721 | } 722 | ], 723 | "source": [ 724 | "input_sentence = 'hello'\n", 725 | "\n", 726 | "\n", 727 | "input_sentence = [source_letter_to_int.get(word, source_letter_to_int['']) for word in input_sentence.lower()]\n", 728 | "input_sentence = input_sentence + [0] * (sequence_length - len(input_sentence))\n", 729 | "batch_shell = np.zeros((batch_size, sequence_length))\n", 730 | "batch_shell[0] = input_sentence\n", 731 | "chatbot_logits = sess.run(inference_logits, {input_data: batch_shell})[0]\n", 732 | "\n", 733 | "print('Input')\n", 734 | "print(' Word Ids: {}'.format([i for i in input_sentence]))\n", 735 | "print(' Input Words: {}'.format([source_int_to_letter[i] for i in input_sentence]))\n", 736 | "\n", 737 | "print('\\nPrediction')\n", 738 | "print(' Word Ids: {}'.format([i for i in np.argmax(chatbot_logits, 1)]))\n", 739 | "print(' Chatbot Answer Words: {}'.format([target_int_to_letter[i] for i in np.argmax(chatbot_logits, 1)]))" 740 | ] 741 | } 742 | ], 743 | "metadata": { 744 | "anaconda-cloud": {}, 745 | "kernelspec": { 746 | "display_name": "Python 3", 747 | "language": "python", 748 | "name": "python3" 749 | }, 750 | "language_info": { 751 | "codemirror_mode": { 752 | "name": "ipython", 753 | "version": 3 754 | }, 755 | "file_extension": ".py", 756 | "mimetype": "text/x-python", 757 | "name": "python", 758 | "nbconvert_exporter": "python", 759 | "pygments_lexer": "ipython3", 760 | "version": "3.6.1" 761 | } 762 | }, 763 | "nbformat": 4, 764 | "nbformat_minor": 1 765 | } 766 | --------------------------------------------------------------------------------