├── .gitignore ├── .gitmodules ├── README.md ├── examples ├── bernoulli_regression.py ├── gif │ ├── lls.png │ ├── posterior_mean.jpg │ ├── synthetic_gif.py │ ├── test_model.gif │ └── true_model.jpg └── synthetic.py ├── pyglm ├── __init__.py ├── models.py ├── networks.py ├── plotting.py ├── regression.py └── utils │ ├── __init__.py │ ├── basis.py │ ├── profiling.py │ └── utils.py ├── setup.py └── test └── test_generate.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | lib/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | 25 | # PyInstaller 26 | # Usually these files are written by a python script from a template 27 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 28 | *.manifest 29 | *.spec 30 | 31 | # Installer logs 32 | pip-log.txt 33 | pip-delete-this-directory.txt 34 | 35 | # Unit test / coverage reports 36 | htmlcov/ 37 | .tox/ 38 | .coverage 39 | .cache 40 | nosetests.xml 41 | coverage.xml 42 | 43 | # Translations 44 | *.mo 45 | *.pot 46 | 47 | # Django stuff: 48 | *.log 49 | 50 | # Sphinx documentation 51 | docs/_build/ 52 | 53 | # PyBuilder 54 | target/ 55 | 56 | # Specific files 57 | /*.pdf 58 | 59 | # Gif files 60 | examples/gif/test_model_*.jpg -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pyglm/deps/graphistician"] 2 | path = pyglm/deps/graphistician 3 | url = https://github.com/slinderman/graphistician.git 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyGLM: Bayesian inference for nonlinear autoregressive models of count data 2 | 3 | Neural circuits contain heterogeneous groups of neurons that differ in 4 | type, location, connectivity, and basic response properties. However, 5 | traditional methods for dimensionality reduction and clustering are 6 | ill-suited to recovering the structure underlying the organization of 7 | neural circuits. In particular, they do not take advantage of the rich 8 | temporal dependencies in multi-neuron recordings and fail to account 9 | for the noise in neural spike trains. This repository contains tools for 10 | inferring latent structure from simultaneously recorded spike train 11 | data using a hierarchical extension of a multi-neuron point process 12 | model commonly known as the generalized linear model (GLM). In the 13 | statistics and time series analysis communities, these correspond to 14 | nonlinear vector autoregressive processes with count observations. 15 | We combine the GLM with flexible graph-theoretic priors 16 | governing the relationship between latent features and neural 17 | connectivity patterns. Fully Bayesian inference via \polyagamma 18 | augmentation of the resulting model allows us to classify neurons and 19 | infer latent dimensions of circuit organization from correlated spike 20 | trains. We demonstrate the effectiveness of our method with 21 | applications to synthetic data and multi-neuron recordings in primate 22 | retina, revealing latent patterns of neural types and locations from 23 | spike trains alone. 24 | 25 | # References 26 | Scott W. Linderman, Ryan P. Adams, and Jonathan W. Pillow. Bayesian latent 27 | structure discovery from multi-neuron recordings. _Advances in Neural Information 28 | Processing Systems_, 2016. 29 | 30 | Scott W. Linderman. _Bayesian methods for discovering structure in neural spike 31 | trains_. PhD thesis, Harvard University, 2016. 32 | 33 | # Example 34 | We provide a number of classes for building and fitting such models. 35 | Let's walk through a simple example 36 | where we construct a discrete time model with four neurons, as in `examples/synthetic`. 37 | The neurons are connected via a network such that spikes influence 38 | the rate of subsequent spikes on post-synaptic (downstream) neurons. 39 | ```python 40 | # Create a simple, sparse network of four neurons 41 | T = 10000 # Number of time bins to generate 42 | N = 4 # Number of neurons 43 | B = 1 # Number of "basis functions" 44 | L = 100 # Autoregressive window of influence 45 | 46 | # Create a cosine basis to model smooth influence of 47 | # spikes on one neuron on the later spikes of others. 48 | basis = cosine_basis(B=B, L=L) / L 49 | 50 | # Generate some data from a model with self inhibition 51 | true_model = SparseBernoulliGLM(N, basis=basis) 52 | 53 | # Generate T time bins of events from the the model 54 | # Y is the generated spike train. 55 | # X is the filtered spike train for inference. 56 | X, Y = true_model.generate(T=T, keep=True) 57 | 58 | # Plot the model parameters and data 59 | true_model.plot() 60 | ``` 61 | 62 | You should see something like this: 63 | 64 | ![True Model](examples/gif/true_model.jpg) 65 | 66 | 67 | Now create a test model and try to infer the network given the spike train. 68 | ```python 69 | # Create the test model and add the spike train 70 | test_model = SparseBernoulliGLM(N, basis=basis) 71 | test_model.add_data(Y) 72 | 73 | # Initialize the plot 74 | _, _, handles = test_model.plot() 75 | 76 | # Run a Gibbs sampler 77 | N_samples = 100 78 | lps = [] 79 | for itr in xrange(N_samples): 80 | test_model.resample_model() 81 | lps.append(test_model.log_likelihood()) 82 | test_model.plot(handles=handles) 83 | ``` 84 | 85 | With interactive plotting enabled, you should see something like: 86 | ![Test Model](examples/gif/test_model.gif) 87 | 88 | Finally, we can plot the log likelihood over iterations to assess the 89 | convergence of the sampling algorithm, at least in a heuristic way. 90 | 91 | ```python 92 | plt.plot(lps) 93 | plt.xlabel("Iteration") 94 | plt.ylabel("Log Likelihood") 95 | ``` 96 | 97 | ![Log Likelihood](examples/gif/lls.png) 98 | 99 | Looks like it has! Now let's look at the posterior mean of the 100 | network and firing rates. 101 | 102 | ![Posterior Mean](examples/gif/posterior_mean.jpg) 103 | 104 | # Installation 105 | PyGLM requires [pypolyagamma](https://github.com/slinderman/pypolyagamma) 106 | for its Bayesian inference algorithms. This dependency will automatically 107 | be installed if you do not already have it, but by default, `pip` will not 108 | install the parallel version. If you want to use parallel resampling, look 109 | at the [pypolyagamma](https://github.com/slinderman/pypolyagamma) homepage 110 | for instructions on installing from source with OpenMP. 111 | 112 | To install `pyglm` from source, first clone the repo 113 | 114 | git clone git@github.com:slinderman/pyglm.git 115 | cd pyglm 116 | 117 | Then install in developer mode: 118 | 119 | pip install -e . 120 | 121 | Or use the standard install: 122 | 123 | python setup.py install 124 | 125 | -------------------------------------------------------------------------------- /examples/bernoulli_regression.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the Bernoulli regression models. 3 | """ 4 | import numpy as np 5 | # np.random.seed(1) 6 | 7 | import matplotlib.pyplot as plt 8 | import seaborn as sns 9 | sns.set_style("white") 10 | 11 | from pybasicbayes.util.text import progprint_xrange 12 | from pyglm.regression import SparseBernoulliRegression 13 | 14 | N = 2 15 | B = 1 16 | T = 1000 17 | 18 | # Make a regression model and simulate data 19 | true_reg = SparseBernoulliRegression(N, B) 20 | X = np.random.randn(T, N*B) 21 | y = true_reg.rvs(X=X) 22 | 23 | # Make a test regression and fit it 24 | test_reg = SparseBernoulliRegression(N, B) 25 | test_reg.a = np.bitwise_not(true_reg.a) 26 | 27 | def _collect(r): 28 | return r.a.copy(), r.W.copy(), r.log_likelihood((X, y)).sum() 29 | 30 | def _update(r): 31 | r.resample([(X,y)]) 32 | return _collect(r) 33 | 34 | smpls = [_collect(test_reg)] 35 | for _ in progprint_xrange(100): 36 | smpls.append(_update(test_reg)) 37 | 38 | smpls = zip(*smpls) 39 | As, Ws, lps = tuple(map(np.array, smpls)) 40 | 41 | # Plot the regression results 42 | plt.figure() 43 | lim = (-3, 3) 44 | npts = 50 45 | x1, x2 = np.meshgrid(np.linspace(*lim, npts), np.linspace(*lim, npts)) 46 | 47 | plt.subplot(121) 48 | mu = true_reg.mean(np.column_stack((x1.ravel(), x2.ravel()))) 49 | plt.imshow(mu.reshape((npts, npts)), 50 | cmap="Greys", vmin=-0, vmax=1, 51 | alpha=0.8, 52 | extent=lim + tuple(reversed(lim))) 53 | plt.scatter(X[:,0], X[:,1], c=y, vmin=0, vmax=1) 54 | plt.xlim(lim) 55 | plt.ylim(lim) 56 | plt.colorbar() 57 | 58 | plt.subplot(122) 59 | mu = test_reg.mean(np.column_stack((x1.ravel(), x2.ravel()))) 60 | plt.imshow(mu.reshape((npts, npts)), 61 | cmap="Greys", vmin=0, vmax=1, 62 | alpha=0.8, 63 | extent=lim + tuple(reversed(lim))) 64 | plt.scatter(X[:,0], X[:,1], c=y, vmin=0, vmax=1) 65 | plt.xlim(lim) 66 | plt.ylim(lim) 67 | plt.colorbar() 68 | 69 | print("True A: {}".format(true_reg.a)) 70 | print("Mean A: {}".format(As.mean(0))) 71 | 72 | plt.show() 73 | 74 | -------------------------------------------------------------------------------- /examples/gif/lls.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slinderman/pyglm/cb93ce901e6bb127b6baa408b8e5b9995c932f17/examples/gif/lls.png -------------------------------------------------------------------------------- /examples/gif/posterior_mean.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slinderman/pyglm/cb93ce901e6bb127b6baa408b8e5b9995c932f17/examples/gif/posterior_mean.jpg -------------------------------------------------------------------------------- /examples/gif/synthetic_gif.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | np.random.seed(0) 4 | 5 | import matplotlib.pyplot as plt 6 | import seaborn as sns 7 | sns.set_style("white") 8 | sns.set_context("paper") 9 | plt.ion() 10 | 11 | from pybasicbayes.util.text import progprint_xrange 12 | 13 | from pyglm.utils.basis import cosine_basis 14 | from pyglm.plotting import plot_glm 15 | from pyglm.models import SparseBernoulliGLM 16 | 17 | T = 10000 # Number of time bins to generate 18 | N = 4 # Number of neurons 19 | B = 1 # Number of "basis functions" 20 | L = 100 # Autoregressive window of influence 21 | 22 | # Create a cosine basis to model smooth influence of 23 | # spikes on one neuron on the later spikes of others. 24 | basis = cosine_basis(B=B, L=L) / L 25 | 26 | # Generate some data from a model with self inhibition 27 | true_model = \ 28 | SparseBernoulliGLM(N, basis=basis, 29 | regression_kwargs=dict(S_w=10.0, mu_b=-2.)) 30 | for n in range(N): 31 | true_model.regressions[n].a[n] = True 32 | true_model.regressions[n].W[n,:] = -2.0 33 | _, Y = true_model.generate(T=T, keep=True) 34 | 35 | # Plot the true model 36 | fig, axs, handles = true_model.plot() 37 | fig.savefig("examples/gif/true_model.jpg") 38 | plt.pause(0.1) 39 | 40 | # Create a test model for fitting 41 | test_model = \ 42 | SparseBernoulliGLM(N, basis=basis, 43 | regression_kwargs=dict(S_w=10.0, mu_b=-2.)) 44 | test_model.add_data(Y) 45 | 46 | # Plot the test model 47 | fig, axs, handles = test_model.plot(title="Sample 0") 48 | plt.pause(0.1) 49 | fig.savefig("examples/gif/test_model_{:03d}.jpg".format(0)) 50 | 51 | # Fit with Gibbs sampling 52 | def _collect(m): 53 | return m.log_likelihood(), m.weights, m.adjacency, m.biases, m.means[0] 54 | 55 | def _update(m, itr): 56 | m.resample_model() 57 | test_model.plot(handles=handles, 58 | pltslice=slice(0, 500), 59 | title="Sample {}".format(itr+1)) 60 | fig.savefig("examples/gif/test_model_{:03d}.jpg".format(itr+1)) 61 | 62 | return _collect(m) 63 | 64 | N_samples = 100 65 | samples = [] 66 | for itr in progprint_xrange(N_samples): 67 | samples.append(_update(test_model, itr)) 68 | 69 | # Create the gif 70 | import subprocess 71 | cmd = "convert -delay 20 -loop 0 examples/gif/test_model_*.jpg examples/gif/test_model.gif" 72 | subprocess.run(cmd, shell=True) 73 | 74 | # Unpack the samples 75 | samples = zip(*samples) 76 | lps, W_smpls, A_smpls, b_smpls, fr_smpls = tuple(map(np.array, samples)) 77 | 78 | # Plot the log likelihood per iteration 79 | fig = plt.figure(figsize=(4,4)) 80 | plt.plot(lps) 81 | plt.xlabel("Iteration") 82 | plt.ylabel("Log Likelihood") 83 | plt.tight_layout() 84 | fig.savefig("examples/gif/lls.jpg") 85 | 86 | # Plot the posterior mean and variance 87 | W_mean = W_smpls[N_samples//2:].mean(0) 88 | A_mean = A_smpls[N_samples//2:].mean(0) 89 | fr_mean = fr_smpls[N_samples//2:].mean(0) 90 | fr_std = fr_smpls[N_samples//2:].std(0) 91 | 92 | fig, _, _ = plot_glm(Y, W_mean, A_mean, fr_mean, 93 | std_firingrates=3*fr_std, title="Posterior Mean") 94 | 95 | fig.savefig("examples/gif/posterior_mean.jpg") -------------------------------------------------------------------------------- /examples/gif/test_model.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slinderman/pyglm/cb93ce901e6bb127b6baa408b8e5b9995c932f17/examples/gif/test_model.gif -------------------------------------------------------------------------------- /examples/gif/true_model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slinderman/pyglm/cb93ce901e6bb127b6baa408b8e5b9995c932f17/examples/gif/true_model.jpg -------------------------------------------------------------------------------- /examples/synthetic.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | np.random.seed(0) 4 | 5 | import matplotlib.pyplot as plt 6 | import seaborn as sns 7 | sns.set_style("white") 8 | sns.set_context("paper") 9 | plt.ion() 10 | 11 | from pybasicbayes.util.text import progprint_xrange 12 | 13 | from pyglm.utils.basis import cosine_basis 14 | from pyglm.plotting import plot_glm 15 | from pyglm.models import SparseBernoulliGLM 16 | 17 | T = 10000 # Number of time bins to generate 18 | N = 4 # Number of neurons 19 | B = 1 # Number of "basis functions" 20 | L = 100 # Autoregressive window of influence 21 | 22 | # Create a cosine basis to model smooth influence of 23 | # spikes on one neuron on the later spikes of others. 24 | basis = cosine_basis(B=B, L=L) / L 25 | 26 | # Generate some data from a model with self inhibition 27 | true_model = \ 28 | SparseBernoulliGLM(N, basis=basis, 29 | regression_kwargs=dict(S_w=10.0, mu_b=-2.)) 30 | for n in range(N): 31 | true_model.regressions[n].a[n] = True 32 | true_model.regressions[n].W[n,:] = -2.0 33 | _, Y = true_model.generate(T=T, keep=True) 34 | 35 | # Plot the true model 36 | fig, axs, handles = true_model.plot() 37 | plt.pause(0.1) 38 | 39 | # Create a test model for fitting 40 | test_model = \ 41 | SparseBernoulliGLM(N, basis=basis, 42 | regression_kwargs=dict(S_w=10.0, mu_b=-2.)) 43 | 44 | test_model.add_data(Y) 45 | 46 | # Plot the test model 47 | fig, axs, handles = test_model.plot(title="Sample 0") 48 | plt.pause(0.1) 49 | 50 | # Fit with Gibbs sampling 51 | def _collect(m): 52 | return m.log_likelihood(), m.weights, m.adjacency, m.biases, m.means[0] 53 | 54 | def _update(m, itr): 55 | m.resample_model() 56 | test_model.plot(handles=handles, 57 | pltslice=slice(0, 500), 58 | title="Sample {}".format(itr+1)) 59 | return _collect(m) 60 | 61 | N_samples = 100 62 | samples = [] 63 | for itr in progprint_xrange(N_samples): 64 | samples.append(_update(test_model, itr)) 65 | 66 | # Unpack the samples 67 | samples = zip(*samples) 68 | lps, W_smpls, A_smpls, b_smpls, fr_smpls = tuple(map(np.array, samples)) 69 | 70 | # Plot the log likelihood per iteration 71 | fig = plt.figure(figsize=(4,4)) 72 | plt.plot(lps) 73 | plt.xlabel("Iteration") 74 | plt.ylabel("Log Likelihood") 75 | plt.tight_layout() 76 | 77 | # Plot the posterior mean and variance 78 | W_mean = W_smpls[N_samples//2:].mean(0) 79 | A_mean = A_smpls[N_samples//2:].mean(0) 80 | fr_mean = fr_smpls[N_samples//2:].mean(0) 81 | fr_std = fr_smpls[N_samples//2:].std(0) 82 | 83 | fig, _, _ = plot_glm(Y, W_mean, A_mean, fr_mean, 84 | std_firingrates=3*fr_std, title="Posterior Mean") 85 | 86 | plt.show() -------------------------------------------------------------------------------- /pyglm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slinderman/pyglm/cb93ce901e6bb127b6baa408b8e5b9995c932f17/pyglm/__init__.py -------------------------------------------------------------------------------- /pyglm/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pybasicbayes.abstractions import ModelGibbsSampling 3 | 4 | import pyglm.networks 5 | import pyglm.regression 6 | from pyglm.utils.basis import convolve_with_basis 7 | 8 | class NonlinearAutoregressiveModel(ModelGibbsSampling): 9 | """ 10 | The "generalized linear model" in neuroscience is really 11 | a vector autoregressive model. As the name suggests, 12 | the key component in these models is a regression from 13 | inputs, x, to outputs, y. 14 | 15 | When the outputs are discrete random variables, like spike 16 | counts, we typically take the regression to be a generalized 17 | linear model: 18 | 19 | y ~ p(mu(x), theta) 20 | mu(x) = f(w \dot x) 21 | 22 | where 'p' is a discrete distribution, like the Poisson, 23 | and 'f' is a "link" function that maps a linear function of 24 | x to the parameters of 'p'. Hence the name "GLM" in 25 | computational neuroscience. 26 | """ 27 | 28 | def __init__(self, N, regressions, basis=None, B=10): 29 | """ 30 | :param N: Observation dimension 31 | :param regressions: Regression objects, one per observation dim. 32 | :param basis: Basis onto which the preceding activity is projected 33 | In the "identity" case, this is just a lag matrix 34 | :param B: Basis dimensionality. 35 | In the "identity" case, this is the number of lags. 36 | """ 37 | self.N = N 38 | 39 | assert len(regressions) == N 40 | self.regressions = regressions 41 | 42 | # Initialize the basis 43 | if basis is None: 44 | basis = np.eye(B) 45 | else: 46 | assert basis.ndim == 2 47 | self.basis = basis 48 | self.B = self.basis.shape[1] 49 | 50 | # Initialize the data list to empty 51 | self.data_list = [] 52 | 53 | # Expose the autoregressive weights and adjacency matrix 54 | @property 55 | def weights(self): 56 | return np.array([r.W for r in self.regressions]) 57 | 58 | @property 59 | def adjacency(self): 60 | return np.array([r.a for r in self.regressions]) 61 | 62 | @property 63 | def biases(self): 64 | return np.array([r.b for r in self.regressions]).ravel() 65 | 66 | def add_data(self, data, X=None): 67 | N, B = self.N, self.B 68 | assert isinstance(data, np.ndarray) \ 69 | and data.ndim == 2 \ 70 | and data.shape[1] == self.N 71 | T = data.shape[0] 72 | 73 | # Convolve the data with the basis to get regressors 74 | if X is None: 75 | X = convolve_with_basis(data, self.basis) 76 | else: 77 | assert X.shape == (T, N, B) 78 | 79 | # Add the covariates and observations 80 | self.data_list.append((X, data)) 81 | 82 | def log_likelihood(self, datas=None): 83 | if datas is None: 84 | datas = self.data_list 85 | 86 | ll = 0 87 | for data in datas: 88 | if isinstance(data, tuple): 89 | X, Y = data 90 | else: 91 | X, Y = convolve_with_basis(data, self.basis), data 92 | 93 | for n, reg in enumerate(self.regressions): 94 | ll += reg.log_likelihood((X, Y[:,n])).sum() 95 | 96 | return ll 97 | 98 | def generate(self, keep=True, T=100, verbose=False, intvl=10): 99 | """ 100 | Generate data from the model. 101 | 102 | :param keep: Add the data to the model's datalist 103 | :param T: Number of time bins to simulate 104 | :param verbose: Whether or not to print status 105 | :param intvl: Number of intervals between printing status 106 | 107 | :return X: Convolution of data with basis functions 108 | :return Y: Generate data matrix 109 | """ 110 | if T == 0: 111 | return np.zeros((0,self.N)) 112 | assert isinstance(T, int), "Size must be an integer number of time bins" 113 | 114 | N, basis = self.N, self.basis 115 | L, B = basis.shape 116 | 117 | # NOTE: the basis is defined such that the first row is the 118 | # previous time step and the last row is T-L steps in 119 | # the past. Thus, for the dot products below, we need 120 | # to flip the basis matrix. 121 | basis = np.flipud(basis) 122 | assert not np.allclose(basis, self.basis) 123 | 124 | # Precompute the weights and biases 125 | W = self.weights.reshape((N, N*B)) # N x NB (post x (pre x B)) 126 | b = self.biases # N (post) 127 | 128 | # Initialize output matrix of spike counts 129 | Y = np.zeros((T+L, N)) 130 | X = np.zeros((T+L, N, B)) 131 | Psi = np.zeros((T+L, N)) 132 | 133 | # Iterate forward in time 134 | for t in range(L,T+L): 135 | if verbose: 136 | if t % intvl == 0: 137 | print("Generate t={}".format(t)) 138 | # 1. Project previous activity window onto the basis 139 | # previous activity is L x N, basis is L x B, 140 | X[t] = Y[t-L:t].T.dot(basis) 141 | 142 | # 2. Compute the activation, W.dot(X[t]) + b 143 | Psi[t] = W.dot(X[t].reshape((N*B,))) + b 144 | 145 | # 3. Sample new data 146 | Y[t] = self.regressions[0].rvs(psi=Psi[t]) 147 | 148 | if keep: 149 | self.add_data(Y[L:], X=X[L:]) 150 | 151 | return X[L:], Y[L:] 152 | 153 | @property 154 | def means(self): 155 | """ 156 | Compute the mean observation for each dataset 157 | """ 158 | mus = [] 159 | for (X,Y) in self.data_list: 160 | mus.append(np.column_stack( 161 | [r.mean(X) for r in self.regressions])) 162 | 163 | return mus 164 | 165 | ### Gibbs sampling 166 | def resample_model(self): 167 | self.resample_regressions() 168 | 169 | def resample_regressions(self): 170 | for n, reg in enumerate(self.regressions): 171 | reg.resample([(X, Y[:,n]) for (X,Y) in self.data_list]) 172 | 173 | ### Plotting 174 | def plot(self, 175 | fig=None, 176 | axs=None, 177 | handles=None, 178 | title=None, 179 | figsize=(6,3), 180 | W_lim=3, 181 | pltslice=slice(0, 500), 182 | N_to_plot=2, 183 | data_index=0): 184 | """ 185 | Plot the parameters of the model 186 | :return: 187 | """ 188 | from pyglm.plotting import plot_glm 189 | return plot_glm( 190 | self.data_list[data_index][1], 191 | self.weights, 192 | self.adjacency, 193 | self.means[0], 194 | fig=fig, 195 | axs=axs, 196 | handles=handles, 197 | title=title, 198 | figsize=figsize, 199 | W_lim=W_lim, 200 | pltslice=pltslice, 201 | N_to_plot=N_to_plot) 202 | 203 | 204 | class HierarchicalNonlinearAutoregressiveModel(NonlinearAutoregressiveModel): 205 | """ 206 | The network GLM is really just a hierarchical AR model. We specify a 207 | prior distribution on the weights of a collection of conditionally 208 | independent AR models. Since these weights are naturally interpreted 209 | as a network, we refer to these as "network" AR models, or "network GLMs". 210 | """ 211 | 212 | def __init__(self, N, network, regressions, basis=None, B=10): 213 | """ 214 | The only difference here is that we also provide a 'network' object, 215 | which specifies a prior distribution on the regression weights. 216 | 217 | :param network: 218 | """ 219 | super(HierarchicalNonlinearAutoregressiveModel, self). \ 220 | __init__(N, regressions, basis=basis, B=B) 221 | 222 | self.network = network 223 | 224 | def resample_model(self): 225 | super(HierarchicalNonlinearAutoregressiveModel, self).resample_model() 226 | self.resample_network() 227 | 228 | def resample_network(self): 229 | net = self.network 230 | net.resample((self.adjacency, self.weights)) 231 | 232 | # Update the regression hyperparameters 233 | for n, reg in enumerate(self.regressions): 234 | reg.S_w = net.sigma_W[n] 235 | reg.mu_w = net.mu_W[n] 236 | reg.rho = net.rho[n] 237 | 238 | # Alias the "GLM" and its "Network" extension 239 | GLM = NonlinearAutoregressiveModel 240 | NetworkGLM = HierarchicalNonlinearAutoregressiveModel 241 | 242 | # Define default GLMs for various regression classes 243 | class _DefaultMixin(object): 244 | _network_class = None 245 | _regression_class = None 246 | def __init__(self, N, B=10, basis=None, 247 | network=None, 248 | network_kwargs=None, 249 | regressions=None, 250 | regression_kwargs=None): 251 | """ 252 | :param N: Observation dimension. 253 | :param basis: Basis onto which the preceding activity is projected. 254 | In the "identity" case, this is just a lag matrix 255 | :param B: Basis dimensionality. 256 | In the "identity" case, this is the number of lags. 257 | :param kwargs: arguments to the corresponding regression constructor. 258 | """ 259 | B = B if basis is None else basis.shape[1] 260 | if network is None: 261 | network_kwargs = dict() if network_kwargs is None else network_kwargs 262 | network = self._network_class(N, B, **network_kwargs) 263 | 264 | if regressions is None: 265 | regression_kwargs = dict() if regression_kwargs is None else regression_kwargs 266 | regressions = [self._regression_class(N, B, **regression_kwargs) for _ in range(N)] 267 | super(_DefaultMixin, self).__init__(N, network, regressions, B=B, basis=basis) 268 | 269 | 270 | class GaussianGLM(_DefaultMixin, NetworkGLM): 271 | _network_class = pyglm.networks.NIWDenseNetwork 272 | _regression_class = pyglm.regression.GaussianRegression 273 | 274 | class SparseGaussianGLM(_DefaultMixin, NetworkGLM): 275 | _network_class = pyglm.networks.NIWSparseNetwork 276 | _regression_class = pyglm.regression.SparseGaussianRegression 277 | 278 | class BernoulliGLM(_DefaultMixin, NetworkGLM): 279 | _network_class = pyglm.networks.NIWDenseNetwork 280 | _regression_class = pyglm.regression.BernoulliRegression 281 | 282 | class SparseBernoulliGLM(_DefaultMixin, NetworkGLM): 283 | _network_class = pyglm.networks.NIWSparseNetwork 284 | _regression_class = pyglm.regression.SparseBernoulliRegression 285 | -------------------------------------------------------------------------------- /pyglm/networks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Define some "networks" -- hierarchical prior distributions on the 3 | weights of a set of regression objects. 4 | """ 5 | import abc 6 | import numpy as np 7 | 8 | from pybasicbayes.abstractions import GibbsSampling 9 | from pybasicbayes.distributions import Gaussian 10 | 11 | from pyglm.utils.utils import expand_scalar, expand_cov 12 | 13 | class _NetworkModel(GibbsSampling): 14 | def __init__(self, N, B, **kwargs): 15 | """ 16 | Only extra requirement is that we explicitly tell it the 17 | number of nodes and the dimensionality of the weights in the constructor. 18 | 19 | :param N: Number of nodes 20 | :param B: Dimensionality of the weights 21 | """ 22 | self.N, self.B = N, B 23 | 24 | @abc.abstractmethod 25 | def resample(self,data=[]): 26 | """ 27 | Every network mixin's resample method should call its parent in its 28 | first line. That way we can ensure that this base method is called 29 | first, and that each mixin is resampled appropriately. 30 | 31 | :param data: an adjacency matrix and a weight matrix 32 | A, W = data 33 | A in [0,1]^{N x N} where rows are incoming and columns are outgoing nodes 34 | W in [0,1]^{N x N x B} where rows are incoming and columns are outgoing nodes 35 | """ 36 | assert isinstance(data, tuple) 37 | A, W = data 38 | N, B = self.N, self.B 39 | assert A.shape == (N, N) 40 | assert A.dtype == bool 41 | assert W.shape == (N, N, B) 42 | 43 | @abc.abstractproperty 44 | def mu_W(self): 45 | """ 46 | NxNxB array of mean weights 47 | """ 48 | raise NotImplementedError 49 | 50 | @abc.abstractproperty 51 | def sigma_W(self): 52 | """ 53 | NxNxBxB array with conditional covariances of each weight 54 | """ 55 | raise NotImplementedError 56 | 57 | @abc.abstractproperty 58 | def rho(self): 59 | """ 60 | Connection probability 61 | :return: NxN matrix with values in [0,1] 62 | """ 63 | pass 64 | 65 | ## TODO: Add properties for info form weight parameters 66 | 67 | def log_likelihood(self, x): 68 | # TODO 69 | return 0 70 | 71 | def rvs(self,size=[]): 72 | # TODO 73 | return None 74 | 75 | ### Weight models 76 | class _IndependentGaussianMixin(_NetworkModel): 77 | """ 78 | Each weight is an independent Gaussian with a shared NIW prior. 79 | Special case the self-connections. 80 | """ 81 | def __init__(self, N, B, 82 | mu_0=0.0, sigma_0=1.0, kappa_0=1.0, nu_0=3.0, 83 | is_diagonal_weight_special=True, 84 | **kwargs): 85 | super(_IndependentGaussianMixin, self).__init__(N, B) 86 | 87 | mu_0 = expand_scalar(mu_0, (B,)) 88 | sigma_0 = expand_cov(sigma_0, (B,B)) 89 | self._gaussian = Gaussian(mu_0=mu_0, sigma_0=sigma_0, kappa_0=kappa_0, nu_0=max(nu_0, B+2.)) 90 | 91 | self.is_diagonal_weight_special = is_diagonal_weight_special 92 | if is_diagonal_weight_special: 93 | self._self_gaussian = \ 94 | Gaussian(mu_0=mu_0, sigma_0=sigma_0, kappa_0=kappa_0, nu_0=nu_0) 95 | 96 | @property 97 | def mu_W(self): 98 | N, B = self.N, self.B 99 | mu = np.zeros((N, N, B)) 100 | if self.is_diagonal_weight_special: 101 | # Set off-diagonal weights 102 | mask = np.ones((N, N), dtype=bool) 103 | mask[np.diag_indices(N)] = False 104 | mu[mask] = self._gaussian.mu 105 | 106 | # set diagonal weights 107 | mask = np.eye(N).astype(bool) 108 | mu[mask] = self._self_gaussian.mu 109 | 110 | else: 111 | mu = np.tile(self._gaussian.mu[None,None,:], (N, N, 1)) 112 | return mu 113 | 114 | @property 115 | def sigma_W(self): 116 | N, B = self.N, self.B 117 | if self.is_diagonal_weight_special: 118 | sigma = np.zeros((N, N, B, B)) 119 | # Set off-diagonal weights 120 | mask = np.ones((N, N), dtype=bool) 121 | mask[np.diag_indices(N)] = False 122 | sigma[mask] = self._gaussian.sigma 123 | 124 | # set diagonal weights 125 | mask = np.eye(N).astype(bool) 126 | sigma[mask] = self._self_gaussian.sigma 127 | 128 | else: 129 | sigma = np.tile(self._gaussian.mu[None, None, :, :], (N, N, 1, 1)) 130 | return sigma 131 | 132 | def resample(self, data=[]): 133 | super(_IndependentGaussianMixin, self).resample(data) 134 | A, W = data 135 | N, B = self.N, self.B 136 | if self.is_diagonal_weight_special: 137 | # Resample prior for off-diagonal weights 138 | mask = np.ones((N, N), dtype=bool) 139 | mask[np.diag_indices(N)] = False 140 | mask = mask & A 141 | self._gaussian.resample(W[mask]) 142 | 143 | # Resample prior for diagonal weights 144 | mask = np.eye(N).astype(bool) & A 145 | self._self_gaussian.resample(W[mask]) 146 | 147 | else: 148 | # Resample prior for all weights 149 | self._gaussian.resample(W[A]) 150 | 151 | class _FixedWeightsMixin(_NetworkModel): 152 | def __init__(self, N, B, 153 | mu=0.0, sigma=1.0, 154 | mu_self=None, sigma_self=None, 155 | **kwargs): 156 | super(_FixedWeightsMixin, self).__init__(N, B) 157 | self._mu = expand_scalar(mu, (N, N, B)) 158 | self._sigma = expand_cov(mu, (N, N, B, B)) 159 | 160 | if (mu_self is not None) and (sigma_self is not None): 161 | self._mu[np.arange(N), np.arange(N), :] = expand_scalar(mu_self, (N, B)) 162 | self._sigma[np.arange(N), np.arange(N), :] = expand_cov(sigma_self, (N, B, B)) 163 | 164 | @property 165 | def mu_W(self): 166 | return self._mu 167 | 168 | @property 169 | def sigma_W(self): 170 | return self._sigma 171 | 172 | def resample(self,data=[]): 173 | super(_FixedWeightsMixin, self).resample(data) 174 | 175 | # TODO: Define the stochastic block models 176 | 177 | ### Adjacency models 178 | class _FixedAdjacencyMixin(_NetworkModel): 179 | def __init__(self, N, B, rho=0.5, rho_self=None, **kwargs): 180 | super(_FixedAdjacencyMixin, self).__init__(N, B) 181 | self._rho = expand_scalar(rho, (N, N)) 182 | if rho_self is not None: 183 | self._rho[np.diag_indices(N)] = rho_self 184 | 185 | @property 186 | def rho(self): 187 | return self._rho 188 | 189 | def resample(self,data=[]): 190 | super(_FixedAdjacencyMixin, self).resample(data) 191 | 192 | 193 | 194 | class _DenseAdjacencyMixin(_NetworkModel): 195 | def __init__(self, N, B, **kwargs): 196 | super(_DenseAdjacencyMixin, self).__init__(N, B) 197 | self._rho = np.ones((N,N)) 198 | 199 | @property 200 | def rho(self): 201 | return self._rho 202 | 203 | def resample(self,data=[]): 204 | super(_DenseAdjacencyMixin, self).resample(data) 205 | 206 | 207 | class _IndependentBernoulliMixin(_NetworkModel): 208 | 209 | def __init__(self, N, B, 210 | a_0=1.0, b_0=1.0, 211 | is_diagonal_conn_special=True, 212 | **kwargs): 213 | super(_IndependentBernoulliMixin, self).__init__(N, B) 214 | raise NotImplementedError("TODO: Implement the BetaBernoulli class") 215 | 216 | assert np.isscalar(a_0) 217 | assert np.isscalar(b_0) 218 | self._betabernoulli = BetaBernoulli(a_0, b_0) 219 | 220 | self.is_diagonal_conn_special = is_diagonal_conn_special 221 | if is_diagonal_conn_special: 222 | self._self_betabernoulli = BetaBernoulli(a_0, b_0) 223 | 224 | @property 225 | def rho(self): 226 | N, B = self.N, self.B 227 | rho = np.zeros((N, N)) 228 | if self.is_diagonal_conn_special: 229 | # Set off-diagonal weights 230 | mask = np.ones((N, N), dtype=bool) 231 | mask[np.diag_indices(N)] = False 232 | rho[mask] = self._betabernoulli.rho 233 | 234 | # set diagonal weights 235 | mask = np.eye(N).astype(bool) 236 | rho[mask] = self._self_betabernoulli.rho 237 | 238 | else: 239 | rho = self._betabernoulli.rho * np.ones((N, N)) 240 | return rho 241 | 242 | def resample(self, data=[]): 243 | super(_IndependentBernoulliMixin, self).resample(data) 244 | A, W = data 245 | N, B = self.N, self.B 246 | if self.is_diagonal_conn_special: 247 | # Resample prior for off-diagonal conns 248 | mask = np.ones((N, N), dtype=bool) 249 | mask[np.diag_indices(N)] = False 250 | self._betabernoulli.resample(A[mask]) 251 | 252 | # Resample prior for off-diagonal conns 253 | mask = np.eye(N).astype(bool) 254 | self._self_betabernoulli.resample(A[mask]) 255 | 256 | else: 257 | # Resample prior for all conns 258 | mask = np.ones((N, N), dtype=bool) 259 | self._betabernoulli.resample(A[mask]) 260 | 261 | # TODO: Define the distance and block models 262 | 263 | ### Define different combinations of network models 264 | class FixedMeanDenseNetwork(_DenseAdjacencyMixin, 265 | _FixedWeightsMixin): 266 | pass 267 | 268 | class FixedMeanSparseNetwork(_FixedAdjacencyMixin, 269 | _FixedWeightsMixin): 270 | pass 271 | 272 | class NIWDenseNetwork(_DenseAdjacencyMixin, 273 | _IndependentGaussianMixin): 274 | pass 275 | 276 | class NIWSparseNetwork(_FixedAdjacencyMixin, 277 | _IndependentGaussianMixin): 278 | pass 279 | -------------------------------------------------------------------------------- /pyglm/plotting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def plot_glm(data, 4 | weights, 5 | adjacency, 6 | firingrates, 7 | std_firingrates=None, 8 | fig=None, 9 | axs=None, 10 | handles=None, 11 | title=None, 12 | figsize=(6, 3), 13 | W_lim=3, 14 | pltslice=slice(0, 500), 15 | data_index=0, 16 | N_to_plot=2): 17 | """ 18 | Plot the parameters of the model 19 | :return: 20 | """ 21 | Y = data 22 | W, A = weights, adjacency 23 | N = W.shape[0] 24 | 25 | # Do the imports here so that plotting stuff isn't loaded 26 | # unless it is necessary 27 | import matplotlib.pyplot as plt 28 | import matplotlib.gridspec as gridspec 29 | from mpl_toolkits.axes_grid1 import make_axes_locatable 30 | 31 | if handles is None: 32 | # If handles are not given, create a new plot 33 | handles = [] 34 | 35 | fig = plt.figure(figsize=figsize) 36 | gs = gridspec.GridSpec(N_to_plot, 3) 37 | W_ax = fig.add_subplot(gs[:, 0]) 38 | A_ax = fig.add_subplot(gs[:, 1]) 39 | lam_axs = [fig.add_subplot(gs[i, 2]) for i in range(N_to_plot)] 40 | axs = (W_ax, A_ax, lam_axs) 41 | 42 | # Weight matrix 43 | h_W = W_ax.imshow(W[:, :, 0], vmin=-W_lim, vmax=W_lim, cmap="RdBu", interpolation="nearest") 44 | W_ax.set_xlabel("pre") 45 | W_ax.set_ylabel("post") 46 | W_ax.set_xticks(np.arange(N)) 47 | W_ax.set_xticklabels(np.arange(N) + 1) 48 | W_ax.set_yticks(np.arange(N)) 49 | W_ax.set_yticklabels(np.arange(N) + 1) 50 | W_ax.set_title("Weights") 51 | 52 | # Colorbar 53 | divider = make_axes_locatable(W_ax) 54 | cbax = divider.new_horizontal(size="5%", pad=0.05) 55 | fig.add_axes(cbax) 56 | plt.colorbar(h_W, cax=cbax) 57 | handles.append(h_W) 58 | 59 | # Adjacency matrix 60 | h_A = A_ax.imshow(A, vmin=0, vmax=1, cmap="Greys", interpolation="nearest") 61 | A_ax.set_xlabel("pre") 62 | A_ax.set_ylabel("post") 63 | A_ax.set_title("Adjacency") 64 | A_ax.set_xticks(np.arange(N)) 65 | A_ax.set_xticklabels(np.arange(N) + 1) 66 | A_ax.set_yticks(np.arange(N)) 67 | A_ax.set_yticklabels(np.arange(N) + 1) 68 | 69 | # Colorbar 70 | divider = make_axes_locatable(A_ax) 71 | cbax = divider.new_horizontal(size="5%", pad=0.05) 72 | fig.add_axes(cbax) 73 | plt.colorbar(h_A, cax=cbax) 74 | handles.append(h_A) 75 | 76 | # Plot the true and inferred rates 77 | for n in range(min(N, N_to_plot)): 78 | tn = np.where(Y[pltslice, n])[0] 79 | lam_axs[n].plot(tn, np.ones_like(tn), 'ko', markersize=4) 80 | 81 | # If given, plot the mean+-std of the firing rates 82 | if std_firingrates is not None: 83 | sausage_plot(np.arange(pltslice.start, pltslice.stop), 84 | firingrates[pltslice, n], 85 | std_firingrates[pltslice,n], 86 | sgax=lam_axs[n], 87 | alpha=0.5) 88 | 89 | h_fr = lam_axs[n].plot(firingrates[pltslice, n], label="True")[0] 90 | lam_axs[n].set_ylim(-0.05, 1.1) 91 | lam_axs[n].set_ylabel("$\lambda_{}(t)$".format(n + 1)) 92 | 93 | if n == 0: 94 | lam_axs[n].set_title("Firing Rates") 95 | 96 | if n == min(N, N_to_plot) - 1: 97 | lam_axs[n].set_xlabel("Time") 98 | handles.append(h_fr) 99 | 100 | if title is not None: 101 | handles.append(fig.suptitle(title)) 102 | 103 | plt.tight_layout() 104 | 105 | else: 106 | # If we are given handles, update the data 107 | handles[0].set_data(W[:, :, 0]) 108 | handles[1].set_data(A) 109 | for n in range(min(N, N_to_plot)): 110 | handles[2 + n].set_data(np.arange(pltslice.start, pltslice.stop), firingrates[pltslice, n]) 111 | 112 | if title is not None: 113 | handles[-1].set_text(title) 114 | plt.pause(0.001) 115 | 116 | return fig, axs, handles 117 | 118 | 119 | def sausage_plot(x, y, yerr, sgax=None, **kwargs): 120 | import matplotlib.pyplot as plt 121 | from matplotlib.patches import Polygon 122 | 123 | T = x.size 124 | assert x.shape == y.shape == yerr.shape == (T,) 125 | 126 | # Get axis 127 | if sgax is None: 128 | sgax = plt.gca() 129 | 130 | # Compute envelope 131 | env = np.zeros((T*2,2)) 132 | env[:,0] = np.concatenate((x, x[::-1])) 133 | env[:,1] = np.concatenate((y + yerr, y[::-1] - yerr[::-1])) 134 | 135 | # Add the patch 136 | sgax.add_patch(Polygon(env, **kwargs)) -------------------------------------------------------------------------------- /pyglm/regression.py: -------------------------------------------------------------------------------- 1 | """ 2 | The "generalized linear models" of computational neuroscience 3 | are ultimately nonlinear vector autoregressive models in 4 | statistics. As the name suggests, the key component in these 5 | models is a regression from inputs, x, to outputs, y. 6 | 7 | When the outputs are discrete random variables, like spike 8 | counts, we typically take the regression to be a generalized 9 | linear model: 10 | 11 | y ~ p(mu(x), theta) 12 | mu(x) = f(w \dot x) 13 | 14 | where 'p' is a discrete distribution, like the Poisson, 15 | and 'f' is a "link" function that maps a linear function of 16 | x to the parameters of 'p'. Hence the name "GLM" in 17 | computational neuroscience. 18 | 19 | Our contribution is a host of hierarchical models for the 20 | weights of the GLM, along with an efficient Bayesian inference 21 | algorithm for inferring the weights, 'w', under count observations. 22 | Specifically, we build hierarchical sparse priors for the weights 23 | and then leverage Polya-gamma augmentation to perform efficient 24 | inference. 25 | 26 | This module implements these sparse regressions. 27 | """ 28 | import abc 29 | import numpy as np 30 | import numpy.random as npr 31 | 32 | from scipy.linalg import block_diag 33 | from scipy.linalg.lapack import dpotrs 34 | 35 | from pybasicbayes.abstractions import GibbsSampling 36 | from pybasicbayes.util.stats import sample_gaussian, sample_discrete_from_log, sample_invgamma 37 | 38 | from pyglm.utils.utils import logistic, expand_scalar, expand_cov 39 | 40 | class _SparseScalarRegressionBase(GibbsSampling): 41 | """ 42 | Base class for the sparse regression. 43 | We assume the outputs are scalar. 44 | 45 | T: number of observations 46 | N: number of input groups 47 | B: input dimension for each group 48 | inputs: X \in R^{T x N x B} 49 | outputs: y \in R^T 50 | 51 | model: 52 | 53 | y_t = \sum_{n=1}^N a_{n} * (w_{n} \dot x_{t,n}) + b + noise 54 | 55 | where: 56 | 57 | a_n \in {0,1} is a binary indicator 58 | w_n \in R^ B is a weight vector for group n 59 | x_n \in R^B is the input for group n 60 | b \in R is a bias vector 61 | 62 | hyperparameters: 63 | 64 | rho in [0,1]^N probability of a_n for each group n 65 | mu_w in R^{NxB} mean of weight matrices 66 | S_w in R^{NxBxB} covariance for each row of the the weight matrices 67 | mu_b in R mean of the bias vector 68 | S_b in R_+ covariance of the bias vector 69 | 70 | """ 71 | __metaclass__ = abc.ABCMeta 72 | 73 | def __init__(self, N, B, 74 | rho=0.5, 75 | mu_w=0.0, S_w=1.0, 76 | mu_b=0.0, S_b=1.0): 77 | self.N, self.B = N, B 78 | 79 | # Initialize the hyperparameters 80 | self.rho = rho 81 | self.mu_w = mu_w 82 | self.mu_b = mu_b 83 | self.S_w = S_w 84 | self.S_b = S_b 85 | 86 | # Initialize the model parameters with a draw from the prior 87 | self.a = npr.rand(N) < self.rho 88 | self.W = np.zeros((N,B)) 89 | for n in range(N): 90 | self.W[n] = self.a[n] * npr.multivariate_normal(self.mu_w[n], self.S_w[n]) 91 | 92 | self.b = npr.multivariate_normal(self.mu_b, self.S_b) 93 | 94 | # Properties 95 | @property 96 | def rho(self): 97 | return self._rho 98 | 99 | @rho.setter 100 | def rho(self, value): 101 | self._rho = expand_scalar(value, (self.N,)) 102 | 103 | @property 104 | def mu_w(self): 105 | return self._mu_w 106 | 107 | @mu_w.setter 108 | def mu_w(self, value): 109 | N, B = self.N, self.B 110 | self._mu_w = expand_scalar(value, (N, B)) 111 | 112 | @property 113 | def mu_b(self): 114 | return self._mu_b 115 | 116 | @mu_b.setter 117 | def mu_b(self, value): 118 | self._mu_b = expand_scalar(value, (1,)) 119 | 120 | @property 121 | def S_w(self): 122 | return self._S_w 123 | 124 | @S_w.setter 125 | def S_w(self, value): 126 | N, B = self.N, self.B 127 | self._S_w = expand_cov(value, (N, B, B)) 128 | 129 | @property 130 | def S_b(self): 131 | return self._S_b 132 | 133 | @S_b.setter 134 | def S_b(self, value): 135 | assert np.isscalar(value) 136 | self._S_b = expand_cov(value, (1, 1)) 137 | 138 | @property 139 | def natural_params(self): 140 | # Compute information form parameters 141 | N, B = self.N, self.B 142 | J_w = np.zeros((N, B, B)) 143 | h_w = np.zeros((N, B)) 144 | for n in range(N): 145 | J_w[n] = np.linalg.inv(self.S_w[n]) 146 | h_w[n] = J_w[n].dot(self.mu_w[n]) 147 | 148 | J_b = np.linalg.inv(self.S_b) 149 | h_b = J_b.dot(self.mu_b) 150 | 151 | return J_w, h_w, J_b, h_b 152 | 153 | @property 154 | def deterministic_sparsity(self): 155 | return np.all((self.rho < 1e-6) | (self.rho > 1-1e-6)) 156 | 157 | @abc.abstractmethod 158 | def omega(self, X, y): 159 | """ 160 | The "precision" of the observations y. For the standard 161 | homoskedastic Gaussian model, this is a function of model parameters. 162 | """ 163 | raise NotImplementedError 164 | 165 | @abc.abstractmethod 166 | def kappa(self, X, y): 167 | """ 168 | The "normalized" observations, y. For the standard 169 | homoskedastic Gaussian model, this is the data times the precision. 170 | """ 171 | raise NotImplementedError 172 | 173 | def _flatten_X(self, X): 174 | if X.ndim == 2: 175 | assert X.shape[1] == self.N*self.B 176 | elif X.ndim == 3: 177 | X = np.reshape(X, (-1, self.N * self.B)) 178 | else: 179 | raise Exception 180 | return X 181 | 182 | 183 | def extract_data(self, data): 184 | N, B = self.N, self.B 185 | 186 | assert isinstance(data, tuple) and len(data) == 2 187 | X, y = data 188 | T = X.shape[0] 189 | assert y.shape == (T, 1) or y.shape == (T,) 190 | 191 | # Reshape X such that it is T x NB 192 | X = self._flatten_X(X) 193 | return X, y 194 | 195 | def activation(self, X): 196 | N, B = self.N, self.B 197 | X = self._flatten_X(X) 198 | 199 | W = np.reshape((self.a[:, None] * self.W), (N * B,)) 200 | b = self.b[0] 201 | return X.dot(W) + b 202 | 203 | @abc.abstractmethod 204 | def mean(self, X): 205 | """ 206 | Return the expected value of y given X. 207 | """ 208 | raise NotImplementedError 209 | 210 | def _prior_sufficient_statistics(self): 211 | """ 212 | Compute the prior statistics (information form Gaussian 213 | potentials) for the complete set of weights and biases. 214 | """ 215 | N, B = self.N, self.B 216 | 217 | J_w, h_w, J_b, h_b = self.natural_params 218 | J_prior = block_diag(*J_w, J_b) 219 | assert J_prior.shape == (N*B+1, N*B+1) 220 | 221 | h_prior = np.concatenate((h_w.ravel(), h_b.ravel())) 222 | assert h_prior.shape == (N*B+1,) 223 | return J_prior, h_prior 224 | 225 | def _lkhd_sufficient_statistics(self, datas): 226 | """ 227 | Compute the likelihood statistics (information form Gaussian 228 | potentials) for each dataset. Polya-gamma regressions will 229 | have to override this class. 230 | """ 231 | N, B = self.N, self.B 232 | 233 | J_lkhd = np.zeros((N*B+1, N*B+1)) 234 | h_lkhd = np.zeros(N*B+1) 235 | 236 | # Compute the posterior sufficient statistics 237 | for data in datas: 238 | assert isinstance(data, tuple) 239 | X, y = self.extract_data(data) 240 | T = X.shape[0] 241 | 242 | # Get the precision and the normalized observations 243 | omega = self.omega(X,y) 244 | assert omega.shape == (T,) 245 | kappa = self.kappa(X,y) 246 | assert kappa.shape == (T,) 247 | 248 | # Add the sufficient statistics to J_lkhd 249 | # The last row and column correspond to the 250 | # affine term 251 | XO = X * omega[:,None] 252 | J_lkhd[:N*B, :N*B] += XO.T.dot(X) 253 | Xsum = XO.sum(0) 254 | J_lkhd[:N*B,-1] += Xsum 255 | J_lkhd[-1,:N*B] += Xsum 256 | J_lkhd[-1,-1] += omega.sum() 257 | 258 | # Add the sufficient statisticcs to h_lkhd 259 | h_lkhd[:N*B] += kappa.T.dot(X) 260 | h_lkhd[-1] += kappa.sum() 261 | 262 | return J_lkhd, h_lkhd 263 | 264 | ### Gibbs sampling 265 | def resample(self, datas): 266 | # Compute the prior and posterior sufficient statistics of W 267 | J_prior, h_prior = self._prior_sufficient_statistics() 268 | J_lkhd, h_lkhd = self._lkhd_sufficient_statistics(datas) 269 | 270 | J_post = J_prior + J_lkhd 271 | h_post = h_prior + h_lkhd 272 | 273 | # Resample a 274 | if self.deterministic_sparsity: 275 | self.a = np.round(self.rho).astype(bool) 276 | else: 277 | self._collapsed_resample_a(J_prior, h_prior, J_post, h_post) 278 | 279 | # Resample weights 280 | self._resample_W(J_post, h_post) 281 | 282 | def _collapsed_resample_a(self, J_prior, h_prior, J_post, h_post): 283 | """ 284 | """ 285 | N, B, rho = self.N, self.B, self.rho 286 | perm = npr.permutation(self.N) 287 | 288 | ml_prev = self._marginal_likelihood(J_prior, h_prior, J_post, h_post) 289 | for n in perm: 290 | # TODO: Check if rho is deterministic 291 | 292 | # Compute the marginal prob with and without A[m,n] 293 | lps = np.zeros(2) 294 | # We already have the marginal likelihood for the current value of a[m] 295 | # We just need to add the prior 296 | v_prev = int(self.a[n]) 297 | lps[v_prev] += ml_prev 298 | lps[v_prev] += v_prev * np.log(rho[n]) + (1-v_prev) * np.log(1-rho[n]) 299 | 300 | # Now compute the posterior stats for 1-v 301 | v_new = 1 - v_prev 302 | self.a[n] = v_new 303 | 304 | ml_new = self._marginal_likelihood(J_prior, h_prior, J_post, h_post) 305 | 306 | lps[v_new] += ml_new 307 | lps[v_new] += v_new * np.log(rho[n]) + (1-v_new) * np.log(1-rho[n]) 308 | 309 | # Sample from the marginal probability 310 | # max_lps = max(lps[0], lps[1]) 311 | # se_lps = np.sum(np.exp(lps-max_lps)) 312 | # lse_lps = np.log(se_lps) + max_lps 313 | # ps = np.exp(lps - lse_lps) 314 | # v_smpl = npr.rand() < ps[1] 315 | v_smpl = sample_discrete_from_log(lps) 316 | self.a[n] = v_smpl 317 | 318 | # Cache the posterior stats and update the matrix objects 319 | if v_smpl != v_prev: 320 | ml_prev = ml_new 321 | 322 | 323 | def _resample_W(self, J_post, h_post): 324 | """ 325 | Resample the weight of a connection (synapse) 326 | """ 327 | N, B = self.N, self.B 328 | 329 | a = np.concatenate((np.repeat(self.a, self.B), [1])).astype(np.bool) 330 | Jp = J_post[np.ix_(a, a)] 331 | hp = h_post[a] 332 | 333 | # Sample in information form 334 | W = sample_gaussian(J=Jp, h=hp) 335 | 336 | # Set bias and weights 337 | self.W *= 0 338 | self.W[self.a, :] = W[:-1].reshape((-1,B)) 339 | # self.W = np.reshape(W[:-1], (D,N,B)) 340 | self.b = np.reshape(W[-1], (1,)) 341 | 342 | 343 | def _marginal_likelihood(self, J_prior, h_prior, J_post, h_post): 344 | """ 345 | Compute the marginal likelihood as the ratio of log normalizers 346 | """ 347 | a = np.concatenate((np.repeat(self.a, self.B), [1])).astype(np.bool) 348 | 349 | # Extract the entries for which A=1 350 | J0 = J_prior[np.ix_(a, a)] 351 | h0 = h_prior[a] 352 | Jp = J_post[np.ix_(a, a)] 353 | hp = h_post[a] 354 | 355 | # This relates to the mean/covariance parameterization as follows 356 | # log |C| = log |J^{-1}| = -log |J| 357 | # and 358 | # mu^T C^{-1} mu = mu^T h 359 | # = mu C^{-1} C h 360 | # = h^T C h 361 | # = h^T J^{-1} h 362 | # ml = 0 363 | # ml -= 0.5*np.linalg.slogdet(Jp)[1] 364 | # ml += 0.5*np.linalg.slogdet(J0)[1] 365 | # ml += 0.5*hp.dot(np.linalg.solve(Jp, hp)) 366 | # ml -= 0.5*h0.T.dot(np.linalg.solve(J0, h0)) 367 | 368 | # Now compute it even faster using the Cholesky! 369 | L0 = np.linalg.cholesky(J0) 370 | Lp = np.linalg.cholesky(Jp) 371 | 372 | ml = 0 373 | ml -= np.sum(np.log(np.diag(Lp))) 374 | ml += np.sum(np.log(np.diag(L0))) 375 | ml += 0.5*hp.T.dot(dpotrs(Lp, hp, lower=True)[0]) 376 | ml -= 0.5*h0.T.dot(dpotrs(L0, h0, lower=True)[0]) 377 | 378 | return ml 379 | 380 | class SparseGaussianRegression(_SparseScalarRegressionBase): 381 | """ 382 | The standard case of a sparse regression with Gaussian observations. 383 | """ 384 | def __init__(self, N, B, 385 | a_0=2.0, b_0=2.0, eta=None, 386 | **kwargs): 387 | super(SparseGaussianRegression, self).__init__(N, B, **kwargs) 388 | 389 | # Initialize the noise model 390 | assert np.isscalar(a_0) and a_0 > 0 391 | assert np.isscalar(b_0) and a_0 > 0 392 | self.a_0, self.b_0 = a_0, b_0 393 | if eta is not None: 394 | assert np.isscalar(eta) and eta > 0 395 | self.eta = eta 396 | else: 397 | # Sample eta from its inverse gamma prior 398 | self.eta = sample_invgamma(self.a_0, self.b_0) 399 | 400 | def log_likelihood(self, x): 401 | N, B, eta = self.N, self.B, self.eta 402 | 403 | X, y = self.extract_data(x) 404 | return -0.5 * np.log(2*np.pi*eta) -0.5 * (y-self.mean(X))**2 / eta 405 | 406 | def rvs(self,size=[], X=None, psi=None): 407 | N, B = self.N, self.B 408 | 409 | if psi is None: 410 | if X is None: 411 | assert isinstance(size, int) 412 | X = npr.randn(size,N*B) 413 | 414 | X = self._flatten_X(X) 415 | psi = self.mean(X) 416 | 417 | return psi + np.sqrt(self.eta) * npr.randn(*psi.shape) 418 | 419 | def omega(self, X, y): 420 | T = X.shape[0] 421 | return 1./self.eta * np.ones(T) 422 | 423 | def kappa(self, X, y): 424 | return y / self.eta 425 | 426 | def resample(self, datas): 427 | super(SparseGaussianRegression, self).resample(datas) 428 | self._resample_eta(datas) 429 | 430 | def mean(self, X): 431 | return self.activation(X) 432 | 433 | def _resample_eta(self, datas): 434 | N, B = self.N, self.B 435 | 436 | alpha = self.a_0 437 | beta = self.b_0 438 | for data in datas: 439 | X, y = self.extract_data(data) 440 | T = X.shape[0] 441 | 442 | alpha += T / 2.0 443 | beta += np.sum((y-self.mean(X))**2) 444 | 445 | self.eta = sample_invgamma(alpha, beta) 446 | 447 | 448 | class GaussianRegression(SparseGaussianRegression): 449 | """ 450 | The standard scalar regression has dense weights. 451 | """ 452 | def __init__(self, N, B, 453 | **kwargs): 454 | rho = np.ones(N) 455 | kwargs["rho"] = rho 456 | super(GaussianRegression, self).__init__(N, B, **kwargs) 457 | 458 | 459 | class _SparsePGRegressionBase(_SparseScalarRegressionBase): 460 | """ 461 | Extend the sparse scalar regression to handle count observations 462 | by leveraging the Polya-gamma augmentation for logistic regression 463 | models. This supports the subclasses implemented below. Namely: 464 | - SparseBernoulliRegression 465 | - SparseBinomialRegression 466 | - SparseNegativeBinomialRegression 467 | """ 468 | __metaclass__ = abc.ABCMeta 469 | 470 | def __init__(self, N, B, **kwargs): 471 | super(_SparsePGRegressionBase, self).__init__(N, B, **kwargs) 472 | 473 | # Initialize Polya-gamma samplers 474 | import pypolyagamma as ppg 475 | num_threads = ppg.get_omp_num_threads() 476 | seeds = npr.randint(2 ** 16, size=num_threads) 477 | self.ppgs = [ppg.PyPolyaGamma(seed) for seed in seeds] 478 | 479 | @abc.abstractmethod 480 | def a_func(self, y): 481 | raise NotImplementedError 482 | 483 | @abc.abstractmethod 484 | def b_func(self, y): 485 | raise NotImplementedError 486 | 487 | @abc.abstractmethod 488 | def c_func(self, y): 489 | raise NotImplementedError 490 | 491 | def log_likelihood(self, x): 492 | X, y = self.extract_data(x) 493 | psi = self.activation(X) 494 | return np.log(self.c_func(y)) + self.a_func(y) * psi - self.b_func(y) * np.log1p(np.exp(psi)) 495 | 496 | def omega(self, X, y): 497 | """ 498 | In the Polya-gamma augmentation, the precision is 499 | given by an auxiliary variable that we must sample 500 | """ 501 | import pypolyagamma as ppg 502 | psi = self.activation(X) 503 | omega = np.zeros(y.size) 504 | ppg.pgdrawvpar(self.ppgs, 505 | self.b_func(y).ravel(), 506 | psi.ravel(), 507 | omega) 508 | return omega.reshape(y.shape) 509 | 510 | def kappa(self, X, y): 511 | return self.a_func(y) - self.b_func(y) / 2.0 512 | 513 | 514 | class SparseBernoulliRegression(_SparsePGRegressionBase): 515 | def a_func(self, data): 516 | return data 517 | 518 | def b_func(self, data): 519 | return np.ones_like(data, dtype=np.float) 520 | 521 | def c_func(self, data): 522 | return 1.0 523 | 524 | def mean(self, X): 525 | psi = self.activation(X) 526 | return logistic(psi) 527 | 528 | def rvs(self, X=None, size=[], psi=None): 529 | if psi is None: 530 | if X is None: 531 | assert isinstance(size, int) 532 | X = npr.randn(size, self.N*self.B) 533 | 534 | X = self._flatten_X(X) 535 | p = self.mean(X) 536 | else: 537 | p = logistic(psi) 538 | 539 | y = npr.rand(*p.shape) < p 540 | 541 | return y 542 | 543 | 544 | class BernoulliRegression(SparseBernoulliRegression): 545 | """ 546 | The standard Bernoulli regression has dense weights. 547 | """ 548 | 549 | def __init__(self, N, B, **kwargs): 550 | rho = np.ones(N) 551 | kwargs["rho"] = rho 552 | super(BernoulliRegression, self).__init__(N, B, **kwargs) 553 | 554 | -------------------------------------------------------------------------------- /pyglm/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slinderman/pyglm/cb93ce901e6bb127b6baa408b8e5b9995c932f17/pyglm/utils/__init__.py -------------------------------------------------------------------------------- /pyglm/utils/basis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.linalg 3 | import scipy.signal as sig 4 | 5 | def convolve_with_basis(S, basis): 6 | """ 7 | Convolve each column of the event count matrix with this basis 8 | :param S: TxN matrix of inputs. 9 | T is the number of time bins 10 | N is the number of input dimensions. 11 | :return: TxNxB tensor of inputs convolved with bases 12 | """ 13 | # TODO: Check that basis is filtered causally 14 | (T,N) = S.shape 15 | (R,B) = basis.shape 16 | 17 | # Concatenate basis with a layer of ones 18 | basis = np.vstack((np.zeros((1, B)), basis)) 19 | 20 | # Initialize array for filtered stimulus 21 | F = np.empty((T,N,B)) 22 | 23 | # Compute convolutions fo each basis vector, one at a time 24 | for b in np.arange(B): 25 | F[:,:,b] = sig.fftconvolve(S, 26 | np.reshape(basis[:,b],(R+1,1)), 27 | 'full')[:T,:] 28 | 29 | # Check for positivity 30 | if np.amin(basis) >= 0 and np.amin(S) >= 0: 31 | np.clip(F, 0, np.inf, out=F) 32 | assert np.amin(F) >= 0, "convolution should be >= 0" 33 | 34 | return F 35 | 36 | def interpolate_basis(basis, dt, dt_max, 37 | norm=True, allow_instantaneous=False): 38 | # Interpolate basis at the resolution of the data 39 | L,B = basis.shape 40 | t_int = np.arange(0.0, dt_max, step=dt) 41 | t_bas = np.linspace(0.0, dt_max, L) 42 | 43 | ibasis = np.zeros((len(t_int), B)) 44 | for b in np.arange(B): 45 | ibasis[:,b] = np.interp(t_int, t_bas, basis[:,b]) 46 | 47 | # Normalize so that the interpolated basis has volume 1 48 | if norm: 49 | # ibasis /= np.trapz(ibasis,t_int,axis=0) 50 | ibasis /= (dt * np.sum(ibasis, axis=0)) 51 | 52 | if not allow_instantaneous: 53 | # Typically, the impulse responses are applied to times 54 | # (t+1:t+R). That means we need to prepend a row of zeros to make 55 | # sure the basis remains causal 56 | ibasis = np.vstack((np.zeros((1,B)), ibasis)) 57 | 58 | return ibasis 59 | 60 | 61 | def cosine_basis(B, 62 | L=100, 63 | orth=False, 64 | norm=True, 65 | n_eye=0, 66 | a=1.0/120, 67 | b=0.5): 68 | """ 69 | Create a basis of raised cosine tuning curves 70 | """ 71 | # Number of cosine basis functions 72 | n_cos = B - n_eye 73 | assert n_cos >= 0 and n_eye >= 0 74 | 75 | 76 | # The first n_eye basis elements are identity vectors in the first time bins 77 | basis = np.zeros((L,B)) 78 | basis[:n_eye,:n_eye] = np.eye(n_eye) 79 | 80 | # The remaining basis elements are raised cosine functions with peaks 81 | # logarithmically warped between [n_eye*dt:dt_max]. 82 | nlin = lambda t: np.log(a*t+b) # Nonlinearity 83 | u_ir = nlin(np.arange(L)) # Time in log time 84 | ctrs = u_ir[np.floor(np.linspace(n_eye,(L/2.0),n_cos)).astype(np.int)] 85 | if len(ctrs) == 1: 86 | w = ctrs/2 87 | else: 88 | w = (ctrs[-1]-ctrs[0])/(n_cos-1) # Width of the cosine tuning curves 89 | 90 | # Basis function is a raised cosine centered at c with width w 91 | basis_fn = lambda u,c,w: (np.cos(np.maximum(-np.pi,np.minimum(np.pi,(u-c)*np.pi/w/2.0)))+1)/2.0 92 | for i in np.arange(n_cos): 93 | basis[:,n_eye+i] = basis_fn(u_ir,ctrs[i],w) 94 | 95 | 96 | # Orthonormalize basis (this may decrease the number of effective basis vectors) 97 | if orth: 98 | basis = scipy.linalg.orth(basis) 99 | elif norm: 100 | # We can only normalize nonnegative bases 101 | if np.any(basis<0): 102 | raise Exception("We can only normalize nonnegative impulse responses!") 103 | 104 | # Normalize such that \int_0^1 b(t) dt = 1 105 | basis = basis / np.tile(np.sum(basis,axis=0), [L,1]) / (1.0/L) 106 | 107 | return basis -------------------------------------------------------------------------------- /pyglm/utils/profiling.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import sys, StringIO, inspect, os, functools, time, collections 4 | 5 | 6 | try: 7 | import line_profiler 8 | _prof = line_profiler.LineProfiler() 9 | 10 | def line_profiled(func): 11 | mod = inspect.getmodule(func) 12 | if 'PROFILING' in os.environ or (hasattr(mod,'PROFILING') and mod.PROFILING): 13 | return _prof(func) 14 | return func 15 | 16 | def show_line_stats(stream=None): 17 | _prof.print_stats(stream=stream) 18 | except ImportError: 19 | print "Failed to load line profiler" 20 | line_profiled = lambda x: x -------------------------------------------------------------------------------- /pyglm/utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def logistic(x): 4 | return 1./(1+np.exp(-x)) 5 | 6 | # Expand a mean vector 7 | def expand_scalar(x, shp): 8 | if np.isscalar(x): 9 | x *= np.ones(shp) 10 | else: 11 | assert x.shape == shp 12 | return x 13 | 14 | # Expand the covariance matrices 15 | def expand_cov(c, shp): 16 | assert len(shp) >= 2 17 | assert shp[-2] == shp[-1] 18 | d = shp[-1] 19 | if np.isscalar(c): 20 | c = c * np.eye(d) 21 | tshp = np.array(shp) 22 | tshp[-2:] = 1 23 | c = np.tile(c, tshp) 24 | else: 25 | assert c.shape == shp 26 | 27 | return c 28 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | import numpy as np 4 | 5 | setup(name='pyglm', 6 | version='0.1', 7 | description='Bayesian inference for generalized linear models of neural spike trains', 8 | author='Scott Linderman', 9 | author_email='scott.linderman@columbia.edu', 10 | url='http://www.github.com/slinderman/pyglm', 11 | include_dirs=[np.get_include(),], 12 | install_requires=[ 13 | 'numpy>=1.9.3', 'scipy>=0.16', 'matplotlib', 'pybasicbayes', 'pypolyagamma'], 14 | classifiers=[ 15 | 'Intended Audience :: Science/Research', 16 | 'Programming Language :: Python', 17 | 'Programming Language :: C++', 18 | ], 19 | keywords=[ 20 | 'generalized linear model', 'autoregressive', 'AR', 'computational neuroscience'], 21 | platforms="ALL", 22 | packages=['pyglm', 'pyglm.utils']) 23 | -------------------------------------------------------------------------------- /test/test_generate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | sns.set_style("white") 5 | 6 | from pyglm.regression import SparseBernoulliRegression 7 | from pyglm.models import NonlinearAutoregressiveModel 8 | from pyglm.utils.basis import cosine_basis 9 | 10 | def test_means(): 11 | N = 2 # Number of neurons 12 | B = 3 # Number of basis functions 13 | L = 10 # Length of basis functions 14 | 15 | basis = cosine_basis(B, L=L) / L 16 | regressions = [SparseBernoulliRegression(N, B, mu_b=-2, S_b=0.1) for n in range(N)] 17 | model = NonlinearAutoregressiveModel(N, regressions, basis=basis) 18 | 19 | X, Y = model.generate(T=1000, keep=False) 20 | 21 | model.add_data(Y) 22 | Xtest = model.data_list[0][0] 23 | 24 | assert np.allclose(X, Xtest) 25 | 26 | means = model.means 27 | model.data_list[0] = (X, Y) 28 | means2 = model.means 29 | assert np.allclose(means, means2) 30 | 31 | plt.figure() 32 | for n in range(N): 33 | plt.subplot(N,1,n+1) 34 | plt.plot(means[0][:,n], lw=4) 35 | plt.plot(means2[0][:,n], lw=1) 36 | tn = np.where(Y[:,n])[0] 37 | plt.plot(tn, np.ones_like(tn), 'ko') 38 | plt.ylim(-0.05, 1.1) 39 | plt.show() 40 | 41 | 42 | def test_basis(): 43 | N = 2 # Number of neurons 44 | B = 3 # Number of basis functions 45 | L = 10 # Length of basis functions 46 | 47 | regressions = [SparseBernoulliRegression(N, B, mu_b=-2, S_b=0.1) for n in range(N)] 48 | model = NonlinearAutoregressiveModel(N, regressions, B=B) 49 | 50 | X, Y = model.generate(T=1000, keep=False) 51 | 52 | # Check that the lags are working properly 53 | for n in range(N): 54 | for b in range(B): 55 | assert np.allclose(Y[:-(b+1),n], X[(b+1):,n,b]) 56 | 57 | 58 | if __name__ == "__main__": 59 | test_basis() 60 | test_means() 61 | --------------------------------------------------------------------------------