├── slicetca ├── tests │ ├── __init__.py │ └── tests.py ├── plotting │ ├── additional.py │ ├── __init__.py │ ├── grid.py │ └── factors.py ├── core │ ├── __init__.py │ ├── helper_functions.py │ └── decompositions.py ├── invariance │ ├── __init__.py │ ├── criteria.py │ ├── invariance.py │ ├── analytic_invariance.py │ ├── helper.py │ ├── iterative_invariance.py │ └── transformations.py ├── run │ ├── __init__.py │ ├── utils.py │ ├── decompose.py │ └── grid_search.py └── __init__.py ├── setup.cfg ├── setup.py ├── LICENSE.txt ├── documentation.md ├── README.md └── img └── decomposition.svg /slicetca/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /slicetca/plotting/additional.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md -------------------------------------------------------------------------------- /slicetca/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .decompositions import * 2 | 3 | __all__ = ['PartitionTCA', 'SliceTCA', 'TCA'] 4 | -------------------------------------------------------------------------------- /slicetca/invariance/__init__.py: -------------------------------------------------------------------------------- 1 | from slicetca.invariance.invariance import invariance 2 | 3 | __all__ = ['invariance'] 4 | -------------------------------------------------------------------------------- /slicetca/plotting/__init__.py: -------------------------------------------------------------------------------- 1 | from .factors import plot 2 | from .grid import plot_grid 3 | 4 | __all__ = ['plot', 'plot_grid'] 5 | -------------------------------------------------------------------------------- /slicetca/run/__init__.py: -------------------------------------------------------------------------------- 1 | from .decompose import * 2 | from .grid_search import * 3 | from .utils import * 4 | 5 | __all__ = ['decompose', 'grid_search', 'block_mask'] 6 | -------------------------------------------------------------------------------- /slicetca/__init__.py: -------------------------------------------------------------------------------- 1 | from .invariance import * 2 | from .plotting import * 3 | from .run import * 4 | 5 | __version__ = '0.1.8' 6 | __author__ = 'Arthur Pellegrino, Heike Stein' 7 | 8 | __all__ = ['invariance', 'decompose', 'grid_search', 'block_mask', 'plot', 'plot_grid'] 9 | -------------------------------------------------------------------------------- /slicetca/core/helper_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def squared_difference(x, x_hat): 5 | return (x - x_hat) ** 2 6 | 7 | 8 | def poisson_log_likelihood(spikes, rates, spikes_factorial, activation=torch.nn.functional.softplus): 9 | 10 | likelihood = torch.exp(-activation(rates)) * torch.pow(activation(rates), spikes) / spikes_factorial 11 | 12 | return -torch.log(likelihood) 13 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | from pathlib import Path 4 | this_directory = Path(__file__).parent 5 | long_description = (this_directory / "README.md").read_text() 6 | 7 | setup( 8 | name='slicetca', 9 | packages=find_packages(exclude=['tests*']), 10 | version='1.0.4', 11 | 12 | description='Package to perform Slice Tensor Component Analysis', 13 | long_description=long_description, 14 | long_description_content_type='text/markdown', 15 | 16 | url='https://github.com/arthur-pe/slicetca', 17 | author='Arthur Pellegrino', 18 | license='MIT', 19 | install_requires=['torch', 20 | 'numpy', 21 | 'matplotlib', 22 | 'tqdm', 23 | 'scipy' 24 | ], 25 | python_requires='>=3.10', 26 | classifiers=[ 27 | 'Intended Audience :: Science/Research', 28 | 'Topic :: Scientific/Engineering :: Mathematics', 29 | 'License :: OSI Approved :: MIT License', 30 | 'Programming Language :: Python :: 3.8', 31 | ], 32 | ) 33 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2023 Arthur Pellegrino, Heike Stein, N Alex Cayco-Gajic 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. -------------------------------------------------------------------------------- /slicetca/invariance/criteria.py: -------------------------------------------------------------------------------- 1 | from itertools import combinations 2 | import torch 3 | from typing import Sequence 4 | 5 | # Example of criteria to use for L2 optimization. 6 | 7 | 8 | def orthogonality_component_type_wise(reconstructed_tensors_of_each_partition: Sequence[torch.Tensor]): 9 | """ 10 | Penalizes non-orthogonality between the reconstructed tensors of each partition/slicing. 11 | 12 | :param reconstructed_tensors_of_each_partition: The sum of the terms of a given partition/slicing. 13 | :return: Torch float. 14 | """ 15 | 16 | l = 0 17 | for combo in combinations(reconstructed_tensors_of_each_partition, 2): 18 | l += torch.square(torch.sum(combo[0] * combo[1]) / torch.sqrt(torch.sum(combo[0] ** 2)) / torch.sqrt( 19 | torch.sum(combo[1] ** 2))) 20 | return l + l2(reconstructed_tensors_of_each_partition) 21 | 22 | 23 | def l2(reconstructed_tensors_of_each_partition: Sequence[torch.Tensor]): 24 | """ 25 | Classic L_2 regularization, per reconstructed tensors of each partition/slicing. 26 | 27 | :param reconstructed_tensors_of_each_partition: The sum of the terms of a given partition/slicing. 28 | :return: Torch float. 29 | """ 30 | 31 | l = 0 32 | for t in reconstructed_tensors_of_each_partition: 33 | l += (t ** 2).mean() 34 | return l 35 | -------------------------------------------------------------------------------- /slicetca/invariance/invariance.py: -------------------------------------------------------------------------------- 1 | from slicetca.invariance.iterative_invariance import sgd_invariance 2 | from slicetca.invariance.analytic_invariance import svd_basis 3 | from slicetca.invariance.criteria import * 4 | from slicetca.core.decompositions import SliceTCA 5 | 6 | dict_L2_invariance_objectives = {'regularization': l2} 7 | dict_L3_invariance_functions = {'svd': svd_basis} 8 | 9 | def invariance(model: SliceTCA, 10 | L2: str = 'regularization', 11 | L3: str = 'svd', 12 | **kwargs): 13 | """ 14 | High level function for invariance optimization. 15 | Note: modifies inplace, deepcopy your model if you want a copy of the not invariance-optimized components. 16 | 17 | :param model: A sliceTCA model. 18 | :param L2: String, currently only supports 'regularization', you may add additional objectives. 19 | :param L3: String, currently only supports 'svd'. 20 | :param kwargs: Key-word arguments to be passed to L2 and L3 optimization functions. See iterative_function.py 21 | :return: model with modified components. 22 | """ 23 | 24 | if sum([r!=0 for r in model.ranks])>1: 25 | model = sgd_invariance(model, objective_function=dict_L2_invariance_objectives[L2], **kwargs) 26 | model = dict_L3_invariance_functions[L3](model, **kwargs) 27 | 28 | return model 29 | -------------------------------------------------------------------------------- /documentation.md: -------------------------------------------------------------------------------- 1 | # Documentation 2 | 3 | Here is a brief list of the functionalities provided by this repository. Additional information is provided in their docstring or by calling `help(function_name)`. 4 | 5 | ## High-level functions 6 | 7 | To get quickly started the following high-level functions can be used. These can imported at once with `from slicetca import *`. 8 | 9 | * `decompose` is the high-level function to decompose a data tensor. 10 | * `grid_search` is for determining the number of components. 11 | * `plot` allows plotting sliceTCA and TCA components. 12 | 13 | We recommend having a look at our notebooks for further details. 14 | 15 | ## Low-level functions 16 | 17 | For more specific use-cases, low-level functions might be preferred. 18 | 19 | * `_core.decompositions.SliceTCA` 20 | * `.fit(self, data, ...)` fits the components of a sliceTCA object to some data. 21 | * `.set_components(self, components)` sets the model's components. 22 | * `.get_components(self, detach=False, numpy=False)` returns the model's components. To backpropagate through the tensor, set detach=False. 23 | * `_invariance` 24 | * `.analytic_invariance.svd_basis(model)` sets the vectors of each slice type to an orthonormal basis and sort them by variance explained. 25 | * `.sgd_invariance(model, objective_function, transformation, ...)` allows optimizing the components w.r.t. some objective function while fixing the overall tensor. 26 | 27 | -------------------------------------------------------------------------------- /slicetca/invariance/analytic_invariance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def svd_basis(model, **kwargs): 5 | """ 6 | Sets the vectors of each slice type to an orthonormal basis. 7 | 8 | :param model: SliceTCA model 9 | :param kwargs: ignored 10 | :return: model with new components. 11 | """ 12 | 13 | device = model.device 14 | ranks = model.ranks 15 | 16 | new_components = [[None, None] for i in range(len(ranks))] 17 | for i in range(len(ranks)): 18 | if ranks[i] != 0: 19 | constructed = model.construct_single_partition(i) 20 | flattened_constructed = constructed.permute([i]+[q for q in range(len(ranks)) if q != i]) 21 | flattened_constructed = flattened_constructed.reshape(model.dimensions[i],-1).transpose(0,1) 22 | 23 | U, S, V = torch.linalg.svd(flattened_constructed.detach().cpu(), full_matrices=False) 24 | U, S, V = U[:,:ranks[i]], S[:ranks[i]], V[:ranks[i]] 25 | U, S, V = U.to(device), S.to(device), V.to(device) 26 | 27 | US = (U @ torch.diag(S)) 28 | slice = US.transpose(0,1).reshape([ranks[i]]+[model.dimensions[q] for q in range(len(ranks)) if q != i]) 29 | 30 | new_components[i][0] = V 31 | new_components[i][1] = slice 32 | else: 33 | new_components[i][0] = torch.zeros_like(model.vectors[i][0]) 34 | new_components[i][1] = torch.zeros_like(model.vectors[i][1]) 35 | 36 | model.set_components(new_components) 37 | 38 | return model 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SliceTCA 2 | 3 | This library provides tools to perform [sliceTCA](https://www.biorxiv.org/content/10.1101/2023.03.01.530616v1). 4 | 5 | ___ 6 | 7 |

8 | 9 |

10 | 11 | ## Installation 12 | 13 | ```commandline 14 | pip install slicetca 15 | ``` 16 | 17 | ## Full documentation 18 | 19 | The full documentation can be found [here](https://github.com/arthur-pe/slicetca/blob/master/documentation.md). 20 | 21 | ## Examples 22 | 23 | ### Quick example 24 | 25 | ```python 26 | import slicetca 27 | import torch 28 | from matplotlib import pyplot as plt 29 | 30 | device = ('cuda' if torch.cuda.is_available() else 'cpu') 31 | 32 | # your_data is a numpy array of shape (trials, neurons, time). 33 | data = torch.tensor(your_data, dtype=torch.float, device=device) 34 | 35 | # The tensor is decomposed into 2 trial-, 0 neuron- and 3 time-slicing components. 36 | components, model = slicetca.decompose(data, (2,0,3)) 37 | 38 | # For a not positive decomposition, we apply uniqueness constraints 39 | model = slicetca.invariance(model) 40 | 41 | slicetca.plot(model) 42 | 43 | plt.show() 44 | ``` 45 | 46 | ### Notebook 47 | 48 | See the [example notebook](https://github.com/arthur-pe/slicetca/blob/master/sliceTCA_notebook_1.ipynb) for an application of sliceTCA to publicly available neural data. 49 | 50 | 51 | Open In Colab 52 | 53 | 54 | ## Reference 55 | 56 | A. Pellegrino@, H. Stein, N. A. Cayco-Gaijc@. (2024). Dimensionality reduction beyond neural subspaces with slice tensor component analysis. *Nature Neuroscience* [https://www.nature.com/articles/s41593-024-01626-2](https://www.nature.com/articles/s41593-024-01626-2). -------------------------------------------------------------------------------- /slicetca/invariance/helper.py: -------------------------------------------------------------------------------- 1 | from slicetca.core.decompositions import SliceTCA 2 | 3 | import torch 4 | from typing import Sequence 5 | 6 | def mm(a: torch.Tensor, b: torch.Tensor): 7 | """ 8 | Performs generalized matrix multiplication (ijq) x (qkl) -> (ijkl) 9 | :param a: torch 10 | :param b: 11 | :return: 12 | """ 13 | temp1 = [chr(105+i) for i in range(len(a.size()))] 14 | temp2 = [chr(105+len(a.size())-1+i) for i in range(len(b.size()))] 15 | indexes1 = ''.join(temp1) 16 | indexes2 = ''.join(temp2) 17 | rhs = ''.join(temp1[:-1])+''.join(temp2[1:]) 18 | formula = indexes1+','+indexes2+'->'+rhs 19 | return torch.einsum(formula,[a,b]) 20 | 21 | 22 | def batch_outer(a: torch.Tensor, b: torch.Tensor): 23 | temp1 = [chr(105 + i + 1) for i in range(len(a.size()) - 1)] 24 | temp2 = [chr(105 + len(a.size()) + i + 1) for i in range(len(b.size()) - 1)] 25 | indexes1 = ''.join(temp1) 26 | indexes2 = ''.join(temp2) 27 | formula = chr(105) + indexes1 + ',' + chr(105) + indexes2 + '->' + chr(105) + indexes1 + indexes2 28 | return torch.einsum(formula, [a, b]) 29 | 30 | 31 | def construct_per_type(model: SliceTCA, components: Sequence[Sequence[torch.Tensor]]): 32 | """ 33 | :param model: SliceTCA model. 34 | :param components: The components to construct. 35 | :return: Reconstructed tensor. 36 | """ 37 | 38 | temp = [torch.zeros(model.dimensions).to(model.device) for i in range(len(components))] 39 | 40 | for i in range(len(components)): 41 | for j in range(model.ranks[i]): 42 | temp[i] += construct_single_component(model, components, i, j) 43 | return temp 44 | 45 | 46 | def construct_single_component(model: SliceTCA, components: Sequence[Sequence[torch.Tensor]], partition: int, k: int): 47 | 48 | temp2 = [model.positive_function[partition][q](components[partition][q][k]) for q in range(len(components[partition]))] 49 | outer = torch.einsum(model.einsums[partition], temp2) 50 | outer = outer.permute(model.inverse_permutations[partition]) 51 | 52 | return outer 53 | -------------------------------------------------------------------------------- /slicetca/invariance/iterative_invariance.py: -------------------------------------------------------------------------------- 1 | from .transformations import * 2 | from .criteria import * 3 | from ..core.decompositions import SliceTCA 4 | 5 | import torch 6 | import copy 7 | import tqdm 8 | from typing import Callable 9 | 10 | 11 | def sgd_invariance(model: SliceTCA, 12 | objective_function: Callable = l2, 13 | transformation: object = None, 14 | learning_rate: float = 10**-2, 15 | max_iter: int = 10000, 16 | min_std: float = 10**-3, 17 | iter_std: int = 100, 18 | verbose: bool = False, 19 | progress_bar: bool = True, 20 | **kwargs): 21 | """ 22 | Enables optimizing the components w.r.t. some objective function while fixing the overall reconstructed tensor. 23 | 24 | :param model: SliceTCA model. 25 | :param objective_function: The objective to optimize. 26 | :param transformation: transformations.TransformationBetween(model) or 27 | transformations.TransformationWithin(model) or 28 | nn.Sequential(TransformationWithin(model), TransformationBetween(model)) 29 | :param learning_rate: Learning rate for the optimizer (default Adam). 30 | :param max_iter: Maximum number of iterations. 31 | :param min_std: Minimum std of the last iter_std iterations under which to assume the model has converged. 32 | :param iter_std: See min_std. 33 | :param verbose: Whether to print the loss. 34 | :param progress_bar: Whether to have a progress bar. 35 | :param kwargs: ignored. 36 | :return: model with the modified components. 37 | """ 38 | 39 | if transformation is None: transformation = TransformationBetween(model) 40 | 41 | model.requires_grad_(False) 42 | 43 | optim = torch.optim.Adam(transformation.parameters(), lr=learning_rate) 44 | 45 | components = model.get_components(detach=True) 46 | 47 | losses = [] 48 | 49 | iterator = tqdm.tqdm(range(max_iter)) if progress_bar else range(max_iter) 50 | 51 | for iteration in iterator: 52 | 53 | components_transformed = transformation(copy.deepcopy(components)) 54 | 55 | components_transformed_constructed = construct_per_type(model, components_transformed) 56 | l = objective_function(components_transformed_constructed) 57 | 58 | if verbose: print('Iteration:', iteration, '\tloss:', l.item()) 59 | if progress_bar: iterator.set_description("Invariance loss: " + str(l.item())[:10] + ' ') 60 | 61 | optim.zero_grad() 62 | l.backward() 63 | optim.step() 64 | 65 | losses.append(l.item()) 66 | 67 | if len(losses)>iter_std and np.array(losses[-100:]).std()=0): 35 | raise Exception('For all i it should be that train_blocks_dimensions[i]>=test_blocks_dimensions[i].') 36 | 37 | if number_blocks is None: 38 | number_blocks = int(fraction_test * flattened_max_dim / np.prod(2*np.array(test_blocks_dimensions)+1)) 39 | else: 40 | warnings.warn('The parameter number_blocks is deprecated, use fraction_test instead.', DeprecationWarning) 41 | 42 | if exact: 43 | start = torch.zeros(flattened_max_dim, device=device) 44 | start[:number_blocks] = 1 45 | start = start[torch.randperm(flattened_max_dim, device=device)] 46 | start = start.reshape(dimensions) 47 | else: 48 | start = (torch.rand(tuple(dimensions), device=device) < fraction_test).long() 49 | 50 | start_index = start.nonzero() 51 | number_blocks = len(start_index) 52 | 53 | # Build outer-blocks mask 54 | a = [[slice(torch.clip(start_index[j][i]-train_blocks_dimensions[i],min=0, max=dimensions[i]), 55 | torch.clip(start_index[j][i]+train_blocks_dimensions[i]+1,min=0, max=dimensions[i])) 56 | for i in range(valence)] for j in range(number_blocks)] 57 | 58 | train_mask = torch.full(dimensions, True, device=device) 59 | 60 | for j in a: train_mask[j] = 0 61 | 62 | # Build inner-blocks tensor 63 | a = [[slice(torch.clip(start_index[j][i]-test_blocks_dimensions[i],min=0, max=dimensions[i]), 64 | torch.clip(start_index[j][i]+test_blocks_dimensions[i]+1,min=0, max=dimensions[i])) 65 | for i in range(valence)] for j in range(number_blocks)] 66 | 67 | test_mask = torch.full(dimensions, False, device=device) 68 | 69 | for j in a: test_mask[j] = 1 70 | 71 | return train_mask, test_mask 72 | -------------------------------------------------------------------------------- /slicetca/run/decompose.py: -------------------------------------------------------------------------------- 1 | from slicetca.core import SliceTCA, TCA 2 | from slicetca.core.helper_functions import squared_difference, poisson_log_likelihood 3 | 4 | import torch 5 | from typing import Union, Sequence 6 | import numpy as np 7 | import scipy 8 | from functools import partial 9 | 10 | 11 | def decompose(data: Union[torch.Tensor, np.array], 12 | number_components: Union[Sequence[int], int], 13 | positive: bool = False, 14 | initialization: str = 'uniform', 15 | learning_rate: float = 5*10**-3, 16 | batch_prop: float = 0.2, 17 | max_iter: int = 10000, 18 | min_std: float = 10**-5, 19 | iter_std: int = 100, 20 | mask: torch.Tensor = None, 21 | verbose: bool = False, 22 | progress_bar: bool = True, 23 | seed: int = 7, 24 | weight_decay: float = None, 25 | batch_prop_decay: int = 1): 26 | """ 27 | High-level function to decompose a data tensor into a SliceTCA or TCA decomposition. 28 | 29 | :param data: Torch tensor. 30 | :param number_components: If list or tuple number of sliceTCA components, else number of TCA components. 31 | :param positive: Whether to use a positive decomposition. Defaults the initialization to 'uniform-positive'. 32 | :param initialization: Components initialization 'uniform'~U(-1,1), 'uniform-positive'~U(0,1), 'normal'~N(0,1). 33 | :param learning_rate: Learning rate of the optimizer. 34 | :param batch_prop: Proportion of entries used to compute the gradient at every training iteration. 35 | :param max_iter: Maximum training iterations. 36 | :param min_std: Minimum std of the loss under which to return. 37 | :param iter_std: Number of iterations over which this std is computed. 38 | :param mask: Entries which are not used to compute the gradient at any training iteration. 39 | :param verbose: Whether to print the loss at every step. 40 | :param progress_bar: Whether to have a tqdm progress bar. 41 | :param seed: Torch seed. 42 | :param weight_decay: Decay of the parameters. If None defaults to Adam, else AdamW. 43 | :param batch_prop_decay: Exponential decay steps of the proportion of entries not used to compute the gradient. 44 | :return: components: A list (over component types) of lists (over factors) of rank x component_shape tensors. 45 | :return: model: A SliceTCA or TCA model. It can be used to access the losses over training and much more. 46 | """ 47 | 48 | torch.manual_seed(seed) 49 | 50 | if isinstance(data, np.ndarray): data = torch.tensor(data, device='cuda' if torch.cuda.is_available() else 'cpu') 51 | 52 | if data.dtype != torch.long: 53 | loss_function = squared_difference 54 | else: 55 | spikes_factorial = torch.tensor(scipy.special.factorial(data.numpy(force=True)), device=data.device) 56 | loss_function = partial(poisson_log_likelihood, spikes_factorial=spikes_factorial) 57 | 58 | dimensions = list(data.shape) 59 | 60 | if isinstance(number_components, int): decomposition = TCA 61 | else: decomposition = SliceTCA 62 | 63 | model = decomposition(dimensions, number_components, positive, initialization, device=data.device) 64 | 65 | if weight_decay is None: optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 66 | else: optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) 67 | 68 | for i in range(1,batch_prop_decay+1): 69 | model.fit(data, optimizer, loss_function, 70 | 1-(1-batch_prop)**i, max_iter, min_std, iter_std, mask, verbose, progress_bar) 71 | 72 | return model.get_components(numpy=True), model 73 | -------------------------------------------------------------------------------- /slicetca/plotting/grid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from matplotlib import pyplot as plt 4 | from matplotlib.ticker import MaxNLocator 5 | 6 | from typing import Literal, Sequence 7 | 8 | 9 | def plot_grid(loss_grid: np.ndarray, 10 | min_ranks: Sequence = (0, 0, 0), 11 | data: np.ndarray = None, 12 | elbow: float = 0.9, 13 | quantile: float = 0.0, 14 | vmin: float = None, 15 | vmax: float = None, 16 | reduction: Literal['mean', 'min'] = 'min', 17 | variables: Sequence[str] = ('trial', 'neuron', 'time'), 18 | ): 19 | """ 20 | :param loss_grid: (trial x neuron x time x batch) grid of cross-validated losses as a function of the number of components 21 | :param min_ranks: minimum ranks of the gridsearch (default (0, 0, 0)) 22 | :param data: to construct elbow 23 | :param elbow: fraction of best model performance relative to data squared norm to use for elbow 24 | :param quantile: 25 | :param vmin: 26 | :param vmax: 27 | :param reduction: 28 | :param variables: 29 | :return: 30 | """ 31 | 32 | max_ranks = tuple(np.array(min_ranks)+np.array(loss_grid.shape[:3])) 33 | 34 | match reduction: 35 | case 'mean': 36 | reduced_loss_grid = loss_grid.mean(axis=-1) 37 | case 'min': 38 | reduced_loss_grid = loss_grid.min(axis=-1) 39 | case _: 40 | raise Exception('Reduction should be mean or min.') 41 | 42 | min_index = np.unravel_index(np.argmin(reduced_loss_grid), reduced_loss_grid.shape) 43 | 44 | if data is not None: 45 | elbow_threshold = np.min(reduced_loss_grid)*elbow+np.mean(data**2)*(1-elbow) 46 | 47 | nb_plot = reduced_loss_grid.shape[0] 48 | if vmin is None: vmin = np.quantile(reduced_loss_grid, quantile) 49 | if vmax is None: vmax = np.quantile(reduced_loss_grid, 1-quantile) 50 | 51 | fig = plt.figure(figsize=(3*nb_plot,3), constrained_layout=True) 52 | axes = [fig.add_subplot(1, nb_plot, i+1) for i in range(nb_plot)] 53 | 54 | for i in range(nb_plot): 55 | im = axes[i].imshow(reduced_loss_grid[i], cmap='pink_r', origin='lower', 56 | vmin=vmin, vmax=vmax, extent=(min_ranks[2]-0.5, max_ranks[2]-0.5, min_ranks[1]-0.5, max_ranks[1]-0.5)) 57 | 58 | if data is not None: 59 | elbow_neuron = np.argmax(reduced_loss_grid[i] < elbow_threshold, axis=1).astype(float) 60 | elbow_time = np.argmax(reduced_loss_grid[i] < elbow_threshold, axis=0).astype(float) 61 | elbow_time[np.all(reduced_loss_grid[i] >= elbow_threshold, axis=0)] = np.nan 62 | elbow_neuron[np.all(reduced_loss_grid[i] >= elbow_threshold, axis=1)] = np.nan 63 | axes[i].plot(np.arange(min_ranks[2], max_ranks[2]), min_ranks[1]+elbow_time, color=(0, 0, 0), alpha=0.3, linewidth=2.0) 64 | axes[i].plot(min_ranks[2]+elbow_neuron, np.arange(min_ranks[1], max_ranks[1]), color=(0, 0, 0), alpha=0.3, linewidth=2.0) 65 | 66 | axes[i].set_aspect('equal') 67 | axes[i].set_title('$R_{'+variables[0]+'}='+f'{min_ranks[0]+i}$') 68 | axes[i].set_ylabel('$R_{'+variables[1]+'}$') 69 | axes[i].set_xlabel('$R_{'+variables[2]+'}$') 70 | 71 | axes[i].yaxis.set_major_locator(MaxNLocator(integer=True)) 72 | axes[i].xaxis.set_major_locator(MaxNLocator(integer=True)) 73 | 74 | fig.colorbar(im,fraction=0.046, pad=0.04) 75 | 76 | axes[min_index[0]].scatter(min_ranks[2]+min_index[2], min_ranks[1]+min_index[1], 77 | color=(0.9,0.2, 0.2), marker='*', s=200) 78 | 79 | if data is not None: 80 | elbow_index = np.stack(np.meshgrid(*tuple([np.arange(min_ranks[i], max_ranks[i]) for i in range(3)]), indexing='ij'), axis=-1).sum(axis=-1).astype(float) 81 | elbow_index[(reduced_loss_grid>=elbow_threshold)] = np.nan 82 | if not np.all(np.isnan(elbow_index)): 83 | elbow_index = np.unravel_index(np.argmin(elbow_index, axis=None), elbow_index.shape) 84 | axes[elbow_index[0]].scatter(min_ranks[2]+elbow_index[2], min_ranks[1]+elbow_index[1], 85 | color=(0.9,0.2, 0.2), marker='*', s=200, facecolor=(1, 1, 1, 0)) 86 | -------------------------------------------------------------------------------- /slicetca/tests/tests.py: -------------------------------------------------------------------------------- 1 | from slicetca.core.decompositions import PartitionTCA, SliceTCA 2 | from slicetca.invariance.analytic_invariance import svd_basis 3 | from slicetca.invariance.iterative_invariance import TransformationWithin, TransformationBetween, sgd_invariance, l2 4 | from slicetca.run.decompose import decompose 5 | from slicetca.run.grid_search import grid_search 6 | 7 | import torch 8 | from torch import nn 9 | import numpy as np 10 | 11 | 12 | dimensions = (5, 6, 7, 8) 13 | ranks = (1, 0, 2, 3) 14 | 15 | 16 | def test_analytic_invariance(): 17 | 18 | m = SliceTCA(dimensions, ranks, initialization='uniform', device=device) 19 | 20 | a = m.construct().detach() 21 | m = svd_basis(m) 22 | b = m.construct() 23 | 24 | assert torch.mean(torch.square(a - b)).item()<10**-6, 'Analytic invariance changes L1' 25 | 26 | print('test_analytic_invariance passed') 27 | 28 | 29 | def test_iterative_invariance_within_between(): 30 | 31 | m = SliceTCA(dimensions, ranks, initialization='uniform', device=device) 32 | 33 | a = m.construct().detach() 34 | 35 | transfo = nn.Sequential(TransformationBetween(m), TransformationWithin(m)) 36 | 37 | m = sgd_invariance(m, l2, transformation=transfo, max_iter=3, learning_rate=0.01, progress_bar=progress_bar) 38 | b = m.construct() 39 | 40 | assert torch.mean(torch.square(a - b)).item()<10**-6, 'Iterative invariance L2, L3 changes L1' 41 | 42 | print('test_invariance_between passed') 43 | 44 | 45 | def test_iterative_invariance_within(): 46 | m = SliceTCA(dimensions, ranks, initialization='uniform', device=device) 47 | 48 | a = m.construct().detach() 49 | 50 | transfo = TransformationWithin(m) 51 | 52 | m = sgd_invariance(m, l2, transformation=transfo, max_iter=3, learning_rate=0.01, progress_bar=progress_bar) 53 | b = m.construct() 54 | 55 | assert torch.mean(torch.square(a - b)).item()<10**-6, 'Iterative invariance L3 changes L1' 56 | 57 | print('test_invariance_within passed') 58 | 59 | 60 | def test_iterative_invariance_between(): 61 | m = SliceTCA(dimensions, ranks, initialization='uniform', device=device) 62 | 63 | a = m.construct().detach() 64 | 65 | transfo = TransformationBetween(m) 66 | 67 | m = sgd_invariance(m, l2, transformation=transfo, max_iter=3, learning_rate=0.01, progress_bar=progress_bar) 68 | b = m.construct() 69 | 70 | assert torch.mean(torch.square(a - b)).item()<10**-6, 'Iterative invariance L2 changes L1' 71 | 72 | print('test_iterative_invariance_between passed') 73 | 74 | 75 | def test_gridsearch(): 76 | 77 | data = SliceTCA(dimensions, ranks, device=device).construct().detach() 78 | 79 | mask_train = torch.rand_like(data) < 0.5 80 | mask_test = mask_train & (torch.rand_like(data) < 0.5) 81 | 82 | loss_grid, seed_grid = grid_search(data, ranks, learning_rate=10 ** -3, max_iter=3, sample_size=2, 83 | processes_sample=2, 84 | processes_grid=2, mask_train=mask_train, mask_test=mask_test) 85 | 86 | print('test_gridsearch passed') 87 | 88 | 89 | def test_decompose(): 90 | 91 | a = SliceTCA(dimensions, ranks, device=device).construct().detach() 92 | 93 | c, m = decompose(a, ranks) 94 | 95 | b = m.construct() 96 | 97 | assert torch.mean(torch.square(a - b)).item()<10**-6, 'Decompose fails to decompose a reconstructed tensor' 98 | 99 | print('test_decompose passed') 100 | 101 | 102 | def test_fit(): 103 | t = PartitionTCA(dimensions[:3], [[[0], [1, 2]], [[2, 0], [1]]], [1, 2]).construct().detach() 104 | p = PartitionTCA(dimensions[:3], [[[0], [1, 2]], [[2, 0], [1]]], [1, 2]) 105 | 106 | optim = torch.optim.Adam(p.parameters(), lr=10 ** -4) 107 | 108 | p.fit(t, optimizer=optim, min_std=10 ** -6, batch_prop=0.1, max_iter=3 * 10 ** 4, progress_bar=progress_bar) 109 | 110 | mse = torch.mean(torch.square(t - p.construct())).item() 111 | 112 | assert mse < 10 ** -3, 'Failed to decompose a reconstruction: ' + str(mse) 113 | 114 | print('test_fit passed') 115 | 116 | 117 | if __name__ == '__main__': 118 | 119 | torch.manual_seed(7) 120 | 121 | device = ('cuda' if torch.cuda.is_available() else 'cpu') 122 | 123 | progress_bar = False 124 | 125 | test_analytic_invariance() 126 | test_iterative_invariance_within_between() 127 | test_iterative_invariance_between() 128 | test_iterative_invariance_within() 129 | test_gridsearch() 130 | test_fit() 131 | 132 | print('All tests passed.') 133 | -------------------------------------------------------------------------------- /slicetca/invariance/transformations.py: -------------------------------------------------------------------------------- 1 | from .helper import * 2 | 3 | from itertools import combinations 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | 9 | class TransformationBetween(nn.Module): 10 | """ 11 | Transformation between sliceTCA component types. 12 | """ 13 | 14 | def __init__(self, model): 15 | super(TransformationBetween, self).__init__() 16 | 17 | self.number_components = len(model.ranks) 18 | self.ranks = model.ranks 19 | self.partitions = model.partitions 20 | self.dims = model.dimensions 21 | self.device = model.device 22 | 23 | self.free_vectors_combinations = nn.ModuleList( 24 | [nn.ParameterList([nn.Parameter(torch.tensor(0.0, device=self.device)) for j in range(self.number_components)]) for i in 25 | range(self.number_components)]) 26 | 27 | self.components = model.get_components() 28 | 29 | self.remaining_index_combinations = [[None for j in range(self.number_components)] for i in 30 | range(self.number_components)] 31 | self.remaining_dims_combinations = [[None for j in range(self.number_components)] for i in 32 | range(self.number_components)] 33 | 34 | for combination in combinations(list(range(self.number_components)), 2): 35 | if self.ranks[combination[0]] != 0 and self.ranks[combination[1]] != 0: 36 | temp = set(self.partitions[combination[0]][1]) 37 | temp = temp.intersection(set(self.partitions[combination[1]][1])) 38 | self.remaining_index_combinations[combination[0]][combination[1]] = list(temp) 39 | 40 | self.remaining_dims_combinations[combination[0]][combination[1]] = [self.dims[i] for i in temp] 41 | remaining_dims = self.remaining_dims_combinations[combination[0]][combination[1]] 42 | 43 | free_vectors_dim = [self.ranks[combination[0]], self.ranks[combination[1]]] + remaining_dims 44 | free_vectors = nn.Parameter(torch.randn(free_vectors_dim, device=self.device)) 45 | self.free_vectors_combinations[combination[0]][combination[1]] = free_vectors 46 | 47 | def forward(self, components): 48 | 49 | for combination in combinations(list(range(self.number_components)), 2): 50 | if self.ranks[combination[0]] != 0 and self.ranks[combination[1]] != 0: 51 | a_index = self.partitions[combination[0]][0][0] 52 | b_index = self.partitions[combination[1]][0][0] 53 | 54 | A_indexes = [b_index] + self.remaining_index_combinations[combination[0]][combination[1]] 55 | B_indexes = [a_index] + self.remaining_index_combinations[combination[0]][combination[1]] 56 | 57 | perm_B = [A_indexes.index(i) for i in self.partitions[combination[0]][1]] 58 | perm_A = [B_indexes.index(i) for i in self.partitions[combination[1]][1]] 59 | 60 | free_vectors = self.free_vectors_combinations[combination[0]][combination[1]] 61 | A = batch_outer(components[combination[0]][0], free_vectors) 62 | B = batch_outer(components[combination[1]][0], free_vectors.transpose(0, 1)) 63 | 64 | A = A.sum(dim=0) 65 | B = B.sum(dim=0) 66 | 67 | A = A.transpose(0, 1) 68 | B = B.transpose(0, 1) 69 | 70 | A = A.permute([0] + [1 + i for i in perm_A]) 71 | B = B.permute([0] + [1 + i for i in perm_B]) 72 | 73 | components[combination[0]][1] = components[combination[0]][1] + B 74 | components[combination[1]][1] = components[combination[1]][1] - A 75 | 76 | return components 77 | 78 | 79 | class TransformationWithin(nn.Module): 80 | """ 81 | Transformation within sliceTCA component types. 82 | """ 83 | 84 | def __init__(self, model): 85 | super().__init__() 86 | 87 | self.ranks = model.ranks 88 | self.number_components = len(model.ranks) 89 | self.device = model.device 90 | 91 | self.free_gl = nn.ParameterList([nn.Parameter(torch.eye(i, device=self.device)+torch.randn((i,i), 92 | device=self.device)/np.sqrt(3*i)) for i in self.ranks]) 93 | 94 | def forward(self, components): 95 | 96 | for i in range(self.number_components): 97 | if self.ranks[i] != 0: 98 | components[i][0] = mm(self.free_gl[i].T, components[i][0]) 99 | components[i][1] = mm(torch.linalg.inv(self.free_gl[i]), components[i][1]) 100 | 101 | return components 102 | -------------------------------------------------------------------------------- /slicetca/run/grid_search.py: -------------------------------------------------------------------------------- 1 | from slicetca.run.decompose import decompose 2 | 3 | import multiprocessing as mp 4 | from functools import partial 5 | from concurrent.futures import ProcessPoolExecutor as Pool 6 | from tqdm import tqdm 7 | import torch 8 | import numpy as np 9 | 10 | from typing import Sequence, Union 11 | 12 | 13 | # To be fixed: high memory usage when using GPU. 14 | 15 | def grid_search(data: Union[torch.Tensor], # Only works with torch.Tensor atm 16 | max_ranks: Sequence[int], 17 | mask_train: torch.Tensor = None, 18 | mask_test: torch.Tensor = None, 19 | min_ranks: Sequence[int] = None, 20 | sample_size: int = 1, 21 | processes_sample: int = 1, 22 | processes_grid: int = 1, 23 | seed: int = 7, 24 | **kwargs): 25 | """ 26 | Performs a gridsearch over different number of components (ranks) to see which has the lowest cross-validated loss. 27 | 28 | :param data: Data tensor to decompose. 29 | :param max_ranks: Maximum number of components of each type. 30 | :param mask_train: Mask representing over which entries to compute the backpropagated loss. None is full tensor. 31 | :param mask_test: Mask representing over which entries to compute the loss for validation. None is full tensor. 32 | :param min_ranks: Minimum number of components of each type. 33 | :param sample_size: Number of seeds to use for a given number of components. 34 | :param processes_sample: Number of processes (threads) to use for a given number of components across seeds. 35 | :param processes_grid: Number of processes (threads) to use over different number of components. 36 | :param seed: Numpy seed. 37 | :param kwargs: Same kwargs as decompose. 38 | :return: A (max_rank_1-min_rank_1, max_rank_2-min_rank_2, ..., sample_size) ndarray of losses masked entries. 39 | """ 40 | 41 | np.random.seed(seed) 42 | 43 | try: 44 | mp.set_start_method('spawn', force=True) 45 | except RuntimeError: 46 | pass 47 | 48 | if min_ranks is None: min_ranks = [0 for i in max_ranks] 49 | max_ranks = [i+1 for i in max_ranks] 50 | rank_span = [max_ranks[i]-min_ranks[i] for i in range(len(max_ranks))] 51 | 52 | grid = get_grid_sample(min_ranks, max_ranks) 53 | grid = np.concatenate([grid, np.random.randint(10**2,10**6, grid.shape[0])[:,np.newaxis]], axis=-1) 54 | 55 | print('Grid shape:', str(rank_span), 56 | '- Samples:', sample_size, 57 | '- Grid entries:', torch.tensor(grid).size()[0], 58 | '- Number of models to fit:', torch.tensor(grid).size()[0]*sample_size) 59 | 60 | dec = partial(decompose_mp_sample, data=data, mask_train=mask_train, mask_test=mask_test, sample_size=sample_size, 61 | processes_sample=processes_sample, **kwargs) 62 | 63 | out_grid = [] 64 | with Pool(max_workers=processes_grid) as pool: 65 | iterator = tqdm(pool.map(dec, grid), total=torch.tensor(grid).size()[0]) 66 | iterator.set_description('Number of components (completed): - ', refresh=True) 67 | for i, p in enumerate(iterator): 68 | out_grid.append(p) 69 | iterator.set_description('Number of components (completed): '+str(np.unravel_index(i, tuple(max_ranks))) + ' ', refresh=True) 70 | out_grid = np.array(out_grid, dtype=np.float32) 71 | 72 | loss_grid = out_grid[:,0] 73 | seed_grid = out_grid[:,1].astype(int) 74 | 75 | loss_grid = loss_grid.reshape(rank_span+[sample_size]) 76 | seed_grid = seed_grid.reshape(rank_span+[sample_size]) 77 | 78 | return loss_grid, seed_grid 79 | 80 | 81 | def decompose_mp_sample(number_components_seed, data, mask_train, mask_test, sample_size, processes_sample, **kwargs): 82 | 83 | number_components = number_components_seed[:-1] 84 | seed = number_components_seed[-1] 85 | 86 | np.random.seed(seed) 87 | 88 | dec = partial(decompose_mp, 89 | data=data.clone(), 90 | mask_train=(mask_train.clone() if mask_train is not None else None), 91 | mask_test=(mask_test.clone() if mask_test is not None else None), 92 | **kwargs) 93 | 94 | sample = number_components[np.newaxis].repeat(sample_size, 0) 95 | seeds = np.random.randint(10**2,10**6, sample_size) 96 | 97 | sample = np.concatenate([sample, seeds[:,np.newaxis]], axis=-1) 98 | 99 | with Pool(max_workers=processes_sample) as pool: loss = np.array(list(pool.map(dec, sample))) 100 | 101 | return loss, seeds 102 | 103 | 104 | def decompose_mp(number_components_seed, data, mask_train, mask_test, *args, **kwargs): 105 | 106 | number_components, seed = number_components_seed[:-1], number_components_seed[-1] 107 | 108 | if (number_components == np.zeros_like(number_components)).all(): 109 | data_hat = 0 110 | else: 111 | _, model = decompose(data, number_components, mask=mask_train, verbose=False, progress_bar=False, *args, 112 | seed=seed, **kwargs) 113 | data_hat = model.construct() 114 | 115 | if mask_test is None: loss = torch.mean((data-data_hat)**2).item() 116 | else: loss = torch.mean(((data-data_hat)[mask_test])**2).item() 117 | 118 | return loss 119 | 120 | 121 | def get_grid_sample(min_dims, max_dims): 122 | 123 | grid = np.meshgrid(*[np.array([i for i in range(min_dims[j],max_dims[j])]) for j in range(len(max_dims))], 124 | indexing='ij') 125 | 126 | grid = np.stack(grid) 127 | 128 | return grid.reshape(grid.shape[0], -1).T 129 | -------------------------------------------------------------------------------- /slicetca/plotting/factors.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | import numpy as np 3 | from typing import Sequence, Union 4 | 5 | 6 | def plot(model, 7 | components: Sequence[Sequence[np.ndarray]] = None, 8 | variables: Sequence[str] = ('trial', 'neuron', 'time'), 9 | colors: Union[Sequence[np.ndarray], Sequence[Sequence[float]]] = (None, None, None), 10 | sorting_indices: Sequence[np.ndarray] = (None, None, None), 11 | ticks: Sequence[np.ndarray] = (None, None, None), 12 | tick_labels: Sequence[np.ndarray] = (None, None, None), 13 | quantile: float = 0.95, 14 | factor_height: int = 2, 15 | aspect: str = 'auto', 16 | s: int = 10, 17 | cmap: str = None, 18 | tight_layout: bool = True, 19 | dpi: int = 60): 20 | """ 21 | Plots SliceTCA components. Plotting TCA or PartitionTCA components also works but is not optimized. 22 | 23 | :param model: SliceTCA, TCA or PartitionTCA instance. 24 | :param components: By default, components = model.get_components(numpy=True). 25 | But you may pass pre-processed components (e.g. sorted neurons etc...). 26 | :param variables: The axes labels, in the same order as the dimensions of the tensor. 27 | :param colors: The colors of the variable (e.g. trial condition). Used only for 1-tensor factors. 28 | None or 1-d variable will default to plt.plot, 2-d (trials x RGBA) to scatter. 29 | Note that to generate RGBA colors from integer trial condition you may call: 30 | colors = matplotlib.colormaps['hsv'](condition/np.max(condition)) 31 | :param sorting_indices: Sort (e.g. trials) according to indices. 32 | :param ticks: Can be used instead of the 0,1, ... default indexing. 33 | :param tick_labels: Requires ticks 34 | :param quantile: Quantile of imshow cmap. 35 | :param factor_height: Height of the 1-tensor factors. Their length is 3. 36 | :param aspect: 'auto' will give a square-looking slice, 'equal' will preserve the ratios. 37 | :param s: size of scatter dots (see colors parameter). 38 | :param cmap: matplotlib cmap for 2-tensor factors (slices). Defaults to inferno for positive else seismic. 39 | :param tight_layout: To call plt.tight_layout(). Note that constrained_layout does not work well with 40 | :param dpi: Figure dpi. Set lower if you have many components of a given type. 41 | :return: A list of axes which can be used for further customizing the plots. 42 | The list has shape the same shape as model.get_components. That is component_type x (slice/factor) x rank 43 | """ 44 | 45 | components = model.get_components(numpy=True) if components is None else components 46 | partitions = model.partitions 47 | positive = model.positive 48 | ranks = model.ranks 49 | 50 | # Pad the variables in case fewer than needed variables are provided 51 | variables = list(variables)+['variable '+str(i+1) for i in range(len(variables), max(len(variables), len(ranks)))] 52 | 53 | number_nonzero_components = np.sum(np.array(ranks) != 0) 54 | 55 | axes = [[[None for k in j] for j in i] for i in components] 56 | 57 | figure_size = max([sum([j.shape[0]*3 if len(j.shape) == 3 else j.shape[0]*factor_height for j in i]) for i in components]) 58 | 59 | fig = plt.figure(figsize=(number_nonzero_components*3, figure_size), dpi=dpi) 60 | gs = fig.add_gridspec(figure_size, number_nonzero_components) 61 | 62 | column = 0 63 | for i in range(len(ranks)): 64 | row = 0 65 | for j in range(ranks[i]): 66 | for k in range(len(components[i])): 67 | current_component = components[i][k][j] 68 | 69 | # =========== Plots 1-tensor factors =========== 70 | if len(list(components[i][k].shape)) == 2: 71 | ax = fig.add_subplot(gs[row:row+factor_height, column]) 72 | row += factor_height 73 | 74 | leg = partitions[i][k][0] 75 | 76 | if sorting_indices[leg] is not None: 77 | current_component = current_component[sorting_indices[leg]] 78 | 79 | if isinstance(colors[leg], np.ndarray) and len(colors[leg].shape) == 2: 80 | ax.scatter(np.arange(len(current_component)), current_component, color=colors[leg], s=s) 81 | else: 82 | ax.plot(np.arange(len(current_component)), current_component, 83 | color=(0.0, 0.0, 0.0) if colors[leg] is None else colors[leg]) 84 | 85 | if ticks[partitions[i][k][0]] is not None: 86 | ax.set_xticks(ticks[partitions[i][k][0]], tick_labels[partitions[i][k][0]]) 87 | 88 | ax.set_xlabel(variables[leg]) 89 | 90 | # =========== Plots 2-tensor factors (slices) =========== 91 | elif len(list(components[i][k].shape)) == 3: 92 | ax = fig.add_subplot(gs[row:row+3, column]) 93 | row += 3 94 | ax.set_aspect(aspect) 95 | 96 | p = (positive if isinstance(positive, bool) else positive[i][k]) 97 | 98 | if sorting_indices[partitions[i][k][0]] is not None: 99 | current_component = current_component[sorting_indices[partitions[i][k][0]]] 100 | if sorting_indices[partitions[i][k][1]] is not None: 101 | current_component = current_component.T[sorting_indices[partitions[i][k][1]]].T 102 | 103 | if p: 104 | ax.imshow(current_component, aspect=aspect, cmap=(cmap if cmap is not None else 'inferno'), 105 | vmin=np.quantile(current_component,1-quantile), 106 | vmax=np.quantile(current_component,quantile)) 107 | else: 108 | min_max = np.quantile(np.abs(current_component),quantile) 109 | ax.imshow(current_component, aspect=aspect, cmap=(cmap if cmap is not None else 'seismic'), 110 | vmin=-min_max, vmax=min_max) 111 | 112 | # =========== Axes labels =========== 113 | variable_x = variables[partitions[i][k][1]] 114 | variable_y = variables[partitions[i][k][0]] 115 | ax.set_xlabel(variable_x) 116 | ax.set_ylabel(variable_y) 117 | 118 | if ticks[partitions[i][k][0]] is not None: 119 | ax.set_yticks(ticks[partitions[i][k][0]], tick_labels[partitions[i][k][0]]) 120 | if ticks[partitions[i][k][1]] is not None: 121 | ax.set_xticks(ticks[partitions[i][k][1]], tick_labels[partitions[i][k][1]]) 122 | 123 | # =========== Higher order factors can't be plotted =========== 124 | elif len(list(components[i][k].shape)) >= 4: 125 | ax = fig.add_subplot(gs[row:row+factor_height, column]) 126 | row += factor_height 127 | ax.text(0.5, 0.5, '3$\geq$ tensor', va='center', ha='center', color='black') 128 | ax.axis('off') 129 | 130 | # =========== Store axes =========== 131 | axes[i][k][j] = ax 132 | 133 | if ranks[i] != 0: column += 1 134 | 135 | if tight_layout: fig.tight_layout() 136 | 137 | return axes 138 | -------------------------------------------------------------------------------- /slicetca/core/decompositions.py: -------------------------------------------------------------------------------- 1 | from .helper_functions import squared_difference 2 | 3 | import torch 4 | from torch import nn 5 | import numpy as np 6 | import tqdm 7 | from collections.abc import Iterable 8 | 9 | from typing import Sequence, Union, Callable 10 | 11 | 12 | class PartitionTCA(nn.Module): 13 | 14 | def __init__(self, 15 | dimensions: Sequence[int], 16 | partitions: Sequence[Sequence[Sequence[int]]], 17 | ranks: Sequence[int], 18 | positive: Union[bool, Sequence[Sequence[Callable]]] = False, 19 | initialization: str = 'uniform', 20 | init_weight: float = None, 21 | init_bias: float = None, 22 | device: str = 'cpu'): 23 | """ 24 | Parent class for the sliceTCA and TCA decompositions. 25 | 26 | :param dimensions: Dimensions of the data to decompose. 27 | :param partitions: List of partitions of the legs of the tensor. 28 | [[[0],[1]]] would be a matrix rank decomposition. 29 | :param ranks: Number of components of each partition. 30 | :param positive: If False does nothing. 31 | If True constrains all components to be positive. 32 | If list of list, the list of functions to apply to a given partition and component. 33 | :param initialization: Components initialization 'uniform'~U(-1,1), 'uniform-positive'~U(0,1), 'normal'~N(0,1). 34 | :param init_weight: Coefficient to multiply the initial component by. 35 | :param init_bias: Coefficient to add to the initial component. 36 | :param device: Torch device. 37 | """ 38 | 39 | super(PartitionTCA, self).__init__() 40 | 41 | components = [[[dimensions[k] for k in j] for j in i] for i in partitions] 42 | 43 | if init_weight is None: 44 | if initialization == 'normal': init_weight = 1/np.sqrt(sum(ranks)) 45 | if initialization == 'uniform-positive': init_weight = ((0.5 / sum(ranks)) ** (1 / max([len(p) for p in partitions])))*2 46 | if initialization == 'uniform': init_weight = 1/np.sqrt(sum(ranks)) 47 | if init_bias is None: init_bias = 0.0 48 | 49 | if isinstance(positive, bool): 50 | if positive: positive_function = [[torch.abs for j in i] for i in partitions] 51 | else: positive_function = [[self.identity for j in i] for i in partitions] 52 | elif isinstance(positive, tuple) or isinstance(positive, list): positive_function = positive 53 | 54 | vectors = nn.ModuleList([]) 55 | 56 | for i in range(len(ranks)): 57 | r = ranks[i] 58 | dim = components[i] 59 | 60 | # k-tensors of the outer product 61 | if initialization == 'normal': 62 | v = [nn.Parameter(positive_function[i][j](torch.randn([r]+d, device=device)*init_weight + init_bias)) for j, d in enumerate(dim)] 63 | elif initialization == 'uniform': 64 | v = [nn.Parameter(positive_function[i][j](2*(torch.rand([r] + d, device=device)-0.5)*init_weight + init_bias)) for j, d in enumerate(dim)] 65 | elif initialization == 'uniform-positive': 66 | v = [nn.Parameter(positive_function[i][j](torch.rand([r] + d, device=device)*init_weight + init_bias)) for j, d in enumerate(dim)] 67 | else: 68 | raise Exception('Undefined initialization, select one of : normal, uniform, uniform-positive') 69 | 70 | vectors.append(nn.ParameterList(v)) 71 | 72 | self.vectors = vectors 73 | 74 | self.dimensions = dimensions 75 | self.partitions = partitions 76 | self.ranks = ranks 77 | self.positive = positive 78 | self.initialization = initialization 79 | self.init_weight = init_weight 80 | self.init_bias = init_bias 81 | self.device = device 82 | 83 | self.components = components 84 | self.positive_function = positive_function 85 | self.valence = len(dimensions) 86 | self.entries = np.prod(dimensions) 87 | 88 | self.losses = [] 89 | 90 | self.inverse_permutations = [] 91 | self.flattened_permutations = [] 92 | for i in self.partitions: 93 | temp = [] 94 | for j in i: 95 | for k in j: 96 | temp.append(k) 97 | self.flattened_permutations.append(temp) 98 | self.inverse_permutations.append(torch.argsort(torch.tensor(temp)).tolist()) 99 | 100 | self.set_einsums() 101 | 102 | def identity(self, x): 103 | return x 104 | 105 | def set_einsums(self): 106 | 107 | self.einsums = [] 108 | for i in self.partitions: 109 | lhs = '' 110 | rhs = '' 111 | for j in range(len(i)): 112 | for k in i[j]: 113 | lhs += chr(105 + k) 114 | rhs += chr(105 + k) 115 | if j != len(i) - 1: 116 | lhs += ',' 117 | self.einsums.append(lhs + '->' + rhs) 118 | 119 | def construct_single_component(self, partition: int, k: int): 120 | """ 121 | Constructs the kth term of the given partition. 122 | 123 | :param partition: Type of the partition 124 | :param k: Number of the component 125 | :return: Tensor of shape self.dimensions 126 | """ 127 | 128 | temp = [self.positive_function[partition][q](self.vectors[partition][q][k]) for q in range(len(self.components[partition]))] 129 | outer = torch.einsum(self.einsums[partition], temp) 130 | outer = outer.permute(self.inverse_permutations[partition]) 131 | 132 | return outer 133 | 134 | def construct_single_partition(self, partition: int): 135 | """ 136 | Constructs the sum of the terms of a given type of partition. 137 | 138 | :param partition: Type of the partition 139 | :return: Tensor of shape self.dimensions 140 | """ 141 | 142 | temp = torch.zeros(self.dimensions).to(self.device) 143 | for j in range(self.ranks[partition]): 144 | temp += self.construct_single_component(partition, j) 145 | 146 | return temp 147 | 148 | def construct(self): 149 | """ 150 | Constructs the full tensor. 151 | :return: Tensor of shape self.dimensions 152 | """ 153 | 154 | temp = torch.zeros(self.dimensions).to(self.device) 155 | 156 | for i in range(len(self.partitions)): 157 | for j in range(self.ranks[i]): 158 | temp += self.construct_single_component(i, j) 159 | 160 | return temp 161 | 162 | def get_components(self, detach=False, numpy=False): 163 | """ 164 | Returns the components of the model. 165 | 166 | :param detach: Whether to detach the gradient. 167 | :param numpy: Whether to cast them to numpy arrays. 168 | :return: list of list of tensors. 169 | """ 170 | 171 | temp = [[] for i in range(len(self.vectors))] 172 | 173 | for i in range(len(self.vectors)): 174 | for j in range(len(self.vectors[i])): 175 | if numpy: 176 | temp[i].append( self.positive_function[i][j](self.vectors[i][j]).data.detach().cpu().numpy()) 177 | else: 178 | if not detach: temp[i].append(self.positive_function[i][j](self.vectors[i][j]).data.detach()) 179 | else: temp[i].append(self.positive_function[i][j](self.vectors[i][j]).data) 180 | 181 | return temp 182 | 183 | def set_components(self, components: Sequence[Sequence[torch.Tensor]]): # bug if positive_function != abs 184 | """ 185 | Set the model's components. 186 | If the positive functions are abs or the identity model.set_components(model.get_components) 187 | has no effect besides resetting the gradient. 188 | 189 | :param components: list of list tensors. 190 | """ 191 | 192 | for i in range(len(self.vectors)): 193 | for j in range(len(self.vectors[i])): 194 | with torch.no_grad(): 195 | if isinstance(components[i][j], torch.Tensor): 196 | self.vectors[i][j].copy_(components[i][j].to(self.device)) 197 | else: 198 | self.vectors[i][j].copy_(torch.tensor(components[i][j], device=self.device)) 199 | self.zero_grad() 200 | 201 | def fit(self, 202 | X: torch.Tensor, 203 | optimizer: torch.optim.Optimizer, 204 | loss_function: Callable = squared_difference, 205 | batch_prop: float = 0.2, 206 | max_iter: int = 10000, 207 | min_std: float = 10 ** -3, 208 | iter_std: int = 100, 209 | mask: torch.Tensor = None, 210 | verbose: bool = False, 211 | progress_bar: bool = True): 212 | """ 213 | Fits the model to data. 214 | 215 | :param X: The data tensor. 216 | :param optimizer: A torch optimizer. 217 | :param loss_function: The final loss if torch.mean(loss_function(X, X_hat)). That is, loss_function: R^n -> R^n. 218 | :param batch_prop: Proportion of entries used to compute the gradient at every training iteration. 219 | :param max_iter: Maximum training iterations. 220 | :param min_std: Minimum std of the loss under which to return. 221 | :param iter_std: Number of iterations over which this std is computed. 222 | :param mask: Entries which are not used to compute the gradient at any training iteration. 223 | :param verbose: Whether to print the loss at every step. 224 | :param progress_bar: Whether to have a tqdm progress bar. 225 | """ 226 | 227 | losses = [] 228 | 229 | iterator = tqdm.tqdm(range(max_iter)) if progress_bar else range(max_iter) 230 | 231 | for iteration in iterator: 232 | 233 | X_hat = self.construct() 234 | 235 | loss_entries = loss_function(X, X_hat) 236 | 237 | total_loss = torch.mean(loss_entries) 238 | 239 | if batch_prop != 1.0: batch_mask = torch.rand(self.dimensions, device=self.device) < batch_prop 240 | 241 | if mask is None and batch_prop == 1.0: 242 | loss = total_loss 243 | else: 244 | if mask is None: 245 | total_mask = batch_mask 246 | else: 247 | if batch_prop == 1.0: 248 | total_mask = mask 249 | else: 250 | total_mask = mask & batch_mask 251 | 252 | total_entries = torch.sum(total_mask) 253 | loss = torch.sum(loss_entries * total_mask) / total_entries 254 | 255 | optimizer.zero_grad() 256 | loss.backward() 257 | optimizer.step() 258 | 259 | total_loss = total_loss.item() 260 | 261 | losses.append(total_loss) 262 | 263 | if verbose: print('Iteration:', iteration, 'Loss:', total_loss) 264 | if progress_bar: iterator.set_description('Loss: ' + str(total_loss) + ' ') 265 | 266 | if len(losses) > iter_std and np.array(losses[-iter_std:]).std() < min_std: 267 | if progress_bar: iterator.set_description('The model converged. Loss: ' + str(total_loss) + ' ') 268 | break 269 | 270 | self.losses += losses 271 | 272 | 273 | class SliceTCA(PartitionTCA): 274 | def __init__(self, 275 | dimensions: Sequence[int], 276 | ranks: Sequence[int], 277 | positive: bool = False, 278 | initialization: str = 'uniform', 279 | init_weight: float = None, 280 | init_bias: float = None, 281 | device: str = 'cpu'): 282 | """ 283 | Main sliceTCA decomposition class. 284 | 285 | :param dimensions: Dimensions of the data to decompose. 286 | :param ranks: Number of components of each slice type. 287 | :param positive: If False does nothing. 288 | If True constrains all components to be positive. 289 | If list of list, the list of functions to apply to a given partition and component. 290 | :param initialization: Components initialization 'uniform'~U(-1,1), 'uniform-positive'~U(0,1), 'normal'~N(0,1). 291 | :param init_weight: Coefficient to multiply the initial component by. 292 | :param init_bias: Coefficient to add to the initial component. 293 | :param device: Torch device. 294 | """ 295 | 296 | valence = len(dimensions) 297 | partitions = [[[i], [j for j in range(valence) if j != i]] for i in range(valence)] 298 | 299 | super().__init__(dimensions=dimensions, ranks=ranks, partitions=partitions, positive=positive, 300 | initialization=initialization, init_weight=init_weight, init_bias=init_bias, device=device) 301 | 302 | 303 | class TCA(PartitionTCA): 304 | def __init__(self, 305 | dimensions: Sequence[int], 306 | rank: int, 307 | positive: bool = False, 308 | initialization: str = 'uniform', 309 | init_weight: float = None, 310 | init_bias: float = None, 311 | device: str = 'cpu'): 312 | """ 313 | Main TCA decomposition class. 314 | 315 | :param dimensions: Dimensions of the data to decompose. 316 | :param rank: Number of components. 317 | :param positive: If False does nothing. 318 | If True constrains all components to be positive. 319 | If list of list, the list of functions to apply to a given partition and component. 320 | :param initialization: Components initialization 'uniform'~U(-1,1), 'uniform-positive'~U(0,1), 'normal'~N(0,1). 321 | :param init_weight: Coefficient to multiply the initial component by. 322 | :param init_bias: Coefficient to add to the initial component. 323 | :param device: Torch device. 324 | """ 325 | 326 | if not isinstance(rank, Iterable): 327 | rank = (rank,) 328 | 329 | valence = len(dimensions) 330 | partitions = [[[j] for j in range(valence)]] 331 | 332 | super().__init__(dimensions=dimensions, ranks=rank, partitions=partitions, positive=positive, 333 | initialization=initialization, init_weight=init_weight, init_bias=init_bias, device=device) 334 | -------------------------------------------------------------------------------- /img/decomposition.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 16 | 35 | 37 | 40 | 43 | 44 | 47 | 50 | 51 | 54 | 57 | 58 | 59 | 64 | 68 | 72 | 76 | 80 | 84 | 88 | 92 | 96 | 100 | 104 | 108 | 112 | 116 | 120 | 124 | 128 | 132 | 136 | 140 | Trials 151 | 154 | 156 | 158 | 161 | 164 | 168 | 169 | 172 | 176 | 177 | 180 | 184 | 185 | 186 | 187 | 188 | 191 | 195 | 196 | 199 | 203 | 204 | 207 | 211 | 212 | 215 | 219 | 220 | 224 | 228 | 232 | 236 | 239 | 243 | 244 | 247 | 251 | 252 | 256 | 259 | 263 | 264 | 267 | 271 | 272 | 276 | 280 | 284 | 287 | 291 | 292 | 295 | 299 | 300 | 303 | 307 | 308 | 311 | 315 | 316 | 319 | 323 | 324 | 327 | 331 | 332 | 335 | 337 | 339 | 342 | 345 | 349 | 350 | 353 | 357 | 358 | 362 | 363 | 364 | 365 | 369 | 372 | 376 | 377 | 380 | 384 | 385 | 388 | 392 | 393 | 396 | 400 | 401 | 404 | 408 | 409 | 412 | 414 | 416 | 419 | 422 | 426 | 427 | 430 | 434 | 435 | 439 | 440 | 441 | 442 | 446 | 449 | 453 | 454 | 457 | 461 | 462 | 465 | 469 | 470 | 473 | 477 | 478 | 481 | 485 | 486 | 490 | 494 | 498 | 502 | 505 | 509 | 510 | 513 | 517 | 518 | 521 | 525 | 526 | 529 | 533 | 534 | 538 | 542 | 543 | 546 | 550 | 551 | +...+ 562 | +...+ 573 | +...+ 584 | 595 | 596 | 597 | --------------------------------------------------------------------------------