├── .gitignore ├── LICENSE ├── README.md ├── examples ├── __init__.py ├── ex_1d.png └── ex_1d.py └── src ├── __init__.py └── blocks.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pyc 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Tony Duan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Mixture Density Network 2 | 3 | Last update: December 2022. 4 | 5 | --- 6 | 7 | Lightweight implementation of a mixture density network [1] in PyTorch. 8 | 9 | #### Setup 10 | 11 | Suppose we want to regress response $\mathbf{y} \in \mathbb{R}^{d}$ using covariates $\mathbf{x} \in \mathbb{R}^n$. 12 | 13 | We model the conditional distribution as a mixture of Gaussians 14 | ```math 15 | p_\theta(\mathbf{y}|\mathbf{x}) = \sum_{k=1}^K \pi_k N(\boldsymbol\mu^{(k)}, {\boldsymbol\Sigma}^{(k)}), 16 | ``` 17 | where the mixture distribution parameters are output by a neural network dependent on $\mathbf{x}$. 18 | ```math 19 | \begin{align*} 20 | ( \boldsymbol\pi & \in\Delta^{K-1} & \boldsymbol\mu^{(k)}&\in\mathbb{R}^{d} &\boldsymbol\Sigma^{(k)}&\in \mathrm{S}_+^d) = f_\theta(\mathbf{x}) 21 | \end{align*} 22 | ``` 23 | The training objective is to maximize log-likelihood. The objective is clearly non-convex. 24 | ```math 25 | \begin{align*} 26 | \log p_\theta(\mathbf{y}|\mathbf{x}) 27 | & \propto\log \sum_{k}\left(\pi_k\exp\left(-\frac{1}{2}\left(\mathbf{y}-\boldsymbol\mu^{(k)}\right)^\top {\boldsymbol\Sigma^{(k)}}^{-1}\left(\mathbf{y}-\boldsymbol\mu^{(k)}\right) -\frac{1}{2}\log\det \boldsymbol\Sigma^{(k)}\right)\right)\\ 28 | & = \mathrm{logsumexp}_k\left(\log\pi_k - \frac{1}{2}\left(\mathbf{y}-\boldsymbol\mu^{(k)}\right)^\top {\boldsymbol\Sigma^{(k)}}^{-1}\left(\mathbf{y}-\boldsymbol\mu^{(k)}\right) -\frac{1}{2}\log\det \boldsymbol\Sigma^{(k)}\right)\\ 29 | \end{align*} 30 | ``` 31 | Importantly, we need to use `torch.log_softmax(...)` to compute logits $\log \boldsymbol\pi$ for numerical stability. 32 | 33 | #### Noise Model 34 | 35 | There are several options we can make to constrain the noise model $\boldsymbol\Sigma^{(k)}$. 36 | 37 | 1. No assumptions, $\boldsymbol\Sigma^{(k)} \in \mathrm{S}_+^d$. 38 | 2. Fully factored, let $\boldsymbol\Sigma^{(k)} = \mathrm{diag}({\boldsymbol\sigma^{(k)}}^{2}), {\boldsymbol\sigma^{(k)}}^{2}\in\mathbb{R}_+^d$ where the noise level for each dimension is predicted separately. 39 | 3. Isotrotopic, let $\boldsymbol\Sigma^{(k)} = {\sigma^{(k)}}^{2}\mathbf{I}, {\sigma^{(k)}}^{2}\in\mathbb{R}_+$ which assumes the same noise level for each dimension over $d$. 40 | 4. Isotropic across clusters, let $\boldsymbol\Sigma^{(k)} = \sigma^2\mathbf{I}, \sigma^2\in\mathbb{R}_+$ which assumes the same noise level for each dimension over $d$ *and* cluster. 41 | 5. Fixed isotropic, same as above but do not learn $\sigma^2$. 42 | 43 | Thse correspond to the following objectives. 44 | ```math 45 | \begin{align*} 46 | \log p_\theta(\mathbf{y}|\mathbf{x}) & = \mathrm{logsumexp}_k\left(\log\pi_k - \frac{1}{2}\left(\mathbf{y}-\boldsymbol\mu^{(k)}\right)^\top {\boldsymbol\Sigma^{(k)}}^{-1}\left(\mathbf{y}-\boldsymbol\mu^{(k)}\right) -\frac{1}{2}\log\det \boldsymbol\Sigma^{(k)}\right) \tag{1}\\ 47 | & = \mathrm{logsumexp}_k \left(\log\pi_k - \frac{1}{2}\left\|\frac{\mathbf{y}-\boldsymbol\mu^{(k)}}{\boldsymbol\sigma^{(k)}}\right\|^2-\|\log\boldsymbol\sigma^{(k)}\|_1\right) \tag{2}\\ 48 | & = \mathrm{logsumexp}_k \left(\log\pi_k - \frac{1}{2}\left\|\frac{\mathbf{y}-\boldsymbol\mu^{(k)}}{\sigma^{(k)}}\right\|^2-d\log(\sigma^{(k)})\right) \tag{3}\\ 49 | & = \mathrm{logsumexp}_k \left(\log\pi_k - \frac{1}{2}\left\|\frac{\mathbf{y}-\boldsymbol\mu^{(k)}}{\sigma}\right\|^2-d\log(\sigma)\right) \tag{4}\\ 50 | & = \mathrm{logsumexp}_k \left(\log\pi_k - \frac{1}{2}\left\|\frac{\mathbf{y}-\boldsymbol\mu^{(k)}}{\sigma}\right\|^2\right) \tag{5} 51 | \end{align*} 52 | ``` 53 | In this repository we implement options (2, 3, 4, 5). 54 | 55 | #### Miscellaneous 56 | 57 | Recall that the objective is clearly non-convex. For example, one local minimum is to ignore all modes except one and place a single diffuse Gaussian distribution on the marginal outcome (i.e. high ${\sigma}^{(k)}$). 58 | 59 | For this reason it's often preferable to over-parameterize the model and specify `n_components` higher than the true hypothesized number of modes. 60 | 61 | #### Usage 62 | 63 | ```python 64 | import torch 65 | from src.blocks import MixtureDensityNetwork 66 | 67 | x = torch.randn(5, 1) 68 | y = torch.randn(5, 1) 69 | 70 | # 1D input, 1D output, 3 mixture components 71 | model = MixtureDensityNetwork(1, 1, n_components=3, hidden_dim=50) 72 | pred_parameters = model(x) 73 | 74 | # use this to backprop 75 | loss = model.loss(x, y) 76 | 77 | # use this to sample a trained model 78 | samples = model.sample(x) 79 | ``` 80 | 81 | For further details see the `examples/` folder. Below is a model fit with 3 components in `ex_1d.py`. 82 | 83 | ![ex_model](examples/ex_1d.png "Example model output") 84 | 85 | #### References 86 | 87 | [1] Bishop, C. M. Mixture density networks. (1994). 88 | 89 | [2] Ha, D. & Schmidhuber, J. Recurrent World Models Facilitate Policy Evolution. in *Advances in Neural Information Processing Systems 31* (eds. Bengio, S. et al.) 2450–2462 (Curran Associates, Inc., 2018). 90 | 91 | #### License 92 | 93 | This code is available under the MIT License. 94 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyduan/mixture-density-network/5bfeae42de2adf42d680bdfee24bb8c2ce52f259/examples/__init__.py -------------------------------------------------------------------------------- /examples/ex_1d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyduan/mixture-density-network/5bfeae42de2adf42d680bdfee24bb8c2ce52f259/examples/ex_1d.png -------------------------------------------------------------------------------- /examples/ex_1d.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import logging 3 | 4 | from matplotlib import pyplot as plt 5 | import numpy as np 6 | import torch 7 | import torch.optim as optim 8 | 9 | from src.blocks import MixtureDensityNetwork, NoiseType 10 | 11 | 12 | def gen_data(n=512): 13 | y = np.linspace(-1, 1, n) 14 | x = 7 * np.sin(5 * y) + 0.5 * y + 0.5 * np.random.randn(*y.shape) 15 | return x[:,np.newaxis], y[:,np.newaxis] 16 | 17 | def plot_data(x, y): 18 | plt.hist2d(x, y, bins=35) 19 | plt.xlim(-8, 8) 20 | plt.ylim(-1, 1) 21 | plt.axis('off') 22 | 23 | 24 | if __name__ == "__main__": 25 | 26 | argparser = ArgumentParser() 27 | argparser.add_argument("--n-iterations", type=int, default=2000) 28 | args = argparser.parse_args() 29 | 30 | logging.basicConfig(level=logging.INFO) 31 | logger = logging.getLogger(__name__) 32 | 33 | x, y = gen_data() 34 | x = torch.Tensor(x) 35 | y = torch.Tensor(y) 36 | 37 | model = MixtureDensityNetwork(1, 1, n_components=3, hidden_dim=50, noise_type=NoiseType.DIAGONAL) 38 | optimizer = optim.Adam(model.parameters(), lr=0.005) 39 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.n_iterations) 40 | 41 | for i in range(args.n_iterations): 42 | optimizer.zero_grad() 43 | loss = model.loss(x, y).mean() 44 | loss.backward() 45 | optimizer.step() 46 | scheduler.step() 47 | if i % 100 == 0: 48 | logger.info(f"Iter: {i}\t" + f"Loss: {loss.data:.2f}") 49 | 50 | with torch.no_grad(): 51 | y_hat = model.sample(x) 52 | 53 | plt.figure(figsize=(8, 3)) 54 | plt.subplot(1, 2, 1) 55 | plot_data(x[:, 0].numpy(), y[:, 0].numpy()) 56 | plt.title("Observed data") 57 | plt.subplot(1, 2, 2) 58 | plot_data(x[:, 0].numpy(), y_hat[:, 0].numpy()) 59 | plt.title("Sampled data") 60 | plt.show() 61 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyduan/mixture-density-network/5bfeae42de2adf42d680bdfee24bb8c2ce52f259/src/__init__.py -------------------------------------------------------------------------------- /src/blocks.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, auto 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class NoiseType(Enum): 9 | DIAGONAL = auto() 10 | ISOTROPIC = auto() 11 | ISOTROPIC_ACROSS_CLUSTERS = auto() 12 | FIXED = auto() 13 | 14 | 15 | class MixtureDensityNetwork(nn.Module): 16 | """ 17 | Mixture density network. 18 | 19 | [ Bishop, 1994 ] 20 | 21 | Parameters 22 | ---------- 23 | dim_in: int; dimensionality of the covariates 24 | dim_out: int; dimensionality of the response variable 25 | n_components: int; number of components in the mixture model 26 | """ 27 | def __init__(self, dim_in, dim_out, n_components, hidden_dim, noise_type=NoiseType.DIAGONAL, fixed_noise_level=None): 28 | super().__init__() 29 | assert (fixed_noise_level is not None) == (noise_type is NoiseType.FIXED) 30 | num_sigma_channels = { 31 | NoiseType.DIAGONAL: dim_out * n_components, 32 | NoiseType.ISOTROPIC: n_components, 33 | NoiseType.ISOTROPIC_ACROSS_CLUSTERS: 1, 34 | NoiseType.FIXED: 0, 35 | }[noise_type] 36 | self.dim_in, self.dim_out, self.n_components = dim_in, dim_out, n_components 37 | self.noise_type, self.fixed_noise_level = noise_type, fixed_noise_level 38 | self.pi_network = nn.Sequential( 39 | nn.Linear(dim_in, hidden_dim), 40 | nn.ReLU(), 41 | nn.Linear(hidden_dim, hidden_dim), 42 | nn.ReLU(), 43 | nn.Linear(hidden_dim, n_components), 44 | ) 45 | self.normal_network = nn.Sequential( 46 | nn.Linear(dim_in, hidden_dim), 47 | nn.ReLU(), 48 | nn.Linear(hidden_dim, hidden_dim), 49 | nn.ReLU(), 50 | nn.Linear(hidden_dim, dim_out * n_components + num_sigma_channels) 51 | ) 52 | 53 | def forward(self, x, eps=1e-6): 54 | # 55 | # Returns 56 | # ------- 57 | # log_pi: (bsz, n_components) 58 | # mu: (bsz, n_components, dim_out) 59 | # sigma: (bsz, n_components, dim_out) 60 | # 61 | log_pi = torch.log_softmax(self.pi_network(x), dim=-1) 62 | normal_params = self.normal_network(x) 63 | mu = normal_params[..., :self.dim_out * self.n_components] 64 | sigma = normal_params[..., self.dim_out * self.n_components:] 65 | if self.noise_type is NoiseType.DIAGONAL: 66 | sigma = torch.exp(sigma + eps) 67 | if self.noise_type is NoiseType.ISOTROPIC: 68 | sigma = torch.exp(sigma + eps).repeat(1, self.dim_out) 69 | if self.noise_type is NoiseType.ISOTROPIC_ACROSS_CLUSTERS: 70 | sigma = torch.exp(sigma + eps).repeat(1, self.n_components * self.dim_out) 71 | if self.noise_type is NoiseType.FIXED: 72 | sigma = torch.full_like(mu, fill_value=self.fixed_noise_level) 73 | mu = mu.reshape(-1, self.n_components, self.dim_out) 74 | sigma = sigma.reshape(-1, self.n_components, self.dim_out) 75 | return log_pi, mu, sigma 76 | 77 | def loss(self, x, y): 78 | log_pi, mu, sigma = self.forward(x) 79 | z_score = (y.unsqueeze(1) - mu) / sigma 80 | normal_loglik = ( 81 | -0.5 * torch.einsum("bij,bij->bi", z_score, z_score) 82 | -torch.sum(torch.log(sigma), dim=-1) 83 | ) 84 | loglik = torch.logsumexp(log_pi + normal_loglik, dim=-1) 85 | return -loglik 86 | 87 | def sample(self, x): 88 | log_pi, mu, sigma = self.forward(x) 89 | cum_pi = torch.cumsum(torch.exp(log_pi), dim=-1) 90 | rvs = torch.rand(len(x), 1).to(x) 91 | rand_pi = torch.searchsorted(cum_pi, rvs) 92 | rand_normal = torch.randn_like(mu) * sigma + mu 93 | samples = torch.take_along_dim(rand_normal, indices=rand_pi.unsqueeze(-1), dim=1).squeeze(dim=1) 94 | return samples 95 | --------------------------------------------------------------------------------