├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── requirements.txt ├── seqnmf ├── __init__.py ├── data │ └── MackeviciusData.mat ├── dev.py ├── helpers.py └── seqnmf.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | **.asv 2 | **.m~ 3 | *.DS_Store 4 | *ipynb_checkpoints* 5 | .idea* 6 | *__pycache__* 7 | .pypirc 8 | dist/* 9 | *.egg-info* 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Contextual Dynamics Laboratory 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | include LICENSE 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # seqNMF 2 | 3 | This package is a Python port of the [SeqNMF MATLAB Toolbox](https://github.com/FeeLab/seqNMF). It provides a tool for performing unsupervised discovery of temporal sequences in high-dimensional data. 4 | 5 | Credit for MATLAB toolbox: [Emily Mackevicius, Andrew Bahle, and the Fee Lab](http://web.mit.edu/feelab/). 6 | 7 | This Python toolbox was developed by [Jeremy Manning](http://www.context-lab.com/) as a hackathon project during the [2018 MIND Summer School](https://summer-mind.github.io/). 8 | 9 | ### Description 10 | SeqNMF uses regularized convolutional non-negative matrix factorization to extract repeated sequential patterns from high-dimensional data. The algorithm can discovery of patterns directly from timeseries data without reference to external markers or labels. 11 | 12 | For more information please see: 13 | - [**Original MATLAB implementation**](https://github.com/FeeLab/seqNMF) 14 | 15 | - [**preprint**](https://www.biorxiv.org/content/early/2018/03/02/273128) 16 | 17 | - [**COSYNE talk**](https://www.youtube.com/watch?reload=9&v=XyWtCtZ_m-8) 18 | 19 | - tutorial [**video**](https://cbmm.mit.edu/video/unsupervised-discovery-temporal-sequences-high-dimensional-datasets) and [**materials**](https://stellar.mit.edu/S/project/bcs-comp-tut/materials.html) 20 | 21 | - Simons foundation [**article**](https://www.simonsfoundation.org/2018/05/04/finding-neural-patterns-in-the-din/) 22 | 23 | ### Installing the toolbox 24 | To install the latest official version of this toolbox type 25 | ``` 26 | pip install --upgrade seqnmf 27 | ``` 28 | 29 | To install the (bleeding edge) development version type 30 | ``` 31 | pip install --upgrade git+https://github.com/ContextLab/seqnmf 32 | ``` 33 | 34 | ### Using the toolbox 35 | 36 | Given a N (number of features) by T (number of timepoints) data matrix, `X`, the commands below may be used to factorize the data using seqNFM: 37 | ``` 38 | from seqnmf import seqnmf 39 | 40 | W, H, cost, loadings, power = seqnmf(X, K=20, L=100, Lambda=0.001) 41 | ``` 42 | 43 | Here `K` is the (maximum) number of factors, `L` is (maximum) sequence length, and `Lambda` is a regularization parameter. The data matrix is factorized into a tensor product of `W` and `H` as follows: 44 | 45 | according to: 46 | ``` 47 | ---------- 48 | L / /| 49 | / / | 50 | ---------------- /---------/ | ---------------- 51 | | | | | | | | 52 | N | X | = N | W | / (*) K | H | 53 | | | | | / | | 54 | ---------------- /----------/ ---------------- 55 | T K T 56 | ``` 57 | where `W` contains each of the `N` by `L` sequence factors and `H` describes the combination of sequences that are present at each timepoint. 58 | 59 | The `cost` output stores the reconstruction error after each iteration. The `loadings` variable stores the factor loadings. The `power` variable provides a measure of how well the original data is captured by the full reconstruction. 60 | 61 | The `plot` function may be used to visualize the discovered structure by calling `plot(W, H)`. 62 | 63 | ### Demo 64 | 65 | An example dataset, ported from the MATLAB toolbox, is provided as part of the seqnmf Python toolbox. To apply seqNMF to the example data and generate a plot, run: 66 | ``` 67 | from seqnmf import seqnmf, plot, example_data 68 | 69 | [W, H, cost, loadings, power] = seqnmf(example_data, K=20, L=100, Lambda=0.001, plot_it=True) 70 | 71 | plot(W, H).show() 72 | ``` 73 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | matplotlib 3 | scipy 4 | seaborn -------------------------------------------------------------------------------- /seqnmf/__init__.py: -------------------------------------------------------------------------------- 1 | from .seqnmf import seqnmf, plot 2 | 3 | from scipy.io import loadmat 4 | import os 5 | import pkg_resources 6 | 7 | DATA_PATH = pkg_resources.resource_filename('seqnmf', 'data/') 8 | example_data = loadmat(os.path.join(DATA_PATH, 'MackeviciusData.mat'))['NEURAL'] 9 | 10 | del DATA_PATH 11 | del os 12 | del loadmat 13 | del pkg_resources -------------------------------------------------------------------------------- /seqnmf/data/MackeviciusData.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ContextLab/seqnmf/796ec8cdd5ff846f5733c3dd4ac02352a0da3e65/seqnmf/data/MackeviciusData.mat -------------------------------------------------------------------------------- /seqnmf/dev.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | from matplotlib import pyplot as plt 3 | import os 4 | from scipy.io import loadmat 5 | from seqnmf import seqnmf, plot 6 | 7 | data = loadmat(os.path.join('data', 'MackeviciusData.mat')) 8 | W, H, cost, loadings, power = seqnmf(data['NEURAL']) 9 | 10 | h = plot(W, H) 11 | h.show() 12 | 13 | sns.heatmap(data['NEURAL'], cmap='gray_r') 14 | plt.show() 15 | 16 | -------------------------------------------------------------------------------- /seqnmf/helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import warnings 3 | 4 | def get_shapes(W, H, force_full=False): 5 | N = W.shape[0] 6 | T = H.shape[1] 7 | K = W.shape[1] 8 | L = W.shape[2] 9 | 10 | #trim zero padding along the L and K dimensions 11 | if not force_full: 12 | W_sum = W.sum(axis=0).sum(axis=1) 13 | H_sum = H.sum(axis=1) 14 | K = 1 15 | for k in np.arange(W.shape[1]-1, 0, -1): 16 | if (W_sum[k] > 0) or (H_sum[k] > 0): 17 | K = k+1 18 | break 19 | 20 | L = 2 21 | for l in np.arange(W.shape[2]-1, 2, -1): 22 | W_sum = W.sum(axis=1).sum(axis=0) 23 | if W_sum[l] > 0: 24 | L = l+1 25 | break 26 | 27 | return N, K, L, T 28 | 29 | def trim_shapes(W, H, N, K, L, T): 30 | return W[:N, :K, :L], H[:K, :T] 31 | 32 | def reconstruct(W, H): 33 | N, K, L, T = get_shapes(W, H, force_full=True) 34 | W, H = trim_shapes(W, H, N, K, L, T) 35 | 36 | H = np.hstack((np.zeros([K, L]), H, np.zeros([K, L]))) 37 | T += 2 * L 38 | X_hat = np.zeros([N, T]) 39 | 40 | for t in np.arange(L): 41 | X_hat += np.dot(W[:, :, t], np.roll(H, t - 1, axis=1)) 42 | 43 | return X_hat[:, L:-L] 44 | 45 | 46 | def shift_factors(W, H): 47 | warnings.simplefilter('ignore') #ignore warnings for nan-related errors 48 | 49 | N, K, L, T = get_shapes(W, H, force_full=True) 50 | W, H = trim_shapes(W, H, N, K, L, T) 51 | 52 | if L > 1: 53 | center = int(np.max([np.floor(L / 2), 1])) 54 | Wpad = np.concatenate((np.zeros([N, K, L]), W, np.zeros([N, K, L])), axis=2) 55 | 56 | for i in np.arange(K): 57 | temp = np.sum(np.squeeze(W[:, i, :]), axis=0) 58 | # return temp, temp 59 | try: 60 | cmass = int(np.max(np.floor(np.sum(temp * np.arange(1, L + 1)) / np.sum(temp)), axis=0)) 61 | except ValueError: 62 | cmass = center 63 | Wpad[:, i, :] = np.roll(np.squeeze(Wpad[:, i, :]), center - cmass, axis=1) 64 | H[i, :] = np.roll(H[i, :], cmass - center, axis=0) 65 | 66 | return Wpad[:, :, L:-L], H 67 | 68 | 69 | def compute_loadings_percent_power(V, W, H): 70 | N, K, L, T = get_shapes(W, H) 71 | W, H = trim_shapes(W, H, N, K, L, T) 72 | 73 | loadings = np.zeros(K) 74 | var_v = np.sum(np.power(V, 2)) 75 | 76 | for i in np.arange(K): 77 | WH = reconstruct(np.reshape(W[:, i, :], [W.shape[0], 1, W.shape[2]]),\ 78 | np.reshape(H[i, :], [1, H.shape[1]])) 79 | loadings[i] = np.divide(np.sum(np.multiply(2 * V.flatten(), WH.flatten()) - np.power(WH.flatten(), 2)), var_v) 80 | 81 | loadings[loadings < 0] = 0 82 | return loadings 83 | 84 | -------------------------------------------------------------------------------- /seqnmf/seqnmf.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.simplefilter('ignore') #ignore numpy incompatability warning (harmless) 3 | 4 | import numpy as np 5 | import seaborn as sns 6 | from scipy.signal import convolve2d as conv2 7 | from matplotlib import pyplot as plt 8 | from matplotlib import gridspec 9 | from .helpers import reconstruct, shift_factors, compute_loadings_percent_power, get_shapes, trim_shapes 10 | 11 | 12 | def seqnmf(X, K=10, L=100, Lambda=.001, W_init=None, H_init=None, 13 | plot_it=False, max_iter=100, tol=-np.inf, shift=True, sort_factors=True, 14 | lambda_L1W=0, lambda_L1H=0, lambda_OrthH=0, lambda_OrthW=0, M=None, 15 | use_W_update=True, W_fixed=False): 16 | ''' 17 | :param X: an N (features) by T (timepoints) data matrix to be factorized using seqNMF 18 | :param K: the (maximum) number of factors to search for; any unused factors will be set to all zeros 19 | :param L: the (maximum) number of timepoints to consider in each factor; any unused timepoints will be set to zeros 20 | :param Lambda: regularization parameter (default: 0.001) 21 | :param W_init: initial factors (if unspecified, use random initialization) 22 | :param H_init: initial per-timepoint factor loadings (if unspecified, initialize randomly) 23 | :param plot_it: if True, display progress in each update using a plot (default: False) 24 | :param max_iter: maximum number of iterations/updates 25 | :param tol: if cost is within tol of the average of the previous 5 updates, the algorithm will terminate (default: tol = -inf) 26 | :param shift: allow timepoint shifts in H 27 | :param sort_factors: sort factors by time 28 | :param lambda_L1W: regularization parameter for W (default: 0) 29 | :param lambda_L1H: regularization parameter for H (default: 0) 30 | :param lambda_OrthH: regularization parameter for H (default: 0) 31 | :param lambda_OrthW: regularization parameter for W (default: 0) 32 | :param M: binary mask of the same size as X, used to ignore a subset of the data during training (default: use all data) 33 | :param use_W_update: set to True for more accurate results; set to False for faster results (default: True) 34 | :param W_fixed: if true, fix factors (W), e.g. for cross validation (default: False) 35 | 36 | :return: 37 | :W: N (features) by K (factors) by L (per-factor timepoints) tensor of factors 38 | :H: K (factors) by T (timepoints) matrix of factor loadings (i.e. factor timecourses) 39 | :cost: a vector of length (number-of-iterations + 1) containing the initial cost and cost after each update (i.e. the reconstruction error) 40 | :loadings: the per-factor loadings-- i.e. the explanatory power of each individual factor 41 | :power: the total power (across all factors) explained by the full reconstruction 42 | ''' 43 | N = X.shape[0] 44 | T = X.shape[1] + 2 * L 45 | X = np.concatenate((np.zeros([N, L]), X, np.zeros([N, L])), axis=1) 46 | 47 | if W_init is None: 48 | W_init = np.max(X) * np.random.rand(N, K, L) 49 | if H_init is None: 50 | H_init = np.max(X) * np.random.rand(K, T) / np.sqrt(T / 3) 51 | if M is None: 52 | M = np.ones([N, T]) 53 | 54 | assert np.all(X >= 0), 'all data values must be positive!' 55 | 56 | W = W_init 57 | H = H_init 58 | 59 | X_hat = reconstruct(W, H) 60 | mask = M == 0 61 | X[mask] = X_hat[mask] 62 | 63 | smooth_kernel = np.ones([1, (2 * L) - 1]) 64 | eps = np.max(X) * 1e-6 65 | last_time = False 66 | 67 | cost = np.zeros([max_iter + 1, 1]) 68 | cost[0] = np.sqrt(np.mean(np.power(X - X_hat, 2))) 69 | 70 | for i in np.arange(max_iter): 71 | if (i == max_iter - 1) or ((i > 6) and (cost[i + 1] + tol) > np.mean(cost[i - 6:i])): 72 | cost = cost[:(i + 2)] 73 | last_time = True 74 | if i > 0: 75 | Lambda = 0 76 | 77 | WTX = np.zeros([K, T]) 78 | WTX_hat = np.zeros([K, T]) 79 | for j in np.arange(L): 80 | X_shifted = np.roll(X, -j + 1, axis=1) 81 | X_hat_shifted = np.roll(X_hat, -j + 1, axis=1) 82 | 83 | WTX += np.dot(W[:, :, j].T, X_shifted) 84 | WTX_hat += np.dot(W[:, :, j].T, X_hat_shifted) 85 | 86 | if Lambda > 0: 87 | dRdH = np.dot(Lambda * (1 - np.eye(K)), conv2(WTX, smooth_kernel, 'same')) 88 | else: 89 | dRdH = 0 90 | 91 | if lambda_OrthH > 0: 92 | dHHdH = np.dot(lambda_OrthH * (1 - np.eye(K)), conv2(H, smooth_kernel, 'same')) 93 | else: 94 | dHHdH = 0 95 | 96 | dRdH += lambda_L1H + dHHdH 97 | 98 | H *= np.divide(WTX, WTX_hat + dRdH + eps) 99 | 100 | if shift: 101 | W, H = shift_factors(W, H) 102 | W += eps 103 | 104 | norms = np.sqrt(np.sum(np.power(H, 2), axis=1)).T 105 | H = np.dot(np.diag(np.divide(1., norms + eps)), H) 106 | for j in np.arange(L): 107 | W[:, :, j] = np.dot(W[:, :, j], np.diag(norms)) 108 | 109 | if not W_fixed: 110 | X_hat = reconstruct(W, H) 111 | mask = M == 0 112 | X[mask] = X_hat[mask] 113 | 114 | if lambda_OrthW > 0: 115 | W_flat = np.sum(W, axis=2) 116 | if (Lambda > 0) and use_W_update: 117 | XS = conv2(X, smooth_kernel, 'same') 118 | 119 | for j in np.arange(L): 120 | H_shifted = np.roll(H, j - 1, axis=1) 121 | XHT = np.dot(X, H_shifted.T) 122 | X_hat_HT = np.dot(X_hat, H_shifted.T) 123 | 124 | if (Lambda > 0) and use_W_update: 125 | dRdW = Lambda * np.dot(np.dot(XS, H_shifted.T), (1. - np.eye(K))) 126 | else: 127 | dRdW = 0 128 | 129 | if lambda_OrthW > 0: 130 | dWWdW = np.dot(lambda_OrthW * W_flat, 1. - np.eye(K)) 131 | else: 132 | dWWdW = 0 133 | 134 | dRdW += lambda_L1W + dWWdW 135 | W[:, :, j] *= np.divide(XHT, X_hat_HT + dRdW + eps) 136 | 137 | X_hat = reconstruct(W, H) 138 | mask = M == 0 139 | X[mask] = X_hat[mask] 140 | cost[i + 1] = np.sqrt(np.mean(np.power(X - X_hat, 2))) 141 | 142 | if plot_it: 143 | if i > 0: 144 | try: 145 | h.close() 146 | except: 147 | pass 148 | h = plot(W, H) 149 | h.suptitle(f'iteration {i}', fontsize=8) 150 | h.show() 151 | 152 | if last_time: 153 | break 154 | 155 | X = X[:, L:-L] 156 | X_hat = X_hat[:, L:-L] 157 | H = H[:, L:-L] 158 | 159 | power = np.divide(np.sum(np.power(X, 2)) - np.sum(np.power(X - X_hat, 2)), np.sum(np.power(X, 2))) 160 | 161 | loadings = compute_loadings_percent_power(X, W, H) 162 | 163 | if sort_factors: 164 | inds = np.flip(np.argsort(loadings), 0) 165 | loadings = loadings[inds] 166 | 167 | W = W[:, inds, :] 168 | H = H[inds, :] 169 | 170 | return W, H, cost, loadings, power 171 | 172 | 173 | def plot(W, H, cmap='gray_r', factor_cmap='Spectral'): 174 | ''' 175 | :param W: N (features) by K (factors) by L (per-factor timepoints) tensor of factors 176 | :param H: K (factors) by T (timepoints) matrix of factor loadings (i.e. factor timecourses) 177 | :param cmap: colormap used to draw heatmaps for the factors, factor loadings, and data reconstruction 178 | :param factor_cmap: colormap used to distinguish individual factors 179 | :return f: matplotlib figure handle 180 | ''' 181 | 182 | N, K, L, T = get_shapes(W, H) 183 | W, H = trim_shapes(W, H, N, K, L, T) 184 | 185 | data_recon = reconstruct(W, H) 186 | 187 | fig = plt.figure(figsize=(5, 5)) 188 | gs = gridspec.GridSpec(2, 2, width_ratios=[1, 4], height_ratios=[1, 4]) 189 | ax_h = plt.subplot(gs[1]) 190 | ax_w = plt.subplot(gs[2]) 191 | ax_data = plt.subplot(gs[3]) 192 | 193 | # plot W, H, and data_recon 194 | sns.heatmap(np.hstack(list(map(np.squeeze, np.split(W, K, axis=1)))), cmap=cmap, ax=ax_w, cbar=False) 195 | sns.heatmap(H, cmap=cmap, ax=ax_h, cbar=False) 196 | sns.heatmap(data_recon, cmap=cmap, ax=ax_data, cbar=False) 197 | 198 | # add dividing bars for factors of W and H 199 | factor_colors = sns.color_palette(factor_cmap, K) 200 | for k in np.arange(K): 201 | plt.sca(ax_w) 202 | start_w = k * L 203 | plt.plot([start_w, start_w], [0, N - 1], '-', color=factor_colors[k]) 204 | 205 | plt.sca(ax_h) 206 | plt.plot([0, T - 1], [k, k], '-', color=factor_colors[k]) 207 | 208 | return fig 209 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open('requirements.txt') as f: 4 | requirements = f.read().splitlines() 5 | 6 | readme = 'Python implementation of seqNMF. For more information visit https://github.com/ContextLab/seqnmf.' 7 | 8 | with open('LICENSE') as f: 9 | license = f.read() 10 | 11 | setup( 12 | name='seqnmf', 13 | version='0.1.3', 14 | description='Python implementation of seqNMF', 15 | long_description=readme, 16 | author='Contextual Dynamics Laboratory', 17 | author_email='contextualdynamics@gmail.com', 18 | url='https://www.context-lab.com', 19 | license=license, 20 | install_requires=requirements, 21 | package_data = {'seqnmf':['data/MackeviciusData.mat']}, 22 | packages=find_packages(exclude=('tests', 'docs')) 23 | ) 24 | --------------------------------------------------------------------------------