├── .gitignore ├── .travis.yml ├── Dockerfile ├── README.md ├── docker-compose.yml ├── imgs └── transformer.png ├── ipynb ├── .ipynb_checkpoints │ └── Set_transformer-checkpoint.ipynb └── Set_transformer.ipynb ├── requirements.txt ├── set_transformer ├── __init__.py ├── blocks.py ├── data │ └── simulation.py ├── layers │ ├── __init__.py │ └── attention.py └── model.py ├── setup.py └── tests ├── __init__.py ├── test_attention.py ├── test_blocks.py ├── test_data_generation.py ├── test_layers.py └── test_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | ipynb/images_background.zip 2 | ipynb/images_background/ 3 | .idea 4 | ipynb/vision_model.h5 5 | ipynb/.ipynb_checkpoints/* 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | 140 | # pytype static type analyzer 141 | .pytype/ 142 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | cache: pip 3 | python: 4 | - '3.6.9' 5 | 6 | 7 | install: pip install -r requirements.txt 8 | script: pytest -W ignore 9 | after_success: 10 | - codecov # submit coverage 11 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:nightly-gpu-py3-jupyter 2 | 3 | USER root 4 | COPY . /app 5 | 6 | WORKDIR /app 7 | 8 | RUN pip3 install -r requirements.txt 9 | RUN python3 setup.py develop 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # set-transformer 2 | A TensorFlow implementation of the paper 'Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks' 3 | 4 | [![Build Status](https://travis-ci.com/arrigonialberto86/set_transformer.svg?branch=master)](https://travis-ci.com/arrigonialberto86/set_transformer) 5 | 6 | [![PyPI version](https://badge.fury.io/py/tensorflow.svg)](https://badge.fury.io/py/tensorflow) 7 | 8 | Image not found 9 | 10 | ## Using Docker (dev-only mode) 11 | 12 | In this project a Dockerfile and a docker-compose.yml files have been added. You can use the listed services after cloning this project by doing: 13 | 14 | ```docker-compose build``` 15 | 16 | and then to start a Jupyter notebook with this package already installed in `develop` mode: 17 | 18 | ```docker-compose run -p 8001:8001 jupyter``` 19 | 20 | or start a `bash` session: 21 | 22 | ```docker-compose run bash``` 23 | 24 | and execute the automated unit test suite: 25 | 26 | ```pytest -W ignore``` 27 | 28 | ## Basic example usage 29 | 30 | ```python 31 | from set_transformer.data.simulation import gen_max_dataset 32 | from set_transformer.model import BasicSetTransformer 33 | import numpy as np 34 | 35 | train_X, train_y = gen_max_dataset(dataset_size=100000, set_size=9, seed=1) 36 | test_X, test_y = gen_max_dataset(dataset_size=15000, set_size=9, seed=3) 37 | 38 | set_transformer = BasicSetTransformer() 39 | set_transformer.compile(loss='mae', optimizer='adam') 40 | set_transformer.fit(train_X, train_y, epochs=3) 41 | predictions = set_transformer.predict(test_X) 42 | print("MAE on test set is: ", np.abs(test_y - predictions).mean()) 43 | ``` 44 | 45 | Which returns: 46 | 47 | ```bash 48 | Train on 100000 samples 49 | Epoch 1/3 50 | 100000/100000 [==============================] - 27s 270us/sample - loss: 32.8959 51 | Epoch 2/3 52 | 100000/100000 [==============================] - 20s 197us/sample - loss: 6.6131 53 | Epoch 3/3 54 | 100000/100000 [==============================] - 22s 216us/sample - loss: 6.6121 55 | MAE on test set is: 6.558687 56 | ``` -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | services: 3 | jupyter: 4 | image: set_transformer:0.0.1 5 | build: . 6 | command: jupyter notebook --ip=0.0.0.0 --port=8001 --allow-root --NotebookApp.token='' --NotebookApp.password='' --notebook-dir=/app/ 7 | volumes: 8 | - ${TRANSFORMER}:/app 9 | ports: 10 | - 8001:8001 11 | bash: 12 | image: set_transformer:0.0.1 13 | build: . 14 | command: /bin/bash 15 | volumes: 16 | - ${TRANSFORMER}:/app 17 | -------------------------------------------------------------------------------- /imgs/transformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arrigonialberto86/set_transformer/ec76b58d43febb85bc30f4d53d58fd630c46f009/imgs/transformer.png -------------------------------------------------------------------------------- /ipynb/.ipynb_checkpoints/Set_transformer-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from __future__ import absolute_import, division, print_function, unicode_literals\n", 10 | "import time\n", 11 | "import numpy as np\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "import tensorflow as tf\n", 14 | "import warnings\n", 15 | "with warnings.catch_warnings():\n", 16 | " warnings.filterwarnings(\"ignore\",category=FutureWarning)" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 3, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "Attention weights are:\n", 29 | "tf.Tensor([[0. 1. 0. 0.]], shape=(1, 4), dtype=float32)\n", 30 | "Output is:\n", 31 | "tf.Tensor([[10. 0.]], shape=(1, 2), dtype=float32)\n" 32 | ] 33 | } 34 | ], 35 | "source": [ 36 | "# MultiHeadAttention\n", 37 | "# https://www.tensorflow.org/tutorials/text/transformer, appears in \"Attention is all you need\" NIPS 2018 paper\n", 38 | "import numpy as np\n", 39 | "import tensorflow as tf\n", 40 | "\n", 41 | "\n", 42 | "def scaled_dot_product_attention(q, k, v, mask):\n", 43 | " \"\"\"Calculate the attention weights.\n", 44 | " q, k, v must have matching leading dimensions.\n", 45 | " k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.\n", 46 | " The mask has different shapes depending on its type(padding or look ahead) \n", 47 | " but it must be broadcastable for addition.\n", 48 | "\n", 49 | " Args:\n", 50 | " q: query shape == (..., seq_len_q, depth)\n", 51 | " k: key shape == (..., seq_len_k, depth)\n", 52 | " v: value shape == (..., seq_len_v, depth_v)\n", 53 | " mask: Float tensor with shape broadcastable \n", 54 | " to (..., seq_len_q, seq_len_k). Defaults to None.\n", 55 | "\n", 56 | " Returns:\n", 57 | " output, attention_weights\n", 58 | " \"\"\"\n", 59 | "\n", 60 | " matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)\n", 61 | " \n", 62 | " # scale matmul_qk\n", 63 | " dk = tf.cast(tf.shape(k)[-1], tf.float32)\n", 64 | " scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)\n", 65 | "\n", 66 | " # add the mask to the scaled tensor.\n", 67 | " if mask is not None:\n", 68 | " scaled_attention_logits += (mask * -1e9) \n", 69 | "\n", 70 | " # softmax is normalized on the last axis (seq_len_k) so that the scores\n", 71 | " # add up to 1.\n", 72 | " attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)\n", 73 | "\n", 74 | " output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)\n", 75 | "\n", 76 | " return output, attention_weights\n", 77 | "\n", 78 | "\n", 79 | "def print_out(q, k, v):\n", 80 | " temp_out, temp_attn = scaled_dot_product_attention(\n", 81 | " q, k, v, None)\n", 82 | " print ('Attention weights are:')\n", 83 | " print (temp_attn)\n", 84 | " print ('Output is:')\n", 85 | " print (temp_out)\n", 86 | " \n", 87 | "np.set_printoptions(suppress=True)\n", 88 | "\n", 89 | "temp_k = tf.constant([[10,0,0],\n", 90 | " [0,10,0],\n", 91 | " [0,0,10],\n", 92 | " [0,0,10]], dtype=tf.float32) # (4, 3)\n", 93 | "\n", 94 | "temp_v = tf.constant([[ 1,0],\n", 95 | " [ 10,0],\n", 96 | " [ 100,5],\n", 97 | " [1000,6]], dtype=tf.float32) # (4, 2)\n", 98 | "\n", 99 | "# This `query` aligns with the second `key`,\n", 100 | "# so the second `value` is returned.\n", 101 | "temp_q = tf.constant([[0, 10, 0]], dtype=tf.float32) # (1, 3)\n", 102 | "print_out(temp_q, temp_k, temp_v)" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 4, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "from tensorflow.keras.layers import Dense\n", 112 | "from tensorflow.keras.models import Model\n", 113 | "\n", 114 | " \n", 115 | "class MultiHeadAttention(tf.keras.layers.Layer):\n", 116 | " def __init__(self, d_model, num_heads):\n", 117 | " super(MultiHeadAttention, self).__init__()\n", 118 | " self.num_heads = num_heads\n", 119 | " self.d_model = d_model\n", 120 | "\n", 121 | " assert d_model % self.num_heads == 0\n", 122 | " \n", 123 | " self.depth = d_model // self.num_heads\n", 124 | " \n", 125 | " self.wq = tf.keras.layers.Dense(d_model)\n", 126 | " self.wk = tf.keras.layers.Dense(d_model)\n", 127 | " self.wv = tf.keras.layers.Dense(d_model)\n", 128 | "\n", 129 | " self.dense = tf.keras.layers.Dense(d_model)\n", 130 | " \n", 131 | " def split_heads(self, x, batch_size):\n", 132 | " \"\"\"Split the last dimension into (num_heads, depth).\n", 133 | " Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)\n", 134 | " \"\"\"\n", 135 | " x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))\n", 136 | " return tf.transpose(x, perm=[0, 2, 1, 3])\n", 137 | " \n", 138 | " def call(self, q, k, v, mask=None):\n", 139 | " batch_size = tf.shape(q)[0]\n", 140 | "\n", 141 | " q = self.wq(q) # (batch_size, seq_len, d_model)\n", 142 | " k = self.wk(k) # (batch_size, seq_len, d_model)\n", 143 | " v = self.wv(v) # (batch_size, seq_len, d_model)\n", 144 | "\n", 145 | " q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)\n", 146 | " k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)\n", 147 | " v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)\n", 148 | "\n", 149 | " # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)\n", 150 | " # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)\n", 151 | " scaled_attention, attention_weights = scaled_dot_product_attention(\n", 152 | " q, k, v, mask)\n", 153 | " \n", 154 | " scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)\n", 155 | "\n", 156 | " concat_attention = tf.reshape(scaled_attention, \n", 157 | " (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)\n", 158 | "\n", 159 | " output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)\n", 160 | "\n", 161 | " return output\n", 162 | " \n", 163 | "\n", 164 | "# temp_mha = MultiHeadAttention(d_model=512, num_heads=8)\n", 165 | "# y = tf.random.uniform((1, 60, 512)) # (batch_size, encoder_sequence, d_model)\n", 166 | "# out = temp_mha(v=y, k=y, q=y)\n", 167 | "# print(out.shape)\n", 168 | "\n", 169 | "\n", 170 | "class RFF(tf.keras.layers.Layer):\n", 171 | " \"\"\"\n", 172 | " Row-wise FeedForward layers.\n", 173 | " \"\"\"\n", 174 | " def __init__(self, d):\n", 175 | " super(RFF, self).__init__()\n", 176 | " \n", 177 | " self.linear_1 = Dense(d, activation='relu')\n", 178 | " self.linear_2 = Dense(d, activation='relu')\n", 179 | " self.linear_3 = Dense(d, activation='relu')\n", 180 | " \n", 181 | " def call(self, x):\n", 182 | " \"\"\"\n", 183 | " Arguments:\n", 184 | " x: a float tensor with shape [b, n, d].\n", 185 | " Returns:\n", 186 | " a float tensor with shape [b, n, d].\n", 187 | " \"\"\"\n", 188 | " return self.linear_3(self.linear_2(self.linear_1(x))) \n", 189 | "\n", 190 | "\n", 191 | "# mlp = RFF(3)\n", 192 | "# y = mlp(tf.ones(shape=(2, 4, 3))) # The first call to the `mlp` will create the weights\n", 193 | "# print('weights:', len(mlp.weights))\n", 194 | "# print('trainable weights:', len(mlp.trainable_weights))" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 27, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "# Referencing https://arxiv.org/pdf/1810.00825.pdf \n", 204 | "# and the original PyTorch implementation https://github.com/TropComplique/set-transformer/blob/master/blocks.py\n", 205 | "from tensorflow import repeat\n", 206 | "# from tensorflow.keras.backend import repeat_elements\n", 207 | "from tensorflow.keras.layers import LayerNormalization\n", 208 | "\n", 209 | "\n", 210 | "class MultiHeadAttentionBlock(tf.keras.layers.Layer):\n", 211 | " def __init__(self, d, h, rff):\n", 212 | " super(MultiHeadAttentionBlock, self).__init__()\n", 213 | " self.multihead = MultiHeadAttention(d, h)\n", 214 | " self.layer_norm1 = LayerNormalization(epsilon=1e-6, dtype='float32')\n", 215 | " self.layer_norm2 = LayerNormalization(epsilon=1e-6, dtype='float32')\n", 216 | " self.rff = rff\n", 217 | " \n", 218 | " def call(self, x, y):\n", 219 | " \"\"\"\n", 220 | " Arguments:\n", 221 | " x: a float tensor with shape [b, n, d].\n", 222 | " y: a float tensor with shape [b, m, d].\n", 223 | " Returns:\n", 224 | " a float tensor with shape [b, n, d].\n", 225 | " \"\"\"\n", 226 | " \n", 227 | " h = self.layer_norm1(x + self.multihead(x, y, y))\n", 228 | " return self.layer_norm2(h + self.rff(h))\n", 229 | "\n", 230 | "# x_data = tf.random.normal(shape=(10, 2, 9))\n", 231 | "# y_data = tf.random.normal(shape=(10, 3, 9))\n", 232 | "# rff = RFF(d=9)\n", 233 | "# mab = MultiHeadAttentionBlock(9, 3, rff=rff)\n", 234 | "# mab(x_data, y_data).shape \n", 235 | "\n", 236 | " \n", 237 | "class SetAttentionBlock(tf.keras.layers.Layer):\n", 238 | " def __init__(self, d, h, rff):\n", 239 | " super(SetAttentionBlock, self).__init__()\n", 240 | " self.mab = MultiHeadAttentionBlock(d, h, rff)\n", 241 | " \n", 242 | " def call(self, x):\n", 243 | " \"\"\"\n", 244 | " Arguments:\n", 245 | " x: a float tensor with shape [b, n, d].\n", 246 | " Returns:\n", 247 | " a float tensor with shape [b, n, d].\n", 248 | " \"\"\"\n", 249 | " return self.mab(x, x)\n", 250 | "\n", 251 | " \n", 252 | "class InducedSetAttentionBlock(tf.keras.layers.Layer):\n", 253 | " def __init__(self, d, m, h, rff1, rff2):\n", 254 | " \"\"\"\n", 255 | " Arguments:\n", 256 | " d: an integer, input dimension.\n", 257 | " m: an integer, number of inducing points.\n", 258 | " h: an integer, number of heads.\n", 259 | " rff1, rff2: modules, row-wise feedforward layers.\n", 260 | " It takes a float tensor with shape [b, n, d] and\n", 261 | " returns a float tensor with the same shape.\n", 262 | " \"\"\"\n", 263 | " super(InducedSetAttentionBlock, self).__init__()\n", 264 | " self.mab1 = MultiHeadAttentionBlock(d, h, rff1)\n", 265 | " self.mab2 = MultiHeadAttentionBlock(d, h, rff2)\n", 266 | " self.inducing_points = tf.random.normal(shape=(1, m, d))\n", 267 | "\n", 268 | " def call(self, x):\n", 269 | " \"\"\"\n", 270 | " Arguments:\n", 271 | " x: a float tensor with shape [b, n, d].\n", 272 | " Returns:\n", 273 | " a float tensor with shape [b, n, d].\n", 274 | " \"\"\"\n", 275 | " b = tf.shape(x)[0] \n", 276 | " p = self.inducing_points\n", 277 | " p = repeat(p, (b), axis=0) # shape [b, m, d] \n", 278 | " \n", 279 | " h = self.mab1(p, x) # shape [b, m, d]\n", 280 | " return self.mab2(x, h) \n", 281 | " \n", 282 | "\n", 283 | "class PoolingMultiHeadAttention(tf.keras.layers.Layer):\n", 284 | "\n", 285 | " def __init__(self, d, k, h, rff, rff_s):\n", 286 | " \"\"\"\n", 287 | " Arguments:\n", 288 | " d: an integer, input dimension.\n", 289 | " k: an integer, number of seed vectors.\n", 290 | " h: an integer, number of heads.\n", 291 | " rff: a module, row-wise feedforward layers.\n", 292 | " It takes a float tensor with shape [b, n, d] and\n", 293 | " returns a float tensor with the same shape.\n", 294 | " \"\"\"\n", 295 | " super(PoolingMultiHeadAttention, self).__init__()\n", 296 | " self.mab = MultiHeadAttentionBlock(d, h, rff)\n", 297 | " self.seed_vectors = tf.random.normal(shape=(1, k, d))\n", 298 | " self.rff_s = rff_s\n", 299 | "\n", 300 | " @tf.function\n", 301 | " def call(self, z):\n", 302 | " \"\"\"\n", 303 | " Arguments:\n", 304 | " z: a float tensor with shape [b, n, d].\n", 305 | " Returns:\n", 306 | " a float tensor with shape [b, k, d]\n", 307 | " \"\"\"\n", 308 | " b = tf.shape(z)[0]\n", 309 | " s = self.seed_vectors\n", 310 | " s = repeat(s, (b), axis=0) # shape [b, k, d]\n", 311 | " return self.mab(s, self.rff_s(z))\n", 312 | " \n", 313 | "\n", 314 | "# z = tf.random.normal(shape=(10, 2, 9))\n", 315 | "# rff, rff_s = RFF(d=9), RFF(d=9) \n", 316 | "# pma = PoolingMultiHeadAttention(d=9, k=10, h=3, rff=rff, rff_s=rff_s)\n", 317 | "# pma(z).shape" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": 28, 323 | "metadata": {}, 324 | "outputs": [], 325 | "source": [ 326 | "from tensorflow.keras.layers import Dense\n", 327 | " \n", 328 | "\n", 329 | "class STEncoderBasic(tf.keras.layers.Layer):\n", 330 | " def __init__(self, d=12, m=6, h=6):\n", 331 | " super(STEncoderBasic, self).__init__()\n", 332 | " \n", 333 | " # Embedding part\n", 334 | " self.linear_1 = Dense(d, activation='relu')\n", 335 | " \n", 336 | " # Encoding part\n", 337 | " self.isab_1 = InducedSetAttentionBlock(d, m, h, RFF(d), RFF(d))\n", 338 | " self.isab_2 = InducedSetAttentionBlock(d, m, h, RFF(d), RFF(d))\n", 339 | " \n", 340 | " def call(self, x):\n", 341 | " return self.isab_2(self.isab_1(self.linear_1(x)))\n", 342 | "\n", 343 | " \n", 344 | "class STDecoderBasic(tf.keras.layers.Layer):\n", 345 | " def __init__(self, out_dim, d=12, m=6, h=2, k=8):\n", 346 | " super(STDecoderBasic, self).__init__()\n", 347 | " \n", 348 | " self.PMA = PoolingMultiHeadAttention(d, k, h, RFF(d), RFF(d))\n", 349 | " self.SAB = SetAttentionBlock(d, h, RFF(d))\n", 350 | " self.output_mapper = Dense(out_dim) \n", 351 | " self.k, self.d = k, d\n", 352 | "\n", 353 | " def call(self, x):\n", 354 | " decoded_vec = self.SAB(self.PMA(x))\n", 355 | " decoded_vec = tf.reshape(decoded_vec, [-1, self.k * self.d])\n", 356 | " return tf.reshape(self.output_mapper(decoded_vec), (tf.shape(decoded_vec)[0],))\n" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": 29, 362 | "metadata": {}, 363 | "outputs": [ 364 | { 365 | "data": { 366 | "text/plain": [ 367 | "TensorShape([100000, 9, 1])" 368 | ] 369 | }, 370 | "execution_count": 29, 371 | "metadata": {}, 372 | "output_type": "execute_result" 373 | } 374 | ], 375 | "source": [ 376 | "def gen_max_dataset(dataset_size=100000, set_size=9):\n", 377 | " \"\"\"\n", 378 | " The number of objects per set is constant in this toy example\n", 379 | " \"\"\"\n", 380 | " x = np.random.uniform(1, 100, (dataset_size, set_size))\n", 381 | " y = np.max(x, axis=1)\n", 382 | " x, y = np.expand_dims(x, axis=2), np.expand_dims(y, axis=1)\n", 383 | " return tf.cast(x, 'float32'), tf.cast(y, 'float32')\n", 384 | "\n", 385 | "X, y = gen_max_dataset()\n", 386 | "X.shape" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 30, 392 | "metadata": { 393 | "scrolled": true 394 | }, 395 | "outputs": [ 396 | { 397 | "name": "stdout", 398 | "output_type": "stream", 399 | "text": [ 400 | "(100000, 9, 3)\n", 401 | "(100000,)\n" 402 | ] 403 | } 404 | ], 405 | "source": [ 406 | "# Dimensionality check on encoder-decoder couple\n", 407 | "\n", 408 | "encoder = STEncoderBasic(d=3, m=2, h=1)\n", 409 | "encoded = encoder(X)\n", 410 | "print(encoded.shape)\n", 411 | "\n", 412 | "decoder = STDecoderBasic(out_dim=1, d=1, m=2, h=1, k=1)\n", 413 | "decoded = decoder(encoded)\n", 414 | "print(decoded.shape)" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": 31, 420 | "metadata": {}, 421 | "outputs": [], 422 | "source": [ 423 | "# Actual model for max-set prediction\n", 424 | "\n", 425 | "class SetTransformer(tf.keras.Model):\n", 426 | " def __init__(self, ):\n", 427 | " super(SetTransformer, self).__init__()\n", 428 | " self.basic_encoder = STEncoderBasic(d=4, m=3, h=2)\n", 429 | " self.basic_decoder = STDecoderBasic(out_dim=1, d=4, m=2, h=2, k=2)\n", 430 | " \n", 431 | " def call(self, x):\n", 432 | " enc_output = self.basic_encoder(x) # (batch_size, set_len, d_model)\n", 433 | " return self.basic_decoder(enc_output)" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": 32, 439 | "metadata": { 440 | "scrolled": true 441 | }, 442 | "outputs": [ 443 | { 444 | "name": "stdout", 445 | "output_type": "stream", 446 | "text": [ 447 | "Train on 100000 samples\n", 448 | "Epoch 1/6\n", 449 | "100000/100000 [==============================] - 24s 237us/sample - loss: 29.9259\n", 450 | "Epoch 2/6\n", 451 | "100000/100000 [==============================] - 21s 206us/sample - loss: 2.5411\n", 452 | "Epoch 3/6\n", 453 | "100000/100000 [==============================] - 20s 199us/sample - loss: 0.5547\n", 454 | "Epoch 4/6\n", 455 | "100000/100000 [==============================] - 20s 199us/sample - loss: 0.4607\n", 456 | "Epoch 5/6\n", 457 | "100000/100000 [==============================] - 20s 199us/sample - loss: 0.4181\n", 458 | "Epoch 6/6\n", 459 | "100000/100000 [==============================] - 21s 205us/sample - loss: 0.4109\n" 460 | ] 461 | }, 462 | { 463 | "data": { 464 | "text/plain": [ 465 | "" 466 | ] 467 | }, 468 | "execution_count": 32, 469 | "metadata": {}, 470 | "output_type": "execute_result" 471 | } 472 | ], 473 | "source": [ 474 | "set_transformer = SetTransformer()\n", 475 | "set_transformer.compile(loss='mae', optimizer='adam')\n", 476 | "set_transformer.fit(X, y, epochs=6)" 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": 116, 482 | "metadata": {}, 483 | "outputs": [], 484 | "source": [ 485 | "import tensorflow as tf\n", 486 | "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()" 487 | ] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "execution_count": 119, 492 | "metadata": {}, 493 | "outputs": [], 494 | "source": [ 495 | "import numpy as np\n", 496 | "from typing import List, Tuple\n", 497 | "\n", 498 | "def extract_image_set(x_data: np.array, y_data :np.array, agg_fun=np.sum, n_images=3) -> Tuple[np.array, np.array]:\n", 499 | " \"\"\"\n", 500 | " Extract a single set of images with corresponding target\n", 501 | " :param x_data\n", 502 | " \"\"\"\n", 503 | " idxs = np.random.randint(low=0, high=len(x_data)-1, size=n_images)\n", 504 | " return x_data[idxs], agg_fun(y_data[idxs])\n", 505 | "\n", 506 | "\n", 507 | "def generate_dataset(n_samples: int, x_data: np.array, y_data :np.array, agg_fun=np.sum, n_images=3) -> Tuple[List[List[np.array]], np.array]:\n", 508 | " \"\"\"\n", 509 | " :return X,y in format suitable for training/prediction \n", 510 | " \"\"\"\n", 511 | " generated_list = [extract_image_set(x_data, y_data, agg_fun, n_images) for i in range(n_samples)]\n", 512 | " X, y = [i[0] for i in generated_list], np.array([t[1] for t in generated_list])\n", 513 | " output_lists = [[] for i in range(n_images)]\n", 514 | " for image_idx in range(n_images):\n", 515 | " for sample_idx in range(n_samples):\n", 516 | " output_lists[image_idx].append(np.expand_dims(X[sample_idx][image_idx], axis=2))\n", 517 | " return output_lists, y\n", 518 | "\n", 519 | "X_train_data, y_train_data = generate_dataset(n_samples=100000, x_data=x_train, y_data=y_train, n_images=3)\n", 520 | "X_test_data, y_test_data = generate_dataset(n_samples=20000, x_data=x_test, y_data=y_test, n_images=3)" 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": 37, 526 | "metadata": {}, 527 | "outputs": [ 528 | { 529 | "data": { 530 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgAAAAIJCAYAAADTd4UyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3deZhUxb3/8c8XRTZFJOIe4EICuCRuuGCMolzXGDUgolHjHn+AGjAaFSJCXK4KLrnikoiKeiGiKGqCuXEhuOK+izGCcg3BBcImyE79/qgzpumqmenp7lnr/Xqefob+zDl1qrsP3d+pU+e0OecEAADS0qy+OwAAAOoeBQAAAAmiAAAAIEEUAAAAJIgCAACABFEAAACQIAqAEpnZ/mb2pJnNN7PlZvaRmY03sx3qu2/lYGZzzGxMHW5vRzN7LnsunZl1roVtHGpmQyL5eDN7rYD1nZmdW+5+1aVCH2s1bfTM2vnQzNab2fhKlnOR20ulbBtA6Tau7w40Zma2v6Tpkh6RdKakFZJ2kvRTSZ0kza23zjVeoyW1k3S0pOWSPquFbRwq6ThJNxW5fi9Jn5SvO/XiCkmtSmzjB5L2l/SSpM2qWfZ6SZNz7n9V4rYBlIgCoDQDJX0gqb/79xWVnpT0WzOz+utWo9ZD0mPOuadLaSR7/ls451aWp1v/5pwr+1+vZtbKObei3O1Wxjk3uwzN3Oyc+60kFTCaMKc2njcAxeMQQGnaSfrSRS6nmJvFhozNbKSZLci5f1q23B5mNt3Mvjazt7L7bczsbjNbYmYfm9mJVXUqW//BSD7azD6tKE7M7Boze9fMlpnZXDObYGbbFND25Lysd9b3XXKylmZ2nZn9w8xWmdnbZnZkFe12NjMnqaukoVl703N+f252eGWVmc0ys6F56480swXZIZlXJa2U1D+ynZGSfimpU85w9Pi8ZQ4xs3eywxDPm9nOeb/f4PXMtvmcmS3Nbm+ZWbDt/MdqZieZ2b1mtljSH7PfbZQ9lk+zx/q+mf00Z92DsnW3y8lmmNk6M2uXk71rZldV0YcNDgGYWTszG2dm88xsZbb9OypbX5Kcc+ur+j2Aho0CoDRvSDrIzC4zsy5lavMeSX+Q1E+SyQ+b3ilpnvyw9cuS7rWq5xhMknSkmbWpCLIP/eMlPZBTnGwl6WpJP5I0RFIXSdPMrBz7xWRJp2Xt/1jSq5IeM7PdKln+M/mh9c8lTcz+PSjr+9mSbpb0WNbWg5KuN7NL8tpoLf/8jZN0uKRXItsZl7X/ebaNXvLD4RU6yh+GuErSifLP0aTKRnTMrK2kP0n6WP41O07SffLFYXXGyA+F95d/niTpN5KGS/q9/GGQFyRNyCn6Xpa0RtIPs+23lrSnpNXyQ/Iys/aSdpb0XAF9qHCD/HD+UEmHSRomqZzXCR9pZmuzIu2urI8A6pNzjluRN0ltJU2Tf6N08h/St0vqlreck3RuXjZS0oKc+6dly52akx2ZZXflZJvLfwAMrKJfHSStlXRCTtYra6tnJetsJGn7bJkDcvI5ksbk3J8uaXLeur2z9XbJ7vfJ7h+Yt9yzkh6s5jnN314zSf+UdHfecrdKWiKpZc7z6SQdU8DrNkZ+SDo/H589b9/NyY7N2u0Rez0l9czub1aD/aZzts6UvLy9/LyHy/PyxyV9mHN/hqSx2b8PljRf0v2SrsmyoyWtk9S2ij6Ml/Razv33JJ1Xwv+F1ySNr2Jb/SQdIOkCSYskvS5po2K3x40bt9JvjACUwDm3VP7Dbj/5v+BmSzpL0htmtkeRzeYe+56V/ZyWs80l8m/421fRr/nZOgNy4gGSZjvncod9jzCzF81sifwHX8WkxW5F9r3Cf8r/hf2CmW1ccZN/bD1r2NYOkraT/6s/1yT5Aux7OZmT9OfiuvyNOc65j3Luz8zpR8xsScskTTSzY3KH4QswNe/+LvKjGLHH2s3MOmT3n1U2AiD/ofq8pGfysrez/bNQb0m6yMwGmVmpr/8GnHOnOececs4965y7QX6S7B7yozkA6gkFQImcN8M5N9w590P5D7j1ki4rssnFOf9eHckq8pbVtHO/pCPMrG02pN9f/oNEkmRme8kPqc+VdIr8CMG+2a+ra7s6W0raRn6kIvc2UtK3a9jWttnPL/Lyivu5Q8mLnHOrVZrYcy1V8pw45xZJOkRSc0kPSJpvZlMLPCSU/5gKfazPSdolKzZ+mN1/TlJPM2uZk9XEufJns4yQ9GE23+KEGrZRqP+VL5qKLZIBlAEFQJk5596SPxOgR068StImeYtuUctdmSI/h+AY+WO72ymnAJD0E/mRhAHOucecn6H9eQHtrlT1j2Wh/LD9XpHbvqqZitMAt8rLt87ZVoV6+W5r59xLzrnD5Y/795UfQZlYyKp59wt9rC9kP3vLP5/PSnpf/kO1j/wHa40KAOfcYufc+c65bSTtKj/XYIKZ7VSTdgrcVsXj5rvIgXpEAVACM8t/o66YbNdVG/4VN1fSjjnLNJN/o6412V+mT8gP/Q+Q9IFz7p2cRVpJWpPzZixJJxXQ9FxtWNxI/rz6XE/LjwAsc869ln+r0QPx25uncEb/8ZKWSnq3hu1JhY2g1JhzboVz7o+S7pK/HkRNvSfpa8Uf69+zQzsVr+178hP21kl6M3sdn5f0K/nTe2s6AvCNbD+5SP79If+1LpmZHS5pU/l5AADqCdcBKM247MP8IfljwVtIOl3+L6jcN/Epkgab2Zvys8XPkj9+XdsmyX8YLZE0Nu93T0oaYmY3yZ+Ctp+kkwtoc4qkM83sRvlj2AfJz7jPb/svkp40s2vl/zptK2k3+Ul7lxb6AJxz67NT935nZv/K2j5Q/hoMw1xx5/n/TdLWZnaa/AfpAufcnCLakZn9SNIZ8sPnn8rPzThHOfM2CuWcW5i9Hr82s7XyE+v6yk8GzT/18zlJgyX9xTm3LicbLekj51z+YYTqHsfz8q/te/J/mZ8tPyExdiZFxTod5F8Lye/7nczsuOyxTM6W+bn8YbGnJC2QH534ddZu/hwIAHWIAqA0t8rP3h8hf/x2sfyH3WHOuSdylhslP6x7pfxfn2Oz5QbXcv8elZ/ct6X8nIBvOOceN7OLJZ0n/2Y/Q9JRkv5eVYPOualmNkz+FL2zsm38IvtZsYwzs77yp5INkT+1bqH8RLOba/ognHN3ZMe2f5Hd5kr6pXPuxpq2lXlAvnC5Tv6MiXvkX8dizJL/wLxa/jWeL39a4LAi2xsh/5oNlB/6nyXpZOfc/XnLVRQAz+Zlkh8JqKkZ8s9BZ2WjCpKOcM5VdTXLnbXhhMUu8oclJH/4SfKF8anyZwG0lT/MdK+ky3IKFwD1wDYcAQYAAClgDgAAAAmiAAAAIEEUAAAAJIgCAACABFEAAACQoOpOA+QUAZQi+g16dYx9GKVoCPuwxH6M0kT3Y0YAAABIEAUAAAAJogAAACBBFAAAACSIAgAAgARRAAAAkCAKAAAAEkQBAABAgqq7EBAauenTpwfZoEGDgmzChAnR9XffffdydwkA0AAwAgAAQIIoAAAASBAFAAAACaIAAAAgQUwCbKRWr14dZHfeeWeQXXrppUG2dOnSINtiiy3K0zEAQKPACAAAAAmiAAAAIEEUAAAAJIgCAACABFEAAACQIM4CaKSOOuqoIHvqqaeC7PTTTw+y4cOHB1nHjh3L0zE0Kh9++GGQ9ejRo6B1+/btG80feuihkvoEoG4wAgAAQIIoAAAASBAFAAAACaIAAAAgQUwCbARuvPHGIItN+Bs9enSQDRkyJMg22mij8nQMSXv44YfruwsASsAIAAAACaIAAAAgQRQAAAAkiAIAAIAEMQmwnsyfPz+a9+7dO8iWL18eZDNnzgyybt26BVmzZtR4qFuxqwt27969HnqCxuTZZ58NslmzZhW07sqVK4PsvPPOC7L169cH2RFHHBFt87jjjguynXfeOcj22WefQrrYIPHpAABAgigAAABIEAUAAAAJogAAACBBTAKsA7EJKieddFJ02Tlz5gTZ66+/HmSFfmUrADQkd911VzQfOHBgkK1bt67o7ZhZkMUmRT/xxBPR9WN58+bNg2zfffcNsquvvjrIevXqFd1OfWIEAACABFEAAACQIAoAAAASRAEAAECCzDlX1e+r/CVCX331VZD1798/yJ555pno+i+//HKQff/73y+9Y/UjnIVT99iHq9CvX78gK/Vrfqt5T2lsGsI+LDXS/XjJkiVBtt9++0WXjV1Bsm/fvkHWqVOnIItdtW+nnXYKslGjRgVZbLJgZcaPHx9kixYtCrIWLVoE2c033xxkZ5xxRsHbLlH0QTICAABAgigAAABIEAUAAAAJogAAACBBXAmwBK+++mqQHXbYYUG2YsWKIItd3U+KT1wBGqLYBC0gV2xiXyyTpMGDBwfZddddF2SxCXaFGjNmTNHrStJf//rXIItNAly1alWQTZ06NcgOPfTQINthhx2K7F3NMQIAAECCKAAAAEgQBQAAAAmiAAAAIEFMAixQ7CssL7jggiBbvnx5kL3zzjtBtuOOO5anY0CBbr311iAr5ap/ffr0KaU7wAZiX5teyoS/Qi1dujSaP/roo0FW2QTGfO3atQuy0aNHB1ldTviLYQQAAIAEUQAAAJAgCgAAABJEAQAAQIIoAAAASBBnAUTEvvP55z//eZCtX78+yGbOnBlkPXr0KEu/gFI8/fTTZW2PswBQTg899FCQ/exnPwuyH/zgBwW1F5vd/8gjjwTZTTfdFF0/dvZWTGzG/5lnnhlkXbp0Kai9usQIAAAACaIAAAAgQRQAAAAkiAIAAIAEmXOuqt9X+cum4L777guyU089taB1p0+fHmR77713kL3wwgtB9vjjj0fbjE1Sueiii4IsNsmkefPm0TbrkdV3B5TAPhwTu2RpKZNR+/btG2SxSVtNUEPYh6VGuh9/9dVXQXbOOedEl/3zn/8cZLFL5cYms65ZsybIjjrqqCCLTewzi7/EW2yxRZD1798/yC688MIga4AT/qIPkhEAAAASRAEAAECCKAAAAEgQBQAAAAlKZhJgbCKeJP2///f/guz9998PshNOOCHIttxyyyB74IEHguzLL78Msvbt20f7861vfSvIvvjiiyDbfvvtgyzW73rWECZQNZl9uCb69esXZA8//HDR7d1yyy1BNmjQoKLba0Qawj4sJbAfx95jJ0+eXNZtxD7vYld5leKT+7p27VrW/tQhJgECAACPAgAAgARRAAAAkCAKAAAAEtQkJwHOmTMnyA444IDosnPnzi16O7GrPcUmshxyyCFBtueee0bb3HTTTYPsxhtvDLLY1QFfeeWVINtjjz2i26kjDWECVaPch2ui3Ff9i/nb3/4WZN27dy/rNhqohrAPSwnsx7Grsp5++ull3UZsUuGPf/zj6LIbbbRRWbddz5gECAAAPAoAAAASRAEAAECCKAAAAEjQxvXdgVKtXLkyyIYPHx5kNZns17Zt2yA78cQTgyw2Oa9ly5YFbycm9njmzZsXZB07dgyy73znOyVtGwDqy9SpU8vaXuw9+9hjjy3rNho7RgAAAEgQBQAAAAmiAAAAIEEUAAAAJKjRTwK85JJLguwPf/hDweu3adMmyGbMmBFkO+64Y806Vo358+dH82HDhgXZE088EWT3339/kMUmLwI11bdv3yBL5Kp/qANjx46N5g8++GBZt7N+/fqyttcUMQIAAECCKAAAAEgQBQAAAAmiAAAAIEGN6uuAZ86cGWR77713kH399ddBdsopp0TbvO2224KsdevWBfVn3bp1QfbFF18E2auvvhpklX3NZbt27YJswoQJQdarV69CuljfGsJXqTaofbg29OvXL8gefvjhotuLTQJ86KGHim6vkWsI+7DUSPfjN954I8gqe++KvZ9uueWWQXbqqacG2fXXX19Qf9auXVvQck0QXwcMAAA8CgAAABJEAQAAQIIoAAAASBAFAAAACWqwZwGsWrUqyGIz+SdPnhxkhx56aJD97//+b8HbXrlyZZB9+eWXQXbLLbcE2XXXXRdkzZs3D7LY2QuSdO655wbZCSecEF22EWgIM6gb5ezpmjAr79P8t7/9LcgSvhRwQ9iHpUa6H7/yyitBtt9++0WX3XrrrYNs2rRpQbbddtsF2RZbbFFQfzgLYEOMAAAAkCAKAAAAEkQBAABAgigAAABI0Mb13QEp/r3NF198cZDFJvzFxCYB7r///gX359NPPw2yuXPnBlmXLl2C7LTTTguyYcOGBdl3v/vdgvsDSPFL/pYqdtnfhCf8ocyuvvrqgpc9+eSTgyy2L8YuGRybPD127NiCt50qRgAAAEgQBQAAAAmiAAAAIEEUAAAAJKhBTAJcvHhxkP33f/930e1deOGFQVbZFQ9jV1E76KCDgiw2ke+MM84Isk022aSQLgJAk/LWW28F2dNPP13w+qeffnpBy8Xey5ctW1bwdvBvjAAAAJAgCgAAABJEAQAAQIIoAAAASFCDmAQ4fPjwINtnn32C7OGHHw6yDz74IMhefPHFIOvVq1d02506dQqy2BX+mjWjVkLd+fDDD+tkO3369KmT7aDpGzVqVJCtWLGi7NuJtTl+/PiybycFfKoBAJAgCgAAABJEAQAAQIIoAAAASFCDmAR42223Fb3utttuG2QHH3xwKd0BksEkQJTLwoULy97m2rVrg+zOO+8s+3ZSxQgAAAAJogAAACBBFAAAACSIAgAAgAQ1iEmAADbUvXv3IKtswl7sCpkxt9xyS0HbAYpx7bXXBlnsq9XXrFkTXf83v/lNkLVu3TrICr3qX7t27QpaLmWMAAAAkCAKAAAAEkQBAABAgigAAABIkDnnqvp9lb8EqmH13QGxD6M0DWEflhrpfnzzzTcH2YUXXhhddt26dUVvZ4sttgiyN954I8i+/e1vF72NRi66HzMCAABAgigAAABIEAUAAAAJogAAACBBFAAAACSIswBQmxrCDGr2YZSiIezDUhPaj6dNmxbNhwwZEmQLFy4MsquuuirIDj744CBLeMZ/DGcBAAAAjwIAAIAEUQAAAJAgCgAAABLEJEDUpoYwgYp9GKVoCPuwxH6M0jAJEAAAeBQAAAAkiAIAAIAEUQAAAJAgCgAAABJEAQAAQIIoAAAASBAFAAAACaIAAAAgQdVdCRAAADRBjAAAAJAgCgAAABJEAQAAQIIoAAAASBAFAAAACaIAAAAgQRQAAAAkiAIAAIAEUQAAAJAgCgAAABJEAQAAQIIoAEpkZvub2ZNmNt/MlpvZR2Y23sx2qO++lYOZzTGzMXW4vR3N7LnsuXRm1rkWtnGomQ2J5OPN7LUC1ndmdm65+1WXCn2s1bTRM2vnQzNbb2bjK1nuMjN7ysyW1tZrCqDmKABKYGb7S5ouaYmkMyUdK2mspB0ldaq/njVqoyW1k3S0pF6SPquFbRwqKSgAaqCXpAfL1Jf6coWk00ps4weS9pf0qqTPq1juHEkbS/pridsDUEYb13cHGrmBkj6Q1N/9+2sVn5T0WzOz+utWo9ZD0mPOuadLaSR7/ls451aWp1v/5px7qdxtmlkr59yKcrdbGefc7DI0c7Nz7reSVM1oQkfn3HozO0q+sAPQADACUJp2kr50ke9Uzs1iQ8ZmNtLMFuTcPy1bbg8zm25mX5vZW9n9NmZ2t5ktMbOPzezEqjqVrR/8hWpmo83s04rixMyuMbN3zWyZmc01swlmtk0BbU/Oy3pnfd8lJ2tpZteZ2T/MbJWZvW1mR1bRbmczc5K6ShqatTc95/fnZodXVpnZLDMbmrf+SDNbkB2SeVXSSkn9I9sZKemXkjpl23D5Q9dmdoiZvZMdhnjezHbO+/0Gr2e2zeeyIe6l2esWbDv/sZrZSWZ2r5ktlvTH7HcbZY/l0+yxvm9mP81Z96Bs3e1yshlmts7M2uVk75rZVVX0YYNDAGbWzszGmdk8M1uZbf+OytaXJOfc+qp+X9PlANQtCoDSvCHpoOwYZ5cytXmPpD9I6ifJJE2WdKekeZKOk/SypHut6jkGkyQdaWZtKoLsQ/94SQ/kFCdbSbpa0o/kh8S7SJpmZuXYLybLDzFfLenH8sPEj5nZbpUs/5n80PrnkiZm/x6U9f1sSTdLeixr60FJ15vZJXlttJZ//sZJOlzSK5HtjMva/zzbRi/54fAKHeUPQ1wl6UT552hSZSM6ZtZW0p8kfSz/mh0n6T754rA6YyR9JV+oXJ1lv5E0XNLv5f9afkHShJyi72VJayT9MNt+a0l7SlotPyQvM2svaWdJzxXQhwo3yA/nD5V0mKRhkoLCFkAT4pzjVuRNUltJ0+TfKJ38h/TtkrrlLecknZuXjZS0IOf+adlyp+ZkR2bZXTnZ5vIfAAOr6FcHSWslnZCT9cra6lnJOhtJ2j5b5oCcfI6kMTn3p0uanLdu72y9XbL7fbL7B+Yt96ykB6t5TvO310zSPyXdnbfcrfJzL1rmPJ9O0jEFvG5jJM2J5OOz5+27OdmxWbs9Yq+npJ7Z/c1qsN90ztaZkpe3l7Rc0uV5+eOSPsy5P0PS2OzfB0uaL+l+Sddk2dGS1klqW0Ufxkt6Lef+e5LOK+H/wmuSxlezzFHZ4+5c7Ha4ceNWvhsjACVwzi2V/7DbT/4vuNmSzpL0hpntUWSzuce+Z2U/p+Vsc4n8G/72VfRrfrbOgJx4gKTZzrncYd8jzOxFM1si/8E3N/tVtyL7XuE/5f/CfsHMNq64yT+2njVsawdJ2ymcdDdJvgD7Xk7mJP25uC5/Y45z7qOc+zNz+hEzW9IySRPN7JjcYfgCTM27v4v8KEbssXYzsw7Z/WeVjQBIOkDS85KeycvezvbPQr0l6SIzG2Rmpb7+ABoBCoASOW+Gc264c+6H8h9w6yVdVmSTi3P+vTqSVeQtq2nnfklHmFnbbEi/v/wHiSTJzPaSH1KfK+kU+RGCfbNfV9d2dbaUtI38SEXubaSkb9ewrW2zn1/k5RX32+dki5xzq1Wa2HMtVfKcOOcWSTpEUnNJD0iab2ZTCzwklP+YCn2sz0naJSs2fpjdf05STzNrmZPVxLmSHpE0QtKH2XyLE2rYBoBGhAKgzJxzb8mfCdAjJ14laZO8Rbeo5a5MkZ9DcIz8sd3tlFMASPqJ/EjCAOfcY87PbK/qVK4KK1X9Y1koP2y/V+S2r2qm4jTArfLyrXO2VaFejlk7515yzh0uf9y/r/wIysRCVs27X+hjfSH72Vv++XxW0vvyIxF9JO2hGhYAzrnFzrnznXPbSNpVfq7BBDPbqSbtAGg8KABKYGb5b9QVk+26asO/4ubKXxugYplm8m/UtSb7y/QJ+aH/AZI+cM69k7NIK0lrnHO5H0InFdD0XG1Y3Ej+vPpcT8uPACxzzr2Wf6vRA/Hbm6dwRv/xkpZKereG7UmFjaDUmHNuhXPuj5LuklTMB+d7kr5W/LH+PTu0U/Havic/YW+dpDez1/F5Sb+SP723piMA38j2k4vk3x/yX2sATQTXASjNuOzD/CH5Y8FbSDpd/i+o3DfxKZIGm9mb8rPFz5I/fl3bJsl/GC2Rv0BRriclDTGzm+RPQdtP0skFtDlF0plmdqP8MeyD5Gfc57f9F0lPmtm18n+dtpW0m/ykvUsLfQDOnz8+UtLvzOxfWdsHyl+DYZgr7jz/v0na2sxOk/8gXeCcm1NEOzKzH0k6Q374/FP5uRnnKGfeRqGccwuz1+PXZrZWfmJdX/nJoPmnfj4nabCkvzjn1uVkoyV95JzLP4xQ3eN4Xv61fU9+ZOJs+QmJsTMpKtbpIP9aSH7f72Rmx2WPZXLOcgfKT0zdM4uOMLP5kmY652YKQL2gACjNrfKz90fIH79dLP9hd5hz7omc5UbJD+teKf/X59hsucG13L9H5Sf3bSk/J+AbzrnHzexiSefJv9nPkJ+l/feqGnTOTTWzYfKn6J2VbeMX2c+KZZyZ9ZU/lWyI/Kl1C+Unmt1c0wfhnLsjO7b9i+w2V9IvnXM31rStzAPyhct18h9M96j4q+LNkv/AvFr+NZ4vf1rgsCLbGyH/mg2UH/qfJelk59z9ectVFADP5mWSHwmoqRnyz0FnZaMKko5wzs2tYp2dteGExS7yhyUkf/ipwij9u1CQ/P+binxkEX0FUAa24QgwAABIAXMAAABIEAUAAAAJogAAACBBFAAAACSourMAmCGIUjSEr0RmH0YpGsI+LLEfozTR/ZgRAAAAEkQBAABAgigAAABIEAUAAAAJogAAACBBFAAAACSIAgAAgARRAAAAkCAKAAAAEkQBAABAgigAAABIEAUAAAAJogAAACBBFAAAACSouq8DBtBEfPLJJ0HWvXv3IFuzZk2Q9evXL8juueee6HbatGlTRO8A1DVGAAAASBAFAAAACaIAAAAgQRQAAAAkiAIAAIAEcRYA0AR9/PHHQXbIIYcE2bp164KsWbPw74IpU6YEWatWraLbvu+++wrpIoB6xggAAAAJogAAACBBFAAAACSIAgAAgAQxCTBi9erVQXbYYYcF2fTp04PMzILskUceCbKjjz66uM4hWV9//XU0f+KJJ4LsjDPOCLIlS5YUtJ2NNw7fFoYPHx5k7dq1K6g9AA0TIwAAACSIAgAAgARRAAAAkCAKAAAAEmTOuap+X+Uvm6qVK1cG2eDBg4Ns8eLFQRab8Pe9730vyF5++eXotlu0aFFIFxuLcEZk3Wsy+/C7774bzXfbbbei2xw4cGCQXXjhhUHWuXPnorfRyDWEfVhqQvtxffryyy+D7Mwzz4wu+6c//SnIli1bFmRt2rQpvWO1L7ofMwIAAECCKAAAAEgQBQAAAAmiAAAAIEHJXwlw/fr1QXbHHXcE2cknnxxkLVu2DLLYJMDY5K21a9dG+9PEJgGijJ5++umyt3niiScGWcIT/tCETJ48OcjOPvvsIFu6dGl0/dhVXZsaRgAAAEgQBQAAAAmiAAAAIEEUAAAAJCj5SYDNmoU10Omnnx5kvXv3DrJ//vOftdElIHrFsVInAX7nO98Jsl122aWkNoG6FrtS60033RRkl19+eZDtvvvuQfbqq6+Wp2ONECMAAAAkiAIAAIAEUQAAAJAgCgAAABKU/CTAQr355ptFrzf+VOUAABxSSURBVNu9e/cg22ijjUrpDpqQFStWBFnsCn2PP/54wW3uvPPOQfbEE08E2eabb15wm0Bd+8c//hFke++9d5B98cUXQTZgwIAg+81vfhNkPXr0KLJ3jR8jAAAAJIgCAACABFEAAACQIAoAAAASxCTAiObNmwfZrrvuGmRvv/12Qe116tQpyJgEiAqrVq0KsppM+IvZcccdg2ybbbYpqU2gHD777LNo/tOf/jTInnnmmYLaHDp0aJBdc801QRb7v+acK2gbTREjAAAAJIgCAACABFEAAACQIAoAAAASRAEAAECCkj8LYP369UF2zz33BNmmm25aUHuHHnpokMUuwbp69ero+rEzEICaOv/88+u7C0jMggULguyWW24JstGjR0fXj10Se//99w+yMWPGBFns8sAxsbMAzCy67E9+8pMga926dUHbaSwYAQAAIEEUAAAAJIgCAACABFEAAACQoOQnAcYuAzlw4MCi24tN+OvTp0+QbbLJJkVvA03L+++/X9L6vXr1CrI999wzyGLfrT5v3rwgu/rqq4Nss802C7Lzzjsv2p+OHTsG2bbbbhtdFg3f3Llzg+zhhx8OsmHDhgVZbGJfu3btott59NFHg+yAAw4IslImSq9du7bgZWPv0ZVNGGysGAEAACBBFAAAACSIAgAAgARRAAAAkKDkJwHGbL/99kH29ddfB9miRYsKam/UqFFBxhX/UOGqq64qaf2zzz47yEaMGBFksUlWs2bNKnq7f/jDH6L59773vSCLXantZz/7WZDtu+++QbbxxrxN1YY1a9YE2bhx44Js8ODBQVboZLjYxMCLL744umyhV1stxdSpU2t9G40JIwAAACSIAgAAgARRAAAAkCAKAAAAEsTsmojYVctiX3VZ6CTArl27ltwnYKuttormt99+e5C98sortd2dSr377rsFZXfeeWeQ7bTTTkH21FNPBdnWW29dZO/SU9lkzdhXRv/rX/8qqM2//vWvQRa7al9DE7saZuxqsJL085//vLa7U+8YAQAAIEEUAAAAJIgCAACABFEAAACQICYBRrz22mv13QUg0KNHj2j+0EMPBdmXX35Z293R5ZdfHs1jk/YWL15cUJszZ84MspdeeinIjjnmmILaS01sktuZZ54ZXXbVqlVBFrvCX2yS3Iknnhhkp5xySpDtt99+QdazZ89of2JfMRz7OuHu3bsX3Ga++++/P8gqu6phoW02ZowAAACQIAoAAAASRAEAAECCKAAAAEhQ8pMADz744PruAhIyd+7cIJs9e3ZB61Y2Wal9+/YFZeU2adKkaD5hwoQgi331L8rv29/+dpD9+c9/ji673XbbBdnrr78eZPfee2+Qvffee0F24403Btno0aODrLIr7xX6FcMxsTZLaU+Sdt111yCLXSX21FNPDbLTTz89yBri11ozAgAAQIIoAAAASBAFAAAACaIAAAAgQQ1vVkIde/PNN+u7C0hIbHJehw4dgmzWrFlB9tVXX0Xb/OKLLwpqs1mz8tb769ati+afffZZWbeD0hx44IEFL/vd7343yE444YSC1v3444+D7JFHHgmy2ERYSfrWt74VZN26dQuy2JUhY5MAY1/XHpvQWJlPPvkkyObMmRNkM2bMCLLYFRB33nnngrddVxgBAAAgQRQAAAAkiAIAAIAEUQAAAJAgCgAAABKU/FkALVq0CLLly5fXQ0+QgtatWwdZ27ZtC1r3jTfeiOaxS7o+8MADQdavX7+CthOb3X/rrbcG2dKlS6PrjxgxoqDtxHTt2jXI9thjj6LbQ93p0qVLkF1wwQVl307//v2LXjd2id7evXtHl917772D7JprrgmyP/3pT0HWqVOnmneuHjACAABAgigAAABIEAUAAAAJogAAACBBVtl3M2eq/GVT8OCDDwZZoZe+LFTs0qhbbbVVWbfRQJX2hdzl0eD34f/7v/8LstiEqppo1apVkG2//fYFrRt7T5g9e3ZJ/YmJTfh75plngiz2Hex1qCHsw1Ij2I8bg2XLlgXZ5ptvHl12wIABQTZx4sSy96mORPdjRgAAAEgQBQAAAAmiAAAAIEEUAAAAJCj5KwH27du3vruAxFU2CakUK1asCLJZs2aVfTuF2mGHHYLs6aefDrJ6nvCHJu6FF16o7y40KIwAAACQIAoAAAASRAEAAECCKAAAAEhQ8pMAmzULa6DY1cgOPPDAuugOEhSbBBj7Supf/vKX0fVvv/32svepEIceemg0HzhwYJAddthhQRb7Km6gNrVr167gZQ8//PBa7EnDwAgAAAAJogAAACBBFAAAACSIAgAAgAQlPwnQLPyWxJ49ewbZ3nvvXVB7c+bMCbIrr7wyyG644Ybo+htvnPxLkpzYPtiyZcsgGzNmTHT92NdXv/XWW0E2ZMiQIDvkkEOCbPjw4dHt5Ntjjz2ieZs2bQpaH6hrHTp0CLLY119XlTcljAAAAJAgCgAAABJEAQAAQIIoAAAASJBVM9Gh6c+CKNDKlSuD7OOPPw6y2GTB2Fezfuc734lu59prrw2yY489tpAuNkTh7La6xz6MUjSEfVhiPy6LZcuWBVllX8c9YMCAIJs4cWLZ+1RHovsxIwAAACSIAgAAgARRAAAAkCAKAAAAEsRl5woUuzJbjx49guySSy4JsthXu1566aXR7Xz22WdF9A4AUE6rVq0KsvXr1wdZ7CvlG4vG23MAAFA0CgAAABJEAQAAQIIoAAAASBAFAAAACeJSwKhNDeEyquzDKEVD2Icl9uOyqMmlgGOfjV999VWQtWnTpvSO1T4uBQwAADwKAAAAEkQBAABAgigAAABIEJMAUZsawgQq9mGUoiHswxL7MUrDJEAAAOBRAAAAkCAKAAAAEkQBAABAgigAAABIEAUAAAAJogAAACBBFAAAACSIAgAAgARVdyVAAADQBDECAABAgigAAABIEAUAAAAJogAAACBBFAAAACSIAgAAgARRAAAAkCAKAAAAEkQBAABAgigAAABIEAUAAAAJogAokZntb2ZPmtl8M1tuZh+Z2Xgz26G++1YOZjbHzMbU4fZ2NLPnsufSmVnnWtjGoWY2JJKPN7PXCljfmdm55e5XXSr0sVbThpnZuWb2vpl9bWb/Z2Y3m1m7cvUTQO3ZuL470JiZ2f6Spkt6RNKZklZI2knSTyV1kjS33jrXeI2W1E7S0ZKWS/qsFrZxqKTjJN1U5Pq9JH1Svu7UiysktSqxjfPkn8Mr5P8fdJN0taSOko4psW0AtYwCoDQDJX0gqb/799cqPinpt2Zm9detRq2HpMecc0+X0kj2/Ldwzq0sT7f+zTn3UrnbNLNWzrkV5W63Ms652WVo5qeSpjjnLs/u/9XMWki60czaOOeWl2EbAGoJhwBK007Sly7yncq5WWzI2MxGmtmCnPunZcvtYWbTsyHVt7L7bczsbjNbYmYfm9mJVXUqW//BSD7azD6tKE7M7Boze9fMlpnZXDObYGbbFND25Lysd9b3XXKylmZ2nZn9w8xWmdnbZnZkFe12NjMnqaukoVl703N+f252eGWVmc0ys6F56480swXZIZlXJa2U1D+ynZGSfimpU7YNZ2bj85Y5xMzeyQ5DPG9mO+f9foPXM9vmc2a2NLu9ZWbBtvMfq5mdZGb3mtliSX/MfrdR9lg+zR7r+2b205x1D8rW3S4nm2Fm63KH3rPX9aoq+rDBIQAza2dm48xsnpmtzLZ/R2XrZ5pLWpKXLZZk2Q1AA0YBUJo3JB1kZpeZWZcytXmPpD9I6if/JjpZ0p2S5skPW78s6V6reo7BJElHmlmbiiD70D9e0gM5xclW8kO2P5I0RFIXSdPMrBz7xWRJp2Xt/1jSq5IeM7PdKln+M/mh9c8lTcz+PSjr+9mSbpb0WNbWg5KuN7NL8tpoLf/8jZN0uKRXItsZl7X/ebaNXvJD2BU6yh+GuErSifLP0aTKRnTMrK2kP0n6WP41O07SffLFYXXGSPpKvlC5Ost+I2m4pN/LHwZ5QdKEnKLvZUlrJP0w235rSXtKWi3pB1nWXtLOkp4roA8VbpC0v6Shkg6TNExSUNjmGSfpeDM70sw2M7PdJV0iabxzblkNtg2gPjjnuBV5k9RW0jT5N0on/yF9u6Ruecs5SefmZSMlLci5f1q23Kk52ZFZdldOtrn8B8DAKvrVQdJaSSfkZL2ytnpWss5GkrbPljkgJ58jaUzO/emSJuet2ztbb5fsfp/s/oF5yz0r6cFqntP87TWT9E9Jd+ctd6v8X58tc55PJ+mYAl63MZLmRPLx2fP23Zzs2KzdHrHXU1LP7P5mNdhvOmfrTMnL28vPe7g8L39c0oc592dIGpv9+2BJ8yXdL+maLDta0jpJbavow3hJr+Xcf0/SeUX8H/hVtq2K/wNTJDUv1/8xbty41d6NEYASOOeWyn/Y7Sf/F9xsSWdJesPM9iiy2dxj37Oyn9NytrlE/g1/+yr6NT9bZ0BOPEDSbOdc7rDvEWb2opktkf/gq5i02K3Ivlf4T/m/sF8ws40rbvKPrWcN29pB0nbyf/XnmiRfgH0vJ3OS/lxcl78xxzn3Uc79mTn9iJktaZmkiWZ2jNVsBvzUvPu7yI9ixB5rNzPrkN1/VtkIgKQDJD0v6Zm87O1s/yzUW5IuMrNBZlbQ65+NSlwm6deSDpR0hqS95EesADRwFAAlct4M59xw59wP5T/g1su/MRZjcc6/V0eyirxlNe3cL+kIM2ubDen3l/8gkSSZ2V7yQ+pzJZ0iP0Kwb/br6tquzpaStpEfqci9jZT07Rq2tW3284u8vOJ++5xskXNutUoTe66lSp4T59wiSYfIHw9/QNJ8M5ta4CGh/MdU6GN9TtIuWbHxw+z+c5J6mlnLnKwmzpU/m2WEpA+z+RYnVLZwtk/dLOm/nXP/5Zx71jl3t/zZMKeUUAADqCMUAGXmnHtL/kyAHjnxKkmb5C26RS13ZYr8HIJj5I/tbqecAkDST+RHEgY45x5zfmb75wW0u1LVP5aF8sP2e0Vu+6pmKk4D3Cov3zpnWxWqO2ZdK5xzLznnDpc/7t9XfgRlYiGr5t0v9LG+kP3sLf98PivpffmRiD6S9lANCwDn3GLn3PnOuW0k7So/12CCme1UySpbSvqW/MhBrjezn11rsn0AdY8CoARmlv9GXTHZrqs2/CturqQdc5ZpJv9GXWuyv0yfkB/6HyDpA+fcOzmLtJK0xjmX+yF0UgFNz9WGxY3kz6vP9bT8CMAy59xr+bcaPRC/vXkKZ/QfL2mppHdr2J5U2AhKjTnnVjjn/ijpLvnrQdTUe5K+Vvyx/j07tFPx2r4nP2FvnaQ3s9fxeflj8hur5iMA38j2k4vk3x/yX+sK87O+5v+lv2f2c06x2wdQN7gOQGnGZR/mD8kfC95C0unyf0HlvolPkTTYzN6Uny1+lvzx69o2Sf7DaImksXm/e1LSEDO7Sf4UtP0knVxAm1MknWlmN8ofwz5IfsZ9ftt/kfSkmV0r/9dpW0m7yU/au7TQB+CcW5+duvc7M/tX1vaB8tdgGOaKO8//b5K2NrPT5D9IFzjn5hTRjszsR/LHvh+R9Kn83IxzlDNvo1DOuYXZ6/FrM1sr6TX5EYUj5c9IyPWcpMGS/uKcW5eTjZb0kXMu/zBCdY/jefnX9j35kYmz5Sckxs6kkHPOmdnv5U/Z/Fp+VKKrpFGSXpL0ek22D6DuUQCU5lb52fsj5I/fLpb/sDvMOfdEznKj5Id1r5T/63NsttzgWu7fo/KT+7aUnxPwDefc42Z2sfzV3M6Wn1l+lKS/V9Wgc26qmQ2TP0XvrGwbv8h+VizjzKyv/KlkQ+RPrVsoP1x8c00fhHPujuzY9i+y21xJv3TO3VjTtjIPyBcu18mfMXGP/OtYjFnyH5hXy7/G8+VPCxxWZHsj5F+zgfJD/7Mkneycuz9vuYoC4Nm8TPIjATU1Q/456KxsVEHSEc65qq5meYmkBfJzSC7Vvx/7r51z64voA4A6ZBuOAAMAgBQwBwAAgARRAAAAkCAKAAAAEkQBAABAgqo7C4AZgihFQ/hGOPZhlKIh7MMS+zFKE92PGQEAACBBFAAAACSIAgAAgARRAAAAkCAKAAAAEkQBAABAgigAAABIEAUAAAAJogAAACBBFAAAACSIAgAAgARRAAAAkCAKAAAAEkQBAABAgigAAABIEAUAAAAJogAAACBBFAAAACSIAgAAgARRAAAAkKCN67sDdcU5F81/97vfBdnAgQODrFOnTkE2ZcqUINt9992L6B0AAHWLEQAAABJEAQAAQIIoAAAASBAFAAAACbLKJsdlqvxlY7J27dpovskmmxTdZo8ePYLs+eefD7L27dsXvY1Gzuq7A2pC+zDqRUPYhyX247K44oorgmzEiBHRZYcOHRpkN9xwQ9n7VEei+zEjAAAAJIgCAACABFEAAACQIAoAAAASlMyVAOfOnRvNN9100yB75ZVXgmzatGlBdt555wXZlVdeGWTXX399kJk1lLlFaCyWLl0azd9+++0gmzp1apDdd999QTZv3rzSO5YndtXMDz74IMhatWpV9m0DVXnqqaeCrFmz+N/BKbxHMwIAAECCKAAAAEgQBQAAAAmiAAAAIEHJTAKsbKLHsmXLguzvf/97kA0aNCjIPvnkkyC77bbbguyII44IskMOOSTaH0CK74O9e/eOLvv5558XvZ3Y/4vY1TFjE6JWrlwZbfPTTz8NstWrVwcZkwBRm2L/L2bNmlUPPWm4GAEAACBBFAAAACSIAgAAgARRAAAAkKBkJgFutNFGBS/78ssvB9nRRx8dZKNGjQqyvn37Blm/fv2CLHa1QUnaYYcdCukimrhtt902yM4///zosq+//nqQDRw4MMiaN28eZJtvvnmQxfbBhQsXBlm3bt2i/QEagtmzZwdZTSbM/uQnPylndxokRgAAAEgQBQAAAAmiAAAAIEEUAAAAJCiZSYBjx44teNlCr9LXunXrIPuP//iPIIt95XDsK4Il6YYbbgiyFL6WEhvabLPNguySSy6ph55406dPL3jZ7t27B1mLFi3K2BtgQ1988UWQHX/88QWt26dPn2i+1157ldSnxoARAAAAEkQBAABAgigAAABIEAUAAAAJSmYS4Jo1awpedscddyx6O9tss02QnXTSSUEWu4qgFL/61AEHHFB0f4Caevvtt4PsoosuKnj9Y445JshatmxZUp+AqixYsCDIYlf9a9u2bZBdccUV0TZTmLjKCAAAAAmiAAAAIEEUAAAAJIgCAACABFEAAACQoGTOAnjsscei+VZbbRVksZmipYhdWriyswB+9atfBdmLL74YZM2aUbuhdrzzzjtB9sknnwRZZTP7hw4dWvY+AVWZNGlSQcs9+uijQbbPPvuUuzuNBp8iAAAkiAIAAIAEUQAAAJAgCgAAABLUJCcBvvLKK0E2e/bs6LIDBw4MslatWpW1P7169QqyyiYBXn755UH26quvBlnKE1dQPitXrgyyK6+8sqB1K/t/svXWW5fUJ6AqsUv83nHHHQWt26VLl3J3p1FjBAAAgARRAAAAkCAKAAAAEkQBAABAgprkJMDYlfOcc9FlBw0aVNvdkZkFWWVXS7vrrruC7KyzzgqyN954I8iaN29eRO+QstGjRwfZRx99VNC6lU28mjlzZpD98Y9/DLKDDjooyPbee++Cto103XrrrUH25ZdfBlmfPn2CbMstt6yVPjVWjAAAAJAgCgAAABJEAQAAQIIoAAAASFCTnAT47rvvBtnGG8cf6vbbb1/b3YnadNNNo/mZZ54ZZCNGjAiyf/zjH0HGVa5QU2vWrCl63fPPPz+az5s3L8hat24dZKecckrR20YaVq1aFWTPP/98Qevut99+QVbZV1inihEAAAASRAEAAECCKAAAAEgQBQAAAAlq9JMAFy1aFGT33HNPkB1++OHR9TfffPOy96kUXbt2DbLYlQQfe+yxIBsyZEit9AmNz9q1a4PspZdeCrL/+q//Knobsa9llaRTTz01yGJfc73ddtsVvW2k4eWXXw6yZ555Jsg6dOgQZIMHD66VPjUljAAAAJAgCgAAABJEAQAAQIIoAAAASJBV9jW5mSp/2RA88sgjQda3b98gmzhxYnT9E044oex9KrfYVdRiVzb87LPPgqxNmza10qcChbMX616D34djFi9eHGSxr9SVpHHjxgXZ8uXLgyz2FdKF6tixY5Dddttt0WWPOOKIorfTADWEfVhqpPtxTcyfPz/IdtpppyBbuHBhkF122WVBNnLkyLL0q4mI7seMAAAAkCAKAAAAEkQBAABAgigAAABIUKO/EmChttpqq/ruQtH69esXZBMmTAiy2NXf0DjdcccdQXbxxRfXybaPO+64ILv99tuDrH379nXRHSRi5cqVQRab8Lf11lsH2cCBA2ulT00dIwAAACSIAgAAgARRAAAAkCAKAAAAEkQBAABAgpI5C6Axu+SSS4IsdhYAmo7NNtssyGKz8yVp1KhRQbZ69eog23333Qva9ve///0gY8Y/alvsTJOYBx54IMhiZwageowAAACQIAoAAAASRAEAAECCKAAAAEiQOVfl10w3+O+gnj59epAdfPDBQTZo0KDo+mPHji13l8ou9t3wsUlZixYtCrLNN9+8VvpUoIbwXeoNfh+uDTNnzgyyXXbZJchi+8fs2bODLOFJgA1hH5aa0H780ksvRfPDDz88yHbdddcge+qpp4KsefPmpXesaYvux4wAAACQIAoAAAASRAEAAECCKAAAAEhQo78SYK9evYKsY8eOQfa73/0uuv7ll18eZB06dCi9Y2X0+9//PsiOOeaYIItdPQ5pGjFiREHLnXPOOUGW8IQ/1IH/+Z//ieZfffVVkLVp0ybImPBXPowAAACQIAoAAAASRAEAAECCKAAAAEhQo58E2KJFiyCLfa3kUUcdFV3/2GOPDbLHH388yMp9Rb3Y17VK8auwjRw5MshefPHFIGvWjHoO3rRp0wpa7gc/+EEt9wQoTOwrfW+99dZ66Ek6+MQAACBBFAAAACSIAgAAgARRAAAAkKBGPwkwJva1krfddlt02diV0Dp37hxkffv2DbK99tqroP7MnTs3yCZOnBhdds6cOUEWm5S42267FbRtAKgvCxcuDLK77roruuw+++wTZLH3YpQPIwAAACSIAgAAgARRAAAAkCAKAAAAEtQkJwHGnHXWWdH8+9//fpBNmDAhyGJfJ3z33XcX3Z+hQ4dG80GDBgVZly5dit4OUJUDDjigvruAJuzaa68NslWrVtVDTxDDCAAAAAmiAAAAIEEUAAAAJIgCAACABJlzrqrfV/lLoBpW3x1Qovtw+/btg2zRokVBtm7duiDja6U30BD2YamR7scXXHBBkE2aNCm67Ouvvx5k22yzTdn7lKjofsz/dAAAEkQBAABAgigAAABIEAUAAAAJogAAACBBnAWA2tQQZlAnuQ/feeedQbZgwYIg+9WvfhVkZg3hZWswGsqTkeR+jLLhLAAAAOBRAAAAkCAKAAAAEkQBAABAgpgEiNrUECZQsQ+jFA1hH5bYj1EaJgECAACPAgAAgARRAAAAkCAKAAAAEkQBAABAgigAAABIEAUAAAAJogAAACBBFAAAACSouisBAgCAJogRAAAAEkQBAABAgigAAABIEAUAAAAJogAAACBBFAAAACTo/wO/xBXFI1slCQAAAABJRU5ErkJggg==\n", 531 | "text/plain": [ 532 | "
" 533 | ] 534 | }, 535 | "metadata": { 536 | "needs_background": "light" 537 | }, 538 | "output_type": "display_data" 539 | } 540 | ], 541 | "source": [ 542 | "import matplotlib.pyplot as plt\n", 543 | "%matplotlib inline \n", 544 | "\n", 545 | "n_sample_viz, n_images = 3, 3\n", 546 | "\n", 547 | "fig, axes = plt.subplots(nrows=n_sample_viz, ncols=n_images, figsize=(9.0, 9.0))\n", 548 | "\n", 549 | "for sample_idx in range(n_sample_viz):\n", 550 | " for im_idx in range(n_images):\n", 551 | " axes[sample_idx, im_idx].imshow(X_train_data[im_idx][sample_idx][:, :, 0], cmap='Greys')\n", 552 | " axes[sample_idx, im_idx].axis('off')\n", 553 | " if im_idx==0:\n", 554 | " axes[sample_idx, 0].set_title(' Sum value for this row is {}'.format(y_train_data[sample_idx]), \n", 555 | " fontsize=15, loc='left')" 556 | ] 557 | }, 558 | { 559 | "cell_type": "code", 560 | "execution_count": 111, 561 | "metadata": {}, 562 | "outputs": [ 563 | { 564 | "name": "stdout", 565 | "output_type": "stream", 566 | "text": [ 567 | "Train on 60000 samples\n", 568 | "Epoch 1/3\n", 569 | "60000/60000 [==============================] - 77s 1ms/sample - loss: 1.0678\n", 570 | "Epoch 2/3\n", 571 | "60000/60000 [==============================] - 70s 1ms/sample - loss: 0.5569\n", 572 | "Epoch 3/3\n", 573 | "60000/60000 [==============================] - 71s 1ms/sample - loss: 0.4641\n" 574 | ] 575 | }, 576 | { 577 | "data": { 578 | "text/plain": [ 579 | "" 580 | ] 581 | }, 582 | "execution_count": 111, 583 | "metadata": {}, 584 | "output_type": "execute_result" 585 | } 586 | ], 587 | "source": [ 588 | "# First, define the vision modules\n", 589 | "from tensorflow.keras.layers import Dense\n", 590 | "from tensorflow.keras.layers import Input\n", 591 | "from tensorflow.keras.models import Model\n", 592 | "from tensorflow.keras.layers import Conv2D\n", 593 | "from tensorflow.keras.layers import MaxPooling2D\n", 594 | "from tensorflow.keras.layers import Flatten\n", 595 | "from tensorflow.keras.layers import Dropout\n", 596 | "from tensorflow.keras.layers import Add\n", 597 | "from tensorflow.keras.optimizers import Adam\n", 598 | "\n", 599 | "filters = 64\n", 600 | "kernel_size = 3\n", 601 | "\n", 602 | "import tensorflow as tf\n", 603 | "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n", 604 | "x_train = [np.expand_dims(t, axis=2) for t in x_train]\n", 605 | "x_test = [np.expand_dims(t, axis=2) for t in x_test]\n", 606 | "\n", 607 | "input_image = Input(shape=(28, 28, 1))\n", 608 | "\n", 609 | "y = Conv2D(32, kernel_size=(3, 3),\n", 610 | " activation='relu',\n", 611 | " input_shape=input_shape)(input_image)\n", 612 | "y = Conv2D(64, (3, 3), activation='relu')(y)\n", 613 | "y = MaxPooling2D(pool_size=(2, 2))(y)\n", 614 | "y = Dropout(0.25)(y)\n", 615 | "y = Flatten()(y)\n", 616 | "y = Dense(32, activation='relu')(y)\n", 617 | "y = Dense(16, activation='relu')(y)\n", 618 | "output_vec = Dense(1)(y)\n", 619 | "\n", 620 | "vision_model = Model(input_image, output_vec)\n", 621 | "vision_model.compile(loss='mae')\n", 622 | "vision_model.fit(np.array(x_train), np.array(y_train), epochs=3, batch_size=64)" 623 | ] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "execution_count": 2, 628 | "metadata": {}, 629 | "outputs": [], 630 | "source": [ 631 | "# vision_model.save('vision_model.h5')" 632 | ] 633 | }, 634 | { 635 | "cell_type": "code", 636 | "execution_count": 1, 637 | "metadata": {}, 638 | "outputs": [ 639 | { 640 | "data": { 641 | "application/vnd.jupyter.widget-view+json": { 642 | "model_id": "7581c681606f4d65ba471a048804c4a1", 643 | "version_major": 2, 644 | "version_minor": 0 645 | }, 646 | "text/plain": [ 647 | "HBox(children=(FloatProgress(value=0.0, max=964.0), HTML(value='')))" 648 | ] 649 | }, 650 | "metadata": {}, 651 | "output_type": "display_data" 652 | }, 653 | { 654 | "name": "stdout", 655 | "output_type": "stream", 656 | "text": [ 657 | "\n" 658 | ] 659 | } 660 | ], 661 | "source": [ 662 | "import os\n", 663 | "import numpy as np\n", 664 | "import matplotlib.pyplot as plt\n", 665 | "%matplotlib inline\n", 666 | "from tqdm.notebook import tqdm\n", 667 | "\n", 668 | "\n", 669 | "# Get list of different character paths\n", 670 | "img_dir = './images_background'\n", 671 | "alphabet_names = [a for a in os.listdir(img_dir) if a[0] != '.'] # get folder names\n", 672 | "char_paths = []\n", 673 | "for lang in alphabet_names:\n", 674 | " for char in [a for a in os.listdir(img_dir+'/'+lang) if a[0] != '.']:\n", 675 | " char_paths.append(img_dir+'/'+lang+'/'+char)\n", 676 | "\n", 677 | "char_to_png = {char_path: os.listdir(char_path) for char_path in tqdm(char_paths)}" 678 | ] 679 | }, 680 | { 681 | "cell_type": "code", 682 | "execution_count": 2, 683 | "metadata": {}, 684 | "outputs": [], 685 | "source": [ 686 | "from typing import List\n", 687 | "\n", 688 | "def draw_one_sample(char_paths: List, sample_size=6):\n", 689 | " n_chars = np.random.randint(low=1, high=sample_size, size=1)\n", 690 | " selected_chars = np.random.choice(char_paths, size=n_chars, replace=False)\n", 691 | " rep_char_list = selected_chars.tolist() + \\\n", 692 | " np.random.choice(selected_chars, size=sample_size-len(selected_chars), replace=True).tolist()\n", 693 | " sampled_paths = [char_path+'/'+np.random.choice(char_to_png[char_path]) for char_path in rep_char_list]\n", 694 | " return sampled_paths, n_chars[0]\n", 695 | "\n", 696 | "sampled_paths, n_chars = draw_one_sample(char_paths, sample_size=6)" 697 | ] 698 | }, 699 | { 700 | "cell_type": "code", 701 | "execution_count": 3, 702 | "metadata": {}, 703 | "outputs": [ 704 | { 705 | "name": "stdout", 706 | "output_type": "stream", 707 | "text": [ 708 | "Number of selected characters is 4\n" 709 | ] 710 | }, 711 | { 712 | "data": { 713 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAAD3CAYAAAC+eIeLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAT0ElEQVR4nO3df6xkZX3H8fcXF1gWWHcXFaGwWH9ACgY3JIZQWiC1oZCWiBJjJIUuKNU2FqU2EU2sJpZUiQmaWII2DVo0IEYFakGIGmMwJTQ2VULAqGuE5Ycuslt+qbjw9I9zRg/D/Lr3zjzznHPer2Ry750zZ86P7zyf88xzzsyNlBKSpDz2WfYKSFKfGLqSlJGhK0kZGbqSlJGhK0kZGbqSlFHW0I2I7RFxe85lavGsazdZ18XoZU83Ik6LiJ1rfI6jI+LGiNgVEY9GxK0RcUxj+lUR8UTj9uuIeHzta69x5lHX+nk+HRE/iIhnI2L70LS/iojvRsRjEbEzIi6PiHVrXabGm2Ndt9W1e6r+ua0x7ZKI2FHX9cGIuGJRdW1t6C7zhV4vexNwE3AMcChwJ3Dj4DEppXeklA4a3IBrgS8uY33bpIC6AnwP+Fvgf0Y8bAPwbuBFwInA64B/yLKCLbbsukbEflTt83PAZuCzwI31/VC15RNSShuBVwOvAS5eyAqllBZyA44EvgzsAn4BfBLYDtwOfAzYDfwEOLMxzwXAPcDjwA7g7Y1ppwE7gfcCDwPX1Dvvq/Uydte/H9GYZwtwNfBgPf0G4EDgl8CzwBP17XCqA9ClwI/r9b0e2FI/z8uABLwVuA/49ojt3VI/5pAR0w6st+nURe3vXLc+1bXepu1T9sffA/+x7LpY18l1BU4HHgCisbz7gDNG7ItDgK8DVy5iXy+kpxsRL6h36E/rHfB7wHX15BOBH1D1FC4H/i0iop72c+AvgI1UBb0iIk5oPPVLqQpzFPDXVDv+6vrvrVTF+WTj8ddQ9UyOA14CXJFSehI4E3gw/a4n+iDwd8DZwKlURd0N/MvQpp0K/AHwZyM2+xTg4ZTSL0ZMO4fqhfbtEdNao6d1neYU4O5VzFeMntT1OOD7qU7V2vfr+wf74dyIeAx4hKqn+6kpu251FnTUPIkqZNYN3b8d+FHj7w1UR6SXjnmeG4B3NY6cTwPrJyx3G7C7/v0wqqPj5hGPOw3YOXTfPcDrGn8fBvwGWMfvjpwvH7PcI6iOom8ZM/0bwIcWsa9z3npY14k9XeBCqt7ci5ZdG+s6ua7AB4Drhp7j86PaJfAq4MPjtnOtt0WNsxwJ/DSltHfEtIcHv6SUnqoPmgcBRMSZwAeBo6mOihuAuxrz7kop/WrwR0RsAK4AzqB66wJwcH3kPhJ4NKW0e8Z1Pgr4SkQ827jvGarx2oH7h2eKiBcDt1G9Fbl2xPStVC+ai2Zcj5L1pq7TRMTZwD8Df5pSemSl8xemD3V9gqpH3rSRamjkOVJKP4yIu4ErgTfOuD4zW9SJtPuBrSsZPI+I/YEvUY0fHZpS2gTcDETjYcNfifYeqhNZJ6ZqAPyUwdPV67AlIjaNWNyor1a7n2q8alPjtj6l9MC4+SJiM1Xg3pRSumzMpp0HfCeltGPM9DbpRV1n2KYzgH8Fzkop3TXt8S3Qh7reDRzfGBoBOJ7xQ0PrgFeMmbYmiwrdO4GHgI9ExIERsT4iTp4yz37A/lRvc/bWR9HTp8xzMNW40J6I2EJ11AUgpfQQcAtwZURsjoh9I2JQ5J8Bh0TECxvPdRVwWUQcBVUPNiJeP27BEbERuJUqUC+dsI7nA5+Zsh1t0fm61o/ZLyLWU4XBvvV27lNP+xOqt6XnpJTunLIdbdGHun6Lqid8cUTsHxHvrO//Zj3/2yLiJfXvxwLvoxoWnLuFhG5K6RngLOCVVGcIdwJvnjLP41SXaFxPNSh+LtVlHJN8HDiAauD7DuBrQ9PPoxrnuZdq0P/d9bLupbqEa0dE7ImIw4FP1Mu7Larrae+gOokwzhuA1wIXxHOvx906eEBEnEQ13tuJS8V6Uleo3r38EvhD4NP174MA+ADwQuDmRs1vmfJ8RetDXVNKT1OdeDsf2EM1Hn92fT/AycBdEfEkVY/9ZuD9U7ZnVaIeOJYkZdDaD0dIUhsZupKUkaErSRkZupKUkaErSRlNuxjaSxvKEdMfMjPrWo551hWsbUlG1taeriRlZOhKUkaGriRlZOhKUkaGriRlZOhKUkaGriRlZOhKUkaGriRlZOhKUkaGriRlZOhKUkaGriRlZOhKUkaGriRlNO37dNUQ8buvx/S/KEtaDUN3imbQDt9v8Epaqd6G7rgwlaRF6mToGqiSStX60DVgJbVJ60K3lJB1PLe9Rr2GrKdyKSp0SwnUYTZISfNSVOgug4EqKadOhO6swVlqT1p5WH+VoJWha++0bMPhtux6GbYqSVGhu+zGqW4xbFWiokJXWiuDVqUzdLVwi75Ey6BVmxi6aqV5B63fpaFcDF0tRe7e6XCgjlq+wascDF3NXUqpiLf8BqhK5JeYq5OmBe646SUcLNRthq4WIqX021upSl43dZfDC2q1tQ5lDM9vEJeli19OZOhqKUpqOCWtiyqTDqRtP+Hp8IKkoszyzqXNY+/2dCUtzCzhOOi1rjRI29rjNXTVSSttwG1svKVaTXiuZVltq53DCxLtfrtakkXsxzZcCbMSvQndLp4FlUoy78AdFbRdaLO9CV1pGnu7ecwSnF0I13FaM6Y7rkF0uThavWmvCwN2uUr5qPgytL6n2/XLS7QYHqy1LK3p6U5S2r+HkaRxWt/THWUl1wZKmo9ZrjBoTu9rG+xET3eUQfD2tbDSstjmJmtN6M7yJdSjOJ4rqSStCd1hq/3ooMowz3F4e1bd0vV6tjZ0B/p86UnJ+vI5eq1NH2veiRNpXfqIoKRua31Pt2mWXq/hvFi+65Am60RPt2lSqBq4i7XawPWdivqkUz3daRw3XJyVBK41UJ91MnQnDTMYvMvhPu8Ph5gm69zwwsCkRu6LYv5G7e+SvwfV18BiuF+n62zogr2r3NryEU+/sW4xDNzZdHJ4oWnchyj8mPBitHV/tnW9S7GMwG1rzTofutKAPbG82hqKi9bp4YWmcS8AG6IMB+XUm9AFG1dfRYQH1yUocZ+XsE4OL+BlZF3mJxSXa9L+n3Xfzzsol93eexe6fkFOfxi4+aymXfW1HfZqeGFgVGPr6wugqwzc/Nyns+ll6I5j8HaDgbs87tvpehu6k65mMHzba9b/Dm2NF6fUTyGWondjuuouvzi9LNP2bV+/JCmmbEx3tnSMeZxdzWSeXbOiNmxe1vLVkks07y53J2vbUiNr29vhhYFxDa6wwNUUowLXGqpEDi/gP7lsu1nfrVhflcDQbbBn1G1eo60S9H54Qe1miKptDF21VotOgkq/5fCCOsWwVens6ao3HIpQCQxddcq4T5sZuCqFwwvqpFlD1uEI5WZPV6211sA0cLUMhq5azeBU2xi6ar3VBK9hrWVxTFedMOtHuQ1bLZuhq04xVFU6hxckKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyCr/0WZLysacrSRkZupKUkaErSRkZupKUUbbQjYjtEXF7ruUpH2vbTdZ1MXrX042I0yJi5xqf4+iIuDEidkXEoxFxa0Qc05geEfFPEfFARPxfRHwrIo5b+9prknnUtn6ebRHx3Yh4qv65rTHtkojYERGPRcSDEXFFRKxb6zI1XqY2e1VEPNG4/ToiHl/72j9fK0N3mS/yetmbgJuAY4BDgTuBGxsPexNwIfDHwBbgv4Br8q5pOy27thGxH1UtPwdsBj4L3FjfD1XdT0gpbQReDbwGuHgZ69smy64rU9psSukdKaWDBjfgWuCLC1mhlNLcb8CRwJeBXcAvgE8C24HbgY8Bu4GfAGc25rkAuAd4HNgBvL0x7TRgJ/Be4GGqANsMfLVexu769yMa82wBrgYerKffABwI/BJ4Fniivh1OdfC5FPhxvb7XA1vq53kZkIC3AvcB3x6xvVvqxxxS//1e4PrG9OOAXy1iX+e+db22wOnAA9TXsNePuw84Y8S+OAT4OnDlsutiXdfWZoemHVhv06kL2dcLKN4LgO8BV9Qrvx74o7qAvwEuqh/zN/XOHXxA48+BVwABnAo8RdWjGBRwL/BRYH/ggPoFfw6wATiY6qh0Q2M9/hP4Ql3ofQc7cPBiGFrndwF3AEfUz/8p4NqhAv57vT0HjNjms4GHGn8fBXwXOLpe9uXNdWvrrQ+1BS4Bbhl6jq8C72n8fS7wWD3vLuA1y66NdV1bmx2adj7VQSRWsz+n7u8FFPCk+oW4buj+7cCPGn9vqHfMS8c8zw3Auxo7/Wlg/YTlbgN2178fRnVk3DzicaMKeA/wusbfh9UvtnWNAr58zHKPoOoZvaVx337AJ+r59lL1EH5/2Y3L2k6vLfAB4Lqh5/g88KERy3sV8OFx29mWWx/qOjTv89rs0PRvjKr3vG6LGGc5EvhpSmnviGkPD35JKT0VEQAHAUTEmcAHqXqH+1AV+K7GvLtSSr8a/BERG6iOzGdQHRkBDo6IF9Tr8GhKafeM63wU8JWIeLZx3zNUYz8D9w/PFBEvBm6jent5bWPSPwKvrdfjYeAvgW9GxHEppadmXKcS9aG2TwAbh55jI9XbzedIKf0wIu4GrgTeOOP6lKgPdR2sw7g2O5i+lSrkL5pxPVZsESfS7ge2rmTgPCL2B75ENXZ0aEppE3Az1duWgeEviXgP1aD4iak6qXHK4OnqddgSEZtGLG7Ul03cTzVWtalxW59SemDcfBGxmap4N6WULht6vm3AF1JKO1NKe1NKn6F6kR07avtbpA+1vRs4Pup0qR1f3z/KOqq32G3Wh7pOa7MD5wHfSSntGDN9zRYRuncCDwEfiYgDI2J9RJw8ZZ79qMZldgF76yPo6VPmOZhqgH1PRGyhOuICkFJ6CLgFuDIiNkfEvhExKPDPgEMi4oWN57oKuCwijoLqaBgRrx+34IjYCNxKVZxLRzzkv4E3RcShEbFPRJxHNUb1oynbVLrO1xb4FlWP6eKI2D8i3lnf/816/rdFxEvq348F3kf1drTNOl/XGdrswPnAZ6Zsx5rMPXRTSs8AZwGvpDpzuBN485R5Hqe67OZ6qrOW51Jd3jHJx6kG5x+hGlD/2tD086jGeO4Ffg68u17WvVSXg+yIiD0RcTjV+OtNwG31tXl3ACdOWPYbqIYPLhi6tm9rPf2jVCcm/hfYQ3Vy5pyU0p4p21S0PtQ2pfQ01UmW86lqdyFwdn0/wMnAXRHxJFXP7mbg/VO2p2h9qCvT2ywRcRLVeO9iLhUbLKceOJYkZdDKD0dIUlsZupKUkaErSRkZupKU0bTr8jzLVo6Y/pCZWddyzLOuYG1LMrK29nQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKaNF/GPKYj33315V/BJ3STl1KnRHhSoYrJLK0anQHWdcGEtSbo7pSlJGhq4kZdT70HXoQVJOnRrTbZ4wM0wllaj3PV1JyqmzoetlYpJK1NnQlaQSdWpMd6XsDUvKrdOhm1LyhJq0JG1oe8voeDm8IEkZdT50HUKQVJLOhy4YvJKeb1m50Okx3WkiwkCWFmS1bWvSWHAX2msverrQjWJJfdeGk3PT9CZ0JbVDSqnTnaTeh24XjpxSF40L3ra32V6FbpePnlIXdbHN9ip0JWnZDF1JRetab7d3odu1Akp91OZx3V5fp6tuGNUAPbh2S5e+R6V3PV11R0SMbYhdaaDqHkMXG2gbzVIz69ptba1vL4cXuvRWpQtmqcVguGCldfOj3t3RlXbby9DV8q0mPNeyLINXpXB4odaFI2hbLGJfDz46ariqdL0N3VGN0+BdvHnv41FBa/B2U1faZ29DF2ycJZulNtavfQZXnEy68mRWba1/r0N3lK4cTbugrY1K89eldllE6M7jqLdaNmxpeVbb7tvcbpceus2d3qWjmaTZ9K3dLz101S+zXGHQnN7mHo1mNyl4uxbKXqc7gtd1Lp77t58mfcBh0O66FrLD7OlKKkbXAxcMXbWUPeX2Wmvt2l57hxdUvLY3Mj1fH4YRxrGnK2kpVnow7crHvA1dSUsz6ycPuxC2A4aupKWaFqpdG4YwdCUVoS/B64k0LUWXGpHmZ5breNvOnq6yM3A1Sdd7vMWFbhd2qsazvppFl4O3uNBVdy2jsXTtzHefTKpdm4PXMV0tlYGoacaN87Z1jNeerqTidanHa+hqqUpsNCWuk7oTvEUOL7T1bYNWZ1KjmfV1MO+G52uwTF0YaogpK5plK8Y1mFw7cdTyCyzgPFNlaRvXpl5JptfAvHdIcS/cRVh2Zsxo5Eo6vKCsCmsU6pg2HNSLCF0bYr9Yb61Vm6/jLSJ0YfRO9D8Ed5fXz2qRSg7eIk+k5VRycfpgWvCupD6GeL8M6j3puxqajytF8aGb+6xkaQXqO+uhaab9F4rSrmwoZngB8l+HZy9X6oZZ3jGV0t6LCt2cWnKZmKQZtaX9Fhe6i+ztDo52pRzxJM1XG07QFhe6MN8j1qxBW3qhJM2u5PZc/Im0pkX1UEsukKTVaV7dUFIbL7KnC3mCsA1vRSStTWltvFU93XkprQiS+qPo0J12/d1KnkeSSlB06IKBKalbih3TlaQuMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyCv/briTlY09XkjIydCUpI0NXkjIydCUpI0NXkjIydCUpo/8H8O60oq/Odh8AAAAASUVORK5CYII=\n", 714 | "text/plain": [ 715 | "
" 716 | ] 717 | }, 718 | "metadata": { 719 | "needs_background": "light" 720 | }, 721 | "output_type": "display_data" 722 | } 723 | ], 724 | "source": [ 725 | "from matplotlib.figure import Figure\n", 726 | "from numpy import ndarray\n", 727 | "from typing import List\n", 728 | "import matplotlib.image as mpimg\n", 729 | "\n", 730 | "def render_chart(fig: Figure, axis: ndarray, image_id: int, data_list: List):\n", 731 | " image = mpimg.imread(data_list[image_id])\n", 732 | " axis.title.set_text(data_list[image_id].split('/')[-2])\n", 733 | " axis.axis('off')\n", 734 | " axis.imshow(image, cmap='gray')\n", 735 | "\n", 736 | "print('Number of selected characters is {}'.format(n_chars)) \n", 737 | "\n", 738 | "fig, axs = plt.subplots(2, 3)\n", 739 | "render_chart(fig=fig, axis=axs[0, 0], image_id=0, data_list=sampled_paths)\n", 740 | "render_chart(fig=fig, axis=axs[0, 1], image_id=1, data_list=sampled_paths)\n", 741 | "render_chart(fig=fig, axis=axs[0, 2], image_id=2, data_list=sampled_paths)\n", 742 | "render_chart(fig=fig, axis=axs[1, 0], image_id=3, data_list=sampled_paths)\n", 743 | "render_chart(fig=fig, axis=axs[1, 1], image_id=4, data_list=sampled_paths)\n", 744 | "render_chart(fig=fig, axis=axs[1, 2], image_id=5, data_list=sampled_paths)" 745 | ] 746 | }, 747 | { 748 | "cell_type": "code", 749 | "execution_count": 4, 750 | "metadata": { 751 | "scrolled": true 752 | }, 753 | "outputs": [ 754 | { 755 | "data": { 756 | "application/vnd.jupyter.widget-view+json": { 757 | "model_id": "59fff385c8aa44cf8f9b61383551fdfe", 758 | "version_major": 2, 759 | "version_minor": 0 760 | }, 761 | "text/plain": [ 762 | "HBox(children=(FloatProgress(value=0.0, max=60000.0), HTML(value='')))" 763 | ] 764 | }, 765 | "metadata": {}, 766 | "output_type": "display_data" 767 | }, 768 | { 769 | "name": "stdout", 770 | "output_type": "stream", 771 | "text": [ 772 | "\n" 773 | ] 774 | }, 775 | { 776 | "data": { 777 | "application/vnd.jupyter.widget-view+json": { 778 | "model_id": "872f05b8d043413dadfbe238bbd33fb9", 779 | "version_major": 2, 780 | "version_minor": 0 781 | }, 782 | "text/plain": [ 783 | "HBox(children=(FloatProgress(value=0.0, max=15000.0), HTML(value='')))" 784 | ] 785 | }, 786 | "metadata": {}, 787 | "output_type": "display_data" 788 | }, 789 | { 790 | "name": "stdout", 791 | "output_type": "stream", 792 | "text": [ 793 | "\n" 794 | ] 795 | } 796 | ], 797 | "source": [ 798 | "train_size, test_size = 60000, 15000\n", 799 | "\n", 800 | "train_dataset = [draw_one_sample(char_paths, sample_size=6) for i in tqdm(range(train_size))]\n", 801 | "train_X, train_y = [i[0] for i in train_dataset], [i[1] for i in train_dataset]\n", 802 | "\n", 803 | "test_dataset = [draw_one_sample(char_paths, sample_size=6) for i in tqdm(range(test_size))]\n", 804 | "test_X, test_y = [i[0] for i in test_dataset], [i[1] for i in test_dataset]" 805 | ] 806 | }, 807 | { 808 | "cell_type": "code", 809 | "execution_count": 5, 810 | "metadata": {}, 811 | "outputs": [], 812 | "source": [ 813 | "import tensorflow as tf\n", 814 | "from typing import List\n", 815 | "from tensorflow import convert_to_tensor\n", 816 | "AUTOTUNE = tf.data.experimental.AUTOTUNE\n", 817 | "\n", 818 | "\n", 819 | "@tf.function\n", 820 | "def load_image(file_path):\n", 821 | " image = tf.io.read_file(file_path)\n", 822 | " return tf.image.decode_png(image, channels=1)\n", 823 | "\n", 824 | "@tf.function\n", 825 | "def load_image_list(image_list: tf.Tensor):\n", 826 | " return tf.cast(tf.map_fn(lambda x: load_image(x), image_list, dtype=tf.uint8), tf.float32)\n", 827 | "\n", 828 | "\n", 829 | "class SetDataGenerator:\n", 830 | " def __init__(self, X, y):\n", 831 | " self.X = X\n", 832 | " self.y = y \n", 833 | " self.dataset = None\n", 834 | "\n", 835 | " def generator_init(self, shuffle_buffer_size=500, repeat=-1, batch_size=64): \n", 836 | " \"\"\"\n", 837 | " :param repeat, -1 is the default behaviour to repeat the dataset indefinitely\n", 838 | " \"\"\"\n", 839 | " self.dataset = tf.data.Dataset.from_tensor_slices((self.X, self.y))\n", 840 | " self.dataset = self.dataset.map(lambda x, y: (load_image_list(x), tf.cast(y, tf.float32)))\n", 841 | "\n", 842 | " if shuffle_buffer_size == 0:\n", 843 | " self.dataset = self.dataset.repeat(repeat).batch(batch_size).prefetch(buffer_size=AUTOTUNE)\n", 844 | " else:\n", 845 | " self.dataset = self.dataset.shuffle(shuffle_buffer_size).repeat(repeat).batch(batch_size) \\\n", 846 | " .prefetch(buffer_size=AUTOTUNE)\n", 847 | "\n", 848 | " @property\n", 849 | " def batch(self):\n", 850 | " return self.dataset" 851 | ] 852 | }, 853 | { 854 | "cell_type": "code", 855 | "execution_count": 6, 856 | "metadata": {}, 857 | "outputs": [ 858 | { 859 | "name": "stdout", 860 | "output_type": "stream", 861 | "text": [ 862 | "Number of unique characters is : 5.0\n" 863 | ] 864 | }, 865 | { 866 | "data": { 867 | "text/plain": [ 868 | "(-0.5, 104.5, 104.5, -0.5)" 869 | ] 870 | }, 871 | "execution_count": 6, 872 | "metadata": {}, 873 | "output_type": "execute_result" 874 | }, 875 | { 876 | "data": { 877 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAADnCAYAAAC9roUQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAJZUlEQVR4nO3d0Q7qthIF0KTq//8yfaioKIdACPZ4PF5Lug/VaTm5sb0zTEyy3263DYAYf40+AICVCF2AQEIXIJDQBQgkdAECCV2AQH9/+HP7yfLYG36Wcc2j5bhum7HN5OXYqnQBAgldgEBCFyCQ0AUIJHQBAgldgEBCFyCQ0AUIJHQBAn36RRqktu/f/aDLQ/sZTegyhW/D9dfPEc70or0AL7QKeXgmdAEChbUXXlUOvsIBq9HTpYRPF3A33MhC6DKFewhe/cYkRMlC6AJpVWxLCl2gm4qh+auhoVttW87qkwn4zJaxhvZ9L3chAdoKC11VIKxFAfKaShdoTuAeC+3pvtv2U4WKnpV9WtvWx6Abaa1PvDukMJawPU97AfiJwP2OfbrAZe8CV9i+ptIFLhG41whdplH5BuxsBO51Qhf4isD9jdBlahZ5HsbiHDfSGOJMq8Aizudo3IzVeUKXtM4E877vXRb8498tULQUWhK6TM8Ntr4Eblt6usAhgdue0AVeErh9aC8wROsXSfayargI3H66h66H0XDFqzkSHcSrzlOB25dKl6kJgbYEbn9CF9i2LUfgZmkr9TQkdO8n1pUTchj9o4cVwvZu6O6FlU40zCYicFd8mWv30J3lLjWsasQaXDFs71L0dP3kEnLpuQ6/ee5GxWAOCd3HAfx0Elud5FGD5aLBTCLXydm/q/oaCu/pVj+hFa/M1BR58+xsdVs9H7Zt8NuABRSwQtA+GtrTvd1ughcSaR2AGfb+ZjP8Rlq1qnfViQTPRu/9zWp46N79MhCe7wDMwqMdAQIJXVhQxLfD589bZXfCJ2naC0CMyPsnQvZPKl2AQEIXUJEGEroAgYQukFaG1za1JnQBAgldIHX1WK3fbMsYsG3bn8GbOez2fU99fO8IXeClTM9OqPRwrLLthSoDBK39Gpr3V+1YY9eUDd1tE7xwpNVPckcG8Kzru0ToztrbgdFarp3HAO4RiFXWeZme7lHPZ+aGO0TotRc24qE6M67vMqG7bbWa7TDSuyD7ZY39GpKv1vj9n2cJ31KhC/SX9W0vs2x5WyJ0Z/wKAtk9r6moED77jfbb44nKiCVCd9v+PwACGNo7E8LWXsHQPXMVFMDQX8+1lbXFcUaJLWO/mHHQgH/d9xvPVDyVq3S37ftdDJ/+3ZkGFFb1uE6/Wf/R67tk6G5b268f7z5DIEM+mddl2dC9Ozr52grACMv2dDNfCYG6yle675wJXq0FoKWlQ/cMwQq0tGx7AWAEoQsQSOgCBBK6AIGELkAgoQsQSOgCBBK6AIGELkAgocu0/FqQGQldpvEYsgKXWXn2AlMRtsxOpQsQSOgCBBK6AIGELkAgoQsQSOgCBBK6AIGELkAgoQsQaPcLH4A4Kl2AQEIXIJDQBQgkdAECCV2AQEIXIJDQBQgkdAECCV2AQEIXIJDQBQgkdAECCV2AQEIXIJDQBQgkdAECCV2AQEIXIJDQBQj094c/9wK1PPaGn2Vc82g5rttmbDN5ObYqXYBAQhcgkNAFCCR0AQIJXYBAQhcgkNAFCPRpn24q+/7ntrfbzbZEYB4qXYBAU1W6r7yqfh+phIFM0lS6+77/97/WnwuQRYrQfQ7Go6C83W6XKlfBC2QxPHSvBKLgBWY1PHSv0qsFZjT0Rtq7NsIZR//e0efu+y6sgaGGVbq/Bu47ghXIakjo9gzcT5+lt1tLjx0v0FOafbqqU856FbJaR8wivNKNrEpUu7V8qmqNKzNIUemOqFBURuM8h+OncRCm6+n1nJVv514PoaE74oE1t9vNboZEjloDcFd9PoS1F0aeSMGaQ/XFxO9WmCPT/jjiW4K3FuNZy0q7UELaCxFbxFiDOVPLKkH7qHvoClzeiZwH5uJY3wZs63HJEvDLtBfgSJbFWNk35/jq0wRn0TV0VRbcZQ+27Mc3s7PnNjpsR+VQin26QD3fhO1KwkN3tRPMsei58Pj3qWzHy9DPH0GlS3eZJjwx3o15hsJr5DG4kcYQv056Qc6sVLqUkaGCwg30T4Qu0N2oh1pltEToZj35wHpKh66whVhZ1lyW43hl2tBtcVL1mKAdvdxzlt29YCLEyFxxUFP2Obds6GYfmMpc8NYRPdZn1/XI9b9s6G6b4IVWMqylDMdwxtKhC/ST6We+mb5dTXsj7cpJ9OruWCPeiUesDNVl9p8cP5s2dIG8eoXdt8/lzWip9kLWQSCeudDP6HOb/SHoS4Xuto2fEOTxOBfMi+9FtxYytDJa0F6gi1n6uRmPietmGE+hCzTTI/TOVLgzhO3dcu0FYB7VAnfbhC5BZlsYjFcxcLdNe4EOqtzwINa382bGwN02lS4BZl0cHHNhvU6lCzQRcXG9+ndkukiodIEpVPnGJHRpamRFkamaoa3WgesV7JRVpTqhv6O5Um0OLRe6qqEaqi1E/vU8rhkenNPacqFLPy5otHAP2qoXVqFLN1UXDf1V6uE+E7o08Vzljto+pNoeZ7ZzP+p4hS4/y7bYsh0POWSpdsND14KgpVavbeI7WQJsRktVuhZbezOd033f/zvemY6bfq7Og1/mT9efAd9utzSTO8txrGCWKshLSdurfk4fc+Tqg/qXqHRne1voLGa6kM10rLM4WjuZz3XP9X72s5cI3SMC97qjhZXxnM50rMRrcZH4Zi6VD93MV13asHVsnBmr3da+vXgPCd2oAdFW6GOmynGmYyXG1Yt0q9wqX+m+YsFd83j3/5lzuq4q1e7R/G4977s/xHzUDgbhwBFzIM6MuxnO5tXV/19LVrq0lXVRzVZpzezdHOg1DvcK9F0lmlHJ0FXlcsQc6OdT8LYKxxaf8+s8+OW/L/eOtJmueBUIMR6daSfe/3xEdfzoauvz1zk/LHR79HrsVuhr9nM4+/HP4myYZSiQHufEp+NpNX/KVLoCF/KIvoHeYo1H5URI6B4NQNRXCNZmDozxfN5brveZxzSs0h2xdWzmgaENcyCP+1j03pKVXZn2wrOqA8ZrmZ5ox3tnbqBVXr+hoRuxMCoPFlS3wvoNr3RbB+8KgwTUMaS9ICjp4fmCbp6RUclfpLEuQUt2ZW+ksS7BS2YqXYBAQhcgkNAFCCR0AQIJXYBAQhcgkNAFCCR0AQIJXYBAQhcgkNAFCCR0AQIJXYBAQhcgkNAFCCR0AQIJXYBAu6fsA8RR6QIEEroAgYQuQCChCxBI6AIEEroAgf4BNMKkdv4aDUMAAAAASUVORK5CYII=\n", 878 | "text/plain": [ 879 | "
" 880 | ] 881 | }, 882 | "metadata": { 883 | "needs_background": "light" 884 | }, 885 | "output_type": "display_data" 886 | } 887 | ], 888 | "source": [ 889 | "# Test data generator\n", 890 | "\n", 891 | "set_data_gen = SetDataGenerator(train_X, train_y)\n", 892 | "set_data_gen.generator_init(batch_size=3)\n", 893 | "batch_data = next(iter(set_data_gen.batch))\n", 894 | "\n", 895 | "im_to_plot = batch_data[0][0].numpy()\n", 896 | "\n", 897 | "print('Number of unique characters is : {}'.format(batch_data[1][0].numpy()))\n", 898 | "\n", 899 | "fig, axs = plt.subplots(2, 3)\n", 900 | "axs[0, 0].imshow(im_to_plot[0, :, :, 0], cmap='gray')\n", 901 | "axs[0, 0].axis('off')\n", 902 | "axs[0, 1].imshow(im_to_plot[1, :, :, 0], cmap='gray')\n", 903 | "axs[0, 1].axis('off')\n", 904 | "axs[0, 2].imshow(im_to_plot[2, :, :, 0], cmap='gray')\n", 905 | "axs[0, 2].axis('off')\n", 906 | "axs[1, 0].imshow(im_to_plot[3, :, :, 0], cmap='gray')\n", 907 | "axs[1, 0].axis('off')\n", 908 | "axs[1, 1].imshow(im_to_plot[4, :, :, 0], cmap='gray')\n", 909 | "axs[1, 1].axis('off')\n", 910 | "axs[1, 2].imshow(im_to_plot[5, :, :, 0], cmap='gray')\n", 911 | "axs[1, 2].axis('off')" 912 | ] 913 | }, 914 | { 915 | "cell_type": "code", 916 | "execution_count": 7, 917 | "metadata": {}, 918 | "outputs": [ 919 | { 920 | "name": "stdout", 921 | "output_type": "stream", 922 | "text": [ 923 | "Train on 3 samples\n", 924 | "Epoch 1/3\n", 925 | "3/3 [==============================] - 7s 2s/sample - loss: 4.7445\n", 926 | "Epoch 2/3\n", 927 | "3/3 [==============================] - 2s 601ms/sample - loss: 3.4197\n", 928 | "Epoch 3/3\n", 929 | "3/3 [==============================] - 2s 654ms/sample - loss: 1.0533\n" 930 | ] 931 | }, 932 | { 933 | "data": { 934 | "text/plain": [ 935 | "" 936 | ] 937 | }, 938 | "execution_count": 7, 939 | "metadata": {}, 940 | "output_type": "execute_result" 941 | } 942 | ], 943 | "source": [ 944 | "from tensorflow.keras.layers import LayerNormalization, Dense\n", 945 | "import tensorflow as tf\n", 946 | "from set_transformer.layers.attention import MultiHeadAttention\n", 947 | "from set_transformer.layers import RFF\n", 948 | "from set_transformer.blocks import SetAttentionBlock, PoolingMultiHeadAttention\n", 949 | "from tensorflow.keras.layers import Conv2D\n", 950 | "from tensorflow.keras.layers import Conv2D, Input, MaxPooling2D, Dropout, Flatten\n", 951 | "from tensorflow.keras.models import Model\n", 952 | "tf.keras.backend.set_floatx('float32')\n", 953 | "\n", 954 | "\n", 955 | "def image_processing_model(input_image_shape = (105, 105, 1), output_len=128):\n", 956 | " input_image = Input(shape=input_image_shape)\n", 957 | " y = Conv2D(64, kernel_size=(3, 3),\n", 958 | " activation='relu',\n", 959 | " input_shape=input_image_shape)(input_image)\n", 960 | " y = Conv2D(64, (3, 3), activation='relu')(y)\n", 961 | " y = Conv2D(64, (3, 3), activation='relu')(y)\n", 962 | " y = Conv2D(64, (3, 3), activation='relu')(y)\n", 963 | " y = MaxPooling2D(pool_size=(2, 2))(y)\n", 964 | " y = Dropout(0.25)(y)\n", 965 | " y = Flatten()(y)\n", 966 | " output_vec = Dense(output_len, activation='relu')(y)\n", 967 | " return Model(input_image, output_vec)\n", 968 | "\n", 969 | "\n", 970 | "class CharEncoder(tf.keras.layers.Layer):\n", 971 | " def __init__(self, d=128, h=8):\n", 972 | " super(CharEncoder, self).__init__()\n", 973 | "\n", 974 | " # Instantiate image processing model\n", 975 | " self.image_model = image_processing_model(output_len=d)\n", 976 | "\n", 977 | " # Encoding part\n", 978 | " self.sab_1 = SetAttentionBlock(d, h, RFF(d))\n", 979 | " self.sab_2 = SetAttentionBlock(d, h, RFF(d))\n", 980 | "\n", 981 | " def call(self, x):\n", 982 | " return self.sab_2(self.sab_1(tf.map_fn(self.image_model, x)))\n", 983 | " \n", 984 | "\n", 985 | "class CharDecoder(tf.keras.layers.Layer):\n", 986 | " def __init__(self, out_dim, d=128, h=8, k=32):\n", 987 | " super(CharDecoder, self).__init__()\n", 988 | "\n", 989 | " self.PMA = PoolingMultiHeadAttention(d, k, h, RFF(d), RFF(d))\n", 990 | " self.SAB = SetAttentionBlock(d, h, RFF(d))\n", 991 | " self.output_mapper = Dense(out_dim)\n", 992 | " self.k, self.d = k, d\n", 993 | "\n", 994 | " def call(self, x):\n", 995 | " decoded_vec = self.SAB(self.PMA(x))\n", 996 | " decoded_vec = tf.reshape(decoded_vec, [-1, self.k * self.d])\n", 997 | " return tf.reshape(self.output_mapper(decoded_vec), (tf.shape(decoded_vec)[0],))\n", 998 | "\n", 999 | " \n", 1000 | "class CharSetTransformer(tf.keras.Model):\n", 1001 | " def __init__(self):\n", 1002 | " super(CharSetTransformer, self).__init__()\n", 1003 | " self.encoder = CharEncoder()\n", 1004 | " self.decoder = CharDecoder(out_dim=1)\n", 1005 | "\n", 1006 | " def call(self, x):\n", 1007 | " enc_output = self.encoder(x) # (batch_size, set_len, d_model)\n", 1008 | " return self.decoder(enc_output)\n", 1009 | " \n", 1010 | "\n", 1011 | "tset_model = CharSetTransformer()\n", 1012 | "tset_model.compile(loss='mae', optimizer='adam')\n", 1013 | "tset_model.fit(batch_data[0], batch_data[1], epochs=3)" 1014 | ] 1015 | }, 1016 | { 1017 | "cell_type": "code", 1018 | "execution_count": 8, 1019 | "metadata": {}, 1020 | "outputs": [], 1021 | "source": [ 1022 | "# set_data_gen = SetDataGenerator(train_X, train_y)\n", 1023 | "# set_data_gen.generator_init(batch_size=64)\n", 1024 | "\n", 1025 | "# tset_model = CharSetTransformer()\n", 1026 | "# tset_model.compile(loss='mae', optimizer='adam')\n", 1027 | "# tset_model.fit(set_data_gen.batch, epochs=3, steps_per_epoch=300)" 1028 | ] 1029 | }, 1030 | { 1031 | "cell_type": "code", 1032 | "execution_count": null, 1033 | "metadata": {}, 1034 | "outputs": [], 1035 | "source": [] 1036 | }, 1037 | { 1038 | "cell_type": "code", 1039 | "execution_count": null, 1040 | "metadata": {}, 1041 | "outputs": [], 1042 | "source": [] 1043 | } 1044 | ], 1045 | "metadata": { 1046 | "kernelspec": { 1047 | "display_name": "Python 3", 1048 | "language": "python", 1049 | "name": "python3" 1050 | }, 1051 | "language_info": { 1052 | "codemirror_mode": { 1053 | "name": "ipython", 1054 | "version": 3 1055 | }, 1056 | "file_extension": ".py", 1057 | "mimetype": "text/x-python", 1058 | "name": "python", 1059 | "nbconvert_exporter": "python", 1060 | "pygments_lexer": "ipython3", 1061 | "version": "3.6.9" 1062 | } 1063 | }, 1064 | "nbformat": 4, 1065 | "nbformat_minor": 2 1066 | } 1067 | -------------------------------------------------------------------------------- /ipynb/Set_transformer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from __future__ import absolute_import, division, print_function, unicode_literals\n", 10 | "import time\n", 11 | "import numpy as np\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "import tensorflow as tf\n", 14 | "import warnings\n", 15 | "with warnings.catch_warnings():\n", 16 | " warnings.filterwarnings(\"ignore\",category=FutureWarning)" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 3, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "Attention weights are:\n", 29 | "tf.Tensor([[0. 1. 0. 0.]], shape=(1, 4), dtype=float32)\n", 30 | "Output is:\n", 31 | "tf.Tensor([[10. 0.]], shape=(1, 2), dtype=float32)\n" 32 | ] 33 | } 34 | ], 35 | "source": [ 36 | "# MultiHeadAttention\n", 37 | "# https://www.tensorflow.org/tutorials/text/transformer, appears in \"Attention is all you need\" NIPS 2018 paper\n", 38 | "import numpy as np\n", 39 | "import tensorflow as tf\n", 40 | "\n", 41 | "\n", 42 | "def scaled_dot_product_attention(q, k, v, mask):\n", 43 | " \"\"\"Calculate the attention weights.\n", 44 | " q, k, v must have matching leading dimensions.\n", 45 | " k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.\n", 46 | " The mask has different shapes depending on its type(padding or look ahead) \n", 47 | " but it must be broadcastable for addition.\n", 48 | "\n", 49 | " Args:\n", 50 | " q: query shape == (..., seq_len_q, depth)\n", 51 | " k: key shape == (..., seq_len_k, depth)\n", 52 | " v: value shape == (..., seq_len_v, depth_v)\n", 53 | " mask: Float tensor with shape broadcastable \n", 54 | " to (..., seq_len_q, seq_len_k). Defaults to None.\n", 55 | "\n", 56 | " Returns:\n", 57 | " output, attention_weights\n", 58 | " \"\"\"\n", 59 | "\n", 60 | " matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)\n", 61 | " \n", 62 | " # scale matmul_qk\n", 63 | " dk = tf.cast(tf.shape(k)[-1], tf.float32)\n", 64 | " scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)\n", 65 | "\n", 66 | " # add the mask to the scaled tensor.\n", 67 | " if mask is not None:\n", 68 | " scaled_attention_logits += (mask * -1e9) \n", 69 | "\n", 70 | " # softmax is normalized on the last axis (seq_len_k) so that the scores\n", 71 | " # add up to 1.\n", 72 | " attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)\n", 73 | "\n", 74 | " output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)\n", 75 | "\n", 76 | " return output, attention_weights\n", 77 | "\n", 78 | "\n", 79 | "def print_out(q, k, v):\n", 80 | " temp_out, temp_attn = scaled_dot_product_attention(\n", 81 | " q, k, v, None)\n", 82 | " print ('Attention weights are:')\n", 83 | " print (temp_attn)\n", 84 | " print ('Output is:')\n", 85 | " print (temp_out)\n", 86 | " \n", 87 | "np.set_printoptions(suppress=True)\n", 88 | "\n", 89 | "temp_k = tf.constant([[10,0,0],\n", 90 | " [0,10,0],\n", 91 | " [0,0,10],\n", 92 | " [0,0,10]], dtype=tf.float32) # (4, 3)\n", 93 | "\n", 94 | "temp_v = tf.constant([[ 1,0],\n", 95 | " [ 10,0],\n", 96 | " [ 100,5],\n", 97 | " [1000,6]], dtype=tf.float32) # (4, 2)\n", 98 | "\n", 99 | "# This `query` aligns with the second `key`,\n", 100 | "# so the second `value` is returned.\n", 101 | "temp_q = tf.constant([[0, 10, 0]], dtype=tf.float32) # (1, 3)\n", 102 | "print_out(temp_q, temp_k, temp_v)" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 4, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "from tensorflow.keras.layers import Dense\n", 112 | "from tensorflow.keras.models import Model\n", 113 | "\n", 114 | " \n", 115 | "class MultiHeadAttention(tf.keras.layers.Layer):\n", 116 | " def __init__(self, d_model, num_heads):\n", 117 | " super(MultiHeadAttention, self).__init__()\n", 118 | " self.num_heads = num_heads\n", 119 | " self.d_model = d_model\n", 120 | "\n", 121 | " assert d_model % self.num_heads == 0\n", 122 | " \n", 123 | " self.depth = d_model // self.num_heads\n", 124 | " \n", 125 | " self.wq = tf.keras.layers.Dense(d_model)\n", 126 | " self.wk = tf.keras.layers.Dense(d_model)\n", 127 | " self.wv = tf.keras.layers.Dense(d_model)\n", 128 | "\n", 129 | " self.dense = tf.keras.layers.Dense(d_model)\n", 130 | " \n", 131 | " def split_heads(self, x, batch_size):\n", 132 | " \"\"\"Split the last dimension into (num_heads, depth).\n", 133 | " Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)\n", 134 | " \"\"\"\n", 135 | " x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))\n", 136 | " return tf.transpose(x, perm=[0, 2, 1, 3])\n", 137 | " \n", 138 | " def call(self, q, k, v, mask=None):\n", 139 | " batch_size = tf.shape(q)[0]\n", 140 | "\n", 141 | " q = self.wq(q) # (batch_size, seq_len, d_model)\n", 142 | " k = self.wk(k) # (batch_size, seq_len, d_model)\n", 143 | " v = self.wv(v) # (batch_size, seq_len, d_model)\n", 144 | "\n", 145 | " q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)\n", 146 | " k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)\n", 147 | " v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)\n", 148 | "\n", 149 | " # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)\n", 150 | " # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)\n", 151 | " scaled_attention, attention_weights = scaled_dot_product_attention(\n", 152 | " q, k, v, mask)\n", 153 | " \n", 154 | " scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)\n", 155 | "\n", 156 | " concat_attention = tf.reshape(scaled_attention, \n", 157 | " (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)\n", 158 | "\n", 159 | " output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)\n", 160 | "\n", 161 | " return output\n", 162 | " \n", 163 | "\n", 164 | "# temp_mha = MultiHeadAttention(d_model=512, num_heads=8)\n", 165 | "# y = tf.random.uniform((1, 60, 512)) # (batch_size, encoder_sequence, d_model)\n", 166 | "# out = temp_mha(v=y, k=y, q=y)\n", 167 | "# print(out.shape)\n", 168 | "\n", 169 | "\n", 170 | "class RFF(tf.keras.layers.Layer):\n", 171 | " \"\"\"\n", 172 | " Row-wise FeedForward layers.\n", 173 | " \"\"\"\n", 174 | " def __init__(self, d):\n", 175 | " super(RFF, self).__init__()\n", 176 | " \n", 177 | " self.linear_1 = Dense(d, activation='relu')\n", 178 | " self.linear_2 = Dense(d, activation='relu')\n", 179 | " self.linear_3 = Dense(d, activation='relu')\n", 180 | " \n", 181 | " def call(self, x):\n", 182 | " \"\"\"\n", 183 | " Arguments:\n", 184 | " x: a float tensor with shape [b, n, d].\n", 185 | " Returns:\n", 186 | " a float tensor with shape [b, n, d].\n", 187 | " \"\"\"\n", 188 | " return self.linear_3(self.linear_2(self.linear_1(x))) \n", 189 | "\n", 190 | "\n", 191 | "# mlp = RFF(3)\n", 192 | "# y = mlp(tf.ones(shape=(2, 4, 3))) # The first call to the `mlp` will create the weights\n", 193 | "# print('weights:', len(mlp.weights))\n", 194 | "# print('trainable weights:', len(mlp.trainable_weights))" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 27, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "# Referencing https://arxiv.org/pdf/1810.00825.pdf \n", 204 | "# and the original PyTorch implementation https://github.com/TropComplique/set-transformer/blob/master/blocks.py\n", 205 | "from tensorflow import repeat\n", 206 | "# from tensorflow.keras.backend import repeat_elements\n", 207 | "from tensorflow.keras.layers import LayerNormalization\n", 208 | "\n", 209 | "\n", 210 | "class MultiHeadAttentionBlock(tf.keras.layers.Layer):\n", 211 | " def __init__(self, d, h, rff):\n", 212 | " super(MultiHeadAttentionBlock, self).__init__()\n", 213 | " self.multihead = MultiHeadAttention(d, h)\n", 214 | " self.layer_norm1 = LayerNormalization(epsilon=1e-6, dtype='float32')\n", 215 | " self.layer_norm2 = LayerNormalization(epsilon=1e-6, dtype='float32')\n", 216 | " self.rff = rff\n", 217 | " \n", 218 | " def call(self, x, y):\n", 219 | " \"\"\"\n", 220 | " Arguments:\n", 221 | " x: a float tensor with shape [b, n, d].\n", 222 | " y: a float tensor with shape [b, m, d].\n", 223 | " Returns:\n", 224 | " a float tensor with shape [b, n, d].\n", 225 | " \"\"\"\n", 226 | " \n", 227 | " h = self.layer_norm1(x + self.multihead(x, y, y))\n", 228 | " return self.layer_norm2(h + self.rff(h))\n", 229 | "\n", 230 | "# x_data = tf.random.normal(shape=(10, 2, 9))\n", 231 | "# y_data = tf.random.normal(shape=(10, 3, 9))\n", 232 | "# rff = RFF(d=9)\n", 233 | "# mab = MultiHeadAttentionBlock(9, 3, rff=rff)\n", 234 | "# mab(x_data, y_data).shape \n", 235 | "\n", 236 | " \n", 237 | "class SetAttentionBlock(tf.keras.layers.Layer):\n", 238 | " def __init__(self, d, h, rff):\n", 239 | " super(SetAttentionBlock, self).__init__()\n", 240 | " self.mab = MultiHeadAttentionBlock(d, h, rff)\n", 241 | " \n", 242 | " def call(self, x):\n", 243 | " \"\"\"\n", 244 | " Arguments:\n", 245 | " x: a float tensor with shape [b, n, d].\n", 246 | " Returns:\n", 247 | " a float tensor with shape [b, n, d].\n", 248 | " \"\"\"\n", 249 | " return self.mab(x, x)\n", 250 | "\n", 251 | " \n", 252 | "class InducedSetAttentionBlock(tf.keras.layers.Layer):\n", 253 | " def __init__(self, d, m, h, rff1, rff2):\n", 254 | " \"\"\"\n", 255 | " Arguments:\n", 256 | " d: an integer, input dimension.\n", 257 | " m: an integer, number of inducing points.\n", 258 | " h: an integer, number of heads.\n", 259 | " rff1, rff2: modules, row-wise feedforward layers.\n", 260 | " It takes a float tensor with shape [b, n, d] and\n", 261 | " returns a float tensor with the same shape.\n", 262 | " \"\"\"\n", 263 | " super(InducedSetAttentionBlock, self).__init__()\n", 264 | " self.mab1 = MultiHeadAttentionBlock(d, h, rff1)\n", 265 | " self.mab2 = MultiHeadAttentionBlock(d, h, rff2)\n", 266 | " self.inducing_points = tf.random.normal(shape=(1, m, d))\n", 267 | "\n", 268 | " def call(self, x):\n", 269 | " \"\"\"\n", 270 | " Arguments:\n", 271 | " x: a float tensor with shape [b, n, d].\n", 272 | " Returns:\n", 273 | " a float tensor with shape [b, n, d].\n", 274 | " \"\"\"\n", 275 | " b = tf.shape(x)[0] \n", 276 | " p = self.inducing_points\n", 277 | " p = repeat(p, (b), axis=0) # shape [b, m, d] \n", 278 | " \n", 279 | " h = self.mab1(p, x) # shape [b, m, d]\n", 280 | " return self.mab2(x, h) \n", 281 | " \n", 282 | "\n", 283 | "class PoolingMultiHeadAttention(tf.keras.layers.Layer):\n", 284 | "\n", 285 | " def __init__(self, d, k, h, rff, rff_s):\n", 286 | " \"\"\"\n", 287 | " Arguments:\n", 288 | " d: an integer, input dimension.\n", 289 | " k: an integer, number of seed vectors.\n", 290 | " h: an integer, number of heads.\n", 291 | " rff: a module, row-wise feedforward layers.\n", 292 | " It takes a float tensor with shape [b, n, d] and\n", 293 | " returns a float tensor with the same shape.\n", 294 | " \"\"\"\n", 295 | " super(PoolingMultiHeadAttention, self).__init__()\n", 296 | " self.mab = MultiHeadAttentionBlock(d, h, rff)\n", 297 | " self.seed_vectors = tf.random.normal(shape=(1, k, d))\n", 298 | " self.rff_s = rff_s\n", 299 | "\n", 300 | " @tf.function\n", 301 | " def call(self, z):\n", 302 | " \"\"\"\n", 303 | " Arguments:\n", 304 | " z: a float tensor with shape [b, n, d].\n", 305 | " Returns:\n", 306 | " a float tensor with shape [b, k, d]\n", 307 | " \"\"\"\n", 308 | " b = tf.shape(z)[0]\n", 309 | " s = self.seed_vectors\n", 310 | " s = repeat(s, (b), axis=0) # shape [b, k, d]\n", 311 | " return self.mab(s, self.rff_s(z))\n", 312 | " \n", 313 | "\n", 314 | "# z = tf.random.normal(shape=(10, 2, 9))\n", 315 | "# rff, rff_s = RFF(d=9), RFF(d=9) \n", 316 | "# pma = PoolingMultiHeadAttention(d=9, k=10, h=3, rff=rff, rff_s=rff_s)\n", 317 | "# pma(z).shape" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": 28, 323 | "metadata": {}, 324 | "outputs": [], 325 | "source": [ 326 | "from tensorflow.keras.layers import Dense\n", 327 | " \n", 328 | "\n", 329 | "class STEncoderBasic(tf.keras.layers.Layer):\n", 330 | " def __init__(self, d=12, m=6, h=6):\n", 331 | " super(STEncoderBasic, self).__init__()\n", 332 | " \n", 333 | " # Embedding part\n", 334 | " self.linear_1 = Dense(d, activation='relu')\n", 335 | " \n", 336 | " # Encoding part\n", 337 | " self.isab_1 = InducedSetAttentionBlock(d, m, h, RFF(d), RFF(d))\n", 338 | " self.isab_2 = InducedSetAttentionBlock(d, m, h, RFF(d), RFF(d))\n", 339 | " \n", 340 | " def call(self, x):\n", 341 | " return self.isab_2(self.isab_1(self.linear_1(x)))\n", 342 | "\n", 343 | " \n", 344 | "class STDecoderBasic(tf.keras.layers.Layer):\n", 345 | " def __init__(self, out_dim, d=12, m=6, h=2, k=8):\n", 346 | " super(STDecoderBasic, self).__init__()\n", 347 | " \n", 348 | " self.PMA = PoolingMultiHeadAttention(d, k, h, RFF(d), RFF(d))\n", 349 | " self.SAB = SetAttentionBlock(d, h, RFF(d))\n", 350 | " self.output_mapper = Dense(out_dim) \n", 351 | " self.k, self.d = k, d\n", 352 | "\n", 353 | " def call(self, x):\n", 354 | " decoded_vec = self.SAB(self.PMA(x))\n", 355 | " decoded_vec = tf.reshape(decoded_vec, [-1, self.k * self.d])\n", 356 | " return tf.reshape(self.output_mapper(decoded_vec), (tf.shape(decoded_vec)[0],))\n" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": 29, 362 | "metadata": {}, 363 | "outputs": [ 364 | { 365 | "data": { 366 | "text/plain": [ 367 | "TensorShape([100000, 9, 1])" 368 | ] 369 | }, 370 | "execution_count": 29, 371 | "metadata": {}, 372 | "output_type": "execute_result" 373 | } 374 | ], 375 | "source": [ 376 | "def gen_max_dataset(dataset_size=100000, set_size=9):\n", 377 | " \"\"\"\n", 378 | " The number of objects per set is constant in this toy example\n", 379 | " \"\"\"\n", 380 | " x = np.random.uniform(1, 100, (dataset_size, set_size))\n", 381 | " y = np.max(x, axis=1)\n", 382 | " x, y = np.expand_dims(x, axis=2), np.expand_dims(y, axis=1)\n", 383 | " return tf.cast(x, 'float32'), tf.cast(y, 'float32')\n", 384 | "\n", 385 | "X, y = gen_max_dataset()\n", 386 | "X.shape" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 30, 392 | "metadata": { 393 | "scrolled": true 394 | }, 395 | "outputs": [ 396 | { 397 | "name": "stdout", 398 | "output_type": "stream", 399 | "text": [ 400 | "(100000, 9, 3)\n", 401 | "(100000,)\n" 402 | ] 403 | } 404 | ], 405 | "source": [ 406 | "# Dimensionality check on encoder-decoder couple\n", 407 | "\n", 408 | "encoder = STEncoderBasic(d=3, m=2, h=1)\n", 409 | "encoded = encoder(X)\n", 410 | "print(encoded.shape)\n", 411 | "\n", 412 | "decoder = STDecoderBasic(out_dim=1, d=1, m=2, h=1, k=1)\n", 413 | "decoded = decoder(encoded)\n", 414 | "print(decoded.shape)" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": 31, 420 | "metadata": {}, 421 | "outputs": [], 422 | "source": [ 423 | "# Actual model for max-set prediction\n", 424 | "\n", 425 | "class SetTransformer(tf.keras.Model):\n", 426 | " def __init__(self, ):\n", 427 | " super(SetTransformer, self).__init__()\n", 428 | " self.basic_encoder = STEncoderBasic(d=4, m=3, h=2)\n", 429 | " self.basic_decoder = STDecoderBasic(out_dim=1, d=4, m=2, h=2, k=2)\n", 430 | " \n", 431 | " def call(self, x):\n", 432 | " enc_output = self.basic_encoder(x) # (batch_size, set_len, d_model)\n", 433 | " return self.basic_decoder(enc_output)" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": 32, 439 | "metadata": { 440 | "scrolled": true 441 | }, 442 | "outputs": [ 443 | { 444 | "name": "stdout", 445 | "output_type": "stream", 446 | "text": [ 447 | "Train on 100000 samples\n", 448 | "Epoch 1/6\n", 449 | "100000/100000 [==============================] - 24s 237us/sample - loss: 29.9259\n", 450 | "Epoch 2/6\n", 451 | "100000/100000 [==============================] - 21s 206us/sample - loss: 2.5411\n", 452 | "Epoch 3/6\n", 453 | "100000/100000 [==============================] - 20s 199us/sample - loss: 0.5547\n", 454 | "Epoch 4/6\n", 455 | "100000/100000 [==============================] - 20s 199us/sample - loss: 0.4607\n", 456 | "Epoch 5/6\n", 457 | "100000/100000 [==============================] - 20s 199us/sample - loss: 0.4181\n", 458 | "Epoch 6/6\n", 459 | "100000/100000 [==============================] - 21s 205us/sample - loss: 0.4109\n" 460 | ] 461 | }, 462 | { 463 | "data": { 464 | "text/plain": [ 465 | "" 466 | ] 467 | }, 468 | "execution_count": 32, 469 | "metadata": {}, 470 | "output_type": "execute_result" 471 | } 472 | ], 473 | "source": [ 474 | "set_transformer = SetTransformer()\n", 475 | "set_transformer.compile(loss='mae', optimizer='adam')\n", 476 | "set_transformer.fit(X, y, epochs=6)" 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": 116, 482 | "metadata": {}, 483 | "outputs": [], 484 | "source": [ 485 | "import tensorflow as tf\n", 486 | "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()" 487 | ] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "execution_count": 119, 492 | "metadata": {}, 493 | "outputs": [], 494 | "source": [ 495 | "import numpy as np\n", 496 | "from typing import List, Tuple\n", 497 | "\n", 498 | "def extract_image_set(x_data: np.array, y_data :np.array, agg_fun=np.sum, n_images=3) -> Tuple[np.array, np.array]:\n", 499 | " \"\"\"\n", 500 | " Extract a single set of images with corresponding target\n", 501 | " :param x_data\n", 502 | " \"\"\"\n", 503 | " idxs = np.random.randint(low=0, high=len(x_data)-1, size=n_images)\n", 504 | " return x_data[idxs], agg_fun(y_data[idxs])\n", 505 | "\n", 506 | "\n", 507 | "def generate_dataset(n_samples: int, x_data: np.array, y_data :np.array, agg_fun=np.sum, n_images=3) -> Tuple[List[List[np.array]], np.array]:\n", 508 | " \"\"\"\n", 509 | " :return X,y in format suitable for training/prediction \n", 510 | " \"\"\"\n", 511 | " generated_list = [extract_image_set(x_data, y_data, agg_fun, n_images) for i in range(n_samples)]\n", 512 | " X, y = [i[0] for i in generated_list], np.array([t[1] for t in generated_list])\n", 513 | " output_lists = [[] for i in range(n_images)]\n", 514 | " for image_idx in range(n_images):\n", 515 | " for sample_idx in range(n_samples):\n", 516 | " output_lists[image_idx].append(np.expand_dims(X[sample_idx][image_idx], axis=2))\n", 517 | " return output_lists, y\n", 518 | "\n", 519 | "X_train_data, y_train_data = generate_dataset(n_samples=100000, x_data=x_train, y_data=y_train, n_images=3)\n", 520 | "X_test_data, y_test_data = generate_dataset(n_samples=20000, x_data=x_test, y_data=y_test, n_images=3)" 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": 37, 526 | "metadata": {}, 527 | "outputs": [ 528 | { 529 | "data": { 530 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgAAAAIJCAYAAADTd4UyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3deZhUxb3/8c8XRTZFJOIe4EICuCRuuGCMolzXGDUgolHjHn+AGjAaFSJCXK4KLrnikoiKeiGiKGqCuXEhuOK+izGCcg3BBcImyE79/qgzpumqmenp7lnr/Xqefob+zDl1qrsP3d+pU+e0OecEAADS0qy+OwAAAOoeBQAAAAmiAAAAIEEUAAAAJIgCAACABFEAAACQIAqAEpnZ/mb2pJnNN7PlZvaRmY03sx3qu2/lYGZzzGxMHW5vRzN7LnsunZl1roVtHGpmQyL5eDN7rYD1nZmdW+5+1aVCH2s1bfTM2vnQzNab2fhKlnOR20ulbBtA6Tau7w40Zma2v6Tpkh6RdKakFZJ2kvRTSZ0kza23zjVeoyW1k3S0pOWSPquFbRwq6ThJNxW5fi9Jn5SvO/XiCkmtSmzjB5L2l/SSpM2qWfZ6SZNz7n9V4rYBlIgCoDQDJX0gqb/79xWVnpT0WzOz+utWo9ZD0mPOuadLaSR7/ls451aWp1v/5pwr+1+vZtbKObei3O1Wxjk3uwzN3Oyc+60kFTCaMKc2njcAxeMQQGnaSfrSRS6nmJvFhozNbKSZLci5f1q23B5mNt3Mvjazt7L7bczsbjNbYmYfm9mJVXUqW//BSD7azD6tKE7M7Boze9fMlpnZXDObYGbbFND25Lysd9b3XXKylmZ2nZn9w8xWmdnbZnZkFe12NjMnqaukoVl703N+f252eGWVmc0ys6F56480swXZIZlXJa2U1D+ynZGSfimpU85w9Pi8ZQ4xs3eywxDPm9nOeb/f4PXMtvmcmS3Nbm+ZWbDt/MdqZieZ2b1mtljSH7PfbZQ9lk+zx/q+mf00Z92DsnW3y8lmmNk6M2uXk71rZldV0YcNDgGYWTszG2dm88xsZbb9OypbX5Kcc+ur+j2Aho0CoDRvSDrIzC4zsy5lavMeSX+Q1E+SyQ+b3ilpnvyw9cuS7rWq5xhMknSkmbWpCLIP/eMlPZBTnGwl6WpJP5I0RFIXSdPMrBz7xWRJp2Xt/1jSq5IeM7PdKln+M/mh9c8lTcz+PSjr+9mSbpb0WNbWg5KuN7NL8tpoLf/8jZN0uKRXItsZl7X/ebaNXvLD4RU6yh+GuErSifLP0aTKRnTMrK2kP0n6WP41O07SffLFYXXGyA+F95d/niTpN5KGS/q9/GGQFyRNyCn6Xpa0RtIPs+23lrSnpNXyQ/Iys/aSdpb0XAF9qHCD/HD+UEmHSRomqZzXCR9pZmuzIu2urI8A6pNzjluRN0ltJU2Tf6N08h/St0vqlreck3RuXjZS0oKc+6dly52akx2ZZXflZJvLfwAMrKJfHSStlXRCTtYra6tnJetsJGn7bJkDcvI5ksbk3J8uaXLeur2z9XbJ7vfJ7h+Yt9yzkh6s5jnN314zSf+UdHfecrdKWiKpZc7z6SQdU8DrNkZ+SDo/H589b9/NyY7N2u0Rez0l9czub1aD/aZzts6UvLy9/LyHy/PyxyV9mHN/hqSx2b8PljRf0v2SrsmyoyWtk9S2ij6Ml/Razv33JJ1Xwv+F1ySNr2Jb/SQdIOkCSYskvS5po2K3x40bt9JvjACUwDm3VP7Dbj/5v+BmSzpL0htmtkeRzeYe+56V/ZyWs80l8m/421fRr/nZOgNy4gGSZjvncod9jzCzF81sifwHX8WkxW5F9r3Cf8r/hf2CmW1ccZN/bD1r2NYOkraT/6s/1yT5Aux7OZmT9OfiuvyNOc65j3Luz8zpR8xsScskTTSzY3KH4QswNe/+LvKjGLHH2s3MOmT3n1U2AiD/ofq8pGfysrez/bNQb0m6yMwGmVmpr/8GnHOnOececs4965y7QX6S7B7yozkA6gkFQImcN8M5N9w590P5D7j1ki4rssnFOf9eHckq8pbVtHO/pCPMrG02pN9f/oNEkmRme8kPqc+VdIr8CMG+2a+ra7s6W0raRn6kIvc2UtK3a9jWttnPL/Lyivu5Q8mLnHOrVZrYcy1V8pw45xZJOkRSc0kPSJpvZlMLPCSU/5gKfazPSdolKzZ+mN1/TlJPM2uZk9XEufJns4yQ9GE23+KEGrZRqP+VL5qKLZIBlAEFQJk5596SPxOgR068StImeYtuUctdmSI/h+AY+WO72ymnAJD0E/mRhAHOucecn6H9eQHtrlT1j2Wh/LD9XpHbvqqZitMAt8rLt87ZVoV6+W5r59xLzrnD5Y/795UfQZlYyKp59wt9rC9kP3vLP5/PSnpf/kO1j/wHa40KAOfcYufc+c65bSTtKj/XYIKZ7VSTdgrcVsXj5rvIgXpEAVACM8t/o66YbNdVG/4VN1fSjjnLNJN/o6412V+mT8gP/Q+Q9IFz7p2cRVpJWpPzZixJJxXQ9FxtWNxI/rz6XE/LjwAsc869ln+r0QPx25uncEb/8ZKWSnq3hu1JhY2g1JhzboVz7o+S7pK/HkRNvSfpa8Uf69+zQzsVr+178hP21kl6M3sdn5f0K/nTe2s6AvCNbD+5SP79If+1LpmZHS5pU/l5AADqCdcBKM247MP8IfljwVtIOl3+L6jcN/Epkgab2Zvys8XPkj9+XdsmyX8YLZE0Nu93T0oaYmY3yZ+Ctp+kkwtoc4qkM83sRvlj2AfJz7jPb/svkp40s2vl/zptK2k3+Ul7lxb6AJxz67NT935nZv/K2j5Q/hoMw1xx5/n/TdLWZnaa/AfpAufcnCLakZn9SNIZ8sPnn8rPzThHOfM2CuWcW5i9Hr82s7XyE+v6yk8GzT/18zlJgyX9xTm3LicbLekj51z+YYTqHsfz8q/te/J/mZ8tPyExdiZFxTod5F8Lye/7nczsuOyxTM6W+bn8YbGnJC2QH534ddZu/hwIAHWIAqA0t8rP3h8hf/x2sfyH3WHOuSdylhslP6x7pfxfn2Oz5QbXcv8elZ/ct6X8nIBvOOceN7OLJZ0n/2Y/Q9JRkv5eVYPOualmNkz+FL2zsm38IvtZsYwzs77yp5INkT+1bqH8RLOba/ognHN3ZMe2f5Hd5kr6pXPuxpq2lXlAvnC5Tv6MiXvkX8dizJL/wLxa/jWeL39a4LAi2xsh/5oNlB/6nyXpZOfc/XnLVRQAz+Zlkh8JqKkZ8s9BZ2WjCpKOcM5VdTXLnbXhhMUu8oclJH/4SfKF8anyZwG0lT/MdK+ky3IKFwD1wDYcAQYAAClgDgAAAAmiAAAAIEEUAAAAJIgCAACABFEAAACQoOpOA+QUAZQi+g16dYx9GKVoCPuwxH6M0kT3Y0YAAABIEAUAAAAJogAAACBBFAAAACSIAgAAgARRAAAAkCAKAAAAEkQBAABAgqq7EBAauenTpwfZoEGDgmzChAnR9XffffdydwkA0AAwAgAAQIIoAAAASBAFAAAACaIAAAAgQUwCbKRWr14dZHfeeWeQXXrppUG2dOnSINtiiy3K0zEAQKPACAAAAAmiAAAAIEEUAAAAJIgCAACABFEAAACQIM4CaKSOOuqoIHvqqaeC7PTTTw+y4cOHB1nHjh3L0zE0Kh9++GGQ9ejRo6B1+/btG80feuihkvoEoG4wAgAAQIIoAAAASBAFAAAACaIAAAAgQUwCbARuvPHGIItN+Bs9enSQDRkyJMg22mij8nQMSXv44YfruwsASsAIAAAACaIAAAAgQRQAAAAkiAIAAIAEMQmwnsyfPz+a9+7dO8iWL18eZDNnzgyybt26BVmzZtR4qFuxqwt27969HnqCxuTZZ58NslmzZhW07sqVK4PsvPPOC7L169cH2RFHHBFt87jjjguynXfeOcj22WefQrrYIPHpAABAgigAAABIEAUAAAAJogAAACBBTAKsA7EJKieddFJ02Tlz5gTZ66+/HmSFfmUrADQkd911VzQfOHBgkK1bt67o7ZhZkMUmRT/xxBPR9WN58+bNg2zfffcNsquvvjrIevXqFd1OfWIEAACABFEAAACQIAoAAAASRAEAAECCzDlX1e+r/CVCX331VZD1798/yJ555pno+i+//HKQff/73y+9Y/UjnIVT99iHq9CvX78gK/Vrfqt5T2lsGsI+LDXS/XjJkiVBtt9++0WXjV1Bsm/fvkHWqVOnIItdtW+nnXYKslGjRgVZbLJgZcaPHx9kixYtCrIWLVoE2c033xxkZ5xxRsHbLlH0QTICAABAgigAAABIEAUAAAAJogAAACBBXAmwBK+++mqQHXbYYUG2YsWKIItd3U+KT1wBGqLYBC0gV2xiXyyTpMGDBwfZddddF2SxCXaFGjNmTNHrStJf//rXIItNAly1alWQTZ06NcgOPfTQINthhx2K7F3NMQIAAECCKAAAAEgQBQAAAAmiAAAAIEFMAixQ7CssL7jggiBbvnx5kL3zzjtBtuOOO5anY0CBbr311iAr5ap/ffr0KaU7wAZiX5teyoS/Qi1dujSaP/roo0FW2QTGfO3atQuy0aNHB1ldTviLYQQAAIAEUQAAAJAgCgAAABJEAQAAQIIoAAAASBBnAUTEvvP55z//eZCtX78+yGbOnBlkPXr0KEu/gFI8/fTTZW2PswBQTg899FCQ/exnPwuyH/zgBwW1F5vd/8gjjwTZTTfdFF0/dvZWTGzG/5lnnhlkXbp0Kai9usQIAAAACaIAAAAgQRQAAAAkiAIAAIAEmXOuqt9X+cum4L777guyU089taB1p0+fHmR77713kL3wwgtB9vjjj0fbjE1Sueiii4IsNsmkefPm0TbrkdV3B5TAPhwTu2RpKZNR+/btG2SxSVtNUEPYh6VGuh9/9dVXQXbOOedEl/3zn/8cZLFL5cYms65ZsybIjjrqqCCLTewzi7/EW2yxRZD1798/yC688MIga4AT/qIPkhEAAAASRAEAAECCKAAAAEgQBQAAAAlKZhJgbCKeJP2///f/guz9998PshNOOCHIttxyyyB74IEHguzLL78Msvbt20f7861vfSvIvvjiiyDbfvvtgyzW73rWECZQNZl9uCb69esXZA8//HDR7d1yyy1BNmjQoKLba0Qawj4sJbAfx95jJ0+eXNZtxD7vYld5leKT+7p27VrW/tQhJgECAACPAgAAgARRAAAAkCAKAAAAEtQkJwHOmTMnyA444IDosnPnzi16O7GrPcUmshxyyCFBtueee0bb3HTTTYPsxhtvDLLY1QFfeeWVINtjjz2i26kjDWECVaPch2ui3Ff9i/nb3/4WZN27dy/rNhqohrAPSwnsx7Grsp5++ull3UZsUuGPf/zj6LIbbbRRWbddz5gECAAAPAoAAAASRAEAAECCKAAAAEjQxvXdgVKtXLkyyIYPHx5kNZns17Zt2yA78cQTgyw2Oa9ly5YFbycm9njmzZsXZB07dgyy73znOyVtGwDqy9SpU8vaXuw9+9hjjy3rNho7RgAAAEgQBQAAAAmiAAAAIEEUAAAAJKjRTwK85JJLguwPf/hDweu3adMmyGbMmBFkO+64Y806Vo358+dH82HDhgXZE088EWT3339/kMUmLwI11bdv3yBL5Kp/qANjx46N5g8++GBZt7N+/fqyttcUMQIAAECCKAAAAEgQBQAAAAmiAAAAIEGN6uuAZ86cGWR77713kH399ddBdsopp0TbvO2224KsdevWBfVn3bp1QfbFF18E2auvvhpklX3NZbt27YJswoQJQdarV69CuljfGsJXqTaofbg29OvXL8gefvjhotuLTQJ86KGHim6vkWsI+7DUSPfjN954I8gqe++KvZ9uueWWQXbqqacG2fXXX19Qf9auXVvQck0QXwcMAAA8CgAAABJEAQAAQIIoAAAASBAFAAAACWqwZwGsWrUqyGIz+SdPnhxkhx56aJD97//+b8HbXrlyZZB9+eWXQXbLLbcE2XXXXRdkzZs3D7LY2QuSdO655wbZCSecEF22EWgIM6gb5ezpmjAr79P8t7/9LcgSvhRwQ9iHpUa6H7/yyitBtt9++0WX3XrrrYNs2rRpQbbddtsF2RZbbFFQfzgLYEOMAAAAkCAKAAAAEkQBAABAgigAAABI0Mb13QEp/r3NF198cZDFJvzFxCYB7r///gX359NPPw2yuXPnBlmXLl2C7LTTTguyYcOGBdl3v/vdgvsDSPFL/pYqdtnfhCf8ocyuvvrqgpc9+eSTgyy2L8YuGRybPD127NiCt50qRgAAAEgQBQAAAAmiAAAAIEEUAAAAJKhBTAJcvHhxkP33f/930e1deOGFQVbZFQ9jV1E76KCDgiw2ke+MM84Isk022aSQLgJAk/LWW28F2dNPP13w+qeffnpBy8Xey5ctW1bwdvBvjAAAAJAgCgAAABJEAQAAQIIoAAAASFCDmAQ4fPjwINtnn32C7OGHHw6yDz74IMhefPHFIOvVq1d02506dQqy2BX+mjWjVkLd+fDDD+tkO3369KmT7aDpGzVqVJCtWLGi7NuJtTl+/PiybycFfKoBAJAgCgAAABJEAQAAQIIoAAAASFCDmAR42223Fb3utttuG2QHH3xwKd0BksEkQJTLwoULy97m2rVrg+zOO+8s+3ZSxQgAAAAJogAAACBBFAAAACSIAgAAgAQ1iEmAADbUvXv3IKtswl7sCpkxt9xyS0HbAYpx7bXXBlnsq9XXrFkTXf83v/lNkLVu3TrICr3qX7t27QpaLmWMAAAAkCAKAAAAEkQBAABAgigAAABIkDnnqvp9lb8EqmH13QGxD6M0DWEflhrpfnzzzTcH2YUXXhhddt26dUVvZ4sttgiyN954I8i+/e1vF72NRi66HzMCAABAgigAAABIEAUAAAAJogAAACBBFAAAACSIswBQmxrCDGr2YZSiIezDUhPaj6dNmxbNhwwZEmQLFy4MsquuuirIDj744CBLeMZ/DGcBAAAAjwIAAIAEUQAAAJAgCgAAABLEJEDUpoYwgYp9GKVoCPuwxH6M0jAJEAAAeBQAAAAkiAIAAIAEUQAAAJAgCgAAABJEAQAAQIIoAAAASBAFAAAACaIAAAAgQdVdCRAAADRBjAAAAJAgCgAAABJEAQAAQIIoAAAASBAFAAAACaIAAAAgQRQAAAAkiAIAAIAEUQAAAJAgCgAAABJEAQAAQIIoAEpkZvub2ZNmNt/MlpvZR2Y23sx2qO++lYOZzTGzMXW4vR3N7LnsuXRm1rkWtnGomQ2J5OPN7LUC1ndmdm65+1WXCn2s1bTRM2vnQzNbb2bjK1nuMjN7ysyW1tZrCqDmKABKYGb7S5ouaYmkMyUdK2mspB0ldaq/njVqoyW1k3S0pF6SPquFbRwqKSgAaqCXpAfL1Jf6coWk00ps4weS9pf0qqTPq1juHEkbS/pridsDUEYb13cHGrmBkj6Q1N/9+2sVn5T0WzOz+utWo9ZD0mPOuadLaSR7/ls451aWp1v/5px7qdxtmlkr59yKcrdbGefc7DI0c7Nz7reSVM1oQkfn3HozO0q+sAPQADACUJp2kr50ke9Uzs1iQ8ZmNtLMFuTcPy1bbg8zm25mX5vZW9n9NmZ2t5ktMbOPzezEqjqVrR/8hWpmo83s04rixMyuMbN3zWyZmc01swlmtk0BbU/Oy3pnfd8lJ2tpZteZ2T/MbJWZvW1mR1bRbmczc5K6ShqatTc95/fnZodXVpnZLDMbmrf+SDNbkB2SeVXSSkn9I9sZKemXkjpl23D5Q9dmdoiZvZMdhnjezHbO+/0Gr2e2zeeyIe6l2esWbDv/sZrZSWZ2r5ktlvTH7HcbZY/l0+yxvm9mP81Z96Bs3e1yshlmts7M2uVk75rZVVX0YYNDAGbWzszGmdk8M1uZbf+OytaXJOfc+qp+X9PlANQtCoDSvCHpoOwYZ5cytXmPpD9I6ifJJE2WdKekeZKOk/SypHut6jkGkyQdaWZtKoLsQ/94SQ/kFCdbSbpa0o/kh8S7SJpmZuXYLybLDzFfLenH8sPEj5nZbpUs/5n80PrnkiZm/x6U9f1sSTdLeixr60FJ15vZJXlttJZ//sZJOlzSK5HtjMva/zzbRi/54fAKHeUPQ1wl6UT552hSZSM6ZtZW0p8kfSz/mh0n6T754rA6YyR9JV+oXJ1lv5E0XNLv5f9afkHShJyi72VJayT9MNt+a0l7SlotPyQvM2svaWdJzxXQhwo3yA/nD5V0mKRhkoLCFkAT4pzjVuRNUltJ0+TfKJ38h/TtkrrlLecknZuXjZS0IOf+adlyp+ZkR2bZXTnZ5vIfAAOr6FcHSWslnZCT9cra6lnJOhtJ2j5b5oCcfI6kMTn3p0uanLdu72y9XbL7fbL7B+Yt96ykB6t5TvO310zSPyXdnbfcrfJzL1rmPJ9O0jEFvG5jJM2J5OOz5+27OdmxWbs9Yq+npJ7Z/c1qsN90ztaZkpe3l7Rc0uV5+eOSPsy5P0PS2OzfB0uaL+l+Sddk2dGS1klqW0Ufxkt6Lef+e5LOK+H/wmuSxlezzFHZ4+5c7Ha4ceNWvhsjACVwzi2V/7DbT/4vuNmSzpL0hpntUWSzuce+Z2U/p+Vsc4n8G/72VfRrfrbOgJx4gKTZzrncYd8jzOxFM1si/8E3N/tVtyL7XuE/5f/CfsHMNq64yT+2njVsawdJ2ymcdDdJvgD7Xk7mJP25uC5/Y45z7qOc+zNz+hEzW9IySRPN7JjcYfgCTM27v4v8KEbssXYzsw7Z/WeVjQBIOkDS85KeycvezvbPQr0l6SIzG2Rmpb7+ABoBCoASOW+Gc264c+6H8h9w6yVdVmSTi3P+vTqSVeQtq2nnfklHmFnbbEi/v/wHiSTJzPaSH1KfK+kU+RGCfbNfV9d2dbaUtI38SEXubaSkb9ewrW2zn1/k5RX32+dki5xzq1Wa2HMtVfKcOOcWSTpEUnNJD0iab2ZTCzwklP+YCn2sz0naJSs2fpjdf05STzNrmZPVxLmSHpE0QtKH2XyLE2rYBoBGhAKgzJxzb8mfCdAjJ14laZO8Rbeo5a5MkZ9DcIz8sd3tlFMASPqJ/EjCAOfcY87PbK/qVK4KK1X9Y1koP2y/V+S2r2qm4jTArfLyrXO2VaFejlk7515yzh0uf9y/r/wIysRCVs27X+hjfSH72Vv++XxW0vvyIxF9JO2hGhYAzrnFzrnznXPbSNpVfq7BBDPbqSbtAGg8KABKYGb5b9QVk+26asO/4ubKXxugYplm8m/UtSb7y/QJ+aH/AZI+cM69k7NIK0lrnHO5H0InFdD0XG1Y3Ej+vPpcT8uPACxzzr2Wf6vRA/Hbm6dwRv/xkpZKereG7UmFjaDUmHNuhXPuj5LuklTMB+d7kr5W/LH+PTu0U/Havic/YW+dpDez1/F5Sb+SP723piMA38j2k4vk3x/yX2sATQTXASjNuOzD/CH5Y8FbSDpd/i+o3DfxKZIGm9mb8rPFz5I/fl3bJsl/GC2Rv0BRriclDTGzm+RPQdtP0skFtDlF0plmdqP8MeyD5Gfc57f9F0lPmtm18n+dtpW0m/ykvUsLfQDOnz8+UtLvzOxfWdsHyl+DYZgr7jz/v0na2sxOk/8gXeCcm1NEOzKzH0k6Q374/FP5uRnnKGfeRqGccwuz1+PXZrZWfmJdX/nJoPmnfj4nabCkvzjn1uVkoyV95JzLP4xQ3eN4Xv61fU9+ZOJs+QmJsTMpKtbpIP9aSH7f72Rmx2WPZXLOcgfKT0zdM4uOMLP5kmY652YKQL2gACjNrfKz90fIH79dLP9hd5hz7omc5UbJD+teKf/X59hsucG13L9H5Sf3bSk/J+AbzrnHzexiSefJv9nPkJ+l/feqGnTOTTWzYfKn6J2VbeMX2c+KZZyZ9ZU/lWyI/Kl1C+Unmt1c0wfhnLsjO7b9i+w2V9IvnXM31rStzAPyhct18h9M96j4q+LNkv/AvFr+NZ4vf1rgsCLbGyH/mg2UH/qfJelk59z9ectVFADP5mWSHwmoqRnyz0FnZaMKko5wzs2tYp2dteGExS7yhyUkf/ipwij9u1CQ/P+binxkEX0FUAa24QgwAABIAXMAAABIEAUAAAAJogAAACBBFAAAACSourMAmCGIUjSEr0RmH0YpGsI+LLEfozTR/ZgRAAAAEkQBAABAgigAAABIEAUAAAAJogAAACBBFAAAACSIAgAAgARRAAAAkCAKAAAAEkQBAABAgigAAABIEAUAAAAJogAAACBBFAAAACSouq8DBtBEfPLJJ0HWvXv3IFuzZk2Q9evXL8juueee6HbatGlTRO8A1DVGAAAASBAFAAAACaIAAAAgQRQAAAAkiAIAAIAEcRYA0AR9/PHHQXbIIYcE2bp164KsWbPw74IpU6YEWatWraLbvu+++wrpIoB6xggAAAAJogAAACBBFAAAACSIAgAAgAQxCTBi9erVQXbYYYcF2fTp04PMzILskUceCbKjjz66uM4hWV9//XU0f+KJJ4LsjDPOCLIlS5YUtJ2NNw7fFoYPHx5k7dq1K6g9AA0TIwAAACSIAgAAgARRAAAAkCAKAAAAEmTOuap+X+Uvm6qVK1cG2eDBg4Ns8eLFQRab8Pe9730vyF5++eXotlu0aFFIFxuLcEZk3Wsy+/C7774bzXfbbbei2xw4cGCQXXjhhUHWuXPnorfRyDWEfVhqQvtxffryyy+D7Mwzz4wu+6c//SnIli1bFmRt2rQpvWO1L7ofMwIAAECCKAAAAEgQBQAAAAmiAAAAIEHJXwlw/fr1QXbHHXcE2cknnxxkLVu2DLLYJMDY5K21a9dG+9PEJgGijJ5++umyt3niiScGWcIT/tCETJ48OcjOPvvsIFu6dGl0/dhVXZsaRgAAAEgQBQAAAAmiAAAAIEEUAAAAJCj5SYDNmoU10Omnnx5kvXv3DrJ//vOftdElIHrFsVInAX7nO98Jsl122aWkNoG6FrtS60033RRkl19+eZDtvvvuQfbqq6+Wp2ONECMAAAAkiAIAAIAEUQAAAJAgCgAAABKU/CTAQr355ptFrzf+VOUAABxSSURBVNu9e/cg22ijjUrpDpqQFStWBFnsCn2PP/54wW3uvPPOQfbEE08E2eabb15wm0Bd+8c//hFke++9d5B98cUXQTZgwIAg+81vfhNkPXr0KLJ3jR8jAAAAJIgCAACABFEAAACQIAoAAAASxCTAiObNmwfZrrvuGmRvv/12Qe116tQpyJgEiAqrVq0KsppM+IvZcccdg2ybbbYpqU2gHD777LNo/tOf/jTInnnmmYLaHDp0aJBdc801QRb7v+acK2gbTREjAAAAJIgCAACABFEAAACQIAoAAAASRAEAAECCkj8LYP369UF2zz33BNmmm25aUHuHHnpokMUuwbp69ero+rEzEICaOv/88+u7C0jMggULguyWW24JstGjR0fXj10Se//99w+yMWPGBFns8sAxsbMAzCy67E9+8pMga926dUHbaSwYAQAAIEEUAAAAJIgCAACABFEAAACQoOQnAcYuAzlw4MCi24tN+OvTp0+QbbLJJkVvA03L+++/X9L6vXr1CrI999wzyGLfrT5v3rwgu/rqq4Nss802C7Lzzjsv2p+OHTsG2bbbbhtdFg3f3Llzg+zhhx8OsmHDhgVZbGJfu3btott59NFHg+yAAw4IslImSq9du7bgZWPv0ZVNGGysGAEAACBBFAAAACSIAgAAgARRAAAAkKDkJwHGbL/99kH29ddfB9miRYsKam/UqFFBxhX/UOGqq64qaf2zzz47yEaMGBFksUlWs2bNKnq7f/jDH6L59773vSCLXantZz/7WZDtu+++QbbxxrxN1YY1a9YE2bhx44Js8ODBQVboZLjYxMCLL744umyhV1stxdSpU2t9G40JIwAAACSIAgAAgARRAAAAkCAKAAAAEsTsmojYVctiX3VZ6CTArl27ltwnYKuttormt99+e5C98sortd2dSr377rsFZXfeeWeQ7bTTTkH21FNPBdnWW29dZO/SU9lkzdhXRv/rX/8qqM2//vWvQRa7al9DE7saZuxqsJL085//vLa7U+8YAQAAIEEUAAAAJIgCAACABFEAAACQICYBRrz22mv13QUg0KNHj2j+0EMPBdmXX35Z293R5ZdfHs1jk/YWL15cUJszZ84MspdeeinIjjnmmILaS01sktuZZ54ZXXbVqlVBFrvCX2yS3Iknnhhkp5xySpDtt99+QdazZ89of2JfMRz7OuHu3bsX3Ga++++/P8gqu6phoW02ZowAAACQIAoAAAASRAEAAECCKAAAAEhQ8pMADz744PruAhIyd+7cIJs9e3ZB61Y2Wal9+/YFZeU2adKkaD5hwoQgi331L8rv29/+dpD9+c9/ji673XbbBdnrr78eZPfee2+Qvffee0F24403Btno0aODrLIr7xX6FcMxsTZLaU+Sdt111yCLXSX21FNPDbLTTz89yBri11ozAgAAQIIoAAAASBAFAAAACaIAAAAgQQ1vVkIde/PNN+u7C0hIbHJehw4dgmzWrFlB9tVXX0Xb/OKLLwpqs1mz8tb769ati+afffZZWbeD0hx44IEFL/vd7343yE444YSC1v3444+D7JFHHgmy2ERYSfrWt74VZN26dQuy2JUhY5MAY1/XHpvQWJlPPvkkyObMmRNkM2bMCLLYFRB33nnngrddVxgBAAAgQRQAAAAkiAIAAIAEUQAAAJAgCgAAABKU/FkALVq0CLLly5fXQ0+QgtatWwdZ27ZtC1r3jTfeiOaxS7o+8MADQdavX7+CthOb3X/rrbcG2dKlS6PrjxgxoqDtxHTt2jXI9thjj6LbQ93p0qVLkF1wwQVl307//v2LXjd2id7evXtHl917772D7JprrgmyP/3pT0HWqVOnmneuHjACAABAgigAAABIEAUAAAAJogAAACBBVtl3M2eq/GVT8OCDDwZZoZe+LFTs0qhbbbVVWbfRQJX2hdzl0eD34f/7v/8LstiEqppo1apVkG2//fYFrRt7T5g9e3ZJ/YmJTfh75plngiz2Hex1qCHsw1Ij2I8bg2XLlgXZ5ptvHl12wIABQTZx4sSy96mORPdjRgAAAEgQBQAAAAmiAAAAIEEUAAAAJCj5KwH27du3vruAxFU2CakUK1asCLJZs2aVfTuF2mGHHYLs6aefDrJ6nvCHJu6FF16o7y40KIwAAACQIAoAAAASRAEAAECCKAAAAEhQ8pMAmzULa6DY1cgOPPDAuugOEhSbBBj7Supf/vKX0fVvv/32svepEIceemg0HzhwYJAddthhQRb7Km6gNrVr167gZQ8//PBa7EnDwAgAAAAJogAAACBBFAAAACSIAgAAgAQlPwnQLPyWxJ49ewbZ3nvvXVB7c+bMCbIrr7wyyG644Ybo+htvnPxLkpzYPtiyZcsgGzNmTHT92NdXv/XWW0E2ZMiQIDvkkEOCbPjw4dHt5Ntjjz2ieZs2bQpaH6hrHTp0CLLY119XlTcljAAAAJAgCgAAABJEAQAAQIIoAAAASJBVM9Gh6c+CKNDKlSuD7OOPPw6y2GTB2Fezfuc734lu59prrw2yY489tpAuNkTh7La6xz6MUjSEfVhiPy6LZcuWBVllX8c9YMCAIJs4cWLZ+1RHovsxIwAAACSIAgAAgARRAAAAkCAKAAAAEsRl5woUuzJbjx49guySSy4JsthXu1566aXR7Xz22WdF9A4AUE6rVq0KsvXr1wdZ7CvlG4vG23MAAFA0CgAAABJEAQAAQIIoAAAASBAFAAAACeJSwKhNDeEyquzDKEVD2Icl9uOyqMmlgGOfjV999VWQtWnTpvSO1T4uBQwAADwKAAAAEkQBAABAgigAAABIEJMAUZsawgQq9mGUoiHswxL7MUrDJEAAAOBRAAAAkCAKAAAAEkQBAABAgigAAABIEAUAAAAJogAAACBBFAAAACSIAgAAgARVdyVAAADQBDECAABAgigAAABIEAUAAAAJogAAACBBFAAAACSIAgAAgARRAAAAkCAKAAAAEkQBAABAgigAAABIEAUAAAAJogAokZntb2ZPmtl8M1tuZh+Z2Xgz26G++1YOZjbHzMbU4fZ2NLPnsufSmVnnWtjGoWY2JJKPN7PXCljfmdm55e5XXSr0sVbThpnZuWb2vpl9bWb/Z2Y3m1m7cvUTQO3ZuL470JiZ2f6Spkt6RNKZklZI2knSTyV1kjS33jrXeI2W1E7S0ZKWS/qsFrZxqKTjJN1U5Pq9JH1Svu7UiysktSqxjfPkn8Mr5P8fdJN0taSOko4psW0AtYwCoDQDJX0gqb/799cqPinpt2Zm9detRq2HpMecc0+X0kj2/Ldwzq0sT7f+zTn3UrnbNLNWzrkV5W63Ms652WVo5qeSpjjnLs/u/9XMWki60czaOOeWl2EbAGoJhwBK007Sly7yncq5WWzI2MxGmtmCnPunZcvtYWbTsyHVt7L7bczsbjNbYmYfm9mJVXUqW//BSD7azD6tKE7M7Boze9fMlpnZXDObYGbbFND25Lysd9b3XXKylmZ2nZn9w8xWmdnbZnZkFe12NjMnqaukoVl703N+f252eGWVmc0ys6F56480swXZIZlXJa2U1D+ynZGSfimpU7YNZ2bj85Y5xMzeyQ5DPG9mO+f9foPXM9vmc2a2NLu9ZWbBtvMfq5mdZGb3mtliSX/MfrdR9lg+zR7r+2b205x1D8rW3S4nm2Fm63KH3rPX9aoq+rDBIQAza2dm48xsnpmtzLZ/R2XrZ5pLWpKXLZZk2Q1AA0YBUJo3JB1kZpeZWZcytXmPpD9I6if/JjpZ0p2S5skPW78s6V6reo7BJElHmlmbiiD70D9e0gM5xclW8kO2P5I0RFIXSdPMrBz7xWRJp2Xt/1jSq5IeM7PdKln+M/mh9c8lTcz+PSjr+9mSbpb0WNbWg5KuN7NL8tpoLf/8jZN0uKRXItsZl7X/ebaNXvJD2BU6yh+GuErSifLP0aTKRnTMrK2kP0n6WP41O07SffLFYXXGSPpKvlC5Ost+I2m4pN/LHwZ5QdKEnKLvZUlrJP0w235rSXtKWi3pB1nWXtLOkp4roA8VbpC0v6Shkg6TNExSUNjmGSfpeDM70sw2M7PdJV0iabxzblkNtg2gPjjnuBV5k9RW0jT5N0on/yF9u6Ruecs5SefmZSMlLci5f1q23Kk52ZFZdldOtrn8B8DAKvrVQdJaSSfkZL2ytnpWss5GkrbPljkgJ58jaUzO/emSJuet2ztbb5fsfp/s/oF5yz0r6cFqntP87TWT9E9Jd+ctd6v8X58tc55PJ+mYAl63MZLmRPLx2fP23Zzs2KzdHrHXU1LP7P5mNdhvOmfrTMnL28vPe7g8L39c0oc592dIGpv9+2BJ8yXdL+maLDta0jpJbavow3hJr+Xcf0/SeUX8H/hVtq2K/wNTJDUv1/8xbty41d6NEYASOOeWyn/Y7Sf/F9xsSWdJesPM9iiy2dxj37Oyn9NytrlE/g1/+yr6NT9bZ0BOPEDSbOdc7rDvEWb2opktkf/gq5i02K3Ivlf4T/m/sF8ws40rbvKPrWcN29pB0nbyf/XnmiRfgH0vJ3OS/lxcl78xxzn3Uc79mTn9iJktaZmkiWZ2jNVsBvzUvPu7yI9ixB5rNzPrkN1/VtkIgKQDJD0v6Zm87O1s/yzUW5IuMrNBZlbQ65+NSlwm6deSDpR0hqS95EesADRwFAAlct4M59xw59wP5T/g1su/MRZjcc6/V0eyirxlNe3cL+kIM2ubDen3l/8gkSSZ2V7yQ+pzJZ0iP0Kwb/br6tquzpaStpEfqci9jZT07Rq2tW3284u8vOJ++5xskXNutUoTe66lSp4T59wiSYfIHw9/QNJ8M5ta4CGh/MdU6GN9TtIuWbHxw+z+c5J6mlnLnKwmzpU/m2WEpA+z+RYnVLZwtk/dLOm/nXP/5Zx71jl3t/zZMKeUUAADqCMUAGXmnHtL/kyAHjnxKkmb5C26RS13ZYr8HIJj5I/tbqecAkDST+RHEgY45x5zfmb75wW0u1LVP5aF8sP2e0Vu+6pmKk4D3Cov3zpnWxWqO2ZdK5xzLznnDpc/7t9XfgRlYiGr5t0v9LG+kP3sLf98PivpffmRiD6S9lANCwDn3GLn3PnOuW0k7So/12CCme1UySpbSvqW/MhBrjezn11rsn0AdY8CoARmlv9GXTHZrqs2/CturqQdc5ZpJv9GXWuyv0yfkB/6HyDpA+fcOzmLtJK0xjmX+yF0UgFNz9WGxY3kz6vP9bT8CMAy59xr+bcaPRC/vXkKZ/QfL2mppHdr2J5U2AhKjTnnVjjn/ijpLvnrQdTUe5K+Vvyx/j07tFPx2r4nP2FvnaQ3s9fxeflj8hur5iMA38j2k4vk3x/yX+sK87O+5v+lv2f2c06x2wdQN7gOQGnGZR/mD8kfC95C0unyf0HlvolPkTTYzN6Uny1+lvzx69o2Sf7DaImksXm/e1LSEDO7Sf4UtP0knVxAm1MknWlmN8ofwz5IfsZ9ftt/kfSkmV0r/9dpW0m7yU/au7TQB+CcW5+duvc7M/tX1vaB8tdgGOaKO8//b5K2NrPT5D9IFzjn5hTRjszsR/LHvh+R9Kn83IxzlDNvo1DOuYXZ6/FrM1sr6TX5EYUj5c9IyPWcpMGS/uKcW5eTjZb0kXMu/zBCdY/jefnX9j35kYmz5Sckxs6kkHPOmdnv5U/Z/Fp+VKKrpFGSXpL0ek22D6DuUQCU5lb52fsj5I/fLpb/sDvMOfdEznKj5Id1r5T/63NsttzgWu7fo/KT+7aUnxPwDefc42Z2sfzV3M6Wn1l+lKS/V9Wgc26qmQ2TP0XvrGwbv8h+VizjzKyv/KlkQ+RPrVsoP1x8c00fhHPujuzY9i+y21xJv3TO3VjTtjIPyBcu18mfMXGP/OtYjFnyH5hXy7/G8+VPCxxWZHsj5F+zgfJD/7Mkneycuz9vuYoC4Nm8TPIjATU1Q/456KxsVEHSEc65qq5meYmkBfJzSC7Vvx/7r51z64voA4A6ZBuOAAMAgBQwBwAAgARRAAAAkCAKAAAAEkQBAABAgqo7C4AZgihFQ/hGOPZhlKIh7MMS+zFKE92PGQEAACBBFAAAACSIAgAAgARRAAAAkCAKAAAAEkQBAABAgigAAABIEAUAAAAJogAAACBBFAAAACSIAgAAgARRAAAAkCAKAAAAEkQBAABAgigAAABIEAUAAAAJogAAACBBFAAAACSIAgAAgARRAAAAkKCN67sDdcU5F81/97vfBdnAgQODrFOnTkE2ZcqUINt9992L6B0AAHWLEQAAABJEAQAAQIIoAAAASBAFAAAACbLKJsdlqvxlY7J27dpovskmmxTdZo8ePYLs+eefD7L27dsXvY1Gzuq7A2pC+zDqRUPYhyX247K44oorgmzEiBHRZYcOHRpkN9xwQ9n7VEei+zEjAAAAJIgCAACABFEAAACQIAoAAAASlMyVAOfOnRvNN9100yB75ZVXgmzatGlBdt555wXZlVdeGWTXX399kJk1lLlFaCyWLl0azd9+++0gmzp1apDdd999QTZv3rzSO5YndtXMDz74IMhatWpV9m0DVXnqqaeCrFmz+N/BKbxHMwIAAECCKAAAAEgQBQAAAAmiAAAAIEHJTAKsbKLHsmXLguzvf/97kA0aNCjIPvnkkyC77bbbguyII44IskMOOSTaH0CK74O9e/eOLvv5558XvZ3Y/4vY1TFjE6JWrlwZbfPTTz8NstWrVwcZkwBRm2L/L2bNmlUPPWm4GAEAACBBFAAAACSIAgAAgARRAAAAkKBkJgFutNFGBS/78ssvB9nRRx8dZKNGjQqyvn37Blm/fv2CLHa1QUnaYYcdCukimrhtt902yM4///zosq+//nqQDRw4MMiaN28eZJtvvnmQxfbBhQsXBlm3bt2i/QEagtmzZwdZTSbM/uQnPylndxokRgAAAEgQBQAAAAmiAAAAIEEUAAAAJCiZSYBjx44teNlCr9LXunXrIPuP//iPIIt95XDsK4Il6YYbbgiyFL6WEhvabLPNguySSy6ph55406dPL3jZ7t27B1mLFi3K2BtgQ1988UWQHX/88QWt26dPn2i+1157ldSnxoARAAAAEkQBAABAgigAAABIEAUAAAAJSmYS4Jo1awpedscddyx6O9tss02QnXTSSUEWu4qgFL/61AEHHFB0f4Caevvtt4PsoosuKnj9Y445JshatmxZUp+AqixYsCDIYlf9a9u2bZBdccUV0TZTmLjKCAAAAAmiAAAAIEEUAAAAJIgCAACABFEAAACQoGTOAnjsscei+VZbbRVksZmipYhdWriyswB+9atfBdmLL74YZM2aUbuhdrzzzjtB9sknnwRZZTP7hw4dWvY+AVWZNGlSQcs9+uijQbbPPvuUuzuNBp8iAAAkiAIAAIAEUQAAAJAgCgAAABLUJCcBvvLKK0E2e/bs6LIDBw4MslatWpW1P7169QqyyiYBXn755UH26quvBlnKE1dQPitXrgyyK6+8sqB1K/t/svXWW5fUJ6AqsUv83nHHHQWt26VLl3J3p1FjBAAAgARRAAAAkCAKAAAAEkQBAABAgprkJMDYlfOcc9FlBw0aVNvdkZkFWWVXS7vrrruC7KyzzgqyN954I8iaN29eRO+QstGjRwfZRx99VNC6lU28mjlzZpD98Y9/DLKDDjooyPbee++Cto103XrrrUH25ZdfBlmfPn2CbMstt6yVPjVWjAAAAJAgCgAAABJEAQAAQIIoAAAASFCTnAT47rvvBtnGG8cf6vbbb1/b3YnadNNNo/mZZ54ZZCNGjAiyf/zjH0HGVa5QU2vWrCl63fPPPz+az5s3L8hat24dZKecckrR20YaVq1aFWTPP/98Qevut99+QVbZV1inihEAAAASRAEAAECCKAAAAEgQBQAAAAlq9JMAFy1aFGT33HNPkB1++OHR9TfffPOy96kUXbt2DbLYlQQfe+yxIBsyZEit9AmNz9q1a4PspZdeCrL/+q//Knobsa9llaRTTz01yGJfc73ddtsVvW2k4eWXXw6yZ555Jsg6dOgQZIMHD66VPjUljAAAAJAgCgAAABJEAQAAQIIoAAAASJBV9jW5mSp/2RA88sgjQda3b98gmzhxYnT9E044oex9KrfYVdRiVzb87LPPgqxNmza10qcChbMX616D34djFi9eHGSxr9SVpHHjxgXZ8uXLgyz2FdKF6tixY5Dddttt0WWPOOKIorfTADWEfVhqpPtxTcyfPz/IdtpppyBbuHBhkF122WVBNnLkyLL0q4mI7seMAAAAkCAKAAAAEkQBAABAgigAAABIUKO/EmChttpqq/ruQtH69esXZBMmTAiy2NXf0DjdcccdQXbxxRfXybaPO+64ILv99tuDrH379nXRHSRi5cqVQRab8Lf11lsH2cCBA2ulT00dIwAAACSIAgAAgARRAAAAkCAKAAAAEkQBAABAgpI5C6Axu+SSS4IsdhYAmo7NNtssyGKz8yVp1KhRQbZ69eog23333Qva9ve///0gY8Y/alvsTJOYBx54IMhiZwageowAAACQIAoAAAASRAEAAECCKAAAAEiQOVfl10w3+O+gnj59epAdfPDBQTZo0KDo+mPHji13l8ou9t3wsUlZixYtCrLNN9+8VvpUoIbwXeoNfh+uDTNnzgyyXXbZJchi+8fs2bODLOFJgA1hH5aa0H780ksvRfPDDz88yHbdddcge+qpp4KsefPmpXesaYvux4wAAACQIAoAAAASRAEAAECCKAAAAEhQo78SYK9evYKsY8eOQfa73/0uuv7ll18eZB06dCi9Y2X0+9//PsiOOeaYIItdPQ5pGjFiREHLnXPOOUGW8IQ/1IH/+Z//ieZfffVVkLVp0ybImPBXPowAAACQIAoAAAASRAEAAECCKAAAAEhQo58E2KJFiyCLfa3kUUcdFV3/2GOPDbLHH388yMp9Rb3Y17VK8auwjRw5MshefPHFIGvWjHoO3rRp0wpa7gc/+EEt9wQoTOwrfW+99dZ66Ek6+MQAACBBFAAAACSIAgAAgARRAAAAkKBGPwkwJva1krfddlt02diV0Dp37hxkffv2DbK99tqroP7MnTs3yCZOnBhdds6cOUEWm5S42267FbRtAKgvCxcuDLK77roruuw+++wTZLH3YpQPIwAAACSIAgAAgARRAAAAkCAKAAAAEtQkJwHGnHXWWdH8+9//fpBNmDAhyGJfJ3z33XcX3Z+hQ4dG80GDBgVZly5dit4OUJUDDjigvruAJuzaa68NslWrVtVDTxDDCAAAAAmiAAAAIEEUAAAAJIgCAACABJlzrqrfV/lLoBpW3x1Qovtw+/btg2zRokVBtm7duiDja6U30BD2YamR7scXXHBBkE2aNCm67Ouvvx5k22yzTdn7lKjofsz/dAAAEkQBAABAgigAAABIEAUAAAAJogAAACBBnAWA2tQQZlAnuQ/feeedQbZgwYIg+9WvfhVkZg3hZWswGsqTkeR+jLLhLAAAAOBRAAAAkCAKAAAAEkQBAABAgpgEiNrUECZQsQ+jFA1hH5bYj1EaJgECAACPAgAAgARRAAAAkCAKAAAAEkQBAABAgigAAABIEAUAAAAJogAAACBBFAAAACSouisBAgCAJogRAAAAEkQBAABAgigAAABIEAUAAAAJogAAACBBFAAAACTo/wO/xBXFI1slCQAAAABJRU5ErkJggg==\n", 531 | "text/plain": [ 532 | "
" 533 | ] 534 | }, 535 | "metadata": { 536 | "needs_background": "light" 537 | }, 538 | "output_type": "display_data" 539 | } 540 | ], 541 | "source": [ 542 | "import matplotlib.pyplot as plt\n", 543 | "%matplotlib inline \n", 544 | "\n", 545 | "n_sample_viz, n_images = 3, 3\n", 546 | "\n", 547 | "fig, axes = plt.subplots(nrows=n_sample_viz, ncols=n_images, figsize=(9.0, 9.0))\n", 548 | "\n", 549 | "for sample_idx in range(n_sample_viz):\n", 550 | " for im_idx in range(n_images):\n", 551 | " axes[sample_idx, im_idx].imshow(X_train_data[im_idx][sample_idx][:, :, 0], cmap='Greys')\n", 552 | " axes[sample_idx, im_idx].axis('off')\n", 553 | " if im_idx==0:\n", 554 | " axes[sample_idx, 0].set_title(' Sum value for this row is {}'.format(y_train_data[sample_idx]), \n", 555 | " fontsize=15, loc='left')" 556 | ] 557 | }, 558 | { 559 | "cell_type": "code", 560 | "execution_count": 111, 561 | "metadata": {}, 562 | "outputs": [ 563 | { 564 | "name": "stdout", 565 | "output_type": "stream", 566 | "text": [ 567 | "Train on 60000 samples\n", 568 | "Epoch 1/3\n", 569 | "60000/60000 [==============================] - 77s 1ms/sample - loss: 1.0678\n", 570 | "Epoch 2/3\n", 571 | "60000/60000 [==============================] - 70s 1ms/sample - loss: 0.5569\n", 572 | "Epoch 3/3\n", 573 | "60000/60000 [==============================] - 71s 1ms/sample - loss: 0.4641\n" 574 | ] 575 | }, 576 | { 577 | "data": { 578 | "text/plain": [ 579 | "" 580 | ] 581 | }, 582 | "execution_count": 111, 583 | "metadata": {}, 584 | "output_type": "execute_result" 585 | } 586 | ], 587 | "source": [ 588 | "# First, define the vision modules\n", 589 | "from tensorflow.keras.layers import Dense\n", 590 | "from tensorflow.keras.layers import Input\n", 591 | "from tensorflow.keras.models import Model\n", 592 | "from tensorflow.keras.layers import Conv2D\n", 593 | "from tensorflow.keras.layers import MaxPooling2D\n", 594 | "from tensorflow.keras.layers import Flatten\n", 595 | "from tensorflow.keras.layers import Dropout\n", 596 | "from tensorflow.keras.layers import Add\n", 597 | "from tensorflow.keras.optimizers import Adam\n", 598 | "\n", 599 | "filters = 64\n", 600 | "kernel_size = 3\n", 601 | "\n", 602 | "import tensorflow as tf\n", 603 | "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n", 604 | "x_train = [np.expand_dims(t, axis=2) for t in x_train]\n", 605 | "x_test = [np.expand_dims(t, axis=2) for t in x_test]\n", 606 | "\n", 607 | "input_image = Input(shape=(28, 28, 1))\n", 608 | "\n", 609 | "y = Conv2D(32, kernel_size=(3, 3),\n", 610 | " activation='relu',\n", 611 | " input_shape=input_shape)(input_image)\n", 612 | "y = Conv2D(64, (3, 3), activation='relu')(y)\n", 613 | "y = MaxPooling2D(pool_size=(2, 2))(y)\n", 614 | "y = Dropout(0.25)(y)\n", 615 | "y = Flatten()(y)\n", 616 | "y = Dense(32, activation='relu')(y)\n", 617 | "y = Dense(16, activation='relu')(y)\n", 618 | "output_vec = Dense(1)(y)\n", 619 | "\n", 620 | "vision_model = Model(input_image, output_vec)\n", 621 | "vision_model.compile(loss='mae')\n", 622 | "vision_model.fit(np.array(x_train), np.array(y_train), epochs=3, batch_size=64)" 623 | ] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "execution_count": 2, 628 | "metadata": {}, 629 | "outputs": [], 630 | "source": [ 631 | "# vision_model.save('vision_model.h5')" 632 | ] 633 | }, 634 | { 635 | "cell_type": "code", 636 | "execution_count": 1, 637 | "metadata": {}, 638 | "outputs": [ 639 | { 640 | "data": { 641 | "application/vnd.jupyter.widget-view+json": { 642 | "model_id": "7581c681606f4d65ba471a048804c4a1", 643 | "version_major": 2, 644 | "version_minor": 0 645 | }, 646 | "text/plain": [ 647 | "HBox(children=(FloatProgress(value=0.0, max=964.0), HTML(value='')))" 648 | ] 649 | }, 650 | "metadata": {}, 651 | "output_type": "display_data" 652 | }, 653 | { 654 | "name": "stdout", 655 | "output_type": "stream", 656 | "text": [ 657 | "\n" 658 | ] 659 | } 660 | ], 661 | "source": [ 662 | "import os\n", 663 | "import numpy as np\n", 664 | "import matplotlib.pyplot as plt\n", 665 | "%matplotlib inline\n", 666 | "from tqdm.notebook import tqdm\n", 667 | "\n", 668 | "\n", 669 | "# Get list of different character paths\n", 670 | "img_dir = './images_background'\n", 671 | "alphabet_names = [a for a in os.listdir(img_dir) if a[0] != '.'] # get folder names\n", 672 | "char_paths = []\n", 673 | "for lang in alphabet_names:\n", 674 | " for char in [a for a in os.listdir(img_dir+'/'+lang) if a[0] != '.']:\n", 675 | " char_paths.append(img_dir+'/'+lang+'/'+char)\n", 676 | "\n", 677 | "char_to_png = {char_path: os.listdir(char_path) for char_path in tqdm(char_paths)}" 678 | ] 679 | }, 680 | { 681 | "cell_type": "code", 682 | "execution_count": 2, 683 | "metadata": {}, 684 | "outputs": [], 685 | "source": [ 686 | "from typing import List\n", 687 | "\n", 688 | "def draw_one_sample(char_paths: List, sample_size=6):\n", 689 | " n_chars = np.random.randint(low=1, high=sample_size, size=1)\n", 690 | " selected_chars = np.random.choice(char_paths, size=n_chars, replace=False)\n", 691 | " rep_char_list = selected_chars.tolist() + \\\n", 692 | " np.random.choice(selected_chars, size=sample_size-len(selected_chars), replace=True).tolist()\n", 693 | " sampled_paths = [char_path+'/'+np.random.choice(char_to_png[char_path]) for char_path in rep_char_list]\n", 694 | " return sampled_paths, n_chars[0]\n", 695 | "\n", 696 | "sampled_paths, n_chars = draw_one_sample(char_paths, sample_size=6)" 697 | ] 698 | }, 699 | { 700 | "cell_type": "code", 701 | "execution_count": 3, 702 | "metadata": {}, 703 | "outputs": [ 704 | { 705 | "name": "stdout", 706 | "output_type": "stream", 707 | "text": [ 708 | "Number of selected characters is 4\n" 709 | ] 710 | }, 711 | { 712 | "data": { 713 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAAD3CAYAAAC+eIeLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAT0ElEQVR4nO3df6xkZX3H8fcXF1gWWHcXFaGwWH9ACgY3JIZQWiC1oZCWiBJjJIUuKNU2FqU2EU2sJpZUiQmaWII2DVo0IEYFakGIGmMwJTQ2VULAqGuE5Ycuslt+qbjw9I9zRg/D/Lr3zjzznHPer2Ry750zZ86P7zyf88xzzsyNlBKSpDz2WfYKSFKfGLqSlJGhK0kZGbqSlJGhK0kZGbqSlFHW0I2I7RFxe85lavGsazdZ18XoZU83Ik6LiJ1rfI6jI+LGiNgVEY9GxK0RcUxj+lUR8UTj9uuIeHzta69x5lHX+nk+HRE/iIhnI2L70LS/iojvRsRjEbEzIi6PiHVrXabGm2Ndt9W1e6r+ua0x7ZKI2FHX9cGIuGJRdW1t6C7zhV4vexNwE3AMcChwJ3Dj4DEppXeklA4a3IBrgS8uY33bpIC6AnwP+Fvgf0Y8bAPwbuBFwInA64B/yLKCLbbsukbEflTt83PAZuCzwI31/VC15RNSShuBVwOvAS5eyAqllBZyA44EvgzsAn4BfBLYDtwOfAzYDfwEOLMxzwXAPcDjwA7g7Y1ppwE7gfcCDwPX1Dvvq/Uydte/H9GYZwtwNfBgPf0G4EDgl8CzwBP17XCqA9ClwI/r9b0e2FI/z8uABLwVuA/49ojt3VI/5pAR0w6st+nURe3vXLc+1bXepu1T9sffA/+x7LpY18l1BU4HHgCisbz7gDNG7ItDgK8DVy5iXy+kpxsRL6h36E/rHfB7wHX15BOBH1D1FC4H/i0iop72c+AvgI1UBb0iIk5oPPVLqQpzFPDXVDv+6vrvrVTF+WTj8ddQ9UyOA14CXJFSehI4E3gw/a4n+iDwd8DZwKlURd0N/MvQpp0K/AHwZyM2+xTg4ZTSL0ZMO4fqhfbtEdNao6d1neYU4O5VzFeMntT1OOD7qU7V2vfr+wf74dyIeAx4hKqn+6kpu251FnTUPIkqZNYN3b8d+FHj7w1UR6SXjnmeG4B3NY6cTwPrJyx3G7C7/v0wqqPj5hGPOw3YOXTfPcDrGn8fBvwGWMfvjpwvH7PcI6iOom8ZM/0bwIcWsa9z3npY14k9XeBCqt7ci5ZdG+s6ua7AB4Drhp7j86PaJfAq4MPjtnOtt0WNsxwJ/DSltHfEtIcHv6SUnqoPmgcBRMSZwAeBo6mOihuAuxrz7kop/WrwR0RsAK4AzqB66wJwcH3kPhJ4NKW0e8Z1Pgr4SkQ827jvGarx2oH7h2eKiBcDt1G9Fbl2xPStVC+ai2Zcj5L1pq7TRMTZwD8Df5pSemSl8xemD3V9gqpH3rSRamjkOVJKP4yIu4ErgTfOuD4zW9SJtPuBrSsZPI+I/YEvUY0fHZpS2gTcDETjYcNfifYeqhNZJ6ZqAPyUwdPV67AlIjaNWNyor1a7n2q8alPjtj6l9MC4+SJiM1Xg3pRSumzMpp0HfCeltGPM9DbpRV1n2KYzgH8Fzkop3TXt8S3Qh7reDRzfGBoBOJ7xQ0PrgFeMmbYmiwrdO4GHgI9ExIERsT4iTp4yz37A/lRvc/bWR9HTp8xzMNW40J6I2EJ11AUgpfQQcAtwZURsjoh9I2JQ5J8Bh0TECxvPdRVwWUQcBVUPNiJeP27BEbERuJUqUC+dsI7nA5+Zsh1t0fm61o/ZLyLWU4XBvvV27lNP+xOqt6XnpJTunLIdbdGHun6Lqid8cUTsHxHvrO//Zj3/2yLiJfXvxwLvoxoWnLuFhG5K6RngLOCVVGcIdwJvnjLP41SXaFxPNSh+LtVlHJN8HDiAauD7DuBrQ9PPoxrnuZdq0P/d9bLupbqEa0dE7ImIw4FP1Mu7Larrae+gOokwzhuA1wIXxHOvx906eEBEnEQ13tuJS8V6Uleo3r38EvhD4NP174MA+ADwQuDmRs1vmfJ8RetDXVNKT1OdeDsf2EM1Hn92fT/AycBdEfEkVY/9ZuD9U7ZnVaIeOJYkZdDaD0dIUhsZupKUkaErSRkZupKUkaErSRlNuxjaSxvKEdMfMjPrWo551hWsbUlG1taeriRlZOhKUkaGriRlZOhKUkaGriRlZOhKUkaGriRlZOhKUkaGriRlZOhKUkaGriRlZOhKUkaGriRlZOhKUkaGriRlNO37dNUQ8buvx/S/KEtaDUN3imbQDt9v8Epaqd6G7rgwlaRF6mToGqiSStX60DVgJbVJ60K3lJB1PLe9Rr2GrKdyKSp0SwnUYTZISfNSVOgug4EqKadOhO6swVlqT1p5WH+VoJWha++0bMPhtux6GbYqSVGhu+zGqW4xbFWiokJXWiuDVqUzdLVwi75Ey6BVmxi6aqV5B63fpaFcDF0tRe7e6XCgjlq+wascDF3NXUqpiLf8BqhK5JeYq5OmBe646SUcLNRthq4WIqX021upSl43dZfDC2q1tQ5lDM9vEJeli19OZOhqKUpqOCWtiyqTDqRtP+Hp8IKkoszyzqXNY+/2dCUtzCzhOOi1rjRI29rjNXTVSSttwG1svKVaTXiuZVltq53DCxLtfrtakkXsxzZcCbMSvQndLp4FlUoy78AdFbRdaLO9CV1pGnu7ecwSnF0I13FaM6Y7rkF0uThavWmvCwN2uUr5qPgytL6n2/XLS7QYHqy1LK3p6U5S2r+HkaRxWt/THWUl1wZKmo9ZrjBoTu9rG+xET3eUQfD2tbDSstjmJmtN6M7yJdSjOJ4rqSStCd1hq/3ooMowz3F4e1bd0vV6tjZ0B/p86UnJ+vI5eq1NH2veiRNpXfqIoKRua31Pt2mWXq/hvFi+65Am60RPt2lSqBq4i7XawPWdivqkUz3daRw3XJyVBK41UJ91MnQnDTMYvMvhPu8Ph5gm69zwwsCkRu6LYv5G7e+SvwfV18BiuF+n62zogr2r3NryEU+/sW4xDNzZdHJ4oWnchyj8mPBitHV/tnW9S7GMwG1rzTofutKAPbG82hqKi9bp4YWmcS8AG6IMB+XUm9AFG1dfRYQH1yUocZ+XsE4OL+BlZF3mJxSXa9L+n3Xfzzsol93eexe6fkFOfxi4+aymXfW1HfZqeGFgVGPr6wugqwzc/Nyns+ll6I5j8HaDgbs87tvpehu6k65mMHzba9b/Dm2NF6fUTyGWondjuuouvzi9LNP2bV+/JCmmbEx3tnSMeZxdzWSeXbOiNmxe1vLVkks07y53J2vbUiNr29vhhYFxDa6wwNUUowLXGqpEDi/gP7lsu1nfrVhflcDQbbBn1G1eo60S9H54Qe1miKptDF21VotOgkq/5fCCOsWwVens6ao3HIpQCQxddcq4T5sZuCqFwwvqpFlD1uEI5WZPV6211sA0cLUMhq5azeBU2xi6ar3VBK9hrWVxTFedMOtHuQ1bLZuhq04xVFU6hxckKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyCr/0WZLysacrSRkZupKUkaErSRkZupKUUbbQjYjtEXF7ruUpH2vbTdZ1MXrX042I0yJi5xqf4+iIuDEidkXEoxFxa0Qc05geEfFPEfFARPxfRHwrIo5b+9prknnUtn6ebRHx3Yh4qv65rTHtkojYERGPRcSDEXFFRKxb6zI1XqY2e1VEPNG4/ToiHl/72j9fK0N3mS/yetmbgJuAY4BDgTuBGxsPexNwIfDHwBbgv4Br8q5pOy27thGxH1UtPwdsBj4L3FjfD1XdT0gpbQReDbwGuHgZ69smy64rU9psSukdKaWDBjfgWuCLC1mhlNLcb8CRwJeBXcAvgE8C24HbgY8Bu4GfAGc25rkAuAd4HNgBvL0x7TRgJ/Be4GGqANsMfLVexu769yMa82wBrgYerKffABwI/BJ4Fniivh1OdfC5FPhxvb7XA1vq53kZkIC3AvcB3x6xvVvqxxxS//1e4PrG9OOAXy1iX+e+db22wOnAA9TXsNePuw84Y8S+OAT4OnDlsutiXdfWZoemHVhv06kL2dcLKN4LgO8BV9Qrvx74o7qAvwEuqh/zN/XOHXxA48+BVwABnAo8RdWjGBRwL/BRYH/ggPoFfw6wATiY6qh0Q2M9/hP4Ql3ofQc7cPBiGFrndwF3AEfUz/8p4NqhAv57vT0HjNjms4GHGn8fBXwXOLpe9uXNdWvrrQ+1BS4Bbhl6jq8C72n8fS7wWD3vLuA1y66NdV1bmx2adj7VQSRWsz+n7u8FFPCk+oW4buj+7cCPGn9vqHfMS8c8zw3Auxo7/Wlg/YTlbgN2178fRnVk3DzicaMKeA/wusbfh9UvtnWNAr58zHKPoOoZvaVx337AJ+r59lL1EH5/2Y3L2k6vLfAB4Lqh5/g88KERy3sV8OFx29mWWx/qOjTv89rs0PRvjKr3vG6LGGc5EvhpSmnviGkPD35JKT0VEQAHAUTEmcAHqXqH+1AV+K7GvLtSSr8a/BERG6iOzGdQHRkBDo6IF9Tr8GhKafeM63wU8JWIeLZx3zNUYz8D9w/PFBEvBm6jent5bWPSPwKvrdfjYeAvgW9GxHEppadmXKcS9aG2TwAbh55jI9XbzedIKf0wIu4GrgTeOOP6lKgPdR2sw7g2O5i+lSrkL5pxPVZsESfS7ge2rmTgPCL2B75ENXZ0aEppE3Az1duWgeEviXgP1aD4iak6qXHK4OnqddgSEZtGLG7Ul03cTzVWtalxW59SemDcfBGxmap4N6WULht6vm3AF1JKO1NKe1NKn6F6kR07avtbpA+1vRs4Pup0qR1f3z/KOqq32G3Wh7pOa7MD5wHfSSntGDN9zRYRuncCDwEfiYgDI2J9RJw8ZZ79qMZldgF76yPo6VPmOZhqgH1PRGyhOuICkFJ6CLgFuDIiNkfEvhExKPDPgEMi4oWN57oKuCwijoLqaBgRrx+34IjYCNxKVZxLRzzkv4E3RcShEbFPRJxHNUb1oynbVLrO1xb4FlWP6eKI2D8i3lnf/816/rdFxEvq348F3kf1drTNOl/XGdrswPnAZ6Zsx5rMPXRTSs8AZwGvpDpzuBN485R5Hqe67OZ6qrOW51Jd3jHJx6kG5x+hGlD/2tD086jGeO4Ffg68u17WvVSXg+yIiD0RcTjV+OtNwG31tXl3ACdOWPYbqIYPLhi6tm9rPf2jVCcm/hfYQ3Vy5pyU0p4p21S0PtQ2pfQ01UmW86lqdyFwdn0/wMnAXRHxJFXP7mbg/VO2p2h9qCvT2ywRcRLVeO9iLhUbLKceOJYkZdDKD0dIUlsZupKUkaErSRkZupKU0bTr8jzLVo6Y/pCZWddyzLOuYG1LMrK29nQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKaNF/GPKYj33315V/BJ3STl1KnRHhSoYrJLK0anQHWdcGEtSbo7pSlJGhq4kZdT70HXoQVJOnRrTbZ4wM0wllaj3PV1JyqmzoetlYpJK1NnQlaQSdWpMd6XsDUvKrdOhm1LyhJq0JG1oe8voeDm8IEkZdT50HUKQVJLOhy4YvJKeb1m50Okx3WkiwkCWFmS1bWvSWHAX2msverrQjWJJfdeGk3PT9CZ0JbVDSqnTnaTeh24XjpxSF40L3ra32V6FbpePnlIXdbHN9ip0JWnZDF1JRetab7d3odu1Akp91OZx3V5fp6tuGNUAPbh2S5e+R6V3PV11R0SMbYhdaaDqHkMXG2gbzVIz69ptba1vL4cXuvRWpQtmqcVguGCldfOj3t3RlXbby9DV8q0mPNeyLINXpXB4odaFI2hbLGJfDz46ariqdL0N3VGN0+BdvHnv41FBa/B2U1faZ29DF2ycJZulNtavfQZXnEy68mRWba1/r0N3lK4cTbugrY1K89eldllE6M7jqLdaNmxpeVbb7tvcbpceus2d3qWjmaTZ9K3dLz101S+zXGHQnN7mHo1mNyl4uxbKXqc7gtd1Lp77t58mfcBh0O66FrLD7OlKKkbXAxcMXbWUPeX2Wmvt2l57hxdUvLY3Mj1fH4YRxrGnK2kpVnow7crHvA1dSUsz6ycPuxC2A4aupKWaFqpdG4YwdCUVoS/B64k0LUWXGpHmZ5breNvOnq6yM3A1Sdd7vMWFbhd2qsazvppFl4O3uNBVdy2jsXTtzHefTKpdm4PXMV0tlYGoacaN87Z1jNeerqTidanHa+hqqUpsNCWuk7oTvEUOL7T1bYNWZ1KjmfV1MO+G52uwTF0YaogpK5plK8Y1mFw7cdTyCyzgPFNlaRvXpl5JptfAvHdIcS/cRVh2Zsxo5Eo6vKCsCmsU6pg2HNSLCF0bYr9Yb61Vm6/jLSJ0YfRO9D8Ed5fXz2qRSg7eIk+k5VRycfpgWvCupD6GeL8M6j3puxqajytF8aGb+6xkaQXqO+uhaab9F4rSrmwoZngB8l+HZy9X6oZZ3jGV0t6LCt2cWnKZmKQZtaX9Fhe6i+ztDo52pRzxJM1XG07QFhe6MN8j1qxBW3qhJM2u5PZc/Im0pkX1UEsukKTVaV7dUFIbL7KnC3mCsA1vRSStTWltvFU93XkprQiS+qPo0J12/d1KnkeSSlB06IKBKalbih3TlaQuMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyMnQlKSNDV5IyCv/briTlY09XkjIydCUpI0NXkjIydCUpI0NXkjIydCUpo/8H8O60oq/Odh8AAAAASUVORK5CYII=\n", 714 | "text/plain": [ 715 | "
" 716 | ] 717 | }, 718 | "metadata": { 719 | "needs_background": "light" 720 | }, 721 | "output_type": "display_data" 722 | } 723 | ], 724 | "source": [ 725 | "from matplotlib.figure import Figure\n", 726 | "from numpy import ndarray\n", 727 | "from typing import List\n", 728 | "import matplotlib.image as mpimg\n", 729 | "\n", 730 | "def render_chart(fig: Figure, axis: ndarray, image_id: int, data_list: List):\n", 731 | " image = mpimg.imread(data_list[image_id])\n", 732 | " axis.title.set_text(data_list[image_id].split('/')[-2])\n", 733 | " axis.axis('off')\n", 734 | " axis.imshow(image, cmap='gray')\n", 735 | "\n", 736 | "print('Number of selected characters is {}'.format(n_chars)) \n", 737 | "\n", 738 | "fig, axs = plt.subplots(2, 3)\n", 739 | "render_chart(fig=fig, axis=axs[0, 0], image_id=0, data_list=sampled_paths)\n", 740 | "render_chart(fig=fig, axis=axs[0, 1], image_id=1, data_list=sampled_paths)\n", 741 | "render_chart(fig=fig, axis=axs[0, 2], image_id=2, data_list=sampled_paths)\n", 742 | "render_chart(fig=fig, axis=axs[1, 0], image_id=3, data_list=sampled_paths)\n", 743 | "render_chart(fig=fig, axis=axs[1, 1], image_id=4, data_list=sampled_paths)\n", 744 | "render_chart(fig=fig, axis=axs[1, 2], image_id=5, data_list=sampled_paths)" 745 | ] 746 | }, 747 | { 748 | "cell_type": "code", 749 | "execution_count": 4, 750 | "metadata": { 751 | "scrolled": true 752 | }, 753 | "outputs": [ 754 | { 755 | "data": { 756 | "application/vnd.jupyter.widget-view+json": { 757 | "model_id": "59fff385c8aa44cf8f9b61383551fdfe", 758 | "version_major": 2, 759 | "version_minor": 0 760 | }, 761 | "text/plain": [ 762 | "HBox(children=(FloatProgress(value=0.0, max=60000.0), HTML(value='')))" 763 | ] 764 | }, 765 | "metadata": {}, 766 | "output_type": "display_data" 767 | }, 768 | { 769 | "name": "stdout", 770 | "output_type": "stream", 771 | "text": [ 772 | "\n" 773 | ] 774 | }, 775 | { 776 | "data": { 777 | "application/vnd.jupyter.widget-view+json": { 778 | "model_id": "872f05b8d043413dadfbe238bbd33fb9", 779 | "version_major": 2, 780 | "version_minor": 0 781 | }, 782 | "text/plain": [ 783 | "HBox(children=(FloatProgress(value=0.0, max=15000.0), HTML(value='')))" 784 | ] 785 | }, 786 | "metadata": {}, 787 | "output_type": "display_data" 788 | }, 789 | { 790 | "name": "stdout", 791 | "output_type": "stream", 792 | "text": [ 793 | "\n" 794 | ] 795 | } 796 | ], 797 | "source": [ 798 | "train_size, test_size = 60000, 15000\n", 799 | "\n", 800 | "train_dataset = [draw_one_sample(char_paths, sample_size=6) for i in tqdm(range(train_size))]\n", 801 | "train_X, train_y = [i[0] for i in train_dataset], [i[1] for i in train_dataset]\n", 802 | "\n", 803 | "test_dataset = [draw_one_sample(char_paths, sample_size=6) for i in tqdm(range(test_size))]\n", 804 | "test_X, test_y = [i[0] for i in test_dataset], [i[1] for i in test_dataset]" 805 | ] 806 | }, 807 | { 808 | "cell_type": "code", 809 | "execution_count": 5, 810 | "metadata": {}, 811 | "outputs": [], 812 | "source": [ 813 | "import tensorflow as tf\n", 814 | "from typing import List\n", 815 | "from tensorflow import convert_to_tensor\n", 816 | "AUTOTUNE = tf.data.experimental.AUTOTUNE\n", 817 | "\n", 818 | "\n", 819 | "@tf.function\n", 820 | "def load_image(file_path):\n", 821 | " image = tf.io.read_file(file_path)\n", 822 | " return tf.image.decode_png(image, channels=1)\n", 823 | "\n", 824 | "@tf.function\n", 825 | "def load_image_list(image_list: tf.Tensor):\n", 826 | " return tf.cast(tf.map_fn(lambda x: load_image(x), image_list, dtype=tf.uint8), tf.float32)\n", 827 | "\n", 828 | "\n", 829 | "class SetDataGenerator:\n", 830 | " def __init__(self, X, y):\n", 831 | " self.X = X\n", 832 | " self.y = y \n", 833 | " self.dataset = None\n", 834 | "\n", 835 | " def generator_init(self, shuffle_buffer_size=500, repeat=-1, batch_size=64): \n", 836 | " \"\"\"\n", 837 | " :param repeat, -1 is the default behaviour to repeat the dataset indefinitely\n", 838 | " \"\"\"\n", 839 | " self.dataset = tf.data.Dataset.from_tensor_slices((self.X, self.y))\n", 840 | " self.dataset = self.dataset.map(lambda x, y: (load_image_list(x), tf.cast(y, tf.float32)))\n", 841 | "\n", 842 | " if shuffle_buffer_size == 0:\n", 843 | " self.dataset = self.dataset.repeat(repeat).batch(batch_size).prefetch(buffer_size=AUTOTUNE)\n", 844 | " else:\n", 845 | " self.dataset = self.dataset.shuffle(shuffle_buffer_size).repeat(repeat).batch(batch_size) \\\n", 846 | " .prefetch(buffer_size=AUTOTUNE)\n", 847 | "\n", 848 | " @property\n", 849 | " def batch(self):\n", 850 | " return self.dataset" 851 | ] 852 | }, 853 | { 854 | "cell_type": "code", 855 | "execution_count": 6, 856 | "metadata": {}, 857 | "outputs": [ 858 | { 859 | "name": "stdout", 860 | "output_type": "stream", 861 | "text": [ 862 | "Number of unique characters is : 5.0\n" 863 | ] 864 | }, 865 | { 866 | "data": { 867 | "text/plain": [ 868 | "(-0.5, 104.5, 104.5, -0.5)" 869 | ] 870 | }, 871 | "execution_count": 6, 872 | "metadata": {}, 873 | "output_type": "execute_result" 874 | }, 875 | { 876 | "data": { 877 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAADnCAYAAAC9roUQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAJZUlEQVR4nO3d0Q7qthIF0KTq//8yfaioKIdACPZ4PF5Lug/VaTm5sb0zTEyy3263DYAYf40+AICVCF2AQEIXIJDQBQgkdAECCV2AQH9/+HP7yfLYG36Wcc2j5bhum7HN5OXYqnQBAgldgEBCFyCQ0AUIJHQBAgldgEBCFyCQ0AUIJHQBAn36RRqktu/f/aDLQ/sZTegyhW/D9dfPEc70or0AL7QKeXgmdAEChbUXXlUOvsIBq9HTpYRPF3A33MhC6DKFewhe/cYkRMlC6AJpVWxLCl2gm4qh+auhoVttW87qkwn4zJaxhvZ9L3chAdoKC11VIKxFAfKaShdoTuAeC+3pvtv2U4WKnpV9WtvWx6Abaa1PvDukMJawPU97AfiJwP2OfbrAZe8CV9i+ptIFLhG41whdplH5BuxsBO51Qhf4isD9jdBlahZ5HsbiHDfSGOJMq8Aizudo3IzVeUKXtM4E877vXRb8498tULQUWhK6TM8Ntr4Eblt6usAhgdue0AVeErh9aC8wROsXSfayargI3H66h66H0XDFqzkSHcSrzlOB25dKl6kJgbYEbn9CF9i2LUfgZmkr9TQkdO8n1pUTchj9o4cVwvZu6O6FlU40zCYicFd8mWv30J3lLjWsasQaXDFs71L0dP3kEnLpuQ6/ee5GxWAOCd3HAfx0Elud5FGD5aLBTCLXydm/q/oaCu/pVj+hFa/M1BR58+xsdVs9H7Zt8NuABRSwQtA+GtrTvd1ughcSaR2AGfb+ZjP8Rlq1qnfViQTPRu/9zWp46N79MhCe7wDMwqMdAQIJXVhQxLfD589bZXfCJ2naC0CMyPsnQvZPKl2AQEIXUJEGEroAgYQukFaG1za1JnQBAgldIHX1WK3fbMsYsG3bn8GbOez2fU99fO8IXeClTM9OqPRwrLLthSoDBK39Gpr3V+1YY9eUDd1tE7xwpNVPckcG8Kzru0ToztrbgdFarp3HAO4RiFXWeZme7lHPZ+aGO0TotRc24qE6M67vMqG7bbWa7TDSuyD7ZY39GpKv1vj9n2cJ31KhC/SX9W0vs2x5WyJ0Z/wKAtk9r6moED77jfbb44nKiCVCd9v+PwACGNo7E8LWXsHQPXMVFMDQX8+1lbXFcUaJLWO/mHHQgH/d9xvPVDyVq3S37ftdDJ/+3ZkGFFb1uE6/Wf/R67tk6G5b268f7z5DIEM+mddl2dC9Ozr52grACMv2dDNfCYG6yle675wJXq0FoKWlQ/cMwQq0tGx7AWAEoQsQSOgCBBK6AIGELkAgoQsQSOgCBBK6AIGELkAgocu0/FqQGQldpvEYsgKXWXn2AlMRtsxOpQsQSOgCBBK6AIGELkAgoQsQSOgCBBK6AIGELkAgoQsQaPcLH4A4Kl2AQEIXIJDQBQgkdAECCV2AQEIXIJDQBQgkdAECCV2AQEIXIJDQBQgkdAECCV2AQEIXIJDQBQgkdAECCV2AQEIXIJDQBQj094c/9wK1PPaGn2Vc82g5rttmbDN5ObYqXYBAQhcgkNAFCCR0AQIJXYBAQhcgkNAFCPRpn24q+/7ntrfbzbZEYB4qXYBAU1W6r7yqfh+phIFM0lS6+77/97/WnwuQRYrQfQ7Go6C83W6XKlfBC2QxPHSvBKLgBWY1PHSv0qsFZjT0Rtq7NsIZR//e0efu+y6sgaGGVbq/Bu47ghXIakjo9gzcT5+lt1tLjx0v0FOafbqqU856FbJaR8wivNKNrEpUu7V8qmqNKzNIUemOqFBURuM8h+OncRCm6+n1nJVv514PoaE74oE1t9vNboZEjloDcFd9PoS1F0aeSMGaQ/XFxO9WmCPT/jjiW4K3FuNZy0q7UELaCxFbxFiDOVPLKkH7qHvoClzeiZwH5uJY3wZs63HJEvDLtBfgSJbFWNk35/jq0wRn0TV0VRbcZQ+27Mc3s7PnNjpsR+VQin26QD3fhO1KwkN3tRPMsei58Pj3qWzHy9DPH0GlS3eZJjwx3o15hsJr5DG4kcYQv056Qc6sVLqUkaGCwg30T4Qu0N2oh1pltEToZj35wHpKh66whVhZ1lyW43hl2tBtcVL1mKAdvdxzlt29YCLEyFxxUFP2Obds6GYfmMpc8NYRPdZn1/XI9b9s6G6b4IVWMqylDMdwxtKhC/ST6We+mb5dTXsj7cpJ9OruWCPeiUesDNVl9p8cP5s2dIG8eoXdt8/lzWip9kLWQSCeudDP6HOb/SHoS4Xuto2fEOTxOBfMi+9FtxYytDJa0F6gi1n6uRmPietmGE+hCzTTI/TOVLgzhO3dcu0FYB7VAnfbhC5BZlsYjFcxcLdNe4EOqtzwINa382bGwN02lS4BZl0cHHNhvU6lCzQRcXG9+ndkukiodIEpVPnGJHRpamRFkamaoa3WgesV7JRVpTqhv6O5Um0OLRe6qqEaqi1E/vU8rhkenNPacqFLPy5otHAP2qoXVqFLN1UXDf1V6uE+E7o08Vzljto+pNoeZ7ZzP+p4hS4/y7bYsh0POWSpdsND14KgpVavbeI7WQJsRktVuhZbezOd033f/zvemY6bfq7Og1/mT9efAd9utzSTO8txrGCWKshLSdurfk4fc+Tqg/qXqHRne1voLGa6kM10rLM4WjuZz3XP9X72s5cI3SMC97qjhZXxnM50rMRrcZH4Zi6VD93MV13asHVsnBmr3da+vXgPCd2oAdFW6GOmynGmYyXG1Yt0q9wqX+m+YsFd83j3/5lzuq4q1e7R/G4977s/xHzUDgbhwBFzIM6MuxnO5tXV/19LVrq0lXVRzVZpzezdHOg1DvcK9F0lmlHJ0FXlcsQc6OdT8LYKxxaf8+s8+OW/L/eOtJmueBUIMR6daSfe/3xEdfzoauvz1zk/LHR79HrsVuhr9nM4+/HP4myYZSiQHufEp+NpNX/KVLoCF/KIvoHeYo1H5URI6B4NQNRXCNZmDozxfN5brveZxzSs0h2xdWzmgaENcyCP+1j03pKVXZn2wrOqA8ZrmZ5ox3tnbqBVXr+hoRuxMCoPFlS3wvoNr3RbB+8KgwTUMaS9ICjp4fmCbp6RUclfpLEuQUt2ZW+ksS7BS2YqXYBAQhcgkNAFCCR0AQIJXYBAQhcgkNAFCCR0AQIJXYBAQhcgkNAFCCR0AQIJXYBAQhcgkNAFCCR0AQIJXYBAu6fsA8RR6QIEEroAgYQuQCChCxBI6AIEEroAgf4BNMKkdv4aDUMAAAAASUVORK5CYII=\n", 878 | "text/plain": [ 879 | "
" 880 | ] 881 | }, 882 | "metadata": { 883 | "needs_background": "light" 884 | }, 885 | "output_type": "display_data" 886 | } 887 | ], 888 | "source": [ 889 | "# Test data generator\n", 890 | "\n", 891 | "set_data_gen = SetDataGenerator(train_X, train_y)\n", 892 | "set_data_gen.generator_init(batch_size=3)\n", 893 | "batch_data = next(iter(set_data_gen.batch))\n", 894 | "\n", 895 | "im_to_plot = batch_data[0][0].numpy()\n", 896 | "\n", 897 | "print('Number of unique characters is : {}'.format(batch_data[1][0].numpy()))\n", 898 | "\n", 899 | "fig, axs = plt.subplots(2, 3)\n", 900 | "axs[0, 0].imshow(im_to_plot[0, :, :, 0], cmap='gray')\n", 901 | "axs[0, 0].axis('off')\n", 902 | "axs[0, 1].imshow(im_to_plot[1, :, :, 0], cmap='gray')\n", 903 | "axs[0, 1].axis('off')\n", 904 | "axs[0, 2].imshow(im_to_plot[2, :, :, 0], cmap='gray')\n", 905 | "axs[0, 2].axis('off')\n", 906 | "axs[1, 0].imshow(im_to_plot[3, :, :, 0], cmap='gray')\n", 907 | "axs[1, 0].axis('off')\n", 908 | "axs[1, 1].imshow(im_to_plot[4, :, :, 0], cmap='gray')\n", 909 | "axs[1, 1].axis('off')\n", 910 | "axs[1, 2].imshow(im_to_plot[5, :, :, 0], cmap='gray')\n", 911 | "axs[1, 2].axis('off')" 912 | ] 913 | }, 914 | { 915 | "cell_type": "code", 916 | "execution_count": 7, 917 | "metadata": {}, 918 | "outputs": [ 919 | { 920 | "name": "stdout", 921 | "output_type": "stream", 922 | "text": [ 923 | "Train on 3 samples\n", 924 | "Epoch 1/3\n", 925 | "3/3 [==============================] - 7s 2s/sample - loss: 4.7445\n", 926 | "Epoch 2/3\n", 927 | "3/3 [==============================] - 2s 601ms/sample - loss: 3.4197\n", 928 | "Epoch 3/3\n", 929 | "3/3 [==============================] - 2s 654ms/sample - loss: 1.0533\n" 930 | ] 931 | }, 932 | { 933 | "data": { 934 | "text/plain": [ 935 | "" 936 | ] 937 | }, 938 | "execution_count": 7, 939 | "metadata": {}, 940 | "output_type": "execute_result" 941 | } 942 | ], 943 | "source": [ 944 | "from tensorflow.keras.layers import LayerNormalization, Dense\n", 945 | "import tensorflow as tf\n", 946 | "from set_transformer.layers.attention import MultiHeadAttention\n", 947 | "from set_transformer.layers import RFF\n", 948 | "from set_transformer.blocks import SetAttentionBlock, PoolingMultiHeadAttention\n", 949 | "from tensorflow.keras.layers import Conv2D\n", 950 | "from tensorflow.keras.layers import Conv2D, Input, MaxPooling2D, Dropout, Flatten\n", 951 | "from tensorflow.keras.models import Model\n", 952 | "tf.keras.backend.set_floatx('float32')\n", 953 | "\n", 954 | "\n", 955 | "def image_processing_model(input_image_shape = (105, 105, 1), output_len=128):\n", 956 | " input_image = Input(shape=input_image_shape)\n", 957 | " y = Conv2D(64, kernel_size=(3, 3),\n", 958 | " activation='relu',\n", 959 | " input_shape=input_image_shape)(input_image)\n", 960 | " y = Conv2D(64, (3, 3), activation='relu')(y)\n", 961 | " y = Conv2D(64, (3, 3), activation='relu')(y)\n", 962 | " y = Conv2D(64, (3, 3), activation='relu')(y)\n", 963 | " y = MaxPooling2D(pool_size=(2, 2))(y)\n", 964 | " y = Dropout(0.25)(y)\n", 965 | " y = Flatten()(y)\n", 966 | " output_vec = Dense(output_len, activation='relu')(y)\n", 967 | " return Model(input_image, output_vec)\n", 968 | "\n", 969 | "\n", 970 | "class CharEncoder(tf.keras.layers.Layer):\n", 971 | " def __init__(self, d=128, h=8):\n", 972 | " super(CharEncoder, self).__init__()\n", 973 | "\n", 974 | " # Instantiate image processing model\n", 975 | " self.image_model = image_processing_model(output_len=d)\n", 976 | "\n", 977 | " # Encoding part\n", 978 | " self.sab_1 = SetAttentionBlock(d, h, RFF(d))\n", 979 | " self.sab_2 = SetAttentionBlock(d, h, RFF(d))\n", 980 | "\n", 981 | " def call(self, x):\n", 982 | " return self.sab_2(self.sab_1(tf.map_fn(self.image_model, x)))\n", 983 | " \n", 984 | "\n", 985 | "class CharDecoder(tf.keras.layers.Layer):\n", 986 | " def __init__(self, out_dim, d=128, h=8, k=32):\n", 987 | " super(CharDecoder, self).__init__()\n", 988 | "\n", 989 | " self.PMA = PoolingMultiHeadAttention(d, k, h, RFF(d), RFF(d))\n", 990 | " self.SAB = SetAttentionBlock(d, h, RFF(d))\n", 991 | " self.output_mapper = Dense(out_dim)\n", 992 | " self.k, self.d = k, d\n", 993 | "\n", 994 | " def call(self, x):\n", 995 | " decoded_vec = self.SAB(self.PMA(x))\n", 996 | " decoded_vec = tf.reshape(decoded_vec, [-1, self.k * self.d])\n", 997 | " return tf.reshape(self.output_mapper(decoded_vec), (tf.shape(decoded_vec)[0],))\n", 998 | "\n", 999 | " \n", 1000 | "class CharSetTransformer(tf.keras.Model):\n", 1001 | " def __init__(self):\n", 1002 | " super(CharSetTransformer, self).__init__()\n", 1003 | " self.encoder = CharEncoder()\n", 1004 | " self.decoder = CharDecoder(out_dim=1)\n", 1005 | "\n", 1006 | " def call(self, x):\n", 1007 | " enc_output = self.encoder(x) # (batch_size, set_len, d_model)\n", 1008 | " return self.decoder(enc_output)\n", 1009 | " \n", 1010 | "\n", 1011 | "tset_model = CharSetTransformer()\n", 1012 | "tset_model.compile(loss='mae', optimizer='adam')\n", 1013 | "tset_model.fit(batch_data[0], batch_data[1], epochs=3)" 1014 | ] 1015 | }, 1016 | { 1017 | "cell_type": "code", 1018 | "execution_count": 8, 1019 | "metadata": {}, 1020 | "outputs": [], 1021 | "source": [ 1022 | "# set_data_gen = SetDataGenerator(train_X, train_y)\n", 1023 | "# set_data_gen.generator_init(batch_size=64)\n", 1024 | "\n", 1025 | "# tset_model = CharSetTransformer()\n", 1026 | "# tset_model.compile(loss='mae', optimizer='adam')\n", 1027 | "# tset_model.fit(set_data_gen.batch, epochs=3, steps_per_epoch=300)" 1028 | ] 1029 | }, 1030 | { 1031 | "cell_type": "code", 1032 | "execution_count": null, 1033 | "metadata": {}, 1034 | "outputs": [], 1035 | "source": [] 1036 | }, 1037 | { 1038 | "cell_type": "code", 1039 | "execution_count": null, 1040 | "metadata": {}, 1041 | "outputs": [], 1042 | "source": [] 1043 | } 1044 | ], 1045 | "metadata": { 1046 | "kernelspec": { 1047 | "display_name": "Python 3", 1048 | "language": "python", 1049 | "name": "python3" 1050 | }, 1051 | "language_info": { 1052 | "codemirror_mode": { 1053 | "name": "ipython", 1054 | "version": 3 1055 | }, 1056 | "file_extension": ".py", 1057 | "mimetype": "text/x-python", 1058 | "name": "python", 1059 | "nbconvert_exporter": "python", 1060 | "pygments_lexer": "ipython3", 1061 | "version": "3.6.9" 1062 | } 1063 | }, 1064 | "nbformat": 4, 1065 | "nbformat_minor": 2 1066 | } 1067 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | codecov 2 | scipy 3 | seaborn 4 | tf-nightly-gpu==2.2.0.dev20200218 5 | tensorflow-datasets 6 | tqdm==4.43.0 7 | pytest 8 | pillow -------------------------------------------------------------------------------- /set_transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arrigonialberto86/set_transformer/ec76b58d43febb85bc30f4d53d58fd630c46f009/set_transformer/__init__.py -------------------------------------------------------------------------------- /set_transformer/blocks.py: -------------------------------------------------------------------------------- 1 | # Referencing https://arxiv.org/pdf/1810.00825.pdf 2 | # and the original PyTorch implementation https://github.com/TropComplique/set-transformer/blob/master/blocks.py 3 | from tensorflow import repeat 4 | from tensorflow.keras.layers import LayerNormalization, Dense 5 | import tensorflow as tf 6 | from set_transformer.layers.attention import MultiHeadAttention 7 | from set_transformer.layers import RFF 8 | 9 | 10 | class MultiHeadAttentionBlock(tf.keras.layers.Layer): 11 | def __init__(self, d: int, h: int, rff: RFF): 12 | super(MultiHeadAttentionBlock, self).__init__() 13 | self.multihead = MultiHeadAttention(d, h) 14 | self.layer_norm1 = LayerNormalization(epsilon=1e-6, dtype='float32') 15 | self.layer_norm2 = LayerNormalization(epsilon=1e-6, dtype='float32') 16 | self.rff = rff 17 | 18 | def call(self, x, y): 19 | """ 20 | Arguments: 21 | x: a float tensor with shape [b, n, d]. 22 | y: a float tensor with shape [b, m, d]. 23 | Returns: 24 | a float tensor with shape [b, n, d]. 25 | """ 26 | 27 | h = self.layer_norm1(x + self.multihead(x, y, y)) 28 | return self.layer_norm2(h + self.rff(h)) 29 | 30 | 31 | class SetAttentionBlock(tf.keras.layers.Layer): 32 | def __init__(self, d: int, h: int, rff: RFF): 33 | super(SetAttentionBlock, self).__init__() 34 | self.mab = MultiHeadAttentionBlock(d, h, rff) 35 | 36 | def call(self, x): 37 | """ 38 | Arguments: 39 | x: a float tensor with shape [b, n, d]. 40 | Returns: 41 | a float tensor with shape [b, n, d]. 42 | """ 43 | return self.mab(x, x) 44 | 45 | 46 | class InducedSetAttentionBlock(tf.keras.layers.Layer): 47 | def __init__(self, d: int, m: int, h: int, rff1: RFF, rff2: RFF): 48 | """ 49 | Arguments: 50 | d: an integer, input dimension. 51 | m: an integer, number of inducing points. 52 | h: an integer, number of heads. 53 | rff1, rff2: modules, row-wise feedforward layers. 54 | It takes a float tensor with shape [b, n, d] and 55 | returns a float tensor with the same shape. 56 | """ 57 | super(InducedSetAttentionBlock, self).__init__() 58 | self.mab1 = MultiHeadAttentionBlock(d, h, rff1) 59 | self.mab2 = MultiHeadAttentionBlock(d, h, rff2) 60 | self.inducing_points = tf.random.normal(shape=(1, m, d)) 61 | 62 | def call(self, x): 63 | """ 64 | Arguments: 65 | x: a float tensor with shape [b, n, d]. 66 | Returns: 67 | a float tensor with shape [b, n, d]. 68 | """ 69 | b = tf.shape(x)[0] 70 | p = self.inducing_points 71 | p = repeat(p, (b), axis=0) # shape [b, m, d] 72 | 73 | h = self.mab1(p, x) # shape [b, m, d] 74 | return self.mab2(x, h) 75 | 76 | 77 | class PoolingMultiHeadAttention(tf.keras.layers.Layer): 78 | 79 | def __init__(self, d: int, k: int, h: int, rff: RFF, rff_s: RFF): 80 | """ 81 | Arguments: 82 | d: an integer, input dimension. 83 | k: an integer, number of seed vectors. 84 | h: an integer, number of heads. 85 | rff: a module, row-wise feedforward layers. 86 | It takes a float tensor with shape [b, n, d] and 87 | returns a float tensor with the same shape. 88 | """ 89 | super(PoolingMultiHeadAttention, self).__init__() 90 | self.mab = MultiHeadAttentionBlock(d, h, rff) 91 | self.seed_vectors = tf.random.normal(shape=(1, k, d)) 92 | self.rff_s = rff_s 93 | 94 | @tf.function 95 | def call(self, z): 96 | """ 97 | Arguments: 98 | z: a float tensor with shape [b, n, d]. 99 | Returns: 100 | a float tensor with shape [b, k, d] 101 | """ 102 | b = tf.shape(z)[0] 103 | s = self.seed_vectors 104 | s = repeat(s, (b), axis=0) # shape [b, k, d] 105 | return self.mab(s, self.rff_s(z)) 106 | 107 | 108 | class STEncoder(tf.keras.layers.Layer): 109 | def __init__(self, d=12, m=6, h=6): 110 | super(STEncoder, self).__init__() 111 | 112 | # Embedding part 113 | self.linear_1 = Dense(d, activation='relu') 114 | 115 | # Encoding part 116 | self.isab_1 = InducedSetAttentionBlock(d, m, h, RFF(d), RFF(d)) 117 | self.isab_2 = InducedSetAttentionBlock(d, m, h, RFF(d), RFF(d)) 118 | 119 | def call(self, x): 120 | return self.isab_2(self.isab_1(self.linear_1(x))) 121 | 122 | 123 | class STDecoder(tf.keras.layers.Layer): 124 | def __init__(self, out_dim, d=12, h=2, k=8): 125 | super(STDecoder, self).__init__() 126 | 127 | self.PMA = PoolingMultiHeadAttention(d, k, h, RFF(d), RFF(d)) 128 | self.SAB = SetAttentionBlock(d, h, RFF(d)) 129 | self.output_mapper = Dense(out_dim) 130 | self.k, self.d = k, d 131 | 132 | def call(self, x): 133 | decoded_vec = self.SAB(self.PMA(x)) 134 | decoded_vec = tf.reshape(decoded_vec, [-1, self.k * self.d]) 135 | return tf.reshape(self.output_mapper(decoded_vec), (tf.shape(decoded_vec)[0],)) 136 | -------------------------------------------------------------------------------- /set_transformer/data/simulation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | def gen_max_dataset(dataset_size=100000, set_size=9, seed=0): 6 | """ 7 | The number of objects per set is constant in this toy example 8 | """ 9 | np.random.seed(seed) 10 | x = np.random.uniform(1, 100, (dataset_size, set_size)) 11 | y = np.max(x, axis=1) 12 | x, y = np.expand_dims(x, axis=2), np.expand_dims(y, axis=1) 13 | return tf.cast(x, 'float32'), tf.cast(y, 'float32') 14 | -------------------------------------------------------------------------------- /set_transformer/layers/__init__.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.layers import Dense 3 | 4 | 5 | class RFF(tf.keras.layers.Layer): 6 | """ 7 | Row-wise FeedForward layers. 8 | """ 9 | 10 | def __init__(self, d): 11 | super(RFF, self).__init__() 12 | 13 | self.linear_1 = Dense(d, activation='relu') 14 | self.linear_2 = Dense(d, activation='relu') 15 | self.linear_3 = Dense(d, activation='relu') 16 | 17 | def call(self, x): 18 | """ 19 | Arguments: 20 | x: a float tensor with shape [b, n, d]. 21 | Returns: 22 | a float tensor with shape [b, n, d]. 23 | """ 24 | return self.linear_3(self.linear_2(self.linear_1(x))) -------------------------------------------------------------------------------- /set_transformer/layers/attention.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | # https://www.tensorflow.org/tutorials/text/transformer, appears in "Attention is all you need" NIPS 2018 paper 5 | def scaled_dot_product_attention(q, k, v, mask): 6 | """Calculate the attention weights. 7 | q, k, v must have matching leading dimensions. 8 | k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v. 9 | The mask has different shapes depending on its type(padding or look ahead) 10 | but it must be broadcastable for addition. 11 | 12 | Args: 13 | q: query shape == (..., seq_len_q, depth) 14 | k: key shape == (..., seq_len_k, depth) 15 | v: value shape == (..., seq_len_v, depth_v) 16 | mask: Float tensor with shape broadcastable 17 | to (..., seq_len_q, seq_len_k). Defaults to None. 18 | 19 | Returns: 20 | output, attention_weights 21 | """ 22 | 23 | matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k) 24 | 25 | # scale matmul_qk 26 | dk = tf.cast(tf.shape(k)[-1], tf.float32) 27 | scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) 28 | 29 | # add the mask to the scaled tensor. 30 | if mask is not None: 31 | scaled_attention_logits += (mask * -1e9) 32 | 33 | # softmax is normalized on the last axis (seq_len_k) so that the scores 34 | # add up to 1. 35 | attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k) 36 | 37 | output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v) 38 | 39 | return output, attention_weights 40 | 41 | 42 | class MultiHeadAttention(tf.keras.layers.Layer): 43 | def __init__(self, d_model, num_heads): 44 | super(MultiHeadAttention, self).__init__() 45 | self.num_heads = num_heads 46 | self.d_model = d_model 47 | 48 | assert d_model % self.num_heads == 0 49 | 50 | self.depth = d_model // self.num_heads 51 | 52 | self.wq = tf.keras.layers.Dense(d_model) 53 | self.wk = tf.keras.layers.Dense(d_model) 54 | self.wv = tf.keras.layers.Dense(d_model) 55 | 56 | self.dense = tf.keras.layers.Dense(d_model) 57 | 58 | def split_heads(self, x, batch_size): 59 | """Split the last dimension into (num_heads, depth). 60 | Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth) 61 | """ 62 | x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth)) 63 | return tf.transpose(x, perm=[0, 2, 1, 3]) 64 | 65 | def call(self, q, k, v, mask=None): 66 | batch_size = tf.shape(q)[0] 67 | 68 | q = self.wq(q) # (batch_size, seq_len, d_model) 69 | k = self.wk(k) # (batch_size, seq_len, d_model) 70 | v = self.wv(v) # (batch_size, seq_len, d_model) 71 | 72 | q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth) 73 | k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth) 74 | v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth) 75 | 76 | # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth) 77 | # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k) 78 | scaled_attention, attention_weights = scaled_dot_product_attention( 79 | q, k, v, mask) 80 | 81 | scaled_attention = tf.transpose(scaled_attention, 82 | perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth) 83 | 84 | concat_attention = tf.reshape(scaled_attention, 85 | (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model) 86 | 87 | output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model) 88 | 89 | return output 90 | -------------------------------------------------------------------------------- /set_transformer/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from set_transformer.blocks import STEncoder, STDecoder 3 | 4 | 5 | class BasicSetTransformer(tf.keras.Model): 6 | def __init__(self, encoder_d=4, m=3, encoder_h=2, out_dim=1, decoder_d=4, decoder_h=2, k=2): 7 | super(BasicSetTransformer, self).__init__() 8 | self.basic_encoder = STEncoder(d=encoder_d, m=m, h=encoder_h) 9 | self.basic_decoder = STDecoder(out_dim=out_dim, d=decoder_d, h=decoder_h, k=k) 10 | 11 | def call(self, x): 12 | enc_output = self.basic_encoder(x) # (batch_size, set_len, d_model) 13 | return self.basic_decoder(enc_output) 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='set_transformer', 4 | version='0.0.1', 5 | description='Set transformer TF implementation', 6 | author='Alberto Arrigoni', 7 | author_email='arrigonialberto86@gmail.com', 8 | url='https://github.com/arrigonialberto86', 9 | # requires=['numpy', 'pandas', 'scipy', 'seaborn', 'pillow'], 10 | packages=find_packages() 11 | ) 12 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arrigonialberto86/set_transformer/ec76b58d43febb85bc30f4d53d58fd630c46f009/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_attention.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import tensorflow as tf 3 | from set_transformer.layers.attention import scaled_dot_product_attention, MultiHeadAttention 4 | 5 | 6 | class TestAttention(unittest.TestCase): 7 | def setUp(self) -> None: 8 | self.temp_k = tf.constant([[10, 0, 0], 9 | [0, 10, 0], 10 | [0, 0, 10], 11 | [0, 0, 10]], dtype=tf.float32) # (4, 3) 12 | 13 | self.temp_v = tf.constant([[1, 0], 14 | [10, 0], 15 | [100, 5], 16 | [1000, 6]], dtype=tf.float32) # (4, 2) 17 | 18 | # This `query` aligns with the second `key`, 19 | # so the second `value` is returned. 20 | self.temp_q = tf.constant([[0, 10, 0]], dtype=tf.float32) # (1, 3) 21 | 22 | def test_scaled_dot_product(self): 23 | # Dimensionality output check 24 | temp_out, temp_attn = scaled_dot_product_attention(self.temp_q, self.temp_k, self.temp_v, None) 25 | self.assertEqual(temp_out.shape[0], 1) 26 | self.assertEqual(temp_out.shape[1], 2) 27 | 28 | def test_multi_head_output_dimension(self): 29 | temp_mha = MultiHeadAttention(d_model=512, num_heads=8) 30 | y = tf.random.uniform((1, 60, 512)) # (batch_size, encoder_sequence, d_model) 31 | out = temp_mha(v=y, k=y, q=y) 32 | self.assertEqual(out.shape[0], 1) 33 | self.assertEqual(out.shape[1], 60) 34 | self.assertEqual(out.shape[2], 512) 35 | 36 | 37 | if __name__ == '__main__': 38 | unittest.main() 39 | -------------------------------------------------------------------------------- /tests/test_blocks.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import tensorflow as tf 3 | from set_transformer.blocks import MultiHeadAttentionBlock, SetAttentionBlock, InducedSetAttentionBlock, \ 4 | PoolingMultiHeadAttention, STEncoder, STDecoder 5 | from set_transformer.layers import RFF 6 | import numpy as np 7 | 8 | 9 | class TestBlocks(unittest.TestCase): 10 | def setUp(self) -> None: 11 | # For Encoder-Decoder tests 12 | self.X = np.random.uniform(1, 100, (10, 3)) 13 | self.X = np.expand_dims(self.X, axis=2) 14 | self.X = tf.cast(self.X, 'float32') 15 | 16 | def test_multi_attention_block_dim(self): 17 | x_data = tf.random.normal(shape=(10, 2, 9)) 18 | y_data = tf.random.normal(shape=(10, 3, 9)) 19 | rff = RFF(d=9) 20 | mab = MultiHeadAttentionBlock(9, 3, rff=rff) 21 | mab_output = mab(x_data, y_data) 22 | self.assertEqual(mab_output.shape[0], 10) 23 | self.assertEqual(mab_output.shape[1], 2) 24 | self.assertEqual(mab_output.shape[2], 9) 25 | 26 | def test_set_attention_block_dim(self): 27 | x_data = tf.random.normal(shape=(10, 2, 9)) 28 | rff = RFF(d=9) 29 | sab = SetAttentionBlock(9, 3, rff=rff) 30 | sab_output = sab(x_data) 31 | self.assertEqual(sab_output.shape[0], 10) 32 | self.assertEqual(sab_output.shape[1], 2) 33 | self.assertEqual(sab_output.shape[2], 9) 34 | 35 | def test_induced_set_attention_block_dim(self): 36 | z = tf.random.normal(shape=(10, 2, 9)) 37 | rff, rff_s = RFF(d=9), RFF(d=9) 38 | pma = InducedSetAttentionBlock(d=9, m=10, h=3, rff1=rff, rff2=rff_s) 39 | output = pma(z) 40 | self.assertEqual(output.shape[0], 10) 41 | self.assertEqual(output.shape[1], 2) 42 | self.assertEqual(output.shape[2], 9) 43 | 44 | def test_pma_block_dim(self): 45 | z = tf.random.normal(shape=(10, 2, 9)) 46 | rff, rff_s = RFF(d=9), RFF(d=9) 47 | pma = PoolingMultiHeadAttention(d=9, k=10, h=3, rff=rff, rff_s=rff_s) 48 | output = pma(z) 49 | self.assertEqual(output.shape[0], 10) 50 | self.assertEqual(output.shape[1], 10) 51 | self.assertEqual(output.shape[2], 9) 52 | 53 | def test_encoder_dim(self): 54 | encoder = STEncoder(d=3, m=2, h=1) 55 | encoded = encoder(self.X) 56 | self.assertEqual(encoded.shape[0], 10) 57 | self.assertEqual(encoded.shape[1], 3) 58 | self.assertEqual(encoded.shape[2], 3) 59 | decoder = STDecoder(out_dim=1, d=1, h=1, k=1) 60 | decoded = decoder(encoded) 61 | self.assertEqual(decoded.shape[0], 10) 62 | 63 | 64 | if __name__ == '__main__': 65 | unittest.main() -------------------------------------------------------------------------------- /tests/test_data_generation.py: -------------------------------------------------------------------------------- 1 | from set_transformer.data.simulation import gen_max_dataset 2 | import unittest 3 | 4 | 5 | class TestDataGen(unittest.TestCase): 6 | def setUp(self) -> None: 7 | pass 8 | 9 | def test_max_gen(self): 10 | dataset = gen_max_dataset(dataset_size=10, set_size=3) 11 | self.assertEqual(dataset[0].shape[0], 10) 12 | self.assertEqual(dataset[0].shape[1], 3) 13 | self.assertEqual(dataset[0].shape[2], 1) 14 | 15 | 16 | if __name__ == '__main__': 17 | unittest.main() 18 | -------------------------------------------------------------------------------- /tests/test_layers.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import tensorflow as tf 3 | from set_transformer.layers import RFF 4 | 5 | 6 | class TestLayers(unittest.TestCase): 7 | def test_rff(self): 8 | mlp = RFF(3) 9 | y = mlp(tf.ones(shape=(2, 4, 3))) 10 | self.assertEqual(len(mlp.weights), 6) 11 | self.assertEqual(y.shape[0], 2) 12 | self.assertEqual(y.shape[1], 4) 13 | self.assertEqual(y.shape[2], 3) 14 | 15 | 16 | if __name__ == '__main__': 17 | unittest.main() 18 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from set_transformer.model import BasicSetTransformer 3 | from set_transformer.data.simulation import gen_max_dataset 4 | import tensorflow as tf 5 | 6 | 7 | class TestModel(unittest.TestCase): 8 | def setUp(self) -> None: 9 | # Integration test 10 | self.X, self.y = gen_max_dataset(300, 3) 11 | 12 | def test_basic_transfomer(self): 13 | set_transformer = BasicSetTransformer() 14 | set_transformer.compile(loss='mae', optimizer='adam') 15 | set_transformer.fit(self.X, self.y, epochs=1) 16 | prediction = set_transformer.predict(tf.expand_dims(self.X[0], axis=0)) 17 | self.assertEqual(prediction.shape[0], 1) 18 | 19 | 20 | if __name__ == '__main__': 21 | unittest.main() --------------------------------------------------------------------------------