├── .github └── workflows │ └── python-app.yml ├── .gitignore ├── LICENSE.md ├── README.md ├── examples ├── 1-sineprediction.py └── 4-robojam-touch-generation.py ├── keras_mdn_layer ├── __init__.py └── tests │ ├── __init__.py │ └── test_mdn.py ├── notebooks ├── MDN-1D-sine-prediction.ipynb ├── MDN-2D-spiral-prediction.ipynb ├── MDN-RNN-RoboJam-touch-generation.ipynb ├── MDN-RNN-kanji-generation-example.ipynb ├── MDN-RNN-kanji-generation-with-stateless-decoder.ipynb ├── MDN-RNN-time-distributed-MDN-training.ipynb ├── context.py └── figures │ ├── kanji-mdn-diagram.png │ ├── kanji_mdn_examples.png │ ├── kanji_test_1.png │ ├── kanji_test_2.png │ ├── microjam.gif │ ├── robojam-action-diagram.jpg │ ├── robojam-mdn-diagram.png │ ├── robojam-mdn-loss.png │ └── robojam_examples.png ├── poetry.lock └── pyproject.toml /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | name: Build and test keras-mdn-layer 2 | 3 | on: 4 | push: 5 | branches: [ "master" ] 6 | pull_request: 7 | branches: [ "master" ] 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | build: 14 | strategy: 15 | matrix: 16 | platform: [ubuntu-latest, macos-latest, windows-latest] 17 | 18 | runs-on: ${{ matrix.platform }} 19 | 20 | steps: 21 | - uses: actions/checkout@v4 22 | - name: Install poetry 23 | run: pipx install poetry 24 | - uses: actions/setup-python@v5 25 | with: 26 | python-version: '3.11' 27 | cache: 'poetry' 28 | - name: Install dependencies 29 | run: | 30 | poetry install 31 | - name: Lint with flake8 32 | run: | 33 | # stop the build if there are Python syntax errors or undefined names 34 | poetry run flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 35 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 36 | poetry run flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 37 | - name: Run tests. 38 | run: | 39 | poetry run coverage run --source=keras_mdn_layer -m pytest 40 | - name: Upload coverage. 41 | run: poetry run coveralls 42 | env: 43 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 44 | - name: refresh coverage badge 45 | uses: fjogeleit/http-request-action@v1 46 | with: 47 | url: https://camo.githubusercontent.com/2cd3e1ce343708e82b3b0423f4b53355b1c10d981934b0f8e6e81fdaa8f536dc/68747470733a2f2f636f766572616c6c732e696f2f7265706f732f6769746875622f63706d70657263757373696f6e2f6b657261732d6d646e2d6c617965722f62616467652e7376673f6272616e63683d6d6173746572 48 | method: PURGE 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # 2 | *.svg 3 | *.npz 4 | *.h5 5 | *.keras 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright 2019 Charles P Martin 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Keras Mixture Density Network Layer 2 | 3 | [![Coverage Status](https://coveralls.io/repos/github/cpmpercussion/keras-mdn-layer/badge.svg?branch=master)](https://coveralls.io/github/cpmpercussion/keras-mdn-layer?branch=master) 4 | [![Build and test keras-mdn-layer](https://github.com/cpmpercussion/keras-mdn-layer/actions/workflows/python-app.yml/badge.svg)](https://github.com/cpmpercussion/keras-mdn-layer/actions/workflows/python-app.yml) 5 | ![MIT License](https://img.shields.io/github/license/cpmpercussion/keras-mdn-layer.svg?style=flat) 6 | [![DOI](https://zenodo.org/badge/137585470.svg)](https://zenodo.org/badge/latestdoi/137585470) 7 | [![PyPI version](https://badge.fury.io/py/keras-mdn-layer.svg)](https://badge.fury.io/py/keras-mdn-layer) 8 | 9 | A mixture density network (MDN) Layer for Keras using TensorFlow's distributions module. This makes it a bit more simple to experiment with neural networks that predict multiple real-valued variables that can take on multiple equally likely values. 10 | 11 | This layer can help build MDN-RNNs similar to those used in [RoboJam](https://github.com/cpmpercussion/robojam), [Sketch-RNN](https://experiments.withgoogle.com/sketch-rnn-demo), [handwriting generation](https://distill.pub/2016/handwriting/), and maybe even [world models](https://worldmodels.github.io). You can do a lot of cool stuff with MDNs! 12 | 13 | One benefit of this implementation is that you can predict any number of real-values. TensorFlow's `Mixture`, `Categorical`, and `MultivariateNormalDiag` distribution functions are used to generate the loss function (the probability density function of a mixture of multivariate normal distributions with a diagonal covariance matrix). In previous work, the loss function has often been specified by hand which is fine for 1D or 2D prediction, but becomes a bit more annoying after that. 14 | 15 | Two important functions are provided for training and prediction: 16 | 17 | - `get_mixture_loss_func(output_dim, num_mixtures)`: This function generates a loss function with the correct output dimensiona and number of mixtures. 18 | - `sample_from_output(params, output_dim, num_mixtures, temp=1.0)`: This functions samples from the mixture distribution output by the model. 19 | 20 | ## Installation 21 | 22 | This project requires Python 3.6+, TensorFlow and TensorFlow Probability. You can easily install this package from [PyPI](https://pypi.org/project/keras-mdn-layer/) via `pip` like so: 23 | 24 | python3 -m pip install keras-mdn-layer 25 | 26 | And finally, import the module in Python: `import keras_mdn_layer as mdn` 27 | 28 | Alternatively, you can clone or download this repository and then install via `python setup.py install`, or copy the `mdn` folder into your own project. 29 | 30 | ## Build 31 | 32 | This project builds using `poetry`. To build a wheel use `poetry build`. 33 | 34 | ## Examples 35 | 36 | Some examples are provided in the notebooks directory. 37 | 38 | To run these using `poetry`, run `poetry install` and then open jupyter `poetry run jupyter lab`. 39 | 40 | There's scripts for fitting multivalued functions, a standard MDN toy problem: 41 | 42 | Keras MDN Demo 43 | 44 | There's also a script for generating fake kanji characters: 45 | 46 | kanji test 1 47 | 48 | And finally, for learning how to generate musical touch-screen performances with a temporal component: 49 | 50 | Robojam Model Examples 51 | 52 | ## How to use 53 | 54 | The MDN layer should be the last in your network and you should use `get_mixture_loss_func` to generate a loss function. Here's an example of a simple network with one Dense layer followed by the MDN. 55 | 56 | from tensorflow import keras 57 | import keras_mdn_layer as mdn 58 | 59 | N_HIDDEN = 15 # number of hidden units in the Dense layer 60 | N_MIXES = 10 # number of mixture components 61 | OUTPUT_DIMS = 2 # number of real-values predicted by each mixture component 62 | 63 | model = keras.Sequential() 64 | model.add(keras.layers.Dense(N_HIDDEN, batch_input_shape=(None, 1), activation='relu')) 65 | model.add(mdn.MDN(OUTPUT_DIMS, N_MIXES)) 66 | model.compile(loss=mdn.get_mixture_loss_func(OUTPUT_DIMS,N_MIXES), optimizer=keras.optimizers.Adam()) 67 | model.summary() 68 | 69 | Fit as normal: 70 | 71 | history = model.fit(x=x_train, y=y_train) 72 | 73 | The predictions from the network are parameters of the mixture models, so you have to apply the `sample_from_output` function to generate samples. 74 | 75 | y_test = model.predict(x_test) 76 | y_samples = np.apply_along_axis(sample_from_output, 1, y_test, OUTPUT_DIMS, N_MIXES, temp=1.0) 77 | 78 | See the notebooks directory for examples in jupyter notebooks! 79 | 80 | ### Load/Save Model 81 | 82 | Saving models is straight forward: 83 | 84 | model.save('test_save.h5') 85 | 86 | But loading requires `cutom_objects` to be filled with the MDN layer, and a loss function with the appropriate parameters: 87 | 88 | m_2 = keras.models.load_model('test_save.h5', custom_objects={'MDN': mdn.MDN, 'mdn_loss_func': mdn.get_mixture_loss_func(1, N_MIXES)}) 89 | 90 | 91 | ## Acknowledgements 92 | 93 | - Hat tip to [Omimo's Keras MDN layer](https://github.com/omimo/Keras-MDN) for a starting point for this code. 94 | - Super hat tip to [hardmaru's MDN explanation, projects, and good ideas for sampling functions](http://blog.otoro.net/2015/11/24/mixture-density-networks-with-tensorflow/) etc. 95 | - Many good ideas from [Axel Brando's Master's Thesis](https://github.com/axelbrando/Mixture-Density-Networks-for-distribution-and-uncertainty-estimation) 96 | - Mixture Density Networks in Edward [tutorial](http://edwardlib.org/tutorials/mixture-density-network). 97 | 98 | ## References 99 | 100 | 1. Christopher M. Bishop. 1994. Mixture Density Networks. [Technical Report NCRG/94/004](http://publications.aston.ac.uk/373/). Neural Computing Research Group, Aston University. http://publications.aston.ac.uk/373/ 101 | 2. Axel Brando. 2017. Mixture Density Networks (MDN) for distribution and uncertainty estimation. Master’s thesis. Universitat Politècnica de Catalunya. 102 | 3. A. Graves. 2013. Generating Sequences With Recurrent Neural Networks. ArXiv e-prints (Aug. 2013). https://arxiv.org/abs/1308.0850 103 | 4. David Ha and Douglas Eck. 2017. A Neural Representation of Sketch Drawings. ArXiv e-prints (April 2017). https://arxiv.org/abs/1704.03477 104 | 5. Charles P. Martin and Jim Torresen. 2018. RoboJam: A Musical Mixture Density Network for Collaborative Touchscreen Interaction. In Evolutionary and Biologically Inspired Music, Sound, Art and Design: EvoMUSART ’18, A. Liapis et al. (Ed.). Lecture Notes in Computer Science, Vol. 10783. Springer International Publishing. DOI:[10.1007/9778-3-319-77583-8_11](http://dx.doi.org/10.1007/9778-3-319-77583-8_11) 105 | -------------------------------------------------------------------------------- /examples/1-sineprediction.py: -------------------------------------------------------------------------------- 1 | # Normal imports for everybody 2 | from tensorflow import keras 3 | import keras_mdn_layer as mdn 4 | import numpy as np 5 | 6 | 7 | # Generating some data: 8 | NSAMPLE = 3000 9 | 10 | y_data = np.float32(np.random.uniform(-10.5, 10.5, NSAMPLE)) 11 | r_data = np.random.normal(size=NSAMPLE) 12 | x_data = np.sin(0.75 * y_data) * 7.0 + y_data * 0.5 + r_data * 1.0 13 | x_data = x_data.reshape((NSAMPLE, 1)) 14 | 15 | N_HIDDEN = 15 16 | N_MIXES = 10 17 | 18 | model = keras.Sequential() 19 | model.add(keras.layers.Dense(N_HIDDEN, batch_input_shape=(None, 1), activation='relu')) 20 | model.add(keras.layers.Dense(N_HIDDEN, activation='relu')) 21 | model.add(mdn.MDN(1, N_MIXES)) 22 | model.compile(loss=mdn.get_mixture_loss_func(1, N_MIXES), optimizer=keras.optimizers.Adam()) 23 | model.summary() 24 | 25 | history = model.fit(x=x_data, y=y_data, batch_size=128, epochs=500, validation_split=0.15) 26 | 27 | # Sample on some test data: 28 | x_test = np.float32(np.arange(-15, 15, 0.01)) 29 | NTEST = x_test.size 30 | print("Testing:", NTEST, "samples.") 31 | x_test = x_test.reshape(NTEST, 1) # needs to be a matrix, not a vector 32 | 33 | # Make predictions from the model 34 | y_test = model.predict(x_test) 35 | # y_test contains parameters for distributions, not actual points on the graph. 36 | # To find points on the graph, we need to sample from each distribution. 37 | 38 | # Sample from the predicted distributions 39 | y_samples = np.apply_along_axis(mdn.sample_from_output, 1, y_test, 1, N_MIXES, temp=1.0) 40 | 41 | # Split up the mixture parameters (for future fun) 42 | mus = np.apply_along_axis((lambda a: a[:N_MIXES]), 1, y_test) 43 | sigs = np.apply_along_axis((lambda a: a[N_MIXES:2*N_MIXES]), 1, y_test) 44 | pis = np.apply_along_axis((lambda a: mdn.softmax(a[2*N_MIXES:])), 1, y_test) 45 | 46 | print("Done.") -------------------------------------------------------------------------------- /examples/4-robojam-touch-generation.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | from tensorflow.keras import backend as K 3 | from tensorflow.keras.layers import Dense, Input 4 | import numpy as np 5 | import tensorflow as tf 6 | import math 7 | import h5py 8 | import random 9 | import time 10 | import pandas as pd 11 | import keras_mdn_layer as mdn 12 | import matplotlib.pyplot as plt 13 | 14 | 15 | #input_colour = 'darkblue' 16 | #gen_colour = 'firebrick' 17 | #plt.style.use('seaborn-talk') 18 | import os 19 | os.environ["CUDA_VISIBLE_DEVICES"]="1" 20 | 21 | config = tf.ConfigProto() 22 | config.gpu_options.allow_growth = True 23 | sess = tf.Session(config=config) 24 | K.set_session(sess) 25 | 26 | # Download microjam performance data if needed. 27 | import urllib.request 28 | url = 'https://github.com/cpmpercussion/creative-prediction-datasets/raw/main/datasets/TinyPerformanceCorpus.h5' 29 | urllib.request.urlretrieve(url, './TinyPerformanceCorpus.h5') 30 | 31 | 32 | # ## Helper functions for touchscreen performances 33 | # 34 | # We need a few helper functions for managing performances: 35 | # 36 | # - Convert performances to and from pandas dataframes. 37 | # - Generate random touches. 38 | # - Sample whole performances from scratch and from a priming performance. 39 | # - Plot performances including dividing into swipes. 40 | 41 | SCALE_FACTOR = 1 42 | 43 | def perf_df_to_array(perf_df, include_moving=False): 44 | """Converts a dataframe of a performance into array a,b,dt format.""" 45 | perf_df['dt'] = perf_df.time.diff() 46 | perf_df.dt = perf_df.dt.fillna(0.0) 47 | # Clean performance data 48 | # Tiny Performance bounds defined to be in [[0,1],[0,1]], edit to fix this. 49 | perf_df.at[perf_df[perf_df.dt > 5].index, 'dt'] = 5.0 50 | perf_df.at[perf_df[perf_df.dt < 0].index, 'dt'] = 0.0 51 | perf_df.at[perf_df[perf_df.x > 1].index, 'x'] = 1.0 52 | perf_df.at[perf_df[perf_df.x < 0].index, 'x'] = 0.0 53 | perf_df.at[perf_df[perf_df.y > 1].index, 'y'] = 1.0 54 | perf_df.at[perf_df[perf_df.y < 0].index, 'y'] = 0.0 55 | if include_moving: 56 | output = np.array(perf_df[['x', 'y', 'dt', 'moving']]) 57 | else: 58 | output = np.array(perf_df[['x', 'y', 'dt']]) 59 | return output 60 | 61 | 62 | def perf_array_to_df(perf_array): 63 | """Converts an array of a performance (a,b,dt(,moving) format) into a dataframe.""" 64 | perf_array = perf_array.T 65 | perf_df = pd.DataFrame({'x': perf_array[0], 'y': perf_array[1], 'dt': perf_array[2]}) 66 | if len(perf_array) == 4: 67 | perf_df['moving'] = perf_array[3] 68 | else: 69 | # As a rule of thumb, could classify taps with dt>0.1 as taps, dt<0.1 as moving touches. 70 | perf_df['moving'] = 1 71 | perf_df.at[perf_df[perf_df.dt > 0.1].index, 'moving'] = 0 72 | perf_df['time'] = perf_df.dt.cumsum() 73 | perf_df['z'] = 38.0 74 | perf_df = perf_df.set_index(['time']) 75 | return perf_df[['x', 'y', 'z', 'moving']] 76 | 77 | 78 | def random_touch(with_moving=False): 79 | """Generate a random tiny performance touch.""" 80 | if with_moving: 81 | return np.array([np.random.rand(), np.random.rand(), 0.01, 0]) 82 | else: 83 | return np.array([np.random.rand(), np.random.rand(), 0.01]) 84 | 85 | 86 | def constrain_touch(touch, with_moving=False): 87 | """Constrain touch values from the MDRNN""" 88 | touch[0] = min(max(touch[0], 0.0), 1.0) # x in [0,1] 89 | touch[1] = min(max(touch[1], 0.0), 1.0) # y in [0,1] 90 | touch[2] = max(touch[2], 0.001) # dt # define minimum time step 91 | if with_moving: 92 | touch[3] = np.greater(touch[3], 0.5) * 1.0 93 | return touch 94 | 95 | 96 | def generate_random_tiny_performance(model, n_mixtures, first_touch, time_limit=5.0, steps_limit=1000, temp=1.0, sigma_temp=0.0, predict_moving=False): 97 | """Generates a tiny performance up to 5 seconds in length.""" 98 | if predict_moving: 99 | out_dim = 4 100 | else: 101 | out_dim = 3 102 | time = 0 103 | steps = 0 104 | previous_touch = first_touch 105 | performance = [previous_touch.reshape((out_dim,))] 106 | while (steps < steps_limit and time < time_limit): 107 | params = model.predict(previous_touch.reshape(1,1,out_dim) * SCALE_FACTOR) 108 | previous_touch = mdn.sample_from_output(params[0], out_dim, n_mixtures, temp=temp, sigma_temp=sigma_temp) / SCALE_FACTOR 109 | output_touch = previous_touch.reshape(out_dim,) 110 | output_touch = constrain_touch(output_touch, with_moving=predict_moving) 111 | performance.append(output_touch.reshape((out_dim,))) 112 | steps += 1 113 | time += output_touch[2] 114 | return np.array(performance) 115 | 116 | 117 | def condition_and_generate(model, perf, n_mixtures, time_limit=5.0, steps_limit=1000, temp=1.0, sigma_temp=0.0, predict_moving=False): 118 | """Conditions the network on an existing tiny performance, then generates a new one.""" 119 | if predict_moving: 120 | out_dim = 4 121 | else: 122 | out_dim = 3 123 | time = 0 124 | steps = 0 125 | # condition 126 | for touch in perf: 127 | params = model.predict(touch.reshape(1, 1, out_dim) * SCALE_FACTOR) 128 | previous_touch = mdn.sample_from_output(params[0], out_dim, n_mixtures, temp=temp, sigma_temp=sigma_temp) / SCALE_FACTOR 129 | output = [previous_touch.reshape((out_dim,))] 130 | # generate 131 | while (steps < steps_limit and time < time_limit): 132 | params = model.predict(previous_touch.reshape(1, 1, out_dim) * SCALE_FACTOR) 133 | previous_touch = mdn.sample_from_output(params[0], out_dim, n_mixtures, temp=temp, sigma_temp=sigma_temp) / SCALE_FACTOR 134 | output_touch = previous_touch.reshape(out_dim,) 135 | output_touch = constrain_touch(output_touch, with_moving=predict_moving) 136 | output.append(output_touch.reshape((out_dim,))) 137 | steps += 1 138 | time += output_touch[2] 139 | net_output = np.array(output) 140 | return net_output 141 | 142 | 143 | def divide_performance_into_swipes(perf_df): 144 | """Divides a performance into a sequence of swipe dataframes for plotting.""" 145 | touch_starts = perf_df[perf_df.moving == 0].index 146 | performance_swipes = [] 147 | remainder = perf_df 148 | for att in touch_starts: 149 | swipe = remainder.iloc[remainder.index < att] 150 | performance_swipes.append(swipe) 151 | remainder = remainder.iloc[remainder.index >= att] 152 | performance_swipes.append(remainder) 153 | return performance_swipes 154 | 155 | 156 | input_colour = "#4388ff" 157 | gen_colour = "#ec0205" 158 | 159 | def plot_perf_on_ax(perf_df, ax, color="#ec0205", linewidth=3, alpha=0.5): 160 | """Plot a 2D representation of a performance 2D""" 161 | swipes = divide_performance_into_swipes(perf_df) 162 | for swipe in swipes: 163 | p = ax.plot(swipe.x, swipe.y, 'o-', alpha=alpha, markersize=linewidth) 164 | plt.setp(p, color=color, linewidth=linewidth) 165 | ax.set_ylim([1.0,0]) 166 | ax.set_xlim([0,1.0]) 167 | ax.set_xticks([]) 168 | ax.set_yticks([]) 169 | 170 | def plot_2D(perf_df, name="foo", saving=False, figsize=(5, 5)): 171 | """Plot a 2D representation of a performance 2D""" 172 | fig, ax = plt.subplots(figsize=(figsize)) 173 | plot_perf_on_ax(perf_df, ax, color=gen_colour, linewidth=5, alpha=0.7) 174 | if saving: 175 | fig.savefig(name+".png", bbox_inches='tight') 176 | 177 | def plot_double_2d(perf1, perf2, name="foo", saving=False, figsize=(8, 8)): 178 | """Plot two performances in 2D""" 179 | fig, ax = plt.subplots(figsize=(figsize)) 180 | plot_perf_on_ax(perf1, ax, color=input_colour, linewidth=5, alpha=0.7) 181 | plot_perf_on_ax(perf2, ax, color=gen_colour, linewidth=5, alpha=0.7) 182 | if saving: 183 | fig.savefig(name+".png", bbox_inches='tight') 184 | 185 | # # Load up the Dataset: 186 | # 187 | # The dataset consists of around 1000 5-second performances from the MicroJam app. 188 | # This is in a sequence of points consisting of an x-location, a y-location, and a time-delta from the previous point. 189 | # When the user swipes, the time-delta is very small, if they tap it's quite large. 190 | # Let's have a look at some of the data: 191 | 192 | # Load Data 193 | microjam_data_file_name = "./TinyPerformanceCorpus.h5" 194 | 195 | with h5py.File(microjam_data_file_name, 'r') as data_file: 196 | microjam_corpus = data_file['total_performances'][:] 197 | 198 | print("Corpus data points between 100 and 120:") 199 | print(perf_array_to_df(microjam_corpus[100:120])) 200 | 201 | print("Some statistics about the dataset:") 202 | pd.DataFrame(microjam_corpus).describe() 203 | 204 | # - This time, the X and Y locations are *not* differences, but the exact value, but the time is a delta value. 205 | # - The data doesn't have a "pen up" value, but we can just call taps with dt>0.1 as taps, dt<0.1 as moving touches. 206 | 207 | # Plot a bit of the data to have a look: 208 | #plot_2D(perf_array_to_df(microjam_corpus[100:200])) 209 | 210 | # ## MDN RNN 211 | # 212 | # - Now we're going to build an MDN-RNN to predict MicroJam data. 213 | # - The architecture will be: 214 | # - 3 inputs (x, y, dt) 215 | # - 2 layers of 256 LSTM cells each 216 | # - MDN Layer with 3 dimensions and 5 mixtures. 217 | # - Training model will have a sequence length of 30 (prediction model: 1 in, 1 out) 218 | # 219 | # ![RoboJam MDN RNN Model](https://preview.ibb.co/cKZk9T/robojam_mdn_diagram.png) 220 | # 221 | # - Here's the model parameters and training data preparation. 222 | # - We end up with 172K training examples. 223 | 224 | # In[ ]: 225 | 226 | 227 | # Training Hyperparameters: 228 | SEQ_LEN = 30 229 | BATCH_SIZE = 256 230 | HIDDEN_UNITS = 256 231 | EPOCHS = 100 232 | VAL_SPLIT=0.15 233 | 234 | # Set random seed for reproducibility 235 | SEED = 2345 236 | random.seed(SEED) 237 | np.random.seed(SEED) 238 | 239 | def slice_sequence_examples(sequence, num_steps): 240 | xs = [] 241 | for i in range(len(sequence) - num_steps - 1): 242 | example = sequence[i: i + num_steps] 243 | xs.append(example) 244 | return xs 245 | 246 | def seq_to_singleton_format(examples): 247 | xs = [] 248 | ys = [] 249 | for ex in examples: 250 | xs.append(ex[:-1]) 251 | ys.append(ex[-1]) 252 | return (xs,ys) 253 | 254 | sequences = slice_sequence_examples(microjam_corpus, SEQ_LEN+1) 255 | print("Total training examples:", len(sequences)) 256 | X, y = seq_to_singleton_format(sequences) 257 | X = np.array(X) 258 | y = np.array(y) 259 | print("X:", X.shape, "y:", y.shape) 260 | 261 | OUTPUT_DIMENSION = 3 262 | NUMBER_MIXTURES = 5 263 | 264 | model = keras.Sequential() 265 | model.add(keras.layers.LSTM(HIDDEN_UNITS, batch_input_shape=(None,SEQ_LEN,OUTPUT_DIMENSION), return_sequences=True)) 266 | model.add(keras.layers.LSTM(HIDDEN_UNITS)) 267 | model.add(mdn.MDN(OUTPUT_DIMENSION, NUMBER_MIXTURES)) 268 | model.compile(loss=mdn.get_mixture_loss_func(OUTPUT_DIMENSION,NUMBER_MIXTURES), optimizer=keras.optimizers.Adam()) 269 | model.summary() 270 | 271 | # Define callbacks 272 | filepath="robojam_mdrnn-E{epoch:02d}-VL{val_loss:.2f}.h5" 273 | checkpoint = keras.callbacks.ModelCheckpoint(filepath, save_weights_only=True, verbose=1, save_best_only=True, mode='min') 274 | early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=10) 275 | callbacks = [keras.callbacks.TerminateOnNaN(), checkpoint, early_stopping] 276 | 277 | history = model.fit(X, y, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks=callbacks, validation_split=VAL_SPLIT) 278 | 279 | # Save the Model 280 | model.save('robojam-mdrnn.h5') # creates a HDF5 file of the model 281 | 282 | # Plot the loss 283 | #plt.figure(figsize=(10, 5)) 284 | #plt.plot(history.history['loss']) 285 | #plt.plot(history.history['val_loss']) 286 | #plt.xlabel("epochs") 287 | #plt.ylabel("loss") 288 | #plt.show() 289 | 290 | 291 | # # Try out the model 292 | # 293 | # - Let's try out the model 294 | # - First we will load up a decoding model with a sequence length of 1. 295 | # - The weights are loaded from a the trained model file. 296 | 297 | # Decoding Model 298 | decoder = keras.Sequential() 299 | decoder.add(keras.layers.LSTM(HIDDEN_UNITS, batch_input_shape=(1,1,OUTPUT_DIMENSION), return_sequences=True, stateful=True)) 300 | decoder.add(keras.layers.LSTM(HIDDEN_UNITS, stateful=True)) 301 | decoder.add(mdn.MDN(OUTPUT_DIMENSION, NUMBER_MIXTURES)) 302 | decoder.compile(loss=mdn.get_mixture_loss_func(OUTPUT_DIMENSION,NUMBER_MIXTURES), optimizer=keras.optimizers.Adam()) 303 | decoder.summary() 304 | 305 | # decoder.set_weights(model.get_weights()) 306 | decoder.load_weights("robojam-mdrnn.h5") 307 | 308 | 309 | # Plotting some conditioned performances. 310 | length = 100 311 | t = random.randint(0,len(microjam_corpus)-length) 312 | ex = microjam_corpus[t:t+length] #sequences[600] 313 | 314 | decoder.reset_states() 315 | p = condition_and_generate(decoder, ex, NUMBER_MIXTURES, temp=1.5, sigma_temp=0.05) 316 | #plot_double_2d(perf_array_to_df(ex), perf_array_to_df(p), figsize=(4,4)) 317 | 318 | # We can also generate unconditioned performances from a random starting point. 319 | 320 | decoder.reset_states() 321 | t = random_touch() 322 | p = generate_random_tiny_performance(decoder, NUMBER_MIXTURES, t, temp=1.2, sigma_temp=0.01) 323 | #plot_2D(perf_array_to_df(p), figsize=(4,4)) 324 | -------------------------------------------------------------------------------- /keras_mdn_layer/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | A Mixture Density Layer for Keras 3 | cpmpercussion: Charles Martin (University of Oslo) 2018 4 | https://github.com/cpmpercussion/keras-mdn-layer 5 | 6 | Hat tip to [Omimo's Keras MDN layer](https://github.com/omimo/Keras-MDN) 7 | for a starting point for this code. 8 | 9 | Provided under MIT License 10 | """ 11 | 12 | import os;os.environ["TF_USE_LEGACY_KERAS"]="1" 13 | from tensorflow import keras 14 | from tensorflow.keras import backend as K 15 | from tensorflow.keras import layers 16 | import numpy as np 17 | import tensorflow as tf 18 | from tensorflow_probability import distributions as tfd 19 | 20 | 21 | def elu_plus_one_plus_epsilon(x): 22 | """ELU activation with a very small addition to help prevent 23 | NaN in loss.""" 24 | return keras.backend.elu(x) + 1 + keras.backend.epsilon() 25 | 26 | 27 | class MDN(layers.Layer): 28 | """A Mixture Density Network Layer for Keras. 29 | This layer has a few tricks to avoid NaNs in the loss function when training: 30 | - Activation for variances is ELU + 1 + 1e-8 (to avoid very small values) 31 | - Mixture weights (pi) are trained in as logits, not in the softmax space. 32 | 33 | A loss function needs to be constructed with the same output dimension and number of mixtures. 34 | A sampling function is also provided to sample from distribution parametrised by the MDN outputs. 35 | """ 36 | 37 | def __init__(self, output_dimension, num_mixtures, **kwargs): 38 | self.output_dim = output_dimension 39 | self.num_mix = num_mixtures 40 | with tf.name_scope('MDN'): 41 | self.mdn_mus = layers.Dense(self.num_mix * self.output_dim, name='mdn_mus') # mix*output vals, no activation 42 | self.mdn_sigmas = layers.Dense(self.num_mix * self.output_dim, activation=elu_plus_one_plus_epsilon, name='mdn_sigmas') # mix*output vals exp activation 43 | self.mdn_pi = layers.Dense(self.num_mix, name='mdn_pi') # mix vals, logits 44 | super(MDN, self).__init__(**kwargs) 45 | 46 | def build(self, input_shape): 47 | with tf.name_scope('mus'): 48 | self.mdn_mus.build(input_shape) 49 | with tf.name_scope('sigmas'): 50 | self.mdn_sigmas.build(input_shape) 51 | with tf.name_scope('pis'): 52 | self.mdn_pi.build(input_shape) 53 | super(MDN, self).build(input_shape) 54 | 55 | @property 56 | def trainable_weights(self): 57 | return self.mdn_mus.trainable_weights + self.mdn_sigmas.trainable_weights + self.mdn_pi.trainable_weights 58 | 59 | @property 60 | def non_trainable_weights(self): 61 | return self.mdn_mus.non_trainable_weights + self.mdn_sigmas.non_trainable_weights + self.mdn_pi.non_trainable_weights 62 | 63 | def call(self, x, mask=None): 64 | with tf.name_scope('MDN'): 65 | mdn_out = layers.concatenate([self.mdn_mus(x), 66 | self.mdn_sigmas(x), 67 | self.mdn_pi(x)], 68 | name='mdn_outputs') 69 | return mdn_out 70 | 71 | def compute_output_shape(self, input_shape): 72 | """Returns output shape, showing the number of mixture parameters.""" 73 | return (input_shape[0], (2 * self.output_dim * self.num_mix) + self.num_mix) 74 | 75 | def get_config(self): 76 | config = { 77 | "output_dimension": self.output_dim, 78 | "num_mixtures": self.num_mix 79 | } 80 | base_config = super(MDN, self).get_config() 81 | return dict(list(base_config.items()) + list(config.items())) 82 | 83 | # @classmethod 84 | # def from_config(cls, config): 85 | # return cls(**config) 86 | 87 | 88 | def get_mixture_loss_func(output_dim, num_mixes): 89 | """Construct a loss functions for the MDN layer parametrised by number of mixtures.""" 90 | # Construct a loss function with the right number of mixtures and outputs 91 | def mdn_loss_func(y_true, y_pred): 92 | # Reshape inputs in case this is used in a TimeDistribued layer 93 | y_pred = tf.reshape(y_pred, [-1, (2 * num_mixes * output_dim) + num_mixes], name='reshape_ypreds') 94 | y_true = tf.reshape(y_true, [-1, output_dim], name='reshape_ytrue') 95 | # Split the inputs into paramaters 96 | out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[num_mixes * output_dim, 97 | num_mixes * output_dim, 98 | num_mixes], 99 | axis=-1, name='mdn_coef_split') 100 | # Construct the mixture models 101 | cat = tfd.Categorical(logits=out_pi) 102 | component_splits = [output_dim] * num_mixes 103 | mus = tf.split(out_mu, num_or_size_splits=component_splits, axis=1) 104 | sigs = tf.split(out_sigma, num_or_size_splits=component_splits, axis=1) 105 | coll = [tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale 106 | in zip(mus, sigs)] 107 | mixture = tfd.Mixture(cat=cat, components=coll) 108 | loss = mixture.log_prob(y_true) 109 | loss = tf.negative(loss) 110 | loss = tf.reduce_mean(loss) 111 | return loss 112 | 113 | # Actually return the loss function 114 | with tf.name_scope('MDN'): 115 | return mdn_loss_func 116 | 117 | 118 | def get_mixture_sampling_fun(output_dim, num_mixes): 119 | """Construct a TensorFlor sampling operation for the MDN layer parametrised 120 | by mixtures and output dimension. This can be used in a Keras model to 121 | generate samples directly.""" 122 | 123 | def sampling_func(y_pred): 124 | # Reshape inputs in case this is used in a TimeDistribued layer 125 | y_pred = tf.reshape(y_pred, [-1, (2 * num_mixes * output_dim) + num_mixes], name='reshape_ypreds') 126 | out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[num_mixes * output_dim, 127 | num_mixes * output_dim, 128 | num_mixes], 129 | axis=1, name='mdn_coef_split') 130 | cat = tfd.Categorical(logits=out_pi) 131 | component_splits = [output_dim] * num_mixes 132 | mus = tf.split(out_mu, num_or_size_splits=component_splits, axis=1) 133 | sigs = tf.split(out_sigma, num_or_size_splits=component_splits, axis=1) 134 | coll = [tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale 135 | in zip(mus, sigs)] 136 | mixture = tfd.Mixture(cat=cat, components=coll) 137 | samp = mixture.sample() 138 | # Todo: temperature adjustment for sampling function. 139 | return samp 140 | 141 | # Actually return the loss_func 142 | with tf.name_scope('MDNLayer'): 143 | return sampling_func 144 | 145 | 146 | def get_mixture_mse_accuracy(output_dim, num_mixes): 147 | """Construct an MSE accuracy function for the MDN layer 148 | that takes one sample and compares to the true value.""" 149 | # Construct a loss function with the right number of mixtures and outputs 150 | def mse_func(y_true, y_pred): 151 | # Reshape inputs in case this is used in a TimeDistribued layer 152 | y_pred = tf.reshape(y_pred, [-1, (2 * num_mixes * output_dim) + num_mixes], name='reshape_ypreds') 153 | y_true = tf.reshape(y_true, [-1, output_dim], name='reshape_ytrue') 154 | out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[num_mixes * output_dim, 155 | num_mixes * output_dim, 156 | num_mixes], 157 | axis=1, name='mdn_coef_split') 158 | cat = tfd.Categorical(logits=out_pi) 159 | component_splits = [output_dim] * num_mixes 160 | mus = tf.split(out_mu, num_or_size_splits=component_splits, axis=1) 161 | sigs = tf.split(out_sigma, num_or_size_splits=component_splits, axis=1) 162 | coll = [tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale 163 | in zip(mus, sigs)] 164 | mixture = tfd.Mixture(cat=cat, components=coll) 165 | samp = mixture.sample() 166 | mse = tf.reduce_mean(tf.square(samp - y_true), axis=-1) 167 | # Todo: temperature adjustment for sampling functon. 168 | return mse 169 | 170 | # Actually return the loss_func 171 | with tf.name_scope('MDNLayer'): 172 | return mse_func 173 | 174 | 175 | def split_mixture_params(params, output_dim, num_mixes): 176 | """Splits up an array of mixture parameters into mus, sigmas, and pis 177 | depending on the number of mixtures and output dimension. 178 | 179 | Arguments: 180 | params -- the parameters of the mixture model 181 | output_dim -- the dimension of the normal models in the mixture model 182 | num_mixes -- the number of mixtures represented 183 | """ 184 | assert len(params) == num_mixes + (output_dim * 2 * num_mixes), "The size of params needs to match the mixture configuration" 185 | mus = params[:num_mixes * output_dim] 186 | sigs = params[num_mixes * output_dim:2 * num_mixes * output_dim] 187 | pi_logits = params[-num_mixes:] 188 | return mus, sigs, pi_logits 189 | 190 | 191 | def softmax(w, t=1.0): 192 | """Softmax function for a list or numpy array of logits. Also adjusts temperature. 193 | 194 | Arguments: 195 | w -- a list or numpy array of logits 196 | 197 | Keyword arguments: 198 | t -- the temperature for to adjust the distribution (default 1.0) 199 | """ 200 | e = np.array(w) / t # adjust temperature 201 | e -= e.max() # subtract max to protect from exploding exp values. 202 | e = np.exp(e) 203 | dist = e / np.sum(e) 204 | return dist 205 | 206 | 207 | def sample_from_categorical(dist): 208 | """Samples from a categorical model PDF. 209 | 210 | Arguments: 211 | dist -- the parameters of the categorical model 212 | 213 | Returns: 214 | One sample from the categorical model, or -1 if sampling fails. 215 | """ 216 | r = np.random.rand(1) # uniform random number in [0,1] 217 | accumulate = 0 218 | for i in range(0, dist.size): 219 | accumulate += dist[i] 220 | if accumulate >= r: 221 | return i 222 | tf.logging.info('Error sampling categorical model.') 223 | return -1 224 | 225 | 226 | def sample_from_output(params, output_dim, num_mixes, temp=1.0, sigma_temp=1.0): 227 | """Sample from an MDN output with temperature adjustment. 228 | This calculation is done outside of the Keras model using 229 | Numpy. 230 | 231 | Arguments: 232 | params -- the parameters of the mixture model 233 | output_dim -- the dimension of the normal models in the mixture model 234 | num_mixes -- the number of mixtures represented 235 | 236 | Keyword arguments: 237 | temp -- the temperature for sampling between mixture components (default 1.0) 238 | sigma_temp -- the temperature for sampling from the normal distribution (default 1.0) 239 | 240 | Returns: 241 | One sample from the the mixture model, that is a numpy array of length output_dim 242 | """ 243 | assert len(params) == num_mixes + (output_dim * 2 * num_mixes), "The size of params needs to match the mixture configuration" 244 | mus, sigs, pi_logits = split_mixture_params(params, output_dim, num_mixes) 245 | pis = softmax(pi_logits, t=temp) 246 | m = sample_from_categorical(pis) 247 | # Alternative way to sample from categorical: 248 | # m = np.random.choice(range(len(pis)), p=pis) 249 | mus_vector = mus[m * output_dim:(m + 1) * output_dim] 250 | sig_vector = sigs[m * output_dim:(m + 1) * output_dim] 251 | scale_matrix = np.identity(output_dim) * sig_vector # scale matrix from diag 252 | cov_matrix = np.matmul(scale_matrix, scale_matrix.T) # cov is scale squared. 253 | cov_matrix = cov_matrix * sigma_temp # adjust for sigma temperature 254 | sample = np.random.multivariate_normal(mus_vector, cov_matrix, 1) 255 | return sample[0] 256 | -------------------------------------------------------------------------------- /keras_mdn_layer/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpmpercussion/keras-mdn-layer/bf102dc404d2e04daa975776c195b23c15b27653/keras_mdn_layer/tests/__init__.py -------------------------------------------------------------------------------- /keras_mdn_layer/tests/test_mdn.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | import keras_mdn_layer as mdn 3 | import numpy as np 4 | 5 | 6 | def test_build_mdn(): 7 | """Make sure an MDN model can be constructed""" 8 | N_HIDDEN = 5 9 | N_MIXES = 5 10 | model = keras.Sequential() 11 | model.add(keras.layers.Dense(N_HIDDEN, batch_input_shape=(None, 1), activation='relu')) 12 | model.add(keras.layers.Dense(N_HIDDEN, activation='relu')) 13 | model.add(mdn.MDN(1, N_MIXES)) 14 | model.compile(loss=mdn.get_mixture_loss_func(1, N_MIXES), optimizer=keras.optimizers.Adam()) 15 | assert isinstance(model, keras.Sequential) 16 | 17 | 18 | def test_number_of_weights(): 19 | """Make sure the number of trainable weights is set up correctly""" 20 | N_HIDDEN = 5 21 | N_MIXES = 5 22 | inputs = keras.layers.Input(shape=(1,)) 23 | x = keras.layers.Dense(N_HIDDEN, activation='relu')(inputs) 24 | m = mdn.MDN(1, N_MIXES) 25 | predictions = m(x) 26 | model = keras.Model(inputs=inputs, outputs=predictions) 27 | model.compile(loss=mdn.get_mixture_loss_func(1, N_MIXES), optimizer=keras.optimizers.Adam()) 28 | num_mdn_params = np.sum([w.get_shape().num_elements() for w in m.trainable_weights]) 29 | assert (num_mdn_params == 90) 30 | 31 | 32 | def test_save_mdn(): 33 | """Make sure an MDN model can be saved and loaded""" 34 | N_HIDDEN = 5 35 | N_MIXES = 5 36 | model = keras.Sequential() 37 | model.add(keras.layers.Dense(N_HIDDEN, batch_input_shape=(None, 1), activation='relu')) 38 | model.add(mdn.MDN(1, N_MIXES)) 39 | model.compile(loss=mdn.get_mixture_loss_func(1, N_MIXES), optimizer=keras.optimizers.Adam()) 40 | model.save('test_save.keras', overwrite=True) 41 | m_2 = keras.models.load_model('test_save.keras', custom_objects={'MDN': mdn.MDN, 'mdn_loss_func': mdn.get_mixture_loss_func(1, N_MIXES)}) 42 | assert isinstance(m_2, keras.Sequential) 43 | 44 | def test_output_shapes(): 45 | """Checks that the output shapes on an MDN model end up correct. Builds a 1-layer LSTM network to do it.""" 46 | # parameters 47 | N_HIDDEN = 5 48 | N_MIXES = 10 49 | N_DIMENSION = 7 50 | 51 | # set up a 1-layer MDRNN 52 | inputs = keras.layers.Input(shape=(1,N_DIMENSION)) 53 | lstm_1_state_h_input = keras.layers.Input(shape=(N_HIDDEN,)) 54 | lstm_1_state_c_input = keras.layers.Input(shape=(N_HIDDEN,)) 55 | lstm_1_state_input = [lstm_1_state_h_input, lstm_1_state_c_input] 56 | lstm_1, state_h_1, state_c_1 = keras.layers.LSTM(N_HIDDEN, return_state=True)(inputs, initial_state=lstm_1_state_input) 57 | lstm_1_state_output = [state_h_1, state_c_1] 58 | mdn_out = mdn.MDN(N_DIMENSION, N_MIXES)(lstm_1) 59 | decoder = keras.Model(inputs=[inputs] + lstm_1_state_input, outputs=[mdn_out] + lstm_1_state_output) 60 | 61 | # create starting input and generate one output 62 | starting_input = np.zeros((1, 1, N_DIMENSION), dtype=np.float32) 63 | initial_state = [np.zeros((1,N_HIDDEN), dtype=np.float32), np.zeros((1,N_HIDDEN), dtype=np.float32)] 64 | output_list = decoder([starting_input] + initial_state) # run the network 65 | mdn_parameters = output_list[0][0].numpy() 66 | 67 | # sample from the output to test sampling functions 68 | generated_sample = mdn.sample_from_output(mdn_parameters, N_DIMENSION, N_MIXES) 69 | print("Sample shape:", generated_sample.shape) 70 | print("Sample:", generated_sample) 71 | # test that the length of the generated sample is the same as N_DIMENSION 72 | assert len(generated_sample) == N_DIMENSION 73 | -------------------------------------------------------------------------------- /notebooks/MDN-1D-sine-prediction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Mixture Density Network\n", 8 | "\n", 9 | "Reproducing the classic Bishop MDN network tasks in Keras. The idea in this task is to predict a the value of an inverse sine function. This function has multiple real-valued solutions at each point, so the ANN model needs to have the capacity to handle this in it's loss function. An MDN is a good way to handle the predictions of these multiple output values.\n", 10 | "\n", 11 | "There's a couple of other versions of this task, and this implementation owes much to the following:\n", 12 | "\n", 13 | "- [David Ha - Mixture Density Networks with TensorFlow](http://blog.otoro.net/2015/11/24/mixture-density-networks-with-tensorflow/)\n", 14 | "- [Mixture Density Networks in Edward](http://edwardlib.org/tutorials/mixture-density-network)" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "# Normal imports for everybody\n", 24 | "from context import * # imports the MDN layer \n", 25 | "import numpy as np\n", 26 | "import random\n", 27 | "import matplotlib.pyplot as plt\n", 28 | "%matplotlib widget" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "## Generate Synthetic Data\n", 36 | "\n", 37 | "Data generation" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "## Generating some data:\n", 47 | "NSAMPLE = 3000\n", 48 | "\n", 49 | "y_data = np.float32(np.random.uniform(-10.5, 10.5, NSAMPLE))\n", 50 | "r_data = np.random.normal(size=NSAMPLE)\n", 51 | "x_data = np.sin(0.75 * y_data) * 7.0 + y_data * 0.5 + r_data * 1.0\n", 52 | "x_data = x_data.reshape((NSAMPLE, 1))\n", 53 | "\n", 54 | "plt.figure(figsize=(8, 8))\n", 55 | "plt.plot(x_data,y_data,'ro', alpha=0.3)\n", 56 | "plt.show()" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "## Build the MDN Model\n", 64 | "\n", 65 | "Now we will construct the MDN model in Keras. This uses the `Sequential` model interface in Keras.\n", 66 | "\n", 67 | "The `MDN` layer comes after one or more `Dense` layers. You need to define the output dimension and number of mixtures for the MDN like so: `MDN(output_dimension, number_mixtures)`.\n", 68 | "\n", 69 | "For this problem, we only need an output dimension of 1 as we are predicting one value (y). Adding more mixtures adds a more parameters (model is more complex, takes longer to train), but might help make the solutions better. You can see from the training data that there are at maximum 5 different layers to predict in the curve, so setting `N_MIXES = 5` is a good place to start.\n", 70 | "\n", 71 | "For MDNs, we have to use a special loss function that can handle the mixture parameters: the function has to take into account the number of output dimensions and mixtures." 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "N_HIDDEN = 15\n", 81 | "N_MIXES = 10\n", 82 | "\n", 83 | "model = keras.Sequential()\n", 84 | "model.add(keras.layers.Dense(N_HIDDEN, batch_input_shape=(None, 1), activation='relu'))\n", 85 | "model.add(keras.layers.Dense(N_HIDDEN, activation='relu'))\n", 86 | "model.add(mdn.MDN(1, N_MIXES))\n", 87 | "model.compile(loss=mdn.get_mixture_loss_func(1,N_MIXES), optimizer='adam') #, metrics=[mdn.get_mixture_mse_accuracy(1,N_MIXES)])\n", 88 | "model.summary()" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "### Training the model\n", 96 | "\n", 97 | "Now we train the model using Keras' normal `fit` command." 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "history = model.fit(x=x_data, y=y_data, batch_size=128, epochs=500, validation_split=0.15)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "# Save model if you want to.\n", 116 | "model.save(\"MDN-1D-sine-prediction-model.keras\")" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "# Load the model if you want to.\n", 126 | "# To load models from file, you need to supply the layer and loss function as custom_objects:\n", 127 | "model = keras.models.load_model('MDN-1D-sine-prediction-model.keras', custom_objects={'MDN': mdn.MDN, 'mdn_loss_func': mdn.get_mixture_loss_func(1, N_MIXES)})" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "### Training and Validation Loss\n", 135 | "\n", 136 | "It's interesting to see how the model trained. We can see that after a certain point training is rather slow.\n", 137 | "\n", 138 | "For this problem a loss value around 1.5 produces quite good results." 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "plt.figure(figsize=(10, 5))\n", 148 | "plt.ylim([0,9])\n", 149 | "plt.plot(history.history['loss'])\n", 150 | "plt.plot(history.history['val_loss'])\n", 151 | "plt.show()" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "## Sampling Functions\n", 159 | "\n", 160 | "The MDN model outputs parameters of a mixture model---a list of means (mu), variances (sigma), and weights (pi).\n", 161 | "\n", 162 | "The `mdn` module provides a function to sample from these parameters as follows. First the parameters are split up into `mu`s, `sigma`s and `pi`s, then the categorical distribution formed by the `pi`s is sampled to choose which mixture component should be sampled, then that component's `mu`s and `sigma`s is used to sample from a multivariate normal model, here's the code:\n", 163 | "\n", 164 | " def sample_from_output(params, output_dim, num_mixes, temp=1.0):\n", 165 | " \"\"\"Sample from an MDN output with temperature adjustment.\"\"\"\n", 166 | " mus = params[:num_mixes*output_dim]\n", 167 | " sigs = params[num_mixes*output_dim:2*num_mixes*output_dim]\n", 168 | " pis = softmax(params[-num_mixes:], t=temp)\n", 169 | " m = sample_from_categorical(pis)\n", 170 | " # Alternative way to sample from categorical:\n", 171 | " # m = np.random.choice(range(len(pis)), p=pis)\n", 172 | " mus_vector = mus[m*output_dim:(m+1)*output_dim]\n", 173 | " sig_vector = sigs[m*output_dim:(m+1)*output_dim] * temp # adjust for temperature\n", 174 | " cov_matrix = np.identity(output_dim) * sig_vector\n", 175 | " sample = np.random.multivariate_normal(mus_vector, cov_matrix, 1)\n", 176 | " return sample\n", 177 | " \n", 178 | "If you only have one prediction to sample from, you can use the function as is; but if you need to sample from a lot of predictions at once (as in the following sections), you can use `np.apply_along_axis` to apply it to a whole numpy array of predicted parameters." 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": {}, 184 | "source": [ 185 | "## Try out the MDN Model\n", 186 | "\n", 187 | "Now we try out the model by making predictions at 3000 evenly spaced points on the x-axis. \n", 188 | "\n", 189 | "Mixture models output lists of parameters, so we're going to sample from these parameters for each point on the x-axis, and also try plotting the parameters themselves so we can have some insight into what the model is learning!" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "## Sample on some test data:\n", 199 | "x_test = np.float32(np.arange(-15,15,0.01))\n", 200 | "NTEST = x_test.size\n", 201 | "print(\"Testing:\", NTEST, \"samples.\")\n", 202 | "x_test = x_test.reshape(NTEST,1) # needs to be a matrix, not a vector\n", 203 | "\n", 204 | "# Make predictions from the model\n", 205 | "y_test = model.predict(x_test)\n", 206 | "# y_test contains parameters for distributions, not actual points on the graph.\n", 207 | "# To find points on the graph, we need to sample from each distribution.\n", 208 | "\n", 209 | "# Sample from the predicted distributions\n", 210 | "y_samples = np.apply_along_axis(mdn.sample_from_output, 1, y_test, 1, N_MIXES,temp=1.0)\n", 211 | "\n", 212 | "# Split up the mixture parameters (for future fun)\n", 213 | "mus = np.apply_along_axis((lambda a: a[:N_MIXES]),1, y_test)\n", 214 | "sigs = np.apply_along_axis((lambda a: a[N_MIXES:2*N_MIXES]),1, y_test)\n", 215 | "pis = np.apply_along_axis((lambda a: mdn.softmax(a[2*N_MIXES:])),1, y_test)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "# Plot the samples\n", 225 | "plt.figure(figsize=(8, 8))\n", 226 | "plt.plot(x_data,y_data,'ro', x_test, y_samples[:,0], 'bo',alpha=0.3)\n", 227 | "plt.show()\n", 228 | "# These look pretty good!" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "# Plot the means - this gives us some insight into how the model learns to produce the mixtures.\n", 238 | "plt.figure(figsize=(8, 8))\n", 239 | "plt.plot(x_data,y_data,'ro', x_test, mus,'bo',alpha=0.3)\n", 240 | "plt.show()\n", 241 | "# Cool!" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "# Let's plot the variances and weightings of the means as well.\n", 251 | "fig = plt.figure(figsize=(8, 8))\n", 252 | "ax1 = fig.add_subplot(111)\n", 253 | "# ax1.scatter(data[0], data[1], marker='o', c='b', s=data[2], label='the data')\n", 254 | "ax1.scatter(x_data,y_data,marker='o', c='r', alpha=0.3)\n", 255 | "for i in range(N_MIXES):\n", 256 | " ax1.scatter(x_test, mus[:,i], marker='o', s=200*sigs[:,i]*pis[:,i],alpha=0.3)\n", 257 | "plt.show()" 258 | ] 259 | } 260 | ], 261 | "metadata": { 262 | "kernelspec": { 263 | "display_name": "Python 3 (ipykernel)", 264 | "language": "python", 265 | "name": "python3" 266 | }, 267 | "language_info": { 268 | "codemirror_mode": { 269 | "name": "ipython", 270 | "version": 3 271 | }, 272 | "file_extension": ".py", 273 | "mimetype": "text/x-python", 274 | "name": "python", 275 | "nbconvert_exporter": "python", 276 | "pygments_lexer": "ipython3", 277 | "version": "3.11.3" 278 | } 279 | }, 280 | "nbformat": 4, 281 | "nbformat_minor": 4 282 | } 283 | -------------------------------------------------------------------------------- /notebooks/MDN-2D-spiral-prediction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 2D Mixture Density Network\n", 8 | "\n", 9 | "An extension of Bishops' classic MDN prediction task to 2-dimensions.\n", 10 | "\n", 11 | "The idea in this task is to predict a the value of a two inverse sine functions simultaneously. This function has multiple real-valued solutions at each point, so the ANN model needs to have the capacity to handle this in it's loss function. An MDN is a good way to handle the predictions of these multiple output values.\n", 12 | "\n", 13 | "This implementation owes much to the following:\n", 14 | "\n", 15 | "- [David Ha - Mixture Density Networks with TensorFlow](http://blog.otoro.net/2015/11/24/mixture-density-networks-with-tensorflow/)\n", 16 | "- [Mixture Density Networks in Edward](http://edwardlib.org/tutorials/mixture-density-network)" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "import keras\n", 26 | "from context import * # imports the MDN layer \n", 27 | "import numpy as np\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "from mpl_toolkits.mplot3d import Axes3D \n", 30 | "%matplotlib widget" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "## Generate Synthetic Data\n", 38 | "\n", 39 | "Data generation" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "## Generating some data:\n", 49 | "NSAMPLE = 5000\n", 50 | "\n", 51 | "z_data = np.float32(np.random.uniform(-10.5, 10.5, NSAMPLE))\n", 52 | "r_data = np.random.normal(size=NSAMPLE)\n", 53 | "s_data = np.random.normal(size=NSAMPLE)\n", 54 | "x_data = np.sin(0.75 * z_data) * 7.0 + z_data * 0.5 + r_data * 1.0\n", 55 | "y_data = np.cos(0.80 * z_data) * 6.5 + z_data * 0.5 + s_data * 1.0\n", 56 | "\n", 57 | "x_input = z_data.reshape((NSAMPLE, 1))\n", 58 | "y_input = np.array([x_data,y_data])\n", 59 | "y_input = y_input.T #reshape to (NSAMPLE,2)\n", 60 | "\n", 61 | "fig = plt.figure()\n", 62 | "ax = fig.add_subplot(111, projection='3d')\n", 63 | "ax.scatter(x_data, y_data, z_data, alpha=0.3, c='r') #c=perf_down_sampled.moving\n", 64 | "plt.show()" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "## Build the MDN Model\n", 72 | "\n", 73 | "Now we will construct the MDN model in Keras. This uses the `Sequential` model interface in Keras.\n", 74 | "\n", 75 | "The `MDN` layer comes after one or more `Dense` layers. You need to define the output dimension and number of mixtures for the MDN like so: `MDN(output_dimension, number_mixtures)`.\n", 76 | "\n", 77 | "For this problem, we only need an output dimension of 1 as we are predicting one value (y). Adding more mixtures adds a more parameters (model is more complex, takes longer to train), but might help make the solutions better. You can see from the training data that there are at maximum 5 different layers to predict in the curve, so setting `N_MIXES = 5` is a good place to start.\n", 78 | "\n", 79 | "For MDNs, we have to use a special loss function that can handle the mixture parameters: the function has to take into account the number of output dimensions and mixtures." 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": { 86 | "scrolled": true 87 | }, 88 | "outputs": [], 89 | "source": [ 90 | "N_HIDDEN = 15\n", 91 | "N_MIXES = 10\n", 92 | "OUTPUT_DIMS = 2\n", 93 | "\n", 94 | "model = keras.Sequential()\n", 95 | "model.add(keras.layers.Dense(N_HIDDEN, batch_input_shape=(None, 1), activation='relu'))\n", 96 | "model.add(keras.layers.Dense(N_HIDDEN, activation='relu'))\n", 97 | "model.add(mdn.MDN(OUTPUT_DIMS, N_MIXES))\n", 98 | "model.compile(loss=mdn.get_mixture_loss_func(OUTPUT_DIMS,N_MIXES), optimizer='adam')\n", 99 | "model.summary()" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "### Training the model\n", 107 | "\n", 108 | "Now we train the model using Keras' normal `fit` command." 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "history = model.fit(x=x_input, y=y_input, batch_size=128, epochs=200, validation_split=0.15, callbacks=[keras.callbacks.TerminateOnNaN()])" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "### Training and Validation Loss\n", 125 | "\n", 126 | "It's interesting to see how the model trained. We can see that after a certain point improvement in training is rather slow.\n", 127 | "\n", 128 | "For this problem a loss value around 3.0 produces quite good results." 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "plt.figure(figsize=(10, 5))\n", 138 | "plt.ylim([0,9])\n", 139 | "plt.plot(history.history['loss'])\n", 140 | "plt.plot(history.history['val_loss'])\n", 141 | "plt.show()" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "## Try out the MDN Model\n", 149 | "\n", 150 | "Now we try out the model by making predictions at 3000 evenly spaced points on the x-axis. \n", 151 | "\n", 152 | "Mixture models output lists of parameters, so we're going to sample from these parameters for each point on the x-axis, and also try plotting the parameters themselves so we can have some insight into what the model is learning!" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "## Sample on some test data:\n", 162 | "x_test = np.float32(np.arange(-15,15,0.05))\n", 163 | "NTEST = x_test.size\n", 164 | "\n", 165 | "print(\"Testing:\", NTEST, \"samples.\")\n", 166 | "x_test_pred = x_test.reshape(NTEST,1) # needs to be a matrix for predictions but a vector for display, not a vector\n", 167 | "\n", 168 | "# Make predictions from the model\n", 169 | "y_test = model.predict(x_test_pred)\n", 170 | "# y_test contains parameters for distributions, not actual points on the graph.\n", 171 | "# To find points on the graph, we need to sample from each distribution.\n", 172 | "\n", 173 | "# Split up the mixture parameters (for future fun)\n", 174 | "mus = np.apply_along_axis((lambda a: a[:N_MIXES*OUTPUT_DIMS]), 1, y_test)\n", 175 | "sigs = np.apply_along_axis((lambda a: a[N_MIXES*OUTPUT_DIMS:2*N_MIXES*OUTPUT_DIMS]), 1, y_test)\n", 176 | "pis = np.apply_along_axis((lambda a: mdn.softmax(a[-N_MIXES:])), 1, y_test)\n", 177 | "\n", 178 | "# Sample from the predicted distributions\n", 179 | "y_samples = np.apply_along_axis(mdn.sample_from_output, 1, y_test, OUTPUT_DIMS, N_MIXES, temp=1.0, sigma_temp=1.0)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "# Plot the predicted samples.\n", 189 | "fig = plt.figure(figsize=(8, 8))\n", 190 | "ax = fig.add_subplot(111, projection='3d')\n", 191 | "ax.scatter(x_data, y_data, z_data, alpha=0.05, c='r') #c=perf_down_sampled.moving\n", 192 | "ax.scatter(y_samples.T[0], y_samples.T[1], x_test, alpha=0.2, c='b') #c=perf_down_sampled.moving\n", 193 | "plt.show()" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "# Plot the means - this gives us some insight into how the model learns to produce the mixtures.\n", 203 | "# Cool!\n", 204 | "\n", 205 | "# Plot the predicted samples.\n", 206 | "fig = plt.figure(figsize=(8, 8))\n", 207 | "ax = fig.add_subplot(111, projection='3d')\n", 208 | "ax.scatter(x_data, y_data, z_data, alpha=0.1, c='r') #c=perf_down_sampled.moving\n", 209 | "ax.scatter(y_samples.T[0], y_samples.T[1], x_test, alpha=0.1, c='b') #c=perf_down_sampled.moving\n", 210 | "for m in range(N_MIXES):\n", 211 | " one_pair = mus[m*OUTPUT_DIMS:(m+1)*OUTPUT_DIMS]\n", 212 | " ax.scatter(mus[:,2*m], mus[:,2*m + 1] , x_test, marker='o',alpha=0.3)\n", 213 | "plt.show()" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "# Let's plot the variances and weightings of the means as well.\n", 223 | "\n", 224 | "# Plot the predicted samples.\n", 225 | "fig = plt.figure(figsize=(8, 8))\n", 226 | "ax = fig.add_subplot(111, projection='3d')\n", 227 | "ax.scatter(x_data, y_data, z_data, alpha=0.1, c='r') #c=perf_down_sampled.moving\n", 228 | "for m in range(N_MIXES):\n", 229 | " one_pair = mus[m*OUTPUT_DIMS:(m+1)*OUTPUT_DIMS]\n", 230 | " ax.scatter(mus[:,2*m], mus[:,2*m + 1] , x_test, s=100*sigs[:,2*m]*pis[:,m], marker='o',alpha=0.3)\n", 231 | "plt.show()" 232 | ] 233 | } 234 | ], 235 | "metadata": { 236 | "kernelspec": { 237 | "display_name": "Python 3 (ipykernel)", 238 | "language": "python", 239 | "name": "python3" 240 | }, 241 | "language_info": { 242 | "codemirror_mode": { 243 | "name": "ipython", 244 | "version": 3 245 | }, 246 | "file_extension": ".py", 247 | "mimetype": "text/x-python", 248 | "name": "python", 249 | "nbconvert_exporter": "python", 250 | "pygments_lexer": "ipython3", 251 | "version": "3.11.3" 252 | } 253 | }, 254 | "nbformat": 4, 255 | "nbformat_minor": 4 256 | } 257 | -------------------------------------------------------------------------------- /notebooks/MDN-RNN-RoboJam-touch-generation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "slideshow": { 7 | "slide_type": "-" 8 | } 9 | }, 10 | "source": [ 11 | "# Generating touch data in time: RoboJam\n", 12 | "\n", 13 | "- Let's look at some \"touch screen music\" data.\n", 14 | "- E.g., music made with the \"MicroJam\" app:\n", 15 | "\n", 16 | "![MicroJam Performance](https://media.giphy.com/media/XWIkHvErZtPG0/giphy.gif)\n", 17 | "\n", 18 | "These performances are sequences of x and y locations, as well as time!\n", 19 | "\n", 20 | "In MicroJam, you can \"reply\" to performances made by other users; but what if you don't have any friends??? (Too much time spent on Neural Networks). \n", 21 | "\n", 22 | "Let's make \"RoboJam\", a system to respond automatically to performances.\n", 23 | "\n", 24 | "- We need to predict x and y values---as well as time!\n", 25 | "- So, we use a 3 dimensional MDN-RNN.\n", 26 | "\n", 27 | "The idea is to generate \"responses\" to an existing touchscreen sequence. Here's some examples of what it will do:\n", 28 | "\n", 29 | "![Robojam Model Examples](https://preview.ibb.co/mpfa9T/robojam_examples.jpg)" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "from tensorflow import keras\n", 39 | "from tensorflow.keras.layers import Dense, Input\n", 40 | "import numpy as np\n", 41 | "import tensorflow as tf\n", 42 | "import math\n", 43 | "import h5py\n", 44 | "import random\n", 45 | "import time\n", 46 | "import pandas as pd\n", 47 | "from context import * # imports MDN\n", 48 | "%matplotlib widget\n", 49 | "import matplotlib.pyplot as plt\n", 50 | "\n", 51 | "\n", 52 | "input_colour = 'darkblue'\n", 53 | "gen_colour = 'firebrick'" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "## First, download the dataset\n", 61 | "\n", 62 | "- This is a set of microjam performances collected from our lab" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "# Download microjam performance data if needed.\n", 72 | "import urllib.request\n", 73 | "url = 'https://github.com/cpmpercussion/creative-prediction-datasets/raw/main/datasets/TinyPerformanceCorpus.h5'\n", 74 | "urllib.request.urlretrieve(url, './TinyPerformanceCorpus.h5') " 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "## Helper functions for touchscreen performances\n", 82 | "\n", 83 | "We need a few helper functions for managing performances:\n", 84 | " \n", 85 | "- Convert performances to and from pandas dataframes.\n", 86 | "- Generate random touches.\n", 87 | "- Sample whole performances from scratch and from a priming performance.\n", 88 | "- Plot performances including dividing into swipes." 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "SCALE_FACTOR = 1\n", 98 | "\n", 99 | "def perf_df_to_array(perf_df, include_moving=False):\n", 100 | " \"\"\"Converts a dataframe of a performance into array a,b,dt format.\"\"\"\n", 101 | " perf_df['dt'] = perf_df.time.diff()\n", 102 | " perf_df.dt = perf_df.dt.fillna(0.0)\n", 103 | " # Clean performance data\n", 104 | " # Tiny Performance bounds defined to be in [[0,1],[0,1]], edit to fix this.\n", 105 | " perf_df.at[perf_df[perf_df.dt > 5].index, 'dt'] = 5.0\n", 106 | " perf_df.at[perf_df[perf_df.dt < 0].index, 'dt'] = 0.0\n", 107 | " perf_df.at[perf_df[perf_df.x > 1].index, 'x'] = 1.0\n", 108 | " perf_df.at[perf_df[perf_df.x < 0].index, 'x'] = 0.0\n", 109 | " perf_df.at[perf_df[perf_df.y > 1].index, 'y'] = 1.0\n", 110 | " perf_df.at[perf_df[perf_df.y < 0].index, 'y'] = 0.0\n", 111 | " if include_moving:\n", 112 | " output = np.array(perf_df[['x', 'y', 'dt', 'moving']])\n", 113 | " else:\n", 114 | " output = np.array(perf_df[['x', 'y', 'dt']])\n", 115 | " return output\n", 116 | "\n", 117 | "def perf_array_to_df(perf_array):\n", 118 | " \"\"\"Converts an array of a performance (a,b,dt(,moving) format) into a dataframe.\"\"\"\n", 119 | " perf_array = perf_array.T\n", 120 | " perf_df = pd.DataFrame({'x': perf_array[0], 'y': perf_array[1], 'dt': perf_array[2]})\n", 121 | " if len(perf_array) == 4:\n", 122 | " perf_df['moving'] = perf_array[3]\n", 123 | " else:\n", 124 | " # As a rule of thumb, could classify taps with dt>0.1 as taps, dt<0.1 as moving touches.\n", 125 | " perf_df['moving'] = perf_df['dt'].apply(lambda dt: 0 if dt > 0.1 else 1)\n", 126 | " perf_df['time'] = perf_df.dt.cumsum()\n", 127 | " perf_df['z'] = 38.0\n", 128 | " perf_df = perf_df.set_index(['time'])\n", 129 | " return perf_df[['x', 'y', 'z', 'moving']]\n", 130 | "\n", 131 | "\n", 132 | "def random_touch(with_moving=False):\n", 133 | " \"\"\"Generate a random tiny performance touch.\"\"\"\n", 134 | " if with_moving:\n", 135 | " return np.array([np.random.rand(), np.random.rand(), 0.01, 0])\n", 136 | " else:\n", 137 | " return np.array([np.random.rand(), np.random.rand(), 0.01])\n", 138 | "\n", 139 | "\n", 140 | "def constrain_touch(touch, with_moving=False):\n", 141 | " \"\"\"Constrain touch values from the MDRNN\"\"\"\n", 142 | " touch[0] = min(max(touch[0], 0.0), 1.0) # x in [0,1]\n", 143 | " touch[1] = min(max(touch[1], 0.0), 1.0) # y in [0,1]\n", 144 | " touch[2] = max(touch[2], 0.001) # dt # define minimum time step\n", 145 | " if with_moving:\n", 146 | " touch[3] = np.greater(touch[3], 0.5) * 1.0\n", 147 | " return touch\n", 148 | "\n", 149 | "\n", 150 | "def generate_random_tiny_performance(model, n_mixtures, first_touch, time_limit=5.0, steps_limit=1000, temp=1.0, sigma_temp=0.0, predict_moving=False):\n", 151 | " \"\"\"Generates a tiny performance up to 5 seconds in length.\"\"\"\n", 152 | " if predict_moving:\n", 153 | " out_dim = 4\n", 154 | " else:\n", 155 | " out_dim = 3\n", 156 | " time = 0\n", 157 | " steps = 0\n", 158 | " previous_touch = first_touch\n", 159 | " performance = [previous_touch.reshape((out_dim,))]\n", 160 | " while (steps < steps_limit and time < time_limit):\n", 161 | " net_output = model(previous_touch.reshape(1,1,out_dim) * SCALE_FACTOR)\n", 162 | " mdn_params = net_output[0].numpy()\n", 163 | " previous_touch = mdn.sample_from_output(mdn_params, out_dim, n_mixtures, temp=temp, sigma_temp=sigma_temp) / SCALE_FACTOR\n", 164 | " output_touch = previous_touch.reshape(out_dim,)\n", 165 | " output_touch = constrain_touch(output_touch, with_moving=predict_moving)\n", 166 | " performance.append(output_touch.reshape((out_dim,)))\n", 167 | " steps += 1\n", 168 | " time += output_touch[2]\n", 169 | " return np.array(performance)\n", 170 | "\n", 171 | "\n", 172 | "def condition_and_generate(model, perf, n_mixtures, time_limit=5.0, steps_limit=1000, temp=1.0, sigma_temp=0.0, predict_moving=False):\n", 173 | " \"\"\"Conditions the network on an existing tiny performance, then generates a new one.\"\"\"\n", 174 | " if predict_moving:\n", 175 | " out_dim = 4\n", 176 | " else:\n", 177 | " out_dim = 3\n", 178 | " time = 0\n", 179 | " steps = 0\n", 180 | " # condition\n", 181 | " for touch in perf:\n", 182 | " net_output = model(touch.reshape(1, 1, out_dim) * SCALE_FACTOR)\n", 183 | " mdn_params = net_output[0].numpy()\n", 184 | " previous_touch = mdn.sample_from_output(mdn_params, out_dim, n_mixtures, temp=temp, sigma_temp=sigma_temp) / SCALE_FACTOR\n", 185 | " output = [previous_touch.reshape((out_dim,))]\n", 186 | " # generate\n", 187 | " while (steps < steps_limit and time < time_limit):\n", 188 | " net_output = model(previous_touch.reshape(1, 1, out_dim) * SCALE_FACTOR)\n", 189 | " mdn_params = net_output[0].numpy()\n", 190 | " previous_touch = mdn.sample_from_output(mdn_params, out_dim, n_mixtures, temp=temp, sigma_temp=sigma_temp) / SCALE_FACTOR\n", 191 | " output_touch = previous_touch.reshape(out_dim,)\n", 192 | " output_touch = constrain_touch(output_touch, with_moving=predict_moving)\n", 193 | " output.append(output_touch.reshape((out_dim,)))\n", 194 | " steps += 1\n", 195 | " time += output_touch[2]\n", 196 | " net_output = np.array(output)\n", 197 | " return net_output\n", 198 | "\n", 199 | "\n", 200 | "def divide_performance_into_swipes(perf_df):\n", 201 | " \"\"\"Divides a performance into a sequence of swipe dataframes for plotting.\"\"\"\n", 202 | " touch_starts = perf_df[perf_df.moving == 0].index\n", 203 | " performance_swipes = []\n", 204 | " remainder = perf_df\n", 205 | " for att in touch_starts:\n", 206 | " swipe = remainder.iloc[remainder.index < att]\n", 207 | " performance_swipes.append(swipe)\n", 208 | " remainder = remainder.iloc[remainder.index >= att]\n", 209 | " performance_swipes.append(remainder)\n", 210 | " return performance_swipes\n", 211 | "\n", 212 | "\n", 213 | "input_colour = \"#4388ff\"\n", 214 | "gen_colour = \"#ec0205\"\n", 215 | "\n", 216 | "def plot_perf_on_ax(perf_df, ax, color=\"#ec0205\", linewidth=3, alpha=0.5):\n", 217 | " \"\"\"Plot a 2D representation of a performance 2D\"\"\"\n", 218 | " swipes = divide_performance_into_swipes(perf_df)\n", 219 | " for swipe in swipes:\n", 220 | " p = ax.plot(swipe.x, swipe.y, 'o-', alpha=alpha, markersize=linewidth)\n", 221 | " plt.setp(p, color=color, linewidth=linewidth)\n", 222 | " ax.set_ylim([1.0,0])\n", 223 | " ax.set_xlim([0,1.0])\n", 224 | " ax.set_xticks([])\n", 225 | " ax.set_yticks([])\n", 226 | "\n", 227 | "def plot_2D(perf_df, name=\"foo\", saving=False, figsize=(5, 5)):\n", 228 | " \"\"\"Plot a 2D representation of a performance 2D\"\"\"\n", 229 | " fig, ax = plt.subplots(figsize=(figsize))\n", 230 | " plot_perf_on_ax(perf_df, ax, color=gen_colour, linewidth=5, alpha=0.7)\n", 231 | " if saving:\n", 232 | " fig.savefig(name+\".png\", bbox_inches='tight')\n", 233 | "\n", 234 | "def plot_double_2d(perf1, perf2, name=\"foo\", saving=False, figsize=(8, 8)):\n", 235 | " \"\"\"Plot two performances in 2D\"\"\"\n", 236 | " fig, ax = plt.subplots(figsize=(figsize))\n", 237 | " plot_perf_on_ax(perf1, ax, color=input_colour, linewidth=5, alpha=0.7)\n", 238 | " plot_perf_on_ax(perf2, ax, color=gen_colour, linewidth=5, alpha=0.7)\n", 239 | " if saving:\n", 240 | " fig.savefig(name+\".png\", bbox_inches='tight')\n", 241 | " \n", 242 | "# fig, ax = plt.subplots(figsize=(5, 5))\n", 243 | "# plot_perf_on_ax(perf_array_to_df(p), ax, color=\"#ec0205\", linewidth=4, alpha=0.7)\n", 244 | "# fig.show()" 245 | ] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "metadata": {}, 250 | "source": [ 251 | "# Load up the Dataset:\n", 252 | "\n", 253 | "The dataset consists of around 1000 5-second performances from the MicroJam app.\n", 254 | "\n", 255 | "This is in a sequence of points consisting of an x-location, a y-location, and a time-delta from the previous point.\n", 256 | "\n", 257 | "When the user swipes, the time-delta is very small, if they tap it's quite large.\n", 258 | "\n", 259 | "Let's have a look at some of the data:" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": null, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "# Load Data\n", 269 | "microjam_data_file_name = \"./TinyPerformanceCorpus.h5\"\n", 270 | "\n", 271 | "with h5py.File(microjam_data_file_name, 'r') as data_file:\n", 272 | " microjam_corpus = data_file['total_performances'][:]\n", 273 | "\n", 274 | "print(\"Corpus data points between 100 and 120:\")\n", 275 | "print(perf_array_to_df(microjam_corpus[100:120]))\n", 276 | "\n", 277 | "print(\"Some statistics about the dataset:\")\n", 278 | "pd.DataFrame(microjam_corpus).describe()" 279 | ] 280 | }, 281 | { 282 | "cell_type": "markdown", 283 | "metadata": {}, 284 | "source": [ 285 | "- This time, the X and Y locations are *not* differences, but the exact value, but the time is a delta value.\n", 286 | "- The data doesn't have a \"pen up\" value, but we can just call taps with dt>0.1 as taps, dt<0.1 as moving touches." 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": null, 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [ 295 | "# Plot a bit of the data to have a look:\n", 296 | "plot_2D(perf_array_to_df(microjam_corpus[100:200]))" 297 | ] 298 | }, 299 | { 300 | "cell_type": "markdown", 301 | "metadata": {}, 302 | "source": [ 303 | "## MDN RNN\n", 304 | "\n", 305 | "- Now we're going to build an MDN-RNN to predict MicroJam data.\n", 306 | "- The architecture will be:\n", 307 | " - 3 inputs (x, y, dt)\n", 308 | " - 2 layers of 256 LSTM cells each\n", 309 | " - MDN Layer with 3 dimensions and 5 mixtures.\n", 310 | " - Training model will have a sequence length of 30 (prediction model: 1 in, 1 out)\n", 311 | " \n", 312 | "![RoboJam MDN RNN Model](https://preview.ibb.co/cKZk9T/robojam_mdn_diagram.png)\n", 313 | " \n", 314 | "- Here's the model parameters and training data preparation. \n", 315 | "- We end up with 172K training examples." 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": null, 321 | "metadata": {}, 322 | "outputs": [], 323 | "source": [ 324 | "# Training Hyperparameters:\n", 325 | "SEQ_LEN = 30\n", 326 | "BATCH_SIZE = 256\n", 327 | "HIDDEN_UNITS = 128\n", 328 | "EPOCHS = 10\n", 329 | "VAL_SPLIT=0.15\n", 330 | "\n", 331 | "# Set random seed for reproducibility\n", 332 | "SEED = 2345 \n", 333 | "random.seed(SEED)\n", 334 | "np.random.seed(SEED)\n", 335 | "\n", 336 | "def slice_sequence_examples(sequence, num_steps):\n", 337 | " xs = []\n", 338 | " for i in range(len(sequence) - num_steps - 1):\n", 339 | " example = sequence[i: i + num_steps]\n", 340 | " xs.append(example)\n", 341 | " return xs\n", 342 | "\n", 343 | "def seq_to_singleton_format(examples):\n", 344 | " xs = []\n", 345 | " ys = []\n", 346 | " for ex in examples:\n", 347 | " xs.append(ex[:-1])\n", 348 | " ys.append(ex[-1])\n", 349 | " return (xs,ys)\n", 350 | "\n", 351 | "sequences = slice_sequence_examples(microjam_corpus, SEQ_LEN+1)\n", 352 | "print(\"Total training examples:\", len(sequences))\n", 353 | "X, y = seq_to_singleton_format(sequences)\n", 354 | "X = np.array(X)\n", 355 | "y = np.array(y)\n", 356 | "print(\"X:\", X.shape, \"y:\", y.shape)" 357 | ] 358 | }, 359 | { 360 | "cell_type": "markdown", 361 | "metadata": {}, 362 | "source": [ 363 | "Now let's set up the model:" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": null, 369 | "metadata": {}, 370 | "outputs": [], 371 | "source": [ 372 | "OUTPUT_DIMENSION = 3\n", 373 | "NUMBER_MIXTURES = 5\n", 374 | "\n", 375 | "inputs = keras.layers.Input(shape=(SEQ_LEN,OUTPUT_DIMENSION))\n", 376 | "lstm_1 = keras.layers.LSTM(HIDDEN_UNITS, return_sequences=True)(inputs)\n", 377 | "lstm_2 = keras.layers.LSTM(HIDDEN_UNITS)(lstm_1)\n", 378 | "mdn_out = mdn.MDN(OUTPUT_DIMENSION, NUMBER_MIXTURES)(lstm_2)\n", 379 | "\n", 380 | "model = keras.Model(inputs=inputs, outputs=mdn_out, name=\"robojam-training\")\n", 381 | "model.compile(loss=mdn.get_mixture_loss_func(OUTPUT_DIMENSION,NUMBER_MIXTURES), optimizer='adam')\n", 382 | "model.summary()" 383 | ] 384 | }, 385 | { 386 | "cell_type": "markdown", 387 | "metadata": {}, 388 | "source": [ 389 | "## Training\n", 390 | "\n", 391 | "- Now we do the training\n", 392 | " - batch size of 256\n", 393 | " - 100 epochs.\n", 394 | "- This takes about 110s per epoch on Google Colab (3 hours to train)\n", 395 | "\n", 396 | "Here's the training and validation loss from my training run:\n", 397 | "\n", 398 | "![Training and validation loss](figures/robojam-mdn-loss.png)\n", 399 | "\n", 400 | "- The validation loss (in green) tends to jump around a lot in the early epochs.\n", 401 | "- The training loss (in blue) has one strange jump around epoch 15, but then reduces quickly at about epoch 20.\n", 402 | "- Seems to still be learning, could probably improve this model further." 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": null, 408 | "metadata": {}, 409 | "outputs": [], 410 | "source": [ 411 | "# Train the model\n", 412 | "\n", 413 | "# Define callbacks\n", 414 | "filepath=\"robojam_mdrnn-E{epoch:02d}-VL{val_loss:.2f}.keras\"\n", 415 | "checkpoint = keras.callbacks.ModelCheckpoint(filepath, verbose=1, save_best_only=True, mode='min')\n", 416 | "early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=10)\n", 417 | "callbacks = [keras.callbacks.TerminateOnNaN(), checkpoint, early_stopping]\n", 418 | "\n", 419 | "history = model.fit(X, y, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks=callbacks, validation_split=VAL_SPLIT)\n", 420 | "\n", 421 | "# Save the Model\n", 422 | "model.save('robojam-mdrnn.keras')\n", 423 | "\n", 424 | "# Plot the loss\n", 425 | "%matplotlib inline\n", 426 | "plt.figure(figsize=(10, 5))\n", 427 | "plt.plot(history.history['loss'])\n", 428 | "plt.plot(history.history['val_loss'])\n", 429 | "plt.xlabel(\"epochs\")\n", 430 | "plt.ylabel(\"loss\")\n", 431 | "plt.show()" 432 | ] 433 | }, 434 | { 435 | "cell_type": "markdown", 436 | "metadata": {}, 437 | "source": [ 438 | "# Try out the model\n", 439 | "\n", 440 | "- Let's try out the model\n", 441 | "- First we will load up a decoding model with a sequence length of 1.\n", 442 | "- The weights are loaded from a the trained model file." 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": null, 448 | "metadata": {}, 449 | "outputs": [], 450 | "source": [ 451 | "# Stateful Decoding Model\n", 452 | "inputs = keras.layers.Input(batch_input_shape=(1,1,OUTPUT_DIMENSION))\n", 453 | "lstm_1 = keras.layers.LSTM(HIDDEN_UNITS, return_sequences=True, stateful=True)(inputs)\n", 454 | "lstm_2 = keras.layers.LSTM(HIDDEN_UNITS, stateful=True)(lstm_1)\n", 455 | "mdn_out = mdn.MDN(OUTPUT_DIMENSION, NUMBER_MIXTURES)(lstm_2)\n", 456 | "\n", 457 | "decoder = keras.Model(inputs=inputs, outputs=mdn_out, name=\"robojam-generating\")\n", 458 | "decoder.summary()\n", 459 | "decoder.load_weights(\"robojam-mdrnn.keras\")" 460 | ] 461 | }, 462 | { 463 | "cell_type": "markdown", 464 | "metadata": {}, 465 | "source": [ 466 | "Plotting some conditioned performances.\n", 467 | "\n", 468 | "This model seems to work best with a very low temperature for sampling from the Gaussian elements (`sigma_temp=0.05`) and a temperature for choosing between mixtures (pi-temperature) of around 1.0." 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": null, 474 | "metadata": {}, 475 | "outputs": [], 476 | "source": [ 477 | "length = 100\n", 478 | "t = random.randint(0,len(microjam_corpus)-length)\n", 479 | "ex = microjam_corpus[t:t+length] #sequences[600]\n", 480 | "\n", 481 | "decoder.reset_states()\n", 482 | "p = condition_and_generate(decoder, ex, NUMBER_MIXTURES, temp=1.5, sigma_temp=0.05)\n", 483 | "plot_double_2d(perf_array_to_df(ex), perf_array_to_df(p), figsize=(4,4))" 484 | ] 485 | }, 486 | { 487 | "cell_type": "markdown", 488 | "metadata": {}, 489 | "source": [ 490 | "We can also generate unconditioned performances from a random starting point." 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": null, 496 | "metadata": { 497 | "scrolled": true 498 | }, 499 | "outputs": [], 500 | "source": [ 501 | "decoder.reset_states()\n", 502 | "t = random_touch()\n", 503 | "p = generate_random_tiny_performance(decoder, NUMBER_MIXTURES, t, temp=1.2, sigma_temp=0.01)\n", 504 | "plot_2D(perf_array_to_df(p), figsize=(4,4))" 505 | ] 506 | }, 507 | { 508 | "cell_type": "markdown", 509 | "metadata": {}, 510 | "source": [ 511 | "# RoboJam in Practice!\n", 512 | "\n", 513 | "- RoboJam is designed for use in a mobile music called MicroJam!\n", 514 | "- Check it out on the iTunes App Store (only Apple devices! Soz!)\n", 515 | "- [microjam website: https://microjam.info](https://microjam.info)\n", 516 | "\n", 517 | "![RoboJam in action](https://preview.ibb.co/hXg43o/robojam_action_diagram.jpg)" 518 | ] 519 | } 520 | ], 521 | "metadata": { 522 | "kernelspec": { 523 | "display_name": "Python 3 (ipykernel)", 524 | "language": "python", 525 | "name": "python3" 526 | }, 527 | "language_info": { 528 | "codemirror_mode": { 529 | "name": "ipython", 530 | "version": 3 531 | }, 532 | "file_extension": ".py", 533 | "mimetype": "text/x-python", 534 | "name": "python", 535 | "nbconvert_exporter": "python", 536 | "pygments_lexer": "ipython3", 537 | "version": "3.11.3" 538 | } 539 | }, 540 | "nbformat": 4, 541 | "nbformat_minor": 4 542 | } 543 | -------------------------------------------------------------------------------- /notebooks/MDN-RNN-kanji-generation-example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Writing Kanji with an MDN-RNN\n", 8 | "\n", 9 | "- What kind of data can be predicted by a mixture density network RNN?\n", 10 | " - Sequential data that is _continuous_, not categorical.\n", 11 | "- Drawing data is a good example, tend to want high resolution in 2 dimensions to draw.\n", 12 | " - not practical for categories\n", 13 | "- Let's try modelling some _drawing_ data using an MDN-RNN.\n", 14 | "- In this case we will use a dataset of Kanji\n", 15 | "\n", 16 | "This example is similar to hardmaru's Kanji tutorial and the original Sketch-RNN repository:\n", 17 | "\n", 18 | "- http://blog.otoro.net/2015/12/28/recurrent-net-dreams-up-fake-chinese-characters-in-vector-format-with-tensorflow/\n", 19 | "- https://github.com/hardmaru/sketch-rnn\n", 20 | "\n", 21 | "- The idea is to learn how to draw kanji characters from a dataset of vector representations. \n", 22 | "- This means learning how to move a pen in 2D space.\n", 23 | "- The data consists of a sequence of pen movements (loations in 2D) and whether the pen is up or down.\n", 24 | "- In this example, we will use one 3D MDN to model everything!\n", 25 | "\n", 26 | "We will end up with a system that can invent \"new\" kanji---but it won't know how to stop drawing! E.g.:\n", 27 | "\n", 28 | "![Kanji Test 1](figures/kanji_test_1.png)\n" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "from context import * # imports the MDN layer \n", 38 | "import numpy as np\n", 39 | "import pandas as pd\n", 40 | "import random\n", 41 | "import matplotlib.pyplot as plt\n", 42 | "from mpl_toolkits.mplot3d import Axes3D \n", 43 | "%matplotlib widget" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "### First download and process the dataset." 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "# Train from David Ha's Kanji dataset from Sketch-RNN: https://github.com/hardmaru/sketch-rnn-datasets\n", 60 | "# Other datasets in \"Sketch 3\" format should also work.\n", 61 | "import urllib.request\n", 62 | "url = 'https://github.com/hardmaru/sketch-rnn-datasets/raw/master/kanji/kanji.rdp25.npz' \n", 63 | "urllib.request.urlretrieve(url, './kanji.rdp25.npz') " 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "### Dataset:\n", 71 | "\n", 72 | "Includes about 11000 handwritten kanji characters divied into training, validation, and testing sets.\n", 73 | "\n", 74 | "For creative purposes, we may not need the validation or testing sets, and can just focus on the training set." 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "with np.load('./kanji.rdp25.npz', allow_pickle=True) as data:\n", 84 | " train_set = data['train']\n", 85 | " valid_set = data['valid']\n", 86 | " test_set = data['test']\n", 87 | " \n", 88 | "print(\"Training kanji:\", len(train_set))\n", 89 | "print(\"Validation kanji:\", len(valid_set))\n", 90 | "print(\"Testing kanji:\", len(test_set))" 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "metadata": {}, 96 | "source": [ 97 | "### Looking at one example\n", 98 | "\n", 99 | "Let's have a look at one example from the training data.\n", 100 | "\n", 101 | "- Each example is a sequence of pen movements with three numbers:\n", 102 | " - The movement of the pen in the x-direction (left-right)\n", 103 | " - The movement of the pen in the y-direction (up-down)\n", 104 | " - Whether the pen is raised, or lowered touching the paper (1 = up, 0 = down).\n", 105 | " " 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "# Have a look at the data.\n", 115 | "example = train_set[99]\n", 116 | "print(\"Shape:\", example.shape)\n", 117 | "print(example[:20])" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "### Reconstructing a training example\n", 125 | "\n", 126 | "Lets try to plot this example:" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "plt.plot(example.T[0], example.T[1])\n", 136 | "plt.title(\"Raw values (diffs) for one training example\")\n", 137 | "plt.show()" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "- That didn't work very well as we were just plotting the raw values\n", 145 | " - (the difference between each pen movement)\n", 146 | "- We can transform these into paper locations by using the `cumsum()` function.\n", 147 | "- This will add each value to the sum of the previous in the array.\n", 148 | "\n", 149 | "Here's a proper sketch:" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "plt.plot(example.T[0].cumsum(), -1 * example.T[1].cumsum())\n", 159 | "plt.title(\"Accumulated values for one training example\")\n", 160 | "plt.show()" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "(Note that this sketch ignores the pen's touching or not value.)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": {}, 173 | "source": [ 174 | "### Setup an MDN RNN\n", 175 | "\n", 176 | "So let's set up an MDN RNN to learn how to create similar drawings.\n", 177 | "\n", 178 | "Our RNN will have the following settings:\n", 179 | "\n", 180 | "- 2 RNN layers.\n", 181 | "- 256 LSTM units per RNN layer\n", 182 | "- a 3-dimensional mixture layer with 10 mixtures.\n", 183 | "- train for sequence length of 30.\n", 184 | "- training for 100 epochs with a batch size of 64.\n", 185 | "\n", 186 | "Here's a diagram:\n", 187 | "\n", 188 | "![Diagram of the Kanji MDN-RNN](figures/kanji-mdn-diagram.png)\n", 189 | "\n", 190 | "\n", 191 | "\n", 192 | "Why do we need a 3D mixture model?\n", 193 | "\n", 194 | "- One dimension for `pen-X`, one for `pen-Y`, and one for `pen-UpDown`\n", 195 | "- `pen-UpDown` isn't exactly a real number (it's either 0 or 1), but we can _make a simpler model_ by just adding another MDN dimension.\n", 196 | "- When doing predictions, we can just round the `pen-UpDown` value up to 1 or down to 0. Easy!" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "# Training Hyperparameters:\n", 206 | "SEQ_LEN = 30\n", 207 | "BATCH_SIZE = 64\n", 208 | "HIDDEN_UNITS = 64\n", 209 | "EPOCHS = 10 # small training run, change to 100 for \"production\"\n", 210 | "SEED = 2345 # set random seed for reproducibility\n", 211 | "random.seed(SEED)\n", 212 | "np.random.seed(SEED)\n", 213 | "OUTPUT_DIMENSION = 3\n", 214 | "NUMBER_MIXTURES = 10\n", 215 | "\n", 216 | "# Sequential model\n", 217 | "model = keras.Sequential()\n", 218 | "\n", 219 | "# Add two LSTM layers, make sure the input shape of the first one is (?, 30, 3)\n", 220 | "model.add(keras.layers.LSTM(HIDDEN_UNITS, batch_input_shape=(None,SEQ_LEN,OUTPUT_DIMENSION), return_sequences=True))\n", 221 | "model.add(keras.layers.LSTM(HIDDEN_UNITS))\n", 222 | "\n", 223 | "# Here's the MDN layer, need to specify the output dimension (3) and number of mixtures (10)\n", 224 | "model.add(mdn.MDN(OUTPUT_DIMENSION, NUMBER_MIXTURES))\n", 225 | "\n", 226 | "# Now we compile the MDN RNN - need to use a special loss function with the right number of dimensions and mixtures.\n", 227 | "model.compile(loss=mdn.get_mixture_loss_func(OUTPUT_DIMENSION,NUMBER_MIXTURES), optimizer=keras.optimizers.legacy.Adam())\n", 228 | "\n", 229 | "# Let's see what we have:\n", 230 | "model.summary()" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": {}, 236 | "source": [ 237 | "## Process the Data and Train the Model\n", 238 | "\n", 239 | "- Chop up the data into slices of the correct length, generate `X` and `y` for the training process.\n", 240 | "- Very similar process to the previous RNN examples!\n", 241 | "- We end up with 330000 examples - a pretty healthy dataset." 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "# Functions for slicing up data\n", 251 | "def slice_sequence_examples(sequence, num_steps):\n", 252 | " xs = []\n", 253 | " for i in range(len(sequence) - num_steps - 1):\n", 254 | " example = sequence[i: i + num_steps]\n", 255 | " xs.append(example)\n", 256 | " return xs\n", 257 | "\n", 258 | "def seq_to_singleton_format(examples):\n", 259 | " xs = []\n", 260 | " ys = []\n", 261 | " for ex in examples:\n", 262 | " xs.append(ex[:-1])\n", 263 | " ys.append(ex[-1])\n", 264 | " return (xs,ys)\n", 265 | "\n", 266 | "# Prepare training data as X and Y.\n", 267 | "slices = []\n", 268 | "for seq in train_set:\n", 269 | " slices += slice_sequence_examples(seq, SEQ_LEN+1)\n", 270 | "X, y = seq_to_singleton_format(slices)\n", 271 | "\n", 272 | "X = np.array(X)\n", 273 | "y = np.array(y)\n", 274 | "\n", 275 | "print(\"Number of training examples:\")\n", 276 | "print(\"X:\", X.shape)\n", 277 | "print(\"y:\", y.shape)" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "metadata": {}, 283 | "source": [ 284 | "## Do the training!\n", 285 | "\n", 286 | "- We're not going to train in the tutorial!\n", 287 | "- These settings take about 220 seconds per epoch, about 6 hours for the whole training run." 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": null, 293 | "metadata": {}, 294 | "outputs": [], 295 | "source": [ 296 | "# Fit the model\n", 297 | "history = model.fit(X, y, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks=[keras.callbacks.TerminateOnNaN()])" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [ 306 | "model.save('kanji_mdnrnn_model.keras') # creates a model file" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "metadata": {}, 313 | "outputs": [], 314 | "source": [ 315 | "plt.figure()\n", 316 | "plt.plot(history.history['loss'])\n", 317 | "plt.show()" 318 | ] 319 | }, 320 | { 321 | "cell_type": "markdown", 322 | "metadata": {}, 323 | "source": [ 324 | "## Try out the model! Generate some Kanji!\n", 325 | "\n", 326 | "We need to create a decoding model with batch size 1 and sequence length 1." 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": null, 332 | "metadata": {}, 333 | "outputs": [], 334 | "source": [ 335 | "# Decoding Model\n", 336 | "# Sequence length is 1\n", 337 | "# This is the stateful version.\n", 338 | "\n", 339 | "decoder = keras.Sequential()\n", 340 | "decoder.add(keras.layers.LSTM(HIDDEN_UNITS, batch_input_shape=(1,1,OUTPUT_DIMENSION), return_sequences=True, stateful=True))\n", 341 | "decoder.add(keras.layers.LSTM(HIDDEN_UNITS, stateful=True))\n", 342 | "decoder.add(mdn.MDN(OUTPUT_DIMENSION, NUMBER_MIXTURES))\n", 343 | "decoder.summary()\n", 344 | "\n", 345 | "decoder.load_weights('kanji_mdnrnn_model.keras') # load weights independently from file" 346 | ] 347 | }, 348 | { 349 | "cell_type": "markdown", 350 | "metadata": {}, 351 | "source": [ 352 | "## Generating drawings\n", 353 | "\n", 354 | "- First need some helper functions to view the output." 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": null, 360 | "metadata": {}, 361 | "outputs": [], 362 | "source": [ 363 | "def zero_start_position():\n", 364 | " \"\"\"A zeroed out start position with pen down\"\"\"\n", 365 | " out = np.zeros((1, 1, 3), dtype=np.float32)\n", 366 | " out[0, 0, 2] = 1 # set pen down.\n", 367 | " return out\n", 368 | "\n", 369 | "def generate_sketch(model, start_pos, num_points=100):\n", 370 | " return None\n", 371 | "\n", 372 | "def cutoff_stroke(x):\n", 373 | " return np.greater(x,0.5) * 1.0\n", 374 | "\n", 375 | "def plot_sketch(sketch_array):\n", 376 | " \"\"\"Plot a sketch quickly to see what it looks like.\"\"\"\n", 377 | " sketch_df = pd.DataFrame({'x':sketch_array.T[0],'y':sketch_array.T[1],'z':sketch_array.T[2]})\n", 378 | " sketch_df.x = sketch_df.x.cumsum()\n", 379 | " sketch_df.y = -1 * sketch_df.y.cumsum()\n", 380 | " # Do the plot\n", 381 | " fig = plt.figure(figsize=(8, 8))\n", 382 | " ax1 = fig.add_subplot(111)\n", 383 | " #ax1.scatter(sketch_df.x,sketch_df.y,marker='o', c='r', alpha=1.0)\n", 384 | " # Need to do something with sketch_df.z\n", 385 | " ax1.plot(sketch_df.x,sketch_df.y,'r-')\n", 386 | " plt.show()" 387 | ] 388 | }, 389 | { 390 | "cell_type": "markdown", 391 | "metadata": {}, 392 | "source": [ 393 | "## SVG Drawing Function\n", 394 | "\n", 395 | "Here's Hardmaru's Drawing Functions from _write-rnn-tensorflow_. Big hat tip to Hardmaru for this!\n", 396 | "\n", 397 | "Here's the source: https://github.com/hardmaru/write-rnn-tensorflow/blob/master/utils.py\n" 398 | ] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "execution_count": null, 403 | "metadata": {}, 404 | "outputs": [], 405 | "source": [ 406 | "# Hardmaru's Drawing Functions from write-rnn-tensorflow\n", 407 | "# Big hat tip\n", 408 | "# Here's the source:\n", 409 | "# https://github.com/hardmaru/write-rnn-tensorflow/blob/master/utils.py\n", 410 | "\n", 411 | "import svgwrite\n", 412 | "from IPython.display import SVG, display\n", 413 | "\n", 414 | "def get_bounds(data, factor):\n", 415 | " min_x = 0\n", 416 | " max_x = 0\n", 417 | " min_y = 0\n", 418 | " max_y = 0\n", 419 | "\n", 420 | " abs_x = 0\n", 421 | " abs_y = 0\n", 422 | " for i in range(len(data)):\n", 423 | " x = float(data[i, 0]) / factor\n", 424 | " y = float(data[i, 1]) / factor\n", 425 | " abs_x += x\n", 426 | " abs_y += y\n", 427 | " min_x = min(min_x, abs_x)\n", 428 | " min_y = min(min_y, abs_y)\n", 429 | " max_x = max(max_x, abs_x)\n", 430 | " max_y = max(max_y, abs_y)\n", 431 | "\n", 432 | " return (min_x, max_x, min_y, max_y)\n", 433 | "\n", 434 | "def draw_strokes(data, factor=1, svg_filename='sample.svg'):\n", 435 | " min_x, max_x, min_y, max_y = get_bounds(data, factor)\n", 436 | " dims = (50 + max_x - min_x, 50 + max_y - min_y)\n", 437 | "\n", 438 | " dwg = svgwrite.Drawing(svg_filename, size=dims)\n", 439 | " dwg.add(dwg.rect(insert=(0, 0), size=dims, fill='white'))\n", 440 | "\n", 441 | " lift_pen = 1\n", 442 | "\n", 443 | " abs_x = 25 - min_x\n", 444 | " abs_y = 25 - min_y\n", 445 | " p = \"M%s,%s \" % (abs_x, abs_y)\n", 446 | "\n", 447 | " command = \"m\"\n", 448 | "\n", 449 | " for i in range(len(data)):\n", 450 | " if (lift_pen == 1):\n", 451 | " command = \"m\"\n", 452 | " elif (command != \"l\"):\n", 453 | " command = \"l\"\n", 454 | " else:\n", 455 | " command = \"\"\n", 456 | " x = float(data[i, 0]) / factor\n", 457 | " y = float(data[i, 1]) / factor\n", 458 | " lift_pen = data[i, 2]\n", 459 | " p += command + str(x) + \",\" + str(y) + \" \"\n", 460 | "\n", 461 | " the_color = \"black\"\n", 462 | " stroke_width = 1\n", 463 | "\n", 464 | " dwg.add(dwg.path(p).stroke(the_color, stroke_width).fill(\"none\"))\n", 465 | "\n", 466 | " dwg.save()\n", 467 | " display(SVG(dwg.tostring()))" 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": null, 473 | "metadata": { 474 | "scrolled": true 475 | }, 476 | "outputs": [], 477 | "source": [ 478 | "import time\n", 479 | "\n", 480 | "# Predict a character and plot the result.\n", 481 | "pi_temperature = 2.5 # seems to work well with rather high temperature (2.5)\n", 482 | "sigma_temp = 0.1 # seems to work well with low temp\n", 483 | "\n", 484 | "p = zero_start_position()\n", 485 | "sketch = [p.reshape(3,)]\n", 486 | "decoder.reset_states()\n", 487 | "\n", 488 | "number_of_movements = 25\n", 489 | "\n", 490 | "start_time = time.time()\n", 491 | "for i in range(number_of_movements):\n", 492 | " output_list = decoder(p.reshape(1,1,3))\n", 493 | " mdn_values = output_list[0].numpy()\n", 494 | " p = mdn.sample_from_output(mdn_values, OUTPUT_DIMENSION, NUMBER_MIXTURES, temp=pi_temperature, sigma_temp=sigma_temp)\n", 495 | " # print(\"Generated:\", p)\n", 496 | " sketch.append(p.reshape((3,)))\n", 497 | "print(\"Finished. That took\", (time.time() - start_time)/number_of_movements, \"seconds per generation.\")\n", 498 | "\n", 499 | "sketch = np.array(sketch)\n", 500 | "\n", 501 | "sketch.T[2] = cutoff_stroke(sketch.T[2])\n", 502 | "draw_strokes(sketch, factor=0.5)" 503 | ] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "execution_count": null, 508 | "metadata": {}, 509 | "outputs": [], 510 | "source": [ 511 | "# Draw the sketch\n", 512 | "plot_sketch(sketch)" 513 | ] 514 | } 515 | ], 516 | "metadata": { 517 | "kernelspec": { 518 | "display_name": "Python 3 (ipykernel)", 519 | "language": "python", 520 | "name": "python3" 521 | }, 522 | "language_info": { 523 | "codemirror_mode": { 524 | "name": "ipython", 525 | "version": 3 526 | }, 527 | "file_extension": ".py", 528 | "mimetype": "text/x-python", 529 | "name": "python", 530 | "nbconvert_exporter": "python", 531 | "pygments_lexer": "ipython3", 532 | "version": "3.11.3" 533 | } 534 | }, 535 | "nbformat": 4, 536 | "nbformat_minor": 4 537 | } 538 | -------------------------------------------------------------------------------- /notebooks/MDN-RNN-kanji-generation-with-stateless-decoder.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "7e46fc34-498d-4e12-9dc4-4d56a02c4e76", 6 | "metadata": {}, 7 | "source": [ 8 | "# Kanji Generation with a stateless decoder\n", 9 | "\n", 10 | "This notebook generates how to build a LSTM MDRNN decoder for generation without using `stateful=True` in the LSTM layers.\n", 11 | "\n", 12 | "What we are going to do is create extra inputs and outputs for the decoder model which will let us add in the two LSTM state vectors ($h$ and $c$) for each LSTM layer and collect them at the output.\n", 13 | "\n", 14 | "At the start of generating a Kanji character, the LSTM state vectors are initialised to zero.\n", 15 | "\n", 16 | "In between generating Kanji strokes, we will store the LSTM state vectors in a variable.\n", 17 | "\n", 18 | "The decoder model will be defined with Keras' functional API style and the model generations will use Tensorflow's eager execution.\n", 19 | "\n", 20 | "This file does not train a model, you'll need to look at the Kanji Generation Example notebook for that. Make sure that the hyperparameters are the sample for the model that is being loaded here." 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "9acdc2de-f54b-4ffc-b535-0ef1cf68bac5", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "# Decoding Model:\n", 31 | "# uses functional Keras API\n", 32 | "# takes LSTM state as extra inputs (2 per LSTM layer)\n", 33 | "# returns LSTM state as extra outputs (2 per LSTM layer) \n", 34 | "\n", 35 | "# imports\n", 36 | "from context import * # imports the MDN layer \n", 37 | "\n", 38 | "# Hyper parameters\n", 39 | "HIDDEN_UNITS = 64\n", 40 | "OUTPUT_DIMENSION = 3\n", 41 | "NUMBER_MIXTURES = 10\n", 42 | "\n", 43 | "inputs = keras.layers.Input(shape=(1,OUTPUT_DIMENSION))\n", 44 | "lstm_1_state_h_input = keras.layers.Input(shape=(HIDDEN_UNITS,))\n", 45 | "lstm_1_state_c_input = keras.layers.Input(shape=(HIDDEN_UNITS,))\n", 46 | "lstm_1_state_input = [lstm_1_state_h_input, lstm_1_state_c_input]\n", 47 | "lstm_2_state_h_input = keras.layers.Input(shape=(HIDDEN_UNITS,))\n", 48 | "lstm_2_state_c_input = keras.layers.Input(shape=(HIDDEN_UNITS,))\n", 49 | "lstm_2_state_input = [lstm_2_state_h_input, lstm_2_state_c_input]\n", 50 | "lstm_1, state_h_1, state_c_1 = keras.layers.LSTM(HIDDEN_UNITS, return_sequences=True, return_state=True)(inputs, initial_state=lstm_1_state_input)\n", 51 | "lstm_2, state_h_2, state_c_2 = keras.layers.LSTM(HIDDEN_UNITS, return_state=True)(lstm_1, initial_state=lstm_2_state_input)\n", 52 | "lstm_1_state_output = [state_h_1, state_c_1]\n", 53 | "lstm_2_state_output = [state_h_2, state_c_2]\n", 54 | "mdn_out = mdn.MDN(OUTPUT_DIMENSION, NUMBER_MIXTURES)(lstm_2)\n", 55 | "\n", 56 | "decoder = keras.Model(inputs=[inputs] + lstm_1_state_input + lstm_2_state_input, \n", 57 | " outputs=[mdn_out] + lstm_1_state_output + lstm_2_state_output,\n", 58 | " name=\"kanji-decoder\")\n", 59 | "decoder.summary()\n", 60 | "decoder.load_weights('kanji_mdnrnn_model.keras') # load weights independently from file" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "id": "dec9c1a8-f956-409d-9a74-49fe2e6fb78e", 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "import numpy as np\n", 71 | "# Let's test that the model is working on just one prediction:\n", 72 | "\n", 73 | "def zero_start_position():\n", 74 | " \"\"\"A zeroed out start position with pen down\"\"\"\n", 75 | " out = np.zeros((1, 1, 3), dtype=np.float32)\n", 76 | " out[0, 0, 2] = 1 # set pen down.\n", 77 | " return out\n", 78 | "\n", 79 | "def random_start_position():\n", 80 | " \"\"\"A random start position with pen down\"\"\"\n", 81 | " limit = 5\n", 82 | " out = limit - (2*limit) * np.random.rand(1, 1, 3)\n", 83 | " out[0, 0, 2] = 1 # set pen down.\n", 84 | " return out\n", 85 | "\n", 86 | "def generate_initial_lstm_states(units):\n", 87 | " return [np.zeros((1,units), dtype=np.float32), np.zeros((1,units), dtype=np.float32)] \n", 88 | "\n", 89 | "start_pos = random_start_position()\n", 90 | "start_state_1 = generate_initial_lstm_states(HIDDEN_UNITS)\n", 91 | "start_state_2 = generate_initial_lstm_states(HIDDEN_UNITS)\n", 92 | "\n", 93 | "print(\"Start pos shape:\", start_pos.shape)\n", 94 | "print(\"Example state shape:\", start_state_1[0].shape)\n", 95 | "input_list = [start_pos] + start_state_1 + start_state_2\n", 96 | "\n", 97 | "print(\"Input shapes:\")\n", 98 | "for i in input_list:\n", 99 | " print(\"Shape:\", i.shape)\n", 100 | "# run one prediction\n", 101 | "output_list = decoder(input_list)\n", 102 | "print(\"Output list length:\", len(output_list))\n", 103 | "print(\"Output shapes:\")\n", 104 | "for i in output_list:\n", 105 | " print(\"Shape:\", i.shape)\n", 106 | "\n", 107 | "# This test shows that everything... seems to be working..." 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "id": "ec0a1082-457c-4799-a408-c7839d2a05af", 113 | "metadata": {}, 114 | "source": [ 115 | "# Generating Kanji (again)\n", 116 | "\n", 117 | "Now we can generate some Kanji with this model in the same way as in the previous example." 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "id": "8db3187f-f6f0-4352-a1b6-d09e7297954d", 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "# Hardmaru's Drawing Functions from write-rnn-tensorflow\n", 128 | "# Big hat tip\n", 129 | "# Here's the source:\n", 130 | "# https://github.com/hardmaru/write-rnn-tensorflow/blob/master/utils.py\n", 131 | "\n", 132 | "import svgwrite\n", 133 | "from IPython.display import SVG, display\n", 134 | "\n", 135 | "def get_bounds(data, factor):\n", 136 | " min_x = 0\n", 137 | " max_x = 0\n", 138 | " min_y = 0\n", 139 | " max_y = 0\n", 140 | "\n", 141 | " abs_x = 0\n", 142 | " abs_y = 0\n", 143 | " for i in range(len(data)):\n", 144 | " x = float(data[i, 0]) / factor\n", 145 | " y = float(data[i, 1]) / factor\n", 146 | " abs_x += x\n", 147 | " abs_y += y\n", 148 | " min_x = min(min_x, abs_x)\n", 149 | " min_y = min(min_y, abs_y)\n", 150 | " max_x = max(max_x, abs_x)\n", 151 | " max_y = max(max_y, abs_y)\n", 152 | "\n", 153 | " return (min_x, max_x, min_y, max_y)\n", 154 | "\n", 155 | "def draw_strokes(data, factor=1, svg_filename='sample.svg'):\n", 156 | " min_x, max_x, min_y, max_y = get_bounds(data, factor)\n", 157 | " dims = (50 + max_x - min_x, 50 + max_y - min_y)\n", 158 | "\n", 159 | " dwg = svgwrite.Drawing(svg_filename, size=dims)\n", 160 | " dwg.add(dwg.rect(insert=(0, 0), size=dims, fill='white'))\n", 161 | "\n", 162 | " lift_pen = 1\n", 163 | "\n", 164 | " abs_x = 25 - min_x\n", 165 | " abs_y = 25 - min_y\n", 166 | " p = \"M%s,%s \" % (abs_x, abs_y)\n", 167 | "\n", 168 | " command = \"m\"\n", 169 | "\n", 170 | " for i in range(len(data)):\n", 171 | " if (lift_pen == 1):\n", 172 | " command = \"m\"\n", 173 | " elif (command != \"l\"):\n", 174 | " command = \"l\"\n", 175 | " else:\n", 176 | " command = \"\"\n", 177 | " x = float(data[i, 0]) / factor\n", 178 | " y = float(data[i, 1]) / factor\n", 179 | " lift_pen = data[i, 2]\n", 180 | " p += command + str(x) + \",\" + str(y) + \" \"\n", 181 | "\n", 182 | " the_color = \"black\"\n", 183 | " stroke_width = 1\n", 184 | "\n", 185 | " dwg.add(dwg.path(p).stroke(the_color, stroke_width).fill(\"none\"))\n", 186 | "\n", 187 | " dwg.save()\n", 188 | " display(SVG(dwg.tostring()))\n", 189 | "\n", 190 | "def cutoff_stroke(x):\n", 191 | " return np.greater(x,0.5) * 1.0" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "id": "a314d4a9-1540-4fb7-a6c0-5c36feee1e6d", 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "import time\n", 202 | "# Predict a character and plot the result.\n", 203 | "pi_temperature = 2.5 # seems to work well with rather high temperature (2.5)\n", 204 | "sigma_temp = 0.1 # seems to work well with low temp\n", 205 | "\n", 206 | "start_pos = random_start_position()\n", 207 | "start_state_1 = generate_initial_lstm_states(HIDDEN_UNITS)\n", 208 | "start_state_2 = generate_initial_lstm_states(HIDDEN_UNITS)\n", 209 | "input_list = [start_pos] + start_state_1 + start_state_2 # five inputs\n", 210 | "\n", 211 | "sketch = [start_pos.reshape(3,)] # starting value for the sketch.\n", 212 | "\n", 213 | "number_of_movements = 400\n", 214 | "start_time = time.time()\n", 215 | "\n", 216 | "for i in range(number_of_movements):\n", 217 | " output_list = decoder(input_list)\n", 218 | " mdn_values = output_list[0][0].numpy()\n", 219 | " next_point = mdn.sample_from_output(mdn_values, OUTPUT_DIMENSION, NUMBER_MIXTURES, temp=pi_temperature, sigma_temp=sigma_temp)\n", 220 | " sketch.append(next_point.reshape((3,)))\n", 221 | " states = output_list[1:]\n", 222 | " input_list = [next_point.reshape(1, 1, 3)] + states\n", 223 | "\n", 224 | "print(\"Finished. That took\", round((time.time() - start_time)/number_of_movements, 4), \"seconds per generation.\")\n", 225 | "\n", 226 | "# Draw the sketch\n", 227 | "sketch = np.array(sketch)\n", 228 | "sketch.T[2] = cutoff_stroke(sketch.T[2])\n", 229 | "draw_strokes(sketch, factor=0.5)" 230 | ] 231 | } 232 | ], 233 | "metadata": { 234 | "kernelspec": { 235 | "display_name": "Python 3 (ipykernel)", 236 | "language": "python", 237 | "name": "python3" 238 | }, 239 | "language_info": { 240 | "codemirror_mode": { 241 | "name": "ipython", 242 | "version": 3 243 | }, 244 | "file_extension": ".py", 245 | "mimetype": "text/x-python", 246 | "name": "python", 247 | "nbconvert_exporter": "python", 248 | "pygments_lexer": "ipython3", 249 | "version": "3.11.3" 250 | } 251 | }, 252 | "nbformat": 4, 253 | "nbformat_minor": 5 254 | } 255 | -------------------------------------------------------------------------------- /notebooks/MDN-RNN-time-distributed-MDN-training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Writing Kanji with a time distributed MDN-RNN \n", 8 | "\n", 9 | "- This notebook is the same as the Kanji MDN-RNN except that it trains on predictions made over the whole sequence length.\n", 10 | "- The MDN-RNN is also written in Keras' functional API for good measure!\n" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "from tensorflow import keras\n", 20 | "import tensorflow as tf\n", 21 | "from context import * # imports the MDN layer \n", 22 | "import numpy as np\n", 23 | "import random\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "from mpl_toolkits.mplot3d import Axes3D \n", 26 | "%matplotlib widget" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "### Download the Dataset:" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# Train from David Ha's Kanji dataset from Sketch-RNN: https://github.com/hardmaru/sketch-rnn-datasets\n", 43 | "# Other datasets in \"Sketch 3\" format should also work.\n", 44 | "import urllib.request\n", 45 | "url = 'https://github.com/hardmaru/sketch-rnn-datasets/raw/master/kanji/kanji.rdp25.npz' \n", 46 | "urllib.request.urlretrieve(url, './kanji.rdp25.npz') " 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "### Dataset:\n", 54 | "\n", 55 | "Includes about 11000 handwritten kanji characters divied into training, validation, and testing sets.\n", 56 | "\n", 57 | "For creative purposes, we may not need the validation or testing sets, and can just focus on the training set." 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 2, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "name": "stdout", 67 | "output_type": "stream", 68 | "text": [ 69 | "Training kanji: 10358\n", 70 | "Validation kanji: 600\n", 71 | "Testing kanji: 500\n" 72 | ] 73 | } 74 | ], 75 | "source": [ 76 | "with np.load('./kanji.rdp25.npz', allow_pickle=True) as data:\n", 77 | " train_set = data['train']\n", 78 | " valid_set = data['valid']\n", 79 | " test_set = data['test']\n", 80 | " \n", 81 | "print(\"Training kanji:\", len(train_set))\n", 82 | "print(\"Validation kanji:\", len(valid_set))\n", 83 | "print(\"Testing kanji:\", len(test_set))" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "### Setup an MDN RNN" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 5, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "Model: \"kanji-training-td\"\n", 103 | "_________________________________________________________________\n", 104 | " Layer (type) Output Shape Param # \n", 105 | "=================================================================\n", 106 | " inputs (InputLayer) [(None, 50, 3)] 0 \n", 107 | " \n", 108 | " lstm1 (LSTM) (None, 50, 256) 266240 \n", 109 | " \n", 110 | " lstm2 (LSTM) (None, 50, 256) 525312 \n", 111 | " \n", 112 | " td_mdn (TimeDistributed) (None, 50, 70) 17990 \n", 113 | " \n", 114 | "=================================================================\n", 115 | "Total params: 809542 (3.09 MB)\n", 116 | "Trainable params: 809542 (3.09 MB)\n", 117 | "Non-trainable params: 0 (0.00 Byte)\n", 118 | "_________________________________________________________________\n" 119 | ] 120 | } 121 | ], 122 | "source": [ 123 | "# Training Hyperparameters:\n", 124 | "SEQ_LEN = 50\n", 125 | "BATCH_SIZE = 64\n", 126 | "HIDDEN_UNITS = 256\n", 127 | "EPOCHS = 100\n", 128 | "SEED = 2345 # set random seed for reproducibility\n", 129 | "random.seed(SEED)\n", 130 | "np.random.seed(SEED)\n", 131 | "OUTPUT_DIMENSION = 3\n", 132 | "NUMBER_MIXTURES = 10\n", 133 | "\n", 134 | "inputs = keras.layers.Input(shape=(SEQ_LEN,OUTPUT_DIMENSION), name='inputs')\n", 135 | "lstm1_out = keras.layers.LSTM(HIDDEN_UNITS, name='lstm1', return_sequences=True)(inputs)\n", 136 | "lstm2_out = keras.layers.LSTM(HIDDEN_UNITS, name='lstm2', return_sequences=True)(lstm1_out)\n", 137 | "mdn_out = keras.layers.TimeDistributed(mdn.MDN(OUTPUT_DIMENSION, NUMBER_MIXTURES, name='mdn_outputs'), name='td_mdn')(lstm2_out)\n", 138 | "\n", 139 | "model = keras.models.Model(inputs=inputs, outputs=mdn_out, name=\"kanji-training-td\")\n", 140 | "model.compile(loss=mdn.get_mixture_loss_func(OUTPUT_DIMENSION,NUMBER_MIXTURES), optimizer='adam')\n", 141 | "model.summary()" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "## Process the Data and Train the Model\n", 149 | "\n", 150 | "- Chop up the data into slices of the correct length, generate `X` and `y` for the training process.\n", 151 | "- Very similar process to the previous RNN examples!\n", 152 | "- We end up with 330000 examples - a pretty healthy dataset." 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "# Functions for slicing up data\n", 162 | "def slice_sequence_examples(sequence, num_steps):\n", 163 | " xs = []\n", 164 | " for i in range(len(sequence) - num_steps - 1):\n", 165 | " example = sequence[i: i + num_steps]\n", 166 | " xs.append(example)\n", 167 | " return xs\n", 168 | "\n", 169 | "def seq_to_overlapping_format(examples):\n", 170 | " \"\"\"Takes sequences of seq_len+1 and returns overlapping\n", 171 | " sequences of seq_len.\"\"\"\n", 172 | " xs = []\n", 173 | " ys = []\n", 174 | " for ex in examples:\n", 175 | " xs.append(ex[:-1])\n", 176 | " ys.append(ex[1:])\n", 177 | " return (xs,ys)\n", 178 | "\n", 179 | "# Prepare training data as X and Y.\n", 180 | "slices = []\n", 181 | "for seq in train_set:\n", 182 | " slices += slice_sequence_examples(seq, SEQ_LEN+1)\n", 183 | "X, y = seq_to_overlapping_format(slices)\n", 184 | "\n", 185 | "X = np.array(X)\n", 186 | "y = np.array(y)\n", 187 | "\n", 188 | "print(\"Number of training examples:\")\n", 189 | "print(\"X:\", X.shape)\n", 190 | "print(\"y:\", y.shape)" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "# Prepare validation data as X and Y.\n", 200 | "slices = []\n", 201 | "for seq in valid_set:\n", 202 | " slices += slice_sequence_examples(seq, SEQ_LEN+1)\n", 203 | "Xval, yval = seq_to_overlapping_format(slices)\n", 204 | "\n", 205 | "Xval = np.array(Xval)\n", 206 | "yval = np.array(yval)\n", 207 | "\n", 208 | "print(\"Number of training examples:\")\n", 209 | "print(\"X:\", Xval.shape)\n", 210 | "print(\"y:\", yval.shape)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "metadata": {}, 216 | "source": [ 217 | "## Do the training!\n", 218 | "\n", 219 | "- We're not going to train in the tutorial!\n", 220 | "- These settings take about 220 seconds per epoch, about 6 hours for the whole training run." 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "# Fit the model\n", 230 | "filepath=\"kanji_mdnrnn-{epoch:02d}.h5\"\n", 231 | "checkpoint = keras.callbacks.ModelCheckpoint(filepath, save_weights_only=True, verbose=1, save_best_only=True, mode='min')\n", 232 | "early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=10)\n", 233 | "callbacks = [keras.callbacks.TerminateOnNaN(), checkpoint, early_stopping]\n", 234 | "history = model.fit(X, y, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks=callbacks, validation_data=(Xval,yval))\n", 235 | "model.save('kanji_mdnrnn_model_time_distributed.h5') # creates a HDF5 file 'my_model.h5'" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [ 244 | "# summarize history for loss\n", 245 | "plt.plot(history.history['loss'])\n", 246 | "plt.plot(history.history['val_loss'])\n", 247 | "plt.title(\"Kanji MDRNN Training\")\n", 248 | "plt.show()" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": null, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "!ls" 258 | ] 259 | }, 260 | { 261 | "cell_type": "markdown", 262 | "metadata": {}, 263 | "source": [ 264 | "## Try out the model! Generate some Kanji!\n", 265 | "\n", 266 | "We need to create a decoding model with batch size 1 and sequence length 1." 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": null, 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "# Decoding Model\n", 276 | "# Same as training model except for dimension and mixtures.\n", 277 | "\n", 278 | "decoder = keras.Sequential()\n", 279 | "decoder.add(keras.layers.LSTM(HIDDEN_UNITS, batch_input_shape=(1,1,OUTPUT_DIMENSION), return_sequences=True, stateful=True))\n", 280 | "decoder.add(keras.layers.LSTM(HIDDEN_UNITS, stateful=True))\n", 281 | "decoder.add(mdn.MDN(OUTPUT_DIMENSION, NUMBER_MIXTURES))\n", 282 | "decoder.compile(loss=mdn.get_mixture_loss_func(OUTPUT_DIMENSION,NUMBER_MIXTURES), optimizer=keras.optimizers.Adam())\n", 283 | "decoder.summary()\n", 284 | "\n", 285 | "#decoder.load_weights('kanji_mdnrnn_model_time_distributed.h5') # load weights independently from file\n", 286 | "#decoder.load_weights('kanji_mdnrnn-99.hdf5')\n", 287 | "decoder.load_weights('kanji_mdnrnn_model_time_distributed.h5')" 288 | ] 289 | }, 290 | { 291 | "cell_type": "markdown", 292 | "metadata": {}, 293 | "source": [ 294 | "## Generating drawings\n", 295 | "\n", 296 | "- First need some helper functions to view the output." 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "import pandas as pd\n", 306 | "import matplotlib.pyplot as plt\n", 307 | "%matplotlib inline\n", 308 | "\n", 309 | "def zero_start_position():\n", 310 | " \"\"\"A zeroed out start position with pen down\"\"\"\n", 311 | " out = np.zeros((1, 1, 3), dtype=np.float32)\n", 312 | " out[0, 0, 2] = 1 # set pen down.\n", 313 | " return out\n", 314 | "\n", 315 | "def generate_sketch(model, start_pos, num_points=100):\n", 316 | " return None\n", 317 | "\n", 318 | "def cutoff_stroke(x):\n", 319 | " return np.greater(x,0.5) * 1.0\n", 320 | "\n", 321 | "def plot_sketch(sketch_array):\n", 322 | " \"\"\"Plot a sketch quickly to see what it looks like.\"\"\"\n", 323 | " sketch_df = pd.DataFrame({'x':sketch_array.T[0],'y':sketch_array.T[1],'z':sketch_array.T[2]})\n", 324 | " sketch_df.x = sketch_df.x.cumsum()\n", 325 | " sketch_df.y = -1 * sketch_df.y.cumsum()\n", 326 | " # Do the plot\n", 327 | " fig = plt.figure(figsize=(8, 8))\n", 328 | " ax1 = fig.add_subplot(111)\n", 329 | " #ax1.scatter(sketch_df.x,sketch_df.y,marker='o', c='r', alpha=1.0)\n", 330 | " # Need to do something with sketch_df.z\n", 331 | " ax1.plot(sketch_df.x,sketch_df.y,'r-')\n", 332 | " plt.show()" 333 | ] 334 | }, 335 | { 336 | "cell_type": "markdown", 337 | "metadata": {}, 338 | "source": [ 339 | "## SVG Drawing Function\n", 340 | "\n", 341 | "Here's Hardmaru's Drawing Functions from _write-rnn-tensorflow_. Big hat tip to Hardmaru for this!\n", 342 | "\n", 343 | "Here's the source: https://github.com/hardmaru/write-rnn-tensorflow/blob/master/utils.py\n" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": null, 349 | "metadata": {}, 350 | "outputs": [], 351 | "source": [ 352 | "# Hardmaru's Drawing Functions from write-rnn-tensorflow\n", 353 | "# Big hat tip\n", 354 | "# Here's the source:\n", 355 | "# https://github.com/hardmaru/write-rnn-tensorflow/blob/master/utils.py\n", 356 | "\n", 357 | "import svgwrite\n", 358 | "from IPython.display import SVG, display\n", 359 | "\n", 360 | "def get_bounds(data, factor):\n", 361 | " min_x = 0\n", 362 | " max_x = 0\n", 363 | " min_y = 0\n", 364 | " max_y = 0\n", 365 | "\n", 366 | " abs_x = 0\n", 367 | " abs_y = 0\n", 368 | " for i in range(len(data)):\n", 369 | " x = float(data[i, 0]) / factor\n", 370 | " y = float(data[i, 1]) / factor\n", 371 | " abs_x += x\n", 372 | " abs_y += y\n", 373 | " min_x = min(min_x, abs_x)\n", 374 | " min_y = min(min_y, abs_y)\n", 375 | " max_x = max(max_x, abs_x)\n", 376 | " max_y = max(max_y, abs_y)\n", 377 | "\n", 378 | " return (min_x, max_x, min_y, max_y)\n", 379 | "\n", 380 | "def draw_strokes(data, factor=1, svg_filename='sample.svg'):\n", 381 | " min_x, max_x, min_y, max_y = get_bounds(data, factor)\n", 382 | " dims = (50 + max_x - min_x, 50 + max_y - min_y)\n", 383 | "\n", 384 | " dwg = svgwrite.Drawing(svg_filename, size=dims)\n", 385 | " dwg.add(dwg.rect(insert=(0, 0), size=dims, fill='white'))\n", 386 | "\n", 387 | " lift_pen = 1\n", 388 | "\n", 389 | " abs_x = 25 - min_x\n", 390 | " abs_y = 25 - min_y\n", 391 | " p = \"M%s,%s \" % (abs_x, abs_y)\n", 392 | "\n", 393 | " command = \"m\"\n", 394 | "\n", 395 | " for i in range(len(data)):\n", 396 | " if (lift_pen == 1):\n", 397 | " command = \"m\"\n", 398 | " elif (command != \"l\"):\n", 399 | " command = \"l\"\n", 400 | " else:\n", 401 | " command = \"\"\n", 402 | " x = float(data[i, 0]) / factor\n", 403 | " y = float(data[i, 1]) / factor\n", 404 | " lift_pen = data[i, 2]\n", 405 | " p += command + str(x) + \",\" + str(y) + \" \"\n", 406 | "\n", 407 | " the_color = \"black\"\n", 408 | " stroke_width = 2\n", 409 | "\n", 410 | " dwg.add(dwg.path(p).stroke(the_color, stroke_width).fill(\"none\"))\n", 411 | "\n", 412 | " dwg.save()\n", 413 | " display(SVG(dwg.tostring()))" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": null, 419 | "metadata": { 420 | "scrolled": true 421 | }, 422 | "outputs": [], 423 | "source": [ 424 | "# Predict a character and plot the result.\n", 425 | "temperature = 1.5 # seems to work well with rather high temperature (2.5)\n", 426 | "sigma_temp = 0.01\n", 427 | "\n", 428 | "p = zero_start_position()\n", 429 | "sketch = [p.reshape(3,)]\n", 430 | "\n", 431 | "for i in range(100):\n", 432 | " params = decoder.predict(p.reshape(1,1,3))\n", 433 | " p = mdn.sample_from_output(params[0], OUTPUT_DIMENSION, NUMBER_MIXTURES, temp=temperature, sigma_temp=sigma_temp)\n", 434 | " sketch.append(p.reshape((3,)))\n", 435 | "\n", 436 | "sketch = np.array(sketch)\n", 437 | "decoder.reset_states()\n", 438 | "\n", 439 | "sketch.T[2] = cutoff_stroke(sketch.T[2])\n", 440 | "draw_strokes(sketch, factor=0.5)\n", 441 | "#plot_sketch(sketch)" 442 | ] 443 | } 444 | ], 445 | "metadata": { 446 | "kernelspec": { 447 | "display_name": "Python 3 (ipykernel)", 448 | "language": "python", 449 | "name": "python3" 450 | }, 451 | "language_info": { 452 | "codemirror_mode": { 453 | "name": "ipython", 454 | "version": 3 455 | }, 456 | "file_extension": ".py", 457 | "mimetype": "text/x-python", 458 | "name": "python", 459 | "nbconvert_exporter": "python", 460 | "pygments_lexer": "ipython3", 461 | "version": "3.11.3" 462 | } 463 | }, 464 | "nbformat": 4, 465 | "nbformat_minor": 4 466 | } 467 | -------------------------------------------------------------------------------- /notebooks/context.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # little path hack to access module which is one directory up. 4 | import sys 5 | import os 6 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 7 | import keras_mdn_layer as mdn 8 | from tensorflow import keras 9 | -------------------------------------------------------------------------------- /notebooks/figures/kanji-mdn-diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpmpercussion/keras-mdn-layer/bf102dc404d2e04daa975776c195b23c15b27653/notebooks/figures/kanji-mdn-diagram.png -------------------------------------------------------------------------------- /notebooks/figures/kanji_mdn_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpmpercussion/keras-mdn-layer/bf102dc404d2e04daa975776c195b23c15b27653/notebooks/figures/kanji_mdn_examples.png -------------------------------------------------------------------------------- /notebooks/figures/kanji_test_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpmpercussion/keras-mdn-layer/bf102dc404d2e04daa975776c195b23c15b27653/notebooks/figures/kanji_test_1.png -------------------------------------------------------------------------------- /notebooks/figures/kanji_test_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpmpercussion/keras-mdn-layer/bf102dc404d2e04daa975776c195b23c15b27653/notebooks/figures/kanji_test_2.png -------------------------------------------------------------------------------- /notebooks/figures/microjam.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpmpercussion/keras-mdn-layer/bf102dc404d2e04daa975776c195b23c15b27653/notebooks/figures/microjam.gif -------------------------------------------------------------------------------- /notebooks/figures/robojam-action-diagram.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpmpercussion/keras-mdn-layer/bf102dc404d2e04daa975776c195b23c15b27653/notebooks/figures/robojam-action-diagram.jpg -------------------------------------------------------------------------------- /notebooks/figures/robojam-mdn-diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpmpercussion/keras-mdn-layer/bf102dc404d2e04daa975776c195b23c15b27653/notebooks/figures/robojam-mdn-diagram.png -------------------------------------------------------------------------------- /notebooks/figures/robojam-mdn-loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpmpercussion/keras-mdn-layer/bf102dc404d2e04daa975776c195b23c15b27653/notebooks/figures/robojam-mdn-loss.png -------------------------------------------------------------------------------- /notebooks/figures/robojam_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpmpercussion/keras-mdn-layer/bf102dc404d2e04daa975776c195b23c15b27653/notebooks/figures/robojam_examples.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "keras-mdn-layer" 3 | version = "0.5.0" 4 | description = "An MDN Layer for Keras using TensorFlow's distributions module" 5 | authors = ["Charles Martin "] 6 | license = "MIT" 7 | readme = "README.md" 8 | homepage = "https://github.com/cpmpercussion/keras-mdn-layer" 9 | repository = "https://github.com/cpmpercussion/keras-mdn-layer" 10 | keywords = ["mixture density layer", "neural network", "machine learning"] 11 | classifiers = ['Topic :: Scientific/Engineering :: Artificial Intelligence'] 12 | 13 | [tool.poetry.dependencies] 14 | python = "3.11.*" 15 | numpy = "^1.26.4" 16 | tensorflow-probability = "0.24.0" 17 | tensorflow = "2.16.2" 18 | tf-keras = "^2.16.0" 19 | tensorflow-io-gcs-filesystem = [ 20 | {platform="darwin", version = "^0.37.1"}, 21 | {platform="linux", version = "^0.37.1"}, 22 | {platform = "win32", version = "0.31.0"}, 23 | ] 24 | tensorflow-intel = {version="^2.16.2", platform = "win32"} 25 | 26 | [tool.poetry.group.dev.dependencies] 27 | pytest = "^8.1.1" 28 | flake8 = "^7.0.0" 29 | jupyter = "^1.0.0" 30 | matplotlib = "^3.8.4" 31 | ipympl = "^0.9.4" 32 | pandas = "^2.2.2" 33 | svgwrite = "^1.4.3" 34 | black = "^24.8.0" 35 | coveralls = "^4.0.1" 36 | 37 | [build-system] 38 | requires = ["poetry-core"] 39 | build-backend = "poetry.core.masonry.api" 40 | --------------------------------------------------------------------------------