├── .gitignore ├── LICENSE ├── README.md ├── demos ├── manuscript_analysis │ ├── 10x_pbmc_demo.ipynb │ ├── lymph_node_demo.ipynb │ ├── paired_cellline_demo.ipynb │ ├── sciCAR_cellline_demo.ipynb │ ├── share_skin_demo.ipynb │ ├── snare_cellline_demo.ipynb │ ├── snare_p0_demo.ipynb │ └── snare_p0_no_pretrain.ipynb └── scMVP_tutorial.ipynb ├── scMVP ├── __init__.py ├── _settings.py ├── dataset │ ├── __init__.py │ ├── dataset.py │ └── scMVP_dataloader.py ├── inference │ ├── __init__.py │ ├── annotation.py │ ├── autotune.py │ ├── inference.py │ ├── multi_inference.py │ ├── posterior.py │ └── trainer.py └── models │ ├── __init__.py │ ├── classifier.py │ ├── log_likelihood.py │ ├── modules.py │ ├── multi_vae_attention.py │ ├── utils.py │ ├── vaePeak_selfattetion.py │ └── vae_attention.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # custom 2 | .idea/ 3 | venv/ 4 | data/ 5 | *.pkl 6 | *.csv 7 | *.tsv 8 | *.mtx 9 | *.xls 10 | demos/dataset/*.txt 11 | demos/dataset/*/*.txt 12 | *.bed 13 | analysis_script 14 | requirement.txt 15 | docs/ 16 | *.gz 17 | nohup.out 18 | 19 | demos/.ipynb_checkpoints/ 20 | spatstat_1.64-1.tar.gz 21 | demos/snare_p0_old_model.ipynb 22 | demos/demo.ipynb 23 | demos/demo_share_format.py 24 | 25 | # demo folders 26 | demos/share_skin/ 27 | demos/share_skin/ 28 | demos/10x_pbmc/ 29 | demos/paired_cellline/ 30 | demos/snare_cellline/ 31 | demos/sciCAR_cellline/ 32 | demos/snare_p0/ 33 | demos/paired_adult.ipynb 34 | 35 | # Byte-compiled / optimized / DLL files 36 | __pycache__/ 37 | *.py[cod] 38 | *$py.class 39 | # necessary files 40 | # C extensions 41 | *.so 42 | 43 | # DS_Store 44 | .DS_Store 45 | 46 | # Distribution / packaging 47 | .Python 48 | env/ 49 | build/ 50 | develop-eggs/ 51 | dist/ 52 | downloads/ 53 | eggs/ 54 | .eggs/ 55 | lib/ 56 | lib64/ 57 | parts/ 58 | sdist/ 59 | var/ 60 | wheels/ 61 | *.egg-info/ 62 | .installed.cfg 63 | *.egg 64 | 65 | # PyInstaller 66 | # Usually these files are written by a python script from a template 67 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 68 | *.manifest 69 | *.spec 70 | 71 | # Installer logs 72 | pip-log.txt 73 | pip-delete-this-directory.txt 74 | 75 | # Unit test / coverage reports 76 | htmlcov/ 77 | .tox/ 78 | .coverage 79 | .coverage.* 80 | .cache 81 | nosetests.xml 82 | coverage.xml 83 | *.cover 84 | .hypothesis/ 85 | .pytest_cache/ 86 | 87 | # Translations 88 | *.mo 89 | *.pot 90 | 91 | # Django stuff: 92 | *.log 93 | local_settings.py 94 | 95 | # Flask stuff: 96 | instance/ 97 | .webassets-cache 98 | 99 | # Scrapy stuff: 100 | .scrapy 101 | 102 | # Sphinx documentation 103 | docs/_build/ 104 | docs/authors.rst 105 | 106 | # PyBuilder 107 | target/ 108 | 109 | # Jupyter Notebook 110 | .ipynb_checkpoints 111 | 112 | # pyenv 113 | .python-version 114 | 115 | # celery beat schedule file 116 | celerybeat-schedule 117 | 118 | # SageMath parsed files 119 | *.sage.py 120 | 121 | # dotenv 122 | .env 123 | 124 | # virtualenv 125 | .venv 126 | venv/ 127 | ENV/ 128 | 129 | # Spyder project settings 130 | .spyderproject 131 | .spyproject 132 | 133 | # Rope project settings 134 | .ropeproject 135 | 136 | # mkdocs documentation 137 | /site 138 | 139 | # mypy 140 | .mypy_cache/ 141 | 142 | # PyCharm 143 | .idea/ 144 | 145 | # Floobits 146 | .floo 147 | .flooignore 148 | 149 | 150 | /data 151 | .vscode/settings.json 152 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Romain Lopez, 2018 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # scMVP - single cell Multi-View Profiler 2 | 3 | scMVP is a python toolkit for joint profiling of scRNA and scATAC data and analysis 4 | with multi-modal self-attention generation model. 5 | 6 | ## Update logs 7 | - 20220815
8 | Complete the test for GPU version on NVIDA 3090 platform with CUDA version 11.2.
9 | Add scRNA and scATAC input check in tutorial.
10 | 11 | 12 | ## Installation 13 | **Environment requirements:**
14 | scMVP requires Python3.7.x and [**Pytorch**](http://pytorch.org).
15 | For example, use [**miniconda**](https://conda.io/miniconda.html) to install python and pytorch of CPU or GPU version. 16 | We have tested the GPU version on NVIDA 1080Ti platform with CUDA version 10.2. 17 | 18 | ```Bash 19 | conda install -c pytorch python=3.7 pytorch 20 | # if you do not have jupyter notebook/ipython notebook, you can also install by conda 21 | conda install jupyter 22 | ``` 23 | 24 | Then you can install scMVP from github repo:
25 | ```Bash 26 | # first move to your target directory 27 | git clone https://github.com/bm2-lab/scMVP.git 28 | cd scMVP/ 29 | python setup.py install 30 | ``` 31 | 32 | Try ```import scMVP``` in your python console and start your first [**tutorial**](demos/scMVP_tutorial.ipynb) with scMVP! 33 | 34 | Jupyter notebooks for other datasets analyzed and benchmarked in our GB paper are deposited in [**folder**](demos/manuscript_analysis/) 35 | 36 | ### All processed dataset and trained models:
37 | Download link: [baidu cloud disk](https://pan.baidu.com/s/183jLROAUuNfVKCeBY4B4DQ)
38 | - update 23/10/22 : google drive link for cellline datasets: [google drive](https://drive.google.com/drive/folders/18ymTLyMb_wD20O4Z2qkOXBQt5yoDTvea?usp=sharing)
39 | 40 | Download code: mkij
41 | - pre_trainer.pkl scRNA pretraining models
42 | - pre_atac_trainer.pkl scATAC pretraining models
43 | - multi_vae_trainer.pkl scMVP training models
44 | 45 | 46 | ## User tutorial 47 | 48 | Applying scMVP to sci-CAR cell line mixture. [**demo**](demos/scMVP_tutorial.ipynb) 49 | - Training and visualization with scMVP. 50 | - Pretraining and transferring to scMVP(perform better in large dataset). 51 | 52 | 53 | 54 | ### Reference 55 | A deep generative model for multi-view profiling of single-cell RNA-seq and ATAC-seq data. Genome Biology 2022 [**paper**](https://genomebiology.biomedcentral.com/articles/10.1186/s13059-021-02595-6) 56 | 57 | 58 | ### Contact Authors 59 | Prof. Qi Liu: [qiliu@tongji.edu.cn](qiliu@tongji.edu.cn)
60 | Dr. Gaoyang Li: [lgyzngc@gmail.com](lgyzngc@gmail.com)
61 | Shaliu Fu: [adam.tongji@gmail.com](adam.tongji@gmail.com)
62 | 63 | -------------------------------------------------------------------------------- /scMVP/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | __author__ = "Gaoyang Li" 5 | __email__ = "lgyzngc@tongji.edu.cn" 6 | __version__ = "0.0.1" 7 | 8 | # Set default logging handler to avoid logging with logging.lastResort logger. 9 | import logging 10 | from logging import NullHandler 11 | 12 | from ._settings import set_verbosity 13 | 14 | logger = logging.getLogger(__name__) 15 | logger.addHandler(NullHandler()) 16 | 17 | # default to INFO level logging for the scMVP package 18 | set_verbosity(logging.INFO) 19 | 20 | __all__ = ["set_verbosity"] 21 | -------------------------------------------------------------------------------- /scMVP/_settings.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Union 3 | 4 | logger = logging.getLogger(__name__) 5 | scMVP_logger = logging.getLogger("scMVP") 6 | 7 | autotune_formatter = logging.Formatter( 8 | "[%(asctime)s - %(processName)s - %(threadName)s] %(levelname)s - %(name)s\n%(message)s" 9 | ) 10 | 11 | 12 | class DispatchingFormatter(logging.Formatter): 13 | """Dispatch formatter for logger and it's sub logger.""" 14 | 15 | def __init__(self, default_formatter, formatters=None): 16 | super().__init__() 17 | self._formatters = formatters if formatters is not None else {} 18 | self._default_formatter = default_formatter 19 | 20 | def format(self, record): 21 | # Search from record's logger up to it's parents: 22 | logger = logging.getLogger(record.name) 23 | while logger: 24 | # Check if suitable formatter for current logger exists: 25 | if logger.name in self._formatters: 26 | formatter = self._formatters[logger.name] 27 | break 28 | else: 29 | logger = logger.parent 30 | else: 31 | # If no formatter found, just use default: 32 | formatter = self._default_formatter 33 | return formatter.format(record) 34 | 35 | 36 | def set_verbosity(level: Union[str, int]): 37 | """Sets logging configuration for scMVP based on chosen level of verbosity. 38 | 39 | Sets "scMVP" logging level to `level` 40 | If "scMVP" logger has no StreamHandler, add one. 41 | Else, set its level to `level`. 42 | """ 43 | scMVP_logger.setLevel(level) 44 | has_streamhandler = False 45 | for handler in scMVP_logger.handlers: 46 | if isinstance(handler, logging.StreamHandler): 47 | handler.setLevel(level) 48 | logger.info( 49 | "'scMVP' logger already has a StreamHandler, set its level to {}.".format( 50 | level 51 | ) 52 | ) 53 | has_streamhandler = True 54 | 55 | 56 | if not has_streamhandler: 57 | ch = logging.StreamHandler() 58 | formatter = logging.Formatter( 59 | "[%(asctime)s] %(levelname)s - %(name)s | %(message)s" 60 | ) 61 | ch.setFormatter( 62 | DispatchingFormatter(formatter, {"scMVP.autotune": autotune_formatter}) 63 | ) 64 | scMVP_logger.addHandler(ch) 65 | logger.info("Added StreamHandler with custom formatter to 'scMVP' logger.") 66 | -------------------------------------------------------------------------------- /scMVP/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from scMVP.dataset.dataset import ( 2 | GeneExpressionDataset, 3 | CellMeasurement, 4 | ) 5 | 6 | from scMVP.dataset.scMVP_dataloader import LoadData 7 | 8 | __all__ = [ 9 | "CellMeasurement", 10 | "GeneExpressionDataset", 11 | "LoadData", 12 | ] 13 | -------------------------------------------------------------------------------- /scMVP/dataset/scMVP_dataloader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import urllib 4 | import numpy as np 5 | import pandas as pd 6 | import scipy.io as sp_io 7 | from scipy.sparse import csr_matrix, issparse 8 | 9 | from scMVP.dataset.dataset import CellMeasurement, GeneExpressionDataset 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | available_specification = ["filtered", "raw"] 14 | 15 | 16 | class LoadData(GeneExpressionDataset): 17 | 18 | """ 19 | Dataset format: 20 | dataset = { 21 | "gene_barcodes": xxx, 22 | "gene_expression": xxx, 23 | "gene_names": xxx, 24 | "atac_barcodes": xxx, 25 | "atac_expression": xxx, 26 | "atac_names": xxx, 27 | } 28 | OR 29 | dataset = { 30 | "gene_expression":xxx, 31 | "atac_expression":xxx, 32 | } 33 | """ 34 | def __init__(self, 35 | dataset: dict = None, 36 | data_path: str = "dataset/", 37 | dense: bool = False, 38 | measurement_names_column: int = 0, 39 | remove_extracted_data: bool = False, 40 | delayed_populating: bool = False, 41 | file_separator: str = "\t", 42 | gzipped: bool = False, 43 | atac_threshold: float = 0.0001, # express in over 0.01% 44 | cell_threshold: int = 1, # filtering cells less than minimum count 45 | cell_meta: pd.DataFrame = None, 46 | ): 47 | 48 | self.dataset = dataset 49 | self.data_path = data_path 50 | self.barcodes = None 51 | self.dense = dense 52 | self.measurement_names_column = measurement_names_column 53 | self.remove_extracted_data = remove_extracted_data 54 | self.file_separator = file_separator 55 | self.gzip = gzipped 56 | self.atac_thres = atac_threshold 57 | self.cell_thres = cell_threshold 58 | self._minimum_input = ("gene_expression", "atac_expression") 59 | self._allow_input = ( 60 | "gene_expression", "atac_expression", 61 | "gene_barcodes", "gene_names", 62 | "atac_barcodes", "atac_names" 63 | ) 64 | self.cell_meta = cell_meta 65 | super().__init__() 66 | if not delayed_populating: 67 | self.populate() 68 | 69 | def populate(self): 70 | logger.info("Preprocessing joint profiling dataset.") 71 | if not self._input_check(): 72 | logger.info("Please reload your dataset.") 73 | return 74 | joint_profiles = {} 75 | if len(self.dataset.keys()) == 2: 76 | # for _key in self.dataset.keys(): 77 | if self.gzip: 78 | _tmp = pd.read_csv("{}/{}".format(self.data_path,self.dataset["gene_expression"]), sep=self.file_separator, 79 | header=0, index_col=0, compression="gzip") 80 | else: 81 | _tmp = pd.read_csv("{}/{}".format(self.data_path,self.dataset["gene_expression"]), sep=self.file_separator, 82 | header=0) 83 | joint_profiles["gene_barcodes"] = pd.DataFrame(_tmp.columns.values) 84 | joint_profiles["gene_names"] = pd.DataFrame(_tmp._stat_axis.values) 85 | joint_profiles["gene_expression"] = np.array(_tmp).T 86 | 87 | if self.gzip: 88 | _tmp = pd.read_csv("{}/{}".format(self.data_path,self.dataset["atac_expression"]), sep=self.file_separator, 89 | header=0, index_col=0, compression="gzip") 90 | else: 91 | _tmp = pd.read_csv("{}/{}".format(self.data_path,self.dataset["atac_expression"]), sep=self.file_separator, 92 | header=0, index_col=0) 93 | joint_profiles["atac_barcodes"] = pd.DataFrame(_tmp.columns.values) 94 | joint_profiles["atac_names"] = pd.DataFrame(_tmp._stat_axis.values) 95 | joint_profiles["atac_expression"] = np.array(_tmp).T 96 | 97 | elif len(self.dataset.keys()) == 6: 98 | for _key in self.dataset.keys(): 99 | if _key == "atac_expression" or _key == "gene_expression" and not self.dense: 100 | joint_profiles[_key] = csr_matrix(sp_io.mmread("{}/{}".format(self.data_path,self.dataset[_key])).T) 101 | 102 | elif self.gzip: 103 | joint_profiles[_key] = pd.read_csv("{}/{}".format(self.data_path, self.dataset[_key]), 104 | sep=self.file_separator, 105 | compression="gzip", header=None) 106 | else: 107 | joint_profiles[_key] = pd.read_csv("{}/{}".format(self.data_path,self.dataset[_key]), 108 | sep=self.file_separator, header=None) 109 | 110 | else: 111 | logger.info("more than 6 inputs.") 112 | 113 | ## 200920 gene barcode file may include more than 1 column 114 | if joint_profiles["gene_names"].shape[1] > 1: 115 | joint_profiles["gene_names"] = pd.DataFrame(joint_profiles["gene_names"].iloc[:,1]) 116 | if joint_profiles["atac_names"].shape[1] > 1: 117 | joint_profiles["atac_names"] = pd.DataFrame(joint_profiles["atac_names"].iloc[:,1]) 118 | share_index, gene_barcode_index, atac_barcode_index = np.intersect1d(joint_profiles["gene_barcodes"].values, 119 | joint_profiles["atac_barcodes"].values, 120 | return_indices=True) 121 | if isinstance(self.cell_meta,pd.DataFrame): 122 | if self.cell_meta.shape[1] < 2: 123 | logger.info("Please use cell id in first column and give ata least 2 columns.") 124 | return 125 | meta_cell_id = self.cell_meta.iloc[:,0].values 126 | meta_share, meta_barcode_index, share_barcode_index =\ 127 | np.intersect1d(meta_cell_id, 128 | share_index, return_indices=True) 129 | _gene_barcode_index = gene_barcode_index[share_barcode_index] 130 | _atac_barcode_index = atac_barcode_index[share_barcode_index] 131 | if len(_gene_barcode_index) < 2: # no overlaps 132 | logger.info("Inconsistent metadata to expression data.") 133 | return 134 | tmp = joint_profiles["gene_barcodes"] 135 | joint_profiles["gene_barcodes"] = tmp.loc[_gene_barcode_index, :] 136 | temp = joint_profiles["atac_barcodes"] 137 | joint_profiles["atac_barcodes"] = temp.loc[_atac_barcode_index, :] 138 | 139 | else: 140 | # reorder rnaseq cell meta 141 | tmp = joint_profiles["gene_barcodes"] 142 | joint_profiles["gene_barcodes"] = tmp.loc[gene_barcode_index,:] 143 | temp = joint_profiles["atac_barcodes"] 144 | joint_profiles["atac_barcodes"] = temp.loc[atac_barcode_index, :] 145 | 146 | gene_tab = joint_profiles["gene_expression"] 147 | if issparse(gene_tab): 148 | joint_profiles["gene_expression"] = gene_tab[gene_barcode_index, :].A 149 | else: 150 | joint_profiles["gene_expression"] = gene_tab[gene_barcode_index, :] 151 | 152 | temp = joint_profiles["atac_expression"] 153 | reorder_atac_exp = temp[atac_barcode_index, :] 154 | binary_index = reorder_atac_exp > 1 155 | reorder_atac_exp[binary_index] = 1 156 | # remove peaks > 10% of total cells 157 | high_count_atacs = ((reorder_atac_exp > 0).sum(axis=0).ravel() >= self.atac_thres * reorder_atac_exp.shape[0]) \ 158 | & ((reorder_atac_exp > 0).sum(axis=0).ravel() <= 0.1 * reorder_atac_exp.shape[0]) 159 | 160 | if issparse(reorder_atac_exp): 161 | high_count_atacs_index = np.where(high_count_atacs) 162 | _tmp = reorder_atac_exp[:, high_count_atacs_index[1]] 163 | joint_profiles["atac_expression"] = _tmp.A 164 | joint_profiles["atac_names"] = joint_profiles["atac_names"].loc[high_count_atacs_index[1], :] 165 | 166 | else: 167 | _tmp = reorder_atac_exp[:, high_count_atacs] 168 | 169 | joint_profiles["atac_expression"] = _tmp 170 | joint_profiles["atac_names"] = joint_profiles["atac_names"].loc[high_count_atacs, :] 171 | 172 | # RNA-seq as the key 173 | Ys = [] 174 | measurement = CellMeasurement( 175 | name="atac_expression", 176 | data=joint_profiles["atac_expression"], 177 | columns_attr_name="atac_names", 178 | columns=joint_profiles["atac_names"].astype(np.str), 179 | ) 180 | Ys.append(measurement) 181 | # Add cell metadata 182 | if isinstance(self.cell_meta,pd.DataFrame): 183 | for l_index, label in enumerate(list(self.cell_meta.columns.values)): 184 | if l_index >0: 185 | label_measurement = CellMeasurement( 186 | name="{}_label".format(label), 187 | data=self.cell_meta.iloc[meta_barcode_index,l_index], 188 | columns_attr_name=label, 189 | columns=self.cell_meta.iloc[meta_barcode_index, l_index] 190 | ) 191 | Ys.append(label_measurement) 192 | logger.info("Loading {} into dataset.".format(label)) 193 | 194 | cell_attributes_dict = { 195 | "barcodes": np.squeeze(np.asarray(joint_profiles["gene_barcodes"], dtype=str)) 196 | } 197 | 198 | logger.info("Finished preprocessing dataset") 199 | 200 | self.populate_from_data( 201 | X=joint_profiles["gene_expression"], 202 | batch_indices=None, 203 | gene_names=joint_profiles["gene_names"].astype(np.str), 204 | cell_attributes_dict=cell_attributes_dict, 205 | Ys=Ys, 206 | ) 207 | self.filter_cells_by_count(self.cell_thres) 208 | 209 | def _input_check(self): 210 | if len(self.dataset.keys()) == 2: 211 | for _key in self.dataset.keys(): 212 | if _key not in self._minimum_input: 213 | logger.info("Unknown input data type:{}".format(_key)) 214 | return False 215 | # if not self.dataset[_key].split(".")[-1] in ["txt","tsv","csv"]: 216 | # logger.debug("scMVP only support two files input of txt, tsv or csv!") 217 | # return False 218 | elif len(self.dataset.keys()) >= 6: 219 | for _key in self._allow_input: 220 | if not _key in self.dataset.keys(): 221 | logger.info("Data type {} missing.".format(_key)) 222 | return False 223 | else: 224 | logger.info("Incorrect input file number.") 225 | return False 226 | for _key in self.dataset.keys(): 227 | if not os.path.exists(self.data_path): 228 | logger.info("{} do not exist!".format(self.data_path)) 229 | if not os.path.exists("{}{}".format(self.data_path, self.dataset[_key])): 230 | logger.info("Cannot find {}{}!".format(self.data_path, self.dataset[_key])) 231 | return False 232 | return True 233 | 234 | def _download(self, url: str, save_path: str, filename: str): 235 | """Writes data from url to file.""" 236 | if os.path.exists(os.path.join(save_path, filename)): 237 | logger.info("File %s already downloaded" % (os.path.join(save_path, filename))) 238 | return 239 | 240 | r = urllib.request.urlopen(url) 241 | logger.info("Downloading file at %s" % os.path.join(save_path, filename)) 242 | 243 | def read_iter(file, block_size=1000): 244 | """Given a file 'file', returns an iterator that returns bytes of 245 | size 'blocksize' from the file, using read().""" 246 | while True: 247 | block = file.read(block_size) 248 | if not block: 249 | break 250 | yield block 251 | 252 | def _add_cell_meta(self, cell_meta, filter=False): 253 | cell_ids = cell_meta.iloc[:,1].values 254 | share_index, meta_barcode_index, gene_barcode_index = \ 255 | np.intersect1d(cell_ids,self.barcodes,return_indices=True) 256 | if len(share_index) <=1: 257 | logger.info("No consistent cell IDs!") 258 | return 259 | if len(share_index) < len(self.barcodes): 260 | logger.info("{} cells match metadata.".format(len(share_index))) 261 | return 262 | 263 | 264 | 265 | class SnareDemo(LoadData): 266 | 267 | def __init__(self, dataset_name: str=None, data_path: str="/dataset", cell_meta: str = None): 268 | url="https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE126074" 269 | available_datasets = { 270 | "CellLineMixture": { 271 | "gene_expression": "GSE126074_CellLineMixture_SNAREseq_cDNA_counts.tsv.gz", 272 | "atac_expression": "GSE126074_CellLineMixture_SNAREseq_chromatin_counts.tsv.gz", 273 | }, 274 | "AdBrainCortex": { 275 | "gene_barcodes": "GSE126074_AdBrainCortex_SNAREseq_cDNA.barcodes.tsv.gz", 276 | "gene_expression": "GSE126074_AdBrainCortex_SNAREseq_cDNA.counts.mtx.gz", 277 | "gene_names": "GSE126074_AdBrainCortex_SNAREseq_cDNA.genes.tsv.gz", 278 | "atac_barcodes": "GSE126074_AdBrainCortex_SNAREseq_chromatin.barcodes.tsv.gz", 279 | "atac_expression": "GSE126074_AdBrainCortex_SNAREseq_chromatin.counts.mtx.gz", 280 | "atac_names": "GSE126074_AdBrainCortex_SNAREseq_chromatin.peaks.tsv.gz", 281 | }, 282 | "P0_BrainCortex": { 283 | "gene_barcodes": "GSE126074_P0_BrainCortex_SNAREseq_cDNA.barcodes.tsv.gz", 284 | "gene_expression": "GSE126074_P0_BrainCortex_SNAREseq_cDNA.counts.mtx.gz", 285 | "gene_names": "GSE126074_P0_BrainCortex_SNAREseq_cDNA.genes.tsv.gz", 286 | "atac_barcodes": "GSE126074_P0_BrainCortex_SNAREseq_chromatin.barcodes.tsv.gz", 287 | "atac_expression": "GSE126074_P0_BrainCortex_SNAREseq_chromatin.counts.mtx.gz", 288 | "atac_names": "GSE126074_P0_BrainCortex_SNAREseq_chromatin.peaks.tsv.gz", 289 | } 290 | } 291 | if cell_meta: 292 | cell_meta_data = pd.read_csv(cell_meta, sep=",", header=0) 293 | else: 294 | cell_meta_data = None 295 | if dataset_name=="CellLineMixture": 296 | super(SnareDemo, self).__init__(dataset = available_datasets[dataset_name], 297 | data_path= data_path, 298 | dense = False, 299 | measurement_names_column = 1, 300 | cell_meta=cell_meta_data, 301 | remove_extracted_data = False, 302 | delayed_populating = False, 303 | file_separator = "\t", 304 | gzipped = True, 305 | atac_threshold = 0.0005, 306 | cell_threshold = 1 307 | ) 308 | elif dataset_name=="AdBrainCortex" or dataset_name=="P0_BrainCortex": 309 | super(SnareDemo, self).__init__(dataset=available_datasets[dataset_name], 310 | data_path=data_path, 311 | dense=False, 312 | measurement_names_column=1, 313 | cell_meta=cell_meta_data, 314 | remove_extracted_data=False, 315 | delayed_populating=False, 316 | gzipped=True, 317 | atac_threshold=0.0005, 318 | cell_threshold=1 319 | ) 320 | else: 321 | logger.info('Please select from "CellLineMixture", "AdBrainCortex" or "P0_BrainCortex" dataset.') 322 | 323 | 324 | class PairedDemo(LoadData): 325 | 326 | def __init__(self, dataset_name: str = None, data_path: str = "/dataset"): 327 | urls = [ 328 | "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE130nnn/GSE130399/suppl/GSE130399_GSM3737488_GSM3737489_Cell_Mix.tar.gz", 329 | "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE130nnn/GSE130399/suppl/GSE130399_GSM3737490-GSM3737495_Adult_Cerebrail_Cortex.tar.gz", 330 | "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE130nnn/GSE130399/suppl/GSE130399_GSM3737496-GSM3737499_Fetal_Forebrain.tar.gz" 331 | ] 332 | 333 | available_datasets = { 334 | "CellLineMixture": { 335 | "gene_names": "Cell_Mix_RNA/genes.tsv", 336 | "gene_expression": "Cell_Mix_RNA/matrix.mtx", 337 | "gene_barcodes": "Cell_Mix_RNA/barcodes.tsv", 338 | "atac_names": "Cell_Mix_DNA/genes.tsv", 339 | "atac_expression": "Cell_Mix_DNA/matrix.mtx", 340 | "atac_barcodes":"Cell_Mix_DNA/barcodes.tsv" 341 | }, 342 | "Adult_Cerebral": { 343 | "gene_names": "Adult_CTX_RNA/genes.tsv", 344 | "gene_expression": "Adult_CTX_RNA/matrix.mtx", 345 | "gene_barcodes": "Adult_CTX_RNA/barcodes.tsv", 346 | "atac_names": "Adult_CTX_DNA/genes.tsv", 347 | "atac_expression": "Adult_CTX_DNA/matrix.mtx", 348 | "atac_barcodes": "Adult_CTX_DNA/barcodes.tsv" 349 | }, 350 | "Fetal_Forebrain": { 351 | "gene_names": "FB_RNA/genes.tsv", 352 | "gene_expression": "FB_RNA/matrix.mtx", 353 | "gene_barcodes": "FB_RNA/barcodes.tsv", 354 | "atac_names": "FB_DNA/genes.tsv", 355 | "atac_expression": "FB_DNA/matrix.mtx", 356 | "atac_barcodes": "FB_DNA/barcodes.tsv" 357 | } 358 | } 359 | if dataset_name=="CellLineMixture" or dataset_name=="Fetal_Forebrain": 360 | if os.path.exists("{}/Cell_embeddings.xls".format(data_path)): 361 | cell_embed = pd.read_csv("{}/Cell_embeddings.xls".format(data_path), sep='\t') 362 | cell_embed_info = cell_embed.iloc[:, 0:2] 363 | cell_embed_info.columns = ["Cell_ID","Cluster"] 364 | else: 365 | logger.info("Cannot find cell embedding files for Paried-seq Demo.") 366 | return 367 | 368 | super().__init__(dataset = available_datasets[dataset_name], 369 | data_path= data_path, 370 | dense = False, 371 | measurement_names_column = 1, 372 | remove_extracted_data = False, 373 | delayed_populating = False, 374 | gzipped = False, 375 | atac_threshold = 0.005, 376 | cell_threshold = 100, 377 | cell_meta=cell_embed_info 378 | ) 379 | elif dataset_name=="Adult_Cerebral": 380 | if os.path.exists("{}/Cell_embeddings.xls".format(data_path)): 381 | cell_embed = pd.read_csv("{}/Cell_embeddings.xls".format(data_path), sep='\t') 382 | cell_embed_info = cell_embed.iloc[:, ["ID","Cluster"]] 383 | cell_embed_info.columns = ["Cell_ID","Cluster"] 384 | else: 385 | logger.info("Cannot find cell embedding files for Paried-seq Demo.") 386 | return 387 | super().__init__(dataset=available_datasets[dataset_name], 388 | data_path=data_path, 389 | dense=False, 390 | measurement_names_column = 1, 391 | remove_extracted_data=False, 392 | delayed_populating=False, 393 | gzipped=False, 394 | atac_threshold=0.005, 395 | cell_threshold=1, 396 | cell_meta=cell_embed_info 397 | ) 398 | else: 399 | logger.info('Please select from {} dataset.'.format("\t".join(available_datasets.keys()))) 400 | 401 | 402 | class SciCarDemo(LoadData): 403 | def __init__(self, dataset_name: str = None, data_path: str = "/dataset", cell_meta: str = None): 404 | urls = "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE117089&format=file" 405 | # NOTICE, tsv files are generated from original txt files 406 | available_datasets = { 407 | "CellLineMixture": { 408 | "gene_barcodes": "GSM3271040_RNA_sciCAR_A549_cell.tsv", 409 | "gene_names": "GSM3271040_RNA_sciCAR_A549_gene.tsv", 410 | "gene_expression": "GSM3271040_RNA_sciCAR_A549_gene_count.txt", 411 | "atac_barcodes": "GSM3271041_ATAC_sciCAR_A549_cell.tsv", 412 | "atac_names": "GSM3271041_ATAC_sciCAR_A549_peak.tsv", 413 | "atac_expression": "GSM3271041_ATAC_sciCAR_A549_peak_count.txt" 414 | }, 415 | "mouse_kidney": { 416 | "gene_barcodes": "GSM3271044_RNA_mouse_kidney_cell.tsv", 417 | "gene_names": "GSM3271044_RNA_mouse_kidney_gene.tsv", 418 | "gene_expression": "GSM3271044_RNA_mouse_kidney_gene_count.txt", 419 | "atac_barcodes": "GSM3271045_ATAC_mouse_kidney_cell.tsv", 420 | "atac_names": "GSM3271045_ATAC_mouse_kidney_peak.tsv", 421 | "atac_expression": "GSM3271045_ATAC_mouse_kidney_peak_count.txt" 422 | } 423 | } 424 | if dataset_name: 425 | for barcode_file in ["gene_barcodes", "atac_barcodes", "gene_names", "atac_names"]: 426 | # generate gene and atac barcodes from cell metadata. 427 | with open("{}/{}".format(data_path, available_datasets[dataset_name][barcode_file]),"w") as fo: 428 | infile = "{}/{}.txt".format(data_path, available_datasets[dataset_name][barcode_file][:-4]) 429 | indata = [i.rstrip().split(",") for i in open(infile)][1:] 430 | for line in indata: 431 | fo.write("{}\n".format(line[0])) 432 | if cell_meta: 433 | cell_meta_data = pd.read_csv(cell_meta, sep=",", header=0) 434 | else: 435 | cell_meta_data = None 436 | super().__init__(dataset=available_datasets[dataset_name], 437 | data_path=data_path, 438 | dense=False, 439 | measurement_names_column=0, 440 | cell_meta=cell_meta_data, 441 | remove_extracted_data=False, 442 | delayed_populating=False, 443 | gzipped=False, 444 | atac_threshold=0.0005, 445 | cell_threshold=1 446 | ) 447 | else: 448 | logger.info('Please select from {} dataset.'.format("\t".join(available_datasets.keys()))) 449 | 450 | -------------------------------------------------------------------------------- /scMVP/inference/__init__.py: -------------------------------------------------------------------------------- 1 | from .posterior import Posterior 2 | from .trainer import Trainer 3 | from .inference import UnsupervisedTrainer, AdapterTrainer 4 | from .annotation import ( 5 | ClassifierTrainer, 6 | ) 7 | from .multi_inference import MultiPosterior, MultiTrainer 8 | 9 | __all__ = [ 10 | "Trainer", 11 | "Posterior", 12 | "UnsupervisedTrainer", 13 | "AdapterTrainer", 14 | "ClassifierTrainer", 15 | "MultiPosterior", 16 | "MultiTrainer" 17 | ] 18 | -------------------------------------------------------------------------------- /scMVP/inference/annotation.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import numpy as np 4 | import logging 5 | 6 | from sklearn import neighbors 7 | from sklearn.ensemble import RandomForestClassifier 8 | from sklearn.model_selection import GridSearchCV 9 | from sklearn.neighbors import KNeighborsClassifier 10 | from sklearn.svm import SVC 11 | 12 | import torch 13 | from torch.nn import functional as F 14 | 15 | from scMVP.inference import Posterior 16 | from scMVP.inference import Trainer 17 | from scMVP.inference.inference import UnsupervisedTrainer 18 | from scMVP.inference.posterior import unsupervised_clustering_accuracy 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class AnnotationPosterior(Posterior): 24 | def __init__(self, *args, model_zl=False, **kwargs): 25 | super().__init__(*args, **kwargs) 26 | self.model_zl = model_zl 27 | 28 | def accuracy(self): 29 | model, cls = ( 30 | (self.sampling_model, self.model) 31 | if hasattr(self, "sampling_model") 32 | else (self.model, None) 33 | ) 34 | acc = compute_accuracy(model, self, classifier=cls, model_zl=self.model_zl) 35 | logger.debug("Acc: %.4f" % (acc)) 36 | return acc 37 | 38 | accuracy.mode = "max" 39 | 40 | @torch.no_grad() 41 | def hierarchical_accuracy(self): 42 | all_y, all_y_pred = self.compute_predictions() 43 | acc = np.mean(all_y == all_y_pred) 44 | 45 | all_y_groups = np.array([self.model.labels_groups[y] for y in all_y]) 46 | all_y_pred_groups = np.array([self.model.labels_groups[y] for y in all_y_pred]) 47 | h_acc = np.mean(all_y_groups == all_y_pred_groups) 48 | 49 | logger.debug("Hierarchical Acc : %.4f\n" % h_acc) 50 | return acc 51 | 52 | accuracy.mode = "max" 53 | 54 | @torch.no_grad() 55 | def compute_predictions(self, soft=False): 56 | """ 57 | :return: the true labels and the predicted labels 58 | :rtype: 2-tuple of :py:class:`numpy.int32` 59 | """ 60 | model, cls = ( 61 | (self.sampling_model, self.model) 62 | if hasattr(self, "sampling_model") 63 | else (self.model, None) 64 | ) 65 | return compute_predictions( 66 | model, self, classifier=cls, soft=soft, model_zl=self.model_zl 67 | ) 68 | 69 | @torch.no_grad() 70 | def unsupervised_classification_accuracy(self): 71 | all_y, all_y_pred = self.compute_predictions() 72 | uca = unsupervised_clustering_accuracy(all_y, all_y_pred)[0] 73 | logger.debug("UCA : %.4f" % (uca)) 74 | return uca 75 | 76 | unsupervised_classification_accuracy.mode = "max" 77 | 78 | @torch.no_grad() 79 | def nn_latentspace(self, posterior): 80 | data_train, _, labels_train = self.get_latent() 81 | data_test, _, labels_test = posterior.get_latent() 82 | nn = KNeighborsClassifier() 83 | nn.fit(data_train, labels_train) 84 | score = nn.score(data_test, labels_test) 85 | return score 86 | 87 | 88 | class ClassifierTrainer(Trainer): 89 | r"""The ClassifierInference class for training a classifier either on the raw data or on top of the latent 90 | space of another model (VAE, VAEC, SCANVI). 91 | 92 | Args: 93 | :model: A model instance from class ``VAE``, ``VAEC``, ``SCANVI`` 94 | :gene_dataset: A gene_dataset instance like ``CortexDataset()`` 95 | :train_size: The train size, either a float between 0 and 1 or and integer for the number of training samples 96 | to use Default: ``0.8``. 97 | :test_size: The test size, either a float between 0 and 1 or and integer for the number of test samples 98 | to use Default: ``None``. 99 | :sampling_model: Model with z_encoder with which to first transform data. 100 | :sampling_zl: Transform data with sampling_model z_encoder and l_encoder and concat. 101 | :\**kwargs: Other keywords arguments from the general Trainer class. 102 | 103 | 104 | Examples: 105 | >>> gene_dataset = CortexDataset() 106 | >>> vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False, 107 | ... n_labels=gene_dataset.n_labels) 108 | 109 | >>> classifier = Classifier(vae.n_latent, n_labels=cortex_dataset.n_labels) 110 | >>> trainer = ClassifierTrainer(classifier, gene_dataset, sampling_model=vae, train_size=0.5) 111 | >>> trainer.train(n_epochs=20, lr=1e-3) 112 | >>> trainer.test_set.accuracy() 113 | """ 114 | 115 | def __init__( 116 | self, 117 | *args, 118 | train_size=0.8, 119 | test_size=None, 120 | sampling_model=None, 121 | sampling_zl=False, 122 | use_cuda=True, 123 | **kwargs 124 | ): 125 | self.sampling_model = sampling_model 126 | self.sampling_zl = sampling_zl 127 | super().__init__(*args, use_cuda=use_cuda, **kwargs) 128 | self.train_set, self.test_set, self.validation_set = self.train_test_validation( 129 | self.model, 130 | self.gene_dataset, 131 | train_size=train_size, 132 | test_size=test_size, 133 | type_class=AnnotationPosterior, 134 | ) 135 | self.train_set.to_monitor = ["accuracy"] 136 | self.test_set.to_monitor = ["accuracy"] 137 | self.validation_set.to_monitor = ["accuracy"] 138 | self.train_set.model_zl = sampling_zl 139 | self.test_set.model_zl = sampling_zl 140 | self.validation_set.model_zl = sampling_zl 141 | 142 | @property 143 | def posteriors_loop(self): 144 | return ["train_set"] 145 | 146 | def __setattr__(self, key, value): 147 | if key in ["train_set", "test_set"]: 148 | value.sampling_model = self.sampling_model 149 | super().__setattr__(key, value) 150 | 151 | def loss(self, tensors_labelled): 152 | x, _, _, _, labels_train = tensors_labelled 153 | if self.sampling_model: 154 | if hasattr(self.sampling_model, "classify"): 155 | return F.cross_entropy( 156 | self.sampling_model.classify(x), labels_train.view(-1) 157 | ) 158 | else: 159 | if self.sampling_model.log_variational: 160 | x = torch.log(1 + x) 161 | if self.sampling_zl: 162 | x_z = self.sampling_model.z_encoder(x)[0] 163 | x_l = self.sampling_model.l_encoder(x)[0] 164 | x = torch.cat((x_z, x_l), dim=-1) 165 | else: 166 | x = self.sampling_model.z_encoder(x)[0] 167 | return F.cross_entropy(self.model(x), labels_train.view(-1)) 168 | 169 | @torch.no_grad() 170 | def compute_predictions(self, soft=False): 171 | """ 172 | :return: the true labels and the predicted labels 173 | :rtype: 2-tuple of :py:class:`numpy.int32` 174 | """ 175 | model, cls = ( 176 | (self.sampling_model, self.model) 177 | if hasattr(self, "sampling_model") 178 | else (self.model, None) 179 | ) 180 | full_set = self.create_posterior(type_class=AnnotationPosterior) 181 | return compute_predictions( 182 | model, full_set, classifier=cls, soft=soft, model_zl=self.sampling_zl 183 | ) 184 | 185 | 186 | @torch.no_grad() 187 | def compute_predictions( 188 | model, data_loader, classifier=None, soft=False, model_zl=False 189 | ): 190 | all_y_pred = [] 191 | all_y = [] 192 | 193 | for i_batch, tensors in enumerate(data_loader): 194 | sample_batch, _, _, _, labels = tensors 195 | all_y += [labels.view(-1).cpu()] 196 | 197 | if hasattr(model, "classify"): 198 | y_pred = model.classify(sample_batch) 199 | elif classifier is not None: 200 | # Then we use the specified classifier 201 | if model is not None: 202 | if model.log_variational: 203 | sample_batch = torch.log(1 + sample_batch) 204 | if model_zl: 205 | sample_z = model.z_encoder(sample_batch)[0] 206 | sample_l = model.l_encoder(sample_batch)[0] 207 | sample_batch = torch.cat((sample_z, sample_l), dim=-1) 208 | else: 209 | sample_batch, _, _ = model.z_encoder(sample_batch) 210 | y_pred = classifier(sample_batch) 211 | else: # The model is the raw classifier 212 | y_pred = model(sample_batch) 213 | 214 | if not soft: 215 | y_pred = y_pred.argmax(dim=-1) 216 | 217 | all_y_pred += [y_pred.cpu()] 218 | 219 | all_y_pred = np.array(torch.cat(all_y_pred)) 220 | all_y = np.array(torch.cat(all_y)) 221 | 222 | return all_y, all_y_pred 223 | 224 | 225 | @torch.no_grad() 226 | def compute_accuracy(vae, data_loader, classifier=None, model_zl=False): 227 | all_y, all_y_pred = compute_predictions( 228 | vae, data_loader, classifier=classifier, model_zl=model_zl 229 | ) 230 | return np.mean(all_y == all_y_pred) 231 | 232 | 233 | Accuracy = namedtuple( 234 | "Accuracy", ["unweighted", "weighted", "worst", "accuracy_classes"] 235 | ) 236 | 237 | 238 | @torch.no_grad() 239 | def compute_accuracy_tuple(y, y_pred): 240 | y = y.ravel() 241 | n_labels = len(np.unique(y)) 242 | classes_probabilities = [] 243 | accuracy_classes = [] 244 | for cl in range(n_labels): 245 | idx = y == cl 246 | classes_probabilities += [np.mean(idx)] 247 | accuracy_classes += [ 248 | np.mean((y[idx] == y_pred[idx])) if classes_probabilities[-1] else 0 249 | ] 250 | # This is also referred to as the "recall": p = n_true_positive / (n_false_negative + n_true_positive) 251 | # ( We could also compute the "precision": p = n_true_positive / (n_false_positive + n_true_positive) ) 252 | accuracy_named_tuple = Accuracy( 253 | unweighted=np.dot(accuracy_classes, classes_probabilities), 254 | weighted=np.mean(accuracy_classes), 255 | worst=np.min(accuracy_classes), 256 | accuracy_classes=accuracy_classes, 257 | ) 258 | return accuracy_named_tuple 259 | 260 | 261 | @torch.no_grad() 262 | def compute_accuracy_nn(data_train, labels_train, data_test, labels_test, k=5): 263 | clf = neighbors.KNeighborsClassifier(k, weights="distance") 264 | return compute_accuracy_classifier( 265 | clf, data_train, labels_train, data_test, labels_test 266 | ) 267 | 268 | 269 | @torch.no_grad() 270 | def compute_accuracy_classifier(clf, data_train, labels_train, data_test, labels_test): 271 | clf.fit(data_train, labels_train) 272 | # Predicting the labels 273 | y_pred_test = clf.predict(data_test) 274 | y_pred_train = clf.predict(data_train) 275 | 276 | return ( 277 | ( 278 | compute_accuracy_tuple(labels_train, y_pred_train), 279 | compute_accuracy_tuple(labels_test, y_pred_test), 280 | ), 281 | y_pred_test, 282 | ) 283 | 284 | 285 | @torch.no_grad() 286 | def compute_accuracy_svc( 287 | data_train, 288 | labels_train, 289 | data_test, 290 | labels_test, 291 | param_grid=None, 292 | verbose=0, 293 | max_iter=-1, 294 | ): 295 | if param_grid is None: 296 | param_grid = [ 297 | {"C": [1, 10, 100, 1000], "kernel": ["linear"]}, 298 | {"C": [1, 10, 100, 1000], "gamma": [0.001, 0.0001], "kernel": ["rbf"]}, 299 | ] 300 | svc = SVC(max_iter=max_iter) 301 | clf = GridSearchCV(svc, param_grid, verbose=verbose) 302 | return compute_accuracy_classifier( 303 | clf, data_train, labels_train, data_test, labels_test 304 | ) 305 | 306 | 307 | @torch.no_grad() 308 | def compute_accuracy_rf( 309 | data_train, labels_train, data_test, labels_test, param_grid=None, verbose=0 310 | ): 311 | if param_grid is None: 312 | param_grid = {"max_depth": np.arange(3, 10), "n_estimators": [10, 50, 100, 200]} 313 | rf = RandomForestClassifier(max_depth=2, random_state=0) 314 | clf = GridSearchCV(rf, param_grid, verbose=verbose) 315 | return compute_accuracy_classifier( 316 | clf, data_train, labels_train, data_test, labels_test 317 | ) 318 | -------------------------------------------------------------------------------- /scMVP/inference/inference.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import matplotlib.pyplot as plt 4 | import torch 5 | 6 | from scMVP.inference import Trainer 7 | 8 | plt.switch_backend("agg") 9 | 10 | 11 | class UnsupervisedTrainer(Trainer): 12 | r"""The VariationalInference class for the unsupervised training of an autoencoder. 13 | 14 | Args: 15 | :model: A model instance from class ``VAE``, ``VAEC``, 16 | :gene_dataset: A gene_dataset instance like ``snareDataset()`` 17 | :train_size: The train size, either a float between 0 and 1 or an integer for the number of training samples 18 | to use Default: ``0.8``. 19 | :test_size: The test size, either a float between 0 and 1 or an integer for the number of training samples 20 | to use Default: ``None``, which is equivalent to data not in the train set. If ``train_size`` and ``test_size`` 21 | do not add to 1 or the length of the dataset then the remaining samples are added to a ``validation_set``. 22 | :n_epochs_kl_warmup: Number of epochs for linear warmup of KL(q(z|x)||p(z)) term. After `n_epochs_kl_warmup`, 23 | the training objective is the ELBO. This might be used to prevent inactivity of latent units, and/or to 24 | improve clustering of latent space, as a long warmup turns the model into something more of an autoencoder. 25 | :normalize_loss: A boolean determining whether the loss is divided by the total number of samples used for 26 | training. In particular, when the global KL divergence is equal to 0 and the division is performed, the loss 27 | for a minibatchis is equal to the average of reconstruction losses and KL divergences on the minibatch. 28 | Default: ``None``, which is equivalent to setting False when the model is an instance from class 29 | ``AutoZIVAE`` and True otherwise. 30 | :\*\*kwargs: Other keywords arguments from the general Trainer class. 31 | 32 | Examples: 33 | >>> gene_dataset = snareDataset() 34 | >>> vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False, 35 | ... n_labels=gene_dataset.n_labels) 36 | 37 | >>> infer = VariationalInference(gene_dataset, vae, train_size=0.5) 38 | >>> infer.train(n_epochs=20, lr=1e-3) 39 | """ 40 | default_metrics_to_monitor = ["elbo"] 41 | 42 | def __init__( 43 | self, 44 | model, 45 | gene_dataset, 46 | train_size=0.8, 47 | test_size=None, 48 | n_epochs_kl_warmup=400, 49 | normalize_loss=None, 50 | **kwargs 51 | ): 52 | super().__init__(model, gene_dataset, **kwargs) 53 | self.n_epochs_kl_warmup = n_epochs_kl_warmup 54 | 55 | self.normalize_loss = ( 56 | not ( 57 | hasattr(self.model, "reconstruction_loss") 58 | and self.model.reconstruction_loss == "autozinb" 59 | ) 60 | if normalize_loss is None 61 | else normalize_loss 62 | ) 63 | 64 | # Total size of the dataset used for training 65 | # (e.g. training set in this class but testing set in AdapterTrainer). 66 | # It used to rescale minibatch losses (cf. eq. (8) in Kingma et al., Auto-Encoding Variational Bayes, iCLR 2013) 67 | self.n_samples = 1.0 68 | 69 | if type(self) is UnsupervisedTrainer: 70 | ( 71 | self.train_set, 72 | self.test_set, 73 | self.validation_set, 74 | ) = self.train_test_validation(model, gene_dataset, train_size, test_size) 75 | self.train_set.to_monitor = ["elbo"] 76 | self.test_set.to_monitor = ["elbo"] 77 | self.validation_set.to_monitor = ["elbo"] 78 | self.n_samples = len(self.train_set.indices) 79 | 80 | @property 81 | def posteriors_loop(self): 82 | return ["train_set"] 83 | 84 | def loss(self, tensors): 85 | sample_batch, local_l_mean, local_l_var, batch_index, y = tensors 86 | #reconst_loss, kl_divergence_local, kl_divergence_global = self.model( 87 | # sample_batch, local_l_mean, local_l_var, batch_index, y 88 | #) 89 | reconst_loss, kl_divergence_local, kl_divergence_global = self.model( 90 | sample_batch, local_l_mean, local_l_var, batch_index, batch_index 91 | ) 92 | loss = ( 93 | self.n_samples 94 | * torch.mean(reconst_loss + self.kl_weight * kl_divergence_local) 95 | + kl_divergence_global 96 | ) 97 | if self.normalize_loss: 98 | loss = loss / self.n_samples 99 | return loss 100 | 101 | def on_epoch_begin(self): 102 | if self.n_epochs_kl_warmup is not None: 103 | self.kl_weight = min(1, self.epoch / self.n_epochs_kl_warmup) 104 | else: 105 | self.kl_weight = 1.0 106 | 107 | 108 | class AdapterTrainer(UnsupervisedTrainer): 109 | def __init__(self, model, gene_dataset, posterior_test, frequency=5): 110 | super().__init__(model, gene_dataset, frequency=frequency) 111 | self.test_set = posterior_test 112 | self.test_set.to_monitor = ["elbo"] 113 | self.params = list(self.model.z_encoder.parameters()) + list( 114 | self.model.l_encoder.parameters() 115 | ) 116 | self.z_encoder_state = copy.deepcopy(model.z_encoder.state_dict()) 117 | self.l_encoder_state = copy.deepcopy(model.l_encoder.state_dict()) 118 | self.n_scale = len(self.test_set.indices) 119 | 120 | @property 121 | def posteriors_loop(self): 122 | return ["test_set"] 123 | 124 | def train(self, n_path=10, n_epochs=50, **kwargs): 125 | for i in range(n_path): 126 | # Re-initialize to create new path 127 | self.model.z_encoder.load_state_dict(self.z_encoder_state) 128 | self.model.l_encoder.load_state_dict(self.l_encoder_state) 129 | super().train(n_epochs, params=self.params, **kwargs) 130 | 131 | return min(self.history["elbo_test_set"]) 132 | -------------------------------------------------------------------------------- /scMVP/inference/multi_inference.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import logging 3 | import torch 4 | from torch.distributions import Poisson, Gamma, Bernoulli, Normal 5 | from torch.utils.data import DataLoader 6 | import torch.nn.functional as F 7 | from torch import logsumexp 8 | import torch.distributions as distributions 9 | import numpy as np 10 | 11 | from scMVP.inference import Posterior 12 | from . import UnsupervisedTrainer 13 | 14 | from scMVP.dataset import GeneExpressionDataset 15 | from scMVP.models import multi_vae_attention 16 | # from sklearn.utils.linear_assignment_ import linear_assignment 17 | from scipy.optimize import linear_sum_assignment as linear_assignment 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | 23 | 24 | 25 | class MultiPosterior(Posterior): 26 | r"""The functional data unit for Multivae. A `MultiPosterior` instance is instantiated with a model and 27 | a gene_dataset, and as well as additional arguments that for Pytorch's `DataLoader`. A subset of indices 28 | can be specified, for purposes such as splitting the data into train/test/validation. Each trainer instance of the `Trainer` class can therefore have multiple 29 | `MultiPosterior` instances to train a model. A `MultiPosterior` instance also comes with many methods or 30 | utilities for its corresponding data. 31 | 32 | 33 | :param model: A model instance from class ``Multivae`` 34 | :param gene_dataset: A gene_dataset instance like ``ATACDataset()`` with attribute ``ATAC_expression`` 35 | :param shuffle: Specifies if a `RandomSampler` or a `SequentialSampler` should be used 36 | :param indices: Specifies how the data should be split with regards to train/test or labelled/unlabelled 37 | :param use_cuda: Default: ``True`` 38 | :param data_loader_kwarg: Keyword arguments to passed into the `DataLoader` 39 | 40 | Examples: 41 | 42 | Let us instantiate a `trainer`, with a gene_dataset and a model 43 | 44 | >>> gene_dataset = CbmcDataset() 45 | >>> totalvi = TOTALVI(gene_dataset.nb_genes, len(gene_dataset.protein_names), 46 | ... n_batch=gene_dataset.n_batches * False, n_labels=gene_dataset.n_labels, use_cuda=True) 47 | >>> trainer = TotalTrainer(vae, gene_dataset) 48 | >>> trainer.train(n_epochs=400) 49 | """ 50 | 51 | def __init__( 52 | self, 53 | model: multi_vae_attention, 54 | gene_dataset: GeneExpressionDataset, 55 | shuffle: bool = False, 56 | indices: Optional[np.ndarray] = None, 57 | use_cuda: bool = True, 58 | data_loader_kwargs=dict(), 59 | ): 60 | 61 | super().__init__( 62 | model, 63 | gene_dataset, 64 | shuffle=shuffle, 65 | indices=indices, 66 | use_cuda=use_cuda, 67 | data_loader_kwargs=data_loader_kwargs, 68 | ) 69 | # Add atac tensor as another tensor to be loaded 70 | self.data_loader_kwargs.update( 71 | { 72 | "collate_fn": gene_dataset.collate_fn_builder( 73 | {"atac_expression": np.float32}# debug cell index 74 | ) 75 | } 76 | ) 77 | 78 | self.data_loader = DataLoader(gene_dataset, **self.data_loader_kwargs) 79 | 80 | def corrupted(self): 81 | return self.update( 82 | { 83 | "collate_fn": self.gene_dataset.collate_fn_builder( 84 | {"atac_expression": np.float32}, corrupted=True 85 | ) 86 | } 87 | ) 88 | 89 | def uncorrupted(self): 90 | return self.update( 91 | { 92 | "collate_fn": self.gene_dataset.collate_fn_builder( 93 | {"atac_expression": np.float32} 94 | ) 95 | } 96 | ) 97 | 98 | @torch.no_grad() 99 | def elbo(self): 100 | elbo = self.compute_elbo(self.model) 101 | logger.debug("ELBO : %.4f" % elbo) 102 | return elbo 103 | elbo.mode = "min" 104 | 105 | @torch.no_grad() 106 | def reconstruction_error(self): 107 | reconstruction_error = self.compute_reconstruction_error(self.model, self) 108 | logger.debug("Reconstruction Error : %.4f" % reconstruction_error) 109 | return reconstruction_error 110 | 111 | reconstruction_error.mode = "min" 112 | 113 | @torch.no_grad() 114 | def marginal_ll(self, n_mc_samples=1000): 115 | 116 | ll = self.compute_marginal_log_likelihood(self.model, self, n_mc_samples) 117 | logger.debug("True LL : %.4f" % ll) 118 | return ll 119 | 120 | def compute_elbo(self, vae:multi_vae_attention, **kwargs): 121 | """ Computes the ELBO. 122 | 123 | The ELBO is the reconstruction error + the KL divergences 124 | between the variational distributions and the priors. 125 | It differs from the marginal log likelihood. 126 | Specifically, it is a lower bound on the marginal log likelihood 127 | plus a term that is constant with respect to the variational distribution. 128 | It still gives good insights on the modeling of the data, and is fast to compute. 129 | """ 130 | # Iterate once over the posterior and compute the elbo 131 | elbo = 0 132 | for i_batch, tensors in enumerate(self): 133 | ( 134 | sample_batch_X, 135 | local_l_mean, 136 | local_l_var, 137 | batch_index, 138 | label, 139 | sample_batch_Y, 140 | ) = tensors 141 | 142 | reconst_loss, kl_divergence_local, kl_divergence_global = vae( 143 | sample_batch_X, sample_batch_Y, local_l_mean, local_l_var, batch_index, label 144 | ) 145 | elbo += torch.sum(reconst_loss + kl_divergence_local).item() 146 | n_samples = len(self.indices) 147 | elbo += kl_divergence_global 148 | return elbo / n_samples 149 | 150 | def compute_reconstruction_error(self, vae:multi_vae_attention, **kwargs): 151 | r""" Computes log p(x/z), which is the reconstruction error . 152 | Differs from the marginal log likelihood, but still gives good 153 | insights on the modeling of the data, and is fast to compute 154 | 155 | This is really a helper function to self.ll, self.ll_protein, etc. 156 | """ 157 | # Iterate once over the posterior and computes the total log_likelihood 158 | log_lkl = 0 159 | for i_batch, tensors in enumerate(self): 160 | sample_batch, local_l_mean, local_l_var, batch_index, labels = tensors[ 161 | :5 162 | ] # general fish case 163 | 164 | # Distribution parameters 165 | outputs = vae.inference(sample_batch, batch_index, labels, **kwargs) 166 | p_rna_r = outputs["p_rna_r"] 167 | p_rna_rate = outputs["p_rna_rate"] 168 | p_rna_dropout = outputs["p_rna_dropout"] 169 | p_atac_mean = outputs["p_atac_mean"] 170 | p_atac_r = outputs["p_atac_r"] 171 | p_atac_dropout = outputs["p_atac_dropout"] 172 | 173 | # Reconstruction loss 174 | reconst_rna_loss = vae.get_reconstruction_loss( 175 | sample_batch, 176 | p_rna_rate, 177 | p_rna_r, 178 | p_rna_dropout, 179 | # bernoulli_params=bernoulli_params, 180 | **kwargs 181 | ) 182 | reconst_atac_loss = vae.get_reconstruction_atac_loss( 183 | sample_batch, 184 | p_atac_mean, 185 | p_atac_r, 186 | p_atac_dropout, 187 | **kwargs 188 | ) 189 | 190 | log_lkl += torch.sum(reconst_rna_loss).item() 191 | log_lkl += torch.sum(reconst_atac_loss).item() 192 | n_samples = len(self.indices) 193 | return log_lkl / n_samples 194 | 195 | def compute_marginal_log_likelihood(self, vae:multi_vae_attention , n_mc_samples): 196 | """ Computes a biased estimator for log p(x), which is the marginal log likelihood. 197 | 198 | Despite its bias, the estimator still converges to the real value 199 | of log p(x) when n_samples_mc (for Monte Carlo) goes to infinity 200 | (a fairly high value like 100 should be enough) 201 | Due to the Monte Carlo sampling, this method is not as computationally efficient 202 | as computing only the reconstruction loss 203 | """ 204 | # Uses MC sampling to compute a tighter lower bound on log p(x) 205 | 206 | log_lkl = 0 207 | for i_batch, tensors in enumerate(self): 208 | sample_batch, local_l_mean, local_l_var, batch_index, labels = tensors 209 | to_sum = torch.zeros(sample_batch.size()[0], n_mc_samples) 210 | 211 | for i in range(n_mc_samples): 212 | # Distribution parameters and sampled variables 213 | outputs = vae.inference(sample_batch, batch_index, labels) 214 | p_rna_r = outputs["p_rna_r"] 215 | p_rna_rate = outputs["p_rna_rate"] 216 | p_rna_dropout = outputs["p_rna_dropout"] 217 | qz_m = outputs["qz_m"] 218 | qz_v = outputs["qz_v"] 219 | z = outputs["z"] 220 | p_atac_mean = outputs["p_atac_mean"] 221 | p_atac_r = outputs["p_atac_r"] 222 | p_atac_dropout = outputs["p_atac_dropout"] 223 | mu_c = outputs["mu_c"] 224 | var_c = outputs["var_c"] 225 | gamma = outputs["gamma"] 226 | mu_c_max = outputs["mu_c_max"], 227 | var_c_max = outputs["var_c_max"], 228 | z_c_max = outputs["z_c_max"], 229 | 230 | # Reconstruction Loss 231 | reconst_rna_loss = vae.get_reconstruction_loss( 232 | sample_batch, 233 | p_rna_r, 234 | p_rna_rate, 235 | p_rna_dropout, 236 | ) 237 | reconst_atac_loss = vae.get_reconstruction_atac_loss( 238 | sample_batch, 239 | p_atac_r, 240 | p_atac_mean, 241 | p_atac_dropout, 242 | ) 243 | 244 | # Log-probabilities 245 | #p_l = Normal(local_l_mean, local_l_var.sqrt()).log_prob(library).sum(dim=-1) 246 | p_z = 0.0 247 | for prob, mu, var in mu_c, var_c, gamma: 248 | p_z += prob*Normal(mu, var.sqrt()).log_prob(z).sum(dim=-1) 249 | 250 | p_x_zl = -reconst_rna_loss - reconst_atac_loss 251 | q_z_x = Normal(qz_m, qz_v.sqrt()).log_prob(z).sum(dim=-1) 252 | #q_z_max = Normal(mu_c_max, var_c_max.sqrt()).log_prob(z_c_max).sum(dim=-1) 253 | 254 | to_sum[:, i] = p_z + p_x_zl - q_z_x #- q_z_max 255 | 256 | batch_log_lkl = logsumexp(to_sum, dim=-1) - np.log(n_mc_samples) 257 | log_lkl += torch.sum(batch_log_lkl).item() 258 | 259 | n_samples = len(self.indices) 260 | # The minus sign is there because we actually look at the negative log likelihood 261 | return -log_lkl / n_samples 262 | 263 | @torch.no_grad() 264 | def get_latent(self, sample=False): 265 | """ 266 | Output posterior z mean or sample, batch index, and label 267 | :param sample: z mean or z sample 268 | :return: three np.ndarrays, latent, batch_indices, labels 269 | """ 270 | latent = [] 271 | latent_rna = []; 272 | latent_atac = []; 273 | batch_indices = [] 274 | labels = [] 275 | cluster_gamma = [] 276 | cluster_index = [] 277 | for tensors in self: 278 | sample_batch_rna, local_l_mean, local_l_var, batch_index, label, sample_batch_atac = tensors 279 | give_mean = not sample 280 | latent_temp = self.model.sample_from_posterior_z( 281 | [sample_batch_rna, sample_batch_atac], y=label, give_mean=give_mean 282 | ) 283 | latent += [ 284 | latent_temp[0][0].cpu() 285 | ] 286 | latent_rna += [ 287 | latent_temp[1][0].cpu() 288 | ] 289 | latent_atac += [ 290 | latent_temp[2].cpu() 291 | ] 292 | gamma, mu_c, var_c, pi = self.model.get_gamma(latent_temp[0][0]) 293 | cluster_gamma += [gamma.cpu()] 294 | cluster_index += [torch.argmax(gamma.cpu(),dim=1)] 295 | batch_indices += [batch_index.cpu()] 296 | labels += [label.cpu()] 297 | return ( 298 | np.array(torch.cat(latent)), 299 | np.array(torch.cat(latent_rna)), 300 | np.array(torch.cat(latent_atac)), 301 | np.array(torch.cat(cluster_gamma)), 302 | np.array(torch.cat(cluster_index)), 303 | np.array(torch.cat(batch_indices)), 304 | np.array(torch.cat(labels)).ravel(), 305 | ) 306 | 307 | @torch.no_grad() 308 | def generate( 309 | self, 310 | n_samples: int = 100, 311 | genes: Optional[np.ndarray] = None, 312 | batch_size: int = 256, 313 | #batch_size: int = 128, 314 | ) : 315 | """ 316 | Create observation samples from the Posterior Predictive distribution 317 | 318 | :param n_samples: Number of required samples for each cell 319 | :param genes: Indices of genes of interest 320 | :param batch_size: Desired Batch size to generate data 321 | 322 | :return: Tuple (x_new, x_old) 323 | Where x_old has shape (n_cells, n_genes) 324 | Where x_new has shape (n_cells, n_genes, n_samples) 325 | """ 326 | assert self.model.reconstruction_loss in ["zinb", "zip"] 327 | zero_inflated = "zinb" 328 | 329 | rna_old = [] 330 | rna_new = [] 331 | atac_old = [] 332 | atac_new = [] 333 | for tensors in self.update({"batch_size": batch_size}): 334 | sample_batch, _, _, batch_index, labels = tensors 335 | outputs = self.model.inference( 336 | sample_batch, batch_index=batch_index, y=labels, n_samples=n_samples 337 | ) 338 | p_rna_r = outputs["p_rna_r"] 339 | p_rna_rate = outputs["p_rna_rate"] 340 | p_rna_dropout = outputs["p_rna_dropout"] 341 | p_atac_mean = outputs["p_atac_mean"] 342 | p_atac_dropout = outputs["p_atac_dropout"] 343 | 344 | # Generating rna-seq data 345 | p = p_rna_rate / (p_rna_rate + p_rna_r) 346 | r = p_rna_r 347 | # Important remark: Gamma is parametrized by the rate = 1/scale! 348 | l_train_rna = distributions.Gamma(concentration=r, rate=(1 - p) / p).sample() 349 | # Clamping as distributions objects can have buggy behaviors when 350 | # their parameters are too high 351 | l_train_rna = torch.clamp(l_train_rna, max=1e8) 352 | gene_expressions = distributions.Poisson( 353 | l_train_rna 354 | ).sample() # Shape : (n_samples, n_cells_batch, n_genes) 355 | 356 | #Generating atac-seq data 357 | l_train_atac = torch.clamp(p_atac_mean, max=1e2) 358 | atac_expressions = distributions.Poisson( 359 | l_train_atac 360 | ).sample() 361 | 362 | # zero-inflate 363 | if zero_inflated: 364 | p_zero_rna = (1.0 + torch.exp(-p_rna_dropout)).pow(-1) 365 | random_prob_rna = torch.rand_like(p_zero_rna) 366 | gene_expressions[random_prob_rna <= p_zero_rna] = 0 367 | 368 | p_zero_atac = (1.0 + torch.exp(-p_atac_dropout)).pow(-1) 369 | random_prob_atac = torch.rand_like(p_zero_atac) 370 | atac_expressions[random_prob_atac <= p_zero_atac] = 0 371 | 372 | gene_expressions = gene_expressions.permute( 373 | [1, 2, 0] 374 | ) # Shape : (n_cells_batch, n_genes, n_samples) 375 | atac_expressions = atac_expressions.permute( 376 | [1, 2, 0] 377 | ) 378 | 379 | rna_old.append(sample_batch[0].cpu()) 380 | rna_new.append(gene_expressions.cpu()) 381 | atac_old.append(sample_batch[1].cpu()) 382 | atac_new.append(atac_expressions.cpu()) 383 | 384 | rna_old = torch.cat(rna_old) # Shape (n_cells, n_genes) 385 | rna_new = torch.cat(rna_new) # Shape (n_cells, n_genes, n_samples) 386 | if genes is not None: 387 | gene_ids = self.gene_dataset.genes_to_index(genes) 388 | rna_new = rna_new[:, gene_ids, :] 389 | rna_old = rna_old[:, gene_ids] 390 | return rna_new.numpy(), rna_old.numpy(), atac_new.numpy(), rna_old.numpy() 391 | 392 | @torch.no_grad() 393 | def imputation(self, n_samples: int = 1): 394 | """ Gene imputation 395 | """ 396 | imputed_rna_list = [] 397 | imputed_atac_list = [] 398 | label_list = [] # for the annotated data 399 | atac_list = [] 400 | for tensors in self: 401 | x_rna, local_l_mean, local_l_var, batch_index, label, x_atac = tensors 402 | p_rna_rate, p_atac_rate = self.model.get_sample_rate( 403 | x=[x_rna,x_atac], batch_index=batch_index, y=label, n_samples=n_samples, local_l_mean = local_l_mean, local_l_var = local_l_var 404 | ) 405 | imputed_rna_list += [np.array(p_rna_rate.cpu())] 406 | imputed_atac_list += [np.array(p_atac_rate.cpu())] 407 | label_list += [np.array(label.cpu())] # only for annotated data 408 | atac_list += [np.array(x_atac.cpu())] # for the bins without call peak 409 | imputed_rna_list = np.concatenate(imputed_rna_list) 410 | imputed_atac_list = np.concatenate(imputed_atac_list) 411 | label_list = np.concatenate(label_list) # only for annotated data 412 | atac_list = np.concatenate(atac_list)# for the bins without call peak 413 | return imputed_rna_list.squeeze(), imputed_atac_list.squeeze(), label_list.squeeze(), atac_list 414 | 415 | @torch.no_grad() 416 | def get_sample_scale(self): 417 | p_rna_scales = [] 418 | p_atac_scales = [] 419 | for tensors in self: 420 | x_rna, _, _, batch_index, labels, x_atac = tensors 421 | p_rna_scales += [ 422 | np.array( 423 | ( 424 | self.model.get_sample_scale( 425 | x=[x_rna,x_atac], batch_index=batch_index, y=labels, n_samples=1 426 | )[0] 427 | ) 428 | ) 429 | ] 430 | p_atac_scales += [ 431 | np.array( 432 | ( 433 | self.model.get_sample_scale( 434 | x=[x_rna,x_atac], batch_index=batch_index, y=labels, n_samples=1 435 | )[1] 436 | ) 437 | ) 438 | ] 439 | return np.concatenate(p_rna_scales), np.concatenate(p_atac_scales) 440 | 441 | def cluster_acc(Y_pred, Y): 442 | assert Y_pred.size == Y.size 443 | D = max(Y_pred.max(), Y.max()) + 1 444 | w = np.zeros((D, D), dtype=np.int64) 445 | for i in range(Y_pred.size): 446 | w[Y_pred[i], Y[i]] += 1 447 | ind = linear_assignment(w.max() - w) 448 | return sum([w[i, j] for i, j in ind]) * 1.0 / Y_pred.size, ind 449 | 450 | def get_clustering(self): 451 | latent, latent_rna, latent_atac, cluster_gamma, batch_indices, labels = self.get_latent() 452 | cluster_accuarcy, ind = self.cluster_acc(np.argmax(cluster_gamma,axis=1),labels) 453 | print('cell dataset multi-vae - clustering accuracy: %.2f%%' % (cluster_accuarcy * 100)) 454 | return cluster_accuarcy, ind 455 | 456 | class MultiTrainer(UnsupervisedTrainer): 457 | r"""The VariationalInference class for the unsupervised training of an autoencoder. 458 | 459 | Args: 460 | :model: A model instance from class ``TOTALVI`` 461 | :gene_dataset: A gene_dataset instance like ``CbmcDataset()`` with attribute ``protein_expression`` 462 | :train_size: The train size, either a float between 0 and 1 or and integer for the number of training samples 463 | to use Default: ``0.93``. 464 | :test_size: The test size, either a float between 0 and 1 or and integer for the number of training samples 465 | to use Default: ``0.02``. Note that if train and test do not add to 1 the remainder is placed in a validation set 466 | :\*\*kwargs: Other keywords arguments from the general Trainer class. 467 | """ 468 | default_metrics_to_monitor = ["elbo"] 469 | 470 | def __init__( 471 | self, 472 | model, 473 | dataset, 474 | train_size=0.90, 475 | test_size=0.05, 476 | pro_recons_weight=1.0, 477 | n_epochs_back_kl_warmup=50, #200, init 478 | n_epochs_kl_warmup=200, 479 | **kwargs 480 | ): 481 | self.n_genes = dataset.nb_genes 482 | self.n_proteins = model.n_input_atac 483 | 484 | self.pro_recons_weight = pro_recons_weight 485 | self.n_epochs_back_kl_warmup = n_epochs_back_kl_warmup 486 | super().__init__( 487 | model, dataset, n_epochs_kl_warmup=n_epochs_kl_warmup, **kwargs 488 | ) 489 | if type(self) is MultiTrainer: 490 | ( 491 | self.train_set, 492 | self.test_set, 493 | self.validation_set, 494 | ) = self.train_test_validation( 495 | model, dataset, train_size, test_size, type_class=MultiPosterior 496 | ) 497 | self.train_set.to_monitor = [] 498 | self.test_set.to_monitor = ["elbo"] 499 | self.validation_set.to_monitor = ["elbo"] 500 | 501 | def loss(self, tensors): 502 | ( 503 | sample_batch_X, 504 | local_l_mean, 505 | local_l_var, 506 | batch_index, 507 | label, 508 | sample_batch_Y, 509 | ) = tensors 510 | 511 | #reconst_loss, kl_divergence_local, kl_divergence_global = self.model( 512 | # sample_batch_X, sample_batch_Y, local_l_mean, local_l_var, batch_index, label 513 | #) 514 | reconst_loss, kl_divergence_local, kl_divergence_global = self.model( 515 | sample_batch_X, sample_batch_Y, local_l_mean, local_l_var, batch_index, batch_index 516 | ) 517 | loss = ( 518 | self.n_samples 519 | * torch.mean(reconst_loss + self.back_warmup_weight * kl_divergence_local) 520 | + kl_divergence_global 521 | ) 522 | print( 523 | "reconst_loss = %f,kl_divergence_local = %f,kl_weight = %f,loss = %f" % 524 | (torch.mean(reconst_loss), torch.mean(kl_divergence_local), self.back_warmup_weight, loss) 525 | ) 526 | # self.KL_divergence = kl_divergence_global 527 | if self.normalize_loss: 528 | loss = loss / self.n_samples 529 | return loss 530 | 531 | 532 | def on_epoch_begin(self): 533 | super().on_epoch_begin() 534 | if self.n_epochs_back_kl_warmup is not None: 535 | #self.back_warmup_weight = min(1, self.epoch + self.n_epochs_back_kl_warmup / self.n_epochs_back_kl_warmup) 536 | self.back_warmup_weight = min(1, self.epoch + self.n_epochs_back_kl_warmup / self.n_epochs_back_kl_warmup) 537 | else: 538 | self.back_warmup_weight = 1.0 539 | 540 | -------------------------------------------------------------------------------- /scMVP/inference/trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import time 4 | 5 | from abc import abstractmethod 6 | from collections import defaultdict, OrderedDict 7 | from itertools import cycle 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from sklearn.model_selection._split import _validate_shuffle_split 13 | from torch.utils.data.sampler import SubsetRandomSampler 14 | from tqdm import trange 15 | 16 | from scMVP.inference.posterior import Posterior 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class Trainer: 22 | r"""The abstract Trainer class for training a PyTorch model and monitoring its statistics. It should be 23 | inherited at least with a .loss() function to be optimized in the training loop. 24 | 25 | Args: 26 | :model: A model instance from class ``VAE``, ``VAEC``, ``SCANVI`` 27 | :gene_dataset: A gene_dataset instance like ``CortexDataset()`` 28 | :use_cuda: Default: ``True``. 29 | :metrics_to_monitor: A list of the metrics to monitor. If not specified, will use the 30 | ``default_metrics_to_monitor`` as specified in each . Default: ``None``. 31 | :benchmark: if True, prevents statistics computation in the training. Default: ``False``. 32 | :frequency: The frequency at which to keep track of statistics. Default: ``None``. 33 | :early_stopping_metric: The statistics on which to perform early stopping. Default: ``None``. 34 | :save_best_state_metric: The statistics on which we keep the network weights achieving the best store, and 35 | restore them at the end of training. Default: ``None``. 36 | :on: The data_loader name reference for the ``early_stopping_metric`` and ``save_best_state_metric``, that 37 | should be specified if any of them is. Default: ``None``. 38 | :show_progbar: If False, disables progress bar. 39 | :seed: Random seed for train/test/validate split 40 | """ 41 | default_metrics_to_monitor = [] 42 | 43 | def __init__( 44 | self, 45 | model, 46 | gene_dataset, 47 | use_cuda=True, 48 | metrics_to_monitor=None, 49 | benchmark=False, 50 | frequency=None, 51 | weight_decay=1e-6, 52 | early_stopping_kwargs=None, 53 | data_loader_kwargs=None, 54 | show_progbar=True, 55 | seed=0, 56 | ): 57 | # handle mutable defaults 58 | early_stopping_kwargs = ( 59 | early_stopping_kwargs if early_stopping_kwargs else dict() 60 | ) 61 | data_loader_kwargs = data_loader_kwargs if data_loader_kwargs else dict() 62 | 63 | self.model = model 64 | self.gene_dataset = gene_dataset 65 | self._posteriors = OrderedDict() 66 | self.seed = seed 67 | 68 | #self.data_loader_kwargs = {"batch_size": 256, "pin_memory": use_cuda} # 128 for batchsize in init 69 | self.data_loader_kwargs = {"batch_size": 64, "pin_memory": use_cuda} # 128 for batchsize in init 70 | self.data_loader_kwargs.update(data_loader_kwargs) 71 | 72 | self.weight_decay = weight_decay 73 | self.benchmark = benchmark 74 | self.epoch = -1 # epoch = self.epoch + 1 in compute metrics 75 | self.training_time = 0 76 | # self.KL_divergence = -1 77 | # self.KL_divergence_max = 10000 78 | 79 | if metrics_to_monitor is not None: 80 | self.metrics_to_monitor = set(metrics_to_monitor) 81 | else: 82 | self.metrics_to_monitor = set(self.default_metrics_to_monitor) 83 | 84 | self.early_stopping = EarlyStopping(**early_stopping_kwargs) 85 | 86 | if self.early_stopping.early_stopping_metric: 87 | self.metrics_to_monitor.add(self.early_stopping.early_stopping_metric) 88 | 89 | self.use_cuda = use_cuda and torch.cuda.is_available() 90 | if self.use_cuda: 91 | self.model.cuda() 92 | 93 | self.frequency = frequency if not benchmark else None 94 | 95 | self.history = defaultdict(list) 96 | 97 | self.best_state_dict = self.model.state_dict() 98 | self.best_epoch = self.epoch 99 | 100 | self.show_progbar = show_progbar 101 | 102 | @torch.no_grad() 103 | def compute_metrics(self): 104 | begin = time.time() 105 | epoch = self.epoch + 1 106 | if self.frequency and ( 107 | epoch == 0 or epoch == self.n_epochs or (epoch % self.frequency == 0) 108 | ): 109 | with torch.set_grad_enabled(False): 110 | self.model.eval() 111 | logger.debug("\nEPOCH [%d/%d]: " % (epoch, self.n_epochs)) 112 | 113 | for name, posterior in self._posteriors.items(): 114 | message = " ".join([s.capitalize() for s in name.split("_")[-2:]]) 115 | if posterior.nb_cells < 5: 116 | logging.debug( 117 | message + " is too small to track metrics (<5 samples)" 118 | ) 119 | continue 120 | if hasattr(posterior, "to_monitor"): 121 | for metric in posterior.to_monitor: 122 | if metric not in self.metrics_to_monitor: 123 | logger.debug(message) 124 | result = getattr(posterior, metric)() 125 | self.history[metric + "_" + name] += [result] 126 | for metric in self.metrics_to_monitor: 127 | result = getattr(posterior, metric)() 128 | self.history[metric + "_" + name] += [result] 129 | self.model.train() 130 | self.compute_metrics_time += time.time() - begin 131 | 132 | def train(self, n_epochs=20, lr=1e-3, eps=0.01, params=None): 133 | begin = time.time() 134 | self.model.train() 135 | 136 | if params is None: 137 | params = filter(lambda p: p.requires_grad, self.model.parameters()) 138 | 139 | optimizer = self.optimizer = torch.optim.Adam( 140 | params, lr=lr, eps=eps, weight_decay=self.weight_decay 141 | ) 142 | aa = self.model.parameters() 143 | 144 | self.compute_metrics_time = 0 145 | self.n_epochs = n_epochs 146 | flag = True 147 | 148 | with trange( 149 | n_epochs, desc="training", file=sys.stdout, disable=not self.show_progbar 150 | ) as pbar: 151 | # We have to use tqdm this way so it works in Jupyter notebook. 152 | # See https://stackoverflow.com/questions/42212810/tqdm-in-jupyter-notebook 153 | for self.epoch in pbar: 154 | self.on_epoch_begin() 155 | pbar.update(1) 156 | for tensors_list in self.data_loaders_loop(): 157 | if tensors_list[0][0].shape[0] < 3: 158 | continue 159 | loss = self.loss(*tensors_list) 160 | print(loss) 161 | # if self.KL_divergence > self.KL_divergence_max: 162 | # break 163 | #if self.epoch == 15 and flag: 164 | # flag = False 165 | # optimizer.add_param_group({'params': self.model.get_params()}) 166 | optimizer.zero_grad() 167 | loss.backward() 168 | optimizer.step() 169 | 170 | if not self.on_epoch_end(): 171 | break 172 | 173 | if self.early_stopping.save_best_state_metric is not None: 174 | self.model.load_state_dict(self.best_state_dict) 175 | self.compute_metrics() 176 | 177 | self.model.eval() 178 | self.training_time += (time.time() - begin) - self.compute_metrics_time 179 | if self.frequency: 180 | logger.debug( 181 | "\nTraining time: %i s. / %i epochs" 182 | % (int(self.training_time), self.n_epochs) 183 | ) 184 | self.compute_metrics() 185 | 186 | 187 | def on_epoch_begin(self): 188 | pass 189 | 190 | def on_epoch_end(self): 191 | self.compute_metrics() 192 | on = self.early_stopping.on 193 | early_stopping_metric = self.early_stopping.early_stopping_metric 194 | save_best_state_metric = self.early_stopping.save_best_state_metric 195 | if save_best_state_metric is not None and on is not None: 196 | if self.early_stopping.update_state( 197 | self.history[save_best_state_metric + "_" + on][-1] 198 | ): 199 | self.best_state_dict = self.model.state_dict() 200 | self.best_epoch = self.epoch 201 | 202 | continue_training = True 203 | if early_stopping_metric is not None and on is not None: 204 | continue_training, reduce_lr = self.early_stopping.update( 205 | self.history[early_stopping_metric + "_" + on][-1] 206 | ) 207 | if reduce_lr: 208 | logger.info("Reducing LR.") 209 | for param_group in self.optimizer.param_groups: 210 | param_group["lr"] *= self.early_stopping.lr_factor 211 | 212 | # if self.KL_divergence > self.KL_divergence_max: 213 | # continue_training = False 214 | return continue_training 215 | 216 | @property 217 | @abstractmethod 218 | def posteriors_loop(self): 219 | pass 220 | 221 | def data_loaders_loop(self): 222 | """returns an zipped iterable corresponding to loss signature""" 223 | data_loaders_loop = [self._posteriors[name] for name in self.posteriors_loop] 224 | return zip( 225 | data_loaders_loop[0], 226 | *[cycle(data_loader) for data_loader in data_loaders_loop[1:]] 227 | ) 228 | 229 | def register_posterior(self, name, value): 230 | name = name.strip("_") 231 | self._posteriors[name] = value 232 | 233 | def corrupt_posteriors( 234 | self, rate=0.1, corruption="uniform", update_corruption=True 235 | ): 236 | if not hasattr(self.gene_dataset, "corrupted") and update_corruption: 237 | self.gene_dataset.corrupt(rate=rate, corruption=corruption) 238 | for name, posterior in self._posteriors.items(): 239 | self.register_posterior(name, posterior.corrupted()) 240 | 241 | def uncorrupt_posteriors(self): 242 | for name_, posterior in self._posteriors.items(): 243 | self.register_posterior(name_, posterior.uncorrupted()) 244 | 245 | def __getattr__(self, name): 246 | if "_posteriors" in self.__dict__: 247 | _posteriors = self.__dict__["_posteriors"] 248 | if name.strip("_") in _posteriors: 249 | return _posteriors[name.strip("_")] 250 | return object.__getattribute__(self, name) 251 | 252 | def __delattr__(self, name): 253 | if name.strip("_") in self._posteriors: 254 | del self._posteriors[name.strip("_")] 255 | else: 256 | object.__delattr__(self, name) 257 | 258 | def __setattr__(self, name, value): 259 | if isinstance(value, Posterior): 260 | name = name.strip("_") 261 | self.register_posterior(name, value) 262 | else: 263 | object.__setattr__(self, name, value) 264 | 265 | def train_test_validation( 266 | self, 267 | model=None, 268 | gene_dataset=None, 269 | train_size=0.1, 270 | test_size=None, 271 | type_class=Posterior, 272 | ): 273 | """Creates posteriors ``train_set``, ``test_set``, ``validation_set``. 274 | If ``train_size + test_size < 1`` then ``validation_set`` is non-empty. 275 | 276 | :param train_size: float, int, or None (default is 0.1) 277 | :param test_size: float, int, or None (default is None) 278 | """ 279 | model = self.model if model is None and hasattr(self, "model") else model 280 | gene_dataset = ( 281 | self.gene_dataset 282 | if gene_dataset is None and hasattr(self, "model") 283 | else gene_dataset 284 | ) 285 | n = len(gene_dataset) 286 | try: 287 | n_train, n_test = _validate_shuffle_split(n, test_size, train_size) 288 | except ValueError: 289 | if train_size != 1.0: 290 | raise ValueError( 291 | "Choice of train_size={} and test_size={} not understood".format( 292 | train_size, test_size 293 | ) 294 | ) 295 | n_train, n_test = n, 0 296 | random_state = np.random.RandomState(seed=self.seed) 297 | permutation = random_state.permutation(n) 298 | indices_test = permutation[:n_test] 299 | indices_train = permutation[n_test : (n_test + n_train)] 300 | indices_validation = permutation[(n_test + n_train) :] 301 | 302 | return ( 303 | self.create_posterior( 304 | model, gene_dataset, indices=indices_train, type_class=type_class 305 | ), 306 | self.create_posterior( 307 | model, gene_dataset, indices=indices_test, type_class=type_class 308 | ), 309 | self.create_posterior( 310 | model, gene_dataset, indices=indices_validation, type_class=type_class 311 | ), 312 | ) 313 | 314 | def create_posterior( 315 | self, 316 | model=None, 317 | gene_dataset=None, 318 | shuffle=False, 319 | indices=None, 320 | type_class=Posterior, 321 | ): 322 | model = self.model if model is None and hasattr(self, "model") else model 323 | gene_dataset = ( 324 | self.gene_dataset 325 | if gene_dataset is None and hasattr(self, "model") 326 | else gene_dataset 327 | ) 328 | return type_class( 329 | model, 330 | gene_dataset, 331 | shuffle=shuffle, 332 | indices=indices, 333 | use_cuda=self.use_cuda, 334 | data_loader_kwargs=self.data_loader_kwargs, 335 | ) 336 | 337 | 338 | class SequentialSubsetSampler(SubsetRandomSampler): 339 | def __init__(self, indices): 340 | self.indices = np.sort(indices) 341 | 342 | def __iter__(self): 343 | return iter(self.indices) 344 | 345 | 346 | class EarlyStopping: 347 | def __init__( 348 | self, 349 | early_stopping_metric: str = None, 350 | save_best_state_metric: str = None, 351 | on: str = "test_set", 352 | patience: int = 15, 353 | threshold: int = 3, 354 | benchmark: bool = False, 355 | reduce_lr_on_plateau: bool = False, 356 | lr_patience: int = 10, 357 | lr_factor: float = 0.5, 358 | posterior_class=Posterior, 359 | ): 360 | self.benchmark = benchmark 361 | self.patience = patience 362 | self.threshold = threshold 363 | self.epoch = 0 364 | self.wait = 0 365 | self.wait_lr = 0 366 | self.mode = ( 367 | getattr(posterior_class, early_stopping_metric).mode 368 | if early_stopping_metric is not None 369 | else None 370 | ) 371 | # We set the best to + inf because we're dealing with a loss we want to minimize 372 | self.current_performance = np.inf 373 | self.best_performance = np.inf 374 | self.best_performance_state = np.inf 375 | # If we want to maximize, we start at - inf 376 | if self.mode == "max": 377 | self.best_performance *= -1 378 | self.current_performance *= -1 379 | self.mode_save_state = ( 380 | getattr(Posterior, save_best_state_metric).mode 381 | if save_best_state_metric is not None 382 | else None 383 | ) 384 | if self.mode_save_state == "max": 385 | self.best_performance_state *= -1 386 | 387 | self.early_stopping_metric = early_stopping_metric 388 | self.save_best_state_metric = save_best_state_metric 389 | self.on = on 390 | self.reduce_lr_on_plateau = reduce_lr_on_plateau 391 | self.lr_patience = lr_patience 392 | self.lr_factor = lr_factor 393 | 394 | def update(self, scalar): 395 | self.epoch += 1 396 | if self.benchmark: 397 | continue_training = True 398 | reduce_lr = False 399 | elif self.wait >= self.patience: 400 | continue_training = False 401 | reduce_lr = False 402 | else: 403 | # Check if we should reduce the learning rate 404 | if not self.reduce_lr_on_plateau: 405 | reduce_lr = False 406 | elif self.wait_lr >= self.lr_patience: 407 | reduce_lr = True 408 | self.wait_lr = 0 409 | else: 410 | reduce_lr = False 411 | # Shift 412 | self.current_performance = scalar 413 | 414 | # Compute improvement 415 | if self.mode == "max": 416 | improvement = self.current_performance - self.best_performance 417 | elif self.mode == "min": 418 | improvement = self.best_performance - self.current_performance 419 | else: 420 | raise NotImplementedError("Unknown optimization mode") 421 | 422 | # updating best performance 423 | if improvement > 0: 424 | self.best_performance = self.current_performance 425 | 426 | if improvement < self.threshold: 427 | self.wait += 1 428 | self.wait_lr += 1 429 | else: 430 | self.wait = 0 431 | self.wait_lr = 0 432 | 433 | continue_training = True 434 | if not continue_training: 435 | # FIXME: log total number of epochs run 436 | logger.info( 437 | "\nStopping early: no improvement of more than " 438 | + str(self.threshold) 439 | + " nats in " 440 | + str(self.patience) 441 | + " epochs" 442 | ) 443 | logger.info( 444 | "If the early stopping criterion is too strong, " 445 | "please instantiate it with different parameters in the train method." 446 | ) 447 | return continue_training, reduce_lr 448 | 449 | def update_state(self, scalar): 450 | improved = ( 451 | self.mode_save_state == "max" and scalar - self.best_performance_state > 0 452 | ) or ( 453 | self.mode_save_state == "min" and self.best_performance_state - scalar > 0 454 | ) 455 | if improved: 456 | self.best_performance_state = scalar 457 | return improved 458 | -------------------------------------------------------------------------------- /scMVP/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .classifier import Classifier 2 | 3 | from .vae_attention import VAE_Attention 4 | from .vaePeak_selfattetion import VAE_Peak_SelfAttention 5 | from .multi_vae_attention import Multi_VAE_Attention 6 | 7 | 8 | __all__ = [ 9 | "Classifier" 10 | "VAE_Attention", 11 | "VAE_Peak_SelfAttention", 12 | "Multi_VAE_Attention", 13 | ] 14 | -------------------------------------------------------------------------------- /scMVP/models/classifier.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | 3 | from scMVP.models.modules import FCLayers 4 | 5 | 6 | class Classifier(nn.Module): 7 | def __init__( 8 | self, 9 | n_input, 10 | n_hidden=128, 11 | n_labels=5, 12 | n_layers=1, 13 | dropout_rate=0.1, 14 | logits=False, 15 | ): 16 | super().__init__() 17 | layers = [ 18 | FCLayers( 19 | n_in=n_input, 20 | n_out=n_hidden, 21 | n_layers=n_layers, 22 | n_hidden=n_hidden, 23 | dropout_rate=dropout_rate, 24 | use_batch_norm=True, 25 | ), 26 | nn.Linear(n_hidden, n_labels), 27 | ] 28 | if not logits: 29 | layers.append(nn.Softmax(dim=-1)) 30 | 31 | self.classifier = nn.Sequential(*layers) 32 | 33 | def forward(self, x): 34 | return self.classifier(x) 35 | -------------------------------------------------------------------------------- /scMVP/models/log_likelihood.py: -------------------------------------------------------------------------------- 1 | """File for computing log likelihood of the data""" 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import logsumexp 7 | from torch.distributions import Normal, Beta 8 | 9 | 10 | def compute_elbo(vae, posterior, **kwargs): 11 | """ Computes the ELBO. 12 | 13 | The ELBO is the reconstruction error + the KL divergences 14 | between the variational distributions and the priors. 15 | It differs from the marginal log likelihood. 16 | Specifically, it is a lower bound on the marginal log likelihood 17 | plus a term that is constant with respect to the variational distribution. 18 | It still gives good insights on the modeling of the data, and is fast to compute. 19 | """ 20 | # Iterate once over the posterior and compute the elbo 21 | elbo = 0 22 | for i_batch, tensors in enumerate(posterior): 23 | sample_batch, local_l_mean, local_l_var, batch_index, labels = tensors[ 24 | :5 25 | ] # general fish case 26 | # kl_divergence_global (scalar) should be common across all batches after training 27 | reconst_loss, kl_divergence, kl_divergence_global = vae( 28 | sample_batch, 29 | local_l_mean, 30 | local_l_var, 31 | batch_index=batch_index, 32 | y=labels, 33 | **kwargs 34 | ) 35 | elbo += torch.sum(reconst_loss + kl_divergence).item() 36 | n_samples = len(posterior.indices) 37 | elbo += kl_divergence_global 38 | return elbo / n_samples 39 | 40 | 41 | def compute_reconstruction_error(vae, posterior, **kwargs): 42 | """ Computes log p(x/z), which is the reconstruction error. 43 | 44 | Differs from the marginal log likelihood, but still gives good 45 | insights on the modeling of the data, and is fast to compute. 46 | """ 47 | # Iterate once over the posterior and computes the reconstruction error 48 | log_lkl = 0 49 | for i_batch, tensors in enumerate(posterior): 50 | sample_batch, local_l_mean, local_l_var, batch_index, labels = tensors[ 51 | :5 52 | ] # general fish case 53 | 54 | # Distribution parameters 55 | outputs = vae.inference(sample_batch, batch_index, labels, **kwargs) 56 | px_r = outputs["px_r"] 57 | px_rate = outputs["px_rate"] 58 | px_dropout = outputs["px_dropout"] 59 | bernoulli_params = outputs.get("bernoulli_params", None) 60 | 61 | # Reconstruction loss 62 | reconst_loss = vae.get_reconstruction_loss( 63 | sample_batch, 64 | px_rate, 65 | px_r, 66 | px_dropout, 67 | bernoulli_params=bernoulli_params, 68 | **kwargs 69 | ) 70 | 71 | log_lkl += torch.sum(reconst_loss).item() 72 | n_samples = len(posterior.indices) 73 | return log_lkl / n_samples 74 | 75 | 76 | def compute_marginal_log_likelihood_scvi(vae, posterior, n_samples_mc=100): 77 | """ Computes a biased estimator for log p(x), which is the marginal log likelihood. 78 | 79 | Despite its bias, the estimator still converges to the real value 80 | of log p(x) when n_samples_mc (for Monte Carlo) goes to infinity 81 | (a fairly high value like 100 should be enough) 82 | Due to the Monte Carlo sampling, this method is not as computationally efficient 83 | as computing only the reconstruction loss 84 | """ 85 | # Uses MC sampling to compute a tighter lower bound on log p(x) 86 | log_lkl = 0 87 | for i_batch, tensors in enumerate(posterior): 88 | sample_batch, local_l_mean, local_l_var, batch_index, labels = tensors 89 | to_sum = torch.zeros(sample_batch.size()[0], n_samples_mc) 90 | 91 | for i in range(n_samples_mc): 92 | 93 | # Distribution parameters and sampled variables 94 | outputs = vae.inference(sample_batch, batch_index, labels) 95 | px_r = outputs["px_r"] 96 | px_rate = outputs["px_rate"] 97 | px_dropout = outputs["px_dropout"] 98 | qz_m = outputs["qz_m"] 99 | qz_v = outputs["qz_v"] 100 | z = outputs["z"] 101 | ql_m = outputs["ql_m"] 102 | ql_v = outputs["ql_v"] 103 | library = outputs["library"] 104 | 105 | # Reconstruction Loss 106 | reconst_loss = vae.get_reconstruction_loss( 107 | sample_batch, px_rate, px_r, px_dropout 108 | ) 109 | 110 | # Log-probabilities 111 | p_l = Normal(local_l_mean, local_l_var.sqrt()).log_prob(library).sum(dim=-1) 112 | p_z = ( 113 | Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v)) 114 | .log_prob(z) 115 | .sum(dim=-1) 116 | ) 117 | p_x_zl = -reconst_loss 118 | q_z_x = Normal(qz_m, qz_v.sqrt()).log_prob(z).sum(dim=-1) 119 | q_l_x = Normal(ql_m, ql_v.sqrt()).log_prob(library).sum(dim=-1) 120 | 121 | to_sum[:, i] = p_z + p_l + p_x_zl - q_z_x - q_l_x 122 | 123 | batch_log_lkl = logsumexp(to_sum, dim=-1) - np.log(n_samples_mc) 124 | log_lkl += torch.sum(batch_log_lkl).item() 125 | 126 | n_samples = len(posterior.indices) 127 | # The minus sign is there because we actually look at the negative log likelihood 128 | return -log_lkl / n_samples 129 | 130 | 131 | def compute_marginal_log_likelihood_autozi(autozivae, posterior, n_samples_mc=100): 132 | """ Computes a biased estimator for log p(x), which is the marginal log likelihood. 133 | 134 | Despite its bias, the estimator still converges to the real value 135 | of log p(x) when n_samples_mc (for Monte Carlo) goes to infinity 136 | (a fairly high value like 100 should be enough) 137 | Due to the Monte Carlo sampling, this method is not as computationally efficient 138 | as computing only the reconstruction loss 139 | """ 140 | # Uses MC sampling to compute a tighter lower bound on log p(x) 141 | log_lkl = 0 142 | to_sum = torch.zeros((n_samples_mc,)) 143 | alphas_betas = autozivae.get_alphas_betas(as_numpy=False) 144 | alpha_prior = alphas_betas["alpha_prior"] 145 | alpha_posterior = alphas_betas["alpha_posterior"] 146 | beta_prior = alphas_betas["beta_prior"] 147 | beta_posterior = alphas_betas["beta_posterior"] 148 | 149 | for i in range(n_samples_mc): 150 | 151 | bernoulli_params = autozivae.sample_from_beta_distribution( 152 | alpha_posterior, beta_posterior 153 | ) 154 | 155 | for i_batch, tensors in enumerate(posterior): 156 | sample_batch, local_l_mean, local_l_var, batch_index, labels = tensors 157 | 158 | # Distribution parameters and sampled variables 159 | outputs = autozivae.inference(sample_batch, batch_index, labels) 160 | px_r = outputs["px_r"] 161 | px_rate = outputs["px_rate"] 162 | px_dropout = outputs["px_dropout"] 163 | qz_m = outputs["qz_m"] 164 | qz_v = outputs["qz_v"] 165 | z = outputs["z"] 166 | ql_m = outputs["ql_m"] 167 | ql_v = outputs["ql_v"] 168 | library = outputs["library"] 169 | 170 | # Reconstruction Loss 171 | bernoulli_params_batch = autozivae.reshape_bernoulli( 172 | bernoulli_params, batch_index, labels 173 | ) 174 | reconst_loss = autozivae.get_reconstruction_loss( 175 | sample_batch, px_rate, px_r, px_dropout, bernoulli_params_batch 176 | ) 177 | 178 | # Log-probabilities 179 | p_l = Normal(local_l_mean, local_l_var.sqrt()).log_prob(library).sum(dim=-1) 180 | p_z = ( 181 | Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v)) 182 | .log_prob(z) 183 | .sum(dim=-1) 184 | ) 185 | p_x_zld = -reconst_loss 186 | q_z_x = Normal(qz_m, qz_v.sqrt()).log_prob(z).sum(dim=-1) 187 | q_l_x = Normal(ql_m, ql_v.sqrt()).log_prob(library).sum(dim=-1) 188 | 189 | batch_log_lkl = torch.sum(p_x_zld + p_l + p_z - q_z_x - q_l_x, dim=0) 190 | to_sum[i] += batch_log_lkl 191 | 192 | p_d = Beta(alpha_prior, beta_prior).log_prob(bernoulli_params).sum() 193 | q_d = Beta(alpha_posterior, beta_posterior).log_prob(bernoulli_params).sum() 194 | 195 | to_sum[i] += p_d - q_d 196 | 197 | log_lkl = logsumexp(to_sum, dim=-1).item() - np.log(n_samples_mc) 198 | n_samples = len(posterior.indices) 199 | # The minus sign is there because we actually look at the negative log likelihood 200 | return -log_lkl / n_samples 201 | 202 | def binary_cross_entropy(x, recon_x, eps=1e-8): 203 | recon_x = torch.sigmoid(recon_x) 204 | res = (x * torch.log(recon_x + eps) + (1 - x) * torch.log(1 - recon_x + eps)) 205 | # print(torch.mean(recon_x)) 206 | return res 207 | 208 | def mean_square_error(x,recon_x): 209 | res = (x - recon_x)*(x - recon_x) 210 | return res 211 | def mean_square_error_positive(x,recon_x): 212 | #res = (x - recon_x + 1)*(x - recon_x + 1) 213 | #res[x==0] = 0 214 | res = torch.abs((x - recon_x)) 215 | res[x == 0] = 0 # test this property 216 | return res 217 | 218 | def log_zip_positive(x, mu, pi, eps=1e-8): 219 | # the likelihood of zero probability p(x=0) = -softplus(-pi)+softplus(-pi-mu) 220 | softplus_pi = F.softplus(-pi) 221 | softplus_mu_pi = F.softplus(-pi-mu) 222 | case_zero = - softplus_pi + softplus_mu_pi 223 | mul_case_zero = torch.mul((x < eps).type(torch.float32), case_zero) 224 | 225 | # the likelihood of p(x>0) = -softplus(-pi) - pi - mu +x*ln(mu) - ln(x!) 226 | case_non_zero = ( 227 | - softplus_pi 228 | - pi - mu 229 | + x * torch.log(mu + eps) 230 | - torch.lgamma(x + 1) 231 | ) 232 | mul_case_non_zero = torch.mul((x > eps).type(torch.float32), case_non_zero) 233 | 234 | res = mul_case_zero + mul_case_non_zero 235 | 236 | return res 237 | 238 | def log_zinb_positive(x, mu, theta, pi, eps=1e-8): 239 | """ 240 | Note: All inputs are torch Tensors 241 | log likelihood (scalar) of a minibatch according to a zinb model. 242 | Notes: 243 | We parametrize the bernoulli using the logits, hence the softplus functions appearing 244 | 245 | Variables: 246 | mu: mean of the negative binomial (has to be positive support) (shape: minibatch x genes) 247 | theta: inverse dispersion parameter (has to be positive support) (shape: minibatch x genes) 248 | pi: logit of the dropout parameter (real support) (shape: minibatch x genes) 249 | eps: numerical stability constant 250 | """ 251 | 252 | # theta is the dispersion rate. If .ndimension() == 1, it is shared for all cells (regardless of batch or labels) 253 | if theta.ndimension() == 1: 254 | theta = theta.view( 255 | 1, theta.size(0) 256 | ) # In this case, we reshape theta for broadcasting 257 | 258 | softplus_pi = F.softplus(-pi) 259 | log_theta_eps = torch.log(theta + eps) 260 | log_theta_mu_eps = torch.log(theta + mu + eps) 261 | pi_theta_log = -pi + theta * (log_theta_eps - log_theta_mu_eps) 262 | 263 | case_zero = F.softplus(pi_theta_log) - softplus_pi 264 | mul_case_zero = torch.mul((x < eps).type(torch.float32), case_zero) 265 | 266 | case_non_zero = ( 267 | -softplus_pi 268 | + pi_theta_log 269 | + x * (torch.log(mu + eps) - log_theta_mu_eps) 270 | + torch.lgamma(x + theta) 271 | - torch.lgamma(theta) 272 | - torch.lgamma(x + 1) 273 | ) 274 | mul_case_non_zero = torch.mul((x > eps).type(torch.float32), case_non_zero) 275 | 276 | res = mul_case_zero + mul_case_non_zero 277 | 278 | return res 279 | 280 | 281 | def log_nb_positive(x, mu, theta, eps=1e-8): 282 | """ 283 | Note: All inputs should be torch Tensors 284 | log likelihood (scalar) of a minibatch according to a nb model. 285 | 286 | Variables: 287 | mu: mean of the negative binomial (has to be positive support) (shape: minibatch x genes) 288 | theta: inverse dispersion parameter (has to be positive support) (shape: minibatch x genes) 289 | eps: numerical stability constant 290 | """ 291 | if theta.ndimension() == 1: 292 | theta = theta.view( 293 | 1, theta.size(0) 294 | ) # In this case, we reshape theta for broadcasting 295 | 296 | log_theta_mu_eps = torch.log(theta + mu + eps) 297 | 298 | res = ( 299 | theta * (torch.log(theta + eps) - log_theta_mu_eps) 300 | + x * (torch.log(mu + eps) - log_theta_mu_eps) 301 | + torch.lgamma(x + theta) 302 | - torch.lgamma(theta) 303 | - torch.lgamma(x + 1) 304 | ) 305 | 306 | return res 307 | 308 | 309 | def log_mixture_nb(x, mu_1, mu_2, theta_1, theta_2, pi, eps=1e-8): 310 | """ 311 | Note: All inputs should be torch Tensors 312 | log likelihood (scalar) of a minibatch according to a mixture nb model. 313 | pi is the probability to be in the first component. 314 | 315 | For totalVI, the first component should be background. 316 | 317 | Variables: 318 | mu1: mean of the first negative binomial component (has to be positive support) (shape: minibatch x genes) 319 | theta1: first inverse dispersion parameter (has to be positive support) (shape: minibatch x genes) 320 | mu2: mean of the second negative binomial (has to be positive support) (shape: minibatch x genes) 321 | theta2: second inverse dispersion parameter (has to be positive support) (shape: minibatch x genes) 322 | If None, assume one shared inverse dispersion parameter. 323 | eps: numerical stability constant 324 | """ 325 | if theta_2 is not None: 326 | log_nb_1 = log_nb_positive(x, mu_1, theta_1) 327 | log_nb_2 = log_nb_positive(x, mu_2, theta_2) 328 | # this is intended to reduce repeated computations 329 | else: 330 | theta = theta_1 331 | if theta.ndimension() == 1: 332 | theta = theta.view( 333 | 1, theta.size(0) 334 | ) # In this case, we reshape theta for broadcasting 335 | 336 | log_theta_mu_1_eps = torch.log(theta + mu_1 + eps) 337 | log_theta_mu_2_eps = torch.log(theta + mu_2 + eps) 338 | lgamma_x_theta = torch.lgamma(x + theta) 339 | lgamma_theta = torch.lgamma(theta) 340 | lgamma_x_plus_1 = torch.lgamma(x + 1) 341 | 342 | log_nb_1 = ( 343 | theta * (torch.log(theta + eps) - log_theta_mu_1_eps) 344 | + x * (torch.log(mu_1 + eps) - log_theta_mu_1_eps) 345 | + lgamma_x_theta 346 | - lgamma_theta 347 | - lgamma_x_plus_1 348 | ) 349 | log_nb_2 = ( 350 | theta * (torch.log(theta + eps) - log_theta_mu_2_eps) 351 | + x * (torch.log(mu_2 + eps) - log_theta_mu_2_eps) 352 | + lgamma_x_theta 353 | - lgamma_theta 354 | - lgamma_x_plus_1 355 | ) 356 | 357 | logsumexp = torch.logsumexp(torch.stack((log_nb_1, log_nb_2 - pi)), dim=0) 358 | softplus_pi = F.softplus(-pi) 359 | 360 | log_mixture_nb = logsumexp - softplus_pi 361 | 362 | return log_mixture_nb 363 | -------------------------------------------------------------------------------- /scMVP/models/multi_vae_attention.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Main module.""" 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import math 8 | import numpy as np 9 | from sklearn.mixture import GaussianMixture 10 | from torch.distributions import Normal, kl_divergence as kl 11 | 12 | from scMVP.models.log_likelihood import log_zinb_positive, log_nb_positive, log_zip_positive, binary_cross_entropy, \ 13 | mean_square_error_positive 14 | from scMVP.models.modules import Encoder, DecoderSCVI, LinearDecoderSCVI, Multi_Encoder, Multi_Decoder_nb_log, \ 15 | reparameterize_gaussian, Encoder_l, Encoder_nb, Multi_Encoder_nb, Multi_Decoder_nb, Classifer, Multi_Decoder_nb_log_peak,\ 16 | Encoder_nb_attention, Multi_Encoder_nb_attention, Encoder_nb_selfattention, Multi_Encoder_nb_SelfAttention,\ 17 | Multi_Decoder_nb_SelfAttention 18 | 19 | from scMVP.models.utils import one_hot 20 | 21 | torch.backends.cudnn.benchmark = True 22 | 23 | 24 | # VAE model 25 | class Multi_VAE_Attention(nn.Module): 26 | r"""Variational auto-encoder model. 27 | 28 | :param n_input: Number of input genes 29 | :param n_batch: Number of batches 30 | :param n_labels: Number of labels 31 | :param n_hidden: Number of nodes per hidden layer 32 | :param n_latent: Dimensionality of the latent space 33 | :param n_layers: Number of hidden layers used for encoder and decoder NNs 34 | :param dropout_rate: Dropout rate for neural networks 35 | :param mode: One of the following: 36 | * ``'vae'`` -single channel auto-encoder decoder neural framework for scRNA-seq data 37 | * ``'mm-vae'`` -multi-channels auto-encoder decoder neural framework for scRNA and scATAC data 38 | :param dispersion: One of the following 39 | 40 | * ``'gene'`` - dispersion parameter of NB is constant per gene across cells 41 | * ``'gene-batch'`` - dispersion can differ between different batches 42 | * ``'gene-label'`` - dispersion can differ between different labels 43 | * ``'gene-cell'`` - dispersion can differ for every gene in every cell 44 | 45 | :param log_variational: Log(data+1) prior to encoding for numerical stability. Not normalization. 46 | :param reconstruction_loss: One of 47 | 48 | * ``'nb'`` - Negative binomial distribution 49 | * ``'zinb'`` - Zero-inflated negative binomial distribution 50 | 51 | Examples: 52 | >>> gene_dataset = CortexDataset() 53 | >>> vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False, 54 | ... n_labels=gene_dataset.n_labels) 55 | 56 | """ 57 | 58 | def __init__( 59 | self, 60 | RNA_input: int, 61 | ATAC_input: int = 0, 62 | n_batch: int = 0, 63 | n_labels: int = 0, 64 | n_hidden: int = 128, 65 | n_latent: int = 10, 66 | n_layers: int = 1, 67 | n_centroids: int = 20, 68 | n_alfa: float = 1.0, 69 | dropout_rate: float = 0.1, 70 | mode="vae", 71 | dispersion: str = "gene", 72 | log_variational: bool = True, 73 | reconstruction_loss: str = "zinb", 74 | isLibrary: bool = True, 75 | is_cluster: bool = True, 76 | classifer_num: int = 0, 77 | ): 78 | super().__init__() 79 | self.mode = mode 80 | self.dispersion = dispersion 81 | self.n_latent = n_latent 82 | self.log_variational = log_variational 83 | self.reconstruction_loss = reconstruction_loss 84 | # Automatically deactivate if useless 85 | self.n_input_atac = ATAC_input 86 | self.n_input_RNA = RNA_input 87 | self.n_batch = n_batch 88 | self.n_labels = n_labels 89 | self.n_centroids = n_centroids 90 | self.alfa = n_alfa 91 | self.isLibrary = isLibrary 92 | self.is_cluster = is_cluster 93 | self.classifer_num = classifer_num 94 | 95 | if self.dispersion == "gene": 96 | self.px_r = torch.nn.Parameter(torch.randn(RNA_input)) 97 | self.p_atac_r = torch.nn.Parameter(torch.randn(ATAC_input)) 98 | elif self.dispersion == "gene-batch": 99 | self.px_r = torch.nn.Parameter(torch.randn(RNA_input, n_batch)) 100 | self.p_atac_r = torch.nn.Parameter(torch.randn(ATAC_input, n_batch)) 101 | elif self.dispersion == "gene-label": 102 | self.px_r = torch.nn.Parameter(torch.randn(RNA_input, n_labels)) 103 | self.p_atac_r = torch.nn.Parameter(torch.randn(ATAC_input, n_labels)) 104 | elif self.dispersion == "gene-cell": 105 | pass 106 | else: 107 | raise ValueError( 108 | "dispersion must be one of ['gene', 'gene-batch'," 109 | " 'gene-label', 'gene-cell'], but input was " 110 | "{}.format(self.dispersion)" 111 | ) 112 | 113 | if self.mode == "vae": 114 | # z encoder goes from the n_input-dimensional data to an n_latent-d 115 | # latent space representation 116 | self.z_encoder = Encoder( 117 | RNA_input, 118 | n_latent, 119 | n_layers=n_layers, 120 | n_hidden=n_hidden, 121 | dropout_rate=dropout_rate, 122 | ) 123 | # l encoder goes from n_input-dimensional data to 1-d library size 124 | self.l_encoder = Encoder( 125 | RNA_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate 126 | ) 127 | # decoder goes from n_latent-dimensional space to n_input-d data 128 | self.decoder = DecoderSCVI( 129 | n_latent, 130 | RNA_input, 131 | n_cat_list=[n_batch], 132 | n_layers=n_layers, 133 | n_hidden=n_hidden, 134 | ) 135 | elif self.mode == "mm-vae": 136 | if ATAC_input <= 0: 137 | raise ValueError("Input size of ATAC channel should be positive value," 138 | "but input was {}.format(self.ATAC_input)" 139 | ) 140 | 141 | # init c_params 142 | self.pi = nn.Parameter(torch.ones(n_centroids) / n_centroids, requires_grad=True) # pc 143 | self.mu_c = nn.Parameter(torch.zeros(n_latent, n_centroids), requires_grad=True) # mu 144 | self.var_c = nn.Parameter(torch.ones(n_latent, n_centroids), requires_grad=True) # sigma^2 145 | self.counter = nn.Parameter(torch.zeros(2), requires_grad=False) # sigma^2 146 | 147 | if self.classifer_num > 0: 148 | self.classifer = Classifer( 149 | n_latent, 150 | self.classifer_num, 151 | ) 152 | 153 | self.RNA_encoder = Encoder_nb_attention( 154 | RNA_input, 155 | n_latent, 156 | n_layers=n_layers, 157 | n_hidden=n_hidden, 158 | dropout_rate=dropout_rate, 159 | ) 160 | self.ATAC_encoder = Encoder_nb_selfattention( 161 | ATAC_input, 162 | n_latent, 163 | n_layers=n_layers, 164 | n_hidden=n_hidden, 165 | dropout_rate=dropout_rate, 166 | ) 167 | self.concatenter = nn.Linear(2 * self.n_latent, self.n_latent) 168 | if self.isLibrary == True: 169 | # l encoder goes from n_input-dimensional data to 1-d library size 170 | self.l_encoder = Encoder_l( 171 | RNA_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate 172 | ) 173 | self.RNA_ATAC_encoder = Multi_Encoder_nb_SelfAttention( 174 | RNA_input, 175 | ATAC_input, 176 | n_latent, 177 | n_layers=n_layers, 178 | n_hidden=n_hidden, 179 | dropout_rate=dropout_rate, 180 | ) 181 | self.RNA_ATAC_decoder = Multi_Decoder_nb_SelfAttention( 182 | n_latent, 183 | RNA_input, 184 | ATAC_input, 185 | n_cat_list=[n_batch], 186 | n_layers=n_layers, 187 | n_hidden=n_hidden, 188 | is_cluster=is_cluster, 189 | n_cluster=n_centroids 190 | ) 191 | else: 192 | raise ValueError( 193 | "mode must be one of ['vae', 'mm-vae'" 194 | " ], but input was " 195 | "{}.format(self.mode)" 196 | ) 197 | 198 | def get_params(self): 199 | params = [self.pi, self.mu_c, self.var_c] 200 | return params 201 | 202 | def get_latents(self, x_rna, y=None, x_atac=None): 203 | r""" returns the result of ``sample_from_posterior_z`` inside a list 204 | 205 | :param x: tensor of values with shape ``(batch_size, n_input)`` 206 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` 207 | :return: one element list of tensor 208 | :rtype: list of :py:class:`torch.Tensor` 209 | """ 210 | return [self.sample_from_posterior_z([x_rna, x_atac], y)] 211 | 212 | def sample_from_posterior_z(self, x, y=None, give_mean=True): 213 | r""" samples the tensor of latent values from the posterior 214 | #doesn't really sample, returns the means of the posterior distribution 215 | 216 | :param x: tensor of values with shape ``(batch_size, n_input)`` 217 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` 218 | :param give_mean: is True when we want the mean of the posterior distribution rather than sampling 219 | :return: tensor of shape ``(batch_size, n_latent)`` 220 | :rtype: :py:class:`torch.Tensor` 221 | """ 222 | if self.log_variational: 223 | x[0] = torch.log(1 + x[0]) 224 | x[1] = torch.log(1 + x[1]) 225 | 226 | qz_rna_m, qz_rna_v, rna_z = self.RNA_encoder(x[0], None) 227 | qz_atac_m, qz_atac_v, atac_z = self.ATAC_encoder(x[1], None) 228 | qz_m, qz_v, z = self.RNA_ATAC_encoder(x, None) 229 | if give_mean: 230 | z = qz_m, 231 | rna_z = qz_rna_m, 232 | atac_z = qz_atac_m 233 | return [z, rna_z, atac_z] 234 | 235 | def sample_from_posterior_l(self, x): 236 | r""" samples the tensor of library sizes from the posterior 237 | #doesn't really sample, returns the tensor of the means of the posterior distribution 238 | 239 | :param x: tensor of values with shape ``(batch_size, n_input)`` 240 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` 241 | :return: tensor of shape ``(batch_size, 1)`` 242 | :rtype: :py:class:`torch.Tensor` 243 | """ 244 | if self.log_variational: 245 | x = torch.log(1 + x) 246 | ql_m, ql_v, library = self.l_encoder(x) 247 | return library 248 | 249 | def get_sample_scale(self, x, batch_index=None, y=None, n_samples=1): 250 | r"""Returns the tensor of predicted frequencies of expression 251 | 252 | :param x: tensor of values with shape ``(batch_size, n_input)`` 253 | :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size`` 254 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` 255 | :param n_samples: number of samples 256 | :return: tensor of predicted frequencies of expression with shape ``(batch_size, n_input)`` 257 | :rtype: :py:class:`torch.Tensor` 258 | """ 259 | outputs = self.inference(x=x, batch_index=batch_index, y=y, n_samples=n_samples) 260 | return outputs["p_rna_scale"], outputs["p_atac_scale"] 261 | 262 | def get_sample_rate(self, x, batch_index=None, y=None, n_samples=1, local_l_mean=None, local_l_var=None): 263 | r"""Returns the tensor of means of the negative binomial distribution 264 | 265 | :param x: tensor of values with shape ``(batch_size, n_input)`` 266 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` 267 | :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size`` 268 | :param n_samples: number of samples 269 | :return: tensor of means of the negative binomial distribution with shape ``(batch_size, n_input)`` 270 | :rtype: :py:class:`torch.Tensor` 271 | """ 272 | outputs = self.inference(x=x, batch_index=batch_index, y=y, n_samples=n_samples, local_l_mean=local_l_mean, 273 | local_l_var=local_l_var) 274 | return outputs["p_rna_rate"], outputs["p_atac_mean"] 275 | 276 | def get_reconstruction_loss(self, x, px_rate, px_r, px_dropout, **kwargs): 277 | # Reconstruction Loss 278 | if self.reconstruction_loss == "nb": 279 | reconst_loss = -log_nb_positive(x, px_rate, px_r).sum(dim=-1) + 0.5*mean_square_error_positive(x, px_rate).sum(dim=-1) 280 | elif self.reconstruction_loss == "zinb": 281 | reconst_loss = -log_nb_positive(x, px_rate, px_r).sum(dim=-1) + 0.5*mean_square_error_positive(x, px_rate).sum(dim=-1) 282 | 283 | return reconst_loss 284 | 285 | def get_reconstruction_atac_loss(self, x, mu, dispersion, dropout, type="zip", **kwargs): 286 | if type == "zinb": 287 | reconst_loss = -log_zinb_positive(x, mu, dispersion, dropout).sum(dim=-1) 288 | elif type == "zip": 289 | reconst_loss = 0.5 * mean_square_error_positive(x, mu).sum(dim=-1) - log_zip_positive(x, mu, dropout).sum(dim=-1) 290 | mu[x > 0] = 0 291 | reconst_loss = reconst_loss + 0.05 * mu.sum(dim=-1) 292 | elif type == "zip_bu": 293 | reconst_loss = - log_zip_positive(x, mu, dropout).sum(dim=-1) - binary_cross_entropy(x, mu).sum(dim=-1) 294 | elif type == "bu": 295 | reconst_loss = - binary_cross_entropy(x, mu).sum(dim=-1) 296 | return reconst_loss 297 | 298 | def scale_from_z(self, sample_batch, fixed_batch): 299 | if self.log_variational: 300 | sample_batch[0] = torch.log(1 + sample_batch[0]) 301 | sample_batch[1] = torch.log(1 + sample_batch[1]) 302 | qz_rna_m, qz_rna_v, rna_z = self.RNA_encoder(sample_batch[0]) 303 | qz_atac_m, qz_atac_v, atac_z = self.ATAC_encoder(sample_batch[1]) 304 | qz_m, qz_v, z = self.RNA_ATAC_encoder(sample_batch) 305 | 306 | batch_index = fixed_batch * torch.ones_like(sample_batch[:, [0]]) 307 | library = 4.0 * torch.ones_like(sample_batch[:, [0]]) 308 | px_scale, _, _, _ = self.decoder("gene", z, library, batch_index) 309 | return px_scale 310 | 311 | def init_gmm_params(self, z): 312 | """ 313 | Init SCALE model with GMM model parameters 314 | """ 315 | if z is None: 316 | raise ("Input data is empty!") 317 | 318 | gmm = GaussianMixture(n_components=self.n_centroids, covariance_type='diag') 319 | gmm.fit(z) 320 | # gmm.weights_ 321 | self.mu_c.data.copy_(torch.from_numpy(gmm.means_.T.astype(np.float32))) 322 | self.var_c.data.copy_(torch.from_numpy(gmm.covariances_.T.astype(np.float32))) 323 | clust_index = gmm.predict(z) 324 | 325 | return clust_index 326 | 327 | def init_gmm_params_with_louvain(self, z, label): 328 | """ 329 | Init SCALE model with GMM model parameters 330 | """ 331 | if z is None or label is None: 332 | raise ("Input data is empty!") 333 | 334 | mu = np.zeros((z.shape[1],len(np.unique(label)))) 335 | var = np.zeros((z.shape[1],len(np.unique(label)))) 336 | pi = np.zeros(len(np.unique(label))) 337 | for i in range(len(np.unique(label))): 338 | mu[:,i] = np.mean(z[label==i,:],axis=0) 339 | var[:,i] = np.var(z[label==i,:],axis=0) 340 | pi[i] = np.sum(label==i)/len(label) 341 | 342 | self.mu_c.data.copy_(torch.from_numpy(mu.astype(np.float32))) 343 | self.var_c.data.copy_(torch.from_numpy(var.astype(np.float32))) 344 | self.pi.data.copy_(torch.from_numpy(pi.astype(np.float32))) 345 | 346 | return True 347 | 348 | def get_gamma(self, z, update=False): 349 | """ 350 | Inference c from z 351 | 352 | gamma is q(c|x) 353 | q(c|x) = p(c|z) = p(c)p(c|z)/p(z) 354 | """ 355 | n_centroids = self.n_centroids 356 | 357 | N = z.size(0) 358 | z_org = z 359 | z = z.unsqueeze(2).expand(z.size(0), z.size(1), n_centroids) 360 | pi = torch.abs(self.pi.repeat(N, 1)) # NxK 361 | mu_c = self.mu_c.repeat(N, 1, 1) # NxDxK 362 | var_c = torch.abs(self.var_c.repeat(N, 1, 1)) # NxDxK 363 | 364 | p_c_z = torch.exp( 365 | torch.log(pi) - torch.sum(0.5 * torch.log(2 * math.pi * var_c) + (z - mu_c) ** 2 / (2 * var_c), 366 | dim=1)) + 1e-10 367 | gamma = p_c_z / torch.sum(p_c_z, dim=1, keepdim=True) 368 | return gamma, mu_c, var_c, pi 369 | 370 | def inference(self, x, batch_index=None, y=None, local_l_mean=None, local_l_var=None, update=False, n_samples=1): 371 | x_ = x 372 | if len(x_) != 2: 373 | raise ValueError("Input training data should be 2 data types(RNA and ATAC)," 374 | "but input was only {}.format(len(x_))" 375 | ) 376 | x_rna = x_[0] 377 | x_atac = x_[1] 378 | libary_atac = torch.log(x_[1].sum(dim=-1)).reshape(-1, 1) 379 | libary_rna = torch.log(x_[0].sum(dim=-1)).reshape(-1, 1) 380 | if self.log_variational: 381 | x_rna = torch.log(1 + x_rna) 382 | x_atac = torch.log(1 + x_atac) 383 | 384 | # Sampling 385 | if self.isLibrary: 386 | ql_m, ql_v, l_z = self.l_encoder(x_rna, batch_index) 387 | qz_rna_m, qz_rna_v, rna_z = self.RNA_encoder(x_rna, batch_index) 388 | qz_atac_m, qz_atac_v, atac_z = self.ATAC_encoder(x_atac, batch_index) 389 | qz_m, qz_v, z = self.RNA_ATAC_encoder([x_rna, x_atac], batch_index) 390 | 391 | qz_joint_mu = self.concatenter(torch.cat((qz_rna_m, qz_atac_m), 1)) 392 | qz_joint_v = self.concatenter(torch.cat((torch.log(qz_rna_v), torch.log(qz_atac_v)), 1)) 393 | qz_joint_v = torch.exp(qz_joint_v) 394 | qz_joint_z = Normal(qz_joint_mu, qz_joint_v.sqrt()).rsample() 395 | gamma_joint, _, _, _ = self.get_gamma(qz_joint_z) 396 | 397 | gamma, mu_c, var_c, pi = self.get_gamma(z, update) # , self.n_centroids, c_params) 398 | index = torch.argmax(gamma, dim=1) 399 | 400 | index1 = [i for i in range(len(index))] 401 | mu_c_max = mu_c[index1, :, index] 402 | var_c_max = var_c[index1, :, index] 403 | z_c_max = reparameterize_gaussian(mu_c_max, var_c_max) 404 | 405 | libary_scale = reparameterize_gaussian(local_l_mean, local_l_var) 406 | if self.isLibrary: 407 | libary_scale = libary_rna 408 | # decoder 409 | p_rna_scale, p_rna_r, p_rna_rate, p_rna_dropout, p_atac_scale, p_atac_r, p_atac_mean, p_atac_dropout \ 410 | = self.RNA_ATAC_decoder(z, z_c_max, batch_index, libary_scale=libary_scale, gamma=gamma, libary_atac=libary_atac) 411 | # classifer 412 | if self.classifer_num > 0 and y is not None: 413 | classifer_pred = self.classifer(z) 414 | classifer_loss = -100*( 415 | one_hot(y, self.classifer_num)*torch.log(classifer_pred+1.0e-10) 416 | ).sum(dim=-1) 417 | 418 | if self.log_variational: 419 | p_rna_rate_norm = torch.log(1 + p_rna_rate) 420 | p_atac_mean_norm = torch.log(1 + p_atac_mean) 421 | rec_rna_mu, rec_rna_v, rec_rna_z = self.RNA_encoder(p_rna_rate_norm, batch_index) 422 | gamma_rna_rec, _, _, _ = self.get_gamma(rec_rna_z) 423 | rec_atac_mu, rec_atac_v, rec_atac_z = self.ATAC_encoder(p_atac_mean_norm, batch_index) 424 | gamma_atac_rec, _, _, _ = self.get_gamma(rec_atac_z) 425 | rec_joint_mu = self.concatenter(torch.cat((rec_rna_mu, rec_atac_mu), 1)) 426 | rec_joint_v = self.concatenter(torch.cat((torch.log(rec_rna_v), torch.log(rec_atac_v)), 1)) 427 | rec_joint_v = torch.exp(rec_joint_v) 428 | rec_joint_z = Normal(rec_joint_mu, rec_joint_v.sqrt()).rsample() 429 | gamma_joint_rec, _, _, _ = self.get_gamma(rec_joint_z) 430 | 431 | if self.dispersion == "gene-label": 432 | p_rna_r = F.linear( 433 | one_hot(y, self.n_labels), self.px_r 434 | ) # px_r gets transposed - last dimension is nb genes 435 | p_atac_r = F.linear( 436 | one_hot(y, self.n_labels), self.p_atac_r 437 | ) 438 | elif self.dispersion == "gene-batch": 439 | p_rna_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) 440 | p_atac_r = F.linear(one_hot(batch_index, self.n_batch), self.p_atac_r) 441 | elif self.dispersion == "gene": 442 | p_rna_r = self.px_r 443 | p_atac_r = self.p_atac_r 444 | 445 | p_rna_r = torch.exp(p_rna_r) 446 | p_atac_r = torch.exp(p_atac_r) 447 | 448 | return dict( 449 | p_rna_scale=p_rna_scale, 450 | p_rna_r=p_rna_r, 451 | p_rna_rate=p_rna_rate, 452 | p_rna_dropout=p_rna_dropout, 453 | p_atac_scale=p_atac_scale, 454 | p_atac_r=p_atac_r, 455 | p_atac_mean=p_atac_mean, 456 | p_atac_dropout=p_atac_dropout, 457 | qz_rna_m=qz_rna_m, 458 | qz_rna_v=qz_rna_v, 459 | rna_z=rna_z, 460 | qz_atac_m=qz_atac_m, 461 | qz_atac_v=qz_atac_v, 462 | atac_z=atac_z, 463 | qz_m=qz_m, 464 | qz_v=qz_v, 465 | z=z, 466 | mu_c=mu_c, 467 | var_c=var_c, 468 | gamma=gamma, 469 | pi=pi, 470 | mu_c_max=mu_c_max, 471 | var_c_max=var_c_max, 472 | z_c_max=z_c_max, 473 | gamma_rna_rec=gamma_rna_rec, 474 | gamma_atac_rec=gamma_atac_rec, 475 | rec_atac_mu=rec_atac_mu, 476 | rec_atac_v=rec_atac_v, 477 | rec_rna_mu=rec_rna_mu, 478 | rec_rna_v=rec_rna_v, 479 | ql_m=ql_m, 480 | ql_v=ql_v, 481 | l_z=l_z, 482 | rec_joint_mu=rec_joint_mu, 483 | rec_joint_v=rec_joint_v, 484 | rec_joint_z=rec_joint_z, 485 | gamma_joint_rec=gamma_joint_rec, 486 | qz_joint_mu=qz_joint_mu, 487 | qz_joint_v=qz_joint_v, 488 | qz_joint_z=qz_joint_z, 489 | gamma_joint=gamma_joint, 490 | classifer_loss=classifer_loss if self.classifer_num > 0 else 0, 491 | ) 492 | 493 | def forward(self, x_rna, x_atac, local_l_mean, local_l_var, batch_index=None, y=None): 494 | r""" Returns the reconstruction loss and the Kullback divergences 495 | 496 | :param x: tensor of values with shape (batch_size, n_input) 497 | :param local_l_mean: tensor of means of the prior distribution of latent variable l 498 | with shape (batch_size, 1) 499 | :param local_l_var: tensor of variancess of the prior distribution of latent variable l 500 | with shape (batch_size, 1) 501 | :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size`` 502 | :param y: tensor of cell-types labels with shape (batch_size, n_labels) 503 | :return: the reconstruction loss and the Kullback divergences 504 | :rtype: 2-tuple of :py:class:`torch.FloatTensor` 505 | """ 506 | # Parameters for z latent distribution 507 | x = [x_rna, x_atac] 508 | outputs = self.inference(x, batch_index, y, local_l_mean, local_l_var, update=False) 509 | qz_rna_m = outputs["qz_rna_m"] 510 | qz_rna_v = outputs["qz_rna_v"] 511 | qz_atac_m = outputs["qz_atac_m"] 512 | qz_atac_v = outputs["qz_atac_v"] 513 | qz_m = outputs["qz_m"] 514 | qz_v = outputs["qz_v"] 515 | p_rna_rate = outputs["p_rna_rate"] 516 | p_rna_r = outputs["p_rna_r"] 517 | p_rna_dropout = outputs["p_rna_dropout"] 518 | p_atac_r = outputs["p_atac_r"] 519 | p_atac_mean = outputs["p_atac_mean"] 520 | p_atac_dropout = outputs["p_atac_dropout"] 521 | mu_c = outputs["mu_c"] 522 | var_c = outputs["var_c"] 523 | gamma = outputs["gamma"] 524 | pi = outputs["pi"] 525 | gamma_rna_rec = outputs["gamma_rna_rec"] 526 | gamma_atac_rec = outputs["gamma_atac_rec"] 527 | rec_atac_mu = outputs["rec_atac_mu"] 528 | rec_atac_v = outputs["rec_atac_v"] 529 | rec_rna_mu = outputs["rec_rna_mu"] 530 | rec_rna_v = outputs["rec_rna_v"] 531 | ql_m = outputs["ql_m"] 532 | ql_v = outputs["ql_v"] 533 | l_z = outputs["l_z"] 534 | rec_joint_mu = outputs["rec_joint_mu"] 535 | rec_joint_v = outputs["rec_joint_v"] 536 | rec_joint_z = outputs["rec_joint_z"] 537 | gamma_joint_rec = outputs["gamma_joint_rec"] 538 | qz_joint_mu = outputs["qz_joint_mu"] 539 | qz_joint_v = outputs["qz_joint_v"] 540 | qz_joint_z = outputs["qz_joint_z"] 541 | gamma_joint = outputs["gamma_joint"] 542 | classifer_loss = outputs["classifer_loss"] 543 | 544 | 545 | n_centroids = pi.size(1) 546 | mu_expand = qz_m.unsqueeze(2).expand(qz_m.size(0), qz_m.size(1), n_centroids) 547 | logvar_expand = qz_v.unsqueeze(2).expand(qz_v.size(0), qz_v.size(1), n_centroids) 548 | # zl 549 | 550 | # log p(z|c) 551 | logpzc = -0.5 * torch.sum(gamma * torch.sum(math.log(2 * math.pi) + \ 552 | torch.log(var_c) + \ 553 | torch.exp(logvar_expand) / var_c + \ 554 | (mu_expand - mu_c) ** 2 / var_c, dim=1), dim=1) 555 | # log p(c) 556 | logpc = torch.sum(gamma * torch.log(pi), 1) 557 | 558 | # log q(z|x) or q entropy 559 | qentropy = -0.5 * torch.sum(1 + qz_v + math.log(2 * math.pi), 1) 560 | 561 | # log q(c|x) 562 | logqcx = torch.sum(gamma * torch.log(gamma), 1) 563 | 564 | # kl(qz||pz) 565 | kld_qz_pz = -logpzc - logpc + qentropy + logqcx 566 | print("logpzc:{}, logqcx:{}".format(torch.mean(logpzc), torch.mean(logqcx))) 567 | # print("gamma={},var_c={}".format(gamma,var_c)) 568 | # kl(qz||qz_rna) 569 | kld_qz_rna = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(qz_rna_m, torch.sqrt(qz_rna_v))).sum( 570 | dim=1 571 | ) 572 | 573 | # kl(qz||qz_atac) 574 | kld_qz_atac = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(qz_atac_m, torch.sqrt(qz_atac_v))).sum( 575 | # check the postive qz_v 576 | dim=1 577 | ) 578 | 579 | # kl(qz||qz_joint) 580 | kld_qz_joint = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(qz_joint_mu, torch.sqrt(qz_joint_v))).sum( 581 | # check the postive qz_v 582 | dim=1 583 | ) 584 | 585 | # KL Divergence 586 | kl_divergence = kld_qz_pz + 0.1 * (kld_qz_joint) 587 | if self.isLibrary: 588 | 589 | consistent_loss_rna = -( 590 | torch.softmax(gamma, dim=-1) * torch.log(torch.softmax(gamma_rna_rec, dim=-1) + 1.0e-6) + ( 591 | 1 - torch.softmax(gamma, dim=-1)) * torch.log( 592 | 1 - torch.softmax(gamma_rna_rec, dim=-1) + 1.0e-6)).sum(dim=-1) 593 | consistent_loss_atac = -( 594 | torch.softmax(gamma, dim=-1) * torch.log(torch.softmax(gamma_atac_rec, dim=-1) + 1.0e-6) + ( 595 | 1 - torch.softmax(gamma, dim=-1)) * torch.log( 596 | 1 - torch.softmax(gamma_atac_rec, dim=-1) + 1.0e-6)).sum(dim=-1) 597 | consistent_loss_joint = -( 598 | torch.softmax(gamma, dim=-1) * torch.log(torch.softmax(gamma_joint_rec, dim=-1) + 1.0e-6) + ( 599 | 1 - torch.softmax(gamma, dim=-1)) * torch.log( 600 | 1 - torch.softmax(gamma_joint_rec, dim=-1) + 1.0e-6)).sum(dim=-1) 601 | rec_rna_kl = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(rec_rna_mu, torch.sqrt(rec_rna_v))).sum( 602 | dim=1 603 | ) 604 | rec_atac_kl = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(rec_atac_mu, torch.sqrt(rec_atac_v))).sum( 605 | dim=1 606 | ) 607 | rec_joint_kl = kl(Normal(qz_joint_mu, torch.sqrt(qz_joint_v)), Normal(rec_joint_mu, torch.sqrt(rec_joint_v))).sum( 608 | dim=1 609 | ) 610 | 611 | # likelihood 612 | reconst_loss_rna = 3.0*self.get_reconstruction_loss(x[0], p_rna_rate, p_rna_r, p_rna_dropout) 613 | reconst_loss_atac = 0.1 * self.get_reconstruction_atac_loss(x[1], p_atac_mean, p_atac_r, 614 | p_atac_dropout) # implement this function 615 | reconst_loss = reconst_loss_rna + reconst_loss_atac + classifer_loss 616 | if self.isLibrary: 617 | 618 | reconst_loss = reconst_loss + 0.5 * (consistent_loss_joint - 619 | 50*torch.sum(gamma * gamma,dim=-1) - 620 | 50*torch.sum((torch.sum(gamma,dim=0)/gamma.shape[0])*(torch.log(torch.sum(gamma,dim=0)/gamma.shape[0]+1.0e-10)))) 621 | kl_divergence = kl_divergence + 0.1 * (rec_joint_kl) 622 | 623 | 624 | # init the gmm model, training pc 625 | print("kld_qz_pz = %f,kld_qz_rna = %f,kld_qz_atac = %f,kl_divergence = %f,reconst_loss_rna = %f,\ 626 | reconst_loss_atac = %f, mu=%f, sigma=%f" % ( 627 | torch.mean(kld_qz_pz), torch.mean(kld_qz_rna), torch.mean(kld_qz_atac), \ 628 | torch.mean(kl_divergence), torch.mean(reconst_loss_rna), torch.mean(reconst_loss_atac), 629 | torch.mean(self.mu_c), torch.mean(self.var_c))) 630 | return reconst_loss, kl_divergence, 0.0 631 | 632 | 633 | class LDVAE(Multi_VAE_Attention): 634 | r"""Linear-decoded Variational auto-encoder model. 635 | 636 | This model uses a linear decoder, directly mapping the latent representation 637 | to gene expression levels. It still uses a deep neural network to encode 638 | the latent representation. 639 | 640 | Compared to standard VAE, this model is less powerful, but can be used to 641 | inspect which genes contribute to variation in the dataset. 642 | 643 | :param n_input: Number of input genes 644 | :param n_batch: Number of batches 645 | :param n_labels: Number of labels 646 | :param n_hidden: Number of nodes per hidden layer (for encoder) 647 | :param n_latent: Dimensionality of the latent space 648 | :param n_layers: Number of hidden layers used for encoder NNs 649 | :param dropout_rate: Dropout rate for neural networks 650 | :param dispersion: One of the following 651 | 652 | * ``'gene'`` - dispersion parameter of NB is constant per gene across cells 653 | * ``'gene-batch'`` - dispersion can differ between different batches 654 | * ``'gene-label'`` - dispersion can differ between different labels 655 | * ``'gene-cell'`` - dispersion can differ for every gene in every cell 656 | 657 | :param log_variational: Log(data+1) prior to encoding for numerical stability. Not normalization. 658 | :param reconstruction_loss: One of 659 | 660 | * ``'nb'`` - Negative binomial distribution 661 | * ``'zinb'`` - Zero-inflated negative binomial distribution 662 | """ 663 | 664 | def __init__( 665 | self, 666 | n_input: int, 667 | n_batch: int = 0, 668 | n_labels: int = 0, 669 | n_hidden: int = 128, 670 | n_latent: int = 10, 671 | n_layers: int = 1, 672 | dropout_rate: float = 0.1, 673 | dispersion: str = "gene", 674 | log_variational: bool = True, 675 | reconstruction_loss: str = "zinb", 676 | ): 677 | super().__init__( 678 | n_input, 679 | n_batch, 680 | n_labels, 681 | n_hidden, 682 | n_latent, 683 | n_layers, 684 | dropout_rate, 685 | dispersion, 686 | log_variational, 687 | reconstruction_loss, 688 | ) 689 | 690 | self.decoder = LinearDecoderSCVI( 691 | n_latent, 692 | n_input, 693 | n_cat_list=[n_batch], 694 | n_layers=n_layers, 695 | n_hidden=n_hidden, 696 | ) 697 | 698 | def get_loadings(self): 699 | """ Extract per-gene weights (for each Z) in the linear decoder. 700 | """ 701 | return self.decoder.factor_regressor.parameters() 702 | -------------------------------------------------------------------------------- /scMVP/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import scipy.io as sp_io 3 | import numpy as np 4 | 5 | def iterate(obj, func): 6 | t = type(obj) 7 | if t is list or t is tuple: 8 | return t([iterate(o, func) for o in obj]) 9 | else: 10 | return func(obj) if obj is not None else None 11 | 12 | 13 | def broadcast_labels(y, *o, n_broadcast=-1): 14 | """ 15 | Utility for the semi-supervised setting 16 | If y is defined(labelled batch) then one-hot encode the labels (no broadcasting needed) 17 | If y is undefined (unlabelled batch) then generate all possible labels (and broadcast other arguments if not None) 18 | """ 19 | if not len(o): 20 | raise ValueError("Broadcast must have at least one reference argument") 21 | if y is None: 22 | ys = enumerate_discrete(o[0], n_broadcast) 23 | new_o = iterate( 24 | o, 25 | lambda x: x.repeat(n_broadcast, 1) 26 | if len(x.size()) == 2 27 | else x.repeat(n_broadcast), 28 | ) 29 | else: 30 | ys = one_hot(y, n_broadcast) 31 | new_o = o 32 | return (ys,) + new_o 33 | 34 | 35 | def one_hot(index, n_cat): 36 | onehot = torch.zeros(index.size(0), n_cat, device=index.device) 37 | onehot.scatter_(1, index.type(torch.long), 1) 38 | return onehot.type(torch.float32) 39 | 40 | 41 | def enumerate_discrete(x, y_dim): 42 | def batch(batch_size, label): 43 | labels = torch.ones(batch_size, 1, device=x.device, dtype=torch.long) * label 44 | return one_hot(labels, y_dim) 45 | 46 | batch_size = x.size(0) 47 | return torch.cat([batch(batch_size, i) for i in range(y_dim)]) 48 | 49 | 50 | def binarization(imputed, raw): 51 | return (imputed.T > np.quantile(imputed,q=0.8,axis=1).T).T & (imputed>imputed.mean(0)) & (imputed>0).astype(np.int8) 52 | -------------------------------------------------------------------------------- /scMVP/models/vaePeak_selfattetion.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Main module.""" 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.distributions import Normal, kl_divergence as kl 8 | 9 | from scMVP.models.log_likelihood import log_zinb_positive, log_nb_positive, mean_square_error, mean_square_error_positive, log_zip_positive 10 | from scMVP.models.modules import Encoder, DecoderSCVI, LinearDecoderSCVI, DecoderSCVI_nb, DecoderSCVI_mse, Encoder_l, Encoder_mse,\ 11 | Encoder_nb, DecoderSCVI_Peak, DecoderSCVI_Peak_Selfattention, Encoder_nb_selfattention, DecoderSCVI_nb_Selfattention 12 | from scMVP.models.utils import one_hot 13 | 14 | torch.backends.cudnn.benchmark = True 15 | 16 | 17 | # VAE model 18 | class VAE_Peak_SelfAttention(nn.Module): 19 | r"""Variational auto-encoder model. 20 | 21 | :param n_input: Number of input genes 22 | :param n_batch: Number of batches 23 | :param n_labels: Number of labels 24 | :param n_hidden: Number of nodes per hidden layer 25 | :param n_latent: Dimensionality of the latent space 26 | :param n_layers: Number of hidden layers used for encoder and decoder NNs 27 | :param dropout_rate: Dropout rate for neural networks 28 | :param mode: One of the following: 29 | * ``'vae'`` -single channel auto-encoder decoder neural framework for scRNA-seq data 30 | * ``'mm-vae'`` -multi-channels auto-encoder decoder neural framework for scRNA and scATAC data 31 | :param dispersion: One of the following 32 | 33 | * ``'gene'`` - dispersion parameter of NB is constant per gene across cells 34 | * ``'gene-batch'`` - dispersion can differ between different batches 35 | * ``'gene-label'`` - dispersion can differ between different labels 36 | * ``'gene-cell'`` - dispersion can differ for every gene in every cell 37 | 38 | :param log_variational: Log(data+1) prior to encoding for numerical stability. Not normalization. 39 | :param reconstruction_loss: One of 40 | 41 | * ``'nb'`` - Negative binomial distribution 42 | * ``'zinb'`` - Zero-inflated negative binomial distribution 43 | 44 | Examples: 45 | >>> gene_dataset = CortexDataset() 46 | >>> vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False, 47 | ... n_labels=gene_dataset.n_labels) 48 | 49 | """ 50 | 51 | def __init__( 52 | self, 53 | n_input: int, 54 | n_batch: int = 0, 55 | n_labels: int = 0, 56 | n_hidden: int = 128, # 256 57 | n_latent: int = 10, # 20 58 | n_layers: int = 1, 59 | dropout_rate: float = 0.1, 60 | dispersion: str = "gene", 61 | log_variational: bool = True, 62 | reconstruction_loss: str = "zinb", 63 | #reconstruction_loss: str = "nb", 64 | ): 65 | super().__init__() 66 | self.dispersion = dispersion 67 | self.n_latent = n_latent 68 | self.log_variational = log_variational 69 | self.reconstruction_loss = reconstruction_loss 70 | # Automatically deactivate if useless 71 | self.n_batch = n_batch 72 | self.n_labels = n_labels 73 | 74 | 75 | if self.dispersion == "gene": 76 | self.px_r = torch.nn.Parameter(torch.randn(n_input)) 77 | elif self.dispersion == "gene-batch": 78 | self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch)) 79 | elif self.dispersion == "gene-label": 80 | self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels)) 81 | elif self.dispersion == "gene-cell": 82 | pass 83 | else: 84 | raise ValueError( 85 | "dispersion must be one of ['gene', 'gene-batch'," 86 | " 'gene-label', 'gene-cell'], but input was " 87 | "{}.format(self.dispersion)" 88 | ) 89 | # z encoder goes from the n_input-dimensional data to an n_latent-d 90 | # latent space representation 91 | if self.reconstruction_loss == "mse": 92 | self.z_encoder = Encoder_mse( 93 | n_input, 94 | n_latent, 95 | n_layers=n_layers, 96 | n_hidden=n_hidden, 97 | dropout_rate=dropout_rate, 98 | ) 99 | elif self.reconstruction_loss == "nb": 100 | self.z_encoder = Encoder_nb_selfattention( 101 | n_input, 102 | n_latent, 103 | n_layers=n_layers, 104 | n_hidden=n_hidden, 105 | dropout_rate=dropout_rate, 106 | ) 107 | else: 108 | self.z_encoder = Encoder( 109 | n_input, 110 | n_latent, 111 | n_layers=n_layers, 112 | n_hidden=n_hidden, 113 | dropout_rate=dropout_rate, 114 | ) 115 | # l encoder goes from n_input-dimensional data to 1-d library size 116 | if self.reconstruction_loss == "nb": 117 | self.l_encoder = Encoder_l( 118 | n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate, 119 | ) 120 | elif self.reconstruction_loss == "zinb": 121 | self.l_encoder = Encoder_l( 122 | n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate, 123 | ) 124 | else: 125 | self.l_encoder = Encoder( 126 | n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate, 127 | ) 128 | if self.reconstruction_loss == "mse": 129 | self.decoder = DecoderSCVI_mse( 130 | n_latent, 131 | n_input, 132 | n_cat_list=[n_batch], 133 | n_layers=n_layers, 134 | n_hidden=n_hidden, 135 | ) 136 | elif self.reconstruction_loss == "nb" and self.log_variational: 137 | self.decoder = DecoderSCVI_Peak_Selfattention( 138 | n_latent, 139 | n_input, 140 | n_cat_list=[n_batch], 141 | n_layers=n_layers, 142 | n_hidden=n_hidden, 143 | ) 144 | elif self.reconstruction_loss == "nb" and not self.log_variational: 145 | self.decoder = DecoderSCVI_Peak_Selfattention( 146 | n_latent, 147 | n_input, 148 | n_cat_list=[n_batch], 149 | n_layers=n_layers, 150 | n_hidden=n_hidden, 151 | ) 152 | else: 153 | self.decoder = DecoderSCVI( 154 | n_latent, 155 | n_input, 156 | n_cat_list=[n_batch], 157 | n_layers=n_layers, 158 | n_hidden=n_hidden, 159 | ) 160 | 161 | def get_latents(self, x, y=None): 162 | r""" returns the result of ``sample_from_posterior_z`` inside a list 163 | 164 | :param x: tensor of values with shape ``(batch_size, n_input)`` 165 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` 166 | :return: one element list of tensor 167 | :rtype: list of :py:class:`torch.Tensor` 168 | """ 169 | return [self.sample_from_posterior_z(x, y)] 170 | 171 | def sample_from_posterior_z(self, x, y=None, give_mean=False): 172 | r""" samples the tensor of latent values from the posterior 173 | #doesn't really sample, returns the means of the posterior distribution 174 | 175 | :param x: tensor of values with shape ``(batch_size, n_input)`` 176 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` 177 | :param give_mean: is True when we want the mean of the posterior distribution rather than sampling 178 | :return: tensor of shape ``(batch_size, n_latent)`` 179 | :rtype: :py:class:`torch.Tensor` 180 | """ 181 | if self.log_variational: 182 | x = torch.log(1 + x) 183 | qz_m, qz_v, z = self.z_encoder(x, y) # y only used in VAEC 184 | if give_mean: 185 | z = qz_m 186 | return z 187 | 188 | def sample_from_posterior_l(self, x): 189 | r""" samples the tensor of library sizes from the posterior 190 | #doesn't really sample, returns the tensor of the means of the posterior distribution 191 | 192 | :param x: tensor of values with shape ``(batch_size, n_input)`` 193 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` 194 | :return: tensor of shape ``(batch_size, 1)`` 195 | :rtype: :py:class:`torch.Tensor` 196 | """ 197 | if self.log_variational: 198 | x = torch.log(1 + x) 199 | ql_m, ql_v, library = self.l_encoder(x) 200 | return library 201 | 202 | def get_sample_scale(self, x, batch_index=None, y=None, n_samples=1): 203 | r"""Returns the tensor of predicted frequencies of expression 204 | 205 | :param x: tensor of values with shape ``(batch_size, n_input)`` 206 | :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size`` 207 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` 208 | :param n_samples: number of samples 209 | :return: tensor of predicted frequencies of expression with shape ``(batch_size, n_input)`` 210 | :rtype: :py:class:`torch.Tensor` 211 | """ 212 | return self.inference(x, batch_index=batch_index, y=y, n_samples=n_samples)[ 213 | "px_scale" 214 | ] 215 | 216 | def get_sample_rate(self, x, batch_index=None, y=None, n_samples=1): 217 | r"""Returns the tensor of means of the negative binomial distribution 218 | 219 | :param x: tensor of values with shape ``(batch_size, n_input)`` 220 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` 221 | :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size`` 222 | :param n_samples: number of samples 223 | :return: tensor of means of the negative binomial distribution with shape ``(batch_size, n_input)`` 224 | :rtype: :py:class:`torch.Tensor` 225 | """ 226 | return self.inference(x, batch_index=batch_index, y=y, n_samples=n_samples)[ 227 | "px_rate" 228 | ] 229 | 230 | def get_reconstruction_loss(self, x, px_rate, px_r, px_dropout, **kwargs): 231 | # Reconstruction Loss 232 | if self.reconstruction_loss == "zinb": 233 | reconst_loss = -log_nb_positive(x, px_rate, px_r).sum(dim=-1) 234 | elif self.reconstruction_loss == "nb": 235 | reconst_loss = 0.5*mean_square_error_positive(x, px_rate).sum(dim=-1) - log_zip_positive(x, px_rate, px_dropout).sum(dim=-1) 236 | px_rate[x > 0] = 0 237 | reconst_loss = reconst_loss + 0.05*px_rate.sum(dim=-1) 238 | 239 | elif self.reconstruction_loss == "mse": 240 | reconst_loss = mean_square_error_positive(x, px_rate).sum(dim=-1) 241 | return reconst_loss 242 | 243 | def scale_from_z(self, sample_batch, fixed_batch): 244 | if self.log_variational: 245 | sample_batch = torch.log(1 + sample_batch) 246 | qz_m, qz_v, z = self.z_encoder(sample_batch) 247 | batch_index = fixed_batch * torch.ones_like(sample_batch[:, [0]]) 248 | library = 4.0 * torch.ones_like(sample_batch[:, [0]]) 249 | px_scale, _, _, _ = self.decoder("gene", z, library, batch_index) 250 | return px_scale 251 | 252 | def inference(self, x, batch_index=None, y=None, n_samples=1): 253 | 254 | x_ = x 255 | if self.reconstruction_loss == "nb" and self.log_variational: 256 | library_nb = torch.log(x_.sum(dim=-1)).reshape(-1, 1) 257 | elif self.reconstruction_loss == "nb" and (not self.log_variational): 258 | library_nb = (x_.sum(dim=-1)).reshape(-1, 1) 259 | if self.log_variational: 260 | x_ = torch.log(1 + x_) 261 | 262 | # Sampling 263 | qz_m, qz_v, z = self.z_encoder(x_, y) 264 | ql_m, ql_v, library = self.l_encoder(x_) 265 | if self.reconstruction_loss == "nb": 266 | library = library_nb 267 | 268 | if n_samples > 1: 269 | qz_m = qz_m.unsqueeze(0).expand((n_samples, qz_m.size(0), qz_m.size(1))) 270 | qz_v = qz_v.unsqueeze(0).expand((n_samples, qz_v.size(0), qz_v.size(1))) 271 | z = Normal(qz_m, qz_v.sqrt()).sample() 272 | ql_m = ql_m.unsqueeze(0).expand((n_samples, ql_m.size(0), ql_m.size(1))) 273 | ql_v = ql_v.unsqueeze(0).expand((n_samples, ql_v.size(0), ql_v.size(1))) 274 | library = Normal(ql_m, ql_v.sqrt()).sample() 275 | 276 | px_scale, px_r, px_rate, px_dropout = self.decoder( 277 | self.dispersion, z, library, batch_index, y 278 | ) 279 | if self.dispersion == "gene-label": 280 | px_r = F.linear( 281 | one_hot(y, self.n_labels), self.px_r 282 | ) # px_r gets transposed - last dimension is nb genes 283 | elif self.dispersion == "gene-batch": 284 | px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) 285 | elif self.dispersion == "gene": 286 | px_r = self.px_r 287 | px_r = torch.exp(px_r) 288 | 289 | return dict( 290 | px_scale=px_scale, 291 | px_r=px_r, 292 | px_rate=px_rate, 293 | px_dropout=px_dropout, 294 | qz_m=qz_m, 295 | qz_v=qz_v, 296 | z=z, 297 | ql_m=ql_m, 298 | ql_v=ql_v, 299 | library=library, 300 | ) 301 | 302 | def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None): 303 | r""" Returns the reconstruction loss and the Kullback divergences 304 | 305 | :param x: tensor of values with shape (batch_size, n_input) 306 | :param local_l_mean: tensor of means of the prior distribution of latent variable l 307 | with shape (batch_size, 1) 308 | :param local_l_var: tensor of variancess of the prior distribution of latent variable l 309 | with shape (batch_size, 1) 310 | :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size`` 311 | :param y: tensor of cell-types labels with shape (batch_size, n_labels) 312 | :return: the reconstruction loss and the Kullback divergences 313 | :rtype: 2-tuple of :py:class:`torch.FloatTensor` 314 | """ 315 | # Parameters for z latent distribution 316 | outputs = self.inference(x, batch_index, None) 317 | qz_m = outputs["qz_m"] 318 | qz_v = outputs["qz_v"] 319 | ql_m = outputs["ql_m"] 320 | ql_v = outputs["ql_v"] 321 | px_rate = outputs["px_rate"] 322 | px_r = outputs["px_r"] 323 | px_dropout = outputs["px_dropout"] 324 | 325 | # KL Divergence 326 | mean = torch.zeros_like(qz_m) 327 | scale = torch.ones_like(qz_v) 328 | 329 | 330 | kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum( 331 | dim=1 332 | ) 333 | kl_divergence_l = kl( 334 | Normal(ql_m, torch.sqrt(ql_v)), 335 | Normal(local_l_mean, torch.sqrt(local_l_var)), 336 | ).sum(dim=1) 337 | kl_divergence = kl_divergence_z 338 | 339 | reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout) 340 | 341 | if self.reconstruction_loss == "mse" or self.reconstruction_loss == "nb": 342 | kl_divergence_l = 1.0 343 | print("reconst_loss=%f, kl_divergence=%f"%(torch.mean(reconst_loss),torch.mean(kl_divergence))) 344 | return reconst_loss + kl_divergence_l, kl_divergence, 0.0 345 | 346 | 347 | class LDVAE(VAE_Peak_SelfAttention): 348 | r"""Linear-decoded Variational auto-encoder model. 349 | 350 | This model uses a linear decoder, directly mapping the latent representation 351 | to gene expression levels. It still uses a deep neural network to encode 352 | the latent representation. 353 | 354 | Compared to standard VAE, this model is less powerful, but can be used to 355 | inspect which genes contribute to variation in the dataset. 356 | 357 | :param n_input: Number of input genes 358 | :param n_batch: Number of batches 359 | :param n_labels: Number of labels 360 | :param n_hidden: Number of nodes per hidden layer (for encoder) 361 | :param n_latent: Dimensionality of the latent space 362 | :param n_layers: Number of hidden layers used for encoder NNs 363 | :param dropout_rate: Dropout rate for neural networks 364 | :param dispersion: One of the following 365 | 366 | * ``'gene'`` - dispersion parameter of NB is constant per gene across cells 367 | * ``'gene-batch'`` - dispersion can differ between different batches 368 | * ``'gene-label'`` - dispersion can differ between different labels 369 | * ``'gene-cell'`` - dispersion can differ for every gene in every cell 370 | 371 | :param log_variational: Log(data+1) prior to encoding for numerical stability. Not normalization. 372 | :param reconstruction_loss: One of 373 | 374 | * ``'nb'`` - Negative binomial distribution 375 | * ``'zinb'`` - Zero-inflated negative binomial distribution 376 | """ 377 | 378 | def __init__( 379 | self, 380 | n_input: int, 381 | n_batch: int = 0, 382 | n_labels: int = 0, 383 | n_hidden: int = 128, 384 | n_latent: int = 10, 385 | n_layers: int = 1, 386 | dropout_rate: float = 0.1, 387 | dispersion: str = "gene", 388 | log_variational: bool = True, 389 | reconstruction_loss: str = "zinb", 390 | ): 391 | super().__init__( 392 | n_input, 393 | n_batch, 394 | n_labels, 395 | n_hidden, 396 | n_latent, 397 | n_layers, 398 | dropout_rate, 399 | dispersion, 400 | log_variational, 401 | reconstruction_loss, 402 | ) 403 | 404 | self.decoder = LinearDecoderSCVI( 405 | n_latent, 406 | n_input, 407 | n_cat_list=[n_batch], 408 | n_layers=n_layers, 409 | n_hidden=n_hidden, 410 | ) 411 | 412 | def get_loadings(self): 413 | """ Extract per-gene weights (for each Z) in the linear decoder. 414 | """ 415 | return self.decoder.factor_regressor.parameters() 416 | -------------------------------------------------------------------------------- /scMVP/models/vae_attention.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Main module.""" 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.distributions import Normal, kl_divergence as kl 8 | 9 | from scMVP.models.log_likelihood import log_zinb_positive, log_nb_positive, mean_square_error, mean_square_error_positive 10 | from scMVP.models.modules import Encoder, DecoderSCVI, LinearDecoderSCVI, DecoderSCVI_nb, DecoderSCVI_mse, Encoder_l, Encoder_mse,\ 11 | Encoder_nb, DecoderSCVI_nb_rna, Encoder_nb_attention, DecoderSCVI_nb_Selfattention, Encoder_nb_selfattention 12 | from scMVP.models.utils import one_hot 13 | 14 | torch.backends.cudnn.benchmark = True 15 | 16 | 17 | # VAE model 18 | class VAE_Attention(nn.Module): 19 | r"""Variational auto-encoder model. 20 | 21 | :param n_input: Number of input genes 22 | :param n_batch: Number of batches 23 | :param n_labels: Number of labels 24 | :param n_hidden: Number of nodes per hidden layer 25 | :param n_latent: Dimensionality of the latent space 26 | :param n_layers: Number of hidden layers used for encoder and decoder NNs 27 | :param dropout_rate: Dropout rate for neural networks 28 | :param mode: One of the following: 29 | * ``'vae'`` -single channel auto-encoder decoder neural framework for scRNA-seq data 30 | * ``'mm-vae'`` -multi-channels auto-encoder decoder neural framework for scRNA and scATAC data 31 | :param dispersion: One of the following 32 | 33 | * ``'gene'`` - dispersion parameter of NB is constant per gene across cells 34 | * ``'gene-batch'`` - dispersion can differ between different batches 35 | * ``'gene-label'`` - dispersion can differ between different labels 36 | * ``'gene-cell'`` - dispersion can differ for every gene in every cell 37 | 38 | :param log_variational: Log(data+1) prior to encoding for numerical stability. Not normalization. 39 | :param reconstruction_loss: One of 40 | 41 | * ``'nb'`` - Negative binomial distribution 42 | * ``'zinb'`` - Zero-inflated negative binomial distribution 43 | 44 | Examples: 45 | >>> gene_dataset = CortexDataset() 46 | >>> vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False, 47 | ... n_labels=gene_dataset.n_labels) 48 | 49 | """ 50 | 51 | def __init__( 52 | self, 53 | n_input: int, 54 | n_batch: int = 0, 55 | n_labels: int = 0, 56 | n_hidden: int = 128, # 256 57 | n_latent: int = 10, # 20 58 | n_layers: int = 1, 59 | dropout_rate: float = 0.1, 60 | dispersion: str = "gene", 61 | log_variational: bool = True, 62 | reconstruction_loss: str = "zinb", 63 | #reconstruction_loss: str = "nb", 64 | ): 65 | super().__init__() 66 | self.dispersion = dispersion 67 | self.n_latent = n_latent 68 | self.log_variational = log_variational 69 | self.reconstruction_loss = reconstruction_loss 70 | # Automatically deactivate if useless 71 | self.n_batch = n_batch 72 | self.n_labels = n_labels 73 | 74 | if self.dispersion == "gene": 75 | self.px_r = torch.nn.Parameter(torch.randn(n_input)) 76 | elif self.dispersion == "gene-batch": 77 | self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch)) 78 | elif self.dispersion == "gene-label": 79 | self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels)) 80 | elif self.dispersion == "gene-cell": 81 | pass 82 | else: 83 | raise ValueError( 84 | "dispersion must be one of ['gene', 'gene-batch'," 85 | " 'gene-label', 'gene-cell'], but input was " 86 | "{}.format(self.dispersion)" 87 | ) 88 | # z encoder goes from the n_input-dimensional data to an n_latent-d 89 | # latent space representation 90 | if self.reconstruction_loss == "mse": 91 | self.z_encoder = Encoder_mse( 92 | n_input, 93 | n_latent, 94 | n_layers=n_layers, 95 | n_hidden=n_hidden, 96 | dropout_rate=dropout_rate, 97 | ) 98 | elif self.reconstruction_loss == "nb": 99 | self.z_encoder = Encoder_nb_attention( 100 | n_input, 101 | n_latent, 102 | n_layers=n_layers, 103 | n_hidden=n_hidden, 104 | dropout_rate=dropout_rate, 105 | ) 106 | 107 | else: 108 | self.z_encoder = Encoder( 109 | n_input, 110 | n_latent, 111 | n_layers=n_layers, 112 | n_hidden=n_hidden, 113 | dropout_rate=dropout_rate, 114 | ) 115 | # l encoder goes from n_input-dimensional data to 1-d library size 116 | if self.reconstruction_loss == "nb": 117 | self.l_encoder = Encoder_l( 118 | n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate, 119 | ) 120 | elif self.reconstruction_loss == "zinb": 121 | self.l_encoder = Encoder_l( 122 | n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate, 123 | ) 124 | else: 125 | self.l_encoder = Encoder( 126 | n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate, 127 | ) 128 | if self.reconstruction_loss == "mse": 129 | self.decoder = DecoderSCVI_mse( 130 | n_latent, 131 | n_input, 132 | n_cat_list=[n_batch], 133 | n_layers=n_layers, 134 | n_hidden=n_hidden, 135 | ) 136 | elif self.reconstruction_loss == "nb" and self.log_variational: 137 | self.decoder = DecoderSCVI_nb( 138 | n_latent, 139 | n_input, 140 | n_cat_list=[n_batch], 141 | n_layers=n_layers, 142 | n_hidden=n_hidden, 143 | ) 144 | elif self.reconstruction_loss == "nb" and not self.log_variational: 145 | self.decoder = DecoderSCVI_nb_rna( 146 | n_latent, 147 | n_input, 148 | n_cat_list=[n_batch], 149 | n_layers=n_layers, 150 | n_hidden=n_hidden, 151 | ) 152 | else: 153 | self.decoder = DecoderSCVI( 154 | n_latent, 155 | n_input, 156 | n_cat_list=[n_batch], 157 | n_layers=n_layers, 158 | n_hidden=n_hidden, 159 | ) 160 | 161 | def get_latents(self, x, y=None): 162 | r""" returns the result of ``sample_from_posterior_z`` inside a list 163 | 164 | :param x: tensor of values with shape ``(batch_size, n_input)`` 165 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` 166 | :return: one element list of tensor 167 | :rtype: list of :py:class:`torch.Tensor` 168 | """ 169 | return [self.sample_from_posterior_z(x, y)] 170 | 171 | def sample_from_posterior_z(self, x, y=None, give_mean=False): 172 | r""" samples the tensor of latent values from the posterior 173 | #doesn't really sample, returns the means of the posterior distribution 174 | 175 | :param x: tensor of values with shape ``(batch_size, n_input)`` 176 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` 177 | :param give_mean: is True when we want the mean of the posterior distribution rather than sampling 178 | :return: tensor of shape ``(batch_size, n_latent)`` 179 | :rtype: :py:class:`torch.Tensor` 180 | """ 181 | if self.log_variational: 182 | x = torch.log(1 + x) 183 | qz_m, qz_v, z = self.z_encoder(x, y) # y only used in VAEC 184 | if give_mean: 185 | z = qz_m 186 | return z 187 | 188 | def sample_from_posterior_l(self, x): 189 | r""" samples the tensor of library sizes from the posterior 190 | #doesn't really sample, returns the tensor of the means of the posterior distribution 191 | 192 | :param x: tensor of values with shape ``(batch_size, n_input)`` 193 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` 194 | :return: tensor of shape ``(batch_size, 1)`` 195 | :rtype: :py:class:`torch.Tensor` 196 | """ 197 | if self.log_variational: 198 | x = torch.log(1 + x) 199 | ql_m, ql_v, library = self.l_encoder(x) 200 | return library 201 | 202 | def get_sample_scale(self, x, batch_index=None, y=None, n_samples=1): 203 | r"""Returns the tensor of predicted frequencies of expression 204 | 205 | :param x: tensor of values with shape ``(batch_size, n_input)`` 206 | :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size`` 207 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` 208 | :param n_samples: number of samples 209 | :return: tensor of predicted frequencies of expression with shape ``(batch_size, n_input)`` 210 | :rtype: :py:class:`torch.Tensor` 211 | """ 212 | return self.inference(x, batch_index=batch_index, y=y, n_samples=n_samples)[ 213 | "px_scale" 214 | ] 215 | 216 | def get_sample_rate(self, x, batch_index=None, y=None, n_samples=1): 217 | r"""Returns the tensor of means of the negative binomial distribution 218 | 219 | :param x: tensor of values with shape ``(batch_size, n_input)`` 220 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` 221 | :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size`` 222 | :param n_samples: number of samples 223 | :return: tensor of means of the negative binomial distribution with shape ``(batch_size, n_input)`` 224 | :rtype: :py:class:`torch.Tensor` 225 | """ 226 | return self.inference(x, batch_index=batch_index, y=y, n_samples=n_samples)[ 227 | "px_rate" 228 | ] 229 | 230 | def get_reconstruction_loss(self, x, px_rate, px_r, px_dropout, **kwargs): 231 | # Reconstruction Loss: 232 | if self.reconstruction_loss == "nb": 233 | reconst_loss = -log_nb_positive(x, px_rate, px_r).sum(dim=-1) + 0.5*mean_square_error_positive(x, px_rate).sum(dim=-1) 234 | elif self.reconstruction_loss == "zinb": 235 | reconst_loss = -log_nb_positive(x, px_rate, px_r).sum(dim=-1) + 0.5*mean_square_error_positive(x, px_rate).sum(dim=-1) 236 | elif self.reconstruction_loss == "mse": 237 | reconst_loss = mean_square_error_positive(x, px_rate).sum(dim=-1) 238 | return reconst_loss 239 | 240 | def scale_from_z(self, sample_batch, fixed_batch): 241 | if self.log_variational: 242 | sample_batch = torch.log(1 + sample_batch) 243 | qz_m, qz_v, z = self.z_encoder(sample_batch) 244 | batch_index = fixed_batch * torch.ones_like(sample_batch[:, [0]]) 245 | library = 4.0 * torch.ones_like(sample_batch[:, [0]]) 246 | px_scale, _, _, _ = self.decoder("gene", z, library, batch_index) 247 | return px_scale 248 | 249 | def inference(self, x, batch_index=None, y=None, n_samples=1): 250 | 251 | x_ = x 252 | if self.reconstruction_loss == "nb" and self.log_variational: 253 | library_nb = torch.log(x_.sum(dim=-1)).reshape(-1, 1) 254 | elif self.reconstruction_loss == "nb" and not self.log_variational: 255 | library_nb = (x_.sum(dim=-1)).reshape(-1, 1) 256 | if self.log_variational: 257 | x_ = torch.log(1 + x_) 258 | 259 | # Sampling 260 | qz_m, qz_v, z = self.z_encoder(x_, y) 261 | ql_m, ql_v, library = self.l_encoder(x_) 262 | if self.reconstruction_loss == "nb": 263 | library = library_nb 264 | 265 | if n_samples > 1: 266 | qz_m = qz_m.unsqueeze(0).expand((n_samples, qz_m.size(0), qz_m.size(1))) 267 | qz_v = qz_v.unsqueeze(0).expand((n_samples, qz_v.size(0), qz_v.size(1))) 268 | z = Normal(qz_m, qz_v.sqrt()).sample() 269 | ql_m = ql_m.unsqueeze(0).expand((n_samples, ql_m.size(0), ql_m.size(1))) 270 | ql_v = ql_v.unsqueeze(0).expand((n_samples, ql_v.size(0), ql_v.size(1))) 271 | library = Normal(ql_m, ql_v.sqrt()).sample() 272 | 273 | px_scale, px_r, px_rate, px_dropout = self.decoder( 274 | self.dispersion, z, library, batch_index, y 275 | ) 276 | if self.dispersion == "gene-label": 277 | px_r = F.linear( 278 | one_hot(y, self.n_labels), self.px_r 279 | ) # px_r gets transposed - last dimension is nb genes 280 | elif self.dispersion == "gene-batch": 281 | px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) 282 | elif self.dispersion == "gene": 283 | px_r = self.px_r 284 | px_r = torch.exp(px_r) 285 | 286 | return dict( 287 | px_scale=px_scale, 288 | px_r=px_r, 289 | px_rate=px_rate, 290 | px_dropout=px_dropout, 291 | qz_m=qz_m, 292 | qz_v=qz_v, 293 | z=z, 294 | ql_m=ql_m, 295 | ql_v=ql_v, 296 | library=library, 297 | ) 298 | 299 | def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None): 300 | r""" Returns the reconstruction loss and the Kullback divergences 301 | 302 | :param x: tensor of values with shape (batch_size, n_input) 303 | :param local_l_mean: tensor of means of the prior distribution of latent variable l 304 | with shape (batch_size, 1) 305 | :param local_l_var: tensor of variancess of the prior distribution of latent variable l 306 | with shape (batch_size, 1) 307 | :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size`` 308 | :param y: tensor of cell-types labels with shape (batch_size, n_labels) 309 | :return: the reconstruction loss and the Kullback divergences 310 | :rtype: 2-tuple of :py:class:`torch.FloatTensor` 311 | """ 312 | # Parameters for z latent distribution 313 | outputs = self.inference(x, batch_index, None) 314 | qz_m = outputs["qz_m"] 315 | qz_v = outputs["qz_v"] 316 | ql_m = outputs["ql_m"] 317 | ql_v = outputs["ql_v"] 318 | px_rate = outputs["px_rate"] 319 | px_r = outputs["px_r"] 320 | px_dropout = outputs["px_dropout"] 321 | 322 | # KL Divergence 323 | mean = torch.zeros_like(qz_m) 324 | scale = torch.ones_like(qz_v) 325 | 326 | 327 | kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum( 328 | dim=1 329 | ) 330 | kl_divergence_l = kl( 331 | Normal(ql_m, torch.sqrt(ql_v)), 332 | Normal(local_l_mean, torch.sqrt(local_l_var)), 333 | ).sum(dim=1) 334 | kl_divergence = kl_divergence_z 335 | 336 | reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout) 337 | 338 | if self.reconstruction_loss == "mse" or self.reconstruction_loss == "nb": 339 | kl_divergence_l = 1.0 340 | print("reconst_loss=%f, kl_divergence=%f"%(torch.mean(reconst_loss),torch.mean(kl_divergence))) 341 | return reconst_loss + kl_divergence_l, kl_divergence, 0.0 342 | 343 | 344 | class LDVAE(VAE_Attention): 345 | r"""Linear-decoded Variational auto-encoder model. 346 | 347 | This model uses a linear decoder, directly mapping the latent representation 348 | to gene expression levels. It still uses a deep neural network to encode 349 | the latent representation. 350 | 351 | Compared to standard VAE, this model is less powerful, but can be used to 352 | inspect which genes contribute to variation in the dataset. 353 | 354 | :param n_input: Number of input genes 355 | :param n_batch: Number of batches 356 | :param n_labels: Number of labels 357 | :param n_hidden: Number of nodes per hidden layer (for encoder) 358 | :param n_latent: Dimensionality of the latent space 359 | :param n_layers: Number of hidden layers used for encoder NNs 360 | :param dropout_rate: Dropout rate for neural networks 361 | :param dispersion: One of the following 362 | 363 | * ``'gene'`` - dispersion parameter of NB is constant per gene across cells 364 | * ``'gene-batch'`` - dispersion can differ between different batches 365 | * ``'gene-label'`` - dispersion can differ between different labels 366 | * ``'gene-cell'`` - dispersion can differ for every gene in every cell 367 | 368 | :param log_variational: Log(data+1) prior to encoding for numerical stability. Not normalization. 369 | :param reconstruction_loss: One of 370 | 371 | * ``'nb'`` - Negative binomial distribution 372 | * ``'zinb'`` - Zero-inflated negative binomial distribution 373 | """ 374 | 375 | def __init__( 376 | self, 377 | n_input: int, 378 | n_batch: int = 0, 379 | n_labels: int = 0, 380 | n_hidden: int = 128, 381 | n_latent: int = 10, 382 | n_layers: int = 1, 383 | dropout_rate: float = 0.1, 384 | dispersion: str = "gene", 385 | log_variational: bool = True, 386 | reconstruction_loss: str = "zinb", 387 | ): 388 | super().__init__( 389 | n_input, 390 | n_batch, 391 | n_labels, 392 | n_hidden, 393 | n_latent, 394 | n_layers, 395 | dropout_rate, 396 | dispersion, 397 | log_variational, 398 | reconstruction_loss, 399 | ) 400 | 401 | self.decoder = LinearDecoderSCVI( 402 | n_latent, 403 | n_input, 404 | n_cat_list=[n_batch], 405 | n_layers=n_layers, 406 | n_hidden=n_hidden, 407 | ) 408 | 409 | def get_loadings(self): 410 | """ Extract per-gene weights (for each Z) in the linear decoder. 411 | """ 412 | return self.decoder.factor_regressor.parameters() 413 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from setuptools import setup, find_packages 5 | 6 | 7 | requirements = [ 8 | "numpy>=1.16.2", 9 | "torch>=1.0.1", 10 | "matplotlib>=3.0.3", 11 | "h5py>=2.9.0", 12 | "pandas>=0.24.2", 13 | "loompy>=2.0.16", 14 | "tqdm>=4.31.1", 15 | "xlrd==1.2.0", 16 | # "nbconvert>=5.4.0", 17 | # "nbformat>=4.4.0", 18 | # "jupyter>=1.0.0", 19 | # "ipython>=7.1.1", 20 | # "anndata==0.6.22.post1", 21 | # "scanpy==1.4.4.post1", 22 | "dask>=2.0", 23 | "anndata>=0.7", 24 | "scanpy>=1.4.6", 25 | "scikit-learn>=0.22.2", 26 | "numba>=0.48", # numba 0.45.1 has a conflict with UMAP and numba 0.46.0 with parallelization in loompy 27 | "hyperopt==0.1.2", 28 | ] 29 | 30 | setup_requirements = ["pip>=18.1"] 31 | 32 | test_requirements = [ 33 | "pytest>=3.7.4", 34 | "pytest-runner>=2.11.1", 35 | "flake8>=3.7.7", 36 | "coverage>=4.5.1", 37 | "codecov>=2.0.8", 38 | "black>=19.3b0", 39 | ] 40 | 41 | extras_requirements = { 42 | "notebooks": [ 43 | "louvain>=0.6.1", 44 | "python-igraph>=0.7.1.post6", 45 | "colour>=0.1.5", 46 | "umap-learn>=0.3.8", 47 | "seaborn>=0.9.0", 48 | "leidenalg>=0.7.0", 49 | ], 50 | "docs": [ 51 | "sphinx>=1.7.1", 52 | "nbsphinx", 53 | "sphinx_autodoc_typehints", 54 | "sphinx-rtd-theme", 55 | ], 56 | "test": test_requirements, 57 | } 58 | author = ( 59 | "Gao yang Li, Shaliu FU" 60 | ) 61 | 62 | setup( 63 | author=author, 64 | author_email="lgyzngc@tongji.edu.cn", 65 | classifiers=[ 66 | "Development Status :: 4 - Beta", 67 | "Intended Audience :: Science/Research", 68 | "License :: OSI Approved :: MIT License", 69 | "Natural Language :: English", 70 | "Programming Language :: Python :: 3.7", 71 | "Operating System :: MacOS :: MacOS X", 72 | "Operating System :: Microsoft :: Windows", 73 | "Operating System :: POSIX :: Linux", 74 | "Topic :: Scientific/Engineering :: bioinformatics", 75 | ], 76 | description="Single Cell Multi-View Profiler", 77 | install_requires=requirements+extras_requirements["notebooks"], 78 | license="MIT license", 79 | # long_description=readme + "\n\n" + history, 80 | include_package_data=True, 81 | keywords="scMVP", 82 | name="scMVP", 83 | packages=find_packages(), 84 | setup_requires=setup_requirements, 85 | test_suite="tests", 86 | tests_require=test_requirements, 87 | extras_require=extras_requirements, 88 | url="https://github.com/bm2-lab/scMVP", 89 | version="0.0.1", 90 | zip_safe=False, 91 | ) 92 | --------------------------------------------------------------------------------