├── BIVI ├── BIVI │ ├── README.md │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── biVI.cpython-38.pyc │ │ ├── bivae.cpython-38.pyc │ │ ├── distributions.cpython-38.pyc │ │ └── nnNB_module.cpython-38.pyc │ ├── biVI.py │ ├── bivae.py │ ├── distributions.py │ ├── models │ │ ├── README.md │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ └── __init__.cpython-38.pyc │ │ └── best_model_MODEL.zip │ └── nnNB_module.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── biVI.cpython-38.pyc │ ├── bivae.cpython-38.pyc │ ├── distributions.cpython-38.pyc │ └── nnNB_module.cpython-38.pyc └── setup.py ├── Example ├── Demo.ipynb ├── OLD_Demo.ipynb ├── README ├── kb_pipeline.sh └── pbmc_1k_v3 │ └── counts_filtered │ └── adata.loom ├── LICENSE ├── Manuscript └── analysis │ ├── Fig_2a-d,g_S13_S16_Allen.ipynb │ ├── Fig_2e-f_Differential_Expression.ipynb │ ├── Fig_S10_Degradation_Rate_Validation_Battich.ipynb │ ├── Fig_S11_Allen.ipynb │ ├── Fig_S12_Runtime_Memory.ipynb │ ├── Fig_S14_Inferred_Parameter_Variance.ipynb │ ├── Fig_S16_Linear_Decoder.ipynb │ ├── Fig_S3_Simulated_Bursty.ipynb │ ├── Fig_S4_Simulated_Constitutive.ipynb │ ├── Fig_S5_Simulated_Extrinsic.ipynb │ ├── Fig_S6_Simulated_Marker_Genes.ipynb │ ├── Fig_S7_Nascent_MisQuantification.ipynb │ ├── Fig_S8_Burst_Size_Validation_Takei.ipynb │ ├── Fig_S9_Burst_Size_Validation_Desai.ipynb │ ├── Simulate_data.ipynb │ ├── __pycache__ │ └── calculate_metrics.cpython-38.pyc │ ├── calculate_metrics.py │ ├── preprocess.py │ ├── preprocess.sh │ ├── requirements.txt │ ├── train_biVI.py │ └── train_biVI.sh └── README.md /BIVI/BIVI/README.md: -------------------------------------------------------------------------------- 1 | Modifications to scVI-tools [1] modules to allow for mechanistic integration of unspliced and spliced count matrices. 2 | 3 | 4 | 5 | [1] Romain Lopez, Jeffrey Regier, Michael B. Cole, Michael I. Jordan & Nir Yosef (2018), 6 | _Deep generative modeling for single-cell transcriptomics_, 7 | [Nature Methods](https://www.nature.com/articles/s41592-018-0229-2). -------------------------------------------------------------------------------- /BIVI/BIVI/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pachterlab/CGCCP_2023/8e4b6c99e3bda5d664bea51b89302357538e5bd5/BIVI/BIVI/__init__.py -------------------------------------------------------------------------------- /BIVI/BIVI/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pachterlab/CGCCP_2023/8e4b6c99e3bda5d664bea51b89302357538e5bd5/BIVI/BIVI/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /BIVI/BIVI/__pycache__/biVI.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pachterlab/CGCCP_2023/8e4b6c99e3bda5d664bea51b89302357538e5bd5/BIVI/BIVI/__pycache__/biVI.cpython-38.pyc -------------------------------------------------------------------------------- /BIVI/BIVI/__pycache__/bivae.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pachterlab/CGCCP_2023/8e4b6c99e3bda5d664bea51b89302357538e5bd5/BIVI/BIVI/__pycache__/bivae.cpython-38.pyc -------------------------------------------------------------------------------- /BIVI/BIVI/__pycache__/distributions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pachterlab/CGCCP_2023/8e4b6c99e3bda5d664bea51b89302357538e5bd5/BIVI/BIVI/__pycache__/distributions.cpython-38.pyc -------------------------------------------------------------------------------- /BIVI/BIVI/__pycache__/nnNB_module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pachterlab/CGCCP_2023/8e4b6c99e3bda5d664bea51b89302357538e5bd5/BIVI/BIVI/__pycache__/nnNB_module.cpython-38.pyc -------------------------------------------------------------------------------- /BIVI/BIVI/biVI.py: -------------------------------------------------------------------------------- 1 | """Built atop scVI-tools https://github.com/scverse/scvi-tools/tree/7523a30c16397620cf50098fb0fa53cd32395090""" 2 | import sys 3 | sys.path.append('../') 4 | 5 | import logging 6 | from typing import Iterable, List, Optional, Dict, Sequence, Union, TypeVar, Tuple 7 | from collections.abc import Iterable as IterableClass 8 | Number = TypeVar("Number", int, float) 9 | 10 | from anndata import AnnData 11 | import torch 12 | import numpy as np 13 | import pandas as pd 14 | 15 | from typing import Literal 16 | from scvi.model._scvi import SCVI 17 | from scvi.model._utils import ( 18 | _get_batch_code_from_category, 19 | ) 20 | 21 | #### import the BIVAE model! 22 | from BIVI import bivae 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | class biVI(SCVI): 27 | def __init__( 28 | self, 29 | adata: AnnData, 30 | n_hidden: int = 128, 31 | n_latent: int = 10, 32 | n_layers: int = 1, 33 | dropout_rate: float = 0.1, 34 | dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", 35 | gene_likelihood: Literal["nb"] = "nb", 36 | latent_distribution: Literal["normal", "ln"] = "normal", 37 | n_continuous_cov: int = 0, 38 | n_cats_per_cov: Optional[Iterable[int]] = None, 39 | decoder_type : Literal["non-linear","linear"] = "non-linear", 40 | **model_kwargs, 41 | ): 42 | ## switch from VAE to BIVAE 43 | super(SCVI, self).__init__(adata) 44 | self.module = bivae.BIVAE( 45 | n_input=self.summary_stats["n_vars"], 46 | n_batch=self.summary_stats["n_batch"], 47 | n_hidden=n_hidden, 48 | n_latent=n_latent, 49 | n_layers=n_layers, 50 | dropout_rate=dropout_rate, 51 | dispersion=dispersion, 52 | gene_likelihood=gene_likelihood, 53 | latent_distribution=latent_distribution, 54 | n_continuous_cov=n_continuous_cov, 55 | n_cats_per_cov=n_cats_per_cov, 56 | decoder_type = decoder_type, 57 | **model_kwargs, 58 | ) 59 | 60 | self._model_summary_string = ( 61 | "BIVI Model with the following params: \n mode: {}, n_hidden: {}, n_latent: {}, n_layers: {}, dropout_rate: " 62 | "{}, dispersion: {}, gene_likelihood: {}, latent_distribution: {}" 63 | ).format( 64 | self.module.mode, 65 | n_hidden, 66 | n_latent, 67 | n_layers, 68 | dropout_rate, 69 | dispersion, 70 | gene_likelihood, 71 | latent_distribution, 72 | ) 73 | self.init_params_ = self._get_init_params(locals()) 74 | 75 | @torch.inference_mode() 76 | def get_likelihood_parameters( 77 | self, 78 | adata: Optional[AnnData] = None, 79 | indices: Optional[Sequence[int]] = None, 80 | n_samples: Optional[int] = 1, 81 | give_mean: Optional[bool] = False, 82 | batch_size: Optional[int] = None, 83 | ) -> Dict[str, np.ndarray]: 84 | r""" 85 | Estimates for the parameters of the likelihood :math:`p(x \mid z)` 86 | Parameters 87 | ---------- 88 | adata 89 | AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the 90 | AnnData object used to initialize the model. 91 | indices 92 | Indices of cells in adata to use. If `None`, all cells are used. 93 | n_samples 94 | Number of posterior samples to use for estimation. 95 | give_mean 96 | Return expected value of parameters or a samples 97 | batch_size 98 | Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. 99 | """ 100 | 101 | adata = self._validate_anndata(adata) 102 | 103 | scdl = self._make_data_loader( 104 | adata=adata, indices=indices, batch_size=batch_size 105 | ) 106 | 107 | dropout_list = [] 108 | mean_list = [] 109 | dispersion_list = [] 110 | 111 | 112 | for tensors in scdl: 113 | 114 | inference_kwargs = dict(n_samples=n_samples) 115 | _, generative_outputs = self.module.forward( 116 | tensors=tensors, 117 | inference_kwargs=inference_kwargs, 118 | compute_loss=False, 119 | ) 120 | px = generative_outputs["px"] 121 | 122 | px_r = px.theta 123 | px_rate = px.mu 124 | 125 | if self.module.gene_likelihood == "zinb": 126 | px_dropout = px.zi_probs 127 | 128 | n_batch = px_rate.size()[0] if n_samples == 1 else px_rate.size(1) 129 | 130 | px_r = px_r.cpu().numpy() 131 | if len(px_r.shape) == 1: 132 | dispersion_list += [np.repeat(px_r[np.newaxis, :], n_batch, axis=0)] 133 | else: 134 | dispersion_list += [px_r] 135 | mean_list += [px_rate.cpu().numpy()] 136 | if self.module.gene_likelihood == "zinb": 137 | dropout_list += [px_dropout.cpu().numpy()] 138 | dropout = np.concatenate(dropout_list, axis=-2) 139 | 140 | means = np.concatenate(mean_list, axis=-2) 141 | dispersions = np.concatenate(dispersion_list, axis=-2) 142 | 143 | if give_mean and n_samples > 1: 144 | if self.module.gene_likelihood == "zinb": 145 | dropout = dropout.mean(0) 146 | means = means.mean(0) 147 | dispersions = dispersions.mean(0) 148 | 149 | return_dict = {} 150 | return_dict["mean"] = means 151 | n_genes = np.shape(means)[-1]/2 152 | 153 | 154 | 155 | if self.module.gene_likelihood == "zinb": 156 | return_dict["dropout"] = dropout 157 | return_dict["dispersions"] = dispersions 158 | if self.module.gene_likelihood == "nb": 159 | return_dict["dispersions"] = dispersions 160 | 161 | 162 | if self.module.mode == 'Bursty': 163 | print('Bursty mode, getting parameters') 164 | mu1 = means[...,:int(n_genes)] 165 | mu2 = means[...,int(n_genes):] 166 | return_dict['unspliced_means'] = mu1 167 | return_dict['spliced_means'] = mu2 168 | return_dict['dispersions'] = dispersions 169 | 170 | 171 | b,beta,gamma = get_bursty_params(mu1,mu2,dispersions,THETA_IS = self.module.THETA_IS) 172 | 173 | return_dict['burst_size'] = b 174 | return_dict['rel_splicing_rate'] = beta 175 | return_dict['rel_degradation_rate'] = gamma 176 | 177 | if self.module.mode == 'NBcorr': 178 | print('Extrinsic mode, getting parameters') 179 | 180 | mu1 = means[...,:int(n_genes)] 181 | mu2 = means[...,int(n_genes):] 182 | return_dict['unspliced_means'] = mu1 183 | return_dict['spliced_means'] = mu2 184 | return_dict['dispersions'] = dispersions 185 | 186 | alpha,beta,gamma = get_extrinsic_params(mu1,mu2,dispersions) 187 | 188 | return_dict['alpha'] = alpha 189 | return_dict['rel_splicing_rate'] = beta 190 | return_dict['rel_degradation_rate'] = gamma 191 | 192 | if self.module.mode == 'Poisson': 193 | print('Constitutive mode, getting parameters') 194 | mu1 = means[...,:int(n_genes)] 195 | mu2 = means[...,int(n_genes):] 196 | return_dict['unspliced_means'] = mu1 197 | return_dict['spliced_means'] = mu2 198 | 199 | beta,gamma = 1/mu1,1/mu2 200 | 201 | return_dict['rel_splicing_rate'] = beta 202 | return_dict['rel_degradation_rate'] = gamma 203 | 204 | return return_dict 205 | 206 | 207 | @torch.no_grad() 208 | def get_normalized_expression( 209 | self, 210 | adata=None, 211 | indices=None, 212 | transform_batch: Optional[Sequence[Union[Number, str]]] = None, 213 | gene_list: Optional[Sequence[str]] = None, 214 | library_size: Optional[Union[float, Literal["latent"]]] = 1, 215 | n_samples: int = 1, 216 | batch_size: Optional[int] = None, 217 | return_mean: bool = False, 218 | # return_numpy: Optional[bool] = None, 219 | ) -> Dict[str, np.array]: 220 | r""" 221 | Returns the normalized gene expression, normalized burst size, and relative degradation rate. 222 | 223 | 224 | Parameters 225 | ---------- 226 | adata 227 | AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the 228 | AnnData object used to initialize the model. 229 | indices 230 | Indices of cells in adata to use. If `None`, all cells are used. 231 | transform_batch 232 | Batch to condition on. 233 | If transform_batch is: 234 | 235 | - None, then real observed batch is used 236 | - int, then batch transform_batch is used 237 | - List[int], then average over batches in list 238 | gene_list 239 | Return frequencies of expression for a subset of genes. 240 | This can save memory when working with large datasets and few genes are 241 | of interest. 242 | library_size 243 | Scale the expression frequencies to a common library size. 244 | This allows gene expression levels to be interpreted on a common scale of relevant 245 | magnitude. 246 | n_samples 247 | Get sample scale from multiple samples. 248 | batch_size 249 | Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. 250 | return_mean -- VALID 251 | Whether to return the mean of the samples. 252 | 253 | 254 | Returns 255 | ------- 256 | - **gene_normalized_expression** - normalized expression for RNA 257 | - **protein_normalized_expression** - normalized expression for proteins 258 | 259 | If ``n_samples`` > 1 and ``return_mean`` is False, then the shape is ``(samples, cells, genes)``. 260 | Otherwise, shape is ``(cells, genes)``. Return type is numpy array. 261 | """ 262 | adata = self._validate_anndata(adata) 263 | post = self._make_data_loader( 264 | adata=adata, indices=indices, batch_size=batch_size 265 | ) 266 | 267 | if gene_list is None: 268 | gene_mask = slice(None) 269 | else: 270 | # FOR NOW!!! Change so that gene list and what genes in anndata are called are consistent 271 | all_genes = adata.var['gene_name'].tolist() 272 | gene_mask = [True if gene in gene_list else False for gene in all_genes] 273 | if indices is None: 274 | indices = np.arange(adata.n_obs) 275 | 276 | # if n_samples > 1 and return_mean is False: 277 | # if return_numpy is False: 278 | # logger.warning( 279 | # "return_numpy must be True if n_samples > 1 and return_mean is False, returning np.ndarray" 280 | # ) 281 | # return_numpy = True 282 | 283 | if not isinstance(transform_batch, IterableClass): 284 | transform_batch = [transform_batch] 285 | 286 | transform_batch = _get_batch_code_from_category( 287 | self.get_anndata_manager(adata, required=True), transform_batch) 288 | 289 | scale_list_gene = [] 290 | dispersion_list = [] 291 | 292 | for tensors in post: 293 | x = tensors['X'] 294 | px_scale = torch.zeros_like(x) 295 | 296 | if n_samples > 1: 297 | px_scale = torch.stack(n_samples * [px_scale]) 298 | 299 | for b in transform_batch: 300 | if b is not None: 301 | batch_indices = tensors[_CONSTANTS.BATCH_KEY] 302 | tensors[_CONSTANTS.BATCH_KEY] = torch.ones_like(batch_indices) * b 303 | inference_kwargs = dict(n_samples=n_samples) 304 | inference_outputs, generative_outputs = self.module.forward( 305 | tensors=tensors, 306 | inference_kwargs=inference_kwargs, 307 | compute_loss=False, 308 | ) 309 | px = generative_outputs["px"] 310 | 311 | if library_size == "latent": 312 | px_scale += px.rate.cpu() 313 | else: 314 | px_scale += px.scale.cpu() 315 | 316 | px_scale = px.scale[..., gene_mask] 317 | 318 | 319 | 320 | px_scale /= len(transform_batch) 321 | scale_list_gene.append(px_scale) 322 | 323 | px_theta = generative_outputs['px'].theta 324 | dispersion_list.append(px_theta) 325 | 326 | 327 | return_dict = {} 328 | 329 | 330 | if n_samples > 1: 331 | # concatenate along batch dimension -> result shape = (samples, cells, features) 332 | scale_list_gene = torch.cat(scale_list_gene, dim=1) 333 | dispersion_list = torch.cat(dispersion_list,dim=1).cpu().numpy() 334 | n_genes = int(scale_list_gene.size()[-1]/2) 335 | scale_list_unspliced = scale_list_gene[...,:n_genes].cpu().numpy() 336 | scale_list_spliced = scale_list_gene[...,n_genes:].cpu().numpy() 337 | 338 | if self.module.mode == "Bursty": 339 | b,beta,gamma = get_bursty_params(scale_list_unspliced,scale_list_spliced, 340 | dispersion_list, THETA_IS = self.module.THETA_IS) 341 | return_dict['norm_burst_size'] = b 342 | return_dict['norm_degradation_rate'] = gamma 343 | return_dict['norm_splicing_rate'] = beta 344 | return_dict['norm_spliced_mean'] = scale_list_spliced 345 | return_dict['norm_unspliced_mean'] = scale_list_unspliced 346 | elif self.module.mode == "NBcorr": 347 | alpha,beta,gamma = get_extrinsic_params(scale_list_unspliced,scale_list_spliced,dispersion_list) 348 | return_dict['norm_alpha'] = alpha 349 | return_dict['norm_beta'] = beta 350 | return_dict['norm_gamma'] = gamma 351 | return_dict['norm_spliced_mean'] = scale_list_spliced 352 | return_dict['norm_unspliced_mean'] = scale_list_unspliced 353 | elif self.module.mode == "Poisson": 354 | beta,gamma = get_constitutive_params(scale_list_unspliced,scale_list_spliced) 355 | return_dict['norm_splicing_rate'] = beta 356 | return_dict['norm_degradation_rate'] = gamma 357 | return_dict['norm_spliced_mean'] = scale_list_spliced 358 | return_dict['norm_unspliced_mean'] = scale_list_unspliced 359 | else: 360 | raise Exception("Please use valid biVI mode: Bursty, NBcorr, or Poisson.") 361 | 362 | 363 | else: 364 | scale_list_genes = torch.cat(scale_list_gene, dim=0) 365 | dispersion_list = torch.cat(dispersion_list,dim=0).cpu().numpy() 366 | n_genes = int(np.shape(scale_list_genes)[-1]/2) 367 | scale_list_unspliced = scale_list_genes[...,:n_genes].cpu().numpy() 368 | scale_list_spliced = scale_list_genes[...,n_genes:].cpu().numpy() 369 | 370 | if self.module.mode == "Bursty": 371 | b,beta,gamma = get_bursty_params(scale_list_unspliced,scale_list_spliced, 372 | dispersion_list,THETA_IS = self.module.THETA_IS) 373 | return_dict['norm_burst_size'] = b 374 | return_dict['norm_degradation_rate'] = gamma 375 | return_dict['norm_splicing_rate'] = beta 376 | return_dict['norm_spliced_mean'] = scale_list_spliced 377 | return_dict['norm_unspliced_mean'] = scale_list_unspliced 378 | elif self.module.mode == "NBcorr": 379 | alpha,beta,gamma = get_extrinsic_params(scale_list_unspliced,scale_list_spliced,dispersion_list) 380 | return_dict['norm_alpha'] = alpha 381 | return_dict['norm_beta'] = beta 382 | return_dict['norm_gamma'] = gamma 383 | return_dict['norm_spliced_mean'] = scale_list_spliced 384 | return_dict['norm_unspliced_mean'] = scale_list_unspliced 385 | elif self.module.mode == "Poisson": 386 | beta,gamma = get_constitutive_params(scale_list_unspliced,scale_list_spliced) 387 | return_dict['norm_splicing_rate'] = beta 388 | return_dict['norm_degradation_rate'] = gamma 389 | return_dict['norm_spliced_mean'] = scale_list_spliced 390 | return_dict['norm_unspliced_mean'] = scale_list_unspliced 391 | else: 392 | raise Exception("Please use valid biVI mode: Bursty, NBcorr, or Poisson.") 393 | 394 | 395 | 396 | if (return_mean == True) and (n_samples > 1): 397 | for param in return_dict.keys(): 398 | return_dict[param] = np.mean(return_dict[param], axis=0) 399 | 400 | return return_dict 401 | 402 | def get_bayes_factors( 403 | self, 404 | adata = None, 405 | idx1 = None, 406 | idx2 = None, 407 | delta : Optional[float] = 0.2, 408 | gene_list: Optional[Sequence[str]] = None, 409 | library_size: Optional[Union[float, Literal["latent"]]] = 1, 410 | n_samples_1 : int = 10, 411 | n_samples_2 : int = 10, 412 | n_comparisons: int = 5000, 413 | batch_size : Optional[int] = None, 414 | return_all_lfc : bool = False, 415 | # potentially change 416 | eps : float = 1e-10, 417 | return_df : bool = False, 418 | params_dict_1 : str = 'Calculate', 419 | params_dict_2 : str = 'Calculate' 420 | ) -> Union[pd.DataFrame,Tuple[Dict[str, np.array],Dict[str,np.array]]]: 421 | 422 | ''' Calculates Bayes Factor for gene_list (or all genes if no gene_list is input). 423 | Considers two hypotheses for differential expression of parameters in groups A and B: 424 | 425 | LFC : $lfc = log2(\rho_a) - log2(\rho_b)$ 426 | 427 | $H_0 = P(|lfc| >= delta)$ 428 | $H_1 = P(|lfc| < delta)$ 429 | 430 | Parameters 431 | ------------------- 432 | idx1 433 | index for group A 434 | idx2 435 | index for group B 436 | delta 437 | threshold above which to consider the LFC differential between two parameters 438 | gene_list 439 | genes to consider for differential expression testing 440 | library size 441 | default 1 "normalized", can scale or use "latent" 442 | n_samples_1 443 | number of samples from posterior to take for each cell in sample 1 444 | n_samples_2 445 | number of samples from posterior to take for each cell in sample 2 446 | n_comparisons 447 | number of permuted comparisons between samples in each cell,max is (|idx1|*n_samples_1*|idx2|*n_samples_2) 448 | batch_size 449 | size of batch to pass through forward model when sampling from posterior 450 | return_all_lfc 451 | return all the calculated LFCs, defaults to returning the median 452 | return_df 453 | return pandas DataFrame with information 454 | 455 | 456 | Returns 457 | ------------------- 458 | BF_dict 459 | dictionary of all calculated Bayes Factors for genes between groups A and B 460 | effect_size_dict 461 | dictionary with either mean or array of all effect sizes for genes between A and B 462 | 463 | or 464 | bayes_df 465 | pandas DataFrame with information 466 | ''' 467 | 468 | BF_dict = {} 469 | effect_size_dict = {} 470 | df_dict = {} 471 | 472 | if params_dict_1 == 'Calculate': 473 | # sample from posterior for ind1 and ind2 -- doesn't deal with batches yet, could add that later 474 | params_dict_1 = self.get_normalized_expression( 475 | adata=adata, 476 | indices = idx1, 477 | # transform_batch: Optional[Sequence[Union[Number, str]]] = None, 478 | gene_list = gene_list, 479 | library_size = library_size, 480 | n_samples = n_samples_1, 481 | batch_size = batch_size, 482 | return_mean = False, 483 | ) 484 | 485 | if params_dict_2 == 'Calculate': 486 | params_dict_2 = self.get_normalized_expression( 487 | adata=adata, 488 | indices = idx2, 489 | # transform_batch: Optional[Sequence[Union[Number, str]]] = None, 490 | gene_list = gene_list, 491 | library_size = library_size, 492 | n_samples = n_samples_2, 493 | batch_size = batch_size, 494 | return_mean = False, 495 | ) 496 | 497 | 498 | n_possible_permutations = len(idx1)*len(idx2)*n_samples_1*n_samples_2 499 | n_comparisons = min(n_comparisons,n_possible_permutations) 500 | 501 | # Bayes Factor for each parameter 502 | 503 | for param in params_dict_1.keys(): 504 | 505 | params1 = params_dict_1[param] 506 | params2 = params_dict_2[param] 507 | 508 | 509 | N_genes = params1.shape[-1] 510 | 511 | # reshape 512 | params1 = params1.reshape(len(idx1)*n_samples_1,N_genes) 513 | params2 = params2.reshape(len(idx2)*n_samples_2,N_genes) 514 | 515 | 516 | compare_array1,compare_array2 = get_compare_arrays(params1,params2,n_comparisons) 517 | 518 | # a la scVI -- shrink LFC to 0 when there are no observed nascent or maturecounts 519 | where_zero_a = (np.max(adata[idx1].X,0).todense() == 0)[:,:N_genes] & (np.max(adata[idx1].X,0).todense() == 0)[:,N_genes:] 520 | where_zero_b = (np.max(adata[idx2].X,0).todense() == 0)[:,:N_genes] & (np.max(adata[idx2].X,0).todense() == 0)[:,N_genes:] 521 | 522 | 523 | eps = self.estimate_pseudocounts_offset(params1, params2, where_zero_a, where_zero_b) 524 | 525 | 526 | lfc_values = np.log2(compare_array1+eps) - np.log2(compare_array2+eps) 527 | lfc_abs = np.abs(lfc_values) 528 | 529 | 530 | 531 | BF = np.sum(lfc_abs>=delta,axis=0)/(np.sum(lfc_abs=delta,axis=0) 543 | param_dict['prob_not_DE'] = np.sum(lfc_abs= 1: 592 | artefact_scales_a = max_scales_a[where_zero_a] 593 | eps_a = np.percentile(artefact_scales_a, q=percentile) 594 | else: 595 | eps_a = 1e-10 596 | 597 | if where_zero_b.sum() >= 1: 598 | artefact_scales_b = max_scales_b[where_zero_b] 599 | eps_b = np.percentile(artefact_scales_b, q=percentile) 600 | else: 601 | eps_b = 1e-10 602 | res = np.maximum(eps_a, eps_b) 603 | return res 604 | 605 | 606 | 607 | def get_compare_arrays(params1,params2,n_comparisons): 608 | ''' Returns comparison arrays for params1 and params2. 609 | 610 | Randomly samples from params1 and params2 n_comparison times and constructs two arrays of the random samples. 611 | ''' 612 | length_1 = np.shape(params1)[0] 613 | length_2 = np.shape(params2)[0] 614 | n_each_sample = min(length_1,length_2) 615 | n_left_to_sample = n_comparisons 616 | 617 | compare_array1 = np.zeros((n_comparisons,params1.shape[1])) 618 | compare_array2 = np.zeros((n_comparisons,params1.shape[1])) 619 | 620 | # how many have been sampled 621 | n_sampled = 0 622 | 623 | while n_left_to_sample > 0: 624 | 625 | if n_left_to_sample > n_each_sample: 626 | samp = n_each_sample 627 | if n_left_to_sample < n_each_sample: 628 | samp = n_left_to_sample 629 | 630 | 631 | rand_1 = np.random.choice(length_1,samp) 632 | rand_2 = np.random.choice(length_2,samp) 633 | 634 | arr1 = params1[rand_1,:] 635 | arr2 = params2[rand_2,:] 636 | 637 | 638 | 639 | compare_array1[n_sampled : n_sampled+samp, :] = arr1 640 | compare_array2[n_sampled : n_sampled+samp, :] = arr2 641 | 642 | n_left_to_sample -= samp 643 | n_sampled += samp 644 | 645 | return compare_array1,compare_array2 646 | 647 | def get_bursty_params(mu1,mu2,theta,THETA_IS = 'NAS_SHAPE'): 648 | ''' Returns b, beta, gamma of bursty distribution given mu1, mu2 and theta. 649 | Returns whatever size was input. 650 | ''' 651 | 652 | if THETA_IS == 'MAT_SHAPE': 653 | gamma = 1/theta 654 | b = mu2*gamma 655 | beta = b/mu1 656 | 657 | elif THETA_IS == 'B': 658 | b = theta 659 | beta = b/mu1 660 | gamma = b/mu2 661 | 662 | elif THETA_IS == 'NAS_SHAPE': 663 | beta = 1/theta 664 | b = mu1*beta 665 | gamma = b/mu2 666 | 667 | 668 | return(b,beta,gamma) 669 | 670 | 671 | def get_extrinsic_params(mu1,mu2,theta): 672 | ''' Returns splicing rate beta, degradation rate gamma, and alpha (mean of transcription rate distribution) 673 | given BVNB extrinsic noise model. 674 | ''' 675 | alpha = theta 676 | beta = theta/mu1 677 | gamma = theta/mu2 678 | 679 | 680 | return(alpha,beta,gamma) 681 | 682 | def get_constitutive_params(mu1,mu2): 683 | ''' Returns rate of splicing rate beta and rate of degradation gamma given constitutive model. 684 | ''' 685 | beta = 1/mu1 686 | gamma = 1/mu2 687 | 688 | return(beta,gamma) 689 | 690 | -------------------------------------------------------------------------------- /BIVI/BIVI/bivae.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Main module.""" 3 | """Built atop scVI-tools https://github.com/scverse/scvi-tools/tree/7523a30c16397620cf50098fb0fa53cd32395090""" 4 | import sys 5 | sys.path.append('../') 6 | 7 | from typing import Dict, Iterable, Optional, Sequence, Union 8 | import anndata 9 | from anndata import AnnData 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | from torch import logsumexp 15 | from torch.distributions import Normal 16 | from torch.distributions import kl_divergence as kl 17 | 18 | from scvi import REGISTRY_KEYS 19 | from typing import Literal 20 | from scvi.distributions import NegativeBinomial, Poisson, ZeroInflatedNegativeBinomial 21 | from scvi.module.base import auto_move_data 22 | from scvi.nn import DecoderSCVI, Encoder, LinearDecoderSCVI, one_hot 23 | 24 | torch.backends.cudnn.benchmark = True 25 | 26 | from scvi.module._vae import VAE, LDVAE 27 | 28 | # import custom distributions 29 | from BIVI.distributions import BivariateNegativeBinomial, log_prob_poisson, log_prob_NBcorr, log_prob_NBuncorr 30 | from BIVI.nnNB_module import log_prob_nnNB 31 | 32 | torch.backends.cudnn.benchmark = True 33 | 34 | # BIVAE model 35 | class BIVAE(VAE): 36 | """ 37 | """ 38 | def __init__(self, 39 | gene_likelihood: str = "nb", 40 | mode: Literal['NB','NBcorr','Poisson','Bursty','custom'] = 'Bursty', 41 | n_batch: int = 0, 42 | n_continuous_cov: int = 0, 43 | n_cats_per_cov: Optional[Iterable[int]] = None, 44 | deeply_inject_covariates: bool = True, 45 | use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both", 46 | use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none", 47 | use_size_factor_key: bool = False, 48 | custom_dist = None, 49 | THETA_IS : Literal['NAS_SHAPE','MAT_SHAPE','B'] ='NAS_SHAPE', 50 | decoder_type : Literal["non-linear","linear"] = "non-linear", 51 | bias : bool = False, 52 | **kwargs): 53 | print(kwargs) 54 | super().__init__(gene_likelihood=gene_likelihood, 55 | n_continuous_cov=n_continuous_cov, 56 | n_cats_per_cov=n_cats_per_cov, 57 | **kwargs) 58 | 59 | self.decoder_type = decoder_type 60 | self.mode = mode 61 | print('Initiating biVAE') 62 | print(f'Mode: {mode}, Decoder: {decoder_type}, Theta is: {THETA_IS}') 63 | 64 | # define the new custom distribution 65 | if mode == 'custom': 66 | self.custom_dist = custom_dist 67 | elif mode == 'NB': 68 | self.custom_dist = log_prob_NBuncorr 69 | elif mode == 'NBcorr': 70 | self.custom_dist = log_prob_NBcorr 71 | elif mode == 'Poisson': 72 | self.custom_dist = log_prob_poisson 73 | elif mode == 'Bursty': 74 | self.custom_dist = log_prob_nnNB 75 | self.THETA_IS = THETA_IS 76 | 77 | #### switch to n_input/2 (shared between each spliced/unspliced gene) 78 | n_input = kwargs['n_input'] 79 | n_input_px_r = int(n_input/2) # theta !! 80 | 81 | if self.dispersion == "gene": 82 | self.px_r = torch.nn.Parameter(torch.randn(n_input_px_r)) 83 | elif self.dispersion == "gene-batch": 84 | self.px_r = torch.nn.Parameter(torch.randn(n_input_px_r, n_batch)) 85 | elif self.dispersion == "gene-label": 86 | self.px_r = torch.nn.Parameter(torch.randn(n_input_px_r, n_labels)) 87 | elif self.dispersion == "gene-cell": 88 | pass 89 | else: 90 | raise ValueError( 91 | "dispersion must be one of ['gene', 'gene-batch'," 92 | " 'gene-label', 'gene-cell'], but input was " 93 | "{}.format(self.dispersion)" 94 | ) 95 | 96 | 97 | # decoder goes from n_latent-dimensional space to n_input-d data 98 | n_latent = kwargs['n_latent'] 99 | n_layers = kwargs['n_layers'] 100 | n_hidden = kwargs['n_hidden'] 101 | 102 | use_batch_norm_decoder = use_batch_norm == "decoder" or use_batch_norm == "both" 103 | use_layer_norm_decoder = use_layer_norm == "decoder" or use_layer_norm == "both" 104 | 105 | n_input_decoder = n_latent + n_continuous_cov 106 | cat_list = [n_batch] + list([] if n_cats_per_cov is None else n_cats_per_cov) 107 | 108 | n_output = n_input 109 | 110 | #### modify decoderSCVI class 111 | if decoder_type == "non-linear": 112 | self.decoder = DecoderSCVI( 113 | n_input_decoder, 114 | n_output, # modified 115 | n_cat_list=cat_list, 116 | n_layers=n_layers, 117 | n_hidden=n_hidden, 118 | inject_covariates=deeply_inject_covariates, 119 | use_batch_norm=use_batch_norm_decoder, 120 | use_layer_norm=use_layer_norm_decoder, 121 | scale_activation="softplus" if use_size_factor_key else "softmax", 122 | ) 123 | elif decoder_type == "linear": 124 | self.decoder = LinearDecoderSCVI( 125 | n_latent, 126 | n_input, 127 | n_cat_list=cat_list, 128 | use_batch_norm=use_batch_norm_decoder, 129 | use_layer_norm=use_layer_norm_decoder, 130 | bias=bias, 131 | ) 132 | 133 | 134 | 135 | # redefine the reconstruction error 136 | def get_reconstruction_loss( 137 | self, x, px_rate, px_r, px_dropout, **kwargs 138 | ) -> torch.Tensor: 139 | # Reconstruction Loss 140 | if self.gene_likelihood == "nb": 141 | #### switch to BivariateNegative Binomial 142 | reconst_loss = ( 143 | -BivariateNegativeBinomial(mu=px_rate, 144 | theta=px_r, 145 | custom_dist=self.custom_dist, 146 | THETA_IS = self.THETA_IS, 147 | dispersion = self.dispersion, 148 | mode = self.mode, 149 | **kwargs).log_prob(x).sum(dim=-1) 150 | ) 151 | 152 | else: 153 | raise ValueError("Input valid gene_likelihood ['nb']") 154 | return reconst_loss 155 | 156 | @auto_move_data 157 | def generative( 158 | self, 159 | z, 160 | library, 161 | batch_index, 162 | cont_covs=None, 163 | cat_covs=None, 164 | size_factor=None, 165 | y=None, 166 | transform_batch=None, 167 | ): 168 | """Runs the generative model.""" 169 | # CHANGED FOR BIVAE 170 | # Likelihood distribution 171 | if cont_covs is None: 172 | decoder_input = z 173 | elif z.dim() != cont_covs.dim(): 174 | decoder_input = torch.cat( 175 | [z, cont_covs.unsqueeze(0).expand(z.size(0), -1, -1)], dim=-1 176 | ) 177 | else: 178 | decoder_input = torch.cat([z, cont_covs], dim=-1) 179 | 180 | if cat_covs is not None: 181 | categorical_input = torch.split(cat_covs, 1, dim=1) 182 | else: 183 | categorical_input = tuple() 184 | 185 | if transform_batch is not None: 186 | batch_index = torch.ones_like(batch_index) * transform_batch 187 | 188 | if not self.use_size_factor_key: 189 | size_factor = library 190 | 191 | px_scale, px_r, px_rate, px_dropout = self.decoder( 192 | self.dispersion, 193 | decoder_input, 194 | size_factor, 195 | batch_index, 196 | *categorical_input, 197 | y, 198 | ) 199 | 200 | 201 | if self.dispersion == "gene-label": 202 | px_r = F.linear( 203 | one_hot(y, self.n_labels), self.px_r 204 | ) # px_r gets transposed - last dimension is nb genes 205 | elif self.dispersion == "gene-batch": 206 | px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) 207 | elif self.dispersion == "gene": 208 | px_r = self.px_r 209 | # elif self.dispersion == "gene-cell": 210 | # px_r = self.px_r 211 | 212 | px_r = torch.exp(px_r) 213 | 214 | # if self.gene_likelihood == "zinb": 215 | # px = ZeroInflatedNegativeBinomial( 216 | # mu=px_rate, 217 | # theta=px_r, 218 | # zi_logits=px_dropout, 219 | # scale=px_scale, 220 | # ) 221 | if self.gene_likelihood == "nb": 222 | px = BivariateNegativeBinomial(mu=px_rate, theta=px_r, scale=px_scale, 223 | custom_dist = self.custom_dist, 224 | THETA_IS = self.THETA_IS, 225 | dispersion = self.dispersion, 226 | mode = self.mode) 227 | # elif self.gene_likelihood == "NegativeBinomial": 228 | # px = NegativeBinomial(mu=px_rate, theta=px_r, scale=px_scale) 229 | # # elif self.gene_likelihood == "poisson": 230 | # px = Poisson(px_rate, scale=px_scale) 231 | 232 | # Priors 233 | if self.use_observed_lib_size: 234 | pl = None 235 | else: 236 | ( 237 | local_library_log_means, 238 | local_library_log_vars, 239 | ) = self._compute_local_library_params(batch_index) 240 | pl = Normal(local_library_log_means, local_library_log_vars.sqrt()) 241 | pz = Normal(torch.zeros_like(z), torch.ones_like(z)) 242 | return dict( 243 | px=px, 244 | pl=pl, 245 | pz=pz, 246 | ) 247 | 248 | @torch.inference_mode() 249 | def get_likelihood_parameters( 250 | self, 251 | adata: Optional[AnnData] = None, 252 | indices: Optional[Sequence[int]] = None, 253 | n_samples: Optional[int] = 1, 254 | give_mean: Optional[bool] = False, 255 | batch_size: Optional[int] = None, 256 | ) -> Dict[str, np.ndarray]: 257 | r""" 258 | Estimates for the parameters of the likelihood :math:`p(x \mid z)` 259 | Parameters 260 | ---------- 261 | adata 262 | AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the 263 | AnnData object used to initialize the model. 264 | indices 265 | Indices of cells in adata to use. If `None`, all cells are used. 266 | n_samples 267 | Number of posterior samples to use for estimation. 268 | give_mean 269 | Return expected value of parameters or a samples 270 | batch_size 271 | Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. 272 | """ 273 | adata = self._validate_anndata(adata) 274 | 275 | scdl = self._make_data_loader( 276 | adata=adata, indices=indices, batch_size=batch_size 277 | ) 278 | 279 | dropout_list = [] 280 | mean_list = [] 281 | dispersion_list = [] 282 | 283 | 284 | for tensors in scdl: 285 | inference_kwargs = dict(n_samples=n_samples) 286 | _, generative_outputs = self.module.forward( 287 | tensors=tensors, 288 | inference_kwargs=inference_kwargs, 289 | compute_loss=False, 290 | ) 291 | 292 | px = generative_outputs["px"] 293 | 294 | px_r = px.theta 295 | px_rate = px.mu 296 | 297 | if self.module.gene_likelihood == "zinb": 298 | px_dropout = px.zi_probs 299 | 300 | n_batch = px_rate.size(0) if n_samples == 1 else px_rate.size(1) 301 | 302 | px_r = px_r.cpu().numpy() 303 | if len(px_r.shape) == 1: 304 | dispersion_list += [np.repeat(px_r[np.newaxis, :], n_batch, axis=0)] 305 | else: 306 | dispersion_list += [px_r] 307 | mean_list += [px_rate.cpu().numpy()] 308 | if self.module.gene_likelihood == "zinb": 309 | dropout_list += [px_dropout.cpu().numpy()] 310 | dropout = np.concatenate(dropout_list, axis=-2) 311 | 312 | means = np.concatenate(mean_list, axis=-2) 313 | dispersions = np.concatenate(dispersion_list, axis=-2) 314 | 315 | if give_mean and n_samples > 1: 316 | if self.module.gene_likelihood == "zinb": 317 | dropout = dropout.mean(0) 318 | means = means.mean(0) 319 | dispersions = dispersions.mean(0) 320 | 321 | return_dict = {} 322 | return_dict["mean"] = means 323 | 324 | 325 | if self.module.gene_likelihood == "zinb": 326 | return_dict["dropout"] = dropout 327 | return_dict["dispersions"] = dispersions 328 | if self.module.gene_likelihood == "nb": 329 | return_dict["dispersions"] = dispersions 330 | print('gene likelihood nb, getting params') 331 | 332 | if self.module.mode == 'Bursty': 333 | print('Bursty mode, returning parameters') 334 | mu1 = means[:,:int(np.shape(means)[1]/2)] 335 | mu2 = means[:,int(np.shape(means)[1]/2):] 336 | return_dict['unspliced_means'] = mu1 337 | return_dict['spliced_means'] = mu2 338 | return_dict['dispersions'] = dispersions 339 | 340 | b,beta,gamma = get_bursty_params(mu1,mu2,dispersions,THETA_IS = self.module.THETA_IS) 341 | 342 | return_dict['burst_size'] = b 343 | return_dict['rel_splicing_rate'] = beta 344 | return_dict['rel_degradation_rate'] = gamma 345 | 346 | if self.module.mode == 'NBcorr': 347 | mu1 = means[:,:np.shape(params['mean'])[1]/2] 348 | mu2 = means[:,np.shape(params['mean'])[1]/2:] 349 | return_dict['unspliced_means'] = mu1 350 | return_dict['spliced_means'] = mu2 351 | return_dict['dispersions'] = dispersions 352 | 353 | alpha,beta,gamma = get_extrinsic_params(mu1,mu2,dispersions) 354 | 355 | return_dict['alpha'] = alpha 356 | return_dict['rel_splicing_rate'] = beta 357 | return_dict['rel_degradation_rate'] = gamma 358 | 359 | if self.module.mode == 'Poisson': 360 | mu1 = means[:,:np.shape(params['mean'])[1]/2] 361 | mu2 = means[:,np.shape(params['mean'])[1]/2:] 362 | return_dict['unspliced_means'] = mu1 363 | return_dict['spliced_means'] = mu2 364 | 365 | beta,gamma = 1/mu1,1/mu2 366 | 367 | return_dict['rel_splicing_rate'] = beta 368 | return_dict['rel_degradation_rate'] = gamma 369 | 370 | return return_dict 371 | 372 | 373 | 374 | -------------------------------------------------------------------------------- /BIVI/BIVI/distributions.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | 4 | from typing import Union, Tuple, Optional 5 | import warnings 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.distributions import constraints, Distribution, Gamma, Poisson 10 | from torch.distributions.utils import ( 11 | broadcast_all, 12 | probs_to_logits, 13 | lazy_property, 14 | logits_to_probs, 15 | ) 16 | 17 | 18 | class BivariateNegativeBinomial(Distribution): 19 | # """ 20 | # Negative binomial distribution. 21 | # One of the following parameterizations must be provided: 22 | # (1), (`total_count`, `probs`) where `total_count` is the number of failures until 23 | # the experiment is stopped and `probs` the success probability. (2), (`mu`, `theta`) 24 | # parameterization, which is the one used by scvi-tools. These parameters respectively 25 | # control the mean and inverse dispersion of the distribution. 26 | # In the (`mu`, `theta`) parameterization, samples from the negative binomial are generated as follows: 27 | # 1. :math:`w \sim \textrm{Gamma}(\underbrace{\theta}_{\text{shape}}, \underbrace{\theta/\mu}_{\text{rate}})` 28 | # 2. :math:`x \sim \textrm{Poisson}(w)` 29 | # Parameters 30 | # ---------- 31 | # total_count 32 | # Number of failures until the experiment is stopped. 33 | # probs 34 | # The success probability. 35 | # mu 36 | # Mean of the distribution. 37 | # theta 38 | # Inverse dispersion. 39 | # validate_args 40 | # Raise ValueError if arguments do not match constraints 41 | # model_weight 42 | # """ 43 | 44 | arg_constraints = { 45 | "mu": constraints.greater_than_eq(0), 46 | "theta": constraints.greater_than_eq(0), 47 | } 48 | support = constraints.nonnegative_integer 49 | 50 | def __init__( 51 | self, 52 | total_count: Optional[torch.Tensor] = None, 53 | probs: Optional[torch.Tensor] = None, 54 | logits: Optional[torch.Tensor] = None, 55 | mu: Optional[torch.Tensor] = None, 56 | theta: Optional[torch.Tensor] = None, 57 | validate_args: bool = False, 58 | custom_dist = None, 59 | scale: Optional[torch.Tensor] = None, 60 | THETA_IS: str ='NAS_SHAPE', 61 | dispersion: str = 'gene', 62 | mode: str = 'Bursty', 63 | **kwargs, 64 | ): 65 | 66 | super().__init__(validate_args=validate_args) 67 | 68 | 69 | self._eps = 1e-8 70 | if (mu is None) == (total_count is None): 71 | raise ValueError( 72 | "Please use one of the two possible parameterizations. Refer to the documentation for more information." 73 | ) 74 | 75 | using_param_1 = total_count is not None and ( 76 | logits is not None or probs is not None 77 | ) 78 | if using_param_1: 79 | logits = logits if logits is not None else probs_to_logits(probs) 80 | total_count = total_count.type_as(logits) 81 | total_count, logits = broadcast_all(total_count, logits) 82 | mu, theta = _convert_counts_logits_to_mean_disp(total_count, logits) 83 | else: 84 | mu1,mu2 = torch.chunk(mu,2,dim=-1) 85 | if (dispersion == 'gene-cell') and (mode != 'NB'): 86 | theta = theta[...,:int(theta.shape[-1]/2)] 87 | if mode != 'NB': 88 | mu1, mu2, theta = broadcast_all(mu1, mu2, theta) 89 | 90 | #### Modified for bivariate 91 | self.mu = mu 92 | self.mu1, self.mu2 = mu1, mu2 93 | self.theta = theta 94 | self.use_custom = custom_dist is not None 95 | self.custom_dist = custom_dist 96 | self.scale = scale 97 | self.THETA_IS = THETA_IS 98 | 99 | # print('MEANS UNSPLICED SHAPE', mu1.shape) 100 | # print('MEANS UNSPLICED',mu1) 101 | 102 | # print('MEANS SPLICED SHAPE', mu2.shape) 103 | # print('MEANS SPLICED',mu2) 104 | 105 | # print('THETA SHAPE', theta.shape) 106 | # print('THETA',theta) 107 | 108 | 109 | @property 110 | def mean(self): 111 | return self.mu 112 | 113 | @property 114 | def variance(self): 115 | return self.mean + (self.mean ** 2) / self.theta 116 | 117 | def sample( 118 | self, sample_shape: Union[torch.Size, Tuple] = torch.Size() 119 | ) -> torch.Tensor: 120 | with torch.no_grad(): 121 | gamma_d = self._gamma() 122 | p_means = gamma_d.sample(sample_shape) 123 | 124 | # Clamping as distributions objects can have buggy behaviors when 125 | # their parameters are too high 126 | l_train = torch.clamp(p_means, max=1e8) 127 | counts = Poisson( 128 | l_train 129 | ).sample() # Shape : (n_samples, n_cells_batch, n_vars) 130 | return counts 131 | 132 | def log_prob(self, value: torch.Tensor) -> torch.Tensor: 133 | if self._validate_args: 134 | try: 135 | self._validate_sample(value) 136 | except ValueError: 137 | warnings.warn( 138 | "The value argument must be within the support of the distribution", 139 | UserWarning, 140 | ) 141 | 142 | if self.use_custom: 143 | calculate_log_nb = log_prob_custom 144 | log_nb = calculate_log_nb(value, 145 | mu1=self.mu1, mu2=self.mu2, 146 | theta=self.theta, eps=self._eps, 147 | THETA_IS = self.THETA_IS, 148 | custom_dist = self.custom_dist) 149 | else: 150 | log_nb = log_prob_NBuncorr(value, 151 | mu1 = self.mu1, mu2 = self.mu2, eps = self._eps) 152 | 153 | return log_nb 154 | 155 | def _gamma(self): 156 | return _gamma(self.theta, self.mu) 157 | 158 | def log_prob_custom(x: torch.Tensor, mu1: torch.Tensor, mu2: torch.Tensor, 159 | theta: torch.Tensor, THETA_IS, eps=1e-8, 160 | custom_dist=None, **kwargs): 161 | """ 162 | Log likelihood (scalar) of a minibatch according to a bivariate nb model 163 | where individual genes use one of the distributions 164 | """ 165 | 166 | assert custom_dist is not None, "Input a custom_dist" 167 | res = custom_dist(x=x, mu1=mu1, mu2=mu2, theta=theta, eps=eps, THETA_IS = THETA_IS) 168 | 169 | return res 170 | 171 | 172 | 173 | 174 | 175 | 176 | def log_prob_poisson(x: torch.Tensor, mu1: torch.Tensor, mu2: torch.Tensor, 177 | theta: torch.Tensor, THETA_IS, eps, **kwargs): 178 | ''' Calculates the uncorrelated Poisson likelihood for nascent and mature: just returns Poisson(n; mu1)*Poisson(m; mu2).''' 179 | # Divide the original data x into spliced (x) and unspliced (y) 180 | n,m = torch.chunk(x,2,dim=-1) 181 | 182 | # DOES NOT USE THETA AT ALL 183 | 184 | #compute the Poisson term for n and m (uncorrelated) 185 | y_n = n * torch.log(mu1+eps) - mu1- torch.lgamma(n+1) 186 | y_m = m * torch.log(mu2+eps) - mu2- torch.lgamma(m+1) 187 | 188 | P = y_n + y_m 189 | 190 | 191 | return P 192 | 193 | 194 | def log_prob_NBcorr(x: torch.Tensor, mu1: torch.Tensor, mu2: torch.Tensor, 195 | theta: torch.Tensor, THETA_IS, eps=1e-8): 196 | """ 197 | Log likelihood (scalar) of a minibatch according to a bivariate nb model. 198 | Parameters 199 | ---------- 200 | x 201 | data 202 | mu1,mu2 203 | mean of the negative binomial (has to be positive support) (shape: minibatch x vars/2) 204 | theta 205 | params (has to be positive support) (shape: minibatch x vars) 206 | eps 207 | numerical stability constant 208 | Notes 209 | ----- 210 | We parametrize the bernoulli using the logits, hence the softplus functions appearing. 211 | """ 212 | 213 | # Divide the original data x into spliced (x) and unspliced (y) 214 | x,y = torch.chunk(x,2,dim=-1) 215 | 216 | if theta.ndimension() == 1: 217 | theta = theta.view( 218 | 1, theta1.size(0) 219 | ) # In this case, we reshape theta for broadcasting 220 | 221 | log_theta_mu_eps = torch.log(theta + mu1 + mu2 + eps) # theta1 used here 222 | 223 | res = ( 224 | theta * (torch.log(theta + eps) - log_theta_mu_eps) 225 | + x * (torch.log(mu1 + eps) - log_theta_mu_eps) 226 | + y * (torch.log(mu2 + eps) - log_theta_mu_eps) 227 | + torch.lgamma(x + y + theta) 228 | - torch.lgamma(theta) 229 | - torch.lgamma(x + 1) 230 | - torch.lgamma(y + 1) 231 | ) 232 | 233 | return res 234 | 235 | def log_prob_NBuncorr(x: torch.Tensor, mu1: torch.Tensor, mu2: torch.Tensor, 236 | theta: torch.Tensor, THETA_IS, eps=1e-8): 237 | """ 238 | Log likelihood (scalar) of a minibatch according to a bivariate nb model 239 | where spliced and unspliced are predicted separately. 240 | Parameters 241 | ---------- 242 | x 243 | data 244 | mu1,mu2 245 | mean of the negative binomial (has to be positive support) (shape: minibatch x vars/2) 246 | theta 247 | params (has to be positive support) (shape: minibatch x vars) 248 | eps 249 | numerical stability constant 250 | Notes 251 | ----- 252 | We parametrize the bernoulli using the logits, hence the softplus functions appearing. 253 | """ 254 | 255 | # Divide the original data x into spliced (x) and unspliced (y) 256 | x,y = torch.chunk(x,2,dim=-1) 257 | 258 | if theta.ndimension() == 1: 259 | theta = theta.view( 260 | 1, theta1.size(0) 261 | ) # In this case, we reshape theta for broadcasting 262 | 263 | # In contrast to log_nb_positive_bi, 264 | log_theta_mu1_eps = torch.log(theta + mu1 + eps) 265 | log_theta_mu2_eps = torch.log(theta + mu2 + eps) 266 | 267 | res = ( 268 | theta * (2* torch.log(theta + eps) - log_theta_mu1_eps - log_theta_mu2_eps) 269 | + x * (torch.log(mu1 + eps) - log_theta_mu1_eps) 270 | + torch.lgamma(x + theta) 271 | - 2*torch.lgamma(theta) 272 | - torch.lgamma(x + 1) 273 | + y * (torch.log(mu2 + eps) - log_theta_mu2_eps) 274 | + torch.lgamma(y + theta) 275 | - torch.lgamma(y + 1) 276 | ) 277 | 278 | return res 279 | 280 | 281 | def log_prob_NBuncorr(x: torch.Tensor, mu1: torch.Tensor, mu2: torch.Tensor, 282 | theta: torch.Tensor, THETA_IS, eps=1e-8): 283 | """ 284 | Log likelihood (scalar) of a minibatch according to a bivariate nb model 285 | where spliced and unspliced are predicted separately. 286 | Parameters 287 | ---------- 288 | x 289 | data 290 | mu1,mu2 291 | mean of the negative binomial (has to be positive support) (shape: minibatch x vars/2) 292 | theta 293 | params (has to be positive support) (shape: minibatch x vars) 294 | eps 295 | numerical stability constant 296 | Notes 297 | ----- 298 | We parametrize the bernoulli using the logits, hence the softplus functions appearing. 299 | """ 300 | 301 | # Divide the original data x into spliced (m) and unspliced (n) 302 | # divide theta as well 303 | n,m = torch.chunk(x,2,dim=-1) 304 | theta1,theta2 = torch.chunk(theta,2,dim=-1) 305 | 306 | # In contrast to log_nb_positive_bi, 307 | log_theta1_mu1_eps = torch.log(theta1 + mu1 + eps) 308 | log_theta2_mu2_eps = torch.log(theta2 + mu2 + eps) 309 | 310 | 311 | res1 = ( 312 | theta1 * (torch.log(theta1 + eps) - log_theta1_mu1_eps) 313 | + n * (torch.log(mu1 + eps) - log_theta1_mu1_eps) 314 | + torch.lgamma(n + theta1) 315 | - torch.lgamma(theta1) 316 | - torch.lgamma(n + 1) 317 | ) 318 | 319 | res2 = ( 320 | theta2 * (torch.log(theta2 + eps) - log_theta2_mu2_eps) 321 | + m * (torch.log(mu2 + eps) - log_theta2_mu2_eps) 322 | + torch.lgamma(m + theta2) 323 | - torch.lgamma(theta2) 324 | - torch.lgamma(m + 1) 325 | ) 326 | return res1+res2 -------------------------------------------------------------------------------- /BIVI/BIVI/models/README.md: -------------------------------------------------------------------------------- 1 | Contains neural networks for custom distributions. 2 | -------------------------------------------------------------------------------- /BIVI/BIVI/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /BIVI/BIVI/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pachterlab/CGCCP_2023/8e4b6c99e3bda5d664bea51b89302357538e5bd5/BIVI/BIVI/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /BIVI/BIVI/models/best_model_MODEL.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pachterlab/CGCCP_2023/8e4b6c99e3bda5d664bea51b89302357538e5bd5/BIVI/BIVI/models/best_model_MODEL.zip -------------------------------------------------------------------------------- /BIVI/BIVI/nnNB_module.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | 4 | import os 5 | import importlib_resources 6 | 7 | import numpy as np 8 | from scipy import stats 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | ## YC added 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | class MLP(nn.Module): 18 | 19 | def __init__(self, input_dim, npdf, h1_dim, h2_dim): 20 | super().__init__() 21 | 22 | self.input = nn.Linear(input_dim, h1_dim) 23 | self.hidden = nn.Linear(h1_dim, h2_dim) 24 | self.output = nn.Linear(h2_dim, npdf) 25 | 26 | self.hyp = nn.Linear(h1_dim,1) 27 | 28 | self.softmax = nn.Softmax(dim=1) 29 | self.sigmoid = torch.sigmoid 30 | 31 | 32 | def forward(self, inputs): 33 | 34 | # pass inputs to first layer, apply sigmoid 35 | l_1 = self.sigmoid(self.input(inputs)) 36 | 37 | # pass to second layer, apply sigmoid 38 | l_2 = self.sigmoid(self.hidden(l_1)) 39 | 40 | # pass to output layer 41 | w_un = (self.output(l_2)) 42 | 43 | # pass out hyperparameter, sigmoid so it is between 0 and 1, then scale between 1 and 6 44 | hyp = self.sigmoid(self.hyp(l_2)) 45 | 46 | # apply softmax 47 | w_pred = self.softmax(w_un) 48 | 49 | return w_pred,hyp 50 | 51 | # YC added 52 | try: 53 | package_resources = importlib_resources.files("BIVI") 54 | model_path = os.path.join(package_resources,'models/best_model_MODEL.zip') 55 | print(package_resources) 56 | except: 57 | import sys 58 | package_resources = importlib_resources.files("models") 59 | model_path = os.path.join(package_resources,'best_model_MODEL.zip') 60 | 61 | npdf = 10 62 | 63 | # load in model 64 | model = MLP(7,10,256,256) 65 | model.load_state_dict(torch.load(model_path)) 66 | model.eval() 67 | model.to(torch.device(device)) 68 | 69 | 70 | def get_NORM(npdf,quantiles='cheb'): 71 | '''' Returns quantiles based on the number of kernel functions npdf. 72 | Chebyshev or linear, with chebyshev as default. 73 | ''' 74 | if quantiles == 'lin': 75 | q = np.linspace(0,1,npdf+2)[1:-1] 76 | norm = stats.norm.ppf(q) 77 | norm = torch.tensor(norm) 78 | return norm 79 | if quantiles == 'cheb': 80 | n = np.arange(npdf) 81 | q = np.flip((np.cos((2*(n+1)-1)/(2*npdf)*np.pi)+1)/2) 82 | 83 | norm = stats.norm.ppf(q) 84 | norm = torch.tensor(norm) 85 | return norm 86 | 87 | NORM = get_NORM(10).to(torch.device(device)) 88 | norm = NORM 89 | 90 | 91 | def generate_grid(logmean_cond,logstd_cond,norm): 92 | ''' Generate grid of kernel means based on the log mean and log standard devation of a conditional distribution. 93 | Generates the grid of quantile values in NORM, scaled by conditional moments. 94 | ''' 95 | 96 | logmean_cond = torch.reshape(logmean_cond,(-1,1)) 97 | logstd_cond = torch.reshape(logstd_cond,(-1,1)) 98 | translin = torch.exp(torch.add(logmean_cond,logstd_cond*norm)) 99 | 100 | return translin 101 | 102 | def get_ypred_at_RT(p,w,hyp,n,m,norm,eps=1e-8): 103 | '''Given a parameter vector (tensor) and weights (tensor), and hyperparameter, 104 | calculates ypred (Y), or approximate probability. Calculates over array of nascent (n) and mature (m) values. 105 | ''' 106 | 107 | p_vec = 10**p[:,0:3] 108 | logmean_cond = p[:,3] 109 | logstd_cond = p[:,4] 110 | 111 | hyp = hyp*5+1 112 | 113 | grid = generate_grid(logmean_cond,logstd_cond,norm) 114 | s = torch.zeros((len(n),10)).to(torch.device(device)) 115 | s[:,:-1] = torch.diff(grid,axis=1) 116 | s *= hyp 117 | s[:,-1] = torch.sqrt(grid[:,-1]) 118 | 119 | 120 | v = s**2 121 | r = grid**2/((v-grid)+eps) 122 | p_nb = 1-grid/v 123 | 124 | Y = torch.zeros((len(n),1)).to(torch.device(device)) 125 | 126 | y_ = m * torch.log(grid + eps) - grid - torch.lgamma(m+1) 127 | m_array = m.repeat(1,10) 128 | 129 | if (p_nb > 1e-10).any(): 130 | index = [p_nb > 1e-10] 131 | y_[index] += torch.special.gammaln(m_array[index]+r[index]) - torch.special.gammaln(r[index]) \ 132 | - m_array[index]*torch.log(r[index] + grid[index]) + grid[index] + r[index]*torch.log(r[index]/(r[index]+grid[index])) 133 | 134 | y_ = torch.exp(y_) 135 | y_weighted = w*y_ 136 | Y = y_weighted.sum(axis=1) 137 | 138 | EPS = 1e-40 139 | Y[Y this only has to be done once for a reference genome 16 | 17 | # download reference genome and annotations to $main_path/references/ 18 | mkdir -p $main_path/references 19 | cd $main_path/references 20 | wget https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_38/GRCh38.primary_assembly.genome.fa.gz 21 | wget https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_38/gencode.v38.primary_assembly.annotation.gtf.gz 22 | 23 | 24 | # create nascent and mature indices 25 | cd $main_path 26 | mkdir -p $main_path/indices 27 | kb ref --workflow=lamanno --verbose --overwrite -i $main_path/indices/human_lamanno.idx -g $main_path/indices/human_lamanno.t2g -c1 $main_path/indices/human_lamanno.mature.t2c -c2 $main_path/indices/human_lamanno.nascent.t2c -f1 $main_path/indices/human.lamanno.mature.fa -f2 $main_path/indices/human.lamanno.nascent.fa $main_path/references/GRCh38.primary_assembly.genome.fa.gz $main_path/references/gencode.v38.primary_assembly.annotation.gtf.gz 28 | 29 | 30 | 31 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 32 | # download 10x v3 1k PBMC raw data (change paths to reflect desired data) 33 | 34 | mkdir -p $main_path/pbmc_1k_v3_raw 35 | cd $main_path/pbmc_1k_v3_raw/ 36 | curl -O https://cf.10xgenomics.com/samples/cell-exp/3.0.0/pbmc_1k_v3/pbmc_1k_v3_fastqs.tar # 5-10 minutes 37 | tar -xvf pbmc_1k_v3_fastqs.tar 38 | 39 | 40 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 41 | # psuedoalignment and generation of count matrices 42 | cd $main_path 43 | mkdir -p $main_path/pbmc_1k_v3/ 44 | 45 | kb count --verbose \ 46 | -i $main_path/indices/human_lamanno.idx \ 47 | -g $main_path/indices/human_lamanno.t2g \ 48 | -x 10xv3 \ 49 | -o $main_path/pbmc_1k_v3/ \ 50 | -t $threads -m 30G \ 51 | -c1 $main_path/indices/human_lamanno.mature.t2c \ 52 | -c2 $main_path/indices/human_lamanno.nascent.t2c \ 53 | --workflow lamanno --filter bustools --overwrite --loom \ 54 | $main_path/pbmc_1k_v3_raw/pbmc_1k_v3_fastqs/pbmc_1k_v3_S1_L001_R1_001.fastq.gz \ 55 | $main_path/pbmc_1k_v3_raw/pbmc_1k_v3_fastqs/pbmc_1k_v3_S1_L001_R2_001.fastq.gz 56 | $main_path/pbmc_1k_v3_raw/pbmc_1k_v3_fastqs/pbmc_1k_v3_S1_L002_R1_001.fastq.gz \ 57 | $main_path/pbmc_1k_v3_raw/pbmc_1k_v3_fastqs/pbmc_1k_v3_S1_L002_R2_001.fastq.gz 58 | -------------------------------------------------------------------------------- /Example/pbmc_1k_v3/counts_filtered/adata.loom: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pachterlab/CGCCP_2023/8e4b6c99e3bda5d664bea51b89302357538e5bd5/Example/pbmc_1k_v3/counts_filtered/adata.loom -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2023, Pachter Lab 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 19 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 20 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 21 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 22 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 23 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /Manuscript/analysis/Fig_2e-f_Differential_Expression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# New differential expression testing using Bayes Factor\n", 8 | "\n", 9 | "\n", 10 | "\n", 11 | "Compare sets of DE genes identified using standard hypothesis testing (t-test) on point estimates of parameters and scVI bayes factor tests. \n", 12 | "\n", 13 | "\n", 14 | "\n", 15 | "### Bayes factor calculation *a la scVI* and *totalVI*\n", 16 | "\n", 17 | "\n", 18 | "Hypothesis testing in a Bayesian setting can be done by comparing two hypothesis, $H_0$ (null) and $H_1$ (alternate), and choosing the one with the higher probability given the data $X$. That is, accepting or rejecting $H_0$ by comparing the Bayes Factor to a set threshold value:\n", 19 | "\n", 20 | "$$ \\frac{P(H_1 | X)}{P(H_0 | X)}\n", 21 | "$$\n" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 1, 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "name": "stderr", 31 | "output_type": "stream", 32 | "text": [ 33 | "Global seed set to 0\n", 34 | "/home/tara/.local/lib/python3.8/site-packages/pytorch_lightning/utilities/warnings.py:53: LightningDeprecationWarning: pytorch_lightning.utilities.warnings.rank_zero_deprecation has been deprecated in v1.6 and will be removed in v1.8. Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead.\n", 35 | " new_rank_zero_deprecation(\n", 36 | "/home/tara/.local/lib/python3.8/site-packages/pytorch_lightning/utilities/warnings.py:58: LightningDeprecationWarning: The `pytorch_lightning.loggers.base.rank_zero_experiment` is deprecated in v1.7 and will be removed in v1.9. Please use `pytorch_lightning.loggers.logger.rank_zero_experiment` instead.\n", 37 | " return new_rank_zero_deprecation(*args, **kwargs)\n" 38 | ] 39 | }, 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "/home/tara/temp_git2/CGCCP_2023/Manuscript/analysis/../../BIVI/BIVI\n", 45 | "0.18.0\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "# system\n", 51 | "import sys, os\n", 52 | "sys.path.insert(0,'../../BIVI/')\n", 53 | "sys.path.insert(0,'../../BIVI/BIVI')\n", 54 | "\n", 55 | "# timing\n", 56 | "import time\n", 57 | "\n", 58 | "# numbers\n", 59 | "import numpy as np\n", 60 | "import torch\n", 61 | "import pandas as pd\n", 62 | "\n", 63 | "# sc \n", 64 | "import anndata\n", 65 | "\n", 66 | "# plots\n", 67 | "import matplotlib.pyplot as plt\n", 68 | "import seaborn as sns\n", 69 | "cmap = plt.get_cmap('Purples')\n", 70 | "cmap_green = plt.get_cmap('Greens')\n", 71 | "cmap_orange = plt.get_cmap('Oranges')\n", 72 | "cmap_red = plt.get_cmap('Reds')\n", 73 | "cmap_blue = plt.get_cmap('Blues')\n", 74 | "cmap_ygb = plt.get_cmap('YlGnBu')\n", 75 | "\n", 76 | "# biVI\n", 77 | "import biVI\n", 78 | "import scvi\n", 79 | "print(scvi.__version__)\n", 80 | "\n", 81 | "\n", 82 | "\n", 83 | "# reproducibility, set random seeds\n", 84 | "scvi._settings.ScviConfig.seed=(8675309)\n", 85 | "torch.manual_seed(8675309)\n", 86 | "np.random.seed(8675309)" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "Load in data." 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 3, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "name": "stderr", 103 | "output_type": "stream", 104 | "text": [ 105 | "/usr/local/lib/python3.8/dist-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n", 106 | " warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n" 107 | ] 108 | } 109 | ], 110 | "source": [ 111 | "# first load in data\n", 112 | "adata = anndata.read_loom('../data/allen/B08_processed_hv.loom')\n", 113 | "\n", 114 | "# make variable names unique \n", 115 | "adata.var_names_make_unique()\n", 116 | "\n", 117 | "# remove genes with fewer than 10 cells\n", 118 | "cell_types = np.array(adata.obs['subclass_label'].tolist())\n", 119 | "\n", 120 | "\n", 121 | "# ordered according to cell subclass\n", 122 | "unique_cell_types = ['Lamp5', 'Sncg', 'Vip', 'Sst', 'Pvalb',\n", 123 | " 'L2/3 IT', 'L5 IT', 'L5/6 NP', 'L6 CT', 'L6 IT', 'L6b',\n", 124 | " 'Astro', 'OPC', 'Oligo', 'Macrophage', 'Endo']\n", 125 | "\n", 126 | "\n", 127 | "for ct in unique_cell_types:\n", 128 | " \n", 129 | " cells_per_ct_ = (ct == cell_types).sum()\n", 130 | " if cells_per_ct_ < 10.0:\n", 131 | " adata = adata[adata.obs['subclass_label'] != ct, :]\n", 132 | " \n", 133 | "adata = adata.copy() \n", 134 | "\n", 135 | "cell_types = np.array(adata.obs['subclass_label'].tolist())\n", 136 | "\n" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "## Load in trained models" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "import importlib\n", 153 | "importlib.reload(biVI)" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "model1 = biVI.biVI.load(\".../results/Bursty_B08_processed_hv_MODEL\", adata=adata, use_gpu = True)\n", 163 | "model2 = scvi.model.SCVI.load(\".../results/scVI_B08_processed_hv_MODEL\", adata=adata, use_gpu = True)\n", 164 | "# model3 = biVI.biVI.load(\".../results/Constitutive_B08_processed_hv_MODEL\", adata=adata, use_gpu = True)\n", 165 | "# model4 = biVI.biVI.load(\".../results/Extrinsic_B08_processed_hv_MODEL\", adata=adata, use_gpu = True)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "metadata": {}, 171 | "source": [ 172 | "## Calculate Bayes Factors for one cell type versus rest " 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "# scVI has a nice built function that allows easy comparisons between cell types\n", 182 | "delta = 1.0\n", 183 | "scVI_BF_DE_built = model2.differential_expression(adata, groupby = \"Cell Type\", m_permutation = 10000,\n", 184 | " mode = \"change\", delta = delta, n_samples = 20)" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "# In biVI, for now loop through cell types and calculate BF for each\n", 194 | "n_samples = 20\n", 195 | "parameters = ['norm_burst_size','norm_degradation_rate','norm_spliced_mean','norm_unspliced_mean']\n", 196 | "biVI_BF_DE = {ct : {} for ct in unique_cell_types}\n", 197 | "scVI_BF_DE = {ct : {} for ct in unique_cell_types}\n", 198 | "const_BF_DE = {ct : {} for ct in unique_cell_types}\n", 199 | "ext_BF_DE = {ct : {} for ct in unique_cell_types}\n", 200 | "\n", 201 | "for cell_type in unique_cell_types:\n", 202 | " print(cell_type)\n", 203 | " idx1 = np.arange(len(cell_types))[cell_types == cell_type]\n", 204 | " idx2 = np.arange(len(cell_types))[cell_types != cell_type]\n", 205 | "\n", 206 | " biVI_BF_DE[cell_type] = model1.get_bayes_factors(adata,idx1,idx2,\n", 207 | " n_samples_1 = n_samples,\n", 208 | " n_samples_2 = n_samples,\n", 209 | " n_comparisons = 10000,\n", 210 | " return_df = True,\n", 211 | " delta = delta)\n", 212 | " \n", 213 | " \n", 214 | "# params1 = model2.get_normalized_expression(adata[idx1],\n", 215 | "# n_samples = n_samples, return_mean = False)\n", 216 | "# params2 = model2.get_normalized_expression(adata[idx2],\n", 217 | "# n_samples = n_samples, return_mean = False)\n", 218 | "# params_dict_1 = {}\n", 219 | "# params_dict_2 = {}\n", 220 | "# params_dict_1['norm_unspliced_mean'] = params1[:,:,:2000]\n", 221 | "# params_dict_1['norm_spliced_mean'] = params1[:,:,2000:]\n", 222 | "# params_dict_2['norm_unspliced_mean'] = params2[:,:,:2000]\n", 223 | "# params_dict_2['norm_spliced_mean'] = params2[:,:,2000:]\n", 224 | " \n", 225 | "# scVI_BF_DE[cell_type] = model1.get_bayes_factors(adata,idx1,idx2,\n", 226 | "# n_samples_1 = n_samples,\n", 227 | "# n_samples_2 = n_samples,\n", 228 | "# n_comparisons = 10000,\n", 229 | "# return_df = True,\n", 230 | "# delta = delta,\n", 231 | "# params_dict_1 = params_dict_1,\n", 232 | "# params_dict_2 = params_dict_2)\n", 233 | " \n", 234 | "# params1c = model3.get_normalized_expression(adata[idx1],\n", 235 | "# n_samples = n_samples, return_mean = False)\n", 236 | "# params2c = model3.get_normalized_expression(adata[idx2],\n", 237 | "# n_samples = n_samples, return_mean = False)\n", 238 | "\n", 239 | "# const_BF_DE[cell_type] = model3.get_bayes_factors(adata,idx1,idx2,\n", 240 | "# n_samples_1 = n_samples,\n", 241 | "# n_samples_2 = n_samples,\n", 242 | "# n_comparisons = 10000,\n", 243 | "# return_df = True,\n", 244 | "# delta = 1.0)\n", 245 | " \n", 246 | "# ext_BF_DE[cell_type] = model3.get_bayes_factors(adata,idx1,idx2,\n", 247 | "# n_samples_1 = n_samples,\n", 248 | "# n_samples_2 = n_samples,\n", 249 | "# n_comparisons = 10000,\n", 250 | "# return_df = True,\n", 251 | "# delta = 1.0)" 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "metadata": {}, 257 | "source": [ 258 | "Get significant genes using Bayes Factor for scVI and biVI.\n", 259 | "\n", 260 | "The scVI paper uses a Bayes Factor cutoff of 0.7 and a delta cutoff of 0.25 (although this seems too small to me, perhaps I will use a more stringent cutoff).\n", 261 | "\n", 262 | "\n", 263 | "The delta cutoff means that absolute value log2(param_in_A/param_in_rest) >= delta.\n", 264 | "\n", 265 | "The Bayes Factor is the probability that log2(param_in_A/param_in_rest) >= delta divided by the probability that it is less than delta. \n" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "genes_unspliced = adata.var['gene_name'].tolist()\n", 275 | "genes_unspliced[2000:] = [0]*2000\n", 276 | "genes_spliced = adata.var['gene_name'].tolist()\n", 277 | "genes_spliced[:2000] = [0]*2000\n", 278 | "genes = np.array(adata.var['gene_name'])[:2000]" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [ 287 | "biVI_sig_BF_genes = {ct : {} for ct in unique_cell_types}\n", 288 | "# scVI_sig_BF_genes = {ct : {} for ct in unique_cell_types}\n", 289 | "# const_sig_BF_genes = {ct : {} for ct in unique_cell_types}\n", 290 | "# ext_sig_BF_genes = {ct : {} for ct in unique_cell_types}\n", 291 | "\n", 292 | "scVI_sig_BF_genes_built = {ct : {} for ct in unique_cell_types}\n", 293 | "\n", 294 | "BF_THRESH = 1.5\n", 295 | "LFC_THRESH = 1.0\n", 296 | "\n", 297 | "\n", 298 | "for cell_type in unique_cell_types:\n", 299 | " print(cell_type)\n", 300 | " idx_scVI = np.array(scVI_BF_DE_built[(scVI_BF_DE_built['group1']==cell_type) & \n", 301 | " (scVI_BF_DE_built['bayes_factor']>BF_THRESH) &\n", 302 | " (scVI_BF_DE_built['lfc_mean']>LFC_THRESH)].index,dtype=int)\n", 303 | "\n", 304 | " sig_genes_unspliced = list(map(genes_unspliced.__getitem__,idx_scVI))\n", 305 | " sig_genes_unspliced = [g for g in sig_genes_unspliced if g != 0]\n", 306 | " \n", 307 | " sig_genes_spliced = list(map(genes_spliced.__getitem__,idx_scVI))\n", 308 | " sig_genes_spliced = [g for g in sig_genes_spliced if g != 0]\n", 309 | " \n", 310 | " \n", 311 | " scVI_sig_BF_genes_built[cell_type]['norm_unspliced_mean'] = sig_genes_unspliced\n", 312 | " scVI_sig_BF_genes_built[cell_type]['norm_spliced_mean'] = sig_genes_spliced\n", 313 | " \n", 314 | " for param in biVI_BF_DE[cell_type].keys():\n", 315 | " \n", 316 | " df = biVI_BF_DE[cell_type][param]\n", 317 | " \n", 318 | " idx_biVI = np.array(df[(np.log(df['bayes_factor'])>BF_THRESH) &\n", 319 | " (np.log(df['lfc_mean'])>LFC_THRESH)].index,dtype=int)\n", 320 | " sig_genes_param = list(map(genes.__getitem__,idx_biVI))\n", 321 | " biVI_sig_BF_genes[cell_type][param] = sig_genes_param\n", 322 | " \n", 323 | "# for param in scVI_BF_DE[cell_type].keys():\n", 324 | "# df = scVI_BF_DE[cell_type][param]\n", 325 | " \n", 326 | "# idx_scVI = np.array(df[np.log(df['bayes_factor'])>BF_THRESH].index,dtype=int)\n", 327 | "# sig_genes_param = list(map(genes.__getitem__,idx_scVI))\n", 328 | " \n", 329 | "# scVI_sig_BF_genes[cell_type][param] = sig_genes_param\n", 330 | "\n", 331 | "# for param in const_BF_DE[cell_type].keys():\n", 332 | "# df = const_BF_DE[cell_type][param]\n", 333 | " \n", 334 | "# idx_const = np.array(df[np.log(df['bayes_factor'])>BF_THRESH].index,dtype=int)\n", 335 | "# sig_genes_param = list(map(genes.__getitem__,idx_const))\n", 336 | " \n", 337 | "# const_sig_BF_genes[cell_type][param] = sig_genes_param\n", 338 | " \n", 339 | "# for param in ext_BF_DE[cell_type].keys():\n", 340 | "# df = ext_BF_DE[cell_type][param]\n", 341 | " \n", 342 | "# idx_ext = np.array(df[np.log(df['bayes_factor'])>BF_THRESH].index,dtype=int)\n", 343 | "# sig_genes_param = list(map(genes.__getitem__,idx_ext))\n", 344 | " \n", 345 | "# ext_sig_BF_genes[cell_type][param] = sig_genes_param\n", 346 | " " 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": null, 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "for ct in unique_cell_types[:3]:\n", 356 | "\n", 357 | " print('Working cell type: ',ct)\n", 358 | "\n", 359 | " for param in biVI_sig_BF_genes[ct].keys():\n", 360 | " n = len(biVI_sig_BF_genes[ct][param])\n", 361 | " print(f'There are {n} genes significant in {param} BIVI MINE')\n", 362 | " \n", 363 | " for param in scVI_sig_BF_genes_built[ct].keys():\n", 364 | " n = len(scVI_sig_BF_genes_built[ct][param])\n", 365 | " print(f'There are {n} genes significant in {param} SCVI BUILT')\n", 366 | " \n", 367 | " for param in scVI_sig_BF_genes[ct].keys():\n", 368 | " n = len(scVI_sig_BF_genes[ct][param])\n", 369 | " print(f'There are {n} genes significant in {param} SCVI MINE')\n", 370 | "\n", 371 | " \n", 372 | "# for param in const_sig_BF_genes[ct].keys():\n", 373 | "\n", 374 | "# n = len(const_sig_BF_genes[ct][param])\n", 375 | "# print(f'There are {n} genes significant in {param} const')\n", 376 | " \n", 377 | "# for param in ext_sig_BF_genes[ct].keys():\n", 378 | "\n", 379 | "# n = len(ext_sig_BF_genes[ct][param])\n", 380 | "# print(f'There are {n} genes significant in {param} ext')" 381 | ] 382 | }, 383 | { 384 | "cell_type": "markdown", 385 | "metadata": {}, 386 | "source": [ 387 | "# Novel gene identification\n", 388 | "\n", 389 | "Identify genes that are DE (up, for now) in biVI parameters but NOT in scVI spliced means (as is the typical case).\n", 390 | "\n", 391 | "\n" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": null, 397 | "metadata": {}, 398 | "outputs": [], 399 | "source": [ 400 | "b_changed_mu2_unchanged = {}\n", 401 | "gamma_changed_mu2_unchanged = {}\n", 402 | "\n", 403 | "for ct in unique_cell_types:\n", 404 | " genes_b = biVI_sig_BF_genes[cell_type]['norm_burst_size']\n", 405 | " genes_gamma = biVI_sig_BF_genes[cell_type]['norm_degradation_rate']\n", 406 | " genes_scVI_mu2 = scVI_sig_BF_genes_built[cell_type]['norm_spliced_mean']\n", 407 | " \n", 408 | " b_changed_mu2_unchanged[ct] = [g for g in genes_b if g not in genes_scVI_mu2]\n", 409 | " gamma_changed_mu2_unchanged[ct] = [g for g in genes_gamma if g not in genes_scVI_mu2]" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": null, 415 | "metadata": {}, 416 | "outputs": [], 417 | "source": [ 418 | "'Ndnf' in b_changed_mu2_unchanged['L6 CT']" 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": null, 424 | "metadata": {}, 425 | "outputs": [], 426 | "source": [ 427 | "'Trem2' in gamma_changed_mu2_unchanged['L5 IT']" 428 | ] 429 | }, 430 | { 431 | "cell_type": "markdown", 432 | "metadata": {}, 433 | "source": [ 434 | "------\n", 435 | "\n", 436 | "# Plot biVI DE genes" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": null, 442 | "metadata": {}, 443 | "outputs": [], 444 | "source": [ 445 | "sig_gene_plot = []\n", 446 | "\n", 447 | "num_muX_sig = 0\n", 448 | "num_both_sig = 0\n", 449 | "num_biVIX_sig = 0\n", 450 | "\n", 451 | "for ct in unique_cell_types:\n", 452 | "\n", 453 | " biVI_sig_genes = np.array(biVI_sig_BF_genes[ct]['norm_burst_size'] + biVI_sig_BF_genes[ct]['norm_degradation_rate'])\n", 454 | " \n", 455 | " biVI_sig_genes = np.unique(biVI_sig_genes)\n", 456 | "\n", 457 | " mu_sig_genes = np.array(scVI_sig_BF_genes_built[ct]['norm_spliced_mean'])\n", 458 | " # + scVI_sig_BF_genes_built[ct]['norm_unspliced_mean'])\n", 459 | " mu_sig_genes = np.unique(mu_sig_genes)\n", 460 | " muX_sig_genes = [g for g in mu_sig_genes if g not in biVI_sig_genes]\n", 461 | " \n", 462 | " biVIX_sig_genes = [g for g in biVI_sig_genes if g not in mu_sig_genes]\n", 463 | " \n", 464 | " both_sig_genes = [g for g in biVI_sig_genes if g in mu_sig_genes]\n", 465 | " sig_gene_plot.append(len(biVIX_sig_genes))\n", 466 | " \n", 467 | " num_muX_sig += len(muX_sig_genes)\n", 468 | " num_both_sig += len(both_sig_genes)\n", 469 | " num_biVIX_sig += len(biVIX_sig_genes)\n", 470 | " \n", 471 | "plt.figure(figsize=(9,4))\n", 472 | "sns.barplot(x = unique_cell_types, y = sig_gene_plot,\n", 473 | " palette = ['lightsalmon' for i in range(19)])\n", 474 | "\n", 475 | "plt.xticks(rotation=90);\n", 476 | "plt.xlabel('Cell subclass',fontsize = 25)\n", 477 | "plt.tick_params(axis='both', which='major', labelsize=20)\n", 478 | "#plt.title(f'$biVI$ Significant Differentially Regulated Genes',fontsize = 30)\n", 479 | "plt.ylabel('Number of genes',fontsize = 25);\n", 480 | "\n", 481 | "\n", 482 | "\n", 483 | "plt.savefig(f'.../results/B08_processed_hv_figs/biVI_DE_genes.png',bbox_inches='tight')\n", 484 | "\n", 485 | "\n" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": null, 491 | "metadata": {}, 492 | "outputs": [], 493 | "source": [ 494 | "print(num_muX_sig)\n", 495 | "print(num_both_sig)\n", 496 | "print(num_biVIX_sig)" 497 | ] 498 | }, 499 | { 500 | "cell_type": "markdown", 501 | "metadata": {}, 502 | "source": [ 503 | "# Fraction of DE genes\n", 504 | "\n", 505 | "\n", 506 | "For all genes found to be differentially expressed in *biVI* parameters, are they mostly regulated by burst size, degradation rate, or both?" 507 | ] 508 | }, 509 | { 510 | "cell_type": "code", 511 | "execution_count": null, 512 | "metadata": {}, 513 | "outputs": [], 514 | "source": [ 515 | "compare_b_gamma_dict = {ct : {} for ct in unique_cell_types}\n", 516 | "\n", 517 | "b_fractions = []\n", 518 | "gamma_fractions = []\n", 519 | "both_fractions = []\n", 520 | "ct_for_df = []\n", 521 | "\n", 522 | "for ct in unique_cell_types:\n", 523 | " print(ct)\n", 524 | " b_genes = biVI_sig_BF_genes[ct]['norm_burst_size']\n", 525 | " gamma_genes = biVI_sig_BF_genes[ct]['norm_degradation_rate']\n", 526 | " \n", 527 | " \n", 528 | " all_sig_genes = np.unique( np.array(b_genes+gamma_genes) )\n", 529 | " \n", 530 | " if len(all_sig_genes) != 0: \n", 531 | " fraction_b_genes = len([b for b in b_genes if b not in gamma_genes])/len(all_sig_genes)\n", 532 | " fraction_gamma_genes = len([g for g in gamma_genes if g not in b_genes])/len(all_sig_genes)\n", 533 | " fraction_both = len([a for a in all_sig_genes if (a in b_genes) and (a in gamma_genes)])/len(all_sig_genes)\n", 534 | " \n", 535 | " \n", 536 | " b_fractions.append(fraction_b_genes) \n", 537 | " gamma_fractions.append(fraction_gamma_genes) \n", 538 | " both_fractions.append(fraction_both)\n", 539 | " ct_for_df.append(ct)\n", 540 | " \n", 541 | " " 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": null, 547 | "metadata": {}, 548 | "outputs": [], 549 | "source": [ 550 | "# how many genes" 551 | ] 552 | }, 553 | { 554 | "cell_type": "code", 555 | "execution_count": null, 556 | "metadata": {}, 557 | "outputs": [], 558 | "source": [ 559 | "\n", 560 | "N = len(ct_for_df)\n", 561 | "df_1 = pd.DataFrame({'Cell subclass' : ct_for_df,'Fraction' : b_fractions, 'Parameter' : ['Burst size']*N})\n", 562 | "df_2 = pd.DataFrame({'Cell subclass' : ct_for_df,'Fraction' : gamma_fractions, 'Parameter' : ['Relative degradation rate']*N})\n", 563 | "df_3 = pd.DataFrame({'Cell subclass' : ct_for_df,'Fraction' : both_fractions, 'Parameter' : ['Both']*N})\n", 564 | "df_plot = pd.concat([df_1,df_2,df_3],axis=0)" 565 | ] 566 | }, 567 | { 568 | "cell_type": "code", 569 | "execution_count": null, 570 | "metadata": {}, 571 | "outputs": [], 572 | "source": [ 573 | "# get number of cells in each cell type\n", 574 | "num_cell_type = []\n", 575 | "for ct in ct_for_df:\n", 576 | " num_cell_type.append(len(adata[adata.obs.subclass_label==ct]))" 577 | ] 578 | }, 579 | { 580 | "cell_type": "code", 581 | "execution_count": null, 582 | "metadata": {}, 583 | "outputs": [], 584 | "source": [ 585 | "fig, ax = plt.subplots(2,1,figsize = (17,12),gridspec_kw={'height_ratios': [5, 1],\n", 586 | " 'wspace': 0.0, 'hspace':0.07})\n", 587 | "plt.xticks(rotation = 90, fontsize = 30)\n", 588 | "fig.subplots_adjust(wspace=None)\n", 589 | "#plt.yticks(fontsize=34)\n", 590 | "sns.barplot(x='Cell subclass', y='Fraction', hue='Parameter', data=df_plot, \n", 591 | " palette = [cmap_ygb(80),cmap_ygb(170),cmap_ygb(270)], ax = ax[0])\n", 592 | "ax[0].set_ylabel('Fraction',fontsize = 40)\n", 593 | "ax[0].set_ylim(0,1.2)\n", 594 | "ax[0].set(xlabel=None)\n", 595 | "ax[0].tick_params(\n", 596 | " axis='x', # changes apply to the x-axis\n", 597 | " which='both', # both major and minor ticks are affected\n", 598 | " bottom=False, # ticks along the bottom edge are off\n", 599 | " top=False, # ticks along the top edge are off\n", 600 | " labelbottom=False,)\n", 601 | "ax[0].set_yticklabels([0.0,0.2,0.4,0.6,0.8,1.0,],fontsize = 30)\n", 602 | "plt.setp(ax[0].get_legend().get_texts(), fontsize='30');\n", 603 | "plt.setp(ax[0].get_legend().get_title(), fontsize='30');\n", 604 | "\n", 605 | "\n", 606 | "sns.barplot(x=ct_for_df,y = num_cell_type,color = 'gray',ax=ax[1])\n", 607 | "ax[1].set_xticklabels(ct_for_df, rotation=90);\n", 608 | "ax[1].set_yscale('log')\n", 609 | "ax[1].set_yscale('log')\n", 610 | "ax[1].set_xlabel('Cell subclass',fontsize = 40)\n", 611 | "ax[1].set_ylabel('# Cells',fontsize = 30)\n", 612 | "ax[1].set_ylim(0,8000)\n", 613 | "ax[1].set_yticklabels([0,1,10,10**2,10**3],fontsize = 25);\n", 614 | "#ax[1].ticklabel_format(axis='y',style='sci');\n", 615 | "\n", 616 | "\n", 617 | "plt.savefig(f'.../results/B08_processed_hv_figs/b_gamma_percent.png',bbox_inches='tight');" 618 | ] 619 | }, 620 | { 621 | "cell_type": "markdown", 622 | "metadata": {}, 623 | "source": [ 624 | "------\n", 625 | "# Checking $Ndnf$ in other Allen samples" 626 | ] 627 | }, 628 | { 629 | "cell_type": "code", 630 | "execution_count": 2, 631 | "metadata": {}, 632 | "outputs": [ 633 | { 634 | "name": "stderr", 635 | "output_type": "stream", 636 | "text": [ 637 | "/usr/local/lib/python3.8/dist-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n", 638 | " warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n" 639 | ] 640 | } 641 | ], 642 | "source": [ 643 | "\n" 644 | ] 645 | }, 646 | { 647 | "cell_type": "code", 648 | "execution_count": 5, 649 | "metadata": {}, 650 | "outputs": [ 651 | { 652 | "name": "stdout", 653 | "output_type": "stream", 654 | "text": [ 655 | "\u001b[34mINFO \u001b[0m File ..\u001b[35m/../results/Bursty_F08_processed_hv_MODEL/\u001b[0m\u001b[95mmodel.pt\u001b[0m already downloaded \n", 656 | "{'n_input': 4000, 'n_hidden': 128, 'n_latent': 10, 'n_layers': 3, 'dropout_rate': 0.1, 'dispersion': 'gene', 'latent_distribution': 'normal', 'log_variational': True}\n", 657 | "Initiating biVAE\n", 658 | "Mode: Bursty, Decoder: non-linear, Theta is: NAS_SHAPE\n" 659 | ] 660 | }, 661 | { 662 | "name": "stderr", 663 | "output_type": "stream", 664 | "text": [ 665 | "/home/tara/.local/lib/python3.8/site-packages/scvi/model/base/_utils.py:142: UserWarning: var_names for adata passed in does not match var_names of adata used to train the model. For valid results, the vars need to be the same and in the same order as the adata used to train the model.\n", 666 | " warnings.warn(\n" 667 | ] 668 | } 669 | ], 670 | "source": [ 671 | "model1_f08 = biVI.biVI.load(\".../results/Bursty_F08_processed_hv_MODEL\", adata=adata_f08, use_gpu = True)" 672 | ] 673 | }, 674 | { 675 | "cell_type": "code", 676 | "execution_count": 64, 677 | "metadata": {}, 678 | "outputs": [ 679 | { 680 | "name": "stdout", 681 | "output_type": "stream", 682 | "text": [ 683 | "A08\n", 684 | "\u001b[34mINFO \u001b[0m File ..\u001b[35m/../results/Bursty_A08_processed_hv_MODEL/\u001b[0m\u001b[95mmodel.pt\u001b[0m already downloaded \n" 685 | ] 686 | }, 687 | { 688 | "name": "stderr", 689 | "output_type": "stream", 690 | "text": [ 691 | "/usr/local/lib/python3.8/dist-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n", 692 | " warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n" 693 | ] 694 | }, 695 | { 696 | "name": "stdout", 697 | "output_type": "stream", 698 | "text": [ 699 | "{'n_input': 4000, 'n_hidden': 128, 'n_latent': 10, 'n_layers': 3, 'dropout_rate': 0.1, 'dispersion': 'gene', 'latent_distribution': 'normal', 'log_variational': True}\n", 700 | "Initiating biVAE\n", 701 | "Mode: Bursty, Decoder: non-linear, Theta is: NAS_SHAPE\n", 702 | "L6 CT\n", 703 | "\u001b[34mINFO \u001b[0m AnnData object appears to be a copy. Attempting to transfer setup. \n" 704 | ] 705 | }, 706 | { 707 | "name": "stderr", 708 | "output_type": "stream", 709 | "text": [ 710 | "/home/tara/.local/lib/python3.8/site-packages/scvi/model/base/_utils.py:142: UserWarning: var_names for adata passed in does not match var_names of adata used to train the model. For valid results, the vars need to be the same and in the same order as the adata used to train the model.\n", 711 | " warnings.warn(\n" 712 | ] 713 | }, 714 | { 715 | "name": "stdout", 716 | "output_type": "stream", 717 | "text": [ 718 | "L5 IT\n", 719 | "C01\n", 720 | "\u001b[34mINFO \u001b[0m File ..\u001b[35m/../results/Bursty_C01_processed_hv_MODEL/\u001b[0m\u001b[95mmodel.pt\u001b[0m already downloaded \n", 721 | "{'n_input': 4000, 'n_hidden': 128, 'n_latent': 10, 'n_layers': 3, 'dropout_rate': 0.1, 'dispersion': 'gene', 'latent_distribution': 'normal', 'log_variational': True}\n", 722 | "Initiating biVAE\n", 723 | "Mode: Bursty, Decoder: non-linear, Theta is: NAS_SHAPE\n", 724 | "L6 CT\n", 725 | "\u001b[34mINFO \u001b[0m AnnData object appears to be a copy. Attempting to transfer setup. \n" 726 | ] 727 | }, 728 | { 729 | "name": "stderr", 730 | "output_type": "stream", 731 | "text": [ 732 | "/usr/local/lib/python3.8/dist-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n", 733 | " warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n", 734 | "/home/tara/.local/lib/python3.8/site-packages/scvi/model/base/_utils.py:142: UserWarning: var_names for adata passed in does not match var_names of adata used to train the model. For valid results, the vars need to be the same and in the same order as the adata used to train the model.\n", 735 | " warnings.warn(\n" 736 | ] 737 | }, 738 | { 739 | "name": "stdout", 740 | "output_type": "stream", 741 | "text": [ 742 | "L5 IT\n", 743 | "F08\n", 744 | "\u001b[34mINFO \u001b[0m File ..\u001b[35m/../results/Bursty_F08_processed_hv_MODEL/\u001b[0m\u001b[95mmodel.pt\u001b[0m already downloaded \n", 745 | "{'n_input': 4000, 'n_hidden': 128, 'n_latent': 10, 'n_layers': 3, 'dropout_rate': 0.1, 'dispersion': 'gene', 'latent_distribution': 'normal', 'log_variational': True}\n", 746 | "Initiating biVAE\n", 747 | "Mode: Bursty, Decoder: non-linear, Theta is: NAS_SHAPE\n", 748 | "L6 CT\n", 749 | "\u001b[34mINFO \u001b[0m AnnData object appears to be a copy. Attempting to transfer setup. \n" 750 | ] 751 | }, 752 | { 753 | "name": "stderr", 754 | "output_type": "stream", 755 | "text": [ 756 | "/usr/local/lib/python3.8/dist-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n", 757 | " warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n", 758 | "/home/tara/.local/lib/python3.8/site-packages/scvi/model/base/_utils.py:142: UserWarning: var_names for adata passed in does not match var_names of adata used to train the model. For valid results, the vars need to be the same and in the same order as the adata used to train the model.\n", 759 | " warnings.warn(\n" 760 | ] 761 | }, 762 | { 763 | "name": "stdout", 764 | "output_type": "stream", 765 | "text": [ 766 | "L5 IT\n", 767 | "H12\n", 768 | "\u001b[34mINFO \u001b[0m File ..\u001b[35m/../results/Bursty_H12_processed_hv_MODEL/\u001b[0m\u001b[95mmodel.pt\u001b[0m already downloaded \n", 769 | "{'n_input': 4000, 'n_hidden': 128, 'n_latent': 10, 'n_layers': 3, 'dropout_rate': 0.1, 'dispersion': 'gene', 'latent_distribution': 'normal', 'log_variational': True}\n", 770 | "Initiating biVAE\n", 771 | "Mode: Bursty, Decoder: non-linear, Theta is: NAS_SHAPE\n", 772 | "L6 CT\n", 773 | "\u001b[34mINFO \u001b[0m AnnData object appears to be a copy. Attempting to transfer setup. \n" 774 | ] 775 | }, 776 | { 777 | "name": "stderr", 778 | "output_type": "stream", 779 | "text": [ 780 | "/usr/local/lib/python3.8/dist-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n", 781 | " warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n", 782 | "/home/tara/.local/lib/python3.8/site-packages/scvi/model/base/_utils.py:142: UserWarning: var_names for adata passed in does not match var_names of adata used to train the model. For valid results, the vars need to be the same and in the same order as the adata used to train the model.\n", 783 | " warnings.warn(\n" 784 | ] 785 | }, 786 | { 787 | "name": "stdout", 788 | "output_type": "stream", 789 | "text": [ 790 | "L5 IT\n" 791 | ] 792 | } 793 | ], 794 | "source": [ 795 | "# get bayes factors' for different samples and models \n", 796 | "samples = ['A08','C01','F08','H12']\n", 797 | "ndnf_index = {samp : {} for samp in samples}\n", 798 | "trem2_index = {samp : {} for samp in samples}\n", 799 | "all_samples_dict = {samp: {} for samp in samples}\n", 800 | "n_samples = 20\n", 801 | "delta = 1.0\n", 802 | "parameters = ['norm_burst_size','norm_degradation_rate','norm_spliced_mean','norm_unspliced_mean']\n", 803 | "cell_types_to_test = ['L6 CT', 'L5 IT']\n", 804 | "\n", 805 | "\n", 806 | "for samp in samples:\n", 807 | " print(samp)\n", 808 | " biVI_BF_DE_ = {ct : {} for ct in unique_cell_types}\n", 809 | " # first load in data\n", 810 | " adata_ = anndata.read_loom(f'../data/allen/{samp}_processed_hv.loom')\n", 811 | " \n", 812 | " ndnf_index[samp] = np.where(adata_.var['gene_name'][:2000] == 'Ndnf')[0]\n", 813 | " trem2_index[samp] = np.where(adata_.var['gene_name'][:2000] == 'Trem2')[0]\n", 814 | " \n", 815 | " model1_ = biVI.biVI.load(f\".../results/Bursty_{samp}_processed_hv_MODEL\", adata=adata_, use_gpu = True)\n", 816 | "\n", 817 | " # make variable names unique \n", 818 | " adata_.var_names_make_unique()\n", 819 | "\n", 820 | " # remove genes with fewer than 10 cells\n", 821 | " cell_types = np.array(adata_.obs['subclass_label'].tolist())\n", 822 | " \n", 823 | "\n", 824 | " # ordered according to cell subclass\n", 825 | " unique_cell_types = ['Lamp5', 'Sncg', 'Vip', 'Sst', 'Pvalb',\n", 826 | " 'L2/3 IT', 'L5 IT', 'L5/6 NP', 'L6 CT', 'L6 IT', 'L6b',\n", 827 | " 'Astro', 'OPC', 'Oligo', 'Macrophage', 'Endo']\n", 828 | "\n", 829 | "\n", 830 | " for ct in unique_cell_types:\n", 831 | " \n", 832 | " cells_per_ct_ = (ct == cell_types).sum()\n", 833 | " if cells_per_ct_ < 10.0:\n", 834 | " adata_ = adata_[adata_.obs['subclass_label'] != ct, :]\n", 835 | " \n", 836 | " adata_ = adata_.copy() \n", 837 | " \n", 838 | " for cell_type in cell_types_to_test:\n", 839 | " print(cell_type)\n", 840 | " idx1 = np.arange(len(cell_types))[cell_types == cell_type]\n", 841 | " idx2 = np.arange(len(cell_types))[cell_types != cell_type]\n", 842 | "\n", 843 | " biVI_BF_DE_[cell_type] = model1_.get_bayes_factors(adata_,idx1,idx2,\n", 844 | " n_samples_1 = n_samples,\n", 845 | " n_samples_2 = n_samples,\n", 846 | " n_comparisons = 10000,\n", 847 | " return_df = True,\n", 848 | " delta = delta)\n", 849 | " all_samples_dict[samp] = biVI_BF_DE_" 850 | ] 851 | }, 852 | { 853 | "cell_type": "code", 854 | "execution_count": 65, 855 | "metadata": {}, 856 | "outputs": [ 857 | { 858 | "data": { 859 | "text/html": [ 860 | "
\n", 861 | "\n", 874 | "\n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | " \n", 908 | " \n", 909 | " \n", 910 | " \n", 911 | " \n", 912 | " \n", 913 | " \n", 914 | " \n", 915 | " \n", 916 | " \n", 917 | " \n", 918 | " \n", 919 | " \n", 920 | " \n", 921 | " \n", 922 | " \n", 923 | " \n", 924 | " \n", 925 | " \n", 926 | " \n", 927 | " \n", 928 | " \n", 929 | " \n", 930 | " \n", 931 | " \n", 932 | " \n", 933 | " \n", 934 | " \n", 935 | " \n", 936 | " \n", 937 | " \n", 938 | " \n", 939 | " \n", 940 | " \n", 941 | " \n", 942 | "
SampleBF_listCT_listLFC_list
0A080.040799L6 CT-0.194941
1A080.000000L5 IT-0.000106
2C010.252035L6 CT-0.239548
3C010.000000L5 IT-0.030240
4F080.027644L6 CT-0.280004
5F080.022077L5 IT-0.047185
6H120.050200L6 CT0.012408
7H120.083189L5 IT-0.434457
\n", 943 | "
" 944 | ], 945 | "text/plain": [ 946 | " Sample BF_list CT_list LFC_list\n", 947 | "0 A08 0.040799 L6 CT -0.194941\n", 948 | "1 A08 0.000000 L5 IT -0.000106\n", 949 | "2 C01 0.252035 L6 CT -0.239548\n", 950 | "3 C01 0.000000 L5 IT -0.030240\n", 951 | "4 F08 0.027644 L6 CT -0.280004\n", 952 | "5 F08 0.022077 L5 IT -0.047185\n", 953 | "6 H12 0.050200 L6 CT 0.012408\n", 954 | "7 H12 0.083189 L5 IT -0.434457" 955 | ] 956 | }, 957 | "execution_count": 65, 958 | "metadata": {}, 959 | "output_type": "execute_result" 960 | } 961 | ], 962 | "source": [ 963 | "# Bayes factor differential expression for Ndnf\n", 964 | "samp_list = []\n", 965 | "BF_list = []\n", 966 | "CT_list = []\n", 967 | "LFC_list = []\n", 968 | "\n", 969 | "param='norm_burst_size'\n", 970 | "cell_type_to_test = ['L6 CT','L5 IT']\n", 971 | "for samp in samples:\n", 972 | " for cell_type in cell_type_to_test:\n", 973 | " idx = ndnf_index[samp]\n", 974 | " CT_list.append(cell_type)\n", 975 | " samp_list.append(samp)\n", 976 | " LFC_list.append(all_samples_dict[samp][cell_type][param]['lfc_mean'].values[idx][0])\n", 977 | " BF_list.append(all_samples_dict[samp][cell_type][param]['bayes_factor'].values[idx][0])\n", 978 | " \n", 979 | "df_plot = pd.DataFrame({'Sample' : samp_list,\n", 980 | " 'BF_list' : BF_list,\n", 981 | " 'BF_list' : BF_list,\n", 982 | " 'CT_list' : CT_list,\n", 983 | " 'LFC_list' : LFC_list,})\n", 984 | "\n", 985 | "df_plot" 986 | ] 987 | }, 988 | { 989 | "cell_type": "code", 990 | "execution_count": 66, 991 | "metadata": {}, 992 | "outputs": [], 993 | "source": [ 994 | "# Bayes factor differential expression for Ndnf\n", 995 | "samp_list = []\n", 996 | "BF_list = []\n", 997 | "CT_list = []\n", 998 | "LFC_list = []\n", 999 | "\n", 1000 | "param='norm_degradation_rate'\n", 1001 | "cell_type_to_test = ['L6 CT','L5 IT']\n", 1002 | "for samp in samples:\n", 1003 | " for cell_type in cell_type_to_test:\n", 1004 | " idx = trem2_index[samp]\n", 1005 | " CT_list.append(cell_type)\n", 1006 | " samp_list.append(samp)\n", 1007 | " LFC_list.append(all_samples_dict[samp][cell_type][param]['lfc_mean'].values[idx][0])\n", 1008 | " BF_list.append(all_samples_dict[samp][cell_type][param]['bayes_factor'].values[idx][0])\n", 1009 | " \n", 1010 | "df_plot = pd.DataFrame({'Sample' : samp_list,\n", 1011 | " 'BF_list' : BF_list,\n", 1012 | " 'BF_list' : BF_list,\n", 1013 | " 'CT_list' : CT_list,\n", 1014 | " 'LFC_list' : LFC_list,})" 1015 | ] 1016 | }, 1017 | { 1018 | "cell_type": "code", 1019 | "execution_count": 67, 1020 | "metadata": {}, 1021 | "outputs": [ 1022 | { 1023 | "data": { 1024 | "text/html": [ 1025 | "
\n", 1026 | "\n", 1039 | "\n", 1040 | " \n", 1041 | " \n", 1042 | " \n", 1043 | " \n", 1044 | " \n", 1045 | " \n", 1046 | " \n", 1047 | " \n", 1048 | " \n", 1049 | " \n", 1050 | " \n", 1051 | " \n", 1052 | " \n", 1053 | " \n", 1054 | " \n", 1055 | " \n", 1056 | " \n", 1057 | " \n", 1058 | " \n", 1059 | " \n", 1060 | " \n", 1061 | " \n", 1062 | " \n", 1063 | " \n", 1064 | " \n", 1065 | " \n", 1066 | " \n", 1067 | " \n", 1068 | " \n", 1069 | " \n", 1070 | " \n", 1071 | " \n", 1072 | " \n", 1073 | " \n", 1074 | " \n", 1075 | " \n", 1076 | " \n", 1077 | " \n", 1078 | " \n", 1079 | " \n", 1080 | " \n", 1081 | " \n", 1082 | " \n", 1083 | " \n", 1084 | " \n", 1085 | " \n", 1086 | " \n", 1087 | " \n", 1088 | " \n", 1089 | " \n", 1090 | " \n", 1091 | " \n", 1092 | " \n", 1093 | " \n", 1094 | " \n", 1095 | " \n", 1096 | " \n", 1097 | " \n", 1098 | " \n", 1099 | " \n", 1100 | " \n", 1101 | " \n", 1102 | " \n", 1103 | " \n", 1104 | " \n", 1105 | " \n", 1106 | " \n", 1107 | "
SampleBF_listCT_listLFC_list
0A080.053741L6 CT-0.289521
1A080.000000L5 IT0.000017
2C010.234111L6 CT-0.507248
3C010.022181L5 IT-0.060986
4F080.019888L6 CT-0.058045
5F080.010918L5 IT0.026180
6H120.016157L6 CT-0.207309
7H120.014713L5 IT-0.274923
\n", 1108 | "
" 1109 | ], 1110 | "text/plain": [ 1111 | " Sample BF_list CT_list LFC_list\n", 1112 | "0 A08 0.053741 L6 CT -0.289521\n", 1113 | "1 A08 0.000000 L5 IT 0.000017\n", 1114 | "2 C01 0.234111 L6 CT -0.507248\n", 1115 | "3 C01 0.022181 L5 IT -0.060986\n", 1116 | "4 F08 0.019888 L6 CT -0.058045\n", 1117 | "5 F08 0.010918 L5 IT 0.026180\n", 1118 | "6 H12 0.016157 L6 CT -0.207309\n", 1119 | "7 H12 0.014713 L5 IT -0.274923" 1120 | ] 1121 | }, 1122 | "execution_count": 67, 1123 | "metadata": {}, 1124 | "output_type": "execute_result" 1125 | } 1126 | ], 1127 | "source": [ 1128 | "df_plot" 1129 | ] 1130 | }, 1131 | { 1132 | "cell_type": "code", 1133 | "execution_count": null, 1134 | "metadata": {}, 1135 | "outputs": [], 1136 | "source": [] 1137 | } 1138 | ], 1139 | "metadata": { 1140 | "kernelspec": { 1141 | "display_name": "Python 3", 1142 | "language": "python", 1143 | "name": "python3" 1144 | }, 1145 | "language_info": { 1146 | "codemirror_mode": { 1147 | "name": "ipython", 1148 | "version": 3 1149 | }, 1150 | "file_extension": ".py", 1151 | "mimetype": "text/x-python", 1152 | "name": "python", 1153 | "nbconvert_exporter": "python", 1154 | "pygments_lexer": "ipython3", 1155 | "version": "3.8.10" 1156 | } 1157 | }, 1158 | "nbformat": 4, 1159 | "nbformat_minor": 4 1160 | } 1161 | -------------------------------------------------------------------------------- /Manuscript/analysis/__pycache__/calculate_metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pachterlab/CGCCP_2023/8e4b6c99e3bda5d664bea51b89302357538e5bd5/Manuscript/analysis/__pycache__/calculate_metrics.cpython-38.pyc -------------------------------------------------------------------------------- /Manuscript/analysis/calculate_metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # Nearest Neighbor calculation 5 | # 6 | # Calculate nearest neighbor cluster metric 7 | 8 | # In[1]: 9 | 10 | 11 | import numpy as np 12 | import torch 13 | import pandas as pd 14 | 15 | # nearest neighbor classifiers and pearson correlation calculators 16 | from sklearn.neighbors import KNeighborsClassifier 17 | from scipy import stats #function: stats.pearsonr(x,y) 18 | 19 | 20 | # In[56]: 21 | def calc_MSE_1D(x,y): 22 | '''Calculate the MSE between x and y. 23 | params 24 | ------- 25 | x : (Z) 26 | y : (Z) 27 | 28 | returns 29 | ------- 30 | MSE : 1''' 31 | 32 | MSE_ = (x-y)**2 33 | MSE = np.sum(MSE_)/len(x) 34 | return(MSE) 35 | 36 | def l1(z1,z2): 37 | ''' Calculates l1 distance between two vectors: absolute value of difference between components. 38 | ''' 39 | abs_dist = np.absolute(z1-z2) 40 | return( np.sum(abs_dist,axis=1) ) 41 | 42 | def l2(z1,z2): 43 | ''' Calculates l2 distance between two vectors: square root of difference squared. 44 | ''' 45 | 46 | dist = np.sqrt( np.sum((z1-z2)**2,axis=1) ) 47 | 48 | return( dist ) 49 | 50 | 51 | def squared_dist(z1,z2): 52 | ''' Calculates l2 distance between two vectors: square root of difference squared. 53 | ''' 54 | 55 | dist = np.sum((z1-z2)**2,axis=1) 56 | 57 | return( dist ) 58 | 59 | 60 | # def get_nearest_neighbors_percentages(z_array, cluster_memberships, top = 50, distance = 'l2'): 61 | # ''' Calculate the percent of nearest neighbors in the same cell type or cluster. 62 | 63 | # Parameters 64 | # ---------- 65 | # z_array : numpy array of z_vectors, shape (num_cells, latent_dim) 66 | # cluster_memberships : list or array listing cluster or cell type memberships for cells in z_array, length (num_cells) 67 | # top : how many nearest neighbors to calculate percentage for 68 | # distance : what distance metric to use to define nearest neighbors, options ['l1','l2','squared_dist'] 69 | 70 | # Returns 71 | # ---------- 72 | # percentages : array of percentages of nearest neighbors in same celltype/cluster for each cell, length (num_cells) 73 | # ''' 74 | 75 | 76 | # # set up array to store percentages 77 | # percentages = np.zeros(len(z_array)) 78 | 79 | # if distance == 'l2': 80 | # dist_func = l2 81 | # if distance == 'l1': 82 | # dist_func = l1 83 | # if distance == 'squared_dist': 84 | # dist_func = squared_dist 85 | 86 | 87 | # for i,z_i in enumerate(z_array): 88 | # z_i = z_array[i,:] 89 | 90 | # z_i_array = np.repeat(z_i.reshape(1,-1), len(z_array), axis = 0) 91 | 92 | # dist_array = dist_func(z_i_array,z_array) 93 | 94 | # # will give indices of top nearest neighbors for z_i -- note, will include z_i itself so add 1 95 | # idx = np.argpartition(dist_array, (top+1) )[:(top+1)] 96 | 97 | # clusters = np.take(cluster_memberships,idx) 98 | 99 | # cluster_z_i = cluster_memberships[i] 100 | 101 | # same_cluster = clusters[clusters == cluster_z_i] 102 | 103 | # percent_same = ( len(same_cluster) - 1 ) / (top) # make sure to remove the cluster for z_i itself 104 | 105 | # percentages[i] = percent_same 106 | 107 | 108 | # return(percentages) 109 | 110 | def nn_percentages(x,cluster_assignments): 111 | ''' Calculate the percentage of nearest neighbors in the same cluster. 112 | 113 | params 114 | ------ 115 | x : (N,Z) N cells, Z latent space 116 | cluster_assignments : cluster assignments for vectors in x 117 | 118 | returns 119 | ------- 120 | nn_percent_array = (N) percent of N nearest neighbors in same cluster for each vector of x 121 | ''' 122 | 123 | cluster_assignments = np.array(cluster_assignments) 124 | unique_clusters = np.unique(cluster_assignments) 125 | x_done = 0 126 | 127 | nn_percent_array = np.ones(x.shape[0]) 128 | 129 | for cluster in unique_clusters: 130 | cluster_assignments_ = cluster_assignments[cluster_assignments == cluster] 131 | x_ = x[cluster_assignments == cluster] 132 | 133 | # how many neighbors were in this unique cluster 134 | N_ = len(cluster_assignments_) 135 | 136 | # set up nearest neighbor class 137 | neigh = KNeighborsClassifier(n_neighbors=N_) 138 | 139 | # fit model 140 | neigh.fit(x,cluster_assignments) 141 | 142 | # calculate nearest neighbor distance and indices to top N_ neighbors for all vectors in x 143 | # returns array neigh_ind of shape x by N_ 144 | neigh_ind = neigh.kneighbors(x_,return_distance = False) 145 | 146 | nn_percent_cluster_ = np.array([len(cluster_assignments[n][cluster_assignments[n] 147 | == cluster_assignments_[i]])-1 for i,n in enumerate(neigh_ind)])/(N_-1) 148 | 149 | nn_percent_array[x_done : x_done + N_] = nn_percent_cluster_ 150 | x_done = x_done + N_ 151 | 152 | return nn_percent_array 153 | 154 | 155 | def get_metrics(name,results_dict,simulated_params,cluster_assignments,adata): 156 | ''' Given results_dict from model training, returns MSE between simulated/recon means, Pearson correlation between simulated/recon means, 157 | and percentage of N nearest neighbors in the same cluster assignment for all cells. 158 | 159 | 160 | params 161 | ------ 162 | name: name of data 163 | simulated params: IF you pass simulated params, will calculate MSE and Pearson R between simulated means 164 | and reconstructed means \ 165 | rather than observed counts and reconstructed means 166 | results_dict: containing keys for each setup: 167 | ['X_{z}','runtime','df_history','params','recon_error','cell_type'] 168 | 169 | 170 | outputs 171 | ------- 172 | metric_dict containing keys: 173 | ['MSE','MSE',Pearson_R',Pearson_R','nearest_neighbors'] 174 | ''' 175 | 176 | # set up dictionary to store things in with the training setups as keys 177 | 178 | 179 | setups = list(results_dict.keys()) 180 | metric_dict = { setup : {} for setup in setups} 181 | z = list(results_dict[setups[0]].keys())[0][2:] 182 | print(z) 183 | 184 | 185 | # get observed means and dispersions 186 | obs_means = adata[:,:].layers['counts'].toarray() 187 | 188 | for setup in setups: 189 | print(setup) 190 | 191 | setup_dict = results_dict[setup] 192 | 193 | setup_metric_dict = {} 194 | 195 | # unpack dictionary 196 | X_z = setup_dict[f'X_{z}'] 197 | recon_means = setup_dict['params']['mean'] 198 | print(recon_means.shape) 199 | 200 | if simulated_params is not None: 201 | if 'const' in name: 202 | obs_means_U = 10**simulated_params[:,:,0] 203 | obs_means_S = 10**simulated_params[:,:,1] 204 | obs_means = np.concatenate((obs_means_U,obs_means_S),axis=1) 205 | if 'bursty' in name: 206 | params = 10**simulated_params 207 | b,beta,gamma = params[:,:,0],params[:,:,1],params[:,:,2] 208 | obs_means_U = b/beta 209 | obs_means_S = b/gamma 210 | obs_means = np.concatenate((obs_means_U,obs_means_S),axis=1) 211 | if 'BVNB' in name: 212 | alpha = simulated_params[:,:,0] 213 | beta = 10**simulated_params[:,:,1] 214 | gamma = 10**simulated_params[:,:,2] 215 | obs_means_U = alpha/beta 216 | obs_means_S = alpha/gamma 217 | obs_means = np.concatenate((obs_means_U,obs_means_S),axis=1) 218 | 219 | if simulated_params is None: 220 | 221 | setup_metric_dict['MSE'] = np.array([ calc_MSE_1D(recon_means[i],obs_means[i]) for i in range(len(X_z)) ]) 222 | setup_metric_dict['Pearson_R'] = np.array([ stats.pearsonr(recon_means[i], obs_means[i])[0] for i in range(len(X_z)) ]) 223 | 224 | elif simulated_params is not None: 225 | setup_metric_dict['MSE'] = np.array([ calc_MSE_1D(recon_means[i],obs_means[cluster_assignments[i]]) for i in range(len(X_z)) ]) 226 | setup_metric_dict['Pearson_R'] = np.array([ stats.pearsonr(recon_means[i], obs_means[cluster_assignments[i]])[0] for i in range(len(X_z)) ]) 227 | 228 | setup_metric_dict['nearest_neighbors'] = nn_percentages(X_z,cluster_assignments) 229 | 230 | metric_dict[setup] = setup_metric_dict 231 | 232 | return(metric_dict) 233 | 234 | 235 | 236 | 237 | def get_metrics_old(name,results_dict,adata,index,N=100): 238 | ''' Given results_dict from model training, returns MSE between simulated/recon means, Pearson correlation between simulated/recon means, 239 | and percentage of N nearest neighbors in the same cluster assignment for all cells. 240 | 241 | 242 | params 243 | ------ 244 | results_dict containing keys: 245 | ['X_{z}','runtime','df_history','params','recon_error','cell_type'] 246 | 247 | outputs 248 | ------- 249 | metric_dict containing keys: 250 | ['MSE_S','MSE_U',Pearson_R_S',Pearson_R_U','nearest_neighbors'] 251 | ''' 252 | 253 | # set up dictionary to store things in with the training setups as keys 254 | 255 | setups = list(results_dict.keys()) 256 | metric_dict = { setup : {} for setup in setups} 257 | z = list(results_dict[setups[0]].keys())[0][2:] 258 | print(z) 259 | 260 | 261 | # get observed means and dispersions 262 | obs_means_U = adata[:,adata.var['Spliced']==0].layers['counts'].toarray() 263 | obs_means_S = adata[:,adata.var['Spliced']==1].layers['counts'].toarray() 264 | obs_means = adata[:,:].layers['counts'].toarray() 265 | 266 | 267 | for setup in setups: 268 | print(setup) 269 | 270 | setup_dict = results_dict[setup] 271 | 272 | setup_metric_dict = {} 273 | 274 | # unpack dictionary 275 | X_z = setup_dict[f'X_{z}'] 276 | 277 | if '.U' in setup: 278 | recon_means_U = setup_dict['params']['mean'][:,:] 279 | setup_metric_dict['MSE_U'] = np.array([ calc_MSE_1D(recon_means_U[i],obs_means_U[i]) for i in range(len(X_z)) ]) 280 | setup_metric_dict['Pearson_R_U'] = np.array([ stats.pearsonr(recon_means_U[i], obs_means_U[i])[0] for i in range(len(X_z)) ]) 281 | 282 | elif '.S' in setup: 283 | recon_means_S = setup_dict['params']['mean'][:,:] 284 | setup_metric_dict['MSE_S'] = np.array([ calc_MSE_1D(recon_means_S[i],obs_means_S[i]) for i in range(len(X_z)) ]) 285 | setup_metric_dict['Pearson_R_S'] = np.array([ stats.pearsonr(recon_means_S[i], obs_means_S[i])[0] for i in range(len(X_z)) ]) 286 | 287 | else: 288 | recon_means_U = setup_dict['params']['mean'][:,:int(setup_dict['params']['mean'].shape[1]/2)] 289 | recon_means_S = setup_dict['params']['mean'][:,int(setup_dict['params']['mean'].shape[1]/2):] 290 | setup_metric_dict['MSE_U'] = np.array([ calc_MSE_1D(recon_means_U[i], obs_means_U[i]) for i in range(len(X_z)) ]) 291 | setup_metric_dict['Pearson_R_U'] = np.array([ stats.pearsonr(recon_means_U[i], obs_means_U[i])[0] for i in range(len(X_z)) ]) 292 | setup_metric_dict['MSE_S'] = np.array([ calc_MSE_1D(recon_means_S[i], obs_means_S[i]) for i in range(len(X_z)) ]) 293 | setup_metric_dict['Pearson_R_S'] = np.array([ stats.pearsonr(recon_means_S[i], obs_means_S[i])[0] for i in range(len(X_z)) ]) 294 | 295 | 296 | # if (('.P' not in setup) and ('const' not in name)): 297 | # recon_disp = setup_dict['params']['dispersions'] 298 | # setup_metric_dict['alpha correlation'] = stats.pearsonr(sim_disp[0],recon_disp[0,:2000])[0] 299 | setup_metric_dict['nearest_neighbors'] = nn_percentages(X_z,N,cluster_assignments) 300 | 301 | metric_dict[setup] = setup_metric_dict 302 | 303 | return(metric_dict) 304 | 305 | -------------------------------------------------------------------------------- /Manuscript/analysis/preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # Preprocess Allen Data 5 | 6 | 7 | # argument parser 8 | import argparse 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument('--name', type=str) 12 | parser.add_argument('--data_dir', type=str, default = '../data/allen/') 13 | args = parser.parse_args() 14 | 15 | name = args.name 16 | data_dir = args.data_dir 17 | 18 | 19 | # system 20 | import os, sys 21 | 22 | # numbers 23 | import numpy as np 24 | 25 | import pandas as pd 26 | 27 | #sc 28 | import anndata 29 | import scanpy as sc 30 | 31 | # Plots 32 | import matplotlib 33 | import matplotlib.pyplot as plt 34 | import seaborn as sns 35 | 36 | 37 | 38 | # load in raw loom file: 39 | adata = sc.read_loom(data_dir+f'allen_{name}_raw.loom') 40 | 41 | 42 | 43 | # In[348]: 44 | 45 | 46 | # load in metadata 47 | allen_membership = pd.read_csv(data_dir+'/cluster.membership.csv',skiprows = 1, names=['barcode','cluster_id']) 48 | allen_annot = pd.read_csv(data_dir+'/cluster.annotation.csv') 49 | allen_membership['cell_barcode'] = allen_membership['barcode'].str[:16] 50 | allen_membership['sample'] = allen_membership['barcode'].str[-3:] 51 | allen_membership['cluster_id'] = allen_membership['cluster_id'].astype("category") 52 | allen_annot.set_index('cluster_id',inplace=True) 53 | allen_annot_bc = allen_annot.loc[allen_membership['cluster_id']][['cluster_label','subclass_label','class_label']].set_index(allen_membership.index) 54 | meta = pd.concat((allen_membership,allen_annot_bc),axis=1) 55 | 56 | # choose the sample to work on 57 | meta_name = meta[meta['sample'] == name] 58 | 59 | 60 | # In[349]: 61 | 62 | 63 | # subset for cells observed in metadata -- remove all others 64 | index = [adata.obs['barcode'][i] in np.array(meta_name['cell_barcode']) for i in range(len(adata))] 65 | 66 | adata_A = adata[index,:] 67 | 68 | 69 | # In[350]: 70 | 71 | 72 | S = adata_A.layers['spliced'][:] 73 | U = adata_A.layers['unspliced'][:] 74 | S_old = adata.layers['spliced'][:] 75 | U_old = adata.layers['unspliced'][:] 76 | 77 | 78 | # In[351]: 79 | 80 | 81 | def knee_plot(S): 82 | UMI_sorted = np.sort(np.array(S.sum(1)).flatten()) 83 | x_range = range(len(UMI_sorted))[::-1] 84 | 85 | plt.scatter(x_range,UMI_sorted,c='k',s=5) 86 | plt.yscale('log') 87 | plt.xscale('log') 88 | plt.xlabel('# UMI') 89 | plt.ylabel('cell rank') 90 | plt.hlines(10**4,xmin=0,xmax= len(x_range)+1000,colors='red',linestyles='dashed',label='10^4') 91 | plt.vlines(10**4,ymin=0,ymax= 10**5,colors='red',linestyles='dashed') 92 | plt.grid() 93 | plt.legend() 94 | plt.title('Cell Rank vs. UMI ') 95 | 96 | 97 | # In[ ]: 98 | 99 | 100 | # visualize knee plot, use to filter data 101 | # knee_plot(S_old+U_old) 102 | 103 | 104 | # In[ ]: 105 | 106 | 107 | cluster_ids = [] 108 | cluster_labels = [] 109 | subclass_labels = [] 110 | class_labels = [] 111 | 112 | for i in range(len(adata_A)): 113 | 114 | barcode = adata_A.obs['barcode'][i] 115 | 116 | index = np.where(np.array(meta_name['cell_barcode']) == barcode)[0][0] 117 | cluster_id = meta_name['cluster_id'].to_list()[index] 118 | cluster_label = meta_name['cluster_label'].to_list()[index] 119 | subclass_label = meta_name['subclass_label'].to_list()[index] 120 | class_label = meta_name['class_label'].to_list()[index] 121 | 122 | cluster_ids.append(cluster_id) 123 | cluster_labels.append(cluster_label) 124 | subclass_labels.append(subclass_label) 125 | class_labels.append(class_label) 126 | 127 | 128 | # In[ ]: 129 | 130 | 131 | adata_A.obs['cluster_id'] = cluster_ids 132 | adata_A.obs['cluster_label'] = cluster_labels 133 | adata_A.obs['subclass_label'] = subclass_labels 134 | adata_A.obs['class_label'] = class_labels 135 | adata_A.obs['Cell Type'] = subclass_labels 136 | 137 | 138 | # Remove low quality cells 139 | adata_A = adata_A[adata_A.obs['Cell Type'] != 'Low Quality',:] 140 | 141 | # Also remove doublets cells 142 | adata_A = adata_A[adata_A.obs['Cell Type'] != 'doublet',:] 143 | 144 | 145 | # Now, find highly variable genes. 146 | # normalize, log1p, then select highly variable genes :) 147 | 148 | sc.pp.normalize_total(adata_A, target_sum=1e4) 149 | sc.pp.log1p(adata_A) 150 | sc.pp.highly_variable_genes(adata_A, n_top_genes=2000, min_mean=0.0125, max_mean=3, min_disp=0.5) 151 | 152 | # Subset to highly variable genes 153 | adata_s = adata_A[:, adata_A.var.highly_variable] 154 | 155 | 156 | # In[ ]: 157 | 158 | 159 | adata_old = adata_s 160 | adata_spliced = anndata.AnnData(adata_A.layers['spliced']) 161 | adata_unspliced = anndata.AnnData(adata_A.layers['unspliced']) 162 | 163 | adata_spliced.var = adata_A.var.copy() 164 | adata_unspliced.var = adata_A.var.copy() 165 | adata_spliced.var['Spliced'] = True 166 | adata_unspliced.var['Spliced'] = False 167 | adata_unspliced.var_names = adata_unspliced.var_names + '-u' 168 | 169 | adata = anndata.concat([adata_unspliced,adata_spliced],axis=1) 170 | ## Change AnnData expression to raw counts for negative binomial distribution 171 | adata.layers["counts"] = adata.X.copy() # preserve counts 172 | 173 | # Update obs,var 174 | adata.obs = adata_old.obs.copy() 175 | 176 | 177 | # In[ ]: 178 | 179 | 180 | adata.write_loom(f'../data/allen/{name}_processed.loom') 181 | 182 | 183 | # In[ ]: 184 | 185 | 186 | adata_hv = adata[:, adata.var.highly_variable] 187 | adata_hv.write_loom(f'../data/allen/{name}_processed_hv.loom') 188 | 189 | -------------------------------------------------------------------------------- /Manuscript/analysis/preprocess.sh: -------------------------------------------------------------------------------- 1 | names=("A08" "B08" "B01" "C01" "F08" "H12" "A02") 2 | 3 | for name in "${names[@]}" 4 | do 5 | # logdir="out/${dataset}/data" 6 | # loomfile="data/loom_10x_kb/allen_${name}_raw.loom" 7 | 8 | python3 preprocess.py --data_dir "../../data/allen/" \ 9 | --name ${name} 10 | 11 | # python run_scBIVI.py --datadir "${logdir}/preprocessed.h5ad" \ 12 | # --percent_keep "1" \ 13 | # --cluster_method 'RNA_leiden' 14 | 15 | done 16 | -------------------------------------------------------------------------------- /Manuscript/analysis/requirements.txt: -------------------------------------------------------------------------------- 1 | scanpy 2 | scvi-tools==1.2.2 3 | loompy 4 | leidenalg 5 | anndata 6 | -------------------------------------------------------------------------------- /Manuscript/analysis/train_biVI.py: -------------------------------------------------------------------------------- 1 | # # Train scBIVI 2 | # 3 | # This script trains and stores results for different models 4 | 5 | 6 | # argument parser 7 | import argparse 8 | parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument('--name', type=str) 11 | parser.add_argument('--data_dir', type=str, default = '../data/simulated_data/') 12 | args = parser.parse_args() 13 | 14 | name = args.name 15 | data_dir = args.data_dir 16 | 17 | 18 | 19 | # System 20 | import time, gc 21 | 22 | # add module paths to sys path 23 | import sys 24 | sys.path.insert(0, '../../BIVI/BIVI') 25 | 26 | # Math 27 | import numpy as np 28 | import pandas as pd 29 | import torch 30 | from sklearn.model_selection import StratifiedKFold 31 | 32 | # to save results 33 | import pickle 34 | 35 | # scvi 36 | import anndata 37 | import scvi 38 | 39 | 40 | 41 | # import biVI scripts 42 | import biVI 43 | 44 | # memory usage 45 | 46 | from memory_profiler import profile 47 | 48 | # reproducibility -- set random seeds 49 | scvi._settings.ScviConfig.seed=(8675309) 50 | torch.manual_seed(8675309) 51 | np.random.seed(8675309) 52 | 53 | 54 | # first, clear out cuda 55 | torch.cuda.empty_cache() 56 | gc.collect() 57 | # # Load in data 58 | # 59 | # 60 | # Change data name to test out different simulated datasets with varying number of celltypes. 61 | 62 | 63 | # ================================================================================================================== 64 | 65 | # change to hdf5 file if that is what you store data as 66 | adata = anndata.read_loom(data_dir+f'{name}.loom') 67 | 68 | if 'gene_name' in adata.var.columns: 69 | adata.var_names = adata.var['gene_name'].to_list() 70 | 71 | # can change as necessary for data. 72 | adata.obs['Cluster'] = adata.obs['Cell Type'] 73 | adata.var_names_make_unique() 74 | 75 | 76 | #Set up train/test data splits with 5-fold split 77 | skf = StratifiedKFold(n_splits=5, random_state=None, shuffle=False) 78 | skf_splits = skf.split(adata, adata.obs['Cluster']) 79 | 80 | # Use last of the K-fold splits 81 | for k, (train_index, test_index) in enumerate(skf_splits): 82 | pass 83 | 84 | 85 | 86 | print(f'training on {len(train_index)} cells, testing on {len(test_index)} cells') 87 | 88 | 89 | # ================================================================================================================== 90 | 91 | # # Define training function 92 | 93 | 94 | # if anything goes wrong in training, this will catch where it happens 95 | torch.autograd.set_detect_anomaly(True) 96 | 97 | 98 | # compare setups 99 | def compare_setups(adata, setups, results_dict, hyperparameters, train_index = train_index, test_index = test_index): 100 | ''' Runs scBIVI on adata for listed setups in setups given hyperparameters, stores outputs in results_dict. 101 | Train index and test index are defined globally -- could be nice to pass these in as well? 102 | ''' 103 | 104 | lr = hyperparameters['lr'] 105 | max_epochs = hyperparameters['max_epochs'] 106 | n_hidden = hyperparameters['n_hidden'] 107 | n_layers = hyperparameters['n_layers'] 108 | 109 | 110 | for setup in setups: 111 | print(setup, 'with non-linear decoder') 112 | method,n_latent,constant, = setup.split("-") 113 | n_latent = int(n_latent) 114 | 115 | # test using only spliced or unspliced in vanilla scVI 116 | if '.S' in method: 117 | adata_in = adata[:,adata.var['Spliced']==1] 118 | print('spliced') 119 | elif '.U' in method: 120 | adata_in = adata[:,adata.var['Spliced']==0] 121 | print('unspliced') 122 | else: 123 | adata_in = adata 124 | 125 | print(adata_in.X.shape) 126 | #biVI.biVI.setup_anndata(adata_in,layer="counts") 127 | #categorical_covariate_keys=["cell_source", "donor"], 128 | #continuous_covariate_keys=["percent_mito", "percent_ribo"]) 129 | 130 | 131 | train_adata, test_adata = adata_in[train_index], adata_in[test_index] 132 | train_adata = train_adata.copy() 133 | test_adata = test_adata.copy() 134 | if 'scVI' in method: 135 | scvi.model.SCVI.setup_anndata(test_adata,layer="counts") 136 | scvi.model.SCVI.setup_anndata(train_adata,layer="counts") 137 | else: 138 | biVI.biVI.setup_anndata(test_adata,layer="counts") 139 | biVI.biVI.setup_anndata(train_adata,layer="counts") 140 | 141 | 142 | ## Set model parameters 143 | model_args = { 144 | 'n_latent' : n_latent, 145 | 'n_layers' : n_layers, 146 | 'dispersion' : 'gene', 147 | 'n_hidden' : n_hidden, 148 | 'dropout_rate' : 0.1, 149 | 'gene_likelihood' : 'nb', 150 | 'log_variational' : True, 151 | 'latent_distribution': 'normal', 152 | } 153 | #model_args.update(additional_kwargs) 154 | 155 | ## Create model 156 | if method == 'Extrinsic': 157 | model = biVI.biVI(train_adata,mode='NBcorr',**model_args) 158 | elif method == 'NBuncorr': 159 | model = biVI.biVI(train_adata,mode='NBuncorr',**model_args) 160 | elif method == 'Constitutive': 161 | model = biVI.biVI(train_adata,mode='Poisson',**model_args) 162 | elif method == 'Bursty': 163 | model = biVI.biVI(train_adata,mode='Bursty',**model_args) 164 | elif method == 'vanilla.U': 165 | model_args['gene_likelihood'] = 'nb' 166 | model = scvi.model.SCVI(train_adata,**model_args) 167 | elif method == 'vanilla.S': 168 | model_args['gene_likelihood'] = 'nb' 169 | model = scvi.model.SCVI(train_adata,**model_args) 170 | elif method == 'scVI': 171 | model_args['gene_likelihood'] = 'nb' 172 | model = scvi.model.SCVI(train_adata,**model_args) 173 | elif method == 'vanilla.U.P': 174 | model_args['gene_likelihood'] = 'poisson' 175 | model = scvi.model.SCVI(train_adata,**model_args) 176 | elif method == 'vanilla.S.P': 177 | model_args['gene_likelihood'] = 'poisson' 178 | model = scvi.model.SCVI(train_adata,**model_args) 179 | elif method == 'vanilla.full.P': 180 | model_args['gene_likelihood'] = 'poisson' 181 | model = scvi.model.SCVI(train_adata,**model_args) 182 | else: 183 | raise Exception('Input valid scVI model') 184 | 185 | ## Train model 186 | plan_kwargs = {'lr' : lr, 187 | 'n_epochs_kl_warmup' : max_epochs/2, 188 | } 189 | 190 | start = time.time() 191 | model.train(max_epochs = max_epochs, 192 | #early_stopping_monitor = ["reconstruction_loss_validation"], 193 | train_size = 0.9, 194 | check_val_every_n_epoch = 1, 195 | plan_kwargs = plan_kwargs) 196 | 197 | 198 | runtime = time.time() - start 199 | memory_used = torch.cuda.memory_allocated() 200 | results_dict[setup]['runtime'].append(runtime) 201 | 202 | ## Save training history 203 | df_history = {'reconstruction_error_train_set' : [model.history['reconstruction_loss_train']], 204 | 'reconstruction_error_test_set': [model.history['reconstruction_loss_validation']]} 205 | 206 | 207 | results_dict[setup]['df_history'] = df_history 208 | 209 | ## Get reconstruction loss on test data 210 | test_error = model.get_reconstruction_error(test_adata) 211 | train_error = model.get_reconstruction_error(train_adata) 212 | results_dict[setup]['recon_error'].append(np.array([train_error,test_error])) 213 | 214 | # get reconstructed parameters 215 | results_dict[setup]['params'] = model.get_likelihood_parameters(adata_in) 216 | results_dict[setup]['norm_params'] = model.get_normalized_expression(adata_in) 217 | 218 | ## Extract the embedding space for scVI 219 | X_out_full = model.get_latent_representation(adata_in) 220 | 221 | adata.obsm[f'X_{method}'] = X_out_full 222 | results_dict[setup][f'X_{n_latent}'] = X_out_full 223 | 224 | results_dict[setup]['memory_used'] = torch.cuda.memory_allocated() 225 | # save model for future testing 226 | 227 | print('save path',f'../../results/{method}_model_{name}_linear') 228 | if 'Bursty' in method: 229 | model.save(f'../../results/{method}_{name}_MODEL',overwrite=True) 230 | print('model saved') 231 | 232 | del model 233 | torch.cuda.empty_cache() 234 | gc.collect() 235 | 236 | 237 | return(results_dict,adata) 238 | 239 | 240 | # ============================================================================================================== 241 | # # Compare Distributions 242 | 243 | # Can change various training hyperparameters. 244 | 245 | print('Training non-linear models') 246 | 247 | # Hyper-parameters 248 | hyperparameters = { 'lr' : 1e-5, 249 | 'max_epochs' : 400, 250 | 'n_hidden' : 128, 251 | 'n_layers' : 3 } 252 | 253 | z = 10 254 | constant = 'NAS_SHAPE' 255 | 256 | setups = [ 257 | f'scVI-{z}-{constant}', 258 | f'Bursty-{z}-{constant}', 259 | f'Constitutive-{z}-{constant}', 260 | f'Extrinsic-{z}-{constant}' 261 | ] 262 | 263 | metrics_list = [f'X_{z}','runtime','df_history','params','recon_error','norm_means'] 264 | results_dict = {setup:{metrics: [] for metrics in metrics_list} for setup in setups} 265 | 266 | if __name__ == "__main__": 267 | results_dict, adata = compare_setups(adata, setups, results_dict, hyperparameters, train_index = train_index, test_index = test_index) 268 | 269 | 270 | # results_dict, adata = compare_setups(adata, setups,results_dict,hyperparameters) 271 | results_dict['Cell Type'] = adata.obs['Cell Type'] 272 | results_dict['train_index'] = train_index 273 | results_dict['test_index'] = test_index 274 | 275 | # # Save results dict 276 | 277 | results_file = open(f"../results/{name}_results_dict_092523.pickle", "wb") 278 | pickle.dump(results_dict, results_file) 279 | results_file.close() 280 | 281 | 282 | -------------------------------------------------------------------------------- /Manuscript/analysis/train_biVI.sh: -------------------------------------------------------------------------------- 1 | # pip install scanpy -q 2 | # pip install scvi-tools==0.8.1 -q 3 | # pip install loompy -q 4 | # pip install leidenalg -q 5 | #pip install --upgrade torch==1.12.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html -q 6 | 7 | #allen data 8 | python3 train_biVI.py --name B08_processed_hv --data_dir ../data/allen/ 9 | 10 | #simulated data 11 | python3 train_biVI.py --name 'bursty_20ct_many' --data_dir ../data/simulated_data/ 12 | python3 train_biVI.py --name 'const_20ct_many' --data_dir ../data/simulated_data/ 13 | python3 train_biVI.py --name 'extrinsic_20ct_many' --data_dir ../data/simulated_data/ 14 | 15 | python3 train_biVI.py --name A08_processed_hv --data_dir ../data/allen/ 16 | python3 train_biVI.py --name B01_processed_hv --data_dir ../data/allen/ 17 | python3 train_biVI.py --name C01_processed_hv --data_dir ../data/allen/ 18 | python3 train_biVI.py --name F08_processed_hv --data_dir ../data/allen/ 19 | python3 train_biVI.py --name H12_processed_hv --data_dir ../data/allen/ 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CGCCP_2023 2 | This repository contains the scripts and notebooks for the preprint "Biophysical modeling with variational autoencoders for bimodal, single-cell RNA sequencing data". Although variational autoencoders can be treated as purely phenomenological and descriptive, without any explicit claims about the processes that generated data, it is possible to exploit *implicit* physics encoded in the mathematical formulation to model biophysical processes that gave rise to observations. By interpreting the _scVI_ generative model as a description of a particular biophysical model, we can represent bivariate RNA distributions. We benchmark the implementation, _biVI_, on simulated and biological data. 3 | 4 | `BIVI/` contains all of the scripts used to implement _biVI_, while `Manuscript/analysis/` contains all of the notebooks and scripts used to generate the manuscript figures and results. `Example/` contains `kb_pipeline.sh,` a script that demonstrates how to align raw reads to a reference genome to obtain the unspliced/spliced count matrices necessary for _biVI_, and 'Demo.ipynb,` a Google Colaboratory notebook that processes the output matrices, train a _biVI_ model, and visualize the results. 5 | 6 | 7 | 8 | 9 | 10 | The biVI software can be installed as a standalone package using the following command: 11 | 12 | 13 | pip3 install git+https://github.com/pachterlab/CGCCP_2023.git#subdirectory=BIVI . 14 | 15 | 16 | 17 | If package dependencies cause installation issues, create a clean Conda environment and rerun the installation: 18 | 19 | 20 | conda create --name biVI_env python==3.9 21 | 22 | pip3 install git+https://github.com/pachterlab/CGCCP_2023.git#subdirectory=BIVI . 23 | 24 | 25 | Installation takes one to several minutes on a standard laptop. Alternatively, _biVI_ can be run in a Google Colab notebook, an example of which is given in `Example/Demo.ipynb`. 26 | --------------------------------------------------------------------------------