├── slicetca
├── tests
│ ├── __init__.py
│ └── tests.py
├── plotting
│ ├── additional.py
│ ├── __init__.py
│ ├── grid.py
│ └── factors.py
├── core
│ ├── __init__.py
│ ├── helper_functions.py
│ └── decompositions.py
├── invariance
│ ├── __init__.py
│ ├── criteria.py
│ ├── invariance.py
│ ├── analytic_invariance.py
│ ├── helper.py
│ ├── iterative_invariance.py
│ └── transformations.py
├── run
│ ├── __init__.py
│ ├── utils.py
│ ├── decompose.py
│ └── grid_search.py
└── __init__.py
├── setup.cfg
├── setup.py
├── LICENSE.txt
├── documentation.md
├── README.md
└── img
└── decomposition.svg
/slicetca/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/slicetca/plotting/additional.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | description-file = README.md
--------------------------------------------------------------------------------
/slicetca/core/__init__.py:
--------------------------------------------------------------------------------
1 | from .decompositions import *
2 |
3 | __all__ = ['PartitionTCA', 'SliceTCA', 'TCA']
4 |
--------------------------------------------------------------------------------
/slicetca/invariance/__init__.py:
--------------------------------------------------------------------------------
1 | from slicetca.invariance.invariance import invariance
2 |
3 | __all__ = ['invariance']
4 |
--------------------------------------------------------------------------------
/slicetca/plotting/__init__.py:
--------------------------------------------------------------------------------
1 | from .factors import plot
2 | from .grid import plot_grid
3 |
4 | __all__ = ['plot', 'plot_grid']
5 |
--------------------------------------------------------------------------------
/slicetca/run/__init__.py:
--------------------------------------------------------------------------------
1 | from .decompose import *
2 | from .grid_search import *
3 | from .utils import *
4 |
5 | __all__ = ['decompose', 'grid_search', 'block_mask']
6 |
--------------------------------------------------------------------------------
/slicetca/__init__.py:
--------------------------------------------------------------------------------
1 | from .invariance import *
2 | from .plotting import *
3 | from .run import *
4 |
5 | __version__ = '0.1.8'
6 | __author__ = 'Arthur Pellegrino, Heike Stein'
7 |
8 | __all__ = ['invariance', 'decompose', 'grid_search', 'block_mask', 'plot', 'plot_grid']
9 |
--------------------------------------------------------------------------------
/slicetca/core/helper_functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def squared_difference(x, x_hat):
5 | return (x - x_hat) ** 2
6 |
7 |
8 | def poisson_log_likelihood(spikes, rates, spikes_factorial, activation=torch.nn.functional.softplus):
9 |
10 | likelihood = torch.exp(-activation(rates)) * torch.pow(activation(rates), spikes) / spikes_factorial
11 |
12 | return -torch.log(likelihood)
13 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | from pathlib import Path
4 | this_directory = Path(__file__).parent
5 | long_description = (this_directory / "README.md").read_text()
6 |
7 | setup(
8 | name='slicetca',
9 | packages=find_packages(exclude=['tests*']),
10 | version='1.0.4',
11 |
12 | description='Package to perform Slice Tensor Component Analysis',
13 | long_description=long_description,
14 | long_description_content_type='text/markdown',
15 |
16 | url='https://github.com/arthur-pe/slicetca',
17 | author='Arthur Pellegrino',
18 | license='MIT',
19 | install_requires=['torch',
20 | 'numpy',
21 | 'matplotlib',
22 | 'tqdm',
23 | 'scipy'
24 | ],
25 | python_requires='>=3.10',
26 | classifiers=[
27 | 'Intended Audience :: Science/Research',
28 | 'Topic :: Scientific/Engineering :: Mathematics',
29 | 'License :: OSI Approved :: MIT License',
30 | 'Programming Language :: Python :: 3.8',
31 | ],
32 | )
33 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2023 Arthur Pellegrino, Heike Stein, N Alex Cayco-Gajic
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in
13 | all copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 | THE SOFTWARE.
--------------------------------------------------------------------------------
/slicetca/invariance/criteria.py:
--------------------------------------------------------------------------------
1 | from itertools import combinations
2 | import torch
3 | from typing import Sequence
4 |
5 | # Example of criteria to use for L2 optimization.
6 |
7 |
8 | def orthogonality_component_type_wise(reconstructed_tensors_of_each_partition: Sequence[torch.Tensor]):
9 | """
10 | Penalizes non-orthogonality between the reconstructed tensors of each partition/slicing.
11 |
12 | :param reconstructed_tensors_of_each_partition: The sum of the terms of a given partition/slicing.
13 | :return: Torch float.
14 | """
15 |
16 | l = 0
17 | for combo in combinations(reconstructed_tensors_of_each_partition, 2):
18 | l += torch.square(torch.sum(combo[0] * combo[1]) / torch.sqrt(torch.sum(combo[0] ** 2)) / torch.sqrt(
19 | torch.sum(combo[1] ** 2)))
20 | return l + l2(reconstructed_tensors_of_each_partition)
21 |
22 |
23 | def l2(reconstructed_tensors_of_each_partition: Sequence[torch.Tensor]):
24 | """
25 | Classic L_2 regularization, per reconstructed tensors of each partition/slicing.
26 |
27 | :param reconstructed_tensors_of_each_partition: The sum of the terms of a given partition/slicing.
28 | :return: Torch float.
29 | """
30 |
31 | l = 0
32 | for t in reconstructed_tensors_of_each_partition:
33 | l += (t ** 2).mean()
34 | return l
35 |
--------------------------------------------------------------------------------
/slicetca/invariance/invariance.py:
--------------------------------------------------------------------------------
1 | from slicetca.invariance.iterative_invariance import sgd_invariance
2 | from slicetca.invariance.analytic_invariance import svd_basis
3 | from slicetca.invariance.criteria import *
4 | from slicetca.core.decompositions import SliceTCA
5 |
6 | dict_L2_invariance_objectives = {'regularization': l2}
7 | dict_L3_invariance_functions = {'svd': svd_basis}
8 |
9 | def invariance(model: SliceTCA,
10 | L2: str = 'regularization',
11 | L3: str = 'svd',
12 | **kwargs):
13 | """
14 | High level function for invariance optimization.
15 | Note: modifies inplace, deepcopy your model if you want a copy of the not invariance-optimized components.
16 |
17 | :param model: A sliceTCA model.
18 | :param L2: String, currently only supports 'regularization', you may add additional objectives.
19 | :param L3: String, currently only supports 'svd'.
20 | :param kwargs: Key-word arguments to be passed to L2 and L3 optimization functions. See iterative_function.py
21 | :return: model with modified components.
22 | """
23 |
24 | if sum([r!=0 for r in model.ranks])>1:
25 | model = sgd_invariance(model, objective_function=dict_L2_invariance_objectives[L2], **kwargs)
26 | model = dict_L3_invariance_functions[L3](model, **kwargs)
27 |
28 | return model
29 |
--------------------------------------------------------------------------------
/documentation.md:
--------------------------------------------------------------------------------
1 | # Documentation
2 |
3 | Here is a brief list of the functionalities provided by this repository. Additional information is provided in their docstring or by calling `help(function_name)`.
4 |
5 | ## High-level functions
6 |
7 | To get quickly started the following high-level functions can be used. These can imported at once with `from slicetca import *`.
8 |
9 | * `decompose` is the high-level function to decompose a data tensor.
10 | * `grid_search` is for determining the number of components.
11 | * `plot` allows plotting sliceTCA and TCA components.
12 |
13 | We recommend having a look at our notebooks for further details.
14 |
15 | ## Low-level functions
16 |
17 | For more specific use-cases, low-level functions might be preferred.
18 |
19 | * `_core.decompositions.SliceTCA`
20 | * `.fit(self, data, ...)` fits the components of a sliceTCA object to some data.
21 | * `.set_components(self, components)` sets the model's components.
22 | * `.get_components(self, detach=False, numpy=False)` returns the model's components. To backpropagate through the tensor, set detach=False.
23 | * `_invariance`
24 | * `.analytic_invariance.svd_basis(model)` sets the vectors of each slice type to an orthonormal basis and sort them by variance explained.
25 | * `.sgd_invariance(model, objective_function, transformation, ...)` allows optimizing the components w.r.t. some objective function while fixing the overall tensor.
26 |
27 |
--------------------------------------------------------------------------------
/slicetca/invariance/analytic_invariance.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def svd_basis(model, **kwargs):
5 | """
6 | Sets the vectors of each slice type to an orthonormal basis.
7 |
8 | :param model: SliceTCA model
9 | :param kwargs: ignored
10 | :return: model with new components.
11 | """
12 |
13 | device = model.device
14 | ranks = model.ranks
15 |
16 | new_components = [[None, None] for i in range(len(ranks))]
17 | for i in range(len(ranks)):
18 | if ranks[i] != 0:
19 | constructed = model.construct_single_partition(i)
20 | flattened_constructed = constructed.permute([i]+[q for q in range(len(ranks)) if q != i])
21 | flattened_constructed = flattened_constructed.reshape(model.dimensions[i],-1).transpose(0,1)
22 |
23 | U, S, V = torch.linalg.svd(flattened_constructed.detach().cpu(), full_matrices=False)
24 | U, S, V = U[:,:ranks[i]], S[:ranks[i]], V[:ranks[i]]
25 | U, S, V = U.to(device), S.to(device), V.to(device)
26 |
27 | US = (U @ torch.diag(S))
28 | slice = US.transpose(0,1).reshape([ranks[i]]+[model.dimensions[q] for q in range(len(ranks)) if q != i])
29 |
30 | new_components[i][0] = V
31 | new_components[i][1] = slice
32 | else:
33 | new_components[i][0] = torch.zeros_like(model.vectors[i][0])
34 | new_components[i][1] = torch.zeros_like(model.vectors[i][1])
35 |
36 | model.set_components(new_components)
37 |
38 | return model
39 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SliceTCA
2 |
3 | This library provides tools to perform [sliceTCA](https://www.biorxiv.org/content/10.1101/2023.03.01.530616v1).
4 |
5 | ___
6 |
7 |
8 |
9 |
10 |
11 | ## Installation
12 |
13 | ```commandline
14 | pip install slicetca
15 | ```
16 |
17 | ## Full documentation
18 |
19 | The full documentation can be found [here](https://github.com/arthur-pe/slicetca/blob/master/documentation.md).
20 |
21 | ## Examples
22 |
23 | ### Quick example
24 |
25 | ```python
26 | import slicetca
27 | import torch
28 | from matplotlib import pyplot as plt
29 |
30 | device = ('cuda' if torch.cuda.is_available() else 'cpu')
31 |
32 | # your_data is a numpy array of shape (trials, neurons, time).
33 | data = torch.tensor(your_data, dtype=torch.float, device=device)
34 |
35 | # The tensor is decomposed into 2 trial-, 0 neuron- and 3 time-slicing components.
36 | components, model = slicetca.decompose(data, (2,0,3))
37 |
38 | # For a not positive decomposition, we apply uniqueness constraints
39 | model = slicetca.invariance(model)
40 |
41 | slicetca.plot(model)
42 |
43 | plt.show()
44 | ```
45 |
46 | ### Notebook
47 |
48 | See the [example notebook](https://github.com/arthur-pe/slicetca/blob/master/sliceTCA_notebook_1.ipynb) for an application of sliceTCA to publicly available neural data.
49 |
50 |
51 |
52 |
53 |
54 | ## Reference
55 |
56 | A. Pellegrino@†, H. Stein†, N. A. Cayco-Gaijc@. (2024). Dimensionality reduction beyond neural subspaces with slice tensor component analysis. *Nature Neuroscience* [https://www.nature.com/articles/s41593-024-01626-2](https://www.nature.com/articles/s41593-024-01626-2).
--------------------------------------------------------------------------------
/slicetca/invariance/helper.py:
--------------------------------------------------------------------------------
1 | from slicetca.core.decompositions import SliceTCA
2 |
3 | import torch
4 | from typing import Sequence
5 |
6 | def mm(a: torch.Tensor, b: torch.Tensor):
7 | """
8 | Performs generalized matrix multiplication (ijq) x (qkl) -> (ijkl)
9 | :param a: torch
10 | :param b:
11 | :return:
12 | """
13 | temp1 = [chr(105+i) for i in range(len(a.size()))]
14 | temp2 = [chr(105+len(a.size())-1+i) for i in range(len(b.size()))]
15 | indexes1 = ''.join(temp1)
16 | indexes2 = ''.join(temp2)
17 | rhs = ''.join(temp1[:-1])+''.join(temp2[1:])
18 | formula = indexes1+','+indexes2+'->'+rhs
19 | return torch.einsum(formula,[a,b])
20 |
21 |
22 | def batch_outer(a: torch.Tensor, b: torch.Tensor):
23 | temp1 = [chr(105 + i + 1) for i in range(len(a.size()) - 1)]
24 | temp2 = [chr(105 + len(a.size()) + i + 1) for i in range(len(b.size()) - 1)]
25 | indexes1 = ''.join(temp1)
26 | indexes2 = ''.join(temp2)
27 | formula = chr(105) + indexes1 + ',' + chr(105) + indexes2 + '->' + chr(105) + indexes1 + indexes2
28 | return torch.einsum(formula, [a, b])
29 |
30 |
31 | def construct_per_type(model: SliceTCA, components: Sequence[Sequence[torch.Tensor]]):
32 | """
33 | :param model: SliceTCA model.
34 | :param components: The components to construct.
35 | :return: Reconstructed tensor.
36 | """
37 |
38 | temp = [torch.zeros(model.dimensions).to(model.device) for i in range(len(components))]
39 |
40 | for i in range(len(components)):
41 | for j in range(model.ranks[i]):
42 | temp[i] += construct_single_component(model, components, i, j)
43 | return temp
44 |
45 |
46 | def construct_single_component(model: SliceTCA, components: Sequence[Sequence[torch.Tensor]], partition: int, k: int):
47 |
48 | temp2 = [model.positive_function[partition][q](components[partition][q][k]) for q in range(len(components[partition]))]
49 | outer = torch.einsum(model.einsums[partition], temp2)
50 | outer = outer.permute(model.inverse_permutations[partition])
51 |
52 | return outer
53 |
--------------------------------------------------------------------------------
/slicetca/invariance/iterative_invariance.py:
--------------------------------------------------------------------------------
1 | from .transformations import *
2 | from .criteria import *
3 | from ..core.decompositions import SliceTCA
4 |
5 | import torch
6 | import copy
7 | import tqdm
8 | from typing import Callable
9 |
10 |
11 | def sgd_invariance(model: SliceTCA,
12 | objective_function: Callable = l2,
13 | transformation: object = None,
14 | learning_rate: float = 10**-2,
15 | max_iter: int = 10000,
16 | min_std: float = 10**-3,
17 | iter_std: int = 100,
18 | verbose: bool = False,
19 | progress_bar: bool = True,
20 | **kwargs):
21 | """
22 | Enables optimizing the components w.r.t. some objective function while fixing the overall reconstructed tensor.
23 |
24 | :param model: SliceTCA model.
25 | :param objective_function: The objective to optimize.
26 | :param transformation: transformations.TransformationBetween(model) or
27 | transformations.TransformationWithin(model) or
28 | nn.Sequential(TransformationWithin(model), TransformationBetween(model))
29 | :param learning_rate: Learning rate for the optimizer (default Adam).
30 | :param max_iter: Maximum number of iterations.
31 | :param min_std: Minimum std of the last iter_std iterations under which to assume the model has converged.
32 | :param iter_std: See min_std.
33 | :param verbose: Whether to print the loss.
34 | :param progress_bar: Whether to have a progress bar.
35 | :param kwargs: ignored.
36 | :return: model with the modified components.
37 | """
38 |
39 | if transformation is None: transformation = TransformationBetween(model)
40 |
41 | model.requires_grad_(False)
42 |
43 | optim = torch.optim.Adam(transformation.parameters(), lr=learning_rate)
44 |
45 | components = model.get_components(detach=True)
46 |
47 | losses = []
48 |
49 | iterator = tqdm.tqdm(range(max_iter)) if progress_bar else range(max_iter)
50 |
51 | for iteration in iterator:
52 |
53 | components_transformed = transformation(copy.deepcopy(components))
54 |
55 | components_transformed_constructed = construct_per_type(model, components_transformed)
56 | l = objective_function(components_transformed_constructed)
57 |
58 | if verbose: print('Iteration:', iteration, '\tloss:', l.item())
59 | if progress_bar: iterator.set_description("Invariance loss: " + str(l.item())[:10] + ' ')
60 |
61 | optim.zero_grad()
62 | l.backward()
63 | optim.step()
64 |
65 | losses.append(l.item())
66 |
67 | if len(losses)>iter_std and np.array(losses[-100:]).std()=0):
35 | raise Exception('For all i it should be that train_blocks_dimensions[i]>=test_blocks_dimensions[i].')
36 |
37 | if number_blocks is None:
38 | number_blocks = int(fraction_test * flattened_max_dim / np.prod(2*np.array(test_blocks_dimensions)+1))
39 | else:
40 | warnings.warn('The parameter number_blocks is deprecated, use fraction_test instead.', DeprecationWarning)
41 |
42 | if exact:
43 | start = torch.zeros(flattened_max_dim, device=device)
44 | start[:number_blocks] = 1
45 | start = start[torch.randperm(flattened_max_dim, device=device)]
46 | start = start.reshape(dimensions)
47 | else:
48 | start = (torch.rand(tuple(dimensions), device=device) < fraction_test).long()
49 |
50 | start_index = start.nonzero()
51 | number_blocks = len(start_index)
52 |
53 | # Build outer-blocks mask
54 | a = [[slice(torch.clip(start_index[j][i]-train_blocks_dimensions[i],min=0, max=dimensions[i]),
55 | torch.clip(start_index[j][i]+train_blocks_dimensions[i]+1,min=0, max=dimensions[i]))
56 | for i in range(valence)] for j in range(number_blocks)]
57 |
58 | train_mask = torch.full(dimensions, True, device=device)
59 |
60 | for j in a: train_mask[j] = 0
61 |
62 | # Build inner-blocks tensor
63 | a = [[slice(torch.clip(start_index[j][i]-test_blocks_dimensions[i],min=0, max=dimensions[i]),
64 | torch.clip(start_index[j][i]+test_blocks_dimensions[i]+1,min=0, max=dimensions[i]))
65 | for i in range(valence)] for j in range(number_blocks)]
66 |
67 | test_mask = torch.full(dimensions, False, device=device)
68 |
69 | for j in a: test_mask[j] = 1
70 |
71 | return train_mask, test_mask
72 |
--------------------------------------------------------------------------------
/slicetca/run/decompose.py:
--------------------------------------------------------------------------------
1 | from slicetca.core import SliceTCA, TCA
2 | from slicetca.core.helper_functions import squared_difference, poisson_log_likelihood
3 |
4 | import torch
5 | from typing import Union, Sequence
6 | import numpy as np
7 | import scipy
8 | from functools import partial
9 |
10 |
11 | def decompose(data: Union[torch.Tensor, np.array],
12 | number_components: Union[Sequence[int], int],
13 | positive: bool = False,
14 | initialization: str = 'uniform',
15 | learning_rate: float = 5*10**-3,
16 | batch_prop: float = 0.2,
17 | max_iter: int = 10000,
18 | min_std: float = 10**-5,
19 | iter_std: int = 100,
20 | mask: torch.Tensor = None,
21 | verbose: bool = False,
22 | progress_bar: bool = True,
23 | seed: int = 7,
24 | weight_decay: float = None,
25 | batch_prop_decay: int = 1):
26 | """
27 | High-level function to decompose a data tensor into a SliceTCA or TCA decomposition.
28 |
29 | :param data: Torch tensor.
30 | :param number_components: If list or tuple number of sliceTCA components, else number of TCA components.
31 | :param positive: Whether to use a positive decomposition. Defaults the initialization to 'uniform-positive'.
32 | :param initialization: Components initialization 'uniform'~U(-1,1), 'uniform-positive'~U(0,1), 'normal'~N(0,1).
33 | :param learning_rate: Learning rate of the optimizer.
34 | :param batch_prop: Proportion of entries used to compute the gradient at every training iteration.
35 | :param max_iter: Maximum training iterations.
36 | :param min_std: Minimum std of the loss under which to return.
37 | :param iter_std: Number of iterations over which this std is computed.
38 | :param mask: Entries which are not used to compute the gradient at any training iteration.
39 | :param verbose: Whether to print the loss at every step.
40 | :param progress_bar: Whether to have a tqdm progress bar.
41 | :param seed: Torch seed.
42 | :param weight_decay: Decay of the parameters. If None defaults to Adam, else AdamW.
43 | :param batch_prop_decay: Exponential decay steps of the proportion of entries not used to compute the gradient.
44 | :return: components: A list (over component types) of lists (over factors) of rank x component_shape tensors.
45 | :return: model: A SliceTCA or TCA model. It can be used to access the losses over training and much more.
46 | """
47 |
48 | torch.manual_seed(seed)
49 |
50 | if isinstance(data, np.ndarray): data = torch.tensor(data, device='cuda' if torch.cuda.is_available() else 'cpu')
51 |
52 | if data.dtype != torch.long:
53 | loss_function = squared_difference
54 | else:
55 | spikes_factorial = torch.tensor(scipy.special.factorial(data.numpy(force=True)), device=data.device)
56 | loss_function = partial(poisson_log_likelihood, spikes_factorial=spikes_factorial)
57 |
58 | dimensions = list(data.shape)
59 |
60 | if isinstance(number_components, int): decomposition = TCA
61 | else: decomposition = SliceTCA
62 |
63 | model = decomposition(dimensions, number_components, positive, initialization, device=data.device)
64 |
65 | if weight_decay is None: optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
66 | else: optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
67 |
68 | for i in range(1,batch_prop_decay+1):
69 | model.fit(data, optimizer, loss_function,
70 | 1-(1-batch_prop)**i, max_iter, min_std, iter_std, mask, verbose, progress_bar)
71 |
72 | return model.get_components(numpy=True), model
73 |
--------------------------------------------------------------------------------
/slicetca/plotting/grid.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from matplotlib import pyplot as plt
4 | from matplotlib.ticker import MaxNLocator
5 |
6 | from typing import Literal, Sequence
7 |
8 |
9 | def plot_grid(loss_grid: np.ndarray,
10 | min_ranks: Sequence = (0, 0, 0),
11 | data: np.ndarray = None,
12 | elbow: float = 0.9,
13 | quantile: float = 0.0,
14 | vmin: float = None,
15 | vmax: float = None,
16 | reduction: Literal['mean', 'min'] = 'min',
17 | variables: Sequence[str] = ('trial', 'neuron', 'time'),
18 | ):
19 | """
20 | :param loss_grid: (trial x neuron x time x batch) grid of cross-validated losses as a function of the number of components
21 | :param min_ranks: minimum ranks of the gridsearch (default (0, 0, 0))
22 | :param data: to construct elbow
23 | :param elbow: fraction of best model performance relative to data squared norm to use for elbow
24 | :param quantile:
25 | :param vmin:
26 | :param vmax:
27 | :param reduction:
28 | :param variables:
29 | :return:
30 | """
31 |
32 | max_ranks = tuple(np.array(min_ranks)+np.array(loss_grid.shape[:3]))
33 |
34 | match reduction:
35 | case 'mean':
36 | reduced_loss_grid = loss_grid.mean(axis=-1)
37 | case 'min':
38 | reduced_loss_grid = loss_grid.min(axis=-1)
39 | case _:
40 | raise Exception('Reduction should be mean or min.')
41 |
42 | min_index = np.unravel_index(np.argmin(reduced_loss_grid), reduced_loss_grid.shape)
43 |
44 | if data is not None:
45 | elbow_threshold = np.min(reduced_loss_grid)*elbow+np.mean(data**2)*(1-elbow)
46 |
47 | nb_plot = reduced_loss_grid.shape[0]
48 | if vmin is None: vmin = np.quantile(reduced_loss_grid, quantile)
49 | if vmax is None: vmax = np.quantile(reduced_loss_grid, 1-quantile)
50 |
51 | fig = plt.figure(figsize=(3*nb_plot,3), constrained_layout=True)
52 | axes = [fig.add_subplot(1, nb_plot, i+1) for i in range(nb_plot)]
53 |
54 | for i in range(nb_plot):
55 | im = axes[i].imshow(reduced_loss_grid[i], cmap='pink_r', origin='lower',
56 | vmin=vmin, vmax=vmax, extent=(min_ranks[2]-0.5, max_ranks[2]-0.5, min_ranks[1]-0.5, max_ranks[1]-0.5))
57 |
58 | if data is not None:
59 | elbow_neuron = np.argmax(reduced_loss_grid[i] < elbow_threshold, axis=1).astype(float)
60 | elbow_time = np.argmax(reduced_loss_grid[i] < elbow_threshold, axis=0).astype(float)
61 | elbow_time[np.all(reduced_loss_grid[i] >= elbow_threshold, axis=0)] = np.nan
62 | elbow_neuron[np.all(reduced_loss_grid[i] >= elbow_threshold, axis=1)] = np.nan
63 | axes[i].plot(np.arange(min_ranks[2], max_ranks[2]), min_ranks[1]+elbow_time, color=(0, 0, 0), alpha=0.3, linewidth=2.0)
64 | axes[i].plot(min_ranks[2]+elbow_neuron, np.arange(min_ranks[1], max_ranks[1]), color=(0, 0, 0), alpha=0.3, linewidth=2.0)
65 |
66 | axes[i].set_aspect('equal')
67 | axes[i].set_title('$R_{'+variables[0]+'}='+f'{min_ranks[0]+i}$')
68 | axes[i].set_ylabel('$R_{'+variables[1]+'}$')
69 | axes[i].set_xlabel('$R_{'+variables[2]+'}$')
70 |
71 | axes[i].yaxis.set_major_locator(MaxNLocator(integer=True))
72 | axes[i].xaxis.set_major_locator(MaxNLocator(integer=True))
73 |
74 | fig.colorbar(im,fraction=0.046, pad=0.04)
75 |
76 | axes[min_index[0]].scatter(min_ranks[2]+min_index[2], min_ranks[1]+min_index[1],
77 | color=(0.9,0.2, 0.2), marker='*', s=200)
78 |
79 | if data is not None:
80 | elbow_index = np.stack(np.meshgrid(*tuple([np.arange(min_ranks[i], max_ranks[i]) for i in range(3)]), indexing='ij'), axis=-1).sum(axis=-1).astype(float)
81 | elbow_index[(reduced_loss_grid>=elbow_threshold)] = np.nan
82 | if not np.all(np.isnan(elbow_index)):
83 | elbow_index = np.unravel_index(np.argmin(elbow_index, axis=None), elbow_index.shape)
84 | axes[elbow_index[0]].scatter(min_ranks[2]+elbow_index[2], min_ranks[1]+elbow_index[1],
85 | color=(0.9,0.2, 0.2), marker='*', s=200, facecolor=(1, 1, 1, 0))
86 |
--------------------------------------------------------------------------------
/slicetca/tests/tests.py:
--------------------------------------------------------------------------------
1 | from slicetca.core.decompositions import PartitionTCA, SliceTCA
2 | from slicetca.invariance.analytic_invariance import svd_basis
3 | from slicetca.invariance.iterative_invariance import TransformationWithin, TransformationBetween, sgd_invariance, l2
4 | from slicetca.run.decompose import decompose
5 | from slicetca.run.grid_search import grid_search
6 |
7 | import torch
8 | from torch import nn
9 | import numpy as np
10 |
11 |
12 | dimensions = (5, 6, 7, 8)
13 | ranks = (1, 0, 2, 3)
14 |
15 |
16 | def test_analytic_invariance():
17 |
18 | m = SliceTCA(dimensions, ranks, initialization='uniform', device=device)
19 |
20 | a = m.construct().detach()
21 | m = svd_basis(m)
22 | b = m.construct()
23 |
24 | assert torch.mean(torch.square(a - b)).item()<10**-6, 'Analytic invariance changes L1'
25 |
26 | print('test_analytic_invariance passed')
27 |
28 |
29 | def test_iterative_invariance_within_between():
30 |
31 | m = SliceTCA(dimensions, ranks, initialization='uniform', device=device)
32 |
33 | a = m.construct().detach()
34 |
35 | transfo = nn.Sequential(TransformationBetween(m), TransformationWithin(m))
36 |
37 | m = sgd_invariance(m, l2, transformation=transfo, max_iter=3, learning_rate=0.01, progress_bar=progress_bar)
38 | b = m.construct()
39 |
40 | assert torch.mean(torch.square(a - b)).item()<10**-6, 'Iterative invariance L2, L3 changes L1'
41 |
42 | print('test_invariance_between passed')
43 |
44 |
45 | def test_iterative_invariance_within():
46 | m = SliceTCA(dimensions, ranks, initialization='uniform', device=device)
47 |
48 | a = m.construct().detach()
49 |
50 | transfo = TransformationWithin(m)
51 |
52 | m = sgd_invariance(m, l2, transformation=transfo, max_iter=3, learning_rate=0.01, progress_bar=progress_bar)
53 | b = m.construct()
54 |
55 | assert torch.mean(torch.square(a - b)).item()<10**-6, 'Iterative invariance L3 changes L1'
56 |
57 | print('test_invariance_within passed')
58 |
59 |
60 | def test_iterative_invariance_between():
61 | m = SliceTCA(dimensions, ranks, initialization='uniform', device=device)
62 |
63 | a = m.construct().detach()
64 |
65 | transfo = TransformationBetween(m)
66 |
67 | m = sgd_invariance(m, l2, transformation=transfo, max_iter=3, learning_rate=0.01, progress_bar=progress_bar)
68 | b = m.construct()
69 |
70 | assert torch.mean(torch.square(a - b)).item()<10**-6, 'Iterative invariance L2 changes L1'
71 |
72 | print('test_iterative_invariance_between passed')
73 |
74 |
75 | def test_gridsearch():
76 |
77 | data = SliceTCA(dimensions, ranks, device=device).construct().detach()
78 |
79 | mask_train = torch.rand_like(data) < 0.5
80 | mask_test = mask_train & (torch.rand_like(data) < 0.5)
81 |
82 | loss_grid, seed_grid = grid_search(data, ranks, learning_rate=10 ** -3, max_iter=3, sample_size=2,
83 | processes_sample=2,
84 | processes_grid=2, mask_train=mask_train, mask_test=mask_test)
85 |
86 | print('test_gridsearch passed')
87 |
88 |
89 | def test_decompose():
90 |
91 | a = SliceTCA(dimensions, ranks, device=device).construct().detach()
92 |
93 | c, m = decompose(a, ranks)
94 |
95 | b = m.construct()
96 |
97 | assert torch.mean(torch.square(a - b)).item()<10**-6, 'Decompose fails to decompose a reconstructed tensor'
98 |
99 | print('test_decompose passed')
100 |
101 |
102 | def test_fit():
103 | t = PartitionTCA(dimensions[:3], [[[0], [1, 2]], [[2, 0], [1]]], [1, 2]).construct().detach()
104 | p = PartitionTCA(dimensions[:3], [[[0], [1, 2]], [[2, 0], [1]]], [1, 2])
105 |
106 | optim = torch.optim.Adam(p.parameters(), lr=10 ** -4)
107 |
108 | p.fit(t, optimizer=optim, min_std=10 ** -6, batch_prop=0.1, max_iter=3 * 10 ** 4, progress_bar=progress_bar)
109 |
110 | mse = torch.mean(torch.square(t - p.construct())).item()
111 |
112 | assert mse < 10 ** -3, 'Failed to decompose a reconstruction: ' + str(mse)
113 |
114 | print('test_fit passed')
115 |
116 |
117 | if __name__ == '__main__':
118 |
119 | torch.manual_seed(7)
120 |
121 | device = ('cuda' if torch.cuda.is_available() else 'cpu')
122 |
123 | progress_bar = False
124 |
125 | test_analytic_invariance()
126 | test_iterative_invariance_within_between()
127 | test_iterative_invariance_between()
128 | test_iterative_invariance_within()
129 | test_gridsearch()
130 | test_fit()
131 |
132 | print('All tests passed.')
133 |
--------------------------------------------------------------------------------
/slicetca/invariance/transformations.py:
--------------------------------------------------------------------------------
1 | from .helper import *
2 |
3 | from itertools import combinations
4 | import torch
5 | import torch.nn as nn
6 | import numpy as np
7 |
8 |
9 | class TransformationBetween(nn.Module):
10 | """
11 | Transformation between sliceTCA component types.
12 | """
13 |
14 | def __init__(self, model):
15 | super(TransformationBetween, self).__init__()
16 |
17 | self.number_components = len(model.ranks)
18 | self.ranks = model.ranks
19 | self.partitions = model.partitions
20 | self.dims = model.dimensions
21 | self.device = model.device
22 |
23 | self.free_vectors_combinations = nn.ModuleList(
24 | [nn.ParameterList([nn.Parameter(torch.tensor(0.0, device=self.device)) for j in range(self.number_components)]) for i in
25 | range(self.number_components)])
26 |
27 | self.components = model.get_components()
28 |
29 | self.remaining_index_combinations = [[None for j in range(self.number_components)] for i in
30 | range(self.number_components)]
31 | self.remaining_dims_combinations = [[None for j in range(self.number_components)] for i in
32 | range(self.number_components)]
33 |
34 | for combination in combinations(list(range(self.number_components)), 2):
35 | if self.ranks[combination[0]] != 0 and self.ranks[combination[1]] != 0:
36 | temp = set(self.partitions[combination[0]][1])
37 | temp = temp.intersection(set(self.partitions[combination[1]][1]))
38 | self.remaining_index_combinations[combination[0]][combination[1]] = list(temp)
39 |
40 | self.remaining_dims_combinations[combination[0]][combination[1]] = [self.dims[i] for i in temp]
41 | remaining_dims = self.remaining_dims_combinations[combination[0]][combination[1]]
42 |
43 | free_vectors_dim = [self.ranks[combination[0]], self.ranks[combination[1]]] + remaining_dims
44 | free_vectors = nn.Parameter(torch.randn(free_vectors_dim, device=self.device))
45 | self.free_vectors_combinations[combination[0]][combination[1]] = free_vectors
46 |
47 | def forward(self, components):
48 |
49 | for combination in combinations(list(range(self.number_components)), 2):
50 | if self.ranks[combination[0]] != 0 and self.ranks[combination[1]] != 0:
51 | a_index = self.partitions[combination[0]][0][0]
52 | b_index = self.partitions[combination[1]][0][0]
53 |
54 | A_indexes = [b_index] + self.remaining_index_combinations[combination[0]][combination[1]]
55 | B_indexes = [a_index] + self.remaining_index_combinations[combination[0]][combination[1]]
56 |
57 | perm_B = [A_indexes.index(i) for i in self.partitions[combination[0]][1]]
58 | perm_A = [B_indexes.index(i) for i in self.partitions[combination[1]][1]]
59 |
60 | free_vectors = self.free_vectors_combinations[combination[0]][combination[1]]
61 | A = batch_outer(components[combination[0]][0], free_vectors)
62 | B = batch_outer(components[combination[1]][0], free_vectors.transpose(0, 1))
63 |
64 | A = A.sum(dim=0)
65 | B = B.sum(dim=0)
66 |
67 | A = A.transpose(0, 1)
68 | B = B.transpose(0, 1)
69 |
70 | A = A.permute([0] + [1 + i for i in perm_A])
71 | B = B.permute([0] + [1 + i for i in perm_B])
72 |
73 | components[combination[0]][1] = components[combination[0]][1] + B
74 | components[combination[1]][1] = components[combination[1]][1] - A
75 |
76 | return components
77 |
78 |
79 | class TransformationWithin(nn.Module):
80 | """
81 | Transformation within sliceTCA component types.
82 | """
83 |
84 | def __init__(self, model):
85 | super().__init__()
86 |
87 | self.ranks = model.ranks
88 | self.number_components = len(model.ranks)
89 | self.device = model.device
90 |
91 | self.free_gl = nn.ParameterList([nn.Parameter(torch.eye(i, device=self.device)+torch.randn((i,i),
92 | device=self.device)/np.sqrt(3*i)) for i in self.ranks])
93 |
94 | def forward(self, components):
95 |
96 | for i in range(self.number_components):
97 | if self.ranks[i] != 0:
98 | components[i][0] = mm(self.free_gl[i].T, components[i][0])
99 | components[i][1] = mm(torch.linalg.inv(self.free_gl[i]), components[i][1])
100 |
101 | return components
102 |
--------------------------------------------------------------------------------
/slicetca/run/grid_search.py:
--------------------------------------------------------------------------------
1 | from slicetca.run.decompose import decompose
2 |
3 | import multiprocessing as mp
4 | from functools import partial
5 | from concurrent.futures import ProcessPoolExecutor as Pool
6 | from tqdm import tqdm
7 | import torch
8 | import numpy as np
9 |
10 | from typing import Sequence, Union
11 |
12 |
13 | # To be fixed: high memory usage when using GPU.
14 |
15 | def grid_search(data: Union[torch.Tensor], # Only works with torch.Tensor atm
16 | max_ranks: Sequence[int],
17 | mask_train: torch.Tensor = None,
18 | mask_test: torch.Tensor = None,
19 | min_ranks: Sequence[int] = None,
20 | sample_size: int = 1,
21 | processes_sample: int = 1,
22 | processes_grid: int = 1,
23 | seed: int = 7,
24 | **kwargs):
25 | """
26 | Performs a gridsearch over different number of components (ranks) to see which has the lowest cross-validated loss.
27 |
28 | :param data: Data tensor to decompose.
29 | :param max_ranks: Maximum number of components of each type.
30 | :param mask_train: Mask representing over which entries to compute the backpropagated loss. None is full tensor.
31 | :param mask_test: Mask representing over which entries to compute the loss for validation. None is full tensor.
32 | :param min_ranks: Minimum number of components of each type.
33 | :param sample_size: Number of seeds to use for a given number of components.
34 | :param processes_sample: Number of processes (threads) to use for a given number of components across seeds.
35 | :param processes_grid: Number of processes (threads) to use over different number of components.
36 | :param seed: Numpy seed.
37 | :param kwargs: Same kwargs as decompose.
38 | :return: A (max_rank_1-min_rank_1, max_rank_2-min_rank_2, ..., sample_size) ndarray of losses masked entries.
39 | """
40 |
41 | np.random.seed(seed)
42 |
43 | try:
44 | mp.set_start_method('spawn', force=True)
45 | except RuntimeError:
46 | pass
47 |
48 | if min_ranks is None: min_ranks = [0 for i in max_ranks]
49 | max_ranks = [i+1 for i in max_ranks]
50 | rank_span = [max_ranks[i]-min_ranks[i] for i in range(len(max_ranks))]
51 |
52 | grid = get_grid_sample(min_ranks, max_ranks)
53 | grid = np.concatenate([grid, np.random.randint(10**2,10**6, grid.shape[0])[:,np.newaxis]], axis=-1)
54 |
55 | print('Grid shape:', str(rank_span),
56 | '- Samples:', sample_size,
57 | '- Grid entries:', torch.tensor(grid).size()[0],
58 | '- Number of models to fit:', torch.tensor(grid).size()[0]*sample_size)
59 |
60 | dec = partial(decompose_mp_sample, data=data, mask_train=mask_train, mask_test=mask_test, sample_size=sample_size,
61 | processes_sample=processes_sample, **kwargs)
62 |
63 | out_grid = []
64 | with Pool(max_workers=processes_grid) as pool:
65 | iterator = tqdm(pool.map(dec, grid), total=torch.tensor(grid).size()[0])
66 | iterator.set_description('Number of components (completed): - ', refresh=True)
67 | for i, p in enumerate(iterator):
68 | out_grid.append(p)
69 | iterator.set_description('Number of components (completed): '+str(np.unravel_index(i, tuple(max_ranks))) + ' ', refresh=True)
70 | out_grid = np.array(out_grid, dtype=np.float32)
71 |
72 | loss_grid = out_grid[:,0]
73 | seed_grid = out_grid[:,1].astype(int)
74 |
75 | loss_grid = loss_grid.reshape(rank_span+[sample_size])
76 | seed_grid = seed_grid.reshape(rank_span+[sample_size])
77 |
78 | return loss_grid, seed_grid
79 |
80 |
81 | def decompose_mp_sample(number_components_seed, data, mask_train, mask_test, sample_size, processes_sample, **kwargs):
82 |
83 | number_components = number_components_seed[:-1]
84 | seed = number_components_seed[-1]
85 |
86 | np.random.seed(seed)
87 |
88 | dec = partial(decompose_mp,
89 | data=data.clone(),
90 | mask_train=(mask_train.clone() if mask_train is not None else None),
91 | mask_test=(mask_test.clone() if mask_test is not None else None),
92 | **kwargs)
93 |
94 | sample = number_components[np.newaxis].repeat(sample_size, 0)
95 | seeds = np.random.randint(10**2,10**6, sample_size)
96 |
97 | sample = np.concatenate([sample, seeds[:,np.newaxis]], axis=-1)
98 |
99 | with Pool(max_workers=processes_sample) as pool: loss = np.array(list(pool.map(dec, sample)))
100 |
101 | return loss, seeds
102 |
103 |
104 | def decompose_mp(number_components_seed, data, mask_train, mask_test, *args, **kwargs):
105 |
106 | number_components, seed = number_components_seed[:-1], number_components_seed[-1]
107 |
108 | if (number_components == np.zeros_like(number_components)).all():
109 | data_hat = 0
110 | else:
111 | _, model = decompose(data, number_components, mask=mask_train, verbose=False, progress_bar=False, *args,
112 | seed=seed, **kwargs)
113 | data_hat = model.construct()
114 |
115 | if mask_test is None: loss = torch.mean((data-data_hat)**2).item()
116 | else: loss = torch.mean(((data-data_hat)[mask_test])**2).item()
117 |
118 | return loss
119 |
120 |
121 | def get_grid_sample(min_dims, max_dims):
122 |
123 | grid = np.meshgrid(*[np.array([i for i in range(min_dims[j],max_dims[j])]) for j in range(len(max_dims))],
124 | indexing='ij')
125 |
126 | grid = np.stack(grid)
127 |
128 | return grid.reshape(grid.shape[0], -1).T
129 |
--------------------------------------------------------------------------------
/slicetca/plotting/factors.py:
--------------------------------------------------------------------------------
1 | from matplotlib import pyplot as plt
2 | import numpy as np
3 | from typing import Sequence, Union
4 |
5 |
6 | def plot(model,
7 | components: Sequence[Sequence[np.ndarray]] = None,
8 | variables: Sequence[str] = ('trial', 'neuron', 'time'),
9 | colors: Union[Sequence[np.ndarray], Sequence[Sequence[float]]] = (None, None, None),
10 | sorting_indices: Sequence[np.ndarray] = (None, None, None),
11 | ticks: Sequence[np.ndarray] = (None, None, None),
12 | tick_labels: Sequence[np.ndarray] = (None, None, None),
13 | quantile: float = 0.95,
14 | factor_height: int = 2,
15 | aspect: str = 'auto',
16 | s: int = 10,
17 | cmap: str = None,
18 | tight_layout: bool = True,
19 | dpi: int = 60):
20 | """
21 | Plots SliceTCA components. Plotting TCA or PartitionTCA components also works but is not optimized.
22 |
23 | :param model: SliceTCA, TCA or PartitionTCA instance.
24 | :param components: By default, components = model.get_components(numpy=True).
25 | But you may pass pre-processed components (e.g. sorted neurons etc...).
26 | :param variables: The axes labels, in the same order as the dimensions of the tensor.
27 | :param colors: The colors of the variable (e.g. trial condition). Used only for 1-tensor factors.
28 | None or 1-d variable will default to plt.plot, 2-d (trials x RGBA) to scatter.
29 | Note that to generate RGBA colors from integer trial condition you may call:
30 | colors = matplotlib.colormaps['hsv'](condition/np.max(condition))
31 | :param sorting_indices: Sort (e.g. trials) according to indices.
32 | :param ticks: Can be used instead of the 0,1, ... default indexing.
33 | :param tick_labels: Requires ticks
34 | :param quantile: Quantile of imshow cmap.
35 | :param factor_height: Height of the 1-tensor factors. Their length is 3.
36 | :param aspect: 'auto' will give a square-looking slice, 'equal' will preserve the ratios.
37 | :param s: size of scatter dots (see colors parameter).
38 | :param cmap: matplotlib cmap for 2-tensor factors (slices). Defaults to inferno for positive else seismic.
39 | :param tight_layout: To call plt.tight_layout(). Note that constrained_layout does not work well with
40 | :param dpi: Figure dpi. Set lower if you have many components of a given type.
41 | :return: A list of axes which can be used for further customizing the plots.
42 | The list has shape the same shape as model.get_components. That is component_type x (slice/factor) x rank
43 | """
44 |
45 | components = model.get_components(numpy=True) if components is None else components
46 | partitions = model.partitions
47 | positive = model.positive
48 | ranks = model.ranks
49 |
50 | # Pad the variables in case fewer than needed variables are provided
51 | variables = list(variables)+['variable '+str(i+1) for i in range(len(variables), max(len(variables), len(ranks)))]
52 |
53 | number_nonzero_components = np.sum(np.array(ranks) != 0)
54 |
55 | axes = [[[None for k in j] for j in i] for i in components]
56 |
57 | figure_size = max([sum([j.shape[0]*3 if len(j.shape) == 3 else j.shape[0]*factor_height for j in i]) for i in components])
58 |
59 | fig = plt.figure(figsize=(number_nonzero_components*3, figure_size), dpi=dpi)
60 | gs = fig.add_gridspec(figure_size, number_nonzero_components)
61 |
62 | column = 0
63 | for i in range(len(ranks)):
64 | row = 0
65 | for j in range(ranks[i]):
66 | for k in range(len(components[i])):
67 | current_component = components[i][k][j]
68 |
69 | # =========== Plots 1-tensor factors ===========
70 | if len(list(components[i][k].shape)) == 2:
71 | ax = fig.add_subplot(gs[row:row+factor_height, column])
72 | row += factor_height
73 |
74 | leg = partitions[i][k][0]
75 |
76 | if sorting_indices[leg] is not None:
77 | current_component = current_component[sorting_indices[leg]]
78 |
79 | if isinstance(colors[leg], np.ndarray) and len(colors[leg].shape) == 2:
80 | ax.scatter(np.arange(len(current_component)), current_component, color=colors[leg], s=s)
81 | else:
82 | ax.plot(np.arange(len(current_component)), current_component,
83 | color=(0.0, 0.0, 0.0) if colors[leg] is None else colors[leg])
84 |
85 | if ticks[partitions[i][k][0]] is not None:
86 | ax.set_xticks(ticks[partitions[i][k][0]], tick_labels[partitions[i][k][0]])
87 |
88 | ax.set_xlabel(variables[leg])
89 |
90 | # =========== Plots 2-tensor factors (slices) ===========
91 | elif len(list(components[i][k].shape)) == 3:
92 | ax = fig.add_subplot(gs[row:row+3, column])
93 | row += 3
94 | ax.set_aspect(aspect)
95 |
96 | p = (positive if isinstance(positive, bool) else positive[i][k])
97 |
98 | if sorting_indices[partitions[i][k][0]] is not None:
99 | current_component = current_component[sorting_indices[partitions[i][k][0]]]
100 | if sorting_indices[partitions[i][k][1]] is not None:
101 | current_component = current_component.T[sorting_indices[partitions[i][k][1]]].T
102 |
103 | if p:
104 | ax.imshow(current_component, aspect=aspect, cmap=(cmap if cmap is not None else 'inferno'),
105 | vmin=np.quantile(current_component,1-quantile),
106 | vmax=np.quantile(current_component,quantile))
107 | else:
108 | min_max = np.quantile(np.abs(current_component),quantile)
109 | ax.imshow(current_component, aspect=aspect, cmap=(cmap if cmap is not None else 'seismic'),
110 | vmin=-min_max, vmax=min_max)
111 |
112 | # =========== Axes labels ===========
113 | variable_x = variables[partitions[i][k][1]]
114 | variable_y = variables[partitions[i][k][0]]
115 | ax.set_xlabel(variable_x)
116 | ax.set_ylabel(variable_y)
117 |
118 | if ticks[partitions[i][k][0]] is not None:
119 | ax.set_yticks(ticks[partitions[i][k][0]], tick_labels[partitions[i][k][0]])
120 | if ticks[partitions[i][k][1]] is not None:
121 | ax.set_xticks(ticks[partitions[i][k][1]], tick_labels[partitions[i][k][1]])
122 |
123 | # =========== Higher order factors can't be plotted ===========
124 | elif len(list(components[i][k].shape)) >= 4:
125 | ax = fig.add_subplot(gs[row:row+factor_height, column])
126 | row += factor_height
127 | ax.text(0.5, 0.5, '3$\geq$ tensor', va='center', ha='center', color='black')
128 | ax.axis('off')
129 |
130 | # =========== Store axes ===========
131 | axes[i][k][j] = ax
132 |
133 | if ranks[i] != 0: column += 1
134 |
135 | if tight_layout: fig.tight_layout()
136 |
137 | return axes
138 |
--------------------------------------------------------------------------------
/slicetca/core/decompositions.py:
--------------------------------------------------------------------------------
1 | from .helper_functions import squared_difference
2 |
3 | import torch
4 | from torch import nn
5 | import numpy as np
6 | import tqdm
7 | from collections.abc import Iterable
8 |
9 | from typing import Sequence, Union, Callable
10 |
11 |
12 | class PartitionTCA(nn.Module):
13 |
14 | def __init__(self,
15 | dimensions: Sequence[int],
16 | partitions: Sequence[Sequence[Sequence[int]]],
17 | ranks: Sequence[int],
18 | positive: Union[bool, Sequence[Sequence[Callable]]] = False,
19 | initialization: str = 'uniform',
20 | init_weight: float = None,
21 | init_bias: float = None,
22 | device: str = 'cpu'):
23 | """
24 | Parent class for the sliceTCA and TCA decompositions.
25 |
26 | :param dimensions: Dimensions of the data to decompose.
27 | :param partitions: List of partitions of the legs of the tensor.
28 | [[[0],[1]]] would be a matrix rank decomposition.
29 | :param ranks: Number of components of each partition.
30 | :param positive: If False does nothing.
31 | If True constrains all components to be positive.
32 | If list of list, the list of functions to apply to a given partition and component.
33 | :param initialization: Components initialization 'uniform'~U(-1,1), 'uniform-positive'~U(0,1), 'normal'~N(0,1).
34 | :param init_weight: Coefficient to multiply the initial component by.
35 | :param init_bias: Coefficient to add to the initial component.
36 | :param device: Torch device.
37 | """
38 |
39 | super(PartitionTCA, self).__init__()
40 |
41 | components = [[[dimensions[k] for k in j] for j in i] for i in partitions]
42 |
43 | if init_weight is None:
44 | if initialization == 'normal': init_weight = 1/np.sqrt(sum(ranks))
45 | if initialization == 'uniform-positive': init_weight = ((0.5 / sum(ranks)) ** (1 / max([len(p) for p in partitions])))*2
46 | if initialization == 'uniform': init_weight = 1/np.sqrt(sum(ranks))
47 | if init_bias is None: init_bias = 0.0
48 |
49 | if isinstance(positive, bool):
50 | if positive: positive_function = [[torch.abs for j in i] for i in partitions]
51 | else: positive_function = [[self.identity for j in i] for i in partitions]
52 | elif isinstance(positive, tuple) or isinstance(positive, list): positive_function = positive
53 |
54 | vectors = nn.ModuleList([])
55 |
56 | for i in range(len(ranks)):
57 | r = ranks[i]
58 | dim = components[i]
59 |
60 | # k-tensors of the outer product
61 | if initialization == 'normal':
62 | v = [nn.Parameter(positive_function[i][j](torch.randn([r]+d, device=device)*init_weight + init_bias)) for j, d in enumerate(dim)]
63 | elif initialization == 'uniform':
64 | v = [nn.Parameter(positive_function[i][j](2*(torch.rand([r] + d, device=device)-0.5)*init_weight + init_bias)) for j, d in enumerate(dim)]
65 | elif initialization == 'uniform-positive':
66 | v = [nn.Parameter(positive_function[i][j](torch.rand([r] + d, device=device)*init_weight + init_bias)) for j, d in enumerate(dim)]
67 | else:
68 | raise Exception('Undefined initialization, select one of : normal, uniform, uniform-positive')
69 |
70 | vectors.append(nn.ParameterList(v))
71 |
72 | self.vectors = vectors
73 |
74 | self.dimensions = dimensions
75 | self.partitions = partitions
76 | self.ranks = ranks
77 | self.positive = positive
78 | self.initialization = initialization
79 | self.init_weight = init_weight
80 | self.init_bias = init_bias
81 | self.device = device
82 |
83 | self.components = components
84 | self.positive_function = positive_function
85 | self.valence = len(dimensions)
86 | self.entries = np.prod(dimensions)
87 |
88 | self.losses = []
89 |
90 | self.inverse_permutations = []
91 | self.flattened_permutations = []
92 | for i in self.partitions:
93 | temp = []
94 | for j in i:
95 | for k in j:
96 | temp.append(k)
97 | self.flattened_permutations.append(temp)
98 | self.inverse_permutations.append(torch.argsort(torch.tensor(temp)).tolist())
99 |
100 | self.set_einsums()
101 |
102 | def identity(self, x):
103 | return x
104 |
105 | def set_einsums(self):
106 |
107 | self.einsums = []
108 | for i in self.partitions:
109 | lhs = ''
110 | rhs = ''
111 | for j in range(len(i)):
112 | for k in i[j]:
113 | lhs += chr(105 + k)
114 | rhs += chr(105 + k)
115 | if j != len(i) - 1:
116 | lhs += ','
117 | self.einsums.append(lhs + '->' + rhs)
118 |
119 | def construct_single_component(self, partition: int, k: int):
120 | """
121 | Constructs the kth term of the given partition.
122 |
123 | :param partition: Type of the partition
124 | :param k: Number of the component
125 | :return: Tensor of shape self.dimensions
126 | """
127 |
128 | temp = [self.positive_function[partition][q](self.vectors[partition][q][k]) for q in range(len(self.components[partition]))]
129 | outer = torch.einsum(self.einsums[partition], temp)
130 | outer = outer.permute(self.inverse_permutations[partition])
131 |
132 | return outer
133 |
134 | def construct_single_partition(self, partition: int):
135 | """
136 | Constructs the sum of the terms of a given type of partition.
137 |
138 | :param partition: Type of the partition
139 | :return: Tensor of shape self.dimensions
140 | """
141 |
142 | temp = torch.zeros(self.dimensions).to(self.device)
143 | for j in range(self.ranks[partition]):
144 | temp += self.construct_single_component(partition, j)
145 |
146 | return temp
147 |
148 | def construct(self):
149 | """
150 | Constructs the full tensor.
151 | :return: Tensor of shape self.dimensions
152 | """
153 |
154 | temp = torch.zeros(self.dimensions).to(self.device)
155 |
156 | for i in range(len(self.partitions)):
157 | for j in range(self.ranks[i]):
158 | temp += self.construct_single_component(i, j)
159 |
160 | return temp
161 |
162 | def get_components(self, detach=False, numpy=False):
163 | """
164 | Returns the components of the model.
165 |
166 | :param detach: Whether to detach the gradient.
167 | :param numpy: Whether to cast them to numpy arrays.
168 | :return: list of list of tensors.
169 | """
170 |
171 | temp = [[] for i in range(len(self.vectors))]
172 |
173 | for i in range(len(self.vectors)):
174 | for j in range(len(self.vectors[i])):
175 | if numpy:
176 | temp[i].append( self.positive_function[i][j](self.vectors[i][j]).data.detach().cpu().numpy())
177 | else:
178 | if not detach: temp[i].append(self.positive_function[i][j](self.vectors[i][j]).data.detach())
179 | else: temp[i].append(self.positive_function[i][j](self.vectors[i][j]).data)
180 |
181 | return temp
182 |
183 | def set_components(self, components: Sequence[Sequence[torch.Tensor]]): # bug if positive_function != abs
184 | """
185 | Set the model's components.
186 | If the positive functions are abs or the identity model.set_components(model.get_components)
187 | has no effect besides resetting the gradient.
188 |
189 | :param components: list of list tensors.
190 | """
191 |
192 | for i in range(len(self.vectors)):
193 | for j in range(len(self.vectors[i])):
194 | with torch.no_grad():
195 | if isinstance(components[i][j], torch.Tensor):
196 | self.vectors[i][j].copy_(components[i][j].to(self.device))
197 | else:
198 | self.vectors[i][j].copy_(torch.tensor(components[i][j], device=self.device))
199 | self.zero_grad()
200 |
201 | def fit(self,
202 | X: torch.Tensor,
203 | optimizer: torch.optim.Optimizer,
204 | loss_function: Callable = squared_difference,
205 | batch_prop: float = 0.2,
206 | max_iter: int = 10000,
207 | min_std: float = 10 ** -3,
208 | iter_std: int = 100,
209 | mask: torch.Tensor = None,
210 | verbose: bool = False,
211 | progress_bar: bool = True):
212 | """
213 | Fits the model to data.
214 |
215 | :param X: The data tensor.
216 | :param optimizer: A torch optimizer.
217 | :param loss_function: The final loss if torch.mean(loss_function(X, X_hat)). That is, loss_function: R^n -> R^n.
218 | :param batch_prop: Proportion of entries used to compute the gradient at every training iteration.
219 | :param max_iter: Maximum training iterations.
220 | :param min_std: Minimum std of the loss under which to return.
221 | :param iter_std: Number of iterations over which this std is computed.
222 | :param mask: Entries which are not used to compute the gradient at any training iteration.
223 | :param verbose: Whether to print the loss at every step.
224 | :param progress_bar: Whether to have a tqdm progress bar.
225 | """
226 |
227 | losses = []
228 |
229 | iterator = tqdm.tqdm(range(max_iter)) if progress_bar else range(max_iter)
230 |
231 | for iteration in iterator:
232 |
233 | X_hat = self.construct()
234 |
235 | loss_entries = loss_function(X, X_hat)
236 |
237 | total_loss = torch.mean(loss_entries)
238 |
239 | if batch_prop != 1.0: batch_mask = torch.rand(self.dimensions, device=self.device) < batch_prop
240 |
241 | if mask is None and batch_prop == 1.0:
242 | loss = total_loss
243 | else:
244 | if mask is None:
245 | total_mask = batch_mask
246 | else:
247 | if batch_prop == 1.0:
248 | total_mask = mask
249 | else:
250 | total_mask = mask & batch_mask
251 |
252 | total_entries = torch.sum(total_mask)
253 | loss = torch.sum(loss_entries * total_mask) / total_entries
254 |
255 | optimizer.zero_grad()
256 | loss.backward()
257 | optimizer.step()
258 |
259 | total_loss = total_loss.item()
260 |
261 | losses.append(total_loss)
262 |
263 | if verbose: print('Iteration:', iteration, 'Loss:', total_loss)
264 | if progress_bar: iterator.set_description('Loss: ' + str(total_loss) + ' ')
265 |
266 | if len(losses) > iter_std and np.array(losses[-iter_std:]).std() < min_std:
267 | if progress_bar: iterator.set_description('The model converged. Loss: ' + str(total_loss) + ' ')
268 | break
269 |
270 | self.losses += losses
271 |
272 |
273 | class SliceTCA(PartitionTCA):
274 | def __init__(self,
275 | dimensions: Sequence[int],
276 | ranks: Sequence[int],
277 | positive: bool = False,
278 | initialization: str = 'uniform',
279 | init_weight: float = None,
280 | init_bias: float = None,
281 | device: str = 'cpu'):
282 | """
283 | Main sliceTCA decomposition class.
284 |
285 | :param dimensions: Dimensions of the data to decompose.
286 | :param ranks: Number of components of each slice type.
287 | :param positive: If False does nothing.
288 | If True constrains all components to be positive.
289 | If list of list, the list of functions to apply to a given partition and component.
290 | :param initialization: Components initialization 'uniform'~U(-1,1), 'uniform-positive'~U(0,1), 'normal'~N(0,1).
291 | :param init_weight: Coefficient to multiply the initial component by.
292 | :param init_bias: Coefficient to add to the initial component.
293 | :param device: Torch device.
294 | """
295 |
296 | valence = len(dimensions)
297 | partitions = [[[i], [j for j in range(valence) if j != i]] for i in range(valence)]
298 |
299 | super().__init__(dimensions=dimensions, ranks=ranks, partitions=partitions, positive=positive,
300 | initialization=initialization, init_weight=init_weight, init_bias=init_bias, device=device)
301 |
302 |
303 | class TCA(PartitionTCA):
304 | def __init__(self,
305 | dimensions: Sequence[int],
306 | rank: int,
307 | positive: bool = False,
308 | initialization: str = 'uniform',
309 | init_weight: float = None,
310 | init_bias: float = None,
311 | device: str = 'cpu'):
312 | """
313 | Main TCA decomposition class.
314 |
315 | :param dimensions: Dimensions of the data to decompose.
316 | :param rank: Number of components.
317 | :param positive: If False does nothing.
318 | If True constrains all components to be positive.
319 | If list of list, the list of functions to apply to a given partition and component.
320 | :param initialization: Components initialization 'uniform'~U(-1,1), 'uniform-positive'~U(0,1), 'normal'~N(0,1).
321 | :param init_weight: Coefficient to multiply the initial component by.
322 | :param init_bias: Coefficient to add to the initial component.
323 | :param device: Torch device.
324 | """
325 |
326 | if not isinstance(rank, Iterable):
327 | rank = (rank,)
328 |
329 | valence = len(dimensions)
330 | partitions = [[[j] for j in range(valence)]]
331 |
332 | super().__init__(dimensions=dimensions, ranks=rank, partitions=partitions, positive=positive,
333 | initialization=initialization, init_weight=init_weight, init_bias=init_bias, device=device)
334 |
--------------------------------------------------------------------------------
/img/decomposition.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
597 |
--------------------------------------------------------------------------------