├── scNET ├── Data │ ├── readme │ └── scNET-env.yml ├── KNNs │ └── readme ├── Models │ └── readme ├── Embedding │ └── readme ├── __init__.py ├── KNNDataset.py ├── Utils.py ├── MultyGraphModel.py ├── main.py └── coEmbeddedNetwork.py ├── images ├── readme ├── scNET.jpg ├── scNET.pdf └── scNET.png ├── setup.py ├── LICENSE └── README.md /scNET/Data/readme: -------------------------------------------------------------------------------- 1 | readme 2 | -------------------------------------------------------------------------------- /images/readme: -------------------------------------------------------------------------------- 1 | folder for images 2 | -------------------------------------------------------------------------------- /scNET/KNNs/readme: -------------------------------------------------------------------------------- 1 | folder to store the pruned KNN 2 | -------------------------------------------------------------------------------- /scNET/Models/readme: -------------------------------------------------------------------------------- 1 | folder to hold the trained models 2 | -------------------------------------------------------------------------------- /scNET/Embedding/readme: -------------------------------------------------------------------------------- 1 | In this folder the new embedding will be saved 2 | -------------------------------------------------------------------------------- /images/scNET.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madilabcode/scNET/HEAD/images/scNET.jpg -------------------------------------------------------------------------------- /images/scNET.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madilabcode/scNET/HEAD/images/scNET.pdf -------------------------------------------------------------------------------- /images/scNET.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madilabcode/scNET/HEAD/images/scNET.png -------------------------------------------------------------------------------- /scNET/Data/scNET-env.yml: -------------------------------------------------------------------------------- 1 | name: scNET 2 | channels: 3 | - pytorch 4 | - pyg 5 | - conda-forge 6 | dependencies: 7 | - python=3.11 8 | - pytorch==2.6.0 9 | - pandas>=2.2.1 10 | - numpy==1.26.4 11 | - networkx>=3.1 12 | - scanpy>=1.11.0 13 | - scikit-learn>=1.4.1 14 | - gseapy>=1.1.6 15 | - matplotlib>=3.8.0 16 | - seaborn 17 | - igraph 18 | - leidenalg 19 | - tqdm 20 | - scipy 21 | - pip 22 | - pip: 23 | - torch-geometric==2.6.1 24 | -------------------------------------------------------------------------------- /scNET/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import run_scNET 2 | from .Utils import load_embeddings, propagation, run_signature 3 | from .coEmbeddedNetwork import build_co_embeded_network, create_reconstructed_obj, pathway_enricment, test_KEGG_prediction, plot_de_pathways, find_downstream_tfs 4 | from scNET.MultyGraphModel import scNET 5 | 6 | __all__ = ['run_scNET', 'load_embeddings', 'build_co_embeded_network', 'scNET', 'create_reconstructed_obj', "test_KEGG_prediction", "pathway_enricment", "plot_de_pathways", "propagation", "run_signature", "find_downstream_tfs"] 7 | -------------------------------------------------------------------------------- /scNET/KNNDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | 5 | class KNNDataset(Dataset): 6 | def __init__(self, edge_index): 7 | self.edge_index = edge_index.T 8 | 9 | def __len__(self): 10 | return self.edge_index.shape[0] 11 | 12 | def __getitem__(self, idx): 13 | return self.edge_index[idx,:] 14 | 15 | 16 | class CellDataset(Dataset): 17 | def __init__(self, x, knn): 18 | self.x = x 19 | self.knn = knn 20 | 21 | 22 | def __len__(self): 23 | return self.x.shape[1] 24 | 25 | def __getitem__(self, idx): 26 | return self.x[:,idx] , idx 27 | 28 | 29 | class CustomDataset(Dataset): 30 | def __init__(self, x): 31 | self.data = x 32 | 33 | def __len__(self): 34 | return len(self.data) 35 | 36 | def __getitem__(self, index): 37 | return torch.tensor(self.data[index]) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setup( 7 | name='scnet', 8 | version='0.2.2.6', 9 | packages=find_packages(), 10 | include_package_data=True, # Include data files 11 | package_data={ 12 | # Include all files within the data directory under the your_package namespace 13 | 'scNET': ['Data/*',"KNNs/*","Embedding/*","Models/*"] 14 | }, 15 | install_requires=[ 16 | 'torch==2.6.0', 17 | 'torch-geometric==2.6.1', 18 | 'pandas>=2.2.1', 19 | 'numpy==1.26.4', 20 | 'networkx>=3.1', 21 | 'scanpy>=1.11.0', 22 | 'scikit-learn>=1.4.1', 23 | 'gseapy>=1.1.6', 24 | 'matplotlib>=3.8.0', 25 | 'igraph', 26 | 'leidenalg', 27 | 'tqdm' 28 | ], 29 | author='Ron Sheinin', 30 | description='Our method employs a unique dual-graph architecture based on graph neural networks (GNNs), enabling the joint representation of gene expression and PPI network data', 31 | long_description=long_description, 32 | long_description_content_type="text/markdown", 33 | url='https://github.com/madilabcode/scNET' 34 | ) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, Ron Sheinin 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | [![Published in Nature Methods](https://img.shields.io/badge/published-Nature%20Methods-brightgreen)](https://www.nature.com/articles/s41592-025-02627-0) [![Nature Methods Research Briefing](https://img.shields.io/badge/Nature%20Methods-Research%20Briefing-blue)](https://www.nature.com/articles/s41592-025-02628-z) 4 | [![PyPI Downloads](https://static.pepy.tech/badge/scnet)](https://pepy.tech/projects/scnet) [![PyPI version](https://img.shields.io/pypi/v/scnet.svg)](https://pypi.org/project/scnet/) 5 | 6 | 7 | # **scNET: Learning Context-Specific Gene and Cell Embeddings by Integrating Single-Cell Gene Expression Data with Protein-Protein Interaction Information** 8 | 9 | ## **Overview** 10 | 11 | Recent advances in single-cell RNA sequencing (scRNA-seq) techniques have provided unprecedented insights into tissue heterogeneity. However, gene expression data alone often fails to capture changes in cellular pathways and complexes, which are more discernible at the protein level. Additionally, analyzing scRNA-seq data presents challenges due to high noise levels and zero inflation. In this study, we propose a novel approach to address these limitations by integrating scRNA-seq datasets with a protein-protein interaction (PPI) network. Our method employs a unique bi-graph architecture based on graph neural networks (GNNs), enabling the joint representation of gene expression and PPI network data. This approach models gene-to-gene relationships under specific biological contexts and refines cell-cell relations using an attention mechanism, resulting in new gene and cell embeddings. 12 | 13 | ![Overview of the scNET Method](https://raw.githubusercontent.com/madilabcode/scNET/bb9385a9945e34e1e2500c8173baf5c8ece91f79/images/scNET.jpg) 14 | ## Download via PIP 15 | `pip install scnet` 16 | 17 | ## Download via git 18 | To clone the repository, use the following command: 19 | `git clone https://github.com/madilabcode/scNET` 20 | 21 | We recommend using the provided Conda environment located at ./Data/scNET-env.yaml. 22 | cd scNET 23 | conda env create -f ./Data/scNET-env.yaml 24 | 25 | ## import scNET 26 | `import scNET` 27 | 28 | ## API 29 | To train scNET on scRNA-seq data, first load an AnnData object using Scanpy, then initialize training with the following command: 30 | 31 | `scNET.run_scNET(obj, pre_processing_flag=False, human_flag=False, number_of_batches=3, split_cells= True, max_epoch=250, model_name = project_name)` 32 | 33 | with the following args: 34 | 35 | * **obj (AnnData, optional)**: AnnData obj. 36 | 37 | * **pre_processing_flag (bool, optional)**: If True, perform pre-processing steps. 38 | 39 | * **human_flag (bool, optional)**: Controls gene name casing in the network. 40 | 41 | * **number_of_batches (int, optional)**: Number of mini-batches for the training. 42 | 43 | * **split_cells (bool, optional)**: If True, split by cells instead of edges during training. 44 | 45 | * **n_neighbors (int, optional)**: Number of neighbors for building the adjacency graph. 46 | 47 | * **max_epoch (int, optional)**: Max number of epochs for model training. 48 | 49 | * **model_name (str, optional)**: Identifier for saving the model outputs. 50 | 51 | * **save_model_flag (bool, optional)**: If True, save the trained model. 52 | 53 | 54 | ### Retrieve embeddings and model outputs with: 55 | 56 | `embedded_genes, embedded_cells, node_features , out_features = scNET.load_embeddings(project_name)` 57 | 58 | where: 59 | * **embedded_genes (np.ndarray)**: Learned gene embeddings. 60 | 61 | * **embedded_cells (np.ndarray)**: Learned cell embeddings. 62 | 63 | * **node_features (pd.DataFrame)**: Original gene expression matrix. 64 | 65 | * **out_features (np.ndarray)**: Reconstructed gene expression matrix 66 | 67 | 68 | ### Create a new AnnData object using model outputs: 69 | 70 | `recon_obj = scNET.create_reconstructed_obj(node_features, out_features, obj)` 71 | 72 | ### Construct a co-embedded network using the gene embeddings: 73 | `scNET.build_co_embeded_network(embedded_genes, node_features)` 74 | ## Tutorials 75 | 76 | For a basic usage example of our framework, please refer to the following notebook: 77 | [scNET Example Notebook](https://colab.research.google.com/github/madilabcode/scNET/blob/main/scNET.ipynb) 78 | 79 | For a uasge example with batch integration using bbknn graph, plese refer to the following notebook: 80 | [scNET Multi Batch Example Notebook](https://github.com/madilabcode/scNET/blob/main/scNET_Integration.ipynb) 81 | 82 | 83 | For a simple usage example on gene inference using scNET gene embedding,please refer to the following notebook: 84 | [scNET Icos embedding](https://github.com/madilabcode/scNET/blob/main/scNET_gene_inference.ipynb) 85 | 86 | 87 | For a simple example of predicting functional annotations using gene embeddings, please refer to the following notebook: 88 | [scNET functional annotations](https://github.com/madilabcode/scNET/blob/main/scNET_Predicting_Annotation_From_Gene_Embedding.ipynb) 89 | 90 | 91 | For a example of how to use scNET to identify CD8+ T Cells subpopulation please refer to the following notebook: 92 | [scNET subpouplation clustring](https://github.com/madilabcode/scNET/blob/main/scNET_CD8_subsets.ipynb) 93 | 94 | -------------------------------------------------------------------------------- /scNET/Utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pandas as pd 4 | import scanpy as sc 5 | from scipy.stats import ranksums 6 | from sklearn.metrics import roc_curve, auc 7 | from sklearn.metrics import average_precision_score 8 | import pickle 9 | import pkg_resources 10 | 11 | alpha = 0.9 12 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 13 | epsilon = 0.0001 14 | 15 | 16 | def save_obj(obj, name): 17 | with open(name + '.pkl', 'wb') as f: 18 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 19 | 20 | def load_obj(name): 21 | with open(name + '.pkl', 'rb') as f: 22 | return pickle.load(f) 23 | 24 | 25 | def normW(W): 26 | sum_rows = pd.DataFrame(W.sum(axis=1)) + epsilon 27 | sum_rows = sum_rows @ sum_rows.T 28 | sum_rows **= 1/2 29 | return W / sum_rows 30 | 31 | def calculate_propagation_matrix(W, epsilon = 0.0001): 32 | # device = torch.device("cpu") 33 | S = [] 34 | W = normW(W) 35 | W = torch.tensor(W.values).to(device) 36 | for index in range(W.shape[0]): 37 | y = torch.zeros(W.shape[0],dtype=torch.float32).to(device) 38 | y[index] = 1 39 | f = y.clone() 40 | flag = True 41 | 42 | while(flag): 43 | next_f = (alpha*(W@f) + (1-alpha)*y).to(device) 44 | 45 | if torch.linalg.norm(next_f - f) <= epsilon: 46 | flag = False 47 | else: 48 | # print(torch.linalg.norm(next_f - f)) 49 | f = next_f 50 | S.append(f) 51 | return torch.concat(S).view(W.shape) 52 | 53 | def propagate_all_genes(W,exp): 54 | S = calculate_propagation_matrix(W) 55 | prop_exp = torch.tensor(exp.values).to(device).T 56 | prop_exp = S @ prop_exp 57 | prop_norm = S @ torch.ones_like(prop_exp) 58 | prop_exp /= prop_norm 59 | prop_exp = pd.DataFrame(prop_exp.T.detach().cpu().numpy(),index = exp.index, columns = exp.columns) 60 | return prop_exp 61 | 62 | def one_step_propagation(W,F): 63 | W = torch.tensor(normW(W).values, dtype= torch.float32) 64 | F = torch.tensor(F,dtype= torch.float32) 65 | prop_exp = (alpha)*W@F + (1-alpha)*F 66 | prop_norm = (alpha)*W@torch.ones_like(F) + (1-alpha)*torch.ones_like(F) 67 | return prop_exp/prop_norm 68 | 69 | def add_noise(obj,alpha = 0.0, drop_out = False): 70 | obj_noise = obj.raw.to_adata() 71 | #obj_noise.X = (1-alpha) *obj_noise.X + alpha*np.random.randn(*obj.X.shape) 72 | if drop_out: 73 | obj_noise.X = obj_noise.X * np.random.binomial(1,(1-alpha),obj.X.shape) 74 | else: 75 | obj_noise.X = ((1-alpha) *obj_noise.X + alpha*np.random.randn(*obj.X.shape)).astype(np.float32) 76 | obj_noise.var["highly_variable"] = True 77 | sc.tl.pca(obj_noise, svd_solver='arpack',use_highly_variable = False) 78 | sc.pp.neighbors(obj_noise,n_pcs=20, n_neighbors=50) 79 | obj_noise.raw = obj_noise 80 | 81 | return obj_noise 82 | 83 | def wilcoxon_enrcment_test(up_sig, down_sig, exp): 84 | gene_exp = exp.loc[exp.index.isin(up_sig)] 85 | if down_sig is None: 86 | backround_exp = exp.loc[~exp.index.isin(up_sig)] 87 | else: 88 | backround_exp = exp.loc[exp.index.isin(down_sig)] 89 | 90 | rank = ranksums(backround_exp,gene_exp,alternative="less")[1] # rank expression of up sig higher than backround 91 | return -1 * np.log(rank) 92 | 93 | 94 | # --------------------------- 95 | # calculates the signature of the data 96 | # 97 | # returns scores vector of signature calculated per cell 98 | # --------------------------- 99 | def signature_values(exp, up_sig, down_sig=None): 100 | up_sig = pd.DataFrame(up_sig).squeeze() 101 | # first letter of gene in upper case 102 | up_sig = up_sig.apply(lambda x: x[0].upper() + x[1:].lower()) 103 | # keep genes in sig that appear in exp data 104 | up_sig = up_sig[up_sig.isin(exp.index)] 105 | 106 | if down_sig is not None: 107 | down_sig = pd.DataFrame(down_sig).squeeze() 108 | down_sig = down_sig.apply(lambda x: x[0].upper() + x[1:].lower()) 109 | down_sig = down_sig[down_sig.isin(exp.index)] 110 | 111 | return exp.apply(lambda cell: wilcoxon_enrcment_test(up_sig, down_sig, cell), axis=0) 112 | 113 | def run_signature(obj, up_sig, down_sig=None, umap_flag = True, alpha = 0.9,prop_exp = None): 114 | """ 115 | Calculate and visualize a propagated signature score for cells in the given object. 116 | Parameters 117 | ---------- 118 | obj : AnnData 119 | The annotated data object containing gene expression matrix and graph data. 120 | up_sig : list or set 121 | A collection of genes used to calculate the up-regulated signature score. 122 | down_sig : list or set, optional 123 | A collection of genes used to calculate the down-regulated signature score. 124 | If None, only the up-regulated signature is used. Default is None. 125 | umap_flag : bool, optional 126 | If True, generates a UMAP plot colored by the calculated signature score. 127 | If False, generates a t-SNE plot. Default is True. 128 | alpha : float, optional 129 | A parameter controlling the smoothing or propagation factor during signature 130 | score calculation. Default is 0.9. 131 | prop_exp : None or other, optional 132 | An unused parameter placeholder, reserved for future use or extended 133 | signature propagation functionality. 134 | Returns 135 | ------- 136 | np.ndarray 137 | An array of propagated signature scores, with one score per cell. The 138 | scores are also stored in obj.obs["SigScore"]. 139 | """ 140 | 141 | exp = obj.to_df().T 142 | graph = obj.obsp["connectivities"].toarray() 143 | sigs_scores = signature_values(exp, up_sig, down_sig) 144 | sigs_scores = propagation(sigs_scores, graph) 145 | obj.obs["SigScore"] = sigs_scores 146 | # color_map = "jet" 147 | if umap_flag: 148 | sc.pl.umap(obj, color=["SigScore"],color_map="magma") 149 | else: 150 | sc.pl.tsne(obj, color=["SigScore"],color_map="magma") 151 | return sigs_scores 152 | 153 | def calculate_roc_auc(idents, predict): 154 | fpr, tpr, _ = roc_curve(idents, predict, pos_label=1) 155 | return auc(fpr, tpr) 156 | 157 | def calculate_aupr(idents, predict): 158 | return average_precision_score(idents, predict) 159 | 160 | def calculate_roc_auc(idents, predict): 161 | fpr, tpr, _ = roc_curve(idents, predict, pos_label=1) 162 | return auc(fpr, tpr) 163 | 164 | def calculate_aupr(idents, predict): 165 | return average_precision_score(idents, predict) 166 | 167 | # --------------------------- 168 | # Y - scores vector of cells 169 | # W - Adjacency matrix 170 | # 171 | # f_t = alpha * (W * f_(t-1)) + (1-alpha)*Y 172 | # 173 | # returns f/f1 174 | # --------------------------- 175 | def propagation(Y, W): 176 | W = normW(W) 177 | f = np.array(Y) 178 | Y = np.array(Y) 179 | # f2 = calculate_propagation_matrix(W) @ Y 180 | 181 | W = np.array(W.values) 182 | 183 | Y1 = np.ones(Y.shape, dtype=np.float64) 184 | f1 = np.ones(Y.shape, dtype=np.float64) 185 | flag = True 186 | 187 | while(flag): 188 | next_f = alpha*(W@f) + (1-alpha)*Y 189 | next_f1 = alpha*(W@f1) + (1-alpha)*Y1 190 | 191 | if np.linalg.norm(next_f - f) <= epsilon and np.linalg.norm(next_f1 - f1) <= epsilon: 192 | flag = False 193 | else: 194 | #print(np.linalg.norm(next_f - f)) 195 | #print(np.linalg.norm(next_f1 - f1)) 196 | f = next_f 197 | f1 = next_f1 198 | # return f1,f2 199 | return np.array(f/f1) 200 | 201 | def crate_anndata(path, pcs = 15,neighbors = 30): 202 | exp = pd.read_csv(path,index_col=0) 203 | #exp = pd.read_table(path, sep='\t') 204 | adata = sc.AnnData(exp.T) 205 | sc.pp.filter_cells(adata, min_genes=200) 206 | sc.pp.filter_genes(adata, min_cells=3) 207 | adata.var['mt'] = adata.var_names.str.startswith('MT-') 208 | sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True) 209 | adata = adata[adata.obs.n_genes_by_counts < 6000, :] 210 | adata = adata[adata.obs.pct_counts_mt < 10, :] 211 | sc.pp.normalize_total(adata, target_sum=1e4) 212 | sc.pp.log1p(adata) 213 | sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5) 214 | adata.raw = adata 215 | sc.pp.regress_out(adata, ['total_counts', 'pct_counts_mt']) 216 | sc.tl.pca(adata, svd_solver='arpack') 217 | sc.pp.neighbors(adata, n_neighbors=neighbors, n_pcs=pcs) 218 | sc.tl.leiden(adata) 219 | sc.tl.tsne(adata) 220 | return adata 221 | 222 | def save_model(path, model): 223 | torch.save(model.state_dict(), path) 224 | 225 | 226 | def load_embeddings(proj_name): 227 | ''' 228 | Loads the embeddings and gene expression data for a given project. 229 | 230 | Args: 231 | proj_name (str): The name of the project. 232 | 233 | Returns: 234 | tuple: A tuple containing: 235 | - embedded_genes (np.ndarray): Learned gene embeddings. 236 | - embedded_cells (np.ndarray): Learned cell embeddings. 237 | - node_features (pd.DataFrame): Original gene expression matrix. 238 | - out_features (np.ndarray): Reconstructed gene expression matrix. 239 | ''' 240 | embeded_genes = load_obj(pkg_resources.resource_filename(__name__,r"./Embedding/row_embedding_" + proj_name)) 241 | embeded_cells = load_obj(pkg_resources.resource_filename(__name__,r"./Embedding/col_embedding_" + proj_name)) 242 | #node_features = pd.read_csv(pkg_resources.resource_filename(__name__,r"./Embedding/node_features_" + proj_name),index_col=0) 243 | node_features = pd.read_pickle(pkg_resources.resource_filename(__name__,r"./Embedding/node_features_" + proj_name)) 244 | out_features = load_obj(pkg_resources.resource_filename(__name__,r"./Embedding/out_features_" + proj_name)) 245 | return embeded_genes, embeded_cells, node_features, out_features -------------------------------------------------------------------------------- /scNET/MultyGraphModel.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch_geometric.nn import sequential, GATConv, GraphNorm, VGAE, GCNConv, InnerProductDecoder, TransformerConv, GAE,LayerNorm, SAGEConv 6 | from torch_geometric.nn.conv import transformer_conv 7 | from torch_geometric.utils import negative_sampling 8 | from sklearn.metrics import average_precision_score, roc_auc_score 9 | from torch_geometric.utils import add_self_loops, remove_self_loops, softmax 10 | import math 11 | import numpy as np 12 | import pandas as pd 13 | EPS = 1e-15 14 | MAX_LOGSTD = 10 15 | 16 | class FeatureDecoder(torch.nn.Module): 17 | def __init__(self, feature_dim, embd_dim,inter_dim , drop_p = 0.0): 18 | super(FeatureDecoder, self).__init__() 19 | self.feature_dim = feature_dim 20 | self.embd_dim = embd_dim 21 | self.inter_dim = inter_dim 22 | self.decoder = nn.Sequential(nn.Linear(embd_dim, inter_dim), 23 | nn.Dropout(drop_p), 24 | nn.ReLU(), 25 | nn.Linear(inter_dim, inter_dim), 26 | nn.Dropout(drop_p), 27 | nn.ReLU(), 28 | nn.Linear(inter_dim, feature_dim), 29 | nn.Dropout(drop_p)) 30 | 31 | def forward(self, z): 32 | out = self.decoder(z) 33 | return out 34 | 35 | class MutualEncoder(torch.nn.Module): 36 | def __init__(self,col_dim, row_dim,num_layers=4, drop_p = 0.25): 37 | super(MutualEncoder, self).__init__() 38 | self.col_dim = col_dim 39 | self.row_dim = row_dim 40 | self.num_layers = num_layers 41 | 42 | self.rows_layers = nn.ModuleList([ 43 | sequential.Sequential('x,edge_index', [ 44 | (SAGEConv(self.row_dim, self.row_dim), 'x, edge_index -> x1'), 45 | (nn.Dropout(drop_p,inplace=False), 'x1-> x2'), 46 | nn.LeakyReLU(inplace=True), 47 | ]) for _ in range(num_layers)]) 48 | 49 | self.cols_layers = nn.ModuleList([ 50 | sequential.Sequential('x,edge_index', [ 51 | (SAGEConv(self.col_dim, self.col_dim), 'x, edge_index -> x1'), 52 | nn.LeakyReLU(inplace=True), 53 | (nn.Dropout(drop_p,inplace=False), 'x1-> x2'), 54 | ]) for _ in range(num_layers)]) 55 | 56 | 57 | def forward(self, x, knn_edge_index, ppi_edge_index): 58 | 59 | embbded = x.clone() 60 | for i in range(self.num_layers): 61 | embbded = self.cols_layers[i](embbded.T,knn_edge_index).T 62 | embbded = self.rows_layers[i](embbded, ppi_edge_index) 63 | 64 | return embbded 65 | 66 | class TransformerConvReducrLayer(TransformerConv): 67 | def __init__(self, in_channels, out_channels, heads=1, dropout=0 , add_self_loops=True,scale_param = 2, **kwargs): 68 | super().__init__(in_channels, out_channels, heads, dropout, add_self_loops, **kwargs) 69 | self.treshold_alpha = None 70 | self.scale_param = scale_param 71 | 72 | def message(self, query_i, key_j, value_j, 73 | edge_attr, index, ptr, 74 | size_i): 75 | 76 | if self.lin_edge is not None: 77 | assert edge_attr is not None 78 | edge_attr = self.lin_edge(edge_attr).view(-1, self.heads, 79 | self.out_channels) 80 | key_j += edge_attr 81 | 82 | alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels) 83 | if not self.scale_param is None: 84 | alpha = alpha - alpha.mean() 85 | alpha = alpha / ((1/self.scale_param) * alpha.std()) 86 | alpha = F.sigmoid(alpha) 87 | else: 88 | alpha = softmax(alpha, index, ptr, size_i) 89 | self.treshold_alpha = alpha 90 | 91 | self._alpha = alpha 92 | alpha = F.dropout(alpha, p=self.dropout, training=self.training) 93 | 94 | out = value_j 95 | if edge_attr is not None: 96 | out += edge_attr 97 | 98 | out *= alpha.view(-1, self.heads, 1) 99 | return out 100 | 101 | class DimEncoder(torch.nn.Module): 102 | def __init__(self,feature_dim, inter_dim, embd_dim,reducer=False,drop_p = 0.2, scale_param=3): 103 | super(DimEncoder, self).__init__() 104 | self.feature_dim = feature_dim 105 | self.embd_dim = embd_dim 106 | self.inter_dim = inter_dim 107 | self.reducer = reducer 108 | 109 | self.encoder = sequential.Sequential('x, edge_index', [ 110 | (GCNConv(self.feature_dim, self.inter_dim), 'x, edge_index -> x1'), 111 | nn.LeakyReLU(inplace=True), 112 | (nn.Dropout(drop_p,inplace=False), 'x1-> x2') 113 | ]) 114 | if self.reducer: 115 | self.atten_layer = TransformerConvReducrLayer(self.inter_dim, self.embd_dim,dropout= drop_p,add_self_loops = False,heads=1, scale_param=scale_param) 116 | else: 117 | self.atten_layer = TransformerConv(self.inter_dim, self.embd_dim,dropout=drop_p) 118 | 119 | self.atten_map = None 120 | self.atten_weights = None 121 | self.plot_count = 0 122 | 123 | 124 | def reduce_network(self, threshold = 0.2, min_connect=10): 125 | self.plot_count += 1 126 | graph = self.atten_weights.cpu().detach().numpy() 127 | threshold_bound = np.percentile(graph, 10) 128 | threshold = min(threshold,threshold_bound) 129 | df = pd.DataFrame({"v1": self.atten_map[0].cpu().detach().numpy(), "v2": self.atten_map[1].cpu().detach().numpy(), "atten": graph.squeeze()}) 130 | saved_edges = df.groupby('v1')['atten'].nlargest(min_connect).index.values 131 | saved_edges = [v2 for _, v2 in saved_edges] 132 | df.iloc[saved_edges,2] = threshold + EPS 133 | indexs = list(df.loc[df.atten >= threshold].index) 134 | atten_map = self.atten_map[:,indexs] 135 | self.atten_map = None 136 | self.atten_weights = None 137 | return atten_map, df 138 | 139 | def forward(self, x, edge_index, infrance=False): 140 | embbded = x.clone() 141 | embbded = self.encoder(embbded,edge_index) 142 | embbded, atten_map = self.atten_layer(embbded, edge_index, return_attention_weights=True) 143 | if self.reducer and not infrance : 144 | if self.atten_map is None: 145 | self.atten_map = atten_map[0].detach() 146 | self.atten_weights = atten_map[1].detach() 147 | else: 148 | self.atten_map = torch.concat([self.atten_map.T, atten_map[0].detach().T]).T 149 | self.atten_weights = torch.concat([self.atten_weights, atten_map[1].detach()]) 150 | 151 | return embbded 152 | 153 | 154 | class scNET(torch.nn.Module): 155 | def __init__(self,col_dim, row_dim,inter_row_dim, embd_row_dim, inter_col_dim,embd_col_dim, 156 | lambda_rows = 1, lambda_cols = 1, num_layers=2, drop_p = 0.25): 157 | 158 | super(scNET, self).__init__() 159 | self.col_dim = col_dim 160 | self.row_dim = row_dim 161 | self.inter_row_dim = inter_row_dim 162 | self.embd_row_dim = embd_row_dim 163 | self.inter_col_dim = inter_col_dim 164 | self.embd_col_dim = embd_col_dim 165 | self.lambda_rows = lambda_rows 166 | self.lambda_cols = lambda_cols 167 | 168 | 169 | self.encoder = MutualEncoder(col_dim, row_dim,num_layers, drop_p) 170 | self.rows_encoder = DimEncoder(row_dim, inter_row_dim, embd_row_dim,drop_p = drop_p, scale_param=None, reducer=False) 171 | 172 | self.cols_encoder = DimEncoder(col_dim, inter_col_dim, embd_col_dim,drop_p=drop_p, reducer=True) 173 | self.feature_decodr = FeatureDecoder(col_dim, embd_col_dim, inter_col_dim, drop_p = 0 ) 174 | self.ipd = InnerProductDecoder() 175 | self.feature_critarion = nn.MSELoss(reduction ='mean') 176 | 177 | def recon_loss(self, z, pos_edge_index, neg_edge_index = None, sig=False) : 178 | if neg_edge_index is None: 179 | neg_edge_index = negative_sampling(pos_edge_index, z.size(0)) 180 | 181 | if not sig: 182 | embd = torch.corrcoef(z) 183 | pos = torch.sigmoid(embd[pos_edge_index[0],pos_edge_index[1]]) 184 | neg = torch.sigmoid(embd[neg_edge_index[0],neg_edge_index[1]]) 185 | pos_loss = -torch.log(pos +EPS).mean() 186 | neg_loss = -torch.log(1 - neg + EPS).mean() 187 | else: 188 | pos_loss = -torch.log( 189 | self.ipd(z, pos_edge_index, sigmoid=sig) + EPS).mean() 190 | 191 | 192 | neg_loss = -torch.log(1 - 193 | self.ipd(z, neg_edge_index, sigmoid=sig) + 194 | EPS).mean() 195 | 196 | return pos_loss + neg_loss 197 | 198 | 199 | def kl_loss(self, mu = None , logstd = None): 200 | 201 | mu = self.rows_encoder.__mu__ if mu is None else mu 202 | logstd = self.rows_encoder.__logstd__ if logstd is None else logstd 203 | return -0.5 * torch.mean( 204 | torch.sum(1 + 2 * logstd - mu**2 - logstd.exp()**2, dim=1)) 205 | 206 | def test(self, z, pos_edge_index, neg_edge_index ): 207 | 208 | pos_y = z.new_ones(pos_edge_index.size(1)) 209 | neg_y = z.new_zeros(neg_edge_index.size(1)) 210 | y = torch.cat([pos_y, neg_y], dim=0) 211 | 212 | pos_pred = self.ipd(z, pos_edge_index, sigmoid=True) 213 | neg_pred = self.ipd(z, neg_edge_index, sigmoid=True) 214 | pred = torch.cat([pos_pred, neg_pred], dim=0) 215 | 216 | y, pred = y.detach().cpu().numpy(), pred.detach().cpu().numpy() 217 | 218 | return roc_auc_score(y, pred), average_precision_score(y, pred) 219 | 220 | 221 | def calculate_loss(self, x ,knn_edge_index, ppi_edge_index, highly_variable_index): 222 | embbed = self.encoder(x, knn_edge_index, ppi_edge_index) 223 | embbed_rows = self.rows_encoder(embbed, ppi_edge_index) 224 | row_loss = self.recon_loss(embbed_rows, ppi_edge_index,sig=True) 225 | 226 | embbed_cols = self.cols_encoder(embbed.T, knn_edge_index) 227 | out_features = self.feature_decodr(embbed_cols) 228 | out_features = (out_features - (out_features.mean(axis=0)))/ (out_features.std(axis=0)+ EPS) 229 | reg = self.recon_loss(out_features.T, ppi_edge_index, sig=False) 230 | 231 | out_features = out_features.T[highly_variable_index.values].T 232 | col_loss = self.feature_critarion(x[highly_variable_index.values].T, out_features) 233 | 234 | 235 | return self.lambda_rows * row_loss + self.lambda_cols * (col_loss + reg), row_loss, col_loss 236 | 237 | 238 | def forward(self, x, knn_edge_index, ppi_edge_index): 239 | embbed = self.encoder(x, knn_edge_index, ppi_edge_index) 240 | embbed_rows = self.rows_encoder(embbed, ppi_edge_index) 241 | embbed_cols = self.cols_encoder(embbed.T, knn_edge_index, infrance=True) 242 | out_features = self.feature_decodr(embbed_cols) 243 | 244 | return embbed_rows, embbed_cols, out_features 245 | 246 | -------------------------------------------------------------------------------- /scNET/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import scanpy as sc 5 | import torch 6 | import networkx as nx 7 | from scNET.MultyGraphModel import scNET 8 | from scNET.Utils import save_model, save_obj 9 | import torch 10 | from torch_geometric.utils import convert 11 | from torch_geometric.data import Data 12 | from torch_geometric.utils import train_test_split_edges 13 | from scNET.KNNDataset import KNNDataset, CellDataset 14 | from torch.utils.data import DataLoader 15 | import warnings 16 | import gc 17 | import scNET.Utils as ut 18 | import pkg_resources 19 | from tqdm import tqdm 20 | import warnings 21 | import random 22 | 23 | INTER_DIM = 250 24 | EMBEDDING_DIM = 75 25 | NETWORK_CUTOFF = 0.5 26 | MAX_CELLS_BATCH_SIZE = 4000 27 | MAX_CELLS_FOR_SPLITING = 10000 28 | DE_GENES_NUM = 3000 29 | EXPRESSION_CUTOFF = 0.0 30 | NUM_LAYERS = 3 31 | 32 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 33 | warnings.filterwarnings('ignore') 34 | 35 | 36 | def build_network(obj, net, biogrid_flag = False, human_flag = False): 37 | """ 38 | Build a gene-gene network from the provided interaction information. 39 | Args: 40 | obj (anndata.AnnData): Single-cell data object (AnnData) containing gene expression data. 41 | net (pandas.DataFrame): DataFrame containing gene interactions (Source, Target, and Conn columns). 42 | biogrid_flag (bool, optional): If True, columns for net are set to ["Source", "Target"] only. 43 | human_flag (bool, optional): If True, keeps gene names unchanged; otherwise adjusts gene name casing. 44 | Returns: 45 | tuple: 46 | pandas.DataFrame: Filtered interaction DataFrame for valid genes. 47 | networkx.Graph: Graph representation of the gene network. 48 | pandas.DataFrame: Node-level gene expression features. 49 | """ 50 | if not biogrid_flag: 51 | net.columns = ["Source","Target","Conn"] 52 | net = net.loc[net.Conn >= NETWORK_CUTOFF] 53 | 54 | else: 55 | net.columns = ["Source","Target"] 56 | 57 | if not human_flag: 58 | net["Source"] = net["Source"].apply(lambda x: x[0] + x[1:].lower()).astype(str) 59 | net["Target"] = net["Target"].apply(lambda x: x[0] + x[1:].lower()).astype(str) 60 | 61 | 62 | genes = list(pd.concat([net.Source, net.Target]).drop_duplicates()) 63 | genes = obj.var[obj.var.index.isin(genes)].index 64 | node_feature = sc.get.obs_df(obj.raw.to_adata(),list(genes)).T 65 | node_feature["non_zero"] = node_feature.apply(lambda x: x.astype(bool).sum(), axis=1) 66 | node_feature = node_feature.loc[node_feature.non_zero > node_feature.shape[1] * EXPRESSION_CUTOFF] 67 | node_feature.drop("non_zero",axis=1,inplace=True) 68 | 69 | net = net.loc[net.Source != net.Target] 70 | net = net.loc[net.Source.isin(node_feature.index)] 71 | net = net.loc[net.Target.isin(node_feature.index)] 72 | 73 | gp = nx.from_pandas_edgelist(net, "Source", "Target") 74 | 75 | node_feature = node_feature.loc[list(gp.nodes)] 76 | 77 | 78 | return net, gp, node_feature 79 | 80 | def test_recon(model,x, data, knn_edge_index): 81 | """ 82 | Evaluate model reconstruction performance on test edges. 83 | Args: 84 | model (torch.nn.Module): Trained scNET model. 85 | x (torch.Tensor): Input features for the nodes. 86 | data (torch_geometric.data.Data): Graph data object containing positive and negative edges. 87 | knn_edge_index (torch.Tensor): k-NN graph edges for the rows. 88 | Returns: 89 | float: AUC score of edge reconstruction. 90 | """ 91 | model.eval() 92 | with torch.no_grad(): 93 | embbed_rows, _, _ = model(x, knn_edge_index, data.train_pos_edge_index) 94 | return model.test(embbed_rows, data.test_pos_edge_index, data.test_neg_edge_index) 95 | 96 | def pre_processing(adata,n_neighbors): 97 | sc.pp.filter_cells(adata, min_genes=200) 98 | sc.pp.filter_genes(adata, min_cells=3) 99 | 100 | sc.pp.normalize_total(adata, target_sum=1e4) 101 | sc.pp.log1p(adata) 102 | adata.raw = adata.copy() 103 | sc.pp.neighbors(adata, n_neighbors=n_neighbors, n_pcs=15) 104 | 105 | return adata 106 | 107 | def crate_knn_batch(knn,idxs,k=15): 108 | """ 109 | Create a mini-batch of the k-NN graph for the given subset of indices. 110 | Args: 111 | knn (scipy.sparse.csr_matrix): Sparse adjacency matrix representing the k-NN graph. 112 | idxs (list[int]): List of indices used to subset the k-NN graph. 113 | k (int, optional): Number of nearest neighbors (used for reference if needed). 114 | Returns: 115 | torch.Tensor: Edge index for the sub-batch of the k-NN graph. 116 | """ 117 | idxs = idxs.cpu().numpy() 118 | adjacency_matrix = torch.tensor(knn[idxs][:,idxs].toarray()) 119 | row_indices, col_indices = torch.nonzero(adjacency_matrix, as_tuple=True) 120 | knn_edge_index = torch.stack((row_indices, col_indices)) 121 | knn_edge_index = torch.unique(knn_edge_index, dim=1) 122 | return knn_edge_index.to(device) 123 | 124 | def train(data, loader, highly_variable_index,number_of_batches=5 , 125 | max_epoch = 500, rduce_interavel = 30,model_name="", cell_flag=False): 126 | """ 127 | Train the scNET model using mini-batches of the k-NN graph or cells. 128 | Args: 129 | data (torch_geometric.data.Data): Graph data including edge information. 130 | loader (torch.utils.data.DataLoader): DataLoader for batches of edges or cells. 131 | highly_variable_index (pandas.Series or np.ndarray): Boolean mask for highly variable genes. 132 | number_of_batches (int, optional): Number of mini-batches. 133 | max_epoch (int, optional): Maximum number of training epochs. 134 | rduce_interavel (int, optional): Interval at which the model attempts graph reduction. 135 | model_name (str, optional): Custom string identifier for saving the model and outputs. 136 | cell_flag (bool, optional): If True, performs mini-batch training by cells rather than by edges. 137 | Returns: 138 | scNET: Trained scNET model instance. 139 | Build a k-NN graph from precomputed distances in the AnnData object. 140 | Args: 141 | obj (anndata.AnnData): Single-cell data object with 'distances' stored in obsp. 142 | Returns: 143 | tuple: 144 | torch.Tensor: Edge index of the k-NN graph. 145 | pandas.Series: Boolean mask for highly variable genes. 146 | Create a mini-batch DataLoader for k-NN edges. 147 | Args: 148 | edge_index (torch.Tensor): All edges of the k-NN graph. 149 | batch_size (int): Number of edges per mini-batch. 150 | Returns: 151 | torch.utils.data.DataLoader: DataLoader object for batching edges. 152 | """ 153 | x_full = data.x.clone() 154 | if cell_flag: 155 | model = scNET(x_full.shape[0], x_full.shape[1]//number_of_batches, 156 | INTER_DIM, EMBEDDING_DIM, INTER_DIM, EMBEDDING_DIM, lambda_rows = 1, lambda_cols=1,num_layers=NUM_LAYERS).to(device) 157 | else: 158 | model = scNET(x_full.shape[0], x_full.shape[1], INTER_DIM, EMBEDDING_DIM, INTER_DIM, EMBEDDING_DIM, 159 | lambda_rows = 1, lambda_cols=1, num_layers=NUM_LAYERS).to(device) 160 | x = x_full.clone() 161 | x = ((x.T - (x.mean(axis=1)))/ (x.std(axis=1)+ 0.00001)).T 162 | 163 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5) 164 | 165 | best_auc = 0.5 166 | concat_flag = False 167 | 168 | for epoch in tqdm(range(max_epoch), desc="Training", total=max_epoch): 169 | 170 | total_row_loss = 0 171 | total_col_loss = 0 172 | col_emb_lst = [] 173 | row_emb_lst = [] 174 | imput_lst = [] 175 | out_features_lst = [] 176 | concat_flag = False 177 | 178 | for _,batch in enumerate(loader): 179 | model.train() 180 | 181 | if cell_flag: 182 | x = batch[0].T 183 | x = ((x.T - (x.mean(axis=1)))/ (x.std(axis=1)+ 0.00001)).T 184 | knn_edge_index = crate_knn_batch(loader.dataset.knn, batch[1]) 185 | 186 | else: 187 | knn_edge_index = batch.T.to(device) 188 | 189 | if cell_flag or knn_edge_index.shape[1] == loader.dataset.edge_index.shape[0] // number_of_batches : 190 | 191 | loss, row_loss, col_loss = model.calculate_loss(x.clone().to(device), knn_edge_index.to(device), 192 | data.train_pos_edge_index,highly_variable_index) 193 | optimizer.zero_grad() 194 | loss.backward() 195 | optimizer.step() 196 | 197 | total_row_loss += row_loss 198 | total_col_loss += col_loss 199 | 200 | with torch.no_grad(): 201 | if cell_flag: 202 | row_embed, col_embed, out_features = model(x.clone().to(device), knn_edge_index, data.train_pos_edge_index) 203 | imput = model.encoder(x.to(device), knn_edge_index, data.train_pos_edge_index) 204 | col_emb_lst.append(col_embed.cpu()) 205 | row_emb_lst.append(row_embed.cpu()) 206 | imput_lst.append(imput.T.cpu()) 207 | out_features_lst.append(out_features.cpu()) 208 | else: 209 | row_embed, col_embed, out_features = model(x.to(device),knn_edge_index.to(device), data.train_pos_edge_index) 210 | 211 | else: 212 | concat_flag = True 213 | 214 | gc.collect() 215 | torch.cuda.empty_cache() 216 | 217 | if not cell_flag: 218 | new_knn_edge_index, _ = model.cols_encoder.reduce_network() 219 | 220 | if concat_flag: 221 | new_knn_edge_index = torch.concat([new_knn_edge_index,knn_edge_index], axis=-1) 222 | knn_edge_index = new_knn_edge_index 223 | 224 | if (epoch+1) % rduce_interavel == 0: 225 | #print(new_knn_edge_index.shape[1] / loader.dataset.edge_index.shape[0]) 226 | loader = mini_batch_knn(new_knn_edge_index, new_knn_edge_index.shape[1] // number_of_batches) 227 | 228 | 229 | 230 | if epoch%10 == 0: 231 | if not cell_flag: 232 | knn_edge_index = list(loader)[0].T.to(device) 233 | 234 | auc, ap = test_recon(model, x.to(device), data, knn_edge_index) 235 | 236 | if auc > best_auc: 237 | best_auc = auc 238 | 239 | if cell_flag: 240 | st = torch.stack(row_emb_lst) 241 | row_embed = st.mean(dim=0) 242 | save_obj(torch.concat(col_emb_lst).cpu().detach().numpy(), pkg_resources.resource_filename(__name__,r"Embedding/col_embedding_" + model_name)) 243 | save_obj(row_embed.cpu().detach().numpy(), pkg_resources.resource_filename(__name__,r"Embedding/row_embedding_" + model_name)) 244 | save_obj(torch.concat(out_features_lst).cpu().detach().numpy(), pkg_resources.resource_filename(__name__,r"Embedding/out_features_" + model_name)) 245 | else: 246 | save_obj(new_knn_edge_index.cpu(),pkg_resources.resource_filename(__name__, r"KNNs/best_new_knn_graph_" + model_name)) 247 | save_obj(col_embed.cpu().detach().numpy(), pkg_resources.resource_filename(__name__,r"Embedding/col_embedding_" + model_name)) 248 | save_obj(row_embed.cpu().detach().numpy(), pkg_resources.resource_filename(__name__,r"Embedding/row_embedding_" + model_name)) 249 | save_obj(out_features.cpu().detach().numpy(), pkg_resources.resource_filename(__name__,r"Embedding/out_features_" + model_name)) 250 | 251 | print(f"Best Network AUC: {best_auc}") 252 | # if cell_flag: 253 | # save_obj(loader, "knn_loader"+model_name) 254 | # else: 255 | # save_obj(new_knn_edge_index.cpu(), "new_knn_graph_"+model_name) 256 | 257 | return model 258 | 259 | def build_knn_graph(obj): 260 | graph = obj.obsp["distances"].toarray() 261 | graph = (graph > 0).astype(int) 262 | graph = nx.from_numpy_array(np.matrix(graph)) 263 | ppi_geo = convert.from_networkx(graph) 264 | edge_index = ppi_geo.edge_index 265 | sc.pp.highly_variable_genes(obj, n_top_genes=DE_GENES_NUM) 266 | return edge_index, obj.var.highly_variable 267 | 268 | def mini_batch_knn(edge_index, batch_size): 269 | """ 270 | Create a mini-batch DataLoader for cells and their corresponding edges. 271 | Args: 272 | x (torch.Tensor): Matrix of gene expression features. 273 | edge_index (scipy.sparse.spmatrix): Distance or similarity matrix for cells. 274 | batch_size (int): Number of cells per mini-batch. 275 | Returns: 276 | torch.utils.data.DataLoader: DataLoader object for batching cells. 277 | Convert a NetworkX graph to a PyTorch Geometric edge index. 278 | Args: 279 | G (networkx.Graph): Input NetworkX graph. 280 | mapping (dict, optional): Dictionary mapping original node IDs to new indices. 281 | Returns: 282 | tuple: 283 | torch.Tensor: PyTorch Geometric edge index. 284 | dict: Mapping of graph nodes to tensor indices. 285 | """ 286 | knn_dataset = KNNDataset(edge_index) 287 | knn_loader = DataLoader(knn_dataset,batch_size=batch_size, shuffle=True, drop_last=False) 288 | return knn_loader 289 | 290 | def mini_batch_cells(x,edge_index, batch_size): 291 | cell_dataset = CellDataset(x, edge_index) 292 | cell_loader = DataLoader(cell_dataset,batch_size=batch_size, shuffle=False, drop_last=True) 293 | return cell_loader 294 | 295 | def nx_to_pyg_edge_index(G, mapping=None): 296 | G = G.to_directed() if not nx.is_directed(G) else G 297 | if mapping is None: 298 | mapping = dict(zip(G.nodes(), range(G.number_of_nodes()))) 299 | edge_index = torch.empty((2, G.number_of_edges()), dtype=torch.long).to(device) 300 | for i, (src, dst) in enumerate(G.edges()): 301 | edge_index[0, i] = mapping[src] 302 | edge_index[1, i] = mapping[dst] 303 | return edge_index, mapping 304 | 305 | 306 | def run_scNET(obj,pre_processing_flag = True ,biogrid_flag = False, 307 | human_flag=False,number_of_batches=5,split_cells = False, n_neighbors=25, 308 | max_epoch=150, model_name="", save_model_flag = False, bbknn_flag = False): 309 | 310 | """ 311 | Main function to load data, build networks, and run the scNET training pipeline. 312 | Args: 313 | obj (AnnData, optional): AnnData obj. 314 | pre_processing_flag (bool, optional): If True, perform pre-processing steps. 315 | biogrid_flag (bool, optional): If True, use BioGRID-formatted data for network building. 316 | human_flag (bool, optional): Controls gene name casing in the network. 317 | number_of_batches (int, optional): Number of mini-batches for the training. 318 | split_cells (bool, optional): If True, split by cells instead of edges during training. 319 | n_neighbors (int, optional): Number of neighbors for building the adjacency graph. 320 | max_epoch (int, optional): Max number of epochs for model training. 321 | model_name (str, optional): Identifier for saving the model outputs. 322 | save_model_flag (bool, optional): If True, save the trained model. 323 | bbknn_flag (bool, optional): If True, use BBKNN for building the adjacency graph. 324 | Returns: 325 | scNET: A trained scNET model. 326 | """ 327 | random.seed(42) 328 | np.random.seed(42) 329 | torch.manual_seed(42) 330 | torch.cuda.manual_seed(42) 331 | torch.cuda.manual_seed_all(42) 332 | torch.backends.cudnn.deterministic = True 333 | torch.backends.cudnn.benchmark = False 334 | 335 | if pre_processing_flag: 336 | obj = pre_processing(obj,n_neighbors) 337 | 338 | else: 339 | if obj.raw is None: 340 | obj.raw = obj.copy() 341 | sc.pp.log1p(obj) 342 | obj.X = obj.raw.X 343 | if not bbknn_flag: 344 | sc.pp.neighbors(obj, n_neighbors=n_neighbors, n_pcs=15) 345 | 346 | if obj.obs.shape[0] > MAX_CELLS_FOR_SPLITING: 347 | split_cells = True 348 | 349 | if split_cells: 350 | batch_size = obj.obs.shape[0] // number_of_batches 351 | if batch_size > MAX_CELLS_BATCH_SIZE: 352 | number_of_batches = obj.obs.shape[0] // MAX_CELLS_BATCH_SIZE 353 | 354 | 355 | if not biogrid_flag: 356 | print(pkg_resources.resource_filename(__name__,r"Data/format_h_sapiens.csv")) 357 | 358 | net = pd.read_csv(pkg_resources.resource_filename(__name__,r"Data/format_h_sapiens.csv"))[["g1_symbol","g2_symbol","conn"]].drop_duplicates() 359 | net, ppi, node_feature = build_network(obj, net,human_flag=human_flag) 360 | print(f"N genes: {node_feature.shape}") 361 | 362 | else: 363 | print(pkg_resources.resource_filename(__name__,r"Data/BIOGRID.tab.txt")) 364 | net = pd.read_table(pkg_resources.resource_filename(__name__,r"Data/BIOGRID.tab.txt"))[["OFFICIAL_SYMBOL_A","OFFICIAL_SYMBOL_B"]].drop_duplicates() 365 | net, ppi, node_feature = build_network(obj, net, biogrid_flag,human_flag) 366 | print(f"N genes: {node_feature.shape}") 367 | 368 | ppi_edge_index, _ = nx_to_pyg_edge_index(ppi) 369 | ppi_edge_index = ppi_edge_index.to(device) 370 | 371 | if split_cells: 372 | obj = obj[:,node_feature.index] 373 | sc.pp.highly_variable_genes(obj,n_top_genes=DE_GENES_NUM) 374 | highly_variable_index = obj.var.highly_variable 375 | if highly_variable_index.sum() < 1000 or highly_variable_index.sum() > 5000: 376 | obj.var["std"] = sc.get.obs_df(obj.raw.to_adata(),list(obj.var.index)).std() 377 | highly_variable_index = obj.var["std"] >= obj.var["std"].sort_values(ascending=False)[3500] 378 | 379 | print(f"Highly variable genes: {highly_variable_index.sum()}") 380 | 381 | 382 | else: 383 | obj = obj[:,node_feature.index] 384 | knn_edge_index, highly_variable_index = build_knn_graph(obj) 385 | loader = mini_batch_knn(knn_edge_index, knn_edge_index.shape[1] // number_of_batches) 386 | 387 | highly_variable_index = highly_variable_index[node_feature.index] 388 | #node_feature.to_csv(pkg_resources.resource_filename(__name__,r"Embedding/node_features_" + model_name)) 389 | node_feature.to_pickle(pkg_resources.resource_filename(__name__,r"Embedding/node_features_" + model_name)) 390 | 391 | x = node_feature.values 392 | 393 | x = torch.tensor(x, dtype=torch.float32).cpu() 394 | if split_cells: 395 | loader = mini_batch_cells(x, obj.obsp["distances"], x.shape[1] // number_of_batches) 396 | 397 | data = Data(x,ppi_edge_index) 398 | data = train_test_split_edges(data,test_ratio=0.2, val_ratio=0) 399 | model = train(data, loader, highly_variable_index, number_of_batches=number_of_batches, max_epoch=max_epoch, 400 | rduce_interavel=30,model_name=model_name, cell_flag=split_cells) 401 | 402 | if save_model_flag: 403 | save_model(pkg_resources.resource_filename(__name__, r"Models/scNET_" + model_name + ".pt"), model) 404 | 405 | -------------------------------------------------------------------------------- /scNET/coEmbeddedNetwork.py: -------------------------------------------------------------------------------- 1 | from scipy.stats import spearmanr 2 | import numpy as np 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | import seaborn as sns 6 | import torch 7 | from sklearn.cluster import KMeans 8 | #import umap.plot 9 | import networkx as nx 10 | from networkx.algorithms import community 11 | import networkx.algorithms.community as nx_comm 12 | from sklearn.metrics import precision_recall_curve, auc 13 | import scNET.Utils as ut 14 | import gseapy as gp 15 | import os 16 | 17 | from sklearn.metrics import average_precision_score, roc_auc_score, adjusted_rand_score 18 | import scanpy as sc 19 | import urllib 20 | import gseapy as gp 21 | import warnings 22 | import pkg_resources 23 | warnings.filterwarnings('ignore') 24 | 25 | 26 | device = torch.device("cpu") 27 | cp = { 28 | '0': '#1f77b4', 29 | '1': '#aec7e8', 30 | '2': '#ff7f0e', 31 | '3': '#ffbb78', 32 | '4': '#2ca02c', 33 | '5': '#98df8a', 34 | '6': '#d62728', 35 | '7': '#ff9896', 36 | '8': '#9467bd', 37 | '9': '#c5b0d5', 38 | '10': '#8c564b', 39 | '11': '#c49c94', 40 | '12': '#e377c2', 41 | '13': '#f7b6d2', 42 | '14': '#7f7f7f', 43 | '15': '#c7c7c7', 44 | '16': '#bcbd22', 45 | '17': '#dbdb8d', 46 | '18': '#17becf', 47 | '19': '#9edae5', 48 | '20': '#1f77b4', 49 | '21': '#ff7f0e', 50 | '22': '#2ca02c', 51 | '23': '#d62728', 52 | '24': '#9467bd', 53 | '25': '#8c564b', 54 | '26': '#e377c2', 55 | '27': '#7f7f7f', 56 | '28': '#bcbd22', 57 | '29': '#17becf' 58 | } 59 | 60 | def create_reconstructed_obj(node_features, out_features, orignal_obj=None): 61 | ''' 62 | Creates an AnnData object from reconstructed gene expression data, normalizes it, and computes PCA, neighbors, clustering, and UMAP. 63 | 64 | Args: 65 | node_features (pd.DataFrame): The original gene expression matrix with genes as columns and cells as rows. 66 | out_features (np.ndarray): The reconstructed gene expression matrix. 67 | original_obj (AnnData, optional): The original AnnData object, if available, to copy cell metadata (obs) from. Defaults to None. 68 | 69 | Returns: 70 | AnnData: An AnnData object containing the reconstructed gene expression data with PCA, neighbors, Leiden clustering, and UMAP embeddings computed. 71 | ''' 72 | embd = pd.DataFrame(out_features,index=node_features.columns[:out_features.shape[0]], columns=node_features.index) 73 | 74 | embd = (embd - embd.min()) / (embd.max() - embd.min()) 75 | 76 | adata = sc.AnnData(embd) 77 | if not orignal_obj is None: 78 | adata.obs = orignal_obj.obs[:embd.shape[0]] 79 | 80 | sc.tl.pca(adata, svd_solver='arpack') 81 | sc.pp.neighbors(adata, n_neighbors=10, n_pcs=15) 82 | sc.tl.leiden(adata,resolution=0.5) 83 | sc.tl.umap(adata) 84 | return adata 85 | 86 | 87 | def calculate_marker_gene_aupr(adata, marker_genes=['Cd4','Cd14',"P2ry12","Ncr1"]\ 88 | , cell_types=[['CD4 Tcells'], ['Macrophages'], ['Microglia'],["NK"]]): 89 | ''' 90 | Calculates the Area Under the Precision-Recall curve (AUPR) for specified marker genes in identifying specific cell types. 91 | 92 | Args: 93 | adata (AnnData): The annotated data matrix (AnnData object) containing gene expression data and cell type information. 94 | marker_genes (list of str, optional): A list of marker genes to evaluate. Defaults to ['Cd4', 'Cd8a', 'Cd14', 'P2ry12', 'Ncr1']. 95 | cell_types (list of lists, optional): A list of lists where each sublist contains the cell types associated with the corresponding marker gene. Defaults to [['CD4 Tcells'], ['CD8 Tcells', 'NK'], ['Macrophages'], ['Microglia'], ['NK']]. 96 | 97 | Returns: 98 | None: The function prints the AUPR score for each marker gene and the corresponding cell type(s). 99 | 100 | ''' 101 | for marker_gene, cell_type in zip(marker_genes, cell_types): 102 | gene_expression = adata[:, marker_gene].X.toarray().flatten() 103 | binary_labels = (adata.obs["Cell Type"].isin(cell_type)).astype(int) 104 | 105 | precision, recall, _ = precision_recall_curve(binary_labels, gene_expression) 106 | aupr = auc(recall, precision) 107 | 108 | print(f"AUPR for {marker_gene} in identifying {cell_type[0]}: {aupr:.4f}") 109 | 110 | 111 | def pathway_enricment(adata, groupby="seurat_clusters", groups=None, gene_sets=None, logfc_threshold=0, pval_threshold=0.05): 112 | ''' 113 | Performs pathway enrichment analysis using KEGG pathways for differentially expressed genes in specific groups. 114 | 115 | Args: 116 | adata (AnnData): The annotated data matrix (AnnData object) containing gene expression data and cell clustering/grouping information. 117 | groupby (str, optional): The key in `adata.obs` to group cells by for differential expression analysis. Defaults to "seurat_clusters". 118 | groups (list, optional): A list of specific groups (clusters or cell types) to analyze. If None, all unique groups in `adata.obs[groupby]` are used. Defaults to None. 119 | gene_sets (dict, optional): A dictionary of gene sets to use for pathway enrichment analysis. If None, the KEGG 2021 Human gene sets are used. Defaults to None. 120 | 121 | Returns: 122 | tuple: A tuple containing: 123 | - de_genes_per_group (dict): A dictionary where keys are group names and values are DataFrames of differentially expressed genes for each group. 124 | - significant_pathways (dict): A dictionary where keys are group names and values are DataFrames of significant KEGG pathways with adjusted p-values for each group. 125 | - filtered_kegg (dict): A dictionary of KEGG pathways filtered for genes present in the dataset. 126 | - enrichment_results (dict): A dictionary where keys are group names and values are full enrichment results from Enrichr for each group. 127 | 128 | Method: 129 | - The function retrieves KEGG pathways and filters them based on the genes present in `adata`. 130 | - Differentially expressed (DE) genes are identified for each group using the Wilcoxon rank-sum test. 131 | - Pathway enrichment analysis is performed using Enrichr, based on the DE genes. 132 | - Pathways with adjusted p-values below 0.05 are considered significant. 133 | ''' 134 | adata.var.index = adata.var.index.str.upper() 135 | if gene_sets is None: 136 | gene_sets = gp.get_library('KEGG_2021_Human') 137 | 138 | filtered_gene_set = {pathway: [gene for gene in genes if gene in adata.var.index] 139 | for pathway, genes in gene_sets.items()} 140 | 141 | filtered_gene_set = {pathway: genes for pathway, genes in filtered_gene_set.items() if len(genes) > 0} 142 | 143 | 144 | if groups is None: 145 | groups = adata.obs[groupby].unique() 146 | 147 | sc.tl.rank_genes_groups(adata, groupby=groupby, method='wilcoxon') 148 | 149 | de_genes_per_group = {} 150 | for group in groups: 151 | dedf = sc.get.rank_genes_groups_df(adata, group=group) 152 | dedf.names = dedf.names.str.upper() 153 | genes = dedf[(dedf['logfoldchanges'] > logfc_threshold) & (dedf["pvals_adj"] < pval_threshold)] 154 | de_genes_per_group[group] = dedf[(dedf['logfoldchanges'] > logfc_threshold) & (dedf["pvals_adj"] < pval_threshold)] 155 | 156 | enrichment_results = {} 157 | significant_pathways = {} 158 | significance_threshold = 0.05 159 | 160 | for group, genes in de_genes_per_group.items(): 161 | 162 | try: 163 | genes = genes['names'].values 164 | enr = gp.enrichr(gene_list=(genes.tolist()), 165 | gene_sets=filtered_gene_set, 166 | background=list(adata.var.index), 167 | organism='Human', 168 | outdir=None) 169 | except: 170 | continue 171 | 172 | significant = enr.results[enr.results['Adjusted P-value'] < significance_threshold] 173 | 174 | enrichment_results[group] = enr.results 175 | significant_pathways[group] = significant[['Term', 'Adjusted P-value']] 176 | 177 | 178 | return de_genes_per_group, significant_pathways, filtered_gene_set , enrichment_results 179 | 180 | 181 | def plot_de_pathways(significant_pathways,enrichment_results, head=20): 182 | ''' 183 | Plots a heatmap of the -log10(Adjusted P-value) for significant pathways across multiple datasets. 184 | 185 | Args: 186 | significant_pathways (dict): A dictionary where keys are dataset names (or groups), and values are DataFrames containing significant pathways and their adjusted p-values. 187 | enrichment_results (dict): A dictionary where keys are dataset names (or groups), and values are DataFrames containing full pathway enrichment results, including adjusted p-values for each pathway. 188 | head (int, optional): The number of top pathways to display in the heatmap. Defaults to 20. 189 | 190 | Returns: 191 | None: The function generates and displays a heatmap showing the significance (-log10(Adjusted P-value)) of the top 20 pathways across different datasets. 192 | 193 | ''' 194 | 195 | data_dict = significant_pathways 196 | combined_df = pd.DataFrame() 197 | 198 | for _, df in enrichment_results.items(): 199 | top5_df = df.sort_values(by='Adjusted P-value').head(head) 200 | for dataset_name, df2 in enrichment_results.items(): 201 | df2 = df2.loc[df2.Term.isin(top5_df.Term)] 202 | df2['Dataset'] = dataset_name 203 | combined_df = pd.concat([combined_df, df2]) 204 | 205 | combined_df['Unique Term'] = combined_df['Term'] 206 | 207 | combined_df['-log10(Adjusted P-value)'] = -np.log(combined_df['Adjusted P-value']) 208 | 209 | # Pivot the data to make a matrix suitable for a heatmap 210 | pivot_df = combined_df.drop_duplicates().pivot(index="Unique Term", columns="Dataset", values="-log10(Adjusted P-value)") 211 | pivot_df.fillna(0,inplace=True) 212 | plt.figure(figsize=(10, 30)) 213 | g = sns.clustermap(pivot_df, annot=False, cmap="YlGnBu", linewidths=.5,figsize=(15,25)) 214 | plt.title('Heatmap of Pathway Significance by Dataset', fontsize=18) 215 | g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), fontsize=15,rotation=45) 216 | g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), fontsize=15) 217 | g.ax_heatmap.set_xlabel('Dataset', fontsize=15) 218 | g.ax_heatmap.set_ylabel('Pathway Term', fontsize=15) 219 | plt.tight_layout() 220 | 221 | 222 | def plot_gene_umap_clustring(embedded_rows): 223 | means_embedd = KMeans(n_clusters=20, random_state=42).fit(embedded_rows) 224 | obj = sc.AnnData(embedded_rows) 225 | obj.obs["cluster"] = means_embedd.labels_ 226 | obj.obs["cluster"] = obj.obs.cluster.astype(str) 227 | sc.pp.neighbors(obj, n_neighbors=12) 228 | sc.tl.leiden(obj) 229 | sc.tl.umap(obj) 230 | sc.pl.umap(obj, color="cluster",palette=cp) 231 | return means_embedd.labels_ 232 | 233 | 234 | def build_co_embeded_network(embedded_rows ,node_features,threshold=99): 235 | ''' 236 | Builds a co-embedded network from the given embedded rows using a correlation-based thresholding approach and detects communities using the Louvain algorithm. 237 | 238 | Args: 239 | embedded_rows (np.ndarray): A matrix of embeddings (e.g., gene embeddings) where each row corresponds to an entity (e.g., gene or cell). 240 | node_features (pd.DataFrame): A DataFrame containing features or identifiers for the nodes, where the index corresponds to the entities in `embedded_rows`. 241 | threshold (int, optional): The percentile threshold to use when binarizing the correlation matrix. Defaults to 99. 242 | 243 | Returns: 244 | tuple: A tuple containing: 245 | - graph (networkx.Graph): The co-embedded network created from the thresholded correlation matrix. 246 | - mod (float): The modularity score of the network, indicating the strength of the community structure. 247 | 248 | Method: 249 | - Computes the absolute Pearson correlation matrix between the rows of `embedded_rows`. 250 | - Applies a threshold (specified by the `threshold` percentile) to binarize the correlation matrix. 251 | - Constructs a graph where nodes are connected if their correlation exceeds the threshold. 252 | - Applies the Louvain algorithm to detect communities within the graph. 253 | - Calculates the modularity score for the detected communities. 254 | - Relabels the graph nodes using the identifiers from `node_fetures`. 255 | ''' 256 | corr = np.corrcoef(embedded_rows) 257 | corr = np.abs(corr) 258 | np.fill_diagonal(corr,0) 259 | mat = (np.abs(corr) > np.percentile(corr, threshold)).astype(np.int64) 260 | graph = nx.from_numpy_array(mat) 261 | comm = nx_comm.louvain_communities(graph,resolution=1, seed=42) 262 | mod = nx_comm.modularity(graph, comm) 263 | map_nodes = {list(graph.nodes)[i]:node_features.index[i] for i in range(len(node_features.index))} 264 | graph = nx.relabel_nodes(graph,map_nodes) 265 | return graph, mod 266 | 267 | 268 | def crate_kegg_annot(all_genes): 269 | ''' 270 | Creates a binary annotation matrix for KEGG pathways, indicating gene-pathway memberships for a given set of genes. 271 | 272 | Args: 273 | all_genes (list): A list of genes to be annotated with KEGG pathway membership. 274 | 275 | Returns: 276 | pd.DataFrame: A binary DataFrame where rows correspond to genes and columns to KEGG pathways. A value of 1 indicates that a gene is part of the pathway, and 0 otherwise. 277 | ''' 278 | KEGG_custom = gp.get_library("KEGG_2021_Human") 279 | filtered_kegg = {pathway: [gene for gene in genes if gene in all_genes] 280 | for pathway, genes in KEGG_custom.items()} 281 | array = [ (gene, key) for key in filtered_kegg for gene in filtered_kegg[key] ] 282 | kegg_df = pd.DataFrame(array) 283 | df = pd.DataFrame(0, index=all_genes, columns=filtered_kegg.keys()) 284 | 285 | for key, values in filtered_kegg.items(): 286 | for value in values: 287 | df.loc[value, key] = 1 288 | 289 | return df 290 | 291 | 292 | def calculate_aupr(pred, vec, test_vec): 293 | pred_test = list(map(lambda x: pred[list(vec.index).index(x)], test_vec.index)) 294 | return average_precision_score(test_vec.values, pred_test) 295 | 296 | 297 | def make_term_predication(graphs, term_vec): 298 | ''' 299 | Propagates gene-term predictions across multiple graphs and evaluates their performance using AUPR. 300 | 301 | Args: 302 | graphs (list of networkx.Graph): A list of graphs (e.g., gene co-expression networks) to propagate term information. 303 | term_vec (pd.Series): A binary vector indicating whether each gene is associated with a specific KEGG term. 304 | 305 | Returns: 306 | list: A list of AUPR scores for each graph's predictions. 307 | 308 | Method: 309 | - Splits the term vector into training and testing sets. 310 | - Uses propagation on each graph to predict term associations for genes. 311 | - Evaluates predictions using AUPR. 312 | ''' 313 | train_vec = term_vec.sample(frac=0.7) 314 | test_vec = term_vec[~term_vec.index.isin(train_vec.index)] 315 | test_pos = test_vec[test_vec == 1] 316 | test_neg = test_vec[test_vec == 0].sample(test_pos.shape[0]) 317 | test_vec = test_vec[list(test_pos.index) + list(test_neg.index)] 318 | vec = term_vec.copy() 319 | vec *= list(map(lambda x: train_vec[x] if x in train_vec.index else float(0), vec.index)) 320 | result_aupr = [] 321 | for graph in graphs: 322 | w = nx.to_pandas_adjacency(graph) 323 | w = w.loc[term_vec.index, term_vec.index] 324 | train_vec = vec.copy() 325 | pred = ut.propagation(train_vec.values, w) 326 | result_aupr.append([calculate_aupr(pred , term_vec, test_vec)]) 327 | return result_aupr 328 | 329 | 330 | def test_KEGG_prediction(gene_embedding, ref): 331 | ''' 332 | Predicts KEGG pathway memberships using gene embeddings and reference data, and evaluates the performance using AUPR. 333 | 334 | Args: 335 | gene_embedding (np.ndarray): The matrix of gene embeddings. 336 | ref (pd.DataFrame): A reference dataset containing gene expression or other relevant features. 337 | 338 | Returns: 339 | pd.DataFrame: A DataFrame containing the AUPR scores for predictions from the gene embeddings and reference data. 340 | 341 | Method: 342 | - Annotates genes with KEGG pathway memberships using `crate_kegg_annot`. 343 | - Filters KEGG pathways to include those with at least 40 gene members. 344 | - Constructs co-embedded networks from both the embeddings and reference data. 345 | - Uses propagation to predict pathway memberships for each graph. 346 | - Evaluates the predictions using AUPR and plots the results. 347 | ''' 348 | ref.index = list(map(lambda x: x.upper(),ref.index)) 349 | annot = crate_kegg_annot(ref.index) 350 | annot_threshold = annot.sum()>=40 351 | annot_threshold = annot_threshold[annot_threshold == True].sort_values(ascending=False).head(50) 352 | graph_embedded,_ = build_co_embeded_network(gene_embedding,ref) 353 | graph_ref,_ =build_co_embeded_network(ref,ref) 354 | kegg_pred = [make_term_predication([graph_embedded,graph_ref], annot[term]) for term in annot_threshold.index] 355 | 356 | kegg_pred = np.array(kegg_pred).squeeze() 357 | df = pd.DataFrame({"AUPR" : kegg_pred.T.reshape(-1), "Method": ["scNET" for i in range(kegg_pred.shape[0])] + ["Counts" for i in range(kegg_pred.shape[0])]}) 358 | 359 | fig, ax = plt.subplots(figsize=[10,7]) 360 | fig.set_dpi(600) 361 | 362 | custom_palette = ['darkturquoise', 'lightsalmon'] 363 | 364 | sns.boxenplot(ax=ax, data=df,x="Method", y="AUPR", palette=custom_palette) 365 | sns.set_theme(style='white',font_scale=1.5) 366 | plt.show() 367 | return df 368 | 369 | 370 | def find_downstream_tfs(net, signature, human_flag=False): 371 | """ 372 | Find downstream transcription factors (TFs) in a given network based on an input node signature. 373 | Parameters 374 | ---------- 375 | net : networkx.Graph 376 | The network graph representing nodes and edges. 377 | signature : list or set 378 | A collection of nodes (e.g., genes) representing a specific signature. 379 | Returns 380 | ------- 381 | pandas.Series 382 | A series of normalized propagation scores for transcription factors present 383 | in the network and in the TF dictionary. Each entry corresponds to the TF 384 | and its corresponding propagation score. 385 | """ 386 | 387 | url = "https://raw.githubusercontent.com/madilabcode/interFLOW/555a374b3057a99cd2d18760a4923499bf58d963/files/TFdictBT1.npy" 388 | local_filename = "TFdictBT1.npy" 389 | urllib.request.urlretrieve(url, local_filename) 390 | tf_dict = np.load(local_filename, allow_pickle=True) 391 | tf_dict = tf_dict.item() 392 | if not human_flag: 393 | tfs = list(map(lambda x: x[0] + x[1:].lower(),tf_dict.keys())) 394 | W = nx.to_numpy_array(net) 395 | v = np.array([1 if x in signature else 0 for x in net.nodes()]) 396 | res = ut.propagation(v,W) 397 | res = pd.Series(res) 398 | res.index = list(net.nodes()) 399 | res = res[res.index.isin(tfs)] 400 | res = (res - res.min()) / (res.max() - res.min()) 401 | res.sort_values(ascending=False) 402 | return res 403 | 404 | 405 | def plot_umap_cells(cell_embedding): 406 | obj = sc.AnnData(cell_embedding) 407 | sc.pp.neighbors(obj, n_neighbors=12) 408 | sc.tl.leiden(obj) 409 | sc.tl.umap(obj) 410 | sc.pl.umap(obj, color="leiden",palette=cp) --------------------------------------------------------------------------------