├── .dockerignore ├── .gitignore ├── 02_01_deep_learning_deep_neural_network.ipynb ├── 02_02_deep_learning_convolutions.ipynb ├── 02_03_deep_learning_conv_neural_network.ipynb ├── 03_01_autoencoder_train.ipynb ├── 03_02_autoencoder_analysis.ipynb ├── 03_03_vae_digits_train.ipynb ├── 03_04_vae_digits_analysis.ipynb ├── 03_05_vae_faces_train.ipynb ├── 03_06_vae_faces_analysis.ipynb ├── 04_01_gan_camel_train.ipynb ├── 04_02_wgan_cifar_train.ipynb ├── 04_03_wgangp_faces_train.ipynb ├── 05_01_cyclegan_train.ipynb ├── 06_01_lstm_text_train.ipynb ├── 06_02_qa_train.ipynb ├── 06_03_qa_analysis.ipynb ├── 07_01_notation_compose.ipynb ├── 07_02_lstm_compose_train.ipynb ├── 07_03_lstm_compose_predict.ipynb ├── 07_04_musegan_train.ipynb ├── 07_05_musegan_analysis.ipynb ├── 09_01_positional_encoding.ipynb ├── Dockerfile.cpu ├── Dockerfile.gpu ├── LICENSE ├── README.md ├── colab ├── 03_01_02_autoencoder.ipynb ├── 03_03_04_vae_digits.ipynb ├── 04_01_gan_camel.ipynb ├── 04_02_wgan_cifar.ipynb ├── 05_01_cyclegan_train.ipynb ├── 06_01_lstm_text_train.ipynb └── README.md ├── data ├── .gitignore └── qa_test │ └── my_test.csv ├── launch-docker-cpu.sh ├── launch-docker-gpu.sh ├── models ├── AE.py ├── GAN.py ├── MuseGAN.py ├── RNNAttention.py ├── VAE.py ├── WGAN.py ├── WGANGP.py ├── cycleGAN.py └── layers │ └── layers.py ├── requirements.txt ├── run ├── compose │ └── .gitignore ├── gan │ └── .gitignore ├── paint │ └── .gitignore ├── vae │ └── .gitignore └── write │ └── .gitignore ├── scripts ├── download_cyclegan_data.sh └── download_gutenburg_data.sh └── utils ├── callbacks.py ├── loaders.py └── write.py /.dockerignore: -------------------------------------------------------------------------------- 1 | data/ 2 | run/ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | Dockerfile 2 | 3 | run/gan/* 4 | run/compose/* 5 | run/paint/* 6 | run/write/* 7 | run/vae/* 8 | 9 | 10 | 11 | archive/* 12 | 13 | models/archive/* 14 | 15 | 16 | .DS_Store 17 | 18 | # Byte-compiled / optimized / DLL files 19 | __pycache__/ 20 | *.py[cod] 21 | *$py.class 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | MANIFEST 44 | 45 | # PyInstaller 46 | # Usually these files are written by a python script from a template 47 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 48 | *.manifest 49 | *.spec 50 | 51 | # Installer logs 52 | pip-log.txt 53 | pip-delete-this-directory.txt 54 | 55 | # Unit test / coverage reports 56 | htmlcov/ 57 | .tox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | .hypothesis/ 65 | .pytest_cache/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | db.sqlite3 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # celery beat schedule file 96 | celerybeat-schedule 97 | 98 | # SageMath parsed files 99 | *.sage.py 100 | 101 | # Environments 102 | .env 103 | .venv 104 | env/ 105 | venv/ 106 | ENV/ 107 | env.bak/ 108 | venv.bak/ 109 | 110 | # Spyder project settings 111 | .spyderproject 112 | .spyproject 113 | 114 | # Rope project settings 115 | .ropeproject 116 | 117 | # mkdocs documentation 118 | /site 119 | 120 | # mypy 121 | .mypy_cache/ 122 | 123 | -------------------------------------------------------------------------------- /02_01_deep_learning_deep_neural_network.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Your first deep neural network" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "# imports" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import numpy as np\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "\n", 26 | "from tensorflow.keras.layers import Input, Flatten, Dense, Conv2D\n", 27 | "from tensorflow.keras.models import Model\n", 28 | "from tensorflow.keras.optimizers import Adam\n", 29 | "from tensorflow.keras.utils import to_categorical\n", 30 | "\n", 31 | "from tensorflow.keras.datasets import cifar10" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "# data" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "NUM_CLASSES = 10" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "(x_train, y_train), (x_test, y_test) = cifar10.load_data()" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "x_train = x_train.astype('float32') / 255.0\n", 66 | "x_test = x_test.astype('float32') / 255.0\n", 67 | "\n", 68 | "y_train = to_categorical(y_train, NUM_CLASSES)\n", 69 | "y_test = to_categorical(y_test, NUM_CLASSES)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "x_train[54, 12, 13, 1] " 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "# architecture" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "input_layer = Input((32,32,3))\n", 95 | "\n", 96 | "x = Flatten()(input_layer)\n", 97 | "\n", 98 | "x = Dense(200, activation = 'relu')(x)\n", 99 | "x = Dense(150, activation = 'relu')(x)\n", 100 | "\n", 101 | "output_layer = Dense(NUM_CLASSES, activation = 'softmax')(x)\n", 102 | "\n", 103 | "model = Model(input_layer, output_layer)\n", 104 | "\n", 105 | "model.summary()" 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "metadata": {}, 111 | "source": [ 112 | "# train" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "opt = Adam(lr=0.0005)\n", 122 | "model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "model.fit(x_train\n", 132 | " , y_train\n", 133 | " , batch_size=32\n", 134 | " , epochs=10\n", 135 | " , shuffle=True)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": {}, 141 | "source": [ 142 | "# analysis" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "model.evaluate(x_test, y_test)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "CLASSES = np.array(['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'])\n", 161 | "\n", 162 | "preds = model.predict(x_test)\n", 163 | "preds_single = CLASSES[np.argmax(preds, axis = -1)]\n", 164 | "actual_single = CLASSES[np.argmax(y_test, axis = -1)]" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": { 171 | "scrolled": true 172 | }, 173 | "outputs": [], 174 | "source": [ 175 | "\n", 176 | "n_to_show = 10\n", 177 | "indices = np.random.choice(range(len(x_test)), n_to_show)\n", 178 | "\n", 179 | "fig = plt.figure(figsize=(15, 3))\n", 180 | "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n", 181 | "\n", 182 | "for i, idx in enumerate(indices):\n", 183 | " img = x_test[idx]\n", 184 | " ax = fig.add_subplot(1, n_to_show, i+1)\n", 185 | " ax.axis('off')\n", 186 | " ax.text(0.5, -0.35, 'pred = ' + str(preds_single[idx]), fontsize=10, ha='center', transform=ax.transAxes) \n", 187 | " ax.text(0.5, -0.7, 'act = ' + str(actual_single[idx]), fontsize=10, ha='center', transform=ax.transAxes)\n", 188 | " ax.imshow(img)\n" 189 | ] 190 | } 191 | ], 192 | "metadata": { 193 | "kernelspec": { 194 | "display_name": "gdl_code", 195 | "language": "python", 196 | "name": "gdl_code" 197 | }, 198 | "language_info": { 199 | "codemirror_mode": { 200 | "name": "ipython", 201 | "version": 3 202 | }, 203 | "file_extension": ".py", 204 | "mimetype": "text/x-python", 205 | "name": "python", 206 | "nbconvert_exporter": "python", 207 | "pygments_lexer": "ipython3", 208 | "version": "3.7.5" 209 | } 210 | }, 211 | "nbformat": 4, 212 | "nbformat_minor": 2 213 | } 214 | -------------------------------------------------------------------------------- /02_02_deep_learning_convolutions.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# imports" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "%matplotlib inline\n", 17 | "import matplotlib.pyplot as plt\n", 18 | "from scipy.ndimage import correlate\n", 19 | "import numpy as np\n", 20 | "from skimage import data\n", 21 | "from skimage.color import rgb2gray\n", 22 | "from skimage.transform import rescale,resize" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "# original image input" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "im = rgb2gray(data.coffee())\n", 39 | "im = resize(im, (64,64))\n", 40 | "print(im.shape)\n", 41 | "\n", 42 | "plt.axis('off')\n", 43 | "plt.imshow(im, cmap = 'gray');\n", 44 | "\n" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "# horizontal edge filter" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "filter1 = np.array([\n", 61 | " [ 1, 1, 1],\n", 62 | " [ 0, 0, 0],\n", 63 | " [-1, -1, -1]\n", 64 | "])\n", 65 | "\n", 66 | "new_image = np.zeros(im.shape)\n", 67 | "\n", 68 | "im_pad = np.pad(im, 1, 'constant')\n", 69 | "\n", 70 | "for i in range(im.shape[0]):\n", 71 | " for j in range(im.shape[1]):\n", 72 | " try:\n", 73 | " new_image[i,j] = \\\n", 74 | " im_pad[i-1,j-1] * filter1[0,0] + \\\n", 75 | " im_pad[i-1,j] * filter1[0,1] + \\\n", 76 | " im_pad[i-1,j+1] * filter1[0,2] + \\\n", 77 | " im_pad[i,j-1] * filter1[1,0] + \\\n", 78 | " im_pad[i,j] * filter1[1,1] + \\\n", 79 | " im_pad[i,j+1] * filter1[1,2] +\\\n", 80 | " im_pad[i+1,j-1] * filter1[2,0] + \\\n", 81 | " im_pad[i+1,j] * filter1[2,1] + \\\n", 82 | " im_pad[i+1,j+1] * filter1[2,2] \n", 83 | " except:\n", 84 | " pass\n", 85 | "\n", 86 | "plt.axis('off')\n", 87 | "plt.imshow(new_image, cmap='Greys');" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "# vertical edge filter" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "filter2 = np.array([\n", 104 | " [ -1, 0, 1],\n", 105 | " [ -1, 0, 1],\n", 106 | " [ -1, 0, 1]\n", 107 | "])\n", 108 | "\n", 109 | "new_image = np.zeros(im.shape)\n", 110 | "\n", 111 | "im_pad = np.pad(im,1, 'constant')\n", 112 | "\n", 113 | "for i in range(im.shape[0]):\n", 114 | " for j in range(im.shape[1]):\n", 115 | " try:\n", 116 | " new_image[i,j] = \\\n", 117 | " im_pad[i-1,j-1] * filter2[0,0] + \\\n", 118 | " im_pad[i-1,j] * filter2[0,1] + \\\n", 119 | " im_pad[i-1,j+1] * filter2[0,2] + \\\n", 120 | " im_pad[i,j-1] * filter2[1,0] + \\\n", 121 | " im_pad[i,j] * filter2[1,1] + \\\n", 122 | " im_pad[i,j+1] * filter2[1,2] +\\\n", 123 | " im_pad[i+1,j-1] * filter2[2,0] + \\\n", 124 | " im_pad[i+1,j] * filter2[2,1] + \\\n", 125 | " im_pad[i+1,j+1] * filter2[2,2] \n", 126 | " except:\n", 127 | " pass\n", 128 | "\n", 129 | "plt.axis('off')\n", 130 | "plt.imshow(new_image, cmap='Greys');" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": {}, 136 | "source": [ 137 | "# horizontal edge filter with stride 2" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "filter1 = np.array([\n", 147 | " [ 1, 1, 1],\n", 148 | " [ 0, 0, 0],\n", 149 | " [-1, -1, -1]\n", 150 | "])\n", 151 | "\n", 152 | "stride = 2\n", 153 | "\n", 154 | "new_image = np.zeros((int(im.shape[0] / stride), int(im.shape[1] / stride)))\n", 155 | "\n", 156 | "im_pad = np.pad(im,1, 'constant')\n", 157 | "\n", 158 | "for i in range(0,im.shape[0],stride):\n", 159 | " for j in range(0,im.shape[1],stride):\n", 160 | " try:\n", 161 | " new_image[int(i/stride),int(j/stride)] = \\\n", 162 | " im_pad[i-1,j-1] * filter1[0,0] + \\\n", 163 | " im_pad[i-1,j] * filter1[0,1] + \\\n", 164 | " im_pad[i-1,j+1] * filter1[0,2] + \\\n", 165 | " im_pad[i,j-1] * filter1[1,0] + \\\n", 166 | " im_pad[i,j] * filter1[1,1] + \\\n", 167 | " im_pad[i,j+1] * filter1[1,2] +\\\n", 168 | " im_pad[i+1,j-1] * filter1[2,0] + \\\n", 169 | " im_pad[i+1,j] * filter1[2,1] + \\\n", 170 | " im_pad[i+1,j+1] * filter1[2,2] \n", 171 | " except:\n", 172 | " pass\n", 173 | "\n", 174 | "plt.axis('off')\n", 175 | "plt.imshow(new_image, cmap='Greys');" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": {}, 181 | "source": [ 182 | "# vertical edge filter with stride 2" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "filter2 = np.array([\n", 192 | " [ -1, 0, 1],\n", 193 | " [ -1, 0, 1],\n", 194 | " [ -1, 0, 1]\n", 195 | "])\n", 196 | "\n", 197 | "stride = 2\n", 198 | "\n", 199 | "new_image = np.zeros((int(im.shape[0] / stride), int(im.shape[1] / stride)))\n", 200 | "\n", 201 | "im_pad = np.pad(im,1, 'constant')\n", 202 | "\n", 203 | "for i in range(0,im.shape[0],stride):\n", 204 | " for j in range(0,im.shape[1],stride):\n", 205 | " try:\n", 206 | " new_image[int(i/stride),int(j/stride)] = \\\n", 207 | " im_pad[i-1,j-1] * filter2[0,0] + \\\n", 208 | " im_pad[i-1,j] * filter2[0,1] + \\\n", 209 | " im_pad[i-1,j+1] * filter2[0,2] + \\\n", 210 | " im_pad[i,j-1] * filter2[1,0] + \\\n", 211 | " im_pad[i,j] * filter2[1,1] + \\\n", 212 | " im_pad[i,j+1] * filter2[1,2] +\\\n", 213 | " im_pad[i+1,j-1] * filter2[2,0] + \\\n", 214 | " im_pad[i+1,j] * filter2[2,1] + \\\n", 215 | " im_pad[i+1,j+1] * filter2[2,2] \n", 216 | " except:\n", 217 | " pass\n", 218 | "\n", 219 | "plt.axis('off')\n", 220 | "plt.imshow(new_image, cmap='Greys');" 221 | ] 222 | } 223 | ], 224 | "metadata": { 225 | "kernelspec": { 226 | "display_name": "gdl_code", 227 | "language": "python", 228 | "name": "gdl_code" 229 | }, 230 | "language_info": { 231 | "codemirror_mode": { 232 | "name": "ipython", 233 | "version": 3 234 | }, 235 | "file_extension": ".py", 236 | "mimetype": "text/x-python", 237 | "name": "python", 238 | "nbconvert_exporter": "python", 239 | "pygments_lexer": "ipython3", 240 | "version": "3.7.5" 241 | } 242 | }, 243 | "nbformat": 4, 244 | "nbformat_minor": 2 245 | } 246 | -------------------------------------------------------------------------------- /02_03_deep_learning_conv_neural_network.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Your first convolutional neural network" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "# imports" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import numpy as np\n", 24 | "\n", 25 | "from tensorflow.keras.layers import Input, Flatten, Dense, Conv2D, BatchNormalization, LeakyReLU, Dropout, Activation\n", 26 | "from tensorflow.keras.models import Model\n", 27 | "from tensorflow.keras.optimizers import Adam\n", 28 | "from tensorflow.keras.utils import to_categorical\n", 29 | "import tensorflow.keras.backend as K \n", 30 | "\n", 31 | "from tensorflow.keras.datasets import cifar10" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "# data" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "NUM_CLASSES = 10" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "(x_train, y_train), (x_test, y_test) = cifar10.load_data()" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "x_train = x_train.astype('float32') / 255.0\n", 66 | "x_test = x_test.astype('float32') / 255.0\n", 67 | "\n", 68 | "y_train = to_categorical(y_train, NUM_CLASSES)\n", 69 | "y_test = to_categorical(y_test, NUM_CLASSES)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "x_train[54, 12, 13, 1] " 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "# architecture" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "input_layer = Input(shape=(32,32,3))\n", 95 | "\n", 96 | "conv_layer_1 = Conv2D(\n", 97 | " filters = 10\n", 98 | " , kernel_size = (4,4)\n", 99 | " , strides = 2\n", 100 | " , padding = 'same'\n", 101 | " )(input_layer)\n", 102 | "\n", 103 | "conv_layer_2 = Conv2D(\n", 104 | " filters = 20\n", 105 | " , kernel_size = (3,3)\n", 106 | " , strides = 2\n", 107 | " , padding = 'same'\n", 108 | " )(conv_layer_1)\n", 109 | "\n", 110 | "flatten_layer = Flatten()(conv_layer_2)\n", 111 | "\n", 112 | "output_layer = Dense(units=10, activation = 'softmax')(flatten_layer)\n", 113 | "\n", 114 | "model = Model(input_layer, output_layer)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "model.summary()" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "input_layer = Input((32,32,3))\n", 133 | "\n", 134 | "x = Conv2D(filters = 32, kernel_size = 3, strides = 1, padding = 'same')(input_layer)\n", 135 | "x = BatchNormalization()(x)\n", 136 | "x = LeakyReLU()(x)\n", 137 | "\n", 138 | "\n", 139 | "x = Conv2D(filters = 32, kernel_size = 3, strides = 2, padding = 'same')(x)\n", 140 | "x = BatchNormalization()(x)\n", 141 | "x = LeakyReLU()(x)\n", 142 | "\n", 143 | "\n", 144 | "x = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = 'same')(x)\n", 145 | "x = BatchNormalization()(x)\n", 146 | "x = LeakyReLU()(x)\n", 147 | "\n", 148 | "\n", 149 | "x = Conv2D(filters = 64, kernel_size = 3, strides = 2, padding = 'same')(x)\n", 150 | "x = BatchNormalization()(x)\n", 151 | "x = LeakyReLU()(x)\n", 152 | "\n", 153 | "\n", 154 | "x = Flatten()(x)\n", 155 | "\n", 156 | "x = Dense(128)(x)\n", 157 | "x = BatchNormalization()(x)\n", 158 | "x = LeakyReLU()(x)\n", 159 | "x = Dropout(rate = 0.5)(x)\n", 160 | "\n", 161 | "x = Dense(NUM_CLASSES)(x)\n", 162 | "output_layer = Activation('softmax')(x)\n", 163 | "\n", 164 | "model = Model(input_layer, output_layer)" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "model.summary()" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "# train" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "opt = Adam(lr=0.0005)\n", 190 | "model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "model.fit(x_train\n", 200 | " , y_train\n", 201 | " , batch_size=32\n", 202 | " , epochs=10\n", 203 | " , shuffle=True\n", 204 | " , validation_data = (x_test, y_test))" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "model.layers[6].get_weights()" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": {}, 219 | "source": [ 220 | "# analysis" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "model.evaluate(x_test, y_test, batch_size=1000)" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "CLASSES = np.array(['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'])\n", 239 | "\n", 240 | "preds = model.predict(x_test)\n", 241 | "preds_single = CLASSES[np.argmax(preds, axis = -1)]\n", 242 | "actual_single = CLASSES[np.argmax(y_test, axis = -1)]" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "import matplotlib.pyplot as plt\n", 252 | "\n", 253 | "n_to_show = 10\n", 254 | "indices = np.random.choice(range(len(x_test)), n_to_show)\n", 255 | "\n", 256 | "fig = plt.figure(figsize=(15, 3))\n", 257 | "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n", 258 | "\n", 259 | "for i, idx in enumerate(indices):\n", 260 | " img = x_test[idx]\n", 261 | " ax = fig.add_subplot(1, n_to_show, i+1)\n", 262 | " ax.axis('off')\n", 263 | " ax.text(0.5, -0.35, 'pred = ' + str(preds_single[idx]), fontsize=10, ha='center', transform=ax.transAxes) \n", 264 | " ax.text(0.5, -0.7, 'act = ' + str(actual_single[idx]), fontsize=10, ha='center', transform=ax.transAxes)\n", 265 | " ax.imshow(img)\n" 266 | ] 267 | } 268 | ], 269 | "metadata": { 270 | "kernelspec": { 271 | "display_name": "gdl_code", 272 | "language": "python", 273 | "name": "gdl_code" 274 | }, 275 | "language_info": { 276 | "codemirror_mode": { 277 | "name": "ipython", 278 | "version": 3 279 | }, 280 | "file_extension": ".py", 281 | "mimetype": "text/x-python", 282 | "name": "python", 283 | "nbconvert_exporter": "python", 284 | "pygments_lexer": "ipython3", 285 | "version": "3.7.5" 286 | } 287 | }, 288 | "nbformat": 4, 289 | "nbformat_minor": 2 290 | } 291 | -------------------------------------------------------------------------------- /03_01_autoencoder_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Autoencoder" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import os\n", 17 | "\n", 18 | "from utils.loaders import load_mnist\n", 19 | "from models.AE import Autoencoder" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "## Set parameters" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "# run params\n", 36 | "SECTION = 'vae'\n", 37 | "RUN_ID = '0001'\n", 38 | "DATA_NAME = 'digits'\n", 39 | "RUN_FOLDER = 'run/{}/'.format(SECTION)\n", 40 | "RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])\n", 41 | "\n", 42 | "if not os.path.exists(RUN_FOLDER):\n", 43 | " os.mkdir(RUN_FOLDER)\n", 44 | " os.mkdir(os.path.join(RUN_FOLDER, 'viz'))\n", 45 | " os.mkdir(os.path.join(RUN_FOLDER, 'images'))\n", 46 | " os.mkdir(os.path.join(RUN_FOLDER, 'weights'))\n", 47 | "\n", 48 | "MODE = 'build' #'load' #" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "## Load the data" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "(x_train, y_train), (x_test, y_test) = load_mnist()" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "## Define the structure of the neural network" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "AE = Autoencoder(\n", 81 | " input_dim = (28,28,1)\n", 82 | " , encoder_conv_filters = [32,64,64, 64]\n", 83 | " , encoder_conv_kernel_size = [3,3,3,3]\n", 84 | " , encoder_conv_strides = [1,2,2,1]\n", 85 | " , decoder_conv_t_filters = [64,64,32,1]\n", 86 | " , decoder_conv_t_kernel_size = [3,3,3,3]\n", 87 | " , decoder_conv_t_strides = [1,2,2,1]\n", 88 | " , z_dim = 2\n", 89 | ")\n", 90 | "\n", 91 | "if MODE == 'build':\n", 92 | " AE.save(RUN_FOLDER)\n", 93 | "else:\n", 94 | " AE.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "AE.encoder.summary()" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "AE.decoder.summary()" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": {}, 118 | "source": [ 119 | "## Train the autoencoder" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "LEARNING_RATE = 0.0005\n", 129 | "BATCH_SIZE = 32\n", 130 | "INITIAL_EPOCH = 0" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "AE.compile(LEARNING_RATE)" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "AE.train( \n", 149 | " x_train[:1000]\n", 150 | " , batch_size = BATCH_SIZE\n", 151 | " , epochs = 200\n", 152 | " , run_folder = RUN_FOLDER\n", 153 | " , initial_epoch = INITIAL_EPOCH\n", 154 | ")" 155 | ] 156 | } 157 | ], 158 | "metadata": { 159 | "kernelspec": { 160 | "display_name": "gdl_code_2", 161 | "language": "python", 162 | "name": "gdl_code_2" 163 | }, 164 | "language_info": { 165 | "codemirror_mode": { 166 | "name": "ipython", 167 | "version": 3 168 | }, 169 | "file_extension": ".py", 170 | "mimetype": "text/x-python", 171 | "name": "python", 172 | "nbconvert_exporter": "python", 173 | "pygments_lexer": "ipython3", 174 | "version": "3.7.3" 175 | } 176 | }, 177 | "nbformat": 4, 178 | "nbformat_minor": 2 179 | } 180 | -------------------------------------------------------------------------------- /03_02_autoencoder_analysis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Autoencoder Analysis" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## imports" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import numpy as np\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "import numpy as np\n", 26 | "import os\n", 27 | "from scipy.stats import norm\n", 28 | "\n", 29 | "from models.AE import Autoencoder\n", 30 | "from utils.loaders import load_mnist, load_model" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "# run params\n", 40 | "SECTION = 'vae'\n", 41 | "RUN_ID = '0001'\n", 42 | "DATA_NAME = 'digits'\n", 43 | "RUN_FOLDER = 'run/{}/'.format(SECTION)\n", 44 | "RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])\n" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "## Load the data" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "(x_train, y_train), (x_test, y_test) = load_mnist()" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "## Load the model architecture" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "AE = load_model(Autoencoder, RUN_FOLDER)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "## reconstructing original paintings" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "n_to_show = 10\n", 93 | "example_idx = np.random.choice(range(len(x_test)), n_to_show)\n", 94 | "example_images = x_test[example_idx]\n", 95 | "\n", 96 | "z_points = AE.encoder.predict(example_images)\n", 97 | "\n", 98 | "reconst_images = AE.decoder.predict(z_points)\n", 99 | "\n", 100 | "fig = plt.figure(figsize=(15, 3))\n", 101 | "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n", 102 | "\n", 103 | "for i in range(n_to_show):\n", 104 | " img = example_images[i].squeeze()\n", 105 | " ax = fig.add_subplot(2, n_to_show, i+1)\n", 106 | " ax.axis('off')\n", 107 | " ax.text(0.5, -0.35, str(np.round(z_points[i],1)), fontsize=10, ha='center', transform=ax.transAxes) \n", 108 | " ax.imshow(img, cmap='gray_r')\n", 109 | "\n", 110 | "for i in range(n_to_show):\n", 111 | " img = reconst_images[i].squeeze()\n", 112 | " ax = fig.add_subplot(2, n_to_show, i+n_to_show+1)\n", 113 | " ax.axis('off')\n", 114 | " ax.imshow(img, cmap='gray_r')\n" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": {}, 120 | "source": [ 121 | "## Mr N. Coder's wall" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "n_to_show = 5000\n", 131 | "grid_size = 15\n", 132 | "figsize = 12\n", 133 | "\n", 134 | "example_idx = np.random.choice(range(len(x_test)), n_to_show)\n", 135 | "example_images = x_test[example_idx]\n", 136 | "example_labels = y_test[example_idx]\n", 137 | "\n", 138 | "z_points = AE.encoder.predict(example_images)\n", 139 | "\n", 140 | "min_x = min(z_points[:, 0])\n", 141 | "max_x = max(z_points[:, 0])\n", 142 | "min_y = min(z_points[:, 1])\n", 143 | "max_y = max(z_points[:, 1])\n", 144 | "\n", 145 | "plt.figure(figsize=(figsize, figsize))\n", 146 | "plt.scatter(z_points[:, 0] , z_points[:, 1], c='black', alpha=0.5, s=2)\n", 147 | "plt.show()" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "### The new generated art exhibition" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "figsize = 5\n", 164 | "\n", 165 | "plt.figure(figsize=(figsize, figsize))\n", 166 | "plt.scatter(z_points[:, 0] , z_points[:, 1], c='black', alpha=0.5, s=2)\n", 167 | "\n", 168 | "grid_size = 10\n", 169 | "grid_depth = 3\n", 170 | "figsize = 15\n", 171 | "\n", 172 | "x = np.random.uniform(min_x,max_x, size = grid_size * grid_depth)\n", 173 | "y = np.random.uniform(min_y,max_y, size = grid_size * grid_depth)\n", 174 | "z_grid = np.array(list(zip(x, y)))\n", 175 | "reconst = AE.decoder.predict(z_grid)\n", 176 | "\n", 177 | "plt.scatter(z_grid[:, 0] , z_grid[:, 1], c = 'red', alpha=1, s=20)\n", 178 | "plt.show()\n", 179 | "\n", 180 | "fig = plt.figure(figsize=(figsize, grid_depth))\n", 181 | "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n", 182 | "\n", 183 | "for i in range(grid_size*grid_depth):\n", 184 | " ax = fig.add_subplot(grid_depth, grid_size, i+1)\n", 185 | " ax.axis('off')\n", 186 | " ax.text(0.5, -0.35, str(np.round(z_grid[i],1)), fontsize=10, ha='center', transform=ax.transAxes)\n", 187 | " \n", 188 | " ax.imshow(reconst[i, :,:,0], cmap = 'Greys')" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "n_to_show = 5000\n", 198 | "grid_size = 15\n", 199 | "figsize = 12\n", 200 | "\n", 201 | "example_idx = np.random.choice(range(len(x_test)), n_to_show)\n", 202 | "example_images = x_test[example_idx]\n", 203 | "example_labels = y_test[example_idx]\n", 204 | "\n", 205 | "z_points = AE.encoder.predict(example_images)\n", 206 | "\n", 207 | "plt.figure(figsize=(figsize, figsize))\n", 208 | "plt.scatter(z_points[:, 0] , z_points[:, 1] , cmap='rainbow' , c= example_labels\n", 209 | " , alpha=0.5, s=2)\n", 210 | "plt.colorbar()\n", 211 | "plt.show()" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "n_to_show = 5000\n", 221 | "grid_size = 20\n", 222 | "figsize = 8\n", 223 | "\n", 224 | "example_idx = np.random.choice(range(len(x_test)), n_to_show)\n", 225 | "example_images = x_test[example_idx]\n", 226 | "example_labels = y_test[example_idx]\n", 227 | "\n", 228 | "z_points = AE.encoder.predict(example_images)\n", 229 | "\n", 230 | "plt.figure(figsize=(5, 5))\n", 231 | "plt.scatter(z_points[:, 0] , z_points[:, 1] , cmap='rainbow' , c= example_labels\n", 232 | " , alpha=0.5, s=2)\n", 233 | "plt.colorbar()\n", 234 | "\n", 235 | "# x = norm.ppf(np.linspace(0.05, 0.95, 10))\n", 236 | "# y = norm.ppf(np.linspace(0.05, 0.95, 10))\n", 237 | "x = np.linspace(min(z_points[:, 0]), max(z_points[:, 0]), grid_size)\n", 238 | "y = np.linspace(max(z_points[:, 1]), min(z_points[:, 1]), grid_size)\n", 239 | "xv, yv = np.meshgrid(x, y)\n", 240 | "xv = xv.flatten()\n", 241 | "yv = yv.flatten()\n", 242 | "z_grid = np.array(list(zip(xv, yv)))\n", 243 | "\n", 244 | "reconst = AE.decoder.predict(z_grid)\n", 245 | "\n", 246 | "plt.scatter(z_grid[:, 0] , z_grid[:, 1], c = 'black'#, cmap='rainbow' , c= example_labels\n", 247 | " , alpha=1, s=5)\n", 248 | "\n", 249 | "\n", 250 | "\n", 251 | "\n", 252 | "plt.show()\n", 253 | "\n", 254 | "\n", 255 | "fig = plt.figure(figsize=(figsize, figsize))\n", 256 | "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n", 257 | "for i in range(grid_size**2):\n", 258 | " ax = fig.add_subplot(grid_size, grid_size, i+1)\n", 259 | " ax.axis('off')\n", 260 | " ax.imshow(reconst[i, :,:,0], cmap = 'Greys')" 261 | ] 262 | } 263 | ], 264 | "metadata": { 265 | "kernelspec": { 266 | "display_name": "gdl_code", 267 | "language": "python", 268 | "name": "gdl_code" 269 | }, 270 | "language_info": { 271 | "codemirror_mode": { 272 | "name": "ipython", 273 | "version": 3 274 | }, 275 | "file_extension": ".py", 276 | "mimetype": "text/x-python", 277 | "name": "python", 278 | "nbconvert_exporter": "python", 279 | "pygments_lexer": "ipython3", 280 | "version": "3.7.5" 281 | } 282 | }, 283 | "nbformat": 4, 284 | "nbformat_minor": 2 285 | } 286 | -------------------------------------------------------------------------------- /03_03_vae_digits_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# VAE Training" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## imports" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "%load_ext autoreload\n", 24 | "%autoreload 2" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "import os\n", 34 | "\n", 35 | "from models.VAE import VariationalAutoencoder\n", 36 | "from utils.loaders import load_mnist" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "# run params\n", 46 | "SECTION = 'vae'\n", 47 | "RUN_ID = '0002'\n", 48 | "DATA_NAME = 'digits'\n", 49 | "RUN_FOLDER = 'run/{}/'.format(SECTION)\n", 50 | "RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])\n", 51 | "\n", 52 | "if not os.path.exists(RUN_FOLDER):\n", 53 | " os.mkdir(RUN_FOLDER)\n", 54 | " os.mkdir(os.path.join(RUN_FOLDER, 'viz'))\n", 55 | " os.mkdir(os.path.join(RUN_FOLDER, 'images'))\n", 56 | " os.mkdir(os.path.join(RUN_FOLDER, 'weights'))\n", 57 | "\n", 58 | "mode = 'build' #'load' #" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "## data" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "(x_train, y_train), (x_test, y_test) = load_mnist()" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "## architecture" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "vae = VariationalAutoencoder(\n", 91 | " input_dim = (28,28,1)\n", 92 | " , encoder_conv_filters = [32,64,64, 64]\n", 93 | " , encoder_conv_kernel_size = [3,3,3,3]\n", 94 | " , encoder_conv_strides = [1,2,2,1]\n", 95 | " , decoder_conv_t_filters = [64,64,32,1]\n", 96 | " , decoder_conv_t_kernel_size = [3,3,3,3]\n", 97 | " , decoder_conv_t_strides = [1,2,2,1]\n", 98 | " , z_dim = 2\n", 99 | " , r_loss_factor = 1000\n", 100 | ")\n", 101 | "\n", 102 | "if mode == 'build':\n", 103 | " vae.save(RUN_FOLDER)\n", 104 | "else:\n", 105 | " vae.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "vae.encoder.summary()" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "vae.decoder.summary()" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "## training" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "LEARNING_RATE = 0.0005" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "vae.compile(LEARNING_RATE)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "BATCH_SIZE = 32\n", 158 | "EPOCHS = 200\n", 159 | "PRINT_EVERY_N_BATCHES = 100\n", 160 | "INITIAL_EPOCH = 0" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": { 167 | "scrolled": false 168 | }, 169 | "outputs": [], 170 | "source": [ 171 | "vae.train( \n", 172 | " x_train\n", 173 | " , batch_size = BATCH_SIZE\n", 174 | " , epochs = EPOCHS\n", 175 | " , run_folder = RUN_FOLDER\n", 176 | " , print_every_n_batches = PRINT_EVERY_N_BATCHES\n", 177 | " , initial_epoch = INITIAL_EPOCH\n", 178 | ")" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [] 187 | } 188 | ], 189 | "metadata": { 190 | "kernelspec": { 191 | "display_name": "gdl_code", 192 | "language": "python", 193 | "name": "gdl_code" 194 | }, 195 | "language_info": { 196 | "codemirror_mode": { 197 | "name": "ipython", 198 | "version": 3 199 | }, 200 | "file_extension": ".py", 201 | "mimetype": "text/x-python", 202 | "name": "python", 203 | "nbconvert_exporter": "python", 204 | "pygments_lexer": "ipython3", 205 | "version": "3.7.5" 206 | } 207 | }, 208 | "nbformat": 4, 209 | "nbformat_minor": 2 210 | } 211 | -------------------------------------------------------------------------------- /03_04_vae_digits_analysis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# VAE Analysis" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## imports" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "%load_ext autoreload\n", 24 | "%autoreload 2" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "import numpy as np\n", 34 | "import matplotlib.pyplot as plt\n", 35 | "import numpy as np\n", 36 | "import os\n", 37 | "from scipy.stats import norm\n", 38 | "\n", 39 | "from models.VAE import VariationalAutoencoder\n", 40 | "from utils.loaders import load_mnist, load_model" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "# run params\n", 50 | "SECTION = 'vae'\n", 51 | "RUN_ID = '0002'\n", 52 | "DATA_NAME = 'digits'\n", 53 | "RUN_FOLDER = 'run/{}/'.format(SECTION)\n", 54 | "RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "## data" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "(x_train, y_train), (x_test, y_test) = load_mnist()" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "## architecture" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "vae = load_model(VariationalAutoencoder, RUN_FOLDER)" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "## reconstructing original paintings" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "n_to_show = 10\n", 103 | "example_idx = np.random.choice(range(len(x_test)), n_to_show)\n", 104 | "example_images = x_test[example_idx]\n", 105 | "\n", 106 | "_,_,z_points = vae.encoder.predict(example_images)\n", 107 | "\n", 108 | "reconst_images = vae.decoder.predict(z_points)\n", 109 | "\n", 110 | "fig = plt.figure(figsize=(15, 3))\n", 111 | "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n", 112 | "\n", 113 | "for i in range(n_to_show):\n", 114 | " img = example_images[i].squeeze()\n", 115 | " sub = fig.add_subplot(2, n_to_show, i+1)\n", 116 | " sub.axis('off')\n", 117 | " sub.text(0.5, -0.35, str(np.round(z_points[i],1)), fontsize=10, ha='center', transform=sub.transAxes)\n", 118 | " \n", 119 | " sub.imshow(img, cmap='gray_r')\n", 120 | "\n", 121 | "for i in range(n_to_show):\n", 122 | " img = reconst_images[i].squeeze()\n", 123 | " sub = fig.add_subplot(2, n_to_show, i+n_to_show+1)\n", 124 | " sub.axis('off')\n", 125 | " sub.imshow(img, cmap='gray_r')\n" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": {}, 131 | "source": [ 132 | "## Mr N. Coder's wall" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "n_to_show = 5000\n", 142 | "figsize = 12\n", 143 | "\n", 144 | "example_idx = np.random.choice(range(len(x_test)), n_to_show)\n", 145 | "example_images = x_test[example_idx]\n", 146 | "example_labels = y_test[example_idx]\n", 147 | "\n", 148 | "_,_,z_points = vae.encoder.predict(example_images)\n", 149 | "\n", 150 | "min_x = min(z_points[:, 0])\n", 151 | "max_x = max(z_points[:, 0])\n", 152 | "min_y = min(z_points[:, 1])\n", 153 | "max_y = max(z_points[:, 1])\n", 154 | "\n", 155 | "plt.figure(figsize=(figsize, figsize))\n", 156 | "plt.scatter(z_points[:, 0] , z_points[:, 1], c='black', alpha=0.5, s=2)\n", 157 | "plt.show()" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "metadata": {}, 163 | "source": [ 164 | "### The new generated art exhibition" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "figsize = 8\n", 174 | "plt.figure(figsize=(figsize, figsize))\n", 175 | "plt.scatter(z_points[:, 0] , z_points[:, 1], c='black', alpha=0.5, s=2)\n", 176 | "\n", 177 | "\n", 178 | "grid_size = 15\n", 179 | "grid_depth = 2\n", 180 | "figsize = 15\n", 181 | "\n", 182 | "x = np.random.normal(size = grid_size * grid_depth)\n", 183 | "y = np.random.normal(size = grid_size * grid_depth)\n", 184 | "\n", 185 | "z_grid = np.array(list(zip(x, y)))\n", 186 | "reconst = vae.decoder.predict(z_grid)\n", 187 | "\n", 188 | "plt.scatter(z_grid[:, 0] , z_grid[:, 1], c = 'red', alpha=1, s=20)\n", 189 | "plt.show()\n", 190 | "\n", 191 | "fig = plt.figure(figsize=(figsize, grid_depth))\n", 192 | "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n", 193 | "\n", 194 | "for i in range(grid_size*grid_depth):\n", 195 | " ax = fig.add_subplot(grid_depth, grid_size, i+1)\n", 196 | " ax.axis('off')\n", 197 | " ax.text(0.5, -0.35, str(np.round(z_grid[i],1)), fontsize=8, ha='center', transform=ax.transAxes)\n", 198 | " \n", 199 | " ax.imshow(reconst[i, :,:,0], cmap = 'Greys')" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "n_to_show = 5000\n", 209 | "grid_size = 15\n", 210 | "fig_height = 7\n", 211 | "fig_width = 15\n", 212 | "\n", 213 | "example_idx = np.random.choice(range(len(x_test)), n_to_show)\n", 214 | "example_images = x_test[example_idx]\n", 215 | "example_labels = y_test[example_idx]\n", 216 | "\n", 217 | "_,_,z_points = vae.encoder.predict(example_images)\n", 218 | "p_points = norm.cdf(z_points)\n", 219 | "\n", 220 | "fig = plt.figure(figsize=(fig_width, fig_height))\n", 221 | "\n", 222 | "ax = fig.add_subplot(1, 2, 1)\n", 223 | "plot_1 = ax.scatter(z_points[:, 0] , z_points[:, 1] , cmap='rainbow' , c= example_labels\n", 224 | " , alpha=0.5, s=2)\n", 225 | "plt.colorbar(plot_1)\n", 226 | "\n", 227 | "ax = fig.add_subplot(1, 2, 2)\n", 228 | "plot_2 = ax.scatter(p_points[:, 0] , p_points[:, 1] , cmap='rainbow' , c= example_labels\n", 229 | " , alpha=0.5, s=5)\n", 230 | "\n", 231 | "\n", 232 | "\n", 233 | "plt.show()" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "n_to_show = 5000\n", 243 | "grid_size = 20\n", 244 | "figsize = 8\n", 245 | "\n", 246 | "example_idx = np.random.choice(range(len(x_test)), n_to_show)\n", 247 | "example_images = x_test[example_idx]\n", 248 | "example_labels = y_test[example_idx]\n", 249 | "\n", 250 | "_,_,z_points = vae.encoder.predict(example_images)\n", 251 | "\n", 252 | "plt.figure(figsize=(5, 5))\n", 253 | "plt.scatter(z_points[:, 0] , z_points[:, 1] , cmap='rainbow' , c= example_labels\n", 254 | " , alpha=0.5, s=2)\n", 255 | "plt.colorbar()\n", 256 | "\n", 257 | "x = norm.ppf(np.linspace(0.01, 0.99, grid_size))\n", 258 | "y = norm.ppf(np.linspace(0.01, 0.99, grid_size))\n", 259 | "xv, yv = np.meshgrid(x, y)\n", 260 | "xv = xv.flatten()\n", 261 | "yv = yv.flatten()\n", 262 | "z_grid = np.array(list(zip(xv, yv)))\n", 263 | "\n", 264 | "reconst = vae.decoder.predict(z_grid)\n", 265 | "\n", 266 | "plt.scatter(z_grid[:, 0] , z_grid[:, 1], c = 'black'#, cmap='rainbow' , c= example_labels\n", 267 | " , alpha=1, s=2)\n", 268 | "\n", 269 | "\n", 270 | "\n", 271 | "\n", 272 | "plt.show()\n", 273 | "\n", 274 | "\n", 275 | "fig = plt.figure(figsize=(figsize, figsize))\n", 276 | "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n", 277 | "for i in range(grid_size**2):\n", 278 | " ax = fig.add_subplot(grid_size, grid_size, i+1)\n", 279 | " ax.axis('off')\n", 280 | " ax.imshow(reconst[i, :,:,0], cmap = 'Greys')" 281 | ] 282 | } 283 | ], 284 | "metadata": { 285 | "kernelspec": { 286 | "display_name": "gdl_code", 287 | "language": "python", 288 | "name": "gdl_code" 289 | }, 290 | "language_info": { 291 | "codemirror_mode": { 292 | "name": "ipython", 293 | "version": 3 294 | }, 295 | "file_extension": ".py", 296 | "mimetype": "text/x-python", 297 | "name": "python", 298 | "nbconvert_exporter": "python", 299 | "pygments_lexer": "ipython3", 300 | "version": "3.7.5" 301 | } 302 | }, 303 | "nbformat": 4, 304 | "nbformat_minor": 2 305 | } 306 | -------------------------------------------------------------------------------- /03_05_vae_faces_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# VAE Training - Faces dataset" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## imports" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "%load_ext autoreload\n", 24 | "%autoreload 2" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "import os\n", 34 | "from glob import glob\n", 35 | "import numpy as np\n", 36 | "\n", 37 | "from models.VAE import VariationalAutoencoder\n", 38 | "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "# run params\n", 48 | "section = 'vae'\n", 49 | "run_id = '0001'\n", 50 | "data_name = 'faces'\n", 51 | "RUN_FOLDER = 'run/{}/'.format(section)\n", 52 | "RUN_FOLDER += '_'.join([run_id, data_name])\n", 53 | "\n", 54 | "if not os.path.exists(RUN_FOLDER):\n", 55 | " os.mkdir(RUN_FOLDER)\n", 56 | " os.mkdir(os.path.join(RUN_FOLDER, 'viz'))\n", 57 | " os.mkdir(os.path.join(RUN_FOLDER, 'images'))\n", 58 | " os.mkdir(os.path.join(RUN_FOLDER, 'weights'))\n", 59 | "\n", 60 | "mode = 'build' #'load' #\n", 61 | "\n", 62 | "\n", 63 | "DATA_FOLDER = './data/celeb/'" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "## data" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "INPUT_DIM = (128,128,3)\n", 80 | "BATCH_SIZE = 32\n", 81 | "\n", 82 | "filenames = np.array(glob(os.path.join(DATA_FOLDER, '*/*.jpg')))\n", 83 | "\n", 84 | "NUM_IMAGES = len(filenames)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "data_gen = ImageDataGenerator(rescale=1./255)\n", 94 | "\n", 95 | "data_flow = data_gen.flow_from_directory(DATA_FOLDER\n", 96 | " , target_size = INPUT_DIM[:2]\n", 97 | " , batch_size = BATCH_SIZE\n", 98 | " , shuffle = True\n", 99 | " , class_mode = 'input'\n", 100 | " , subset = \"training\"\n", 101 | " )" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "## architecture" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "vae = VariationalAutoencoder(\n", 118 | " input_dim = INPUT_DIM\n", 119 | " , encoder_conv_filters=[32,64,64, 64]\n", 120 | " , encoder_conv_kernel_size=[3,3,3,3]\n", 121 | " , encoder_conv_strides=[2,2,2,2]\n", 122 | " , decoder_conv_t_filters=[64,64,32,3]\n", 123 | " , decoder_conv_t_kernel_size=[3,3,3,3]\n", 124 | " , decoder_conv_t_strides=[2,2,2,2]\n", 125 | " , z_dim=200\n", 126 | " , use_batch_norm=True\n", 127 | " , use_dropout=True\n", 128 | " , r_loss_factor = 10000\n", 129 | " )\n", 130 | "\n", 131 | "if mode == 'build':\n", 132 | " vae.save(RUN_FOLDER)\n", 133 | "else:\n", 134 | " vae.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "vae.encoder.summary()" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "vae.decoder.summary()" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "## training" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "LEARNING_RATE = 0.0005\n", 169 | "EPOCHS = 200\n", 170 | "PRINT_EVERY_N_BATCHES = 100\n", 171 | "INITIAL_EPOCH = 0" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "vae.compile(LEARNING_RATE)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "vae.train_with_generator( \n", 190 | " data_flow\n", 191 | " , epochs = EPOCHS\n", 192 | " , steps_per_epoch = NUM_IMAGES / BATCH_SIZE\n", 193 | " , run_folder = RUN_FOLDER\n", 194 | " , print_every_n_batches = PRINT_EVERY_N_BATCHES\n", 195 | " , initial_epoch = INITIAL_EPOCH\n", 196 | ")" 197 | ] 198 | } 199 | ], 200 | "metadata": { 201 | "kernelspec": { 202 | "display_name": "gdl_code", 203 | "language": "python", 204 | "name": "gdl_code" 205 | }, 206 | "language_info": { 207 | "codemirror_mode": { 208 | "name": "ipython", 209 | "version": 3 210 | }, 211 | "file_extension": ".py", 212 | "mimetype": "text/x-python", 213 | "name": "python", 214 | "nbconvert_exporter": "python", 215 | "pygments_lexer": "ipython3", 216 | "version": "3.7.5" 217 | } 218 | }, 219 | "nbformat": 4, 220 | "nbformat_minor": 2 221 | } 222 | -------------------------------------------------------------------------------- /04_01_gan_camel_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# GAN Training" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## imports" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import os\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "\n", 26 | "from models.GAN import GAN\n", 27 | "from utils.loaders import load_safari\n" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "# run params\n", 37 | "SECTION = 'gan'\n", 38 | "RUN_ID = '0001'\n", 39 | "DATA_NAME = 'camel'\n", 40 | "RUN_FOLDER = 'run/{}/'.format(SECTION)\n", 41 | "RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])\n", 42 | "\n", 43 | "if not os.path.exists(RUN_FOLDER):\n", 44 | " os.mkdir(RUN_FOLDER)\n", 45 | " os.mkdir(os.path.join(RUN_FOLDER, 'viz'))\n", 46 | " os.mkdir(os.path.join(RUN_FOLDER, 'images'))\n", 47 | " os.mkdir(os.path.join(RUN_FOLDER, 'weights'))\n", 48 | "\n", 49 | "mode = 'build' #'load' #" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "## data" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "(x_train, y_train) = load_safari(DATA_NAME)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "x_train.shape" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "plt.imshow(x_train[200,:,:,0], cmap = 'gray')" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "## architecture" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "gan = GAN(input_dim = (28,28,1)\n", 100 | " , discriminator_conv_filters = [64,64,128,128]\n", 101 | " , discriminator_conv_kernel_size = [5,5,5,5]\n", 102 | " , discriminator_conv_strides = [2,2,2,1]\n", 103 | " , discriminator_batch_norm_momentum = None\n", 104 | " , discriminator_activation = 'relu'\n", 105 | " , discriminator_dropout_rate = 0.4\n", 106 | " , discriminator_learning_rate = 0.0008\n", 107 | " , generator_initial_dense_layer_size = (7, 7, 64)\n", 108 | " , generator_upsample = [2,2, 1, 1]\n", 109 | " , generator_conv_filters = [128,64, 64,1]\n", 110 | " , generator_conv_kernel_size = [5,5,5,5]\n", 111 | " , generator_conv_strides = [1,1, 1, 1]\n", 112 | " , generator_batch_norm_momentum = 0.9\n", 113 | " , generator_activation = 'relu'\n", 114 | " , generator_dropout_rate = None\n", 115 | " , generator_learning_rate = 0.0004\n", 116 | " , optimiser = 'rmsprop'\n", 117 | " , z_dim = 100\n", 118 | " )\n", 119 | "\n", 120 | "if mode == 'build':\n", 121 | " gan.save(RUN_FOLDER)\n", 122 | "else:\n", 123 | " gan.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "gan.discriminator.summary()" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "gan.generator.summary()" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "## training" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "BATCH_SIZE = 64\n", 158 | "EPOCHS = 6000\n", 159 | "PRINT_EVERY_N_BATCHES = 5" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": { 166 | "scrolled": false 167 | }, 168 | "outputs": [], 169 | "source": [ 170 | "gan.train( \n", 171 | " x_train\n", 172 | " , batch_size = BATCH_SIZE\n", 173 | " , epochs = EPOCHS\n", 174 | " , run_folder = RUN_FOLDER\n", 175 | " , print_every_n_batches = PRINT_EVERY_N_BATCHES\n", 176 | ")" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "fig = plt.figure()\n", 186 | "plt.plot([x[0] for x in gan.d_losses], color='black', linewidth=0.25)\n", 187 | "\n", 188 | "plt.plot([x[1] for x in gan.d_losses], color='green', linewidth=0.25)\n", 189 | "plt.plot([x[2] for x in gan.d_losses], color='red', linewidth=0.25)\n", 190 | "plt.plot([x[0] for x in gan.g_losses], color='orange', linewidth=0.25)\n", 191 | "\n", 192 | "plt.xlabel('batch', fontsize=18)\n", 193 | "plt.ylabel('loss', fontsize=16)\n", 194 | "\n", 195 | "plt.xlim(0, 2000)\n", 196 | "plt.ylim(0, 2)\n", 197 | "\n", 198 | "plt.show()\n" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "fig = plt.figure()\n", 208 | "plt.plot([x[3] for x in gan.d_losses], color='black', linewidth=0.25)\n", 209 | "plt.plot([x[4] for x in gan.d_losses], color='green', linewidth=0.25)\n", 210 | "plt.plot([x[5] for x in gan.d_losses], color='red', linewidth=0.25)\n", 211 | "plt.plot([x[1] for x in gan.g_losses], color='orange', linewidth=0.25)\n", 212 | "\n", 213 | "plt.xlabel('batch', fontsize=18)\n", 214 | "plt.ylabel('accuracy', fontsize=16)\n", 215 | "\n", 216 | "plt.xlim(0, 2000)\n", 217 | "\n", 218 | "plt.show()" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": null, 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [] 227 | } 228 | ], 229 | "metadata": { 230 | "kernelspec": { 231 | "display_name": "gdl_code", 232 | "language": "python", 233 | "name": "gdl_code" 234 | }, 235 | "language_info": { 236 | "codemirror_mode": { 237 | "name": "ipython", 238 | "version": 3 239 | }, 240 | "file_extension": ".py", 241 | "mimetype": "text/x-python", 242 | "name": "python", 243 | "nbconvert_exporter": "python", 244 | "pygments_lexer": "ipython3", 245 | "version": "3.7.5" 246 | } 247 | }, 248 | "nbformat": 4, 249 | "nbformat_minor": 2 250 | } 251 | -------------------------------------------------------------------------------- /04_02_wgan_cifar_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# WGAN Training" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## imports" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "%matplotlib inline\n", 24 | "\n", 25 | "import os\n", 26 | "import numpy as np\n", 27 | "import matplotlib.pyplot as plt\n", 28 | "\n", 29 | "from models.WGAN import WGAN\n", 30 | "from utils.loaders import load_cifar\n", 31 | "\n" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "# run params\n", 41 | "SECTION = 'gan'\n", 42 | "RUN_ID = '0002'\n", 43 | "DATA_NAME = 'horses'\n", 44 | "RUN_FOLDER = 'run/{}/'.format(SECTION)\n", 45 | "RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])\n", 46 | "\n", 47 | "if not os.path.exists(RUN_FOLDER):\n", 48 | " os.mkdir(RUN_FOLDER)\n", 49 | " os.mkdir(os.path.join(RUN_FOLDER, 'viz'))\n", 50 | " os.mkdir(os.path.join(RUN_FOLDER, 'images'))\n", 51 | " os.mkdir(os.path.join(RUN_FOLDER, 'weights'))\n", 52 | "\n", 53 | "mode = 'build' #'load' #\n" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "## data" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "if DATA_NAME == 'cars':\n", 70 | " label = 1\n", 71 | "elif DATA_NAME == 'horses':\n", 72 | " label = 7\n", 73 | "(x_train, y_train) = load_cifar(label, 10)\n" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "plt.imshow((x_train[150,:,:,:]+1)/2)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "## architecture" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "\n", 99 | "if mode == 'build':\n", 100 | "\n", 101 | " gan = WGAN(input_dim = (32,32,3)\n", 102 | " , critic_conv_filters = [32,64,128,128]\n", 103 | " , critic_conv_kernel_size = [5,5,5,5]\n", 104 | " , critic_conv_strides = [2,2,2,1]\n", 105 | " , critic_batch_norm_momentum = None\n", 106 | " , critic_activation = 'leaky_relu'\n", 107 | " , critic_dropout_rate = None\n", 108 | " , critic_learning_rate = 0.00005\n", 109 | " , generator_initial_dense_layer_size = (4, 4, 128)\n", 110 | " , generator_upsample = [2,2, 2,1]\n", 111 | " , generator_conv_filters = [128,64,32,3]\n", 112 | " , generator_conv_kernel_size = [5,5,5,5]\n", 113 | " , generator_conv_strides = [1,1, 1,1]\n", 114 | " , generator_batch_norm_momentum = 0.8\n", 115 | " , generator_activation = 'leaky_relu'\n", 116 | " , generator_dropout_rate = None\n", 117 | " , generator_learning_rate = 0.00005\n", 118 | " , optimiser = 'rmsprop'\n", 119 | " , z_dim = 100\n", 120 | " )\n", 121 | " gan.save(RUN_FOLDER)\n", 122 | "\n", 123 | "else:\n", 124 | " gan.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))\n" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "gan.critic.summary()" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "gan.generator.summary()" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "## training" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "BATCH_SIZE = 128\n", 159 | "EPOCHS = 6000\n", 160 | "PRINT_EVERY_N_BATCHES = 5\n", 161 | "N_CRITIC = 5\n", 162 | "CLIP_THRESHOLD = 0.01" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": { 169 | "scrolled": false 170 | }, 171 | "outputs": [], 172 | "source": [ 173 | "gan.train( \n", 174 | " x_train\n", 175 | " , batch_size = BATCH_SIZE\n", 176 | " , epochs = EPOCHS\n", 177 | " , run_folder = RUN_FOLDER\n", 178 | " , print_every_n_batches = PRINT_EVERY_N_BATCHES\n", 179 | " , n_critic = N_CRITIC\n", 180 | " , clip_threshold = CLIP_THRESHOLD\n", 181 | ")" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "gan.sample_images(RUN_FOLDER)" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "fig = plt.figure()\n", 200 | "plt.plot([x[0] for x in gan.d_losses], color='black', linewidth=0.25)\n", 201 | "\n", 202 | "plt.plot([x[1] for x in gan.d_losses], color='green', linewidth=0.25)\n", 203 | "plt.plot([x[2] for x in gan.d_losses], color='red', linewidth=0.25)\n", 204 | "plt.plot(gan.g_losses, color='orange', linewidth=0.25)\n", 205 | "\n", 206 | "plt.xlabel('batch', fontsize=18)\n", 207 | "plt.ylabel('loss', fontsize=16)\n", 208 | "\n", 209 | "# plt.xlim(0, 2000)\n", 210 | "# plt.ylim(0, 2)\n", 211 | "\n", 212 | "plt.show()" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "def compare_images(img1, img2):\n", 222 | " return np.mean(np.abs(img1 - img2))" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "\n", 232 | "r, c = 5, 5\n", 233 | "\n", 234 | "idx = np.random.randint(0, x_train.shape[0], BATCH_SIZE)\n", 235 | "true_imgs = (x_train[idx] + 1) *0.5\n", 236 | "\n", 237 | "fig, axs = plt.subplots(r, c, figsize=(15,15))\n", 238 | "cnt = 0\n", 239 | "\n", 240 | "for i in range(r):\n", 241 | " for j in range(c):\n", 242 | " axs[i,j].imshow(true_imgs[cnt], cmap = 'gray_r')\n", 243 | " axs[i,j].axis('off')\n", 244 | " cnt += 1\n", 245 | "fig.savefig(os.path.join(RUN_FOLDER, \"images/real.png\"))\n", 246 | "plt.close()" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "r, c = 5, 5\n", 256 | "noise = np.random.normal(0, 1, (r * c, gan.z_dim))\n", 257 | "gen_imgs = gan.generator.predict(noise)\n", 258 | "\n", 259 | "#Rescale images 0 - 1\n", 260 | "\n", 261 | "gen_imgs = 0.5 * (gen_imgs + 1)\n", 262 | "# gen_imgs = np.clip(gen_imgs, 0, 1)\n", 263 | "\n", 264 | "fig, axs = plt.subplots(r, c, figsize=(15,15))\n", 265 | "cnt = 0\n", 266 | "\n", 267 | "for i in range(r):\n", 268 | " for j in range(c):\n", 269 | " axs[i,j].imshow(np.squeeze(gen_imgs[cnt, :,:,:]), cmap = 'gray_r')\n", 270 | " axs[i,j].axis('off')\n", 271 | " cnt += 1\n", 272 | "fig.savefig(os.path.join(RUN_FOLDER, \"images/sample.png\"))\n", 273 | "plt.close()\n", 274 | "\n", 275 | "\n", 276 | "fig, axs = plt.subplots(r, c, figsize=(15,15))\n", 277 | "cnt = 0\n", 278 | "\n", 279 | "for i in range(r):\n", 280 | " for j in range(c):\n", 281 | " c_diff = 99999\n", 282 | " c_img = None\n", 283 | " for k_idx, k in enumerate((x_train + 1) * 0.5):\n", 284 | " \n", 285 | " diff = compare_images(gen_imgs[cnt, :,:,:], k)\n", 286 | " if diff < c_diff:\n", 287 | " c_img = np.copy(k)\n", 288 | " c_diff = diff\n", 289 | " axs[i,j].imshow(c_img, cmap = 'gray_r')\n", 290 | " axs[i,j].axis('off')\n", 291 | " cnt += 1\n", 292 | "\n", 293 | "fig.savefig(os.path.join(RUN_FOLDER, \"images/sample_closest.png\"))\n", 294 | "plt.close()" 295 | ] 296 | } 297 | ], 298 | "metadata": { 299 | "kernelspec": { 300 | "display_name": "gdl_code", 301 | "language": "python", 302 | "name": "gdl_code" 303 | }, 304 | "language_info": { 305 | "codemirror_mode": { 306 | "name": "ipython", 307 | "version": 3 308 | }, 309 | "file_extension": ".py", 310 | "mimetype": "text/x-python", 311 | "name": "python", 312 | "nbconvert_exporter": "python", 313 | "pygments_lexer": "ipython3", 314 | "version": "3.7.5" 315 | } 316 | }, 317 | "nbformat": 4, 318 | "nbformat_minor": 2 319 | } 320 | -------------------------------------------------------------------------------- /04_03_wgangp_faces_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# WGAN-GP Training" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## imports" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "%matplotlib inline\n", 24 | "\n", 25 | "import os\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "\n", 28 | "from models.WGANGP import WGANGP\n", 29 | "from utils.loaders import load_celeb\n", 30 | "\n", 31 | "import pickle\n" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "# run params\n", 41 | "SECTION = 'gan'\n", 42 | "RUN_ID = '0003'\n", 43 | "DATA_NAME = 'celeb'\n", 44 | "RUN_FOLDER = 'run/{}/'.format(SECTION)\n", 45 | "RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])\n", 46 | "\n", 47 | "if not os.path.exists(RUN_FOLDER):\n", 48 | " os.mkdir(RUN_FOLDER)\n", 49 | " os.mkdir(os.path.join(RUN_FOLDER, 'viz'))\n", 50 | " os.mkdir(os.path.join(RUN_FOLDER, 'images'))\n", 51 | " os.mkdir(os.path.join(RUN_FOLDER, 'weights'))\n", 52 | "\n", 53 | "mode = 'build' #'load' #" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "## data" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "BATCH_SIZE = 64\n", 70 | "IMAGE_SIZE = 64" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "x_train = load_celeb(DATA_NAME, IMAGE_SIZE, BATCH_SIZE)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "x_train[0][0][0]" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "plt.imshow((x_train[0][0][0]+1)/2)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "## architecture" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "gan = WGANGP(input_dim = (IMAGE_SIZE,IMAGE_SIZE,3)\n", 114 | " , critic_conv_filters = [64,128,256,512]\n", 115 | " , critic_conv_kernel_size = [5,5,5,5]\n", 116 | " , critic_conv_strides = [2,2,2,2]\n", 117 | " , critic_batch_norm_momentum = None\n", 118 | " , critic_activation = 'leaky_relu'\n", 119 | " , critic_dropout_rate = None\n", 120 | " , critic_learning_rate = 0.0002\n", 121 | " , generator_initial_dense_layer_size = (4, 4, 512)\n", 122 | " , generator_upsample = [1,1,1,1]\n", 123 | " , generator_conv_filters = [256,128,64,3]\n", 124 | " , generator_conv_kernel_size = [5,5,5,5]\n", 125 | " , generator_conv_strides = [2,2,2,2]\n", 126 | " , generator_batch_norm_momentum = 0.9\n", 127 | " , generator_activation = 'leaky_relu'\n", 128 | " , generator_dropout_rate = None\n", 129 | " , generator_learning_rate = 0.0002\n", 130 | " , optimiser = 'adam'\n", 131 | " , grad_weight = 10\n", 132 | " , z_dim = 100\n", 133 | " , batch_size = BATCH_SIZE\n", 134 | " )\n", 135 | "\n", 136 | "if mode == 'build':\n", 137 | " gan.save(RUN_FOLDER)\n", 138 | "\n", 139 | "else:\n", 140 | " gan.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))\n" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "gan.critic.summary()" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "gan.generator.summary()" 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "metadata": {}, 164 | "source": [ 165 | "## training" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "EPOCHS = 6000\n", 175 | "PRINT_EVERY_N_BATCHES = 5\n", 176 | "N_CRITIC = 5\n", 177 | "BATCH_SIZE = 64" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [ 186 | "gan.train( \n", 187 | " x_train\n", 188 | " , batch_size = BATCH_SIZE\n", 189 | " , epochs = EPOCHS\n", 190 | " , run_folder = RUN_FOLDER\n", 191 | " , print_every_n_batches = PRINT_EVERY_N_BATCHES\n", 192 | " , n_critic = N_CRITIC\n", 193 | " , using_generator = True\n", 194 | ")" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "fig = plt.figure()\n", 204 | "plt.plot([x[0] for x in gan.d_losses], color='black', linewidth=0.25)\n", 205 | "\n", 206 | "plt.plot([x[1] for x in gan.d_losses], color='green', linewidth=0.25)\n", 207 | "plt.plot([x[2] for x in gan.d_losses], color='red', linewidth=0.25)\n", 208 | "plt.plot(gan.g_losses, color='orange', linewidth=0.25)\n", 209 | "\n", 210 | "plt.xlabel('batch', fontsize=18)\n", 211 | "plt.ylabel('loss', fontsize=16)\n", 212 | "\n", 213 | "plt.xlim(0, 2000)\n", 214 | "# plt.ylim(0, 2)\n", 215 | "\n", 216 | "plt.show()\n" 217 | ] 218 | } 219 | ], 220 | "metadata": { 221 | "kernelspec": { 222 | "display_name": "gdl_code", 223 | "language": "python", 224 | "name": "gdl_code" 225 | }, 226 | "language_info": { 227 | "codemirror_mode": { 228 | "name": "ipython", 229 | "version": 3 230 | }, 231 | "file_extension": ".py", 232 | "mimetype": "text/x-python", 233 | "name": "python", 234 | "nbconvert_exporter": "python", 235 | "pygments_lexer": "ipython3", 236 | "version": "3.7.5" 237 | } 238 | }, 239 | "nbformat": 4, 240 | "nbformat_minor": 2 241 | } 242 | -------------------------------------------------------------------------------- /05_01_cyclegan_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# CycleGAN train" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "%load_ext autoreload\n", 17 | "%autoreload 2" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "import os\n", 27 | "import matplotlib.pyplot as plt\n", 28 | "\n", 29 | "from models.cycleGAN import CycleGAN\n", 30 | "from utils.loaders import DataLoader" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "\n", 40 | "# run params\n", 41 | "SECTION = 'paint'\n", 42 | "RUN_ID = '0001'\n", 43 | "DATA_NAME = 'apple2orange'\n", 44 | "RUN_FOLDER = 'run/{}/'.format(SECTION)\n", 45 | "RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])\n", 46 | "\n", 47 | "if not os.path.exists(RUN_FOLDER):\n", 48 | " os.mkdir(RUN_FOLDER)\n", 49 | " os.mkdir(os.path.join(RUN_FOLDER, 'viz'))\n", 50 | " os.mkdir(os.path.join(RUN_FOLDER, 'images'))\n", 51 | " os.mkdir(os.path.join(RUN_FOLDER, 'weights'))\n", 52 | "\n", 53 | "mode = 'build' # 'build' # " 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "# data" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "IMAGE_SIZE = 128" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "\n", 79 | "data_loader = DataLoader(dataset_name=DATA_NAME, img_res=(IMAGE_SIZE, IMAGE_SIZE))\n" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "# architecture" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "gan = CycleGAN(\n", 96 | " input_dim = (IMAGE_SIZE,IMAGE_SIZE,3)\n", 97 | " ,learning_rate = 0.0002\n", 98 | " , buffer_max_length = 50\n", 99 | " , lambda_validation = 1\n", 100 | " , lambda_reconstr = 10\n", 101 | " , lambda_id = 2\n", 102 | " , generator_type = 'unet'\n", 103 | " , gen_n_filters = 32\n", 104 | " , disc_n_filters = 32\n", 105 | " )\n", 106 | "if mode == 'build':\n", 107 | " gan.save(RUN_FOLDER)\n", 108 | "else:\n", 109 | " gan.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))\n", 110 | " " 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "gan.g_BA.summary()" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "gan.g_AB.summary()" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "gan.d_A.summary()" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "gan.d_B.summary()" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "metadata": {}, 152 | "source": [ 153 | "# train" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "BATCH_SIZE = 1\n", 163 | "EPOCHS = 200\n", 164 | "PRINT_EVERY_N_BATCHES = 10\n", 165 | "\n", 166 | "TEST_A_FILE = 'n07740461_14740.jpg'\n", 167 | "TEST_B_FILE = 'n07749192_4241.jpg'" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": { 174 | "scrolled": false 175 | }, 176 | "outputs": [], 177 | "source": [ 178 | "gan.train(data_loader\n", 179 | " , run_folder = RUN_FOLDER\n", 180 | " , epochs=EPOCHS\n", 181 | " , test_A_file = TEST_A_FILE\n", 182 | " , test_B_file = TEST_B_FILE\n", 183 | " , batch_size=BATCH_SIZE\n", 184 | " , sample_interval=PRINT_EVERY_N_BATCHES)\n", 185 | " " 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": {}, 191 | "source": [ 192 | "# loss" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "fig = plt.figure(figsize=(20,10))\n", 202 | "\n", 203 | "plt.plot([x[1] for x in gan.g_losses], color='green', linewidth=0.1) #DISCRIM LOSS\n", 204 | "# plt.plot([x[2] for x in gan.g_losses], color='orange', linewidth=0.1)\n", 205 | "plt.plot([x[3] for x in gan.g_losses], color='blue', linewidth=0.1) #CYCLE LOSS\n", 206 | "# plt.plot([x[4] for x in gan.g_losses], color='orange', linewidth=0.25)\n", 207 | "plt.plot([x[5] for x in gan.g_losses], color='red', linewidth=0.25) #ID LOSS\n", 208 | "# plt.plot([x[6] for x in gan.g_losses], color='orange', linewidth=0.25)\n", 209 | "\n", 210 | "plt.plot([x[0] for x in gan.g_losses], color='black', linewidth=0.25)\n", 211 | "\n", 212 | "# plt.plot([x[0] for x in gan.d_losses], color='black', linewidth=0.25)\n", 213 | "\n", 214 | "plt.xlabel('batch', fontsize=18)\n", 215 | "plt.ylabel('loss', fontsize=16)\n", 216 | "\n", 217 | "plt.ylim(0, 5)\n", 218 | "\n", 219 | "plt.show()" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [] 228 | } 229 | ], 230 | "metadata": { 231 | "kernelspec": { 232 | "display_name": "gdl_code", 233 | "language": "python", 234 | "name": "gdl_code" 235 | }, 236 | "language_info": { 237 | "codemirror_mode": { 238 | "name": "ipython", 239 | "version": 3 240 | }, 241 | "file_extension": ".py", 242 | "mimetype": "text/x-python", 243 | "name": "python", 244 | "nbconvert_exporter": "python", 245 | "pygments_lexer": "ipython3", 246 | "version": "3.7.5" 247 | } 248 | }, 249 | "nbformat": 4, 250 | "nbformat_minor": 2 251 | } 252 | -------------------------------------------------------------------------------- /06_01_lstm_text_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# LSTM Training" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## imports" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import numpy as np\n", 24 | "import re\n", 25 | "from IPython.display import clear_output\n", 26 | "\n", 27 | "from tensorflow.keras.layers import Dense, LSTM, Input, Embedding, Dropout\n", 28 | "from tensorflow.keras.utils import to_categorical\n", 29 | "from tensorflow.keras.models import Model, load_model\n", 30 | "from tensorflow.keras.optimizers import Adam, RMSprop\n", 31 | "from tensorflow.keras.preprocessing.sequence import pad_sequences\n", 32 | "from tensorflow.keras.preprocessing.text import Tokenizer\n", 33 | "from tensorflow.keras.callbacks import LambdaCallback" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "load_saved_model = False\n", 43 | "train_model = True" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "token_type = 'word'" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "## data" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "#load in the text and perform some cleanup\n", 69 | "\n", 70 | "seq_length = 20\n", 71 | "\n", 72 | "filename = \"./data/aesop/data.txt\"\n", 73 | "\n", 74 | "with open(filename, encoding='utf-8-sig') as f:\n", 75 | " text = f.read()\n", 76 | " \n", 77 | " \n", 78 | "#removing text before and after the main stories\n", 79 | "start = text.find(\"THE FOX AND THE GRAPES\\n\\n\\n\")\n", 80 | "end = text.find(\"ILLUSTRATIONS\\n\\n\\n[\")\n", 81 | "text = text[start:end]" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "start_story = '| ' * seq_length\n", 91 | " \n", 92 | "text = start_story + text\n", 93 | "text = text.lower()\n", 94 | "text = text.replace('\\n\\n\\n\\n\\n', start_story)\n", 95 | "text = text.replace('\\n', ' ')\n", 96 | "text = re.sub(' +', '. ', text).strip()\n", 97 | "text = text.replace('..', '.')\n", 98 | "\n", 99 | "text = re.sub('([!\"#$%&()*+,-./:;<=>?@[\\]^_`{|}~])', r' \\1 ', text)\n", 100 | "text = re.sub('\\s{2,}', ' ', text)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "len(text)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "text" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "\n", 128 | "if token_type == 'word':\n", 129 | " tokenizer = Tokenizer(char_level = False, filters = '')\n", 130 | "else:\n", 131 | " tokenizer = Tokenizer(char_level = True, filters = '', lower = False)\n", 132 | " \n", 133 | " \n", 134 | "tokenizer.fit_on_texts([text])\n", 135 | "\n", 136 | "total_words = len(tokenizer.word_index) + 1\n", 137 | "\n", 138 | "token_list = tokenizer.texts_to_sequences([text])[0]\n" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "total_words" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": { 154 | "scrolled": true 155 | }, 156 | "outputs": [], 157 | "source": [ 158 | "print(tokenizer.word_index)\n", 159 | "print(token_list)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "def generate_sequences(token_list, step):\n", 169 | " \n", 170 | " X = []\n", 171 | " y = []\n", 172 | "\n", 173 | " for i in range(0, len(token_list) - seq_length, step):\n", 174 | " X.append(token_list[i: i + seq_length])\n", 175 | " y.append(token_list[i + seq_length])\n", 176 | " \n", 177 | "\n", 178 | " y = to_categorical(y, num_classes = total_words)\n", 179 | " \n", 180 | " num_seq = len(X)\n", 181 | " print('Number of sequences:', num_seq, \"\\n\")\n", 182 | " \n", 183 | " return X, y, num_seq\n", 184 | "\n", 185 | "step = 1\n", 186 | "seq_length = 20\n", 187 | "\n", 188 | "X, y, num_seq = generate_sequences(token_list, step)\n", 189 | "\n", 190 | "X = np.array(X)\n", 191 | "y = np.array(y)\n" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "X.shape" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "y.shape" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "metadata": {}, 215 | "source": [ 216 | "## Define the LSTM model" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "if load_saved_model:\n", 226 | " # model = load_model('./saved_models/lstm_aesop_1.h5')\n", 227 | " model = load_model('./saved_models/aesop_dropout_100.h5')\n", 228 | "\n", 229 | "else:\n", 230 | "\n", 231 | " n_units = 256\n", 232 | " embedding_size = 100\n", 233 | "\n", 234 | " text_in = Input(shape = (None,))\n", 235 | " embedding = Embedding(total_words, embedding_size)\n", 236 | " x = embedding(text_in)\n", 237 | " x = LSTM(n_units)(x)\n", 238 | " # x = Dropout(0.2)(x)\n", 239 | " text_out = Dense(total_words, activation = 'softmax')(x)\n", 240 | "\n", 241 | " model = Model(text_in, text_out)\n", 242 | "\n", 243 | " opti = RMSprop(lr = 0.001)\n", 244 | " model.compile(loss='categorical_crossentropy', optimizer=opti)" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "model.summary()" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "def sample_with_temp(preds, temperature=1.0):\n", 263 | " # helper function to sample an index from a probability array\n", 264 | " preds = np.asarray(preds).astype('float64')\n", 265 | " preds = np.log(preds) / temperature\n", 266 | " exp_preds = np.exp(preds)\n", 267 | " preds = exp_preds / np.sum(exp_preds)\n", 268 | " probas = np.random.multinomial(1, preds, 1)\n", 269 | " return np.argmax(probas)\n", 270 | "\n", 271 | "\n", 272 | "\n", 273 | "def generate_text(seed_text, next_words, model, max_sequence_len, temp):\n", 274 | " output_text = seed_text\n", 275 | " \n", 276 | " seed_text = start_story + seed_text\n", 277 | " \n", 278 | " for _ in range(next_words):\n", 279 | " token_list = tokenizer.texts_to_sequences([seed_text])[0]\n", 280 | " token_list = token_list[-max_sequence_len:]\n", 281 | " token_list = np.reshape(token_list, (1, max_sequence_len))\n", 282 | " \n", 283 | " probs = model.predict(token_list, verbose=0)[0]\n", 284 | " y_class = sample_with_temp(probs, temperature = temp)\n", 285 | " \n", 286 | " if y_class == 0:\n", 287 | " output_word = ''\n", 288 | " else:\n", 289 | " output_word = tokenizer.index_word[y_class]\n", 290 | " \n", 291 | " if output_word == \"|\":\n", 292 | " break\n", 293 | " \n", 294 | " if token_type == 'word':\n", 295 | " output_text += output_word + ' '\n", 296 | " seed_text += output_word + ' '\n", 297 | " else:\n", 298 | " output_text += output_word + ' '\n", 299 | " seed_text += output_word + ' '\n", 300 | " \n", 301 | " \n", 302 | " return output_text" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": null, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "def on_epoch_end(epoch, logs):\n", 312 | " seed_text = \"\"\n", 313 | " gen_words = 500\n", 314 | "\n", 315 | " print('Temp 0.2')\n", 316 | " print (generate_text(seed_text, gen_words, model, seq_length, temp = 0.2))\n", 317 | " print('Temp 0.33')\n", 318 | " print (generate_text(seed_text, gen_words, model, seq_length, temp = 0.33))\n", 319 | " print('Temp 0.5')\n", 320 | " print (generate_text(seed_text, gen_words, model, seq_length, temp = 0.5))\n", 321 | " print('Temp 1.0')\n", 322 | " print (generate_text(seed_text, gen_words, model, seq_length, temp = 1))\n", 323 | "\n", 324 | " \n", 325 | " \n", 326 | "if train_model:\n", 327 | " epochs = 1000\n", 328 | " batch_size = 32\n", 329 | " num_batches = int(len(X) / batch_size)\n", 330 | " callback = LambdaCallback(on_epoch_end=on_epoch_end)\n", 331 | " model.fit(X, y, epochs=epochs, batch_size=batch_size, callbacks = [callback], shuffle = True)\n", 332 | "\n", 333 | "\n" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [ 342 | "model.summary()" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": null, 348 | "metadata": {}, 349 | "outputs": [], 350 | "source": [ 351 | "seed_text = \"the frog and the snake . \"\n", 352 | "gen_words = 500\n", 353 | "temp = 0.1\n", 354 | "\n", 355 | "print (generate_text(seed_text, gen_words, model, seq_length, temp))" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": null, 361 | "metadata": {}, 362 | "outputs": [], 363 | "source": [ 364 | "def generate_human_led_text(model, max_sequence_len):\n", 365 | " \n", 366 | " output_text = ''\n", 367 | " seed_text = start_story\n", 368 | " \n", 369 | " while 1:\n", 370 | " token_list = tokenizer.texts_to_sequences([seed_text])[0]\n", 371 | " token_list = token_list[-max_sequence_len:]\n", 372 | " token_list = np.reshape(token_list, (1, max_sequence_len))\n", 373 | " \n", 374 | " probs = model.predict(token_list, verbose=0)[0]\n", 375 | "\n", 376 | " top_10_idx = np.flip(np.argsort(probs)[-10:])\n", 377 | " top_10_probs = [probs[x] for x in top_10_idx]\n", 378 | " top_10_words = tokenizer.sequences_to_texts([[x] for x in top_10_idx])\n", 379 | " \n", 380 | " for prob, word in zip(top_10_probs, top_10_words):\n", 381 | " print('{:<6.1%} : {}'.format(prob, word))\n", 382 | "\n", 383 | " chosen_word = input()\n", 384 | " \n", 385 | " if chosen_word == '|':\n", 386 | " break\n", 387 | " \n", 388 | " \n", 389 | " seed_text += chosen_word + ' '\n", 390 | " output_text += chosen_word + ' '\n", 391 | " \n", 392 | " clear_output()\n", 393 | "\n", 394 | " print (output_text)\n", 395 | " \n", 396 | " \n", 397 | " " 398 | ] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "execution_count": null, 403 | "metadata": {}, 404 | "outputs": [], 405 | "source": [ 406 | "generate_human_led_text(model, 20)" 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": null, 412 | "metadata": {}, 413 | "outputs": [], 414 | "source": [ 415 | "# model.save('./saved_models/aesop_no_dropout_100.h5')" 416 | ] 417 | } 418 | ], 419 | "metadata": { 420 | "kernelspec": { 421 | "display_name": "gdl_code", 422 | "language": "python", 423 | "name": "gdl_code" 424 | }, 425 | "language_info": { 426 | "codemirror_mode": { 427 | "name": "ipython", 428 | "version": 3 429 | }, 430 | "file_extension": ".py", 431 | "mimetype": "text/x-python", 432 | "name": "python", 433 | "nbconvert_exporter": "python", 434 | "pygments_lexer": "ipython3", 435 | "version": "3.7.5" 436 | } 437 | }, 438 | "nbformat": 4, 439 | "nbformat_minor": 2 440 | } 441 | -------------------------------------------------------------------------------- /06_03_qa_analysis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import importlib\n", 10 | "import os\n", 11 | "\n", 12 | "from tensorflow.keras.layers import Input, Embedding, GRU, GRUCell, Bidirectional, TimeDistributed, Dense, Lambda\n", 13 | "from tensorflow.keras.models import Model, load_model\n", 14 | "from tensorflow.keras.preprocessing.sequence import pad_sequences\n", 15 | "from tensorflow.keras.optimizers import Adam\n", 16 | "import tensorflow.keras.backend as K\n", 17 | "from tensorflow.keras.utils import plot_model\n", 18 | "\n", 19 | "import numpy as np\n", 20 | "import random\n", 21 | "\n", 22 | "from utils.write import training_data, test_data, collapse_documents, expand_answers, _read_data, get_glove, START_TOKEN, END_TOKEN, look_up_token\n", 23 | "\n", 24 | "import matplotlib.pyplot as plt" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "# run params\n", 34 | "SECTION = 'write'\n", 35 | "RUN_ID = '0001'\n", 36 | "DATA_NAME = 'qa'\n", 37 | "RUN_FOLDER = 'run/{}/'.format(SECTION)\n", 38 | "RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])\n", 39 | "\n", 40 | "if not os.path.exists(RUN_FOLDER):\n", 41 | " os.mkdir(RUN_FOLDER)\n", 42 | " os.mkdir(os.path.join(RUN_FOLDER, 'viz'))\n", 43 | " os.mkdir(os.path.join(RUN_FOLDER, 'images'))\n", 44 | " os.mkdir(os.path.join(RUN_FOLDER, 'weights'))\n", 45 | "\n", 46 | "mode = 'build' #'load' #" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "# data" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "#### LOAD DATA ####\n", 63 | "\n", 64 | "test_data_gen = test_data()\n", 65 | "batch = next(test_data_gen)\n", 66 | "batch = collapse_documents(batch)\n", 67 | "glove = get_glove()" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "# parameters" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "VOCAB_SIZE = glove.shape[0]\n", 84 | "EMBEDDING_DIMENS = glove.shape[1]\n", 85 | "\n", 86 | "GRU_UNITS = 100\n", 87 | "MAX_DOC_SIZE = None\n", 88 | "MAX_ANSWER_SIZE = None\n", 89 | "MAX_Q_SIZE = None" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "# architecture" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "#### TRAINING MODEL ####\n", 106 | "\n", 107 | "document_tokens = Input(shape=(MAX_DOC_SIZE,), name=\"document_tokens\")\n", 108 | "\n", 109 | "embedding = Embedding(input_dim = VOCAB_SIZE, output_dim = EMBEDDING_DIMENS, weights=[glove], mask_zero = True, name = 'embedding')\n", 110 | "document_emb = embedding(document_tokens)\n", 111 | "\n", 112 | "answer_outputs = Bidirectional(GRU(GRU_UNITS, return_sequences=True), name = 'answer_outputs')(document_emb)\n", 113 | "answer_tags = Dense(2, activation = 'softmax', name = 'answer_tags')(answer_outputs)\n", 114 | "\n", 115 | "encoder_input_mask = Input(shape=(MAX_ANSWER_SIZE, MAX_DOC_SIZE), name=\"encoder_input_mask\")\n", 116 | "encoder_inputs = Lambda(lambda x: K.batch_dot(x[0], x[1]), name=\"encoder_inputs\")([encoder_input_mask, answer_outputs])\n", 117 | "encoder_cell = GRU(2 * GRU_UNITS, name = 'encoder_cell')(encoder_inputs)\n", 118 | "\n", 119 | "decoder_inputs = Input(shape=(MAX_Q_SIZE,), name=\"decoder_inputs\")\n", 120 | "decoder_emb = embedding(decoder_inputs)\n", 121 | "decoder_emb.trainable = False\n", 122 | "decoder_cell = GRU(2 * GRU_UNITS, return_sequences = True, name = 'decoder_cell')\n", 123 | "decoder_states = decoder_cell(decoder_emb, initial_state = [encoder_cell])\n", 124 | "\n", 125 | "decoder_projection = Dense(VOCAB_SIZE, name = 'decoder_projection', activation = 'softmax', use_bias = False)\n", 126 | "decoder_outputs = decoder_projection(decoder_states)\n", 127 | "\n", 128 | "total_model = Model([document_tokens, decoder_inputs, encoder_input_mask], [answer_tags, decoder_outputs])\n", 129 | "plot_model(total_model, to_file=os.path.join(RUN_FOLDER ,'viz/model.png'),show_shapes=True)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "#### INFERENCE MODEL ####\n", 139 | "\n", 140 | "decoder_inputs_dynamic = Input(shape=(1,), name=\"decoder_inputs_dynamic\")\n", 141 | "decoder_emb_dynamic = embedding(decoder_inputs_dynamic)\n", 142 | "decoder_init_state_dynamic = Input(shape=(2 * GRU_UNITS,), name = 'decoder_init_state_dynamic') #the embedding of the previous word\n", 143 | "decoder_states_dynamic = decoder_cell(decoder_emb_dynamic, initial_state = [decoder_init_state_dynamic])\n", 144 | "decoder_outputs_dynamic = decoder_projection(decoder_states_dynamic)\n", 145 | "\n", 146 | "answer_model = Model(document_tokens, [answer_tags])\n", 147 | "decoder_initial_state_model = Model([document_tokens, encoder_input_mask], [encoder_cell])\n", 148 | "question_model = Model([decoder_inputs_dynamic, decoder_init_state_dynamic], [decoder_outputs_dynamic, decoder_states_dynamic])\n", 149 | "\n" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "### LOAD MODEL WEIGHTS ####\n", 159 | "\n", 160 | "model_num = 1\n", 161 | "\n", 162 | "total_model.load_weights(os.path.join(RUN_FOLDER, 'weights/weights_{}.h5'.format(model_num)), by_name = True)\n", 163 | "question_model.load_weights(os.path.join(RUN_FOLDER, 'weights/weights_{}.h5'.format(model_num)), by_name = True)\n", 164 | "answer_model.load_weights(os.path.join(RUN_FOLDER, 'weights/weights_{}.h5'.format(model_num)), by_name = True)\n", 165 | "decoder_initial_state_model.load_weights(os.path.join(RUN_FOLDER, 'weights/weights_{}.h5'.format(model_num)), by_name = True)\n", 166 | "\n" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "metadata": {}, 172 | "source": [ 173 | "# testing" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "# answer placement predictions\n", 183 | "\n", 184 | "plt.figure(figsize=(15,5))\n", 185 | "idx = 0\n", 186 | "\n", 187 | "answer_preds = answer_model.predict(batch[\"document_tokens\"])\n", 188 | "\n", 189 | "print('Predicted answer probabilities')\n", 190 | "ax = plt.gca()\n", 191 | "ax.xaxis.grid(True)\n", 192 | "plt.plot(answer_preds[idx, :, 1])\n", 193 | "plt.show()\n", 194 | "\n", 195 | "for i in range(len(batch['document_words'][idx])):\n", 196 | " print(i, batch['document_words'][idx][i], np.round(answer_preds[idx][i][1],2))" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "# set chosen answer position\n", 206 | "\n", 207 | "start_answer = 37\n", 208 | "end_answer = 39\n", 209 | "\n", 210 | "print(batch['document_words'][idx][start_answer:(1+end_answer)])" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "\n", 220 | "answer_preds = answer_model.predict(batch[\"document_tokens\"])\n", 221 | "\n", 222 | "answers = [[0] * len(answer_preds[idx])]\n", 223 | "for i in range(start_answer, end_answer + 1):\n", 224 | " answers[0][i] = 1\n", 225 | "\n", 226 | "answer_batch = expand_answers(batch, answers)\n", 227 | "\n", 228 | "next_decoder_init_state = decoder_initial_state_model.predict([answer_batch['document_tokens'][[idx]], answer_batch['answer_masks'][[idx]]])\n", 229 | "\n", 230 | "word_tokens = [START_TOKEN]\n", 231 | "questions = [look_up_token(START_TOKEN)]\n", 232 | "\n", 233 | "ended = False\n", 234 | "counter = 0\n", 235 | "\n", 236 | "while not ended:\n", 237 | " \n", 238 | " counter += 1\n", 239 | "\n", 240 | " word_preds, next_decoder_init_state = question_model.predict([np.array(word_tokens), next_decoder_init_state])\n", 241 | "\n", 242 | " next_decoder_init_state = np.squeeze(next_decoder_init_state, axis = 1)\n", 243 | " word_tokens = np.argmax(word_preds, 2)[0]\n", 244 | "\n", 245 | " questions.append(look_up_token(word_tokens[0]))\n", 246 | "\n", 247 | " if word_tokens[0] == END_TOKEN or counter > 20 :\n", 248 | " ended = True\n", 249 | "\n", 250 | "questions = ' '.join(questions)\n", 251 | "\n" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "questions" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [] 269 | } 270 | ], 271 | "metadata": { 272 | "kernelspec": { 273 | "display_name": "gdl_code", 274 | "language": "python", 275 | "name": "gdl_code" 276 | }, 277 | "language_info": { 278 | "codemirror_mode": { 279 | "name": "ipython", 280 | "version": 3 281 | }, 282 | "file_extension": ".py", 283 | "mimetype": "text/x-python", 284 | "name": "python", 285 | "nbconvert_exporter": "python", 286 | "pygments_lexer": "ipython3", 287 | "version": "3.7.5" 288 | } 289 | }, 290 | "nbformat": 4, 291 | "nbformat_minor": 2 292 | } 293 | -------------------------------------------------------------------------------- /07_01_notation_compose.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "\n", 10 | "from music21 import converter, note, chord\n" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": {}, 16 | "source": [ 17 | "# Getting the data\n", 18 | "\n", 19 | "You can find midi files for each of the 36 movements in the J.S. Bach Cello Suites here:\n", 20 | "\n", 21 | "http://www.jsbach.net/midi/midi_solo_cello.html\n", 22 | "\n", 23 | "Save these inside the './data/cello' folder" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "# Musical notation software\n", 31 | "\n", 32 | "You'll also need to download some software to view and listen to the music generated by the model.\n", 33 | "\n", 34 | "Musescore can be freely downloaded here:\n", 35 | "\n", 36 | "https://musescore.org/en" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "# Viewing the data" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "dataset_name = 'cello'\n", 53 | "filename = 'cs1-2all'\n", 54 | "file = \"./data/{}/{}.mid\".format(dataset_name, filename)\n", 55 | "\n", 56 | "original_score = converter.parse(file).chordify()" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "original_score.show()" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": { 72 | "scrolled": true 73 | }, 74 | "outputs": [], 75 | "source": [ 76 | "original_score.show('text')" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "# Extracting the data" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "notes = []\n", 93 | "durations = []\n", 94 | "\n", 95 | "for element in original_score.flat:\n", 96 | " \n", 97 | " if isinstance(element, chord.Chord):\n", 98 | " notes.append('.'.join(n.nameWithOctave for n in element.pitches))\n", 99 | " durations.append(element.duration.quarterLength)\n", 100 | "\n", 101 | " if isinstance(element, note.Note):\n", 102 | " if element.isRest:\n", 103 | " notes.append(str(element.name))\n", 104 | " durations.append(element.duration.quarterLength)\n", 105 | " else:\n", 106 | " notes.append(str(element.nameWithOctave))\n", 107 | " durations.append(element.duration.quarterLength)\n", 108 | "\n", 109 | " " 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "print('\\nduration', 'pitch')\n", 119 | "for n,d in zip(notes,durations):\n", 120 | " print(d, '\\t', n)" 121 | ] 122 | } 123 | ], 124 | "metadata": { 125 | "kernelspec": { 126 | "display_name": "gdl_code", 127 | "language": "python", 128 | "name": "gdl_code" 129 | }, 130 | "language_info": { 131 | "codemirror_mode": { 132 | "name": "ipython", 133 | "version": 3 134 | }, 135 | "file_extension": ".py", 136 | "mimetype": "text/x-python", 137 | "name": "python", 138 | "nbconvert_exporter": "python", 139 | "pygments_lexer": "ipython3", 140 | "version": "3.7.5" 141 | } 142 | }, 143 | "nbformat": 4, 144 | "nbformat_minor": 2 145 | } 146 | -------------------------------------------------------------------------------- /07_02_lstm_compose_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Compose: Training a model to generate music" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import os\n", 17 | "import pickle\n", 18 | "import numpy\n", 19 | "from music21 import note, chord\n", 20 | "\n", 21 | "from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping\n", 22 | "from tensorflow.keras.utils import plot_model\n", 23 | "\n", 24 | "from models.RNNAttention import get_distinct, create_lookups, prepare_sequences, get_music_list, create_network" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "## Set parameters" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "# run params\n", 41 | "section = 'compose'\n", 42 | "run_id = '0006'\n", 43 | "music_name = 'cello'\n", 44 | "\n", 45 | "run_folder = 'run/{}/'.format(section)\n", 46 | "run_folder += '_'.join([run_id, music_name])\n", 47 | "\n", 48 | "\n", 49 | "store_folder = os.path.join(run_folder, 'store')\n", 50 | "data_folder = os.path.join('data', music_name)\n", 51 | "\n", 52 | "if not os.path.exists(run_folder):\n", 53 | " os.mkdir(run_folder)\n", 54 | " os.mkdir(os.path.join(run_folder, 'store'))\n", 55 | " os.mkdir(os.path.join(run_folder, 'output'))\n", 56 | " os.mkdir(os.path.join(run_folder, 'weights'))\n", 57 | " os.mkdir(os.path.join(run_folder, 'viz'))\n", 58 | " \n", 59 | "\n", 60 | "\n", 61 | "mode = 'build' # 'load' # \n", 62 | "\n", 63 | "# data params\n", 64 | "intervals = range(1)\n", 65 | "seq_len = 32\n", 66 | "\n", 67 | "# model params\n", 68 | "embed_size = 100\n", 69 | "rnn_units = 256\n", 70 | "use_attention = True" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "## Extract the notes" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "if mode == 'build':\n", 87 | " \n", 88 | " music_list, parser = get_music_list(data_folder)\n", 89 | " print(len(music_list), 'files in total')\n", 90 | "\n", 91 | " notes = []\n", 92 | " durations = []\n", 93 | "\n", 94 | " for i, file in enumerate(music_list):\n", 95 | " print(i+1, \"Parsing %s\" % file)\n", 96 | " original_score = parser.parse(file).chordify()\n", 97 | " \n", 98 | "\n", 99 | " for interval in intervals:\n", 100 | "\n", 101 | " score = original_score.transpose(interval)\n", 102 | "\n", 103 | " notes.extend(['START'] * seq_len)\n", 104 | " durations.extend([0]* seq_len)\n", 105 | "\n", 106 | " for element in score.flat:\n", 107 | " \n", 108 | " if isinstance(element, note.Note):\n", 109 | " if element.isRest:\n", 110 | " notes.append(str(element.name))\n", 111 | " durations.append(element.duration.quarterLength)\n", 112 | " else:\n", 113 | " notes.append(str(element.nameWithOctave))\n", 114 | " durations.append(element.duration.quarterLength)\n", 115 | "\n", 116 | " if isinstance(element, chord.Chord):\n", 117 | " notes.append('.'.join(n.nameWithOctave for n in element.pitches))\n", 118 | " durations.append(element.duration.quarterLength)\n", 119 | "\n", 120 | " with open(os.path.join(store_folder, 'notes'), 'wb') as f:\n", 121 | " pickle.dump(notes, f) #['G2', 'D3', 'B3', 'A3', 'B3', 'D3', 'B3', 'D3', 'G2',...]\n", 122 | " with open(os.path.join(store_folder, 'durations'), 'wb') as f:\n", 123 | " pickle.dump(durations, f) \n", 124 | "else:\n", 125 | " with open(os.path.join(store_folder, 'notes'), 'rb') as f:\n", 126 | " notes = pickle.load(f) #['G2', 'D3', 'B3', 'A3', 'B3', 'D3', 'B3', 'D3', 'G2',...]\n", 127 | " with open(os.path.join(store_folder, 'durations'), 'rb') as f:\n", 128 | " durations = pickle.load(f) " 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "## Create the lookup tables" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "# get the distinct sets of notes and durations\n", 145 | "note_names, n_notes = get_distinct(notes)\n", 146 | "duration_names, n_durations = get_distinct(durations)\n", 147 | "distincts = [note_names, n_notes, duration_names, n_durations]\n", 148 | "\n", 149 | "with open(os.path.join(store_folder, 'distincts'), 'wb') as f:\n", 150 | " pickle.dump(distincts, f)\n", 151 | "\n", 152 | "# make the lookup dictionaries for notes and dictionaries and save\n", 153 | "note_to_int, int_to_note = create_lookups(note_names)\n", 154 | "duration_to_int, int_to_duration = create_lookups(duration_names)\n", 155 | "lookups = [note_to_int, int_to_note, duration_to_int, int_to_duration]\n", 156 | "\n", 157 | "with open(os.path.join(store_folder, 'lookups'), 'wb') as f:\n", 158 | " pickle.dump(lookups, f)" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "print('\\nnote_to_int')\n", 168 | "note_to_int" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "print('\\nduration_to_int')\n", 178 | "duration_to_int" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": {}, 184 | "source": [ 185 | "## Prepare the sequences used by the Neural Network" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "network_input, network_output = prepare_sequences(notes, durations, lookups, distincts, seq_len)" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "print('pitch input')\n", 204 | "print(network_input[0][0])\n", 205 | "print('duration input')\n", 206 | "print(network_input[1][0])\n", 207 | "print('pitch output')\n", 208 | "print(network_output[0][0])\n", 209 | "print('duration output')\n", 210 | "print(network_output[1][0])" 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "metadata": {}, 216 | "source": [ 217 | "## Create the structure of the neural network" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [ 226 | "model, att_model = create_network(n_notes, n_durations, embed_size, rnn_units, use_attention)\n", 227 | "model.summary()" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "#Currently errors in TF2.2\n", 237 | "#plot_model(model, to_file=os.path.join(run_folder ,'viz/model.png'), show_shapes = True, show_layer_names = True)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "metadata": {}, 243 | "source": [ 244 | "## Train the neural network" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "weights_folder = os.path.join(run_folder, 'weights')\n", 254 | "# model.load_weights(os.path.join(weights_folder, \"weights.h5\"))" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "weights_folder = os.path.join(run_folder, 'weights')\n", 264 | "\n", 265 | "checkpoint1 = ModelCheckpoint(\n", 266 | " os.path.join(weights_folder, \"weights-improvement-{epoch:02d}-{loss:.4f}-bigger.h5\"),\n", 267 | " monitor='loss',\n", 268 | " verbose=0,\n", 269 | " save_best_only=True,\n", 270 | " mode='min'\n", 271 | ")\n", 272 | "\n", 273 | "checkpoint2 = ModelCheckpoint(\n", 274 | " os.path.join(weights_folder, \"weights.h5\"),\n", 275 | " monitor='loss',\n", 276 | " verbose=0,\n", 277 | " save_best_only=True,\n", 278 | " mode='min'\n", 279 | ")\n", 280 | "\n", 281 | "early_stopping = EarlyStopping(\n", 282 | " monitor='loss'\n", 283 | " , restore_best_weights=True\n", 284 | " , patience = 10\n", 285 | ")\n", 286 | "\n", 287 | "\n", 288 | "callbacks_list = [\n", 289 | " checkpoint1\n", 290 | " , checkpoint2\n", 291 | " , early_stopping\n", 292 | " ]\n", 293 | "\n", 294 | "model.save_weights(os.path.join(weights_folder, \"weights.h5\"))\n", 295 | "model.fit(network_input, network_output\n", 296 | " , epochs=2000000, batch_size=32\n", 297 | " , validation_split = 0.2\n", 298 | " , callbacks=callbacks_list\n", 299 | " , shuffle=True\n", 300 | " )\n", 301 | "\n" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": null, 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [] 310 | } 311 | ], 312 | "metadata": { 313 | "kernelspec": { 314 | "display_name": "gdl_code", 315 | "language": "python", 316 | "name": "gdl_code" 317 | }, 318 | "language_info": { 319 | "codemirror_mode": { 320 | "name": "ipython", 321 | "version": 3 322 | }, 323 | "file_extension": ".py", 324 | "mimetype": "text/x-python", 325 | "name": "python", 326 | "nbconvert_exporter": "python", 327 | "pygments_lexer": "ipython3", 328 | "version": "3.7.5" 329 | } 330 | }, 331 | "nbformat": 4, 332 | "nbformat_minor": 2 333 | } 334 | -------------------------------------------------------------------------------- /07_03_lstm_compose_predict.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# LSTM - Analysis" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## imports" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import pickle as pkl\n", 24 | "import time\n", 25 | "import os\n", 26 | "import numpy as np\n", 27 | "import sys\n", 28 | "from music21 import instrument, note, stream, chord, duration\n", 29 | "from models.RNNAttention import create_network, sample_with_temp\n", 30 | "\n", 31 | "import matplotlib.pyplot as plt\n" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "# parameters" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "# run params\n", 48 | "section = 'compose'\n", 49 | "run_id = '0006'\n", 50 | "music_name = 'cello'\n", 51 | "run_folder = 'run/{}/'.format(section)\n", 52 | "run_folder += '_'.join([run_id, music_name])\n", 53 | "\n", 54 | "# model params\n", 55 | "embed_size = 100\n", 56 | "rnn_units = 256\n", 57 | "use_attention = True\n" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "## load the lookup tables" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "\n", 74 | "store_folder = os.path.join(run_folder, 'store')\n", 75 | "\n", 76 | "with open(os.path.join(store_folder, 'distincts'), 'rb') as filepath:\n", 77 | " distincts = pkl.load(filepath)\n", 78 | " note_names, n_notes, duration_names, n_durations = distincts\n", 79 | "\n", 80 | "with open(os.path.join(store_folder, 'lookups'), 'rb') as filepath:\n", 81 | " lookups = pkl.load(filepath)\n", 82 | " note_to_int, int_to_note, duration_to_int, int_to_duration = lookups" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "## build the model" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "weights_folder = os.path.join(run_folder, 'weights')\n", 99 | "weights_file = 'weights.h5'\n", 100 | "\n", 101 | "model, att_model = create_network(n_notes, n_durations, embed_size, rnn_units, use_attention)\n", 102 | "\n", 103 | "# Load the weights to each node\n", 104 | "weight_source = os.path.join(weights_folder,weights_file)\n", 105 | "model.load_weights(weight_source)\n", 106 | "model.summary()" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "## build your own phrase" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "# prediction params\n", 123 | "notes_temp=0.5\n", 124 | "duration_temp = 0.5\n", 125 | "max_extra_notes = 50\n", 126 | "max_seq_len = 32\n", 127 | "seq_len = 32\n", 128 | "\n", 129 | "# notes = ['START', 'D3', 'D3', 'E3', 'D3', 'G3', 'F#3','D3', 'D3', 'E3', 'D3', 'G3', 'F#3','D3', 'D3', 'E3', 'D3', 'G3', 'F#3','D3', 'D3', 'E3', 'D3', 'G3', 'F#3']\n", 130 | "# durations = [0, 0.75, 0.25, 1, 1, 1, 2, 0.75, 0.25, 1, 1, 1, 2, 0.75, 0.25, 1, 1, 1, 2, 0.75, 0.25, 1, 1, 1, 2]\n", 131 | "\n", 132 | "\n", 133 | "# notes = ['START', 'F#3', 'G#3', 'F#3', 'E3', 'F#3', 'G#3', 'F#3', 'E3', 'F#3', 'G#3', 'F#3', 'E3','F#3', 'G#3', 'F#3', 'E3', 'F#3', 'G#3', 'F#3', 'E3', 'F#3', 'G#3', 'F#3', 'E3']\n", 134 | "# durations = [0, 0.75, 0.25, 1, 1, 1, 2, 0.75, 0.25, 1, 1, 1, 2, 0.75, 0.25, 1, 1, 1, 2, 0.75, 0.25, 1, 1, 1, 2]\n", 135 | "\n", 136 | "\n", 137 | "notes = ['START']\n", 138 | "durations = [0]\n", 139 | "\n", 140 | "if seq_len is not None:\n", 141 | " notes = ['START'] * (seq_len - len(notes)) + notes\n", 142 | " durations = [0] * (seq_len - len(durations)) + durations\n", 143 | "\n", 144 | "\n", 145 | "sequence_length = len(notes)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "## Generate notes from the neural network based on a sequence of notes" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "prediction_output = []\n", 162 | "notes_input_sequence = []\n", 163 | "durations_input_sequence = []\n", 164 | "\n", 165 | "overall_preds = []\n", 166 | "\n", 167 | "for n, d in zip(notes,durations):\n", 168 | " note_int = note_to_int[n]\n", 169 | " duration_int = duration_to_int[d]\n", 170 | " \n", 171 | " notes_input_sequence.append(note_int)\n", 172 | " durations_input_sequence.append(duration_int)\n", 173 | " \n", 174 | " prediction_output.append([n, d])\n", 175 | " \n", 176 | " if n != 'START':\n", 177 | " midi_note = note.Note(n)\n", 178 | "\n", 179 | " new_note = np.zeros(128)\n", 180 | " new_note[midi_note.pitch.midi] = 1\n", 181 | " overall_preds.append(new_note)\n", 182 | "\n", 183 | "\n", 184 | "att_matrix = np.zeros(shape = (max_extra_notes+sequence_length, max_extra_notes))\n", 185 | "\n", 186 | "for note_index in range(max_extra_notes):\n", 187 | "\n", 188 | " prediction_input = [\n", 189 | " np.array([notes_input_sequence])\n", 190 | " , np.array([durations_input_sequence])\n", 191 | " ]\n", 192 | "\n", 193 | " notes_prediction, durations_prediction = model.predict(prediction_input, verbose=0)\n", 194 | " if use_attention:\n", 195 | " att_prediction = att_model.predict(prediction_input, verbose=0)[0]\n", 196 | " att_matrix[(note_index-len(att_prediction)+sequence_length):(note_index+sequence_length), note_index] = att_prediction\n", 197 | " \n", 198 | " new_note = np.zeros(128)\n", 199 | " \n", 200 | " for idx, n_i in enumerate(notes_prediction[0]):\n", 201 | " try:\n", 202 | " note_name = int_to_note[idx]\n", 203 | " midi_note = note.Note(note_name)\n", 204 | " new_note[midi_note.pitch.midi] = n_i\n", 205 | " \n", 206 | " except:\n", 207 | " pass\n", 208 | " \n", 209 | " overall_preds.append(new_note)\n", 210 | " \n", 211 | " \n", 212 | " i1 = sample_with_temp(notes_prediction[0], notes_temp)\n", 213 | " i2 = sample_with_temp(durations_prediction[0], duration_temp)\n", 214 | " \n", 215 | "\n", 216 | " note_result = int_to_note[i1]\n", 217 | " duration_result = int_to_duration[i2]\n", 218 | " \n", 219 | " prediction_output.append([note_result, duration_result])\n", 220 | "\n", 221 | " notes_input_sequence.append(i1)\n", 222 | " durations_input_sequence.append(i2)\n", 223 | " \n", 224 | " if len(notes_input_sequence) > max_seq_len:\n", 225 | " notes_input_sequence = notes_input_sequence[1:]\n", 226 | " durations_input_sequence = durations_input_sequence[1:]\n", 227 | " \n", 228 | "# print(note_result)\n", 229 | "# print(duration_result)\n", 230 | " \n", 231 | " if note_result == 'START':\n", 232 | " break\n", 233 | "\n", 234 | "overall_preds = np.transpose(np.array(overall_preds)) \n", 235 | "print('Generated sequence of {} notes'.format(len(prediction_output)))" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [ 244 | "fig, ax = plt.subplots(figsize=(15,15))\n", 245 | "ax.set_yticks([int(j) for j in range(35,70)])\n", 246 | "\n", 247 | "plt.imshow(overall_preds[35:70,:], origin=\"lower\", cmap='coolwarm', vmin = -0.5, vmax = 0.5, extent=[0, max_extra_notes, 35,70]\n", 248 | " \n", 249 | " )" 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "metadata": {}, 255 | "source": [ 256 | "## convert the output from the prediction to notes and create a midi file from the notes " 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "output_folder = os.path.join(run_folder, 'output')\n", 266 | "\n", 267 | "midi_stream = stream.Stream()\n", 268 | "\n", 269 | "# create note and chord objects based on the values generated by the model\n", 270 | "for pattern in prediction_output:\n", 271 | " note_pattern, duration_pattern = pattern\n", 272 | " # pattern is a chord\n", 273 | " if ('.' in note_pattern):\n", 274 | " notes_in_chord = note_pattern.split('.')\n", 275 | " chord_notes = []\n", 276 | " for current_note in notes_in_chord:\n", 277 | " new_note = note.Note(current_note)\n", 278 | " new_note.duration = duration.Duration(duration_pattern)\n", 279 | " new_note.storedInstrument = instrument.Violoncello()\n", 280 | " chord_notes.append(new_note)\n", 281 | " new_chord = chord.Chord(chord_notes)\n", 282 | " midi_stream.append(new_chord)\n", 283 | " elif note_pattern == 'rest':\n", 284 | " # pattern is a rest\n", 285 | " new_note = note.Rest()\n", 286 | " new_note.duration = duration.Duration(duration_pattern)\n", 287 | " new_note.storedInstrument = instrument.Violoncello()\n", 288 | " midi_stream.append(new_note)\n", 289 | " elif note_pattern != 'START':\n", 290 | " # pattern is a note\n", 291 | " new_note = note.Note(note_pattern)\n", 292 | " new_note.duration = duration.Duration(duration_pattern)\n", 293 | " new_note.storedInstrument = instrument.Violoncello()\n", 294 | " midi_stream.append(new_note)\n", 295 | "\n", 296 | "\n", 297 | "\n", 298 | "midi_stream = midi_stream.chordify()\n", 299 | "timestr = time.strftime(\"%Y%m%d-%H%M%S\")\n", 300 | "midi_stream.write('midi', fp=os.path.join(output_folder, 'output-' + timestr + '.mid'))" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": null, 306 | "metadata": {}, 307 | "outputs": [], 308 | "source": [ 309 | "## attention plot\n", 310 | "if use_attention:\n", 311 | " fig, ax = plt.subplots(figsize=(20,20))\n", 312 | "\n", 313 | " im = ax.imshow(att_matrix[(seq_len-2):,], cmap='coolwarm', interpolation='nearest')\n", 314 | "\n", 315 | "\n", 316 | " \n", 317 | "\n", 318 | " # Minor ticks\n", 319 | " ax.set_xticks(np.arange(-.5, len(prediction_output)- seq_len, 1), minor=True);\n", 320 | " ax.set_yticks(np.arange(-.5, len(prediction_output)- seq_len, 1), minor=True);\n", 321 | "\n", 322 | " # Gridlines based on minor ticks\n", 323 | " ax.grid(which='minor', color='black', linestyle='-', linewidth=1)\n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " # We want to show all ticks...\n", 329 | " ax.set_xticks(np.arange(len(prediction_output) - seq_len))\n", 330 | " ax.set_yticks(np.arange(len(prediction_output)- seq_len+2))\n", 331 | " # ... and label them with the respective list entries\n", 332 | " ax.set_xticklabels([n[0] for n in prediction_output[(seq_len):]])\n", 333 | " ax.set_yticklabels([n[0] for n in prediction_output[(seq_len - 2):]])\n", 334 | "\n", 335 | " # ax.grid(color='black', linestyle='-', linewidth=1)\n", 336 | "\n", 337 | " ax.xaxis.tick_top()\n", 338 | "\n", 339 | "\n", 340 | " \n", 341 | " plt.setp(ax.get_xticklabels(), rotation=90, ha=\"left\", va = \"center\",\n", 342 | " rotation_mode=\"anchor\")\n", 343 | "\n", 344 | " plt.show()" 345 | ] 346 | } 347 | ], 348 | "metadata": { 349 | "kernelspec": { 350 | "display_name": "gdl_code", 351 | "language": "python", 352 | "name": "gdl_code" 353 | }, 354 | "language_info": { 355 | "codemirror_mode": { 356 | "name": "ipython", 357 | "version": 3 358 | }, 359 | "file_extension": ".py", 360 | "mimetype": "text/x-python", 361 | "name": "python", 362 | "nbconvert_exporter": "python", 363 | "pygments_lexer": "ipython3", 364 | "version": "3.7.5" 365 | } 366 | }, 367 | "nbformat": 4, 368 | "nbformat_minor": 2 369 | } 370 | -------------------------------------------------------------------------------- /07_04_musegan_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# MuseGAN Training" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## imports" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import os\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "import numpy as np\n", 26 | "import types\n", 27 | "\n", 28 | "from models.MuseGAN import MuseGAN\n", 29 | "from utils.loaders import load_music\n", 30 | "\n", 31 | "from music21 import midi\n", 32 | "from music21 import note, stream, duration" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "# run params\n", 42 | "SECTION = 'compose'\n", 43 | "RUN_ID = '001'\n", 44 | "DATA_NAME = 'chorales'\n", 45 | "FILENAME = 'Jsb16thSeparated.npz'\n", 46 | "RUN_FOLDER = 'run/{}/'.format(SECTION)\n", 47 | "RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])\n", 48 | "\n", 49 | "\n", 50 | "\n", 51 | "if not os.path.exists(RUN_FOLDER):\n", 52 | " os.mkdir(RUN_FOLDER)\n", 53 | " os.mkdir(os.path.join(RUN_FOLDER, 'viz'))\n", 54 | " os.mkdir(os.path.join(RUN_FOLDER, 'images'))\n", 55 | " os.mkdir(os.path.join(RUN_FOLDER, 'weights'))\n", 56 | " os.mkdir(os.path.join(RUN_FOLDER, 'samples'))\n", 57 | "\n", 58 | "mode = 'build' # ' 'load' # " 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "## data" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "BATCH_SIZE = 64\n", 75 | "n_bars = 2\n", 76 | "n_steps_per_bar = 16\n", 77 | "n_pitches = 84\n", 78 | "n_tracks = 4\n", 79 | "\n", 80 | "data_binary, data_ints, raw_data = load_music(DATA_NAME, FILENAME, n_bars, n_steps_per_bar)\n", 81 | "data_binary = np.squeeze(data_binary)" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": {}, 87 | "source": [ 88 | "## architecture" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "gan = MuseGAN(input_dim = data_binary.shape[1:]\n", 98 | " , critic_learning_rate = 0.001\n", 99 | " , generator_learning_rate = 0.001\n", 100 | " , optimiser = 'adam'\n", 101 | " , grad_weight = 10\n", 102 | " , z_dim = 32\n", 103 | " , batch_size = BATCH_SIZE\n", 104 | " , n_tracks = n_tracks\n", 105 | " , n_bars = n_bars\n", 106 | " , n_steps_per_bar = n_steps_per_bar\n", 107 | " , n_pitches = n_pitches\n", 108 | " )\n", 109 | "\n", 110 | "if mode == 'build':\n", 111 | " gan.save(RUN_FOLDER)\n", 112 | "else: \n", 113 | " gan.load_weights(RUN_FOLDER)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "gan.chords_tempNetwork.summary()" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "gan.barGen[0].summary()" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "gan.generator.summary()" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "metadata": { 147 | "scrolled": false 148 | }, 149 | "outputs": [], 150 | "source": [ 151 | "gan.critic.summary()" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "## training" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "\n", 168 | "EPOCHS = 6000\n", 169 | "PRINT_EVERY_N_BATCHES = 10\n", 170 | "\n", 171 | "gan.epoch = 0" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": { 178 | "scrolled": true 179 | }, 180 | "outputs": [], 181 | "source": [ 182 | "gan.train( \n", 183 | " data_binary\n", 184 | " , batch_size = BATCH_SIZE\n", 185 | " , epochs = EPOCHS\n", 186 | " , run_folder = RUN_FOLDER\n", 187 | " , print_every_n_batches = PRINT_EVERY_N_BATCHES\n", 188 | ")" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "fig = plt.figure()\n", 198 | "plt.plot([x[0] for x in gan.d_losses], color='black', linewidth=0.25)\n", 199 | "\n", 200 | "plt.plot([x[1] for x in gan.d_losses], color='green', linewidth=0.25)\n", 201 | "plt.plot([x[2] for x in gan.d_losses], color='red', linewidth=0.25)\n", 202 | "plt.plot(gan.g_losses, color='orange', linewidth=0.25)\n", 203 | "\n", 204 | "plt.xlabel('batch', fontsize=18)\n", 205 | "plt.ylabel('loss', fontsize=16)\n", 206 | "\n", 207 | "plt.xlim(0, len(gan.d_losses))\n", 208 | "# plt.ylim(0, 2)\n", 209 | "\n", 210 | "plt.show()\n" 211 | ] 212 | } 213 | ], 214 | "metadata": { 215 | "kernelspec": { 216 | "display_name": "gdl_code", 217 | "language": "python", 218 | "name": "gdl_code" 219 | }, 220 | "language_info": { 221 | "codemirror_mode": { 222 | "name": "ipython", 223 | "version": 3 224 | }, 225 | "file_extension": ".py", 226 | "mimetype": "text/x-python", 227 | "name": "python", 228 | "nbconvert_exporter": "python", 229 | "pygments_lexer": "ipython3", 230 | "version": "3.7.5" 231 | } 232 | }, 233 | "nbformat": 4, 234 | "nbformat_minor": 2 235 | } 236 | -------------------------------------------------------------------------------- /07_05_musegan_analysis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import numpy as np\n", 12 | "\n", 13 | "from music21 import midi\n", 14 | "from music21 import note, stream, duration\n", 15 | "from music21 import converter\n", 16 | "\n", 17 | "from models.MuseGAN import MuseGAN\n", 18 | "\n", 19 | "from utils.loaders import load_music\n", 20 | "\n", 21 | "from tensorflow.keras.models import load_model" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "# run params\n", 31 | "SECTION = 'compose'\n", 32 | "RUN_ID = '001'\n", 33 | "DATA_NAME = 'chorales'\n", 34 | "FILENAME = 'Jsb16thSeparated.npz'\n", 35 | "RUN_FOLDER = 'run/{}/'.format(SECTION)\n", 36 | "RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])\n" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "## data" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "BATCH_SIZE = 64\n", 53 | "n_bars = 2\n", 54 | "n_steps_per_bar = 16\n", 55 | "n_pitches = 84\n", 56 | "n_tracks = 4\n", 57 | "\n", 58 | "data_binary, data_ints, raw_data = load_music(DATA_NAME, FILENAME, n_bars, n_steps_per_bar)\n", 59 | "# data_binary = np.squeeze(data_binary)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "gan = MuseGAN(input_dim = data_binary.shape[1:]\n", 69 | " , critic_learning_rate = 0.001\n", 70 | " , generator_learning_rate = 0.001\n", 71 | " , optimiser = 'adam'\n", 72 | " , grad_weight = 10\n", 73 | " , z_dim = 32\n", 74 | " , batch_size = BATCH_SIZE\n", 75 | " , n_tracks = n_tracks\n", 76 | " , n_bars = n_bars\n", 77 | " , n_steps_per_bar = n_steps_per_bar\n", 78 | " , n_pitches = n_pitches\n", 79 | " )\n", 80 | "\n" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "gan.load_weights(RUN_FOLDER, None)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": { 96 | "scrolled": false 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "gan.generator.summary()" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "gan.critic.summary()" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "# view sample score" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "chords_noise = np.random.normal(0, 1, (1, gan.z_dim))\n", 126 | "style_noise = np.random.normal(0, 1, (1, gan.z_dim))\n", 127 | "melody_noise = np.random.normal(0, 1, (1, gan.n_tracks, gan.z_dim))\n", 128 | "groove_noise = np.random.normal(0, 1, (1, gan.n_tracks, gan.z_dim))" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "gen_scores = gan.generator.predict([chords_noise, style_noise, melody_noise, groove_noise])" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "np.argmax(gen_scores[0,0,0:4,:,3], axis = 1)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "gen_scores[0,0,0:4,60,3] = 0.02347812" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "filename = 'example'\n", 165 | "gan.notes_to_midi(RUN_FOLDER, gen_scores, filename)\n", 166 | "gen_score = converter.parse(os.path.join(RUN_FOLDER, 'samples/{}.midi'.format(filename)))\n", 167 | "gen_score.show()" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "gan.draw_score(gen_scores, 0)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "metadata": {}, 182 | "source": [ 183 | "# find the closest match" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "def find_closest(data_binary, score):\n", 193 | " current_dist = 99999999\n", 194 | " current_i = -1\n", 195 | " for i, d in enumerate(data_binary):\n", 196 | " dist = np.sqrt(np.sum(pow((d - score),2)))\n", 197 | " if dist < current_dist:\n", 198 | " current_i = i\n", 199 | " current_dist = dist\n", 200 | " \n", 201 | " return current_i\n", 202 | " " 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "closest_idx = find_closest(data_binary, gen_scores[0])\n", 212 | "closest_data = data_binary[[closest_idx]]\n", 213 | "print(closest_idx)" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "filename = 'closest'\n", 223 | "gan.notes_to_midi(RUN_FOLDER, closest_data,filename)\n", 224 | "closest_score = converter.parse(os.path.join(RUN_FOLDER, 'samples/{}.midi'.format(filename)))\n", 225 | "print('original')\n", 226 | "gen_score.show()\n", 227 | "print('closest')\n", 228 | "closest_score.show()\n" 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "metadata": {}, 234 | "source": [ 235 | "# changing the chords noise" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [ 244 | "chords_noise_2 = 5 * np.ones((1, gan.z_dim))" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "chords_scores = gan.generator.predict([chords_noise_2, style_noise, melody_noise, groove_noise])" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "filename = 'changing_chords'\n", 263 | "gan.notes_to_midi(RUN_FOLDER, chords_scores, filename)\n", 264 | "chords_score = converter.parse(os.path.join(RUN_FOLDER, 'samples/{}.midi'.format(filename)))\n", 265 | "print('original')\n", 266 | "gen_score.show()\n", 267 | "print('chords noise changed')\n", 268 | "chords_score.show()" 269 | ] 270 | }, 271 | { 272 | "cell_type": "markdown", 273 | "metadata": {}, 274 | "source": [ 275 | "# changing the style noise" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": null, 281 | "metadata": {}, 282 | "outputs": [], 283 | "source": [ 284 | "style_noise_2 = 5 * np.ones((1, gan.z_dim))" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": null, 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [ 293 | "style_scores = gan.generator.predict([chords_noise, style_noise_2, melody_noise, groove_noise])" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": null, 299 | "metadata": {}, 300 | "outputs": [], 301 | "source": [ 302 | "filename = 'changing_style'\n", 303 | "gan.notes_to_midi(RUN_FOLDER, style_scores, filename)\n", 304 | "style_score = converter.parse(os.path.join(RUN_FOLDER, 'samples/{}.midi'.format(filename)))\n", 305 | "print('original')\n", 306 | "gen_score.show()\n", 307 | "print('style noise changed')\n", 308 | "style_score.show()" 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "metadata": {}, 314 | "source": [ 315 | "# changing the melody noise" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": null, 321 | "metadata": {}, 322 | "outputs": [], 323 | "source": [ 324 | "melody_noise_2 = np.copy(melody_noise)\n", 325 | "melody_noise_2[0,0,:] = 5 * np.ones(gan.z_dim) " 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": null, 331 | "metadata": {}, 332 | "outputs": [], 333 | "source": [ 334 | "melody_scores = gan.generator.predict([chords_noise, style_noise, melody_noise_2, groove_noise])" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": null, 340 | "metadata": {}, 341 | "outputs": [], 342 | "source": [ 343 | "filename = 'changing_melody'\n", 344 | "gan.notes_to_midi(RUN_FOLDER, melody_scores, filename)\n", 345 | "melody_score = converter.parse(os.path.join(RUN_FOLDER, 'samples/{}.midi'.format(filename)))\n", 346 | "print('original')\n", 347 | "gen_score.show()\n", 348 | "print('melody noise changed')\n", 349 | "melody_score.show()" 350 | ] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "metadata": {}, 355 | "source": [ 356 | "# changing the groove noise" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": null, 362 | "metadata": {}, 363 | "outputs": [], 364 | "source": [ 365 | "groove_noise_2 = np.copy(groove_noise)\n", 366 | "groove_noise_2[0,3,:] = 5 * np.ones(gan.z_dim)" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": null, 372 | "metadata": {}, 373 | "outputs": [], 374 | "source": [ 375 | "groove_scores = gan.generator.predict([chords_noise, style_noise, melody_noise, groove_noise_2])" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": null, 381 | "metadata": {}, 382 | "outputs": [], 383 | "source": [ 384 | "filename = 'changing_groove'\n", 385 | "gan.notes_to_midi(RUN_FOLDER, groove_scores, filename)\n", 386 | "groove_score = converter.parse(os.path.join(RUN_FOLDER, 'samples/{}.midi'.format(filename)))\n", 387 | "print('original')\n", 388 | "gen_score.show()\n", 389 | "print('groove noise changed')\n", 390 | "groove_score.show()" 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": null, 396 | "metadata": {}, 397 | "outputs": [], 398 | "source": [] 399 | } 400 | ], 401 | "metadata": { 402 | "kernelspec": { 403 | "display_name": "gdl_code", 404 | "language": "python", 405 | "name": "gdl_code" 406 | }, 407 | "language_info": { 408 | "codemirror_mode": { 409 | "name": "ipython", 410 | "version": 3 411 | }, 412 | "file_extension": ".py", 413 | "mimetype": "text/x-python", 414 | "name": "python", 415 | "nbconvert_exporter": "python", 416 | "pygments_lexer": "ipython3", 417 | "version": "3.7.5" 418 | } 419 | }, 420 | "nbformat": 4, 421 | "nbformat_minor": 2 422 | } 423 | -------------------------------------------------------------------------------- /09_01_positional_encoding.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import matplotlib.pyplot as plt" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "seq_len = 128\n", 20 | "d_model = 512\n" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "pe = np.zeros((seq_len, d_model))" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "for pos in range(seq_len):\n", 39 | " for i in range(int(d_model / 2)):\n", 40 | " pe[pos,2*i] = np.sin(pos/(pow(10000,((2*i)/d_model))))\n", 41 | " pe[pos,2*i+1] = np.cos(pos/(pow(10000,((2*i)/d_model))))\n", 42 | " " 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": { 49 | "scrolled": true 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "fig= plt.figure(figsize=(20,5))\n", 54 | "# plt.xlabel('2i (d_model = 512)')\n", 55 | "# plt.ylabel('pos')\n", 56 | "plt.imshow(pe, cmap = 'Greys')" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "enc = np.zeros((seq_len, d_model))" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "enc = np.random.rand(seq_len, d_model)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "fig= plt.figure(figsize=(20,5))\n", 84 | "plt.imshow(enc, cmap = 'Greys')" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "out = enc + pe" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "fig= plt.figure(figsize=(20,5))\n", 103 | "plt.imshow(out, cmap = 'Greys')" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [] 112 | } 113 | ], 114 | "metadata": { 115 | "kernelspec": { 116 | "display_name": "gdl_code", 117 | "language": "python", 118 | "name": "gdl_code" 119 | }, 120 | "language_info": { 121 | "codemirror_mode": { 122 | "name": "ipython", 123 | "version": 3 124 | }, 125 | "file_extension": ".py", 126 | "mimetype": "text/x-python", 127 | "name": "python", 128 | "nbconvert_exporter": "python", 129 | "pygments_lexer": "ipython3", 130 | "version": "3.7.5" 131 | } 132 | }, 133 | "nbformat": 4, 134 | "nbformat_minor": 2 135 | } 136 | -------------------------------------------------------------------------------- /Dockerfile.cpu: -------------------------------------------------------------------------------- 1 | # See all tag variants at https://hub.docker.com/r/tensorflow/tensorflow/tags/ 2 | # build with `ln -sf Dockerfile.cpu Dockerfile && docker build --network=host -t gdl-image-cpu .` 3 | # Note: 'host' neworking isn't supported on macOS/windows - https://docs.docker.com/network/host/ 4 | # on macOS build with `ln -sf Dockerfile.cpu Dockerfile && docker build -t gdl-image-cpu .` 5 | FROM tensorflow/tensorflow:latest-py3-jupyter 6 | 7 | ## modify below 8 | ARG username=gdl 9 | ARG groupid=1000 10 | ARG userid=1000 11 | ## end modify 12 | 13 | RUN apt-get update 14 | RUN apt-get -y install graphviz 15 | 16 | COPY ./requirements.txt / 17 | RUN python3 -m pip install --upgrade pip 18 | RUN pip install --no-cache-dir -r /requirements.txt 19 | 20 | # -m option creates a fake writable home folder for Jupyter. 21 | RUN groupadd -g $groupid $username \ 22 | && useradd -m -r -u $userid -g $username $username 23 | USER $username 24 | 25 | VOLUME ["/GDL"] 26 | WORKDIR /GDL 27 | 28 | CMD ["jupyter", "notebook", "--no-browser", "--ip=0.0.0.0", "/GDL"] 29 | -------------------------------------------------------------------------------- /Dockerfile.gpu: -------------------------------------------------------------------------------- 1 | # See all tag variants at https://hub.docker.com/r/tensorflow/tensorflow/tags/ 2 | # build with `ln -sf Dockerfile.gpu Dockerfile && docker build --network=host -t {container-name} .` 3 | # NOTE(jwd) - if you wish to use this implementation, you must install nvidia-docker v2.0 4 | # see https://github.com/nvidia/nvidia-docker/wiki/Installation-(version-2.0) for steps 5 | FROM tensorflow/tensorflow:latest-gpu-py3-jupyter 6 | 7 | ## modify below 8 | ARG username=gdl 9 | ARG groupid=1000 10 | ARG userid=1000 11 | ## end modify 12 | 13 | RUN apt-get update 14 | RUN apt-get install graphviz 15 | 16 | COPY ./requirements.txt / 17 | RUN python3 -m pip install --upgrade pip 18 | RUN pip install --no-cache-dir -r /requirements.txt 19 | 20 | # -m option creates a fake writable home folder for Jupyter. 21 | RUN groupadd -g $groupid $username \ 22 | && useradd -m -r -u $userid -g $username $username 23 | USER $username 24 | 25 | VOLUME ["/GDL"] 26 | WORKDIR /GDL 27 | 28 | CMD ["jupyter", "notebook", "--no-browser", "--ip=0.0.0.0", "/GDL"] 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generative Deep Learning 2 | ### Teaching Machines to paint, write, compose and play 3 | 4 | The official code repository for examples in the O'Reilly book 'Generative Deep Learning' 5 | 6 | https://learning.oreilly.com/library/view/generative-deep-learning/9781492041931/ 7 | 8 | https://www.amazon.com/Generative-Deep-Learning-Teaching-Machines/dp/1492041947/ref=sr_1_1 9 | 10 | ## Tensorflow 2.0 11 | 12 | This branch uses Keras within Tensorflow 2.0. 13 | 14 | ## Structure 15 | 16 | This repository is structured as follows: 17 | 18 | The notebooks for each chapter are in the root of the repository, prefixed with the chapter number. 19 | 20 | The `data` folder is where to download relevant data sources (chapter 3 onwards) 21 | The `run` folder stores output from the generative models (chapter 3 onwards) 22 | The `utils` folder stores useful functions that are sourced by the main notebooks 23 | 24 | ## Book Contents 25 | Part 1: Introduction to Generative Deep Learning 26 | * Chapter 1: Generative Modeling 27 | * Chapter 2: Deep Learning 28 | * Chapter 3: Variational Autoencoders 29 | * Chapter 4: Generative Adversarial Networks 30 | 31 | Part 2: Teaching Machines to Paint, Write, Compose and Play 32 | * Chapter 5: Paint 33 | * Chapter 6: Write 34 | * Chapter 7: Compose 35 | * Chapter 8: Play 36 | * Chapter 9: The Future of Generative Modeling 37 | * Chapter 10: Conclusion 38 | 39 | 40 | ## Getting started 41 | 42 | To get started, first install the required libraries inside a virtual environment: 43 | 44 | `pip install -r requirements.txt` 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /colab/README.md: -------------------------------------------------------------------------------- 1 | # Generative Deep Learning (Google Colab Notebook) 2 | 3 | Google Colab Notebook 4 | 5 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/karaage0703/GDL_code/blob/karaage) 6 | 7 | ## 03_01_02_autoencoder.ipynb 8 | 9 | AE(Auto Encoder) training and analysis. 10 | 11 | Open this notebook on google colab. 12 | 13 | ## 03_03_04_vae_digits.ipynb 14 | 15 | VAE training and analysis. 16 | 17 | Open this notebook on google colab. 18 | 19 | ## 4_01_gan_camel.ipynb 20 | 21 | GAN training. 22 | 23 | Download `camel.npy` from [Numpy Quick Draw](https://console.cloud.google.com/storage/browser/quickdraw_dataset/full/numpy_bitmap) and upload Google drive under `My Drive/dl` directory. 24 | 25 | Then, open this notebook on google colab. 26 | 27 | ## 04_02_wgan_cifar.ipynb 28 | 29 | WGAN training. 30 | 31 | Open this notebook on google colab. 32 | 33 | ## 05_01_cyclegan_train.ipynb 34 | 35 | Cycle GAN training. 36 | 37 | Open this notebook on google colab. 38 | 39 | ## 06_01_lstm_text_train.ipynb 40 | 41 | LSTM Text training. 42 | 43 | Open this notebook on google colab. 44 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | aesop/* 3 | apple2orange/* 4 | camel/* 5 | celeb/* 6 | cello/* 7 | chorales/* 8 | glove/* 9 | monet2photo/* 10 | qa/* 11 | 12 | !.gitignore 13 | -------------------------------------------------------------------------------- /data/qa_test/my_test.csv: -------------------------------------------------------------------------------- 1 | story_id,story_text,question,answer_token_ranges 2 | ./cnn/stories/dave.story,"The winning goal was scored by 23-year-old striker Joe Bloggs during the match between Arsenal and Barcelona . Arsenal recently signed the striker for 50 million pounds . The next match is in two weeks time, on July 31st 2005 . ",How much money was spent on the striker ?,24:27 3 | -------------------------------------------------------------------------------- /launch-docker-cpu.sh: -------------------------------------------------------------------------------- 1 | # USAGE - ./launch-docker-cpu.sh {abs-path-to-GDL-code} 2 | # - eg. to run from current directory: 3 | # ./launch-docker-cpu.sh $(pwd) 4 | if [[ "$OSTYPE" == "darwin"* ]]; then 5 | docker run --rm -p 8888:8888 -it -v $1:/GDL gdl-image-cpu 6 | else 7 | docker run --rm --network=host -it -v $1:/GDL gdl-image-cpu 8 | fi 9 | -------------------------------------------------------------------------------- /launch-docker-gpu.sh: -------------------------------------------------------------------------------- 1 | # USAGE - ./launch-docker-gpu.sh {abs-path-to-GDL-code} 2 | docker run --rm --runtime=nvidia --network=host -it -v $1:/GDL gdl-image 3 | -------------------------------------------------------------------------------- /models/AE.py: -------------------------------------------------------------------------------- 1 | 2 | from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, BatchNormalization, LeakyReLU, Dropout 3 | from tensorflow.keras.models import Model 4 | from tensorflow.keras import backend as K 5 | from tensorflow.keras.optimizers import Adam 6 | from tensorflow.keras.callbacks import ModelCheckpoint 7 | from tensorflow.keras.utils import plot_model 8 | 9 | from utils.callbacks import CustomCallback, step_decay_schedule 10 | 11 | import numpy as np 12 | import json 13 | import os 14 | import pickle 15 | 16 | 17 | class Autoencoder(): 18 | def __init__(self 19 | , input_dim 20 | , encoder_conv_filters 21 | , encoder_conv_kernel_size 22 | , encoder_conv_strides 23 | , decoder_conv_t_filters 24 | , decoder_conv_t_kernel_size 25 | , decoder_conv_t_strides 26 | , z_dim 27 | , use_batch_norm = False 28 | , use_dropout = False 29 | ): 30 | 31 | self.name = 'autoencoder' 32 | 33 | self.input_dim = input_dim 34 | self.encoder_conv_filters = encoder_conv_filters 35 | self.encoder_conv_kernel_size = encoder_conv_kernel_size 36 | self.encoder_conv_strides = encoder_conv_strides 37 | self.decoder_conv_t_filters = decoder_conv_t_filters 38 | self.decoder_conv_t_kernel_size = decoder_conv_t_kernel_size 39 | self.decoder_conv_t_strides = decoder_conv_t_strides 40 | self.z_dim = z_dim 41 | 42 | self.use_batch_norm = use_batch_norm 43 | self.use_dropout = use_dropout 44 | 45 | self.n_layers_encoder = len(encoder_conv_filters) 46 | self.n_layers_decoder = len(decoder_conv_t_filters) 47 | 48 | self._build() 49 | 50 | def _build(self): 51 | 52 | ### THE ENCODER 53 | encoder_input = Input(shape=self.input_dim, name='encoder_input') 54 | 55 | x = encoder_input 56 | 57 | for i in range(self.n_layers_encoder): 58 | conv_layer = Conv2D( 59 | filters = self.encoder_conv_filters[i] 60 | , kernel_size = self.encoder_conv_kernel_size[i] 61 | , strides = self.encoder_conv_strides[i] 62 | , padding = 'same' 63 | , name = 'encoder_conv_' + str(i) 64 | ) 65 | 66 | x = conv_layer(x) 67 | 68 | x = LeakyReLU()(x) 69 | 70 | if self.use_batch_norm: 71 | x = BatchNormalization()(x) 72 | 73 | if self.use_dropout: 74 | x = Dropout(rate = 0.25)(x) 75 | 76 | shape_before_flattening = K.int_shape(x)[1:] 77 | 78 | x = Flatten()(x) 79 | encoder_output= Dense(self.z_dim, name='encoder_output')(x) 80 | 81 | self.encoder = Model(encoder_input, encoder_output) 82 | 83 | 84 | ### THE DECODER 85 | decoder_input = Input(shape=(self.z_dim,), name='decoder_input') 86 | 87 | x = Dense(np.prod(shape_before_flattening))(decoder_input) 88 | x = Reshape(shape_before_flattening)(x) 89 | 90 | for i in range(self.n_layers_decoder): 91 | conv_t_layer = Conv2DTranspose( 92 | filters = self.decoder_conv_t_filters[i] 93 | , kernel_size = self.decoder_conv_t_kernel_size[i] 94 | , strides = self.decoder_conv_t_strides[i] 95 | , padding = 'same' 96 | , name = 'decoder_conv_t_' + str(i) 97 | ) 98 | 99 | x = conv_t_layer(x) 100 | 101 | if i < self.n_layers_decoder - 1: 102 | x = LeakyReLU()(x) 103 | 104 | if self.use_batch_norm: 105 | x = BatchNormalization()(x) 106 | 107 | if self.use_dropout: 108 | x = Dropout(rate = 0.25)(x) 109 | else: 110 | x = Activation('sigmoid')(x) 111 | 112 | decoder_output = x 113 | 114 | self.decoder = Model(decoder_input, decoder_output) 115 | 116 | ### THE FULL AUTOENCODER 117 | model_input = encoder_input 118 | model_output = self.decoder(encoder_output) 119 | 120 | self.model = Model(model_input, model_output) 121 | 122 | 123 | def compile(self, learning_rate): 124 | self.learning_rate = learning_rate 125 | 126 | optimizer = Adam(lr=learning_rate) 127 | 128 | def r_loss(y_true, y_pred): 129 | return K.mean(K.square(y_true - y_pred), axis = [1,2,3]) 130 | 131 | self.model.compile(optimizer=optimizer, loss = r_loss) 132 | 133 | def save(self, folder): 134 | 135 | if not os.path.exists(folder): 136 | os.makedirs(folder) 137 | os.makedirs(os.path.join(folder, 'viz')) 138 | os.makedirs(os.path.join(folder, 'weights')) 139 | os.makedirs(os.path.join(folder, 'images')) 140 | 141 | with open(os.path.join(folder, 'params.pkl'), 'wb') as f: 142 | pickle.dump([ 143 | self.input_dim 144 | , self.encoder_conv_filters 145 | , self.encoder_conv_kernel_size 146 | , self.encoder_conv_strides 147 | , self.decoder_conv_t_filters 148 | , self.decoder_conv_t_kernel_size 149 | , self.decoder_conv_t_strides 150 | , self.z_dim 151 | , self.use_batch_norm 152 | , self.use_dropout 153 | ], f) 154 | 155 | self.plot_model(folder) 156 | 157 | 158 | 159 | 160 | def load_weights(self, filepath): 161 | self.model.load_weights(filepath) 162 | 163 | 164 | def train(self, x_train, batch_size, epochs, run_folder, print_every_n_batches = 100, initial_epoch = 0, lr_decay = 1): 165 | 166 | custom_callback = CustomCallback(run_folder, print_every_n_batches, initial_epoch, self) 167 | lr_sched = step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1) 168 | 169 | checkpoint2 = ModelCheckpoint(os.path.join(run_folder, 'weights/weights.h5'), save_weights_only = True, verbose=1) 170 | 171 | callbacks_list = [checkpoint2, custom_callback, lr_sched] 172 | 173 | self.model.fit( 174 | x_train 175 | , x_train 176 | , batch_size = batch_size 177 | , shuffle = True 178 | , epochs = epochs 179 | , initial_epoch = initial_epoch 180 | , callbacks = callbacks_list 181 | ) 182 | 183 | def plot_model(self, run_folder): 184 | plot_model(self.model, to_file=os.path.join(run_folder ,'viz/model.png'), show_shapes = True, show_layer_names = True) 185 | plot_model(self.encoder, to_file=os.path.join(run_folder ,'viz/encoder.png'), show_shapes = True, show_layer_names = True) 186 | plot_model(self.decoder, to_file=os.path.join(run_folder ,'viz/decoder.png'), show_shapes = True, show_layer_names = True) 187 | 188 | 189 | -------------------------------------------------------------------------------- /models/GAN.py: -------------------------------------------------------------------------------- 1 | 2 | from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, BatchNormalization, LeakyReLU, Dropout, ZeroPadding2D, UpSampling2D 3 | 4 | from tensorflow.keras.models import Model, Sequential 5 | from tensorflow.keras import backend as K 6 | from tensorflow.keras.optimizers import Adam, RMSprop 7 | from tensorflow.keras.utils import plot_model 8 | from tensorflow.keras.initializers import RandomNormal 9 | 10 | import numpy as np 11 | import json 12 | import os 13 | import pickle as pkl 14 | import matplotlib.pyplot as plt 15 | 16 | 17 | class GAN(): 18 | def __init__(self 19 | , input_dim 20 | , discriminator_conv_filters 21 | , discriminator_conv_kernel_size 22 | , discriminator_conv_strides 23 | , discriminator_batch_norm_momentum 24 | , discriminator_activation 25 | , discriminator_dropout_rate 26 | , discriminator_learning_rate 27 | , generator_initial_dense_layer_size 28 | , generator_upsample 29 | , generator_conv_filters 30 | , generator_conv_kernel_size 31 | , generator_conv_strides 32 | , generator_batch_norm_momentum 33 | , generator_activation 34 | , generator_dropout_rate 35 | , generator_learning_rate 36 | , optimiser 37 | , z_dim 38 | ): 39 | 40 | self.name = 'gan' 41 | 42 | self.input_dim = input_dim 43 | self.discriminator_conv_filters = discriminator_conv_filters 44 | self.discriminator_conv_kernel_size = discriminator_conv_kernel_size 45 | self.discriminator_conv_strides = discriminator_conv_strides 46 | self.discriminator_batch_norm_momentum = discriminator_batch_norm_momentum 47 | self.discriminator_activation = discriminator_activation 48 | self.discriminator_dropout_rate = discriminator_dropout_rate 49 | self.discriminator_learning_rate = discriminator_learning_rate 50 | 51 | self.generator_initial_dense_layer_size = generator_initial_dense_layer_size 52 | self.generator_upsample = generator_upsample 53 | self.generator_conv_filters = generator_conv_filters 54 | self.generator_conv_kernel_size = generator_conv_kernel_size 55 | self.generator_conv_strides = generator_conv_strides 56 | self.generator_batch_norm_momentum = generator_batch_norm_momentum 57 | self.generator_activation = generator_activation 58 | self.generator_dropout_rate = generator_dropout_rate 59 | self.generator_learning_rate = generator_learning_rate 60 | 61 | self.optimiser = optimiser 62 | self.z_dim = z_dim 63 | 64 | self.n_layers_discriminator = len(discriminator_conv_filters) 65 | self.n_layers_generator = len(generator_conv_filters) 66 | 67 | self.weight_init = RandomNormal(mean=0., stddev=0.02) 68 | 69 | self.d_losses = [] 70 | self.g_losses = [] 71 | 72 | self.epoch = 0 73 | 74 | self._build_discriminator() 75 | self._build_generator() 76 | 77 | self._build_adversarial() 78 | 79 | def get_activation(self, activation): 80 | if activation == 'leaky_relu': 81 | layer = LeakyReLU(alpha = 0.2) 82 | else: 83 | layer = Activation(activation) 84 | return layer 85 | 86 | def _build_discriminator(self): 87 | 88 | ### THE discriminator 89 | discriminator_input = Input(shape=self.input_dim, name='discriminator_input') 90 | 91 | x = discriminator_input 92 | 93 | for i in range(self.n_layers_discriminator): 94 | 95 | x = Conv2D( 96 | filters = self.discriminator_conv_filters[i] 97 | , kernel_size = self.discriminator_conv_kernel_size[i] 98 | , strides = self.discriminator_conv_strides[i] 99 | , padding = 'same' 100 | , name = 'discriminator_conv_' + str(i) 101 | , kernel_initializer = self.weight_init 102 | )(x) 103 | 104 | if self.discriminator_batch_norm_momentum and i > 0: 105 | x = BatchNormalization(momentum = self.discriminator_batch_norm_momentum)(x) 106 | 107 | x = self.get_activation(self.discriminator_activation)(x) 108 | 109 | if self.discriminator_dropout_rate: 110 | x = Dropout(rate = self.discriminator_dropout_rate)(x) 111 | 112 | x = Flatten()(x) 113 | 114 | discriminator_output = Dense(1, activation='sigmoid', kernel_initializer = self.weight_init)(x) 115 | 116 | self.discriminator = Model(discriminator_input, discriminator_output) 117 | 118 | 119 | def _build_generator(self): 120 | 121 | ### THE generator 122 | 123 | generator_input = Input(shape=(self.z_dim,), name='generator_input') 124 | 125 | x = generator_input 126 | 127 | x = Dense(np.prod(self.generator_initial_dense_layer_size), kernel_initializer = self.weight_init)(x) 128 | 129 | if self.generator_batch_norm_momentum: 130 | x = BatchNormalization(momentum = self.generator_batch_norm_momentum)(x) 131 | 132 | x = self.get_activation(self.generator_activation)(x) 133 | 134 | x = Reshape(self.generator_initial_dense_layer_size)(x) 135 | 136 | if self.generator_dropout_rate: 137 | x = Dropout(rate = self.generator_dropout_rate)(x) 138 | 139 | for i in range(self.n_layers_generator): 140 | 141 | if self.generator_upsample[i] == 2: 142 | x = UpSampling2D()(x) 143 | x = Conv2D( 144 | filters = self.generator_conv_filters[i] 145 | , kernel_size = self.generator_conv_kernel_size[i] 146 | , padding = 'same' 147 | , name = 'generator_conv_' + str(i) 148 | , kernel_initializer = self.weight_init 149 | )(x) 150 | else: 151 | 152 | x = Conv2DTranspose( 153 | filters = self.generator_conv_filters[i] 154 | , kernel_size = self.generator_conv_kernel_size[i] 155 | , padding = 'same' 156 | , strides = self.generator_conv_strides[i] 157 | , name = 'generator_conv_' + str(i) 158 | , kernel_initializer = self.weight_init 159 | )(x) 160 | 161 | if i < self.n_layers_generator - 1: 162 | 163 | if self.generator_batch_norm_momentum: 164 | x = BatchNormalization(momentum = self.generator_batch_norm_momentum)(x) 165 | 166 | x = self.get_activation(self.generator_activation)(x) 167 | 168 | 169 | else: 170 | 171 | x = Activation('tanh')(x) 172 | 173 | 174 | generator_output = x 175 | 176 | self.generator = Model(generator_input, generator_output) 177 | 178 | 179 | def get_opti(self, lr): 180 | if self.optimiser == 'adam': 181 | opti = Adam(lr=lr, beta_1=0.5) 182 | elif self.optimiser == 'rmsprop': 183 | opti = RMSprop(lr=lr) 184 | else: 185 | opti = Adam(lr=lr) 186 | 187 | return opti 188 | 189 | def set_trainable(self, m, val): 190 | m.trainable = val 191 | for l in m.layers: 192 | l.trainable = val 193 | 194 | 195 | def _build_adversarial(self): 196 | 197 | ### COMPILE DISCRIMINATOR 198 | 199 | self.discriminator.compile( 200 | optimizer=self.get_opti(self.discriminator_learning_rate) 201 | , loss = 'binary_crossentropy' 202 | , metrics = ['accuracy'] 203 | ) 204 | 205 | ### COMPILE THE FULL GAN 206 | 207 | self.set_trainable(self.discriminator, False) 208 | 209 | model_input = Input(shape=(self.z_dim,), name='model_input') 210 | model_output = self.discriminator(self.generator(model_input)) 211 | self.model = Model(model_input, model_output) 212 | 213 | self.model.compile(optimizer=self.get_opti(self.generator_learning_rate) , loss='binary_crossentropy', metrics=['accuracy'] 214 | , experimental_run_tf_function=False 215 | ) 216 | 217 | self.set_trainable(self.discriminator, True) 218 | 219 | 220 | 221 | 222 | def train_discriminator(self, x_train, batch_size, using_generator): 223 | 224 | valid = np.ones((batch_size,1)) 225 | fake = np.zeros((batch_size,1)) 226 | 227 | if using_generator: 228 | true_imgs = next(x_train)[0] 229 | if true_imgs.shape[0] != batch_size: 230 | true_imgs = next(x_train)[0] 231 | else: 232 | idx = np.random.randint(0, x_train.shape[0], batch_size) 233 | true_imgs = x_train[idx] 234 | 235 | noise = np.random.normal(0, 1, (batch_size, self.z_dim)) 236 | gen_imgs = self.generator.predict(noise) 237 | 238 | d_loss_real, d_acc_real = self.discriminator.train_on_batch(true_imgs, valid) 239 | d_loss_fake, d_acc_fake = self.discriminator.train_on_batch(gen_imgs, fake) 240 | d_loss = 0.5 * (d_loss_real + d_loss_fake) 241 | d_acc = 0.5 * (d_acc_real + d_acc_fake) 242 | 243 | return [d_loss, d_loss_real, d_loss_fake, d_acc, d_acc_real, d_acc_fake] 244 | 245 | def train_generator(self, batch_size): 246 | valid = np.ones((batch_size,1)) 247 | noise = np.random.normal(0, 1, (batch_size, self.z_dim)) 248 | return self.model.train_on_batch(noise, valid) 249 | 250 | 251 | def train(self, x_train, batch_size, epochs, run_folder 252 | , print_every_n_batches = 50 253 | , using_generator = False): 254 | 255 | for epoch in range(self.epoch, self.epoch + epochs): 256 | 257 | d = self.train_discriminator(x_train, batch_size, using_generator) 258 | g = self.train_generator(batch_size) 259 | 260 | print ("%d [D loss: (%.3f)(R %.3f, F %.3f)] [D acc: (%.3f)(%.3f, %.3f)] [G loss: %.3f] [G acc: %.3f]" % (epoch, d[0], d[1], d[2], d[3], d[4], d[5], g[0], g[1])) 261 | 262 | self.d_losses.append(d) 263 | self.g_losses.append(g) 264 | 265 | if epoch % print_every_n_batches == 0: 266 | self.sample_images(run_folder) 267 | self.model.save_weights(os.path.join(run_folder, 'weights/weights-%d.h5' % (epoch))) 268 | self.model.save_weights(os.path.join(run_folder, 'weights/weights.h5')) 269 | self.save_model(run_folder) 270 | 271 | self.epoch += 1 272 | 273 | 274 | def sample_images(self, run_folder): 275 | r, c = 5, 5 276 | noise = np.random.normal(0, 1, (r * c, self.z_dim)) 277 | gen_imgs = self.generator.predict(noise) 278 | 279 | gen_imgs = 0.5 * (gen_imgs + 1) 280 | gen_imgs = np.clip(gen_imgs, 0, 1) 281 | 282 | fig, axs = plt.subplots(r, c, figsize=(15,15)) 283 | cnt = 0 284 | 285 | for i in range(r): 286 | for j in range(c): 287 | axs[i,j].imshow(np.squeeze(gen_imgs[cnt, :,:,:]), cmap = 'gray') 288 | axs[i,j].axis('off') 289 | cnt += 1 290 | fig.savefig(os.path.join(run_folder, "images/sample_%d.png" % self.epoch)) 291 | plt.close() 292 | 293 | 294 | 295 | 296 | 297 | def plot_model(self, run_folder): 298 | plot_model(self.model, to_file=os.path.join(run_folder ,'viz/model.png'), show_shapes = True, show_layer_names = True) 299 | plot_model(self.discriminator, to_file=os.path.join(run_folder ,'viz/discriminator.png'), show_shapes = True, show_layer_names = True) 300 | plot_model(self.generator, to_file=os.path.join(run_folder ,'viz/generator.png'), show_shapes = True, show_layer_names = True) 301 | 302 | 303 | 304 | def save(self, folder): 305 | 306 | with open(os.path.join(folder, 'params.pkl'), 'wb') as f: 307 | pkl.dump([ 308 | self.input_dim 309 | , self.discriminator_conv_filters 310 | , self.discriminator_conv_kernel_size 311 | , self.discriminator_conv_strides 312 | , self.discriminator_batch_norm_momentum 313 | , self.discriminator_activation 314 | , self.discriminator_dropout_rate 315 | , self.discriminator_learning_rate 316 | , self.generator_initial_dense_layer_size 317 | , self.generator_upsample 318 | , self.generator_conv_filters 319 | , self.generator_conv_kernel_size 320 | , self.generator_conv_strides 321 | , self.generator_batch_norm_momentum 322 | , self.generator_activation 323 | , self.generator_dropout_rate 324 | , self.generator_learning_rate 325 | , self.optimiser 326 | , self.z_dim 327 | ], f) 328 | 329 | self.plot_model(folder) 330 | 331 | def save_model(self, run_folder): 332 | self.model.save(os.path.join(run_folder, 'model.h5')) 333 | self.discriminator.save(os.path.join(run_folder, 'discriminator.h5')) 334 | self.generator.save(os.path.join(run_folder, 'generator.h5')) 335 | 336 | def load_weights(self, filepath): 337 | self.model.load_weights(filepath) 338 | 339 | 340 | 341 | 342 | 343 | -------------------------------------------------------------------------------- /models/RNNAttention.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import glob 4 | 5 | from music21 import corpus, converter 6 | 7 | from tensorflow.keras.layers import LSTM, Input, Dropout, Dense, Activation, Embedding, Concatenate, Reshape 8 | from tensorflow.keras.layers import Flatten, RepeatVector, Permute, TimeDistributed 9 | from tensorflow.keras.layers import Multiply, Lambda, Softmax 10 | import tensorflow.keras.backend as K 11 | from tensorflow.keras.models import Model 12 | from tensorflow.keras.optimizers import RMSprop 13 | 14 | from tensorflow.keras.utils import to_categorical 15 | 16 | def get_music_list(data_folder): 17 | 18 | if data_folder == 'chorales': 19 | file_list = ['bwv' + str(x['bwv']) for x in corpus.chorales.ChoraleList().byBWV.values()] 20 | parser = corpus 21 | else: 22 | file_list = glob.glob(os.path.join(data_folder, "*.mid")) 23 | parser = converter 24 | 25 | return file_list, parser 26 | 27 | def create_network(n_notes, n_durations, embed_size = 100, rnn_units = 256, use_attention = False): 28 | """ create the structure of the neural network """ 29 | 30 | notes_in = Input(shape = (None,)) 31 | durations_in = Input(shape = (None,)) 32 | 33 | x1 = Embedding(n_notes, embed_size)(notes_in) 34 | x2 = Embedding(n_durations, embed_size)(durations_in) 35 | 36 | x = Concatenate()([x1,x2]) 37 | 38 | x = LSTM(rnn_units, return_sequences=True)(x) 39 | # x = Dropout(0.2)(x) 40 | 41 | if use_attention: 42 | 43 | x = LSTM(rnn_units, return_sequences=True)(x) 44 | # x = Dropout(0.2)(x) 45 | 46 | e = Dense(1, activation='tanh')(x) 47 | e = Reshape([-1])(e) 48 | alpha = Activation('softmax')(e) 49 | 50 | alpha_repeated = Permute([2, 1])(RepeatVector(rnn_units)(alpha)) 51 | 52 | c = Multiply()([x, alpha_repeated]) 53 | c = Lambda(lambda xin: K.sum(xin, axis=1), output_shape=(rnn_units,))(c) 54 | 55 | else: 56 | c = LSTM(rnn_units)(x) 57 | # c = Dropout(0.2)(c) 58 | 59 | notes_out = Dense(n_notes, activation = 'softmax', name = 'pitch')(c) 60 | durations_out = Dense(n_durations, activation = 'softmax', name = 'duration')(c) 61 | 62 | model = Model([notes_in, durations_in], [notes_out, durations_out]) 63 | 64 | 65 | if use_attention: 66 | att_model = Model([notes_in, durations_in], alpha) 67 | else: 68 | att_model = None 69 | 70 | 71 | opti = RMSprop(lr = 0.001) 72 | model.compile(loss=['categorical_crossentropy', 'categorical_crossentropy'], optimizer=opti) 73 | 74 | return model, att_model 75 | 76 | 77 | def get_distinct(elements): 78 | # Get all pitch names 79 | element_names = sorted(set(elements)) 80 | n_elements = len(element_names) 81 | return (element_names, n_elements) 82 | 83 | def create_lookups(element_names): 84 | # create dictionary to map notes and durations to integers 85 | element_to_int = dict((element, number) for number, element in enumerate(element_names)) 86 | int_to_element = dict((number, element) for number, element in enumerate(element_names)) 87 | 88 | return (element_to_int, int_to_element) 89 | 90 | 91 | def prepare_sequences(notes, durations, lookups, distincts, seq_len =32): 92 | """ Prepare the sequences used to train the Neural Network """ 93 | 94 | note_to_int, int_to_note, duration_to_int, int_to_duration = lookups 95 | note_names, n_notes, duration_names, n_durations = distincts 96 | 97 | notes_network_input = [] 98 | notes_network_output = [] 99 | durations_network_input = [] 100 | durations_network_output = [] 101 | 102 | # create input sequences and the corresponding outputs 103 | for i in range(len(notes) - seq_len): 104 | notes_sequence_in = notes[i:i + seq_len] 105 | notes_sequence_out = notes[i + seq_len] 106 | notes_network_input.append([note_to_int[char] for char in notes_sequence_in]) 107 | notes_network_output.append(note_to_int[notes_sequence_out]) 108 | 109 | durations_sequence_in = durations[i:i + seq_len] 110 | durations_sequence_out = durations[i + seq_len] 111 | durations_network_input.append([duration_to_int[char] for char in durations_sequence_in]) 112 | durations_network_output.append(duration_to_int[durations_sequence_out]) 113 | 114 | n_patterns = len(notes_network_input) 115 | 116 | # reshape the input into a format compatible with LSTM layers 117 | notes_network_input = np.reshape(notes_network_input, (n_patterns, seq_len)) 118 | durations_network_input = np.reshape(durations_network_input, (n_patterns, seq_len)) 119 | network_input = [notes_network_input, durations_network_input] 120 | 121 | notes_network_output = to_categorical(notes_network_output, num_classes=n_notes) 122 | durations_network_output = to_categorical(durations_network_output, num_classes=n_durations) 123 | network_output = [notes_network_output, durations_network_output] 124 | 125 | return (network_input, network_output) 126 | 127 | 128 | def sample_with_temp(preds, temperature): 129 | 130 | if temperature == 0: 131 | return np.argmax(preds) 132 | else: 133 | preds = np.log(preds) / temperature 134 | exp_preds = np.exp(preds) 135 | preds = exp_preds / np.sum(exp_preds) 136 | return np.random.choice(len(preds), p=preds) 137 | -------------------------------------------------------------------------------- /models/VAE.py: -------------------------------------------------------------------------------- 1 | 2 | from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, BatchNormalization, LeakyReLU, Dropout, Layer 3 | from tensorflow.keras.models import Model 4 | from tensorflow.keras import backend as K 5 | from tensorflow.keras.optimizers import Adam 6 | from tensorflow.keras.callbacks import ModelCheckpoint 7 | from tensorflow.keras.utils import plot_model 8 | 9 | import tensorflow as tf 10 | 11 | from utils.callbacks import CustomCallback, step_decay_schedule 12 | 13 | import numpy as np 14 | import json 15 | import os 16 | import pickle 17 | 18 | class Sampling(Layer): 19 | def call(self, inputs): 20 | mu, log_var = inputs 21 | epsilon = K.random_normal(shape=K.shape(mu), mean=0., stddev=1.) 22 | return mu + K.exp(log_var / 2) * epsilon 23 | 24 | 25 | class VAEModel(Model): 26 | def __init__(self, encoder, decoder, r_loss_factor, **kwargs): 27 | super(VAEModel, self).__init__(**kwargs) 28 | self.encoder = encoder 29 | self.decoder = decoder 30 | self.r_loss_factor = r_loss_factor 31 | 32 | def train_step(self, data): 33 | if isinstance(data, tuple): 34 | data = data[0] 35 | with tf.GradientTape() as tape: 36 | z_mean, z_log_var, z = self.encoder(data) 37 | reconstruction = self.decoder(z) 38 | reconstruction_loss = tf.reduce_mean( 39 | tf.square(data - reconstruction), axis = [1,2,3] 40 | ) 41 | reconstruction_loss *= self.r_loss_factor 42 | kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var) 43 | kl_loss = tf.reduce_sum(kl_loss, axis = 1) 44 | kl_loss *= -0.5 45 | total_loss = reconstruction_loss + kl_loss 46 | grads = tape.gradient(total_loss, self.trainable_weights) 47 | self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) 48 | return { 49 | "loss": total_loss, 50 | "reconstruction_loss": reconstruction_loss, 51 | "kl_loss": kl_loss, 52 | } 53 | 54 | def call(self,inputs): 55 | latent = self.encoder(inputs) 56 | return self.decoder(latent) 57 | 58 | 59 | 60 | 61 | 62 | class VariationalAutoencoder(): 63 | def __init__(self 64 | , input_dim 65 | , encoder_conv_filters 66 | , encoder_conv_kernel_size 67 | , encoder_conv_strides 68 | , decoder_conv_t_filters 69 | , decoder_conv_t_kernel_size 70 | , decoder_conv_t_strides 71 | , z_dim 72 | , r_loss_factor 73 | , use_batch_norm = False 74 | , use_dropout= False 75 | ): 76 | 77 | self.name = 'variational_autoencoder' 78 | 79 | self.input_dim = input_dim 80 | self.encoder_conv_filters = encoder_conv_filters 81 | self.encoder_conv_kernel_size = encoder_conv_kernel_size 82 | self.encoder_conv_strides = encoder_conv_strides 83 | self.decoder_conv_t_filters = decoder_conv_t_filters 84 | self.decoder_conv_t_kernel_size = decoder_conv_t_kernel_size 85 | self.decoder_conv_t_strides = decoder_conv_t_strides 86 | self.z_dim = z_dim 87 | self.r_loss_factor = r_loss_factor 88 | 89 | self.use_batch_norm = use_batch_norm 90 | self.use_dropout = use_dropout 91 | 92 | self.n_layers_encoder = len(encoder_conv_filters) 93 | self.n_layers_decoder = len(decoder_conv_t_filters) 94 | 95 | self._build() 96 | 97 | def _build(self): 98 | 99 | ### THE ENCODER 100 | encoder_input = Input(shape=self.input_dim, name='encoder_input') 101 | 102 | x = encoder_input 103 | 104 | for i in range(self.n_layers_encoder): 105 | conv_layer = Conv2D( 106 | filters = self.encoder_conv_filters[i] 107 | , kernel_size = self.encoder_conv_kernel_size[i] 108 | , strides = self.encoder_conv_strides[i] 109 | , padding = 'same' 110 | , name = 'encoder_conv_' + str(i) 111 | ) 112 | 113 | x = conv_layer(x) 114 | 115 | if self.use_batch_norm: 116 | x = BatchNormalization()(x) 117 | 118 | x = LeakyReLU()(x) 119 | 120 | if self.use_dropout: 121 | x = Dropout(rate = 0.25)(x) 122 | 123 | shape_before_flattening = K.int_shape(x)[1:] 124 | 125 | x = Flatten()(x) 126 | self.mu = Dense(self.z_dim, name='mu')(x) 127 | self.log_var = Dense(self.z_dim, name='log_var')(x) 128 | 129 | self.z = Sampling(name='encoder_output')([self.mu, self.log_var]) 130 | 131 | self.encoder = Model(encoder_input, [self.mu, self.log_var, self.z], name = 'encoder') 132 | 133 | 134 | 135 | ### THE DECODER 136 | 137 | decoder_input = Input(shape=(self.z_dim,), name='decoder_input') 138 | 139 | x = Dense(np.prod(shape_before_flattening))(decoder_input) 140 | x = Reshape(shape_before_flattening)(x) 141 | 142 | for i in range(self.n_layers_decoder): 143 | conv_t_layer = Conv2DTranspose( 144 | filters = self.decoder_conv_t_filters[i] 145 | , kernel_size = self.decoder_conv_t_kernel_size[i] 146 | , strides = self.decoder_conv_t_strides[i] 147 | , padding = 'same' 148 | , name = 'decoder_conv_t_' + str(i) 149 | ) 150 | 151 | x = conv_t_layer(x) 152 | 153 | if i < self.n_layers_decoder - 1: 154 | if self.use_batch_norm: 155 | x = BatchNormalization()(x) 156 | x = LeakyReLU()(x) 157 | if self.use_dropout: 158 | x = Dropout(rate = 0.25)(x) 159 | else: 160 | x = Activation('sigmoid')(x) 161 | 162 | 163 | 164 | decoder_output = x 165 | 166 | self.decoder = Model(decoder_input, decoder_output, name = 'decoder') 167 | 168 | ### THE FULL VAE 169 | 170 | self.model = VAEModel(self.encoder, self.decoder, self.r_loss_factor) 171 | 172 | 173 | 174 | def compile(self, learning_rate): 175 | self.learning_rate = learning_rate 176 | optimizer = Adam(lr=learning_rate) 177 | self.model.compile(optimizer=optimizer) 178 | 179 | 180 | def save(self, folder): 181 | 182 | if not os.path.exists(folder): 183 | os.makedirs(folder) 184 | os.makedirs(os.path.join(folder, 'viz')) 185 | os.makedirs(os.path.join(folder, 'weights')) 186 | os.makedirs(os.path.join(folder, 'images')) 187 | 188 | with open(os.path.join(folder, 'params.pkl'), 'wb') as f: 189 | pickle.dump([ 190 | self.input_dim 191 | , self.encoder_conv_filters 192 | , self.encoder_conv_kernel_size 193 | , self.encoder_conv_strides 194 | , self.decoder_conv_t_filters 195 | , self.decoder_conv_t_kernel_size 196 | , self.decoder_conv_t_strides 197 | , self.z_dim 198 | , self.use_batch_norm 199 | , self.use_dropout 200 | ], f) 201 | 202 | self.plot_model(folder) 203 | 204 | 205 | def load_weights(self, filepath): 206 | self.model.load_weights(filepath) 207 | 208 | def train(self, x_train, batch_size, epochs, run_folder, print_every_n_batches = 100, initial_epoch = 0, lr_decay = 1): 209 | 210 | custom_callback = CustomCallback(run_folder, print_every_n_batches, initial_epoch, self) 211 | lr_sched = step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1) 212 | 213 | checkpoint_filepath=os.path.join(run_folder, "weights/weights-{epoch:03d}-{loss:.2f}.h5") 214 | checkpoint1 = ModelCheckpoint(checkpoint_filepath, save_weights_only = True, verbose=1) 215 | checkpoint2 = ModelCheckpoint(os.path.join(run_folder, 'weights/weights.h5'), save_weights_only = True, verbose=1) 216 | 217 | callbacks_list = [checkpoint1, checkpoint2, custom_callback, lr_sched] 218 | 219 | self.model.fit( 220 | x_train 221 | , x_train 222 | , batch_size = batch_size 223 | , shuffle = True 224 | , epochs = epochs 225 | , initial_epoch = initial_epoch 226 | , callbacks = callbacks_list 227 | ) 228 | 229 | 230 | 231 | def train_with_generator(self, data_flow, epochs, steps_per_epoch, run_folder, print_every_n_batches = 100, initial_epoch = 0, lr_decay = 1, ): 232 | 233 | custom_callback = CustomCallback(run_folder, print_every_n_batches, initial_epoch, self) 234 | lr_sched = step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1) 235 | 236 | checkpoint_filepath=os.path.join(run_folder, "weights/weights-{epoch:03d}-{loss:.2f}.h5") 237 | checkpoint1 = ModelCheckpoint(checkpoint_filepath, save_weights_only = True, verbose=1) 238 | checkpoint2 = ModelCheckpoint(os.path.join(run_folder, 'weights/weights.h5'), save_weights_only = True, verbose=1) 239 | 240 | callbacks_list = [checkpoint1, checkpoint2, custom_callback, lr_sched] 241 | 242 | self.model.save_weights(os.path.join(run_folder, 'weights/weights.h5')) 243 | 244 | 245 | self.model.fit( 246 | data_flow 247 | , shuffle = True 248 | , epochs = epochs 249 | , initial_epoch = initial_epoch 250 | , callbacks = callbacks_list 251 | , steps_per_epoch=steps_per_epoch 252 | ) 253 | 254 | 255 | 256 | 257 | 258 | def plot_model(self, run_folder): 259 | plot_model(self.model, to_file=os.path.join(run_folder ,'viz/model.png'), show_shapes = True, show_layer_names = True) 260 | plot_model(self.encoder, to_file=os.path.join(run_folder ,'viz/encoder.png'), show_shapes = True, show_layer_names = True) 261 | plot_model(self.decoder, to_file=os.path.join(run_folder ,'viz/decoder.png'), show_shapes = True, show_layer_names = True) 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | -------------------------------------------------------------------------------- /models/layers/layers.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import tensorflow as tf 4 | import keras 5 | 6 | from tensorflow.keras.layers import Layer, InputSpec 7 | import keras.backend as K 8 | 9 | class ReflectionPadding2D(Layer): 10 | def __init__(self, padding=(1, 1), **kwargs): 11 | self.padding = tuple(padding) 12 | self.input_spec = [InputSpec(ndim=4)] 13 | super(ReflectionPadding2D, self).__init__(**kwargs) 14 | 15 | def compute_output_shape(self, s): 16 | """ If you are using "channels_last" configuration""" 17 | return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3]) 18 | 19 | def call(self, x, mask=None): 20 | w_pad,h_pad = self.padding 21 | return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT') 22 | 23 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.8.1 2 | appnope==0.1.0 3 | astor==0.8.0 4 | astunparse==1.6.3 5 | attrs==19.2.0 6 | backcall==0.1.0 7 | bleach==3.1.0 8 | cachetools==4.1.1 9 | certifi==2020.6.20 10 | chardet==3.0.4 11 | cycler==0.10.0 12 | decorator==4.4.0 13 | defusedxml==0.6.0 14 | entrypoints==0.3 15 | gast==0.3.3 16 | google-auth==1.18.0 17 | google-auth-oauthlib==0.4.1 18 | google-pasta==0.2.0 19 | h5py==2.10.0 20 | idna==2.10 21 | imageio==2.6.1 22 | importlib-metadata==0.23 23 | ipykernel==5.1.2 24 | ipython==7.8.0 25 | ipython-genutils==0.2.0 26 | ipywidgets==7.5.1 27 | jedi==0.15.1 28 | Jinja2==2.10.3 29 | jsonschema==3.1.1 30 | jupyter==1.0.0 31 | jupyter-client==5.3.4 32 | jupyter-console==6.0.0 33 | jupyter-core==4.6.0 34 | Keras==2.3.1 35 | Keras-Applications==1.0.8 36 | Keras-Preprocessing==1.1.0 37 | kiwisolver==1.1.0 38 | Markdown==3.1.1 39 | MarkupSafe==1.1.1 40 | matplotlib==3.1.1 41 | mistune==0.8.4 42 | more-itertools==7.2.0 43 | music21==5.7.0 44 | nbconvert==5.6.0 45 | nbformat==4.4.0 46 | networkx==2.3 47 | notebook==6.0.1 48 | numpy==1.17.2 49 | oauthlib==3.1.0 50 | opt-einsum==3.1.0 51 | pandas==0.25.1 52 | pandocfilters==1.4.2 53 | parso==0.5.1 54 | pexpect==4.7.0 55 | pickleshare==0.7.5 56 | Pillow==6.2.0 57 | prometheus-client==0.7.1 58 | prompt-toolkit==2.0.10 59 | protobuf==3.10.0 60 | ptyprocess==0.6.0 61 | pyasn1==0.4.8 62 | pyasn1-modules==0.2.8 63 | pydot==1.4.1 64 | pydotplus==2.0.2 65 | Pygments==2.4.2 66 | pyparsing==2.4.2 67 | pyrsistent==0.15.4 68 | python-dateutil==2.8.0 69 | pytz==2019.3 70 | PyYAML==5.1.2 71 | pyzmq==18.1.0 72 | qtconsole==4.5.5 73 | requests==2.24.0 74 | requests-oauthlib==1.3.0 75 | rsa==4.6 76 | scikit-image==0.17.2 77 | scipy==1.4.1 78 | Send2Trash==1.5.0 79 | six==1.12.0 80 | tensorboard==2.2.2 81 | tensorboard-plugin-wit==1.7.0 82 | tensorflow==2.2.0 83 | tensorflow-addons==0.10.0 84 | tensorflow-estimator==2.2.0 85 | termcolor==1.1.0 86 | terminado==0.8.2 87 | testpath==0.4.2 88 | tornado==6.0.3 89 | traitlets==4.3.3 90 | typeguard==2.9.1 91 | urllib3==1.25.9 92 | wcwidth==0.1.7 93 | webencodings==0.5.1 94 | Werkzeug==0.16.0 95 | widgetsnbextension==3.5.1 96 | wrapt==1.11.2 97 | zipp==0.6.0 98 | -------------------------------------------------------------------------------- /run/compose/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /run/gan/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /run/paint/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /run/vae/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /run/write/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /scripts/download_cyclegan_data.sh: -------------------------------------------------------------------------------- 1 | 2 | FILE=$1 3 | 4 | if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then 5 | echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos" 6 | exit 1 7 | fi 8 | 9 | URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip 10 | ZIP_FILE=./data/$FILE.zip 11 | TARGET_DIR=./data/$FILE/ 12 | wget -N $URL -O $ZIP_FILE 13 | mkdir $TARGET_DIR 14 | unzip $ZIP_FILE -d ./data/ 15 | rm $ZIP_FILE 16 | -------------------------------------------------------------------------------- /scripts/download_gutenburg_data.sh: -------------------------------------------------------------------------------- 1 | 2 | FILE=$1 3 | NAME=$2 4 | 5 | URL=http://www.gutenberg.org/cache/epub/$FILE/pg$FILE.txt 6 | 7 | TARGET_DIR=./data/$NAME/ 8 | mkdir $TARGET_DIR 9 | TXT_FILE=./data/$NAME/data.txt 10 | wget -N $URL -O $TXT_FILE 11 | 12 | -------------------------------------------------------------------------------- /utils/callbacks.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.callbacks import Callback, LearningRateScheduler 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import os 5 | 6 | #### CALLBACKS 7 | class CustomCallback(Callback): 8 | 9 | def __init__(self, run_folder, print_every_n_batches, initial_epoch, vae): 10 | self.epoch = initial_epoch 11 | self.run_folder = run_folder 12 | self.print_every_n_batches = print_every_n_batches 13 | self.vae = vae 14 | 15 | def on_train_batch_end(self, batch, logs={}): 16 | if batch % self.print_every_n_batches == 0: 17 | z_new = np.random.normal(size = (1,self.vae.z_dim)) 18 | reconst = self.vae.decoder.predict(np.array(z_new))[0].squeeze() 19 | 20 | filepath = os.path.join(self.run_folder, 'images', 'img_' + str(self.epoch).zfill(3) + '_' + str(batch) + '.jpg') 21 | if len(reconst.shape) == 2: 22 | plt.imsave(filepath, reconst, cmap='gray_r') 23 | else: 24 | plt.imsave(filepath, reconst) 25 | 26 | def on_epoch_begin(self, epoch, logs={}): 27 | self.epoch += 1 28 | 29 | 30 | 31 | def step_decay_schedule(initial_lr, decay_factor=0.5, step_size=1): 32 | ''' 33 | Wrapper function to create a LearningRateScheduler with step decay schedule. 34 | ''' 35 | def schedule(epoch): 36 | new_lr = initial_lr * (decay_factor ** np.floor(epoch/step_size)) 37 | 38 | return new_lr 39 | 40 | return LearningRateScheduler(schedule) -------------------------------------------------------------------------------- /utils/loaders.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | 4 | from tensorflow.keras.datasets import mnist, cifar100,cifar10 5 | from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, save_img, img_to_array 6 | 7 | import pandas as pd 8 | from PIL import Image 9 | import numpy as np 10 | from os import walk, getcwd 11 | import h5py 12 | 13 | import imageio 14 | from glob import glob 15 | 16 | from tensorflow.keras.applications import vgg19 17 | from tensorflow.keras import backend as K 18 | from tensorflow.keras.utils import to_categorical 19 | 20 | import pdb 21 | 22 | 23 | class ImageLabelLoader(): 24 | def __init__(self, image_folder, target_size): 25 | self.image_folder = image_folder 26 | self.target_size = target_size 27 | 28 | def build(self, att, batch_size, label = None): 29 | 30 | data_gen = ImageDataGenerator(rescale=1./255) 31 | if label: 32 | data_flow = data_gen.flow_from_dataframe( 33 | att 34 | , self.image_folder 35 | , x_col='image_id' 36 | , y_col=label 37 | , target_size=self.target_size 38 | , class_mode='other' 39 | , batch_size=batch_size 40 | , shuffle=True 41 | ) 42 | else: 43 | data_flow = data_gen.flow_from_dataframe( 44 | att 45 | , self.image_folder 46 | , x_col='image_id' 47 | , target_size=self.target_size 48 | , class_mode='input' 49 | , batch_size=batch_size 50 | , shuffle=True 51 | ) 52 | 53 | return data_flow 54 | 55 | 56 | 57 | 58 | class DataLoader(): 59 | def __init__(self, dataset_name, img_res=(256, 256)): 60 | self.dataset_name = dataset_name 61 | self.img_res = img_res 62 | 63 | def load_data(self, domain, batch_size=1, is_testing=False): 64 | data_type = "train%s" % domain if not is_testing else "test%s" % domain 65 | path = glob('./data/%s/%s/*' % (self.dataset_name, data_type)) 66 | 67 | batch_images = np.random.choice(path, size=batch_size) 68 | 69 | imgs = [] 70 | for img_path in batch_images: 71 | img = self.imread(img_path) 72 | if not is_testing: 73 | img = np.array(Image.fromarray(img).resize(self.img_res)) 74 | 75 | if np.random.random() > 0.5: 76 | img = np.fliplr(img) 77 | else: 78 | img = np.array(Image.fromarray(img).resize(self.img_res)) 79 | imgs.append(img) 80 | 81 | imgs = np.array(imgs)/127.5 - 1. 82 | 83 | return imgs 84 | 85 | def load_batch(self, batch_size=1, is_testing=False): 86 | data_type = "train" if not is_testing else "val" 87 | path_A = glob('./data/%s/%sA/*' % (self.dataset_name, data_type)) 88 | path_B = glob('./data/%s/%sB/*' % (self.dataset_name, data_type)) 89 | 90 | self.n_batches = int(min(len(path_A), len(path_B)) / batch_size) 91 | total_samples = self.n_batches * batch_size 92 | 93 | # Sample n_batches * batch_size from each path list so that model sees all 94 | # samples from both domains 95 | path_A = np.random.choice(path_A, total_samples, replace=False) 96 | path_B = np.random.choice(path_B, total_samples, replace=False) 97 | 98 | for i in range(self.n_batches-1): 99 | batch_A = path_A[i*batch_size:(i+1)*batch_size] 100 | batch_B = path_B[i*batch_size:(i+1)*batch_size] 101 | imgs_A, imgs_B = [], [] 102 | for img_A, img_B in zip(batch_A, batch_B): 103 | img_A = self.imread(img_A) 104 | img_B = self.imread(img_B) 105 | 106 | img_A = np.array(Image.fromarray(img_A).resize(self.img_res)) 107 | img_B = np.array(Image.fromarray(img_B).resize(self.img_res)) 108 | 109 | if not is_testing and np.random.random() > 0.5: 110 | img_A = np.fliplr(img_A) 111 | img_B = np.fliplr(img_B) 112 | 113 | imgs_A.append(img_A) 114 | imgs_B.append(img_B) 115 | 116 | imgs_A = np.array(imgs_A)/127.5 - 1. 117 | imgs_B = np.array(imgs_B)/127.5 - 1. 118 | 119 | yield imgs_A, imgs_B 120 | 121 | def load_img(self, path): 122 | img = self.imread(path) 123 | img = np.array(Image.fromarray(img).resize(self.img_res)) 124 | img = img/127.5 - 1. 125 | return img[np.newaxis, :, :, :] 126 | 127 | def imread(self, path): 128 | return imageio.imread(path, pilmode='RGB').astype(np.uint8) 129 | 130 | 131 | 132 | 133 | def load_model(model_class, folder): 134 | 135 | with open(os.path.join(folder, 'params.pkl'), 'rb') as f: 136 | params = pickle.load(f) 137 | 138 | model = model_class(*params) 139 | 140 | model.load_weights(os.path.join(folder, 'weights/weights.h5')) 141 | 142 | return model 143 | 144 | 145 | def load_mnist(): 146 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 147 | 148 | x_train = x_train.astype('float32') / 255. 149 | x_train = x_train.reshape(x_train.shape + (1,)) 150 | x_test = x_test.astype('float32') / 255. 151 | x_test = x_test.reshape(x_test.shape + (1,)) 152 | 153 | return (x_train, y_train), (x_test, y_test) 154 | 155 | def load_mnist_gan(): 156 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 157 | 158 | x_train = (x_train.astype('float32') - 127.5) / 127.5 159 | x_train = x_train.reshape(x_train.shape + (1,)) 160 | x_test = (x_test.astype('float32') - 127.5) / 127.5 161 | x_test = x_test.reshape(x_test.shape + (1,)) 162 | 163 | return (x_train, y_train), (x_test, y_test) 164 | 165 | 166 | 167 | def load_fashion_mnist(input_rows, input_cols, path='./data/fashion/fashion-mnist_train.csv'): 168 | #read the csv data 169 | df = pd.read_csv(path) 170 | #extract the image pixels 171 | X_train = df.drop(columns = ['label']) 172 | X_train = X_train.values 173 | X_train = (X_train.astype('float32') - 127.5) / 127.5 174 | X_train = X_train.reshape(X_train.shape[0], input_rows, input_cols, 1) 175 | #extract the labels 176 | y_train = df['label'].values 177 | 178 | return X_train, y_train 179 | 180 | def load_safari(folder): 181 | 182 | mypath = os.path.join("./data", folder) 183 | txt_name_list = [] 184 | for (dirpath, dirnames, filenames) in walk(mypath): 185 | for f in filenames: 186 | if f != '.DS_Store': 187 | txt_name_list.append(f) 188 | break 189 | 190 | slice_train = int(80000/len(txt_name_list)) ###Setting value to be 80000 for the final dataset 191 | i = 0 192 | seed = np.random.randint(1, 10e6) 193 | 194 | for txt_name in txt_name_list: 195 | txt_path = os.path.join(mypath,txt_name) 196 | x = np.load(txt_path) 197 | x = (x.astype('float32') - 127.5) / 127.5 198 | # x = x.astype('float32') / 255.0 199 | 200 | x = x.reshape(x.shape[0], 28, 28, 1) 201 | 202 | y = [i] * len(x) 203 | np.random.seed(seed) 204 | np.random.shuffle(x) 205 | np.random.seed(seed) 206 | np.random.shuffle(y) 207 | x = x[:slice_train] 208 | y = y[:slice_train] 209 | if i != 0: 210 | xtotal = np.concatenate((x,xtotal), axis=0) 211 | ytotal = np.concatenate((y,ytotal), axis=0) 212 | else: 213 | xtotal = x 214 | ytotal = y 215 | i += 1 216 | 217 | return xtotal, ytotal 218 | 219 | 220 | 221 | def load_cifar(label, num): 222 | if num == 10: 223 | (x_train, y_train), (x_test, y_test) = cifar10.load_data() 224 | else: 225 | (x_train, y_train), (x_test, y_test) = cifar100.load_data(label_mode = 'fine') 226 | 227 | train_mask = [y[0]==label for y in y_train] 228 | test_mask = [y[0]==label for y in y_test] 229 | 230 | x_data = np.concatenate([x_train[train_mask], x_test[test_mask]]) 231 | y_data = np.concatenate([y_train[train_mask], y_test[test_mask]]) 232 | 233 | x_data = (x_data.astype('float32') - 127.5) / 127.5 234 | 235 | return (x_data, y_data) 236 | 237 | 238 | def load_celeb(data_name, image_size, batch_size): 239 | data_folder = os.path.join("./data", data_name) 240 | 241 | data_gen = ImageDataGenerator(preprocessing_function=lambda x: (x.astype('float32') - 127.5) / 127.5) 242 | 243 | x_train = data_gen.flow_from_directory(data_folder 244 | , target_size = (image_size,image_size) 245 | , batch_size = batch_size 246 | , shuffle = True 247 | , class_mode = 'input' 248 | , subset = "training" 249 | ) 250 | 251 | return x_train 252 | 253 | 254 | def load_music(data_name, filename, n_bars, n_steps_per_bar): 255 | file = os.path.join("./data", data_name, filename) 256 | 257 | with np.load(file, encoding='bytes', allow_pickle = True) as f: 258 | data = f['train'] 259 | 260 | data_ints = [] 261 | 262 | for x in data: 263 | counter = 0 264 | cont = True 265 | while cont: 266 | if not np.any(np.isnan(x[counter:(counter+4)])): 267 | cont = False 268 | else: 269 | counter += 4 270 | 271 | if n_bars * n_steps_per_bar < x.shape[0]: 272 | data_ints.append(x[counter:(counter + (n_bars * n_steps_per_bar)),:]) 273 | 274 | 275 | data_ints = np.array(data_ints) 276 | 277 | n_songs = data_ints.shape[0] 278 | n_tracks = data_ints.shape[2] 279 | 280 | data_ints = data_ints.reshape([n_songs, n_bars, n_steps_per_bar, n_tracks]) 281 | 282 | max_note = 83 283 | 284 | where_are_NaNs = np.isnan(data_ints) 285 | data_ints[where_are_NaNs] = max_note + 1 286 | max_note = max_note + 1 287 | 288 | data_ints = data_ints.astype(int) 289 | 290 | num_classes = max_note + 1 291 | 292 | 293 | data_binary = np.eye(num_classes)[data_ints] 294 | data_binary[data_binary==0] = -1 295 | data_binary = np.delete(data_binary, max_note,-1) 296 | 297 | data_binary = data_binary.transpose([0,1,2, 4,3]) 298 | 299 | 300 | 301 | 302 | 303 | return data_binary, data_ints, data 304 | 305 | 306 | def preprocess_image(data_name, file, img_nrows, img_ncols): 307 | 308 | image_path = os.path.join('./data', data_name, file) 309 | 310 | img = load_img(image_path, target_size=(img_nrows, img_ncols)) 311 | img = img_to_array(img) 312 | img = np.expand_dims(img, axis=0) 313 | img = vgg19.preprocess_input(img) 314 | return img 315 | 316 | -------------------------------------------------------------------------------- /utils/write.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from collections import Counter 4 | 5 | import csv 6 | 7 | import random 8 | 9 | import numpy as np 10 | 11 | 12 | _MAX_BATCH_SIZE = 128 13 | _MAX_DOC_LENGTH = 200 14 | 15 | PADDING_WORD = "" 16 | UNKNOWN_WORD = "" 17 | START_WORD = "" 18 | END_WORD = "" 19 | 20 | _word_to_idx = {} 21 | _idx_to_word = [] 22 | 23 | 24 | def _add_word(word): 25 | idx = len(_idx_to_word) 26 | _word_to_idx[word] = idx 27 | _idx_to_word.append(word) 28 | return idx 29 | 30 | 31 | PADDING_TOKEN = _add_word(PADDING_WORD) 32 | UNKNOWN_TOKEN = _add_word(UNKNOWN_WORD) 33 | START_TOKEN = _add_word(START_WORD) 34 | END_TOKEN = _add_word(END_WORD) 35 | 36 | def get_glove(): 37 | 38 | embeddings_path = './data/glove/glove.6B.100d.trimmed.txt' 39 | 40 | with open(embeddings_path) as f: 41 | line = f.readline() 42 | chunks = line.split(" ") 43 | dimensions = len(chunks) - 1 44 | f.seek(0) 45 | 46 | vocab_size = sum(1 for line in f) 47 | vocab_size += 4 #3 48 | f.seek(0) 49 | 50 | glove = np.ndarray((vocab_size, dimensions), dtype=np.float32) 51 | glove[PADDING_TOKEN] = np.random.normal(0, 0.02, dimensions) 52 | glove[UNKNOWN_TOKEN] = np.random.normal(0, 0.02, dimensions) 53 | glove[START_TOKEN] = np.random.normal(0, 0.02, dimensions) 54 | glove[END_TOKEN] = np.random.normal(0, 0.02, dimensions) 55 | 56 | for line in f: 57 | chunks = line.split(" ") 58 | idx = _add_word(chunks[0]) 59 | glove[idx] = [float(chunk) for chunk in chunks[1:]] 60 | if len(_idx_to_word) >= vocab_size: 61 | break 62 | 63 | return glove 64 | 65 | 66 | 67 | 68 | def look_up_word(word): 69 | return _word_to_idx.get(word, UNKNOWN_TOKEN) 70 | 71 | 72 | def look_up_token(token): 73 | return _idx_to_word[token] 74 | 75 | 76 | 77 | def _tokenize(string): 78 | return [word.lower() for word in string.split(" ")] 79 | 80 | 81 | def _prepare_batch(batch): 82 | id_to_indices = {} 83 | document_ids = [] 84 | document_text = [] 85 | document_words = [] 86 | answer_text = [] 87 | answer_indices = [] 88 | question_text = [] 89 | question_input_words = [] 90 | question_output_words = [] 91 | for i, entry in enumerate(batch): 92 | id_to_indices.setdefault(entry["document_id"], []).append(i) 93 | document_ids.append(entry["document_id"]) 94 | document_text.append(entry["document_text"]) 95 | document_words.append(entry["document_words"]) 96 | answer_text.append(entry["answer_text"]) 97 | answer_indices.append(entry["answer_indices"]) 98 | question_text.append(entry["question_text"]) 99 | 100 | question_words = entry["question_words"] 101 | question_input_words.append([START_WORD] + question_words) 102 | question_output_words.append(question_words + [END_WORD]) 103 | 104 | batch_size = len(batch) 105 | max_document_len = max((len(document) for document in document_words), default=0) 106 | max_answer_len = max((len(answer) for answer in answer_indices), default=0) 107 | max_question_len = max((len(question) for question in question_input_words), default=0) 108 | 109 | document_tokens = np.zeros((batch_size, max_document_len), dtype=np.int32) 110 | document_lengths = np.zeros(batch_size, dtype=np.int32) 111 | answer_labels = np.zeros((batch_size, max_document_len), dtype=np.int32) 112 | answer_masks = np.zeros((batch_size, max_answer_len, max_document_len), dtype=np.int32) 113 | answer_lengths = np.zeros(batch_size, dtype=np.int32) 114 | question_input_tokens = np.zeros((batch_size, max_question_len), dtype=np.int32) 115 | question_output_tokens = np.zeros((batch_size, max_question_len), dtype=np.int32) 116 | question_lengths = np.zeros(batch_size, dtype=np.int32) 117 | 118 | for i in range(batch_size): 119 | for j, word in enumerate(document_words[i]): 120 | document_tokens[i, j] = look_up_word(word) 121 | document_lengths[i] = len(document_words[i]) 122 | 123 | for j, index in enumerate(answer_indices[i]): 124 | for shared_i in id_to_indices[batch[i]["document_id"]]: 125 | answer_labels[shared_i, index] = 1 126 | answer_masks[i, j, index] = 1 127 | answer_lengths[i] = len(answer_indices[i]) 128 | 129 | for j, word in enumerate(question_input_words[i]): 130 | question_input_tokens[i, j] = look_up_word(word) 131 | for j, word in enumerate(question_output_words[i]): 132 | question_output_tokens[i, j] = look_up_word(word) 133 | question_lengths[i] = len(question_input_words[i]) 134 | 135 | return { 136 | "size": batch_size, 137 | "document_ids": document_ids, 138 | "document_text": document_text, 139 | "document_words": document_words, 140 | "document_tokens": document_tokens, 141 | "document_lengths": document_lengths, 142 | "answer_text": answer_text, 143 | "answer_indices": answer_indices, 144 | "answer_labels": answer_labels, 145 | "answer_masks": answer_masks, 146 | "answer_lengths": answer_lengths, 147 | "question_text": question_text, 148 | "question_input_tokens": question_input_tokens, 149 | "question_output_tokens": question_output_tokens, 150 | "question_lengths": question_lengths, 151 | } 152 | 153 | 154 | def collapse_documents(batch): 155 | seen_ids = set() 156 | keep = [] 157 | 158 | for i in range(batch["size"]): 159 | id = batch["document_ids"][i] 160 | if id in seen_ids: 161 | continue 162 | 163 | keep.append(i) 164 | seen_ids.add(id) 165 | 166 | result = {} 167 | for key, value in batch.items(): 168 | if key == "size": 169 | result[key] = len(keep) 170 | elif isinstance(value, np.ndarray): 171 | result[key] = value[keep] 172 | else: 173 | result[key] = [value[i] for i in keep] 174 | return result 175 | 176 | 177 | def expand_answers(batch, answers): 178 | new_batch = [] 179 | 180 | for i in range(batch["size"]): 181 | split_answers = [] 182 | last = None 183 | for j, tag in enumerate(answers[i]): 184 | if tag: 185 | if last != j - 1: 186 | split_answers.append([]) 187 | split_answers[-1].append(j) 188 | last = j 189 | 190 | if len(split_answers) > 0: 191 | 192 | answer_indices = split_answers[0] 193 | # for answer_indices in split_answers: 194 | document_id = batch["document_ids"][i] 195 | document_text = batch["document_text"][i] 196 | document_words = batch["document_words"][i] 197 | answer_text = " ".join(document_words[i] for i in answer_indices) 198 | new_batch.append({ 199 | "document_id": document_id, 200 | "document_text": document_text, 201 | "document_words": document_words, 202 | "answer_text": answer_text, 203 | "answer_indices": answer_indices, 204 | "question_text": "", 205 | "question_words": [], 206 | }) 207 | else: 208 | new_batch.append({ 209 | "document_id": batch["document_ids"][i], 210 | "document_text": batch["document_text"][i], 211 | "document_words": batch["document_words"][i], 212 | "answer_text": "", 213 | "answer_indices": [], 214 | "question_text": "", 215 | "question_words": [], 216 | }) 217 | 218 | return _prepare_batch(new_batch) 219 | 220 | 221 | def _read_data(path): 222 | stories = {} 223 | 224 | with open(path) as f: 225 | header_seen = False 226 | for row in csv.reader(f): 227 | if not header_seen: 228 | header_seen = True 229 | continue 230 | 231 | document_id = row[0] 232 | 233 | existing_stories = stories.setdefault(document_id, []) 234 | 235 | document_text = row[1] 236 | if existing_stories and document_text == existing_stories[0]["document_text"]: 237 | # Save memory by sharing identical documents 238 | document_text = existing_stories[0]["document_text"] 239 | document_words = existing_stories[0]["document_words"] 240 | else: 241 | document_words = _tokenize(document_text) 242 | document_words = document_words[:_MAX_DOC_LENGTH] 243 | 244 | question_text = row[2] 245 | question_words = _tokenize(question_text) 246 | 247 | answer = row[3] 248 | answer_indices = [] 249 | for chunk in answer.split(","): 250 | start, end = (int(index) for index in chunk.split(":")) 251 | if end < _MAX_DOC_LENGTH: 252 | answer_indices.extend(range(start, end)) 253 | answer_text = " ".join(document_words[i] for i in answer_indices) 254 | 255 | if len(answer_indices) > 0: 256 | existing_stories.append({ 257 | "document_id": document_id, 258 | "document_text": document_text, 259 | "document_words": document_words, 260 | "answer_text": answer_text, 261 | "answer_indices": answer_indices, 262 | "question_text": question_text, 263 | "question_words": question_words, 264 | }) 265 | 266 | 267 | 268 | return stories 269 | 270 | 271 | def _process_stories(stories): 272 | batch = [] 273 | vals = list(stories.values()) 274 | random.shuffle(vals) 275 | 276 | for story in vals: 277 | if len(batch) + len(story) > _MAX_BATCH_SIZE: 278 | yield _prepare_batch(batch) 279 | batch = [] 280 | batch.extend(story) 281 | 282 | if batch: 283 | yield _prepare_batch(batch) 284 | 285 | 286 | _training_stories = None 287 | _test_stories = None 288 | 289 | def _load_training_stories(): 290 | global _training_stories 291 | _training_stories = _read_data("./data/qa/train.csv") 292 | return _training_stories 293 | 294 | def _load_test_stories(): 295 | global _test_stories 296 | _test_stories = _read_data("./data/qa_test/my_test.csv") 297 | return _test_stories 298 | 299 | def training_data(): 300 | return _process_stories(_load_training_stories()) 301 | 302 | def test_data(): 303 | return _process_stories(_load_test_stories()) 304 | 305 | 306 | def trim_embeddings(): 307 | document_counts = Counter() 308 | question_counts = Counter() 309 | for data in [_load_training_stories().values(), _load_test_stories().values()]: 310 | 311 | for stories in data: 312 | 313 | if len(stories) > 0: 314 | document_counts.update(stories[0]["document_words"]) 315 | for story in stories: 316 | question_counts.update(story["question_words"]) 317 | 318 | keep = set() 319 | for word, count in question_counts.most_common(5000): 320 | keep.add(word) 321 | for word, count in document_counts.most_common(): 322 | if len(keep) >= 10000: 323 | break 324 | keep.add(word) 325 | 326 | with open("./data/glove/glove.6B.100d.txt", encoding="utf-8") as f: 327 | with open("./data/glove/glove.6B.100d.trimmed.txt", "w") as f2: 328 | for line in f: 329 | if line.split(" ")[0] in keep: 330 | f2.write(line) 331 | 332 | 333 | if __name__ == '__main__': 334 | trim_embeddings() 335 | --------------------------------------------------------------------------------