├── .gitignore ├── .travis.yml ├── LICENSE ├── Makefile ├── README.md ├── data ├── after.png └── before.png ├── example └── example.py ├── mdn ├── __init__.py └── mdn.py ├── requirements.txt ├── setup.py └── tests ├── __init__.py ├── basic_test.py └── gradient_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | .*swp 2 | __pycache__ 3 | pytorch_mdn.egg-info 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.8" 4 | - "3.9" 5 | - "3.10" 6 | - "3.11" 7 | - "3.12" 8 | - "3.13" 9 | install: make init 10 | script: make test 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Benjamin Bastian 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | init: 2 | pip install -r requirements.txt 3 | 4 | test: 5 | py.test tests 6 | 7 | .PHONY: init test 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-mdn 2 | 3 | ![Build Status](https://app.travis-ci.com/sagelywizard/pytorch-mdn.svg?token=dz4Mst8SUgYSUZdMqRsK&branch=master) 4 | 5 | This repo contains the code for [mixture density networks](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.120.5685&rep=rep1&type=pdf). 6 | 7 | ## Usage: 8 | 9 | ```python 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import mdn 13 | 14 | # initialize the model 15 | model = nn.Sequential( 16 | nn.Linear(5, 6), 17 | nn.Tanh(), 18 | mdn.MDN(6, 7, 20) 19 | ) 20 | optimizer = optim.Adam(model.parameters()) 21 | 22 | # train the model 23 | for minibatch, labels in train_set: 24 | model.zero_grad() 25 | pi, sigma, mu = model(minibatch) 26 | loss = mdn.mdn_loss(pi, sigma, mu, labels) 27 | loss.backward() 28 | optimizer.step() 29 | 30 | # sample new points from the trained model 31 | minibatch = next(test_set) 32 | pi, sigma, mu = model(minibatch) 33 | samples = mdn.sample(pi, sigma, mu) 34 | ``` 35 | 36 | ### Example 37 | 38 | Red are training data. 39 | 40 | ![before](https://github.com/sagelywizard/pytorch-mdn/raw/master/data/before.png) 41 | 42 | Blue are samples from a trained MDN. 43 | 44 | ![after](https://github.com/sagelywizard/pytorch-mdn/raw/master/data/after.png) 45 | 46 | For a full example with code, see `example/example.py` 47 | -------------------------------------------------------------------------------- /data/after.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagelywizard/pytorch-mdn/b5744b88eea88bc138fc19bc66c87e81dd5e340a/data/after.png -------------------------------------------------------------------------------- /data/before.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagelywizard/pytorch-mdn/b5744b88eea88bc138fc19bc66c87e81dd5e340a/data/before.png -------------------------------------------------------------------------------- /example/example.py: -------------------------------------------------------------------------------- 1 | """A script that shows how to use the MDN. It's a simple MDN with a single 2 | nonlinearity that's trained to output 1D samples given a 2D input. 3 | """ 4 | import matplotlib.pyplot as plt 5 | import sys 6 | sys.path.append('../mdn') 7 | import mdn 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | 12 | input_dims = 2 13 | output_dims = 1 14 | num_gaussians = 5 15 | 16 | 17 | def translate_cluster(cluster, dim, amount): 18 | """Translates a cluster in a particular dimension by some amount 19 | """ 20 | translation = torch.ones(cluster.size(0)) * amount 21 | cluster.transpose(0, 1)[dim].add_(translation) 22 | return cluster 23 | 24 | 25 | print("Generating training data... ", end='') 26 | cluster1 = torch.randn((50, input_dims + output_dims)) / 4 27 | cluster1 = translate_cluster(cluster1, 1, 1.2) 28 | cluster2 = torch.randn((50, input_dims + output_dims)) / 4 29 | cluster2 = translate_cluster(cluster2, 0, -1.2) 30 | cluster3 = torch.randn((50, input_dims + output_dims)) / 4 31 | cluster3 = translate_cluster(cluster3, 2, -1.2) 32 | training_set = torch.cat([cluster1, cluster2, cluster3]) 33 | print('Done') 34 | 35 | print("Initializing model... ", end='') 36 | model = nn.Sequential( 37 | nn.Linear(input_dims, 5), 38 | nn.Tanh(), 39 | mdn.MDN(5, output_dims, num_gaussians) 40 | ) 41 | 42 | optimizer = optim.Adam(model.parameters()) 43 | print('Done') 44 | 45 | print('Training model... ', end='') 46 | sys.stdout.flush() 47 | for epoch in range(1000): 48 | model.zero_grad() 49 | pi, sigma, mu = model(training_set[:, 0:input_dims]) 50 | loss = mdn.mdn_loss(pi, sigma, mu, training_set[:, input_dims:]) 51 | loss.backward() 52 | optimizer.step() 53 | if epoch % 100 == 99: 54 | print(f' {round(epoch/10)}%', end='') 55 | sys.stdout.flush() 56 | print(' Done') 57 | 58 | print('Generating samples... ', end='') 59 | pi, sigma, mu = model(training_set[:, 0:input_dims]) 60 | samples = mdn.sample(pi, sigma, mu) 61 | print('Done') 62 | 63 | print('Saving samples.png... ', end='') 64 | fig = plt.figure() 65 | ax = fig.add_subplot(projection='3d') 66 | 67 | xs = training_set[:, 0] 68 | ys = training_set[:, 1] 69 | zs = training_set[:, 2] 70 | 71 | ax.scatter(xs, ys, zs, label='target') 72 | ax.scatter(xs, ys, samples, label='samples') 73 | ax.legend() 74 | fig.savefig('samples.png') 75 | print('Done') 76 | -------------------------------------------------------------------------------- /mdn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagelywizard/pytorch-mdn/b5744b88eea88bc138fc19bc66c87e81dd5e340a/mdn/__init__.py -------------------------------------------------------------------------------- /mdn/mdn.py: -------------------------------------------------------------------------------- 1 | """A module for a mixture density network layer 2 | 3 | For more info on MDNs, see _Mixture Desity Networks_ by Bishop, 1994. 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torch.autograd import Variable 9 | from torch.distributions import Categorical 10 | import math 11 | 12 | 13 | ONEOVERSQRT2PI = 1.0 / math.sqrt(2 * math.pi) 14 | 15 | 16 | class MDN(nn.Module): 17 | """A mixture density network layer 18 | 19 | The input maps to the parameters of a MoG probability distribution, where 20 | each Gaussian has O dimensions and diagonal covariance. 21 | 22 | Arguments: 23 | in_features (int): the number of dimensions in the input 24 | out_features (int): the number of dimensions in the output 25 | num_gaussians (int): the number of Gaussians per output dimensions 26 | 27 | Input: 28 | minibatch (BxD): B is the batch size and D is the number of input 29 | dimensions. 30 | 31 | Output: 32 | (pi, sigma, mu) (BxG, BxGxO, BxGxO): B is the batch size, G is the 33 | number of Gaussians, and O is the number of dimensions for each 34 | Gaussian. Pi is a multinomial distribution of the Gaussians. Sigma 35 | is the standard deviation of each Gaussian. Mu is the mean of each 36 | Gaussian. 37 | """ 38 | 39 | def __init__(self, in_features, out_features, num_gaussians): 40 | super(MDN, self).__init__() 41 | self.in_features = in_features 42 | self.out_features = out_features 43 | self.num_gaussians = num_gaussians 44 | self.pi = nn.Sequential( 45 | nn.Linear(in_features, num_gaussians), 46 | nn.Softmax(dim=1) 47 | ) 48 | self.sigma = nn.Linear(in_features, out_features * num_gaussians) 49 | self.mu = nn.Linear(in_features, out_features * num_gaussians) 50 | 51 | def forward(self, minibatch): 52 | pi = self.pi(minibatch) 53 | sigma = torch.exp(self.sigma(minibatch)) 54 | sigma = sigma.view(-1, self.num_gaussians, self.out_features) 55 | mu = self.mu(minibatch) 56 | mu = mu.view(-1, self.num_gaussians, self.out_features) 57 | return pi, sigma, mu 58 | 59 | 60 | def gaussian_probability(sigma, mu, target): 61 | """Returns the probability of `target` given MoG parameters `sigma` and `mu`. 62 | 63 | Arguments: 64 | sigma (BxGxO): The standard deviation of the Gaussians. B is the batch 65 | size, G is the number of Gaussians, and O is the number of 66 | dimensions per Gaussian. 67 | mu (BxGxO): The means of the Gaussians. B is the batch size, G is the 68 | number of Gaussians, and O is the number of dimensions per Gaussian. 69 | target (BxI): A batch of target. B is the batch size and I is the number of 70 | input dimensions. 71 | 72 | Returns: 73 | probabilities (BxG): The probability of each point in the probability 74 | of the distribution in the corresponding sigma/mu index. 75 | """ 76 | target = target.unsqueeze(1).expand_as(sigma) 77 | ret = ONEOVERSQRT2PI * torch.exp(-0.5 * ((target - mu) / sigma)**2) / sigma 78 | return torch.prod(ret, 2) 79 | 80 | 81 | def mdn_loss(pi, sigma, mu, target): 82 | """Calculates the error, given the MoG parameters and the target 83 | 84 | The loss is the negative log likelihood of the data given the MoG 85 | parameters. 86 | """ 87 | prob = pi * gaussian_probability(sigma, mu, target) 88 | nll = -torch.log(torch.sum(prob, dim=1)) 89 | return torch.mean(nll) 90 | 91 | 92 | def sample(pi, sigma, mu): 93 | """Draw samples from a MoG. 94 | """ 95 | # Choose which gaussian we'll sample from 96 | pis = Categorical(pi).sample().view(pi.size(0), 1, 1) 97 | # Choose a random sample, one randn for batch X output dims 98 | # Do a (output dims)X(batch size) tensor here, so the broadcast works in 99 | # the next step, but we have to transpose back. 100 | gaussian_noise = torch.randn( 101 | (sigma.size(2), sigma.size(0)), requires_grad=False) 102 | variance_samples = sigma.gather(1, pis).detach().squeeze() 103 | mean_samples = mu.detach().gather(1, pis).squeeze() 104 | return (gaussian_noise * variance_samples + mean_samples).transpose(0, 1) 105 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.0.0 2 | pytest>=6.0.0 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="pytorch-mdn", 8 | version="0.0.2", 9 | author="Benjamin Bastian", 10 | description="A mixture density network module for PyTorch", 11 | long_description=long_description, 12 | long_description_content_type="text/markdown", 13 | url="https://github.com/sagelywizard/pytorch-mdn", 14 | project_urls={ 15 | "Bug Tracker": "https://github.com/sagelywizard/pytorch-mdn/issues", 16 | }, 17 | classifiers=[ 18 | "Programming Language :: Python :: 3", 19 | "License :: OSI Approved :: MIT License", 20 | "Operating System :: OS Independent", 21 | ], 22 | package_dir={"": "mdn"}, 23 | packages=setuptools.find_packages(where="mdn"), 24 | python_requires=">=3.6", 25 | ) 26 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagelywizard/pytorch-mdn/b5744b88eea88bc138fc19bc66c87e81dd5e340a/tests/__init__.py -------------------------------------------------------------------------------- /tests/basic_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from mdn import mdn 4 | 5 | 6 | class TestMDNOutputs(unittest.TestCase): 7 | def setUp(self): 8 | self.mdn = mdn.MDN(4, 6, 10) 9 | 10 | def testOutputShape(self): 11 | minibatch = torch.randn((2, 4)) 12 | pi, sigma, mu = self.mdn(minibatch) 13 | self.assertEqual(pi.size(), (2, 10)) 14 | self.assertEqual(sigma.size(), (2, 10, 6)) 15 | self.assertEqual(mu.size(), (2, 10, 6)) 16 | 17 | def testPiSumsToOne(self): 18 | # Pi represents a categorical distirbution across the gaussians, so it 19 | # should sum to 1 20 | minibatch = torch.randn((2, 4)) 21 | pi, _, _ = self.mdn(minibatch) 22 | self.assertTrue( 23 | all(torch.isclose(pi.sum(dim=1), torch.ones(pi.size(0))))) 24 | -------------------------------------------------------------------------------- /tests/gradient_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from mdn import mdn 6 | 7 | 8 | class BackpropDecreasesLossMDN(unittest.TestCase): 9 | def testLossDecreases(self): 10 | model = nn.Sequential( 11 | nn.Linear(1, 5), 12 | nn.Tanh(), 13 | mdn.MDN(5, 1, 2) 14 | ) 15 | 16 | torch.manual_seed(0) 17 | first_loss = None 18 | training_set = torch.randn((100, 2)) 19 | optimizer = optim.Adam(model.parameters()) 20 | 21 | for _ in range(10): 22 | model.zero_grad() 23 | pi, sigma, mu = model(training_set[:, 0:1]) 24 | loss = mdn.mdn_loss(pi, sigma, mu, training_set[:, 1:]) 25 | loss.backward() 26 | optimizer.step() 27 | if first_loss is None: 28 | first_loss = loss 29 | 30 | self.assertLess(loss, first_loss) 31 | --------------------------------------------------------------------------------