├── DSTG ├── R_utils.R ├── __init__.py ├── convert_data.R ├── data.py ├── evaluation.R ├── graph.py ├── gutils.py ├── layers.py ├── metrics.py ├── models.py ├── synthetic_data │ ├── example_data.RDS │ └── example_label.RDS ├── train.py └── utils.py ├── LICENCE ├── README.md ├── requirments.txt └── setup.py /DSTG/R_utils.R: -------------------------------------------------------------------------------- 1 | #' evalaute performance 2 | JSD_performance <- function(spots_true_composition, spots_predicted_composition) 3 | { 4 | suppressMessages(require(philentropy)) 5 | jsd_matrix <- matrix(nrow = nrow(spots_true_composition), ncol = 1) 6 | 7 | for (i in seq_len(nrow(spots_true_composition))) { 8 | x <- rbind(spots_true_composition[i, ], spots_predicted_composition[i, ]) 9 | if (sum(spots_predicted_composition[i, ]) > 0) { 10 | jsd_matrix[i, 1] <- suppressMessages(JSD(x = x, unit = "log2", est.prob = "empirical")) 11 | } 12 | else { jsd_matrix[i, 1] <- 1 } } 13 | 14 | quants_jsd <- round(quantile(matrixStats::rowMins( 15 | jsd_matrix, 16 | na.rm = TRUE), c(0.25, 0.5, 0.75)), 5) 17 | cat(sprintf("The following JSD quantiles are obtained: 18 | %s [%s - %s]", 19 | quants_jsd[[2]], quants_jsd[[1]], quants_jsd[[3]]), sep = "\n") 20 | return(list(JSD = jsd_matrix)) 21 | } 22 | 23 | #' normalize function 24 | normalize_data <- function(count.list){ 25 | norm.list <- vector('list') 26 | var.features <- vector('list') 27 | for ( i in 1:length(count.list)){ 28 | norm.list[[i]] <- as.matrix(Seurat:::NormalizeData.default(count.list[[i]],verbose=FALSE)) 29 | hvf.info <- Seurat:::FindVariableFeatures.default(count.list[[i]],selection.method='vst',verbose=FALSE) 30 | hvf.info <- hvf.info[which(x = hvf.info[, 1, drop = TRUE] != 0), ] 31 | hvf.info <- hvf.info[order(hvf.info$vst.variance.standardized, decreasing = TRUE), , drop = FALSE] 32 | var.features[[i]] <- head(rownames(hvf.info), n = 2000) 33 | } 34 | sel.features <- selectIntegrationFeature(count.list,var.features) 35 | return (list(norm.list,sel.features))} 36 | 37 | #' scaling function 38 | scale_data <- function(count.list,norm.list,hvg.features){ 39 | scale.list <- lapply(norm.list,function(mat){ 40 | Seurat:::ScaleData.default(object = mat, features = hvg.features,verbose=FALSE)}) 41 | scale.list <- lapply(1:length(count.list),function(i){ 42 | return (scale.list[[i]][na.omit(match(rownames(count.list[[i]]),rownames(scale.list[[i]]))),])}) 43 | return (scale.list)} 44 | 45 | #' select HVG genes 46 | selectIntegrationFeature <- function(count.list,var.features,nfeatures = 2000){ 47 | var.features1 <- unname(unlist(var.features)) 48 | var.features2 <- sort(table(var.features1), decreasing = TRUE) 49 | for (i in 1:length(count.list)) { 50 | var.features3 <- var.features2[names(var.features2) %in% rownames(count.list[[i]])]} 51 | tie.val <- var.features3[min(nfeatures, length(var.features3))] 52 | features <- names(var.features3[which(var.features3 > tie.val)]) 53 | if (length(features) > 0) { 54 | feature.ranks <- sapply(features, function(x) { 55 | ranks <- sapply(var.features, function(y) { 56 | if (x %in% y) { 57 | return(which(x == y)) 58 | } 59 | return(NULL) 60 | }) 61 | median(unlist(ranks)) 62 | }) 63 | features <- names(sort(feature.ranks)) 64 | } 65 | features.tie <- var.features3[which(var.features3 == tie.val)] 66 | tie.ranks <- sapply(names(features.tie), function(x) { 67 | ranks <- sapply(var.features, function(y) { 68 | if (x %in% y) {return(which(x == y))} 69 | return(NULL) 70 | }) 71 | median(unlist(ranks)) 72 | }) 73 | features <- c(features, names(head(sort(tie.ranks), nfeatures - length(features)))) 74 | return(features) 75 | } 76 | 77 | #' select variable genes 78 | select_feature <- function(data,label,nf=2000){ 79 | M <- nrow(data); new.label <- label[,1] 80 | pv1 <- sapply(1:M, function(i){ 81 | mydataframe <- data.frame(y=as.numeric(data[i,]), ig=new.label) 82 | fit <- aov(y ~ ig, data=mydataframe) 83 | summary(fit)[[1]][["Pr(>F)"]][1]}) 84 | names(pv1) <- rownames(data) 85 | pv1.sig <- names(pv1)[order(pv1)[1:nf]] 86 | egen <- unique(pv1.sig) 87 | return (egen) 88 | } 89 | 90 | #' This function takes pseudo-spatail and real-spatial data to identify variable genes 91 | data_process <- function(st_count,st_label,anova){ 92 | if (anova){ 93 | sel.features <- select_feature(st_count[[1]],st_label[[1]]) 94 | st_count_new <- list(st_count[[1]][sel.features,],st_count[[2]][sel.features,]) 95 | 96 | colnames(st_label[[1]]) <- 'subclass' 97 | tem.t1 <- Seurat::CreateSeuratObject(counts = st_count_new[[1]],meta.data=st_label[[1]]); 98 | Seurat::Idents(object = tem.t1) <- tem.t1@meta.data$subclass 99 | 100 | #' convert scRNA-seq data to pseudo-spatial data 101 | test.spot.ls1<-SPOTlight::test_spot_fun(se_obj=tem.t1,clust_vr='subclass',n=1000); 102 | test.spot.counts1 <- as.matrix(test.spot.ls1[[1]]) 103 | colnames(test.spot.counts1)<-paste("mixt",1:ncol(test.spot.counts1),sep="_"); 104 | metadata1 <- test.spot.ls1[[2]] 105 | test.spot.metadata1 <- do.call(rbind,lapply(1:nrow(metadata1),function(i){metadata1[i,]/sum(metadata1[i,])})) 106 | st_counts <- list(test.spot.counts1,st_count_new[[2]]) 107 | 108 | st_label[[1]] <- test.spot.metadata1 109 | N1 <- ncol(st_counts[[1]]); N2 <- ncol(st_counts[[2]]) 110 | label.list2 <- do.call("rbind", rep(list(st_label[[1]]), round(N2/N1)+1))[1:N2,] 111 | st_labels <- list(st_label[[1]],label.list2) 112 | } else { 113 | st_counts <- st_count; st_labels=st_label } 114 | res1 <- normalize_data(st_counts) 115 | st_norm <- res1[[1]]; variable_gene <- res1[[2]]; 116 | st_scale <- scale_data(st_counts,st_norm,variable_gene) 117 | return (list(st_counts,st_labels,st_norm,st_scale,variable_gene)) 118 | } 119 | 120 | #' @param count.list list of pseudo-spatial data and real-spatial data, of which rows are genes and columns are cells 121 | #' @param label.list list of pseudo-spatail label and real-spatial label (if any) 122 | #' @return This function returns files saved in folders "Datadir" & "Infor_Data" 123 | Convert_Data <- function(count.list,label.list,anova=TRUE){ 124 | step1 <- data_process(st_count=count.list,st_label=label.list,anova) 125 | st.count <- step1[[1]]; 126 | st.label <- step1[[2]]; 127 | st.norm <- step1[[3]]; 128 | st.scale <- step1[[4]]; 129 | variable.genes <- step1[[5]] 130 | 131 | #' create data folders 132 | dir.create('Datadir'); dir.create('Output'); dir.create('DSTG_Result') 133 | inforDir <- 'Infor_Data'; dir.create(inforDir) 134 | 135 | #' save counts data to certain path: 'Datadir' 136 | write.csv(t(st.count[[1]]),file='Datadir/Pseudo_ST1.csv',quote=F,row.names=T) 137 | write.csv(t(st.count[[2]]),file='Datadir/Real_ST2.csv',quote=F,row.names=T) 138 | 139 | #' save scaled data to certain path: 'Infor_Data' 140 | write.csv(variable.genes,file=paste0(inforDir,'/Variable_features.csv'),quote=F,row.names=F) 141 | 142 | if (!dir.exists(paste0(inforDir,'/ST_count'))){dir.create(paste0(inforDir,'/ST_count'))} 143 | if (!dir.exists(paste0(inforDir,'/ST_label'))){dir.create(paste0(inforDir,'/ST_label'))} 144 | if (!dir.exists(paste0(inforDir,'/ST_norm'))){dir.create(paste0(inforDir,'/ST_norm'))} 145 | if (!dir.exists(paste0(inforDir,'/ST_scale'))){dir.create(paste0(inforDir,'/ST_scale'))} 146 | 147 | for (i in 1:2){ 148 | write.csv(st.count[[i]],file=paste0(inforDir,'/ST_count/ST_count_',i,'.csv'),quote=F) 149 | write.csv(st.label[[i]],file=paste0(inforDir,'/ST_label/ST_label_',i,'.csv'),quote=F) 150 | write.csv(st.norm[[i]],file=paste0(inforDir,'/ST_norm/ST_norm_',i,'.csv'),quote=F) 151 | write.csv(st.scale[[i]],file=paste0(inforDir,'/ST_scale/ST_scale_',i,'.csv'),quote=F) 152 | } 153 | } 154 | 155 | 156 | 157 | # se_obj is a seurat object. 158 | 159 | test_spot_fun = function (se_obj, clust_vr, n = 1000, verbose = TRUE){ 160 | if (is(se_obj) != "Seurat") 161 | stop("ERROR: se_obj must be a Seurat object!") 162 | if (!is.character(clust_vr)) 163 | stop("ERROR: clust_vr must be a character string!") 164 | if (!is.numeric(n)) 165 | stop("ERROR: n must be an integer!") 166 | if (!is.logical(verbose)) 167 | stop("ERROR: verbose must be a logical object!") 168 | suppressMessages(require(DropletUtils)) 169 | suppressMessages(require(purrr)) 170 | suppressMessages(require(dplyr)) 171 | suppressMessages(require(tidyr)) 172 | se_obj@meta.data[, clust_vr] <- gsub(pattern = "[[:punct:]]|[[:blank:]]", 173 | ".", x = se_obj@meta.data[, clust_vr], perl = TRUE) 174 | print("Generating synthetic test spots...") 175 | start_gen <- Sys.time() 176 | pb <- txtProgressBar(min = 0, max = n, style = 3) 177 | count_mtrx <- as.matrix(se_obj@assays$RNA@counts) 178 | ds_spots <- lapply(seq_len(n), function(i) { 179 | cell_pool <- sample(colnames(count_mtrx), sample(x = 2:10, 180 | size = 1)) 181 | pos <- which(colnames(count_mtrx) %in% cell_pool) 182 | tmp_ds <- se_obj@meta.data[pos, ] %>% mutate(weight = 1) 183 | name_simp <- paste("spot_", i, sep = "") 184 | spot_ds <- tmp_ds %>% dplyr::select(all_of(clust_vr), 185 | weight) %>% dplyr::group_by(!!sym(clust_vr)) %>% 186 | dplyr::summarise(sum_weights = sum(weight)) %>% dplyr::ungroup() %>% 187 | tidyr::pivot_wider(names_from = all_of(clust_vr), 188 | values_from = sum_weights) %>% dplyr::mutate(name = name_simp) 189 | syn_spot <- rowSums(as.matrix(count_mtrx[, cell_pool])) 190 | sum(syn_spot) 191 | names_genes <- names(syn_spot) 192 | if (sum(syn_spot) > 25000) { 193 | syn_spot_sparse <- DropletUtils::downsampleMatrix(Matrix::Matrix(syn_spot, 194 | sparse = T), prop = 20000/sum(syn_spot)) 195 | } 196 | else { 197 | syn_spot_sparse <- Matrix::Matrix(syn_spot, sparse = T) 198 | } 199 | rownames(syn_spot_sparse) <- names_genes 200 | colnames(syn_spot_sparse) <- name_simp 201 | setTxtProgressBar(pb, i) 202 | return(list(syn_spot_sparse, spot_ds)) 203 | }) 204 | ds_syn_spots <- purrr::map(ds_spots, 1) %>% base::Reduce(function(m1, 205 | m2) cbind(unlist(m1), unlist(m2)), .) 206 | ds_spots_metadata <- purrr::map(ds_spots, 2) %>% dplyr::bind_rows() %>% 207 | data.frame() 208 | ds_spots_metadata[is.na(ds_spots_metadata)] <- 0 209 | lev_mod <- gsub("[\\+|\\ |\\/]", ".", unique(se_obj@meta.data[, 210 | clust_vr])) 211 | colnames(ds_spots_metadata) <- gsub("[\\+|\\ |\\/]", ".", 212 | colnames(ds_spots_metadata)) 213 | if (sum(lev_mod %in% colnames(ds_spots_metadata)) == (length(unique(se_obj@meta.data[, 214 | clust_vr])) + 1)) { 215 | ds_spots_metadata <- ds_spots_metadata[, lev_mod] 216 | } 217 | else { 218 | missing_cols <- lev_mod[which(!lev_mod %in% colnames(ds_spots_metadata))] 219 | ds_spots_metadata[missing_cols] <- 0 220 | ds_spots_metadata <- ds_spots_metadata[, lev_mod] 221 | } 222 | close(pb) 223 | print(sprintf("Generation of %s test spots took %s mins", 224 | n, round(difftime(Sys.time(), start_gen, units = "mins"), 225 | 2))) 226 | print("output consists of a list with two dataframes, this first one has the weighted count matrix and the second has the metadata for each spot") 227 | return(list(topic_profiles = ds_syn_spots, cell_composition = ds_spots_metadata)) 228 | } 229 | 230 | 231 | -------------------------------------------------------------------------------- /DSTG/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | -------------------------------------------------------------------------------- /DSTG/convert_data.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | args = commandArgs(trailingOnly=TRUE) 3 | 4 | #' This functions takes sythetic data to test DSTG's performance 5 | 6 | #' @return This function returns files saved in folders "Datadir" & "Infor_Data" 7 | #' @export: all files are saved in current path 8 | #' @examples: load data from folder "syntheic_data" 9 | 10 | source('R_utils.R') 11 | #' if you have the scRNA-seq data and want to decompose ST data, plese run below: 12 | 13 | if (length(args)==0) { 14 | message('run synthetic data...') 15 | synthetic.count <- readRDS('./synthetic_data/example_data.RDS') 16 | synthetic.label <- readRDS('./synthetic_data/example_label.RDS') 17 | Convert_Data(synthetic.count,synthetic.label,anova=FALSE) 18 | } else if (length(args)==3) { 19 | message('run real data...') 20 | sc.count <- readRDS(args[1]) 21 | st.count <- readRDS(args[2]) 22 | intersect.genes <- intersect(rownames(sc.count),rownames(st.count)) 23 | sc.count <- sc.count[intersect.genes,] 24 | st.count <- st.count[intersect.genes,] 25 | count.list <- list(sc.count,st.count) 26 | label.list <- list(data.frame(readRDS(args[3]),stringsAsFactors=F)) 27 | Convert_Data(count.list,label.list,anova=TRUE) 28 | } 29 | 30 | -------------------------------------------------------------------------------- /DSTG/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import random 4 | import pandas as pd 5 | import time as tm 6 | from operator import itemgetter 7 | from sklearn.model_selection import train_test_split 8 | import pickle as pkl 9 | import scipy.sparse 10 | from metrics import * 11 | from gutils import * 12 | from graph import * 13 | 14 | 15 | #' data preperation 16 | def input_data(DataDir): 17 | Link_Graph(outputdir='Infor_Data') 18 | DataPath1 = '{}/Pseudo_ST1.csv'.format(DataDir) 19 | DataPath2 = '{}/Real_ST2.csv'.format(DataDir) 20 | LabelsPath1 = '{}/Pseudo_Label1.csv'.format(DataDir) 21 | LabelsPath2 = '{}/Real_Label2.csv'.format(DataDir) 22 | 23 | #' read the data 24 | data1 = pd.read_csv(DataPath1, index_col=0, sep=',') 25 | data2 = pd.read_csv(DataPath2, index_col=0, sep=',') 26 | lab_label1 = pd.read_csv(LabelsPath1, header=0, index_col=False, sep=',') 27 | lab_label2 = pd.read_csv(LabelsPath2, header=0, index_col=False, sep=',') 28 | 29 | lab_data1 = data1.reset_index(drop=True) #.transpose() 30 | lab_data2 = data2.reset_index(drop=True) #.transpose() 31 | 32 | random.seed(123) 33 | p_data = lab_data1 34 | p_label = lab_label1 35 | 36 | temD_train, temd_test, temL_train, teml_test = train_test_split( 37 | p_data, p_label, test_size=0.1, random_state=1) 38 | temd_train, temd_val, teml_train, teml_val = train_test_split( 39 | temD_train, temL_train, test_size=0.1, random_state=1) 40 | 41 | print((temd_train.index == teml_train.index).all()) 42 | print((temd_test.index == teml_test.index).all()) 43 | print((temd_val.index == teml_val.index).all()) 44 | data_train = temd_train 45 | label_train = teml_train 46 | data_test = temd_test 47 | label_test = teml_test 48 | data_val = temd_val 49 | label_val = teml_val 50 | 51 | data_train1 = data_train 52 | data_test1 = data_test 53 | data_val1 = data_val 54 | label_train1 = label_train 55 | label_test1 = label_test 56 | label_val1 = label_val 57 | 58 | train2 = pd.concat([data_train1, lab_data2]) 59 | lab_train2 = pd.concat([label_train1, lab_label2]) 60 | 61 | #' save objects 62 | 63 | PIK = "{}/datasets.dat".format(DataDir) 64 | res = [ 65 | data_train1, data_test1, data_val1, label_train1, label_test1, 66 | label_val1, lab_data2, lab_label2 67 | ] 68 | 69 | with open(PIK, "wb") as f: 70 | pkl.dump(res, f) 71 | 72 | print('load data succesfully....') 73 | -------------------------------------------------------------------------------- /DSTG/evaluation.R: -------------------------------------------------------------------------------- 1 | 2 | #' check JSD score 3 | true <- read.csv('./DSTG_Result/true_output.csv',header=F) 4 | predict <- read.csv('./DSTG_Result/predict_output.csv',header=F) 5 | 6 | source('R_utils.R') 7 | 8 | jsd.score <- JSD_performance( 9 | spots_true_composition = as.matrix(true), 10 | spots_predicted_composition = as.matrix(predict)) 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /DSTG/graph.py: -------------------------------------------------------------------------------- 1 | from gutils import * 2 | import glob 3 | import pandas as pd 4 | import os 5 | import itertools 6 | 7 | 8 | def Link_Graph(outputdir): 9 | path0 = os.path.join(os.getcwd(), outputdir) 10 | 11 | #' import processed data 12 | files1 = glob.glob(path0 + "/ST_count/*.csv") 13 | files1.sort() 14 | count_list = [] 15 | for df in files1: 16 | print(df) 17 | count_list.append(pd.read_csv(df, index_col=0)) 18 | 19 | files2 = glob.glob(path0 + "/ST_norm/*.csv") 20 | files2.sort() 21 | norm_list = [] 22 | for df in files2: 23 | print(df) 24 | norm_list.append(pd.read_csv(df, index_col=0)) 25 | 26 | files3 = glob.glob(path0 + "/ST_scale/*.csv") 27 | files3.sort() 28 | scale_list = [] 29 | for df in files3: 30 | print(df) 31 | scale_list.append(pd.read_csv(df, index_col=0)) 32 | 33 | files4 = glob.glob(path0 + "/ST_label/*.csv") 34 | files4.sort() 35 | label_list = [] 36 | for df in files4: 37 | print(df) 38 | label_list.append(pd.read_csv(df, index_col=0)) 39 | 40 | fpath = os.path.join(path0, 'Variable_features.csv') 41 | features = pd.read_csv(fpath, index_col=False).values.flatten() 42 | 43 | N = len(count_list) 44 | if (N == 1): 45 | combine = pd.Series([(0, 0)]) 46 | else: 47 | combin = list(itertools.product(list(range(N)), list(range(N)))) 48 | index = [i for i, x in enumerate([i[0] < i[1] for i in combin]) if x] 49 | combine = pd.Series(combin)[index] 50 | 51 | link1 = Link_graph(count_list=count_list, 52 | norm_list=norm_list, 53 | scale_list=scale_list, 54 | features=features, 55 | combine=combine) 56 | 57 | #' ---- input data for link grpah 2 ----- 58 | files1 = glob.glob(path0 + "/ST_count/*.csv") 59 | files1.sort() 60 | tem_count = pd.read_csv(files1[1], index_col=0) 61 | tem_count.columns = tem_count.columns.str.replace("mixt_", "rept_") 62 | 63 | files2 = glob.glob(path0 + "/ST_norm/*.csv") 64 | files2.sort() 65 | tem_norm = pd.read_csv(files2[1], index_col=0) 66 | tem_norm.columns = tem_norm.columns.str.replace("mixt_", "rept_") 67 | 68 | files3 = glob.glob(path0 + "/ST_scale/*.csv") 69 | files3.sort() 70 | tem_scale = pd.read_csv(files3[1], index_col=0) 71 | tem_scale.columns = tem_scale.columns.str.replace("mixt_", "rept_") 72 | 73 | files4 = glob.glob(path0 + "/ST_label/*.csv") 74 | files4.sort() 75 | tem_label = pd.read_csv(files4[1], index_col=0) 76 | tem_label.columns = tem_label.columns.str.replace("mixt_", "rept_") 77 | 78 | count_list2 = [count_list[1],tem_count] 79 | norm_list2 = [norm_list[1], tem_norm] 80 | scale_list2 = [scale_list[1], tem_scale] 81 | 82 | link2 = Link_graph(count_list=count_list2, 83 | norm_list=norm_list2, 84 | scale_list=scale_list2, 85 | features=features, 86 | combine=combine, 87 | k_filter=100) 88 | 89 | graph1 = link1[0].iloc[:, 0:2].reset_index() 90 | graph1 = graph1.iloc[:,1:3] 91 | graph1.to_csv('./Datadir/Linked_graph1.csv') 92 | 93 | graph2 = link2[0].iloc[:, 0:2].reset_index() 94 | graph2 = graph2.iloc[:,1:3] 95 | graph2.to_csv('./Datadir/Linked_graph2.csv') 96 | 97 | label1 = label_list[0] 98 | label1.to_csv('./Datadir/Pseudo_Label1.csv', index=False) 99 | 100 | label2 = label_list[1] 101 | label2.to_csv('./Datadir/Real_Label2.csv', index=False) 102 | -------------------------------------------------------------------------------- /DSTG/gutils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import random 4 | from sklearn.neighbors import KDTree 5 | from metrics import * 6 | 7 | 8 | #' @param num.cc Number of canonical vectors to calculate 9 | #' @param seed.use Random seed to set. 10 | #' @importFrom SVD 11 | def embed(data1, data2, num_cc=20): 12 | random.seed(123) 13 | object1 = Scale(data1) 14 | object2 = Scale(data2) 15 | mat3 = np.matmul(np.matrix(object1).transpose(), np.matrix(object2)) 16 | a = SVD(mat=mat3, num_cc=int(num_cc)) 17 | embeds_data = np.concatenate((a[0], a[1])) 18 | ind = np.where( 19 | [embeds_data[:, col][0] < 0 for col in range(embeds_data.shape[1])])[0] 20 | embeds_data[:, ind] = embeds_data[:, ind] * (-1) 21 | 22 | embeds_data = pd.DataFrame(embeds_data) 23 | embeds_data.index = np.concatenate( 24 | (np.array(data1.columns), np.array(data2.columns))) 25 | embeds_data.columns = ['D_' + str(i) for i in range(num_cc)] 26 | d = a[2] 27 | #' d = np.around(a[2], 3) #.astype('int') 28 | return embeds_data, d 29 | 30 | 31 | def Embed(data_use1, data_use2, features, count_names, num_cc): 32 | features = checkFeature(data_use1, features) 33 | features = checkFeature(data_use2, features) 34 | data1 = data_use1.loc[features, ] 35 | data2 = data_use2.loc[features, ] 36 | embed_results = embed(data1=data1, data2=data2, num_cc=num_cc) 37 | cell_embeddings = np.matrix(embed_results[0]) 38 | combined_data = data1.merge(data2, 39 | left_index=True, 40 | right_index=True, 41 | how='inner') 42 | new_data1 = combined_data.loc[count_names, ].dropna() 43 | # loadings=loadingDim(new.data1,cell.embeddings) 44 | loadings = pd.DataFrame(np.matmul(np.matrix(new_data1), cell_embeddings)) 45 | loadings.index = new_data1.index 46 | return embed_results, loadings 47 | 48 | 49 | def checkFeature(data_use, features): 50 | data1 = data_use.loc[features, ] 51 | feature_var = data1.var(1) 52 | Var_features = features[np.where(feature_var != 0)[0]] 53 | return Var_features 54 | 55 | 56 | def kNN(data, k, query=None): 57 | tree = KDTree(data) 58 | if query is None: 59 | query = data 60 | dist, ind = tree.query(query, k) 61 | return dist, ind 62 | 63 | 64 | #' @param cell_embedding : pandas data frame 65 | def KNN(cell_embedding, spots1, spots2, k): 66 | embedding_spots1 = cell_embedding.loc[spots1, ] 67 | embedding_spots2 = cell_embedding.loc[spots2, ] 68 | nnaa = kNN(embedding_spots1, k=k + 1) 69 | nnbb = kNN(embedding_spots2, k=k + 1) 70 | nnab = kNN(data=embedding_spots2, k=k, query=embedding_spots1) 71 | nnba = kNN(data=embedding_spots1, k=k, query=embedding_spots2) 72 | return nnaa, nnab, nnba, nnbb, spots1, spots2 73 | 74 | 75 | def MNN(neighbors, colnames, num): 76 | max_nn = np.array([neighbors[1][1].shape[1], neighbors[2][1].shape[1]]) 77 | if ((num > max_nn).any()): 78 | num = np.min(max_nn) 79 | # convert cell name to neighbor index 80 | spots1 = colnames 81 | spots2 = colnames 82 | nn_spots1 = neighbors[4] 83 | nn_spots2 = neighbors[5] 84 | cell1_index = [ 85 | list(nn_spots1).index(i) for i in spots1 if (nn_spots1 == i).any() 86 | ] 87 | cell2_index = [ 88 | list(nn_spots2).index(i) for i in spots2 if (nn_spots2 == i).any() 89 | ] 90 | ncell = range(neighbors[1][1].shape[0]) 91 | ncell = np.array(ncell)[np.in1d(ncell, cell1_index)] 92 | # initialize a list 93 | mnn_cell1 = [None] * (len(ncell) * 5) 94 | mnn_cell2 = [None] * (len(ncell) * 5) 95 | idx = -1 96 | for cell in ncell: 97 | neighbors_ab = neighbors[1][1][cell, 0:5] 98 | mutual_neighbors = np.where( 99 | neighbors[2][1][neighbors_ab, 0:5] == cell)[0] 100 | for i in neighbors_ab[mutual_neighbors]: 101 | idx = idx + 1 102 | mnn_cell1[idx] = cell 103 | mnn_cell2[idx] = i 104 | mnn_cell1 = mnn_cell1[0:(idx + 1)] 105 | mnn_cell2 = mnn_cell2[0:(idx + 1)] 106 | import pandas as pd 107 | mnns = pd.DataFrame(np.column_stack((mnn_cell1, mnn_cell2))) 108 | mnns.columns = ['spot1', 'spot2'] 109 | return mnns 110 | 111 | 112 | def filterEdge(edges, neighbors, mats, features, k_filter): 113 | nn_spots1 = neighbors[4] 114 | nn_spots2 = neighbors[5] 115 | mat1 = mats.loc[features, nn_spots1].transpose() 116 | mat2 = mats.loc[features, nn_spots2].transpose() 117 | cn_data1 = l2norm(mat1) 118 | cn_data2 = l2norm(mat2) 119 | nn = kNN(data=cn_data2.loc[nn_spots2, ], 120 | query=cn_data1.loc[nn_spots1, ], 121 | k=k_filter) 122 | position = [ 123 | np.where( 124 | edges.loc[:, "spot2"][x] == nn[1][edges.loc[:, 'spot1'][x], ])[0] 125 | for x in range(edges.shape[0]) 126 | ] 127 | nps = np.concatenate(position, axis=0) 128 | fedge = edges.iloc[nps, ] 129 | return (fedge) 130 | 131 | 132 | def Link_graph(count_list, 133 | norm_list, 134 | scale_list, 135 | features, 136 | combine, 137 | k_filter=200): 138 | all_edges = [] 139 | for row in combine: 140 | i = row[0] 141 | j = row[1] 142 | counts1 = count_list[i] 143 | counts2 = count_list[j] 144 | norm_data1 = norm_list[i] 145 | norm_data2 = norm_list[j] 146 | scale_data1 = scale_list[i] 147 | scale_data2 = scale_list[j] 148 | rowname = counts1.index 149 | cell_embedding, loading = Embed(data_use1=scale_data1, 150 | data_use2=scale_data2, 151 | features=features, 152 | count_names=rowname, 153 | num_cc=30) 154 | norm_embedding = l2norm(mat=cell_embedding[0]) 155 | spots1 = counts1.columns 156 | spots2 = counts2.columns 157 | neighbor = KNN(cell_embedding=norm_embedding, 158 | spots1=spots1, 159 | spots2=spots2, 160 | k=30) 161 | mnn_edges = MNN(neighbors=neighbor, 162 | colnames=cell_embedding[0].index, 163 | num=5) 164 | select_genes = TopGenes(Loadings=loading, 165 | dims=range(30), 166 | DimGenes=100, 167 | maxGenes=200) 168 | Mat = pd.concat([norm_data1, norm_data2], axis=1) 169 | final_edges = filterEdge(edges=mnn_edges, 170 | neighbors=neighbor, 171 | mats=Mat, 172 | features=select_genes, 173 | k_filter=k_filter) 174 | final_edges['Dataset1'] = [i + 1] * final_edges.shape[0] 175 | final_edges['Dataset2'] = [j + 1] * final_edges.shape[0] 176 | all_edges.append(final_edges) 177 | return all_edges 178 | -------------------------------------------------------------------------------- /DSTG/layers.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import tensorflow as tf 3 | 4 | flags = tf.app.flags 5 | FLAGS = flags.FLAGS 6 | 7 | # global unique layer ID dictionary for layer name assignment 8 | _LAYER_UIDS = {} 9 | 10 | 11 | def get_layer_uid(layer_name=''): 12 | """Helper function, assigns unique layer IDs.""" 13 | if layer_name not in _LAYER_UIDS: 14 | _LAYER_UIDS[layer_name] = 1 15 | return 1 16 | else: 17 | _LAYER_UIDS[layer_name] += 1 18 | return _LAYER_UIDS[layer_name] 19 | 20 | 21 | def sparse_dropout(x, keep_prob, noise_shape): 22 | """Dropout for sparse tensors.""" 23 | random_tensor = keep_prob 24 | random_tensor += tf.random_uniform(noise_shape) 25 | dropout_mask = tf.cast(tf.floor(random_tensor), dtype=tf.bool) 26 | pre_out = tf.sparse_retain(x, dropout_mask) 27 | return pre_out * (1. / keep_prob) 28 | 29 | 30 | def glorot(shape, name=None): 31 | """Glorot & Bengio (AISTATS 2010) init.""" 32 | init_range = np.sqrt(6.0 / (shape[0] + shape[1])) 33 | initial = tf.random_uniform(shape, 34 | minval=-init_range, 35 | maxval=init_range, 36 | dtype=tf.float32) 37 | return tf.Variable(initial, name=name) 38 | 39 | 40 | def dot(x, y, sparse=False): 41 | """Wrapper for tf.matmul (sparse vs dense).""" 42 | if sparse: 43 | res = tf.sparse_tensor_dense_matmul(x, y) 44 | else: 45 | res = tf.matmul(x, y) 46 | return res 47 | 48 | 49 | class Layer(object): 50 | def __init__(self, **kwargs): 51 | allowed_kwargs = {'name', 'logging'} 52 | for kwarg in kwargs.keys(): 53 | assert kwarg in allowed_kwargs, 'Invalid keyword argument: ' + kwarg 54 | name = kwargs.get('name') 55 | if not name: 56 | layer = self.__class__.__name__.lower() 57 | name = layer + '_' + str(get_layer_uid(layer)) 58 | self.name = name 59 | self.vars = {} 60 | logging = kwargs.get('logging', False) 61 | self.logging = logging 62 | self.sparse_inputs = False 63 | 64 | def _call(self, inputs): 65 | return inputs 66 | 67 | def __call__(self, inputs): 68 | with tf.name_scope(self.name): 69 | if self.logging and not self.sparse_inputs: 70 | tf.summary.histogram(self.name + '/inputs', inputs) 71 | outputs = self._call(inputs) 72 | if self.logging: 73 | tf.summary.histogram(self.name + '/outputs', outputs) 74 | return outputs 75 | 76 | def _log_vars(self): 77 | for var in self.vars: 78 | tf.summary.histogram(self.name + '/vars/' + var, self.vars[var]) 79 | 80 | 81 | class GraphConvolution(Layer): 82 | def __init__(self, 83 | input_dim, 84 | output_dim, 85 | placeholders, 86 | dropout=0., 87 | sparse_inputs=False, 88 | act=tf.nn.relu, 89 | bias=False, 90 | featureless=False, 91 | **kwargs): 92 | super(GraphConvolution, self).__init__(**kwargs) 93 | 94 | if dropout: 95 | self.dropout = placeholders['dropout'] 96 | else: 97 | self.dropout = 0. 98 | 99 | self.act = act 100 | self.support = placeholders['support'] 101 | self.sparse_inputs = sparse_inputs 102 | self.featureless = featureless 103 | self.bias = bias 104 | 105 | # helper variable for sparse dropout 106 | self.num_features_nonzero = placeholders['num_features_nonzero'] 107 | 108 | with tf.variable_scope(self.name + '_vars'): 109 | for i in range(len(self.support)): 110 | self.vars['weights_' + str(i)] = glorot( 111 | [input_dim, output_dim], name='weights_' + str(i)) 112 | if self.bias: 113 | self.vars['bias'] = zeros([output_dim], name='bias') 114 | 115 | if self.logging: 116 | self._log_vars() 117 | 118 | def _call(self, inputs): 119 | x = inputs 120 | 121 | # dropout 122 | if self.sparse_inputs: 123 | x = sparse_dropout(x, 1 - self.dropout, self.num_features_nonzero) 124 | else: 125 | x = tf.nn.dropout(x, 1 - self.dropout) 126 | 127 | # convolve 128 | supports = list() 129 | for i in range(len(self.support)): 130 | if not self.featureless: 131 | pre_sup = dot(x, 132 | self.vars['weights_' + str(i)], 133 | sparse=self.sparse_inputs) 134 | else: 135 | pre_sup = self.vars['weights_' + str(i)] 136 | support = dot(self.support[i], pre_sup, sparse=True) 137 | supports.append(support) 138 | output = tf.add_n(supports) 139 | 140 | # bias 141 | if self.bias: 142 | output += self.vars['bias'] 143 | 144 | return self.act(output) 145 | -------------------------------------------------------------------------------- /DSTG/metrics.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | import scipy.sparse 3 | import numpy as np 4 | import pandas as pd 5 | from scipy import sparse as sp 6 | import networkx as nx 7 | from collections import defaultdict 8 | from scipy.stats import uniform 9 | import tensorflow as tf 10 | from sklearn import preprocessing 11 | 12 | 13 | def masked_softmax_cross_entropy(preds, labels, mask): 14 | """Softmax cross-entropy loss with masking.""" 15 | loss = tf.nn.softmax_cross_entropy_with_logits(logits=preds, labels=labels) 16 | mask = tf.cast(mask, dtype=tf.float32) 17 | mask /= tf.reduce_mean(mask) 18 | loss *= mask 19 | return tf.reduce_mean(loss) 20 | 21 | 22 | def masked_accuracy(preds, labels, mask): 23 | """Accuracy with masking.""" 24 | correct_prediction = tf.equal(tf.argmax(preds, 1), tf.argmax(labels, 1)) 25 | accuracy_all = tf.cast(correct_prediction, tf.float32) 26 | mask = tf.cast(mask, dtype=tf.float32) 27 | mask /= tf.reduce_mean(mask) 28 | accuracy_all *= mask 29 | return tf.reduce_mean(accuracy_all) 30 | 31 | 32 | def SVD(mat, num_cc): 33 | U, s, V = np.linalg.svd(mat) 34 | d = s[0:int(num_cc)] 35 | u = U[:, 0:int(num_cc)] 36 | v = V[0:int(num_cc), :].transpose() 37 | return u, v, d 38 | 39 | 40 | def Scale(x): 41 | y = preprocessing.scale(x) 42 | return y 43 | 44 | 45 | def l2norm(mat): 46 | stat = np.sqrt(np.sum(mat**2, axis=1)) 47 | cols = mat.columns 48 | mat[cols] = mat[cols].div(stat, axis=0) 49 | mat[np.isinf(mat)] = 0 50 | return mat 51 | 52 | 53 | def topGenes(Loadings, dim, numG): 54 | data = Loadings.iloc[:, dim] 55 | num = np.round(numG / 2).astype('int') 56 | data1 = data.sort_values(ascending=False) 57 | data2 = data.sort_values(ascending=True) 58 | posG = np.array(data1.index[0:num]) 59 | negG = np.array(data2.index[0:num]) 60 | topG = np.concatenate((posG, negG)) 61 | return topG 62 | 63 | 64 | def TopGenes(Loadings, dims, DimGenes, maxGenes): 65 | maxG = max(len(dims) * 2, maxGenes) 66 | gens = [None] * DimGenes 67 | idx = -1 68 | for i in range(1, DimGenes + 1): 69 | idx = idx + 1 70 | selg = [] 71 | for j in dims: 72 | selg.extend(set(topGenes(Loadings, dim=j, numG=i))) 73 | gens[idx] = set(selg) 74 | lens = np.array([len(i) for i in gens]) 75 | lens = lens[lens < maxG] 76 | maxPer = np.where(lens == np.max(lens))[0][0] + 1 77 | selg = [] 78 | for j in dims: 79 | selg.extend(set(topGenes(Loadings, dim=j, numG=maxPer))) 80 | selgene = np.array(list(set(selg)), dtype=object) 81 | return (selgene) 82 | 83 | 84 | def preprocess_features(features): 85 | """Row-normalize feature matrix and convert to tuple representation""" 86 | rowsum = np.array(features.sum(1)) 87 | r_inv = np.power(rowsum, -1).flatten() 88 | r_inv[np.isinf(r_inv)] = 0. 89 | r_mat_inv = sp.diags(r_inv) 90 | features = r_mat_inv.dot(features) 91 | return sparse_to_tuple(features) 92 | 93 | 94 | def sparse_to_tuple(sparse_mx): 95 | """Convert sparse matrix to tuple representation.""" 96 | def to_tuple(mx): 97 | if not sp.isspmatrix_coo(mx): 98 | mx = mx.tocoo() 99 | coords = np.vstack((mx.row, mx.col)).transpose() 100 | values = mx.data 101 | shape = mx.shape 102 | return coords, values, shape 103 | 104 | if isinstance(sparse_mx, list): 105 | for i in range(len(sparse_mx)): 106 | sparse_mx[i] = to_tuple(sparse_mx[i]) 107 | else: 108 | sparse_mx = to_tuple(sparse_mx) 109 | return sparse_mx 110 | 111 | 112 | def preprocess_adj(adj): 113 | """Preprocessing of adjacency matrix for scGCN model and conversion to tuple representation.""" 114 | adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0])) 115 | return sparse_to_tuple(adj_normalized) 116 | 117 | 118 | def normalize_adj(adj): 119 | """Symmetrically normalize adjacency matrix.""" 120 | adj = sp.coo_matrix(adj) 121 | rowsum = np.array(adj.sum(1)) 122 | d_inv_sqrt = np.power(rowsum, -0.5).flatten() 123 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 124 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 125 | return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() 126 | 127 | 128 | def construct_feed_dict(features, support, labels, labels_mask, placeholders): 129 | """Construct feed dictionary.""" 130 | feed_dict = dict() 131 | feed_dict.update({placeholders['labels']: labels}) 132 | feed_dict.update({placeholders['labels_mask']: labels_mask}) 133 | feed_dict.update({placeholders['features']: features}) 134 | feed_dict.update( 135 | {placeholders['support'][i]: support[i] 136 | for i in range(len(support))}) 137 | feed_dict.update({placeholders['num_features_nonzero']: features[1].shape}) 138 | return feed_dict 139 | 140 | 141 | def get_value(diction, specific): 142 | for key, val in diction.items(): 143 | if val == specific: 144 | return (key) 145 | 146 | 147 | def graph(matrix): 148 | adj = defaultdict(list) # default value of int is 0 149 | for i, row in enumerate(matrix): 150 | for j, adjacent in enumerate(row): 151 | if adjacent: 152 | adj[i].append(j) 153 | if adj[i].__len__ == 0: 154 | adj[i] = [] 155 | return adj 156 | 157 | 158 | def sample_mask(idx, l): 159 | """Create mask.""" 160 | mask = np.zeros(l) 161 | mask[idx] = 1 162 | return np.array(mask, dtype=np.bool) 163 | 164 | 165 | # convert nested lists to a flat list 166 | output = [] 167 | 168 | 169 | def removNestings(l): 170 | for i in l: 171 | if type(i) == list: 172 | removNestings(i) 173 | else: 174 | output.append(i) 175 | return (output) 176 | -------------------------------------------------------------------------------- /DSTG/models.py: -------------------------------------------------------------------------------- 1 | from layers import * 2 | from utils import * 3 | 4 | flags = tf.app.flags 5 | FLAGS = flags.FLAGS 6 | 7 | 8 | class Model(object): 9 | def __init__(self, **kwargs): 10 | allowed_kwargs = {'name', 'logging'} 11 | for kwarg in kwargs.keys(): 12 | assert kwarg in allowed_kwargs, 'Invalid keyword argument: ' + kwarg 13 | name = kwargs.get('name') 14 | if not name: 15 | name = self.__class__.__name__.lower() 16 | self.name = name 17 | 18 | logging = kwargs.get('logging', False) 19 | self.logging = logging 20 | 21 | self.vars = {} 22 | self.placeholders = {} 23 | 24 | self.layers = [] 25 | self.activations = [] 26 | 27 | self.inputs = None 28 | self.outputs = None 29 | 30 | self.loss = 0 31 | self.accuracy = 0 32 | self.optimizer = None 33 | self.opt_op = None 34 | 35 | def _build(self): 36 | raise NotImplementedError 37 | 38 | def build(self): 39 | """ Wrapper for _build() """ 40 | with tf.variable_scope(self.name): 41 | self._build() 42 | 43 | # Build sequential layer model 44 | self.activations.append(self.inputs) 45 | for layer in self.layers: 46 | hidden = layer(self.activations[-1]) 47 | self.activations.append(hidden) 48 | self.outputs = self.activations[-1] 49 | 50 | # Store model variables for easy access 51 | variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 52 | scope=self.name) 53 | self.vars = {var.name: var for var in variables} 54 | 55 | # Build metrics 56 | self._loss() 57 | self._accuracy() 58 | 59 | self.opt_op = self.optimizer.minimize(self.loss) 60 | 61 | def predict(self): 62 | pass 63 | 64 | def _loss(self): 65 | raise NotImplementedError 66 | 67 | def _accuracy(self): 68 | raise NotImplementedError 69 | 70 | def save(self, sess=None): 71 | if not sess: 72 | raise AttributeError("TensorFlow session not provided.") 73 | saver = tf.train.Saver(self.vars) 74 | save_path = saver.save(sess, "tmp/%s.ckpt" % self.name) 75 | print("Model saved in file: %s" % save_path) 76 | 77 | def load(self, sess=None): 78 | if not sess: 79 | raise AttributeError("TensorFlow session not provided.") 80 | saver = tf.train.Saver(self.vars) 81 | save_path = "tmp/%s.ckpt" % self.name 82 | saver.restore(sess, save_path) 83 | print("Model restored from file: %s" % save_path) 84 | 85 | 86 | class DSTG(Model): 87 | def __init__(self, placeholders, input_dim, **kwargs): 88 | super(DSTG, self).__init__(**kwargs) 89 | 90 | self.inputs = placeholders['features'] 91 | self.input_dim = input_dim 92 | # self.input_dim = self.inputs.get_shape().as_list()[1] # To be supported in future Tensorflow versions 93 | self.output_dim = placeholders['labels'].get_shape().as_list()[1] 94 | self.placeholders = placeholders 95 | 96 | self.optimizer = tf.train.AdamOptimizer( 97 | learning_rate=FLAGS.learning_rate) 98 | 99 | self.build() 100 | 101 | def _loss(self): 102 | # Weight decay loss 103 | for var in self.layers[0].vars.values(): 104 | self.loss += FLAGS.weight_decay * tf.nn.l2_loss(var) 105 | 106 | # Cross entropy error 107 | self.loss += masked_softmax_cross_entropy( 108 | self.outputs, self.placeholders['labels'], 109 | self.placeholders['labels_mask']) 110 | 111 | def _accuracy(self): 112 | self.accuracy = masked_accuracy(self.outputs, 113 | self.placeholders['labels'], 114 | self.placeholders['labels_mask']) 115 | 116 | def _build(self): 117 | 118 | self.layers.append( 119 | GraphConvolution(input_dim=self.input_dim, 120 | output_dim=FLAGS.hidden1, 121 | placeholders=self.placeholders, 122 | act=tf.nn.relu, 123 | dropout=True, 124 | sparse_inputs=True, 125 | logging=self.logging)) 126 | 127 | self.layers.append( 128 | GraphConvolution(input_dim=FLAGS.hidden1, 129 | output_dim=self.output_dim, 130 | placeholders=self.placeholders, 131 | act=lambda x: x, 132 | dropout=True, 133 | logging=self.logging)) 134 | 135 | def predict(self): 136 | return tf.nn.softmax(self.outputs) 137 | -------------------------------------------------------------------------------- /DSTG/synthetic_data/example_data.RDS: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Su-informatics-lab/DSTG/4a2f958a87ba3137c3ffa75188563af3e0245a5a/DSTG/synthetic_data/example_data.RDS -------------------------------------------------------------------------------- /DSTG/synthetic_data/example_label.RDS: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Su-informatics-lab/DSTG/4a2f958a87ba3137c3ffa75188563af3e0245a5a/DSTG/synthetic_data/example_label.RDS -------------------------------------------------------------------------------- /DSTG/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import numpy as np 5 | import pickle as pkl 6 | import tensorflow as tf 7 | from utils import * 8 | from models import DSTG 9 | 10 | import warnings 11 | warnings.filterwarnings("ignore") 12 | 13 | # Set random seed 14 | seed = 123 15 | np.random.seed(seed) 16 | tf.compat.v1.set_random_seed(seed) 17 | tf.set_random_seed(seed) 18 | 19 | # Settings 20 | flags = tf.app.flags 21 | FLAGS = flags.FLAGS 22 | flags.DEFINE_string('dataset', 'Datadir', 'Input data') 23 | flags.DEFINE_string('result', 'DSTG_Result', 'Output result') 24 | flags.DEFINE_string('model', 'DSTG', 'Model string.') 25 | flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.') 26 | flags.DEFINE_integer('epochs', 200, 'Number of epochs to train.') 27 | flags.DEFINE_integer('hidden1', 32, 'Number of units in hidden layer 1.') 28 | flags.DEFINE_float('dropout', 0, 'Dropout rate (1 - keep probability).') 29 | flags.DEFINE_float('weight_decay', 0, 30 | 'Weight for L2 loss on embedding matrix.') 31 | flags.DEFINE_integer('early_stopping', 10, 32 | 'Tolerance for early stopping (# of epochs).') 33 | # Load data 34 | adj, features, labels_binary_train, labels_binary_val, labels_binary_test, train_mask, pred_mask, val_mask, test_mask, new_label, true_label = load_data( 35 | FLAGS.dataset) 36 | 37 | support = [preprocess_adj(adj)] 38 | num_supports = 1 39 | model_func = DSTG 40 | 41 | # Define placeholders 42 | placeholders = { 43 | 'support': 44 | [tf.sparse_placeholder(tf.float32) for _ in range(num_supports)], 45 | 'features': 46 | tf.sparse_placeholder(tf.float32, 47 | shape=tf.constant(features[2], dtype=tf.int64)), 48 | 'labels': 49 | tf.placeholder(tf.float32, shape=(None, labels_binary_train.shape[1])), 50 | 'labels_mask': 51 | tf.placeholder(tf.int32), 52 | 'dropout': 53 | tf.placeholder_with_default(0., shape=()), 54 | 'num_features_nonzero': 55 | tf.placeholder(tf.int32) # helper variable for sparse dropout 56 | } 57 | 58 | # Create model 59 | model = model_func(placeholders, input_dim=features[2][1], logging=True) 60 | 61 | 62 | # Define model evaluation function 63 | def evaluate(features, support, labels, mask, placeholders): 64 | t_test = time.time() 65 | feed_dict_val = construct_feed_dict(features, support, labels, mask, 66 | placeholders) 67 | outs_val = sess.run([model.loss, model.accuracy], feed_dict=feed_dict_val) 68 | return outs_val[0], outs_val[1], (time.time() - t_test) 69 | 70 | 71 | # Initialize session 72 | sess = tf.Session() 73 | # Init variables 74 | sess.run(tf.global_variables_initializer()) 75 | 76 | train_accuracy = [] 77 | train_loss = [] 78 | val_accuracy = [] 79 | val_loss = [] 80 | test_accuracy = [] 81 | test_loss = [] 82 | 83 | # Train model 84 | for epoch in range(FLAGS.epochs): 85 | t = time.time() 86 | # Construct feed dictionary 87 | feed_dict = construct_feed_dict(features, support, labels_binary_train, 88 | train_mask, placeholders) 89 | feed_dict.update({placeholders['dropout']: FLAGS.dropout}) 90 | # Training step 91 | outs = sess.run([model.opt_op, model.loss, model.accuracy], 92 | feed_dict=feed_dict) 93 | train_accuracy.append(outs[2]) 94 | train_loss.append(outs[1]) 95 | # Validation 96 | cost, acc, duration = evaluate(features, support, labels_binary_val, 97 | val_mask, placeholders) 98 | val_loss.append(cost) 99 | val_accuracy.append(acc) 100 | test_cost, test_acc, test_duration = evaluate(features, support, 101 | labels_binary_test, 102 | test_mask, placeholders) 103 | test_accuracy.append(test_acc) 104 | test_loss.append(test_cost) 105 | print("Epoch:", '%04d' % (epoch + 1), "train_loss=", 106 | "{:.5f}".format(outs[1]), "train_acc=", "{:.5f}".format(outs[2]), 107 | "val_loss=", "{:.5f}".format(cost), "val_acc=", "{:.5f}".format(acc), 108 | "time=", "{:.5f}".format(time.time() - t)) 109 | if epoch > FLAGS.early_stopping and val_loss[-1] > np.mean( 110 | val_loss[-(FLAGS.early_stopping + 1):-1]): 111 | print("Early stopping...") 112 | break 113 | 114 | print("Finished Training....") 115 | 116 | #' --------------- --------------- --------------- 117 | #' all outputs : prediction and activation 118 | #' --------------- --------------- --------------- 119 | 120 | all_mask = np.array([True] * len(train_mask)) 121 | labels_binary_all = new_label 122 | 123 | feed_dict_all = construct_feed_dict(features, support, labels_binary_all, 124 | all_mask, placeholders) 125 | feed_dict_all.update({placeholders['dropout']: FLAGS.dropout}) 126 | 127 | activation_output = sess.run(model.activations, feed_dict=feed_dict_all)[1] 128 | predict_output = sess.run(model.outputs, feed_dict=feed_dict_all) 129 | 130 | #' ------- accuracy on prediction masks --------- 131 | ab = sess.run(tf.nn.softmax(predict_output)) 132 | 133 | true_label1 = np.array(true_label) 134 | 135 | result_file1 = '{}/predict_output.csv'.format(FLAGS.result) 136 | result_file2 = '{}/true_output.csv'.format(FLAGS.result) 137 | 138 | np.savetxt(result_file1, ab[pred_mask], delimiter=',') 139 | np.savetxt(result_file2, true_label1[pred_mask], delimiter=',') 140 | -------------------------------------------------------------------------------- /DSTG/utils.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | import scipy.sparse 3 | import numpy as np 4 | import pandas as pd 5 | from scipy import sparse as sp 6 | import networkx as nx 7 | from collections import defaultdict 8 | from scipy.stats import uniform 9 | from data import * 10 | 11 | def load_data(datadir): 12 | input_data(datadir) 13 | PIK = "{}/datasets.dat".format(datadir) 14 | with open(PIK, "rb") as f: 15 | objects = pkl.load(f) 16 | 17 | data_train1, data_test1, data_val1, label_train1, label_test1, label_val1, lab_data2, lab_label2 = tuple( 18 | objects) 19 | 20 | train2 = pd.concat([data_train1, lab_data2]) 21 | lab_train2 = pd.concat([label_train1, lab_label2]) 22 | 23 | datas_train = np.array(train2) 24 | datas_test = np.array(data_test1) 25 | datas_val = np.array(data_val1) 26 | labels_train = np.array(lab_train2) 27 | labels_test = np.array(label_test1) 28 | labels_val = np.array(label_val1) 29 | 30 | #' convert pandas data frame to csr_matrix format 31 | datas_tr = scipy.sparse.csr_matrix(datas_train.astype('Float64')) 32 | datas_va = scipy.sparse.csr_matrix(datas_val.astype('Float64')) 33 | datas_te = scipy.sparse.csr_matrix(datas_test.astype('Float64')) 34 | 35 | M = len(data_train1) 36 | 37 | #' 4) get the feature object by combining training, test, valiation sets 38 | features = sp.vstack((sp.vstack((datas_tr, datas_va)), datas_te)).tolil() 39 | features = preprocess_features(features) 40 | 41 | labels_tr = labels_train 42 | labels_va = labels_val 43 | labels_te = labels_test 44 | 45 | labels = np.concatenate( 46 | [np.concatenate([labels_tr, labels_va]), labels_te]) 47 | Labels = pd.DataFrame(labels) 48 | 49 | true_label = Labels 50 | 51 | #' new label with binary values 52 | new_label = labels 53 | idx_train = range(M) 54 | idx_pred = range(M, len(labels_tr)) 55 | idx_val = range(len(labels_tr), len(labels_tr) + len(labels_va)) 56 | idx_test = range( 57 | len(labels_tr) + len(labels_va), 58 | len(labels_tr) + len(labels_va) + len(labels_te)) 59 | 60 | train_mask = sample_mask(idx_train, new_label.shape[0]) 61 | pred_mask = sample_mask(idx_pred, new_label.shape[0]) 62 | val_mask = sample_mask(idx_val, new_label.shape[0]) 63 | test_mask = sample_mask(idx_test, new_label.shape[0]) 64 | 65 | labels_binary_train = np.zeros(new_label.shape) 66 | labels_binary_val = np.zeros(new_label.shape) 67 | labels_binary_test = np.zeros(new_label.shape) 68 | labels_binary_train[train_mask, :] = new_label[train_mask, :] 69 | labels_binary_val[val_mask, :] = new_label[val_mask, :] 70 | labels_binary_test[test_mask, :] = new_label[test_mask, :] 71 | 72 | #' construct adjacent matrix 73 | id_graph1 = pd.read_csv('{}/Linked_graph1.csv'.format(datadir), 74 | index_col=0, 75 | sep=',') 76 | 77 | #' map index 78 | fake1 = np.array([-1] * len(lab_data2.index)) 79 | index1 = np.concatenate((data_train1.index, fake1, data_val1.index, 80 | data_test1.index)).flatten() 81 | #' (feature_data.index==index1).all() 82 | fake2 = np.array([-1] * len(data_train1)) 83 | fake3 = np.array([-1] * (len(data_val1) + len(data_test1))) 84 | find1 = np.concatenate((fake2, np.array(lab_data2.index), fake3)).flatten() 85 | 86 | row1 = [np.where(find1 == id_graph1.iloc[i, 1])[0][0] 87 | for i in range(len(id_graph1)) 88 | ] 89 | col1 = [np.where(index1 == id_graph1.iloc[i, 0])[0][0] 90 | for i in range(len(id_graph1)) 91 | ] 92 | adj = defaultdict(list) # default value of int is 0 93 | for i in range(len(labels)): 94 | adj[i].append(i) 95 | for i in range(len(row1)): 96 | adj[row1[i]].append(col1[i]) 97 | adj[col1[i]].append(row1[i]) 98 | 99 | adj = nx.adjacency_matrix(nx.from_dict_of_lists(adj)) 100 | 101 | print("assign input coordinatly....") 102 | return adj, features, labels_binary_train, labels_binary_val, labels_binary_test, train_mask, pred_mask, val_mask, test_mask, new_label, true_label 103 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2020 Qianqian Song 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deconvoluting Spatial Transcriptomics data through Graph-based convolutional networks (DSTG) 2 | 3 | This is a TensorFlow implementation of DSTG for decomposing spatial transcriptomics data, which is described in our paper: 4 | 5 | ## Installation 6 | 7 | ```bash 8 | python setup.py install 9 | ``` 10 | 11 | ## Requirements 12 | * tensorflow (>0.12) 13 | * networkx 14 | 15 | ## Run the demo 16 | 17 | load the example data using the convert_data.R script 18 | In the example data, we provide two synthetic spatial transcriptomics data generated from scRNA-seq data (GSE72056). Each synthetic data consists of 1,000 spots, which can be found in folder synthetic_data. 19 | ```bash 20 | cd DSTG 21 | Rscript convert_data.R # load example data 22 | python train.py # run DSTG 23 | ``` 24 | Predicted compositions within each spot are saved in will be shown in the DSTG_Result folder. 25 | 26 | Performance of JSD score will be shown if you run 27 | ``` 28 | Rscript evaluation.R 29 | ``` 30 | If you want to use your own scRNA-seq data to deconvolute your spatail transcriptomcis data, provide you data to script below: 31 | 32 | ## Run your own data 33 | When using your own scRNA-seq data to deconvolute your spatail transcriptomcis data, you have to provide 34 | * the raw scRNA-seq data matrix and label, which are saved as .RDS format (e.g. 'scRNAseq_data.RDS' & 'scRNAseq_label.RDS') 35 | * the raw spatial transcriptomics data matrix saved as .RDS format (e.g. 'spatial_data.RDS') 36 | 37 | ``` 38 | cd DSTG 39 | Rscript convert_data.R scRNAseq_data.RDS spatial_data.RDS scRNAseq_label.RDS 40 | python train.py # run DSTG 41 | ``` 42 | Then you will get your results in the DSTG_Result folder. 43 | 44 | 45 | ## Cite 46 | 47 | Please cite our paper if you use this code in your own work: 48 | 49 | ``` 50 | Qianqian Song, Jing Su, DSTG: deconvoluting spatial transcriptomics data through graph-based artificial intelligence, Briefings in Bioinformatics, 2021;, bbaa414, https://doi.org/10.1093/bib/bbaa414 51 | ``` 52 | -------------------------------------------------------------------------------- /requirments.txt: -------------------------------------------------------------------------------- 1 | networkx==2.2 2 | scipy==1.1.0 3 | setuptools==40.6.3 4 | numpy==1.15.4 5 | tensorflow==1.15.2 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="DSTG", 8 | version="0.0.1", 9 | author="QSong", 10 | author_email="wasqqdyx@gmail.com", 11 | description= 12 | "Deconvoluting Spatial Transcriptomics data through Graph-based convolutional networks (DSTG)", 13 | long_description=long_description, 14 | long_description_content_type="text/markdown", 15 | url="https://github.com/Su-informatics-lab/DSTG", 16 | packages=setuptools.find_packages(), 17 | classifiers=[ 18 | "Programming Language :: Python :: 3", 19 | "License :: OSI Approved :: MIT License", 20 | "Operating System :: OS Independent", 21 | ], 22 | python_requires='>=3.7', 23 | ) 24 | --------------------------------------------------------------------------------