├── LICENSE ├── README.md ├── Minimal character-level Tensorflow RNN model.ipynb ├── keras_mnist_generator.ipynb ├── tf-minimal-Char-RNN.ipynb ├── rnn_face_tests ├── tf-LFW-load-tensor.ipynb └── LFW Model Face V0.3.ipynb ├── word2vec └── word2vec.ipynb ├── Fizz Buzz.ipynb └── images2gif.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016, Damien Henry 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ML-Tutorial-Notebooks 2 | Machine learning tutorial from scratch. 3 | 4 | This depos contains some tutorials for real beginners who want to understand machine learning by reading some code. 5 | 6 | ## Minimal character-level Vanilla RNN model, explained in a notbook 7 | 8 | RNN stand for "Recurent Neural Network". 9 | To understand why RNN are so hot you _must_ read [this](http://karpathy.github.io/2015/05/21/rnn-effectiveness/)! 10 | 11 | [This notebook](https://github.com/dh7/ML-Tutorial-Notebooks/blob/master/RNN.ipynb) to explain the _[Minimal character-level Vanilla RNN model](https://gist.github.com/karpathy/d4dee566867f8291f086)_ written by __Andrej Karpathy__ 12 | This code create a RNN to generate a text, char after char, by learning char after char from a textfile. 13 | 14 | I love this _character-level Vanilla RNN_ code because it doesn't use any library except numpy. 15 | All the NN magic in 112 lines of code, no need to understand any dependency. Everything is there! I'll try to explain in detail every line of it. Disclamer: I still need to use some external links for reference. 16 | 17 | ## Minimal character-level TensorFlow RNN model 18 | 112 ligne of code to implement a character-level RNN in TensorFlow. 19 | [here the code!](https://github.com/dh7/ML-Tutorial-Notebooks/blob/master/Minimal%20character-level%20Tensorflow%20RNN%20model.ipynb) 20 | This is an adaptation of [Minimal character-level Vanilla RNN model](https://gist.github.com/karpathy/d4dee566867f8291f086)_ written by __Andrej Karpathy__ 21 | 22 | ## Character-level TensorFlow RNN model. 23 | If you want to go deeper in TensorFlow for RNN, [This notebook](https://github.com/dh7/ML-Tutorial-Notebooks/blob/master/tf-char-RNN.ipynb) try to explain [this original code](https://github.com/sherjilozair/char-rnn-tensorflow) from __Sherjil Ozair__. 24 | 25 | ## Linear regression with Tensor Flow. 26 | 27 | [This notebook](https://github.com/dh7/ML-Tutorial-Notebooks/blob/master/tf-linear-regression.ipynb) to start with tensor flow, using a simple example from __Aymeric Damien__ 28 | 29 | ## Fizz Buzz with Tensor Flow. 30 | 31 | [This notebook](https://github.com/dh7/ML-Tutorial-Notebooks/blob/master/Fizz%20Buzz.ipynb) to explain the [code](https://github.com/joelgrus/fizz-buzz-tensorflow/blob/master/fizz_buzz.py) from [Fizz Buzz in Tensor Flow](http://joelgrus.com/2016/05/23/fizz-buzz-in-tensorflow/) blog post written by __Joel Grus__ 32 | You should read his post first! 33 | 34 | His [code](https://github.com/joelgrus/fizz-buzz-tensorflow/blob/master/fizz_buzz.py) try to play the Fizz Buzz game by using machine learning. 35 | 36 | ## Temperature 37 | 38 | Temperature is a concept that is used when you need to generate a random number from a probability vector but want to over empasis samples that have the highest probability. 39 | 40 | [This notebook](https://github.com/dh7/ML-Tutorial-Notebooks/blob/master/Temperature.ipynb) show the effects in practice. 41 | 42 | -------------------------------------------------------------------------------- /Minimal character-level Tensorflow RNN model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Minimal Character-level TensorFlow RNN model. \n", 8 | "\n", 9 | "The following code is an adaptation to TensorFlow of the _[Minimal character-level Vanilla RNN model](https://gist.github.com/karpathy/d4dee566867f8291f086)_ written by __Andrej Karpathy__. \n", 10 | "\n", 11 | "[The Unreasonable Effectiveness of Recurrent Neural Networks](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) is a great source of inspiration to understand the power of RNN.\n", 12 | "\n", 13 | "This notebook is for beginners who whant to understand RNN and the basis of TensorFlow by reading code.\n", 14 | "\n", 15 | "More ressources:\n", 16 | "* [Minimal character-level Vanilla RNN model explained in a notebook](https://github.com/dh7/ML-Tutorial-Notebooks/blob/master/RNN.ipynb)\n", 17 | "* [RNN in TensorFlow explained in a notebook](https://github.com/dh7/ML-Tutorial-Notebooks/blob/master/tf-char-RNN.ipynb)\n", 18 | "* [A model that implement LSTM with 2 layer](https://github.com/sherjilozair/char-rnn-tensorflow) from __Sherjil Ozair__\n" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": { 25 | "collapsed": false 26 | }, 27 | "outputs": [], 28 | "source": [ 29 | "\"\"\"\n", 30 | "Minimal character-level TensorFlow RNN model.\n", 31 | "Original code written by Andrej Karpathy (@karpathy), \n", 32 | "adapted to TensorFlow by Damien Henry (@dh7net)\n", 33 | "BSD License\n", 34 | "\"\"\"\n", 35 | "import numpy as np\n", 36 | "import tensorflow as tf\n", 37 | "tf.reset_default_graph() # Useful in Jupyter, to run the code several times\n", 38 | "\n", 39 | "# data I/O\n", 40 | "data = open('methamorphosis.txt', 'r').read() # should be simple plain text file\n", 41 | "chars = list(set(data))\n", 42 | "data_size, vocab_size = len(data), len(chars)\n", 43 | "print 'data has %d characters, %d unique.' % (data_size, vocab_size)\n", 44 | "char_to_ix = { ch:i for i,ch in enumerate(chars) } # to convert a char to an ID\n", 45 | "ix_to_char = { i:ch for i,ch in enumerate(chars) } # to convert an ID back to a char\n", 46 | "\n", 47 | "# hyperparameters\n", 48 | "hidden_size = 100 # size of hidden layer of neurons\n", 49 | "seq_length = 25 # number of steps to unroll the RNN for\n", 50 | "learning_rate = 0.002\n", 51 | "decay_rate = 0.98 # \n", 52 | "\n", 53 | "# model parameters\n", 54 | "Wxh = tf.Variable(tf.random_uniform((hidden_size, vocab_size))*0.01, name='Wxh') #input to hidden\n", 55 | "Whh = tf.Variable(tf.random_uniform((hidden_size, hidden_size))*0.01, name='Whh')#hidden to hidden\n", 56 | "Why = tf.Variable(tf.random_uniform((vocab_size, hidden_size))*0.01, name='Why') #hidden to output\n", 57 | "bh = tf.Variable(tf.zeros((hidden_size, 1)), name='bh') # hidden bias\n", 58 | "by = tf.Variable(tf.zeros((vocab_size, 1)), name='by') # output bias\n", 59 | "\n", 60 | "# Define placeholder to for the input and the target & create the sequences\n", 61 | "input_data = tf.placeholder(tf.float32, [seq_length, vocab_size], name='input_data')\n", 62 | "xs = tf.split(0, seq_length, input_data)\n", 63 | "target_data = tf.placeholder(tf.float32, [seq_length, vocab_size], name='target_data') \n", 64 | "targets = tf.split(0, seq_length, target_data) \n", 65 | "# initial_state & loss\n", 66 | "initial_state = tf.zeros((hidden_size, 1))\n", 67 | "loss = tf.zeros([1], name='loss')\n", 68 | "# unroll recursion to create the forward pass graph\n", 69 | "hs, ys, ps = {}, {}, {}\n", 70 | "hs[-1] = initial_state \n", 71 | "for t in xrange(seq_length):\n", 72 | " xs_t = tf.transpose(xs[t])\n", 73 | " targets_t = tf.transpose(targets[t]) \n", 74 | " hs[t] = tf.tanh(tf.matmul(Wxh, xs_t) + tf.matmul(Whh, hs[t-1]) + bh) # hidden state\n", 75 | " ys[t] = tf.matmul(Why, hs[t]) + by # unnormalized log probabilities for next chars\n", 76 | " ps[t] = tf.exp(ys[t]) / tf.reduce_sum(tf.exp(ys[t])) # probabilities for next chars\n", 77 | " loss += -tf.log(tf.reduce_sum(tf.mul(ps[t], targets_t))) # softmax (cross-entropy loss)\n", 78 | "\n", 79 | "cost = loss / seq_length\n", 80 | "final_state = hs[seq_length-1]\n", 81 | "lr = tf.Variable(0.0, trainable=False, name='learning_rate')\n", 82 | "tvars = tf.trainable_variables()\n", 83 | "# Calculation of gradient is done by TensorFlow using \"tf.gradients(cost, tvars)\"\n", 84 | "grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), 5) # clip exploding gradients\n", 85 | "optimizer = tf.train.AdamOptimizer(lr) \n", 86 | "train_op = optimizer.apply_gradients(zip(grads, tvars))\n", 87 | "\n", 88 | "def sample(h, seed_ix, n):\n", 89 | " \"\"\" \n", 90 | " sample a sequence of integers from the model \n", 91 | " h is memory state, seed_ix is seed letter for first time step\n", 92 | " \"\"\"\n", 93 | " x = np.zeros((vocab_size, 1))\n", 94 | " x[seed_ix] = 1\n", 95 | " ixes = []\n", 96 | " for t in xrange(n):\n", 97 | " h = np.tanh(np.dot(Wxh.eval(), x) + np.dot(Whh.eval(), h) + bh.eval())\n", 98 | " y = np.dot(Why.eval(), h) + by.eval()\n", 99 | " p = np.exp(y) / np.sum(np.exp(y))\n", 100 | " ix = np.random.choice(range(vocab_size), p=p.ravel())\n", 101 | " x = np.zeros((vocab_size, 1))\n", 102 | " x[ix] = 1\n", 103 | " ixes.append(ix)\n", 104 | " return ixes\n", 105 | "\n", 106 | "def vectorize(x): # take an array of IX and return an array of vector\n", 107 | " vectorized = np.zeros((len(x), vocab_size))\n", 108 | " for i in range(0, len(x)):\n", 109 | " vectorized[i][x[i]] = 1\n", 110 | " return vectorized\n", 111 | "\n", 112 | "n, p, epoch = 0, 0, 0\n", 113 | "smooth_loss = -np.log(1.0/vocab_size)*seq_length # loss at iteration 0\n", 114 | "with tf.Session() as sess:\n", 115 | " tf.initialize_all_variables().run()\n", 116 | " print \"all variable initialized\"\n", 117 | " while True:\n", 118 | " # prepare inputs (we're sweeping from left to right in steps seq_length long)\n", 119 | " if p+seq_length+1 >= len(data) or n == 0: \n", 120 | " state = initial_state.eval() # reset RNN memory\n", 121 | " sess.run(tf.assign(lr, learning_rate * (decay_rate ** epoch)))\n", 122 | " p = 0 # go from start of data\n", 123 | " epoch += 1 # increase epoch number\n", 124 | " x = vectorize([char_to_ix[ch] for ch in data[p:p+seq_length]])\n", 125 | " y = vectorize([char_to_ix[ch] for ch in data[p+1:p+seq_length+1]])\n", 126 | " # Create the structure for the learning data\n", 127 | " feed = {input_data: x, target_data: y, initial_state: state}\n", 128 | " # Run a session using train_op\n", 129 | " [train_loss], state, _ = sess.run([cost, final_state, train_op], feed)\n", 130 | " smooth_loss = smooth_loss * 0.999 + train_loss * 0.001\n", 131 | " # sample from the model now and then\n", 132 | " if n % 1000 == 0:\n", 133 | " print 'iter %d, loss: %f' % (n, smooth_loss) # print progress\n", 134 | " sample_ix = sample(state, char_to_ix[data[p]], 200)\n", 135 | " txt = ''.join(ix_to_char[ix] for ix in sample_ix)\n", 136 | " print '----\\n %s \\n----' % (txt, )\n", 137 | "\n", 138 | " p += seq_length # move data pointer\n", 139 | " n += 1 # iteration counter \n", 140 | " " 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": {}, 146 | "source": [ 147 | "Feedback welcome __@dh7net__" 148 | ] 149 | } 150 | ], 151 | "metadata": { 152 | "kernelspec": { 153 | "display_name": "Python 2", 154 | "language": "python", 155 | "name": "python2" 156 | }, 157 | "language_info": { 158 | "codemirror_mode": { 159 | "name": "ipython", 160 | "version": 2 161 | }, 162 | "file_extension": ".py", 163 | "mimetype": "text/x-python", 164 | "name": "python", 165 | "nbconvert_exporter": "python", 166 | "pygments_lexer": "ipython2", 167 | "version": "2.7.10" 168 | } 169 | }, 170 | "nbformat": 4, 171 | "nbformat_minor": 0 172 | } 173 | -------------------------------------------------------------------------------- /keras_mnist_generator.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# A simple network using Keras" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 2, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import tensorflow as tf\n", 17 | "import random\n", 18 | "import numpy as np\n", 19 | "\n", 20 | "from tensorflow.python import keras\n", 21 | "\n", 22 | "from tensorflow.python.keras.datasets import mnist\n", 23 | "from tensorflow.python.keras.utils import np_utils" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "## Load MNIST" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 3, 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "name": "stdout", 40 | "output_type": "stream", 41 | "text": [ 42 | "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n", 43 | "11493376/11490434 [==============================] - 5s 0us/step\n", 44 | "Training matrix X shape (60000, 784)\n", 45 | "Testing matrix X shape (10000, 784)\n", 46 | "Training matrix Y shape (60000, 10)\n", 47 | "Testing matrix Y shape (10000, 10)\n" 48 | ] 49 | } 50 | ], 51 | "source": [ 52 | "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", 53 | "\n", 54 | "X_train = x_train.reshape(60000, 784)\n", 55 | "X_test = x_test.reshape(10000, 784)\n", 56 | "X_train = X_train.astype('float32')\n", 57 | "X_test = X_test.astype('float32')\n", 58 | "X_train /= 255\n", 59 | "X_test /= 255\n", 60 | "print(\"Training matrix X shape\", X_train.shape)\n", 61 | "print(\"Testing matrix X shape\", X_test.shape)\n", 62 | "\n", 63 | "Y_train = np_utils.to_categorical(y_train, 10)\n", 64 | "Y_test = np_utils.to_categorical(y_test, 10)\n", 65 | "\n", 66 | "print(\"Training matrix Y shape\", Y_train.shape)\n", 67 | "print(\"Testing matrix Y shape\", Y_test.shape)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "### Display a sample" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 4, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAABNtJREFUeJzt3dFN3FgYgNHxanuBOtBUAFVACSmAEqAKqAClDqYa70teVoLrBGMP4TvnkauRHSefrpSf65nmeT4A398/574BYB9ihwixQ4TYIULsEPHvnhebpsl//cPG5nme3vq5nR0ixA4RYocIsUOE2CFC7BAhdogQO0SIHSLEDhFihwixQ4TYIULsECF2iBA7RIgdIsQOEWKHCLFDhNghQuwQseurpPl6Xl9fh+sXFxfD9cfHx+H63d3dH98T27CzQ4TYIULsECF2iBA7RIgdIsQOEdM87/ctyr6yeX9r5+hrnU6nd9cuLy83vXaVr2yGOLFDhNghQuwQIXaIEDtEiB0izNm/gdEsfes5+hrOwm/DnB3ixA4RYocIsUOE2CFC7BAhdojw3vi/wMPDw3D9K8/S+Trs7BAhdogQO0SIHSLEDhFihwijt7ibm5vh+vF4HK7f3t5++NpLn3XE9XPZ2SFC7BAhdogQO0SIHSLEDhFihwhz9r/AlvPm5+fn4frSnJ2/h50dIsQOEWKHCLFDhNghQuwQIXaIMGf/Br7rue/r6+vh+tLvCPB/dnaIEDtEiB0ixA4RYocIsUOE2CFimud5v4tN034X41Mszbqfnp6G66fT6d21pa+anqZpuM7b5nl+88HZ2SFC7BAhdogQO0SIHSLEDhFihwhzdja15t+XOfvHmLNDnNghQuwQIXaIEDtEiB0ixA4RYocIsUOE2CFC7BAhdogQO0SIHSJ8ZTNn8/j4eO5bSLGzQ4TYIULsECF2iBA7RIgdIsQOEWKHCLFDhNghQuwQIXaIEDtEiB0ixA4RzrOzysPDw4c/+/Ly8ol3whI7O0SIHSLEDhFihwixQ4TYIULsEDHN87zfxaZpv4v9oaV58e3t7WbXXnp/+t3d3WbXXnJ9fT1cf3p6Gq6fTqd31y4vLz90T4zN8zy99XM7O0SIHSLEDhFihwixQ4TYIcIR1y9gaay3tL7leOv+/n7V53/8+LHq83weOztEiB0ixA4RYocIsUOE2CFC7BDhiOsvr6+vw/WLi4t316bpzROFv+0rH69de+21z+arWvMK7cNh22PLjrhCnNghQuwQIXaIEDtEiB0ixA4R5uy/nHPOvmTpdc6jM+ej+97D6Nks/bm2tPac/tJzHb1j4HDY9jXa5uwQJ3aIEDtEiB0ixA4RYocIsUOE98b/8vPnz+H6aK66dLZ57dnl5+fn4frxeHx37dxz9j1/j2NPNzc3w/Wlv7NzsLNDhNghQuwQIXaIEDtEiB0iHHH9TaMjsFuPt7Z83fPSUcylkeQaV1dXw/W1z3X03LZ8lfO5OeIKcWKHCLFDhNghQuwQIXaIEDtEmLN/gnN+5fLhMJ6VL83Jv/O8ucqcHeLEDhFihwixQ4TYIULsECF2iDBnh2/GnB3ixA4RYocIsUOE2CFC7BAhdogQO0SIHSLEDhFihwixQ4TYIULsECF2iBA7RIgdIsQOEWKHCLFDhNghQuwQIXaIEDtEiB0ixA4RYocIsUOE2CFC7BAhdogQO0SIHSLEDhFihwixQ4TYIULsEDHN83zuewB2YGeHCLFDhNghQuwQIXaIEDtEiB0ixA4RYocIsUOE2CFC7BAhdogQO0SIHSLEDhFihwixQ4TYIULsECF2iBA7RIgdIv4DuELWBQhVSx4AAAAASUVORK5CYII=\n", 85 | "text/plain": [ 86 | "
" 87 | ] 88 | }, 89 | "metadata": { 90 | "needs_background": "light" 91 | }, 92 | "output_type": "display_data" 93 | } 94 | ], 95 | "source": [ 96 | "%matplotlib inline\n", 97 | "import matplotlib.pyplot as plt\n", 98 | "\n", 99 | "random_index = random.randint(0,len(x_train))\n", 100 | "\n", 101 | "plt.imshow(x_train[random_index], cmap='Greys_r')\n", 102 | "plt.axis('off')\n", 103 | "plt.show()" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "## Create a simple fully connect network\n", 111 | "Take as in imput a digit number (hot encoded)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 6, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "from tensorflow.python.keras.models import Sequential\n", 121 | "from tensorflow.python.keras.layers import Dense, Activation, Dropout\n", 122 | "\n", 123 | "model = Sequential()\n", 124 | "model.add(Dense(256, input_shape=(10,)))\n", 125 | "model.add(Activation('tanh'))\n", 126 | "model.add(Dropout(0.2))\n", 127 | "model.add(Dense(784))\n", 128 | "model.add(Activation('sigmoid'))" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "### compile" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 7, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "model.compile(loss='categorical_crossentropy', optimizer='adam')" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "### train" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 8, 157 | "metadata": {}, 158 | "outputs": [ 159 | { 160 | "name": "stdout", 161 | "output_type": "stream", 162 | "text": [ 163 | "Train on 60000 samples, validate on 10000 samples\n", 164 | "Epoch 1/4\n", 165 | "60000/60000 [==============================] - 5s 81us/step - loss: 591.6621 - val_loss: 583.4598\n", 166 | "Epoch 2/4\n", 167 | "60000/60000 [==============================] - 4s 73us/step - loss: 576.9601 - val_loss: 583.0143\n", 168 | "Epoch 3/4\n", 169 | "60000/60000 [==============================] - 4s 72us/step - loss: 576.7511 - val_loss: 583.0249\n", 170 | "Epoch 4/4\n", 171 | "60000/60000 [==============================] - 4s 72us/step - loss: 576.6814 - val_loss: 582.9404\n" 172 | ] 173 | }, 174 | { 175 | "data": { 176 | "text/plain": [ 177 | "" 178 | ] 179 | }, 180 | "execution_count": 8, 181 | "metadata": {}, 182 | "output_type": "execute_result" 183 | } 184 | ], 185 | "source": [ 186 | "model.fit(x=Y_train, y=X_train,\n", 187 | " batch_size=128, epochs=4,\n", 188 | " verbose=1,\n", 189 | " validation_data=(Y_test, X_test))" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "metadata": {}, 195 | "source": [ 196 | "# Generate digits" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 9, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "def predict_and_show(digit):\n", 206 | " predicted = model.predict(digit)\n", 207 | " predicted *= 255.\n", 208 | " predicted = predicted.reshape(28,28)\n", 209 | " #print(predicted)\n", 210 | "\n", 211 | " plt.imshow(predicted, cmap='Greys_r')\n", 212 | " plt.axis('off')\n", 213 | " plt.show()" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 13, 219 | "metadata": {}, 220 | "outputs": [ 221 | { 222 | "data": { 223 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAACUxJREFUeJzt3dtLVG0cxfE95ozjIS3zVDpFRSVBVloQdVNdBNX/GtRlFxIRBRGYdMJKq+lkmpkyap5m3hvfu/b6SaY2ru/nssVTc3C5od9+np2pVCoJgJ2vZrtfAICtQdkBE5QdMEHZAROUHTBRu5X/WCaT4b/+gU1WqVQyv/tzruyACcoOmKDsgAnKDpig7IAJyg6YoOyACcoOmKDsgAnKDpig7IAJyg6YoOyACcoOmKDsgAnKDpig7IAJyg6YoOyACcoOmKDsgAnKDpig7IAJyg6YoOyACcoOmKDsgAnKDpig7IAJyg6Y2NJHNmNzZDK/fUJvkiRJsmvXLrm2rq5O5o2NjTLPZrMyV5aXl2W+sLCwofUrKyup2erqqlxbqey8p4tzZQdMUHbABGUHTFB2wARlB0xQdsAEZQdMMGffAmoOniRJUlOjf+dGs/DOzs7U7MKFC3JtlJ8/f17mHR0dMlfvLZqjj46Oyvz+/fsyHx4eTs3evHkj105OTsr8169fMi+XyzLfjjk+V3bABGUHTFB2wARlB0xQdsAEZQdMUHbABHP2vyCak+dyOZm3tLTIvL+/X+bXrl1LzS5fvizXHjx4UOYNDQ0yj967Eu0pb2trk3mhUJB5b29vavbgwQO5dmhoSOafPn2S+dzcnMwXFxdTs82awXNlB0xQdsAEZQdMUHbABGUHTFB2wASjt3VSI6ZoC2pra6vM+/r6ZH7z5k2ZX7lyJTVrb2+Xa6OjpqNtqGqElCT6uOdoxBRtDW5qapL5kSNHUrNoC+vs7KzMo/cdvTf1uUQjyT/FlR0wQdkBE5QdMEHZAROUHTBB2QETlB0wwZx9ndQ8Op/Py7VdXV0yP3v2rMxPnTolc7UNNTryONqqGW31fPXqlczVPLmnp0euHRgYkHl0jLX6Xvbs2SPXRnk045+ZmZF5dA/BZuDKDpig7IAJyg6YoOyACcoOmKDsgAnKDphgzr4mmnuqOXu0n/3QoUMy7+7ulnlzc7PMV1ZWUrPx8XG59t69ezJ/8uSJzKPHKqvXHu35Pnz4sMyjvfpqX3i0T39paUnm6v6B6N/eLlzZAROUHTBB2QETlB0wQdkBE5QdMEHZARPM2ddJzeGz2axcG83ho/3w0TxanWE+PT0t13779k3mX79+lXltrf4RUnv5o8dFR3n0uah7DD5//izXfvnyRebRfvXokc3lclnmm4ErO2CCsgMmKDtggrIDJig7YIKyAyYoO2CCOfsWiPZO//jxQ+bz8/Myr6+vT82i888LhYLMozl9LpeT+aVLl1Kz6Fz46Gz2YrEoc3Xm/YsXL+TasbExmUefS/T8dubsADYNZQdMUHbABGUHTFB2wARlB0wwelsndXRwqVSSa9+9eyfz/fv3y/zkyZN/vD46jrm1tVXm586d29D6AwcOpGbR+CnaZvro0SOZP378ODV7/fq1XDs1NSXz6FHY2zFai3BlB0xQdsAEZQdMUHbABGUHTFB2wARlB0wwZ18THUus5qbRFtbo2OLnz5/LPJqVnzhxIjVraWmRa/fu3Svz6HHT6lHWSaI/1+j+gzt37sh8cHBQ5mob6+zsrFwbPbL5X5yjR7iyAyYoO2CCsgMmKDtggrIDJig7YIKyAyaYs6+TmqtGM9lopjs6Oirzp0+fyvzGjRupWTQHjx43rR5VnSTx/QnqGOxbt27JtXfv3pX5yMiIzNXnHn1n0fuqRlzZAROUHTBB2QETlB0wQdkBE5QdMEHZARPM2f+CjeyFT5IkWV1dlXlNjf6drPJobTRHj0SvXZ39Ht1fEJ0DEJ0jsLKykppF3xlzdgBVi7IDJig7YIKyAyYoO2CCsgMmKDtggjn7XxDNqvP5vMy7urpkPjAwIPO2tjaZK2oWvZ482qv/8ePH1Czaax99brW1+sdX3WOw0XPfq3EOz5UdMEHZAROUHTBB2QETlB0wQdkBE4ze1kmN1+rq6uTazs5OmV+9enVDuRpRLS4uyrU/f/6U+cTEhMy/f/8u8/Hx8dQsl8vJtdFIcWpqSubqvUejtyhn9Abgn0XZAROUHTBB2QETlB0wQdkBE5QdMMGcfZ3Udszm5ma59uLFizK/fv26zDs6OmS+vLycmhWLRbl2cHBQ5tEsO5qFq22o0RbX6P6F+vp6mas5e7R1txrn6BGu7IAJyg6YoOyACcoOmKDsgAnKDpig7IAJ5uxrouOgs9lsalYoFOTa/v5+mff09Mg8mgmPjY2lZrdv35ZrR0ZGZN7U1CTzaE+6moVH7yv6uxsaGmReKpVSs2jGHz2KuhpxZQdMUHbABGUHTFB2wARlB0xQdsAEZQdMMGdfE83Z1bz46NGjcm13d/cfvab/TU9Py/zhw4ep2fv37+XaaJ4c7SmPcjXPjtbu3r1b5tE9AOpzi77vnYgrO2CCsgMmKDtggrIDJig7YIKyAyYYva2pqdG/9xobG1OzaKtlJNrqqbbXJkmSnDlzJjWLjnqO/u7W1laZR8doz8zMpGZLS0tybfQ46YgaK3KUNIAdi7IDJig7YIKyAyYoO2CCsgMmKDtggjn7mmiuqma28/Pzcu3k5KTMjx8/LvN9+/bJXB1FHR2ZHCmXyzKfm5uTuTrmWj3OOUmSZGFhQebRHF6tj7b2MmcHULUoO2CCsgMmKDtggrIDJig7YIKyAyaYs6+J5slqX/azZ8/k2vb2dpkfO3ZM5tGedHWkcrRfPRLtOY9m4R8+fEjNXr58KdcODw/LfGJiQubqtUXfN3N2AFWLsgMmKDtggrIDJig7YIKyAyYoO2Ais5XzxEwmU7XDS/WI33w+L9dGc/K+vj6Znz59Wua9vb2pWfS46OXlZZkXi0WZDw0N/XH+9u1bubZUKsk8ugcgmqXvVJVK5bc/rFzZAROUHTBB2QETlB0wQdkBE5QdMEHZARPM2atA9Ox4dQ9AJFob/XxsJN+Je8b/BczZAXOUHTBB2QETlB0wQdkBE5QdMMFR0lXAdasm/i6u7IAJyg6YoOyACcoOmKDsgAnKDpig7IAJyg6YoOyACcoOmKDsgAnKDpig7IAJyg6YoOyAiS09ShrA9uHKDpig7IAJyg6YoOyACcoOmKDsgAnKDpig7IAJyg6YoOyACcoOmKDsgAnKDpig7IAJyg6YoOyACcoOmKDsgAnKDpig7IAJyg6YoOyACcoOmPgPMwNuSg2BkqYAAAAASUVORK5CYII=\n", 224 | "text/plain": [ 225 | "
" 226 | ] 227 | }, 228 | "metadata": { 229 | "needs_background": "light" 230 | }, 231 | "output_type": "display_data" 232 | } 233 | ], 234 | "source": [ 235 | "my_digit = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 2, 0]])\n", 236 | "\n", 237 | "predict_and_show(my_digit)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [] 246 | } 247 | ], 248 | "metadata": { 249 | "kernelspec": { 250 | "display_name": "Python 3", 251 | "language": "python", 252 | "name": "python3" 253 | }, 254 | "language_info": { 255 | "codemirror_mode": { 256 | "name": "ipython", 257 | "version": 3 258 | }, 259 | "file_extension": ".py", 260 | "mimetype": "text/x-python", 261 | "name": "python", 262 | "nbconvert_exporter": "python", 263 | "pygments_lexer": "ipython3", 264 | "version": "3.5.2" 265 | } 266 | }, 267 | "nbformat": 4, 268 | "nbformat_minor": 2 269 | } 270 | -------------------------------------------------------------------------------- /tf-minimal-Char-RNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [ 10 | { 11 | "name": "stdout", 12 | "output_type": "stream", 13 | "text": [ 14 | "data has 119163 characters, 61 unique.\n", 15 | "all variable initialized\n", 16 | "iter 0, loss: 102.673186\n", 17 | "----\n", 18 | " .Dt) Bhtfq(CpHsI,G\n", 19 | "zYpTTD;.,JdB lOU\n", 20 | "qi.?uy:Erer!\n", 21 | "aUhwVBo;rqFLzz:,aQ!BngUi:GbfAWf\n", 22 | "A.uvUODoUEqpBHMmoxNAGoCm\"AeIlyGy\n", 23 | "LvIMU'PfajH)njUcz,LQgi?AbF(aSGArCEWtLpDP;DINryil?LJW)IWAbtDpnt'ynq b\n", 24 | "o?!CSSMS'(zgPeNlt \n", 25 | "----\n", 26 | "iter 100, loss: 93.292775\n", 27 | "----\n", 28 | " akn'ifudvyAVip hEyniqheuYan i.\n", 29 | "Guke,if;zob plo;nvg iuu;\"iw;xaw; se xu\n", 30 | "s,SieYofeYmt Fxak(yQ(obCYeFlyeePcdiheyzcn!hl;biS oupbdUubiirknoce;ooki\"S yrW h.t\n", 31 | ":ope,uarSSf;gFan za \"iad 'feylJ ify:kff-AMsn oge \n", 32 | "----\n", 33 | "iter 200, loss: 84.733920\n", 34 | "----\n", 35 | " slgisnfrgenug?ertrctss fdtn aehsa eeubs l voo o qe wuvepon eyoh owseaea ilbaee neau as tfnnb s h ehetWaieefa teaai rbeYa g t tleeie nsellaad gH oh lcneide e seheosi nttieta 'tuzs tkaicgdgrhnoks \n", 36 | "----\n", 37 | "iter 300, loss: 76.980630\n", 38 | "----\n", 39 | " ftamtucmn np ekhaunaopg tye evod\"cso e delueiinnce cumme\"sotlmce new g demtaTt,astpee c s, ay e ,y aifoiGtabniiogm \n", 40 | "gs fdiroo\n", 41 | " eyp \n", 42 | "oei,srh maeuettob m eukrdGaem EVo ec ,homo u d,s hlief sneoicp g \n", 43 | "----\n" 44 | ] 45 | }, 46 | { 47 | "ename": "KeyboardInterrupt", 48 | "evalue": "", 49 | "output_type": "error", 50 | "traceback": [ 51 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 52 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 53 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0mfeed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0minput_data\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_data\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minitial_state\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 96\u001b[0m \u001b[0;31m# Run a session using train_op\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 97\u001b[0;31m \u001b[0;34m[\u001b[0m\u001b[0mtrain_loss\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mcost\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfinal_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_op\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 98\u001b[0m \u001b[0msmooth_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msmooth_loss\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m0.999\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mtrain_loss\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m0.001\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 99\u001b[0m \u001b[0;31m# sample from the model now and then\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 54 | "\u001b[0;32m/Users/damienhenry/code/jupyter/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 338\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 339\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 340\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 341\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 342\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 55 | "\u001b[0;32m/Users/damienhenry/code/jupyter/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 562\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 563\u001b[0m results = self._do_run(handle, target_list, unique_fetches,\n\u001b[0;32m--> 564\u001b[0;31m feed_dict_string, options, run_metadata)\n\u001b[0m\u001b[1;32m 565\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 566\u001b[0m \u001b[0;31m# The movers are no longer used. Delete them.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 56 | "\u001b[0;32m/Users/damienhenry/code/jupyter/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 635\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 636\u001b[0m return self._do_call(_run_fn, self._session, feed_dict, fetch_list,\n\u001b[0;32m--> 637\u001b[0;31m target_list, options, run_metadata)\n\u001b[0m\u001b[1;32m 638\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 639\u001b[0m return self._do_call(_prun_fn, self._session, handle, feed_dict,\n", 57 | "\u001b[0;32m/Users/damienhenry/code/jupyter/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m 642\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 643\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 644\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 645\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mStatusNotOK\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 646\u001b[0m \u001b[0merror_message\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0merror_message\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 58 | "\u001b[0;32m/Users/damienhenry/code/jupyter/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(session, feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m 626\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 627\u001b[0m return tf_session.TF_Run(\n\u001b[0;32m--> 628\u001b[0;31m session, None, feed_dict, fetch_list, target_list, None)\n\u001b[0m\u001b[1;32m 629\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 630\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msession\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 59 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 60 | ] 61 | } 62 | ], 63 | "source": [ 64 | "\"\"\"\n", 65 | "Minimal character-level TensorFlow RNN model.\n", 66 | "Original code written by Andrej Karpathy (@karpathy), adapted to TensorFlow by Damien Henry (@dh7net)\n", 67 | "BSD License\n", 68 | "\"\"\"\n", 69 | "import numpy as np\n", 70 | "import tensorflow as tf\n", 71 | "tf.reset_default_graph() # Only usefull when used in Jupyter if you want to run the code several times\n", 72 | "\n", 73 | "# data I/O\n", 74 | "data = open('methamorphosis.txt', 'r').read() # should be simple plain text file\n", 75 | "chars = list(set(data))\n", 76 | "data_size, vocab_size = len(data), len(chars)\n", 77 | "print 'data has %d characters, %d unique.' % (data_size, vocab_size)\n", 78 | "char_to_ix = { ch:i for i,ch in enumerate(chars) } # to convert a char to an ID\n", 79 | "ix_to_char = { i:ch for i,ch in enumerate(chars) } # to convert an ID back to a char\n", 80 | "\n", 81 | "# hyperparameters\n", 82 | "hidden_size = 100 # size of hidden layer of neurons\n", 83 | "seq_length = 25 # number of steps to unroll the RNN for\n", 84 | "learning_rate = 1e-2\n", 85 | "\n", 86 | "# model parameters\n", 87 | "Wxh = tf.Variable(tf.random_uniform((hidden_size, vocab_size))*0.01, name='Wxh') # input to hidden\n", 88 | "Whh = tf.Variable(tf.random_uniform((hidden_size, hidden_size))*0.01, name='Whh') # hidden to hidden\n", 89 | "Why = tf.Variable(tf.random_uniform((vocab_size, hidden_size))*0.01, name='Why') # hidden to output\n", 90 | "bh = tf.Variable(tf.zeros((hidden_size, 1)), name='bh') # hidden bias\n", 91 | "by = tf.Variable(tf.zeros((vocab_size, 1)), name='by') # output bias\n", 92 | "\n", 93 | "# loss function\n", 94 | "# Define placeholder to for the input and the target & create the sequences\n", 95 | "input_data = tf.placeholder(tf.float32, [seq_length, vocab_size], name='input_data')\n", 96 | "xs = tf.split(0, seq_length, input_data)\n", 97 | "target_data = tf.placeholder(tf.float32, [seq_length, vocab_size], name='target_data') \n", 98 | "targets = tf.split(0, seq_length, target_data) \n", 99 | "# initial_state & loss\n", 100 | "initial_state = tf.zeros((hidden_size, 1))\n", 101 | "loss = tf.zeros([1], name='loss')\n", 102 | "# unroll recursion to create the loss\n", 103 | "hs, ys, ps = {}, {}, {}\n", 104 | "hs[-1] = initial_state\n", 105 | "# forward pass \n", 106 | "for t in xrange(seq_length):\n", 107 | " xs_t = tf.transpose(xs[t])\n", 108 | " targets_t = tf.transpose(targets[t]) \n", 109 | " hs[t] = tf.tanh(tf.matmul(Wxh, xs_t) + tf.matmul(Whh, hs[t-1]) + bh) # hidden state\n", 110 | " ys[t] = tf.matmul(Why, hs[t]) + by # unnormalized log probabilities for next chars\n", 111 | " ps[t] = tf.exp(ys[t]) / tf.reduce_sum(tf.exp(ys[t])) # probabilities for next chars\n", 112 | " loss += -tf.log(tf.reduce_sum(tf.mul(ps[t], targets_t))) # softmax (cross-entropy loss)\n", 113 | "\n", 114 | "cost = loss / seq_length\n", 115 | "final_state = hs[seq_length-1]\n", 116 | "tvars = tf.trainable_variables()\n", 117 | "grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), 5) # clip to mitigate exploding gradients\n", 118 | "optimizer = tf.train.AdamOptimizer(learning_rate)\n", 119 | "train_op = optimizer.apply_gradients(zip(grads, tvars))\n", 120 | "\n", 121 | "def sample(h, seed_ix, n):\n", 122 | " \"\"\" \n", 123 | " sample a sequence of integers from the model \n", 124 | " h is memory state, seed_ix is seed letter for first time step\n", 125 | " \"\"\"\n", 126 | " x = np.zeros((vocab_size, 1))\n", 127 | " x[seed_ix] = 1\n", 128 | " ixes = []\n", 129 | " for t in xrange(n):\n", 130 | " h = np.tanh(np.dot(Wxh.eval(), x) + np.dot(Whh.eval(), h) + bh.eval())\n", 131 | " y = np.dot(Why.eval(), h) + by.eval()\n", 132 | " p = np.exp(y) / np.sum(np.exp(y))\n", 133 | " ix = np.random.choice(range(vocab_size), p=p.ravel())\n", 134 | " x = np.zeros((vocab_size, 1))\n", 135 | " x[ix] = 1\n", 136 | " ixes.append(ix)\n", 137 | " return ixes\n", 138 | "\n", 139 | "def vectorize(x): # take an array of IX and return an array of vector\n", 140 | " vectorized = np.zeros((len(x), vocab_size))\n", 141 | " for i in range(0, len(x)):\n", 142 | " vectorized[i][x[i]] = 1\n", 143 | " return vectorized\n", 144 | "\n", 145 | "n, p = 0, 0\n", 146 | "smooth_loss = -np.log(1.0/vocab_size)*seq_length # loss at iteration 0\n", 147 | "with tf.Session() as sess:\n", 148 | " tf.initialize_all_variables().run()\n", 149 | " print \"all variable initialized\"\n", 150 | " while True:\n", 151 | " # prepare inputs (we're sweeping from left to right in steps seq_length long)\n", 152 | " if p+seq_length+1 >= len(data) or n == 0: \n", 153 | " state = initial_state.eval() # reset RNN memory\n", 154 | " p = 0 # go from start of data\n", 155 | " x = vectorize([char_to_ix[ch] for ch in data[p:p+seq_length]])\n", 156 | " y = vectorize([char_to_ix[ch] for ch in data[p+1:p+seq_length+1]])\n", 157 | " # Create the structure for the learning data\n", 158 | " feed = {input_data: x, target_data: y, initial_state: state}\n", 159 | " # Run a session using train_op\n", 160 | " [train_loss], state, _ = sess.run([cost, final_state, train_op], feed)\n", 161 | " smooth_loss = smooth_loss * 0.999 + train_loss * 0.001\n", 162 | " # sample from the model now and then\n", 163 | " if n % 1000 == 0:\n", 164 | " print 'iter %d, loss: %f' % (n, smooth_loss) # print progress\n", 165 | " #sample(sess)\n", 166 | " sample_ix = sample(state, char_to_ix['A'], 200)\n", 167 | " txt = ''.join(ix_to_char[ix] for ix in sample_ix)\n", 168 | " print '----\\n %s \\n----' % (txt, )\n", 169 | "\n", 170 | " p += seq_length # move data pointer\n", 171 | " n += 1 # iteration counter \n", 172 | " " 173 | ] 174 | } 175 | ], 176 | "metadata": { 177 | "kernelspec": { 178 | "display_name": "Python 2", 179 | "language": "python", 180 | "name": "python2" 181 | }, 182 | "language_info": { 183 | "codemirror_mode": { 184 | "name": "ipython", 185 | "version": 2 186 | }, 187 | "file_extension": ".py", 188 | "mimetype": "text/x-python", 189 | "name": "python", 190 | "nbconvert_exporter": "python", 191 | "pygments_lexer": "ipython2", 192 | "version": "2.7.10" 193 | } 194 | }, 195 | "nbformat": 4, 196 | "nbformat_minor": 0 197 | } 198 | -------------------------------------------------------------------------------- /rnn_face_tests/tf-LFW-load-tensor.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Tensor Flow to encode LFW set\n", 8 | "To learn how to encode a simple image and a GIF" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "## Import needed for Tensorflow" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "metadata": { 22 | "collapsed": false 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "import numpy as np\n", 27 | "import tensorflow as tf" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "## Import needed for Jupiter:" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": { 41 | "collapsed": true 42 | }, 43 | "outputs": [], 44 | "source": [ 45 | "%matplotlib notebook\n", 46 | "import matplotlib\n", 47 | "import matplotlib.pyplot as plt\n", 48 | "\n", 49 | "from IPython.display import Image" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 3, 55 | "metadata": { 56 | "collapsed": false 57 | }, 58 | "outputs": [], 59 | "source": [ 60 | "size = 10\n", 61 | "\n", 62 | "tf.reset_default_graph()\n", 63 | "\n", 64 | "all_faces = tf.Variable(0,validate_shape=False, dtype=tf.float32, name='all_faces')\n", 65 | "saver = tf.train.Saver()\n", 66 | "sess = tf.Session()" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 4, 72 | "metadata": { 73 | "collapsed": false 74 | }, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "(u'all_faces:0', ())\n", 81 | "0.0\n" 82 | ] 83 | } 84 | ], 85 | "source": [ 86 | "with sess.as_default():\n", 87 | " tf.initialize_all_variables().run()\n", 88 | " for var in tf.all_variables():\n", 89 | " print (var.name, var.eval().shape)\n", 90 | " print all_faces.eval()" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 5, 96 | "metadata": { 97 | "collapsed": false 98 | }, 99 | "outputs": [ 100 | { 101 | "name": "stdout", 102 | "output_type": "stream", 103 | "text": [ 104 | "./tmp/model.ckpt\n", 105 | "Model restored.\n" 106 | ] 107 | } 108 | ], 109 | "source": [ 110 | "with sess.as_default():\n", 111 | " ckpt = tf.train.get_checkpoint_state(\"./tmp/\")\n", 112 | " if ckpt and ckpt.model_checkpoint_path:\n", 113 | " print ckpt.model_checkpoint_path\n", 114 | " saver.restore(sess, ckpt.model_checkpoint_path)\n", 115 | " print(\"Model restored.\")\n", 116 | " else:\n", 117 | " print (\"Model not restored.\")" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 6, 123 | "metadata": { 124 | "collapsed": false 125 | }, 126 | "outputs": [ 127 | { 128 | "name": "stdout", 129 | "output_type": "stream", 130 | "text": [ 131 | "(u'all_faces:0', (10, 250, 250))\n", 132 | "[[[ 0.14536607 0.15289439 0.14419065 ..., 0.92481071 0.95753837\n", 133 | " 0.95830441]\n", 134 | " [ 0.14536607 0.15289439 0.14419065 ..., 0.90830153 0.95753837\n", 135 | " 0.95830441]\n", 136 | " [ 0.14536607 0.15289439 0.14419065 ..., 0.89179236 0.94102919\n", 137 | " 0.94179523]\n", 138 | " ..., \n", 139 | " [ 0.44253123 0.45005956 0.44135582 ..., 1.15593922 1.17215765\n", 140 | " 1.17292368]\n", 141 | " [ 0.47554958 0.46656874 0.457865 ..., 1.15593922 1.17215765\n", 142 | " 1.17292368]\n", 143 | " [ 0.47554958 0.46656874 0.47437418 ..., 1.15593922 1.17215765\n", 144 | " 1.17292368]]\n", 145 | "\n", 146 | " [[ 0.73969638 0.74722475 0.75503016 ..., 1.18895757 1.205176\n", 147 | " 1.20594203]\n", 148 | " [ 0.73969638 0.74722475 0.75503016 ..., 1.18895757 1.205176\n", 149 | " 1.20594203]\n", 150 | " [ 0.73969638 0.74722475 0.75503016 ..., 1.1724484 1.18866682\n", 151 | " 1.18943286]\n", 152 | " ..., \n", 153 | " [ 0.3930037 0.35100451 0.24324571 ..., -2.11287761 -2.12967753\n", 154 | " -2.14542055]\n", 155 | " [ 0.37649453 0.11987605 0.21022736 ..., -2.2119329 -2.22873259\n", 156 | " -2.3105123 ]\n", 157 | " [ 0.35998535 0.10336687 0.19371818 ..., -2.410043 -2.45986104\n", 158 | " -2.55814981]]\n", 159 | "\n", 160 | " [[ 0.73969638 0.74722475 0.75503016 ..., 1.18895757 1.205176\n", 161 | " 1.20594203]\n", 162 | " [ 0.73969638 0.74722475 0.75503016 ..., 1.18895757 1.205176\n", 163 | " 1.20594203]\n", 164 | " [ 0.73969638 0.74722475 0.75503016 ..., 1.18895757 1.205176\n", 165 | " 1.20594203]\n", 166 | " ..., \n", 167 | " [-1.24140465 -1.25038552 -1.24258006 ..., 1.04037499 0.97404754\n", 168 | " 0.85924935]\n", 169 | " [-1.2083863 -1.21736717 -1.20956171 ..., 1.07339334 1.04008424\n", 170 | " 0.99132276]\n", 171 | " [-1.2083863 -1.21736717 -1.20956171 ..., 1.07339334 1.12263012\n", 172 | " 1.07386863]]\n", 173 | "\n", 174 | " ..., \n", 175 | " [[-1.58809733 -1.56405985 -1.50672686 ..., -0.90770781 -0.39621404\n", 176 | " -0.14781035]\n", 177 | " [-1.62111568 -1.5970782 -1.55625439 ..., -1.30392802 -1.00705349\n", 178 | " -0.79166818]\n", 179 | " [-1.65413404 -1.63009655 -1.58927274 ..., -1.66712987 -1.43629205\n", 180 | " -1.27043426]\n", 181 | " ..., \n", 182 | " [-2.71072149 -2.86828494 -3.00906181 ..., -1.60109317 -1.66742063\n", 183 | " -1.56759942]\n", 184 | " [-2.64468479 -2.67017484 -2.76142406 ..., -1.6506207 -1.63440228\n", 185 | " -1.53458107]\n", 186 | " [-2.49610209 -2.53810143 -2.62935066 ..., -1.53505647 -1.56836557\n", 187 | " -1.48505354]]\n", 188 | "\n", 189 | " [[-1.50555146 -1.49802315 -1.45719934 ..., -0.34639582 -0.34668651\n", 190 | " -0.32941127]\n", 191 | " [-1.43951476 -1.48151398 -1.49021769 ..., -0.362905 -0.34668651\n", 192 | " -0.34592047]\n", 193 | " [-1.30744135 -1.41547728 -1.49021769 ..., -0.39592335 -0.34668651\n", 194 | " -0.34592047]\n", 195 | " ..., \n", 196 | " [ 0.54158628 0.53260541 0.54041088 ..., 1.13943005 1.07310259\n", 197 | " 1.02434111]\n", 198 | " [ 0.54158628 0.53260541 0.54041088 ..., 1.1064117 1.05659342\n", 199 | " 0.97481358]\n", 200 | " [ 0.54158628 0.53260541 0.54041088 ..., 1.1064117 1.02357507\n", 201 | " 0.94179523]]\n", 202 | "\n", 203 | " [[ 0.73969638 0.74722475 0.75503016 ..., 1.18895757 1.205176\n", 204 | " 1.20594203]\n", 205 | " [ 0.73969638 0.74722475 0.75503016 ..., 1.18895757 1.205176\n", 206 | " 1.20594203]\n", 207 | " [ 0.73969638 0.74722475 0.75503016 ..., 1.18895757 1.205176\n", 208 | " 1.20594203]\n", 209 | " ..., \n", 210 | " [-2.54562974 -2.58762884 -2.62935066 ..., -0.47846922 -0.06603053\n", 211 | " -0.0487553 ]\n", 212 | " [-2.59515715 -2.62064719 -2.66236901 ..., -0.77563441 -0.34668651\n", 213 | " -0.13130118]\n", 214 | " [-2.59515715 -2.62064719 -2.66236901 ..., -1.07279956 -0.56130582\n", 215 | " -0.32941127]]]\n" 216 | ] 217 | } 218 | ], 219 | "source": [ 220 | "with sess.as_default():\n", 221 | " for var in tf.all_variables():\n", 222 | " print (var.name, var.eval().shape)\n", 223 | " print all_faces.eval()" 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "metadata": {}, 229 | "source": [ 230 | "## A function to save a picture" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 7, 236 | "metadata": { 237 | "collapsed": true 238 | }, 239 | "outputs": [], 240 | "source": [ 241 | "#need to be called within a session\n", 242 | "def write_png(tensor, name):\n", 243 | " casted_to_uint8 = tf.cast(tensor, tf.uint8)\n", 244 | " converted_to_png = tf.image.encode_png(casted_to_uint8)\n", 245 | " f = open(name, \"wb+\")\n", 246 | " f.write(converted_to_png.eval())\n", 247 | " f.close() " 248 | ] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "metadata": {}, 253 | "source": [ 254 | "## Create a test pictures\n", 255 | "### Encode the input (a number)\n", 256 | "This example convert the number to a binary representation" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 8, 262 | "metadata": { 263 | "collapsed": false 264 | }, 265 | "outputs": [ 266 | { 267 | "name": "stdout", 268 | "output_type": "stream", 269 | "text": [ 270 | "(1, 250, 250)\n", 271 | "(250, 250, 1)\n" 272 | ] 273 | } 274 | ], 275 | "source": [ 276 | "# Init size\n", 277 | "with sess.as_default():\n", 278 | " a_face = tf.gather(all_faces,[0])\n", 279 | " print a_face.eval().shape\n", 280 | " # remove channel dimension and add index dimension\n", 281 | " pict_face = tf.expand_dims(tf.squeeze(a_face, squeeze_dims=[0]),2)\n", 282 | " print pict_face.eval().shape\n", 283 | " write_png(pict_face, 'a_face.png')\n" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 9, 289 | "metadata": { 290 | "collapsed": false 291 | }, 292 | "outputs": [ 293 | { 294 | "data": { 295 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPoAAAD6CAAAAACthwXhAAAOu0lEQVR4nO1dyQKlOgoN3f3/v0wv\nnDIwB7zWe8XiDhojRwghgArNRNBQ3AwNwdbTfWRHZN+O/TH6T0mvPqpBplIadJ/Qv0BfkLpKNZf1\nj4BeQ3+h/xvpV9B/ZNV72oLOzPa2Q39Of4TC/3NdmloCTsFSoH9AeyViwP/jpc6LJQf6h8WOjbMV\n/3ipN9ZK/u9dLn5A7PTwb5A6Q1+QOnjmbWxZpiVD6tVGrsjpNUP/sBEPkl3q0BpMzkGSOL4foAJo\ns/B/wDRAlgI6oH9gndlayxt6HjPHYP9TrYBX4adNrX1HG7z0L3Zp/kL/NBUZkz8BehH9hf46fWBa\n+Ct1hY75G3GWFp6fezL8kU/klDqNERu6wf9e49MUfgPK91duCnkBXO39CpNEngBV8pj8tcp7li/K\n/l9D8dJ/Ta1g+SE1ogihtYZ8JdYPyHbSp9XI+RFL1fswKITUSY0+OaXu3unh+2XZ78XhoaFRXNMF\nIA/iqg7niyef00xe6EO6gAbQoDXsGxJyHxMJFxZjxaUrZSF04201QSe4pXxdpddbjmTBadHUYZI6\nd31IKSHzW6azq2OV8M6g3xvrMjSXgcNJs+rhux1ZM0uagzoHeJEdKEVXITvTiv2XUFy2wIGGo+V2\nVpn7yS11UZbYI2dLlxorSKT/SD3FaXPlRqXg0LAW45pw2yuwxxT+mlkJ5MMe4RosLvCtMCTMpLm8\nJwt09pLD6t0Tjgw4Z+YbfPFoTzNzvchbawiPjVOBzxivqb3TgAK3xjDWqSzjsg3bOFDvXLxooo6j\nLLCA+b1B21LHO9u6ItATscNceB6BT7fiQbu0HZuDdogbB/t2bW0BGR1lKxO+AoU3MLbckvZwBvcS\nrQN+x2Lc2ffOn72+r5P/BPpqgggvzTqlRRjyzhBWCo31ZSAGVZu6VOu8XhXuDJq5Ue5z4MHbmRr+\nKKGEeT0uFbi/Se+wmNzQu3AcdFPahtDvrjp6Y8UenNwQRwMuI5+rLK/NXX/EruqUhnppiQZUCIr1\nOMnZia25HLvhmctw792TWyO9DQI5MSmauoNnK7DhaQ9ybtEXMXP8AlLEZ2QX+8HDYEwxAxGpz+Fm\nmLf6OFsu5NjfNkwuTJaQX99DTqzb0jPuNEPReX1R+hi7XJwq1JnzLBUVVGbGXyqoCMf7VpeaGjy0\nmQ6xpDGQRRGpM+nD4b8B1jsi5ykC3ZQQlBXDU25WdYkCkxuyy3XyyATOa1Q+JSJbaaeRTjyT5AvW\nZySeigft2L0U4fUxEsm08jnRGiqKZOzP67XI55xO4vm2ob8xRdnPMSuFpCS70N+fnD0ZZ/FOSBX6\nq+EyloYspgB+MolN4v/PuAtiyNqaFQ26T4J253XTVJqYRHD0I+srJOTcbCMiOm48dstzatClrp9K\nlcOld1uC34lDErk/aBlJZg8H9f48OwChdQlxaK3OzFERvVBkwilt8RSX/h3fu1Jn12yzgv96dd7R\neTX9Uj9dSzuUTNBHEYPUo11Lwiu3KfV/nxnuNFOSM+QuULCSpvCcZRU2kzmjbd6tJt5+opRaGvrP\nPeXj3p2fwhm2KLMy+oDap1ULK2EUMjzCRYNOlraQ/0fEhbUClj4tD6/JSDxBa6w/m4F96MOet9Qo\nw6VZuen+1hR079A1NwZcGm+Rck5R853UNC4V9Z013txw7kwbl1kbvgO9Y4OddT/kv86UVRQ+YUe5\nECg+6SWKve4Bi9LqcfmtFWLd3hGY1v0m128DulATNe9VaTmcQ4iwG/S4qUbqqqYbSLpTLAW7Hzqp\nTGwhHNPiC+R3aQLIgSmXNJ+pO3pX4M/xcYWnahznHfIawBlQ9x8n95VbMsidhdsQoqzpLam0QB8F\nDPDe/59dgWJ3KABd5sgynY8tsU+N9fAf7WIrRvmTUeLABvgUtYdWboIm8+WzZP3N/SGXDxDRDzd1\nTw87eAlAX601zY5cCOKstpz9PrHtESphfKJ7c3i9znQscOPvJNMXWLMCETPXK+jkc9HGX4OAT1aI\nvhq+NTDdxbI1aOElHxOmJlIgE4cWyV4gClcMIQi9K3KgLTeMNlvrjk4zoO3onkbZyjfGppQRPaZj\neGLDszGUaMTlh4EclypcD99a67njByqOe+bh0B++HsZ1mkExqV8G+wlDyzPZQrNysGpxhSELFn6b\ni1agNqtseoy0Azks4VrRB8wLVTwF02S+8WkmTeawod9exVCgE911ftU5wE9+GXWYCcmfwpYdXZfE\nvmXhb+u+Bid2uu1PcK9xLM19Z618/NY2oTx2RmJcZf7ghLFuuAuoa+vq2Tm1b0aB1d2XOcOuiTcL\nd/aktTtlbjPvTIf8ubbX6/FRPXNjifS4OlTIr/Ag/HN1MM1yQAjulrgeDmR9IpbD1CePURvMBxO+\nC3BtUyjoyI7frvUFCwLF/XvYKQ5DoQpojdNB8Soo7J+6mf58Naa/4N2NAwhrWN46OuQr5M3ksCuj\n6Fj36t/yPjRpCS92jlkDf8PMORjwDmDlsZRO14XpTBnrOQ/64xdBw3w2uAuSP+iyBkD7m3uOLIyP\nP+WeCql20xQ59tffLwpWemGFt85rxrXH8nI+atJv3EIxRGU3gByyZIGvI3ZuOmn1BlyGjT1vbo4H\nTbvlaajSP58OJY+OQ4fWoJfM7H0bepCnaHaRlOTybCg84Gk7oTX01sI+D1AU0xmVlHvHkyNReLs0\n9zVb4qm9Je1Cv1l+blTqtx9/+t3XZscIxlWN6XqAojBYkSMbCCgR+ZvaAVB1i59NLZdIV3mks6M4\ndFxWcOMixYZ9XsawBaJO7gy0JfVkhQRqZXJ6b5/Iud3UP2eO4iwwV1FGLtqXShsuzWHjhRaabfYE\nHb6m8O1Cz0AYsVNI6bXeM1VWUt2tABMNOjKMkXn1De15Qm0l+G3oIHoxanSGSu8MX3wnu5Qzr7OB\nJmojYheZY5++NgyVOGcC7Ss8IKxKu1JfbXOpMqxZZPbRsfm0CR1cU+58fYglLzblZqk80vg2LMG6\nz8jxQzfURmOq1U2vWXiNhBjzV82cIZMQVt5arc9auUmz2J7UEroQOt7Y35ow2vUgfddgetnjbeuq\nhnqG1E8jrzGoreDfW6iflBKHB1sqSCqHeh34Dyw87bLwyAdb6YpJa2qYdeu+PQM4vaawrZGZxfMB\nZPaYWOupP/Uv5nXBqRv/XpDzRkPvfG5CHxYZKRyOIdnpFabAtR2Tn3JpwrV3c3LrRuAUk4GRVetl\n4ZMY2G3WezMspDNrZO+beJ6UsHdK5jGl6/4edA5igg+y2YXlcMWlkUfN9T2/fC1KSW8Xb6aVRWL2\nJYPrvJdDsKvAh00FusSK7LRLfzjKdNbJWX2ICcct/Bo8E0W2K0//8feExwRBw9CplKhTR1wUOX59\nT1wPJ2esG1yJTWUOXTnZdORI3eS77Mk9JvUlIdj9DldVzOe4f7I8QmuNvc2tgpQpN7+0QNYjc8xl\nTtAFB4xwvuz715sYWz32L3463dypH8tpYTgHeYDWIUELU3zqzERrW1g2LwsE4gg9W9VtDSm8CZXv\nnQ32hly/o0GTWp2UNNbJi7H7RM0Ng2g5c/TBybZ22rBfG3ZqW4xcb2QcxUI/jqjd6HySETu78bu6\n4BbU+vVhwiVKI/UImbAtICNOPOVlkk4tffjKk95qJS/jS/stT4j055yhijADJfbeTLQVdrCh5wno\nXc4o+qGnSsTaS8QQt5s5xYGTd3ZGW7X6ZwM2Oh8iArDRkd05McJtce6vV3JsnedHidom9b30FhPD\n0TxSwo310H1WLiiVZFJUjW/m0ftca7wnucBTc1UX/o03e8H0fUVPBOb6TMtxBQwicsUc29svNbuD\nB5QNp5rfMVSD1L2R7FczrY+x4Uokn8DKbRZhSecl0adeZUdF9qGFDJLhWmWYOccSCIlf/aYunoTj\nHq/cGaZ2QxUTObjSYgmc1YeC12m8NtZPIw3XjEW1EAi8otcDDTkPTjZJ5ClnIEGosM53M6dZvNfN\nHHeVLJFOgOPFYQHdh7XDN7y5S117F2Wma4vlRjkU94osvezN3Se9YiZEWsClxQDQ3M9oWakyIjtS\nt1xbsPs9dBDC0iIPD2VJXWN+TowMIZBQ1Au6FwaGKG+2lBdveBnodj2Lkh6xAxZdGaS1QDfESCYT\nHQWTU3ddAhI7ghe6CJ4IVfwO+hWg6pYkE9tu6I1Gv4TrYfhqLdXCq4wODjyV/WXY1eiZ7c2vJ3f0\nbuTB0xJJcY3Meyy/OtvD/dH9T6MA9o5ZbDArg2vWI0fQzN3T/ZfW67ti0G8rH6jo4SyWlqtbDYK5\nNtEx4I2NfyZ1BuH2ukwwdlPfydA9C+oh7Nr10PcRWaMN4IXVQnrsw9hjd2/HbOONeT6JHmsPpxdF\neDX5Cm9kdnUx7uP3gxHjW3OZdV6+1L19drewHyLKEDt/JO3Pp1EQ+zX4B+xbiU5iWzH0uNzbEXZP\ngi5Hvb4BXVjOV+RdDqqZ1338PlmHOpwEFbk0XuxwHAQNG2JRjm2mb/jwB3aPF5pAZc+b8x6QZdmo\nrukQ3jekPi9l8jV+wQ5QBt3JOzvdZowAJC/lR6R+Dnc0Zla8fVPJTayD7ucfnjjKGza+UOp+lYd2\nlppsrltXRoigZN1Yb0HsNf7bauVqx7pXXpU2nuDlM2buoCuNloub7q0UeqD0h5BOioO3hILKHUdv\n97f/3mffEqI2XSfmABXE6jfi9CxYa6Y3ZwriPexwYn/jjLYbQN4T/AkbkXY+M8lq5vbBH3O2pfpn\nLg5NcWrWvuwWPpgNmDsxVZ6crk2m2KH7PEiGDuwfP53xVkMBQoP7ElXqvAi9YP1ooA77Uguz2e3Q\nkcul2WTBdHinmLt517nn6fwSdCJX47N2c4jZc9RVePOM/WxyO7IMePqiYLfbMFtjb4nLh1eCD38V\nq2sagcZHTEHnzGc9lYoaNxx01xmX/O18yiP6ohv3Pv0KZ2q4zMhz0MVKPJZUdNYrejs1TAFCBoVe\nWylUeM08xkrhrkGfK/bJlfRZeMMu2dzZ6B4mwATWYjSxVnDbj4zdrPLl67f/A8ybODI9gRODAAAA\nAElFTkSuQmCC\n", 296 | "text/plain": [ 297 | "" 298 | ] 299 | }, 300 | "execution_count": 9, 301 | "metadata": {}, 302 | "output_type": "execute_result" 303 | } 304 | ], 305 | "source": [ 306 | "Image(\"a_face.png\")" 307 | ] 308 | }, 309 | { 310 | "cell_type": "markdown", 311 | "metadata": {}, 312 | "source": [ 313 | "Feedback wellcome @dh7net" 314 | ] 315 | } 316 | ], 317 | "metadata": { 318 | "kernelspec": { 319 | "display_name": "Python 2", 320 | "language": "python", 321 | "name": "python2" 322 | }, 323 | "language_info": { 324 | "codemirror_mode": { 325 | "name": "ipython", 326 | "version": 2 327 | }, 328 | "file_extension": ".py", 329 | "mimetype": "text/x-python", 330 | "name": "python", 331 | "nbconvert_exporter": "python", 332 | "pygments_lexer": "ipython2", 333 | "version": "2.7.10" 334 | } 335 | }, 336 | "nbformat": 4, 337 | "nbformat_minor": 0 338 | } 339 | -------------------------------------------------------------------------------- /word2vec/word2vec.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true, 7 | "deletable": true, 8 | "editable": true 9 | }, 10 | "source": [ 11 | "# A Word2Vec playground\n", 12 | "\n", 13 | "To play with this notebook, you'll need Numpy, Annoy, Gensim, and the [GoogleNews word2vec model]( https://code.google.com/archive/p/word2vec/)\n", 14 | "\n", 15 | "* pip install numpy\n", 16 | "* pip install annoy \n", 17 | "* pip install gensim \n", 18 | "* you can find the GoogleNews vector by googling _./GoogleNews-vectors-negative300.bin_ \n", 19 | " \n", 20 | "\n", 21 | "Inspired by: https://github.com/chrisjmccormick/inspect_word2vec\n" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 1, 27 | "metadata": { 28 | "collapsed": false, 29 | "deletable": true, 30 | "editable": true 31 | }, 32 | "outputs": [], 33 | "source": [ 34 | "# import and init\n", 35 | "from annoy import AnnoyIndex\n", 36 | "import gensim\n", 37 | "import os.path\n", 38 | "import numpy as np\n", 39 | "\n", 40 | "prefix_filename = 'word2vec'\n", 41 | "ann_filename = prefix_filename + '.ann'\n", 42 | "i2k_filename = prefix_filename + '_i2k.npy'\n", 43 | "k2i_filename = prefix_filename + '_k2i.npy'" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": { 49 | "deletable": true, 50 | "editable": true 51 | }, 52 | "source": [ 53 | "## Create a model or load it" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 2, 59 | "metadata": { 60 | "collapsed": false, 61 | "deletable": true, 62 | "editable": true 63 | }, 64 | "outputs": [ 65 | { 66 | "name": "stdout", 67 | "output_type": "stream", 68 | "text": [ 69 | "load GoogleNews Model\n", 70 | "loading done\n", 71 | "model size= 3000000\n", 72 | "vector size= 300\n" 73 | ] 74 | } 75 | ], 76 | "source": [ 77 | "# Load Google's pre-trained Word2Vec model.\n", 78 | "print \"load GoogleNews Model\"\n", 79 | "model = gensim.models.KeyedVectors.load_word2vec_format('./GoogleNews-vectors-negative300.bin', binary=True) \n", 80 | "print \"loading done\"\n", 81 | "\n", 82 | "hello = model['hello']\n", 83 | "vector_size = len(hello)\n", 84 | "print 'model size=', len(model.vocab)\n", 85 | "print 'vector size=', vector_size" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 3, 91 | "metadata": { 92 | "collapsed": false, 93 | "deletable": true, 94 | "editable": true 95 | }, 96 | "outputs": [ 97 | { 98 | "name": "stdout", 99 | "output_type": "stream", 100 | "text": [ 101 | "creating indexes\n", 102 | "10000 DeLille_Cellars\n", 103 | "20000 igned\n", 104 | "30000 industrial_Ruhr\n", 105 | "40000 ANSI_ASHRAE_IESNA_Standard\n", 106 | "50000 coach_Jay_Vidovich\n", 107 | "60000 Kizil\n", 108 | "70000 Nanakshahi\n", 109 | "80000 iSink_U_Facebook\n", 110 | "90000 Renfrey\n", 111 | "100000 Doctorate_Degree\n", 112 | "110000 Synthetic_Cannabinoids\n", 113 | "120000 Employee_Jeff_Colucy\n", 114 | "130000 Kolbek\n", 115 | "140000 dunce_hat\n", 116 | "150000 Irn_Bru_First\n", 117 | "160000 model_Maggie_Rizer\n", 118 | "170000 OTTAWA_Karlheinz_Schreiber\n", 119 | "180000 BGiles\n", 120 | "190000 prMac.com_Vienna_Austria\n", 121 | "200000 Tina_Pisnik\n", 122 | "210000 undersigned_Rubin_Lublin\n", 123 | "220000 Willnett_Crockett\n", 124 | "230000 Sony_Pictures_Studios\n", 125 | "240000 Voices\n", 126 | "250000 salmon_Delta_smelt\n", 127 | "260000 Yasuaki_Iwamoto_auto\n", 128 | "270000 Ambrose\n", 129 | "280000 DeLamatre\n", 130 | "290000 BY_JOYCE_J._PERSICO\n", 131 | "300000 Austin_Ruse\n", 132 | "310000 Adeline_Teoh\n", 133 | "320000 1_Utama_Shopping\n", 134 | "330000 iSimCity\n", 135 | "340000 symbol_TOT.UN\n", 136 | "350000 southeastwards\n", 137 | "360000 Whitchurch_Heath\n", 138 | "370000 WEXFORD\n", 139 | "380000 Kirk_Baert\n", 140 | "390000 church_renounced_polygamy\n", 141 | "400000 Whitney_Otis_elevators\n", 142 | "410000 Fonze\n", 143 | "420000 Fabian_Babich\n", 144 | "430000 Desmodur_®\n", 145 | "440000 Michael_Egholm_Ph.D.\n", 146 | "450000 Cookie_Zip\n", 147 | "460000 David_Dal_Maso\n", 148 | "470000 Santa_Barbara_Botanic_Garden\n", 149 | "480000 Jellicoe_Road\n", 150 | "490000 E_coli\n", 151 | "500000 Burleith\n", 152 | "510000 LSBK\n", 153 | "520000 Maidie\n", 154 | "530000 Buddha_Nallah\n", 155 | "540000 Coast_Guard_watchstanders\n", 156 | "550000 EPIRB_signal\n", 157 | "560000 Comtech_Telecommunications_Corp.\n", 158 | "570000 Silver_oz.\n", 159 | "580000 Dominique_Darden\n", 160 | "590000 Raiders_George_Blanda\n", 161 | "600000 attract_younger_hipper\n", 162 | "610000 backstabbed\n", 163 | "620000 Omro_Rushford_Fire\n", 164 | "630000 Lasse\n", 165 | "640000 BY_DENNIS_BARTLOW\n", 166 | "650000 Yorkshire_Casualty_Reduction\n", 167 | "660000 Korleone_Young\n", 168 | "670000 David_Dzhokhadze\n", 169 | "680000 Rudy_Currence\n", 170 | "690000 Archive_retrievals\n", 171 | "700000 Dousis\n", 172 | "710000 Albert_Kligman\n", 173 | "720000 actress_Archie_Panjabi\n", 174 | "730000 Wolf_Haldenstein\n", 175 | "740000 Lendel_Thomas\n", 176 | "750000 Grubka\n", 177 | "760000 Rene_Charlebois\n", 178 | "770000 arsenic\n", 179 | "780000 Minkus_Electronic_Display\n", 180 | "790000 Ibi_Kaslik\n", 181 | "800000 Valhall\n", 182 | "810000 visibly_irritated_plainclothes\n", 183 | "820000 Norwegian_expressionist_Edvard\n", 184 | "830000 troy_oz\n", 185 | "840000 PfEMP1\n", 186 | "850000 μ_velOSity\n", 187 | "860000 active_RFID_RTLS\n", 188 | "870000 Pitched_Perfectly\n", 189 | "880000 JRT\n", 190 | "890000 scotch_carts\n", 191 | "900000 Insurer_Humana\n", 192 | "910000 Monica_Isley\n", 193 | "920000 Julie_Deardorff\n", 194 | "930000 Doc_Paulin\n", 195 | "940000 http://www.opengeospatial.org\n", 196 | "950000 Microcell_Telecommunications_Inc.\n", 197 | "960000 PRNewswire_FirstCall_FNDS####_Corp\n", 198 | "970000 MENNONITE\n", 199 | "980000 Gabelsville\n", 200 | "990000 PLN_##.#mn\n", 201 | "1000000 Graebe\n", 202 | "1010000 midges_swarmed\n", 203 | "1020000 Gokool\n", 204 | "1030000 Laura_Stotler_writes\n", 205 | "1040000 remotely_piloted_Predators\n", 206 | "1050000 Antlered_deer\n", 207 | "1060000 desolate_desert\n", 208 | "1070000 DRCE\n", 209 | "1080000 Dhiya_al_Kenani\n", 210 | "1090000 Lily_Ledbetter\n", 211 | "1100000 Flatford\n", 212 | "1110000 nurdled\n", 213 | "1120000 Electron_Optics\n", 214 | "1130000 Hostos_Community_College\n", 215 | "1140000 NVIDIA_PhysX\n", 216 | "1150000 Eurythmics\n", 217 | "1160000 Crazed_Fan\n", 218 | "1170000 Tullett_Liberty\n", 219 | "1180000 Hunka_Hunka_Burnin\n", 220 | "1190000 Eat_Your_Own\n", 221 | "1200000 ENSOR\n", 222 | "1210000 nemesis_Captain_Hook\n", 223 | "1220000 U._Shrinivas\n", 224 | "1230000 penile_enlargement\n", 225 | "1240000 Antolin_Alcaraz\n", 226 | "1250000 parimutuel_betting\n", 227 | "1260000 Shell_Petroleum\n", 228 | "1270000 Woodenbong\n", 229 | "1280000 Kishoreganj_district\n", 230 | "1290000 tee'd\n", 231 | "1300000 Niweigha\n", 232 | "1310000 Engine_Overhaul\n", 233 | "1320000 Jen_Marlowe\n", 234 | "1330000 Nitsch\n", 235 | "1340000 Tiananmen_dissident\n", 236 | "1350000 expletive_laden_banter\n", 237 | "1360000 eurocentric\n", 238 | "1370000 provincial_spokesman_Zulmi\n", 239 | "1380000 Thurrock_Harriers\n", 240 | "1390000 Psychosocial_Factors\n", 241 | "1400000 Imagi_Mangal\n", 242 | "1410000 Marco_Rubio\n", 243 | "1420000 distinguishes\n", 244 | "1430000 Malaman\n", 245 | "1440000 bowled_Peter_Ongondo\n", 246 | "1450000 Patricia_A._Vinchesi\n", 247 | "1460000 Keltbray\n", 248 | "1470000 Jolliff\n", 249 | "1480000 SERIOUS_MAN\n", 250 | "1490000 Egyptians_resealed\n", 251 | "1500000 BROKE_INTO\n", 252 | "1510000 Texas_hold'em\n", 253 | "1520000 Caruso_Benefits\n", 254 | "1530000 ENTrigue_Surgical\n", 255 | "1540000 Southern_Tagalog_Arterial\n", 256 | "1550000 OHL'_murt\n", 257 | "1560000 DeSean_Jackson\n", 258 | "1570000 Adepoju\n", 259 | "1580000 Kachkar\n", 260 | "1590000 &_quotWe\n", 261 | "1600000 Eyewear_Collection\n", 262 | "1610000 novitiate\n", 263 | "1620000 Inferiority_Complex\n", 264 | "1630000 papal_nuncio_Benedict\n", 265 | "1640000 Lyddiard_website\n", 266 | "1650000 MTM###_MPEG_Transport\n", 267 | "1660000 gypsum_stack\n", 268 | "1670000 Taiwan_Straits_ARATS\n", 269 | "1680000 AWARD_FOR_BEST\n", 270 | "1690000 Elmaghraby\n", 271 | "1700000 fright_fests\n", 272 | "1710000 IMMUNE_SYSTEMS\n", 273 | "1720000 spokesman_Larry_Solters\n", 274 | "1730000 burial\n", 275 | "1740000 BY_RAMONA_SHELBURNE\n", 276 | "1750000 Maaroufi\n", 277 | "1760000 bin_Qasim\n", 278 | "1770000 assassinate_Sheik_Jaber\n", 279 | "1780000 Marc_Axton\n", 280 | "1790000 Lockerman\n", 281 | "1800000 Isberto\n", 282 | "1810000 Gatson\n", 283 | "1820000 Jennifer_Klinkert\n", 284 | "1830000 Memento_mori\n", 285 | "1840000 Desi_Arnez\n", 286 | "1850000 HEALTH_PLANS\n", 287 | "1860000 Finger_Lakes_Riesling\n", 288 | "1870000 Colonel_Mengistu_Haile_Mariam\n", 289 | "1880000 Hotel_Sacher\n", 290 | "1890000 Monticello_Ky.\n", 291 | "1900000 Dr._Josyann_Abisaab\n", 292 | "1910000 By_Lawerence_Synett\n", 293 | "1920000 Aprimo_Marketing_Studio\n", 294 | "1930000 Severe_thunderstorms\n", 295 | "1940000 raster_imagery\n", 296 | "1950000 settlement_blocs_Maaleh_Adumim\n", 297 | "1960000 postherpetic_neuralgia\n", 298 | "1970000 Dave_Betras\n", 299 | "1980000 HKY_FLA\n", 300 | "1990000 Gulbudin_Hekmatyar\n", 301 | "2000000 Klip_south\n", 302 | "2010000 crumbling_lakeside\n", 303 | "2020000 Philadelphia_Pa_Lippincott\n", 304 | "2030000 JOBE\n", 305 | "2040000 FABIO_Capello\n", 306 | "2050000 T2_weighted\n", 307 | "2060000 Burke_Badenhop\n", 308 | "2070000 Weather_Stations\n", 309 | "2080000 Sogluizzo\n", 310 | "2090000 AUTHORITY_OF\n", 311 | "2100000 Slim_Goodbody\n", 312 | "2110000 Nizam_Mir\n", 313 | "2120000 NewYork_Presbyterian_Hospital\n", 314 | "2130000 Chairman_Jimmy_Iovine\n", 315 | "2140000 Quail_Run_Elementary\n", 316 | "2150000 neckware_line\n", 317 | "2160000 aminopyralid\n", 318 | "2170000 Essex_Fells\n", 319 | "2180000 DOJ_OIG\n", 320 | "2190000 BONDING\n", 321 | "2200000 fking\n", 322 | "2210000 REALTOR_®_Lockbox_NXT\n", 323 | "2220000 therapeutic_compound_GAMMAGARD\n", 324 | "2230000 Loof\n", 325 | "2240000 Nomir_Medical_Technologies\n", 326 | "2250000 Guilia\n", 327 | "2260000 MetaSphere_application\n", 328 | "2270000 Qi_Ji\n", 329 | "2280000 brewers_Anheuser_Busch\n", 330 | "2290000 8dec\n", 331 | "2300000 Balachander\n", 332 | "2310000 Baoyu\n", 333 | "2320000 forwards_Colby_Armstrong\n", 334 | "2330000 Hanjin_Heavy_Industries\n", 335 | "2340000 Permanently_extending\n", 336 | "2350000 Mohsenian\n", 337 | "2360000 dark_Lord_Voldemort\n", 338 | "2370000 visit_http://www.elpaso.com\n", 339 | "2380000 CONTACT_Universal_Stainless\n", 340 | "2390000 Steve_Schale_Democratic\n", 341 | "2400000 BY_ROB_STEIN\n", 342 | "2410000 INTERNET_TELEPHONY_Conference\n", 343 | "2420000 Cloux_France\n", 344 | "2430000 Superwire_Inc.\n", 345 | "2440000 Rothiemurchus\n", 346 | "2450000 Menarik_Property\n", 347 | "2460000 TOM_MACK\n", 348 | "2470000 noni_juice\n", 349 | "2480000 Larder_Lake_Property\n", 350 | "2490000 confortable\n", 351 | "2500000 Gowad\n", 352 | "2510000 racewinner\n", 353 | "2520000 Hazard_Elimination\n", 354 | "2530000 wet_distiller_grains\n", 355 | "2540000 downtown_honky_tonks\n", 356 | "2550000 potency_dosing\n", 357 | "2560000 Lupercalia\n", 358 | "2570000 Niuatoputapu_wiping\n", 359 | "2580000 lambasted\n", 360 | "2590000 caisse\n", 361 | "2600000 evrything\n", 362 | "2610000 Molecular_Pharmacology_Physiology\n", 363 | "2620000 Bill_Kostroun_FILE\n", 364 | "2630000 directorial_reigns\n", 365 | "2640000 Calle_Ridderwall\n", 366 | "2650000 Schlappi\n", 367 | "2660000 Orin_Hatch\n", 368 | "2670000 cricketer_Anil_Kumble\n", 369 | "2680000 Yoshinori_Nagano_strategist\n", 370 | "2690000 Brahim_Boulami\n", 371 | "2700000 Klaasen\n", 372 | "2710000 Dovetail_Solar\n", 373 | "2720000 Southwark_diocese\n", 374 | "2730000 AND_COUPLE_DANCING\n", 375 | "2740000 disassociates_itself\n", 376 | "2750000 R._Madhavan\n", 377 | "2760000 Kibwezi\n", 378 | "2770000 KIMBERLY_EDDS\n", 379 | "2780000 Sunway_Lagoon_Surf\n", 380 | "2790000 PKR_supreme\n", 381 | "2800000 nonsmokers_groused\n", 382 | "2810000 Glyco\n", 383 | "2820000 IronStone\n", 384 | "2830000 Billings_Forge\n", 385 | "2840000 Famer_Gordie_Howe\n", 386 | "2850000 By_Eric_Mchugh\n", 387 | "2860000 Ingrid_Beckles\n", 388 | "2870000 Savory_Spice\n", 389 | "2880000 Jamnong\n", 390 | "2890000 TVonics\n", 391 | "2900000 MENAFN_Arab\n", 392 | "2910000 Imam_Khomeini_mausoleum\n", 393 | "2920000 Mayor_Arturo_Garino\n", 394 | "2930000 Dogwood_Trail\n", 395 | "2940000 Joel_Scodnick\n", 396 | "2950000 Brenner_CSO\n", 397 | "2960000 Brownscombe\n", 398 | "2970000 Ezra_Cray\n", 399 | "2980000 Katrina_Relief_Efforts\n", 400 | "2990000 EMILY_KAISER\n", 401 | "3000000 Kenneth_Klinge\n", 402 | "building 10 trees\n", 403 | "save files\n", 404 | "done\n" 405 | ] 406 | } 407 | ], 408 | "source": [ 409 | "# process the model and save a model\n", 410 | "# or load the model directly\n", 411 | "vocab = model.vocab.keys()\n", 412 | "#indexNN = AnnoyIndex(vector_size, metric='angular')\n", 413 | "indexNN = AnnoyIndex(vector_size)\n", 414 | "index2key = [None]*len(model.vocab)\n", 415 | "key2index = {}\n", 416 | "\n", 417 | "if not os.path.isfile(ann_filename): \n", 418 | " print 'creating indexes'\n", 419 | " i = 0\n", 420 | " try:\n", 421 | " for key in vocab:\n", 422 | " indexNN.add_item(i, model[key])\n", 423 | " key2index[key]=i\n", 424 | " index2key[i]=key\n", 425 | " i=i+1\n", 426 | " if (i%10000==0):\n", 427 | " print i, key\n", 428 | " except TypeError:\n", 429 | " print 'Error with key', key\n", 430 | " print 'building 10 trees'\n", 431 | " indexNN.build(10) # 10 trees\n", 432 | " print 'save files'\n", 433 | " indexNN.save(ann_filename)\n", 434 | " np.save(i2k_filename, index2key)\n", 435 | " np.save(k2i_filename, key2index)\n", 436 | " print 'done'\n", 437 | "else:\n", 438 | " print \"loading files\"\n", 439 | " indexNN.load(ann_filename)\n", 440 | " index2key = np.load(i2k_filename)\n", 441 | " key2index = np.load(k2i_filename)\n", 442 | " print \"loading done:\", indexNN.get_n_items(), \"items\"" 443 | ] 444 | }, 445 | { 446 | "cell_type": "markdown", 447 | "metadata": { 448 | "deletable": true, 449 | "editable": true 450 | }, 451 | "source": [ 452 | "## King - Male + Female = Queen?\n", 453 | "Nope!\n", 454 | "\n", 455 | "At least not based on a word2vec that is trained on the News..." 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": 10, 461 | "metadata": { 462 | "collapsed": false, 463 | "deletable": true, 464 | "editable": true 465 | }, 466 | "outputs": [ 467 | { 468 | "name": "stdout", 469 | "output_type": "stream", 470 | "text": [ 471 | "king\n" 472 | ] 473 | } 474 | ], 475 | "source": [ 476 | "what_vec = model['king'] - model['male'] + model['female']\n", 477 | "\n", 478 | "what_indexes = indexNN.get_nns_by_vector(what_vec, 1)\n", 479 | "print index2key[what_indexes[0]]" 480 | ] 481 | }, 482 | { 483 | "cell_type": "markdown", 484 | "metadata": { 485 | "deletable": true, 486 | "editable": true 487 | }, 488 | "source": [ 489 | "## King - boy + girl = Queen?\n", 490 | "Yes :) \n", 491 | "but it don't work with man & women :(" 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": 12, 497 | "metadata": { 498 | "collapsed": false, 499 | "deletable": true, 500 | "editable": true 501 | }, 502 | "outputs": [ 503 | { 504 | "name": "stdout", 505 | "output_type": "stream", 506 | "text": [ 507 | "queen\n" 508 | ] 509 | } 510 | ], 511 | "source": [ 512 | "what_vec = model['king'] - model['boy'] + model['girl']\n", 513 | "\n", 514 | "what_indexes = indexNN.get_nns_by_vector(what_vec, 1)\n", 515 | "print index2key[what_indexes[0]]" 516 | ] 517 | }, 518 | { 519 | "cell_type": "code", 520 | "execution_count": 15, 521 | "metadata": { 522 | "collapsed": false, 523 | "deletable": true, 524 | "editable": true 525 | }, 526 | "outputs": [ 527 | { 528 | "name": "stdout", 529 | "output_type": "stream", 530 | "text": [ 531 | "absolute_monarch\n" 532 | ] 533 | } 534 | ], 535 | "source": [ 536 | "what_vec = model['king'] - model['man'] + model['women']\n", 537 | "\n", 538 | "what_indexes = indexNN.get_nns_by_vector(what_vec, 1)\n", 539 | "print index2key[what_indexes[0]]" 540 | ] 541 | }, 542 | { 543 | "cell_type": "markdown", 544 | "metadata": { 545 | "deletable": true, 546 | "editable": true 547 | }, 548 | "source": [ 549 | "## Berlin - Germany + France = Paris?\n", 550 | "Yes!\n", 551 | "\n", 552 | "This makes me happy, but if someone understand why, please tell me!" 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "execution_count": 14, 558 | "metadata": { 559 | "collapsed": false, 560 | "deletable": true, 561 | "editable": true 562 | }, 563 | "outputs": [ 564 | { 565 | "name": "stdout", 566 | "output_type": "stream", 567 | "text": [ 568 | "Paris\n" 569 | ] 570 | } 571 | ], 572 | "source": [ 573 | "what_vec = model['Berlin'] - model['Germany'] + model['France']\n", 574 | "\n", 575 | "what_indexes = indexNN.get_nns_by_vector(what_vec, 1)\n", 576 | "print index2key[what_indexes[0]]" 577 | ] 578 | }, 579 | { 580 | "cell_type": "markdown", 581 | "metadata": { 582 | "deletable": true, 583 | "editable": true 584 | }, 585 | "source": [ 586 | "## Trump - USA + Germany = Hitler?\n", 587 | "FAKE NEWS" 588 | ] 589 | }, 590 | { 591 | "cell_type": "code", 592 | "execution_count": 12, 593 | "metadata": { 594 | "collapsed": false, 595 | "deletable": true, 596 | "editable": true 597 | }, 598 | "outputs": [ 599 | { 600 | "name": "stdout", 601 | "output_type": "stream", 602 | "text": [ 603 | "Dean_Gitter\n" 604 | ] 605 | } 606 | ], 607 | "source": [ 608 | "what_vec = model['Trump'] + model['Germany'] - model['USA']\n", 609 | "what_indexes = indexNN.get_nns_by_vector(what_vec, 1)\n", 610 | "\n", 611 | "for i in what_indexes:\n", 612 | " print index2key[i]" 613 | ] 614 | }, 615 | { 616 | "cell_type": "markdown", 617 | "metadata": { 618 | "deletable": true, 619 | "editable": true 620 | }, 621 | "source": [ 622 | "# Let's explore the stereotypes hidded in the news:" 623 | ] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "execution_count": 53, 628 | "metadata": { 629 | "collapsed": false, 630 | "deletable": true, 631 | "editable": true 632 | }, 633 | "outputs": [ 634 | { 635 | "name": "stdout", 636 | "output_type": "stream", 637 | "text": [ 638 | "king for him, queen for her.\n", 639 | "prince for him, duchess for her.\n", 640 | "male for him, female for her.\n", 641 | "boy for him, girl for her.\n", 642 | "dad for him, motherly_instincts for her.\n", 643 | "father for him, mother for her.\n", 644 | "president for him, president for her.\n", 645 | "dentist for him, plastic_surgeon for her.\n", 646 | "scientist for him, linguistics_professor for her.\n", 647 | "efficient for him, efficient for her.\n", 648 | "teacher for him, teacher for her.\n", 649 | "doctor for him, doctor for her.\n", 650 | "minister for him, minister for her.\n", 651 | "lover for him, seductress for her.\n" 652 | ] 653 | } 654 | ], 655 | "source": [ 656 | "man2women = - model['boy'] + model['girl'] \n", 657 | "\n", 658 | "word_list = [\"king\",\"prince\", \"male\", \"boy\",\"dad\", \"father\", \"president\", \"dentist\",\n", 659 | " \"scientist\", \"efficient\", \"teacher\", \"doctor\", \"minister\", \"lover\"]\n", 660 | "for word in word_list:\n", 661 | " what_vec = model[word] + man2women\n", 662 | " what_indexes = indexNN.get_nns_by_vector(what_vec, 1)\n", 663 | " print word, \"for him,\", index2key[what_indexes[0]], \"for her.\"" 664 | ] 665 | }, 666 | { 667 | "cell_type": "code", 668 | "execution_count": 54, 669 | "metadata": { 670 | "collapsed": false, 671 | "deletable": true, 672 | "editable": true 673 | }, 674 | "outputs": [ 675 | { 676 | "name": "stdout", 677 | "output_type": "stream", 678 | "text": [ 679 | "Berlin is the capital of Germany\n", 680 | "Paris is the capital of France\n", 681 | "Rome is the capital of Italy\n", 682 | "Teen_Poetry_Slam is the capital of USA\n", 683 | "Moscow is the capital of Russia\n", 684 | "kids is the capital of boys\n", 685 | "paddywagon is the capital of cars\n", 686 | "flower is the capital of flowers\n", 687 | "civilians is the capital of soldiers\n", 688 | "Humberto_Campins is the capital of scientists\n" 689 | ] 690 | } 691 | ], 692 | "source": [ 693 | "capital = model['Berlin'] - model['Germany'] \n", 694 | "\n", 695 | "word_list = [\"Germany\", \"France\", \"Italy\", \"USA\", \"Russia\", \"boys\", \"cars\", \"flowers\", \"soldiers\",\n", 696 | " \"scientists\", ]\n", 697 | "for word in word_list:\n", 698 | " what_vec = model[word] + capital\n", 699 | " what_indexes = indexNN.get_nns_by_vector(what_vec, 1)\n", 700 | " print index2key[what_indexes[0]], \"is the capital of\", word" 701 | ] 702 | }, 703 | { 704 | "cell_type": "markdown", 705 | "metadata": { 706 | "deletable": true, 707 | "editable": true 708 | }, 709 | "source": [ 710 | "If you play with this notebook and find good word2vec equation, please tweet them to me! \n", 711 | "__@dh7net__" 712 | ] 713 | } 714 | ], 715 | "metadata": { 716 | "kernelspec": { 717 | "display_name": "Python 2", 718 | "language": "python", 719 | "name": "python2" 720 | }, 721 | "language_info": { 722 | "codemirror_mode": { 723 | "name": "ipython", 724 | "version": 2 725 | }, 726 | "file_extension": ".py", 727 | "mimetype": "text/x-python", 728 | "name": "python", 729 | "nbconvert_exporter": "python", 730 | "pygments_lexer": "ipython2", 731 | "version": "2.7.12+" 732 | } 733 | }, 734 | "nbformat": 4, 735 | "nbformat_minor": 0 736 | } 737 | -------------------------------------------------------------------------------- /Fizz Buzz.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Fizz Buzz with Tensor Flow.\n", 8 | "\n", 9 | "This notebook to explain the [code](https://github.com/joelgrus/fizz-buzz-tensorflow/blob/master/fizz_buzz.py) from [Fizz Buzz in Tensor Flow](http://joelgrus.com/2016/05/23/fizz-buzz-in-tensorflow/) blog post written by __Joel Grus__ \n", 10 | "You should read his post first it is super funny! \n", 11 | "\n", 12 | "His [code](https://github.com/joelgrus/fizz-buzz-tensorflow/blob/master/fizz_buzz.py) try to play the Fizz Buzz game by using machine learning. \n", 13 | "\n", 14 | "This notebook is for real beginners who whant to understand the basis of TensorFlow by reading code. \n", 15 | "Feedback welcome __@dh7net__\n", 16 | " \n", 17 | "## Let's start! \n", 18 | "\n", 19 | "The [code](https://github.com/joelgrus/fizz-buzz-tensorflow/blob/master/fizz_buzz.py) contain several part:\n", 20 | "* Create the training set\n", 21 | " * Encode the input (a number)\n", 22 | " * Encode the result (fizz or buzz, none or both?)\n", 23 | " * create the training set\n", 24 | "* Build a model\n", 25 | "* Train the model\n", 26 | " * Create a cost function\n", 27 | " * Iterate\n", 28 | "* Make prediction" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 1, 34 | "metadata": { 35 | "collapsed": true 36 | }, 37 | "outputs": [], 38 | "source": [ 39 | "import numpy as np\n", 40 | "import tensorflow as tf" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": {}, 46 | "source": [ 47 | "## Create the trainning set\n", 48 | "### Encode the input (a number)\n", 49 | "This example convert the number to a binary representation" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 2, 55 | "metadata": { 56 | "collapsed": false 57 | }, 58 | "outputs": [ 59 | { 60 | "name": "stdout", 61 | "output_type": "stream", 62 | "text": [ 63 | "0 [0 0 0 0 0 0 0 0 0 0]\n", 64 | "1 [1 0 0 0 0 0 0 0 0 0]\n", 65 | "2 [0 1 0 0 0 0 0 0 0 0]\n", 66 | "3 [1 1 0 0 0 0 0 0 0 0]\n", 67 | "4 [0 0 1 0 0 0 0 0 0 0]\n", 68 | "5 [1 0 1 0 0 0 0 0 0 0]\n", 69 | "6 [0 1 1 0 0 0 0 0 0 0]\n", 70 | "7 [1 1 1 0 0 0 0 0 0 0]\n", 71 | "8 [0 0 0 1 0 0 0 0 0 0]\n", 72 | "9 [1 0 0 1 0 0 0 0 0 0]\n" 73 | ] 74 | } 75 | ], 76 | "source": [ 77 | "NUM_DIGITS = 10\n", 78 | "\n", 79 | "def binary_encode(i, num_digits):\n", 80 | " return np.array([i >> d & 1 for d in range(num_digits)])\n", 81 | "\n", 82 | "#Let's check if it works\n", 83 | "for i in range(10):\n", 84 | " print i, binary_encode(i, NUM_DIGITS)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": {}, 90 | "source": [ 91 | "## Encode the result (fizz or buzz, none or both?)\n", 92 | "* The fizz_buzz function calculate what the output should be, an encoded it to a 4 dimention vector. \n", 93 | "* The fizz_buzz function take a number and a prediction, and output a string" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 3, 99 | "metadata": { 100 | "collapsed": false 101 | }, 102 | "outputs": [ 103 | { 104 | "name": "stdout", 105 | "output_type": "stream", 106 | "text": [ 107 | "1 [1 0 0 0]\n", 108 | "2 [1 0 0 0]\n", 109 | "3 [0 1 0 0]\n", 110 | "4 [1 0 0 0]\n", 111 | "5 [0 0 1 0]\n", 112 | "6 [0 1 0 0]\n", 113 | "7 [1 0 0 0]\n", 114 | "8 [1 0 0 0]\n", 115 | "9 [0 1 0 0]\n", 116 | "10 [0 0 1 0]\n", 117 | "11 [1 0 0 0]\n", 118 | "12 [0 1 0 0]\n", 119 | "13 [1 0 0 0]\n", 120 | "14 [1 0 0 0]\n", 121 | "15 [0 0 0 1]\n" 122 | ] 123 | } 124 | ], 125 | "source": [ 126 | "def fizz_buzz_encode(i):\n", 127 | " if i % 15 == 0: return np.array([0, 0, 0, 1])\n", 128 | " elif i % 5 == 0: return np.array([0, 0, 1, 0])\n", 129 | " elif i % 3 == 0: return np.array([0, 1, 0, 0])\n", 130 | " else: return np.array([1, 0, 0, 0])\n", 131 | " \n", 132 | "def fizz_buzz(i, prediction):\n", 133 | " return [str(i), \"fizz\", \"buzz\", \"fizzbuzz\"][prediction]\n", 134 | " \n", 135 | "# let'see how the encoding works\n", 136 | "for i in range(1, 16):\n", 137 | " print i, fizz_buzz_encode(i)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 4, 143 | "metadata": { 144 | "collapsed": false 145 | }, 146 | "outputs": [ 147 | { 148 | "name": "stdout", 149 | "output_type": "stream", 150 | "text": [ 151 | "1 0 1\n", 152 | "2 0 2\n", 153 | "3 1 fizz\n", 154 | "4 0 4\n", 155 | "5 2 buzz\n", 156 | "6 1 fizz\n", 157 | "7 0 7\n", 158 | "8 0 8\n", 159 | "9 1 fizz\n", 160 | "10 2 buzz\n", 161 | "11 0 11\n", 162 | "12 1 fizz\n", 163 | "13 0 13\n", 164 | "14 0 14\n", 165 | "15 3 fizzbuzz\n" 166 | ] 167 | } 168 | ], 169 | "source": [ 170 | "# and the decoding\n", 171 | "for i in range(1, 16):\n", 172 | " fizz_or_buzz_number = np.argmax(fizz_buzz_encode(i))\n", 173 | " print i, fizz_or_buzz_number, fizz_buzz(i, fizz_or_buzz_number)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "### Create the training set" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 5, 186 | "metadata": { 187 | "collapsed": false 188 | }, 189 | "outputs": [ 190 | { 191 | "name": "stdout", 192 | "output_type": "stream", 193 | "text": [ 194 | "Size of the set: 1024\n", 195 | "First 15 values:\n", 196 | "101 [0 1 0 1 0 0 1 1 0 0] [1 0 0 0]\n", 197 | "102 [1 1 0 1 0 0 1 1 0 0] [1 0 0 0]\n", 198 | "103 [0 0 1 1 0 0 1 1 0 0] [0 1 0 0]\n", 199 | "104 [1 0 1 1 0 0 1 1 0 0] [0 0 1 0]\n", 200 | "105 [0 1 1 1 0 0 1 1 0 0] [1 0 0 0]\n", 201 | "106 [1 1 1 1 0 0 1 1 0 0] [0 1 0 0]\n", 202 | "107 [0 0 0 0 1 0 1 1 0 0] [1 0 0 0]\n", 203 | "108 [1 0 0 0 1 0 1 1 0 0] [1 0 0 0]\n", 204 | "109 [0 1 0 0 1 0 1 1 0 0] [0 0 0 1]\n", 205 | "110 [1 1 0 0 1 0 1 1 0 0] [1 0 0 0]\n", 206 | "111 [0 0 1 0 1 0 1 1 0 0] [1 0 0 0]\n", 207 | "112 [1 0 1 0 1 0 1 1 0 0] [0 1 0 0]\n", 208 | "113 [0 1 1 0 1 0 1 1 0 0] [1 0 0 0]\n", 209 | "114 [1 1 1 0 1 0 1 1 0 0] [0 0 1 0]\n", 210 | "115 [0 0 0 1 1 0 1 1 0 0] [0 1 0 0]\n" 211 | ] 212 | } 213 | ], 214 | "source": [ 215 | "training_size = 2 ** NUM_DIGITS\n", 216 | "print \"Size of the set:\", training_size\n", 217 | "trX = np.array([binary_encode(i, NUM_DIGITS) for i in range(101, training_size)])\n", 218 | "trY = np.array([fizz_buzz_encode(i) for i in range(101, training_size)])\n", 219 | "\n", 220 | "print \"First 15 values:\"\n", 221 | "for i in range(101, 116):\n", 222 | " print i, trX[i], trY[i]" 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "metadata": {}, 228 | "source": [ 229 | "## Creation of the model\n", 230 | "\n", 231 | "The model is made of:\n", 232 | "* one hidden layer that contains 100 neurons\n", 233 | "* one output layer \n", 234 | "\n", 235 | "The input is fully connected to the hidden layer and a relu function is applyed \n", 236 | "The relu function is a [rectifier](https://en.wikipedia.org/wiki/Rectifier_%28neural_networks%29) that just output zero if the input is negative.\n", 237 | "\n", 238 | "First we'll define an helper function to initialise parameters with randoms values " 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 6, 244 | "metadata": { 245 | "collapsed": true 246 | }, 247 | "outputs": [], 248 | "source": [ 249 | "def init_weights(shape):\n", 250 | " return tf.Variable(tf.random_normal(shape, stddev=0.01))" 251 | ] 252 | }, 253 | { 254 | "cell_type": "markdown", 255 | "metadata": {}, 256 | "source": [ 257 | "__X__ is the input \n", 258 | "__Y__ is the output \n", 259 | "__w_h__ are the parameters between the input and the hidden layer \n", 260 | "__w_o__ are the parameters between the hidden layer and the output " 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 7, 266 | "metadata": { 267 | "collapsed": true 268 | }, 269 | "outputs": [], 270 | "source": [ 271 | "NUM_HIDDEN = 100 #Number of neuron in the hidden layer\n", 272 | "\n", 273 | "X = tf.placeholder(\"float\", [None, NUM_DIGITS])\n", 274 | "Y = tf.placeholder(\"float\", [None, 4])\n", 275 | "\n", 276 | "w_h = init_weights([NUM_DIGITS, NUM_HIDDEN])\n", 277 | "w_o = init_weights([NUM_HIDDEN, 4])" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "metadata": {}, 283 | "source": [ 284 | "To create the model we apply the __w_h__ parameters to the input, \n", 285 | "and then we aply the relu function to calculate the value of the hidden layer.\n", 286 | " \n", 287 | "The __w_o__ coeefient are used to calculate the output layer. No rectification is applyed \n", 288 | "__py_x__ is the predicted value for a given input represented as a vector (dimention 4)" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 8, 294 | "metadata": { 295 | "collapsed": true 296 | }, 297 | "outputs": [], 298 | "source": [ 299 | "def model(X, w_h, w_o):\n", 300 | " h = tf.nn.relu(tf.matmul(X, w_h))\n", 301 | " return tf.matmul(h, w_o)\n", 302 | "\n", 303 | "py_x = model(X, w_h, w_o)" 304 | ] 305 | }, 306 | { 307 | "cell_type": "markdown", 308 | "metadata": {}, 309 | "source": [ 310 | "# Training\n", 311 | "## Create the cost function\n", 312 | "The cost function measure how bad the model is. \n", 313 | "It is the distance between the prediction (py_x) and the reality (Y).\n" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": 9, 319 | "metadata": { 320 | "collapsed": true 321 | }, 322 | "outputs": [], 323 | "source": [ 324 | "cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(py_x, Y))" 325 | ] 326 | }, 327 | { 328 | "cell_type": "markdown", 329 | "metadata": {}, 330 | "source": [ 331 | "* __softmax_cross_entropy_with_logits(py_x, Y)__ measure the distance between py_x and Y. [SoftMax](https://en.wikipedia.org/wiki/Softmax_function) is the classical way to measure the distance between a predicted result and the actual result in a cost function. \n", 332 | "* __reduce_mean__ calculate the mean of a tensor. In this case the mean of the distance for the whole training set\n", 333 | "\n", 334 | "## Train the model\n", 335 | "Training a model in TensorFlow is extremly simple, you just define a trainer operator! " 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 10, 341 | "metadata": { 342 | "collapsed": false 343 | }, 344 | "outputs": [], 345 | "source": [ 346 | "train_op = tf.train.GradientDescentOptimizer(0.05).minimize(cost)" 347 | ] 348 | }, 349 | { 350 | "cell_type": "markdown", 351 | "metadata": {}, 352 | "source": [ 353 | "This operator will minimize the cost using the [Gradient Descent](https://en.wikipedia.org/wiki/Gradient_descent) witch is the most common optimizer to find parameters than will minimise the cost.\n", 354 | "\n", 355 | "We'll also define a prediction operator that will be able to output a prediction.\n", 356 | "* 0 means no fizz no buzz\n", 357 | "* 1 means fizz\n", 358 | "* 2 means buzz\n", 359 | "* 3 means fizzbuzz\n" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": 11, 365 | "metadata": { 366 | "collapsed": true 367 | }, 368 | "outputs": [], 369 | "source": [ 370 | "predict_op = tf.argmax(py_x, 1)" 371 | ] 372 | }, 373 | { 374 | "cell_type": "markdown", 375 | "metadata": {}, 376 | "source": [ 377 | "## Iterate until the model is good enough\n", 378 | "\n", 379 | "One epoch consists of one full training cycle on the training set.\n", 380 | "Once every sample in the set is seen, you start again - marking the beginning of the 2nd epoch. [source](http://stackoverflow.com/questions/31155388/meaning-of-an-epoch-in-neural-networks-training) \n", 381 | "\n", 382 | "The training set is randomly permuted between each epoch.\n", 383 | "\n", 384 | "The learning is not done on the full set at once. \n", 385 | "Instead the learning set is divided in small batch and the learning is done for each of them.\n" 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": 12, 391 | "metadata": { 392 | "collapsed": true 393 | }, 394 | "outputs": [], 395 | "source": [ 396 | "BATCH_SIZE = 128" 397 | ] 398 | }, 399 | { 400 | "cell_type": "markdown", 401 | "metadata": {}, 402 | "source": [ 403 | "Here an example of index used for one epoch:" 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": 13, 409 | "metadata": { 410 | "collapsed": false 411 | }, 412 | "outputs": [ 413 | { 414 | "name": "stdout", 415 | "output_type": "stream", 416 | "text": [ 417 | "Batch starting at 0\n", 418 | "[417 122 664 322 289 778 679 654 550 363 489 172 24 795 105 333 700 198\n", 419 | " 404 339 471 402 240 859 907 666 495 661 153 296 23 77 423 61 603 9\n", 420 | " 895 254 434 642 807 594 624 248 596 611 35 315 532 866 862 231 705 108\n", 421 | " 204 908 499 552 408 657 651 405 525 734 130 228 53 569 593 391 653 464\n", 422 | " 803 415 632 62 837 458 775 85 148 452 620 367 628 892 576 675 825 73\n", 423 | " 758 441 200 512 529 393 640 138 112 47 717 890 432 854 896 119 521 430\n", 424 | " 874 416 708 650 167 707 94 149 390 765 462 541 201 210 217 223 488 44\n", 425 | " 286 58]\n", 426 | "Batch starting at 128\n", 427 | "[718 36 401 291 680 782 145 886 508 898 563 481 351 400 539 757 513 809\n", 428 | " 540 92 213 706 451 163 534 601 371 186 68 909 93 362 728 482 698 337\n", 429 | " 438 900 45 828 143 395 602 487 365 236 66 466 287 715 811 799 299 480\n", 430 | " 591 251 668 135 867 460 870 492 134 192 566 743 822 597 577 440 600 791\n", 431 | " 738 326 67 733 607 334 370 789 629 102 776 2 677 701 683 83 899 814\n", 432 | " 414 695 835 188 104 522 604 469 574 730 283 781 769 794 465 109 903 65\n", 433 | " 411 801 338 290 741 118 638 273 181 302 581 721 536 468 303 344 332 484\n", 434 | " 245 265]\n", 435 | "Batch starting at 256\n", 436 | "[ 97 505 307 90 472 843 136 726 146 547 241 709 735 687 323 369 524 846\n", 437 | " 678 88 151 711 125 229 314 788 59 819 357 572 724 792 450 313 324 686\n", 438 | " 437 796 614 630 54 79 821 779 384 635 221 787 647 378 644 166 479 179\n", 439 | " 546 353 444 266 72 320 203 209 714 684 838 412 298 463 761 260 573 582\n", 440 | " 257 249 641 562 11 375 270 84 847 117 304 774 195 531 199 387 158 818\n", 441 | " 114 920 162 381 356 606 259 335 912 377 627 282 191 76 60 301 518 128\n", 442 | " 316 37 740 374 622 349 455 592 21 100 857 623 855 749 493 46 475 202\n", 443 | " 906 501]\n", 444 | "Batch starting at 384\n", 445 | "[823 528 634 771 16 71 720 295 588 831 461 427 261 780 808 889 609 506\n", 446 | " 618 154 567 615 570 824 691 820 29 159 293 551 311 643 99 507 403 883\n", 447 | " 503 856 425 802 568 6 913 610 810 246 915 269 113 851 688 40 225 659\n", 448 | " 860 277 214 748 319 490 584 234 565 230 674 509 64 914 876 272 5 511\n", 449 | " 354 858 504 330 187 599 278 267 665 439 616 553 813 878 137 18 587 312\n", 450 | " 621 554 523 7 537 772 388 797 285 682 91 790 555 853 916 182 237 545\n", 451 | " 699 13 459 215 15 474 98 1 731 871 258 454 389 861 347 494 219 756\n", 452 | " 350 81]\n", 453 | "Batch starting at 512\n", 454 | "[737 238 232 255 864 905 276 716 396 739 20 613 893 51 713 111 120 842\n", 455 | " 428 86 655 839 904 564 280 306 736 147 331 32 308 189 150 190 227 300\n", 456 | " 784 921 798 346 770 325 514 161 262 127 38 881 63 631 141 252 762 844\n", 457 | " 559 359 342 443 732 722 806 39 841 271 502 168 873 755 208 800 264 413\n", 458 | " 885 394 827 729 542 918 773 619 543 380 561 115 702 183 829 321 583 56\n", 459 | " 279 180 341 431 343 55 207 764 133 783 205 719 498 284 470 826 863 447\n", 460 | " 693 419 845 317 4 392 305 767 176 101 517 446 226 406 672 852 57 919\n", 461 | " 608 598]\n", 462 | "Batch starting at 640\n", 463 | "[123 491 329 667 681 422 12 865 379 426 435 78 636 696 22 578 637 256\n", 464 | " 177 744 922 164 8 87 89 694 420 250 473 515 193 222 697 887 612 328\n", 465 | " 456 585 558 382 777 785 424 723 268 727 891 442 467 617 595 410 690 25\n", 466 | " 297 14 17 361 368 399 373 385 646 364 185 793 834 310 519 160 496 42\n", 467 | " 850 383 662 575 107 888 52 712 318 648 548 884 30 786 48 156 516 768\n", 468 | " 660 242 911 397 358 376 340 235 658 483 309 549 882 449 703 103 74 211\n", 469 | " 747 586 671 69 902 429 910 804 348 742 901 233 704 247 50 544 605 142\n", 470 | " 33 327]\n", 471 | "Batch starting at 768\n", 472 | "[212 830 216 751 0 144 633 366 833 418 917 849 140 486 880 80 692 75\n", 473 | " 639 131 178 124 579 590 877 110 759 589 95 875 10 485 372 526 894 763\n", 474 | " 560 184 218 836 27 448 663 132 710 152 868 497 126 3 165 294 157 139\n", 475 | " 848 872 49 530 673 121 355 288 174 407 535 170 816 34 453 155 753 41\n", 476 | " 457 626 43 897 253 220 433 538 82 752 436 274 805 760 175 685 676 500\n", 477 | " 476 244 409 527 817 840 652 19 556 869 533 670 520 239 173 129 243 725\n", 478 | " 386 689 197 28 26 96 812 510 292 360 224 171 106 746 571 336 766 169\n", 479 | " 478 649]\n", 480 | "Batch starting at 896\n", 481 | "[345 116 398 281 263 206 477 196 352 656 754 625 745 669 879 445 275 421\n", 482 | " 750 832 31 580 815 194 557 645 70]\n" 483 | ] 484 | } 485 | ], 486 | "source": [ 487 | "#random permutation of the index will be used during the training for each epoch\n", 488 | "permutation_index = np.random.permutation(range(len(trX)))\n", 489 | "for start in range(0, len(trX), BATCH_SIZE):\n", 490 | " end = start + BATCH_SIZE\n", 491 | " print \"Batch starting at\", start\n", 492 | " print permutation_index[start:end]\n" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": 14, 498 | "metadata": { 499 | "collapsed": false 500 | }, 501 | "outputs": [ 502 | { 503 | "name": "stdout", 504 | "output_type": "stream", 505 | "text": [ 506 | "(0, 0.5297941495124594)\n", 507 | "(100, 0.53412784398699886)\n", 508 | "(200, 0.53412784398699886)\n", 509 | "(300, 0.53412784398699886)\n", 510 | "(400, 0.53412784398699886)\n", 511 | "(500, 0.53412784398699886)\n", 512 | "(600, 0.53412784398699886)\n", 513 | "(700, 0.53846153846153844)\n", 514 | "(800, 0.53954496208017333)\n", 515 | "(900, 0.54929577464788737)\n", 516 | "(1000, 0.55254604550379194)\n", 517 | "(1100, 0.55579631635969662)\n", 518 | "(1200, 0.56338028169014087)\n", 519 | "(1300, 0.59046587215601298)\n", 520 | "(1400, 0.61971830985915488)\n", 521 | "(1500, 0.64138678223185264)\n", 522 | "(1600, 0.6619718309859155)\n", 523 | "(1700, 0.6912242686890574)\n", 524 | "(1800, 0.72156013001083419)\n", 525 | "(1900, 0.71722643553629473)\n", 526 | "(2000, 0.78439869989165767)\n", 527 | "(2100, 0.83748645720476711)\n", 528 | "(2200, 0.84507042253521125)\n", 529 | "(2300, 0.83423618634886243)\n", 530 | "(2400, 0.90357529794149516)\n", 531 | "(2500, 0.90465872156013005)\n", 532 | "(2600, 0.90465872156013005)\n", 533 | "(2700, 0.9263271939328277)\n", 534 | "(2800, 0.93824485373781152)\n", 535 | "(2900, 0.93282773564463706)\n", 536 | "(3000, 0.94907908992416035)\n", 537 | "(3100, 0.94366197183098588)\n", 538 | "(3200, 0.93824485373781152)\n", 539 | "(3300, 0.96424702058504874)\n", 540 | "(3400, 0.96099674972914406)\n", 541 | "(3500, 0.9707475622968581)\n", 542 | "(3600, 0.94474539544962077)\n", 543 | "(3700, 0.97616468039003246)\n", 544 | "(3800, 0.95882990249187428)\n", 545 | "(3900, 0.97724810400866735)\n", 546 | "(4000, 0.97508125677139756)\n", 547 | "(4100, 0.98374864572047671)\n", 548 | "(4200, 0.98158179848320692)\n", 549 | "(4300, 0.99133261105092096)\n", 550 | "(4400, 0.97616468039003246)\n", 551 | "(4500, 0.971830985915493)\n", 552 | "(4600, 0.9848320693391116)\n", 553 | "(4700, 0.99241603466955575)\n", 554 | "(4800, 0.99349945828819064)\n", 555 | "(4900, 0.99674972914409532)\n" 556 | ] 557 | } 558 | ], 559 | "source": [ 560 | "# Launch the graph in a session\n", 561 | "sess = tf.Session()\n", 562 | "tf.initialize_all_variables().run(session=sess)\n", 563 | "\n", 564 | "for epoch in range(5000):\n", 565 | " # Shuffle the data before each training iteration.\n", 566 | " p = np.random.permutation(range(len(trX)))\n", 567 | " trX, trY = trX[p], trY[p]\n", 568 | "\n", 569 | " # Train in batches of 128 inputs.\n", 570 | " for start in range(0, len(trX), BATCH_SIZE):\n", 571 | " end = start + BATCH_SIZE\n", 572 | " sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end]})\n", 573 | "\n", 574 | " # And print the current accuracy on the training data.\n", 575 | " if (epoch%100==0): # each 100 epoch, to not overflow the jupyter log\n", 576 | " # np.mean(A==B) return a number between 0 and 1. (true_count/total_count)\n", 577 | " print(epoch, np.mean(np.argmax(trY, axis=1) ==\n", 578 | " sess.run(predict_op, feed_dict={X: trX, Y: trY})))\n" 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "execution_count": 15, 584 | "metadata": { 585 | "collapsed": false 586 | }, 587 | "outputs": [ 588 | { 589 | "name": "stdout", 590 | "output_type": "stream", 591 | "text": [ 592 | "['1' '2' 'fizz' '4' '5' '6' '7' '8' 'fizz' '10' '11' 'fizz' '13' '14'\n", 593 | " 'fizzbuzz' '16' '17' 'fizz' '19' 'buzz' 'fizz' '22' '23' 'fizz' 'fizz'\n", 594 | " '26' 'fizz' '28' '29' 'fizzbuzz' '31' '32' 'fizz' '34' '35' '36' '37' '38'\n", 595 | " '39' '40' '41' '42' '43' '44' 'fizzbuzz' '46' '47' 'fizz' '49' '50' 'fizz'\n", 596 | " '52' 'fizz' '54' 'fizz' '56' 'fizz' '58' '59' 'fizzbuzz' '61' '62' 'fizz'\n", 597 | " '64' 'buzz' 'fizz' '67' '68' 'fizz' 'buzz' '71' '72' '73' '74' 'fizzbuzz'\n", 598 | " '76' '77' 'fizz' '79' '80' 'fizz' '82' 'fizz' 'fizz' 'buzz' '86' 'fizz'\n", 599 | " '88' '89' 'fizzbuzz' '91' '92' 'fizz' '94' 'buzz' '96' '97' 'buzz' '99'\n", 600 | " 'buzz']\n" 601 | ] 602 | } 603 | ], 604 | "source": [ 605 | "# And now for some fizz buzz\n", 606 | "numbers = np.arange(1, 101)\n", 607 | "teX = np.transpose(binary_encode(numbers, NUM_DIGITS))\n", 608 | "teY = sess.run(predict_op, feed_dict={X: teX})\n", 609 | "\n", 610 | "output = np.vectorize(fizz_buzz)(numbers, teY)\n", 611 | "print output" 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": 16, 617 | "metadata": { 618 | "collapsed": true 619 | }, 620 | "outputs": [], 621 | "source": [ 622 | "sess.close() # don't forget to close the session if you don't use it anymore. Or use the *with* statement." 623 | ] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "execution_count": 17, 628 | "metadata": { 629 | "collapsed": false 630 | }, 631 | "outputs": [ 632 | { 633 | "name": "stdout", 634 | "output_type": "stream", 635 | "text": [ 636 | "accuracy 0.81\n", 637 | "1 1 1 True\n", 638 | "2 2 2 True\n", 639 | "3 fizz fizz True\n", 640 | "4 4 4 True\n", 641 | "5 buzz 5 False\n", 642 | "6 fizz 6 False\n", 643 | "7 7 7 True\n", 644 | "8 8 8 True\n", 645 | "9 fizz fizz True\n", 646 | "10 buzz 10 False\n", 647 | "11 11 11 True\n", 648 | "12 fizz fizz True\n", 649 | "13 13 13 True\n", 650 | "14 14 14 True\n", 651 | "15 fizzbuzz fizzbuzz True\n", 652 | "16 16 16 True\n", 653 | "17 17 17 True\n", 654 | "18 fizz fizz True\n", 655 | "19 19 19 True\n", 656 | "20 buzz buzz True\n", 657 | "21 fizz fizz True\n", 658 | "22 22 22 True\n", 659 | "23 23 23 True\n", 660 | "24 fizz fizz True\n", 661 | "25 buzz fizz False\n", 662 | "26 26 26 True\n", 663 | "27 fizz fizz True\n", 664 | "28 28 28 True\n", 665 | "29 29 29 True\n", 666 | "30 fizzbuzz fizzbuzz True\n", 667 | "31 31 31 True\n", 668 | "32 32 32 True\n", 669 | "33 fizz fizz True\n", 670 | "34 34 34 True\n", 671 | "35 buzz 35 False\n", 672 | "36 fizz 36 False\n", 673 | "37 37 37 True\n", 674 | "38 38 38 True\n", 675 | "39 fizz 39 False\n", 676 | "40 buzz 40 False\n", 677 | "41 41 41 True\n", 678 | "42 fizz 42 False\n", 679 | "43 43 43 True\n", 680 | "44 44 44 True\n", 681 | "45 fizzbuzz fizzbuzz True\n", 682 | "46 46 46 True\n", 683 | "47 47 47 True\n", 684 | "48 fizz fizz True\n", 685 | "49 49 49 True\n", 686 | "50 buzz 50 False\n", 687 | "51 fizz fizz True\n", 688 | "52 52 52 True\n", 689 | "53 53 fizz False\n", 690 | "54 fizz 54 False\n", 691 | "55 buzz fizz False\n", 692 | "56 56 56 True\n", 693 | "57 fizz fizz True\n", 694 | "58 58 58 True\n", 695 | "59 59 59 True\n", 696 | "60 fizzbuzz fizzbuzz True\n", 697 | "61 61 61 True\n", 698 | "62 62 62 True\n", 699 | "63 fizz fizz True\n", 700 | "64 64 64 True\n", 701 | "65 buzz buzz True\n", 702 | "66 fizz fizz True\n", 703 | "67 67 67 True\n", 704 | "68 68 68 True\n", 705 | "69 fizz fizz True\n", 706 | "70 buzz buzz True\n", 707 | "71 71 71 True\n", 708 | "72 fizz 72 False\n", 709 | "73 73 73 True\n", 710 | "74 74 74 True\n", 711 | "75 fizzbuzz fizzbuzz True\n", 712 | "76 76 76 True\n", 713 | "77 77 77 True\n", 714 | "78 fizz fizz True\n", 715 | "79 79 79 True\n", 716 | "80 buzz 80 False\n", 717 | "81 fizz fizz True\n", 718 | "82 82 82 True\n", 719 | "83 83 fizz False\n", 720 | "84 fizz fizz True\n", 721 | "85 buzz buzz True\n", 722 | "86 86 86 True\n", 723 | "87 fizz fizz True\n", 724 | "88 88 88 True\n", 725 | "89 89 89 True\n", 726 | "90 fizzbuzz fizzbuzz True\n", 727 | "91 91 91 True\n", 728 | "92 92 92 True\n", 729 | "93 fizz fizz True\n", 730 | "94 94 94 True\n", 731 | "95 buzz buzz True\n", 732 | "96 fizz 96 False\n", 733 | "97 97 97 True\n", 734 | "98 98 buzz False\n", 735 | "99 fizz 99 False\n" 736 | ] 737 | } 738 | ], 739 | "source": [ 740 | "# Lets check the quality\n", 741 | "Y = np.array([fizz_buzz_encode(i) for i in range(1,101)])\n", 742 | "print \"accuracy\", np.mean(np.argmax(Y, axis=1) == teY)\n", 743 | "\n", 744 | "for i in range(1,100):\n", 745 | " actual = fizz_buzz(i, np.argmax(fizz_buzz_encode(i)))\n", 746 | " predicted = output[i-1]\n", 747 | " ok = True\n", 748 | " if actual <> predicted: ok = False\n", 749 | " print i, \"{:>8}\".format(actual), \"{:>8}\".format(predicted), ok" 750 | ] 751 | }, 752 | { 753 | "cell_type": "markdown", 754 | "metadata": {}, 755 | "source": [ 756 | "# Conclusion\n", 757 | "Using Tensor flow to solve fizz buzz is overkill and not very accurate. \n", 758 | "But is is fun and a nice way to learn Tensor Flow!\n", 759 | "\n", 760 | "Feedback welcome __@dh7net__" 761 | ] 762 | } 763 | ], 764 | "metadata": { 765 | "kernelspec": { 766 | "display_name": "Python 2", 767 | "language": "python", 768 | "name": "python2" 769 | }, 770 | "language_info": { 771 | "codemirror_mode": { 772 | "name": "ipython", 773 | "version": 2 774 | }, 775 | "file_extension": ".py", 776 | "mimetype": "text/x-python", 777 | "name": "python", 778 | "nbconvert_exporter": "python", 779 | "pygments_lexer": "ipython2", 780 | "version": "2.7.10" 781 | } 782 | }, 783 | "nbformat": 4, 784 | "nbformat_minor": 0 785 | } 786 | -------------------------------------------------------------------------------- /rnn_face_tests/LFW Model Face V0.3.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# RNN for pictures genaration\n", 8 | "This notebook is an experiment. I tryed to generate a picture pixel by pixel using an RNN. \n", 9 | "each pixel can be black or white. \n", 10 | "__WIP__" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": {}, 16 | "source": [ 17 | "## Import needed for Jupiter" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 38, 23 | "metadata": { 24 | "collapsed": true 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "%matplotlib notebook\n", 29 | "import matplotlib\n", 30 | "import matplotlib.pyplot as plt\n", 31 | "\n", 32 | "from IPython.display import Image" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "## Import needed for the code" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 39, 45 | "metadata": { 46 | "collapsed": false 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "import numpy as np\n", 51 | "import tensorflow as tf\n", 52 | "\n", 53 | "import fnmatch, os\n", 54 | "import time" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "## Helpers functions\n", 62 | "to save a picture" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 40, 68 | "metadata": { 69 | "collapsed": true 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "#need to be called within a session\n", 74 | "def write_png(tensor, name):\n", 75 | " casted_to_uint8 = tf.cast(tensor, tf.uint8)\n", 76 | " converted_to_png = tf.image.encode_png(casted_to_uint8)\n", 77 | " f = open(name, \"wb+\")\n", 78 | " f.write(converted_to_png.eval())\n", 79 | " f.close() " 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": { 85 | "collapsed": true 86 | }, 87 | "source": [ 88 | "## A class to define all args" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 42, 94 | "metadata": { 95 | "collapsed": false 96 | }, 97 | "outputs": [], 98 | "source": [ 99 | "class Args():\n", 100 | " def __init__(self):\n", 101 | " '''directory to store checkpointed models'''\n", 102 | " self.save_dir = 'save_face_training_v0.3'\n", 103 | " \n", 104 | " '''Picture size'''\n", 105 | " self.picture_size = 32\n", 106 | " \n", 107 | " '''size of RNN hidden state'''\n", 108 | " self.rnn_size = self.picture_size*3 \n", 109 | " '''minibatch size'''\n", 110 | " self.batch_size = 1\n", 111 | " '''RNN sequence length'''\n", 112 | " self.seq_length = self.picture_size\n", 113 | " '''number of epochs'''\n", 114 | " self.num_epochs = 10 # was 5\n", 115 | " '''save frequency'''\n", 116 | " self.save_every = 100 # was 500\n", 117 | " '''Print frequency'''\n", 118 | " self.print_every = 10\n", 119 | " '''clip gradients at this value'''\n", 120 | " self.grad_clip = 5.\n", 121 | " '''learning rate'''\n", 122 | " self.learning_rate = 0.002 # was 0.002\n", 123 | " '''decay rate for rmsprop'''\n", 124 | " self.decay_rate = 0.98\n", 125 | " \"\"\"continue training from saved model at this path.\n", 126 | " Path must contain files saved by previous training process: \"\"\"\n", 127 | " #self.init_from = 'save_face_training'\n", 128 | " self.init_from = None\n", 129 | " \n", 130 | " '''number of ligne to sample'''\n", 131 | " self.n = 250\n" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 43, 137 | "metadata": { 138 | "collapsed": true 139 | }, 140 | "outputs": [], 141 | "source": [ 142 | "class FaceLoader:\n", 143 | " def prepare_reading_faces(self):\n", 144 | " self.matches = []\n", 145 | " \n", 146 | " for root, dirnames, filenames in os.walk('./lfw/'):\n", 147 | " #print filenames\n", 148 | " for filename in fnmatch.filter(filenames, '*.jpg'):\n", 149 | " self.matches.append(os.path.join(root, filename))\n", 150 | "\n", 151 | " size = len(self.matches)\n", 152 | "\n", 153 | " filenames = tf.constant(self.matches)\n", 154 | " self.filename_queue = tf.train.string_input_producer(filenames)\n", 155 | " self.image_reader = tf.WholeFileReader()\n", 156 | " return size\n", 157 | " \n", 158 | " def do_when_session(self): \n", 159 | " # For some reason, we need a coordinator and some threads\n", 160 | " self.coord = tf.train.Coordinator()\n", 161 | " self.threads = tf.train.start_queue_runners(coord=self.coord)\n", 162 | "\n", 163 | " def stop_reading_faces(self):\n", 164 | " # Finish off the filename queue coordinator.\n", 165 | " self.coord.request_stop()\n", 166 | " self.coord.join(self.threads)\n", 167 | " \n", 168 | " def load_one_face(self, image_size):\n", 169 | " # read and decode image, will give a uint8 with shape [250, 250, 1]\n", 170 | " filename, image_file = self.image_reader.read(self.filename_queue) \n", 171 | " image = tf.image.decode_jpeg(image_file, channels=1)\n", 172 | " #resize\n", 173 | " image = tf.image.resize_images(image, image_size, image_size)\n", 174 | "\n", 175 | " # remove channel dimension\n", 176 | " tensor_uint8 = tf.squeeze(image, squeeze_dims=[2])\n", 177 | "\n", 178 | " # convert to float32 and scale\n", 179 | " face = tf.cast(tensor_uint8, tf.float32)/255.0\n", 180 | " self.picture = tf.constant(face.eval())\n", 181 | " #print self.picture\n", 182 | " \n", 183 | " def get_bw_picts(self, level): \n", 184 | " bw = (tf.sign(self.picture-level)+1)/2\n", 185 | " #print bw.eval()\n", 186 | " return bw\n", 187 | " \n", 188 | " def get_training_set():\n", 189 | " xdata = a_vector_face.eval()\n", 190 | " ydata = np.copy(xdata)\n", 191 | " ydata[:-1] = xdata[1:]\n", 192 | " ydata[-1] = xdata[0]\n", 193 | " self.x_batches = np.squeeze(np.split(xdata, image_size, 0))\n", 194 | " self.y_batches = np.squeeze(np.split(ydata, image_size, 0))\n", 195 | " \n", 196 | " def next_batch(self):\n", 197 | " return self.x_batches, self.y_batches" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": {}, 203 | "source": [ 204 | "## This code to that the formulas are working.\n", 205 | "It create a list of pictures \n", 206 | "Useless for now." 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 53, 212 | "metadata": { 213 | "collapsed": false 214 | }, 215 | "outputs": [ 216 | { 217 | "data": { 218 | "text/plain": [ 219 | "'\\ntf.reset_default_graph()\\nfaceloader = FaceLoader()\\nface_count = faceloader.prepare_reading_faces()\\nwith tf.Session() as sess:\\n tf.initialize_all_variables().run()\\n faceloader.do_when_session()\\n faceloader.load_one_face(250)\\n for i in range(255/2):\\n bw = faceloader.get_bw_picts(i*2/255.)\\n bw = tf.expand_dims(bw, 2)\\n write_png(bw*255., \"generated{:06}.png\".format(i))\\n'" 220 | ] 221 | }, 222 | "execution_count": 53, 223 | "metadata": {}, 224 | "output_type": "execute_result" 225 | } 226 | ], 227 | "source": [ 228 | "'''\n", 229 | "tf.reset_default_graph()\n", 230 | "faceloader = FaceLoader()\n", 231 | "face_count = faceloader.prepare_reading_faces()\n", 232 | "with tf.Session() as sess:\n", 233 | " tf.initialize_all_variables().run()\n", 234 | " faceloader.do_when_session()\n", 235 | " faceloader.load_one_face(250)\n", 236 | " for i in range(255/2):\n", 237 | " bw = faceloader.get_bw_picts(i*2/255.)\n", 238 | " bw = tf.expand_dims(bw, 2)\n", 239 | " write_png(bw*255., \"generated{:06}.png\".format(i))\n", 240 | "'''" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 54, 246 | "metadata": { 247 | "collapsed": false 248 | }, 249 | "outputs": [ 250 | { 251 | "data": { 252 | "text/plain": [ 253 | "'\\nfrom PIL import Image, ImageSequence\\nimport glob, sys, os\\nos.chdir(\".\")\\nframes = []\\nfor file in glob.glob(\"gene*.png\"):\\n print(file)\\n im = Image.open(file)\\n frames.append(im)\\n\\nfrom images2gif import writeGif\\nwriteGif(\"generated.gif\", frames, duration=0.1)\\n'" 254 | ] 255 | }, 256 | "execution_count": 54, 257 | "metadata": {}, 258 | "output_type": "execute_result" 259 | } 260 | ], 261 | "source": [ 262 | "'''\n", 263 | "from PIL import Image, ImageSequence\n", 264 | "import glob, sys, os\n", 265 | "os.chdir(\".\")\n", 266 | "frames = []\n", 267 | "for file in glob.glob(\"gene*.png\"):\n", 268 | " print(file)\n", 269 | " im = Image.open(file)\n", 270 | " frames.append(im)\n", 271 | "\n", 272 | "from images2gif import writeGif\n", 273 | "writeGif(\"generated.gif\", frames, duration=0.1)\n", 274 | "'''" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 55, 280 | "metadata": { 281 | "collapsed": false 282 | }, 283 | "outputs": [], 284 | "source": [ 285 | "class Model():\n", 286 | " def __init__(self, args, infer=False):\n", 287 | " self.args = args\n", 288 | " #if infer:\n", 289 | " # '''Infer is true when the model is used for sampling'''\n", 290 | " # args.seq_length = 1\n", 291 | " \n", 292 | " hidden_size = args.rnn_size\n", 293 | " vector_size = args.picture_size\n", 294 | " \n", 295 | " # define place holder to for the input data and the target.\n", 296 | " self.input_data = tf.placeholder(tf.float32, [ args.seq_length, vector_size], name='input_data')\n", 297 | " self.target_data = tf.placeholder(tf.float32, [ args.seq_length, vector_size], name='target_data') \n", 298 | " # define the input xs\n", 299 | " xs = tf.split(0, args.seq_length, self.input_data)\n", 300 | " # define the target\n", 301 | " targets = tf.split(0, args.seq_length, self.target_data) \n", 302 | " #initial_state\n", 303 | " self.initial_state = tf.zeros((hidden_size,1))\n", 304 | " #last_state = tf.placeholder(tf.float32, (hidden_size, 1))\n", 305 | " \n", 306 | " # model parameters\n", 307 | " Wxh = tf.Variable(tf.random_uniform((hidden_size, vector_size))*0.01, name='Wxh') # input to hidden\n", 308 | " Wph = tf.Variable(tf.random_uniform((hidden_size, vector_size))*0.01, name='Wph') # position to hidden\n", 309 | " Whh = tf.Variable(tf.random_uniform((hidden_size, hidden_size))*0.01, name='Whh') # hidden to hidden\n", 310 | " Why = tf.Variable(tf.random_uniform((vector_size, hidden_size))*0.01, name='Why') # hidden to output\n", 311 | " bh = tf.Variable(tf.zeros((hidden_size, 1)), name='bh') # hidden bias\n", 312 | " by = tf.Variable(tf.zeros((vector_size, 1)), name='by') # output bias\n", 313 | " loss = tf.zeros([1], name='loss')\n", 314 | " self.pos = tf.Variable(0.0, trainable=False, name='pos')\n", 315 | " hs, ys, ps = {}, {}, {}\n", 316 | " \n", 317 | " hs[-1] = self.initial_state\n", 318 | " # forward pass \n", 319 | " for t in xrange(args.seq_length):\n", 320 | " xs_t = tf.transpose(xs[t])\n", 321 | " if infer and t>0:\n", 322 | " xs_t = ys[t-1]\n", 323 | " targets_t = tf.transpose(targets[t])\n", 324 | " indices = [[t, 0]]\n", 325 | " values = [1.0]\n", 326 | " shape = [args.seq_length, 1]\n", 327 | " delta = tf.SparseTensor(indices, values, shape) \n", 328 | " position = tf.zeros([vector_size, 1]) + tf.sparse_tensor_to_dense(delta)\n", 329 | " \n", 330 | " hs[t] = tf.sigmoid(tf.matmul(Wxh, xs_t) \n", 331 | " + tf.matmul(Whh, hs[t-1]) \n", 332 | " + tf.matmul(Wph, position)\n", 333 | " + bh) # hidden state\n", 334 | " ys[t] = tf.matmul(Why, hs[t]) + by # unnormalized log probabilities for next line\n", 335 | " ys[t] = tf.sigmoid(ys[t])\n", 336 | " #ps[t] = tf.exp(ys[t]) / tf.reduce_sum(tf.exp(ys[t])) # probabilities for next chars\n", 337 | " loss += tf.reduce_sum(tf.abs(ys[t]-targets_t))\n", 338 | " \n", 339 | "\n", 340 | " self.probs = tf.pack([ys[key] for key in ys])\n", 341 | " self.cost = loss / args.batch_size / args.seq_length\n", 342 | " self.final_state = hs[args.seq_length-1]\n", 343 | " self.lr = tf.Variable(0.0, trainable=False, name='learning_rate')\n", 344 | " tvars = tf.trainable_variables()\n", 345 | " grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),\n", 346 | " args.grad_clip)\n", 347 | " optimizer = tf.train.AdamOptimizer(self.lr)\n", 348 | " self.train_op = optimizer.apply_gradients(zip(grads, tvars))\n", 349 | "\n", 350 | " def sample(self, sess):\n", 351 | " size = self.args.picture_size\n", 352 | " picture_vect = np.zeros((size, size))\n", 353 | " state = model.initial_state.eval()\n", 354 | " x = np.random.random([1, size])\n", 355 | " feed = {self.input_data: x, self.initial_state:state}\n", 356 | " [probs, state] = sess.run([self.probs, self.final_state], feed)\n", 357 | " for n in range(0):\n", 358 | " line = np.transpose(probs)\n", 359 | " feed = {self.input_data: line, self.initial_state:state}\n", 360 | " [probs, state] = sess.run([self.probs, self.final_state], feed)\n", 361 | " for n in range(size):\n", 362 | " line = np.transpose(probs)\n", 363 | " feed = {self.input_data: line, self.initial_state:state}\n", 364 | " [probs, state] = sess.run([self.probs, self.final_state], feed)\n", 365 | " #print probs\n", 366 | " #line = (np.sign(probs-0.5)+1)/2\n", 367 | " #print line\n", 368 | " picture_vect[n] = np.squeeze(line) \n", 369 | " picture = picture_vect*255\n", 370 | " return tf.expand_dims(picture,2)\n", 371 | " \n", 372 | " def inspect(self, draw=False):\n", 373 | " for var in tf.all_variables():\n", 374 | " if var in tf.trainable_variables():\n", 375 | " print ('t', var.name, var.eval().shape)\n", 376 | " if draw:\n", 377 | " plt.figure(figsize=(1,1))\n", 378 | " plt.figimage(var.eval())\n", 379 | " plt.show()\n", 380 | " else:\n", 381 | " print ('nt', var.name, var.eval().shape)\n", 382 | " \n", 383 | " " 384 | ] 385 | }, 386 | { 387 | "cell_type": "markdown", 388 | "metadata": {}, 389 | "source": [ 390 | "## Training" 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": null, 396 | "metadata": { 397 | "collapsed": false 398 | }, 399 | "outputs": [ 400 | { 401 | "name": "stdout", 402 | "output_type": "stream", 403 | "text": [ 404 | "model created\n", 405 | "('faces count', 13233)\n", 406 | "variable initialized\n", 407 | "0/132330 (epoch 0), train_loss = 11.846004, time/batch = 0.299\n", 408 | "model saved to save_face_training/model.ckpt\n", 409 | "10/132330 (epoch 0), train_loss = 7.609998, time/batch = 0.150\n", 410 | "20/132330 (epoch 0), train_loss = 7.214982, time/batch = 0.150\n", 411 | "30/132330 (epoch 0), train_loss = 4.345308, time/batch = 0.148\n", 412 | "40/132330 (epoch 0), train_loss = 7.447813, time/batch = 0.154\n", 413 | "50/132330 (epoch 0), train_loss = 7.482661, time/batch = 0.154\n", 414 | "60/132330 (epoch 0), train_loss = 6.279762, time/batch = 0.153\n", 415 | "70/132330 (epoch 0), train_loss = 7.785715, time/batch = 0.156\n", 416 | "80/132330 (epoch 0), train_loss = 4.964209, time/batch = 0.161\n", 417 | "90/132330 (epoch 0), train_loss = 7.774941, time/batch = 0.163\n", 418 | "100/132330 (epoch 0), train_loss = 10.457413, time/batch = 0.162\n", 419 | "model saved to save_face_training/model.ckpt\n", 420 | "110/132330 (epoch 0), train_loss = 7.914732, time/batch = 0.165\n", 421 | "120/132330 (epoch 0), train_loss = 7.111222, time/batch = 0.167\n", 422 | "130/132330 (epoch 0), train_loss = 7.023408, time/batch = 0.168\n", 423 | "140/132330 (epoch 0), train_loss = 4.848790, time/batch = 0.169\n", 424 | "150/132330 (epoch 0), train_loss = 7.302919, time/batch = 0.171\n", 425 | "160/132330 (epoch 0), train_loss = 3.674625, time/batch = 0.174\n", 426 | "170/132330 (epoch 0), train_loss = 8.263390, time/batch = 0.176\n", 427 | "180/132330 (epoch 0), train_loss = 6.758085, time/batch = 0.182\n", 428 | "190/132330 (epoch 0), train_loss = 3.910988, time/batch = 0.179\n", 429 | "200/132330 (epoch 0), train_loss = 5.194339, time/batch = 0.180\n", 430 | "model saved to save_face_training/model.ckpt\n", 431 | "210/132330 (epoch 0), train_loss = 4.551933, time/batch = 0.183\n", 432 | "220/132330 (epoch 0), train_loss = 6.324289, time/batch = 0.190\n", 433 | "230/132330 (epoch 0), train_loss = 5.959601, time/batch = 0.186\n", 434 | "240/132330 (epoch 0), train_loss = 7.450030, time/batch = 0.188\n", 435 | "250/132330 (epoch 0), train_loss = 5.124949, time/batch = 0.195\n", 436 | "260/132330 (epoch 0), train_loss = 4.154919, time/batch = 0.200\n", 437 | "270/132330 (epoch 0), train_loss = 4.003494, time/batch = 0.196\n", 438 | "280/132330 (epoch 0), train_loss = 5.118635, time/batch = 0.205\n", 439 | "290/132330 (epoch 0), train_loss = 4.427326, time/batch = 0.205\n", 440 | "300/132330 (epoch 0), train_loss = 4.765638, time/batch = 0.208\n", 441 | "model saved to save_face_training/model.ckpt\n", 442 | "310/132330 (epoch 0), train_loss = 3.735344, time/batch = 0.211\n", 443 | "320/132330 (epoch 0), train_loss = 4.571231, time/batch = 0.206\n", 444 | "330/132330 (epoch 0), train_loss = 5.801054, time/batch = 0.207\n", 445 | "340/132330 (epoch 0), train_loss = 5.453507, time/batch = 0.210\n", 446 | "350/132330 (epoch 0), train_loss = 6.524479, time/batch = 0.214\n", 447 | "360/132330 (epoch 0), train_loss = 4.466758, time/batch = 0.218\n", 448 | "370/132330 (epoch 0), train_loss = 3.093597, time/batch = 0.215\n", 449 | "380/132330 (epoch 0), train_loss = 6.006961, time/batch = 0.220\n", 450 | "390/132330 (epoch 0), train_loss = 5.414852, time/batch = 0.221\n", 451 | "400/132330 (epoch 0), train_loss = 7.497438, time/batch = 0.225\n", 452 | "model saved to save_face_training/model.ckpt\n", 453 | "410/132330 (epoch 0), train_loss = 6.339993, time/batch = 0.224\n", 454 | "420/132330 (epoch 0), train_loss = 6.619037, time/batch = 0.230\n", 455 | "430/132330 (epoch 0), train_loss = 5.751145, time/batch = 0.228\n", 456 | "440/132330 (epoch 0), train_loss = 5.581079, time/batch = 0.231\n", 457 | "450/132330 (epoch 0), train_loss = 11.383739, time/batch = 0.233\n", 458 | "460/132330 (epoch 0), train_loss = 5.102082, time/batch = 0.233\n", 459 | "470/132330 (epoch 0), train_loss = 4.909813, time/batch = 0.243\n", 460 | "480/132330 (epoch 0), train_loss = 7.428552, time/batch = 0.245\n", 461 | "490/132330 (epoch 0), train_loss = 6.681951, time/batch = 0.240\n", 462 | "500/132330 (epoch 0), train_loss = 5.431161, time/batch = 0.245\n", 463 | "model saved to save_face_training/model.ckpt\n", 464 | "510/132330 (epoch 0), train_loss = 6.323401, time/batch = 0.248\n", 465 | "520/132330 (epoch 0), train_loss = 6.324689, time/batch = 0.255\n", 466 | "530/132330 (epoch 0), train_loss = 5.522001, time/batch = 0.254\n", 467 | "540/132330 (epoch 0), train_loss = 4.788240, time/batch = 0.255\n", 468 | "550/132330 (epoch 0), train_loss = 5.925033, time/batch = 0.259\n", 469 | "560/132330 (epoch 0), train_loss = 7.541386, time/batch = 0.259\n", 470 | "570/132330 (epoch 0), train_loss = 4.742616, time/batch = 0.270\n", 471 | "580/132330 (epoch 0), train_loss = 3.831331, time/batch = 0.262\n", 472 | "590/132330 (epoch 0), train_loss = 8.389158, time/batch = 0.278\n", 473 | "600/132330 (epoch 0), train_loss = 6.273591, time/batch = 0.268\n", 474 | "model saved to save_face_training/model.ckpt\n", 475 | "610/132330 (epoch 0), train_loss = 5.904160, time/batch = 0.270\n", 476 | "620/132330 (epoch 0), train_loss = 7.140617, time/batch = 0.268\n", 477 | "630/132330 (epoch 0), train_loss = 5.809810, time/batch = 0.267\n", 478 | "640/132330 (epoch 0), train_loss = 5.787280, time/batch = 0.272\n", 479 | "650/132330 (epoch 0), train_loss = 8.121449, time/batch = 0.278\n", 480 | "660/132330 (epoch 0), train_loss = 3.951409, time/batch = 0.281\n", 481 | "670/132330 (epoch 0), train_loss = 5.007551, time/batch = 0.280\n", 482 | "680/132330 (epoch 0), train_loss = 5.078654, time/batch = 0.277\n", 483 | "690/132330 (epoch 0), train_loss = 4.589311, time/batch = 0.291\n", 484 | "700/132330 (epoch 0), train_loss = 4.262353, time/batch = 0.295\n", 485 | "model saved to save_face_training/model.ckpt\n", 486 | "710/132330 (epoch 0), train_loss = 4.878027, time/batch = 0.284\n", 487 | "720/132330 (epoch 0), train_loss = 5.099174, time/batch = 0.287\n", 488 | "730/132330 (epoch 0), train_loss = 8.126123, time/batch = 0.292\n", 489 | "740/132330 (epoch 0), train_loss = 4.916740, time/batch = 0.291\n", 490 | "750/132330 (epoch 0), train_loss = 4.964235, time/batch = 0.293\n", 491 | "760/132330 (epoch 0), train_loss = 5.020784, time/batch = 0.297\n", 492 | "770/132330 (epoch 0), train_loss = 5.582267, time/batch = 0.301\n", 493 | "780/132330 (epoch 0), train_loss = 3.607513, time/batch = 0.298\n", 494 | "790/132330 (epoch 0), train_loss = 4.957433, time/batch = 0.304\n", 495 | "800/132330 (epoch 0), train_loss = 8.756006, time/batch = 0.303\n", 496 | "model saved to save_face_training/model.ckpt\n", 497 | "810/132330 (epoch 0), train_loss = 5.832938, time/batch = 0.314\n", 498 | "820/132330 (epoch 0), train_loss = 7.171200, time/batch = 0.310\n", 499 | "830/132330 (epoch 0), train_loss = 5.898413, time/batch = 0.322\n", 500 | "840/132330 (epoch 0), train_loss = 4.798182, time/batch = 0.318\n", 501 | "850/132330 (epoch 0), train_loss = 3.033494, time/batch = 0.326\n", 502 | "860/132330 (epoch 0), train_loss = 4.715693, time/batch = 0.323\n", 503 | "870/132330 (epoch 0), train_loss = 4.410551, time/batch = 0.449\n", 504 | "880/132330 (epoch 0), train_loss = 3.671957, time/batch = 0.340\n", 505 | "890/132330 (epoch 0), train_loss = 3.211787, time/batch = 0.410\n", 506 | "900/132330 (epoch 0), train_loss = 5.762773, time/batch = 0.347\n", 507 | "model saved to save_face_training/model.ckpt\n", 508 | "910/132330 (epoch 0), train_loss = 6.225863, time/batch = 0.348\n", 509 | "920/132330 (epoch 0), train_loss = 6.778830, time/batch = 0.416\n", 510 | "930/132330 (epoch 0), train_loss = 4.799048, time/batch = 0.447\n", 511 | "940/132330 (epoch 0), train_loss = 4.508432, time/batch = 0.343\n", 512 | "950/132330 (epoch 0), train_loss = 6.060311, time/batch = 0.364\n", 513 | "960/132330 (epoch 0), train_loss = 5.704145, time/batch = 0.363\n", 514 | "970/132330 (epoch 0), train_loss = 6.299800, time/batch = 0.376\n", 515 | "980/132330 (epoch 0), train_loss = 4.064652, time/batch = 0.376\n", 516 | "990/132330 (epoch 0), train_loss = 5.421255, time/batch = 0.351\n", 517 | "1000/132330 (epoch 0), train_loss = 3.249725, time/batch = 0.374\n", 518 | "model saved to save_face_training/model.ckpt\n", 519 | "1010/132330 (epoch 0), train_loss = 5.766055, time/batch = 0.370\n", 520 | "1020/132330 (epoch 0), train_loss = 8.624304, time/batch = 0.368\n", 521 | "1030/132330 (epoch 0), train_loss = 3.750180, time/batch = 0.365\n", 522 | "1040/132330 (epoch 0), train_loss = 5.541049, time/batch = 0.368\n", 523 | "1050/132330 (epoch 0), train_loss = 5.823425, time/batch = 0.366\n", 524 | "1060/132330 (epoch 0), train_loss = 3.467629, time/batch = 0.382\n", 525 | "1070/132330 (epoch 0), train_loss = 4.423463, time/batch = 0.396\n", 526 | "1080/132330 (epoch 0), train_loss = 7.298230, time/batch = 0.418\n", 527 | "1090/132330 (epoch 0), train_loss = 6.507596, time/batch = 0.405\n", 528 | "1100/132330 (epoch 0), train_loss = 4.531975, time/batch = 0.421\n", 529 | "model saved to save_face_training/model.ckpt\n", 530 | "1110/132330 (epoch 0), train_loss = 3.758112, time/batch = 0.430\n", 531 | "1120/132330 (epoch 0), train_loss = 3.868862, time/batch = 0.475\n", 532 | "1130/132330 (epoch 0), train_loss = 7.370373, time/batch = 0.409\n", 533 | "1140/132330 (epoch 0), train_loss = 4.399141, time/batch = 0.415\n", 534 | "1150/132330 (epoch 0), train_loss = 4.341646, time/batch = 0.395\n", 535 | "1160/132330 (epoch 0), train_loss = 4.806339, time/batch = 0.402\n", 536 | "1170/132330 (epoch 0), train_loss = 4.351828, time/batch = 0.431\n", 537 | "1180/132330 (epoch 0), train_loss = 3.155379, time/batch = 0.403\n", 538 | "1190/132330 (epoch 0), train_loss = 5.923162, time/batch = 0.412\n", 539 | "1200/132330 (epoch 0), train_loss = 3.506185, time/batch = 0.429\n", 540 | "model saved to save_face_training/model.ckpt\n", 541 | "1210/132330 (epoch 0), train_loss = 4.326102, time/batch = 0.413\n", 542 | "1220/132330 (epoch 0), train_loss = 5.531622, time/batch = 0.421\n", 543 | "1230/132330 (epoch 0), train_loss = 7.466490, time/batch = 0.407\n", 544 | "1240/132330 (epoch 0), train_loss = 5.292984, time/batch = 0.434\n", 545 | "1250/132330 (epoch 0), train_loss = 9.296682, time/batch = 0.442\n", 546 | "1260/132330 (epoch 0), train_loss = 7.542732, time/batch = 0.463\n", 547 | "1270/132330 (epoch 0), train_loss = 5.754212, time/batch = 0.416\n", 548 | "1280/132330 (epoch 0), train_loss = 6.197461, time/batch = 0.441\n", 549 | "1290/132330 (epoch 0), train_loss = 4.116388, time/batch = 0.452\n", 550 | "1300/132330 (epoch 0), train_loss = 6.148061, time/batch = 0.456\n", 551 | "model saved to save_face_training/model.ckpt\n", 552 | "1310/132330 (epoch 0), train_loss = 4.099864, time/batch = 0.446\n", 553 | "1320/132330 (epoch 0), train_loss = 4.672244, time/batch = 0.461\n", 554 | "1330/132330 (epoch 0), train_loss = 4.554517, time/batch = 0.463\n", 555 | "1340/132330 (epoch 0), train_loss = 4.426936, time/batch = 0.456\n", 556 | "1350/132330 (epoch 0), train_loss = 8.222431, time/batch = 0.448\n", 557 | "1360/132330 (epoch 0), train_loss = 2.372349, time/batch = 0.474\n", 558 | "1370/132330 (epoch 0), train_loss = 4.806358, time/batch = 0.453\n", 559 | "1380/132330 (epoch 0), train_loss = 2.799022, time/batch = 0.478\n", 560 | "1390/132330 (epoch 0), train_loss = 3.600282, time/batch = 0.461\n", 561 | "1400/132330 (epoch 0), train_loss = 5.314572, time/batch = 0.487\n", 562 | "model saved to save_face_training/model.ckpt\n", 563 | "1410/132330 (epoch 0), train_loss = 3.716839, time/batch = 0.474\n", 564 | "1420/132330 (epoch 0), train_loss = 5.867580, time/batch = 0.502\n", 565 | "1430/132330 (epoch 0), train_loss = 3.307088, time/batch = 0.483\n", 566 | "1440/132330 (epoch 0), train_loss = 5.661429, time/batch = 0.524\n", 567 | "1450/132330 (epoch 0), train_loss = 5.752923, time/batch = 0.488\n", 568 | "1460/132330 (epoch 0), train_loss = 5.262812, time/batch = 0.516\n", 569 | "1470/132330 (epoch 0), train_loss = 4.028398, time/batch = 0.525\n", 570 | "1480/132330 (epoch 0), train_loss = 3.386873, time/batch = 0.508\n", 571 | "1490/132330 (epoch 0), train_loss = 7.146141, time/batch = 0.537\n", 572 | "1500/132330 (epoch 0), train_loss = 2.592293, time/batch = 0.503\n", 573 | "model saved to save_face_training/model.ckpt\n", 574 | "1510/132330 (epoch 0), train_loss = 4.420721, time/batch = 0.540\n", 575 | "1520/132330 (epoch 0), train_loss = 6.661725, time/batch = 0.538\n", 576 | "1530/132330 (epoch 0), train_loss = 3.700504, time/batch = 0.534\n", 577 | "1540/132330 (epoch 0), train_loss = 6.506403, time/batch = 0.519\n", 578 | "1550/132330 (epoch 0), train_loss = 4.536848, time/batch = 0.544\n", 579 | "1560/132330 (epoch 0), train_loss = 5.640853, time/batch = 0.551\n", 580 | "1570/132330 (epoch 0), train_loss = 5.638056, time/batch = 0.520\n", 581 | "1580/132330 (epoch 0), train_loss = 7.173485, time/batch = 0.541\n", 582 | "1590/132330 (epoch 0), train_loss = 2.604686, time/batch = 0.538\n", 583 | "1600/132330 (epoch 0), train_loss = 3.820972, time/batch = 0.534\n", 584 | "model saved to save_face_training/model.ckpt\n", 585 | "1610/132330 (epoch 0), train_loss = 3.754754, time/batch = 0.554\n", 586 | "1620/132330 (epoch 0), train_loss = 2.470329, time/batch = 0.541\n", 587 | "1630/132330 (epoch 0), train_loss = 5.454981, time/batch = 0.544\n", 588 | "1640/132330 (epoch 0), train_loss = 3.191897, time/batch = 0.556\n", 589 | "1650/132330 (epoch 0), train_loss = 3.515225, time/batch = 0.558\n", 590 | "1660/132330 (epoch 0), train_loss = 2.818658, time/batch = 0.554\n", 591 | "1670/132330 (epoch 0), train_loss = 3.251315, time/batch = 0.559\n", 592 | "1680/132330 (epoch 0), train_loss = 4.379750, time/batch = 0.575\n", 593 | "1690/132330 (epoch 0), train_loss = 6.107162, time/batch = 0.581\n", 594 | "1700/132330 (epoch 0), train_loss = 10.520744, time/batch = 0.573\n", 595 | "model saved to save_face_training/model.ckpt\n", 596 | "1710/132330 (epoch 0), train_loss = 5.984098, time/batch = 0.576\n", 597 | "1720/132330 (epoch 0), train_loss = 3.026716, time/batch = 0.578\n", 598 | "1730/132330 (epoch 0), train_loss = 5.400842, time/batch = 0.577\n", 599 | "1740/132330 (epoch 0), train_loss = 4.510511, time/batch = 0.601\n", 600 | "1750/132330 (epoch 0), train_loss = 5.325934, time/batch = 0.585\n", 601 | "1760/132330 (epoch 0), train_loss = 5.739883, time/batch = 0.597\n", 602 | "1770/132330 (epoch 0), train_loss = 2.015807, time/batch = 0.606\n", 603 | "1780/132330 (epoch 0), train_loss = 2.713206, time/batch = 0.585\n", 604 | "1790/132330 (epoch 0), train_loss = 5.520914, time/batch = 0.639\n", 605 | "1800/132330 (epoch 0), train_loss = 2.586047, time/batch = 0.644\n", 606 | "model saved to save_face_training/model.ckpt\n", 607 | "1810/132330 (epoch 0), train_loss = 5.679832, time/batch = 0.645\n", 608 | "1820/132330 (epoch 0), train_loss = 3.191137, time/batch = 0.657\n", 609 | "1830/132330 (epoch 0), train_loss = 3.479776, time/batch = 0.647\n", 610 | "1840/132330 (epoch 0), train_loss = 6.601597, time/batch = 0.668\n", 611 | "1850/132330 (epoch 0), train_loss = 3.503648, time/batch = 0.658\n", 612 | "1860/132330 (epoch 0), train_loss = 3.836609, time/batch = 0.663\n", 613 | "1870/132330 (epoch 0), train_loss = 3.950723, time/batch = 0.663\n", 614 | "1880/132330 (epoch 0), train_loss = 5.474760, time/batch = 0.691\n", 615 | "1890/132330 (epoch 0), train_loss = 4.600953, time/batch = 0.659\n", 616 | "1900/132330 (epoch 0), train_loss = 5.153743, time/batch = 0.668\n", 617 | "model saved to save_face_training/model.ckpt\n", 618 | "1910/132330 (epoch 0), train_loss = 4.155967, time/batch = 0.638\n", 619 | "1920/132330 (epoch 0), train_loss = 3.450524, time/batch = 0.714\n", 620 | "1930/132330 (epoch 0), train_loss = 3.735923, time/batch = 0.656\n", 621 | "1940/132330 (epoch 0), train_loss = 3.744927, time/batch = 0.642\n", 622 | "1950/132330 (epoch 0), train_loss = 3.671499, time/batch = 0.744\n", 623 | "1960/132330 (epoch 0), train_loss = 3.720388, time/batch = 0.656\n", 624 | "1970/132330 (epoch 0), train_loss = 5.815224, time/batch = 0.670\n", 625 | "1980/132330 (epoch 0), train_loss = 3.059415, time/batch = 0.672\n", 626 | "1990/132330 (epoch 0), train_loss = 3.465191, time/batch = 0.677\n", 627 | "2000/132330 (epoch 0), train_loss = 2.620191, time/batch = 0.690\n", 628 | "model saved to save_face_training/model.ckpt\n", 629 | "2010/132330 (epoch 0), train_loss = 4.495060, time/batch = 0.685\n" 630 | ] 631 | } 632 | ], 633 | "source": [ 634 | "tf.reset_default_graph()\n", 635 | "args = Args()\n", 636 | "model = Model(args)\n", 637 | "print (\"model created\")\n", 638 | "faceloader = FaceLoader()\n", 639 | "face_count = faceloader.prepare_reading_faces()\n", 640 | "print ('faces count', face_count)\n", 641 | "\n", 642 | "cost_optimisation = []\n", 643 | "\n", 644 | "with tf.Session() as sess:\n", 645 | " tf.initialize_all_variables().run()\n", 646 | " print (\"variable initialized\")\n", 647 | " faceloader.do_when_session()\n", 648 | " saver = tf.train.Saver(tf.all_variables())\n", 649 | "\n", 650 | " # restore model\n", 651 | " if args.init_from is not None:\n", 652 | " ckpt = tf.train.get_checkpoint_state(args.init_from)\n", 653 | " assert ckpt,\"No checkpoint found\"\n", 654 | " assert ckpt.model_checkpoint_path,\"No model path found in checkpoint\"\n", 655 | " saver.restore(sess, ckpt.model_checkpoint_path)\n", 656 | " print (\"model restored\")\n", 657 | " for e in range(args.num_epochs):\n", 658 | " faceloader.image_reader.reset()\n", 659 | " sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))\n", 660 | " state = model.initial_state.eval()\n", 661 | " for b in range(face_count):\n", 662 | " start = time.time()\n", 663 | " # Get learning data\n", 664 | " faceloader.load_one_face(args.picture_size)\n", 665 | " x, y = faceloader.next_batch()\n", 666 | " # Create the structure for the learning data\n", 667 | " feed = {model.input_data: x, model.target_data: y, model.initial_state: state}\n", 668 | " # Run a session using train_op\n", 669 | " [train_loss], state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)\n", 670 | " end = time.time()\n", 671 | " if (e * face_count + b) % args.print_every == 0:\n", 672 | " cost_optimisation.append(train_loss)\n", 673 | " print(\"{}/{} (epoch {}), train_loss = {:.6f}, time/batch = {:.3f}\" \\\n", 674 | " .format(e * face_count + b,\n", 675 | " args.num_epochs * face_count,\n", 676 | " e, train_loss, end - start))\n", 677 | " if (e * face_count + b) % args.save_every == 0:\n", 678 | " checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')\n", 679 | " saver.save(sess, checkpoint_path, global_step = e * face_count + b)\n", 680 | " print(\"model saved to {}\".format(checkpoint_path))\n", 681 | " np.save('cost', cost_optimisation)\n" 682 | ] 683 | }, 684 | { 685 | "cell_type": "code", 686 | "execution_count": null, 687 | "metadata": { 688 | "collapsed": false 689 | }, 690 | "outputs": [], 691 | "source": [ 692 | "cost_optimisation = np.load('cost.npy')\n", 693 | "plt.figure(figsize=(12,5))\n", 694 | "plt.plot(range(len(cost_optimisation)), cost_optimisation, label='cost')\n", 695 | "plt.legend()\n", 696 | "plt.show()" 697 | ] 698 | }, 699 | { 700 | "cell_type": "code", 701 | "execution_count": null, 702 | "metadata": { 703 | "collapsed": false 704 | }, 705 | "outputs": [], 706 | "source": [ 707 | "tf.reset_default_graph()\n", 708 | "args = Args()\n", 709 | "model = Model(args, True) # True to generate the model in sampling mode\n", 710 | "with tf.Session() as sess:\n", 711 | " tf.initialize_all_variables().run()\n", 712 | " saver = tf.train.Saver(tf.all_variables())\n", 713 | " ckpt = tf.train.get_checkpoint_state(args.save_dir)\n", 714 | " print (ckpt)\n", 715 | " \n", 716 | " model.inspect(draw=True)" 717 | ] 718 | }, 719 | { 720 | "cell_type": "markdown", 721 | "metadata": {}, 722 | "source": [ 723 | "## sampling" 724 | ] 725 | }, 726 | { 727 | "cell_type": "code", 728 | "execution_count": 10, 729 | "metadata": { 730 | "collapsed": false 731 | }, 732 | "outputs": [ 733 | { 734 | "name": "stdout", 735 | "output_type": "stream", 736 | "text": [ 737 | "intialisation done\n", 738 | "model_checkpoint_path: \"save_face_training/model.ckpt-2000\"\n", 739 | "all_model_checkpoint_paths: \"save_face_training/model.ckpt-1600\"\n", 740 | "all_model_checkpoint_paths: \"save_face_training/model.ckpt-1700\"\n", 741 | "all_model_checkpoint_paths: \"save_face_training/model.ckpt-1800\"\n", 742 | "all_model_checkpoint_paths: \"save_face_training/model.ckpt-1900\"\n", 743 | "all_model_checkpoint_paths: \"save_face_training/model.ckpt-2000\"\n", 744 | "\n" 745 | ] 746 | } 747 | ], 748 | "source": [ 749 | "tf.reset_default_graph()\n", 750 | "args = Args()\n", 751 | "model = Model(args, infer=True)\n", 752 | "with tf.Session() as sess:\n", 753 | " tf.initialize_all_variables().run()\n", 754 | " print 'intialisation done'\n", 755 | " saver = tf.train.Saver(tf.all_variables())\n", 756 | " ckpt = tf.train.get_checkpoint_state(args.save_dir)\n", 757 | " print (ckpt)\n", 758 | " \n", 759 | " if ckpt and ckpt.model_checkpoint_path:\n", 760 | " saver.restore(sess, ckpt.model_checkpoint_path)\n", 761 | " \n", 762 | " state = model.initial_state.eval()\n", 763 | " x = np.random.random([args.picture_size, args.picture_size])\n", 764 | " feed = {model.input_data: x, model.initial_state: state}\n", 765 | " [lines] = sess.run([model.probs], feed)\n", 766 | " pict = tf.expand_dims(lines*255,2)\n", 767 | " #print(pict.eval())\n", 768 | " write_png(lines*255, 'a_face.png')\n", 769 | " " 770 | ] 771 | }, 772 | { 773 | "cell_type": "code", 774 | "execution_count": 11, 775 | "metadata": { 776 | "collapsed": false 777 | }, 778 | "outputs": [ 779 | { 780 | "data": { 781 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAAAAABWESUoAAAB6klEQVQ4jW1SQbLbOgwDQDpOMtMb\n/0v+s3Ta1I4IdOG8ThfVSsSAEgCS/z3urWSt1zv3cj2Ogc9Dt6I8nN7veyuWViKKqoAbtr0pm9Nb\nbyULkcYwPEOi+rZJNldXSWIsxpPJmiHDqi7ZZFdVyRHg92C6zkVBkFQy0FUtMSN4MXafQ7ETkCHd\nFElCBJIM7SGYAAADtsdAEtT2EFzPc1FWEYkT91GsSlZ033dZz3Mx71XIxMb02awKxqh+tPXYFn28\nmGHGmZ5VQGIbkCBJBJkE9mB6ZhjEscg/BAAIEqRvt1spWGuO169yPY8h3nNTSR5U3++75MA5z5+c\nfl1BNarKBfVaVUnAvu83TT3PRc705RPug2khtvbt2dbjHPr4VbA1k9W2B5di2YZtxnHCJEhXt8pw\n1pyrpt7HIte6BVQACwBBIFfgCIBcRZCgJZGACGfZfr8XOQYkIVBfPQBIllhVIQPyQtF2ksQJ8VeS\nF2i6cX2XeLDirLXIubqAS0MxIfyeJTfORZkdkCTVkiiDiFdob+dQ0UcV2V/WvgDKBGGbvDTkkxlI\nEbo4QBI7uN66pv/v0yBIgiSBrwuBqwJakqjos/cfdxcDJFsqyREJE0Pb0EcVAPb/2yNY83p9//GD\n4PaYN3S7f3vutXm29RsylZQD43DXWAAAAABJRU5ErkJggg==\n", 782 | "text/plain": [ 783 | "" 784 | ] 785 | }, 786 | "execution_count": 11, 787 | "metadata": {}, 788 | "output_type": "execute_result" 789 | } 790 | ], 791 | "source": [ 792 | "Image(\"a_face.png\")" 793 | ] 794 | }, 795 | { 796 | "cell_type": "markdown", 797 | "metadata": {}, 798 | "source": [ 799 | "Feedback wellcome __@dh7net__" 800 | ] 801 | } 802 | ], 803 | "metadata": { 804 | "kernelspec": { 805 | "display_name": "Python 2", 806 | "language": "python", 807 | "name": "python2" 808 | }, 809 | "language_info": { 810 | "codemirror_mode": { 811 | "name": "ipython", 812 | "version": 2 813 | }, 814 | "file_extension": ".py", 815 | "mimetype": "text/x-python", 816 | "name": "python", 817 | "nbconvert_exporter": "python", 818 | "pygments_lexer": "ipython2", 819 | "version": "2.7.10" 820 | } 821 | }, 822 | "nbformat": 4, 823 | "nbformat_minor": 0 824 | } 825 | -------------------------------------------------------------------------------- /images2gif.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (C) 2012, Almar Klein, Ant1, Marius van Voorden 3 | # 4 | # This code is subject to the (new) BSD license: 5 | # 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions are met: 8 | # * Redistributions of source code must retain the above copyright 9 | # notice, this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # * Neither the name of the nor the 14 | # names of its contributors may be used to endorse or promote products 15 | # derived from this software without specific prior written permission. 16 | # 17 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | # ARE DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY 21 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 22 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 23 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 24 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | 28 | """ Module images2gif 29 | Provides functionality for reading and writing animated GIF images. 30 | Use writeGif to write a series of numpy arrays or PIL images as an 31 | animated GIF. Use readGif to read an animated gif as a series of numpy 32 | arrays. 33 | Note that since July 2004, all patents on the LZW compression patent have 34 | expired. Therefore the GIF format may now be used freely. 35 | Acknowledgements 36 | ---------------- 37 | Many thanks to Ant1 for: 38 | * noting the use of "palette=PIL.Image.ADAPTIVE", which significantly 39 | improves the results. 40 | * the modifications to save each image with its own palette, or optionally 41 | the global palette (if its the same). 42 | Many thanks to Marius van Voorden for porting the NeuQuant quantization 43 | algorithm of Anthony Dekker to Python (See the NeuQuant class for its 44 | license). 45 | Many thanks to Alex Robinson for implementing the concept of subrectangles, 46 | which (depening on image content) can give a very significant reduction in 47 | file size. 48 | This code is based on gifmaker (in the scripts folder of the source 49 | distribution of PIL) 50 | Usefull links 51 | ------------- 52 | * http://tronche.com/computer-graphics/gif/ 53 | * http://en.wikipedia.org/wiki/Graphics_Interchange_Format 54 | * http://www.w3.org/Graphics/GIF/spec-gif89a.txt 55 | """ 56 | # todo: This module should be part of imageio (or at least based on) 57 | 58 | import os, time 59 | 60 | def encode(x): 61 | if False: 62 | return x.encode('utf-8') 63 | return x 64 | 65 | try: 66 | import PIL 67 | from PIL import Image 68 | from PIL.GifImagePlugin import getheader, getdata 69 | except ImportError: 70 | PIL = None 71 | 72 | try: 73 | import numpy as np 74 | except ImportError: 75 | np = None 76 | 77 | def get_cKDTree(): 78 | try: 79 | from scipy.spatial import cKDTree 80 | except ImportError: 81 | cKDTree = None 82 | return cKDTree 83 | 84 | 85 | # getheader gives a 87a header and a color palette (two elements in a list). 86 | # getdata()[0] gives the Image Descriptor up to (including) "LZW min code size". 87 | # getdatas()[1:] is the image data itself in chuncks of 256 bytes (well 88 | # technically the first byte says how many bytes follow, after which that 89 | # amount (max 255) follows). 90 | 91 | def checkImages(images): 92 | """ checkImages(images) 93 | Check numpy images and correct intensity range etc. 94 | The same for all movie formats. 95 | """ 96 | # Init results 97 | images2 = [] 98 | 99 | for im in images: 100 | if PIL and isinstance(im, PIL.Image.Image): 101 | # We assume PIL images are allright 102 | images2.append(im) 103 | 104 | elif np and isinstance(im, np.ndarray): 105 | # Check and convert dtype 106 | if im.dtype == np.uint8: 107 | images2.append(im) # Ok 108 | elif im.dtype in [np.float32, np.float64]: 109 | im = im.copy() 110 | im[im<0] = 0 111 | im[im>1] = 1 112 | im *= 255 113 | images2.append( im.astype(np.uint8) ) 114 | else: 115 | im = im.astype(np.uint8) 116 | images2.append(im) 117 | # Check size 118 | if im.ndim == 2: 119 | pass # ok 120 | elif im.ndim == 3: 121 | if im.shape[2] not in [3,4]: 122 | raise ValueError('This array can not represent an image.') 123 | else: 124 | raise ValueError('This array can not represent an image.') 125 | else: 126 | raise ValueError('Invalid image type: ' + str(type(im))) 127 | 128 | # Done 129 | return images2 130 | 131 | 132 | def intToBin(i): 133 | """ Integer to two bytes """ 134 | # devide in two parts (bytes) 135 | i1 = i % 256 136 | i2 = int( i/256) 137 | # make string (little endian) 138 | return chr(i1) + chr(i2) 139 | 140 | 141 | class GifWriter: 142 | """ GifWriter() 143 | Class that contains methods for helping write the animated GIF file. 144 | """ 145 | 146 | def getheaderAnim(self, im): 147 | """ getheaderAnim(im) 148 | Get animation header. To replace PILs getheader()[0] 149 | """ 150 | bb = "GIF89a" 151 | bb += intToBin(im.size[0]) 152 | bb += intToBin(im.size[1]) 153 | bb += "\x87\x00\x00" 154 | return bb 155 | 156 | 157 | def getImageDescriptor(self, im, xy=None): 158 | """ getImageDescriptor(im, xy=None) 159 | Used for the local color table properties per image. 160 | Otherwise global color table applies to all frames irrespective of 161 | whether additional colors comes in play that require a redefined 162 | palette. Still a maximum of 256 color per frame, obviously. 163 | Written by Ant1 on 2010-08-22 164 | Modified by Alex Robinson in Janurari 2011 to implement subrectangles. 165 | """ 166 | 167 | # Defaule use full image and place at upper left 168 | if xy is None: 169 | xy = (0,0) 170 | 171 | # Image separator, 172 | bb = '\x2C' 173 | 174 | # Image position and size 175 | bb += intToBin( xy[0] ) # Left position 176 | bb += intToBin( xy[1] ) # Top position 177 | bb += intToBin( im.size[0] ) # image width 178 | bb += intToBin( im.size[1] ) # image height 179 | 180 | # packed field: local color table flag1, interlace0, sorted table0, 181 | # reserved00, lct size111=7=2^(7+1)=256. 182 | bb += '\x87' 183 | 184 | # LZW minimum size code now comes later, begining of [image data] blocks 185 | return bb 186 | 187 | 188 | def getAppExt(self, loops=float('inf')): 189 | """ getAppExt(loops=float('inf')) 190 | Application extention. This part specifies the amount of loops. 191 | If loops is 0 or inf, it goes on infinitely. 192 | """ 193 | 194 | if loops==0 or loops==float('inf'): 195 | loops = 2**16-1 196 | #bb = "" # application extension should not be used 197 | # (the extension interprets zero loops 198 | # to mean an infinite number of loops) 199 | # Mmm, does not seem to work 200 | if True: 201 | bb = "\x21\xFF\x0B" # application extension 202 | bb += "NETSCAPE2.0" 203 | bb += "\x03\x01" 204 | bb += intToBin(loops) 205 | bb += '\x00' # end 206 | return bb 207 | 208 | 209 | def getGraphicsControlExt(self, duration=0.1, dispose=2): 210 | """ getGraphicsControlExt(duration=0.1, dispose=2) 211 | Graphics Control Extension. A sort of header at the start of 212 | each image. Specifies duration and transparancy. 213 | Dispose 214 | ------- 215 | * 0 - No disposal specified. 216 | * 1 - Do not dispose. The graphic is to be left in place. 217 | * 2 - Restore to background color. The area used by the graphic 218 | must be restored to the background color. 219 | * 3 - Restore to previous. The decoder is required to restore the 220 | area overwritten by the graphic with what was there prior to 221 | rendering the graphic. 222 | * 4-7 -To be defined. 223 | """ 224 | 225 | bb = '\x21\xF9\x04' 226 | bb += chr((dispose & 3) << 2) # low bit 1 == transparency, 227 | # 2nd bit 1 == user input , next 3 bits, the low two of which are used, 228 | # are dispose. 229 | bb += intToBin( int(duration*100) ) # in 100th of seconds 230 | bb += '\x00' # no transparant color 231 | bb += '\x00' # end 232 | return bb 233 | 234 | 235 | def handleSubRectangles(self, images, subRectangles): 236 | """ handleSubRectangles(images) 237 | Handle the sub-rectangle stuff. If the rectangles are given by the 238 | user, the values are checked. Otherwise the subrectangles are 239 | calculated automatically. 240 | """ 241 | 242 | if isinstance(subRectangles, (tuple,list)): 243 | # xy given directly 244 | 245 | # Check xy 246 | xy = subRectangles 247 | if xy is None: 248 | xy = (0,0) 249 | if hasattr(xy, '__len__'): 250 | if len(xy) == len(images): 251 | xy = [xxyy for xxyy in xy] 252 | else: 253 | raise ValueError("len(xy) doesn't match amount of images.") 254 | else: 255 | xy = [xy for im in images] 256 | xy[0] = (0,0) 257 | 258 | else: 259 | # Calculate xy using some basic image processing 260 | 261 | # Check Numpy 262 | if np is None: 263 | raise RuntimeError("Need Numpy to use auto-subRectangles.") 264 | 265 | # First make numpy arrays if required 266 | for i in range(len(images)): 267 | im = images[i] 268 | if isinstance(im, Image.Image): 269 | tmp = im.convert() # Make without palette 270 | a = np.asarray(tmp) 271 | if len(a.shape)==0: 272 | raise MemoryError("Too little memory to convert PIL image to array") 273 | images[i] = a 274 | 275 | # Determine the sub rectangles 276 | images, xy = self.getSubRectangles(images) 277 | 278 | # Done 279 | return images, xy 280 | 281 | 282 | def getSubRectangles(self, ims): 283 | """ getSubRectangles(ims) 284 | Calculate the minimal rectangles that need updating each frame. 285 | Returns a two-element tuple containing the cropped images and a 286 | list of x-y positions. 287 | Calculating the subrectangles takes extra time, obviously. However, 288 | if the image sizes were reduced, the actual writing of the GIF 289 | goes faster. In some cases applying this method produces a GIF faster. 290 | """ 291 | 292 | # Check image count 293 | if len(ims) < 2: 294 | return ims, [(0,0) for i in ims] 295 | 296 | # We need numpy 297 | if np is None: 298 | raise RuntimeError("Need Numpy to calculate sub-rectangles. ") 299 | 300 | # Prepare 301 | ims2 = [ims[0]] 302 | xy = [(0,0)] 303 | t0 = time.time() 304 | 305 | # Iterate over images 306 | prev = ims[0] 307 | for im in ims[1:]: 308 | 309 | # Get difference, sum over colors 310 | diff = np.abs(im-prev) 311 | if diff.ndim==3: 312 | diff = diff.sum(2) 313 | # Get begin and end for both dimensions 314 | X = np.argwhere(diff.sum(0)) 315 | Y = np.argwhere(diff.sum(1)) 316 | # Get rect coordinates 317 | if X.size and Y.size: 318 | x0, x1 = X[0], X[-1]+1 319 | y0, y1 = Y[0], Y[-1]+1 320 | else: # No change ... make it minimal 321 | x0, x1 = 0, 2 322 | y0, y1 = 0, 2 323 | 324 | # Cut out and store 325 | im2 = im[y0:y1,x0:x1] 326 | prev = im 327 | ims2.append(im2) 328 | xy.append((x0,y0)) 329 | 330 | # Done 331 | #print('%1.2f seconds to determine subrectangles of %i images' % 332 | # (time.time()-t0, len(ims2)) ) 333 | return ims2, xy 334 | 335 | 336 | def convertImagesToPIL(self, images, dither, nq=0): 337 | """ convertImagesToPIL(images, nq=0) 338 | Convert images to Paletted PIL images, which can then be 339 | written to a single animaged GIF. 340 | """ 341 | 342 | # Convert to PIL images 343 | images2 = [] 344 | for im in images: 345 | if isinstance(im, Image.Image): 346 | images2.append(im) 347 | elif np and isinstance(im, np.ndarray): 348 | if im.ndim==3 and im.shape[2]==3: 349 | im = Image.fromarray(im,'RGB') 350 | elif im.ndim==3 and im.shape[2]==4: 351 | im = Image.fromarray(im[:,:,:3],'RGB') 352 | elif im.ndim==2: 353 | im = Image.fromarray(im,'L') 354 | images2.append(im) 355 | 356 | # Convert to paletted PIL images 357 | images, images2 = images2, [] 358 | if nq >= 1: 359 | # NeuQuant algorithm 360 | for im in images: 361 | im = im.convert("RGBA") # NQ assumes RGBA 362 | nqInstance = NeuQuant(im, int(nq)) # Learn colors from image 363 | if dither: 364 | im = im.convert("RGB").quantize(palette=nqInstance.paletteImage()) 365 | else: 366 | im = nqInstance.quantize(im) # Use to quantize the image itself 367 | images2.append(im) 368 | else: 369 | # Adaptive PIL algorithm 370 | AD = Image.ADAPTIVE 371 | for im in images: 372 | im = im.convert('P', palette=AD, dither=dither) 373 | images2.append(im) 374 | 375 | # Done 376 | return images2 377 | 378 | 379 | def writeGifToFile(self, fp, images, durations, loops, xys, disposes): 380 | """ writeGifToFile(fp, images, durations, loops, xys, disposes) 381 | Given a set of images writes the bytes to the specified stream. 382 | """ 383 | 384 | # Obtain palette for all images and count each occurance 385 | palettes, occur = [], [] 386 | for im in images: 387 | #palette = getheader(im)[1] 388 | palette = getheader(im)[0][-1] 389 | if not palette: 390 | #palette = PIL.ImagePalette.ImageColor 391 | palette = im.palette.tobytes() 392 | palettes.append(palette) 393 | for palette in palettes: 394 | occur.append( palettes.count( palette ) ) 395 | 396 | # Select most-used palette as the global one (or first in case no max) 397 | globalPalette = palettes[ occur.index(max(occur)) ] 398 | 399 | # Init 400 | frames = 0 401 | firstFrame = True 402 | 403 | 404 | for im, palette in zip(images, palettes): 405 | 406 | if firstFrame: 407 | # Write header 408 | 409 | # Gather info 410 | header = self.getheaderAnim(im) 411 | appext = self.getAppExt(loops) 412 | 413 | # Write 414 | fp.write(encode(header)) 415 | fp.write(globalPalette) 416 | fp.write(encode(appext)) 417 | 418 | # Next frame is not the first 419 | firstFrame = False 420 | 421 | if True: 422 | # Write palette and image data 423 | 424 | # Gather info 425 | data = getdata(im) 426 | imdes, data = data[0], data[1:] 427 | graphext = self.getGraphicsControlExt(durations[frames], 428 | disposes[frames]) 429 | # Make image descriptor suitable for using 256 local color palette 430 | lid = self.getImageDescriptor(im, xys[frames]) 431 | 432 | # Write local header 433 | if (palette != globalPalette) or (disposes[frames] != 2): 434 | # Use local color palette 435 | fp.write(encode(graphext)) 436 | fp.write(encode(lid)) # write suitable image descriptor 437 | fp.write(palette) # write local color table 438 | fp.write(encode('\x08')) # LZW minimum size code 439 | else: 440 | # Use global color palette 441 | fp.write(encode(graphext)) 442 | fp.write(imdes) # write suitable image descriptor 443 | 444 | # Write image data 445 | for d in data: 446 | fp.write(d) 447 | 448 | # Prepare for next round 449 | frames = frames + 1 450 | 451 | fp.write(encode(";")) # end gif 452 | return frames 453 | 454 | 455 | 456 | 457 | ## Exposed functions 458 | 459 | def writeGif(filename, images, duration=0.1, repeat=True, dither=False, 460 | nq=0, subRectangles=True, dispose=None): 461 | """ writeGif(filename, images, duration=0.1, repeat=True, dither=False, 462 | nq=0, subRectangles=True, dispose=None) 463 | Write an animated gif from the specified images. 464 | Parameters 465 | ---------- 466 | filename : string 467 | The name of the file to write the image to. 468 | images : list 469 | Should be a list consisting of PIL images or numpy arrays. 470 | The latter should be between 0 and 255 for integer types, and 471 | between 0 and 1 for float types. 472 | duration : scalar or list of scalars 473 | The duration for all frames, or (if a list) for each frame. 474 | repeat : bool or integer 475 | The amount of loops. If True, loops infinitetely. 476 | dither : bool 477 | Whether to apply dithering 478 | nq : integer 479 | If nonzero, applies the NeuQuant quantization algorithm to create 480 | the color palette. This algorithm is superior, but slower than 481 | the standard PIL algorithm. The value of nq is the quality 482 | parameter. 1 represents the best quality. 10 is in general a 483 | good tradeoff between quality and speed. When using this option, 484 | better results are usually obtained when subRectangles is False. 485 | subRectangles : False, True, or a list of 2-element tuples 486 | Whether to use sub-rectangles. If True, the minimal rectangle that 487 | is required to update each frame is automatically detected. This 488 | can give significant reductions in file size, particularly if only 489 | a part of the image changes. One can also give a list of x-y 490 | coordinates if you want to do the cropping yourself. The default 491 | is True. 492 | dispose : int 493 | How to dispose each frame. 1 means that each frame is to be left 494 | in place. 2 means the background color should be restored after 495 | each frame. 3 means the decoder should restore the previous frame. 496 | If subRectangles==False, the default is 2, otherwise it is 1. 497 | """ 498 | 499 | # Check PIL 500 | if PIL is None: 501 | raise RuntimeError("Need PIL to write animated gif files.") 502 | 503 | # Check images 504 | images = checkImages(images) 505 | 506 | # Instantiate writer object 507 | gifWriter = GifWriter() 508 | 509 | # Check loops 510 | if repeat is False: 511 | loops = 1 512 | elif repeat is True: 513 | loops = 0 # zero means infinite 514 | else: 515 | loops = int(repeat) 516 | 517 | # Check duration 518 | if hasattr(duration, '__len__'): 519 | if len(duration) == len(images): 520 | duration = [d for d in duration] 521 | else: 522 | raise ValueError("len(duration) doesn't match amount of images.") 523 | else: 524 | duration = [duration for im in images] 525 | 526 | # Check subrectangles 527 | if subRectangles: 528 | images, xy = gifWriter.handleSubRectangles(images, subRectangles) 529 | defaultDispose = 1 # Leave image in place 530 | else: 531 | # Normal mode 532 | xy = [(0,0) for im in images] 533 | defaultDispose = 2 # Restore to background color. 534 | 535 | # Check dispose 536 | if dispose is None: 537 | dispose = defaultDispose 538 | if hasattr(dispose, '__len__'): 539 | if len(dispose) != len(images): 540 | raise ValueError("len(xy) doesn't match amount of images.") 541 | else: 542 | dispose = [dispose for im in images] 543 | 544 | 545 | # Make images in a format that we can write easy 546 | images = gifWriter.convertImagesToPIL(images, dither, nq) 547 | 548 | # Write 549 | fp = open(filename, 'wb') 550 | try: 551 | gifWriter.writeGifToFile(fp, images, duration, loops, xy, dispose) 552 | finally: 553 | fp.close() 554 | 555 | 556 | 557 | def readGif(filename, asNumpy=True): 558 | """ readGif(filename, asNumpy=True) 559 | Read images from an animated GIF file. Returns a list of numpy 560 | arrays, or, if asNumpy is false, a list if PIL images. 561 | """ 562 | 563 | # Check PIL 564 | if PIL is None: 565 | raise RuntimeError("Need PIL to read animated gif files.") 566 | 567 | # Check Numpy 568 | if np is None: 569 | raise RuntimeError("Need Numpy to read animated gif files.") 570 | 571 | # Check whether it exists 572 | if not os.path.isfile(filename): 573 | raise IOError('File not found: '+str(filename)) 574 | 575 | # Load file using PIL 576 | pilIm = PIL.Image.open(filename) 577 | pilIm.seek(0) 578 | 579 | # Read all images inside 580 | images = [] 581 | try: 582 | while True: 583 | # Get image as numpy array 584 | tmp = pilIm.convert() # Make without palette 585 | a = np.asarray(tmp) 586 | if len(a.shape)==0: 587 | raise MemoryError("Too little memory to convert PIL image to array") 588 | # Store, and next 589 | images.append(a) 590 | pilIm.seek(pilIm.tell()+1) 591 | except EOFError: 592 | pass 593 | 594 | # Convert to normal PIL images if needed 595 | if not asNumpy: 596 | images2 = images 597 | images = [] 598 | for im in images2: 599 | images.append( PIL.Image.fromarray(im) ) 600 | 601 | # Done 602 | return images 603 | 604 | 605 | class NeuQuant: 606 | """ NeuQuant(image, samplefac=10, colors=256) 607 | samplefac should be an integer number of 1 or higher, 1 608 | being the highest quality, but the slowest performance. 609 | With avalue of 10, one tenth of all pixels are used during 610 | training. This value seems a nice tradeof between speed 611 | and quality. 612 | colors is the amount of colors to reduce the image to. This 613 | should best be a power of two. 614 | See also: 615 | http://members.ozemail.com.au/~dekker/NEUQUANT.HTML 616 | License of the NeuQuant Neural-Net Quantization Algorithm 617 | --------------------------------------------------------- 618 | Copyright (c) 1994 Anthony Dekker 619 | Ported to python by Marius van Voorden in 2010 620 | NEUQUANT Neural-Net quantization algorithm by Anthony Dekker, 1994. 621 | See "Kohonen neural networks for optimal colour quantization" 622 | in "network: Computation in Neural Systems" Vol. 5 (1994) pp 351-367. 623 | for a discussion of the algorithm. 624 | See also http://members.ozemail.com.au/~dekker/NEUQUANT.HTML 625 | Any party obtaining a copy of these files from the author, directly or 626 | indirectly, is granted, free of charge, a full and unrestricted irrevocable, 627 | world-wide, paid up, royalty-free, nonexclusive right and license to deal 628 | in this software and documentation files (the "Software"), including without 629 | limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, 630 | and/or sell copies of the Software, and to permit persons who receive 631 | copies from any such party to do so, with the only requirement being 632 | that this copyright notice remain intact. 633 | """ 634 | 635 | NCYCLES = None # Number of learning cycles 636 | NETSIZE = None # Number of colours used 637 | SPECIALS = None # Number of reserved colours used 638 | BGCOLOR = None # Reserved background colour 639 | CUTNETSIZE = None 640 | MAXNETPOS = None 641 | 642 | INITRAD = None # For 256 colours, radius starts at 32 643 | RADIUSBIASSHIFT = None 644 | RADIUSBIAS = None 645 | INITBIASRADIUS = None 646 | RADIUSDEC = None # Factor of 1/30 each cycle 647 | 648 | ALPHABIASSHIFT = None 649 | INITALPHA = None # biased by 10 bits 650 | 651 | GAMMA = None 652 | BETA = None 653 | BETAGAMMA = None 654 | 655 | network = None # The network itself 656 | colormap = None # The network itself 657 | 658 | netindex = None # For network lookup - really 256 659 | 660 | bias = None # Bias and freq arrays for learning 661 | freq = None 662 | 663 | pimage = None 664 | 665 | # Four primes near 500 - assume no image has a length so large 666 | # that it is divisible by all four primes 667 | PRIME1 = 499 668 | PRIME2 = 491 669 | PRIME3 = 487 670 | PRIME4 = 503 671 | MAXPRIME = PRIME4 672 | 673 | pixels = None 674 | samplefac = None 675 | 676 | a_s = None 677 | 678 | 679 | def setconstants(self, samplefac, colors): 680 | self.NCYCLES = 100 # Number of learning cycles 681 | self.NETSIZE = colors # Number of colours used 682 | self.SPECIALS = 3 # Number of reserved colours used 683 | self.BGCOLOR = self.SPECIALS-1 # Reserved background colour 684 | self.CUTNETSIZE = self.NETSIZE - self.SPECIALS 685 | self.MAXNETPOS = self.NETSIZE - 1 686 | 687 | self.INITRAD = self.NETSIZE/8 # For 256 colours, radius starts at 32 688 | self.RADIUSBIASSHIFT = 6 689 | self.RADIUSBIAS = 1 << self.RADIUSBIASSHIFT 690 | self.INITBIASRADIUS = self.INITRAD * self.RADIUSBIAS 691 | self.RADIUSDEC = 30 # Factor of 1/30 each cycle 692 | 693 | self.ALPHABIASSHIFT = 10 # Alpha starts at 1 694 | self.INITALPHA = 1 << self.ALPHABIASSHIFT # biased by 10 bits 695 | 696 | self.GAMMA = 1024.0 697 | self.BETA = 1.0/1024.0 698 | self.BETAGAMMA = self.BETA * self.GAMMA 699 | 700 | self.network = np.empty((self.NETSIZE, 3), dtype='float64') # The network itself 701 | self.colormap = np.empty((self.NETSIZE, 4), dtype='int32') # The network itself 702 | 703 | self.netindex = np.empty(256, dtype='int32') # For network lookup - really 256 704 | 705 | self.bias = np.empty(self.NETSIZE, dtype='float64') # Bias and freq arrays for learning 706 | self.freq = np.empty(self.NETSIZE, dtype='float64') 707 | 708 | self.pixels = None 709 | self.samplefac = samplefac 710 | 711 | self.a_s = {} 712 | 713 | def __init__(self, image, samplefac=10, colors=256): 714 | 715 | # Check Numpy 716 | if np is None: 717 | raise RuntimeError("Need Numpy for the NeuQuant algorithm.") 718 | 719 | # Check image 720 | if image.size[0] * image.size[1] < NeuQuant.MAXPRIME: 721 | raise IOError("Image is too small") 722 | if image.mode != "RGBA": 723 | raise IOError("Image mode should be RGBA.") 724 | 725 | # Initialize 726 | self.setconstants(samplefac, colors) 727 | self.pixels = np.fromstring(image.tostring(), np.uint32) 728 | self.setUpArrays() 729 | 730 | self.learn() 731 | self.fix() 732 | self.inxbuild() 733 | 734 | def writeColourMap(self, rgb, outstream): 735 | for i in range(self.NETSIZE): 736 | bb = self.colormap[i,0]; 737 | gg = self.colormap[i,1]; 738 | rr = self.colormap[i,2]; 739 | outstream.write(rr if rgb else bb) 740 | outstream.write(gg) 741 | outstream.write(bb if rgb else rr) 742 | return self.NETSIZE 743 | 744 | def setUpArrays(self): 745 | self.network[0,0] = 0.0 # Black 746 | self.network[0,1] = 0.0 747 | self.network[0,2] = 0.0 748 | 749 | self.network[1,0] = 255.0 # White 750 | self.network[1,1] = 255.0 751 | self.network[1,2] = 255.0 752 | 753 | # RESERVED self.BGCOLOR # Background 754 | 755 | for i in range(self.SPECIALS): 756 | self.freq[i] = 1.0 / self.NETSIZE 757 | self.bias[i] = 0.0 758 | 759 | for i in range(self.SPECIALS, self.NETSIZE): 760 | p = self.network[i] 761 | p[:] = (255.0 * (i-self.SPECIALS)) / self.CUTNETSIZE 762 | 763 | self.freq[i] = 1.0 / self.NETSIZE 764 | self.bias[i] = 0.0 765 | 766 | # Omitted: setPixels 767 | 768 | def altersingle(self, alpha, i, b, g, r): 769 | """Move neuron i towards biased (b,g,r) by factor alpha""" 770 | n = self.network[i] # Alter hit neuron 771 | n[0] -= (alpha*(n[0] - b)) 772 | n[1] -= (alpha*(n[1] - g)) 773 | n[2] -= (alpha*(n[2] - r)) 774 | 775 | def geta(self, alpha, rad): 776 | try: 777 | return self.a_s[(alpha, rad)] 778 | except KeyError: 779 | length = rad*2-1 780 | mid = int(length//2) 781 | q = np.array(list(range(mid-1,-1,-1))+list(range(-1,mid))) 782 | a = alpha*(rad*rad - q*q)/(rad*rad) 783 | a[mid] = 0 784 | self.a_s[(alpha, rad)] = a 785 | return a 786 | 787 | def alterneigh(self, alpha, rad, i, b, g, r): 788 | if i-rad >= self.SPECIALS-1: 789 | lo = i-rad 790 | start = 0 791 | else: 792 | lo = self.SPECIALS-1 793 | start = (self.SPECIALS-1 - (i-rad)) 794 | 795 | if i+rad <= self.NETSIZE: 796 | hi = i+rad 797 | end = rad*2-1 798 | else: 799 | hi = self.NETSIZE 800 | end = (self.NETSIZE - (i+rad)) 801 | 802 | a = self.geta(alpha, rad)[start:end] 803 | 804 | p = self.network[lo+1:hi] 805 | p -= np.transpose(np.transpose(p - np.array([b, g, r])) * a) 806 | 807 | #def contest(self, b, g, r): 808 | # """ Search for biased BGR values 809 | # Finds closest neuron (min dist) and updates self.freq 810 | # finds best neuron (min dist-self.bias) and returns position 811 | # for frequently chosen neurons, self.freq[i] is high and self.bias[i] is negative 812 | # self.bias[i] = self.GAMMA*((1/self.NETSIZE)-self.freq[i])""" 813 | # 814 | # i, j = self.SPECIALS, self.NETSIZE 815 | # dists = abs(self.network[i:j] - np.array([b,g,r])).sum(1) 816 | # bestpos = i + np.argmin(dists) 817 | # biasdists = dists - self.bias[i:j] 818 | # bestbiaspos = i + np.argmin(biasdists) 819 | # self.freq[i:j] -= self.BETA * self.freq[i:j] 820 | # self.bias[i:j] += self.BETAGAMMA * self.freq[i:j] 821 | # self.freq[bestpos] += self.BETA 822 | # self.bias[bestpos] -= self.BETAGAMMA 823 | # return bestbiaspos 824 | def contest(self, b, g, r): 825 | """ Search for biased BGR values 826 | Finds closest neuron (min dist) and updates self.freq 827 | finds best neuron (min dist-self.bias) and returns position 828 | for frequently chosen neurons, self.freq[i] is high and self.bias[i] is negative 829 | self.bias[i] = self.GAMMA*((1/self.NETSIZE)-self.freq[i])""" 830 | i, j = self.SPECIALS, self.NETSIZE 831 | dists = abs(self.network[i:j] - np.array([b,g,r])).sum(1) 832 | bestpos = i + np.argmin(dists) 833 | biasdists = dists - self.bias[i:j] 834 | bestbiaspos = i + np.argmin(biasdists) 835 | self.freq[i:j] *= (1-self.BETA) 836 | self.bias[i:j] += self.BETAGAMMA * self.freq[i:j] 837 | self.freq[bestpos] += self.BETA 838 | self.bias[bestpos] -= self.BETAGAMMA 839 | return bestbiaspos 840 | 841 | 842 | 843 | 844 | def specialFind(self, b, g, r): 845 | for i in range(self.SPECIALS): 846 | n = self.network[i] 847 | if n[0] == b and n[1] == g and n[2] == r: 848 | return i 849 | return -1 850 | 851 | def learn(self): 852 | biasRadius = self.INITBIASRADIUS 853 | alphadec = 30 + ((self.samplefac-1)/3) 854 | lengthcount = self.pixels.size 855 | samplepixels = lengthcount / self.samplefac 856 | delta = samplepixels / self.NCYCLES 857 | alpha = self.INITALPHA 858 | 859 | i = 0; 860 | rad = biasRadius * 2**self.RADIUSBIASSHIFT 861 | if rad <= 1: 862 | rad = 0 863 | 864 | print("Beginning 1D learning: samplepixels = %1.2f rad = %i" % 865 | (samplepixels, rad) ) 866 | step = 0 867 | pos = 0 868 | if lengthcount%NeuQuant.PRIME1 != 0: 869 | step = NeuQuant.PRIME1 870 | elif lengthcount%NeuQuant.PRIME2 != 0: 871 | step = NeuQuant.PRIME2 872 | elif lengthcount%NeuQuant.PRIME3 != 0: 873 | step = NeuQuant.PRIME3 874 | else: 875 | step = NeuQuant.PRIME4 876 | 877 | i = 0 878 | printed_string = '' 879 | while i < samplepixels: 880 | if i%100 == 99: 881 | tmp = '\b'*len(printed_string) 882 | printed_string = str((i+1)*100/samplepixels)+"%\n" 883 | print(tmp + printed_string) 884 | p = self.pixels[pos] 885 | r = (p >> 16) & 0xff 886 | g = (p >> 8) & 0xff 887 | b = (p ) & 0xff 888 | 889 | if i == 0: # Remember background colour 890 | self.network[self.BGCOLOR] = [b, g, r] 891 | 892 | j = self.specialFind(b, g, r) 893 | if j < 0: 894 | j = self.contest(b, g, r) 895 | 896 | if j >= self.SPECIALS: # Don't learn for specials 897 | a = (1.0 * alpha) / self.INITALPHA 898 | self.altersingle(a, j, b, g, r) 899 | if rad > 0: 900 | self.alterneigh(a, rad, j, b, g, r) 901 | 902 | pos = (pos+step)%lengthcount 903 | 904 | i += 1 905 | if i%delta == 0: 906 | alpha -= alpha / alphadec 907 | biasRadius -= biasRadius / self.RADIUSDEC 908 | rad = biasRadius * 2**self.RADIUSBIASSHIFT 909 | if rad <= 1: 910 | rad = 0 911 | 912 | finalAlpha = (1.0*alpha)/self.INITALPHA 913 | print("Finished 1D learning: final alpha = %1.2f!" % finalAlpha) 914 | 915 | def fix(self): 916 | for i in range(self.NETSIZE): 917 | for j in range(3): 918 | x = int(0.5 + self.network[i,j]) 919 | x = max(0, x) 920 | x = min(255, x) 921 | self.colormap[i,j] = x 922 | self.colormap[i,3] = i 923 | 924 | def inxbuild(self): 925 | previouscol = 0 926 | startpos = 0 927 | for i in range(self.NETSIZE): 928 | p = self.colormap[i] 929 | q = None 930 | smallpos = i 931 | smallval = p[1] # Index on g 932 | # Find smallest in i..self.NETSIZE-1 933 | for j in range(i+1, self.NETSIZE): 934 | q = self.colormap[j] 935 | if q[1] < smallval: # Index on g 936 | smallpos = j 937 | smallval = q[1] # Index on g 938 | 939 | q = self.colormap[smallpos] 940 | # Swap p (i) and q (smallpos) entries 941 | if i != smallpos: 942 | p[:],q[:] = q, p.copy() 943 | 944 | # smallval entry is now in position i 945 | if smallval != previouscol: 946 | self.netindex[previouscol] = (startpos+i) >> 1 947 | for j in range(previouscol+1, smallval): 948 | self.netindex[j] = i 949 | previouscol = smallval 950 | startpos = i 951 | self.netindex[previouscol] = (startpos+self.MAXNETPOS) >> 1 952 | for j in range(previouscol+1, 256): # Really 256 953 | self.netindex[j] = self.MAXNETPOS 954 | 955 | 956 | def paletteImage(self): 957 | """ PIL weird interface for making a paletted image: create an image which 958 | already has the palette, and use that in Image.quantize. This function 959 | returns this palette image. """ 960 | if self.pimage is None: 961 | palette = [] 962 | for i in range(self.NETSIZE): 963 | palette.extend(self.colormap[i][:3]) 964 | 965 | palette.extend([0]*(256-self.NETSIZE)*3) 966 | 967 | # a palette image to use for quant 968 | self.pimage = Image.new("P", (1, 1), 0) 969 | self.pimage.putpalette(palette) 970 | return self.pimage 971 | 972 | 973 | def quantize(self, image): 974 | """ Use a kdtree to quickly find the closest palette colors for the pixels """ 975 | if get_cKDTree(): 976 | return self.quantize_with_scipy(image) 977 | else: 978 | print('Scipy not available, falling back to slower version.') 979 | return self.quantize_without_scipy(image) 980 | 981 | 982 | def quantize_with_scipy(self, image): 983 | w,h = image.size 984 | px = np.asarray(image).copy() 985 | px2 = px[:,:,:3].reshape((w*h,3)) 986 | 987 | cKDTree = get_cKDTree() 988 | kdtree = cKDTree(self.colormap[:,:3],leafsize=10) 989 | result = kdtree.query(px2) 990 | colorindex = result[1] 991 | print("Distance: %1.2f" % (result[0].sum()/(w*h)) ) 992 | px2[:] = self.colormap[colorindex,:3] 993 | 994 | return Image.fromarray(px).convert("RGB").quantize(palette=self.paletteImage()) 995 | 996 | 997 | def quantize_without_scipy(self, image): 998 | """" This function can be used if no scipy is availabe. 999 | It's 7 times slower though. 1000 | """ 1001 | w,h = image.size 1002 | px = np.asarray(image).copy() 1003 | memo = {} 1004 | for j in range(w): 1005 | for i in range(h): 1006 | key = (px[i,j,0],px[i,j,1],px[i,j,2]) 1007 | try: 1008 | val = memo[key] 1009 | except KeyError: 1010 | val = self.convert(*key) 1011 | memo[key] = val 1012 | px[i,j,0],px[i,j,1],px[i,j,2] = val 1013 | return Image.fromarray(px).convert("RGB").quantize(palette=self.paletteImage()) 1014 | 1015 | def convert(self, *color): 1016 | i = self.inxsearch(*color) 1017 | return self.colormap[i,:3] 1018 | 1019 | def inxsearch(self, r, g, b): 1020 | """Search for BGR values 0..255 and return colour index""" 1021 | dists = (self.colormap[:,:3] - np.array([r,g,b])) 1022 | a= np.argmin((dists*dists).sum(1)) 1023 | return a 1024 | 1025 | 1026 | 1027 | if __name__ == '__main__': 1028 | im = np.zeros((200,200), dtype=np.uint8) 1029 | im[10:30,:] = 100 1030 | im[:,80:120] = 255 1031 | im[-50:-40,:] = 50 1032 | 1033 | images = [im*1.0, im*0.8, im*0.6, im*0.4, im*0] 1034 | writeGif('lala3.gif',images, duration=0.5, dither=0) 1035 | --------------------------------------------------------------------------------