├── .github └── workflows │ └── test.yaml ├── .gitignore ├── .readthedocs.yaml ├── CHANGELOG.md ├── LICENSE ├── README.md ├── docs ├── Makefile ├── make.bat └── source │ ├── _static │ ├── .DS_Store │ └── img │ │ ├── .DS_Store │ │ ├── cross-modality_integration.png │ │ ├── logo_white.png │ │ ├── mouse_atlas_reference.png │ │ ├── pancreas_projection.png │ │ ├── pbmc_before_integration.png │ │ ├── pbmc_integration.png │ │ ├── query_kidney.png │ │ ├── query_tabula_muris_aging.png │ │ ├── scATAC_integration.png │ │ └── scalex.jpg │ ├── api │ └── index.rst │ ├── conf.py │ ├── contributors.rst │ ├── index.rst │ ├── installation.rst │ ├── news.rst │ ├── release │ ├── 1.0.0.rst │ └── index.rst │ ├── tutorial │ ├── Integration_PBMC.ipynb │ ├── Integration_cross-modality.ipynb │ ├── Integration_scATAC-seq.ipynb │ ├── Projection_pancreas.ipynb │ └── index.rst │ └── usage.rst ├── experiments ├── LISI.ipynb ├── NMI-ARI.ipynb ├── Silhouette_score & batch_entropy_mixing_score.ipynb ├── dirichlet_regression.ipynb ├── overcorrection_score.ipynb └── projection.ipynb ├── pyproject.toml ├── scalex ├── __init__.py ├── analysis.py ├── atac │ ├── __init__.py │ ├── bedtools.py │ ├── fragments.py │ ├── read_modisco.py │ └── snapatac2 │ │ ├── _basic.py │ │ ├── _clustering.py │ │ ├── _diff.py │ │ ├── _embedding.py │ │ ├── _knn.py │ │ ├── _misc.py │ │ ├── _motif.py │ │ └── _utils.py ├── data.py ├── function.py ├── linkage │ ├── __init__.py │ ├── linkage.py │ ├── motif │ │ ├── __init__.py │ │ ├── logo_utils.py │ │ └── motif_compendium.py │ └── utils.py ├── logger.py ├── metrics.py ├── net │ ├── __init__.py │ ├── layer.py │ ├── loss.py │ ├── utils.py │ └── vae.py ├── pl │ ├── __init__.py │ ├── _base.py │ ├── _genometrack.py │ ├── _network.py │ ├── analysis.py │ └── plot.py ├── plot.py ├── pp │ ├── __init__.py │ └── annotation.py └── specifity.py ├── tests ├── conftest.py └── test_scalex.py └── third_parties ├── BBKNN.py ├── Conos.R ├── DESC.py ├── FastMNN.R ├── Harmony.R ├── LIGER.R ├── Raw.py ├── Scanorama.py ├── Seurat_v3.R ├── online_iNMF.R └── scVI.py /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | schedule: 9 | - cron: "0 5 1,15 * *" 10 | 11 | concurrency: 12 | group: ${{ github.workflow }}-${{ github.ref }} 13 | cancel-in-progress: true 14 | 15 | jobs: 16 | test: 17 | runs-on: ${{ matrix.os }} 18 | defaults: 19 | run: 20 | shell: bash -e {0} # -e to fail on error 21 | 22 | strategy: 23 | fail-fast: false 24 | matrix: 25 | include: 26 | - os: ubuntu-latest 27 | python: "3.9" 28 | - os: ubuntu-latest 29 | python: "3.11" 30 | - os: ubuntu-latest 31 | python: "3.11" 32 | pip-flags: "--pre" 33 | name: PRE-RELEASE DEPENDENCIES 34 | 35 | name: ${{ matrix.name }} Python ${{ matrix.python }} 36 | 37 | env: 38 | OS: ${{ matrix.os }} 39 | PYTHON: ${{ matrix.python }} 40 | 41 | steps: 42 | - uses: actions/checkout@v3 43 | - name: Set up Python ${{ matrix.python }} 44 | uses: actions/setup-python@v4 45 | with: 46 | python-version: ${{ matrix.python }} 47 | cache: "pip" 48 | cache-dependency-path: "**/pyproject.toml" 49 | 50 | - name: Install test dependencies 51 | run: | 52 | python -m pip install --upgrade pip wheel 53 | - name: Install dependencies 54 | run: | 55 | pip install ${{ matrix.pip-flags }} ".[dev,test]" 56 | - name: Test 57 | env: 58 | MPLBACKEND: agg 59 | PLATFORM: ${{ matrix.os }} 60 | DISPLAY: :42 61 | run: | 62 | coverage run -m pytest -v --color=yes 63 | - name: Report coverage 64 | run: | 65 | coverage report 66 | - name: Upload coverage 67 | uses: codecov/codecov-action@v3 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | lib/*.pyc 2 | *.pyc 3 | *.ipynb_checkpoints 4 | *.DS_Store 5 | lib/__pycache__ 6 | other/ 7 | output/ 8 | build/ 9 | dist/ 10 | scalex.egg-info/ 11 | docs/_build/ 12 | docs/source/api/scalex* 13 | __pycache__/ 14 | /scalex/_version.py 15 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # https://docs.readthedocs.io/en/stable/config-file/v2.html 2 | version: 2 3 | build: 4 | os: ubuntu-20.04 5 | tools: 6 | python: "3.10" 7 | sphinx: 8 | configuration: docs/source/conf.py 9 | # disable this for more lenient docs builds 10 | fail_on_warning: true 11 | python: 12 | install: 13 | - method: pip 14 | path: . 15 | extra_requirements: 16 | - doc 17 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## [My 30, 2025] 4 | add cmd `scalex` in addition to `SCALEX` 5 | add cmd `fragment` to make peak matrix and gene activity score 6 | add annotate function in analysis.py with filtering pseudogenes 7 | 8 | ## [Dec 20, 2024] 9 | ### Added 10 | - Add plot track in plot.py 11 | - Add cache genome folder .cache/genome/ 12 | - Add gene_sets for annotate in analysis.py 13 | 14 | ## [October 18, 2024] 15 | 16 | ### Added 17 | - CHANGELOG 18 | - Add analysis.py for annotation 19 | - Modify the embedding and add plot_expr to enable gene expression across batchßß 20 | 21 | ## [April 29, 2024] 22 | 23 | ### Added 24 | - CHANGELOG 25 | - pyproject.toml to enable test 26 | - Enable query and reference in the same figure in projection 27 | 28 | ### Replace 29 | - - Replace SCALE.py with SCALE 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, scverse community 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Stars](https://img.shields.io/github/stars/jsxlei/scalex?logo=GitHub&color=yellow)](https://github.com/jsxlei/scalex/stargazers) 2 | [![PyPI](https://img.shields.io/pypi/v/scalex.svg)](https://pypi.org/project/scalex) 3 | [![Documentation Status](https://readthedocs.org/projects/scalex/badge/?version=latest)](https://scalex.readthedocs.io/en/latest/?badge=stable) 4 | [![Downloads](https://pepy.tech/badge/scalex)](https://pepy.tech/project/scalex) 5 | [![DOI](https://zenodo.org/badge/345941713.svg)](https://zenodo.org/badge/latestdoi/345941713) 6 | # [Online single-cell data integration through projecting heterogeneous datasets into a common cell-embedding space](https://www.nature.com/articles/s41467-022-33758-z) 7 | 8 | ![](docs/source/_static/img/scalex.jpg) 9 | 10 | 11 | ## News 12 | 13 | ## [Documentation](https://scalex.readthedocs.io/en/latest/index.html) 14 | ## [Tutorial](https://scalex.readthedocs.io/en/latest/tutorial/index.html) 15 | ## Installation 16 | #### install from PyPI 17 | 18 | pip install scalex 19 | 20 | #### install from GitHub 21 | install the latest develop version 22 | 23 | pip install git+https://github.com/jsxlei/scalex.git 24 | 25 | or git clone and install 26 | 27 | git clone git://github.com/jsxlei/scalex.git 28 | cd scalex 29 | python setup.py install 30 | 31 | SCALEX is implemented in [Pytorch](https://pytorch.org/) framework. 32 | SCALEX can be run on CPU devices, and running SCALEX on GPU devices if available is recommended. 33 | 34 | ## Getting started 35 | 36 | SCALEX can both used under command line and API function in jupyter notebook 37 | Please refer to the [Documentation](https://readthedocs.org/projects/scalex/badge/?version=latest) and [Tutorial](https://scalex.readthedocs.io/en/latest/tutorial/index.html) 38 | 39 | 40 | ### 1. API function 41 | 42 | from scalex import SCALEX 43 | adata = SCALEX(data_list, batch_categories) 44 | 45 | Function of parameters are similar to command line options. 46 | Output is a Anndata object for further analysis with scanpy. 47 | `data_list` can be 48 | * data_path, file format included txt, csv, h5ad, h5mu/rna, h5mu/atac, dir contains mtx 49 | * list of data_paths 50 | * [Anndata]((https://anndata.readthedocs.io/en/stable/anndata.AnnData.html#anndata.AnnData)) 51 | * list of [AnnData]((https://anndata.readthedocs.io/en/stable/anndata.AnnData.html#anndata.AnnData)) 52 | * above mixed 53 | 54 | `batch_categories` is optional, name of each batch, will be range from 0 to N-1 if not provided 55 | 56 | ### 2. Command line 57 | #### Standard usage 58 | 59 | 60 | SCALEX --data_list data1 data2 dataN --batch_categories batch_name1 batch_name2 batch_nameN 61 | 62 | 63 | `--data_list`: data path of each batch of single-cell dataset, use `-d` for short 64 | 65 | `--batch_categories`: name of each batch, batch_categories will range from 0 to N-1 if not specified 66 | 67 | 68 | #### Output 69 | Output will be saved in the output folder including: 70 | * **checkpoint**: saved model to reproduce results cooperated with option --checkpoint or -c 71 | * **[adata.h5ad](https://anndata.readthedocs.io/en/stable/anndata.AnnData.html#anndata.AnnData)**: preprocessed data and results including, latent, clustering and imputation 72 | * **umap.png**: UMAP visualization of latent representations of cells 73 | * **log.txt**: log file of training process 74 | 75 | ### Other Common Usage 76 | #### Use h5ad file storing `anndata` as input, one or multiple separated files 77 | 78 | SCALEX --data_list 79 | 80 | #### Specify batch in `anadata.obs` using `--batch_name` if only one concatenated h5ad file provided, batch_name can be e.g. conditions, samples, assays or patients, default is `batch` 81 | 82 | SCALEX --data_list --batch_name 83 | 84 | 85 | #### Integrate heterogenous scATAC-seq datasets, add option `--profile` ATAC 86 | 87 | SCALEX --data_list --profile ATAC 88 | 89 | #### Inputation simultaneously along with Integration, add option `--impute`, results are stored at anndata.layers['impute'] 90 | 91 | SCALEX --data_list --profile ATAC --impute True 92 | 93 | 94 | #### Custom features through `--n_top_features` a filename contains features in one column format read 95 | 96 | SCALEX --data_list --n_top_features features.txt 97 | 98 | #### Use preprocessed data `--processed` 99 | 100 | SCALEX --data_list --processed 101 | 102 | #### Option 103 | 104 | * --**data_list** 105 | A list of matrices file (each as a `batch`) or a single batch/batch-merged file. 106 | * --**batch_categories** 107 | Categories for the batch annotation. By default, use increasing numbers if not given 108 | * --**batch_name** 109 | Use this annotation in anndata.obs as batches for training model. Default: 'batch'. 110 | * --**profile** 111 | Specify the single-cell profile, RNA or ATAC. Default: RNA. 112 | * --**min_features** 113 | Filtered out cells that are detected in less than min_features. Default: 600 for RNA, 100 for ATAC. 114 | * --**min_cells** 115 | Filtered out genes that are detected in less than min_cells. Default: 3. 116 | * --**n_top_features** 117 | Number of highly-variable genes to keep. Default: 2000 for RNA, 30000 for ATAC. 118 | * --**outdir** 119 | Output directory. Default: 'output/'. 120 | * --**projection** 121 | Use for new dataset projection. Input the folder containing the pre-trained model. Default: None. 122 | * --**impute** 123 | If True, calculate the imputed gene expression and store it at adata.layers['impute']. Default: False. 124 | * --**chunk_size** 125 | Number of samples from the same batch to transform. Default: 20000. 126 | * --**ignore_umap** 127 | If True, do not perform UMAP for visualization and leiden for clustering. Default: False. 128 | * --**join** 129 | Use intersection ('inner') or union ('outer') of variables of different batches. 130 | * --**batch_key** 131 | Add the batch annotation to obs using this key. By default, batch_key='batch'. 132 | * --**batch_size** 133 | Number of samples per batch to load. Default: 64. 134 | * --**lr** 135 | Learning rate. Default: 2e-4. 136 | * --**max_iteration** 137 | Max iterations for training. Training one batch_size samples is one iteration. Default: 30000. 138 | * --**seed** 139 | Random seed for torch and numpy. Default: 124. 140 | * --**gpu** 141 | Index of GPU to use if GPU is available. Default: 0. 142 | * --**verbose** 143 | Verbosity, True or False. Default: False. 144 | 145 | 146 | 147 | 148 | #### Help 149 | Look for more usage of SCALEX 150 | 151 | SCALEX.py --help 152 | 153 | 154 | ## Release notes 155 | 156 | See the [changelog](https://github.com/jsxlei/SCALEX/CHANGELOG.md). 157 | 158 | 159 | ## Citation 160 | 161 | Xiong, L., Tian, K., Li, Y., Ning, W., Gao, X., & Zhang, Q. C. (2022). Online single-cell data integration through projecting heterogeneous datasets into a common cell-embedding space. Nature Communications, 13(1), 6118. https://doi.org/10.1038/s41467-022-33758-z 162 | 163 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/_static/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsxlei/SCALEX/7a410ffe11b1d2139d3c66961ee90f77136c26fa/docs/source/_static/.DS_Store -------------------------------------------------------------------------------- /docs/source/_static/img/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsxlei/SCALEX/7a410ffe11b1d2139d3c66961ee90f77136c26fa/docs/source/_static/img/.DS_Store -------------------------------------------------------------------------------- /docs/source/_static/img/cross-modality_integration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsxlei/SCALEX/7a410ffe11b1d2139d3c66961ee90f77136c26fa/docs/source/_static/img/cross-modality_integration.png -------------------------------------------------------------------------------- /docs/source/_static/img/logo_white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsxlei/SCALEX/7a410ffe11b1d2139d3c66961ee90f77136c26fa/docs/source/_static/img/logo_white.png -------------------------------------------------------------------------------- /docs/source/_static/img/mouse_atlas_reference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsxlei/SCALEX/7a410ffe11b1d2139d3c66961ee90f77136c26fa/docs/source/_static/img/mouse_atlas_reference.png -------------------------------------------------------------------------------- /docs/source/_static/img/pancreas_projection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsxlei/SCALEX/7a410ffe11b1d2139d3c66961ee90f77136c26fa/docs/source/_static/img/pancreas_projection.png -------------------------------------------------------------------------------- /docs/source/_static/img/pbmc_before_integration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsxlei/SCALEX/7a410ffe11b1d2139d3c66961ee90f77136c26fa/docs/source/_static/img/pbmc_before_integration.png -------------------------------------------------------------------------------- /docs/source/_static/img/pbmc_integration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsxlei/SCALEX/7a410ffe11b1d2139d3c66961ee90f77136c26fa/docs/source/_static/img/pbmc_integration.png -------------------------------------------------------------------------------- /docs/source/_static/img/query_kidney.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsxlei/SCALEX/7a410ffe11b1d2139d3c66961ee90f77136c26fa/docs/source/_static/img/query_kidney.png -------------------------------------------------------------------------------- /docs/source/_static/img/query_tabula_muris_aging.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsxlei/SCALEX/7a410ffe11b1d2139d3c66961ee90f77136c26fa/docs/source/_static/img/query_tabula_muris_aging.png -------------------------------------------------------------------------------- /docs/source/_static/img/scATAC_integration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsxlei/SCALEX/7a410ffe11b1d2139d3c66961ee90f77136c26fa/docs/source/_static/img/scATAC_integration.png -------------------------------------------------------------------------------- /docs/source/_static/img/scalex.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsxlei/SCALEX/7a410ffe11b1d2139d3c66961ee90f77136c26fa/docs/source/_static/img/scalex.jpg -------------------------------------------------------------------------------- /docs/source/api/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: scalex 2 | .. automodule:: scalex 3 | :noindex: 4 | 5 | 6 | API 7 | ==== 8 | 9 | 10 | Import SCALEX:: 11 | 12 | import scalex 13 | 14 | 15 | Function 16 | -------- 17 | .. module:: scalex 18 | .. currentmodule:: scalex 19 | 20 | .. autosummary:: 21 | :toctree: . 22 | 23 | SCALEX 24 | label_transfer 25 | 26 | 27 | Data 28 | ------------------- 29 | .. module:: scalex.data 30 | .. currentmodule:: scalex 31 | 32 | 33 | Load data 34 | ~~~~~~~~~~~~~~~~~~~ 35 | .. autosummary:: 36 | :toctree: . 37 | 38 | data.load_data 39 | data.concat_data 40 | data.load_files 41 | data.load_file 42 | data.read_mtx 43 | 44 | 45 | Preprocessing 46 | ~~~~~~~~~~~~~ 47 | .. autosummary:: 48 | :toctree: . 49 | 50 | data.preprocessing 51 | data.preprocessing_rna 52 | data.preprocessing_atac 53 | data.batch_scale 54 | data.reindex 55 | 56 | 57 | DataLoader 58 | ~~~~~~~~~~ 59 | .. autosummary:: 60 | :toctree: . 61 | 62 | data.SingleCellDataset 63 | data.BatchSampler 64 | 65 | 66 | 67 | Net 68 | ----------- 69 | .. module:: scalex.net 70 | .. currentmodule:: scalex 71 | 72 | 73 | Model 74 | ~~~~~ 75 | .. autosummary:: 76 | :toctree: . 77 | 78 | net.vae.VAE 79 | 80 | 81 | Layer 82 | ~~~~~ 83 | .. autosummary:: 84 | :toctree: . 85 | 86 | net.layer.DSBatchNorm 87 | net.layer.Block 88 | net.layer.NN 89 | net.layer.Encoder 90 | 91 | 92 | Loss 93 | ~~~~ 94 | .. autosummary:: 95 | :toctree: . 96 | 97 | net.loss.kl_div 98 | net.loss.binary_cross_entropy 99 | 100 | 101 | Utils 102 | ~~~~~ 103 | .. autosummary:: 104 | :toctree: . 105 | 106 | net.utils.onehot 107 | net.utils.EarlyStopping 108 | 109 | 110 | 111 | Plot 112 | ------------ 113 | .. module:: scalex.plot 114 | .. currentmodule:: scalex 115 | 116 | 117 | .. autosummary:: 118 | :toctree: . 119 | 120 | plot.embedding 121 | plot.plot_meta 122 | plot.plot_meta2 123 | plot.plot_confusion 124 | 125 | 126 | Metric 127 | ---------------- 128 | .. module:: scalex.metric 129 | .. currentmodule:: scalex 130 | 131 | 132 | Collections of useful measurements for evaluating results. 133 | 134 | .. autosummary:: 135 | :toctree: . 136 | 137 | metrics.batch_entropy_mixing_score 138 | metrics.silhouette_score 139 | 140 | 141 | Logger 142 | ------ 143 | .. module:: scalex.logger 144 | .. currentmodule:: scalex 145 | 146 | .. autosummary:: 147 | :toctree: . 148 | 149 | logger.create_logger 150 | 151 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | from datetime import datetime 16 | 17 | sys.path.insert(0, os.path.abspath(__file__+'../../../..')) 18 | 19 | 20 | import scalex 21 | 22 | # -- Project information ----------------------------------------------------- 23 | 24 | project = 'SCALEX' 25 | author = scalex.__author__ 26 | copyright = f'{datetime.now():%Y}, {author}.' 27 | 28 | 29 | # The full version, including alpha/beta/rc tags 30 | release = scalex.__version__ 31 | 32 | 33 | # -- General configuration --------------------------------------------------- 34 | 35 | # Add any Sphinx extension module names here, as strings. They can be 36 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 37 | # ones. 38 | 39 | nitpicky = True # Warn about broken links. This is here for a reason: Do not change. 40 | needs_sphinx = '2.0' # Nicer param docs 41 | 42 | extensions = [ 43 | 'sphinx.ext.autodoc', 44 | 'sphinx.ext.intersphinx', 45 | 'sphinx.ext.doctest', 46 | 'sphinx.ext.coverage', 47 | 'sphinx.ext.mathjax', 48 | 'sphinx.ext.napoleon', 49 | 'sphinx.ext.autosummary', 50 | 'sphinx_autodoc_typehints', 51 | 'nbsphinx' 52 | ] 53 | 54 | # Generate the API documentation when building 55 | autosummary_generate = True 56 | autodoc_member_order = 'bysource' 57 | 58 | napoleon_google_docstring = False 59 | napoleon_numpy_docstring = True 60 | napoleon_include_init_with_doc = False 61 | napoleon_use_rtype = True # having a separate entry generally helps readability 62 | napoleon_use_param = True 63 | napoleon_custom_sections = [('Params', 'Parameters')] 64 | todo_include_todos = False 65 | 66 | 67 | # Add any paths that contain templates here, relative to this directory. 68 | templates_path = ['_templates'] 69 | 70 | # List of patterns, relative to source directory, that match files and 71 | # directories to ignore when looking for source files. 72 | # This pattern also affects html_static_path and html_extra_path. 73 | exclude_patterns = [] 74 | 75 | 76 | # -- Options for HTML output ------------------------------------------------- 77 | 78 | # The theme to use for HTML and HTML Help pages. See the documentation for 79 | # a list of builtin themes. 80 | # 81 | html_theme = 'sphinx_book_theme' 82 | 83 | html_theme_options = dict(navigation_depth=4, logo_only=True) # Only show the logo 84 | html_context = dict( 85 | display_github=True, # Integrate GitHub 86 | github_user='jsxlei', # Username 87 | github_repo='SCALEX', # Repo name 88 | github_version='main', # Version 89 | conf_py_path='/docs/', # Path in the checkout to the docs root 90 | ) 91 | html_static_path = ['_static'] 92 | html_show_sphinx = False 93 | html_logo = '_static/img/logo_white.png' 94 | 95 | # Add any paths that contain custom static files (such as style sheets) here, 96 | # relative to this directory. They are copied after the builtin static files, 97 | # so a file named "default.css" will overwrite the builtin "default.css". 98 | html_static_path = ['_static'] 99 | -------------------------------------------------------------------------------- /docs/source/contributors.rst: -------------------------------------------------------------------------------- 1 | .. sidebar:: Contributors 2 | 3 | * `Lei Xiong`_: Leader Developer 4 | * `Kang Tian`_: Developer 5 | * `Yuzhe Li`_: Developer 6 | 7 | .. _Lei Xiong: http://xiong-lei.com/ 8 | .. _Kang Tian: 9 | .. _Yuzhe Li: 10 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. scalex documentation master file, created by 2 | sphinx-quickstart on Sun Mar 7 15:52:19 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | |Stars| |PyPI| |PyPIDownloads| |Docs| 7 | 8 | .. |Stars| image:: https://img.shields.io/github/stars/jsxlei/SCALEX?logo=GitHub&color=yellow 9 | :target: https://github.com/jsxlei/SCALEX/stargazers 10 | .. |PyPI| image:: https://img.shields.io/pypi/v/scalex?logo=PyPI 11 | :target: https://pypi.org/project/scalex 12 | .. |PyPIDownloads| image:: https://pepy.tech/badge/scalex 13 | :target: https://pepy.tech/project/scalex 14 | .. |Docs| image:: https://readthedocs.com/projects/scalex/badge/?version=latest 15 | :target: https://scalex.readthedocs.io 16 | 17 | 18 | Online single-cell data integration through projecting heterogeneous datasets into a common cell-embedding space 19 | ---------------- 20 | 21 | .. include:: contributors.rst 22 | 23 | .. role:: small 24 | .. role:: smaller 25 | 26 | .. toctree:: 27 | :maxdepth: 2 28 | :hidden: 29 | 30 | tutorial/index 31 | installation 32 | usage 33 | api/index 34 | news 35 | release/index 36 | 37 | 38 | News 39 | ---- 40 | 41 | .. include:: news.rst 42 | :start-line: 2 43 | :end-line: 22 44 | 45 | 46 | 47 | Indices and tables 48 | ================== 49 | 50 | * :ref:`genindex` 51 | * :ref:`modindex` 52 | * :ref:`search` 53 | -------------------------------------------------------------------------------- /docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ------------ 3 | 4 | PyPI install 5 | ~~~~~~~~~~~~ 6 | 7 | Pull SCALE from `PyPI `__ (consider using ``pip3`` to access Python 3):: 8 | 9 | pip install scalex 10 | 11 | .. _from PyPI: https://pypi.org/project/scalex 12 | 13 | 14 | Pytorch 15 | ~~~~~~~ 16 | If you have cuda devices, consider install Pytorch_ cuda version.:: 17 | 18 | conda install pytorch torchvision torchaudio -c pytorch 19 | 20 | .. _Pytorch: https://pytorch.org/ 21 | 22 | Troubleshooting 23 | ~~~~~~~~~~~~~~~ 24 | 25 | 26 | Anaconda 27 | ~~~~~~~~ 28 | If you do not have a working installation of Python 3.6 (or later), consider 29 | installing Miniconda_ (see `Installing Miniconda`_). 30 | 31 | 32 | Installing Miniconda 33 | ~~~~~~~~~~~~~~~~~~~~ 34 | After downloading Miniconda_, in a unix shell (Linux, Mac), run 35 | 36 | .. code:: shell 37 | 38 | cd DOWNLOAD_DIR 39 | chmod +x Miniconda3-latest-VERSION.sh 40 | ./Miniconda3-latest-VERSION.sh 41 | 42 | and accept all suggestions. 43 | Either reopen a new terminal or `source ~/.bashrc` on Linux/ `source ~/.bash_profile` on Mac. 44 | The whole process takes just a couple of minutes. 45 | 46 | .. _Miniconda: http://conda.pydata.org/miniconda.html 47 | 48 | -------------------------------------------------------------------------------- /docs/source/news.rst: -------------------------------------------------------------------------------- 1 | News 2 | ===== 3 | .. role:: small 4 | 5 | SCALEX is online on `Nature Communications `_ :small:`2022-10-17` 6 | -------------------------------------------------------------------------------- /docs/source/release/1.0.0.rst: -------------------------------------------------------------------------------- 1 | 1.0.0 :small:`2022-08-29` 2 | ~~~~~~~~~~~~~~~~~~~~~~~~~ 3 | 4 | Online single-cell data integration through projecting heterogeneous datasets into a common cell-embedding space 5 | 6 | -------------------------------------------------------------------------------- /docs/source/release/index.rst: -------------------------------------------------------------------------------- 1 | Release 2 | ======= 3 | 4 | 5 | 6 | Version 1.0 7 | ----------- 8 | 9 | .. include:: 1.0.0.rst -------------------------------------------------------------------------------- /docs/source/tutorial/index.rst: -------------------------------------------------------------------------------- 1 | Tutorial 2 | ========= 3 | 4 | `Integration <../tutorial/Integration_PBMC.ipynb>`_ 5 | -------------------- 6 | 7 | before integration 8 | 9 | .. image:: ../_static/img/pbmc_before_integration.png 10 | :width: 500px 11 | 12 | after SCALEX integration 13 | 14 | .. image:: ../_static/img/pbmc_integration.png 15 | :width: 600px 16 | 17 | 18 | `Projection <../tutorial/Projection_pancreas.ipynb>`_ 19 | ------------- 20 | 21 | Map new data to the embeddings of reference 22 | 23 | A pancreas reference was created by integrating eight batches. 24 | 25 | Here, map pancreas_gse81547, pancreas_gse83139 and pancreas_gse114297 to the embeddings of pancreas reference. 26 | 27 | .. image:: ../_static/img/pancreas_projection.png 28 | :width: 600px 29 | 30 | 31 | `Label transfer `_ 32 | --------------- 33 | Annotate cells in new data through label transfer 34 | 35 | Label transfer tabula muris data and mouse kidney data from mouse atlas reference 36 | 37 | mouse atlas reference 38 | 39 | .. image:: ../_static/img/mouse_atlas_reference.png 40 | :width: 400px 41 | 42 | query tabula muris aging and query mouse kidney 43 | 44 | .. image:: ../_static/img/query_tabula_muris_aging.png 45 | :width: 300px 46 | 47 | .. image:: ../_static/img/query_kidney.png 48 | :width: 300px 49 | 50 | 51 | `Integration scATAC-seq data `_ 52 | --------------- 53 | 54 | .. image:: ../_static/img/scATAC_integration.png 55 | :width: 600px 56 | 57 | 58 | `Integration cross-modality data <../tutorial/Integration_cross-modality.ipynb>`_ 59 | ------------------------------- 60 | Integrate scRNA-seq and scATAC-seq dataset 61 | 62 | .. image:: ../_static/img/cross-modality_integration.png 63 | :width: 600px 64 | 65 | 66 | Spatial data (To be updated) 67 | ------------ 68 | Integrating spatial data with scRNA-seq 69 | 70 | 71 | Examples 72 | -------- 73 | 74 | .. toctree:: 75 | :maxdepth: 2 76 | :hidden: 77 | 78 | Integration_PBMC 79 | Projection_pancreas 80 | Integration_scATAC-seq 81 | Integration_cross-modality 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /docs/source/usage.rst: -------------------------------------------------------------------------------- 1 | Usage 2 | ---------------- 3 | 4 | SCALEX provide both commanline tool and api function used in jupyter notebook 5 | 6 | Command line 7 | ^^^^^^^^^^^^ 8 | Run SCALEX after installation:: 9 | 10 | SCALEX.py --data_list data1 data2 --batch_categories batch_name1 batch_name2 11 | 12 | ``data_list``: data path of each batch of single-cell dataset 13 | 14 | ``batch_categories``: name of each batch, batch_categories will range from 0 to N if not specified 15 | 16 | Input 17 | ~~~~~ 18 | Input can be one of following: 19 | 20 | * single file of format h5ad, csv, txt, mtx or their compression file 21 | * multiple files of above format 22 | 23 | .. note:: h5ad file input 24 | * SCALEX will use the ``batch`` column in the obs of adata format read from h5ad file as batch information 25 | * Users can specify any columns in the obs with option: ``--batch_name`` name 26 | * If multiple inputs are given, SCALEX can take each file as individual batch by default, and overload previous batch information, users can change the concat name via option ``--batch_key`` other_name 27 | 28 | Output 29 | ~~~~~~~~~~~ 30 | Output will be saved in the output folder including: 31 | 32 | * **checkpoint**: saved model to reproduce results cooperated with option ``--checkpoint`` or -c 33 | * **adata.h5ad**: preprocessed data and results including, latent, clustering and imputation 34 | * **umap.png**: UMAP visualization of latent representations of cells 35 | * **log.txt**: log file of training process 36 | 37 | 38 | Useful options 39 | ~~~~~~~~~~~~~~ 40 | * output folder for saveing results: [-o] or [--outdir] 41 | * filter rare genes, default 3: [--min_cell] 42 | * filter low quality cells, default 600: [--min_gene] 43 | * select the number of highly variable genes, keep all genes with -1, default 2000: [--n_top_genes] 44 | 45 | 46 | Help 47 | ~~~~ 48 | Look for more usage of SCALEX:: 49 | 50 | SCALEX.py --help 51 | 52 | 53 | API function 54 | ^^^^^^^^^^^^ 55 | Use SCALEX in jupyter notebook:: 56 | 57 | from scalex.function import SCALEX 58 | adata = SCALEX(data_list, batch_categories) 59 | 60 | or 61 | adata = SCALEX([adata_1, adata_2]) 62 | 63 | Function of parameters are similar to command line options. 64 | Input can be the files of adata or a list of AnnData or one concatenated AnnData 65 | Output is a Anndata object for further analysis with scanpy. 66 | 67 | 68 | 69 | 70 | AnnData 71 | ^^^^^^^ 72 | SCALEX supports :mod:`scanpy` and :mod:`anndata`, which provides the :class:`~anndata.AnnData` class. 73 | 74 | .. image:: http://falexwolf.de/img/scanpy/anndata.svg 75 | :width: 300px 76 | 77 | At the most basic level, an :class:`~anndata.AnnData` object `adata` stores 78 | a data matrix `adata.X`, annotation of observations 79 | `adata.obs` and variables `adata.var` as `pd.DataFrame` and unstructured 80 | annotation `adata.uns` as `dict`. Names of observations and 81 | variables can be accessed via `adata.obs_names` and `adata.var_names`, 82 | respectively. :class:`~anndata.AnnData` objects can be sliced like 83 | dataframes, for example, `adata_subset = adata[:, list_of_gene_names]`. 84 | For more, see this `blog post`_. 85 | 86 | .. _blog post: http://falexwolf.de/blog/171223_AnnData_indexing_views_HDF5-backing/ 87 | 88 | To read a data file to an :class:`~anndata.AnnData` object, call:: 89 | 90 | import scanpy as sc 91 | adata = sc.read(filename) 92 | 93 | to initialize an :class:`~anndata.AnnData` object. Possibly add further annotation using, e.g., `pd.read_csv`:: 94 | 95 | import pandas as pd 96 | anno = pd.read_csv(filename_sample_annotation) 97 | adata.obs['cell_groups'] = anno['cell_groups'] # categorical annotation of type pandas.Categorical 98 | adata.obs['time'] = anno['time'] # numerical annotation of type float 99 | # alternatively, you could also set the whole dataframe 100 | # adata.obs = anno 101 | 102 | To write, use:: 103 | 104 | adata.write(filename) 105 | adata.write_csvs(filename) 106 | adata.write_loom(filename) 107 | 108 | 109 | .. _Seaborn: http://seaborn.pydata.org/ 110 | .. _matplotlib: http://matplotlib.org/ -------------------------------------------------------------------------------- /experiments/LISI.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "library(lisi)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 8, 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "[1] \"SCALEX\"\n", 22 | "[1] \"LIGER\"\n", 23 | "[1] \"online_iNMF\"\n", 24 | "[1] \"Harmony\"\n", 25 | "[1] \"Seurat_v3\"\n", 26 | "[1] \"Conos\"\n", 27 | "[1] \"FastMNN\"\n", 28 | "[1] \"Raw\"\n", 29 | "[1] \"Scanorama\"\n", 30 | "[1] \"BBKNN\"\n", 31 | "[1] \"scVI\"\n", 32 | "[1] \"SCALEX\"\n", 33 | "[1] \"LIGER\"\n", 34 | "[1] \"online_iNMF\"\n", 35 | "[1] \"Harmony\"\n", 36 | "[1] \"Seurat_v3\"\n", 37 | "[1] \"Conos\"\n", 38 | "[1] \"FastMNN\"\n", 39 | "[1] \"Raw\"\n", 40 | "[1] \"Scanorama\"\n", 41 | "[1] \"BBKNN\"\n", 42 | "[1] \"scVI\"\n", 43 | "[1] \"SCALEX\"\n", 44 | "[1] \"LIGER\"\n", 45 | "[1] \"online_iNMF\"\n", 46 | "[1] \"Harmony\"\n", 47 | "[1] \"Seurat_v3\"\n", 48 | "[1] \"Conos\"\n", 49 | "[1] \"FastMNN\"\n", 50 | "[1] \"Raw\"\n", 51 | "[1] \"Scanorama\"\n", 52 | "[1] \"BBKNN\"\n", 53 | "[1] \"scVI\"\n" 54 | ] 55 | } 56 | ], 57 | "source": [ 58 | "for(dataset in c('pancreas','PBMC','liver','heart','NSCLC'))\n", 59 | "{\n", 60 | " items <- c('SCALEX','LIGER','online_iNMF','Harmony','Seurat_v3','Conos','FastMNN', 'Raw', 'Scanorama', 'BBKNN', 'scVI')\n", 61 | " path <- '~/SCALEX/notebook/benchmark/LISI/data/'\n", 62 | " batch_id <- read.table(paste(path,dataset,'/','batch.txt',sep=''),sep='\\t',header = F)\n", 63 | " celltype <- read.table(paste(path,dataset,'/','celltype.txt',sep=''),sep='\\t',header = F)\n", 64 | " metadata=cbind(batch_id,celltype)\n", 65 | " colnames(metadata)=c('batch','celltype')\n", 66 | "\n", 67 | " lisi_res <- list()\n", 68 | " for(item in items){\n", 69 | " if(file.exists(paste(path,dataset,'/',item,'.txt',sep=''))){\n", 70 | " print(item)\n", 71 | " umap <- read.table(paste(path,dataset,'/',item,'.txt',sep=''))\n", 72 | " result <- lisi::compute_lisi(umap, metadata, c('batch', 'celltype'))\n", 73 | " lisi_res <- append(lisi_res, data.frame(result))\n", 74 | " }\n", 75 | " }\n", 76 | " lisi_res <- Reduce(cbind,lisi_res)\n", 77 | " lisi_res = round(lisi_res,5)\n", 78 | "\n", 79 | " colnames(lisi_res) <- c('SCALEX_batch','SCALEX_celltype','LIGER_batch','LIGER_celltype','online_iNMF_batch','online_iNMF_celltype','Harmony_batch','Harmony_celltype',\n", 80 | " 'Seurat_v3_batch','Seurat_v3_celltype','Conos_batch','Conos_celltype',\n", 81 | " 'FastMNN_batch','FastMNN_celltype','Raw_batch','Raw_celltype','Scanorama_batch','Scanorama_celltype',\n", 82 | " 'BBKNN_batch','BBKNN_celltype','scVI_batch','scVI_celltype')\n", 83 | "\n", 84 | " write.table(lisi_res,paste(path,dataset,'/','lisi_res.txt',sep=''),sep='\\t',quote=F)\n", 85 | " }\n" 86 | ] 87 | } 88 | ], 89 | "metadata": { 90 | "kernelspec": { 91 | "display_name": "R", 92 | "language": "R", 93 | "name": "ir" 94 | }, 95 | "language_info": { 96 | "codemirror_mode": "r", 97 | "file_extension": ".r", 98 | "mimetype": "text/x-r-source", 99 | "name": "R", 100 | "pygments_lexer": "r", 101 | "version": "4.2.0" 102 | } 103 | }, 104 | "nbformat": 4, 105 | "nbformat_minor": 4 106 | } 107 | -------------------------------------------------------------------------------- /experiments/dirichlet_regression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# install.packages('DirichletReg')" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 26, 15 | "metadata": { 16 | "heading_collapsed": "false" 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "library(Seurat)\n", 21 | "library(RColorBrewer) #for brewer.pal\n", 22 | "library(Matrix) #for Matrix\n", 23 | "library(DirichletReg)\n", 24 | "library(data.table)\n", 25 | "library(tidyverse)\n", 26 | "library(cowplot)\n", 27 | "\n", 28 | "## this function is extracted from analysis.r \n", 29 | "dirichlet_regression = function(counts, covariates, formula){ \n", 30 | " # Dirichlet multinomial regression to detect changes in cell frequencies\n", 31 | " # formula is not quoted, example: counts ~ condition\n", 32 | " # counts is a [samples x cell types] matrix\n", 33 | " # covariates holds additional data to use in the regression\n", 34 | " #\n", 35 | " # Example:\n", 36 | " # counts = do.call(cbind, tapply(seur@data.info$orig.ident, seur@ident, table))\n", 37 | " # covariates = data.frame(condition=gsub('[12].*', '', rownames(counts)))\n", 38 | " # res = dirichlet_regression(counts, covariates, counts ~ condition)\n", 39 | " \n", 40 | " #ep.pvals = dirichlet_regression(counts=ep.freq, covariates=ep.cov, formula=counts ~ condition)$pvals\n", 41 | "\n", 42 | " # Calculate regression\n", 43 | " counts = as.data.frame(counts)\n", 44 | " counts$counts = DR_data(counts)\n", 45 | " data = cbind(counts, covariates)\n", 46 | " fit = DirichReg(counts ~ condition, data) \n", 47 | " \n", 48 | " # Get p-values\n", 49 | " u = summary(fit)\n", 50 | " #compared with healthy condition, 15 vars. noninflame and inflame, 30pvalues\n", 51 | " pvals = u$coef.mat[grep('Intercept', rownames(u$coef.mat), invert=T), 4] \n", 52 | " v = names(pvals)\n", 53 | " pvals = matrix(pvals, ncol=length(u$varnames))\n", 54 | " rownames(pvals) = gsub('condition', '', v[1:nrow(pvals)])\n", 55 | " colnames(pvals) = u$varnames\n", 56 | " fit$pvals = pvals\n", 57 | " \n", 58 | " fit\n", 59 | "}" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 15, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "freq=read.csv('celltype.csv',row.names=1)\n", 69 | "freq1=as.matrix(as.data.frame(lapply(freq, as.double),row.names=row.names(freq)))\n", 70 | "cov=read.csv('conv.csv',row.names=1)\n", 71 | "cov1 = data.frame(condition=factor(cov[rownames(freq),1], levels=c('healthy control','mild(moderate)','severe','convalescence','influenza')), \n", 72 | " row.names=rownames(freq))" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 16, 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "name": "stderr", 82 | "output_type": "stream", 83 | "text": [ 84 | "Warning message in DR_data(counts):\n", 85 | "“not all rows sum up to 1 => normalization forced\n", 86 | " some entries are 0 or 1 => transformation forced”\n" 87 | ] 88 | } 89 | ], 90 | "source": [ 91 | "pvals = dirichlet_regression(counts=freq1, covariates=cov1, formula=counts ~ condition)$pvals\n", 92 | "colnames(pvals) = colnames(freq1)\n", 93 | "# write.csv(pvals,'healthy(control).csv')" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 21, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "name": "stderr", 103 | "output_type": "stream", 104 | "text": [ 105 | "Warning message in DR_data(counts):\n", 106 | "“not all rows sum up to 1 => normalization forced\n", 107 | " some entries are 0 or 1 => transformation forced”\n" 108 | ] 109 | } 110 | ], 111 | "source": [ 112 | "cov=read.csv('conv.csv',row.names=1)\n", 113 | "cov1 = data.frame(condition=factor(cov[rownames(freq),1], levels=c('mild(moderate)','healthy control','severe','convalescence','influenza')), \n", 114 | " row.names=rownames(freq))\n", 115 | "pvals = dirichlet_regression(counts=freq1, covariates=cov1, formula=counts ~ condition)$pvals\n", 116 | "colnames(pvals) = colnames(freq1)\n", 117 | "# write.csv(pvals,'mild(control).csv')" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 2, 123 | "metadata": { 124 | "heading_collapsed": "false" 125 | }, 126 | "outputs": [ 127 | { 128 | "data": { 129 | "text/plain": [ 130 | "R version 4.0.2 (2020-06-22)\n", 131 | "Platform: x86_64-pc-linux-gnu (64-bit)\n", 132 | "Running under: Ubuntu 16.04.6 LTS\n", 133 | "\n", 134 | "Matrix products: default\n", 135 | "BLAS: /usr/lib/libblas/libblas.so.3.6.0\n", 136 | "LAPACK: /usr/lib/lapack/liblapack.so.3.6.0\n", 137 | "\n", 138 | "locale:\n", 139 | " [1] LC_CTYPE=C.UTF-8 LC_NUMERIC=C \n", 140 | " [3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8 \n", 141 | " [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8 \n", 142 | " [7] LC_PAPER=en_US.UTF-8 LC_NAME=C \n", 143 | " [9] LC_ADDRESS=C LC_TELEPHONE=C \n", 144 | "[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C \n", 145 | "\n", 146 | "attached base packages:\n", 147 | "[1] stats graphics grDevices utils datasets methods base \n", 148 | "\n", 149 | "other attached packages:\n", 150 | " [1] cowplot_1.1.1 forcats_0.5.1 stringr_1.4.0 dplyr_1.0.4 \n", 151 | " [5] purrr_0.3.4 readr_1.4.0 tidyr_1.1.2 tibble_3.0.6 \n", 152 | " [9] ggplot2_3.3.3 tidyverse_1.3.0 data.table_1.13.6 DirichletReg_0.7-0\n", 153 | "[13] Formula_1.2-4 Matrix_1.2-18 RColorBrewer_1.1-2 SeuratObject_4.0.0\n", 154 | "[17] Seurat_4.0.0 \n", 155 | "\n", 156 | "loaded via a namespace (and not attached):\n", 157 | " [1] Rtsne_0.15 colorspace_2.0-0 deldir_0.2-10 \n", 158 | " [4] ellipsis_0.3.1 ggridges_0.5.3 IRdisplay_1.0 \n", 159 | " [7] fs_1.5.0 base64enc_0.1-3 rstudioapi_0.13 \n", 160 | " [10] spatstat.data_2.0-0 leiden_0.3.7 listenv_0.8.0 \n", 161 | " [13] ggrepel_0.9.1 lubridate_1.7.9.2 xml2_1.3.2 \n", 162 | " [16] codetools_0.2-16 splines_4.0.2 polyclip_1.10-0 \n", 163 | " [19] IRkernel_1.1.1 jsonlite_1.7.2 broom_0.7.5 \n", 164 | " [22] ica_1.0-2 dbplyr_2.1.0 cluster_2.1.0 \n", 165 | " [25] png_0.1-7 uwot_0.1.10 shiny_1.6.0 \n", 166 | " [28] sctransform_0.3.2 compiler_4.0.2 httr_1.4.2 \n", 167 | " [31] backports_1.2.1 assertthat_0.2.1 fastmap_1.1.0 \n", 168 | " [34] lazyeval_0.2.2 cli_2.3.0 later_1.1.0.1 \n", 169 | " [37] htmltools_0.5.1.1 tools_4.0.2 igraph_1.2.6 \n", 170 | " [40] gtable_0.3.0 glue_1.4.2 RANN_2.6.1 \n", 171 | " [43] reshape2_1.4.4 Rcpp_1.0.6 spatstat_1.64-1 \n", 172 | " [46] scattermore_0.7 cellranger_1.1.0 vctrs_0.3.6 \n", 173 | " [49] nlme_3.1-149 lmtest_0.9-38 ps_1.5.0 \n", 174 | " [52] globals_0.14.0 rvest_0.3.6 mime_0.10 \n", 175 | " [55] miniUI_0.1.1.1 lifecycle_1.0.0 irlba_2.3.3 \n", 176 | " [58] goftest_1.2-2 future_1.21.0 MASS_7.3-53 \n", 177 | " [61] zoo_1.8-8 scales_1.1.1 hms_1.0.0 \n", 178 | " [64] miscTools_0.6-26 promises_1.2.0.1 spatstat.utils_2.0-0\n", 179 | " [67] parallel_4.0.2 sandwich_3.0-0 reticulate_1.18 \n", 180 | " [70] pbapply_1.4-3 gridExtra_2.3 rpart_4.1-15 \n", 181 | " [73] stringi_1.5.3 repr_1.1.3 rlang_0.4.10 \n", 182 | " [76] pkgconfig_2.0.3 matrixStats_0.58.0 evaluate_0.14 \n", 183 | " [79] lattice_0.20-41 ROCR_1.0-11 tensor_1.5 \n", 184 | " [82] patchwork_1.1.1 htmlwidgets_1.5.3 tidyselect_1.1.0 \n", 185 | " [85] parallelly_1.23.0 RcppAnnoy_0.0.18 plyr_1.8.6 \n", 186 | " [88] magrittr_2.0.1 R6_2.5.0 generics_0.1.0 \n", 187 | " [91] pbdZMQ_0.3-5 DBI_1.1.1 withr_2.4.1 \n", 188 | " [94] haven_2.3.1 pillar_1.4.7 mgcv_1.8-33 \n", 189 | " [97] fitdistrplus_1.1-3 survival_3.2-3 abind_1.4-5 \n", 190 | "[100] future.apply_1.7.0 modelr_0.1.8 crayon_1.4.1 \n", 191 | "[103] uuid_0.1-4 KernSmooth_2.23-17 plotly_4.9.3 \n", 192 | "[106] maxLik_1.4-6 readxl_1.3.1 grid_4.0.2 \n", 193 | "[109] reprex_1.0.0 digest_0.6.27 xtable_1.8-4 \n", 194 | "[112] httpuv_1.5.5 munsell_0.5.0 viridisLite_0.3.0 " 195 | ] 196 | }, 197 | "metadata": {}, 198 | "output_type": "display_data" 199 | } 200 | ], 201 | "source": [ 202 | "sessionInfo()" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [] 211 | } 212 | ], 213 | "metadata": { 214 | "kernelspec": { 215 | "display_name": "R", 216 | "language": "R", 217 | "name": "ir" 218 | }, 219 | "language_info": { 220 | "codemirror_mode": "r", 221 | "file_extension": ".r", 222 | "mimetype": "text/x-r-source", 223 | "name": "R", 224 | "pygments_lexer": "r", 225 | "version": "4.2.0" 226 | } 227 | }, 228 | "nbformat": 4, 229 | "nbformat_minor": 4 230 | } 231 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "hatchling.build" 3 | requires = ["hatchling"] 4 | 5 | [project] 6 | name = "scalex" 7 | authors = [{name = "Lei Xiong"}] 8 | version = "1.0.6" 9 | readme = "README.md" 10 | requires-python = ">=3.7" 11 | description = "Online single-cell data integration through projecting heterogeneous datasets into a common cell-embedding space" 12 | license = { text = "MIT" } 13 | classifiers = [ 14 | "Programming Language :: Python :: 3", 15 | "License :: OSI Approved :: BSD License", 16 | "Operating System :: OS Independent", 17 | "Development Status :: 3 - Alpha", 18 | "Topic :: Scientific/Engineering :: Bio-Informatics", 19 | "Intended Audience :: Science/Research" 20 | ] 21 | dependencies = [ 22 | "numpy>=1.26.4", 23 | "pandas>=2.2.2", 24 | "scipy>=1.13.0", 25 | "scikit-learn>=1.4.2", 26 | "torch>=2.2.2", 27 | "scanpy>=1.10.1", 28 | "tqdm>=4.66.2", 29 | "matplotlib>=3.8.4", 30 | "seaborn>=0.13.2", 31 | "leidenalg>=0.8.3", 32 | "gseapy", 33 | "pyranges", 34 | ] 35 | 36 | [project.scripts] 37 | SCALEX = "scalex.function:main" 38 | scalex = "scalex.function:main" 39 | 40 | [project.optional-dependencies] 41 | dev = [ 42 | "pre-commit", 43 | "twine>=4.0.2", 44 | ] 45 | doc = [ 46 | "docutils>=0.8,!=0.18.*,!=0.19.*", 47 | "sphinx>=4", 48 | "sphinx-book-theme>=1.0.0", 49 | "myst-nb", 50 | "sphinxcontrib-bibtex>=1.0.0", 51 | "sphinx-autodoc-typehints", 52 | "sphinxext-opengraph", 53 | "nbsphinx", 54 | # For notebooks 55 | "ipykernel", 56 | "ipython", 57 | "sphinx-copybutton", 58 | "pandas", 59 | ] 60 | test = [ 61 | "pytest>=6.0", 62 | "coverage", 63 | ] 64 | 65 | [tool.coverage.run] 66 | source = ["scalex"] 67 | omit = [ 68 | "**/test_*.py", 69 | ] 70 | 71 | [tool.pytest.ini_options] 72 | testpaths = ["tests"] 73 | xfail_strict = true 74 | addopts = [ 75 | "--import-mode=importlib", # allow using test files with same name 76 | ] 77 | 78 | [tool.ruff] 79 | line-length = 120 80 | src = ["scalex"] 81 | extend-include = ["*.ipynb"] 82 | 83 | [tool.ruff.format] 84 | docstring-code-format = true 85 | -------------------------------------------------------------------------------- /scalex/__init__.py: -------------------------------------------------------------------------------- 1 | # Define the variable '__version__': 2 | __version__ = "1.0.4" 3 | __author__ = "Lei Xiong" 4 | __email__ = "jsxlei@gmail.com" 5 | 6 | from .function import SCALEX, label_transfer -------------------------------------------------------------------------------- /scalex/atac/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsxlei/SCALEX/7a410ffe11b1d2139d3c66961ee90f77136c26fa/scalex/atac/__init__.py -------------------------------------------------------------------------------- /scalex/atac/bedtools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from pybedtools import BedTool 4 | import pandas as pd 5 | import numpy as np 6 | 7 | import torch 8 | from scipy.sparse import coo_matrix 9 | import numpy as np 10 | 11 | def coo_from_pandas(df, source, target, values=None, shape=None): 12 | if values is None: 13 | df['values'] = 1 14 | data = df['values'] 15 | else: 16 | data = df[values] 17 | 18 | coo = coo_matrix((data, (df[source], df[target])), shape=shape) 19 | return coo 20 | 21 | 22 | def edge_index_to_coo(edge, data=None, shape=None): 23 | data = np.ones_like(edge[0]) if data is None else data 24 | shape = (edge[0].max()+1, edge[1].max()+1) if shape is None else shape 25 | return coo_matrix((data, (edge[0], edge[1])), shape=shape) 26 | 27 | 28 | def coo_to_sparse_tensor(coo): 29 | values = coo.data 30 | indices = np.vstack((coo.row, coo.col)) 31 | 32 | i = torch.LongTensor(indices) 33 | v = torch.FloatTensor(values) 34 | shape = coo.shape 35 | 36 | return torch.sparse.FloatTensor(i, v, torch.Size(shape)) 37 | 38 | 39 | # ======================================================= 40 | # Bed 41 | # ======================================================= 42 | 43 | def bed_to_df(x, keep=False): 44 | """ 45 | Convert list of peaks(str) into data frame. 46 | 47 | Args: 48 | x (list of str): list of peak names 49 | 50 | Returns: 51 | pandas.dataframe: peak info as DataFrame 52 | 53 | Examples: 54 | >>> x = ['chr1_3094484_3095479', 'chr1_3113499_3113979', 'chr1_3119478_3121690'] 55 | >>> list_peakstr_to_df(x) 56 | chr start end 57 | 0 chr1 3094484 3095479 58 | 1 chr1 3113499 3113979 59 | 2 chr1 3119478 3121690 60 | """ 61 | df = np.array([to_coord(i) for i in x]) 62 | df = pd.DataFrame(df, columns=["Chromosome", "Start", "End"]) 63 | df["Start"] = df["Start"].astype(int) 64 | df["End"] = df["End"].astype(int) 65 | df.index = x 66 | 67 | return df 68 | 69 | 70 | def to_coord(peak): 71 | if ':' not in peak: 72 | chrom, start, end = peak.split('-') 73 | else: 74 | chrom, start_end = peak.split(':') 75 | start, end = start_end.split('-') 76 | return chrom, int(start), int(end) 77 | 78 | 79 | def df_to_bed(x): 80 | # return (x.iloc[:, 0]+':'+x.iloc[:, 1].astype(str)+'-'+x.iloc[:, 2].astype(str)).values 81 | return x.apply(lambda row: row.iloc[0]+':'+str(row.iloc[1])+'-'+str(row.iloc[2]), axis=1).values 82 | 83 | 84 | def extend_bed(df, up=0, down=0, start='Start', end='End'): 85 | # assert 'Start' in df.columns and 'End' in df.columns, 'Start and End columns are required.' 86 | df = df.copy() 87 | if not isinstance(df, pd.DataFrame): 88 | df = bed_to_df(df) 89 | if 'Strand' not in df.columns: 90 | df['Strand'] = '+' 91 | 92 | if 'Start' not in df.columns or 'End' not in df.columns: 93 | df.columns = ["Chromosome", "Start", "End"] + list(df.columns[3:]) 94 | 95 | # x = x.apply(lambda row: (row[0], max(0, int(row[1])-down), int(row[2])+up) if 'Strand' in row and (row['Strand'] == '-') 96 | # else (row[0], max(0, int(row[1])-up), int(row[2])+down), axis=1, result_type='expand') 97 | # x.columns =["Chromosome", "Start", "End"] 98 | # return x 99 | pos_strand = df.query("Strand == '+'").index 100 | neg_strand = df.query("Strand == '-'").index 101 | 102 | df.loc[pos_strand, end] = df.loc[pos_strand, end] + down 103 | df.loc[pos_strand, start] = df.loc[pos_strand, start] - up 104 | 105 | df.loc[neg_strand, end] = df.loc[neg_strand, end] + up 106 | df.loc[neg_strand, start] = df.loc[neg_strand, start] - down 107 | 108 | df.loc[:, start] = df.loc[:, start].clip(lower=0) 109 | return df 110 | 111 | 112 | def process_bed(bed, sort=False, add_index=True): 113 | if not isinstance(bed, pd.DataFrame): 114 | bed = bed_to_df(bed) 115 | if add_index: 116 | bed = bed.iloc[:, :3] 117 | bed['index'] = np.arange(len(bed)) 118 | bed = BedTool.from_dataframe(bed) 119 | if sort: 120 | bed = bed.sort() 121 | return bed 122 | 123 | 124 | 125 | def decode_dist(dist): 126 | if isinstance(dist, int): 127 | dist = dist 128 | elif isinstance(dist, str): 129 | if 'M' in dist: 130 | dist = int(dist.replace('M', '')) * 1_000_000 131 | elif 'K' in dist: 132 | dist = int(dist.replace('K', '')) * 1_000 133 | else: 134 | dist = int(dist) 135 | return dist 136 | 137 | 138 | def intersect_bed( 139 | query, ref, 140 | up=0, 141 | down=0, 142 | add_query_index=True, 143 | add_ref_index=True, 144 | index=[3, 7], 145 | out='edge_index', 146 | add_distance=False, 147 | ): 148 | """ 149 | Return intersection index of query and ref, 150 | make sure the fourth column of query and last column of ref are index. 151 | """ 152 | up = decode_dist(up) 153 | if up > 0 or down > 0: 154 | query = extend_bed(query, up=up, down=down) 155 | 156 | query = process_bed(query, add_index=add_query_index) 157 | ref = process_bed(ref, add_index=add_ref_index) 158 | 159 | intersected = query.intersect(ref, wa=True, wb=True).to_dataframe() # nonamecheck 160 | if len(intersected) == 0: 161 | raise ValueError('No intersection found, please check the input bed file.') 162 | if index is None: 163 | return intersected 164 | else: 165 | edges = intersected.iloc[:, index].values.T.astype(int) 166 | 167 | if add_distance: 168 | mid1 = intersected.iloc[:, [1, 2]].mean(axis=1) 169 | mid2 = intersected.iloc[:, [5, 6]].mean(axis=1) 170 | distance = np.abs(mid1 - mid2).astype(int).values 171 | else: 172 | distance = None 173 | 174 | if out == 'edge_index': 175 | if add_distance: 176 | edges = np.concatenate([edges, distance.reshape(1, -1)], axis=0) 177 | return edges 178 | elif out == 'coo': 179 | return edge_index_to_coo(edges, data=distance, shape=(len(query), len(ref))) 180 | 181 | 182 | def subtract_bed(query, ref): 183 | """ 184 | Return subtraction index of query and ref, 185 | make sure the fourth column of query and last column of ref are index. 186 | """ 187 | query = BedTool.from_dataframe(query) 188 | ref = BedTool.from_dataframe(ref) 189 | return query.subtract(ref).to_dataframe() 190 | 191 | 192 | def closest_bed(query, ref, k=1, D='a', t='first'): 193 | """ 194 | Return two matrix, one is query ref pair matrix, the other is query ref distance matrix 195 | row is query index, column is ref index 196 | """ 197 | query = process_bed(query, sort=True) 198 | ref = process_bed(ref, sort=True) 199 | 200 | intersected = query.closest(ref, k=k, D=D, t=t).to_dataframe() 201 | intersected = intersected[intersected.iloc[:, -3]!=-1] 202 | 203 | out = intersected.iloc[:, [3, -2, -1]] 204 | out.columns = ['query', 'ref', 'distance'] 205 | pair = pd.DataFrame(out.groupby('query')['ref'].apply(list).to_dict()).T 206 | distance = pd.DataFrame(out.groupby('query')['distance'].apply(list).to_dict()).T 207 | return pair, distance 208 | 209 | 210 | import torch 211 | def get_promoter_offset_for_embedding(promoter, peak_list): 212 | result = intersect_bed(promoter, peak_list) 213 | promoter_dict = result.groupby('name')['thickEnd'].apply(list).to_dict() 214 | inputs = [] 215 | offsets = [] 216 | offset = 0 217 | for i in range(len(promoter)): 218 | offsets.append(offset) 219 | if i in promoter_dict: 220 | v = promoter_dict[i] 221 | inputs+=v 222 | offset += len(v) 223 | inputs = torch.LongTensor(inputs) 224 | offsets = torch.LongTensor(offsets) 225 | return inputs, offsets 226 | 227 | 228 | -------------------------------------------------------------------------------- /scalex/atac/read_modisco.py: -------------------------------------------------------------------------------- 1 | from bs4 import BeautifulSoup 2 | import pandas as pd 3 | 4 | def read_modisco_html(motif_html): 5 | """ 6 | Read modisco html file and return a dataframe. 7 | """ 8 | if motif_html.endswith(".html"): 9 | with open(motif_html, "r") as file: 10 | modisco_report = BeautifulSoup(file, "html.parser") 11 | 12 | table = modisco_report.find_all("table") 13 | df = pd.read_html(str(table))[0] 14 | else: 15 | df = pd.read_csv(motif_html, sep="\t") 16 | 17 | # if motif_meta: 18 | # meta = pd.read_csv(motif_meta, sep="\t") 19 | # # df = pd.merge(df, meta, left_on='match0', right_on='motif_id', how='left') 20 | # motif_id_dict = pd.Series(meta['tf_name'].values, meta['motif_id'].values).to_dict() 21 | # source_id_dict = pd.Series(meta['tf_name'].values, meta['source_id'].values).to_dict() 22 | # else: 23 | # motif_id_dict = None 24 | # source_id_dict = None 25 | # return df['match0'] 26 | 27 | # mapping_dict = {} 28 | # for i, row in df.iterrows(): 29 | # # print(row) 30 | # k = row.loc['pattern'] 31 | # v = row.loc['match0'] 32 | # if v in motif_id_dict: 33 | # mapping_dict[k] = motif_id_dict[v] 34 | # elif v in source_id_dict: 35 | # mapping_dict[k] = source_id_dict[v] 36 | # else: 37 | # print(f"{k} {v} Not found") 38 | 39 | # df['match0'] = df['pattern'].map(mapping_dict) 40 | return df[['match0', 'num_seqlets']].groupby('match0', as_index=False)['num_seqlets'].sum().set_index('match0').sort_values('num_seqlets', ascending=False) 41 | 42 | 43 | def read_mapping_meta(motif_meta): 44 | """ 45 | Read motif meta file and return a mapping dict. 46 | """ 47 | meta = pd.read_csv(motif_meta, sep="\t") 48 | # df = pd.merge(df, meta, left_on='match0', right_on='motif_id', how='left') 49 | motif_id_dict = pd.Series(meta['tf_name'].values, meta['motif_id'].values).to_dict() 50 | source_id_dict = pd.Series(meta['tf_name'].values, meta['source_id'].values).to_dict() 51 | 52 | more_dict = {i: source_id_dict[i] for i in source_id_dict if i not in motif_id_dict} 53 | motif_id_dict.update(more_dict) 54 | return motif_id_dict 55 | # mapping_dict = {} 56 | # for i, row in df.iterrows(): 57 | # print(row) 58 | # k = row.loc['pattern'] 59 | # v = row.loc['match0'] 60 | # if v in motif_id_dict: 61 | # mapping_dict[v] = motif_id_dict[v] 62 | # elif v in source_id_dict: 63 | # mapping_dict[v] = source_id_dict[v] 64 | # else: 65 | # print(f"{v} Not found") 66 | 67 | # return mapping_dict 68 | -------------------------------------------------------------------------------- /scalex/atac/snapatac2/_basic.py: -------------------------------------------------------------------------------- 1 | import pybedtools 2 | import pandas as pd 3 | from typing import List 4 | from anndata import AnnData 5 | import numpy as np 6 | from pathlib import Path 7 | import scanpy as sc 8 | import logging 9 | 10 | from ._clustering import spectral 11 | from ._knn import knn 12 | from ._misc import aggregate_X 13 | 14 | 15 | def _find_most_accessible_features( 16 | feature_count, 17 | filter_lower_quantile, 18 | filter_upper_quantile, 19 | total_features, 20 | ) -> np.ndarray: 21 | idx = np.argsort(feature_count) 22 | for i in range(idx.size): 23 | if feature_count[idx[i]] > 0: 24 | break 25 | idx = idx[i:] 26 | n = idx.size 27 | n_lower = int(filter_lower_quantile * n) 28 | n_upper = int(filter_upper_quantile * n) 29 | idx = idx[n_lower:n-n_upper] 30 | return idx[::-1][:total_features] 31 | 32 | 33 | 34 | 35 | 36 | def intersect_bed(regions: List[str], bed_file: str) -> List[bool]: 37 | """ 38 | Check if genomic regions intersect with a BED file. 39 | 40 | Parameters: 41 | - regions: List of genomic regions as strings (e.g., "chr1:1000-2000"). 42 | - bed_file: Path to the BED file. 43 | 44 | Returns: 45 | - List of booleans indicating whether each region overlaps with the BED file. 46 | """ 47 | # Load BED file as an interval tree using pybedtools 48 | bed_intervals = pybedtools.BedTool(bed_file) 49 | 50 | results = [] 51 | for region in regions: 52 | # Convert "chr1:1000-2000" to BED format "chr1 1000 2000" 53 | chrom, coords = region.split(":") 54 | start, end = map(int, coords.split("-")) 55 | 56 | # Create a temporary interval 57 | query_interval = pybedtools.BedTool(f"{chrom}\t{start}\t{end}", from_string=True) 58 | 59 | # Check for intersection 60 | results.append(bool(query_interval.intersect(bed_intervals, u=True))) 61 | 62 | return results 63 | 64 | 65 | 66 | 67 | def select_features( 68 | adata: AnnData | list[AnnData], 69 | n_features: int = 500000, 70 | filter_lower_quantile: float = 0.005, 71 | filter_upper_quantile: float = 0.005, 72 | whitelist: Path | None = None, 73 | blacklist: Path | None = None, 74 | max_iter: int = 1, 75 | inplace: bool = True, 76 | n_jobs: int = 8, 77 | verbose: bool = True, 78 | ) -> np.ndarray | list[np.ndarray] | None: 79 | """ 80 | Perform feature selection by selecting the most accessibile features across 81 | all cells unless `max_iter` > 1. 82 | 83 | Note 84 | ---- 85 | This function does not perform the actual subsetting. The feature mask is used by 86 | various functions to generate submatrices on the fly. 87 | Features that are zero in all cells will be always removed regardless of the 88 | filtering criteria. 89 | For more discussion about feature selection, see: https://github.com/kaizhang/SnapATAC2/discussions/116. 90 | 91 | Parameters 92 | ---------- 93 | adata 94 | The (annotated) data matrix of shape `n_obs` x `n_vars`. 95 | Rows correspond to cells and columns to regions. 96 | `adata` can also be a list of AnnData objects. 97 | In this case, the function will be applied to each AnnData object in parallel. 98 | n_features 99 | Number of features to keep. Note that the final number of features 100 | may be smaller than this number if there is not enough features that pass 101 | the filtering criteria. 102 | filter_lower_quantile 103 | Lower quantile of the feature count distribution to filter out. 104 | For example, 0.005 means the bottom 0.5% features with the lowest counts will be removed. 105 | filter_upper_quantile 106 | Upper quantile of the feature count distribution to filter out. 107 | For example, 0.005 means the top 0.5% features with the highest counts will be removed. 108 | Be aware that when the number of feature is very large, the default value of 0.005 may 109 | risk removing too many features. 110 | whitelist 111 | A user provided bed file containing genome-wide whitelist regions. 112 | None-zero features listed here will be kept regardless of the other 113 | filtering criteria. 114 | If a feature is present in both whitelist and blacklist, it will be kept. 115 | blacklist 116 | A user provided bed file containing genome-wide blacklist regions. 117 | Features that are overlapped with these regions will be removed. 118 | max_iter 119 | If greater than 1, this function will perform iterative clustering and feature selection 120 | based on variable features found using previous clustering results. 121 | This is similar to the procedure implemented in ArchR, but we do not recommend it, 122 | see https://github.com/kaizhang/SnapATAC2/issues/111. 123 | Default value is 1, which means no iterative clustering is performed. 124 | inplace 125 | Perform computation inplace or return result. 126 | n_jobs 127 | Number of parallel jobs to use when `adata` is a list. 128 | verbose 129 | Whether to print progress messages. 130 | 131 | Returns 132 | ------- 133 | np.ndarray | None: 134 | If `inplace = False`, return a boolean index mask that does filtering, 135 | where `True` means that the feature is kept, `False` means the feature is removed. 136 | Otherwise, store this index mask directly to `.var['selected']`. 137 | """ 138 | 139 | count = np.zeros(adata.shape[1]) 140 | for batch, _, _ in adata.chunked_X(2000): 141 | count += np.ravel(batch.sum(axis = 0)) 142 | adata.var['count'] = count 143 | 144 | selected_features = _find_most_accessible_features( 145 | count, filter_lower_quantile, filter_upper_quantile, n_features) 146 | 147 | if blacklist is not None: 148 | blacklist = np.array(intersect_bed(adata.var_names, str(blacklist))) 149 | selected_features = selected_features[np.logical_not(blacklist[selected_features])] 150 | 151 | # Iteratively select features 152 | iter = 1 153 | while iter < max_iter: 154 | embedding = spectral(adata, features=selected_features, inplace=False)[1] 155 | clusters = sc.tl.leiden(knn(embedding, inplace=False)) 156 | rpm = aggregate_X(adata, groupby=clusters).X 157 | var = np.var(np.log(rpm + 1), axis=0) 158 | selected_features = np.argsort(var)[::-1][:n_features] 159 | 160 | # Apply blacklist to the result 161 | if blacklist is not None: 162 | selected_features = selected_features[np.logical_not(blacklist[selected_features])] 163 | iter += 1 164 | 165 | result = np.zeros(adata.shape[1], dtype=bool) 166 | result[selected_features] = True 167 | 168 | # Finally, apply whitelist to the result 169 | if whitelist is not None: 170 | whitelist = np.array(intersect_bed(adata.var_names, str(whitelist))) 171 | whitelist &= count != 0 172 | result |= whitelist 173 | 174 | if verbose: 175 | logging.info(f"Selected {result.sum()} features.") 176 | 177 | if inplace: 178 | adata.var["selected"] = result 179 | else: 180 | return result -------------------------------------------------------------------------------- /scalex/atac/snapatac2/_diff.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Literal 4 | import numpy as np 5 | from scipy.stats import chi2, norm, zscore 6 | import logging 7 | 8 | from anndata import AnnData 9 | from ._misc import aggregate_X 10 | 11 | def marker_regions( 12 | data: AnnData, 13 | groupby: str | list[str], 14 | pvalue: float = 0.01, 15 | ) -> dict[str, list[str]]: 16 | """ 17 | A quick-and-dirty way to get marker regions. 18 | 19 | Parameters 20 | ---------- 21 | data 22 | AnnData or AnnDataSet object. 23 | groupby 24 | Grouping variable. 25 | pvalue 26 | P-value threshold. 27 | """ 28 | count = aggregate_X(data, groupby, normalize="RPKM") 29 | z = zscore(np.log2(1 + count.X), axis = 0) 30 | peaks = {} 31 | for i in range(z.shape[0]): 32 | select = norm.sf(z[i, :]) < pvalue 33 | if np.where(select)[0].size >= 1: 34 | peaks[count.obs_names[i]] = count.var_names[select] 35 | return peaks 36 | 37 | def mad(data, axis=None): 38 | """ Compute Median Absolute Deviation """ 39 | return np.median(np.absolute(data - np.median(data, axis)), axis) 40 | 41 | def modified_zscore(matrix, axis=0): 42 | """ Compute Modified Z-score for a matrix along specified axis """ 43 | median = np.median(matrix, axis=axis) 44 | median_absolute_deviation = mad(matrix, axis=axis) 45 | min_non_zero = np.min(median_absolute_deviation[median_absolute_deviation > 0]) 46 | median_absolute_deviation[median_absolute_deviation == 0] = min_non_zero 47 | 48 | if axis == 0: 49 | modified_z_scores = 0.6745 * (matrix - median) / median_absolute_deviation 50 | elif axis == 1: 51 | modified_z_scores = 0.6745 * (matrix.T - median).T / median_absolute_deviation 52 | else: 53 | raise ValueError("Invalid axis, it should be 0 or 1") 54 | 55 | return modified_z_scores 56 | 57 | def diff_test( 58 | data: AnnData, 59 | cell_group1: list[int] | list[str], 60 | cell_group2: list[int] | list[str], 61 | features : list[str] | list[int] | None = None, 62 | covariates: list[str] | None = None, 63 | direction: Literal["positive", "negative", "both"] = "both", 64 | min_log_fc: float = 0.25, 65 | min_pct: float = 0.05, 66 | ) -> 'polars.DataFrame': 67 | """ 68 | Identify differentially accessible regions. 69 | 70 | Parameters 71 | ---------- 72 | data 73 | AnnData or AnnDataSet object. 74 | cell_group1 75 | cells belonging to group 1. This can be a list of cell barcodes, indices or 76 | boolean mask vector. 77 | cell_group2 78 | cells belonging to group 2. This can be a list of cell barcodes, indices or 79 | boolean mask vector. 80 | features 81 | Features/peaks to test. If None, all features are tested. 82 | covariates 83 | direction 84 | "positive", "negative", or "both". 85 | "positive": return features that are enriched in group 1. 86 | "negative": return features that are enriched in group 2. 87 | "both": return features that are enriched in group 1 or group 2. 88 | min_log_fc 89 | Limit testing to features which show, on average, at least 90 | X-fold difference (log2-scale) between the two groups of cells. 91 | min_pct 92 | Only test features that are detected in a minimum fraction of min_pct 93 | cells in either of the two populations. 94 | 95 | Returns 96 | ------- 97 | pl.DataFrame 98 | A DataFrame with 4 columns: "feature name", "log2(fold_change)", 99 | "p-value", and "adjusted p-value". 100 | """ 101 | import polars as pl 102 | 103 | def to_indices(xs, type): 104 | xs = [_convert_to_bool_if_np_bool(x) for x in xs] 105 | if all(isinstance(x, bool) for x in xs): 106 | return [i for i, value in enumerate(xs) if value] 107 | elif all([isinstance(item, str) for item in xs]): 108 | if type == "obs": 109 | if data.isbacked: 110 | return data.obs_ix(xs) 111 | else: 112 | return [data.obs_names.get_loc(x) for x in xs] 113 | else: 114 | if data.isbacked: 115 | return data.var_ix(xs) 116 | else: 117 | return [data.var_names.get_loc(x) for x in xs] 118 | else: 119 | return xs 120 | 121 | cell_group1 = to_indices(cell_group1, "obs") 122 | n_group1 = len(cell_group1) 123 | cell_group2 = to_indices(cell_group2, "obs") 124 | n_group2 = len(cell_group2) 125 | 126 | cell_by_peak = data.X[cell_group1 + cell_group2, :].tocsc() 127 | test_var = np.array([0] * n_group1 + [1] * n_group2) 128 | if covariates is not None: 129 | raise NameError("covariates is not implemented") 130 | 131 | features = range(data.n_vars) if features is None else to_indices(features, "var") 132 | logging.info("Input contains {} features, now perform filtering with 'min_log_fc = {}' and 'min_pct = {}' ...".format(len(features), min_log_fc, min_pct)) 133 | filtered = _filter_features( 134 | cell_by_peak[:n_group1, :], 135 | cell_by_peak[n_group1:, :], 136 | features, 137 | direction, 138 | min_pct, 139 | min_log_fc, 140 | ) 141 | 142 | if len(filtered) == 0: 143 | logging.warning("Zero feature left after filtering, perhaps 'min_log_fc' or 'min_pct' is too large") 144 | return pl.DataFrame() 145 | else: 146 | features, log_fc = zip(*filtered) 147 | logging.info("Testing {} features ...".format(len(features))) 148 | pvals = _diff_test_helper(cell_by_peak, test_var, features, covariates) 149 | var_names = data.var_names 150 | return pl.DataFrame({ 151 | "feature name": [var_names[i] for i in features], 152 | "log2(fold_change)": np.array(log_fc), 153 | "p-value": np.array(pvals), 154 | "adjusted p-value": _p_adjust_bh(pvals), 155 | }).sort("adjusted p-value") 156 | 157 | def _p_adjust_bh(p): 158 | """Benjamini-Hochberg p-value correction for multiple hypothesis testing.""" 159 | p = np.asfarray(p) 160 | by_descend = p.argsort()[::-1] 161 | by_orig = by_descend.argsort() 162 | steps = float(len(p)) / np.arange(len(p), 0, -1) 163 | q = np.minimum(1, np.minimum.accumulate(steps * p[by_descend])) 164 | return q[by_orig] 165 | 166 | def _filter_features(mat1, mat2, peak_indices, direction, 167 | min_pct, min_log_fc, pseudo_count = 1, 168 | ): 169 | def rpm(m): 170 | x = np.ravel(np.sum(m, axis = 0)) + pseudo_count 171 | s = x.sum() 172 | return x / (s / 1000000) 173 | 174 | def pass_min_pct(i): 175 | cond1 = mat1[:, i].count_nonzero() / mat1.shape[0] >= min_pct 176 | cond2 = mat2[:, i].count_nonzero() / mat2.shape[0] >= min_pct 177 | return cond1 or cond2 178 | 179 | def adjust_sign(fc): 180 | if direction == "both": 181 | return abs(fc) 182 | elif direction == "positive": 183 | return fc 184 | elif direction == "negative": 185 | return -fc 186 | else: 187 | raise NameError("direction must be 'positive', 'negative' or 'both'") 188 | 189 | log_fc = np.log2(rpm(mat1) / rpm(mat2)) 190 | peak_indices = [i for i in peak_indices if pass_min_pct(i)] 191 | return [(i, log_fc[i]) for i in peak_indices if adjust_sign(log_fc[i]) >= min_log_fc] 192 | 193 | def _diff_test_helper(mat, z, peaks=None, covariate=None) -> list[float]: 194 | """ 195 | Parameters 196 | ---------- 197 | mat 198 | cell by peak matrix. 199 | z 200 | variables to test 201 | peaks 202 | peak indices 203 | covariate 204 | additional variables to regress out. 205 | """ 206 | 207 | if len(z.shape) == 1: 208 | z = z.reshape((-1, 1)) 209 | 210 | if covariate is None: 211 | X = np.log1p(np.sum(mat, axis=1)) 212 | else: 213 | X = covariate 214 | 215 | mat = mat.tocsc() 216 | if peaks is not None: 217 | mat = mat[:, peaks] 218 | 219 | return _likelihood_ratio_test_many(np.asarray(X), np.asarray(z), mat) 220 | 221 | 222 | def _likelihood_ratio_test_many(X, z, Y) -> list[float]: 223 | """ 224 | Parameters 225 | ---------- 226 | X 227 | (n_sample, n_feature). 228 | z 229 | (n_sample, 1), the additional variable. 230 | Y 231 | (n_sample, k), labels 232 | 233 | Returns 234 | ------- 235 | P-values of whether adding z to the models improves the prediction. 236 | """ 237 | from tqdm import tqdm 238 | 239 | X0 = X 240 | X1 = np.concatenate((X, z), axis=1) 241 | 242 | _, n = Y.shape 243 | Y.data = np.ones(Y.data.shape) 244 | 245 | result = [] 246 | for i in tqdm(range(n)): 247 | result.append( 248 | _likelihood_ratio_test(X0, X1, np.asarray(np.ravel(Y[:, i].todense()))) 249 | ) 250 | return result 251 | 252 | def _likelihood_ratio_test( 253 | X0: np.ndarray, 254 | X1: np.ndarray, 255 | y: np.ndarray, 256 | ) -> float: 257 | """ 258 | Comparing null model with alternative model using the likehood ratio test. 259 | 260 | Parameters 261 | ---------- 262 | X0 263 | (n_sample, n_feature), variables used in null model. 264 | X1 265 | (n_sample, n_feature2), variables used in alternative model. 266 | Note X1 contains X0. 267 | Y 268 | (n_sample, ), labels. 269 | 270 | Returns 271 | ------- 272 | The P-value. 273 | """ 274 | from sklearn.linear_model import LogisticRegression 275 | from sklearn.metrics import log_loss 276 | 277 | model = LogisticRegression(penalty=None, random_state=0, n_jobs=1, 278 | solver="lbfgs", warm_start=False, 279 | max_iter = 1000, 280 | ).fit(X0, y) 281 | reduced = -log_loss(y, model.predict_proba(X0), normalize=False) 282 | 283 | model = LogisticRegression(penalty=None, random_state=0, n_jobs=1, 284 | solver="lbfgs", warm_start=False, 285 | max_iter = 1000, 286 | ).fit(X1, y) 287 | full = -log_loss(y, model.predict_proba(X1), normalize=False) 288 | chi = -2 * (reduced - full) 289 | return chi2.sf(chi, X1.shape[1] - X0.shape[1]) 290 | 291 | def _convert_to_bool_if_np_bool(value): 292 | if isinstance(value, np.bool_): 293 | return bool(value) 294 | return value -------------------------------------------------------------------------------- /scalex/atac/snapatac2/_embedding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | from scipy.sparse.linalg import eigsh 4 | from sklearn.utils.extmath import randomized_svd 5 | import logging 6 | from typing import Optional, Tuple 7 | 8 | logging.basicConfig(level=logging.INFO) 9 | 10 | def normalize(matrix: sp.csr_matrix, feature_weights: np.ndarray) -> None: 11 | """ 12 | Applies feature weighting and L2 normalization on the rows of a sparse matrix. 13 | """ 14 | matrix = matrix.copy() 15 | matrix.data *= feature_weights[matrix.indices] 16 | 17 | row_norms = np.sqrt(matrix.multiply(matrix).sum(axis=1)).A1 18 | row_norms[row_norms == 0] = 1.0 # Avoid division by zero 19 | matrix = matrix.multiply(sp.diags(1.0 / row_norms)) 20 | 21 | return matrix 22 | 23 | def compute_idf(matrix: sp.csr_matrix) -> np.ndarray: 24 | """ 25 | Computes the inverse document frequency (IDF) for feature weighting. 26 | """ 27 | doc_count = np.diff(matrix.indptr) 28 | n_docs = matrix.shape[0] 29 | idf_values = np.log(n_docs / (doc_count + 1)) # Avoid division by zero 30 | return idf_values 31 | 32 | def spectral_embedding( 33 | anndata: sp.csr_matrix, 34 | selected_features: np.ndarray, 35 | n_components: int, 36 | random_state: int, 37 | feature_weights: Optional[np.ndarray] = None 38 | ) -> Tuple[np.ndarray, np.ndarray]: 39 | """ 40 | Perform spectral embedding using matrix factorization. 41 | """ 42 | logging.info("Performing spectral embedding using matrix factorization...") 43 | 44 | # Select features 45 | matrix = anndata[:, selected_features] 46 | 47 | # Apply feature weighting 48 | if feature_weights is None: 49 | feature_weights = compute_idf(matrix) 50 | matrix = normalize(matrix, feature_weights) 51 | 52 | # Compute eigenvalues and eigenvectors 53 | evals, evecs = eigsh(matrix, k=n_components, which="LM", random_state=random_state) 54 | return evals, evecs 55 | 56 | def spectral_embedding_nystrom( 57 | anndata: sp.csr_matrix, 58 | selected_features: np.ndarray, 59 | n_components: int, 60 | sample_size: int, 61 | weighted_by_degree: bool, 62 | chunk_size: int, 63 | feature_weights: Optional[np.ndarray] = None 64 | ) -> Tuple[np.ndarray, np.ndarray]: 65 | """ 66 | Perform spectral embedding using the Nystrom method. 67 | """ 68 | logging.info("Performing spectral embedding using the Nystrom algorithm...") 69 | 70 | matrix = anndata[:, selected_features] 71 | 72 | # Compute feature weighting 73 | if feature_weights is None: 74 | feature_weights = compute_idf(matrix) 75 | 76 | matrix = normalize(matrix, feature_weights) 77 | 78 | # Sample landmarks 79 | n_samples = matrix.shape[0] 80 | rng = np.random.default_rng(2023) 81 | 82 | if weighted_by_degree: 83 | degree_weights = np.array(matrix.sum(axis=1)).flatten() 84 | degree_weights /= degree_weights.sum() 85 | selected_indices = rng.choice(n_samples, size=sample_size, p=degree_weights, replace=False) 86 | else: 87 | selected_indices = rng.choice(n_samples, size=sample_size, replace=False) 88 | 89 | seed_matrix = matrix[selected_indices, :] 90 | 91 | # Compute spectral decomposition of the sample matrix 92 | evals, evecs = eigsh(seed_matrix, k=n_components, which="LM") 93 | 94 | logging.info("Applying Nystrom extension...") 95 | full_matrix = matrix @ evecs @ np.diag(1 / evals) 96 | 97 | return evals, full_matrix 98 | 99 | def multi_spectral_embedding( 100 | anndatas: list, 101 | selected_features: list, 102 | weights: list, 103 | n_components: int, 104 | random_state: int 105 | ) -> Tuple[np.ndarray, np.ndarray]: 106 | """ 107 | Perform multi-view spectral embedding by concatenating feature spaces. 108 | """ 109 | logging.info("Computing normalized views...") 110 | 111 | weighted_matrices = [] 112 | for data, features in zip(anndatas, selected_features): 113 | matrix = data[:, features] 114 | feature_weights = compute_idf(matrix) 115 | matrix = normalize(matrix, feature_weights) 116 | 117 | norm_factor = np.linalg.norm(matrix.data) # Frobenius norm approximation 118 | weighted_matrices.append(matrix * (weights / norm_factor)) 119 | 120 | combined_matrix = sp.hstack(weighted_matrices) 121 | 122 | logging.info("Computing embedding...") 123 | evals, evecs = eigsh(combined_matrix, k=n_components, which="LM", random_state=random_state) 124 | 125 | return evals, evecs 126 | -------------------------------------------------------------------------------- /scalex/atac/snapatac2/_knn.py: -------------------------------------------------------------------------------- 1 | "Adapted from snapatac2: https://github.com/kaizhang/SnapATAC2/blob/main/snapatac2-python/python/snapatac2/preprocessing/_knn.py" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Literal 6 | import numpy as np 7 | from scipy.sparse import csr_matrix 8 | from anndata import AnnData 9 | 10 | # from snapatac2._utils import is_anndata 11 | # import snapatac2._snapatac2 as internal 12 | 13 | def is_anndata(adata): 14 | """ 15 | Check if the input is an AnnData object. 16 | """ 17 | return isinstance(adata, AnnData) 18 | 19 | 20 | import numpy as np 21 | from scipy.spatial import cKDTree 22 | from sklearn.neighbors import NearestNeighbors 23 | 24 | def nearest_neighbour_graph(data: np.ndarray, k: int) -> np.ndarray: 25 | """ 26 | Compute the k-nearest neighbor graph using exact nearest neighbor search. 27 | 28 | Parameters: 29 | - data: 2D NumPy array of shape (n_samples, n_features). 30 | - k: Number of neighbors to find. 31 | 32 | Returns: 33 | - Adjacency matrix (sparse) representing the nearest neighbor graph. 34 | """ 35 | tree = cKDTree(data) 36 | distances, indices = tree.query(data, k=k + 1) # +1 to include self 37 | indices = indices[:, 1:] # Remove self-neighbors 38 | return indices 39 | 40 | def approximate_nearest_neighbour_graph(data: np.ndarray, k: int) -> np.ndarray: 41 | """ 42 | Compute the k-nearest neighbor graph using an approximate nearest neighbor search. 43 | 44 | Parameters: 45 | - data: 2D NumPy array of shape (n_samples, n_features). 46 | - k: Number of neighbors to find. 47 | 48 | Returns: 49 | - Adjacency matrix (sparse) representing the nearest neighbor graph. 50 | """ 51 | nn = NearestNeighbors(n_neighbors=k + 1, algorithm="auto").fit(data) 52 | distances, indices = nn.kneighbors(data) 53 | indices = indices[:, 1:] # Remove self-neighbors 54 | return indices 55 | 56 | 57 | def knn( 58 | adata: AnnData | np.ndarray, 59 | n_neighbors: int = 50, 60 | use_dims: int | list[int] | None = None, 61 | use_rep: str = 'X_spectral', 62 | method: Literal['kdtree', 'hora', 'pynndescent'] = "kdtree", 63 | inplace: bool = True, 64 | random_state: int = 0, 65 | ) -> csr_matrix | None: 66 | """ 67 | Compute a neighborhood graph of observations. 68 | 69 | Computes a neighborhood graph of observations stored in `adata` using 70 | the method specified by `method`. The distance metric used is Euclidean. 71 | 72 | Parameters 73 | ---------- 74 | adata 75 | Annotated data matrix or numpy array. 76 | n_neighbors 77 | The number of nearest neighbors to be searched. 78 | use_dims 79 | The dimensions used for computation. 80 | use_rep 81 | The key for the matrix 82 | method 83 | Can be one of the following: 84 | - 'kdtree': use the kdtree algorithm to find the nearest neighbors. 85 | - 'hora': use the HNSW algorithm to find the approximate nearest neighbors. 86 | - 'pynndescent': use the pynndescent algorithm to find the approximate nearest neighbors. 87 | inplace 88 | Whether to store the result in the anndata object. 89 | random_state 90 | Random seed for approximate nearest neighbor search. 91 | Note that this is only used when `method='pynndescent'`. 92 | Currently 'hora' does not support random seed, so the result of 'hora' is not reproducible. 93 | 94 | Returns 95 | ------- 96 | csr_matrix | None 97 | if `inplace=True`, store KNN in `.obsp['distances']`. 98 | Otherwise, return a sparse matrix. 99 | """ 100 | if is_anndata(adata): 101 | data = adata.obsm[use_rep] 102 | else: 103 | inplace = False 104 | data = adata 105 | if data.size == 0: 106 | raise ValueError("matrix is empty") 107 | 108 | if use_dims is not None: 109 | if isinstance(use_dims, int): 110 | data = data[:, :use_dims] 111 | else: 112 | data = data[:, use_dims] 113 | 114 | n = data.shape[0] 115 | if method == 'hora': 116 | adj = approximate_nearest_neighbour_graph( 117 | data.astype(np.float32), n_neighbors) 118 | elif method == 'pynndescent': 119 | import pynndescent 120 | index = pynndescent.NNDescent(data, n_neighbors=max(50, n_neighbors), random_state=random_state) 121 | adj, distances = index.neighbor_graph 122 | indices = np.ravel(adj[:, :n_neighbors]) 123 | distances = np.ravel(distances[:, :n_neighbors]) 124 | indptr = np.arange(0, distances.size + 1, n_neighbors) 125 | adj = csr_matrix((distances, indices, indptr), shape=(n, n)) 126 | adj.sort_indices() 127 | elif method == 'kdtree': 128 | adj = nearest_neighbour_graph(data, n_neighbors) 129 | else: 130 | raise ValueError("method must be one of 'hora', 'pynndescent', 'kdtree'") 131 | 132 | if inplace: 133 | adata.obsp['distances'] = adj 134 | else: 135 | return adj -------------------------------------------------------------------------------- /scalex/atac/snapatac2/_misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from snapatac2 3 | """ 4 | 5 | # from __future__ import annotations 6 | 7 | from typing import Literal 8 | import logging 9 | from pathlib import Path 10 | import numpy as np 11 | import functools 12 | from anndata import AnnData 13 | 14 | from ._knn import knn 15 | from scanpy.tl import leiden 16 | # import snapatac2._snapatac2 as internal 17 | # from snapatac2._utils import is_anndata 18 | # from snapatac2.tools import leiden 19 | # from snapatac2.preprocessing import knn 20 | 21 | __all__ = ['aggregate_X', 'aggregate_cells'] 22 | 23 | def is_anndata(adata): 24 | """ 25 | Check if the input is an AnnData object. 26 | """ 27 | return isinstance(adata, AnnData) 28 | 29 | def aggregate_X( 30 | adata: AnnData, 31 | groupby: str | list[str] | None = None, 32 | normalize: Literal["RPM", "RPKM"] | None = None, 33 | file: Path | None = None, 34 | ) -> np.ndarray | AnnData: 35 | """ 36 | Aggregate values in adata.X in a row-wise fashion. 37 | 38 | Aggregate values in adata.X in a row-wise fashion. This is used to compute 39 | RPKM or RPM values stratified by user-provided groupings. 40 | 41 | Parameters 42 | ---------- 43 | adata 44 | The AnnData or AnnDataSet object. 45 | groupby 46 | Group the cells into different groups. If a `str`, groups are obtained 47 | from `.obs[groupby]`. 48 | normalize 49 | normalization method: "RPM" or "RPKM". 50 | file 51 | if provided, the results will be saved to a new h5ad file. 52 | 53 | Returns 54 | ------- 55 | np.ndarray | AnnData 56 | If `grouby` is `None`, return a 1d array. Otherwise, return an AnnData 57 | object. 58 | """ 59 | from natsort import natsorted 60 | from anndata import AnnData 61 | 62 | def norm(x): 63 | if normalize is None: 64 | return x 65 | elif normalize == "RPKM": 66 | size_factor = _get_sizes(adata.var_names) / 1000.0 67 | return _normalize(x, size_factor) 68 | elif normalize == "RPM": 69 | return _normalize(x) 70 | else: 71 | raise NameError("Normalization method must be 'RPKM' or 'RPM'") 72 | 73 | if groupby is None: 74 | row_sum = functools.reduce( 75 | lambda a, b: a + b, 76 | (np.ravel(chunk.sum(axis=0)) for chunk, _, _ in adata.chunked_X(1000)), 77 | ) 78 | row_sum = norm(row_sum) 79 | return row_sum 80 | else: 81 | groups = adata.obs[groupby].to_numpy() if isinstance(groupby, str) else np.array(groupby) 82 | if groups.size != adata.n_obs: 83 | raise NameError("the length of `groupby` should equal to the number of obervations") 84 | 85 | result = {x: np.zeros(adata.n_vars) for x in natsorted(np.unique(groups))} 86 | for chunk, start, stop in adata.chunked_X(2000): 87 | for i in range(start, stop): 88 | result[groups[i]] += chunk[i-start, :] 89 | for k in result.keys(): 90 | result[k] = norm(np.ravel(result[k])) 91 | 92 | keys, values = zip(*result.items()) 93 | if file is None: 94 | out_adata = AnnData(X=np.array(values)) 95 | else: 96 | out_adata = AnnData(filename=file, X=np.array(values)) 97 | out_adata.obs_names = list(keys) 98 | out_adata.obs[groupby] = keys 99 | out_adata.var_names = adata.var_names 100 | return out_adata 101 | 102 | def aggregate_cells( 103 | adata: AnnData | np.ndarray, 104 | use_rep: str = 'X_spectral', 105 | target_num_cells: int | None = None, 106 | min_cluster_size: int = 50, 107 | random_state: int = 0, 108 | key_added: str = 'pseudo_cell', 109 | inplace: bool = True, 110 | ) -> np.ndarray | None: 111 | """Aggregate cells into pseudo-cells. 112 | 113 | Aggregate cells into pseudo-cells by iterative clustering. 114 | 115 | Parameters 116 | ---------- 117 | adata 118 | AnnData or AnnDataSet object or matrix. 119 | use_rep 120 | `adata.obs` key for retrieving the input matrix. 121 | target_num_cells 122 | If None, `target_num_cells = num_cells / min_cluster_size`. 123 | min_cluster_size 124 | The minimum size of clusters. 125 | random_state 126 | Change the initialization of the optimization. 127 | key_added 128 | `adata.obs` key under which to add the cluster labels. 129 | inplace 130 | Whether to store the result in the anndata object. 131 | 132 | Returns 133 | ------- 134 | np.ndarray | None 135 | If `inplace=False`, return the result as a numpy array. 136 | Otherwise, store the result in `adata.obs[`key_added`]`. 137 | """ 138 | def clustering(data): 139 | return leiden(knn(data), resolution=1, objective_function='modularity', 140 | min_cluster_size=min_cluster_size, random_state=random_state) 141 | 142 | if is_anndata(adata): 143 | X = adata.obsm[use_rep] 144 | else: 145 | inplace = False 146 | X = adata 147 | 148 | if target_num_cells is None: 149 | target_num_cells = X.shape[0] // min_cluster_size 150 | 151 | logging.info("Perform initial clustering ...") 152 | membership = clustering(X).astype('object') 153 | cluster_ids = [x for x in np.unique(membership) if x != "-1"] 154 | ids_next = cluster_ids 155 | n_clusters = len(cluster_ids) 156 | depth = 0 157 | while n_clusters < target_num_cells and len(ids_next) > 0: 158 | depth += 1 159 | logging.info("Iterative clustering: {}, number of clusters: {}".format(depth, n_clusters)) 160 | ids = set() 161 | for cid in ids_next: 162 | mask = membership == cid 163 | sub_clusters = clustering(X[mask, :]) 164 | n_sub_clusters = np.count_nonzero(np.unique(sub_clusters) != "-1") 165 | if n_sub_clusters > 1 and np.count_nonzero(sub_clusters != "-1") / sub_clusters.shape[0] > 0.9: 166 | n_clusters += n_sub_clusters - 1 167 | for i, i_ in enumerate(np.where(mask)[0]): 168 | lab = sub_clusters[i] 169 | if lab == "-1": 170 | membership[i_] = lab 171 | else: 172 | new_lab = membership[i_] + "." + lab 173 | membership[i_] = new_lab 174 | ids.add(new_lab) 175 | if n_clusters >= target_num_cells: 176 | break 177 | ids_next = ids 178 | logging.info("Asked for {} pseudo-cells; Got: {}.".format(target_num_cells, n_clusters)) 179 | 180 | if inplace: 181 | import polars 182 | adata.obs[key_added] = polars.Series( 183 | [str(x) for x in membership], 184 | dtype=polars.datatypes.Categorical, 185 | ) 186 | else: 187 | return membership 188 | 189 | def marker_enrichment( 190 | gene_matrix: AnnData, 191 | groupby: str | list[str], 192 | markers: dict[str, list[str]], 193 | min_num_markers: int = 1, 194 | hierarchical: bool = True, 195 | ): 196 | """ 197 | Parameters 198 | ---------- 199 | gene_matrix 200 | The cell by gene activity matrix. 201 | groupby 202 | Group the cells into different groups. If a `str`, groups are obtained from 203 | `.obs[groupby]`. 204 | """ 205 | from scipy.stats import zscore 206 | import polars as pl 207 | 208 | gene_names = dict((x.upper(), i) for i, x in enumerate(gene_matrix.var_names)) 209 | retained = [] 210 | removed = [] 211 | for key in markers.keys(): 212 | genes = [] 213 | for name in markers[key]: 214 | name = name.upper() 215 | if name in gene_names: 216 | genes.append(gene_names[name]) 217 | if len(genes) >= min_num_markers: 218 | retained.append((key, genes)) 219 | else: 220 | removed.append(key) 221 | if len(removed) > 0: 222 | logging.warn("The following cell types are not annotated because they have less than {} marker genes: {}", min_num_markers, removed) 223 | 224 | aggr_counts = aggregate_X(gene_matrix, groupby=groupby, normalize="RPM") 225 | zscores = zscore( 226 | np.log2(np.vstack(list(aggr_counts.values())) + 1), 227 | axis = 0, 228 | ) 229 | 230 | if hierarchical: 231 | return _hierarchical_enrichment(dict(retained), zscores) 232 | else: 233 | df = pl.DataFrame( 234 | np.vstack([zscores[:, genes].mean(axis = 1) for _, genes in retained]), 235 | columns = list(aggr_counts.keys()), 236 | ) 237 | df.insert_at_idx(0, pl.Series("Cell type", [cell_type for cell_type, _ in retained])) 238 | return df 239 | 240 | def _hierarchical_enrichment( 241 | marker_genes, 242 | zscores, 243 | ): 244 | from scipy.cluster.hierarchy import linkage, to_tree 245 | from collections import Counter 246 | 247 | def jaccard_distances(x): 248 | def jaccard(a, b): 249 | a = set(a) 250 | b = set(b) 251 | return 1 - len(a.intersection(b)) / len(a.union(b)) 252 | 253 | result = [] 254 | n = len(x) 255 | for i in range(n): 256 | for j in range(i+1, n): 257 | result.append(jaccard(x[i], x[j])) 258 | return result 259 | 260 | def make_tree(Z, genes, labels): 261 | def get_genes_weighted(node, node2 = None): 262 | leaves = node.pre_order(lambda x: x.id) 263 | if node2 is not None: 264 | leaves = leaves + node2.pre_order(lambda x: x.id) 265 | n = len(leaves) 266 | count = Counter(g for i in leaves for g in genes[i]) 267 | for key in count.keys(): 268 | count[key] /= n 269 | return count 270 | 271 | def normalize_weights(a, b): 272 | a_ = [] 273 | for k, v in a.items(): 274 | if k in b: 275 | v = v - b[k] 276 | if v > 0: 277 | a_.append((k, v)) 278 | return a_ 279 | 280 | def process(pid, x, score): 281 | scores.append(score) 282 | parents.append(pid) 283 | ids.append(x.id) 284 | if x.id < len(labels): 285 | labels_.append(labels[x.id]) 286 | else: 287 | labels_.append("") 288 | go(x) 289 | 290 | def norm(b, x): 291 | return np.sqrt(np.exp(b) * np.exp(x)) 292 | 293 | def go(tr): 294 | def sc_fn(gene_w): 295 | if len(gene_w) > 0: 296 | idx, ws = zip(*gene_w) 297 | return np.average(zscores[:, list(idx)], axis = 1, weights=list(ws)) 298 | else: 299 | return np.zeros(zscores.shape[0]) 300 | 301 | left = tr.left 302 | right = tr.right 303 | if left is not None and right is not None: 304 | genes_left = get_genes_weighted(left) 305 | genes_right = get_genes_weighted(right) 306 | base = sc_fn(list(get_genes_weighted(left, right).items())) 307 | sc_left = sc_fn(normalize_weights(genes_left, genes_right)) 308 | sc_right = sc_fn(normalize_weights(genes_right, genes_left)) 309 | process(tr.id, left, norm(base, sc_left)) 310 | process(tr.id, right, norm(base, sc_right)) 311 | 312 | root = to_tree(Z) 313 | ids = [root.id] 314 | parents = [""] 315 | labels_ = [""] 316 | scores = [np.zeros(zscores.shape[0])] 317 | go(root) 318 | return (ids, parents, labels_, np.vstack(scores).T) 319 | 320 | jm = jaccard_distances([v for v in marker_genes.values()]) 321 | Z = linkage(jm, method='average') 322 | return make_tree( 323 | Z, list(marker_genes.values()), list(marker_genes.keys()), 324 | ) 325 | 326 | 327 | def _groupby(x, groups): 328 | idx = groups.argsort() 329 | groups = groups[idx] 330 | x = x[idx] 331 | u, indices = np.unique(groups, return_index=True) 332 | splits = np.split(np.arange(x.shape[0]), indices[1:]) 333 | return dict((label, x[indices, :]) for (label, indices) in zip(u, splits)) 334 | 335 | def _normalize(x, size_factor = None): 336 | result = x / (x.sum() / 1000000.0) 337 | if size_factor is not None: 338 | result /= size_factor 339 | return result 340 | 341 | def _get_sizes(regions): 342 | def size(x): 343 | x = x.split(':')[1].split("-") 344 | return int(x[1]) - int(x[0]) 345 | return np.array(list(size(x) for x in regions)) -------------------------------------------------------------------------------- /scalex/atac/snapatac2/_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | 4 | from anndata import AnnData 5 | 6 | def is_anndata(data) -> bool: 7 | return isinstance(data, AnnData) 8 | 9 | def anndata_par(adatas, func, n_jobs=4): 10 | return anndata_ipar(list(enumerate(adatas)), lambda x: func(x[1]), n_jobs=n_jobs) 11 | 12 | def anndata_ipar(inputs, func, n_jobs=4): 13 | from tqdm import tqdm 14 | 15 | exist_in_memory_adata = False 16 | for _, adata in inputs: 17 | if isinstance(adata, AnnData): 18 | exist_in_memory_adata = True 19 | break 20 | if exist_in_memory_adata: 21 | logging.warn(("Input contains in-memory AnnData objects. " 22 | "Multiprocessing will not be used. " 23 | "To enable multiprocessing, use backed AnnData objects")) 24 | return [func((i, adata)) for i, adata in tqdm(inputs)] 25 | else: 26 | from multiprocess import get_context 27 | 28 | def _func(x): 29 | adata = internal.read(x[1], backend=x[2]) 30 | result = func((x[0], adata)) 31 | adata.close() 32 | return result 33 | 34 | # Close the AnnData objects and return the filenames 35 | files = [] 36 | for i, adata in inputs: 37 | files.append((i, adata.filename, adata.backend)) 38 | adata.close() 39 | 40 | with get_context("spawn").Pool(n_jobs) as p: 41 | result = list(tqdm(p.imap(_func, files), total=len(files))) 42 | 43 | # Reopen the files if they were closed 44 | for _, adata in inputs: 45 | adata.open(mode='r+') 46 | 47 | return result 48 | 49 | def get_file_format(suffix): 50 | suffix = suffix.lower() 51 | _suffix = suffix 52 | 53 | if suffix.endswith(".gz"): 54 | compression = "gzip" 55 | _suffix = suffix[:-3] 56 | elif suffix.endswith(".zst"): 57 | compression = "zstandard" 58 | _suffix = suffix[:-4] 59 | else: 60 | compression = None 61 | 62 | if suffix.endswith(".bw") or suffix.endswith(".bigwig"): 63 | format = "bigwig" 64 | elif _suffix.endswith(".bedgraph") or _suffix.endswith(".bg") or _suffix.endswith(".bdg"): 65 | format = "bedgraph" 66 | else: 67 | format = None 68 | 69 | return format, compression 70 | 71 | def get_igraph_from_adjacency(adj): 72 | """Get igraph graph from adjacency matrix.""" 73 | import igraph as ig 74 | vcount = max(adj.shape) 75 | sources, targets = adj.nonzero() 76 | edgelist = list(zip(list(sources), list(targets))) 77 | weights = np.ravel(adj[(sources, targets)]) 78 | gr = ig.Graph(n=vcount, edges=edgelist, directed=False, edge_attrs={"weight": weights}) 79 | return gr 80 | 81 | def chunks(mat, chunk_size: int): 82 | """ 83 | Return chunks of the input matrix 84 | """ 85 | n = mat.shape[0] 86 | for i in range(0, n, chunk_size): 87 | j = max(i + chunk_size, n) 88 | yield mat[i:j, :] 89 | 90 | def find_elbow(x, saturation=0.01): 91 | accum_gap = 0 92 | for i in range(1, len(x)): 93 | gap = x[i-1] - x[i] 94 | accum_gap = accum_gap + gap 95 | if gap < saturation * accum_gap: 96 | return i 97 | return None 98 | 99 | def fetch_seq(fasta, region): 100 | chr, x = region.split(':') 101 | start, end = x.split('-') 102 | start = int(start) 103 | end = int(end) 104 | seq = fasta[chr][start:end].seq 105 | l1 = len(seq) 106 | l2 = end - start 107 | if l1 != l2: 108 | raise NameError( 109 | "sequence fetch error: expected length: {}, but got {}.".format(l2, l1) 110 | ) 111 | else: 112 | return seq 113 | 114 | def pcorr(A, B): 115 | """Compute pairwsie correlation between two matrices. 116 | 117 | A 118 | n_sample x n_feature 119 | B 120 | n_sample x n_feature 121 | """ 122 | N = B.shape[0] 123 | 124 | sA = A.sum(0) 125 | sB = B.sum(0) 126 | 127 | p1 = N * np.einsum('ij,ik->kj', A, B) 128 | p2 = sA * sB[:,None] 129 | p3 = N * ((B**2).sum(0)) - (sB**2) 130 | p4 = N * ((A**2).sum(0)) - (sA**2) 131 | 132 | return (p1 - p2) / np.sqrt(p4 * p3[:,None]) -------------------------------------------------------------------------------- /scalex/linkage/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsxlei/SCALEX/7a410ffe11b1d2139d3c66961ee90f77136c26fa/scalex/linkage/__init__.py -------------------------------------------------------------------------------- /scalex/linkage/linkage.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import os 4 | import seaborn as sns 5 | import matplotlib.pyplot as plt 6 | 7 | import matplotlib 8 | matplotlib.rcParams['pdf.fonttype'] = 42 9 | 10 | from scalex.atac.snapatac2._diff import aggregate_X 11 | from scalex.data import aggregate_data 12 | 13 | from scalex.atac.bedtools import intersect_bed, bed_to_df, df_to_bed 14 | from scalex.pp.annotation import format_rna, format_atac 15 | from scalex.analysis import get_markers, flatten_dict 16 | # from scalex.linkage.utils import row_wise_correlation 17 | 18 | class Linkage: 19 | def __init__(self, rna, atac, groupby='cell_type', gtf=None, up=100_000, down=100_000): 20 | self.rna = rna 21 | self.atac = atac 22 | self.gtf = gtf if gtf is not None else os.path.expanduser('~/.scalex/gencode.v38.annotation.gtf.gz') 23 | self.up = up 24 | self.down = down 25 | 26 | self.format_data() 27 | self.aggregate_data(groupby=groupby) 28 | self.get_edges(up=up, down=down) 29 | self.get_peak_gene_corr() 30 | 31 | 32 | def format_data(self): 33 | self.rna = format_rna(self.rna.copy(), gtf=self.gtf) 34 | self.atac = format_atac(self.atac.copy()) 35 | self.rna_var = self.rna.var 36 | self.atac_var = self.atac.var 37 | 38 | def aggregate_data(self, groupby="cell_type"): 39 | self.rna_agg = aggregate_data(self.rna, groupby=groupby) 40 | self.atac_agg = aggregate_X(self.atac, groupby=groupby, normalize="RPKM") 41 | 42 | common = set(self.rna_agg.obs[groupby].values) & set(self.atac_agg.obs[groupby].values) 43 | self.rna_agg = self.rna_agg[self.rna_agg.obs[groupby].isin(common)] 44 | self.atac_agg = self.atac_agg[self.atac_agg.obs[groupby].isin(common)] 45 | print(common) 46 | 47 | def get_edges(self, up=100_000, down=100_000): 48 | # assert self.rna_agg.obs[self.groupby] == self.atac_agg.obs[self.groupby], "groupby should have the same order" 49 | ## peak to gene linkage 50 | self.edges = intersect_bed(self.rna_var, self.atac_var, up=up, down=down, add_distance=True) 51 | self.gene_to_index = {gene: i for i, gene in enumerate(self.rna_var.index)} 52 | self.peak_to_index = {peak: i for i, peak in enumerate(self.atac_var.index)} 53 | 54 | def get_peak_gene_corr(self): 55 | peak_vec = self.atac_agg.to_df().iloc[:, self.edges[1]].T.values 56 | genes_vec = self.rna_agg.to_df().iloc[:, self.edges[0]].T.values 57 | 58 | self.peak_gene_corr = row_wise_correlation(genes_vec, peak_vec) 59 | 60 | def filter_peak_gene_linkage(self, threshold=0.7): 61 | indices = np.where(self.peak_gene_corr > threshold)[0] 62 | gene_indices = self.edges[0][indices] 63 | peak_indices = self.edges[1][indices] 64 | 65 | genes = self.rna_var.iloc[gene_indices].index.values 66 | peaks = self.atac_var.iloc[peak_indices].index.values 67 | 68 | return genes, peaks 69 | 70 | def find_overlapping_peak_gene_linkage(self, genes_deg, peaks_deg, threshold=0.7): 71 | genes, peaks = self.filter_peak_gene_linkage(threshold=threshold) 72 | if isinstance(genes_deg, dict): 73 | genes_deg = flatten_dict(genes_deg) 74 | if isinstance(peaks_deg, dict): 75 | peaks_deg = flatten_dict(peaks_deg) 76 | 77 | gene_peak_pair = [(genes[i], peaks[i]) for i, _ in enumerate(genes) if genes[i] in genes_deg and peaks[i] in peaks_deg] 78 | 79 | return gene_peak_pair 80 | 81 | def row_wise_correlation(arr1, arr2, epsilon=1e-8): 82 | """ 83 | Calculates the Pearson correlation coefficient between corresponding rows of two NumPy arrays, 84 | with robustness to small values and division by zero. 85 | 86 | Parameters: 87 | - arr1: NumPy array of shape (m, n) 88 | - arr2: NumPy array of shape (m, n) 89 | - epsilon: Small constant to avoid division by zero (default: 1e-8) 90 | 91 | Returns: 92 | - correlations: NumPy array of shape (m,) 93 | """ 94 | assert arr1.shape == arr2.shape, "Arrays must have the same shape." 95 | 96 | # Compute means 97 | mean1 = np.mean(arr1, axis=1, keepdims=True) 98 | mean2 = np.mean(arr2, axis=1, keepdims=True) 99 | 100 | # Compute standard deviations 101 | std1 = np.std(arr1, axis=1, ddof=1, keepdims=True) 102 | std2 = np.std(arr2, axis=1, ddof=1, keepdims=True) 103 | 104 | # Avoid division by zero by adding epsilon 105 | safe_std1 = np.where(std1 < epsilon, np.nan, std1) 106 | safe_std2 = np.where(std2 < epsilon, np.nan, std2) 107 | 108 | # Standardize the data (z-scores) 109 | z1 = (arr1 - mean1) / safe_std1 110 | z2 = (arr2 - mean2) / safe_std2 111 | 112 | # Compute sum of products of z-scores 113 | sum_of_products = np.nansum(z1 * z2, axis=1) 114 | 115 | # Degrees of freedom 116 | n = arr1.shape[1] 117 | degrees_of_freedom = n - 1 118 | 119 | # Compute Pearson correlation coefficients 120 | correlations = sum_of_products / degrees_of_freedom 121 | 122 | return correlations -------------------------------------------------------------------------------- /scalex/linkage/motif/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsxlei/SCALEX/7a410ffe11b1d2139d3c66961ee90f77136c26fa/scalex/linkage/motif/__init__.py -------------------------------------------------------------------------------- /scalex/linkage/motif/logo_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('pdf') 3 | from matplotlib import pyplot as plt 4 | import logomaker 5 | from typing import List, Union 6 | import os 7 | import h5py 8 | import numpy as np 9 | import pandas as pd 10 | from matplotlib import rcParams 11 | rcParams['pdf.fonttype'] = 42 12 | 13 | 14 | def _plot_weights(array, path, figsize=(10,3)): 15 | """Plot weights as a sequence logo and save to file.""" 16 | 17 | if not os.path.isfile(path): 18 | fig = plt.figure(figsize=figsize) 19 | ax = fig.add_subplot(111) 20 | 21 | df = pd.DataFrame(array, columns=['A', 'C', 'G', 'T']) 22 | df.index.name = 'pos' 23 | 24 | crp_logo = logomaker.Logo(df, ax=ax) 25 | crp_logo.style_spines(visible=False) 26 | plt.ylim(min(df.sum(axis=1).min(), 0), df.sum(axis=1).max()) 27 | 28 | plt.savefig(path) 29 | plt.close() 30 | 31 | else: 32 | pass 33 | 34 | def create_modisco_logos(modisco_h5py: os.PathLike, modisco_logo_dir, trim_threshold, pattern_groups: List[str]): 35 | """Open a modisco results file and create and write logos to file for each pattern.""" 36 | modisco_results = h5py.File(modisco_h5py, 'r') 37 | 38 | tags = [] 39 | 40 | for name in pattern_groups: 41 | if name not in modisco_results.keys(): 42 | continue 43 | 44 | metacluster = modisco_results[name] 45 | # print(metacluster) 46 | key = lambda x: int(x[0].split("_")[-1]) 47 | # key = lambda x: int(x[0].split("#")[-1]) 48 | for pattern_name, pattern in sorted(metacluster.items(), key=key): 49 | tag = pattern_name 50 | tags.append(tag) 51 | 52 | cwm_fwd = np.array(pattern['contrib_scores'][:]) 53 | cwm_rev = cwm_fwd[::-1, ::-1] 54 | 55 | score_fwd = np.sum(np.abs(cwm_fwd), axis=1) 56 | score_rev = np.sum(np.abs(cwm_rev), axis=1) 57 | 58 | trim_thresh_fwd = np.max(score_fwd) * trim_threshold 59 | trim_thresh_rev = np.max(score_rev) * trim_threshold 60 | 61 | pass_inds_fwd = np.where(score_fwd >= trim_thresh_fwd)[0] 62 | pass_inds_rev = np.where(score_rev >= trim_thresh_rev)[0] 63 | 64 | start_fwd, end_fwd = max(np.min(pass_inds_fwd) - 4, 0), min(np.max(pass_inds_fwd) + 4 + 1, len(score_fwd) + 1) 65 | start_rev, end_rev = max(np.min(pass_inds_rev) - 4, 0), min(np.max(pass_inds_rev) + 4 + 1, len(score_rev) + 1) 66 | 67 | trimmed_cwm_fwd = cwm_fwd[start_fwd:end_fwd] 68 | trimmed_cwm_rev = cwm_rev[start_rev:end_rev] 69 | 70 | _plot_weights(trimmed_cwm_fwd, path='{}/{}.cwm.fwd.png'.format(modisco_logo_dir, tag)) 71 | _plot_weights(trimmed_cwm_rev, path='{}/{}.cwm.rev.png'.format(modisco_logo_dir, tag)) 72 | 73 | modisco_results.close() 74 | return tags 75 | 76 | def create_selin_logos(modisco_h5py: os.PathLike, modisco_logo_dir, trim_threshold, pattern_groups: List[str]): 77 | """Open a modisco results file and create and write logos to file for each pattern.""" 78 | modisco_results = h5py.File(modisco_h5py, 'r') 79 | 80 | tags = [] 81 | 82 | for name in pattern_groups: 83 | if name not in modisco_results.keys(): 84 | continue 85 | 86 | metacluster = modisco_results[name] 87 | for pattern_name, pattern in metacluster.items(): 88 | tag = pattern_name.replace('/', '-').replace("#", "-") 89 | tags.append(tag) 90 | 91 | cwm_fwd = np.array(pattern['contrib_scores'][:]) 92 | cwm_rev = cwm_fwd[::-1, ::-1] 93 | 94 | score_fwd = np.sum(np.abs(cwm_fwd), axis=1) 95 | score_rev = np.sum(np.abs(cwm_rev), axis=1) 96 | 97 | trim_thresh_fwd = np.max(score_fwd) * trim_threshold 98 | trim_thresh_rev = np.max(score_rev) * trim_threshold 99 | 100 | pass_inds_fwd = np.where(score_fwd >= trim_thresh_fwd)[0] 101 | pass_inds_rev = np.where(score_rev >= trim_thresh_rev)[0] 102 | 103 | start_fwd, end_fwd = max(np.min(pass_inds_fwd) - 4, 0), min(np.max(pass_inds_fwd) + 4 + 1, len(score_fwd) + 1) 104 | start_rev, end_rev = max(np.min(pass_inds_rev) - 4, 0), min(np.max(pass_inds_rev) + 4 + 1, len(score_rev) + 1) 105 | 106 | trimmed_cwm_fwd = cwm_fwd[start_fwd:end_fwd] 107 | trimmed_cwm_rev = cwm_rev[start_rev:end_rev] 108 | 109 | _plot_weights(trimmed_cwm_fwd, path='{}/{}.cwm.fwd.png'.format(modisco_logo_dir, tag)) 110 | _plot_weights(trimmed_cwm_rev, path='{}/{}.cwm.rev.png'.format(modisco_logo_dir, tag)) 111 | 112 | modisco_results.close() 113 | return tags 114 | 115 | def read_meme(filename): 116 | motifs = {} 117 | 118 | with open(filename, "r") as infile: 119 | motif, width, i = None, None, 0 120 | 121 | for line in infile: 122 | if motif is None: 123 | if line[:5] == 'MOTIF': 124 | motif = line.split()[1] 125 | else: 126 | continue 127 | 128 | elif width is None: 129 | if line[:6] == 'letter': 130 | width = int(line.split()[5]) 131 | pwm = np.zeros((width, 4)) 132 | 133 | elif i < width: 134 | pwm[i] = list(map(float, line.split())) 135 | i += 1 136 | 137 | else: 138 | motifs[motif] = pwm 139 | motif, width, i = None, None, 0 140 | 141 | return motifs 142 | 143 | def compute_per_position_ic(ppm, background, pseudocount): 144 | alphabet_len = len(background) 145 | ic = ((np.log((ppm+pseudocount)/(1 + pseudocount*alphabet_len))/np.log(2)) 146 | *ppm - (np.log(background)*background/np.log(2))[None,:]) 147 | return np.sum(ic,axis=1) 148 | 149 | def make_logo(match, logo_dir, motifs): 150 | if match == 'NA': 151 | return 152 | 153 | background = np.array([0.25, 0.25, 0.25, 0.25]) 154 | ppm = motifs[match] 155 | ic = compute_per_position_ic(ppm, background, 0.001) 156 | 157 | _plot_weights(ppm*ic[:, None], path='{}/{}.png'.format(logo_dir, match)) 158 | 159 | def path_to_image_link(path): 160 | return '=IMAGE("' + path + '#"&RANDBETWEEN(1111111,9999999), 4, 80, 240)' -------------------------------------------------------------------------------- /scalex/linkage/utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | 5 | def correlation_between_df(df1, df2): 6 | df1_standardized = (df1 - df1.mean(axis=1).values.reshape(-1, 1)) / df1.std(axis=1).values.reshape(-1, 1) 7 | df2_standardized = (df2 - df2.mean(axis=1).values.reshape(-1, 1)) / df2.std(axis=1).values.reshape(-1, 1) 8 | 9 | # Calculate the correlation matrix 10 | correlation_matrix = np.dot(df2_standardized.values, df1_standardized.values.T) / df1.shape[1] 11 | 12 | # Convert the result to a DataFrame 13 | correlation_df = pd.DataFrame(correlation_matrix, index=df2.index, columns=df1.index) 14 | return correlation_df 15 | 16 | 17 | def row_wise_correlation(arr1, arr2, epsilon=1e-8): 18 | """ 19 | Calculates the Pearson correlation coefficient between corresponding rows of two NumPy arrays, 20 | with robustness to small values and division by zero. 21 | 22 | Parameters: 23 | - arr1: NumPy array of shape (m, n) 24 | - arr2: NumPy array of shape (m, n) 25 | - epsilon: Small constant to avoid division by zero (default: 1e-8) 26 | 27 | Returns: 28 | - correlations: NumPy array of shape (m,) 29 | """ 30 | assert arr1.shape == arr2.shape, "Arrays must have the same shape." 31 | 32 | # Compute means 33 | mean1 = np.mean(arr1, axis=1, keepdims=True) 34 | mean2 = np.mean(arr2, axis=1, keepdims=True) 35 | 36 | # Compute standard deviations 37 | std1 = np.std(arr1, axis=1, ddof=1, keepdims=True) 38 | std2 = np.std(arr2, axis=1, ddof=1, keepdims=True) 39 | 40 | # Avoid division by zero by adding epsilon 41 | safe_std1 = np.where(std1 < epsilon, np.nan, std1) 42 | safe_std2 = np.where(std2 < epsilon, np.nan, std2) 43 | 44 | # Standardize the data (z-scores) 45 | z1 = (arr1 - mean1) / safe_std1 46 | z2 = (arr2 - mean2) / safe_std2 47 | 48 | # Compute sum of products of z-scores 49 | sum_of_products = np.nansum(z1 * z2, axis=1) 50 | 51 | # Degrees of freedom 52 | n = arr1.shape[1] 53 | degrees_of_freedom = n - 1 54 | 55 | # Compute Pearson correlation coefficients 56 | correlations = sum_of_products / degrees_of_freedom 57 | 58 | return correlations 59 | 60 | -------------------------------------------------------------------------------- /scalex/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | # Author: Xiong Lei 4 | # Created Time : Tue 20 Aug 2019 09:23:19 PM CST 5 | 6 | # File Name: logger.py 7 | # Description: 8 | 9 | """ 10 | 11 | import logging 12 | 13 | def create_logger(name='', ch=True, fh=False, levelname=logging.INFO, overwrite=False): 14 | logger = logging.getLogger(name) 15 | logger.setLevel(levelname) 16 | 17 | if overwrite: 18 | for h in logger.handlers: 19 | logger.removeHandler(h) 20 | 21 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 22 | # handler 23 | if ch: 24 | ch = logging.StreamHandler() 25 | ch.setLevel(logging.INFO) 26 | ch.setFormatter(formatter) 27 | logger.addHandler(ch) 28 | if fh: 29 | fh = logging.FileHandler(fh, mode='w') 30 | fh.setLevel(logging.DEBUG) 31 | fh.setFormatter(formatter) 32 | logger.addHandler(fh) 33 | return logger 34 | 35 | -------------------------------------------------------------------------------- /scalex/metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | # Author: Xiong Lei 4 | # Created Time : Thu 10 Jan 2019 07:38:10 PM CST 5 | 6 | # File Name: metrics.py 7 | # Description: 8 | 9 | """ 10 | 11 | import numpy as np 12 | import scipy 13 | from sklearn.neighbors import NearestNeighbors, KNeighborsRegressor 14 | 15 | 16 | def batch_entropy_mixing_score(data, batches, n_neighbors=100, n_pools=100, n_samples_per_pool=100): 17 | """ 18 | Calculate batch entropy mixing score 19 | 20 | Algorithm 21 | ----- 22 | * 1. Calculate the regional mixing entropies at the location of 100 randomly chosen cells from all batches 23 | * 2. Define 100 nearest neighbors for each randomly chosen cell 24 | * 3. Calculate the mean mixing entropy as the mean of the regional entropies 25 | * 4. Repeat above procedure for 100 iterations with different randomly chosen cells. 26 | 27 | Parameters 28 | ---------- 29 | data 30 | np.array of shape nsamples x nfeatures. 31 | batches 32 | batch labels of nsamples. 33 | n_neighbors 34 | The number of nearest neighbors for each randomly chosen cell. By default, n_neighbors=100. 35 | n_samples_per_pool 36 | The number of randomly chosen cells from all batches per iteration. By default, n_samples_per_pool=100. 37 | n_pools 38 | The number of iterations with different randomly chosen cells. By default, n_pools=100. 39 | 40 | Returns 41 | ------- 42 | Batch entropy mixing score 43 | """ 44 | # print("Start calculating Entropy mixing score") 45 | def entropy(batches): 46 | p = np.zeros(N_batches) 47 | adapt_p = np.zeros(N_batches) 48 | a = 0 49 | for i in range(N_batches): 50 | p[i] = np.mean(batches == batches_[i]) 51 | a = a + p[i]/P[i] 52 | entropy = 0 53 | for i in range(N_batches): 54 | adapt_p[i] = (p[i]/P[i])/a 55 | entropy = entropy - adapt_p[i]*np.log(adapt_p[i]+10**-8) 56 | return entropy 57 | 58 | n_neighbors = min(n_neighbors, len(data) - 1) 59 | nne = NearestNeighbors(n_neighbors=1 + n_neighbors, n_jobs=8) 60 | nne.fit(data) 61 | kmatrix = nne.kneighbors_graph(data) - scipy.sparse.identity(data.shape[0]) 62 | 63 | score = 0 64 | batches_ = np.unique(batches) 65 | N_batches = len(batches_) 66 | if N_batches < 2: 67 | raise ValueError("Should be more than one cluster for batch mixing") 68 | P = np.zeros(N_batches) 69 | for i in range(N_batches): 70 | P[i] = np.mean(batches == batches_[i]) 71 | for t in range(n_pools): 72 | indices = np.random.choice(np.arange(data.shape[0]), size=n_samples_per_pool) 73 | score += np.mean([entropy(batches[kmatrix[indices].nonzero()[1] 74 | [kmatrix[indices].nonzero()[0] == i]]) 75 | for i in range(n_samples_per_pool)]) 76 | Score = score / float(n_pools) 77 | return Score / float(np.log2(N_batches)) 78 | 79 | 80 | from sklearn.metrics import silhouette_score -------------------------------------------------------------------------------- /scalex/net/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsxlei/SCALEX/7a410ffe11b1d2139d3c66961ee90f77136c26fa/scalex/net/__init__.py -------------------------------------------------------------------------------- /scalex/net/layer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | # Author: Xiong Lei 4 | # Created Time : Mon 19 Aug 2019 02:25:11 PM CST 5 | 6 | # File Name: layer.py 7 | # Description: 8 | 9 | """ 10 | import math 11 | import numpy as np 12 | 13 | import torch 14 | from torch import nn as nn 15 | import torch.nn.functional as F 16 | from torch.distributions import Normal 17 | from torch.nn.parameter import Parameter 18 | from torch.nn import init 19 | from torch.autograd import Function 20 | 21 | 22 | activation = { 23 | 'relu':nn.ReLU(), 24 | 'rrelu':nn.RReLU(), 25 | 'sigmoid':nn.Sigmoid(), 26 | 'leaky_relu':nn.LeakyReLU(), 27 | 'tanh':nn.Tanh(), 28 | '':None 29 | } 30 | 31 | 32 | class DSBatchNorm(nn.Module): 33 | """ 34 | Domain-specific Batch Normalization 35 | """ 36 | def __init__(self, num_features, n_domain, eps=1e-5, momentum=0.1): 37 | """ 38 | Parameters 39 | ---------- 40 | num_features 41 | dimension of the features 42 | n_domain 43 | domain number 44 | """ 45 | super().__init__() 46 | self.n_domain = n_domain 47 | self.num_features = num_features 48 | self.bns = nn.ModuleList([nn.BatchNorm1d(num_features, eps=eps, momentum=momentum) for i in range(n_domain)]) 49 | 50 | def reset_running_stats(self): 51 | for bn in self.bns: 52 | bn.reset_running_stats() 53 | 54 | def reset_parameters(self): 55 | for bn in self.bns: 56 | bn.reset_parameters() 57 | 58 | def _check_input_dim(self, input): 59 | raise NotImplementedError 60 | 61 | def forward(self, x, y): 62 | out = torch.zeros(x.size(0), self.num_features, device=x.device) #, requires_grad=False) 63 | for i in range(self.n_domain): 64 | indices = np.where(y.cpu().numpy()==i)[0] 65 | 66 | if len(indices) > 1: 67 | out[indices] = self.bns[i](x[indices]) 68 | elif len(indices) == 1: 69 | out[indices] = x[indices] 70 | # self.bns[i].training = False 71 | # out[indices] = self.bns[i](x[indices]) 72 | # self.bns[i].training = True 73 | return out 74 | 75 | 76 | class Block(nn.Module): 77 | """ 78 | Basic block consist of: 79 | fc -> bn -> act -> dropout 80 | """ 81 | def __init__( 82 | self, 83 | input_dim, 84 | output_dim, 85 | norm='', 86 | act='', 87 | dropout=0 88 | ): 89 | """ 90 | Parameters 91 | ---------- 92 | input_dim 93 | dimension of input 94 | output_dim 95 | dimension of output 96 | norm 97 | batch normalization, 98 | * '' represent no batch normalization 99 | * 1 represent regular batch normalization 100 | * int>1 represent domain-specific batch normalization of n domain 101 | act 102 | activation function, 103 | * relu -> nn.ReLU 104 | * rrelu -> nn.RReLU 105 | * sigmoid -> nn.Sigmoid() 106 | * leaky_relu -> nn.LeakyReLU() 107 | * tanh -> nn.Tanh() 108 | * '' -> None 109 | dropout 110 | dropout rate 111 | """ 112 | super().__init__() 113 | self.fc = nn.Linear(input_dim, output_dim) 114 | 115 | if type(norm) == int: 116 | if norm==1: # TO DO 117 | self.norm = nn.BatchNorm1d(output_dim) 118 | else: 119 | self.norm = DSBatchNorm(output_dim, norm) 120 | else: 121 | self.norm = None 122 | 123 | self.act = activation[act] 124 | 125 | if dropout >0: 126 | self.dropout = nn.Dropout(dropout) 127 | else: 128 | self.dropout = None 129 | 130 | def forward(self, x, y=None): 131 | h = self.fc(x) 132 | if self.norm: 133 | if len(x) == 1: 134 | pass 135 | elif self.norm.__class__.__name__ == 'DSBatchNorm': 136 | h = self.norm(h, y) 137 | else: 138 | h = self.norm(h) 139 | if self.act: 140 | h = self.act(h) 141 | if self.dropout: 142 | h = self.dropout(h) 143 | return h 144 | 145 | 146 | 147 | class NN(nn.Module): 148 | """ 149 | Neural network consist of multi Blocks 150 | """ 151 | def __init__(self, input_dim, cfg): 152 | """ 153 | Parameters 154 | ---------- 155 | input_dim 156 | input dimension 157 | cfg 158 | model structure configuration, 'fc' -> fully connected layer 159 | 160 | Example 161 | ------- 162 | >>> latent_dim = 10 163 | >>> dec_cfg = [['fc', x_dim, n_domain, 'sigmoid']] 164 | >>> decoder = NN(latent_dim, dec_cfg) 165 | """ 166 | super().__init__() 167 | net = [] 168 | for i, layer in enumerate(cfg): 169 | if i==0: 170 | d_in = input_dim 171 | if layer[0] == 'fc': 172 | net.append(Block(d_in, *layer[1:])) 173 | d_in = layer[1] 174 | self.net = nn.ModuleList(net) 175 | 176 | def forward(self, x, y=None): 177 | for layer in self.net: 178 | x = layer(x, y) 179 | return x 180 | 181 | 182 | class Encoder(nn.Module): 183 | """ 184 | VAE Encoder 185 | """ 186 | def __init__(self, input_dim, cfg): 187 | """ 188 | Parameters 189 | ---------- 190 | input_dim 191 | input dimension 192 | cfg 193 | encoder configuration, e.g. enc_cfg = [['fc', 1024, 1, 'relu'],['fc', 10, '', '']] 194 | """ 195 | super().__init__() 196 | h_dim = cfg[-2][1] 197 | self.enc = NN(input_dim, cfg[:-1]) 198 | self.mu_enc = NN(h_dim, cfg[-1:]) 199 | self.var_enc = NN(h_dim, cfg[-1:]) 200 | 201 | def reparameterize(self, mu, var): 202 | return Normal(mu, var.sqrt()).rsample() 203 | 204 | def forward(self, x, y=None): 205 | """ 206 | """ 207 | q = self.enc(x, y) 208 | mu = self.mu_enc(q, y) 209 | var = torch.exp(self.var_enc(q, y)) 210 | z = self.reparameterize(mu, var) 211 | return z, mu, var 212 | 213 | 214 | -------------------------------------------------------------------------------- /scalex/net/loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | # Author: Xiong Lei 4 | # Created Time : Mon 21 Jan 2019 03:00:26 PM CST 5 | 6 | # File Name: loss.py 7 | # Description: 8 | 9 | """ 10 | 11 | import torch 12 | from torch.distributions import Normal, kl_divergence 13 | 14 | def kl_div(mu, var): 15 | return kl_divergence(Normal(mu, var.sqrt()), 16 | Normal(torch.zeros_like(mu),torch.ones_like(var))).sum(dim=1).mean() 17 | 18 | def binary_cross_entropy(recon_x, x): 19 | return -torch.sum(x * torch.log(recon_x + 1e-8) + (1 - x) * torch.log(1 - recon_x + 1e-8), dim=-1) 20 | 21 | -------------------------------------------------------------------------------- /scalex/net/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | # Author: Xiong Lei 4 | # Created Time : Mon 18 Nov 2019 01:25:24 PM CST 5 | 6 | # File Name: utils.py 7 | # Description: 8 | 9 | """ 10 | 11 | import numpy as np 12 | import torch 13 | 14 | 15 | def onehot(y, n): 16 | """ 17 | Make the input tensor one hot tensors 18 | 19 | Parameters 20 | ---------- 21 | y 22 | input tensors 23 | n 24 | number of classes 25 | 26 | Return 27 | ------ 28 | Tensor 29 | """ 30 | if (y is None) or (n<2): 31 | return None 32 | assert torch.max(y).item() < n 33 | y = y.view(y.size(0), 1) 34 | y_cat = torch.zeros(y.size(0), n).to(y.device) 35 | y_cat.scatter_(1, y.data, 1) 36 | return y_cat 37 | 38 | 39 | class EarlyStopping: 40 | """ 41 | Early stops the training if loss doesn't improve after a given patience. 42 | """ 43 | def __init__(self, patience=10, verbose=False, checkpoint_file=''): 44 | """ 45 | Parameters 46 | ---------- 47 | patience 48 | How long to wait after last time loss improved. Default: 10 49 | verbose 50 | If True, prints a message for each loss improvement. Default: False 51 | """ 52 | self.patience = patience 53 | self.verbose = verbose 54 | self.counter = 0 55 | self.best_score = None 56 | self.early_stop = False 57 | self.loss_min = np.Inf 58 | self.checkpoint_file = checkpoint_file 59 | 60 | def __call__(self, loss, model): 61 | if np.isnan(loss): 62 | self.early_stop = True 63 | score = -loss 64 | 65 | if self.best_score is None: 66 | self.best_score = score 67 | self.save_checkpoint(loss, model) 68 | elif score <= self.best_score: 69 | self.counter += 1 70 | if self.verbose: 71 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 72 | if self.counter >= self.patience: 73 | self.early_stop = True 74 | model.load_model(self.checkpoint_file) 75 | else: 76 | self.best_score = score 77 | self.save_checkpoint(loss, model) 78 | self.counter = 0 79 | 80 | def save_checkpoint(self, loss, model): 81 | ''' 82 | Saves model when loss decrease. 83 | ''' 84 | if self.verbose: 85 | print(f'Loss decreased ({self.loss_min:.6f} --> {loss:.6f}). Saving model ...') 86 | if self.checkpoint_file: 87 | torch.save(model.state_dict(), self.checkpoint_file) 88 | self.loss_min = loss 89 | 90 | -------------------------------------------------------------------------------- /scalex/net/vae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | # Author: Xiong Lei 4 | # Created Time : Mon 18 Nov 2019 01:16:06 PM CST 5 | 6 | # File Name: vae.py 7 | # Description: 8 | 9 | """ 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import numpy as np 14 | from tqdm.autonotebook import trange 15 | from tqdm.contrib import tenumerate 16 | from collections import defaultdict 17 | 18 | from .layer import * 19 | from .loss import * 20 | 21 | 22 | class VAE(nn.Module): 23 | """ 24 | VAE framework 25 | """ 26 | def __init__(self, enc, dec, n_domain=1): 27 | """ 28 | Parameters 29 | ---------- 30 | enc 31 | Encoder structure config 32 | dec 33 | Decoder structure config 34 | n_domain 35 | The number of different domains 36 | """ 37 | super().__init__() 38 | x_dim = dec[-1][1] 39 | z_dim = enc[-1][1] 40 | self.encoder = Encoder(x_dim, enc) 41 | self.decoder = NN(z_dim, dec) 42 | self.n_domain = n_domain 43 | self.x_dim = x_dim 44 | self.z_dim = z_dim 45 | 46 | def load_model(self, path): 47 | """ 48 | Load trained model parameters dictionary. 49 | Parameters 50 | ---------- 51 | path 52 | file path that stores the model parameters 53 | """ 54 | pretrained_dict = torch.load(path, map_location=lambda storage, loc: storage) 55 | model_dict = self.state_dict() 56 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 57 | model_dict.update(pretrained_dict) 58 | self.load_state_dict(model_dict) 59 | 60 | 61 | def forward(self, x, y): 62 | x, y = x.float(), y.long() 63 | 64 | # loss 65 | z, mu, var = self.encoder(x) 66 | recon_x = self.decoder(z, y) 67 | recon_loss = F.binary_cross_entropy(recon_x, x) * x.size(-1) ## TO DO 68 | kl_loss = kl_div(mu, var) 69 | 70 | # acc.append(pearson_corr_coef(recon_x, x)) 71 | loss = {'recon_loss':recon_loss, 'kl_loss':0.5*kl_loss} 72 | return z, recon_x, loss 73 | 74 | def encodeBatch( 75 | self, 76 | dataloader, 77 | device='cuda', 78 | out='latent', 79 | batch_id=None, 80 | return_idx=False, 81 | eval=False 82 | ): 83 | """ 84 | Inference 85 | 86 | Parameters 87 | ---------- 88 | dataloader 89 | An iterable over the given dataset for inference. 90 | device 91 | 'cuda' or 'cpu' for . Default: 'cuda'. 92 | out 93 | The inference layer for output. If 'latent', output latent feature z. If 'impute', output imputed gene expression matrix. Default: 'latent'. 94 | batch_id 95 | If None, use batch 0 decoder to infer for all samples. Else, use the corresponding decoder according to the sample batch id to infer for each sample. 96 | return_idx 97 | Whether return the dataloader sample index. Default: False. 98 | eval 99 | If True, set the model to evaluation mode. If False, set the model to train mode. Default: False. 100 | 101 | Returns 102 | ------- 103 | Inference layer and sample index (if return_idx=True). 104 | """ 105 | self.to(device) 106 | if eval: 107 | self.eval();print('eval mode') 108 | else: 109 | self.train() 110 | indices = np.zeros(dataloader.dataset.shape[0]) 111 | if out == 'latent': 112 | output = np.zeros((dataloader.dataset.shape[0], self.z_dim)) 113 | 114 | for x,y,idx in dataloader: 115 | x = x.float().to(device) 116 | z = self.encoder(x)[1] # z, mu, var 117 | output[idx] = z.detach().cpu().numpy() 118 | indices[idx] = idx 119 | elif out == 'impute': 120 | output = np.zeros((dataloader.dataset.shape[0], self.x_dim)) 121 | 122 | if batch_id in dataloader.dataset.adata.obs['batch'].cat.categories: 123 | batch_id = list(dataloader.dataset.adata.obs['batch'].cat.categories).index(batch_id) 124 | else: 125 | batch_id = 0 126 | 127 | for x,y,idx in dataloader: 128 | x = x.float().to(device) 129 | z = self.encoder(x)[1] # z, mu, var 130 | output[idx] = self.decoder(z, torch.LongTensor([batch_id]*len(z))).detach().cpu().numpy() 131 | indices[idx] = idx 132 | 133 | if return_idx: 134 | return output, indices 135 | else: 136 | return output 137 | 138 | def fit( 139 | self, 140 | dataloader, 141 | lr=2e-4, 142 | max_iteration=30000, 143 | beta=0.5, 144 | early_stopping=None, 145 | device='cuda', 146 | verbose=False, 147 | ): 148 | """ 149 | Fit model 150 | 151 | Parameters 152 | ---------- 153 | dataloader 154 | An iterable over the given dataset for training. 155 | lr 156 | Learning rate. Default: 2e-4. 157 | max_iteration 158 | Max iterations for training. Training one batch_size samples is one iteration. Default: 30000. 159 | beta 160 | The co-efficient of KL-divergence when calculate loss. Default: 0.5. 161 | early_stopping 162 | EarlyStopping class (definite in utils.py) for stoping the training if loss doesn't improve after a given patience. Default: None. 163 | device 164 | 'cuda' or 'cpu' for training. Default: 'cuda'. 165 | verbose 166 | Verbosity, True or False. Default: False. 167 | """ 168 | self.to(device) 169 | optim = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=5e-4) 170 | n_epoch = int(np.ceil(max_iteration/len(dataloader))) 171 | 172 | with trange(n_epoch, total=n_epoch, desc='Epochs') as tq: 173 | for epoch in tq: 174 | tk0 = tenumerate(dataloader, total=len(dataloader), leave=False, desc='Iterations', disable=(not verbose)) 175 | epoch_loss = defaultdict(float) 176 | acc = [] 177 | for i, (x, y, idx) in tk0: 178 | x, y = x.float().to(device), y.long().to(device) 179 | 180 | # loss 181 | z, mu, var = self.encoder(x) 182 | recon_x = self.decoder(z, y) 183 | recon_loss = F.binary_cross_entropy(recon_x, x) * x.size(-1) ## TO DO 184 | kl_loss = kl_div(mu, var) 185 | 186 | # acc.append(pearson_corr_coef(recon_x, x)) 187 | loss = {'recon_loss':recon_loss, 'kl_loss':0.5*kl_loss} 188 | 189 | optim.zero_grad() 190 | sum(loss.values()).backward() 191 | optim.step() 192 | 193 | for k,v in loss.items(): 194 | epoch_loss[k] += loss[k].item() 195 | 196 | info = ','.join(['{}={:.3f}'.format(k, v) for k,v in loss.items()]) 197 | # tk0.set_postfix_str(info) 198 | 199 | 200 | epoch_loss = {k:v/(i+1) for k, v in epoch_loss.items()} 201 | epoch_info = ','.join(['{}={:.3f}'.format(k, v) for k,v in epoch_loss.items()]) 202 | # epoch_info += ',acc={:.3f}'.format(torch.Tensor(acc).mean().item()) 203 | tq.set_postfix_str(epoch_info) 204 | 205 | early_stopping(sum(epoch_loss.values()), self) 206 | if early_stopping.early_stop: 207 | print('EarlyStopping: run {} epoch'.format(epoch+1)) 208 | break 209 | 210 | 211 | def pearson_corr_coef(x, y, dim = 1, reduce_dims = (-1,)): 212 | x_centered = x - x.mean(dim = dim, keepdim = True) 213 | y_centered = y - y.mean(dim = dim, keepdim = True) 214 | return F.cosine_similarity(x_centered, y_centered, dim = dim).mean(dim = reduce_dims) 215 | -------------------------------------------------------------------------------- /scalex/pl/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import logging 5 | 6 | from anndata import AnnData 7 | 8 | from ..atac.snapatac2._misc import aggregate_X 9 | from ..atac.snapatac2._utils import find_elbow, is_anndata 10 | from ._base import render_plot, heatmap, kde2d, scatter, scatter3d 11 | from ._network import network_scores, network_edge_stat 12 | 13 | __all__ = [ 14 | 'tsse', 'frag_size_distr', 'umap', 'network_scores', 'spectral_eigenvalues', 15 | 'regions', 'motif_enrichment', 16 | ] 17 | 18 | def valid_cells( 19 | values, 20 | width: int = 500, 21 | height: int = 400, 22 | **kwargs, 23 | ): 24 | import plotly.graph_objects as go 25 | 26 | values = sorted(values, reverse=True) 27 | result = {} 28 | for x, y in enumerate(values): 29 | x = x + 1 30 | if y in result: 31 | x_, n = result[y] 32 | result[y] = (x_ + x, n + 1) 33 | else: 34 | result[y] = (x, 1) 35 | for y, (x, n) in result.items(): 36 | result[y] = x / n 37 | y, x = zip(*result.items()) 38 | 39 | fig = go.Figure() 40 | fig.add_trace(go.Scatter(x=x, y=y)) 41 | fig.update_xaxes(type="log") 42 | fig.update_yaxes(type="log") 43 | fig.update_layout( 44 | xaxis_title="Barcodes", 45 | yaxis_title="Counts", 46 | ) 47 | 48 | return render_plot(fig, width, height, **kwargs) 49 | 50 | def tsse( 51 | adata: AnnData, 52 | min_fragment: int = 500, 53 | width: int = 500, 54 | height: int = 400, 55 | **kwargs, 56 | ) -> 'plotly.graph_objects.Figure' | None: 57 | """Plot the TSS enrichment vs. number of fragments density figure. 58 | 59 | Parameters 60 | ---------- 61 | adata 62 | Annotated data matrix. 63 | min_fragment 64 | The cells' unique fragments lower than it should be removed 65 | width 66 | The width of the plot 67 | height 68 | The height of the plot 69 | kwargs 70 | Additional arguments passed to :func:`~snapatac2.pl.render_plot` to 71 | control the final plot output. Please see :func:`~snapatac2.pl.render_plot` 72 | for details. 73 | 74 | Returns 75 | ------- 76 | 'plotly.graph_objects.Figure' | None 77 | If `show=False` and `out_file=None`, an `plotly.graph_objects.Figure` will be 78 | returned, which can then be further customized using the plotly API. 79 | 80 | See Also 81 | -------- 82 | render_plot 83 | 84 | Examples 85 | -------- 86 | .. plotly:: 87 | 88 | >>> import snapatac2 as snap 89 | >>> data = snap.read(snap.datasets.pbmc5k(type='h5ad')) 90 | >>> fig = snap.pl.tsse(data, show=False, out_file=None) 91 | >>> fig.show() 92 | """ 93 | if "tsse" not in adata.obs: 94 | raise ValueError("TSS enrichment score is not computed, please run `metrics.tsse` first.") 95 | 96 | selected_cells = np.where(adata.obs["n_fragment"] >= min_fragment)[0] 97 | x = adata.obs["n_fragment"][selected_cells] 98 | y = adata.obs["tsse"][selected_cells] 99 | 100 | fig = kde2d(x, y, log_x=True, log_y=False) 101 | fig.update_layout( 102 | xaxis_title="Number of unique fragments", 103 | yaxis_title="TSS enrichment score", 104 | ) 105 | 106 | return render_plot(fig, width, height, **kwargs) 107 | 108 | def frag_size_distr( 109 | adata: AnnData | np.ndarray, 110 | use_rep: str = "frag_size_distr", 111 | max_recorded_size: int = 1000, 112 | **kwargs, 113 | ) -> 'plotly.graph_objects.Figure' | None: 114 | """ Plot the fragment size distribution. 115 | """ 116 | import plotly.graph_objects as go 117 | 118 | if is_anndata(adata): 119 | if use_rep not in adata.uns or len(adata.uns[use_rep]) <= max_recorded_size: 120 | logging.info("Computing fragment size distribution...") 121 | snapatac2.metrics.frag_size_distr(adata, add_key=use_rep, max_recorded_size=max_recorded_size) 122 | data = adata.uns[use_rep] 123 | else: 124 | data = adata 125 | data = data[:max_recorded_size+1] 126 | 127 | x, y = zip(*enumerate(data)) 128 | # Make a line plot 129 | fig = go.Figure() 130 | fig.add_trace(go.Scatter(x=x[1:], y=y[1:], mode='lines')) 131 | fig.update_layout( 132 | xaxis_title="Fragment size", 133 | yaxis_title="Count", 134 | ) 135 | return render_plot(fig, **kwargs) 136 | 137 | def spectral_eigenvalues( 138 | adata: AnnData, 139 | width: int = 600, 140 | height: int = 400, 141 | show: bool = True, 142 | interactive: bool = True, 143 | out_file: str | None = None, 144 | ) -> 'plotly.graph_objects.Figure' | None: 145 | """Plot the eigenvalues of spectral embedding. 146 | 147 | Parameters 148 | ---------- 149 | adata 150 | Annotated data matrix. 151 | width 152 | The width of the plot 153 | height 154 | The height of the plot 155 | show 156 | Show the figure. 157 | interactive 158 | Whether to make interactive plot 159 | out_file 160 | Path of the output file for saving the output image, end with 161 | '.svg' or '.pdf' or '.png' or '.html'. 162 | 163 | Returns 164 | ------- 165 | 'plotly.graph_objects.Figure' | None 166 | If `show=False` and `out_file=None`, an `plotly.graph_objects.Figure` will be 167 | returned, which can then be further customized using the plotly API. 168 | """ 169 | 170 | import plotly.express as px 171 | import pandas as pd 172 | 173 | data = adata.uns["spectral_eigenvalue"] 174 | 175 | df = pd.DataFrame({"Component": map(str, range(1, data.shape[0] + 1)), "Eigenvalue": data}) 176 | fig = px.scatter(df, x="Component", y="Eigenvalue", template="plotly_white") 177 | n = find_elbow(data) 178 | adata.uns["num_eigen"] = n 179 | fig.add_vline(x=n) 180 | 181 | return render_plot(fig, width, height, interactive, show, out_file) 182 | 183 | def regions( 184 | adata: AnnData, 185 | groupby: str | list[str], 186 | peaks: dict[str, list[str]], 187 | width: float = 600, 188 | height: float = 400, 189 | show: bool = True, 190 | interactive: bool = True, 191 | out_file: str | None = None, 192 | ) -> 'plotly.graph_objects.Figure' | None: 193 | """ 194 | Parameters 195 | ---------- 196 | adata 197 | Annotated data matrix. 198 | groupby 199 | Group the cells into different groups. If a `str`, groups are obtained from 200 | `.obs[groupby]`. 201 | peaks 202 | Peaks of each group. 203 | width 204 | The width of the plot 205 | height 206 | The height of the plot 207 | show 208 | Show the figure 209 | interactive 210 | Whether to make interactive plot 211 | out_file 212 | Path of the output file for saving the output image, end with 213 | '.svg' or '.pdf' or '.png' or '.html'. 214 | 215 | Returns 216 | ------- 217 | 'plotly.graph_objects.Figure' | None 218 | If `show=False` and `out_file=None`, an `plotly.graph_objects.Figure` will be 219 | returned, which can then be further customized using the plotly API. 220 | """ 221 | import polars as pl 222 | import plotly.graph_objects as go 223 | 224 | peaks = np.concatenate([[x for x in p] for p in peaks.values()]) 225 | n = len(peaks) 226 | if n > 50000: 227 | logging.warning(f"Input contains {n} peaks, only 50000 peaks will be plotted.") 228 | np.random.seed(0) 229 | indices = np.random.choice(n, 50000, replace=False) 230 | peaks = peaks[sorted(indices)] 231 | 232 | count = aggregate_X(adata, groupby=groupby, normalize="RPKM") 233 | names = count.obs_names 234 | count = pl.DataFrame(count.X.T) 235 | count.columns = list(names) 236 | idx_map = {x: i for i, x in enumerate(adata.var_names)} 237 | idx = [idx_map[x] for x in peaks] 238 | mat = np.log2(1 + count.to_numpy()[idx, :]) 239 | 240 | trace = go.Heatmap( 241 | x=count.columns, 242 | y=peaks[::-1], 243 | z=mat, 244 | type='heatmap', 245 | colorscale='Viridis', 246 | colorbar={ "title": "log2(1 + RPKM)" }, 247 | ) 248 | data = [trace] 249 | layout = { 250 | "yaxis": { "visible": False, "autorange": "reversed" }, 251 | "xaxis": { "title": groupby }, 252 | } 253 | fig = go.Figure(data=data, layout=layout) 254 | return render_plot(fig, width, height, interactive, show, out_file) 255 | 256 | def umap( 257 | adata: AnnData | np.ndarray, 258 | color: str | np.ndarray | None = None, 259 | use_rep: str = "X_umap", 260 | marker_size: float = None, 261 | marker_opacity: float = 1, 262 | sample_size: int | None = None, 263 | **kwargs, 264 | ) -> 'plotly.graph_objects.Figure' | None: 265 | """Plot the UMAP embedding. 266 | 267 | Parameters 268 | ---------- 269 | adata 270 | Annotated data matrix. 271 | color 272 | If the input is a string, it will be used the key to retrieve values from 273 | `obs`. 274 | use_rep 275 | Use the indicated representation in `.obsm`. 276 | marker_size 277 | Size of the dots. 278 | marker_opacity 279 | Opacity of the dots. 280 | sample_size 281 | If the number of cells is larger than `sample_size`, a random sample of 282 | `sample_size` cells will be used for plotting. 283 | kwargs 284 | Additional arguments passed to :func:`~snapatac2.pl.render_plot` to 285 | control the final plot output. Please see :func:`~snapatac2.pl.render_plot` 286 | for details. 287 | 288 | Returns 289 | ------- 290 | 'plotly.graph_objects.Figure' | None 291 | If `show=False` and `out_file=None`, an `plotly.graph_objects.Figure` will be 292 | returned, which can then be further customized using the plotly API. 293 | """ 294 | from natsort import index_natsorted 295 | 296 | embedding = adata.obsm[use_rep] if is_anndata(adata) else adata 297 | if isinstance(color, str): 298 | groups = adata.obs[color].to_numpy() 299 | else: 300 | groups = color 301 | color = "color" 302 | 303 | if sample_size is not None and embedding.shape[0] > sample_size: 304 | idx = np.random.choice(embedding.shape[0], sample_size, replace=False) 305 | embedding = embedding[idx, :] 306 | if groups is not None: groups = groups[idx] 307 | 308 | if groups is not None: 309 | idx = index_natsorted(groups) 310 | embedding = embedding[idx, :] 311 | groups = [groups[i] for i in idx] 312 | 313 | if marker_size is None: 314 | num_points = embedding.shape[0] 315 | marker_size = (1000 / num_points)**(1/3) * 3 316 | 317 | if embedding.shape[1] >= 3: 318 | return scatter3d(embedding[:, 0], embedding[:, 1], embedding[:, 2], color=groups, 319 | x_label="UMAP-1", y_label="UMAP-2", z_label="UMAP-3", color_label=color, 320 | marker_size=marker_size, marker_opacity=marker_opacity, **kwargs) 321 | else: 322 | return scatter(embedding[:, 0], embedding[:, 1], color=groups, 323 | x_label="UMAP-1", y_label="UMAP-2", color_label=color, 324 | marker_size=marker_size, marker_opacity=marker_opacity, **kwargs) 325 | 326 | def motif_enrichment( 327 | enrichment: list(str, 'pl.DataFrame'), 328 | min_log_fc: float = 1, 329 | max_fdr: float = 0.01, 330 | **kwargs, 331 | ) -> 'plotly.graph_objects.Figure' | None: 332 | """Plot the motif enrichment result. 333 | 334 | Parameters 335 | ---------- 336 | enrichment 337 | Motif enrichment result. 338 | min_log_fc 339 | Retain motifs that satisfy: log2-fold-change >= `min_log_fc`. 340 | max_fdr 341 | Retain motifs that satisfy: FDR <= `max_fdr`. 342 | kwargs 343 | Additional arguments passed to :func:`~snapatac2.pl.render_plot` to 344 | control the final plot output. Please see :func:`~snapatac2.pl.render_plot` 345 | for details. 346 | 347 | Returns 348 | ------- 349 | 'plotly.graph_objects.Figure' | None 350 | If `show=False` and `out_file=None`, an `plotly.graph_objects.Figure` will be 351 | returned, which can then be further customized using the plotly API. 352 | """ 353 | 354 | import pandas as pd 355 | 356 | fc = np.vstack([df['log2(fold change)'] for df in enrichment.values()]) 357 | filter1 = np.apply_along_axis(lambda x: np.any(np.abs(x) >= min_log_fc), 0, fc) 358 | 359 | fdr = np.vstack([df['adjusted p-value'] for df in enrichment.values()]) 360 | filter2 = np.apply_along_axis(lambda x: np.any(x <= max_fdr), 0, fdr) 361 | 362 | passed = np.logical_and(filter1, filter2) 363 | 364 | sign = np.sign(fc[:, passed]) 365 | pvals = np.vstack([df['p-value'].to_numpy()[passed] for df in enrichment.values()]) 366 | minval = np.min(pvals[np.nonzero(pvals)]) 367 | pvals = np.clip(pvals, minval, None) 368 | pvals = sign * np.log(-np.log10(pvals)) 369 | 370 | df = pd.DataFrame( 371 | pvals.T, 372 | columns=list(enrichment.keys()), 373 | index=next(iter(enrichment.values()))['id'].to_numpy()[passed], 374 | ) 375 | 376 | return heatmap( 377 | df.to_numpy(), 378 | row_names=df.index, 379 | column_names=df.columns, 380 | colorscale='RdBu_r', 381 | **kwargs, 382 | ) -------------------------------------------------------------------------------- /scalex/pl/_genometrack.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import pandas as pd 4 | import deepdish as dd 5 | import logomaker 6 | from .GenomeTrack import GenomeTrack 7 | 8 | class ShapTrack(GenomeTrack): 9 | """ 10 | A track that plots GWAS variants within a specified region. 11 | This version hardcodes the input file path and processes data inside the plot method. 12 | The init_data() step is removed and no file input is read from configuration. 13 | """ 14 | 15 | SUPPORTED_ENDINGS = [] 16 | TRACK_TYPE = 'shap' 17 | OPTIONS_TXT = """ 18 | height = 3 19 | title = 20 | file_type = shap 21 | sample = 22 | fold = 23 | """ 24 | 25 | # Default values for properties 26 | DEFAULTS_PROPERTIES = { 27 | } 28 | 29 | # With no file in configuration, we don't need it as necessary property 30 | NECESSARY_PROPERTIES = ['sample', 'fold'] 31 | 32 | SYNONYMOUS_PROPERTIES = {} 33 | POSSIBLE_PROPERTIES = {} 34 | BOOLEAN_PROPERTIES = [] 35 | STRING_PROPERTIES = ['title', 'file_type', 'sample', 'fold'] 36 | FLOAT_PROPERTIES = { 37 | 'height': [0, np.inf], 38 | 'min_value': [-np.inf, np.inf], 39 | 'max_value': [-np.inf, np.inf] 40 | } 41 | INTEGER_PROPERTIES = { 42 | } 43 | 44 | def plot(self, ax, chrom, region_start, region_end): 45 | print("Plotting SHAP ...") 46 | 47 | # Hardcoded file path 48 | variant_shap_dir = '/oak/stanford/groups/akundaje/projects/CRC_finemap/peak_shap/specific_peaks' 49 | sample = self.properties['sample'] 50 | fold = self.properties['fold'] 51 | 52 | shap_h5 = dd.io.load(variant_shap_dir + '/' + sample + '/' + fold + '/' + sample + '.' + fold + '.specific_peaks.counts_scores.counts_scores.h5') 53 | shap_peaks = pd.read_table(variant_shap_dir + '/' + sample + '/' + fold + '/' + sample + '.' + fold + '.specific_peaks.counts_scores.interpreted_regions.bed', 54 | header=None) 55 | 56 | shap_peaks_with_flanks = shap_peaks.copy() 57 | shap_peaks_with_flanks['start'] = (shap_peaks[1] + shap_peaks[9]) - (shap_h5['projected_shap']['seq'].shape[2] // 2) 58 | shap_peaks_with_flanks['end'] = (shap_peaks[1] + shap_peaks[9]) + (shap_h5['projected_shap']['seq'].shape[2] // 2) 59 | shap_peaks_with_flanks.sort_values(by=[0, 'start', 'end'], inplace=True) 60 | 61 | region_peaks = shap_peaks_with_flanks[(shap_peaks_with_flanks[0] == chrom) & (((shap_peaks_with_flanks[1] >= region_start) & (shap_peaks_with_flanks[2] <= region_end)) | 62 | ((shap_peaks_with_flanks[1] <= region_start) & (shap_peaks_with_flanks[2] > region_start)) | 63 | ((shap_peaks_with_flanks[1] < region_end) & (shap_peaks_with_flanks[2] >= region_end)) | 64 | ((shap_peaks_with_flanks[1] <= region_start) & (shap_peaks_with_flanks[2] >= region_end)))].copy() 65 | if region_peaks.empty: 66 | print("No shap peaks found in this region.") 67 | 68 | region_peaks = region_peaks.sort_values(by=1) 69 | shap_values = [] 70 | last_end = region_start 71 | 72 | for index,row in region_peaks.iterrows(): 73 | assert (row['end'] - row['start']) == shap_h5['projected_shap']['seq'].shape[2] 74 | 75 | print('peak_start:', row['start']) 76 | print('peak_end:', row['end']) 77 | print() 78 | 79 | if last_end < row['start']: 80 | if len(shap_values) == 0: 81 | shap_values = np.zeros((row['start'] - last_end, shap_h5['projected_shap']['seq'].shape[1])) 82 | last_end = row['start'] 83 | print("Added starting zeros") 84 | print(shap_values.shape) 85 | print(last_end) 86 | print() 87 | else: 88 | shap_values = np.concatenate([shap_values, np.zeros((row['start'] - last_end, shap_h5['projected_shap']['seq'].shape[1]))]) 89 | last_end = row['start'] 90 | print("Added middle zeros") 91 | print(shap_values.shape) 92 | print(last_end) 93 | print() 94 | 95 | if last_end > row['end']: 96 | continue 97 | 98 | elif last_end > row['start'] and last_end + shap_h5['projected_shap']['seq'].shape[2] <= row['end']: 99 | if len(shap_values) == 0: 100 | shap_values = shap_h5['projected_shap']['seq'][index][:,last_end - row['start']:].T 101 | last_end = row['end'] 102 | print("Added start trimmed shap") 103 | print(shap_values.shape) 104 | print(last_end) 105 | print() 106 | else: 107 | shap_values = np.concatenate([shap_values, shap_h5['projected_shap']['seq'][index][:,last_end - row['start']:].T]) 108 | last_end = row['end'] 109 | print("Added start trimmed shap") 110 | print(shap_values.shape) 111 | print(last_end) 112 | print() 113 | 114 | elif last_end > row['start'] and last_end + shap_h5['projected_shap']['seq'].shape[2] > region_end: 115 | if len(shap_values) == 0: 116 | shap_values = shap_h5['projected_shap']['seq'][index][:,last_end - row['start']:(region_end - last_end) + (last_end - row['start'])].T 117 | last_end = region_end 118 | print("Added start and end trimmed shap") 119 | print(shap_values.shape) 120 | print(last_end) 121 | print() 122 | else: 123 | shap_values = np.concatenate([shap_values, shap_h5['projected_shap']['seq'][index][:,last_end - row['start']:(region_end - last_end) + (last_end - row['start'])].T]) 124 | last_end = region_end 125 | print("Added start and end trimmed shap") 126 | print(shap_values.shape) 127 | print(last_end) 128 | print 129 | 130 | elif last_end + shap_h5['projected_shap']['seq'].shape[2] <= region_end: 131 | shap_values = np.concatenate([shap_values, shap_h5['projected_shap']['seq'][index].T]) 132 | last_end = row['end'] 133 | print("Added shap") 134 | print(shap_values.shape) 135 | print(last_end) 136 | print() 137 | 138 | elif last_end + shap_h5['projected_shap']['seq'].shape[2] > region_end: 139 | shap_values = np.concatenate([shap_values, shap_h5['projected_shap']['seq'][index][:,:region_end - last_end].T]) 140 | last_end = region_end 141 | print("Added end trimmed shap") 142 | print(shap_values.shape) 143 | print(last_end) 144 | print() 145 | 146 | else: 147 | print("ERROR: Peaks are not sorted.") 148 | 149 | if last_end < region_end: 150 | shap_values = np.concatenate([shap_values, np.zeros((region_end - last_end, shap_h5['projected_shap']['seq'].shape[1]))]) 151 | print("Added ending zeros") 152 | print(shap_values.shape) 153 | print() 154 | 155 | logo1 = logomaker.Logo(pd.DataFrame(shap_values, 156 | columns=['A','C','G','T']), ax=ax) 157 | if 'min_value' in self.properties and 'max_value' in self.properties: 158 | ax.set_ylim(float(self.properties['min_value']), float(self.properties['max_value'])) 159 | ax.set_ylabel(sample + ' Counts Shap') 160 | 161 | return ax 162 | -------------------------------------------------------------------------------- /scalex/pl/_network.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import rustworkx as rx 5 | 6 | from ._base import render_plot 7 | 8 | def network_edge_stat( 9 | network: rx.PyDiGraph, 10 | **kwargs, 11 | ): 12 | """ 13 | Parameters 14 | ---------- 15 | network 16 | Network. 17 | kwargs 18 | Additional arguments passed to :func:`~snapatac2.pl.render_plot` to 19 | control the final plot output. Please see :func:`~snapatac2.pl.render_plot` 20 | for details. 21 | """ 22 | from collections import defaultdict 23 | import plotly.graph_objects as go 24 | 25 | scores = defaultdict(lambda: defaultdict(lambda: [])) 26 | 27 | for fr, to, data in network.edge_index_map().values(): 28 | type = "{} -> {}".format(network[fr].type, network[to].type) 29 | if data.cor_score is not None: 30 | scores["correlation"][type].append(data.cor_score) 31 | if data.regr_score is not None: 32 | scores["regression"][type].append(data.regr_score) 33 | 34 | fig = go.Figure() 35 | 36 | for key, vals in scores["correlation"].items(): 37 | fig.add_trace(go.Violin( 38 | y=vals, 39 | name=key, 40 | box_visible=True, 41 | meanline_visible=True 42 | )) 43 | 44 | return render_plot(fig, **kwargs) 45 | 46 | def network_scores( 47 | network: rx.PyDiGraph, 48 | score_name: str, 49 | width: float = 800, 50 | height: float = 400, 51 | show: bool = True, 52 | interactive: bool = True, 53 | out_file: str | None = None, 54 | ): 55 | """ 56 | score_name 57 | Name of the edge attribute 58 | width 59 | The width of the plot 60 | height 61 | The height of the plot 62 | """ 63 | import plotly.express as px 64 | import pandas as pd 65 | import bisect 66 | 67 | def human_format(num): 68 | num = float('{:.3g}'.format(num)) 69 | magnitude = 0 70 | while abs(num) >= 1000: 71 | magnitude += 1 72 | num /= 1000.0 73 | return '{}{}'.format('{:f}'.format(num).rstrip('0').rstrip('.'), ['', 'K', 'M', 'B', 'T'][magnitude]) 74 | 75 | break_points = [100, 500, 2000, 20000, 50000, 100000, 500000] 76 | intervals = [] 77 | for i in range(len(break_points)): 78 | if i == 0: 79 | intervals.append("0 - " + human_format(break_points[i])) 80 | else: 81 | intervals.append(human_format(break_points[i - 1]) + " - " + human_format(break_points[i])) 82 | intervals.append("> 500k") 83 | values = [[] for _ in range(len(intervals))] 84 | for e in network.edges(): 85 | i = bisect.bisect(break_points, e.distance) 86 | sc = getattr(e, score_name) 87 | if sc is not None: 88 | values[i].append(sc) 89 | 90 | intervals, values = zip(*filter(lambda x: len(x[1]) > 0, zip(intervals, values))) 91 | values = [np.nanmean(v) for v in values] 92 | 93 | df = pd.DataFrame({ 94 | "Distance to TSS (bp)": intervals, 95 | "Average score": values, 96 | }) 97 | fig = px.bar( 98 | df, x="Distance to TSS (bp)", y="Average score", title = score_name, 99 | ) 100 | return render_plot(fig, width, height, interactive, show, out_file) -------------------------------------------------------------------------------- /scalex/pl/analysis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scanpy as sc 3 | import pandas as pd 4 | import numpy as np 5 | from gseapy import barplot, dotplot 6 | import gseapy as gp 7 | import matplotlib.pyplot as plt 8 | 9 | macrophage_markers = { 10 | 'B': ['CD79A', 'CD37', 'IGHM'], 11 | 'NK': ['GNLY', 'NKG7', 'PRF1'], 12 | 'NKT': ['DCN', 'MGP','COL1A1'], 13 | 'T': ['CD3D', 'CD3E', 'CD3G', 'CD4', 'CD8A', 'CD8B'], 14 | 'Treg': ['FOXP3', 'CD25'], #'UBC', 'DNAJB1'], 15 | 'naive T': ['TPT1'], 16 | 'mast': ['TPSB2', 'CPA3', 'MS4A2', 'KIT', 'GATA2', 'FOS2'], 17 | 'pDC': ['IRF7', 'CLEC4C', 'TCL1A'], #['IRF8', 'PLD4', 'MPEG1'], 18 | 'epithelial': ['KRT8'], 19 | 'cancer cells': ['EPCAM'], 20 | 'neutrophils': ['FCGR3B', 'CSF3R', 'CXCR2', 'SOD2', 'GOS2'], 21 | 'cDC1': ['CLEC9A'], 22 | 'cDC2': ['FCER1A', 'CD1C', 'CD1E', 'CLEC10A'], 23 | 'migratoryDC': ['BIRC3', 'CCR7', 'LAMP3'], 24 | 'follicular DC': ['FDCSP'], 25 | 'DC': ['CLEC9A', 'XCR1', 'CD1C', 'CD1A', 'LILRA4'], 26 | 'CD207+ DC': ['CD1A', 'CD207'], # 'FCAR1A'], 27 | 'Monocyte': ['FCGR3A', 'VCAN', 'SELL', 'CDKN1C', 'MTSS1'], 28 | 'Macrophage': ['CSF1R', 'C1QA', 'APOE', 'TREM2', 'MARCO', 'MCR1', 'CD68', 'CD163', 'CD206', 'CCL2', 'CCL3', 'CCL4'], 29 | 'Macro SPP1+': ['SPP1', 'LPL', 'MGLL', 'FN1'], # IL4I1 30 | 'Macro SELENOP': ['SEPP1', 'SELENOP', 'FOLR2'], # 'RNASE1', 31 | 'Macro FCN1+': ['FCN1', 'S100A9', 'S100A8'], # NLRP3 32 | 'Macro IL32+': ['IL32', 'CCL5', 'CD7', 'TRAC', 'CD3D', 'TRBC2', 'IGHG1', 'IGKC'], 33 | 'Macro APOE+': ['C1QC', 'APOE', 'GPNMB', 'LGMN'], # C1QC > APOE 34 | 'Macro IL1B': ['NFKBIA', 'CXCL8', 'IER3', 'SOD2', 'IL1B'], 35 | 'Macro FABP4+': ['FABP4'], 36 | } 37 | 38 | def enrich_analysis(gene_names, organism='hsapiens', gene_sets='GO_Biological_Process_2023', cutoff=0.05, **kwargs): # gene_sets="GO_Biological_Process_2021" 39 | """ 40 | Perform KEGG pathway analysis and plot the results as a clustermap. 41 | 42 | Parameters: 43 | - gene_names: A dictionary with group labels as keys and lists of gene names as values. 44 | - gene_sets: The gene set database to use for enrichment analysis (default is 'KEGG_2021_Human'). 'GO_Biological_Process_2021', could find in gp.get_library_name() 45 | - organism: Organism for KEGG analysis (default is 'hsapiens 46 | - top_terms: Number of top terms to consider for the clustermap. 47 | """ 48 | import gseapy as gp 49 | from gseapy import Msigdb 50 | msig = Msigdb() 51 | if isinstance(gene_names, pd.DataFrame): 52 | gene_names = gene_names.to_dict(orient='list') 53 | if gene_sets in msig.list_category(): 54 | # ['c1.all', 'c2.all', 'c2.cgp', 'c2.cp.biocarta', 'c2.cp.kegg_legacy', 'c2.cp.kegg_medicus', 'c2.cp.pid', 'c2.cp.reactome', 'c2.cp', 'c2.cp.wikipathways', 'c3.all', 'c3.mir.mir_legacy', 'c3.mir.mirdb', 'c3.mir', 'c3.tft.gtrd', 'c3.tft.tft_legacy', 'c3.tft', 55 | # 'c4.3ca', 'c4.all', 'c4.cgn', 'c4.cm', 'c5.all', 'c5.go.bp', 'c5.go.cc', 'c5.go.mf', 'c5.go', 'c5.hpo', 'c6.all', 'c7.all', 'c7.immunesigdb', 'c7.vax', 'c8.all', 'h.all', 'msigdb'] 56 | gene_sets = msig.get_gmt(category = gene_sets, dbver='2024.1.Hs') 57 | 58 | results = pd.DataFrame() 59 | for group, genes in gene_names.items(): 60 | enr = gp.enrichr(genes, gene_sets=gene_sets, cutoff=cutoff).results 61 | enr['cell_type'] = group # Add the group label to the results 62 | results = pd.concat([results, enr]) 63 | 64 | results_filtered = results[results['Adjusted P-value'] < cutoff] 65 | # results_pivot = results_filtered.pivot_table(index='Term', columns='cell_type', values='Adjusted P-value', aggfunc='min') 66 | # results_pivot = results_pivot.sort_values(by=results_pivot.columns.tolist(), ascending=True) 67 | 68 | # return results_pivot, results_filtered 69 | return results_filtered 70 | 71 | 72 | def annotate( 73 | adata, 74 | cell_type='leiden', 75 | color = ['cell_type', 'leiden', 'tissue', 'donor'], 76 | cell_type_markers='macrophage', #None, 77 | show_markers=False, 78 | gene_sets='GO_Biological_Process_2023', 79 | n_tops = [100], 80 | options = ['pos'], # ['pos', 'neg'] 81 | additional={}, 82 | go=True, 83 | out_dir = None, #'../../results/go_and_pathway/NSCLC_macrophage/' 84 | ): 85 | 86 | color = [i for i in color if i in adata.obs.columns] 87 | color = color + [cell_type] if cell_type not in color else color 88 | sc.pl.umap(adata, color=color, legend_loc='on data', legend_fontsize=10) 89 | 90 | var_names = adata.raw.var_names if adata.raw is not None else adata.var_names 91 | if cell_type_markers is not None: 92 | if isinstance(cell_type_markers, str): 93 | if cell_type_markers == 'macrophage': 94 | cell_type_markers = macrophage_markers 95 | cell_type_markers_ = {k: [i for i in v if i in var_names] for k,v in cell_type_markers.items() } 96 | sc.pl.dotplot(adata, cell_type_markers_, groupby=cell_type, standard_scale='var', cmap='coolwarm') 97 | 98 | sc.tl.rank_genes_groups(adata, groupby=cell_type, key_added=cell_type, dendrogram=False) 99 | sc.pl.rank_genes_groups_dotplot(adata, n_genes=5, cmap='coolwarm', key=cell_type, standard_scale='var', figsize=(22, 5), dendrogram=False) 100 | marker = pd.DataFrame(adata.uns[cell_type]['names']) 101 | marker_dict = marker.head(5).to_dict(orient='list') 102 | plt.show() 103 | 104 | if show_markers: 105 | for k, v in marker_dict.items(): 106 | print(k) 107 | sc.pl.umap(adata, color=v, ncols=5) 108 | 109 | for n_top in n_tops: 110 | print('-'*20+'\n', n_top, '\n'+'-'*20) 111 | 112 | if go: 113 | for option in options: 114 | if option == 'pos': 115 | go_results = enrich_analysis(marker.head(n_top), gene_sets=gene_sets) 116 | else: 117 | go_results = enrich_analysis(marker.tail(n_top), gene_sets=gene_sets) 118 | 119 | go_results['cell_type'] = 'leiden_' + go_results['cell_type'] 120 | n = go_results['cell_type'].nunique() 121 | ax = dotplot(go_results, 122 | column="Adjusted P-value", 123 | x='cell_type', # set x axis, so you could do a multi-sample/library comparsion 124 | # size=10, 125 | top_term=10, 126 | figsize=(0.7*n, 2*n), 127 | title = f"{option}_GO_BP_{n_top}", 128 | xticklabels_rot=45, # rotate xtick labels 129 | show_ring=False, # set to False to revmove outer ring 130 | marker='o', 131 | cutoff=0.05, 132 | cmap='viridis' 133 | ) 134 | if out_dir is not None: 135 | os.makedirs(out_dir, exist_ok=True) 136 | go_results = go_results.sort_values('Adjusted P-value', ascending=False).groupby('cell_type').head(10) 137 | go_results[['Gene_set','Term','Overlap', 'Adjusted P-value', 'Genes', 'cell_type']].to_csv(out_dir + f'/{option}_go_results_{n_top}.csv') 138 | plt.show() 139 | 140 | for pathway_name, pathways in additional.items(): 141 | try: 142 | pathway_results = enrich_analysis(marker.head(n_top), gene_sets=pathways) 143 | except: 144 | continue 145 | ax = dotplot(pathway_results, 146 | column="Adjusted P-value", 147 | x='cell_type', # set x axis, so you could do a multi-sample/library comparsion 148 | # size=10, 149 | top_term=10, 150 | figsize=(8,10), 151 | title = pathway_name, 152 | xticklabels_rot=45, # rotate xtick labels 153 | show_ring=False, # set to False to revmove outer ring 154 | marker='o', 155 | cutoff=0.05, 156 | cmap='viridis' 157 | ) 158 | 159 | plt.show() 160 | 161 | 162 | 163 | 164 | def find_go_term_gene(df, term): 165 | """ 166 | df: df = pd.read_csv(go_results, index_col=0) 167 | term: either Term full name or Go number: GO:xxxx 168 | """ 169 | if term.startswith('GO'): 170 | df['GO'] = df['Term'].str.split('(').str[1].str.replace(')', '') 171 | select = df[df['GO'] == term].copy() 172 | else: 173 | select = df[df['Term'] == term].copy() 174 | gene_set = set(gene for sublist in select['Genes'].str.split(';') for gene in sublist) 175 | # gene_set = set(select['Genes'].str.split(';')) 176 | # print(select['Term'].head(1).values[0]) 177 | # print('\n'.join(gene_set)) 178 | return gene_set 179 | 180 | def format_dict_of_list(d, out='table'): 181 | if out == 'matrix': 182 | data = [] 183 | for k, lt in d.items(): 184 | for v in lt: 185 | data.append({'Gene': v, 'Pathway': k}) 186 | 187 | # Step 2: Create a DataFrame from the list 188 | df = pd.DataFrame(data) 189 | 190 | # Step 3: Use crosstab to pivot the DataFrame 191 | df = pd.crosstab(df['Gene'], df['Pathway']) 192 | elif out == 'table': 193 | df = pd.DataFrame.from_dict(d, orient='index').transpose() 194 | 195 | return df 196 | 197 | 198 | def parse_go_results(df, cell_type='cell_type', top=20, out='table', tag='', dataset=''): 199 | """ 200 | Return: 201 | a term gene dataframe: each column is a term 202 | a term cluster dataframe: each column is a term 203 | """ 204 | term_genes = {} 205 | term_clusters = {} 206 | for c in np.unique(df[cell_type]): 207 | terms = df[df[cell_type]==c]['Term'].values 208 | for term in terms[:top]: 209 | if term not in term_clusters: 210 | term_clusters[term] = [] 211 | 212 | term_clusters[term].append(c) 213 | 214 | if term not in term_genes: 215 | term_genes[term] = find_go_term_gene(df, term) 216 | 217 | tag = tag + ':' if tag else '' 218 | 219 | if out == 'dict': 220 | return term_genes, term_clusters 221 | else: 222 | term_genes = format_dict_of_list(term_genes, out=out) 223 | index = [(k, dataset, tag+';'.join(v)) for k, v in term_clusters.items()] 224 | term_genes.columns = pd.MultiIndex.from_tuples(index, names=['Pathway', 'Dataset', 'Cluster']) 225 | return term_genes 226 | 227 | 228 | def merge_all_go_results(path, datasets=None, top=20, out_dir=None, add_ref=False, union=True, reference='GO_Biological_Process_2023', organism='human'): 229 | """ 230 | The go results should organized by path/datasets/go_results.csv 231 | Args: 232 | path is the input to store all the go results 233 | datasets are selected to merge 234 | """ 235 | df_list = [] 236 | if datasets is None: 237 | datasets = [i for i in os.listdir(path) if os.path.isdir(os.path.join(path, i))] 238 | for dataset in datasets: 239 | path2 = os.path.join(path, dataset) 240 | for filename in os.listdir(path2): 241 | if 'go_results' in filename: 242 | name = filename.replace('.csv', '') 243 | path3 = os.path.join(path2, filename) 244 | df = pd.read_csv(path3, index_col=0) 245 | term_genes = parse_go_results(df, dataset=dataset, tag=name, top=top) 246 | df_list.append(term_genes) 247 | concat_df = pd.concat(df_list, axis=1) 248 | 249 | if add_ref and not union: 250 | go_ref = gp.get_library(name=reference, organism=organism) 251 | go_ref = format_dict_of_list(go_ref) 252 | pathways = [i for i in concat_df.columns.get_level_values('Pathway').unique() if i in go_ref.columns] 253 | go_ref = go_ref.loc[:, pathways] 254 | index_tuples = [ (i, 'GO_Biological_Process_2023', 'reference') for i in go_ref.columns ] 255 | go_ref.columns = pd.MultiIndex.from_tuples(index_tuples, names=['Pathway', 'Dataset', 'Cluster']) 256 | concat_df = pd.concat([concat_df, go_ref], axis=1) 257 | 258 | concat_df = concat_df.sort_index(axis=1, level='Pathway') 259 | 260 | if union: 261 | concat_df = concat_df.groupby(level=["Pathway"], axis=1) 262 | concat_dict = {name: [i for i in set(group.values.flatten()) if pd.notnull(i)] for name, group in concat_df} 263 | concat_df = pd.DataFrame.from_dict(concat_dict, orient='index').transpose() 264 | 265 | if out_dir is not None: 266 | dirname = os.path.dirname(out_dir) 267 | os.makedirs(dirname, exist_ok=True) 268 | 269 | if not union: 270 | if not out_dir.endswith('xlsx'): 271 | out_dir = out_dir + '.xlsx' 272 | with pd.ExcelWriter(out_dir, engine='openpyxl') as writer: 273 | concat_df.to_excel(writer, sheet_name='Sheet1') 274 | else: 275 | concat_df.to_csv(out_dir, index=False) 276 | return concat_df -------------------------------------------------------------------------------- /scalex/pp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsxlei/SCALEX/7a410ffe11b1d2139d3c66961ee90f77136c26fa/scalex/pp/__init__.py -------------------------------------------------------------------------------- /scalex/pp/annotation.py: -------------------------------------------------------------------------------- 1 | import pyranges as pr 2 | import os 3 | from scalex.atac.bedtools import bed_to_df 4 | 5 | GENE_COLUMNS = ['Chromosome', 'Start', 'End', 'Strand', 'gene_name', 'gene_ids'] 6 | import re 7 | def ens_trim_version(x: str): 8 | return re.sub(r"\.[0-9_-]+$", "", x) 9 | 10 | def annotate_genes(gene_var, gtf=None, by='gene_name'): 11 | COLUMNS = ['Chromosome', 'Start', 'End', 'Strand', 'gene_name', 'gene_ids'] 12 | if isinstance(gtf, str): 13 | gtf = get_gtf(gtf, drop_by=by).df 14 | elif isinstance(gtf, pr.PyRanges): 15 | gtf = gtf.df 16 | 17 | if 'gene_ids' in gene_var.columns: 18 | gene_var.index = gene_var['gene_ids'] 19 | by = 'gene_id' 20 | print("Use gene_id") 21 | 22 | if by == 'gene_id': 23 | gtf.index = gtf[by].apply(ens_trim_version) 24 | else: 25 | gtf.index = gtf[by] 26 | 27 | gtf['gene_ids'] = gtf['gene_id'] 28 | gene_var = gtf.reindex(gene_var.index).loc[:, COLUMNS] 29 | 30 | return gene_var 31 | 32 | def remove_genes_without_annotation(adata, by='gene_name', exclude_mt=True): 33 | adata.var_names_make_unique() 34 | genes = adata.var.dropna(subset=[by]).index 35 | adata = adata[:, genes] 36 | if exclude_mt: 37 | indices = [i for i, name in enumerate(adata.var.gene_name) 38 | if not str(name).startswith(tuple(['ERCC', 'MT-', 'mt-']))] 39 | 40 | adata = adata[:, indices] #.copy() 41 | 42 | adata.var.Start = adata.var.Start.astype(int) 43 | adata.var.End = adata.var.End.astype(int) 44 | return adata 45 | 46 | def add_interval_to_gene_var(gene_var): 47 | # gene_var = gene_var.copy() 48 | gene_var['interval'] = df_to_bed(gene_var) 49 | gene_var = strand_specific_start_site(gene_var) 50 | gene_var['tss'] = gene_var['Start'] 51 | # gene_var['promoter_interval'] = df_to_bed(promoter) 52 | return gene_var 53 | 54 | 55 | 56 | def rna_var_to_promoter(var, up=1000, down=100): 57 | var = strand_specific_start_site(var) 58 | var = get_promoter_interval(var, up=up, down=down) 59 | return var 60 | 61 | 62 | def format_rna( 63 | rna, 64 | gtf=os.path.expanduser('~/.scalex/gencode.v38.annotation.gtf.gz'), 65 | up=1000, 66 | down=100, 67 | force=False 68 | ): 69 | if set(GENE_COLUMNS).issubset(rna.var.columns) and not force: 70 | print("Already formatted") 71 | return rna 72 | 73 | rna.var = annotate_genes(rna.var, gtf) 74 | rna = remove_genes_without_annotation(rna) 75 | rna.var = add_interval_to_gene_var(rna.var) 76 | return rna 77 | 78 | def format_atac(atac): 79 | if set(["Chromosome", "Start", "End"]).issubset(atac.var.columns): 80 | print("Already formatted") 81 | return atac 82 | else: 83 | atac.var = bed_to_df(atac.var_names) 84 | return atac 85 | 86 | 87 | def strand_specific_start_site(df): 88 | df = df.copy() 89 | if set(df["Strand"]) != set(["+", "-"]): 90 | raise ValueError("Not all features are strand specific!") 91 | 92 | pos_strand = df.query("Strand == '+'").index 93 | neg_strand = df.query("Strand == '-'").index 94 | df.loc[pos_strand, "End"] = df.loc[pos_strand, "Start"] + 1 95 | df.loc[neg_strand, "Start"] = df.loc[neg_strand, "End"] - 1 96 | return df 97 | 98 | def get_promoter_interval(genes, up=1000, down=0): 99 | tss = strand_specific_start_site(genes) 100 | from .bedtools import extend_bed 101 | promoter = extend_bed(tss, up=up, down=down) 102 | return promoter 103 | 104 | def df_to_bed(x): 105 | # return (x.iloc[:, 0]+':'+x.iloc[:, 1].astype(str)+'-'+x.iloc[:, 2].astype(str)).values 106 | return x.apply(lambda row: row.iloc[0]+':'+str(row.iloc[1])+'-'+str(row.iloc[2]), axis=1).values 107 | 108 | def get_gtf(gtf_file, genome='hg38', drop_by='gene_name'): 109 | # if genome == 'hg38': 110 | # gtf_file = GENOME_PATH / 'gencode.v38.annotation.gtf.gz' 111 | # elif genome == 'hg19': 112 | # gtf_file = GENOME_PATH / 'gencode.v19.annotation.gtf.gz' 113 | if not os.path.exists(gtf_file): 114 | version = genome.split('hg')[-1] 115 | os.system(f'wget -O {gtf_file} https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_{version}/gencode.v{version}.annotation.gtf.gz') 116 | 117 | gtf = pr.read_gtf(gtf_file) 118 | gtf = gtf.df.drop_duplicates(subset=[drop_by], keep="first") 119 | return pr.PyRanges(gtf) -------------------------------------------------------------------------------- /scalex/specifity.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | # Optimized Specificity Score Calculation 4 | # Author: Xiong Lei (original), optimized by Xuxin Tang 5 | """ 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import scipy.stats as stats 10 | 11 | def jsd(p, q, base=np.e): 12 | """ 13 | Compute Jensen-Shannon divergence. 14 | """ 15 | p, q = np.asarray(p), np.asarray(q) 16 | p /= p.sum() 17 | q /= q.sum() 18 | m = 0.5 * (p + q) 19 | return 0.5 * (stats.entropy(p, m, base=base) + stats.entropy(q, m, base=base)) 20 | 21 | def jsd_sp(p, q, base=np.e): 22 | """ 23 | Specificity score: 1 - sqrt(jsd) 24 | """ 25 | return 1 - np.sqrt(jsd(p, q, base)) 26 | 27 | def log2norm(e): 28 | """ 29 | Log2 normalization: log2(e+1) and scaling. 30 | """ 31 | loge = np.log2(e + 1) 32 | return loge / loge.sum() 33 | 34 | def predefined_pattern(t, labels): 35 | """ 36 | Generate predefined binary pattern for cluster t. 37 | """ 38 | return (labels == t).astype(int) 39 | 40 | def vec_specificity_score(e, t, labels): 41 | """ 42 | Compute specificity score for a given cluster. 43 | """ 44 | return jsd_sp(log2norm(e), log2norm(predefined_pattern(t, labels))) 45 | 46 | def mat_specificity_score(mat, labels): 47 | """ 48 | Compute specificity scores for all genes/peaks across clusters. 49 | Returns a DataFrame of genes/peaks (rows) x clusters (columns). 50 | """ 51 | unique_labels = np.unique(labels) 52 | return pd.DataFrame( 53 | {t: mat.apply(lambda x: vec_specificity_score(x, t, labels), axis=1) for t in unique_labels}, 54 | index=mat.index 55 | ) 56 | 57 | def compute_pvalues(score_mat, labels, num_permutations=1000): 58 | """ 59 | Compute p-values for cluster-specificity scores using permutation testing. 60 | """ 61 | shuffled_scores = np.zeros((score_mat.shape[0], score_mat.shape[1], num_permutations)) 62 | 63 | for i in range(num_permutations): 64 | shuffled_labels = np.random.permutation(labels) 65 | shuffled_scores[:, :, i] = mat_specificity_score(score_mat, shuffled_labels).values 66 | 67 | p_values = (np.sum(shuffled_scores >= score_mat.values[:, :, None], axis=2) + 1) / (num_permutations + 1) 68 | return pd.DataFrame(p_values, index=score_mat.index, columns=score_mat.columns) 69 | 70 | def filter_significant_clusters(score_mat, p_values, alpha=0.05): 71 | """ 72 | Filter cluster-specific genes/peaks based on significance threshold (p-value < alpha). 73 | """ 74 | significant_mask = p_values < alpha 75 | return score_mat.where(significant_mask) 76 | 77 | def cluster_specific(score_mat, classes=None, top=0): 78 | """ 79 | Identify top specific genes/peaks for each cluster. 80 | """ 81 | max_scores = score_mat.max(axis=1) 82 | peak_labels = score_mat.idxmax(axis=1) 83 | 84 | if classes is None: 85 | classes = peak_labels.unique() 86 | 87 | top_indices = { 88 | cls: max_scores[peak_labels == cls].nlargest(top).index for cls in classes 89 | } 90 | 91 | return top_indices 92 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import anndata as ad 3 | import numpy as np 4 | import pandas as pd 5 | 6 | np.random.seed(42) 7 | 8 | @pytest.fixture(scope='session') 9 | def adata_test(n_obs=16, n_vars=32, n_categories=3): 10 | """ 11 | Creates a virtual AnnData object with random binary data for testing. 12 | 13 | Parameters: 14 | - n_obs (int): Number of observations (cells) 15 | - n_vars (int): Number of variables (genes) 16 | - n_categories (int): Number of categories for the categorical annotation 17 | 18 | Returns: 19 | - An AnnData object populated with binary data and annotations. 20 | """ 21 | # Generate random binary data 22 | X = np.random.randint(0, 2, size=(n_obs, n_vars)) # Random binary matrix (0s and 1s) 23 | 24 | # Generate observation names and variable names 25 | obs_names = [f"Cell_{i}" for i in range(n_obs)] 26 | var_names = [f"Gene_{j}" for j in range(n_vars)] 27 | 28 | # Create observation (cell) metadata 29 | obs = pd.DataFrame({ 30 | 'condition': np.random.choice([f"Condition_{i}" for i in range(1, n_categories + 1)], n_obs), 31 | 'batch': np.random.choice([i for i in range(1, n_categories + 1)], n_obs) 32 | }, index=obs_names) 33 | 34 | # Create variable (gene) metadata 35 | var = pd.DataFrame({ 36 | 'Gene_ID': var_names 37 | }, index=var_names) 38 | 39 | # Create AnnData object 40 | adata = ad.AnnData(X=X, obs=obs, var=var) 41 | adata.obs['batch'] = adata.obs['batch'].astype('category') 42 | 43 | return adata -------------------------------------------------------------------------------- /tests/test_scalex.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import scalex 4 | from scalex.function import SCALEX 5 | from scalex.net.vae import VAE 6 | from scalex.data import preprocessing_rna 7 | 8 | import torch 9 | 10 | # def test_preprocess_rna(adata_test): 11 | # adata = adata_test.copy() 12 | # adata = preprocessing_rna(adata, min_cells=2, min_features=0) 13 | # assert adata.raw.shape == adata_test.shape 14 | 15 | 16 | def test_scalex_forward(adata_test): 17 | n_domain = len(adata_test.obs['batch'].astype('category').cat.categories) 18 | x_dim = adata_test.X.shape[1] 19 | 20 | # model config 21 | enc = [['fc', 1024, 1, 'relu'],['fc', 10, '', '']] # TO DO 22 | dec = [['fc', x_dim, n_domain, 'sigmoid']] 23 | model = VAE(enc, dec, n_domain=n_domain) 24 | 25 | x = torch.Tensor(adata_test.X) 26 | y = torch.LongTensor(adata_test.obs['batch'].values) 27 | z, recon_x, loss = model(x, y) 28 | # Print a summary of the AnnData object 29 | assert z.shape == (adata_test.shape[0], 10) 30 | assert recon_x.shape == adata_test.shape 31 | 32 | # Load the file 33 | 34 | 35 | # def test_full_model(adata_test): 36 | # out = SCALEX( 37 | # adata_test, processed=True, min_cells=0, min_features=0, batch_size=2, max_iteration=10, 38 | # ) 39 | # assert 'distances' in out.obsp 40 | # assert 'X_scalex_umap' in out.obsm 41 | 42 | 43 | if __name__ == '__main__': 44 | pytest.main([__file__]) 45 | -------------------------------------------------------------------------------- /third_parties/BBKNN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import scanpy as sc 4 | import argparse 5 | import anndata 6 | import bbknn 7 | import os 8 | import time 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | sc.settings.verbosity = 3 # verbosity: errors (0), warnings (1), info (2), hints (3) 13 | sc.settings.set_figure_params(dpi=150) # low dpi (dots per inch) yields small inline figures 14 | 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser(description='Integrate multi single cell datasets by BBKNN') 17 | parser.add_argument('--h5ad', type=str, default=None) 18 | parser.add_argument('--outdir', '-o', type=str, default='./') 19 | parser.add_argument('--min_genes', type=int, default=600) 20 | parser.add_argument('--min_cells', type=int, default=3) 21 | parser.add_argument('--num_pcs', type=int, default=20) 22 | parser.add_argument('--n_top_features', type=int, default=2000) 23 | 24 | args = parser.parse_args() 25 | 26 | outdir = args.outdir 27 | os.makedirs(outdir, exist_ok=True) 28 | 29 | adata = sc.read_h5ad(args.h5ad) 30 | time1 = time.time() 31 | sc.pp.filter_cells(adata, min_genes=args.min_genes) 32 | sc.pp.filter_genes(adata, min_cells=args.min_cells) 33 | 34 | sc.pp.normalize_total(adata, target_sum=1e4) 35 | sc.pp.log1p(adata) 36 | adata = adata.copy() 37 | 38 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", batch_key='batch', n_top_genes=args.n_top_features, subset=True) 39 | 40 | sc.pp.scale(adata, max_value=10) 41 | sc.tl.pca(adata) 42 | sc.pp.neighbors(adata, n_pcs=args.num_pcs, n_neighbors=20) 43 | adata_bbknn = bbknn.bbknn(adata, neighbors_within_batch=5, n_pcs=args.num_pcs, trim=0, copy=True) 44 | sc.tl.umap(adata_bbknn, min_dist=0.1) 45 | 46 | # UMAP 47 | sc.settings.figdir = outdir 48 | plt.rcParams['figure.figsize'] = (6, 8) 49 | cols = ['celltype', 'batch'] 50 | 51 | color = [c for c in cols if c in adata_bbknn.obs] 52 | sc.pl.umap(adata_bbknn, color=color, frameon=False, save='.png', wspace=0.4, show=False, ncols=1) 53 | time2 = time.time() 54 | print("--- %s seconds ---" % (time2 - time1)) 55 | 56 | # pickle data 57 | adata_bbknn.write(outdir+'/adata.h5ad', compression='gzip') 58 | -------------------------------------------------------------------------------- /third_parties/Conos.R: -------------------------------------------------------------------------------- 1 | suppressPackageStartupMessages(library(argparse)) 2 | suppressPackageStartupMessages(library(conos)) 3 | suppressPackageStartupMessages(library(dplyr)) 4 | suppressPackageStartupMessages(library(Seurat)) 5 | suppressPackageStartupMessages(library(Matrix)) 6 | 7 | parser <- ArgumentParser(description='Conos for the integrative analysis of multi-batch single-cell transcriptomic profiles') 8 | 9 | parser$add_argument("-i", "--input_path", type="character", help="Path contains RNA data") 10 | parser$add_argument("-o", "--output_path", type="character", default='./', help="Output path") 11 | parser$add_argument("-mf", "--minFeatures", type="integer", default=600, help="Remove cells with less than minFeatures features") 12 | parser$add_argument("-mc", "--minCells", type="integer", default=3, help="Remove features with less than minCells cells") 13 | parser$add_argument("-nt", "--n_top_features", type="integer", default=2000, help="N highly variable features") 14 | args <- parser$parse_args() 15 | 16 | message('Reading matrix.mtx and metadata.txt in R...') 17 | data <- readMM(paste(args$input_path, '/matrix.mtx', sep='')) 18 | genes <- read.table(paste(args$input_path, '/genes.txt', sep=''), sep='\t') 19 | metadata <- read.csv(paste(args$input_path, '/metadata.txt', sep=''), sep='\t') 20 | row.names(metadata) <- metadata[,1] 21 | metadata <- metadata[,-1] 22 | metadata$batch = as.character(metadata$batch) 23 | # metadata$celltype = as.character(metadata$celltype) 24 | 25 | data <- data.frame(t(data)) 26 | 27 | colnames(data) <- row.names(metadata) 28 | row.names(data) <- genes[,1] 29 | 30 | adata <- CreateSeuratObject(data, 31 | meta.data = metadata, 32 | min.cells = args$minCells, 33 | min.features = args$minFeatures) 34 | 35 | batch_ <- unique(metadata$batch) 36 | panel.preprocessed <- list() 37 | for (batch in batch_){ 38 | panel.preprocessed[[as.character(batch)]] <- basicSeuratProc(adata@assays$RNA@counts[,(adata@meta.data$batch == batch)], tsne=FALSE, umap=FALSE, verbose=FALSE) 39 | panel.preprocessed[[as.character(batch)]] <- RunTSNE(panel.preprocessed[[as.character(batch)]], npcs = 30, verbose=FALSE, check_duplicates=FALSE) 40 | } 41 | 42 | con <- Conos$new(panel.preprocessed, n.cores=1) 43 | message('Integrating...') 44 | con$buildGraph(k=30, 45 | k.self=20, 46 | space='PCA', 47 | ncomps=30, 48 | n.odgenes=args$n_top_features, 49 | matching.method='mNN', 50 | metric='angular', 51 | score.component.variance=TRUE, 52 | verbose=FALSE) 53 | 54 | con$findCommunities(method=leiden.community, resolution=1) 55 | con$embedGraph(method="UMAP", 56 | min.dist=0.1, 57 | spread=1, 58 | min.prob.lower=1e-3) 59 | 60 | embedding <- data.frame(con$embedding) 61 | colnames(embedding) <- c('Umap1','Umap2') 62 | 63 | if (!file.exists(args$output_path)){ 64 | dir.create(file.path(args$output_path),recursive = TRUE) 65 | } 66 | 67 | write.table(metadata[rownames(embedding),], paste(args$output_path, "/metadata.txt", sep=''), sep='\t') 68 | write.table(embedding, paste(args$output_path, "/integrated.txt", sep=''), sep='\t') 69 | rm(list = ls()) -------------------------------------------------------------------------------- /third_parties/DESC.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import argparse 4 | import os 5 | import time 6 | import matplotlib.pyplot as plt 7 | 8 | import anndata 9 | import desc 10 | import scanpy as sc 11 | 12 | sc.settings.verbosity = 3 # verbosity: errors (0), warnings (1), info (2), hints (3) 13 | sc.settings.set_figure_params(dpi=150) # low dpi (dots per inch) yields small inline figures 14 | 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser(description='Integrate multi single cell datasets by DESC') 17 | parser.add_argument('--h5ad', type=str, default=None) 18 | parser.add_argument('--outdir', '-o', type=str, default='./') 19 | parser.add_argument('--min_genes', type=int, default=600) 20 | parser.add_argument('--min_cells', type=int, default=3) 21 | parser.add_argument('-g','--gpu', type=int, default=0) 22 | args = parser.parse_args() 23 | 24 | outdir = args.outdir 25 | os.makedirs(outdir, exist_ok=True) 26 | 27 | adata = sc.read_h5ad(args.h5ad) 28 | time1 = time.time() 29 | sc.pp.filter_cells(adata, min_genes=args.min_genes) 30 | sc.pp.filter_genes(adata, min_cells=args.min_cells) 31 | 32 | sc.pp.normalize_total(adata, target_sum=1e4) 33 | sc.pp.log1p(adata) 34 | adata.raw=adata 35 | # sc.pp.scale(adata, zero_center=True, max_value=3) 36 | 37 | sc.pp.highly_variable_genes(adata, 38 | flavor="seurat_v3", 39 | n_top_genes=2000, 40 | batch_key="batch", 41 | subset=True) 42 | 43 | 44 | adata = desc.scale_bygroup(adata, groupby='batch', max_value=10) 45 | 46 | adata = desc.train(adata, 47 | dims=[adata.shape[1], 128, 32], 48 | tol=0.001, 49 | n_neighbors=10, 50 | batch_size=256, 51 | louvain_resolution=[0.8], 52 | save_dir=outdir, 53 | do_tsne=False, 54 | learning_rate=300, 55 | use_GPU=True, 56 | num_Cores=1, 57 | do_umap=False, 58 | num_Cores_tsne=4, 59 | use_ae_weights=False, 60 | save_encoder_weights=False) 61 | 62 | sc.pp.neighbors(adata, use_rep="X_Embeded_z0.8") 63 | sc.tl.umap(adata, min_dist=0.1) 64 | time2 = time.time() 65 | 66 | # UMAP 67 | sc.settings.figdir = outdir 68 | plt.rcParams['figure.figsize'] = (6, 8) 69 | cols = ['celltype', 'batch'] 70 | 71 | color = [c for c in cols if c in adata.obs] 72 | sc.pl.umap(adata, color=color, frameon=False, save='.png', wspace=0.4, show=False, ncols=1) 73 | 74 | # pickle data 75 | adata.write(outdir+'/adata.h5ad', compression='gzip') 76 | print("--- %s seconds ---" % (time2 - time1)) -------------------------------------------------------------------------------- /third_parties/FastMNN.R: -------------------------------------------------------------------------------- 1 | suppressPackageStartupMessages(library(argparse)) 2 | suppressPackageStartupMessages(library(Matrix)) 3 | suppressPackageStartupMessages(library(Seurat)) 4 | suppressPackageStartupMessages(library(SeuratWrappers)) 5 | suppressPackageStartupMessages(library(batchelor)) 6 | suppressPackageStartupMessages(library(future)) 7 | suppressPackageStartupMessages(library(future.apply)) 8 | 9 | parser <- ArgumentParser(description='FastMNN for the integrative analysis of multi-batch single-cell transcriptomic profiles') 10 | 11 | parser$add_argument("-i", "--input_path", type="character", help="Path contains RNA data") 12 | parser$add_argument("-o", "--output_path", type="character", default='./', help="Output path") 13 | parser$add_argument("-mf", "--minFeatures", type="integer", default=600, help="Remove cells with less than minFeatures features") 14 | parser$add_argument("-mc", "--minCells", type="integer", default=3, help="Remove features with less than minCells cells") 15 | parser$add_argument("-nt", "--n_top_features", type="integer", default=2000, help="N highly variable features") 16 | 17 | args <- parser$parse_args() 18 | 19 | plan("multiprocess", workers = 1) 20 | options(future.globals.maxSize = 100000 * 1024^2) 21 | 22 | message('Reading matrix.mtx and metadata.txt in R...') 23 | data <- readMM(paste(args$input_path, '/matrix.mtx', sep='')) 24 | genes <- read.table(paste(args$input_path, '/genes.txt', sep=''), sep='\t') 25 | metadata <- read.csv(paste(args$input_path, '/metadata.txt', sep=''), sep='\t') 26 | row.names(metadata) <- metadata[,1] 27 | metadata <- metadata[,-1] 28 | metadata$batch = as.character(metadata$batch) 29 | # metadata$celltype = as.character(metadata$celltype) 30 | 31 | data <- t(data) 32 | 33 | colnames(data) <- row.names(metadata) 34 | row.names(data) <- genes[,1] 35 | 36 | adata <- CreateSeuratObject(as(data, "sparseMatrix"), 37 | meta.data = metadata, 38 | min.cells = args$minCells, 39 | min.features = args$minFeatures) 40 | 41 | print(dim(adata)) 42 | message('Preprocessing...') 43 | adata <- NormalizeData(adata, verbose = FALSE) 44 | adata <- FindVariableFeatures(adata, selection.method = "vst", nfeatures = args$n_top_features, verbose = FALSE) 45 | 46 | message('IntegrateData...') 47 | adata <- RunFastMNN(object.list = SplitObject(adata, split.by = "batch")) 48 | 49 | if (!file.exists(args$output_path)){ 50 | dir.create(file.path(args$output_path),recursive = TRUE) 51 | } 52 | 53 | message('Writing...') 54 | write.table(adata@meta.data[,-1], paste(args$output_path, "/metadata.txt", sep=''), sep='\t') 55 | write.table(adata@reductions$mnn@cell.embeddings, paste(args$output_path, "/integrated.txt", sep=''), sep='\t') 56 | rm(list = ls()) -------------------------------------------------------------------------------- /third_parties/Harmony.R: -------------------------------------------------------------------------------- 1 | suppressPackageStartupMessages(library(argparse)) 2 | suppressPackageStartupMessages(library(Seurat)) 3 | suppressPackageStartupMessages(library(scater)) 4 | suppressPackageStartupMessages(library(future)) 5 | suppressPackageStartupMessages(library(Matrix)) 6 | suppressPackageStartupMessages(library(harmony)) 7 | 8 | parser <- ArgumentParser(description='Harmony for the integrative analysis of multi-batch single-cell transcriptomic profiles') 9 | 10 | parser$add_argument("-i", "--input_path", type="character", help="Path contains RNA data") 11 | parser$add_argument("-o", "--output_path", type="character", default='./', help="Output path") 12 | parser$add_argument("-mf", "--minFeatures", type="integer", default=600, help="Remove cells with less than minFeatures features") 13 | parser$add_argument("-mc", "--minCells", type="integer", default=3, help="Remove features with less than minCells cells") 14 | parser$add_argument("-nt", "--n_top_features", type="integer", default=2000, help="N highly variable features") 15 | 16 | args <- parser$parse_args() 17 | 18 | plan("multiprocess", workers = 4) 19 | options(future.globals.maxSize = 10000 * 1024^2) 20 | 21 | message('Reading matrix.mtx and metadata.txt in R...') 22 | data <- readMM(paste(args$input_path, '/matrix.mtx', sep='')) 23 | genes <- read.table(paste(args$input_path, '/genes.txt', sep=''), sep='\t') 24 | metadata <- read.csv(paste(args$input_path, '/metadata.txt', sep=''), sep='\t') 25 | row.names(metadata) <- metadata[,1] 26 | metadata <- metadata[,-1] 27 | metadata$batch = as.character(metadata$batch) 28 | # metadata$celltype = as.character(metadata$celltype) 29 | 30 | data <- t(data) 31 | 32 | colnames(data) <- row.names(metadata) 33 | row.names(data) <- genes[,1] 34 | 35 | adata <- CreateSeuratObject(as(data, "sparseMatrix"), 36 | meta.data = metadata, 37 | min.cells = args$minCells, 38 | min.features = args$minFeatures) 39 | 40 | print(dim(adata)[2]) 41 | message('Preprocessing...') 42 | message('Normalization') 43 | adata <- NormalizeData(adata) 44 | message('FindVariableFeatures') 45 | adata <- FindVariableFeatures(adata, selection.method = "vst", nfeatures = args$n_top_features, verbose = FALSE) 46 | message('ScaleData') 47 | adata <- ScaleData(adata, verbose = FALSE) 48 | message('RunPCA') 49 | adata <- RunPCA(adata, pc.genes = data@var.genes, npcs = 30, verbose = FALSE) 50 | 51 | message('Integrating...') 52 | options(repr.plot.height = 2.5, repr.plot.width = 6) 53 | adata <- RunHarmony(adata, "batch", plot_convergence = FALSE) 54 | 55 | if (!file.exists(args$output_path)){ 56 | dir.create(file.path(args$output_path),recursive = TRUE) 57 | } 58 | 59 | write.table(adata@meta.data[,-1], paste(args$output_path, "/metadata.txt", sep=''), sep='\t') 60 | write.table(Embeddings(adata@reductions$harmony), paste(args$output_path, "/integrated.txt", sep=''), sep='\t') 61 | rm(list = ls()) -------------------------------------------------------------------------------- /third_parties/LIGER.R: -------------------------------------------------------------------------------- 1 | suppressPackageStartupMessages(library(argparse)) 2 | suppressPackageStartupMessages(library(Seurat)) 3 | suppressPackageStartupMessages(library(scater)) 4 | suppressPackageStartupMessages(library(rliger)) 5 | suppressPackageStartupMessages(library(future)) 6 | suppressPackageStartupMessages(library(future.apply)) 7 | 8 | parser <- ArgumentParser(description='batch_iNMF for the integrative analysis of multi-batch single-cell transcriptomic profiles') 9 | 10 | parser$add_argument("-i", "--input_path", type="character", help="Path contains RNA data") 11 | parser$add_argument("-o", "--output_path", type="character", default='./', help="Output path") 12 | parser$add_argument("-mf", "--minFeatures", type="integer", default=600, help="Remove cells with less than minFeatures features") 13 | parser$add_argument("-mc", "--minCells", type="integer", default=3, help="Remove features with less than minCells cells") 14 | parser$add_argument("-nt", "--n_top_features", type="integer", default=2000, help="N highly variable features") 15 | 16 | args <- parser$parse_args() 17 | 18 | 19 | plan("multiprocess", workers = 1) 20 | options(future.globals.maxSize = 100000 * 1024^2) 21 | 22 | message('Reading matrix.mtx and metadata.txt in R...') 23 | data <- readMM(paste(args$input_path, '/matrix.mtx', sep='')) 24 | genes <- read.table(paste(args$input_path, '/genes.txt', sep=''), sep='\t') 25 | metadata <- read.csv(paste(args$input_path, '/metadata.txt', sep=''), sep='\t') 26 | row.names(metadata) <- metadata[,1] 27 | metadata <- metadata[,-1] 28 | metadata$batch = as.character(metadata$batch) 29 | # metadata$celltype = as.character(metadata$celltype) 30 | 31 | data <- t(data) 32 | 33 | colnames(data) <- row.names(metadata) 34 | row.names(data) <- genes[,1] 35 | 36 | scdata <- CreateSeuratObject(as(data, "sparseMatrix"), 37 | meta.data = metadata, 38 | min.cells = args$minCells, 39 | min.features = args$minFeatures) 40 | 41 | message('Preprocessing...') 42 | scdata.list <- SplitObject(scdata, split.by = "batch") 43 | batches=list() 44 | i = 1 45 | for(item in scdata.list){ 46 | batches[names(scdata.list[i])] = item@assays$RNA@counts 47 | i = i+1 48 | } 49 | 50 | adata = createLiger(batches,remove.missing=FALSE) 51 | varthresh = 0.2 ### default 0.2 52 | 53 | # varthresh = 0.001 54 | adata = rliger::normalize(adata,remove.missing = FALSE) 55 | adata = selectGenes(adata, var.thresh = varthresh, do.plot = F) 56 | # adata@var.genes = rownames(adata@norm.data$'0') 57 | adata = scaleNotCenter(adata, remove.missing = FALSE) 58 | adata = optimizeALS(adata, k = 20) 59 | adata = quantile_norm(adata) 60 | 61 | 62 | if (!file.exists(args$output_path)){ 63 | dir.create(file.path(args$output_path),recursive = TRUE) 64 | } 65 | 66 | message('Writing...') 67 | write.table(metadata[row.names(adata@cell.data),], paste(args$output_path, "/metadata.txt", sep=''), sep='\t') 68 | write.table(data.frame(adata@H.norm), paste(args$output_path, "/integrated.txt", sep=''), sep='\t') 69 | rm(list = ls()) 70 | -------------------------------------------------------------------------------- /third_parties/Raw.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import scanpy as sc 4 | import argparse 5 | import anndata 6 | import os 7 | import time 8 | from datetime import timedelta 9 | import matplotlib.pyplot as plt 10 | 11 | sc.settings.verbosity = 3 # verbosity: errors (0), warnings (1), info (2), hints (3) 12 | sc.settings.set_figure_params(dpi=150) # low dpi (dots per inch) yields small inline figures 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser(description='Integrate multi single cell datasets') 16 | parser.add_argument('--h5ad', type=str, default=None) 17 | parser.add_argument('--outdir', '-o', type=str, default='./') 18 | parser.add_argument('--min_genes', type=int, default=600) 19 | parser.add_argument('--min_cells', type=int, default=3) 20 | parser.add_argument('--batch_key', '-bk', type=str, default='batch') 21 | parser.add_argument('--num_pcs', type=int, default=20) 22 | parser.add_argument('--n_top_features', type=int, default=2000) 23 | args = parser.parse_args() 24 | 25 | outdir = args.outdir 26 | os.makedirs(outdir, exist_ok=True) 27 | adata = sc.read_h5ad(args.h5ad) 28 | time1 = time.time() 29 | sc.pp.filter_cells(adata, min_genes=args.min_genes) 30 | sc.pp.filter_genes(adata, min_cells=args.min_cells) 31 | 32 | sc.pp.normalize_total(adata, target_sum=1e4) 33 | sc.pp.log1p(adata) 34 | adata.raw = adata 35 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", batch_key=args.batch_key, n_top_genes=args.n_top_features, subset=True) 36 | sc.pp.scale(adata, max_value=10) 37 | sc.tl.pca(adata) 38 | sc.pp.neighbors(adata, n_pcs=args.num_pcs, n_neighbors=20) 39 | sc.tl.umap(adata, min_dist=0.1) 40 | 41 | # UMAP 42 | sc.settings.figdir = outdir 43 | plt.rcParams['figure.figsize'] = (6, 8) 44 | cols = ['celltype', 'batch','sample'] 45 | 46 | color = [c for c in cols if c in adata.obs] 47 | sc.pl.umap(adata, color=color, frameon=False, save='.png', wspace=0.4, show=False, ncols=1) 48 | 49 | time2 = time.time() 50 | print("--- %s seconds ---" % (time2 - time1)) 51 | 52 | # pickle data 53 | adata.write(outdir+'/adata.h5ad', compression='gzip') 54 | -------------------------------------------------------------------------------- /third_parties/Scanorama.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import scanpy as sc 4 | import scanorama 5 | from anndata import AnnData 6 | import argparse 7 | import anndata 8 | import os 9 | import time 10 | import matplotlib.pyplot as plt 11 | from matplotlib import style 12 | 13 | sc.settings.verbosity = 3 # verbosity: errors (0), warnings (1), info (2), hints (3) 14 | sc.settings.set_figure_params(dpi=150) # low dpi (dots per inch) yields small inline figures 15 | 16 | if __name__ == '__main__': 17 | parser = argparse.ArgumentParser(description='Integrate multi single cell datasets by Scanorama') 18 | parser.add_argument('--h5ad', type=str, default=None) 19 | parser.add_argument('--outdir', '-o', type=str, default='./') 20 | parser.add_argument('--min_genes', type=int, default=600) 21 | parser.add_argument('--min_cells', type=int, default=3) 22 | parser.add_argument('--num_pcs', type=int, default=20) 23 | parser.add_argument('--n_top_features', type=int, default=2000) 24 | 25 | args = parser.parse_args() 26 | 27 | outdir = args.outdir 28 | os.makedirs(outdir, exist_ok=True) 29 | 30 | adata = sc.read_h5ad(args.h5ad) 31 | t1 = time.time() 32 | sc.pp.filter_cells(adata, min_genes=args.min_genes) 33 | sc.pp.filter_genes(adata, min_cells=args.min_cells) 34 | 35 | sc.pp.normalize_total(adata, inplace=True) 36 | sc.pp.log1p(adata) 37 | adata = adata.copy() 38 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", batch_key='batch', n_top_genes=args.n_top_features, subset=True) 39 | batch_ = adata.obs['batch'].unique().astype('object') 40 | adatas = [adata[adata.obs['batch']==batch] for batch in batch_] 41 | 42 | time1 = time.time() 43 | integrated = scanorama.correct_scanpy(adatas, return_dimred=True) 44 | concat = AnnData.concatenate(*integrated, join='outer', batch_key='batch_scanorama', index_unique=None) 45 | sc.pp.neighbors(concat, use_rep="X_scanorama") 46 | sc.tl.umap(concat) 47 | 48 | # UMAP 49 | sc.settings.figdir = outdir 50 | plt.rcParams['figure.figsize'] = (6, 8) 51 | cols = ['celltype', 'batch'] 52 | 53 | color = [c for c in cols if c in concat.obs] 54 | sc.pl.umap(concat, color=color, frameon=False, save='.png', wspace=0.4, show=False, ncols=1) 55 | 56 | time2 = time.time() 57 | print("--- %s seconds ---" % (time2 - time1)) 58 | # pickle data 59 | concat.write(outdir+'/adata.h5ad', compression='gzip') 60 | 61 | -------------------------------------------------------------------------------- /third_parties/Seurat_v3.R: -------------------------------------------------------------------------------- 1 | suppressPackageStartupMessages(library(argparse)) 2 | suppressPackageStartupMessages(library(Seurat)) 3 | suppressPackageStartupMessages(library(scater)) 4 | suppressPackageStartupMessages(library(future)) 5 | suppressPackageStartupMessages(library(Matrix)) 6 | suppressPackageStartupMessages(library(future.apply)) 7 | 8 | parser <- ArgumentParser(description='Seurat_v3 for the integrative analysis of multi-batch single-cell transcriptomic profiles') 9 | 10 | parser$add_argument("-i", "--input_path", type="character", help="Path contains RNA data") 11 | parser$add_argument("-o", "--output_path", type="character", default='./', help="Output path") 12 | parser$add_argument("-mf", "--minFeatures", type="integer", default=600, help="Remove cells with less than minFeatures features") 13 | parser$add_argument("-mc", "--minCells", type="integer", default=3, help="Remove features with less than minCells cells") 14 | parser$add_argument("-nt", "--n_top_features", type="integer", default=2000, help="N highly variable features") 15 | args <- parser$parse_args() 16 | 17 | plan("multiprocess", workers = 1) 18 | options(future.globals.maxSize = 100000 * 1024^2) 19 | 20 | message('Reading matrix.mtx and metadata.txt in R...') 21 | data <- readMM(paste(args$input_path, '/matrix.mtx', sep='')) 22 | genes <- read.table(paste(args$input_path, '/genes.txt', sep=''), sep='\t') 23 | metadata <- read.csv(paste(args$input_path, '/metadata.txt', sep=''), sep='\t') 24 | row.names(metadata) <- metadata[,1] 25 | metadata <- metadata[,-1] 26 | metadata$batch = as.character(metadata$batch) 27 | # metadata$celltype = as.character(metadata$celltype) 28 | 29 | data <- t(data) 30 | 31 | colnames(data) <- row.names(metadata) 32 | row.names(data) <- genes[,1] 33 | 34 | adata <- CreateSeuratObject(as(data, "sparseMatrix"), 35 | meta.data = metadata, 36 | min.cells = args$minCells, 37 | min.features = args$minFeatures) 38 | 39 | 40 | message('Preprocessing...') 41 | adata.list <- SplitObject(adata, split.by = "batch") 42 | 43 | print(dim(adata)[2]) 44 | if(dim(adata)[2] < 50000){ 45 | for (i in 1:length(adata.list)) { 46 | adata.list[[i]] <- NormalizeData(adata.list[[i]], verbose = FALSE) 47 | adata.list[[i]] <- FindVariableFeatures(adata.list[[i]], selection.method = "vst", nfeatures = args$n_top_features, verbose = FALSE) 48 | } 49 | message('FindIntegrationAnchors...') 50 | adata.anchors <- FindIntegrationAnchors(object.list = adata.list, dims = 1:30,verbose =FALSE,k.filter = 30) 51 | # adata.anchors <- FindIntegrationAnchors(object.list = adata.list, dims = 1:30,verbose =FALSE,k.filter = 100) 52 | 53 | message('IntegrateData...') 54 | adata.integrated <- IntegrateData(anchorset = adata.anchors, dims = 1:30, verbose = FALSE) 55 | }else{ 56 | adata.list <- future_lapply(X = adata.list, FUN = function(x) { 57 | x <- NormalizeData(x, verbose = FALSE) 58 | x <- FindVariableFeatures(x, nfeatures = args$n_top_features, verbose = FALSE) 59 | }) 60 | 61 | features <- SelectIntegrationFeatures(object.list = adata.list) 62 | adata.list <- future_lapply(X = adata.list, FUN = function(x) { 63 | x <- ScaleData(x, features = features, verbose = FALSE) 64 | x <- RunPCA(x, features = features, verbose = FALSE) 65 | }) 66 | message('FindIntegrationAnchors...') 67 | adata.anchors <- FindIntegrationAnchors(object.list = adata.list, dims = 1:30, verbose =FALSE, reduction = 'rpca', reference = c(1, 2)) 68 | message('IntegrateData...') 69 | adata.integrated <- IntegrateData(anchorset = adata.anchors, dims = 1:30, verbose = FALSE) 70 | } 71 | 72 | if (!file.exists(args$output_path)){ 73 | dir.create(file.path(args$output_path),recursive = TRUE) 74 | } 75 | 76 | message('Writing...') 77 | write.table(adata.integrated@meta.data[,-1], paste(args$output_path, "/metadata.txt", sep=''), sep='\t') 78 | write.table(t(data.frame(adata.integrated@assays$integrated@data)), paste(args$output_path, "/integrated.txt", sep=''), sep='\t') 79 | rm(list = ls()) -------------------------------------------------------------------------------- /third_parties/online_iNMF.R: -------------------------------------------------------------------------------- 1 | suppressPackageStartupMessages(library(argparse)) 2 | suppressPackageStartupMessages(library(Seurat)) 3 | suppressPackageStartupMessages(library(scater)) 4 | suppressPackageStartupMessages(library(rliger)) 5 | suppressPackageStartupMessages(library(future)) 6 | suppressPackageStartupMessages(library(future.apply)) 7 | 8 | parser <- ArgumentParser(description='online_iNMF for the integrative analysis of multi-batch single-cell transcriptomic profiles') 9 | 10 | parser$add_argument("-i", "--input_path", type="character", help="Path contains RNA data") 11 | parser$add_argument("-o", "--output_path", type="character", default='./', help="Output path") 12 | parser$add_argument("-mf", "--minFeatures", type="integer", default=600, help="Remove cells with less than minFeatures features") 13 | parser$add_argument("-mc", "--minCells", type="integer", default=3, help="Remove features with less than minCells cells") 14 | parser$add_argument("-nt", "--n_top_features", type="integer", default=2000, help="N highly variable features") 15 | 16 | args <- parser$parse_args() 17 | 18 | 19 | plan("multiprocess", workers = 1) 20 | options(future.globals.maxSize = 100000 * 1024^2) 21 | 22 | message('Reading matrix.mtx and metadata.txt in R...') 23 | data <- readMM(paste(args$input_path, '/matrix.mtx', sep='')) 24 | genes <- read.table(paste(args$input_path, '/genes.txt', sep=''), sep='\t') 25 | metadata <- read.csv(paste(args$input_path, '/metadata.txt', sep=''), sep='\t') 26 | row.names(metadata) <- metadata[,1] 27 | metadata <- metadata[,-1] 28 | metadata$batch = as.character(metadata$batch) 29 | # metadata$celltype = as.character(metadata$celltype) 30 | 31 | data <- t(data) 32 | 33 | colnames(data) <- row.names(metadata) 34 | row.names(data) <- genes[,1] 35 | 36 | scdata <- CreateSeuratObject(as(data, "sparseMatrix"), 37 | meta.data = metadata, 38 | min.cells = args$minCells, 39 | min.features = args$minFeatures) 40 | 41 | message('Preprocessing...') 42 | scdata.list <- SplitObject(scdata, split.by = "batch") 43 | batches=list() 44 | i = 1 45 | for(item in scdata.list){ 46 | batches[names(scdata.list[i])] = item@assays$RNA@counts 47 | i = i+1 48 | } 49 | adata = createLiger(batches,remove.missing=FALSE) 50 | 51 | if(dim(data)[2]>5000){ 52 | Batch_size = 5000 53 | }else{ 54 | Batch_size = 1000 55 | } 56 | varthresh = 0.2 57 | # varthresh = 0.001 58 | adata = rliger::normalize(adata,remove.missing = FALSE) 59 | adata = selectGenes(adata, var.thresh = varthresh, do.plot = F) 60 | # adata@var.genes = rownames(adata@norm.data$'0') 61 | adata = scaleNotCenter(adata, remove.missing = FALSE) 62 | print(Batch_size) 63 | adata = online_iNMF(adata, k = 20, miniBatch_size = Batch_size, max.epochs = 5) 64 | adata = quantile_norm(adata) 65 | 66 | 67 | if (!file.exists(args$output_path)){ 68 | dir.create(file.path(args$output_path),recursive = TRUE) 69 | } 70 | 71 | message('Writing...') 72 | write.table(metadata[row.names(adata@cell.data),], paste(args$output_path, "/metadata.txt", sep=''), sep='\t') 73 | write.table(data.frame(adata@H.norm), paste(args$output_path, "/integrated.txt", sep=''), sep='\t') 74 | rm(list = ls()) 75 | -------------------------------------------------------------------------------- /third_parties/scVI.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import argparse 4 | import matplotlib.pyplot as plt 5 | import os 6 | import time 7 | 8 | import warnings 9 | warnings.simplefilter(action='ignore', category=FutureWarning) 10 | 11 | import anndata 12 | import scvi 13 | import scanpy as sc 14 | 15 | sc.settings.verbosity = 3 # verbosity: errors (0), warnings (1), info (2), hints (3) 16 | sc.settings.set_figure_params(dpi=150) # low dpi (dots per inch) yields small inline figures 17 | 18 | 19 | 20 | if __name__ == '__main__': 21 | parser = argparse.ArgumentParser(description='Integrate multi single cell datasets by scVI') 22 | parser.add_argument('--h5ad', type=str, default=None) 23 | parser.add_argument('--outdir', '-o', type=str, default='./') 24 | parser.add_argument('--min_genes', type=int, default=600) 25 | parser.add_argument('--min_cells', type=int, default=3) 26 | parser.add_argument('-g','--gpu', type=int, default=0) 27 | parser.add_argument('--n_top_features', type=int, default=2000) 28 | args = parser.parse_args() 29 | 30 | outdir = args.outdir 31 | os.makedirs(outdir, exist_ok=True) 32 | 33 | adata = sc.read_h5ad(args.h5ad) 34 | time1 = time.time() 35 | sc.pp.filter_cells(adata, min_genes=args.min_genes) 36 | sc.pp.filter_genes(adata, min_cells=args.min_cells) 37 | 38 | adata.layers["counts"] = adata.X.copy() 39 | sc.pp.normalize_total(adata, target_sum=1e4) 40 | sc.pp.log1p(adata) 41 | adata.raw = adata 42 | 43 | sc.pp.highly_variable_genes(adata, 44 | flavor="seurat_v3", 45 | n_top_genes=args.n_top_features, 46 | layer="counts", 47 | batch_key="batch", 48 | subset=True) 49 | 50 | scvi.model.SCVI.setup_anndata(adata, 51 | layer="counts", 52 | # categorical_covariate_keys=["batch", "donor"], 53 | # continuous_covariate_keys=["percent_mito", "percent_ribo"], 54 | batch_key="batch" 55 | ) 56 | vae = scvi.model.SCVI(adata) 57 | # vae = scvi.model.SCVI(adata, n_layers=2, n_latent=30, gene_likelihood="nb") 58 | vae.train(use_gpu = args.gpu) 59 | 60 | adata.obsm["X_scVI"] = vae.get_latent_representation() 61 | sc.pp.neighbors(adata, use_rep="X_scVI") 62 | sc.tl.umap(adata, min_dist=0.1) 63 | time2 = time.time() 64 | 65 | # UMAP 66 | sc.settings.figdir = outdir 67 | plt.rcParams['figure.figsize'] = (6, 8) 68 | cols = ['celltype', 'batch'] 69 | 70 | color = [c for c in cols if c in adata.obs] 71 | sc.pl.umap(adata, color=color, frameon=False, save='.png', wspace=0.4, show=False, ncols=1) 72 | 73 | # pickle data 74 | adata.write(outdir+'/adata.h5ad', compression='gzip') 75 | print("--- %s seconds ---" % (time2 - time1)) 76 | 77 | --------------------------------------------------------------------------------