├── LICENSE.md ├── README.md ├── datasets └── readme.md ├── script ├── README.md ├── run_LRP_script.sh └── run_scMDC_script.sh └── src ├── LRP.py ├── Simulation.R ├── fig1_.png ├── layers.py ├── preprocess.py ├── run_LRP.py ├── run_scMDC.py ├── run_scMDC_batch.py ├── scMDC.py ├── scMDC_batch.py ├── tree.txt └── utils.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright Xiang Lin 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # scMDC 2 | Single Cell Multi-omics deep clustering (**scMDC v1.0.1**) 3 | 4 | We develop a novel multimodal deep learning method, scMDC, for single-cell multi-omics data clustering analysis. scMDC is an end-to-end deep model that explicitly characterizes different data sources and jointly learns latent features of deep embedding for clustering analysis. Extensive simulation and real-data experiments reveal that scMDC outperforms existing single-cell single-modal and multimodal clustering methods on different single-cell multimodal datasets. The linear scalability of running time makes scMDC a promising method for analyzing large multimodal datasets. 5 | 6 | ## Table of contents 7 | - [Network diagram](#diagram) 8 | - [Dependencies](#Dependencies) 9 | - [Usage](#Usage) 10 | - [Output](#Output) 11 | - [Arguments](#Arguments) 12 | - [Citation](#Citation) 13 | - [Contact](#contact) 14 | 15 | ## Network diagram 16 | ![Model structure](https://github.com/xianglin226/scMDC/blob/master/src/fig1_.png?raw=true) 17 | 18 | ## Dependencies 19 | Python 3.8.1 20 | 21 | Pytorch 1.6.0 22 | 23 | Scanpy 1.6.0 24 | 25 | SKlearn 0.22.1 26 | 27 | Numpy 1.18.1 28 | 29 | h5py 2.9.0 30 | 31 | All experiments of scMDC in this study are conducted on Nvidia Tesla P100 (16G) GPU. 32 | We suggest to install the dependencies in a conda environment (conda create -n scMDC). 33 | It takes few minutes to install the dependencies. 34 | scMDC takes about 3 minutes to cluster a dataset with 5000 cells. 35 | 36 | ## Usage 37 | 1) Prepare the input data in h5 format. (See readme in 'dataset' folder) 38 | 2) Run scMDC according to the running script in "script" folder (Note the parameter settings if you work on mRNA+ATAC data and use run_scMDC_batch.py for multi-batch data clustering) 39 | 3) Run DE analysis by run_LRP.py based on the well-trained scMDC model (refer the LRP running script in "script" folder) 40 | 41 | ## Output 42 | 1) scMDC outputs a latent representation of data which can be used for further downstream analyses and visualized by t-SNE or Umap; 43 | 2) Multi-batch scMDC outputs a latent representation of integrated datasets on which the batch effects are corrected. 44 | 3) LRP outputs a gene rank which indicates the importances of genes for a given cluster and can be used for pathway analysis. 45 | 46 | ## Arguments 47 | --n_clusters: number of clusters (K); scMDC will estimate K if this arguments is set to -1. 48 | --cutoff: A ratio of epoch before which the model only train the low-level autoencoders. 49 | --batch_size: batch size. 50 | --data_file: path to the data input. 51 | Data format: H5. 52 | Structure: X1(RNA), X2(ADT or ATAC), Y(label, if exit), Batch (Batch indicator for multi-batch data clustering). 53 | --maxiter: maximum epochs of training. Default: 10000. 54 | --pretrain_epochs: number of epochs for pre-training. Default: 400. 55 | --gamma: coefficient of clustering loss. Default: 0.1. 56 | --phi1 and phi2: coefficient of KL loss in pretraining and clustering stage. Default: 0.001 for CITE-Seq; 0.005 for SMAGE-Seq*. 57 | --update_interval: the interval to check the performance. Default: 1. 58 | --tol: the criterion to stop the model, which is a percentage of changed labels. Default: 0.001. 59 | --ae_weights: path of the weight file. 60 | --save_dir: the directory to store the outputs. 61 | --ae_weight_file: the directory to store the weights. 62 | --resolution: the resolution parameter to estimate k. Default: 0.2. 63 | --n_neighbors: the n_neighbors parameter to estimate K. Default: 30. 64 | --embedding_file: if save embedding file. Default: No 65 | --prediction_file: if save prediction file. Default: No 66 | --encodeLayer: layers of the low-level encoder for RNA: Default: [256,64,32,16] for CITE-Seq; [256,128,64] for SMAGE-seq. 67 | --decodeLayer1: layers of the low-level encoder for ADT: Default: [16,64,256] for CITE-Seq. [64,128,256] for SMAGE-seq. 68 | --decodeLayer2: layers of the high-level encoder. Default:[16,20] for CITE-Seq. [64,128,256] for SMAGE-seq. 69 | --sigma1: noise on RNA data. Default: 2.5. 70 | --sigma2: noise on ADT data. Default: 1.5 for CITE-Seq; 2.5 for SMAGE-Seq 71 | --filter1: if do feature selection on Genes. Default: No. 72 | --filter2: if do feature selection on ATAC. Default: No. 73 | --f1: Number of high variable genes (in X1) used for clustering if doing the featue selection. Default: 2000 74 | --f2: Number of high variable genes from ATAC (in X2) used for clustering if doing the featue selection. Default: 2000 75 | *We denote 10X Single-Cell Multiome ATAC + Gene Expression technology as SMAGE-seq for convenience. 76 | 77 | 78 | ## Citation 79 | Lin, X., Tian, T., Wei, Z., & Hakonarson, H. (2022). Clustering of single-cell multi-omics data with a multimodal deep learning method. Nature Communications, 13(1), 1-18. 80 | 81 | ## Contact 82 | Xiang Lin 83 | -------------------------------------------------------------------------------- /datasets/readme.md: -------------------------------------------------------------------------------- 1 | # Example Real datasets 2 | Example datasets are in h5 format and can be downloaded from https://drive.google.com/drive/folders/1HTCrb3HZpbN8Nte5xGhY517NMdyzqftw?usp=sharing 3 | 4 | Required objects in h5 file for running scMDC 5 | 1) X1: mRNA count matrix 6 | 2) X2: ADT count matrix 7 | 3) Y: True labels (if exist) 8 | 4) Batch: batch indicator (for multi-batch analysis) 9 | 10 | Other objects in the h5 files: 11 | 1) ADT: feature names in ADT count matirx (only in CITE-seq data) 12 | 2) GenesFromPeaks: feature names in the gene-to-cell matrix mapped from scATAC-seq (only in SMAGE-seq data) 13 | 3) Genes: feature names in mRNA count matrix 14 | 4) Cell types: cell type of each cell (if exist). 15 | 5) Barcodes: cell barcodes (if exits) 16 | 17 | Note: for using filtered datasets (Normalized_filtred*), use X1_ and X2_ as inputs. 18 | -------------------------------------------------------------------------------- /script/README.md: -------------------------------------------------------------------------------- 1 | # Note 2 | For using full feature datasets, use X1 and X2 as inputs and turn on the filtering (--filter1 and/or --filter2) function. 3 | For using filtered datasets (Normalized_filtered*), use X1_ and X2_ as inputs and turn off the filtering function. 4 | -------------------------------------------------------------------------------- /script/run_LRP_script.sh: -------------------------------------------------------------------------------- 1 | # Run DE analysis (LRP) based on the well-trained scMDC model and its results. 2 | f=../datasets/CITESeq_GSE128639_BMNC_annodata.h5 3 | python -u run_scMDC.py --n_clusters 27 --ae_weight_file ./out_bmnc_full/AE_weights_bmnc.pth.tar --data_file $f --prediction_file bmnc --save_dir out_bmnc_full --filter1 4 | python -u run_LRP.py --n_clusters 27 --ae_weights ./out_bmnc_full/AE_weights_bmnc.pth.tar --cluster_index_file ./out_bmnc_full/1_pred.csv --data_file $f --save_dir out_bmnc_full --filter1 5 | -------------------------------------------------------------------------------- /script/run_scMDC_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -l 2 | #SBATCH --gres=gpu:1 3 | 4 | #Here are the commands for the real data experiments. We test ten times on each dataset. 5 | 6 | f=../datasets/CITESeq_GSE128639_BMNC_annodata.h5 7 | echo "Run CITE-seq BMNC" 8 | python -u run_scMDC.py --n_clusters 27 --ae_weight_file AE_weights_bmnc.pth.tar --data_file $f --save_dir citeseq_bmnc / 9 | --embedding_file --prediction_file --filter1 10 | 11 | f=../datasets/SMAGESeq_10X_pbmc_granulocyte_plus.h5 12 | echo "Run SMAGE-seq PBMC10K" 13 | python -u run_scMDC.py --n_clusters 12 --ae_weight_file AE_weights_pbmc10k.pth.tar --data_file $f --save_dir atac_pbmc10k / 14 | --embedding_file --prediction_file --filter1 --filter2 --f1 2000 --f2 2000 -el 256 128 64 -dl1 64 128 256 -dl2 64 128 256 -phi1 0.005 -phi2 0.005 -signma2 2.5 -tau .1 15 | 16 | f=../datasets/CITESeq_realdata_spleen_lymph_111_anno_multiBatch.h5 17 | echo "Run multi-batch CITE-seq SLN111" 18 | python -u run_scMDC_batch.py --n_clusters 35 --ae_weight_file AE_weights_sln111.pth.tar --data_file $f --save_dir citeseq_sln111 / 19 | --embedding_file --prediction_file --filter1 20 | -------------------------------------------------------------------------------- /src/LRP.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from torch.functional import norm 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | from torch.nn import Parameter 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from torch.optim.lr_scheduler import * 10 | from torch.utils.data import DataLoader, TensorDataset 11 | from torch.nn.utils import clip_grad_norm_ 12 | import numpy as np 13 | import math, os 14 | 15 | 16 | class ClustDistLayer(nn.Module): 17 | def __init__(self, centroids, n_clusters, clust_list, device): 18 | super(ClustDistLayer, self).__init__() 19 | self.centroids = Variable(centroids).to(device) 20 | self.n_clusters = n_clusters 21 | self.clust_list = clust_list 22 | 23 | def forward(self, x, curr_clust_id): 24 | output = [] 25 | for i in self.clust_list: 26 | if i==curr_clust_id: 27 | continue 28 | weight = 2 * (self.centroids[self.clust_list.index(curr_clust_id)] - self.centroids[self.clust_list.index(i)]) 29 | bias = torch.norm(self.centroids[self.clust_list.index(curr_clust_id)], p=2) - torch.norm(self.centroids[self.clust_list.index(i)], p=2) 30 | h = torch.matmul(x, weight.T) + bias 31 | output.append(h.unsqueeze(1)) 32 | 33 | return torch.cat(output, dim=1) 34 | 35 | 36 | class ClustMinPoolLayer(nn.Module): 37 | def __init__(self, beta): 38 | super(ClustMinPoolLayer, self).__init__() 39 | self.beta = beta 40 | self.eps = 1e-10 41 | 42 | def forward(self, inputs): 43 | return - torch.log(torch.sum(torch.exp(inputs * -self.beta), dim=1) + self.eps) 44 | 45 | 46 | class LRP(nn.Module): 47 | def __init__(self, model, X1, X2, Z, clust_ids, n_clusters, beta=1., device="cuda"): 48 | super(LRP, self).__init__() 49 | #model.freeze_model() 50 | self.model = model 51 | self.clust_ids = clust_ids 52 | self.n_clusters = n_clusters 53 | self.clust_list = np.unique(clust_ids).astype(int).tolist() 54 | self.centroids_ =torch.tensor(self.set_centroids(Z), dtype=torch.float32) 55 | self.X1_ = torch.tensor(X1, dtype=torch.float32) 56 | self.X2_ = torch.tensor(X2, dtype=torch.float32) 57 | self.Z_ = torch.tensor(Z, dtype=torch.float32) 58 | self.distLayer = ClustDistLayer(self.centroids_, n_clusters, self.clust_list, device).to(device) 59 | self.clustMinPool = ClustMinPoolLayer(beta).to(device) 60 | self.device = device 61 | 62 | def set_centroids(self, Z): 63 | centroids = [] 64 | for i in self.clust_list: 65 | clust_Z = Z[self.clust_ids==i] 66 | curr_centroid = np.mean(clust_Z, axis=0) 67 | centroids.append(curr_centroid) 68 | 69 | return np.stack(centroids, axis=0) 70 | 71 | def clust_minpoolAct(self, X1, X2, curr_clust_id): 72 | z,_,_,_,_,_,_,_,_ = self.model.forwardAE(X1, X2) 73 | return self.clustMinPool(self.distLayer(z, curr_clust_id)) 74 | 75 | def calc_carlini_wagner_one_vs_one(self, clust_c_id, clust_k_id, margin=1., lamda=1e2, max_iter=5000, lr=2e-3, use_abs=True): 76 | X1_0 = Variable(self.X1_[self.clust_ids==clust_c_id], requires_grad=False).to(self.device) 77 | curr_X1 = Variable(X1_0 + 1e-6, requires_grad=True).to(self.device) 78 | X2_0 = Variable(self.X2_[self.clust_ids==clust_c_id], requires_grad=False).to(self.device) 79 | curr_X2 = Variable(X2_0 + 1e-6, requires_grad=True).to(self.device) 80 | optimizer = optim.SGD([curr_X1, curr_X2], lr=lr) 81 | 82 | for iter in range(max_iter): 83 | clust_c_minpoolAct_tensor = self.clust_minpoolAct(curr_X1, curr_X2, clust_c_id) 84 | clust_k_minpoolAct_tensor = self.clust_minpoolAct(curr_X1, curr_X2, clust_k_id) 85 | clust_loss_tensor = margin + clust_c_minpoolAct_tensor - clust_k_minpoolAct_tensor 86 | clust_loss_tensor = torch.maximum(clust_loss_tensor, torch.zeros_like(clust_loss_tensor)) 87 | clust_loss = torch.sum(clust_loss_tensor) 88 | 89 | norm_loss = torch.norm(curr_X1 - X1_0, p=1) + torch.norm(curr_X2 - X2_0, p=1) 90 | 91 | loss = clust_loss * lamda + norm_loss 92 | 93 | optimizer.zero_grad() 94 | loss.backward() 95 | optimizer.step() 96 | 97 | if (iter+1) % 50 == 0: 98 | print('Iteration {}, Total loss:{:.8f}, clust loss:{:.8f}, L1 penalty:{:.8f}'.format(iter, loss.item(), clust_loss.item(), norm_loss.item())) 99 | 100 | if use_abs: 101 | rel_score1 = torch.mean(torch.abs(curr_X1 - X1_0), dim=0) 102 | rel_score2 = torch.mean(torch.abs(curr_X2 - X2_0), dim=0) 103 | else: 104 | rel_score1 = torch.mean(curr_X1 - X1_0, dim=0) 105 | rel_score2 = torch.mean(curr_X2 - X2_0, dim=0) 106 | return rel_score1.data.cpu().numpy(), rel_score2.data.cpu().numpy() 107 | 108 | def calc_carlini_wagner_one_vs_rest(self, clust_c_id, margin=1., lamda=1e2, max_iter=5000, lr=2e-3, use_abs=True): 109 | X1_0 = Variable(self.X1_[self.clust_ids==clust_c_id], requires_grad=False).to(self.device) 110 | curr_X1 = Variable(X1_0 + 1e-6, requires_grad=True).to(self.device) 111 | X2_0 = Variable(self.X2_[self.clust_ids==clust_c_id], requires_grad=False).to(self.device) 112 | curr_X2 = Variable(X2_0 + 1e-6, requires_grad=True).to(self.device) 113 | optimizer = optim.SGD([curr_X1, curr_X2], lr=lr) 114 | 115 | for iter in range(max_iter): 116 | clust_rest_minpoolAct_tensor_list = [] 117 | for clust_k_id in self.clust_list: 118 | if clust_k_id == clust_c_id: 119 | continue 120 | clust_k_minpoolAct_tensor = self.clust_minpoolAct(curr_X1, curr_X2, clust_k_id) 121 | clust_rest_minpoolAct_tensor_list.append(clust_k_minpoolAct_tensor) 122 | clust_rest_minpoolAct_tensor = clust_rest_minpoolAct_tensor_list[0] 123 | for clust_k_id in range(1, len(clust_rest_minpoolAct_tensor_list)): 124 | clust_rest_minpoolAct_tensor = torch.maximum(clust_rest_minpoolAct_tensor, clust_rest_minpoolAct_tensor_list[clust_k_id]) 125 | 126 | clust_c_minpoolAct_tensor = self.clust_minpoolAct(curr_X1, curr_X2, clust_c_id) 127 | 128 | clust_loss_tensor = margin + clust_c_minpoolAct_tensor - clust_rest_minpoolAct_tensor 129 | clust_loss_tensor = torch.maximum(clust_loss_tensor, torch.zeros_like(clust_loss_tensor)) 130 | clust_loss = torch.sum(clust_loss_tensor) 131 | 132 | norm_loss = torch.norm(curr_X1 - X1_0, p=1) + torch.norm(curr_X2 - X2_0, p=1) 133 | 134 | loss = clust_loss * lamda + norm_loss 135 | 136 | optimizer.zero_grad() 137 | loss.backward() 138 | optimizer.step() 139 | 140 | if (iter+1) % 50 == 0: 141 | print('Iteration {}, Total loss:{:.8f}, clust loss:{:.8f}, L1 penalty:{:.8f}'.format(iter, loss.item(), clust_loss.item(), norm_loss.item())) 142 | 143 | if use_abs: 144 | rel_score1 = torch.mean(torch.abs(curr_X1 - X1_0), dim=0) 145 | rel_score2 = torch.mean(torch.abs(curr_X2 - X2_0), dim=0) 146 | else: 147 | rel_score1 = torch.mean(curr_X1 - X1_0, dim=0) 148 | rel_score2 = torch.mean(curr_X2 - X2_0, dim=0) 149 | return rel_score1.data.cpu().numpy(), rel_score2.data.cpu().numpy() 150 | -------------------------------------------------------------------------------- /src/Simulation.R: -------------------------------------------------------------------------------- 1 | library(SymSim) 2 | library(rhdf5) 3 | library(Seurat) 4 | 5 | phyla <- read.tree("tree.txt") 6 | phyla2 <- read.tree("tree.txt") 7 | data(gene_len_pool) 8 | 9 | #This is a simulation script with adding batch effect 10 | #First we need to modify the DivideBatches2 function in SymSim to make the same batch partition in mRNA and ADT data. 11 | DivideBatches2 <- function(observed_counts_res, batchIDs, batch_effect_size=1){ 12 | observed_counts <- observed_counts_res[["counts"]] 13 | meta_cell <- observed_counts_res[["cell_meta"]] 14 | ncells <- dim(observed_counts)[2] 15 | ngenes <- dim(observed_counts)[1] 16 | nbatch <- unique(batchIDs) 17 | meta_cell2 <- data.frame(batch = batchIDs, stringsAsFactors = F) 18 | meta_cell <- cbind(meta_cell, meta_cell2) 19 | mean_matrix <- matrix(0, ngenes, nbatch) 20 | gene_mean <- rnorm(ngenes, 0, 1) 21 | temp <- lapply(1:ngenes, function(igene) { 22 | return(runif(nbatch, min = gene_mean[igene] - batch_effect_size, 23 | max = gene_mean[igene] + batch_effect_size)) 24 | }) 25 | mean_matrix <- do.call(rbind, temp) 26 | batch_factor <- matrix(0, ngenes, ncells) 27 | for (igene in 1:ngenes) { 28 | for (icell in 1:ncells) { 29 | batch_factor[igene, icell] <- rnorm(n = 1, mean = mean_matrix[igene, 30 | batchIDs[icell]], sd = 0.01) 31 | } 32 | } 33 | observed_counts <- round(2^(log2(observed_counts) + batch_factor)) 34 | return(list(counts = observed_counts, cell_meta = meta_cell)) 35 | } 36 | 37 | for(k in 1:10){ 38 | ##RNA 39 | ncells = 1000 40 | nbatchs = 2 41 | batchIDs <- sample(1:nbatchs, ncells, replace = TRUE) 42 | print(k) 43 | print("Simulate RNA") 44 | ngenes <- 2000 45 | gene_len <- sample(gene_len_pool, ngenes, replace = FALSE) 46 | true_RNAcounts_res <- SimulateTrueCounts(ncells_total=ncells, 47 | min_popsize=50, 48 | i_minpop=1, 49 | ngenes=ngenes, 50 | nevf=10, 51 | evf_type="discrete", 52 | n_de_evf=6, 53 | vary="s", 54 | Sigma=0.6, 55 | phyla=phyla, 56 | randseed=k+1000) 57 | 58 | observed_RNAcounts <- True2ObservedCounts(true_counts=true_RNAcounts_res[[1]], 59 | meta_cell=true_RNAcounts_res[[3]], 60 | protocol="UMI", 61 | alpha_mean=0.00075, 62 | alpha_sd=0.0001, 63 | gene_len=gene_len, 64 | depth_mean=50000, 65 | depth_sd=3000, 66 | ) 67 | 68 | batch_RNAcounts <- DivideBatches2(observed_RNAcounts, batchIDs, batch_effect_size = 1) 69 | 70 | ## Add batch effects 71 | print((sum(batch_RNAcounts$counts==0)-sum(true_RNAcounts_res$counts==0))/sum(true_RNAcounts_res$counts>0)) 72 | print(sum(batch_RNAcounts$counts==0)/prod(dim(batch_RNAcounts$counts))) 73 | 74 | ##ADT 75 | print("Simulate ADT") 76 | nadts <- 100 77 | gene_len <- sample(gene_len_pool, nadts, replace = FALSE) 78 | #The true counts of the five populations can be simulated: 79 | true_ADTcounts_res <- SimulateTrueCounts(ncells_total=ncells, 80 | min_popsize=50, 81 | i_minpop=1, 82 | ngenes=nadts, 83 | nevf=10, 84 | evf_type="discrete", 85 | n_de_evf=6, 86 | vary="s", 87 | Sigma=0.3, 88 | phyla=phyla2, 89 | randseed=k+1000) 90 | 91 | observed_ADTcounts <- True2ObservedCounts(true_counts=true_ADTcounts_res[[1]], 92 | meta_cell=true_ADTcounts_res[[3]], 93 | protocol="UMI", 94 | alpha_mean=0.045, 95 | alpha_sd=0.01, 96 | gene_len=gene_len, 97 | depth_mean=50000, 98 | depth_sd=3000, 99 | ) 100 | 101 | ## Add batch effects 102 | batch_ADTcounts <- DivideBatches2(observed_ADTcounts, batchIDs, batch_effect_size = 1) 103 | 104 | print((sum(batch_ADTcounts$counts==0)-sum(true_ADTcounts_res$counts==0))/sum(true_ADTcounts_res$counts>0)) 105 | print(sum(batch_ADTcounts$counts==0)/prod(dim(batch_ADTcounts$counts))) 106 | 107 | y1 = batch_RNAcounts$cell_meta$pop 108 | y2 = batch_ADTcounts$cell_meta$pop 109 | batch1 = batch_ADTcounts$cell_meta$batch 110 | batch2 = batch_RNAcounts$cell_meta$batch 111 | print(sum(y1==y2)) 112 | print(sum(batch1 == batch2)) 113 | 114 | counts1 <- batch_RNAcounts[[1]] 115 | counts2 <- batch_ADTcounts[[1]] 116 | 117 | #filter 118 | rownames(counts2) <- paste("G",1:nrow(counts2),sep = "") 119 | colnames(counts2) <- paste("C",1:ncol(counts2),sep = "") 120 | 121 | pbmc <- CreateSeuratObject(counts = counts2, project = "P2", min.cells = 0, min.features = 0) 122 | pbmc <- NormalizeData(pbmc, normalization.method = "LogNormalize") 123 | pbmc <- FindVariableFeatures(pbmc, selection.method = "vst", nfeatures = 30) 124 | counts2 <- counts2[pbmc@assays[["RNA"]]@var.features,] 125 | 126 | h5file = paste("./batch/Simulation.", k, ".h5", sep="") 127 | h5createFile(h5file) 128 | h5write(as.matrix(counts1), h5file,"X1") 129 | h5write(as.matrix(counts2), h5file,"X2") 130 | h5write(y1, h5file,"Y") 131 | h5write(batch1, h5file,"Batch") 132 | } 133 | -------------------------------------------------------------------------------- /src/fig1_.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xianglin226/scMDC/d0d4baeefc1fda27342d7b715167cbae4d081ca7/src/fig1_.png -------------------------------------------------------------------------------- /src/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class NBLoss(nn.Module): 7 | def __init__(self): 8 | super(NBLoss, self).__init__() 9 | 10 | def forward(self, x, mean, disp, scale_factor=1.0): 11 | eps = 1e-10 12 | scale_factor = scale_factor[:, None] 13 | mean = mean * scale_factor 14 | 15 | t1 = torch.lgamma(disp+eps) + torch.lgamma(x+1.0) - torch.lgamma(x+disp+eps) 16 | t2 = (disp+x) * torch.log(1.0 + (mean/(disp+eps))) + (x * (torch.log(disp+eps) - torch.log(mean+eps))) 17 | result = t1 + t2 18 | 19 | result = torch.mean(result) 20 | return result 21 | 22 | 23 | class ZINBLoss(nn.Module): 24 | def __init__(self): 25 | super(ZINBLoss, self).__init__() 26 | 27 | def forward(self, x, mean, disp, pi, scale_factor=1.0, ridge_lambda=0.0): 28 | eps = 1e-10 29 | scale_factor = scale_factor[:, None] 30 | mean = mean * scale_factor 31 | 32 | t1 = torch.lgamma(disp+eps) + torch.lgamma(x+1.0) - torch.lgamma(x+disp+eps) 33 | t2 = (disp+x) * torch.log(1.0 + (mean/(disp+eps))) + (x * (torch.log(disp+eps) - torch.log(mean+eps))) 34 | nb_final = t1 + t2 35 | 36 | nb_case = nb_final - torch.log(1.0-pi+eps) 37 | zero_nb = torch.pow(disp/(disp+mean+eps), disp) 38 | zero_case = -torch.log(pi + ((1.0-pi)*zero_nb)+eps) 39 | result = torch.where(torch.le(x, 1e-8), zero_case, nb_case) 40 | 41 | if ridge_lambda > 0: 42 | ridge = ridge_lambda*torch.square(pi) 43 | result += ridge 44 | 45 | result = torch.mean(result) 46 | return result 47 | 48 | 49 | class GaussianNoise(nn.Module): 50 | def __init__(self, sigma=0): 51 | super(GaussianNoise, self).__init__() 52 | self.sigma = sigma 53 | 54 | def forward(self, x): 55 | if self.training: 56 | x = x + self.sigma * torch.randn_like(x) 57 | return x 58 | 59 | 60 | class MeanAct(nn.Module): 61 | def __init__(self): 62 | super(MeanAct, self).__init__() 63 | 64 | def forward(self, x): 65 | return torch.clamp(torch.exp(x), min=1e-5, max=1e6) 66 | 67 | class DispAct(nn.Module): 68 | def __init__(self): 69 | super(DispAct, self).__init__() 70 | 71 | def forward(self, x): 72 | return torch.clamp(F.softplus(x), min=1e-4, max=1e4) 73 | -------------------------------------------------------------------------------- /src/preprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Goekcen Eraslan 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import pickle, os, numbers 22 | 23 | import numpy as np 24 | import scipy as sp 25 | import pandas as pd 26 | import scanpy as sc 27 | from sklearn.model_selection import train_test_split 28 | from sklearn.preprocessing import scale 29 | import scipy 30 | 31 | #TODO: Fix this 32 | class AnnSequence: 33 | def __init__(self, matrix, batch_size, sf=None): 34 | self.matrix = matrix 35 | if sf is None: 36 | self.size_factors = np.ones((self.matrix.shape[0], 1), 37 | dtype=np.float32) 38 | else: 39 | self.size_factors = sf 40 | self.batch_size = batch_size 41 | 42 | def __len__(self): 43 | return len(self.matrix) // self.batch_size 44 | 45 | def __getitem__(self, idx): 46 | batch = self.matrix[idx*self.batch_size:(idx+1)*self.batch_size] 47 | batch_sf = self.size_factors[idx*self.batch_size:(idx+1)*self.batch_size] 48 | 49 | # return an (X, Y) pair 50 | return {'count': batch, 'size_factors': batch_sf}, batch 51 | 52 | 53 | def read_dataset(adata, transpose=False, test_split=False, copy=False): 54 | 55 | if isinstance(adata, sc.AnnData): 56 | if copy: 57 | adata = adata.copy() 58 | elif isinstance(adata, str): 59 | adata = sc.read(adata) 60 | else: 61 | raise NotImplementedError 62 | 63 | norm_error = 'Make sure that the dataset (adata.X) contains unnormalized count data.' 64 | assert 'n_count' not in adata.obs, norm_error 65 | 66 | if adata.X.size < 50e6: # check if adata.X is integer only if array is small 67 | if sp.sparse.issparse(adata.X): 68 | assert (adata.X.astype(int) != adata.X).nnz == 0, norm_error 69 | else: 70 | assert np.all(adata.X.astype(int) == adata.X), norm_error 71 | 72 | if transpose: adata = adata.transpose() 73 | 74 | if test_split: 75 | train_idx, test_idx = train_test_split(np.arange(adata.n_obs), test_size=0.1, random_state=42) 76 | spl = pd.Series(['train'] * adata.n_obs) 77 | spl.iloc[test_idx] = 'test' 78 | adata.obs['DCA_split'] = spl.values 79 | else: 80 | adata.obs['DCA_split'] = 'train' 81 | 82 | adata.obs['DCA_split'] = adata.obs['DCA_split'].astype('category') 83 | print('### Autoencoder: Successfully preprocessed {} genes and {} cells.'.format(adata.n_vars, adata.n_obs)) 84 | 85 | return adata 86 | 87 | def clr_normalize_each_cell(adata): 88 | """Normalize count vector for each cell, i.e. for each row of .X""" 89 | 90 | def seurat_clr(x): 91 | # TODO: support sparseness 92 | s = np.sum(np.log1p(x[x > 0])) 93 | exp = np.exp(s / len(x)) 94 | return np.log1p(x / exp) 95 | 96 | adata.raw = adata.copy() 97 | sc.pp.normalize_per_cell(adata) 98 | adata.obs['size_factors'] = adata.obs.n_counts / np.median(adata.obs.n_counts) 99 | 100 | # apply to dense or sparse matrix, along axis. returns dense matrix 101 | adata.X = np.apply_along_axis( 102 | seurat_clr, 1, (adata.raw.X.A if scipy.sparse.issparse(adata.raw.X) else adata.raw.X) 103 | ) 104 | return adata 105 | 106 | def normalize(adata, filter_min_counts=True, size_factors=True, normalize_input=True, logtrans_input=True): 107 | 108 | if filter_min_counts: 109 | sc.pp.filter_genes(adata, min_counts=1) 110 | sc.pp.filter_cells(adata, min_counts=1) 111 | 112 | if size_factors or normalize_input or logtrans_input: 113 | adata.raw = adata.copy() 114 | else: 115 | adata.raw = adata 116 | 117 | if size_factors: 118 | sc.pp.normalize_per_cell(adata) 119 | adata.obs['size_factors'] = adata.obs.n_counts / np.median(adata.obs.n_counts) 120 | else: 121 | adata.obs['size_factors'] = 1.0 122 | 123 | if logtrans_input: 124 | sc.pp.log1p(adata) 125 | 126 | if normalize_input: 127 | sc.pp.scale(adata) 128 | 129 | return adata 130 | 131 | def read_genelist(filename): 132 | genelist = list(set(open(filename, 'rt').read().strip().split('\n'))) 133 | assert len(genelist) > 0, 'No genes detected in genelist file' 134 | print('### Autoencoder: Subset of {} genes will be denoised.'.format(len(genelist))) 135 | 136 | return genelist 137 | 138 | def write_text_matrix(matrix, filename, rownames=None, colnames=None, transpose=False): 139 | if transpose: 140 | matrix = matrix.T 141 | rownames, colnames = colnames, rownames 142 | 143 | pd.DataFrame(matrix, index=rownames, columns=colnames).to_csv(filename, 144 | sep='\t', 145 | index=(rownames is not None), 146 | header=(colnames is not None), 147 | float_format='%.6f') 148 | def read_pickle(inputfile): 149 | return pickle.load(open(inputfile, "rb")) 150 | -------------------------------------------------------------------------------- /src/run_LRP.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | import math, os 3 | from sklearn import metrics 4 | from sklearn.cluster import KMeans 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | from torch.nn import Parameter 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from torch.utils.data import DataLoader, TensorDataset 12 | 13 | from scMDC import scMultiCluster 14 | import numpy as np 15 | import collections 16 | import h5py 17 | import scanpy as sc 18 | from preprocess import read_dataset, normalize, clr_normalize_each_cell 19 | from utils import * 20 | from functools import reduce 21 | from LRP import LRP 22 | 23 | 24 | if __name__ == "__main__": 25 | 26 | # setting the hyper parameters 27 | import argparse 28 | parser = argparse.ArgumentParser(description='train', 29 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 30 | parser.add_argument('--n_clusters', default=8, type=int) 31 | parser.add_argument('--cutoff', default=0.5, type=float, help='Start to train combined layer after what ratio of epoch') 32 | parser.add_argument('--batch_size', default=256, type=int) 33 | parser.add_argument('--data_file', default='Simulation.1.h5') 34 | parser.add_argument('--cluster_index_file', default='label.txt') 35 | parser.add_argument('--maxiter', default=10000, type=int) 36 | parser.add_argument('--pretrain_epochs', default=400, type=int) 37 | parser.add_argument('--gamma', default=.1, type=float, 38 | help='coefficient of clustering loss') 39 | parser.add_argument('--tau', default=1., type=float, 40 | help='fuzziness of clustering loss') 41 | parser.add_argument('--phi1', default=0.001, type=float, 42 | help='coefficient of KL loss in pretraining stage') 43 | parser.add_argument('--phi2', default=0.001, type=float, 44 | help='coefficient of KL loss in clustering stage') 45 | parser.add_argument('--update_interval', default=1, type=int) 46 | parser.add_argument('--tol', default=0.001, type=float) 47 | parser.add_argument('--ae_weights', default=None) 48 | parser.add_argument('--save_dir', default='results/') 49 | parser.add_argument('--ae_weight_file', default='AE_weights_1.pth.tar') 50 | parser.add_argument('--resolution', default=0.2, type=float) 51 | parser.add_argument('--n_neighbors', default=30, type=int) 52 | parser.add_argument('--embedding_file', action='store_true', default=False) 53 | parser.add_argument('--prediction_file', action='store_true', default=False) 54 | parser.add_argument('-el','--encodeLayer', nargs='+', default=[256,64,32,16]) 55 | parser.add_argument('-dl1','--decodeLayer1', nargs='+', default=[16,64,256]) 56 | parser.add_argument('-dl2','--decodeLayer2', nargs='+', default=[16,20]) 57 | parser.add_argument('--sigma1', default=2.5, type=float) 58 | parser.add_argument('--sigma2', default=1.5, type=float) 59 | parser.add_argument('--f1', default=1000, type=float, help='Number of mRNA after feature selection') 60 | parser.add_argument('--f2', default=2000, type=float, help='Number of ADT/ATAC after feature selection') 61 | parser.add_argument('--filter1', action='store_true', default=False, help='Do mRNA selection') 62 | parser.add_argument('--filter2', action='store_true', default=False, help='Do ADT/ATAC selection') 63 | parser.add_argument('--run', default=1, type=int) 64 | parser.add_argument('--beta', default=1., type=float, 65 | help='coefficient of the clustering fuzziness') 66 | parser.add_argument('--margin', default=1., type=float, 67 | help='margin of difference between logits') 68 | parser.add_argument('--lamda', default=100., type=float, 69 | help='coefficient of the clustering perturbation loss') 70 | parser.add_argument('--lr', default=0.001, type=int) 71 | parser.add_argument('--device', default='cuda') 72 | 73 | args = parser.parse_args() 74 | print(args) 75 | data_mat = h5py.File(args.data_file) 76 | x1 = np.array(data_mat['X1']) 77 | x2 = np.array(data_mat['X2']) 78 | #y = np.array(data_mat['Y']) - 1 79 | data_mat.close() 80 | 81 | 82 | clust_ids = np.loadtxt(args.cluster_index_file, delimiter=",").astype(int) 83 | 84 | #Gene features 85 | if args.filter1: 86 | importantGenes = geneSelection(x1, n=args.f1, plot=False) 87 | x1 = x1[:, importantGenes] 88 | if args.filter2: 89 | importantGenes = geneSelection(x2, n=args.f2, plot=False) 90 | x2 = x2[:, importantGenes] 91 | 92 | adata1 = sc.AnnData(x1) 93 | #adata1.obs['Group'] = y 94 | 95 | adata1 = read_dataset(adata1, 96 | transpose=False, 97 | test_split=False, 98 | copy=True) 99 | 100 | adata1 = normalize(adata1, 101 | size_factors=True, 102 | normalize_input=True, 103 | logtrans_input=True) 104 | 105 | adata2 = sc.AnnData(x2) 106 | #adata2.obs['Group'] = y 107 | adata2 = read_dataset(adata2, 108 | transpose=False, 109 | test_split=False, 110 | copy=True) 111 | 112 | adata2 = normalize(adata2, 113 | size_factors=True, 114 | normalize_input=True, 115 | logtrans_input=True) 116 | 117 | #adata2 = clr_normalize_each_cell(adata2) 118 | 119 | input_size1 = adata1.n_vars 120 | input_size2 = adata2.n_vars 121 | print(adata1.X.shape) 122 | print(adata2.X.shape) 123 | 124 | print(args) 125 | 126 | encodeLayer = list(map(int, args.encodeLayer)) 127 | decodeLayer1 = list(map(int, args.decodeLayer1)) 128 | decodeLayer2 = list(map(int, args.decodeLayer2)) 129 | 130 | model = scMultiCluster(input_dim1=input_size1, input_dim2=input_size2, tau=args.tau, 131 | encodeLayer=encodeLayer, decodeLayer1=decodeLayer1, decodeLayer2=decodeLayer2, 132 | activation='elu', sigma1=args.sigma1, sigma2=args.sigma2, gamma=args.gamma, 133 | cutoff = args.cutoff, phi1=args.phi1, phi2=args.phi2, device=args.device).to(args.device) 134 | 135 | print(str(model)) 136 | 137 | if os.path.isfile(args.ae_weights): 138 | print("==> loading checkpoint '{}'".format(args.ae_weights)) 139 | checkpoint = torch.load(args.ae_weights) 140 | model.load_state_dict(checkpoint['ae_state_dict']) 141 | else: 142 | print("==> no checkpoint found at '{}'".format(args.ae_weights)) 143 | raise ValueError 144 | 145 | n_clusters = np.unique(clust_ids).shape[0] 146 | print("n cluster is: " + str(n_clusters)) 147 | 148 | Z = model.encodeBatch(torch.tensor(adata1.X).to(args.device), torch.tensor(adata2.X).to(args.device)).data.cpu().numpy() 149 | 150 | cluster_list = np.unique(clust_ids).astype(int).tolist() 151 | print(cluster_list) 152 | 153 | model_explainer = LRP(model, X1=adata1.X, X2=adata2.X, Z=Z, clust_ids=clust_ids, n_clusters=n_clusters, beta=args.beta).to(args.device) 154 | 155 | #for clust_c in [cluster_ind[0]]: #range(args.n_clusters): 156 | # for clust_k in [cluster_ind[1]]: #range(clust_c+1, args.n_clusters): 157 | # print("Cluster"+str(clust_c)+" vs Cluster"+str(clust_k)) 158 | # rel_score1, rel_score2 = model_explainer.calc_carlini_wagner_one_vs_one(clust_c, clust_k, margin=args.margin, lamda=args.lamda, max_iter=args.maxiter, lr=args.lr) 159 | # print(rel_score1.shape) 160 | # print(rel_score2.shape) 161 | # np.savetxt(args.save_dir + "/" + str(clust_c)+"_vs_"+str(clust_k)+"_rel_mRNA_scores.csv", rel_score1, delimiter=",") 162 | # np.savetxt(args.save_dir + "/" + str(clust_c)+"_vs_"+str(clust_k)+"_rel_ADT_scores.csv", rel_score2, delimiter=",") 163 | 164 | for clust_c in cluster_list: 165 | print("Cluster"+str(clust_c)+" vs Rest") 166 | rel_score1, rel_score2 = model_explainer.calc_carlini_wagner_one_vs_rest(clust_c, margin=args.margin, lamda=args.lamda, max_iter=args.maxiter, lr=args.lr) 167 | np.savetxt(args.save_dir + "/" + str(clust_c)+"_vs_rest_rel_mRNA_scores.csv", rel_score1, delimiter=",") 168 | np.savetxt(args.save_dir + "/" + str(clust_c)+"_vs_rest_rel_ADT_scores.csv", rel_score2, delimiter=",") 169 | -------------------------------------------------------------------------------- /src/run_scMDC.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | import math, os 3 | from sklearn import metrics 4 | from sklearn.cluster import KMeans 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | from torch.nn import Parameter 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from torch.utils.data import DataLoader, TensorDataset 12 | 13 | from scMDC import scMultiCluster 14 | import numpy as np 15 | import collections 16 | import h5py 17 | import scanpy as sc 18 | from preprocess import read_dataset, normalize, clr_normalize_each_cell 19 | from utils import * 20 | 21 | if __name__ == "__main__": 22 | 23 | # setting the hyper parameters 24 | import argparse 25 | parser = argparse.ArgumentParser(description='train', 26 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 27 | parser.add_argument('--n_clusters', default=27, type=int) 28 | parser.add_argument('--cutoff', default=0.5, type=float, help='Start to train combined layer after what ratio of epoch') 29 | parser.add_argument('--batch_size', default=256, type=int) 30 | parser.add_argument('--data_file', default='Normalized_filtered_BMNC_GSE128639_Seurat.h5') 31 | parser.add_argument('--maxiter', default=5000, type=int) 32 | parser.add_argument('--pretrain_epochs', default=400, type=int) 33 | parser.add_argument('--gamma', default=.1, type=float, 34 | help='coefficient of clustering loss') 35 | parser.add_argument('--tau', default=1., type=float, 36 | help='fuzziness of clustering loss') 37 | parser.add_argument('--phi1', default=0.001, type=float, 38 | help='coefficient of KL loss in pretraining stage') 39 | parser.add_argument('--phi2', default=0.001, type=float, 40 | help='coefficient of KL loss in clustering stage') 41 | parser.add_argument('--update_interval', default=1, type=int) 42 | parser.add_argument('--tol', default=0.001, type=float) 43 | parser.add_argument('--lr', default=1., type=float) 44 | parser.add_argument('--ae_weights', default=None) 45 | parser.add_argument('--save_dir', default='results/') 46 | parser.add_argument('--ae_weight_file', default='AE_weights_1.pth.tar') 47 | parser.add_argument('--resolution', default=0.2, type=float) 48 | parser.add_argument('--n_neighbors', default=30, type=int) 49 | parser.add_argument('--embedding_file', action='store_true', default=False) 50 | parser.add_argument('--prediction_file', action='store_true', default=False) 51 | parser.add_argument('-el','--encodeLayer', nargs='+', default=[256,64,32,16]) 52 | parser.add_argument('-dl1','--decodeLayer1', nargs='+', default=[16,64,256]) 53 | parser.add_argument('-dl2','--decodeLayer2', nargs='+', default=[16,20]) 54 | parser.add_argument('--sigma1', default=2.5, type=float) 55 | parser.add_argument('--sigma2', default=1.5, type=float) 56 | parser.add_argument('--f1', default=1000, type=float, help='Number of mRNA after feature selection') 57 | parser.add_argument('--f2', default=2000, type=float, help='Number of ADT/ATAC after feature selection') 58 | parser.add_argument('--filter1', action='store_true', default=False, help='Do mRNA selection') 59 | parser.add_argument('--filter2', action='store_true', default=False, help='Do ADT/ATAC selection') 60 | parser.add_argument('--run', default=1, type=int) 61 | parser.add_argument('--device', default='cuda') 62 | parser.add_argument('--no_labels', action='store_true', default=False) 63 | args = parser.parse_args() 64 | print(args) 65 | 66 | data_mat = h5py.File(args.data_file) 67 | x1 = np.array(data_mat['X1']) 68 | x2 = np.array(data_mat['X2']) 69 | if not args.no_labels: 70 | y = np.array(data_mat['Y']) 71 | data_mat.close() 72 | 73 | #Gene filter 74 | if args.filter1: 75 | importantGenes = geneSelection(x1, n=args.f1, plot=False) 76 | x1 = x1[:, importantGenes] 77 | if args.filter2: 78 | importantGenes = geneSelection(x2, n=args.f2, plot=False) 79 | x2 = x2[:, importantGenes] 80 | 81 | # preprocessing scRNA-seq read counts matrix 82 | adata1 = sc.AnnData(x1) 83 | #adata1.obs['Group'] = y 84 | 85 | adata1 = read_dataset(adata1, 86 | transpose=False, 87 | test_split=False, 88 | copy=True) 89 | 90 | adata1 = normalize(adata1, 91 | size_factors=True, 92 | normalize_input=True, 93 | logtrans_input=True) 94 | 95 | adata2 = sc.AnnData(x2) 96 | #adata2.obs['Group'] = y 97 | adata2 = read_dataset(adata2, 98 | transpose=False, 99 | test_split=False, 100 | copy=True) 101 | 102 | adata2 = normalize(adata2, 103 | size_factors=True, 104 | normalize_input=True, 105 | logtrans_input=True) 106 | 107 | #adata2 = clr_normalize_each_cell(adata2) 108 | 109 | input_size1 = adata1.n_vars 110 | input_size2 = adata2.n_vars 111 | 112 | print(args) 113 | 114 | encodeLayer = list(map(int, args.encodeLayer)) 115 | decodeLayer1 = list(map(int, args.decodeLayer1)) 116 | decodeLayer2 = list(map(int, args.decodeLayer2)) 117 | 118 | model = scMultiCluster(input_dim1=input_size1, input_dim2=input_size2, tau=args.tau, 119 | encodeLayer=encodeLayer, decodeLayer1=decodeLayer1, decodeLayer2=decodeLayer2, 120 | activation='elu', sigma1=args.sigma1, sigma2=args.sigma2, gamma=args.gamma, 121 | cutoff = args.cutoff, phi1=args.phi1, phi2=args.phi2, device=args.device).to(args.device) 122 | 123 | print(str(model)) 124 | 125 | if not os.path.exists(args.save_dir): 126 | os.makedirs(args.save_dir) 127 | 128 | t0 = time() 129 | if args.ae_weights is None: 130 | model.pretrain_autoencoder(X1=adata1.X, X_raw1=adata1.raw.X, sf1=adata1.obs.size_factors, 131 | X2=adata2.X, X_raw2=adata2.raw.X, sf2=adata2.obs.size_factors, batch_size=args.batch_size, 132 | epochs=args.pretrain_epochs, ae_weights=args.ae_weight_file) 133 | else: 134 | if os.path.isfile(args.ae_weights): 135 | print("==> loading checkpoint '{}'".format(args.ae_weights)) 136 | checkpoint = torch.load(args.ae_weights) 137 | model.load_state_dict(checkpoint['ae_state_dict']) 138 | else: 139 | print("==> no checkpoint found at '{}'".format(args.ae_weights)) 140 | raise ValueError 141 | 142 | print('Pretraining time: %d seconds.' % int(time() - t0)) 143 | 144 | #get k 145 | latent = model.encodeBatch(torch.tensor(adata1.X).to(args.device), torch.tensor(adata2.X).to(args.device)) 146 | latent = latent.cpu().numpy() 147 | if args.n_clusters == -1: 148 | n_clusters = GetCluster(latent, res=args.resolution, n=args.n_neighbors) 149 | else: 150 | print("n_cluster is defined as " + str(args.n_clusters)) 151 | n_clusters = args.n_clusters 152 | 153 | if not args.no_labels: 154 | y_pred, _ = model.fit(X1=adata1.X, X_raw1=adata1.raw.X, sf1=adata1.obs.size_factors, 155 | X2=adata2.X, X_raw2=adata2.raw.X, sf2=adata2.obs.size_factors, y=y, 156 | n_clusters=n_clusters, batch_size=args.batch_size, num_epochs=args.maxiter, 157 | update_interval=args.update_interval, tol=args.tol, lr=args.lr, save_dir=args.save_dir) 158 | else: 159 | y_pred, _ = model.fit(X1=adata1.X, X_raw1=adata1.raw.X, sf1=adata1.obs.size_factors, 160 | X2=adata2.X, X_raw2=adata2.raw.X, sf2=adata2.obs.size_factors, y=None, 161 | n_clusters=n_clusters, batch_size=args.batch_size, num_epochs=args.maxiter, 162 | update_interval=args.update_interval, tol=args.tol, lr=args.lr, save_dir=args.save_dir) 163 | print('Total time: %d seconds.' % int(time() - t0)) 164 | 165 | if args.prediction_file: 166 | if not args.no_labels: 167 | y_pred_ = best_map(y, y_pred) - 1 168 | np.savetxt(args.save_dir + "/" + str(args.run) + "_pred.csv", y_pred_, delimiter=",") 169 | else: 170 | np.savetxt(args.save_dir + "/" + str(args.run) + "_pred.csv", y_pred, delimiter=",") 171 | 172 | if args.embedding_file: 173 | final_latent = model.encodeBatch(torch.tensor(adata1.X).to(args.device), torch.tensor(adata2.X).to(args.device)) 174 | final_latent = final_latent.cpu().numpy() 175 | np.savetxt(args.save_dir + "/" + str(args.run) + "_embedding.csv", final_latent, delimiter=",") 176 | 177 | if not args.no_labels: 178 | y_pred_ = best_map(y, y_pred) 179 | ami = np.round(metrics.adjusted_mutual_info_score(y, y_pred), 5) 180 | nmi = np.round(metrics.normalized_mutual_info_score(y, y_pred), 5) 181 | ari = np.round(metrics.adjusted_rand_score(y, y_pred), 5) 182 | print('Final: AMI= %.4f, NMI= %.4f, ARI= %.4f' % (ami, nmi, ari)) 183 | else: 184 | print("No labels for evaluation!") 185 | -------------------------------------------------------------------------------- /src/run_scMDC_batch.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | import math, os 3 | from sklearn import metrics 4 | from sklearn.cluster import KMeans 5 | from sklearn.preprocessing import OneHotEncoder 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | from torch.nn import Parameter 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | from torch.utils.data import DataLoader, TensorDataset 13 | 14 | from scMDC_batch import scMultiClusterBatch 15 | import numpy as np 16 | import collections 17 | import h5py 18 | import scanpy as sc 19 | from preprocess import read_dataset, normalize 20 | from utils import * 21 | 22 | if __name__ == "__main__": 23 | 24 | # setting the hyper parameters 25 | import argparse 26 | parser = argparse.ArgumentParser(description='train', 27 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 28 | parser.add_argument('--n_clusters', default=27, type=int) 29 | parser.add_argument('--cutoff', default=0.5, type=float, help='Start to train combined layer after what ratio of epoch') 30 | parser.add_argument('--batch_size', default=256, type=int) 31 | parser.add_argument('--data_file', default='Normalized_filtered_BMNC_GSE128639_Seurat.h5') 32 | parser.add_argument('--maxiter', default=5000, type=int) 33 | parser.add_argument('--pretrain_epochs', default=400, type=int) 34 | parser.add_argument('--gamma', default=.1, type=float, 35 | help='coefficient of clustering loss') 36 | parser.add_argument('--tau', default=1., type=float, 37 | help='weight of clustering loss') 38 | parser.add_argument('--phi1', default=0.001, type=float, 39 | help='coefficient of KL loss in pretraining stage') 40 | parser.add_argument('--phi2', default=0.001, type=float, 41 | help='coefficient of KL loss in clustering stage') 42 | parser.add_argument('--update_interval', default=1, type=int) 43 | parser.add_argument('--tol', default=0.001, type=float) 44 | parser.add_argument('--lr', default=1., type=float) 45 | parser.add_argument('--ae_weights', default=None) 46 | parser.add_argument('--save_dir', default='results/') 47 | parser.add_argument('--ae_weight_file', default='AE_weights_1.pth.tar') 48 | parser.add_argument('--resolution', default=0.2, type=float) 49 | parser.add_argument('--n_neighbors', default=30, type=int) 50 | parser.add_argument('--embedding_file', action='store_true', default=False) 51 | parser.add_argument('--prediction_file', action='store_true', default=False) 52 | parser.add_argument('-el','--encodeLayer', nargs='+', default=[256,64,32,16]) 53 | parser.add_argument('-dl1','--decodeLayer1', nargs='+', default=[16,64,256]) 54 | parser.add_argument('-dl2','--decodeLayer2', nargs='+', default=[16,20]) 55 | parser.add_argument('--sigma1', default=2.5, type=float) 56 | parser.add_argument('--sigma2', default=1.5, type=float) 57 | parser.add_argument('--f1', default=1000, type=float, help='Number of mRNA after feature selection') 58 | parser.add_argument('--f2', default=2000, type=float, help='Number of ADT/ATAC after feature selection') 59 | parser.add_argument('--filter1', action='store_true', default=False, help='Do mRNA selection') 60 | parser.add_argument('--filter2', action='store_true', default=False, help='Do ADT/ATAC selection') 61 | parser.add_argument('--nbatch', default=2, type=int) 62 | parser.add_argument('--run', default=1, type=int) 63 | parser.add_argument('--device', default='cuda') 64 | parser.add_argument('--no_labels', action='store_true', default=False) 65 | args = parser.parse_args() 66 | print(args) 67 | data_mat = h5py.File(args.data_file) 68 | x1 = np.array(data_mat['X1']) 69 | x2 = np.array(data_mat['X2']) 70 | if not args.no_labels: 71 | y = np.array(data_mat['Y']) 72 | b = np.array(data_mat['Batch']) 73 | enc = OneHotEncoder() 74 | enc.fit(b.reshape(-1, 1)) 75 | B = enc.transform(b.reshape(-1, 1)).toarray() 76 | data_mat.close() 77 | 78 | #Gene filter 79 | if args.filter1: 80 | importantGenes = geneSelection(x1, n=args.f1, plot=False) 81 | x1 = x1[:, importantGenes] 82 | if args.filter2: 83 | importantGenes = geneSelection(x2, n=args.f2, plot=False) 84 | x2 = x2[:, importantGenes] 85 | 86 | # preprocessing scRNA-seq read counts matrix 87 | adata1 = sc.AnnData(x1) 88 | #adata1.obs['Group'] = y 89 | 90 | adata1 = read_dataset(adata1, 91 | transpose=False, 92 | test_split=False, 93 | copy=True) 94 | 95 | adata1 = normalize(adata1, 96 | size_factors=True, 97 | normalize_input=True, 98 | logtrans_input=True) 99 | 100 | adata2 = sc.AnnData(x2) 101 | #adata2.obs['Group'] = y 102 | adata2 = read_dataset(adata2, 103 | transpose=False, 104 | test_split=False, 105 | copy=True) 106 | 107 | adata2 = normalize(adata2, 108 | size_factors=True, 109 | normalize_input=True, 110 | logtrans_input=True) 111 | 112 | input_size1 = adata1.n_vars 113 | input_size2 = adata2.n_vars 114 | 115 | print(args) 116 | 117 | encodeLayer = list(map(int, args.encodeLayer)) 118 | decodeLayer1 = list(map(int, args.decodeLayer1)) 119 | decodeLayer2 = list(map(int, args.decodeLayer2)) 120 | 121 | model = scMultiClusterBatch(input_dim1=input_size1, input_dim2=input_size2, n_batch = args.nbatch, tau=args.tau, 122 | encodeLayer=encodeLayer, decodeLayer1=decodeLayer1, decodeLayer2=decodeLayer2, 123 | activation='elu', sigma1=args.sigma1, sigma2=args.sigma2, gamma=args.gamma, 124 | cutoff = args.cutoff, phi1=args.phi1, phi2=args.phi2, device=args.device).to(args.device) 125 | 126 | print(str(model)) 127 | 128 | if not os.path.exists(args.save_dir): 129 | os.makedirs(args.save_dir) 130 | 131 | 132 | t0 = time() 133 | if args.ae_weights is None: 134 | model.pretrain_autoencoder(X1=adata1.X, X_raw1=adata1.raw.X, sf1=adata1.obs.size_factors, 135 | X2=adata2.X, X_raw2=adata2.raw.X, sf2=adata2.obs.size_factors, B = B, batch_size=args.batch_size, 136 | epochs=args.pretrain_epochs, ae_weights=args.ae_weight_file) 137 | else: 138 | if os.path.isfile(args.ae_weights): 139 | print("==> loading checkpoint '{}'".format(args.ae_weights)) 140 | checkpoint = torch.load(args.ae_weights) 141 | model.load_state_dict(checkpoint['ae_state_dict']) 142 | else: 143 | print("==> no checkpoint found at '{}'".format(args.ae_weights)) 144 | raise ValueError 145 | 146 | print('Pretraining time: %d seconds.' % int(time() - t0)) 147 | 148 | #get k 149 | latent = model.encodeBatch(torch.tensor(adata1.X).to(args.device), torch.tensor(adata2.X).to(args.device), torch.tensor(B).to(args.device), batch_size=args.batch_size) 150 | latent = latent.cpu().numpy() 151 | if args.n_clusters == -1: 152 | n_clusters = GetCluster(latent, res=args.resolution, n=args.n_neighbors) 153 | else: 154 | print("n_cluster is defined as " + str(args.n_clusters)) 155 | n_clusters = args.n_clusters 156 | 157 | if not args.no_labels: 158 | y_pred,_ = model.fit(X1=adata1.X, X_raw1=adata1.raw.X, sf1=adata1.obs.size_factors, 159 | X2=adata2.X, X_raw2=adata2.raw.X, sf2=adata2.obs.size_factors, B=B, y=y, 160 | n_clusters=n_clusters, batch_size=args.batch_size, num_epochs=args.maxiter, 161 | update_interval=args.update_interval, tol=args.tol, lr=args.lr, save_dir=args.save_dir) 162 | else: 163 | y_pred,_ = model.fit(X1=adata1.X, X_raw1=adata1.raw.X, sf1=adata1.obs.size_factors, 164 | X2=adata2.X, X_raw2=adata2.raw.X, sf2=adata2.obs.size_factors, B=B, y=None, 165 | n_clusters=n_clusters, batch_size=args.batch_size, num_epochs=args.maxiter, 166 | update_interval=args.update_interval, tol=args.tol, lr=args.lr, save_dir=args.save_dir) 167 | print('Total time: %d seconds.' % int(time() - t0)) 168 | 169 | if args.prediction_file: 170 | np.savetxt(args.save_dir + "/" + str(args.run) + "_pred.csv", y_pred, delimiter=",") 171 | 172 | if args.embedding_file: 173 | final_latent = model.encodeBatch(torch.tensor(adata1.X).to(args.device), torch.tensor(adata2.X).to(args.device), torch.tensor(B).to(args.device), batch_size=args.batch_size) 174 | final_latent = final_latent.cpu().numpy() 175 | np.savetxt(args.save_dir + "/" + str(args.run) + "_embedding.csv", final_latent, delimiter=",") 176 | 177 | if not args.no_labels: 178 | y_pred_ = best_map(y, y_pred) 179 | ami = np.round(metrics.adjusted_mutual_info_score(y, y_pred), 5) 180 | nmi = np.round(metrics.normalized_mutual_info_score(y, y_pred), 5) 181 | ari = np.round(metrics.adjusted_rand_score(y, y_pred), 5) 182 | print('Final: AMI= %.4f, NMI= %.4f, ARI= %.4f' % (ami, nmi, ari)) 183 | else: 184 | print("No labels for evaluation!") 185 | -------------------------------------------------------------------------------- /src/scMDC.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics.pairwise import paired_distances 2 | from sklearn.decomposition import PCA 3 | from sklearn import metrics 4 | from sklearn.cluster import KMeans 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | from torch.nn import Parameter 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from torch.utils.data import DataLoader, TensorDataset 12 | from layers import NBLoss, ZINBLoss, MeanAct, DispAct 13 | import numpy as np 14 | 15 | import math, os 16 | 17 | def buildNetwork2(layers, type, activation="relu"): 18 | net = [] 19 | for i in range(1, len(layers)): 20 | net.append(nn.Linear(layers[i-1], layers[i])) 21 | net.append(nn.BatchNorm1d(layers[i], affine=True)) 22 | if activation=="relu": 23 | net.append(nn.ReLU()) 24 | elif activation=="selu": 25 | net.append(nn.SELU()) 26 | elif activation=="sigmoid": 27 | net.append(nn.Sigmoid()) 28 | elif activation=="elu": 29 | net.append(nn.ELU()) 30 | return nn.Sequential(*net) 31 | 32 | class scMultiCluster(nn.Module): 33 | def __init__(self, input_dim1, input_dim2, 34 | encodeLayer=[], decodeLayer1=[], decodeLayer2=[], tau=1., t=10, device="cuda", 35 | activation="elu", sigma1=2.5, sigma2=.1, alpha=1., gamma=1., phi1=0.0001, phi2=0.0001, cutoff = 0.5): 36 | super(scMultiCluster, self).__init__() 37 | self.tau=tau 38 | self.input_dim1 = input_dim1 39 | self.input_dim2 = input_dim2 40 | self.cutoff = cutoff 41 | self.activation = activation 42 | self.sigma1 = sigma1 43 | self.sigma2 = sigma2 44 | self.alpha = alpha 45 | self.gamma = gamma 46 | self.phi1 = phi1 47 | self.phi2 = phi2 48 | self.t = t 49 | self.device = device 50 | self.encoder = buildNetwork2([input_dim1+input_dim2]+encodeLayer, type="encode", activation=activation) 51 | self.decoder1 = buildNetwork2(decodeLayer1, type="decode", activation=activation) 52 | self.decoder2 = buildNetwork2(decodeLayer2, type="decode", activation=activation) 53 | self.dec_mean1 = nn.Sequential(nn.Linear(decodeLayer1[-1], input_dim1), MeanAct()) 54 | self.dec_disp1 = nn.Sequential(nn.Linear(decodeLayer1[-1], input_dim1), DispAct()) 55 | self.dec_mean2 = nn.Sequential(nn.Linear(decodeLayer2[-1], input_dim2), MeanAct()) 56 | self.dec_disp2 = nn.Sequential(nn.Linear(decodeLayer2[-1], input_dim2), DispAct()) 57 | self.dec_pi1 = nn.Sequential(nn.Linear(decodeLayer1[-1], input_dim1), nn.Sigmoid()) 58 | self.dec_pi2 = nn.Sequential(nn.Linear(decodeLayer2[-1], input_dim2), nn.Sigmoid()) 59 | self.zinb_loss = ZINBLoss() 60 | self.z_dim = encodeLayer[-1] 61 | 62 | def save_model(self, path): 63 | torch.save(self.state_dict(), path) 64 | 65 | def load_model(self, path): 66 | pretrained_dict = torch.load(path, map_location=lambda storage, loc: storage) 67 | model_dict = self.state_dict() 68 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 69 | model_dict.update(pretrained_dict) 70 | self.load_state_dict(model_dict) 71 | 72 | def cal_latent(self, z): 73 | sum_y = torch.sum(torch.square(z), dim=1) 74 | num = -2.0 * torch.matmul(z, z.t()) + torch.reshape(sum_y, [-1, 1]) + sum_y 75 | num = num / self.alpha 76 | num = torch.pow(1.0 + num, -(self.alpha + 1.0) / 2.0) 77 | zerodiag_num = num - torch.diag(torch.diag(num)) 78 | latent_p = (zerodiag_num.t() / torch.sum(zerodiag_num, dim=1)).t() 79 | return num, latent_p 80 | 81 | def kmeans_loss(self, z): 82 | dist1 = self.tau*torch.sum(torch.square(z.unsqueeze(1) - self.mu), dim=2) 83 | temp_dist1 = dist1 - torch.reshape(torch.mean(dist1, dim=1), [-1, 1]) 84 | q = torch.exp(-temp_dist1) 85 | q = (q.t() / torch.sum(q, dim=1)).t() 86 | q = torch.pow(q, 2) 87 | q = (q.t() / torch.sum(q, dim=1)).t() 88 | dist2 = dist1 * q 89 | return dist1, torch.mean(torch.sum(dist2, dim=1)) 90 | 91 | def target_distribution(self, q): 92 | p = q**2 / q.sum(0) 93 | return (p.t() / p.sum(1)).t() 94 | 95 | def forward(self, x1, x2): 96 | x = torch.cat([x1+torch.randn_like(x1)*self.sigma1, x2+torch.randn_like(x2)*self.sigma2], dim=-1) 97 | h = self.encoder(x) 98 | 99 | h1 = self.decoder1(h) 100 | mean1 = self.dec_mean1(h1) 101 | disp1 = self.dec_disp1(h1) 102 | pi1 = self.dec_pi1(h1) 103 | 104 | h2 = self.decoder2(h) 105 | mean2 = self.dec_mean2(h2) 106 | disp2 = self.dec_disp2(h2) 107 | pi2 = self.dec_pi2(h2) 108 | 109 | x0 = torch.cat([x1, x2], dim=-1) 110 | h0 = self.encoder(x0) 111 | num, lq = self.cal_latent(h0) 112 | return h0, num, lq, mean1, mean2, disp1, disp2, pi1, pi2 113 | 114 | def forwardAE(self, x1, x2): 115 | x = torch.cat([x1+torch.randn_like(x1)*self.sigma1, x2+torch.randn_like(x2)*self.sigma2], dim=-1) 116 | h = self.encoder(x) 117 | 118 | h1 = self.decoder1(h) 119 | mean1 = self.dec_mean1(h1) 120 | disp1 = self.dec_disp1(h1) 121 | pi1 = self.dec_pi1(h1) 122 | 123 | h2 = self.decoder2(h) 124 | mean2 = self.dec_mean2(h2) 125 | disp2 = self.dec_disp2(h2) 126 | pi2 = self.dec_pi2(h2) 127 | 128 | x0 = torch.cat([x1, x2], dim=-1) 129 | h0 = self.encoder(x0) 130 | num, lq = self.cal_latent(h0) 131 | return h0, num, lq, mean1, mean2, disp1, disp2, pi1, pi2 132 | 133 | def encodeBatch(self, X1, X2, batch_size=256): 134 | use_cuda = torch.cuda.is_available() 135 | if use_cuda: 136 | self.to(self.device) 137 | encoded = [] 138 | self.eval() 139 | num = X1.shape[0] 140 | num_batch = int(math.ceil(1.0*X1.shape[0]/batch_size)) 141 | for batch_idx in range(num_batch): 142 | x1batch = X1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)] 143 | x2batch = X2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)] 144 | inputs1 = Variable(x1batch) 145 | inputs2 = Variable(x2batch) 146 | z,_,_,_,_,_,_,_,_ = self.forwardAE(inputs1, inputs2) 147 | encoded.append(z.data) 148 | 149 | encoded = torch.cat(encoded, dim=0) 150 | return encoded 151 | 152 | def kldloss(self, p, q): 153 | c1 = -torch.sum(p * torch.log(q), dim=-1) 154 | c2 = -torch.sum(p * torch.log(p), dim=-1) 155 | return torch.mean(c1 - c2) 156 | 157 | def pretrain_autoencoder(self, X1, X_raw1, sf1, X2, X_raw2, sf2, 158 | batch_size=256, lr=0.001, epochs=400, ae_save=True, ae_weights='AE_weights.pth.tar'): 159 | num_batch = int(math.ceil(1.0*X1.shape[0]/batch_size)) 160 | dataset = TensorDataset(torch.Tensor(X1), torch.Tensor(X_raw1), torch.Tensor(sf1), torch.Tensor(X2), torch.Tensor(X_raw2), torch.Tensor(sf2)) 161 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 162 | print("Pretraining stage") 163 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=lr, amsgrad=True) 164 | num = X1.shape[0] 165 | for epoch in range(epochs): 166 | loss_val = 0 167 | recon_loss1_val = 0 168 | recon_loss2_val = 0 169 | kl_loss_val = 0 170 | for batch_idx, (x1_batch, x_raw1_batch, sf1_batch, x2_batch, x_raw2_batch, sf2_batch) in enumerate(dataloader): 171 | x1_tensor = Variable(x1_batch).to(self.device) 172 | x_raw1_tensor = Variable(x_raw1_batch).to(self.device) 173 | sf1_tensor = Variable(sf1_batch).to(self.device) 174 | x2_tensor = Variable(x2_batch).to(self.device) 175 | x_raw2_tensor = Variable(x_raw2_batch).to(self.device) 176 | sf2_tensor = Variable(sf2_batch).to(self.device) 177 | zbatch, z_num, lqbatch, mean1_tensor, mean2_tensor, disp1_tensor, disp2_tensor, pi1_tensor, pi2_tensor = self.forwardAE(x1_tensor, x2_tensor) 178 | recon_loss1 = self.zinb_loss(x=x_raw1_tensor, mean=mean1_tensor, disp=disp1_tensor, pi=pi1_tensor, scale_factor=sf1_tensor) 179 | recon_loss2 = self.zinb_loss(x=x_raw2_tensor, mean=mean2_tensor, disp=disp2_tensor, pi=pi2_tensor, scale_factor=sf2_tensor) 180 | lpbatch = self.target_distribution(lqbatch) 181 | lqbatch = lqbatch + torch.diag(torch.diag(z_num)) 182 | lpbatch = lpbatch + torch.diag(torch.diag(z_num)) 183 | kl_loss = self.kldloss(lpbatch, lqbatch) 184 | if epoch+1 >= epochs * self.cutoff: 185 | loss = recon_loss1 + recon_loss2 + kl_loss * self.phi1 186 | else: 187 | loss = recon_loss1 + recon_loss2 188 | optimizer.zero_grad() 189 | loss.backward() 190 | optimizer.step() 191 | 192 | loss_val += loss.item() * len(x1_batch) 193 | recon_loss1_val += recon_loss1.item() * len(x1_batch) 194 | recon_loss2_val += recon_loss2.item() * len(x2_batch) 195 | if epoch+1 >= epochs * self.cutoff: 196 | kl_loss_val += kl_loss.item() * len(x1_batch) 197 | 198 | loss_val = loss_val/num 199 | recon_loss1_val = recon_loss1_val/num 200 | recon_loss2_val = recon_loss2_val/num 201 | kl_loss_val = kl_loss_val/num 202 | if epoch%self.t == 0: 203 | print('Pretrain epoch {}, Total loss:{:.6f}, ZINB loss1:{:.6f}, ZINB loss2:{:.6f}, KL loss:{:.6f}'.format(epoch+1, loss_val, recon_loss1_val, recon_loss2_val, kl_loss_val)) 204 | 205 | if ae_save: 206 | torch.save({'ae_state_dict': self.state_dict(), 207 | 'optimizer_state_dict': optimizer.state_dict()}, ae_weights) 208 | 209 | def save_checkpoint(self, state, index, filename): 210 | newfilename = os.path.join(filename, 'FTcheckpoint_%d.pth.tar' % index) 211 | torch.save(state, newfilename) 212 | 213 | def fit(self, X1, X_raw1, sf1, X2, X_raw2, sf2, y=None, lr=1., n_clusters = 4, 214 | batch_size=256, num_epochs=10, update_interval=1, tol=1e-3, save_dir=""): 215 | '''X: tensor data''' 216 | use_cuda = torch.cuda.is_available() 217 | if use_cuda: 218 | self.to(self.device) 219 | print("Clustering stage") 220 | X1 = torch.tensor(X1).to(self.device) 221 | X_raw1 = torch.tensor(X_raw1).to(self.device) 222 | sf1 = torch.tensor(sf1).to(self.device) 223 | X2 = torch.tensor(X2).to(self.device) 224 | X_raw2 = torch.tensor(X_raw2).to(self.device) 225 | sf2 = torch.tensor(sf2).to(self.device) 226 | self.mu = Parameter(torch.Tensor(n_clusters, self.z_dim), requires_grad=True) 227 | optimizer = optim.Adadelta(filter(lambda p: p.requires_grad, self.parameters()), lr=lr, rho=.95) 228 | 229 | print("Initializing cluster centers with kmeans.") 230 | kmeans = KMeans(n_clusters, n_init=20) 231 | Zdata = self.encodeBatch(X1, X2, batch_size=batch_size) 232 | #latent 233 | self.y_pred = kmeans.fit_predict(Zdata.data.cpu().numpy()) 234 | self.y_pred_last = self.y_pred 235 | self.mu.data.copy_(torch.Tensor(kmeans.cluster_centers_)) 236 | 237 | if y is not None: 238 | ami = np.round(metrics.adjusted_mutual_info_score(y, self.y_pred), 5) 239 | nmi = np.round(metrics.normalized_mutual_info_score(y, self.y_pred), 5) 240 | ari = np.round(metrics.adjusted_rand_score(y, self.y_pred), 5) 241 | print('Initializing k-means: AMI= %.4f, NMI= %.4f, ARI= %.4f' % (ami, nmi, ari)) 242 | 243 | self.train() 244 | num = X1.shape[0] 245 | num_batch = int(math.ceil(1.0*X1.shape[0]/batch_size)) 246 | 247 | final_ami, final_nmi, final_ari, final_epoch = 0, 0, 0, 0 248 | 249 | for epoch in range(num_epochs): 250 | if epoch%update_interval == 0: 251 | Zdata = self.encodeBatch(X1, X2, batch_size=batch_size) 252 | dist, _ = self.kmeans_loss(Zdata) 253 | self.y_pred = torch.argmin(dist, dim=1).data.cpu().numpy() 254 | if y is not None: 255 | #acc2 = np.round(cluster_acc(y, self.y_pred), 5) 256 | final_ami = ami = np.round(metrics.adjusted_mutual_info_score(y, self.y_pred), 5) 257 | final_nmi = nmi = np.round(metrics.normalized_mutual_info_score(y, self.y_pred), 5) 258 | final_ari = ari = np.round(metrics.adjusted_rand_score(y, self.y_pred), 5) 259 | final_epoch = epoch+1 260 | print('Clustering %d: AMI= %.4f, NMI= %.4f, ARI= %.4f' % (epoch+1, ami, nmi, ari)) 261 | 262 | # check stop criterion 263 | delta_label = np.sum(self.y_pred != self.y_pred_last).astype(np.float32) / num 264 | self.y_pred_last = self.y_pred 265 | if epoch>0 and delta_label < tol: 266 | print('delta_label ', delta_label, '< tol ', tol) 267 | print("Reach tolerance threshold. Stopping training.") 268 | break 269 | 270 | #save current model 271 | # if (epoch>0 and delta_label < tol) or epoch%10 == 0: 272 | # self.save_checkpoint({'epoch': epoch+1, 273 | # 'state_dict': self.state_dict(), 274 | # 'mu': self.mu, 275 | # 'y_pred': self.y_pred, 276 | # 'y_pred_last': self.y_pred_last, 277 | # 'y': y 278 | # }, epoch+1, filename=save_dir) 279 | 280 | # train 1 epoch for clustering loss 281 | loss_val = 0.0 282 | recon_loss1_val = 0.0 283 | recon_loss2_val = 0.0 284 | cluster_loss_val = 0.0 285 | kl_loss_val = 0.0 286 | for batch_idx in range(num_batch): 287 | x1_batch = X1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)] 288 | x_raw1_batch = X_raw1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)] 289 | sf1_batch = sf1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)] 290 | x2_batch = X2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)] 291 | x_raw2_batch = X_raw2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)] 292 | sf2_batch = sf2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)] 293 | 294 | inputs1 = Variable(x1_batch) 295 | rawinputs1 = Variable(x_raw1_batch) 296 | sfinputs1 = Variable(sf1_batch) 297 | inputs2 = Variable(x2_batch) 298 | rawinputs2 = Variable(x_raw2_batch) 299 | sfinputs2 = Variable(sf2_batch) 300 | 301 | zbatch, z_num, lqbatch, mean1_tensor, mean2_tensor, disp1_tensor, disp2_tensor, pi1_tensor, pi2_tensor = self.forward(inputs1, inputs2) 302 | _, cluster_loss = self.kmeans_loss(zbatch) 303 | recon_loss1 = self.zinb_loss(x=rawinputs1, mean=mean1_tensor, disp=disp1_tensor, pi=pi1_tensor, scale_factor=sfinputs1) 304 | recon_loss2 = self.zinb_loss(x=rawinputs2, mean=mean2_tensor, disp=disp2_tensor, pi=pi2_tensor, scale_factor=sfinputs2) 305 | target2 = self.target_distribution(lqbatch) 306 | lqbatch = lqbatch + torch.diag(torch.diag(z_num)) 307 | target2 = target2 + torch.diag(torch.diag(z_num)) 308 | kl_loss = self.kldloss(target2, lqbatch) 309 | loss = recon_loss1 + recon_loss2 + kl_loss * self.phi2 + cluster_loss * self.gamma 310 | optimizer.zero_grad() 311 | loss.backward() 312 | # torch.nn.utils.clip_grad_norm_(self.mu, 1) 313 | optimizer.step() 314 | cluster_loss_val += cluster_loss.data * len(inputs1) 315 | recon_loss1_val += recon_loss1.data * len(inputs1) 316 | recon_loss2_val += recon_loss2.data * len(inputs2) 317 | kl_loss_val += kl_loss.data * len(inputs1) 318 | loss_val = cluster_loss_val + recon_loss1_val + recon_loss2_val + kl_loss_val 319 | 320 | if epoch%self.t == 0: 321 | print("#Epoch %d: Total: %.6f Clustering Loss: %.6f ZINB Loss1: %.6f ZINB Loss2: %.6f KL Loss: %.6f" % ( 322 | epoch + 1, loss_val / num, cluster_loss_val / num, recon_loss1_val / num, recon_loss2_val / num, kl_loss_val / num)) 323 | 324 | return self.y_pred, final_epoch 325 | -------------------------------------------------------------------------------- /src/scMDC_batch.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics.pairwise import paired_distances 2 | from sklearn.decomposition import PCA 3 | from sklearn import metrics 4 | from sklearn.cluster import KMeans 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | from torch.nn import Parameter 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from torch.utils.data import DataLoader, TensorDataset 12 | from layers import NBLoss, ZINBLoss, MeanAct, DispAct 13 | import numpy as np 14 | 15 | import math, os 16 | 17 | from utils import torch_PCA 18 | 19 | from preprocess import read_dataset, normalize 20 | import scanpy as sc 21 | 22 | def buildNetwork1(layers, type, activation="relu"): 23 | net = [] 24 | for i in range(1, len(layers)): 25 | net.append(nn.Linear(layers[i-1], layers[i])) 26 | if type=="encode" and i==len(layers)-1: 27 | break 28 | if activation=="relu": 29 | net.append(nn.ReLU()) 30 | elif activation=="sigmoid": 31 | net.append(nn.Sigmoid()) 32 | elif activation=="elu": 33 | net.append(nn.ELU()) 34 | return nn.Sequential(*net) 35 | 36 | def buildNetwork2(layers, type, activation="relu"): 37 | net = [] 38 | for i in range(1, len(layers)): 39 | net.append(nn.Linear(layers[i-1], layers[i])) 40 | net.append(nn.BatchNorm1d(layers[i], affine=True)) 41 | if activation=="relu": 42 | net.append(nn.ReLU()) 43 | elif activation=="selu": 44 | net.append(nn.SELU()) 45 | elif activation=="sigmoid": 46 | net.append(nn.Sigmoid()) 47 | elif activation=="elu": 48 | net.append(nn.ELU()) 49 | return nn.Sequential(*net) 50 | 51 | class scMultiClusterBatch(nn.Module): 52 | def __init__(self, input_dim1, input_dim2, n_batch, 53 | encodeLayer=[], decodeLayer1=[], decodeLayer2=[], tau=1., t=10, device = "cuda", 54 | activation="elu", sigma1=2.5, sigma2=.1, alpha=1., gamma=1., phi1=0.0001, phi2=0.0001, cutoff = 0.5): 55 | super(scMultiClusterBatch, self).__init__() 56 | self.tau=tau 57 | self.input_dim1 = input_dim1 58 | self.input_dim2 = input_dim2 59 | self.cutoff = cutoff 60 | self.activation = activation 61 | self.sigma1 = sigma1 62 | self.sigma2 = sigma2 63 | self.alpha = alpha 64 | self.gamma = gamma 65 | self.phi1 = phi1 66 | self.phi2 = phi2 67 | self.t=t 68 | self.device = device 69 | self.encoder = buildNetwork2([input_dim1+input_dim2+n_batch]+encodeLayer, type="encode", activation=activation) 70 | self.decoder1 = buildNetwork2([decodeLayer1[0]+n_batch]+decodeLayer1[1:], type="decode", activation=activation) 71 | self.decoder2 = buildNetwork2([decodeLayer2[0]+n_batch]+decodeLayer2[1:], type="decode", activation=activation) 72 | self.dec_mean1 = nn.Sequential(nn.Linear(decodeLayer1[-1], input_dim1), MeanAct()) 73 | self.dec_disp1 = nn.Sequential(nn.Linear(decodeLayer1[-1], input_dim1), DispAct()) 74 | self.dec_mean2 = nn.Sequential(nn.Linear(decodeLayer2[-1], input_dim2), MeanAct()) 75 | self.dec_disp2 = nn.Sequential(nn.Linear(decodeLayer2[-1], input_dim2), DispAct()) 76 | self.dec_pi1 = nn.Sequential(nn.Linear(decodeLayer1[-1], input_dim1), nn.Sigmoid()) 77 | self.dec_pi2 = nn.Sequential(nn.Linear(decodeLayer2[-1], input_dim2), nn.Sigmoid()) 78 | self.zinb_loss = ZINBLoss() 79 | self.NBLoss = NBLoss() 80 | self.mse = nn.MSELoss() 81 | self.z_dim = encodeLayer[-1] 82 | 83 | def save_model(self, path): 84 | torch.save(self.state_dict(), path) 85 | 86 | def load_model(self, path): 87 | pretrained_dict = torch.load(path, map_location=lambda storage, loc: storage) 88 | model_dict = self.state_dict() 89 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 90 | model_dict.update(pretrained_dict) 91 | self.load_state_dict(model_dict) 92 | 93 | def soft_assign(self, z): 94 | q = 1.0 / (1.0 + torch.sum((z.unsqueeze(1) - self.mu)**2, dim=2) / self.alpha) 95 | q = q**((self.alpha+1.0)/2.0) 96 | q = (q.t() / torch.sum(q, dim=1)).t() 97 | return q 98 | 99 | def cal_latent(self, z): 100 | sum_y = torch.sum(torch.square(z), dim=1) 101 | num = -2.0 * torch.matmul(z, z.t()) + torch.reshape(sum_y, [-1, 1]) + sum_y 102 | num = num / self.alpha 103 | num = torch.pow(1.0 + num, -(self.alpha + 1.0) / 2.0) 104 | zerodiag_num = num - torch.diag(torch.diag(num)) 105 | latent_p = (zerodiag_num.t() / torch.sum(zerodiag_num, dim=1)).t() 106 | return num, latent_p 107 | 108 | def target_distribution(self, q): 109 | p = q**2 / q.sum(0) 110 | return (p.t() / p.sum(1)).t() 111 | 112 | def kmeans_loss(self, z): 113 | dist1 = self.tau * torch.sum(torch.square(z.unsqueeze(1) - self.mu), dim=2) 114 | temp_dist1 = dist1 - torch.reshape(torch.mean(dist1, dim=1), [-1, 1]) 115 | q = torch.exp(-temp_dist1) 116 | q = (q.t() / torch.sum(q, dim=1)).t() 117 | q = torch.pow(q, 2) 118 | q = (q.t() / torch.sum(q, dim=1)).t() 119 | dist2 = dist1 * q 120 | return dist1, torch.mean(torch.sum(dist2, dim=1)) 121 | 122 | def forward(self, x1, x2, b): 123 | x = torch.cat([x1+torch.randn_like(x1)*self.sigma1, x2+torch.randn_like(x2)*self.sigma2], dim=-1) 124 | h = self.encoder(torch.cat([x, b], dim=-1)) 125 | h = torch.cat([h, b], dim=-1) 126 | 127 | h1 = self.decoder1(h) 128 | mean1 = self.dec_mean1(h1) 129 | disp1 = self.dec_disp1(h1) 130 | pi1 = self.dec_pi1(h1) 131 | 132 | h2 = self.decoder2(h) 133 | mean2 = self.dec_mean2(h2) 134 | disp2 = self.dec_disp2(h2) 135 | pi2 = self.dec_pi2(h2) 136 | 137 | x0 = torch.cat([x1, x2], dim=-1) 138 | h0 = self.encoder(torch.cat([x0, b], dim=-1)) 139 | q = self.soft_assign(h0) 140 | num, lq = self.cal_latent(h0) 141 | return h0, q, num, lq, mean1, mean2, disp1, disp2, pi1, pi2 142 | 143 | def forwardAE(self, x1, x2, b): 144 | x = torch.cat([x1+torch.randn_like(x1)*self.sigma1, x2+torch.randn_like(x2)*self.sigma2], dim=-1) 145 | h = self.encoder(torch.cat([x, b], dim=-1)) 146 | h = torch.cat([h, b], dim=-1) 147 | 148 | h1 = self.decoder1(h) 149 | mean1 = self.dec_mean1(h1) 150 | disp1 = self.dec_disp1(h1) 151 | pi1 = self.dec_pi1(h1) 152 | 153 | h2 = self.decoder2(h) 154 | mean2 = self.dec_mean2(h2) 155 | disp2 = self.dec_disp2(h2) 156 | pi2 = self.dec_pi2(h2) 157 | 158 | x0 = torch.cat([x1, x2], dim=-1) 159 | h0 = self.encoder(torch.cat([x0, b], dim=-1)) 160 | num, lq = self.cal_latent(h0) 161 | return h0, num, lq, mean1, mean2, disp1, disp2, pi1, pi2 162 | 163 | def encodeBatch(self, X1, X2, B, batch_size=256): 164 | use_cuda = torch.cuda.is_available() 165 | if use_cuda: 166 | self.to(self.device) 167 | encoded = [] 168 | self.eval() 169 | num = X1.shape[0] 170 | num_batch = int(math.ceil(1.0*X1.shape[0]/batch_size)) 171 | for batch_idx in range(num_batch): 172 | x1batch = X1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)] 173 | x2batch = X2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)] 174 | b_batch = B[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)] 175 | inputs1 = Variable(x1batch).to(self.device) 176 | inputs2 = Variable(x2batch).to(self.device) 177 | b_tensor = Variable(b_batch).to(self.device) 178 | z,_,_,_,_,_,_,_,_ = self.forwardAE(inputs1.float(), inputs2.float(), b_tensor.float()) 179 | encoded.append(z.data) 180 | 181 | encoded = torch.cat(encoded, dim=0) 182 | return encoded 183 | 184 | def cluster_loss(self, p, q): 185 | def kld(target, pred): 186 | return torch.mean(torch.sum(target*torch.log(target/(pred+1e-6)), dim=-1)) 187 | kldloss = kld(p, q) 188 | return kldloss 189 | 190 | def kldloss(self, p, q): 191 | c1 = -torch.sum(p * torch.log(q), dim=-1) 192 | c2 = -torch.sum(p * torch.log(p), dim=-1) 193 | return torch.mean(c1 - c2) 194 | 195 | def SDis_func(self, x, y): 196 | return torch.sum(torch.square(x - y), dim=1) 197 | 198 | def pretrain_autoencoder(self, X1, X_raw1, sf1, X2, X_raw2, sf2, B, 199 | batch_size=256, lr=0.001, epochs=400, ae_save=True, ae_weights='AE_weights.pth.tar'): 200 | num_batch = int(math.ceil(1.0*X1.shape[0]/batch_size)) 201 | dataset = TensorDataset(torch.Tensor(X1), torch.Tensor(X_raw1), torch.Tensor(sf1), torch.Tensor(X2), torch.Tensor(X_raw2), torch.Tensor(sf2), torch.Tensor(B)) 202 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 203 | print("Pretraining stage") 204 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=lr, amsgrad=True) 205 | counts = 0 206 | for epoch in range(epochs): 207 | loss_val = 0 208 | recon_loss1_val = 0 209 | recon_loss2_val = 0 210 | kl_loss_val = 0 211 | for batch_idx, (x1_batch, x_raw1_batch, sf1_batch, x2_batch, x_raw2_batch, sf2_batch, b_batch) in enumerate(dataloader): 212 | x1_tensor = Variable(x1_batch).to(self.device) 213 | x_raw1_tensor = Variable(x_raw1_batch).to(self.device) 214 | sf1_tensor = Variable(sf1_batch).to(self.device) 215 | x2_tensor = Variable(x2_batch).to(self.device) 216 | x_raw2_tensor = Variable(x_raw2_batch).to(self.device) 217 | sf2_tensor = Variable(sf2_batch).to(self.device) 218 | b_tensor = Variable(b_batch).to(self.device) 219 | zbatch, z_num, lqbatch, mean1_tensor, mean2_tensor, disp1_tensor, disp2_tensor, pi1_tensor, pi2_tensor = self.forwardAE(x1_tensor, x2_tensor, b_tensor) 220 | #recon_loss1 = self.mse(mean1_tensor, x1_tensor) 221 | recon_loss1 = self.zinb_loss(x=x_raw1_tensor, mean=mean1_tensor, disp=disp1_tensor, pi=pi1_tensor, scale_factor=sf1_tensor) 222 | #recon_loss2 = self.mse(mean2_tensor, x2_tensor) 223 | recon_loss2 = self.zinb_loss(x=x_raw2_tensor, mean=mean2_tensor, disp=disp2_tensor, pi=pi2_tensor, scale_factor=sf2_tensor) 224 | lpbatch = self.target_distribution(lqbatch) 225 | lqbatch = lqbatch + torch.diag(torch.diag(z_num)) 226 | lpbatch = lpbatch + torch.diag(torch.diag(z_num)) 227 | kl_loss = self.kldloss(lpbatch, lqbatch) 228 | if epoch+1 >= epochs * self.cutoff: 229 | loss = recon_loss1 + recon_loss2 + kl_loss * self.phi1 230 | else: 231 | loss = recon_loss1 + recon_loss2 #+ kl_loss 232 | optimizer.zero_grad() 233 | loss.backward() 234 | optimizer.step() 235 | 236 | loss_val += loss.item() * len(x1_batch) 237 | recon_loss1_val += recon_loss1.item() * len(x1_batch) 238 | recon_loss2_val += recon_loss2.item() * len(x1_batch) 239 | if epoch+1 >= epochs * self.cutoff: 240 | kl_loss_val += kl_loss.item() * len(x1_batch) 241 | 242 | loss_val = loss_val/X1.shape[0] 243 | recon_loss1_val = loss_val/X1.shape[0] 244 | recon_loss2_val = recon_loss2_val/X1.shape[0] 245 | kl_loss_val = kl_loss_val/X1.shape[0] 246 | if epoch%self.t == 0: 247 | print('Pretrain epoch {}, Total loss:{:.6f}, ZINB loss:{:.6f}, NB loss:{:.6f}, KL loss:{:.6f}'.format(epoch+1, loss_val, recon_loss1_val, recon_loss2_val, kl_loss_val)) 248 | 249 | if ae_save: 250 | torch.save({'ae_state_dict': self.state_dict(), 251 | 'optimizer_state_dict': optimizer.state_dict()}, ae_weights) 252 | 253 | def save_checkpoint(self, state, index, filename): 254 | newfilename = os.path.join(filename, 'FTcheckpoint_%d.pth.tar' % index) 255 | torch.save(state, newfilename) 256 | 257 | def fit(self, X1, X_raw1, sf1, X2, X_raw2, sf2, B, y=None, lr=1., n_clusters = 4, 258 | batch_size=256, num_epochs=10, update_interval=1, tol=1e-3, save_dir=""): 259 | '''X: tensor data''' 260 | use_cuda = torch.cuda.is_available() 261 | if use_cuda: 262 | self.to(self.device) 263 | print("Clustering stage") 264 | X1 = torch.tensor(X1).to(self.device) 265 | X_raw1 = torch.tensor(X_raw1).to(self.device) 266 | sf1 = torch.tensor(sf1).to(self.device) 267 | X2 = torch.tensor(X2).to(self.device) 268 | X_raw2 = torch.tensor(X_raw2).to(self.device) 269 | sf2 = torch.tensor(sf2).to(self.device) 270 | B = torch.tensor(B).to(self.device) 271 | self.mu = Parameter(torch.Tensor(n_clusters, self.z_dim), requires_grad=True) 272 | optimizer = optim.Adadelta(filter(lambda p: p.requires_grad, self.parameters()), lr=lr, rho=.95) 273 | #optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=0.001) 274 | 275 | print("Initializing cluster centers with kmeans.") 276 | kmeans = KMeans(n_clusters, n_init=20) 277 | Zdata = self.encodeBatch(X1, X2, B, batch_size=batch_size) 278 | #latent 279 | self.y_pred = kmeans.fit_predict(Zdata.data.cpu().numpy()) 280 | self.y_pred_last = self.y_pred 281 | self.mu.data.copy_(torch.Tensor(kmeans.cluster_centers_)) 282 | if y is not None: 283 | ami = np.round(metrics.adjusted_mutual_info_score(y, self.y_pred), 5) 284 | nmi = np.round(metrics.normalized_mutual_info_score(y, self.y_pred), 5) 285 | ari = np.round(metrics.adjusted_rand_score(y, self.y_pred), 5) 286 | print('Initializing k-means: AMI= %.4f, NMI= %.4f, ARI= %.4f' % (ami, nmi, ari)) 287 | 288 | self.train() 289 | num = X1.shape[0] 290 | num_batch = int(math.ceil(1.0*X1.shape[0]/batch_size)) 291 | 292 | final_nmi, final_ari, final_epoch = 0, 0, 0 293 | 294 | for epoch in range(num_epochs): 295 | if epoch%update_interval == 0: 296 | # update the targe distribution p 297 | Zdata = self.encodeBatch(X1, X2, B, batch_size=batch_size) 298 | 299 | # evalute the clustering performance 300 | dist, _ = self.kmeans_loss(Zdata) 301 | self.y_pred = torch.argmin(dist, dim=1).data.cpu().numpy() 302 | 303 | if y is not None: 304 | #acc2 = np.round(cluster_acc(y, self.y_pred), 5) 305 | final_ami = ami = np.round(metrics.adjusted_mutual_info_score(y, self.y_pred), 5) 306 | final_nmi = nmi = np.round(metrics.normalized_mutual_info_score(y, self.y_pred), 5) 307 | final_ari = ari = np.round(metrics.adjusted_rand_score(y, self.y_pred), 5) 308 | final_epoch = epoch+1 309 | print('Clustering %d: AMI= %.4f, NMI= %.4f, ARI= %.4f' % (epoch+1, ami, nmi, ari)) 310 | 311 | # check stop criterion 312 | delta_label = np.sum(self.y_pred != self.y_pred_last).astype(np.float32) / num 313 | self.y_pred_last = self.y_pred 314 | if epoch>0 and delta_label < tol: 315 | print('delta_label ', delta_label, '< tol ', tol) 316 | print("Reach tolerance threshold. Stopping training.") 317 | break 318 | 319 | # save current model 320 | # if (epoch>0 and delta_label < tol) or epoch%10 == 0: 321 | # self.save_checkpoint({'epoch': epoch+1, 322 | # 'state_dict': self.state_dict(), 323 | # 'mu': self.mu, 324 | # 'y_pred': self.y_pred, 325 | # 'y_pred_last': self.y_pred_last, 326 | # 'y': y 327 | # }, epoch+1, filename=save_dir) 328 | 329 | # train 1 epoch for clustering loss 330 | train_loss = 0.0 331 | recon_loss1_val = 0.0 332 | recon_loss2_val = 0.0 333 | recon_loss_latent_val = 0.0 334 | cluster_loss_val = 0.0 335 | kl_loss_val = 0.0 336 | for batch_idx in range(num_batch): 337 | x1_batch = X1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)] 338 | x_raw1_batch = X_raw1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)] 339 | sf1_batch = sf1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)] 340 | x2_batch = X2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)] 341 | x_raw2_batch = X_raw2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)] 342 | sf2_batch = sf2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)] 343 | b_batch = B[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)] 344 | optimizer.zero_grad() 345 | inputs1 = Variable(x1_batch) 346 | rawinputs1 = Variable(x_raw1_batch) 347 | sfinputs1 = Variable(sf1_batch) 348 | inputs2 = Variable(x2_batch) 349 | rawinputs2 = Variable(x_raw2_batch) 350 | sfinputs2 = Variable(sf2_batch) 351 | 352 | zbatch, qbatch, z_num, lqbatch, mean1_tensor, mean2_tensor, disp1_tensor, disp2_tensor, pi1_tensor, pi2_tensor = self.forward(inputs1.float(), inputs2.float(), b_batch.float()) 353 | 354 | _, cluster_loss = self.kmeans_loss(zbatch) 355 | recon_loss1 = self.zinb_loss(x=rawinputs1, mean=mean1_tensor, disp=disp1_tensor, pi=pi1_tensor, scale_factor=sfinputs1) 356 | recon_loss2 = self.zinb_loss(x=rawinputs2, mean=mean2_tensor, disp=disp2_tensor, pi=pi2_tensor, scale_factor=sfinputs2) 357 | target2 = self.target_distribution(lqbatch) 358 | lqbatch = lqbatch + torch.diag(torch.diag(z_num)) 359 | target2 = target2 + torch.diag(torch.diag(z_num)) 360 | kl_loss = self.kldloss(target2, lqbatch) 361 | loss = cluster_loss * self.gamma + kl_loss * self.phi2 + recon_loss1 + recon_loss2 362 | loss.backward() 363 | torch.nn.utils.clip_grad_norm_(self.mu, 1) 364 | optimizer.step() 365 | cluster_loss_val += cluster_loss.data * len(inputs1) 366 | recon_loss1_val += recon_loss1.data * len(inputs1) 367 | recon_loss2_val += recon_loss2.data * len(inputs2) 368 | kl_loss_val += kl_loss.data * len(inputs1) 369 | loss_val = cluster_loss_val + recon_loss1_val + recon_loss2_val + kl_loss_val 370 | 371 | if epoch%self.t == 0: 372 | print("#Epoch %d: Total: %.6f Clustering Loss: %.6f ZINB Loss: %.6f ZINB Loss2: %.6f KL Loss: %.6f" % ( 373 | epoch + 1, loss_val / num, cluster_loss_val / num, recon_loss1_val / num, recon_loss2_val / num, kl_loss_val / num)) 374 | 375 | return self.y_pred, final_epoch 376 | -------------------------------------------------------------------------------- /src/tree.txt: -------------------------------------------------------------------------------- 1 | ((A:3,H:3):3,(B:2,C:2):4,(D:1,E:1):5,(H:4,G:4):2); 2 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import scanpy as sc 5 | from scipy import stats, spatial, sparse 6 | from scipy.linalg import norm 7 | from sklearn.metrics.pairwise import euclidean_distances 8 | import numpy as np 9 | import random 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.init as init 13 | import torch.utils.data as data 14 | from sklearn.neighbors import kneighbors_graph 15 | 16 | def cluster_acc(y_true, y_pred): 17 | """ 18 | Calculate clustering accuracy. Require scikit-learn installed 19 | # Arguments 20 | y: true labels, numpy.array with shape `(n_samples,)` 21 | y_pred: predicted labels, numpy.array with shape `(n_samples,)` 22 | # Return 23 | accuracy, in [0,1] 24 | """ 25 | y_true = y_true.astype(np.int64) 26 | assert y_pred.size == y_true.size 27 | D = max(y_pred.max(), y_true.max()) + 1 28 | w = np.zeros((D, D), dtype=np.int64) 29 | for i in range(y_pred.size): 30 | w[y_pred[i], y_true[i]] += 1 31 | from sklearn.utils.linear_assignment_ import linear_assignment 32 | ind = linear_assignment(w.max() - w) 33 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size 34 | 35 | def GetCluster(X, res, n): 36 | adata0=sc.AnnData(X) 37 | if adata0.shape[0]>200000: 38 | np.random.seed(adata0.shape[0])#set seed 39 | adata0=adata0[np.random.choice(adata0.shape[0],200000,replace=False)] 40 | sc.pp.neighbors(adata0, n_neighbors=n, use_rep="X") 41 | sc.tl.louvain(adata0,resolution=res) 42 | Y_pred_init=adata0.obs['louvain'] 43 | Y_pred_init=np.asarray(Y_pred_init,dtype=int) 44 | if np.unique(Y_pred_init).shape[0]<=1: 45 | #avoid only a cluster 46 | exit("Error: There is only a cluster detected. The resolution:"+str(res)+"is too small, please choose a larger resolution!!") 47 | else: 48 | print("Estimated n_clusters is: ", np.shape(np.unique(Y_pred_init))[0]) 49 | return(np.shape(np.unique(Y_pred_init))[0]) 50 | 51 | def torch_PCA(X, k, center=True, scale=False): 52 | X = X.t() 53 | n,p = X.size() 54 | ones = torch.ones(n).cuda().view([n,1]) 55 | h = ((1/n) * torch.mm(ones, ones.t())) if center else torch.zeros(n*n).view([n,n]) 56 | H = torch.eye(n).cuda() - h 57 | X_center = torch.mm(H.double(), X.double()) 58 | covariance = 1/(n-1) * torch.mm(X_center.t(), X_center).view(p,p) 59 | scaling = torch.sqrt(1/torch.diag(covariance)).double() if scale else torch.ones(p).cuda().double() 60 | scaled_covariance = torch.mm(torch.diag(scaling).view(p,p), covariance) 61 | eigenvalues, eigenvectors = torch.eig(scaled_covariance, True) 62 | components = (eigenvectors[:, :k]) 63 | #explained_variance = eigenvalues[:k, 0] 64 | return components 65 | 66 | def best_map(L1,L2): 67 | #L1 should be the groundtruth labels and L2 should be the clustering labels we got 68 | Label1 = np.unique(L1) 69 | nClass1 = len(Label1) 70 | Label2 = np.unique(L2) 71 | nClass2 = len(Label2) 72 | nClass = np.maximum(nClass1,nClass2) 73 | G = np.zeros((nClass,nClass)) 74 | for i in range(nClass1): 75 | ind_cla1 = L1 == Label1[i] 76 | ind_cla1 = ind_cla1.astype(float) 77 | for j in range(nClass2): 78 | ind_cla2 = L2 == Label2[j] 79 | ind_cla2 = ind_cla2.astype(float) 80 | G[i,j] = np.sum(ind_cla2 * ind_cla1) 81 | m = Munkres() 82 | index = m.compute(-G.T) 83 | index = np.array(index) 84 | c = index[:,1] 85 | newL2 = np.zeros(L2.shape) 86 | for i in range(nClass2): 87 | newL2[L2 == Label2[i]] = Label1[c[i]] 88 | return newL2 89 | 90 | def geneSelection(data, threshold=0, atleast=10, 91 | yoffset=.02, xoffset=5, decay=1.5, n=None, 92 | plot=True, markers=None, genes=None, figsize=(6,3.5), 93 | markeroffsets=None, labelsize=10, alpha=1, verbose=1): 94 | 95 | if sparse.issparse(data): 96 | zeroRate = 1 - np.squeeze(np.array((data>threshold).mean(axis=0))) 97 | A = data.multiply(data>threshold) 98 | A.data = np.log2(A.data) 99 | meanExpr = np.zeros_like(zeroRate) * np.nan 100 | detected = zeroRate < 1 101 | meanExpr[detected] = np.squeeze(np.array(A[:,detected].mean(axis=0))) / (1-zeroRate[detected]) 102 | else: 103 | zeroRate = 1 - np.mean(data>threshold, axis=0) 104 | meanExpr = np.zeros_like(zeroRate) * np.nan 105 | detected = zeroRate < 1 106 | mask = data[:,detected]>threshold 107 | logs = np.zeros_like(data[:,detected]) * np.nan 108 | logs[mask] = np.log2(data[:,detected][mask]) 109 | meanExpr[detected] = np.nanmean(logs, axis=0) 110 | 111 | lowDetection = np.array(np.sum(data>threshold, axis=0)).squeeze() < atleast 112 | zeroRate[lowDetection] = np.nan 113 | meanExpr[lowDetection] = np.nan 114 | 115 | if n is not None: 116 | up = 10 117 | low = 0 118 | for t in range(100): 119 | nonan = ~np.isnan(zeroRate) 120 | selected = np.zeros_like(zeroRate).astype(bool) 121 | selected[nonan] = zeroRate[nonan] > np.exp(-decay*(meanExpr[nonan] - xoffset)) + yoffset 122 | if np.sum(selected) == n: 123 | break 124 | elif np.sum(selected) < n: 125 | up = xoffset 126 | xoffset = (xoffset + low)/2 127 | else: 128 | low = xoffset 129 | xoffset = (xoffset + up)/2 130 | if verbose>0: 131 | print('Chosen offset: {:.2f}'.format(xoffset)) 132 | else: 133 | nonan = ~np.isnan(zeroRate) 134 | selected = np.zeros_like(zeroRate).astype(bool) 135 | selected[nonan] = zeroRate[nonan] > np.exp(-decay*(meanExpr[nonan] - xoffset)) + yoffset 136 | 137 | if plot: 138 | if figsize is not None: 139 | plt.figure(figsize=figsize) 140 | plt.ylim([0, 1]) 141 | if threshold>0: 142 | plt.xlim([np.log2(threshold), np.ceil(np.nanmax(meanExpr))]) 143 | else: 144 | plt.xlim([0, np.ceil(np.nanmax(meanExpr))]) 145 | x = np.arange(plt.xlim()[0], plt.xlim()[1]+.1,.1) 146 | y = np.exp(-decay*(x - xoffset)) + yoffset 147 | if decay==1: 148 | plt.text(.4, 0.2, '{} genes selected\ny = exp(-x+{:.2f})+{:.2f}'.format(np.sum(selected),xoffset, yoffset), 149 | color='k', fontsize=labelsize, transform=plt.gca().transAxes) 150 | else: 151 | plt.text(.4, 0.2, '{} genes selected\ny = exp(-{:.1f}*(x-{:.2f}))+{:.2f}'.format(np.sum(selected),decay,xoffset, yoffset), 152 | color='k', fontsize=labelsize, transform=plt.gca().transAxes) 153 | 154 | plt.plot(x, y, color=sns.color_palette()[1], linewidth=2) 155 | xy = np.concatenate((np.concatenate((x[:,None],y[:,None]),axis=1), np.array([[plt.xlim()[1], 1]]))) 156 | t = plt.matplotlib.patches.Polygon(xy, color=sns.color_palette()[1], alpha=.4) 157 | plt.gca().add_patch(t) 158 | 159 | plt.scatter(meanExpr, zeroRate, s=1, alpha=alpha, rasterized=True) 160 | if threshold==0: 161 | plt.xlabel('Mean log2 nonzero expression') 162 | plt.ylabel('Frequency of zero expression') 163 | else: 164 | plt.xlabel('Mean log2 nonzero expression') 165 | plt.ylabel('Frequency of near-zero expression') 166 | plt.tight_layout() 167 | 168 | if markers is not None and genes is not None: 169 | if markeroffsets is None: 170 | markeroffsets = [(0, 0) for g in markers] 171 | for num,g in enumerate(markers): 172 | i = np.where(genes==g)[0] 173 | plt.scatter(meanExpr[i], zeroRate[i], s=10, color='k') 174 | dx, dy = markeroffsets[num] 175 | plt.text(meanExpr[i]+dx+.1, zeroRate[i]+dy, g, color='k', fontsize=labelsize) 176 | 177 | return selected 178 | --------------------------------------------------------------------------------