├── LICENSE.md
├── README.md
├── datasets
└── readme.md
├── script
├── README.md
├── run_LRP_script.sh
└── run_scMDC_script.sh
└── src
├── LRP.py
├── Simulation.R
├── fig1_.png
├── layers.py
├── preprocess.py
├── run_LRP.py
├── run_scMDC.py
├── run_scMDC_batch.py
├── scMDC.py
├── scMDC_batch.py
├── tree.txt
└── utils.py
/LICENSE.md:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright Xiang Lin
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # scMDC
2 | Single Cell Multi-omics deep clustering (**scMDC v1.0.1**)
3 |
4 | We develop a novel multimodal deep learning method, scMDC, for single-cell multi-omics data clustering analysis. scMDC is an end-to-end deep model that explicitly characterizes different data sources and jointly learns latent features of deep embedding for clustering analysis. Extensive simulation and real-data experiments reveal that scMDC outperforms existing single-cell single-modal and multimodal clustering methods on different single-cell multimodal datasets. The linear scalability of running time makes scMDC a promising method for analyzing large multimodal datasets.
5 |
6 | ## Table of contents
7 | - [Network diagram](#diagram)
8 | - [Dependencies](#Dependencies)
9 | - [Usage](#Usage)
10 | - [Output](#Output)
11 | - [Arguments](#Arguments)
12 | - [Citation](#Citation)
13 | - [Contact](#contact)
14 |
15 | ## Network diagram
16 | 
17 |
18 | ## Dependencies
19 | Python 3.8.1
20 |
21 | Pytorch 1.6.0
22 |
23 | Scanpy 1.6.0
24 |
25 | SKlearn 0.22.1
26 |
27 | Numpy 1.18.1
28 |
29 | h5py 2.9.0
30 |
31 | All experiments of scMDC in this study are conducted on Nvidia Tesla P100 (16G) GPU.
32 | We suggest to install the dependencies in a conda environment (conda create -n scMDC).
33 | It takes few minutes to install the dependencies.
34 | scMDC takes about 3 minutes to cluster a dataset with 5000 cells.
35 |
36 | ## Usage
37 | 1) Prepare the input data in h5 format. (See readme in 'dataset' folder)
38 | 2) Run scMDC according to the running script in "script" folder (Note the parameter settings if you work on mRNA+ATAC data and use run_scMDC_batch.py for multi-batch data clustering)
39 | 3) Run DE analysis by run_LRP.py based on the well-trained scMDC model (refer the LRP running script in "script" folder)
40 |
41 | ## Output
42 | 1) scMDC outputs a latent representation of data which can be used for further downstream analyses and visualized by t-SNE or Umap;
43 | 2) Multi-batch scMDC outputs a latent representation of integrated datasets on which the batch effects are corrected.
44 | 3) LRP outputs a gene rank which indicates the importances of genes for a given cluster and can be used for pathway analysis.
45 |
46 | ## Arguments
47 | --n_clusters: number of clusters (K); scMDC will estimate K if this arguments is set to -1.
48 | --cutoff: A ratio of epoch before which the model only train the low-level autoencoders.
49 | --batch_size: batch size.
50 | --data_file: path to the data input.
51 | Data format: H5.
52 | Structure: X1(RNA), X2(ADT or ATAC), Y(label, if exit), Batch (Batch indicator for multi-batch data clustering).
53 | --maxiter: maximum epochs of training. Default: 10000.
54 | --pretrain_epochs: number of epochs for pre-training. Default: 400.
55 | --gamma: coefficient of clustering loss. Default: 0.1.
56 | --phi1 and phi2: coefficient of KL loss in pretraining and clustering stage. Default: 0.001 for CITE-Seq; 0.005 for SMAGE-Seq*.
57 | --update_interval: the interval to check the performance. Default: 1.
58 | --tol: the criterion to stop the model, which is a percentage of changed labels. Default: 0.001.
59 | --ae_weights: path of the weight file.
60 | --save_dir: the directory to store the outputs.
61 | --ae_weight_file: the directory to store the weights.
62 | --resolution: the resolution parameter to estimate k. Default: 0.2.
63 | --n_neighbors: the n_neighbors parameter to estimate K. Default: 30.
64 | --embedding_file: if save embedding file. Default: No
65 | --prediction_file: if save prediction file. Default: No
66 | --encodeLayer: layers of the low-level encoder for RNA: Default: [256,64,32,16] for CITE-Seq; [256,128,64] for SMAGE-seq.
67 | --decodeLayer1: layers of the low-level encoder for ADT: Default: [16,64,256] for CITE-Seq. [64,128,256] for SMAGE-seq.
68 | --decodeLayer2: layers of the high-level encoder. Default:[16,20] for CITE-Seq. [64,128,256] for SMAGE-seq.
69 | --sigma1: noise on RNA data. Default: 2.5.
70 | --sigma2: noise on ADT data. Default: 1.5 for CITE-Seq; 2.5 for SMAGE-Seq
71 | --filter1: if do feature selection on Genes. Default: No.
72 | --filter2: if do feature selection on ATAC. Default: No.
73 | --f1: Number of high variable genes (in X1) used for clustering if doing the featue selection. Default: 2000
74 | --f2: Number of high variable genes from ATAC (in X2) used for clustering if doing the featue selection. Default: 2000
75 | *We denote 10X Single-Cell Multiome ATAC + Gene Expression technology as SMAGE-seq for convenience.
76 |
77 |
78 | ## Citation
79 | Lin, X., Tian, T., Wei, Z., & Hakonarson, H. (2022). Clustering of single-cell multi-omics data with a multimodal deep learning method. Nature Communications, 13(1), 1-18.
80 |
81 | ## Contact
82 | Xiang Lin
83 |
--------------------------------------------------------------------------------
/datasets/readme.md:
--------------------------------------------------------------------------------
1 | # Example Real datasets
2 | Example datasets are in h5 format and can be downloaded from https://drive.google.com/drive/folders/1HTCrb3HZpbN8Nte5xGhY517NMdyzqftw?usp=sharing
3 |
4 | Required objects in h5 file for running scMDC
5 | 1) X1: mRNA count matrix
6 | 2) X2: ADT count matrix
7 | 3) Y: True labels (if exist)
8 | 4) Batch: batch indicator (for multi-batch analysis)
9 |
10 | Other objects in the h5 files:
11 | 1) ADT: feature names in ADT count matirx (only in CITE-seq data)
12 | 2) GenesFromPeaks: feature names in the gene-to-cell matrix mapped from scATAC-seq (only in SMAGE-seq data)
13 | 3) Genes: feature names in mRNA count matrix
14 | 4) Cell types: cell type of each cell (if exist).
15 | 5) Barcodes: cell barcodes (if exits)
16 |
17 | Note: for using filtered datasets (Normalized_filtred*), use X1_ and X2_ as inputs.
18 |
--------------------------------------------------------------------------------
/script/README.md:
--------------------------------------------------------------------------------
1 | # Note
2 | For using full feature datasets, use X1 and X2 as inputs and turn on the filtering (--filter1 and/or --filter2) function.
3 | For using filtered datasets (Normalized_filtered*), use X1_ and X2_ as inputs and turn off the filtering function.
4 |
--------------------------------------------------------------------------------
/script/run_LRP_script.sh:
--------------------------------------------------------------------------------
1 | # Run DE analysis (LRP) based on the well-trained scMDC model and its results.
2 | f=../datasets/CITESeq_GSE128639_BMNC_annodata.h5
3 | python -u run_scMDC.py --n_clusters 27 --ae_weight_file ./out_bmnc_full/AE_weights_bmnc.pth.tar --data_file $f --prediction_file bmnc --save_dir out_bmnc_full --filter1
4 | python -u run_LRP.py --n_clusters 27 --ae_weights ./out_bmnc_full/AE_weights_bmnc.pth.tar --cluster_index_file ./out_bmnc_full/1_pred.csv --data_file $f --save_dir out_bmnc_full --filter1
5 |
--------------------------------------------------------------------------------
/script/run_scMDC_script.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -l
2 | #SBATCH --gres=gpu:1
3 |
4 | #Here are the commands for the real data experiments. We test ten times on each dataset.
5 |
6 | f=../datasets/CITESeq_GSE128639_BMNC_annodata.h5
7 | echo "Run CITE-seq BMNC"
8 | python -u run_scMDC.py --n_clusters 27 --ae_weight_file AE_weights_bmnc.pth.tar --data_file $f --save_dir citeseq_bmnc /
9 | --embedding_file --prediction_file --filter1
10 |
11 | f=../datasets/SMAGESeq_10X_pbmc_granulocyte_plus.h5
12 | echo "Run SMAGE-seq PBMC10K"
13 | python -u run_scMDC.py --n_clusters 12 --ae_weight_file AE_weights_pbmc10k.pth.tar --data_file $f --save_dir atac_pbmc10k /
14 | --embedding_file --prediction_file --filter1 --filter2 --f1 2000 --f2 2000 -el 256 128 64 -dl1 64 128 256 -dl2 64 128 256 -phi1 0.005 -phi2 0.005 -signma2 2.5 -tau .1
15 |
16 | f=../datasets/CITESeq_realdata_spleen_lymph_111_anno_multiBatch.h5
17 | echo "Run multi-batch CITE-seq SLN111"
18 | python -u run_scMDC_batch.py --n_clusters 35 --ae_weight_file AE_weights_sln111.pth.tar --data_file $f --save_dir citeseq_sln111 /
19 | --embedding_file --prediction_file --filter1
20 |
--------------------------------------------------------------------------------
/src/LRP.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | from torch.functional import norm
4 | import torch.nn as nn
5 | from torch.autograd import Variable
6 | from torch.nn import Parameter
7 | import torch.nn.functional as F
8 | import torch.optim as optim
9 | from torch.optim.lr_scheduler import *
10 | from torch.utils.data import DataLoader, TensorDataset
11 | from torch.nn.utils import clip_grad_norm_
12 | import numpy as np
13 | import math, os
14 |
15 |
16 | class ClustDistLayer(nn.Module):
17 | def __init__(self, centroids, n_clusters, clust_list, device):
18 | super(ClustDistLayer, self).__init__()
19 | self.centroids = Variable(centroids).to(device)
20 | self.n_clusters = n_clusters
21 | self.clust_list = clust_list
22 |
23 | def forward(self, x, curr_clust_id):
24 | output = []
25 | for i in self.clust_list:
26 | if i==curr_clust_id:
27 | continue
28 | weight = 2 * (self.centroids[self.clust_list.index(curr_clust_id)] - self.centroids[self.clust_list.index(i)])
29 | bias = torch.norm(self.centroids[self.clust_list.index(curr_clust_id)], p=2) - torch.norm(self.centroids[self.clust_list.index(i)], p=2)
30 | h = torch.matmul(x, weight.T) + bias
31 | output.append(h.unsqueeze(1))
32 |
33 | return torch.cat(output, dim=1)
34 |
35 |
36 | class ClustMinPoolLayer(nn.Module):
37 | def __init__(self, beta):
38 | super(ClustMinPoolLayer, self).__init__()
39 | self.beta = beta
40 | self.eps = 1e-10
41 |
42 | def forward(self, inputs):
43 | return - torch.log(torch.sum(torch.exp(inputs * -self.beta), dim=1) + self.eps)
44 |
45 |
46 | class LRP(nn.Module):
47 | def __init__(self, model, X1, X2, Z, clust_ids, n_clusters, beta=1., device="cuda"):
48 | super(LRP, self).__init__()
49 | #model.freeze_model()
50 | self.model = model
51 | self.clust_ids = clust_ids
52 | self.n_clusters = n_clusters
53 | self.clust_list = np.unique(clust_ids).astype(int).tolist()
54 | self.centroids_ =torch.tensor(self.set_centroids(Z), dtype=torch.float32)
55 | self.X1_ = torch.tensor(X1, dtype=torch.float32)
56 | self.X2_ = torch.tensor(X2, dtype=torch.float32)
57 | self.Z_ = torch.tensor(Z, dtype=torch.float32)
58 | self.distLayer = ClustDistLayer(self.centroids_, n_clusters, self.clust_list, device).to(device)
59 | self.clustMinPool = ClustMinPoolLayer(beta).to(device)
60 | self.device = device
61 |
62 | def set_centroids(self, Z):
63 | centroids = []
64 | for i in self.clust_list:
65 | clust_Z = Z[self.clust_ids==i]
66 | curr_centroid = np.mean(clust_Z, axis=0)
67 | centroids.append(curr_centroid)
68 |
69 | return np.stack(centroids, axis=0)
70 |
71 | def clust_minpoolAct(self, X1, X2, curr_clust_id):
72 | z,_,_,_,_,_,_,_,_ = self.model.forwardAE(X1, X2)
73 | return self.clustMinPool(self.distLayer(z, curr_clust_id))
74 |
75 | def calc_carlini_wagner_one_vs_one(self, clust_c_id, clust_k_id, margin=1., lamda=1e2, max_iter=5000, lr=2e-3, use_abs=True):
76 | X1_0 = Variable(self.X1_[self.clust_ids==clust_c_id], requires_grad=False).to(self.device)
77 | curr_X1 = Variable(X1_0 + 1e-6, requires_grad=True).to(self.device)
78 | X2_0 = Variable(self.X2_[self.clust_ids==clust_c_id], requires_grad=False).to(self.device)
79 | curr_X2 = Variable(X2_0 + 1e-6, requires_grad=True).to(self.device)
80 | optimizer = optim.SGD([curr_X1, curr_X2], lr=lr)
81 |
82 | for iter in range(max_iter):
83 | clust_c_minpoolAct_tensor = self.clust_minpoolAct(curr_X1, curr_X2, clust_c_id)
84 | clust_k_minpoolAct_tensor = self.clust_minpoolAct(curr_X1, curr_X2, clust_k_id)
85 | clust_loss_tensor = margin + clust_c_minpoolAct_tensor - clust_k_minpoolAct_tensor
86 | clust_loss_tensor = torch.maximum(clust_loss_tensor, torch.zeros_like(clust_loss_tensor))
87 | clust_loss = torch.sum(clust_loss_tensor)
88 |
89 | norm_loss = torch.norm(curr_X1 - X1_0, p=1) + torch.norm(curr_X2 - X2_0, p=1)
90 |
91 | loss = clust_loss * lamda + norm_loss
92 |
93 | optimizer.zero_grad()
94 | loss.backward()
95 | optimizer.step()
96 |
97 | if (iter+1) % 50 == 0:
98 | print('Iteration {}, Total loss:{:.8f}, clust loss:{:.8f}, L1 penalty:{:.8f}'.format(iter, loss.item(), clust_loss.item(), norm_loss.item()))
99 |
100 | if use_abs:
101 | rel_score1 = torch.mean(torch.abs(curr_X1 - X1_0), dim=0)
102 | rel_score2 = torch.mean(torch.abs(curr_X2 - X2_0), dim=0)
103 | else:
104 | rel_score1 = torch.mean(curr_X1 - X1_0, dim=0)
105 | rel_score2 = torch.mean(curr_X2 - X2_0, dim=0)
106 | return rel_score1.data.cpu().numpy(), rel_score2.data.cpu().numpy()
107 |
108 | def calc_carlini_wagner_one_vs_rest(self, clust_c_id, margin=1., lamda=1e2, max_iter=5000, lr=2e-3, use_abs=True):
109 | X1_0 = Variable(self.X1_[self.clust_ids==clust_c_id], requires_grad=False).to(self.device)
110 | curr_X1 = Variable(X1_0 + 1e-6, requires_grad=True).to(self.device)
111 | X2_0 = Variable(self.X2_[self.clust_ids==clust_c_id], requires_grad=False).to(self.device)
112 | curr_X2 = Variable(X2_0 + 1e-6, requires_grad=True).to(self.device)
113 | optimizer = optim.SGD([curr_X1, curr_X2], lr=lr)
114 |
115 | for iter in range(max_iter):
116 | clust_rest_minpoolAct_tensor_list = []
117 | for clust_k_id in self.clust_list:
118 | if clust_k_id == clust_c_id:
119 | continue
120 | clust_k_minpoolAct_tensor = self.clust_minpoolAct(curr_X1, curr_X2, clust_k_id)
121 | clust_rest_minpoolAct_tensor_list.append(clust_k_minpoolAct_tensor)
122 | clust_rest_minpoolAct_tensor = clust_rest_minpoolAct_tensor_list[0]
123 | for clust_k_id in range(1, len(clust_rest_minpoolAct_tensor_list)):
124 | clust_rest_minpoolAct_tensor = torch.maximum(clust_rest_minpoolAct_tensor, clust_rest_minpoolAct_tensor_list[clust_k_id])
125 |
126 | clust_c_minpoolAct_tensor = self.clust_minpoolAct(curr_X1, curr_X2, clust_c_id)
127 |
128 | clust_loss_tensor = margin + clust_c_minpoolAct_tensor - clust_rest_minpoolAct_tensor
129 | clust_loss_tensor = torch.maximum(clust_loss_tensor, torch.zeros_like(clust_loss_tensor))
130 | clust_loss = torch.sum(clust_loss_tensor)
131 |
132 | norm_loss = torch.norm(curr_X1 - X1_0, p=1) + torch.norm(curr_X2 - X2_0, p=1)
133 |
134 | loss = clust_loss * lamda + norm_loss
135 |
136 | optimizer.zero_grad()
137 | loss.backward()
138 | optimizer.step()
139 |
140 | if (iter+1) % 50 == 0:
141 | print('Iteration {}, Total loss:{:.8f}, clust loss:{:.8f}, L1 penalty:{:.8f}'.format(iter, loss.item(), clust_loss.item(), norm_loss.item()))
142 |
143 | if use_abs:
144 | rel_score1 = torch.mean(torch.abs(curr_X1 - X1_0), dim=0)
145 | rel_score2 = torch.mean(torch.abs(curr_X2 - X2_0), dim=0)
146 | else:
147 | rel_score1 = torch.mean(curr_X1 - X1_0, dim=0)
148 | rel_score2 = torch.mean(curr_X2 - X2_0, dim=0)
149 | return rel_score1.data.cpu().numpy(), rel_score2.data.cpu().numpy()
150 |
--------------------------------------------------------------------------------
/src/Simulation.R:
--------------------------------------------------------------------------------
1 | library(SymSim)
2 | library(rhdf5)
3 | library(Seurat)
4 |
5 | phyla <- read.tree("tree.txt")
6 | phyla2 <- read.tree("tree.txt")
7 | data(gene_len_pool)
8 |
9 | #This is a simulation script with adding batch effect
10 | #First we need to modify the DivideBatches2 function in SymSim to make the same batch partition in mRNA and ADT data.
11 | DivideBatches2 <- function(observed_counts_res, batchIDs, batch_effect_size=1){
12 | observed_counts <- observed_counts_res[["counts"]]
13 | meta_cell <- observed_counts_res[["cell_meta"]]
14 | ncells <- dim(observed_counts)[2]
15 | ngenes <- dim(observed_counts)[1]
16 | nbatch <- unique(batchIDs)
17 | meta_cell2 <- data.frame(batch = batchIDs, stringsAsFactors = F)
18 | meta_cell <- cbind(meta_cell, meta_cell2)
19 | mean_matrix <- matrix(0, ngenes, nbatch)
20 | gene_mean <- rnorm(ngenes, 0, 1)
21 | temp <- lapply(1:ngenes, function(igene) {
22 | return(runif(nbatch, min = gene_mean[igene] - batch_effect_size,
23 | max = gene_mean[igene] + batch_effect_size))
24 | })
25 | mean_matrix <- do.call(rbind, temp)
26 | batch_factor <- matrix(0, ngenes, ncells)
27 | for (igene in 1:ngenes) {
28 | for (icell in 1:ncells) {
29 | batch_factor[igene, icell] <- rnorm(n = 1, mean = mean_matrix[igene,
30 | batchIDs[icell]], sd = 0.01)
31 | }
32 | }
33 | observed_counts <- round(2^(log2(observed_counts) + batch_factor))
34 | return(list(counts = observed_counts, cell_meta = meta_cell))
35 | }
36 |
37 | for(k in 1:10){
38 | ##RNA
39 | ncells = 1000
40 | nbatchs = 2
41 | batchIDs <- sample(1:nbatchs, ncells, replace = TRUE)
42 | print(k)
43 | print("Simulate RNA")
44 | ngenes <- 2000
45 | gene_len <- sample(gene_len_pool, ngenes, replace = FALSE)
46 | true_RNAcounts_res <- SimulateTrueCounts(ncells_total=ncells,
47 | min_popsize=50,
48 | i_minpop=1,
49 | ngenes=ngenes,
50 | nevf=10,
51 | evf_type="discrete",
52 | n_de_evf=6,
53 | vary="s",
54 | Sigma=0.6,
55 | phyla=phyla,
56 | randseed=k+1000)
57 |
58 | observed_RNAcounts <- True2ObservedCounts(true_counts=true_RNAcounts_res[[1]],
59 | meta_cell=true_RNAcounts_res[[3]],
60 | protocol="UMI",
61 | alpha_mean=0.00075,
62 | alpha_sd=0.0001,
63 | gene_len=gene_len,
64 | depth_mean=50000,
65 | depth_sd=3000,
66 | )
67 |
68 | batch_RNAcounts <- DivideBatches2(observed_RNAcounts, batchIDs, batch_effect_size = 1)
69 |
70 | ## Add batch effects
71 | print((sum(batch_RNAcounts$counts==0)-sum(true_RNAcounts_res$counts==0))/sum(true_RNAcounts_res$counts>0))
72 | print(sum(batch_RNAcounts$counts==0)/prod(dim(batch_RNAcounts$counts)))
73 |
74 | ##ADT
75 | print("Simulate ADT")
76 | nadts <- 100
77 | gene_len <- sample(gene_len_pool, nadts, replace = FALSE)
78 | #The true counts of the five populations can be simulated:
79 | true_ADTcounts_res <- SimulateTrueCounts(ncells_total=ncells,
80 | min_popsize=50,
81 | i_minpop=1,
82 | ngenes=nadts,
83 | nevf=10,
84 | evf_type="discrete",
85 | n_de_evf=6,
86 | vary="s",
87 | Sigma=0.3,
88 | phyla=phyla2,
89 | randseed=k+1000)
90 |
91 | observed_ADTcounts <- True2ObservedCounts(true_counts=true_ADTcounts_res[[1]],
92 | meta_cell=true_ADTcounts_res[[3]],
93 | protocol="UMI",
94 | alpha_mean=0.045,
95 | alpha_sd=0.01,
96 | gene_len=gene_len,
97 | depth_mean=50000,
98 | depth_sd=3000,
99 | )
100 |
101 | ## Add batch effects
102 | batch_ADTcounts <- DivideBatches2(observed_ADTcounts, batchIDs, batch_effect_size = 1)
103 |
104 | print((sum(batch_ADTcounts$counts==0)-sum(true_ADTcounts_res$counts==0))/sum(true_ADTcounts_res$counts>0))
105 | print(sum(batch_ADTcounts$counts==0)/prod(dim(batch_ADTcounts$counts)))
106 |
107 | y1 = batch_RNAcounts$cell_meta$pop
108 | y2 = batch_ADTcounts$cell_meta$pop
109 | batch1 = batch_ADTcounts$cell_meta$batch
110 | batch2 = batch_RNAcounts$cell_meta$batch
111 | print(sum(y1==y2))
112 | print(sum(batch1 == batch2))
113 |
114 | counts1 <- batch_RNAcounts[[1]]
115 | counts2 <- batch_ADTcounts[[1]]
116 |
117 | #filter
118 | rownames(counts2) <- paste("G",1:nrow(counts2),sep = "")
119 | colnames(counts2) <- paste("C",1:ncol(counts2),sep = "")
120 |
121 | pbmc <- CreateSeuratObject(counts = counts2, project = "P2", min.cells = 0, min.features = 0)
122 | pbmc <- NormalizeData(pbmc, normalization.method = "LogNormalize")
123 | pbmc <- FindVariableFeatures(pbmc, selection.method = "vst", nfeatures = 30)
124 | counts2 <- counts2[pbmc@assays[["RNA"]]@var.features,]
125 |
126 | h5file = paste("./batch/Simulation.", k, ".h5", sep="")
127 | h5createFile(h5file)
128 | h5write(as.matrix(counts1), h5file,"X1")
129 | h5write(as.matrix(counts2), h5file,"X2")
130 | h5write(y1, h5file,"Y")
131 | h5write(batch1, h5file,"Batch")
132 | }
133 |
--------------------------------------------------------------------------------
/src/fig1_.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xianglin226/scMDC/d0d4baeefc1fda27342d7b715167cbae4d081ca7/src/fig1_.png
--------------------------------------------------------------------------------
/src/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class NBLoss(nn.Module):
7 | def __init__(self):
8 | super(NBLoss, self).__init__()
9 |
10 | def forward(self, x, mean, disp, scale_factor=1.0):
11 | eps = 1e-10
12 | scale_factor = scale_factor[:, None]
13 | mean = mean * scale_factor
14 |
15 | t1 = torch.lgamma(disp+eps) + torch.lgamma(x+1.0) - torch.lgamma(x+disp+eps)
16 | t2 = (disp+x) * torch.log(1.0 + (mean/(disp+eps))) + (x * (torch.log(disp+eps) - torch.log(mean+eps)))
17 | result = t1 + t2
18 |
19 | result = torch.mean(result)
20 | return result
21 |
22 |
23 | class ZINBLoss(nn.Module):
24 | def __init__(self):
25 | super(ZINBLoss, self).__init__()
26 |
27 | def forward(self, x, mean, disp, pi, scale_factor=1.0, ridge_lambda=0.0):
28 | eps = 1e-10
29 | scale_factor = scale_factor[:, None]
30 | mean = mean * scale_factor
31 |
32 | t1 = torch.lgamma(disp+eps) + torch.lgamma(x+1.0) - torch.lgamma(x+disp+eps)
33 | t2 = (disp+x) * torch.log(1.0 + (mean/(disp+eps))) + (x * (torch.log(disp+eps) - torch.log(mean+eps)))
34 | nb_final = t1 + t2
35 |
36 | nb_case = nb_final - torch.log(1.0-pi+eps)
37 | zero_nb = torch.pow(disp/(disp+mean+eps), disp)
38 | zero_case = -torch.log(pi + ((1.0-pi)*zero_nb)+eps)
39 | result = torch.where(torch.le(x, 1e-8), zero_case, nb_case)
40 |
41 | if ridge_lambda > 0:
42 | ridge = ridge_lambda*torch.square(pi)
43 | result += ridge
44 |
45 | result = torch.mean(result)
46 | return result
47 |
48 |
49 | class GaussianNoise(nn.Module):
50 | def __init__(self, sigma=0):
51 | super(GaussianNoise, self).__init__()
52 | self.sigma = sigma
53 |
54 | def forward(self, x):
55 | if self.training:
56 | x = x + self.sigma * torch.randn_like(x)
57 | return x
58 |
59 |
60 | class MeanAct(nn.Module):
61 | def __init__(self):
62 | super(MeanAct, self).__init__()
63 |
64 | def forward(self, x):
65 | return torch.clamp(torch.exp(x), min=1e-5, max=1e6)
66 |
67 | class DispAct(nn.Module):
68 | def __init__(self):
69 | super(DispAct, self).__init__()
70 |
71 | def forward(self, x):
72 | return torch.clamp(F.softplus(x), min=1e-4, max=1e4)
73 |
--------------------------------------------------------------------------------
/src/preprocess.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Goekcen Eraslan
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import pickle, os, numbers
22 |
23 | import numpy as np
24 | import scipy as sp
25 | import pandas as pd
26 | import scanpy as sc
27 | from sklearn.model_selection import train_test_split
28 | from sklearn.preprocessing import scale
29 | import scipy
30 |
31 | #TODO: Fix this
32 | class AnnSequence:
33 | def __init__(self, matrix, batch_size, sf=None):
34 | self.matrix = matrix
35 | if sf is None:
36 | self.size_factors = np.ones((self.matrix.shape[0], 1),
37 | dtype=np.float32)
38 | else:
39 | self.size_factors = sf
40 | self.batch_size = batch_size
41 |
42 | def __len__(self):
43 | return len(self.matrix) // self.batch_size
44 |
45 | def __getitem__(self, idx):
46 | batch = self.matrix[idx*self.batch_size:(idx+1)*self.batch_size]
47 | batch_sf = self.size_factors[idx*self.batch_size:(idx+1)*self.batch_size]
48 |
49 | # return an (X, Y) pair
50 | return {'count': batch, 'size_factors': batch_sf}, batch
51 |
52 |
53 | def read_dataset(adata, transpose=False, test_split=False, copy=False):
54 |
55 | if isinstance(adata, sc.AnnData):
56 | if copy:
57 | adata = adata.copy()
58 | elif isinstance(adata, str):
59 | adata = sc.read(adata)
60 | else:
61 | raise NotImplementedError
62 |
63 | norm_error = 'Make sure that the dataset (adata.X) contains unnormalized count data.'
64 | assert 'n_count' not in adata.obs, norm_error
65 |
66 | if adata.X.size < 50e6: # check if adata.X is integer only if array is small
67 | if sp.sparse.issparse(adata.X):
68 | assert (adata.X.astype(int) != adata.X).nnz == 0, norm_error
69 | else:
70 | assert np.all(adata.X.astype(int) == adata.X), norm_error
71 |
72 | if transpose: adata = adata.transpose()
73 |
74 | if test_split:
75 | train_idx, test_idx = train_test_split(np.arange(adata.n_obs), test_size=0.1, random_state=42)
76 | spl = pd.Series(['train'] * adata.n_obs)
77 | spl.iloc[test_idx] = 'test'
78 | adata.obs['DCA_split'] = spl.values
79 | else:
80 | adata.obs['DCA_split'] = 'train'
81 |
82 | adata.obs['DCA_split'] = adata.obs['DCA_split'].astype('category')
83 | print('### Autoencoder: Successfully preprocessed {} genes and {} cells.'.format(adata.n_vars, adata.n_obs))
84 |
85 | return adata
86 |
87 | def clr_normalize_each_cell(adata):
88 | """Normalize count vector for each cell, i.e. for each row of .X"""
89 |
90 | def seurat_clr(x):
91 | # TODO: support sparseness
92 | s = np.sum(np.log1p(x[x > 0]))
93 | exp = np.exp(s / len(x))
94 | return np.log1p(x / exp)
95 |
96 | adata.raw = adata.copy()
97 | sc.pp.normalize_per_cell(adata)
98 | adata.obs['size_factors'] = adata.obs.n_counts / np.median(adata.obs.n_counts)
99 |
100 | # apply to dense or sparse matrix, along axis. returns dense matrix
101 | adata.X = np.apply_along_axis(
102 | seurat_clr, 1, (adata.raw.X.A if scipy.sparse.issparse(adata.raw.X) else adata.raw.X)
103 | )
104 | return adata
105 |
106 | def normalize(adata, filter_min_counts=True, size_factors=True, normalize_input=True, logtrans_input=True):
107 |
108 | if filter_min_counts:
109 | sc.pp.filter_genes(adata, min_counts=1)
110 | sc.pp.filter_cells(adata, min_counts=1)
111 |
112 | if size_factors or normalize_input or logtrans_input:
113 | adata.raw = adata.copy()
114 | else:
115 | adata.raw = adata
116 |
117 | if size_factors:
118 | sc.pp.normalize_per_cell(adata)
119 | adata.obs['size_factors'] = adata.obs.n_counts / np.median(adata.obs.n_counts)
120 | else:
121 | adata.obs['size_factors'] = 1.0
122 |
123 | if logtrans_input:
124 | sc.pp.log1p(adata)
125 |
126 | if normalize_input:
127 | sc.pp.scale(adata)
128 |
129 | return adata
130 |
131 | def read_genelist(filename):
132 | genelist = list(set(open(filename, 'rt').read().strip().split('\n')))
133 | assert len(genelist) > 0, 'No genes detected in genelist file'
134 | print('### Autoencoder: Subset of {} genes will be denoised.'.format(len(genelist)))
135 |
136 | return genelist
137 |
138 | def write_text_matrix(matrix, filename, rownames=None, colnames=None, transpose=False):
139 | if transpose:
140 | matrix = matrix.T
141 | rownames, colnames = colnames, rownames
142 |
143 | pd.DataFrame(matrix, index=rownames, columns=colnames).to_csv(filename,
144 | sep='\t',
145 | index=(rownames is not None),
146 | header=(colnames is not None),
147 | float_format='%.6f')
148 | def read_pickle(inputfile):
149 | return pickle.load(open(inputfile, "rb"))
150 |
--------------------------------------------------------------------------------
/src/run_LRP.py:
--------------------------------------------------------------------------------
1 | from time import time
2 | import math, os
3 | from sklearn import metrics
4 | from sklearn.cluster import KMeans
5 | import torch
6 | import torch.nn as nn
7 | from torch.autograd import Variable
8 | from torch.nn import Parameter
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 | from torch.utils.data import DataLoader, TensorDataset
12 |
13 | from scMDC import scMultiCluster
14 | import numpy as np
15 | import collections
16 | import h5py
17 | import scanpy as sc
18 | from preprocess import read_dataset, normalize, clr_normalize_each_cell
19 | from utils import *
20 | from functools import reduce
21 | from LRP import LRP
22 |
23 |
24 | if __name__ == "__main__":
25 |
26 | # setting the hyper parameters
27 | import argparse
28 | parser = argparse.ArgumentParser(description='train',
29 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
30 | parser.add_argument('--n_clusters', default=8, type=int)
31 | parser.add_argument('--cutoff', default=0.5, type=float, help='Start to train combined layer after what ratio of epoch')
32 | parser.add_argument('--batch_size', default=256, type=int)
33 | parser.add_argument('--data_file', default='Simulation.1.h5')
34 | parser.add_argument('--cluster_index_file', default='label.txt')
35 | parser.add_argument('--maxiter', default=10000, type=int)
36 | parser.add_argument('--pretrain_epochs', default=400, type=int)
37 | parser.add_argument('--gamma', default=.1, type=float,
38 | help='coefficient of clustering loss')
39 | parser.add_argument('--tau', default=1., type=float,
40 | help='fuzziness of clustering loss')
41 | parser.add_argument('--phi1', default=0.001, type=float,
42 | help='coefficient of KL loss in pretraining stage')
43 | parser.add_argument('--phi2', default=0.001, type=float,
44 | help='coefficient of KL loss in clustering stage')
45 | parser.add_argument('--update_interval', default=1, type=int)
46 | parser.add_argument('--tol', default=0.001, type=float)
47 | parser.add_argument('--ae_weights', default=None)
48 | parser.add_argument('--save_dir', default='results/')
49 | parser.add_argument('--ae_weight_file', default='AE_weights_1.pth.tar')
50 | parser.add_argument('--resolution', default=0.2, type=float)
51 | parser.add_argument('--n_neighbors', default=30, type=int)
52 | parser.add_argument('--embedding_file', action='store_true', default=False)
53 | parser.add_argument('--prediction_file', action='store_true', default=False)
54 | parser.add_argument('-el','--encodeLayer', nargs='+', default=[256,64,32,16])
55 | parser.add_argument('-dl1','--decodeLayer1', nargs='+', default=[16,64,256])
56 | parser.add_argument('-dl2','--decodeLayer2', nargs='+', default=[16,20])
57 | parser.add_argument('--sigma1', default=2.5, type=float)
58 | parser.add_argument('--sigma2', default=1.5, type=float)
59 | parser.add_argument('--f1', default=1000, type=float, help='Number of mRNA after feature selection')
60 | parser.add_argument('--f2', default=2000, type=float, help='Number of ADT/ATAC after feature selection')
61 | parser.add_argument('--filter1', action='store_true', default=False, help='Do mRNA selection')
62 | parser.add_argument('--filter2', action='store_true', default=False, help='Do ADT/ATAC selection')
63 | parser.add_argument('--run', default=1, type=int)
64 | parser.add_argument('--beta', default=1., type=float,
65 | help='coefficient of the clustering fuzziness')
66 | parser.add_argument('--margin', default=1., type=float,
67 | help='margin of difference between logits')
68 | parser.add_argument('--lamda', default=100., type=float,
69 | help='coefficient of the clustering perturbation loss')
70 | parser.add_argument('--lr', default=0.001, type=int)
71 | parser.add_argument('--device', default='cuda')
72 |
73 | args = parser.parse_args()
74 | print(args)
75 | data_mat = h5py.File(args.data_file)
76 | x1 = np.array(data_mat['X1'])
77 | x2 = np.array(data_mat['X2'])
78 | #y = np.array(data_mat['Y']) - 1
79 | data_mat.close()
80 |
81 |
82 | clust_ids = np.loadtxt(args.cluster_index_file, delimiter=",").astype(int)
83 |
84 | #Gene features
85 | if args.filter1:
86 | importantGenes = geneSelection(x1, n=args.f1, plot=False)
87 | x1 = x1[:, importantGenes]
88 | if args.filter2:
89 | importantGenes = geneSelection(x2, n=args.f2, plot=False)
90 | x2 = x2[:, importantGenes]
91 |
92 | adata1 = sc.AnnData(x1)
93 | #adata1.obs['Group'] = y
94 |
95 | adata1 = read_dataset(adata1,
96 | transpose=False,
97 | test_split=False,
98 | copy=True)
99 |
100 | adata1 = normalize(adata1,
101 | size_factors=True,
102 | normalize_input=True,
103 | logtrans_input=True)
104 |
105 | adata2 = sc.AnnData(x2)
106 | #adata2.obs['Group'] = y
107 | adata2 = read_dataset(adata2,
108 | transpose=False,
109 | test_split=False,
110 | copy=True)
111 |
112 | adata2 = normalize(adata2,
113 | size_factors=True,
114 | normalize_input=True,
115 | logtrans_input=True)
116 |
117 | #adata2 = clr_normalize_each_cell(adata2)
118 |
119 | input_size1 = adata1.n_vars
120 | input_size2 = adata2.n_vars
121 | print(adata1.X.shape)
122 | print(adata2.X.shape)
123 |
124 | print(args)
125 |
126 | encodeLayer = list(map(int, args.encodeLayer))
127 | decodeLayer1 = list(map(int, args.decodeLayer1))
128 | decodeLayer2 = list(map(int, args.decodeLayer2))
129 |
130 | model = scMultiCluster(input_dim1=input_size1, input_dim2=input_size2, tau=args.tau,
131 | encodeLayer=encodeLayer, decodeLayer1=decodeLayer1, decodeLayer2=decodeLayer2,
132 | activation='elu', sigma1=args.sigma1, sigma2=args.sigma2, gamma=args.gamma,
133 | cutoff = args.cutoff, phi1=args.phi1, phi2=args.phi2, device=args.device).to(args.device)
134 |
135 | print(str(model))
136 |
137 | if os.path.isfile(args.ae_weights):
138 | print("==> loading checkpoint '{}'".format(args.ae_weights))
139 | checkpoint = torch.load(args.ae_weights)
140 | model.load_state_dict(checkpoint['ae_state_dict'])
141 | else:
142 | print("==> no checkpoint found at '{}'".format(args.ae_weights))
143 | raise ValueError
144 |
145 | n_clusters = np.unique(clust_ids).shape[0]
146 | print("n cluster is: " + str(n_clusters))
147 |
148 | Z = model.encodeBatch(torch.tensor(adata1.X).to(args.device), torch.tensor(adata2.X).to(args.device)).data.cpu().numpy()
149 |
150 | cluster_list = np.unique(clust_ids).astype(int).tolist()
151 | print(cluster_list)
152 |
153 | model_explainer = LRP(model, X1=adata1.X, X2=adata2.X, Z=Z, clust_ids=clust_ids, n_clusters=n_clusters, beta=args.beta).to(args.device)
154 |
155 | #for clust_c in [cluster_ind[0]]: #range(args.n_clusters):
156 | # for clust_k in [cluster_ind[1]]: #range(clust_c+1, args.n_clusters):
157 | # print("Cluster"+str(clust_c)+" vs Cluster"+str(clust_k))
158 | # rel_score1, rel_score2 = model_explainer.calc_carlini_wagner_one_vs_one(clust_c, clust_k, margin=args.margin, lamda=args.lamda, max_iter=args.maxiter, lr=args.lr)
159 | # print(rel_score1.shape)
160 | # print(rel_score2.shape)
161 | # np.savetxt(args.save_dir + "/" + str(clust_c)+"_vs_"+str(clust_k)+"_rel_mRNA_scores.csv", rel_score1, delimiter=",")
162 | # np.savetxt(args.save_dir + "/" + str(clust_c)+"_vs_"+str(clust_k)+"_rel_ADT_scores.csv", rel_score2, delimiter=",")
163 |
164 | for clust_c in cluster_list:
165 | print("Cluster"+str(clust_c)+" vs Rest")
166 | rel_score1, rel_score2 = model_explainer.calc_carlini_wagner_one_vs_rest(clust_c, margin=args.margin, lamda=args.lamda, max_iter=args.maxiter, lr=args.lr)
167 | np.savetxt(args.save_dir + "/" + str(clust_c)+"_vs_rest_rel_mRNA_scores.csv", rel_score1, delimiter=",")
168 | np.savetxt(args.save_dir + "/" + str(clust_c)+"_vs_rest_rel_ADT_scores.csv", rel_score2, delimiter=",")
169 |
--------------------------------------------------------------------------------
/src/run_scMDC.py:
--------------------------------------------------------------------------------
1 | from time import time
2 | import math, os
3 | from sklearn import metrics
4 | from sklearn.cluster import KMeans
5 | import torch
6 | import torch.nn as nn
7 | from torch.autograd import Variable
8 | from torch.nn import Parameter
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 | from torch.utils.data import DataLoader, TensorDataset
12 |
13 | from scMDC import scMultiCluster
14 | import numpy as np
15 | import collections
16 | import h5py
17 | import scanpy as sc
18 | from preprocess import read_dataset, normalize, clr_normalize_each_cell
19 | from utils import *
20 |
21 | if __name__ == "__main__":
22 |
23 | # setting the hyper parameters
24 | import argparse
25 | parser = argparse.ArgumentParser(description='train',
26 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
27 | parser.add_argument('--n_clusters', default=27, type=int)
28 | parser.add_argument('--cutoff', default=0.5, type=float, help='Start to train combined layer after what ratio of epoch')
29 | parser.add_argument('--batch_size', default=256, type=int)
30 | parser.add_argument('--data_file', default='Normalized_filtered_BMNC_GSE128639_Seurat.h5')
31 | parser.add_argument('--maxiter', default=5000, type=int)
32 | parser.add_argument('--pretrain_epochs', default=400, type=int)
33 | parser.add_argument('--gamma', default=.1, type=float,
34 | help='coefficient of clustering loss')
35 | parser.add_argument('--tau', default=1., type=float,
36 | help='fuzziness of clustering loss')
37 | parser.add_argument('--phi1', default=0.001, type=float,
38 | help='coefficient of KL loss in pretraining stage')
39 | parser.add_argument('--phi2', default=0.001, type=float,
40 | help='coefficient of KL loss in clustering stage')
41 | parser.add_argument('--update_interval', default=1, type=int)
42 | parser.add_argument('--tol', default=0.001, type=float)
43 | parser.add_argument('--lr', default=1., type=float)
44 | parser.add_argument('--ae_weights', default=None)
45 | parser.add_argument('--save_dir', default='results/')
46 | parser.add_argument('--ae_weight_file', default='AE_weights_1.pth.tar')
47 | parser.add_argument('--resolution', default=0.2, type=float)
48 | parser.add_argument('--n_neighbors', default=30, type=int)
49 | parser.add_argument('--embedding_file', action='store_true', default=False)
50 | parser.add_argument('--prediction_file', action='store_true', default=False)
51 | parser.add_argument('-el','--encodeLayer', nargs='+', default=[256,64,32,16])
52 | parser.add_argument('-dl1','--decodeLayer1', nargs='+', default=[16,64,256])
53 | parser.add_argument('-dl2','--decodeLayer2', nargs='+', default=[16,20])
54 | parser.add_argument('--sigma1', default=2.5, type=float)
55 | parser.add_argument('--sigma2', default=1.5, type=float)
56 | parser.add_argument('--f1', default=1000, type=float, help='Number of mRNA after feature selection')
57 | parser.add_argument('--f2', default=2000, type=float, help='Number of ADT/ATAC after feature selection')
58 | parser.add_argument('--filter1', action='store_true', default=False, help='Do mRNA selection')
59 | parser.add_argument('--filter2', action='store_true', default=False, help='Do ADT/ATAC selection')
60 | parser.add_argument('--run', default=1, type=int)
61 | parser.add_argument('--device', default='cuda')
62 | parser.add_argument('--no_labels', action='store_true', default=False)
63 | args = parser.parse_args()
64 | print(args)
65 |
66 | data_mat = h5py.File(args.data_file)
67 | x1 = np.array(data_mat['X1'])
68 | x2 = np.array(data_mat['X2'])
69 | if not args.no_labels:
70 | y = np.array(data_mat['Y'])
71 | data_mat.close()
72 |
73 | #Gene filter
74 | if args.filter1:
75 | importantGenes = geneSelection(x1, n=args.f1, plot=False)
76 | x1 = x1[:, importantGenes]
77 | if args.filter2:
78 | importantGenes = geneSelection(x2, n=args.f2, plot=False)
79 | x2 = x2[:, importantGenes]
80 |
81 | # preprocessing scRNA-seq read counts matrix
82 | adata1 = sc.AnnData(x1)
83 | #adata1.obs['Group'] = y
84 |
85 | adata1 = read_dataset(adata1,
86 | transpose=False,
87 | test_split=False,
88 | copy=True)
89 |
90 | adata1 = normalize(adata1,
91 | size_factors=True,
92 | normalize_input=True,
93 | logtrans_input=True)
94 |
95 | adata2 = sc.AnnData(x2)
96 | #adata2.obs['Group'] = y
97 | adata2 = read_dataset(adata2,
98 | transpose=False,
99 | test_split=False,
100 | copy=True)
101 |
102 | adata2 = normalize(adata2,
103 | size_factors=True,
104 | normalize_input=True,
105 | logtrans_input=True)
106 |
107 | #adata2 = clr_normalize_each_cell(adata2)
108 |
109 | input_size1 = adata1.n_vars
110 | input_size2 = adata2.n_vars
111 |
112 | print(args)
113 |
114 | encodeLayer = list(map(int, args.encodeLayer))
115 | decodeLayer1 = list(map(int, args.decodeLayer1))
116 | decodeLayer2 = list(map(int, args.decodeLayer2))
117 |
118 | model = scMultiCluster(input_dim1=input_size1, input_dim2=input_size2, tau=args.tau,
119 | encodeLayer=encodeLayer, decodeLayer1=decodeLayer1, decodeLayer2=decodeLayer2,
120 | activation='elu', sigma1=args.sigma1, sigma2=args.sigma2, gamma=args.gamma,
121 | cutoff = args.cutoff, phi1=args.phi1, phi2=args.phi2, device=args.device).to(args.device)
122 |
123 | print(str(model))
124 |
125 | if not os.path.exists(args.save_dir):
126 | os.makedirs(args.save_dir)
127 |
128 | t0 = time()
129 | if args.ae_weights is None:
130 | model.pretrain_autoencoder(X1=adata1.X, X_raw1=adata1.raw.X, sf1=adata1.obs.size_factors,
131 | X2=adata2.X, X_raw2=adata2.raw.X, sf2=adata2.obs.size_factors, batch_size=args.batch_size,
132 | epochs=args.pretrain_epochs, ae_weights=args.ae_weight_file)
133 | else:
134 | if os.path.isfile(args.ae_weights):
135 | print("==> loading checkpoint '{}'".format(args.ae_weights))
136 | checkpoint = torch.load(args.ae_weights)
137 | model.load_state_dict(checkpoint['ae_state_dict'])
138 | else:
139 | print("==> no checkpoint found at '{}'".format(args.ae_weights))
140 | raise ValueError
141 |
142 | print('Pretraining time: %d seconds.' % int(time() - t0))
143 |
144 | #get k
145 | latent = model.encodeBatch(torch.tensor(adata1.X).to(args.device), torch.tensor(adata2.X).to(args.device))
146 | latent = latent.cpu().numpy()
147 | if args.n_clusters == -1:
148 | n_clusters = GetCluster(latent, res=args.resolution, n=args.n_neighbors)
149 | else:
150 | print("n_cluster is defined as " + str(args.n_clusters))
151 | n_clusters = args.n_clusters
152 |
153 | if not args.no_labels:
154 | y_pred, _ = model.fit(X1=adata1.X, X_raw1=adata1.raw.X, sf1=adata1.obs.size_factors,
155 | X2=adata2.X, X_raw2=adata2.raw.X, sf2=adata2.obs.size_factors, y=y,
156 | n_clusters=n_clusters, batch_size=args.batch_size, num_epochs=args.maxiter,
157 | update_interval=args.update_interval, tol=args.tol, lr=args.lr, save_dir=args.save_dir)
158 | else:
159 | y_pred, _ = model.fit(X1=adata1.X, X_raw1=adata1.raw.X, sf1=adata1.obs.size_factors,
160 | X2=adata2.X, X_raw2=adata2.raw.X, sf2=adata2.obs.size_factors, y=None,
161 | n_clusters=n_clusters, batch_size=args.batch_size, num_epochs=args.maxiter,
162 | update_interval=args.update_interval, tol=args.tol, lr=args.lr, save_dir=args.save_dir)
163 | print('Total time: %d seconds.' % int(time() - t0))
164 |
165 | if args.prediction_file:
166 | if not args.no_labels:
167 | y_pred_ = best_map(y, y_pred) - 1
168 | np.savetxt(args.save_dir + "/" + str(args.run) + "_pred.csv", y_pred_, delimiter=",")
169 | else:
170 | np.savetxt(args.save_dir + "/" + str(args.run) + "_pred.csv", y_pred, delimiter=",")
171 |
172 | if args.embedding_file:
173 | final_latent = model.encodeBatch(torch.tensor(adata1.X).to(args.device), torch.tensor(adata2.X).to(args.device))
174 | final_latent = final_latent.cpu().numpy()
175 | np.savetxt(args.save_dir + "/" + str(args.run) + "_embedding.csv", final_latent, delimiter=",")
176 |
177 | if not args.no_labels:
178 | y_pred_ = best_map(y, y_pred)
179 | ami = np.round(metrics.adjusted_mutual_info_score(y, y_pred), 5)
180 | nmi = np.round(metrics.normalized_mutual_info_score(y, y_pred), 5)
181 | ari = np.round(metrics.adjusted_rand_score(y, y_pred), 5)
182 | print('Final: AMI= %.4f, NMI= %.4f, ARI= %.4f' % (ami, nmi, ari))
183 | else:
184 | print("No labels for evaluation!")
185 |
--------------------------------------------------------------------------------
/src/run_scMDC_batch.py:
--------------------------------------------------------------------------------
1 | from time import time
2 | import math, os
3 | from sklearn import metrics
4 | from sklearn.cluster import KMeans
5 | from sklearn.preprocessing import OneHotEncoder
6 | import torch
7 | import torch.nn as nn
8 | from torch.autograd import Variable
9 | from torch.nn import Parameter
10 | import torch.nn.functional as F
11 | import torch.optim as optim
12 | from torch.utils.data import DataLoader, TensorDataset
13 |
14 | from scMDC_batch import scMultiClusterBatch
15 | import numpy as np
16 | import collections
17 | import h5py
18 | import scanpy as sc
19 | from preprocess import read_dataset, normalize
20 | from utils import *
21 |
22 | if __name__ == "__main__":
23 |
24 | # setting the hyper parameters
25 | import argparse
26 | parser = argparse.ArgumentParser(description='train',
27 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
28 | parser.add_argument('--n_clusters', default=27, type=int)
29 | parser.add_argument('--cutoff', default=0.5, type=float, help='Start to train combined layer after what ratio of epoch')
30 | parser.add_argument('--batch_size', default=256, type=int)
31 | parser.add_argument('--data_file', default='Normalized_filtered_BMNC_GSE128639_Seurat.h5')
32 | parser.add_argument('--maxiter', default=5000, type=int)
33 | parser.add_argument('--pretrain_epochs', default=400, type=int)
34 | parser.add_argument('--gamma', default=.1, type=float,
35 | help='coefficient of clustering loss')
36 | parser.add_argument('--tau', default=1., type=float,
37 | help='weight of clustering loss')
38 | parser.add_argument('--phi1', default=0.001, type=float,
39 | help='coefficient of KL loss in pretraining stage')
40 | parser.add_argument('--phi2', default=0.001, type=float,
41 | help='coefficient of KL loss in clustering stage')
42 | parser.add_argument('--update_interval', default=1, type=int)
43 | parser.add_argument('--tol', default=0.001, type=float)
44 | parser.add_argument('--lr', default=1., type=float)
45 | parser.add_argument('--ae_weights', default=None)
46 | parser.add_argument('--save_dir', default='results/')
47 | parser.add_argument('--ae_weight_file', default='AE_weights_1.pth.tar')
48 | parser.add_argument('--resolution', default=0.2, type=float)
49 | parser.add_argument('--n_neighbors', default=30, type=int)
50 | parser.add_argument('--embedding_file', action='store_true', default=False)
51 | parser.add_argument('--prediction_file', action='store_true', default=False)
52 | parser.add_argument('-el','--encodeLayer', nargs='+', default=[256,64,32,16])
53 | parser.add_argument('-dl1','--decodeLayer1', nargs='+', default=[16,64,256])
54 | parser.add_argument('-dl2','--decodeLayer2', nargs='+', default=[16,20])
55 | parser.add_argument('--sigma1', default=2.5, type=float)
56 | parser.add_argument('--sigma2', default=1.5, type=float)
57 | parser.add_argument('--f1', default=1000, type=float, help='Number of mRNA after feature selection')
58 | parser.add_argument('--f2', default=2000, type=float, help='Number of ADT/ATAC after feature selection')
59 | parser.add_argument('--filter1', action='store_true', default=False, help='Do mRNA selection')
60 | parser.add_argument('--filter2', action='store_true', default=False, help='Do ADT/ATAC selection')
61 | parser.add_argument('--nbatch', default=2, type=int)
62 | parser.add_argument('--run', default=1, type=int)
63 | parser.add_argument('--device', default='cuda')
64 | parser.add_argument('--no_labels', action='store_true', default=False)
65 | args = parser.parse_args()
66 | print(args)
67 | data_mat = h5py.File(args.data_file)
68 | x1 = np.array(data_mat['X1'])
69 | x2 = np.array(data_mat['X2'])
70 | if not args.no_labels:
71 | y = np.array(data_mat['Y'])
72 | b = np.array(data_mat['Batch'])
73 | enc = OneHotEncoder()
74 | enc.fit(b.reshape(-1, 1))
75 | B = enc.transform(b.reshape(-1, 1)).toarray()
76 | data_mat.close()
77 |
78 | #Gene filter
79 | if args.filter1:
80 | importantGenes = geneSelection(x1, n=args.f1, plot=False)
81 | x1 = x1[:, importantGenes]
82 | if args.filter2:
83 | importantGenes = geneSelection(x2, n=args.f2, plot=False)
84 | x2 = x2[:, importantGenes]
85 |
86 | # preprocessing scRNA-seq read counts matrix
87 | adata1 = sc.AnnData(x1)
88 | #adata1.obs['Group'] = y
89 |
90 | adata1 = read_dataset(adata1,
91 | transpose=False,
92 | test_split=False,
93 | copy=True)
94 |
95 | adata1 = normalize(adata1,
96 | size_factors=True,
97 | normalize_input=True,
98 | logtrans_input=True)
99 |
100 | adata2 = sc.AnnData(x2)
101 | #adata2.obs['Group'] = y
102 | adata2 = read_dataset(adata2,
103 | transpose=False,
104 | test_split=False,
105 | copy=True)
106 |
107 | adata2 = normalize(adata2,
108 | size_factors=True,
109 | normalize_input=True,
110 | logtrans_input=True)
111 |
112 | input_size1 = adata1.n_vars
113 | input_size2 = adata2.n_vars
114 |
115 | print(args)
116 |
117 | encodeLayer = list(map(int, args.encodeLayer))
118 | decodeLayer1 = list(map(int, args.decodeLayer1))
119 | decodeLayer2 = list(map(int, args.decodeLayer2))
120 |
121 | model = scMultiClusterBatch(input_dim1=input_size1, input_dim2=input_size2, n_batch = args.nbatch, tau=args.tau,
122 | encodeLayer=encodeLayer, decodeLayer1=decodeLayer1, decodeLayer2=decodeLayer2,
123 | activation='elu', sigma1=args.sigma1, sigma2=args.sigma2, gamma=args.gamma,
124 | cutoff = args.cutoff, phi1=args.phi1, phi2=args.phi2, device=args.device).to(args.device)
125 |
126 | print(str(model))
127 |
128 | if not os.path.exists(args.save_dir):
129 | os.makedirs(args.save_dir)
130 |
131 |
132 | t0 = time()
133 | if args.ae_weights is None:
134 | model.pretrain_autoencoder(X1=adata1.X, X_raw1=adata1.raw.X, sf1=adata1.obs.size_factors,
135 | X2=adata2.X, X_raw2=adata2.raw.X, sf2=adata2.obs.size_factors, B = B, batch_size=args.batch_size,
136 | epochs=args.pretrain_epochs, ae_weights=args.ae_weight_file)
137 | else:
138 | if os.path.isfile(args.ae_weights):
139 | print("==> loading checkpoint '{}'".format(args.ae_weights))
140 | checkpoint = torch.load(args.ae_weights)
141 | model.load_state_dict(checkpoint['ae_state_dict'])
142 | else:
143 | print("==> no checkpoint found at '{}'".format(args.ae_weights))
144 | raise ValueError
145 |
146 | print('Pretraining time: %d seconds.' % int(time() - t0))
147 |
148 | #get k
149 | latent = model.encodeBatch(torch.tensor(adata1.X).to(args.device), torch.tensor(adata2.X).to(args.device), torch.tensor(B).to(args.device), batch_size=args.batch_size)
150 | latent = latent.cpu().numpy()
151 | if args.n_clusters == -1:
152 | n_clusters = GetCluster(latent, res=args.resolution, n=args.n_neighbors)
153 | else:
154 | print("n_cluster is defined as " + str(args.n_clusters))
155 | n_clusters = args.n_clusters
156 |
157 | if not args.no_labels:
158 | y_pred,_ = model.fit(X1=adata1.X, X_raw1=adata1.raw.X, sf1=adata1.obs.size_factors,
159 | X2=adata2.X, X_raw2=adata2.raw.X, sf2=adata2.obs.size_factors, B=B, y=y,
160 | n_clusters=n_clusters, batch_size=args.batch_size, num_epochs=args.maxiter,
161 | update_interval=args.update_interval, tol=args.tol, lr=args.lr, save_dir=args.save_dir)
162 | else:
163 | y_pred,_ = model.fit(X1=adata1.X, X_raw1=adata1.raw.X, sf1=adata1.obs.size_factors,
164 | X2=adata2.X, X_raw2=adata2.raw.X, sf2=adata2.obs.size_factors, B=B, y=None,
165 | n_clusters=n_clusters, batch_size=args.batch_size, num_epochs=args.maxiter,
166 | update_interval=args.update_interval, tol=args.tol, lr=args.lr, save_dir=args.save_dir)
167 | print('Total time: %d seconds.' % int(time() - t0))
168 |
169 | if args.prediction_file:
170 | np.savetxt(args.save_dir + "/" + str(args.run) + "_pred.csv", y_pred, delimiter=",")
171 |
172 | if args.embedding_file:
173 | final_latent = model.encodeBatch(torch.tensor(adata1.X).to(args.device), torch.tensor(adata2.X).to(args.device), torch.tensor(B).to(args.device), batch_size=args.batch_size)
174 | final_latent = final_latent.cpu().numpy()
175 | np.savetxt(args.save_dir + "/" + str(args.run) + "_embedding.csv", final_latent, delimiter=",")
176 |
177 | if not args.no_labels:
178 | y_pred_ = best_map(y, y_pred)
179 | ami = np.round(metrics.adjusted_mutual_info_score(y, y_pred), 5)
180 | nmi = np.round(metrics.normalized_mutual_info_score(y, y_pred), 5)
181 | ari = np.round(metrics.adjusted_rand_score(y, y_pred), 5)
182 | print('Final: AMI= %.4f, NMI= %.4f, ARI= %.4f' % (ami, nmi, ari))
183 | else:
184 | print("No labels for evaluation!")
185 |
--------------------------------------------------------------------------------
/src/scMDC.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics.pairwise import paired_distances
2 | from sklearn.decomposition import PCA
3 | from sklearn import metrics
4 | from sklearn.cluster import KMeans
5 | import torch
6 | import torch.nn as nn
7 | from torch.autograd import Variable
8 | from torch.nn import Parameter
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 | from torch.utils.data import DataLoader, TensorDataset
12 | from layers import NBLoss, ZINBLoss, MeanAct, DispAct
13 | import numpy as np
14 |
15 | import math, os
16 |
17 | def buildNetwork2(layers, type, activation="relu"):
18 | net = []
19 | for i in range(1, len(layers)):
20 | net.append(nn.Linear(layers[i-1], layers[i]))
21 | net.append(nn.BatchNorm1d(layers[i], affine=True))
22 | if activation=="relu":
23 | net.append(nn.ReLU())
24 | elif activation=="selu":
25 | net.append(nn.SELU())
26 | elif activation=="sigmoid":
27 | net.append(nn.Sigmoid())
28 | elif activation=="elu":
29 | net.append(nn.ELU())
30 | return nn.Sequential(*net)
31 |
32 | class scMultiCluster(nn.Module):
33 | def __init__(self, input_dim1, input_dim2,
34 | encodeLayer=[], decodeLayer1=[], decodeLayer2=[], tau=1., t=10, device="cuda",
35 | activation="elu", sigma1=2.5, sigma2=.1, alpha=1., gamma=1., phi1=0.0001, phi2=0.0001, cutoff = 0.5):
36 | super(scMultiCluster, self).__init__()
37 | self.tau=tau
38 | self.input_dim1 = input_dim1
39 | self.input_dim2 = input_dim2
40 | self.cutoff = cutoff
41 | self.activation = activation
42 | self.sigma1 = sigma1
43 | self.sigma2 = sigma2
44 | self.alpha = alpha
45 | self.gamma = gamma
46 | self.phi1 = phi1
47 | self.phi2 = phi2
48 | self.t = t
49 | self.device = device
50 | self.encoder = buildNetwork2([input_dim1+input_dim2]+encodeLayer, type="encode", activation=activation)
51 | self.decoder1 = buildNetwork2(decodeLayer1, type="decode", activation=activation)
52 | self.decoder2 = buildNetwork2(decodeLayer2, type="decode", activation=activation)
53 | self.dec_mean1 = nn.Sequential(nn.Linear(decodeLayer1[-1], input_dim1), MeanAct())
54 | self.dec_disp1 = nn.Sequential(nn.Linear(decodeLayer1[-1], input_dim1), DispAct())
55 | self.dec_mean2 = nn.Sequential(nn.Linear(decodeLayer2[-1], input_dim2), MeanAct())
56 | self.dec_disp2 = nn.Sequential(nn.Linear(decodeLayer2[-1], input_dim2), DispAct())
57 | self.dec_pi1 = nn.Sequential(nn.Linear(decodeLayer1[-1], input_dim1), nn.Sigmoid())
58 | self.dec_pi2 = nn.Sequential(nn.Linear(decodeLayer2[-1], input_dim2), nn.Sigmoid())
59 | self.zinb_loss = ZINBLoss()
60 | self.z_dim = encodeLayer[-1]
61 |
62 | def save_model(self, path):
63 | torch.save(self.state_dict(), path)
64 |
65 | def load_model(self, path):
66 | pretrained_dict = torch.load(path, map_location=lambda storage, loc: storage)
67 | model_dict = self.state_dict()
68 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
69 | model_dict.update(pretrained_dict)
70 | self.load_state_dict(model_dict)
71 |
72 | def cal_latent(self, z):
73 | sum_y = torch.sum(torch.square(z), dim=1)
74 | num = -2.0 * torch.matmul(z, z.t()) + torch.reshape(sum_y, [-1, 1]) + sum_y
75 | num = num / self.alpha
76 | num = torch.pow(1.0 + num, -(self.alpha + 1.0) / 2.0)
77 | zerodiag_num = num - torch.diag(torch.diag(num))
78 | latent_p = (zerodiag_num.t() / torch.sum(zerodiag_num, dim=1)).t()
79 | return num, latent_p
80 |
81 | def kmeans_loss(self, z):
82 | dist1 = self.tau*torch.sum(torch.square(z.unsqueeze(1) - self.mu), dim=2)
83 | temp_dist1 = dist1 - torch.reshape(torch.mean(dist1, dim=1), [-1, 1])
84 | q = torch.exp(-temp_dist1)
85 | q = (q.t() / torch.sum(q, dim=1)).t()
86 | q = torch.pow(q, 2)
87 | q = (q.t() / torch.sum(q, dim=1)).t()
88 | dist2 = dist1 * q
89 | return dist1, torch.mean(torch.sum(dist2, dim=1))
90 |
91 | def target_distribution(self, q):
92 | p = q**2 / q.sum(0)
93 | return (p.t() / p.sum(1)).t()
94 |
95 | def forward(self, x1, x2):
96 | x = torch.cat([x1+torch.randn_like(x1)*self.sigma1, x2+torch.randn_like(x2)*self.sigma2], dim=-1)
97 | h = self.encoder(x)
98 |
99 | h1 = self.decoder1(h)
100 | mean1 = self.dec_mean1(h1)
101 | disp1 = self.dec_disp1(h1)
102 | pi1 = self.dec_pi1(h1)
103 |
104 | h2 = self.decoder2(h)
105 | mean2 = self.dec_mean2(h2)
106 | disp2 = self.dec_disp2(h2)
107 | pi2 = self.dec_pi2(h2)
108 |
109 | x0 = torch.cat([x1, x2], dim=-1)
110 | h0 = self.encoder(x0)
111 | num, lq = self.cal_latent(h0)
112 | return h0, num, lq, mean1, mean2, disp1, disp2, pi1, pi2
113 |
114 | def forwardAE(self, x1, x2):
115 | x = torch.cat([x1+torch.randn_like(x1)*self.sigma1, x2+torch.randn_like(x2)*self.sigma2], dim=-1)
116 | h = self.encoder(x)
117 |
118 | h1 = self.decoder1(h)
119 | mean1 = self.dec_mean1(h1)
120 | disp1 = self.dec_disp1(h1)
121 | pi1 = self.dec_pi1(h1)
122 |
123 | h2 = self.decoder2(h)
124 | mean2 = self.dec_mean2(h2)
125 | disp2 = self.dec_disp2(h2)
126 | pi2 = self.dec_pi2(h2)
127 |
128 | x0 = torch.cat([x1, x2], dim=-1)
129 | h0 = self.encoder(x0)
130 | num, lq = self.cal_latent(h0)
131 | return h0, num, lq, mean1, mean2, disp1, disp2, pi1, pi2
132 |
133 | def encodeBatch(self, X1, X2, batch_size=256):
134 | use_cuda = torch.cuda.is_available()
135 | if use_cuda:
136 | self.to(self.device)
137 | encoded = []
138 | self.eval()
139 | num = X1.shape[0]
140 | num_batch = int(math.ceil(1.0*X1.shape[0]/batch_size))
141 | for batch_idx in range(num_batch):
142 | x1batch = X1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
143 | x2batch = X2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
144 | inputs1 = Variable(x1batch)
145 | inputs2 = Variable(x2batch)
146 | z,_,_,_,_,_,_,_,_ = self.forwardAE(inputs1, inputs2)
147 | encoded.append(z.data)
148 |
149 | encoded = torch.cat(encoded, dim=0)
150 | return encoded
151 |
152 | def kldloss(self, p, q):
153 | c1 = -torch.sum(p * torch.log(q), dim=-1)
154 | c2 = -torch.sum(p * torch.log(p), dim=-1)
155 | return torch.mean(c1 - c2)
156 |
157 | def pretrain_autoencoder(self, X1, X_raw1, sf1, X2, X_raw2, sf2,
158 | batch_size=256, lr=0.001, epochs=400, ae_save=True, ae_weights='AE_weights.pth.tar'):
159 | num_batch = int(math.ceil(1.0*X1.shape[0]/batch_size))
160 | dataset = TensorDataset(torch.Tensor(X1), torch.Tensor(X_raw1), torch.Tensor(sf1), torch.Tensor(X2), torch.Tensor(X_raw2), torch.Tensor(sf2))
161 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
162 | print("Pretraining stage")
163 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=lr, amsgrad=True)
164 | num = X1.shape[0]
165 | for epoch in range(epochs):
166 | loss_val = 0
167 | recon_loss1_val = 0
168 | recon_loss2_val = 0
169 | kl_loss_val = 0
170 | for batch_idx, (x1_batch, x_raw1_batch, sf1_batch, x2_batch, x_raw2_batch, sf2_batch) in enumerate(dataloader):
171 | x1_tensor = Variable(x1_batch).to(self.device)
172 | x_raw1_tensor = Variable(x_raw1_batch).to(self.device)
173 | sf1_tensor = Variable(sf1_batch).to(self.device)
174 | x2_tensor = Variable(x2_batch).to(self.device)
175 | x_raw2_tensor = Variable(x_raw2_batch).to(self.device)
176 | sf2_tensor = Variable(sf2_batch).to(self.device)
177 | zbatch, z_num, lqbatch, mean1_tensor, mean2_tensor, disp1_tensor, disp2_tensor, pi1_tensor, pi2_tensor = self.forwardAE(x1_tensor, x2_tensor)
178 | recon_loss1 = self.zinb_loss(x=x_raw1_tensor, mean=mean1_tensor, disp=disp1_tensor, pi=pi1_tensor, scale_factor=sf1_tensor)
179 | recon_loss2 = self.zinb_loss(x=x_raw2_tensor, mean=mean2_tensor, disp=disp2_tensor, pi=pi2_tensor, scale_factor=sf2_tensor)
180 | lpbatch = self.target_distribution(lqbatch)
181 | lqbatch = lqbatch + torch.diag(torch.diag(z_num))
182 | lpbatch = lpbatch + torch.diag(torch.diag(z_num))
183 | kl_loss = self.kldloss(lpbatch, lqbatch)
184 | if epoch+1 >= epochs * self.cutoff:
185 | loss = recon_loss1 + recon_loss2 + kl_loss * self.phi1
186 | else:
187 | loss = recon_loss1 + recon_loss2
188 | optimizer.zero_grad()
189 | loss.backward()
190 | optimizer.step()
191 |
192 | loss_val += loss.item() * len(x1_batch)
193 | recon_loss1_val += recon_loss1.item() * len(x1_batch)
194 | recon_loss2_val += recon_loss2.item() * len(x2_batch)
195 | if epoch+1 >= epochs * self.cutoff:
196 | kl_loss_val += kl_loss.item() * len(x1_batch)
197 |
198 | loss_val = loss_val/num
199 | recon_loss1_val = recon_loss1_val/num
200 | recon_loss2_val = recon_loss2_val/num
201 | kl_loss_val = kl_loss_val/num
202 | if epoch%self.t == 0:
203 | print('Pretrain epoch {}, Total loss:{:.6f}, ZINB loss1:{:.6f}, ZINB loss2:{:.6f}, KL loss:{:.6f}'.format(epoch+1, loss_val, recon_loss1_val, recon_loss2_val, kl_loss_val))
204 |
205 | if ae_save:
206 | torch.save({'ae_state_dict': self.state_dict(),
207 | 'optimizer_state_dict': optimizer.state_dict()}, ae_weights)
208 |
209 | def save_checkpoint(self, state, index, filename):
210 | newfilename = os.path.join(filename, 'FTcheckpoint_%d.pth.tar' % index)
211 | torch.save(state, newfilename)
212 |
213 | def fit(self, X1, X_raw1, sf1, X2, X_raw2, sf2, y=None, lr=1., n_clusters = 4,
214 | batch_size=256, num_epochs=10, update_interval=1, tol=1e-3, save_dir=""):
215 | '''X: tensor data'''
216 | use_cuda = torch.cuda.is_available()
217 | if use_cuda:
218 | self.to(self.device)
219 | print("Clustering stage")
220 | X1 = torch.tensor(X1).to(self.device)
221 | X_raw1 = torch.tensor(X_raw1).to(self.device)
222 | sf1 = torch.tensor(sf1).to(self.device)
223 | X2 = torch.tensor(X2).to(self.device)
224 | X_raw2 = torch.tensor(X_raw2).to(self.device)
225 | sf2 = torch.tensor(sf2).to(self.device)
226 | self.mu = Parameter(torch.Tensor(n_clusters, self.z_dim), requires_grad=True)
227 | optimizer = optim.Adadelta(filter(lambda p: p.requires_grad, self.parameters()), lr=lr, rho=.95)
228 |
229 | print("Initializing cluster centers with kmeans.")
230 | kmeans = KMeans(n_clusters, n_init=20)
231 | Zdata = self.encodeBatch(X1, X2, batch_size=batch_size)
232 | #latent
233 | self.y_pred = kmeans.fit_predict(Zdata.data.cpu().numpy())
234 | self.y_pred_last = self.y_pred
235 | self.mu.data.copy_(torch.Tensor(kmeans.cluster_centers_))
236 |
237 | if y is not None:
238 | ami = np.round(metrics.adjusted_mutual_info_score(y, self.y_pred), 5)
239 | nmi = np.round(metrics.normalized_mutual_info_score(y, self.y_pred), 5)
240 | ari = np.round(metrics.adjusted_rand_score(y, self.y_pred), 5)
241 | print('Initializing k-means: AMI= %.4f, NMI= %.4f, ARI= %.4f' % (ami, nmi, ari))
242 |
243 | self.train()
244 | num = X1.shape[0]
245 | num_batch = int(math.ceil(1.0*X1.shape[0]/batch_size))
246 |
247 | final_ami, final_nmi, final_ari, final_epoch = 0, 0, 0, 0
248 |
249 | for epoch in range(num_epochs):
250 | if epoch%update_interval == 0:
251 | Zdata = self.encodeBatch(X1, X2, batch_size=batch_size)
252 | dist, _ = self.kmeans_loss(Zdata)
253 | self.y_pred = torch.argmin(dist, dim=1).data.cpu().numpy()
254 | if y is not None:
255 | #acc2 = np.round(cluster_acc(y, self.y_pred), 5)
256 | final_ami = ami = np.round(metrics.adjusted_mutual_info_score(y, self.y_pred), 5)
257 | final_nmi = nmi = np.round(metrics.normalized_mutual_info_score(y, self.y_pred), 5)
258 | final_ari = ari = np.round(metrics.adjusted_rand_score(y, self.y_pred), 5)
259 | final_epoch = epoch+1
260 | print('Clustering %d: AMI= %.4f, NMI= %.4f, ARI= %.4f' % (epoch+1, ami, nmi, ari))
261 |
262 | # check stop criterion
263 | delta_label = np.sum(self.y_pred != self.y_pred_last).astype(np.float32) / num
264 | self.y_pred_last = self.y_pred
265 | if epoch>0 and delta_label < tol:
266 | print('delta_label ', delta_label, '< tol ', tol)
267 | print("Reach tolerance threshold. Stopping training.")
268 | break
269 |
270 | #save current model
271 | # if (epoch>0 and delta_label < tol) or epoch%10 == 0:
272 | # self.save_checkpoint({'epoch': epoch+1,
273 | # 'state_dict': self.state_dict(),
274 | # 'mu': self.mu,
275 | # 'y_pred': self.y_pred,
276 | # 'y_pred_last': self.y_pred_last,
277 | # 'y': y
278 | # }, epoch+1, filename=save_dir)
279 |
280 | # train 1 epoch for clustering loss
281 | loss_val = 0.0
282 | recon_loss1_val = 0.0
283 | recon_loss2_val = 0.0
284 | cluster_loss_val = 0.0
285 | kl_loss_val = 0.0
286 | for batch_idx in range(num_batch):
287 | x1_batch = X1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
288 | x_raw1_batch = X_raw1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
289 | sf1_batch = sf1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
290 | x2_batch = X2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
291 | x_raw2_batch = X_raw2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
292 | sf2_batch = sf2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
293 |
294 | inputs1 = Variable(x1_batch)
295 | rawinputs1 = Variable(x_raw1_batch)
296 | sfinputs1 = Variable(sf1_batch)
297 | inputs2 = Variable(x2_batch)
298 | rawinputs2 = Variable(x_raw2_batch)
299 | sfinputs2 = Variable(sf2_batch)
300 |
301 | zbatch, z_num, lqbatch, mean1_tensor, mean2_tensor, disp1_tensor, disp2_tensor, pi1_tensor, pi2_tensor = self.forward(inputs1, inputs2)
302 | _, cluster_loss = self.kmeans_loss(zbatch)
303 | recon_loss1 = self.zinb_loss(x=rawinputs1, mean=mean1_tensor, disp=disp1_tensor, pi=pi1_tensor, scale_factor=sfinputs1)
304 | recon_loss2 = self.zinb_loss(x=rawinputs2, mean=mean2_tensor, disp=disp2_tensor, pi=pi2_tensor, scale_factor=sfinputs2)
305 | target2 = self.target_distribution(lqbatch)
306 | lqbatch = lqbatch + torch.diag(torch.diag(z_num))
307 | target2 = target2 + torch.diag(torch.diag(z_num))
308 | kl_loss = self.kldloss(target2, lqbatch)
309 | loss = recon_loss1 + recon_loss2 + kl_loss * self.phi2 + cluster_loss * self.gamma
310 | optimizer.zero_grad()
311 | loss.backward()
312 | # torch.nn.utils.clip_grad_norm_(self.mu, 1)
313 | optimizer.step()
314 | cluster_loss_val += cluster_loss.data * len(inputs1)
315 | recon_loss1_val += recon_loss1.data * len(inputs1)
316 | recon_loss2_val += recon_loss2.data * len(inputs2)
317 | kl_loss_val += kl_loss.data * len(inputs1)
318 | loss_val = cluster_loss_val + recon_loss1_val + recon_loss2_val + kl_loss_val
319 |
320 | if epoch%self.t == 0:
321 | print("#Epoch %d: Total: %.6f Clustering Loss: %.6f ZINB Loss1: %.6f ZINB Loss2: %.6f KL Loss: %.6f" % (
322 | epoch + 1, loss_val / num, cluster_loss_val / num, recon_loss1_val / num, recon_loss2_val / num, kl_loss_val / num))
323 |
324 | return self.y_pred, final_epoch
325 |
--------------------------------------------------------------------------------
/src/scMDC_batch.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics.pairwise import paired_distances
2 | from sklearn.decomposition import PCA
3 | from sklearn import metrics
4 | from sklearn.cluster import KMeans
5 | import torch
6 | import torch.nn as nn
7 | from torch.autograd import Variable
8 | from torch.nn import Parameter
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 | from torch.utils.data import DataLoader, TensorDataset
12 | from layers import NBLoss, ZINBLoss, MeanAct, DispAct
13 | import numpy as np
14 |
15 | import math, os
16 |
17 | from utils import torch_PCA
18 |
19 | from preprocess import read_dataset, normalize
20 | import scanpy as sc
21 |
22 | def buildNetwork1(layers, type, activation="relu"):
23 | net = []
24 | for i in range(1, len(layers)):
25 | net.append(nn.Linear(layers[i-1], layers[i]))
26 | if type=="encode" and i==len(layers)-1:
27 | break
28 | if activation=="relu":
29 | net.append(nn.ReLU())
30 | elif activation=="sigmoid":
31 | net.append(nn.Sigmoid())
32 | elif activation=="elu":
33 | net.append(nn.ELU())
34 | return nn.Sequential(*net)
35 |
36 | def buildNetwork2(layers, type, activation="relu"):
37 | net = []
38 | for i in range(1, len(layers)):
39 | net.append(nn.Linear(layers[i-1], layers[i]))
40 | net.append(nn.BatchNorm1d(layers[i], affine=True))
41 | if activation=="relu":
42 | net.append(nn.ReLU())
43 | elif activation=="selu":
44 | net.append(nn.SELU())
45 | elif activation=="sigmoid":
46 | net.append(nn.Sigmoid())
47 | elif activation=="elu":
48 | net.append(nn.ELU())
49 | return nn.Sequential(*net)
50 |
51 | class scMultiClusterBatch(nn.Module):
52 | def __init__(self, input_dim1, input_dim2, n_batch,
53 | encodeLayer=[], decodeLayer1=[], decodeLayer2=[], tau=1., t=10, device = "cuda",
54 | activation="elu", sigma1=2.5, sigma2=.1, alpha=1., gamma=1., phi1=0.0001, phi2=0.0001, cutoff = 0.5):
55 | super(scMultiClusterBatch, self).__init__()
56 | self.tau=tau
57 | self.input_dim1 = input_dim1
58 | self.input_dim2 = input_dim2
59 | self.cutoff = cutoff
60 | self.activation = activation
61 | self.sigma1 = sigma1
62 | self.sigma2 = sigma2
63 | self.alpha = alpha
64 | self.gamma = gamma
65 | self.phi1 = phi1
66 | self.phi2 = phi2
67 | self.t=t
68 | self.device = device
69 | self.encoder = buildNetwork2([input_dim1+input_dim2+n_batch]+encodeLayer, type="encode", activation=activation)
70 | self.decoder1 = buildNetwork2([decodeLayer1[0]+n_batch]+decodeLayer1[1:], type="decode", activation=activation)
71 | self.decoder2 = buildNetwork2([decodeLayer2[0]+n_batch]+decodeLayer2[1:], type="decode", activation=activation)
72 | self.dec_mean1 = nn.Sequential(nn.Linear(decodeLayer1[-1], input_dim1), MeanAct())
73 | self.dec_disp1 = nn.Sequential(nn.Linear(decodeLayer1[-1], input_dim1), DispAct())
74 | self.dec_mean2 = nn.Sequential(nn.Linear(decodeLayer2[-1], input_dim2), MeanAct())
75 | self.dec_disp2 = nn.Sequential(nn.Linear(decodeLayer2[-1], input_dim2), DispAct())
76 | self.dec_pi1 = nn.Sequential(nn.Linear(decodeLayer1[-1], input_dim1), nn.Sigmoid())
77 | self.dec_pi2 = nn.Sequential(nn.Linear(decodeLayer2[-1], input_dim2), nn.Sigmoid())
78 | self.zinb_loss = ZINBLoss()
79 | self.NBLoss = NBLoss()
80 | self.mse = nn.MSELoss()
81 | self.z_dim = encodeLayer[-1]
82 |
83 | def save_model(self, path):
84 | torch.save(self.state_dict(), path)
85 |
86 | def load_model(self, path):
87 | pretrained_dict = torch.load(path, map_location=lambda storage, loc: storage)
88 | model_dict = self.state_dict()
89 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
90 | model_dict.update(pretrained_dict)
91 | self.load_state_dict(model_dict)
92 |
93 | def soft_assign(self, z):
94 | q = 1.0 / (1.0 + torch.sum((z.unsqueeze(1) - self.mu)**2, dim=2) / self.alpha)
95 | q = q**((self.alpha+1.0)/2.0)
96 | q = (q.t() / torch.sum(q, dim=1)).t()
97 | return q
98 |
99 | def cal_latent(self, z):
100 | sum_y = torch.sum(torch.square(z), dim=1)
101 | num = -2.0 * torch.matmul(z, z.t()) + torch.reshape(sum_y, [-1, 1]) + sum_y
102 | num = num / self.alpha
103 | num = torch.pow(1.0 + num, -(self.alpha + 1.0) / 2.0)
104 | zerodiag_num = num - torch.diag(torch.diag(num))
105 | latent_p = (zerodiag_num.t() / torch.sum(zerodiag_num, dim=1)).t()
106 | return num, latent_p
107 |
108 | def target_distribution(self, q):
109 | p = q**2 / q.sum(0)
110 | return (p.t() / p.sum(1)).t()
111 |
112 | def kmeans_loss(self, z):
113 | dist1 = self.tau * torch.sum(torch.square(z.unsqueeze(1) - self.mu), dim=2)
114 | temp_dist1 = dist1 - torch.reshape(torch.mean(dist1, dim=1), [-1, 1])
115 | q = torch.exp(-temp_dist1)
116 | q = (q.t() / torch.sum(q, dim=1)).t()
117 | q = torch.pow(q, 2)
118 | q = (q.t() / torch.sum(q, dim=1)).t()
119 | dist2 = dist1 * q
120 | return dist1, torch.mean(torch.sum(dist2, dim=1))
121 |
122 | def forward(self, x1, x2, b):
123 | x = torch.cat([x1+torch.randn_like(x1)*self.sigma1, x2+torch.randn_like(x2)*self.sigma2], dim=-1)
124 | h = self.encoder(torch.cat([x, b], dim=-1))
125 | h = torch.cat([h, b], dim=-1)
126 |
127 | h1 = self.decoder1(h)
128 | mean1 = self.dec_mean1(h1)
129 | disp1 = self.dec_disp1(h1)
130 | pi1 = self.dec_pi1(h1)
131 |
132 | h2 = self.decoder2(h)
133 | mean2 = self.dec_mean2(h2)
134 | disp2 = self.dec_disp2(h2)
135 | pi2 = self.dec_pi2(h2)
136 |
137 | x0 = torch.cat([x1, x2], dim=-1)
138 | h0 = self.encoder(torch.cat([x0, b], dim=-1))
139 | q = self.soft_assign(h0)
140 | num, lq = self.cal_latent(h0)
141 | return h0, q, num, lq, mean1, mean2, disp1, disp2, pi1, pi2
142 |
143 | def forwardAE(self, x1, x2, b):
144 | x = torch.cat([x1+torch.randn_like(x1)*self.sigma1, x2+torch.randn_like(x2)*self.sigma2], dim=-1)
145 | h = self.encoder(torch.cat([x, b], dim=-1))
146 | h = torch.cat([h, b], dim=-1)
147 |
148 | h1 = self.decoder1(h)
149 | mean1 = self.dec_mean1(h1)
150 | disp1 = self.dec_disp1(h1)
151 | pi1 = self.dec_pi1(h1)
152 |
153 | h2 = self.decoder2(h)
154 | mean2 = self.dec_mean2(h2)
155 | disp2 = self.dec_disp2(h2)
156 | pi2 = self.dec_pi2(h2)
157 |
158 | x0 = torch.cat([x1, x2], dim=-1)
159 | h0 = self.encoder(torch.cat([x0, b], dim=-1))
160 | num, lq = self.cal_latent(h0)
161 | return h0, num, lq, mean1, mean2, disp1, disp2, pi1, pi2
162 |
163 | def encodeBatch(self, X1, X2, B, batch_size=256):
164 | use_cuda = torch.cuda.is_available()
165 | if use_cuda:
166 | self.to(self.device)
167 | encoded = []
168 | self.eval()
169 | num = X1.shape[0]
170 | num_batch = int(math.ceil(1.0*X1.shape[0]/batch_size))
171 | for batch_idx in range(num_batch):
172 | x1batch = X1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
173 | x2batch = X2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
174 | b_batch = B[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
175 | inputs1 = Variable(x1batch).to(self.device)
176 | inputs2 = Variable(x2batch).to(self.device)
177 | b_tensor = Variable(b_batch).to(self.device)
178 | z,_,_,_,_,_,_,_,_ = self.forwardAE(inputs1.float(), inputs2.float(), b_tensor.float())
179 | encoded.append(z.data)
180 |
181 | encoded = torch.cat(encoded, dim=0)
182 | return encoded
183 |
184 | def cluster_loss(self, p, q):
185 | def kld(target, pred):
186 | return torch.mean(torch.sum(target*torch.log(target/(pred+1e-6)), dim=-1))
187 | kldloss = kld(p, q)
188 | return kldloss
189 |
190 | def kldloss(self, p, q):
191 | c1 = -torch.sum(p * torch.log(q), dim=-1)
192 | c2 = -torch.sum(p * torch.log(p), dim=-1)
193 | return torch.mean(c1 - c2)
194 |
195 | def SDis_func(self, x, y):
196 | return torch.sum(torch.square(x - y), dim=1)
197 |
198 | def pretrain_autoencoder(self, X1, X_raw1, sf1, X2, X_raw2, sf2, B,
199 | batch_size=256, lr=0.001, epochs=400, ae_save=True, ae_weights='AE_weights.pth.tar'):
200 | num_batch = int(math.ceil(1.0*X1.shape[0]/batch_size))
201 | dataset = TensorDataset(torch.Tensor(X1), torch.Tensor(X_raw1), torch.Tensor(sf1), torch.Tensor(X2), torch.Tensor(X_raw2), torch.Tensor(sf2), torch.Tensor(B))
202 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
203 | print("Pretraining stage")
204 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=lr, amsgrad=True)
205 | counts = 0
206 | for epoch in range(epochs):
207 | loss_val = 0
208 | recon_loss1_val = 0
209 | recon_loss2_val = 0
210 | kl_loss_val = 0
211 | for batch_idx, (x1_batch, x_raw1_batch, sf1_batch, x2_batch, x_raw2_batch, sf2_batch, b_batch) in enumerate(dataloader):
212 | x1_tensor = Variable(x1_batch).to(self.device)
213 | x_raw1_tensor = Variable(x_raw1_batch).to(self.device)
214 | sf1_tensor = Variable(sf1_batch).to(self.device)
215 | x2_tensor = Variable(x2_batch).to(self.device)
216 | x_raw2_tensor = Variable(x_raw2_batch).to(self.device)
217 | sf2_tensor = Variable(sf2_batch).to(self.device)
218 | b_tensor = Variable(b_batch).to(self.device)
219 | zbatch, z_num, lqbatch, mean1_tensor, mean2_tensor, disp1_tensor, disp2_tensor, pi1_tensor, pi2_tensor = self.forwardAE(x1_tensor, x2_tensor, b_tensor)
220 | #recon_loss1 = self.mse(mean1_tensor, x1_tensor)
221 | recon_loss1 = self.zinb_loss(x=x_raw1_tensor, mean=mean1_tensor, disp=disp1_tensor, pi=pi1_tensor, scale_factor=sf1_tensor)
222 | #recon_loss2 = self.mse(mean2_tensor, x2_tensor)
223 | recon_loss2 = self.zinb_loss(x=x_raw2_tensor, mean=mean2_tensor, disp=disp2_tensor, pi=pi2_tensor, scale_factor=sf2_tensor)
224 | lpbatch = self.target_distribution(lqbatch)
225 | lqbatch = lqbatch + torch.diag(torch.diag(z_num))
226 | lpbatch = lpbatch + torch.diag(torch.diag(z_num))
227 | kl_loss = self.kldloss(lpbatch, lqbatch)
228 | if epoch+1 >= epochs * self.cutoff:
229 | loss = recon_loss1 + recon_loss2 + kl_loss * self.phi1
230 | else:
231 | loss = recon_loss1 + recon_loss2 #+ kl_loss
232 | optimizer.zero_grad()
233 | loss.backward()
234 | optimizer.step()
235 |
236 | loss_val += loss.item() * len(x1_batch)
237 | recon_loss1_val += recon_loss1.item() * len(x1_batch)
238 | recon_loss2_val += recon_loss2.item() * len(x1_batch)
239 | if epoch+1 >= epochs * self.cutoff:
240 | kl_loss_val += kl_loss.item() * len(x1_batch)
241 |
242 | loss_val = loss_val/X1.shape[0]
243 | recon_loss1_val = loss_val/X1.shape[0]
244 | recon_loss2_val = recon_loss2_val/X1.shape[0]
245 | kl_loss_val = kl_loss_val/X1.shape[0]
246 | if epoch%self.t == 0:
247 | print('Pretrain epoch {}, Total loss:{:.6f}, ZINB loss:{:.6f}, NB loss:{:.6f}, KL loss:{:.6f}'.format(epoch+1, loss_val, recon_loss1_val, recon_loss2_val, kl_loss_val))
248 |
249 | if ae_save:
250 | torch.save({'ae_state_dict': self.state_dict(),
251 | 'optimizer_state_dict': optimizer.state_dict()}, ae_weights)
252 |
253 | def save_checkpoint(self, state, index, filename):
254 | newfilename = os.path.join(filename, 'FTcheckpoint_%d.pth.tar' % index)
255 | torch.save(state, newfilename)
256 |
257 | def fit(self, X1, X_raw1, sf1, X2, X_raw2, sf2, B, y=None, lr=1., n_clusters = 4,
258 | batch_size=256, num_epochs=10, update_interval=1, tol=1e-3, save_dir=""):
259 | '''X: tensor data'''
260 | use_cuda = torch.cuda.is_available()
261 | if use_cuda:
262 | self.to(self.device)
263 | print("Clustering stage")
264 | X1 = torch.tensor(X1).to(self.device)
265 | X_raw1 = torch.tensor(X_raw1).to(self.device)
266 | sf1 = torch.tensor(sf1).to(self.device)
267 | X2 = torch.tensor(X2).to(self.device)
268 | X_raw2 = torch.tensor(X_raw2).to(self.device)
269 | sf2 = torch.tensor(sf2).to(self.device)
270 | B = torch.tensor(B).to(self.device)
271 | self.mu = Parameter(torch.Tensor(n_clusters, self.z_dim), requires_grad=True)
272 | optimizer = optim.Adadelta(filter(lambda p: p.requires_grad, self.parameters()), lr=lr, rho=.95)
273 | #optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=0.001)
274 |
275 | print("Initializing cluster centers with kmeans.")
276 | kmeans = KMeans(n_clusters, n_init=20)
277 | Zdata = self.encodeBatch(X1, X2, B, batch_size=batch_size)
278 | #latent
279 | self.y_pred = kmeans.fit_predict(Zdata.data.cpu().numpy())
280 | self.y_pred_last = self.y_pred
281 | self.mu.data.copy_(torch.Tensor(kmeans.cluster_centers_))
282 | if y is not None:
283 | ami = np.round(metrics.adjusted_mutual_info_score(y, self.y_pred), 5)
284 | nmi = np.round(metrics.normalized_mutual_info_score(y, self.y_pred), 5)
285 | ari = np.round(metrics.adjusted_rand_score(y, self.y_pred), 5)
286 | print('Initializing k-means: AMI= %.4f, NMI= %.4f, ARI= %.4f' % (ami, nmi, ari))
287 |
288 | self.train()
289 | num = X1.shape[0]
290 | num_batch = int(math.ceil(1.0*X1.shape[0]/batch_size))
291 |
292 | final_nmi, final_ari, final_epoch = 0, 0, 0
293 |
294 | for epoch in range(num_epochs):
295 | if epoch%update_interval == 0:
296 | # update the targe distribution p
297 | Zdata = self.encodeBatch(X1, X2, B, batch_size=batch_size)
298 |
299 | # evalute the clustering performance
300 | dist, _ = self.kmeans_loss(Zdata)
301 | self.y_pred = torch.argmin(dist, dim=1).data.cpu().numpy()
302 |
303 | if y is not None:
304 | #acc2 = np.round(cluster_acc(y, self.y_pred), 5)
305 | final_ami = ami = np.round(metrics.adjusted_mutual_info_score(y, self.y_pred), 5)
306 | final_nmi = nmi = np.round(metrics.normalized_mutual_info_score(y, self.y_pred), 5)
307 | final_ari = ari = np.round(metrics.adjusted_rand_score(y, self.y_pred), 5)
308 | final_epoch = epoch+1
309 | print('Clustering %d: AMI= %.4f, NMI= %.4f, ARI= %.4f' % (epoch+1, ami, nmi, ari))
310 |
311 | # check stop criterion
312 | delta_label = np.sum(self.y_pred != self.y_pred_last).astype(np.float32) / num
313 | self.y_pred_last = self.y_pred
314 | if epoch>0 and delta_label < tol:
315 | print('delta_label ', delta_label, '< tol ', tol)
316 | print("Reach tolerance threshold. Stopping training.")
317 | break
318 |
319 | # save current model
320 | # if (epoch>0 and delta_label < tol) or epoch%10 == 0:
321 | # self.save_checkpoint({'epoch': epoch+1,
322 | # 'state_dict': self.state_dict(),
323 | # 'mu': self.mu,
324 | # 'y_pred': self.y_pred,
325 | # 'y_pred_last': self.y_pred_last,
326 | # 'y': y
327 | # }, epoch+1, filename=save_dir)
328 |
329 | # train 1 epoch for clustering loss
330 | train_loss = 0.0
331 | recon_loss1_val = 0.0
332 | recon_loss2_val = 0.0
333 | recon_loss_latent_val = 0.0
334 | cluster_loss_val = 0.0
335 | kl_loss_val = 0.0
336 | for batch_idx in range(num_batch):
337 | x1_batch = X1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
338 | x_raw1_batch = X_raw1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
339 | sf1_batch = sf1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
340 | x2_batch = X2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
341 | x_raw2_batch = X_raw2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
342 | sf2_batch = sf2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
343 | b_batch = B[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
344 | optimizer.zero_grad()
345 | inputs1 = Variable(x1_batch)
346 | rawinputs1 = Variable(x_raw1_batch)
347 | sfinputs1 = Variable(sf1_batch)
348 | inputs2 = Variable(x2_batch)
349 | rawinputs2 = Variable(x_raw2_batch)
350 | sfinputs2 = Variable(sf2_batch)
351 |
352 | zbatch, qbatch, z_num, lqbatch, mean1_tensor, mean2_tensor, disp1_tensor, disp2_tensor, pi1_tensor, pi2_tensor = self.forward(inputs1.float(), inputs2.float(), b_batch.float())
353 |
354 | _, cluster_loss = self.kmeans_loss(zbatch)
355 | recon_loss1 = self.zinb_loss(x=rawinputs1, mean=mean1_tensor, disp=disp1_tensor, pi=pi1_tensor, scale_factor=sfinputs1)
356 | recon_loss2 = self.zinb_loss(x=rawinputs2, mean=mean2_tensor, disp=disp2_tensor, pi=pi2_tensor, scale_factor=sfinputs2)
357 | target2 = self.target_distribution(lqbatch)
358 | lqbatch = lqbatch + torch.diag(torch.diag(z_num))
359 | target2 = target2 + torch.diag(torch.diag(z_num))
360 | kl_loss = self.kldloss(target2, lqbatch)
361 | loss = cluster_loss * self.gamma + kl_loss * self.phi2 + recon_loss1 + recon_loss2
362 | loss.backward()
363 | torch.nn.utils.clip_grad_norm_(self.mu, 1)
364 | optimizer.step()
365 | cluster_loss_val += cluster_loss.data * len(inputs1)
366 | recon_loss1_val += recon_loss1.data * len(inputs1)
367 | recon_loss2_val += recon_loss2.data * len(inputs2)
368 | kl_loss_val += kl_loss.data * len(inputs1)
369 | loss_val = cluster_loss_val + recon_loss1_val + recon_loss2_val + kl_loss_val
370 |
371 | if epoch%self.t == 0:
372 | print("#Epoch %d: Total: %.6f Clustering Loss: %.6f ZINB Loss: %.6f ZINB Loss2: %.6f KL Loss: %.6f" % (
373 | epoch + 1, loss_val / num, cluster_loss_val / num, recon_loss1_val / num, recon_loss2_val / num, kl_loss_val / num))
374 |
375 | return self.y_pred, final_epoch
376 |
--------------------------------------------------------------------------------
/src/tree.txt:
--------------------------------------------------------------------------------
1 | ((A:3,H:3):3,(B:2,C:2):4,(D:1,E:1):5,(H:4,G:4):2);
2 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 | import scanpy as sc
5 | from scipy import stats, spatial, sparse
6 | from scipy.linalg import norm
7 | from sklearn.metrics.pairwise import euclidean_distances
8 | import numpy as np
9 | import random
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.init as init
13 | import torch.utils.data as data
14 | from sklearn.neighbors import kneighbors_graph
15 |
16 | def cluster_acc(y_true, y_pred):
17 | """
18 | Calculate clustering accuracy. Require scikit-learn installed
19 | # Arguments
20 | y: true labels, numpy.array with shape `(n_samples,)`
21 | y_pred: predicted labels, numpy.array with shape `(n_samples,)`
22 | # Return
23 | accuracy, in [0,1]
24 | """
25 | y_true = y_true.astype(np.int64)
26 | assert y_pred.size == y_true.size
27 | D = max(y_pred.max(), y_true.max()) + 1
28 | w = np.zeros((D, D), dtype=np.int64)
29 | for i in range(y_pred.size):
30 | w[y_pred[i], y_true[i]] += 1
31 | from sklearn.utils.linear_assignment_ import linear_assignment
32 | ind = linear_assignment(w.max() - w)
33 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size
34 |
35 | def GetCluster(X, res, n):
36 | adata0=sc.AnnData(X)
37 | if adata0.shape[0]>200000:
38 | np.random.seed(adata0.shape[0])#set seed
39 | adata0=adata0[np.random.choice(adata0.shape[0],200000,replace=False)]
40 | sc.pp.neighbors(adata0, n_neighbors=n, use_rep="X")
41 | sc.tl.louvain(adata0,resolution=res)
42 | Y_pred_init=adata0.obs['louvain']
43 | Y_pred_init=np.asarray(Y_pred_init,dtype=int)
44 | if np.unique(Y_pred_init).shape[0]<=1:
45 | #avoid only a cluster
46 | exit("Error: There is only a cluster detected. The resolution:"+str(res)+"is too small, please choose a larger resolution!!")
47 | else:
48 | print("Estimated n_clusters is: ", np.shape(np.unique(Y_pred_init))[0])
49 | return(np.shape(np.unique(Y_pred_init))[0])
50 |
51 | def torch_PCA(X, k, center=True, scale=False):
52 | X = X.t()
53 | n,p = X.size()
54 | ones = torch.ones(n).cuda().view([n,1])
55 | h = ((1/n) * torch.mm(ones, ones.t())) if center else torch.zeros(n*n).view([n,n])
56 | H = torch.eye(n).cuda() - h
57 | X_center = torch.mm(H.double(), X.double())
58 | covariance = 1/(n-1) * torch.mm(X_center.t(), X_center).view(p,p)
59 | scaling = torch.sqrt(1/torch.diag(covariance)).double() if scale else torch.ones(p).cuda().double()
60 | scaled_covariance = torch.mm(torch.diag(scaling).view(p,p), covariance)
61 | eigenvalues, eigenvectors = torch.eig(scaled_covariance, True)
62 | components = (eigenvectors[:, :k])
63 | #explained_variance = eigenvalues[:k, 0]
64 | return components
65 |
66 | def best_map(L1,L2):
67 | #L1 should be the groundtruth labels and L2 should be the clustering labels we got
68 | Label1 = np.unique(L1)
69 | nClass1 = len(Label1)
70 | Label2 = np.unique(L2)
71 | nClass2 = len(Label2)
72 | nClass = np.maximum(nClass1,nClass2)
73 | G = np.zeros((nClass,nClass))
74 | for i in range(nClass1):
75 | ind_cla1 = L1 == Label1[i]
76 | ind_cla1 = ind_cla1.astype(float)
77 | for j in range(nClass2):
78 | ind_cla2 = L2 == Label2[j]
79 | ind_cla2 = ind_cla2.astype(float)
80 | G[i,j] = np.sum(ind_cla2 * ind_cla1)
81 | m = Munkres()
82 | index = m.compute(-G.T)
83 | index = np.array(index)
84 | c = index[:,1]
85 | newL2 = np.zeros(L2.shape)
86 | for i in range(nClass2):
87 | newL2[L2 == Label2[i]] = Label1[c[i]]
88 | return newL2
89 |
90 | def geneSelection(data, threshold=0, atleast=10,
91 | yoffset=.02, xoffset=5, decay=1.5, n=None,
92 | plot=True, markers=None, genes=None, figsize=(6,3.5),
93 | markeroffsets=None, labelsize=10, alpha=1, verbose=1):
94 |
95 | if sparse.issparse(data):
96 | zeroRate = 1 - np.squeeze(np.array((data>threshold).mean(axis=0)))
97 | A = data.multiply(data>threshold)
98 | A.data = np.log2(A.data)
99 | meanExpr = np.zeros_like(zeroRate) * np.nan
100 | detected = zeroRate < 1
101 | meanExpr[detected] = np.squeeze(np.array(A[:,detected].mean(axis=0))) / (1-zeroRate[detected])
102 | else:
103 | zeroRate = 1 - np.mean(data>threshold, axis=0)
104 | meanExpr = np.zeros_like(zeroRate) * np.nan
105 | detected = zeroRate < 1
106 | mask = data[:,detected]>threshold
107 | logs = np.zeros_like(data[:,detected]) * np.nan
108 | logs[mask] = np.log2(data[:,detected][mask])
109 | meanExpr[detected] = np.nanmean(logs, axis=0)
110 |
111 | lowDetection = np.array(np.sum(data>threshold, axis=0)).squeeze() < atleast
112 | zeroRate[lowDetection] = np.nan
113 | meanExpr[lowDetection] = np.nan
114 |
115 | if n is not None:
116 | up = 10
117 | low = 0
118 | for t in range(100):
119 | nonan = ~np.isnan(zeroRate)
120 | selected = np.zeros_like(zeroRate).astype(bool)
121 | selected[nonan] = zeroRate[nonan] > np.exp(-decay*(meanExpr[nonan] - xoffset)) + yoffset
122 | if np.sum(selected) == n:
123 | break
124 | elif np.sum(selected) < n:
125 | up = xoffset
126 | xoffset = (xoffset + low)/2
127 | else:
128 | low = xoffset
129 | xoffset = (xoffset + up)/2
130 | if verbose>0:
131 | print('Chosen offset: {:.2f}'.format(xoffset))
132 | else:
133 | nonan = ~np.isnan(zeroRate)
134 | selected = np.zeros_like(zeroRate).astype(bool)
135 | selected[nonan] = zeroRate[nonan] > np.exp(-decay*(meanExpr[nonan] - xoffset)) + yoffset
136 |
137 | if plot:
138 | if figsize is not None:
139 | plt.figure(figsize=figsize)
140 | plt.ylim([0, 1])
141 | if threshold>0:
142 | plt.xlim([np.log2(threshold), np.ceil(np.nanmax(meanExpr))])
143 | else:
144 | plt.xlim([0, np.ceil(np.nanmax(meanExpr))])
145 | x = np.arange(plt.xlim()[0], plt.xlim()[1]+.1,.1)
146 | y = np.exp(-decay*(x - xoffset)) + yoffset
147 | if decay==1:
148 | plt.text(.4, 0.2, '{} genes selected\ny = exp(-x+{:.2f})+{:.2f}'.format(np.sum(selected),xoffset, yoffset),
149 | color='k', fontsize=labelsize, transform=plt.gca().transAxes)
150 | else:
151 | plt.text(.4, 0.2, '{} genes selected\ny = exp(-{:.1f}*(x-{:.2f}))+{:.2f}'.format(np.sum(selected),decay,xoffset, yoffset),
152 | color='k', fontsize=labelsize, transform=plt.gca().transAxes)
153 |
154 | plt.plot(x, y, color=sns.color_palette()[1], linewidth=2)
155 | xy = np.concatenate((np.concatenate((x[:,None],y[:,None]),axis=1), np.array([[plt.xlim()[1], 1]])))
156 | t = plt.matplotlib.patches.Polygon(xy, color=sns.color_palette()[1], alpha=.4)
157 | plt.gca().add_patch(t)
158 |
159 | plt.scatter(meanExpr, zeroRate, s=1, alpha=alpha, rasterized=True)
160 | if threshold==0:
161 | plt.xlabel('Mean log2 nonzero expression')
162 | plt.ylabel('Frequency of zero expression')
163 | else:
164 | plt.xlabel('Mean log2 nonzero expression')
165 | plt.ylabel('Frequency of near-zero expression')
166 | plt.tight_layout()
167 |
168 | if markers is not None and genes is not None:
169 | if markeroffsets is None:
170 | markeroffsets = [(0, 0) for g in markers]
171 | for num,g in enumerate(markers):
172 | i = np.where(genes==g)[0]
173 | plt.scatter(meanExpr[i], zeroRate[i], s=10, color='k')
174 | dx, dy = markeroffsets[num]
175 | plt.text(meanExpr[i]+dx+.1, zeroRate[i]+dy, g, color='k', fontsize=labelsize)
176 |
177 | return selected
178 |
--------------------------------------------------------------------------------