├── mdn_image.png └── README.md /mdn_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/mdn_jax_tutorial/HEAD/mdn_image.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Tutorial: Mixture Density Networks with JAX ## 2 | 3 | 4 | 5 | *April 2020* 6 | 7 | Tutorial Notebook: [mixture_density_networks_jax.ipynb](mixture_density_networks_jax.ipynb) 8 | 9 | Reference paper: [Mixture Density Networks](https://publications.aston.ac.uk/id/eprint/373/1/NCRG_94_004.pdf) (Bishop, 1994) 10 | 11 | 12 | Related posts: 13 | - JavaScript [Tutorial](http://blog.otoro.net/2015/06/14/mixture-density-networks/) 14 | - TensorFlow [Tutorial](http://blog.otoro.net/2015/11/24/mixture-density-networks-with-tensorflow/) 15 | - PyTorch [Tuturial](https://github.com/hardmaru/pytorch_notebooks/blob/master/mixture_density_networks.ipynb) 16 | 17 | This tutorial is based on the recent PyTorch notebook with many improvements added by [kylemcdonald](https://github.com/kylemcdonald). 18 | 19 | *Note: This notebook describes a slightly different loss formulation compared to the previous tutorials that is much more numerically stable, and is used in most of my other recent [projects](https://otoro.net/ml/) that needed MDNs.* 20 | 21 | JAX is a minimal framework to automatically calculate the gradients of native Python and NumPy / SciPy functions. It is a nice tool in the machine learning research toolbox. 22 | 23 | Recommended [JAX](https://github.com/google/jax/) Tutorials: [Getting started with JAX](https://roberttlange.github.io/posts/2020/03/blog-post-10/) and [You don't know JAX](https://colinraffel.com/blog/you-don-t-know-jax.html). 24 | 25 | ## License 26 | 27 | MIT 28 | --------------------------------------------------------------------------------