├── .gitignore
├── LICENSE
├── README.md
├── demos
├── manuscript_analysis
│ ├── 10x_pbmc_demo.ipynb
│ ├── lymph_node_demo.ipynb
│ ├── paired_cellline_demo.ipynb
│ ├── sciCAR_cellline_demo.ipynb
│ ├── share_skin_demo.ipynb
│ ├── snare_cellline_demo.ipynb
│ ├── snare_p0_demo.ipynb
│ └── snare_p0_no_pretrain.ipynb
└── scMVP_tutorial.ipynb
├── scMVP
├── __init__.py
├── _settings.py
├── dataset
│ ├── __init__.py
│ ├── dataset.py
│ └── scMVP_dataloader.py
├── inference
│ ├── __init__.py
│ ├── annotation.py
│ ├── autotune.py
│ ├── inference.py
│ ├── multi_inference.py
│ ├── posterior.py
│ └── trainer.py
└── models
│ ├── __init__.py
│ ├── classifier.py
│ ├── log_likelihood.py
│ ├── modules.py
│ ├── multi_vae_attention.py
│ ├── utils.py
│ ├── vaePeak_selfattetion.py
│ └── vae_attention.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # custom
2 | .idea/
3 | venv/
4 | data/
5 | *.pkl
6 | *.csv
7 | *.tsv
8 | *.mtx
9 | *.xls
10 | demos/dataset/*.txt
11 | demos/dataset/*/*.txt
12 | *.bed
13 | analysis_script
14 | requirement.txt
15 | docs/
16 | *.gz
17 | nohup.out
18 |
19 | demos/.ipynb_checkpoints/
20 | spatstat_1.64-1.tar.gz
21 | demos/snare_p0_old_model.ipynb
22 | demos/demo.ipynb
23 | demos/demo_share_format.py
24 |
25 | # demo folders
26 | demos/share_skin/
27 | demos/share_skin/
28 | demos/10x_pbmc/
29 | demos/paired_cellline/
30 | demos/snare_cellline/
31 | demos/sciCAR_cellline/
32 | demos/snare_p0/
33 | demos/paired_adult.ipynb
34 |
35 | # Byte-compiled / optimized / DLL files
36 | __pycache__/
37 | *.py[cod]
38 | *$py.class
39 | # necessary files
40 | # C extensions
41 | *.so
42 |
43 | # DS_Store
44 | .DS_Store
45 |
46 | # Distribution / packaging
47 | .Python
48 | env/
49 | build/
50 | develop-eggs/
51 | dist/
52 | downloads/
53 | eggs/
54 | .eggs/
55 | lib/
56 | lib64/
57 | parts/
58 | sdist/
59 | var/
60 | wheels/
61 | *.egg-info/
62 | .installed.cfg
63 | *.egg
64 |
65 | # PyInstaller
66 | # Usually these files are written by a python script from a template
67 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
68 | *.manifest
69 | *.spec
70 |
71 | # Installer logs
72 | pip-log.txt
73 | pip-delete-this-directory.txt
74 |
75 | # Unit test / coverage reports
76 | htmlcov/
77 | .tox/
78 | .coverage
79 | .coverage.*
80 | .cache
81 | nosetests.xml
82 | coverage.xml
83 | *.cover
84 | .hypothesis/
85 | .pytest_cache/
86 |
87 | # Translations
88 | *.mo
89 | *.pot
90 |
91 | # Django stuff:
92 | *.log
93 | local_settings.py
94 |
95 | # Flask stuff:
96 | instance/
97 | .webassets-cache
98 |
99 | # Scrapy stuff:
100 | .scrapy
101 |
102 | # Sphinx documentation
103 | docs/_build/
104 | docs/authors.rst
105 |
106 | # PyBuilder
107 | target/
108 |
109 | # Jupyter Notebook
110 | .ipynb_checkpoints
111 |
112 | # pyenv
113 | .python-version
114 |
115 | # celery beat schedule file
116 | celerybeat-schedule
117 |
118 | # SageMath parsed files
119 | *.sage.py
120 |
121 | # dotenv
122 | .env
123 |
124 | # virtualenv
125 | .venv
126 | venv/
127 | ENV/
128 |
129 | # Spyder project settings
130 | .spyderproject
131 | .spyproject
132 |
133 | # Rope project settings
134 | .ropeproject
135 |
136 | # mkdocs documentation
137 | /site
138 |
139 | # mypy
140 | .mypy_cache/
141 |
142 | # PyCharm
143 | .idea/
144 |
145 | # Floobits
146 | .floo
147 | .flooignore
148 |
149 |
150 | /data
151 | .vscode/settings.json
152 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Romain Lopez, 2018
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # scMVP - single cell Multi-View Profiler
2 |
3 | scMVP is a python toolkit for joint profiling of scRNA and scATAC data and analysis
4 | with multi-modal self-attention generation model.
5 |
6 | ## Update logs
7 | - 20220815
8 | Complete the test for GPU version on NVIDA 3090 platform with CUDA version 11.2.
9 | Add scRNA and scATAC input check in tutorial.
10 |
11 |
12 | ## Installation
13 | **Environment requirements:**
14 | scMVP requires Python3.7.x and [**Pytorch**](http://pytorch.org).
15 | For example, use [**miniconda**](https://conda.io/miniconda.html) to install python and pytorch of CPU or GPU version.
16 | We have tested the GPU version on NVIDA 1080Ti platform with CUDA version 10.2.
17 |
18 | ```Bash
19 | conda install -c pytorch python=3.7 pytorch
20 | # if you do not have jupyter notebook/ipython notebook, you can also install by conda
21 | conda install jupyter
22 | ```
23 |
24 | Then you can install scMVP from github repo:
25 | ```Bash
26 | # first move to your target directory
27 | git clone https://github.com/bm2-lab/scMVP.git
28 | cd scMVP/
29 | python setup.py install
30 | ```
31 |
32 | Try ```import scMVP``` in your python console and start your first [**tutorial**](demos/scMVP_tutorial.ipynb) with scMVP!
33 |
34 | Jupyter notebooks for other datasets analyzed and benchmarked in our GB paper are deposited in [**folder**](demos/manuscript_analysis/)
35 |
36 | ### All processed dataset and trained models:
37 | Download link: [baidu cloud disk](https://pan.baidu.com/s/183jLROAUuNfVKCeBY4B4DQ)
38 | - update 23/10/22 : google drive link for cellline datasets: [google drive](https://drive.google.com/drive/folders/18ymTLyMb_wD20O4Z2qkOXBQt5yoDTvea?usp=sharing)
39 |
40 | Download code: mkij
41 | - pre_trainer.pkl scRNA pretraining models
42 | - pre_atac_trainer.pkl scATAC pretraining models
43 | - multi_vae_trainer.pkl scMVP training models
44 |
45 |
46 | ## User tutorial
47 |
48 | Applying scMVP to sci-CAR cell line mixture. [**demo**](demos/scMVP_tutorial.ipynb)
49 | - Training and visualization with scMVP.
50 | - Pretraining and transferring to scMVP(perform better in large dataset).
51 |
52 |
53 |
54 | ### Reference
55 | A deep generative model for multi-view profiling of single-cell RNA-seq and ATAC-seq data. Genome Biology 2022 [**paper**](https://genomebiology.biomedcentral.com/articles/10.1186/s13059-021-02595-6)
56 |
57 |
58 | ### Contact Authors
59 | Prof. Qi Liu: [qiliu@tongji.edu.cn](qiliu@tongji.edu.cn)
60 | Dr. Gaoyang Li: [lgyzngc@gmail.com](lgyzngc@gmail.com)
61 | Shaliu Fu: [adam.tongji@gmail.com](adam.tongji@gmail.com)
62 |
63 |
--------------------------------------------------------------------------------
/scMVP/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | __author__ = "Gaoyang Li"
5 | __email__ = "lgyzngc@tongji.edu.cn"
6 | __version__ = "0.0.1"
7 |
8 | # Set default logging handler to avoid logging with logging.lastResort logger.
9 | import logging
10 | from logging import NullHandler
11 |
12 | from ._settings import set_verbosity
13 |
14 | logger = logging.getLogger(__name__)
15 | logger.addHandler(NullHandler())
16 |
17 | # default to INFO level logging for the scMVP package
18 | set_verbosity(logging.INFO)
19 |
20 | __all__ = ["set_verbosity"]
21 |
--------------------------------------------------------------------------------
/scMVP/_settings.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Union
3 |
4 | logger = logging.getLogger(__name__)
5 | scMVP_logger = logging.getLogger("scMVP")
6 |
7 | autotune_formatter = logging.Formatter(
8 | "[%(asctime)s - %(processName)s - %(threadName)s] %(levelname)s - %(name)s\n%(message)s"
9 | )
10 |
11 |
12 | class DispatchingFormatter(logging.Formatter):
13 | """Dispatch formatter for logger and it's sub logger."""
14 |
15 | def __init__(self, default_formatter, formatters=None):
16 | super().__init__()
17 | self._formatters = formatters if formatters is not None else {}
18 | self._default_formatter = default_formatter
19 |
20 | def format(self, record):
21 | # Search from record's logger up to it's parents:
22 | logger = logging.getLogger(record.name)
23 | while logger:
24 | # Check if suitable formatter for current logger exists:
25 | if logger.name in self._formatters:
26 | formatter = self._formatters[logger.name]
27 | break
28 | else:
29 | logger = logger.parent
30 | else:
31 | # If no formatter found, just use default:
32 | formatter = self._default_formatter
33 | return formatter.format(record)
34 |
35 |
36 | def set_verbosity(level: Union[str, int]):
37 | """Sets logging configuration for scMVP based on chosen level of verbosity.
38 |
39 | Sets "scMVP" logging level to `level`
40 | If "scMVP" logger has no StreamHandler, add one.
41 | Else, set its level to `level`.
42 | """
43 | scMVP_logger.setLevel(level)
44 | has_streamhandler = False
45 | for handler in scMVP_logger.handlers:
46 | if isinstance(handler, logging.StreamHandler):
47 | handler.setLevel(level)
48 | logger.info(
49 | "'scMVP' logger already has a StreamHandler, set its level to {}.".format(
50 | level
51 | )
52 | )
53 | has_streamhandler = True
54 |
55 |
56 | if not has_streamhandler:
57 | ch = logging.StreamHandler()
58 | formatter = logging.Formatter(
59 | "[%(asctime)s] %(levelname)s - %(name)s | %(message)s"
60 | )
61 | ch.setFormatter(
62 | DispatchingFormatter(formatter, {"scMVP.autotune": autotune_formatter})
63 | )
64 | scMVP_logger.addHandler(ch)
65 | logger.info("Added StreamHandler with custom formatter to 'scMVP' logger.")
66 |
--------------------------------------------------------------------------------
/scMVP/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from scMVP.dataset.dataset import (
2 | GeneExpressionDataset,
3 | CellMeasurement,
4 | )
5 |
6 | from scMVP.dataset.scMVP_dataloader import LoadData
7 |
8 | __all__ = [
9 | "CellMeasurement",
10 | "GeneExpressionDataset",
11 | "LoadData",
12 | ]
13 |
--------------------------------------------------------------------------------
/scMVP/dataset/scMVP_dataloader.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import urllib
4 | import numpy as np
5 | import pandas as pd
6 | import scipy.io as sp_io
7 | from scipy.sparse import csr_matrix, issparse
8 |
9 | from scMVP.dataset.dataset import CellMeasurement, GeneExpressionDataset
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 | available_specification = ["filtered", "raw"]
14 |
15 |
16 | class LoadData(GeneExpressionDataset):
17 |
18 | """
19 | Dataset format:
20 | dataset = {
21 | "gene_barcodes": xxx,
22 | "gene_expression": xxx,
23 | "gene_names": xxx,
24 | "atac_barcodes": xxx,
25 | "atac_expression": xxx,
26 | "atac_names": xxx,
27 | }
28 | OR
29 | dataset = {
30 | "gene_expression":xxx,
31 | "atac_expression":xxx,
32 | }
33 | """
34 | def __init__(self,
35 | dataset: dict = None,
36 | data_path: str = "dataset/",
37 | dense: bool = False,
38 | measurement_names_column: int = 0,
39 | remove_extracted_data: bool = False,
40 | delayed_populating: bool = False,
41 | file_separator: str = "\t",
42 | gzipped: bool = False,
43 | atac_threshold: float = 0.0001, # express in over 0.01%
44 | cell_threshold: int = 1, # filtering cells less than minimum count
45 | cell_meta: pd.DataFrame = None,
46 | ):
47 |
48 | self.dataset = dataset
49 | self.data_path = data_path
50 | self.barcodes = None
51 | self.dense = dense
52 | self.measurement_names_column = measurement_names_column
53 | self.remove_extracted_data = remove_extracted_data
54 | self.file_separator = file_separator
55 | self.gzip = gzipped
56 | self.atac_thres = atac_threshold
57 | self.cell_thres = cell_threshold
58 | self._minimum_input = ("gene_expression", "atac_expression")
59 | self._allow_input = (
60 | "gene_expression", "atac_expression",
61 | "gene_barcodes", "gene_names",
62 | "atac_barcodes", "atac_names"
63 | )
64 | self.cell_meta = cell_meta
65 | super().__init__()
66 | if not delayed_populating:
67 | self.populate()
68 |
69 | def populate(self):
70 | logger.info("Preprocessing joint profiling dataset.")
71 | if not self._input_check():
72 | logger.info("Please reload your dataset.")
73 | return
74 | joint_profiles = {}
75 | if len(self.dataset.keys()) == 2:
76 | # for _key in self.dataset.keys():
77 | if self.gzip:
78 | _tmp = pd.read_csv("{}/{}".format(self.data_path,self.dataset["gene_expression"]), sep=self.file_separator,
79 | header=0, index_col=0, compression="gzip")
80 | else:
81 | _tmp = pd.read_csv("{}/{}".format(self.data_path,self.dataset["gene_expression"]), sep=self.file_separator,
82 | header=0)
83 | joint_profiles["gene_barcodes"] = pd.DataFrame(_tmp.columns.values)
84 | joint_profiles["gene_names"] = pd.DataFrame(_tmp._stat_axis.values)
85 | joint_profiles["gene_expression"] = np.array(_tmp).T
86 |
87 | if self.gzip:
88 | _tmp = pd.read_csv("{}/{}".format(self.data_path,self.dataset["atac_expression"]), sep=self.file_separator,
89 | header=0, index_col=0, compression="gzip")
90 | else:
91 | _tmp = pd.read_csv("{}/{}".format(self.data_path,self.dataset["atac_expression"]), sep=self.file_separator,
92 | header=0, index_col=0)
93 | joint_profiles["atac_barcodes"] = pd.DataFrame(_tmp.columns.values)
94 | joint_profiles["atac_names"] = pd.DataFrame(_tmp._stat_axis.values)
95 | joint_profiles["atac_expression"] = np.array(_tmp).T
96 |
97 | elif len(self.dataset.keys()) == 6:
98 | for _key in self.dataset.keys():
99 | if _key == "atac_expression" or _key == "gene_expression" and not self.dense:
100 | joint_profiles[_key] = csr_matrix(sp_io.mmread("{}/{}".format(self.data_path,self.dataset[_key])).T)
101 |
102 | elif self.gzip:
103 | joint_profiles[_key] = pd.read_csv("{}/{}".format(self.data_path, self.dataset[_key]),
104 | sep=self.file_separator,
105 | compression="gzip", header=None)
106 | else:
107 | joint_profiles[_key] = pd.read_csv("{}/{}".format(self.data_path,self.dataset[_key]),
108 | sep=self.file_separator, header=None)
109 |
110 | else:
111 | logger.info("more than 6 inputs.")
112 |
113 | ## 200920 gene barcode file may include more than 1 column
114 | if joint_profiles["gene_names"].shape[1] > 1:
115 | joint_profiles["gene_names"] = pd.DataFrame(joint_profiles["gene_names"].iloc[:,1])
116 | if joint_profiles["atac_names"].shape[1] > 1:
117 | joint_profiles["atac_names"] = pd.DataFrame(joint_profiles["atac_names"].iloc[:,1])
118 | share_index, gene_barcode_index, atac_barcode_index = np.intersect1d(joint_profiles["gene_barcodes"].values,
119 | joint_profiles["atac_barcodes"].values,
120 | return_indices=True)
121 | if isinstance(self.cell_meta,pd.DataFrame):
122 | if self.cell_meta.shape[1] < 2:
123 | logger.info("Please use cell id in first column and give ata least 2 columns.")
124 | return
125 | meta_cell_id = self.cell_meta.iloc[:,0].values
126 | meta_share, meta_barcode_index, share_barcode_index =\
127 | np.intersect1d(meta_cell_id,
128 | share_index, return_indices=True)
129 | _gene_barcode_index = gene_barcode_index[share_barcode_index]
130 | _atac_barcode_index = atac_barcode_index[share_barcode_index]
131 | if len(_gene_barcode_index) < 2: # no overlaps
132 | logger.info("Inconsistent metadata to expression data.")
133 | return
134 | tmp = joint_profiles["gene_barcodes"]
135 | joint_profiles["gene_barcodes"] = tmp.loc[_gene_barcode_index, :]
136 | temp = joint_profiles["atac_barcodes"]
137 | joint_profiles["atac_barcodes"] = temp.loc[_atac_barcode_index, :]
138 |
139 | else:
140 | # reorder rnaseq cell meta
141 | tmp = joint_profiles["gene_barcodes"]
142 | joint_profiles["gene_barcodes"] = tmp.loc[gene_barcode_index,:]
143 | temp = joint_profiles["atac_barcodes"]
144 | joint_profiles["atac_barcodes"] = temp.loc[atac_barcode_index, :]
145 |
146 | gene_tab = joint_profiles["gene_expression"]
147 | if issparse(gene_tab):
148 | joint_profiles["gene_expression"] = gene_tab[gene_barcode_index, :].A
149 | else:
150 | joint_profiles["gene_expression"] = gene_tab[gene_barcode_index, :]
151 |
152 | temp = joint_profiles["atac_expression"]
153 | reorder_atac_exp = temp[atac_barcode_index, :]
154 | binary_index = reorder_atac_exp > 1
155 | reorder_atac_exp[binary_index] = 1
156 | # remove peaks > 10% of total cells
157 | high_count_atacs = ((reorder_atac_exp > 0).sum(axis=0).ravel() >= self.atac_thres * reorder_atac_exp.shape[0]) \
158 | & ((reorder_atac_exp > 0).sum(axis=0).ravel() <= 0.1 * reorder_atac_exp.shape[0])
159 |
160 | if issparse(reorder_atac_exp):
161 | high_count_atacs_index = np.where(high_count_atacs)
162 | _tmp = reorder_atac_exp[:, high_count_atacs_index[1]]
163 | joint_profiles["atac_expression"] = _tmp.A
164 | joint_profiles["atac_names"] = joint_profiles["atac_names"].loc[high_count_atacs_index[1], :]
165 |
166 | else:
167 | _tmp = reorder_atac_exp[:, high_count_atacs]
168 |
169 | joint_profiles["atac_expression"] = _tmp
170 | joint_profiles["atac_names"] = joint_profiles["atac_names"].loc[high_count_atacs, :]
171 |
172 | # RNA-seq as the key
173 | Ys = []
174 | measurement = CellMeasurement(
175 | name="atac_expression",
176 | data=joint_profiles["atac_expression"],
177 | columns_attr_name="atac_names",
178 | columns=joint_profiles["atac_names"].astype(np.str),
179 | )
180 | Ys.append(measurement)
181 | # Add cell metadata
182 | if isinstance(self.cell_meta,pd.DataFrame):
183 | for l_index, label in enumerate(list(self.cell_meta.columns.values)):
184 | if l_index >0:
185 | label_measurement = CellMeasurement(
186 | name="{}_label".format(label),
187 | data=self.cell_meta.iloc[meta_barcode_index,l_index],
188 | columns_attr_name=label,
189 | columns=self.cell_meta.iloc[meta_barcode_index, l_index]
190 | )
191 | Ys.append(label_measurement)
192 | logger.info("Loading {} into dataset.".format(label))
193 |
194 | cell_attributes_dict = {
195 | "barcodes": np.squeeze(np.asarray(joint_profiles["gene_barcodes"], dtype=str))
196 | }
197 |
198 | logger.info("Finished preprocessing dataset")
199 |
200 | self.populate_from_data(
201 | X=joint_profiles["gene_expression"],
202 | batch_indices=None,
203 | gene_names=joint_profiles["gene_names"].astype(np.str),
204 | cell_attributes_dict=cell_attributes_dict,
205 | Ys=Ys,
206 | )
207 | self.filter_cells_by_count(self.cell_thres)
208 |
209 | def _input_check(self):
210 | if len(self.dataset.keys()) == 2:
211 | for _key in self.dataset.keys():
212 | if _key not in self._minimum_input:
213 | logger.info("Unknown input data type:{}".format(_key))
214 | return False
215 | # if not self.dataset[_key].split(".")[-1] in ["txt","tsv","csv"]:
216 | # logger.debug("scMVP only support two files input of txt, tsv or csv!")
217 | # return False
218 | elif len(self.dataset.keys()) >= 6:
219 | for _key in self._allow_input:
220 | if not _key in self.dataset.keys():
221 | logger.info("Data type {} missing.".format(_key))
222 | return False
223 | else:
224 | logger.info("Incorrect input file number.")
225 | return False
226 | for _key in self.dataset.keys():
227 | if not os.path.exists(self.data_path):
228 | logger.info("{} do not exist!".format(self.data_path))
229 | if not os.path.exists("{}{}".format(self.data_path, self.dataset[_key])):
230 | logger.info("Cannot find {}{}!".format(self.data_path, self.dataset[_key]))
231 | return False
232 | return True
233 |
234 | def _download(self, url: str, save_path: str, filename: str):
235 | """Writes data from url to file."""
236 | if os.path.exists(os.path.join(save_path, filename)):
237 | logger.info("File %s already downloaded" % (os.path.join(save_path, filename)))
238 | return
239 |
240 | r = urllib.request.urlopen(url)
241 | logger.info("Downloading file at %s" % os.path.join(save_path, filename))
242 |
243 | def read_iter(file, block_size=1000):
244 | """Given a file 'file', returns an iterator that returns bytes of
245 | size 'blocksize' from the file, using read()."""
246 | while True:
247 | block = file.read(block_size)
248 | if not block:
249 | break
250 | yield block
251 |
252 | def _add_cell_meta(self, cell_meta, filter=False):
253 | cell_ids = cell_meta.iloc[:,1].values
254 | share_index, meta_barcode_index, gene_barcode_index = \
255 | np.intersect1d(cell_ids,self.barcodes,return_indices=True)
256 | if len(share_index) <=1:
257 | logger.info("No consistent cell IDs!")
258 | return
259 | if len(share_index) < len(self.barcodes):
260 | logger.info("{} cells match metadata.".format(len(share_index)))
261 | return
262 |
263 |
264 |
265 | class SnareDemo(LoadData):
266 |
267 | def __init__(self, dataset_name: str=None, data_path: str="/dataset", cell_meta: str = None):
268 | url="https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE126074"
269 | available_datasets = {
270 | "CellLineMixture": {
271 | "gene_expression": "GSE126074_CellLineMixture_SNAREseq_cDNA_counts.tsv.gz",
272 | "atac_expression": "GSE126074_CellLineMixture_SNAREseq_chromatin_counts.tsv.gz",
273 | },
274 | "AdBrainCortex": {
275 | "gene_barcodes": "GSE126074_AdBrainCortex_SNAREseq_cDNA.barcodes.tsv.gz",
276 | "gene_expression": "GSE126074_AdBrainCortex_SNAREseq_cDNA.counts.mtx.gz",
277 | "gene_names": "GSE126074_AdBrainCortex_SNAREseq_cDNA.genes.tsv.gz",
278 | "atac_barcodes": "GSE126074_AdBrainCortex_SNAREseq_chromatin.barcodes.tsv.gz",
279 | "atac_expression": "GSE126074_AdBrainCortex_SNAREseq_chromatin.counts.mtx.gz",
280 | "atac_names": "GSE126074_AdBrainCortex_SNAREseq_chromatin.peaks.tsv.gz",
281 | },
282 | "P0_BrainCortex": {
283 | "gene_barcodes": "GSE126074_P0_BrainCortex_SNAREseq_cDNA.barcodes.tsv.gz",
284 | "gene_expression": "GSE126074_P0_BrainCortex_SNAREseq_cDNA.counts.mtx.gz",
285 | "gene_names": "GSE126074_P0_BrainCortex_SNAREseq_cDNA.genes.tsv.gz",
286 | "atac_barcodes": "GSE126074_P0_BrainCortex_SNAREseq_chromatin.barcodes.tsv.gz",
287 | "atac_expression": "GSE126074_P0_BrainCortex_SNAREseq_chromatin.counts.mtx.gz",
288 | "atac_names": "GSE126074_P0_BrainCortex_SNAREseq_chromatin.peaks.tsv.gz",
289 | }
290 | }
291 | if cell_meta:
292 | cell_meta_data = pd.read_csv(cell_meta, sep=",", header=0)
293 | else:
294 | cell_meta_data = None
295 | if dataset_name=="CellLineMixture":
296 | super(SnareDemo, self).__init__(dataset = available_datasets[dataset_name],
297 | data_path= data_path,
298 | dense = False,
299 | measurement_names_column = 1,
300 | cell_meta=cell_meta_data,
301 | remove_extracted_data = False,
302 | delayed_populating = False,
303 | file_separator = "\t",
304 | gzipped = True,
305 | atac_threshold = 0.0005,
306 | cell_threshold = 1
307 | )
308 | elif dataset_name=="AdBrainCortex" or dataset_name=="P0_BrainCortex":
309 | super(SnareDemo, self).__init__(dataset=available_datasets[dataset_name],
310 | data_path=data_path,
311 | dense=False,
312 | measurement_names_column=1,
313 | cell_meta=cell_meta_data,
314 | remove_extracted_data=False,
315 | delayed_populating=False,
316 | gzipped=True,
317 | atac_threshold=0.0005,
318 | cell_threshold=1
319 | )
320 | else:
321 | logger.info('Please select from "CellLineMixture", "AdBrainCortex" or "P0_BrainCortex" dataset.')
322 |
323 |
324 | class PairedDemo(LoadData):
325 |
326 | def __init__(self, dataset_name: str = None, data_path: str = "/dataset"):
327 | urls = [
328 | "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE130nnn/GSE130399/suppl/GSE130399_GSM3737488_GSM3737489_Cell_Mix.tar.gz",
329 | "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE130nnn/GSE130399/suppl/GSE130399_GSM3737490-GSM3737495_Adult_Cerebrail_Cortex.tar.gz",
330 | "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE130nnn/GSE130399/suppl/GSE130399_GSM3737496-GSM3737499_Fetal_Forebrain.tar.gz"
331 | ]
332 |
333 | available_datasets = {
334 | "CellLineMixture": {
335 | "gene_names": "Cell_Mix_RNA/genes.tsv",
336 | "gene_expression": "Cell_Mix_RNA/matrix.mtx",
337 | "gene_barcodes": "Cell_Mix_RNA/barcodes.tsv",
338 | "atac_names": "Cell_Mix_DNA/genes.tsv",
339 | "atac_expression": "Cell_Mix_DNA/matrix.mtx",
340 | "atac_barcodes":"Cell_Mix_DNA/barcodes.tsv"
341 | },
342 | "Adult_Cerebral": {
343 | "gene_names": "Adult_CTX_RNA/genes.tsv",
344 | "gene_expression": "Adult_CTX_RNA/matrix.mtx",
345 | "gene_barcodes": "Adult_CTX_RNA/barcodes.tsv",
346 | "atac_names": "Adult_CTX_DNA/genes.tsv",
347 | "atac_expression": "Adult_CTX_DNA/matrix.mtx",
348 | "atac_barcodes": "Adult_CTX_DNA/barcodes.tsv"
349 | },
350 | "Fetal_Forebrain": {
351 | "gene_names": "FB_RNA/genes.tsv",
352 | "gene_expression": "FB_RNA/matrix.mtx",
353 | "gene_barcodes": "FB_RNA/barcodes.tsv",
354 | "atac_names": "FB_DNA/genes.tsv",
355 | "atac_expression": "FB_DNA/matrix.mtx",
356 | "atac_barcodes": "FB_DNA/barcodes.tsv"
357 | }
358 | }
359 | if dataset_name=="CellLineMixture" or dataset_name=="Fetal_Forebrain":
360 | if os.path.exists("{}/Cell_embeddings.xls".format(data_path)):
361 | cell_embed = pd.read_csv("{}/Cell_embeddings.xls".format(data_path), sep='\t')
362 | cell_embed_info = cell_embed.iloc[:, 0:2]
363 | cell_embed_info.columns = ["Cell_ID","Cluster"]
364 | else:
365 | logger.info("Cannot find cell embedding files for Paried-seq Demo.")
366 | return
367 |
368 | super().__init__(dataset = available_datasets[dataset_name],
369 | data_path= data_path,
370 | dense = False,
371 | measurement_names_column = 1,
372 | remove_extracted_data = False,
373 | delayed_populating = False,
374 | gzipped = False,
375 | atac_threshold = 0.005,
376 | cell_threshold = 100,
377 | cell_meta=cell_embed_info
378 | )
379 | elif dataset_name=="Adult_Cerebral":
380 | if os.path.exists("{}/Cell_embeddings.xls".format(data_path)):
381 | cell_embed = pd.read_csv("{}/Cell_embeddings.xls".format(data_path), sep='\t')
382 | cell_embed_info = cell_embed.iloc[:, ["ID","Cluster"]]
383 | cell_embed_info.columns = ["Cell_ID","Cluster"]
384 | else:
385 | logger.info("Cannot find cell embedding files for Paried-seq Demo.")
386 | return
387 | super().__init__(dataset=available_datasets[dataset_name],
388 | data_path=data_path,
389 | dense=False,
390 | measurement_names_column = 1,
391 | remove_extracted_data=False,
392 | delayed_populating=False,
393 | gzipped=False,
394 | atac_threshold=0.005,
395 | cell_threshold=1,
396 | cell_meta=cell_embed_info
397 | )
398 | else:
399 | logger.info('Please select from {} dataset.'.format("\t".join(available_datasets.keys())))
400 |
401 |
402 | class SciCarDemo(LoadData):
403 | def __init__(self, dataset_name: str = None, data_path: str = "/dataset", cell_meta: str = None):
404 | urls = "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE117089&format=file"
405 | # NOTICE, tsv files are generated from original txt files
406 | available_datasets = {
407 | "CellLineMixture": {
408 | "gene_barcodes": "GSM3271040_RNA_sciCAR_A549_cell.tsv",
409 | "gene_names": "GSM3271040_RNA_sciCAR_A549_gene.tsv",
410 | "gene_expression": "GSM3271040_RNA_sciCAR_A549_gene_count.txt",
411 | "atac_barcodes": "GSM3271041_ATAC_sciCAR_A549_cell.tsv",
412 | "atac_names": "GSM3271041_ATAC_sciCAR_A549_peak.tsv",
413 | "atac_expression": "GSM3271041_ATAC_sciCAR_A549_peak_count.txt"
414 | },
415 | "mouse_kidney": {
416 | "gene_barcodes": "GSM3271044_RNA_mouse_kidney_cell.tsv",
417 | "gene_names": "GSM3271044_RNA_mouse_kidney_gene.tsv",
418 | "gene_expression": "GSM3271044_RNA_mouse_kidney_gene_count.txt",
419 | "atac_barcodes": "GSM3271045_ATAC_mouse_kidney_cell.tsv",
420 | "atac_names": "GSM3271045_ATAC_mouse_kidney_peak.tsv",
421 | "atac_expression": "GSM3271045_ATAC_mouse_kidney_peak_count.txt"
422 | }
423 | }
424 | if dataset_name:
425 | for barcode_file in ["gene_barcodes", "atac_barcodes", "gene_names", "atac_names"]:
426 | # generate gene and atac barcodes from cell metadata.
427 | with open("{}/{}".format(data_path, available_datasets[dataset_name][barcode_file]),"w") as fo:
428 | infile = "{}/{}.txt".format(data_path, available_datasets[dataset_name][barcode_file][:-4])
429 | indata = [i.rstrip().split(",") for i in open(infile)][1:]
430 | for line in indata:
431 | fo.write("{}\n".format(line[0]))
432 | if cell_meta:
433 | cell_meta_data = pd.read_csv(cell_meta, sep=",", header=0)
434 | else:
435 | cell_meta_data = None
436 | super().__init__(dataset=available_datasets[dataset_name],
437 | data_path=data_path,
438 | dense=False,
439 | measurement_names_column=0,
440 | cell_meta=cell_meta_data,
441 | remove_extracted_data=False,
442 | delayed_populating=False,
443 | gzipped=False,
444 | atac_threshold=0.0005,
445 | cell_threshold=1
446 | )
447 | else:
448 | logger.info('Please select from {} dataset.'.format("\t".join(available_datasets.keys())))
449 |
450 |
--------------------------------------------------------------------------------
/scMVP/inference/__init__.py:
--------------------------------------------------------------------------------
1 | from .posterior import Posterior
2 | from .trainer import Trainer
3 | from .inference import UnsupervisedTrainer, AdapterTrainer
4 | from .annotation import (
5 | ClassifierTrainer,
6 | )
7 | from .multi_inference import MultiPosterior, MultiTrainer
8 |
9 | __all__ = [
10 | "Trainer",
11 | "Posterior",
12 | "UnsupervisedTrainer",
13 | "AdapterTrainer",
14 | "ClassifierTrainer",
15 | "MultiPosterior",
16 | "MultiTrainer"
17 | ]
18 |
--------------------------------------------------------------------------------
/scMVP/inference/annotation.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | import numpy as np
4 | import logging
5 |
6 | from sklearn import neighbors
7 | from sklearn.ensemble import RandomForestClassifier
8 | from sklearn.model_selection import GridSearchCV
9 | from sklearn.neighbors import KNeighborsClassifier
10 | from sklearn.svm import SVC
11 |
12 | import torch
13 | from torch.nn import functional as F
14 |
15 | from scMVP.inference import Posterior
16 | from scMVP.inference import Trainer
17 | from scMVP.inference.inference import UnsupervisedTrainer
18 | from scMVP.inference.posterior import unsupervised_clustering_accuracy
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | class AnnotationPosterior(Posterior):
24 | def __init__(self, *args, model_zl=False, **kwargs):
25 | super().__init__(*args, **kwargs)
26 | self.model_zl = model_zl
27 |
28 | def accuracy(self):
29 | model, cls = (
30 | (self.sampling_model, self.model)
31 | if hasattr(self, "sampling_model")
32 | else (self.model, None)
33 | )
34 | acc = compute_accuracy(model, self, classifier=cls, model_zl=self.model_zl)
35 | logger.debug("Acc: %.4f" % (acc))
36 | return acc
37 |
38 | accuracy.mode = "max"
39 |
40 | @torch.no_grad()
41 | def hierarchical_accuracy(self):
42 | all_y, all_y_pred = self.compute_predictions()
43 | acc = np.mean(all_y == all_y_pred)
44 |
45 | all_y_groups = np.array([self.model.labels_groups[y] for y in all_y])
46 | all_y_pred_groups = np.array([self.model.labels_groups[y] for y in all_y_pred])
47 | h_acc = np.mean(all_y_groups == all_y_pred_groups)
48 |
49 | logger.debug("Hierarchical Acc : %.4f\n" % h_acc)
50 | return acc
51 |
52 | accuracy.mode = "max"
53 |
54 | @torch.no_grad()
55 | def compute_predictions(self, soft=False):
56 | """
57 | :return: the true labels and the predicted labels
58 | :rtype: 2-tuple of :py:class:`numpy.int32`
59 | """
60 | model, cls = (
61 | (self.sampling_model, self.model)
62 | if hasattr(self, "sampling_model")
63 | else (self.model, None)
64 | )
65 | return compute_predictions(
66 | model, self, classifier=cls, soft=soft, model_zl=self.model_zl
67 | )
68 |
69 | @torch.no_grad()
70 | def unsupervised_classification_accuracy(self):
71 | all_y, all_y_pred = self.compute_predictions()
72 | uca = unsupervised_clustering_accuracy(all_y, all_y_pred)[0]
73 | logger.debug("UCA : %.4f" % (uca))
74 | return uca
75 |
76 | unsupervised_classification_accuracy.mode = "max"
77 |
78 | @torch.no_grad()
79 | def nn_latentspace(self, posterior):
80 | data_train, _, labels_train = self.get_latent()
81 | data_test, _, labels_test = posterior.get_latent()
82 | nn = KNeighborsClassifier()
83 | nn.fit(data_train, labels_train)
84 | score = nn.score(data_test, labels_test)
85 | return score
86 |
87 |
88 | class ClassifierTrainer(Trainer):
89 | r"""The ClassifierInference class for training a classifier either on the raw data or on top of the latent
90 | space of another model (VAE, VAEC, SCANVI).
91 |
92 | Args:
93 | :model: A model instance from class ``VAE``, ``VAEC``, ``SCANVI``
94 | :gene_dataset: A gene_dataset instance like ``CortexDataset()``
95 | :train_size: The train size, either a float between 0 and 1 or and integer for the number of training samples
96 | to use Default: ``0.8``.
97 | :test_size: The test size, either a float between 0 and 1 or and integer for the number of test samples
98 | to use Default: ``None``.
99 | :sampling_model: Model with z_encoder with which to first transform data.
100 | :sampling_zl: Transform data with sampling_model z_encoder and l_encoder and concat.
101 | :\**kwargs: Other keywords arguments from the general Trainer class.
102 |
103 |
104 | Examples:
105 | >>> gene_dataset = CortexDataset()
106 | >>> vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False,
107 | ... n_labels=gene_dataset.n_labels)
108 |
109 | >>> classifier = Classifier(vae.n_latent, n_labels=cortex_dataset.n_labels)
110 | >>> trainer = ClassifierTrainer(classifier, gene_dataset, sampling_model=vae, train_size=0.5)
111 | >>> trainer.train(n_epochs=20, lr=1e-3)
112 | >>> trainer.test_set.accuracy()
113 | """
114 |
115 | def __init__(
116 | self,
117 | *args,
118 | train_size=0.8,
119 | test_size=None,
120 | sampling_model=None,
121 | sampling_zl=False,
122 | use_cuda=True,
123 | **kwargs
124 | ):
125 | self.sampling_model = sampling_model
126 | self.sampling_zl = sampling_zl
127 | super().__init__(*args, use_cuda=use_cuda, **kwargs)
128 | self.train_set, self.test_set, self.validation_set = self.train_test_validation(
129 | self.model,
130 | self.gene_dataset,
131 | train_size=train_size,
132 | test_size=test_size,
133 | type_class=AnnotationPosterior,
134 | )
135 | self.train_set.to_monitor = ["accuracy"]
136 | self.test_set.to_monitor = ["accuracy"]
137 | self.validation_set.to_monitor = ["accuracy"]
138 | self.train_set.model_zl = sampling_zl
139 | self.test_set.model_zl = sampling_zl
140 | self.validation_set.model_zl = sampling_zl
141 |
142 | @property
143 | def posteriors_loop(self):
144 | return ["train_set"]
145 |
146 | def __setattr__(self, key, value):
147 | if key in ["train_set", "test_set"]:
148 | value.sampling_model = self.sampling_model
149 | super().__setattr__(key, value)
150 |
151 | def loss(self, tensors_labelled):
152 | x, _, _, _, labels_train = tensors_labelled
153 | if self.sampling_model:
154 | if hasattr(self.sampling_model, "classify"):
155 | return F.cross_entropy(
156 | self.sampling_model.classify(x), labels_train.view(-1)
157 | )
158 | else:
159 | if self.sampling_model.log_variational:
160 | x = torch.log(1 + x)
161 | if self.sampling_zl:
162 | x_z = self.sampling_model.z_encoder(x)[0]
163 | x_l = self.sampling_model.l_encoder(x)[0]
164 | x = torch.cat((x_z, x_l), dim=-1)
165 | else:
166 | x = self.sampling_model.z_encoder(x)[0]
167 | return F.cross_entropy(self.model(x), labels_train.view(-1))
168 |
169 | @torch.no_grad()
170 | def compute_predictions(self, soft=False):
171 | """
172 | :return: the true labels and the predicted labels
173 | :rtype: 2-tuple of :py:class:`numpy.int32`
174 | """
175 | model, cls = (
176 | (self.sampling_model, self.model)
177 | if hasattr(self, "sampling_model")
178 | else (self.model, None)
179 | )
180 | full_set = self.create_posterior(type_class=AnnotationPosterior)
181 | return compute_predictions(
182 | model, full_set, classifier=cls, soft=soft, model_zl=self.sampling_zl
183 | )
184 |
185 |
186 | @torch.no_grad()
187 | def compute_predictions(
188 | model, data_loader, classifier=None, soft=False, model_zl=False
189 | ):
190 | all_y_pred = []
191 | all_y = []
192 |
193 | for i_batch, tensors in enumerate(data_loader):
194 | sample_batch, _, _, _, labels = tensors
195 | all_y += [labels.view(-1).cpu()]
196 |
197 | if hasattr(model, "classify"):
198 | y_pred = model.classify(sample_batch)
199 | elif classifier is not None:
200 | # Then we use the specified classifier
201 | if model is not None:
202 | if model.log_variational:
203 | sample_batch = torch.log(1 + sample_batch)
204 | if model_zl:
205 | sample_z = model.z_encoder(sample_batch)[0]
206 | sample_l = model.l_encoder(sample_batch)[0]
207 | sample_batch = torch.cat((sample_z, sample_l), dim=-1)
208 | else:
209 | sample_batch, _, _ = model.z_encoder(sample_batch)
210 | y_pred = classifier(sample_batch)
211 | else: # The model is the raw classifier
212 | y_pred = model(sample_batch)
213 |
214 | if not soft:
215 | y_pred = y_pred.argmax(dim=-1)
216 |
217 | all_y_pred += [y_pred.cpu()]
218 |
219 | all_y_pred = np.array(torch.cat(all_y_pred))
220 | all_y = np.array(torch.cat(all_y))
221 |
222 | return all_y, all_y_pred
223 |
224 |
225 | @torch.no_grad()
226 | def compute_accuracy(vae, data_loader, classifier=None, model_zl=False):
227 | all_y, all_y_pred = compute_predictions(
228 | vae, data_loader, classifier=classifier, model_zl=model_zl
229 | )
230 | return np.mean(all_y == all_y_pred)
231 |
232 |
233 | Accuracy = namedtuple(
234 | "Accuracy", ["unweighted", "weighted", "worst", "accuracy_classes"]
235 | )
236 |
237 |
238 | @torch.no_grad()
239 | def compute_accuracy_tuple(y, y_pred):
240 | y = y.ravel()
241 | n_labels = len(np.unique(y))
242 | classes_probabilities = []
243 | accuracy_classes = []
244 | for cl in range(n_labels):
245 | idx = y == cl
246 | classes_probabilities += [np.mean(idx)]
247 | accuracy_classes += [
248 | np.mean((y[idx] == y_pred[idx])) if classes_probabilities[-1] else 0
249 | ]
250 | # This is also referred to as the "recall": p = n_true_positive / (n_false_negative + n_true_positive)
251 | # ( We could also compute the "precision": p = n_true_positive / (n_false_positive + n_true_positive) )
252 | accuracy_named_tuple = Accuracy(
253 | unweighted=np.dot(accuracy_classes, classes_probabilities),
254 | weighted=np.mean(accuracy_classes),
255 | worst=np.min(accuracy_classes),
256 | accuracy_classes=accuracy_classes,
257 | )
258 | return accuracy_named_tuple
259 |
260 |
261 | @torch.no_grad()
262 | def compute_accuracy_nn(data_train, labels_train, data_test, labels_test, k=5):
263 | clf = neighbors.KNeighborsClassifier(k, weights="distance")
264 | return compute_accuracy_classifier(
265 | clf, data_train, labels_train, data_test, labels_test
266 | )
267 |
268 |
269 | @torch.no_grad()
270 | def compute_accuracy_classifier(clf, data_train, labels_train, data_test, labels_test):
271 | clf.fit(data_train, labels_train)
272 | # Predicting the labels
273 | y_pred_test = clf.predict(data_test)
274 | y_pred_train = clf.predict(data_train)
275 |
276 | return (
277 | (
278 | compute_accuracy_tuple(labels_train, y_pred_train),
279 | compute_accuracy_tuple(labels_test, y_pred_test),
280 | ),
281 | y_pred_test,
282 | )
283 |
284 |
285 | @torch.no_grad()
286 | def compute_accuracy_svc(
287 | data_train,
288 | labels_train,
289 | data_test,
290 | labels_test,
291 | param_grid=None,
292 | verbose=0,
293 | max_iter=-1,
294 | ):
295 | if param_grid is None:
296 | param_grid = [
297 | {"C": [1, 10, 100, 1000], "kernel": ["linear"]},
298 | {"C": [1, 10, 100, 1000], "gamma": [0.001, 0.0001], "kernel": ["rbf"]},
299 | ]
300 | svc = SVC(max_iter=max_iter)
301 | clf = GridSearchCV(svc, param_grid, verbose=verbose)
302 | return compute_accuracy_classifier(
303 | clf, data_train, labels_train, data_test, labels_test
304 | )
305 |
306 |
307 | @torch.no_grad()
308 | def compute_accuracy_rf(
309 | data_train, labels_train, data_test, labels_test, param_grid=None, verbose=0
310 | ):
311 | if param_grid is None:
312 | param_grid = {"max_depth": np.arange(3, 10), "n_estimators": [10, 50, 100, 200]}
313 | rf = RandomForestClassifier(max_depth=2, random_state=0)
314 | clf = GridSearchCV(rf, param_grid, verbose=verbose)
315 | return compute_accuracy_classifier(
316 | clf, data_train, labels_train, data_test, labels_test
317 | )
318 |
--------------------------------------------------------------------------------
/scMVP/inference/inference.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import matplotlib.pyplot as plt
4 | import torch
5 |
6 | from scMVP.inference import Trainer
7 |
8 | plt.switch_backend("agg")
9 |
10 |
11 | class UnsupervisedTrainer(Trainer):
12 | r"""The VariationalInference class for the unsupervised training of an autoencoder.
13 |
14 | Args:
15 | :model: A model instance from class ``VAE``, ``VAEC``,
16 | :gene_dataset: A gene_dataset instance like ``snareDataset()``
17 | :train_size: The train size, either a float between 0 and 1 or an integer for the number of training samples
18 | to use Default: ``0.8``.
19 | :test_size: The test size, either a float between 0 and 1 or an integer for the number of training samples
20 | to use Default: ``None``, which is equivalent to data not in the train set. If ``train_size`` and ``test_size``
21 | do not add to 1 or the length of the dataset then the remaining samples are added to a ``validation_set``.
22 | :n_epochs_kl_warmup: Number of epochs for linear warmup of KL(q(z|x)||p(z)) term. After `n_epochs_kl_warmup`,
23 | the training objective is the ELBO. This might be used to prevent inactivity of latent units, and/or to
24 | improve clustering of latent space, as a long warmup turns the model into something more of an autoencoder.
25 | :normalize_loss: A boolean determining whether the loss is divided by the total number of samples used for
26 | training. In particular, when the global KL divergence is equal to 0 and the division is performed, the loss
27 | for a minibatchis is equal to the average of reconstruction losses and KL divergences on the minibatch.
28 | Default: ``None``, which is equivalent to setting False when the model is an instance from class
29 | ``AutoZIVAE`` and True otherwise.
30 | :\*\*kwargs: Other keywords arguments from the general Trainer class.
31 |
32 | Examples:
33 | >>> gene_dataset = snareDataset()
34 | >>> vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False,
35 | ... n_labels=gene_dataset.n_labels)
36 |
37 | >>> infer = VariationalInference(gene_dataset, vae, train_size=0.5)
38 | >>> infer.train(n_epochs=20, lr=1e-3)
39 | """
40 | default_metrics_to_monitor = ["elbo"]
41 |
42 | def __init__(
43 | self,
44 | model,
45 | gene_dataset,
46 | train_size=0.8,
47 | test_size=None,
48 | n_epochs_kl_warmup=400,
49 | normalize_loss=None,
50 | **kwargs
51 | ):
52 | super().__init__(model, gene_dataset, **kwargs)
53 | self.n_epochs_kl_warmup = n_epochs_kl_warmup
54 |
55 | self.normalize_loss = (
56 | not (
57 | hasattr(self.model, "reconstruction_loss")
58 | and self.model.reconstruction_loss == "autozinb"
59 | )
60 | if normalize_loss is None
61 | else normalize_loss
62 | )
63 |
64 | # Total size of the dataset used for training
65 | # (e.g. training set in this class but testing set in AdapterTrainer).
66 | # It used to rescale minibatch losses (cf. eq. (8) in Kingma et al., Auto-Encoding Variational Bayes, iCLR 2013)
67 | self.n_samples = 1.0
68 |
69 | if type(self) is UnsupervisedTrainer:
70 | (
71 | self.train_set,
72 | self.test_set,
73 | self.validation_set,
74 | ) = self.train_test_validation(model, gene_dataset, train_size, test_size)
75 | self.train_set.to_monitor = ["elbo"]
76 | self.test_set.to_monitor = ["elbo"]
77 | self.validation_set.to_monitor = ["elbo"]
78 | self.n_samples = len(self.train_set.indices)
79 |
80 | @property
81 | def posteriors_loop(self):
82 | return ["train_set"]
83 |
84 | def loss(self, tensors):
85 | sample_batch, local_l_mean, local_l_var, batch_index, y = tensors
86 | #reconst_loss, kl_divergence_local, kl_divergence_global = self.model(
87 | # sample_batch, local_l_mean, local_l_var, batch_index, y
88 | #)
89 | reconst_loss, kl_divergence_local, kl_divergence_global = self.model(
90 | sample_batch, local_l_mean, local_l_var, batch_index, batch_index
91 | )
92 | loss = (
93 | self.n_samples
94 | * torch.mean(reconst_loss + self.kl_weight * kl_divergence_local)
95 | + kl_divergence_global
96 | )
97 | if self.normalize_loss:
98 | loss = loss / self.n_samples
99 | return loss
100 |
101 | def on_epoch_begin(self):
102 | if self.n_epochs_kl_warmup is not None:
103 | self.kl_weight = min(1, self.epoch / self.n_epochs_kl_warmup)
104 | else:
105 | self.kl_weight = 1.0
106 |
107 |
108 | class AdapterTrainer(UnsupervisedTrainer):
109 | def __init__(self, model, gene_dataset, posterior_test, frequency=5):
110 | super().__init__(model, gene_dataset, frequency=frequency)
111 | self.test_set = posterior_test
112 | self.test_set.to_monitor = ["elbo"]
113 | self.params = list(self.model.z_encoder.parameters()) + list(
114 | self.model.l_encoder.parameters()
115 | )
116 | self.z_encoder_state = copy.deepcopy(model.z_encoder.state_dict())
117 | self.l_encoder_state = copy.deepcopy(model.l_encoder.state_dict())
118 | self.n_scale = len(self.test_set.indices)
119 |
120 | @property
121 | def posteriors_loop(self):
122 | return ["test_set"]
123 |
124 | def train(self, n_path=10, n_epochs=50, **kwargs):
125 | for i in range(n_path):
126 | # Re-initialize to create new path
127 | self.model.z_encoder.load_state_dict(self.z_encoder_state)
128 | self.model.l_encoder.load_state_dict(self.l_encoder_state)
129 | super().train(n_epochs, params=self.params, **kwargs)
130 |
131 | return min(self.history["elbo_test_set"])
132 |
--------------------------------------------------------------------------------
/scMVP/inference/multi_inference.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import logging
3 | import torch
4 | from torch.distributions import Poisson, Gamma, Bernoulli, Normal
5 | from torch.utils.data import DataLoader
6 | import torch.nn.functional as F
7 | from torch import logsumexp
8 | import torch.distributions as distributions
9 | import numpy as np
10 |
11 | from scMVP.inference import Posterior
12 | from . import UnsupervisedTrainer
13 |
14 | from scMVP.dataset import GeneExpressionDataset
15 | from scMVP.models import multi_vae_attention
16 | # from sklearn.utils.linear_assignment_ import linear_assignment
17 | from scipy.optimize import linear_sum_assignment as linear_assignment
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 |
23 |
24 |
25 | class MultiPosterior(Posterior):
26 | r"""The functional data unit for Multivae. A `MultiPosterior` instance is instantiated with a model and
27 | a gene_dataset, and as well as additional arguments that for Pytorch's `DataLoader`. A subset of indices
28 | can be specified, for purposes such as splitting the data into train/test/validation. Each trainer instance of the `Trainer` class can therefore have multiple
29 | `MultiPosterior` instances to train a model. A `MultiPosterior` instance also comes with many methods or
30 | utilities for its corresponding data.
31 |
32 |
33 | :param model: A model instance from class ``Multivae``
34 | :param gene_dataset: A gene_dataset instance like ``ATACDataset()`` with attribute ``ATAC_expression``
35 | :param shuffle: Specifies if a `RandomSampler` or a `SequentialSampler` should be used
36 | :param indices: Specifies how the data should be split with regards to train/test or labelled/unlabelled
37 | :param use_cuda: Default: ``True``
38 | :param data_loader_kwarg: Keyword arguments to passed into the `DataLoader`
39 |
40 | Examples:
41 |
42 | Let us instantiate a `trainer`, with a gene_dataset and a model
43 |
44 | >>> gene_dataset = CbmcDataset()
45 | >>> totalvi = TOTALVI(gene_dataset.nb_genes, len(gene_dataset.protein_names),
46 | ... n_batch=gene_dataset.n_batches * False, n_labels=gene_dataset.n_labels, use_cuda=True)
47 | >>> trainer = TotalTrainer(vae, gene_dataset)
48 | >>> trainer.train(n_epochs=400)
49 | """
50 |
51 | def __init__(
52 | self,
53 | model: multi_vae_attention,
54 | gene_dataset: GeneExpressionDataset,
55 | shuffle: bool = False,
56 | indices: Optional[np.ndarray] = None,
57 | use_cuda: bool = True,
58 | data_loader_kwargs=dict(),
59 | ):
60 |
61 | super().__init__(
62 | model,
63 | gene_dataset,
64 | shuffle=shuffle,
65 | indices=indices,
66 | use_cuda=use_cuda,
67 | data_loader_kwargs=data_loader_kwargs,
68 | )
69 | # Add atac tensor as another tensor to be loaded
70 | self.data_loader_kwargs.update(
71 | {
72 | "collate_fn": gene_dataset.collate_fn_builder(
73 | {"atac_expression": np.float32}# debug cell index
74 | )
75 | }
76 | )
77 |
78 | self.data_loader = DataLoader(gene_dataset, **self.data_loader_kwargs)
79 |
80 | def corrupted(self):
81 | return self.update(
82 | {
83 | "collate_fn": self.gene_dataset.collate_fn_builder(
84 | {"atac_expression": np.float32}, corrupted=True
85 | )
86 | }
87 | )
88 |
89 | def uncorrupted(self):
90 | return self.update(
91 | {
92 | "collate_fn": self.gene_dataset.collate_fn_builder(
93 | {"atac_expression": np.float32}
94 | )
95 | }
96 | )
97 |
98 | @torch.no_grad()
99 | def elbo(self):
100 | elbo = self.compute_elbo(self.model)
101 | logger.debug("ELBO : %.4f" % elbo)
102 | return elbo
103 | elbo.mode = "min"
104 |
105 | @torch.no_grad()
106 | def reconstruction_error(self):
107 | reconstruction_error = self.compute_reconstruction_error(self.model, self)
108 | logger.debug("Reconstruction Error : %.4f" % reconstruction_error)
109 | return reconstruction_error
110 |
111 | reconstruction_error.mode = "min"
112 |
113 | @torch.no_grad()
114 | def marginal_ll(self, n_mc_samples=1000):
115 |
116 | ll = self.compute_marginal_log_likelihood(self.model, self, n_mc_samples)
117 | logger.debug("True LL : %.4f" % ll)
118 | return ll
119 |
120 | def compute_elbo(self, vae:multi_vae_attention, **kwargs):
121 | """ Computes the ELBO.
122 |
123 | The ELBO is the reconstruction error + the KL divergences
124 | between the variational distributions and the priors.
125 | It differs from the marginal log likelihood.
126 | Specifically, it is a lower bound on the marginal log likelihood
127 | plus a term that is constant with respect to the variational distribution.
128 | It still gives good insights on the modeling of the data, and is fast to compute.
129 | """
130 | # Iterate once over the posterior and compute the elbo
131 | elbo = 0
132 | for i_batch, tensors in enumerate(self):
133 | (
134 | sample_batch_X,
135 | local_l_mean,
136 | local_l_var,
137 | batch_index,
138 | label,
139 | sample_batch_Y,
140 | ) = tensors
141 |
142 | reconst_loss, kl_divergence_local, kl_divergence_global = vae(
143 | sample_batch_X, sample_batch_Y, local_l_mean, local_l_var, batch_index, label
144 | )
145 | elbo += torch.sum(reconst_loss + kl_divergence_local).item()
146 | n_samples = len(self.indices)
147 | elbo += kl_divergence_global
148 | return elbo / n_samples
149 |
150 | def compute_reconstruction_error(self, vae:multi_vae_attention, **kwargs):
151 | r""" Computes log p(x/z), which is the reconstruction error .
152 | Differs from the marginal log likelihood, but still gives good
153 | insights on the modeling of the data, and is fast to compute
154 |
155 | This is really a helper function to self.ll, self.ll_protein, etc.
156 | """
157 | # Iterate once over the posterior and computes the total log_likelihood
158 | log_lkl = 0
159 | for i_batch, tensors in enumerate(self):
160 | sample_batch, local_l_mean, local_l_var, batch_index, labels = tensors[
161 | :5
162 | ] # general fish case
163 |
164 | # Distribution parameters
165 | outputs = vae.inference(sample_batch, batch_index, labels, **kwargs)
166 | p_rna_r = outputs["p_rna_r"]
167 | p_rna_rate = outputs["p_rna_rate"]
168 | p_rna_dropout = outputs["p_rna_dropout"]
169 | p_atac_mean = outputs["p_atac_mean"]
170 | p_atac_r = outputs["p_atac_r"]
171 | p_atac_dropout = outputs["p_atac_dropout"]
172 |
173 | # Reconstruction loss
174 | reconst_rna_loss = vae.get_reconstruction_loss(
175 | sample_batch,
176 | p_rna_rate,
177 | p_rna_r,
178 | p_rna_dropout,
179 | # bernoulli_params=bernoulli_params,
180 | **kwargs
181 | )
182 | reconst_atac_loss = vae.get_reconstruction_atac_loss(
183 | sample_batch,
184 | p_atac_mean,
185 | p_atac_r,
186 | p_atac_dropout,
187 | **kwargs
188 | )
189 |
190 | log_lkl += torch.sum(reconst_rna_loss).item()
191 | log_lkl += torch.sum(reconst_atac_loss).item()
192 | n_samples = len(self.indices)
193 | return log_lkl / n_samples
194 |
195 | def compute_marginal_log_likelihood(self, vae:multi_vae_attention , n_mc_samples):
196 | """ Computes a biased estimator for log p(x), which is the marginal log likelihood.
197 |
198 | Despite its bias, the estimator still converges to the real value
199 | of log p(x) when n_samples_mc (for Monte Carlo) goes to infinity
200 | (a fairly high value like 100 should be enough)
201 | Due to the Monte Carlo sampling, this method is not as computationally efficient
202 | as computing only the reconstruction loss
203 | """
204 | # Uses MC sampling to compute a tighter lower bound on log p(x)
205 |
206 | log_lkl = 0
207 | for i_batch, tensors in enumerate(self):
208 | sample_batch, local_l_mean, local_l_var, batch_index, labels = tensors
209 | to_sum = torch.zeros(sample_batch.size()[0], n_mc_samples)
210 |
211 | for i in range(n_mc_samples):
212 | # Distribution parameters and sampled variables
213 | outputs = vae.inference(sample_batch, batch_index, labels)
214 | p_rna_r = outputs["p_rna_r"]
215 | p_rna_rate = outputs["p_rna_rate"]
216 | p_rna_dropout = outputs["p_rna_dropout"]
217 | qz_m = outputs["qz_m"]
218 | qz_v = outputs["qz_v"]
219 | z = outputs["z"]
220 | p_atac_mean = outputs["p_atac_mean"]
221 | p_atac_r = outputs["p_atac_r"]
222 | p_atac_dropout = outputs["p_atac_dropout"]
223 | mu_c = outputs["mu_c"]
224 | var_c = outputs["var_c"]
225 | gamma = outputs["gamma"]
226 | mu_c_max = outputs["mu_c_max"],
227 | var_c_max = outputs["var_c_max"],
228 | z_c_max = outputs["z_c_max"],
229 |
230 | # Reconstruction Loss
231 | reconst_rna_loss = vae.get_reconstruction_loss(
232 | sample_batch,
233 | p_rna_r,
234 | p_rna_rate,
235 | p_rna_dropout,
236 | )
237 | reconst_atac_loss = vae.get_reconstruction_atac_loss(
238 | sample_batch,
239 | p_atac_r,
240 | p_atac_mean,
241 | p_atac_dropout,
242 | )
243 |
244 | # Log-probabilities
245 | #p_l = Normal(local_l_mean, local_l_var.sqrt()).log_prob(library).sum(dim=-1)
246 | p_z = 0.0
247 | for prob, mu, var in mu_c, var_c, gamma:
248 | p_z += prob*Normal(mu, var.sqrt()).log_prob(z).sum(dim=-1)
249 |
250 | p_x_zl = -reconst_rna_loss - reconst_atac_loss
251 | q_z_x = Normal(qz_m, qz_v.sqrt()).log_prob(z).sum(dim=-1)
252 | #q_z_max = Normal(mu_c_max, var_c_max.sqrt()).log_prob(z_c_max).sum(dim=-1)
253 |
254 | to_sum[:, i] = p_z + p_x_zl - q_z_x #- q_z_max
255 |
256 | batch_log_lkl = logsumexp(to_sum, dim=-1) - np.log(n_mc_samples)
257 | log_lkl += torch.sum(batch_log_lkl).item()
258 |
259 | n_samples = len(self.indices)
260 | # The minus sign is there because we actually look at the negative log likelihood
261 | return -log_lkl / n_samples
262 |
263 | @torch.no_grad()
264 | def get_latent(self, sample=False):
265 | """
266 | Output posterior z mean or sample, batch index, and label
267 | :param sample: z mean or z sample
268 | :return: three np.ndarrays, latent, batch_indices, labels
269 | """
270 | latent = []
271 | latent_rna = [];
272 | latent_atac = [];
273 | batch_indices = []
274 | labels = []
275 | cluster_gamma = []
276 | cluster_index = []
277 | for tensors in self:
278 | sample_batch_rna, local_l_mean, local_l_var, batch_index, label, sample_batch_atac = tensors
279 | give_mean = not sample
280 | latent_temp = self.model.sample_from_posterior_z(
281 | [sample_batch_rna, sample_batch_atac], y=label, give_mean=give_mean
282 | )
283 | latent += [
284 | latent_temp[0][0].cpu()
285 | ]
286 | latent_rna += [
287 | latent_temp[1][0].cpu()
288 | ]
289 | latent_atac += [
290 | latent_temp[2].cpu()
291 | ]
292 | gamma, mu_c, var_c, pi = self.model.get_gamma(latent_temp[0][0])
293 | cluster_gamma += [gamma.cpu()]
294 | cluster_index += [torch.argmax(gamma.cpu(),dim=1)]
295 | batch_indices += [batch_index.cpu()]
296 | labels += [label.cpu()]
297 | return (
298 | np.array(torch.cat(latent)),
299 | np.array(torch.cat(latent_rna)),
300 | np.array(torch.cat(latent_atac)),
301 | np.array(torch.cat(cluster_gamma)),
302 | np.array(torch.cat(cluster_index)),
303 | np.array(torch.cat(batch_indices)),
304 | np.array(torch.cat(labels)).ravel(),
305 | )
306 |
307 | @torch.no_grad()
308 | def generate(
309 | self,
310 | n_samples: int = 100,
311 | genes: Optional[np.ndarray] = None,
312 | batch_size: int = 256,
313 | #batch_size: int = 128,
314 | ) :
315 | """
316 | Create observation samples from the Posterior Predictive distribution
317 |
318 | :param n_samples: Number of required samples for each cell
319 | :param genes: Indices of genes of interest
320 | :param batch_size: Desired Batch size to generate data
321 |
322 | :return: Tuple (x_new, x_old)
323 | Where x_old has shape (n_cells, n_genes)
324 | Where x_new has shape (n_cells, n_genes, n_samples)
325 | """
326 | assert self.model.reconstruction_loss in ["zinb", "zip"]
327 | zero_inflated = "zinb"
328 |
329 | rna_old = []
330 | rna_new = []
331 | atac_old = []
332 | atac_new = []
333 | for tensors in self.update({"batch_size": batch_size}):
334 | sample_batch, _, _, batch_index, labels = tensors
335 | outputs = self.model.inference(
336 | sample_batch, batch_index=batch_index, y=labels, n_samples=n_samples
337 | )
338 | p_rna_r = outputs["p_rna_r"]
339 | p_rna_rate = outputs["p_rna_rate"]
340 | p_rna_dropout = outputs["p_rna_dropout"]
341 | p_atac_mean = outputs["p_atac_mean"]
342 | p_atac_dropout = outputs["p_atac_dropout"]
343 |
344 | # Generating rna-seq data
345 | p = p_rna_rate / (p_rna_rate + p_rna_r)
346 | r = p_rna_r
347 | # Important remark: Gamma is parametrized by the rate = 1/scale!
348 | l_train_rna = distributions.Gamma(concentration=r, rate=(1 - p) / p).sample()
349 | # Clamping as distributions objects can have buggy behaviors when
350 | # their parameters are too high
351 | l_train_rna = torch.clamp(l_train_rna, max=1e8)
352 | gene_expressions = distributions.Poisson(
353 | l_train_rna
354 | ).sample() # Shape : (n_samples, n_cells_batch, n_genes)
355 |
356 | #Generating atac-seq data
357 | l_train_atac = torch.clamp(p_atac_mean, max=1e2)
358 | atac_expressions = distributions.Poisson(
359 | l_train_atac
360 | ).sample()
361 |
362 | # zero-inflate
363 | if zero_inflated:
364 | p_zero_rna = (1.0 + torch.exp(-p_rna_dropout)).pow(-1)
365 | random_prob_rna = torch.rand_like(p_zero_rna)
366 | gene_expressions[random_prob_rna <= p_zero_rna] = 0
367 |
368 | p_zero_atac = (1.0 + torch.exp(-p_atac_dropout)).pow(-1)
369 | random_prob_atac = torch.rand_like(p_zero_atac)
370 | atac_expressions[random_prob_atac <= p_zero_atac] = 0
371 |
372 | gene_expressions = gene_expressions.permute(
373 | [1, 2, 0]
374 | ) # Shape : (n_cells_batch, n_genes, n_samples)
375 | atac_expressions = atac_expressions.permute(
376 | [1, 2, 0]
377 | )
378 |
379 | rna_old.append(sample_batch[0].cpu())
380 | rna_new.append(gene_expressions.cpu())
381 | atac_old.append(sample_batch[1].cpu())
382 | atac_new.append(atac_expressions.cpu())
383 |
384 | rna_old = torch.cat(rna_old) # Shape (n_cells, n_genes)
385 | rna_new = torch.cat(rna_new) # Shape (n_cells, n_genes, n_samples)
386 | if genes is not None:
387 | gene_ids = self.gene_dataset.genes_to_index(genes)
388 | rna_new = rna_new[:, gene_ids, :]
389 | rna_old = rna_old[:, gene_ids]
390 | return rna_new.numpy(), rna_old.numpy(), atac_new.numpy(), rna_old.numpy()
391 |
392 | @torch.no_grad()
393 | def imputation(self, n_samples: int = 1):
394 | """ Gene imputation
395 | """
396 | imputed_rna_list = []
397 | imputed_atac_list = []
398 | label_list = [] # for the annotated data
399 | atac_list = []
400 | for tensors in self:
401 | x_rna, local_l_mean, local_l_var, batch_index, label, x_atac = tensors
402 | p_rna_rate, p_atac_rate = self.model.get_sample_rate(
403 | x=[x_rna,x_atac], batch_index=batch_index, y=label, n_samples=n_samples, local_l_mean = local_l_mean, local_l_var = local_l_var
404 | )
405 | imputed_rna_list += [np.array(p_rna_rate.cpu())]
406 | imputed_atac_list += [np.array(p_atac_rate.cpu())]
407 | label_list += [np.array(label.cpu())] # only for annotated data
408 | atac_list += [np.array(x_atac.cpu())] # for the bins without call peak
409 | imputed_rna_list = np.concatenate(imputed_rna_list)
410 | imputed_atac_list = np.concatenate(imputed_atac_list)
411 | label_list = np.concatenate(label_list) # only for annotated data
412 | atac_list = np.concatenate(atac_list)# for the bins without call peak
413 | return imputed_rna_list.squeeze(), imputed_atac_list.squeeze(), label_list.squeeze(), atac_list
414 |
415 | @torch.no_grad()
416 | def get_sample_scale(self):
417 | p_rna_scales = []
418 | p_atac_scales = []
419 | for tensors in self:
420 | x_rna, _, _, batch_index, labels, x_atac = tensors
421 | p_rna_scales += [
422 | np.array(
423 | (
424 | self.model.get_sample_scale(
425 | x=[x_rna,x_atac], batch_index=batch_index, y=labels, n_samples=1
426 | )[0]
427 | )
428 | )
429 | ]
430 | p_atac_scales += [
431 | np.array(
432 | (
433 | self.model.get_sample_scale(
434 | x=[x_rna,x_atac], batch_index=batch_index, y=labels, n_samples=1
435 | )[1]
436 | )
437 | )
438 | ]
439 | return np.concatenate(p_rna_scales), np.concatenate(p_atac_scales)
440 |
441 | def cluster_acc(Y_pred, Y):
442 | assert Y_pred.size == Y.size
443 | D = max(Y_pred.max(), Y.max()) + 1
444 | w = np.zeros((D, D), dtype=np.int64)
445 | for i in range(Y_pred.size):
446 | w[Y_pred[i], Y[i]] += 1
447 | ind = linear_assignment(w.max() - w)
448 | return sum([w[i, j] for i, j in ind]) * 1.0 / Y_pred.size, ind
449 |
450 | def get_clustering(self):
451 | latent, latent_rna, latent_atac, cluster_gamma, batch_indices, labels = self.get_latent()
452 | cluster_accuarcy, ind = self.cluster_acc(np.argmax(cluster_gamma,axis=1),labels)
453 | print('cell dataset multi-vae - clustering accuracy: %.2f%%' % (cluster_accuarcy * 100))
454 | return cluster_accuarcy, ind
455 |
456 | class MultiTrainer(UnsupervisedTrainer):
457 | r"""The VariationalInference class for the unsupervised training of an autoencoder.
458 |
459 | Args:
460 | :model: A model instance from class ``TOTALVI``
461 | :gene_dataset: A gene_dataset instance like ``CbmcDataset()`` with attribute ``protein_expression``
462 | :train_size: The train size, either a float between 0 and 1 or and integer for the number of training samples
463 | to use Default: ``0.93``.
464 | :test_size: The test size, either a float between 0 and 1 or and integer for the number of training samples
465 | to use Default: ``0.02``. Note that if train and test do not add to 1 the remainder is placed in a validation set
466 | :\*\*kwargs: Other keywords arguments from the general Trainer class.
467 | """
468 | default_metrics_to_monitor = ["elbo"]
469 |
470 | def __init__(
471 | self,
472 | model,
473 | dataset,
474 | train_size=0.90,
475 | test_size=0.05,
476 | pro_recons_weight=1.0,
477 | n_epochs_back_kl_warmup=50, #200, init
478 | n_epochs_kl_warmup=200,
479 | **kwargs
480 | ):
481 | self.n_genes = dataset.nb_genes
482 | self.n_proteins = model.n_input_atac
483 |
484 | self.pro_recons_weight = pro_recons_weight
485 | self.n_epochs_back_kl_warmup = n_epochs_back_kl_warmup
486 | super().__init__(
487 | model, dataset, n_epochs_kl_warmup=n_epochs_kl_warmup, **kwargs
488 | )
489 | if type(self) is MultiTrainer:
490 | (
491 | self.train_set,
492 | self.test_set,
493 | self.validation_set,
494 | ) = self.train_test_validation(
495 | model, dataset, train_size, test_size, type_class=MultiPosterior
496 | )
497 | self.train_set.to_monitor = []
498 | self.test_set.to_monitor = ["elbo"]
499 | self.validation_set.to_monitor = ["elbo"]
500 |
501 | def loss(self, tensors):
502 | (
503 | sample_batch_X,
504 | local_l_mean,
505 | local_l_var,
506 | batch_index,
507 | label,
508 | sample_batch_Y,
509 | ) = tensors
510 |
511 | #reconst_loss, kl_divergence_local, kl_divergence_global = self.model(
512 | # sample_batch_X, sample_batch_Y, local_l_mean, local_l_var, batch_index, label
513 | #)
514 | reconst_loss, kl_divergence_local, kl_divergence_global = self.model(
515 | sample_batch_X, sample_batch_Y, local_l_mean, local_l_var, batch_index, batch_index
516 | )
517 | loss = (
518 | self.n_samples
519 | * torch.mean(reconst_loss + self.back_warmup_weight * kl_divergence_local)
520 | + kl_divergence_global
521 | )
522 | print(
523 | "reconst_loss = %f,kl_divergence_local = %f,kl_weight = %f,loss = %f" %
524 | (torch.mean(reconst_loss), torch.mean(kl_divergence_local), self.back_warmup_weight, loss)
525 | )
526 | # self.KL_divergence = kl_divergence_global
527 | if self.normalize_loss:
528 | loss = loss / self.n_samples
529 | return loss
530 |
531 |
532 | def on_epoch_begin(self):
533 | super().on_epoch_begin()
534 | if self.n_epochs_back_kl_warmup is not None:
535 | #self.back_warmup_weight = min(1, self.epoch + self.n_epochs_back_kl_warmup / self.n_epochs_back_kl_warmup)
536 | self.back_warmup_weight = min(1, self.epoch + self.n_epochs_back_kl_warmup / self.n_epochs_back_kl_warmup)
537 | else:
538 | self.back_warmup_weight = 1.0
539 |
540 |
--------------------------------------------------------------------------------
/scMVP/inference/trainer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import sys
3 | import time
4 |
5 | from abc import abstractmethod
6 | from collections import defaultdict, OrderedDict
7 | from itertools import cycle
8 |
9 | import numpy as np
10 | import torch
11 |
12 | from sklearn.model_selection._split import _validate_shuffle_split
13 | from torch.utils.data.sampler import SubsetRandomSampler
14 | from tqdm import trange
15 |
16 | from scMVP.inference.posterior import Posterior
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | class Trainer:
22 | r"""The abstract Trainer class for training a PyTorch model and monitoring its statistics. It should be
23 | inherited at least with a .loss() function to be optimized in the training loop.
24 |
25 | Args:
26 | :model: A model instance from class ``VAE``, ``VAEC``, ``SCANVI``
27 | :gene_dataset: A gene_dataset instance like ``CortexDataset()``
28 | :use_cuda: Default: ``True``.
29 | :metrics_to_monitor: A list of the metrics to monitor. If not specified, will use the
30 | ``default_metrics_to_monitor`` as specified in each . Default: ``None``.
31 | :benchmark: if True, prevents statistics computation in the training. Default: ``False``.
32 | :frequency: The frequency at which to keep track of statistics. Default: ``None``.
33 | :early_stopping_metric: The statistics on which to perform early stopping. Default: ``None``.
34 | :save_best_state_metric: The statistics on which we keep the network weights achieving the best store, and
35 | restore them at the end of training. Default: ``None``.
36 | :on: The data_loader name reference for the ``early_stopping_metric`` and ``save_best_state_metric``, that
37 | should be specified if any of them is. Default: ``None``.
38 | :show_progbar: If False, disables progress bar.
39 | :seed: Random seed for train/test/validate split
40 | """
41 | default_metrics_to_monitor = []
42 |
43 | def __init__(
44 | self,
45 | model,
46 | gene_dataset,
47 | use_cuda=True,
48 | metrics_to_monitor=None,
49 | benchmark=False,
50 | frequency=None,
51 | weight_decay=1e-6,
52 | early_stopping_kwargs=None,
53 | data_loader_kwargs=None,
54 | show_progbar=True,
55 | seed=0,
56 | ):
57 | # handle mutable defaults
58 | early_stopping_kwargs = (
59 | early_stopping_kwargs if early_stopping_kwargs else dict()
60 | )
61 | data_loader_kwargs = data_loader_kwargs if data_loader_kwargs else dict()
62 |
63 | self.model = model
64 | self.gene_dataset = gene_dataset
65 | self._posteriors = OrderedDict()
66 | self.seed = seed
67 |
68 | #self.data_loader_kwargs = {"batch_size": 256, "pin_memory": use_cuda} # 128 for batchsize in init
69 | self.data_loader_kwargs = {"batch_size": 64, "pin_memory": use_cuda} # 128 for batchsize in init
70 | self.data_loader_kwargs.update(data_loader_kwargs)
71 |
72 | self.weight_decay = weight_decay
73 | self.benchmark = benchmark
74 | self.epoch = -1 # epoch = self.epoch + 1 in compute metrics
75 | self.training_time = 0
76 | # self.KL_divergence = -1
77 | # self.KL_divergence_max = 10000
78 |
79 | if metrics_to_monitor is not None:
80 | self.metrics_to_monitor = set(metrics_to_monitor)
81 | else:
82 | self.metrics_to_monitor = set(self.default_metrics_to_monitor)
83 |
84 | self.early_stopping = EarlyStopping(**early_stopping_kwargs)
85 |
86 | if self.early_stopping.early_stopping_metric:
87 | self.metrics_to_monitor.add(self.early_stopping.early_stopping_metric)
88 |
89 | self.use_cuda = use_cuda and torch.cuda.is_available()
90 | if self.use_cuda:
91 | self.model.cuda()
92 |
93 | self.frequency = frequency if not benchmark else None
94 |
95 | self.history = defaultdict(list)
96 |
97 | self.best_state_dict = self.model.state_dict()
98 | self.best_epoch = self.epoch
99 |
100 | self.show_progbar = show_progbar
101 |
102 | @torch.no_grad()
103 | def compute_metrics(self):
104 | begin = time.time()
105 | epoch = self.epoch + 1
106 | if self.frequency and (
107 | epoch == 0 or epoch == self.n_epochs or (epoch % self.frequency == 0)
108 | ):
109 | with torch.set_grad_enabled(False):
110 | self.model.eval()
111 | logger.debug("\nEPOCH [%d/%d]: " % (epoch, self.n_epochs))
112 |
113 | for name, posterior in self._posteriors.items():
114 | message = " ".join([s.capitalize() for s in name.split("_")[-2:]])
115 | if posterior.nb_cells < 5:
116 | logging.debug(
117 | message + " is too small to track metrics (<5 samples)"
118 | )
119 | continue
120 | if hasattr(posterior, "to_monitor"):
121 | for metric in posterior.to_monitor:
122 | if metric not in self.metrics_to_monitor:
123 | logger.debug(message)
124 | result = getattr(posterior, metric)()
125 | self.history[metric + "_" + name] += [result]
126 | for metric in self.metrics_to_monitor:
127 | result = getattr(posterior, metric)()
128 | self.history[metric + "_" + name] += [result]
129 | self.model.train()
130 | self.compute_metrics_time += time.time() - begin
131 |
132 | def train(self, n_epochs=20, lr=1e-3, eps=0.01, params=None):
133 | begin = time.time()
134 | self.model.train()
135 |
136 | if params is None:
137 | params = filter(lambda p: p.requires_grad, self.model.parameters())
138 |
139 | optimizer = self.optimizer = torch.optim.Adam(
140 | params, lr=lr, eps=eps, weight_decay=self.weight_decay
141 | )
142 | aa = self.model.parameters()
143 |
144 | self.compute_metrics_time = 0
145 | self.n_epochs = n_epochs
146 | flag = True
147 |
148 | with trange(
149 | n_epochs, desc="training", file=sys.stdout, disable=not self.show_progbar
150 | ) as pbar:
151 | # We have to use tqdm this way so it works in Jupyter notebook.
152 | # See https://stackoverflow.com/questions/42212810/tqdm-in-jupyter-notebook
153 | for self.epoch in pbar:
154 | self.on_epoch_begin()
155 | pbar.update(1)
156 | for tensors_list in self.data_loaders_loop():
157 | if tensors_list[0][0].shape[0] < 3:
158 | continue
159 | loss = self.loss(*tensors_list)
160 | print(loss)
161 | # if self.KL_divergence > self.KL_divergence_max:
162 | # break
163 | #if self.epoch == 15 and flag:
164 | # flag = False
165 | # optimizer.add_param_group({'params': self.model.get_params()})
166 | optimizer.zero_grad()
167 | loss.backward()
168 | optimizer.step()
169 |
170 | if not self.on_epoch_end():
171 | break
172 |
173 | if self.early_stopping.save_best_state_metric is not None:
174 | self.model.load_state_dict(self.best_state_dict)
175 | self.compute_metrics()
176 |
177 | self.model.eval()
178 | self.training_time += (time.time() - begin) - self.compute_metrics_time
179 | if self.frequency:
180 | logger.debug(
181 | "\nTraining time: %i s. / %i epochs"
182 | % (int(self.training_time), self.n_epochs)
183 | )
184 | self.compute_metrics()
185 |
186 |
187 | def on_epoch_begin(self):
188 | pass
189 |
190 | def on_epoch_end(self):
191 | self.compute_metrics()
192 | on = self.early_stopping.on
193 | early_stopping_metric = self.early_stopping.early_stopping_metric
194 | save_best_state_metric = self.early_stopping.save_best_state_metric
195 | if save_best_state_metric is not None and on is not None:
196 | if self.early_stopping.update_state(
197 | self.history[save_best_state_metric + "_" + on][-1]
198 | ):
199 | self.best_state_dict = self.model.state_dict()
200 | self.best_epoch = self.epoch
201 |
202 | continue_training = True
203 | if early_stopping_metric is not None and on is not None:
204 | continue_training, reduce_lr = self.early_stopping.update(
205 | self.history[early_stopping_metric + "_" + on][-1]
206 | )
207 | if reduce_lr:
208 | logger.info("Reducing LR.")
209 | for param_group in self.optimizer.param_groups:
210 | param_group["lr"] *= self.early_stopping.lr_factor
211 |
212 | # if self.KL_divergence > self.KL_divergence_max:
213 | # continue_training = False
214 | return continue_training
215 |
216 | @property
217 | @abstractmethod
218 | def posteriors_loop(self):
219 | pass
220 |
221 | def data_loaders_loop(self):
222 | """returns an zipped iterable corresponding to loss signature"""
223 | data_loaders_loop = [self._posteriors[name] for name in self.posteriors_loop]
224 | return zip(
225 | data_loaders_loop[0],
226 | *[cycle(data_loader) for data_loader in data_loaders_loop[1:]]
227 | )
228 |
229 | def register_posterior(self, name, value):
230 | name = name.strip("_")
231 | self._posteriors[name] = value
232 |
233 | def corrupt_posteriors(
234 | self, rate=0.1, corruption="uniform", update_corruption=True
235 | ):
236 | if not hasattr(self.gene_dataset, "corrupted") and update_corruption:
237 | self.gene_dataset.corrupt(rate=rate, corruption=corruption)
238 | for name, posterior in self._posteriors.items():
239 | self.register_posterior(name, posterior.corrupted())
240 |
241 | def uncorrupt_posteriors(self):
242 | for name_, posterior in self._posteriors.items():
243 | self.register_posterior(name_, posterior.uncorrupted())
244 |
245 | def __getattr__(self, name):
246 | if "_posteriors" in self.__dict__:
247 | _posteriors = self.__dict__["_posteriors"]
248 | if name.strip("_") in _posteriors:
249 | return _posteriors[name.strip("_")]
250 | return object.__getattribute__(self, name)
251 |
252 | def __delattr__(self, name):
253 | if name.strip("_") in self._posteriors:
254 | del self._posteriors[name.strip("_")]
255 | else:
256 | object.__delattr__(self, name)
257 |
258 | def __setattr__(self, name, value):
259 | if isinstance(value, Posterior):
260 | name = name.strip("_")
261 | self.register_posterior(name, value)
262 | else:
263 | object.__setattr__(self, name, value)
264 |
265 | def train_test_validation(
266 | self,
267 | model=None,
268 | gene_dataset=None,
269 | train_size=0.1,
270 | test_size=None,
271 | type_class=Posterior,
272 | ):
273 | """Creates posteriors ``train_set``, ``test_set``, ``validation_set``.
274 | If ``train_size + test_size < 1`` then ``validation_set`` is non-empty.
275 |
276 | :param train_size: float, int, or None (default is 0.1)
277 | :param test_size: float, int, or None (default is None)
278 | """
279 | model = self.model if model is None and hasattr(self, "model") else model
280 | gene_dataset = (
281 | self.gene_dataset
282 | if gene_dataset is None and hasattr(self, "model")
283 | else gene_dataset
284 | )
285 | n = len(gene_dataset)
286 | try:
287 | n_train, n_test = _validate_shuffle_split(n, test_size, train_size)
288 | except ValueError:
289 | if train_size != 1.0:
290 | raise ValueError(
291 | "Choice of train_size={} and test_size={} not understood".format(
292 | train_size, test_size
293 | )
294 | )
295 | n_train, n_test = n, 0
296 | random_state = np.random.RandomState(seed=self.seed)
297 | permutation = random_state.permutation(n)
298 | indices_test = permutation[:n_test]
299 | indices_train = permutation[n_test : (n_test + n_train)]
300 | indices_validation = permutation[(n_test + n_train) :]
301 |
302 | return (
303 | self.create_posterior(
304 | model, gene_dataset, indices=indices_train, type_class=type_class
305 | ),
306 | self.create_posterior(
307 | model, gene_dataset, indices=indices_test, type_class=type_class
308 | ),
309 | self.create_posterior(
310 | model, gene_dataset, indices=indices_validation, type_class=type_class
311 | ),
312 | )
313 |
314 | def create_posterior(
315 | self,
316 | model=None,
317 | gene_dataset=None,
318 | shuffle=False,
319 | indices=None,
320 | type_class=Posterior,
321 | ):
322 | model = self.model if model is None and hasattr(self, "model") else model
323 | gene_dataset = (
324 | self.gene_dataset
325 | if gene_dataset is None and hasattr(self, "model")
326 | else gene_dataset
327 | )
328 | return type_class(
329 | model,
330 | gene_dataset,
331 | shuffle=shuffle,
332 | indices=indices,
333 | use_cuda=self.use_cuda,
334 | data_loader_kwargs=self.data_loader_kwargs,
335 | )
336 |
337 |
338 | class SequentialSubsetSampler(SubsetRandomSampler):
339 | def __init__(self, indices):
340 | self.indices = np.sort(indices)
341 |
342 | def __iter__(self):
343 | return iter(self.indices)
344 |
345 |
346 | class EarlyStopping:
347 | def __init__(
348 | self,
349 | early_stopping_metric: str = None,
350 | save_best_state_metric: str = None,
351 | on: str = "test_set",
352 | patience: int = 15,
353 | threshold: int = 3,
354 | benchmark: bool = False,
355 | reduce_lr_on_plateau: bool = False,
356 | lr_patience: int = 10,
357 | lr_factor: float = 0.5,
358 | posterior_class=Posterior,
359 | ):
360 | self.benchmark = benchmark
361 | self.patience = patience
362 | self.threshold = threshold
363 | self.epoch = 0
364 | self.wait = 0
365 | self.wait_lr = 0
366 | self.mode = (
367 | getattr(posterior_class, early_stopping_metric).mode
368 | if early_stopping_metric is not None
369 | else None
370 | )
371 | # We set the best to + inf because we're dealing with a loss we want to minimize
372 | self.current_performance = np.inf
373 | self.best_performance = np.inf
374 | self.best_performance_state = np.inf
375 | # If we want to maximize, we start at - inf
376 | if self.mode == "max":
377 | self.best_performance *= -1
378 | self.current_performance *= -1
379 | self.mode_save_state = (
380 | getattr(Posterior, save_best_state_metric).mode
381 | if save_best_state_metric is not None
382 | else None
383 | )
384 | if self.mode_save_state == "max":
385 | self.best_performance_state *= -1
386 |
387 | self.early_stopping_metric = early_stopping_metric
388 | self.save_best_state_metric = save_best_state_metric
389 | self.on = on
390 | self.reduce_lr_on_plateau = reduce_lr_on_plateau
391 | self.lr_patience = lr_patience
392 | self.lr_factor = lr_factor
393 |
394 | def update(self, scalar):
395 | self.epoch += 1
396 | if self.benchmark:
397 | continue_training = True
398 | reduce_lr = False
399 | elif self.wait >= self.patience:
400 | continue_training = False
401 | reduce_lr = False
402 | else:
403 | # Check if we should reduce the learning rate
404 | if not self.reduce_lr_on_plateau:
405 | reduce_lr = False
406 | elif self.wait_lr >= self.lr_patience:
407 | reduce_lr = True
408 | self.wait_lr = 0
409 | else:
410 | reduce_lr = False
411 | # Shift
412 | self.current_performance = scalar
413 |
414 | # Compute improvement
415 | if self.mode == "max":
416 | improvement = self.current_performance - self.best_performance
417 | elif self.mode == "min":
418 | improvement = self.best_performance - self.current_performance
419 | else:
420 | raise NotImplementedError("Unknown optimization mode")
421 |
422 | # updating best performance
423 | if improvement > 0:
424 | self.best_performance = self.current_performance
425 |
426 | if improvement < self.threshold:
427 | self.wait += 1
428 | self.wait_lr += 1
429 | else:
430 | self.wait = 0
431 | self.wait_lr = 0
432 |
433 | continue_training = True
434 | if not continue_training:
435 | # FIXME: log total number of epochs run
436 | logger.info(
437 | "\nStopping early: no improvement of more than "
438 | + str(self.threshold)
439 | + " nats in "
440 | + str(self.patience)
441 | + " epochs"
442 | )
443 | logger.info(
444 | "If the early stopping criterion is too strong, "
445 | "please instantiate it with different parameters in the train method."
446 | )
447 | return continue_training, reduce_lr
448 |
449 | def update_state(self, scalar):
450 | improved = (
451 | self.mode_save_state == "max" and scalar - self.best_performance_state > 0
452 | ) or (
453 | self.mode_save_state == "min" and self.best_performance_state - scalar > 0
454 | )
455 | if improved:
456 | self.best_performance_state = scalar
457 | return improved
458 |
--------------------------------------------------------------------------------
/scMVP/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .classifier import Classifier
2 |
3 | from .vae_attention import VAE_Attention
4 | from .vaePeak_selfattetion import VAE_Peak_SelfAttention
5 | from .multi_vae_attention import Multi_VAE_Attention
6 |
7 |
8 | __all__ = [
9 | "Classifier"
10 | "VAE_Attention",
11 | "VAE_Peak_SelfAttention",
12 | "Multi_VAE_Attention",
13 | ]
14 |
--------------------------------------------------------------------------------
/scMVP/models/classifier.py:
--------------------------------------------------------------------------------
1 | from torch import nn as nn
2 |
3 | from scMVP.models.modules import FCLayers
4 |
5 |
6 | class Classifier(nn.Module):
7 | def __init__(
8 | self,
9 | n_input,
10 | n_hidden=128,
11 | n_labels=5,
12 | n_layers=1,
13 | dropout_rate=0.1,
14 | logits=False,
15 | ):
16 | super().__init__()
17 | layers = [
18 | FCLayers(
19 | n_in=n_input,
20 | n_out=n_hidden,
21 | n_layers=n_layers,
22 | n_hidden=n_hidden,
23 | dropout_rate=dropout_rate,
24 | use_batch_norm=True,
25 | ),
26 | nn.Linear(n_hidden, n_labels),
27 | ]
28 | if not logits:
29 | layers.append(nn.Softmax(dim=-1))
30 |
31 | self.classifier = nn.Sequential(*layers)
32 |
33 | def forward(self, x):
34 | return self.classifier(x)
35 |
--------------------------------------------------------------------------------
/scMVP/models/log_likelihood.py:
--------------------------------------------------------------------------------
1 | """File for computing log likelihood of the data"""
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import logsumexp
7 | from torch.distributions import Normal, Beta
8 |
9 |
10 | def compute_elbo(vae, posterior, **kwargs):
11 | """ Computes the ELBO.
12 |
13 | The ELBO is the reconstruction error + the KL divergences
14 | between the variational distributions and the priors.
15 | It differs from the marginal log likelihood.
16 | Specifically, it is a lower bound on the marginal log likelihood
17 | plus a term that is constant with respect to the variational distribution.
18 | It still gives good insights on the modeling of the data, and is fast to compute.
19 | """
20 | # Iterate once over the posterior and compute the elbo
21 | elbo = 0
22 | for i_batch, tensors in enumerate(posterior):
23 | sample_batch, local_l_mean, local_l_var, batch_index, labels = tensors[
24 | :5
25 | ] # general fish case
26 | # kl_divergence_global (scalar) should be common across all batches after training
27 | reconst_loss, kl_divergence, kl_divergence_global = vae(
28 | sample_batch,
29 | local_l_mean,
30 | local_l_var,
31 | batch_index=batch_index,
32 | y=labels,
33 | **kwargs
34 | )
35 | elbo += torch.sum(reconst_loss + kl_divergence).item()
36 | n_samples = len(posterior.indices)
37 | elbo += kl_divergence_global
38 | return elbo / n_samples
39 |
40 |
41 | def compute_reconstruction_error(vae, posterior, **kwargs):
42 | """ Computes log p(x/z), which is the reconstruction error.
43 |
44 | Differs from the marginal log likelihood, but still gives good
45 | insights on the modeling of the data, and is fast to compute.
46 | """
47 | # Iterate once over the posterior and computes the reconstruction error
48 | log_lkl = 0
49 | for i_batch, tensors in enumerate(posterior):
50 | sample_batch, local_l_mean, local_l_var, batch_index, labels = tensors[
51 | :5
52 | ] # general fish case
53 |
54 | # Distribution parameters
55 | outputs = vae.inference(sample_batch, batch_index, labels, **kwargs)
56 | px_r = outputs["px_r"]
57 | px_rate = outputs["px_rate"]
58 | px_dropout = outputs["px_dropout"]
59 | bernoulli_params = outputs.get("bernoulli_params", None)
60 |
61 | # Reconstruction loss
62 | reconst_loss = vae.get_reconstruction_loss(
63 | sample_batch,
64 | px_rate,
65 | px_r,
66 | px_dropout,
67 | bernoulli_params=bernoulli_params,
68 | **kwargs
69 | )
70 |
71 | log_lkl += torch.sum(reconst_loss).item()
72 | n_samples = len(posterior.indices)
73 | return log_lkl / n_samples
74 |
75 |
76 | def compute_marginal_log_likelihood_scvi(vae, posterior, n_samples_mc=100):
77 | """ Computes a biased estimator for log p(x), which is the marginal log likelihood.
78 |
79 | Despite its bias, the estimator still converges to the real value
80 | of log p(x) when n_samples_mc (for Monte Carlo) goes to infinity
81 | (a fairly high value like 100 should be enough)
82 | Due to the Monte Carlo sampling, this method is not as computationally efficient
83 | as computing only the reconstruction loss
84 | """
85 | # Uses MC sampling to compute a tighter lower bound on log p(x)
86 | log_lkl = 0
87 | for i_batch, tensors in enumerate(posterior):
88 | sample_batch, local_l_mean, local_l_var, batch_index, labels = tensors
89 | to_sum = torch.zeros(sample_batch.size()[0], n_samples_mc)
90 |
91 | for i in range(n_samples_mc):
92 |
93 | # Distribution parameters and sampled variables
94 | outputs = vae.inference(sample_batch, batch_index, labels)
95 | px_r = outputs["px_r"]
96 | px_rate = outputs["px_rate"]
97 | px_dropout = outputs["px_dropout"]
98 | qz_m = outputs["qz_m"]
99 | qz_v = outputs["qz_v"]
100 | z = outputs["z"]
101 | ql_m = outputs["ql_m"]
102 | ql_v = outputs["ql_v"]
103 | library = outputs["library"]
104 |
105 | # Reconstruction Loss
106 | reconst_loss = vae.get_reconstruction_loss(
107 | sample_batch, px_rate, px_r, px_dropout
108 | )
109 |
110 | # Log-probabilities
111 | p_l = Normal(local_l_mean, local_l_var.sqrt()).log_prob(library).sum(dim=-1)
112 | p_z = (
113 | Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v))
114 | .log_prob(z)
115 | .sum(dim=-1)
116 | )
117 | p_x_zl = -reconst_loss
118 | q_z_x = Normal(qz_m, qz_v.sqrt()).log_prob(z).sum(dim=-1)
119 | q_l_x = Normal(ql_m, ql_v.sqrt()).log_prob(library).sum(dim=-1)
120 |
121 | to_sum[:, i] = p_z + p_l + p_x_zl - q_z_x - q_l_x
122 |
123 | batch_log_lkl = logsumexp(to_sum, dim=-1) - np.log(n_samples_mc)
124 | log_lkl += torch.sum(batch_log_lkl).item()
125 |
126 | n_samples = len(posterior.indices)
127 | # The minus sign is there because we actually look at the negative log likelihood
128 | return -log_lkl / n_samples
129 |
130 |
131 | def compute_marginal_log_likelihood_autozi(autozivae, posterior, n_samples_mc=100):
132 | """ Computes a biased estimator for log p(x), which is the marginal log likelihood.
133 |
134 | Despite its bias, the estimator still converges to the real value
135 | of log p(x) when n_samples_mc (for Monte Carlo) goes to infinity
136 | (a fairly high value like 100 should be enough)
137 | Due to the Monte Carlo sampling, this method is not as computationally efficient
138 | as computing only the reconstruction loss
139 | """
140 | # Uses MC sampling to compute a tighter lower bound on log p(x)
141 | log_lkl = 0
142 | to_sum = torch.zeros((n_samples_mc,))
143 | alphas_betas = autozivae.get_alphas_betas(as_numpy=False)
144 | alpha_prior = alphas_betas["alpha_prior"]
145 | alpha_posterior = alphas_betas["alpha_posterior"]
146 | beta_prior = alphas_betas["beta_prior"]
147 | beta_posterior = alphas_betas["beta_posterior"]
148 |
149 | for i in range(n_samples_mc):
150 |
151 | bernoulli_params = autozivae.sample_from_beta_distribution(
152 | alpha_posterior, beta_posterior
153 | )
154 |
155 | for i_batch, tensors in enumerate(posterior):
156 | sample_batch, local_l_mean, local_l_var, batch_index, labels = tensors
157 |
158 | # Distribution parameters and sampled variables
159 | outputs = autozivae.inference(sample_batch, batch_index, labels)
160 | px_r = outputs["px_r"]
161 | px_rate = outputs["px_rate"]
162 | px_dropout = outputs["px_dropout"]
163 | qz_m = outputs["qz_m"]
164 | qz_v = outputs["qz_v"]
165 | z = outputs["z"]
166 | ql_m = outputs["ql_m"]
167 | ql_v = outputs["ql_v"]
168 | library = outputs["library"]
169 |
170 | # Reconstruction Loss
171 | bernoulli_params_batch = autozivae.reshape_bernoulli(
172 | bernoulli_params, batch_index, labels
173 | )
174 | reconst_loss = autozivae.get_reconstruction_loss(
175 | sample_batch, px_rate, px_r, px_dropout, bernoulli_params_batch
176 | )
177 |
178 | # Log-probabilities
179 | p_l = Normal(local_l_mean, local_l_var.sqrt()).log_prob(library).sum(dim=-1)
180 | p_z = (
181 | Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v))
182 | .log_prob(z)
183 | .sum(dim=-1)
184 | )
185 | p_x_zld = -reconst_loss
186 | q_z_x = Normal(qz_m, qz_v.sqrt()).log_prob(z).sum(dim=-1)
187 | q_l_x = Normal(ql_m, ql_v.sqrt()).log_prob(library).sum(dim=-1)
188 |
189 | batch_log_lkl = torch.sum(p_x_zld + p_l + p_z - q_z_x - q_l_x, dim=0)
190 | to_sum[i] += batch_log_lkl
191 |
192 | p_d = Beta(alpha_prior, beta_prior).log_prob(bernoulli_params).sum()
193 | q_d = Beta(alpha_posterior, beta_posterior).log_prob(bernoulli_params).sum()
194 |
195 | to_sum[i] += p_d - q_d
196 |
197 | log_lkl = logsumexp(to_sum, dim=-1).item() - np.log(n_samples_mc)
198 | n_samples = len(posterior.indices)
199 | # The minus sign is there because we actually look at the negative log likelihood
200 | return -log_lkl / n_samples
201 |
202 | def binary_cross_entropy(x, recon_x, eps=1e-8):
203 | recon_x = torch.sigmoid(recon_x)
204 | res = (x * torch.log(recon_x + eps) + (1 - x) * torch.log(1 - recon_x + eps))
205 | # print(torch.mean(recon_x))
206 | return res
207 |
208 | def mean_square_error(x,recon_x):
209 | res = (x - recon_x)*(x - recon_x)
210 | return res
211 | def mean_square_error_positive(x,recon_x):
212 | #res = (x - recon_x + 1)*(x - recon_x + 1)
213 | #res[x==0] = 0
214 | res = torch.abs((x - recon_x))
215 | res[x == 0] = 0 # test this property
216 | return res
217 |
218 | def log_zip_positive(x, mu, pi, eps=1e-8):
219 | # the likelihood of zero probability p(x=0) = -softplus(-pi)+softplus(-pi-mu)
220 | softplus_pi = F.softplus(-pi)
221 | softplus_mu_pi = F.softplus(-pi-mu)
222 | case_zero = - softplus_pi + softplus_mu_pi
223 | mul_case_zero = torch.mul((x < eps).type(torch.float32), case_zero)
224 |
225 | # the likelihood of p(x>0) = -softplus(-pi) - pi - mu +x*ln(mu) - ln(x!)
226 | case_non_zero = (
227 | - softplus_pi
228 | - pi - mu
229 | + x * torch.log(mu + eps)
230 | - torch.lgamma(x + 1)
231 | )
232 | mul_case_non_zero = torch.mul((x > eps).type(torch.float32), case_non_zero)
233 |
234 | res = mul_case_zero + mul_case_non_zero
235 |
236 | return res
237 |
238 | def log_zinb_positive(x, mu, theta, pi, eps=1e-8):
239 | """
240 | Note: All inputs are torch Tensors
241 | log likelihood (scalar) of a minibatch according to a zinb model.
242 | Notes:
243 | We parametrize the bernoulli using the logits, hence the softplus functions appearing
244 |
245 | Variables:
246 | mu: mean of the negative binomial (has to be positive support) (shape: minibatch x genes)
247 | theta: inverse dispersion parameter (has to be positive support) (shape: minibatch x genes)
248 | pi: logit of the dropout parameter (real support) (shape: minibatch x genes)
249 | eps: numerical stability constant
250 | """
251 |
252 | # theta is the dispersion rate. If .ndimension() == 1, it is shared for all cells (regardless of batch or labels)
253 | if theta.ndimension() == 1:
254 | theta = theta.view(
255 | 1, theta.size(0)
256 | ) # In this case, we reshape theta for broadcasting
257 |
258 | softplus_pi = F.softplus(-pi)
259 | log_theta_eps = torch.log(theta + eps)
260 | log_theta_mu_eps = torch.log(theta + mu + eps)
261 | pi_theta_log = -pi + theta * (log_theta_eps - log_theta_mu_eps)
262 |
263 | case_zero = F.softplus(pi_theta_log) - softplus_pi
264 | mul_case_zero = torch.mul((x < eps).type(torch.float32), case_zero)
265 |
266 | case_non_zero = (
267 | -softplus_pi
268 | + pi_theta_log
269 | + x * (torch.log(mu + eps) - log_theta_mu_eps)
270 | + torch.lgamma(x + theta)
271 | - torch.lgamma(theta)
272 | - torch.lgamma(x + 1)
273 | )
274 | mul_case_non_zero = torch.mul((x > eps).type(torch.float32), case_non_zero)
275 |
276 | res = mul_case_zero + mul_case_non_zero
277 |
278 | return res
279 |
280 |
281 | def log_nb_positive(x, mu, theta, eps=1e-8):
282 | """
283 | Note: All inputs should be torch Tensors
284 | log likelihood (scalar) of a minibatch according to a nb model.
285 |
286 | Variables:
287 | mu: mean of the negative binomial (has to be positive support) (shape: minibatch x genes)
288 | theta: inverse dispersion parameter (has to be positive support) (shape: minibatch x genes)
289 | eps: numerical stability constant
290 | """
291 | if theta.ndimension() == 1:
292 | theta = theta.view(
293 | 1, theta.size(0)
294 | ) # In this case, we reshape theta for broadcasting
295 |
296 | log_theta_mu_eps = torch.log(theta + mu + eps)
297 |
298 | res = (
299 | theta * (torch.log(theta + eps) - log_theta_mu_eps)
300 | + x * (torch.log(mu + eps) - log_theta_mu_eps)
301 | + torch.lgamma(x + theta)
302 | - torch.lgamma(theta)
303 | - torch.lgamma(x + 1)
304 | )
305 |
306 | return res
307 |
308 |
309 | def log_mixture_nb(x, mu_1, mu_2, theta_1, theta_2, pi, eps=1e-8):
310 | """
311 | Note: All inputs should be torch Tensors
312 | log likelihood (scalar) of a minibatch according to a mixture nb model.
313 | pi is the probability to be in the first component.
314 |
315 | For totalVI, the first component should be background.
316 |
317 | Variables:
318 | mu1: mean of the first negative binomial component (has to be positive support) (shape: minibatch x genes)
319 | theta1: first inverse dispersion parameter (has to be positive support) (shape: minibatch x genes)
320 | mu2: mean of the second negative binomial (has to be positive support) (shape: minibatch x genes)
321 | theta2: second inverse dispersion parameter (has to be positive support) (shape: minibatch x genes)
322 | If None, assume one shared inverse dispersion parameter.
323 | eps: numerical stability constant
324 | """
325 | if theta_2 is not None:
326 | log_nb_1 = log_nb_positive(x, mu_1, theta_1)
327 | log_nb_2 = log_nb_positive(x, mu_2, theta_2)
328 | # this is intended to reduce repeated computations
329 | else:
330 | theta = theta_1
331 | if theta.ndimension() == 1:
332 | theta = theta.view(
333 | 1, theta.size(0)
334 | ) # In this case, we reshape theta for broadcasting
335 |
336 | log_theta_mu_1_eps = torch.log(theta + mu_1 + eps)
337 | log_theta_mu_2_eps = torch.log(theta + mu_2 + eps)
338 | lgamma_x_theta = torch.lgamma(x + theta)
339 | lgamma_theta = torch.lgamma(theta)
340 | lgamma_x_plus_1 = torch.lgamma(x + 1)
341 |
342 | log_nb_1 = (
343 | theta * (torch.log(theta + eps) - log_theta_mu_1_eps)
344 | + x * (torch.log(mu_1 + eps) - log_theta_mu_1_eps)
345 | + lgamma_x_theta
346 | - lgamma_theta
347 | - lgamma_x_plus_1
348 | )
349 | log_nb_2 = (
350 | theta * (torch.log(theta + eps) - log_theta_mu_2_eps)
351 | + x * (torch.log(mu_2 + eps) - log_theta_mu_2_eps)
352 | + lgamma_x_theta
353 | - lgamma_theta
354 | - lgamma_x_plus_1
355 | )
356 |
357 | logsumexp = torch.logsumexp(torch.stack((log_nb_1, log_nb_2 - pi)), dim=0)
358 | softplus_pi = F.softplus(-pi)
359 |
360 | log_mixture_nb = logsumexp - softplus_pi
361 |
362 | return log_mixture_nb
363 |
--------------------------------------------------------------------------------
/scMVP/models/multi_vae_attention.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Main module."""
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import math
8 | import numpy as np
9 | from sklearn.mixture import GaussianMixture
10 | from torch.distributions import Normal, kl_divergence as kl
11 |
12 | from scMVP.models.log_likelihood import log_zinb_positive, log_nb_positive, log_zip_positive, binary_cross_entropy, \
13 | mean_square_error_positive
14 | from scMVP.models.modules import Encoder, DecoderSCVI, LinearDecoderSCVI, Multi_Encoder, Multi_Decoder_nb_log, \
15 | reparameterize_gaussian, Encoder_l, Encoder_nb, Multi_Encoder_nb, Multi_Decoder_nb, Classifer, Multi_Decoder_nb_log_peak,\
16 | Encoder_nb_attention, Multi_Encoder_nb_attention, Encoder_nb_selfattention, Multi_Encoder_nb_SelfAttention,\
17 | Multi_Decoder_nb_SelfAttention
18 |
19 | from scMVP.models.utils import one_hot
20 |
21 | torch.backends.cudnn.benchmark = True
22 |
23 |
24 | # VAE model
25 | class Multi_VAE_Attention(nn.Module):
26 | r"""Variational auto-encoder model.
27 |
28 | :param n_input: Number of input genes
29 | :param n_batch: Number of batches
30 | :param n_labels: Number of labels
31 | :param n_hidden: Number of nodes per hidden layer
32 | :param n_latent: Dimensionality of the latent space
33 | :param n_layers: Number of hidden layers used for encoder and decoder NNs
34 | :param dropout_rate: Dropout rate for neural networks
35 | :param mode: One of the following:
36 | * ``'vae'`` -single channel auto-encoder decoder neural framework for scRNA-seq data
37 | * ``'mm-vae'`` -multi-channels auto-encoder decoder neural framework for scRNA and scATAC data
38 | :param dispersion: One of the following
39 |
40 | * ``'gene'`` - dispersion parameter of NB is constant per gene across cells
41 | * ``'gene-batch'`` - dispersion can differ between different batches
42 | * ``'gene-label'`` - dispersion can differ between different labels
43 | * ``'gene-cell'`` - dispersion can differ for every gene in every cell
44 |
45 | :param log_variational: Log(data+1) prior to encoding for numerical stability. Not normalization.
46 | :param reconstruction_loss: One of
47 |
48 | * ``'nb'`` - Negative binomial distribution
49 | * ``'zinb'`` - Zero-inflated negative binomial distribution
50 |
51 | Examples:
52 | >>> gene_dataset = CortexDataset()
53 | >>> vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False,
54 | ... n_labels=gene_dataset.n_labels)
55 |
56 | """
57 |
58 | def __init__(
59 | self,
60 | RNA_input: int,
61 | ATAC_input: int = 0,
62 | n_batch: int = 0,
63 | n_labels: int = 0,
64 | n_hidden: int = 128,
65 | n_latent: int = 10,
66 | n_layers: int = 1,
67 | n_centroids: int = 20,
68 | n_alfa: float = 1.0,
69 | dropout_rate: float = 0.1,
70 | mode="vae",
71 | dispersion: str = "gene",
72 | log_variational: bool = True,
73 | reconstruction_loss: str = "zinb",
74 | isLibrary: bool = True,
75 | is_cluster: bool = True,
76 | classifer_num: int = 0,
77 | ):
78 | super().__init__()
79 | self.mode = mode
80 | self.dispersion = dispersion
81 | self.n_latent = n_latent
82 | self.log_variational = log_variational
83 | self.reconstruction_loss = reconstruction_loss
84 | # Automatically deactivate if useless
85 | self.n_input_atac = ATAC_input
86 | self.n_input_RNA = RNA_input
87 | self.n_batch = n_batch
88 | self.n_labels = n_labels
89 | self.n_centroids = n_centroids
90 | self.alfa = n_alfa
91 | self.isLibrary = isLibrary
92 | self.is_cluster = is_cluster
93 | self.classifer_num = classifer_num
94 |
95 | if self.dispersion == "gene":
96 | self.px_r = torch.nn.Parameter(torch.randn(RNA_input))
97 | self.p_atac_r = torch.nn.Parameter(torch.randn(ATAC_input))
98 | elif self.dispersion == "gene-batch":
99 | self.px_r = torch.nn.Parameter(torch.randn(RNA_input, n_batch))
100 | self.p_atac_r = torch.nn.Parameter(torch.randn(ATAC_input, n_batch))
101 | elif self.dispersion == "gene-label":
102 | self.px_r = torch.nn.Parameter(torch.randn(RNA_input, n_labels))
103 | self.p_atac_r = torch.nn.Parameter(torch.randn(ATAC_input, n_labels))
104 | elif self.dispersion == "gene-cell":
105 | pass
106 | else:
107 | raise ValueError(
108 | "dispersion must be one of ['gene', 'gene-batch',"
109 | " 'gene-label', 'gene-cell'], but input was "
110 | "{}.format(self.dispersion)"
111 | )
112 |
113 | if self.mode == "vae":
114 | # z encoder goes from the n_input-dimensional data to an n_latent-d
115 | # latent space representation
116 | self.z_encoder = Encoder(
117 | RNA_input,
118 | n_latent,
119 | n_layers=n_layers,
120 | n_hidden=n_hidden,
121 | dropout_rate=dropout_rate,
122 | )
123 | # l encoder goes from n_input-dimensional data to 1-d library size
124 | self.l_encoder = Encoder(
125 | RNA_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate
126 | )
127 | # decoder goes from n_latent-dimensional space to n_input-d data
128 | self.decoder = DecoderSCVI(
129 | n_latent,
130 | RNA_input,
131 | n_cat_list=[n_batch],
132 | n_layers=n_layers,
133 | n_hidden=n_hidden,
134 | )
135 | elif self.mode == "mm-vae":
136 | if ATAC_input <= 0:
137 | raise ValueError("Input size of ATAC channel should be positive value,"
138 | "but input was {}.format(self.ATAC_input)"
139 | )
140 |
141 | # init c_params
142 | self.pi = nn.Parameter(torch.ones(n_centroids) / n_centroids, requires_grad=True) # pc
143 | self.mu_c = nn.Parameter(torch.zeros(n_latent, n_centroids), requires_grad=True) # mu
144 | self.var_c = nn.Parameter(torch.ones(n_latent, n_centroids), requires_grad=True) # sigma^2
145 | self.counter = nn.Parameter(torch.zeros(2), requires_grad=False) # sigma^2
146 |
147 | if self.classifer_num > 0:
148 | self.classifer = Classifer(
149 | n_latent,
150 | self.classifer_num,
151 | )
152 |
153 | self.RNA_encoder = Encoder_nb_attention(
154 | RNA_input,
155 | n_latent,
156 | n_layers=n_layers,
157 | n_hidden=n_hidden,
158 | dropout_rate=dropout_rate,
159 | )
160 | self.ATAC_encoder = Encoder_nb_selfattention(
161 | ATAC_input,
162 | n_latent,
163 | n_layers=n_layers,
164 | n_hidden=n_hidden,
165 | dropout_rate=dropout_rate,
166 | )
167 | self.concatenter = nn.Linear(2 * self.n_latent, self.n_latent)
168 | if self.isLibrary == True:
169 | # l encoder goes from n_input-dimensional data to 1-d library size
170 | self.l_encoder = Encoder_l(
171 | RNA_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate
172 | )
173 | self.RNA_ATAC_encoder = Multi_Encoder_nb_SelfAttention(
174 | RNA_input,
175 | ATAC_input,
176 | n_latent,
177 | n_layers=n_layers,
178 | n_hidden=n_hidden,
179 | dropout_rate=dropout_rate,
180 | )
181 | self.RNA_ATAC_decoder = Multi_Decoder_nb_SelfAttention(
182 | n_latent,
183 | RNA_input,
184 | ATAC_input,
185 | n_cat_list=[n_batch],
186 | n_layers=n_layers,
187 | n_hidden=n_hidden,
188 | is_cluster=is_cluster,
189 | n_cluster=n_centroids
190 | )
191 | else:
192 | raise ValueError(
193 | "mode must be one of ['vae', 'mm-vae'"
194 | " ], but input was "
195 | "{}.format(self.mode)"
196 | )
197 |
198 | def get_params(self):
199 | params = [self.pi, self.mu_c, self.var_c]
200 | return params
201 |
202 | def get_latents(self, x_rna, y=None, x_atac=None):
203 | r""" returns the result of ``sample_from_posterior_z`` inside a list
204 |
205 | :param x: tensor of values with shape ``(batch_size, n_input)``
206 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
207 | :return: one element list of tensor
208 | :rtype: list of :py:class:`torch.Tensor`
209 | """
210 | return [self.sample_from_posterior_z([x_rna, x_atac], y)]
211 |
212 | def sample_from_posterior_z(self, x, y=None, give_mean=True):
213 | r""" samples the tensor of latent values from the posterior
214 | #doesn't really sample, returns the means of the posterior distribution
215 |
216 | :param x: tensor of values with shape ``(batch_size, n_input)``
217 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
218 | :param give_mean: is True when we want the mean of the posterior distribution rather than sampling
219 | :return: tensor of shape ``(batch_size, n_latent)``
220 | :rtype: :py:class:`torch.Tensor`
221 | """
222 | if self.log_variational:
223 | x[0] = torch.log(1 + x[0])
224 | x[1] = torch.log(1 + x[1])
225 |
226 | qz_rna_m, qz_rna_v, rna_z = self.RNA_encoder(x[0], None)
227 | qz_atac_m, qz_atac_v, atac_z = self.ATAC_encoder(x[1], None)
228 | qz_m, qz_v, z = self.RNA_ATAC_encoder(x, None)
229 | if give_mean:
230 | z = qz_m,
231 | rna_z = qz_rna_m,
232 | atac_z = qz_atac_m
233 | return [z, rna_z, atac_z]
234 |
235 | def sample_from_posterior_l(self, x):
236 | r""" samples the tensor of library sizes from the posterior
237 | #doesn't really sample, returns the tensor of the means of the posterior distribution
238 |
239 | :param x: tensor of values with shape ``(batch_size, n_input)``
240 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
241 | :return: tensor of shape ``(batch_size, 1)``
242 | :rtype: :py:class:`torch.Tensor`
243 | """
244 | if self.log_variational:
245 | x = torch.log(1 + x)
246 | ql_m, ql_v, library = self.l_encoder(x)
247 | return library
248 |
249 | def get_sample_scale(self, x, batch_index=None, y=None, n_samples=1):
250 | r"""Returns the tensor of predicted frequencies of expression
251 |
252 | :param x: tensor of values with shape ``(batch_size, n_input)``
253 | :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size``
254 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
255 | :param n_samples: number of samples
256 | :return: tensor of predicted frequencies of expression with shape ``(batch_size, n_input)``
257 | :rtype: :py:class:`torch.Tensor`
258 | """
259 | outputs = self.inference(x=x, batch_index=batch_index, y=y, n_samples=n_samples)
260 | return outputs["p_rna_scale"], outputs["p_atac_scale"]
261 |
262 | def get_sample_rate(self, x, batch_index=None, y=None, n_samples=1, local_l_mean=None, local_l_var=None):
263 | r"""Returns the tensor of means of the negative binomial distribution
264 |
265 | :param x: tensor of values with shape ``(batch_size, n_input)``
266 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
267 | :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size``
268 | :param n_samples: number of samples
269 | :return: tensor of means of the negative binomial distribution with shape ``(batch_size, n_input)``
270 | :rtype: :py:class:`torch.Tensor`
271 | """
272 | outputs = self.inference(x=x, batch_index=batch_index, y=y, n_samples=n_samples, local_l_mean=local_l_mean,
273 | local_l_var=local_l_var)
274 | return outputs["p_rna_rate"], outputs["p_atac_mean"]
275 |
276 | def get_reconstruction_loss(self, x, px_rate, px_r, px_dropout, **kwargs):
277 | # Reconstruction Loss
278 | if self.reconstruction_loss == "nb":
279 | reconst_loss = -log_nb_positive(x, px_rate, px_r).sum(dim=-1) + 0.5*mean_square_error_positive(x, px_rate).sum(dim=-1)
280 | elif self.reconstruction_loss == "zinb":
281 | reconst_loss = -log_nb_positive(x, px_rate, px_r).sum(dim=-1) + 0.5*mean_square_error_positive(x, px_rate).sum(dim=-1)
282 |
283 | return reconst_loss
284 |
285 | def get_reconstruction_atac_loss(self, x, mu, dispersion, dropout, type="zip", **kwargs):
286 | if type == "zinb":
287 | reconst_loss = -log_zinb_positive(x, mu, dispersion, dropout).sum(dim=-1)
288 | elif type == "zip":
289 | reconst_loss = 0.5 * mean_square_error_positive(x, mu).sum(dim=-1) - log_zip_positive(x, mu, dropout).sum(dim=-1)
290 | mu[x > 0] = 0
291 | reconst_loss = reconst_loss + 0.05 * mu.sum(dim=-1)
292 | elif type == "zip_bu":
293 | reconst_loss = - log_zip_positive(x, mu, dropout).sum(dim=-1) - binary_cross_entropy(x, mu).sum(dim=-1)
294 | elif type == "bu":
295 | reconst_loss = - binary_cross_entropy(x, mu).sum(dim=-1)
296 | return reconst_loss
297 |
298 | def scale_from_z(self, sample_batch, fixed_batch):
299 | if self.log_variational:
300 | sample_batch[0] = torch.log(1 + sample_batch[0])
301 | sample_batch[1] = torch.log(1 + sample_batch[1])
302 | qz_rna_m, qz_rna_v, rna_z = self.RNA_encoder(sample_batch[0])
303 | qz_atac_m, qz_atac_v, atac_z = self.ATAC_encoder(sample_batch[1])
304 | qz_m, qz_v, z = self.RNA_ATAC_encoder(sample_batch)
305 |
306 | batch_index = fixed_batch * torch.ones_like(sample_batch[:, [0]])
307 | library = 4.0 * torch.ones_like(sample_batch[:, [0]])
308 | px_scale, _, _, _ = self.decoder("gene", z, library, batch_index)
309 | return px_scale
310 |
311 | def init_gmm_params(self, z):
312 | """
313 | Init SCALE model with GMM model parameters
314 | """
315 | if z is None:
316 | raise ("Input data is empty!")
317 |
318 | gmm = GaussianMixture(n_components=self.n_centroids, covariance_type='diag')
319 | gmm.fit(z)
320 | # gmm.weights_
321 | self.mu_c.data.copy_(torch.from_numpy(gmm.means_.T.astype(np.float32)))
322 | self.var_c.data.copy_(torch.from_numpy(gmm.covariances_.T.astype(np.float32)))
323 | clust_index = gmm.predict(z)
324 |
325 | return clust_index
326 |
327 | def init_gmm_params_with_louvain(self, z, label):
328 | """
329 | Init SCALE model with GMM model parameters
330 | """
331 | if z is None or label is None:
332 | raise ("Input data is empty!")
333 |
334 | mu = np.zeros((z.shape[1],len(np.unique(label))))
335 | var = np.zeros((z.shape[1],len(np.unique(label))))
336 | pi = np.zeros(len(np.unique(label)))
337 | for i in range(len(np.unique(label))):
338 | mu[:,i] = np.mean(z[label==i,:],axis=0)
339 | var[:,i] = np.var(z[label==i,:],axis=0)
340 | pi[i] = np.sum(label==i)/len(label)
341 |
342 | self.mu_c.data.copy_(torch.from_numpy(mu.astype(np.float32)))
343 | self.var_c.data.copy_(torch.from_numpy(var.astype(np.float32)))
344 | self.pi.data.copy_(torch.from_numpy(pi.astype(np.float32)))
345 |
346 | return True
347 |
348 | def get_gamma(self, z, update=False):
349 | """
350 | Inference c from z
351 |
352 | gamma is q(c|x)
353 | q(c|x) = p(c|z) = p(c)p(c|z)/p(z)
354 | """
355 | n_centroids = self.n_centroids
356 |
357 | N = z.size(0)
358 | z_org = z
359 | z = z.unsqueeze(2).expand(z.size(0), z.size(1), n_centroids)
360 | pi = torch.abs(self.pi.repeat(N, 1)) # NxK
361 | mu_c = self.mu_c.repeat(N, 1, 1) # NxDxK
362 | var_c = torch.abs(self.var_c.repeat(N, 1, 1)) # NxDxK
363 |
364 | p_c_z = torch.exp(
365 | torch.log(pi) - torch.sum(0.5 * torch.log(2 * math.pi * var_c) + (z - mu_c) ** 2 / (2 * var_c),
366 | dim=1)) + 1e-10
367 | gamma = p_c_z / torch.sum(p_c_z, dim=1, keepdim=True)
368 | return gamma, mu_c, var_c, pi
369 |
370 | def inference(self, x, batch_index=None, y=None, local_l_mean=None, local_l_var=None, update=False, n_samples=1):
371 | x_ = x
372 | if len(x_) != 2:
373 | raise ValueError("Input training data should be 2 data types(RNA and ATAC),"
374 | "but input was only {}.format(len(x_))"
375 | )
376 | x_rna = x_[0]
377 | x_atac = x_[1]
378 | libary_atac = torch.log(x_[1].sum(dim=-1)).reshape(-1, 1)
379 | libary_rna = torch.log(x_[0].sum(dim=-1)).reshape(-1, 1)
380 | if self.log_variational:
381 | x_rna = torch.log(1 + x_rna)
382 | x_atac = torch.log(1 + x_atac)
383 |
384 | # Sampling
385 | if self.isLibrary:
386 | ql_m, ql_v, l_z = self.l_encoder(x_rna, batch_index)
387 | qz_rna_m, qz_rna_v, rna_z = self.RNA_encoder(x_rna, batch_index)
388 | qz_atac_m, qz_atac_v, atac_z = self.ATAC_encoder(x_atac, batch_index)
389 | qz_m, qz_v, z = self.RNA_ATAC_encoder([x_rna, x_atac], batch_index)
390 |
391 | qz_joint_mu = self.concatenter(torch.cat((qz_rna_m, qz_atac_m), 1))
392 | qz_joint_v = self.concatenter(torch.cat((torch.log(qz_rna_v), torch.log(qz_atac_v)), 1))
393 | qz_joint_v = torch.exp(qz_joint_v)
394 | qz_joint_z = Normal(qz_joint_mu, qz_joint_v.sqrt()).rsample()
395 | gamma_joint, _, _, _ = self.get_gamma(qz_joint_z)
396 |
397 | gamma, mu_c, var_c, pi = self.get_gamma(z, update) # , self.n_centroids, c_params)
398 | index = torch.argmax(gamma, dim=1)
399 |
400 | index1 = [i for i in range(len(index))]
401 | mu_c_max = mu_c[index1, :, index]
402 | var_c_max = var_c[index1, :, index]
403 | z_c_max = reparameterize_gaussian(mu_c_max, var_c_max)
404 |
405 | libary_scale = reparameterize_gaussian(local_l_mean, local_l_var)
406 | if self.isLibrary:
407 | libary_scale = libary_rna
408 | # decoder
409 | p_rna_scale, p_rna_r, p_rna_rate, p_rna_dropout, p_atac_scale, p_atac_r, p_atac_mean, p_atac_dropout \
410 | = self.RNA_ATAC_decoder(z, z_c_max, batch_index, libary_scale=libary_scale, gamma=gamma, libary_atac=libary_atac)
411 | # classifer
412 | if self.classifer_num > 0 and y is not None:
413 | classifer_pred = self.classifer(z)
414 | classifer_loss = -100*(
415 | one_hot(y, self.classifer_num)*torch.log(classifer_pred+1.0e-10)
416 | ).sum(dim=-1)
417 |
418 | if self.log_variational:
419 | p_rna_rate_norm = torch.log(1 + p_rna_rate)
420 | p_atac_mean_norm = torch.log(1 + p_atac_mean)
421 | rec_rna_mu, rec_rna_v, rec_rna_z = self.RNA_encoder(p_rna_rate_norm, batch_index)
422 | gamma_rna_rec, _, _, _ = self.get_gamma(rec_rna_z)
423 | rec_atac_mu, rec_atac_v, rec_atac_z = self.ATAC_encoder(p_atac_mean_norm, batch_index)
424 | gamma_atac_rec, _, _, _ = self.get_gamma(rec_atac_z)
425 | rec_joint_mu = self.concatenter(torch.cat((rec_rna_mu, rec_atac_mu), 1))
426 | rec_joint_v = self.concatenter(torch.cat((torch.log(rec_rna_v), torch.log(rec_atac_v)), 1))
427 | rec_joint_v = torch.exp(rec_joint_v)
428 | rec_joint_z = Normal(rec_joint_mu, rec_joint_v.sqrt()).rsample()
429 | gamma_joint_rec, _, _, _ = self.get_gamma(rec_joint_z)
430 |
431 | if self.dispersion == "gene-label":
432 | p_rna_r = F.linear(
433 | one_hot(y, self.n_labels), self.px_r
434 | ) # px_r gets transposed - last dimension is nb genes
435 | p_atac_r = F.linear(
436 | one_hot(y, self.n_labels), self.p_atac_r
437 | )
438 | elif self.dispersion == "gene-batch":
439 | p_rna_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
440 | p_atac_r = F.linear(one_hot(batch_index, self.n_batch), self.p_atac_r)
441 | elif self.dispersion == "gene":
442 | p_rna_r = self.px_r
443 | p_atac_r = self.p_atac_r
444 |
445 | p_rna_r = torch.exp(p_rna_r)
446 | p_atac_r = torch.exp(p_atac_r)
447 |
448 | return dict(
449 | p_rna_scale=p_rna_scale,
450 | p_rna_r=p_rna_r,
451 | p_rna_rate=p_rna_rate,
452 | p_rna_dropout=p_rna_dropout,
453 | p_atac_scale=p_atac_scale,
454 | p_atac_r=p_atac_r,
455 | p_atac_mean=p_atac_mean,
456 | p_atac_dropout=p_atac_dropout,
457 | qz_rna_m=qz_rna_m,
458 | qz_rna_v=qz_rna_v,
459 | rna_z=rna_z,
460 | qz_atac_m=qz_atac_m,
461 | qz_atac_v=qz_atac_v,
462 | atac_z=atac_z,
463 | qz_m=qz_m,
464 | qz_v=qz_v,
465 | z=z,
466 | mu_c=mu_c,
467 | var_c=var_c,
468 | gamma=gamma,
469 | pi=pi,
470 | mu_c_max=mu_c_max,
471 | var_c_max=var_c_max,
472 | z_c_max=z_c_max,
473 | gamma_rna_rec=gamma_rna_rec,
474 | gamma_atac_rec=gamma_atac_rec,
475 | rec_atac_mu=rec_atac_mu,
476 | rec_atac_v=rec_atac_v,
477 | rec_rna_mu=rec_rna_mu,
478 | rec_rna_v=rec_rna_v,
479 | ql_m=ql_m,
480 | ql_v=ql_v,
481 | l_z=l_z,
482 | rec_joint_mu=rec_joint_mu,
483 | rec_joint_v=rec_joint_v,
484 | rec_joint_z=rec_joint_z,
485 | gamma_joint_rec=gamma_joint_rec,
486 | qz_joint_mu=qz_joint_mu,
487 | qz_joint_v=qz_joint_v,
488 | qz_joint_z=qz_joint_z,
489 | gamma_joint=gamma_joint,
490 | classifer_loss=classifer_loss if self.classifer_num > 0 else 0,
491 | )
492 |
493 | def forward(self, x_rna, x_atac, local_l_mean, local_l_var, batch_index=None, y=None):
494 | r""" Returns the reconstruction loss and the Kullback divergences
495 |
496 | :param x: tensor of values with shape (batch_size, n_input)
497 | :param local_l_mean: tensor of means of the prior distribution of latent variable l
498 | with shape (batch_size, 1)
499 | :param local_l_var: tensor of variancess of the prior distribution of latent variable l
500 | with shape (batch_size, 1)
501 | :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size``
502 | :param y: tensor of cell-types labels with shape (batch_size, n_labels)
503 | :return: the reconstruction loss and the Kullback divergences
504 | :rtype: 2-tuple of :py:class:`torch.FloatTensor`
505 | """
506 | # Parameters for z latent distribution
507 | x = [x_rna, x_atac]
508 | outputs = self.inference(x, batch_index, y, local_l_mean, local_l_var, update=False)
509 | qz_rna_m = outputs["qz_rna_m"]
510 | qz_rna_v = outputs["qz_rna_v"]
511 | qz_atac_m = outputs["qz_atac_m"]
512 | qz_atac_v = outputs["qz_atac_v"]
513 | qz_m = outputs["qz_m"]
514 | qz_v = outputs["qz_v"]
515 | p_rna_rate = outputs["p_rna_rate"]
516 | p_rna_r = outputs["p_rna_r"]
517 | p_rna_dropout = outputs["p_rna_dropout"]
518 | p_atac_r = outputs["p_atac_r"]
519 | p_atac_mean = outputs["p_atac_mean"]
520 | p_atac_dropout = outputs["p_atac_dropout"]
521 | mu_c = outputs["mu_c"]
522 | var_c = outputs["var_c"]
523 | gamma = outputs["gamma"]
524 | pi = outputs["pi"]
525 | gamma_rna_rec = outputs["gamma_rna_rec"]
526 | gamma_atac_rec = outputs["gamma_atac_rec"]
527 | rec_atac_mu = outputs["rec_atac_mu"]
528 | rec_atac_v = outputs["rec_atac_v"]
529 | rec_rna_mu = outputs["rec_rna_mu"]
530 | rec_rna_v = outputs["rec_rna_v"]
531 | ql_m = outputs["ql_m"]
532 | ql_v = outputs["ql_v"]
533 | l_z = outputs["l_z"]
534 | rec_joint_mu = outputs["rec_joint_mu"]
535 | rec_joint_v = outputs["rec_joint_v"]
536 | rec_joint_z = outputs["rec_joint_z"]
537 | gamma_joint_rec = outputs["gamma_joint_rec"]
538 | qz_joint_mu = outputs["qz_joint_mu"]
539 | qz_joint_v = outputs["qz_joint_v"]
540 | qz_joint_z = outputs["qz_joint_z"]
541 | gamma_joint = outputs["gamma_joint"]
542 | classifer_loss = outputs["classifer_loss"]
543 |
544 |
545 | n_centroids = pi.size(1)
546 | mu_expand = qz_m.unsqueeze(2).expand(qz_m.size(0), qz_m.size(1), n_centroids)
547 | logvar_expand = qz_v.unsqueeze(2).expand(qz_v.size(0), qz_v.size(1), n_centroids)
548 | # zl
549 |
550 | # log p(z|c)
551 | logpzc = -0.5 * torch.sum(gamma * torch.sum(math.log(2 * math.pi) + \
552 | torch.log(var_c) + \
553 | torch.exp(logvar_expand) / var_c + \
554 | (mu_expand - mu_c) ** 2 / var_c, dim=1), dim=1)
555 | # log p(c)
556 | logpc = torch.sum(gamma * torch.log(pi), 1)
557 |
558 | # log q(z|x) or q entropy
559 | qentropy = -0.5 * torch.sum(1 + qz_v + math.log(2 * math.pi), 1)
560 |
561 | # log q(c|x)
562 | logqcx = torch.sum(gamma * torch.log(gamma), 1)
563 |
564 | # kl(qz||pz)
565 | kld_qz_pz = -logpzc - logpc + qentropy + logqcx
566 | print("logpzc:{}, logqcx:{}".format(torch.mean(logpzc), torch.mean(logqcx)))
567 | # print("gamma={},var_c={}".format(gamma,var_c))
568 | # kl(qz||qz_rna)
569 | kld_qz_rna = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(qz_rna_m, torch.sqrt(qz_rna_v))).sum(
570 | dim=1
571 | )
572 |
573 | # kl(qz||qz_atac)
574 | kld_qz_atac = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(qz_atac_m, torch.sqrt(qz_atac_v))).sum(
575 | # check the postive qz_v
576 | dim=1
577 | )
578 |
579 | # kl(qz||qz_joint)
580 | kld_qz_joint = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(qz_joint_mu, torch.sqrt(qz_joint_v))).sum(
581 | # check the postive qz_v
582 | dim=1
583 | )
584 |
585 | # KL Divergence
586 | kl_divergence = kld_qz_pz + 0.1 * (kld_qz_joint)
587 | if self.isLibrary:
588 |
589 | consistent_loss_rna = -(
590 | torch.softmax(gamma, dim=-1) * torch.log(torch.softmax(gamma_rna_rec, dim=-1) + 1.0e-6) + (
591 | 1 - torch.softmax(gamma, dim=-1)) * torch.log(
592 | 1 - torch.softmax(gamma_rna_rec, dim=-1) + 1.0e-6)).sum(dim=-1)
593 | consistent_loss_atac = -(
594 | torch.softmax(gamma, dim=-1) * torch.log(torch.softmax(gamma_atac_rec, dim=-1) + 1.0e-6) + (
595 | 1 - torch.softmax(gamma, dim=-1)) * torch.log(
596 | 1 - torch.softmax(gamma_atac_rec, dim=-1) + 1.0e-6)).sum(dim=-1)
597 | consistent_loss_joint = -(
598 | torch.softmax(gamma, dim=-1) * torch.log(torch.softmax(gamma_joint_rec, dim=-1) + 1.0e-6) + (
599 | 1 - torch.softmax(gamma, dim=-1)) * torch.log(
600 | 1 - torch.softmax(gamma_joint_rec, dim=-1) + 1.0e-6)).sum(dim=-1)
601 | rec_rna_kl = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(rec_rna_mu, torch.sqrt(rec_rna_v))).sum(
602 | dim=1
603 | )
604 | rec_atac_kl = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(rec_atac_mu, torch.sqrt(rec_atac_v))).sum(
605 | dim=1
606 | )
607 | rec_joint_kl = kl(Normal(qz_joint_mu, torch.sqrt(qz_joint_v)), Normal(rec_joint_mu, torch.sqrt(rec_joint_v))).sum(
608 | dim=1
609 | )
610 |
611 | # likelihood
612 | reconst_loss_rna = 3.0*self.get_reconstruction_loss(x[0], p_rna_rate, p_rna_r, p_rna_dropout)
613 | reconst_loss_atac = 0.1 * self.get_reconstruction_atac_loss(x[1], p_atac_mean, p_atac_r,
614 | p_atac_dropout) # implement this function
615 | reconst_loss = reconst_loss_rna + reconst_loss_atac + classifer_loss
616 | if self.isLibrary:
617 |
618 | reconst_loss = reconst_loss + 0.5 * (consistent_loss_joint -
619 | 50*torch.sum(gamma * gamma,dim=-1) -
620 | 50*torch.sum((torch.sum(gamma,dim=0)/gamma.shape[0])*(torch.log(torch.sum(gamma,dim=0)/gamma.shape[0]+1.0e-10))))
621 | kl_divergence = kl_divergence + 0.1 * (rec_joint_kl)
622 |
623 |
624 | # init the gmm model, training pc
625 | print("kld_qz_pz = %f,kld_qz_rna = %f,kld_qz_atac = %f,kl_divergence = %f,reconst_loss_rna = %f,\
626 | reconst_loss_atac = %f, mu=%f, sigma=%f" % (
627 | torch.mean(kld_qz_pz), torch.mean(kld_qz_rna), torch.mean(kld_qz_atac), \
628 | torch.mean(kl_divergence), torch.mean(reconst_loss_rna), torch.mean(reconst_loss_atac),
629 | torch.mean(self.mu_c), torch.mean(self.var_c)))
630 | return reconst_loss, kl_divergence, 0.0
631 |
632 |
633 | class LDVAE(Multi_VAE_Attention):
634 | r"""Linear-decoded Variational auto-encoder model.
635 |
636 | This model uses a linear decoder, directly mapping the latent representation
637 | to gene expression levels. It still uses a deep neural network to encode
638 | the latent representation.
639 |
640 | Compared to standard VAE, this model is less powerful, but can be used to
641 | inspect which genes contribute to variation in the dataset.
642 |
643 | :param n_input: Number of input genes
644 | :param n_batch: Number of batches
645 | :param n_labels: Number of labels
646 | :param n_hidden: Number of nodes per hidden layer (for encoder)
647 | :param n_latent: Dimensionality of the latent space
648 | :param n_layers: Number of hidden layers used for encoder NNs
649 | :param dropout_rate: Dropout rate for neural networks
650 | :param dispersion: One of the following
651 |
652 | * ``'gene'`` - dispersion parameter of NB is constant per gene across cells
653 | * ``'gene-batch'`` - dispersion can differ between different batches
654 | * ``'gene-label'`` - dispersion can differ between different labels
655 | * ``'gene-cell'`` - dispersion can differ for every gene in every cell
656 |
657 | :param log_variational: Log(data+1) prior to encoding for numerical stability. Not normalization.
658 | :param reconstruction_loss: One of
659 |
660 | * ``'nb'`` - Negative binomial distribution
661 | * ``'zinb'`` - Zero-inflated negative binomial distribution
662 | """
663 |
664 | def __init__(
665 | self,
666 | n_input: int,
667 | n_batch: int = 0,
668 | n_labels: int = 0,
669 | n_hidden: int = 128,
670 | n_latent: int = 10,
671 | n_layers: int = 1,
672 | dropout_rate: float = 0.1,
673 | dispersion: str = "gene",
674 | log_variational: bool = True,
675 | reconstruction_loss: str = "zinb",
676 | ):
677 | super().__init__(
678 | n_input,
679 | n_batch,
680 | n_labels,
681 | n_hidden,
682 | n_latent,
683 | n_layers,
684 | dropout_rate,
685 | dispersion,
686 | log_variational,
687 | reconstruction_loss,
688 | )
689 |
690 | self.decoder = LinearDecoderSCVI(
691 | n_latent,
692 | n_input,
693 | n_cat_list=[n_batch],
694 | n_layers=n_layers,
695 | n_hidden=n_hidden,
696 | )
697 |
698 | def get_loadings(self):
699 | """ Extract per-gene weights (for each Z) in the linear decoder.
700 | """
701 | return self.decoder.factor_regressor.parameters()
702 |
--------------------------------------------------------------------------------
/scMVP/models/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import scipy.io as sp_io
3 | import numpy as np
4 |
5 | def iterate(obj, func):
6 | t = type(obj)
7 | if t is list or t is tuple:
8 | return t([iterate(o, func) for o in obj])
9 | else:
10 | return func(obj) if obj is not None else None
11 |
12 |
13 | def broadcast_labels(y, *o, n_broadcast=-1):
14 | """
15 | Utility for the semi-supervised setting
16 | If y is defined(labelled batch) then one-hot encode the labels (no broadcasting needed)
17 | If y is undefined (unlabelled batch) then generate all possible labels (and broadcast other arguments if not None)
18 | """
19 | if not len(o):
20 | raise ValueError("Broadcast must have at least one reference argument")
21 | if y is None:
22 | ys = enumerate_discrete(o[0], n_broadcast)
23 | new_o = iterate(
24 | o,
25 | lambda x: x.repeat(n_broadcast, 1)
26 | if len(x.size()) == 2
27 | else x.repeat(n_broadcast),
28 | )
29 | else:
30 | ys = one_hot(y, n_broadcast)
31 | new_o = o
32 | return (ys,) + new_o
33 |
34 |
35 | def one_hot(index, n_cat):
36 | onehot = torch.zeros(index.size(0), n_cat, device=index.device)
37 | onehot.scatter_(1, index.type(torch.long), 1)
38 | return onehot.type(torch.float32)
39 |
40 |
41 | def enumerate_discrete(x, y_dim):
42 | def batch(batch_size, label):
43 | labels = torch.ones(batch_size, 1, device=x.device, dtype=torch.long) * label
44 | return one_hot(labels, y_dim)
45 |
46 | batch_size = x.size(0)
47 | return torch.cat([batch(batch_size, i) for i in range(y_dim)])
48 |
49 |
50 | def binarization(imputed, raw):
51 | return (imputed.T > np.quantile(imputed,q=0.8,axis=1).T).T & (imputed>imputed.mean(0)) & (imputed>0).astype(np.int8)
52 |
--------------------------------------------------------------------------------
/scMVP/models/vaePeak_selfattetion.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Main module."""
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.distributions import Normal, kl_divergence as kl
8 |
9 | from scMVP.models.log_likelihood import log_zinb_positive, log_nb_positive, mean_square_error, mean_square_error_positive, log_zip_positive
10 | from scMVP.models.modules import Encoder, DecoderSCVI, LinearDecoderSCVI, DecoderSCVI_nb, DecoderSCVI_mse, Encoder_l, Encoder_mse,\
11 | Encoder_nb, DecoderSCVI_Peak, DecoderSCVI_Peak_Selfattention, Encoder_nb_selfattention, DecoderSCVI_nb_Selfattention
12 | from scMVP.models.utils import one_hot
13 |
14 | torch.backends.cudnn.benchmark = True
15 |
16 |
17 | # VAE model
18 | class VAE_Peak_SelfAttention(nn.Module):
19 | r"""Variational auto-encoder model.
20 |
21 | :param n_input: Number of input genes
22 | :param n_batch: Number of batches
23 | :param n_labels: Number of labels
24 | :param n_hidden: Number of nodes per hidden layer
25 | :param n_latent: Dimensionality of the latent space
26 | :param n_layers: Number of hidden layers used for encoder and decoder NNs
27 | :param dropout_rate: Dropout rate for neural networks
28 | :param mode: One of the following:
29 | * ``'vae'`` -single channel auto-encoder decoder neural framework for scRNA-seq data
30 | * ``'mm-vae'`` -multi-channels auto-encoder decoder neural framework for scRNA and scATAC data
31 | :param dispersion: One of the following
32 |
33 | * ``'gene'`` - dispersion parameter of NB is constant per gene across cells
34 | * ``'gene-batch'`` - dispersion can differ between different batches
35 | * ``'gene-label'`` - dispersion can differ between different labels
36 | * ``'gene-cell'`` - dispersion can differ for every gene in every cell
37 |
38 | :param log_variational: Log(data+1) prior to encoding for numerical stability. Not normalization.
39 | :param reconstruction_loss: One of
40 |
41 | * ``'nb'`` - Negative binomial distribution
42 | * ``'zinb'`` - Zero-inflated negative binomial distribution
43 |
44 | Examples:
45 | >>> gene_dataset = CortexDataset()
46 | >>> vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False,
47 | ... n_labels=gene_dataset.n_labels)
48 |
49 | """
50 |
51 | def __init__(
52 | self,
53 | n_input: int,
54 | n_batch: int = 0,
55 | n_labels: int = 0,
56 | n_hidden: int = 128, # 256
57 | n_latent: int = 10, # 20
58 | n_layers: int = 1,
59 | dropout_rate: float = 0.1,
60 | dispersion: str = "gene",
61 | log_variational: bool = True,
62 | reconstruction_loss: str = "zinb",
63 | #reconstruction_loss: str = "nb",
64 | ):
65 | super().__init__()
66 | self.dispersion = dispersion
67 | self.n_latent = n_latent
68 | self.log_variational = log_variational
69 | self.reconstruction_loss = reconstruction_loss
70 | # Automatically deactivate if useless
71 | self.n_batch = n_batch
72 | self.n_labels = n_labels
73 |
74 |
75 | if self.dispersion == "gene":
76 | self.px_r = torch.nn.Parameter(torch.randn(n_input))
77 | elif self.dispersion == "gene-batch":
78 | self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch))
79 | elif self.dispersion == "gene-label":
80 | self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels))
81 | elif self.dispersion == "gene-cell":
82 | pass
83 | else:
84 | raise ValueError(
85 | "dispersion must be one of ['gene', 'gene-batch',"
86 | " 'gene-label', 'gene-cell'], but input was "
87 | "{}.format(self.dispersion)"
88 | )
89 | # z encoder goes from the n_input-dimensional data to an n_latent-d
90 | # latent space representation
91 | if self.reconstruction_loss == "mse":
92 | self.z_encoder = Encoder_mse(
93 | n_input,
94 | n_latent,
95 | n_layers=n_layers,
96 | n_hidden=n_hidden,
97 | dropout_rate=dropout_rate,
98 | )
99 | elif self.reconstruction_loss == "nb":
100 | self.z_encoder = Encoder_nb_selfattention(
101 | n_input,
102 | n_latent,
103 | n_layers=n_layers,
104 | n_hidden=n_hidden,
105 | dropout_rate=dropout_rate,
106 | )
107 | else:
108 | self.z_encoder = Encoder(
109 | n_input,
110 | n_latent,
111 | n_layers=n_layers,
112 | n_hidden=n_hidden,
113 | dropout_rate=dropout_rate,
114 | )
115 | # l encoder goes from n_input-dimensional data to 1-d library size
116 | if self.reconstruction_loss == "nb":
117 | self.l_encoder = Encoder_l(
118 | n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate,
119 | )
120 | elif self.reconstruction_loss == "zinb":
121 | self.l_encoder = Encoder_l(
122 | n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate,
123 | )
124 | else:
125 | self.l_encoder = Encoder(
126 | n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate,
127 | )
128 | if self.reconstruction_loss == "mse":
129 | self.decoder = DecoderSCVI_mse(
130 | n_latent,
131 | n_input,
132 | n_cat_list=[n_batch],
133 | n_layers=n_layers,
134 | n_hidden=n_hidden,
135 | )
136 | elif self.reconstruction_loss == "nb" and self.log_variational:
137 | self.decoder = DecoderSCVI_Peak_Selfattention(
138 | n_latent,
139 | n_input,
140 | n_cat_list=[n_batch],
141 | n_layers=n_layers,
142 | n_hidden=n_hidden,
143 | )
144 | elif self.reconstruction_loss == "nb" and not self.log_variational:
145 | self.decoder = DecoderSCVI_Peak_Selfattention(
146 | n_latent,
147 | n_input,
148 | n_cat_list=[n_batch],
149 | n_layers=n_layers,
150 | n_hidden=n_hidden,
151 | )
152 | else:
153 | self.decoder = DecoderSCVI(
154 | n_latent,
155 | n_input,
156 | n_cat_list=[n_batch],
157 | n_layers=n_layers,
158 | n_hidden=n_hidden,
159 | )
160 |
161 | def get_latents(self, x, y=None):
162 | r""" returns the result of ``sample_from_posterior_z`` inside a list
163 |
164 | :param x: tensor of values with shape ``(batch_size, n_input)``
165 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
166 | :return: one element list of tensor
167 | :rtype: list of :py:class:`torch.Tensor`
168 | """
169 | return [self.sample_from_posterior_z(x, y)]
170 |
171 | def sample_from_posterior_z(self, x, y=None, give_mean=False):
172 | r""" samples the tensor of latent values from the posterior
173 | #doesn't really sample, returns the means of the posterior distribution
174 |
175 | :param x: tensor of values with shape ``(batch_size, n_input)``
176 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
177 | :param give_mean: is True when we want the mean of the posterior distribution rather than sampling
178 | :return: tensor of shape ``(batch_size, n_latent)``
179 | :rtype: :py:class:`torch.Tensor`
180 | """
181 | if self.log_variational:
182 | x = torch.log(1 + x)
183 | qz_m, qz_v, z = self.z_encoder(x, y) # y only used in VAEC
184 | if give_mean:
185 | z = qz_m
186 | return z
187 |
188 | def sample_from_posterior_l(self, x):
189 | r""" samples the tensor of library sizes from the posterior
190 | #doesn't really sample, returns the tensor of the means of the posterior distribution
191 |
192 | :param x: tensor of values with shape ``(batch_size, n_input)``
193 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
194 | :return: tensor of shape ``(batch_size, 1)``
195 | :rtype: :py:class:`torch.Tensor`
196 | """
197 | if self.log_variational:
198 | x = torch.log(1 + x)
199 | ql_m, ql_v, library = self.l_encoder(x)
200 | return library
201 |
202 | def get_sample_scale(self, x, batch_index=None, y=None, n_samples=1):
203 | r"""Returns the tensor of predicted frequencies of expression
204 |
205 | :param x: tensor of values with shape ``(batch_size, n_input)``
206 | :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size``
207 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
208 | :param n_samples: number of samples
209 | :return: tensor of predicted frequencies of expression with shape ``(batch_size, n_input)``
210 | :rtype: :py:class:`torch.Tensor`
211 | """
212 | return self.inference(x, batch_index=batch_index, y=y, n_samples=n_samples)[
213 | "px_scale"
214 | ]
215 |
216 | def get_sample_rate(self, x, batch_index=None, y=None, n_samples=1):
217 | r"""Returns the tensor of means of the negative binomial distribution
218 |
219 | :param x: tensor of values with shape ``(batch_size, n_input)``
220 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
221 | :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size``
222 | :param n_samples: number of samples
223 | :return: tensor of means of the negative binomial distribution with shape ``(batch_size, n_input)``
224 | :rtype: :py:class:`torch.Tensor`
225 | """
226 | return self.inference(x, batch_index=batch_index, y=y, n_samples=n_samples)[
227 | "px_rate"
228 | ]
229 |
230 | def get_reconstruction_loss(self, x, px_rate, px_r, px_dropout, **kwargs):
231 | # Reconstruction Loss
232 | if self.reconstruction_loss == "zinb":
233 | reconst_loss = -log_nb_positive(x, px_rate, px_r).sum(dim=-1)
234 | elif self.reconstruction_loss == "nb":
235 | reconst_loss = 0.5*mean_square_error_positive(x, px_rate).sum(dim=-1) - log_zip_positive(x, px_rate, px_dropout).sum(dim=-1)
236 | px_rate[x > 0] = 0
237 | reconst_loss = reconst_loss + 0.05*px_rate.sum(dim=-1)
238 |
239 | elif self.reconstruction_loss == "mse":
240 | reconst_loss = mean_square_error_positive(x, px_rate).sum(dim=-1)
241 | return reconst_loss
242 |
243 | def scale_from_z(self, sample_batch, fixed_batch):
244 | if self.log_variational:
245 | sample_batch = torch.log(1 + sample_batch)
246 | qz_m, qz_v, z = self.z_encoder(sample_batch)
247 | batch_index = fixed_batch * torch.ones_like(sample_batch[:, [0]])
248 | library = 4.0 * torch.ones_like(sample_batch[:, [0]])
249 | px_scale, _, _, _ = self.decoder("gene", z, library, batch_index)
250 | return px_scale
251 |
252 | def inference(self, x, batch_index=None, y=None, n_samples=1):
253 |
254 | x_ = x
255 | if self.reconstruction_loss == "nb" and self.log_variational:
256 | library_nb = torch.log(x_.sum(dim=-1)).reshape(-1, 1)
257 | elif self.reconstruction_loss == "nb" and (not self.log_variational):
258 | library_nb = (x_.sum(dim=-1)).reshape(-1, 1)
259 | if self.log_variational:
260 | x_ = torch.log(1 + x_)
261 |
262 | # Sampling
263 | qz_m, qz_v, z = self.z_encoder(x_, y)
264 | ql_m, ql_v, library = self.l_encoder(x_)
265 | if self.reconstruction_loss == "nb":
266 | library = library_nb
267 |
268 | if n_samples > 1:
269 | qz_m = qz_m.unsqueeze(0).expand((n_samples, qz_m.size(0), qz_m.size(1)))
270 | qz_v = qz_v.unsqueeze(0).expand((n_samples, qz_v.size(0), qz_v.size(1)))
271 | z = Normal(qz_m, qz_v.sqrt()).sample()
272 | ql_m = ql_m.unsqueeze(0).expand((n_samples, ql_m.size(0), ql_m.size(1)))
273 | ql_v = ql_v.unsqueeze(0).expand((n_samples, ql_v.size(0), ql_v.size(1)))
274 | library = Normal(ql_m, ql_v.sqrt()).sample()
275 |
276 | px_scale, px_r, px_rate, px_dropout = self.decoder(
277 | self.dispersion, z, library, batch_index, y
278 | )
279 | if self.dispersion == "gene-label":
280 | px_r = F.linear(
281 | one_hot(y, self.n_labels), self.px_r
282 | ) # px_r gets transposed - last dimension is nb genes
283 | elif self.dispersion == "gene-batch":
284 | px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
285 | elif self.dispersion == "gene":
286 | px_r = self.px_r
287 | px_r = torch.exp(px_r)
288 |
289 | return dict(
290 | px_scale=px_scale,
291 | px_r=px_r,
292 | px_rate=px_rate,
293 | px_dropout=px_dropout,
294 | qz_m=qz_m,
295 | qz_v=qz_v,
296 | z=z,
297 | ql_m=ql_m,
298 | ql_v=ql_v,
299 | library=library,
300 | )
301 |
302 | def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):
303 | r""" Returns the reconstruction loss and the Kullback divergences
304 |
305 | :param x: tensor of values with shape (batch_size, n_input)
306 | :param local_l_mean: tensor of means of the prior distribution of latent variable l
307 | with shape (batch_size, 1)
308 | :param local_l_var: tensor of variancess of the prior distribution of latent variable l
309 | with shape (batch_size, 1)
310 | :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size``
311 | :param y: tensor of cell-types labels with shape (batch_size, n_labels)
312 | :return: the reconstruction loss and the Kullback divergences
313 | :rtype: 2-tuple of :py:class:`torch.FloatTensor`
314 | """
315 | # Parameters for z latent distribution
316 | outputs = self.inference(x, batch_index, None)
317 | qz_m = outputs["qz_m"]
318 | qz_v = outputs["qz_v"]
319 | ql_m = outputs["ql_m"]
320 | ql_v = outputs["ql_v"]
321 | px_rate = outputs["px_rate"]
322 | px_r = outputs["px_r"]
323 | px_dropout = outputs["px_dropout"]
324 |
325 | # KL Divergence
326 | mean = torch.zeros_like(qz_m)
327 | scale = torch.ones_like(qz_v)
328 |
329 |
330 | kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(
331 | dim=1
332 | )
333 | kl_divergence_l = kl(
334 | Normal(ql_m, torch.sqrt(ql_v)),
335 | Normal(local_l_mean, torch.sqrt(local_l_var)),
336 | ).sum(dim=1)
337 | kl_divergence = kl_divergence_z
338 |
339 | reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout)
340 |
341 | if self.reconstruction_loss == "mse" or self.reconstruction_loss == "nb":
342 | kl_divergence_l = 1.0
343 | print("reconst_loss=%f, kl_divergence=%f"%(torch.mean(reconst_loss),torch.mean(kl_divergence)))
344 | return reconst_loss + kl_divergence_l, kl_divergence, 0.0
345 |
346 |
347 | class LDVAE(VAE_Peak_SelfAttention):
348 | r"""Linear-decoded Variational auto-encoder model.
349 |
350 | This model uses a linear decoder, directly mapping the latent representation
351 | to gene expression levels. It still uses a deep neural network to encode
352 | the latent representation.
353 |
354 | Compared to standard VAE, this model is less powerful, but can be used to
355 | inspect which genes contribute to variation in the dataset.
356 |
357 | :param n_input: Number of input genes
358 | :param n_batch: Number of batches
359 | :param n_labels: Number of labels
360 | :param n_hidden: Number of nodes per hidden layer (for encoder)
361 | :param n_latent: Dimensionality of the latent space
362 | :param n_layers: Number of hidden layers used for encoder NNs
363 | :param dropout_rate: Dropout rate for neural networks
364 | :param dispersion: One of the following
365 |
366 | * ``'gene'`` - dispersion parameter of NB is constant per gene across cells
367 | * ``'gene-batch'`` - dispersion can differ between different batches
368 | * ``'gene-label'`` - dispersion can differ between different labels
369 | * ``'gene-cell'`` - dispersion can differ for every gene in every cell
370 |
371 | :param log_variational: Log(data+1) prior to encoding for numerical stability. Not normalization.
372 | :param reconstruction_loss: One of
373 |
374 | * ``'nb'`` - Negative binomial distribution
375 | * ``'zinb'`` - Zero-inflated negative binomial distribution
376 | """
377 |
378 | def __init__(
379 | self,
380 | n_input: int,
381 | n_batch: int = 0,
382 | n_labels: int = 0,
383 | n_hidden: int = 128,
384 | n_latent: int = 10,
385 | n_layers: int = 1,
386 | dropout_rate: float = 0.1,
387 | dispersion: str = "gene",
388 | log_variational: bool = True,
389 | reconstruction_loss: str = "zinb",
390 | ):
391 | super().__init__(
392 | n_input,
393 | n_batch,
394 | n_labels,
395 | n_hidden,
396 | n_latent,
397 | n_layers,
398 | dropout_rate,
399 | dispersion,
400 | log_variational,
401 | reconstruction_loss,
402 | )
403 |
404 | self.decoder = LinearDecoderSCVI(
405 | n_latent,
406 | n_input,
407 | n_cat_list=[n_batch],
408 | n_layers=n_layers,
409 | n_hidden=n_hidden,
410 | )
411 |
412 | def get_loadings(self):
413 | """ Extract per-gene weights (for each Z) in the linear decoder.
414 | """
415 | return self.decoder.factor_regressor.parameters()
416 |
--------------------------------------------------------------------------------
/scMVP/models/vae_attention.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Main module."""
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.distributions import Normal, kl_divergence as kl
8 |
9 | from scMVP.models.log_likelihood import log_zinb_positive, log_nb_positive, mean_square_error, mean_square_error_positive
10 | from scMVP.models.modules import Encoder, DecoderSCVI, LinearDecoderSCVI, DecoderSCVI_nb, DecoderSCVI_mse, Encoder_l, Encoder_mse,\
11 | Encoder_nb, DecoderSCVI_nb_rna, Encoder_nb_attention, DecoderSCVI_nb_Selfattention, Encoder_nb_selfattention
12 | from scMVP.models.utils import one_hot
13 |
14 | torch.backends.cudnn.benchmark = True
15 |
16 |
17 | # VAE model
18 | class VAE_Attention(nn.Module):
19 | r"""Variational auto-encoder model.
20 |
21 | :param n_input: Number of input genes
22 | :param n_batch: Number of batches
23 | :param n_labels: Number of labels
24 | :param n_hidden: Number of nodes per hidden layer
25 | :param n_latent: Dimensionality of the latent space
26 | :param n_layers: Number of hidden layers used for encoder and decoder NNs
27 | :param dropout_rate: Dropout rate for neural networks
28 | :param mode: One of the following:
29 | * ``'vae'`` -single channel auto-encoder decoder neural framework for scRNA-seq data
30 | * ``'mm-vae'`` -multi-channels auto-encoder decoder neural framework for scRNA and scATAC data
31 | :param dispersion: One of the following
32 |
33 | * ``'gene'`` - dispersion parameter of NB is constant per gene across cells
34 | * ``'gene-batch'`` - dispersion can differ between different batches
35 | * ``'gene-label'`` - dispersion can differ between different labels
36 | * ``'gene-cell'`` - dispersion can differ for every gene in every cell
37 |
38 | :param log_variational: Log(data+1) prior to encoding for numerical stability. Not normalization.
39 | :param reconstruction_loss: One of
40 |
41 | * ``'nb'`` - Negative binomial distribution
42 | * ``'zinb'`` - Zero-inflated negative binomial distribution
43 |
44 | Examples:
45 | >>> gene_dataset = CortexDataset()
46 | >>> vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False,
47 | ... n_labels=gene_dataset.n_labels)
48 |
49 | """
50 |
51 | def __init__(
52 | self,
53 | n_input: int,
54 | n_batch: int = 0,
55 | n_labels: int = 0,
56 | n_hidden: int = 128, # 256
57 | n_latent: int = 10, # 20
58 | n_layers: int = 1,
59 | dropout_rate: float = 0.1,
60 | dispersion: str = "gene",
61 | log_variational: bool = True,
62 | reconstruction_loss: str = "zinb",
63 | #reconstruction_loss: str = "nb",
64 | ):
65 | super().__init__()
66 | self.dispersion = dispersion
67 | self.n_latent = n_latent
68 | self.log_variational = log_variational
69 | self.reconstruction_loss = reconstruction_loss
70 | # Automatically deactivate if useless
71 | self.n_batch = n_batch
72 | self.n_labels = n_labels
73 |
74 | if self.dispersion == "gene":
75 | self.px_r = torch.nn.Parameter(torch.randn(n_input))
76 | elif self.dispersion == "gene-batch":
77 | self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch))
78 | elif self.dispersion == "gene-label":
79 | self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels))
80 | elif self.dispersion == "gene-cell":
81 | pass
82 | else:
83 | raise ValueError(
84 | "dispersion must be one of ['gene', 'gene-batch',"
85 | " 'gene-label', 'gene-cell'], but input was "
86 | "{}.format(self.dispersion)"
87 | )
88 | # z encoder goes from the n_input-dimensional data to an n_latent-d
89 | # latent space representation
90 | if self.reconstruction_loss == "mse":
91 | self.z_encoder = Encoder_mse(
92 | n_input,
93 | n_latent,
94 | n_layers=n_layers,
95 | n_hidden=n_hidden,
96 | dropout_rate=dropout_rate,
97 | )
98 | elif self.reconstruction_loss == "nb":
99 | self.z_encoder = Encoder_nb_attention(
100 | n_input,
101 | n_latent,
102 | n_layers=n_layers,
103 | n_hidden=n_hidden,
104 | dropout_rate=dropout_rate,
105 | )
106 |
107 | else:
108 | self.z_encoder = Encoder(
109 | n_input,
110 | n_latent,
111 | n_layers=n_layers,
112 | n_hidden=n_hidden,
113 | dropout_rate=dropout_rate,
114 | )
115 | # l encoder goes from n_input-dimensional data to 1-d library size
116 | if self.reconstruction_loss == "nb":
117 | self.l_encoder = Encoder_l(
118 | n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate,
119 | )
120 | elif self.reconstruction_loss == "zinb":
121 | self.l_encoder = Encoder_l(
122 | n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate,
123 | )
124 | else:
125 | self.l_encoder = Encoder(
126 | n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate,
127 | )
128 | if self.reconstruction_loss == "mse":
129 | self.decoder = DecoderSCVI_mse(
130 | n_latent,
131 | n_input,
132 | n_cat_list=[n_batch],
133 | n_layers=n_layers,
134 | n_hidden=n_hidden,
135 | )
136 | elif self.reconstruction_loss == "nb" and self.log_variational:
137 | self.decoder = DecoderSCVI_nb(
138 | n_latent,
139 | n_input,
140 | n_cat_list=[n_batch],
141 | n_layers=n_layers,
142 | n_hidden=n_hidden,
143 | )
144 | elif self.reconstruction_loss == "nb" and not self.log_variational:
145 | self.decoder = DecoderSCVI_nb_rna(
146 | n_latent,
147 | n_input,
148 | n_cat_list=[n_batch],
149 | n_layers=n_layers,
150 | n_hidden=n_hidden,
151 | )
152 | else:
153 | self.decoder = DecoderSCVI(
154 | n_latent,
155 | n_input,
156 | n_cat_list=[n_batch],
157 | n_layers=n_layers,
158 | n_hidden=n_hidden,
159 | )
160 |
161 | def get_latents(self, x, y=None):
162 | r""" returns the result of ``sample_from_posterior_z`` inside a list
163 |
164 | :param x: tensor of values with shape ``(batch_size, n_input)``
165 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
166 | :return: one element list of tensor
167 | :rtype: list of :py:class:`torch.Tensor`
168 | """
169 | return [self.sample_from_posterior_z(x, y)]
170 |
171 | def sample_from_posterior_z(self, x, y=None, give_mean=False):
172 | r""" samples the tensor of latent values from the posterior
173 | #doesn't really sample, returns the means of the posterior distribution
174 |
175 | :param x: tensor of values with shape ``(batch_size, n_input)``
176 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
177 | :param give_mean: is True when we want the mean of the posterior distribution rather than sampling
178 | :return: tensor of shape ``(batch_size, n_latent)``
179 | :rtype: :py:class:`torch.Tensor`
180 | """
181 | if self.log_variational:
182 | x = torch.log(1 + x)
183 | qz_m, qz_v, z = self.z_encoder(x, y) # y only used in VAEC
184 | if give_mean:
185 | z = qz_m
186 | return z
187 |
188 | def sample_from_posterior_l(self, x):
189 | r""" samples the tensor of library sizes from the posterior
190 | #doesn't really sample, returns the tensor of the means of the posterior distribution
191 |
192 | :param x: tensor of values with shape ``(batch_size, n_input)``
193 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
194 | :return: tensor of shape ``(batch_size, 1)``
195 | :rtype: :py:class:`torch.Tensor`
196 | """
197 | if self.log_variational:
198 | x = torch.log(1 + x)
199 | ql_m, ql_v, library = self.l_encoder(x)
200 | return library
201 |
202 | def get_sample_scale(self, x, batch_index=None, y=None, n_samples=1):
203 | r"""Returns the tensor of predicted frequencies of expression
204 |
205 | :param x: tensor of values with shape ``(batch_size, n_input)``
206 | :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size``
207 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
208 | :param n_samples: number of samples
209 | :return: tensor of predicted frequencies of expression with shape ``(batch_size, n_input)``
210 | :rtype: :py:class:`torch.Tensor`
211 | """
212 | return self.inference(x, batch_index=batch_index, y=y, n_samples=n_samples)[
213 | "px_scale"
214 | ]
215 |
216 | def get_sample_rate(self, x, batch_index=None, y=None, n_samples=1):
217 | r"""Returns the tensor of means of the negative binomial distribution
218 |
219 | :param x: tensor of values with shape ``(batch_size, n_input)``
220 | :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
221 | :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size``
222 | :param n_samples: number of samples
223 | :return: tensor of means of the negative binomial distribution with shape ``(batch_size, n_input)``
224 | :rtype: :py:class:`torch.Tensor`
225 | """
226 | return self.inference(x, batch_index=batch_index, y=y, n_samples=n_samples)[
227 | "px_rate"
228 | ]
229 |
230 | def get_reconstruction_loss(self, x, px_rate, px_r, px_dropout, **kwargs):
231 | # Reconstruction Loss:
232 | if self.reconstruction_loss == "nb":
233 | reconst_loss = -log_nb_positive(x, px_rate, px_r).sum(dim=-1) + 0.5*mean_square_error_positive(x, px_rate).sum(dim=-1)
234 | elif self.reconstruction_loss == "zinb":
235 | reconst_loss = -log_nb_positive(x, px_rate, px_r).sum(dim=-1) + 0.5*mean_square_error_positive(x, px_rate).sum(dim=-1)
236 | elif self.reconstruction_loss == "mse":
237 | reconst_loss = mean_square_error_positive(x, px_rate).sum(dim=-1)
238 | return reconst_loss
239 |
240 | def scale_from_z(self, sample_batch, fixed_batch):
241 | if self.log_variational:
242 | sample_batch = torch.log(1 + sample_batch)
243 | qz_m, qz_v, z = self.z_encoder(sample_batch)
244 | batch_index = fixed_batch * torch.ones_like(sample_batch[:, [0]])
245 | library = 4.0 * torch.ones_like(sample_batch[:, [0]])
246 | px_scale, _, _, _ = self.decoder("gene", z, library, batch_index)
247 | return px_scale
248 |
249 | def inference(self, x, batch_index=None, y=None, n_samples=1):
250 |
251 | x_ = x
252 | if self.reconstruction_loss == "nb" and self.log_variational:
253 | library_nb = torch.log(x_.sum(dim=-1)).reshape(-1, 1)
254 | elif self.reconstruction_loss == "nb" and not self.log_variational:
255 | library_nb = (x_.sum(dim=-1)).reshape(-1, 1)
256 | if self.log_variational:
257 | x_ = torch.log(1 + x_)
258 |
259 | # Sampling
260 | qz_m, qz_v, z = self.z_encoder(x_, y)
261 | ql_m, ql_v, library = self.l_encoder(x_)
262 | if self.reconstruction_loss == "nb":
263 | library = library_nb
264 |
265 | if n_samples > 1:
266 | qz_m = qz_m.unsqueeze(0).expand((n_samples, qz_m.size(0), qz_m.size(1)))
267 | qz_v = qz_v.unsqueeze(0).expand((n_samples, qz_v.size(0), qz_v.size(1)))
268 | z = Normal(qz_m, qz_v.sqrt()).sample()
269 | ql_m = ql_m.unsqueeze(0).expand((n_samples, ql_m.size(0), ql_m.size(1)))
270 | ql_v = ql_v.unsqueeze(0).expand((n_samples, ql_v.size(0), ql_v.size(1)))
271 | library = Normal(ql_m, ql_v.sqrt()).sample()
272 |
273 | px_scale, px_r, px_rate, px_dropout = self.decoder(
274 | self.dispersion, z, library, batch_index, y
275 | )
276 | if self.dispersion == "gene-label":
277 | px_r = F.linear(
278 | one_hot(y, self.n_labels), self.px_r
279 | ) # px_r gets transposed - last dimension is nb genes
280 | elif self.dispersion == "gene-batch":
281 | px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
282 | elif self.dispersion == "gene":
283 | px_r = self.px_r
284 | px_r = torch.exp(px_r)
285 |
286 | return dict(
287 | px_scale=px_scale,
288 | px_r=px_r,
289 | px_rate=px_rate,
290 | px_dropout=px_dropout,
291 | qz_m=qz_m,
292 | qz_v=qz_v,
293 | z=z,
294 | ql_m=ql_m,
295 | ql_v=ql_v,
296 | library=library,
297 | )
298 |
299 | def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):
300 | r""" Returns the reconstruction loss and the Kullback divergences
301 |
302 | :param x: tensor of values with shape (batch_size, n_input)
303 | :param local_l_mean: tensor of means of the prior distribution of latent variable l
304 | with shape (batch_size, 1)
305 | :param local_l_var: tensor of variancess of the prior distribution of latent variable l
306 | with shape (batch_size, 1)
307 | :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size``
308 | :param y: tensor of cell-types labels with shape (batch_size, n_labels)
309 | :return: the reconstruction loss and the Kullback divergences
310 | :rtype: 2-tuple of :py:class:`torch.FloatTensor`
311 | """
312 | # Parameters for z latent distribution
313 | outputs = self.inference(x, batch_index, None)
314 | qz_m = outputs["qz_m"]
315 | qz_v = outputs["qz_v"]
316 | ql_m = outputs["ql_m"]
317 | ql_v = outputs["ql_v"]
318 | px_rate = outputs["px_rate"]
319 | px_r = outputs["px_r"]
320 | px_dropout = outputs["px_dropout"]
321 |
322 | # KL Divergence
323 | mean = torch.zeros_like(qz_m)
324 | scale = torch.ones_like(qz_v)
325 |
326 |
327 | kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(
328 | dim=1
329 | )
330 | kl_divergence_l = kl(
331 | Normal(ql_m, torch.sqrt(ql_v)),
332 | Normal(local_l_mean, torch.sqrt(local_l_var)),
333 | ).sum(dim=1)
334 | kl_divergence = kl_divergence_z
335 |
336 | reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout)
337 |
338 | if self.reconstruction_loss == "mse" or self.reconstruction_loss == "nb":
339 | kl_divergence_l = 1.0
340 | print("reconst_loss=%f, kl_divergence=%f"%(torch.mean(reconst_loss),torch.mean(kl_divergence)))
341 | return reconst_loss + kl_divergence_l, kl_divergence, 0.0
342 |
343 |
344 | class LDVAE(VAE_Attention):
345 | r"""Linear-decoded Variational auto-encoder model.
346 |
347 | This model uses a linear decoder, directly mapping the latent representation
348 | to gene expression levels. It still uses a deep neural network to encode
349 | the latent representation.
350 |
351 | Compared to standard VAE, this model is less powerful, but can be used to
352 | inspect which genes contribute to variation in the dataset.
353 |
354 | :param n_input: Number of input genes
355 | :param n_batch: Number of batches
356 | :param n_labels: Number of labels
357 | :param n_hidden: Number of nodes per hidden layer (for encoder)
358 | :param n_latent: Dimensionality of the latent space
359 | :param n_layers: Number of hidden layers used for encoder NNs
360 | :param dropout_rate: Dropout rate for neural networks
361 | :param dispersion: One of the following
362 |
363 | * ``'gene'`` - dispersion parameter of NB is constant per gene across cells
364 | * ``'gene-batch'`` - dispersion can differ between different batches
365 | * ``'gene-label'`` - dispersion can differ between different labels
366 | * ``'gene-cell'`` - dispersion can differ for every gene in every cell
367 |
368 | :param log_variational: Log(data+1) prior to encoding for numerical stability. Not normalization.
369 | :param reconstruction_loss: One of
370 |
371 | * ``'nb'`` - Negative binomial distribution
372 | * ``'zinb'`` - Zero-inflated negative binomial distribution
373 | """
374 |
375 | def __init__(
376 | self,
377 | n_input: int,
378 | n_batch: int = 0,
379 | n_labels: int = 0,
380 | n_hidden: int = 128,
381 | n_latent: int = 10,
382 | n_layers: int = 1,
383 | dropout_rate: float = 0.1,
384 | dispersion: str = "gene",
385 | log_variational: bool = True,
386 | reconstruction_loss: str = "zinb",
387 | ):
388 | super().__init__(
389 | n_input,
390 | n_batch,
391 | n_labels,
392 | n_hidden,
393 | n_latent,
394 | n_layers,
395 | dropout_rate,
396 | dispersion,
397 | log_variational,
398 | reconstruction_loss,
399 | )
400 |
401 | self.decoder = LinearDecoderSCVI(
402 | n_latent,
403 | n_input,
404 | n_cat_list=[n_batch],
405 | n_layers=n_layers,
406 | n_hidden=n_hidden,
407 | )
408 |
409 | def get_loadings(self):
410 | """ Extract per-gene weights (for each Z) in the linear decoder.
411 | """
412 | return self.decoder.factor_regressor.parameters()
413 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | from setuptools import setup, find_packages
5 |
6 |
7 | requirements = [
8 | "numpy>=1.16.2",
9 | "torch>=1.0.1",
10 | "matplotlib>=3.0.3",
11 | "h5py>=2.9.0",
12 | "pandas>=0.24.2",
13 | "loompy>=2.0.16",
14 | "tqdm>=4.31.1",
15 | "xlrd==1.2.0",
16 | # "nbconvert>=5.4.0",
17 | # "nbformat>=4.4.0",
18 | # "jupyter>=1.0.0",
19 | # "ipython>=7.1.1",
20 | # "anndata==0.6.22.post1",
21 | # "scanpy==1.4.4.post1",
22 | "dask>=2.0",
23 | "anndata>=0.7",
24 | "scanpy>=1.4.6",
25 | "scikit-learn>=0.22.2",
26 | "numba>=0.48", # numba 0.45.1 has a conflict with UMAP and numba 0.46.0 with parallelization in loompy
27 | "hyperopt==0.1.2",
28 | ]
29 |
30 | setup_requirements = ["pip>=18.1"]
31 |
32 | test_requirements = [
33 | "pytest>=3.7.4",
34 | "pytest-runner>=2.11.1",
35 | "flake8>=3.7.7",
36 | "coverage>=4.5.1",
37 | "codecov>=2.0.8",
38 | "black>=19.3b0",
39 | ]
40 |
41 | extras_requirements = {
42 | "notebooks": [
43 | "louvain>=0.6.1",
44 | "python-igraph>=0.7.1.post6",
45 | "colour>=0.1.5",
46 | "umap-learn>=0.3.8",
47 | "seaborn>=0.9.0",
48 | "leidenalg>=0.7.0",
49 | ],
50 | "docs": [
51 | "sphinx>=1.7.1",
52 | "nbsphinx",
53 | "sphinx_autodoc_typehints",
54 | "sphinx-rtd-theme",
55 | ],
56 | "test": test_requirements,
57 | }
58 | author = (
59 | "Gao yang Li, Shaliu FU"
60 | )
61 |
62 | setup(
63 | author=author,
64 | author_email="lgyzngc@tongji.edu.cn",
65 | classifiers=[
66 | "Development Status :: 4 - Beta",
67 | "Intended Audience :: Science/Research",
68 | "License :: OSI Approved :: MIT License",
69 | "Natural Language :: English",
70 | "Programming Language :: Python :: 3.7",
71 | "Operating System :: MacOS :: MacOS X",
72 | "Operating System :: Microsoft :: Windows",
73 | "Operating System :: POSIX :: Linux",
74 | "Topic :: Scientific/Engineering :: bioinformatics",
75 | ],
76 | description="Single Cell Multi-View Profiler",
77 | install_requires=requirements+extras_requirements["notebooks"],
78 | license="MIT license",
79 | # long_description=readme + "\n\n" + history,
80 | include_package_data=True,
81 | keywords="scMVP",
82 | name="scMVP",
83 | packages=find_packages(),
84 | setup_requires=setup_requirements,
85 | test_suite="tests",
86 | tests_require=test_requirements,
87 | extras_require=extras_requirements,
88 | url="https://github.com/bm2-lab/scMVP",
89 | version="0.0.1",
90 | zip_safe=False,
91 | )
92 |
--------------------------------------------------------------------------------