├── .gitignore ├── LICENSE ├── README.rst ├── deprecated ├── README.rst ├── old_examples │ ├── fly_brain │ │ ├── config.py │ │ ├── fly_brain.py │ │ └── process.py │ ├── rna_velocity │ │ └── schema_rna_velocity_demo1.ipynb │ ├── sci-car │ │ ├── cmdlines.sh │ │ └── scicar-process2.py │ ├── slide-seq │ │ ├── cmdlines.sh │ │ └── slideseq-process1.py │ └── tcr-binding │ │ ├── cmdlines.sh │ │ └── tcrbinding-process1.py └── old_readme.md ├── docs ├── Makefile ├── _static │ ├── Schema-Overview-v2.png │ ├── Schema-Overview-v3.png │ ├── Schema-webpage-logo-1.png │ └── Schema-webpage-logo-2-blue.png ├── _templates │ └── footer.html ├── make.bat └── source │ ├── _static │ ├── schema_atacrna_demo_dotplot1.png │ ├── schema_atacrna_demo_tsne1.png │ ├── schema_atacrna_demo_wts1.png │ ├── schema_paired-tag_data-dist.png │ ├── schema_paired-tag_gene_plots.png │ ├── schema_paired-tag_go-annot.csv │ ├── schema_paired-tag_umap-row1.png │ ├── schema_paired-tag_umap-row2.png │ ├── schema_paired-tag_umap-row3.png │ ├── umap_flybrain_regular_r3.png │ ├── umap_flybrain_schema0.999-0.99_r3.png │ └── umap_flybrain_schema0.99_r3.png │ ├── api │ └── index.rst │ ├── conf.py │ ├── datasets.rst │ ├── index.rst │ ├── installation.rst │ ├── overview.rst │ ├── recipes │ └── index.rst │ ├── references.rst │ └── visualization │ └── index.rst ├── examples ├── README.rst └── Schema_demo.ipynb ├── requirements.txt └── schema ├── __init__.py ├── datasets ├── __init__.py └── _datasets.py ├── schema_base_config.py ├── schema_qp.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | *.log 3 | *.log.* 4 | *.pdf 5 | *.png 6 | *.pyc 7 | *.svg 8 | *.txt 9 | *~ 10 | .#* 11 | build/ 12 | data 13 | dist/ 14 | plots* 15 | target* 16 | temp* 17 | setup.py 18 | *.out 19 | .ipynb_checkpoints/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Rohit Singh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | |PyPI| |Docs| 2 | 3 | .. |PyPI| image:: https://img.shields.io/pypi/v/schema_learn.svg 4 | :target: https://pypi.org/project/schema_learn 5 | .. |Docs| image:: https://readthedocs.org/projects/schema-multimodal/badge/?version=latest 6 | :target: https://schema-multimodal.readthedocs.io/en/latest/?badge=latest 7 | 8 | 9 | 10 | Schema - Analyze and Visualize Multimodal Single-Cell Data 11 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 12 | 13 | Schema is a Python library for the synthesis and integration of heterogeneous single-cell modalities. 14 | **It is designed for the case where the modalities have all been assayed for the same cells simultaneously.** 15 | Here are some of the analyses that you can do with Schema: 16 | 17 | - infer cell types jointly across modalities. 18 | - perform spatial transcriptomic analyses to identify differntially-expressed genes in cells that display a specific spatial characteristic. 19 | - create informative t-SNE & UMAP visualizations of multimodal data by infusing information from other modalities into scRNA-seq data. 20 | 21 | Schema offers support for the incorporation of more than two modalities and can also simultaneously handle batch effects and metadata (e.g., cell age). 22 | 23 | 24 | Schema is based on a metric learning approach and formulates the modality-synthesis problem as a quadratic programming problem. Its Python-based implementation can efficiently process large datasets without the need of a GPU. 25 | 26 | Read the documentation_. 27 | We encourage you to report issues at our `Github page`_ ; you can also create pull reports there to contribute your enhancements. 28 | If Schema is useful in your research, please consider citing our papers: `Genome Biology (2021)`_, with preprint in `bioRxiv (2019)`_. 29 | 30 | .. _documentation: https://schema-multimodal.readthedocs.io/en/latest/overview.html 31 | .. _bioRxiv (2019): http://doi.org/10.1101/834549 32 | .. _Github page: https://github.com/rs239/schema 33 | .. _Genome Biology (2021): https://genomebiology.biomedcentral.com/articles/10.1186/s13059-021-02313-2 34 | -------------------------------------------------------------------------------- /deprecated/README.rst: -------------------------------------------------------------------------------- 1 | Deprecated 2 | ========== 3 | 4 | This section contains deprecated areas of the code. In particular, if you're looking for examples, we recommend looking at the `documentation`_ page, which will be primary location for examples going forward. 5 | 6 | .. _documentation: https://schema-multimodal.readthedocs.io/en/latest/overview.html 7 | -------------------------------------------------------------------------------- /deprecated/old_examples/fly_brain/config.py: -------------------------------------------------------------------------------- 1 | import fileinput 2 | import sys 3 | 4 | data_names = None 5 | 6 | if data_names is None: 7 | if len(sys.argv) == 1: 8 | print('Enter data names followed by EOF/Ctrl-D:') 9 | 10 | data_names = [] 11 | for line in fileinput.input(): 12 | fields = line.rstrip().split(',') 13 | for f in fields: 14 | if f.strip() == '': 15 | continue 16 | data_names.append(f) 17 | print('Data names loaded') 18 | -------------------------------------------------------------------------------- /deprecated/old_examples/fly_brain/fly_brain.py: -------------------------------------------------------------------------------- 1 | from schema import SchemaQP 2 | from anndata import AnnData 3 | import numpy as np 4 | import scanpy as sc 5 | 6 | from .process import load_names 7 | 8 | def load_meta(fname): 9 | age, strain = [], [] 10 | with open(fname) as f: 11 | f.readline() # Consume header. 12 | for line in f: 13 | fields = line.rstrip().split() 14 | age.append(int(fields[4])) 15 | strain.append(fields[3]) 16 | return np.array(age), np.array(strain) 17 | 18 | if __name__ == '__main__': 19 | [ X ], [ genes ], _ = load_names([ 'data/fly_brain/GSE107451' ], norm=False) 20 | 21 | age, strain = load_meta('data/fly_brain/GSE107451/annotation.tsv') 22 | 23 | # Only analyze wild-type strain. 24 | adata = AnnData(X[strain == 'DGRP-551']) 25 | adata.var['gene_symbols'] = genes 26 | adata.obs['age'] = age[strain == 'DGRP-551'] 27 | 28 | # No Schema transformation. 29 | 30 | sc.pp.pca(adata) 31 | sc.tl.tsne(adata, n_pcs=50) 32 | sc.pl.tsne(adata, color='age', color_map='coolwarm', 33 | save='_flybrain_regular.png') 34 | 35 | sc.pp.neighbors(adata, n_neighbors=15) 36 | sc.tl.umap(adata) 37 | sc.pl.umap(adata, color='age', color_map='coolwarm', 38 | save='_flybrain_regular.png') 39 | 40 | # Schema transformation to include age. 41 | 42 | schema_corrs = [ 0.9999, 0.999, 0.99, 0.9, 0.7, 0.5 ] 43 | 44 | for schema_corr in schema_corrs: 45 | 46 | sqp = SchemaQP( 47 | min_desired_corr=schema_corr, 48 | w_max_to_avg=100, 49 | params={ 50 | 'decomposition_model': 'nmf', 51 | 'num_top_components': 20, 52 | }, 53 | ) 54 | 55 | X = sqp.fit_transform( 56 | adata.X, 57 | [ adata.obs['age'].values, ], 58 | [ 'numeric', ], 59 | [ 1, ] 60 | ) 61 | 62 | sdata = AnnData(X) 63 | sdata.obs['age'] = age[strain == 'DGRP-551'] 64 | 65 | sc.tl.tsne(sdata) 66 | sc.pl.tsne(sdata, color='age', color_map='coolwarm', 67 | save='_flybrain_schema_corr{}_w100.png'.format(schema_corr)) 68 | 69 | sc.pp.neighbors(sdata, n_neighbors=15) 70 | sc.tl.umap(sdata) 71 | sc.pl.umap(sdata, color='age', color_map='coolwarm', 72 | save='_flybrain_schema{}_w100.png'.format(schema_corr)) 73 | -------------------------------------------------------------------------------- /deprecated/old_examples/fly_brain/process.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import numpy as np 3 | import os.path 4 | import scipy.sparse 5 | from scipy.sparse import csr_matrix, csc_matrix 6 | from sklearn.preprocessing import normalize 7 | import sys 8 | 9 | MIN_TRANSCRIPTS = 0 10 | 11 | def load_tab(fname, delim='\t'): 12 | if fname.endswith('.gz'): 13 | opener = gzip.open 14 | else: 15 | opener = open 16 | 17 | with opener(fname, 'r') as f: 18 | if fname.endswith('.gz'): 19 | header = f.readline().decode('utf-8').rstrip().replace('"', '').split(delim) 20 | else: 21 | header = f.readline().rstrip().replace('"', '').split(delim) 22 | 23 | X = [] 24 | genes = [] 25 | for i, line in enumerate(f): 26 | if fname.endswith('.gz'): 27 | line = line.decode('utf-8') 28 | fields = line.rstrip().replace('"', '').split(delim) 29 | genes.append(fields[0]) 30 | X.append([ float(f) for f in fields[1:] ]) 31 | 32 | if i == 0: 33 | if len(header) == (len(fields) - 1): 34 | cells = header 35 | elif len(header) == len(fields): 36 | cells = header[1:] 37 | else: 38 | raise ValueError('Incompatible header/value dimensions {} and {}' 39 | .format(len(header), len(fields))) 40 | 41 | return np.array(X).T, np.array(cells), np.array(genes) 42 | 43 | def load_mtx(dname): 44 | with open(dname + '/matrix.mtx', 'r') as f: 45 | while True: 46 | header = f.readline() 47 | if not header.startswith('%'): 48 | break 49 | header = header.rstrip().split() 50 | n_genes, n_cells = int(header[0]), int(header[1]) 51 | 52 | data, i, j = [], [], [] 53 | for line in f: 54 | fields = line.rstrip().split() 55 | data.append(float(fields[2])) 56 | i.append(int(fields[1])-1) 57 | j.append(int(fields[0])-1) 58 | X = csr_matrix((data, (i, j)), shape=(n_cells, n_genes)) 59 | 60 | genes = [] 61 | with open(dname + '/genes.tsv', 'r') as f: 62 | for line in f: 63 | fields = line.rstrip().split() 64 | genes.append(fields[1]) 65 | assert(len(genes) == n_genes) 66 | 67 | return X, np.array(genes) 68 | 69 | def load_h5(fname, genome='GRCh38'): 70 | try: 71 | import tables 72 | except ImportError: 73 | sys.stderr.write('Please install PyTables to read .h5 files: ' 74 | 'https://www.pytables.org/usersguide/installation.html\n') 75 | exit(1) 76 | 77 | # Adapted from scanpy's read_10x_h5() method. 78 | with tables.open_file(str(fname), 'r') as f: 79 | try: 80 | dsets = {} 81 | for node in f.walk_nodes('/' + genome, 'Array'): 82 | dsets[node.name] = node.read() 83 | 84 | n_genes, n_cells = dsets['shape'] 85 | data = dsets['data'] 86 | if dsets['data'].dtype == np.dtype('int32'): 87 | data = dsets['data'].view('float32') 88 | data[:] = dsets['data'] 89 | 90 | X = csr_matrix((data, dsets['indices'], dsets['indptr']), 91 | shape=(n_cells, n_genes)) 92 | genes = [ gene for gene in dsets['gene_names'].astype(str) ] 93 | assert(len(genes) == n_genes) 94 | assert(len(genes) == X.shape[1]) 95 | 96 | except tables.NoSuchNodeError: 97 | raise Exception('Genome %s does not exist in this file.' % genome) 98 | except KeyError: 99 | raise Exception('File is missing one or more required datasets.') 100 | 101 | return X, np.array(genes) 102 | 103 | def process_tab(fname, min_trans=MIN_TRANSCRIPTS, delim='\t'): 104 | X, cells, genes = load_tab(fname, delim=delim) 105 | 106 | gt_idx = [ i for i, s in enumerate(np.sum(X != 0, axis=1)) 107 | if s >= min_trans ] 108 | X = csr_matrix(X[gt_idx, :]) 109 | cells = cells[gt_idx] 110 | if len(gt_idx) == 0: 111 | print('Warning: 0 cells passed QC in {}'.format(fname)) 112 | if fname.endswith('.txt'): 113 | cache_prefix = '.'.join(fname.split('.')[:-1]) 114 | elif fname.endswith('.txt.gz'): 115 | cache_prefix = '.'.join(fname.split('.')[:-2]) 116 | elif fname.endswith('.tsv'): 117 | cache_prefix = '.'.join(fname.split('.')[:-1]) 118 | elif fname.endswith('.tsv.gz'): 119 | cache_prefix = '.'.join(fname.split('.')[:-2]) 120 | elif fname.endswith('.csv'): 121 | cache_prefix = '.'.join(fname.split('.')[:-1]) 122 | elif fname.endswith('.csv.gz'): 123 | cache_prefix = '.'.join(fname.split('.')[:-2]) 124 | else: 125 | cache_prefix = fname 126 | 127 | cache_fname = cache_prefix + '_tab.npz' 128 | scipy.sparse.save_npz(cache_fname, X, compressed=False) 129 | 130 | with open(cache_prefix + '_tab.genes.txt', 'w') as of: 131 | of.write('\n'.join(genes) + '\n') 132 | 133 | return X, cells, genes 134 | 135 | def process_mtx(dname, min_trans=MIN_TRANSCRIPTS): 136 | X, genes = load_mtx(dname) 137 | 138 | gt_idx = [ i for i, s in enumerate(np.sum(X != 0, axis=1)) 139 | if s >= min_trans ] 140 | X = X[gt_idx, :] 141 | if len(gt_idx) == 0: 142 | print('Warning: 0 cells passed QC in {}'.format(dname)) 143 | 144 | cache_fname = dname + '/tab.npz' 145 | scipy.sparse.save_npz(cache_fname, X, compressed=False) 146 | 147 | with open(dname + '/tab.genes.txt', 'w') as of: 148 | of.write('\n'.join(genes) + '\n') 149 | 150 | return X, genes 151 | 152 | def process_h5(fname, min_trans=MIN_TRANSCRIPTS): 153 | X, genes = load_h5(fname) 154 | 155 | gt_idx = [ i for i, s in enumerate(np.sum(X != 0, axis=1)) 156 | if s >= min_trans ] 157 | X = X[gt_idx, :] 158 | if len(gt_idx) == 0: 159 | print('Warning: 0 cells passed QC in {}'.format(fname)) 160 | 161 | if fname.endswith('.h5'): 162 | cache_prefix = '.'.join(fname.split('.')[:-1]) 163 | 164 | cache_fname = cache_prefix + '.h5.npz' 165 | scipy.sparse.save_npz(cache_fname, X, compressed=False) 166 | 167 | with open(cache_prefix + '.h5.genes.txt', 'w') as of: 168 | of.write('\n'.join(genes) + '\n') 169 | 170 | return X, genes 171 | 172 | def load_data(name): 173 | if os.path.isfile(name + '.h5.npz'): 174 | X = scipy.sparse.load_npz(name + '.h5.npz') 175 | with open(name + '.h5.genes.txt') as f: 176 | genes = np.array(f.read().rstrip().split('\n')) 177 | elif os.path.isfile(name + '_tab.npz'): 178 | X = scipy.sparse.load_npz(name + '_tab.npz') 179 | with open(name + '_tab.genes.txt') as f: 180 | genes = np.array(f.read().rstrip().split('\n')) 181 | elif os.path.isfile(name + '/tab.npz'): 182 | X = scipy.sparse.load_npz(name + '/tab.npz') 183 | with open(name + '/tab.genes.txt') as f: 184 | genes = np.array(f.read().rstrip().split('\n')) 185 | else: 186 | sys.stderr.write('Could not find: {}\n'.format(name)) 187 | exit(1) 188 | genes = np.array([ gene.upper() for gene in genes ]) 189 | return X, genes 190 | 191 | def load_names(data_names, norm=False, log1p=False, verbose=True): 192 | # Load datasets. 193 | datasets = [] 194 | genes_list = [] 195 | n_cells = 0 196 | for name in data_names: 197 | X_i, genes_i = load_data(name) 198 | if norm: 199 | X_i = normalize(X_i, axis=1) 200 | if log1p: 201 | X_i = np.log1p(X_i) 202 | X_i = csr_matrix(X_i) 203 | 204 | datasets.append(X_i) 205 | genes_list.append(genes_i) 206 | n_cells += X_i.shape[0] 207 | if verbose: 208 | print('Loaded {} with {} genes and {} cells'. 209 | format(name, X_i.shape[1], X_i.shape[0])) 210 | if verbose: 211 | print('Found {} cells among all datasets' 212 | .format(n_cells)) 213 | 214 | return datasets, genes_list, n_cells 215 | 216 | def save_datasets(datasets, genes, data_names, verbose=True, 217 | truncate_neg=False): 218 | for i in range(len(datasets)): 219 | dataset = datasets[i].toarray() 220 | name = data_names[i] 221 | 222 | if truncate_neg: 223 | dataset[dataset < 0] = 0 224 | 225 | with open(name + '.scanorama_corrected.txt', 'w') as of: 226 | # Save header. 227 | of.write('Genes\t') 228 | of.write('\t'.join( 229 | [ 'cell' + str(cell) for cell in range(dataset.shape[0]) ] 230 | ) + '\n') 231 | 232 | for g in range(dataset.shape[1]): 233 | of.write(genes[g] + '\t') 234 | of.write('\t'.join( 235 | [ str(expr) for expr in dataset[:, g] ] 236 | ) + '\n') 237 | 238 | def merge_datasets(datasets, genes, ds_names=None, verbose=True, 239 | union=False, keep_genes=None): 240 | if keep_genes is None: 241 | # Find genes in common. 242 | keep_genes = set() 243 | for idx, gene_list in enumerate(genes): 244 | gene_list = [ g for gene in gene_list for g in gene.split(';') ] 245 | if len(keep_genes) == 0: 246 | keep_genes = set(gene_list) 247 | elif union: 248 | keep_genes |= set(gene_list) 249 | else: 250 | keep_genes &= set(gene_list) 251 | if not union and not ds_names is None and verbose: 252 | print('After {}: {} genes'.format(ds_names[idx], len(keep_genes))) 253 | if len(keep_genes) == 0: 254 | print('Error: No genes found in all datasets, exiting...') 255 | exit(1) 256 | else: 257 | union = True 258 | 259 | if verbose: 260 | print('Found {} genes among all datasets' 261 | .format(len(keep_genes))) 262 | 263 | if union: 264 | union_genes = sorted(keep_genes) 265 | for i in range(len(datasets)): 266 | if verbose: 267 | print('Processing dataset {}'.format(i)) 268 | X_new = np.zeros((datasets[i].shape[0], len(union_genes))) 269 | X_old = csc_matrix(datasets[i]) 270 | gene_to_idx = { g: idx for idx, gene in enumerate(genes[i]) 271 | for g in gene.split(';') } 272 | for j, gene in enumerate(union_genes): 273 | if gene in gene_to_idx: 274 | X_new[:, j] = X_old[:, gene_to_idx[gene]].toarray().flatten() 275 | datasets[i] = csr_matrix(X_new) 276 | ret_genes = np.array(union_genes) 277 | else: 278 | # Only keep genes in common. 279 | ret_genes = np.array(sorted(keep_genes)) 280 | for i in range(len(datasets)): 281 | if len(genes[i]) != datasets[i].shape[1]: 282 | raise ValueError('Mismatch along gene dimension for dataset {}, ' 283 | '{} genes vs {} matrix shape' 284 | .format(ds_names[i] if ds_names is not None 285 | else i, len(genes[i]), datasets[i].shape[1])) 286 | 287 | # Remove duplicate genes. 288 | uniq_genes, uniq_idx = np.unique(genes[i], return_index=True) 289 | datasets[i] = datasets[i][:, uniq_idx] 290 | 291 | # Do gene filtering. 292 | gene_sort_idx = np.argsort(uniq_genes) 293 | gene_idx = [ 294 | idx 295 | for idx in gene_sort_idx 296 | for g in uniq_genes[idx].split(';') if g in keep_genes 297 | ] 298 | datasets[i] = datasets[i][:, gene_idx] 299 | assert(len(uniq_genes[gene_idx]) == len(ret_genes)) 300 | 301 | return datasets, ret_genes 302 | 303 | def process(data_names, min_trans=MIN_TRANSCRIPTS): 304 | for name in data_names: 305 | if os.path.isdir(name): 306 | process_mtx(name, min_trans=min_trans) 307 | elif os.path.isfile(name) and name.endswith('.h5'): 308 | process_h5(name, min_trans=min_trans) 309 | elif os.path.isfile(name + '.h5'): 310 | process_h5(name + '.h5', min_trans=min_trans) 311 | elif os.path.isfile(name): 312 | process_tab(name, min_trans=min_trans) 313 | elif os.path.isfile(name + '.txt'): 314 | process_tab(name + '.txt', min_trans=min_trans) 315 | elif os.path.isfile(name + '.txt.gz'): 316 | process_tab(name + '.txt.gz', min_trans=min_trans) 317 | elif os.path.isfile(name + '.tsv'): 318 | process_tab(name + '.tsv', min_trans=min_trans) 319 | elif os.path.isfile(name + '.tsv.gz'): 320 | process_tab(name + '.tsv.gz', min_trans=min_trans) 321 | elif os.path.isfile(name + '.csv'): 322 | process_tab(name + '.csv', min_trans=min_trans, delim=',') 323 | elif os.path.isfile(name + '.csv.gz'): 324 | process_tab(name + '.csv.gz', min_trans=min_trans, delim=',') 325 | else: 326 | sys.stderr.write('Warning: Could not find {}\n'.format(name)) 327 | continue 328 | print('Successfully processed {}'.format(name)) 329 | 330 | if __name__ == '__main__': 331 | from config import data_names 332 | 333 | process(data_names) 334 | -------------------------------------------------------------------------------- /deprecated/old_examples/sci-car/cmdlines.sh: -------------------------------------------------------------------------------- 1 | ####################### 2 | ######### README ###### 3 | # Do not execute this file in one go. Best is to execute each command separately in a separate window and make sure all goes well 4 | # You will need to change SCRIPT and DDIR variables below. 5 | # You should first download the pre-processed data generated from raw Sci-CAR data, available at: schema.csail.mit.edu 6 | ####################### 7 | 8 | DDIR=/afs/csail.mit.edu/u/r/rsingh/work/schema/data/sci-car/processed/ 9 | SCRIPT=/afs/csail.mit.edu/u/r/rsingh/work/schema/examples/sci-car/scicar-process2.py 10 | 11 | 12 | ### read raw data and make Scanpy files ###################### 13 | # You do NOT need to run this. Its output is available at schema.csail.mit.edu 14 | 15 | # $SCRIPT --mode raw_data_read 16 | 17 | 18 | 19 | 20 | ### generate peak <-> gene (radial basis function) features ################ 21 | # NOTE: Pre-generated features (i.e. the output of this command) are available at the aforementioned site. 22 | # The cmd takes a while to run; for each gene, it looks at all peaks when computing the feature values... 23 | 24 | oldd=$PWD; cd $DDIR 25 | $SCRIPT --mode produce_gene2fpeak --outsfx 20191218-1615 --njobs 36 --infile ${DDIR}/adata1x.h5ad 26 | cd $oldd 27 | 28 | 29 | 30 | 31 | ### do Schema runs ####################### 32 | # We do Schema runs for the entire dataset ('0:12000' below), as well as for each quartile of genes ranked by expression variability 33 | # After processing and filtering for sparsity, the dataset has a little under 12K genes 34 | 35 | oldd=$PWD; cd $DDIR 36 | for x in 0:12000 0:3000 3000:6000 6000:9000 9000:12000 #0:12000 corresponds to entire dataset; rest are quartiles 37 | do 38 | minhvg=$(echo $x| cut -d: -f1); 39 | maxhvg=$(echo $x | cut -d: -f2); 40 | 41 | $SCRIPT --infile ${DDIR}/adata1x.h5ad --mode schema_gene2fpeak --outsfx 20200211-1600 --njobs 4 --extra adata_norm_style=2 gene2fpeak_file=adata1_M-produce_gene2fpeak_S-0_mtx_20191218-1615.csv min_hvgrank=$minhvg max_hvgrank=$maxhvg fpeak_cols_to_drop=fpeak_rbf_500 42 | 43 | done 44 | cd $oldd 45 | 46 | 47 | 48 | ### measure the clustering of highly variable genes (HVGs) within topologically associating domains (TADs) ###### 49 | # 50 | # with prep_significance_testing_data=1, each call to $SCRIPT produces a file of 1001 rows, of the following format: 51 | # - the first row corresponds to actual ('Orig') data while the remaining 1000 rows correspond to randomly shuffled instances 52 | # - each row has the format ,,,c_1,...,c_15,f_1,...,f_15 53 | # - desc = Orig|Random 54 | # - n = number of genes that were within a TAD. We limit our analysis to these, excluding genes that lie outside a TAD. 55 | # The "Random" instances shuffle n genes across k TADs, with n and k as determined from "Orig" 56 | # - pair2_freq = fraction of gene-pairs (denominator is n*(n-1)/2) that share a TAD 57 | # - c_i (i=1,...,15) = number of TADs that contain exactly 'i' genes 58 | # - f_i (i=1,...,15) = fraction of genes (denominator is n) in TADs that contain exactly 'i' genes 59 | 60 | oldd=$PWD; cd $DDIR 61 | for x in 0:3000 3000:6000 6000:9000 9000:12000 #0:3000 is top quartile 62 | do 63 | minhvg=$(echo $x| cut -d: -f1); 64 | maxhvg=$(echo $x | cut -d: -f2); 65 | 66 | $SCRIPT --infile ${DDIR}/adata1x.h5ad --mode compute_hvg_tad_dispersion --outsfx t11 --extra min_hvgrank=$minhvg max_hvgrank=$maxhvg tad_locations_file=hg19_A549_TAD.bed prep_significance_testing_data=1 67 | 68 | done 69 | cd $oldd 70 | -------------------------------------------------------------------------------- /deprecated/old_examples/sci-car/scicar-process2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import pandas as pd 4 | import numpy as np 5 | import scipy, sklearn, os, sys, string, fileinput, glob, re, math, itertools, functools, copy, multiprocessing, traceback 6 | import scipy.stats, sklearn.decomposition, sklearn.preprocessing, sklearn.covariance 7 | from scipy.stats import describe 8 | from scipy import sparse 9 | import os.path 10 | import scipy.sparse 11 | from scipy.sparse import csr_matrix, csc_matrix 12 | from sklearn.preprocessing import normalize 13 | from collections import defaultdict 14 | from tqdm import tqdm 15 | 16 | 17 | 18 | def read_raw_files_and_write_h5_files(outdir): 19 | def f_atac1(w): 20 | try: 21 | v = w["chr"] 22 | assert ("_" not in v) and (v in ['X','Y'] or int(v)<=22) 23 | return True 24 | except: 25 | return False 26 | 27 | import utils 28 | adata1 = utils.SciCar.loadData('/afs/csail.mit.edu/u/r/rsingh/work/schema/data/sci-car', 29 | [('rna' ,'gene','GSM3271040', lambda v: v["cell_name"]=="A549", lambda v: v["gene_type"]=="protein_coding"), 30 | ('atac','peak','GSM3271041', lambda v: v["group"][:4]=="A549", f_atac1), 31 | ], 32 | "/afs/csail.mit.edu/u/r/rsingh/work/refdata/hg19_mapped.tsv") 33 | adata1.write("{0}/adata1x.h5ad".format(outdir)) 34 | 35 | 36 | def mprun_geneXfpeak_mtx(adata, n_jobs=8): 37 | peak_func_list = [a[0] for a in SciCar.fpeak_list_all] 38 | print (adata.uns.keys(), adata.uns["atac.var"].head(2)) 39 | chr_mapping = SciCar.getChrMapping(adata) 40 | 41 | 42 | nPeaks = adata.uns["atac.X"].shape[1] 43 | l = [list(a) for a in np.array_split(range(nPeaks), 5*n_jobs)] 44 | pool = multiprocessing.Pool(processes = n_jobs) 45 | 46 | lx = pool.map(functools.partial(SciCar.computeGeneByFpeakMatrix, 47 | adata, peak_func_list, chr_mapping, normalize_distwt= True), 48 | l) 49 | 50 | 51 | g2p = None 52 | for m, _ in lx: 53 | if g2p is None: 54 | g2p = m 55 | else: 56 | g2p += m 57 | 58 | g2p = g2p * (1e5/nPeaks) 59 | 60 | dx = pd.DataFrame(g2p, index=None) 61 | dx.columns = [a[1] for a in SciCar.fpeak_list_all] 62 | dx["gene"] = list(adata.var.index) 63 | dx["ensembl_id"] = list(adata.var.ensembl_id) 64 | return dx 65 | 66 | 67 | 68 | 69 | def f_helper_mprun_schemawts_2(args): 70 | ax, dz, dir1, outsfx, min_corr, maxwt, strand, chromosome, adata_norm_style = args 71 | 72 | if adata_norm_style == 1: 73 | ax_l2norm = np.sqrt(np.sum(ax**2, axis=1)) 74 | ax = ax.copy() / (1e-12 + ax_l2norm[:,None]) 75 | 76 | elif adata_norm_style == 2: 77 | ax_l2norm = np.sqrt(np.sum(ax**2, axis=1)) 78 | ax = np.sort(ax.copy(), axis=1) / (1e-12 + ax_l2norm[:,None]) 79 | print("Flag 231.10 ") 80 | print(ax[-10:,-10:]) 81 | print(np.sum(ax**2,axis=1)[-10:]) 82 | 83 | dz_cols = dz.columns[:-2] 84 | dz_vals = dz.values[:,:-2] 85 | 86 | vstd = np.std(dz_vals.astype(float), axis=0) 87 | print("Flag 231.0275 ", vstd.shape, dz_vals.shape, flush=True) 88 | dz_vals = dz_vals.copy() / (1e-12 + vstd) 89 | 90 | try: 91 | sys.path.append(os.path.join(sys.path[0],'../../schema')) 92 | import schema_qp 93 | except: 94 | from schema import schema_qp 95 | 96 | sqp = schema_qp.SchemaQP(min_corr, maxwt, params= {"dist_npairs": 1000000}, mode="scale") 97 | try: 98 | dz1 = sqp.fit_transform(dz_vals, [ax], ['feature_vector'], [1]) 99 | 100 | print("Flag 231.030 ", min_corr, maxwt, dz_vals.shape, ax.shape, flush=True) 101 | 102 | wtsx = np.sqrt(np.maximum(sqp._wts/np.sum(sqp._wts), 0)) 103 | except: 104 | print ("ERROR: schema failed for ", min_corr, maxwt, strand, chromosome) 105 | wtsx = 1e12*np.ones(dz_vals.shape[1]) 106 | 107 | wdf = pd.Series(wtsx, index=dz_cols).sort_values(ascending=False).reset_index().rename(columns={"index": "fdist",0: "wt"}) 108 | wdf.to_csv("{0}/adata1_sqp_wts_mincorr{1}_maxw{2}_strand{4}_chr{5}_adatanorm{8}_{7}.csv".format(dir1, min_corr, maxwt, 1, strand, chromosome, 5, outsfx, adata_norm_style), index=False) 109 | 110 | 111 | 112 | def mprun_schemawts_2(adata1, dz, dir1, outsfx, adata_norm_style, do_dataset_split=False, n_jobs=4): 113 | 114 | try: 115 | sys.path.append(os.path.join(sys.path[0],'../../schema')) 116 | import schema_qp 117 | except: 118 | from schema import schema_qp 119 | 120 | pool = multiprocessing.Pool(processes = n_jobs) 121 | try: 122 | ax = np.copy(adata1.X.todense().T) 123 | except: 124 | ax = adata1.X.T.copy() 125 | 126 | nGenes = adata1.var.shape[0] 127 | chrx = adata1.var["chr"].apply(lambda s: s.replace("chr","")) 128 | 129 | lx = [] 130 | 131 | # one-time export of data for ridge-regression comparison 132 | # if False: 133 | # dfax = pd.DataFrame(ax.T) 134 | # dfax.columns = list(adata1.var_names) 135 | 136 | # dz_vals = dz[dz.ensembl_id.isin(adata1.var["ensembl_id"])].values.copy()[:,:-2] 137 | # vstd = np.std(dz_vals.astype(float), axis=0) 138 | # dfdz = pd.DataFrame(dz_vals/ (1e-12 + vstd)) 139 | # dfdz.columns = list(dz.columns)[:-2] 140 | # dfax.to_csv("saved_data_{0}_gexp.csv".format(outsfx), index=False) 141 | # adata1.var.to_csv("saved_data_{0}_aux.csv".format(outsfx), index=False) 142 | # dfdz.to_csv("saved_data_{0}_features.csv".format(outsfx), index=False) 143 | 144 | 145 | dz2 = dz 146 | for strand in ["both", "plus", "minus"]: 147 | if strand != "both" and do_dataset_split==False: continue 148 | 149 | gidx_strand = np.full(nGenes, True) 150 | if strand == "plus": gidx_strand = np.where(adata1.var["strand"]=="+", True, False) 151 | if strand == "minus": gidx_strand = np.where(adata1.var["strand"]=="-", True, False) 152 | 153 | for chromosome in ["all", "1--8","9--16","17--23"]: 154 | if chromosome != "all" and do_dataset_split==False: continue 155 | 156 | gidx_chr = np.full(nGenes, True) 157 | if chromosome=="1--8": gidx_chr = chrx.isin("1,2,3,4,5,6,7,8".split(",")) 158 | if chromosome=="9--16": gidx_chr = chrx.isin("9,10,11,12,13,14,15,16".split(",")) 159 | if chromosome=="17--23": gidx_chr = chrx.isin("17,18,19,20,21,22,X,Y".split(",")) 160 | 161 | for mc in [0.20, 0.5, 0.9]: 162 | for mw in [10,5,3]: 163 | 164 | gidx = gidx_strand & gidx_chr 165 | print("Flag 3312.040 ", gidx.shape, np.sum(gidx), ax.shape, dz2.shape, flush=True) 166 | lx.append((ax[gidx,:], dz2[dz2.ensembl_id.isin(adata1.var["ensembl_id"][gidx])], dir1, outsfx, 167 | mc, mw, strand, chromosome, adata_norm_style)) 168 | print("Flag 3312.050 ", np.sum(gidx), lx[-1][2:], flush=True) 169 | 170 | try: 171 | pool.map(f_helper_mprun_schemawts_2, lx) 172 | 173 | finally: 174 | pool.close() 175 | pool.join() 176 | 177 | 178 | 179 | 180 | def getGeneDistances(adata1): 181 | chrInt = adata1.var["chr"].apply(lambda s: s.replace("chr","").replace("X","231").replace("Y","232")).astype(int).values 182 | diff_chr = np.subtract.outer(chrInt, chrInt)!=0 183 | 184 | tss_adj = adata1.var['tss_adj'].values 185 | chrdist = np.abs(np.subtract.outer(tss_adj, tss_adj)) 186 | 187 | import numpy.ma as ma 188 | gene_dist = ma.masked_where( diff_chr, chrdist).filled(1e15) 189 | nGenes = gene_dist.shape[0] 190 | l0 = list(range(nGenes)) 191 | gene_dist[l0, l0] = 1e15 192 | return gene_dist 193 | 194 | 195 | def getGeneStrandMatch(adata1): 196 | chrInt = adata1.var["chr"].apply(lambda s: s.replace("chr","").replace("X","231").replace("Y","232")).astype(int).values 197 | diff_chr = np.subtract.outer(chrInt, chrInt)!=0 198 | 199 | strand_adj = np.where(adata1.var['strand'].values=="+",1,0) 200 | strandM = (np.subtract.outer(strand_adj, strand_adj) == 0).astype(int) 201 | 202 | import numpy.ma as ma 203 | gene_sign = ma.masked_where( diff_chr, strandM).filled(1e15) 204 | nGenes = gene_sign.shape[0] 205 | l0 = list(range(nGenes)) 206 | gene_sign[l0, l0] = 1e15 207 | return gene_sign 208 | 209 | 210 | def rankGenesByVariability(adata1): 211 | import matplotlib 212 | matplotlib.use("agg") #otherwise scanpy tries to use tkinter which has issues importing 213 | import scanpy as sc 214 | nGenes = adata1.shape[1] 215 | print ("Flag 7868.200 ", nGenes, adata1.X.shape, adata1.shape, scipy.sparse.isspmatrix(adata1.X)) 216 | v = np.full(nGenes, nGenes) 217 | for i in range(500,nGenes-500,250): 218 | print ("Flag 7868.205 ", i) 219 | 220 | v0 = sc.pp.highly_variable_genes(adata1, n_top_genes=i,inplace=False).highly_variable 221 | print ("Flag 7868.210 ", i, np.sum(v0), len(v0)) 222 | v = np.where((v==nGenes) & v0, i, v).copy() 223 | return v 224 | 225 | 226 | def findNonVariableGenes(adata1, gene_window): 227 | nGenes = adata1.shape[1] 228 | 229 | v = np.full(nGenes, nGenes) 230 | for i in range(250,nGenes-500,250): 231 | try: 232 | hv = sc.pp.highly_variable_genes(adata1, n_top_genes=i,inplace=False).highly_variable 233 | except: 234 | print ("Flag 7867.500 ", i) 235 | traceback.print_exc(file=sys.stdout) 236 | continue 237 | 238 | v = np.where((v==nGenes) & hv, i, v).copy() 239 | 240 | 241 | w = (v >= gene_window[0]) & (v < gene_window[1]) 242 | print ("Flag 7867.200 ", np.sum(w), gene_window, nGenes) 243 | print ("Flag 7867.300 ", pd.Series(v).value_counts()) 244 | 245 | return w 246 | 247 | 248 | 249 | def random_gene_list(nGenes, gene_cnt, num_random_samples=1000): 250 | L = [] 251 | for i in range(num_random_samples): 252 | b = np.zeros(nGenes, dtype=bool) 253 | b[np.random.choice(nGenes, gene_cnt, replace=False)] = True 254 | L.append(b) 255 | 256 | return L 257 | 258 | 259 | def getGene2TADmap(adata1, tad): 260 | l = [] 261 | nGenes = adata1.var.shape[0] 262 | for i in tqdm(range(nGenes)): 263 | gchr = adata1.var["chr"][i] 264 | gtss = adata1.var["tss_adj"][i] 265 | idx1 = (tad.chr==gchr) & (tad["start"] <= gtss) & (tad["end"] > gtss) 266 | #if i%100==0: print ("Flag 3231.10 ", i) 267 | if idx1.sum() > 0: 268 | l.append( np.argmax(idx1.values)) 269 | else: 270 | l.append(np.NaN) 271 | g2tad = np.array(l) 272 | gene_sharedtad = np.subtract.outer( g2tad, g2tad)==0 273 | return g2tad, gene_sharedtad 274 | 275 | 276 | 277 | def f_in_tad_cnt( g2tad, gene_set): 278 | return np.sum(~np.isnan(g2tad[gene_set])) 279 | 280 | 281 | def f_mode_hvg_tad_dispersion_helper(g2tad, gene_sharedtad, gene_set): 282 | tad2genes = defaultdict(list) 283 | for i, t in enumerate(g2tad): 284 | if not np.isnan(t) and gene_set[i]: 285 | tad2genes[t].append(i) 286 | return [(k,v,len(v)) for k,v in tad2genes.items()] 287 | 288 | 289 | 290 | ################################################################################# 291 | 292 | if __name__ == "__main__": 293 | try: 294 | sys.path.append(os.path.join(sys.path[0],'../../schema')) 295 | from utils import SciCar 296 | except: 297 | import schema 298 | from schema.utils import SciCar 299 | 300 | import argparse 301 | parser = argparse.ArgumentParser() 302 | parser.add_argument("--mode", help="which code path to run. see main(..) for details") 303 | parser.add_argument("--outdir", help="output directory (can set to '.')", type=str, default=".") 304 | parser.add_argument("--outsfx", help="suffix to use when producing output files") 305 | parser.add_argument("--style", help="mode-specific interpretation", type=int, default=-1) 306 | parser.add_argument("--infile", help="input .h5ad file. Default is Sci-Car hs A549") 307 | parser.add_argument("--njobs", help="number of parallel cores to use", type=int, default=24) 308 | 309 | parser.add_argument("--extra", help="put this as the LAST option and arbitrary space-separated key=val pairs after that", type=str, nargs='*') 310 | 311 | 312 | args = parser.parse_args() 313 | assert args.mode is not None 314 | if args.mode !="raw_data_read": assert args.outsfx is not None 315 | 316 | extra_args = dict([a.split("=") for a in args.extra]) if args.extra else {} 317 | 318 | if args.mode== "raw_data_read": 319 | read_raw_files_and_write_h5_files( args.outdir) 320 | 321 | 322 | 323 | if args.infile is None: 324 | args.infile = "/afs/csail.mit.edu/u/r/rsingh/work/schema/data/sci-car/processed/adata1x.h5ad" 325 | 326 | 327 | adata1 = SciCar.loadAnnData(args.infile) 328 | adata1 = SciCar.preprocessAnnData(adata1, True, 5, 3, 5) 329 | 330 | if args.mode=="produce_gene2fpeak": 331 | if args.style < 0: args.style = 0 332 | dx = mprun_geneXfpeak_mtx(adata1, 36) 333 | dx.to_csv("{0}/adata1_M-{1}_S-{2}_mtx_{3}.csv".format( args.outdir, args.mode, args.style, args.outsfx), index=False) 334 | 335 | 336 | if args.mode=="schema_gene2fpeak": 337 | # --mode schema_gene2fpeak --outsfx 20191218-1615 --njobs 5 --extra gene2fpeak_file=adata1_M-produce_gene2fpeak_S-0_mtx_20191218-1615.csv 338 | 339 | if args.style < 0: args.style = 0 340 | 341 | assert "gene2fpeak_file" in extra_args 342 | dz = pd.read_csv(extra_args["gene2fpeak_file"]) 343 | 344 | if "fpeak_cols_to_drop" in extra_args: 345 | dz = dz.drop(columns = extra_args["fpeak_cols_to_drop"].split(",")) 346 | 347 | adata1.var["gene_variability_ranking"] = rankGenesByVariability(adata1) 348 | 349 | adata1.X = adata1.X.todense() 350 | 351 | min_hvgrank, max_hvgrank = int(extra_args.get("min_hvgrank",0)), int(extra_args.get("max_hvgrank",1000000)) 352 | adata1 = adata1[:, ((adata1.var["gene_variability_ranking"] >= min_hvgrank) & 353 | (adata1.var["gene_variability_ranking"] < max_hvgrank))] 354 | 355 | dz = dz[dz.ensembl_id.isin(adata1.var["ensembl_id"])].reset_index(drop=True) 356 | 357 | assert np.sum(dz.ensembl_id.values != adata1.var["ensembl_id"].values)==0 358 | 359 | adata_norm_style = int(extra_args.get("adata_norm_style",0)) 360 | do_dataset_split = int(extra_args.get("do_dataset_split",0)) > 0.5 361 | 362 | mprun_schemawts_2(adata1, dz, args.outdir, "M-{0}_S-{1}_minhvg-{2}_maxhvg-{3}_{4}".format(args.mode, args.style, min_hvgrank, max_hvgrank, args.outsfx), adata_norm_style, do_dataset_split, args.njobs) 363 | 364 | 365 | 366 | if args.mode == "compute_hvg_tad_dispersion": 367 | import scanpy as sc 368 | 369 | min_hvgrank, max_hvgrank = int(extra_args.get("min_hvgrank",0)), int(extra_args.get("max_hvgrank",1000000)) 370 | gene_window = [min_hvgrank, max_hvgrank] 371 | 372 | outsfx = args.outsfx 373 | 374 | gdist_matrix = getGeneDistances(adata1) 375 | print ("Flag 54123.20 ", gdist_matrix.shape) #, describe(gdist_matrix)) 376 | 377 | gsign_matrix = getGeneStrandMatch(adata1) 378 | print ("Flag 54123.21 ", gsign_matrix.shape) #, describe(gdist_matrix)) 379 | 380 | tad = pd.read_csv(extra_args.get("tad_locations_file", "hg19_A549_TAD.bed"), delimiter="\t", header=None) 381 | tad.columns = ["chr","start","end", "x1","x2"] 382 | g2tad, gene_sharedtad = getGene2TADmap(adata1, tad) 383 | 384 | nGenes = adata1.shape[1] 385 | 386 | # v = np.full(nGenes, nGenes) 387 | # for i in tqdm(range(500,nGenes-500,500)): 388 | # v0 = sc.pp.highly_variable_genes(adata1, n_top_genes=i,inplace=False).highly_variable 389 | # v = np.where((v==nGenes) & v0, i, v).copy() 390 | # adata1.var["gene_variability_ranking"] = v 391 | 392 | adata1.var["gene_variability_ranking"] = rankGenesByVariability(adata1) 393 | 394 | window_genes = ((adata1.var["gene_variability_ranking"].values >= gene_window[0]) & 395 | (adata1.var["gene_variability_ranking"].values < gene_window[1])) 396 | 397 | n_window_genes = np.sum(window_genes) 398 | 399 | print ("Flag 54123.30 ", gene_window, len(window_genes), np.sum(window_genes)) 400 | 401 | num_samples = 1000 402 | 403 | in_tad_cnt = f_in_tad_cnt(g2tad, window_genes) 404 | tad_total_cnt = np.sum(~np.isnan(g2tad)) 405 | 406 | rl = random_gene_list(tad_total_cnt, in_tad_cnt, num_random_samples = num_samples) 407 | 408 | # we don't really care about genes outside TADs in this analysis 409 | window_genes [np.isnan(g2tad)] = False 410 | n_window_genes = np.sum(window_genes) 411 | lx0 = [window_genes] 412 | print ("Flag 54123.32 ", len(window_genes), np.sum(window_genes), len(lx0)) 413 | for rlx in rl: 414 | ax = np.full( nGenes, False) 415 | ax[~np.isnan(g2tad)] = rlx #[rlx] = True 416 | lx0.append(ax) 417 | print ("Flag 54123.33 ", len(window_genes), np.sum(window_genes), len(lx0), len(lx0[1]), np.sum(lx0[1])) 418 | 419 | lx = [(g2tad, gene_sharedtad, a) for a in lx0] 420 | 421 | 422 | print ("Flag 54123.40 ", len(lx), len(lx[1][2]), np.sum(lx[1][2])) 423 | 424 | n_jobs = 36 425 | print ("Flag 54123.435 ", nGenes, n_window_genes, gdist_matrix.shape, adata1.shape, len(lx)) 426 | 427 | 428 | pool = multiprocessing.Pool(processes = n_jobs) 429 | ly = pool.starmap(f_mode_hvg_tad_dispersion_helper, lx) 430 | 431 | outfile = "adata1_tad_membership_genewindow{0}-{1}_{2}".format(min_hvgrank, max_hvgrank, outsfx) 432 | outfh = open(outfile, 'w') 433 | 434 | prep_significance_testing_data = int(extra_args.get("prep_significance_testing_data",1)) > 0.5 435 | 436 | if not prep_significance_testing_data: 437 | random_samples_cnts = defaultdict(int) 438 | random_samples_N = 0 439 | for z in ly[1:]: 440 | for _, _, tadcnt in z: 441 | random_samples_cnts[tadcnt] += 1 442 | random_samples_N += 1.0 443 | 444 | for k,v in sorted(random_samples_cnts.items(), key = lambda a: a[0]): 445 | outfh.write("Random,{0},{1}\n".format(k, v/random_samples_N)) 446 | 447 | for tadidx, tad_members, tadcnt in sorted(ly[0], key=lambda a: -a[2]): 448 | outfh.write("Orig,{0},{1},{2}\n".format( tad["x1"][tadidx], tadcnt, ",".join([adata1.var["Symbol"][t] for t in tad_members]))) 449 | 450 | else: 451 | for i , z in enumerate(ly): 452 | #print ("Flag 54123.50 ", i, z) 453 | tadcnt2freq = defaultdict(int) 454 | numpairs_shared_tad = 0.0 455 | n = 0.0 456 | for _, _, tadcnt in z: 457 | tadcnt2freq[tadcnt] += 1 458 | numpairs_shared_tad += tadcnt*(tadcnt-1)/2.0 459 | n += tadcnt 460 | numpairs_all = n*(n-1)/2.0 461 | s = "Random" if i>0 else "Orig" 462 | s += ",{0},{1}".format(n, numpairs_shared_tad/numpairs_all) 463 | s += ",".join([""]+["{0}".format(tadcnt2freq[j]) for j in range(1,16)]) 464 | s += ",".join([""]+["{0}".format(tadcnt2freq[j]*j/n) for j in range(1,16)]) 465 | #print("Flag 54123.56 ", s) 466 | outfh.write(s + "\n") 467 | 468 | 469 | outfh.close() 470 | 471 | 472 | -------------------------------------------------------------------------------- /deprecated/old_examples/slide-seq/cmdlines.sh: -------------------------------------------------------------------------------- 1 | ####################### 2 | ######### README ###### 3 | # Do not execute this file in one go. Best is to execute each command separately in a separate window and make sure all goes well 4 | # You will need to change SCRIPT and DDIR variables below. 5 | # You should first download the pre-processed data generated from raw Slide-seq data, available at: schema.csail.mit.edu 6 | ####################### 7 | 8 | DDIR=/afs/csail.mit.edu/u/r/rsingh/work/schema/data/slideseq/processed/ 9 | SCRIPT=/afs/csail.mit.edu/u/r/rsingh/work/schema/examples/slide-seq/slideseq-process1.py 10 | 11 | 12 | 13 | ### read raw data and make Scanpy files ###################### 14 | # $SCRIPT --mode raw_data_read # You do NOT need to run this. Its output is available at schema.csail.mit.edu 15 | 16 | 17 | 18 | ### do Schema runs #################### 19 | 20 | for pid in 180430_1 180430_5 180430_6; do echo $pid; $SCRIPT --mode schema_kd_granule_cells --infile ${DDIR}/puck_${pid}.h5ad --outpfx ${DDIR}/kd_fit-on-all_kdbw-45_${pid} --extra kd_fit_granule_only=0 kd_bw=45 ; done 21 | 22 | 23 | 24 | 25 | ### do CCA runs (2nd method, involving 2 steps) ####################### 26 | 27 | for pid in 180430_1 180430_5 180430_6; do echo $pid; $SCRIPT --mode cca2step_kd_granule_cells --infile ${DDIR}/puck_${pid}.h5ad --outpfx ${DDIR}/ccakd_fit-on-all_kdbw-45_${pid} --extra kd_fit_granule_only=0 kd_bw=45 ; done 28 | 29 | 30 | 31 | 32 | ### Generate gene-rankings per puck 33 | 34 | oldd=$PWD; 35 | cd $DDIR 36 | 37 | for f in kd_fit-on-all_kdbw-45_180430_?_func_output.pkl cca2stepkd_fit-on-all_kdbw-45_180430_?_CCA2STEP_output.pkl 38 | do 39 | sfx=$(echo $f | perl -pe 's/(180430_.).*$/\1/') 40 | pid=$(echo $f | perl -pe 's/^.*(180430_.).*$/\1/') 41 | typ=$(echo $f | awk '/CCA/ {print "cca"} !/CCA/ {print "schema"}') 42 | 43 | $SCRIPT --mode generate_multipuck_gene_ranks --outpfx per_puck_${sfx} --extra data_type=$typ pkl_file_glob=./$f 44 | done 45 | cd $oldd 46 | 47 | 48 | 49 | 50 | ### Generate consensus gene-rankings by aggregating scores across 3 pucks. 51 | ### same usage as above except we now specify globs instead of one file at a time 52 | 53 | oldd=$PWD; 54 | cd $DDIR 55 | 56 | for f in "kd_fit-on-all_kdbw-45_180430_?_func_output.pkl" "cca2stepkd_fit-on-all_kdbw-45_180430_?_CCA2STEP_output.pkl" 57 | do 58 | sfx=$(echo $f | perl -pe 's/(180430_.).*$/\1/') 59 | pid=$(echo $f | perl -pe 's/^.*(180430_.).*$/\1/') 60 | typ=$(echo $f | awk '/CCA/ {print "cca"} !/CCA/ {print "schema"}') 61 | 62 | $SCRIPT --mode generate_multipuck_gene_ranks --outpfx per_puck_${sfx} --extra data_type=$typ pkl_file_glob="./$f" 63 | done 64 | cd $oldd 65 | -------------------------------------------------------------------------------- /deprecated/old_examples/slide-seq/slideseq-process1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import pandas as pd 4 | import numpy as np 5 | import scipy, sklearn, os, sys, string, fileinput, glob, re, math, itertools, functools 6 | import copy, multiprocessing, traceback, logging, pickle 7 | import scipy.stats, sklearn.decomposition, sklearn.preprocessing, sklearn.covariance 8 | from scipy.stats import describe 9 | from scipy import sparse 10 | import os.path 11 | import scipy.sparse 12 | from scipy.sparse import csr_matrix, csc_matrix 13 | from sklearn.preprocessing import normalize 14 | from collections import defaultdict 15 | from tqdm import tqdm 16 | 17 | 18 | 19 | def read_raw_files_and_write_h5_files(outdir): 20 | import utils 21 | for pid in ['180430_1','180430_5','180430_6']: 22 | adata1 = utils.SlideSeq.loadRawData("/afs/csail.mit.edu/u/r/rsingh/work/afid/data/slideseq/raw/", pid, 100) 23 | adata1.write("/afs/csail.mit.edu/u/r/rsingh/work/afid/data/slideseq/processed/puck_{0}.h5ad".format(pid)) 24 | 25 | 26 | def computeKernelDensityGranuleCells(adata1, kd_fit_granule_only=True, kd_bw=125): 27 | from sklearn.neighbors import KernelDensity 28 | fscl = lambda v: 20*(sklearn.preprocessing.MinMaxScaler().fit_transform(np.exp(v[:,None]-v.min()))).ravel() 29 | d3 = adata1.obs.copy(deep=True) #adata1.uns["Ho"].merge(adata1.obs, how="inner", left_index=True, right_index=True) 30 | d3c = d3[["xcoord","ycoord"]] 31 | if kd_fit_granule_only: 32 | d3["kd"] = fscl(KernelDensity(kernel='gaussian', bandwidth=kd_bw).fit(d3c[d3["atlas_cluster"]==1].values).score_samples(d3c.values)) 33 | else: 34 | d3["kd"] = fscl(KernelDensity(kernel='gaussian', bandwidth=kd_bw).fit(d3c.values).score_samples(d3c.values)) 35 | adata1.obs["kd"] = d3["kd"] 36 | return adata1 37 | 38 | 39 | 40 | def checkMaxFeasibleCorr(D, d0, g, tg, wg): 41 | try: 42 | sys.path.append(os.path.join(sys.path[0],'../../schema')) 43 | import schema_qp 44 | except: 45 | from schema import schema_qp 46 | 47 | for thresh in [0.30, 0.275, 0.25, 0.225, 0.20, 0.15, 0.10, 0.075, 0.06, 0.05, 0.04, 0.03, 0.025, 0.02, 0.015, 0.01]: 48 | print ("STARTING TRY OF ", thresh) 49 | try: 50 | sqp = schema_qp.SchemaQP(thresh, w_max_to_avg=1000, params= {"dist_npairs": 1000000}, mode="scale") 51 | dz1 = sqp.fit(D, g, tg, wg, d0=d0) 52 | print ("SUCCEEDED TRY OF ", thresh) 53 | return 0.9*thresh, thresh 54 | except: 55 | print ("EXCEPTION WHEN TRYING ", thresh) 56 | #raise 57 | return 0,0 58 | 59 | 60 | def runSchemaGranuleCellDensity(D, d0, gIn, tgIn, wgIn, min_corr1, min_corr2): 61 | try: 62 | sys.path.append(os.path.join(sys.path[0],'../../schema')) 63 | import schema_qp 64 | except: 65 | from schema import schema_qp 66 | 67 | f_linear = lambda v:v 68 | ret_val = {} 69 | 70 | w_list= [1,10,50,100] 71 | for w in w_list: 72 | s="linear" 73 | f=f_linear 74 | 75 | g1, wg1, tg1 = gIn[:], wgIn[:], tgIn[:] # does maximize negative corr with non-granule 76 | wg1[0] = w 77 | 78 | g = [g1[0]]; wg = [wg1[0]]; tg=[tg1[0]] # does NOT maximize negative corr with non-granule 79 | 80 | #afx0 = schema_qp.SchemaQP(0.001, 1000, mode="scale") 81 | #Dx0 = afx0.fit_transform(D,g,tg,wg,d0) 82 | #ret_val[(s,w,0)] = (np.sqrt(afx0._wts), afx0._soln_info) 83 | 84 | try: 85 | afx1 = schema_qp.SchemaQP(min_corr1, w_max_to_avg=1000, mode="scale") 86 | Dx1 = afx1.fit_transform(D,g,tg,wg,d0=d0) 87 | ret_val[(s,w,1)] = (np.sqrt(afx1._wts), afx1._soln_info ) # does NOT maximize negative corr with non-granule 88 | except: 89 | print("TRYING min-corr {0} for afx1 broke here".format(min_corr1)) 90 | continue 91 | 92 | try: 93 | afx2 = schema_qp.SchemaQP(min_corr1, w_max_to_avg=1000, mode="scale") 94 | Dx2 = afx2.fit_transform(D,g1,tg1,wg1,d0=d0) 95 | ret_val[(s,w,2)] = (np.sqrt(afx2._wts), afx2._soln_info) # does maximize negative corr with non-granule 96 | except: 97 | print("TRYING min-corr {0} for afx2 broke here".format(min_corr1)) 98 | continue 99 | 100 | try: 101 | afx3 = schema_qp.SchemaQP(min_corr2, w_max_to_avg=1000, mode="scale") 102 | Dx3 = afx3.fit_transform(D,g1,tg1,wg1,d0=d0) # does maximize negative corr with non-granule 103 | ret_val[(s,w,3)] = (np.sqrt(afx3._wts), afx3._soln_info) 104 | except: 105 | print("TRYING min-corr {0} for afx3 broke here".format(min_corr2)) 106 | continue 107 | 108 | return ret_val 109 | 110 | 111 | 112 | def getSoftmaxCombinedScores(Wo, schema_ret, use_generanks=True, do_pow2=True, schema_allowed_w1s=None): 113 | if use_generanks: 114 | R = Wo.rank(axis=1, pct=True).values.T 115 | else: 116 | R = Wo.values.T 117 | 118 | sumr = None; nr=0 119 | for x in schema_ret: 120 | style,w,i = x 121 | 122 | if style!="linear" or i not in [2,3]: continue #2,3 correspond to schema runs that also require disagreement w other altas clusters 123 | 124 | if schema_allowed_w1s is not None and w not in schema_allowed_w1s: continue 125 | 126 | wx = schema_ret[x][0]**2 127 | #wx = wx/np.sum(wx) 128 | 129 | schema_wts = wx**(2 if do_pow2 else 1) 130 | if np.max(schema_wts) > 20: 131 | schema_wts = 20*schema_wts/np.max(schema_wts) 132 | schema_probs = np.exp(schema_wts)/np.sum(np.exp(schema_wts)) 133 | g1 = (R*schema_probs).sum(axis=1) 134 | g2 = g1/np.std(g1.ravel()) 135 | #g2 = scipy.stats.rankdata(g1); g2 = g2/np.max(g2) 136 | if sumr is None: 137 | sumr = g2 138 | else: 139 | sumr += g2 140 | nr += 1 141 | rnks = sumr/nr 142 | s1= pd.Series(rnks, index=list(Wo.columns)).sort_values(ascending=False) 143 | return {u:(i+1) for i,u in enumerate(list(s1.index))} 144 | 145 | 146 | 147 | def getCellLoadingSoftmax(d3, Wo, schema_ret): 148 | R = Wo.values.T 149 | 150 | sumr = None; nr=0 151 | for x in schema_ret: 152 | style,w,i = x 153 | if style!="linear" or i not in [2,3]: continue 154 | 155 | wx = schema_ret[x][0]**2 156 | #wx = wx/np.sum(wx) 157 | 158 | schema_wts = wx 159 | if np.max(schema_wts) > 20: 160 | schema_wts = 20*schema_wts/np.max(schema_wts) 161 | 162 | schema_probs = np.exp(schema_wts)/np.sum(np.exp(schema_wts)) 163 | 164 | if sumr is None: 165 | sumr = schema_probs 166 | else: 167 | sumr += schema_probs 168 | nr += 1 169 | 170 | r = sumr/nr 171 | v = (d3.iloc[:,:100].multiply(r,axis=1)).sum(axis=1) 172 | return v.values 173 | 174 | 175 | 176 | 177 | def generatePlotGranuleCellDensity(d3, cell_loadings): 178 | import matplotlib.pyplot as plt 179 | import seaborn as sns 180 | score1 = cell_loadings 181 | 182 | np.random.seed(239) 183 | plt.style.use('seaborn-paper') 184 | plt.rcParams['lines.markersize'] = np.sqrt(0.25) 185 | fc = lambda v: np.where(v,'lightslategray','red') 186 | 187 | #fig = plt.figure(constrained_layout=True, figsize=(6.48,2.16), dpi=300) #(2*6.48,2*2.16)) 188 | fig = plt.figure(figsize=(6.48,2.16), dpi=300) #(2*6.48,2*2.16)) 189 | gs = fig.add_gridspec(2,6,wspace=0,hspace=0) #(2, 2) 190 | 191 | idxY=d3["atlas_cluster"]==1 192 | idxN=d3["atlas_cluster"]!=1 193 | 194 | coords = d3.loc[:,["xcoord","ycoord"]].values 195 | clstr = d3["atlas_cluster"].values 196 | 197 | cid_list = [1,2,3,6] 198 | 199 | fc = lambda v: np.where(v,'lightslategray','red') 200 | axdict = {} 201 | 202 | xyL = [(gs[0,0], coords[clstr==1,:], 'a','Granule Cells'), (gs[0,1], coords[clstr==2,:],'b','Purkinje Cells'), 203 | (gs[1,0], coords[clstr==3,:], 'c','Interneuron'), (gs[1,1], coords[clstr==6,:], 'd','Oligodendrocytes')] 204 | 205 | for g, dx, titlestr, desc in xyL: 206 | 207 | ax = fig.add_subplot(g) 208 | fc = lambda v: np.where(v,'lightslategray','red') 209 | ax.text(.95, .05, titlestr, horizontalalignment='center', transform=ax.transAxes, size=14 ) 210 | ax.axis('off') 211 | ax.scatter(dx[:,0], dx[:,1], color='black', alpha=0.20 if titlestr=='a' else 0.6 ) 212 | ax.set_aspect('equal') 213 | axdict[titlestr] = ax 214 | 215 | 216 | ax = fig.add_subplot(gs[:,2:4]) 217 | im = ax.scatter(coords[clstr==1,0],coords[clstr==1,1],c=2*d3["kd"].values[clstr==1],cmap="seismic",s=1) 218 | #im = ax.scatter(coords[:,0],coords[:,1],c=2*d3["kd"].values,cmap="seismic",s=1) 219 | ax.set_aspect('equal') 220 | ax.axis('off') 221 | from mpl_toolkits.axes_grid1 import make_axes_locatable 222 | div = make_axes_locatable(ax) 223 | cax = div.append_axes("bottom", size="3%", pad=0.01) 224 | cbar = fig.colorbar(im, cax=cax, shrink=0.2, orientation='horizontal') 225 | ax.text(.9, .05, "e", horizontalalignment='center', transform=ax.transAxes, size=14 ) 226 | 227 | sx = score1 > np.quantile(score1,0.75) 228 | for g, titlestr, ii, c1 in [(gs[0,4], "f", idxY & sx, 'r'), (gs[0,5], "g", idxY & (~sx), 'b'), 229 | (gs[1,4], "h", idxN & sx, 'r'), (gs[1,5], "i", idxN & (~sx), 'b')]: 230 | 231 | ax = fig.add_subplot(g) 232 | ax.text(.95, .05, titlestr, horizontalalignment='center', transform=ax.transAxes, size=14 ) 233 | #ax.axes.get_xaxis().set_visible(False) 234 | #ax.axes.get_yaxis().set_visible(False) 235 | ax.axis('off') 236 | ax.scatter(coords[ii,0], coords[ii,1], color=c1, alpha=0.40 ) 237 | ax.set_aspect('equal') 238 | axdict[titlestr] = ax 239 | 240 | 241 | #################################### 242 | fig.tight_layout() 243 | return fig 244 | 245 | 246 | 247 | 248 | def processGranuleCellDensitySchema(adata1, extra_args): 249 | 250 | if "kd" not in adata1.obs.columns: 251 | adata1 = computeKernelDensityGranuleCells(adata1, 252 | kd_fit_granule_only = int(extra_args.get("kd_fit_granule_only",1))==1, 253 | kd_bw = float(extra_args.get("kd_bw",125))) 254 | 255 | d3 = adata1.uns["Ho"].merge(adata1.obs, how="inner", left_index=True, right_index=True) 256 | Wo = adata1.uns["Wo"] 257 | cols_Ho = list(adata1.uns["Ho"].columns) 258 | 259 | D = d3[cols_Ho].values 260 | d0 = 1*(d3["atlas_cluster"].values==1) 261 | g = [(d3["kd"].values)]; wg=[10]; tg=["numeric"] 262 | for clid in [2,3,6,7]: 263 | g.append(1*(d3["atlas_cluster"].values==clid)) 264 | wg.append(-1) 265 | tg.append("categorical") 266 | 267 | 268 | min_corr1, min_corr2 = checkMaxFeasibleCorr(D, d0, g, tg, wg) 269 | schema_ret = runSchemaGranuleCellDensity(D, d0, g, tg, wg, min_corr1, min_corr2) 270 | scores = getSoftmaxCombinedScores(Wo, schema_ret, use_generanks=False, do_pow2=False) 271 | cell_loadings = getCellLoadingSoftmax(d3, Wo, schema_ret) 272 | 273 | fig = generatePlotGranuleCellDensity(d3, cell_loadings) 274 | return (fig, d3, schema_ret, min_corr1, min_corr2, scores, cell_loadings) 275 | 276 | 277 | 278 | def doSchemaCCA_CellScorePlot2(d3, cca_x_scores, cell_loadings): 279 | clstrs = d3["atlas_cluster"] 280 | kd = d3["kd"] 281 | 282 | cca_sgn = np.sign(scipy.stats.pearsonr(d3["kd"],cca_x_scores)[0]) #flip signs if needed 283 | 284 | R = {} 285 | for desc,v in [("ccax", cca_sgn*cca_x_scores), ("schema", cell_loadings)]: 286 | vr = scipy.stats.rankdata(v) 287 | vr = vr/vr.max() 288 | l = [] 289 | for t in np.linspace(0,1,100)[:-1]: 290 | cx = clstrs[vr >=t ] 291 | granule_frac = (np.sum(cx==1)/(1e-12+ len(cx))) 292 | cx2 = kd[ vr >= t] 293 | kd_val = np.median(cx2) 294 | l.append((granule_frac, kd_val)) 295 | R[desc]= list(zip(*l)) 296 | 297 | import matplotlib.pyplot as plt 298 | plt.style.use('seaborn-paper') 299 | plt.rcParams['lines.markersize'] = np.sqrt(0.25) 300 | 301 | fig = plt.figure(dpi=300) #(2*6.48,2*2.16)) 302 | 303 | a = np.linspace(0,1,100) 304 | plt.scatter(R["ccax"][0], R["ccax"][1], s=(1+3*a)**2, c="red", figure=fig) 305 | plt.scatter(R["schema"][0], R["schema"][1], s=(1+3*a)**2, c="blue", figure=fig) 306 | fig.legend("CCA fit,Schema fit".split(",")) 307 | plt.xlabel("Fraction of Beads labeled as Granule Cells", figure=fig) 308 | plt.ylabel("Median Kernel Density Score", figure=fig) 309 | return fig 310 | 311 | 312 | 313 | 314 | ################################################################################# 315 | 316 | if __name__ == "__main__": 317 | try: 318 | sys.path.append(os.path.join(sys.path[0],'../../schema')) 319 | from utils import SlideSeq 320 | except: 321 | from schema.utils import SlideSeq 322 | 323 | import argparse 324 | parser = argparse.ArgumentParser() 325 | parser.add_argument("--mode", help="which code path to run. see main(..) for details") 326 | parser.add_argument("--outdir", help="output directory (can set to '.')", type=str, default=".") 327 | parser.add_argument("--outpfx", help="prefix to use when producing output files") 328 | parser.add_argument("--style", help="mode-specific interpretation", type=int, default=-1) 329 | parser.add_argument("--infile", help="input .h5ad file. Default is SlideSeq 180430_1 h5ad") 330 | parser.add_argument("--njobs", help="number of parallel cores to use", type=int, default=24) 331 | 332 | parser.add_argument("--extra", help="put this as the LAST option and arbitrary space-separated key=val pairs after that", type=str, nargs='*') 333 | 334 | 335 | args = parser.parse_args() 336 | assert args.mode is not None 337 | if args.mode !="raw_data_read": assert args.outpfx is not None 338 | if args.infile is None: 339 | args.infile = "/afs/csail.mit.edu/u/r/rsingh/work/schema/data/slideseq/processed/puck_180430_1.h5ad" 340 | extra_args = dict([a.split("=") for a in args.extra]) if args.extra else {} 341 | 342 | 343 | if args.mode== "raw_data_read": 344 | read_raw_files_and_write_h5_files( args.outdir) 345 | 346 | 347 | if args.mode == "schema_kd_granule_cells": 348 | adata1 = SlideSeq.loadAnnData(args.infile) 349 | try: 350 | from schema import schema_qp 351 | except: 352 | sys.path.append(os.path.join(sys.path[0],'../../schema')) 353 | import schema_qp 354 | 355 | schema_qp.schema_loglevel = logging.WARNING 356 | 357 | fig, d3, schema_ret, min_corr1, min_corr2, scores, cell_loadings = processGranuleCellDensitySchema(adata1, extra_args) 358 | fig.tight_layout() 359 | fig.savefig("{0}_fig-KD.png".format(args.outpfx), dpi=300) 360 | fig.savefig("{0}_fig-KD.svg".format(args.outpfx)) 361 | pickle.dump((d3[["xcoord","ycoord","kd","atlas_cluster"]], schema_ret, min_corr1, min_corr2, scores, cell_loadings), 362 | open("{0}_func_output.pkl".format(args.outpfx), "wb")) 363 | 364 | 365 | if args.mode == "cca_kd_granule_cells": 366 | adata1 = SlideSeq.loadAnnData(args.infile) 367 | if "kd" not in adata1.obs.columns: 368 | adata1 = computeKernelDensityGranuleCells(adata1, 369 | kd_fit_granule_only = int(extra_args.get("kd_fit_granule_only",1))==1, 370 | kd_bw = float(extra_args.get("kd_bw",125))) 371 | from sklearn.cross_decomposition import CCA 372 | cca = CCA(1) 373 | cca.fit(adata1.X, adata1.obs["kd"]) 374 | cca_sgn = np.sign(scipy.stats.pearsonr(adata1.obs["kd"], cca.x_scores_[:,0])[0]) #flip signs if needed 375 | fig = generatePlotGranuleCellDensity(adata1.obs, cca_sgn*cca.x_scores_[:,0]) 376 | fig.tight_layout() 377 | fig.savefig("{0}_fig-CCA.png".format(args.outpfx), dpi=300) 378 | fig.savefig("{0}_fig-CCA.svg".format(args.outpfx)) 379 | pickle.dump((adata1.obs[["xcoord","ycoord","kd","atlas_cluster"]], cca.x_scores_[:,0], cca.x_loadings_[:,0], cca.y_scores_[:,0]), 380 | open("{0}_CCA_output.pkl".format(args.outpfx), "wb")) 381 | 382 | 383 | 384 | if args.mode == "cca2step_kd_granule_cells": 385 | adata1 = SlideSeq.loadAnnData(args.infile) 386 | if "kd" not in adata1.obs.columns: 387 | adata1 = computeKernelDensityGranuleCells(adata1, 388 | kd_fit_granule_only = int(extra_args.get("kd_fit_granule_only",1))==1, 389 | kd_bw = float(extra_args.get("kd_bw",125))) 390 | 391 | #### adata1 = adata1[:,:40] ## FOR TESTING 392 | 393 | from sklearn.cross_decomposition import CCA 394 | cca1 = CCA(1) 395 | cca1.fit(adata1.X, adata1.obs["kd"]) 396 | cca1_sgn = np.sign(scipy.stats.pearsonr(adata1.obs["kd"],cca1.x_scores_[:,0])[0]) #flip signs if needed 397 | 398 | cca2 = CCA(1) 399 | cca2.fit(adata1.X, 1*(adata1.obs["atlas_cluster"]==1)) 400 | cca2_sgn = np.sign(scipy.stats.pearsonr(1*(adata1.obs["atlas_cluster"]==1),cca2.x_scores_[:,0])[0]) #flip signs if needed 401 | 402 | score1 = cca1_sgn*cca1.x_scores_[:,0] 403 | score2 = cca2_sgn*cca2.x_scores_[:,0] 404 | scorex = 0.5 * (score1/np.std(score1) + score2/np.std(score2)) 405 | 406 | scorex = scorex/np.sqrt(np.sum(scorex**2)) 407 | 408 | loadings = np.matmul(np.transpose(adata1.X), scorex) 409 | intcpt = 0 410 | 411 | print("Flag 2320.01 ", scorex.shape, adata1.X.shape, loadings.shape, describe(scorex), describe(loadings)) 412 | 413 | 414 | fig = generatePlotGranuleCellDensity(adata1.obs, scorex) 415 | fig.tight_layout() 416 | fig.savefig("{0}_fig-CCA2STEP.png".format(args.outpfx), dpi=300) 417 | fig.savefig("{0}_fig-CCA2STEP.svg".format(args.outpfx)) 418 | pickle.dump((adata1.obs[["xcoord","ycoord","kd","atlas_cluster"]], scorex, loadings, intcpt), 419 | open("{0}_CCA2STEP_output.pkl".format(args.outpfx), "wb")) 420 | 421 | 422 | 423 | if args.mode == "cca_schema_comparison_plot": 424 | cca_pkl_file = extra_args["cca_pkl_file"] 425 | schema_pkl_file = extra_args["schema_pkl_file"] 426 | cca_d3, cca_x_scores, _ , _ = pickle.load(open(cca_pkl_file,"rb")) 427 | schema_d3, _, _, _, _, cell_loadings = pickle.load(open(schema_pkl_file,"rb")) 428 | #fig = doSchemaCCA_CellScorePlot(cca_d3, cca_x_scores, cell_loadings) 429 | fig = doSchemaCCA_CellScorePlot2(cca_d3, cca_x_scores, cell_loadings) 430 | fig.savefig("{0}_fig-Schema-CCA-cmp.png".format(args.outpfx), dpi=300) 431 | fig.savefig("{0}_fig-Schema-CCA-cmp.svg".format(args.outpfx)) 432 | 433 | 434 | 435 | if args.mode == "generate_multipuck_gene_ranks": 436 | pkl_file_glob = extra_args["pkl_file_glob"] 437 | assert extra_args["data_type"].lower() in ["schema","cca"] 438 | data_type = extra_args["data_type"].lower() 439 | 440 | pkl_flist = glob.glob(pkl_file_glob) 441 | print("Flag 67.10 ", pkl_flist) 442 | assert len(pkl_flist) > 0 443 | 444 | L = [] 445 | for f in pkl_flist: 446 | if data_type == "schema": 447 | _, _, _, _, scores, _ = pickle.load(open(f,"rb")) 448 | L.append([a[0] for a in sorted(scores.items(), key=lambda v:v[1])]) #in schema rankings, low number means top-rank 449 | 450 | elif data_type == "cca": 451 | d3, cca_x_scores, cca_x_loadings, _ = pickle.load(open(f,"rb")) 452 | cca_sgn = np.sign(scipy.stats.pearsonr(d3["kd"],cca_x_scores)[0]) 453 | 454 | puckid = f[f.index("180430"):][:8] 455 | adata1 = SlideSeq.loadAnnData("{0}/puck_{1}.h5ad".format(os.path.dirname(f), puckid)) 456 | 457 | df = pd.DataFrame.from_dict({"gene": list(adata1.uns["Wo"].columns), "cca_scores": cca_sgn*cca_x_loadings}) 458 | df = df.sort_values("cca_scores", ascending=False) 459 | L.append(list(df.gene.values)) 460 | 461 | Nmax = max(len(a) for a in L) 462 | print ("Flag 67.40 ", len(L), len(L[0]), Nmax) 463 | cons_score = {} 464 | active_set = set() 465 | for i in range(1,Nmax+1): 466 | currset = set.intersection(*[set(a[:i]) for a in L]) 467 | if len(currset) > len(active_set): 468 | for s in currset-active_set: 469 | cons_score[s] = len(currset) 470 | active_set = currset 471 | 472 | g = []; s = [] 473 | for k,v in cons_score.items(): 474 | g.append(k) 475 | s.append(v) 476 | 477 | pd.DataFrame.from_dict({"gene": g, "rank": s}).to_csv("{0}_generankings_dtype-{1}.csv".format(args.outpfx, data_type), index=False) 478 | 479 | -------------------------------------------------------------------------------- /deprecated/old_examples/tcr-binding/cmdlines.sh: -------------------------------------------------------------------------------- 1 | ####################### 2 | ######### README ###### 3 | # Do not execute this file in one go. Best is to execute each command separately in a separate window and make sure all goes well 4 | # You will need to change SCRIPT and DDIR variables below. 5 | # You should first download the pre-processed data generated from raw Slide-seq data, available at: schema.csail.mit.edu 6 | ####################### 7 | 8 | #DDIR=/afs/csail.mit.edu/u/r/rsingh/work/schema/data/tcr-binding/processed/ 9 | #SCRIPT=/afs/csail.mit.edu/u/r/rsingh/work/schema/examples/tcr-binding/tcrbinding-process1.py 10 | 11 | DDIR=/afs/csail.mit.edu/u/r/rsingh/work/schema/data/p2/tcr-binding/processed/ 12 | SCRIPT=/afs/csail.mit.edu/u/r/rsingh/work/schema/public/schema/examples/tcr-binding/tcrbinding-process1.py 13 | 14 | 15 | 16 | ### read raw data and make a HDF5 file ###################### 17 | # You do NOT need to run this. Its output is available at schema.csail.mit.edu 18 | 19 | # $SCRIPT --mode raw_data_read 20 | 21 | 22 | 23 | 24 | ### do Schema run to identify location-wise preferences in TCR alpha (tra) and TCR beta (trb) chains #################### 25 | # Here, we're doing a 3-modality integration: a) CDR3 sequence + b) epitope binding-specificity - c) cell-surface protein markers. 26 | # The last one (c) above is put in with a -0.25 wt as a batch-effect correction [corresponding wt of (b) is +1.0] 27 | # Can also be run as a 2-modality problem by changing to "--mode compute_2_modality_columnwise_preference" below 28 | # The output file produced has two columns: ,. Location is 0-indexed, and score indicates how likely is 29 | # it that the location displays low variablity. In the paper, we show 1-score, to indicate which locations are more variable. 30 | # Separate runs for alpha and beta chains 31 | # 32 | 33 | oldd=$PWD; cd $DDIR 34 | for c in tra trb 35 | do 36 | $SCRIPT --infile ${DDIR}/vdj_binarized_alldonors.h5 --mode compute_3_modality_columnwise_preference --outsfx _mode3_location_${c}_v1 --style 0 --extra chain=$c w_surface_markers=-0.25 > std1-mode3-${c}-wneg-style0.out 2>&1 37 | done 38 | cd $oldd 39 | 40 | 41 | 42 | 43 | 44 | ### do Schema run to identify amino acid selection pressure in TCR alpha (tra) and TCR beta (trb) chains #################### 45 | # Here, we're doing a 3-modality integration: a) CDR3 sequence + b) epitope binding-specificity - c) cell-surface protein markers. 46 | # The last one (c) above is put in with a -0.25 wt as a batch-effect correction [corresponding wt of (b) is +1.0] 47 | # Can also be run as a 2-modality problem by changing to "--mode compute_2_modality_columnwise_preference" below 48 | # The output file produced has two columns: ,. The score indicates how likely is 49 | # it that the amino acid is under selection pressure. 50 | 51 | oldd=$PWD; cd $DDIR 52 | $SCRIPT --infile ${DDIR}/vdj_binarized_alldonors.h5 --mode compute_3_modality_selection_pressure --outsfx _mode3_aa_v2 --extra kmer_type=std w_surface_markers=-0.25 > std2-mode3-${c}-wneg.out 2>&1 53 | cd $oldd 54 | 55 | 56 | -------------------------------------------------------------------------------- /deprecated/old_examples/tcr-binding/tcrbinding-process1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import pandas as pd 4 | import numpy as np 5 | import scipy, sklearn, os, sys, string, fileinput, glob, re, math, itertools, functools 6 | import copy, multiprocessing, traceback, logging, pickle, traceback 7 | import scipy.stats, sklearn.decomposition, sklearn.preprocessing, sklearn.covariance 8 | from scipy.stats import describe 9 | from scipy import sparse 10 | import os.path 11 | import scipy.sparse 12 | from scipy.sparse import csr_matrix, csc_matrix 13 | from sklearn.preprocessing import normalize 14 | from collections import defaultdict 15 | from tqdm import tqdm 16 | 17 | def fast_csv_read(filename, *args, **kwargs): 18 | small_chunk = pd.read_csv(filename, nrows=50) 19 | if small_chunk.index[0] == 0: 20 | coltypes = dict(enumerate([a.name for a in small_chunk.dtypes.values])) 21 | return pd.read_csv(filename, dtype=coltypes, *args, **kwargs) 22 | else: 23 | coltypes = dict((i+1,k) for i,k in enumerate([a.name for a in small_chunk.dtypes.values])) 24 | coltypes[0] = str 25 | return pd.read_csv(filename, index_col=0, dtype=coltypes, *args, **kwargs) 26 | 27 | 28 | def processRawData(rawdir, h5file): 29 | M = pd.concat([fast_csv_read(f) for f in glob.glob("{0}/vdj_v1_hs_aggregated_donor?_binarized_matrix.csv".format(rawdir))]) 30 | truthval_cols = [c for c in M.columns if 'binder' in c] 31 | 32 | surface_marker_cols = "CD3,CD19,CD45RA,CD4,CD8a,CD14,CD45RO,CD279_PD-1,IgG1,IgG2a,IgG2b,CD127,CD197_CCR7,HLA-DR".split(",") 33 | dx = copy.deepcopy(M.loc[:, surface_marker_cols]) 34 | 35 | M.loc[:,surface_marker_cols] = (np.log2(1 + 1e6*dx.divide(dx.sum(axis=1), axis="index"))).values 36 | 37 | # get rid of B cells etc. 38 | f_trim = lambda v: v < v.quantile(0.975) 39 | ok_Tcells = f_trim(M["CD19"]) & f_trim(M["CD4"]) & f_trim(M["CD14"]) 40 | 41 | M = M.loc[ ok_Tcells, :] 42 | 43 | a = M.loc[:,truthval_cols] 44 | a2= M['cell_clono_cdr3_aa'].apply(lambda v: 'TRA:' in v and 'TRB:' in v) 45 | bc_idx = (a2 & a.any(axis=1)) 46 | M = M[bc_idx].reset_index(drop=True) 47 | 48 | mcols = ["donor","cell_clono_cdr3_aa"] + truthval_cols + surface_marker_cols 49 | print("Flag 67.10 ", h5file) 50 | M.loc[:,mcols].to_hdf(h5file, "df", mode="w") 51 | 52 | 53 | 54 | def chunkifyCDRseqs(M, f_tra_filter, f_trb_filter, tra_start=0, trb_start=0, tra_end=100, trb_end=100): 55 | truthval_cols = [c for c in M.columns if 'binder' in c] 56 | surface_marker_cols = "CD3,CD19,CD45RA,CD4,CD8a,CD14,CD45RO,CD279_PD-1,IgG1,IgG2a,IgG2b,CD127,CD197_CCR7,HLA-DR".split(",") 57 | 58 | tra_L = []; trb_L = []; binds_L = []; idxL = [] 59 | for i in tqdm(range(M.shape[0])): #range(M.shape[0]): 60 | sl = M.at[i,"cell_clono_cdr3_aa"].split(";") 61 | a_l = [x[4:][tra_start:tra_end] for x in sl if x[:4]=="TRA:" and f_tra_filter(x[4:])] 62 | b_l = [x[4:][trb_start:trb_end] for x in sl if x[:4]=="TRB:" and f_trb_filter(x[4:])] 63 | c_np = M.loc[i,truthval_cols].astype(int).values 64 | A0 = ord('A') 65 | for a in a_l: 66 | a_np = np.zeros(26) 67 | for letter in a: 68 | a_np[ord(letter)-A0] += 1 69 | 70 | for b in b_l: 71 | b_np = np.zeros(26) 72 | for letter in b: 73 | b_np[ord(letter)-A0] += 1 74 | 75 | tra_L.append(a_np) 76 | trb_L.append(b_np) 77 | binds_L.append(c_np) 78 | idxL.append(i) 79 | tra = np.array(tra_L) 80 | trb = np.array(trb_L) 81 | binds = np.array(binds_L) 82 | return tra, trb, binds, M.loc[:, surface_marker_cols].iloc[idxL,:], M.iloc[idxL,:]["donor"] 83 | 84 | 85 | 86 | 87 | 88 | def f_dataset_helper(w_vdj, x, md, mw, w_surface_markers): 89 | trax, trbx, bindsx, d_surface_markers, _ = w_vdj 90 | try: 91 | return run_dataset_schema(trax, trbx, bindsx, 0.01, max_w = mw, mode=md, d_surface_markers = d_surface_markers, w_surface_markers=w_surface_markers) 92 | except: 93 | print ("Flag 67567.10 Saw exception in f_dataset_helper") 94 | return (None, None) 95 | 96 | 97 | 98 | 99 | def run_dataset_schema(tra, trb, binds, min_corr, max_w=1000, mode="both", d_surface_markers=None, w_surface_markers=0): 100 | alphabet = [chr(ord('A')+i) for i in range(26)] 101 | non_aa = np.array([ord(c)-ord('A') for c in "BJOUXZ"]) #list interepretation of string 102 | if "both" in mode: 103 | D = np.hstack([tra,trb]) 104 | letters = np.array(['a'+c for c in alphabet] + ['b'+c for c in alphabet]) 105 | to_delete = list(non_aa)+list(non_aa+26) 106 | elif "tra" in mode: 107 | D = tra 108 | letters = ['{0}'.format(chr(ord('A')+i)) for i in range(26)] 109 | to_delete = list(non_aa) 110 | elif "trb" in mode: 111 | D = trb 112 | letters = ['{0}'.format(chr(ord('A')+i)) for i in range(26)] 113 | to_delete = list(non_aa) 114 | 115 | D = np.delete(D, to_delete, axis=1) 116 | letters = np.delete(letters, to_delete) 117 | 118 | if "bin:" in mode: 119 | D = 1*(D>0) 120 | 121 | if "std:" in mode: 122 | D = D / (1e-12 + D.std(axis=0)) 123 | 124 | g = [binds]; wg = [1]; tg=["feature_vector"] 125 | 126 | if w_surface_markers != 0: 127 | g.append(d_surface_markers.values) 128 | wg.append(w_surface_markers) 129 | tg.append("feature_vector") 130 | 131 | try: 132 | sys.path.append(os.path.join(sys.path[0],'../../schema')) 133 | import schema_qp 134 | except: 135 | from schema import schema_qp 136 | 137 | afx = schema_qp.SchemaQP(min_corr, max_w, params = {"require_nonzero_lambda":1}, mode="scale") 138 | afx.fit(D,g,tg,wg) 139 | return (pd.Series(np.sqrt(afx._wts), index=letters), afx._soln_info) 140 | 141 | 142 | 143 | 144 | def f_tra_filter(v): 145 | return v[:2]=="CA" and len(v)>=13 146 | 147 | def f_trb_filter(v): 148 | return v[:4]=="CASS" and len(v)>=13 149 | 150 | 151 | def do_dataset_process(M, l, n_jobs, intermediate_file=None, include_full_seq=True, w_surface_markers=0, kmer_type=""): 152 | try: 153 | if include_full_seq: 154 | l.append((2,4,7,11)) 155 | 156 | l1 = [(M, f_tra_filter, f_trb_filter, *v) for v in l] 157 | 158 | print ("Flag 681.10 ", len(l1), M.shape, l, n_jobs) 159 | 160 | if intermediate_file is not None and os.path.exists(intermediate_file): 161 | vdj_data = pickle.load(open(intermediate_file,'rb')) 162 | else: 163 | pool = multiprocessing.Pool(processes = n_jobs) 164 | try: 165 | vdj_data = pool.starmap(chunkifyCDRseqs, l1) 166 | finally: 167 | pool.close() 168 | pool.join() 169 | 170 | if intermediate_file is not None: 171 | pickle.dump(vdj_data, open(intermediate_file,'wb')) 172 | 173 | print ("Flag 681.50 ", len(vdj_data)) 174 | 175 | lx = [] 176 | for md in ["tra","trb"]: 177 | for mw in [1.5, 1.75, 2, 2.25, 3, 5]: 178 | lx.extend([(w, l[i], kmer_type+":"+md, mw, w_surface_markers) for i,w in enumerate(vdj_data)]) 179 | 180 | 181 | print ("Flag 681.70 ", len(lx)) 182 | 183 | pool2 = multiprocessing.Pool(processes = n_jobs) 184 | try: 185 | lx2 = pool2.starmap(f_dataset_helper, lx) 186 | except: 187 | lx2 188 | finally: 189 | pool2.close() 190 | pool2.join() 191 | 192 | print ("Flag 681.80 ", len(lx2)) 193 | 194 | f_0_1_scaler = lambda v: (v-np.min(v))/(np.max(v)-np.min(v)) 195 | 196 | ly = [] 197 | rd = {} 198 | for i, v in enumerate(lx): 199 | _, k, md, mw, w_surface = v 200 | rd[(k, md, mw, w_surface)] = lx2[i][0] 201 | 202 | ly.append(f_0_1_scaler(lx2[i][0])) 203 | 204 | print ("Flag 681.85 ", len(ly), len(rd)) 205 | 206 | v_rd = pd.DataFrame(ly).median(axis=0).sort_values() 207 | return v_rd, rd 208 | 209 | except: 210 | return (None, None) 211 | 212 | 213 | 214 | 215 | def f_colprocess_helper(trx, binds, surface_markers, mw, w_surface_markers): 216 | try: 217 | g = [binds]; wg = [1]; tg=["feature_vector"] 218 | 219 | if w_surface_markers != 0: 220 | g.append(d_surface_markers.values) 221 | wg.append(w_surface_markers) 222 | tg.append("feature_vector") 223 | 224 | try: 225 | sys.path.append(os.path.join(sys.path[0],'../../schema')) 226 | import schema_qp 227 | except: 228 | from schema import schema_qp 229 | 230 | afx = schema_qp.SchemaQP(0.01, mw, params = {"require_nonzero_lambda":1, 231 | "scale_mode_uses_standard_scaler":1, 232 | "d0_type_is_feature_vector_categorical":1,}, 233 | mode="scale") 234 | afx.fit(trx,g,tg,wg) 235 | return (pd.Series(np.sqrt(afx._wts)), afx._soln_info) 236 | 237 | except: 238 | print ("Flag 67568.10 Saw exception in f_colprocess_helper") 239 | return (None, None) 240 | 241 | 242 | 243 | def do_columnwise_process(M, chain, n_jobs, intermediate_file=None, w_surface_markers=0): 244 | assert chain in ["tra","trb"] 245 | 246 | try: 247 | truthval_cols = [c for c in M.columns if 'binder' in c] 248 | surface_marker_cols = "CD3,CD19,CD45RA,CD4,CD8a,CD14,CD45RO,CD279_PD-1,IgG1,IgG2a,IgG2b,CD127,CD197_CCR7,HLA-DR".split(",") 249 | 250 | def f_pad20(s): 251 | return [ord(c)-ord('A')+1 for c in s[:20]] + [0]*max(0,20-len(s)) 252 | 253 | trx_L = []; binds_L = []; markers_L = [] 254 | for i in tqdm(range(M.shape[0])): #range(M.shape[0]): 255 | sl = M.at[i,"cell_clono_cdr3_aa"].split(";") 256 | for x in sl: 257 | if x[:4].lower() != (chain+":"): continue 258 | trx_L.append(f_pad20(x[4:])) 259 | binds_L.append(M.loc[i,truthval_cols].astype(int).values) 260 | markers_L.append(M.loc[i,surface_marker_cols].astype(int).values) 261 | 262 | trx = np.array(trx_L) 263 | binds = np.array(binds_L) 264 | surface_markers = np.array(markers_L) 265 | 266 | vdj_data = (trx, binds, surface_markers) 267 | 268 | print ("Flag 682.10 ", M.shape, chain, n_jobs) 269 | 270 | if intermediate_file is not None and os.path.exists(intermediate_file): 271 | vdj_data = pickle.load(open(intermediate_file,'rb')) 272 | elif intermediate_file is not None: 273 | pickle.dump(vdj_data, open(intermediate_file,'wb')) 274 | trx, binds, surface_markers = vdj_data 275 | 276 | print ("Flag 681.50 ", trx.shape) 277 | 278 | lx = [] 279 | for mw in [1.6, 1.8, 2, 2.2]: 280 | lx.extend([(trx, binds, surface_markers, mw, w_surface_markers)]) 281 | 282 | 283 | print ("Flag 681.70 ", len(lx)) 284 | 285 | pool2 = multiprocessing.Pool(processes = n_jobs) 286 | try: 287 | lx2 = pool2.starmap(f_colprocess_helper, lx) 288 | except: 289 | lx2 290 | finally: 291 | pool2.close() 292 | pool2.join() 293 | 294 | print ("Flag 681.80 ", len(lx2)) 295 | 296 | f_0_1_scaler = lambda v: (v-np.min(v))/(np.max(v)-np.min(v)) 297 | 298 | ly = [] 299 | rd = {} 300 | for i, v in enumerate(lx): 301 | _, _, _, mw, w_surface = v 302 | rd[(chain, mw, w_surface)] = lx2[i][0] 303 | 304 | ly.append(f_0_1_scaler(lx2[i][0])) 305 | 306 | print ("Flag 681.85 ", len(ly), len(rd)) 307 | 308 | v_rd = pd.DataFrame(ly).median(axis=0).sort_values() 309 | v_rd2 = pd.DataFrame(ly).mean(axis=0).sort_values() 310 | return v_rd, v_rd2, rd 311 | 312 | except: 313 | #raise 314 | return (None, None, None) 315 | 316 | 317 | ################################################################################# 318 | 319 | if __name__ == "__main__": 320 | sys.path.append(os.path.join(sys.path[0],'../../schema')) 321 | from utils import SlideSeq 322 | 323 | import argparse 324 | parser = argparse.ArgumentParser() 325 | parser.add_argument("--mode", help="which code path to run. see main(..) for details") 326 | parser.add_argument("--outdir", help="output directory (can set to '.')", type=str, default=".") 327 | parser.add_argument("--outsfx", help="suffix to use when producing output files") 328 | parser.add_argument("--style", help="mode-specific interpretation", type=int, default=0) 329 | parser.add_argument("--infile", help="path to the .h5 file containing processed 10X binarized-csv dataframe", default="/afs/csail.mit.edu/u/r/rsingh/work/schema/data/10x/processed/vdj_binarized_alldonors.h5") 330 | parser.add_argument("--njobs", help="number of parallel cores to use", type=int, default=24) 331 | 332 | parser.add_argument("--extra", help="put this as the LAST option and arbitrary space-separated key=val pairs after that", type=str, nargs='*') 333 | 334 | 335 | args = parser.parse_args() 336 | assert args.mode is not None 337 | extra_args = dict([a.split("=") for a in args.extra]) if args.extra else {} 338 | 339 | 340 | if args.mode== "raw_data_read": 341 | processRawData( os.path.dirname(args.infile).replace("/processed",""), args.infile) 342 | 343 | 344 | 345 | 346 | if args.mode == "compute_2_modality_selection_pressure" or args.mode== "compute_3_modality_selection_pressure": 347 | K = 2 if args.mode=="compute_2_modality_selection_pressure" else 3 348 | 349 | intmdt_file = "{0}/schema_median_wts_{3}-dataset_style{1}_intermediate_{2}.pkl".format(args.outdir, args.style, args.outsfx, K) 350 | 351 | M = pd.read_hdf(args.infile) 352 | 353 | l = [] 354 | for i in range(5): 355 | l += [(2+i, 4+i, min(7,2+1+i), min(11,4+1+i))] 356 | l += [(2,9,7,10)] 357 | l += [(2,10,7,11)] 358 | 359 | w_surface_markers = 0 360 | if args.mode=="compute_3_modality_selection_pressure": 361 | w_surface_markers = float(extra_args.get("w_surface_markers",-0.1)) 362 | 363 | # regular run, not for cross-validation 364 | if int(extra_args.get("separate_by_donor",0))==0 and int(extra_args.get("kfold_split",1))<=1: 365 | print ("Flag 679.10 ", M.shape, len(l), l) 366 | v_rd, rd = do_dataset_process(M, l, args.njobs, 367 | intermediate_file = intmdt_file, 368 | include_full_seq= int(extra_args.get("include_full_seq",0))==1, 369 | w_surface_markers = w_surface_markers, 370 | kmer_type = extra_args.get("kmer_type","")) 371 | 372 | 373 | csvfile = "{0}/schema_median_wts_{3}-dataset_style{1}_{2}.csv".format(args.outdir, args.style, args.outsfx, K) 374 | pklfile = "{0}/schema_median_wts_{3}-dataset_style{1}_schema-results_{2}.pkl".format(args.outdir, args.style, args.outsfx, K) 375 | 376 | try: 377 | v_rd.to_csv(csvfile, index=True, header=False) 378 | pickle.dump(rd, open(pklfile, "wb")) 379 | except: 380 | os.system("echo Error > {0}".format(csvfile)) 381 | traceback.print_exc(open(csvfile,'at')) 382 | os.system("echo Error > {0}".format(pklfile)) 383 | 384 | # cross-validation run, split by donor 385 | elif int(extra_args.get("separate_by_donor",0))>0 : 386 | dnrL = list(M["donor"].unique()) 387 | for dnr in ["ALL"] + dnrL: 388 | Mx = M if dnr=="ALL" else M[M["donor"]==dnr].copy().reset_index(drop=True) 389 | print ("Flag 679.20 ", Mx.shape, len(l), l) 390 | v_rd, rd = do_dataset_process(Mx, l, args.njobs, 391 | intermediate_file = None, 392 | include_full_seq= int(extra_args.get("include_full_seq",0))==1, 393 | w_surface_markers = w_surface_markers, 394 | kmer_type = extra_args.get("kmer_type","")) 395 | 396 | 397 | csvfile = "{0}/schema_median_wts_{4}-dataset_style{1}_donor-{3}_{2}.csv".format(args.outdir, args.style, args.outsfx, dnr, K) 398 | pklfile = "{0}/schema_median_wts_{4}-dataset_style{1}_donor-{3}_schema-results_{2}.pkl".format(args.outdir, args.style, args.outsfx, dnr, K) 399 | try: 400 | v_rd.to_csv(csvfile, index=True, header=False) 401 | pickle.dump(rd, open(pklfile, "wb")) 402 | except: 403 | os.system("echo Error > {0}".format(csvfile)) 404 | traceback.print_exc(open(csvfile,'at')) 405 | os.system("echo Error > {0}".format(pklfile)) 406 | 407 | 408 | # cross-validation run, split by epitopes into k subsets 409 | elif int(extra_args.get("kfold_split",1)) >1: 410 | n = M.shape[0] 411 | 412 | # split by binding specifities 413 | truthval_cols = [c for c in M.columns if 'binder' in c] 414 | Mbinds = M.loc[:,truthval_cols].values 415 | b1 = np.argmax(Mbinds, axis=1) 416 | vL = np.array_split(list(set(b1)), int(extra_args["kfold_split"])) 417 | print ("Flag 681.10 ", vL) 418 | idxL = [ np.ravel(np.nonzero(np.isin(b1, v))) for v in vL] 419 | 420 | # split randomly 421 | #X = np.arange(n).reshape(n,1) 422 | #kfold = sklearn.model_selection.KFold(n_splits = int(extra_args["kfold_split"]), random_state=0) 423 | #idxL = [idx for idx,_ in kfold.split(X)] 424 | 425 | for i, idx in enumerate(idxL + ["ALL"]): #(["ALL"] + idxL): 426 | if idx=="ALL": 427 | Mx = M 428 | nm = "ALL" 429 | else: 430 | Mx = M.iloc[idx,:].copy().reset_index(drop=True) 431 | #nm = "C{0}".format(i) 432 | nm = "B{0}".format(i) 433 | 434 | print ("Flag 681.20 ", Mx.shape, len(l), l) 435 | v_rd, rd = do_dataset_process(Mx, l, args.njobs, 436 | intermediate_file = None, 437 | include_full_seq= int(extra_args.get("include_full_seq",0))==1, 438 | w_surface_markers = w_surface_markers, 439 | kmer_type = extra_args.get("kmer_type","")) 440 | 441 | 442 | csvfile = "{0}/schema_median_wts_{4}-dataset_style{1}_kfold-{3}_{2}.csv".format(args.outdir, args.style, args.outsfx, nm, K) 443 | pklfile = "{0}/schema_median_wts_{4}-dataset_style{1}_kfold-{3}_schema-results_{2}.pkl".format(args.outdir, args.style, args.outsfx, nm, K) 444 | try: 445 | v_rd.to_csv(csvfile, index=True, header=False) 446 | pickle.dump(rd, open(pklfile, "wb")) 447 | except: 448 | os.system("echo Error > {0}".format(csvfile)) 449 | traceback.print_exc(open(csvfile,'at')) 450 | os.system("echo Error > {0}".format(pklfile)) 451 | 452 | 453 | 454 | 455 | if args.mode == "compute_2_modality_columnwise_preference" or args.mode == "compute_3_modality_columnwise_preference": 456 | K = 3 if args.mode=="compute_3_modality_columnwise_preference" else 2 457 | 458 | intmdt_file = "{0}/schema_columnwise_median_wts_{3}-dataset_style{1}_intermediate_{2}.pkl".format(args.outdir, args.style, args.outsfx, K) 459 | 460 | M = pd.read_hdf(args.infile) 461 | #M = M.iloc[:5000,:] #for testing 462 | 463 | w_surface_markers = 0 464 | if args.mode=="compute_3_modality_selection_pressure": 465 | w_surface_markers = float(extra_args.get("w_surface_markers",-0.1)) 466 | 467 | chain = extra_args.get("chain","tra") 468 | 469 | print ("Flag 683.20 ", M.shape, chain) 470 | v_rd, v_rd2, rd = do_columnwise_process(M, chain, args.njobs, 471 | intermediate_file = None, 472 | w_surface_markers = w_surface_markers) 473 | 474 | 475 | csvfile = "{0}/schema_columnwise_median_wts_{3}-dataset_style{1}_chain-{4}_{2}.csv".format(args.outdir, args.style, args.outsfx, K, chain) 476 | csvfile2 = "{0}/schema_columnwise_mean_wts_{3}-dataset_style{1}_chain-{4}_{2}.csv".format(args.outdir, args.style, args.outsfx, K, chain) 477 | pklfile = "{0}/schema_columnwise_median_wts_{3}-dataset_style{1}_schema-results_chain-{4}_{2}.pkl".format(args.outdir, args.style, args.outsfx, K, chain) 478 | try: 479 | v_rd.to_csv(csvfile, index=True, header=False) 480 | v_rd2.to_csv(csvfile2, index=True, header=False) 481 | pickle.dump(rd, open(pklfile, "wb")) 482 | except: 483 | os.system("echo Error > {0}".format(csvfile)) 484 | traceback.print_exc(open(csvfile,'at')) 485 | os.system("echo Error > {0}".format(csvfile2)) 486 | os.system("echo Error > {0}".format(pklfile)) 487 | 488 | 489 | -------------------------------------------------------------------------------- /deprecated/old_readme.md: -------------------------------------------------------------------------------- 1 | # Schema 2 | 3 | Schema is a general algorithm for integrating heterogeneous data 4 | modalities. It has been specially designed for multi-modal 5 | single-cell biological datasets, but should work in other contexts too. 6 | This version is based on a Quadratic Programming framework. 7 | 8 | 9 | It is described in the paper 10 | ["*Schema: A general framework for integrating heterogeneous single-cell modalities*"](https://www.biorxiv.org/content/10.1101/834549v1). 11 | 12 | 13 | 14 | The module provides a class SchemaQP that offers a sklearn type fit+transform API for affine 15 | transformations of input datasets such that the transformed data is in agreement 16 | with all the input datasets. 17 | 18 | ## Getting Started 19 | The examples provided here are also available in the examples/Schema_demo.ipynb notebook 20 | 21 | ### Installation 22 | ``` 23 | pip install schema_learn 24 | ``` 25 | 26 | ### Schema: A simple example 27 | 28 | For the examples below, you'll also need scanpy (`pip install scanpy`). 29 | We use `fast_tsne` below for visualization, but feel free to use your favorite tool. 30 | 31 | #### Sample data 32 | 33 | The data in the examples below is from the paper below; we thank the authors for making it available: 34 | * Tasic et al. [*Shared and distinct transcriptomic cell types across neocortical areas*.](https://www.nature.com/articles/s41586-018-0654-5) Nature. 2018 Nov;563(7729):72-78. doi:10.1038/s41586-018-0654-5 35 | 36 | We make available a processed subset of the data for demonstration and analysis. 37 | Linux shell commands to get this data: 38 | ``` 39 | wget http://schema.csail.mit.edu/datasets/Schema_demo_Tasic2018.h5ad.gz 40 | gunzip Schema_demo_Tasic2018.h5ad.gz 41 | ``` 42 | 43 | In Python, set the `DATASET_DIR` variable to the folder containing this file. 44 | 45 | The processing of raw data here broadly followed the steps in Kobak & Berens 46 | * https://www.biorxiv.org/content/10.1101/453449v1 47 | 48 | The gene expression data has been count-normalized and log-transformed. Load with the commands 49 | ```python 50 | import scanpy as sc 51 | adata = sc.read(DATASET_DIR + "/" + "Schema_demo_Tasic2018.h5ad") 52 | ``` 53 | 54 | #### Sample Schema usage 55 | 56 | Import Schema as: 57 | ```python 58 | from schema import SchemaQP 59 | afx = SchemaQP(0.75) # min_desired_corr is the only required argument. 60 | 61 | dx_pca = afx.fit_transform(adata.X, # primary dataset 62 | [adata.obs["class"].values], # just one secondary dataset 63 | ['categorical'] # has labels, i.e., is a categorical datatype 64 | ) 65 | ``` 66 | This uses PCA as the change-of-basis transform; requires a min corr of 0.75 between the 67 | primary dataset (gene expression) and the transformed dataset; and maximizes 68 | correlation between the primary dataset and the secondary dataset, supercluster 69 | (i.e. higher-level clusters) labels produced during Tasic et al.'s hierarchical clustering. 70 | 71 | 72 | ### More Schema examples 73 | * In all of what follows, the primary dataset is gene expression. The secondary datasets are 1) cluster IDs; and/or 2) cell-type "class" variables which correspond to superclusters (i.e. higher-level clusters) in the Tasic et al. paper. 74 | 75 | 76 | 77 | #### With NMF (Non-negative Matrix Factorization) as change-of-basis, a different min_desired_corr, and two secondary datasets 78 | 79 | 80 | ```python 81 | afx = SchemaQP(0.6, params= {"decomposition_model": "nmf", "num_top_components": 50}) 82 | 83 | dx_nmf = afx.fit_transform(adata.X, 84 | [adata.obs["class"].values, adata.obs.cluster_id.values], # two secondary datasets 85 | ['categorical', 'categorical'], # both are labels 86 | [10, 1] # relative wts 87 | ) 88 | ``` 89 | 90 | #### Now let's do something unusual. Perturb the data so it *disagrees* with cluster ids 91 | 92 | 93 | ```python 94 | afx = SchemaQP(0.97, # Notice that we bumped up the min_desired_corr so the perturbation is limited 95 | params = {"decomposition_model": "nmf", "num_top_components": 50}) 96 | 97 | dx_perturb = afx.fit_transform(adata.X, 98 | [adata.obs.cluster_id.values], # could have used both secondary datasets, but one's fine here 99 | ['categorical'], 100 | [-1] # This is key: we are putting a negative wt on the correlation 101 | ) 102 | ``` 103 | 104 | 105 | #### Recommendations for parameter settings 106 | * `min_desired_corr` and `w_max_to_avg` are the names for the hyperparameters $s_1$ and $\bar{w}$ from our paper 107 | * *min_desired_corr*: at first, you should try a range of values for `min_desired_corr` (e.g., 0.99, 0.90, 0.50). This will give you a sense of what might work well for your data; after this, you can progressively narrow down your range. In typical use-cases, high `min_desired_corr` values (> 0.80) work best. 108 | * *w_max_to_avg*: start by keeping this constraint very loose. This ensures that `min_desired_corr` remains the binding constraint. Later, as you get a better sense for `min_desired_corr` values, you can experiment with this too. A value of 100 is pretty high and should work well in the beginning. 109 | 110 | 111 | 112 | #### tSNE plots of the baseline and Schema transforms 113 | 114 | ```python 115 | fig = plt.figure(constrained_layout=True, figsize=(8,2), dpi=300) 116 | tmps = {} 117 | for i,p in enumerate([("Original", adata.X), 118 | ("PCA1 (pos corr)", dx_pca), 119 | ("NMF (pos corr)", dx_nmf), 120 | ("Perturb (neg corr)", dx_perturb) 121 | ]): 122 | titlestr, dx1 = p 123 | ax = fig.add_subplot(1,4,i+1, frameon=False) 124 | tmps[titlestr] = dy = fast_tsne(dx1, seed=42) 125 | ax = plt.gca() 126 | ax.set_aspect('equal', adjustable='datalim') 127 | ax.scatter(dy[:,0], dy[:,1], s=1, color=adata.obs['cluster_color']) 128 | ax.set_title(titlestr) 129 | ax.axis("off") 130 | ``` 131 | 132 | 133 | 134 | ## API 135 | 136 | ### Constructor 137 | Initializes the `SchemaQP` object 138 | 139 | #### Parameters 140 | 141 | `min_desired_corr`: `float` in [0,1) 142 | 143 | The minimum desired correlation between squared L2 distances in the transformed space 144 | and distances in the original space. 145 | 146 | 147 | RECOMMENDED VALUES: At first, you should try a range of values (e.g., 0.99, 0.90, 0.50). 148 | This will give you a sense of what might work well for your data. 149 | After this, you can progressively narrow down your range. 150 | In typical use-cases of large biological datasets, 151 | high values (> 0.80) will probably work best. 152 | 153 | 154 | `w_max_to_avg`: `float` >1, optional (default: 100) 155 | 156 | Sets the upper-bound on the ratio of w's largest element to w's avg element. 157 | Making it large will allow for more severe transformations. 158 | 159 | RECOMMENDED VALUES: Start by keeping this constraint very loose; the default value (100) does 160 | this, ensuring that min_desired_corr remains the binding constraint. 161 | Later, as you get a better sense for the right min_desired_corr values 162 | for your data, you can experiment with this too. 163 | 164 | To really constrain this, set it in the (1-5] range, depending on 165 | how many features you have. 166 | 167 | 168 | `params`: `dict` of key-value pairs, optional (see defaults below) 169 | 170 | Additional configuration parameters. 171 | Here are the important ones: 172 | * decomposition_model: "pca" or "nmf" (default=pca) 173 | * num_top_components: (default=50) number of PCA (or NMF) components to use 174 | when mode=="affine". 175 | 176 | You can ignore the rest on your first pass; the default values are pretty reasonable: 177 | * dist_npairs: (default=2000000). How many pt-pairs to use for computing pairwise distances 178 | value=None means compute exhaustively over all n*(n-1)/2 pt-pairs. Not recommended for n>5000. 179 | Otherwise, the given number of pt-pairs is sampled randomly. The sampling is done 180 | in a way in which each point will be represented roughly equally. 181 | * scale_mode_uses_standard_scaler: 1 or 0 (default=0), apply the standard scaler 182 | in the scaling mode 183 | * do_whiten: 1 or 0 (default=1). When mode=="affine", should the change-of-basis loadings 184 | be made 1-variance? 185 | 186 | 187 | `mode`: {`'affine'`, `'scale'`}, optional (default: `'affine'`) 188 | 189 | Whether to perform a general affine transformation or just a scaling transformation 190 | 191 | * 'scale' does scaling transformations only. 192 | * 'affine' first does a mapping to PCA or NMF space (you can specify n_components) 193 | It then does a scaling transform in that space and then maps everything back to the 194 | regular space, the final space being an affine transformation 195 | 196 | RECOMMENDED VALUES: 'affine' is the default, which uses PCA or NMF to do the change-of-basis. 197 | You'll want 'scale' only in one of two cases: 198 | 1) You have some features on which you directly want Schema to compute 199 | feature-weights. 200 | 2) You want to do a change-of-basis transform other PCA or NMF. If so, you will 201 | need to do that yourself and then call SchemaQP with the transformed 202 | primary dataset with mode='scale'. 203 | 204 | #### Returns 205 | 206 | A SchemaQP object on which you can call fit(...), transform(...) or fit_transform(....). 207 | 208 | 209 | ### Fit 210 | Given the primary dataset 'd' and a list of secondary datasets, fit a linear transformation (d*) of 211 | 'd' such that the correlation between squared pairwise distances in d* and those in secondary datasets 212 | is maximized while the correlation between the primary dataset d and d* remains above 213 | min_desired_corr 214 | 215 | 216 | #### Parameters 217 | 218 | `d`: A numpy 2-d `array` 219 | 220 | The primary dataset (e.g. scanpy/anndata's .X). 221 | The rows are observations (e.g., cells) and the cols are variables (e.g., gene expression). 222 | The default distance measure computed is L2: sum((point1-point2)**2). See d0_dist_transform. 223 | 224 | 225 | `secondary_data_val_list`: `list` of 1-d or 2-d numpy `array`s, each with same number of rows as `d` 226 | 227 | The secondary datasets you want to align the primary data towards. 228 | Columns in scanpy's .obs variables work well (just remember to use .values) 229 | 230 | 231 | `secondary_data_type_list`: `list` of `string`s, each value in {'numeric','feature_vector','categorical'} 232 | 233 | The list's length should match the length of secondary_data_val_list 234 | 235 | * 'numeric' means you're giving one floating-pt value for each obs. 236 | The default distance measure is L2: (point1-point2)**2 237 | * 'feature_vector' means you're giving some multi-dimensional representation for each obs. 238 | The default distance measure is L2: sum((point1-point2)**2) 239 | * 'categorical' means that you are providing label information that should be compared for equality. 240 | The default distance measure is: 1*(val1!=val2) 241 | 242 | 243 | `secondary_data_wt_list`: `list` of `float`s, optional (default: `None`) 244 | 245 | User-specified wts for each dataset. If 'None', the wts are 1. 246 | If specified, the list's length should match the length of secondary_data_wt_list 247 | 248 | NOTE: you can try to get a mapping that *disagrees* with a dataset_info instead of *agreeing*. 249 | To do so, pass in a negative number (e.g., -1) here. This works even if you have just one secondary 250 | dataset 251 | 252 | 253 | `d0`: A 1-d or 2-d numpy array, same number of rows as 'd', optional (default: `None`) 254 | 255 | An alternative representation of the primary dataset. 256 | 257 | HANDLE WITH CARE! Most likely, you don't need this parameter. 258 | This is useful if you want to provide the primary dataset in two forms: one for transforming and 259 | another one for computing pairwise distances to use in the QP constraint; if so, 'd' is used for the 260 | former, while 'd0' is used for the latter 261 | 262 | 263 | `d0_dist_transform`: a function that takes a non-negative float as input and 264 | returns a non-negative float, optional (default: `None`) 265 | 266 | 267 | HANDLE WITH CARE! Most likely, you don't need this parameter. 268 | The transformation to apply on d or d0's L2 distances before using them for correlations. 269 | 270 | 271 | `secondary_data_dist_transform`: `list` of functions, each taking a non-negative float and 272 | returning a non-negative float, optional (default: `None`) 273 | 274 | HANDLE WITH CARE! Most likely, you don't need this parameter. 275 | The transformations to apply on secondary dataset's L2 distances before using them for correlations. 276 | If specified, the length of the list should match that of secondary_data_val_list 277 | 278 | 279 | #### Returns: 280 | 281 | None 282 | 283 | 284 | ### Transform 285 | Given a dataset `d`, apply the fitted transform to it 286 | 287 | 288 | #### Parameters 289 | 290 | `d`: a numpy 2-d array with same number of columns as primary dataset `d` in the fit(...) 291 | 292 | The rows are observations (e.g., cells) and the cols are variables (e.g., gene expression). 293 | 294 | 295 | #### Returns 296 | 297 | a 2-d numpy array with the same shape as `d` 298 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/Schema-Overview-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rs239/schema/cabdbb2a4069fd6848d5c0951b99f20bf10206f6/docs/_static/Schema-Overview-v2.png -------------------------------------------------------------------------------- /docs/_static/Schema-Overview-v3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rs239/schema/cabdbb2a4069fd6848d5c0951b99f20bf10206f6/docs/_static/Schema-Overview-v3.png -------------------------------------------------------------------------------- /docs/_static/Schema-webpage-logo-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rs239/schema/cabdbb2a4069fd6848d5c0951b99f20bf10206f6/docs/_static/Schema-webpage-logo-1.png -------------------------------------------------------------------------------- /docs/_static/Schema-webpage-logo-2-blue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rs239/schema/cabdbb2a4069fd6848d5c0951b99f20bf10206f6/docs/_static/Schema-webpage-logo-2-blue.png -------------------------------------------------------------------------------- /docs/_templates/footer.html: -------------------------------------------------------------------------------- 1 | {% extends "!footer.html" %} 2 | {% block extrafooter %} 3 |

The design of this documentation has been inspired by the documentation of the Python package scanpy

4 | {{ super() }} 5 | {% endblock %} 6 | 7 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/_static/schema_atacrna_demo_dotplot1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rs239/schema/cabdbb2a4069fd6848d5c0951b99f20bf10206f6/docs/source/_static/schema_atacrna_demo_dotplot1.png -------------------------------------------------------------------------------- /docs/source/_static/schema_atacrna_demo_tsne1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rs239/schema/cabdbb2a4069fd6848d5c0951b99f20bf10206f6/docs/source/_static/schema_atacrna_demo_tsne1.png -------------------------------------------------------------------------------- /docs/source/_static/schema_atacrna_demo_wts1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rs239/schema/cabdbb2a4069fd6848d5c0951b99f20bf10206f6/docs/source/_static/schema_atacrna_demo_wts1.png -------------------------------------------------------------------------------- /docs/source/_static/schema_paired-tag_data-dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rs239/schema/cabdbb2a4069fd6848d5c0951b99f20bf10206f6/docs/source/_static/schema_paired-tag_data-dist.png -------------------------------------------------------------------------------- /docs/source/_static/schema_paired-tag_gene_plots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rs239/schema/cabdbb2a4069fd6848d5c0951b99f20bf10206f6/docs/source/_static/schema_paired-tag_gene_plots.png -------------------------------------------------------------------------------- /docs/source/_static/schema_paired-tag_go-annot.csv: -------------------------------------------------------------------------------- 1 | GO:0007416,synapse assembly 2 | GO:0050808,synapse organization 3 | GO:0051960,regulation of nervous system development 4 | GO:0007156,homophilic cell adhesion via plasma membrane adhesion molecules 5 | GO:0032989,cellular component morphogenesis 6 | GO:0007155,cell adhesion 7 | GO:0051239,regulation of multicellular organismal process 8 | GO:0022610,biological adhesion 9 | GO:0099177,regulation of trans-synaptic signaling 10 | GO:0050804,modulation of chemical synaptic transmission 11 | GO:0007399,nervous system development 12 | GO:0099536,synaptic signaling 13 | GO:0006810,transport 14 | GO:0042391,regulation of membrane potential 15 | GO:0007610,behavior 16 | GO:0098742,cell-cell adhesion via plasma-membrane adhesion molecules 17 | -------------------------------------------------------------------------------- /docs/source/_static/schema_paired-tag_umap-row1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rs239/schema/cabdbb2a4069fd6848d5c0951b99f20bf10206f6/docs/source/_static/schema_paired-tag_umap-row1.png -------------------------------------------------------------------------------- /docs/source/_static/schema_paired-tag_umap-row2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rs239/schema/cabdbb2a4069fd6848d5c0951b99f20bf10206f6/docs/source/_static/schema_paired-tag_umap-row2.png -------------------------------------------------------------------------------- /docs/source/_static/schema_paired-tag_umap-row3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rs239/schema/cabdbb2a4069fd6848d5c0951b99f20bf10206f6/docs/source/_static/schema_paired-tag_umap-row3.png -------------------------------------------------------------------------------- /docs/source/_static/umap_flybrain_regular_r3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rs239/schema/cabdbb2a4069fd6848d5c0951b99f20bf10206f6/docs/source/_static/umap_flybrain_regular_r3.png -------------------------------------------------------------------------------- /docs/source/_static/umap_flybrain_schema0.999-0.99_r3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rs239/schema/cabdbb2a4069fd6848d5c0951b99f20bf10206f6/docs/source/_static/umap_flybrain_schema0.999-0.99_r3.png -------------------------------------------------------------------------------- /docs/source/_static/umap_flybrain_schema0.99_r3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rs239/schema/cabdbb2a4069fd6848d5c0951b99f20bf10206f6/docs/source/_static/umap_flybrain_schema0.99_r3.png -------------------------------------------------------------------------------- /docs/source/api/index.rst: -------------------------------------------------------------------------------- 1 | API 2 | === 3 | 4 | .. automodule:: schema 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os.path 14 | import sys 15 | from datetime import datetime 16 | 17 | import matplotlib 18 | matplotlib.use('agg') 19 | 20 | #sys.path[:0] = [os.path.dirname(os.path.abspath(__file__)) + "/../../schema/"] 21 | #import schema_qp 22 | 23 | sys.path[:0] = [os.path.dirname(os.path.abspath(__file__)) + "/../../"] 24 | 25 | 26 | # -- Project information ----------------------------------------------------- 27 | 28 | project = 'Schema' 29 | copyright = '2020, Rohit Singh, Ashwin Narayan, Brian Hie' 30 | author = 'Rohit Singh, Ashwin Narayan, Brian Hie' 31 | 32 | # The full version, including alpha/beta/rc tags 33 | release = '0.1.0' 34 | 35 | master_doc = 'index' 36 | 37 | # -- General configuration --------------------------------------------------- 38 | 39 | # Add any Sphinx extension module names here, as strings. They can be 40 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 41 | # ones. 42 | extensions = [ 'sphinx.ext.todo', 'sphinx.ext.viewcode', 'sphinx.ext.autodoc'] 43 | 44 | 45 | autoclass_content = 'both' 46 | 47 | 48 | # Add any paths that contain templates here, relative to this directory. 49 | templates_path = ['../_templates'] 50 | 51 | # List of patterns, relative to source directory, that match files and 52 | # directories to ignore when looking for source files. 53 | # This pattern also affects html_static_path and html_extra_path. 54 | exclude_patterns = [] 55 | 56 | 57 | # -- Options for HTML output ------------------------------------------------- 58 | 59 | # The theme to use for HTML and HTML Help pages. See the documentation for 60 | # a list of builtin themes. 61 | # 62 | html_theme = 'sphinx_rtd_theme' 63 | 64 | html_theme_options = { 65 | 'logo_only': True, 66 | } 67 | 68 | 69 | # Add any paths that contain custom static files (such as style sheets) here, 70 | # relative to this directory. They are copied after the builtin static files, 71 | # so a file named "default.css" will overwrite the builtin "default.css". 72 | html_static_path = ['../_static'] 73 | html_logo = '../_static/Schema-webpage-logo-2-blue.png' 74 | 75 | 76 | autodoc_mock_imports = ["scanpy"] 77 | 78 | 79 | #import schema 80 | -------------------------------------------------------------------------------- /docs/source/datasets.rst: -------------------------------------------------------------------------------- 1 | Datasets 2 | ========= 3 | 4 | Ageing *Drosophila* brain 5 | ~~~~~~~~~~~~~~~~~~~~~~ 6 | 7 | This is sourced from `Davie et al.`_ (*Cell* 2018, `GSE 107451`_) and contains scRNA-seq data from a collection of fly brain cells along with each cell's age (in days). It is a useful dataset for exploring a common scenario in multi-modal integration: scRNA-seq data aligned to a 1-dimensional secondary modality. Please see the `example in Visualization`_ where this dataset is used. 8 | 9 | .. code-block:: Python 10 | 11 | import schema 12 | adata = schema.datasets.fly_brain() 13 | 14 | 15 | Paired RNA-seq and ATAC-seq from mouse kidney cells 16 | ~~~~~~~~~~~~~~~~~~~~~~ 17 | 18 | This is sourced from `Cao et al.`_ (*Science* 2018, `GSE 117089`_) and contains paired RNA-seq and ATAC-seq data from a collection of mouse kidney cells. The AnnData object provided here has some additional processing done to remove very low count genes and peaks. This is a useful dataset for the case where one of the modalities is very sparse (here, ATAC-seq). Please see the example in `Paired RNA-seq and ATAC-seq`_ where this dataset is used. 19 | 20 | .. code-block:: Python 21 | 22 | import schema 23 | adata = schema.datasets.scicar_mouse_kidney() 24 | 25 | 26 | 27 | 28 | 29 | .. _Davie et al.: https://doi.org/10.1016/j.cell.2018.05.057 30 | .. _GSE 107451: https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE107451 31 | .. _example in Visualization: https://schema-multimodal.readthedocs.io/en/latest/visualization/index.html#ageing-fly-brain 32 | .. _Cao et al.: https://doi.org/10.1126/science.aau0730 33 | .. _GSE 117089: https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE117089 34 | .. _Paired RNA-seq and ATAC-seq: https://schema-multimodal.readthedocs.io/en/latest/recipes/index.html#paired-rna-seq-and-atac-seq 35 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../../README.rst 2 | 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | :hidden: 7 | 8 | overview 9 | installation 10 | api/index 11 | recipes/index 12 | visualization/index 13 | datasets 14 | references 15 | 16 | 17 | 18 | Indices and tables 19 | ================== 20 | 21 | * :ref:`genindex` 22 | * :ref:`modindex` 23 | * :ref:`search` 24 | -------------------------------------------------------------------------------- /docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | 5 | We recommend Python v3.6 or higher. 6 | 7 | PyPI, Virtualenv, or Anaconda 8 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 9 | 10 | You can use ``pip`` (or ``pip3``): 11 | 12 | .. code-block:: bash 13 | 14 | pip install schema_learn 15 | 16 | 17 | 18 | Docker 19 | ~~~~~~ 20 | 21 | Schema has been designed to be compatible with the popular and excellent single-cell Python package, Scanpy_. 22 | We recommend installing the Docker image recommended_ by Scanpy maintainers and then using ``pip``, as described above, to install Schema in it. 23 | 24 | 25 | .. _Scanpy: http://scanpy.readthedocs.io 26 | 27 | .. _recommended: https://scanpy.readthedocs.io/en/1.4.4.post1/installation.html#docker 28 | -------------------------------------------------------------------------------- /docs/source/overview.rst: -------------------------------------------------------------------------------- 1 | Overview 2 | ======== 3 | 4 | 5 | Schema is a general algorithm for integrating heterogeneous data 6 | modalities. While it has been specially designed for multi-modal 7 | single-cell biological datasets, it should work in other multi-modal 8 | contexts too. 9 | 10 | .. image:: ../_static/Schema-Overview-v2.png 11 | :width: 648 12 | :alt: 'Overview of Schema' 13 | 14 | Schema is designed for single-cell assays where multiple modalities have 15 | been *simultaneously* measured for each cell. For example, this could be 16 | simultaneously-asayed ("paired") scRNA-seq and scATAC-seq data, or a 17 | spatial-transcriptomics dataset (e.g. 10x Visium, Slideseq or 18 | STARmap). Schema can also be used with just a scRNA-seq dataset where some 19 | per-cell metadata is available (e.g., cell age, donor information, batch 20 | ID etc.). With this data, Schema can help perform analyses like: 21 | 22 | * Characterize cells that look similar transcriptionally but differ 23 | epigenetically. 24 | 25 | * Improve cell-type inference by combining RNA-seq and ATAC-seq data. 26 | 27 | * In spatially-resolved single-cell data, identify differentially 28 | expressed genes (DEGs) specific to a spatial pattern. 29 | 30 | * **Improved visualizations**: tune t-SNE or UMAP plots to more clearly 31 | arrange cells along a desired manifold. 32 | 33 | * Simultaneously account for batch effects while also integrating 34 | other modalities. 35 | 36 | Intuition 37 | ~~~~~~~~~ 38 | 39 | To integrate multi-modal data, Schema takes a `metric learning`_ 40 | approach. Each modality is interepreted as a multi-dimensional space, with 41 | observations mapped to points in it (**B** in figure above). We associate 42 | a distance metric with each modality: the metric reflects what it means 43 | for cells to be similar under that modality. For example, Euclidean 44 | distances between L2-normalized expression vectors are a proxy for 45 | coexpression. Across the three graphs in the figure (**B**), the dashed and 46 | dotted lines indicate distances between the same pairs of 47 | observations. 48 | 49 | Schema learns a new distance metric between points, informed 50 | jointly by all the modalities. In Schema, we start by designating one 51 | high-confidence modality as the *primary* (i.e., reference) and the 52 | remaining modalities as *secondary*--- we've found scRNA-seq to typically 53 | be a good choice for the primary modality. Schema transforms the 54 | primary-modality space by scaling each of its dimensions so that the 55 | distances in the transformed space have a higher (or lower, if desired!) 56 | correlation with corresponding distances in the secondary modalities 57 | (**C,D** in the figure above). You can choose any distance metric for the 58 | secondary modalities, though the primary modality's metric needs to be Euclidean. 59 | The primary modality can be pre-transformed by 60 | a `PCA`_ or `NMF`_ transformation so that the scaling occurs in this latter 61 | space; this can often be more powerful because the major directions of variance are 62 | now axis-aligned and hence can be scaled independently. 63 | 64 | Advantages 65 | ~~~~~~~~~~ 66 | 67 | In generating a shared-space representation, Schema is similar to 68 | statistical approaches like CCA (canonical correlation analysis) and 69 | deep-learning methods like autoencoders (which map multiple 70 | representations into a shared latent space). Each of these approaches offers a 71 | different set of trade-offs. Schema, for instance, requires the output 72 | space to be a linear transformation of the primary modality. Doing so 73 | allows it to offer the following advantages: 74 | 75 | * **Interpretability**: Schema identifies which features of the primary 76 | modality were important in maximizing its agreement with the secondary 77 | modalities. If the features corresponded to genes (or principal components), 78 | this can directly be interpreted in terms of gene importances. 79 | 80 | * **Regularization**: single-cell data can be sparse and noisy. As we 81 | discuss in our `paper`_, unconstrained approaches like CCA and 82 | autoencoders seek to maximize the alignment between modalities without 83 | any other considerations. In doing so, they can pick up on artifacts 84 | rather than true biology. A key feature of Schema is its 85 | regularization: if enforces a limit on the distortion of the primary 86 | modality, making sure that the final result remains biologically 87 | informative. 88 | 89 | * **Speed and flexibility**: Schema is a based on a fast quadratic 90 | programming approach that allows for substantial flexibility in the 91 | number of secondary modalities supported and their relative weights. Also, arbitrary 92 | distance metrics (i.e., kernels) are supported for the secondary modalities. 93 | 94 | 95 | Quick Start 96 | ~~~~~~~~~~~ 97 | 98 | Install via pip 99 | 100 | .. code-block:: bash 101 | 102 | pip install schema_learn 103 | 104 | **Example**: correlate gene expression with developmental stage. We demonstrate use with Anndata objects here. 105 | 106 | .. code-block:: Python 107 | 108 | import schema 109 | adata = schema.datasets.fly_brain() # adata has scRNA-seq data & cell age 110 | 111 | sqp = schema.SchemaQP( min_desired_corr=0.99, # require 99% agreement with original scRNA-seq distances 112 | params= {'decomposition_model': 'nmf', 'num_top_components': 20} ) 113 | 114 | #correlate the gene expression with the 'age' parameter 115 | mod_X = sqp.fit_transform( adata.X, # primary modality 116 | [ adata.obs['age'] ], # list of secondary modalities 117 | [ 'numeric' ] ) # datatypes of secondary modalities 118 | 119 | gene_wts = sqp.feature_weights() # get a ranking of gene wts important to the alignment 120 | 121 | 122 | Paper & Code 123 | ~~~~~~~~~~~~ 124 | 125 | Schema is described in the paper *Schema: metric learning enables 126 | interpretable synthesis of heterogeneous single-cell modalities* 127 | (http://doi.org/10.1101/834549) 128 | 129 | Source code available at: https://github.com/rs239/schema 130 | 131 | 132 | .. _metric learning: https://en.wikipedia.org/wiki/Similarity_learning#Metric_learning 133 | .. _paper: https://doi.org/10.1101/834549 134 | .. _PCA: https://en.wikipedia.org/wiki/Principal_component_analysis 135 | .. _NMF: https://en.wikipedia.org/wiki/Non-negative_matrix_factorization 136 | -------------------------------------------------------------------------------- /docs/source/recipes/index.rst: -------------------------------------------------------------------------------- 1 | Data Integration Examples 2 | ======= 3 | 4 | API-usage Examples 5 | ~~~~~~~~~~~~~~ 6 | 7 | *Note*: The code snippets below show how Schema could be used for hypothetical datasets and illustrates the API usage. In the next sections (`Paired RNA-seq and ATAC-seq`_, `Paired-Tag`_) and in `Visualization`_, we describe worked examples where we also provide the dataset to try things on. We are working to add more datasets. 8 | 9 | 10 | **Example** Correlate gene expression 1) positively with ATAC-Seq data and 2) negatively with Batch information. 11 | 12 | .. code-block:: Python 13 | 14 | atac_50d = sklearn.decomposition.TruncatedSVD(50).fit_transform( atac_cnts_sp_matrix) 15 | 16 | sqp = SchemaQP(min_corr=0.9) 17 | 18 | # df is a pd.DataFrame, srs is a pd.Series, -1 means try to disagree 19 | mod_X = sqp.fit_transform( df_gene_exp, # gene expression dataframe: rows=cells, cols=genes 20 | [ atac_50d, batch_id], # batch_info can be a pd.Series or np.array. rows=cells 21 | [ 'feature_vector', 'categorical'], 22 | [ 1, -1]) # maximize combination of (agreement with ATAC-seq + disagreement with batch_id) 23 | 24 | gene_wts = sqp.feature_weights() # get gene importances 25 | 26 | 27 | 28 | **Example** Correlate gene expression with three secondary modalities. 29 | 30 | .. code-block:: Python 31 | 32 | sqp = SchemaQP(min_corr = 0.9) # lower than the default, allowing greater distortion of the primary modality 33 | sqp.fit( adata.X, 34 | [ adata.obs['col1'], adata.obs['col2'], adata.obsm['Matrix1'] ], 35 | [ "categorical", "numeric", "feature_vector"]) # data types of the three modalities 36 | mod_X = sqp.transform( adata.X) # transform 37 | gene_wts = sqp.feature_weights() # get gene importances 38 | 39 | 40 | 41 | 42 | Paired RNA-seq and ATAC-seq 43 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~ 44 | 45 | Here, we integrate simultaneously assayed RNA- and ATAC-seq data from `Cao et al.'s`_ sci-CAR study of mouse kidney cells. Specifically, we'll try to do better cell-type inference by considering both RNA-seq and ATAC-seq data simultaneously. The original study has ground-truth labels for most of the cell types, allowing us to benchmark automatically-computed clusters (generated by Leiden clustering here). As we'll show, a key challenge here is that the ATAC-seq data is very sparse and noisy. Naively incorporating it with RNA-seq can actually be counter-productive--- the joint clustering from a naive approach can actually have a *lower* overlap with the ground truth labels than if we were to just use RNA-seq-based clustering. 46 | 47 | *Note*: This example involves generating Leiden clusters; you will need to install the *igraph* and *leidenalg* Python packages if you want to use them: 48 | 49 | .. code-block:: bash 50 | 51 | pip install igraph 52 | pip install leidenalg 53 | 54 | Let's start by getting the data. We have preprocessed the original dataset, done some basic cleanup, and put it into an AnnData object that you can download. Please remember to also cite the original study if you use this dataset. 55 | 56 | .. code-block:: Python 57 | 58 | import schema 59 | adata = schema.datasets.scicar_mouse_kidney() 60 | print(adata.shape, adata.uns['atac.X'].shape) 61 | print(adata.uns.keys()) 62 | 63 | As you see, we have stored the ATAC data (as a sparse numpy matrix) in the .uns slots of the anndata object. Also look at the *adata.obs* dataframe which has t-SNE coordinates, ground-truth cell type names (as assigned by Cao et al.) and cluster colors etc. You'll notice that some cells don't have ground truth assignments. When evaluating, we'll skip those. 64 | 65 | 66 | To use the ATAC-seq data, we reduce its dimensionality to 50. Instead of PCA, we apply *TruncatedSVD* since the ATAC counts matrix is sparse. 67 | 68 | .. code-block:: Python 69 | 70 | svd2 = sklearn.decomposition.TruncatedSVD(n_components= 50, random_state = 17) 71 | H2 = svd2.fit_transform(adata.uns["atac.X"]) 72 | 73 | 74 | Next, we run Schema. We choose RNA-seq as the primary modality because 1) it has lower noise than ATAC-seq, and 2) we want to investigate which of its features (i.e., genes) are important during the integration. We will first perform a NMF transformation on the RNA-seq data. For the secondary modality, we'll use the dimensionality-reduced ATAC-seq. We require a positive correlation between the two (`secondary_data_wt_list = [1]` below). **Importantly, we force Schema to generate a low-distortation transformation** : the correlation of distances between original RNA-seq space and the transformed space, `min_desired_corr` is required to be >99%. This low-distortion capability of Schema is crucial here, as we'll demonstrate. 75 | 76 | In the `params` settings below, the number of randomly sampled point-pairs has been bumped up to 5M (from default=2M). It helps with the accuracy and doesn't cost too much computationally. We also turned off `do_whiten` (default=1, i.e., true). When `do_whiten=1`, Schema first rescales the PCA/NMF transformation so that each axis has unit variance; typically, doing so is "nice" from a theoretical/statistical perspective. But it can interfere with downstream analyses (e.g., Leiden clustering here). 77 | 78 | .. code-block:: Python 79 | 80 | sqp99 = schema.SchemaQP(0.99, mode='affine', params= {"decomposition_model":"nmf", 81 | "num_top_components":50, 82 | "do_whiten": 0, 83 | "dist_npairs": 5000000}) 84 | dz99 = sqp99.fit_transform(adata.X, [H2], ['feature_vector'], [1]) 85 | 86 | 87 | Let's look at the feature weights. Since we ran the code in 'affine' mode, the raw weights from the quadratic program will correspond to the 50 NMF factors. Three of these factors seem to stand out; most other weights are quite low. 88 | 89 | .. code-block:: Python 90 | 91 | plt.plot(sqp99._wts) 92 | 93 | 94 | .. image:: ../_static/schema_atacrna_demo_wts1.png 95 | :width: 300 96 | 97 | 98 | 99 | Schema offers a helper function to convert these NMF (or PCA) feature weights to gene weights. The function offers a few ways of doing so, but the default is to simply average the loadings across the top-k factors: 100 | 101 | .. code-block:: Python 102 | 103 | v99 = sqp99.feature_weights("top-k-loading", 3) 104 | 105 | 106 | Let's do a dotplot to visualize how the expression of these genes varies by cell name. We plot the top 10 genes by importance here. 107 | 108 | .. code-block:: Python 109 | 110 | dfv99 = pd.DataFrame({"gene": adata.var_names, "v":v99}).sort_values("v", ascending=False).reset_index(drop=True) 111 | sc.pl.dotplot(adata, dfv99.gene.head(10).tolist(),'cell_name_short', figsize=(8,6)) 112 | 113 | As you'll notice, theese gene seem to be differentially expressed in PT cells, PBA and Ki-67+ cells. Essentially, these are cell types where ATAC-seq data was most informative. As we'll see shortly, it is preciely in these cells where Schema is able to offer the biggest improvement. 114 | 115 | .. image:: ../_static/schema_atacrna_demo_dotplot1.png 116 | :width: 500 117 | 118 | 119 | For a comparison later, let's also do a Schema run without a strong distortion control. Below, we set the `min_desired_corr` parameter to 0.10 (i.e., 10%). Thus, the ATAC-seq data will get to influence the transformation a lot more. 120 | 121 | .. code-block:: Python 122 | 123 | sqp10 = schema.SchemaQP(0.10, mode='affine', params= {"decomposition_model":"nmf", 124 | "num_top_components":50, 125 | "do_whiten": 0, 126 | "dist_npairs": 5000000}) 127 | dz10 = sqp10.fit_transform(adata.X, [H2], ['feature_vector'], [1]) 128 | 129 | 130 | Finally, let's do Leiden clustering of the RNA-seq, ATAC-seq, and the two Schema runs. We'll compare the cluster assignments to the ground truth cell labels. Intuitively, by combining RNA-seq and ATAC-seq, one should be able to get a more biologically accurate clustering. We visually evaluate the clusterings below; in the paper, we've supplemented this with more quantitative estimates. 131 | 132 | .. code-block:: Python 133 | 134 | import schema.utils 135 | fcluster = schema.utils.get_leiden_clustering #feel free to try your own clustering algo 136 | 137 | ld_cluster_rna = fcluster(sqp99._decomp_mdl.transform(adata.X.todense())) 138 | ld_cluster_atac = fcluster(H2) 139 | ld_cluster_sqp99 = fcluster(dz99) 140 | ld_cluster_sqp10 = fcluster(dz10) 141 | 142 | 143 | .. code-block:: Python 144 | 145 | x = adata.obs.tsne_1 146 | y = adata.obs.tsne_2 147 | idx = adata.obs.rgb.apply(lambda s: isinstance(s,str) and '#' in s).values.tolist() #skip nan cells 148 | 149 | fig, axs = plt.subplots(3,2, figsize=(10,15)) 150 | axs[0][0].scatter(x[idx], y[idx], c=adata.obs.rgb.values[idx], s=1) 151 | axs[0][0].set_title('Ground Truth') 152 | axs[0][1].scatter(x[idx], y[idx], c=adata.obs.rgb.values[idx], s=1, alpha=0.1) 153 | axs[0][1].set_title('Ground Truth Labels') 154 | for c in np.unique(adata.obs.cell_name_short[idx]): 155 | if c=='nan': continue 156 | cx,cy = x[adata.obs.cell_name_short==c].mean(), y[adata.obs.cell_name_short==c].mean() 157 | axs[0][1].text(cx,cy,c,fontsize=10) 158 | axs[1][0].scatter(x[idx], y[idx], c=ld_cluster_rna[idx], cmap='tab20b', s=1) 159 | axs[1][0].set_title('RNA-seq') 160 | axs[1][1].scatter(x[idx], y[idx], c=ld_cluster_atac[idx], cmap='tab20b', s=1) 161 | axs[1][1].set_title('ATAC-seq') 162 | axs[2][0].scatter(x[idx], y[idx], c=ld_cluster_sqp99[idx], cmap='tab20b', s=1) 163 | axs[2][0].set_title('Schema-99%') 164 | axs[2][1].scatter(x[idx], y[idx], c=ld_cluster_sqp10[idx], cmap='tab20b', s=1) 165 | axs[2][1].set_title('Schema-10%') 166 | 167 | for ax in np.ravel(axs): ax.axis('off') 168 | 169 | 170 | 171 | Below, we show the figures in a 3x2 panel of t-SNE plots. In the first row, the left panel shows the cells colored by ground-truth cell types; the right panel is basically the same but lists the cell types explicitly. The next row shows cells colored by RNA- or ATAC-only clustering. Notice how noisy the ATAC-only clustering is! This is not a bug in our analysis-- less than 0.3% of ATAC count matrix entries are non-zero and the sparsity of the ATAC data makes it difficult to produce high-quality cell type estimates. 172 | 173 | The third row shows cells colored by Schema-based clustering at 99% (left) and 10% (right) `min_desired_corr` thresholds. With Schema at a low-distortion setting (i.e., `min_desired_corr = 99%`), notice that PT cells and Ki-67+ cells, circled in red, are getting more correctly classified now. This improvement of the Schema-implied clustering over the RNA-seq-only clustering can be quantified by measuring the overlap with ground truth cell grouping, as we do in the paper. 174 | 175 | **This is a key strength of Schema** --- even with a modality that is sparse and noisy (like ATAC-seq here), it can nonetheless extract something of value from the noisy modality because the constraint on distortion of the primary modality acts as a regularization. This is also why we recommend that your highest-confidence modality be set as the primary. Lastly as demonstration, if we relax the distortion constraint by setting `min_desired_corr = 10%`, you'll notice that the noise of ATAC-seq data does swamp out the RNA-seq signal. With an unconstrained approach (e.g., CCA or some deep learning approaches), this ends being a major challenge. 176 | 177 | .. image:: ../_static/schema_atacrna_demo_tsne1.png 178 | :width: 600 179 | 180 | 181 | 182 | Paired-Tag 183 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~ 184 | 185 | Here we synthesize simultaneously assayed RNA-seq, ATAC-seq and histone-modification data at a single-cell resolution, from the Paired-Tag protocol described in `Zhu et al.’s study`_ of adult mouse frontal cortex and hippocampus (Nature Methods, 2021). This is a fascinating dataset with five different histone modifications assayed separately (3 repressors and 2 activators), in addition to RNA-seq and ATAC-seq. As in the original study, we consider each of the histone modifications as a separate modality, implying a hepta-modal assay! 186 | 187 | Interestingly, though, the modalities are available only in pairwise combinations with RNA-seq: some cells were assayed for H3K4me1 & RNA-seq while another set of cells provided ATAC-seq & RNA-seq data, and so on. Here’s the overall distribution of non-RNA-seq modalities across 64,849 cells. 188 | 189 | 190 | .. image:: ../_static/schema_paired-tag_data-dist.png 191 | :width: 300 192 | 193 | This organization of data might be tricky to integrate with a method which expects *each* modality to be available for *all* cells and has difficulty accomodating partial coverage of some modalities. Of course, you could always fall back to an integrative approach that treats each modality’s cell population as independent, but then you miss out on the simultaneously-multimodal aspect of this data. 194 | 195 | With Schema, you can have your cake and eat it too! We do 6 two-way integrations (RNA-seq as the primary modality against each of the other modalities) using the subsets of cells available in each case. Schema’s interpretable and linear framework makes it easy to combine these. Once Schema computes the optimal transformation of RNA-seq that aligns it with, say, ATAC-seq, we apply that transformation to the entire RNA-seq dataset, including cells that do *not* have ATAC-seq data. 196 | 197 | Such full-dataset extensions of the pairwise syntheses can then be stacked together. Doing Leiden clustering on the result would enable us to infer cell types by integrating information from all modalities. As we will show below, Schema's synthesis helps improve the quality of cell type inference over what you could get just from RNA-seq. Similarly for feature selection, Schema's computed feature weights for each two-way synthesis can be averaged to get the genes important to the overall synthesis. In a completely automated fashion and without any knowledge of tissue’s source or biology, we’ll find that the genes Schema identifies as important turn out to be very relevant to neuronal function and disease. Ready for more? 198 | 199 | First, you will need the data. The original is available on GEO (`GSE152020`_) but the individual modalities are huge (e.g., the ATAC-seq peak-counts are in a 14,095 x 2,443,832 sparse matrix!). This is not unusual--- epigenetic modalites are typically very sparse (we discuss why this matters in `Paired RNA-seq and ATAC-seq`_). As a preprocessing step, we performed singular value decompositions (SVD) of these modalities and also reduced the RNA-seq data to its 4,000 highly variable genes. An AnnData object with this preprocessing is available here (please remember to also cite the original study if you use this dataset) : 200 | 201 | .. code-block:: bash 202 | 203 | wget http://cb.csail.mit.edu/cb/schema/adata_dimreduced_paired-tag.pkl 204 | 205 | 206 | Let's load it in: 207 | 208 | .. code-block:: Python 209 | 210 | import schema, pickle, anndata, sklearn.metrics 211 | import scanpy as sc 212 | 213 | # you may need to change the file location as appopriate to your setup 214 | adata = pickle.load(open("adata_dimreduced_paired-tag.pkl", "rb")) 215 | 216 | print (adata.shape, 217 | [(c, adata.uns['SVD_'+c].shape) for c in adata.uns['sec_modalities']]) 218 | 219 | 220 | As you see, we have stored the 50-dimensional SVDs of the secondary modalities in the :code:`.uns` slots of the anndata object. Also look at the :code:`adata.obs` dataframe which has UMAP coordinates, ground-truth cell type names (as assigned by Zhu et al.) etc. 221 | 222 | 223 | We now do Schema runs for the 6 two-way modality combinations, with RNA-seq as the primary in each run. Each run will also store the transformation on the entire 64,849-cell RNA-seq dataset and also store the gene importances. 224 | 225 | 226 | .. code-block:: Python 227 | 228 | d_rna = adata.X.todense() 229 | 230 | desc2transforms = {} 231 | for desc in adata.uns['sec_modalities']: 232 | print(desc) 233 | 234 | # we mostly stick with the default settings, explicitly listed here for clarity 235 | sqp = schema.SchemaQP(0.99, mode='affine', params= {"decomposition_model": 'pca', 236 | "num_top_components":50, 237 | "do_whiten": 0, # this is different from default 238 | "dist_npairs": 5000000}) 239 | 240 | # extract the relevant subset 241 | idx1 = adata.obs['rowidx'][adata.uns["SVD_"+desc].index] 242 | prim_d = d_rna[idx1,:] 243 | sec_d = adata.uns["SVD_"+desc].values 244 | print(len(idx1), prim_d.shape, sec_d.shape) 245 | 246 | sqp.fit(prim_d, [sec_d], ['feature_vector'], [1]) # fit on the idx1 subset... 247 | dz = sqp.transform(d_rna) # ...then transform the full RNA-seq dataset 248 | 249 | desc2transforms[desc] = (sqp, dz, idx1, sqp.feature_weights(k=3)) 250 | 251 | 252 | **Cell type inference:**: In each of the 6 runs above, :code:`dz` is a 64,849 x 50 matrix. We can horizontally stack these matrices for a 64,849 x 300 matrix that represents the transformation of RNA-seq data informed simultaneously by all 6 secondary modalities. 253 | 254 | .. code-block:: Python 255 | 256 | a6Xpca = np.hstack([dz for _,dz,_,_ in desc2transforms.values()]) 257 | adata_schema = anndata.AnnData(X=a6Xpca, obs=adata.obs) 258 | print (adata_schema.shape) 259 | 260 | We then perform Leiden clustering on the original and transformed data, computing the overlap with expert marker-gene-based annotation by Zhu et al. 261 | 262 | 263 | .. code-block:: Python 264 | 265 | # original 266 | sc.pp.pca(adata) 267 | sc.pp.neighbors(adata) 268 | sc.tl.leiden(adata) 269 | 270 | # Schema-transformed 271 | # since Schema had already done PCA before it transformed, let's stick with its raw output 272 | sc.pp.neighbors(adata_schema, use_rep='X') 273 | sc.tl.leiden(adata_schema) 274 | 275 | # we'll do plots etc. with the original AnnData object 276 | adata.obs['leiden_schema'] = adata_schema.obs['leiden'].values 277 | 278 | # compute overlap with manual cell type annotations 279 | ari_orig = sklearn.metrics.adjusted_rand_score(adata.obs.Annotation, adata.obs.leiden) 280 | ari_schema= sklearn.metrics.adjusted_rand_score(adata.obs.Annotation, adata.obs.leiden_schema) 281 | 282 | print ("ARI: Orig: {} With Schema: {}".format( ari_orig, ari_schema)) 283 | 284 | 285 | As you can see, the ARI with Schema improved from 0.437 (using only RNA-seq) to 0.446 (using all modalities). Single-cell epigenetic modalities are very sparse, making it difficult to distinguish signal from noise. However, Schema's constrained approach allows it to extract signal from these secondary modalities nonetheless, a task which has otherwise been challenging (see the related discussion in our `paper`_ or in `Paired RNA-seq and ATAC-seq`_). 286 | 287 | Before we plot these clusters, we'll relabel the Schema-based Leiden clusters to match the labeling of RNA-seq only Leiden clusters; this will make their color schemes consistent. You will need to install the Python package *munkres* (:code:`pip install munkres`) for the related computation. 288 | 289 | 290 | .. code-block:: Python 291 | 292 | import munkres 293 | list1 = adata.obs['leiden'].astype(int).tolist() 294 | list2 = adata.obs['leiden_schema'].astype(int).tolist() 295 | 296 | contmat = sklearn.metrics.cluster.contingency_matrix(list1, list2) 297 | map21 = dict(munkres.Munkres().compute(contmat.max() - contmat)) 298 | adata.obs['leiden_schema_relabeled'] = [str(map21[a]) for a in list2] 299 | adata.obs['Schema_reassign'] = [('Same' if (map21[a]==a) else 'Different') for a in list2] 300 | 301 | for c in ['Annotation','Annot2', 'leiden', 'leiden_schema_relabeled', 'Schema_reassign']: 302 | sc.pl.umap(adata, color=c) 303 | 304 | .. image:: ../_static/schema_paired-tag_umap-row1.png 305 | :width: 800 306 | 307 | .. image:: ../_static/schema_paired-tag_umap-row2.png 308 | :width: 650 309 | 310 | 311 | It's also interesting to identify cells where the cluster assignments changed after multi-modal synthesis. As you can see, it's only in certain cell types where the epigenetic data suggests a different clustering than the primary RNA-seq modality. 312 | 313 | .. image:: ../_static/schema_paired-tag_umap-row3.png 314 | :width: 300 315 | 316 | **Gene set identification:** The feature importances output by Schema here identify the genes whose expression variations best agree with epigenetic variations in these tissues. We first aggregate the feature importances across the 6 two-ways runs: 317 | 318 | .. code-block:: Python 319 | 320 | df_genes = pd.DataFrame({'gene': adata.var.symbol}) 321 | for desc, (_,_,_,wts) in desc2transforms.items(): 322 | df_genes[desc] = wts 323 | df_genes['avg_wt'] = df_genes.iloc[:,1:].mean(axis=1) 324 | df_genes = df_genes.sort_values('avg_wt', ascending=False).reset_index(drop=True) 325 | 326 | gene_list = df_genes.gene.values 327 | 328 | sc.pl.umap(adata, color= gene_list[:6], gene_symbols='symbol', color_map='plasma', frameon=False, ncols=3) 329 | 330 | 331 | .. image:: ../_static/schema_paired-tag_gene_plots.png 332 | :width: 800 333 | 334 | Many of the top genes identified by Schema (e.g., `Erbb4`_, `Npas3`_, `Zbtb20`_, `Luzp2`_) are known to be relevant to neuronal function or disease. Note that all of this fell out of the synthesis directly--- we didn't do any differential expression analysis against an external background or provide the method some other indication that the data is from brain tissue. 335 | 336 | We also did a GO enrichment analysis (via `Gorilla`_) of the top 100 genes by Schema weight. Here are the significant hits (FDR q-val < 0.1). Again, most GO terms relate to neuronal development, activity, and communication: 337 | 338 | 339 | .. csv-table:: GO Enrichment of Top Schema-identified genes 340 | :file: ../_static/schema_paired-tag_go-annot.csv 341 | :widths: 20, 80 342 | :header-rows: 0 343 | 344 | 345 | 346 | 347 | 348 | .. _Visualization: https://schema-multimodal.readthedocs.io/en/latest/visualization/index.html#ageing-fly-brain 349 | 350 | .. _Cao et al.'s: https://science.sciencemag.org/content/361/6409/1380/ 351 | 352 | .. _paper: https://genomebiology.biomedcentral.com/articles/10.1186/s13059-021-02313-2 353 | 354 | .. _Erbb4: https://www.ncbi.nlm.nih.gov/gene/2066 355 | 356 | .. _Npas3: https://www.ncbi.nlm.nih.gov/gene/64067 357 | 358 | .. _Zbtb20: https://www.ncbi.nlm.nih.gov/gene/26137 359 | 360 | .. _Luzp2: https://www.ncbi.nlm.nih.gov/gene/338645 361 | 362 | .. _Gorilla: http://cbl-gorilla.cs.technion.ac.il/ 363 | 364 | .. _Zhu et al.’s study: https://www.nature.com/articles/s41592-021-01060-3 365 | 366 | .. _GSE152020: https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE152020 367 | 368 | -------------------------------------------------------------------------------- /docs/source/references.rst: -------------------------------------------------------------------------------- 1 | References 2 | ========== 3 | 4 | **Code**: `Github`_ repo 5 | 6 | **Paper**: If you use Schema, please consider citing *Schema: metric learning enables interpretable synthesis of heterogeneous single-cell modalities* (http://doi.org/10.1101/834549) 7 | 8 | **Project Website**: http://schema.csail.mit.edu 9 | 10 | 11 | 12 | .. _Github: https://github.com/rs239/schema 13 | 14 | 15 | -------------------------------------------------------------------------------- /docs/source/visualization/index.rst: -------------------------------------------------------------------------------- 1 | Visualization Examples 2 | ============= 3 | 4 | Popular tools like `t-SNE`_ and `UMAP`_ can produce intuitive and appealing 5 | visualizations. However, since they perform opaque non-linear transformations of 6 | the input data, it can be unclear how to "tweak" the visualization to 7 | accentuate a specific aspect of the input. Also, it can can sometimes 8 | be difficult to understand which features (e.g. genes) of the input were most important to getting 9 | the plot. 10 | 11 | Schema can help with both of these issues. With scRNA-seq data as the primary 12 | modality, Schema can transform it by infusing additional information into it 13 | while preserving a high level of similarity with the original data. When 14 | t-SNE/UMAP are applied on the transformed data, we have found that the 15 | broad contours of the original plot are preserved while the new 16 | information is also reflected. Furthermore, the relative weight of the new data 17 | can be calibrated using the `min_desired_corr` parameter of Schema. 18 | 19 | Ageing fly brain 20 | ~~~~~~~~~~~~~~~~ 21 | 22 | Here, we tweak the UMAP plot of `Davie et al.'s`_ ageing fly brain data to 23 | accentuate cell age. 24 | 25 | First, let's get the data and do a regular UMAP plot. 26 | 27 | .. code-block:: Python 28 | 29 | import schema 30 | import scanpy as sc 31 | import anndata 32 | 33 | def sc_umap_pipeline(bdata, fig_suffix): 34 | sc.pp.pca(bdata) 35 | sc.pp.neighbors(bdata, n_neighbors=15) 36 | sc.tl.umap(bdata) 37 | sc.pl.umap(bdata, color='age', color_map='coolwarm', save='_{}.png'.format(fig_suffix) ) 38 | 39 | 40 | .. code-block:: Python 41 | 42 | adata = schema.datasets.fly_brain() # adata has scRNA-seq data & cell age 43 | sc_umap_pipeline(adata, 'regular') 44 | 45 | This should produce a plot like this, where cells are colored by age. 46 | 47 | .. image:: ../_static/umap_flybrain_regular_r3.png 48 | :width: 300 49 | 50 | Next, we apply Schema to infuse cell age into the scRNA-seq data, while 51 | preserving a high level of correlation with the original scRNA-seq 52 | distances. We start by requiring a minimum 99.9% correlation with original 53 | scRNA-seq distances 54 | 55 | .. code-block:: Python 56 | 57 | sqp = schema.SchemaQP( min_desired_corr=0.999, # require 99.9% agreement with original scRNA-seq distances 58 | params= {'decomposition_model': 'nmf', 'num_top_components': 20} ) 59 | 60 | mod999_X = sqp.fit_transform( adata.X, [ adata.obs['age'] ], ['numeric']) # correlate gene expression with the age 61 | sc_umap_pipeline( anndata.AnnData( mod999_X, obs=adata.obs), '0.999' ) 62 | 63 | We then loosen the `min_desired_corr` constraint a tiny bit, to 99% 64 | 65 | .. code-block:: Python 66 | 67 | sqp.reset_mincorr_param(0.99) # we can re-use the NMF transform (which takes more time than the quadratic program) 68 | 69 | mod990_X = sqp.fit_transform( adata.X, [ adata.obs['age'] ], ['numeric']) 70 | sc_umap_pipeline( anndata.AnnData( mod990_X, obs=adata.obs), '0.990' ) 71 | 72 | diffexp_gene_wts = sqp.feature_weights() # get a ranking of genes important to the alignment 73 | 74 | These runs should produce a pair of plots like the ones shown below. Note 75 | how cell-age progressively stands out as a characteristic feature. We also 76 | encourage you to try out other choices of `min_desired_corr` (e.g., 0.90 77 | or 0.7); these will show the effect of allowing greater distortions of the 78 | primary modality. 79 | 80 | .. image:: ../_static/umap_flybrain_schema0.999-0.99_r3.png 81 | :width: 620 82 | 83 | This example also illustrates Scehma's interpretability. The variable 84 | `diffexp_gene_wts` identifies the genes most important to aligning 85 | scRNA-seq with cell age. As we describe in our `paper`_, these genes turn 86 | out to be differentially expressed between young cells and old cells. 87 | 88 | 89 | 90 | 91 | .. _Davie et al.'s: https://doi.org/10.1016/j.cell.2018.05.057 92 | .. _paper: https://doi.org/10.1101/834549 93 | .. _t-SNE: https://lvdmaaten.github.io/tsne/ 94 | .. _UMAP: https://umap-learn.readthedocs.io/en/latest/ 95 | -------------------------------------------------------------------------------- /examples/README.rst: -------------------------------------------------------------------------------- 1 | Examples are moving... 2 | ~~~~~~~~~~~~~~~~~~~~~~ 3 | 4 | We recommend looking at the `documentation`_ page, which will be primary location for examples going forward. 5 | 6 | .. _documentation: https://schema-multimodal.readthedocs.io/en/latest/overview.html 7 | -------------------------------------------------------------------------------- /examples/Schema_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### Preamble" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import pandas as pd\n", 17 | "import numpy as np\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "\n", 20 | "import warnings\n", 21 | "with warnings.catch_warnings():\n", 22 | " warnings.simplefilter(\"ignore\")\n", 23 | " import scanpy as sc\n" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "## local paths etc. You'll want to change these\n", 33 | "DATASET_DIR = \"/scratch1/rsingh/work/schema/data/tasic-nature\"\n", 34 | "import sys; sys.path.extend(['/scratch1/rsingh/tools','/afs/csail.mit.edu/u/r/rsingh/work/schema/'])\n" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "#### Import Schema and tSNE\n", 42 | "We use fast-tsne here, but use whatever you like" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "from fast_tsne import fast_tsne\n", 52 | "from schema import SchemaQP" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "### Get example data \n", 60 | " * This data is from Tasic et al. (Nature 2018, DOI: 10.1038/s41586-018-0654-5 )\n", 61 | " * Shell commands to get our copy of the data:\n", 62 | " * wget http://schema.csail.mit.edu/datasets/Schema_demo_Tasic2018.h5ad.gz\n", 63 | " * gunzip Schema_demo_Tasic2018.h5ad.gz\n", 64 | " * The processing of raw data here broadly followed the steps in Kobak & Berens, https://www.biorxiv.org/content/10.1101/453449v1\n", 65 | " * The gene expression data has been count-normalized and log-transformed. \n" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "adata = sc.read(DATASET_DIR + \"/\" + \"Schema_demo_Tasic2018.h5ad\")" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "### Schema examples\n", 82 | " * In all of what follows, the primary dataset is gene expression. The secondary datasets are 1) cluster IDs; and 2) cell-type \"class\" variables which correspond to superclusters (i.e. higher-level clusters) in the Tasic et al. paper.\n", 83 | "#### Recommendations for parameter settings\n", 84 | " * min_desired_corr and w_max_to_avg are the names for the hyperparameters $s_1$ and $\\bar{w}$ from our paper\n", 85 | " * *min_desired_corr*: at first, you should try a range of values for min_desired_corr (e.g., 0.99, 0.90, 0.50). This will give you a sense of what might work well for your data; after this, you can progressively narrow down your range. In typical use-cases, high min_desired_corr values (> 0.80) work best.\n", 86 | " * *w_max_to_avg*: start by keeping this constraint very loose. This ensures that min_desired_corr remains the binding constraint. Later, as you get a better sense for min_desired_corr values, you can experiment with this too. A value of 100 is pretty high and should work well in the beginning.\n" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "#### With PCA as change-of-basis, min_desired_corr=0.75, positive correlation with secondary datasets" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "afx = SchemaQP(0.75) # min_desired_corr is the only required argument.\n", 103 | "\n", 104 | "dx_pca = afx.fit_transform(adata.X, # primary dataset\n", 105 | " [adata.obs[\"class\"].values], # one secondary dataset\n", 106 | " ['categorical'] #it has labels, i.e., is a categorical datatype\n", 107 | " )" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": {}, 113 | "source": [ 114 | "#### Similar to above, with NMF as change-of-basis and a different min_desired_corr" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "afx = SchemaQP(0.6, params= {\"decomposition_model\": \"nmf\", \"num_top_components\": 50})\n", 124 | "\n", 125 | "dx_nmf = afx.fit_transform(adata.X,\n", 126 | " [adata.obs[\"class\"].values, adata.obs.cluster_id.values], # two secondary datasets \n", 127 | " ['categorical', 'categorical'], # both are labels\n", 128 | " [10, 1] # relative wts\n", 129 | " )" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "#### Now let's do something unusual. Perturb the data so it *disagrees* with cluster ids" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "afx = SchemaQP(0.97, # Notice that we bumped up the min_desired_corr so the perturbation is limited \n", 146 | " params = {\"decomposition_model\": \"nmf\", \"num_top_components\": 50})\n", 147 | "\n", 148 | "dx_perturb = afx.fit_transform(adata.X,\n", 149 | " [adata.obs.cluster_id.values], # could have used both secondary datasets, but one's fine here\n", 150 | " ['categorical'],\n", 151 | " [-1] # This is key: we are putting a negative wt on the correlation\n", 152 | " )" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "### tSNE plots of the baseline and Schema transforms " 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "fig = plt.figure(constrained_layout=True, figsize=(8,2), dpi=300)\n", 169 | "tmps = {}\n", 170 | "for i,p in enumerate([(\"Original\", adata.X), \n", 171 | " (\"PCA1 (pos corr)\", dx_pca), \n", 172 | " (\"NMF (pos corr)\", dx_nmf), \n", 173 | " (\"Perturb (neg corr)\", dx_perturb)\n", 174 | " ]):\n", 175 | " titlestr, dx1 = p \n", 176 | " ax = fig.add_subplot(1,4,i+1, frameon=False)\n", 177 | " tmps[titlestr] = dy = fast_tsne(dx1, seed=42)\n", 178 | " ax = plt.gca()\n", 179 | " ax.set_aspect('equal', adjustable='datalim')\n", 180 | " ax.scatter(dy[:,0], dy[:,1], s=1, color=adata.obs['cluster_color'])\n", 181 | " ax.set_title(titlestr)\n", 182 | " ax.axis(\"off\")" 183 | ] 184 | } 185 | ], 186 | "metadata": { 187 | "kernelspec": { 188 | "display_name": "Python 3", 189 | "language": "python", 190 | "name": "python3" 191 | }, 192 | "language_info": { 193 | "codemirror_mode": { 194 | "name": "ipython", 195 | "version": 3 196 | }, 197 | "file_extension": ".py", 198 | "mimetype": "text/x-python", 199 | "name": "python", 200 | "nbconvert_exporter": "python", 201 | "pygments_lexer": "ipython3", 202 | "version": "3.6.1" 203 | }, 204 | "toc": { 205 | "base_numbering": 1, 206 | "nav_menu": {}, 207 | "number_sections": true, 208 | "sideBar": true, 209 | "skip_h1_title": false, 210 | "title_cell": "Table of Contents", 211 | "title_sidebar": "Contents", 212 | "toc_cell": false, 213 | "toc_position": {}, 214 | "toc_section_display": true, 215 | "toc_window_display": false 216 | }, 217 | "varInspector": { 218 | "cols": { 219 | "lenName": 16, 220 | "lenType": 16, 221 | "lenVar": 40 222 | }, 223 | "kernels_config": { 224 | "python": { 225 | "delete_cmd_postfix": "", 226 | "delete_cmd_prefix": "del ", 227 | "library": "var_list.py", 228 | "varRefreshCmd": "print(var_dic_list())" 229 | }, 230 | "r": { 231 | "delete_cmd_postfix": ") ", 232 | "delete_cmd_prefix": "rm(", 233 | "library": "var_list.r", 234 | "varRefreshCmd": "cat(var_dic_list()) " 235 | } 236 | }, 237 | "types_to_exclude": [ 238 | "module", 239 | "function", 240 | "builtin_function_or_method", 241 | "instance", 242 | "_Feature" 243 | ], 244 | "window_display": false 245 | } 246 | }, 247 | "nbformat": 4, 248 | "nbformat_minor": 2 249 | } 250 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | matplotlib 3 | pandas 4 | scipy 5 | seaborn 6 | cvxopt 7 | tables 8 | tqdm 9 | scikit-learn 10 | umap-learn 11 | legacy-api-wrap 12 | setuptools_scm 13 | packaging 14 | sinfo 15 | scanpy 16 | -------------------------------------------------------------------------------- /schema/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .schema_qp import SchemaQP 3 | from .schema_base_config import schema_loglevel 4 | 5 | __all__ = ['SchemaQP', 'schema_loglevel'] 6 | 7 | 8 | import pkgutil, pathlib, importlib 9 | 10 | # from pkgutil import iter_modules 11 | # from pathlib import Path 12 | # from importlib import import_module 13 | 14 | # https://julienharbulot.com/python-dynamical-import.html 15 | # iterate through the modules in the current package 16 | # 17 | package_dir = str(pathlib.Path(__file__).resolve().parent) 18 | 19 | for (_, module_name, _) in pkgutil.iter_modules([package_dir]): 20 | if 'datasets' in module_name: 21 | module = importlib.import_module(f"{__name__}.{module_name}") 22 | -------------------------------------------------------------------------------- /schema/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from ._datasets import fly_brain, scicar_mouse_kidney 3 | 4 | __all__ = ['fly_brain', 'scicar_mouse_kidney'] 5 | 6 | -------------------------------------------------------------------------------- /schema/datasets/_datasets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | 4 | ################################################################### 5 | ## Primary Author: Rohit Singh rsingh@alum.mit.edu 6 | ## Co-Authors: Ashwin Narayan, Brian Hie {ashwinn,brianhie}@mit.edu 7 | ## License: MIT 8 | ## Repository: http://github.io/rs239/schema 9 | ################################################################### 10 | 11 | import sys, copy, os, warnings 12 | 13 | with warnings.catch_warnings(): 14 | warnings.simplefilter("ignore") 15 | import scanpy 16 | 17 | 18 | # #### local directory imports #### 19 | # oldpath = copy.copy(sys.path) 20 | # sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))+"/../") 21 | 22 | # from schema_base_config import * 23 | 24 | # sys.path = copy.copy(oldpath) 25 | # #### 26 | 27 | 28 | 29 | 30 | def fly_brain(): 31 | """ Anndata object containing scRNA-seq data of the ageing Drosophila brain (GSE107451, Davie et al., Cell 2018) 32 | """ 33 | 34 | adata = scanpy.read("datasets/Davie_fly_brain.h5", backup_url="http://cb.csail.mit.edu/cb/schema/datasets/Davie_fly_brain.h5") 35 | return adata 36 | 37 | 38 | def scicar_mouse_kidney(): 39 | """ Anndata object containing scRNA-seq+ATAC-seq data of mouse kidney cells from the Sci-CAR study (GSE117089, Cao et al., Science 2018) 40 | """ 41 | 42 | adata = scanpy.read("datasets/Cao_mouse_kidney.h5", backup_url="http://cb.csail.mit.edu/cb/schema/datasets/Cao_mouse_kidney.h5") 43 | return adata 44 | 45 | -------------------------------------------------------------------------------- /schema/schema_base_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import logging 4 | 5 | schema_loglevel = logging.WARNING #WARNING 6 | 7 | 8 | def schema_debug(*args, **kwargs): 9 | if schema_loglevel <= logging.DEBUG: print("DEBUG: ", *args, **kwargs) 10 | 11 | def schema_info(*args, **kwargs): 12 | if schema_loglevel <= logging.INFO: print("INFO: ", *args, **kwargs) 13 | 14 | def schema_warning(*args, **kwargs): 15 | if schema_loglevel <= logging.WARNING: print("WARNING: ", *args, **kwargs) 16 | 17 | def schema_error(*args, **kwargs): 18 | if schema_loglevel <= logging.ERROR: print("ERROR: ", *args, **kwargs) 19 | 20 | 21 | ########## for maintenance ################### 22 | # def noop(*args, **kwargs): 23 | # pass 24 | # 25 | # logging.info = print 26 | # logging.debug = noop 27 | ############################################## 28 | -------------------------------------------------------------------------------- /schema/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import pandas as pd 4 | import numpy as np 5 | import scipy, sklearn, os, sys, string, fileinput, glob, re, math, itertools, functools, copy, multiprocessing 6 | import scipy.stats, sklearn.decomposition, sklearn.preprocessing, sklearn.covariance, sklearn.neighbors 7 | from scipy.stats import describe 8 | from scipy import sparse 9 | import os.path 10 | import scipy.sparse 11 | from scipy.sparse import csr_matrix, csc_matrix 12 | from sklearn.preprocessing import normalize 13 | from collections import defaultdict 14 | 15 | 16 | #### local directory imports #### 17 | oldpath = copy.copy(sys.path) 18 | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) 19 | 20 | from schema_base_config import * 21 | 22 | sys.path = copy.copy(oldpath) 23 | #### 24 | 25 | 26 | def get_leiden_clustering(mtx, num_neighbors=30): 27 | import igraph 28 | gNN = igraph.Graph() 29 | N = mtx.shape[0] 30 | gNN.add_vertices(N) 31 | schema_debug("Flag 192.30 ", mtx.shape) 32 | mNN = scipy.sparse.coo_matrix( sklearn.neighbors.kneighbors_graph(mtx, num_neighbors)) 33 | schema_debug("Flag 192.40 ", mNN.shape) 34 | gNN.add_edges([ (i,j) for i,j in zip(mNN.row, mNN.col)]) 35 | schema_debug("Flag 192.50 ") 36 | import leidenalg as la 37 | p1 = la.find_partition(gNN, la.ModularityVertexPartition, seed=0) 38 | schema_debug("Flag 192.60 ") 39 | return np.array([a for a in p1.membership]) 40 | 41 | 42 | 43 | class ScatterPlotAligner: 44 | """ 45 | When seeded with a Mx2 matrix (that will be used for a scatter plot), 46 | will allow you to map other Mx2 matrices so that the original scatter plot and the 47 | new scatter plot will be in a similar orientation and positioning 48 | 49 | !!!!! DEPRECATED - do not use !!!!! 50 | """ 51 | 52 | def __init__(self): 53 | """ 54 | you will need to seed it before you can use it 55 | """ 56 | self._is_seeded = False 57 | 58 | 59 | def seed(self, D): 60 | """ 61 | the base matrix to be used for seeding. Should be a 2-d numpy type array. 62 | IMPORTANT: a transformed version of this will be returned to you by the call. Use that for plots. 63 | 64 | Returns: 1) Kx2 transformed matrix and, 2) (xmin, xmax, ymin, ymax) to use with plt.xlim & plt.ylim 65 | """ 66 | assert (not self._is_seeded) # Trying to seed an already-seeded ScatterPlotAligner 67 | assert D.shape[1] == 2 68 | 69 | self._is_seeded = True 70 | self._K = D.shape[0] 71 | 72 | mcd = sklearn.covariance.MinCovDet() 73 | mcd.fit(D) 74 | 75 | v = mcd.mahalanobis(D) 76 | self._valid_idx = (v < np.quantile(v,0.75)) #focus on the 75% of non-outlier data 77 | 78 | self._seed = D.copy() 79 | 80 | # translate and scale _seed so it's centered nicely around the _valid_idx's 81 | seed_mu = self._seed[self._valid_idx].mean(axis=0) 82 | self._seed = self._seed - seed_mu 83 | 84 | pt_magnitudes = (np.sum((self._seed[self._valid_idx])**2, axis=1))**0.5 85 | 86 | d_std = pt_magnitudes.std() 87 | self._seed = self._seed * (2.0/d_std) #have 1-SD pts be roughly at dist 2 from origin 88 | 89 | # get the bounding boxes 90 | u1 = 2*self._seed[self._valid_idx].min(axis=0) 91 | self._xmin, self._ymin = u1[0], u1[1] 92 | u2 = 2*self._seed[self._valid_idx].max(axis=0) 93 | self._xmax, self._ymax = u2[0], u2[1] 94 | 95 | return self._seed, (self._xmin, self._xmax, self._ymin, self._ymax) 96 | 97 | 98 | 99 | def map(self, D): 100 | """ 101 | D is a 2-d numpy-type array to be mapped to the seed. require D.shape == seed.shape 102 | 103 | IMPORTANT: seed() should be called first. 104 | Subsequent calls to map() will transform their inputs to match the seed's orientation 105 | 106 | Returns: a transformed version of D that you should use for plotting 107 | """ 108 | assert self._is_seeded 109 | assert D.shape[0] == self._K 110 | assert D.shape[1] == 2 111 | 112 | dx = D.copy() 113 | 114 | # translate and scale dx so it's centered nicely around the _valid_idx's 115 | dx_mu = dx[self._valid_idx].mean(axis=0) 116 | dx = dx - dx_mu 117 | 118 | pt_magnitudes = (np.sum((dx[self._valid_idx])**2, axis=1))**0.5 119 | 120 | d_std = pt_magnitudes.std() 121 | dx = dx * (2.0/d_std) 122 | 123 | #get the rotation matrix 124 | # from https://igl.ethz.ch/projects/ARAP/svd_rot.pdf (pg5) and http://post.queensu.ca/~sdb2/PAPERS/PAMI-3DLS-1987.pdf 125 | 126 | H = np.matmul( dx[self._valid_idx].T, self._seed[self._valid_idx]) 127 | #H = np.matmul( self._seed[self._valid_idx].T, dx[self._valid_idx]) 128 | 129 | import numpy.linalg 130 | 131 | U,S,V = numpy.linalg.svd(H, full_matrices=True) 132 | det_R0 = np.linalg.det(np.matmul(V, U.T)) 133 | R1 = np.diagflat(np.ones(2)) 134 | R1[1,1] = det_R0 135 | 136 | R = np.matmul(V, np.matmul(R1, U.T)) 137 | 138 | dx.iloc[:,:] = np.transpose( np.matmul(R, dx.T)) 139 | 140 | return dx #, (self._xmin, self._xmax, self._ymin, self._ymax) 141 | 142 | 143 | 144 | def sparse_read_csv(filename, index_col=None, verbose=False): 145 | """ 146 | csr read of file. Does a 100-row load first to get data-types. 147 | returns sparse matrix, rows and cols. 148 | based on https://github.com/berenslab/rna-seq-tsne 149 | 150 | #### Parameters 151 | 152 | `filename`: `string` 153 | 154 | 155 | 156 | `index_col`: `string` (default=`None`) 157 | 158 | name of column that serves as index_col. Passed unchanged to read_csv(...,index_col=index_col, ...) 159 | 160 | 161 | `verbose: `boolean` (default=`False`) 162 | 163 | say stuff 164 | """ 165 | 166 | small_chunk = pd.read_csv(filename, nrows=100) 167 | coltypes = dict(enumerate([a.name for a in small_chunk.dtypes.values])) 168 | 169 | indexL = [] 170 | chunkL = [] 171 | with open(filename) as file: 172 | for i,chunk in enumerate(pd.read_csv(filename, chunksize=3000, index_col=index_col, dtype=coltypes)): 173 | if verbose: print('.', end='', flush=True) 174 | if i==0: 175 | colsL = list(chunk.columns) 176 | indexL.extend(list(chunk.index)) 177 | chunkL.append(sparse.csr_matrix(chunk.values.astype(float))) 178 | 179 | mat = sparse.vstack(chunkL, 'csr') 180 | if verbose: print(' done') 181 | return mat, np.array(indexL), colsL 182 | 183 | 184 | 185 | def fast_csv_read(filename, *args, **kwargs): 186 | """ 187 | fast csv read. Like sparse_read_csv but returns a dense Pandas DF 188 | """ 189 | 190 | small_chunk = pd.read_csv(filename, nrows=50) 191 | if small_chunk.index[0] == 0: 192 | coltypes = dict(enumerate([a.name for a in small_chunk.dtypes.values])) 193 | return pd.read_csv(filename, dtype=coltypes, *args, **kwargs) 194 | else: 195 | coltypes = dict((i+1,k) for i,k in enumerate([a.name for a in small_chunk.dtypes.values])) 196 | coltypes[0] = str 197 | return pd.read_csv(filename, index_col=0, dtype=coltypes, *args, **kwargs) 198 | 199 | 200 | 201 | 202 | class SlideSeq: 203 | """ 204 | Utility class for Slide-Seq (Rodriques et al., Science 2019) data. Most methods are static. 205 | """ 206 | 207 | @staticmethod 208 | def loadRawData(datadir, puckid, num_nmf_factors=100, prep_for_benchmarking=False): 209 | """ 210 | Load data for a particular puck, clean it up a bit and store as AnnData. For later use, also performs a NMF and stores those. 211 | Borrows code from autoNMFreg_windows.py, provided with the Slide-Seq raw data. 212 | """ 213 | from sklearn.preprocessing import StandardScaler 214 | 215 | puckdir = "{0}/Puck_{1}".format(datadir, puckid) 216 | beadmapdir = max(glob.glob("{0}/BeadMapping_*-*_????".format(puckdir)), key=os.path.getctime) 217 | schema_debug("Flag 314.001 ", beadmapdir) 218 | 219 | # gene exp 220 | gexp_file = "{0}/MappedDGEForR.csv".format(beadmapdir) 221 | dge = fast_csv_read(gexp_file, header = 0, index_col = 0) 222 | # for faster testing runs, use below, it has just the first 500 cols of the gexp_file 223 | ## dge = fast_csv_read("/tmp/a1_dge.csv", header = 0, index_col = 0) 224 | dge = dge.T 225 | dge = dge.reset_index() 226 | dge = dge.rename(columns={'index':'barcode'}) 227 | schema_debug("Flag 314.010 ", dge.shape, dge.columns) 228 | 229 | # spatial location 230 | beadloc_file = "{0}/BeadLocationsForR.csv".format(beadmapdir) 231 | coords = fast_csv_read(beadloc_file, header = 0) 232 | coords = coords.rename(columns={'Barcodes':'barcode'}) 233 | coords = coords.rename(columns={'barcodes':'barcode'}) 234 | schema_debug("Flag 314.020 ", coords.shape, coords.columns) 235 | 236 | # Slide-Seq cluster assignments 237 | atlas_clusters_file = "{0}/AnalogizerClusterAssignments.csv".format(beadmapdir) 238 | clstrs = pd.read_csv(atlas_clusters_file, index_col=None) 239 | assert list(clstrs.columns) == ["Var1","x"] 240 | clstrs.columns = ["barcode","atlas_cluster"] 241 | clstrs = clstrs.set_index("barcode") 242 | schema_debug("Flag 314.030 ", clstrs.shape, clstrs.columns) 243 | 244 | df_merged = dge.merge(coords, right_on='barcode', left_on='barcode') 245 | df_merged = df_merged[ df_merged.barcode.isin(clstrs.index)] 246 | schema_debug("Flag 314.040 ", df_merged.shape, df_merged.columns) 247 | 248 | # remove sparse gene exp 249 | counts = df_merged.drop(['xcoord', 'ycoord'], axis=1) 250 | counts2 = counts.copy(deep=True) 251 | counts2 = counts2.set_index('barcode') #.drop('barcode',axis=1) 252 | counts2_okcols = counts2.sum(axis=0) > 0 253 | counts2 = counts2.loc[:, counts2_okcols] 254 | UMI_threshold = 5 255 | counts2_umis = counts2.sum(axis=1).values 256 | counts2 = counts2.loc[counts2_umis > UMI_threshold,:] 257 | schema_debug("Flag 314.0552 ", counts.shape, counts2.shape, counts2_umis.shape,isinstance(counts2, pd.DataFrame)) 258 | 259 | #slide-seq authors normalize to have sum=1 across each bead, rather than 1e6 260 | cval = counts2_umis[counts2_umis>UMI_threshold] 261 | if not prep_for_benchmarking: 262 | counts2 = counts2.divide(cval, axis=0) #np.true_divide(counts2, counts2_umis[:,None]) 263 | #counts2 = np.true_divide(counts2, counts2_umis[:,None]) 264 | 265 | # this is also a little unusual, but I'm following their practice 266 | counts2.iloc[:,:] = StandardScaler(with_mean=False).fit_transform(counts2.values) 267 | schema_debug("Flag 314.0553 ", counts2.shape, counts2_umis.shape,isinstance(counts2, pd.DataFrame)) 268 | 269 | coords2 = df_merged.loc[ df_merged.barcode.isin(counts2.index), ["barcode","xcoord","ycoord"]].copy(deep=True) 270 | coords2 = coords2.set_index('barcode') #.drop('barcode', axis=1) 271 | schema_debug("Flag 314.0555 ", coords2.shape,isinstance(coords2, pd.DataFrame)) 272 | 273 | ok_barcodes = set(coords2.index) & set(counts2.index) & set(clstrs.index) 274 | schema_debug("Flag 314.060 ", coords2.shape, counts2.shape, clstrs.shape, len(ok_barcodes)) 275 | 276 | if prep_for_benchmarking: 277 | return (counts2[counts2.index.isin(ok_barcodes)].sort_index(), coords2[coords2.index.isin(ok_barcodes)].sort_index(), clstrs[clstrs.index.isin(ok_barcodes)].sort_index()) 278 | 279 | ## do NMF 280 | K1 = num_nmf_factors 281 | listK1 = ["P{}".format(i+1) for i in range(K1)] 282 | random_state = 17 #for repeatability, a fixed value 283 | model1 = sklearn.decomposition.NMF(n_components= K1, init='random', random_state = random_state, alpha = 0, l1_ratio = 0) 284 | Ho = model1.fit_transform(counts2.values) #yes, slideseq code had Ho and Wo mixed up. Just following their lead here. 285 | Wo = model1.components_ 286 | 287 | schema_debug("Flag 314.070 ", Ho.shape, Wo.shape) 288 | 289 | Ho_norm = StandardScaler(with_mean=False).fit_transform(Ho) 290 | Ho_norm = pd.DataFrame(Ho_norm) 291 | Ho_norm.index = counts2.index 292 | Ho_norm.columns = listK1 293 | Wo = pd.DataFrame(Wo) 294 | Wo.index = listK1; Wo.index.name = "Factor" 295 | Wo.columns = list(counts2.columns) 296 | 297 | Ho_norm = Ho_norm[Ho_norm.index.isin(ok_barcodes)] 298 | Ho_norm = Ho_norm / Ho_norm.std(axis=0) 299 | 300 | schema_debug("Flag 314.080 ", Ho_norm.shape, Wo.shape) 301 | 302 | genexp = counts2[ counts2.index.isin(ok_barcodes)].sort_index() 303 | beadloc = coords2[ coords2.index.isin(ok_barcodes)].sort_index() 304 | clstrs = clstrs[ clstrs.index.isin(ok_barcodes)].sort_index() 305 | Ho_norm = Ho_norm.sort_index() 306 | 307 | schema_debug("Flag 314.090 ", genexp.shape, beadloc.shape, clstrs.shape, Ho_norm.shape, genexp.index[:5], beadloc.index[:5]) 308 | 309 | beadloc["atlas_cluster"] = clstrs["atlas_cluster"] 310 | 311 | if "AnnData" not in dir(): 312 | from anndata import AnnData 313 | 314 | adata = AnnData(X = genexp.values, obs = beadloc, uns = {"Ho": Ho_norm, "Ho.index": list(Ho_norm.index), "Ho.columns": list(Ho_norm.columns), 315 | "Wo": Wo, "Wo.index": list(Wo.index), "Wo.columns": list(Wo.columns)}) 316 | return adata 317 | 318 | 319 | 320 | @staticmethod 321 | def loadAnnData(fpath): 322 | """ 323 | Import a h5ad file. Also deals with some scanpy weirdness when loading dataframes in the .uns slot 324 | """ 325 | 326 | import matplotlib 327 | matplotlib.use("agg") #otherwise scanpy tries to use tkinter which has issues importing 328 | import scanpy as sc 329 | 330 | adata = sc.read(fpath) 331 | for k in ["Ho","Wo"]: 332 | adata.uns[k] = pd.DataFrame(adata.uns[k], index=adata.uns[k + ".index"], columns= adata.uns[k + ".columns"]) 333 | return adata 334 | 335 | 336 | 337 | 338 | class SciCar: 339 | """ 340 | Utility class for Sci-Car (Cao et al., Science 2018) data. Most methods are static. 341 | """ 342 | 343 | 344 | @staticmethod 345 | def loadData(path, gsm_info, refdata): 346 | """ 347 | Load sci-CAR data format, as uploaded to GEO. Written for DOI:10.1126/science.aau0730 348 | but will probably work with other Shendure Lab datasets as well 349 | 350 | #### Parameters 351 | 352 | `path`: `string` 353 | 354 | directory where the files are 355 | 356 | 357 | `gsm_info`: `list` of (`string`, `string`, `string`, `function`, `function`) tuples 358 | 359 | Each tuple corresponds to one dataset (e.g. RNA-seq or ATAC-seq) 360 | The list should be in order of importance. Once a feature name for a modality shows up, 361 | it will be ignored in subsequent tuples 362 | The cells will be taken as the intersection of "sample" column from the various cell files 363 | 364 | tuple = (name, modality, gsm_id, f_cell_filter, f_mdlty_filter) 365 | 366 | name: your name for the dataset. 367 | modality: 'gene' or 'peak' for now 368 | gsm_id: GSMXXXXXX id 369 | f_cell_filter: None or a function that returns a boolean given a row-vector, used to filter 370 | f_mdlty_filter: None or a function that returns a boolean given a row-vector, used to filter 371 | 372 | 373 | `refdata`: `string`: path to a tab-separated file 374 | 375 | Read as a dataframe containing Ensemble IDs ("ensembl_id"), TSS start/end etc. 376 | 377 | #### Returns 378 | 379 | AnnData object where the cells are rows in X, the columns of dataset #1 are in .X, 380 | the columns for remaining datasets are in .uns dicts keyed by given dataset name. 381 | The cell-line info is in .obs, while dataset #1's modality info is in .var 382 | """ 383 | 384 | dref = pd.read_csv(refdata, sep="\t", low_memory=False) 385 | dref["tss_adj"] = np.where(dref["strand"]=="+", dref["txstart"], dref["txend"]) 386 | dref = dref["ensembl_id,chr,strand,txstart,txend,tss_adj,map_location,Symbol".split(",")] 387 | 388 | if "AnnData" not in dir(): 389 | from anndata import AnnData 390 | 391 | datasets = [] 392 | cell_sets = [] 393 | for nm, typestr, gsm, f_cell_filter, f_mdlty_filter in gsm_info: 394 | try: 395 | assert typestr in ['gene','peak'] 396 | 397 | cells = pd.read_csv(glob.glob("{0}/{1}*_cell.txt".format(path, gsm))[0], low_memory=False) 398 | mdlty = pd.read_csv(glob.glob("{0}/{1}*_{2}.txt".format(path, gsm, typestr))[0], low_memory=False) 399 | 400 | 401 | cell_idx = np.full((cells.shape[0],),True) 402 | if f_cell_filter is not None: 403 | cell_idx = (cells.apply(f_cell_filter, axis=1, result_type="reduce")).values 404 | 405 | if typestr=="gene": 406 | mdlty_idx = np.full((mdlty.shape[0],),True) 407 | if f_mdlty_filter is not None: 408 | mdlty_idx = (mdlty.apply(f_mdlty_filter, axis=1, result_type="reduce")).values 409 | 410 | schema_debug ("Flag 324.112 filtered {0}:{1}:{2} mdlty".format(len(mdlty_idx), np.sum(mdlty_idx), mdlty_idx.shape)) 411 | 412 | mdlty["gene_id"] = mdlty["gene_id"].apply(lambda v: v.split('.')[0]) 413 | mdlty_idx[~(mdlty["gene_id"].isin(dref["ensembl_id"]))] = False 414 | 415 | schema_debug ("Flag 324.113 filtered {0}:{1}:{2} mdlty".format(len(mdlty_idx), np.sum(mdlty_idx), mdlty_idx.shape)) 416 | 417 | mdlty["index"] = np.arange(mdlty.shape[0]) 418 | mdlty = (pd.merge(mdlty, dref, left_on="gene_id", right_on="ensembl_id", how="left")) 419 | mdlty = mdlty.drop_duplicates("index").sort_values("index").reset_index(drop=True).drop(columns=["index"]) 420 | 421 | schema_debug ("Flag 324.114 filtered {0}:{1}:{2} mdlty".format(len(mdlty_idx), np.sum(mdlty_idx), mdlty_idx.shape)) 422 | 423 | def f_is_standard_chromosome(v): 424 | try: 425 | assert v[:3]=="chr" and ("_" not in v) and (v[3] in ['X','Y'] or int(v[3:])<=22 ) 426 | return True 427 | except: 428 | return False 429 | 430 | mdlty_idx[~(mdlty["chr"].apply(f_is_standard_chromosome))] = False 431 | 432 | schema_debug ("Flag 324.115 filtered {0}:{1}:{2} mdlty".format(len(mdlty_idx), np.sum(mdlty_idx), mdlty_idx.shape)) 433 | 434 | else: 435 | mdlty_idx = np.full((mdlty.shape[0],),True) 436 | if f_mdlty_filter is not None: 437 | mdlty_idx = (mdlty.apply(f_mdlty_filter, axis=1, result_type="reduce")).values 438 | 439 | 440 | 441 | # data = pd.DataFrame(data = scipy.io.mmread(glob.glob("{0}/{1}*_count.txt".format(path, gsm))[0]).T.tocsr().astype(np.float_), 442 | # index = cells['sample'], 443 | # columns = mdlty['mdlty_short_name']) 444 | 445 | data = scipy.io.mmread(glob.glob("{0}/{1}*_{2}_count.txt".format(path, gsm, typestr))[0]).T.tocsr().astype(np.float_) 446 | 447 | schema_debug ("Flag 324.12 read {0} cells, {1} mdlty, {2} data".format(cells.shape, mdlty.shape, data.shape)) 448 | 449 | 450 | schema_debug ("Flag 324.13 filtered {0}:{1}:{2} cells, {3}:{4}:{5} mdlty".format( len(cell_idx), np.sum(cell_idx), cell_idx.shape, len(mdlty_idx), np.sum(mdlty_idx), mdlty_idx.shape)) 451 | 452 | data = data[cell_idx, :] 453 | data = data[:, mdlty_idx] 454 | #data = data[cell_idx,:] # mdlty_idx] 455 | cells = cells[cell_idx].reset_index(drop=True) 456 | mdlty = mdlty[mdlty_idx].reset_index(drop=True) 457 | 458 | schema_debug ("Flag 324.14 filtered down to {0} cells, {1} mdlty, {2} data".format(cells.shape, mdlty.shape, data.shape)) 459 | 460 | 461 | schema_debug ("Flag 324.15 filtered down to {0} cells, {1} mdlty, {2} data".format(cells.shape, mdlty.shape, data.shape)) 462 | 463 | sortidx = np.argsort(cells['sample'].values) 464 | cells = cells.iloc[sortidx,:].reset_index(drop=True) 465 | data = data[sortidx, :] 466 | 467 | schema_debug ("Flag 324.17 \n {0} \n {1}".format( cells.head(2), data[:2,:2])) 468 | 469 | datasets.append((nm, typestr, data, cells, mdlty)) 470 | 471 | cell_sets.append(set(cells['sample'].values)) 472 | 473 | except: 474 | raise 475 | #raise ValueError('{0} could not be read in {1}'.format(nm, path)) 476 | 477 | common_cells = list(set.intersection(*cell_sets)) 478 | schema_debug ("Flag 324.20 got {0} common cells {1}".format( len(common_cells), common_cells[:10])) 479 | 480 | def logcpm(dx): 481 | libsizes = 1e-6 + np.sum(dx, axis=1) 482 | dxout = copy.deepcopy(dx) 483 | for i in range(dxout.shape[0]): 484 | i0,i1 = dxout.indptr[i], dxout.indptr[i+1] 485 | dxout.data[i0:i1] = np.log2(dxout.data[i0:i1]*1e6/libsizes[i] + 1) 486 | #for ind in range(i0,i1): 487 | # dxout.data[ind] = np.log2( dxout.data[ind]*1e6/libsizes[i] + 1) 488 | return dxout 489 | 490 | 491 | for i, dx in enumerate(datasets): 492 | nm, typestr, data, cells, mdlty = dx 493 | 494 | cidx = np.in1d(cells["sample"].values, common_cells) 495 | schema_debug ("Flag 324.205 got {0} {1} {2} {3}".format( nm, cells.shape, len(cidx), np.sum(cidx))) 496 | 497 | data = data[cidx,:] 498 | cells = cells.iloc[cidx,:].reset_index(drop=True) 499 | mdlty = mdlty.set_index('gene_short_name' if typestr=='gene' else 'peak') 500 | 501 | if i==0: 502 | cells = cells.set_index('sample') 503 | #adata = AnnData(X = logcpm(data), obs = cells.copy(deep=True), var = mdlty.copy(deep=True)) 504 | adata = AnnData(X = data, obs = cells.copy(deep=True), var = mdlty.copy(deep=True)) 505 | adata.uns["names"] = [] 506 | 507 | schema_debug ("Flag 324.22 got X {0} obs {1} var {2} uns {3}".format( adata.X.shape, adata.obs.shape, adata.var.shape, list(adata.uns.keys()))) 508 | else: 509 | for c in cells.columns: 510 | if c in adata.obs.columns: continue 511 | adata.obs[c] = cells[c] 512 | #adata.uns[nm + ".X"] = logcpm(data) 513 | adata.uns[nm + ".X"] = data 514 | adata.uns[nm + ".var"] = mdlty.copy(deep=True) 515 | adata.uns[nm + ".var.index"] = list(mdlty.index) #scanpy is annoying, it'll convert these to numpy matrices when writing 516 | adata.uns[nm + ".var.columns"] = list(mdlty.columns) 517 | 518 | adata.obs[typestr + "_log_sumcounts"] = np.log2(np.sum(data,axis=1)+1) 519 | adata.uns["names"].append(nm) 520 | adata.uns[nm + ".type"] = typestr 521 | adata.var_names_make_unique() 522 | schema_debug ("Flag 324.25 got X {0} obs {1} var {2} uns {3}".format( adata.X.shape, adata.obs.shape, adata.var.shape, list(adata.uns.keys()))) 523 | 524 | return adata 525 | 526 | @staticmethod 527 | def loadAnnData(fpath): 528 | import matplotlib 529 | matplotlib.use("agg") #otherwise scanpy tries to use tkinter which has issues importing 530 | import scanpy as sc 531 | 532 | adata = sc.read(fpath) 533 | for k in adata.uns.keys(): 534 | if k.endswith(".var") and isinstance(adata.uns[k], np.ndarray): 535 | schema_debug("Flag 343.100 hello", k, adata.uns.keys(), type(adata.uns[k])) 536 | adata.uns[k] = pd.DataFrame(adata.uns[k], index= adata.uns[k + ".index"], columns = adata.uns[k + ".columns"]) 537 | schema_debug("Flag 343.102 ", k, adata.uns.keys(), type(adata.uns[k])) 538 | for c in adata.uns[k].columns: 539 | if c in ["id","start","end"]: 540 | adata.uns[k][c] = adata.uns[k][c].astype(int) 541 | if c in ["chr"]: 542 | adata.uns[k][c] = adata.uns[k][c].astype(str) 543 | return adata 544 | 545 | @staticmethod 546 | def preprocessAnnData(adata, do_logcpm=True, valid_gene_minobs=0, valid_peak_minobs=0, valid_cell_mingenes=0): 547 | """ 548 | Preprocess sci-CAR data to remove too-sparse genes, peaks and cells. Also, convert to log(counts_per_million(..)) format 549 | 550 | #### Parameters 551 | 552 | `adata`: `AnnData` 553 | 554 | The dataframe containing Read as a dataframe containing Ensemble IDs ("ensembl_id"), TSS start/end etc. 555 | 556 | 557 | `do_logcpm`: `bool` 558 | 559 | Convert peak and gene expression counts to log2 counts-per-million 560 | 561 | 562 | `valid_gene_minobs`: `int` 563 | 564 | Only keep genes that show up in at least valid_gene_minobs cells 565 | 566 | 567 | `valid_aux_minobs`: `int` 568 | 569 | Only keep peaks that show up in at least valid_peak_minobs cells 570 | 571 | 572 | `valid_cell_mingenes`: `int` 573 | 574 | Only keep cells that have at least valid_cell_mingenes genes 575 | 576 | #### Returns 577 | 578 | copy of filtered anndata 579 | """ 580 | 581 | valid_genes = np.ravel((adata.X > 0).sum(axis=0)) >= valid_gene_minobs 582 | adata = adata[:, valid_genes] 583 | 584 | valid_cells = np.ravel((adata.X > 0).sum(axis=1)) >= valid_cell_mingenes 585 | adata = adata[valid_cells, :] 586 | 587 | if "atac.X" in adata.uns: 588 | adata.uns["atac.X"] = adata.uns["atac.X"][ valid_cells, :] 589 | 590 | valid_peaks = np.ravel((adata.uns["atac.X"] > 0).sum(axis=0)) >= valid_peak_minobs 591 | adata.uns["atac.X"] = adata.uns["atac.X"][:, valid_peaks] 592 | adata.uns["atac.var"] = adata.uns["atac.var"][valid_peaks] 593 | adata.uns["atac.var.index"] = adata.uns["atac.var.index"][valid_peaks] 594 | 595 | adata2 = adata.copy() 596 | 597 | def logcpm(dx): 598 | libsizes = 1e-6 + np.sum(dx, axis=1) 599 | schema_debug ("Flag 3343.10 ", libsizes.shape, libsizes.sum()) 600 | dxout = dx #copy.deepcopy(dx) 601 | for i in range(dxout.shape[0]): 602 | i0,i1 = dxout.indptr[i], dxout.indptr[i+1] 603 | dxout.data[i0:i1] = np.log2(dxout.data[i0:i1]*1e6/libsizes[i] + 1) 604 | #for ind in range(i0,i1): 605 | # dxout.data[ind] = np.log2( dxout.data[ind]*1e6/libsizes[i] + 1) 606 | return dxout 607 | 608 | adata2.X = logcpm(adata2.X) 609 | if "atac.X" in adata.uns: 610 | adata2.uns["atac.X"] = logcpm(adata2.uns["atac.X"]) 611 | 612 | return adata2 613 | 614 | 615 | 616 | @staticmethod 617 | def getChrMapping(adata): 618 | """ 619 | Get a mapping of genes/peaks to chromosomes and back 620 | 621 | #### Parameters 622 | 623 | `adata`: `AnnData` object 624 | 625 | output from loadData(...) 626 | 627 | #### Returns 628 | gchr, chr2genes, pchr, chr2peaks : the first and third are integer vectors, the second and fourth are int->set(int) dicts. All gene and peak 629 | integers refer to indexes 630 | """ 631 | chr2genes = defaultdict(set) 632 | chr2peaks = defaultdict(set) 633 | gchr = adata.var["chr"].astype(str).apply(lambda s: s.replace("chr","")) 634 | for i in range(adata.var.shape[0]): 635 | chr2genes[gchr[i]].add(i) 636 | pchr = adata.uns["atac.var"]["chr"].astype(str).apply(lambda s: s.replace("chr","")) 637 | for i in range(adata.uns["atac.var"].shape[0]): 638 | chr2peaks[pchr[i]].add(i) 639 | return gchr, chr2genes, pchr, chr2peaks 640 | 641 | 642 | @staticmethod 643 | def getPeakPosReGenes(gVar, peak, gene2chr): 644 | """ 645 | Given a peak, get its position vis-a-vis the genes in the genome 646 | 647 | #### Parameters 648 | 649 | `gVar`: `pd.DataFrame` 650 | 651 | adata.var df from adata object 652 | 653 | 654 | `peak`: `list` of size 3 655 | 656 | [chr, start, end] identifying the peak 657 | 658 | `gene2chr`: `array of strings` 659 | 660 | for each gene idx, indicates which chromosome it is a part of 661 | 662 | #### Returns 663 | gVar.shape[0] x 5 nd-array, with 5 numbers describing the peak's posn re the gene (see code) 664 | """ 665 | pchr, pstart, pend = peak 666 | schema_debug ("Flag 321.10012 ", peak) 667 | pchr = (str(pchr)).replace("chr","") 668 | 669 | 670 | #5 dims: 0: same chr, 1: pStart-gStart, 2: pEnd-gStart, 3: pStart-gEnd, 4: pEnd-gEnd 671 | pos = np.zeros((gVar.shape[0],5)) 672 | pos[:,0] = np.where(gene2chr==pchr,1,0) 673 | pos[:,1] = np.where(gVar["strand"]=="+", gVar["txstart"]-pstart, pend-gVar["txend"]) #+ if pstart is upstream of txstart 674 | pos[:,2] = np.where(gVar["strand"]=="+", gVar["txstart"]-pend, pstart-gVar["txend"]) #+ if pend is upstream of txstart 675 | pos[:,3] = np.where(gVar["strand"]=="+", gVar["txend"]-pstart, pend-gVar["txstart"]) #+ if pstart is upstream of txend 676 | pos[:,4] = np.where(gVar["strand"]=="+", gVar["txend"]-pend, pstart-gVar["txstart"]) #+ if pend is upstream of txend 677 | 678 | assert np.sum((pos[:,2] > 0) & (pos[:,1] <= 0)) ==0 679 | assert np.sum((pos[:,1] > 0) & (pos[:,3] <= 0)) ==0 680 | assert np.sum((pos[:,3] < 0) & (pos[:,4] >= 0)) ==0 681 | assert np.sum((pos[:,4] < 0) & (pos[:,2] >= 0)) ==0 682 | 683 | # pos[:,1] = np.where(gVar["strand"]=="+", pstart-gVar["txstart"], gVar["txend"]-pend) 684 | # pos[:,2] = np.where(gVar["strand"]=="+", pend-gVar["txstart"], gVar["txend"]-pstart) 685 | # pos[:,3] = np.where(gVar["strand"]=="+", pstart-gVar["txend"], gVar["txstart"]-pend) 686 | # pos[:,4] = np.where(gVar["strand"]=="+", pend-gVar["txend"], gVar["txstart"]-pstart) 687 | return pos 688 | 689 | 690 | @staticmethod 691 | def fpeak_0_500(pos): 692 | """ 693 | peak ends within 500bp upstream of gene 694 | """ 695 | return np.where(((pos[:,0]> 0) & 696 | (pos[:,2]>=0) & 697 | (pos[:,2] < 5e2)), 1, 0) 698 | 699 | @staticmethod 700 | def fpeak_500_2e3(pos): 701 | """ 702 | peak ends within 500-2000bp upstream of gene 703 | """ 704 | return np.where(((pos[:,0]> 0) & 705 | (pos[:,2]>0) & 706 | (pos[:,2]> 5e2) & 707 | (pos[:,2] <= 2e3)), 1, 0) 708 | 709 | @staticmethod 710 | def fpeak_2e3_20e3(pos): 711 | """ 712 | peak ends within 2k-20kb upstream of gene 713 | """ 714 | return np.where(((pos[:,0]> 0) & 715 | (pos[:,2]>0) & 716 | (pos[:,2]> 2e3) & 717 | (pos[:,2] <= 20e3)), 1, 0) 718 | 719 | @staticmethod 720 | def fpeak_20e3_100e3(pos): 721 | """ 722 | peak ends within 20-100kb upstream of gene 723 | """ 724 | return np.where(((pos[:,0]> 0) & 725 | (pos[:,2]>0) & 726 | (pos[:,2]> 20e3) & 727 | (pos[:,2] <= 100e3)), 1, 0) 728 | 729 | 730 | @staticmethod 731 | def fpeak_100e3_1e6(pos): 732 | """ 733 | peak ends within 100kb-1Mb upstream of gene 734 | """ 735 | return np.where(((pos[:,0]> 0) & 736 | (pos[:,2]>0) & 737 | (pos[:,2]> 100e3) & 738 | (pos[:,2] <= 1e6)), 1, 0) 739 | 740 | @staticmethod 741 | def fpeak_1e6_10e6(pos): 742 | """ 743 | peak ends within 1Mb-10Mb upstream of gene 744 | """ 745 | return np.where(((pos[:,0]> 0) & 746 | (pos[:,2]>0) & 747 | (pos[:,2]> 1e6) & 748 | (pos[:,2] <= 10e6)), 1, 0) 749 | 750 | 751 | @staticmethod 752 | def fpeak_10e6_20e6(pos): 753 | """ 754 | peak ends within 10Mb-20Mb upstream of gene 755 | """ 756 | return np.where(((pos[:,0]> 0) & 757 | (pos[:,2]>0) & 758 | (pos[:,2]> 10e6) & 759 | (pos[:,2] <= 20e6)), 1, 0) 760 | 761 | 762 | @staticmethod 763 | def fpeak_crossing_in(pos): 764 | """ 765 | peak spans the TSS of the gene 766 | """ 767 | return np.where(((pos[:,0]> 0) & 768 | (pos[:,1]>0) & 769 | (pos[:,2]<=0) & 770 | (pos[:,4]>0)), 1, 0) 771 | 772 | 773 | @staticmethod 774 | def fpeak_inside(pos): 775 | """ 776 | peak is between the start and end of the gene 777 | """ 778 | return np.where(((pos[:,0]> 0) & 779 | (pos[:,1]<0) & 780 | (pos[:,2]<0) & 781 | (pos[:,3]>0) & 782 | (pos[:,4]>0)), 1, 0) 783 | 784 | 785 | @staticmethod 786 | def fpeak_crossing_out(pos): 787 | """ 788 | peak is spands the txend of the gene 789 | """ 790 | return np.where(((pos[:,0]> 0) & 791 | (pos[:,1]<0) & 792 | (pos[:,2]<0) & 793 | (pos[:,3]>0) & 794 | (pos[:,4]<0)), 1, 0) 795 | 796 | 797 | 798 | @staticmethod 799 | def fpeak_behind_1e3(pos): 800 | """ 801 | peak starts within 1kb of gene end 802 | """ 803 | return np.where(((pos[:,0]> 0) & 804 | (pos[:,3]<0) & 805 | (-pos[:,3] < 1e3)), 1, 0) 806 | 807 | 808 | @staticmethod 809 | def fpeak_behind_1e3_20e3(pos): 810 | """ 811 | peak starts within 1-20kb of gene end 812 | """ 813 | return np.where(((pos[:,0]> 0) & 814 | (pos[:,3]<0) & 815 | (-pos[:,3] > 1e3) & 816 | (-pos[:,3] < 20e3)), 1, 0) 817 | 818 | 819 | @staticmethod 820 | def fpeak_behind_20e3_100e3(pos): 821 | """ 822 | peak starts within 20kb-100kb of gene end 823 | """ 824 | return np.where(((pos[:,0]> 0) & 825 | (pos[:,3]<0) & 826 | (-pos[:,3] > 20e3) & 827 | (-pos[:,3] < 100e3)), 1, 0) 828 | 829 | @staticmethod 830 | def fpeak_behind_100e3_1e6(pos): 831 | """ 832 | peak starts within 100kb-1Mb of gene end 833 | """ 834 | return np.where(((pos[:,0]> 0) & 835 | (pos[:,3]<0) & 836 | (-pos[:,3] > 100e3) & 837 | (-pos[:,3] < 1e6)), 1, 0) 838 | 839 | 840 | @staticmethod 841 | def fpeak_behind_1e6_10e6(pos): 842 | """ 843 | peak starts within 1Mb-10Mb of gene end 844 | """ 845 | return np.where(((pos[:,0]> 0) & 846 | (pos[:,3]<0) & 847 | (-pos[:,3] > 1e6) & 848 | (-pos[:,3] < 10e6)), 1, 0) 849 | 850 | 851 | @staticmethod 852 | def fpeak_behind_10e6_20e6(pos): 853 | """ 854 | peak starts within 10Mb-20Mb of gene end 855 | """ 856 | return np.where(((pos[:,0]> 0) & 857 | (pos[:,3]<0) & 858 | (-pos[:,3] > 10e6) & 859 | (-pos[:,3] < 20e6)), 1, 0) 860 | 861 | 862 | @staticmethod 863 | def fpeak_rbf_500(pos): 864 | """ 865 | exp(-d/500)^2 if peak ends upstream of gene 866 | """ 867 | return np.where((pos[:,0]> 0) & (pos[:,2]>=0), 868 | np.exp(-((pos[:,2]/5e2)**2)), 0) 869 | 870 | 871 | @staticmethod 872 | def fpeak_rbf_1e3(pos): 873 | """ 874 | exp(-d/1e3)^2 if peak ends upstream of gene 875 | """ 876 | return np.where((pos[:,0]> 0) & (pos[:,2]>=0), 877 | np.exp(-((pos[:,2]/1e3)**2)), 0) 878 | 879 | 880 | @staticmethod 881 | def fpeak_rbf_5e3(pos): 882 | """ 883 | exp(-d/5e3)^2 if peak ends upstream of gene 884 | """ 885 | return np.where((pos[:,0]> 0) & (pos[:,2]>=0), 886 | np.exp(-((pos[:,2]/5e3)**2)), 0) 887 | 888 | 889 | @staticmethod 890 | def fpeak_rbf_20e3(pos): 891 | """ 892 | exp(-d/20e3)^2 if peak ends upstream of gene 893 | """ 894 | return np.where((pos[:,0]> 0) & (pos[:,2]>=0), 895 | np.exp(-((pos[:,2]/20e3)**2)), 0) 896 | 897 | 898 | @staticmethod 899 | def fpeak_rbf_100e3(pos): 900 | """ 901 | exp(-d/100e3)^2 if peak ends upstream of gene 902 | """ 903 | return np.where((pos[:,0]> 0) & (pos[:,2]>=0), 904 | np.exp(-((pos[:,2]/100e3)**2)), 0) 905 | 906 | 907 | @staticmethod 908 | def fpeak_rbf_1e6(pos): 909 | """ 910 | exp(-d/1e6)^2 if peak ends upstream of gene 911 | """ 912 | return np.where((pos[:,0]> 0) & (pos[:,2]>=0), 913 | np.exp(-((pos[:,2]/1e6)**2)), 0) 914 | 915 | 916 | @staticmethod 917 | def fpeak_rbf_10e6(pos): 918 | """ 919 | exp(-d/10e6)^2 if peak ends upstream of gene 920 | """ 921 | return np.where((pos[:,0]> 0) & (pos[:,2]>=0), 922 | np.exp(-((pos[:,2]/10e6)**2)), 0) 923 | 924 | 925 | @staticmethod 926 | def fpeak_behind_rbf_20e3(pos): 927 | """ 928 | exp(-d/20e3)^2 if peak starts downstream of gene 929 | """ 930 | return np.where((pos[:,0]> 0) & (pos[:,3]>=0), 931 | np.exp(-((pos[:,3]/20e3)**2)), 0) 932 | 933 | 934 | @staticmethod 935 | def fpeak_behind_rbf_100e3(pos): 936 | """ 937 | exp(-d/100e3)^2 if peak starts downstream of gene 938 | """ 939 | return np.where((pos[:,0]> 0) & (pos[:,3]>=0), 940 | np.exp(-((pos[:,3]/100e3)**2)), 0) 941 | 942 | 943 | @staticmethod 944 | def fpeak_behind_rbf_1e6(pos): 945 | """ 946 | exp(-d/1e6)^2 if peak starts downstream of gene 947 | """ 948 | return np.where((pos[:,0]> 0) & (pos[:,3]>=0), 949 | np.exp(-((pos[:,3]/1e6)**2)), 0) 950 | 951 | 952 | @staticmethod 953 | def fpeak_behind_rbf_10e6(pos): 954 | """ 955 | exp(-d/10e6)^2 if peak starts downstream of gene 956 | """ 957 | return np.where((pos[:,0]> 0) & (pos[:,3]>=0), 958 | np.exp(-((pos[:,3]/10e6)**2)), 0) 959 | 960 | 961 | 962 | fpeak_list_all = [ (fpeak_0_500.__func__, "fpeak_0_500"), 963 | (fpeak_500_2e3.__func__, "fpeak_500_2e3"), 964 | (fpeak_2e3_20e3.__func__, "fpeak_2e3_20e3"), 965 | (fpeak_20e3_100e3.__func__, "fpeak_20e3_100e3"), 966 | (fpeak_100e3_1e6.__func__, "fpeak_100e3_1e6"), 967 | (fpeak_1e6_10e6.__func__, "fpeak_1e6_10e6"), 968 | (fpeak_10e6_20e6.__func__, "fpeak_10e6_20e6"), 969 | (fpeak_crossing_in.__func__, "fpeak_crossing_in"), 970 | (fpeak_inside.__func__, "fpeak_inside"), 971 | (fpeak_crossing_out.__func__, "fpeak_crossing_out"), 972 | (fpeak_behind_1e3.__func__, "fpeak_behind_1e3"), 973 | (fpeak_behind_1e3_20e3.__func__, "fpeak_behind_1e3_20e3"), 974 | (fpeak_behind_20e3_100e3.__func__, "fpeak_behind_20e3_100e3"), 975 | (fpeak_behind_100e3_1e6.__func__, "fpeak_behind_100e3_1e6"), 976 | (fpeak_behind_1e6_10e6.__func__, "fpeak_behind_1e6_10e6"), 977 | (fpeak_behind_10e6_20e6.__func__, "fpeak_behind_10e6_20e6"), 978 | (fpeak_rbf_500.__func__, "fpeak_rbf_500"), 979 | (fpeak_rbf_1e3.__func__, "fpeak_rbf_1e3"), 980 | (fpeak_rbf_5e3.__func__, "fpeak_rbf_5e3"), 981 | (fpeak_rbf_20e3.__func__, "fpeak_rbf_20e3"), 982 | (fpeak_rbf_100e3.__func__, "fpeak_rbf_100e3"), 983 | (fpeak_rbf_1e6.__func__, "fpeak_rbf_1e6"), 984 | (fpeak_rbf_10e6.__func__, "fpeak_rbf_10e6"), 985 | (fpeak_behind_rbf_20e3.__func__, "fpeak_behind_rbf_20e3"), 986 | (fpeak_behind_rbf_100e3.__func__, "fpeak_behind_rbf_100e3"), 987 | (fpeak_behind_rbf_1e6.__func__, "fpeak_behind_rbf_1e6"), 988 | (fpeak_behind_rbf_10e6.__func__, "fpeak_behind_rbf_10e6"), 989 | ] 990 | 991 | 992 | 993 | @staticmethod 994 | def computeGeneByFpeakMatrix(adata, peak_func_list, chr_mapping = None, peakList = None, normalize_distwt=True, booleanPeakCounts = False): 995 | """ 996 | Compute a matrix that is nG x len(peak_func_list) with cell [i,j] being gene[i]'s dot-product with peak scores across all peaks, 997 | subject to peak wt as described by f_peak 998 | 999 | #### Parameters 1000 | 1001 | `adata` : AnnData object from loadData(...) 1002 | 1003 | `peak_func_list` : `list` of functions of the signature `pos: int` 1004 | 1005 | `chr_mapping`: `4-tuple` 1006 | 1007 | Optional argument providing the output of `getChrMapping(..)`, for caching 1008 | 1009 | `peakList`: `list of int` 1010 | 1011 | Optional argument specifying which peak indexes to run over 1012 | 1013 | #### Returns 1014 | 1015 | a matrix of shape adata.var.shape[0] X len(peak_func_list) 1016 | 1017 | """ 1018 | 1019 | if chr_mapping is None: 1020 | gene2chr, chr2genes, peak2chr, chr2peaks = SciCar.getChrMapping(adata) 1021 | else: 1022 | gene2chr, chr2genes, peak2chr, chr2peaks = chr_mapping 1023 | 1024 | nCells, nGenes, nPeaks = adata.shape[0], adata.shape[1], adata.uns["atac.X"].shape[1] 1025 | try: 1026 | gXt = adata.X.T.tocsr() 1027 | except: 1028 | gXt = adata.X.T 1029 | 1030 | if booleanPeakCounts and gXt.shape[0] > 0: 1031 | gXt = gXt > np.median(gXt, axis=0) ## UNDO ???? 1032 | 1033 | pXt = adata.uns["atac.X"].T.tocsr() 1034 | gVar = adata.var[["chr","strand","txstart","txend","tss_adj"]] 1035 | pVar = adata.uns["atac.var"][["chr","start","end"]] 1036 | 1037 | k = len(peak_func_list) 1038 | 1039 | g2p = np.zeros((nGenes,k)) 1040 | g2p_wts = np.zeros((nGenes,k)) 1041 | 1042 | if peakList is not None: 1043 | plist = peakList 1044 | else: 1045 | plist = range(nPeaks) 1046 | 1047 | for i in plist: 1048 | v = SciCar.getPeakPosReGenes(gVar, pVar.values[i,:], gene2chr) 1049 | 1050 | pXti = np.ravel(pXt[i,:].todense()) 1051 | 1052 | pXti_positive = (pXti > 0).astype(int) 1053 | pXti_ones = np.ones_like(pXti) 1054 | 1055 | schema_debug ("Flag 2.0008 ",i, flush=True) 1056 | 1057 | for j,f in enumerate(peak_func_list): 1058 | distwt = f(v) 1059 | #schema_debug("Flag 2.0010 ", i,j,np.std(distwt), np.mean(distwt)) 1060 | if np.sum(distwt)>1e-12: 1061 | if normalize_distwt: distwt = distwt/np.mean(distwt) 1062 | G = gXt.copy() 1063 | try: 1064 | G.data *= distwt.repeat(np.diff(G.indptr)) 1065 | except: 1066 | G = G*distwt[:,None] 1067 | # schema_debug ("Flag 2.0201 ", len(distwt), np.sum(G.data), G.shape, 1068 | # np.sum(gXt.data), 1069 | # G.data.shape, len(np.diff(G.indptr)), 1070 | # pXt.shape) 1071 | gw = G.dot(pXti).ravel() 1072 | if booleanPeakCounts: 1073 | gw = G.dot(pXti_positive).ravel() 1074 | g2p[:,j] += gw/nCells 1075 | 1076 | gw_ones = G.dot(pXti_ones).ravel() 1077 | g2p_wts[:,j] += gw_ones/nCells 1078 | 1079 | # schema_debug ("Flag 2.0320 ", G.shape, distwt.shape, gXt.shape, pXt.shape, gw.shape, g2p.shape) 1080 | 1081 | return (g2p, g2p_wts) 1082 | 1083 | 1084 | 1085 | 1086 | 1087 | @staticmethod 1088 | def aggregatePeaksByGenes(adata, peak_func, chr_mapping = None): 1089 | """ 1090 | Compute a matrix that is nCells x nGenes with cell [i,j] being expression of peaks around gene[i] in cell [i], 1091 | subject to peak wt as described by f_peak 1092 | 1093 | #### Parameters 1094 | 1095 | `adata` : AnnData object from loadData(...) 1096 | 1097 | `peak_func` : function of the signature `pos: int` 1098 | 1099 | `chr_mapping`: `4-tuple` 1100 | 1101 | Optional argument providing the output of `getChrMapping(..)`, for caching 1102 | 1103 | #### Returns 1104 | 1105 | a matrix of shape adata.shape 1106 | 1107 | """ 1108 | 1109 | 1110 | if chr_mapping is None: 1111 | gene2chr, chr2genes, peak2chr, chr2peaks = SciCar.getChrMapping(adata) 1112 | else: 1113 | gene2chr, chr2genes, peak2chr, chr2peaks = chr_mapping 1114 | 1115 | nCells, nGenes, nPeaks = adata.shape[0], adata.shape[1], adata.uns["atac.X"].shape[1] 1116 | 1117 | pX = adata.uns["atac.X"] 1118 | 1119 | exp_pX = np.exp(pX) #data was log1p'd before. We'll add the cpm counts and then re-log1p it 1120 | gVar = adata.var[["chr","strand","txstart","txend","tss_adj"]] 1121 | pVar = adata.uns["atac.var"][["chr","start","end"]] 1122 | 1123 | c2g = np.zeros(nCells, nGenes) 1124 | c2g_tmp = np.zeros(nCells, nGenes) 1125 | 1126 | plist = range(nPeaks) 1127 | 1128 | for i in plist: 1129 | c2g_tmp[:] = 0 1130 | 1131 | v = SciCar.getPeakPosReGenes(gVar, pVar.values[i,:], gene2chr) 1132 | c2g_tmp += peak_func(v) 1133 | p_i = np.exp(np.ravel(pX[i,:].todense())) 1134 | c2g_tmp *= p_i[:,None] 1135 | 1136 | c2g += c2g_tmp 1137 | 1138 | c2g = np.log1p(c2g) 1139 | return c2g 1140 | 1141 | 1142 | 1143 | 1144 | 1145 | 1146 | 1147 | def cmpSpearmanVsPearson(d1, type1="numeric", d2=None, type2=None, nPointPairs=2000000, nRuns=1): 1148 | 1149 | N, K1 = d1.shape[0], d1.shape[1] 1150 | K2 = 0 1151 | if d2 is not None: 1152 | assert d2.shape[0] == N 1153 | K2 = d2.shape[1] 1154 | 1155 | if nPointPairs is None: assert nRuns==1 1156 | 1157 | schema_debug ("Flag 676.10 ", N, K1, K2, nPointPairs, type1, type2) 1158 | 1159 | corrL = [] 1160 | 1161 | for nr in range(nRuns): 1162 | if nPointPairs is not None: 1163 | j_u = np.random.randint(0, N, int(3*nPointPairs)) 1164 | j_v = np.random.randint(0, N, int(3*nPointPairs)) 1165 | valid = j_u < j_v #get rid of potential duplicates (x1,x2) and (x2,x1) as well as (x1,x1) 1166 | i_u = (j_u[valid])[:nPointPairs] 1167 | i_v = (j_v[valid])[:nPointPairs] 1168 | else: 1169 | x = pd.DataFrame.from_records(list(itertools.combinations(range(N),2)), columns=["u","v"]) 1170 | x = x[x.u < x.v] 1171 | i_u = x.u.values 1172 | i_v = x.v.values 1173 | 1174 | schema_debug ("Flag 676.30 ", i_u.shape, i_v.shape, nr) 1175 | 1176 | dL = [] 1177 | for g_val, g_type in [(d1, type1), (d2, type2)]: 1178 | if g_val is None: 1179 | dL.append(None) 1180 | continue 1181 | 1182 | dg = [] 1183 | for ii in np.split(np.arange(i_u.shape[0]), int(i_u.shape[0]/2000)): 1184 | ii_u = i_u[ii] 1185 | ii_v = i_v[ii] 1186 | 1187 | if g_type == "categorical": 1188 | dgx = 1.0*( g_val[ii_u] != g_val[ii_v]) #1.0*( g_val[i_u].toarray() != g_val[i_v].toarray()) 1189 | elif g_type == "feature_vector": 1190 | schema_debug (g_val[ii_u].shape, g_val[ii_v].shape) 1191 | dgx = np.ravel(np.sum(np.power(g_val[ii_u].astype(np.float64) - g_val[ii_v].astype(np.float64),2), axis=1)) 1192 | else: #numeric 1193 | dgx = (g_val[ii_u].astype(np.float64) - g_val[ii_v].astype(np.float64))**2 #(g_val[i_u].toarray() - g_val[i_v].toarray 1194 | dg.extend(dgx) 1195 | schema_debug ("Flag 676.50 ", g_type, len(dg)) 1196 | dL.append(np.array(dg)) 1197 | 1198 | schema_debug ("Flag 676.60 ") 1199 | dg1, dg2 = dL[0], dL[1] 1200 | 1201 | if d2 is None: 1202 | rp = scipy.stats.pearsonr(dg1, scipy.stats.rankdata(dg1))[0] 1203 | rs = None 1204 | else: 1205 | rp = scipy.stats.pearsonr(dg1, dg2)[0] 1206 | rs = scipy.stats.spearmanr(dg1, dg2)[0] 1207 | 1208 | schema_debug ("Flag 676.80 ", rp, rs) 1209 | corrL.append( (rp,rs)) 1210 | 1211 | return corrL 1212 | 1213 | 1214 | 1215 | 1216 | 1217 | 1218 | 1219 | 1220 | --------------------------------------------------------------------------------