├── LICENSE ├── README.md ├── STARCH.py ├── hgTables_hg19.txt └── run_STARCH.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 raphael-group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # STARCH 2 | Spatial Transcriptomics Algorithm Reconstructing Copy-number Heterogeneity 3 | 4 | ## Required Python Libraries 5 | Numpy, Pandas, Scipy, SKlearn, multiprocessing, HMMlearn 6 | 7 | ## Required input 8 | 1. Gene x Spot STRNA-seq expression matrix. This can be a text file (.csv, .tsv, .txt) where the first column is contains the gene names in HUGO format and the first row contains the spot coordinates in the format '1x2'. STARCH can also read in the output from 10X's Space Ranger pipeline. In this case, the input is the path to a directory containing the following files: barcodes.tsv, features.tsv, matrix.mtx, and tissue_positions_list.csv 9 | 10 | 2. Gene mapping file which maps each HUGO gene name to chromosomal positions. The mapping file for human assembly hg19 is provided by default (hgTables_hg19.txt). For other organisms, an hgTables file can be downloaded from https://genome.ucsc.edu/cgi-bin/hgTables (when generating the hgTable, select group = Genes and Gene predictions, track = all GENCODE V33, region = Genome). 11 | 12 | ## Running from Command Line 13 | 14 | python run_STARCH.py -i gene_expression_matrix.csv (or 10X_directory/) --output name --n_clusters 3 --outdir output/directory/ 15 | 16 | ## Output 17 | 18 | run_STARCH.py will output two files (1) the gene x clone CNV matrix which is saved to outdir/states_output.csv. (2) a spot label vector which assigns each spot to one of n_clusters clones which is saved to outdir/labels_output.csv. 19 | -------------------------------------------------------------------------------- /STARCH.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import argparse 4 | import logging 5 | logging.basicConfig(level=logging.INFO) 6 | logger = logging.getLogger() 7 | import math 8 | import copy 9 | import sklearn 10 | import sklearn.cluster 11 | import random 12 | from sklearn.cluster import KMeans 13 | from sklearn.mixture import GaussianMixture 14 | from sklearn.metrics import silhouette_score, davies_bouldin_score,v_measure_score 15 | from sklearn.preprocessing import MinMaxScaler 16 | from sklearn.cluster import AgglomerativeClustering 17 | from sklearn.decomposition import NMF 18 | from sklearn.decomposition import PCA 19 | import multiprocessing as mp 20 | from functools import partial 21 | from scipy.spatial import distance 22 | import os 23 | from scipy.stats import norm 24 | from scipy.stats import multivariate_normal 25 | from scipy.stats import ttest_ind 26 | from scipy.stats import ks_2samp 27 | from hmmlearn import hmm 28 | from scipy.io import mmread 29 | from scipy.sparse import csr_matrix 30 | import multiprocessing 31 | import warnings 32 | from pathlib import Path 33 | os.environ['NUMEXPR_MAX_THREADS'] = '50' 34 | 35 | 36 | def jointLikelihoodEnergyLabels_helper(label,data,states,norms): 37 | e = 1e-50 38 | r0 = [x for x in range(data.shape[0]) if states[x,label]==0] 39 | l0 = np.sum(-np.log(np.asarray(norms[0].pdf(data[r0,:])+e)),axis=0) 40 | r1 = [x for x in range(data.shape[0]) if states[x,label]==1] 41 | l1 = np.sum(-np.log(np.asarray(norms[1].pdf(data[r1,:])+e)),axis=0) 42 | r2 = [x for x in range(data.shape[0]) if states[x,label]==2] 43 | l2 = np.sum(-np.log(np.asarray(norms[2].pdf(data[r2,:])+e)),axis=0) 44 | return l0 + l1 + l2 45 | 46 | def init_helper(i,data, n_clusters,normal,diff,labels,c): 47 | l = [] 48 | for k in range(n_clusters): 49 | pval = ks_2samp(data[i,labels==k],normal[i,:])[1] 50 | mn = np.mean(normal[i,:]) 51 | if c[i,k]< mn and pval <= diff: 52 | l.append(0) 53 | elif c[i,k]> mn and pval <= diff: 54 | l.append(2) 55 | else: 56 | l.append(1) 57 | return np.asarray(l).astype(int) 58 | 59 | def HMM_helper(inds, data, means, sigmas ,t, num_states, model,normal): 60 | ind_bin,ind_spot,k = inds 61 | data = data[np.asarray(ind_bin)[:, None],np.asarray(ind_spot)] 62 | data2 = np.mean(data,axis=1) 63 | X = np.asarray([[x] for x in data2]) 64 | C = np.asarray(model.predict(X)) 65 | score = model.score(X) 66 | #bootstrap 67 | b=3 68 | for i in range(b): 69 | inds = random.sample(range(data.shape[1]),int(data.shape[1]*.8+1)) 70 | data2 = np.mean(data[:,inds],axis=1) 71 | X = np.asarray([[x] for x in data2]) 72 | C2 = np.asarray(model.predict(X)) 73 | for j,c in enumerate(C2): 74 | if C[j] != c: 75 | C[j] = 1 76 | return [C,score] 77 | 78 | class STARCH: 79 | """ 80 | This is a class for Hidden Markov Random Field for calling Copy Number Aberrations 81 | using spatial relationships and gene adjacencies along chromosomes 82 | """ 83 | 84 | def __init__(self,data,normal_spots=[],labels=[],beta_spots=2,n_clusters=3,num_states=3,gene_mapping_file_name='hgTables_hg19.txt',nthreads=0,platform="ST"): 85 | """ 86 | The constructor for HMFR_CNA 87 | 88 | Parameters: 89 | data (pandas data frame): gene x spot (or cell). 90 | colnames = 2d or 3d indices (eg. 5x18, 5x18x2 if multiple layers). 91 | rownames = HUGO gene name 92 | """ 93 | assert( platform == "ST" or platform == "Visium" ) 94 | self.platform = platform 95 | logger.info("platform is {}".format(self.platform)) 96 | if nthreads == 0: 97 | nthreads = int(multiprocessing.cpu_count() / 2 + 1) 98 | logger.info('Running with ' + str(nthreads) + ' threads') 99 | logger.info("initializing HMRF...") 100 | self.beta_spots = beta_spots 101 | self.gene_mapping_file_name = gene_mapping_file_name 102 | self.n_clusters = int(n_clusters) 103 | dat,data = self.preload(data) 104 | dat,data = self.filter_spots(dat,data) 105 | logger.info(str(self.rows[0:20])) 106 | logger.info(str(len(self.rows)) + ' ' + str(len(self.columns)) + ' ' + str(data.shape)) 107 | if isinstance(normal_spots, str): 108 | self.read_normal_spots(normal_spots) 109 | if normal_spots == []: 110 | self.get_normal_spots(data) 111 | else: 112 | self.normal_spots = np.asarray([int(x) for x in normal_spots]) 113 | logger.info('normal spots ' + str(len(self.normal_spots))) 114 | dat = self.preprocess_data(data,dat) 115 | logger.info('done preprocessing...') 116 | self.data = self.data * 1000 117 | self.bins = self.data.shape[0] 118 | self.spots = self.data.shape[1] 119 | self.tumor_spots = np.asarray([int(x) for x in range(self.spots) if int(x) not in self.normal_spots]) 120 | self.normal = self.data[:,self.normal_spots] 121 | self.data = self.data[:,self.tumor_spots] 122 | self.bins = self.data.shape[0] 123 | self.spots = self.data.shape[1] 124 | self.num_states = int(num_states) 125 | self.normal_state = int((self.num_states-1)/2) 126 | 127 | logger.info('getting spot network...') 128 | self.get_spot_network(self.data,self.columns[self.tumor_spots]) 129 | if isinstance(labels, str): 130 | self.get_labels(labels) 131 | if len(labels)>0: 132 | self.labels = labels 133 | else: 134 | logger.info('initializing labels...') 135 | self.initialize_labels() 136 | logger.debug('starting labels: '+str(self.labels)) 137 | np.fill_diagonal(self.spot_network, 0) 138 | logger.info('getting params...') 139 | count_valueerror = 0 140 | for d in range(10 ,20,1): 141 | try: 142 | self.init_params(d/10,nthreads) 143 | break 144 | except ValueError: 145 | count_valueerror += 1 146 | continue 147 | logger.info('Count of ValueError in init_params is {}'.format(count_valueerror)) 148 | self.states = np.zeros((self.bins,self.n_clusters)) 149 | logger.info('starting means: '+str(self.means)) 150 | logger.info('starting cov: '+str(self.sigmas)) 151 | logger.info(str(len(self.rows)) + ' ' + str(len(self.columns)) + ' ' + str(self.data.shape)) 152 | 153 | 154 | def to_transpose(self,sep,data): 155 | dat = pd.read_csv(data,sep=sep,header=0,index_col=0) 156 | if 'x' in dat.index.values[0] and 'x' in dat.index.values[1] and 'x' in dat.index.values[2]: 157 | return True 158 | return False 159 | 160 | def which_sep(self,data): 161 | dat = np.asarray(pd.read_csv(data,sep='\t',header=0,index_col=0)).size 162 | dat2 = np.asarray(pd.read_csv(data,sep=',',header=0,index_col=0)).size 163 | dat3 = np.asarray(pd.read_csv(data,sep=' ',header=0,index_col=0)).size 164 | if dat > dat2 and dat > dat3: 165 | return '\t' 166 | elif dat2 > dat and dat2 > dat3: 167 | return ',' 168 | else: 169 | return ' ' 170 | 171 | def get_bin_size(self,data,chroms): 172 | for bin_size in range(20,100): 173 | test = self.bin_data2(data[:,self.normal_spots],chroms,bin_size=bin_size,step_size=1) 174 | test = test[test!=0] 175 | logger.debug(str(bin_size)+' mean expression binned ' + str(np.mean(test))) 176 | logger.debug(str(bin_size)+' median expression binned ' + str(np.median(test))) 177 | if np.median(test) >= 10: 178 | break 179 | logger.info('selected bin size: ' + str(bin_size)) 180 | return bin_size 181 | 182 | def preload(self,l): 183 | if isinstance(l,list): # list of multiple datasets 184 | offset = 0 185 | dats = [] 186 | datas = [] 187 | for data in l: 188 | dat,data = self.load(data) 189 | datas.append(data) 190 | dats.append(dat) 191 | conserved_genes = [] 192 | inds = [] 193 | for dat in dats: 194 | inds.append([]) 195 | for gene in dats[0].index.values: 196 | inall = True 197 | for dat in dats: 198 | if gene not in dat.index.values: 199 | inall = False 200 | if inall: 201 | conserved_genes.append(gene) 202 | for i,dat in enumerate(dats): 203 | ind = inds[i] 204 | ind.append(np.where(dat.index.values == gene)[0][0]) 205 | inds[i] = ind 206 | conserved_genes = np.asarray(conserved_genes) 207 | logger.info(str(conserved_genes)) 208 | newdatas = [] 209 | newdats = [] 210 | for i in range(len(datas)): 211 | data = datas[i] 212 | dat = dats[i] 213 | ind = np.asarray(inds[i]) 214 | newdatas.append(data[ind,:]) 215 | newdats.append(dat.iloc[ind,:]) 216 | for dat in newdats: 217 | spots = np.asarray([[float(y) for y in x.split('x')] for x in dat.columns.values]) 218 | for spot in spots: 219 | spot[0] += offset 220 | spots = ['x'.join([str(y) for y in x]) for x in spots] 221 | dat.columns = spots 222 | offset += 100 223 | data = np.concatenate(newdatas,axis=1) 224 | dat = pd.concat(newdats,axis=1) 225 | self.rows = dat.index.values 226 | self.columns = dat.columns.values 227 | else: 228 | dat,data = self.load(l) 229 | return dat,data 230 | 231 | def load(self,data): 232 | try: 233 | if isinstance(data, str) and ('.csv' in data or '.tsv' in data or '.txt' in data): 234 | logger.info('Reading data...') 235 | sep = self.which_sep(data) 236 | if self.to_transpose(sep,data): 237 | dat = pd.read_csv(data,sep=sep,header=0,index_col=0).T 238 | else: 239 | dat = pd.read_csv(data,sep=sep,header=0,index_col=0) 240 | elif isinstance(data,str): 241 | logger.info('Importing 10X data from directory. Directory must contain barcodes.tsv, features.tsv, matrix.mtx, tissue_positions_list.csv') 242 | # find the barcodes file from 10X directory 243 | file_barcodes = [str(x) for x in Path(data).rglob("*barcodes.tsv*")] 244 | if len(file_barcodes) == 0: 245 | logger.error('There is no barcode.tsv file in the 10X directory.') 246 | file_barcodes = file_barcodes[0] 247 | barcodes = np.asarray(pd.read_csv(file_barcodes,header=None)).flatten() 248 | # find the features file from 10X directory 249 | file_features = [str(x) for x in Path(data).rglob("*features.tsv*")] 250 | if len(file_features) == 0: 251 | logger.error('There is no features.tsv file in the 10X directory.') 252 | file_features = file_features[0] 253 | genes = np.asarray(pd.read_csv(file_features,sep='\t',header=None)) 254 | genes = genes[:,1] 255 | # find the tissue_position_list file from 10X directory 256 | file_coords = [str(x) for x in Path(data).rglob("*tissue_positions_list.csv*")] 257 | if len(file_coords) == 0: 258 | logger.error('There is no tissue_positions_list.csv file in the 10X directory.') 259 | file_coords = file_coords[0] 260 | coords = np.asarray(pd.read_csv(file_coords,sep=',',header=None)) 261 | d = dict() 262 | for row in coords: 263 | d[row[0]] = str(row[2]) + 'x' + str(row[3]) 264 | inds = [] 265 | coords2 = [] 266 | for i,barcode in enumerate(barcodes): 267 | if barcode in d.keys(): 268 | inds.append(i) 269 | coords2.append(d[barcode]) 270 | # find the count matrix file 271 | file_matrix = [str(x) for x in Path(data).rglob("*matrix.mtx*")] 272 | if len(file_matrix) == 0: 273 | logger.error('There is no matrix.mtx file in the 10X directory.') 274 | file_matrix = file_matrix[0] 275 | matrix = mmread(file_matrix).toarray() 276 | logger.info(str(barcodes) + ' ' + str(barcodes.shape)) 277 | logger.info(str(genes) + ' ' + str(genes.shape)) 278 | logger.info(str(coords) + ' ' + str(coords.shape)) 279 | logger.info(str(matrix.shape)) 280 | 281 | matrix = matrix[:,inds] 282 | genes,inds2 = np.unique(genes, return_index=True) 283 | matrix = matrix[inds2,:] 284 | dat = pd.DataFrame(matrix,index = genes,columns = coords2) 285 | 286 | logger.info(str(dat)) 287 | else: 288 | dat = pd.DataFrame(data) 289 | except: 290 | raise Exception("Incorrect input format") 291 | logger.info('coords ' + str(len(dat.columns.values))) 292 | logger.info('genes ' + str(len(dat.index.values))) 293 | data = dat.values 294 | logger.info(str(data.shape)) 295 | self.rows = dat.index.values 296 | self.columns = dat.columns.values 297 | return(dat,data) 298 | 299 | def filter_spots(self, dat, data, min_umi_perspot=10): 300 | tmpdata, inds = self.filter_genes(data,min_cells=int(data.shape[1]/20)) 301 | idx_spots = np.where(np.sum(tmpdata, axis=0) > min_umi_perspot)[0] 302 | data = data[:, idx_spots] 303 | dat = dat.iloc[:, idx_spots] 304 | self.rows = dat.index.values 305 | self.columns = dat.columns.values 306 | logger.info('Filtered spots, now have ' + str(data.shape[1]) + ' spots') 307 | return dat, data 308 | 309 | def preprocess_data(self,data,dat): 310 | logger.info('data shape ' + str(data.shape)) 311 | data,inds = self.filter_genes(data,min_cells=int(data.shape[1]/20)) 312 | logger.info('Filtered genes, now have ' + str(data.shape[0]) + ' genes') 313 | data[data>np.mean(data)+np.std(data)*2]=np.mean(data)+np.std(data)*2 314 | dat = dat.T[dat.index.values[inds]].T 315 | self.rows = dat.index.values 316 | self.columns = dat.columns.values 317 | logger.info('filter ' + str(len(self.rows)) + ' ' + str(len(self.columns)) + ' ' + str(data.shape)) 318 | data,chroms,pos,inds = self.order_genes_by_position(data,dat.index.values) 319 | dat = dat.T[dat.index.values[inds]].T 320 | self.rows = dat.index.values 321 | self.columns = dat.columns.values 322 | logger.info('order ' + str(len(self.rows)) + ' ' + str(len(self.columns)) + ' ' + str(data.shape)) 323 | logger.info('zero percentage ' + str((data.size - np.count_nonzero(data)) / data.size)) 324 | 325 | bin_size = self.get_bin_size(data,chroms) 326 | 327 | data = np.log(data+1) 328 | data = self.library_size_normalize(data) #2 329 | data = data-np.mean(data[:,self.normal_spots],axis=1).reshape(data.shape[0],1) 330 | data = self.threshold_data(data,max_value=3.0) 331 | data = self.bin_data(data,chroms,bin_size=bin_size,step_size=1) 332 | data = self.center_at_zero(data) #7 333 | data = data-np.mean(data[:,self.normal_spots],axis=1).reshape(data.shape[0],1) 334 | data = np.exp(data)-1 335 | self.data = data 336 | self.pos = np.asarray([str(x) for x in pos]) 337 | logger.info('preprocess ' + str(len(self.rows)) + ' ' + str(len(self.columns)) + ' ' + str(data.shape)) 338 | return(dat) 339 | 340 | def read_normal_spots(self,normal_spots): 341 | normal_spots = pd.read_csv(normal_spots,sep=',') 342 | self.normal_spots = np.asarray([int(x) for x in np.asarray(normal_spots)]) 343 | 344 | def get_normal_spots(self,data): 345 | data,k = self.filter_genes(data,min_cells=int(data.shape[1]/20)) # 1 346 | data = self.library_size_normalize(data) #2 347 | data = np.log(data+1) 348 | data = self.threshold_data(data,max_value=3.0) 349 | pca = PCA(n_components=1).fit_transform(data.T) 350 | km = KMeans(n_clusters=2).fit(pca) 351 | clusters = np.asarray(km.predict(pca)) 352 | if np.mean(data[:,clusters==0]) < np.mean(data[:,clusters==1]): 353 | self.normal_spots = np.asarray([x for x in range(data.shape[1])])[clusters==0] 354 | else: 355 | self.normal_spots = np.asarray([x for x in range(data.shape[1])])[clusters==1] 356 | 357 | def filter_genes(self,data,min_cells=20): 358 | keep = [] 359 | for gene in range(data.shape[0]): 360 | if np.count_nonzero(data[gene,:]) >= min_cells: 361 | keep.append(gene) 362 | return data[np.asarray(keep),:],np.asarray(keep) 363 | 364 | def library_size_normalize(self,data): 365 | m = np.median(np.sum(data,axis=0)) 366 | data = data / np.sum(data,axis=0) 367 | data = data * m 368 | return data 369 | 370 | def threshold_data(self,data,max_value=4.0): 371 | data[data> max_value] = max_value 372 | data[data< -max_value] = -max_value 373 | return data 374 | 375 | def center_at_zero(self,data): 376 | return data - np.median(data,axis=0).reshape(1,data.shape[1]) 377 | 378 | 379 | def bin_data2(self,data,chroms,bin_size,step_size): 380 | newdata = copy.deepcopy(data) 381 | i=0 382 | c = np.asarray(list(set(chroms))) 383 | c.sort() 384 | for chrom in c: 385 | data2 = data[chroms==chrom,:] 386 | for gene in range(data2.shape[0]): 387 | start = max(0,gene-int(bin_size/2)) 388 | end = min(data2.shape[0],gene+int(bin_size/2)) 389 | r = np.asarray([x for x in range(start,end)]) 390 | mean = np.sum(data2[r,:],axis=0) 391 | newdata[i,:] = mean 392 | i += 1 393 | return newdata 394 | 395 | 396 | def bin_data(self,data,chroms,bin_size,step_size): 397 | newdata = copy.deepcopy(data) 398 | i=0 399 | c = np.asarray(list(set(chroms))) 400 | c.sort() 401 | for chrom in c: 402 | data2 = data[chroms==chrom,:] 403 | for gene in range(data2.shape[0]): 404 | start = max(0,gene-int(bin_size/2)) 405 | end = min(data2.shape[0],gene+int(bin_size/2)) 406 | r = np.asarray([x for x in range(start,end)]) 407 | weighting = np.asarray([x+1 for x in range(start,end)]) 408 | weighting = abs(weighting - len(weighting)/2) 409 | weighting = 1/(weighting+1) 410 | weighting = weighting / sum(weighting) #pyramidinal weighting 411 | weighting = weighting.reshape(len(r),1) 412 | mean = np.sum(data2[r,:]*weighting,axis=0) 413 | newdata[i,:] = mean 414 | i += 1 415 | return newdata 416 | 417 | 418 | def order_genes_by_position(self,data,genes): 419 | mapping = pd.read_csv(self.gene_mapping_file_name,sep='\t') 420 | names = mapping['name2'] 421 | chroms = mapping['chrom'] 422 | starts = mapping['cdsStart'] 423 | ends = mapping['cdsEnd'] 424 | d = dict() 425 | d2 = dict() 426 | for i,gene in enumerate(names): 427 | try: 428 | if int(chroms[i][3:]) > 0: 429 | d[gene.upper()] = int(int(chroms[i][3:])*1e10 + int(starts[i])) 430 | d2[gene.upper()] = str(chroms[i][3:]) + ':' + str(starts[i]) 431 | except: 432 | None 433 | positions = [] 434 | posnames = [] 435 | for gene in genes: 436 | gene = gene.upper() 437 | if gene in d.keys(): 438 | positions.append(d[gene]) 439 | posnames.append(d2[gene]) 440 | else: 441 | positions.append(-1) 442 | posnames.append(-1) 443 | positions = np.asarray(positions) 444 | posnames = np.asarray(posnames) 445 | l = len(positions[positions==-1]) 446 | order = np.argsort(positions) 447 | order = order[l:] 448 | positions = positions[order]/1e10 449 | posnames = posnames[order] 450 | return data[order,:],positions.astype('int'),posnames,order 451 | 452 | def get_labels(self,labels): 453 | labels = np.asarray(pd.read_csv(data,sep=',')) 454 | self.labels = labels 455 | 456 | def init_params(self,d=1.3,nthreads=1): 457 | c = np.zeros((self.data.shape[0],self.n_clusters)) 458 | for i in range(self.data.shape[0]): 459 | for k in range(self.n_clusters): 460 | c[i,k] = np.mean(self.data[i,self.labels==k]) 461 | labels = np.zeros((self.data.shape[0],self.n_clusters)) 462 | diffs = [] 463 | for i in range(0,self.data.shape[0],10): 464 | diffs.append(ks_2samp(self.normal[i,:]+np.std(self.normal[i,:])/d,self.normal[i,:])[1]) 465 | diff = np.mean(diffs) 466 | logger.info(str(diff)) 467 | 468 | pool = mp.Pool(nthreads) 469 | results = pool.map(partial(init_helper, data=self.data, n_clusters=self.n_clusters,normal=self.normal,diff=diff,labels=self.labels,c=c), [x for x in range(self.data.shape[0])]) 470 | for i in range(len(results)): 471 | labels[i,:] = results[i] 472 | labels = labels.astype(int) 473 | with warnings.catch_warnings(): 474 | warnings.simplefilter("ignore", category=RuntimeWarning) 475 | means = [np.mean(c[labels==cluster]) for cluster in range(self.num_states)] 476 | sigmas = [np.std(c[labels==cluster]) for cluster in range(self.num_states)] 477 | indices = np.argsort([x for x in means]) 478 | states = copy.deepcopy(labels) 479 | m = np.zeros((3,1)) 480 | s = np.zeros((3,1)) 481 | i=0 482 | for index in indices: 483 | states[labels==index]=i # set states 484 | mean = means[index] 485 | sigma = sigmas[index] 486 | if np.isnan(mean) or np.isnan(sigma) or sigma < .01: 487 | raise ValueError() 488 | m[i] = [mean] 489 | s[i] = [sigma**2] 490 | i+=1 491 | self.means = m 492 | self.sigmas = s 493 | 494 | def init_params2(self): 495 | means = [[],[],[]] 496 | sigmas = [[],[],[]] 497 | for s in range(self.num_states): 498 | d=[] 499 | for cluster in range(self.n_clusters): 500 | dat = np.asarray(list(self.data[:,self.labels==cluster])) 501 | d += list(dat[np.asarray(list(self.states[:,cluster].astype(int)==int(s)))].flatten()) 502 | means[s] = [np.mean(d)] 503 | sigmas[s] = [np.std(d)**2] 504 | logger.info(str(means)) 505 | self.means = np.asarray(means) 506 | self.sigmas = np.asarray(sigmas) 507 | 508 | def initialize_labels(self): 509 | dat=self.data 510 | 511 | km = KMeans(n_clusters=self.n_clusters).fit(dat.T) 512 | clusters = np.asarray(km.predict(dat.T)) 513 | self.labels = clusters 514 | 515 | def get_spot_network(self,data,spots,l=1): 516 | spots = np.asarray([[float(y) for y in x.split('x')] for x in spots]) 517 | if self.platform == "Visium": 518 | logger.info("Using Visium platform layout.") 519 | # scale row and col coordinate to make them a regular hexagon with the adjacent hexagon center distance = 1 520 | scale_row = np.sqrt(3) / 2 521 | scale_col = 1.0 / 2 522 | spots[:,0] = spots[:,0] * scale_row 523 | spots[:,1] = spots[:,1] * scale_col 524 | spot_network = np.zeros((len(spots),len(spots))) 525 | for i in range(len(spots)): 526 | for j in range(i,len(spots)): 527 | dist = distance.euclidean(spots[i],spots[j]) 528 | spot_network[i,j] = np.exp(-dist/(l)) # exponential covariance 529 | spot_network[j,i] = spot_network[i,j] 530 | self.spot_network = spot_network 531 | 532 | 533 | def get_gene_network(self,data,genes,l=1): 534 | genes = np.asarray(genes) 535 | gene_network = np.zeros((len(genes),len(genes))) 536 | for i in range(len(genes)): 537 | for j in range(i,len(genes)): 538 | dist = j-i 539 | gene_network[i,j] = np.exp(-dist/(l)) # exponential covariance 540 | gene_network[j,i] = gene_network[i,j] 541 | return gene_network 542 | 543 | def _optimalK(self,data, maxClusters=15): 544 | X_scaled = data 545 | km_scores= [] 546 | km_silhouette = [] 547 | db_score = [] 548 | for i in range(2,maxClusters): 549 | km = KMeans(n_clusters=i).fit(X_scaled) 550 | preds = km.predict(X_scaled) 551 | 552 | silhouette = silhouette_score(X_scaled,preds) 553 | km_silhouette.append(silhouette) 554 | logger.info("Silhouette score for number of cluster(s) {}: {}".format(i,silhouette)) 555 | 556 | best_silouette = np.argmax(km_silhouette)+2 557 | best_db = np.argmin(db_score)+2 558 | logger.info('silhouette: ' + str(best_silouette)) 559 | return(int(best_silouette)) 560 | 561 | 562 | def HMM_estimate_states_parallel(self,t,maxiters=100,deltoamp=0,nthreads=1): 563 | n_clusters = self.n_clusters 564 | self.EnergyPriors = np.zeros((self.data.shape[0],n_clusters,self.num_states)) 565 | self.t = t 566 | chromosomes = [int(x.split(':')[0]) for x in self.pos] 567 | inds = [] 568 | n_clusters = self.n_clusters 569 | if len(set(self.labels)) != self.n_clusters: 570 | labels = copy.deepcopy(self.labels) 571 | i=0 572 | for label in set(self.labels): 573 | labels[self.labels==label]=i 574 | i=i+1 575 | self.labels = labels 576 | self.n_clusters = len(set(self.labels)) 577 | for chrom in set(chromosomes): 578 | for k in range(self.n_clusters): 579 | inds.append([np.asarray([i for i in range(len(chromosomes)) if chromosomes[i] == chrom]),np.asarray([i for i in range(len(self.labels)) if self.labels[i]==k]),k]) 580 | pool = mp.Pool(nthreads) 581 | results = pool.map(partial(HMM_helper, data=self.data, means = self.means, sigmas = self.sigmas,t = self.t,num_states = self.num_states,model=self.model,normal=self.normal), inds) 582 | score = 0 583 | for i in range(len(results)): 584 | self.states[inds[i][0][:, None],inds[i][2]] = results[i][0].reshape((len(results[i][0]),1)) 585 | score += results[i][1] 586 | return score 587 | 588 | 589 | def jointLikelihoodEnergyLabels(self,norms,pool): 590 | Z = (2*math.pi)**(self.num_states/2) 591 | n_clusters = self.n_clusters 592 | likelihoods = np.zeros((self.data.shape[1],n_clusters)) 593 | results = pool.map(partial(jointLikelihoodEnergyLabels_helper, data=self.data, states=self.states,norms=norms), range(n_clusters)) 594 | for label in range(n_clusters): 595 | likelihoods[:,label] += results[label] 596 | likelihoods = likelihoods / self.data.shape[0] 597 | likelihood_energies = likelihoods 598 | return(likelihood_energies) 599 | 600 | def jointLikelihoodEnergyLabelsapprox(self,means): 601 | e = 1e-20 602 | n_clusters = self.n_clusters 603 | likelihoods = np.zeros((self.data.shape[1],n_clusters)) 604 | for spot in range(self.spots): 605 | ml=np.inf 606 | for label in range(n_clusters): 607 | likelihood = np.sum(abs(self.data[:,spot]-means[:,label]))/self.data.shape[0] 608 | if likelihood < ml: 609 | ml = likelihood 610 | likelihoods[spot,label] = likelihood 611 | likelihoods[spot,:]-=ml 612 | likelihood_energies = likelihoods 613 | return(likelihood_energies) 614 | 615 | def MAP_estimate_labels(self,beta_spots,nthreads,maxiters=20): 616 | inds_spot = [] 617 | tmp_spot = [] 618 | n_clusters = self.n_clusters 619 | prev_labels = copy.deepcopy(self.labels) 620 | for j in range(self.spots): 621 | inds_spot.append(np.where(self.spot_network[j,:] >= .25)[0]) 622 | tmp_spot.append(self.spot_network[j,inds_spot[j]]) 623 | logger.debug(str(tmp_spot)) 624 | pool = mp.Pool(nthreads) 625 | norms = [norm(self.means[0][0],np.sqrt(self.sigmas[0][0])),norm(self.means[1][0],np.sqrt(self.sigmas[1][0])),norm(self.means[2][0],np.sqrt(self.sigmas[2][0]))] 626 | for m in range(maxiters): 627 | posteriors = 0 628 | means = np.zeros((self.bins,n_clusters)) 629 | for label in range(n_clusters): 630 | means[:,label] = np.asarray([self.means[int(i)][0] for i in self.states[:,label]]) 631 | likelihood_energies = self.jointLikelihoodEnergyLabels(norms,pool) 632 | #likelihood_energies = self.jointLikelihoodEnergyLabelsapprox(means) 633 | for j in range(self.spots): 634 | p = [((np.sum(tmp_spot[j][self.labels[inds_spot[j]] != label]))) for label in range(n_clusters)] 635 | val = [likelihood_energies[j,label]+beta_spots*1*p[label] for label in range(n_clusters)] 636 | arg = np.argmin(val) 637 | posteriors += val[arg] 638 | self.labels[j] = arg 639 | if np.array_equal(np.asarray(prev_labels),np.asarray(self.labels)): # check for convergence 640 | break 641 | prev_labels = copy.deepcopy(self.labels) 642 | return(-posteriors) 643 | 644 | 645 | def update_params(self): 646 | c = np.zeros((self.data.shape[0],self.n_clusters)) 647 | for i in range(self.data.shape[0]): 648 | for k in range(self.n_clusters): 649 | c[i,k] = np.mean(self.data[i,self.labels==k]) 650 | means = [np.mean(c[self.states==cluster]) for cluster in range(self.num_states)] 651 | sigmas = [np.std(c[self.states==cluster]) for cluster in range(self.num_states)] 652 | 653 | indices = np.argsort([x for x in means]) 654 | 655 | m = np.zeros((3,1)) 656 | s = np.zeros((3,1)) 657 | i=0 658 | for index in indices: 659 | self.states[self.states==index]=i # set states 660 | mean = means[index] 661 | sigma = sigmas[index] 662 | m[i] = [mean] 663 | s[i] = [sigma**2] 664 | i+=1 665 | self.means = m 666 | self.sigmas = s 667 | logger.debug(str(self.means)) 668 | logger.debug(str(self.sigmas)) 669 | 670 | def callCNA(self,t=.00001,beta_spots=2,maxiters=20,deltoamp=0.0,nthreads=0,returnnormal=True): 671 | """ 672 | Run HMRF-EM framework to call CNA states by alternating between 673 | MAP estimate of states given current params and EM estimate of 674 | params given current states until convergence 675 | 676 | Returns: 677 | states (np array): integer CNA states (0 = del, 1 norm, 2 = amp) 678 | """ 679 | logger.info("running HMRF to call CNAs...") 680 | states = [copy.deepcopy(self.states),copy.deepcopy(self.states)] 681 | logger.debug('sum start:'+str(np.sum(states[-1]))) 682 | logger.info('beta spots: '+str(beta_spots)) 683 | if nthreads == 0: 684 | nthreads = int(multiprocessing.cpu_count() / 2 + 1) 685 | logger.info('Running with ' + str(nthreads) + ' threads') 686 | X = [] 687 | lengths = [] 688 | for i in range(self.data.shape[1]): 689 | X.append([[x] for x in self.data[:,i]]) 690 | lengths.append(len(self.data[:,i])) 691 | X = np.concatenate(X) 692 | model = hmm.GaussianHMM(n_components=self.num_states, covariance_type="diag",init_params="mc", params="",algorithm='viterbi') 693 | model.transmat_ = np.array([[1-2*t, t, t], 694 | [t, 1-2*t, t], 695 | [t, t, 1-2*t]]) 696 | model.startprob_ = np.asarray([.1,.8,.1]) 697 | model.means_ = self.means 698 | model.covars_ = self.sigmas 699 | model.fit(X,lengths) 700 | logger.info("fitted HMM means: " + str(model.means_)) 701 | logger.info("fitted HMM covariance matrices: " + str(model.covars_)) 702 | logger.info("fitted HMM transition matrix: " + str(model.transmat_)) 703 | logger.info("fitted HMM starting probability: " + str(model.startprob_)) 704 | self.model = model 705 | for i in range(maxiters): 706 | score_state = self.HMM_estimate_states_parallel(t=t,deltoamp=deltoamp,nthreads=nthreads) 707 | self.init_params2() 708 | score_label = self.MAP_estimate_labels(beta_spots=beta_spots,nthreads=nthreads,maxiters=20) 709 | states.append(copy.deepcopy(self.states)) 710 | logger.debug('sum iter:'+str(i) + ' ' + str(np.sum(states[-1]))) 711 | if np.array_equal(states[-2],states[-1]) or np.array_equal(states[-3],states[-1]): # check for convergence 712 | logger.info('states converged') 713 | break 714 | if len(states) > 3: 715 | states = states[-3:] 716 | logger.info('Posterior Energy: ' + str(score_state + score_label)) 717 | if returnnormal: 718 | labels = np.asarray([self.n_clusters for i in range(len(self.columns))]) 719 | labels[self.tumor_spots] = self.labels 720 | states = np.ones((self.states.shape[0],self.n_clusters+1)) 721 | for cluster in range(self.n_clusters): 722 | states[:,cluster] = self.states[:,cluster] 723 | self.labels = pd.DataFrame(data=labels,index=self.columns) 724 | self.states = states 725 | self.n_clusters += 1 726 | else: 727 | self.labels = pd.DataFrame(data=self.labels,index=self.columns[self.tumor_spots]) 728 | states = pd.DataFrame(self.states) 729 | logger.info(str(len(self.rows)) + ' ' + str(len(np.asarray([i for i in range(self.states.shape[1])]))) + ' ' + str(self.states.shape)) 730 | self.states = pd.DataFrame(self.states, index=self.rows,columns=np.asarray([i for i in range(self.states.shape[1])])) 731 | logger.debug(str(self.states)) 732 | logger.debug(str(self.labels)) 733 | return(score_state + score_label) # return CNA states 734 | 735 | 736 | 737 | 738 | 739 | -------------------------------------------------------------------------------- /run_STARCH.py: -------------------------------------------------------------------------------- 1 | 2 | from STARCH import STARCH 3 | import numpy as np 4 | import pandas as pd 5 | import argparse 6 | import logging 7 | logging.basicConfig(level=logging.INFO) 8 | logger = logging.getLogger() 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | 13 | parser.add_argument('-t','--threads',required=False,type=float,default=0,help="threads") 14 | parser.add_argument('-beta_spot','--beta_spot',required=False,default=2.0,type=float,help="spot") 15 | parser.add_argument('-c','--n_clusters',required=True,type=int,help="number of clones") 16 | parser.add_argument('-i','--input',required=True,nargs='+',help="name of input file (either expression matrix in .csv, .tsv, .txt format or 10X directory containing barcodes.tsv, features.tsv, matrix.mtx, and tissue_positions_list.csv)") 17 | parser.add_argument('-normal_spots','--normal_spots',required=False,type=str,help="name of input file containing indices of normal spots",default=0) 18 | parser.add_argument('-returnnormal','--returnnormal',required=False,type=int,default=1) 19 | parser.add_argument('-o','--output',required=False,type=str,default='STITCH_output',help='output name (ex. prostate1)') 20 | parser.add_argument('-outdir','--outdir',required=False,type=str,default='.',help='output directory') 21 | parser.add_argument('-m','--gene_mapping_file_name',required=False,type=str,default='hgTables_hg19.txt',help='gene mapping file name') 22 | parser.add_argument('-p','--platform',required=False,type=str,choices=["ST", "Visium"],default='ST',help='platform for spatial transcriptomcis data') 23 | args = parser.parse_args() 24 | 25 | nthreads = args.threads 26 | beta_spot = args.beta_spot 27 | n_clusters = args.n_clusters 28 | returnnormal = args.returnnormal 29 | i = args.input 30 | normal_spots = args.normal_spots 31 | out = args.output 32 | gene_mapping_file_name = args.gene_mapping_file_name 33 | outdir = args.outdir 34 | 35 | if normal_spots !=0: 36 | normal_spots = np.asarray(pd.read_csv(normal_spots,header=None)).flatten() 37 | else: 38 | normal_spots = [] 39 | 40 | operator = STARCH(i,n_clusters=n_clusters,num_states=3,normal_spots=normal_spots,beta_spots = beta_spot,nthreads=nthreads,gene_mapping_file_name=gene_mapping_file_name, platform=args.platform) 41 | posteriors = operator.callCNA(beta_spots=beta_spot,nthreads=nthreads,returnnormal=returnnormal) 42 | 43 | # relabel states 44 | new_states = np.array(operator.states).astype(int) 45 | index = np.argsort(operator.means.flatten()) 46 | map_states = {index[i]:i for i in range(len(index))} 47 | new_states = np.array([ [map_states[x] for x in y] for y in new_states]) 48 | if returnnormal: 49 | new_states[:,-1] = 1 50 | logger.info("{}, {}".format(new_states.shape, operator.states.shape) ) 51 | pd.DataFrame(new_states, index=operator.states.index, columns=operator.states.columns).to_csv('%s/states_%s.csv'%(outdir,out),sep=',') 52 | pd.DataFrame(operator.labels).to_csv('%s/labels_%s.csv'%(outdir,out),sep=',') 53 | --------------------------------------------------------------------------------