├── Cluster_model ├── Cluster.py ├── GNN.py ├── KNN.py ├── __pycache__ │ ├── GNN.cpython-37.pyc │ ├── KNN.cpython-37.pyc │ ├── Model.cpython-37.pyc │ ├── evaluation.cpython-37.pyc │ ├── preprocess.cpython-37.pyc │ ├── pretrain.cpython-37.pyc │ ├── utils.cpython-37.pyc │ └── utilss.cpython-37.pyc ├── dataset │ ├── Quake_10x_Limb_Muscle │ │ └── data.h5 │ └── Yan │ │ ├── data.csv │ │ └── label.csv ├── evaluation.py ├── preprocess.py ├── pretain_model │ ├── Adam │ │ └── Adam.pkl │ ├── Camp_Brain │ │ └── Camp_Brain.pkl │ ├── Camp_Liver │ │ └── Camp_Liver.pkl │ ├── Chung │ │ └── Chung.pkl │ ├── Klein │ │ └── Klein.pkl │ ├── Kolodziejczyk │ │ └── Kolodziejczyk.pkl │ ├── Muraro │ │ └── Muraro.pkl │ ├── Quake_10x_Limb_Muscle │ │ └── Quake_10x_Limb_Muscle.pkl │ ├── Quake_Smart-seq2_Diaphragm │ │ └── Quake_Smart-seq2_Diaphragm3.pkl │ ├── Quake_Smart-seq2_Limb_Muscle │ │ └── Quake_Smart-seq2_Limb_Muscle127.pkl │ ├── Quake_Smart-seq2_Lung │ │ └── Quake_Smart-seq2_Lung.pkl │ ├── Yan │ │ └── Yan.pkl │ ├── Young │ │ └── Young.pkl │ ├── Zeisel │ │ └── Zeisel.pkl │ ├── human │ │ └── human.pkl │ ├── mouse │ │ └── mouse.pkl │ ├── panc │ │ └── panc.pkl │ └── pbmc │ │ └── pbmc.pkl ├── pretrain.py ├── utils.py └── utilss.py ├── Interaction_model ├── Feature.R ├── Feature.py ├── Interaction_inference.py ├── LRDB │ ├── LRDB.human.rda │ ├── LRDB.mouse.rda │ └── myCompute.RData ├── Mobilev2.py ├── Train │ ├── Mobilev2.py │ ├── ResNet.py │ ├── __pycache__ │ │ ├── Mobilev2.cpython-37.pyc │ │ ├── dataset.cpython-37.pyc │ │ ├── modelv2.cpython-37.pyc │ │ └── test.cpython-37.pyc │ ├── data │ │ └── readme.md │ ├── dataset.py │ ├── modelv2.py │ ├── test.py │ ├── train_kfold.py │ └── trainall_no.py ├── __pycache__ │ ├── Mobilev2.cpython-37.pyc │ ├── dataset.cpython-37.pyc │ ├── modelv2.cpython-37.pyc │ └── utils.cpython-37.pyc ├── cluster │ ├── CellAnnotate.R │ ├── CellAnnotate.py │ ├── Cluster.py │ ├── Feature.R │ ├── GNN.py │ ├── __pycache__ │ │ ├── GNN.cpython-37.pyc │ │ ├── KNN.cpython-37.pyc │ │ ├── Model.cpython-37.pyc │ │ ├── evaluation.cpython-37.pyc │ │ ├── preprocess.cpython-37.pyc │ │ ├── pretrain.cpython-37.pyc │ │ └── utils.cpython-37.pyc │ ├── evaluation.py │ ├── preprocess.py │ ├── pretain_model │ │ └── pbmc │ │ │ └── pbmc.pkl │ ├── pretrain.py │ ├── test.txt │ └── utils.py ├── dataset.py ├── input │ ├── Download.md │ ├── test_cell_label.csv │ └── test_label.txt ├── model │ └── checkpoint-000100.pth ├── modelv2.py ├── nohup.out └── utils.py ├── LICENSE ├── Plot ├── Plot.py ├── bubble.R ├── cell_type.csv ├── chord.R ├── chord.py ├── heatmap.R ├── heatmap.py ├── network.R └── network.py ├── README.md ├── Rpack.Rdata └── requirements.txt /Cluster_model/Cluster.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import numpy as np 4 | import pandas as pd 5 | from preprocess import * 6 | from pretrain import * 7 | import sys 8 | import argparse 9 | import random 10 | from sklearn.cluster import SpectralBiclustering,KMeans, kmeans_plusplus, DBSCAN,SpectralClustering 11 | from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score 12 | from sklearn.metrics import adjusted_rand_score as ari_score 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from torch.nn.parameter import Parameter 17 | from torch.optim import Adam ,SGD,Adamax 18 | from torch.utils.data import DataLoader 19 | from torch.nn import Linear 20 | from utils import load_data, load_graph 21 | from GNN import GNNLayer 22 | import umap 23 | from evaluation import eva,eva_pretrain 24 | from collections import Counter 25 | from sklearn.manifold import TSNE 26 | import matplotlib.pyplot as plt 27 | from collections import OrderedDict 28 | 29 | 30 | 31 | def plot(X, fig, col, size, true_labels,ann): 32 | ax = fig.add_subplot(1, 1, 1) 33 | for i, point in enumerate(X): 34 | ax.scatter(point[0], point[1], s=size, c=col[true_labels[i]],label=ann[i]) 35 | 36 | 37 | def plotClusters(hidden_emb, true_labels,ann): 38 | # Doing dimensionality reduction for plotting 39 | Umap = umap.UMAP(random_state=42) 40 | X_umap = Umap.fit_transform(hidden_emb) 41 | fig2 = plt.figure(figsize=(10,10),dpi=500) 42 | plot(X_umap, fig2, ['green','brown','purple','orange','yellow','hotpink','red','cyan','blue'], 8, true_labels,ann) 43 | handles, labels = fig2.gca().get_legend_handles_labels() 44 | by_label = OrderedDict(zip(labels, handles)) 45 | fig2.legend(by_label.values(), by_label.keys(),loc="upper right") 46 | #fig2.legend() 47 | fig2.savefig("./dataset/"+args.name+"/UMAP.pdf") 48 | plt.close() 49 | 50 | def init_seed(opt): 51 | torch.cuda.cudnn_enabled = False 52 | np.random.seed(opt.seed) 53 | torch.manual_seed(opt.seed) 54 | torch.cuda.manual_seed(opt.seed) 55 | def pretarin_cluster(n_clusters,x,device): 56 | 57 | #print("generate cell graph...") 58 | Auto = args.Auto 59 | #calculate the number of clusters 60 | 61 | 62 | 63 | device = device 64 | 65 | silhouette_pre=[] 66 | print("Start pretrain") 67 | for i in range(args.pretrain_frequency): 68 | print("pretrain:"+str(i)) 69 | model = AE( 70 | n_enc_1=100, 71 | n_enc_2=200, 72 | n_enc_3=200, 73 | n_dec_1=200, 74 | n_dec_2=200, 75 | n_dec_3=100, 76 | n_input=2000, 77 | n_z=5).to(device) 78 | dataset = LoadDataset(x) 79 | epoch = args.pretrain_epoch 80 | silhouette=pretrain_ae(model,dataset,i,device,n_clusters,epoch,args.name,Auto=Auto) 81 | silhouette_pre.append(silhouette) 82 | silhouette_pre = np.array(silhouette_pre) 83 | premodel_i=np.where(silhouette_pre==np.max(silhouette_pre))[0][0] 84 | print("Pretrain end") 85 | return premodel_i 86 | 87 | class AE_train(nn.Module): 88 | 89 | def __init__(self, n_enc_1, n_enc_2, n_enc_3, n_dec_1, n_dec_2, n_dec_3, 90 | n_input, n_z): 91 | super(AE_train, self).__init__() 92 | self.enc_1 = Linear(n_input, n_enc_1) 93 | self.enc_2 = Linear(n_enc_1, n_enc_2) 94 | self.enc_3 = Linear(n_enc_2, n_enc_3) 95 | self.z_layer = Linear(n_enc_3, n_z) 96 | 97 | self.dec_1 = Linear(n_z, n_dec_1) 98 | self.dec_2 = Linear(n_dec_1, n_dec_2) 99 | self.dec_3 = Linear(n_dec_2, n_dec_3) 100 | self.x_bar_layer = Linear(n_dec_3, n_input) 101 | 102 | def forward(self, x): 103 | enc_h1 = F.relu(self.enc_1(x)) 104 | enc_h2 = F.relu(self.enc_2(enc_h1)) 105 | enc_h3 = F.relu(self.enc_3(enc_h2)) 106 | z = self.z_layer(enc_h3) 107 | 108 | dec_h1 = F.relu(self.dec_1(z)) 109 | dec_h2 = F.relu(self.dec_2(dec_h1)) 110 | dec_h3 = F.relu(self.dec_3(dec_h2)) 111 | x_bar = self.x_bar_layer(dec_h3) 112 | 113 | return x_bar, enc_h1, enc_h2, enc_h3, z 114 | 115 | 116 | class ClusterModel(nn.Module): 117 | 118 | def __init__(self, n_enc_1, n_enc_2, n_enc_3, n_dec_1, n_dec_2, n_dec_3, 119 | n_input, n_z, n_clusters, v=1): 120 | super(ClusterModel, self).__init__() 121 | 122 | # autoencoder for intra information 123 | 124 | self.ae = AE_train( 125 | n_enc_1=n_enc_1, 126 | n_enc_2=n_enc_2, 127 | n_enc_3=n_enc_3, 128 | n_dec_1=n_dec_1, 129 | n_dec_2=n_dec_2, 130 | n_dec_3=n_dec_3, 131 | n_input=n_input, 132 | n_z=n_z) 133 | self.ae.load_state_dict(torch.load(args.pretrain_path, map_location='cpu')) 134 | 135 | 136 | #self.ae.load_state_dict(torch.load(args.pretrain_path, map_location='cpu')) 137 | 138 | # GCN for inter information 139 | self.gnn_1 = GNNLayer(n_input, n_enc_1) 140 | self.gnn_2 = GNNLayer(n_enc_1, n_enc_2) 141 | self.gnn_3 = GNNLayer(n_enc_2, n_enc_3) 142 | self.gnn_4 = GNNLayer(n_enc_3, n_z) 143 | self.gnn_5 = GNNLayer(n_z, n_clusters) 144 | 145 | # cluster layer 146 | self.cluster_layer = Parameter(torch.Tensor(n_clusters, n_z)) 147 | torch.nn.init.xavier_normal_(self.cluster_layer.data) 148 | 149 | # degree 150 | self.v = v 151 | 152 | def forward(self, x, adj): 153 | # DNN Module 154 | #x_bar, tra1, tra2, tra3, z = self.ae(x) 155 | #print(x.size()) 156 | # GCN Module 157 | h1 = self.gnn_1(x, adj) 158 | h2 = self.gnn_2(h1, adj) 159 | h3 = self.gnn_3(h2, adj) 160 | h4 = self.gnn_4(h3, adj) 161 | h5 = self.gnn_5(h4, adj, active=False) 162 | predict = F.softmax(h5, dim=1) 163 | 164 | 165 | enc_h1 = F.relu(self.ae.enc_1(x)) 166 | #print(enc_h1.size()) 167 | enc_h2 = F.relu(self.ae.enc_2(enc_h1+h1)) 168 | enc_h3 = F.relu(self.ae.enc_3(enc_h2+h2)) 169 | z = self.ae.z_layer(enc_h3+h3) 170 | 171 | dec_h1 = F.relu(self.ae.dec_1(z+h4)) 172 | dec_h2 = F.relu(self.ae.dec_2(dec_h1+h3)) 173 | dec_h3 = F.relu(self.ae.dec_3(dec_h2+h2)) 174 | x_bar = self.ae.x_bar_layer(dec_h3+h1) 175 | 176 | 177 | 178 | # Dual Self-supervised Module 179 | q = 1.0 / (1.0 + torch.sum(torch.pow(z.unsqueeze(1) - self.cluster_layer, 2), 2) / self.v) 180 | q = q.pow((self.v + 1.0) / 2.0) 181 | q = (q.t() / torch.sum(q, 1)).t() 182 | 183 | return x_bar, q, predict, z 184 | 185 | 186 | def target_distribution(q): 187 | weight = q**2 / q.sum(0) 188 | return (weight.t() / weight.sum(1)).t() 189 | def adjust_learning_rate(optimizer, epoch): 190 | lr = 0.001 * (0.1 ** (epoch // 20)) 191 | for param_group in optimizer.param_groups: 192 | param_group['lr'] = lr 193 | 194 | def train_cluster(dataset,n_clusters,device): 195 | Auto=args.Auto 196 | if Auto: 197 | if z.shape[0] < 2000: 198 | resolution = 0.8 199 | else: 200 | resolution = 0.5 201 | n_clusters = int(n_clusters*resolution) if int(n_clusters*resolution)>=3 else 2 202 | else: 203 | n_clusters=n_clusters 204 | device = device 205 | model = ClusterModel(100, 200, 200, 200, 200, 100, 206 | n_input=args.n_input, 207 | n_z=args.n_z, 208 | n_clusters=n_clusters).to(device) 209 | 210 | 211 | optimizer = Adamax(model.parameters(), lr=args.lr) 212 | 213 | # KNN Graph 214 | adj = load_graph(args.name) 215 | adj = adj.to(device) 216 | 217 | # cluster parameter initiate 218 | data = torch.Tensor(dataset.x).to(device) 219 | y = dataset.y 220 | with torch.no_grad(): 221 | _, _, _, _, z = model.ae(data) 222 | #print(n_clusters) 223 | kmeans = KMeans(n_clusters=n_clusters, n_init=20) 224 | y_pred = kmeans.fit_predict(z.data.cpu().numpy()) 225 | y_pred_last = y_pred 226 | model.cluster_layer.data = torch.tensor(kmeans.cluster_centers_).to(device) 227 | 228 | #print(meta.shape) 229 | for epoch in range(args.Train_epoch): 230 | adjust_learning_rate(optimizer, epoch) 231 | 232 | if epoch % 1 == 0: 233 | # update_interval 234 | _, tmp_q, pred, _ = model(data, adj) 235 | tmp_q = tmp_q.data 236 | p = target_distribution(tmp_q) 237 | 238 | res1 = tmp_q.cpu().numpy().argmax(1) #Q 239 | res2 = pred.data.cpu().numpy().argmax(1) #Z 240 | res3 = p.data.cpu().numpy().argmax(1) #P 241 | nmi,ari,ami,silhouette=eva(tmp_q.cpu().numpy(),y, res1, str(epoch) + 'Q') 242 | 243 | print(str(epoch) + 'Q', 244 | ', nmi {:.4f}'.format(nmi), ', ari {:.4f}'.format(ari), 245 | ', ami {:.4f}'.format(ami),', silhouette {:.4f}'.format(silhouette) 246 | ) 247 | 248 | x_bar, q, pred, _ = model(data, adj) 249 | 250 | kl_loss = F.kl_div(q.log(), p, reduction='batchmean') 251 | ce_loss = F.kl_div(pred.log(), p, reduction='batchmean') 252 | re_loss = F.mse_loss(x_bar, data) 253 | #loss = 0.1*kl_loss + 1*ce_loss + 0.001*re_loss 254 | loss =0.0001*kl_loss + 0.001*ce_loss + 1*re_loss 255 | 256 | optimizer.zero_grad() 257 | loss.backward() 258 | optimizer.step() 259 | #np.savetxt("./output/pre_label.txt",res1,fmt="%s",delimiter=",") 260 | np.savetxt("./dataset/"+args.name+"/pre_embedding.txt",tmp_q.cpu().numpy(),fmt="%s",delimiter=",") 261 | np.savetxt("./dataset/"+args.name+"/pre_label.csv",res1,fmt="%s",delimiter=",") 262 | #pd.DataFrame(res1,index=index,columns=columns).to_csv("./dataset/"+args.name+"/pre_label.csv",quoting=1) 263 | #size = len(np.unique(res1)) 264 | #drawUMAP(tmp_q.cpu().numpy(), res1, size, saveFlag=True) 265 | 266 | 267 | 268 | if __name__ == "__main__": 269 | 270 | 271 | parser = argparse.ArgumentParser( 272 | description='Cell_cluster', 273 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 274 | parser.add_argument('--name', type=str, default='Yan') 275 | parser.add_argument('--seed', type=int, default=2022, help='Random seed.') 276 | parser.add_argument('--lr', type=float, default=0.001) 277 | parser.add_argument('--n_z', default=5, type=int) 278 | parser.add_argument('--pretrain_epoch', default=50, type=int) 279 | parser.add_argument('--pretrain_frequency', default=20, type=int) 280 | parser.add_argument('--Train_epoch', default=30, type=int) 281 | parser.add_argument('--n_input', default=2000, type=int) 282 | parser.add_argument('--pretrain_path', type=str, default='pkl') 283 | parser.add_argument('--Auto', default=False) 284 | parser.add_argument('--pretain', default=True) 285 | parser.add_argument('--device', type=str, default='cuda', 286 | help="Device: 'cuda' or 'cpu'") 287 | args = parser.parse_args() 288 | args.cuda = torch.cuda.is_available() 289 | print("use cuda: {}".format(args.cuda)) 290 | if not os.path.exists("./dataset/"+args.name+"/data"): 291 | os.system('mkdir ./dataset/'+args.name+'/data') 292 | if not os.path.exists("./dataset/"+args.name+"/graph"): 293 | os.system('mkdir ./dataset/'+args.name+'/graph') 294 | if not os.path.exists("./dataset/"+args.name+"/model"): 295 | os.system('mkdir ./dataset/'+args.name+'/model') 296 | 297 | 298 | device = torch.device("cuda" if torch.cuda.is_available() and args.device == 'cuda' else "cpu") 299 | #args.Auto = False 300 | x = np.loadtxt("./dataset/"+args.name+"/data/" + args.name + ".txt", dtype=float) 301 | y = np.loadtxt("./dataset/"+args.name+"/data/" + args.name + "_label.txt", dtype=int) 302 | if args.Auto: 303 | auto_clusters = getcluster(x) 304 | n_clusters = auto_clusters 305 | else: 306 | #cluster_number = int(max(Y) - min(Y) + 1) 307 | n_clusters = int(max(y) - min(y) + 1) 308 | 309 | if args.pretain: 310 | premodel_i = pretarin_cluster(n_clusters,x,device) 311 | #print(premodel_i) 312 | #pretrain_path 313 | args.pretrain_path = './dataset/'+args.name+'/model/'+args.name+str(premodel_i)+'.pkl' 314 | else: 315 | #pretain_model 316 | args.pretrain_path = './pretain_model/'+args.name+'/'+args.name+'.pkl' 317 | 318 | dataset = load_data(args.name) 319 | train_cluster(dataset,n_clusters,device) 320 | -------------------------------------------------------------------------------- /Cluster_model/GNN.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | from torch.nn.modules.module import Module 6 | 7 | 8 | class GNNLayer(Module): 9 | def __init__(self, in_features, out_features): 10 | super(GNNLayer, self).__init__() 11 | self.in_features = in_features 12 | self.out_features = out_features 13 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 14 | torch.nn.init.xavier_uniform_(self.weight) 15 | 16 | def forward(self, features, adj, active=True): 17 | support = torch.mm(features, self.weight) 18 | output = torch.spmm(adj, support) 19 | if active: 20 | output = F.relu(output) 21 | return output 22 | 23 | -------------------------------------------------------------------------------- /Cluster_model/KNN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import ast 4 | import operator 5 | from itertools import chain 6 | import math 7 | import os 8 | from scipy import sparse 9 | from sklearn.cluster import KMeans 10 | from sklearn.neighbors import kneighbors_graph 11 | from scipy.spatial import distance_matrix, minkowski_distance, distance 12 | import scipy.sparse 13 | import sys 14 | import pickle 15 | import csv 16 | import networkx as nx 17 | import numpy as np 18 | from sklearn.ensemble import IsolationForest 19 | import time 20 | from multiprocessing import Pool 21 | import multiprocessing 22 | from igraph import * 23 | from sklearn import preprocessing 24 | 25 | def calculateKNNgraphDistanceMatrixStatsSingleThread(featureMatrix, distanceType='euclidean', k=10, param=None): 26 | 27 | edgeList=[] 28 | 29 | p_time = time.time() 30 | for i in np.arange(featureMatrix.shape[0]): 31 | tmp=featureMatrix[i,:].reshape(1,-1) 32 | distMat = distance.cdist(tmp,featureMatrix, distanceType) 33 | res = distMat.argsort()[:k+1] 34 | tmpdist = distMat[0,res[0][1:k+1]] 35 | 36 | boundary = np.mean(tmpdist)+np.std(tmpdist) 37 | for j in np.arange(1,k+1): 38 | # TODO: check, only exclude large outliners 39 | # if (distMat[0,res[0][j]]<=mean+std) and (distMat[0,res[0][j]]>=mean-std): 40 | 41 | if distMat[0,res[0][j]]<=boundary: 42 | weight = 1.0 43 | else: 44 | weight = 0.0 45 | 46 | #weight = 1.0 47 | edgeList.append((i,res[0][j],weight)) 48 | 49 | return edgeList 50 | 51 | def calculateKNNgraphDistanceMatrix(featureMatrix, distanceType='euclidean', k=10): 52 | 53 | distMat = distance.cdist(featureMatrix,featureMatrix, distanceType) 54 | 55 | edgeList=[] 56 | 57 | for i in np.arange(distMat.shape[0]): 58 | res = distMat[:,i].argsort()[:k] 59 | for j in np.arange(k): 60 | edgeList.append((i,res[j])) 61 | 62 | return edgeList 63 | 64 | def edgeList2edgeDict(edgeList, nodesize): 65 | graphdict={} 66 | tdict={} 67 | 68 | for edge in edgeList: 69 | end1 = edge[0] 70 | end2 = edge[1] 71 | tdict[end1]="" 72 | tdict[end2]="" 73 | if end1 in graphdict: 74 | tmplist = graphdict[end1] 75 | else: 76 | tmplist = [] 77 | tmplist.append(end2) 78 | graphdict[end1]= tmplist 79 | 80 | #check and get full matrix 81 | for i in range(nodesize): 82 | if i not in tdict: 83 | graphdict[i]=[] 84 | 85 | return graphdict 86 | 87 | def generateAdj(featureMatrix, graphType='KNNgraph', para = None): 88 | """ 89 | Generating edgeList 90 | """ 91 | edgeList = None 92 | adj = None 93 | #edgeList = calculateKNNgraphDistanceMatrix(featureMatrix) 94 | edgeList = calculateKNNgraphDistanceMatrixStatsSingleThread(featureMatrix) 95 | graphdict = edgeList2edgeDict(edgeList, featureMatrix.shape[0]) 96 | adj = nx.adjacency_matrix(nx.from_dict_of_lists(graphdict)) 97 | return adj, edgeList 98 | 99 | 100 | def generateLouvainCluster(edgeList): 101 | """ 102 | Louvain Clustering using igraph 103 | """ 104 | Gtmp = nx.Graph() 105 | Gtmp.add_weighted_edges_from(edgeList) 106 | W = nx.adjacency_matrix(Gtmp) 107 | W = W.todense() 108 | graph = Graph.Weighted_Adjacency( 109 | W.tolist(), mode=ADJ_UNDIRECTED, attr="weight", loops=False) 110 | louvain_partition = graph.community_multilevel( 111 | weights=graph.es['weight'], return_levels=False) 112 | size = len(louvain_partition) 113 | hdict = {} 114 | count = 0 115 | for i in range(size): 116 | tlist = louvain_partition[i] 117 | for j in range(len(tlist)): 118 | hdict[tlist[j]] = i 119 | count += 1 120 | 121 | listResult = [] 122 | for i in range(count): 123 | listResult.append(hdict[i]) 124 | 125 | return listResult, size 126 | 127 | 128 | 129 | def getcluster(): 130 | 131 | #featureMatrix = pd.read_csv('./output/Top2000.csv').values[:,1:].T 132 | 133 | feature = pd.read_csv('./output/Top2000.csv',header=None,low_memory=False).values 134 | 135 | Cell_name = feature[0,1:] 136 | featureMatrix = feature[1:,1:].T 137 | np.savetxt("./output/cell_name.txt",Cell_name,fmt="%s",delimiter=" ") 138 | data=(featureMatrix.astype(np.float32)) 139 | np.savetxt("./output/data/cell.txt",data,fmt="%s",delimiter=" ") 140 | adj, edgeList = generateAdj(featureMatrix) 141 | #print(adj) 142 | idx=[] 143 | for i in range(np.array(edgeList).shape[0]): 144 | if np.array(edgeList)[i,-1]==1.0: 145 | idx.append(i) 146 | np.savetxt("./output/graph/cell_graph.csv",np.array(edgeList)[idx,0:-1],fmt="%d") 147 | listResult, size = generateLouvainCluster(edgeList) 148 | n_clusters = len(np.unique(listResult)) 149 | #print('Louvain cluster: '+str(n_clusters)) 150 | return n_clusters 151 | -------------------------------------------------------------------------------- /Cluster_model/__pycache__/GNN.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/__pycache__/GNN.cpython-37.pyc -------------------------------------------------------------------------------- /Cluster_model/__pycache__/KNN.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/__pycache__/KNN.cpython-37.pyc -------------------------------------------------------------------------------- /Cluster_model/__pycache__/Model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/__pycache__/Model.cpython-37.pyc -------------------------------------------------------------------------------- /Cluster_model/__pycache__/evaluation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/__pycache__/evaluation.cpython-37.pyc -------------------------------------------------------------------------------- /Cluster_model/__pycache__/preprocess.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/__pycache__/preprocess.cpython-37.pyc -------------------------------------------------------------------------------- /Cluster_model/__pycache__/pretrain.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/__pycache__/pretrain.cpython-37.pyc -------------------------------------------------------------------------------- /Cluster_model/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /Cluster_model/__pycache__/utilss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/__pycache__/utilss.cpython-37.pyc -------------------------------------------------------------------------------- /Cluster_model/dataset/Quake_10x_Limb_Muscle/data.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/dataset/Quake_10x_Limb_Muscle/data.h5 -------------------------------------------------------------------------------- /Cluster_model/dataset/Yan/label.csv: -------------------------------------------------------------------------------- 1 | Barcode,Cluster 2 | cell1,1 3 | cell2,1 4 | cell3,1 5 | cell4,1 6 | cell5,1 7 | cell6,1 8 | cell7,2 9 | cell8,2 10 | cell9,2 11 | cell10,2 12 | cell11,2 13 | cell12,2 14 | cell13,3 15 | cell14,3 16 | cell15,3 17 | cell16,3 18 | cell17,3 19 | cell18,3 20 | cell19,3 21 | cell20,3 22 | cell21,3 23 | cell22,3 24 | cell23,3 25 | cell24,3 26 | cell25,4 27 | cell26,4 28 | cell27,4 29 | cell28,4 30 | cell29,4 31 | cell30,4 32 | cell31,4 33 | cell32,4 34 | cell33,4 35 | cell34,4 36 | cell35,4 37 | cell36,4 38 | cell37,4 39 | cell38,4 40 | cell39,4 41 | cell40,4 42 | cell41,4 43 | cell42,4 44 | cell43,4 45 | cell44,4 46 | cell45,5 47 | cell46,5 48 | cell47,5 49 | cell48,5 50 | cell49,5 51 | cell50,5 52 | cell51,5 53 | cell52,5 54 | cell53,5 55 | cell54,5 56 | cell55,5 57 | cell56,5 58 | cell57,5 59 | cell58,5 60 | cell59,5 61 | cell60,5 62 | cell61,6 63 | cell62,6 64 | cell63,6 65 | cell64,6 66 | cell65,6 67 | cell66,6 68 | cell67,6 69 | cell68,6 70 | cell69,6 71 | cell70,6 72 | cell71,6 73 | cell72,6 74 | cell73,6 75 | cell74,6 76 | cell75,6 77 | cell76,6 78 | cell77,6 79 | cell78,6 80 | cell79,6 81 | cell80,6 82 | cell81,6 83 | cell82,6 84 | cell83,6 85 | cell84,6 86 | cell85,6 87 | cell86,6 88 | cell87,6 89 | cell88,6 90 | cell89,6 91 | cell90,6 92 | -------------------------------------------------------------------------------- /Cluster_model/evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from munkres import Munkres, print_matrix 3 | from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score 4 | from sklearn.metrics import adjusted_rand_score as ari_score 5 | from scipy.optimize import linear_sum_assignment as linear 6 | from sklearn import metrics 7 | from sklearn.metrics import adjusted_mutual_info_score as ami_score 8 | from sklearn.metrics import silhouette_score,davies_bouldin_score 9 | 10 | def cluster_acc(y_true, y_pred): 11 | y_true = y_true - np.min(y_true) 12 | 13 | l1 = list(set(y_true)) 14 | numclass1 = len(l1) 15 | 16 | l2 = list(set(y_pred)) 17 | numclass2 = len(l2) 18 | 19 | ind = 0 20 | if numclass1 != numclass2: 21 | for i in l1: 22 | if i in l2: 23 | pass 24 | else: 25 | y_pred[ind] = i 26 | ind += 1 27 | 28 | l2 = list(set(y_pred)) 29 | numclass2 = len(l2) 30 | 31 | if numclass1 != numclass2: 32 | print('error') 33 | return 34 | 35 | cost = np.zeros((numclass1, numclass2), dtype=int) 36 | for i, c1 in enumerate(l1): 37 | mps = [i1 for i1, e1 in enumerate(y_true) if e1 == c1] 38 | for j, c2 in enumerate(l2): 39 | mps_d = [i1 for i1 in mps if y_pred[i1] == c2] 40 | cost[i][j] = len(mps_d) 41 | 42 | # match two clustering results by Munkres algorithm 43 | m = Munkres() 44 | cost = cost.__neg__().tolist() 45 | indexes = m.compute(cost) 46 | 47 | # get the match results 48 | new_predict = np.zeros(len(y_pred)) 49 | for i, c in enumerate(l1): 50 | # correponding label in l2: 51 | c2 = l2[indexes[i][1]] 52 | 53 | # ai is the index with label==c2 in the pred_label list 54 | ai = [ind for ind, elm in enumerate(y_pred) if elm == c2] 55 | new_predict[ai] = c 56 | 57 | acc = metrics.accuracy_score(y_true, new_predict) 58 | f1_macro = metrics.f1_score(y_true, new_predict, average='macro') 59 | precision_macro = metrics.precision_score(y_true, new_predict, average='macro') 60 | recall_macro = metrics.recall_score(y_true, new_predict, average='macro') 61 | f1_micro = metrics.f1_score(y_true, new_predict, average='micro') 62 | precision_micro = metrics.precision_score(y_true, new_predict, average='micro') 63 | recall_micro = metrics.recall_score(y_true, new_predict, average='micro') 64 | return acc, f1_macro 65 | 66 | 67 | def eva(X,y_true, y_pred, epoch=0): 68 | #acc, f1 = cluster_acc(y_true, y_pred) 69 | nmi = nmi_score(y_true, y_pred, average_method='arithmetic') 70 | ari = ari_score(y_true, y_pred) 71 | ami = ami_score(y_true, y_pred) 72 | silhouette = silhouette_score(X, y_pred,metric='euclidean') 73 | return nmi,ari,ami,silhouette 74 | def eva_pretrain(X, y_pred, epoch=0): 75 | silhouette = silhouette_score(X, y_pred,metric='euclidean') 76 | return silhouette -------------------------------------------------------------------------------- /Cluster_model/preprocess.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import utilss as utilss 5 | import h5py 6 | import scipy as sp 7 | import numpy as np 8 | import scanpy as sc 9 | import pandas as pd 10 | import ast 11 | import argparse 12 | import operator 13 | from itertools import chain 14 | import math 15 | import os 16 | from scipy import sparse 17 | from sklearn.cluster import KMeans 18 | from sklearn.neighbors import kneighbors_graph 19 | from scipy.spatial import distance_matrix, minkowski_distance, distance 20 | import scipy.sparse 21 | import sys 22 | import pickle 23 | import csv 24 | import networkx as nx 25 | import numpy as np 26 | from sklearn.ensemble import IsolationForest 27 | import time 28 | from multiprocessing import Pool 29 | import multiprocessing 30 | from igraph import * 31 | from sklearn import preprocessing 32 | 33 | def calculateKNNgraphDistanceMatrixStatsSingleThread(featureMatrix, distanceType='euclidean', k=10, param=None): 34 | 35 | edgeList=[] 36 | 37 | p_time = time.time() 38 | for i in np.arange(featureMatrix.shape[0]): 39 | tmp=featureMatrix[i,:].reshape(1,-1) 40 | distMat = distance.cdist(tmp,featureMatrix, distanceType) 41 | res = distMat.argsort()[:k+1] 42 | tmpdist = distMat[0,res[0][1:k+1]] 43 | 44 | boundary = np.mean(tmpdist)+np.std(tmpdist) 45 | for j in np.arange(1,k+1): 46 | # TODO: check, only exclude large outliners 47 | # if (distMat[0,res[0][j]]<=mean+std) and (distMat[0,res[0][j]]>=mean-std): 48 | 49 | if distMat[0,res[0][j]]<=boundary: 50 | weight = 1.0 51 | else: 52 | weight = 0.0 53 | 54 | #weight = 1.0 55 | edgeList.append((i,res[0][j],weight)) 56 | 57 | return edgeList 58 | 59 | def calculateKNNgraphDistanceMatrix(featureMatrix, distanceType='euclidean', k=10): 60 | 61 | distMat = distance.cdist(featureMatrix,featureMatrix, distanceType) 62 | 63 | edgeList=[] 64 | 65 | for i in np.arange(distMat.shape[0]): 66 | res = distMat[:,i].argsort()[:k] 67 | for j in np.arange(k): 68 | edgeList.append((i,res[j])) 69 | 70 | return edgeList 71 | 72 | def edgeList2edgeDict(edgeList, nodesize): 73 | graphdict={} 74 | tdict={} 75 | 76 | for edge in edgeList: 77 | end1 = edge[0] 78 | end2 = edge[1] 79 | tdict[end1]="" 80 | tdict[end2]="" 81 | if end1 in graphdict: 82 | tmplist = graphdict[end1] 83 | else: 84 | tmplist = [] 85 | tmplist.append(end2) 86 | graphdict[end1]= tmplist 87 | 88 | #check and get full matrix 89 | for i in range(nodesize): 90 | if i not in tdict: 91 | graphdict[i]=[] 92 | 93 | return graphdict 94 | 95 | def generateAdj(featureMatrix, graphType='KNNgraph', para = None): 96 | """ 97 | Generating edgeList 98 | """ 99 | edgeList = None 100 | adj = None 101 | #edgeList = calculateKNNgraphDistanceMatrix(featureMatrix) 102 | edgeList = calculateKNNgraphDistanceMatrixStatsSingleThread(featureMatrix) 103 | graphdict = edgeList2edgeDict(edgeList, featureMatrix.shape[0]) 104 | adj = nx.adjacency_matrix(nx.from_dict_of_lists(graphdict)) 105 | return adj, edgeList 106 | 107 | 108 | def generateLouvainCluster(edgeList): 109 | """ 110 | Louvain Clustering using igraph 111 | """ 112 | Gtmp = nx.Graph() 113 | Gtmp.add_weighted_edges_from(edgeList) 114 | W = nx.adjacency_matrix(Gtmp) 115 | W = W.todense() 116 | graph = Graph.Weighted_Adjacency( 117 | W.tolist(), mode=ADJ_UNDIRECTED, attr="weight", loops=False) 118 | louvain_partition = graph.community_multilevel( 119 | weights=graph.es['weight'], return_levels=False) 120 | size = len(louvain_partition) 121 | hdict = {} 122 | count = 0 123 | for i in range(size): 124 | tlist = louvain_partition[i] 125 | for j in range(len(tlist)): 126 | hdict[tlist[j]] = i 127 | count += 1 128 | 129 | listResult = [] 130 | for i in range(count): 131 | listResult.append(hdict[i]) 132 | 133 | return listResult, size 134 | 135 | 136 | 137 | def read_clean(data): 138 | assert isinstance(data, np.ndarray) 139 | if data.dtype.type is np.bytes_: 140 | data = utilss.decode(data) 141 | if data.size == 1: 142 | data = data.flat[0] 143 | return data 144 | 145 | 146 | def dict_from_group(group): 147 | assert isinstance(group, h5py.Group) 148 | d = utilss.dotdict() 149 | for key in group: 150 | if isinstance(group[key], h5py.Group): 151 | value = dict_from_group(group[key]) 152 | else: 153 | value = read_clean(group[key][...]) 154 | d[key] = value 155 | return d 156 | 157 | 158 | def read_data(filename, sparsify = False, skip_exprs = False): 159 | with h5py.File(filename, "r") as f: 160 | obs = pd.DataFrame(dict_from_group(f["obs"]), index = utilss.decode(f["obs_names"][...])) 161 | var = pd.DataFrame(dict_from_group(f["var"]), index = utilss.decode(f["var_names"][...])) 162 | uns = dict_from_group(f["uns"]) 163 | if not skip_exprs: 164 | exprs_handle = f["exprs"] 165 | if isinstance(exprs_handle, h5py.Group): 166 | mat = sp.sparse.csr_matrix((exprs_handle["data"][...], exprs_handle["indices"][...], 167 | exprs_handle["indptr"][...]), shape = exprs_handle["shape"][...]) 168 | else: 169 | mat = exprs_handle[...].astype(np.float32) 170 | if sparsify: 171 | mat = sp.sparse.csr_matrix(mat) 172 | else: 173 | mat = sp.sparse.csr_matrix((obs.shape[0], var.shape[0])) 174 | return mat, obs, var, uns 175 | 176 | 177 | def prepro(data_type,filename): 178 | if data_type == 'csv': 179 | data_path = "./dataset/" + filename + "/data.csv" 180 | label_path = "./dataset/" + filename + "/label.csv" 181 | X = pd.read_csv(data_path, header=0, index_col=0, sep=',') 182 | #X = np.expm1(X) 183 | cell_label = pd.read_csv(label_path).values[:,-1] 184 | 185 | if data_type == 'h5': 186 | data_path = "./dataset/" + filename + "/data.h5" 187 | mat, obs, var, uns = read_data(data_path, sparsify=False, skip_exprs=False) 188 | if isinstance(mat, np.ndarray): 189 | X = np.array(mat) 190 | else: 191 | X = np.array(mat.toarray()) 192 | cell_name = np.array(obs["cell_type1"]) 193 | cell_type, cell_label = np.unique(cell_name, return_inverse=True) 194 | return X, cell_label 195 | 196 | 197 | def Selecting_highly_variable_genes(X, highly_genes): 198 | adata = sc.AnnData(X) 199 | adata.var_names_make_unique() 200 | sc.pp.filter_genes(adata, min_counts=1) 201 | sc.pp.filter_cells(adata, min_counts=1) 202 | sc.pp.normalize_per_cell(adata) 203 | sc.pp.log1p(adata) 204 | adata.raw = adata 205 | sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5, n_top_genes=highly_genes) 206 | adata = adata[:, adata.var['highly_variable']].copy() 207 | sc.pp.scale(adata) 208 | data = adata.X 209 | 210 | return data 211 | 212 | def normalize(adata, copy=True, highly_genes = None, filter_min_counts=True, size_factors=True, normalize_input=True, logtrans_input=True): 213 | if isinstance(adata, sc.AnnData): 214 | if copy: 215 | adata = adata.copy() 216 | elif isinstance(adata, str): 217 | adata = sc.read(adata) 218 | else: 219 | raise NotImplementedError 220 | norm_error = 'Make sure that the dataset (adata.X) contains unnormalized count data.' 221 | assert 'n_count' not in adata.obs, norm_error 222 | if adata.X.size < 50e6: # check if adata.X is integer only if array is small 223 | if sp.sparse.issparse(adata.X): 224 | assert (adata.X.astype(int) != adata.X).nnz == 0, norm_error 225 | else: 226 | assert np.all(adata.X.astype(int) == adata.X), norm_error 227 | 228 | if filter_min_counts: 229 | sc.pp.filter_genes(adata, min_counts=1) 230 | sc.pp.filter_cells(adata, min_counts=1) 231 | if size_factors or normalize_input or logtrans_input: 232 | adata.raw = adata.copy() 233 | else: 234 | adata.raw = adata 235 | if size_factors: 236 | sc.pp.normalize_per_cell(adata) 237 | adata.obs['size_factors'] = adata.obs.n_counts / np.median(adata.obs.n_counts) 238 | else: 239 | adata.obs['size_factors'] = 1.0 240 | if logtrans_input: 241 | sc.pp.log1p(adata) 242 | if highly_genes != None: 243 | sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5, n_top_genes = highly_genes, subset=True) 244 | if normalize_input: 245 | sc.pp.scale(adata) 246 | return adata 247 | def getcluster(x): 248 | 249 | adj, edgeList = generateAdj(x) 250 | #print(adj) 251 | idx=[] 252 | for i in range(np.array(edgeList).shape[0]): 253 | if np.array(edgeList)[i,-1]==1.0: 254 | idx.append(i) 255 | listResult, size = generateLouvainCluster(edgeList) 256 | n_clusters = len(np.unique(listResult)) 257 | #print('Louvain cluster: '+str(n_clusters)) 258 | return n_clusters 259 | 260 | if __name__ == "__main__": 261 | 262 | parser = argparse.ArgumentParser( 263 | description='Cell_cluster', 264 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 265 | parser.add_argument('--name', type=str, default='Yan') 266 | parser.add_argument('--file_format', type=str, default='csv') 267 | args = parser.parse_args() 268 | 269 | 270 | 271 | 272 | 273 | 274 | filename = args.name 275 | if not os.path.exists("./dataset/"+filename+"/data"): 276 | os.system('mkdir ./dataset/'+filename+'/data') 277 | if not os.path.exists("./dataset/"+filename+"/graph"): 278 | os.system('mkdir ./dataset/'+filename+'/graph') 279 | if not os.path.exists("./dataset/"+filename+"/model"): 280 | os.system('mkdir ./dataset/'+filename+'/model') 281 | 282 | data_type = args.file_format 283 | if data_type == 'h5': 284 | X, Y = prepro(data_type,filename) 285 | X = np.ceil(X).astype(np.float32) 286 | #print(X) 287 | count_X = X 288 | cluster_number = int(max(Y) - min(Y) + 1) 289 | adata = sc.AnnData(X) 290 | adata.obs['Group'] = Y 291 | adata = normalize(adata, copy=True, highly_genes=2000, size_factors=True, normalize_input=True, logtrans_input=True) 292 | X = adata.X.astype(np.float32) 293 | Y = np.array(adata.obs["Group"]) 294 | high_variable = np.array(adata.var.highly_variable.index, dtype=np.int32) 295 | count_X = count_X[:, high_variable] 296 | data=(count_X.astype(np.float32)) 297 | data=preprocessing.MinMaxScaler().fit_transform(data) 298 | data=preprocessing.normalize(data, norm='l2') 299 | np.savetxt("./dataset/"+filename+"/data/" + filename + ".txt",data,fmt="%s",delimiter=" ") 300 | np.savetxt("./dataset/"+filename+"/data/" + filename + "_label.txt",Y,fmt="%s",delimiter=" ") 301 | if data_type == 'csv': 302 | X, Y = prepro(data_type,filename) 303 | #X = np.ceil(X).astype(np.float32) 304 | data = np.array(X).astype('float32') 305 | #print(X) 306 | #count_X = X 307 | cluster_number = int(max(Y) - min(Y) + 1) 308 | #data = np.expm1(data) 309 | #data=(count_X.astype(np.float32)) 310 | data = Selecting_highly_variable_genes(data, 2000) 311 | #data=preprocessing.MinMaxScaler().fit_transform(data) 312 | data=preprocessing.QuantileTransformer(random_state=0).fit_transform(data) 313 | 314 | data=preprocessing.normalize(data, norm='l2') 315 | np.savetxt("./dataset/"+filename+"/data/" + filename + ".txt",data,fmt="%s",delimiter=" ") 316 | np.savetxt("./dataset/"+filename+"/data/" + filename + "_label.txt",Y,fmt="%s",delimiter=" ") 317 | 318 | adj, edgeList = generateAdj(data) 319 | #print(adj) 320 | idx=[] 321 | for i in range(np.array(edgeList).shape[0]): 322 | if np.array(edgeList)[i,-1]==1.0: 323 | idx.append(i) 324 | np.savetxt("./dataset/"+filename+"/graph/" + filename + "_graph.txt",np.array(edgeList)[idx,0:-1],fmt="%d") 325 | 326 | 327 | -------------------------------------------------------------------------------- /Cluster_model/pretain_model/Adam/Adam.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/pretain_model/Adam/Adam.pkl -------------------------------------------------------------------------------- /Cluster_model/pretain_model/Camp_Brain/Camp_Brain.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/pretain_model/Camp_Brain/Camp_Brain.pkl -------------------------------------------------------------------------------- /Cluster_model/pretain_model/Camp_Liver/Camp_Liver.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/pretain_model/Camp_Liver/Camp_Liver.pkl -------------------------------------------------------------------------------- /Cluster_model/pretain_model/Chung/Chung.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/pretain_model/Chung/Chung.pkl -------------------------------------------------------------------------------- /Cluster_model/pretain_model/Klein/Klein.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/pretain_model/Klein/Klein.pkl -------------------------------------------------------------------------------- /Cluster_model/pretain_model/Kolodziejczyk/Kolodziejczyk.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/pretain_model/Kolodziejczyk/Kolodziejczyk.pkl -------------------------------------------------------------------------------- /Cluster_model/pretain_model/Muraro/Muraro.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/pretain_model/Muraro/Muraro.pkl -------------------------------------------------------------------------------- /Cluster_model/pretain_model/Quake_10x_Limb_Muscle/Quake_10x_Limb_Muscle.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/pretain_model/Quake_10x_Limb_Muscle/Quake_10x_Limb_Muscle.pkl -------------------------------------------------------------------------------- /Cluster_model/pretain_model/Quake_Smart-seq2_Diaphragm/Quake_Smart-seq2_Diaphragm3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/pretain_model/Quake_Smart-seq2_Diaphragm/Quake_Smart-seq2_Diaphragm3.pkl -------------------------------------------------------------------------------- /Cluster_model/pretain_model/Quake_Smart-seq2_Limb_Muscle/Quake_Smart-seq2_Limb_Muscle127.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/pretain_model/Quake_Smart-seq2_Limb_Muscle/Quake_Smart-seq2_Limb_Muscle127.pkl -------------------------------------------------------------------------------- /Cluster_model/pretain_model/Quake_Smart-seq2_Lung/Quake_Smart-seq2_Lung.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/pretain_model/Quake_Smart-seq2_Lung/Quake_Smart-seq2_Lung.pkl -------------------------------------------------------------------------------- /Cluster_model/pretain_model/Yan/Yan.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/pretain_model/Yan/Yan.pkl -------------------------------------------------------------------------------- /Cluster_model/pretain_model/Young/Young.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/pretain_model/Young/Young.pkl -------------------------------------------------------------------------------- /Cluster_model/pretain_model/Zeisel/Zeisel.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/pretain_model/Zeisel/Zeisel.pkl -------------------------------------------------------------------------------- /Cluster_model/pretain_model/human/human.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/pretain_model/human/human.pkl -------------------------------------------------------------------------------- /Cluster_model/pretain_model/mouse/mouse.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/pretain_model/mouse/mouse.pkl -------------------------------------------------------------------------------- /Cluster_model/pretain_model/panc/panc.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/pretain_model/panc/panc.pkl -------------------------------------------------------------------------------- /Cluster_model/pretain_model/pbmc/pbmc.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Cluster_model/pretain_model/pbmc/pbmc.pkl -------------------------------------------------------------------------------- /Cluster_model/pretrain.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.parameter import Parameter 7 | from torch.utils.data import DataLoader 8 | from torch.optim import Adam, SGD, Adamax 9 | from torch.nn import Linear 10 | from torch.utils.data import Dataset 11 | from sklearn.metrics import silhouette_score, davies_bouldin_score 12 | from sklearn.cluster import SpectralBiclustering,KMeans, kmeans_plusplus, DBSCAN,SpectralClustering 13 | from evaluation import eva_pretrain 14 | import umap 15 | import argparse 16 | 17 | #torch.cuda.set_device(3) 18 | 19 | 20 | class AE(nn.Module): 21 | 22 | def __init__(self, n_enc_1, n_enc_2, n_enc_3, n_dec_1, n_dec_2, n_dec_3, 23 | n_input, n_z): 24 | super(AE, self).__init__() 25 | self.enc_1 = Linear(n_input, n_enc_1) 26 | self.enc_2 = Linear(n_enc_1, n_enc_2) 27 | self.enc_3 = Linear(n_enc_2, n_enc_3) 28 | self.z_layer = Linear(n_enc_3, n_z) 29 | 30 | self.dec_1 = Linear(n_z, n_dec_1) 31 | self.dec_2 = Linear(n_dec_1, n_dec_2) 32 | self.dec_3 = Linear(n_dec_2, n_dec_3) 33 | self.x_bar_layer = Linear(n_dec_3, n_input) 34 | 35 | def forward(self, x): 36 | enc_h1 = F.relu(self.enc_1(x)) 37 | enc_h2 = F.relu(self.enc_2(enc_h1)) 38 | enc_h3 = F.relu(self.enc_3(enc_h2)) 39 | z = self.z_layer(enc_h3) 40 | 41 | dec_h1 = F.relu(self.dec_1(z)) 42 | dec_h2 = F.relu(self.dec_2(dec_h1)) 43 | dec_h3 = F.relu(self.dec_3(dec_h2)) 44 | x_bar = self.x_bar_layer(dec_h3) 45 | 46 | return x_bar, z 47 | 48 | 49 | class LoadDataset(Dataset): 50 | def __init__(self, data): 51 | self.x = data 52 | 53 | def __len__(self): 54 | return self.x.shape[0] 55 | 56 | def __getitem__(self, idx): 57 | return torch.from_numpy(np.array(self.x[idx])).float(), \ 58 | torch.from_numpy(np.array(idx)) 59 | 60 | 61 | def adjust_learning_rate(optimizer, epoch): 62 | lr = 0.001 * (0.1 ** (epoch // 20)) 63 | for param_group in optimizer.param_groups: 64 | param_group['lr'] = lr 65 | 66 | 67 | def pretrain_ae(model,dataset,m,device,n_clusters,epoch,name,Auto=True): 68 | train_loader = DataLoader(dataset, batch_size=None, shuffle=True) 69 | #device = args.device 70 | #print(device) 71 | optimizer = Adamax(model.parameters(), lr=1e-2) 72 | for epoch in range(epoch): 73 | adjust_learning_rate(optimizer, epoch) 74 | for batch_idx, (x, _) in enumerate(train_loader): 75 | x = x.to(device) 76 | x_bar, _ = model(x) 77 | loss = F.mse_loss(x_bar, x) 78 | 79 | optimizer.zero_grad() 80 | loss.backward() 81 | optimizer.step() 82 | 83 | with torch.no_grad(): 84 | x = torch.Tensor(dataset.x).to(device).float() 85 | x_bar, z = model(x) 86 | loss = F.mse_loss(x_bar, x) 87 | 88 | #for i in range(0,100,10): 89 | if z.shape[0] < 5000: 90 | resolution = 0.8 91 | else: 92 | resolution = 0.5 93 | if Auto: 94 | n_clusters = int(clusters*resolution) if int(clusters*resolution)>=3 else 2 95 | #print(n_clusters) 96 | kmeans = KMeans(n_clusters=n_clusters, n_init=20).fit(z.data.cpu().numpy()) 97 | silhouette =eva_pretrain(z.data.cpu().numpy(), kmeans.labels_, epoch) 98 | #print(kmeans.labels_) 99 | #print('epoch{} loss: {}'.format(epoch, loss),'silhouette {:.4f}'.format(silhouette)) 100 | torch.save(model.state_dict(), './dataset/'+name+'/model/'+name+str(m)+'.pkl') 101 | 102 | return silhouette 103 | 104 | 105 | -------------------------------------------------------------------------------- /Cluster_model/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | import h5py 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | 8 | def load_graph(dataset): 9 | 10 | path = './dataset/{}/graph/{}_graph.txt'.format(dataset,dataset) 11 | 12 | data = np.loadtxt('./dataset/{}/data/{}.txt'.format(dataset,dataset)) 13 | n, _ = data.shape 14 | 15 | idx = np.array([i for i in range(n)], dtype=np.int32) 16 | idx_map = {j: i for i, j in enumerate(idx)} 17 | edges_unordered = np.genfromtxt(path, dtype=np.int32) 18 | edges = np.array(list(map(idx_map.get, edges_unordered.flatten())), 19 | dtype=np.int32).reshape(edges_unordered.shape) 20 | adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])), 21 | shape=(n, n), dtype=np.float32) 22 | 23 | 24 | # build symmetric adjacency matrix 25 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) 26 | adj = adj + sp.eye(adj.shape[0]) 27 | adj = normalize(adj) 28 | adj = sparse_mx_to_torch_sparse_tensor(adj) 29 | 30 | return adj 31 | 32 | 33 | def normalize(mx): 34 | """Row-normalize sparse matrix""" 35 | rowsum = np.array(mx.sum(1)) 36 | r_inv = np.power(rowsum, -1).flatten() 37 | r_inv[np.isinf(r_inv)] = 0. 38 | r_mat_inv = sp.diags(r_inv) 39 | mx = r_mat_inv.dot(mx) 40 | return mx 41 | 42 | 43 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 44 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 45 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 46 | indices = torch.from_numpy( 47 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 48 | values = torch.from_numpy(sparse_mx.data) 49 | shape = torch.Size(sparse_mx.shape) 50 | return torch.sparse.FloatTensor(indices, values, shape) 51 | 52 | 53 | class load_data(Dataset): 54 | def __init__(self, dataset): 55 | #self.x = np.loadtxt('dataset/data/{}.txt'.format(dataset), dtype=float) 56 | #self.y = np.loadtxt('input/test_label.txt', dtype=int) 57 | self.x = np.loadtxt("./dataset/{}/data/{}.txt".format(dataset,dataset), dtype=float) 58 | self.y = np.loadtxt("./dataset/{}/data/{}_label.txt".format(dataset,dataset), dtype=int) 59 | 60 | def __len__(self): 61 | return self.x.shape[0] 62 | 63 | def __getitem__(self, idx): 64 | return torch.from_numpy(np.array(self.x[idx])),\ 65 | torch.from_numpy(np.array(self.y[idx])),\ 66 | torch.from_numpy(np.array(idx)) 67 | 68 | 69 | -------------------------------------------------------------------------------- /Cluster_model/utilss.py: -------------------------------------------------------------------------------- 1 | import json 2 | import functools 3 | import operator 4 | import collections 5 | import jgraph 6 | import numpy as np 7 | import scipy.sparse 8 | import tqdm 9 | 10 | 11 | class dotdict(dict): 12 | __getattr__ = dict.get 13 | __setattr__ = dict.__setitem__ 14 | __delattr__ = dict.__delitem__ 15 | 16 | 17 | def in_ipynb(): # pragma: no cover 18 | try: 19 | # noinspection PyUnresolvedReferences 20 | shell = get_ipython().__class__.__name__ 21 | if shell == "ZMQInteractiveShell": 22 | return True # Jupyter notebook or qtconsole 23 | elif shell == "TerminalInteractiveShell": 24 | return False # Terminal running IPython 25 | else: 26 | return False # Other type (?) 27 | except NameError: 28 | return False # Probably standard Python interpreter 29 | 30 | 31 | def smart_tqdm(): # pragma: no cover 32 | if in_ipynb(): 33 | return tqdm.tqdm_notebook 34 | return tqdm.tqdm 35 | 36 | 37 | def with_self_graph(fn): 38 | @functools.wraps(fn) 39 | def wrapped(self, *args, **kwargs): 40 | with self.graph.as_default(): 41 | return fn(self, *args, **kwargs) 42 | return wrapped 43 | 44 | 45 | # Wraps a batch function into minibatch version 46 | def minibatch(batch_size, desc, use_last=False, progress_bar=True): 47 | def minibatch_wrapper(func): 48 | @functools.wraps(func) 49 | def wrapped_func(*args, **kwargs): 50 | total_size = args[0].shape[0] 51 | if use_last: 52 | n_batch = np.ceil( 53 | total_size / float(batch_size) 54 | ).astype(np.int) 55 | else: 56 | n_batch = max(1, np.floor( 57 | total_size / float(batch_size) 58 | ).astype(np.int)) 59 | for batch_idx in smart_tqdm()( 60 | range(n_batch), desc=desc, unit="batches", 61 | leave=False, disable=not progress_bar 62 | ): 63 | start = batch_idx * batch_size 64 | end = min((batch_idx + 1) * batch_size, total_size) 65 | this_args = (item[start:end] for item in args) 66 | func(*this_args, **kwargs) 67 | return wrapped_func 68 | return minibatch_wrapper 69 | 70 | 71 | # Avoid sklearn warning 72 | def encode_integer(label, sort=False): 73 | label = np.array(label).ravel() 74 | classes = np.unique(label) 75 | if sort: 76 | classes.sort() 77 | mapping = {v: i for i, v in enumerate(classes)} 78 | return np.array([mapping[v] for v in label]), classes 79 | 80 | 81 | # Avoid sklearn warning 82 | def encode_onehot(label, sort=False, ignore=None): 83 | i, c = encode_integer(label, sort) 84 | onehot = scipy.sparse.csc_matrix(( 85 | np.ones_like(i, dtype=np.int32), (np.arange(i.size), i) 86 | )) 87 | if ignore is None: 88 | ignore = [] 89 | return onehot[:, ~np.in1d(c, ignore)].tocsr() 90 | 91 | 92 | class CellTypeDAG(object): 93 | 94 | def __init__(self, graph=None, vdict=None): 95 | self.graph = jgraph.Graph(directed=True) if graph is None else graph 96 | self.vdict = {} if vdict is None else vdict 97 | 98 | @classmethod 99 | def load(cls, file): 100 | if file.endswith(".json"): 101 | return cls.load_json(file) 102 | elif file.endswith(".obo"): 103 | return cls.load_obo(file) 104 | else: 105 | raise ValueError("Unexpected file format!") 106 | 107 | @classmethod 108 | def load_json(cls, file): 109 | with open(file, "r") as f: 110 | d = json.load(f) 111 | dag = cls() 112 | dag._build_tree(d) 113 | return dag 114 | 115 | @classmethod 116 | def load_obo(cls, file): # Only building on "is_a" relation between CL terms 117 | import pronto 118 | ont = pronto.Ontology(file) 119 | graph, vdict = jgraph.Graph(directed=True), {} 120 | for item in ont: 121 | if not item.id.startswith("CL"): 122 | continue 123 | if "is_obsolete" in item.other and item.other["is_obsolete"][0] == "true": 124 | continue 125 | graph.add_vertex( 126 | name=item.id, cell_ontology_class=item.name, 127 | desc=str(item.desc), synonyms=[( 128 | "%s (%s)" % (syn.desc, syn.scope) 129 | ) for syn in item.synonyms] 130 | ) 131 | assert item.id not in vdict 132 | vdict[item.id] = item.id 133 | assert item.name not in vdict 134 | vdict[item.name] = item.id 135 | for synonym in item.synonyms: 136 | if synonym.scope == "EXACT" and synonym.desc != item.name: 137 | vdict[synonym.desc] = item.id 138 | for source in graph.vs: 139 | for relation in ont[source["name"]].relations: 140 | if relation.obo_name != "is_a": 141 | continue 142 | for target in ont[source["name"]].relations[relation]: 143 | if not target.id.startswith("CL"): 144 | continue 145 | graph.add_edge( 146 | source["name"], 147 | graph.vs.find(name=target.id.split()[0])["name"] 148 | ) 149 | # Split because there are many "{is_infered...}" suffix, 150 | # falsely joined to the actual id when pronto parses the 151 | # obo file 152 | return cls(graph, vdict) 153 | 154 | def _build_tree(self, d, parent=None): # For json loading 155 | self.graph.add_vertex(name=d["name"]) 156 | v = self.graph.vs.find(d["name"]) 157 | if parent is not None: 158 | self.graph.add_edge(v, parent) 159 | self.vdict[d["name"]] = d["name"] 160 | if "alias" in d: 161 | for alias in d["alias"]: 162 | self.vdict[alias] = d["name"] 163 | if "children" in d: 164 | for subd in d["children"]: 165 | self._build_tree(subd, v) 166 | 167 | def get_vertex(self, name): 168 | return self.graph.vs.find(self.vdict[name]) 169 | 170 | def is_related(self, name1, name2): 171 | return self.is_descendant_of(name1, name2) \ 172 | or self.is_ancestor_of(name1, name2) 173 | 174 | def is_descendant_of(self, name1, name2): 175 | if name1 not in self.vdict or name2 not in self.vdict: 176 | return False 177 | shortest_path = self.graph.shortest_paths( 178 | self.get_vertex(name1), self.get_vertex(name2) 179 | )[0][0] 180 | return np.isfinite(shortest_path) 181 | 182 | def is_ancestor_of(self, name1, name2): 183 | if name1 not in self.vdict or name2 not in self.vdict: 184 | return False 185 | shortest_path = self.graph.shortest_paths( 186 | self.get_vertex(name2), self.get_vertex(name1) 187 | )[0][0] 188 | return np.isfinite(shortest_path) 189 | 190 | def conditional_prob(self, name1, name2): # p(name1|name2) 191 | if name1 not in self.vdict or name2 not in self.vdict: 192 | return 0 193 | self.graph.vs["prob"] = 0 194 | v2_parents = list(self.graph.bfsiter( 195 | self.get_vertex(name2), mode=jgraph.OUT)) 196 | v1_parents = list(self.graph.bfsiter( 197 | self.get_vertex(name1), mode=jgraph.OUT)) 198 | for v in v2_parents: 199 | v["prob"] = 1 200 | while True: 201 | changed = False 202 | for v1_parent in v1_parents[::-1]: # Reverse may be more efficient 203 | if v1_parent["prob"] != 0: 204 | continue 205 | v1_parent["prob"] = np.prod([ 206 | v["prob"] / v.degree(mode=jgraph.IN) 207 | for v in v1_parent.neighbors(mode=jgraph.OUT) 208 | ]) 209 | if v1_parent["prob"] != 0: 210 | changed = True 211 | if not changed: 212 | break 213 | return self.get_vertex(name1)["prob"] 214 | 215 | def similarity(self, name1, name2, method="probability"): 216 | if method == "probability": 217 | return ( 218 | self.conditional_prob(name1, name2) + 219 | self.conditional_prob(name2, name1) 220 | ) / 2 221 | # if method == "distance": 222 | # return self.distance_ratio(name1, name2) 223 | raise ValueError("Invalid method!") # pragma: no cover 224 | 225 | def count_reset(self): 226 | self.graph.vs["raw_count"] = 0 227 | self.graph.vs["prop_count"] = 0 # count propagated from children 228 | self.graph.vs["count"] = 0 229 | 230 | def count_set(self, name, count): 231 | self.get_vertex(name)["raw_count"] = count 232 | 233 | def count_update(self): 234 | origins = [v for v in self.graph.vs.select(raw_count_gt=0)] 235 | for origin in origins: 236 | for v in self.graph.bfsiter(origin, mode=jgraph.OUT): 237 | if v != origin: # bfsiter includes the vertex self 238 | v["prop_count"] += origin["raw_count"] 239 | self.graph.vs["count"] = list(map( 240 | operator.add, self.graph.vs["raw_count"], 241 | self.graph.vs["prop_count"] 242 | )) 243 | 244 | def best_leaves(self, thresh, retrieve="name"): 245 | subgraph = self.graph.subgraph(self.graph.vs.select(count_ge=thresh)) 246 | leaves, max_count = [], 0 247 | for leaf in subgraph.vs.select(lambda v: v.indegree() == 0): 248 | if leaf["count"] > max_count: 249 | max_count = leaf["count"] 250 | leaves = [leaf[retrieve]] 251 | elif leaf["count"] == max_count: 252 | leaves.append(leaf[retrieve]) 253 | return leaves 254 | 255 | 256 | class DataDict(collections.OrderedDict): 257 | 258 | def shuffle(self, random_state=np.random): 259 | shuffled = DataDict() 260 | shuffle_idx = None 261 | for item in self: 262 | shuffle_idx = random_state.permutation(self[item].shape[0]) \ 263 | if shuffle_idx is None else shuffle_idx 264 | shuffled[item] = self[item][shuffle_idx] 265 | return shuffled 266 | 267 | @property 268 | def size(self): 269 | data_size = set([item.shape[0] for item in self.values()]) 270 | assert len(data_size) == 1 271 | return data_size.pop() 272 | 273 | @property 274 | def shape(self): # Compatibility with numpy arrays 275 | return [self.size] 276 | 277 | def __getitem__(self, fetch): 278 | if isinstance(fetch, (slice, np.ndarray)): 279 | return DataDict([ 280 | (item, self[item][fetch]) for item in self 281 | ]) 282 | return super(DataDict, self).__getitem__(fetch) 283 | 284 | 285 | def densify(arr): 286 | if scipy.sparse.issparse(arr): 287 | return arr.toarray() 288 | return arr 289 | 290 | 291 | def empty_safe(fn, dtype): 292 | def _fn(x): 293 | if x.size: 294 | return fn(x) 295 | return x.astype(dtype) 296 | return _fn 297 | 298 | 299 | decode = empty_safe(np.vectorize(lambda _x: _x.decode("utf-8")), str) 300 | encode = empty_safe(np.vectorize(lambda _x: str(_x).encode("utf-8")), "S") 301 | upper = empty_safe(np.vectorize(lambda x: str(x).upper()), str) 302 | lower = empty_safe(np.vectorize(lambda x: str(x).lower()), str) 303 | tostr = empty_safe(np.vectorize(str), str) -------------------------------------------------------------------------------- /Interaction_model/Feature.R: -------------------------------------------------------------------------------- 1 | #Rscript pbmc.R &>nohup.out& 2 | suppressMessages(library(CellChat)) 3 | suppressMessages(library(patchwork)) 4 | options(warn = -1) 5 | options(stringsAsFactors = FALSE) 6 | future::plan("multiprocess", workers = 6) 7 | options(future.globals.maxSize = 2000 * 1024^2) 8 | suppressMessages(require(tibble)) 9 | suppressMessages(require(magrittr)) 10 | suppressMessages(require(purrr)) 11 | 12 | suppressMessages(library(Seurat)) 13 | suppressMessages(library(dplyr)) 14 | warnings('off') 15 | args=commandArgs(T) 16 | parameter1 = args[1] 17 | parameter2 = args[2] 18 | parameter3 = args[3] 19 | testdata <- readRDS(parameter1) 20 | 21 | data.input <- GetAssayData(testdata, assay = "RNA", slot = "data") 22 | label <- read.csv(parameter2) 23 | labels <- as.factor(label$labels) 24 | meta <- data.frame(labels = labels, row.names = names(labels)) 25 | 26 | LRDB<-load(parameter3) 27 | 28 | load('./LRDB/myCompute.RData') 29 | #load('./LRDB/myCompute1.RData') 30 | test <- suppressMessages(createCellChat(object = data.input, meta = meta, group.by = "labels")) 31 | 32 | test <- suppressMessages(addMeta(test, meta = meta)) 33 | test <- suppressMessages(setIdent(test, ident.use = "labels")) 34 | #levels(cellchat@idents) 35 | groupSize <- as.numeric(table(test@idents)) 36 | 37 | HumanDB <- LRDB.human 38 | 39 | test@DB <- HumanDB 40 | 41 | test <- suppressMessages(subsetData(test)) 42 | 43 | test <- suppressMessages(identifyOverExpressedGenes(test)) 44 | test <- suppressMessages(identifyOverExpressedInteractions(test)) 45 | 46 | test <- suppressMessages(projectData(test, PPI.human)) 47 | 48 | test <- suppressMessages(mycomputeCommunProb(test,type = c( "truncatedMean"),trim = 0.1,nboot = 1000)) 49 | df.net <- suppressMessages(subsetCommunication(test,thresh = 1)) 50 | 51 | write.csv(test@LR$LRsig,"./output/pairLR_use.csv") 52 | write.csv(test@meta,"./output/meta.csv") 53 | write.csv(test@data.project,"./output/data_project.csv") 54 | write.csv(test@DB$complex,"./output/complex_input.csv") 55 | write.csv(test@DB$cofactor,"./output/cofactor.csv") 56 | write.csv(df.net,"./output/df_net.csv") 57 | write.csv(test@DB$interaction,"./output/pairLR.csv") 58 | write.csv(as.matrix(test@data.signaling),"./output/data_signaling.csv") 59 | -------------------------------------------------------------------------------- /Interaction_model/Feature.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import ast 4 | import operator 5 | from itertools import chain 6 | import math 7 | import os 8 | import argparse 9 | from utils import * 10 | from scipy import sparse 11 | from collections import defaultdict 12 | from sklearn.decomposition import TruncatedSVD 13 | import torch 14 | def geometricMean(x): 15 | if np.shape(x)[0]==0: 16 | y=0 17 | else: 18 | y=np.exp(np.mean(np.log(x),axis=0)) 19 | y[np.isnan(y)]=0 20 | return y 21 | 22 | def computeExpr_complex(complex_input, data_use, sorted_indexC,data_rownames): 23 | Rsubunits=complex_input[sorted_indexC,] 24 | #print(Rsubunits) 25 | data_complex_all = np.zeros((len(sorted_indexC),np.shape(data_use)[1])) 26 | for i in range(len(sorted_indexC)): 27 | list1=[] 28 | list1.append(Rsubunits.tolist()[i]) 29 | 30 | list1 = list(chain(*list1)) 31 | while '' in list1: 32 | list1.remove('') 33 | sorted_index=[] 34 | #print(list1) 35 | for j in list1: 36 | if j in data_rownames.tolist(): 37 | sorted_index.append(data_rownames.tolist().index(j)) 38 | 39 | data_complex=geometricMean(data_use[sorted_index,]) 40 | 41 | data_complex_all[i]=data_complex 42 | 43 | return data_complex_all 44 | 45 | def cut_graph(): 46 | net=pd.read_csv('./output/df_net.csv').values 47 | P_prob = net[:,5].astype(np.float32) 48 | P_label=net[:,6].astype(np.float32) 49 | Pair_name=net[:,7] 50 | 51 | Pair_uni = np.unique(Pair_name) 52 | 53 | sum_pair = [] 54 | 55 | for j in range(len(Pair_uni)): 56 | sumP = 0 57 | for i in range(len(Pair_name)): 58 | if str(Pair_name[i]) == str(Pair_uni[j]) and P_label[i]<=0.01: 59 | sumP += P_prob[i] 60 | sum_pair.append(sumP) 61 | #print((sum_pair)) 62 | #print(len(sum_pair)) 63 | sum_pair = np.array(sum_pair) 64 | index=np.argsort(-sum_pair) 65 | Pair_uni = np.array(Pair_uni) 66 | Pair_uninew = Pair_uni[index][:100] 67 | 68 | 69 | pairLRsigdata = pd.read_csv('./output/pairLR_use.csv',header=None).values 70 | pairname = pairLRsigdata[1:,1].tolist() 71 | #print(len(pairname)) 72 | idx=[] 73 | for i in range(len(pairname)): 74 | if str(pairname[i]) in Pair_uni: 75 | idx.append(i) 76 | pair_use_new = pairLRsigdata[1:,1:][idx,:] 77 | #print(pair_use_new.shape) 78 | 79 | index = pairLRsigdata[1:,0][idx] 80 | columns = pairLRsigdata[0,1:] 81 | pd.DataFrame(pair_use_new, index=index,columns=columns).to_csv("./output/pair_use_new.csv",quoting=1) 82 | 83 | 84 | idx1=[] 85 | for i in range(len(pairname)): 86 | if str(pairname[i]) in Pair_uninew: 87 | idx1.append(i) 88 | pair_nui_new = pairLRsigdata[1:,1:][idx1,:] 89 | #print(pair_nui_new.shape) 90 | 91 | index = pairLRsigdata[1:,0][idx1] 92 | columns = pairLRsigdata[0,1:] 93 | pd.DataFrame(pair_nui_new, index=index,columns=columns).to_csv("./output/pair_use_top.csv",quoting=1) 94 | 95 | 96 | 97 | def Feature(): 98 | complex_inputdata = './output/complex_input.csv' 99 | with open(complex_inputdata,encoding = 'gbk') as f: 100 | 101 | complex_inputname=np.loadtxt(complex_inputdata,str,delimiter = ",",skiprows=1)[:,0] 102 | #print(complex_inputname) 103 | for i in range(len(complex_inputname)): 104 | complex_inputname[i]=ast.literal_eval(complex_inputname[i]) 105 | subunit1 = np.loadtxt(complex_inputdata,str,delimiter = ",",skiprows=1)[:,1] 106 | for i in range(len(subunit1)): 107 | subunit1[i]=ast.literal_eval(subunit1[i]) 108 | subunit2 = np.loadtxt(complex_inputdata,str,delimiter = ",",skiprows=1)[:,2] 109 | for i in range(len(subunit2)): 110 | subunit2[i]=ast.literal_eval(subunit2[i]) 111 | subunit3 = np.loadtxt(complex_inputdata,str,delimiter = ",",skiprows=1)[:,3] 112 | for i in range(len(subunit3)): 113 | subunit3[i]=ast.literal_eval(subunit3[i]) 114 | subunit4 = np.loadtxt(complex_inputdata,str,delimiter = ",",skiprows=1)[:,4] 115 | for i in range(len(subunit4)): 116 | subunit4[i]=ast.literal_eval(subunit4[i]) 117 | 118 | 119 | complex_input1=np.vstack((np.array(subunit1),np.array(subunit2))) 120 | complex_input2=np.vstack((np.array(subunit3),np.array(subunit4))) 121 | complex_input=np.vstack((complex_input1,complex_input2)).T 122 | 123 | 124 | pairLRsigdata = './output/pair_use_new.csv' 125 | geneL = pd.read_csv(pairLRsigdata).values[:,3] 126 | geneR = pd.read_csv(pairLRsigdata).values[:,4] 127 | pairname=pd.read_csv(pairLRsigdata).values[:,1] 128 | nLR=len(geneL) 129 | data_project='./output/data_project.csv' 130 | with open(data_project,encoding = 'gbk') as f: 131 | data = np.loadtxt(data_project,str,delimiter = ",",skiprows=1)[:,1:] 132 | data_rownames=np.loadtxt(data_project,str,delimiter = ",",skiprows=1)[:,0] 133 | data=(data.astype(np.float64)) 134 | for i in range(len(data_rownames)): 135 | data_rownames[i]=ast.literal_eval(data_rownames[i]) 136 | 137 | metadata=np.loadtxt('./output/meta.csv',str,delimiter = ",",skiprows=1)[:,-1] 138 | for i in range(len(metadata)): 139 | metadata[i]=ast.literal_eval(metadata[i]).replace (' ','') 140 | labelname = np.unique(metadata) 141 | idx = defaultdict(list) 142 | 143 | for j in range(len(labelname)): 144 | for i in range(len(metadata)): 145 | if labelname[j]== metadata[i]: 146 | idx[labelname[j]].append((i)) 147 | len_idx=[] 148 | for i in range(len(labelname)): 149 | len_idx.append(len(idx[labelname[i]])) 150 | 151 | #net=np.loadtxt('./output/df_net.csv',str,delimiter = ",",skiprows=1) 152 | net=pd.read_csv('./output/df_net.csv').values 153 | Cell_s=net[:,1] 154 | Cell_t=net[:,2] 155 | PairL=net[:,3] 156 | PairR=net[:,4] 157 | P_label=net[:,6] 158 | Pair_name=net[:,7] 159 | Cell1 = [] 160 | Cell2 = [] 161 | for i in range(len(Cell_s)): 162 | Cell1.append(Cell_s[i].replace (' ','')) 163 | Cell2.append(Cell_t[i].replace (' ','')) 164 | 165 | pair_name_label=[] 166 | for i in range(len(Cell1)): 167 | pair_name_label.append(Cell1[i]+'_'+Cell2[i]+'_'+Pair_name[i]) 168 | 169 | data_use = data/np.max(data) 170 | nC=np.shape(data_use)[1] 171 | data_rownames=data_rownames.tolist() 172 | sorted_indexL = [] 173 | for i in geneL.tolist(): 174 | if i in data_rownames: 175 | sorted_indexL.append(data_rownames.index(i)) 176 | sorted_indexR = [] 177 | for i in geneR.tolist(): 178 | if i in data_rownames: 179 | sorted_indexR.append(data_rownames.index(i)) 180 | 181 | data_rownames=np.array(data_rownames) 182 | index_singleL=np.where(np.isin(geneL,data_rownames,invert=False)==True) 183 | index_complexL=np.where(np.isin(geneL,data_rownames,invert=False)==False) 184 | index_singleR=np.where(np.isin(geneR,data_rownames,invert=False)==True) 185 | index_complexR=np.where(np.isin(geneR,data_rownames,invert=False)==False) 186 | dataL1 = data_use[sorted_indexL,] 187 | 188 | 189 | dataL=np.zeros((nLR,nC)) 190 | dataL[index_singleL,] = dataL1 191 | if len(index_complexL[0]) > 0: 192 | complexL = geneL[index_complexL] 193 | sorted_indexCL=[] 194 | complexLnewL=[] 195 | 196 | for i in complexL.tolist(): 197 | if i in complex_inputname.tolist(): 198 | sorted_indexCL.append(complex_inputname.tolist().index(i)) 199 | complexLnewL=complexLnewL+[j for j,v in enumerate(geneL.tolist()) if v==i] 200 | 201 | complexLnewL = list(set(complexLnewL)) 202 | complexLnewL.sort() 203 | data_complex=computeExpr_complex(complex_input, data_use, sorted_indexCL,data_rownames) 204 | dataL[complexLnewL,] = data_complex 205 | 206 | 207 | dataR1 = data_use[sorted_indexR,] 208 | dataR=np.zeros((nLR,nC)) 209 | dataR[index_singleR,] = dataR1 210 | 211 | if len(index_complexR[0]) > 0: 212 | complexR = geneR[index_complexR] 213 | sorted_indexCR=[] 214 | complexLnewR=[] 215 | for i in complexR.tolist(): 216 | if i in complex_inputname.tolist(): 217 | sorted_indexCR.append(complex_inputname.tolist().index(i)) 218 | complexLnewR=complexLnewR+[j for j,v in enumerate(geneR.tolist()) if v==i] 219 | complexLnewR = list(set(complexLnewR)) 220 | complexLnewR.sort() 221 | data_complex=computeExpr_complex(complex_input, data_use, sorted_indexCR,data_rownames) 222 | 223 | dataR[complexLnewR,] = data_complex 224 | 225 | p_idx=[] 226 | for i in Pair_name.tolist(): 227 | if i in pairname.tolist(): 228 | p_idx.append(pairname.tolist().index(i)) 229 | maxlen=np.max(np.array(len_idx)) 230 | Feature = np.zeros((len(Cell1),dataL.shape[1]*2)) 231 | for i in range(len(Cell1)): 232 | L_feature=np.zeros((1,dataL.shape[1])) 233 | R_feature=np.zeros((1,dataR.shape[1])) 234 | L_feature[0,idx[Cell1[i]]] = dataL[p_idx[i],idx[Cell1[i]]] 235 | R_feature[0,idx[Cell2[i]]] = dataR[p_idx[i],idx[Cell2[i]]] 236 | Feature[i,:]=np.hstack((L_feature,R_feature)) 237 | Feature_sparse = sparse.csc_matrix(Feature) 238 | svd = TruncatedSVD(n_components=1600, n_iter=10, random_state=42) 239 | Feature1 = svd.fit_transform(Feature_sparse) 240 | Feature1 = Feature1.astype(str) 241 | Feature_idx=np.arange(0,len(Cell1)) 242 | Featureall=np.column_stack((Feature_idx,Feature1)) 243 | np.savetxt("./output/Feature.csv",Featureall,fmt="%s",delimiter=",") 244 | 245 | 246 | def graph(): 247 | complex_inputdata = './output/complex_input.csv' 248 | with open(complex_inputdata,encoding = 'gbk') as f: 249 | 250 | complex_inputname=np.loadtxt(complex_inputdata,str,delimiter = ",",skiprows=1)[:,0] 251 | #print(complex_inputname) 252 | for i in range(len(complex_inputname)): 253 | complex_inputname[i]=ast.literal_eval(complex_inputname[i]) 254 | subunit1 = np.loadtxt(complex_inputdata,str,delimiter = ",",skiprows=1)[:,1] 255 | for i in range(len(subunit1)): 256 | subunit1[i]=ast.literal_eval(subunit1[i]) 257 | subunit2 = np.loadtxt(complex_inputdata,str,delimiter = ",",skiprows=1)[:,2] 258 | for i in range(len(subunit2)): 259 | subunit2[i]=ast.literal_eval(subunit2[i]) 260 | subunit3 = np.loadtxt(complex_inputdata,str,delimiter = ",",skiprows=1)[:,3] 261 | for i in range(len(subunit3)): 262 | subunit3[i]=ast.literal_eval(subunit3[i]) 263 | subunit4 = np.loadtxt(complex_inputdata,str,delimiter = ",",skiprows=1)[:,4] 264 | for i in range(len(subunit4)): 265 | subunit4[i]=ast.literal_eval(subunit4[i]) 266 | 267 | 268 | complex_input1=np.vstack((np.array(subunit1),np.array(subunit2))) 269 | complex_input2=np.vstack((np.array(subunit3),np.array(subunit4))) 270 | complex_input=np.vstack((complex_input1,complex_input2)).T 271 | 272 | 273 | pairLRsigdata = './output/pair_use_top.csv' 274 | geneL = pd.read_csv(pairLRsigdata).values[:,3] 275 | geneR = pd.read_csv(pairLRsigdata).values[:,4] 276 | pairname=pd.read_csv(pairLRsigdata).values[:,1] 277 | nLR=len(geneL) 278 | data_project='./output/data_signaling.csv' 279 | with open(data_project,encoding = 'gbk') as f: 280 | data = np.loadtxt(data_project,str,delimiter = ",",skiprows=1)[:,1:] 281 | data_rownames=np.loadtxt(data_project,str,delimiter = ",",skiprows=1)[:,0] 282 | data=(data.astype(np.float64)) 283 | for i in range(len(data_rownames)): 284 | data_rownames[i]=ast.literal_eval(data_rownames[i]) 285 | 286 | metadata=np.loadtxt('./output/meta.csv',str,delimiter = ",",skiprows=1)[:,-1] 287 | for i in range(len(metadata)): 288 | metadata[i]=ast.literal_eval(metadata[i]).replace (' ','') 289 | #np.savetxt("./output/labelname.csv",metadata,fmt="%s",delimiter=",") 290 | labelname = np.unique(metadata) 291 | idx = defaultdict(list) 292 | 293 | for j in range(len(labelname)): 294 | for i in range(len(metadata)): 295 | if labelname[j]== metadata[i]: 296 | idx[labelname[j]].append((i)) 297 | len_idx=[] 298 | for i in range(len(labelname)): 299 | len_idx.append(len(idx[labelname[i]])) 300 | 301 | #net=np.loadtxt('./output/df_net.csv',str,delimiter = ",",skiprows=1) 302 | net=pd.read_csv('./output/df_net.csv').values 303 | Cell_s=net[:,1] 304 | Cell_t=net[:,2] 305 | PairL=net[:,3] 306 | PairR=net[:,4] 307 | P_label=net[:,6] 308 | Pair_name=net[:,7] 309 | Cell1 = [] 310 | Cell2 = [] 311 | for i in range(len(Cell_s)): 312 | Cell1.append(Cell_s[i].replace (' ','')) 313 | Cell2.append(Cell_t[i].replace (' ','')) 314 | 315 | pair_name_label=[] 316 | for i in range(len(Cell1)): 317 | pair_name_label.append(Cell1[i]+'_'+Cell2[i]+'_'+Pair_name[i]) 318 | 319 | 320 | data_use = data/np.max(data) 321 | nC=np.shape(data_use)[1] 322 | data_rownames=data_rownames.tolist() 323 | sorted_indexL = [] 324 | for i in geneL.tolist(): 325 | if i in data_rownames: 326 | sorted_indexL.append(data_rownames.index(i)) 327 | sorted_indexR = [] 328 | for i in geneR.tolist(): 329 | if i in data_rownames: 330 | sorted_indexR.append(data_rownames.index(i)) 331 | 332 | data_rownames=np.array(data_rownames) 333 | index_singleL=np.where(np.isin(geneL,data_rownames,invert=False)==True) 334 | index_complexL=np.where(np.isin(geneL,data_rownames,invert=False)==False) 335 | index_singleR=np.where(np.isin(geneR,data_rownames,invert=False)==True) 336 | index_complexR=np.where(np.isin(geneR,data_rownames,invert=False)==False) 337 | dataL1 = data_use[sorted_indexL,] 338 | 339 | 340 | dataL=np.zeros((nLR,nC)) 341 | dataL[index_singleL,] = dataL1 342 | if len(index_complexL[0]) > 0: 343 | complexL = geneL[index_complexL] 344 | sorted_indexCL=[] 345 | complexLnewL=[] 346 | 347 | for i in complexL.tolist(): 348 | if i in complex_inputname.tolist(): 349 | sorted_indexCL.append(complex_inputname.tolist().index(i)) 350 | complexLnewL=complexLnewL+[j for j,v in enumerate(geneL.tolist()) if v==i] 351 | 352 | complexLnewL = list(set(complexLnewL)) 353 | complexLnewL.sort() 354 | data_complex=computeExpr_complex(complex_input, data_use, sorted_indexCL,data_rownames) 355 | dataL[complexLnewL,] = data_complex 356 | 357 | 358 | dataR1 = data_use[sorted_indexR,] 359 | dataR=np.zeros((nLR,nC)) 360 | dataR[index_singleR,] = dataR1 361 | 362 | if len(index_complexR[0]) > 0: 363 | complexR = geneR[index_complexR] 364 | sorted_indexCR=[] 365 | complexLnewR=[] 366 | for i in complexR.tolist(): 367 | if i in complex_inputname.tolist(): 368 | sorted_indexCR.append(complex_inputname.tolist().index(i)) 369 | complexLnewR=complexLnewR+[j for j,v in enumerate(geneR.tolist()) if v==i] 370 | complexLnewR = list(set(complexLnewR)) 371 | complexLnewR.sort() 372 | data_complex=computeExpr_complex(complex_input, data_use, sorted_indexCR,data_rownames) 373 | 374 | dataR[complexLnewR,] = data_complex 375 | 376 | Pair_LR = np.hstack((dataL,dataR)) 377 | Pair_LR_sparse = sparse.csc_matrix(Pair_LR) 378 | svd = TruncatedSVD(n_components=500, n_iter=10, random_state=42) 379 | #sparse.save_npz("../Data/pathway_pre/"+str(pathwayname[m])+".npz",pathway_sparse) 380 | 381 | Pair_LR1 = svd.fit_transform(Pair_LR_sparse) 382 | Pair_LR = np.column_stack((pairname,Pair_LR1)) 383 | 384 | labels=Pair_LR[:,0] 385 | Pair_LRdata = Pair_LR[:,1:] 386 | data=Pair_LRdata.astype(np.float32) 387 | 388 | AJ=np.corrcoef((data)) 389 | 390 | AJ[AJ>=0.95]=1 391 | AJ[AJ<0.95]=0 392 | 393 | 394 | row, col = np.diag_indices_from(AJ) 395 | AJ[row,col] = 0 396 | A=np.array(np.where(AJ>0)) 397 | #print(A.T) 398 | np.savetxt("./output/graph/test.cites",A.T,fmt="%s",delimiter=" ") 399 | 400 | index=np.arange(0,(data.shape[0])).astype(str) 401 | data=data.astype(str) 402 | data1=np.insert(data, 0, values=index, axis=1) 403 | data2=np.column_stack((data1,labels)) 404 | #print(data2.shape) 405 | np.savetxt("./output/graph/test.content",data2,fmt="%s",delimiter=" ") 406 | adj, features = load_data(path="./output/graph/", dataset="test") 407 | torch.save(adj, "./output/adj.pth") 408 | torch.save(features, "./output/features.pth") 409 | 410 | 411 | 412 | 413 | if __name__ == "__main__": 414 | if not os.path.exists("./output/"): 415 | os.system('mkdir ./output/') 416 | if not os.path.exists("./output/graph"): 417 | os.system('mkdir ./output/graph') 418 | 419 | parser = argparse.ArgumentParser( 420 | description='Cell_Interaction', 421 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 422 | parser.add_argument('--species', type=str, default='Human') 423 | parser.add_argument('--label_mode', default=True) 424 | args = parser.parse_args() 425 | if args.species == 'Human': 426 | LR_DB='./LRDB/LRDB.human.rda' 427 | elif args.species == 'Mouse': 428 | LR_DB='./LRDB/LRDB.mouse.rda' 429 | 430 | if args.label_mode: 431 | os.system('Rscript Feature.R ./input/test.rds ./input/test_cell_label.csv '+str(LR_DB)) 432 | else: 433 | print("Cell Clustering...") 434 | os.system('python ./cluster/Cluster.py --pretain True --pretrain_epoch 50 --device cuda --Auto False') 435 | os.system('Rscript Feature.R ./input/test.rds ./cluster/output/cell_annotatetype.csv '+str(LR_DB)) 436 | #print("cut graph") 437 | cut_graph() 438 | #print("Feature") 439 | Feature() 440 | #print("graph") 441 | graph() 442 | 443 | -------------------------------------------------------------------------------- /Interaction_model/Interaction_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | import math 5 | from collections import defaultdict 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | from modelv2 import MultiOutputModel 10 | from torch.utils.data import DataLoader 11 | from sklearn.metrics import roc_curve, auc 12 | from scipy import interp 13 | from dataset import ForDataset 14 | from scipy import sparse 15 | from sklearn import preprocessing 16 | cpu_num = 3 17 | torch.set_num_threads(cpu_num) 18 | def checkpoint_load(model, name): 19 | #print('Restoring checkpoint: {}'.format(name)) 20 | model.load_state_dict(torch.load(name, map_location='cpu')) 21 | epoch = int(os.path.splitext(os.path.basename(name))[0].split('-')[1]) 22 | return epoch 23 | 24 | 25 | def predict(model, dataloader, batch_size, device, adj, features,checkpoint): 26 | #pretrained_dict = torch.load(checkpoint) 27 | #model_dict = model.state_dict() 28 | #pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict and 'fc' not in k)} 29 | #pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict and 'gc2' not in k)} 30 | #model_dict.update(pretrained_dict) 31 | #model.load_state_dict(model_dict) 32 | checkpoint_load(model, checkpoint) 33 | model.eval() 34 | results = defaultdict(list) 35 | with torch.no_grad(): 36 | Predict=[] 37 | for batch in dataloader: 38 | data = batch['data'] 39 | data=torch.reshape(data,(batch_size,1,int(math.sqrt(data.size(1))),int(math.sqrt(data.size(1))))).to(torch.float32) 40 | 41 | output = model(data.to(device),features.to(device), adj.to(device)) 42 | 43 | Predict += (output['sigmoid'].cpu().detach().numpy().tolist()) 44 | 45 | return Predict 46 | 47 | 48 | if __name__ == '__main__': 49 | 50 | parser = argparse.ArgumentParser(description='Inference pipeline') 51 | #parser.add_argument('--checkpoint', type=str, required=True, help="Path to the checkpoint") 52 | parser.add_argument('--batch_size', type=int, default=None, help='batch_size') 53 | 54 | parser.add_argument('--device', type=str, default='cuda', 55 | help="Device: 'cuda' or 'cpu'") 56 | 57 | #parser.add_argument('--label_mode', default=True) 58 | args = parser.parse_args() 59 | num_workers = 0 60 | ''' 61 | if args.label_mode: 62 | os.system("python Feature.py --label_mode True") 63 | else: 64 | os.system("python Feature.py --label_mode False") 65 | ''' 66 | 67 | #print(batch_size) 68 | device = torch.device("cuda" if torch.cuda.is_available() and args.device == 'cuda' else "cpu") 69 | print(device) 70 | #adj, features = load_data(path="./Predict/human/graph/", dataset="test") 71 | Feature=pd.read_csv('./output/Feature.csv',header=None).values 72 | data_predict=Feature[:,1:].astype(np.float32) 73 | idx=Feature[:,0] 74 | adj = torch.load("./output/adj.pth") 75 | features = torch.load("./output/features.pth") 76 | 77 | #Scaler=preprocessing.StandardScaler() 78 | #data_predict = Scaler.fit_transform(data_predict) 79 | MinMax = preprocessing.MinMaxScaler() 80 | data_predict=MinMax.fit_transform(data_predict) 81 | data_predict=preprocessing.normalize(data_predict, norm='l2') 82 | args.batch_size = data_predict.shape[0] 83 | pre_dataset = ForDataset(data_predict,idx) 84 | pre_dataloader = DataLoader(pre_dataset, batch_size=args.batch_size, shuffle=False) 85 | model = MultiOutputModel(nfeat=features.shape[1],nlabel=features.shape[0]).to(device) 86 | 87 | #for i in range(5,100,5): 88 | #print(i) 89 | checkpoint="./model/checkpoint-000100.pth" 90 | Predict = predict(model, pre_dataloader, args.batch_size ,device, adj, features,checkpoint=checkpoint) 91 | #print(Predict) 92 | #label = ["label"] + Predict 93 | Predict=np.array(Predict) 94 | #print(Predict) 95 | Predict[Predict>0.5] = 1 96 | Predict[Predict<=0.5] = 0 97 | CC_net = pd.read_csv('./output/df_net.csv',header=None).values 98 | CC_net_data = CC_net[1:,1:] 99 | CC_pval = CC_net_data[:,5].astype(np.float16) 100 | CC_pval = Predict 101 | CC_pval=CC_pval.astype(np.int32) 102 | CC_net_data[:,5] = CC_pval 103 | 104 | index = CC_net[1:,0] 105 | 106 | columns = CC_net[0,1:] 107 | columns[5]=='label' 108 | pd.DataFrame(CC_net_data, index=index,columns=columns).to_csv("./output/CCI_out.csv",quoting=1) -------------------------------------------------------------------------------- /Interaction_model/LRDB/LRDB.human.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Interaction_model/LRDB/LRDB.human.rda -------------------------------------------------------------------------------- /Interaction_model/LRDB/LRDB.mouse.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Interaction_model/LRDB/LRDB.mouse.rda -------------------------------------------------------------------------------- /Interaction_model/LRDB/myCompute.RData: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Interaction_model/LRDB/myCompute.RData -------------------------------------------------------------------------------- /Interaction_model/Mobilev2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | 5 | def conv_bn(inp, oup, stride): 6 | return nn.Sequential( 7 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 8 | nn.BatchNorm2d(oup), 9 | nn.ReLU6(inplace=True) 10 | ) 11 | 12 | 13 | def conv_1x1_bn(inp, oup): 14 | return nn.Sequential( 15 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 16 | nn.BatchNorm2d(oup), 17 | nn.ReLU6(inplace=True) 18 | ) 19 | 20 | 21 | def make_divisible(x, divisible_by=8): 22 | import numpy as np 23 | return int(np.ceil(x * 1. / divisible_by) * divisible_by) 24 | 25 | 26 | class InvertedResidual(nn.Module): 27 | def __init__(self, inp, oup, stride, expand_ratio): 28 | super(InvertedResidual, self).__init__() 29 | self.stride = stride 30 | assert stride in [1, 2] 31 | 32 | hidden_dim = int(inp * expand_ratio) 33 | self.use_res_connect = self.stride == 1 and inp == oup 34 | 35 | if expand_ratio == 1: 36 | self.conv = nn.Sequential( 37 | # dw 38 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 39 | nn.BatchNorm2d(hidden_dim), 40 | nn.ReLU6(inplace=True), 41 | # pw-linear 42 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 43 | nn.BatchNorm2d(oup), 44 | ) 45 | else: 46 | self.conv = nn.Sequential( 47 | # pw 48 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 49 | nn.BatchNorm2d(hidden_dim), 50 | nn.ReLU6(inplace=True), 51 | # dw 52 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 53 | nn.BatchNorm2d(hidden_dim), 54 | nn.ReLU6(inplace=True), 55 | # pw-linear 56 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 57 | nn.BatchNorm2d(oup), 58 | ) 59 | 60 | def forward(self, x): 61 | if self.use_res_connect: 62 | return x + self.conv(x) 63 | else: 64 | return self.conv(x) 65 | 66 | 67 | class MobileNetV2(nn.Module): 68 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 69 | super(MobileNetV2, self).__init__() 70 | block = InvertedResidual 71 | input_channel = 32 72 | last_channel = 1280 73 | interverted_residual_setting = [ 74 | # t, c, n, s 75 | [1, 16, 1, 1], 76 | [6, 24, 2, 2], 77 | [6, 32, 3, 2], 78 | [6, 64, 4, 2], 79 | [6, 96, 3, 1], 80 | [6, 160, 3, 2], 81 | [6, 320, 1, 1], 82 | ] 83 | 84 | # building first layer 85 | assert input_size % 32 == 0 86 | # input_channel = make_divisible(input_channel * width_mult) # first channel is always 32! 87 | self.last_channel = make_divisible(last_channel * width_mult) if width_mult > 1.0 else last_channel 88 | self.features = [conv_bn(1, input_channel, 2)] 89 | # building inverted residual blocks 90 | for t, c, n, s in interverted_residual_setting: 91 | output_channel = make_divisible(c * width_mult) if t > 1 else c 92 | for i in range(n): 93 | if i == 0: 94 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 95 | else: 96 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 97 | input_channel = output_channel 98 | # building last several layers 99 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 100 | # make it nn.Sequential 101 | self.features = nn.Sequential(*self.features) 102 | 103 | # building classifier 104 | self.classifier = nn.Linear(self.last_channel, n_class) 105 | 106 | self._initialize_weights() 107 | 108 | def forward(self, x): 109 | x = self.features(x) 110 | x = x.mean(3).mean(2) 111 | x = self.classifier(x) 112 | return x 113 | 114 | def _initialize_weights(self): 115 | for m in self.modules(): 116 | if isinstance(m, nn.Conv2d): 117 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 118 | m.weight.data.normal_(0, math.sqrt(2. / n)) 119 | if m.bias is not None: 120 | m.bias.data.zero_() 121 | elif isinstance(m, nn.BatchNorm2d): 122 | m.weight.data.fill_(1) 123 | m.bias.data.zero_() 124 | elif isinstance(m, nn.Linear): 125 | n = m.weight.size(1) 126 | m.weight.data.normal_(0, 0.01) 127 | m.bias.data.zero_() 128 | 129 | 130 | def mobilenet_v2(pretrained=True): 131 | model = MobileNetV2(width_mult=1) 132 | 133 | if pretrained: 134 | try: 135 | from torch.hub import load_state_dict_from_url 136 | except ImportError: 137 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 138 | state_dict = load_state_dict_from_url( 139 | 'https://www.dropbox.com/s/47tyzpofuuyyv1b/mobilenetv2_1.0-f2a8633.pth.tar?dl=1', progress=True) 140 | model.load_state_dict(state_dict) 141 | return model 142 | -------------------------------------------------------------------------------- /Interaction_model/Train/Mobilev2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | 5 | def conv_bn(inp, oup, stride): 6 | return nn.Sequential( 7 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 8 | nn.BatchNorm2d(oup), 9 | nn.ReLU6(inplace=True) 10 | ) 11 | 12 | 13 | def conv_1x1_bn(inp, oup): 14 | return nn.Sequential( 15 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 16 | nn.BatchNorm2d(oup), 17 | nn.ReLU6(inplace=True) 18 | ) 19 | 20 | 21 | def make_divisible(x, divisible_by=8): 22 | import numpy as np 23 | return int(np.ceil(x * 1. / divisible_by) * divisible_by) 24 | 25 | 26 | class InvertedResidual(nn.Module): 27 | def __init__(self, inp, oup, stride, expand_ratio): 28 | super(InvertedResidual, self).__init__() 29 | self.stride = stride 30 | assert stride in [1, 2] 31 | 32 | hidden_dim = int(inp * expand_ratio) 33 | self.use_res_connect = self.stride == 1 and inp == oup 34 | 35 | if expand_ratio == 1: 36 | self.conv = nn.Sequential( 37 | # dw 38 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 39 | nn.BatchNorm2d(hidden_dim), 40 | nn.ReLU6(inplace=True), 41 | # pw-linear 42 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 43 | nn.BatchNorm2d(oup), 44 | ) 45 | else: 46 | self.conv = nn.Sequential( 47 | # pw 48 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 49 | nn.BatchNorm2d(hidden_dim), 50 | nn.ReLU6(inplace=True), 51 | # dw 52 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 53 | nn.BatchNorm2d(hidden_dim), 54 | nn.ReLU6(inplace=True), 55 | # pw-linear 56 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 57 | nn.BatchNorm2d(oup), 58 | ) 59 | 60 | def forward(self, x): 61 | if self.use_res_connect: 62 | return x + self.conv(x) 63 | else: 64 | return self.conv(x) 65 | 66 | 67 | class MobileNetV2(nn.Module): 68 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 69 | super(MobileNetV2, self).__init__() 70 | block = InvertedResidual 71 | input_channel = 32 72 | last_channel = 1280 73 | interverted_residual_setting = [ 74 | # t, c, n, s 75 | [1, 16, 1, 1], 76 | [6, 24, 2, 2], 77 | [6, 32, 3, 2], 78 | [6, 64, 4, 2], 79 | [6, 96, 3, 1], 80 | [6, 160, 3, 2], 81 | [6, 320, 1, 1], 82 | ] 83 | 84 | # building first layer 85 | assert input_size % 32 == 0 86 | # input_channel = make_divisible(input_channel * width_mult) # first channel is always 32! 87 | self.last_channel = make_divisible(last_channel * width_mult) if width_mult > 1.0 else last_channel 88 | self.features = [conv_bn(1, input_channel, 2)] 89 | # building inverted residual blocks 90 | for t, c, n, s in interverted_residual_setting: 91 | output_channel = make_divisible(c * width_mult) if t > 1 else c 92 | for i in range(n): 93 | if i == 0: 94 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 95 | else: 96 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 97 | input_channel = output_channel 98 | # building last several layers 99 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 100 | # make it nn.Sequential 101 | self.features = nn.Sequential(*self.features) 102 | 103 | # building classifier 104 | self.classifier = nn.Linear(self.last_channel, n_class) 105 | 106 | self._initialize_weights() 107 | 108 | def forward(self, x): 109 | x = self.features(x) 110 | x = x.mean(3).mean(2) 111 | x = self.classifier(x) 112 | return x 113 | 114 | def _initialize_weights(self): 115 | for m in self.modules(): 116 | if isinstance(m, nn.Conv2d): 117 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 118 | m.weight.data.normal_(0, math.sqrt(2. / n)) 119 | if m.bias is not None: 120 | m.bias.data.zero_() 121 | elif isinstance(m, nn.BatchNorm2d): 122 | m.weight.data.fill_(1) 123 | m.bias.data.zero_() 124 | elif isinstance(m, nn.Linear): 125 | n = m.weight.size(1) 126 | m.weight.data.normal_(0, 0.01) 127 | m.bias.data.zero_() 128 | 129 | 130 | def mobilenet_v2(pretrained=True): 131 | model = MobileNetV2(width_mult=1) 132 | 133 | if pretrained: 134 | try: 135 | from torch.hub import load_state_dict_from_url 136 | except ImportError: 137 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 138 | state_dict = load_state_dict_from_url( 139 | 'https://www.dropbox.com/s/47tyzpofuuyyv1b/mobilenetv2_1.0-f2a8633.pth.tar?dl=1', progress=True) 140 | model.load_state_dict(state_dict) 141 | return model 142 | -------------------------------------------------------------------------------- /Interaction_model/Train/ResNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | print("PyTorch Version: ",torch.__version__) 6 | print("Torchvision Version: ",torchvision.__version__) 7 | 8 | __all__ = ['ResNet50', 'ResNet101','ResNet152'] 9 | 10 | def Conv1(in_planes, places, stride=2): 11 | return nn.Sequential( 12 | nn.Conv2d(in_channels=in_planes,out_channels=places,kernel_size=7,stride=stride,padding=3, bias=False), 13 | nn.BatchNorm2d(places), 14 | nn.ReLU(inplace=True), 15 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 16 | ) 17 | 18 | class Bottleneck(nn.Module): 19 | def __init__(self,in_places,places, stride=1,downsampling=False, expansion = 4): 20 | super(Bottleneck,self).__init__() 21 | self.expansion = expansion 22 | self.downsampling = downsampling 23 | 24 | self.bottleneck = nn.Sequential( 25 | nn.Conv2d(in_channels=in_places,out_channels=places,kernel_size=1,stride=1, bias=False), 26 | nn.BatchNorm2d(places), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(in_channels=places, out_channels=places, kernel_size=3, stride=stride, padding=1, bias=False), 29 | nn.BatchNorm2d(places), 30 | nn.ReLU(inplace=True), 31 | nn.Conv2d(in_channels=places, out_channels=places*self.expansion, kernel_size=1, stride=1, bias=False), 32 | nn.BatchNorm2d(places*self.expansion), 33 | ) 34 | 35 | if self.downsampling: 36 | self.downsample = nn.Sequential( 37 | nn.Conv2d(in_channels=in_places, out_channels=places*self.expansion, kernel_size=1, stride=stride, bias=False), 38 | nn.BatchNorm2d(places*self.expansion) 39 | ) 40 | self.relu = nn.ReLU(inplace=True) 41 | 42 | def forward(self, x): 43 | residual = x 44 | out = self.bottleneck(x) 45 | 46 | if self.downsampling: 47 | residual = self.downsample(x) 48 | 49 | out += residual 50 | out = self.relu(out) 51 | return out 52 | 53 | class ResNet(nn.Module): 54 | def __init__(self,blocks, num_classes=1000, expansion = 4): 55 | super(ResNet,self).__init__() 56 | self.expansion = expansion 57 | 58 | self.conv1 = Conv1(in_planes = 1, places= 64) 59 | self.bn1 = nn.BatchNorm2d(64) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 62 | 63 | self.layer1 = self.make_layer(in_places = 64, places= 64, block=blocks[0], stride=1) 64 | self.layer2 = self.make_layer(in_places = 256,places=128, block=blocks[1], stride=2) 65 | self.layer3 = self.make_layer(in_places=512,places=256, block=blocks[2], stride=2) 66 | self.layer4 = self.make_layer(in_places=1024,places=512, block=blocks[3], stride=2) 67 | 68 | self.avgpool = nn.AdaptiveAvgPool2d(1) 69 | self.fc = nn.Linear(2048,num_classes) 70 | 71 | for m in self.modules(): 72 | if isinstance(m, nn.Conv2d): 73 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 74 | elif isinstance(m, nn.BatchNorm2d): 75 | nn.init.constant_(m.weight, 1) 76 | nn.init.constant_(m.bias, 0) 77 | 78 | def make_layer(self, in_places, places, block, stride): 79 | layers = [] 80 | layers.append(Bottleneck(in_places, places,stride, downsampling =True)) 81 | for i in range(1, block): 82 | layers.append(Bottleneck(places*self.expansion, places)) 83 | 84 | return nn.Sequential(*layers) 85 | 86 | 87 | def forward(self, x): 88 | x = self.conv1(x) 89 | x = self.bn1(x) 90 | x = self.relu(x) 91 | x = self.maxpool(x) 92 | x = self.layer1(x) 93 | x = self.layer2(x) 94 | x = self.layer3(x) 95 | x = self.layer4(x) 96 | 97 | x = self.avgpool(x) 98 | x = x.view(x.size(0), -1) 99 | x = self.fc(x) 100 | return x 101 | 102 | def ResNet50(): 103 | return ResNet([3, 4, 6, 3]) 104 | 105 | def ResNet101(): 106 | return ResNet([3, 4, 23, 3]) 107 | 108 | def ResNet152(): 109 | return ResNet([3, 8, 36, 3]) 110 | -------------------------------------------------------------------------------- /Interaction_model/Train/__pycache__/Mobilev2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Interaction_model/Train/__pycache__/Mobilev2.cpython-37.pyc -------------------------------------------------------------------------------- /Interaction_model/Train/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Interaction_model/Train/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /Interaction_model/Train/__pycache__/modelv2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Interaction_model/Train/__pycache__/modelv2.cpython-37.pyc -------------------------------------------------------------------------------- /Interaction_model/Train/__pycache__/test.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Interaction_model/Train/__pycache__/test.cpython-37.pyc -------------------------------------------------------------------------------- /Interaction_model/Train/data/readme.md: -------------------------------------------------------------------------------- 1 | The data used to train the model can be download from: 2 | 3 | https://pan.baidu.com/s/1CM6bIfmlm4I3xuGQHLbG7w?pwd=yz33 4 | 5 | Extraction code: yz33 -------------------------------------------------------------------------------- /Interaction_model/Train/dataset.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import pandas as pd 3 | import numpy as np 4 | from PIL import Image 5 | import torch 6 | from sklearn.model_selection import StratifiedKFold,KFold 7 | from scipy import sparse 8 | from torch.utils.data import Dataset 9 | from torch.utils.data import DataLoader 10 | import torchvision.transforms as transforms 11 | mean = [ 0.456] 12 | std = [0.224] 13 | 14 | 15 | class AttributesDataset(): 16 | def __init__(self, annotation_path): 17 | class_labels=[] 18 | multi_labels=[] 19 | with open(annotation_path) as f: 20 | reader = csv.DictReader(f) 21 | for row in reader: 22 | class_labels.append(row['Labels']) 23 | for line in reader: 24 | self.multi_labels.append(line[1:-2]) 25 | self.class_labels = np.unique(class_labels) 26 | self.num_labels = len(self.class_labels) 27 | 28 | #self.class_labels_id_to_name = dict(zip(range(len(self.class_labels)), self.class_labels)) 29 | 30 | #self.class_labels_name_to_id = dict(zip(self.class_labels, range(len(self.class_labels)))) 31 | 32 | 33 | 34 | 35 | 36 | 37 | class ForDataset(Dataset): 38 | def __init__(self, data_train, annotation, transform=None): 39 | super().__init__() 40 | 41 | self.transform = transform 42 | #self.attr = attributes 43 | #self.data=dd.read_csv(data_train,header=None).values.compute() 44 | self.data=data_train 45 | #print(np.shape(data)) 46 | self.data_idx=[] 47 | # initialize the arrays to store the ground truth labels and paths to the images 48 | #self.class_labels=[] 49 | 50 | self.labels=[] 51 | # read the annotations from the CSV file 52 | #with open(annotation_path) as f: 53 | # reader = csv.DictReader(f) 54 | # for row in reader: 55 | #for i in range(len()) 56 | self.data_idx=annotation[:,0].astype(np.int32).tolist() 57 | #self.class_labels.append(self.attr.class_labels_name_to_id[row['Labels']]) 58 | #self.class_labels=dd.read_csv(annotation_path).values.compute()[:,-2] 59 | 60 | self.labels=annotation[:,-1] 61 | #print(self.multi_labels.shape) 62 | def __len__(self): 63 | return len(self.data_idx) 64 | 65 | def __getitem__(self, idx): 66 | # take the data sample by its index 67 | 68 | data=self.data 69 | #print(data.shape) 70 | 71 | #class_labels=np.array(self.class_labels).astype(np.int32) 72 | labels=np.array(self.labels).astype(np.int32) 73 | #print(np.shape(data)) 74 | # read image 75 | #img = Image.open(img_path) 76 | 77 | # apply the image augmentations if needed 78 | if self.transform: 79 | data = self.transform(data) 80 | # return the image and all the associated labels 81 | dict_data = { 82 | 'data': data[idx], 83 | 'labels': { 84 | 'labels':labels[idx] 85 | } 86 | } 87 | #print(dict_data) 88 | 89 | return dict_data 90 | 91 | -------------------------------------------------------------------------------- /Interaction_model/Train/modelv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Parameter 5 | import math 6 | #import torchvision.models as models 7 | from Mobilev2 import MobileNetV2 8 | 9 | class FocalLoss(nn.Module): 10 | def __init__(self, alpha=0.65, gamma=2, logits=False, reduce=True): 11 | super(FocalLoss, self).__init__() 12 | self.alpha = alpha 13 | self.gamma = gamma 14 | self.logits = logits 15 | self.reduce = reduce 16 | 17 | def forward(self, inputs, targets): 18 | if self.logits: 19 | BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') 20 | else: 21 | BCE_loss = F.binary_cross_entropy(inputs, targets, reduction='none') 22 | pt = torch.exp(-BCE_loss) 23 | F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss 24 | 25 | if self.reduce: 26 | return torch.mean(F_loss) 27 | else: 28 | return F_loss 29 | 30 | 31 | 32 | class GraphConvolution(nn.Module): 33 | """ 34 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 35 | """ 36 | 37 | def __init__(self, in_features, out_features, bias=False): 38 | super(GraphConvolution, self).__init__() 39 | self.in_features = in_features 40 | self.out_features = out_features 41 | self.weight = Parameter(torch.Tensor(in_features, out_features)) 42 | if bias: 43 | self.bias = Parameter(torch.Tensor(1, 1, out_features)) 44 | else: 45 | self.register_parameter('bias', None) 46 | self.reset_parameters() 47 | 48 | def reset_parameters(self): 49 | stdv = 1. / math.sqrt(self.weight.size(1)) 50 | self.weight.data.uniform_(-stdv, stdv) 51 | if self.bias is not None: 52 | self.bias.data.uniform_(-stdv, stdv) 53 | 54 | def forward(self, input, adj): 55 | support = torch.matmul(input, self.weight) 56 | output = torch.matmul(adj, support) 57 | if self.bias is not None: 58 | return output + self.bias 59 | else: 60 | return output 61 | 62 | def __repr__(self): 63 | return self.__class__.__name__ + ' (' \ 64 | + str(self.in_features) + ' -> ' \ 65 | + str(self.out_features) + ')' 66 | class MultiOutputModel(nn.Module): 67 | 68 | def __init__(self,nfeat,nlabel): 69 | 70 | super().__init__() 71 | models=MobileNetV2() 72 | self.base_model = MobileNetV2().features # take the model without classifier 73 | last_channel = MobileNetV2().last_channel # size of the layer before classifier 74 | self.pool = nn.AdaptiveMaxPool2d((1, 1)) 75 | self.gc1 = GraphConvolution(nfeat, 512) 76 | self.gc2 = GraphConvolution(512, last_channel) 77 | self.relu = nn.LeakyReLU(0.2) 78 | self.fc = nn.Linear(nlabel,1) 79 | self.dropout=nn.Dropout(0.2) 80 | 81 | def forward(self, x, feature, adj): 82 | x = self.base_model(x) 83 | x = self.pool(x) 84 | x = self.dropout(x) 85 | # reshape from [batch, channels, 1, 1] to [batch, channels] to put it into classifier 86 | x = torch.flatten(x, 1) 87 | feature = self.gc1(feature, adj) 88 | #feature = self.dropout(feature) 89 | feature = self.relu(feature) 90 | feature = self.gc2(feature, adj) 91 | #feature = self.dropout(feature) 92 | feature = self.relu(feature) 93 | feature = feature.transpose(0, 1) 94 | 95 | x = torch.matmul(x,feature) 96 | x = self.fc(x) 97 | x = self.dropout(x) 98 | x = x.squeeze(-1) 99 | xt = torch.sigmoid(x) 100 | #print(y) 101 | return {'class': x,'sigmoid': xt 102 | 103 | } 104 | def initialize(self): 105 | for m in self.modules(): 106 | if isinstance(m, nn.Linear): 107 | nn.init.xavier_uniform_(m.weight, gain=1) 108 | #print(m.weight) 109 | 110 | def get_config_optim(self, lr, lrp): 111 | return [ 112 | {'params': self.base_model.parameters(), 'lr': lr }, 113 | {'params': self.gc1.parameters(), 'lr': lrp}, 114 | {'params': self.gc2.parameters(), 'lr': lrp}, 115 | ] 116 | def get_loss(self, net_output, ground_truth): 117 | #crition=nn.BCEWithLogitsLoss() 118 | crition=FocalLoss() 119 | loss = crition(net_output['sigmoid'].float(), ground_truth['labels'].float()) 120 | #loss = crition(net_output['class'].float(), ground_truth['labels'].float()) 121 | #crition2 = nn.MultiLabelSoftMarginLoss() 122 | #loss = crition2(net_output['mutil'], ground_truth['multi_labels']) 123 | return loss -------------------------------------------------------------------------------- /Interaction_model/Train/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | import math 5 | from collections import defaultdict 6 | import numpy as np 7 | import torch 8 | from sklearn import preprocessing 9 | import torchvision.transforms as transforms 10 | from modelv2 import MultiOutputModel 11 | from sklearn.metrics import jaccard_score, confusion_matrix, multilabel_confusion_matrix,hamming_loss,accuracy_score,average_precision_score,roc_auc_score,label_ranking_average_precision_score,recall_score,f1_score,precision_score 12 | from torch.utils.data import DataLoader 13 | from sklearn.metrics import roc_curve, auc 14 | from scipy import interp 15 | from itertools import cycle 16 | def checkpoint_load(model, name): 17 | print('Restoring checkpoint: {}'.format(name)) 18 | model.load_state_dict(torch.load(name, map_location='cpu')) 19 | epoch = int(os.path.splitext(os.path.basename(name))[0].split('-')[1]) 20 | return epoch 21 | 22 | 23 | def validate(model, dataloader, batch_size, adj, features, iteration, device, checkpoint=None): 24 | ''' 25 | pretrained_dict = torch.load(checkpoint) 26 | model_dict = model.state_dict() 27 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict and 'gc1' not in k)} 28 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict and 'gc2' not in k)} 29 | model_dict.update(pretrained_dict) 30 | 31 | model.load_state_dict(model_dict) 32 | ''' 33 | if checkpoint is not None: 34 | checkpoint_load(model, checkpoint) 35 | #adj, features = load_data() 36 | 37 | model.eval() 38 | #batch_size=512 39 | results = defaultdict(list) 40 | with torch.no_grad(): 41 | avg_loss2 = 0 42 | F1=0 43 | Recall=0 44 | ACC=0 45 | Precision=0 46 | #accuracy_mutil6=0 47 | Target=[] 48 | Predict=[] 49 | for batch in dataloader: 50 | data = batch['data'] 51 | data=torch.reshape(data,(batch_size,1,int(math.sqrt(data.size(1))),int(math.sqrt(data.size(1))))).to(torch.float32) 52 | 53 | target_labels = batch['labels'] 54 | target_labels = {t: target_labels[t].to(device) for t in target_labels} 55 | output = model(data.to(device),features.to(device), adj.to(device)) 56 | val_loss = model.get_loss(output, target_labels) 57 | #n_classes=output['class'].shape[1] 58 | avg_loss2 += val_loss.item() 59 | 60 | batch_F1,batch_Recall,batch_ACC,batch_Precision = \ 61 | calculate_metrics(output, target_labels) 62 | F1 += batch_F1 63 | Recall += batch_Recall 64 | ACC += batch_ACC 65 | Precision += batch_Precision 66 | Target += (target_labels['labels'].cpu().tolist()) 67 | Predict += (output['sigmoid'].cpu().detach().numpy().tolist()) 68 | n_samples = len(dataloader) 69 | avg_loss2 /= n_samples 70 | F1 /=n_samples 71 | Recall /=n_samples 72 | ACC /=n_samples 73 | Precision /=n_samples 74 | 75 | print('-' * 72) 76 | print("Validation loss: {:.4f}, F1: {:.4f}, Recall: {:.4f}, ACC: {:.4f},Precision: {:.4f},".format(avg_loss2, F1,Recall,ACC,Precision)) 77 | 78 | model.train() 79 | 80 | 81 | return avg_loss2,F1,Recall,ACC,Precision,Target,Predict 82 | 83 | 84 | 85 | 86 | 87 | def calculate_metrics(output, target): 88 | predicted_class_labels = output['sigmoid'].cpu() 89 | 90 | predicted_mutil_labels = output['class'].cpu() 91 | 92 | gt_mutil_labels = target['labels'].cpu() 93 | with warnings.catch_warnings(): # sklearn may produce a warning when processing zero row in confusion matrix 94 | warnings.simplefilter("ignore") 95 | 96 | y_true=gt_mutil_labels.numpy() 97 | 98 | #print(y_true) 99 | y_score =predicted_mutil_labels.detach().numpy() 100 | y_pred = predicted_class_labels.detach().numpy() 101 | 102 | F1 = f1_score(y_true,(y_pred > 0.5).astype(float)) 103 | Recall = recall_score(y_true, (y_pred > 0.5).astype(float)) 104 | Acc = accuracy_score(y_true, (y_pred > 0.5).astype(float)) 105 | Precision = precision_score(y_true,(y_pred > 0.5).astype(float)) 106 | 107 | return F1,Recall,Acc,Precision 108 | 109 | 110 | -------------------------------------------------------------------------------- /Interaction_model/Train/train_kfold.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from datetime import datetime 4 | import pandas as pd 5 | import torch 6 | import math 7 | import warnings 8 | warnings.filterwarnings("ignore") 9 | import math 10 | import ctypes 11 | libgcc_s = ctypes.CDLL('libgcc_s.so.1') 12 | import matplotlib.pyplot as plt 13 | import torch.nn as nn 14 | from sklearn import preprocessing 15 | from sklearn.model_selection import KFold 16 | from collections import defaultdict 17 | #from utils import * 18 | from scipy import sparse 19 | import torchvision.transforms as transforms 20 | from dataset import ForDataset, AttributesDataset 21 | from modelv2 import MultiOutputModel 22 | #from modelv2 import MultiOutputModel 23 | from test import calculate_metrics, validate 24 | from torch.utils.data import DataLoader 25 | import numpy as np 26 | import os 27 | from torch.nn import init 28 | import time 29 | #import torchcontrib 30 | from sklearn.metrics import roc_curve, auc 31 | from sklearn.model_selection import KFold,StratifiedKFold 32 | cpu_num = 2 33 | from scipy.stats import ks_2samp 34 | torch.set_num_threads(cpu_num) 35 | def sum_dict(a,b): 36 | temp = dict() 37 | for key in a.keys()| b.keys(): 38 | temp[key] = sum([d.get(key, 0) for d in (a, b)]) 39 | return temp 40 | 41 | def get_cur_time(): 42 | return datetime.strftime(datetime.now(), '%Y-%m-%d_%H-%M') 43 | 44 | 45 | def checkpoint_save(model, name, epoch): 46 | f = os.path.join(name, 'checkpoint-{:06d}.pth'.format(epoch)) 47 | torch.save(model.state_dict(), f) 48 | print('Saved checkpoint:', f) 49 | def l2_penalty(w): 50 | return (w**2).sum() / 2 51 | 52 | 53 | 54 | if __name__ == '__main__': 55 | 56 | parser = argparse.ArgumentParser(description='Training pipeline') 57 | parser.add_argument('--epochs', type=int, default=5, help='Number of epochs to train.') 58 | parser.add_argument('--seed', type=int, default=72, help='Random seed.') 59 | #parser.add_argument('--attributes_file', type=str, default='/home/user/yangwenyi/dataNew/dataall/pathwaylabels.csv',help="Path to the file with attributes") 60 | parser.add_argument('--device', type=str, default='cuda', help="Device: 'cuda' or 'cpu'") 61 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,metavar='LR', help='initial learning rate') 62 | parser.add_argument('--lrp', '--learning-rate-pretrained', default=0.01, type=float, metavar='LR', help='learning rate for pre-trained layers') 63 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)') 64 | parser.add_argument('--batch_size', type=int, default=64, help='batch_size') 65 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 66 | args = parser.parse_args() 67 | 68 | 69 | start_epoch = 1 70 | N_epochs = args.epochs 71 | batch_size = args.batch_size 72 | #print(batch_size) 73 | num_workers = 2 # number of processes to handle dataset loading 74 | #device = torch.device("cpu") 75 | 76 | device = torch.device("cuda:0" if torch.cuda.is_available() and args.device == 'cuda' else "cpu") 77 | #adj, features = load_data(path="./Predict/human/graph/", dataset="test") 78 | 79 | adj = torch.load("./data/adj.pth") 80 | features = torch.load("./data/features.pth") 81 | print(args) 82 | #print(features.shape) 83 | Feature=pd.read_csv('./data/Feature_data.csv',header=None).values 84 | #Feature=np.loadtxt('./Data/Feature_data.csv',str,delimiter = ",",skiprows=0) 85 | data_train=Feature[:,1:-1].astype(np.float32) 86 | label=Feature[:,[0,-1]].astype(np.float32).astype(np.int32) 87 | skf = StratifiedKFold(n_splits=5, shuffle=True) 88 | train_F1=[];train_Recall=[];train_Hamming=[];train_Precision=[];train_AP=[] 89 | test_F1=[];test_Recall=[];test_Hamming=[];test_Precision=[];test_AP=[] 90 | kn=0 91 | plt.figure() 92 | mean_tpr = 0.0 93 | MinMax = preprocessing.MinMaxScaler() 94 | data_train=MinMax.fit_transform(data_train) 95 | data_train=preprocessing.normalize(data_train, norm='l2') 96 | mean_fpr = np.linspace(0, 1, 100) 97 | for train_index, test_index in skf.split(data_train,label[:,-1]): 98 | print(str(kn+1)+"fold:") 99 | results = defaultdict(list) 100 | #print(len(train_index)) 101 | #data_train = preprocessing.scale(data_train) 102 | 103 | ''' 104 | Scaler=preprocessing.StandardScaler() 105 | data_train = Scaler.fit_transform(data_train) 106 | MinMax = preprocessing.MinMaxScaler() 107 | data_train=MinMax.fit_transform(data_train) 108 | 109 | data_train=preprocessing.normalize(data_train, norm='l2') 110 | ''' 111 | train_X, train_y = data_train[train_index], label[train_index] 112 | test_X, test_y = data_train[test_index], label[test_index] 113 | 114 | train_X = torch.from_numpy(train_X) 115 | test_X = torch.from_numpy(test_X) 116 | 117 | #test_X = Scaler.fit_transform(test_X) 118 | #test_X=preprocessing.normalize(test_X, norm='l1') 119 | train_dataset = ForDataset(train_X, train_y) 120 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=num_workers,drop_last=True) 121 | val_dataset = ForDataset(test_X,test_y) 122 | val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=num_workers,drop_last=True) 123 | model = MultiOutputModel(nfeat=features.shape[1],nlabel=features.shape[0]).to(device) 124 | #print(model) 125 | #optimizer = torch.optim.Adam(model.get_config_optim(args.lr, args.lrp),weight_decay=args.weight_decay) 126 | #optimizer = torch.optim.Adamax(model.get_config_optim(args.lr, args.lrp),weight_decay=args.weight_decay) 127 | optimizer = torch.optim.Adamax(model.parameters(),lr=args.lr,weight_decay=args.weight_decay) 128 | #optimizer = torch.optim.Adam(model.parameters(),lr=args.lr,weight_decay=args.weight_decay) 129 | 130 | #optimizer = torch.optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay) 131 | #optimizer = torch.optim.RMSprop(model.parameters(),lr=args.lr,alpha=0.99, eps=1e-08,momentum=args.momentum,weight_decay=args.weight_decay) 132 | 133 | #scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=4, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08) 134 | #optimizer = torchcontrib.optim.SWA(base_opt, swa_start=10, swa_freq=5, swa_lr=0.01) 135 | #print('SGD') 136 | #optimizer = torch.optim.Adadelta(model.get_config_optim(args.lr, args.lrp), lr=args.lr, rho=0.9, eps=1e-6, weight_decay=args.weight_decay) 137 | 138 | #optimizer = torch.optim.SGD(model.get_config_optim(args.lr, args.lrp),momentum=args.momentum,weight_decay=args.weight_decay) 139 | scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=4, verbose=False, threshold=0.0000001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08) 140 | 141 | #logdir = os.path.join('./logs/', get_cur_time()) 142 | #savedir = os.path.join('./checkpoint/', get_cur_time()) 143 | #savedir="./model/panc/train_kfold/" 144 | #os.makedirs(logdir, exist_ok=True) 145 | resultdir = "./results/" 146 | os.makedirs(resultdir, exist_ok=True) 147 | #logger = SummaryWriter(logdir) 148 | 149 | n_train_samples = len(train_dataloader) 150 | 151 | print("Starting training ...") 152 | loss_early = 10000 153 | for epoch in range(start_epoch, N_epochs + 1): 154 | 155 | model.train() 156 | batch_size = args.batch_size 157 | #print("epoch:",epoch) 158 | t=time.time() 159 | 160 | total_loss2 = 0 161 | F1=0 162 | Recall=0 163 | ACC=0 164 | Precision=0 165 | for batch in train_dataloader: 166 | optimizer.zero_grad() 167 | data = batch['data'] 168 | data=torch.reshape(data,(args.batch_size,1,int(math.sqrt(data.size(1))),int(math.sqrt(data.size(1))))).to(torch.float32) 169 | target_labels = batch['labels'] 170 | target_labels = {t: target_labels[t].to(device) for t in target_labels} 171 | #print(target_labels) 172 | output = model(data.to(device),features.to(device), adj.to(device)) 173 | loss = model.get_loss(output, target_labels) 174 | ''' 175 | re_l1=0 176 | for param in model.parameters(): 177 | re_l1+=torch.sum(torch.abs(param)) 178 | ''' 179 | loss2= loss.item() 180 | total_loss2 += loss2 181 | batch_F1,batch_Recall,batch_ACC,batch_Precision = \ 182 | calculate_metrics(output, target_labels) 183 | F1 += batch_F1 184 | Recall += batch_Recall 185 | ACC += batch_ACC 186 | Precision += batch_Precision 187 | loss.backward() 188 | optimizer.step() 189 | #optimizer.swap_swa_sgd() 190 | 191 | print("epoch {:4d}, loss: {:.4f},F1: {:.4f}, recall: {:.4f}, ACC: {:.4f}, Precision: {:.4f},time: {:.4f}".format( 192 | epoch, 193 | total_loss2 / n_train_samples, 194 | F1 / n_train_samples, 195 | Recall / n_train_samples, 196 | ACC / n_train_samples, 197 | Precision / n_train_samples, 198 | time.time()-t 199 | )) 200 | #print(accuracy_mutil6/n_train_samples) 201 | results['train_loss'].append(total_loss2 / n_train_samples) 202 | results['F1_train'].append(F1 / n_train_samples) 203 | results['recall_train'].append(Recall / n_train_samples) 204 | results['Pre_train'].append(Precision / n_train_samples) 205 | #logger.add_scalar('train_loss', total_loss2 / n_train_samples, epoch) 206 | 207 | #if epoch % 5 == 0: 208 | # checkpoint_save(model, savedir, epoch) 209 | 210 | if epoch % 5 == 0: 211 | #checkpoint = os.path.join(savedir, 'checkpoint-{:06d}.pth'.format(epoch)) 212 | val_loss2,F1_val,Recall_val,Hamming_val,Precision_val,Target,Predict = validate(model, val_dataloader, batch_size ,adj, features, epoch, device,checkpoint=None) 213 | results['val_loss'].append(val_loss2) 214 | results['F1_val'].append(F1_val) 215 | results['recall_val'].append(Recall_val) 216 | results['presion_val'].append(Precision_val) 217 | #print(Target) 218 | #print(Predict) 219 | test_F1.append(F1_val);test_Recall.append(Recall_val);test_Hamming.append(Hamming_val) 220 | test_Precision.append(Precision_val) 221 | fpr, tpr, thresholds = roc_curve(Target, Predict) 222 | mean_tpr += np.interp(mean_fpr, fpr, tpr) 223 | mean_tpr[0] = 0.0 224 | roc_auc = auc(fpr, tpr) 225 | plt.plot(fpr, tpr, lw=1, label='Fold {0:.0f} (AUC = {1:.2f})'.format(kn+1, roc_auc)) 226 | ''' 227 | if loss_early > val_loss2: 228 | if epoch < N_epochs: 229 | #print("early stopping") 230 | loss_early = val_loss2 231 | elif epoch == N_epochs: 232 | test_F1.append(F1_val);test_Recall.append(Recall_val);test_Hamming.append(Hamming_val) 233 | test_Precision.append(Precision_val) 234 | fpr, tpr, thresholds = roc_curve(Target, Predict) 235 | mean_tpr += np.interp(mean_fpr, fpr, tpr) 236 | mean_tpr[0] = 0.0 237 | roc_auc = auc(fpr, tpr) 238 | plt.plot(fpr, tpr, lw=1, label='Fold {0:.0f} (AUC = {1:.2f})'.format(kn+1, roc_auc)) 239 | else: 240 | if epoch>30: 241 | print("early stopping") 242 | checkpoint = os.path.join(savedir, 'checkpoint-{:06d}.pth'.format(epoch-5)) 243 | val_loss2,F1_val,Recall_val,Hamming_val,Precision_val,Target,Predict = validate(model, val_dataloader, batch_size ,adj, features, epoch, device,checkpoint=checkpoint) 244 | 245 | #if epoch % N_epochs == 0: 246 | #train_F1.append(F1 / n_train_samples);train_Recall.append(Recall / n_train_samples) 247 | #train_Hamming.append(ACC / n_train_samples);train_Precision.append(Precision / n_train_samples) 248 | test_F1.append(F1_val);test_Recall.append(Recall_val);test_Hamming.append(Hamming_val) 249 | test_Precision.append(Precision_val) 250 | fpr, tpr, thresholds = roc_curve(Target, Predict) 251 | mean_tpr += np.interp(mean_fpr, fpr, tpr) 252 | mean_tpr[0] = 0.0 253 | roc_auc = auc(fpr, tpr) 254 | plt.plot(fpr, tpr, lw=1, label='Fold {0:.0f} (AUC = {1:.2f})'.format(kn+1, roc_auc)) 255 | break 256 | else: 257 | pass 258 | ''' 259 | #plot_results(results, 10,"./results/results"+str(kn)+".png") 260 | kn+=1 261 | 262 | mean_tpr /= kn 263 | mean_tpr[-1] = 1.0 264 | mean_auc = auc(mean_fpr, mean_tpr) 265 | 266 | plt.plot(mean_fpr, mean_tpr, 'k--',label='Mean (ROC = {0:.2f})'.format(mean_auc), lw=2) 267 | 268 | plt.xlim([-0.05, 1.05]) 269 | plt.ylim([-0.05, 1.05]) 270 | plt.xlabel('False Positive Rate') 271 | plt.ylabel('True Positive Rate') 272 | plt.title('Receiver operating characteristic example') 273 | plt.legend(loc="lower right") 274 | plt.tight_layout() 275 | plt.savefig('./results/ROC.pdf',dpi=500) 276 | F1_test= np.mean(np.array(test_F1));Recall_test= np.mean(np.array(test_Recall)); 277 | ACC_test= np.mean(np.array(test_Hamming));Precision_test= np.mean(np.array(test_Precision)); 278 | 279 | print("Test:F1: {:.4f}, Recall: {:.4f}, ACC: {:.4f}, Precision: {:.4f}".format( 280 | F1_test, 281 | Recall_test, 282 | ACC_test, 283 | Precision_test, 284 | )) 285 | -------------------------------------------------------------------------------- /Interaction_model/Train/trainall_no.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from datetime import datetime 4 | import pandas as pd 5 | import torch 6 | import math 7 | import matplotlib.pyplot as plt 8 | import torch.nn as nn 9 | from sklearn import preprocessing 10 | from sklearn.model_selection import KFold 11 | from collections import defaultdict 12 | #from utils import * 13 | from scipy import sparse 14 | import torchvision.transforms as transforms 15 | from dataset import ForDataset, AttributesDataset 16 | from modelv2 import MultiOutputModel 17 | from test import calculate_metrics, validate 18 | from torch.utils.data import DataLoader 19 | import numpy as np 20 | import os 21 | from torch.nn import init 22 | import time 23 | #import torchcontrib 24 | from sklearn.metrics import roc_curve, auc 25 | from sklearn.model_selection import KFold,StratifiedKFold 26 | cpu_num = 6 27 | from scipy.stats import ks_2samp 28 | torch.set_num_threads(cpu_num) 29 | def sum_dict(a,b): 30 | temp = dict() 31 | for key in a.keys()| b.keys(): 32 | temp[key] = sum([d.get(key, 0) for d in (a, b)]) 33 | return temp 34 | 35 | def get_cur_time(): 36 | return datetime.strftime(datetime.now(), '%Y-%m-%d_%H-%M') 37 | 38 | 39 | def checkpoint_save(model, name, epoch): 40 | f = os.path.join(name, 'checkpoint-{:06d}.pth'.format(epoch)) 41 | torch.save(model.state_dict(), f) 42 | print('Saved checkpoint:', f) 43 | 44 | 45 | if __name__ == '__main__': 46 | parser = argparse.ArgumentParser(description='Training pipeline') 47 | parser.add_argument('--epochs', type=int, default=500, help='Number of epochs to train.') 48 | parser.add_argument('--seed', type=int, default=72, help='Random seed.') 49 | #parser.add_argument('--attributes_file', type=str, default='/home/user/yangwenyi/dataNew/dataall/pathwaylabels.csv',help="Path to the file with attributes") 50 | parser.add_argument('--device', type=str, default='cuda', help="Device: 'cuda' or 'cpu'") 51 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,metavar='LR', help='initial learning rate') 52 | parser.add_argument('--lrp', '--learning-rate-pretrained', default=0.01, type=float, metavar='LR', help='learning rate for pre-trained layers') 53 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)') 54 | parser.add_argument('--batch_size', type=int, default=64, help='batch_size') 55 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 56 | args = parser.parse_args() 57 | 58 | 59 | start_epoch = 1 60 | N_epochs = args.epochs 61 | batch_size = args.batch_size 62 | #print(batch_size) 63 | num_workers = 0 # number of processes to handle dataset loading 64 | device = torch.device("cuda" if torch.cuda.is_available() and args.device == 'cuda' else "cpu") 65 | print(device) 66 | #adj, features = load_data() 67 | adj = torch.load("./data/adj.pth") 68 | features = torch.load("./data/features.pth") 69 | print(args) 70 | #print(features.shape) 71 | Feature=pd.read_csv('./data/Feature_data.csv',header=None).values 72 | data_train=Feature[:,1:-1].astype(np.float32) 73 | label=Feature[:,[0,-1]].astype(np.float32).astype(np.int32) 74 | 75 | MinMax = preprocessing.MinMaxScaler() 76 | data_train=MinMax.fit_transform(data_train) 77 | data_train=preprocessing.normalize(data_train, norm='l2') 78 | 79 | train_dataset = ForDataset(data_train, label) 80 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) 81 | 82 | 83 | model = MultiOutputModel(nfeat=features.shape[1],nlabel=features.shape[0]).to(device) 84 | 85 | optimizer = torch.optim.Adamax(model.parameters(),lr=args.lr,weight_decay=args.weight_decay) 86 | scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=4, verbose=False, threshold=0.000001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08) 87 | savedir="./model/" 88 | #os.makedirs(logdir, exist_ok=True) 89 | os.makedirs(savedir, exist_ok=True) 90 | 91 | n_train_samples = len(train_dataloader) 92 | 93 | print("Starting training ...") 94 | for epoch in range(start_epoch, N_epochs + 1): 95 | #model.train() 96 | batch_size = args.batch_size 97 | #print("epoch:",epoch) 98 | t=time.time() 99 | 100 | total_loss2 = 0 101 | F1=0 102 | Recall=0 103 | ACC=0 104 | Precision=0 105 | for batch in train_dataloader: 106 | optimizer.zero_grad() 107 | data = batch['data'] 108 | data=torch.reshape(data,(args.batch_size,1,int(math.sqrt(data.size(1))),int(math.sqrt(data.size(1))))).to(torch.float32) 109 | target_labels = batch['labels'] 110 | target_labels = {t: target_labels[t].to(device) for t in target_labels} 111 | #print(target_labels) 112 | output = model(data.to(device),features.to(device), adj.to(device)) 113 | loss = model.get_loss(output, target_labels) 114 | loss2= loss.item() 115 | total_loss2 += loss2 116 | batch_F1,batch_Recall,batch_ACC,batch_Precision = \ 117 | calculate_metrics(output, target_labels) 118 | F1 += batch_F1 119 | Recall += batch_Recall 120 | ACC += batch_ACC 121 | Precision += batch_Precision 122 | loss.backward() 123 | optimizer.step() 124 | 125 | #optimizer.swap_swa_sgd() 126 | print("epoch {:4d}, loss: {:.4f},F1: {:.4f}, recall: {:.4f}, ACC: {:.4f}, Precision: {:.4f},time: {:.4f}".format( 127 | epoch, 128 | total_loss2 / n_train_samples, 129 | F1 / n_train_samples, 130 | Recall / n_train_samples, 131 | ACC / n_train_samples, 132 | Precision / n_train_samples, 133 | time.time()-t 134 | )) 135 | 136 | if epoch % 5 == 0: 137 | checkpoint_save(model, savedir, epoch) 138 | 139 | -------------------------------------------------------------------------------- /Interaction_model/__pycache__/Mobilev2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Interaction_model/__pycache__/Mobilev2.cpython-37.pyc -------------------------------------------------------------------------------- /Interaction_model/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Interaction_model/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /Interaction_model/__pycache__/modelv2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Interaction_model/__pycache__/modelv2.cpython-37.pyc -------------------------------------------------------------------------------- /Interaction_model/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Interaction_model/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /Interaction_model/cluster/CellAnnotate.R: -------------------------------------------------------------------------------- 1 | suppressMessages(library(Seurat)) 2 | suppressMessages(library(scCATCH)) 3 | 4 | args=commandArgs(T) 5 | parameter1 = args[1] 6 | options(warn = -1) 7 | testdata <- readRDS(parameter1) 8 | testdata <-Seurat::NormalizeData(testdata) 9 | testdata <-Seurat::FindVariableFeatures(testdata,selection.method = "vst", nfeatures = 2000) 10 | testdatasc <- rev_gene(data = testdata@assays$RNA@data, data_type = "data", species = "Human", geneinfo = geneinfo) 11 | label<-read.csv('./output/pre_label.csv') 12 | labels<-as.character(label$labels) 13 | obj <- createscCATCH(data = testdatasc, cluster = labels) 14 | obj <- findmarkergene(object = obj, species = "Human", marker = cellmatch, tissue = c('Blood','Peripheral blood','Bone marrow')) 15 | obj <- findcelltype(object = obj) 16 | print(obj@celltype$cell_type) 17 | write.csv(obj@celltype$cell_type,"./output/cell_type.csv") 18 | write.csv(obj@celltype$cluster,"./output/cell_cluster.csv") -------------------------------------------------------------------------------- /Interaction_model/cluster/CellAnnotate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import umap 5 | from collections import OrderedDict 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def plot(X, fig, col, size, true_labels,ann): 10 | ax = fig.add_subplot(1, 1, 1) 11 | for i, point in enumerate(X): 12 | ax.scatter(point[0], point[1], s=size, c=col[true_labels[i]],label=ann[i]) 13 | 14 | 15 | def plotClusters(hidden_emb, true_labels,ann): 16 | # Doing dimensionality reduction for plotting 17 | Umap = umap.UMAP(random_state=42) 18 | X_umap = Umap.fit_transform(hidden_emb) 19 | fig2 = plt.figure(figsize=(10,10),dpi=500) 20 | plot(X_umap, fig2, ['green','brown','purple','orange','yellow','hotpink','red','cyan','blue'], 8, true_labels,ann) 21 | handles, labels = fig2.gca().get_legend_handles_labels() 22 | by_label = OrderedDict(zip(labels, handles)) 23 | fig2.legend(by_label.values(), by_label.keys(),loc="upper right") 24 | #fig2.legend() 25 | fig2.savefig("./output/UMAP.pdf") 26 | plt.close() 27 | 28 | 29 | Cell_type = pd.read_csv('./output/cell_type.csv').values[:,-1] 30 | pre_cell_cluster = pd.read_csv('./output/cell_cluster.csv').values[:,-1] 31 | pre_label = pd.read_csv('./output/pre_label.csv').values 32 | Cell_label = pre_label[:,-1] 33 | index = pre_label[:,0] 34 | #print(index) 35 | columns = ["labels"] 36 | #print(pre_cell_cluster) 37 | pre_cell_cluster = pre_cell_cluster.tolist() 38 | 39 | all_cell = [] 40 | for i in range(len(Cell_label)): 41 | #if (Cell_label[i]) in pre_cell_cluster: 42 | #print((Cell_label[i])) 43 | index1=pre_cell_cluster.index((Cell_label[i])) 44 | all_cell.append(Cell_type[index1]) 45 | pd.DataFrame(all_cell,index=index,columns=columns).to_csv("./output/cell_annotatetype.csv",quoting=1) 46 | 47 | 48 | featureMatrix = pd.read_csv('./output/pre_embedding.txt',header=None).values 49 | plotClusters(featureMatrix,Cell_label,all_cell) 50 | 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /Interaction_model/cluster/Cluster.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import numpy as np 4 | import pandas as pd 5 | from preprocess import * 6 | from pretrain import * 7 | import sys 8 | import argparse 9 | import random 10 | from sklearn.cluster import SpectralBiclustering,KMeans, kmeans_plusplus, DBSCAN,SpectralClustering 11 | from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score 12 | from sklearn.metrics import adjusted_rand_score as ari_score 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from torch.nn.parameter import Parameter 17 | from torch.optim import Adam ,SGD,Adamax 18 | from torch.utils.data import DataLoader 19 | from torch.nn import Linear 20 | from utils import load_data, load_graph 21 | from GNN import GNNLayer 22 | import umap 23 | from evaluation import eva,eva_pretrain 24 | from collections import Counter 25 | from sklearn.manifold import TSNE 26 | import matplotlib.pyplot as plt 27 | from collections import OrderedDict 28 | 29 | 30 | 31 | def plot(X, fig, col, size, true_labels,ann): 32 | ax = fig.add_subplot(1, 1, 1) 33 | for i, point in enumerate(X): 34 | ax.scatter(point[0], point[1], s=size, c=col[true_labels[i]],label=ann[i]) 35 | 36 | 37 | def plotClusters(hidden_emb, true_labels,ann): 38 | # Doing dimensionality reduction for plotting 39 | Umap = umap.UMAP(random_state=42) 40 | X_umap = Umap.fit_transform(hidden_emb) 41 | fig2 = plt.figure(figsize=(10,10),dpi=500) 42 | plot(X_umap, fig2, ['green','brown','purple','orange','yellow','hotpink','red','cyan','blue'], 8, true_labels,ann) 43 | handles, labels = fig2.gca().get_legend_handles_labels() 44 | by_label = OrderedDict(zip(labels, handles)) 45 | fig2.legend(by_label.values(), by_label.keys(),loc="upper right") 46 | #fig2.legend() 47 | fig2.savefig("./output/UMAP.pdf") 48 | plt.close() 49 | 50 | def init_seed(opt): 51 | torch.cuda.cudnn_enabled = False 52 | np.random.seed(opt.seed) 53 | torch.manual_seed(opt.seed) 54 | torch.cuda.manual_seed(opt.seed) 55 | def pretarin_cluster(n_clusters,x,device): 56 | 57 | print("generate cell graph...") 58 | Auto = args.Auto 59 | #calculate the number of clusters 60 | 61 | 62 | device = device 63 | 64 | silhouette_pre=[] 65 | print("Start pretrain") 66 | for i in range(args.pretrain_frequency): 67 | print("pretrain:"+str(i)) 68 | model = AE( 69 | n_enc_1=100, 70 | n_enc_2=200, 71 | n_enc_3=200, 72 | n_dec_1=200, 73 | n_dec_2=200, 74 | n_dec_3=100, 75 | n_input=2000, 76 | n_z=5).to(device) 77 | dataset = LoadDataset(x) 78 | epoch = args.pretrain_epoch 79 | silhouette=pretrain_ae(model,dataset,i,device,n_clusters,epoch,Auto=Auto) 80 | silhouette_pre.append(silhouette) 81 | silhouette_pre = np.array(silhouette_pre) 82 | premodel_i=np.where(silhouette_pre==np.max(silhouette_pre))[0][0] 83 | print("Pretrain end") 84 | return premodel_i 85 | 86 | class AE_train(nn.Module): 87 | 88 | def __init__(self, n_enc_1, n_enc_2, n_enc_3, n_dec_1, n_dec_2, n_dec_3, 89 | n_input, n_z): 90 | super(AE_train, self).__init__() 91 | self.enc_1 = Linear(n_input, n_enc_1) 92 | self.enc_2 = Linear(n_enc_1, n_enc_2) 93 | self.enc_3 = Linear(n_enc_2, n_enc_3) 94 | self.z_layer = Linear(n_enc_3, n_z) 95 | 96 | self.dec_1 = Linear(n_z, n_dec_1) 97 | self.dec_2 = Linear(n_dec_1, n_dec_2) 98 | self.dec_3 = Linear(n_dec_2, n_dec_3) 99 | self.x_bar_layer = Linear(n_dec_3, n_input) 100 | 101 | def forward(self, x): 102 | enc_h1 = F.relu(self.enc_1(x)) 103 | enc_h2 = F.relu(self.enc_2(enc_h1)) 104 | enc_h3 = F.relu(self.enc_3(enc_h2)) 105 | z = self.z_layer(enc_h3) 106 | 107 | dec_h1 = F.relu(self.dec_1(z)) 108 | dec_h2 = F.relu(self.dec_2(dec_h1)) 109 | dec_h3 = F.relu(self.dec_3(dec_h2)) 110 | x_bar = self.x_bar_layer(dec_h3) 111 | 112 | return x_bar, enc_h1, enc_h2, enc_h3, z 113 | 114 | 115 | class ClusterModel(nn.Module): 116 | 117 | def __init__(self, n_enc_1, n_enc_2, n_enc_3, n_dec_1, n_dec_2, n_dec_3, 118 | n_input, n_z, n_clusters, v=1): 119 | super(ClusterModel, self).__init__() 120 | 121 | # autoencoder for intra information 122 | 123 | self.ae = AE_train( 124 | n_enc_1=n_enc_1, 125 | n_enc_2=n_enc_2, 126 | n_enc_3=n_enc_3, 127 | n_dec_1=n_dec_1, 128 | n_dec_2=n_dec_2, 129 | n_dec_3=n_dec_3, 130 | n_input=n_input, 131 | n_z=n_z) 132 | self.ae.load_state_dict(torch.load(args.pretrain_path, map_location='cpu')) 133 | 134 | 135 | #self.ae.load_state_dict(torch.load(args.pretrain_path, map_location='cpu')) 136 | 137 | # GCN for inter information 138 | self.gnn_1 = GNNLayer(n_input, n_enc_1) 139 | self.gnn_2 = GNNLayer(n_enc_1, n_enc_2) 140 | self.gnn_3 = GNNLayer(n_enc_2, n_enc_3) 141 | self.gnn_4 = GNNLayer(n_enc_3, n_z) 142 | self.gnn_5 = GNNLayer(n_z, n_clusters) 143 | 144 | # cluster layer 145 | self.cluster_layer = Parameter(torch.Tensor(n_clusters, n_z)) 146 | torch.nn.init.xavier_normal_(self.cluster_layer.data) 147 | 148 | # degree 149 | self.v = v 150 | 151 | def forward(self, x, adj): 152 | # DNN Module 153 | #x_bar, tra1, tra2, tra3, z = self.ae(x) 154 | #print(x.size()) 155 | # GCN Module 156 | h1 = self.gnn_1(x, adj) 157 | h2 = self.gnn_2(h1, adj) 158 | h3 = self.gnn_3(h2, adj) 159 | h4 = self.gnn_4(h3, adj) 160 | h5 = self.gnn_5(h4, adj, active=False) 161 | predict = F.softmax(h5, dim=1) 162 | 163 | 164 | enc_h1 = F.relu(self.ae.enc_1(x)) 165 | #print(enc_h1.size()) 166 | enc_h2 = F.relu(self.ae.enc_2(enc_h1+h1)) 167 | enc_h3 = F.relu(self.ae.enc_3(enc_h2+h2)) 168 | z = self.ae.z_layer(enc_h3+h3) 169 | 170 | dec_h1 = F.relu(self.ae.dec_1(z+h4)) 171 | dec_h2 = F.relu(self.ae.dec_2(dec_h1+h3)) 172 | dec_h3 = F.relu(self.ae.dec_3(dec_h2+h2)) 173 | x_bar = self.ae.x_bar_layer(dec_h3+h1) 174 | 175 | 176 | 177 | # Dual Self-supervised Module 178 | q = 1.0 / (1.0 + torch.sum(torch.pow(z.unsqueeze(1) - self.cluster_layer, 2), 2) / self.v) 179 | q = q.pow((self.v + 1.0) / 2.0) 180 | q = (q.t() / torch.sum(q, 1)).t() 181 | 182 | return x_bar, q, predict, z 183 | 184 | 185 | def target_distribution(q): 186 | weight = q**2 / q.sum(0) 187 | return (weight.t() / weight.sum(1)).t() 188 | def adjust_learning_rate(optimizer, epoch): 189 | lr = 0.001 * (0.1 ** (epoch // 20)) 190 | for param_group in optimizer.param_groups: 191 | param_group['lr'] = lr 192 | 193 | def train_cluster(dataset,n_clusters,device): 194 | Auto=args.Auto 195 | if Auto: 196 | if z.shape[0] < 2000: 197 | resolution = 0.8 198 | else: 199 | resolution = 0.5 200 | n_clusters = int(n_clusters*resolution) if int(n_clusters*resolution)>=3 else 2 201 | else: 202 | n_clusters=n_clusters 203 | #device = args.device 204 | model = ClusterModel(100, 200, 200, 200, 200, 100, 205 | n_input=args.n_input, 206 | n_z=args.n_z, 207 | n_clusters=n_clusters).to(device) 208 | 209 | 210 | optimizer = Adamax(model.parameters(), lr=args.lr) 211 | 212 | # KNN Graph 213 | adj = load_graph(args.name) 214 | adj = adj.to(device) 215 | 216 | # cluster parameter initiate 217 | data = torch.Tensor(dataset.x).to(device) 218 | y = dataset.y 219 | with torch.no_grad(): 220 | _, _, _, _, z = model.ae(data) 221 | print(n_clusters) 222 | kmeans = KMeans(n_clusters=n_clusters, n_init=20) 223 | y_pred = kmeans.fit_predict(z.data.cpu().numpy()) 224 | y_pred_last = y_pred 225 | model.cluster_layer.data = torch.tensor(kmeans.cluster_centers_).to(device) 226 | meta = pd.read_csv('./output/cell_name.csv',header=None).values 227 | index = meta[:,0] 228 | columns = ["labels"] 229 | #print(meta.shape) 230 | for epoch in range(args.Train_epoch): 231 | adjust_learning_rate(optimizer, epoch) 232 | 233 | if epoch % 1 == 0: 234 | # update_interval 235 | _, tmp_q, pred, _ = model(data, adj) 236 | tmp_q = tmp_q.data 237 | p = target_distribution(tmp_q) 238 | 239 | res1 = tmp_q.cpu().numpy().argmax(1) #Q 240 | res2 = pred.data.cpu().numpy().argmax(1) #Z 241 | res3 = p.data.cpu().numpy().argmax(1) #P 242 | nmi,ari,ami,silhouette=eva(tmp_q.cpu().numpy(),y, res1, str(epoch) + 'Q') 243 | 244 | print(str(epoch) + 'Q', 245 | ', nmi {:.4f}'.format(nmi), ', ari {:.4f}'.format(ari), 246 | ', ami {:.4f}'.format(ami),', silhouette {:.4f}'.format(silhouette) 247 | ) 248 | ''' 249 | nmi,ari,ami,silhouette=eva(z.data.cpu().numpy(),y, res2, str(epoch) + 'Z') 250 | print(str(epoch) + 'Z', 251 | ', nmi {:.4f}'.format(nmi), ', ari {:.4f}'.format(ari), 252 | ', ami {:.4f}'.format(ami),', silhouette {:.4f}'.format(silhouette) 253 | ) 254 | nmi,ari,ami,silhouette=eva(z.data.cpu().numpy(),y, res3, str(epoch) + 'P') 255 | print(str(epoch) + 'P', 256 | ', nmi {:.4f}'.format(nmi), ', ari {:.4f}'.format(ari), 257 | ', ami {:.4f}'.format(ami),', silhouette {:.4f}'.format(silhouette) 258 | ) 259 | ''' 260 | x_bar, q, pred, _ = model(data, adj) 261 | 262 | kl_loss = F.kl_div(q.log(), p, reduction='batchmean') 263 | ce_loss = F.kl_div(pred.log(), p, reduction='batchmean') 264 | re_loss = F.mse_loss(x_bar, data) 265 | #loss = 0.1*kl_loss + 1*ce_loss + 0.001*re_loss 266 | loss =0.0001*kl_loss + 0.001*ce_loss + 1*re_loss 267 | 268 | optimizer.zero_grad() 269 | loss.backward() 270 | optimizer.step() 271 | #np.savetxt("./output/pre_label.txt",res1,fmt="%s",delimiter=",") 272 | np.savetxt("./output/pre_embedding.txt",tmp_q.cpu().numpy(),fmt="%s",delimiter=",") 273 | pd.DataFrame(res1,index=index,columns=columns).to_csv("./output/pre_label.csv",quoting=1) 274 | #size = len(np.unique(res1)) 275 | #drawUMAP(tmp_q.cpu().numpy(), res1, size, saveFlag=True) 276 | 277 | 278 | 279 | if __name__ == "__main__": 280 | if not os.path.exists("./output/"): 281 | os.system('mkdir ./output/') 282 | if not os.path.exists("./output/graph"): 283 | os.system('mkdir ./output/graph') 284 | if not os.path.exists("./output/data"): 285 | os.system('mkdir ./output/data') 286 | if not os.path.exists("./output/model"): 287 | os.system('mkdir ./output/model') 288 | 289 | parser = argparse.ArgumentParser( 290 | description='Cell_cluster', 291 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 292 | parser.add_argument('--name', type=str, default='cell') 293 | parser.add_argument('--seed', type=int, default=2022, help='Random seed.') 294 | parser.add_argument('--lr', type=float, default=0.001) 295 | parser.add_argument('--n_clusters', default=10, type=int) 296 | parser.add_argument('--n_z', default=5, type=int) 297 | parser.add_argument('--pretrain_epoch', default=30, type=int) 298 | parser.add_argument('--pretrain_frequency', default=10, type=int) 299 | parser.add_argument('--Train_epoch', default=10, type=int) 300 | parser.add_argument('--n_input', default=2000, type=int) 301 | parser.add_argument('--pretrain_path', type=str, default='pkl') 302 | parser.add_argument('--Auto', default=False) 303 | parser.add_argument('--pretain', default=True) 304 | parser.add_argument('--device', type=str, default='cuda', 305 | help="Device: 'cuda' or 'cpu'") 306 | args = parser.parse_args() 307 | args.cuda = torch.cuda.is_available() 308 | print("use cuda: {}".format(args.cuda)) 309 | 310 | para0=str("../input/test.rds") 311 | os.system("Rscript Feature.R"+" "+para0) 312 | feature = pd.read_csv('./output/Top2000.csv',header=None,low_memory=False).values 313 | Cell_name = feature[0,1:] 314 | featureMatrix = feature[1:,1:].T 315 | np.savetxt("./output/cell_name.txt",Cell_name,fmt="%s",delimiter=" ") 316 | data=(featureMatrix.astype(np.float32)) 317 | np.savetxt("./output/data/cell.txt",data,fmt="%s",delimiter=" ") 318 | 319 | adj, edgeList = generateAdj(featureMatrix) 320 | #print(adj) 321 | idx=[] 322 | for i in range(np.array(edgeList).shape[0]): 323 | if np.array(edgeList)[i,-1]==1.0: 324 | idx.append(i) 325 | np.savetxt("./output/graph/cell_graph.txt",np.array(edgeList)[idx,0:-1],fmt="%d") 326 | 327 | 328 | x = np.loadtxt('./output/data/cell.txt', dtype=float) 329 | y = np.loadtxt('../input/test_label.txt', dtype=int) 330 | 331 | if args.Auto: 332 | print("Auto_mode") 333 | auto_clusters = getcluster(edgeList) 334 | n_clusters = auto_clusters 335 | 336 | print(n_clusters) 337 | 338 | else: 339 | n_clusters = int(max(y) - min(y) + 1) 340 | 341 | print(n_clusters) 342 | 343 | 344 | device = torch.device("cuda" if torch.cuda.is_available() and args.device == 'cuda' else "cpu") 345 | args.Auto = False 346 | if args.pretain: 347 | premodel_i = pretarin_cluster(n_clusters,x,device) 348 | #print(premodel_i) 349 | #pretrain_path 350 | args.pretrain_path = './output/model/test'+str(premodel_i)+'.pkl' 351 | else: 352 | #pretain_model 353 | args.pretrain_path = './pretain_model/pbmc/pbmc.pkl' 354 | 355 | dataset = load_data(args.name) 356 | train_cluster(dataset,n_clusters,device) 357 | 358 | print("Cell Annotate") 359 | os.system("Rscript CellAnnotate.R"+" "+para0) 360 | Cell_type = pd.read_csv('./output/cell_type.csv').values[:,-1] 361 | Cell_type_new=[] 362 | for i in range(len(Cell_type)): 363 | Cell_type_new.append(Cell_type[i].split(",")[0]) 364 | pre_cell_cluster = pd.read_csv('./output/cell_cluster.csv').values[:,-1] 365 | pre_label = pd.read_csv('./output/pre_label.csv').values 366 | Cell_label = pre_label[:,-1] 367 | index = pre_label[:,0] 368 | #print((pre_label[:,0])) 369 | #print(index) 370 | columns = ["labels"] 371 | #print(pre_cell_cluster) 372 | pre_cell_cluster = pre_cell_cluster.tolist() 373 | all_cell = [] 374 | for i in range(len(Cell_label)): 375 | index1=pre_cell_cluster.index((Cell_label[i])) 376 | all_cell.append(Cell_type_new[index1].split(",")[0]) 377 | 378 | 379 | pd.DataFrame(all_cell,index=index,columns=columns).to_csv("./output/cell_annotatetype.csv",quoting=1) 380 | featureMatrix = pd.read_csv('./output/pre_embedding.txt',header=None).values 381 | plotClusters(featureMatrix,Cell_label,all_cell) 382 | -------------------------------------------------------------------------------- /Interaction_model/cluster/Feature.R: -------------------------------------------------------------------------------- 1 | suppressMessages(library(Seurat)) 2 | args=commandArgs(T) 3 | parameter1 = args[1] 4 | parameter2 = args[2] 5 | options(warn = -1) 6 | 7 | testdata <- readRDS(parameter1) 8 | testdata <-Seurat::NormalizeData(testdata) 9 | testdata <-Seurat::FindVariableFeatures(testdata,selection.method = "vst", nfeatures = 2000) 10 | data.input<-testdata@assays$RNA@data[testdata@assays$RNA@var.features,] 11 | write.csv(data.input,"./output/Top2000.csv") 12 | 13 | 14 | 15 | #grid_col <- c("red", "green", "cornflowerblue","blueviolet", brewer.pal(8, 'Accent')[6], "darkseagreen1", "hotpink1", "hotpink4", "gold", "slateblue3","tomato", brewer.pal(8, 'Set3')[5],brewer.pal(8, 'Set3')[4]) 16 | 17 | #brewer.pal(8, 'Reds')[6],brewer.pal(8, 'YlGn')[5],brewer.pal(8, 'YlGnBu')[6],brewer.pal(8, 'RdPu')[7],brewer.pal(8, 'Purples')[6], 18 | #brewer.pal(8, 'Reds')[5],brewer.pal(8, 'PuRd')[7],brewer.pal(8, 'PuBuGn')[7],brewer.pal(8, 'PuBu')[8],brewer.pal(8, 'Oranges')[5], 19 | #brewer.pal(8, 'BuPu')[8],brewer.pal(9, 'Oranges')[9],brewer.pal(8, 'Blues')[8] 20 | 21 | -------------------------------------------------------------------------------- /Interaction_model/cluster/GNN.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | from torch.nn.modules.module import Module 6 | 7 | 8 | class GNNLayer(Module): 9 | def __init__(self, in_features, out_features): 10 | super(GNNLayer, self).__init__() 11 | self.in_features = in_features 12 | self.out_features = out_features 13 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 14 | torch.nn.init.xavier_uniform_(self.weight) 15 | 16 | def forward(self, features, adj, active=True): 17 | support = torch.mm(features, self.weight) 18 | output = torch.spmm(adj, support) 19 | if active: 20 | output = F.relu(output) 21 | return output 22 | 23 | -------------------------------------------------------------------------------- /Interaction_model/cluster/__pycache__/GNN.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Interaction_model/cluster/__pycache__/GNN.cpython-37.pyc -------------------------------------------------------------------------------- /Interaction_model/cluster/__pycache__/KNN.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Interaction_model/cluster/__pycache__/KNN.cpython-37.pyc -------------------------------------------------------------------------------- /Interaction_model/cluster/__pycache__/Model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Interaction_model/cluster/__pycache__/Model.cpython-37.pyc -------------------------------------------------------------------------------- /Interaction_model/cluster/__pycache__/evaluation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Interaction_model/cluster/__pycache__/evaluation.cpython-37.pyc -------------------------------------------------------------------------------- /Interaction_model/cluster/__pycache__/preprocess.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Interaction_model/cluster/__pycache__/preprocess.cpython-37.pyc -------------------------------------------------------------------------------- /Interaction_model/cluster/__pycache__/pretrain.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Interaction_model/cluster/__pycache__/pretrain.cpython-37.pyc -------------------------------------------------------------------------------- /Interaction_model/cluster/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Interaction_model/cluster/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /Interaction_model/cluster/evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from munkres import Munkres, print_matrix 3 | from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score 4 | from sklearn.metrics import adjusted_rand_score as ari_score 5 | from scipy.optimize import linear_sum_assignment as linear 6 | from sklearn import metrics 7 | from sklearn.metrics import adjusted_mutual_info_score as ami_score 8 | from sklearn.metrics import silhouette_score,davies_bouldin_score 9 | 10 | def cluster_acc(y_true, y_pred): 11 | y_true = y_true - np.min(y_true) 12 | 13 | l1 = list(set(y_true)) 14 | numclass1 = len(l1) 15 | 16 | l2 = list(set(y_pred)) 17 | numclass2 = len(l2) 18 | 19 | ind = 0 20 | if numclass1 != numclass2: 21 | for i in l1: 22 | if i in l2: 23 | pass 24 | else: 25 | y_pred[ind] = i 26 | ind += 1 27 | 28 | l2 = list(set(y_pred)) 29 | numclass2 = len(l2) 30 | 31 | if numclass1 != numclass2: 32 | print('error') 33 | return 34 | 35 | cost = np.zeros((numclass1, numclass2), dtype=int) 36 | for i, c1 in enumerate(l1): 37 | mps = [i1 for i1, e1 in enumerate(y_true) if e1 == c1] 38 | for j, c2 in enumerate(l2): 39 | mps_d = [i1 for i1 in mps if y_pred[i1] == c2] 40 | cost[i][j] = len(mps_d) 41 | 42 | # match two clustering results by Munkres algorithm 43 | m = Munkres() 44 | cost = cost.__neg__().tolist() 45 | indexes = m.compute(cost) 46 | 47 | # get the match results 48 | new_predict = np.zeros(len(y_pred)) 49 | for i, c in enumerate(l1): 50 | # correponding label in l2: 51 | c2 = l2[indexes[i][1]] 52 | 53 | # ai is the index with label==c2 in the pred_label list 54 | ai = [ind for ind, elm in enumerate(y_pred) if elm == c2] 55 | new_predict[ai] = c 56 | 57 | acc = metrics.accuracy_score(y_true, new_predict) 58 | f1_macro = metrics.f1_score(y_true, new_predict, average='macro') 59 | precision_macro = metrics.precision_score(y_true, new_predict, average='macro') 60 | recall_macro = metrics.recall_score(y_true, new_predict, average='macro') 61 | f1_micro = metrics.f1_score(y_true, new_predict, average='micro') 62 | precision_micro = metrics.precision_score(y_true, new_predict, average='micro') 63 | recall_micro = metrics.recall_score(y_true, new_predict, average='micro') 64 | return acc, f1_macro 65 | 66 | 67 | def eva(X,y_true, y_pred, epoch=0): 68 | #acc, f1 = cluster_acc(y_true, y_pred) 69 | nmi = nmi_score(y_true, y_pred, average_method='arithmetic') 70 | ari = ari_score(y_true, y_pred) 71 | ami = ami_score(y_true, y_pred) 72 | silhouette = silhouette_score(X, y_pred,metric='euclidean') 73 | return nmi,ari,ami,silhouette 74 | def eva_pretrain(X, y_pred, epoch=0): 75 | silhouette = silhouette_score(X, y_pred,metric='euclidean') 76 | return silhouette -------------------------------------------------------------------------------- /Interaction_model/cluster/preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import ast 4 | import operator 5 | from itertools import chain 6 | import math 7 | import os 8 | from scipy import sparse 9 | from sklearn.cluster import KMeans 10 | from sklearn.neighbors import kneighbors_graph 11 | from scipy.spatial import distance_matrix, minkowski_distance, distance 12 | import scipy.sparse 13 | import sys 14 | import pickle 15 | import csv 16 | import networkx as nx 17 | import numpy as np 18 | from sklearn.ensemble import IsolationForest 19 | import time 20 | from multiprocessing import Pool 21 | import multiprocessing 22 | from igraph import * 23 | from sklearn import preprocessing 24 | 25 | def calculateKNNgraphDistanceMatrixStatsSingleThread(featureMatrix, distanceType='euclidean', k=10, param=None): 26 | 27 | edgeList=[] 28 | 29 | p_time = time.time() 30 | for i in np.arange(featureMatrix.shape[0]): 31 | tmp=featureMatrix[i,:].reshape(1,-1) 32 | distMat = distance.cdist(tmp,featureMatrix, distanceType) 33 | res = distMat.argsort()[:k+1] 34 | tmpdist = distMat[0,res[0][1:k+1]] 35 | 36 | boundary = np.mean(tmpdist)+np.std(tmpdist) 37 | for j in np.arange(1,k+1): 38 | # TODO: check, only exclude large outliners 39 | # if (distMat[0,res[0][j]]<=mean+std) and (distMat[0,res[0][j]]>=mean-std): 40 | 41 | if distMat[0,res[0][j]]<=boundary: 42 | weight = 1.0 43 | else: 44 | weight = 0.0 45 | 46 | #weight = 1.0 47 | edgeList.append((i,res[0][j],weight)) 48 | 49 | return edgeList 50 | 51 | def calculateKNNgraphDistanceMatrix(featureMatrix, distanceType='euclidean', k=10): 52 | 53 | distMat = distance.cdist(featureMatrix,featureMatrix, distanceType) 54 | 55 | edgeList=[] 56 | 57 | for i in np.arange(distMat.shape[0]): 58 | res = distMat[:,i].argsort()[:k] 59 | for j in np.arange(k): 60 | edgeList.append((i,res[j])) 61 | 62 | return edgeList 63 | 64 | def edgeList2edgeDict(edgeList, nodesize): 65 | graphdict={} 66 | tdict={} 67 | 68 | for edge in edgeList: 69 | end1 = edge[0] 70 | end2 = edge[1] 71 | tdict[end1]="" 72 | tdict[end2]="" 73 | if end1 in graphdict: 74 | tmplist = graphdict[end1] 75 | else: 76 | tmplist = [] 77 | tmplist.append(end2) 78 | graphdict[end1]= tmplist 79 | 80 | #check and get full matrix 81 | for i in range(nodesize): 82 | if i not in tdict: 83 | graphdict[i]=[] 84 | 85 | return graphdict 86 | 87 | def generateAdj(featureMatrix, graphType='KNNgraph', para = None): 88 | """ 89 | Generating edgeList 90 | """ 91 | edgeList = None 92 | adj = None 93 | #edgeList = calculateKNNgraphDistanceMatrix(featureMatrix) 94 | edgeList = calculateKNNgraphDistanceMatrixStatsSingleThread(featureMatrix) 95 | graphdict = edgeList2edgeDict(edgeList, featureMatrix.shape[0]) 96 | adj = nx.adjacency_matrix(nx.from_dict_of_lists(graphdict)) 97 | return adj, edgeList 98 | 99 | 100 | def generateLouvainCluster(edgeList): 101 | """ 102 | Louvain Clustering using igraph 103 | """ 104 | Gtmp = nx.Graph() 105 | Gtmp.add_weighted_edges_from(edgeList) 106 | W = nx.adjacency_matrix(Gtmp) 107 | W = W.todense() 108 | graph = Graph.Weighted_Adjacency( 109 | W.tolist(), mode=ADJ_UNDIRECTED, attr="weight", loops=False) 110 | louvain_partition = graph.community_multilevel( 111 | weights=graph.es['weight'], return_levels=False) 112 | size = len(louvain_partition) 113 | hdict = {} 114 | count = 0 115 | for i in range(size): 116 | tlist = louvain_partition[i] 117 | for j in range(len(tlist)): 118 | hdict[tlist[j]] = i 119 | count += 1 120 | 121 | listResult = [] 122 | for i in range(count): 123 | listResult.append(hdict[i]) 124 | 125 | return listResult, size 126 | 127 | def read_data(filename, sparsify = False, skip_exprs = False): 128 | with h5py.File(filename, "r") as f: 129 | obs = pd.DataFrame(dict_from_group(f["obs"]), index = utilss.decode(f["obs_names"][...])) 130 | var = pd.DataFrame(dict_from_group(f["var"]), index = utilss.decode(f["var_names"][...])) 131 | uns = dict_from_group(f["uns"]) 132 | if not skip_exprs: 133 | exprs_handle = f["exprs"] 134 | if isinstance(exprs_handle, h5py.Group): 135 | mat = sp.sparse.csr_matrix((exprs_handle["data"][...], exprs_handle["indices"][...], 136 | exprs_handle["indptr"][...]), shape = exprs_handle["shape"][...]) 137 | else: 138 | mat = exprs_handle[...].astype(np.float32) 139 | if sparsify: 140 | mat = sp.sparse.csr_matrix(mat) 141 | else: 142 | mat = sp.sparse.csr_matrix((obs.shape[0], var.shape[0])) 143 | return mat, obs, var, uns 144 | 145 | 146 | def prepro(data_type,filename): 147 | if data_type == 'csv': 148 | data_path = "./dataset/" + filename + "/data.csv" 149 | label_path = "./dataset/" + filename + "/label.csv" 150 | X = pd.read_csv(data_path, header=0, index_col=0, sep=',') 151 | #X = np.expm1(X) 152 | cell_label = pd.read_csv(label_path).values[:,-1] 153 | if data_type == 'txt': 154 | data_path = "./dataset/" + filename + "/data.txt" 155 | label_path = "./dataset/" + filename + "/label.csv" 156 | X = pd.read_csv(data_path, header=0, index_col=0, sep=',') 157 | cell_label = pd.read_csv(label_path).values[:,-1] 158 | if data_type == 'h5': 159 | data_path = "./dataset/" + filename + "/data.h5" 160 | mat, obs, var, uns = read_data(data_path, sparsify=False, skip_exprs=False) 161 | if isinstance(mat, np.ndarray): 162 | X = np.array(mat) 163 | else: 164 | X = np.array(mat.toarray()) 165 | cell_name = np.array(obs["cell_type1"]) 166 | cell_type, cell_label = np.unique(cell_name, return_inverse=True) 167 | return X, cell_label 168 | 169 | 170 | def Selecting_highly_variable_genes(X, highly_genes): 171 | adata = sc.AnnData(X) 172 | adata.var_names_make_unique() 173 | sc.pp.filter_genes(adata, min_counts=1) 174 | sc.pp.filter_cells(adata, min_counts=1) 175 | sc.pp.normalize_per_cell(adata) 176 | sc.pp.log1p(adata) 177 | adata.raw = adata 178 | sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5, n_top_genes=highly_genes) 179 | adata = adata[:, adata.var['highly_variable']].copy() 180 | sc.pp.scale(adata) 181 | data = adata.X 182 | 183 | return data 184 | 185 | def normalize(adata, copy=True, highly_genes = None, filter_min_counts=True, size_factors=True, normalize_input=True, logtrans_input=True): 186 | if isinstance(adata, sc.AnnData): 187 | if copy: 188 | adata = adata.copy() 189 | elif isinstance(adata, str): 190 | adata = sc.read(adata) 191 | else: 192 | raise NotImplementedError 193 | norm_error = 'Make sure that the dataset (adata.X) contains unnormalized count data.' 194 | assert 'n_count' not in adata.obs, norm_error 195 | if adata.X.size < 50e6: # check if adata.X is integer only if array is small 196 | if sp.sparse.issparse(adata.X): 197 | assert (adata.X.astype(int) != adata.X).nnz == 0, norm_error 198 | else: 199 | assert np.all(adata.X.astype(int) == adata.X), norm_error 200 | 201 | if filter_min_counts: 202 | sc.pp.filter_genes(adata, min_counts=1) 203 | sc.pp.filter_cells(adata, min_counts=1) 204 | if size_factors or normalize_input or logtrans_input: 205 | adata.raw = adata.copy() 206 | else: 207 | adata.raw = adata 208 | if size_factors: 209 | sc.pp.normalize_per_cell(adata) 210 | adata.obs['size_factors'] = adata.obs.n_counts / np.median(adata.obs.n_counts) 211 | else: 212 | adata.obs['size_factors'] = 1.0 213 | if logtrans_input: 214 | sc.pp.log1p(adata) 215 | if highly_genes != None: 216 | sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5, n_top_genes = highly_genes, subset=True) 217 | if normalize_input: 218 | sc.pp.scale(adata) 219 | return adata 220 | 221 | def getcluster(edgeList): 222 | 223 | #featureMatrix = pd.read_csv('./output/Top2000.csv').values[:,1:].T 224 | 225 | 226 | listResult, size = generateLouvainCluster(edgeList) 227 | n_clusters = len(np.unique(listResult)) 228 | #print('Louvain cluster: '+str(n_clusters)) 229 | return n_clusters 230 | -------------------------------------------------------------------------------- /Interaction_model/cluster/pretain_model/pbmc/pbmc.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Interaction_model/cluster/pretain_model/pbmc/pbmc.pkl -------------------------------------------------------------------------------- /Interaction_model/cluster/pretrain.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.parameter import Parameter 7 | from torch.utils.data import DataLoader 8 | from torch.optim import Adam, SGD, Adamax 9 | from torch.nn import Linear 10 | from torch.utils.data import Dataset 11 | from sklearn.metrics import silhouette_score, davies_bouldin_score 12 | from sklearn.cluster import SpectralBiclustering,KMeans, kmeans_plusplus, DBSCAN,SpectralClustering 13 | from evaluation import eva_pretrain 14 | import umap 15 | import argparse 16 | 17 | #torch.cuda.set_device(3) 18 | 19 | 20 | class AE(nn.Module): 21 | 22 | def __init__(self, n_enc_1, n_enc_2, n_enc_3, n_dec_1, n_dec_2, n_dec_3, 23 | n_input, n_z): 24 | super(AE, self).__init__() 25 | self.enc_1 = Linear(n_input, n_enc_1) 26 | self.enc_2 = Linear(n_enc_1, n_enc_2) 27 | self.enc_3 = Linear(n_enc_2, n_enc_3) 28 | self.z_layer = Linear(n_enc_3, n_z) 29 | 30 | self.dec_1 = Linear(n_z, n_dec_1) 31 | self.dec_2 = Linear(n_dec_1, n_dec_2) 32 | self.dec_3 = Linear(n_dec_2, n_dec_3) 33 | self.x_bar_layer = Linear(n_dec_3, n_input) 34 | 35 | def forward(self, x): 36 | enc_h1 = F.relu(self.enc_1(x)) 37 | enc_h2 = F.relu(self.enc_2(enc_h1)) 38 | enc_h3 = F.relu(self.enc_3(enc_h2)) 39 | z = self.z_layer(enc_h3) 40 | 41 | dec_h1 = F.relu(self.dec_1(z)) 42 | dec_h2 = F.relu(self.dec_2(dec_h1)) 43 | dec_h3 = F.relu(self.dec_3(dec_h2)) 44 | x_bar = self.x_bar_layer(dec_h3) 45 | 46 | return x_bar, z 47 | 48 | 49 | class LoadDataset(Dataset): 50 | def __init__(self, data): 51 | self.x = data 52 | 53 | def __len__(self): 54 | return self.x.shape[0] 55 | 56 | def __getitem__(self, idx): 57 | return torch.from_numpy(np.array(self.x[idx])).float(), \ 58 | torch.from_numpy(np.array(idx)) 59 | 60 | 61 | def adjust_learning_rate(optimizer, epoch): 62 | lr = 0.001 * (0.1 ** (epoch // 20)) 63 | for param_group in optimizer.param_groups: 64 | param_group['lr'] = lr 65 | 66 | 67 | def pretrain_ae(model,dataset,m,device,n_clusters,epoch,Auto=True): 68 | train_loader = DataLoader(dataset, batch_size=None, shuffle=True) 69 | #device = args.device 70 | #print(device) 71 | optimizer = Adamax(model.parameters(), lr=1e-2) 72 | for epoch in range(epoch): 73 | adjust_learning_rate(optimizer, epoch) 74 | for batch_idx, (x, _) in enumerate(train_loader): 75 | x = x.to(device) 76 | x_bar, _ = model(x) 77 | loss = F.mse_loss(x_bar, x) 78 | 79 | optimizer.zero_grad() 80 | loss.backward() 81 | optimizer.step() 82 | 83 | with torch.no_grad(): 84 | x = torch.Tensor(dataset.x).to(device).float() 85 | x_bar, z = model(x) 86 | loss = F.mse_loss(x_bar, x) 87 | 88 | #for i in range(0,100,10): 89 | if z.shape[0] < 5000: 90 | resolution = 0.8 91 | else: 92 | resolution = 0.5 93 | if Auto: 94 | n_clusters = int(clusters*resolution) if int(clusters*resolution)>=3 else 2 95 | #print(n_clusters) 96 | kmeans = KMeans(n_clusters=n_clusters, n_init=20).fit(z.data.cpu().numpy()) 97 | silhouette =eva_pretrain(z.data.cpu().numpy(), kmeans.labels_, epoch) 98 | #print(kmeans.labels_) 99 | #print('epoch{} loss: {}'.format(epoch, loss),'silhouette {:.4f}'.format(silhouette)) 100 | torch.save(model.state_dict(), './output/model/test'+str(m)+'.pkl') 101 | 102 | return silhouette 103 | 104 | 105 | -------------------------------------------------------------------------------- /Interaction_model/cluster/test.txt: -------------------------------------------------------------------------------- 1 | nohup: ignoring input 2 | use cuda: True 3 | Performing log-normalization 4 | 0% 10 20 30 40 50 60 70 80 90 100% 5 | [----|----|----|----|----|----|----|----|----|----| 6 | **************************************************| 7 | Calculating gene variances 8 | 0% 10 20 30 40 50 60 70 80 90 100% 9 | [----|----|----|----|----|----|----|----|----|----| 10 | **************************************************| 11 | Calculating feature variances of standardized and clipped values 12 | 0% 10 20 30 40 50 60 70 80 90 100% 13 | [----|----|----|----|----|----|----|----|----|----| 14 | **************************************************| 15 | 9 16 | generate cell graph... 17 | Start pretrain 18 | pretrain:0 19 | pretrain:1 20 | pretrain:2 21 | pretrain:3 22 | pretrain:4 23 | pretrain:5 24 | -------------------------------------------------------------------------------- /Interaction_model/cluster/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | import h5py 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | 8 | def load_graph(dataset): 9 | 10 | path = 'output/graph/{}_graph.txt'.format(dataset) 11 | 12 | data = np.loadtxt('output/data/{}.txt'.format(dataset)) 13 | n, _ = data.shape 14 | 15 | idx = np.array([i for i in range(n)], dtype=np.int32) 16 | idx_map = {j: i for i, j in enumerate(idx)} 17 | edges_unordered = np.genfromtxt(path, dtype=np.int32) 18 | edges = np.array(list(map(idx_map.get, edges_unordered.flatten())), 19 | dtype=np.int32).reshape(edges_unordered.shape) 20 | adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])), 21 | shape=(n, n), dtype=np.float32) 22 | 23 | 24 | # build symmetric adjacency matrix 25 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) 26 | adj = adj + sp.eye(adj.shape[0]) 27 | adj = normalize(adj) 28 | adj = sparse_mx_to_torch_sparse_tensor(adj) 29 | 30 | return adj 31 | 32 | 33 | def normalize(mx): 34 | """Row-normalize sparse matrix""" 35 | rowsum = np.array(mx.sum(1)) 36 | r_inv = np.power(rowsum, -1).flatten() 37 | r_inv[np.isinf(r_inv)] = 0. 38 | r_mat_inv = sp.diags(r_inv) 39 | mx = r_mat_inv.dot(mx) 40 | return mx 41 | 42 | 43 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 44 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 45 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 46 | indices = torch.from_numpy( 47 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 48 | values = torch.from_numpy(sparse_mx.data) 49 | shape = torch.Size(sparse_mx.shape) 50 | return torch.sparse.FloatTensor(indices, values, shape) 51 | 52 | 53 | class load_data(Dataset): 54 | def __init__(self, dataset): 55 | self.x = np.loadtxt('output/data/{}.txt'.format(dataset), dtype=float) 56 | self.y = np.loadtxt('../input/test_label.txt', dtype=int) 57 | 58 | def __len__(self): 59 | return self.x.shape[0] 60 | 61 | def __getitem__(self, idx): 62 | return torch.from_numpy(np.array(self.x[idx])),\ 63 | torch.from_numpy(np.array(self.y[idx])),\ 64 | torch.from_numpy(np.array(idx)) 65 | 66 | 67 | -------------------------------------------------------------------------------- /Interaction_model/dataset.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import pandas as pd 3 | import numpy as np 4 | from PIL import Image 5 | import torch 6 | from sklearn.model_selection import StratifiedKFold,KFold 7 | from scipy import sparse 8 | from torch.utils.data import Dataset 9 | from torch.utils.data import DataLoader 10 | import torchvision.transforms as transforms 11 | 12 | 13 | class AttributesDataset(): 14 | def __init__(self, annotation_path): 15 | class_labels=[] 16 | multi_labels=[] 17 | with open(annotation_path) as f: 18 | reader = csv.DictReader(f) 19 | for row in reader: 20 | class_labels.append(row['Labels']) 21 | for line in reader: 22 | self.multi_labels.append(line[1:-2]) 23 | self.class_labels = np.unique(class_labels) 24 | self.num_labels = len(self.class_labels) 25 | 26 | #self.class_labels_id_to_name = dict(zip(range(len(self.class_labels)), self.class_labels)) 27 | 28 | #self.class_labels_name_to_id = dict(zip(self.class_labels, range(len(self.class_labels)))) 29 | 30 | 31 | 32 | 33 | 34 | 35 | class ForDataset(Dataset): 36 | def __init__(self, data_train, annotation, transform=None): 37 | super().__init__() 38 | 39 | self.transform = transform 40 | #self.attr = attributes 41 | #self.data=dd.read_csv(data_train,header=None).values.compute() 42 | self.data=data_train 43 | #print(np.shape(data)) 44 | self.data_idx=[] 45 | # initialize the arrays to store the ground truth labels and paths to the images 46 | #self.class_labels=[] 47 | 48 | #self.labels=[] 49 | # read the annotations from the CSV file 50 | #with open(annotation_path) as f: 51 | # reader = csv.DictReader(f) 52 | # for row in reader: 53 | #for i in range(len()) 54 | self.data_idx=annotation.astype(np.int32).tolist() 55 | #self.class_labels.append(self.attr.class_labels_name_to_id[row['Labels']]) 56 | #self.class_labels=dd.read_csv(annotation_path).values.compute()[:,-2] 57 | 58 | #self.labels=annotation[:,-1] 59 | #print(self.multi_labels.shape) 60 | def __len__(self): 61 | return len(self.data_idx) 62 | 63 | def __getitem__(self, idx): 64 | # take the data sample by its index 65 | 66 | data=self.data 67 | #print(data.shape) 68 | 69 | #class_labels=np.array(self.class_labels).astype(np.int32) 70 | #labels=np.array(self.labels).astype(np.int32) 71 | #print(np.shape(data)) 72 | # read image 73 | #img = Image.open(img_path) 74 | 75 | # apply the image augmentations if needed 76 | if self.transform: 77 | data = self.transform(data) 78 | # return the image and all the associated labels 79 | dict_data = { 80 | 'data': data[idx], 81 | } 82 | #print(dict_data) 83 | 84 | return dict_data 85 | 86 | -------------------------------------------------------------------------------- /Interaction_model/input/Download.md: -------------------------------------------------------------------------------- 1 | The test data can be download from http://jianglab.org.cn/deepcci_download/ -------------------------------------------------------------------------------- /Interaction_model/model/checkpoint-000100.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Interaction_model/model/checkpoint-000100.pth -------------------------------------------------------------------------------- /Interaction_model/modelv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Parameter 5 | import math 6 | #import torchvision.models as models 7 | from Mobilev2 import MobileNetV2 8 | 9 | class FocalLoss(nn.Module): 10 | def __init__(self, alpha=0.65, gamma=2, logits=False, reduce=True): 11 | super(FocalLoss, self).__init__() 12 | self.alpha = alpha 13 | self.gamma = gamma 14 | self.logits = logits 15 | self.reduce = reduce 16 | 17 | def forward(self, inputs, targets): 18 | if self.logits: 19 | BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') 20 | else: 21 | BCE_loss = F.binary_cross_entropy(inputs, targets, reduction='none') 22 | pt = torch.exp(-BCE_loss) 23 | F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss 24 | 25 | if self.reduce: 26 | return torch.mean(F_loss) 27 | else: 28 | return F_loss 29 | 30 | 31 | 32 | class GraphConvolution(nn.Module): 33 | """ 34 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 35 | """ 36 | 37 | def __init__(self, in_features, out_features, bias=False): 38 | super(GraphConvolution, self).__init__() 39 | self.in_features = in_features 40 | self.out_features = out_features 41 | self.weight = Parameter(torch.Tensor(in_features, out_features)) 42 | if bias: 43 | self.bias = Parameter(torch.Tensor(1, 1, out_features)) 44 | else: 45 | self.register_parameter('bias', None) 46 | self.reset_parameters() 47 | 48 | def reset_parameters(self): 49 | stdv = 1. / math.sqrt(self.weight.size(1)) 50 | self.weight.data.uniform_(-stdv, stdv) 51 | if self.bias is not None: 52 | self.bias.data.uniform_(-stdv, stdv) 53 | 54 | def forward(self, input, adj): 55 | support = torch.matmul(input, self.weight) 56 | output = torch.matmul(adj, support) 57 | if self.bias is not None: 58 | return output + self.bias 59 | else: 60 | return output 61 | 62 | def __repr__(self): 63 | return self.__class__.__name__ + ' (' \ 64 | + str(self.in_features) + ' -> ' \ 65 | + str(self.out_features) + ')' 66 | class MultiOutputModel(nn.Module): 67 | 68 | def __init__(self,nfeat,nlabel): 69 | 70 | super().__init__() 71 | models=MobileNetV2() 72 | self.base_model = MobileNetV2().features # take the model without classifier 73 | last_channel = MobileNetV2().last_channel # size of the layer before classifier 74 | self.pool = nn.AdaptiveMaxPool2d((1, 1)) 75 | self.gc1 = GraphConvolution(nfeat, 512) 76 | self.gc2 = GraphConvolution(512, last_channel) 77 | self.relu = nn.LeakyReLU(0.2) 78 | self.fc = nn.Linear(nlabel,1) 79 | self.dropout=nn.Dropout(0.2) 80 | 81 | def forward(self, x, feature, adj): 82 | x = self.base_model(x) 83 | x = self.pool(x) 84 | x = self.dropout(x) 85 | # reshape from [batch, channels, 1, 1] to [batch, channels] to put it into classifier 86 | x = torch.flatten(x, 1) 87 | feature = self.gc1(feature, adj) 88 | #feature = self.dropout(feature) 89 | feature = self.relu(feature) 90 | feature = self.gc2(feature, adj) 91 | #feature = self.dropout(feature) 92 | feature = self.relu(feature) 93 | feature = feature.transpose(0, 1) 94 | 95 | x = torch.matmul(x,feature) 96 | x = self.fc(x) 97 | x = self.dropout(x) 98 | x = x.squeeze(-1) 99 | xt = torch.sigmoid(x) 100 | #print(y) 101 | return {'class': x,'sigmoid': xt 102 | 103 | } 104 | def initialize(self): 105 | for m in self.modules(): 106 | if isinstance(m, nn.Linear): 107 | nn.init.xavier_uniform_(m.weight, gain=1) 108 | #print(m.weight) 109 | 110 | def get_config_optim(self, lr, lrp): 111 | return [ 112 | {'params': self.base_model.parameters(), 'lr': lr }, 113 | {'params': self.gc1.parameters(), 'lr': lrp}, 114 | {'params': self.gc2.parameters(), 'lr': lrp}, 115 | ] 116 | def get_loss(self, net_output, ground_truth): 117 | #crition=nn.BCEWithLogitsLoss() 118 | crition=FocalLoss() 119 | loss = crition(net_output['sigmoid'].float(), ground_truth['labels'].float()) 120 | #loss = crition(net_output['class'].float(), ground_truth['labels'].float()) 121 | #crition2 = nn.MultiLabelSoftMarginLoss() 122 | #loss = crition2(net_output['mutil'], ground_truth['multi_labels']) 123 | return loss -------------------------------------------------------------------------------- /Interaction_model/nohup.out: -------------------------------------------------------------------------------- 1 | The cell barcodes in 'meta' is 1 2 3 4 5 6 2 | The cell groups used for CellChat analysis are B CD14+ Mono CD8 T DC FCGR3A+ Mono Memory CD4 T Naive CD4 T NK Platelet 3 | Issue identified!! Please check the official Gene Symbol of the following genes: 4 | COL6A4 CD1D2 CD1D1 CD209A CLEC2G CLEC2H CLEC2I CLEC2F DSG1C ITGAL_ITGB2L H2-T23 H2-M2 H2-M3 H2-M9 H2-Q1 H2-Q10 H2-Q2 H2-Q7 H2-Q8 H2-T22 H2-T24 H2-T3 H2-T9 H2-M10.4 H2-M11 H2-M1 H2-M10.5 H2-M10.2 H2-M10.6 H2-T10 H2-Q6 H2-M10.3 H2-BL H2-D1 H2-K1 H2-M10.1 H2-T18 H2-Q9 GM8909 H2-L GM10499 H2-Q4 H2-T-PS GM11127 H2-M5 GM7030 H2-D1 GM8909 GM10499 H2-BI H2-D RAET1A H60a H2-Ea-ps H2-AA H2-AB1 H2-EB1 H2-DMB1 H2-DMB2 YARS WISP3 CTGF CYR61 IL8 MFI2 MLLT4 NOV MTf SHP1 SHP2 RAP TIM-1 BY55 HLAE AREGB C1orf200 C5orf55 CGB FIGF KAL1 CCL12 CD55B PIRA2 CD200R2 CD200R3 CD200R4 SIGLECG KLRB1B KLRB1F CD209F ITGAM_ITGB2L KLRA TIMD2 CD97 NGFRAP1 HFE2 C14orf1 TMEM8A PVRL2 PPAPDC2 DARC EMR2 ELTD1 PVRL1 PVRL3 PVRL4 BAI2 LPHN1 LPHN2 GPR56 ECRF3 US28 UL12 E1 KSHV 5 | -------------------------------------------------------------------------------- /Interaction_model/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | import torch 4 | import matplotlib.pyplot as plt 5 | import pickle as pkl 6 | 7 | 8 | def encode_onehot(labels): 9 | classes = set(labels) 10 | classes_dict = {c: np.identity(len(classes))[i, :] for i, c in 11 | enumerate(classes)} 12 | labels_onehot = np.array(list(map(classes_dict.get, labels)), 13 | dtype=np.int32) 14 | return labels_onehot 15 | 16 | 17 | def load_data(path=" ", dataset="test"): 18 | """Load citation network dataset (cora only for now)""" 19 | print('Loading {} dataset...'.format(dataset)) 20 | 21 | idx_features_labels = np.genfromtxt("{}{}.content".format(path, dataset), 22 | dtype=np.dtype(str)) 23 | features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32) 24 | labels = encode_onehot(idx_features_labels[:, -1]) 25 | 26 | # build graph 27 | idx = np.array(idx_features_labels[:, 0], dtype=np.int32) 28 | idx_map = {j: i for i, j in enumerate(idx)} 29 | edges_unordered = np.genfromtxt("{}{}.cites".format(path, dataset), 30 | dtype=np.int32) 31 | edges = np.array(list(map(idx_map.get, edges_unordered.flatten())), 32 | dtype=np.int32).reshape(edges_unordered.shape) 33 | adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])), 34 | shape=(labels.shape[0], labels.shape[0]), 35 | dtype=np.float32) 36 | 37 | # build symmetric adjacency matrix 38 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) 39 | 40 | features = normalize(features) 41 | adj = normalize(adj + sp.eye(adj.shape[0])) 42 | 43 | features = torch.FloatTensor(np.array(features.todense())) 44 | #labels = torch.LongTensor(np.where(labels)[1]) 45 | adj = sparse_mx_to_torch_sparse_tensor(adj) 46 | 47 | 48 | 49 | return adj, features 50 | 51 | 52 | def normalize(mx): 53 | """Row-normalize sparse matrix""" 54 | rowsum = np.array(mx.sum(1)) 55 | r_inv = np.power(rowsum, -1).flatten() 56 | r_inv[np.isinf(r_inv)] = 0. 57 | r_mat_inv = sp.diags(r_inv) 58 | mx = r_mat_inv.dot(mx) 59 | return mx 60 | 61 | 62 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 63 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 64 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 65 | indices = torch.from_numpy( 66 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 67 | values = torch.from_numpy(sparse_mx.data) 68 | shape = torch.Size(sparse_mx.shape) 69 | return torch.sparse.FloatTensor(indices, values, shape) 70 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 JiangBioLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Plot/Plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | 5 | def get_files(path,rule): 6 | all = [] 7 | for fpathe, dirs, fs in os.walk(path): 8 | for f in fs: 9 | if f.endswith(rule): 10 | all.append(f) 11 | return all 12 | 13 | if not os.path.exists("./output/"): 14 | os.system('mkdir ./output/') 15 | 16 | os.system('python heatmap.py') 17 | 18 | os.system('python chord.py') 19 | 20 | os.system('python network.py') 21 | cell_type = np.loadtxt('./cell_type.csv',delimiter=",", dtype=str) 22 | para1="./output/chord/chord_all.pdf" 23 | para2="./output/CCImatix.csv" 24 | 25 | #para5=("brown" "green" "cornflowerblue" "blueviolet") 26 | os.system("Rscript chord.R"+" "+para1+" "+para2) 27 | 28 | for i in range(len(cell_type)): 29 | #print(str(cell_type[i])) 30 | para3=str("./output/heatmap/Plot/"+str(cell_type[i])+".pdf") 31 | para4=str("./output/heatmap/File/"+str(cell_type[i])+".csv") 32 | os.system("Rscript heatmap.R"+" "+para3+" "+para4) 33 | 34 | para5=str("./output/chord/Plot/"+str(cell_type[i])+".pdf") 35 | para6=str("./output/chord/File/"+str(cell_type[i])+".csv") 36 | os.system("Rscript chord.R"+" "+para5+" "+para6) 37 | 38 | para7="./output/CCImatix.csv" 39 | para8="./output/bubble.pdf" 40 | os.system("Rscript bubble.R"+" "+para7+" "+para8) 41 | 42 | net_file = get_files(path='./output/Network/File', rule=".csv") 43 | 44 | for i in range(len(net_file)): 45 | para9 = "./output/Network/File/"+str(net_file[i]) 46 | para10 = "./output/Network/Plot/"+str(net_file[i].split(".")[0])+".pdf" 47 | os.system("Rscript network.R"+" "+para9+" "+para10) 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /Plot/bubble.R: -------------------------------------------------------------------------------- 1 | library(ggplot2) 2 | library(RColorBrewer) 3 | library(reshape2) 4 | args=commandArgs(T) 5 | parameter1 = args[1] 6 | parameter2 = args[2] 7 | data.final<-read.csv(parameter1,header=T,row.names = 1,check.names=F) 8 | data.final<-as.matrix(data.final) 9 | mydata <- melt(data.final) 10 | colnames(mydata)<-c("Cell_type1","Cell_type2","value") 11 | pdf(file = parameter2,width = 9,height = 8) 12 | ggplot(mydata, aes(x= Cell_type1 , y=Cell_type2)) +theme_bw()+ 13 | geom_point(aes(size=value,fill = value), shape=21, colour="black") + 14 | scale_fill_gradientn(colours=c(brewer.pal(7,"Blues")[7],brewer.pal(7,"Reds")[6]),na.value=NA)+ 15 | scale_size_area(max_size=14, guide = "none") + 16 | #geom_text(aes(label=value),color="white",size=5) + 17 | theme(panel.grid = element_blank(), 18 | text=element_text(size=13,face="plain",color="black"),axis.text.x=element_text(angle=45,hjust =1,vjust=1), 19 | axis.title=element_text(size=0,face="plain",color="white"), 20 | axis.text = element_text(size=23,face="plain",color="black"), 21 | legend.position="right" 22 | ) 23 | while (!is.null(dev.list())) dev.off() 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /Plot/cell_type.csv: -------------------------------------------------------------------------------- 1 | B 2 | CD14+Mono 3 | CD8T 4 | DC 5 | FCGR3A+Mono 6 | MemoryCD4T 7 | NK 8 | NaiveCD4T 9 | Platelet 10 | -------------------------------------------------------------------------------- /Plot/chord.R: -------------------------------------------------------------------------------- 1 | suppressMessages(library(RColorBrewer)) 2 | suppressMessages(library(ggalluvial)) 3 | suppressMessages(library(ggplot2)) 4 | suppressMessages(library(circlize)) 5 | options(warn = -1) 6 | 7 | args=commandArgs(T) 8 | parameter1 = args[1] 9 | parameter2 = args[2] 10 | #parameter3 = args[3] 11 | #grid_col = c("red", "green", "cornflowerblue","blueviolet", brewer.pal(8, 'Accent')[6], "darkseagreen1", "hotpink1", "gold", "slateblue3", brewer.pal(8, 'Set3')[5],brewer.pal(8, 'Set3')[4]) 12 | #grid_col = parameter3 13 | #print(grid_col) 14 | pdf(file = parameter1,width = 10,height = 10) 15 | 16 | mat <- read.csv(file = parameter2, row.names = 1,head=T) 17 | colnames(mat) <- rownames(mat) 18 | y = data.matrix(mat) 19 | #circos.par( clock.wise = FALSE,cex = 2) 20 | #par(cex = 1.5) 21 | chordDiagram(y,directional = 1, 22 | direction.type = c("arrows"), 23 | link.arr.type = "triangle",big.gap = 40, small.gap = 10,annotationTrack = c("name", "grid"),annotationTrackHeight = c(0.03, 0.06)) 24 | while (!is.null(dev.list())) dev.off() 25 | 26 | -------------------------------------------------------------------------------- /Plot/chord.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import ast 4 | import operator 5 | from itertools import chain,groupby 6 | import math 7 | import os 8 | from scipy import sparse 9 | from collections import defaultdict 10 | if not os.path.exists("./output/chord"): 11 | os.system('mkdir ./output/chord') 12 | if not os.path.exists("./output/chord/File"): 13 | os.system('mkdir ./output/chord/File') 14 | if not os.path.exists("./output/chord/Plot"): 15 | os.system('mkdir ./output/chord/Plot') 16 | CC_net = pd.read_csv('../Interaction_model/output/CCI_out.csv',header=None).values 17 | CC_net_data = CC_net[1:,1:] 18 | CC_pval = CC_net_data[:,5].astype(np.float16) 19 | source = CC_net_data[:,0] 20 | target = CC_net_data[:,1] 21 | for i in range(len(source)): 22 | source[i] = source[i].replace(" ","") 23 | target[i] = target[i].replace(" ","") 24 | 25 | Cell_type_old = np.unique(source) 26 | Cell_type = [] 27 | for i in range(len(Cell_type_old)): 28 | Cell_type.append(Cell_type_old[i]) 29 | 30 | 31 | matix= np.zeros((len(Cell_type),len(Cell_type))) 32 | 33 | for i in range(len(Cell_type)): 34 | for j in range(len(Cell_type)): 35 | for k in range(len(source)): 36 | if source[k] == Cell_type[i] and target[k] == Cell_type[j]: 37 | if CC_pval[k]==1: 38 | matix[i,j]+=1 39 | pd.DataFrame(matix.astype(np.int32), index=Cell_type,columns=Cell_type).to_csv("./output/CCImatix.csv",quoting=1) 40 | 41 | for i in range(len(Cell_type)): 42 | CellMatix=np.zeros((len(Cell_type),len(Cell_type))) 43 | CCImatix = pd.read_csv('./output/CCImatix.csv',header=None).values[1:,1:] 44 | CellMatix[i,:] = CCImatix[i,:] 45 | CellMatix[:,i] = CCImatix[:,i] 46 | #print(CellMatix) 47 | pd.DataFrame(CellMatix.astype(np.int32), index=Cell_type,columns=Cell_type).to_csv("./output/chord/File/"+Cell_type[i]+".csv",quoting=1) 48 | -------------------------------------------------------------------------------- /Plot/heatmap.R: -------------------------------------------------------------------------------- 1 | library(ggplot2) 2 | library(RColorBrewer) 3 | library(reshape2) 4 | library(pheatmap) 5 | 6 | args=commandArgs(T) 7 | parameter1 = args[1] 8 | parameter2 = args[2] 9 | pdf(file = parameter1,width = 12,height = 5) 10 | 11 | data.final<-read.csv(parameter2,header=T,row.names = 1,check.names=F) 12 | pheatmap((data.final),cluster_rows = FALSE, cluster_cols = FALSE,cellheight = 8, cellwidth = 8, 13 | cexCol = 1,angle_col = "45",fontsize=6) 14 | while (!is.null(dev.list())) dev.off() -------------------------------------------------------------------------------- /Plot/heatmap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import ast 4 | import operator 5 | from itertools import chain,groupby 6 | import math 7 | import os 8 | from scipy import sparse 9 | from collections import defaultdict 10 | if not os.path.exists("./output/heatmap"): 11 | os.system('mkdir ./output/heatmap') 12 | if not os.path.exists("./output/heatmap/File"): 13 | os.system('mkdir ./output/heatmap/File') 14 | if not os.path.exists("./output/heatmap/Plot"): 15 | os.system('mkdir ./output/heatmap/Plot') 16 | 17 | columns = ["Cell2Cell","Interaction","prob","pval"] 18 | 19 | Cell = [] 20 | prob = [] 21 | pval = [] 22 | pair = [] 23 | CC_net = pd.read_csv('../Interaction_model/output/CCI_out.csv',header=None).values 24 | CC_net_data = CC_net[1:,1:] 25 | 26 | 27 | CC_prob = CC_net_data[:,4].astype(np.float16) 28 | CC_pval = CC_net_data[:,5].astype(np.int32) 29 | source = [] 30 | target = [] 31 | Interaction = [] 32 | 33 | for i in range(CC_net_data.shape[0]): 34 | #print(i) 35 | 36 | Interaction.append(CC_net_data[i,7]) 37 | 38 | for i in range(CC_net_data.shape[0]): 39 | source.append(CC_net_data[i,0].replace(" ","")) 40 | target.append(CC_net_data[i,1].replace(" ","")) 41 | #Interaction.append(CC_net_data[i,7].replace("(","")) 42 | #Interaction.append(CC_net_data[i,7]) 43 | source = np.array(source) 44 | 45 | 46 | Cell_type = np.unique(source) 47 | #Cell_type = ['B','Basophil','CD14+Mono','CD1C+_Bdendriticcell','Circulatingfetalcell','CD8+T','DC','FCGR3A+Mono','NaiveT'] 48 | np.savetxt("./cell_type.csv",Cell_type,fmt="%s",delimiter=",") 49 | #print(Cell_type) 50 | for j in range(len(Cell_type)): 51 | Cell = [] 52 | prob = [] 53 | pair = [] 54 | 55 | for i in range(len(source)): 56 | if str(source[i]) == Cell_type[j]: 57 | if CC_pval[i] == 1: 58 | Cell.append(str(source[i])+" | "+str(target[i])) 59 | pair.append(Interaction[i]) 60 | prob.append(CC_prob[i]) 61 | idx = np.argsort(-np.array(prob)) 62 | Cellnew = np.array(Cell)[idx] 63 | pairnew = np.array(pair)[idx] 64 | probnew = np.array(prob)[idx] 65 | probnew = probnew 66 | #pvalnew = np.array(pval)[idx] 67 | Cell_unique = np.unique(np.array(Cellnew)) 68 | pair_unique = np.unique(np.array(pairnew)) 69 | 70 | matix = np.zeros((len(Cell_unique),len(pair_unique))) 71 | if len(Cellnew)>=150: 72 | #Len=150 73 | Len=150 74 | for m in range((Len)): 75 | for k in range(len(Cell_unique)): 76 | for i in range(len(pair_unique)): 77 | if Cell_unique[k] == Cellnew[m] and pair_unique[i] == pairnew[m]: 78 | matix[k,i] = float(probnew[m]) 79 | else: 80 | #print(len(Cellnew)) 81 | Len=len(Cellnew) 82 | for m in range((Len)): 83 | for k in range(len(Cell_unique)): 84 | for i in range(len(pair_unique)): 85 | if Cell_unique[k] == Cellnew[m] and pair_unique[i] == pairnew[m]: 86 | matix[k,i] = float(probnew[m]) 87 | 88 | 89 | 90 | idx = np.argwhere(np.all(matix[..., :] == 0, axis=0)) 91 | a2 = np.delete(matix, idx, axis=1) 92 | idx_all = np.arange(len(pair_unique)) 93 | idx1 = np.delete(idx_all, idx) 94 | #print(idx1) 95 | pd.DataFrame(a2, index=Cell_unique,columns=pair_unique[idx1]).to_csv("./output/heatmap/File/"+str(Cell_type[j])+".csv",quoting=1) 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /Plot/network.R: -------------------------------------------------------------------------------- 1 | suppressMessages(library(psych)) 2 | suppressMessages(library(qgraph)) 3 | suppressMessages(library(igraph)) 4 | suppressMessages(library(purrr)) 5 | suppressMessages(library(RColorBrewer)) 6 | args=commandArgs(T) 7 | parameter1 = args[1] 8 | parameter2 = args[2] 9 | options(warn = -1) 10 | netf<- read.csv(parameter1,header=T,row.names = 1,check.names=F) 11 | mynet <- netf 12 | net<- graph_from_data_frame(mynet) 13 | 14 | allcolour = c("red", "green", "cornflowerblue","blueviolet", brewer.pal(8, 'Accent')[6], "darkseagreen1", "hotpink1", "gold", "slateblue3", brewer.pal(8, 'Set3')[5],brewer.pal(8, 'Set3')[4]) 15 | pdf(file = parameter2,width = 12,height = 12) 16 | 17 | karate_groups <- cluster_optimal(net) 18 | coords <- layout_in_circle(net, order = 19 | order(membership(karate_groups))) 20 | 21 | E(net)$width <- E(net)$count*1 22 | V(net)$color <- allcolour 23 | E(net)$color <- tail_of(net, E(net))$color 24 | plot(net, edge.arrow.size=.5, 25 | edge.curved=0, 26 | #edge.color=allcolour, 27 | #vertex.color=allcolour, 28 | vertex.frame.color="#555555", 29 | vertex.label.color="black", 30 | layout = coords, 31 | vertex.size = 30, 32 | vertex.label.cex=2) 33 | while (!is.null(dev.list())) dev.off() 34 | 35 | -------------------------------------------------------------------------------- /Plot/network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import ast 4 | import operator 5 | from itertools import chain,groupby 6 | import math 7 | import os 8 | from scipy import sparse 9 | from collections import defaultdict 10 | if not os.path.exists("./output/Network"): 11 | os.system('mkdir ./output/Network') 12 | if not os.path.exists("./output/Network/File"): 13 | os.system('mkdir ./output/Network/File') 14 | if not os.path.exists("./output/Network/Plot"): 15 | os.system('mkdir ./output/Network/Plot') 16 | 17 | columns = ["Source","Target","count"] 18 | 19 | CC_net = pd.read_csv('../Interaction_model/output/CCI_out.csv',header=None).values 20 | CC_net_data = CC_net[1:,1:] 21 | CC_pval = CC_net_data[:,5].astype(np.int32) 22 | source = [] 23 | target = [] 24 | Ligand_all = [] 25 | Receptor_all = [] 26 | cc_idx = [] 27 | for i in range(len(CC_pval)): 28 | if (CC_pval[i]) == 1: 29 | cc_idx.append(i) 30 | CC_new = CC_net_data[cc_idx,:] 31 | 32 | pair_name = CC_new[:,6] 33 | source = [] 34 | target = [] 35 | prob = [] 36 | for i in range(CC_new.shape[0]): 37 | source.append(CC_new[i,0].replace(" ","")) 38 | target.append(CC_new[i,1].replace(" ","")) 39 | prob.append(CC_new[i,4]) 40 | 41 | 42 | pair = np.unique(pair_name) 43 | for j in range(len(pair)): 44 | Source = [] 45 | Target = [] 46 | count = [] 47 | for i in range(len(pair_name)): 48 | if str(pair_name[i]) == pair[j]: 49 | Source.append(str(source[i])) 50 | Target.append(str(target[i])) 51 | count.append(str(prob[i])) 52 | INTER1 = np.vstack((np.array(Source),np.array(Target))) 53 | #INTER2 = np.vstack(((INTER1),(CellYnew))) 54 | #INTER3 = np.vstack(((INTER2),(Probnew))) 55 | INTER = np.vstack((INTER1,np.array(count))).T 56 | index = np.arange(1,len(Source)+1) 57 | pd.DataFrame(INTER, index=index,columns=columns).to_csv("./output/Network/File/"+str(pair[j])+".csv",quoting=1) 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepCCI (Deep learning framework for Cell-Cell Interactions inference from scRNA-seq data) 2 | 3 | DeepCCI is a graph convolutional network (GCN)-based deep learning framework for Cell-Cell Interactions inference from scRNA-seq data. 4 | ![workflow](https://user-images.githubusercontent.com/72069543/169433397-ff34dce1-717f-446e-8b0a-0e1b5ccf6da6.png) 5 | 6 | 7 | ## Installation: 8 | 9 | ### From Source: 10 | 11 | Start by grabbing this source codes: 12 | 13 | ``` 14 | git clone https://github.com/JiangBioLab/DeepCCI.git 15 | cd DeepCCI 16 | ``` 17 | 18 | ### (Recommended) Use python virutal environment with conda(https://anaconda.org/) 19 | 20 | ``` 21 | conda create -n deepcciEnv python=3.7.4 pip 22 | conda activate deepcciEnv 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | **because data processing and result visualization require R and then also install the following in R:** 27 | 28 | ``` 29 | conda install r-base 30 | R 31 | install.packages(‘Seurat’) 32 | install.packages("igraph") 33 | install.packages('NMF') 34 | install.packages("devtools") 35 | devtools::install_github("jokergoo/circlize") 36 | devtools::install_github("jokergoo/ComplexHeatmap") 37 | devtools::install_github("sqjin/CellChat") 38 | devtools::install_github('satijalab/seurat-data') 39 | ``` 40 | 41 | ### 42 | ### Quick Start 43 | 44 | ### 1. Cell Cluster Model 45 | 46 | #### **(1) Preprocess input files** 47 | 48 | The cluster model of DeepCCI accepts scRNA-seq data format: CSV and h5. The processed feature file of scRNA-seq data will be provided. Depending on the size of the scRNA-seq file,the process will take about 5-10 minutes. 49 | 50 | ##### CSV format 51 | 52 | Take an example of Yan 's (GSE36552). 53 | 54 | ``` 55 | cd Cluster_model 56 | python preprocess.py --name Yan --file_format csv 57 | ``` 58 | 59 | ##### h5 format 60 | 61 | Take an example Qx Limb Muscle (GSE109774 ). 62 | 63 | ``` 64 | cd Cluster_model 65 | python preprocess.py --name Quake_10x_Limb_Muscle --file_format h5 66 | ``` 67 | 68 | #### (2) Cell Clustering 69 | The clustering results of scRNA-seq data will be output. 70 | ##### With pre-train: 71 | It will take about 25 minutes. 72 | ``` 73 | python Cluster.py --name Yan --pretain True --pretrain_epoch 50 --device cuda 74 | ``` 75 | 76 | Without pre-train: 77 | The pretrained model files are in the pretain_model folder. 78 | It will take about 5 minutes. 79 | ``` 80 | python Cluster.py --name Yan --pretain False --device cuda 81 | python Cluster.py --name Quake_10x_Limb_Muscle --pretain False --device cuda 82 | ``` 83 | 84 | ### 2. Cell Interaction Model 85 | 86 | #### **(1) Preprocess input files** 87 | 88 | The example test file can be download from http://jianglab.org.cn/deepcci_download/. 89 | The processed feature file will be provided. Depending on the size of the scRNA-seq file,the process will take about 10-20 minutes. 90 | 91 | ##### With cell-label: 92 | 93 | ``` 94 | cd Interaction_model 95 | python Feature.py --label_mode True --species Human 96 | ``` 97 | 98 | ##### Without cell-label 99 | 100 | ``` 101 | python Feature.py --label_mode False --species Human 102 | ``` 103 | 104 | #### (2) Interaction Inference 105 | The predicted interaction outfile will be provided. The predicted process will take about 1-2 minutes. 106 | ``` 107 | python Interaction_inference.py --device cuda 108 | ``` 109 | 110 | ### 3. Visualization 111 | 112 | ##### 113 | To show the CCI output intuitively, several visualization methods are provided. The process will take about 1 minutes. 114 | ``` 115 | cd Plot 116 | python Plot.py 117 | ``` 118 | 119 | ## Contact 120 | 121 | Feel free to submit an issue or contact us at wenyiyang22@163.com for problems about the package. 122 | -------------------------------------------------------------------------------- /Rpack.Rdata: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiangBioLab/DeepCCI/0ad534f91f3c669ab97546741778d977f3347242/Rpack.Rdata -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.13.0 2 | alabaster==0.7.12 3 | anaconda-client==1.7.2 4 | anaconda-navigator==1.9.7 5 | anaconda-project==0.8.3 6 | anndata==0.7.6 7 | asn1crypto==1.0.1 8 | astroid==2.3.1 9 | astropy==3.2.2 10 | atomicwrites==1.3.0 11 | attrs==19.2.0 12 | Babel==2.7.0 13 | backcall==0.1.0 14 | backports.functools-lru-cache==1.5 15 | backports.os==0.1.1 16 | backports.shutil-get-terminal-size==1.0.0 17 | backports.tempfile==1.0 18 | backports.weakref==1.0.post1 19 | beautifulsoup4==4.8.0 20 | bitarray==1.0.1 21 | bkcharts==0.2 22 | bleach==3.1.0 23 | bokeh==1.3.4 24 | boto==2.49.0 25 | Bottleneck==1.2.1 26 | cached-property==1.5.2 27 | certifi==2019.9.11 28 | cffi==1.12.3 29 | chardet==3.0.4 30 | chord==1.0.2 31 | Click==7.0 32 | cloudpickle==1.2.2 33 | clyent==1.2.2 34 | colorama==0.4.1 35 | conda==4.10.3 36 | conda-build==3.18.9 37 | conda-package-handling==1.6.0 38 | conda-verify==3.4.2 39 | contextlib2==0.6.0 40 | cryptography==2.7 41 | cycler==0.10.0 42 | Cython==0.29.13 43 | cytoolz==0.10.0 44 | dask==2.5.2 45 | decorator==4.4.0 46 | defusedxml==0.6.0 47 | distributed==2.5.2 48 | docutils==0.15.2 49 | dunamai==1.5.5 50 | entrypoints==0.3 51 | et-xmlfile==1.0.1 52 | fastcache==1.1.0 53 | filelock==3.0.12 54 | Flask==1.1.1 55 | fsspec==0.5.2 56 | future==0.17.1 57 | get-version==3.2 58 | gevent==1.4.0 59 | glob2==0.7 60 | gmpy2==2.0.8 61 | graphviz==0.20 62 | greenlet==0.4.15 63 | grpcio==1.39.0 64 | h5py==3.3.0 65 | HeapDict==1.0.1 66 | html5lib==1.0.1 67 | idna==2.8 68 | imageio==2.6.0 69 | imagesize==1.1.0 70 | importlib-metadata==0.23 71 | ipykernel==5.1.2 72 | ipython==7.8.0 73 | ipython-genutils==0.2.0 74 | ipywidgets==7.5.1 75 | isort==4.3.21 76 | itsdangerous==1.1.0 77 | jdcal==1.4.1 78 | jedi==0.15.1 79 | jeepney==0.4.1 80 | jgraph==0.2.1 81 | Jinja2==2.10.3 82 | joblib==0.13.2 83 | json5==0.8.5 84 | jsonschema==3.0.2 85 | jupyter==1.0.0 86 | jupyter-client==5.3.3 87 | jupyter-console==6.0.0 88 | jupyter-core==4.5.0 89 | jupyterlab==1.1.4 90 | jupyterlab-server==1.0.6 91 | keyring==18.0.0 92 | kiwisolver==1.1.0 93 | lazy-object-proxy==1.4.2 94 | legacy-api-wrap==1.2 95 | leidenalg==0.8.2 96 | libarchive-c==2.8 97 | lief==0.9.0 98 | llvmlite==0.34.0 99 | locket==0.2.0 100 | lxml==4.4.1 101 | Markdown==3.3.4 102 | MarkupSafe==1.1.1 103 | matplotlib==3.4.2 104 | mccabe==0.6.1 105 | mistune==0.8.4 106 | mkl-fft==1.0.14 107 | mkl-random==1.1.0 108 | mkl-service==2.3.0 109 | mock==3.0.5 110 | more-itertools==7.2.0 111 | mpmath==1.1.0 112 | msgpack==0.6.1 113 | multipledispatch==0.6.0 114 | munkres==1.1.4 115 | natsort==7.1.1 116 | navigator-updater==0.2.1 117 | nbconvert==5.6.0 118 | nbformat==4.4.0 119 | networkx==2.3 120 | nltk==3.4.5 121 | nose==1.3.7 122 | notebook==6.0.1 123 | numba==0.51.2 124 | numexpr==2.7.0 125 | numpy==1.21.0 126 | numpydoc==0.9.1 127 | olefile==0.46 128 | openpyxl==3.0.0 129 | packaging==20.0 130 | pandas==1.1.1 131 | pandocfilters==1.4.2 132 | parso==0.5.1 133 | partd==1.0.0 134 | path.py==12.0.1 135 | pathlib2==2.3.5 136 | patsy==0.5.1 137 | pep8==1.7.1 138 | pexpect==4.7.0 139 | pickleshare==0.7.5 140 | Pillow==6.2.0 141 | pkginfo==1.5.0.1 142 | pluggy==0.13.0 143 | ply==3.11 144 | prometheus-client==0.7.1 145 | prompt-toolkit==2.0.10 146 | protobuf==3.17.3 147 | psutil==5.6.3 148 | ptyprocess==0.6.0 149 | py==1.8.0 150 | pycairo==1.18.0 151 | pycodestyle==2.5.0 152 | pycosat==0.6.3 153 | pycparser==2.19 154 | pycrypto==2.6.1 155 | pycurl==7.43.0.3 156 | pydot==1.4.2 157 | pyflakes==2.1.1 158 | Pygments==2.4.2 159 | pylint==2.4.2 160 | pynndescent==0.5.2 161 | pyodbc==4.0.27 162 | pyOpenSSL==19.0.0 163 | pyparsing==2.4.2 164 | pyrsistent==0.15.4 165 | PySocks==1.7.1 166 | pytest==5.2.1 167 | pytest-arraydiff==0.3 168 | pytest-astropy==0.5.0 169 | pytest-doctestplus==0.4.0 170 | pytest-openfiles==0.4.0 171 | pytest-remotedata==0.3.2 172 | python-dateutil==2.8.0 173 | python-igraph==0.7.1.post7 174 | pytz==2019.3 175 | PyWavelets==1.0.3 176 | PyYAML==5.1.2 177 | pyzmq==18.1.0 178 | QtAwesome==0.6.0 179 | qtconsole==4.5.5 180 | QtPy==1.9.0 181 | requests==2.22.0 182 | rope==0.14.0 183 | ruamel-yaml==0.15.46 184 | scanpy==1.8.0 185 | scikit-image==0.15.0 186 | scikit-learn==0.24.2 187 | scipy==1.6.2 188 | seaborn==0.9.0 189 | SecretStorage==3.1.1 190 | Send2Trash==1.5.0 191 | simplegeneric==0.8.1 192 | sinfo==0.3.4 193 | singledispatch==3.4.0.3 194 | six==1.12.0 195 | snowballstemmer==2.0.0 196 | sortedcollections==1.1.2 197 | sortedcontainers==2.1.0 198 | soupsieve==1.9.3 199 | Sphinx==2.2.0 200 | sphinxcontrib-applehelp==1.0.1 201 | sphinxcontrib-devhelp==1.0.1 202 | sphinxcontrib-htmlhelp==1.0.2 203 | sphinxcontrib-jsmath==1.0.1 204 | sphinxcontrib-qthelp==1.0.2 205 | sphinxcontrib-serializinghtml==1.1.3 206 | sphinxcontrib-websupport==1.1.2 207 | spyder==3.3.6 208 | spyder-kernels==0.5.2 209 | SQLAlchemy==1.3.9 210 | statsmodels==0.10.1 211 | stdlib-list==0.8.0 212 | sympy==1.4 213 | tables==3.5.2 214 | tblib==1.4.0 215 | tensorboard==1.15.0 216 | terminado==0.8.2 217 | testpath==0.4.2 218 | threadpoolctl==2.1.0 219 | toolz==0.10.0 220 | torch==1.10.0 221 | torchaudio==0.9.0a0+33b2469 222 | torchvision==0.10.0 223 | torchviz==0.0.2 224 | tornado==6.0.3 225 | tqdm==4.36.1 226 | traitlets==4.3.3 227 | typing-extensions==3.10.0.0 228 | umap-learn==0.5.1 229 | unicodecsv==0.14.1 230 | urllib3==1.24.2 231 | wcwidth==0.1.7 232 | webencodings==0.5.1 233 | Werkzeug==0.16.0 234 | widgetsnbextension==3.5.1 235 | wrapt==1.11.2 236 | wurlitzer==1.0.3 237 | xlrd==1.2.0 238 | XlsxWriter==1.2.1 239 | xlwt==1.3.0 240 | zict==1.0.0 241 | zipp==0.6.0 242 | --------------------------------------------------------------------------------