├── gaussian-loopy-bp ├── __init__.py ├── README.md ├── gauss_bp_utils_test.py ├── gauss_chain_test.py ├── gauss_bp_utils.py ├── gauss_chain.py ├── gauss_factor_graph_test.py ├── gauss_factor_graph.py └── gauss-bp-1d-line.ipynb ├── README.md ├── deprecated-gauss-bp ├── README.md ├── README.md~ ├── variable_node.py ├── gaussian.py ├── factor.py ├── factor_graph.py └── gauss-bp-1d-line.ipynb └── LICENSE /gaussian-loopy-bp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pgm-jax 2 | Probabilistic Graphical Models in JAX. 3 | This is a work in progress. 4 | Currently we only support belief propagation in (loopy) 5 | Gaussian PGMs. In the case of chain structured graphs 6 | with linear-Gaussian potentials, 7 | this gives the same result as Kalman smoothing. 8 | 9 | -------------------------------------------------------------------------------- /deprecated-gauss-bp/README.md: -------------------------------------------------------------------------------- 1 | 2 | Loopy Belief Propagation for Gaussian Graphical Models in JAX. 3 | 4 | The code is based on [this PyTorch colab](https://colab.research.google.com/drive/1-nrE95X4UC9FBLR0-cTnsIP_XhA_PZKW?usp=sharing) 5 | by Joseph Ortiz. The translation to JAX is by moloydas@, murphyk@. 6 | The code allows for nonlinear factors by iteratively linearizing the factors. 7 | Thus the MAP estimate corresponds to solving a nonlinear least squares problem, but we also get approximate posterior marginals. 8 | 9 | 10 | -------------------------------------------------------------------------------- /gaussian-loopy-bp/README.md: -------------------------------------------------------------------------------- 1 | 2 | Loopy Belief Propagation for Gaussian Graphical Models in JAX. 3 | 4 | This code is based on [this PyTorch colab](https://colab.research.google.com/drive/1-nrE95X4UC9FBLR0-cTnsIP_XhA_PZKW?usp=sharing) 5 | by Joseph Ortiz. The translation to JAX is by Giles Harper-Donnelly. 6 | However, it has been completely redesigned to be functionally pure, rather than object-oriented, 7 | which makes it faster. It has also been simplified so it only works with linear Gaussian models, 8 | and does not support iterative relinearization or robust potentials. The unit test checks that it gives the same results as Kalman smoothing on a chain-structured model. 9 | -------------------------------------------------------------------------------- /deprecated-gauss-bp/README.md~: -------------------------------------------------------------------------------- 1 | # ggm-jax 2 | Gaussian Graphical Models in JAX. 3 | 4 | We currently just support Gaussian loopy belief propagation. 5 | The code is based on [this PyTorch colab](https://colab.research.google.com/drive/1-nrE95X4UC9FBLR0-cTnsIP_XhA_PZKW?usp=sharing) 6 | by Joseph Ortiz, which is explained at https://gaussianbp.github.io/. 7 | This allows for nonlinear factors by iteratively linearizing the factors. 8 | Thus the MAP estimate corresponds to solving a nonlinear least squares problem. 9 | 10 | In the future we may add exact Gaussian inference based on junction trees, and Gaussian variational inference. 11 | 12 | Authors: moloydas@, murphyk@. 13 | 14 | -------------------------------------------------------------------------------- /gaussian-loopy-bp/gauss_bp_utils_test.py: -------------------------------------------------------------------------------- 1 | from jax import numpy as jnp 2 | 3 | from gauss_bp.gauss_bp_utils import potential_from_conditional_linear_gaussian, info_condition 4 | 5 | _all_close = lambda x, y: jnp.allclose(x, y, rtol=1e-3, atol=1e-3) 6 | 7 | 8 | def test_clg_potential(): 9 | """Test consistency of conditional linear gaussian potential function 10 | with joint conditioning function. 11 | 12 | p(y|z) = N(y | Az + u, Lambda^{-1}) 13 | """ 14 | # Parameters 15 | A = jnp.array([[1.0, 1.0, 0.0, 1.0], [0.0, 1.0, 2.0, 3.0]]) 16 | z = jnp.ones((4, 1)) 17 | offset = jnp.ones((2, 1)) 18 | Lambda = jnp.ones((2, 2)) * 2 19 | 20 | # Form joint potential \phi(y,z) 21 | (Kzz, Kzy, Kyy), (hz, hy) = potential_from_conditional_linear_gaussian(A, offset, Lambda) 22 | # Condition on z 23 | K_cond, h_cond = info_condition(Kyy, Kzy.T, hy, z) 24 | 25 | # Check that conditioning phi(z,y) on z returns the same parameters as 26 | # explicitly calculating linear conditional. 27 | assert _all_close(Lambda, K_cond) 28 | assert _all_close(Lambda @ (A @ z + offset), h_cond) 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Probabilistic machine learning 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /deprecated-gauss-bp/variable_node.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from gaussian import Gaussian 3 | 4 | class VariableNode: 5 | def __init__(self, id: int, dofs: int, properties: dict = {}) -> None: 6 | self.variableID = id 7 | self.properties = properties 8 | self.dofs = dofs 9 | self.adj_factors = [] 10 | self.belief = Gaussian(dofs) 11 | self.prior = Gaussian(dofs) # prior factor, implemented as part of variable node 12 | 13 | def update_belief(self) -> None: 14 | """ Update local belief estimate by taking product of all incoming messages along all edges. """ 15 | self.belief.eta = self.prior.eta.clone() # message from prior factor 16 | self.belief.lam = self.prior.lam.clone() 17 | for factor in self.adj_factors: # messages from other adjacent variables 18 | message_ix = factor.adj_vIDs.index(self.variableID) 19 | self.belief.eta += factor.messages[message_ix].eta 20 | self.belief.lam += factor.messages[message_ix].lam 21 | 22 | def get_prior_energy(self) -> float: 23 | energy = 0. 24 | if self.prior.lam[0, 0] != 0.: 25 | residual = self.belief.mean() - self.prior.mean() 26 | energy += 0.5 * residual @ self.prior.lam @ residual 27 | return energy 28 | 29 | -------------------------------------------------------------------------------- /gaussian-loopy-bp/gauss_chain_test.py: -------------------------------------------------------------------------------- 1 | from jax import numpy as jnp 2 | from jax import random as jr 3 | 4 | from gauss_bp.gauss_chain import gauss_chain_potentials_from_lgssm, gauss_chain_bp 5 | from ssm_jax.linear_gaussian_ssm.inference import lgssm_sample 6 | from ssm_jax.linear_gaussian_ssm.info_inference import lgssm_info_smoother 7 | from ssm_jax.linear_gaussian_ssm.info_inference_test import build_lgssm_moment_and_info_form 8 | 9 | _all_close = lambda x,y: jnp.allclose(x,y,rtol=1e-3, atol=1e-3) 10 | 11 | def test_gauss_chain_bp(): 12 | """Test that Gaussian chain belief propagation gets the same results as 13 | information form RTS smoother.""" 14 | 15 | lgssm, lgssm_info = build_lgssm_moment_and_info_form() 16 | 17 | key = jr.PRNGKey(111) 18 | num_timesteps = 15 19 | input_size = lgssm.dynamics_input_weights.shape[1] 20 | inputs = jnp.zeros((num_timesteps, input_size)) 21 | x, y = lgssm_sample(key, lgssm, num_timesteps, inputs=inputs) 22 | 23 | lgssm_info_posterior = lgssm_info_smoother(lgssm_info, y, inputs) 24 | 25 | chain_pots = gauss_chain_potentials_from_lgssm(lgssm_info, inputs) 26 | 27 | smoothed_bels = gauss_chain_bp(chain_pots, y) 28 | Ks, hs = smoothed_bels 29 | 30 | assert _all_close(lgssm_info_posterior.smoothed_precisions,Ks) 31 | assert _all_close(lgssm_info_posterior.smoothed_etas,hs) 32 | 33 | 34 | -------------------------------------------------------------------------------- /deprecated-gauss-bp/gaussian.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | from typing import List, Callable, Optional, Union 4 | 5 | class Gaussian: 6 | def __init__(self, dim: int, eta: Optional[jnp.array]=None, lam: Optional[jnp.array]=None, type: jnp.dtype = np.float): 7 | self.dim = dim 8 | 9 | if eta is not None and len(eta) == dim: 10 | self.eta = eta.type(type) 11 | else: 12 | self.eta = jnp.zeros(dim, dtype=type) 13 | 14 | if lam is not None and lam.shape == (dim, dim): 15 | self.lam = lam.type(type) 16 | else: 17 | self.lam = jnp.zeros([dim, dim], dtype=type) 18 | 19 | def mean(self) -> jnp.array: 20 | return jnp.matmul(jnp.linalg.inv(self.lam), self.eta) 21 | 22 | def cov(self) -> jnp.array: 23 | return jnp.linalg.inv(self.lam) 24 | 25 | def mean_and_cov(self) -> List[jnp.array]: 26 | cov = self.cov() 27 | mean = jnp.matmul(cov, self.eta) 28 | return [mean, cov] 29 | 30 | def set_with_cov_form(self, mean: jnp.array, cov: jnp.array) -> None: 31 | self.lam = jnp.linalg.inv(cov) 32 | self.eta = self.lam @ mean 33 | 34 | """ 35 | Defines squared loss functions that correspond to Gaussians. 36 | Robust losses are implemented by scaling the Gaussian covariance. 37 | """ 38 | 39 | class SquaredLoss(): 40 | def __init__(self, dofs: int, diag_cov: Union[float, jnp.array]) -> None: 41 | """ 42 | dofs: dofs of the measurement 43 | cov: diagonal elements of covariance matrix 44 | """ 45 | assert len(diag_cov) == dofs 46 | mat = jnp.zeros((dofs, dofs), dtype=diag_cov.dtype) 47 | mat = mat.at[:, :].set(diag_cov) 48 | self.cov = mat 49 | self.effective_cov = mat.clone() 50 | 51 | def get_effective_cov(self, residual: jnp.array) -> None: 52 | """ Returns the covariance of the Gaussian (squared loss) that matches the loss at the error value. """ 53 | self.effective_cov = self.cov.clone() 54 | 55 | def robust(self) -> bool: 56 | return not jnp.array_equal(self.cov, self.effective_cov) 57 | 58 | 59 | class HuberLoss(SquaredLoss): 60 | def __init__(self, dofs: int, diag_cov: Union[float, jnp.array], stds_transition: float) -> None: 61 | """ 62 | stds_transition: num standard deviations from minimum at which quadratic loss transitions to linear 63 | """ 64 | SquaredLoss.__init__(self, dofs, diag_cov) 65 | self.stds_transition = stds_transition 66 | 67 | def get_effective_cov(self, residual: jnp.array) -> None: 68 | mahalanobis_dist = jnp.sqrt(residual @ jnp.linalg.inv(self.cov) @ residual) 69 | if mahalanobis_dist > self.stds_transition: 70 | self.effective_cov = self.cov * mahalanobis_dist**2 / (2 * self.stds_transition * mahalanobis_dist - self.stds_transition**2) 71 | else: 72 | self.effective_cov = self.cov.clone() 73 | 74 | 75 | class TukeyLoss(SquaredLoss): 76 | def __init__(self, dofs: int, diag_cov: Union[float, jnp.array], stds_transition: float) -> None: 77 | """ 78 | stds_transition: num standard deviations from minimum at which quadratic loss transitions to constant 79 | """ 80 | SquaredLoss.__init__(self, dofs, diag_cov) 81 | self.stds_transition = stds_transition 82 | 83 | def get_effective_cov(self, residual: jnp.array) -> None: 84 | mahalanobis_dist = jnp.sqrt(residual @ jnp.linalg.inv(self.cov) @ residual) 85 | if mahalanobis_dist > self.stds_transition: 86 | self.effective_cov = self.cov * mahalanobis_dist**2 / self.stds_transition**2 87 | else: 88 | self.effective_cov = self.cov.clone() 89 | 90 | class MeasModel: 91 | def __init__(self, meas_fn: Callable, jac_fn: Callable, loss: SquaredLoss, *args) -> None: 92 | self._meas_fn = meas_fn 93 | self._jac_fn = jac_fn 94 | self.loss = loss 95 | self.args = args 96 | self.linear = True 97 | 98 | def jac_fn(self, x: jnp.array) -> jnp.array: 99 | return self._jac_fn(x, *self.args) 100 | 101 | def meas_fn(self, x: jnp.array) -> jnp.array: 102 | return self._meas_fn(x, *self.args) 103 | -------------------------------------------------------------------------------- /deprecated-gauss-bp/factor.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | from typing import List, Callable, Optional, Union 4 | from variable_node import VariableNode 5 | from gaussian import Gaussian, MeasModel 6 | 7 | class Factor: 8 | def __init__(self, 9 | id: int, 10 | adj_var_nodes: List[VariableNode], 11 | measurement: jnp.array, 12 | meas_model: MeasModel, 13 | type: jnp.dtype = np.float, 14 | properties: dict = {}) -> None: 15 | 16 | self.factorID = id 17 | self.properties = properties 18 | 19 | self.adj_var_nodes = adj_var_nodes 20 | self.dofs = sum([var.dofs for var in adj_var_nodes]) 21 | self.adj_vIDs = [var.variableID for var in adj_var_nodes] 22 | self.messages = [Gaussian(var.dofs) for var in adj_var_nodes] 23 | 24 | self.factor = Gaussian(self.dofs) 25 | self.linpoint = jnp.zeros(self.dofs, dtype=type) 26 | 27 | self.measurement = measurement 28 | self.meas_model = meas_model 29 | 30 | # For smarter GBP implementations 31 | self.iters_since_relin = 0 32 | 33 | self.compute_factor() 34 | 35 | def get_adj_means(self) -> jnp.array: 36 | adj_belief_means = [var.belief.mean() for var in self.adj_var_nodes] 37 | return jnp.concatenate(adj_belief_means) 38 | 39 | def get_residual(self, eval_point: jnp.array = None) -> jnp.array: 40 | """ Compute the residual vector. """ 41 | if eval_point is None: 42 | eval_point = self.get_adj_means() 43 | return self.meas_model.meas_fn(eval_point) - self.measurement 44 | 45 | def get_energy(self, eval_point: jnp.array = None) -> float: 46 | """ Computes the squared error using the appropriate loss function. """ 47 | residual = self.get_residual(eval_point) 48 | # print("adj_belifes", self.get_adj_means()) 49 | # print("pred and meas", self.meas_model.meas_fn(self.get_adj_means()), self.measurement) 50 | # print("residual", self.get_residual(), self.meas_model.loss.effective_cov) 51 | return 0.5 * residual @ jnp.linalg.inv(self.meas_model.loss.effective_cov) @ residual 52 | 53 | def robust(self) -> bool: 54 | return self.meas_model.loss.robust() 55 | 56 | def compute_factor(self) -> None: 57 | """ 58 | Compute the factor at current adjacente beliefs using robust. 59 | If measurement model is linear then factor will always be the same regardless of linearisation point. 60 | """ 61 | self.linpoint = self.get_adj_means() 62 | J = self.meas_model.jac_fn(self.linpoint) 63 | pred_measurement = self.meas_model.meas_fn(self.linpoint) 64 | self.meas_model.loss.get_effective_cov(pred_measurement - self.measurement) 65 | effective_lam = jnp.linalg.inv(self.meas_model.loss.effective_cov) 66 | self.factor.lam = J.T @ effective_lam @ J 67 | self.factor.eta = ((J.T @ effective_lam) @ (J @ self.linpoint + self.measurement - pred_measurement)).flatten() 68 | self.iters_since_relin = 0 69 | 70 | def robustify_loss(self) -> None: 71 | """ 72 | Rescale the variance of the noise in the Gaussian measurement model if necessary and update the factor 73 | correspondingly. 74 | """ 75 | old_effective_cov = self.meas_model.loss.effective_cov[0, 0] 76 | self.meas_model.loss.get_effective_cov(self.get_residual()) 77 | self.factor.eta *= old_effective_cov / self.meas_model.loss.effective_cov[0, 0] 78 | self.factor.lam *= old_effective_cov / self.meas_model.loss.effective_cov[0, 0] 79 | 80 | def compute_messages(self, damping: float = 0.) -> None: 81 | """ Compute all outgoing messages from the factor. """ 82 | messages_eta, messages_lam = [], [] 83 | 84 | start_dim = 0 85 | for v in range(len(self.adj_vIDs)): 86 | eta_factor, lam_factor = self.factor.eta.clone(), self.factor.lam.clone() 87 | 88 | # Take product of factor with incoming messages 89 | start = 0 90 | for var in range(len(self.adj_vIDs)): 91 | if var != v: 92 | var_dofs = self.adj_var_nodes[var].dofs 93 | eta_factor = eta_factor.at[start:start + var_dofs].set(eta_factor[start:start + var_dofs] + self.adj_var_nodes[var].belief.eta - self.messages[var].eta) 94 | lam_factor = lam_factor.at[start:start + var_dofs, start:start + var_dofs].set(lam_factor[start:start + var_dofs, start:start + var_dofs] + self.adj_var_nodes[var].belief.lam - self.messages[var].lam) 95 | 96 | start += self.adj_var_nodes[var].dofs 97 | 98 | # Divide up parameters of distribution 99 | mess_dofs = self.adj_var_nodes[v].dofs 100 | eo = eta_factor[start_dim:start_dim + mess_dofs] 101 | eno = jnp.concatenate((eta_factor[:start_dim], eta_factor[start_dim + mess_dofs:])) 102 | 103 | loo = lam_factor[start_dim:start_dim + mess_dofs, start_dim:start_dim + mess_dofs] 104 | lono = jnp.concatenate((lam_factor[start_dim:start_dim + mess_dofs, :start_dim], 105 | lam_factor[start_dim:start_dim + mess_dofs, start_dim + mess_dofs:]), axis=1) 106 | lnoo = jnp.concatenate((lam_factor[:start_dim, start_dim:start_dim + mess_dofs], 107 | lam_factor[start_dim + mess_dofs:, start_dim:start_dim + mess_dofs]), axis=0) 108 | lnono = jnp.concatenate( 109 | ( 110 | jnp.concatenate((lam_factor[:start_dim, :start_dim], lam_factor[:start_dim, start_dim + mess_dofs:]), axis=1), 111 | jnp.concatenate((lam_factor[start_dim + mess_dofs:, :start_dim], lam_factor[start_dim + mess_dofs:, start_dim + mess_dofs:]), axis=1) 112 | ), 113 | axis=0 114 | ) 115 | 116 | new_message_lam = loo - lono @ jnp.linalg.inv(lnono) @ lnoo 117 | new_message_eta = eo - lono @ jnp.linalg.inv(lnono) @ eno 118 | messages_eta.append((1 - damping) * new_message_eta + damping * self.messages[v].eta) 119 | messages_lam.append((1 - damping) * new_message_lam + damping * self.messages[v].lam) 120 | start_dim += self.adj_var_nodes[v].dofs 121 | 122 | for v in range(len(self.adj_vIDs)): 123 | self.messages[v].lam = messages_lam[v] 124 | self.messages[v].eta = messages_eta[v] 125 | -------------------------------------------------------------------------------- /gaussian-loopy-bp/gauss_bp_utils.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from jax import jit 3 | from jax import numpy as jnp 4 | from jax.tree_util import tree_map 5 | 6 | 7 | def info_marginalize(Kxx, Kxy, Kyy, hx, hy): 8 | """Calculate the parameters of marginalized MVN. 9 | 10 | For x, y joint distributed as 11 | p(x, y) = Nc(x,y| h, K), 12 | the marginal distribution of x is given by: 13 | p(y) = \int p(x, y) dx = Nc(y | hy_marg, Ky_marg) 14 | where, 15 | hy_marg = hy - Kyx Kxx^{-1} hx 16 | Ky_marg = Kyy - Kyx Kxx^{-1} Kxy 17 | 18 | Args: 19 | K_blocks: blocks of the joint precision matrix, (Kxx, Kxy, Kyy), 20 | Kxx (dim_x, dim_x), 21 | Kxy (dim_x, dim_y), 22 | Kyy (dim_y, dim_y). 23 | hs (dim_x + dim_y, 1): joint precision weighted mean, (hx, hy): 24 | h1 (dim_x, 1), 25 | h2 (dim_y, 1). 26 | Returns: 27 | Ky_marg (dim_y, dim_y): marginal precision matrix. 28 | hy_marg (dim_y,1): marginal precision weighted mean. 29 | """ 30 | G = jnp.linalg.solve(Kxx, Kxy) 31 | Ky_marg = Kyy - Kxy.T @ G 32 | hy_marg = hy - G.T @ hx 33 | return Ky_marg, hy_marg 34 | 35 | 36 | def info_condition(Kxx, Kxy, hx, y): 37 | """Calculate the parameters of MVN after conditioning. 38 | 39 | For x,y with joint mvn 40 | p(x,y) = Nc(x,y | h, K), 41 | where h, K can be partitioned into, 42 | h = [hx, hy] 43 | K = [[Kxx, Kxy], 44 | [[Kyx, Kyy]] 45 | the distribution of x condition on a particular value of y is given by, 46 | p(x|y) = Nc(x | hx_cond, Kx_cond), 47 | where 48 | hx_cond = hx - Kxy y 49 | Kx_cond = Kxx 50 | """ 51 | return Kxx, hx - Kxy @ y 52 | 53 | 54 | def potential_from_conditional_linear_gaussian(A, offset, Lambda): 55 | """Express a conditional linear Gaussian as a potential in canonical form. 56 | 57 | p(y|z) = N(y | Az + offset, Lambda^{-1}) 58 | \prop exp( -0.5(y z)^T K (y z) + (y z)^T h ) 59 | where, 60 | K = (Lambda; -Lambda A, -A.T Lambda; A.T Lambda A) 61 | h = (Lambda offset, -A.T Lambda offset) 62 | 63 | Args: 64 | A (dim_y, dim_z) 65 | offset (dim_y,1) 66 | Lambda (dim_y, dim_y) 67 | Returns: 68 | K (dim_z + dim_y, dim_z + dim_y) 69 | h (dim_z + dim_y,1) 70 | """ 71 | Kzy = -A.T @ Lambda 72 | Kzz = -Kzy @ A 73 | Kyy = Lambda 74 | hy = Lambda @ offset 75 | hz = -A.T @ hy 76 | return (Kzz, Kzy, Kyy), (hz, hy) 77 | 78 | 79 | def info_multiply(params1, params2): 80 | """Calculate parameters resulting from multiplying Gaussians potentials. 81 | 82 | As all the resultant parameters are the sum of the parameters of the two 83 | potentials being multiplied, then `params1` and `params2` can be any 84 | PyTree of potential parameters as long as the corresponding parameters 85 | of the two input potentials occupy the same leaves of the PyTree. 86 | 87 | For example, 88 | phi(K1,h2) * phi(K2, h2) = phi(K1 + K2, h1 + h2) 89 | 90 | Args: 91 | params1: PyTree of potential parameters. 92 | params2: PyTree of potential parameters with the same tree structure 93 | as `params1`. 94 | 95 | Returns: 96 | params_out: PyTree of resultant potential parameters. 97 | """ 98 | return tree_map(lambda a, b: a + b, params1, params2) 99 | 100 | 101 | def info_divide(params1, params2): 102 | """Calculate parameters resulting from dividing Gaussian potentials. 103 | 104 | As all the resultant parameters are the difference between the parameters 105 | of the two potentials being divided, then `params1` and `params2` can be 106 | any PyTree of potential parameters as long as the corresponding parameters 107 | of the two input potentials occupy the same leaves of the PyTree. 108 | 109 | For example, 110 | phi(K1,h2) / phi(K2, h2) = phi(K1 - K2, h1 - h2) 111 | 112 | Args: 113 | params1: PyTree of potential parameters. 114 | params2: PyTree of potential parameters with the same tree structure 115 | as `params1`. 116 | 117 | Returns: 118 | params_out: PyTree of resultant potential parameters. 119 | """ 120 | return tree_map(lambda a, b: a - b, params1, params2) 121 | 122 | 123 | @partial(jit, static_argnums=2) 124 | def pair_cpot_condition(cpot, obs, obs_var): 125 | """Convenience function for conditioning Gaussian potentials involving two 126 | variables. 127 | 128 | Args: 129 | cpot: canonical parameters of the potential, stored as nested tuples 130 | of the form, 131 | ((K11, K12, K22), (h1, h2)). 132 | obs: observation. 133 | obs_var (int): the label of the variable being condition on. 134 | 135 | Returns: 136 | cond_pot: canonical parameters of the conditioned potential, 137 | (K_cond, h_cond). 138 | """ 139 | (K11, K12, K22), (h1, h2) = cpot 140 | if obs_var == 1: 141 | return info_condition(K22, K12.T, h2, obs) 142 | elif obs_var == 2: 143 | return info_condition(K11, K12, h1, obs) 144 | else: 145 | raise ValueError("obs_var must take a value of either 1 or 2.") 146 | 147 | 148 | @partial(jit, static_argnums=1) 149 | def pair_cpot_marginalize(cpot, marginalize_onto): 150 | """Convenience function for marginalizing Gaussian potentials involving two 151 | variables. 152 | 153 | Args: 154 | cpot: canonical parameters of the potential, stored as nested tuples 155 | of the form, 156 | ((K11, K12, K22), (h1, h2)). 157 | marginalize_onto (int): the label of the output marginal variable. 158 | 159 | Returns: 160 | marg_pot: canonical parameters of the marginal potential, 161 | (K_marg, h_marg). 162 | """ 163 | (K11, K12, K22), (h1, h2) = cpot 164 | if marginalize_onto == 1: 165 | return info_marginalize(K22, K12.T, K11, h2, h1) 166 | elif marginalize_onto == 2: 167 | return info_marginalize(K11, K12, K22, h1, h2) 168 | else: 169 | raise ValueError("marg_var must take a value of either 1 or 2.") 170 | 171 | 172 | @partial(jit, static_argnums=2) 173 | def pair_cpot_absorb_message(cpot, message, message_var): 174 | """Convenience function for absorbing a message into a Gaussain potential 175 | involving two variables. 176 | 177 | Args: 178 | cpot: canonical parameters of the potential, stored as nested tuples 179 | of the form, 180 | ((K11, K12, K22), (h1, h1)). 181 | message: the message potential which takes the form, 182 | (K_message, h_message) 183 | message_var (int): the label of the output marginal variable. 184 | 185 | Returns: 186 | cpot_plus_message: canonical parameters of the joint potential after 187 | the message has been incorporated, 188 | ((K11, K12, K22), (h1, h2)) 189 | """ 190 | K_message, h_message = message 191 | if message_var == 1: 192 | padded_message = ((K_message, 0, 0), (h_message, 0)) 193 | elif message_var == 2: 194 | padded_message = ((0, 0, K_message), (0, h_message)) 195 | else: 196 | raise ValueError("message_var must take a value of either 1 or 2.") 197 | 198 | return info_multiply(cpot, padded_message) 199 | -------------------------------------------------------------------------------- /gaussian-loopy-bp/gauss_chain.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import chex 3 | import jax 4 | from jax import vmap, lax, jit 5 | from jax import numpy as jnp 6 | from gauss_bp.gauss_bp_utils import ( 7 | potential_from_conditional_linear_gaussian, 8 | pair_cpot_condition, 9 | pair_cpot_marginalize, 10 | pair_cpot_absorb_message, 11 | info_multiply, 12 | info_divide, 13 | ) 14 | 15 | 16 | @chex.dataclass 17 | class GaussianChainPotentials: 18 | """Container class for Gaussian Chain Potentials. 19 | 20 | Both `latent_pots` and `obs_pots` contain the canonical parameters for a 21 | gaussian potential over a pair of variables. 22 | 23 | It is assumed that the latent and observed variables have the same shape 24 | along the chain (but not necessarily the same as each other). As occurs for 25 | instance in temporal models. This means that the potential parameters can be 26 | stacked as rows for and used with `jax.vmap` and `jax.lax.scan`. 27 | 28 | Attributes: 29 | prior_pot: A tuple containg the parameters for the prior potential over 30 | the first latent state (K, h). 31 | 32 | latent_pots: A tuple containing the parameters for each pairwise latent 33 | clique potential - ((K11, K12, K22),(h1, h2)). The ith row 34 | of each array contains parameters for the clique containing 35 | the pair of latent states at times (i, i+1). 36 | Arrays have shapes: 37 | K11, K12, K22 - (T-1, D_hid, D_hid) 38 | h1, h2 - (T-1, D_hid) 39 | 40 | obs_pots: A tuple containing the parameters for each pairwise 41 | emission clique potential - ((K11, K12, K22),(h1, h2)). 42 | Arrays have shapes: 43 | K11 - (T, D_hid, D_hid) 44 | K12 - (T, D_hid, D_obs) 45 | K22 - (T, D_obs, D_obs) 46 | h1 - (T, D_hid) 47 | h2 - (T, D_obs) 48 | """ 49 | 50 | prior_pot: chex.Array 51 | latent_pots: chex.Array 52 | obs_pots: chex.Array 53 | 54 | 55 | def gauss_chain_potentials_from_lgssm(lgssm_params, inputs, T=None): 56 | """Construct pairwise latent and emission clique potentials from model. 57 | 58 | Args: 59 | lgssm_params: an LGSSMInfoParams instance. 60 | inputs (T,D_in): array of inputs. 61 | T (int): number of timesteps to to unroll the lgssm, only used if 62 | `inputs=None`. 63 | 64 | Returns: 65 | prior_pot: A tuple of parameters representing the prior potential, 66 | (Lambda0, eta0) 67 | """ 68 | if inputs is None: 69 | if T is not None: 70 | D_in = lgssm_params.dynamics_input_weights.shape[1] 71 | inputs = jnp.zeros((T, D_in)) 72 | else: 73 | raise ValueError("One of `inputs` or `T` must not be None.") 74 | 75 | B, b = lgssm_params.dynamics_input_weights, lgssm_params.dynamics_bias 76 | D, d = lgssm_params.emission_input_weights, lgssm_params.emission_bias 77 | latent_net_inputs = vmap(jnp.dot, (None, 0))(B, inputs) + b 78 | emission_net_inputs = vmap(jnp.dot, (None, 0))(D, inputs) + d 79 | 80 | Lambda0, mu0 = lgssm_params.initial_precision, lgssm_params.initial_mean 81 | prior_pot = (Lambda0, Lambda0 @ mu0) 82 | 83 | F, Q_prec = lgssm_params.dynamics_matrix, lgssm_params.dynamics_precision 84 | latent_pots = vmap(potential_from_conditional_linear_gaussian, (None, 0, None))(F, latent_net_inputs[:-1], Q_prec) 85 | 86 | H, R_prec = lgssm_params.emission_matrix, lgssm_params.emission_precision 87 | emission_pots = vmap(potential_from_conditional_linear_gaussian, (None, 0, None))(H, emission_net_inputs, R_prec) 88 | 89 | gauss_chain_potentials = GaussianChainPotentials( 90 | prior_pot=prior_pot, latent_pots=latent_pots, obs_pots=emission_pots 91 | ) 92 | return gauss_chain_potentials 93 | 94 | 95 | def gauss_chain_bp(gauss_chain_pots, obs): 96 | """Belief propagation on a Gaussian chain. 97 | 98 | Calculate the canonical parameters for the marginal probability of latent 99 | states conditioned on the full set of observation, 100 | p(x_t | y_{1:T}). 101 | 102 | Args: 103 | gauss_chain_pots: GaussianChainPotentials object containing the prior 104 | potential for the first latent state and pairwise 105 | potentials for the latent and observed variables. 106 | obs (T,D_obs): Array containing the observations. 107 | 108 | Returns: 109 | smoothed_bels: canonical parameters of marginal distribution of each latent 110 | state condition on all observations, (K_smoothed, h_smoothed) with 111 | shapes, 112 | K_smoothed (T, D_hid, D_hid) 113 | h_smoothed (T, D_hid). 114 | """ 115 | prior_pot, latent_pots, emission_pots = gauss_chain_pots.to_tuple() 116 | 117 | local_evidence_pots = vmap(partial(pair_cpot_condition, obs_var=2))(emission_pots, obs) 118 | 119 | # Extract first local evidence potential 120 | init_local_evidence_pot = jax.tree_map(lambda a: a[0], local_evidence_pots) 121 | local_evidence_pots_rest = jax.tree_map(lambda a: a[1:], local_evidence_pots) 122 | 123 | # Combine first emission message with prior 124 | init_carry = info_multiply(prior_pot, init_local_evidence_pot) 125 | 126 | def _forward_step(carry, x): 127 | """Gaussian chain belief propagation forward step. 128 | 129 | Carry forward filtered beliefs p(x_{t-1}|y_{1:t-1}) and combine with latent 130 | potential, phi(x_{t-1}, x_t) and local evidence from observation, y_t, 131 | to calculate filtered belief at current step p(x_t|y_{1:t}). 132 | """ 133 | prev_filtered_bel = carry 134 | latent_pot, local_evidence_pot, y = x 135 | 136 | # Calculate latent message 137 | latent_pot = pair_cpot_absorb_message(latent_pot, prev_filtered_bel, message_var=1) 138 | latent_message = pair_cpot_marginalize(latent_pot, marginalize_onto=2) 139 | 140 | # Combine messages 141 | filtered_bel = info_multiply(latent_message, local_evidence_pot) 142 | 143 | return filtered_bel, (filtered_bel, latent_message) 144 | 145 | # Message pass forwards along chain 146 | _, (filtered_bels, forward_messages) = lax.scan( 147 | _forward_step, init_carry, (latent_pots, local_evidence_pots_rest, obs[1:]) 148 | ) 149 | # Append first belief 150 | filtered_bels = jax.tree_map(lambda h, t: jnp.row_stack((h[None, ...], t)), init_carry, filtered_bels) 151 | 152 | # Extract final belief 153 | init_carry = jax.tree_map(lambda a: a[-1], filtered_bels) 154 | filtered_bels_rest = jax.tree_map(lambda a: a[:-1], filtered_bels) 155 | 156 | def _backward_step(carry, x): 157 | """Gaussian chain belief propagation backward step. 158 | 159 | Carry backward smoothed beliefs p(x_t|y_{1:T}) and combine with latent 160 | potential, phi(x_{t-1}, x_t) to calculate smoothed belief at t-1 161 | p(x_{t-1}|y_{1:T}). 162 | """ 163 | smoothed_bel_present = carry 164 | filtered_bel_past, message_from_past, latent_pot_past_present = x 165 | 166 | # Divide out forward message 167 | bel_minus_message_from_past = info_divide(smoothed_bel_present, message_from_past) 168 | # Absorb into joint potential 169 | latent_pot_past_present = pair_cpot_absorb_message(latent_pot_past_present, bel_minus_message_from_past, message_var=2) 170 | message_to_past = pair_cpot_marginalize(latent_pot_past_present, marginalize_onto=1) 171 | 172 | smoothed_bel_past = info_multiply(filtered_bel_past, message_to_past) 173 | return smoothed_bel_past, smoothed_bel_past 174 | 175 | # Message pass back along chain 176 | _, smoothed_bels = lax.scan( 177 | _backward_step, init_carry, (filtered_bels_rest, forward_messages, latent_pots), reverse=True 178 | ) 179 | # Append final belief 180 | smoothed_bels = jax.tree_map(lambda h, t: jnp.row_stack((h, t[None, ...])), smoothed_bels, init_carry) 181 | 182 | return smoothed_bels 183 | -------------------------------------------------------------------------------- /gaussian-loopy-bp/gauss_factor_graph_test.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures.process import _MAX_WINDOWS_WORKERS 2 | from functools import partial 3 | from jax import numpy as jnp 4 | from jax import random as jr 5 | from jax import vmap, jit 6 | 7 | from ssm_jax.distributions import InverseWishart 8 | from gauss_bp.gauss_bp_utils import info_multiply, potential_from_conditional_linear_gaussian, pair_cpot_condition 9 | from gauss_bp.gauss_factor_graph import (GaussianVariableNode, 10 | CanonicalFactor, 11 | CanonicalPotential, 12 | GaussianFactorGraph, 13 | zeros_canonical_pot, 14 | update_all_messages, 15 | make_factor_graph, 16 | init_messages, 17 | calculate_all_beliefs, 18 | make_canonical_factor) 19 | 20 | from ssm_jax.linear_gaussian_ssm.inference import lgssm_sample 21 | from ssm_jax.linear_gaussian_ssm.info_inference import lgssm_info_smoother 22 | from ssm_jax.linear_gaussian_ssm.info_inference_test import build_lgssm_moment_and_info_form 23 | 24 | _all_close = lambda x,y: jnp.allclose(x,y,rtol=1e-3, atol=1e-3) 25 | 26 | def canonical_factor_from_clg(A, u, Lambda, x, y, factorID): 27 | """Construct a CanoncialFactor from the parameters of a conditional linear Gaussian.""" 28 | (Kxx, Kxy, Kyy), (hx, hy) = potential_from_conditional_linear_gaussian(A, u, Lambda) 29 | K = jnp.block([[Kxx, Kxy], [Kxy.T, Kyy]]) 30 | h = jnp.concatenate((hx, hy)) 31 | cpot = CanonicalPotential(eta=h, Lambda=K) 32 | return make_canonical_factor(factorID, [x, y], cpot) 33 | 34 | 35 | def factor_graph_from_lgssm(lgssm_params, inputs, obs, T=None): 36 | """Unroll a linear gaussian state-space model into a factor graph.""" 37 | if inputs is None: 38 | if T is not None: 39 | D_in = lgssm_params.dynamics_input_weights.shape[1] 40 | inputs = jnp.zeros((T, D_in)) 41 | else: 42 | raise ValueError("One of `inputs` or `T` must not be None.") 43 | 44 | num_timesteps = len(inputs) 45 | Lambda0, mu0 = lgssm_params.initial_precision, lgssm_params.initial_mean 46 | latent_dim = len(mu0) 47 | 48 | latent_vars = [GaussianVariableNode(f"x{i}", latent_dim, zeros_canonical_pot(latent_dim)) 49 | for i in range(num_timesteps)] 50 | # Add informative prior to first time point. 51 | x0_prior = CanonicalPotential(eta=Lambda0 @ mu0, Lambda=Lambda0) 52 | latent_vars[0] = GaussianVariableNode("x0", latent_dim, x0_prior) 53 | 54 | B, b = lgssm_params.dynamics_input_weights, lgssm_params.dynamics_bias 55 | F, Q_prec = lgssm_params.dynamics_matrix, lgssm_params.dynamics_precision 56 | latent_net_inputs = vmap(jnp.dot, (None, 0))(B, inputs) + b 57 | latent_factors = [ 58 | canonical_factor_from_clg( 59 | F, latent_net_inputs[i], Q_prec, latent_vars[i], latent_vars[i + 1], f"latent_{i},{i+1}" 60 | ) 61 | for i in range(num_timesteps - 1) 62 | ] 63 | 64 | D, d = lgssm_params.emission_input_weights, lgssm_params.emission_bias 65 | H, R_prec = lgssm_params.emission_matrix, lgssm_params.emission_precision 66 | 67 | emission_net_inputs = vmap(jnp.dot, (None, 0))(D, inputs) + d 68 | emission_pots = vmap(potential_from_conditional_linear_gaussian, (None, 0, None))( 69 | H, emission_net_inputs, R_prec) 70 | local_evidence_pot_chain = vmap(partial(pair_cpot_condition, obs_var=2))(emission_pots, obs) 71 | local_evidence_pots = [CanonicalPotential(eta, Lambda) for Lambda, eta in zip(*local_evidence_pot_chain)] 72 | 73 | def incorporate_local_evidence(var, pot): 74 | """Incorporate local evidence potential into Gaussian variable.""" 75 | prior_plus_pot = info_multiply(var.prior, pot) 76 | return GaussianVariableNode(var.varID, var.dim, prior_plus_pot) 77 | 78 | latent_vars = [incorporate_local_evidence(var,pot) for var, pot in zip(latent_vars, local_evidence_pots)] 79 | 80 | fg = make_factor_graph(latent_vars, latent_factors) 81 | 82 | return fg 83 | 84 | def test_gauss_factor_graph_lgssm(): 85 | """Test that Gaussian chain belief propagation gets the same results as 86 | information form RTS smoother.""" 87 | 88 | lgssm, lgssm_info = build_lgssm_moment_and_info_form() 89 | 90 | key = jr.PRNGKey(111) 91 | num_timesteps = 5 # Fewer timesteps so that we can run fewer iterations. 92 | input_size = lgssm.dynamics_input_weights.shape[1] 93 | inputs = jnp.zeros((num_timesteps, input_size)) 94 | _, y = lgssm_sample(key, lgssm, num_timesteps, inputs=inputs) 95 | 96 | lgssm_info_posterior = lgssm_info_smoother(lgssm_info, y, inputs) 97 | 98 | fg = factor_graph_from_lgssm(lgssm_info,inputs, y) 99 | 100 | # Loopy bp. 101 | messages = init_messages(fg) 102 | for _ in range(num_timesteps): 103 | messages = update_all_messages(fg,messages) 104 | 105 | # Calculate final beliefs 106 | final_beliefs = calculate_all_beliefs(fg,messages) 107 | fg_etas = jnp.vstack([cpot.eta for cpot in final_beliefs.values()]) 108 | fg_Lambdas = jnp.stack([cpot.Lambda for cpot in final_beliefs.values()]) 109 | 110 | assert _all_close(fg_etas, lgssm_info_posterior.smoothed_etas) 111 | assert _all_close(fg_Lambdas, lgssm_info_posterior.smoothed_precisions) 112 | 113 | 114 | def test_tree_factor_graph(): 115 | 116 | key = jr.PRNGKey(0) 117 | dim = 2 118 | 119 | ### Construct variables in moment form ### 120 | IW = InverseWishart(dim, jnp.eye(dim)*0.1) 121 | key, subkey = jr.split(key) 122 | covs = jit(IW.sample,static_argnums=0)(5,subkey) 123 | 124 | key, subkey1 = jr.split(key) 125 | mu1 = jr.normal(subkey1,(dim,)) 126 | Sigma1 = covs[0] 127 | 128 | key, subkey = jr.split(key) 129 | mu2 = jr.normal(subkey,(dim,)) 130 | Sigma2 = covs[1] 131 | 132 | # x_3 | x_1, x_2 ~ N(x_3| A_31 x_1 + A_32 x_2, Sigma_{3|1,2}) 133 | key, *subkeys = jr.split(key,3) 134 | A31 = jr.normal(subkeys[0],(dim,dim)) 135 | A32 = jr.normal(subkeys[1],(dim,dim)) 136 | Sigma3_cond = covs[2] 137 | mu3 = A31 @ mu1 + A32 @ mu2 138 | Sigma3 = Sigma3_cond + A31 @ Sigma1 @ A31.T + A32 @ Sigma2 @ A32.T 139 | 140 | # x_4 | x_3 ~ N(x_4 | A_4 x_3, Sigma_{3|4}) 141 | key, subkey = jr.split(key) 142 | A4 = jr.normal(subkey,(dim,dim)) 143 | Sigma4_cond = covs[3] 144 | mu4 = A4 @ mu3 145 | Sigma4 = Sigma4_cond + A4 @ Sigma3 @ A4.T 146 | 147 | # x_5 | x_3 ~ N(x_5 | A_5 x_3, Sigma_{5|4}) 148 | key, subkey = jr.split(key) 149 | A5 = jr.normal(subkey,(dim,dim)) 150 | Sigma5_cond = covs[4] 151 | mu5 = A5 @ mu3 152 | Sigma5 = Sigma5_cond + A5 @ Sigma3 @ A5.T 153 | 154 | ### Construct variables and factors in Canonical Form ### 155 | Lambda1 = jnp.linalg.inv(Sigma1) 156 | eta1 = Lambda1 @ mu1 157 | prior_x1 = CanonicalPotential(eta1, Lambda1) 158 | 159 | Lambda2 = jnp.linalg.inv(Sigma2) 160 | eta2 = Lambda2 @ mu2 161 | prior_x2 = CanonicalPotential(eta2, Lambda2) 162 | 163 | x1_var = GaussianVariableNode(1, dim, prior_x1) 164 | x2_var = GaussianVariableNode(2, dim, prior_x2) 165 | x3_var = GaussianVariableNode(3, dim, zeros_canonical_pot(dim)) 166 | x4_var = GaussianVariableNode(4, dim, zeros_canonical_pot(dim)) 167 | x5_var = GaussianVariableNode(5, dim, zeros_canonical_pot(dim)) 168 | 169 | offset = jnp.zeros(dim) 170 | Lambda3_cond = jnp.linalg.inv(Sigma3_cond) 171 | A3_joint = jnp.hstack((A31,A32)) 172 | (Kxx, Kxy, Kyy), (hx,hy) = potential_from_conditional_linear_gaussian(A3_joint, 173 | offset, 174 | Lambda3_cond) 175 | K = jnp.block([[Kxx, Kxy], 176 | [Kxy.T, Kyy]]) 177 | h = jnp.concatenate((hx,hy)) 178 | cpot_123 = CanonicalPotential(eta=h, Lambda=K) 179 | factor_123 = make_canonical_factor("factor_123", [x1_var, x2_var, x3_var], cpot_123) 180 | 181 | Lambda4_cond = jnp.linalg.inv(Sigma4_cond) 182 | factor_34 = canonical_factor_from_clg(A4, offset, Lambda4_cond, x3_var, x4_var, "factor_34") 183 | 184 | Lambda5_cond = jnp.linalg.inv(Sigma5_cond) 185 | factor_35 = canonical_factor_from_clg(A5, offset, Lambda5_cond, x3_var, x5_var, "factor_35") 186 | 187 | # Build factor graph. 188 | var_nodes = [x1_var, x2_var, x3_var, x4_var, x5_var] 189 | factors = [factor_123, factor_34, factor_35] 190 | fg = make_factor_graph(var_nodes, factors) 191 | 192 | # Loopy BP 193 | messages = init_messages(fg) 194 | for _ in range(5): 195 | messages = update_all_messages(fg,messages) 196 | 197 | # Extract marginal etas and Lambas from factor graph. 198 | final_beliefs = calculate_all_beliefs(fg,messages) 199 | fg_etas = jnp.vstack([cpot.eta for cpot in final_beliefs.values()]) 200 | fg_Lambdas = jnp.stack([cpot.Lambda for cpot in final_beliefs.values()]) 201 | 202 | # Convert to moment form 203 | fg_means = vmap(jnp.linalg.solve)(fg_Lambdas, fg_etas) 204 | fg_covs = jnp.linalg.inv(fg_Lambdas) 205 | 206 | means = jnp.vstack([mu1,mu2,mu3,mu4,mu5]) 207 | covs = jnp.stack([Sigma1,Sigma2,Sigma3,Sigma4,Sigma5]) 208 | 209 | # Compare to moment form marginals. 210 | assert jnp.allclose(fg_means,means,rtol=1e-2,atol=1e-2) 211 | assert jnp.allclose(fg_covs,covs,rtol=1e-2,atol=1e-2) 212 | -------------------------------------------------------------------------------- /gaussian-loopy-bp/gauss_factor_graph.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from functools import partial 3 | from typing import NamedTuple, Union, List, Dict 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | from jax import jit 9 | from jax.tree_util import register_pytree_node_class, tree_leaves, tree_map 10 | from gauss_bp.gauss_bp_utils import (info_divide, 11 | info_marginalize, 12 | info_multiply) 13 | 14 | def _tree_reduce(function, tree, initializer=None, is_leaf=None): 15 | """Copy of jax.tree_utils.reduce which accepts an `is_leaf` argument.""" 16 | if initializer is None: 17 | return functools.reduce(function, tree_leaves(tree, is_leaf)) 18 | else: 19 | return functools.reduce(function, tree_leaves(tree, is_leaf), initializer) 20 | 21 | class CanonicalPotential(NamedTuple): 22 | """Container class for a Canonical Potential. 23 | 24 | eta: (N,) array. 25 | Lambda: (N, N) array. 26 | """ 27 | eta: jnp.ndarray 28 | Lambda: jnp.ndarray 29 | 30 | def zeros_canonical_pot(dim): 31 | """Construct a Canonical potential with all entries zero.""" 32 | eta = jnp.zeros(dim) 33 | Lambda = jnp.zeros((dim, dim)) 34 | return CanonicalPotential(eta=eta, Lambda=Lambda) 35 | 36 | 37 | @jit 38 | def sum_reduce_cpots(cpots, initializer=None): 39 | """Sum over the corresponding parameters for potential in `cpots`.""" 40 | return _tree_reduce(info_multiply, cpots, initializer, 41 | is_leaf=lambda l: isinstance(l, CanonicalPotential)) 42 | 43 | 44 | class GaussianVariableNode(NamedTuple): 45 | """A Gaussian variable node in a factor graph.""" 46 | varID : Union[int, str] 47 | dim : int 48 | prior : CanonicalPotential # Currently using prior to 'roll in' single variable factors. 49 | 50 | class CanonicalFactor(NamedTuple): 51 | """A canonical factor involving `len(adj_varIDs)` variables.""" 52 | # TODO: Not totally clear what the best way is to handle the var_scopes. 53 | # Could use a dataclass with a __post_init__ to construct automatically... 54 | # I'm not sure how this plays with rolling/unrolling... 55 | factorID : Union[int, str] 56 | adj_varIDs : Union[List[int], List[str]] 57 | potential : CanonicalPotential 58 | var_scopes : Dict # {varID : (var_start, var_stop)} 59 | 60 | def make_canonical_factor(factorID, var_nodes, cpot): 61 | """ Helper function to construct a Canonical Factor.""" 62 | # It should be possible to replace this with the right constructor for CanonicalFactor 63 | varIDs = [var.varID for var in var_nodes] 64 | var_scopes = _calculate_var_scopes(var_nodes) 65 | return CanonicalFactor(factorID, varIDs, cpot, var_scopes) 66 | 67 | class GaussianFactorGraph(NamedTuple): 68 | """ A container for variables and factor in a factor graph.""" 69 | var_nodes : List[GaussianVariableNode] 70 | factors : List[CanonicalFactor] 71 | # factor_to_var_edges are already implicitly contained factors... 72 | factor_to_var_edges : Dict 73 | var_to_factor_edges : Dict 74 | 75 | def make_factor_graph(var_nodes, factors): 76 | """Helper function to construct a factor graph.""" 77 | # TODO: Another place where maybe this belongs as a constructor in the GFG class. 78 | factor_to_var_edges = {factor.factorID: factor.adj_varIDs for factor in factors} 79 | var_to_factor_edges = {var.varID: [fID for fID, vIDs in factor_to_var_edges.items() if var.varID in vIDs] 80 | for var in var_nodes} 81 | return GaussianFactorGraph(var_nodes,factors, 82 | factor_to_var_edges, 83 | var_to_factor_edges) 84 | 85 | def init_messages(factor_graph): 86 | """ Initial factor to variable messages as zeros.""" 87 | var_dims = {var.varID : var.dim for var in factor_graph.var_nodes} 88 | return {factor.factorID : {vID: zeros_canonical_pot(var_dims[vID]) for vID in factor.adj_varIDs} 89 | for factor in factor_graph.factors} 90 | 91 | @partial(jit, static_argnums=2) 92 | def absorb_canonical_message(cpot, message, message_scope): 93 | """Absorb a canonical message into a potential.""" 94 | var_start, var_stop = message_scope 95 | eta = cpot.eta.at[var_start:var_stop].add(message.eta) 96 | Lambda = cpot.Lambda.at[var_start:var_stop, var_start:var_stop].add(message.Lambda) 97 | return CanonicalPotential(eta, Lambda) 98 | 99 | def absorb_var_to_factor_messages(factor, messages): 100 | """Absorb all the messages into a factor potential.""" 101 | pot = tree_map(jnp.copy,factor.potential) 102 | for varID, message in messages.items(): 103 | pot = absorb_canonical_message(pot, message, factor.var_scopes[varID]) 104 | return pot 105 | 106 | def marginalise_onto_var(potential, var_scope): 107 | """Marginalise a joint canonical form Gaussian onto a variable.""" 108 | (K11, K12, K22), (h1, h2) = extract_canonical_potential_blocks(potential, var_scope) 109 | K_marg, h_marg = info_marginalize(K22, K12.T, K11, h2, h1) 110 | return CanonicalPotential(eta=h_marg, Lambda=K_marg) 111 | 112 | def update_belief(var, messages): 113 | """Combine incoming messages to calculate a variable's belief state.""" 114 | return sum_reduce_cpots(messages, initializer=var.prior) 115 | 116 | def calculate_var_belief(var,messages, var_to_factor_edges): 117 | """Combine incoming messages to calculate a variable's belief state.""" 118 | incoming_messages = [messages[f][var.varID] for f in var_to_factor_edges[var.varID]] 119 | return update_belief(var, incoming_messages) 120 | 121 | def calculate_all_beliefs(factor_graph, messages): 122 | """Absorb messages for all variables to calculate all belief states.""" 123 | return {var.varID: calculate_var_belief(var, messages, factor_graph.var_to_factor_edges) 124 | for var in factor_graph.var_nodes} 125 | 126 | def update_factor_to_var_messages(factor,var_beliefs,messages_to_vars,damping=0): 127 | """Given current variable belief state and the previous messages sent to each variable 128 | calculate the update messages to send to variables. 129 | 130 | Note: Ordinarily the message sent to the ith variable, var_i, is given by summing the incoming 131 | messages of all *other* variables and then marginalising onto var_i. 132 | 133 | In Canonical Gaussian form we can sum *all* messages and marginalise onto var_i before simply 134 | substracting the message from var_i. This trick allows us perform a single message absorption 135 | step instead of repeating for each variable. 136 | 137 | Args: 138 | factor: A CanonicalFactor object. 139 | var_beliefs: A dictionary `{varID: CanonicalPotential}` containing marginal variable belief 140 | states. 141 | messages_to_vars: A dictionary {varID: CanonicalPotential} containing messages from `factor` 142 | to each variable from the previous step. 143 | damping: float. 144 | 145 | Returns: 146 | messages_to_vars: A dictionary {varID: CanonicalPotential} containing the updated messages 147 | from `factor` to each variable. 148 | """ 149 | # Divide var.beliefs by messages --> message_v_to_f. 150 | var_messages = info_divide(var_beliefs, messages_to_vars) 151 | # Absorb all messages into factor. 152 | pot_plus_messages = absorb_var_to_factor_messages(factor, var_messages) 153 | for vID in factor.adj_varIDs: 154 | # Marginalise message_v_to_f onto var. 155 | var_marginal = marginalise_onto_var(pot_plus_messages, factor.var_scopes[vID]) 156 | raw_message = info_divide(var_marginal, var_messages[vID]) # Substract message_f_to_v 157 | damped_message = jax.tree_map( 158 | lambda x, y: damping * x + (1 - damping) * y, 159 | messages_to_vars[vID], raw_message 160 | ) 161 | messages_to_vars[vID] = damped_message 162 | return messages_to_vars 163 | 164 | def update_all_messages(factor_graph, messages, damping=0): 165 | """Loop over all factors in the graph calculated updated messages to variables. 166 | 167 | Args: 168 | factor_graph: A GaussianFactorGraph object. 169 | messages: a dict of dicts containing messages for each factor - 170 | {factorID : {varID: message, ...}, ...} 171 | 172 | Returns: 173 | new_messages: dict with the same form as `messages` containing updated factor to variable 174 | messages. 175 | """ 176 | var_beliefs = calculate_all_beliefs(factor_graph, messages) 177 | new_messages = {} 178 | for factor in factor_graph.factors: 179 | factor_var_beliefs = {vID:var_beliefs[vID] for vID in factor.adj_varIDs} 180 | messages_to_vars = messages[factor.factorID] 181 | new_messages[factor.factorID] = update_factor_to_var_messages( 182 | factor, factor_var_beliefs, messages_to_vars, damping 183 | ) 184 | return new_messages 185 | 186 | def extract_canonical_potential_blocks(can_pot, var_scope): 187 | """Split a precision matrix into blocks for marginalising / conditionalising. 188 | 189 | E.g. K = [[ 0, 1, 2, 3, 4], 190 | [ 5, 6, 7, 8, 9], 191 | [10, 11, 12, 13, 14], 192 | [15, 16, 17, 18, 19], 193 | [20, 21, 22, 23, 24]] 194 | h = [0, 1, 2, 3, 4, 5] 195 | idxs = [1,2] 196 | gets split into: 197 | K11 - [[ 6, 7], 198 | [11, 12]] 199 | 200 | K12 - [[ 5, 8, 9], 201 | [10, 13, 14]] 202 | 203 | K22 - [[ 0, 3, 4], 204 | [15, 18, 19], 205 | [20, 23, 24]] 206 | and 207 | h1 - [1, 2] 208 | h2 - [0, 3, 4] 209 | 210 | Args: 211 | can_pot - a CanonicalPotential object (or similar tuple) with elements (h, K) where, 212 | K - (D x D) precision matrix. 213 | h - (D,) potential vector. 214 | idxs (N,) array of indices in 1,...,D. 215 | Returns: 216 | (K11, K12, K22), (h1, h2) - blocks of the potential parameters where: 217 | K11 (N x N) block of precision elements with row and column in `indxs` 218 | K12 (N x D-N) block of precision elements with row in `indxs` but column not in `indxs`. 219 | K22 (D-N x D-N) block of precision elements with neither row nor column in `indxs`. 220 | h1 (N,) elements of potential vector in `indxs`. 221 | h2 (D-N,) elements of potential vector not in `indxs`. 222 | """ 223 | # TODO: Investigate using jax.lax.dynamic_slice instead. 224 | # TODO: also maybe precompute the ~b indices. 225 | h, K = can_pot 226 | # Using np instead of jnp so that these aren't traced. 227 | idxs = np.arange(*var_scope) 228 | idx_range = np.arange(len(h)) 229 | b = np.isin(idx_range, idxs) 230 | K11 = K[b, :][:, b] 231 | K12 = K[b, :][:, ~b] 232 | K22 = K[~b, :][:, ~b] 233 | h1 = h[b] 234 | h2 = h[~b] 235 | return (K11, K12, K22), (h1, h2) 236 | 237 | 238 | def _calculate_var_scopes(var_nodes): 239 | """Helper function to calculate the variable scopes for a factor involving `var_nodes`.""" 240 | var_ids = np.array([var.varID for var in var_nodes]) 241 | var_dims = np.array([var.dim for var in var_nodes], dtype=np.int32) 242 | # Use numpy so it is not traced. 243 | var_starts = np.concatenate((np.zeros(1, dtype=np.int32), np.cumsum(var_dims)[:-1])) 244 | var_stops = var_starts + var_dims 245 | var_scopes = {var_id: (start, stop) 246 | for var_id, start, stop in zip(var_ids, var_starts, var_stops)} 247 | return var_scopes 248 | -------------------------------------------------------------------------------- /deprecated-gauss-bp/factor_graph.py: -------------------------------------------------------------------------------- 1 | # factor_graph.py 2 | from variable_node import VariableNode 3 | import jax.numpy as jnp 4 | from typing import List, Callable, Optional, Union 5 | from gaussian import Gaussian, MeasModel 6 | from factor import Factor 7 | import random 8 | 9 | #@title Main GBP Functions 10 | 11 | """ 12 | Defines classes for variable nodes, factor nodes and edges and factor graph. 13 | """ 14 | class GBPSettings: 15 | def __init__(self, 16 | damping: float = 0., 17 | beta: float = 0.1, 18 | num_undamped_iters: int = 5, 19 | min_linear_iters: int = 10, 20 | dropout: float = 0., 21 | reset_iters_since_relin: List[int] = []) -> None: 22 | 23 | # Parameters for damping the eta component of the message 24 | self.damping = damping 25 | self.num_undamped_iters = num_undamped_iters # Number of undamped iterations after relinearisation before damping is set to damping 26 | 27 | self.dropout = dropout 28 | 29 | # Parameters for just in time factor relinearisation 30 | self.beta = beta # Threshold absolute distance between linpoint and adjacent belief means for relinearisation. 31 | self.min_linear_iters = min_linear_iters # Minimum number of linear iterations before a factor is allowed to realinearise. 32 | self.reset_iters_since_relin = reset_iters_since_relin 33 | 34 | def get_damping(self, iters_since_relin: int) -> float: 35 | if iters_since_relin > self.num_undamped_iters: 36 | return self.damping 37 | else: 38 | return 0 39 | 40 | class FactorGraph: 41 | def __init__(self, gbp_settings: GBPSettings = GBPSettings()) -> None: 42 | self.var_nodes = [] 43 | self.factors = [] 44 | self.gbp_settings = gbp_settings 45 | 46 | def add_var_node(self, 47 | dofs: int, 48 | prior_mean: Optional[jnp.array] = None, 49 | prior_diag_cov: Optional[Union[float, jnp.array]] = None, 50 | properties: dict = {}) -> None: 51 | variableID = len(self.var_nodes) 52 | self.var_nodes.append(VariableNode(variableID, dofs, properties=properties)) 53 | if prior_mean is not None and prior_diag_cov is not None: 54 | prior_cov = jnp.zeros((dofs, dofs), dtype=prior_diag_cov.dtype) 55 | prior_cov = prior_cov.at[:, :].set(prior_diag_cov) 56 | self.var_nodes[-1].prior.set_with_cov_form(prior_mean, prior_cov) 57 | self.var_nodes[-1].update_belief() 58 | 59 | def add_factor(self, adj_var_ids: List[int], 60 | measurement: jnp.array, 61 | meas_model: MeasModel, 62 | properties: dict = {}) -> None: 63 | factorID = len(self.factors) 64 | adj_var_nodes = [self.var_nodes[i] for i in adj_var_ids] 65 | self.factors.append(Factor(factorID, adj_var_nodes, measurement, meas_model, properties=properties)) 66 | for var in adj_var_nodes: 67 | var.adj_factors.append(self.factors[-1]) 68 | 69 | def update_all_beliefs(self) -> None: 70 | for var_node in self.var_nodes: 71 | var_node.update_belief() 72 | 73 | def compute_all_messages(self, apply_dropout: bool = True) -> None: 74 | for factor in self.factors: 75 | if apply_dropout and random.random() > self.gbp_settings.dropout or not apply_dropout: 76 | damping = self.gbp_settings.get_damping(factor.iters_since_relin) 77 | factor.compute_messages(damping) 78 | 79 | def linearise_all_factors(self) -> None: 80 | for factor in self.factors: 81 | factor.compute_factor() 82 | 83 | def robustify_all_factors(self) -> None: 84 | for factor in self.factors: 85 | factor.robustify_loss() 86 | 87 | def linearisation_only_nonlinear_factors(self) -> None: 88 | """ 89 | Check for all factors that the current estimate is close to the linearisation point. 90 | If not, relinearise the factor distribution. 91 | Relinearisation is only allowed at a maximum frequency of once every min_linear_iters iterations. 92 | """ 93 | for factor in self.factors: 94 | if not factor.meas_model.linear: 95 | adj_belief_means = factor.get_adj_means() 96 | factor.iters_since_relin += 1 97 | if torch.norm(factor.linpoint - adj_belief_means) > self.gbp_settings.beta and factor.iters_since_relin >= self.gbp_settings.min_linear_iters: 98 | factor.compute_factor() 99 | 100 | def synchronous_iteration(self) -> None: 101 | self.robustify_all_factors() 102 | self.linearisation_only_nonlinear_factors() # For linear factors, no compute is done 103 | self.compute_all_messages() 104 | self.update_all_beliefs() 105 | 106 | def gbp_solve(self, n_iters: Optional[int] = 20, converged_threshold: Optional[float] = 1e-6, include_priors: bool = True) -> None: 107 | energy_log = [self.energy()] 108 | print(f"\nInitial Energy {energy_log[0]:.5f}") 109 | i = 0 110 | count = 0 111 | not_converged = True 112 | while not_converged and i < n_iters: 113 | self.synchronous_iteration() 114 | if i in self.gbp_settings.reset_iters_since_relin: 115 | for f in self.factors: 116 | f.iters_since_relin = 1 117 | 118 | energy_log.append(self.energy(include_priors=include_priors)) 119 | print( 120 | f"Iter {i+1} --- " 121 | f"Energy {energy_log[-1]:.5f} --- " 122 | # f"Belief means: {self.belief_means().numpy()} --- " 123 | # f"Robust factors: {[factor.meas_model.loss.robust() for factor in self.factors]}" 124 | # f"Relins: {sum([(factor.iters_since_relin==0 and not factor.meas_model.linear) for factor in self.factors])}" 125 | ) 126 | i += 1 127 | if abs(energy_log[-2] - energy_log[-1]) < converged_threshold: 128 | count += 1 129 | if count == 3: 130 | not_converged = False 131 | else: 132 | count = 0 133 | 134 | def energy(self, eval_point: jnp.array = None, include_priors: bool = True) -> float: 135 | """ Computes the sum of all of the squared errors in the graph using the appropriate local loss function. """ 136 | if eval_point is None: 137 | energy = sum([factor.get_energy() for factor in self.factors]) 138 | else: 139 | var_dofs = jnp.array([v.dofs for v in self.var_nodes]) 140 | var_ix = jnp.concatenate([0, jnp.cumsum(var_dofs, axis=0)[:-1]]) 141 | energy = 0. 142 | for f in self.factors: 143 | local_eval_point = jnp.concatenate([eval_point[var_ix[v.variableID]: var_ix[v.variableID] + v.dofs] for v in f.adj_var_nodes]) 144 | energy += f.get_energy(local_eval_point) 145 | if include_priors: 146 | prior_energy = sum([var.get_prior_energy() for var in self.var_nodes]) 147 | energy += prior_energy 148 | return energy 149 | 150 | def get_joint_dim(self) -> int: 151 | return sum([var.dofs for var in self.var_nodes]) 152 | 153 | def get_joint(self) -> Gaussian: 154 | """ 155 | Get the joint distribution over all variables in the information form 156 | If nonlinear factors, it is taken at the current linearisation point. 157 | """ 158 | dim = self.get_joint_dim() 159 | joint = Gaussian(dim) 160 | 161 | # Priors 162 | var_ix = [0] * len(self.var_nodes) 163 | counter = 0 164 | for var in self.var_nodes: 165 | var_ix[var.variableID] = int(counter) 166 | joint.eta[counter:counter + var.dofs] += var.prior.eta 167 | joint.lam[counter:counter + var.dofs, counter:counter + var.dofs] += var.prior.lam 168 | counter += var.dofs 169 | 170 | # Other factors 171 | for factor in self.factors: 172 | factor_ix = 0 173 | for adj_var_node in factor.adj_var_nodes: 174 | vID = adj_var_node.variableID 175 | # Diagonal contribution of factor 176 | joint.eta[var_ix[vID]:var_ix[vID] + adj_var_node.dofs] += \ 177 | factor.factor.eta[factor_ix:factor_ix + adj_var_node.dofs] 178 | joint.lam[var_ix[vID]:var_ix[vID] + adj_var_node.dofs, var_ix[vID]:var_ix[vID] + adj_var_node.dofs] += \ 179 | factor.factor.lam[factor_ix:factor_ix + adj_var_node.dofs, factor_ix:factor_ix + adj_var_node.dofs] 180 | other_factor_ix = 0 181 | for other_adj_var_node in factor.adj_var_nodes: 182 | if other_adj_var_node.variableID > adj_var_node.variableID: 183 | other_vID = other_adj_var_node.variableID 184 | # Off diagonal contributions of factor 185 | joint.lam[var_ix[vID]:var_ix[vID] + adj_var_node.dofs, var_ix[other_vID]:var_ix[other_vID] + other_adj_var_node.dofs] += \ 186 | factor.factor.lam[factor_ix:factor_ix + adj_var_node.dofs, other_factor_ix:other_factor_ix + other_adj_var_node.dofs] 187 | joint.lam[var_ix[other_vID]:var_ix[other_vID] + other_adj_var_node.dofs, var_ix[vID]:var_ix[vID] + adj_var_node.dofs] += \ 188 | factor.factor.lam[other_factor_ix:other_factor_ix + other_adj_var_node.dofs, factor_ix:factor_ix + adj_var_node.dofs] 189 | other_factor_ix += other_adj_var_node.dofs 190 | factor_ix += adj_var_node.dofs 191 | 192 | return joint 193 | 194 | def MAP(self) -> jnp.array: 195 | return self.get_joint().mean() 196 | 197 | def dist_from_MAP(self) -> jnp.array: 198 | return jnp.linalg.norm(self.get_joint().mean() - self.belief_means()) 199 | 200 | def belief_means(self) -> jnp.array: 201 | """ Get an array containing all current estimates of belief means. """ 202 | return jnp.concatenate([var.belief.mean() for var in self.var_nodes]) 203 | 204 | def belief_covs(self) -> List[jnp.array]: 205 | """ Get a list containing all current estimates of belief covariances. """ 206 | covs = [var.belief.cov() for var in self.var_nodes] 207 | return covs 208 | 209 | def print(self, brief=False) -> None: 210 | print("\nFactor Graph:") 211 | print(f"# Variable nodes: {len(self.var_nodes)}") 212 | if not brief: 213 | for i, var in enumerate(self.var_nodes): 214 | print(f"Variable {i}: connects to factors {[f.factorID for f in var.adj_factors]}") 215 | print(f" dofs: {var.dofs}") 216 | print(f" prior mean: {var.prior.mean()}") 217 | print(f" prior covariance: diagonal sigma {jnp.diag(var.prior.cov())}") 218 | print(f"# Factors: {len(self.factors)}") 219 | if not brief: 220 | for i, factor in enumerate(self.factors): 221 | if factor.meas_model.linear: 222 | print("Linear", end =" ") 223 | else: 224 | print("Nonlinear", end =" ") 225 | print(f"Factor {i}: connects to variables {factor.adj_vIDs}") 226 | print(f" measurement model: {type(factor.meas_model).__name__}," 227 | f" {type(factor.meas_model.loss).__name__}," 228 | f" diagonal sigma {jnp.diag(factor.meas_model.loss.effective_cov)}") 229 | print(f" measurement: {factor.measurement}") 230 | print("\n") 231 | -------------------------------------------------------------------------------- /gaussian-loopy-bp/gauss-bp-1d-line.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "\n", 7 | "# Gaussian Belief Propagation applied to denoising a 1d line\n", 8 | "\n", 9 | "\n", 10 | "This example is based on the [PyTorch colab by Joseph Ortiz](https://colab.research.google.com/drive/1-nrE95X4UC9FBLR0-cTnsIP_XhA_PZKW?usp=sharing)\n", 11 | "\n" 12 | ], 13 | "metadata": { 14 | "id": "6dDb7OHiIVOL" 15 | }, 16 | "id": "6dDb7OHiIVOL" 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "id": "08079698-6341-40fb-b67e-3a5e7d744f53", 22 | "metadata": { 23 | "id": "08079698-6341-40fb-b67e-3a5e7d744f53", 24 | "outputId": "7df7103a-ce57-4816-aed1-6f93ec07d1d6", 25 | "colab": { 26 | "base_uri": "https://localhost:8080/" 27 | } 28 | }, 29 | "outputs": [ 30 | { 31 | "output_type": "stream", 32 | "name": "stdout", 33 | "text": [ 34 | "Cloning into 'pgm-jax'...\n", 35 | "remote: Enumerating objects: 98, done.\u001b[K\n", 36 | "remote: Counting objects: 100% (98/98), done.\u001b[K\n", 37 | "remote: Compressing objects: 100% (78/78), done.\u001b[K\n", 38 | "remote: Total 98 (delta 36), reused 57 (delta 16), pack-reused 0\u001b[K\n", 39 | "Unpacking objects: 100% (98/98), done.\n" 40 | ] 41 | } 42 | ], 43 | "source": [ 44 | "!git clone https://github.com/probml/pgm-jax.git" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "id": "a8c778dc-be87-4465-a622-742cc557f87c", 51 | "metadata": { 52 | "id": "a8c778dc-be87-4465-a622-742cc557f87c", 53 | "outputId": "ec18f235-d018-4ce2-e94c-ebd14af10385", 54 | "colab": { 55 | "base_uri": "https://localhost:8080/" 56 | } 57 | }, 58 | "outputs": [ 59 | { 60 | "output_type": "stream", 61 | "name": "stdout", 62 | "text": [ 63 | "/content/pgm-jax\n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "%cd pgm-jax" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 3, 74 | "id": "867daea0-1432-4ae3-9fa6-67a9890b76c1", 75 | "metadata": { 76 | "id": "867daea0-1432-4ae3-9fa6-67a9890b76c1" 77 | }, 78 | "outputs": [], 79 | "source": [ 80 | "import numpy as np\n", 81 | "\n", 82 | "from jax import numpy as jnp\n", 83 | "from jax import random as jr\n", 84 | "from jax import jit\n", 85 | "from matplotlib import pyplot as plt" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "id": "bc968c0d-f9d0-4010-bd92-a61752185238", 91 | "metadata": { 92 | "id": "bc968c0d-f9d0-4010-bd92-a61752185238" 93 | }, 94 | "source": [ 95 | "# 1D Smoothing Demo" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 4, 101 | "id": "93e8fdd9-fcb4-4f0a-9149-023c4f41183f", 102 | "metadata": { 103 | "id": "93e8fdd9-fcb4-4f0a-9149-023c4f41183f" 104 | }, 105 | "outputs": [], 106 | "source": [ 107 | "from gauss_bp.gauss_factor_graph import (CanonicalPotential, GaussianVariableNode, CanonicalFactor,\n", 108 | " GaussianFactorGraph, make_canonical_factor, make_factor_graph,\n", 109 | " init_messages, update_all_messages, calculate_all_beliefs)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 5, 115 | "id": "6f0f9735-d325-4db4-8994-64a0b12d34bc", 116 | "metadata": { 117 | "id": "6f0f9735-d325-4db4-8994-64a0b12d34bc" 118 | }, 119 | "outputs": [], 120 | "source": [ 121 | "def cpot_to_moment_1D(cpot):\n", 122 | " eta, Lambda = cpot\n", 123 | " var = 1/Lambda.squeeze()\n", 124 | " return eta * var, var\n", 125 | "\n", 126 | "def beliefs_to_means_covs(beliefs):\n", 127 | " mean_var_list = [cpot_to_moment_1D(b) for b in beliefs.values()]\n", 128 | " means, variances = (jnp.hstack(x) for x in zip(*mean_var_list))\n", 129 | " return means, variances\n", 130 | "\n", 131 | "def plot_beliefs(fg,messages,xs=None,**kwargs):\n", 132 | " beliefs = calculate_all_beliefs(fg,messages)\n", 133 | " mus, covs = beliefs_to_means_covs(beliefs)\n", 134 | " if xs is None:\n", 135 | " xs = np.arange(len(mus))\n", 136 | " if 'fmt' not in kwargs:\n", 137 | " kwargs['fmt'] = \"-o\"\n", 138 | " plt.errorbar(xs, mus, yerr=np.sqrt(covs), **kwargs); " 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "id": "6c60554c-de6b-429c-8f45-b96c78c72746", 144 | "metadata": { 145 | "id": "6c60554c-de6b-429c-8f45-b96c78c72746" 146 | }, 147 | "source": [ 148 | "### Define measurement factors" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 6, 154 | "id": "c293a2f2-37a5-4b69-b943-b87f544b5419", 155 | "metadata": { 156 | "id": "c293a2f2-37a5-4b69-b943-b87f544b5419" 157 | }, 158 | "outputs": [], 159 | "source": [ 160 | "def gamma(x,g1,g2):\n", 161 | " return (g1-x)/(g1-g2)\n", 162 | "\n", 163 | "def measurement_pot(meas,g1,g2,delta=1.):\n", 164 | " mx, my = meas\n", 165 | " γ = gamma(mx,g1,g2)\n", 166 | " eta = delta * jnp.array([(1-γ)*my, γ*my])\n", 167 | " Lambda = delta * jnp.array([[(1-γ)**2, γ*(1-γ)],\n", 168 | " [γ*(1-γ), γ**2]])\n", 169 | " return CanonicalPotential(eta, Lambda)\n", 170 | "\n", 171 | "def create_measurement_factor(meas,xs,var_nodes,delta=1.):\n", 172 | " mx, my = meas\n", 173 | " i = np.argwhere(xs > mx).min()\n", 174 | " mpot = measurement_pot(meas,xs[i],xs[i-1], delta)\n", 175 | " mfactor = make_canonical_factor(f\"meas_{i}-{i-1}\",(var_nodes[i],var_nodes[i-1]),mpot)\n", 176 | " return mfactor" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "id": "d2f0830c-0d46-4744-9728-ec798aa80dab", 182 | "metadata": { 183 | "id": "d2f0830c-0d46-4744-9728-ec798aa80dab" 184 | }, 185 | "source": [ 186 | "### Set Model Parameters" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 7, 192 | "id": "33ee4f51-1dca-4aef-b599-14bc2376db63", 193 | "metadata": { 194 | "id": "33ee4f51-1dca-4aef-b599-14bc2376db63", 195 | "outputId": "eeb61ded-b99f-455d-db8c-653a6c27ff5d", 196 | "colab": { 197 | "base_uri": "https://localhost:8080/" 198 | } 199 | }, 200 | "outputs": [ 201 | { 202 | "output_type": "stream", 203 | "name": "stderr", 204 | "text": [ 205 | "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" 206 | ] 207 | } 208 | ], 209 | "source": [ 210 | "n_varnodes = 20\n", 211 | "x_range = 10\n", 212 | "n_measurements = 15\n", 213 | "\n", 214 | "## Parameters ##\n", 215 | "prior_cov = 10.\n", 216 | "prior_prec = 1/prior_cov\n", 217 | "data_cov = 0.05\n", 218 | "data_prec = 1/data_cov\n", 219 | "smooth_cov = 0.1\n", 220 | "smooth_prec = 1/smooth_cov\n", 221 | "data_std = jnp.sqrt(data_cov)\n", 222 | "\n", 223 | "## Evaluation points ##\n", 224 | "xs = jnp.linspace(0, x_range, n_varnodes)" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "id": "543f4295-6e8c-4a41-9d34-fcbcb4c3c657", 230 | "metadata": { 231 | "id": "543f4295-6e8c-4a41-9d34-fcbcb4c3c657" 232 | }, 233 | "source": [ 234 | "### Create measurement data" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 8, 240 | "id": "d6b9507b-1f61-48bf-b065-20907717e4d6", 241 | "metadata": { 242 | "id": "d6b9507b-1f61-48bf-b065-20907717e4d6" 243 | }, 244 | "outputs": [], 245 | "source": [ 246 | "key = jr.PRNGKey(42)\n", 247 | "mxs = jr.randint(key, (n_measurements,), 0, x_range)\n", 248 | "key, subkey = jr.split(key)\n", 249 | "mys = jnp.sin(mxs) + jr.normal(key, (n_measurements,))*data_std" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": 9, 255 | "id": "de7531c7-6aed-44af-a7ba-3204697ae6ee", 256 | "metadata": { 257 | "id": "de7531c7-6aed-44af-a7ba-3204697ae6ee", 258 | "outputId": "d6224549-e40d-4708-aaf3-c09bed423acc", 259 | "colab": { 260 | "base_uri": "https://localhost:8080/", 261 | "height": 265 262 | } 263 | }, 264 | "outputs": [ 265 | { 266 | "output_type": "display_data", 267 | "data": { 268 | "text/plain": [ 269 | "
" 270 | ], 271 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAObklEQVR4nO3df4xlZ13H8ffHWRYEE0C7gbLbsDU26sZf4KThSmImTgktGlZFkjZRC5Esf1BBY6JFEkz4ZzEx/koazKZUqxKKaSGsurHAwg1/OJBOoQJtbVir0F2KDKBo/LXu9usf566dXWa3s9wz99ze5/1KJs/5tff55mTnc5/z3HPPpKqQJC2+bxu6AEnSbBj4ktQIA1+SGmHgS1IjDHxJasSuoQu4mCuuuKL2798/dBmS9LRy//33f7Wq9my1b24Df//+/ayvrw9dhiQ9rST5wsX2OaUjSY0w8CWpEQa+JDXCwJekRhj4ktQIA1+SGmHgSxLA2hocPty1C2pu78OXpJlZW4PVVTh9GnbvhuPHYTQauqreOcKXpPG4C/uzZ7t2PB66oh1h4EvSygosLUHStSsrQ1e0Iwx8SYIu7De3C8jAl6TxGM6cgaqudUpHkhbUykr3Ye3SUtcu6JSOd+lI0mjU3ZkzHndhv4B36ICBL0md0Whhg/4cp3QkqREGviQ1wsCXpEYY+JLUCANfkhph4C+6Bp4AKGl7vC1zkTXyBEBJ2+MIf5E18gRASdtj4C+yRr4uLml7nNJZZI18XVzS9hj4i66Br4tL2h6ndCSpEQa+JDWil8BPckeSryT53EX2J8kfJjmR5DNJXtpHv5Kk7etrhP8nwPWX2H8DcM3k5xDwrp76lSRtUy+BX1UfB75+iUMOAn9anU8Az0tyZR99S5K2Z1Zz+HuBxzatn5xsO0+SQ0nWk6xvbGzMqDRJasNcfWhbVUeqarmqlvfs2TN0OZK0UGYV+KeAqzat75tskyTNyKwC/yjwi5O7dV4GfKOqHp9R35IkevqmbZL3AivAFUlOAr8FPAOgqv4IOAa8CjgB/Cfw+j76nWtraz7SQNJc6SXwq+qmp9hfwJv66OtpwccSS5pDc/Wh7cLwscSS5pCBvxN8LLGkOeTTMneCjyWWNIcM/J3iY4klzRmndCSpEQa+JDXCwJekRhj4ktQIA1+SGmHgSxrW2hocPty12lHelilpOD6GZKYc4Usajo8hmSkDf6d4mSo9NR9DMlNO6eyEebpM9THNmmc+hmSmDPydsNVl6hD/kefpjUfS4Az8nXDuMvVc0A51mTovbzzSxTgomSkDfyfMy2Xqygrs2gVPPNG1zo9q3jgomSkDf6fMy9Myq85vpXkyL1fDjTDwF9l43I2cqrrW0ZPmzbxcDTfCwF9kjp70dDAvV8MNMPAXmaMnSZsY+IvO0ZOkCb9pK0mNMPAlqREGviQ1wsCXpEYY+JLUCANfkhph4EtSIwx8SWqEgS9Jjegl8JNcn+SRJCeS3LrF/tcl2UjywOTnDX30K0navqkfrZBkCbgNeAVwErgvydGqeuiCQ99XVbdM258k6VvTxwj/WuBEVT1aVaeBu4CDPbyuJKlHfQT+XuCxTesnJ9su9Jokn0lyd5KrtnqhJIeSrCdZ39jY6KE0SdI5s/rQ9i+B/VX1Q8CHgTu3OqiqjlTVclUt79mzZ0alSVIb+gj8U8DmEfu+ybb/V1Vfq6r/mazeDvxoD/1Kki5DH4F/H3BNkquT7AZuBI5uPiDJlZtWXw083EO/krR41tbg8OGu7dnUd+lU1ZkktwD3AkvAHVX1YJJ3AOtVdRR4c5JXA2eArwOvm7ZfSVo4a2uwuvrknyU9frzXP2DUy1+8qqpjwLELtr190/Jbgbf20ZckLazxuAv7s2e7djzuNfD9pq0kzYuVlW5kv7TUtSsrvb68f9NWkubFaNRN44zHXdj3/PeoDXxJmiejUe9Bf45TOpLUCANfkhph4EtSIwx8SWqEgS9JjTDwJakRBr4kNcLAl6RGGPiS1AgDX5IaYeBLUiMMfElqhIEvSY0w8CWpEQa+JDXCwJekRhj4ktQIA1+SGmHgS1IjDHxJaoSBL0mNMPAlqRGLGfhra3D4cNdKkgDYNXQBvVtbg9VVOH0adu+G48dhNBq6Kkka3OKN8MfjLuzPnu3a8XjoiiRpLixe4K+sdCP7paWuXVkZuiJJmguLN6UzGnXTOONxF/ZO50gSsIiBD13IG/SSdJ5epnSSXJ/kkSQnkty6xf5nJnnfZP8nk+zvo19J0vZNHfhJloDbgBuAA8BNSQ5ccNgvAf9SVd8D/B7w29P2K0m6PH2M8K8FTlTVo1V1GrgLOHjBMQeBOyfLdwOrSdJD35Kkbeoj8PcCj21aPznZtuUxVXUG+AbwXRe+UJJDSdaTrG9sbPRQmiTpnLm6LbOqjlTVclUt79mzZ+hyJGmh9BH4p4CrNq3vm2zb8pgku4DnAl/roW9J0jb1Efj3AdckuTrJbuBG4OgFxxwFbp4s/xzw0aqqHvqWJG3T1PfhV9WZJLcA9wJLwB1V9WCSdwDrVXUUeDfwZ0lOAF+ne1OQJM1QL1+8qqpjwLELtr190/J/A6/toy9J0rdmrj60lSTtHANfkhph4EtSIwx8SWqEgS9JjTDwJakRBr4kNcLAl6RGGPiS1AgDX5IaYeBLUiMMfElqhIEvSY0w8CWpEQa+JDXCwJekRhj4ktQIA1+SGmHgS1IjDHxJaoSBL0mNMPAlqREGviQ1wsCXpEYY+JLUCANfkhph4EtSIwx8SWqEgS9JjTDwJakRBr4kNWKqwE/ynUk+nOTzk/b5FznubJIHJj9Hp+lTUo+OHIFXvrJrtfB2TfnvbwWOV9U7k9w6Wf+NLY77r6r6kSn7ktSnI0fgjW/slj/0oa49dGi4erTjpp3SOQjcOVm+E/jpKV9P0qzcc8+l17Vwpg38F1TV45PlLwMvuMhxz0qynuQTSS76ppDk0OS49Y2NjSlLk3RJr3nNpde1cJ5ySifJR4AXbrHrbZtXqqqS1EVe5sVVdSrJdwMfTfLZqvqHCw+qqiPAEYDl5eWLvZakPpybvrnnni7snc5ZeE8Z+FV13cX2JfnnJFdW1eNJrgS+cpHXODVpH00yBl4CfFPgS5qxQ4cM+oZMO6VzFLh5snwz8MELD0jy/CTPnCxfAbwceGjKfiVJl2nawH8n8Ioknweum6yTZDnJ7ZNjvh9YT/J3wMeAd1aVgS9JMzbVbZlV9TVgdYvt68AbJst/C/zgNP1IkqbnN20lqREGviQ1wsCXpEYY+JLUCANfkhph4EtSIwx8aQhra3D4cNdKMzLt45ElXa61NVhdhdOnYfduOH4cRqOhq1IDHOFLszYed2F/9mzXjsdDV6RGGPjSrK2sdCP7paWuXVkZuiI1wikdadZGo24aZzzuwt7pHM2IgS8NYTQy6DVzTulIUiMMfElqhIEvSY0w8CWpEQa+JDXCwJekRhj4ktQIA1+SGmHgS1IjDHxJaoSBL0mNMPAlqREGviQ1wsCXpEYY+JLUCANfkhph4EtSIwx8SWqEgS9JjZgq8JO8NsmDSZ5IsnyJ465P8kiSE0lunaZPSdK3ZtoR/ueAnwU+frEDkiwBtwE3AAeAm5IcmLJfSdJl2jXNP66qhwGSXOqwa4ETVfXo5Ni7gIPAQ9P0LUm6PLOYw98LPLZp/eRk2zdJcijJepL1jY2NGZQmSe14yhF+ko8AL9xi19uq6oN9FlNVR4AjAMvLy9Xna0tS654y8Kvquin7OAVctWl932SbJGmGZjGlcx9wTZKrk+wGbgSOzqBfSdIm096W+TNJTgIj4K+T3DvZ/qIkxwCq6gxwC3Av8DDwF1X14HRlS5Iu17R36XwA+MAW278EvGrT+jHg2DR9SZKm4zdtJakRBr4kNcLAl6RGGPiS1AgDX5IaYeBLUiMMfElqhIGvtqytweHDXSs1ZqovXklPK2trsLoKp0/D7t1w/DiMRkNXJc2MI3y1Yzzuwv7s2a4dj4euSJopA1/tWFnpRvZLS127sjJ0RdJMOaWjdoxG3TTOeNyFvdM5aoyBr7aMRga9muWUjiQ1wsCXpEYY+JLUCANfkhph4EtSIwx8SWpEqmroGraUZAP4whQvcQXw1Z7KebrzXJzP83E+z8eTFuFcvLiq9my1Y24Df1pJ1qtqeeg65oHn4nyej/N5Pp606OfCKR1JaoSBL0mNWOTAPzJ0AXPEc3E+z8f5PB9PWuhzsbBz+JKk8y3yCF+StImBL0mNWLjAT3J9kkeSnEhy69D1DCnJVUk+luShJA8mecvQNQ0tyVKSTyf5q6FrGVqS5yW5O8nfJ3k4SdPPjU7yq5Pfk88leW+SZw1dU98WKvCTLAG3ATcAB4CbkhwYtqpBnQF+raoOAC8D3tT4+QB4C/Dw0EXMiT8A/qaqvg/4YRo+L0n2Am8GlqvqB4Al4MZhq+rfQgU+cC1woqoerarTwF3AwYFrGkxVPV5Vn5os/zvdL/TeYasaTpJ9wE8Ctw9dy9CSPBf4ceDdAFV1uqr+ddiqBrcL+PYku4BnA18auJ7eLVrg7wUe27R+koYDbrMk+4GXAJ8ctpJB/T7w68ATQxcyB64GNoA/nkxx3Z7kOUMXNZSqOgX8DvBF4HHgG1X1oWGr6t+iBb62kOQ7gHuAX6mqfxu6niEk+SngK1V1/9C1zIldwEuBd1XVS4D/AJr9zCvJ8+lmA64GXgQ8J8nPD1tV/xYt8E8BV21a3zfZ1qwkz6AL+/dU1fuHrmdALwdeneSf6Kb6fiLJnw9b0qBOAier6twV3910bwCtug74x6raqKr/Bd4P/NjANfVu0QL/PuCaJFcn2U33ocvRgWsaTJLQzdE+XFW/O3Q9Q6qqt1bVvqraT/f/4qNVtXAjuO2qqi8DjyX53smmVeChAUsa2heBlyV59uT3ZpUF/BB719AF9KmqziS5BbiX7lP2O6rqwYHLGtLLgV8APpvkgcm236yqYwPWpPnxy8B7JoOjR4HXD1zPYKrqk0nuBj5Fd3fbp1nAxyz4aAVJasSiTelIki7CwJekRhj4ktQIA1+SGmHgS1IjDHxJaoSBL0mN+D9b2tZoykufAgAAAABJRU5ErkJggg==\n" 272 | }, 273 | "metadata": { 274 | "needs_background": "light" 275 | } 276 | } 277 | ], 278 | "source": [ 279 | "plt.plot(mxs,mys,'r.');" 280 | ] 281 | }, 282 | { 283 | "cell_type": "markdown", 284 | "id": "616bb02a-0ede-4044-9111-a8c95914ab53", 285 | "metadata": { 286 | "id": "616bb02a-0ede-4044-9111-a8c95914ab53" 287 | }, 288 | "source": [ 289 | "### Construct factor graph" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 10, 295 | "id": "7d83e762-e86e-404e-aa8e-1817ce01a134", 296 | "metadata": { 297 | "id": "7d83e762-e86e-404e-aa8e-1817ce01a134" 298 | }, 299 | "outputs": [], 300 | "source": [ 301 | "## Variable nodes ##\n", 302 | "dim = 1\n", 303 | "prior_Lambda = jnp.array([[prior_prec]])\n", 304 | "var_nodes = [GaussianVariableNode(i,dim,CanonicalPotential(jnp.zeros(dim),prior_Lambda))\n", 305 | " for i in range(n_varnodes)]\n", 306 | "\n", 307 | "## Smoothing factors ##\n", 308 | "smoothing_eta = jnp.zeros(2*dim)\n", 309 | "smoothing_Lambda = jnp.array([[smooth_prec, -smooth_prec],\n", 310 | " [-smooth_prec, smooth_prec]])\n", 311 | "smoothing_pot = CanonicalPotential(smoothing_eta, smoothing_Lambda)\n", 312 | "smoothing_factors = [CanonicalFactor(factorID=f\"smoothing_factor_{i}-{i+1}\",\n", 313 | " adj_varIDs=[i,i+1],\n", 314 | " potential=smoothing_pot,\n", 315 | " var_scopes = {i:(0,1),i+1:(1,2)})\n", 316 | " for i in range(n_varnodes-1)]\n", 317 | "\n", 318 | "## Measurement factors ##\n", 319 | "measurement_factors = [create_measurement_factor(m,xs,var_nodes, delta=data_prec) for m in zip(mxs,mys)]\n", 320 | "\n", 321 | "## All factors ##\n", 322 | "factors = smoothing_factors + measurement_factors\n", 323 | "\n", 324 | "fg = make_factor_graph(var_nodes, factors)" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": 11, 330 | "id": "5044d1a2-cc32-41cd-8bd7-33b44e06359c", 331 | "metadata": { 332 | "id": "5044d1a2-cc32-41cd-8bd7-33b44e06359c" 333 | }, 334 | "outputs": [], 335 | "source": [ 336 | "jit_update_messages = jit(lambda m: update_all_messages(fg,m,damping=0.1))" 337 | ] 338 | }, 339 | { 340 | "cell_type": "markdown", 341 | "id": "848a00f6-491b-49a8-9926-243ffc91305b", 342 | "metadata": { 343 | "id": "848a00f6-491b-49a8-9926-243ffc91305b" 344 | }, 345 | "source": [ 346 | "### Plot initial belief state" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": 12, 352 | "id": "305a7f6f-7aa5-4ae1-8c95-3794f455eb90", 353 | "metadata": { 354 | "id": "305a7f6f-7aa5-4ae1-8c95-3794f455eb90", 355 | "outputId": "dfd8d847-e776-430a-fe30-c276b2e05340", 356 | "colab": { 357 | "base_uri": "https://localhost:8080/", 358 | "height": 265 359 | } 360 | }, 361 | "outputs": [ 362 | { 363 | "output_type": "display_data", 364 | "data": { 365 | "text/plain": [ 366 | "
" 367 | ], 368 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAASo0lEQVR4nO3dfZDdVX3H8c+H3Q0bntzYbCVkE0JbwMFoQ+cOYumIBUriQ0VhyoNTpK3DSqsldARHzZQk08mYmShoUSg7JBUs9aGI6GggojJlLElkQxIMBIQulSSuZC0FwpKHm5tv/7ibzSbZZB/u797fnnvfr5lM7v3tPed872T3k7O/e37n54gQACBdx+RdAACgMgQ5ACSOIAeAxBHkAJA4ghwAEtecx6BTp06NWbNm5TE0ACRr3bp1v42I9kOP5xLks2bNUnd3dx5DA0CybP9quOOcWgGAxBHkAJA4ghwAEkeQA0DiCHIASBxBDgCJI8gBIHEEOQAkLqkgv+LO1briztW0pz3taV/z9hOlhuEkFeQAgMNVHOS2W23/3PZG20/ZXpxFYQCA0clir5Xdki6IiNdtt0j6me0HI2JNBn0DAEZQcZBH+aafrw88bRn4w41AAaBGMjlHbrvJ9gZJ2yU9HBFrh3lNp+1u2919fX1ZDAsAUEZBHhGliJgjqUPSObZnD/OarogoREShvf2w7XQBAOOU6aqViHhF0iOS5mXZLwDgyLJYtdJuu23g8WRJfybpmUr7BQCMTharVqZJutt2k8r/MXw7In6QQb8AgFHIYtXKk5LOzqAWAMA4cGUnACSOIAeAxBHkAJA4ghwAEkeQA0DiCHIASBxBDgCJI8gBIHEEOQAkjiAHgMQR5ACQOIIcABJHkANA4ghyAEgcQQ4AiSPIASBxBDkAJI4gB4DEEeQAkDiCHAASR5ADQOIIcgBIXMVBbnuG7UdsP237KdvzsygMADA6zRn0sVfSpyLiCdsnSlpn++GIeDqDvgEAI6h4Rh4RvRHxxMDjHZI2S5peab8AgNHJ9By57VmSzpa0Nst+AQBHllmQ2z5B0nck3RARrw3z9U7b3ba7+/r6shoWABpeJkFuu0XlEL83Iu4f7jUR0RURhYgotLe3ZzEsAEDZrFqxpOWSNkfELZWXBAAYiyxm5OdJulrSBbY3DPx5Xwb9AgBGoeLlhxHxM0nOoBYAwDhwZScAJI4gB4DEEeQAkDiCHAASR5ADQOIIcgBIHEEOAIkjyAEgcQQ5ACSOIAeAxBHkAJA4ghwAEkeQA0DiCHIASBxBDqDm+tesVc+ff1B7+/oOeozxqXg/cgAYi/41a/XitddKxaK2Xj9fOzdtkopFvfTFWzR96efzLi9JzMgB1FTvggVSsShJ2rlhw+Dj11auzLOspBHkAGrKkyYdeBIx+PCYlpYcqqkPBDmAmmqfP1/y4XeHbJ0zJ4dq6gNBPoEd+oHQ397zjzqh/9W8ywIqsn3ZsoNm4vu98fjjOVRTH/iwc4LqX7NWW667TlEsatuNN2nnxo2aunuP3r3m+5Lm5V0eMG6eNKk8I4+QWlqkUknat0/HHHts3qUlixn5BPXSkiWKYlEqlbRz40bFrl1qin1623PdeZcGVOTUe+5W25VXqmnKFE1ftkxtl1+upilT1HHbbXmXliyCfIKasfwuHVcoyK2til27JEnF5hb98IKrc64MqExze7umLbxZZ6x+TCfNm6tpixbqjNWP6fhz35l3acnKJMhtr7C93famLPqDtKfnhcGZ+H6OfTpty+YcqwIwEWU1I/+aOHGbqcFTK5Lc2iq1tKi5VNJZnFoBcIhMgjwiHpX0chZ9TUT7V4+c0P+qZm3ZXJPLiWeuWD547vCUpUvVdtll6p98ov7j/ddVdVwA6anZqhXbnZI6JWnmzJm1GrZiQ1ePXLrrTnX09mh3lNR3+x2atvDmqo27/zzi/jFOmjdXN0y7uGrjAUhXzT7sjIiuiChERKG9vb1Ww1Zs6OqRjt4eTdq7RyqVtOOhh/IuDQAksWplRENXj0zau0dS+Zz1yQsX5lwZAJQR5CMYbvVIlErqX7Mmx6qAyrGVbP3IavnhNyStlnSm7a22P5ZFvxPB0NUjxeYW7W1qkopF7Vi1KufKgPHb/9nP7p4ebbvxpsHHfbffkXdpGIesVq1cFRHTIqIlIjoiYnkW/U4EM1cs1wnnny81NenB93xEz5/6dqmpSScvqv6pFfZaQbW8tGSJYk/58543Nmwo/8ZZKrGVbKI4tTKC3f/do/7HHpMkzX52rX7vxacllUO2mvrXrNWWzk7tfu45bZ1/g7Z0dup3/3ebLvzZfVUdF43hdz7xdwee7N49+HDy7Nk5VINKEeQjyGvVSu+CBeUZk8qb78eePbKk2c/+vKrjojH0LfuCtG/fYcff6OaCsxQR5CPIa9XKQZvvD/mBKzU1VXVcNIaDvr+GHm9mQ9QUEeQjyGvVytT510vHHPzPE5K2nvz7VR0XjaH9hvnSMJOC4845J4dqUCmCfAR5rVoZ7ldfSzp12y+rOi4aw2+/8tXBx/v38pEG7qGJ5BDkIxi658l3516r9W97d3kf5Vtvreq4B/3q29IyeGusvU386ovKDbeXTy2+r1EdBPkIhu6d/PQZBf3wwqtrsnfyqffcrbarrjqw+f4VV6h/8on61gc/WdVx0RjYE7y+ML2boNg0C8BoMSMHgMQR5AAaxtD7CdTT/jIEOYCGMGvLZn3kgS/X5f4yBDmAhvDeR/5dTaW9Uql04NqQOrm3AEEOoCF8/dJP6VcdZ8itrYMX+NXLvQUIcgANYer/9aqjt6cu7y1AkANoCIOnVjTkatY6ubcAQQ6gIXz9shu17u3n1+XVrFwQBKAhvH78m7Tywqt1zcdvl6TBK1rrATNyAEgcQQ4AiUsmyIfes7KersgCgEolEeT77/g99eVeXfrgnXV1RRYAVCqJIN9/c4em2HdgHWidXJEFAJXKJMhtz7P9rO3nbX8miz6HmrH8Lv3Xn3xYH714gT70/iW65uLP6ZHTzhnTFVkPrN+m9S++orUvvKzzlv5UD6zfNqYaaE972jdu+4lSw5E0LVq0qKIObDdJekjSXEmfl/TPixcvfnTRokVHPIHd1dW1qLOzc9Rj3LeyW0v62vTasSdItvpbJqt76uma+usXNOeCkTfCf2D9Nn32/l9o997yrdN27Nqr//xlnzqmTNZbp51Ee9rTnvYTvgZJWrx4ce+iRYu6Dj2exYz8HEnPR0RPROyR9E1Jl2TQ76Avru7V7uaD7/q9u3mSbn/9zaNqv2zVs9pZLB10bGexpGWrnqU97WlP+yRqOJosgny6pC1Dnm8dOHYQ2522u213941xtUnfpBOHP97aNqr2v35l55iO0572tKf9RKvhaGr2YWdEdEVEISIK7e3tY2p7StvkMR2nPe1pT/us2k+UGo4miyDfJmnGkOcdA8cyc9PcMzW5pemgY5NbmnTT3DNpT3va076q7SdKDUeTxV4rj0s63fZpKgf4lZI+kkG/gz50dvlMzafve1J7Svs0vW2ybpp75uBx2tOe9rSvVvuJUsPROCIq78R+n6QvSWqStCIilhzt9YVCIbq7u8c8zhV3rpYkfevj7xpHlbSnPe1pP/72E6EG2+sionDo8Ux2P4yIlZJWZtEXAGBskriyEwBwZAQ5ACSOIAeAxBHkAJA4ghwAEkeQA0DiCHIgZ0PveMXdrzAemawjBzA+/WvW6sVrr5WKRW29fr52btokFYt66Yu3aPrSz+ddHhLBjBzIUe+CBVKxKEnauWHD4OPXVnJ9HUaPIAdy5ElD9tkfsl3GMS0tOVSDVBHkQI7a58+X7MOOt86Zk0M1SBVBDuRo+7JlB83E93vj8cdzqAapIsgB5bdyxJMmHZiRt7RIx5R/JI859tiqj436QZCj4fWvWast112n3T092nbjTYOP+26/o+pjn3rP3Wq78ko1TZmi6cuWqe3yy9U0ZYo6brut6mOjfhDkaHgvLVmiKBalUkk7N25U7NollUra8dBDVR+7ub1d0xberDNWP6aT5s3VtEULdcbqx3T8ue+s+tioHwQ5Gt6M5XfpuEJBbm0th7gkt7bq5IULc64MGB2CHA1vT88LB2biA6JUUv+aNTlWBYweQY6GN3hqReWZuFpapGJRO1atyrkyYHQIcjS8mSuWD37IeMrSpWq77LLyh4+33pp3acCosNcKGt7+DxynLbxZkgY/dARSwYwcABJHkANA4ioKctt/Yfsp2/tsF7IqCgAwepXOyDdJulTSoxnUAgAYh4o+7IyIzZLkYXZvAwDURs3OkdvutN1tu7uP21gBQGZGnJHb/rGkk4f50oKI+N5oB4qILkldklQoFA7ftxMAMC4jBnlEXFSLQgAA48PyQ0wo3FEeGLtKlx9+2PZWSe+S9EPbbE6BcctzX3AgZRUFeUR8NyI6IuLYiHhLRMzNqjA0njz3BQdSxqkVTBjsCw6MD0GOCYN9wYHxIcgxYbAvODA+BDkmDPYFB8aH/cgxYbAvODA+zMgBIHEEOQAkjiAHgMQR5ACQOIIcABJHkANA4ghyAEgcQQ4AiSPIASBxBDkAJI4gB4DEEeQAkDiCHAASR5ADQOIIcgBIHEEOAIkjyAEgcRUFue1ltp+x/aTt79puy6owAMDoVDojf1jS7Ih4h6RfSvps5SUBAMaioiCPiB9FxN6Bp2skdVReEgBgLLI8R/43kh7MsD8AwCg0j/QC2z+WdPIwX1oQEd8beM0CSXsl3XuUfjoldUrSzJkzx1UsAOBwIwZ5RFx0tK/b/itJH5B0YUTEUfrpktQlSYVC4YivAwCMzYhBfjS250n6tKTzI+KNbEoCAIxFpefIvyLpREkP295g+18yqAkAMAYVzcgj4g+yKgQAMD5c2QkAiSPIASBxBDkAJI4gB4DEEeQAkDiCHAASR5ADQOIIcgBIHEEOAIkjyAEgcQQ5ACSOIAeAxBHkAJA4ghwAEkeQA0DiCHIASBxBDgCJI8gBIHEEOQAkjiAHgMQR5ACQOIIcABJHkANA4ioKctv/ZPtJ2xts/8j2KVkVBgAYnUpn5Msi4h0RMUfSDyTdnEFNAIAxqCjII+K1IU+PlxSVlQMAGKvmSjuwvUTSRyW9KulPj/K6TkmdkjRz5sxKhwUADBhxRm77x7Y3DfPnEkmKiAURMUPSvZI+eaR+IqIrIgoRUWhvb8/uHQBAgxtxRh4RF42yr3slrZS0sKKKAABjUumqldOHPL1E0jOVlQMAGKtKz5EvtX2mpH2SfiXpuspLAgCMhSNqv9CkUChEd3d3zccFgJTZXhcRhUOPc2UnACSOIAeAxBHkAJA4ghwAEkeQA0DiCHIASBxBDgCJI8gBIHEEOQAkLpcrO233qXxJ/3hMlfTbDMtJAe+5MfCeG0Ml7/nUiDhs+9hcgrwStruHu0S1nvGeGwPvuTFU4z1zagUAEkeQA0DiUgzyrrwLyAHvuTHwnhtD5u85uXPkAICDpTgjBwAMQZADQOKSCnLb82w/a/t525/Ju55qsz3D9iO2n7b9lO35eddUC7abbK+3/YO8a6kF222277P9jO3Ntt+Vd03VZvsfBr6nN9n+hu3WvGvKmu0Vtrfb3jTk2JttP2z7uYG/p2QxVjJBbrtJ0lclvVfSWZKusn1WvlVV3V5Jn4qIsySdK+kTDfCeJWm+pM15F1FDX5b0UES8VdIfqs7fu+3pkq6XVIiI2ZKaJF2Zb1VV8TVJ8w459hlJP4mI0yX9ZOB5xZIJcknnSHo+InoiYo+kb0q6JOeaqioieiPiiYHHO1T+AZ+eb1XVZbtD0vsl3ZV3LbVg+02S3i1puSRFxJ6IeCXfqmqiWdJk282SjpP065zryVxEPCrp5UMOXyLp7oHHd0v6UBZjpRTk0yVtGfJ8q+o81IayPUvS2ZLW5ltJ1X1J0qcl7cu7kBo5TVKfpH8dOJ10l+3j8y6qmiJim6QvSHpRUq+kVyPiR/lWVTNviYjegce/kfSWLDpNKcgblu0TJH1H0g0R8Vre9VSL7Q9I2h4R6/KupYaaJf2RpDsi4mxJ/cro1+2JauC88CUq/yd2iqTjbf9lvlXVXpTXfmey/julIN8macaQ5x0Dx+qa7RaVQ/zeiLg/73qq7DxJH7T9PyqfOrvA9r/lW1LVbZW0NSL2/6Z1n8rBXs8ukvRCRPRFRFHS/ZL+OOeaauUl29MkaeDv7Vl0mlKQPy7pdNun2Z6k8ocj38+5pqqybZXPnW6OiFvyrqfaIuKzEdEREbNU/vf9aUTU9UwtIn4jaYvtMwcOXSjp6RxLqoUXJZ1r+7iB7/ELVecf8A7xfUnXDDy+RtL3sui0OYtOaiEi9tr+pKRVKn/KvSIinsq5rGo7T9LVkn5he8PAsc9FxMoca0L2/l7SvQMTlB5Jf51zPVUVEWtt3yfpCZVXZq1XHV6qb/sbkt4jaartrZIWSloq6du2P6byVt6XZzIWl+gDQNpSOrUCABgGQQ4AiSPIASBxBDkAJI4gB4DEEeQAkDiCHAAS9/8Eoh6n32UC5gAAAABJRU5ErkJggg==\n" 369 | }, 370 | "metadata": { 371 | "needs_background": "light" 372 | } 373 | } 374 | ], 375 | "source": [ 376 | "plt.plot(mxs,mys,'X',color=\"C3\");\n", 377 | "plot_beliefs(fg,init_messages(fg),xs, fmt=\"o\")" 378 | ] 379 | }, 380 | { 381 | "cell_type": "markdown", 382 | "id": "fd1b744c-c5a0-4c24-b1a1-6196fc25bcc6", 383 | "metadata": { 384 | "id": "fd1b744c-c5a0-4c24-b1a1-6196fc25bcc6" 385 | }, 386 | "source": [ 387 | "### Loopy belief propagation" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": 13, 393 | "id": "f21cb1e1-0e33-4a9c-9496-5c34989f80ab", 394 | "metadata": { 395 | "id": "f21cb1e1-0e33-4a9c-9496-5c34989f80ab" 396 | }, 397 | "outputs": [], 398 | "source": [ 399 | "# The first run can be quite slow as it jit compiles.\n", 400 | "messages = init_messages(fg)\n", 401 | "for _ in range(10):\n", 402 | " messages = jit_update_messages(messages)" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": 14, 408 | "id": "4b6c0934-1754-45cd-98cb-cca1204129b1", 409 | "metadata": { 410 | "id": "4b6c0934-1754-45cd-98cb-cca1204129b1", 411 | "outputId": "aeaa2335-34b6-40d3-8766-a7e261e26ba9", 412 | "colab": { 413 | "base_uri": "https://localhost:8080/", 414 | "height": 265 415 | } 416 | }, 417 | "outputs": [ 418 | { 419 | "output_type": "display_data", 420 | "data": { 421 | "text/plain": [ 422 | "
" 423 | ], 424 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3dfXRU9b3v8ffXIZj4gIkQMSRwoKeWWggQiCKy5N6l1aBHkQv16a5WOdRqtN5ae4xX7rkal6ss6dIrp9erCAUq2hap1AdOsQRbpXqEoDz6RFSIURI5GmNVTAmZTH73j0zGJE5IwjzsmT2f11os9t6zs/d3JH7nN9+99/dnzjlERMT/jvE6ABERSQ4lfBGRDKGELyKSIZTwRUQyhBK+iEiGGOR1AL0ZNmyYGz16tNdhiIikle3bt3/inMuP9lrKJvzRo0ezbds2r8MQEUkrZvZ+b6+ppCMikiGU8EVEMoQSvohIhlDCFxHJEEr4IiIZQglfRFJSc/VWai+ZRVtjY7dlOXpK+CKScpqrt7K/vJzDtbU03FoRWW58aInXoaU1JXwRSTkfLVyICwYhFOLQ7t24lhYIhTi4YYPXoaU1JXwRSTkjVyzn2NNOA7OOZA9gRv4tt3gbWJpTwheRlPPF+mc5XFMDXSdoco5PH33Uu6B8QAlfRFLKFUu30HDf/4n6Wuu+fUmOxl+U8EUk5Sy/6n8TOOWUr20fduONHkTjH0r4IpJyclqaaf/ii+4bs7Joa2ryJiCfUMIXkZRz4Qu/67hLB7DsbMjKgmCQg1VVHkeW3pTwRSTlPDb3VnIvv5xAXh4jFi0id+5cAnl5FC5e7HVoac1c16vgKaS0tNSpH75I5rli6RYA1lw/zeNI0pOZbXfOlUZ7TSN8EZEMoYQvIpIhlPBFRDKEEr4PqKugiPSHEn6aU1dBke6uWLolcuFXulPCT3PqKigi/aWEn+ZGrljOcaWlWHZ2pKugZWdzamWlx5GJSKpRwk9zrbXvfTWyD3OhEM3V1R5GJSKpSAk/zUVKOugRdEkNqqGnLiX8NDdq5Qo9gi6+8fTOBnZ+8Blb3/uU6Yue5+mdDV6H5CuDvA5AYjMoP5+CyjspqLwTgCEzyyi4S/V7ST9P72xgwZOv0xpqB6Dhs0MsePJ1AGaXFHoZmm9ohC8iKeHeqrc5FAx123YoGOLeqrc9ish/4pLwzWylmX1sZm/08rqZ2f81s71m9pqZTY7HeUXEPz787NCAtsvAxWuE/wgw8wivXwicFv5zHaCngkSkmxG5OQPa7leJvOgdl4TvnHsR+PQIu1wKPOo6VAO5ZlYQj3OLiD9UlI0lJyvQbVtOVoCKsrEeReQ/yarhFwL7u6zXh7d1Y2bXmdk2M9vWqF4wIhlldkkh98wpZnCgIy0V5uZwz5xiXbCNo5S6S8c5twxYBh0ToHgcjogk2eySQla/8gGgCVASIVkj/AZgZJf1ovA234hnx0o9uCIiiZCshL8OuDp8t85ZwOfOuQNJOnfC9exYuffa6/j73n3qWCkiKSVet2WuBrYAY82s3sx+aGblZlYe3uVZoBbYC/wKuDEe500VPTtWDm5rJeDa1bFSRFJKXGr4zrmr+njdAT+Ox7lS0cgVy/mw4rZuTcyCg7IoVMdKEUkhetI2DqJ1rDTXro6VIhkola/BKeHHQc+OlW2BAINCIXWsFJGUooQfBz07Vu4cN4PmnBPVsVIkydRt88hS6j78dNWzY+X694ew/rwfsOasqR5HJpI51G2zbxrhi4gvqNtm35TwRcQX1G2zb0r4IhI3XtbQ1W2zb0r4InJE/W0b0lsNPVlJX902+6aLtiLSq862IS4YpOHWio7nTYJBGh9aErlJodORaugDuWh6tE3TOs9x29rXaA21U5ibQ0XZWF2w7UIjfBHpVc+2Ia6lBUKhqG1DUqGGPrukkJJRuUwdczIv336ukn0PSvhx0vlV98kX9/D6e41srW3i7IXP6T5gSWsjVyznuNJSLDs78iS5ZWdzapS2IX6poafyk7KxUsKPYqD/4J1fe//UMoQFf3yHvxMAMz482JrUGqZIvEVrG+JCoahtQ1RDT31K+HHQ+bV31ellHD6m+2WRZNwH3POi2qbp3+WH96tTp8Tuo4ULca2tHSvHHguBAASDfPGnP31tX81YlfqU8OOg82tvY05e1NcTWcPs2Yt/f3k5wz49wIzqdQk7p2SOYTf9GMwAGDxqVGT7cVOmRN1fNfTUpoQfB51fe/MP/S3q64msYUZGYKEQh3btwrW0EHDtjH/nlYSdUzLHJ//vwchycP9+CHXchXNo506vQpIYKOHHQWdJ55o3/8Sxba3dXhtoDXOgD64M+/GPwXVM/+sOH+74G2g4ZfSA3oNIp67XsAZy0VZSnxJ+HHR2yzz/73XcVXwsJwX/jjlHQc4xA6phHs2DKx/fe28k4XcyYHTDO0f9fkQ6DeSiraQ+Jfw46OyW+a0tm7nqmov49reKOPMbQ9lSeeGAaphH0/zJBg+Ouj0UCETdLjIQPed6ICsLgkHN9ZCmlPBTyNE8uJJ/880dd0504YC6om/HMzTJUD3nesidO5dAXp7mekhTaq2QQkbk5tAQJbkf6aLvJw9+dVHNsrNxoRAWDDLywL6ExCiZpedcD0NmllFwl+r36Uoj/BRyNA+uRBuBNeecyBP/VJ7ocEUkzWiEn0KOpvlTtBHYTwsuSEq8IpJelPBTzOySQla/8gFw9F0DRUSiUUlHRCRDKOGLiIR5OWNXMijhi4jg/YxdyaCELyLC0T34mG6U8EXEV9ZcP+2obniIx4xdsZaEEl1SUsL3Gb/XIEUSJdYZu2ItCSWjpKSEnwBHO8KIVSbUIEUSJdYZu2ItCSWjpBSXhG9mM83sbTPba2a3R3l9npk1mtmu8J9r43HeREjnEXIm1CBFEiXWGbtiLQklYxL4mB+8MrMA8CBwPlAPvGpm65xzb/XYdY1z7qZYz5dIvY2QgbSYuScZvzDib50DntZQO9MXPd/nk97RpPMDg7E8+Hg0vbDi+fP9EY8R/pnAXudcrXOuFXgcuDQOx026dB8hx1qDlMymkmBsYi0JJWMS+Hgk/EJgf5f1+vC2nuaa2WtmttbMRkY7kJldZ2bbzGxbY2NjHEIbmHQfISfjF0b8K90HPF6LtSSUjEngk9VL59+B1c65w2Z2PbAKOLfnTs65ZcAygNLSUtfz9URLxleqRDqa5msindJ9wJMKYu2FleheWvEY4TcAXUfsReFtEc65Jufc4fDqciD6lPce88MIeXZJISWjcpk65mRevv1cJXvpN5UEvzJ6/x5qL5lFW2MjzdVbI8vpLh4J/1XgNDMbY2aDgSuBdV13MLOCLquzgD1xOG/cJeMrlUiq8sOAJx5G79/Df3/6lxyuraXh1gr2l5dzuLaWxoeWeB1azGIu6Tjn2szsJqAKCAArnXNvmtndwDbn3DrgJ2Y2C2gDPgXmxXreRFF7YslUKgl2uPCF3xEItYFr7zaB+8ENGyLzTqSruNTwnXPPAs/22HZnl+UFwIJ4nKs/rli6BVDCFhkoDXjgsTn/wpwNy/jHxrpIsrfsbE6tTP+pHfWkrXzNFUu3RD40RTLNsL8doOhAbSTZA7hQiObqag+jig8l/BTkVWsGEelS0qFjZE9WFgSDHKyq8jiy2Cnhi4h08djcW9le/F8I5OUxYtEicufOJZCXR+HixV6HFjPNaSsi0sWXx5/Es+f9gGuufwiAITPLKLgr/ev3oBG+iEjGUMIXEckQKumIiHTh5xsmNMIXEckQSvgiIhlCCV9EJEOohu9Dfq5BisjR0whfRCRD+C7hp/Mk5CIiieSrhK85OUVEeuerhK85OUVEeueri7bxmpNTFz0lk+n33798NcLXnJwims9AeuerhK85OUVEeuerko7m5BQR6Z2vEj5oTk4Rkd74qqQjIiK9890IX0TES6lcWdAIX0QkQ/gu4TdXb+WGR+/ghObPaa7eSu0ls2hrbPQ6rLSh1hQi/uWrhN9cvZX95eUM+/QAc/60lP3l5RyuraXxoSVeh5YW1JpCxN98lfA/WrgQFwwScO0UHajFtbRAKMTBDRu8Di0tqDWFiL/5KuGPXLGc40pLaR00mMFtrQBYdjanVlZ6HFl6iFdrChFJTb5K+K2173Fo9+5IsgdwoRDN1dUeRpU+1JpCxN98lfA7SzoAwUFZkJUFwSAHq6o8jiw9qDWFiL/5KuGPWrmC3MsvpznnRJ4q+xG5c+cSyMujcPFir0NLC7NLCrlnTjGDAx2/FoW5Odwzp1itKUR8Ii4PXpnZTOCXQABY7pxb1OP1Y4FHgSlAE3CFc64uHufualB+PgWVd/LTU88HoOD6aRTcpfr9QKg1hYi3Evn/XcwJ38wCwIPA+UA98KqZrXPOvdVltx8Cf3POfdPMrgR+AVwR67l7o0QlIvJ18SjpnAnsdc7VOudagceBS3vscymwKry8FjjPzCwO55YUpH7sIqkpHgm/ENjfZb0+vC3qPs65NuBzYGjPA5nZdWa2zcy2NerpWBGRuEqpi7bOuWXOuVLnXGl+fr7X4YiI+Eo8En4DMLLLelF4W9R9zGwQcBIdF29FJI7UC0mOJB4J/1XgNDMbY2aDgSuBdT32WQdcE17+HvC8c87F4dwiEqZeSNKXmBN+uCZ/E1AF7AF+75x708zuNrNZ4d1WAEPNbC/wM+D2WM8rIt2pF5L0JS734TvnngWe7bHtzi7LLcBl8TiXiESnXkjSl5S6aCsiR0+9kKQvSvgiPqFeSNIXzWkr4hOdPY9uW/saraF2CnNzqCgbq15IEqGEL+Ij6oUkR6KSjohIhlDCFxHJEEr4IiIZQglfRCRDKOGLiGQIJXyRFKP5BCRRlPBFRDKEEr6ISIZQwhcRyRB60la+Rk9oiviTRvgiIhlCCV9EJEMo4YuIZAglfBGRDKGEL3H19M4Gdn7wGVvf+5Tpi57XBNoiKUQJX+Lm6Z0NLHjydVpD7QA0fHaIBU++nnFJX0/KSqpSwpe4ubfqbQ4FQ922HQqGuLfqbY8iEpGudB++xM2Hnx0a0HZJDD1HIb3RCF/iZkRuzoC2i0hyKeFL3FSUjSUnK9BtW05WgIqysR5FJCJdqaQjcTO7pBCA29a+RmuoncLcHCrKxka2i4i3lPAlrmaXFLL6lQ8A1ZJFUo1KOiIiGUIJXySF6ME1SSQlfJEUoQfXJNGU8CXlZOqTqnpwTRItpoRvZieb2XNm9m7477xe9guZ2a7wn3WxnFPEr4704FrTrx+hZsJEWmpqui2LDESsI/zbgb84504D/hJej+aQc25S+M+sGM8p4ku9PaA2fFCIj3/xC1xrK+9ffU1kueHWiiRHKOku1oR/KbAqvLwKmB3j8UQ851VJqbcH165+dW1kvf2LLyLLrfv2JS028YdYE/5w59yB8PJ/AsN72S/bzLaZWbWZ9fqhYGbXhffb1tjYGGNoIulldkkh98wpZnCg43/Lwtwc7plTzD8vXsAxQ4Z8bf+Tr7km2SFKmuvzwSsz+zNwapSX/rXrinPOmZnr5TD/4JxrMLNvAM+b2evOua8NT5xzy4BlAKWlpb0dS8S3oj241vTr57qN7Dt9+R//wXD+Z1Ljk/TWZ8J3zn23t9fM7CMzK3DOHTCzAuDjXo7REP671sw2ASWAvo+K9EPj4sVRt6ukIwMVa0lnHdD5vfIa4JmeO5hZnpkdG14eBkwH3orxvCIZY/Tv1zD4m98EM06eNy+yfMptt3kdmqSZWHvpLAJ+b2Y/BN4HLgcws1Kg3Dl3LXA6sNTM2un4gFnknFPCF+mn7G9/m3/8479H1lXGkaMVU8J3zjUB50XZvg24Nry8GSiO5TwiIhI7PWkrEkfqhSOpTAlfJE7UC0dSnRK+SJyoF46kOiV8kTjRJO6S6pTwReJEk7hLqlPCFxmA5uqt1F4yi7bGxm7LoEncJfVpTluRfmqu3sr+8nJcMEjDrRUc2r0bFwzS+NASCirv1CTukvI0whfpp48WLsQFgxAKdST7lhYIhTi4YUNkn9klhZSMymXqmJN5+fZzlewlpSjhi+8kqr3xyBXLOa60FMvO7kj2gGVnc2plZdzPJZIISvgi/dRa+95XI/swFwrRXF3tYVQi/aeEL3G35vppkda+fhIp6dAxsicrC4JBDlZVeRyZSP/ooq1IP41auYLGh5ZwcMMGTq2spLm6moNVVRT20r74aPnxw1JSgxK+pJTOXjStoXamL3o+pe5yGZSfT0HlnRRU3gnAkJllFNyl+r2kD5V0JGWkQi8aNT8TP1PCl5ThdS+aVPjAEUkkJXxJGV73ovH6A0ck0ZTwJWV43YvG6w8ckURTwpeU4XUvGq8/cEQSTQlfUsbskkLumVPM4EDHr2Vhbg73zClO2l06Xn/giCSabsuUlDK7pJDVr3wAJP9+dDU/E79TwhfpwssPHJFEU0lHRCRDKOGLiGQIJXwRkQyRVjX8YDBIfX09LV3a00pqy87OpqioiKysLK9DEcl4aZXw6+vrOfHEExk9ejRm5nU40gfnHE1NTdTX1zNmzBivw0kaXeyVVJVWJZ2WlhaGDh2qZJ8mzIyhQ4fqG5lIikirhA8o2aeZZP97qdulSO/SLuGL9EbdLkWOzNcJv7l6K7WXzKKtsbHbcizMjO9///uR9ba2NvLz87n44otjDTctbNq0ic2bN3sdRlTqdilyZDElfDO7zMzeNLN2Mys9wn4zzextM9trZrfHcs7+aq7eyv7ycg7X1tJwa0VkufGhJTEd9/jjj+eNN97g0KGODorPPfcchYXePHrf1taW9HOmcsJXt0uRI4t1hP8GMAd4sbcdzCwAPAhcCHwHuMrMvhPjefsUmXA6FOLQ7t24lhYIhTi4YUPMx77oootYv349AKtXr+aqq66KvNbc3Mz8+fM588wzKSkp4ZlnngGgrq6Oc845h8mTJzN58uRI0jxw4AAzZsxg0qRJjB8/npdeegmAE044IXLMtWvXMm/ePADmzZtHeXk5U6dO5bbbbmPfvn3MnDmTKVOmcM4551BTUxPZ74YbbuCss87iG9/4Bps2bWL+/PmcfvrpkWMBbNy4kWnTpjF58mQuu+wyvvzySwBGjx5NZWUlkydPpri4mJqaGurq6nj44YdZvHgxkyZN4qWXXuKJJ55g/PjxTJw4kRkzZsT83zYW6nYpcmQxJXzn3B7nXF/fl88E9jrnap1zrcDjwKWxnLc/Rq5YznGlpVh2dkeyByw7m1MrY5+D9Morr+Txxx+npaWF1157jalTp0ZeW7hwIeeeey6vvPIKL7zwAhUVFTQ3N3PKKafw3HPPsWPHDtasWcNPfvITAH73u99RVlbGrl272L17N5MmTerz/PX19WzevJn777+f6667jgceeIDt27dz3333ceONN0b2+9vf/saWLVtYvHgxs2bN4pZbbuHNN9/k9ddfZ9euXXzyySf8/Oc/589//jM7duygtLSU+++/P/Lzw4YNY8eOHdxwww3cd999jB49mvLycm655RZ27drFOeecw913301VVRW7d+9m3bp1Mf+3jYW6XYocWTLuwy8E9ndZrwemRtvRzK4DrgMYNWpUTCdtrX3vq5F9mAuFaK6uZsjMspiOPWHCBOrq6li9ejUXXXRRt9c2btzIunXruO+++4COW0k/+OADRowYwU033cSuXbsIBAK88847AJxxxhnMnz+fYDDI7Nmz+5XwL7vsMgKBAF9++SWbN2/msssui7x2+PDhyPIll1yCmVFcXMzw4cMpLi4GYNy4cdTV1VFfX89bb73F9OnTAWhtbWXatK/uIZ8zZw4AU6ZM4cknn4way/Tp05k3bx6XX355ZH+vqNulyJH1mfDN7M/AqVFe+lfn3DPxDMY5twxYBlBaWupiOVakpEPHyN6FQhAMcrCqioK7Yh/lz5o1i1tvvZVNmzbR1NQU2e6c4w9/+ANjx3YfVd51110MHz6c3bt3097eTnZ2NgAzZszgxRdfZP369cybN4+f/exnXH311d1uZ+x5H/vxxx8PQHt7O7m5uezatStqjMceeywAxxxzTGS5c72trY1AIMD555/P6tWrj/jzgUCg1+sFDz/8MFu3bmX9+vVMmTKF7du3M3To0Kj7JoO6XYr0rs+SjnPuu8658VH+9DfZNwAju6wXhbcl1KiVK8i9/HICeXmMWLSI3LlzCeTlUbh4cVyOP3/+fCorKyOj5k5lZWU88MADONfxebVz504APv/8cwoKCjjmmGN47LHHCIU67iZ5//33GT58OD/60Y+49tpr2bFjBwDDhw9nz549tLe389RTT0WNYciQIYwZM4YnnngC6Piw2b17d7/fw1lnncXLL7/M3r17gY7rD53fPHpz4okncvDgwcj6vn37mDp1KnfffTf5+fns37//CD8tIl5Kxm2ZrwKnmdkYMxsMXAkkvNg7KD+fgso7+daWzQyZWUbBXZV8a8tmjj8rajVpwIqKiiJ1+K7uuOMOgsEgEyZMYNy4cdxxxx0A3HjjjaxatYqJEydSU1MTGaVv2rSJiRMnUlJSwpo1a7j55psBWLRoERdffDFnn302BQUFvcbx29/+lhUrVjBx4kTGjRsXuUjcH/n5+TzyyCNcddVVTJgwgWnTpkUu+vbmkksu4amnnopctK2oqKC4uJjx48dz9tlnM3HixH6fX0SSyzpHokf1w2b/DXgAyAc+A3Y558rMbASw3Dl3UXi/i4B/AwLASufcwr6OXVpa6rZt29Zt2549ezj99NOPOl7xxkD/3a5YugU4+pKM1z8v4iUz2+6ci3qbfEwXbZ1zTwFfqzc45z4ELuqy/izwbCznEkkWJXrxK18/aSsiIl9RwhcRyRBp1Q9fMoNKKiKJ4fsR/hVLt0QuwomIZDLfJ3wREeng64SfiMkwAoEAkyZNYuLEid2aoB1JZyO0Dz/8kO9973t97l9RUcG4ceOoqKiIOV4RkU6+reH3NhkGEFNvlZycnEgrg6qqKhYsWMBf//rXfv3siBEjWLt2bZ/7LVu2jE8//ZRAINDnvpmouXorHy1cyKiVKzi8rzayPCg/3+vQRFKab0f4yZgM44svviAvL++rc957L2eccQYTJkygMkpXzrq6OsaPHw9AKBSioqIisv/SpUuBjh49X375JVOmTGHNmjUp1X44FSRqngORTODbEX6iJsM4dOgQkyZNoqWlhQMHDvD8888DHV0y3333XV555RWcc8yaNYsXX3yx1yS9YsUKTjrpJF599VUOHz7M9OnTueCCC1i3bh0nnHBC5FtEcXExVVVVFBYW8tlnn8UUux9EnecAOLhhAwWVd3ocnUhq8+0IP1GTYXSWdGpqatiwYQNXX301zjk2btzIxo0bKSkpYfLkydTU1PDuu+/2epyNGzfy6KOPMmnSJKZOnUpTU1PU/TvbD//qV7+KNFzLZImc50DE73w7wq8oG8uCJ1/vVtaJ92QY06ZN45NPPqGxsRHnHAsWLOD666/v188653jggQcoKztyb/5Uaz/stUTOcyDid74d4c8uKeSeOcUMDnS8xcLcHO6ZUxzXyTBqamoIhUIMHTqUsrIyVq5cGZkisKGhgY8//rjXny0rK2PJkiUEwz3733nnHZqbm7+2n9oPd9dzngOysiLzHIjIkfl2hA+JmQyjs4YPHaP0VatWEQgEuOCCC9izZ09kxqgTTjiB3/zmN5xyyilRj3PttddSV1fH5MmTcc6Rn5/P008//bX9KioqePfdd3HOcd5552V8++FRK1fQ+NASDm7YwKmVlTRXV3Owqipu8xyI+FlM7ZETSe2R/UP/biLJc6T2yL4t6YiISHdK+CIiGSLtEn6qlqAkOv17iaSOtEr42dnZNDU1KYmkCeccTU1NZGdnex2KiJBmd+kUFRVRX19PY2Oj16FIP2VnZ1NUVOR1GCJCmiX8rKwsxowZ43UYIiJpKa1KOiIicvSU8EVEMoQSvohIhkjZJ23NrBF4P4ZDDAM+iVM46SLT3nOmvV/Qe84Usbznf3DORZ0NKGUTfqzMbFtvjxf7Vaa950x7v6D3nCkS9Z5V0hERyRBK+CIiGcLPCX+Z1wF4INPec6a9X9B7zhQJec++reGLiEh3fh7hi4hIF0r4IiIZwncJ38xmmtnbZrbXzG73Op5EM7ORZvaCmb1lZm+a2c1ex5QsZhYws51m9kevY0kGM8s1s7VmVmNme8wsPvN2pjAzuyX8e/2Gma02M9+1XjWzlWb2sZm90WXbyWb2nJm9G/47Lx7n8lXCN7MA8CBwIfAd4Coz+463USVcG/AvzrnvAGcBP86A99zpZmCP10Ek0S+BDc65bwMT8fl7N7NC4CdAqXNuPBAArvQ2qoR4BJjZY9vtwF+cc6cBfwmvx8xXCR84E9jrnKt1zrUCjwOXehxTQjnnDjjndoSXD9KRBAq9jSrxzKwI+CdgudexJIOZnQTMAFYAOOdanXOfeRtVUgwCcsxsEHAc8KHH8cSdc+5F4NMemy8FVoWXVwGz43EuvyX8QmB/l/V6MiD5dTKz0UAJsNXbSJLi34DbgHavA0mSMUAj8OtwGWu5mR3vdVCJ5JxrAO4DPgAOAJ875zZ6G1XSDHfOHQgv/ycwPB4H9VvCz1hmdgLwB+CnzrkvvI4nkczsYuBj59x2r2NJokHAZGCJc64EaCZOX/NTVbhufSkdH3YjgOPN7PveRpV8ruPe+bjcP++3hN8AjOyyXhTe5mtmlkVHsv+tc+5Jr+NJgunALDOro6Nsd66Z/cbbkBKuHqh3znV+e1tLxweAn30XeM851+icCwJPAmd7HFOyfGRmBQDhvz+Ox0H9lvBfBU4zszFmNpiOCzzrPI4poczM6Kjr7nHO3e91PMngnFvgnCtyzo2m49/4eeecr0d+zrn/BPab2djwpvOAtzwMKRk+AM4ys+PCv+fn4fML1V2sA64JL18DPBOPg6bVFId9cc61mdlNQBUdV/RXOufe9DisRJsO/AB43cx2hbf9L+fcsx7GJInxP4DfhgcztcA/exxPQjnntprZWmAHHXej7cSHbRbMbDXwX4FhZlYPVAKLgN+b2Q/paBN/eVzOpdYKIiKZwW8lHRER6YUSvr+y0dYAAAAmSURBVIhIhlDCFxHJEEr4IiIZQglfRCRDKOGLiGQIJXwRkQzx/wHThIs/nQhxtAAAAABJRU5ErkJggg==\n" 425 | }, 426 | "metadata": { 427 | "needs_background": "light" 428 | } 429 | } 430 | ], 431 | "source": [ 432 | "plt.plot(mxs,mys,'X',color=\"C3\",label=\"Measurements\")\n", 433 | "plot_beliefs(fg,messages,xs, fmt=\"o\",label=\"Beliefs\")\n", 434 | "plt.legend();" 435 | ] 436 | }, 437 | { 438 | "cell_type": "code", 439 | "execution_count": null, 440 | "id": "d0ca733e-5bc6-41ce-862b-34b72bbbab7f", 441 | "metadata": { 442 | "id": "d0ca733e-5bc6-41ce-862b-34b72bbbab7f" 443 | }, 444 | "outputs": [], 445 | "source": [] 446 | } 447 | ], 448 | "metadata": { 449 | "kernelspec": { 450 | "display_name": "Python 3 (ipykernel)", 451 | "language": "python", 452 | "name": "python3" 453 | }, 454 | "language_info": { 455 | "codemirror_mode": { 456 | "name": "ipython", 457 | "version": 3 458 | }, 459 | "file_extension": ".py", 460 | "mimetype": "text/x-python", 461 | "name": "python", 462 | "nbconvert_exporter": "python", 463 | "pygments_lexer": "ipython3", 464 | "version": "3.10.4" 465 | }, 466 | "colab": { 467 | "provenance": [] 468 | } 469 | }, 470 | "nbformat": 4, 471 | "nbformat_minor": 5 472 | } -------------------------------------------------------------------------------- /deprecated-gauss-bp/gauss-bp-1d-line.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "id": "2OtN-aTp8h0_" 17 | }, 18 | "source": [ 19 | "\n", 20 | "# Gaussian Belief Propagation applied to denoising a 1d line\n", 21 | "\n", 22 | "\n", 23 | "This example is based on the [PyTorch colab by Joseph Ortiz](https://colab.research.google.com/drive/1-nrE95X4UC9FBLR0-cTnsIP_XhA_PZKW?usp=sharing)\n", 24 | "\n", 25 | "\n" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "source": [ 31 | "!git clone https://github.com/probml/pgm-jax.git\n" 32 | ], 33 | "metadata": { 34 | "id": "WdXPgFAa8zGx", 35 | "outputId": "3cbc1a1b-5c19-42a9-8c76-efca830e4fb1", 36 | "colab": { 37 | "base_uri": "https://localhost:8080/" 38 | } 39 | }, 40 | "execution_count": 1, 41 | "outputs": [ 42 | { 43 | "output_type": "stream", 44 | "name": "stdout", 45 | "text": [ 46 | "Cloning into 'pgm-jax'...\n", 47 | "remote: Enumerating objects: 56, done.\u001b[K\n", 48 | "remote: Counting objects: 100% (56/56), done.\u001b[K\n", 49 | "remote: Compressing objects: 100% (48/48), done.\u001b[K\n", 50 | "remote: Total 56 (delta 16), reused 25 (delta 4), pack-reused 0\u001b[K\n", 51 | "Unpacking objects: 100% (56/56), done.\n" 52 | ] 53 | } 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "source": [ 59 | "!ls" 60 | ], 61 | "metadata": { 62 | "id": "kn2X7U6A9KH8", 63 | "outputId": "4080e168-906f-4e23-edbf-4ef4443d1ee8", 64 | "colab": { 65 | "base_uri": "https://localhost:8080/" 66 | } 67 | }, 68 | "execution_count": 2, 69 | "outputs": [ 70 | { 71 | "output_type": "stream", 72 | "name": "stdout", 73 | "text": [ 74 | "pgm-jax sample_data\n" 75 | ] 76 | } 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "source": [ 82 | "%cd pgm-jax/gaussian-loopy-bp/\n" 83 | ], 84 | "metadata": { 85 | "id": "UN3-h2I39Ld6", 86 | "outputId": "30467e3e-d727-4b11-f6cd-06749eee9375", 87 | "colab": { 88 | "base_uri": "https://localhost:8080/" 89 | } 90 | }, 91 | "execution_count": 5, 92 | "outputs": [ 93 | { 94 | "output_type": "stream", 95 | "name": "stdout", 96 | "text": [ 97 | "/content/pgm-jax/gaussian-loopy-bp\n" 98 | ] 99 | } 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 6, 105 | "metadata": { 106 | "id": "FEnKJ9_w8h1D" 107 | }, 108 | "outputs": [], 109 | "source": [ 110 | "\n", 111 | "import numpy as np\n", 112 | "import os\n", 113 | "import matplotlib.pyplot as plt\n", 114 | "\n", 115 | "from typing import List, Callable, Optional, Union\n", 116 | "\n", 117 | "import jax.numpy as jnp\n", 118 | "from jax import random as jrand\n", 119 | "from jax import config\n", 120 | "\n", 121 | "from factor_graph import FactorGraph, GBPSettings\n", 122 | "from gaussian import MeasModel, SquaredLoss\n", 123 | "\n", 124 | "# Uncomment below expression to enforce CPU runtime\n", 125 | "# config.update('jax_platform_name', 'cpu')" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": { 131 | "id": "f9uq2Jcr8h1E" 132 | }, 133 | "source": [ 134 | "\n", 135 | "\n", 136 | "## Create Custom factors" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 7, 142 | "metadata": { 143 | "id": "TlB0zKwl8h1E" 144 | }, 145 | "outputs": [], 146 | "source": [ 147 | "def height_meas_fn(x: jnp.array, gamma: jnp.array):\n", 148 | " gamma = gamma.squeeze()\n", 149 | " J = jnp.array([1-gamma, gamma])\n", 150 | " return J @ x.reshape(-1,1)\n", 151 | "\n", 152 | "def height_jac_fn(x: jnp.array, gamma: jnp.array):\n", 153 | " gamma = gamma.squeeze()\n", 154 | " return jnp.array([[1-gamma, gamma]])\n", 155 | "\n", 156 | "class HeightMeasurementModel(MeasModel):\n", 157 | " def __init__(self, loss: SquaredLoss, gamma: jnp.array) -> None:\n", 158 | " MeasModel.__init__(self, height_meas_fn, height_jac_fn, loss, gamma)\n", 159 | " self.linear = True\n", 160 | "\n", 161 | "def smooth_meas_fn(x: jnp.array):\n", 162 | " return jnp.array([x[1] - x[0]])\n", 163 | "\n", 164 | "def smooth_jac_fn(x: jnp.array):\n", 165 | " return jnp.array([[-1., 1.]])\n", 166 | "\n", 167 | "class SmoothingModel(MeasModel):\n", 168 | " def __init__(self, loss: SquaredLoss) -> None:\n", 169 | " MeasModel.__init__(self, smooth_meas_fn, smooth_jac_fn, loss)\n", 170 | " self.linear = True" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": { 176 | "id": "hf9y6iVv8h1F" 177 | }, 178 | "source": [ 179 | "## Set parameters" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 8, 185 | "metadata": { 186 | "id": "wRh5SP9M8h1F" 187 | }, 188 | "outputs": [], 189 | "source": [ 190 | "n_varnodes = 20\n", 191 | "x_range = 10\n", 192 | "n_measurements = 15\n", 193 | "\n", 194 | "gbp_settings = GBPSettings(\n", 195 | " damping = 0.1,\n", 196 | " beta = 0.01,\n", 197 | " num_undamped_iters = 1,\n", 198 | " min_linear_iters = 1,\n", 199 | " dropout = 0.0,\n", 200 | ")" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": { 206 | "id": "s0G9JMjq8h1F" 207 | }, 208 | "source": [ 209 | "## Gaussian noise measurement model parameters:" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 9, 215 | "metadata": { 216 | "id": "WliIR4TZ8h1G", 217 | "outputId": "7c34207a-c7f2-4e60-edbd-38f71fb7540c", 218 | "colab": { 219 | "base_uri": "https://localhost:8080/" 220 | } 221 | }, 222 | "outputs": [ 223 | { 224 | "output_type": "stream", 225 | "name": "stderr", 226 | "text": [ 227 | "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" 228 | ] 229 | } 230 | ], 231 | "source": [ 232 | "prior_cov = jnp.array([10.])\n", 233 | "data_cov = jnp.array([0.05]) \n", 234 | "smooth_cov = jnp.array([0.1])\n", 235 | "data_std = jnp.sqrt(data_cov)" 236 | ] 237 | }, 238 | { 239 | "cell_type": "markdown", 240 | "metadata": { 241 | "id": "6N6wefi28h1G" 242 | }, 243 | "source": [ 244 | "## Create measurements " 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 10, 250 | "metadata": { 251 | "id": "PhPPPkhq8h1H" 252 | }, 253 | "outputs": [], 254 | "source": [ 255 | "key = jrand.PRNGKey(42)\n", 256 | "meas_x = jrand.randint(key, [n_measurements], 0, x_range)\n", 257 | "key, subkey = jrand.split(key)\n", 258 | "meas_y = jnp.sin(meas_x) + jrand.normal(key, [n_measurements])*data_std" 259 | ] 260 | }, 261 | { 262 | "cell_type": "markdown", 263 | "metadata": { 264 | "id": "vbFn9hA98h1H" 265 | }, 266 | "source": [ 267 | "## Plot measurements" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 11, 273 | "metadata": { 274 | "id": "_xrj8IyG8h1H", 275 | "outputId": "14fa299f-4bb5-4961-8673-a7192d4f72fc", 276 | "colab": { 277 | "base_uri": "https://localhost:8080/", 278 | "height": 265 279 | } 280 | }, 281 | "outputs": [ 282 | { 283 | "output_type": "display_data", 284 | "data": { 285 | "text/plain": [ 286 | "
" 287 | ], 288 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAUiUlEQVR4nO3df3BV5Z3H8c+XC8gWae1iJkXiCM4yLj8DIUUuDExqdERAcXVpZcYVipb+YnW327qyHUprO4adMuuuHaYMo64s66ALSM0uzIql3tEp0TEIWAWkwUUJpW1KXRUWDEm++8e5iQETft2Te27u837NZJ57fnCfLyfwuc99zrnnmrsLAFD8+iRdAAAgPwh8AAgEgQ8AgSDwASAQBD4ABKJv0gV05/LLL/dhw4YlXQYA9Co7duz4g7uXdLWtYAN/2LBhqq+vT7oMAOhVzOyd7rYxpQMAgSDwASAQBD4ABKJg5/C7curUKTU2NurkyZNJl4LzNGDAAJWVlalfv35JlwIEr1cFfmNjowYNGqRhw4bJzJIuB+fg7jp69KgaGxs1fPjwpMsBgterpnROnjypwYMHE/a9hJlp8ODBvCMDCkSvCnxJhH0vw+8LvUZdnVRTE7VFqldN6QBAj6irk6qrpeZmqX9/ads2KZ1OuqrY9boRftLMTHfeeWfHcktLi0pKSjR79uwEq8qfTCaj7du3J10GEK9MJgr71taozWSSrqhHEPgXaODAgXrjjTd04sQJSdLzzz+voUOHJlJLS0tL3vsk8FGUqqqkVEoyi9qqqqQr6hEE/kWYOXOmNm/eLElat26d5s2b17Ht+PHjWrhwoSZNmqQJEybo2WeflSQdPHhQ06ZNU0VFhSoqKjpC88iRI5o+fbrGjx+vMWPG6KWXXpIkXXrppR3PuWHDBi1YsECStGDBAn3ta1/Ttddeq/vvv18HDhzQjBkzNHHiRE2bNk379u3r2O/rX/+6Jk+erKuvvlqZTEYLFy7UyJEjO55LkrZu3ap0Oq2KigrNnTtXx44dkxTd2mLZsmWqqKjQ2LFjtW/fPh08eFCrVq3Sww8/rPHjx+ull17S+vXrNWbMGJWXl2v69Ok9c8CBfGg/31TM553cvSB/Jk6c6Gfas2fPJ9ad0/bt7g89FLUxGDhwoO/evdtvv/12P3HihJeXl/sLL7zgs2bNcnf3JUuW+Nq1a93d/b333vMRI0b4sWPH/Pjx437ixAl3d9+/f7+3//1WrFjhP/rRj9zdvaWlxT/44IOOftqtX7/e58+f7+7u8+fP91mzZnlLS4u7u1933XW+f/9+d3d/+eWX/Qtf+ELHfl/60pe8ra3Nf/azn/mgQYP89ddf99bWVq+oqPCdO3d6U1OTT5s2zY8dO+bu7suXL/cf/OAH7u5+1VVX+SOPPOLu7itXrvS7777b3d2XLVvmP/7xjztqGzNmjDc2Nnb8fbtyUb83IJ8eesg9lXKXovahh5Ku6KJJqvducrW4T9r20ImYcePG6eDBg1q3bp1mzpx52ratW7eqtrZWK1askBRdSvruu+/qiiuu0OLFi7Vr1y6lUint379fkvT5z39eCxcu1KlTp3Trrbdq/Pjx5+x/7ty5SqVSOnbsmLZv3665c+d2bPvoo486Ht98880yM40dO1alpaUaO3asJGn06NE6ePCgGhsbtWfPHk2dOlWS1NzcrHSn43PbbbdJkiZOnKhnnnmmy1qmTp2qBQsW6Itf/GLH/kCvU1UVZUR7VhTplE5xB35XJ2JiOvN+yy236Nvf/rYymYyOHj3asd7dtXHjRl1zzTWn7f/9739fpaWl2r17t9ra2jRgwABJ0vTp0/Xiiy9q8+bNWrBggb71rW/prrvuOu1yxjOvYx84cKAkqa2tTZdddpl27drVZY2XXHKJJKlPnz4dj9uXW1palEqldMMNN2jdunVn/fOpVKrb8wWrVq3SK6+8os2bN2vixInasWOHBg8e3OW+QMFKp6MBYSYThX0RXqEjFfscfvurdioV+6v2woULtWzZso5Rc7sbb7xRP/nJTxS9s5J27twpSXr//fc1ZMgQ9enTR2vXrlVra6sk6Z133lFpaam+8pWv6J577tFrr70mSSotLdXevXvV1tamTZs2dVnDpz/9aQ0fPlzr16+XFL3Y7N69+7z/DpMnT9Yvf/lLNTQ0SIrOP7S/8+jOoEGD9OGHH3YsHzhwQNdee60efPBBlZSU6NChQ+fdP1BQ0mlpyZKiDXup2AO//VX7hz+M/brasrIy3XvvvZ9Yv3TpUp06dUrjxo3T6NGjtXTpUknSN77xDa1Zs0bl5eXat29fxyg9k8movLxcEyZM0NNPP6377rtPkrR8+XLNnj1bU6ZM0ZAhQ7qt48knn9Rjjz2m8vJyjR49uuMk8fkoKSnRE088oXnz5mncuHFKp9MdJ327c/PNN2vTpk0dJ22/853vaOzYsRozZoymTJmi8vLy8+4fQH5Z+0i00FRWVvqZX4Cyd+9ejRw5MqGKcLH4vQH5Y2Y73L2yq23FPcIHAHQg8AEgEL0u8At1Cgpd4/cFFI5eFfgDBgzQ0aNHCZELceyYdORI1OaZZ++H334JKoBk9arr8MvKytTY2KimpqakS+kdPvpI+t3vJPfo4+KlpVKn6/Hzof0brwAkr1cFfr9+/fjmpAtRUyMtXRp98CyVii5PXbIk6aoAJKRXTengAvXgB88A9D69aoSPCxTIx8UBnB8Cv9il0wQ9AElM6QBAMAh8AAhELIFvZo+b2e/N7I1utpuZPWJmDWb2uplVxNEvAOD8xTXCf0LSjLNsv0nSiOzPIkk/jalfAMB5iiXw3f1FSX88yy5zJP1b9hu4XpZ0mZl1f89fAEDs8jWHP1RS52/GaMyuO42ZLTKzejOr59O0ABCvgjpp6+6r3b3S3StLSkqSLgcAikq+Av+wpCs7LZdl1wEA8iRfgV8r6a7s1TqTJb3v7kfy1DcAQDF90tbM1kmqknS5mTVKWiapnyS5+ypJWyTNlNQg6f8kfTmOfgtaXR23NABQUGIJfHefd47tLumbcfTVK9TVSdXVUnNzdNOymL9AHQAuRkGdtC0amUwU9q2tUZvJJF0RABD4PYLbEgMoQNwtsydwW2IABYjA7ynclhhAgWFKBwACQeADQCAIfAAIBIEPAIEg8AEgEAQ+gGTV1Uk1NVGLHsVlmQCSw21I8ooRPoDkcBuSvCLwewpvU4Fz4zYkecWUTk8opLep3KYZhYzbkOQVgd8TunqbmsQ/5EJ64QGQOAK/J7S/TW0P2qTephbKCw/QHQYleUXg94RCeZtaVSX17Su1tUUt86MoNAxK8orA7ymFcrdM99NboJAUyrvhQBD4xSyTiUZO7lHL6AmFplDeDQeCwC9mjJ7QGxTKu+EAEPjFjNETgE4I/GLH6AlAFp+0BYBAEPgAEAgCHwACQeADQCAIfAAIBIEPAIEg8AEgEAQ+AASCwAeAQMQS+GY2w8zeMrMGM3ugi+0LzKzJzHZlf+6Jo18AwPnL+dYKZpaStFLSDZIaJb1qZrXuvueMXZ9298W59gcAuDhxjPAnSWpw97fdvVnSU5LmxPC8AIAYxRH4QyUd6rTcmF13ptvN7HUz22BmV3b1RGa2yMzqzay+qakphtIAAO3yddL2PyUNc/dxkp6XtKarndx9tbtXuntlSUlJnkoDgDDEEfiHJXUesZdl13Vw96Pu/lF28VFJE2PoFwBwAeII/FcljTCz4WbWX9Idkmo772BmQzot3iJpbwz9AkDxqauTamqiNmY5X6Xj7i1mtljSc5JSkh539zfN7EFJ9e5eK+leM7tFUoukP0pakGu/AFB06uqk6uqPv5Z027ZYv8Aolm+8cvctkracse57nR4vkbQkjr4AoGhlMlHYt7ZGbSYTa+DzSVsAKBRVVdHIPpWK2qqqWJ+e77QFgEKRTkfTOJlMFPYxfx81gQ8AhSSdjj3o2zGlAwCBIPABIBAEPgAEgsAHgEAQ+AAQCAIfAAJB4ANAIAh8AAgEgQ8AgSDwASAQBD4ABILAB4BAEPgAEAgCHwACQeADQCAIfAAIBIEPAIEg8AEgEAQ+AASCwAeAQBD4ABAIAh8AAlGcgV9XJ9XURC0AQJLUN+kCYldXJ1VXS83NUv/+0rZtUjqddFUAkLjiG+FnMlHYt7ZGbSaTdEUAUBCKL/CrqqKRfSoVtVVVSVcEAAWh+KZ00uloGieTicKe6RwAkFSMgS9FIU/QA8BpYpnSMbMZZvaWmTWY2QNdbL/EzJ7Obn/FzIbF0S8A4PzlHPhmlpK0UtJNkkZJmmdmo87Y7W5J77n7n0l6WNI/5tovAODCxDHCnySpwd3fdvdmSU9JmnPGPnMkrck+3iCp2swshr4BAOcpjsAfKulQp+XG7Lou93H3FknvSxp85hOZ2SIzqzez+qamphhKAwC0K6jLMt19tbtXuntlSUlJ0uUAQFGJI/APS7qy03JZdl2X+5hZX0mfkXQ0hr4BAOcpjsB/VdIIMxtuZv0l3SGp9ox9aiXNzz7+S0m/cHePoW8AwHnK+Tp8d28xs8WSnpOUkvS4u79pZg9Kqnf3WkmPSVprZg2S/qjoRQEAkEexfPDK3bdI2nLGuu91enxS0tw4+gIAXJyCOmkLAOg5BD4ABILAB4BAEPgAEAgCHwACQeADQCAIfAAIBIEPAIEg8AEgEAQ+AASCwAeAQBD4ABAIAh8AAkHgA0AgCHwACASBDwCBIPABIBAEPgAEgsAHgEAQ+AAQCAIfAAJB4ANAIAh8AAgEgQ8AgSDwASAQBD4ABILAB4BAEPgAEAgCHwACQeADQCAIfAAIRE6Bb2Z/ambPm9mvs+1nu9mv1cx2ZX9qc+kTQIxWr5ZuvDFqUfT65vjnH5C0zd2Xm9kD2eW/72K/E+4+Pse+AMRp9Wrpq1+NHm/dGrWLFiVXD3pcrlM6cyStyT5eI+nWHJ8PQL5s3Hj2ZRSdXAO/1N2PZB//VlJpN/sNMLN6M3vZzLp9UTCzRdn96puamnIsDcBZ3X772ZdRdM45pWNmP5f0uS42fbfzgru7mXk3T3OVux82s6sl/cLMfuXuB87cyd1XS1otSZWVld09F4A4tE/fbNwYhT3TOUXvnIHv7td3t83MfmdmQ9z9iJkNkfT7bp7jcLZ928wykiZI+kTgA8izRYsI+oDkOqVTK2l+9vF8Sc+euYOZfdbMLsk+vlzSVEl7cuwXAHCBcg385ZJuMLNfS7o+uywzqzSzR7P7jJRUb2a7Jb0gabm7E/gAkGc5XZbp7kclVXexvl7SPdnH2yWNzaUfAEDu+KQtAASCwAeAQBD4ABAIAh8AAkHgA0AgCHwACASBDyShrk6qqYlaIE9yvT0ygAtVVydVV0vNzVL//tK2bVI6nXRVCAAjfCDfMpko7FtbozaTSboiBILAB/Ktqioa2adSUVtVlXRFCARTOkC+pdPRNE4mE4U90znIEwIfSEI6TdAj75jSAYBAEPgAEAgCHwACQeADQCAIfAAIBIEPAIEg8AEgEAQ+AASCwAeAQBD4ABAIAh8AAkHgA0AgCHwACASBDwCBIPABIBAEPgAEgsAHgEAQ+AAQCAIfAAKRU+Cb2Vwze9PM2sys8iz7zTCzt8yswcweyKVPAMDFyXWE/4ak2yS92N0OZpaStFLSTZJGSZpnZqNy7BcAcIH65vKH3X2vJJnZ2XabJKnB3d/O7vuUpDmS9uTSNwDgwuRjDn+opEOdlhuz6z7BzBaZWb2Z1Tc1NeWhNAAIxzlH+Gb2c0mf62LTd9392TiLcffVklZLUmVlpcf53AAQunMGvrtfn2MfhyVd2Wm5LLsOAJBH+ZjSeVXSCDMbbmb9Jd0hqTYP/QIAOsn1ssy/MLNGSWlJm83suez6K8xsiyS5e4ukxZKek7RX0n+4+5u5lQ0AuFC5XqWzSdKmLtb/RtLMTstbJG3JpS8AQG74pC0ABILAB4BAEPgAEAgCHwACQeADQCAIfAAIBIEPAIEg8BGWujqppiZqgcDk9MEroFepq5Oqq6XmZql/f2nbNimdTroqIG8Y4SMcmUwU9q2tUZvJJF0RkFcEPsJRVRWN7FOpqK2qSroiIK+Y0kE40uloGieTicKe6RwEhsBHWNJpgh7BYkoHAAJB4ANAIAh8AAgEgQ8AgSDwASAQBD4ABMLcPekaumRmTZLeyeEpLpf0h5jK6e04FqfjeJyO4/GxYjgWV7l7SVcbCjbwc2Vm9e5emXQdhYBjcTqOx+k4Hh8r9mPBlA4ABILAB4BAFHPgr066gALCsTgdx+N0HI+PFfWxKNo5fADA6Yp5hA8A6ITAB4BAFF3gm9kMM3vLzBrM7IGk60mSmV1pZi+Y2R4ze9PM7ku6pqSZWcrMdprZfyVdS9LM7DIz22Bm+8xsr5kFfd9oM/vb7P+TN8xsnZkNSLqmuBVV4JtZStJKSTdJGiVpnpmNSraqRLVI+jt3HyVpsqRvBn48JOk+SXuTLqJA/Iuk/3b3P5dUroCPi5kNlXSvpEp3HyMpJemOZKuKX1EFvqRJkhrc/W13b5b0lKQ5CdeUGHc/4u6vZR9/qOg/9NBkq0qOmZVJmiXp0aRrSZqZfUbSdEmPSZK7N7v7/yZbVeL6SvoTM+sr6VOSfpNwPbErtsAfKulQp+VGBRxwnZnZMEkTJL2SbCWJ+mdJ90tqS7qQAjBcUpOkf81OcT1qZgOTLiop7n5Y0gpJ70o6Iul9d9+abFXxK7bARxfM7FJJGyX9jbt/kHQ9STCz2ZJ+7+47kq6lQPSVVCHpp+4+QdJxScGe8zKzzyqaDRgu6QpJA83szmSril+xBf5hSVd2Wi7LrguWmfVTFPZPuvszSdeToKmSbjGzg4qm+q4zs39PtqRENUpqdPf2d3wbFL0AhOp6Sf/j7k3ufkrSM5KmJFxT7Iot8F+VNMLMhptZf0UnXWoTrikxZmaK5mj3uvs/JV1Pktx9ibuXufswRf8ufuHuRTeCO1/u/ltJh8zsmuyqakl7Eiwpae9Kmmxmn8r+v6lWEZ7E7pt0AXFy9xYzWyzpOUVn2R939zcTLitJUyX9laRfmdmu7Lp/cPctCdaEwvHXkp7MDo7elvTlhOtJjLu/YmYbJL2m6Oq2nSrC2yxwawUACESxTekAALpB4ANAIAh8AAgEgQ8AgSDwASAQBD4ABILAB4BA/D90T+BibHznBgAAAABJRU5ErkJggg==\n" 289 | }, 290 | "metadata": { 291 | "needs_background": "light" 292 | } 293 | } 294 | ], 295 | "source": [ 296 | "plt.scatter(meas_x, meas_y, color=\"red\", label=\"Measurements\", marker=\".\")\n", 297 | "plt.legend()\n", 298 | "plt.savefig('gbp-1d-data.pdf')\n", 299 | "plt.show()" 300 | ] 301 | }, 302 | { 303 | "cell_type": "markdown", 304 | "metadata": { 305 | "id": "z_6gQvzu8h1I" 306 | }, 307 | "source": [ 308 | "## Create factor graph " 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": 12, 314 | "metadata": { 315 | "id": "J-4xrK4h8h1I", 316 | "outputId": "dcf272db-a318-46b7-b17e-935a369fce8b", 317 | "colab": { 318 | "base_uri": "https://localhost:8080/" 319 | } 320 | }, 321 | "outputs": [ 322 | { 323 | "output_type": "stream", 324 | "name": "stdout", 325 | "text": [ 326 | "\n", 327 | "Factor Graph:\n", 328 | "# Variable nodes: 20\n", 329 | "# Factors: 34\n", 330 | "\n", 331 | "\n" 332 | ] 333 | } 334 | ], 335 | "source": [ 336 | "fg = FactorGraph(gbp_settings)\n", 337 | "\n", 338 | "xs = jnp.linspace(0, x_range, n_varnodes).reshape(-1,1)\n", 339 | "\n", 340 | "for i in range(n_varnodes):\n", 341 | " fg.add_var_node(1, jnp.array([0.]), prior_cov)\n", 342 | "\n", 343 | "for i in range(n_varnodes-1):\n", 344 | " fg.add_factor(\n", 345 | " [i, i+1], \n", 346 | " jnp.array([0.]), \n", 347 | " SmoothingModel(SquaredLoss(1, smooth_cov))\n", 348 | " )\n", 349 | "\n", 350 | "for i in range(n_measurements):\n", 351 | " ix2 = np.argmax(xs > meas_x[i])\n", 352 | " ix1 = ix2 - 1\n", 353 | " gamma = (meas_x[i] - xs[ix1]) / (xs[ix2] - xs[ix1])\n", 354 | " fg.add_factor(\n", 355 | " [ix1, ix2], \n", 356 | " meas_y[i], \n", 357 | " HeightMeasurementModel(\n", 358 | " SquaredLoss(1, data_cov), \n", 359 | " gamma \n", 360 | " )\n", 361 | " )\n", 362 | "fg.print(brief=True)" 363 | ] 364 | }, 365 | { 366 | "cell_type": "markdown", 367 | "metadata": { 368 | "id": "c715iinJ8h1I" 369 | }, 370 | "source": [ 371 | "\n", 372 | "## Plot initial beliefs and measurements" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": 13, 378 | "metadata": { 379 | "id": "eYPVoCkx8h1J", 380 | "outputId": "a4f0dd34-d62d-4511-db32-f96bea639d6e", 381 | "colab": { 382 | "base_uri": "https://localhost:8080/", 383 | "height": 265 384 | } 385 | }, 386 | "outputs": [ 387 | { 388 | "output_type": "display_data", 389 | "data": { 390 | "text/plain": [ 391 | "
" 392 | ], 393 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAYwElEQVR4nO3df3SV1b3n8feXY5QUVLwQEYhM0rmUtvyGQA2M3Eis4XoBsYO2dt0CQ1u0jlM6vY2rjEvBWn906dg747RFljLCKuVSUZCpvRJEI72aovwWAfmhiImoEYsIBRLCd/54knOBBEg4T3Kyk89rLdbJefLs/XwPPz7s7LOffczdERGRcHVIdwEiIpIaBbmISOAU5CIigVOQi4gETkEuIhK4C9Jx0W7dunlOTk46Li0iEqx169Z94u5Zpx9PS5Dn5OSwdu3adFxaRCRYZvZeQ8c1tSIiEjgFuYhI4BTkIiKBS8scuYi0nOrqasrLyzl69Gi6S5FG6tixI9nZ2WRkZDTqfAW5SBtXXl7OxRdfTE5ODmaW7nLkHNyd/fv3U15eTm5ubqPaaGpFpI07evQoXbt2VYgHwszo2rVrk36CUpCLtAMK8bA09c9LQS4iEriggvybj5fxzcfL1F7t1f487K48xO7KQ2lpb2bcMOlbyfbHjx8nKyuLcePGtcj142ifSh+lpaW89tprsdTQkKCCXETC1KlTJ3Zu38rRI0cAWLlyJb169UpLLcePH2/xa9YFeXNJOcjNrKOZvW5mm8zsLTO7N47CRKRt+btrr+PlF1cAsGjRIm655Zbk9w4fPsy0adMYMWIEQ4YM4bnnngNgz549XH311QwdOpQJhf+J9a//GYB9+/YxevRoBg8eTP/+/fnTn/4EQOfOnZN9LlmyhKlTpwIwdepU7v7JDP7z2Gu488472b17N2PHjmXYsGFcffXVbN++PXneD37wA6666iq++MUvUlpayrRp0/jKV76S7AugpKSE/Px8hg4dyk033cShQ9EoOycnh1mzZjF06FAGDBjA9u3b2bNnD3PmzOGXv/wl468ZyRt/fpWnn36a/v37M2jQIEaPHp3y720cyw+PAWPc/ZCZZQD/Zmb/6u5/jqFvEUmHsjIoLYWCAsjPj6XLcRMn8dj/fIjvfXsSmzdvZtq0ackAvv/++xkzZgzz5s3jwIEDjBgxgmuvvZbLL7+clStX0rFjR1at2ciPbp3GTf+wnt/97ncUFRVx1113UVNTw1//+tdzXv/DfRX8/vkX+dIVl1JYWMicOXPo06cPa9as4fbbb+ell14C4C9/+QtlZWUsX76cCRMm8Oqrr/LEE08wfPhwtr65mSt69uTnP/85L774Ip06deIXv/gFjz76KPfccw8A3bp1Y/369fz617/mkUce4YknnuC2226jc+fO3DjlNgAmjslnxYoV9OrViwMHDqT8e5tykHv0oZ91kz4Ztb/0QaAioSorg8JCqKqCCy+EVatiCfMv9+tPxft7WbRoEddff/0p3yspKWH58uU88sgjQLRkcu/evfTs2ZM77riDjRs3UuPGu+/sAmD48OFMmzaN6upqJk6cyODBg895/b8ffyOJRIJDhw7x2muvcdNNNyW/d+zYseTX48ePx8wYMGAA3bt3Z8CAAQD069ePivff48N9FWzdupVRo0YBUFVVRf5Jvz/f+MY3ABg2bBjPPvtsg7WMGjWKqVOncvPNNyfPT0UsNwSZWQJYB/wt8Ct3X9PAOdOB6QC9e/eO47Ii0hxKS6MQr6mJHktLYxuVFxZdz09+8hNKS0vZv39/8ri788wzz9C3b99Tzp89ezbdu3dn06ZN7PzoIP2u7AbA6NGjWb16Nc8//zxTp07lxz/+MZMnTz5l2d7p67AzO3UC4MSJE3Tp0oWNGzc2WONFF10EQIcOHZJf1z0/XlNDh0SCr3/96yxatOis7ROJxBnn4+fMmcOaNWt4/vnnGTZsGOvWraNr164NntsYsbzZ6e417j4YyAZGmFn/Bs6Z6+557p6XlVVvO10RaS0KCqKReCIRPRYUxNb1pG9/h1mzZiVHuXWKiop47LHHiH7Ahw0bNgDw2Wef0aNHDzp06MCypxdRU1MDwHvvvUf37t35/ve/z/e+9z3Wr18PQPfu3dm2bRsnTpxg6dKlDdZwySWXkJuby9NPPw1E/4ls2rSp0a9h8LDhvPrqq+zaFf10cPjwYXbs2HHWNhdffDGff/558vnu3bv52te+xs9+9jOysrJ4//33G339hsS6asXdDwAvA2Pj7FdEWlB+fjSdct99sU2r1OnRsxc//OEP6x2/++67qa6uZuDAgfTr14+7774bgNtvv5358+czaNAg3tm5gy98IRpVl5aWMmjQIIYMGcLixYuZMWMGAA899BDjxo1j5MiR9OjR44x1LFy4kCeffJJBgwbRr1+/5JurjdG1WxZPPfUUt9xyCwMHDiQ/Pz/5ZumZjB8/nqVLlybf7CwuLmbAgAH079+fkSNHMmjQoEZfvyEpT62YWRZQ7e4HzCwT+Drwi1T7FZE0ys+PNcAPHaq/frqgoICC2tF+ZmYmjz/+eL12ffr0YfPmzUC0hvvOe+4DYMqUKUyZMqXe+ZMmTWLSpEn1jj/11FOnXD83N5cXXnihwfPq5OTksGXLlgb7GDNmDG+88Ua99nv27El+nZeXR2lpKQBf+tKX2Lx5c7L9t8YX1WubijjmyHsA82vnyTsAv3f3P8TQr4iINEIcq1Y2A0NiqEVERM6D7uwUEQmcglxEJHAKchGRwCnIRaSeVHdalJalIBeRZpdIJBh/zUjGFUQbTTVmJ8C6DbA++OCDBpcUnq64uJh+/fpRXFyccr2h0Wd2isgplm2oYMPeA1TVnGDUQy9RXNSXiUNS23I2MzOT//dyFN671r/KzJkzeeWVVxrVtmfPnixZsuSc+3jPnTuXTz/9lEQikVKtIdKIXESSlm2oYOazb1JVcwKAigNHmPnsmyzbUBHbNQ4ePMhll12WfP7www8zfPhwBg4cyKxZs+qdv2fPHvr3j3b9qKmpobi4OHl+3U1EEyZM4NChQwwbNozFixfHvk1sa6cRuYgkPbzibY5U15xy7Eh1DQ+veJsF3x1x3v0eOXKE8deM5Nixo3zy8UfJLWNLSkrYuXMnr7/+Ou7OhAkTWL169RnD9+mF87n00kt54403OHbsGKNGjeK6665j+fLldO7cObkR1oABA2LdJra1U5CLSNIHB4406XhjnTy18vGuN5k8eTJbtmyhpKSEkpIShgyJ7ik8dOgQO3fuPGOQ/6n0Jd55eytLliwBok21du7cSW5u7innxb1NbGunIBeRpJ5dMqloILR7dsmM7Rr5+fl88sknVFZW4u7MnDmTW2+9tXGN3XnssccoKjr7XiUNbRMLF521Tcg0Ry4iScVFfcnMOPXNwsyMBMVFfc/Qoum2b99OTU0NXbt2paioiHnz5iU/Kq2iooKPP/74jG2vvqaQ3/zmN1RXVwOwY8cODh8+XO+8uLeJbe00IheRpLrVKXcu2UxVzQl6dclMrlpJ5dPf6+bIATISxvz580kkElx33XVs27Yt+Qk7nTt35re//S2XX355g/3c/I9TObz/Q4YOHYq7k5WVxbJly+qdV1xczM6dO3F3CgsLo21wP6kf+G2FglxETjFxSC8Wvb4XgMW3xrOVbU1NTfI/gv+Y1fmU782YMSO5n/jJTv5A4y1btrC78hAdOnTggQce4IEHHjjj+cAZP2KtrVKQi0g9cQW4tAzNkYuIBE5BLtIO1H0WpoShqX9eCnKRNq5jx47s379fYR4Id2f//v107Nix0W00Ry7SxmVnZ1NeXk5FZTkGVH1yfuupKz8/Bu24fUvW0LFjR7Kzsxvdr4JcpI3LyMggNzeXn5ZE29IuvnXwefUz+/H23b611NAQTa2IiAROQS4iEjgFuYhI4FIOcjO70sxeNrOtZvaWmdW/RUtERJpNHG92Hgf+yd3Xm9nFwDozW+nuW2PoW0REziHlEbm773P39bVffw5sA1L7XCgREWm0WOfIzSwHGAKsibNfERE5s9iC3Mw6A88AP3L3gw18f7qZrTWztZWVlXFdVkSk3YslyM0sgyjEF7p7g/tHuvtcd89z97ysrKw4LisiIsSzasWAJ4Ft7v5o6iWJiEhTxDEiHwV8BxhjZhtrf10fQ78iItIIKS8/dPd/AyyGWkRE5Dzozk4RkcApyEVEAqcgFxEJnIJcRCRwCnIRkcApyEVEAqcgFxEJnIJcRCRwCnIRkcApyEVEAqcgFxEJnIJcRCRwCnIRkcApyEVEAqcgFxEJnIJcRNKjrAwefDB6lJSk/MESIiJNVlYGBQVQXQ0ZGVBaCvn56a4qWBqRi0jLW7AAqqrAPXpcsCDdFQVNQS4iEjgFuYi0vMmT4aKLwCx6nDw53RUFTXPkrV1ZWTR/WFCQ7kpE4pOfDy+//O9/tzU/nhIFeWtWVgaFhdEc4oUXwr1L4ZJL0l2VSDzy8xXgMdHUSmtWWhqFeE1N9HjgQLorEpFWKJYgN7N5ZvaxmW2Joz+pVVAQjcQTieixS5d0VyQirVBcI/KngLEx9SV18vNh1Sq4777oUdMqItKAWObI3X21meXE0Zec5uR5xM26A05E6muxOXIzm25ma81sbWVlZUtdNj5lZbB3Lxw8mO5KRERO0WJB7u5z3T3P3fOysrJa6rLxqFs9sudd2LRJe0OISKuiVSuNUbd6xAE/ET0XEWkltI68MepWjxhgHaDg6nRXJCKSFNfyw0VAGdDXzMrN7Ltx9Ntq1K0eycmFQYN0E4O0HdpKtk2Ia9XKLXH006rl58PmdBchEqPT7xxetUqDlEBpjryx0rVqRSMmaS6lpXDsWHTn8LFjeu8nYJojb4y6kcuN90Rz5GWXtMzIpawMrrnm30dM9z2nm4IkPl27wokT0dcnTkTPJUgakTdGulatLFgQjZTco8ePPmqZ60r7sH8/dKiNgA4doucSJAV5Y9RbtVKQ5oJEYlBQEO0FnkhEj/p7HSwFeWOka9XK5Mm1/4FY9Ni9e8tcV9qH0/fy0RudwdIceWOlY9VKfn40jVO3+b5WzUjctCd4m6Agb+20aZaInIOmVkREAqcgFxEJnIJcRCRwCnIRaV/a4N3SerNTRNqPgweh8MY2t7+MRuQi0n4cOBCFeE1N9NhG9pdRkItI+9GlSzQSTySixzZyN6umVkSk/bjkkmg6pe4muzYwrQIKchFpb9rg3ayaWhERCZyCXEQkcApyEZHAKchFRAIXVpAfPBh9bmYbuiNLRCRV4QR5WRls2gR73o0+P1NhLiICxBTkZjbWzN42s11m9tM4+jzdspKNbLiiD2uy+zNqyq9YVrKxae03VLBh7wHWvPspox56iWUbKtRe7dVe7YOq4UwSs2fPTqkDM0sALwBFwIPA/7733ntXz549u/JMbebOnTt7+vTpjb7Gsg0VzHzzCMcsAWZ83rETr/ilZP9NJ77c49yfKr9sQwUzn32TY8ejTwz//OhxXtlRSfZlmWqv9mqv9kHUAHDvvffumz179tzTj8cxIh8B7HL3d9y9CvgX4IYY+k16eMXbHKk59diRmuh4o9tXn9rBkeoatVd7tVf7YGo4mziCvBfw/knPy2uPncLMppvZWjNbW1l5xsF6gz44cKRJx9Ve7dVe7eNq31pqOJsWe7PT3ee6e56752VlZTWpbc8umU06rvZqr/ZqH1f71lLD2cQR5BXAlSc9z649Fpvior5kZiROOZaZkaC4qK/aq73aq32ztm8tNZxNHJtmvQH0MbNcogD/FvDtGPpNmjgkmqm5c8lmqmpO0KtLJsVFfZPH1V7t1V7tm6t9a6nhbMzdU+/E7Hrgn4EEMM/d7z/b+Xl5eb527domX+ebj0drxxffen47l6m92qu92p9v+9ZQg5mtc/e804/Hso2tu/8R+GMcfYmISNOEc2eniIg0SEEuIhI4BbmISOAU5CIigVOQi4gETkEuIhI4BbmISOAU5CKtQVkZPPigPjBFzkssNwSJSArKyqCgAKqrISMDSksh//zvPpT2RyNykXRbsACqqsA9elywIN0VSWAU5CIigVOQi6Tb5Mlw0UVgFj1OnpzuiiQwmiMXSbf8fHj55WhuvKBA8+PSZApykTplZekL0/x8BbicNwW5CEQhXlgYvdl44YWwapWCVYKhOXIRiEbiVVVQUxM9lpamuyKRRlOQi0A0nXLhhZBIRI8FBemuSKTRNLUiAtE0yqpVesNRgqQgF6mjNxwlUJpaEREJnIJcRCRwCnIRkcClFORmdpOZvWVmJ8wsL66iRESk8VIdkW8BvgGsjqEWERE5DymtWnH3bQBmFk81IiLSZC02R25m081srZmtraysbKnLioi0eecckZvZi8AVDXzrLnd/rrEXcve5wFyAvLw8b3SFIiJyVucMcne/tiUKERGR86PlhyIigUt1+eGNZlYO5APPm9mKeMqSdk2fKC/SJKmuWlkKLI2pFhHtCy5yHjS1Iq2L9gUXaTIFubQu2hdcpMm0ja20LtoXXKTJFOTS+mhfcJEm0dSKiEjgFOQiIoFTkIuIBE5BLiISOAW5iEjgFOQiIoFTkIuIBE5BLiISOAW5iEjgFOQiIoFTkIuIBE5BLiISOAW5iEjgFOQiIoFTkIuIBE5BLiISOAW5iEjgFOQiIoFLKcjN7GEz225mm81sqZl1iaswERFpnFRH5CuB/u4+ENgBzEy9JBERaYqUgtzdS9z9eO3TPwPZqZckIiJNEecc+TTgX2PsT0REGuGCc51gZi8CVzTwrbvc/bnac+4CjgMLz9LPdGA6QO/evc+rWBERqe+cQe7u157t+2Y2FRgHFLq7n6WfucBcgLy8vDOeJyIiTXPOID8bMxsL3An8nbv/NZ6SRESkKVKdI/8/wMXASjPbaGZzYqhJRESaIKURubv/bVyFiIjI+dGdnSIigVOQi4gETkEuIhI4BbmISOAU5CIigVOQi4gETkEuIhI4BbmISOAU5CIigVOQi4gETkEuIhI4BbmISOAU5CIigVOQi4gETkEuIhI4BbmISOAU5CIigVOQi4gETkEuIhI4BbmISOAU5CIigVOQi4gETkEuIhK4lILczO4zs81mttHMSsysZ1yFiYhI46Q6In/Y3Qe6+2DgD8A9MdQkIiJNkFKQu/vBk552Ajy1ckREpKkuSLUDM7sfmAx8BlxzlvOmA9MBevfuneplRUSk1jlH5Gb2opltaeDXDQDufpe7XwksBO44Uz/uPtfd89w9LysrK75XICLSzp1zRO7u1zayr4XAH4FZKVUkIiJNkuqqlT4nPb0B2J5aOSIi0lSpzpE/ZGZ9gRPAe8BtqZckIiJNYe4tv9AkLy/P165d2+LXFREJmZmtc/e804/rzk4RkcApyEVEAqcgFxEJnIJcRCRwCnIRkcApyEVEAqcgFxEJnIJcRCRwCnIRkcCl5c5OM6skuqX/fHQDPomxnBDoNbcPes3tQyqv+T+4e73tY9MS5Kkws7UN3aLaluk1tw96ze1Dc7xmTa2IiAROQS4iErgQg3xuugtIA73m9kGvuX2I/TUHN0cuIiKnCnFELiIiJ1GQi4gELqggN7OxZva2me0ys5+mu57mZmZXmtnLZrbVzN4ysxnprqklmFnCzDaY2R/SXUtLMLMuZrbEzLab2TYzy093Tc3NzP577d/pLWa2yMw6prumuJnZPDP72My2nHTsb8xspZntrH28LI5rBRPkZpYAfgX8PfBV4BYz+2p6q2p2x4F/cvevAlcB/7UdvGaAGcC2dBfRgv4X8IK7fxkYRBt/7WbWC/ghkOfu/YEE8K30VtUsngLGnnbsp8Aqd+8DrKp9nrJgghwYAexy93fcvQr4F+CGNNfUrNx9n7uvr/36c6J/4L3SW1XzMrNs4B+AJ9JdS0sws0uB0cCTAO5e5e4H0ltVi7gAyDSzC4AvAB+kuZ7Yuftq4NPTDt8AzK/9ej4wMY5rhRTkvYD3T3peThsPtZOZWQ4wBFiT3kqa3T8DdwIn0l1IC8kFKoH/Wzud9ISZdUp3Uc3J3SuAR4C9wD7gM3cvSW9VLaa7u++r/fpDoHscnYYU5O2WmXUGngF+5O4H011PczGzccDH7r4u3bW0oAuAocBv3H0IcJiYftxurWrnhW8g+k+sJ9DJzP4xvVW1PI/Wfsey/jukIK8ArjzpeXbtsTbNzDKIQnyhuz+b7nqa2ShggpntIZo6G2Nmv01vSc2uHCh397qftJYQBXtbdi3wrrtXuns18CwwMs01tZSPzKwHQO3jx3F0GlKQvwH0MbNcM7uQ6M2R5WmuqVmZmRHNnW5z90fTXU9zc/eZ7p7t7jlEf74vuXubHqm5+4fA+2bWt/ZQIbA1jSW1hL3AVWb2hdq/44W08Td4T7IcmFL79RTguTg6vSCOTlqCux83szuAFUTvcs9z97fSXFZzGwV8B3jTzDbWHvsf7v7HNNYk8ftvwMLaAco7wH9Jcz3Nyt3XmNkSYD3RyqwNtMFb9c1sEVAAdDOzcmAW8BDwezP7LtFW3jfHci3doi8iEraQplZERKQBCnIRkcApyEVEAqcgFxEJnIJcRCRwCnIRkcApyEVEAvf/AYidtSiHqt5/AAAAAElFTkSuQmCC\n" 394 | }, 395 | "metadata": { 396 | "needs_background": "light" 397 | } 398 | } 399 | ], 400 | "source": [ 401 | "covs = jnp.sqrt(jnp.concatenate(fg.belief_covs()).flatten())\n", 402 | "plt.errorbar(xs, fg.belief_means(), yerr=covs, fmt='o', color=\"C0\", label='Beliefs')\n", 403 | "plt.scatter(meas_x, meas_y, color=\"red\", label=\"Measurements\", marker=\".\")\n", 404 | "plt.legend()\n", 405 | "plt.show()" 406 | ] 407 | }, 408 | { 409 | "cell_type": "markdown", 410 | "metadata": { 411 | "id": "_Xo2srZE8h1J" 412 | }, 413 | "source": [ 414 | "## Compute posterior beliefs with GBP " 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": 14, 420 | "metadata": { 421 | "id": "WaSAQ6gx8h1J", 422 | "outputId": "1310ff72-9600-4cec-b15b-0f5c370f277f", 423 | "colab": { 424 | "base_uri": "https://localhost:8080/" 425 | } 426 | }, 427 | "outputs": [ 428 | { 429 | "output_type": "stream", 430 | "name": "stdout", 431 | "text": [ 432 | "\n", 433 | "Initial Energy 106.23512\n", 434 | "Iter 1 --- Energy 45.52082 --- \n", 435 | "Iter 2 --- Energy 26.09324 --- \n", 436 | "Iter 3 --- Energy 27.78959 --- \n", 437 | "Iter 4 --- Energy 17.66786 --- \n", 438 | "Iter 5 --- Energy 16.32376 --- \n", 439 | "Iter 6 --- Energy 15.39065 --- \n", 440 | "Iter 7 --- Energy 14.81990 --- \n", 441 | "Iter 8 --- Energy 14.48639 --- \n", 442 | "Iter 9 --- Energy 14.44526 --- \n", 443 | "Iter 10 --- Energy 14.43492 --- \n" 444 | ] 445 | } 446 | ], 447 | "source": [ 448 | "fg.gbp_solve(n_iters=10)" 449 | ] 450 | }, 451 | { 452 | "cell_type": "markdown", 453 | "metadata": { 454 | "id": "QbAlh3AN8h1K" 455 | }, 456 | "source": [ 457 | "## Plot beliefs and measurements" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": 15, 463 | "metadata": { 464 | "id": "OaQNn_Lk8h1K", 465 | "outputId": "e46bbfd3-62ee-4aab-de67-0d064cc1c8a5", 466 | "colab": { 467 | "base_uri": "https://localhost:8080/", 468 | "height": 265 469 | } 470 | }, 471 | "outputs": [ 472 | { 473 | "output_type": "display_data", 474 | "data": { 475 | "text/plain": [ 476 | "
" 477 | ], 478 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAcqElEQVR4nO3df3RV5b3n8feXA5EUxFiIDCR6oeuyGIUQEhCMLFlZosYyiLlYiszqKJdr0XodnXaMY6ZLcdl2oKOrnTuuVqHIFVtLGRExM7hKFO6pXk3RKCgqKKiIiXiN0Gi1YCB55o9zkiYhPwhnn7P32fvzWitrn73Pzt7P2YRPnjz72c9jzjlERCT8BvldABERyQwFvohIRCjwRUQiQoEvIhIRCnwRkYgY7HcBejNq1Cg3btw4v4shIpJVXnnllU+dc/k9vRfYwB83bhz19fV+F0NEJKuY2Qe9vacmHRGRiFDgi4hEhAJfRCQiFPgiIhGhwBcRiQgFvohIRCjwRUQiQoEvIhIRCnwRCZRFq+pYtKrO72KEkgJfRCQiFPgiElx1dbBiRWIpKQvsWDoiEnF1dTBnDrS0QE4ObNsGZWV+lyqrqYYvIsEUj8NXX0Fra2IZj/tdoqynwBeRYBo5EtraEq/b2hLrkhIFvogE0+HDMCgZUYMGJdYlJQp8EQmm8nI44wyIxRLL8nK/S5T1dNNWRIKprCxxozYeT4S9btimTIEvIsFVVqag95CadEREIkKBLyKBsXlnIzsPNrPj/SPMWrmdzTsbB3wMDc3QOwW+iATC5p2NVG/aTUtroitmY/NRqjftPq3Ql54p8EUkEO7b+jZHj7d22Xb0eCv3bX3bpxKFjwJfRDx1uk0qHzUfHdB2GTgFvogEwti83AFtl4FT4IeFRhWULFdVMZHcIbEu23KHxKiqmOhTicJH/fDDQKMKSghUlhQAcMfG12lpbaMgL5eqiokd2yV1CvwwiMcTYd/amljG4wp8yUqVJQWsf+kgABtu1M+w19SkEwbl5YmafSyWWGrMERHpgWr4YaAxR0TkFCjww0JjjohIP9SkE0B6NFxE0sGTwDeztWb2iZm90cv7Zmb/28z2m9nrZlbqxXlFRMImnRU+r2r4jwBX9vH+N4EJya9lwIMenVdERE6RJ4HvnHsOONLHLlcDj7qEPwJ5ZjbGi3OLiMipyVQbfgHwYaf1huS2LsxsmZnVm1l9U1NThoomIhINgeql45xbDawGmD59uvO5OCLiAz1wlT6ZquE3Aud2Wi9MbhMR8YwXE6iEWaYCvwa4Ltlb5yLgM+fcoQydW0QiQBOo9M+TJh0zWw+UA6PMrAFYDgwBcM49BDwNzAX2A38B/t6L8wZOXR3E4ywaXAojRuhPU4mc9hp2S2sbs1Zuz+jgZ31NoKIB2BI8CXzn3OJ+3nfAP3pxrsDqPGLltf8Diov9LpFIRvVWwwYyErhBmUClvQ99ECt8etLWK51HrHRt0Nzsd4lEMsrvKQo1gUr/FPhe6TxipQ2CvDy/SySSUX7XsDWBSv8C1S0zq3UesXJwMYwY4XeJRDJqbF4ujT2Ee6Zq2JpApX+q4XuprAyqq30Pew2+Jn4IQg27sqSAkvPymDn+67xw56UK+25UwxcRT6iGHXwK/IDxs1ubSKo0RWGwqUknQPTgiIikkwLfY6k82u13tzYRCTcFvodSraH73a1NpFd1dbBiRWIZcmHu9KA2fA+l+mi3393aRHrU+SnynJxE92PNn5yVVMP3UKo19CB0axM5SeenyFtaEuuSlRT4PTjdP+nG5vZ8OU+1hl5ZUsCKBUXkxBLHKcjLZcWCIvXSEX91foo8JyexLllJge+Vujqqnvw5ucePddk80Bp6qg+OaDxw8Vz7U+Q/+pGac7Kc2vC9Eo9T+cZ2aGnhjrm30TJ4CAV5X8tcP/q6OjbX7qK65TxakrcRMj1aoYRYWZmCPgRUw/dK8s/eynf+lZKP9zFz5JDMPdqdvKl238dDOdr1nrG6dYo3ItRLJ8xUw/eKn4OnJW+qfTRiVI9vq1unDMRJ47mrl05oqIbvJb8GTysvh8GDGfv5pz2+rW6dkhL10gkNBX5YOEfVc4+mfNNY5CTqpRMaatIJg3gcWlupfCsOZtwx7/u0MEijFYo3OjdXlperOSeLKfDDoL0G1tJC5Xt/ZP3IHE2iLt7Kol46+rnvnQI/DLrXwF73u0ASZQrc4FLgh0XnGtjr6jonIidT4KeBajgi2SnsExCpl46ICNGYgEiBLyJCNCYgUuCLiBCNCYgU+CG04cYy3UeQ0xLl0VZ7eyJ9IE+qp3r90n39FfgBpMAWP0ShDbsvqU5AlOr1y8T1V+B3E+UajkRbFNqw+5LqBESpXr9MXH91y+ykt9+wEK3x5E8aLVEiIQpt2P2pLClg/UsHgYH//Kd6/TJx/T2p4ZvZlWb2tpntN7M7e3h/iZk1mdmu5NcNXpzXa1Gv4Ui0edGGHWWpXr9MXP+UA9/MYsAvgG8CFwCLzeyCHnbd4Jybmvxak+p500E1HImyVNuwoy7V65eJ6+9FDX8GsN85955zrgX4HXC1B8fNONVwJMpSbcOOulSvXyauvxdt+AXAh53WG4CZPex3jZnNBt4Bvu+c+7D7Dma2DFgGcN5553lQtIGpqphI9abdXZp1VMORKEmlDVtSv37pvv6Z6qXzf4FxzrkpwDPAup52cs6tds5Nd85Nz8/Pz1DR/ko1HBEJMy9q+I3AuZ3WC5PbOjjnDndaXQP8Tw/Omxaq4YhIWHlRw38ZmGBm480sB7gWqOm8g5mN6bQ6H9jjwXlFRNKnrg5WrEgsQyLlGr5z7oSZ3QJsBWLAWufcm2Z2L1DvnKsBbjWz+cAJ4AiwJNXzioikzeefw5y/S0zanpOTmGAoS2b86osnD145554Gnu627e5Or6uBai/OdSr04JCIpKS5ORH2ra2JZTweisDX0AoiIt3l5SVq9rFYYlle7neJPKGhFUREuhsxous80SGo3YMCX0SkZ53niQ4JBb6ISCdhvvenNnwRkYhQDV9EughzDTfqVMMXEYkIBb6ISEQo8EVEIkKBLyISEQp86UKTuIuEV+gCX4F1+nqbxF3XUCQcQtUts7fAAgY0iUlUu6X1NYm7JoERyX6hquH3FVjSP03iLhJuoQp8BVZqNIm7SLiFKvAVWKmpqphI7pBYl22axF0kPEIV+Aqs1GgSd5FwC9VN2/ZgumPj67S0tlGQl0tVxUQF1gBoEneR8ApV4IMCS0SkN6ELfBERPwW5ohmqNnwREemdAl9EJCIU+CIhs2hVHYtW1fldDAmgcAb+55/DwYNQpx96EZF24Qv8ujp47TU48D7MmaPQ94FqmCLBFL7Aj8fBtYEDWloS6yIiEsLALy8HGwQG5OQk1kVEJIT98MvK4F8/h+Zm2LYtsS4iIiEMfIARIxJfCnsRyTLpfHDLkyYdM7vSzN42s/1mdmcP759hZhuS7+8ws3FenFdERE5dyjV8M4sBvwAuBxqAl82sxjn3Vqfd/gH4k3Pub83sWuCnwKJUz92bID/aLCLiFy9q+DOA/c6595xzLcDvgKu77XM1sC75eiMwx8zMg3OLiMgp8iLwC4APO603JLf1uI9z7gTwGTDSg3OLiMgpClS3TDNbZmb1Zlbf1NTkd3FERELFi8BvBM7ttF6Y3NbjPmY2GDgLONz9QM651c656c656fn5+R4UTURE2nkR+C8DE8xsvJnlANcCNd32qQGuT77+FrDdOec8OLeIiJyilHvpOOdOmNktwFYgBqx1zr1pZvcC9c65GuBh4Ndmth84QuKXggSUejmJhJMnD145554Gnu627e5Or48BC704l4j0bvPORnYebKaltY1ZK7drTmfpIlA3bUXk9G3e2Uj1pt20tLYB0Nh8lOpNu9m8s/stNYkqBb5ISNy39W2OHm/tsu3o8Vbu2/q2TyWSoFHgi4TER81HB7RdokeBLxISY/NyB7RdokeBLxISVRUTyR0S67Itd0iMqoqJPpVIgiacwyOLRFB7b5w7Nr5OS2sbBXm56qUjXSjwRUKksqSA9S8dBPQ8hZxMTToiIhGhwBcRiQgFvniq/UnPHe8fYdbK7XroRyRAFPjiGT3pKRJsCnzxjJ70FAk2Bb54Rk96igSbAl88oyc9RYJNgS+e0ZOeIsGmB6/EM3rSUyTYFPjiKT3pKRJcatIREYkIBb6ISEQo8EUCZtGqOhatqvO7GBJCCnwRkYjQTVuRkNHNcumNavgiIhGhwBfxmNrgJagU+CIiEaHAFxGJCAW+BI6aRETSQ4EvIhIRCnwRkYhQ4IuIRERKgW9mXzezZ8xsX3J5di/7tZrZruRXTSrnFBGR05NqDf9OYJtzbgKwLbnek6POuanJr/kpnlMkmlavhoqKxFLkNKQ6tMLVQHny9TogDvy3FI8pIt2tXg033ph4XVubWC5b5l95JCulWsMf7Zw7lHz9MTC6l/2Gmlm9mf3RzCp7O5iZLUvuV9/U1JRi0UROj5/dQjfvbGTnwWZ2vH+EWSu3s3lnY+KNJ57oumP3dZFT0G8N38yeBf5dD2/9sPOKc86ZmevlMH/jnGs0s28A281st3Pu3e47OedWA6sBpk+f3tuxREJp885GqjftpqW1DYDG5qNUb9oNQOU11/y1Zg9wzTV+FFGyXL+B75y7rLf3zOzfzGyMc+6QmY0BPunlGI3J5XtmFgdKgJMCXyTK7tv6NkePt3bZdvR4K/dtfZvKO5PNN088kQh7NefIaUi1SacGuD75+nrgqe47mNnZZnZG8vUoYBbwVornFQmdj5qP9r192TLYulVhL6ct1cBfCVxuZvuAy5LrmNl0M1uT3Od8oN7MXgP+BVjpnFPgi3QzNi93QNtFBiqlXjrOucPAnB621wM3JF+/CBSlch7JLpqA4/RUVUyketPuLs06uUNiVFVM9LFUEiaa8UokICpLCgC4Y+PrtLS2UZCXS1XFxI7tIqlS4IsESGVJAetfOgjoLyXxnsbSEfFQr/3oRQJAgS/ikd760Sv0JSgU+CIe6asfvUgQKPBFPNJvP3oRnynwRQaqrg5WrEgsO1E/egk6Bb6ETloHP6urgzlz4K67EstOoV9VMZHcIbEuu6sfvQSJAl9kIOJxaGmB1tbEMh7veKuypIAVC4rIiSX+WxXk5bJiQZH60UtgqB++BEp7t8aW1jZmrdwevAePysshJycR9jk5ifVO1I9egkyBL4HR5/DAQQn9sjLYti1Rsy8vT6yLZAkFvgRGn8MDByXwIRHyCnrJQmrDl8BQt0aR9FINXwJjbF4ujT2Eeya7NQbhHoLa/iVdVMOXwPC7W6OGRpCwU+BLYPjdrVFDI0jYqUlHAsXPbo26hyBhpxq+SJKGRpCwU+CLJPl9D0Ek3dSkI5KkKQYl7BT4Ip1oaAQJMzXpiIhERFbV8I8fP05DQwPHjh3zuyhyioYOHUphYSFDhgzxuygikZdVgd/Q0MCZZ57JuHHjMDO/iyP9cM5x+PBhGhoaGD9+vN/FEYm8rGrSOXbsGCNHjlTYZwkzY+TIkfqLTCQgsqqGDyjss0wU/710s1eCKqtq+CL9aR/8bMf7R5i1crvGwRHpRIE/QGbGd77znY71EydOkJ+fz7x583wsVebE43FefPFFv4vRIw1+JtI3Bf4ADRs2jDfeeIOjRxPjqzzzzDMUFPjzYM6JEycyfs4gB74GPxPpW/gDv64OVqxILD0yd+5ctmzZAsD69etZvHhxx3tffvklS5cuZcaMGZSUlPDUU08BcODAAS655BJKS0spLS3tCM1Dhw4xe/Zspk6dyuTJk3n++ecBGD58eMcxN27cyJIlSwBYsmQJN910EzNnzuSOO+7g3Xff5corr2TatGlccskl7N27t2O/733ve1x00UV84xvfIB6Ps3TpUs4///yOYwHU1tZSVlZGaWkpCxcu5IsvvgBg3LhxLF++nNLSUoqKiti7dy8HDhzgoYce4uc//zlTp07l+eef5/HHH2fy5MkUFxcze/Zsz67x6dDgZyJ9S+mmrZktBO4BzgdmOOfqe9nvSuCfgBiwxjm3MpXznrK6Opgz568TTm/b5snUdNdeey333nsv8+bN4/XXX2fp0qUdQf2Tn/yESy+9lLVr19Lc3MyMGTO47LLLOOecc3jmmWcYOnQo+/btY/HixdTX1/Pb3/6WiooKfvjDH9La2spf/vKXfs/f0NDAiy++SCwWY86cOTz00ENMmDCBHTt2cPPNN7N9+3YA/vSnP1FXV0dNTQ3z58/nhRdeYM2aNVx44YXs2rWLwsJCfvzjH/Pss88ybNgwfvrTn/Kzn/2Mu+++G4BRo0bx6quv8stf/pL777+fNWvWcNNNNzF8+HBuv/12AIqKiti6dSsFBQU0NzenfG1TEYQJVESCLNVeOm8AC4BVve1gZjHgF8DlQAPwspnVOOfeSvHc/YvHE2Hf2ppYxuOeBP6UKVM4cOAA69evZ+7cuV3eq62tpaamhvvvvx9IdCU9ePAgY8eO5ZZbbmHXrl3EYjHeeecdAC688EKWLl3K8ePHqaysZOrUqf2ef+HChcRiMb744gtefPFFFi5c2PHeV1991fH6qquuwswoKipi9OjRFBUVATBp0iQOHDhAQ0MDb731FrNmzQKgpaWFsk7XZ8GCBQBMmzaNTZs29ViWWbNmsWTJEr797W937O+XqoqJVG/a3aVZR4OfifxVSoHvnNsD/Xa9mwHsd869l9z3d8DVQPoDv7w8UbNvr+GXl3t26Pnz53P77bcTj8c5fPhwx3bnHE888QQTJ3YNmXvuuYfRo0fz2muv0dbWxtChQwGYPXs2zz33HFu2bGHJkiX84Ac/4LrrrutyTbv3Yx82bBgAbW1t5OXlsWvXrh7LeMYZZwAwaNCgjtft6ydOnCAWi3H55Zezfv36Pr8/Fov1er/goYceYseOHWzZsoVp06bxyiuvMHLkyB73TTcNfibSt0y04RcAH3Zab0huO4mZLTOzejOrb2pqSv3MZWWJZpwf/ciz5px2S5cuZfny5R215nYVFRU88MADOOcA2LlzJwCfffYZY8aMYdCgQfz617+mtTVRC/3ggw8YPXo03/3ud7nhhht49dVXARg9ejR79uyhra2NJ598sscyjBgxgvHjx/P4448DiV82r7322il/hosuuogXXniB/fv3A4n7D+1/efTmzDPP5M9//nPH+rvvvsvMmTO59957yc/P58MPP+zju9OvsqSAkvPymDn+67xw56UKe5FO+g18M3vWzN7o4etqrwvjnFvtnJvunJuen5/vzUHLyqC62tOwBygsLOTWW289aftdd93F8ePHmTJlCpMmTeKuu+4C4Oabb2bdunUUFxezd+/ejlp6PB6nuLiYkpISNmzYwG233QbAypUrmTdvHhdffDFjxozptRyPPfYYDz/8MMXFxUyaNKnjJvGpyM/P55FHHmHx4sVMmTKFsrKyjpu+vbnqqqt48sknO27aVlVVUVRUxOTJk7n44ospLi4+5fOLSGZZe000pYOYxYHbe7ppa2ZlwD3OuYrkejWAc25FX8ecPn26q6/verg9e/Zw/vnnp1xeyayB/rstWpXoUXW6T6z6/f0ifjKzV5xz03t6LxNDK7wMTDCz8UAjcC3wHzNwXslSClqR9Ei1W+bfAQ8A+cAWM9vlnKsws7Ekul/Odc6dMLNbgK0kumWudc69mXLJRdJEv3AkrFLtpfMkcNIdRefcR8DcTutPA0+nci4REUlN+J+0FRERIAKBv2hVXcdNOBGRKAt94IuISEKoAz8dY6PHYjGmTp1KcXFxl0HQ+tI+ENpHH33Et771rX73r6qqYtKkSVRVVaVcXhGRdlk349Wp6m1sdCClpy9zc3M7hjLYunUr1dXV/OEPfzil7x07diwbN27sd7/Vq1dz5MgRYrHYaZdTRKS70NbwMzE2+ueff87ZZ5/913Pedx8XXnghU6ZMYfny5Sftf+DAASZPngxAa2srVVVVHfuvWpUYf27+/Pl88cUXTJs2jQ0bNgRq+GERyW6hreGna2z0o0ePMnXqVI4dO8ahQ4c6hiKura1l3759vPTSSzjnmD9/Ps8991yvIf3www9z1lln8fLLL/PVV18xa9YsrrjiCmpqahg+fHjHXxFBGn5YRLJbaGv4vY2BnurY6O1NOnv37uX3v/891113Hc45amtrqa2tpaSkhNLSUvbu3cu+fft6PU5tbS2PPvooU6dOZebMmRw+fLjH/duHH/7Vr37VMeCakJaJbUTCLrQ1/EyMjV5WVsann35KU1MTzjmqq6u58cYbT+l7nXM88MADVFRU9LlfkIYfDow0TWwjEnahreFXlhSwYkERObHERyzIy2XFgiJPh8vdu3cvra2tjBw5koqKCtauXdsxRWBjYyOffPJJr99bUVHBgw8+yPHjxwF45513+PLLL0/aL2jDDwdCTxPbiEi/QlvDh0Tor3/pIODd+CjtbfiQqKWvW7eOWCzGFVdcwZ49ezpmjBo+fDi/+c1vOOecc3o8zg033MCBAwcoLS3FOUd+fj6bN28+ab+qqir27duHc445c+Zo+GFI68Q2ImHmyfDI6aDhkcMjLf9udXWJmn15uZpzRDrxe3hkEe+VlSnoRQYotG34IiLSVdYFflCboKRn+vcSCY6sCvyhQ4dy+PBhhUiWcM5x+PBhhg4d6ndRRIQsa8MvLCykoaGBpqYmv4sip2jo0KEUFhb6XQwRIcsCf8iQIYwfP97vYoiIZKWsatIREZHTp8AXEYkIBb6ISEQE9klbM2sCPkjhEKOATz0qTraI2meO2ucFfeaoSOUz/41zLr+nNwIb+Kkys/reHi8Oq6h95qh9XtBnjop0fWY16YiIRIQCX0QkIsIc+Kv9LoAPovaZo/Z5QZ85KtLymUPbhi8iIl2FuYYvIiKdKPBFRCIidIFvZlea2dtmtt/M7vS7POlmZuea2b+Y2Vtm9qaZ3eZ3mTLFzGJmttPM/p/fZckEM8szs41mttfM9phZ6GeAMbPvJ3+u3zCz9WYWuqFXzWytmX1iZm902vZ1M3vGzPYll2d7ca5QBb6ZxYBfAN8ELgAWm9kF/pYq7U4A/9U5dwFwEfCPEfjM7W4D9vhdiAz6J+D3zrl/DxQT8s9uZgXArcB059xkIAZc62+p0uIR4Mpu2+4EtjnnJgDbkuspC1XgAzOA/c6595xzLcDvgKt9LlNaOecOOedeTb7+M4kQKPC3VOlnZoXAfwDW+F2WTDCzs4DZwMMAzrkW51yzv6XKiMFArpkNBr4GfORzeTznnHsOONJt89XAuuTrdUClF+cKW+AXAB92Wm8gAuHXzszGASXADn9LkhH/C7gDaPO7IBkyHmgC/jnZjLXGzIb5Xah0cs41AvcDB4FDwGfOuVp/S5Uxo51zh5KvPwZGe3HQsAV+ZJnZcOAJ4L845z73uzzpZGbzgE+cc6/4XZYMGgyUAg8650qAL/Hoz/ygSrZbX03il91YYJiZfcffUmWeS/Sd96T/fNgCvxE4t9N6YXJbqJnZEBJh/5hzbpPf5cmAWcB8MztAotnuUjP7jb9FSrsGoME51/7X20YSvwDC7DLgfedck3PuOLAJuNjnMmXKv5nZGIDk8hMvDhq2wH8ZmGBm480sh8QNnhqfy5RWZmYk2nX3OOd+5nd5MsE5V+2cK3TOjSPxb7zdORfqmp9z7mPgQzObmNw0B3jLxyJlwkHgIjP7WvLnfA4hv1HdSQ1wffL19cBTXhw0q6Y47I9z7oSZ3QJsJXFHf61z7k2fi5Vus4D/BOw2s13Jbf/dOfe0j2WS9PjPwGPJysx7wN/7XJ60cs7tMLONwKskeqPtJITDLJjZeqAcGGVmDcByYCXwf8zsH0gME/9tT86loRVERKIhbE06IiLSCwW+iEhEKPBFRCJCgS8iEhEKfBGRiFDgi4hEhAJfRCQi/j/jqujAft1j9wAAAABJRU5ErkJggg==\n" 479 | }, 480 | "metadata": { 481 | "needs_background": "light" 482 | } 483 | } 484 | ], 485 | "source": [ 486 | "covs = jnp.sqrt(jnp.concatenate(fg.belief_covs()).flatten())\n", 487 | "plt.errorbar(xs, fg.belief_means(), yerr=covs, fmt='o', color=\"C0\", label='Beliefs')\n", 488 | "plt.scatter(meas_x, meas_y, color=\"red\", label=\"Measurements\", marker=\".\")\n", 489 | "plt.legend()\n", 490 | "plt.savefig('gbp-1d-posteriors.pdf')\n", 491 | "plt.show()" 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": null, 497 | "metadata": { 498 | "id": "Upg3PCA98h1K" 499 | }, 500 | "outputs": [], 501 | "source": [ 502 | "" 503 | ] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "execution_count": null, 508 | "metadata": { 509 | "id": "uVqWpeIt8h1K" 510 | }, 511 | "outputs": [], 512 | "source": [ 513 | "" 514 | ] 515 | } 516 | ], 517 | "metadata": { 518 | "kernelspec": { 519 | "display_name": "py_36", 520 | "language": "python", 521 | "name": "py_36" 522 | }, 523 | "language_info": { 524 | "codemirror_mode": { 525 | "name": "ipython", 526 | "version": 3 527 | }, 528 | "file_extension": ".py", 529 | "mimetype": "text/x-python", 530 | "name": "python", 531 | "nbconvert_exporter": "python", 532 | "pygments_lexer": "ipython3", 533 | "version": "3.6.13" 534 | }, 535 | "colab": { 536 | "name": "gauss-bp-1d-line.ipynb", 537 | "provenance": [], 538 | "include_colab_link": true 539 | } 540 | }, 541 | "nbformat": 4, 542 | "nbformat_minor": 0 543 | } --------------------------------------------------------------------------------