├── .gitattributes ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── build └── lib │ └── solo │ ├── __init__.py │ ├── hashsolo.py │ ├── solo.py │ └── utils.py ├── hashsolo_params_example.json ├── prespecified.txt ├── release_script.sh ├── requirements.txt ├── setup.cfg ├── setup.py ├── solo ├── __init__.py ├── hashsolo.py ├── solo.py └── utils.py ├── solo_params_example.json ├── testdata ├── 2c.h5ad ├── calculate_performance.py ├── gene_ad_filtered_PoolB4FACs_L4_Rep1.h5ad ├── kidney_performance_tracking.png ├── pbmc_performance_tracking.png ├── performance_test_kidney_PoolB4FACs_L4_Rep1.sh ├── performance_test_pbmc_2c.sh ├── performance_tracking.png └── tracking_performance.csv └── tests └── hashsolo_tests.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.h5ad filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ._*.py 2 | solo_sc.egg-info/ 3 | dist/ 4 | solo.egg-info* 5 | solo/__pycache_* 6 | .ipynb_checkpoints 7 | */.ipynb_checkpoints 8 | slurm-* 9 | testdata/results* 10 | logs/ 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Calico 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # solo -- Doublet detection via semi-supervised deep learning 2 | ### Why 3 | Cells subjected to single cell RNA-seq have been through a lot, and they'd really just like to be alone now, please. If they cannot escape the cell social scene, you end up sequencing RNA from more than one cell to a barcode, creating a *doublet* when you expected single cell profiles. https://www.cell.com/cell-systems/fulltext/S2405-4712(20)30195-2 4 | 5 | **_solo_** is a neural network framework to classify doublets, so that you can remove them from your data and clean your single cell profile. 6 | 7 | We benchmarked **_solo_** against other doublet detection tools such as DoubletFinder and Scrublet, and found that it consistently outperformed them in terms of average precision. Additionally, Solo performed much better on a more complex tissue, mouse kidney. 8 | 9 | ### Quick set up 10 | Run the following to clone and set up ve. 11 | `git clone git@github.com:calico/solo.git && cd solo && conda create -n solo python=3.12 && conda activate solo && pip install -e .` 12 | 13 | Or install via pip 14 | `conda create -n solo python=3.12 && conda activate solo && pip install solo-sc` 15 | 16 | If you don't have conda follow the instructions here: https://docs.conda.io/projects/conda/en/latest/user-guide/install/ 17 | 18 | ### ≈ 19 | ``` 20 | usage: solo [-h] -j MODEL_JSON_FILE -d DATA_PATH 21 | [--set-reproducible-seed REPRODUCIBLE_SEED] 22 | [--doublet-depth DOUBLET_DEPTH] [-g] [-a] [-o OUT_DIR] 23 | [-r DOUBLET_RATIO] [-s SEED] [-e EXPECTED_NUMBER_OF_DOUBLETS] [-p] 24 | [-recalibrate_scores] [--version] [--lr_st] [--lr_vae] 25 | 26 | optional arguments: 27 | -h, --help show this help message and exit 28 | -j MODEL_JSON_FILE json file to pass VAE parameters (default: None) 29 | -d DATA_PATH path to h5ad, loom, or 10x mtx dir cell by genes 30 | counts (default: None) 31 | --set-reproducible-seed REPRODUCIBLE_SEED 32 | Reproducible seed, give an int to set seed (default: 33 | None) 34 | --doublet-depth DOUBLET_DEPTH 35 | Depth multiplier for a doublet relative to the average 36 | of its constituents (default: 2.0) 37 | -g Run on GPU (default: True) 38 | -a output modified anndata object with solo scores Only 39 | works for anndata (default: False) 40 | -o OUT_DIR 41 | -r DOUBLET_RATIO Ratio of doublets to true cells (default: 2) 42 | -s SEED Path to previous solo output directory. Seed VAE 43 | models with previously trained solo model. Directory 44 | structure is assumed to be the same as solo output 45 | directory structure. should at least have a vae.pt a 46 | pickled object of vae weights and a latent.npy an 47 | np.ndarray of the latents of your cells. (default: 48 | None) 49 | -e EXPECTED_NUMBER_OF_DOUBLETS 50 | Experimentally expected number of doublets (default: 51 | None) 52 | -p Plot outputs for solo (default: False) 53 | -recalibrate_scores Recalibrate doublet scores (not recommended anymore) 54 | (default: False) 55 | --version Get version of solo-sc (default: False) 56 | --lr_st 57 | Learning rate used for solo.train (default: 1e-3) 58 | --lr_vae 59 | Learning rate used for vae (default: 1e-3) 60 | 61 | ``` 62 | 63 | Warning: If you are going directly from cellranger 10x output you may want to manually inspect your data prior to running solo. 64 | 65 | model_json example: 66 | ``` 67 | { 68 | "n_hidden": 384, 69 | "n_latent": 64, 70 | "n_layers": 1, 71 | "cl_hidden": 128, 72 | "cl_layers": 1, 73 | "dropout_rate": 0.2, 74 | "lr_st": 1e-3, 75 | "valid_pct": 0.10 76 | } 77 | 78 | The suggested learning rates work best in most settings, but in case a ValueError occurs, you might consider changing the learning rates to 1e-5 79 | 80 | ``` 81 | 82 | Outputs: 83 | * `is_doublet.npy` np boolean array, true if a cell is a doublet, differs from `preds.npy` if `-e expected_number_of_doublets` parameter was used 84 | * `vae` scVI directory for vae 85 | * `classifier.pt` scVI directory for classifier 86 | * `latent.npy` latent embedding for each cell 87 | * `preds.npy` doublet predictions 88 | * `softmax_scores.npy` updated softmax of doublet scores (see paper), same as `no_update_softmax_scores.npy` now 89 | * `no_update_softmax_scores.npy` raw softmax of doublet scores 90 | 91 | * `logit_scores.npy` logit of doublet scores 92 | * `real_cells_dist.pdf` histogram of distribution of doublet scores 93 | * `accuracy.pdf` accuracy plot test vs train 94 | * `train_v_test_dist.pdf` doublet scores of test vs train 95 | * `roc.pdf` roc of test vs train 96 | * `softmax_scores_sim.npy` see above but for simulated doublets 97 | * `logit_scores_sim.npy` see above but for simulated doublets 98 | * `preds_sim.npy` see above but for simulated doublets 99 | * `is_doublet_sim.npy` see above but for simulated doublets 100 | 101 | 102 | ### How to demultiplex cell hashing data using HashSolo CLI 103 | 104 | Demultiplexing takes as input an h5ad file with only hashing counts. Counts can be obtained from your fastqs by using kite. See tutorial here: https://github.com/pachterlab/kite 105 | 106 | ``` 107 | usage: hashsolo [-h] [-j MODEL_JSON_FILE] [-o OUT_DIR] [-c CLUSTERING_DATA] 108 | [-p PRE_EXISTING_CLUSTERS] [-q PLOT_NAME] 109 | [-n NUMBER_OF_NOISE_BARCODES] 110 | data_file 111 | 112 | positional arguments: 113 | data_file h5ad file containing cell hashing counts 114 | 115 | optional arguments: 116 | -h, --help show this help message and exit 117 | -j MODEL_JSON_FILE json file to pass optional arguments (default: None) 118 | -o OUT_DIR Output directory for results (default: 119 | hashsolo_output) 120 | -c CLUSTERING_DATA h5ad file with count transcriptional data to perform 121 | clustering on (default: None) 122 | -p PRE_EXISTING_CLUSTERS 123 | column in cell_hashing_data_file.obs to specifying 124 | different cell types or clusters (default: None) 125 | -q PLOT_NAME name of plot to output (default: hashing_qc_plots.pdf) 126 | -n NUMBER_OF_NOISE_BARCODES 127 | Number of barcodes to use to create noise distribution 128 | (default: None) 129 | ``` 130 | 131 | model_json example: 132 | ``` 133 | { 134 | "priors": [0.01, 0.5, 0.49] 135 | } 136 | ``` 137 | 138 | Priors is a list of the probability of the three hypotheses, negative, singlet, 139 | or doublet that we test when demultiplexing cell hashing data. A negative cell's barcodes 140 | doesn't have enough signal to identify its sample of origin. A singlet has 141 | enough signal from single hashing barcode to associate the cell with ins 142 | originating sample. A doublet is a cell barcode which has signal for more than one hashing barcode. 143 | Depending on how you processed your cell hashing matrix before hand you may 144 | want to set different priors. Under the assumption that you have subset your cell 145 | barcodes using typical QC on your cell by genes matrix, e.g. min UMI counts, 146 | percent mitochondrial reads, etc. We found the above setting of prior performed 147 | well (see paper). If you have only done relatively light QC in transcriptome space 148 | I'd suggest an even prior, e.g. `[1./3., 1./3., 1./3.]`. 149 | 150 | 151 | Outputs: 152 | * `hashsoloed.h5ad` anndata with demultiplexing information in .obs 153 | * `hashing_qc_plots.png` plots of probabilites for each cell 154 | 155 | 156 | ### How to demultiplex cell hashing data using HashSolo in line 157 | 158 | ``` 159 | >>> from solo import hashsolo 160 | >>> import anndata 161 | >>> cell_hashing_data = anndata.read("cell_hashing_counts.h5ad") 162 | >>> hashsolo.hashsolo(cell_hashing_data) 163 | >>> cell_hashing_data.obs.head() 164 | most_likeli_hypothesis cluster_feature negative_hypothesis_probability singlet_hypothesis_probability doublet_hypothesis_probability Classification 165 | index 166 | CCTTTCTGTCCGAACC 2 0 1.203673e-16 0.000002 0.999998 Doublet 167 | CTGATAGGTGACTCAT 1 0 1.370633e-09 0.999920 0.000080 BatchF-GTGTGACGTATT_x 168 | AGCTCTCGTTGTCTTT 1 0 2.369380e-13 0.996992 0.003008 BatchE-GAGGCTGAGCTA_x 169 | GTGCGGTAGCGATGAC 1 0 1.579405e-09 0.999879 0.000121 BatchB-ACATGTTACCGT_x 170 | AAATGCCTCTAACCGA 1 0 1.867626e-13 0.999707 0.000293 BatchB-ACATGTTACCGT_x 171 | >>> demultiplex.plot_qc_checks_cell_hashing(cell_hashing_data) 172 | ``` 173 | 174 | 175 | * `most_likeli_hypothesis` 0 == Negative, 1 == Singlet, 2 == Doublet 176 | * `cluster_feature` how the cell hashing data was divided if specified or done automatically by giving a cell by genes anndata object to the `cluster_data` argument when calling `demultiplex_cell_hashing` 177 | * `negative_hypothesis_probability` 178 | * `singlet_hypothesis_probability` 179 | * `doublet_hypothesis_probability` 180 | * `Classification` The sample of origin for the cell or whether it was a negative or doublet cell. 181 | -------------------------------------------------------------------------------- /build/lib/solo/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = "David Kelley, Nick Bernstein" 2 | __email__ = "nicholas@calicolabs.com" 3 | __version__ = "0.1" 4 | 5 | from . import hashsolo, utils 6 | -------------------------------------------------------------------------------- /build/lib/solo/hashsolo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import json 4 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 5 | 6 | from scipy.stats import norm 7 | from itertools import product 8 | import anndata 9 | import numpy as np 10 | import pandas as pd 11 | import scanpy as sc 12 | 13 | from scipy.sparse import issparse 14 | from sklearn.metrics import calinski_harabasz_score 15 | 16 | """ 17 | HashSolo script provides a probabilistic cell hashing demultiplexing method 18 | which generates a noise distribution and signal distribution for 19 | each hashing barcode from empirically observed counts. These distributions 20 | are updates from the global signal and noise barcode distributions, which 21 | helps in the setting where not many cells are observed. Signal distributions 22 | for a hashing barcode are estimated from samples where that hashing barcode 23 | has the highest count. Noise distributions for a hashing barcode are estimated 24 | from samples where that hashing barcode is one the k-2 lowest barcodes, where 25 | k is the number of barcodes. A doublet should then have its two highest 26 | barcode counts most likely coming from a signal distribution for those barcodes. 27 | A singlet should have its highest barcode from a signal distribution, and its 28 | second highest barcode from a noise distribution. A negative two highest 29 | barcodes should come from noise distributions. We test each of these 30 | hypotheses in a bayesian fashion, and select the most probable hypothesis. 31 | """ 32 | 33 | 34 | def _calculate_log_likelihoods(data, number_of_noise_barcodes): 35 | """Calculate log likelihoods for each hypothesis, negative, singlet, doublet 36 | 37 | Parameters 38 | ---------- 39 | data : np.ndarray 40 | cells by hashing counts matrix 41 | number_of_noise_barcodes : int, 42 | number of barcodes to used to calculated noise distribution 43 | Returns 44 | ------- 45 | log_likelihoods_for_each_hypothesis : np.ndarray 46 | a 2d np.array log likelihood of each hypothesis 47 | all_indices 48 | counter_to_barcode_combo 49 | """ 50 | 51 | def gaussian_updates(data, mu_o, std_o): 52 | """Update parameters of your gaussian 53 | https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf 54 | Parameters 55 | ---------- 56 | data : np.array 57 | 1-d array of counts 58 | mu_o : float, 59 | global mean for hashing count distribution 60 | std_o : float, 61 | global std for hashing count distribution 62 | Returns 63 | ------- 64 | float 65 | mean of gaussian 66 | float 67 | std of gaussian 68 | """ 69 | lam_o = 1 / (std_o ** 2) 70 | n = len(data) 71 | lam = 1 / np.var(data) if len(data) > 1 else lam_o 72 | lam_n = lam_o + n * lam 73 | mu_n = ( 74 | (np.mean(data) * n * lam + mu_o * lam_o) / lam_n if len(data) > 0 else mu_o 75 | ) 76 | return mu_n, (1 / (lam_n / (n + 1))) ** (1 / 2) 77 | 78 | eps = 1e-15 79 | # probabilites for negative, singlet, doublets 80 | log_likelihoods_for_each_hypothesis = np.zeros((data.shape[0], 3)) 81 | 82 | all_indices = np.empty(data.shape[0]) 83 | num_of_barcodes = data.shape[1] 84 | number_of_non_noise_barcodes = ( 85 | num_of_barcodes - number_of_noise_barcodes 86 | if number_of_noise_barcodes is not None 87 | else 2 88 | ) 89 | num_of_noise_barcodes = num_of_barcodes - number_of_non_noise_barcodes 90 | 91 | # assume log normal 92 | data = np.log(data + 1) 93 | data_arg = np.argsort(data, axis=1) 94 | data_sort = np.sort(data, axis=1) 95 | 96 | # global signal and noise counts useful for when we have few cells 97 | # barcodes with the highest number of counts are assumed to be a true signal 98 | # barcodes with rank < k are considered to be noise 99 | global_signal_counts = np.ravel(data_sort[:, -1]) 100 | global_noise_counts = np.ravel(data_sort[:, :-number_of_non_noise_barcodes]) 101 | global_mu_signal_o, global_sigma_signal_o = np.mean(global_signal_counts), np.std( 102 | global_signal_counts 103 | ) 104 | global_mu_noise_o, global_sigma_noise_o = np.mean(global_noise_counts), np.std( 105 | global_noise_counts 106 | ) 107 | 108 | noise_params_dict = {} 109 | signal_params_dict = {} 110 | 111 | # for each barcode get empirical noise and signal distribution parameterization 112 | for x in np.arange(num_of_barcodes): 113 | sample_barcodes = data[:, x] 114 | sample_barcodes_noise_idx = np.where(data_arg[:, :num_of_noise_barcodes] == x)[ 115 | 0 116 | ] 117 | sample_barcodes_signal_idx = np.where(data_arg[:, -1] == x) 118 | 119 | # get noise and signal counts 120 | noise_counts = sample_barcodes[sample_barcodes_noise_idx] 121 | signal_counts = sample_barcodes[sample_barcodes_signal_idx] 122 | 123 | # get parameters of distribution, assuming lognormal do update from global values 124 | noise_param = gaussian_updates( 125 | noise_counts, global_mu_noise_o, global_sigma_noise_o 126 | ) 127 | signal_param = gaussian_updates( 128 | signal_counts, global_mu_signal_o, global_sigma_signal_o 129 | ) 130 | noise_params_dict[x] = noise_param 131 | signal_params_dict[x] = signal_param 132 | 133 | counter_to_barcode_combo = {} 134 | counter = 0 135 | 136 | # for each combination of noise and signal barcode calculate probiltiy of in silico and real cell hypotheses 137 | for noise_sample_idx, signal_sample_idx in product( 138 | np.arange(num_of_barcodes), np.arange(num_of_barcodes) 139 | ): 140 | signal_subset = data_arg[:, -1] == signal_sample_idx 141 | noise_subset = data_arg[:, -2] == noise_sample_idx 142 | subset = signal_subset & noise_subset 143 | if sum(subset) == 0: 144 | continue 145 | 146 | indices = np.where(subset)[0] 147 | barcode_combo = "_".join([str(noise_sample_idx), str(signal_sample_idx)]) 148 | all_indices[np.where(subset)[0]] = counter 149 | counter_to_barcode_combo[counter] = barcode_combo 150 | counter += 1 151 | noise_params = noise_params_dict[noise_sample_idx] 152 | signal_params = signal_params_dict[signal_sample_idx] 153 | 154 | # calculate probabilties for each hypothesis for each cell 155 | data_subset = data[subset] 156 | log_signal_signal_probs = np.log( 157 | norm.pdf( 158 | data_subset[:, signal_sample_idx], 159 | *signal_params[:-2], 160 | loc=signal_params[-2], 161 | scale=signal_params[-1] 162 | ) 163 | + eps 164 | ) 165 | signal_noise_params = signal_params_dict[noise_sample_idx] 166 | log_noise_signal_probs = np.log( 167 | norm.pdf( 168 | data_subset[:, noise_sample_idx], 169 | *signal_noise_params[:-2], 170 | loc=signal_noise_params[-2], 171 | scale=signal_noise_params[-1] 172 | ) 173 | + eps 174 | ) 175 | 176 | log_noise_noise_probs = np.log( 177 | norm.pdf( 178 | data_subset[:, noise_sample_idx], 179 | *noise_params[:-2], 180 | loc=noise_params[-2], 181 | scale=noise_params[-1] 182 | ) 183 | + eps 184 | ) 185 | log_signal_noise_probs = np.log( 186 | norm.pdf( 187 | data_subset[:, signal_sample_idx], 188 | *noise_params[:-2], 189 | loc=noise_params[-2], 190 | scale=noise_params[-1] 191 | ) 192 | + eps 193 | ) 194 | 195 | probs_of_negative = np.sum( 196 | [log_noise_noise_probs, log_signal_noise_probs], axis=0 197 | ) 198 | probs_of_singlet = np.sum( 199 | [log_noise_noise_probs, log_signal_signal_probs], axis=0 200 | ) 201 | probs_of_doublet = np.sum( 202 | [log_noise_signal_probs, log_signal_signal_probs], axis=0 203 | ) 204 | log_probs_list = [probs_of_negative, probs_of_singlet, probs_of_doublet] 205 | 206 | # each cell and each hypothesis probability 207 | for prob_idx, log_prob in enumerate(log_probs_list): 208 | log_likelihoods_for_each_hypothesis[indices, prob_idx] = log_prob 209 | return log_likelihoods_for_each_hypothesis, all_indices, counter_to_barcode_combo 210 | 211 | 212 | def _calculate_bayes_rule(data, priors, number_of_noise_barcodes): 213 | """ 214 | Calculate bayes rule from log likelihoods 215 | 216 | Parameters 217 | ---------- 218 | data : np.array 219 | Anndata object filled only with hashing counts 220 | priors : list, 221 | a list of your prior for each hypothesis 222 | first element is your prior for the negative hypothesis 223 | second element is your prior for the singlet hypothesis 224 | third element is your prior for the doublet hypothesis 225 | We use [0.01, 0.8, 0.19] by default because we assume the barcodes 226 | in your cell hashing matrix are those cells which have passed QC 227 | in the transcriptome space, e.g. UMI counts, pct mito reads, etc. 228 | number_of_noise_barcodes : int 229 | number of barcodes to used to calculated noise distribution 230 | Returns 231 | ------- 232 | bayes_dict_results : dict 233 | 'most_likely_hypothesis' key is a 1d np.array of the most likely hypothesis 234 | 'probs_hypotheses' key is a 2d np.array probability of each hypothesis 235 | 'log_likelihoods_for_each_hypothesis' key is a 2d np.array log likelihood of each hypothesis 236 | """ 237 | priors = np.array(priors) 238 | log_likelihoods_for_each_hypothesis, _, _ = _calculate_log_likelihoods( 239 | data, number_of_noise_barcodes 240 | ) 241 | probs_hypotheses = ( 242 | np.exp(log_likelihoods_for_each_hypothesis) 243 | * priors 244 | / np.sum( 245 | np.multiply(np.exp(log_likelihoods_for_each_hypothesis), priors), axis=1 246 | )[:, None] 247 | ) 248 | most_likely_hypothesis = np.argmax(probs_hypotheses, axis=1) 249 | return { 250 | "most_likely_hypothesis": most_likely_hypothesis, 251 | "probs_hypotheses": probs_hypotheses, 252 | "log_likelihoods_for_each_hypothesis": log_likelihoods_for_each_hypothesis, 253 | } 254 | 255 | 256 | def _get_clusters(clustering_data: anndata.AnnData, resolutions: list): 257 | """ 258 | Principled cell clustering 259 | Parameters 260 | ---------- 261 | cell_hashing_adata : anndata.AnnData 262 | Anndata object filled only with hashing counts 263 | resolutions : list 264 | clustering resolutions for leiden 265 | Returns 266 | ------- 267 | np.ndarray 268 | leiden clustering results for each cell 269 | """ 270 | sc.pp.normalize_per_cell(clustering_data, counts_per_cell_after=1e4) 271 | sc.pp.log1p(clustering_data) 272 | sc.pp.highly_variable_genes( 273 | clustering_data, min_mean=0.0125, max_mean=3, min_disp=0.5 274 | ) 275 | clustering_data = clustering_data[:, clustering_data.var["highly_variable"]] 276 | sc.pp.scale(clustering_data, max_value=10) 277 | sc.tl.pca(clustering_data, svd_solver="arpack") 278 | sc.pp.neighbors(clustering_data, n_neighbors=10, n_pcs=40) 279 | sc.tl.umap(clustering_data) 280 | best_ch_score = -np.inf 281 | 282 | for resolution in resolutions: 283 | sc.tl.leiden(clustering_data, resolution=resolution) 284 | 285 | ch_score = calinski_harabasz_score( 286 | clustering_data.X, clustering_data.obs["leiden"] 287 | ) 288 | 289 | if ch_score > best_ch_score: 290 | clustering_data.obs["best_leiden"] = clustering_data.obs["leiden"].values 291 | best_ch_score = ch_score 292 | return clustering_data.obs["best_leiden"].values 293 | 294 | 295 | def hashsolo( 296 | cell_hashing_adata: anndata.AnnData, 297 | priors: list = [0.01, 0.8, 0.19], 298 | pre_existing_clusters: str = None, 299 | clustering_data: anndata.AnnData = None, 300 | resolutions: list = [0.1, 0.25, 0.5, 0.75, 1], 301 | number_of_noise_barcodes: int = None, 302 | inplace: bool = True, 303 | ): 304 | """Demultiplex cell hashing dataset using HashSolo method 305 | 306 | Parameters 307 | ---------- 308 | cell_hashing_adata : anndata.AnnData 309 | Anndata object filled only with hashing counts 310 | priors : list, 311 | a list of your prior for each hypothesis 312 | first element is your prior for the negative hypothesis 313 | second element is your prior for the singlet hypothesis 314 | third element is your prior for the doublet hypothesis 315 | We use [0.01, 0.8, 0.19] by default because we assume the barcodes 316 | in your cell hashing matrix are those cells which have passed QC 317 | in the transcriptome space, e.g. UMI counts, pct mito reads, etc. 318 | clustering_data : anndata.AnnData 319 | transcriptional data for clustering 320 | resolutions : list 321 | clustering resolutions for leiden 322 | pre_existing_clusters : str 323 | column in cell_hashing_adata.obs for how to break up demultiplexing 324 | inplace : bool 325 | To do operation in place 326 | 327 | Returns 328 | ------- 329 | cell_hashing_adata : AnnData 330 | if inplace is False returns AnnData with demultiplexing results 331 | in .obs attribute otherwise does is in place 332 | """ 333 | if issparse(cell_hashing_adata.X): 334 | cell_hashing_adata.X = np.array(cell_hashing_adata.X.todense()) 335 | 336 | if clustering_data is not None: 337 | print( 338 | "This may take awhile we are running clustering at {} different resolutions".format( 339 | len(resolutions) 340 | ) 341 | ) 342 | if not all(clustering_data.obs_names == cell_hashing_adata.obs_names): 343 | raise ValueError( 344 | "clustering_data and cell hashing cell_hashing_adata must have same index" 345 | ) 346 | cell_hashing_adata.obs["best_leiden"] = _get_clusters( 347 | clustering_data, resolutions 348 | ) 349 | 350 | data = cell_hashing_adata.X 351 | num_of_cells = cell_hashing_adata.shape[0] 352 | results = pd.DataFrame( 353 | np.zeros((num_of_cells, 6)), 354 | columns=[ 355 | "most_likely_hypothesis", 356 | "probs_hypotheses", 357 | "cluster_feature", 358 | "negative_hypothesis_probability", 359 | "singlet_hypothesis_probability", 360 | "doublet_hypothesis_probability", 361 | ], 362 | index=cell_hashing_adata.obs_names, 363 | ) 364 | if clustering_data is not None or pre_existing_clusters is not None: 365 | cluster_features = ( 366 | "best_leiden" if pre_existing_clusters is None else pre_existing_clusters 367 | ) 368 | unique_cluster_features = np.unique(cell_hashing_adata.obs[cluster_features]) 369 | for cluster_feature in unique_cluster_features: 370 | cluster_feature_bool_vector = ( 371 | cell_hashing_adata.obs[cluster_features] == cluster_feature 372 | ) 373 | posterior_dict = _calculate_bayes_rule( 374 | data[cluster_feature_bool_vector], priors, number_of_noise_barcodes 375 | ) 376 | results.loc[ 377 | cluster_feature_bool_vector, "most_likely_hypothesis" 378 | ] = posterior_dict["most_likely_hypothesis"] 379 | results.loc[ 380 | cluster_feature_bool_vector, "cluster_feature" 381 | ] = cluster_feature 382 | results.loc[ 383 | cluster_feature_bool_vector, "negative_hypothesis_probability" 384 | ] = posterior_dict["probs_hypotheses"][:, 0] 385 | results.loc[ 386 | cluster_feature_bool_vector, "singlet_hypothesis_probability" 387 | ] = posterior_dict["probs_hypotheses"][:, 1] 388 | results.loc[ 389 | cluster_feature_bool_vector, "doublet_hypothesis_probability" 390 | ] = posterior_dict["probs_hypotheses"][:, 2] 391 | else: 392 | posterior_dict = _calculate_bayes_rule(data, priors, number_of_noise_barcodes) 393 | results.loc[:, "most_likely_hypothesis"] = posterior_dict[ 394 | "most_likely_hypothesis" 395 | ] 396 | results.loc[:, "cluster_feature"] = 0 397 | results.loc[:, "negative_hypothesis_probability"] = posterior_dict[ 398 | "probs_hypotheses" 399 | ][:, 0] 400 | results.loc[:, "singlet_hypothesis_probability"] = posterior_dict[ 401 | "probs_hypotheses" 402 | ][:, 1] 403 | results.loc[:, "doublet_hypothesis_probability"] = posterior_dict[ 404 | "probs_hypotheses" 405 | ][:, 2] 406 | 407 | cell_hashing_adata.obs["most_likely_hypothesis"] = results.loc[ 408 | cell_hashing_adata.obs_names, "most_likely_hypothesis" 409 | ] 410 | cell_hashing_adata.obs["cluster_feature"] = results.loc[ 411 | cell_hashing_adata.obs_names, "cluster_feature" 412 | ] 413 | cell_hashing_adata.obs["negative_hypothesis_probability"] = results.loc[ 414 | cell_hashing_adata.obs_names, "negative_hypothesis_probability" 415 | ] 416 | cell_hashing_adata.obs["singlet_hypothesis_probability"] = results.loc[ 417 | cell_hashing_adata.obs_names, "singlet_hypothesis_probability" 418 | ] 419 | cell_hashing_adata.obs["doublet_hypothesis_probability"] = results.loc[ 420 | cell_hashing_adata.obs_names, "doublet_hypothesis_probability" 421 | ] 422 | 423 | cell_hashing_adata.obs["Classification"] = None 424 | cell_hashing_adata.obs.loc[ 425 | cell_hashing_adata.obs["most_likely_hypothesis"] == 2, "Classification" 426 | ] = "Doublet" 427 | cell_hashing_adata.obs.loc[ 428 | cell_hashing_adata.obs["most_likely_hypothesis"] == 0, "Classification" 429 | ] = "Negative" 430 | all_sings = cell_hashing_adata.obs["most_likely_hypothesis"] == 1 431 | singlet_sample_index = np.argmax(cell_hashing_adata.X[all_sings], axis=1) 432 | cell_hashing_adata.obs.loc[ 433 | all_sings, "Classification" 434 | ] = cell_hashing_adata.var_names[singlet_sample_index] 435 | 436 | return cell_hashing_adata if not inplace else None 437 | 438 | 439 | def plot_qc_checks_cell_hashing( 440 | cell_hashing_adata: anndata.AnnData, alpha: float = 0.05, fig_path: str = None 441 | ): 442 | """Plot HashSolo demultiplexing results 443 | 444 | Parameters 445 | ---------- 446 | cell_hashing_adata : Anndata 447 | Anndata object filled only with hashing counts 448 | alpha : float 449 | Tranparency of scatterplot points 450 | fig_path : str 451 | Path to save figure 452 | Returns 453 | ------- 454 | """ 455 | import matplotlib 456 | 457 | matplotlib.use("Agg") 458 | import matplotlib.pyplot as plt 459 | 460 | cell_hashing_demultiplexing = cell_hashing_adata.obs 461 | cell_hashing_demultiplexing["log_counts"] = np.log( 462 | np.sum(cell_hashing_adata.X, axis=1) 463 | ) 464 | number_of_clusters = ( 465 | cell_hashing_demultiplexing["cluster_feature"].drop_duplicates().shape[0] 466 | ) 467 | fig, all_axes = plt.subplots( 468 | number_of_clusters, 4, figsize=(40, 10 * number_of_clusters) 469 | ) 470 | counter = 0 471 | for cluster_feature, group in cell_hashing_demultiplexing.groupby( 472 | "cluster_feature" 473 | ): 474 | if number_of_clusters > 1: 475 | axes = all_axes[counter] 476 | else: 477 | axes = all_axes 478 | 479 | ax = axes[0] 480 | ax.plot( 481 | group["log_counts"], 482 | group["negative_hypothesis_probability"], 483 | "bo", 484 | alpha=alpha, 485 | ) 486 | ax.set_title("Probability of negative hypothesis vs log hashing counts") 487 | ax.set_ylabel("Probability of negative hypothesis") 488 | ax.set_xlabel("Log hashing counts") 489 | 490 | ax = axes[1] 491 | ax.plot( 492 | group["log_counts"], 493 | group["singlet_hypothesis_probability"], 494 | "bo", 495 | alpha=alpha, 496 | ) 497 | ax.set_title("Probability of singlet hypothesis vs log hashing counts") 498 | ax.set_ylabel("Probability of singlet hypothesis") 499 | ax.set_xlabel("Log hashing counts") 500 | 501 | ax = axes[2] 502 | ax.plot( 503 | group["log_counts"], 504 | group["doublet_hypothesis_probability"], 505 | "bo", 506 | alpha=alpha, 507 | ) 508 | ax.set_title("Probability of doublet hypothesis vs log hashing counts") 509 | ax.set_ylabel("Probability of doublet hypothesis") 510 | ax.set_xlabel("Log hashing counts") 511 | 512 | ax = axes[3] 513 | group["Classification"].value_counts().plot.bar(ax=ax) 514 | ax.set_title("Count of each samples classification") 515 | counter += 1 516 | plt.show() 517 | if fig_path is not None: 518 | fig.savefig(fig_path, dpi=300, format="pdf") 519 | 520 | 521 | def main(): 522 | usage = "hashsolo" 523 | parser = ArgumentParser(usage, formatter_class=ArgumentDefaultsHelpFormatter) 524 | 525 | parser.add_argument( 526 | dest="data_file", help="h5ad file containing cell hashing counts" 527 | ) 528 | parser.add_argument( 529 | "-j", 530 | dest="model_json_file", 531 | default=None, 532 | help="json file to pass optional arguments", 533 | ) 534 | parser.add_argument( 535 | "-o", 536 | dest="out_dir", 537 | default="hashsolo_output", 538 | help="Output directory for results", 539 | ) 540 | parser.add_argument( 541 | "-c", 542 | dest="clustering_data", 543 | default=None, 544 | help="h5ad file with count transcriptional data to\ 545 | perform clustering on", 546 | ) 547 | parser.add_argument( 548 | "-p", 549 | dest="pre_existing_clusters", 550 | default=None, 551 | help="column in cell_hashing_data_file.obs to \ 552 | specifying different cell types or clusters", 553 | ) 554 | parser.add_argument( 555 | "-q", 556 | dest="plot_name", 557 | default="hashing_qc_plots.pdf", 558 | help="name of plot to output", 559 | ) 560 | parser.add_argument( 561 | "-n", 562 | dest="number_of_noise_barcodes", 563 | default=None, 564 | help="Number of barcodes to use to create noise \ 565 | distribution", 566 | ) 567 | 568 | args = parser.parse_args() 569 | 570 | model_json_file = args.model_json_file 571 | if model_json_file is not None: 572 | # read parameters 573 | with open(model_json_file) as model_json_open: 574 | params = json.load(model_json_open) 575 | else: 576 | params = {} 577 | data_file = args.data_file 578 | data_ext = os.path.splitext(data_file)[-1] 579 | if data_ext == ".h5ad": 580 | cell_hashing_adata = anndata.read(data_file) 581 | else: 582 | print("Unrecognized file format") 583 | 584 | if args.clustering_data is not None: 585 | clustering_data_file = args.clustering_data 586 | clustering_data_ext = os.path.splitext(clustering_data_file)[-1] 587 | if clustering_data_ext == ".h5ad": 588 | clustering_data = anndata.read(clustering_data_file) 589 | else: 590 | print("Unrecognized file format for clustering data") 591 | else: 592 | clustering_data = None 593 | 594 | if not os.path.isdir(args.out_dir): 595 | os.mkdir(args.out_dir) 596 | 597 | hashsolo( 598 | cell_hashing_adata, 599 | pre_existing_clusters=args.pre_existing_clusters, 600 | number_of_noise_barcodes=args.number_of_noise_barcodes, 601 | clustering_data=clustering_data, 602 | **params 603 | ) 604 | cell_hashing_adata.write(os.path.join(args.out_dir, "hashsoloed.h5ad")) 605 | plot_qc_checks_cell_hashing( 606 | cell_hashing_adata, fig_path=os.path.join(args.out_dir, args.plot_name) 607 | ) 608 | 609 | 610 | ############################################################################### 611 | # __main__ 612 | ############################################################################### 613 | 614 | 615 | if __name__ == "__main__": 616 | main() 617 | -------------------------------------------------------------------------------- /build/lib/solo/solo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import json 3 | import os 4 | import umap 5 | import sys 6 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 7 | import pkg_resources 8 | 9 | import numpy as np 10 | from sklearn.metrics import * 11 | from scipy.special import softmax 12 | from scanpy import read_10x_mtx 13 | 14 | import torch 15 | from lightning.pytorch.callbacks.early_stopping import EarlyStopping 16 | 17 | import scvi 18 | from scvi.data import read_h5ad, read_loom 19 | from scvi.model import SCVI 20 | from scvi.external import SOLO 21 | 22 | from .utils import knn_smooth_pred_class 23 | 24 | """ 25 | solo.py 26 | 27 | Simulate doublets, train a VAE, and then a classifier on top. 28 | """ 29 | 30 | 31 | ############################################################################### 32 | # main 33 | ############################################################################### 34 | 35 | 36 | def main(): 37 | usage = "solo" 38 | parser = ArgumentParser(usage, formatter_class=ArgumentDefaultsHelpFormatter) 39 | parser.add_argument( 40 | "-j", 41 | dest="model_json_file", 42 | help="json file to pass VAE parameters", 43 | required="--version" not in sys.argv, 44 | ) 45 | parser.add_argument( 46 | "-d", 47 | dest="data_path", 48 | help="path to h5ad, loom, or 10x mtx dir cell by genes counts", 49 | required="--version" not in sys.argv, 50 | ) 51 | parser.add_argument( 52 | "--set-reproducible-seed", 53 | dest="reproducible_seed", 54 | default=None, 55 | type=int, 56 | help="Reproducible seed, give an int to set seed", 57 | ) 58 | parser.add_argument( 59 | "--doublet-depth", 60 | dest="doublet_depth", 61 | default=2.0, 62 | type=float, 63 | help="Depth multiplier for a doublet relative to the \ 64 | average of its constituents", 65 | ) 66 | parser.add_argument( 67 | "-g", dest="gpu", default=True, action="store_true", help="Run on GPU" 68 | ) 69 | parser.add_argument( 70 | "-a", 71 | dest="anndata_output", 72 | default=False, 73 | action="store_true", 74 | help="output modified anndata object with solo scores \ 75 | Only works for anndata", 76 | ) 77 | parser.add_argument("-o", dest="out_dir", default="solo_out") 78 | parser.add_argument( 79 | "-r", 80 | dest="doublet_ratio", 81 | default=2, 82 | type=int, 83 | help="Ratio of doublets to true \ 84 | cells", 85 | ) 86 | parser.add_argument( 87 | "-s", 88 | dest="seed", 89 | default=None, 90 | help="Path to previous solo output \ 91 | directory. Seed VAE models with previously \ 92 | trained solo model. Directory structure is assumed to \ 93 | be the same as solo output directory structure. \ 94 | should at least have a vae.pt a pickled object of \ 95 | vae weights and a latent.npy an np.ndarray of the \ 96 | latents of your cells.", 97 | ) 98 | parser.add_argument( 99 | "-e", 100 | dest="expected_number_of_doublets", 101 | help="Experimentally expected number of doublets", 102 | type=int, 103 | default=None, 104 | ) 105 | parser.add_argument( 106 | "-p", 107 | dest="plot", 108 | default=False, 109 | action="store_true", 110 | help="Plot outputs for solo", 111 | ) 112 | parser.add_argument( 113 | "-recalibrate_scores", 114 | dest="recalibrate_scores", 115 | default=False, 116 | action="store_true", 117 | help="Recalibrate doublet scores (not recommended anymore)", 118 | ) 119 | parser.add_argument( 120 | "--version", 121 | dest="version", 122 | default=False, 123 | action="store_true", 124 | help="Get version of solo-sc", 125 | ) 126 | 127 | parser.add_argument( 128 | "--lr_st", 129 | dest="lr_st", 130 | default=1e-3, 131 | type=int, 132 | help="Learning rate used for solo.train", 133 | ) 134 | parser.add_argument( 135 | "--lr_vae", 136 | dest="lr_vae", 137 | default=1e-3, 138 | type=int, 139 | help="Learning rate used for vae", 140 | ) 141 | 142 | args = parser.parse_args() 143 | 144 | if args.version: 145 | version = pkg_resources.require("solo-sc")[0].version 146 | print(f"Current version of solo-sc is {version}") 147 | if args.model_json_file is None or args.data_path is None: 148 | print("Json or data path not give exiting solo") 149 | sys.exit() 150 | 151 | model_json_file = args.model_json_file 152 | data_path = args.data_path 153 | if args.gpu and not torch.cuda.is_available(): 154 | args.gpu = torch.cuda.is_available() 155 | print("Cuda is not available, switching to cpu running!") 156 | 157 | if not os.path.isdir(args.out_dir): 158 | os.mkdir(args.out_dir) 159 | 160 | if args.reproducible_seed is not None: 161 | scvi.settings.seed = args.reproducible_seed 162 | else: 163 | scvi.settings.seed = np.random.randint(10000) 164 | 165 | ################################################## 166 | # data 167 | 168 | # read loom/anndata 169 | data_ext = os.path.splitext(data_path)[-1] 170 | if data_ext == ".loom": 171 | scvi_data = read_loom(data_path) 172 | elif data_ext == ".h5ad": 173 | scvi_data = read_h5ad(data_path) 174 | elif os.path.isdir(data_path): 175 | scvi_data = read_10x_mtx(path=data_path) 176 | cell_umi_depth = scvi_data.X.sum(axis=1) 177 | fifth, ninetyfifth = np.percentile(cell_umi_depth, [5, 95]) 178 | min_cell_umi_depth = np.min(cell_umi_depth) 179 | max_cell_umi_depth = np.max(cell_umi_depth) 180 | if fifth * 10 < ninetyfifth: 181 | print( 182 | """WARNING YOUR DATA HAS A WIDE RANGE OF CELL DEPTHS. 183 | PLEASE MANUALLY REVIEW YOUR DATA""" 184 | ) 185 | print( 186 | f"Min cell depth: {min_cell_umi_depth}, Max cell depth: {max_cell_umi_depth}" 187 | ) 188 | else: 189 | msg = f"{data_path} is not a recognized format.\n" 190 | msg += "must be one of {h5ad, loom, 10x mtx dir}" 191 | raise TypeError(msg) 192 | 193 | num_cells, num_genes = scvi_data.X.shape 194 | 195 | # check for parameters 196 | if not os.path.exists(model_json_file): 197 | raise FileNotFoundError(f"{model_json_file} does not exist.") 198 | # read parameters 199 | with open(model_json_file, "r") as model_json_open: 200 | params = json.load(model_json_open) 201 | 202 | # set VAE params 203 | vae_params = {} 204 | for par in ["n_hidden", "n_latent", "n_layers", "dropout_rate", "ignore_batch"]: 205 | if par in params: 206 | vae_params[par] = params[par] 207 | 208 | # training parameters 209 | batch_key = params.get("batch_key", None) 210 | batch_size = params.get("batch_size", 128) 211 | valid_pct = params.get("valid_pct", 0.1) 212 | check_val_every_n_epoch = params.get("check_val_every_n_epoch", 5) 213 | learning_rate = params.get("learning_rate", args.lr_st) 214 | stopping_params = {"patience": params.get("patience", 8), "min_delta": 0} 215 | 216 | # protect against single example batch 217 | while num_cells % batch_size == 1: 218 | batch_size = int(np.round(1.25 * batch_size)) 219 | print("Increasing batch_size to %d to avoid single example batch." % batch_size) 220 | 221 | scvi.settings.batch_size = batch_size 222 | ################################################## 223 | # SCVI 224 | #setup_anndata(scvi_data, batch_key=batch_key) 225 | scvi.model.SCVI.setup_anndata(scvi_data, batch_key=batch_key) 226 | vae = SCVI( 227 | scvi_data, 228 | gene_likelihood="nb", 229 | log_variational=True, 230 | **vae_params, 231 | use_observed_lib_size=False, 232 | ) 233 | 234 | if args.seed: 235 | vae = vae.load(os.path.join(args.seed, "vae"), use_gpu=args.gpu) 236 | else: 237 | scvi_callbacks = [] 238 | scvi_callbacks += [ 239 | EarlyStopping( 240 | monitor="reconstruction_loss_validation", mode="min", **stopping_params 241 | ) 242 | ] 243 | plan_kwargs = { 244 | "reduce_lr_on_plateau": True, 245 | "lr_factor": 0.1, 246 | "lr": args.lr_vae, 247 | "lr_patience": 10, 248 | "lr_threshold": 0, 249 | "lr_min": 1e-4, 250 | "lr_scheduler_metric": "reconstruction_loss_validation", 251 | } 252 | vae.train( 253 | max_epochs=2000, 254 | validation_size=valid_pct, 255 | check_val_every_n_epoch=check_val_every_n_epoch, 256 | plan_kwargs=plan_kwargs, 257 | callbacks=scvi_callbacks, 258 | ) 259 | # save VAE 260 | vae.save(os.path.join(args.out_dir, "vae")) 261 | 262 | latent = vae.get_latent_representation() 263 | # save latent representation 264 | np.save(os.path.join(args.out_dir, "latent.npy"), latent.astype("float32")) 265 | 266 | ################################################## 267 | # classifier 268 | 269 | # model 270 | # todo add doublet ratio 271 | solo = SOLO.from_scvi_model(vae, doublet_ratio=args.doublet_ratio) 272 | solo.train( 273 | 2000, 274 | lr=learning_rate, 275 | train_size=0.9, 276 | check_val_every_n_epoch=5, 277 | early_stopping_patience=6, 278 | ) 279 | solo.train( 280 | 2000, 281 | lr=learning_rate * 0.1, 282 | train_size=0.9, 283 | check_val_every_n_epoch=1, 284 | early_stopping_patience=30, 285 | callbacks=[], 286 | ) 287 | solo.save(os.path.join(args.out_dir, "classifier")) 288 | 289 | logit_predictions = solo.predict(include_simulated_doublets=True) 290 | 291 | is_doublet_known = solo.adata.obs._solo_doub_sim == "doublet" 292 | is_doublet_pred = logit_predictions.idxmin(axis=1) == "singlet" 293 | 294 | validation_is_doublet_known = is_doublet_known[solo.validation_indices] 295 | validation_is_doublet_pred = is_doublet_pred[solo.validation_indices] 296 | training_is_doublet_known = is_doublet_known[solo.train_indices] 297 | training_is_doublet_pred = is_doublet_pred[solo.train_indices] 298 | 299 | valid_as = accuracy_score(validation_is_doublet_known, validation_is_doublet_pred) 300 | valid_roc = roc_auc_score(validation_is_doublet_known, validation_is_doublet_pred) 301 | valid_ap = average_precision_score( 302 | validation_is_doublet_known, validation_is_doublet_pred 303 | ) 304 | 305 | train_as = accuracy_score(training_is_doublet_known, training_is_doublet_pred) 306 | train_roc = roc_auc_score(training_is_doublet_known, training_is_doublet_pred) 307 | train_ap = average_precision_score( 308 | training_is_doublet_known, training_is_doublet_pred 309 | ) 310 | 311 | print(f"Training results") 312 | print(f"AUROC: {train_roc}, Accuracy: {train_as}, Average precision: {train_ap}") 313 | 314 | print(f"Validation results") 315 | print(f"AUROC: {valid_roc}, Accuracy: {valid_as}, Average precision: {valid_ap}") 316 | 317 | # write predictions 318 | # softmax predictions 319 | softmax_predictions = softmax(logit_predictions, axis=1) 320 | if logit_predictions.columns[0]=='doublet': 321 | doublet_score = softmax_predictions[:, 0] 322 | else: 323 | doublet_score = softmax_predictions[:, 1] 324 | 325 | np.save( 326 | os.path.join(args.out_dir, "no_updates_softmax_scores.npy"), 327 | doublet_score[:num_cells], 328 | ) 329 | np.savetxt( 330 | os.path.join(args.out_dir, "no_updates_softmax_scores.csv"), 331 | doublet_score[:num_cells], 332 | delimiter=",", 333 | ) 334 | np.save( 335 | os.path.join(args.out_dir, "no_updates_softmax_scores_sim.npy"), 336 | doublet_score[num_cells:], 337 | ) 338 | 339 | # logit predictions 340 | logit_doublet_score = logit_predictions.loc[:, "doublet"] 341 | np.save( 342 | os.path.join(args.out_dir, "logit_scores.npy"), logit_doublet_score[:num_cells] 343 | ) 344 | np.savetxt( 345 | os.path.join(args.out_dir, "logit_scores.csv"), 346 | logit_doublet_score[:num_cells], 347 | delimiter=",", 348 | ) 349 | np.save( 350 | os.path.join(args.out_dir, "logit_scores_sim.npy"), 351 | logit_doublet_score[num_cells:], 352 | ) 353 | 354 | # update threshold as a function of Solo's estimate of the number of 355 | # doublets 356 | # essentially a log odds update 357 | # TODO put in a function 358 | # currently overshrinking softmaxes 359 | diff = np.inf 360 | counter_update = 0 361 | solo_scores = doublet_score[:num_cells] 362 | logit_scores = logit_doublet_score[:num_cells] 363 | d_s = args.doublet_ratio / (args.doublet_ratio + 1) 364 | if args.recalibrate_scores: 365 | while (diff > 0.01) | (counter_update < 5): 366 | 367 | # calculate log odds calibration for logits 368 | d_o = np.mean(solo_scores) 369 | c = np.log(d_o / (1 - d_o)) - np.log(d_s / (1 - d_s)) 370 | 371 | # update solo scores 372 | solo_scores = 1 / (1 + np.exp(-(logit_scores + c))) 373 | 374 | # update while conditions 375 | diff = np.abs(d_o - np.mean(solo_scores)) 376 | counter_update += 1 377 | 378 | np.save(os.path.join(args.out_dir, "softmax_scores.npy"), solo_scores) 379 | np.savetxt( 380 | os.path.join(args.out_dir, "softmax_scores.csv"), solo_scores, delimiter="," 381 | ) 382 | 383 | if args.expected_number_of_doublets is not None: 384 | k = len(solo_scores) - args.expected_number_of_doublets 385 | if args.expected_number_of_doublets / len(solo_scores) > 0.5: 386 | print( 387 | """Make sure you actually expect more than half your cells 388 | to be doublets. If not change your 389 | -e parameter value""" 390 | ) 391 | assert k > 0 392 | idx = np.argpartition(solo_scores, k) 393 | threshold = np.max(solo_scores[idx[:k]]) 394 | is_solo_doublet = solo_scores > threshold 395 | else: 396 | is_solo_doublet = solo_scores > 0.5 397 | 398 | np.save(os.path.join(args.out_dir, "is_doublet.npy"), is_solo_doublet[:num_cells]) 399 | np.savetxt( 400 | os.path.join(args.out_dir, "is_doublet.csv"), 401 | is_solo_doublet[:num_cells], 402 | delimiter=",", 403 | ) 404 | 405 | np.save( 406 | os.path.join(args.out_dir, "is_doublet_sim.npy"), is_solo_doublet[num_cells:] 407 | ) 408 | 409 | np.save(os.path.join(args.out_dir, "preds.npy"), is_doublet_pred[:num_cells]) 410 | np.savetxt( 411 | os.path.join(args.out_dir, "preds.csv"), 412 | is_doublet_pred[:num_cells], 413 | delimiter=",", 414 | ) 415 | 416 | smoothed_preds = knn_smooth_pred_class( 417 | X=latent, pred_class=is_doublet_pred[:num_cells] 418 | ) 419 | np.save(os.path.join(args.out_dir, "smoothed_preds.npy"), smoothed_preds) 420 | 421 | if args.anndata_output and data_ext == ".h5ad": 422 | scvi_data.obs["is_doublet"] = is_solo_doublet[:num_cells].astype(bool) 423 | scvi_data.obs["logit_scores"] = logit_doublet_score[:num_cells].astype( 424 | float 425 | ) 426 | scvi_data.obs["softmax_scores"] = solo_scores[:num_cells].astype(float) 427 | scvi_data.write(os.path.join(args.out_dir, "soloed.h5ad")) 428 | 429 | if args.plot: 430 | import matplotlib 431 | 432 | matplotlib.use("Agg") 433 | import matplotlib.pyplot as plt 434 | import seaborn as sns 435 | 436 | train_solo_scores = doublet_score[solo.train_indices] 437 | validation_solo_scores = doublet_score[solo.validation_indices] 438 | 439 | train_fpr, train_tpr, _ = roc_curve( 440 | training_is_doublet_known, train_solo_scores 441 | ) 442 | val_fpr, val_tpr, _ = roc_curve( 443 | validation_is_doublet_known, validation_solo_scores 444 | ) 445 | 446 | # plot ROC 447 | plt.figure() 448 | plt.plot(train_fpr, train_tpr, label="Train") 449 | plt.plot(val_fpr, val_tpr, label="Validation") 450 | plt.gca().set_xlabel("False positive rate") 451 | plt.gca().set_ylabel("True positive rate") 452 | plt.legend() 453 | plt.savefig(os.path.join(args.out_dir, "roc.pdf")) 454 | plt.close() 455 | 456 | train_precision, train_recall, _ = precision_recall_curve( 457 | training_is_doublet_known, train_solo_scores 458 | ) 459 | val_precision, val_recall, _ = precision_recall_curve( 460 | validation_is_doublet_known, validation_solo_scores 461 | ) 462 | # plot accuracy 463 | plt.figure() 464 | plt.plot(train_recall, train_precision, label="Train") 465 | plt.plot(val_recall, val_precision, label="Validation") 466 | plt.gca().set_xlabel("Recall") 467 | plt.gca().set_ylabel("pytPrecision") 468 | plt.legend() 469 | plt.savefig(os.path.join(args.out_dir, "precision_recall.pdf")) 470 | plt.close() 471 | 472 | # plot distributions 473 | obs_indices = solo.validation_indices[solo.validation_indices < num_cells] 474 | sim_indices = solo.validation_indices[solo.validation_indices > num_cells] 475 | 476 | plt.figure() 477 | sns.displot(doublet_score[sim_indices], label="Simulated") 478 | sns.displot(doublet_score[obs_indices], label="Observed") 479 | plt.legend() 480 | plt.savefig(os.path.join(args.out_dir, "sim_vs_obs_dist.pdf")) 481 | plt.close() 482 | 483 | plt.figure() 484 | sns.distplot(solo_scores[:num_cells], label="Observed (transformed)") 485 | plt.legend() 486 | plt.savefig(os.path.join(args.out_dir, "real_cells_dist.pdf")) 487 | plt.close() 488 | 489 | scvi_umap = umap.UMAP(n_neighbors=16).fit_transform(latent) 490 | fig, ax = plt.subplots(1, 1, figsize=(10, 10)) 491 | ax.scatter( 492 | scvi_umap[:, 0], 493 | scvi_umap[:, 1], 494 | c=doublet_score[:num_cells], 495 | s=8, 496 | cmap="GnBu", 497 | ) 498 | 499 | ax.set_xlabel("UMAP 1") 500 | ax.set_ylabel("UMAP 2") 501 | fig.savefig(os.path.join(args.out_dir, "umap_solo_scores.pdf")) 502 | 503 | 504 | ############################################################################### 505 | # __main__ 506 | ############################################################################### 507 | 508 | 509 | if __name__ == "__main__": 510 | main() 511 | -------------------------------------------------------------------------------- /build/lib/solo/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from scipy.stats import multinomial 4 | from sklearn.neighbors import NearestNeighbors 5 | 6 | 7 | def knn_smooth_pred_class( 8 | X: np.ndarray, 9 | pred_class: np.ndarray, 10 | grouping: np.ndarray = None, 11 | k: int = 15, 12 | ) -> np.ndarray: 13 | """ 14 | Smooths class predictions by taking the modal class from each cell's 15 | nearest neighbors. 16 | Parameters 17 | ---------- 18 | X : np.ndarray 19 | [N, Features] embedding space for calculation of nearest neighbors. 20 | pred_class : np.ndarray 21 | [N,] array of unique class labels. 22 | groupings : np.ndarray 23 | [N,] unique grouping labels for i.e. clusters. 24 | if provided, only considers nearest neighbors *within the cluster*. 25 | k : int 26 | number of nearest neighbors to use for smoothing. 27 | Returns 28 | ------- 29 | smooth_pred_class : np.ndarray 30 | [N,] unique class labels, smoothed by kNN. 31 | Examples 32 | -------- 33 | >>> smooth_pred_class = knn_smooth_pred_class( 34 | ... X = X, 35 | ... pred_class = raw_predicted_classes, 36 | ... grouping = louvain_cluster_groups, 37 | ... k = 15,) 38 | Notes 39 | ----- 40 | scNym classifiers do not incorporate neighborhood information. 41 | By using a simple kNN smoothing heuristic, we can leverage neighborhood 42 | information to improve classification performance, smoothing out cells 43 | that have an outlier prediction relative to their local neighborhood. 44 | """ 45 | if grouping is None: 46 | # do not use a grouping to restrict local neighborhood 47 | # associations, create a universal pseudogroup `0`. 48 | grouping = np.zeros(X.shape[0]) 49 | 50 | smooth_pred_class = np.zeros_like(pred_class) 51 | for group in np.unique(grouping): 52 | # identify only cells in the relevant group 53 | group_idx = np.where(grouping == group)[0].astype("int") 54 | X_group = X[grouping == group, :] 55 | # if there are < k cells in the group, change `k` to the 56 | # group size 57 | if X_group.shape[0] < k: 58 | k_use = X_group.shape[0] 59 | else: 60 | k_use = k 61 | # compute a nearest neighbor graph and identify kNN 62 | nns = NearestNeighbors( 63 | n_neighbors=k_use, 64 | ).fit(X_group) 65 | dist, idx = nns.kneighbors(X_group) 66 | 67 | # for each cell in the group, assign a class as 68 | # the majority class of the kNN 69 | for i in range(X_group.shape[0]): 70 | classes = pred_class[group_idx[idx[i, :]]] 71 | uniq_classes, counts = np.unique(classes, return_counts=True) 72 | maj_class = uniq_classes[int(np.argmax(counts))] 73 | smooth_pred_class[group_idx[i]] = maj_class 74 | return smooth_pred_class 75 | -------------------------------------------------------------------------------- /hashsolo_params_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "priors": [0.01, 0.5, 0.49] 3 | } 4 | -------------------------------------------------------------------------------- /prespecified.txt: -------------------------------------------------------------------------------- 1 | ConfigArgParse 2 | pandas 3 | seaborn 4 | tqdm 5 | scanpy 6 | scvi 7 | leidenalg 8 | torch 9 | numba 10 | -------------------------------------------------------------------------------- /release_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | rm -rf dist 4 | source activate solo-sc 5 | python setup.py sdist 6 | twine upload dist/* 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ConfigArgParse>=1.7,<2.0 2 | h5py>=3.10.0,<4.0 3 | leidenalg>=0.10.2,<1.0 4 | lightning>=2,<3.0 5 | lightning-utilities>=0.10.1,<1.0 6 | numba>=0.59.0,<1.0 7 | pandas>=2.0 8 | pytorch-lightning>=2.2.1,<3.0 9 | scanpy>=1.9.8,<2.0 10 | scipy>=1.12.0,<2.0 11 | scvi-tools>=1.1.1,<2.0 12 | seaborn>=0.13.2,<1.0 13 | torch>=2.2.1,<3.0 14 | torchmetrics>=1.3.1,<2.0 15 | tqdm>=4.66.2,<5.0 -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # Inside of setup.cfg 2 | [metadata] 3 | description-file = README.md 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | from setuptools import setup, find_packages 4 | 5 | if sys.version_info < (3,): 6 | sys.exit("solo requires Python >= 3.12") 7 | 8 | try: 9 | from solo import __author__, __email__ 10 | except ImportError: # Deps not yet installed 11 | __author__ = __email__ = "" 12 | 13 | 14 | setup( 15 | name="solo-sc", 16 | version="1.2", 17 | description="Neural network classifiers for doublets", 18 | long_description=Path("README.md").read_text("utf-8"), 19 | long_description_content_type="text/markdown", 20 | url="http://github.com/calico/solo", 21 | download_url="https://github.com/calico/solo/archive/1.2.tar.gz", 22 | author=__author__, 23 | author_email=__email__, 24 | license="Apache", 25 | python_requires=">=3.12", 26 | install_requires=[ 27 | l.strip() for l in Path("requirements.txt").read_text("utf-8").splitlines() 28 | ], 29 | packages=find_packages(exclude="testdata"), 30 | entry_points=dict( 31 | console_scripts=["solo=solo.solo:main", "hashsolo=solo.hashsolo:main"], 32 | ), 33 | classifiers=[ 34 | "Environment :: Console", 35 | "Intended Audience :: Science/Research", 36 | "Topic :: Scientific/Engineering :: Bio-Informatics", 37 | ], 38 | ) 39 | -------------------------------------------------------------------------------- /solo/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = "David Kelley, Nick Bernstein" 2 | __email__ = "nicholas@calicolabs.com" 3 | __version__ = "0.1" 4 | 5 | from . import hashsolo, utils 6 | -------------------------------------------------------------------------------- /solo/hashsolo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import json 4 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 5 | 6 | from scipy.stats import norm 7 | from itertools import product 8 | import anndata 9 | import numpy as np 10 | import pandas as pd 11 | import scanpy as sc 12 | 13 | from scipy.sparse import issparse 14 | from sklearn.metrics import calinski_harabasz_score 15 | 16 | """ 17 | HashSolo script provides a probabilistic cell hashing demultiplexing method 18 | which generates a noise distribution and signal distribution for 19 | each hashing barcode from empirically observed counts. These distributions 20 | are updates from the global signal and noise barcode distributions, which 21 | helps in the setting where not many cells are observed. Signal distributions 22 | for a hashing barcode are estimated from samples where that hashing barcode 23 | has the highest count. Noise distributions for a hashing barcode are estimated 24 | from samples where that hashing barcode is one the k-2 lowest barcodes, where 25 | k is the number of barcodes. A doublet should then have its two highest 26 | barcode counts most likely coming from a signal distribution for those barcodes. 27 | A singlet should have its highest barcode from a signal distribution, and its 28 | second highest barcode from a noise distribution. A negative two highest 29 | barcodes should come from noise distributions. We test each of these 30 | hypotheses in a bayesian fashion, and select the most probable hypothesis. 31 | """ 32 | 33 | 34 | def _calculate_log_likelihoods(data, number_of_noise_barcodes): 35 | """Calculate log likelihoods for each hypothesis, negative, singlet, doublet 36 | 37 | Parameters 38 | ---------- 39 | data : np.ndarray 40 | cells by hashing counts matrix 41 | number_of_noise_barcodes : int, 42 | number of barcodes to used to calculated noise distribution 43 | Returns 44 | ------- 45 | log_likelihoods_for_each_hypothesis : np.ndarray 46 | a 2d np.array log likelihood of each hypothesis 47 | all_indices 48 | counter_to_barcode_combo 49 | """ 50 | 51 | def gaussian_updates(data, mu_o, std_o): 52 | """Update parameters of your gaussian 53 | https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf 54 | Parameters 55 | ---------- 56 | data : np.array 57 | 1-d array of counts 58 | mu_o : float, 59 | global mean for hashing count distribution 60 | std_o : float, 61 | global std for hashing count distribution 62 | Returns 63 | ------- 64 | float 65 | mean of gaussian 66 | float 67 | std of gaussian 68 | """ 69 | lam_o = 1 / (std_o ** 2) 70 | n = len(data) 71 | lam = 1 / np.var(data) if len(data) > 1 else lam_o 72 | lam_n = lam_o + n * lam 73 | mu_n = ( 74 | (np.mean(data) * n * lam + mu_o * lam_o) / lam_n if len(data) > 0 else mu_o 75 | ) 76 | return mu_n, (1 / (lam_n / (n + 1))) ** (1 / 2) 77 | 78 | eps = 1e-15 79 | # probabilites for negative, singlet, doublets 80 | log_likelihoods_for_each_hypothesis = np.zeros((data.shape[0], 3)) 81 | 82 | all_indices = np.empty(data.shape[0]) 83 | num_of_barcodes = data.shape[1] 84 | number_of_non_noise_barcodes = ( 85 | num_of_barcodes - number_of_noise_barcodes 86 | if number_of_noise_barcodes is not None 87 | else 2 88 | ) 89 | num_of_noise_barcodes = num_of_barcodes - number_of_non_noise_barcodes 90 | 91 | # assume log normal 92 | data = np.log(data + 1) 93 | data_arg = np.argsort(data, axis=1) 94 | data_sort = np.sort(data, axis=1) 95 | 96 | # global signal and noise counts useful for when we have few cells 97 | # barcodes with the highest number of counts are assumed to be a true signal 98 | # barcodes with rank < k are considered to be noise 99 | global_signal_counts = np.ravel(data_sort[:, -1]) 100 | global_noise_counts = np.ravel(data_sort[:, :-number_of_non_noise_barcodes]) 101 | global_mu_signal_o, global_sigma_signal_o = np.mean(global_signal_counts), np.std( 102 | global_signal_counts 103 | ) 104 | global_mu_noise_o, global_sigma_noise_o = np.mean(global_noise_counts), np.std( 105 | global_noise_counts 106 | ) 107 | 108 | noise_params_dict = {} 109 | signal_params_dict = {} 110 | 111 | # for each barcode get empirical noise and signal distribution parameterization 112 | for x in np.arange(num_of_barcodes): 113 | sample_barcodes = data[:, x] 114 | sample_barcodes_noise_idx = np.where(data_arg[:, :num_of_noise_barcodes] == x)[ 115 | 0 116 | ] 117 | sample_barcodes_signal_idx = np.where(data_arg[:, -1] == x) 118 | 119 | # get noise and signal counts 120 | noise_counts = sample_barcodes[sample_barcodes_noise_idx] 121 | signal_counts = sample_barcodes[sample_barcodes_signal_idx] 122 | 123 | # get parameters of distribution, assuming lognormal do update from global values 124 | noise_param = gaussian_updates( 125 | noise_counts, global_mu_noise_o, global_sigma_noise_o 126 | ) 127 | signal_param = gaussian_updates( 128 | signal_counts, global_mu_signal_o, global_sigma_signal_o 129 | ) 130 | noise_params_dict[x] = noise_param 131 | signal_params_dict[x] = signal_param 132 | 133 | counter_to_barcode_combo = {} 134 | counter = 0 135 | 136 | # for each combination of noise and signal barcode calculate probiltiy of in silico and real cell hypotheses 137 | for noise_sample_idx, signal_sample_idx in product( 138 | np.arange(num_of_barcodes), np.arange(num_of_barcodes) 139 | ): 140 | signal_subset = data_arg[:, -1] == signal_sample_idx 141 | noise_subset = data_arg[:, -2] == noise_sample_idx 142 | subset = signal_subset & noise_subset 143 | if sum(subset) == 0: 144 | continue 145 | 146 | indices = np.where(subset)[0] 147 | barcode_combo = "_".join([str(noise_sample_idx), str(signal_sample_idx)]) 148 | all_indices[np.where(subset)[0]] = counter 149 | counter_to_barcode_combo[counter] = barcode_combo 150 | counter += 1 151 | noise_params = noise_params_dict[noise_sample_idx] 152 | signal_params = signal_params_dict[signal_sample_idx] 153 | 154 | # calculate probabilties for each hypothesis for each cell 155 | data_subset = data[subset] 156 | log_signal_signal_probs = np.log( 157 | norm.pdf( 158 | data_subset[:, signal_sample_idx], 159 | *signal_params[:-2], 160 | loc=signal_params[-2], 161 | scale=signal_params[-1] 162 | ) 163 | + eps 164 | ) 165 | signal_noise_params = signal_params_dict[noise_sample_idx] 166 | log_noise_signal_probs = np.log( 167 | norm.pdf( 168 | data_subset[:, noise_sample_idx], 169 | *signal_noise_params[:-2], 170 | loc=signal_noise_params[-2], 171 | scale=signal_noise_params[-1] 172 | ) 173 | + eps 174 | ) 175 | 176 | log_noise_noise_probs = np.log( 177 | norm.pdf( 178 | data_subset[:, noise_sample_idx], 179 | *noise_params[:-2], 180 | loc=noise_params[-2], 181 | scale=noise_params[-1] 182 | ) 183 | + eps 184 | ) 185 | log_signal_noise_probs = np.log( 186 | norm.pdf( 187 | data_subset[:, signal_sample_idx], 188 | *noise_params[:-2], 189 | loc=noise_params[-2], 190 | scale=noise_params[-1] 191 | ) 192 | + eps 193 | ) 194 | 195 | probs_of_negative = np.sum( 196 | [log_noise_noise_probs, log_signal_noise_probs], axis=0 197 | ) 198 | probs_of_singlet = np.sum( 199 | [log_noise_noise_probs, log_signal_signal_probs], axis=0 200 | ) 201 | probs_of_doublet = np.sum( 202 | [log_noise_signal_probs, log_signal_signal_probs], axis=0 203 | ) 204 | log_probs_list = [probs_of_negative, probs_of_singlet, probs_of_doublet] 205 | 206 | # each cell and each hypothesis probability 207 | for prob_idx, log_prob in enumerate(log_probs_list): 208 | log_likelihoods_for_each_hypothesis[indices, prob_idx] = log_prob 209 | return log_likelihoods_for_each_hypothesis, all_indices, counter_to_barcode_combo 210 | 211 | 212 | def _calculate_bayes_rule(data, priors, number_of_noise_barcodes): 213 | """ 214 | Calculate bayes rule from log likelihoods 215 | 216 | Parameters 217 | ---------- 218 | data : np.array 219 | Anndata object filled only with hashing counts 220 | priors : list, 221 | a list of your prior for each hypothesis 222 | first element is your prior for the negative hypothesis 223 | second element is your prior for the singlet hypothesis 224 | third element is your prior for the doublet hypothesis 225 | We use [0.01, 0.8, 0.19] by default because we assume the barcodes 226 | in your cell hashing matrix are those cells which have passed QC 227 | in the transcriptome space, e.g. UMI counts, pct mito reads, etc. 228 | number_of_noise_barcodes : int 229 | number of barcodes to used to calculated noise distribution 230 | Returns 231 | ------- 232 | bayes_dict_results : dict 233 | 'most_likely_hypothesis' key is a 1d np.array of the most likely hypothesis 234 | 'probs_hypotheses' key is a 2d np.array probability of each hypothesis 235 | 'log_likelihoods_for_each_hypothesis' key is a 2d np.array log likelihood of each hypothesis 236 | """ 237 | priors = np.array(priors) 238 | log_likelihoods_for_each_hypothesis, _, _ = _calculate_log_likelihoods( 239 | data, number_of_noise_barcodes 240 | ) 241 | probs_hypotheses = ( 242 | np.exp(log_likelihoods_for_each_hypothesis) 243 | * priors 244 | / np.sum( 245 | np.multiply(np.exp(log_likelihoods_for_each_hypothesis), priors), axis=1 246 | )[:, None] 247 | ) 248 | most_likely_hypothesis = np.argmax(probs_hypotheses, axis=1) 249 | return { 250 | "most_likely_hypothesis": most_likely_hypothesis, 251 | "probs_hypotheses": probs_hypotheses, 252 | "log_likelihoods_for_each_hypothesis": log_likelihoods_for_each_hypothesis, 253 | } 254 | 255 | 256 | def _get_clusters(clustering_data: anndata.AnnData, resolutions: list): 257 | """ 258 | Principled cell clustering 259 | Parameters 260 | ---------- 261 | cell_hashing_adata : anndata.AnnData 262 | Anndata object filled only with hashing counts 263 | resolutions : list 264 | clustering resolutions for leiden 265 | Returns 266 | ------- 267 | np.ndarray 268 | leiden clustering results for each cell 269 | """ 270 | sc.pp.normalize_per_cell(clustering_data, counts_per_cell_after=1e4) 271 | sc.pp.log1p(clustering_data) 272 | sc.pp.highly_variable_genes( 273 | clustering_data, min_mean=0.0125, max_mean=3, min_disp=0.5 274 | ) 275 | clustering_data = clustering_data[:, clustering_data.var["highly_variable"]] 276 | sc.pp.scale(clustering_data, max_value=10) 277 | sc.tl.pca(clustering_data, svd_solver="arpack") 278 | sc.pp.neighbors(clustering_data, n_neighbors=10, n_pcs=40) 279 | sc.tl.umap(clustering_data) 280 | best_ch_score = -np.inf 281 | 282 | for resolution in resolutions: 283 | sc.tl.leiden(clustering_data, resolution=resolution) 284 | 285 | ch_score = calinski_harabasz_score( 286 | clustering_data.X, clustering_data.obs["leiden"] 287 | ) 288 | 289 | if ch_score > best_ch_score: 290 | clustering_data.obs["best_leiden"] = clustering_data.obs["leiden"].values 291 | best_ch_score = ch_score 292 | return clustering_data.obs["best_leiden"].values 293 | 294 | 295 | def hashsolo( 296 | cell_hashing_adata: anndata.AnnData, 297 | priors: list = [0.01, 0.8, 0.19], 298 | pre_existing_clusters: str = None, 299 | clustering_data: anndata.AnnData = None, 300 | resolutions: list = [0.1, 0.25, 0.5, 0.75, 1], 301 | number_of_noise_barcodes: int = None, 302 | inplace: bool = True, 303 | ): 304 | """Demultiplex cell hashing dataset using HashSolo method 305 | 306 | Parameters 307 | ---------- 308 | cell_hashing_adata : anndata.AnnData 309 | Anndata object filled only with hashing counts 310 | priors : list, 311 | a list of your prior for each hypothesis 312 | first element is your prior for the negative hypothesis 313 | second element is your prior for the singlet hypothesis 314 | third element is your prior for the doublet hypothesis 315 | We use [0.01, 0.8, 0.19] by default because we assume the barcodes 316 | in your cell hashing matrix are those cells which have passed QC 317 | in the transcriptome space, e.g. UMI counts, pct mito reads, etc. 318 | clustering_data : anndata.AnnData 319 | transcriptional data for clustering 320 | resolutions : list 321 | clustering resolutions for leiden 322 | pre_existing_clusters : str 323 | column in cell_hashing_adata.obs for how to break up demultiplexing 324 | inplace : bool 325 | To do operation in place 326 | 327 | Returns 328 | ------- 329 | cell_hashing_adata : AnnData 330 | if inplace is False returns AnnData with demultiplexing results 331 | in .obs attribute otherwise does is in place 332 | """ 333 | if issparse(cell_hashing_adata.X): 334 | cell_hashing_adata.X = np.array(cell_hashing_adata.X.todense()) 335 | 336 | if clustering_data is not None: 337 | print( 338 | "This may take awhile we are running clustering at {} different resolutions".format( 339 | len(resolutions) 340 | ) 341 | ) 342 | if not all(clustering_data.obs_names == cell_hashing_adata.obs_names): 343 | raise ValueError( 344 | "clustering_data and cell hashing cell_hashing_adata must have same index" 345 | ) 346 | cell_hashing_adata.obs["best_leiden"] = _get_clusters( 347 | clustering_data, resolutions 348 | ) 349 | 350 | data = cell_hashing_adata.X 351 | num_of_cells = cell_hashing_adata.shape[0] 352 | results = pd.DataFrame( 353 | np.zeros((num_of_cells, 6)), 354 | columns=[ 355 | "most_likely_hypothesis", 356 | "probs_hypotheses", 357 | "cluster_feature", 358 | "negative_hypothesis_probability", 359 | "singlet_hypothesis_probability", 360 | "doublet_hypothesis_probability", 361 | ], 362 | index=cell_hashing_adata.obs_names, 363 | ) 364 | if clustering_data is not None or pre_existing_clusters is not None: 365 | cluster_features = ( 366 | "best_leiden" if pre_existing_clusters is None else pre_existing_clusters 367 | ) 368 | unique_cluster_features = np.unique(cell_hashing_adata.obs[cluster_features]) 369 | for cluster_feature in unique_cluster_features: 370 | cluster_feature_bool_vector = ( 371 | cell_hashing_adata.obs[cluster_features] == cluster_feature 372 | ) 373 | posterior_dict = _calculate_bayes_rule( 374 | data[cluster_feature_bool_vector], priors, number_of_noise_barcodes 375 | ) 376 | results.loc[ 377 | cluster_feature_bool_vector, "most_likely_hypothesis" 378 | ] = posterior_dict["most_likely_hypothesis"] 379 | results.loc[ 380 | cluster_feature_bool_vector, "cluster_feature" 381 | ] = cluster_feature 382 | results.loc[ 383 | cluster_feature_bool_vector, "negative_hypothesis_probability" 384 | ] = posterior_dict["probs_hypotheses"][:, 0] 385 | results.loc[ 386 | cluster_feature_bool_vector, "singlet_hypothesis_probability" 387 | ] = posterior_dict["probs_hypotheses"][:, 1] 388 | results.loc[ 389 | cluster_feature_bool_vector, "doublet_hypothesis_probability" 390 | ] = posterior_dict["probs_hypotheses"][:, 2] 391 | else: 392 | posterior_dict = _calculate_bayes_rule(data, priors, number_of_noise_barcodes) 393 | results.loc[:, "most_likely_hypothesis"] = posterior_dict[ 394 | "most_likely_hypothesis" 395 | ] 396 | results.loc[:, "cluster_feature"] = 0 397 | results.loc[:, "negative_hypothesis_probability"] = posterior_dict[ 398 | "probs_hypotheses" 399 | ][:, 0] 400 | results.loc[:, "singlet_hypothesis_probability"] = posterior_dict[ 401 | "probs_hypotheses" 402 | ][:, 1] 403 | results.loc[:, "doublet_hypothesis_probability"] = posterior_dict[ 404 | "probs_hypotheses" 405 | ][:, 2] 406 | 407 | cell_hashing_adata.obs["most_likely_hypothesis"] = results.loc[ 408 | cell_hashing_adata.obs_names, "most_likely_hypothesis" 409 | ] 410 | cell_hashing_adata.obs["cluster_feature"] = results.loc[ 411 | cell_hashing_adata.obs_names, "cluster_feature" 412 | ] 413 | cell_hashing_adata.obs["negative_hypothesis_probability"] = results.loc[ 414 | cell_hashing_adata.obs_names, "negative_hypothesis_probability" 415 | ] 416 | cell_hashing_adata.obs["singlet_hypothesis_probability"] = results.loc[ 417 | cell_hashing_adata.obs_names, "singlet_hypothesis_probability" 418 | ] 419 | cell_hashing_adata.obs["doublet_hypothesis_probability"] = results.loc[ 420 | cell_hashing_adata.obs_names, "doublet_hypothesis_probability" 421 | ] 422 | 423 | cell_hashing_adata.obs["Classification"] = None 424 | cell_hashing_adata.obs.loc[ 425 | cell_hashing_adata.obs["most_likely_hypothesis"] == 2, "Classification" 426 | ] = "Doublet" 427 | cell_hashing_adata.obs.loc[ 428 | cell_hashing_adata.obs["most_likely_hypothesis"] == 0, "Classification" 429 | ] = "Negative" 430 | all_sings = cell_hashing_adata.obs["most_likely_hypothesis"] == 1 431 | singlet_sample_index = np.argmax(cell_hashing_adata.X[all_sings], axis=1) 432 | cell_hashing_adata.obs.loc[ 433 | all_sings, "Classification" 434 | ] = cell_hashing_adata.var_names[singlet_sample_index] 435 | 436 | return cell_hashing_adata if not inplace else None 437 | 438 | 439 | def plot_qc_checks_cell_hashing( 440 | cell_hashing_adata: anndata.AnnData, alpha: float = 0.05, fig_path: str = None 441 | ): 442 | """Plot HashSolo demultiplexing results 443 | 444 | Parameters 445 | ---------- 446 | cell_hashing_adata : Anndata 447 | Anndata object filled only with hashing counts 448 | alpha : float 449 | Tranparency of scatterplot points 450 | fig_path : str 451 | Path to save figure 452 | Returns 453 | ------- 454 | """ 455 | import matplotlib 456 | 457 | matplotlib.use("Agg") 458 | import matplotlib.pyplot as plt 459 | 460 | cell_hashing_demultiplexing = cell_hashing_adata.obs 461 | cell_hashing_demultiplexing["log_counts"] = np.log( 462 | np.sum(cell_hashing_adata.X, axis=1) 463 | ) 464 | number_of_clusters = ( 465 | cell_hashing_demultiplexing["cluster_feature"].drop_duplicates().shape[0] 466 | ) 467 | fig, all_axes = plt.subplots( 468 | number_of_clusters, 4, figsize=(40, 10 * number_of_clusters) 469 | ) 470 | counter = 0 471 | for cluster_feature, group in cell_hashing_demultiplexing.groupby( 472 | "cluster_feature" 473 | ): 474 | if number_of_clusters > 1: 475 | axes = all_axes[counter] 476 | else: 477 | axes = all_axes 478 | 479 | ax = axes[0] 480 | ax.plot( 481 | group["log_counts"], 482 | group["negative_hypothesis_probability"], 483 | "bo", 484 | alpha=alpha, 485 | ) 486 | ax.set_title("Probability of negative hypothesis vs log hashing counts") 487 | ax.set_ylabel("Probability of negative hypothesis") 488 | ax.set_xlabel("Log hashing counts") 489 | 490 | ax = axes[1] 491 | ax.plot( 492 | group["log_counts"], 493 | group["singlet_hypothesis_probability"], 494 | "bo", 495 | alpha=alpha, 496 | ) 497 | ax.set_title("Probability of singlet hypothesis vs log hashing counts") 498 | ax.set_ylabel("Probability of singlet hypothesis") 499 | ax.set_xlabel("Log hashing counts") 500 | 501 | ax = axes[2] 502 | ax.plot( 503 | group["log_counts"], 504 | group["doublet_hypothesis_probability"], 505 | "bo", 506 | alpha=alpha, 507 | ) 508 | ax.set_title("Probability of doublet hypothesis vs log hashing counts") 509 | ax.set_ylabel("Probability of doublet hypothesis") 510 | ax.set_xlabel("Log hashing counts") 511 | 512 | ax = axes[3] 513 | group["Classification"].value_counts().plot.bar(ax=ax) 514 | ax.set_title("Count of each samples classification") 515 | counter += 1 516 | plt.show() 517 | if fig_path is not None: 518 | fig.savefig(fig_path, dpi=300, format="pdf") 519 | 520 | 521 | def main(): 522 | usage = "hashsolo" 523 | parser = ArgumentParser(usage, formatter_class=ArgumentDefaultsHelpFormatter) 524 | 525 | parser.add_argument( 526 | dest="data_file", help="h5ad file containing cell hashing counts" 527 | ) 528 | parser.add_argument( 529 | "-j", 530 | dest="model_json_file", 531 | default=None, 532 | help="json file to pass optional arguments", 533 | ) 534 | parser.add_argument( 535 | "-o", 536 | dest="out_dir", 537 | default="hashsolo_output", 538 | help="Output directory for results", 539 | ) 540 | parser.add_argument( 541 | "-c", 542 | dest="clustering_data", 543 | default=None, 544 | help="h5ad file with count transcriptional data to\ 545 | perform clustering on", 546 | ) 547 | parser.add_argument( 548 | "-p", 549 | dest="pre_existing_clusters", 550 | default=None, 551 | help="column in cell_hashing_data_file.obs to \ 552 | specifying different cell types or clusters", 553 | ) 554 | parser.add_argument( 555 | "-q", 556 | dest="plot_name", 557 | default="hashing_qc_plots.pdf", 558 | help="name of plot to output", 559 | ) 560 | parser.add_argument( 561 | "-n", 562 | dest="number_of_noise_barcodes", 563 | default=None, 564 | help="Number of barcodes to use to create noise \ 565 | distribution", 566 | ) 567 | 568 | args = parser.parse_args() 569 | 570 | model_json_file = args.model_json_file 571 | if model_json_file is not None: 572 | # read parameters 573 | with open(model_json_file) as model_json_open: 574 | params = json.load(model_json_open) 575 | else: 576 | params = {} 577 | data_file = args.data_file 578 | data_ext = os.path.splitext(data_file)[-1] 579 | if data_ext == ".h5ad": 580 | cell_hashing_adata = anndata.read(data_file) 581 | else: 582 | print("Unrecognized file format") 583 | 584 | if args.clustering_data is not None: 585 | clustering_data_file = args.clustering_data 586 | clustering_data_ext = os.path.splitext(clustering_data_file)[-1] 587 | if clustering_data_ext == ".h5ad": 588 | clustering_data = anndata.read(clustering_data_file) 589 | else: 590 | print("Unrecognized file format for clustering data") 591 | else: 592 | clustering_data = None 593 | 594 | if not os.path.isdir(args.out_dir): 595 | os.mkdir(args.out_dir) 596 | 597 | hashsolo( 598 | cell_hashing_adata, 599 | pre_existing_clusters=args.pre_existing_clusters, 600 | number_of_noise_barcodes=args.number_of_noise_barcodes, 601 | clustering_data=clustering_data, 602 | **params 603 | ) 604 | cell_hashing_adata.write(os.path.join(args.out_dir, "hashsoloed.h5ad")) 605 | plot_qc_checks_cell_hashing( 606 | cell_hashing_adata, fig_path=os.path.join(args.out_dir, args.plot_name) 607 | ) 608 | 609 | 610 | ############################################################################### 611 | # __main__ 612 | ############################################################################### 613 | 614 | 615 | if __name__ == "__main__": 616 | main() 617 | -------------------------------------------------------------------------------- /solo/solo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import json 3 | import os 4 | import umap 5 | import sys 6 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 7 | import pkg_resources 8 | 9 | import numpy as np 10 | from sklearn.metrics import * 11 | from scipy.special import softmax 12 | from scanpy import read_10x_mtx 13 | 14 | import torch 15 | from lightning.pytorch.callbacks.early_stopping import EarlyStopping 16 | 17 | import scvi 18 | from scvi.data import read_h5ad, read_loom 19 | from scvi.model import SCVI 20 | from scvi.external import SOLO 21 | 22 | from .utils import knn_smooth_pred_class 23 | 24 | """ 25 | solo.py 26 | 27 | Simulate doublets, train a VAE, and then a classifier on top. 28 | """ 29 | 30 | 31 | ############################################################################### 32 | # main 33 | ############################################################################### 34 | 35 | 36 | def main(): 37 | usage = "solo" 38 | parser = ArgumentParser(usage, formatter_class=ArgumentDefaultsHelpFormatter) 39 | parser.add_argument( 40 | "-j", 41 | dest="model_json_file", 42 | help="json file to pass VAE parameters", 43 | required="--version" not in sys.argv, 44 | ) 45 | parser.add_argument( 46 | "-d", 47 | dest="data_path", 48 | help="path to h5ad, loom, or 10x mtx dir cell by genes counts", 49 | required="--version" not in sys.argv, 50 | ) 51 | parser.add_argument( 52 | "--set-reproducible-seed", 53 | dest="reproducible_seed", 54 | default=None, 55 | type=int, 56 | help="Reproducible seed, give an int to set seed", 57 | ) 58 | parser.add_argument( 59 | "--doublet-depth", 60 | dest="doublet_depth", 61 | default=2.0, 62 | type=float, 63 | help="Depth multiplier for a doublet relative to the \ 64 | average of its constituents", 65 | ) 66 | parser.add_argument( 67 | "-g", dest="gpu", default=True, action="store_true", help="Run on GPU" 68 | ) 69 | parser.add_argument( 70 | "-a", 71 | dest="anndata_output", 72 | default=False, 73 | action="store_true", 74 | help="output modified anndata object with solo scores \ 75 | Only works for anndata", 76 | ) 77 | parser.add_argument("-o", dest="out_dir", default="solo_out") 78 | parser.add_argument( 79 | "-r", 80 | dest="doublet_ratio", 81 | default=2, 82 | type=int, 83 | help="Ratio of doublets to true \ 84 | cells", 85 | ) 86 | parser.add_argument( 87 | "-s", 88 | dest="seed", 89 | default=None, 90 | help="Path to previous solo output \ 91 | directory. Seed VAE models with previously \ 92 | trained solo model. Directory structure is assumed to \ 93 | be the same as solo output directory structure. \ 94 | should at least have a vae.pt a pickled object of \ 95 | vae weights and a latent.npy an np.ndarray of the \ 96 | latents of your cells.", 97 | ) 98 | parser.add_argument( 99 | "-e", 100 | dest="expected_number_of_doublets", 101 | help="Experimentally expected number of doublets", 102 | type=int, 103 | default=None, 104 | ) 105 | parser.add_argument( 106 | "-p", 107 | dest="plot", 108 | default=False, 109 | action="store_true", 110 | help="Plot outputs for solo", 111 | ) 112 | parser.add_argument( 113 | "-recalibrate_scores", 114 | dest="recalibrate_scores", 115 | default=False, 116 | action="store_true", 117 | help="Recalibrate doublet scores (not recommended anymore)", 118 | ) 119 | parser.add_argument( 120 | "--version", 121 | dest="version", 122 | default=False, 123 | action="store_true", 124 | help="Get version of solo-sc", 125 | ) 126 | 127 | parser.add_argument( 128 | "--lr_st", 129 | dest="lr_st", 130 | default=1e-3, 131 | type=int, 132 | help="Learning rate used for solo.train", 133 | ) 134 | parser.add_argument( 135 | "--lr_vae", 136 | dest="lr_vae", 137 | default=1e-3, 138 | type=int, 139 | help="Learning rate used for vae", 140 | ) 141 | 142 | args = parser.parse_args() 143 | 144 | if args.version: 145 | version = pkg_resources.require("solo-sc")[0].version 146 | print(f"Current version of solo-sc is {version}") 147 | if args.model_json_file is None or args.data_path is None: 148 | print("Json or data path not give exiting solo") 149 | sys.exit() 150 | 151 | model_json_file = args.model_json_file 152 | data_path = args.data_path 153 | if args.gpu and not torch.cuda.is_available(): 154 | args.gpu = torch.cuda.is_available() 155 | print("Cuda is not available, switching to cpu running!") 156 | 157 | if not os.path.isdir(args.out_dir): 158 | os.mkdir(args.out_dir) 159 | 160 | if args.reproducible_seed is not None: 161 | scvi.settings.seed = args.reproducible_seed 162 | else: 163 | scvi.settings.seed = np.random.randint(10000) 164 | 165 | ################################################## 166 | # data 167 | 168 | # read loom/anndata 169 | data_ext = os.path.splitext(data_path)[-1] 170 | if data_ext == ".loom": 171 | scvi_data = read_loom(data_path) 172 | elif data_ext == ".h5ad": 173 | scvi_data = read_h5ad(data_path) 174 | elif os.path.isdir(data_path): 175 | scvi_data = read_10x_mtx(path=data_path) 176 | cell_umi_depth = scvi_data.X.sum(axis=1) 177 | fifth, ninetyfifth = np.percentile(cell_umi_depth, [5, 95]) 178 | min_cell_umi_depth = np.min(cell_umi_depth) 179 | max_cell_umi_depth = np.max(cell_umi_depth) 180 | if fifth * 10 < ninetyfifth: 181 | print( 182 | """WARNING YOUR DATA HAS A WIDE RANGE OF CELL DEPTHS. 183 | PLEASE MANUALLY REVIEW YOUR DATA""" 184 | ) 185 | print( 186 | f"Min cell depth: {min_cell_umi_depth}, Max cell depth: {max_cell_umi_depth}" 187 | ) 188 | else: 189 | msg = f"{data_path} is not a recognized format.\n" 190 | msg += "must be one of {h5ad, loom, 10x mtx dir}" 191 | raise TypeError(msg) 192 | 193 | num_cells, num_genes = scvi_data.X.shape 194 | 195 | # check for parameters 196 | if not os.path.exists(model_json_file): 197 | raise FileNotFoundError(f"{model_json_file} does not exist.") 198 | # read parameters 199 | with open(model_json_file, "r") as model_json_open: 200 | params = json.load(model_json_open) 201 | 202 | # set VAE params 203 | vae_params = {} 204 | for par in ["n_hidden", "n_latent", "n_layers", "dropout_rate", "ignore_batch"]: 205 | if par in params: 206 | vae_params[par] = params[par] 207 | 208 | # training parameters 209 | batch_key = params.get("batch_key", None) 210 | batch_size = params.get("batch_size", 128) 211 | valid_pct = params.get("valid_pct", 0.1) 212 | check_val_every_n_epoch = params.get("check_val_every_n_epoch", 5) 213 | learning_rate = params.get("learning_rate", args.lr_st) 214 | stopping_params = {"patience": params.get("patience", 8), "min_delta": 0} 215 | 216 | # protect against single example batch 217 | while num_cells % batch_size == 1: 218 | batch_size = int(np.round(1.25 * batch_size)) 219 | print("Increasing batch_size to %d to avoid single example batch." % batch_size) 220 | 221 | scvi.settings.batch_size = batch_size 222 | ################################################## 223 | # SCVI 224 | #setup_anndata(scvi_data, batch_key=batch_key) 225 | scvi.model.SCVI.setup_anndata(scvi_data, batch_key=batch_key) 226 | vae = SCVI( 227 | scvi_data, 228 | gene_likelihood="nb", 229 | log_variational=True, 230 | **vae_params, 231 | use_observed_lib_size=False, 232 | ) 233 | 234 | if args.seed: 235 | vae = vae.load(os.path.join(args.seed, "vae"), use_gpu=args.gpu) 236 | else: 237 | scvi_callbacks = [] 238 | scvi_callbacks += [ 239 | EarlyStopping( 240 | monitor="reconstruction_loss_validation", mode="min", **stopping_params 241 | ) 242 | ] 243 | plan_kwargs = { 244 | "reduce_lr_on_plateau": True, 245 | "lr_factor": 0.1, 246 | "lr": args.lr_vae, 247 | "lr_patience": 10, 248 | "lr_threshold": 0, 249 | "lr_min": 1e-4, 250 | "lr_scheduler_metric": "reconstruction_loss_validation", 251 | } 252 | vae.train( 253 | max_epochs=2000, 254 | validation_size=valid_pct, 255 | check_val_every_n_epoch=check_val_every_n_epoch, 256 | plan_kwargs=plan_kwargs, 257 | callbacks=scvi_callbacks, 258 | ) 259 | # save VAE 260 | vae.save(os.path.join(args.out_dir, "vae")) 261 | 262 | latent = vae.get_latent_representation() 263 | # save latent representation 264 | np.save(os.path.join(args.out_dir, "latent.npy"), latent.astype("float32")) 265 | 266 | ################################################## 267 | # classifier 268 | 269 | # model 270 | # todo add doublet ratio 271 | solo = SOLO.from_scvi_model(vae, doublet_ratio=args.doublet_ratio) 272 | solo.train( 273 | 2000, 274 | lr=learning_rate, 275 | train_size=0.9, 276 | check_val_every_n_epoch=5, 277 | early_stopping_patience=6, 278 | ) 279 | solo.train( 280 | 2000, 281 | lr=learning_rate * 0.1, 282 | train_size=0.9, 283 | check_val_every_n_epoch=1, 284 | early_stopping_patience=30, 285 | callbacks=[], 286 | ) 287 | solo.save(os.path.join(args.out_dir, "classifier")) 288 | 289 | logit_predictions = solo.predict(include_simulated_doublets=True) 290 | 291 | is_doublet_known = solo.adata.obs._solo_doub_sim == "doublet" 292 | is_doublet_pred = logit_predictions.idxmin(axis=1) == "singlet" 293 | 294 | validation_is_doublet_known = is_doublet_known[solo.validation_indices] 295 | validation_is_doublet_pred = is_doublet_pred[solo.validation_indices] 296 | training_is_doublet_known = is_doublet_known[solo.train_indices] 297 | training_is_doublet_pred = is_doublet_pred[solo.train_indices] 298 | 299 | valid_as = accuracy_score(validation_is_doublet_known, validation_is_doublet_pred) 300 | valid_roc = roc_auc_score(validation_is_doublet_known, validation_is_doublet_pred) 301 | valid_ap = average_precision_score( 302 | validation_is_doublet_known, validation_is_doublet_pred 303 | ) 304 | 305 | train_as = accuracy_score(training_is_doublet_known, training_is_doublet_pred) 306 | train_roc = roc_auc_score(training_is_doublet_known, training_is_doublet_pred) 307 | train_ap = average_precision_score( 308 | training_is_doublet_known, training_is_doublet_pred 309 | ) 310 | 311 | print(f"Training results") 312 | print(f"AUROC: {train_roc}, Accuracy: {train_as}, Average precision: {train_ap}") 313 | 314 | print(f"Validation results") 315 | print(f"AUROC: {valid_roc}, Accuracy: {valid_as}, Average precision: {valid_ap}") 316 | 317 | # write predictions 318 | # softmax predictions 319 | softmax_predictions = softmax(logit_predictions, axis=1) 320 | if logit_predictions.columns[0]=='doublet': 321 | doublet_score = softmax_predictions[:, 0] 322 | else: 323 | doublet_score = softmax_predictions[:, 1] 324 | 325 | np.save( 326 | os.path.join(args.out_dir, "no_updates_softmax_scores.npy"), 327 | doublet_score[:num_cells], 328 | ) 329 | np.savetxt( 330 | os.path.join(args.out_dir, "no_updates_softmax_scores.csv"), 331 | doublet_score[:num_cells], 332 | delimiter=",", 333 | ) 334 | np.save( 335 | os.path.join(args.out_dir, "no_updates_softmax_scores_sim.npy"), 336 | doublet_score[num_cells:], 337 | ) 338 | 339 | # logit predictions 340 | logit_doublet_score = logit_predictions.loc[:, "doublet"] 341 | np.save( 342 | os.path.join(args.out_dir, "logit_scores.npy"), logit_doublet_score[:num_cells] 343 | ) 344 | np.savetxt( 345 | os.path.join(args.out_dir, "logit_scores.csv"), 346 | logit_doublet_score[:num_cells], 347 | delimiter=",", 348 | ) 349 | np.save( 350 | os.path.join(args.out_dir, "logit_scores_sim.npy"), 351 | logit_doublet_score[num_cells:], 352 | ) 353 | 354 | # update threshold as a function of Solo's estimate of the number of 355 | # doublets 356 | # essentially a log odds update 357 | # TODO put in a function 358 | # currently overshrinking softmaxes 359 | diff = np.inf 360 | counter_update = 0 361 | solo_scores = doublet_score[:num_cells] 362 | logit_scores = logit_doublet_score[:num_cells] 363 | d_s = args.doublet_ratio / (args.doublet_ratio + 1) 364 | if args.recalibrate_scores: 365 | while (diff > 0.01) | (counter_update < 5): 366 | 367 | # calculate log odds calibration for logits 368 | d_o = np.mean(solo_scores) 369 | c = np.log(d_o / (1 - d_o)) - np.log(d_s / (1 - d_s)) 370 | 371 | # update solo scores 372 | solo_scores = 1 / (1 + np.exp(-(logit_scores + c))) 373 | 374 | # update while conditions 375 | diff = np.abs(d_o - np.mean(solo_scores)) 376 | counter_update += 1 377 | 378 | np.save(os.path.join(args.out_dir, "softmax_scores.npy"), solo_scores) 379 | np.savetxt( 380 | os.path.join(args.out_dir, "softmax_scores.csv"), solo_scores, delimiter="," 381 | ) 382 | 383 | if args.expected_number_of_doublets is not None: 384 | k = len(solo_scores) - args.expected_number_of_doublets 385 | if args.expected_number_of_doublets / len(solo_scores) > 0.5: 386 | print( 387 | """Make sure you actually expect more than half your cells 388 | to be doublets. If not change your 389 | -e parameter value""" 390 | ) 391 | assert k > 0 392 | idx = np.argpartition(solo_scores, k) 393 | threshold = np.max(solo_scores[idx[:k]]) 394 | is_solo_doublet = solo_scores > threshold 395 | else: 396 | is_solo_doublet = solo_scores > 0.5 397 | 398 | np.save(os.path.join(args.out_dir, "is_doublet.npy"), is_solo_doublet[:num_cells]) 399 | np.savetxt( 400 | os.path.join(args.out_dir, "is_doublet.csv"), 401 | is_solo_doublet[:num_cells], 402 | delimiter=",", 403 | ) 404 | 405 | np.save( 406 | os.path.join(args.out_dir, "is_doublet_sim.npy"), is_solo_doublet[num_cells:] 407 | ) 408 | 409 | np.save(os.path.join(args.out_dir, "preds.npy"), is_doublet_pred[:num_cells]) 410 | np.savetxt( 411 | os.path.join(args.out_dir, "preds.csv"), 412 | is_doublet_pred[:num_cells], 413 | delimiter=",", 414 | ) 415 | 416 | smoothed_preds = knn_smooth_pred_class( 417 | X=latent, pred_class=is_doublet_pred[:num_cells] 418 | ) 419 | np.save(os.path.join(args.out_dir, "smoothed_preds.npy"), smoothed_preds) 420 | 421 | if args.anndata_output and data_ext == ".h5ad": 422 | scvi_data.obs["is_doublet"] = is_solo_doublet[:num_cells].astype(bool) 423 | scvi_data.obs["logit_scores"] = logit_doublet_score[:num_cells].astype( 424 | float 425 | ) 426 | scvi_data.obs["softmax_scores"] = solo_scores[:num_cells].astype(float) 427 | scvi_data.write(os.path.join(args.out_dir, "soloed.h5ad")) 428 | 429 | if args.plot: 430 | import matplotlib 431 | 432 | matplotlib.use("Agg") 433 | import matplotlib.pyplot as plt 434 | import seaborn as sns 435 | 436 | train_solo_scores = doublet_score[solo.train_indices] 437 | validation_solo_scores = doublet_score[solo.validation_indices] 438 | 439 | train_fpr, train_tpr, _ = roc_curve( 440 | training_is_doublet_known, train_solo_scores 441 | ) 442 | val_fpr, val_tpr, _ = roc_curve( 443 | validation_is_doublet_known, validation_solo_scores 444 | ) 445 | 446 | # plot ROC 447 | plt.figure() 448 | plt.plot(train_fpr, train_tpr, label="Train") 449 | plt.plot(val_fpr, val_tpr, label="Validation") 450 | plt.gca().set_xlabel("False positive rate") 451 | plt.gca().set_ylabel("True positive rate") 452 | plt.legend() 453 | plt.savefig(os.path.join(args.out_dir, "roc.pdf")) 454 | plt.close() 455 | 456 | train_precision, train_recall, _ = precision_recall_curve( 457 | training_is_doublet_known, train_solo_scores 458 | ) 459 | val_precision, val_recall, _ = precision_recall_curve( 460 | validation_is_doublet_known, validation_solo_scores 461 | ) 462 | # plot accuracy 463 | plt.figure() 464 | plt.plot(train_recall, train_precision, label="Train") 465 | plt.plot(val_recall, val_precision, label="Validation") 466 | plt.gca().set_xlabel("Recall") 467 | plt.gca().set_ylabel("pytPrecision") 468 | plt.legend() 469 | plt.savefig(os.path.join(args.out_dir, "precision_recall.pdf")) 470 | plt.close() 471 | 472 | # plot distributions 473 | obs_indices = solo.validation_indices[solo.validation_indices < num_cells] 474 | sim_indices = solo.validation_indices[solo.validation_indices > num_cells] 475 | 476 | plt.figure() 477 | sns.displot(doublet_score[sim_indices], label="Simulated") 478 | sns.displot(doublet_score[obs_indices], label="Observed") 479 | plt.legend() 480 | plt.savefig(os.path.join(args.out_dir, "sim_vs_obs_dist.pdf")) 481 | plt.close() 482 | 483 | plt.figure() 484 | sns.distplot(solo_scores[:num_cells], label="Observed (transformed)") 485 | plt.legend() 486 | plt.savefig(os.path.join(args.out_dir, "real_cells_dist.pdf")) 487 | plt.close() 488 | 489 | scvi_umap = umap.UMAP(n_neighbors=16).fit_transform(latent) 490 | fig, ax = plt.subplots(1, 1, figsize=(10, 10)) 491 | ax.scatter( 492 | scvi_umap[:, 0], 493 | scvi_umap[:, 1], 494 | c=doublet_score[:num_cells], 495 | s=8, 496 | cmap="GnBu", 497 | ) 498 | 499 | ax.set_xlabel("UMAP 1") 500 | ax.set_ylabel("UMAP 2") 501 | fig.savefig(os.path.join(args.out_dir, "umap_solo_scores.pdf")) 502 | 503 | 504 | ############################################################################### 505 | # __main__ 506 | ############################################################################### 507 | 508 | 509 | if __name__ == "__main__": 510 | main() 511 | -------------------------------------------------------------------------------- /solo/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from scipy.stats import multinomial 4 | from sklearn.neighbors import NearestNeighbors 5 | 6 | 7 | def knn_smooth_pred_class( 8 | X: np.ndarray, 9 | pred_class: np.ndarray, 10 | grouping: np.ndarray = None, 11 | k: int = 15, 12 | ) -> np.ndarray: 13 | """ 14 | Smooths class predictions by taking the modal class from each cell's 15 | nearest neighbors. 16 | Parameters 17 | ---------- 18 | X : np.ndarray 19 | [N, Features] embedding space for calculation of nearest neighbors. 20 | pred_class : np.ndarray 21 | [N,] array of unique class labels. 22 | groupings : np.ndarray 23 | [N,] unique grouping labels for i.e. clusters. 24 | if provided, only considers nearest neighbors *within the cluster*. 25 | k : int 26 | number of nearest neighbors to use for smoothing. 27 | Returns 28 | ------- 29 | smooth_pred_class : np.ndarray 30 | [N,] unique class labels, smoothed by kNN. 31 | Examples 32 | -------- 33 | >>> smooth_pred_class = knn_smooth_pred_class( 34 | ... X = X, 35 | ... pred_class = raw_predicted_classes, 36 | ... grouping = louvain_cluster_groups, 37 | ... k = 15,) 38 | Notes 39 | ----- 40 | scNym classifiers do not incorporate neighborhood information. 41 | By using a simple kNN smoothing heuristic, we can leverage neighborhood 42 | information to improve classification performance, smoothing out cells 43 | that have an outlier prediction relative to their local neighborhood. 44 | """ 45 | if grouping is None: 46 | # do not use a grouping to restrict local neighborhood 47 | # associations, create a universal pseudogroup `0`. 48 | grouping = np.zeros(X.shape[0]) 49 | 50 | smooth_pred_class = np.zeros_like(pred_class) 51 | for group in np.unique(grouping): 52 | # identify only cells in the relevant group 53 | group_idx = np.where(grouping == group)[0].astype("int") 54 | X_group = X[grouping == group, :] 55 | # if there are < k cells in the group, change `k` to the 56 | # group size 57 | if X_group.shape[0] < k: 58 | k_use = X_group.shape[0] 59 | else: 60 | k_use = k 61 | # compute a nearest neighbor graph and identify kNN 62 | nns = NearestNeighbors( 63 | n_neighbors=k_use, 64 | ).fit(X_group) 65 | dist, idx = nns.kneighbors(X_group) 66 | 67 | # for each cell in the group, assign a class as 68 | # the majority class of the kNN 69 | for i in range(X_group.shape[0]): 70 | classes = pred_class[group_idx[idx[i, :]]] 71 | uniq_classes, counts = np.unique(classes, return_counts=True) 72 | maj_class = uniq_classes[int(np.argmax(counts))] 73 | smooth_pred_class[group_idx[i]] = maj_class 74 | return smooth_pred_class 75 | -------------------------------------------------------------------------------- /solo_params_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_hidden": 128, 3 | "n_latent": 16, 4 | "cl_hidden": 64, 5 | "cl_layers": 1, 6 | "dropout_rate": 0.1, 7 | "learning_rate": 0.001, 8 | "valid_pct": 0.10 9 | } 10 | -------------------------------------------------------------------------------- /testdata/2c.h5ad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calico/solo/d94baa34157eaa9bfe79d87f7a04f578b2e82a4c/testdata/2c.h5ad -------------------------------------------------------------------------------- /testdata/calculate_performance.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import anndata 4 | import numpy as np 5 | from sklearn.metrics import average_precision_score, roc_auc_score 6 | from scipy.stats import mannwhitneyu 7 | import matplotlib.pyplot as plt 8 | import datetime 9 | import pandas as pd 10 | from glob import glob 11 | 12 | """ 13 | calculate performance 14 | """ 15 | 16 | ############################################################################### 17 | # main 18 | ############################################################################### 19 | 20 | experiment_name_to_dataset = { 21 | "pbmc": "2c.h5ad", 22 | "kidney": "gene_ad_filtered_PoolB4FACs_L4_Rep1.h5ad", 23 | } 24 | 25 | 26 | def main(): 27 | for result in glob("results_*/softmax_scores.npy"): 28 | experiment_name = result.split("/")[0].split("_")[1] 29 | experiment_number = result.split("/")[0].split("_")[2] 30 | scores = np.load(result) 31 | adata = anndata.read(experiment_name_to_dataset[experiment_name]) 32 | true_labels = adata.obs.doublet_bool 33 | apr = average_precision_score(true_labels, scores) 34 | auc = roc_auc_score(true_labels, scores) 35 | time = datetime.datetime.now().strftime("%Y-%m-%d %H") 36 | with open("tracking_performance.csv", "a") as file: 37 | file.write(f"{time},{experiment_name},{experiment_number},{apr},{auc}\n") 38 | 39 | performance_tracking = pd.read_csv("tracking_performance.csv") 40 | performance_tracking["date (dt)"] = pd.to_datetime( 41 | performance_tracking["date"], format="%Y-%m-%d %H" 42 | ) 43 | for experiment_name, group in performance_tracking.groupby("experiment_name"): 44 | fig, axes = plt.subplots(2, 1, figsize=(10, 20)) 45 | ax = axes[0] 46 | ax.plot(group["date"], group["average_precision"], ".") 47 | ax.set_xlabel("date") 48 | ax.set_ylabel("average precision") 49 | ax = axes[1] 50 | ax.plot(group["date"], group["AUROC"], ".") 51 | ax.set_xlabel("date") 52 | ax.set_ylabel("AUROC") 53 | fig.savefig(f"{experiment_name}_performance_tracking.png") 54 | second_to_last, most_recent = ( 55 | group["date (dt)"].drop_duplicates().sort_values()[-2:] 56 | ) 57 | second_to_last_df = group[group["date (dt)"] == second_to_last] 58 | most_recent_df = group[group["date (dt)"] == most_recent] 59 | for metric in ["AUROC", "average_precision"]: 60 | mean_change = ( 61 | most_recent_df[metric].mean() - second_to_last_df[metric].mean() 62 | ) 63 | pvalue = mannwhitneyu( 64 | most_recent_df[metric], second_to_last_df[metric] 65 | ).pvalue 66 | print(f"Mean {metric} has changed by for {experiment_name}: {mean_change}") 67 | print( 68 | f"P value for metric change {metric} in experiment {experiment_name}: {pvalue}" 69 | ) 70 | if mean_change < 0 and pvalue < 0.05: 71 | for x in range(0, 5): 72 | print("WARNING!") 73 | print( 74 | f"WARNING {metric} HAS GOTTEN SIGNIFICANTLY WORSE for {experiment_name}!" 75 | ) 76 | if mean_change > 0 and pvalue < 0.05: 77 | for x in range(0, 5): 78 | print("NICE JOB!") 79 | print( 80 | f"NICE JOB {metric} HAS GOTTEN SIGNIFICANTLY BETTER for {experiment_name}!" 81 | ) 82 | 83 | 84 | ############################################################################### 85 | # __main__ 86 | ############################################################################### 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /testdata/gene_ad_filtered_PoolB4FACs_L4_Rep1.h5ad: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:eab4e4dd8eac5c884ce51a01b35482fccd4760fccc40a38add4c9568dd8a97ec 3 | size 355343136 4 | -------------------------------------------------------------------------------- /testdata/kidney_performance_tracking.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calico/solo/d94baa34157eaa9bfe79d87f7a04f578b2e82a4c/testdata/kidney_performance_tracking.png -------------------------------------------------------------------------------- /testdata/pbmc_performance_tracking.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calico/solo/d94baa34157eaa9bfe79d87f7a04f578b2e82a4c/testdata/pbmc_performance_tracking.png -------------------------------------------------------------------------------- /testdata/performance_test_kidney_PoolB4FACs_L4_Rep1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #SBATCH -p gpu 3 | #SBATCH -n 1 4 | #SBATCH --gres=gpu:gtx1080ti:1 5 | #SBATCH --mem 120000 6 | #SBATCH --time 8:00:00 7 | #SBATCH -J solo_permformace_test 8 | #SBATCH --array=1-6 9 | #SBATCH --ntasks-per-node=1 10 | #SBATCH -o logs/solo_performance_%A_%a.out 11 | #SBATCH -e logs/solo_performance_%A_%a.err 12 | 13 | echo "My SLURM_ARRAY_TASK_ID: " $SLURM_ARRAY_TASK_ID 14 | echo 'kidney' 15 | source activate solo-sc 16 | solo -p -a -g -r 2 --set-reproducible-seed "$SLURM_ARRAY_TASK_ID" -o results_kidney_"$SLURM_ARRAY_TASK_ID" -j ../solo_params_example.json -d gene_ad_filtered_PoolB4FACs_L4_Rep1.h5ad 17 | 18 | -------------------------------------------------------------------------------- /testdata/performance_test_pbmc_2c.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #SBATCH -p gpu 3 | #SBATCH -n 1 4 | #SBATCH --gres=gpu:gtx1080ti:1 5 | #SBATCH --mem 120000 6 | #SBATCH --time 8:00:00 7 | #SBATCH -J solo_permformace_test 8 | #SBATCH --array=1-6 9 | #SBATCH --ntasks-per-node=1 10 | #SBATCH -o logs/solo_performance_%A_%a.out 11 | #SBATCH -e logs/solo_performance_%A_%a.err 12 | 13 | echo "My SLURM_ARRAY_TASK_ID: " $SLURM_ARRAY_TASK_ID 14 | echo 'pbmc' 15 | source activate solo-sc 16 | solo -p -a -g -r 2 --set-reproducible-seed "$SLURM_ARRAY_TASK_ID" -o results_pbmc_"$SLURM_ARRAY_TASK_ID" -j ../solo_params_example.json -d 2c.h5ad 17 | -------------------------------------------------------------------------------- /testdata/performance_tracking.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calico/solo/d94baa34157eaa9bfe79d87f7a04f578b2e82a4c/testdata/performance_tracking.png -------------------------------------------------------------------------------- /testdata/tracking_performance.csv: -------------------------------------------------------------------------------- 1 | date,experiment_name,experiment_number,average_precision,AUROC 2 | 2020-02-11 10,pbmc,0,0.642,0.941 3 | 2020-02-11 10,kidney,0,0.652,0.756 4 | 2021-02-11 11,kidney,6,0.6556529157908413,0.7708304010670927 5 | 2021-02-11 11,kidney,1,0.6607756974710838,0.7726586445084818 6 | 2021-02-11 11,pbmc,4,0.6591329163786043,0.9271364908721599 7 | 2021-02-11 11,kidney,5,0.6494058766966985,0.7658514072862478 8 | 2021-02-11 11,kidney,2,0.6618481153467091,0.7730679879883625 9 | 2021-02-11 11,pbmc,3,0.6561663933351891,0.9281592671986435 10 | 2021-02-11 11,pbmc,1,0.6623743880515199,0.9226130098544348 11 | 2021-02-11 11,pbmc,6,0.6277683225071036,0.9207927916996268 12 | 2021-02-11 11,pbmc,2,0.6502136907750218,0.9275159493211569 13 | 2021-02-11 11,kidney,3,0.6467093639812479,0.7573174534495049 14 | 2021-02-11 11,kidney,4,0.6586798480352564,0.7664625647110882 15 | 2021-02-11 11,pbmc,5,0.6478040939493424,0.9225424292745276 16 | 2021-02-18 14,kidney,1,0.6462699623991646,0.7627879613475528 17 | 2021-02-18 14,pbmc,4,0.663326842485893,0.9165240840419175 18 | 2021-02-18 14,kidney,5,0.6480422575749873,0.7652822022156713 19 | 2021-02-18 14,kidney,2,0.648134236408012,0.7684026253229614 20 | 2021-02-18 14,pbmc,3,0.6697257411365286,0.9161077482696734 21 | 2021-02-18 14,pbmc,1,0.6660700438231723,0.916962444444143 22 | 2021-02-18 14,pbmc,6,0.6515373213882809,0.9138401321031976 23 | 2021-02-18 14,pbmc,2,0.6693514926787216,0.9162711036660809 24 | 2021-02-18 14,kidney,3,0.6492802667687769,0.7649097761990968 25 | 2021-02-18 14,kidney,4,0.6472856211876892,0.7633028207786451 26 | 2021-02-18 14,pbmc,5,0.6710936786375143,0.9192144568023591 27 | 2021-02-19 16,pbmc,4,0.6723676680997775,0.9177868779532147 28 | 2021-02-19 16,pbmc,3,0.6651878244617202,0.9157017827302478 29 | 2021-02-19 16,pbmc,1,0.6658168009292,0.9160364650338042 30 | 2021-02-19 16,pbmc,2,0.6746171723485587,0.9178513284655913 31 | 2021-02-19 16,pbmc,5,0.6647481223195565,0.918238273835011 32 | 2021-06-08 11,kidney,6,0.6574942259697635,0.7713399289897841 33 | 2021-06-08 11,kidney,1,0.6491524324360479,0.7683925479354703 34 | 2021-06-08 11,pbmc,4,0.6376143679761396,0.9177092708136871 35 | 2021-06-08 11,kidney,5,0.6507265141297041,0.764045303675037 36 | 2021-06-08 11,kidney,2,0.6536805748853676,0.7694188176691817 37 | 2021-06-08 11,pbmc,3,0.6497117076563463,0.9188563930159294 38 | 2021-06-08 11,pbmc,1,0.6410024329832936,0.9169514442439101 39 | 2021-06-08 11,pbmc,6,0.628524458448013,0.9143017043816851 40 | 2021-06-08 11,pbmc,2,0.6408642380778574,0.9174513476166006 41 | 2021-06-08 11,kidney,3,0.6573702076510592,0.7709389098122496 42 | 2021-06-08 11,kidney,4,0.6527276636519359,0.7714110145008662 43 | 2021-06-08 11,pbmc,5,0.6339190482879701,0.915807205354064 44 | -------------------------------------------------------------------------------- /tests/hashsolo_tests.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from solo import hashsolo 4 | 5 | from anndata import AnnData 6 | import numpy as np 7 | 8 | 9 | def test_cell_demultiplexing(): 10 | from scipy import stats 11 | import random 12 | 13 | random.seed(52) 14 | signal = stats.poisson.rvs(1000, 1, 990) 15 | doublet_signal = stats.poisson.rvs(1000, 1, 10) 16 | x = np.reshape(stats.poisson.rvs(5, 1, 10000), (1000, 10)) 17 | for idx, signal_count in enumerate(signal): 18 | col_pos = idx % 10 19 | x[idx, col_pos] = signal_count 20 | 21 | for idx, signal_count in enumerate(doublet_signal): 22 | col_pos = (idx % 10) - 1 23 | x[idx, col_pos] = signal_count 24 | test_data = AnnData(x) 25 | hashsolo.hashsolo(test_data) 26 | 27 | doublets = ["Doublet"] * 10 28 | classes = list(np.repeat(np.arange(10), 98).reshape(98, 10, order="F").ravel()) 29 | negatives = ["Negative"] * 10 30 | classification = doublets + classes + negatives 31 | 32 | assert all(test_data.obs["Classification"] == classification) 33 | 34 | doublets = [2] * 10 35 | classes = [1] * 980 36 | negatives = [0] * 10 37 | classification = doublets + classes + negatives 38 | ll_results = np.argmax(hashsolo._calculate_log_likelihoods(x, 8)[0], axis=1) 39 | assert all(ll_results == classification) 40 | 41 | bayes_results = hashsolo._calculate_bayes_rule(x, [0.1, 0.8, 0.1], 8) 42 | assert all(bayes_results["most_likely_hypothesis"] == classification) 43 | 44 | singlet_prior = 0.99999999999999999 45 | other_prior = (1 - singlet_prior) / 2 46 | bayes_results = hashsolo._calculate_bayes_rule( 47 | x, [other_prior, singlet_prior, other_prior], 8 48 | ) 49 | assert all(bayes_results["most_likely_hypothesis"] == 1) 50 | --------------------------------------------------------------------------------