├── .gitignore ├── GSPA_example.ipynb ├── LICENSE.md ├── Overview.png ├── README.md ├── data └── example.npz ├── gspa ├── __init__.py ├── embedding.py ├── graphs.py ├── gspa.py ├── version.py └── wavelets.py ├── requirements.txt ├── setup.py └── tests └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | __pycache__/ 3 | dist/ 4 | gspa.egg-info/ 5 | gspa/.ipynb_checkpoints/ 6 | gspa/__pycache__/ 7 | -------------------------------------------------------------------------------- /GSPA_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "4bcd03c7-c008-4e71-bcb1-398cfa6a3dca", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stderr", 11 | "output_type": "stream", 12 | "text": [ 13 | "2024-06-17 17:53:02.924034: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", 14 | "2024-06-17 17:53:02.925743: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n", 15 | "2024-06-17 17:53:02.960350: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n", 16 | "2024-06-17 17:53:02.961276: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", 17 | "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", 18 | "2024-06-17 17:53:03.605873: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" 19 | ] 20 | } 21 | ], 22 | "source": [ 23 | "import numpy as np\n", 24 | "import gspa\n", 25 | "import scanpy, phate" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 12, 31 | "id": "ee7d01b2-edd0-4074-bf87-099dffbe42fe", 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "name": "stdout", 36 | "output_type": "stream", 37 | "text": [ 38 | "graphtools 1.5.3\n", 39 | "tensorflow 2.13.0\n", 40 | "keras 2.13.1\n", 41 | "numpy 1.22.4\n", 42 | "sklearn 1.3.2\n", 43 | "scipy 1.10.1\n", 44 | "tqdm 4.66.4\n", 45 | "scanpy 1.9.3\n", 46 | "phate 1.0.11\n" 47 | ] 48 | } 49 | ], 50 | "source": [ 51 | "import graphtools\n", 52 | "print('graphtools', graphtools.__version__)\n", 53 | "import tensorflow\n", 54 | "print('tensorflow', tensorflow.__version__)\n", 55 | "import keras\n", 56 | "print('keras', keras.__version__)\n", 57 | "import numpy as np\n", 58 | "print ('numpy', np.__version__)\n", 59 | "import sklearn\n", 60 | "print ('sklearn', sklearn.__version__)\n", 61 | "import scipy\n", 62 | "print ('scipy', scipy.__version__)\n", 63 | "import tqdm\n", 64 | "print ('tqdm', tqdm.__version__)\n", 65 | "import scanpy\n", 66 | "print ('scanpy', scanpy.__version__)\n", 67 | "import phate\n", 68 | "print ('phate', phate.__version__)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 2, 74 | "id": "8ab5ee38-60cc-4b48-aa07-3d3b6cdd552b", 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "data = np.load(f'data/example.npz')" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 3, 84 | "id": "983862b5-c373-4b77-ac46-f8d665abdd16", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "adata = scanpy.AnnData(data['counts'], obs={'pseudotime': data['pseudotime']}, dtype=np.float64)\n", 89 | "scanpy.pp.highly_variable_genes(adata)\n", 90 | "\n", 91 | "# gene_adata stores genes as observations and cells as variables\n", 92 | "gene_adata = adata[:, adata.var['highly_variable']].copy().T" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 4, 98 | "id": "f4be7185-e42b-4098-910f-09340340ead6", 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "gspa_op = gspa.GSPA(verbose=False)\n", 103 | "gspa_op.construct_graph(adata.to_df())\n", 104 | "gspa_op.build_diffusion_operator()\n", 105 | "gspa_op.build_wavelet_dictionary()" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 5, 111 | "id": "08da314d-84e6-4565-9ebb-cb34754eaf95", 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "name": "stderr", 116 | "output_type": "stream", 117 | "text": [ 118 | "2024-06-17 17:53:16.410180: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:268] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n" 119 | ] 120 | }, 121 | { 122 | "data": { 123 | "image/png": "", 124 | "text/plain": [ 125 | "
" 126 | ] 127 | }, 128 | "metadata": {}, 129 | "output_type": "display_data" 130 | } 131 | ], 132 | "source": [ 133 | "gene_adata.obsm['X_AE'], gene_adata.obsm['X_PC'] = gspa_op.get_gene_embeddings(gene_adata.to_df())\n", 134 | "gene_adata.obs['GSPA_localization'] = gspa_op.calculate_localization()\n", 135 | "gene_adata.obsm['X_phate'] = phate.PHATE(verbose=False, random_state=42).fit_transform(gene_adata.obsm['X_AE'])\n", 136 | "scanpy.external.pl.phate(gene_adata, color=['GSPA_localization'], cmap='PuBuGn', sort_order=False,\n", 137 | " vmax=np.percentile(gene_adata.obs['GSPA_localization'], 99.5))" 138 | ] 139 | } 140 | ], 141 | "metadata": { 142 | "kernelspec": { 143 | "display_name": "Python 3 (ipykernel)", 144 | "language": "python", 145 | "name": "python3" 146 | }, 147 | "language_info": { 148 | "codemirror_mode": { 149 | "name": "ipython", 150 | "version": 3 151 | }, 152 | "file_extension": ".py", 153 | "mimetype": "text/x-python", 154 | "name": "python", 155 | "nbconvert_exporter": "python", 156 | "pygments_lexer": "ipython3", 157 | "version": "3.8.18" 158 | } 159 | }, 160 | "nbformat": 4, 161 | "nbformat_minor": 5 162 | } 163 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | ---------------------------------- 2 | 3 | Non-Commercial License 4 | Yale Copyright © 2024 Yale University. 5 | 6 | Permission is hereby granted to use, copy, modify, and distribute this Software for any non-commercial purpose. Any distribution or modification or derivations of the Software (together “Derivative Works”) must be made available on GitHub and shall include this copyright notice and this permission notice in all copies or substantial portions of the Software. For the purposes of this license, "non-commercial" means not intended for or directed towards commercial advantage or monetary compensation either via the Software itself or Derivative Works or uses of either which lead to or generate any commercial products. In any event, the use and modification of the Software or Derivative Works shall remain governed by the terms and conditions of this Agreement; Any commercial use of the Software requires a separate commercial license from the copyright holder at Yale University. Direct any requests for commercial licenses to Yale Ventures at yaleventures@yale.edu. 7 | 8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 9 | 10 | ---------------------------------- 11 | -------------------------------------------------------------------------------- /Overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrishnaswamyLab/Gene-Signal-Pattern-Analysis/163d4bcde8b48fd79051d4a3ec75f585d6a0349a/Overview.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gene Signal Pattern Analysis 2 | ### Mapping the gene space at single-cell resolution 3 | 4 | Gene Signal Pattern Analysis is a Python package for mapping the gene space from single-cell data. For a detailed explanation of GSPA and potential downstream application, see: 5 | 6 | [**Mapping the gene space at single-cell resolution with Gene Signal Pattern Analysis**. Aarthi Venkat, Sam Leone, Scott E. Youlten, Eric Fagerberg, John Attanasio, Nikhil S. Joshi, Michael Perlmutter, Smita Krishnaswamy.](https://www.biorxiv.org/content/10.1101/2023.11.26.568492v1) 7 | 8 | By considering gene expression values as signals on the cell-cell graph, GSPA enables complex analyses of gene-gene relationships, including gene cluster analysis, cell-cell communication, and patient manifold learning from gene-gene graphs. 9 | 10 | ### Installation 11 | 12 | ``` 13 | pip install gspa 14 | ``` 15 | 16 | ### Requirements 17 | 18 | GSPA requires Python >= 3.6. All other requirements are automatically installed by ``pip`` (see also requirements.txt). 19 | 20 | The following have been tested: Python 3.6.15 (graphtools 1.5.3, tensorflow 2.6.2, keras 2.6.0, numpy 1.19.5, sklearn 0.24.2, scipy 1.5.4, tqdm 4.64.1, scanpy 1.7.2, phate 1.0.11) and Python 3.8.18 (graphtools 1.5.3, tensorflow 2.13.0, keras 2.13.1, numpy 1.22.4, sklearn 1.3.2, scipy 1.10.1, tqdm 4.66.4, scanpy 1.9.3, phate 1.0.11) 21 | 22 | ### Usage example 23 | 24 | ``` 25 | import numpy as np 26 | import gspa 27 | 28 | # Create toy data 29 | n_cells = 1000 30 | n_genes = 50 31 | data = np.random.normal(size=(n_cells, n_genes)) 32 | 33 | # GSPA operator constructs wavelet dictionary 34 | gspa_op = gspa.GSPA() 35 | gspa_op.construct_graph(data) 36 | gspa_op.build_diffusion_operator() 37 | gspa_op.build_wavelet_dictionary() 38 | 39 | # Embed gene signals from wavelet dictionary 40 | gene_signals = data.T # embed all measured genes 41 | gene_ae, gene_pc = gspa_op.get_gene_embeddings(gene_signals) 42 | gene_localization = gspa_op.calculate_localization() 43 | ``` 44 | See `GSPA_example.ipynb` [above](https://github.com/KrishnaswamyLab/Gene-Signal-Pattern-Analysis) for test run on simulated single-cell data. More notebooks to generate paper figures available at [https://github.com/KrishnaswamyLab/GSPA-manuscript-analyses](https://github.com/KrishnaswamyLab/GSPA-manuscript-analyses). 45 | 46 | ![](Overview.png) 47 | -------------------------------------------------------------------------------- /data/example.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrishnaswamyLab/Gene-Signal-Pattern-Analysis/163d4bcde8b48fd79051d4a3ec75f585d6a0349a/data/example.npz -------------------------------------------------------------------------------- /gspa/__init__.py: -------------------------------------------------------------------------------- 1 | from .gspa import GSPA 2 | import gspa.embedding 3 | import gspa.graphs 4 | import gspa.wavelets 5 | from .version import __version__ 6 | -------------------------------------------------------------------------------- /gspa/embedding.py: -------------------------------------------------------------------------------- 1 | from sklearn import decomposition 2 | import tensorflow as tf 3 | import keras 4 | import numpy as np 5 | 6 | def project(signals, cell_dictionary): 7 | norms = np.linalg.norm(signals, axis=1).reshape(-1, 1) 8 | norms[norms == 0] = 1 # Avoid division by zero by setting zero norms to one 9 | signals = signals / norms 10 | return np.dot(signals, cell_dictionary) 11 | 12 | def svd(signals, random_state=1234, n_components=2048): 13 | n_components = min(n_components, signals.shape[0], signals.shape[1]) 14 | pc_op = decomposition.PCA(n_components=n_components, random_state=random_state) 15 | data_pc = pc_op.fit_transform(signals) 16 | 17 | # normalize before autoencoder 18 | data_pc_std = data_pc / np.std(data_pc[:, 0]) 19 | 20 | return (data_pc_std) 21 | 22 | def run_ae(data, random_state=1234, act='relu', bias=1, dim=128, num_layers=2, dropout=0.0, lr=0.001, epochs=100, 23 | val_prop=0.05, weight_decay=0, patience=10, verbose=True): 24 | try: 25 | keras.utils.set_random_seed(random_state) 26 | except: 27 | tf.random.set_seed(random_state) # unstable for TF > 2.7 28 | 29 | # encoder 30 | input = keras.Input(shape=(data.shape[1])) 31 | encoded = keras.layers.Dense(dim * 2, activation=act, use_bias=bias)(input) 32 | if dropout > 0: 33 | encoded = keras.layers.Dropout(dropout)(encoded) 34 | for i in range(num_layers - 2): 35 | encoded = keras.layers.Dense(dim * 2, activation=act, use_bias=bias)(encoded) 36 | if dropout > 0: 37 | encoded = keras.layers.Dropout(dropout)(encoded) 38 | 39 | encoded = keras.layers.Dense(dim, activation='linear', use_bias=bias)(encoded) 40 | 41 | # decoder 42 | decoded = keras.layers.Dense(dim * 2, activation=act, use_bias=bias)(encoded) 43 | for i in range(num_layers - 2): 44 | decoded = keras.layers.Dense(dim * 2, activation=act, use_bias=bias)(decoded) 45 | decoded = keras.layers.Dense(data.shape[1], activation='linear', use_bias=bias)(decoded) 46 | 47 | # autoencoder 48 | autoencoder = keras.Model(input, decoded) 49 | encoder = keras.Model(input, encoded) 50 | try: 51 | autoencoder.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr, decay=weight_decay), 52 | loss='mean_squared_error') 53 | except ValueError: 54 | 55 | autoencoder.compile(optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=lr, decay=weight_decay), 56 | loss='mean_squared_error') 57 | 58 | callback = keras.callbacks.EarlyStopping(monitor="val_loss", patience=patience) 59 | 60 | history = autoencoder.fit(data, data, 61 | verbose=verbose, 62 | validation_split = val_prop, 63 | epochs=epochs, 64 | shuffle=True, 65 | callbacks=[callback]) 66 | 67 | embedding = encoder(data).numpy() 68 | 69 | return (embedding) 70 | -------------------------------------------------------------------------------- /gspa/graphs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import multiscale_phate as mp 3 | 4 | def graph_condensation(data, random_state=42, n_jobs=-1, condensation_threshold=10000, n_pca=100): 5 | mp_op = mp.Multiscale_PHATE(random_state=random_state, n_jobs=n_jobs, n_pca=n_pca) 6 | levels = mp_op.fit(data) 7 | number_of_condensed_points = np.array([np.unique(x).shape[0] for x in mp_op.NxTs]) 8 | condensed_level = np.argwhere(number_of_condensed_points <= condensation_threshold)[0][0] 9 | return (np.array(mp_op.NxTs[condensed_level])) 10 | 11 | def aggregate_signals_over_condensed_nodes(data, condensation_groupings): 12 | clust_unique, clust_unique_ids = np.unique(condensation_groupings, return_index=True) 13 | loc = [] 14 | for c in clust_unique: 15 | loc.append(np.where(condensation_groupings == c)[0]) 16 | 17 | counts_condensed = [] 18 | for l in loc: 19 | counts_condensed.append(data[l].mean(axis=0)) 20 | 21 | return (np.array(counts_condensed)) 22 | -------------------------------------------------------------------------------- /gspa/gspa.py: -------------------------------------------------------------------------------- 1 | import tasklogger 2 | import graphtools 3 | from . import graphs, wavelets, embedding 4 | import numpy as np 5 | from tqdm import tqdm 6 | from scipy import sparse, spatial 7 | import os 8 | 9 | os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' 10 | _logger = tasklogger.get_tasklogger("graphtools") 11 | 12 | class GSPA: 13 | """GSPA operator which performs gene dimensionality reduction. 14 | 15 | Gene Signal Pattern Analysis (GSPA) considers genes as signals on a cell-cell graph, enabling mapping the gene space for complex gene-level analyses, including gene cluster analysis, cell-cell communication, and patient manifold learning from gene-gene graphs (Venkat et al. [1]_). 16 | 17 | Parameters 18 | ---------- 19 | graph : graphtools.Graph, optional, default: None 20 | Cell-cell affinity graph. If None, `construct_graph` function will need to be run to construct graph from data directly. 21 | diffusion_operator : array-like, shape=[n_samples, n_samples], default: None 22 | Cell-cell diffusion operator. If None `build_diffusion_operator` will need to be run based on the cell-cell affinity graph. 23 | qr_decompose : boolean, default: True 24 | If True, composes reduced wavelet dictionary with QR decomposition 25 | qr_epsilon: float, optional, default: 1e-3 26 | If qr_decompose is True, qr_epsilon determines threshold for QR decomposition 27 | wavelet_J: int, optional, default: -1 28 | Maximum number of scales J. If -1, uses J=log(number of cells) based on Tong et al. [2]_. 29 | wavelet_power: int, optional, default: 2 30 | Geometric seqence of ratio wavelet_ower for wavelet transforms. 31 | embedding_dim: int, optional, default: 128 32 | Number of dimensions in which genes will be embedded with autoencoder. 33 | pc_dim: int, optional, default: 2048 34 | Number of dimensions in which genes will be embedded with PCA. 35 | random_state: int, default: 42 36 | Integer random seed for GSPA pipeline. 37 | verbose: boolean, optional, default: True 38 | If True, print status messages 39 | n_jobs: integer, optional, default: -1 40 | The number of jobs to use for computation. If -1 all CPUS are used. 41 | perform_condensation: boolean, optional, default: True 42 | If True, perform graph condensation for large graphs. 43 | condensation_threshold: int, optional, default: 10000 44 | If perform_condensation is True, graph condensation occurrs for graphs with more than condensation_threshold cells. 45 | bc_sample_idx: array-like, shape=[n_samples, 1], default: None 46 | Batch labels. If provided, bc_sample_idx is used to construct mutual nearest neighbors (MNN) graph for batch correction. 47 | bc_theta: float, optional, default: 0.95 48 | If batch labels bc_sample_idx provided, bc_theta is used to parametrize MNN symmetrization. 49 | activation: string, optional, default: relu 50 | Activation function in `keras.activations` between layers for autoencoder. 51 | bias: boolean, optional, default: 1 52 | If 1, autoencoder layers use bias vector. 53 | num_layers: int, optional, default: 2 54 | Number of dense layers within encoder (decoder) of autoencoder. 55 | dropout: float, optional, default: 0.0 56 | If dropout > 0, adds dropout layers between dense layers with `dropout` fraction of input units dropped. 57 | lr: float, optional, default: 0.001 58 | Learning rate for model Adam optimizer. 59 | weight_decay: float, optional, default: 0.0 60 | If set, weight_decay is applied to Adam optimizer. 61 | epochs: int, optional, default: 100 62 | Number of epochs for model training. 63 | val_prop: float, optional, default: 0.05 64 | Proportion of data heldout for validation set used for early stopping. 65 | patience: int, optional, default: 10 66 | Number of epochs with no improvement to validation loss after which training will be stopped. 67 | 68 | Attributes 69 | ---------- 70 | condensation_groupings: array-like, shape=[n_samples, 1] 71 | If perform_condensation is True and graph size is greater than condensation_threshold, MS PHATE (Kuchroo et al. [3]_) computes a new node assignment for each cell, where cells are grouped into nodes based on diffusion condensation. 72 | wavelet_dictionary: array-like, shape=[n_samples, n_wavelet_dictionary_dimensions] 73 | Stores wavelet dictionary vector for each cell after `build_wavelet_dictionary` is run. 74 | signals_projected: array-like, shape=[n_features, n_wavelet_dictionary_dimensions] 75 | Stores gene signals projected onto wavelet dictionary from `get_gene_embeddings`. 76 | 77 | Examples 78 | ---------- 79 | 80 | >>> import gspa 81 | >>> import scprep 82 | >>> import numpy as np 83 | >>> data = np.random.normal(size=(1000, 50)) # fake dataset with 1000 cells and 50 genes 84 | >>> gspa_op = gspa.GSPA() 85 | >>> gspa_op.construct_graph(data) 86 | >>> gspa_op.build_diffusion_operator() 87 | >>> gspa_op.build_wavelet_dictionary() 88 | >>> gene_ae, gene_pc = gspa_op.get_gene_embeddings(data.T) 89 | >>> gene_localization = gspa_op.calculate_localization() 90 | >>> gene_phate = phate.PHATE().fit_transform(gene_ae) 91 | >>> scprep.plot.scatter2d(gene_phate, c=gene_localization, cmap='PuBuGn') 92 | 93 | References 94 | ---------- 95 | .. [1] Venkat A, Leone S, Youlten SE, Fagerberg E, Attansio J, Joshi NS, Perlmutter M, Krishnaswamy S, *Mapping the gene space at single-cell resolution with gene signal pattern analysis* `BioRxiv `_. 96 | .. [2] Tong A, Huguet G, Shung D, Natik A, Kuchroo M, Lajoie G, Wolf G, Krishnaswamy S, *Embedding Signals on Knowledge Graphs with Unbalanced Diffusion Earth Mover's Distance* `arXiv `_. 97 | .. [3] Kuchroo et al, *Multiscale PHATE identifies multimodal signatures of COVID-19* ``_. 98 | """ 99 | 100 | def __init__(self, 101 | graph=None, 102 | diffusion_operator=None, 103 | qr_decompose=True, 104 | qr_epsilon=1e-3, 105 | wavelet_J=-1, 106 | wavelet_power=2, 107 | embedding_dim=128, 108 | pc_dim=2048, 109 | random_state=42, 110 | verbose=True, 111 | n_jobs=-1, 112 | perform_condensation=True, 113 | condensation_threshold=10000, 114 | bc_sample_idx=None, 115 | bc_theta=0.95, 116 | activation='relu', 117 | bias=1, 118 | num_layers=2, 119 | dropout=0.0, 120 | lr=0.001, 121 | weight_decay=0.0, 122 | epochs=100, 123 | val_prop=0.05, 124 | patience=10, 125 | ): 126 | 127 | self.graph = graph 128 | self.diff_op = diffusion_operator 129 | self.qr_decompose = qr_decompose 130 | self.qr_epsilon = qr_epsilon 131 | self.wavelet_J = wavelet_J 132 | self.wavelet_power = wavelet_power 133 | self.embedding_dim = embedding_dim 134 | self.pc_dim = pc_dim 135 | self.random_state = random_state 136 | self.verbose = verbose 137 | self.n_jobs = n_jobs 138 | self.perform_condensation = perform_condensation 139 | self.condensation_threshold = condensation_threshold 140 | self.bc_sample_idx = bc_sample_idx 141 | self.bc_theta = bc_theta 142 | self.activation = activation 143 | self.bias = bias 144 | self.num_layers = num_layers 145 | self.dropout = dropout 146 | self.lr = lr 147 | self.epochs = epochs 148 | self.val_prop = val_prop 149 | self.weight_decay = weight_decay 150 | self.patience = patience 151 | 152 | self.condensation_groupings = None 153 | self.wavelet_dictionary = None 154 | self.wavelet_sizes = None 155 | self.signals_projected = None 156 | 157 | _logger.set_level(self.verbose) 158 | 159 | def construct_graph(self, data): 160 | """Constructs cell-cell affinity graph. 161 | 162 | Parameters 163 | ---------- 164 | data: array-like, shape=[n_samples, n_features] 165 | input data with `n_samples` samples and `n_features` features. Accepted data types: `numpy.ndarray`, `pd.DataFrame`. 166 | 167 | """ 168 | if (data.shape[0] > self.condensation_threshold) & (self.perform_condensation): 169 | _logger.log_info("Dataset is larger than %s cells. Running graph condensation. Set perform_condensation=False to run exact GSPA." % self.condensation_threshold) 170 | self.condensation_groupings = graphs.graph_condensation(data, random_state=self.random_state, 171 | n_jobs=self.n_jobs, 172 | condensation_threshold=self.condensation_threshold, 173 | n_pca=self.pc_dim) 174 | 175 | data = graphs.aggregate_signals_over_condensed_nodes(data, self.condensation_groupings) 176 | if self.bc_sample_idx is None: 177 | self.graph = graphtools.Graph(data, n_pca=100, random_state=self.random_state, verbose=self.verbose, use_pygsp=True) 178 | else: 179 | _logger.log_info(f"bc_sample_idx used for batch correction") 180 | self.graph = graphtools.Graph(data, n_pca=100, sample_idx=self.bc_sample_idx, kernel_symm='mnn', 181 | theta=self.bc_theta, random_state=self.random_state, verbose=self.verbose, use_pygsp=True) 182 | 183 | def build_diffusion_operator(self): 184 | """Constructs diffusion operator from graph. 185 | """ 186 | if self.graph is None: 187 | raise ValueError('Graph not constructed. Run gspa_op.construct_graph(data) or initialize GSPA operator with graph') 188 | else: 189 | self.graph = self.graph.to_pygsp() 190 | 191 | Dmin1 = np.diag([1/np.sum(row) for row in self.graph.A]) 192 | self.diff_op = 1/2 * (np.eye(self.graph.N)+self.graph.A@Dmin1) 193 | 194 | def build_wavelet_dictionary(self): 195 | """Constructs wavelet dictionary from diffusion operator. 196 | """ 197 | if self.diff_op is None: 198 | raise ValueError('Diffusion operator not constructed. Run gspa_op.build_diffusion_operator() or initialize GSPA operator with diffusion_operator') 199 | wavelet_sizes = [] 200 | 201 | if self.graph is not None: 202 | self.graph = self.graph.to_pygsp() 203 | 204 | if sparse.issparse(self.diff_op): 205 | self.diff_op = self.diff_op.toarray() 206 | 207 | N = self.diff_op.shape[0] 208 | if self.wavelet_J == -1: 209 | self.wavelet_J = int(np.log(N)) 210 | I = np.eye(N) 211 | I = wavelets.normalize(I) 212 | wavelet_dictionary = [I] 213 | wavelet_sizes.append(I.shape[1]) 214 | P_j = np.linalg.matrix_power(self.diff_op, self.wavelet_power) 215 | 216 | if self.qr_decompose: 217 | Psi_j_tilde = wavelets.column_subset(I-P_j, epsilon=self.qr_epsilon) 218 | 219 | if Psi_j_tilde.shape[1] == 0: 220 | _logger.log_info(f"Wavelets calculated; J = 1") 221 | self.wavelet_dictionary, self.wavelet_sizes = (wavelets.flatten(wavelet_dictionary, wavelet_sizes)) 222 | 223 | Psi_j_tilde = wavelets.normalize(Psi_j_tilde) 224 | wavelet_sizes.append(Psi_j_tilde.shape[1]) 225 | wavelet_dictionary += [Psi_j_tilde] 226 | 227 | for i in tqdm(range(2,self.wavelet_J), disable=self.verbose==False): 228 | P_j_new = np.linalg.matrix_power(P_j,self.wavelet_power) 229 | Psi_j = P_j - P_j_new 230 | P_j = P_j_new 231 | Psi_j_tilde = wavelets.column_subset(Psi_j, epsilon=self.qr_epsilon) 232 | if Psi_j_tilde.shape[1] == 0: 233 | _logger.log_info("Wavelets calculated; J = %s" %i) 234 | self.wavelet_dictionary, self.wavelet_sizes = (wavelets.flatten(wavelet_dictionary, wavelet_sizes)) 235 | 236 | Psi_j_tilde = wavelets.normalize(Psi_j_tilde) 237 | 238 | wavelet_sizes.append(Psi_j_tilde.shape[1]) 239 | wavelet_dictionary += [Psi_j_tilde] 240 | else: 241 | _logger.log_info("Calculating Wavelets J = %s" % self.wavelet_J) 242 | wavelet_dictionary += [I-P_j] 243 | wavelet_sizes.append((I-P_j).shape[1]) 244 | for i in tqdm(range(2,self.wavelet_J), disable=self.verbose==False): 245 | P_j_new = np.linalg.matrix_power(P_j,self.wavelet_power) 246 | Psi_j = P_j - P_j_new 247 | P_j = P_j_new 248 | Psi_j = wavelets.normalize(Psi_j) 249 | wavelet_sizes.append(Psi_j.shape[1]) 250 | wavelet_dictionary += [Psi_j] 251 | 252 | self.wavelet_dictionary, self.wavelet_sizes = wavelets.flatten(wavelet_dictionary, wavelet_sizes) 253 | 254 | def get_gene_embeddings(self, signals): 255 | """Get gene features embedded in principle component space and autoencoded space. 256 | 257 | Parameters 258 | ---------- 259 | signals: array-like, shape=[n_features, n_samples] 260 | Input signals defined on nodes of cell-cell graph. Accepted data types: `numpy.ndarray`, `pd.DataFrame`. 261 | 262 | Returns 263 | ---------- 264 | signals_ae: array, shape=[n_features, embedding_dim] 265 | Signals embedded with autoencoder into `embedding_dim`-dimensional space. 266 | signals_pc: array, shape=[n_features, pc_dim] 267 | Signals embedded with PCA into `pc_dim`-dimensional space. 268 | """ 269 | 270 | if self.wavelet_dictionary is None: 271 | raise ValueError('Run gspa_op.build_wavelet_dictionary') 272 | 273 | if self.condensation_groupings is not None: 274 | signals = graphs.aggregate_signals_over_condensed_nodes(signals.T, self.condensation_groupings).T 275 | 276 | self.signals_projected = embedding.project(signals, self.wavelet_dictionary) 277 | signals_pc = embedding.svd(self.signals_projected, n_components=self.pc_dim, random_state=self.random_state) 278 | signals_ae = embedding.run_ae(signals_pc, verbose=self.verbose, random_state=self.random_state, act=self.activation, bias=self.bias, 279 | dim=self.embedding_dim, num_layers=self.num_layers, dropout=self.dropout, lr=self.lr, 280 | epochs=self.epochs, val_prop=self.val_prop, weight_decay=self.weight_decay, patience=self.patience) 281 | 282 | return (signals_ae, signals_pc) 283 | 284 | def calculate_localization(self, signals=None): 285 | """Calculates localization for signals. 286 | 287 | Parameters 288 | ---------- 289 | signals: array-like, optional, shape=[n_features, n_samples] 290 | Input signals defined on nodes of cell-cell graph. Accepted data types: `numpy.ndarray`, `pd.DataFrame`. If signals is None, calculates localization for gene signals inputted to `get_gene_embeddings`. 291 | 292 | Returns 293 | ---------- 294 | localization_score: array, shape=[n_features,] 295 | Localization score for each gene, where higher score indicates the gene is more localized on the cell-cell graph. 296 | """ 297 | 298 | if self.wavelet_dictionary is None: 299 | raise ValueError('Run gspa_op.build_wavelet_dictionary') 300 | if signals is not None: 301 | _logger.log_info(f"Computing localization with provided signals.") 302 | if self.condensation_groupings is not None: 303 | signals = graphs.aggregate_signals_over_condensed_nodes(signals.T, self.condensation_groupings).T 304 | signals_projected = embedding.project(signals, self.wavelet_dictionary) 305 | uniform_signal = np.ones((1, self.wavelet_dictionary.shape[0])) 306 | uniform_projected = embedding.project(uniform_signal, self.wavelet_dictionary) 307 | localization_score = spatial.distance.cdist(uniform_projected, signals_projected).reshape(-1,) 308 | 309 | else: 310 | if self.signals_projected is None: 311 | raise ValueError('Provide signals to map to dictionary or run gspa_op.get_gene_embeddings') 312 | else: 313 | _logger.log_info(f"Computing localization with signals used for gene embedding.") 314 | uniform_signal = np.ones((1, self.wavelet_dictionary.shape[0])) 315 | uniform_projected = embedding.project(uniform_signal, self.wavelet_dictionary) 316 | localization_score = spatial.distance.cdist(uniform_projected, self.signals_projected).reshape(-1,) 317 | 318 | return (localization_score) 319 | 320 | def calculate_cell_type_specificity(self, cell_type_assignments, cell_type, signals=None): 321 | """Calculates cell type specificity for each signal to provided cell type of interest. 322 | 323 | Parameters 324 | ---------- 325 | cell_type_assignments: array-like, shape=[n_samples,] 326 | Cluster or cell type assignments to cell nodes. 327 | cell_type: string 328 | Cluster name or cell type of interest. 329 | signals: array-like, optional, shape=[n_features, n_samples] 330 | Input signals defined on nodes of cell-cell graph. Accepted data types: `numpy.ndarray`, `pd.DataFrame`. If signals is None, calculates localization for gene signals inputted to `get_gene_embeddings`. 331 | 332 | Returns 333 | ---------- 334 | specificity_score: array, shape=[n_features,] 335 | Cell type specificity score for each gene, where higher score indicates the gene is more specific to provided cell type. 336 | """ 337 | 338 | cell_type_assignments = np.array(cell_type_assignments) 339 | if cell_type not in cell_type_assignments: 340 | raise ValueError('Cell type not found in cell type assignments') 341 | if self.wavelet_dictionary is None: 342 | raise ValueError('Run gspa_op.build_wavelet_dictionary') 343 | 344 | if signals is not None: 345 | _logger.log_info(f"Computing cell type specificity with provided signals.") 346 | if self.condensation_groupings is not None: 347 | signals = graphs.aggregate_signals_over_condensed_nodes(signals.T, self.condensation_groupings).T 348 | signals_projected = embedding.project(signals, self.wavelet_dictionary) 349 | 350 | cell_type_signal = (cell_type_assignments == cell_type).astype(int).reshape(1, -1) 351 | if self.condensation_groupings is not None: 352 | cell_type_signal = graphs.aggregate_signals_over_condensed_nodes(cell_type_signal.T, self.condensation_groupings).T 353 | cell_type_projected = embedding.project(cell_type_signal, self.wavelet_dictionary) 354 | specificity_score = -1*spatial.distance.cdist(cell_type_projected, signals_projected).reshape(-1,) 355 | 356 | else: 357 | if self.signals_projected is None: 358 | raise ValueError('Provide signals to map to dictionary or run gspa_op.get_gene_embeddings') 359 | else: 360 | _logger.log_info(f"Computing cell type specificity with signals used for gene embedding") 361 | cell_type_signal = (cell_type_assignments == cell_type).astype(int).reshape(1, -1) 362 | if self.condensation_groupings is not None: 363 | cell_type_signal = graphs.aggregate_signals_over_condensed_nodes(cell_type_signal.T, self.condensation_groupings).T 364 | cell_type_projected = embedding.project(cell_type_signal, self.wavelet_dictionary) 365 | specificity_score = -1*spatial.distance.cdist(cell_type_projected, self.signals_projected).reshape(-1,) 366 | 367 | return (specificity_score) 368 | -------------------------------------------------------------------------------- /gspa/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Krishnaswamy Lab, Yale University 2 | 3 | __version__ = "1.1" 4 | -------------------------------------------------------------------------------- /gspa/wavelets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.linalg import qr 3 | 4 | def normalize(A): 5 | """ 6 | Input: A, an n x m matrix 7 | Output: A with each column divided by its L2 norm 8 | """ 9 | 10 | for i in range(A.shape[1]): 11 | A[:,i]=A[:,i]/np.linalg.norm(A[:,i]) 12 | return A 13 | 14 | def column_subset(A,epsilon): 15 | """ 16 | Input: an m x n matrix A, tolerance epsilon 17 | Output: Subset of A's columns s.t. the projection of A into these columns; 18 | can approximate A with error < epsilon |A|_2 19 | """ 20 | 21 | R,P = qr(A,pivoting=True,mode='r') 22 | A_P = A[:,P] 23 | 24 | A_nrm = np.sum(A*A) 25 | tol = epsilon*A_nrm 26 | R_nrm = 0 27 | 28 | for i in range(0,R.shape[0]): 29 | R_nrm += np.sum(R[i]*R[i]) 30 | err = A_nrm-R_nrm 31 | if err < tol: 32 | return A_P[:,:i] 33 | 34 | return A_P 35 | 36 | def flatten(wavelet_list, size_of_wavelets_per_scale): 37 | N = wavelet_list[0].shape[0] 38 | flat_waves = np.zeros((N,np.sum(size_of_wavelets_per_scale))) 39 | curr = 0 40 | for i,wavelet in enumerate(wavelet_list): 41 | last = curr + size_of_wavelets_per_scale[i] 42 | flat_waves[:,curr:last] = wavelet 43 | curr = last 44 | 45 | return (np.array(flat_waves), np.array(size_of_wavelets_per_scale)) 46 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | graphtools>=1.5.0 2 | tensorflow>=2.6.0 3 | multiscale_phate==0.0 4 | numpy>=1.14.0 5 | scikit_learn 6 | scipy>=1.1.0 7 | tqdm 8 | scanpy 9 | phate 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup, find_packages 3 | 4 | install_requires = [ 5 | "graphtools>=1.5.0", 6 | "tensorflow>=2.6.0", 7 | "multiscale_phate==0.0", 8 | "numpy>=1.14.0", 9 | "scikit-learn", 10 | "scipy>=1.1.0", 11 | "tqdm", 12 | "scanpy", 13 | "phate" 14 | ] 15 | 16 | test_requires = [ 17 | "numpy>=1.14.0", 18 | "phate", 19 | ] 20 | 21 | version_py = os.path.join(os.path.dirname(__file__), "gspa", "version.py") 22 | version = open(version_py).read().strip().split("=")[-1].replace('"', "").strip() 23 | 24 | readme = open("README.md").read() 25 | 26 | setup( 27 | name='gspa', 28 | version=version, 29 | description="Gene Signal Pattern Analysis", 30 | author='Aarthi Venkat, Krishnaswamy Lab, Yale University', 31 | author_email='aarthi.venkat@yale.edu', 32 | packages=find_packages(), 33 | license="GNU General Public License Version 3", 34 | install_requires=install_requires, 35 | python_requires=">=3.6", 36 | extras_require={"test": test_requires}, 37 | long_description=readme, 38 | long_description_content_type="text/markdown", 39 | url="https://github.com/KrishnaswamyLab/Gene-Signal-Pattern-Analysis", 40 | download_url="https://github.com/KrishnaswamyLab/Gene-Signal-Pattern-Analysis/archive/v{}.tar.gz".format(version), 41 | keywords=["big-data", "manifold-learning", "computational-biology", "dimensionality-reduction", "single-cell"], 42 | classifiers=[ 43 | 'Programming Language :: Python :: 3', 44 | 'Programming Language :: Python :: 3.6', 45 | 'Programming Language :: Python :: 3.7', 46 | 'Programming Language :: Python :: 3.8', 47 | 'Programming Language :: Python :: 3.9', 48 | 'Topic :: Scientific/Engineering :: Bio-Informatics', 49 | 'Intended Audience :: Developers', 50 | 'Intended Audience :: Science/Research', 51 | 'Natural Language :: English', 52 | 'Development Status :: 5 - Production/Stable', 53 | 'Operating System :: OS Independent', 54 | ], 55 | ) 56 | 57 | setup_dir = os.path.dirname(os.path.realpath(__file__)) 58 | -------------------------------------------------------------------------------- /tests/test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import gspa 3 | import numpy as np 4 | import phate 5 | 6 | class TestGSPA(unittest.TestCase): 7 | def __init__(self, methodName='runTest'): 8 | super().__init__(methodName) 9 | 10 | # Test input 11 | self.test_data = np.random.random(size=(1000,50)) 12 | phate_op = phate.PHATE(verbose=False) 13 | phate_op.fit(self.test_data) 14 | 15 | self.bc_sample_idx = [1]*500 + [2]*500 16 | self.cell_type_assignments = ['cellA']*250 + ['cellB']*500 + ['cellC']*250 17 | self.cell_type = 'cellC' 18 | 19 | # Test setups 20 | self.gspa = gspa.GSPA(qr_decompose=False, pc_dim=20, embedding_dim=2,wavelet_J=3) 21 | self.gspa_qr = gspa.GSPA(pc_dim=20, embedding_dim=2) 22 | self.condensation = gspa.GSPA(perform_condensation=True, condensation_threshold=200, pc_dim=20, embedding_dim=2) 23 | self.batch_correction = gspa.GSPA(bc_sample_idx=self.bc_sample_idx, pc_dim=20, embedding_dim=2) 24 | self.input_graph = gspa.GSPA(graph=phate_op.graph, qr_decompose=False, pc_dim=20, embedding_dim=2, wavelet_J=3) 25 | self.input_diff_op = gspa.GSPA(diffusion_operator=phate_op.graph.diff_op, qr_decompose=False, pc_dim=20, embedding_dim=2, wavelet_J=3) 26 | 27 | def test_construct_graph(self): 28 | # Positive test case, no MS PHATE 29 | self.gspa.construct_graph(self.test_data) 30 | self.assertEqual(self.gspa.graph.N, self.test_data.shape[0]) 31 | 32 | def test_construct_graph_MS_PHATE(self): 33 | # Positive test case, MS PHATE 34 | self.condensation.construct_graph(self.test_data) 35 | self.assertLessEqual(self.condensation.graph.N, 200) 36 | 37 | def test_construct_graph_BC(self): 38 | # Positive test case, BC 39 | self.batch_correction.construct_graph(self.test_data) 40 | self.assertEqual(self.batch_correction.graph.N, self.test_data.shape[0]) 41 | 42 | def test_construct_diff_op_no_graph(self): 43 | with self.assertRaises(ValueError) as context: 44 | self.gspa.build_diffusion_operator() 45 | self.assertEqual(str(context.exception), "Graph not constructed. Run gspa_op.construct_graph(data) or initialize GSPA operator with graph") 46 | 47 | def test_construct_dict_no_graph(self): 48 | with self.assertRaises(ValueError) as context: 49 | self.gspa.build_wavelet_dictionary() 50 | self.assertEqual(str(context.exception), "Diffusion operator not constructed. Run gspa_op.build_diffusion_operator() or initialize GSPA operator with diffusion_operator") 51 | 52 | def test_dictionary(self): 53 | self.gspa.construct_graph(self.test_data) 54 | self.gspa.build_diffusion_operator() 55 | self.gspa.build_wavelet_dictionary() 56 | self.assertEqual(self.gspa.wavelet_dictionary.shape, (self.test_data.shape[0], self.test_data.shape[0] * self.gspa.wavelet_J)) 57 | 58 | def test_dictionary_qr(self): 59 | self.gspa_qr.construct_graph(self.test_data) 60 | self.gspa_qr.build_diffusion_operator() 61 | self.gspa_qr.build_wavelet_dictionary() 62 | self.assertEqual(self.gspa_qr.wavelet_dictionary.shape[0], self.test_data.shape[0]) 63 | self.assertLessEqual(self.gspa_qr.wavelet_dictionary.shape[1], self.test_data.shape[0] * self.gspa_qr.wavelet_J) 64 | 65 | def test_dictionary_MS_PHATE(self): 66 | self.condensation.construct_graph(self.test_data) 67 | self.condensation.build_diffusion_operator() 68 | self.condensation.build_wavelet_dictionary() 69 | self.assertLessEqual(self.condensation.wavelet_dictionary.shape[0], 200) 70 | 71 | def test_dictionary_input_graph(self): 72 | self.input_graph.build_diffusion_operator() 73 | self.input_graph.build_wavelet_dictionary() 74 | self.assertEqual(self.input_graph.wavelet_dictionary.shape, (self.test_data.shape[0], self.test_data.shape[0] * self.input_graph.wavelet_J)) 75 | 76 | def test_dictionary_input_diff_op(self): 77 | self.input_diff_op.build_wavelet_dictionary() 78 | self.assertEqual(self.input_diff_op.wavelet_dictionary.shape, (self.test_data.shape[0], self.test_data.shape[0] * self.input_graph.wavelet_J)) 79 | 80 | def test_get_gene_embeddings_no_wavelet(self): 81 | with self.assertRaises(ValueError) as context: 82 | self.gspa.get_gene_embeddings(self.test_data.T) 83 | self.assertEqual(str(context.exception), "Run gspa_op.build_wavelet_dictionary") 84 | 85 | with self.assertRaises(ValueError) as context: 86 | self.gspa.calculate_localization(self.test_data.T) 87 | self.assertEqual(str(context.exception), "Run gspa_op.build_wavelet_dictionary") 88 | 89 | with self.assertRaises(ValueError) as context: 90 | self.gspa.calculate_cell_type_specificity(cell_type_assignments=self.cell_type_assignments, cell_type=self.cell_type) 91 | self.assertEqual(str(context.exception), "Run gspa_op.build_wavelet_dictionary") 92 | 93 | def test_loc_no_signals(self): 94 | with self.assertRaises(ValueError) as context: 95 | self.gspa.construct_graph(self.test_data) 96 | self.gspa.build_diffusion_operator() 97 | self.gspa.build_wavelet_dictionary() 98 | self.gspa.calculate_localization() 99 | self.assertEqual(str(context.exception), "Provide signals to map to dictionary or run gspa_op.get_gene_embeddings") 100 | 101 | with self.assertRaises(ValueError) as context: 102 | self.gspa.construct_graph(self.test_data) 103 | self.gspa.build_diffusion_operator() 104 | self.gspa.build_wavelet_dictionary() 105 | self.gspa.calculate_cell_type_specificity(cell_type_assignments=self.cell_type_assignments, cell_type=self.cell_type) 106 | self.assertEqual(str(context.exception), "Provide signals to map to dictionary or run gspa_op.get_gene_embeddings") 107 | 108 | def test_get_gene_embeddings(self): 109 | self.gspa.construct_graph(self.test_data) 110 | self.gspa.build_diffusion_operator() 111 | self.gspa.build_wavelet_dictionary() 112 | out = self.gspa.get_gene_embeddings(self.test_data.T) 113 | self.assertEqual(out[0].shape, (self.test_data.shape[1], 2)) 114 | self.assertEqual(out[1].shape, (self.test_data.shape[1], 20)) 115 | 116 | def test_get_gene_embeddings(self): 117 | self.gspa.construct_graph(self.test_data) 118 | self.gspa.build_diffusion_operator() 119 | self.gspa.build_wavelet_dictionary() 120 | out = self.gspa.get_gene_embeddings(self.test_data.T) 121 | self.assertEqual(out[0].shape, (self.test_data.shape[1], 2)) 122 | self.assertEqual(out[1].shape, (self.test_data.shape[1], 20)) 123 | 124 | def test_get_gene_embeddings_MS_PHATE(self): 125 | self.condensation.construct_graph(self.test_data) 126 | self.condensation.build_diffusion_operator() 127 | self.condensation.build_wavelet_dictionary() 128 | out = self.condensation.get_gene_embeddings(self.test_data.T) 129 | self.assertEqual(out[0].shape, (self.test_data.shape[1], 2)) 130 | self.assertEqual(out[1].shape, (self.test_data.shape[1], 20)) 131 | 132 | def test_localization_with_gene_embeddings(self): 133 | self.gspa.construct_graph(self.test_data) 134 | self.gspa.build_diffusion_operator() 135 | self.gspa.build_wavelet_dictionary() 136 | self.gspa.get_gene_embeddings(self.test_data.T) 137 | out = self.gspa.calculate_localization() 138 | self.assertEqual(out.shape[0], self.test_data.shape[1]) 139 | 140 | def test_localization_with_gene_embeddings_MS_PHATE(self): 141 | self.condensation.construct_graph(self.test_data) 142 | self.condensation.build_diffusion_operator() 143 | self.condensation.build_wavelet_dictionary() 144 | self.condensation.get_gene_embeddings(self.test_data.T) 145 | out = self.condensation.calculate_localization() 146 | self.assertEqual(out.shape[0], self.test_data.shape[1]) 147 | 148 | def test_localization_without_gene_embeddings(self): 149 | self.gspa.construct_graph(self.test_data) 150 | self.gspa.build_diffusion_operator() 151 | self.gspa.build_wavelet_dictionary() 152 | out = self.gspa.calculate_localization(self.test_data.T[:20]) 153 | self.assertEqual(out.shape[0], 20) 154 | 155 | def test_localization_without_gene_embeddings_MS_PHATE(self): 156 | self.condensation.construct_graph(self.test_data) 157 | self.condensation.build_diffusion_operator() 158 | self.condensation.build_wavelet_dictionary() 159 | self.condensation.get_gene_embeddings 160 | out = self.condensation.calculate_localization(self.test_data.T[:20]) 161 | self.assertEqual(out.shape[0], 20) 162 | 163 | def test_cell_type_with_gene_embeddings(self): 164 | self.gspa.construct_graph(self.test_data) 165 | self.gspa.build_diffusion_operator() 166 | self.gspa.build_wavelet_dictionary() 167 | self.gspa.get_gene_embeddings(self.test_data.T) 168 | out = self.gspa.calculate_cell_type_specificity(cell_type_assignments=self.cell_type_assignments, cell_type=self.cell_type) 169 | self.assertEqual(out.shape[0], self.test_data.shape[1]) 170 | 171 | def test_cell_type_with_gene_embeddings_MS_PHATE(self): 172 | self.condensation.construct_graph(self.test_data) 173 | self.condensation.build_diffusion_operator() 174 | self.condensation.build_wavelet_dictionary() 175 | self.condensation.get_gene_embeddings(self.test_data.T) 176 | out = self.condensation.calculate_cell_type_specificity(cell_type_assignments=self.cell_type_assignments, cell_type=self.cell_type) 177 | self.assertEqual(out.shape[0], self.test_data.shape[1]) 178 | 179 | def test_cell_type_without_gene_embeddings(self): 180 | self.gspa.construct_graph(self.test_data) 181 | self.gspa.build_diffusion_operator() 182 | self.gspa.build_wavelet_dictionary() 183 | out = self.gspa.calculate_cell_type_specificity(cell_type_assignments=self.cell_type_assignments, cell_type=self.cell_type, signals=self.test_data.T[:20]) 184 | self.assertEqual(out.shape[0], 20) 185 | 186 | def test_cell_type_without_gene_embeddings_MS_PHATE(self): 187 | self.condensation.construct_graph(self.test_data) 188 | self.condensation.build_diffusion_operator() 189 | self.condensation.build_wavelet_dictionary() 190 | self.condensation.get_gene_embeddings 191 | out = self.condensation.calculate_cell_type_specificity(cell_type_assignments=self.cell_type_assignments, cell_type=self.cell_type, signals=self.test_data.T[:20]) 192 | self.assertEqual(out.shape[0], 20) 193 | 194 | if __name__ == '__main__': 195 | unittest.main() 196 | --------------------------------------------------------------------------------