├── LICENSE ├── README.md ├── environment.yml └── portal ├── __init__.py ├── model.py ├── networks.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 YangLabHKUST 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Portal 2 | [![DOI](https://zenodo.org/badge/423325112.svg)](https://zenodo.org/badge/latestdoi/423325112) [![PyPI](https://img.shields.io/pypi/v/portal-sc?color=green)](https://pypi.python.org/pypi/portal-sc/) [![PyPi license](https://badgen.net/pypi/license/portal-sc/)](https://pypi.org/project/portal-sc/) [![Downloads](https://static.pepy.tech/personalized-badge/portal-sc?period=total&units=international_system&left_color=grey&right_color=orange&left_text=downloads)](https://pepy.tech/project/portal-sc) [![Stars](https://img.shields.io/github/stars/YangLabHKUST/Portal?logo=GitHub&color=yellow)](https://github.com/YangLabHKUST/Portal/stargazers) 3 | 4 | *Adversarial domain translation networks for integrating large-scale atlas-level single-cell datasets* 5 | 6 | An efficient, accurate and flexible method for single-cell data integration. 7 | 8 | Check out our manuscript in Nature Computational Science: 9 | + [Nature Computational Science website](https://www.nature.com/articles/s43588-022-00251-y) 10 | + [Read fulltext link](https://rdcu.be/cOCbU) 11 | + [Preprint in bioRxiv](https://www.biorxiv.org/content/10.1101/2021.11.16.468892v2) 12 | 13 | ## Reproducibility 14 | We provide [source codes](https://github.com/jiazhao97/Portal-reproducibility) for reproducing the experiments of the paper "Adversarial domain translation networks for fast and accurate integration of large-scale atlas-level single-cell datasets". 15 | 16 | + [Integration of mouse spleen datasets](https://htmlpreview.github.io/?https://github.com/jiazhao97/Portal-reproducibility/blob/main/Reproduce-Spleen.html) (we reproduce the result of performance metrics in this notebook as an example). [Benchmarking](https://htmlpreview.github.io/?https://github.com/jiazhao97/Portal-reproducibility/blob/main/Benchmarking-Spleen.html). 17 | + [Integration of mouse marrow datasets](https://htmlpreview.github.io/?https://github.com/jiazhao97/Portal-reproducibility/blob/main/Reproduce-Marrow.html). 18 | + [Integration of mouse bladder datasets](https://htmlpreview.github.io/?https://github.com/jiazhao97/Portal-reproducibility/blob/main/Reproduce-Bladder.html). 19 | + [Integration of mouse brain cerebellum datasets](https://htmlpreview.github.io/?https://github.com/jiazhao97/Portal-reproducibility/blob/main/Reproduce-MouseBrain-CB.html). 20 | + [Integration of mouse brain hippocampus datasets](https://htmlpreview.github.io/?https://github.com/jiazhao97/Portal-reproducibility/blob/main/Reproduce-MouseBrain-HC.html). 21 | + [Integration of mouse brain thalamus datasets](https://htmlpreview.github.io/?https://github.com/jiazhao97/Portal-reproducibility/blob/main/Reproduce-MouseBrain-TH.html). 22 | + [Integration of human PBMC datasets](https://htmlpreview.github.io/?https://github.com/jiazhao97/Portal-reproducibility/blob/main/Reproduce-sensitivity.html) (sensitivity analysis). 23 | + [Integration of entire mouse cell atlases from the Tablula Muris project](https://htmlpreview.github.io/?https://github.com/jiazhao97/Portal-reproducibility/blob/main/Reproduce-TabulaMuris-full.html). 24 | + [Integration of mouse brain scRNA-seq and snRNA-seq datasets](https://htmlpreview.github.io/?https://github.com/jiazhao97/Portal-reproducibility/blob/main/Reproduce-MouseBrain-CellNuclei.html). 25 | + [Integration of human PBMC scRNA-seq and human brain snRNA-seq datasets](https://htmlpreview.github.io/?https://github.com/jiazhao97/Portal-reproducibility/blob/main/Reproduce-BloodCell-BrainNuclei.html). 26 | + [Integration of scRNA-seq and scATAC-seq datasets](https://htmlpreview.github.io/?https://github.com/jiazhao97/Portal-reproducibility/blob/main/Reproduce-PBMC-ATACseq.html). 27 | + [Integration of developmental trajectories](https://htmlpreview.github.io/?https://github.com/jiazhao97/Portal-reproducibility/blob/main/Reproduce-trajectory.html). 28 | + [Integration of spermatogenesis differentiation process across multiple species](https://htmlpreview.github.io/?https://github.com/jiazhao97/Portal-reproducibility/blob/main/Reproduce-Spermatogenesis.html). Gene lists from Ensembl Biomart (we only use genes that are assigned with the type "ortholog_one2one" in the lists): [orthologues (human vs mouse)](https://github.com/jiazhao97/Portal-reproducibility/raw/main/orthologues_human_mouse.txt), [orthologues (human vs macaque)](https://github.com/jiazhao97/Portal-reproducibility/raw/main/orthologues_human_macaque.txt). 29 | 30 | ## Installation 31 | * Portal can be installed from PyPI: 32 | ```bash 33 | pip install portal-sc 34 | ``` 35 | 36 | * Alternatively, Portal can also be downloaded from GitHub: 37 | ```bash 38 | git clone https://github.com/YangLabHKUST/Portal.git 39 | cd Portal 40 | conda env update --f environment.yml 41 | conda activate portal 42 | ``` 43 | 44 | Normally the installation time is less than 5 minutes. 45 | 46 | ## Quick Start 47 | ### Basic Usage 48 | Starting with raw count matrices formatted as AnnData objects, Portal uses a standard pipline adopted by Seurat and Scanpy to preprocess data, followed by PCA for dimensionality reduction. After preprocessing, Portal can be trained via `model.train()`. 49 | ```python 50 | import portal 51 | import scanpy as sc 52 | 53 | # read AnnData 54 | adata_1 = sc.read_h5ad("adata_1.h5ad") 55 | adata_2 = sc.read_h5ad("adata_2.h5ad") 56 | 57 | model = portal.model.Model() 58 | model.preprocess(adata_1, adata_2) # perform preprocess and PCA 59 | model.train() # train the model 60 | model.eval() # get integrated latent representation of cells 61 | ``` 62 | The evaluating procedure `model.eval()` saves the integrated latent representation of cells in `model.latent`, which can be used for downstream integrative analysis. 63 | 64 | #### Parameters in `portal.model.Model()`: 65 | * `lambdacos`: Coefficient of the regularizer for preserving cosine similarity across domains. *Default*: `20.0`. 66 | * `training_steps`: Number of steps for training. *Default*: `2000`. Use `training_steps=1000` for datasets with sample size < 20,000. 67 | * `npcs`: Dimensionality of the embeddings in each domain (number of PCs). *Default*: `30`. 68 | * `n_latent`: Dimensionality of the shared latent space. *Default*: `20`. 69 | * `batch_size`: Batch size for training. *Default*: `500`. 70 | * `seed`: Random seed. *Default*: `1234`. 71 | 72 | The default setting of the parameter `lambdacos` works in general. We also enable tuning of this parameter to achieve a better performance, see [**Tuning `lambdacos` (optional)**](#tuning-lambdacos-optional). For the integration task where the cosine similarity is not a reliable cross-domain correspondance (such as cross-species integration), we recommend to use a lower value such as `lambdacos=10.0`. 73 | 74 | ### Memory-efficient Version 75 | To deal with large single-cell datasets, we also developed a memory-efficient version by reading mini-batches from the disk: 76 | ```python 77 | model = portal.model.Model() 78 | model.preprocess_memory_efficient(adata_A_path="adata_1.h5ad", adata_B_path="adata_2.h5ad") 79 | model.train_memory_efficient() 80 | model.eval_memory_efficient() 81 | ``` 82 | 83 | ### Integrating Multiple Datasets 84 | Portal integrates multiple datasets incrementally. Given `adata_list = [adata_1, ..., adata_n]` is a list of AnnData objects, they can be integrated by running the following commands: 85 | ```python 86 | lowdim_list = portal.utils.preprocess_datasets(adata_list) 87 | integrated_data = portal.utils.integrate_datasets(lowdim_list) 88 | ``` 89 | 90 | ### Tuning `lambdacos` (optional) 91 | An optional choice is to tune the parameter `lambdacos` in the range [15.0, 50.0]. Users can run the following command to search for an optimal parameter that yields the best integration result in terms of the mixing metric: 92 | ```python 93 | lowdim_list = portal.utils.preprocess_datasets(adata_list) 94 | integrated_data = portal.utils.integrate_datasets(lowdim_list, search_cos=True) 95 | ``` 96 | 97 | ### Recovering expression matrices 98 | Portal can provide harmonized expression matrices (in scaled level or log-normalized level): 99 | ```python 100 | lowdim_list, hvg, mean, std, pca = portal.utils.preprocess_recover_expression(adata_list) 101 | expression_scaled, expression_log_normalized = portal.utils.integrate_recover_expression(lowdim_list, mean, std, pca) 102 | ``` 103 | 104 | ### Demos 105 | We provide demos for users to get a quick start: [Demo 1](https://jiazhao97.github.io/Portal_demo1/), [Demo 2](https://htmlpreview.github.io/?https://github.com/jiazhao97/Portal-reproducibility/blob/main/Portal_recover_expression.html). 106 | 107 | ## Development 108 | This package is developed by Jia Zhao (jzhaoaz@connect.ust.hk) and Gefei Wang (gwangas@connect.ust.hk). 109 | 110 | ## Citation 111 | Jia Zhao, Gefei Wang, Jingsi Ming, Zhixiang Lin, Yang Wang, The Tabula Microcebus Consortium, Angela Ruohao Wu, Can Yang. Adversarial domain translation networks for integrating large-scale atlas-level single-cell datasets. Nature Computational Science 2, 317–330 (2022). 112 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: portal 2 | channels: 3 | - conda-forge 4 | - bioconda 5 | - defaults 6 | dependencies: 7 | - python=3.7.10 8 | - pytorch=1.8.1 9 | - scanpy=1.7.2 10 | - anndata=0.7.6 11 | - scikit-learn=0.24.1 12 | - numpy=1.19.2 13 | - scipy==1.6.2 14 | - scikit-learn==0.24.1 15 | - statsmodels==0.12.2 16 | - louvain==0.7.0 17 | - leidenalg==0.7.0 18 | - pandas=1.1.5 19 | - umap-learn=0.4.6 20 | - numba=0.49.1 21 | - pytables=3.6.1 22 | - scikit-misc=0.1.3 23 | - rpy2=3.4.5 24 | - anndata2ri=1.0.6 -------------------------------------------------------------------------------- /portal/__init__.py: -------------------------------------------------------------------------------- 1 | import portal.model 2 | import portal.networks 3 | import portal.utils 4 | 5 | __version__ = '1.0.2' 6 | -------------------------------------------------------------------------------- /portal/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import scanpy as sc 5 | import pandas as pd 6 | import tables 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from sklearn.decomposition import PCA, IncrementalPCA 12 | 13 | from portal.networks import * 14 | 15 | class Model(object): 16 | def __init__(self, batch_size=500, training_steps=2000, seed=1234, npcs=30, n_latent=20, lambdacos=20.0, 17 | model_path="models", data_path="data", result_path="results"): 18 | 19 | # add device 20 | self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 21 | 22 | # set random seed 23 | torch.manual_seed(seed) 24 | np.random.seed(seed) 25 | if torch.cuda.is_available(): 26 | torch.cuda.manual_seed_all(seed) 27 | torch.backends.cudnn.benchmark = True 28 | 29 | self.batch_size = batch_size 30 | self.training_steps = training_steps 31 | self.npcs = npcs 32 | self.n_latent = n_latent 33 | self.lambdacos = lambdacos 34 | self.lambdaAE = 10.0 35 | self.lambdaLA = 10.0 36 | self.lambdaGAN = 1.0 37 | self.margin = 5.0 38 | self.model_path = model_path 39 | self.data_path = data_path 40 | self.result_path = result_path 41 | 42 | 43 | def preprocess(self, 44 | adata_A_input, 45 | adata_B_input, 46 | hvg_num=4000, # number of highly variable genes for each anndata 47 | save_embedding=False # save low-dimensional embeddings or not 48 | ): 49 | ''' 50 | Performing preprocess for a pair of datasets. 51 | To integrate multiple datasets, use function preprocess_multiple_anndata in utils.py 52 | ''' 53 | adata_A = adata_A_input.copy() 54 | adata_B = adata_B_input.copy() 55 | 56 | print("Finding highly variable genes...") 57 | sc.pp.highly_variable_genes(adata_A, flavor='seurat_v3', n_top_genes=hvg_num) 58 | sc.pp.highly_variable_genes(adata_B, flavor='seurat_v3', n_top_genes=hvg_num) 59 | hvg_A = adata_A.var[adata_A.var.highly_variable == True].sort_values(by="highly_variable_rank").index 60 | hvg_B = adata_B.var[adata_B.var.highly_variable == True].sort_values(by="highly_variable_rank").index 61 | hvg_total = hvg_A & hvg_B 62 | if len(hvg_total) < 100: 63 | raise ValueError("The total number of highly variable genes is smaller than 100 (%d). Try to set a larger hvg_num." % len(hvg_total)) 64 | 65 | print("Normalizing and scaling...") 66 | sc.pp.normalize_total(adata_A, target_sum=1e4) 67 | sc.pp.log1p(adata_A) 68 | adata_A = adata_A[:, hvg_total] 69 | sc.pp.scale(adata_A, max_value=10) 70 | 71 | sc.pp.normalize_total(adata_B, target_sum=1e4) 72 | sc.pp.log1p(adata_B) 73 | adata_B = adata_B[:, hvg_total] 74 | sc.pp.scale(adata_B, max_value=10) 75 | 76 | adata_total = adata_A.concatenate(adata_B, index_unique=None) 77 | 78 | print("Dimensionality reduction via PCA...") 79 | pca = PCA(n_components=self.npcs, svd_solver="arpack", random_state=0) 80 | adata_total.obsm["X_pca"] = pca.fit_transform(adata_total.X) 81 | 82 | self.emb_A = adata_total.obsm["X_pca"][:adata_A.shape[0], :self.npcs].copy() 83 | self.emb_B = adata_total.obsm["X_pca"][adata_A.shape[0]:, :self.npcs].copy() 84 | 85 | if not os.path.exists(self.data_path): 86 | os.makedirs(self.data_path) 87 | 88 | if save_embedding: 89 | np.save(os.path.join(self.data_path, "lowdim_A.npy"), self.emb_A) 90 | np.save(os.path.join(self.data_path, "lowdim_B.npy"), self.emb_B) 91 | 92 | 93 | def preprocess_memory_efficient(self, 94 | adata_A_path, 95 | adata_B_path, 96 | hvg_num=4000, 97 | chunk_size=20000, 98 | save_embedding=True 99 | ): 100 | ''' 101 | Performing preprocess for a pair of datasets with efficient memory usage. 102 | To improve time efficiency, use a larger chunk_size. 103 | ''' 104 | adata_A_input = sc.read_h5ad(adata_A_path, backed="r+", chunk_size=chunk_size) 105 | adata_B_input = sc.read_h5ad(adata_B_path, backed="r+", chunk_size=chunk_size) 106 | 107 | print("Finding highly variable genes...") 108 | subsample_idx_A = np.random.choice(adata_A_input.shape[0], size=np.minimum(adata_A_input.shape[0], chunk_size), replace=False) 109 | subsample_idx_B = np.random.choice(adata_B_input.shape[0], size=np.minimum(adata_B_input.shape[0], chunk_size), replace=False) 110 | 111 | adata_A_subsample = adata_A_input[subsample_idx_A].to_memory().copy() 112 | adata_B_subsample = adata_B_input[subsample_idx_B].to_memory().copy() 113 | 114 | sc.pp.highly_variable_genes(adata_A_subsample, flavor='seurat_v3', n_top_genes=hvg_num) 115 | sc.pp.highly_variable_genes(adata_B_subsample, flavor='seurat_v3', n_top_genes=hvg_num) 116 | 117 | hvg_A = adata_A_subsample.var[adata_A_subsample.var.highly_variable == True].sort_values(by="highly_variable_rank").index 118 | hvg_B = adata_B_subsample.var[adata_B_subsample.var.highly_variable == True].sort_values(by="highly_variable_rank").index 119 | hvg = hvg_A & hvg_B 120 | 121 | del adata_A_subsample, adata_B_subsample, subsample_idx_A, subsample_idx_B 122 | 123 | print("Normalizing and scaling...") 124 | adata_A = adata_A_input.copy(adata_A_path) 125 | adata_B = adata_B_input.copy(adata_B_path) 126 | 127 | adata_A_hvg_idx = adata_A.var.index.get_indexer(hvg) 128 | adata_B_hvg_idx = adata_B.var.index.get_indexer(hvg) 129 | 130 | mean_A = np.zeros((1, len(hvg))) 131 | sq_A = np.zeros((1, len(hvg))) 132 | mean_B = np.zeros((1, len(hvg))) 133 | sq_B = np.zeros((1, len(hvg))) 134 | 135 | for i in range(adata_A.shape[0] // chunk_size): 136 | X_norm = sc.pp.normalize_total(adata_A[i * chunk_size: (i + 1) * chunk_size].to_memory(), target_sum=1e4, inplace=False)["X"] 137 | X_norm = X_norm[:, adata_A_hvg_idx] 138 | X_norm = sc.pp.log1p(X_norm) 139 | mean_A = mean_A + X_norm.sum(axis=0) / adata_A.shape[0] 140 | sq_A = sq_A + X_norm.power(2).sum(axis=0) / adata_A.shape[0] 141 | 142 | if (adata_A.shape[0] % chunk_size) > 0: 143 | X_norm = sc.pp.normalize_total(adata_A[(adata_A.shape[0] // chunk_size) * chunk_size: adata_A.shape[0]].to_memory(), target_sum=1e4, inplace=False)["X"] 144 | X_norm = X_norm[:, adata_A_hvg_idx] 145 | X_norm = sc.pp.log1p(X_norm) 146 | mean_A = mean_A + X_norm.sum(axis=0) / adata_A.shape[0] 147 | sq_A = sq_A + X_norm.power(2).sum(axis=0) / adata_A.shape[0] 148 | 149 | std_A = np.sqrt(sq_A - np.square(mean_A)) 150 | 151 | for i in range(adata_B.shape[0] // chunk_size): 152 | X_norm = sc.pp.normalize_total(adata_B[i * chunk_size: (i + 1) * chunk_size].to_memory(), target_sum=1e4, inplace=False)["X"] 153 | X_norm = X_norm[:, adata_B_hvg_idx] 154 | X_norm = sc.pp.log1p(X_norm) 155 | mean_B = mean_B + X_norm.sum(axis=0) / adata_B.shape[0] 156 | sq_B = sq_B + X_norm.power(2).sum(axis=0) / adata_B.shape[0] 157 | 158 | if (adata_B.shape[0] % chunk_size) > 0: 159 | X_norm = sc.pp.normalize_total(adata_B[(adata_B.shape[0] // chunk_size) * chunk_size: adata_B.shape[0]].to_memory(), target_sum=1e4, inplace=False)["X"] 160 | X_norm = X_norm[:, adata_B_hvg_idx] 161 | X_norm = sc.pp.log1p(X_norm) 162 | mean_B = mean_B + X_norm.sum(axis=0) / adata_B.shape[0] 163 | sq_B = sq_B + X_norm.power(2).sum(axis=0) / adata_B.shape[0] 164 | 165 | std_B = np.sqrt(sq_B - np.square(mean_B)) 166 | 167 | del X_norm, sq_A, sq_B 168 | 169 | print("Dimensionality reduction via Incremental PCA...") 170 | ipca = IncrementalPCA(n_components=self.npcs, batch_size=chunk_size) 171 | total_ncells = adata_A.shape[0] + adata_B.shape[0] 172 | order = np.arange(total_ncells) 173 | np.random.RandomState(1234).shuffle(order) 174 | 175 | for i in range(total_ncells // chunk_size): 176 | idx = order[i * chunk_size : (i + 1) * chunk_size] 177 | idx_is_A = (idx < adata_A.shape[0]) 178 | data_A = sc.pp.normalize_total(adata_A[idx[idx_is_A]].to_memory(), target_sum=1e4, inplace=False)["X"] 179 | data_A = data_A[:, adata_A_hvg_idx] 180 | data_A = sc.pp.log1p(data_A) 181 | data_A = np.clip((data_A - mean_A) / std_A, -10, 10) 182 | idx_is_B = (idx >= adata_A.shape[0]) 183 | data_B = sc.pp.normalize_total(adata_B[idx[idx_is_B] - adata_A.shape[0]].to_memory(), target_sum=1e4, inplace=False)["X"] 184 | data_B = data_B[:, adata_B_hvg_idx] 185 | data_B = sc.pp.log1p(data_B) 186 | data_B = np.clip((data_B - mean_B) / std_B, -10, 10) 187 | data = np.concatenate((data_A, data_B), axis=0) 188 | ipca.partial_fit(data) 189 | 190 | if (total_ncells % chunk_size) > 0: 191 | idx = order[(total_ncells // chunk_size) * chunk_size: total_ncells] 192 | idx_is_A = (idx < adata_A.shape[0]) 193 | data_A = sc.pp.normalize_total(adata_A[idx[idx_is_A]].to_memory(), target_sum=1e4, inplace=False)["X"] 194 | data_A = data_A[:, adata_A_hvg_idx] 195 | data_A = sc.pp.log1p(data_A) 196 | data_A = np.clip((data_A - mean_A) / std_A, -10, 10) 197 | idx_is_B = (idx >= adata_A.shape[0]) 198 | data_B = sc.pp.normalize_total(adata_B[idx[idx_is_B] - adata_A.shape[0]].to_memory(), target_sum=1e4, inplace=False)["X"] 199 | data_B = data_B[:, adata_B_hvg_idx] 200 | data_B = sc.pp.log1p(data_B) 201 | data_B = np.clip((data_B - mean_B) / std_B, -10, 10) 202 | data = np.concatenate((data_A, data_B), axis=0) 203 | ipca.partial_fit(data) 204 | 205 | if not os.path.exists(self.data_path): 206 | os.makedirs(self.data_path) 207 | 208 | if save_embedding: 209 | h5filename_A = os.path.join(self.data_path, "lowdim_A.h5") 210 | f = tables.open_file(h5filename_A, mode='w') 211 | atom = tables.Float64Atom() 212 | f.create_earray(f.root, 'data', atom, (0, self.npcs)) 213 | f.close() 214 | # transform 215 | f = tables.open_file(h5filename_A, mode='a') 216 | for i in range(adata_A.shape[0] // chunk_size): 217 | data_A = sc.pp.normalize_total(adata_A[i * chunk_size: (i + 1) * chunk_size].to_memory(), target_sum=1e4, inplace=False)["X"] 218 | data_A = data_A[:, adata_A_hvg_idx] 219 | data_A = sc.pp.log1p(data_A) 220 | data_A = np.clip((data_A - mean_A) / std_A, -10, 10) 221 | data_A = ipca.transform(data_A) 222 | f.root.data.append(data_A) 223 | if (adata_A.shape[0] % chunk_size) > 0: 224 | data_A = sc.pp.normalize_total(adata_A[(adata_A.shape[0] // chunk_size) * chunk_size: adata_A.shape[0]].to_memory(), target_sum=1e4, inplace=False)["X"] 225 | data_A = data_A[:, adata_A_hvg_idx] 226 | data_A = sc.pp.log1p(data_A) 227 | data_A = np.clip((data_A - mean_A) / std_A, -10, 10) 228 | data_A = ipca.transform(data_A) 229 | f.root.data.append(data_A) 230 | f.close() 231 | del data_A 232 | 233 | h5filename_B = os.path.join(self.data_path, "lowdim_B.h5") 234 | f = tables.open_file(h5filename_B, mode='w') 235 | atom = tables.Float64Atom() 236 | f.create_earray(f.root, 'data', atom, (0, self.npcs)) 237 | f.close() 238 | # transform 239 | f = tables.open_file(h5filename_B, mode='a') 240 | for i in range(adata_B.shape[0] // chunk_size): 241 | data_B = sc.pp.normalize_total(adata_B[i * chunk_size: (i + 1) * chunk_size].to_memory(), target_sum=1e4, inplace=False)["X"] 242 | data_B = data_B[:, adata_B_hvg_idx] 243 | data_B = sc.pp.log1p(data_B) 244 | data_B = np.clip((data_B - mean_B) / std_B, -10, 10) 245 | data_B = ipca.transform(data_B) 246 | f.root.data.append(data_B) 247 | if (adata_B.shape[0] % chunk_size) > 0: 248 | data_B = sc.pp.normalize_total(adata_B[(adata_B.shape[0] // chunk_size) * chunk_size: adata_B.shape[0]].to_memory(), target_sum=1e4, inplace=False)["X"] 249 | data_B = data_B[:, adata_B_hvg_idx] 250 | data_B = sc.pp.log1p(data_B) 251 | data_B = np.clip((data_B - mean_B) / std_B, -10, 10) 252 | data_B = ipca.transform(data_B) 253 | f.root.data.append(data_B) 254 | f.close() 255 | del data_B 256 | 257 | 258 | def train(self): 259 | begin_time = time.time() 260 | print("Begining time: ", time.asctime(time.localtime(begin_time))) 261 | self.E_A = encoder(self.npcs, self.n_latent).to(self.device) 262 | self.E_B = encoder(self.npcs, self.n_latent).to(self.device) 263 | self.G_A = generator(self.npcs, self.n_latent).to(self.device) 264 | self.G_B = generator(self.npcs, self.n_latent).to(self.device) 265 | self.D_A = discriminator(self.npcs).to(self.device) 266 | self.D_B = discriminator(self.npcs).to(self.device) 267 | params_G = list(self.E_A.parameters()) + list(self.E_B.parameters()) + list(self.G_A.parameters()) + list(self.G_B.parameters()) 268 | optimizer_G = optim.Adam(params_G, lr=0.001, weight_decay=0.) 269 | params_D = list(self.D_A.parameters()) + list(self.D_B.parameters()) 270 | optimizer_D = optim.Adam(params_D, lr=0.001, weight_decay=0.) 271 | self.E_A.train() 272 | self.E_B.train() 273 | self.G_A.train() 274 | self.G_B.train() 275 | self.D_A.train() 276 | self.D_B.train() 277 | 278 | N_A = self.emb_A.shape[0] 279 | N_B = self.emb_B.shape[0] 280 | 281 | for step in range(self.training_steps): 282 | index_A = np.random.choice(np.arange(N_A), size=self.batch_size) 283 | index_B = np.random.choice(np.arange(N_B), size=self.batch_size) 284 | x_A = torch.from_numpy(self.emb_A[index_A, :]).float().to(self.device) 285 | x_B = torch.from_numpy(self.emb_B[index_B, :]).float().to(self.device) 286 | z_A = self.E_A(x_A) 287 | z_B = self.E_B(x_B) 288 | x_AtoB = self.G_B(z_A) 289 | x_BtoA = self.G_A(z_B) 290 | x_Arecon = self.G_A(z_A) 291 | x_Brecon = self.G_B(z_B) 292 | z_AtoB = self.E_B(x_AtoB) 293 | z_BtoA = self.E_A(x_BtoA) 294 | 295 | # discriminator loss: 296 | optimizer_D.zero_grad() 297 | if step <= 5: 298 | # Warm-up 299 | loss_D_A = (torch.log(1 + torch.exp(-self.D_A(x_A))) + torch.log(1 + torch.exp(self.D_A(x_BtoA)))).mean() 300 | loss_D_B = (torch.log(1 + torch.exp(-self.D_B(x_B))) + torch.log(1 + torch.exp(self.D_B(x_AtoB)))).mean() 301 | else: 302 | loss_D_A = (torch.log(1 + torch.exp(-torch.clamp(self.D_A(x_A), -self.margin, self.margin))) + torch.log(1 + torch.exp(torch.clamp(self.D_A(x_BtoA), -self.margin, self.margin)))).mean() 303 | loss_D_B = (torch.log(1 + torch.exp(-torch.clamp(self.D_B(x_B), -self.margin, self.margin))) + torch.log(1 + torch.exp(torch.clamp(self.D_B(x_AtoB), -self.margin, self.margin)))).mean() 304 | loss_D = loss_D_A + loss_D_B 305 | loss_D.backward(retain_graph=True) 306 | optimizer_D.step() 307 | 308 | # autoencoder loss: 309 | loss_AE_A = torch.mean((x_Arecon - x_A)**2) 310 | loss_AE_B = torch.mean((x_Brecon - x_B)**2) 311 | loss_AE = loss_AE_A + loss_AE_B 312 | 313 | # cosine correspondence: 314 | loss_cos_A = (1 - torch.sum(F.normalize(x_AtoB, p=2) * F.normalize(x_A, p=2), 1)).mean() 315 | loss_cos_B = (1 - torch.sum(F.normalize(x_BtoA, p=2) * F.normalize(x_B, p=2), 1)).mean() 316 | loss_cos = loss_cos_A + loss_cos_B 317 | 318 | # latent align loss: 319 | loss_LA_AtoB = torch.mean((z_A - z_AtoB)**2) 320 | loss_LA_BtoA = torch.mean((z_B - z_BtoA)**2) 321 | loss_LA = loss_LA_AtoB + loss_LA_BtoA 322 | 323 | # generator loss 324 | optimizer_G.zero_grad() 325 | if step <= 5: 326 | # Warm-up 327 | loss_G_GAN = (torch.log(1 + torch.exp(-self.D_A(x_BtoA))) + torch.log(1 + torch.exp(-self.D_B(x_AtoB)))).mean() 328 | else: 329 | loss_G_GAN = (torch.log(1 + torch.exp(-torch.clamp(self.D_A(x_BtoA), -self.margin, self.margin))) + torch.log(1 + torch.exp(-torch.clamp(self.D_B(x_AtoB), -self.margin, self.margin)))).mean() 330 | loss_G = self.lambdaGAN * loss_G_GAN + self.lambdacos * loss_cos + self.lambdaAE * loss_AE + self.lambdaLA * loss_LA 331 | loss_G.backward() 332 | optimizer_G.step() 333 | 334 | if not step % 200: 335 | print("step %d, loss_D=%f, loss_GAN=%f, loss_AE=%f, loss_cos=%f, loss_LA=%f" 336 | % (step, loss_D, loss_G_GAN, self.lambdaAE*loss_AE, self.lambdacos*loss_cos, self.lambdaLA*loss_LA)) 337 | 338 | end_time = time.time() 339 | print("Ending time: ", time.asctime(time.localtime(end_time))) 340 | self.train_time = end_time - begin_time 341 | print("Training takes %.2f seconds" % self.train_time) 342 | 343 | if not os.path.exists(self.model_path): 344 | os.makedirs(self.model_path) 345 | 346 | state = {'D_A': self.D_A.state_dict(), 'D_B': self.D_B.state_dict(), 347 | 'E_A': self.E_A.state_dict(), 'E_B': self.E_B.state_dict(), 348 | 'G_A': self.G_A.state_dict(), 'G_B': self.G_B.state_dict()} 349 | 350 | torch.save(state, os.path.join(self.model_path, "ckpt.pth")) 351 | 352 | def train_memory_efficient(self): 353 | import tables 354 | f_A = tables.open_file(os.path.join(self.data_path, "lowdim_A.h5")) 355 | f_B = tables.open_file(os.path.join(self.data_path, "lowdim_B.h5")) 356 | 357 | self.emb_A = np.array(f_A.root.data) 358 | self.emb_B = np.array(f_B.root.data) 359 | 360 | N_A = self.emb_A.shape[0] 361 | N_B = self.emb_B.shape[0] 362 | 363 | begin_time = time.time() 364 | print("Begining time: ", time.asctime(time.localtime(begin_time))) 365 | self.E_A = encoder(self.npcs, self.n_latent).to(self.device) 366 | self.E_B = encoder(self.npcs, self.n_latent).to(self.device) 367 | self.G_A = generator(self.npcs, self.n_latent).to(self.device) 368 | self.G_B = generator(self.npcs, self.n_latent).to(self.device) 369 | self.D_A = discriminator(self.npcs).to(self.device) 370 | self.D_B = discriminator(self.npcs).to(self.device) 371 | params_G = list(self.E_A.parameters()) + list(self.E_B.parameters()) + list(self.G_A.parameters()) + list(self.G_B.parameters()) 372 | optimizer_G = optim.Adam(params_G, lr=0.001, weight_decay=0.) 373 | params_D = list(self.D_A.parameters()) + list(self.D_B.parameters()) 374 | optimizer_D = optim.Adam(params_D, lr=0.001, weight_decay=0.) 375 | self.E_A.train() 376 | self.E_B.train() 377 | self.G_A.train() 378 | self.G_B.train() 379 | self.D_A.train() 380 | self.D_B.train() 381 | 382 | for step in range(self.training_steps): 383 | index_A = np.random.choice(np.arange(N_A), size=self.batch_size, replace=False) 384 | index_B = np.random.choice(np.arange(N_B), size=self.batch_size, replace=False) 385 | x_A = torch.from_numpy(self.emb_A[index_A, :]).float().to(self.device) 386 | x_B = torch.from_numpy(self.emb_B[index_B, :]).float().to(self.device) 387 | z_A = self.E_A(x_A) 388 | z_B = self.E_B(x_B) 389 | x_AtoB = self.G_B(z_A) 390 | x_BtoA = self.G_A(z_B) 391 | x_Arecon = self.G_A(z_A) 392 | x_Brecon = self.G_B(z_B) 393 | z_AtoB = self.E_B(x_AtoB) 394 | z_BtoA = self.E_A(x_BtoA) 395 | 396 | # discriminator loss: 397 | optimizer_D.zero_grad() 398 | if step <= 5: 399 | # Warm-up 400 | loss_D_A = (torch.log(1 + torch.exp(-self.D_A(x_A))) + torch.log(1 + torch.exp(self.D_A(x_BtoA)))).mean() 401 | loss_D_B = (torch.log(1 + torch.exp(-self.D_B(x_B))) + torch.log(1 + torch.exp(self.D_B(x_AtoB)))).mean() 402 | else: 403 | loss_D_A = (torch.log(1 + torch.exp(-torch.clamp(self.D_A(x_A), -self.margin, self.margin))) + torch.log(1 + torch.exp(torch.clamp(self.D_A(x_BtoA), -self.margin, self.margin)))).mean() 404 | loss_D_B = (torch.log(1 + torch.exp(-torch.clamp(self.D_B(x_B), -self.margin, self.margin))) + torch.log(1 + torch.exp(torch.clamp(self.D_B(x_AtoB), -self.margin, self.margin)))).mean() 405 | loss_D = loss_D_A + loss_D_B 406 | loss_D.backward(retain_graph=True) 407 | optimizer_D.step() 408 | 409 | # autoencoder loss: 410 | loss_AE_A = torch.mean((x_Arecon - x_A)**2) 411 | loss_AE_B = torch.mean((x_Brecon - x_B)**2) 412 | loss_AE = loss_AE_A + loss_AE_B 413 | 414 | # cosine correspondence: 415 | loss_cos_A = (1 - torch.sum(F.normalize(x_AtoB, p=2) * F.normalize(x_A, p=2), 1)).mean() 416 | loss_cos_B = (1 - torch.sum(F.normalize(x_BtoA, p=2) * F.normalize(x_B, p=2), 1)).mean() 417 | loss_cos = loss_cos_A + loss_cos_B 418 | 419 | # latent align loss: 420 | loss_LA_AtoB = torch.mean((z_A - z_AtoB)**2) 421 | loss_LA_BtoA = torch.mean((z_B - z_BtoA)**2) 422 | loss_LA = loss_LA_AtoB + loss_LA_BtoA 423 | 424 | # generator loss 425 | optimizer_G.zero_grad() 426 | if step <= 5: 427 | # Warm-up 428 | loss_G_GAN = (torch.log(1 + torch.exp(-self.D_A(x_BtoA))) + torch.log(1 + torch.exp(-self.D_B(x_AtoB)))).mean() 429 | else: 430 | loss_G_GAN = (torch.log(1 + torch.exp(-torch.clamp(self.D_A(x_BtoA), -self.margin, self.margin))) + torch.log(1 + torch.exp(-torch.clamp(self.D_B(x_AtoB), -self.margin, self.margin)))).mean() 431 | loss_G = self.lambdaGAN * loss_G_GAN + self.lambdacos * loss_cos + self.lambdaAE * loss_AE + self.lambdaLA * loss_LA 432 | loss_G.backward() 433 | optimizer_G.step() 434 | 435 | if not step % 200: 436 | print("step %d, loss_D=%f, loss_GAN=%f, loss_AE=%f, loss_cos=%f, loss_LA=%f" 437 | % (step, loss_D, loss_G_GAN, self.lambdaAE*loss_AE, self.lambdacos*loss_cos, self.lambdaLA*loss_LA)) 438 | 439 | f_A.close() 440 | f_B.close() 441 | 442 | end_time = time.time() 443 | print("Ending time: ", time.asctime(time.localtime(end_time))) 444 | self.train_time = end_time - begin_time 445 | print("Training takes %.2f seconds" % self.train_time) 446 | 447 | if not os.path.exists(self.model_path): 448 | os.makedirs(self.model_path) 449 | 450 | state = {'D_A': self.D_A.state_dict(), 'D_B': self.D_B.state_dict(), 451 | 'E_A': self.E_A.state_dict(), 'E_B': self.E_B.state_dict(), 452 | 'G_A': self.G_A.state_dict(), 'G_B': self.G_B.state_dict()} 453 | 454 | torch.save(state, os.path.join(self.model_path, "ckpt.pth")) 455 | 456 | 457 | def eval(self, D_score=False, save_results=False): 458 | begin_time = time.time() 459 | print("Begining time: ", time.asctime(time.localtime(begin_time))) 460 | 461 | self.E_A = encoder(self.npcs, self.n_latent).to(self.device) 462 | self.E_B = encoder(self.npcs, self.n_latent).to(self.device) 463 | self.G_A = generator(self.npcs, self.n_latent).to(self.device) 464 | self.G_B = generator(self.npcs, self.n_latent).to(self.device) 465 | self.E_A.load_state_dict(torch.load(os.path.join(self.model_path, "ckpt.pth"))['E_A']) 466 | self.E_B.load_state_dict(torch.load(os.path.join(self.model_path, "ckpt.pth"))['E_B']) 467 | self.G_A.load_state_dict(torch.load(os.path.join(self.model_path, "ckpt.pth"))['G_A']) 468 | self.G_B.load_state_dict(torch.load(os.path.join(self.model_path, "ckpt.pth"))['G_B']) 469 | 470 | x_A = torch.from_numpy(self.emb_A).float().to(self.device) 471 | x_B = torch.from_numpy(self.emb_B).float().to(self.device) 472 | 473 | z_A = self.E_A(x_A) 474 | z_B = self.E_B(x_B) 475 | 476 | x_AtoB = self.G_B(z_A) 477 | x_BtoA = self.G_A(z_B) 478 | 479 | end_time = time.time() 480 | 481 | print("Ending time: ", time.asctime(time.localtime(end_time))) 482 | self.eval_time = end_time - begin_time 483 | print("Evaluating takes %.2f seconds" % self.eval_time) 484 | 485 | self.latent = np.concatenate((z_A.detach().cpu().numpy(), z_B.detach().cpu().numpy()), axis=0) 486 | self.data_Aspace = np.concatenate((self.emb_A, x_BtoA.detach().cpu().numpy()), axis=0) 487 | self.data_Bspace = np.concatenate((x_AtoB.detach().cpu().numpy(), self.emb_B), axis=0) 488 | 489 | if D_score: 490 | self.D_A = discriminator(self.npcs).to(self.device) 491 | self.D_B = discriminator(self.npcs).to(self.device) 492 | self.D_A.load_state_dict(torch.load(os.path.join(self.model_path, "ckpt.pth"))['D_A']) 493 | self.D_B.load_state_dict(torch.load(os.path.join(self.model_path, "ckpt.pth"))['D_B']) 494 | 495 | score_D_A_A = self.D_A(x_A) 496 | score_D_B_A = self.D_B(x_AtoB) 497 | score_D_B_B = self.D_B(x_B) 498 | score_D_A_B = self.D_A(x_BtoA) 499 | 500 | self.score_Aspace = np.concatenate((score_D_A_A.detach().cpu().numpy(), score_D_A_B.detach().cpu().numpy()), axis=0) 501 | self.score_Bspace = np.concatenate((score_D_B_A.detach().cpu().numpy(), score_D_B_B.detach().cpu().numpy()), axis=0) 502 | 503 | if save_results: 504 | if not os.path.exists(self.result_path): 505 | os.makedirs(self.result_path) 506 | 507 | np.save(os.path.join(self.result_path, "latent_A.npy"), z_A.detach().cpu().numpy()) 508 | np.save(os.path.join(self.result_path, "latent_B.npy"), z_B.detach().cpu().numpy()) 509 | np.save(os.path.join(self.result_path, "x_AtoB.npy"), x_AtoB.detach().cpu().numpy()) 510 | np.save(os.path.join(self.result_path, "x_BtoA.npy"), x_BtoA.detach().cpu().numpy()) 511 | if D_score: 512 | np.save(os.path.join(self.result_path, "score_Aspace_A.npy"), score_D_A_A.detach().cpu().numpy()) 513 | np.save(os.path.join(self.result_path, "score_Bspace_A.npy"), score_D_B_A.detach().cpu().numpy()) 514 | np.save(os.path.join(self.result_path, "score_Bspace_B.npy"), score_D_B_B.detach().cpu().numpy()) 515 | np.save(os.path.join(self.result_path, "score_Aspace_B.npy"), score_D_A_B.detach().cpu().numpy()) 516 | 517 | 518 | def eval_memory_efficient(self): 519 | begin_time = time.time() 520 | print("Begining time: ", time.asctime(time.localtime(begin_time))) 521 | 522 | self.E_A = encoder(self.npcs, self.n_latent).to(self.device) 523 | self.E_B = encoder(self.npcs, self.n_latent).to(self.device) 524 | self.E_A.load_state_dict(torch.load(os.path.join(self.model_path, "ckpt.pth"))['E_A']) 525 | self.E_B.load_state_dict(torch.load(os.path.join(self.model_path, "ckpt.pth"))['E_B']) 526 | 527 | if not os.path.exists(self.result_path): 528 | os.makedirs(self.result_path) 529 | 530 | f_A = tables.open_file(os.path.join(self.data_path, "lowdim_A.h5")) 531 | f_B = tables.open_file(os.path.join(self.data_path, "lowdim_B.h5")) 532 | 533 | N_A = f_A.root.data.shape[0] 534 | N_B = f_B.root.data.shape[0] 535 | 536 | h5_latent_A = os.path.join(self.result_path, "latent_A.h5") 537 | f_latent_A = tables.open_file(h5_latent_A, mode='w') 538 | atom = tables.Float64Atom() 539 | f_latent_A.create_earray(f_latent_A.root, 'data', atom, (0, self.n_latent)) 540 | f_latent_A.close() 541 | 542 | f_latent_A = tables.open_file(h5_latent_A, mode='a') 543 | # f_x_AtoB = tables.open_file(h5_x_AtoB, mode='a') 544 | for i in range(N_A // self.batch_size): 545 | x_A = torch.from_numpy(f_A.root.data[i * self.batch_size: (i + 1) * self.batch_size]).float().to(self.device) 546 | z_A = self.E_A(x_A) 547 | f_latent_A.root.data.append(z_A.detach().cpu().numpy()) 548 | if (N_A % self.batch_size) > 0: 549 | x_A = torch.from_numpy(f_A.root.data[(N_A // self.batch_size) * self.batch_size: N_A]).float().to(self.device) 550 | z_A = self.E_A(x_A) 551 | f_latent_A.root.data.append(z_A.detach().cpu().numpy()) 552 | f_latent_A.close() 553 | 554 | h5_latent_B = os.path.join(self.result_path, "latent_B.h5") 555 | f_latent_B = tables.open_file(h5_latent_B, mode='w') 556 | atom = tables.Float64Atom() 557 | f_latent_B.create_earray(f_latent_B.root, 'data', atom, (0, self.n_latent)) 558 | f_latent_B.close() 559 | 560 | f_latent_B = tables.open_file(h5_latent_B, mode='a') 561 | for i in range(N_B // self.batch_size): 562 | x_B = torch.from_numpy(f_B.root.data[i * self.batch_size: (i + 1) * self.batch_size]).float().to(self.device) 563 | z_B = self.E_B(x_B) 564 | f_latent_B.root.data.append(z_B.detach().cpu().numpy()) 565 | if (N_B % self.batch_size) > 0: 566 | x_B = torch.from_numpy(f_B.root.data[(N_B // self.batch_size) * self.batch_size: N_B]).float().to(self.device) 567 | z_B = self.E_B(x_B) 568 | f_latent_B.root.data.append(z_B.detach().cpu().numpy()) 569 | f_latent_B.close() 570 | 571 | end_time = time.time() 572 | 573 | f_A.close() 574 | f_B.close() 575 | 576 | print("Ending time: ", time.asctime(time.localtime(end_time))) 577 | self.eval_time = end_time - begin_time 578 | print("Evaluating takes %.2f seconds" % self.eval_time) 579 | -------------------------------------------------------------------------------- /portal/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class encoder(nn.Module): 8 | def __init__(self, n_input, n_latent): 9 | super(encoder, self).__init__() 10 | self.n_input = n_input 11 | self.n_latent = n_latent 12 | n_hidden = 512 13 | 14 | self.W_1 = nn.Parameter(torch.Tensor(n_hidden, self.n_input).normal_(mean=0.0, std=0.1)) 15 | self.b_1 = nn.Parameter(torch.Tensor(n_hidden).normal_(mean=0.0, std=0.1)) 16 | 17 | self.W_2 = nn.Parameter(torch.Tensor(self.n_latent, n_hidden).normal_(mean=0.0, std=0.1)) 18 | self.b_2 = nn.Parameter(torch.Tensor(self.n_latent).normal_(mean=0.0, std=0.1)) 19 | 20 | def forward(self, x): 21 | h = F.relu(F.linear(x, self.W_1, self.b_1)) 22 | z = F.linear(h, self.W_2, self.b_2) 23 | return z 24 | 25 | class generator(nn.Module): 26 | def __init__(self, n_input, n_latent): 27 | super(generator, self).__init__() 28 | self.n_input = n_input 29 | self.n_latent = n_latent 30 | n_hidden = 512 31 | 32 | self.W_1 = nn.Parameter(torch.Tensor(n_hidden, self.n_latent).normal_(mean=0.0, std=0.1)) 33 | self.b_1 = nn.Parameter(torch.Tensor(n_hidden).normal_(mean=0.0, std=0.1)) 34 | 35 | self.W_2 = nn.Parameter(torch.Tensor(self.n_input, n_hidden).normal_(mean=0.0, std=0.1)) 36 | self.b_2 = nn.Parameter(torch.Tensor(self.n_input).normal_(mean=0.0, std=0.1)) 37 | 38 | def forward(self, z): 39 | h = F.relu(F.linear(z, self.W_1, self.b_1)) 40 | x = F.linear(h, self.W_2, self.b_2) 41 | return x 42 | 43 | class discriminator(nn.Module): 44 | def __init__(self, n_input): 45 | super(discriminator, self).__init__() 46 | self.n_input = n_input 47 | n_hidden = 512 48 | 49 | self.W_1 = nn.Parameter(torch.Tensor(n_hidden, self.n_input).normal_(mean=0.0, std=0.1)) 50 | self.b_1 = nn.Parameter(torch.Tensor(n_hidden).normal_(mean=0.0, std=0.1)) 51 | 52 | self.W_2 = nn.Parameter(torch.Tensor(n_hidden, n_hidden).normal_(mean=0.0, std=0.1)) 53 | self.b_2 = nn.Parameter(torch.Tensor(n_hidden).normal_(mean=0.0, std=0.1)) 54 | 55 | self.W_3 = nn.Parameter(torch.Tensor(1, n_hidden).normal_(mean=0.0, std=0.1)) 56 | self.b_3 = nn.Parameter(torch.Tensor(1).normal_(mean=0.0, std=0.1)) 57 | 58 | def forward(self, x): 59 | h = F.relu(F.linear(x, self.W_1, self.b_1)) 60 | h = F.relu(F.linear(h, self.W_2, self.b_2)) 61 | score = F.linear(h, self.W_3, self.b_3) 62 | return torch.clamp(score, min=-50.0, max=50.0) -------------------------------------------------------------------------------- /portal/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import scanpy as sc 4 | import pandas as pd 5 | import anndata 6 | import umap 7 | from portal.model import * 8 | import rpy2.robjects as ro 9 | import rpy2.robjects.numpy2ri 10 | import anndata2ri 11 | from sklearn.metrics import normalized_mutual_info_score as NMI 12 | from sklearn.metrics import f1_score 13 | from sklearn.metrics.cluster import silhouette_score 14 | from sklearn.linear_model import LinearRegression 15 | from scipy.sparse.csgraph import connected_components 16 | from sklearn.neighbors import NearestNeighbors 17 | from scipy.spatial.distance import cdist 18 | 19 | def preprocess_datasets(adata_list, # list of anndata to be integrated 20 | hvg_num=4000, # number of highly variable genes for each anndata 21 | save_embedding=False, # save low-dimensional embeddings or not 22 | data_path="data" 23 | ): 24 | 25 | if len(adata_list) < 2: 26 | raise ValueError("There should be at least two datasets for integration!") 27 | 28 | sample_size_list = [] 29 | 30 | print("Finding highly variable genes...") 31 | for i, adata in enumerate(adata_list): 32 | sample_size_list.append(adata.shape[0]) 33 | # adata = adata_input.copy() 34 | sc.pp.highly_variable_genes(adata, flavor='seurat_v3', n_top_genes=hvg_num) 35 | hvg = adata.var[adata.var.highly_variable == True].sort_values(by="highly_variable_rank").index 36 | if i == 0: 37 | hvg_total = hvg 38 | else: 39 | hvg_total = hvg_total & hvg 40 | if len(hvg_total) < 100: 41 | raise ValueError("The total number of highly variable genes is smaller than 100 (%d). Try to set a larger hvg_num." % len(hvg_total)) 42 | 43 | print("Normalizing and scaling...") 44 | for i, adata in enumerate(adata_list): 45 | # adata = adata_input.copy() 46 | sc.pp.normalize_total(adata, target_sum=1e4) 47 | sc.pp.log1p(adata) 48 | adata = adata[:, hvg_total] 49 | sc.pp.scale(adata, max_value=10) 50 | if i == 0: 51 | adata_total = adata 52 | else: 53 | adata_total = adata_total.concatenate(adata, index_unique=None) 54 | 55 | print("Dimensionality reduction via PCA...") 56 | npcs = 30 57 | pca = PCA(n_components=npcs, svd_solver="arpack", random_state=0) 58 | adata_total.obsm["X_pca"] = pca.fit_transform(adata_total.X) 59 | 60 | indices = np.cumsum(sample_size_list) 61 | 62 | data_path = os.path.join(data_path, "preprocess") 63 | 64 | if not os.path.exists(data_path): 65 | os.makedirs(data_path) 66 | 67 | if save_embedding: 68 | for i in range(len(indices)): 69 | if i == 0: 70 | np.save(os.path.join(data_path, "lowdim_1.npy"), 71 | adata_total.obsm["X_pca"][:indices[0], :npcs]) 72 | else: 73 | np.save(os.path.join(data_path, "lowdim_%d.npy" % (i + 1)), 74 | adata_total.obsm["X_pca"][indices[i-1]:indices[i], :npcs]) 75 | 76 | lowdim = adata_total.obsm["X_pca"].copy() 77 | lowdim_list = [lowdim[:indices[0], :npcs] if i == 0 else lowdim[indices[i - 1]:indices[i], :npcs] for i in range(len(indices))] 78 | 79 | return lowdim_list 80 | 81 | 82 | def preprocess_recover_expression(adata_list, # list of anndata to be integrated 83 | hvg_num=4000, # number of highly variable genes for each anndata 84 | save_embedding=False, # save low-dimensional embeddings or not 85 | data_path="data" 86 | ): 87 | 88 | if len(adata_list) < 2: 89 | raise ValueError("There should be at least two datasets for integration!") 90 | 91 | sample_size_list = [] 92 | 93 | print("Finding highly variable genes...") 94 | for i, adata in enumerate(adata_list): 95 | sample_size_list.append(adata.shape[0]) 96 | # adata = adata_input.copy() 97 | sc.pp.highly_variable_genes(adata, flavor='seurat_v3', n_top_genes=hvg_num) 98 | hvg = adata.var[adata.var.highly_variable == True].sort_values(by="highly_variable_rank").index 99 | if i == 0: 100 | hvg_total = hvg 101 | else: 102 | hvg_total = hvg_total & hvg 103 | if len(hvg_total) < 100: 104 | raise ValueError("The total number of highly variable genes is smaller than 100 (%d). Try to set a larger hvg_num." % len(hvg_total)) 105 | 106 | print("Normalizing and scaling...") 107 | for i, adata in enumerate(adata_list): 108 | # adata = adata_input.copy() 109 | sc.pp.normalize_total(adata, target_sum=1e4) 110 | sc.pp.log1p(adata) 111 | adata = adata[:, hvg_total] 112 | sc.pp.scale(adata, max_value=10) 113 | if i == 0: 114 | adata_total = adata 115 | mean = adata.var["mean"] 116 | std = adata.var["std"] 117 | else: 118 | adata_total = adata_total.concatenate(adata, index_unique=None) 119 | 120 | print("Dimensionality reduction via PCA...") 121 | npcs = 30 122 | pca = PCA(n_components=npcs, svd_solver="arpack", random_state=0) 123 | adata_total.obsm["X_pca"] = pca.fit_transform(adata_total.X) 124 | 125 | indices = np.cumsum(sample_size_list) 126 | 127 | data_path = os.path.join(data_path, "preprocess") 128 | 129 | if not os.path.exists(data_path): 130 | os.makedirs(data_path) 131 | 132 | if save_embedding: 133 | for i in range(len(indices)): 134 | if i == 0: 135 | np.save(os.path.join(data_path, "lowdim_1.npy"), 136 | adata_total.obsm["X_pca"][:indices[0], :npcs]) 137 | else: 138 | np.save(os.path.join(data_path, "lowdim_%d.npy" % (i + 1)), 139 | adata_total.obsm["X_pca"][indices[i-1]:indices[i], :npcs]) 140 | 141 | lowdim = adata_total.obsm["X_pca"].copy() 142 | lowdim_list = [lowdim[:indices[0], :npcs] if i == 0 else lowdim[indices[i - 1]:indices[i], :npcs] for i in range(len(indices))] 143 | 144 | return lowdim_list, hvg_total, mean.values.reshape(1, -1), std.values.reshape(1, -1), pca 145 | 146 | 147 | def integrate_datasets(lowdim_list, # list of low-dimensional representations 148 | search_cos=False, # searching for an optimal lambdacos 149 | lambda_cos=20.0, 150 | training_steps=2000, 151 | space=None, # None or "reference" or "latent" 152 | data_path="data", 153 | mixingmetric_subsample=True 154 | ): 155 | 156 | if space == None: 157 | if len(lowdim_list) == 2: 158 | space = "latent" 159 | else: 160 | space = "reference" 161 | 162 | print("Incrementally integrating %d datasets..." % len(lowdim_list)) 163 | 164 | if not search_cos: 165 | # if not search hyperparameter lambdacos 166 | if isinstance(lambda_cos, float) or isinstance(lambda_cos, int): 167 | lambda_cos_tmp = lambda_cos 168 | 169 | for i in range(len(lowdim_list) - 1): 170 | 171 | if isinstance(lambda_cos, list): 172 | lambda_cos_tmp = lambda_cos[i] 173 | 174 | print("Integrating the %d-th dataset to the 1-st dataset..." % (i + 2)) 175 | model = Model(lambdacos=lambda_cos_tmp, 176 | training_steps=training_steps, 177 | data_path=os.path.join(data_path, "preprocess"), 178 | model_path="models/%d_datasets" % (i + 2), 179 | result_path="results/%d_datasets" % (i + 2)) 180 | if i == 0: 181 | model.emb_A = lowdim_list[0] 182 | else: 183 | model.emb_A = emb_total 184 | model.emb_B = lowdim_list[i + 1] 185 | model.train() 186 | model.eval() 187 | emb_total = model.data_Aspace 188 | if space == "reference": 189 | return emb_total 190 | elif space == "latent": 191 | return model.latent 192 | else: 193 | raise ValueError("Space should be either 'reference' or 'latent'.") 194 | else: 195 | for i in range(len(lowdim_list) - 1): 196 | print("Integrating the %d-th dataset to the 1-st dataset..." % (i + 2)) 197 | for lambda_cos in [15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0]: 198 | model = Model(lambdacos=lambda_cos, 199 | training_steps=training_steps, 200 | data_path=os.path.join(data_path, "preprocess"), 201 | model_path="models/%d_datasets" % (i + 2), 202 | result_path="results/%d_datasets" % (i + 2)) 203 | if i == 0: 204 | model.emb_A = lowdim_list[0] 205 | else: 206 | model.emb_A = emb_total 207 | model.emb_B = lowdim_list[i + 1] 208 | model.train() 209 | model.eval() 210 | meta = pd.DataFrame(index=np.arange(model.emb_A.shape[0] + model.emb_B.shape[0])) 211 | meta["method"] = ["A"] * model.emb_A.shape[0] + ["B"] * model.emb_B.shape[0] 212 | mixing = calculate_mixing_metric(model.latent, meta, k=5, max_k=300, methods=list(set(meta.method)), subsample=mixingmetric_subsample) 213 | print("lambda_cos: %f, mixing metric: %f \n" % (lambda_cos, mixing)) 214 | if lambda_cos == 15.0: 215 | model_opt = model 216 | mixing_metric_opt = mixing 217 | elif mixing < mixing_metric_opt: 218 | model_opt = model 219 | mixing_metric_opt = mixing 220 | emb_total = model_opt.data_Aspace 221 | if space == "reference": 222 | return emb_total 223 | elif space == "latent": 224 | return model_opt.latent 225 | else: 226 | raise ValueError("Space should be either 'reference' or 'latent'.") 227 | 228 | 229 | def integrate_recover_expression(lowdim_list, # list of low-dimensional representations 230 | mean, std, pca, # information for recovering expression 231 | search_cos=False, # searching for an optimal lambdacos 232 | lambda_cos=20.0, 233 | training_steps=2000, 234 | data_path="data", 235 | mixingmetric_subsample=True 236 | ): 237 | 238 | print("Incrementally integrating %d datasets..." % len(lowdim_list)) 239 | 240 | if not search_cos: 241 | # if not search hyperparameter lambdacos 242 | if isinstance(lambda_cos, float) or isinstance(lambda_cos, int): 243 | lambda_cos_tmp = lambda_cos 244 | 245 | for i in range(len(lowdim_list) - 1): 246 | 247 | if isinstance(lambda_cos, list): 248 | lambda_cos_tmp = lambda_cos[i] 249 | 250 | print("Integrating the %d-th dataset to the 1-st dataset..." % (i + 2)) 251 | model = Model(lambdacos=lambda_cos_tmp, 252 | training_steps=training_steps, 253 | data_path=os.path.join(data_path, "preprocess"), 254 | model_path="models/%d_datasets" % (i + 2), 255 | result_path="results/%d_datasets" % (i + 2)) 256 | if i == 0: 257 | model.emb_A = lowdim_list[0] 258 | else: 259 | model.emb_A = emb_total 260 | model.emb_B = lowdim_list[i + 1] 261 | model.train() 262 | model.eval() 263 | emb_total = model.data_Aspace 264 | else: 265 | for i in range(len(lowdim_list) - 1): 266 | print("Integrating the %d-th dataset to the 1-st dataset..." % (i + 2)) 267 | for lambda_cos in [15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0]: 268 | model = Model(lambdacos=lambda_cos, 269 | training_steps=training_steps, 270 | data_path=os.path.join(data_path, "preprocess"), 271 | model_path="models/%d_datasets" % (i + 2), 272 | result_path="results/%d_datasets" % (i + 2)) 273 | if i == 0: 274 | model.emb_A = lowdim_list[0] 275 | else: 276 | model.emb_A = emb_total 277 | model.emb_B = lowdim_list[i + 1] 278 | model.train() 279 | model.eval() 280 | meta = pd.DataFrame(index=np.arange(model.emb_A.shape[0] + model.emb_B.shape[0])) 281 | meta["method"] = ["A"] * model.emb_A.shape[0] + ["B"] * model.emb_B.shape[0] 282 | mixing = calculate_mixing_metric(model.latent, meta, k=5, max_k=300, methods=list(set(meta.method)), subsample=mixingmetric_subsample) 283 | print("lambda_cos: %f, mixing metric: %f \n" % (lambda_cos, mixing)) 284 | if lambda_cos == 15.0: 285 | model_opt = model 286 | mixing_metric_opt = mixing 287 | elif mixing < mixing_metric_opt: 288 | model_opt = model 289 | mixing_metric_opt = mixing 290 | emb_total = model_opt.data_Aspace 291 | 292 | expression_scaled = pca.inverse_transform(emb_total) 293 | expression_log_normalized = expression_scaled * std + mean 294 | 295 | return expression_scaled, expression_log_normalized 296 | 297 | 298 | def calculate_mixing_metric(data, meta, methods, k=5, max_k=300, subsample=True): 299 | if subsample: 300 | if data.shape[0] >= 1e4: 301 | np.random.seed(1234) 302 | subsample_idx = np.random.choice(data.shape[0], 10000, replace=False) 303 | data = data[subsample_idx] 304 | meta = meta.iloc[subsample_idx] 305 | meta.index = np.arange(len(subsample_idx)) 306 | lowdim = data 307 | 308 | nbrs = NearestNeighbors(n_neighbors=max_k, algorithm='kd_tree').fit(lowdim) 309 | _, indices = nbrs.kneighbors(lowdim) 310 | indices = indices[:, 1:] 311 | mixing = np.zeros((data.shape[0], 2)) 312 | for i in range(data.shape[0]): 313 | if len(np.where(meta.method[indices[i, :]] == methods[0])[0]) > k-1: 314 | mixing[i, 0] = np.where(meta.method[indices[i, :]] == methods[0])[0][k-1] 315 | else: mixing[i, 0] = max_k - 1 316 | if len(np.where(meta.method[indices[i, :]] == methods[1])[0]) > k-1: 317 | mixing[i, 1] = np.where(meta.method[indices[i, :]] == methods[1])[0][k-1] 318 | else: mixing[i, 1] = max_k - 1 319 | return np.mean(np.median(mixing, axis=1) + 1) 320 | 321 | 322 | def calculate_ARI(data, meta, anno_A="drop_subcluster", anno_B="subcluster"): 323 | # np.random.seed(1234) 324 | if data.shape[0] > 1e5: 325 | np.random.seed(1234) 326 | subsample_idx = np.random.choice(data.shape[0], 50000, replace=False) 327 | data = data[subsample_idx] 328 | meta = meta.iloc[subsample_idx] 329 | lowdim = data 330 | 331 | cellid = meta.index.astype(str) 332 | method = meta["method"].astype(str) 333 | cluster_A = meta[anno_A].astype(str) 334 | if (anno_B != anno_A): 335 | cluster_B = meta[anno_B].astype(str) 336 | 337 | rpy2.robjects.numpy2ri.activate() 338 | nr, nc = lowdim.shape 339 | lowdim = ro.r.matrix(lowdim, nrow=nr, ncol=nc) 340 | ro.r.assign("data", lowdim) 341 | rpy2.robjects.numpy2ri.deactivate() 342 | 343 | cellid = ro.StrVector(cellid) 344 | ro.r.assign("cellid", cellid) 345 | method = ro.StrVector(method) 346 | ro.r.assign("method", method) 347 | cluster_A = ro.StrVector(cluster_A) 348 | ro.r.assign("cluster_A", cluster_A) 349 | if (anno_B != anno_A): 350 | cluster_B = ro.StrVector(cluster_B) 351 | ro.r.assign("cluster_B", cluster_B) 352 | 353 | ro.r("set.seed(1234)") 354 | ro.r['library']("Seurat") 355 | ro.r['library']("mclust") 356 | 357 | ro.r("comb_normalized <- t(data)") 358 | ro.r('''rownames(comb_normalized) <- paste("gene", 1:nrow(comb_normalized), sep = "")''') 359 | ro.r("colnames(comb_normalized) <- as.vector(cellid)") 360 | 361 | ro.r("comb_raw <- matrix(0, nrow = nrow(comb_normalized), ncol = ncol(comb_normalized))") 362 | ro.r("rownames(comb_raw) <- rownames(comb_normalized)") 363 | ro.r("colnames(comb_raw) <- colnames(comb_normalized)") 364 | 365 | ro.r("comb <- CreateSeuratObject(comb_raw)") 366 | ro.r('''scunitdata <- Seurat::CreateDimReducObject( 367 | embeddings = t(comb_normalized), 368 | stdev = as.numeric(apply(comb_normalized, 2, stats::sd)), 369 | assay = "RNA", 370 | key = "scunit")''') 371 | ro.r('''comb[["scunit"]] <- scunitdata''') 372 | 373 | ro.r("comb@meta.data$method <- method") 374 | 375 | ro.r("comb@meta.data$cluster_A <- cluster_A") 376 | if (anno_B != anno_A): 377 | ro.r("comb@meta.data$cluster_B <- cluster_B") 378 | 379 | ro.r('''comb <- FindNeighbors(comb, reduction = "scunit", dims = 1:ncol(data), force.recalc = TRUE, verbose = FALSE)''') 380 | ro.r('''comb <- FindClusters(comb, verbose = FALSE)''') 381 | 382 | if (anno_B != anno_A): 383 | method_set = pd.unique(meta["method"]) 384 | method_A = method_set[0] 385 | ro.r.assign("method_A", method_A) 386 | method_B = method_set[1] 387 | ro.r.assign("method_B", method_B) 388 | ro.r('''indx_A <- which(comb$method == method_A)''') 389 | ro.r('''indx_B <- which(comb$method == method_B)''') 390 | 391 | ro.r("ARI_A <- adjustedRandIndex(comb$cluster_A[indx_A], comb$seurat_clusters[indx_A])") 392 | ro.r("ARI_B <- adjustedRandIndex(comb$cluster_B[indx_B], comb$seurat_clusters[indx_B])") 393 | ARI_A = np.array(ro.r("ARI_A"))[0] 394 | ARI_B = np.array(ro.r("ARI_B"))[0] 395 | 396 | return ARI_A, ARI_B 397 | else: 398 | ro.r("ARI_A <- adjustedRandIndex(comb$cluster_A, comb$seurat_clusters)") 399 | ARI_A = np.array(ro.r("ARI_A"))[0] 400 | 401 | return ARI_A 402 | 403 | 404 | def calculate_kBET(data, meta): 405 | cellid = meta.index.astype(str) 406 | method = meta["method"].astype(str) 407 | 408 | rpy2.robjects.numpy2ri.activate() 409 | nr, nc = data.shape 410 | data = ro.r.matrix(data, nrow=nr, ncol=nc) 411 | ro.r.assign("data", data) 412 | rpy2.robjects.numpy2ri.deactivate() 413 | 414 | cellid = ro.StrVector(cellid) 415 | ro.r.assign("cellid", cellid) 416 | method = ro.StrVector(method) 417 | ro.r.assign("method", method) 418 | 419 | ro.r("set.seed(1234)") 420 | ro.r['library']("kBET") 421 | 422 | accept_rate = [] 423 | for _ in range(100): 424 | ro.r("subset_id <- sample.int(n = length(method), size = 1000, replace=FALSE)") 425 | 426 | ro.r("batch.estimate <- kBET(data[subset_id,], method[subset_id], do.pca = FALSE, plot=FALSE)") 427 | accept_rate.append(np.array(ro.r("mean(batch.estimate$results$kBET.pvalue.test > 0.05)"))) 428 | 429 | return np.median(accept_rate) 430 | 431 | 432 | def calculate_ASW(data, meta, anno_A="drop_subcluster", anno_B="subcluster"): 433 | if data.shape[0] >= 1e5: 434 | np.random.seed(1234) 435 | subsample_idx = np.random.choice(data.shape[0], 50000, replace=False) 436 | data = data[subsample_idx] 437 | meta = meta.iloc[subsample_idx] 438 | lowdim = data 439 | 440 | cellid = meta.index.astype(str) 441 | method = meta["method"].astype(str) 442 | cluster_A = meta[anno_A].astype(str) 443 | if (anno_B != anno_A): 444 | cluster_B = meta[anno_B].astype(str) 445 | 446 | rpy2.robjects.numpy2ri.activate() 447 | nr, nc = lowdim.shape 448 | lowdim = ro.r.matrix(lowdim, nrow=nr, ncol=nc) 449 | ro.r.assign("data", lowdim) 450 | rpy2.robjects.numpy2ri.deactivate() 451 | 452 | cellid = ro.StrVector(cellid) 453 | ro.r.assign("cellid", cellid) 454 | method = ro.StrVector(method) 455 | ro.r.assign("method", method) 456 | cluster_A = ro.StrVector(cluster_A) 457 | ro.r.assign("cluster_A", cluster_A) 458 | if (anno_B != anno_A): 459 | cluster_B = ro.StrVector(cluster_B) 460 | ro.r.assign("cluster_B", cluster_B) 461 | 462 | ro.r("set.seed(1234)") 463 | ro.r['library']("cluster") 464 | 465 | if (anno_B != anno_A): 466 | method_set = pd.unique(meta["method"]) 467 | method_A = method_set[0] 468 | ro.r.assign("method_A", method_A) 469 | method_B = method_set[1] 470 | ro.r.assign("method_B", method_B) 471 | ro.r('''indx_A <- which(method == method_A)''') 472 | ro.r('''indx_B <- which(method == method_B)''') 473 | ro.r('''ASW_A <- summary(silhouette(as.numeric(as.factor(cluster_A[indx_A])), dist(data[indx_A, 1:20])))[["avg.width"]]''') 474 | ro.r('''ASW_B <- summary(silhouette(as.numeric(as.factor(cluster_B[indx_B])), dist(data[indx_B, 1:20])))[["avg.width"]]''') 475 | ASW_A = np.array(ro.r("ASW_A"))[0] 476 | ASW_B = np.array(ro.r("ASW_B"))[0] 477 | 478 | return ASW_A, ASW_B 479 | else: 480 | ro.r('''ASW_A <- summary(silhouette(as.numeric(as.factor(cluster_A)), dist(data[, 1:20])))[["avg.width"]]''') 481 | ASW_A = np.array(ro.r("ASW_A"))[0] 482 | 483 | return ASW_A 484 | 485 | 486 | def calculate_cellcycleconservation(data, meta, adata_raw, organism="mouse", resources_path="./cell_cycle_resources"): 487 | #adata 488 | cellid = list(meta.index.astype(str)) 489 | geneid = ["gene_"+str(i) for i in range(data.shape[1])] 490 | adata = anndata.AnnData(X=data, obs=cellid, var=geneid) 491 | 492 | #score cell cycle 493 | cc_files = {'mouse': [os.path.join(resources_path, 's_genes_tirosh.txt'), 494 | os.path.join(resources_path, 'g2m_genes_tirosh.txt')]} 495 | with open(cc_files[organism][0], "r") as f: 496 | s_genes = [x.strip() for x in f.readlines() if x.strip() in adata_raw.var.index] 497 | with open(cc_files[organism][1], "r") as f: 498 | g2m_genes = [x.strip() for x in f.readlines() if x.strip() in adata_raw.var.index] 499 | sc.tl.score_genes_cell_cycle(adata_raw, s_genes, g2m_genes) 500 | 501 | adata_raw.obs["method"] = meta["method"].values.astype(str) 502 | adata.obs["method"] = meta["method"].values.astype(str) 503 | batches = adata_raw.obs["method"].unique() 504 | 505 | scores_final = [] 506 | scores_before = [] 507 | scores_after = [] 508 | for batch in batches: 509 | raw_sub = adata_raw[adata_raw.obs["method"] == batch] 510 | int_sub = adata[adata.obs["method"] == batch].copy() 511 | int_sub = int_sub.X 512 | 513 | #regression variable 514 | covariate_values = raw_sub.obs[['S_score', 'G2M_score']] 515 | if pd.api.types.is_numeric_dtype(covariate_values): 516 | covariate_values = np.array(covariate_values).reshape(-1, 1) 517 | else: 518 | covariate_values = pd.get_dummies(covariate_values) 519 | 520 | #PCR on data before integration 521 | n_comps = 50 522 | svd_solver = 'arpack' 523 | pca = sc.tl.pca(raw_sub.X, n_comps=n_comps, use_highly_variable=False, return_info=True, svd_solver=svd_solver, copy=True) 524 | X_pca = pca[0].copy() 525 | pca_var = pca[3].copy() 526 | del pca 527 | 528 | r2 = [] 529 | for i in range(n_comps): 530 | pc = X_pca[:, [i]] 531 | lm = LinearRegression() 532 | lm.fit(covariate_values, pc) 533 | r2_score = np.maximum(0, lm.score(covariate_values, pc)) 534 | r2.append(r2_score) 535 | 536 | Var = pca_var / sum(pca_var) * 100 537 | before = sum(r2 * Var) / 100 538 | 539 | #PCR on data after integration 540 | n_comps = min(data.shape) 541 | svd_solver = 'full' 542 | pca = sc.tl.pca(int_sub, n_comps=n_comps, use_highly_variable=False, return_info=True, svd_solver=svd_solver, copy=True) 543 | X_pca = pca[0].copy() 544 | pca_var = pca[3].copy() 545 | del pca 546 | 547 | r2 = [] 548 | for i in range(n_comps): 549 | pc = X_pca[:, [i]] 550 | lm = LinearRegression() 551 | lm.fit(covariate_values, pc) 552 | r2_score = np.maximum(0, lm.score(covariate_values, pc)) 553 | r2.append(r2_score) 554 | 555 | Var = pca_var / sum(pca_var) * 100 556 | after = sum(r2 * Var) / 100 557 | 558 | #scale result 559 | score = 1 - abs(before - after) / before 560 | if score < 0: 561 | score = 0 562 | scores_before.append(before) 563 | scores_after.append(after) 564 | scores_final.append(score) 565 | 566 | score_out = np.mean(scores_final) 567 | return score_out 568 | 569 | 570 | def calculate_isolatedASW(data, meta, anno): 571 | tmp = meta[[anno, "method"]].drop_duplicates() 572 | batch_per_lab = tmp.groupby(anno).agg({"method": "count"}) 573 | iso_threshold = batch_per_lab.min().tolist()[0] 574 | labels = batch_per_lab[batch_per_lab["method"] <= iso_threshold].index.tolist() 575 | 576 | scores = {} 577 | for label_tar in labels: 578 | iso_label = np.array(meta[anno] == label_tar).astype(int) 579 | asw = silhouette_score( 580 | X=data, 581 | labels=iso_label, 582 | metric='euclidean' 583 | ) 584 | asw = (asw + 1) / 2 585 | scores[label_tar] = asw 586 | 587 | scores = pd.Series(scores) 588 | score = scores.mean() 589 | return score 590 | 591 | 592 | def calculate_isolatedF1(data, meta, anno): 593 | if data.shape[0] > 1e5: 594 | np.random.seed(1234) 595 | subsample_idx = np.random.choice(data.shape[0], 50000, replace=False) 596 | data = data[subsample_idx] 597 | meta = meta.iloc[subsample_idx] 598 | lowdim = data 599 | 600 | tmp = meta[[anno, "method"]].drop_duplicates() 601 | batch_per_lab = tmp.groupby(anno).agg({"method": "count"}) 602 | iso_threshold = batch_per_lab.min().tolist()[0] 603 | labels = batch_per_lab[batch_per_lab["method"] <= iso_threshold].index.tolist() 604 | 605 | cellid = meta.index.astype(str) 606 | method = meta["method"].astype(str) 607 | cluster_A = meta[anno].astype(str) 608 | 609 | rpy2.robjects.numpy2ri.activate() 610 | nr, nc = lowdim.shape 611 | lowdim = ro.r.matrix(lowdim, nrow=nr, ncol=nc) 612 | ro.r.assign("data", lowdim) 613 | rpy2.robjects.numpy2ri.deactivate() 614 | 615 | cellid = ro.StrVector(cellid) 616 | ro.r.assign("cellid", cellid) 617 | method = ro.StrVector(method) 618 | ro.r.assign("method", method) 619 | cluster_A = ro.StrVector(cluster_A) 620 | ro.r.assign("cluster_A", cluster_A) 621 | 622 | ro.r("set.seed(1234)") 623 | ro.r['library']("Seurat") 624 | 625 | ro.r("comb_normalized <- t(data)") 626 | ro.r('''rownames(comb_normalized) <- paste("gene", 1:nrow(comb_normalized), sep = "")''') 627 | ro.r("colnames(comb_normalized) <- as.vector(cellid)") 628 | 629 | ro.r("comb_raw <- matrix(0, nrow = nrow(comb_normalized), ncol = ncol(comb_normalized))") 630 | ro.r("rownames(comb_raw) <- rownames(comb_normalized)") 631 | ro.r("colnames(comb_raw) <- colnames(comb_normalized)") 632 | 633 | ro.r("comb <- CreateSeuratObject(comb_raw)") 634 | ro.r('''scunitdata <- Seurat::CreateDimReducObject( 635 | embeddings = t(comb_normalized), 636 | stdev = as.numeric(apply(comb_normalized, 2, stats::sd)), 637 | assay = "RNA", 638 | key = "scunit")''') 639 | ro.r('''comb[["scunit"]] <- scunitdata''') 640 | 641 | ro.r('''comb <- FindNeighbors(comb, reduction = "scunit", dims = 1:ncol(data), force.recalc = TRUE, verbose = FALSE)''') 642 | ro.r('''comb <- FindClusters(comb, verbose = FALSE)''') 643 | 644 | louvain_clusters = np.array(ro.r("comb$seurat_clusters")).astype("str") 645 | louvain_list = list(set(louvain_clusters)) 646 | 647 | scores = {} 648 | for label_tar in labels: 649 | max_f1 = 0 650 | for cluster in louvain_list: 651 | y_pred = louvain_clusters == cluster 652 | y_true = meta[anno].values.astype(str) == label_tar 653 | f1 = f1_score(y_pred, y_true) 654 | if f1 > max_f1: 655 | max_f1 = f1 656 | scores[label_tar] = max_f1 657 | 658 | scores = pd.Series(scores) 659 | score = scores.mean() 660 | return score 661 | 662 | 663 | def calculate_graphconnectivity(data, meta, anno): 664 | cellid = list(meta.index.astype(str)) 665 | geneid = ["gene_"+str(i) for i in range(data.shape[1])] 666 | adata = anndata.AnnData(X=data, obs=cellid, var=geneid) 667 | 668 | adata.obsm["X_emb"] = data 669 | sc.pp.neighbors(adata, n_neighbors=15, use_rep="X_emb") 670 | 671 | adata.obs["anno"] = meta[anno].values.astype(str) 672 | anno_list = list(set(adata.obs["anno"])) 673 | 674 | clust_res = [] 675 | 676 | for label in anno_list: 677 | adata_sub = adata[adata.obs["anno"].isin([label])] 678 | _, labels = connected_components( 679 | adata_sub.obsp['connectivities'], 680 | connection='strong' 681 | ) 682 | tab = pd.value_counts(labels) 683 | clust_res.append(tab.max() / sum(tab)) 684 | 685 | score = np.mean(clust_res) 686 | return score 687 | 688 | 689 | def calculate_PCRbatch(data, meta, data_before=None): 690 | covariate_values = meta["method"] 691 | 692 | n_comps = min(data.shape) 693 | svd_solver = 'full' 694 | pca = sc.tl.pca(data, n_comps=n_comps, use_highly_variable=False, return_info=True, svd_solver=svd_solver, copy=True) 695 | X_pca = pca[0].copy() 696 | pca_var = pca[3].copy() 697 | del pca 698 | 699 | if pd.api.types.is_numeric_dtype(covariate_values): 700 | covariate_values = np.array(covariate_values).reshape(-1, 1) 701 | else: 702 | covariate_values = pd.get_dummies(covariate_values) 703 | 704 | r2 = [] 705 | for i in range(n_comps): 706 | pc = X_pca[:, [i]] 707 | lm = LinearRegression() 708 | lm.fit(covariate_values, pc) 709 | r2_score = np.maximum(0, lm.score(covariate_values, pc)) 710 | r2.append(r2_score) 711 | 712 | Var = pca_var / sum(pca_var) * 100 713 | R2Var = sum(r2 * Var) / 100 714 | 715 | if data_before is not None: 716 | n_comps = 50 717 | svd_solver = 'arpack' 718 | pca = sc.tl.pca(data_before, n_comps=n_comps, use_highly_variable=False, return_info=True, svd_solver=svd_solver, copy=True) 719 | X_pca = pca[0].copy() 720 | pca_var = pca[3].copy() 721 | del pca 722 | 723 | r2 = [] 724 | for i in range(n_comps): 725 | pc = X_pca[:, [i]] 726 | lm = LinearRegression() 727 | lm.fit(covariate_values, pc) 728 | r2_score = np.maximum(0, lm.score(covariate_values, pc)) 729 | r2.append(r2_score) 730 | 731 | Var = pca_var / sum(pca_var) * 100 732 | R2Var_before = sum(r2 * Var) / 100 733 | 734 | score = (R2Var_before - R2Var) / R2Var_before 735 | return score, R2Var, R2Var_before 736 | else: 737 | return R2Var 738 | 739 | 740 | def calculate_NMI(data, meta, anno_A="drop_subcluster", anno_B="subcluster"): 741 | # np.random.seed(1234) 742 | if data.shape[0] > 1e5: 743 | np.random.seed(1234) 744 | subsample_idx = np.random.choice(data.shape[0], 50000, replace=False) 745 | data = data[subsample_idx] 746 | meta = meta.iloc[subsample_idx] 747 | lowdim = data 748 | 749 | cellid = meta.index.astype(str) 750 | method = meta["method"].astype(str) 751 | cluster_A = meta[anno_A].astype(str) 752 | if (anno_B != anno_A): 753 | cluster_B = meta[anno_B].astype(str) 754 | 755 | rpy2.robjects.numpy2ri.activate() 756 | nr, nc = lowdim.shape 757 | lowdim = ro.r.matrix(lowdim, nrow=nr, ncol=nc) 758 | ro.r.assign("data", lowdim) 759 | rpy2.robjects.numpy2ri.deactivate() 760 | 761 | cellid = ro.StrVector(cellid) 762 | ro.r.assign("cellid", cellid) 763 | method = ro.StrVector(method) 764 | ro.r.assign("method", method) 765 | cluster_A = ro.StrVector(cluster_A) 766 | ro.r.assign("cluster_A", cluster_A) 767 | if (anno_B != anno_A): 768 | cluster_B = ro.StrVector(cluster_B) 769 | ro.r.assign("cluster_B", cluster_B) 770 | 771 | ro.r("set.seed(1234)") 772 | ro.r['library']("Seurat") 773 | 774 | ro.r("comb_normalized <- t(data)") 775 | ro.r('''rownames(comb_normalized) <- paste("gene", 1:nrow(comb_normalized), sep = "")''') 776 | ro.r("colnames(comb_normalized) <- as.vector(cellid)") 777 | 778 | ro.r("comb_raw <- matrix(0, nrow = nrow(comb_normalized), ncol = ncol(comb_normalized))") 779 | ro.r("rownames(comb_raw) <- rownames(comb_normalized)") 780 | ro.r("colnames(comb_raw) <- colnames(comb_normalized)") 781 | 782 | ro.r("comb <- CreateSeuratObject(comb_raw)") 783 | ro.r('''scunitdata <- Seurat::CreateDimReducObject( 784 | embeddings = t(comb_normalized), 785 | stdev = as.numeric(apply(comb_normalized, 2, stats::sd)), 786 | assay = "RNA", 787 | key = "scunit")''') 788 | ro.r('''comb[["scunit"]] <- scunitdata''') 789 | 790 | ro.r("comb@meta.data$method <- method") 791 | ro.r("comb@meta.data$cluster_A <- cluster_A") 792 | if (anno_B != anno_A): 793 | ro.r("comb@meta.data$cluster_B <- cluster_B") 794 | 795 | ro.r('''comb <- FindNeighbors(comb, reduction = "scunit", dims = 1:ncol(data), force.recalc = TRUE, verbose = FALSE)''') 796 | ro.r('''comb <- FindClusters(comb, verbose = FALSE)''') 797 | 798 | np.random.seed(1234) 799 | if (anno_B != anno_A): 800 | method_set = pd.unique(meta["method"]) 801 | method_A = method_set[0] 802 | ro.r.assign("method_A", method_A) 803 | method_B = method_set[1] 804 | ro.r.assign("method_B", method_B) 805 | ro.r('''indx_A <- which(comb$method == method_A)''') 806 | ro.r('''indx_B <- which(comb$method == method_B)''') 807 | 808 | #A 809 | louvain_A = np.array(ro.r("comb$seurat_clusters[indx_A]")).astype("str") 810 | cluster_A = np.array(ro.r("comb$cluster_A[indx_A]")).astype("str") 811 | df_A = pd.DataFrame({'louvain_A': louvain_A, 'cluster_A': cluster_A}) 812 | df_A.louvain_A = pd.Categorical(df_A.louvain_A) 813 | df_A.cluster_A = pd.Categorical(df_A.cluster_A) 814 | df_A['louvain_code'] = df_A.louvain_A.cat.codes 815 | df_A['A_code'] = df_A.cluster_A.cat.codes 816 | NMI_A = NMI(df_A['A_code'].values, df_A['louvain_code'].values) 817 | 818 | #B 819 | louvain_B = np.array(ro.r("comb$seurat_clusters[indx_B]")).astype("str") 820 | cluster_B = np.array(ro.r("comb$cluster_B[indx_B]")).astype("str") 821 | df_B = pd.DataFrame({'louvain_B': louvain_B, 'cluster_B': cluster_B}) 822 | df_B.louvain_B = pd.Categorical(df_B.louvain_B) 823 | df_B.cluster_B = pd.Categorical(df_B.cluster_B) 824 | df_B['louvain_code'] = df_B.louvain_B.cat.codes 825 | df_B['B_code'] = df_B.cluster_B.cat.codes 826 | NMI_B = NMI(df_B['B_code'].values, df_B['louvain_code'].values) 827 | 828 | return NMI_A, NMI_B 829 | else: 830 | louvain_clusters = np.array(ro.r("comb$seurat_clusters")).astype("str") 831 | cluster_A = np.array(ro.r("comb$cluster_A")).astype("str") 832 | 833 | df_fornmi = pd.DataFrame({'louvain_clusters': louvain_clusters, 834 | 'cluster_A': cluster_A}) 835 | df_fornmi.louvain_clusters = pd.Categorical(df_fornmi.louvain_clusters) 836 | df_fornmi.cluster_A = pd.Categorical(df_fornmi.cluster_A) 837 | df_fornmi['louvain_code'] = df_fornmi.louvain_clusters.cat.codes 838 | df_fornmi['A_code'] = df_fornmi.cluster_A.cat.codes 839 | 840 | NMI_A = NMI(df_fornmi['A_code'].values, df_fornmi['louvain_code'].values) 841 | return NMI_A 842 | 843 | 844 | def annotate_by_nn(vec_tar, vec_ref, label_ref, k=20, metric='cosine'): 845 | dist_mtx = cdist(vec_tar, vec_ref, metric=metric) 846 | idx = dist_mtx.argsort()[:, :k] 847 | labels = [max(list(label_ref[i]), key=list(label_ref[i]).count) for i in idx] 848 | return labels 849 | 850 | def plot_UMAP(data, meta, space="latent", score=None, colors=["method"], subsample=False, 851 | save=False, result_path=None, filename_suffix=None): 852 | if filename_suffix is not None: 853 | filenames = [os.path.join(result_path, "%s-%s-%s.pdf" % (space, c, filename_suffix)) for c in colors] 854 | else: 855 | filenames = [os.path.join(result_path, "%s-%s.pdf" % (space, c)) for c in colors] 856 | 857 | if subsample: 858 | if data.shape[0] >= 1e5: 859 | np.random.seed(1234) 860 | subsample_idx = np.random.choice(data.shape[0], 50000, replace=False) 861 | data = data[subsample_idx] 862 | meta = meta.iloc[subsample_idx] 863 | if score is not None: 864 | score = score[subsample_idx] 865 | 866 | adata = anndata.AnnData(X=data) 867 | adata.obs.index = meta.index 868 | adata.obs = pd.concat([adata.obs, meta], axis=1) 869 | adata.var.index = "dim-" + adata.var.index 870 | adata.obsm["latent"] = data 871 | 872 | # run UMAP 873 | reducer = umap.UMAP(n_neighbors=30, 874 | n_components=2, 875 | metric="correlation", 876 | n_epochs=None, 877 | learning_rate=1.0, 878 | min_dist=0.3, 879 | spread=1.0, 880 | set_op_mix_ratio=1.0, 881 | local_connectivity=1, 882 | repulsion_strength=1, 883 | negative_sample_rate=5, 884 | a=None, 885 | b=None, 886 | random_state=1234, 887 | metric_kwds=None, 888 | angular_rp_forest=False, 889 | verbose=True) 890 | embedding = reducer.fit_transform(adata.obsm["latent"]) 891 | adata.obsm["X_umap"] = embedding 892 | 893 | n_cells = embedding.shape[0] 894 | if n_cells >= 10000: 895 | size = 120000 / n_cells 896 | else: 897 | size = 12 898 | 899 | for i, c in enumerate(colors): 900 | groups = sorted(set(adata.obs[c].astype(str))) 901 | if "nan" in groups: 902 | groups.remove("nan") 903 | palette = "rainbow" 904 | if save: 905 | fig = sc.pl.umap(adata, color=c, palette=palette, groups=groups, return_fig=True, size=size) 906 | fig.savefig(filenames[i], bbox_inches='tight', dpi=300) 907 | else: 908 | sc.pl.umap(adata, color=c, palette=palette, groups=groups, size=size) 909 | 910 | if space == "Aspace": 911 | method_set = pd.unique(meta["method"]) 912 | adata.obs["score"] = score 913 | adata.obs["margin"] = (score < -5.0) * 1 914 | fig = sc.pl.umap(adata[meta["method"]==method_set[1]], color="score", palette=palette, groups=groups, return_fig=True, size=size) 915 | fig.savefig(os.path.join(result_path, "%s-score.pdf" % space), bbox_inches='tight', dpi=300) 916 | fig = sc.pl.umap(adata[meta["method"]==method_set[1]], color="margin", palette=palette, groups=groups, return_fig=True, size=size) 917 | fig.savefig(os.path.join(result_path, "%s-margin.pdf" % space), bbox_inches='tight', dpi=300) 918 | if space == "Bspace": 919 | method_set = pd.unique(meta["method"]) 920 | adata.obs["score"] = score 921 | adata.obs["margin"] = (score < -5.0) * 1 922 | fig = sc.pl.umap(adata[meta["method"]==method_set[0]], color="score", palette=palette, groups=groups, return_fig=True, size=size) 923 | fig.savefig(os.path.join(result_path, "%s-score.pdf" % space), bbox_inches='tight', dpi=300) 924 | fig = sc.pl.umap(adata[meta["method"]==method_set[0]], color="margin", palette=palette, groups=groups, return_fig=True, size=size) 925 | fig.savefig(os.path.join(result_path, "%s-margin.pdf" % space), bbox_inches='tight', dpi=300) 926 | 927 | --------------------------------------------------------------------------------