├── .github └── workflows │ ├── docs.yml │ ├── publish-to-pypi.yml │ ├── run_isort.yml │ ├── run_pytest.yml │ └── run_yapf.yml ├── .isort.cfg ├── .pre-commit-config.yaml ├── .style.yapf ├── LICENSE ├── README.md ├── config ├── README.md └── config.yaml ├── docs ├── Makefile ├── _static │ └── logo.css ├── conf.py ├── index.rst ├── make.bat ├── modules.rst ├── pipeline.rst ├── scalr.analysis.rst ├── scalr.data.preprocess.rst ├── scalr.data.rst ├── scalr.data.split.rst ├── scalr.feature.rst ├── scalr.feature.scoring.rst ├── scalr.feature.selector.rst ├── scalr.nn.callbacks.rst ├── scalr.nn.dataloader.rst ├── scalr.nn.loss.rst ├── scalr.nn.model.rst ├── scalr.nn.rst ├── scalr.nn.trainer.rst ├── scalr.rst └── scalr.utils.rst ├── img ├── Schematic-of-scPipeline.jpg └── scaLR_logo.png ├── pipeline.py ├── pyproject.toml ├── requirements.txt ├── scalr ├── __init__.py ├── analysis │ ├── __init__.py │ ├── _analyser.py │ ├── dge_lmem.py │ ├── dge_pseudobulk.py │ ├── evaluation.py │ ├── gene_recall_curve.py │ ├── heatmap.py │ ├── roc_auc.py │ ├── test_dge_lmem.py │ └── test_dge_pseudobulk.py ├── data │ ├── __init__.py │ ├── preprocess │ │ ├── __init__.py │ │ ├── _preprocess.py │ │ ├── sample_norm.py │ │ ├── standard_scale.py │ │ ├── test_sample_norm.py │ │ └── test_standard_scale.py │ └── split │ │ ├── __init__.py │ │ ├── _split.py │ │ ├── group_splitter.py │ │ ├── stratified_group_splitter.py │ │ └── stratified_splitter.py ├── data_ingestion_pipeline.py ├── eval_and_analysis_pipeline.py ├── feature │ ├── __init__.py │ ├── feature_subsetting.py │ ├── scoring │ │ ├── __init__.py │ │ ├── _scoring.py │ │ ├── linear_scorer.py │ │ └── shap_scorer.py │ └── selector │ │ ├── __init__.py │ │ ├── _selector.py │ │ ├── abs_mean.py │ │ ├── classwise_abs.py │ │ └── classwise_promoters.py ├── feature_extraction_pipeline.py ├── model_training_pipeline.py ├── nn │ ├── __init__.py │ ├── callbacks │ │ ├── __init__.py │ │ ├── _callbacks.py │ │ ├── early_stopping.py │ │ ├── model_checkpoint.py │ │ ├── tensorboard_logger.py │ │ └── test_early_stopping.py │ ├── dataloader │ │ ├── __init__.py │ │ ├── _dataloader.py │ │ ├── simple_dataloader.py │ │ ├── simple_metadataloader.py │ │ ├── test_simple_dataloader.py │ │ └── test_simple_metadataloader.py │ ├── loss │ │ ├── __init__.py │ │ └── _loss.py │ ├── model │ │ ├── __init__.py │ │ ├── _model.py │ │ ├── sequential_model.py │ │ └── shap_model.py │ └── trainer │ │ ├── __init__.py │ │ ├── _trainer.py │ │ └── simple_model_trainer.py └── utils │ ├── __init__.py │ ├── data_utils.py │ ├── file_utils.py │ ├── logger.py │ ├── misc_utils.py │ ├── test_file_utils.py │ └── test_misc_utils.py └── tutorials ├── analysis ├── differential_gene_expression │ ├── dge.ipynb │ ├── dge_config.yaml │ ├── dge_lmem_main.py │ ├── dge_pseudobulk_main.py │ └── tutorial_config.png ├── gene_recall_curve │ ├── gene_recall_curve.ipynb │ ├── multi_model_gene_recall_comparison.png │ ├── ranked_genes.csv │ ├── reference_genes.csv │ └── score_matrix.csv └── shap_analysis │ └── shap_heatmap.ipynb ├── pipeline ├── config_celltype.yaml ├── config_clinical.yaml ├── grc_reference_gene.csv ├── scalr_pipeline.ipynb └── scalr_pipeline_local_run.ipynb └── preprocessing ├── batch_correction.ipynb └── normalization.ipynb /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Publish Documentation on GitHub Pages 2 | 3 | on: [push, pull_request, workflow_dispatch] 4 | 5 | permissions: 6 | contents: write 7 | 8 | jobs: 9 | docs: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - uses: actions/setup-python@v5 14 | - name: Install scalr requirements 15 | run: | 16 | pip install -r requirements.txt 17 | - name: Install sphinx dependencies 18 | run: | 19 | pip install sphinx sphinx_rtd_theme myst_parser 20 | - name: Sphinx build 21 | run: | 22 | sphinx-build docs _build 23 | - name: Deploy to GitHub Pages 24 | uses: peaceiris/actions-gh-pages@v3 25 | if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} 26 | with: 27 | publish_branch: gh-pages 28 | github_token: ${{ secrets.GITHUB_TOKEN }} 29 | publish_dir: _build/ 30 | force_orphan: true 31 | -------------------------------------------------------------------------------- /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish package to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - v* # Push events to tag which starts with v 7 | 8 | jobs: 9 | build: 10 | name: Build distribution 📦 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v4 15 | - name: Set up Python 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: "3.9" 19 | - name: Install pypa/build 20 | run: >- 21 | python3 -m 22 | pip install 23 | build 24 | --user 25 | - name: Build a binary wheel and a source tarball 26 | run: python3 -m build 27 | - name: Store the distribution packages 28 | uses: actions/upload-artifact@v4 29 | with: 30 | name: python-package-distributions 31 | path: dist/ 32 | publish-to-pypi: 33 | name: >- 34 | Publish to PyPI 35 | if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes 36 | needs: 37 | - build 38 | runs-on: ubuntu-latest 39 | environment: 40 | name: pypi 41 | url: https://pypi.org/p/pyscalr 42 | permissions: 43 | id-token: write # IMPORTANT: mandatory for trusted publishing 44 | steps: 45 | - name: Download all the dists 46 | uses: actions/download-artifact@v4 47 | with: 48 | name: python-package-distributions 49 | path: dist/ 50 | - name: Publish to PyPI 51 | uses: pypa/gh-action-pypi-publish@release/v1 52 | -------------------------------------------------------------------------------- /.github/workflows/run_isort.yml: -------------------------------------------------------------------------------- 1 | name: "Import checker" 2 | 3 | on: 4 | - push 5 | 6 | jobs: 7 | build: 8 | 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - uses: actions/checkout@v2 13 | - name: Set up Python 3.9 14 | uses: actions/setup-python@v2 15 | with: 16 | python-version: 3.9 17 | - name: Install isort 18 | run: pip install isort 19 | - name: Run isort 20 | run: isort . --settings-path .isort.cfg 21 | -------------------------------------------------------------------------------- /.github/workflows/run_pytest.yml: -------------------------------------------------------------------------------- 1 | name: Run Unit Test via Pytest 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | push: 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.9"] 15 | 16 | steps: 17 | - uses: actions/checkout@v3 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v4 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 26 | - name: Running test cases 27 | run: pytest -v -s 28 | -------------------------------------------------------------------------------- /.github/workflows/run_yapf.yml: -------------------------------------------------------------------------------- 1 | name: YAPF Formatting Check 2 | 3 | on: [push] 4 | 5 | jobs: 6 | formatting-check: 7 | name: Formatting Check 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v2 11 | - name: run YAPF 12 | uses: AlexanderMelde/yapf-action@master 13 | with: 14 | args: --verbose 15 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | profile = google 3 | line_length = 80 4 | multi_line_output = 3 5 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/mirrors-yapf 3 | rev: v0.32.0 4 | hooks: 5 | - id: yapf 6 | args: ['-ir', '--style', '.style.yapf'] 7 | additional_dependencies: [toml] 8 | - repo: https://github.com/pre-commit/mirrors-isort 9 | rev: v5.10.1 10 | hooks: 11 | - id: isort 12 | args: ['--settings-path', '.isort.cfg'] 13 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = google 3 | indent_width = 4 4 | spaces_before_comment = 4 5 | column_limit = 80 6 | -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | # Config file for pipeline run. 2 | 3 | # DEVICE SETUP. 4 | device: 'cuda' 5 | 6 | # EXPERIMENT. 7 | experiment: 8 | dirpath: 'scalr_experiments' 9 | exp_name: 'exp_name' 10 | exp_run: 0 11 | 12 | 13 | # DATA CONFIG. 14 | data: 15 | sample_chunksize: 20000 16 | num_workers: 1 17 | 18 | train_val_test: 19 | full_datapath: '/path/to/anndata.h5ad' 20 | 21 | splitter_config: 22 | name: GroupSplitter 23 | params: 24 | split_ratio: [7, 1, 2.5] 25 | stratify: 'donor_id' 26 | 27 | # split_datapaths: '' 28 | 29 | # preprocess: 30 | # - name: SampleNorm 31 | # params: 32 | # **args 33 | 34 | # - name: StandardScaler 35 | # params: 36 | # **args 37 | 38 | target: Cell_Type 39 | 40 | 41 | # FEATURE SELECTION. 42 | feature_selection: 43 | 44 | # score_matrix: '/path/to/matrix' 45 | feature_subsetsize: 5000 46 | num_workers: 1 47 | 48 | model: 49 | name: SequentialModel 50 | params: 51 | layers: [5000, 6] 52 | weights_init_zero: True 53 | 54 | model_train_config: 55 | trainer: SimpleModelTrainer 56 | 57 | dataloader: 58 | name: SimpleDataLoader 59 | params: 60 | batch_size: 25000 61 | padding: 5000 62 | 63 | optimizer: 64 | name: SGD 65 | params: 66 | lr: 1.0e-3 67 | weight_decay: 0.1 68 | 69 | loss: 70 | name: CrossEntropyLoss 71 | 72 | epochs: 1 73 | 74 | scoring_config: 75 | name: LinearScorer 76 | 77 | features_selector: 78 | name: AbsMean 79 | params: 80 | k: 5000 81 | 82 | 83 | # FINAL MODEL TRAINING. 84 | final_training: 85 | 86 | model: 87 | name: SequentialModel 88 | params: 89 | layers: [5000, 6] 90 | dropout: 0 91 | weights_init_zero: False 92 | 93 | model_train_config: 94 | resume_from_checkpoint: null 95 | 96 | trainer: SimpleModelTrainer 97 | 98 | dataloader: 99 | name: SimpleDataLoader 100 | params: 101 | batch_size: 15000 102 | 103 | optimizer: 104 | name: Adam 105 | params: 106 | lr: 1.0e-3 107 | weight_decay: 0 108 | 109 | loss: 110 | name: CrossEntropyLoss 111 | 112 | epochs: 1 113 | 114 | callbacks: 115 | - name: TensorboardLogger 116 | - name: EarlyStopping 117 | params: 118 | patience: 3 119 | min_delta: 1.0e-4 120 | - name: ModelCheckpoint 121 | params: 122 | interval: 5 123 | 124 | 125 | # EVALUATION & DOWNSTREAM ANALYSIS. 126 | analysis: 127 | 128 | model_checkpoint: '' 129 | 130 | dataloader: 131 | name: SimpleDataLoader 132 | params: 133 | batch_size: 15000 134 | 135 | gene_analysis: 136 | scoring_config: 137 | name: LinearScorer 138 | 139 | features_selector: 140 | name: ClasswisePromoters 141 | params: 142 | k: 100 143 | 144 | # full_samples_downstream_analysis: 145 | # - name: DgePseudoBulk 146 | # params: **kwargs 147 | # - name: DgeLMEM 148 | # params: **kwargs 149 | 150 | # test_samples_downstream_analysis: 151 | # - name: GeneRecallCurve 152 | # params: 153 | # reference_genes_path: '/path/to/reference_genes.csv' 154 | # top_K: 300 155 | # plots_per_row: 3 156 | # features_selector: 157 | # name: ClasswiseAbs 158 | # params: {} 159 | # - name: Heatmap 160 | # params: **kwargs 161 | # - name: RocAucCurve 162 | # params: **kwargs 163 | 164 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | 22 | -------------------------------------------------------------------------------- /docs/_static/logo.css: -------------------------------------------------------------------------------- 1 | .wy-side-nav-search .logo { 2 | max-width: 150px !important; 3 | height: auto; 4 | } 5 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -- Project information ----------------------------------------------------- 2 | project = 'scaLR' 3 | copyright = '2024, Infocusp Innovations' 4 | author = 'Infocusp Innovations' 5 | release = 'v1.1.0' 6 | 7 | # -- General configuration --------------------------------------------------- 8 | import os 9 | import sys 10 | 11 | sys.path.insert(0, os.path.abspath("..")) 12 | 13 | extensions = [ 14 | 'sphinx.ext.autodoc', 'sphinx.ext.autosectionlabel', 15 | 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.extlinks', 16 | 'sphinx.ext.githubpages', 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', 17 | 'sphinx.ext.todo', 'sphinx.ext.duration', 'sphinx.ext.doctest', 18 | 'sphinx.ext.autosummary', 'sphinx.ext.intersphinx', 'myst_parser' 19 | ] 20 | source_suffix = ['.rst', '.md'] 21 | templates_path = ['_templates'] 22 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 23 | master_doc = 'index' 24 | 25 | html_theme = 'sphinx_rtd_theme' 26 | html_static_path = ['_static'] 27 | html_css_files = ['logo.css'] 28 | html_logo = "../img/scaLR_logo.png" 29 | html_favicon = "../img/scaLR_logo.png" 30 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | scaLR Documentation 2 | =================== 3 | .. image:: ../img/scaLR_logo.png 4 | :scale: 60% 5 | 6 | .. include:: ../README.md 7 | :parser: myst_parser.sphinx_ 8 | :start-line: 2 9 | :end-line: 24 10 | 11 | .. image:: ../img/Schematic-of-scPipeline.jpg 12 | 13 | .. include:: ../README.md 14 | :parser: myst_parser.sphinx_ 15 | :start-line: 26 16 | 17 | .. toctree:: 18 | :maxdepth: 10 19 | :caption: scaLR 20 | :hidden: 21 | 22 | scalr.analysis 23 | scalr.data 24 | scalr.feature 25 | scalr.nn 26 | scalr.utils 27 | 28 | .. toctree:: 29 | :maxdepth: 10 30 | :caption: scaLR Pipelines 31 | :hidden: 32 | 33 | pipeline 34 | scalr.data_ingestion_pipeline 35 | scalr.feature_extraction_pipeline 36 | scalr.model_training_pipeline 37 | scalr.eval_and_analysis_pipeline 38 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/modules.rst: -------------------------------------------------------------------------------- 1 | scaLR 2 | ===== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | scalr 8 | -------------------------------------------------------------------------------- /docs/pipeline.rst: -------------------------------------------------------------------------------- 1 | Data ingestion 2 | --------------- 3 | 4 | .. automodule:: scalr.data_ingestion_pipeline 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | :private-members: 9 | 10 | Eval and analysis 11 | ------------------- 12 | 13 | .. automodule:: scalr.eval_and_analysis_pipeline 14 | :members: 15 | :undoc-members: 16 | :show-inheritance: 17 | :private-members: 18 | 19 | Feature extraction 20 | ------------------- 21 | 22 | .. automodule:: scalr.feature_extraction_pipeline 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: 26 | :private-members: 27 | 28 | Model training 29 | --------------- 30 | 31 | .. automodule:: scalr.model_training_pipeline 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | :private-members: 36 | -------------------------------------------------------------------------------- /docs/scalr.analysis.rst: -------------------------------------------------------------------------------- 1 | Analysis 2 | ======== 3 | 4 | 5 | \_analyser module 6 | -------------------------------- 7 | 8 | .. automodule:: scalr.analysis._analyser 9 | :members: 10 | :undoc-members: 11 | :show-inheritance: 12 | :private-members: 13 | 14 | dge\_lmem 15 | --------- 16 | 17 | .. automodule:: scalr.analysis.dge_lmem 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | :private-members: 22 | 23 | dge\_pseudobulk 24 | --------------- 25 | 26 | .. automodule:: scalr.analysis.dge_pseudobulk 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | :private-members: 31 | 32 | evaluation 33 | ---------- 34 | 35 | .. automodule:: scalr.analysis.evaluation 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | :private-members: 40 | 41 | gene\_recall\_curve 42 | ------------------- 43 | 44 | .. automodule:: scalr.analysis.gene_recall_curve 45 | :members: 46 | :undoc-members: 47 | :show-inheritance: 48 | :private-members: 49 | 50 | heatmap 51 | ------- 52 | 53 | .. automodule:: scalr.analysis.heatmap 54 | :members: 55 | :undoc-members: 56 | :show-inheritance: 57 | :private-members: 58 | 59 | roc\_auc 60 | -------- 61 | 62 | .. automodule:: scalr.analysis.roc_auc 63 | :members: 64 | :undoc-members: 65 | :show-inheritance: 66 | :private-members: 67 | -------------------------------------------------------------------------------- /docs/scalr.data.preprocess.rst: -------------------------------------------------------------------------------- 1 | preprocess 2 | ========== 3 | 4 | \_preprocess 5 | ------------ 6 | 7 | .. automodule:: scalr.data.preprocess._preprocess 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | :private-members: 12 | 13 | sample\_norm 14 | ------------ 15 | 16 | .. automodule:: scalr.data.preprocess.sample_norm 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | :private-members: 21 | 22 | standard\_scale 23 | --------------- 24 | 25 | .. automodule:: scalr.data.preprocess.standard_scale 26 | :members: 27 | :undoc-members: 28 | :show-inheritance: 29 | :private-members: 30 | 31 | test\_sample\_norm 32 | ------------------ 33 | 34 | .. automodule:: scalr.data.preprocess.test_sample_norm 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | :private-members: 39 | 40 | test\_standard\_scale 41 | --------------------- 42 | 43 | .. automodule:: scalr.data.preprocess.test_standard_scale 44 | :members: 45 | :undoc-members: 46 | :show-inheritance: 47 | :private-members: 48 | -------------------------------------------------------------------------------- /docs/scalr.data.rst: -------------------------------------------------------------------------------- 1 | Data 2 | ==== 3 | 4 | .. toctree:: 5 | :maxdepth: 10 6 | 7 | scalr.data.preprocess 8 | scalr.data.split 9 | 10 | -------------------------------------------------------------------------------- /docs/scalr.data.split.rst: -------------------------------------------------------------------------------- 1 | split 2 | ===== 3 | 4 | \_split module 5 | -------------- 6 | 7 | .. automodule:: scalr.data.split._split 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | :private-members: 12 | 13 | group\_splitter 14 | --------------- 15 | 16 | .. automodule:: scalr.data.split.group_splitter 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | :private-members: 21 | 22 | stratified\_group\_splitter 23 | --------------------------- 24 | 25 | .. automodule:: scalr.data.split.stratified_group_splitter 26 | :members: 27 | :undoc-members: 28 | :show-inheritance: 29 | :private-members: 30 | 31 | stratified\_splitter 32 | -------------------- 33 | 34 | .. automodule:: scalr.data.split.stratified_splitter 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | :private-members: 39 | -------------------------------------------------------------------------------- /docs/scalr.feature.rst: -------------------------------------------------------------------------------- 1 | Feature 2 | ======= 3 | 4 | .. toctree:: 5 | :maxdepth: 10 6 | 7 | scalr.feature.scoring 8 | scalr.feature.selector 9 | 10 | feature\_subsetting 11 | ------------------- 12 | 13 | .. automodule:: scalr.feature.feature_subsetting 14 | :members: 15 | :undoc-members: 16 | :show-inheritance: 17 | :private-members: -------------------------------------------------------------------------------- /docs/scalr.feature.scoring.rst: -------------------------------------------------------------------------------- 1 | scoring 2 | ======= 3 | 4 | \_scoring 5 | --------- 6 | 7 | .. automodule:: scalr.feature.scoring._scoring 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | :private-members: 12 | 13 | linear\_scorer 14 | -------------- 15 | 16 | .. automodule:: scalr.feature.scoring.linear_scorer 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | :private-members: 21 | 22 | shap\_scorer 23 | ------------ 24 | 25 | .. automodule:: scalr.feature.scoring.shap_scorer 26 | :members: 27 | :undoc-members: 28 | :show-inheritance: 29 | :private-members: 30 | -------------------------------------------------------------------------------- /docs/scalr.feature.selector.rst: -------------------------------------------------------------------------------- 1 | selector 2 | ======== 3 | 4 | \_selector 5 | ---------- 6 | 7 | .. automodule:: scalr.feature.selector._selector 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | :private-members: 12 | 13 | abs\_mean 14 | --------- 15 | 16 | .. automodule:: scalr.feature.selector.abs_mean 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | :private-members: 21 | 22 | classwise\_abs 23 | -------------- 24 | 25 | .. automodule:: scalr.feature.selector.classwise_abs 26 | :members: 27 | :undoc-members: 28 | :show-inheritance: 29 | :private-members: 30 | 31 | classwise\_promoters 32 | -------------------- 33 | 34 | .. automodule:: scalr.feature.selector.classwise_promoters 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | :private-members: 39 | -------------------------------------------------------------------------------- /docs/scalr.nn.callbacks.rst: -------------------------------------------------------------------------------- 1 | callbacks 2 | ========= 3 | 4 | \_callbacks 5 | ----------- 6 | 7 | .. automodule:: scalr.nn.callbacks._callbacks 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | :private-members: 12 | 13 | early\_stopping 14 | --------------- 15 | 16 | .. automodule:: scalr.nn.callbacks.early_stopping 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | :private-members: 21 | 22 | model\_checkpoint 23 | ----------------- 24 | 25 | .. automodule:: scalr.nn.callbacks.model_checkpoint 26 | :members: 27 | :undoc-members: 28 | :show-inheritance: 29 | :private-members: 30 | 31 | tensorboard\_logger 32 | ------------------- 33 | 34 | .. automodule:: scalr.nn.callbacks.tensorboard_logger 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | :private-members: 39 | 40 | test\_early\_stopping 41 | --------------------- 42 | 43 | .. automodule:: scalr.nn.callbacks.test_early_stopping 44 | :members: 45 | :undoc-members: 46 | :show-inheritance: 47 | :private-members: -------------------------------------------------------------------------------- /docs/scalr.nn.dataloader.rst: -------------------------------------------------------------------------------- 1 | dataloader 2 | ========== 3 | 4 | \_dataloader 5 | ------------ 6 | 7 | .. automodule:: scalr.nn.dataloader._dataloader 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | :private-members: 12 | 13 | simple\_dataloader 14 | ------------------ 15 | 16 | .. automodule:: scalr.nn.dataloader.simple_dataloader 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | :private-members: 21 | 22 | simple\_metadataloader 23 | ---------------------- 24 | 25 | .. automodule:: scalr.nn.dataloader.simple_metadataloader 26 | :members: 27 | :undoc-members: 28 | :show-inheritance: 29 | :private-members: 30 | 31 | test\_simple\_dataloader 32 | ------------------------ 33 | 34 | .. automodule:: scalr.nn.dataloader.test_simple_dataloader 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | :private-members: 39 | 40 | test\_simple\_metadataloader 41 | ---------------------------- 42 | 43 | .. automodule:: scalr.nn.dataloader.test_simple_metadataloader 44 | :members: 45 | :undoc-members: 46 | :show-inheritance: 47 | :private-members: -------------------------------------------------------------------------------- /docs/scalr.nn.loss.rst: -------------------------------------------------------------------------------- 1 | loss 2 | ==== 3 | 4 | \_loss 5 | ------ 6 | 7 | .. automodule:: scalr.nn.loss._loss 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | :private-members: -------------------------------------------------------------------------------- /docs/scalr.nn.model.rst: -------------------------------------------------------------------------------- 1 | model 2 | ===== 3 | 4 | \_model 5 | ------- 6 | 7 | .. automodule:: scalr.nn.model._model 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | :private-members: 12 | 13 | sequential\_model 14 | ----------------- 15 | 16 | .. automodule:: scalr.nn.model.sequential_model 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | :private-members: 21 | 22 | shap\_model 23 | ----------- 24 | 25 | .. automodule:: scalr.nn.model.shap_model 26 | :members: 27 | :undoc-members: 28 | :show-inheritance: 29 | :private-members: -------------------------------------------------------------------------------- /docs/scalr.nn.rst: -------------------------------------------------------------------------------- 1 | NN 2 | == 3 | 4 | .. toctree:: 5 | :maxdepth: 10 6 | 7 | scalr.nn.callbacks 8 | scalr.nn.dataloader 9 | scalr.nn.loss 10 | scalr.nn.model 11 | scalr.nn.trainer -------------------------------------------------------------------------------- /docs/scalr.nn.trainer.rst: -------------------------------------------------------------------------------- 1 | trainer 2 | ======= 3 | 4 | \_trainer 5 | --------- 6 | 7 | .. automodule:: scalr.nn.trainer._trainer 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | :private-members: 12 | 13 | simple\_model\_trainer 14 | ---------------------- 15 | 16 | .. automodule:: scalr.nn.trainer.simple_model_trainer 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | :private-members: -------------------------------------------------------------------------------- /docs/scalr.rst: -------------------------------------------------------------------------------- 1 | scalr 2 | ===== 3 | 4 | .. toctree:: 5 | :maxdepth: 10 6 | 7 | scalr.analysis 8 | scalr.data 9 | scalr.feature 10 | scalr.nn 11 | scalr.utils 12 | -------------------------------------------------------------------------------- /docs/scalr.utils.rst: -------------------------------------------------------------------------------- 1 | Utils 2 | ===== 3 | 4 | data\_utils 5 | ----------- 6 | 7 | .. automodule:: scalr.utils.data_utils 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | :private-members: 12 | 13 | file\_utils 14 | ----------- 15 | 16 | .. automodule:: scalr.utils.file_utils 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | :private-members: 21 | 22 | logger 23 | ------ 24 | 25 | .. automodule:: scalr.utils.logger 26 | :members: 27 | :undoc-members: 28 | :show-inheritance: 29 | :private-members: 30 | 31 | misc\_utils 32 | ----------- 33 | 34 | .. automodule:: scalr.utils.misc_utils 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | :private-members: 39 | 40 | test\_file\_utils 41 | ----------------- 42 | 43 | .. automodule:: scalr.utils.test_file_utils 44 | :members: 45 | :undoc-members: 46 | :show-inheritance: 47 | :private-members: 48 | 49 | test\_misc\_utils 50 | ----------------- 51 | 52 | .. automodule:: scalr.utils.test_misc_utils 53 | :members: 54 | :undoc-members: 55 | :show-inheritance: 56 | :private-members: -------------------------------------------------------------------------------- /img/Schematic-of-scPipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infocusp/scaLR/b97553bdc1f02d596d5b7b7ad21c622a304a2793/img/Schematic-of-scPipeline.jpg -------------------------------------------------------------------------------- /img/scaLR_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infocusp/scaLR/b97553bdc1f02d596d5b7b7ad21c622a304a2793/img/scaLR_logo.png -------------------------------------------------------------------------------- /pipeline.py: -------------------------------------------------------------------------------- 1 | """This file contains an implementation of end-to-end pipeline execution.""" 2 | 3 | import argparse 4 | import logging 5 | import os 6 | from os import path 7 | import random 8 | import sys 9 | from time import time 10 | 11 | from memory_profiler import memory_usage 12 | from memory_profiler import profile 13 | import numpy as np 14 | import torch 15 | 16 | from scalr.data_ingestion_pipeline import DataIngestionPipeline 17 | from scalr.eval_and_analysis_pipeline import EvalAndAnalysisPipeline 18 | from scalr.feature_extraction_pipeline import FeatureExtractionPipeline 19 | from scalr.model_training_pipeline import ModelTrainingPipeline 20 | from scalr.utils import EventLogger 21 | from scalr.utils import FlowLogger 22 | from scalr.utils import read_data 23 | from scalr.utils import set_seed 24 | from scalr.utils import write_data 25 | 26 | 27 | def get_args(): 28 | """A function to get the command line arguments.""" 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('-c', 31 | '--config', 32 | type=str, 33 | help='config.yaml file path', 34 | required=True) 35 | parser.add_argument('-l', 36 | '--log', 37 | action='store_true', 38 | help='flag to store logs for the experiment') 39 | parser.add_argument('--level', 40 | type=str, 41 | default='INFO', 42 | help='set the level of logging') 43 | parser.add_argument('--logpath', 44 | type=str, 45 | default=False, 46 | help='path to store the logs') 47 | parser.add_argument('-m', 48 | '--memoryprofiler', 49 | action='store_true', 50 | help='flag to get memory usage analysis') 51 | 52 | args = parser.parse_args() 53 | return args 54 | 55 | 56 | # Uncomment `@profile` to get line-by-line memory analysis 57 | # @profile 58 | def pipeline(config, dirpath, device, flow_logger, event_logger): 59 | """A function that configures all components of the pipeline for end-to-end execution. 60 | 61 | Args: 62 | config: User config. 63 | dirpath: Path of root directory. 64 | flow_logger: Object for flow logger. 65 | event_logger: Object for event logger. 66 | """ 67 | if config.get('data'): 68 | # Data ingestion. 69 | flow_logger.info('Data Ingestion pipeline running') 70 | event_logger.heading1('Data Ingestion') 71 | 72 | data_dirpath = path.join(dirpath, 'data') 73 | os.makedirs(data_dirpath, exist_ok=True) 74 | 75 | ingest_data = DataIngestionPipeline(config['data'], data_dirpath) 76 | ingest_data.generate_train_val_test_split() 77 | ingest_data.preprocess_data() 78 | if not config['data'].get('label_mappings'): 79 | ingest_data.generate_mappings() 80 | 81 | config['data'] = ingest_data.get_updated_config() 82 | write_data(config, path.join(dirpath, 'config.yaml')) 83 | del ingest_data 84 | 85 | if config.get('feature_selection'): 86 | # Feature selection. 87 | flow_logger.info('Feature Extraction pipeline running') 88 | event_logger.heading1('Feature Selection') 89 | 90 | feature_extraction_dirpath = path.join(dirpath, 'feature_extraction') 91 | os.makedirs(feature_extraction_dirpath, exist_ok=True) 92 | 93 | extract_features = FeatureExtractionPipeline( 94 | config['feature_selection'], feature_extraction_dirpath, device) 95 | extract_features.load_data_and_targets_from_config(config['data']) 96 | 97 | if not config['feature_selection'].get('score_matrix'): 98 | extract_features.feature_subsetted_model_training() 99 | extract_features.feature_scoring() 100 | else: 101 | extract_features.set_score_matrix( 102 | read_data(config['feature_selection'].get('score_matrix'))) 103 | 104 | extract_features.top_feature_extraction() 105 | config['data'] = extract_features.write_top_features_subset_data( 106 | config['data']) 107 | 108 | config['feature_selection'] = extract_features.get_updated_config() 109 | write_data(config, path.join(dirpath, 'config.yaml')) 110 | del extract_features 111 | 112 | if config.get('final_training'): 113 | # Final model training. 114 | flow_logger.info('Final Model Training pipeline running') 115 | event_logger.heading1('Final Model Training') 116 | 117 | model_training_dirpath = path.join(dirpath, 'model') 118 | os.makedirs(model_training_dirpath, exist_ok=True) 119 | 120 | model_trainer = ModelTrainingPipeline( 121 | config['final_training']['model'], 122 | config['final_training']['model_train_config'], 123 | model_training_dirpath, device) 124 | 125 | model_trainer.load_data_and_targets_from_config(config['data']) 126 | model_trainer.build_model_training_artifacts() 127 | model_trainer.train() 128 | 129 | model_config, model_train_config = model_trainer.get_updated_config() 130 | config['final_training']['model'] = model_config 131 | config['final_training']['model_train_config'] = model_train_config 132 | write_data(config, path.join(dirpath, 'config.yaml')) 133 | del model_trainer 134 | 135 | if config.get('analysis'): 136 | # Analysis of trained model. 137 | flow_logger.info('Analysis pipeline running') 138 | event_logger.heading1('Analysis') 139 | 140 | analysis_dirpath = path.join(dirpath, 'analysis') 141 | os.makedirs(analysis_dirpath, exist_ok=True) 142 | 143 | if config.get('final_training'): 144 | config['analysis']['model_checkpoint'] = path.join( 145 | model_training_dirpath, 'best_model') 146 | 147 | analyser = EvalAndAnalysisPipeline(config['analysis'], analysis_dirpath, 148 | device) 149 | analyser.load_data_and_targets_from_config(config['data']) 150 | 151 | if config['analysis'].get('model_checkpoint'): 152 | analyser.evaluation_and_classification_report() 153 | 154 | if config['analysis'].get('gene_analysis'): 155 | analyser.gene_analysis() 156 | 157 | analyser.full_samples_downstream_anlaysis() 158 | analyser.test_samples_downstream_anlaysis() 159 | 160 | config['analysis'] = analyser.get_updated_config() 161 | write_data(config, path.join(dirpath, 'config.yaml')) 162 | del analyser 163 | 164 | return config 165 | 166 | 167 | if __name__ == '__main__': 168 | set_seed(42) 169 | args = get_args() 170 | 171 | start_time = time() 172 | 173 | # Parsing config. 174 | config = read_data(args.config) 175 | 176 | # Setting experiment information form config. 177 | dirpath = config['experiment']['dirpath'] 178 | exp_name = config['experiment']['exp_name'] 179 | exp_run = config['experiment']['exp_run'] 180 | dirpath = os.path.join(dirpath, f'{exp_name}_{exp_run}') 181 | device = config['device'] 182 | 183 | # Logging. 184 | log = args.log 185 | if log: 186 | level = getattr(logging, args.level) 187 | logpath = args.logpath if args.logpath else path.join( 188 | dirpath, 'logs.txt') 189 | else: 190 | level = logging.CRITICAL 191 | logpath = None 192 | 193 | flow_logger = FlowLogger('ROOT', level) 194 | flow_logger.info(f'Experiment directory: `{dirpath}`') 195 | if os.path.exists(dirpath): 196 | flow_logger.warning('Experiment directory already exists!') 197 | 198 | os.makedirs(dirpath, exist_ok=True) 199 | 200 | event_logger = EventLogger('ROOT', level, logpath) 201 | 202 | kwargs = dict(config=config, 203 | dirpath=dirpath, 204 | device=device, 205 | flow_logger=flow_logger, 206 | event_logger=event_logger) 207 | 208 | if args.memoryprofiler: 209 | max_memory = memory_usage((pipeline, [], kwargs), 210 | max_usage=True, 211 | interval=0.5, 212 | max_iterations=1) 213 | else: 214 | pipeline(**kwargs) 215 | 216 | end_time = time() 217 | flow_logger.info(f'Total time taken: {end_time - start_time} s') 218 | 219 | event_logger.heading1('Runtime Analyis') 220 | event_logger.info(f'Total time taken: {end_time - start_time} s') 221 | 222 | if args.memoryprofiler: 223 | flow_logger.info(f'Maximum memory usage: {max_memory} MB') 224 | event_logger.info(f'Maximum memory usage: {max_memory} MB') 225 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling", "hatch-requirements-txt"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | 7 | name = "pyscaLR" 8 | version = "1.1.0" 9 | requires-python = ">=3.10" 10 | authors = [ 11 | { name="Infocusp", email="saurabh@infocusp.com" }, 12 | ] 13 | description = "scaLR: Single cell analysis using low resource." 14 | readme = "README.md" 15 | classifiers = [ 16 | "Programming Language :: Python :: 3.9", 17 | "Operating System :: OS Independent", 18 | "Intended Audience :: Science/Research" 19 | ] 20 | 21 | dynamic = ["dependencies"] 22 | [tool.hatch.build.targets.wheel] 23 | packages = ["scalr"] 24 | 25 | [tool.hatch.metadata.hooks.requirements_txt] 26 | files = ["requirements.txt"] 27 | 28 | [tool.hatch.build] 29 | exclude = [ 30 | "docs/*", 31 | "tutorials/*", 32 | "tests/*", 33 | ] 34 | 35 | license = {file = "LICENSE"} 36 | 37 | [project.urls] 38 | Repository = "https://github.com/infocusp/scaLR.git" 39 | Homepage = "https://github.com/infocusp/scaLR" 40 | Issues = "https://github.com/infocusp/scaLR/issues" 41 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | anndata==0.10.9 2 | isort==5.13.2 3 | loky==3.4.1 4 | memory-profiler==0.61.0 5 | pillow==10.4.0 6 | pre_commit==4.0.1 7 | pydeseq2==0.4.11 8 | pyparsing==3.2.0 9 | pytest==8.3.3 10 | PyYAML==6.0.2 11 | scanpy==1.10.3 12 | scikit-learn==1.5.2 13 | shap==0.46.0 14 | tensorboard==2.17.0 15 | toml==0.10.2 16 | torch==2.4.1 --index-url https://download.pytorch.org/whl/cu118 17 | tqdm==4.66.5 18 | yapf==0.40.2 19 | -------------------------------------------------------------------------------- /scalr/__init__.py: -------------------------------------------------------------------------------- 1 | from . import analysis 2 | from . import data 3 | from . import feature 4 | from . import nn 5 | from . import utils 6 | from .data_ingestion_pipeline import DataIngestionPipeline 7 | from .eval_and_analysis_pipeline import EvalAndAnalysisPipeline 8 | from .feature_extraction_pipeline import FeatureExtractionPipeline 9 | from .model_training_pipeline import ModelTrainingPipeline 10 | -------------------------------------------------------------------------------- /scalr/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | from ._analyser import AnalysisBase 2 | from ._analyser import build_analyser 3 | from .dge_lmem import DgeLMEM 4 | from .dge_pseudobulk import DgePseudoBulk 5 | from .evaluation import generate_and_save_classification_report 6 | from .evaluation import get_accuracy 7 | from .gene_recall_curve import GeneRecallCurve 8 | from .heatmap import Heatmap 9 | from .roc_auc import RocAucCurve 10 | -------------------------------------------------------------------------------- /scalr/analysis/_analyser.py: -------------------------------------------------------------------------------- 1 | """This file is a base class for the analysis module.""" 2 | 3 | import os 4 | from typing import Union 5 | 6 | from anndata import AnnData 7 | from anndata.experimental import AnnCollection 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | 11 | import scalr 12 | from scalr.utils import build_object 13 | 14 | 15 | class AnalysisBase: 16 | """A base class for downstream analysis tasks. 17 | 18 | This class provides common attributes and methods for all the analysis tasks. 19 | It is intended to be subclassed to create task-specific analysis. 20 | """ 21 | 22 | def __init__(self): 23 | pass 24 | 25 | # Abstract 26 | def generate_analysis(self, model: nn.Module, 27 | test_data: Union[AnnData, AnnCollection], 28 | test_dl: DataLoader, dirpath: str, **kwargs): 29 | """A function to generate analysis, should be overridden by all subclasses. 30 | 31 | Args: 32 | model (nn.Module): final trained model. 33 | test_data (Union[AnnData, AnnCollection]): test data to run analysis on. 34 | test_dl (DataLoader): DataLoader object to prepare inputs for the model. 35 | dirpath (str): dirpath to store analysis. 36 | **kwargs: contains all previous analysis done to be used later. 37 | """ 38 | pass 39 | 40 | @classmethod 41 | def get_default_params(cls) -> dict: 42 | """Class method to get default params for analysis_config.""" 43 | return dict() 44 | 45 | 46 | def build_analyser(analysis_config: dict) -> tuple[AnalysisBase, dict]: 47 | """Builder object to get analyser, updated analyser_config.""" 48 | return build_object(scalr.analysis, analysis_config) 49 | -------------------------------------------------------------------------------- /scalr/analysis/evaluation.py: -------------------------------------------------------------------------------- 1 | """This file generates accuracy, classification report and stores it.""" 2 | 3 | from os import path 4 | 5 | from pandas import DataFrame 6 | from sklearn.metrics import accuracy_score 7 | from sklearn.metrics import classification_report 8 | import torch 9 | from torch import nn 10 | from torch.utils.data import DataLoader 11 | 12 | from scalr.utils import EventLogger 13 | from scalr.utils import write_data 14 | 15 | 16 | def get_accuracy(test_labels: list[int], pred_labels: list[int]) -> float: 17 | """A function to get accuracy for the predicted labels. 18 | 19 | Args: 20 | test_labels (list[int]): True labels from the test set. 21 | pred_labels (list[int]): Predicted labels from the trained model. 22 | 23 | Returns: 24 | float: accuracy score 25 | """ 26 | event_logger = EventLogger('Accuracy') 27 | accuracy = accuracy_score(test_labels, pred_labels) 28 | event_logger.info(f'Accuracy: {accuracy}') 29 | return accuracy 30 | 31 | 32 | def generate_and_save_classification_report(test_labels: list[int], 33 | pred_labels: list[int], 34 | dirpath: str, 35 | mapping: dict = None) -> DataFrame: 36 | """A function to generate a classificaton report from the actual and predicted data 37 | and store at `dirpath`. 38 | 39 | Args: 40 | test_labels: True labels from the test set. 41 | pred_labels: Predicted labels from the trained model. 42 | dirpath: Path to store classification_report. 43 | mapping[optional]: Mapping of label_id to true label_names (id2label). 44 | 45 | Returns: 46 | A Pandas DataFrame with the classification report. 47 | """ 48 | event_logger = EventLogger('ClassReport') 49 | 50 | if mapping: 51 | test_labels = [mapping[x] for x in test_labels] 52 | pred_labels = [mapping[x] for x in pred_labels] 53 | 54 | report = DataFrame( 55 | classification_report(test_labels, pred_labels, 56 | output_dict=True)).transpose() 57 | event_logger.info('\nClassification Report:') 58 | event_logger.info(report) 59 | write_data(report, path.join(dirpath, 'classification_report.csv')) 60 | 61 | return report 62 | -------------------------------------------------------------------------------- /scalr/analysis/heatmap.py: -------------------------------------------------------------------------------- 1 | """This file generates heatmaps for top genes of particular class w.r.t same top genes in other classes.""" 2 | 3 | import os 4 | from typing import Tuple, Union 5 | 6 | import matplotlib.pyplot as plt 7 | import pandas as pd 8 | import seaborn as sns 9 | 10 | from scalr.analysis import AnalysisBase 11 | from scalr.utils import EventLogger 12 | from scalr.utils import read_data 13 | 14 | 15 | class Heatmap(AnalysisBase): 16 | '''Class to generate a heatmap of top genes classwise.''' 17 | 18 | def __init__(self, 19 | top_n_genes: int = 100, 20 | save_plot: bool = True, 21 | score_matrix_path: str = None, 22 | top_features_path: str = None, 23 | *args, 24 | **kwargs): 25 | """Initialize class with shap arguments. 26 | 27 | Args: 28 | top_n_genes: top N genes for each class/label. 29 | save_plot: Where to save plot or show plot. 30 | score_matrix_path: path to score matrix. 31 | top_features_path: path to top features. 32 | """ 33 | 34 | self.top_n_genes = top_n_genes 35 | self.save_plot = save_plot 36 | self.score_matrix_path = score_matrix_path 37 | self.top_features_path = top_features_path 38 | 39 | self.event_logger = EventLogger('Heatmap') 40 | 41 | def generate_analysis(self, 42 | dirpath: str, 43 | score_matrix: pd.DataFrame = None, 44 | top_features: Union[dict, list] = None, 45 | **kwargs) -> None: 46 | """A function to generate heatmap for top features. 47 | 48 | Args: 49 | score_matrix: Matrix(class * genes) that contains a score of each gene per class. 50 | top_features: Class-wise top genes or list of top features. 51 | dirpath: Path to store the heatmap image. 52 | """ 53 | 54 | self.event_logger.heading2("Generating Heatmaps.") 55 | 56 | if isinstance(top_features, list): 57 | self.event_logger.info( 58 | "Generating heatmap for the same top genes across all classes as provided" 59 | " `top_features` is a single list and not top genes per class dict." 60 | ) 61 | top_features = {"all_class_common": top_features} 62 | 63 | if (score_matrix is None) and (top_features is None): 64 | 65 | if not self.score_matrix_path: 66 | raise ValueError("score_matrix_path required.") 67 | 68 | if not self.top_features_path: 69 | raise ValueError("top_features_path required.") 70 | 71 | score_matrix = read_data(self.score_matrix_path) 72 | top_features = read_data(self.top_features_path) 73 | 74 | for class_name, genes in top_features.items(): 75 | self.plot_heatmap(score_matrix[genes[:self.top_n_genes]].T, 76 | f"{dirpath}/heatmaps", class_name) 77 | 78 | self.event_logger.info(f"Heatmaps stored at: {dirpath}/heatmaps") 79 | 80 | def plot_heatmap(self, class_genes_weights: pd.DataFrame, dirpath: str, 81 | filename: str) -> None: 82 | """A function to plot a heatmap for top n genes across all classes. 83 | 84 | Args: 85 | class_genes_weights: Matrix(genes * classes) which contains 86 | shap_value/weights of each gene to class. 87 | dirpath: Path to store the heatmap image. 88 | filename: Heatmap image name. 89 | """ 90 | 91 | os.makedirs(dirpath, exist_ok=True) 92 | 93 | sns.set(rc={'figure.figsize': (9, 12)}) 94 | sns.heatmap(class_genes_weights, vmin=-1e-2, vmax=1e-2) 95 | 96 | plt.tight_layout() 97 | plt.title(filename) 98 | 99 | if self.save_plot: 100 | plt.savefig(os.path.join(dirpath, f"{filename}.svg")) 101 | else: 102 | plt.show() 103 | plt.clf() 104 | -------------------------------------------------------------------------------- /scalr/analysis/roc_auc.py: -------------------------------------------------------------------------------- 1 | """This file generates ROC-AUC plot and stores it.""" 2 | 3 | from os import path 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | from sklearn.metrics import auc 8 | from sklearn.metrics import roc_curve 9 | from sklearn.metrics import RocCurveDisplay 10 | 11 | from scalr.analysis import AnalysisBase 12 | from scalr.utils import data_utils 13 | from scalr.utils import EventLogger 14 | 15 | 16 | class RocAucCurve(AnalysisBase): 17 | '''Class to generate ROC-AUC curve.''' 18 | 19 | def generate_analysis(self, test_labels: list[int], 20 | pred_probabilities: list[list[float]], dirpath: str, 21 | mapping: list, **kwargs) -> None: 22 | """A function to calculate ROC-AUC and save the plot. 23 | 24 | Args: 25 | test_labels: True labels from the test dataset. 26 | pred_probabilities: Predictions probabilities of each sample for all the classes. 27 | dirpath: Path to store gene recall curve if applicable. 28 | mapping: List of class names. 29 | """ 30 | 31 | logger_name = "ROC-AUC Analysis" 32 | self.event_logger = EventLogger(logger_name) 33 | self.event_logger.heading2(logger_name) 34 | self.event_logger.info("Generating one hot matrix of test labels.") 35 | 36 | # convert label predictions list to the one-hot matrix. 37 | test_labels_onehot = data_utils.get_one_hot_matrix( 38 | np.array(test_labels)) 39 | fig, ax = plt.subplots(1, 1, figsize=(16, 8)) 40 | 41 | self.event_logger.info( 42 | "Calculating ROC-AUC for each label and creating a plot for that.") 43 | 44 | # test labels start with 0 so we need to add 1 in max. 45 | for class_label in range(max(test_labels) + 1): 46 | 47 | # fpr: False Positive Rate | tpr: True Positive Rate 48 | fpr, tpr, _ = roc_curve( 49 | test_labels_onehot[:, class_label], 50 | np.array(pred_probabilities)[:, class_label]) 51 | 52 | roc_auc = auc(fpr, tpr) 53 | 54 | display = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc) 55 | display.plot(ax=ax, name=mapping[class_label]) 56 | 57 | self.event_logger.info("Saving plot and clear axis & figure.") 58 | 59 | plt.axline((0, 0), (1, 1), linestyle='--', color='black') 60 | fig.savefig(path.join(dirpath, f'roc_auc.svg')) 61 | plt.clf() # clear axis & figure so it does not affect the next plot. 62 | -------------------------------------------------------------------------------- /scalr/analysis/test_dge_lmem.py: -------------------------------------------------------------------------------- 1 | """This is a test file for dge_lmem.py""" 2 | 3 | import os 4 | from os import path 5 | import shutil 6 | 7 | import numpy as np 8 | 9 | from scalr.analysis import dge_lmem 10 | from scalr.analysis.test_dge_pseudobulk import check_dge_result 11 | from scalr.utils import generate_dummy_dge_anndata 12 | from scalr.utils import read_data 13 | 14 | # DgeLMEM parameters 15 | lmem_parms_dict = { 16 | 'fixed_effect_column': 'disease', 17 | 'fixed_effect_factors': ['disease_x', 'normal'], 18 | 'group': 'donor_id', 19 | 'celltype_column': 'cell_type' 20 | } 21 | # Dictionary with expected results from DgeLMEM 22 | expected_lmem_dge_result_dict = { 23 | 'B_cell': { 24 | 'shape': (10, 6), 25 | 'gene': ['gene_1', 'gene_10'], 26 | 'random_col_and_val': ('coef_disease_x', 0.3, 0.02) 27 | }, 28 | 'T_cell': { 29 | 'shape': (10, 6), 30 | 'gene': ['gene_1', 'gene_10'], 31 | 'random_col_and_val': ('coef_disease_x', -0.05, 0.37) 32 | } 33 | } 34 | 35 | 36 | def test_lmem_generate_analysis( 37 | dge_parms_dict: dict = lmem_parms_dict, 38 | expected_dge_result_dict: dict = expected_lmem_dge_result_dict) -> None: 39 | """This function generates DGE result using `generate_analysis` method in DgeLMEM class. 40 | Finally checks the generated results with the expected by calling `check_dge_result`function. 41 | 42 | Args: 43 | dge_parms_dict: Parameters dictionary for the DgeLMEM class. 44 | expected_dge_result_dict: A dictionary with expected dge results. 45 | """ 46 | 47 | os.makedirs('./tmp', exist_ok=True) 48 | 49 | # Generating dummy anndata. 50 | adata = generate_dummy_dge_anndata() 51 | 52 | # Path to store dge result. 53 | dirpath = './tmp' 54 | cell_subsets = list(expected_dge_result_dict.keys()) 55 | dge_lm = dge_lmem.DgeLMEM( 56 | fixed_effect_column=dge_parms_dict['fixed_effect_column'], 57 | fixed_effect_factors=dge_parms_dict['fixed_effect_factors'], 58 | group=dge_parms_dict['group'], 59 | celltype_column=dge_parms_dict['celltype_column'], 60 | cell_subsets=cell_subsets) 61 | 62 | dge_lm.generate_analysis(adata, dirpath) 63 | lmem_dirpath = path.join(dirpath, 'lmem_dge_result') 64 | #Checking DGE result files and values for each celltype 65 | check_dge_result(lmem_dirpath, expected_dge_result_dict) 66 | 67 | shutil.rmtree('./tmp', ignore_errors=True) 68 | -------------------------------------------------------------------------------- /scalr/analysis/test_dge_pseudobulk.py: -------------------------------------------------------------------------------- 1 | """This is a test file for dge_pseudobulk.py""" 2 | 3 | import os 4 | from os import path 5 | import shutil 6 | 7 | import numpy as np 8 | 9 | from scalr.analysis import dge_pseudobulk 10 | from scalr.utils import generate_dummy_dge_anndata 11 | from scalr.utils import read_data 12 | 13 | 14 | def check_dge_result(result_path: str, expected_dge_result_dict: dict) -> None: 15 | """This function checks the expected DGE results with the generated results. 16 | 17 | Args: 18 | result_path: Path to the generated DGE results. 19 | expected_dge_result_dict: A dictionary with expected dge results. 20 | """ 21 | 22 | celltype_list = list(expected_dge_result_dict.keys()) 23 | result_files = os.listdir(result_path) 24 | csv_files = [] 25 | svg_files = [] 26 | for file in result_files: 27 | if file.endswith('.csv'): 28 | csv_files.append(file) 29 | elif file.endswith('.svg'): 30 | svg_files.append(file) 31 | 32 | # Checking for right numbers of csv & svg files. 33 | assert len(csv_files) == len( 34 | celltype_list), f"Expected {len(celltype_list)} csv files" 35 | assert len(svg_files) == len( 36 | celltype_list), f"Expected {len(celltype_list)} svg files" 37 | 38 | for celltype in celltype_list: 39 | celltype_csv = [file for file in csv_files if celltype in file] 40 | assert celltype_csv, f"CSV file for {celltype} is not produced" 41 | celltype_svg = [file for file in svg_files if celltype in file] 42 | assert celltype_svg, f"SVG file for {celltype} is not produced" 43 | 44 | celltype_df = read_data(path.join(result_path, celltype_csv[0]), 45 | index_col=None) 46 | celltype_dge_result_dict = expected_dge_result_dict[celltype] 47 | assert celltype_df.shape == celltype_dge_result_dict['shape'], ( 48 | f"There is a mismatch in the shape of the dge_result CSV file for '{celltype}'." 49 | ) 50 | assert ( 51 | celltype_df.loc[celltype_df.index[0], 52 | 'gene'] == celltype_dge_result_dict['gene'][0] 53 | ) & (np.round( 54 | celltype_df.loc[celltype_df.index[0], 55 | celltype_dge_result_dict['random_col_and_val'][0]], 56 | 2 57 | ) == celltype_dge_result_dict['random_col_and_val'][1]), ( 58 | f'There is a mismatch in the DGE results of the first row in the CSV file for {celltype}' 59 | ) 60 | assert ( 61 | celltype_df.loc[celltype_df.index[-1], 62 | 'gene'] == celltype_dge_result_dict['gene'][-1] 63 | ) & (np.round( 64 | celltype_df.loc[celltype_df.index[-1], 65 | celltype_dge_result_dict['random_col_and_val'][0]], 66 | 2 67 | ) == celltype_dge_result_dict['random_col_and_val'][-1]), ( 68 | f'There is a mismatch in the DGE results of the last row in the CSV file for {celltype}' 69 | ) 70 | 71 | 72 | # DgePseudoBulk parameters 73 | pseudobulk_parms_dict = { 74 | 'celltype_column': 'cell_type', 75 | 'design_factor': 'disease', 76 | 'factor_categories': ['disease_x', 'normal'], 77 | 'sum_column': 'donor_id', 78 | 'cell_subsets': ['B_cell', 'T_cell'] 79 | } 80 | # Dictionary with expected results from DgePseudoBulk 81 | expected_pbk_dge_result_dict = { 82 | 'B_cell': { 83 | 'shape': (10, 7), 84 | 'gene': ['gene_1', 'gene_10'], 85 | 'random_col_and_val': ('log2FoldChange', 0.97, 0.38) 86 | }, 87 | 'T_cell': { 88 | 'shape': (10, 7), 89 | 'gene': ['gene_1', 'gene_10'], 90 | 'random_col_and_val': ('log2FoldChange', -0.26, 1.28) 91 | } 92 | } 93 | 94 | 95 | def test_pseudobulk_generate_analysis( 96 | dge_parms_dict: dict = pseudobulk_parms_dict, 97 | expected_dge_result_dict: dict = expected_pbk_dge_result_dict) -> None: 98 | """This function generates DGE result using `generate_analysis` method in DgePseudoBulk class. 99 | Finally checks the generated results with the expected by calling `check_dge_result`function. 100 | 101 | Args: 102 | dge_parms_dict: Parameters dictionary for the DgePseudoBulk class. 103 | expected_dge_result_dict: A dictionary with expected dge results. 104 | """ 105 | 106 | os.makedirs('./tmp', exist_ok=True) 107 | 108 | # Generating dummy anndata. 109 | adata = generate_dummy_dge_anndata() 110 | 111 | # Path to store dge result. 112 | dirpath = './tmp' 113 | cell_subsets = list(expected_dge_result_dict.keys()) 114 | dge_pbk = dge_pseudobulk.DgePseudoBulk( 115 | celltype_column=dge_parms_dict['celltype_column'], 116 | design_factor=dge_parms_dict['design_factor'], 117 | factor_categories=dge_parms_dict['factor_categories'], 118 | sum_column=dge_parms_dict['sum_column'], 119 | cell_subsets=cell_subsets) 120 | 121 | dge_pbk.generate_analysis(adata, dirpath) 122 | pbk_dirpath = path.join(dirpath, 'pseudobulk_dge_result') 123 | #Checking DGE result files and values for each celltype 124 | check_dge_result(pbk_dirpath, expected_dge_result_dict) 125 | 126 | shutil.rmtree('./tmp', ignore_errors=True) 127 | -------------------------------------------------------------------------------- /scalr/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .preprocess import * 2 | from .split import * 3 | -------------------------------------------------------------------------------- /scalr/data/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | from ._preprocess import build_preprocessor 2 | from ._preprocess import PreprocessorBase 3 | from .sample_norm import SampleNorm 4 | from .standard_scale import StandardScaler 5 | -------------------------------------------------------------------------------- /scalr/data/preprocess/_preprocess.py: -------------------------------------------------------------------------------- 1 | """This file is a base class for preprocessing module.""" 2 | 3 | from os import path 4 | from typing import Union 5 | 6 | from anndata import AnnData 7 | from anndata.experimental import AnnCollection 8 | import numpy as np 9 | 10 | import scalr 11 | from scalr.utils import build_object 12 | from scalr.utils import write_chunkwise_data 13 | 14 | 15 | class PreprocessorBase: 16 | """Base class for preprocessor""" 17 | 18 | def __init__(self, **kwargs): 19 | # Store all params here. 20 | pass 21 | 22 | # Abstract 23 | def transform(self, data: np.ndarray) -> np.ndarray: 24 | """A required function to transform a numpy array. 25 | 26 | Args: 27 | data (np.ndarray): Input raw data. 28 | 29 | Returns: 30 | np.ndarray: Processed data. 31 | """ 32 | pass 33 | 34 | def fit( 35 | self, 36 | data: Union[AnnData, AnnCollection], 37 | sample_chunksize: int, 38 | ) -> None: 39 | """A function to calculate attributes for transformation. 40 | 41 | It is required only when you need to see the entire train data and 42 | calculate attributes, as required in StdScaler, etc. This method 43 | should not return anything, it should be used to store attributes 44 | that will be used by the `transform` method. 45 | 46 | Args: 47 | data (Union[AnnData, AnnCollection]): train_data in backed mode. 48 | sample_chunksize (int): several samples of data that can at most 49 | be loaded in memory. 50 | """ 51 | pass 52 | 53 | def process_data(self, 54 | full_data: Union[AnnData, AnnCollection], 55 | sample_chunksize: int, 56 | dirpath: str, 57 | num_workers: int = 1): 58 | """A function to process the entire data chunkwise and write the processed data 59 | to disk. 60 | 61 | Args: 62 | full_data (Union[AnnData, AnnCollection]): Full data for transformation. 63 | sample_chunksize (int): Number of samples in one chunk. 64 | dirpath (str): Path to write the data to. 65 | num_workers (int): number of jobs to run in parallel for data writing. 66 | """ 67 | if not sample_chunksize: 68 | # TODO 69 | raise NotImplementedError( 70 | 'Preprocessing does not work without sample chunk size') 71 | 72 | write_chunkwise_data(full_data, 73 | sample_chunksize, 74 | dirpath, 75 | transform=self.transform, 76 | num_workers=num_workers) 77 | 78 | 79 | def build_preprocessor( 80 | preprocessing_config: dict) -> tuple[PreprocessorBase, dict]: 81 | """Builder object to get a processor, updated preprocessing_config.""" 82 | return build_object(scalr.data.preprocess, preprocessing_config) 83 | -------------------------------------------------------------------------------- /scalr/data/preprocess/sample_norm.py: -------------------------------------------------------------------------------- 1 | """This file performs Sample-wise normalization on the data.""" 2 | 3 | from typing import Union 4 | 5 | import numpy as np 6 | 7 | from scalr.data.preprocess import PreprocessorBase 8 | 9 | 10 | class SampleNorm(PreprocessorBase): 11 | """Class for Samplewise Normalization""" 12 | 13 | def __init__(self, scaling_factor: float = 1.0): 14 | """Initialize parameters for Sample-wise normalization. 15 | 16 | Args: 17 | scaling_factor: `Target sum` to maintain for each sample. 18 | """ 19 | 20 | self.scaling_factor = scaling_factor 21 | 22 | def transform(self, data: np.ndarray) -> np.ndarray: 23 | """A function to transform provided input data. 24 | 25 | Args: 26 | data (np.ndarray): Input raw data. 27 | 28 | Returns: 29 | np.ndarray: Processed data. 30 | """ 31 | data *= (self.scaling_factor / (data.sum(axis=1).reshape(len(data), 1))) 32 | return data 33 | 34 | @classmethod 35 | def get_default_params(cls) -> dict: 36 | """Class method to get default params for preprocess_config.""" 37 | return dict(scaling_factor=1.0) 38 | -------------------------------------------------------------------------------- /scalr/data/preprocess/standard_scale.py: -------------------------------------------------------------------------------- 1 | """This file performs standard scaler normalization on the data.""" 2 | 3 | from typing import Union 4 | 5 | from anndata import AnnData 6 | from anndata.experimental import AnnCollection 7 | import numpy as np 8 | 9 | from scalr.data.preprocess import PreprocessorBase 10 | 11 | 12 | class StandardScaler(PreprocessorBase): 13 | """Class for Standard Normalization""" 14 | 15 | def __init__(self, with_mean: bool = True, with_std: bool = True): 16 | """Initialize parameters for standard scaler normalization. 17 | 18 | Args: 19 | with_mean: Mean for standard scaling. 20 | with_std: Standard deviation for standard scaling. 21 | """ 22 | 23 | self.with_mean = with_mean 24 | self.with_std = with_std 25 | 26 | # Parameters for standard scaler. 27 | self.train_mean = None 28 | self.train_std = None 29 | 30 | def transform(self, data: np.ndarray) -> np.ndarray: 31 | """A function to transform provided input data. 32 | 33 | Args: 34 | data (np.ndarray): raw data 35 | 36 | Returns: 37 | np.ndarray: processed data 38 | """ 39 | if not self.with_mean: 40 | train_mean = np.zeros((1, data.shape[1])) 41 | else: 42 | train_mean = self.train_mean 43 | return (data - train_mean) / self.train_std 44 | 45 | def fit(self, data: Union[AnnData, AnnCollection], 46 | sample_chunksize: int) -> None: 47 | """This function calculate parameters for standard scaler object from the train data. 48 | 49 | Args: 50 | data: Data to calculate the required parameters of. 51 | sample_chunksize: Chunks of data that can be loaded into memory at once. 52 | 53 | """ 54 | 55 | self.calculate_mean(data, sample_chunksize) 56 | self.calculate_std(data, sample_chunksize) 57 | 58 | def calculate_mean(self, data: Union[AnnData, AnnCollection], 59 | sample_chunksize: int) -> None: 60 | """Function to calculate mean for each feature in the train data 61 | 62 | Args: 63 | data: Data to calculate the mean of. 64 | sample_chunksize: Chunks of data that can be loaded into memory at once. 65 | 66 | Returns: 67 | Nothing, stores mean per feature of the train data. 68 | """ 69 | 70 | train_sum = np.zeros(data.shape[1]).reshape(1, -1) 71 | 72 | # Iterate through batches of data to get mean statistics 73 | for i in range(int(np.ceil(data.shape[0] / sample_chunksize))): 74 | train_sum += data[i * sample_chunksize:i * sample_chunksize + 75 | sample_chunksize].X.sum(axis=0) 76 | self.train_mean = train_sum / data.shape[0] 77 | 78 | def calculate_std(self, data: Union[AnnData, AnnCollection], 79 | sample_chunksize: int) -> None: 80 | """A function to calculate standard deviation for each feature in the train data. 81 | 82 | Args: 83 | data: Data to calculate the standard deviation of 84 | sample_chunksize: Chunks of data that can be loaded into memory at once. 85 | 86 | Returns: 87 | Nothing, stores standard deviation per feature of the train data. 88 | """ 89 | 90 | # Getting standard deviation of entire train data per feature. 91 | if self.with_std: 92 | self.train_std = np.zeros(data.shape[1]).reshape(1, -1) 93 | # Iterate through batches of data to get std statistics 94 | for i in range(int(np.ceil(data.shape[0] / sample_chunksize))): 95 | self.train_std += np.sum(np.power( 96 | data[i * sample_chunksize:i * sample_chunksize + 97 | sample_chunksize].X - self.train_mean, 2), 98 | axis=0) 99 | self.train_std /= data.shape[0] 100 | self.train_std = np.sqrt(self.train_std) 101 | 102 | # Handling cases where standard deviation of feature is 0, replace it with 1. 103 | self.train_std[self.train_std == 0] = 1 104 | else: 105 | # If `with_std` is False, set train_std to 1. 106 | self.train_std = np.ones((1, data.shape[1])) 107 | 108 | @classmethod 109 | def get_default_params(cls) -> dict: 110 | """Class method to get default params for preprocess_config.""" 111 | return dict(with_mean=True, with_std=True) 112 | -------------------------------------------------------------------------------- /scalr/data/preprocess/test_sample_norm.py: -------------------------------------------------------------------------------- 1 | """This is a test file for Sample-wise normalization.""" 2 | 3 | import scanpy as sc 4 | 5 | from scalr.data.preprocess import sample_norm 6 | from scalr.utils import generate_dummy_anndata 7 | 8 | 9 | def test_transform(): 10 | '''This function tests the transform function of Sample-wise normalization. 11 | 12 | There is no fit() involved in Sample-wise normalization. 13 | ''' 14 | 15 | # Creating an annadata object. 16 | adata = generate_dummy_anndata(n_samples=100, n_features=25) 17 | 18 | # Sample-wise norm required parameter. 19 | target_sum = 5 20 | 21 | # scalr sample-wise normalization. 22 | scalr_sample_norm = sample_norm.SampleNorm(scaling_factor=target_sum) 23 | # No need to fit() for sample-norm normalization 24 | scalr_scaled_data = scalr_sample_norm.transform(adata.X) 25 | 26 | # scanpy sample-wise normalization. 27 | scanpy_scaled_data = sc.pp.normalize_total(adata, 28 | target_sum=target_sum, 29 | inplace=False)['X'] 30 | 31 | # asserts to check transformed data having errors less than 1e-15 compared to scanpy's transformed data. 32 | assert sum( 33 | abs(scanpy_scaled_data.flatten() - 34 | scalr_scaled_data.flatten()).flatten() < 1e-15 35 | ) == scalr_scaled_data.flatten().shape[ 36 | 0], "The sample norm is incorrectly transforming data, please debug code." 37 | -------------------------------------------------------------------------------- /scalr/data/preprocess/test_standard_scale.py: -------------------------------------------------------------------------------- 1 | '''This is a test file for standard-scaler normalization.''' 2 | 3 | from sklearn import preprocessing 4 | 5 | from scalr.data.preprocess import standard_scale 6 | from scalr.utils import generate_dummy_anndata 7 | 8 | 9 | def test_fit(): 10 | '''This function tests fit() function of sample-norm normalization. 11 | 12 | fit() function is enough for testing, as we can compare mean and std with 13 | sklean standard-scaler object params. 14 | ''' 15 | 16 | # Creating an annadata object. 17 | adata = generate_dummy_anndata(n_samples=100, n_features=25) 18 | 19 | # Standard scaler required parameters. 20 | with_mean = False 21 | with_std = True 22 | 23 | # scalr standard-scale normalization. 24 | scalr_std_scaler = standard_scale.StandardScaler(with_mean=with_mean, 25 | with_std=with_std) 26 | # Required parameter - sample_chunksize to process data in chunks. 27 | sample_chunksize = 4 28 | scalr_std_scaler.fit(adata, sample_chunksize=sample_chunksize) 29 | 30 | # sklearn normalization 31 | sklearn_std_scaler = preprocessing.StandardScaler(with_mean=with_mean, 32 | with_std=with_std) 33 | sklearn_std_scaler.fit(adata.X) 34 | 35 | # asserts to check the calculated mean and standard deviation, the error should be less than 1e-15. 36 | assert sum( 37 | abs(scalr_std_scaler.train_mean - 38 | sklearn_std_scaler.mean_).flatten() < 1e-15 39 | ) == adata.shape[1], "Train data mean is not correctly calculated..." 40 | assert sum( 41 | abs(scalr_std_scaler.train_std - sklearn_std_scaler.scale_).flatten() < 42 | 1e-15) == adata.shape[ 43 | 1], "Train data standard deviation is not correctly calculated..." 44 | -------------------------------------------------------------------------------- /scalr/data/split/__init__.py: -------------------------------------------------------------------------------- 1 | from ._split import build_splitter 2 | from ._split import SplitterBase 3 | from .stratified_group_splitter import StratifiedGroupSplitter 4 | 5 | # GroupSplitter inherit from StratifiedSplitter. 6 | from .stratified_splitter import StratifiedSplitter # isort:skip 7 | from .group_splitter import GroupSplitter # isort:skip 8 | -------------------------------------------------------------------------------- /scalr/data/split/_split.py: -------------------------------------------------------------------------------- 1 | """This file is a base class for splitter.""" 2 | 3 | import os 4 | from os import path 5 | from typing import Union 6 | 7 | from anndata import AnnData 8 | from anndata.experimental import AnnCollection 9 | 10 | import scalr 11 | from scalr.utils import build_object 12 | from scalr.utils import EventLogger 13 | from scalr.utils import read_data 14 | from scalr.utils import write_chunkwise_data 15 | from scalr.utils import write_data 16 | 17 | 18 | class SplitterBase: 19 | """Base class for splitter, to make Train|Val|Test Splits.""" 20 | 21 | def __init__(self): 22 | self.event_logger = EventLogger('Splitter') 23 | 24 | # Abstract 25 | def generate_train_val_test_split_indices(datapath: str, target: str, 26 | **kwargs) -> dict: 27 | """Generate a list of indices for train/val/test split of whole dataset. 28 | 29 | Args: 30 | datapath (str): Path to full data. 31 | target (str): Target for classification present in `obs`. 32 | **kwargs: Any other params needed for splitting. 33 | 34 | Returns: 35 | dict: 'train', 'val' and 'test' indices list. 36 | """ 37 | pass 38 | 39 | def check_splits(self, datapath: str, data_splits: dict, target: str): 40 | """This function performs certain checks regarding splits and logs 41 | the distribution of target classes in each split. 42 | 43 | Args: 44 | datapath (str): Path to full data. 45 | data_splits (dict): Split of 'train', 'val' and 'test' indices. 46 | target (str): Classification target column name in `obs`. 47 | """ 48 | 49 | adata = read_data(datapath) 50 | metadata = adata.obs 51 | n_cls = metadata[target].nunique() 52 | 53 | train_inds = data_splits['train'] 54 | val_inds = data_splits['val'] 55 | test_inds = data_splits['test'] 56 | 57 | # Check for classes present in splits. 58 | if len(metadata[target].iloc[train_inds].unique()) != n_cls: 59 | self.event_logger.warning( 60 | 'All classes are not present in Train set') 61 | 62 | if len(metadata[target].iloc[val_inds].unique()) != n_cls: 63 | self.event_logger.warning( 64 | 'All classes are not present in Validation set') 65 | 66 | if len(metadata[target].iloc[test_inds].unique()) != n_cls: 67 | self.event_logger.warning('All classes are not present in Test set') 68 | 69 | # Check for overlapping samples. 70 | assert len(set(train_inds).intersection( 71 | test_inds)) == 0, "Test and Train sets contain overlapping samples" 72 | assert len( 73 | set(val_inds).intersection(train_inds) 74 | ) == 0, "Validation and Train sets contain overlapping samples" 75 | assert len(set(test_inds).intersection(val_inds) 76 | ) == 0, "Test and Validation sets contain overlapping samples" 77 | 78 | # LOGGING. 79 | self.event_logger.info('Train|Validation|Test Splits\n') 80 | self.event_logger.info(f'Length of train set: {len(train_inds)}') 81 | self.event_logger.info(f'Distribution of train set: ') 82 | self.event_logger.info( 83 | f'{metadata[target].iloc[train_inds].value_counts()}\n') 84 | 85 | self.event_logger.info(f'Length of val set: {len(val_inds)}') 86 | self.event_logger.info(f'Distribution of val set: ') 87 | self.event_logger.info( 88 | f'{metadata[target].iloc[val_inds].value_counts()}\n') 89 | 90 | self.event_logger.info(f'Length of test set: {len(test_inds)}') 91 | self.event_logger.info(f'Distribution of test set: ') 92 | self.event_logger.info( 93 | f'{metadata[target].iloc[test_inds].value_counts()}\n') 94 | 95 | def write_splits(self, 96 | full_data: Union[AnnData, AnnCollection], 97 | data_split_indices: dict, 98 | sample_chunksize: int, 99 | dirpath: int, 100 | num_workers: int = None): 101 | """THis function writes the train validation and test splits to the disk. 102 | 103 | Args: 104 | full_data (Union[AnnData, AnnCollection]): Full data to be split. 105 | data_split_indices (dict): Indices of each split. 106 | sample_chunksize (int): Number of samples to be written in one file. 107 | dirpath (int): Path to write data into. 108 | num_workers (int): number of jobs to run in parallel for data writing. 109 | """ 110 | 111 | for split in data_split_indices.keys(): 112 | if sample_chunksize: 113 | split_dirpath = path.join(dirpath, split) 114 | os.makedirs(split_dirpath, exist_ok=True) 115 | write_chunkwise_data(full_data, 116 | sample_chunksize, 117 | split_dirpath, 118 | data_split_indices[split], 119 | num_workers=num_workers) 120 | else: 121 | filepath = path.join(dirpath, f'{split}.h5ad') 122 | write_data(full_data[data_split_indices[split]].to_memory(), 123 | filepath) 124 | 125 | @classmethod 126 | def get_default_params(cls) -> dict: 127 | """Class method to get default params for model_config.""" 128 | return dict() 129 | 130 | 131 | def build_splitter(splitter_config: dict) -> tuple[SplitterBase, dict]: 132 | """Builder object to get splitter, updated splitter_config.""" 133 | return build_object(scalr.data.split, splitter_config) 134 | -------------------------------------------------------------------------------- /scalr/data/split/group_splitter.py: -------------------------------------------------------------------------------- 1 | """This file is an implementation of group splitter.""" 2 | 3 | from pandas import DataFrame 4 | from sklearn.model_selection import GroupShuffleSplit 5 | 6 | from scalr.data.split import StratifiedSplitter 7 | 8 | 9 | class GroupSplitter(StratifiedSplitter): 10 | """Class for splitting data based on the provided group. 11 | 12 | Generate a stratified split of data into train, validation, and test 13 | sets. Stratification ensures samples have the same value for `stratify` 14 | column, can not belong to different sets. 15 | """ 16 | 17 | def __init__(self, split_ratio: list[float], stratify: str): 18 | """Initialize splitter with required parameters. 19 | 20 | Args: 21 | split_ratio (list[float]): Ratio to split number of samples in. 22 | stratify (str): Column name to metadata the split upon in `obs`. 23 | """ 24 | super().__init__(split_ratio) 25 | self.stratify = stratify 26 | 27 | def _split_data_with_stratification( 28 | self, metadata: DataFrame, target: str, 29 | test_ratio: float) -> tuple[list[int], list[int]]: 30 | """A function to split given metadata into a training and testing set. 31 | 32 | Args: 33 | metadata (DataFrame): Dataframe containing all samples to be split. 34 | target (str): Target for classification present in `obs`. 35 | test_ratio (float): Ratio of samples belonging to the test split. 36 | 37 | Returns: 38 | (list(int), list(int)): Two lists consisting of train and test indices. 39 | """ 40 | splitter = GroupShuffleSplit(test_size=test_ratio, 41 | n_splits=1, 42 | random_state=42) 43 | 44 | train_inds, test_inds = next( 45 | splitter.split(metadata, 46 | metadata[target], 47 | groups=metadata[self.stratify])) 48 | 49 | return train_inds, test_inds 50 | 51 | @classmethod 52 | def get_default_params(cls) -> dict: 53 | """Class method to get default params for model_config.""" 54 | return dict(split_ratio=[7, 1, 2], stratify='donor_id') 55 | -------------------------------------------------------------------------------- /scalr/data/split/stratified_group_splitter.py: -------------------------------------------------------------------------------- 1 | """This file is an implementation of stratified group splitter.""" 2 | 3 | from pandas import DataFrame 4 | from sklearn.model_selection import GroupShuffleSplit 5 | 6 | from scalr.data.split import SplitterBase 7 | from scalr.utils import read_data 8 | 9 | 10 | class StratifiedGroupSplitter(SplitterBase): 11 | """Class for stratified group splitter. 12 | 13 | Generates split of data into train, validation, and test 14 | sets. Stratification ensures samples have the same value for `stratify` 15 | column, can not belong to different sets. Also, it ensures every split 16 | contains samples from each class available in the data. 17 | """ 18 | 19 | def __init__(self, split_ratio: list[float], stratify: str): 20 | """Initialize splitter with required parameters. 21 | 22 | Args: 23 | split_ratio (list[float]): ratio to split number of samples in 24 | stratify (str): column name to metadata the split upon in `obs` 25 | """ 26 | super().__init__() 27 | self.stratify = stratify 28 | self.split_ratio = split_ratio 29 | 30 | def _split_data_with_stratification( 31 | self, metadata: DataFrame, target: str, 32 | test_ratio: float) -> tuple[list[int], list[int]]: 33 | """A function to split given metadata into a training and testing set. 34 | 35 | Args: 36 | metadata (DataFrame): Dataframe containing all samples to be split. 37 | target (str): Target for classification present in `obs`. 38 | test_ratio (float): Ratio of samples belonging to the test split. 39 | 40 | Returns: 41 | (list(int), list(int)): Two lists consisting of train and test indices. 42 | """ 43 | splitter = GroupShuffleSplit(test_size=test_ratio, 44 | n_splits=1, 45 | random_state=42) 46 | 47 | train_inds, test_inds = next( 48 | splitter.split(metadata, 49 | metadata[target], 50 | groups=metadata[self.stratify])) 51 | 52 | return train_inds, test_inds 53 | 54 | def generate_train_val_test_split_indices(self, datapath: str, 55 | target: str) -> dict: 56 | """A function to generate a list of indices for train/val/test split of the whole dataset. 57 | 58 | Args: 59 | datapath (str): Path to full data. 60 | target (str): Target for classification present in `obs`. 61 | 62 | Returns: 63 | dict: 'train', 'val' and 'test' indices list. 64 | """ 65 | if not target: 66 | raise ValueError('Must provide target for StratifiedGroupSplitter') 67 | 68 | adata = read_data(datapath) 69 | metadata = adata.obs 70 | metadata['true_index'] = range(len(metadata)) 71 | n_cls = metadata[target].nunique() 72 | 73 | if n_cls > 2: 74 | raise ValueError( 75 | 'StratifiedGroupSplitter only works for binary classification.') 76 | 77 | total_ratio = sum(self.split_ratio) 78 | train_ratio = self.split_ratio[0] / total_ratio 79 | val_ratio = self.split_ratio[1] / total_ratio 80 | val_ratio = val_ratio / (val_ratio + train_ratio) 81 | test_ratio = self.split_ratio[2] / total_ratio 82 | 83 | train_indices = [] 84 | val_indices = [] 85 | test_indices = [] 86 | 87 | for label in metadata[target].unique(): 88 | label_metadata = metadata[metadata[target] == label] 89 | 90 | # Split testing and (train+val) indices. 91 | relative_train_val_inds, relative_test_inds = self._split_data_with_stratification( 92 | label_metadata, target, test_ratio) 93 | 94 | train_val_data = label_metadata.iloc[relative_train_val_inds] 95 | 96 | # Get train and val indices, relative to the `train_val_data`. 97 | relative_train_inds, relative_val_inds = self._split_data_with_stratification( 98 | train_val_data, target, val_ratio) 99 | 100 | # Get true_indices relative to the entire data. 101 | test_indices.extend( 102 | label_metadata.iloc[relative_test_inds]['true_index'].tolist()) 103 | val_indices.extend( 104 | train_val_data.iloc[relative_val_inds]['true_index'].tolist()) 105 | train_indices.extend( 106 | train_val_data.iloc[relative_train_inds]['true_index'].tolist()) 107 | 108 | data_split = { 109 | 'train': train_indices, 110 | 'val': val_indices, 111 | 'test': test_indices 112 | } 113 | 114 | return data_split 115 | 116 | @classmethod 117 | def get_default_params(cls) -> dict: 118 | """Class method to get default params for model_config.""" 119 | return dict(split_ratio=[7, 1, 2], stratify='donor_id') 120 | -------------------------------------------------------------------------------- /scalr/data/split/stratified_splitter.py: -------------------------------------------------------------------------------- 1 | """This file is an implementation of the stratified splitter.""" 2 | 3 | from pandas import DataFrame 4 | from sklearn.model_selection import StratifiedShuffleSplit 5 | 6 | from scalr.data.split import SplitterBase 7 | from scalr.utils import read_data 8 | 9 | 10 | class StratifiedSplitter(SplitterBase): 11 | """ Generate Stratified split of data into train, validation, and test sets. 12 | 13 | Stratification ensures the percentage of samples for each class. It ensures 14 | every split contains samples from each class available in the data. 15 | """ 16 | 17 | def __init__(self, split_ratio: list[float]): 18 | """Initialize splitter with required parameters. 19 | 20 | Args: 21 | split_ratio (list[float]): ratio to split number of samples in 22 | """ 23 | super().__init__() 24 | self.split_ratio = split_ratio 25 | 26 | def _split_data_with_stratification( 27 | self, metadata: DataFrame, target: str, 28 | test_ratio: float) -> tuple[list[int], list[int]]: 29 | """A function to split the given metadata into a training and testing set. 30 | 31 | Args: 32 | metadata (DataFrame): Dataframe containing all samples to be split. 33 | target (str): Target for classification present in `obs`. 34 | test_ratio (float): Ratio of samples belonging to the test split. 35 | 36 | Returns: 37 | (list(int), list(int)): Two lists consisting of train and test indices. 38 | """ 39 | splitter = StratifiedShuffleSplit(test_size=test_ratio, 40 | n_splits=1, 41 | random_state=42) 42 | 43 | train_inds, test_inds = next(splitter.split(metadata, metadata[target])) 44 | 45 | return train_inds, test_inds 46 | 47 | def generate_train_val_test_split_indices(self, datapath: str, 48 | target: str) -> dict: 49 | """A function to generate a list of indices for train/val/test split of the whole dataset. 50 | 51 | Args: 52 | datapath (str): Path to full data. 53 | target (str): Target for classification present in `obs`. 54 | 55 | Returns: 56 | dict: 'train', 'val' and 'test' indices list. 57 | """ 58 | if not target: 59 | raise ValueError('Must provide target for StratifiedSplitter') 60 | 61 | adata = read_data(datapath) 62 | metadata = adata.obs 63 | metadata['true_index'] = range(len(metadata)) 64 | n_cls = metadata[target].nunique() 65 | 66 | total_ratio = sum(self.split_ratio) 67 | train_ratio = self.split_ratio[0] / total_ratio 68 | val_ratio = self.split_ratio[1] / total_ratio 69 | test_ratio = self.split_ratio[2] / total_ratio 70 | 71 | # Split testing and (train+val) indices. 72 | training_inds, testing_inds = self._split_data_with_stratification( 73 | metadata, target, test_ratio) 74 | 75 | train_val_data = metadata.iloc[training_inds] 76 | val_ratio = val_ratio / (val_ratio + train_ratio) 77 | 78 | # Get train and val indices, relative to the `train_val_data`. 79 | relative_train_inds, relative_val_inds = self._split_data_with_stratification( 80 | train_val_data, target, val_ratio) 81 | 82 | # Get true_indices relative to the entire data. 83 | true_test_inds = testing_inds.tolist() 84 | true_val_inds = train_val_data.iloc[relative_val_inds][ 85 | 'true_index'].tolist() 86 | true_train_inds = train_val_data.iloc[relative_train_inds][ 87 | 'true_index'].tolist() 88 | 89 | data_split = { 90 | 'train': true_train_inds, 91 | 'val': true_val_inds, 92 | 'test': true_test_inds 93 | } 94 | 95 | return data_split 96 | 97 | @classmethod 98 | def get_default_params(cls) -> dict: 99 | """Class method to get default params for model_config.""" 100 | return dict(split_ratio=[7, 1, 2]) 101 | -------------------------------------------------------------------------------- /scalr/data_ingestion_pipeline.py: -------------------------------------------------------------------------------- 1 | """This file is a class for data ingestion into the pipeline.""" 2 | 3 | from copy import deepcopy 4 | import os 5 | from os import path 6 | 7 | import pandas as pd 8 | 9 | from scalr.data.preprocess import build_preprocessor 10 | from scalr.data.split import build_splitter 11 | from scalr.utils import FlowLogger 12 | from scalr.utils import read_data 13 | from scalr.utils import write_data 14 | 15 | 16 | class DataIngestionPipeline: 17 | """Class for Data Ingestion into the pipeline""" 18 | 19 | def __init__(self, data_config: dict, dirpath: str = '.'): 20 | """Load data config and create a `data` directory. 21 | 22 | Args: 23 | data_config (dict): Data processing configuration and paths. 24 | dirpath (str): Experiment data directory. Defaults to '.'. 25 | """ 26 | 27 | self.flow_logger = FlowLogger('DataIngestion') 28 | 29 | self.data_config = deepcopy(data_config) 30 | self.target = self.data_config.get('target') 31 | self.sample_chunksize = self.data_config.get('sample_chunksize') 32 | self.num_workers = self.data_config.get('num_workers', 1) 33 | 34 | # Make some necessary checks and logs. 35 | if not self.target: 36 | self.flow_logger.warning('Target not given') 37 | 38 | if not self.sample_chunksize: 39 | self.flow_logger.warning( 40 | '''Sample chunk size not given. Will default to not using chunking. 41 | Might results in excessive use of memory.''') 42 | 43 | self.datadir = dirpath 44 | 45 | def generate_train_val_test_split(self): 46 | """A function to split data into train, validation and test sets.""" 47 | 48 | # TODO: Move to config validation 49 | if self.data_config['train_val_test'].get( 50 | 'full_datapath', 51 | False) and self.data_config['train_val_test'].get( 52 | 'split_datapaths', False): 53 | self.flow_logger.warning( 54 | '''`full_datapath` and `split_datapaths` are both provided in 55 | the config. Pipeline will use `full_datapath` and it will overwrite 56 | the `split_datapaths`. 57 | ''') 58 | 59 | # TODO: Move to config validation 60 | if self.data_config['train_val_test'].get( 61 | 'feature_subset_datapaths' 62 | ) or self.data_config['train_val_test'].get('final_datapaths'): 63 | raise ValueError( 64 | '''`final_datapaths` or `feature_subset_datapaths` can not be provided 65 | by the user in the config. These paths are generated by the pipeline! 66 | ''') 67 | 68 | if self.data_config['train_val_test'].get('full_datapath'): 69 | self.flow_logger.info('Generating Train, Validation and Test sets') 70 | if not self.target: 71 | self.flow_logger.warning( 72 | '''Target not provided. Will not be able to perform 73 | checks regarding splits. 74 | ''') 75 | 76 | full_datapath = self.data_config['train_val_test']['full_datapath'] 77 | self.full_data = read_data(full_datapath) 78 | splitter_config = deepcopy( 79 | self.data_config['train_val_test']['splitter_config']) 80 | splitter, splitter_config = build_splitter(splitter_config) 81 | self.data_config['train_val_test'][ 82 | 'splitter_config'] = splitter_config 83 | 84 | # Make data splits. 85 | train_val_test_split_indices = splitter.generate_train_val_test_split_indices( 86 | full_datapath, self.target) 87 | 88 | write_data(train_val_test_split_indices, 89 | path.join(self.datadir, 'train_val_test_split.json')) 90 | 91 | # Check data splits. 92 | if self.target: 93 | splitter.check_splits(full_datapath, 94 | train_val_test_split_indices, self.target) 95 | 96 | # Write data splits. 97 | train_val_test_split_dirpath = path.join(self.datadir, 98 | 'train_val_test_split') 99 | os.makedirs(train_val_test_split_dirpath, exist_ok=True) 100 | 101 | splitter.write_splits(self.full_data, train_val_test_split_indices, 102 | self.sample_chunksize, 103 | train_val_test_split_dirpath, 104 | self.num_workers) 105 | 106 | # Garbage collection 107 | del self.full_data 108 | 109 | self.data_config['train_val_test'][ 110 | 'split_datapaths'] = train_val_test_split_dirpath 111 | 112 | elif self.data_config['train_val_test'].get('split_datapaths'): 113 | self.flow_logger.info( 114 | 'Reading Train, Validation and Test sets from config') 115 | 116 | # TODO: Move to config validation 117 | else: 118 | raise ValueError( 119 | 'No Data Provided. Please provide `full_datapath` or `split_datapaths`!' 120 | ) 121 | 122 | def preprocess_data(self): 123 | """A function to apply preprocessing on data splits.""" 124 | 125 | self.data_config['train_val_test']['final_datapaths'] = deepcopy( 126 | self.data_config['train_val_test']['split_datapaths']) 127 | 128 | all_preprocessings = self.data_config.get('preprocess', list()) 129 | if not all_preprocessings: 130 | return 131 | 132 | self.flow_logger.info('Preprocessing data') 133 | datapath = self.data_config['train_val_test']['final_datapaths'] 134 | 135 | processed_datapath = path.join(self.datadir, 'processed_data') 136 | os.makedirs(processed_datapath, exist_ok=True) 137 | 138 | for i, (preprocess) in enumerate(all_preprocessings): 139 | self.flow_logger.info(f'Applying {preprocess["name"]}') 140 | 141 | preprocessor, preprocessor_config = build_preprocessor( 142 | deepcopy(preprocess)) 143 | self.data_config['preprocess'][i] = preprocessor_config 144 | # Fit on train data. 145 | preprocessor.fit(read_data(path.join(datapath, 'train')), 146 | self.sample_chunksize) 147 | # Transform on train, val & test split. 148 | for split in ['train', 'val', 'test']: 149 | split_data = read_data(path.join(datapath, split)) 150 | preprocessor.process_data(split_data, self.sample_chunksize, 151 | path.join(processed_datapath, split), 152 | self.num_workers) 153 | 154 | datapath = processed_datapath 155 | 156 | self.data_config['train_val_test'][ 157 | 'final_datapaths'] = processed_datapath 158 | 159 | def generate_mappings(self): 160 | """A function to generate an Integer mapping to and from target columns.""" 161 | 162 | self.flow_logger.info( 163 | 'Generate label mappings for all columns in metadata') 164 | 165 | column_names = read_data( 166 | path.join(self.data_config['train_val_test']['final_datapaths'], 167 | 'val')).obs.columns 168 | 169 | data_obs = [] 170 | for split in ['train', 'val', 'test']: 171 | datapath = path.join( 172 | self.data_config['train_val_test']['final_datapaths'], split) 173 | split_data_obs = read_data(datapath).obs 174 | data_obs.append(split_data_obs) 175 | full_data_obs = pd.concat(data_obs) 176 | 177 | label_mappings = {} 178 | for column_name in column_names: 179 | label_mappings[column_name] = {} 180 | 181 | id2label = sorted(full_data_obs[column_name].astype( 182 | 'category').cat.categories.tolist()) 183 | 184 | label2id = {id2label[i]: i for i in range(len(id2label))} 185 | label_mappings[column_name]['id2label'] = id2label 186 | label_mappings[column_name]['label2id'] = label2id 187 | 188 | # Garbage collection 189 | del data_obs 190 | del full_data_obs 191 | 192 | write_data(label_mappings, path.join(self.datadir, 193 | 'label_mappings.json')) 194 | 195 | self.data_config['label_mappings'] = path.join(self.datadir, 196 | 'label_mappings.json') 197 | 198 | def get_updated_config(self): 199 | """This function returns updated configs.""" 200 | return self.data_config 201 | -------------------------------------------------------------------------------- /scalr/feature/__init__.py: -------------------------------------------------------------------------------- 1 | from . import scoring 2 | from . import selector 3 | from .feature_subsetting import FeatureSubsetting 4 | -------------------------------------------------------------------------------- /scalr/feature/feature_subsetting.py: -------------------------------------------------------------------------------- 1 | """This file contains implementation for model training on feature subsets.""" 2 | 3 | from copy import deepcopy 4 | import os 5 | from os import path 6 | from typing import Union 7 | 8 | from anndata import AnnData 9 | from anndata.experimental import AnnCollection 10 | from joblib import delayed 11 | from joblib import Parallel 12 | from torch import nn 13 | 14 | from scalr.model_training_pipeline import ModelTrainingPipeline 15 | from scalr.utils import EventLogger 16 | from scalr.utils import FlowLogger 17 | from scalr.utils import read_data 18 | from scalr.utils import write_chunkwise_data 19 | 20 | 21 | class FeatureSubsetting: 22 | """Class for FeatureSubsetting. 23 | 24 | It trains a model for each subsetted datasets, each 25 | containing `feature_subsetsize` genes as features. 26 | """ 27 | 28 | def __init__(self, 29 | feature_subsetsize: int, 30 | chunk_model_config: dict, 31 | chunk_model_train_config: dict, 32 | train_data: Union[AnnData, AnnCollection], 33 | val_data: Union[AnnData, AnnCollection], 34 | target: str, 35 | mappings: dict, 36 | dirpath: str = None, 37 | device: str = 'cpu', 38 | num_workers: int = 1, 39 | sample_chunksize: int = None): 40 | """Initialize required parameters for feature subset training. 41 | 42 | Args: 43 | feature_subsetsize (int): Number of features in one subset. 44 | chunk_model_config (dict): Chunked model config. 45 | chunk_model_train_config (dict): Chunked model training config. 46 | train_data (Union[AnnData, AnnCollection]): Train dataset. 47 | val_data (Union[AnnData, AnnCollection]): Validation dataset. 48 | target (str): Target to train model. 49 | mappings (dict): mapping of target to labels. 50 | dirpath (str, optional): Dirpath to store chunked model weights. Defaults to None. 51 | device (str, optional): Device to train models on. Defaults to 'cpu'. 52 | num_workers (int, optional): Number of parallel processes to launch to train multiple 53 | feature subsets simultaneously. Defaults to using single 54 | process. 55 | sample_chunksize (int, optional): Chunks of samples to be loaded in memory at once. 56 | Required when `num_workers` > 1. 57 | """ 58 | self.feature_subsetsize = feature_subsetsize 59 | self.chunk_model_config = chunk_model_config 60 | self.chunk_model_train_config = chunk_model_train_config 61 | self.train_data = train_data 62 | self.val_data = val_data 63 | self.target = target 64 | self.mappings = mappings 65 | self.dirpath = dirpath 66 | self.device = device 67 | self.num_workers = num_workers if num_workers else 1 68 | self.sample_chunksize = sample_chunksize 69 | 70 | self.total_features = len(self.train_data.var_names) 71 | 72 | # Note that EventLogger does not work with parallel training 73 | # You may use tensorboard logging to track model training logs 74 | if self.num_workers == 1: 75 | self.event_logger = EventLogger('FeatureSubsetting') 76 | 77 | def write_feature_subsetted_data(self): 78 | """Write chunks of feature-subsetted data, to enable parallel training of models 79 | using different chunks of data.""" 80 | if self.num_workers == 1: 81 | return 82 | 83 | self.feature_chunked_data_dirpath = path.join(self.dirpath, 84 | 'chunked_data') 85 | os.makedirs(self.feature_chunked_data_dirpath, exist_ok=True) 86 | 87 | i = 0 88 | for start in range(0, self.total_features, self.feature_subsetsize): 89 | 90 | feature_subset_inds = list( 91 | range(start, 92 | min(start + self.feature_subsetsize, 93 | self.total_features))) 94 | 95 | write_chunkwise_data(self.train_data, 96 | self.sample_chunksize, 97 | path.join(self.feature_chunked_data_dirpath, 98 | 'train', str(i)), 99 | feature_inds=feature_subset_inds, 100 | num_workers=self.num_workers) 101 | 102 | write_chunkwise_data(self.val_data, 103 | self.sample_chunksize, 104 | path.join(self.feature_chunked_data_dirpath, 105 | 'val', str(i)), 106 | feature_inds=feature_subset_inds, 107 | num_workers=self.num_workers) 108 | 109 | i += 1 110 | 111 | del self.train_data 112 | del self.val_data 113 | 114 | def train_chunked_models(self) -> list[nn.Module]: 115 | """Trains a model for each subset data. 116 | 117 | Returns: 118 | list[nn.Module]: List of models for each subset. 119 | """ 120 | if self.num_workers == 1: 121 | self.event_logger.info('Feature subset models training') 122 | 123 | chunked_models_dirpath = path.join(self.dirpath, 'chunked_models') 124 | os.makedirs(chunked_models_dirpath, exist_ok=True) 125 | 126 | def train_chunked_model(i, start): 127 | if self.num_workers == 1: 128 | self.event_logger.info(f'\nChunk {i}') 129 | 130 | chunk_dirpath = path.join(chunked_models_dirpath, str(i)) 131 | os.makedirs(chunk_dirpath, exist_ok=True) 132 | 133 | if self.num_workers > 1: 134 | train_features_subset = read_data( 135 | path.join(self.feature_chunked_data_dirpath, 'train', 136 | str(i))) 137 | val_features_subset = read_data( 138 | path.join(self.feature_chunked_data_dirpath, 'val', str(i))) 139 | else: 140 | train_features_subset = self.train_data[:, start:start + 141 | self.feature_subsetsize] 142 | val_features_subset = self.val_data[:, start:start + 143 | self.feature_subsetsize] 144 | 145 | chunk_model_config = deepcopy(self.chunk_model_config) 146 | 147 | model_trainer = ModelTrainingPipeline(chunk_model_config, 148 | self.chunk_model_train_config, 149 | chunk_dirpath, self.device) 150 | 151 | model_trainer.set_data_and_targets(train_features_subset, 152 | val_features_subset, self.target, 153 | self.mappings) 154 | 155 | model_trainer.build_model_training_artifacts() 156 | best_model = model_trainer.train() 157 | 158 | self.chunk_model_config, self.chunk_model_train_config = model_trainer.get_updated_config( 159 | ) 160 | 161 | return i, best_model 162 | 163 | parallel = Parallel(n_jobs=self.num_workers) 164 | models = parallel( 165 | delayed(train_chunked_model)(i, start) for i, (start) in enumerate( 166 | range(0, self.total_features, self.feature_subsetsize))) 167 | 168 | # parallel loop returns all models with the chunk number, which is used to sort models in order 169 | # model[1] returns only the model, without the chunk number 170 | models = sorted(models) 171 | models = [model[1] for model in models] 172 | return models 173 | 174 | def get_updated_configs(self): 175 | """Returns updated configs.""" 176 | return self.chunk_model_config, self.chunk_model_train_config 177 | -------------------------------------------------------------------------------- /scalr/feature/scoring/__init__.py: -------------------------------------------------------------------------------- 1 | from ._scoring import build_scorer 2 | from ._scoring import ScoringBase 3 | from .linear_scorer import LinearScorer 4 | from .shap_scorer import ShapScorer -------------------------------------------------------------------------------- /scalr/feature/scoring/_scoring.py: -------------------------------------------------------------------------------- 1 | """This file is a base class for feature scorer.""" 2 | 3 | from typing import Union 4 | 5 | from anndata import AnnData 6 | from anndata.experimental import AnnCollection 7 | import numpy as np 8 | from torch import nn 9 | 10 | import scalr 11 | from scalr.utils import build_object 12 | 13 | 14 | class ScoringBase: 15 | """Base class for the scorer.""" 16 | 17 | def __init__(self): 18 | pass 19 | 20 | # Abstract 21 | def generate_scores(self, model: nn.Module, 22 | train_data: Union[AnnData, AnnCollection], 23 | val_data: Union[AnnData, AnnCollection], target: str, 24 | mappings: dict) -> np.ndarray: 25 | """A function to return the score of each feature for each class. 26 | 27 | Args: 28 | model (nn.Module): Trained model to generate scores from. 29 | train_data (Union[AnnData, AnnCollection]): Training data of model. 30 | val_data (Union[AnnData, AnnCollection]): Validation data of model. 31 | target (str): Column in data, used to train the model on. 32 | mappings (dict): Mapping of model output dimension to its 33 | corresponding labels in the metadata columns. 34 | 35 | Returns: 36 | np.ndarray: score_matrix [num_classes X num_features] 37 | """ 38 | pass 39 | 40 | @classmethod 41 | def get_default_params(cls) -> dict: 42 | """Class method to get default params.""" 43 | return dict() 44 | 45 | 46 | def build_scorer(scorer_config: dict) -> tuple[ScoringBase, dict]: 47 | """Builder object to get scorer, updated scorer_config.""" 48 | return build_object(scalr.feature.scoring, scorer_config) 49 | -------------------------------------------------------------------------------- /scalr/feature/scoring/linear_scorer.py: -------------------------------------------------------------------------------- 1 | """This file is an implementation of a linear scorer.""" 2 | 3 | from typing import Union 4 | 5 | from anndata import AnnData 6 | from anndata.experimental import AnnCollection 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from scalr.feature.scoring import ScoringBase 12 | 13 | 14 | class LinearScorer(ScoringBase): 15 | """Class for the linear scorer. 16 | 17 | This Scorer is only applicable for linear (single-layer) models. 18 | It directly uses the weights as the score for each feature. 19 | """ 20 | 21 | def __init__(self): 22 | pass 23 | 24 | def generate_scores(self, model: nn.Module, *args, **kwargs) -> np.ndarray: 25 | """A function to generate and return the weights of the model as a score.""" 26 | return model.state_dict()['out_layer.weight'].cpu().detach().numpy() 27 | -------------------------------------------------------------------------------- /scalr/feature/scoring/shap_scorer.py: -------------------------------------------------------------------------------- 1 | """This file is an implementation of SHAP scorer.""" 2 | 3 | from typing import Tuple, Union 4 | 5 | from anndata import AnnData 6 | from anndata.experimental import AnnCollection 7 | import numpy as np 8 | import pandas as pd 9 | import shap 10 | from sklearn.preprocessing import OneHotEncoder 11 | import torch 12 | from torch import nn 13 | 14 | from scalr import utils 15 | from scalr.feature.scoring import ScoringBase 16 | from scalr.nn.dataloader import build_dataloader 17 | from scalr.nn.model import CustomShapModel 18 | 19 | 20 | class ShapScorer(ScoringBase): 21 | """Class for SHAP scorer. It can be used for any model.""" 22 | 23 | def __init__(self, 24 | early_stop: dict, 25 | dataloader: dict, 26 | device: str = 'cpu', 27 | top_n_genes: int = 100, 28 | background_tensor: int = 200, 29 | samples_abs_mean: bool = True, 30 | logger: str = 'EventLogger', 31 | *args, 32 | **kwargs): 33 | """Initialize class with SHAP arguments. 34 | 35 | Args: 36 | early_stop: Contains early stopping-related configuration. 37 | dataloader: Dataloader related config. 38 | device: Where data is processed/loaded. 39 | top_n_genes: Top N genes for each class/label. 40 | background_tensor: Number of training data used for SHAP explainer. 41 | samples_abs_mean: Apply abs before taking the mean across samples. 42 | """ 43 | 44 | self.early_stop_config = early_stop 45 | self.device = device 46 | self.top_n_genes = top_n_genes 47 | self.background_tensor = background_tensor 48 | self.dataloader_config = dataloader 49 | self.samples_abs_mean = samples_abs_mean 50 | 51 | self.logger = getattr(utils, logger)('SHAP analysis') 52 | 53 | def generate_scores(self, model: nn.Module, 54 | train_data: Union[AnnData, AnnCollection], 55 | val_data: Union[AnnData, AnnCollection], target: str, 56 | mappings: dict, *args, **kwargs) -> np.ndarray: 57 | """This function returns the weights of the model as a score. 58 | 59 | Args: 60 | model: Trained model that is used for SHAP. 61 | train_data: Data that is used as reference data for SHAP. 62 | val_data: On which SHAP will generate the score. 63 | mappings: Contains target-related mappings. 64 | 65 | Returns: 66 | class * genes abs weights matrix. 67 | """ 68 | 69 | shap_values = self.get_top_n_genes_weights(model, train_data, val_data, 70 | target, mappings) 71 | 72 | return shap_values 73 | 74 | def get_top_n_genes_weights( 75 | self, model: nn.Module, train_data: Union[AnnData, AnnCollection], 76 | test_data: Union[AnnData, AnnCollection], target: str, 77 | mappings: dict) -> Tuple[np.ndarray, np.ndarray]: 78 | """ A function to get top n genes of each class and its weights. 79 | 80 | Args: 81 | model: Trained model to extract weights from. 82 | train_data: Train data. 83 | test_data: Test data that is used for SHAP values. 84 | target: Target name. 85 | mappings: Contains target-related mappings. 86 | 87 | Returns: 88 | (class * genes abs weights matrix, class * genes weights matrix). 89 | """ 90 | 91 | if isinstance(self.logger, utils.EventLogger): 92 | self.logger.heading2("Genes analysis using SHAP.") 93 | 94 | model.to(self.device) 95 | shap_model = CustomShapModel(model) 96 | 97 | random_indices = np.random.randint(0, train_data.shape[0], 98 | self.background_tensor) 99 | train_dl, _ = build_dataloader(self.dataloader_config, 100 | train_data[random_indices], target, 101 | mappings) 102 | random_background_data = torch.cat([batch[0] for batch in train_dl]) 103 | 104 | self.logger.info( 105 | f"Selected random background data: {random_background_data.shape}") 106 | 107 | test_dl, _ = build_dataloader(self.dataloader_config, test_data, target, 108 | mappings) 109 | 110 | explainer = shap.DeepExplainer(shap_model, 111 | random_background_data.to(self.device)) 112 | 113 | prev_top_genes_batch_wise = {} 114 | count_patience = 0 115 | total_samples = 0 116 | 117 | for batch_id, batch in enumerate(test_dl): 118 | self.logger.info(f"Running on batch: {batch_id}") 119 | total_samples += batch[0].shape[0] 120 | 121 | batch_shap_values = explainer.shap_values(batch[0].to(self.device)) 122 | if self.samples_abs_mean: 123 | sum_shap_values = np.abs(batch_shap_values).sum(axis=0) 124 | else: 125 | # Calcluating 2 mean with abs values and non-abs values. 126 | # Non-abs values required for heatmap. 127 | sum_shap_values = batch_shap_values.sum(axis=0) 128 | 129 | if batch_id >= 1: 130 | sum_shap_values = np.sum( 131 | [sum_shap_values, prev_batches_sum_shap_values], axis=0) 132 | 133 | mean_shap_values = sum_shap_values / total_samples 134 | 135 | genes_class_shap_df = pd.DataFrame( 136 | mean_shap_values[:len(test_dl.dataset.var_names)], 137 | index=test_dl.dataset.var_names) 138 | 139 | prev_batches_sum_shap_values = sum_shap_values 140 | 141 | early_stop, prev_top_genes_batch_wise = self._is_shap_early_stop( 142 | batch_id, genes_class_shap_df, prev_top_genes_batch_wise, 143 | self.top_n_genes, self.early_stop_config['threshold']) 144 | 145 | count_patience = count_patience + 1 if early_stop else 0 146 | 147 | if count_patience == self.early_stop_config['patience']: 148 | self.logger.info(f"Early stopping at batch: {batch_id}") 149 | break 150 | 151 | return mean_shap_values.T 152 | 153 | def _is_shap_early_stop( 154 | self, 155 | batch_id: int, 156 | genes_class_shap_df: pd.DataFrame, 157 | prev_top_genes_batch_wise: dict, 158 | top_n_genes: int, 159 | threshold: int, 160 | ) -> Tuple[bool, dict]: 161 | """A function to check whether previous and current batches' common genes are 162 | are greater than or equal to the threshold and return top genes 163 | batch wise. 164 | 165 | Args: 166 | batch_id: Current batch number. 167 | genes_class_shap_df: label/class wise genes SHAP values(mean across samples). 168 | prev_top_genes_batch_wise: Dictionary where prev batches per labels top genes are stored. 169 | top_n_genes: Number of top genes check. 170 | threshold: early stop if common genes are higher than this. 171 | 172 | Returns: 173 | Early stop value, top genes batch wise. 174 | """ 175 | 176 | early_stop = True 177 | top_genes_batch_wise = {} 178 | classes = genes_class_shap_df.columns 179 | for label in classes: 180 | top_genes_batch_wise[label] = genes_class_shap_df[ 181 | label].sort_values(ascending=False)[:top_n_genes].index 182 | 183 | # Start checking after first batch. 184 | if batch_id >= 1: 185 | num_common_genes = len( 186 | set(top_genes_batch_wise[label]).intersection( 187 | set(prev_top_genes_batch_wise[label]))) 188 | # If commnon genes are less than 90 early stop will be false. 189 | if num_common_genes < threshold: 190 | early_stop = False 191 | else: 192 | early_stop = False 193 | 194 | return early_stop, top_genes_batch_wise 195 | 196 | @classmethod 197 | def get_default_params(cls) -> dict: 198 | """Class method to get default params.""" 199 | return { 200 | "top_n_genes": 100, 201 | "background_tensor": 200, 202 | "samples_abs_mean": True, 203 | "early_stop": { 204 | "patience": 5, 205 | "threshold": 95 206 | }, 207 | "dataloader": { 208 | "name": "SimpleDataLoader", 209 | "params": { 210 | "batch_size": 5000, 211 | "padding": 5000 212 | } 213 | } 214 | } 215 | -------------------------------------------------------------------------------- /scalr/feature/selector/__init__.py: -------------------------------------------------------------------------------- 1 | from ._selector import build_selector 2 | from ._selector import SelectorBase 3 | from .abs_mean import AbsMean 4 | from .classwise_abs import ClasswiseAbs 5 | from .classwise_promoters import ClasswisePromoters 6 | -------------------------------------------------------------------------------- /scalr/feature/selector/_selector.py: -------------------------------------------------------------------------------- 1 | """This file is a base class for the top feature selector.""" 2 | 3 | from typing import Union 4 | 5 | from pandas import DataFrame 6 | 7 | import scalr 8 | from scalr.utils import build_object 9 | 10 | 11 | class SelectorBase: 12 | """Base class for Feature Selector from scores.""" 13 | 14 | # Abstract 15 | def get_feature_list(score_matrix: DataFrame, 16 | **kwargs) -> Union[list[str], dict]: 17 | """A function to return top features from given scores of each feature for each class. 18 | 19 | Args: 20 | score_matrix (DataFrame): Score of each feature across all classes 21 | [num_classes X num_features]. 22 | 23 | Returns: 24 | list[str]: List of features. 25 | """ 26 | return 27 | 28 | @classmethod 29 | def get_default_params(cls) -> dict: 30 | """Class method to get default params.""" 31 | return dict() 32 | 33 | 34 | def build_selector(selector_config: dict) -> tuple[SelectorBase, dict]: 35 | """Builder object to get Selector, updated selector_config.""" 36 | return build_object(scalr.feature.selector, selector_config) 37 | -------------------------------------------------------------------------------- /scalr/feature/selector/abs_mean.py: -------------------------------------------------------------------------------- 1 | """This file is an implementation of the Absolute mean feature selector strategy.""" 2 | 3 | from pandas import DataFrame 4 | 5 | from scalr.feature.selector import SelectorBase 6 | 7 | 8 | class AbsMean(SelectorBase): 9 | """Class for absolute mean feature selector strategy. 10 | 11 | It uses the absolute mean across all classes as the score of the feature. 12 | """ 13 | 14 | def __init__(self, k: int = 1e6): 15 | """Initialize required parameters for the selector.""" 16 | self.k = k 17 | 18 | def get_feature_list(self, score_matrix: DataFrame) -> list[str]: 19 | """A function to return top features using score matrix and selector strategy. 20 | 21 | Args: 22 | score_matrix (DataFrame): Score of each feature across all classes 23 | [num_classes X num_features]. 24 | 25 | Returns: 26 | list[str]: List of top k features. 27 | """ 28 | top_features_list = list(score_matrix.abs().mean().sort_values( 29 | ascending=False).reset_index()['index'][:self.k]) 30 | return top_features_list 31 | 32 | @classmethod 33 | def get_default_params(cls) -> dict: 34 | """Class method to get default params for preprocess_config.""" 35 | return dict(k=int(1e6)) 36 | -------------------------------------------------------------------------------- /scalr/feature/selector/classwise_abs.py: -------------------------------------------------------------------------------- 1 | """This file returns top K features(can be promoters or inhibitors as well) per class.""" 2 | 3 | from pandas import DataFrame 4 | 5 | from scalr.feature.selector import SelectorBase 6 | 7 | 8 | class ClasswiseAbs(SelectorBase): 9 | """Class for class-wise absolute feature selector strategy. 10 | 11 | Classwise scorer returns a dict for each class, containing the top 12 | absolute scores of genes. 13 | """ 14 | 15 | def __init__(self, k: int = 1e6) -> dict: 16 | """Initialize required parameters for the selector.""" 17 | self.k = k 18 | 19 | def get_feature_list(self, score_matrix: DataFrame): 20 | """A function to return top features per class using score matrix 21 | and selector strategy. 22 | 23 | Args: 24 | score_matrix (DataFrame): Score of each feature across all classes 25 | [num_classes X num_features]. 26 | 27 | Returns: 28 | dict: List of top_k features for each class. 29 | """ 30 | classwise_abs = dict() 31 | n_cls = len(score_matrix) 32 | 33 | for i in range(n_cls): 34 | for i in range(n_cls): 35 | classwise_abs[score_matrix.index[i]] = abs( 36 | score_matrix.iloc[i, :]).sort_values( 37 | ascending=False).reset_index()['index'][:self.k].tolist( 38 | ) 39 | 40 | return classwise_abs 41 | 42 | @classmethod 43 | def get_default_params(cls) -> dict: 44 | """Class method to get default params for preprocess_config.""" 45 | return dict(k=int(1e6)) 46 | -------------------------------------------------------------------------------- /scalr/feature/selector/classwise_promoters.py: -------------------------------------------------------------------------------- 1 | """This file returns top K promoter features per class.""" 2 | 3 | from pandas import DataFrame 4 | 5 | from scalr.feature.selector import SelectorBase 6 | 7 | 8 | class ClasswisePromoters(SelectorBase): 9 | """Class for class-wise promoter feature selector strategy. 10 | 11 | Classwise scorer returns a dict for each class, containing the top 12 | positive scored genes. 13 | """ 14 | 15 | def __init__(self, k: int = 1e6): 16 | """Initialize required parameters for the selector.""" 17 | self.k = k 18 | 19 | def get_feature_list(self, score_matrix: DataFrame): 20 | """A function to return top features per class using score matrix 21 | and selector strategy. 22 | 23 | Args: 24 | score_matrix (DataFrame): Score of each feature across all classes 25 | [num_classes X num_features]. 26 | 27 | Returns: 28 | list[str]: List of top k features. 29 | """ 30 | classwise_promoters = dict() 31 | n_cls = len(score_matrix) 32 | 33 | for i in range(n_cls): 34 | for i in range(n_cls): 35 | classwise_promoters[score_matrix.index[i]] = score_matrix.iloc[ 36 | i, :].sort_values(ascending=False).reset_index( 37 | )['index'][:self.k].tolist() 38 | 39 | return classwise_promoters 40 | 41 | @classmethod 42 | def get_default_params(cls) -> dict: 43 | """Class method to get default params for preprocess_config.""" 44 | return dict(k=int(1e6)) 45 | -------------------------------------------------------------------------------- /scalr/feature_extraction_pipeline.py: -------------------------------------------------------------------------------- 1 | """This file contains the implementation of feature subsetting, model training followed by top feature extraction.""" 2 | 3 | from copy import deepcopy 4 | import os 5 | from os import path 6 | from typing import Union 7 | 8 | from anndata import AnnData 9 | from anndata.experimental import AnnCollection 10 | import numpy as np 11 | import pandas as pd 12 | from torch import nn 13 | 14 | from scalr.feature import FeatureSubsetting 15 | from scalr.feature.scoring import build_scorer 16 | from scalr.feature.selector import build_selector 17 | from scalr.utils import FlowLogger 18 | from scalr.utils import load_train_val_data_from_config 19 | from scalr.utils import read_data 20 | from scalr.utils import write_chunkwise_data 21 | from scalr.utils import write_data 22 | 23 | 24 | class FeatureExtractionPipeline: 25 | 26 | def __init__(self, feature_selection_config, dirpath, device): 27 | """Initialize required parameters for feature selection. 28 | 29 | Feature extraction is done in 4 steps: 30 | 1. Model(s) training on chunked/all features 31 | 2. Class X Feature scoring 32 | 3. Top features extraction 33 | 4. Feature subset data writing 34 | 35 | Args: 36 | feature_selection_config: Feature selection config. 37 | dirpath: Path to load data from. 38 | """ 39 | self.flow_logger = FlowLogger('FeatureExtraction') 40 | 41 | self.feature_selection_config = deepcopy(feature_selection_config) 42 | self.device = device 43 | 44 | self.dirpath = dirpath 45 | os.makedirs(dirpath, exist_ok=True) 46 | 47 | def load_data_and_targets_from_config(self, data_config: dict): 48 | """A function to load data and targets from data config. 49 | 50 | Args: 51 | data_config: Data config. 52 | """ 53 | self.train_data, self.val_data = load_train_val_data_from_config( 54 | data_config) 55 | self.target = data_config.get('target') 56 | self.mappings = read_data(data_config['label_mappings']) 57 | self.sample_chunksize = data_config.get('sample_chunksize') 58 | 59 | def set_data_and_targets(self, 60 | train_data: Union[AnnData, AnnCollection], 61 | val_data: Union[AnnData, AnnCollection], 62 | target: Union[str, list[str]], 63 | mappings: dict, 64 | sample_chunksize: int = None): 65 | """A function to set data when you don't use data directly from config, 66 | but rather by other sources like feature subsetting, etc. 67 | 68 | Args: 69 | train_data (Union[AnnData, AnnCollection]): Training data. 70 | val_data (Union[AnnData, AnnCollection]): Validation data. 71 | target (Union[str, list[str]]): Target columns name(s). 72 | mappings (dict): Mapping of a column value to ids 73 | eg. mappings[column_name][label2id] = {A: 1, B:2, ...}. 74 | sample_chunksize (int): Chunks of samples to be loaded in memory at once. 75 | """ 76 | self.train_data = train_data 77 | self.val_data = val_data 78 | self.target = target 79 | self.mappings = mappings 80 | 81 | def feature_subsetted_model_training(self) -> list[nn.Module]: 82 | """This function train models on subsetted data containing `feature_subsetsize` genes.""" 83 | 84 | self.flow_logger.info('Feature subset models training') 85 | 86 | self.feature_subsetsize = self.feature_selection_config.get( 87 | 'feature_subsetsize', len(self.val_data.var_names)) 88 | self.num_workers = self.feature_selection_config.get('num_workers', 1) 89 | 90 | chunk_model_config = self.feature_selection_config.get('model') 91 | chunk_model_train_config = self.feature_selection_config.get( 92 | 'model_train_config') 93 | 94 | chunked_features_model_trainer = FeatureSubsetting( 95 | self.feature_subsetsize, chunk_model_config, 96 | chunk_model_train_config, self.train_data, self.val_data, 97 | self.target, self.mappings, self.dirpath, self.device, 98 | self.num_workers, self.sample_chunksize) 99 | 100 | if self.num_workers > 1: 101 | chunked_features_model_trainer.write_feature_subsetted_data() 102 | 103 | self.chunked_models = chunked_features_model_trainer.train_chunked_models( 104 | ) 105 | chunk_model_config, chunk_model_train_config = chunked_features_model_trainer.get_updated_configs( 106 | ) 107 | self.feature_selection_config['model'] = chunk_model_config 108 | self.feature_selection_config[ 109 | 'model_train_config'] = chunk_model_train_config 110 | 111 | return self.chunked_models 112 | 113 | def set_model(self, models: list[nn.Module]): 114 | """A function to set the trained model for downstream feature tasks.""" 115 | self.chunked_models = models 116 | 117 | def feature_scoring(self) -> pd.DataFrame: 118 | """A function to generate scores of each feature for each class using a scorer 119 | and chunked models. 120 | """ 121 | self.flow_logger.info('Feature scoring') 122 | 123 | scorer, scorer_config = build_scorer( 124 | deepcopy(self.feature_selection_config.get('scoring_config'))) 125 | self.feature_selection_config['scoring_config'] = scorer_config 126 | 127 | all_scores = [] 128 | if not getattr(self, 'feature_subsetsize', None): 129 | self.feature_subsetsize = self.train_data.shape[1] 130 | 131 | # TODO: Parallelize feature scoring 132 | for i, (model) in enumerate(self.chunked_models): 133 | subset_train_data = self.train_data[:, i * 134 | self.feature_subsetsize:(i + 135 | 1) * 136 | self.feature_subsetsize] 137 | subset_val_data = self.val_data[:, i * 138 | self.feature_subsetsize:(i + 1) * 139 | self.feature_subsetsize] 140 | score = scorer.generate_scores(model, subset_train_data, 141 | subset_val_data, self.target, 142 | self.mappings) 143 | 144 | all_scores.append(score[:self.feature_subsetsize]) 145 | 146 | columns = self.train_data.var_names 147 | columns.name = "index" 148 | class_labels = self.mappings[self.target]['id2label'] 149 | all_scores = np.concatenate(all_scores, axis=1) 150 | all_scores = all_scores[:, :len(columns)] 151 | 152 | self.score_matrix = pd.DataFrame(all_scores, 153 | columns=columns, 154 | index=class_labels) 155 | write_data(self.score_matrix, path.join(self.dirpath, 156 | 'score_matrix.csv')) 157 | return self.score_matrix 158 | 159 | def set_score_matrix(self, score_matrix: pd.DataFrame): 160 | """A function to set score_matrix for feature extraction.""" 161 | self.score_matrix = score_matrix 162 | 163 | def top_feature_extraction(self) -> Union[list[str], dict]: 164 | """A function to get top features using `Selector`.""" 165 | 166 | self.flow_logger.info('Top features extraction') 167 | 168 | selector_config = self.feature_selection_config.get( 169 | 'features_selector', dict(name='AbsMean')) 170 | selector, selector_config = build_selector(selector_config) 171 | self.feature_selection_config['features_selector'] = selector_config 172 | 173 | self.top_features = selector.get_feature_list(self.score_matrix) 174 | write_data(self.top_features, 175 | path.join(self.dirpath, 'top_features.json')) 176 | 177 | return self.top_features 178 | 179 | def write_top_features_subset_data(self, data_config: dict) -> dict: 180 | """A function to write top features subset data onto disk 181 | and return updated data_config. 182 | 183 | Args: 184 | data_config: Data config. 185 | """ 186 | 187 | self.flow_logger.info('Writing feature-subset data onto disk') 188 | 189 | datapath = data_config['train_val_test'].get('final_datapaths') 190 | 191 | feature_subset_datapath = path.join(self.dirpath, 'feature_subset_data') 192 | os.makedirs(feature_subset_datapath, exist_ok=True) 193 | 194 | test_data = read_data(path.join(datapath, 'test')) 195 | splits = { 196 | 'train': self.train_data, 197 | 'val': self.val_data, 198 | 'test': test_data 199 | } 200 | 201 | sample_chunksize = data_config.get('sample_chunksize') 202 | num_workers = data_config.get('num_workers') 203 | 204 | for split, split_data in splits.items(): 205 | 206 | split_feature_subset_datapath = path.join(feature_subset_datapath, 207 | split) 208 | write_chunkwise_data(split_data, 209 | sample_chunksize, 210 | split_feature_subset_datapath, 211 | feature_inds=self.top_features, 212 | num_workers=num_workers) 213 | 214 | data_config['train_val_test'][ 215 | 'feature_subset_datapaths'] = feature_subset_datapath 216 | 217 | return data_config 218 | 219 | def get_updated_config(self) -> dict: 220 | """This function returns updated configs.""" 221 | return self.feature_selection_config 222 | -------------------------------------------------------------------------------- /scalr/model_training_pipeline.py: -------------------------------------------------------------------------------- 1 | """This file contains an implementation for the model training pipeline.""" 2 | 3 | from copy import deepcopy 4 | import os 5 | from os import path 6 | from typing import Union 7 | 8 | from anndata import AnnData 9 | from anndata.experimental import AnnCollection 10 | import torch 11 | 12 | import scalr 13 | from scalr.nn.callbacks import CallbackExecutor 14 | from scalr.nn.dataloader import build_dataloader 15 | from scalr.nn.loss import build_loss_fn 16 | from scalr.nn.model import build_model 17 | from scalr.utils import EventLogger 18 | from scalr.utils import FlowLogger 19 | from scalr.utils import load_train_val_data_from_config 20 | from scalr.utils import read_data 21 | from scalr.utils import set_seed 22 | from scalr.utils import write_data 23 | 24 | 25 | class ModelTrainingPipeline: 26 | """Class for Model training pipeline.""" 27 | 28 | def __init__(self, 29 | model_config: dict, 30 | train_config: dict, 31 | dirpath: str = None, 32 | device: str = 'cpu'): 33 | """Initialize required parameters for model training pipeline. 34 | 35 | Class to get trained model from given configs 36 | 37 | Args: 38 | dirpath (str): Path to store checkpoints and logs of the model. 39 | model_config (dict): Model config. 40 | train_config (dict): Model training config. 41 | device (str, optional): Device to run model on. Defaults to 'cpu'. 42 | """ 43 | self.flow_logger = FlowLogger('ModelTraining') 44 | set_seed(42) 45 | 46 | self.train_config = train_config 47 | self.model_config = model_config 48 | self.device = device 49 | self.dirpath = dirpath 50 | 51 | def load_data_and_targets_from_config(self, data_config: dict): 52 | """A function to load data and targets from data config. 53 | 54 | Args: 55 | data_config: Data config. 56 | """ 57 | self.train_data, self.val_data = load_train_val_data_from_config( 58 | data_config) 59 | self.target = data_config.get('target') 60 | self.mappings = read_data(data_config['label_mappings']) 61 | 62 | def set_data_and_targets(self, train_data: Union[AnnData, AnnCollection], 63 | val_data: Union[AnnData, AnnCollection], 64 | target: Union[str, list[str]], mappings: dict): 65 | """A function to set data when you don't use data directly from config, 66 | but rather by other sources like feature subsetting, etc. 67 | 68 | Args: 69 | train_data (Union[AnnData, AnnCollection]): Training data. 70 | val_data (Union[AnnData, AnnCollection]): Validation data. 71 | target (Union[str, list[str]]): Target columns name(s). 72 | mappings (dict): Mapping of a column value to ids 73 | eg. mappings[column_name][label2id] = {A: 1, B:2, ...}. 74 | """ 75 | self.train_data = train_data 76 | self.val_data = val_data 77 | self.target = target 78 | self.mappings = mappings 79 | 80 | def build_model_training_artifacts(self): 81 | """This function configures the model, optimizer, and loss function required 82 | for model training. 83 | """ 84 | self.flow_logger.info('Building model training artifacts') 85 | 86 | # Building model. 87 | self.model, self.model_config = build_model(self.model_config) 88 | self.model.to(self.device) 89 | 90 | # Building optimizer. 91 | opt_config = deepcopy(self.train_config.get('optimizer')) 92 | self.opt, opt_config = self.build_optimizer( 93 | self.train_config.get('optimizer')) 94 | self.train_config['optimizer'] = opt_config 95 | 96 | # Building Loss Function. 97 | self.loss_fn, loss_config = build_loss_fn( 98 | deepcopy(self.train_config.get('loss', dict()))) 99 | self.train_config['loss'] = loss_config 100 | self.loss_fn.to(self.device) 101 | 102 | # Building Callbacks executor. 103 | self.callbacks = CallbackExecutor( 104 | self.dirpath, self.train_config.get('callbacks', list())) 105 | 106 | # Resuming from checkpoint using model weights. 107 | if self.train_config.get('resume_from_checkpoint'): 108 | self.flow_logger.info('Resuming model from checkpoint') 109 | self.flow_logger.info('Loading model weights...') 110 | self.model.load_weights(self.train_config['resume_from_checkpoint']) 111 | self.flow_logger.info('Loading optimizer state dict...') 112 | self.opt.load_state_dict( 113 | torch.load(self.train_config['resume_from_checkpoint']) 114 | ['optimizer_state_dict']) 115 | 116 | def build_optimizer(self, opt_config: dict = None): 117 | """A function to build optimizer. 118 | 119 | Args: 120 | opt_config (dict): Optimizer config. 121 | """ 122 | if not opt_config: 123 | opt_config = dict() 124 | name = opt_config.get('name', 'Adam') 125 | opt_config['name'] = name 126 | params = opt_config.get('params', dict(lr=1e-3)) 127 | opt_config['params'] = params 128 | 129 | opt = getattr(torch.optim, name)(self.model.parameters(), **params) 130 | return opt, opt_config 131 | 132 | def train(self): 133 | """This function trains the model.""" 134 | self.flow_logger.info('Training the model') 135 | # Building Trainer. 136 | trainer_name = self.train_config.get('trainer', 'SimpleModelTrainer') 137 | self.train_config['trainer'] = trainer_name 138 | 139 | Trainer = getattr(scalr.nn.trainer, trainer_name) 140 | trainer = Trainer(self.model, self.opt, self.loss_fn, self.callbacks, 141 | self.device) 142 | 143 | # Building DataLoaders. 144 | dataloader_config = self.train_config.get('dataloader') 145 | train_dl, dataloader_config = build_dataloader(dataloader_config, 146 | self.train_data, 147 | self.target, 148 | self.mappings) 149 | val_dl, dataloader_config = build_dataloader(dataloader_config, 150 | self.val_data, self.target, 151 | self.mappings) 152 | self.train_config['dataloader'] = dataloader_config 153 | 154 | epochs = self.train_config.get('epochs', 1) 155 | self.train_config['epochs'] = epochs 156 | 157 | # Train and store the best model. 158 | best_model = trainer.train(epochs, train_dl, val_dl) 159 | if self.dirpath: 160 | best_model_dir = path.join(self.dirpath, 'best_model') 161 | os.makedirs(best_model_dir, exist_ok=True) 162 | best_model.save_weights(path.join(best_model_dir, 'model.pt')) 163 | write_data(self.model_config, 164 | path.join(best_model_dir, 'model_config.yaml')) 165 | write_data(self.mappings, path.join(best_model_dir, 166 | 'mappings.json')) 167 | 168 | return best_model 169 | 170 | def get_updated_config(self): 171 | """This function returns updated configs.""" 172 | return self.model_config, self.train_config 173 | -------------------------------------------------------------------------------- /scalr/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from . import callbacks 2 | from . import dataloader 3 | from . import loss 4 | from . import model 5 | from . import trainer 6 | -------------------------------------------------------------------------------- /scalr/nn/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from ._callbacks import CallbackBase 2 | from ._callbacks import CallbackExecutor 3 | from .early_stopping import EarlyStopping 4 | from .model_checkpoint import ModelCheckpoint 5 | from .tensorboard_logger import TensorboardLogger 6 | -------------------------------------------------------------------------------- /scalr/nn/callbacks/_callbacks.py: -------------------------------------------------------------------------------- 1 | """This file is a base class for implementation of Callbacks.""" 2 | 3 | import os 4 | from os import path 5 | from typing import Union 6 | 7 | from anndata import AnnData 8 | from anndata.experimental import AnnCollection 9 | 10 | import scalr 11 | from scalr.utils import build_object 12 | 13 | 14 | class CallbackBase: 15 | """Base class to build callbacks.""" 16 | 17 | def __init__(self, dirpath='.'): 18 | """Use to generate necessary arguments or create directories.""" 19 | pass 20 | 21 | def __call__(self): 22 | """Execute the callback here.""" 23 | pass 24 | 25 | @classmethod 26 | def get_default_params(cls): 27 | """Class method to get default params for callbacks config.""" 28 | return None 29 | 30 | 31 | class CallbackExecutor: 32 | """ 33 | Wrapper class to execute all enabled callbacks. 34 | 35 | Enabled callbacks are executed with the early stopping callback 36 | executed last to return a flag for continuation or stopping of model training 37 | """ 38 | 39 | def __init__(self, dirpath: str, callbacks: list[dict]): 40 | """Intialize required parameters for callbacks. 41 | 42 | Args: 43 | dirpath: Path to store logs and checkpoints. 44 | callback: List containing multiple callbacks. 45 | """ 46 | 47 | self.callbacks = [] 48 | 49 | for callback in callbacks: 50 | if callback.get('params'): 51 | callback['params']['dirpath'] = dirpath 52 | else: 53 | callback['params'] = dict(dirpath=dirpath) 54 | callback_object, _ = build_object(scalr.nn.callbacks, callback) 55 | self.callbacks.append(callback_object) 56 | 57 | def execute(self, **kwargs) -> bool: 58 | """Execute all the enabled callbacks. Returns early stopping condition.""" 59 | 60 | early_stop = False 61 | for callback in self.callbacks: 62 | # Below `| False` is to handle cases when callbacks return None. 63 | # And we want to return true when early stopping is achieved. 64 | early_stop |= callback(**kwargs) or False 65 | 66 | return early_stop 67 | -------------------------------------------------------------------------------- /scalr/nn/callbacks/early_stopping.py: -------------------------------------------------------------------------------- 1 | """This file is an implementation of early stopping callback.""" 2 | 3 | import os 4 | from os import path 5 | 6 | import torch 7 | 8 | from scalr.nn.callbacks import CallbackBase 9 | 10 | 11 | class EarlyStopping(CallbackBase): 12 | """ 13 | Implements early stopping based upon validation loss. 14 | 15 | Attributes: 16 | patience: Number of epochs with no improvement after which training will be stopped. 17 | min_delta: Minimum change in the monitored quantity to qualify as an improvement, 18 | i.e. an absolute change of less than min_delta, will count as no improvement. 19 | """ 20 | 21 | def __init__(self, 22 | dirpath: str = None, 23 | patience: int = 3, 24 | min_delta: float = 1e-4): 25 | """Intialize required parameters for early stopping callback. 26 | 27 | Args: 28 | patience: Number of epochs with no improvement after which training will be stopped. 29 | min_delta: Minimum change in the monitored quantity to qualify as an improvement, 30 | i.e. an absolute change of less than min_delta, will count as no improvement. 31 | epoch: An interger count of epochs trained. 32 | min_validation_loss: Keeps track of the minimum validation loss across all epochs. 33 | """ 34 | self.patience = int(patience) 35 | self.min_delta = float(min_delta) 36 | self.epoch = 0 37 | self.min_val_loss = float('inf') 38 | 39 | def __call__(self, val_loss: float, **kwargs) -> bool: 40 | """Return `True` if model training needs to be stopped based upon improvement conditions. 41 | Else returns `False` for continued training. 42 | """ 43 | if val_loss < self.min_val_loss: 44 | self.min_val_loss = val_loss 45 | self.epoch = 0 46 | elif val_loss >= (self.min_val_loss + self.min_delta): 47 | self.epoch += 1 48 | if self.epoch >= self.patience: 49 | return True 50 | return False 51 | 52 | @classmethod 53 | def get_default_params(cls): 54 | """Class method to get default params for model_config.""" 55 | return dict(patience=3, min_delta=1e-4) 56 | -------------------------------------------------------------------------------- /scalr/nn/callbacks/model_checkpoint.py: -------------------------------------------------------------------------------- 1 | """This file is an implementation of model checkpoint callback.""" 2 | 3 | import os 4 | from os import path 5 | 6 | import torch 7 | 8 | from scalr.nn.callbacks import CallbackBase 9 | 10 | 11 | class ModelCheckpoint(CallbackBase): 12 | """Model checkpointing to save model weights at regular intervals. 13 | 14 | Attributes: 15 | epoch: An interger count of epochs trained. 16 | max_validation_acc: Keeps track of the maximum validation accuracy across all epochs. 17 | interval: Regular interval of model checkpointing. 18 | """ 19 | 20 | def __init__(self, dirpath: str, interval: int = 5): 21 | """Intialize required parameters for model checkpoint callback. 22 | 23 | Args: 24 | dirpath: To store the respective model checkpoints. 25 | interval: Regular interval of model checkpointing. 26 | """ 27 | 28 | self.epoch = 0 29 | self.interval = int(interval) 30 | self.dirpath = dirpath 31 | 32 | if self.interval: 33 | os.makedirs(path.join(dirpath, 'checkpoints'), exist_ok=True) 34 | 35 | def save_checkpoint(self, model_state_dict: dict, opt_state_dict: dict, 36 | path: str): 37 | """A function to save model & optimizer state dict to the given path. 38 | 39 | Args: 40 | model_state_dict: Model's state dict. 41 | opt_state_dict: Optimizer's state dict. 42 | path: Path to store checkpoint to. 43 | """ 44 | torch.save( 45 | { 46 | 'epoch': self.epoch, 47 | 'model_state_dict': model_state_dict, 48 | 'optimizer_state_dict': opt_state_dict 49 | }, path) 50 | 51 | def __call__(self, model_state_dict: dict, opt_state_dict: dict, **kwargs): 52 | """A function that evaluates when to save a checkpoint. 53 | 54 | Args: 55 | model_state_dict: Model's state dict. 56 | opt_state_dict: Optimizer's state dict. 57 | """ 58 | self.epoch += 1 59 | if self.interval and self.epoch % self.interval == 0: 60 | self.save_checkpoint( 61 | model_state_dict, opt_state_dict, 62 | path.join(self.dirpath, 'checkpoints', 63 | f'model_{self.epoch}.pt')) 64 | 65 | @classmethod 66 | def get_default_params(cls): 67 | """Class method to get default params for model_config.""" 68 | return dict(dirpath='.', interval=5) 69 | -------------------------------------------------------------------------------- /scalr/nn/callbacks/tensorboard_logger.py: -------------------------------------------------------------------------------- 1 | """This file is an implementation of Tensorboard logging callback.""" 2 | 3 | import os 4 | from os import path 5 | 6 | import torch 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | from scalr.nn.callbacks import CallbackBase 10 | 11 | 12 | class TensorboardLogger(CallbackBase): 13 | """ 14 | Tensorboard logging of the training process. 15 | 16 | Attributes: 17 | epoch: An interger count of epochs trained. 18 | writer: Object that writes to tensorboard. 19 | """ 20 | 21 | def __init__(self, dirpath: str = '.'): 22 | """Intialize required parameters for tensorboard logging callback. 23 | 24 | Args: 25 | dirpath: Path of directory to store the experiment logs. 26 | """ 27 | self.writer = SummaryWriter(path.join(dirpath, 'logs')) 28 | self.epoch = 0 29 | 30 | def __call__(self, train_loss: float, train_acc: float, val_loss: float, 31 | val_acc: float, **kwargs): 32 | """Logs the train_loss, val_loss, train_accuracy, val_accuracy for each epoch.""" 33 | self.epoch += 1 34 | self.writer.add_scalars('Loss', { 35 | 'train': train_loss, 36 | 'val': val_loss 37 | }, self.epoch) 38 | self.writer.add_scalars('Accuracy', { 39 | 'train': train_acc, 40 | 'val': val_acc 41 | }, self.epoch) 42 | 43 | @classmethod 44 | def get_default_params(cls): 45 | """Class method to get default params for model_config.""" 46 | return dict(dirpath='.') 47 | -------------------------------------------------------------------------------- /scalr/nn/callbacks/test_early_stopping.py: -------------------------------------------------------------------------------- 1 | """This is a test file for early_stopping.py""" 2 | 3 | from copy import deepcopy 4 | 5 | from scalr.nn.callbacks import EarlyStopping 6 | 7 | 8 | def test_early_stopping(): 9 | """This function tests early stopping of the model.""" 10 | 11 | # Custom-defined validation loss to check early stopping. 12 | val_losses = [5, 2, 3, 2.1, 1.9, 3.0, 2.5, 2.0, 0.7, 0.4] 13 | patience = 3 14 | 15 | # The model should early stop at epoch 8 (val_loss=2.0) based on defined patience. 16 | expected_early_stop_epoch = 8 17 | 18 | # Creating objects for early stopping. 19 | early_stop = EarlyStopping(patience=patience) 20 | 21 | # Iterating over above val_losses to test epoch at which it is early stopping. 22 | observed_epochs = 1 23 | for val_loss in val_losses: 24 | if early_stop.__call__(val_loss=deepcopy(val_loss)): 25 | break 26 | observed_epochs += 1 27 | 28 | assert observed_epochs==expected_early_stop_epoch, f"There is some issue in early stopping."\ 29 | f" Expected epochs({expected_early_stop_epoch}) != observed epoch({observed_epochs}). Please check!" 30 | -------------------------------------------------------------------------------- /scalr/nn/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | from ._dataloader import build_dataloader 2 | from ._dataloader import DataLoaderBase 3 | from .simple_dataloader import SimpleDataLoader 4 | from .simple_metadataloader import SimpleMetaDataLoader 5 | -------------------------------------------------------------------------------- /scalr/nn/dataloader/_dataloader.py: -------------------------------------------------------------------------------- 1 | """This file is a base class for dataloader.""" 2 | 3 | from typing import Union 4 | 5 | from anndata import AnnData 6 | from anndata.experimental import AnnCollection 7 | from anndata.experimental import AnnLoader 8 | import torch 9 | from torch import Tensor 10 | 11 | import scalr 12 | from scalr.utils import build_object 13 | 14 | 15 | class DataLoaderBase: 16 | 17 | def __init__( 18 | self, 19 | batch_size: int = 1, 20 | target: Union[str, list[str]] = None, 21 | mappings: dict = None, 22 | ): 23 | """Initilize required parameters for dataloader. 24 | 25 | Args: 26 | batch_size (int, optional): _description_. Defaults to 1. 27 | target ([str, list[str]]): List of target. Defaults to None. 28 | mappings (dict): List of label mappings of each target to. Defaults to None. 29 | """ 30 | self.batch_size = batch_size 31 | self.target = target 32 | self.mappings = mappings 33 | 34 | # Abstract 35 | def collate_fn(self, batch): 36 | """Collate function for dataloader. Should be implemented in child classes. 37 | 38 | Given an input anndata of batch_size, the collate function creates inputs and outputs. 39 | It can also be used to perform batch-wise operations. 40 | """ 41 | pass 42 | 43 | def get_targets_ids_from_mappings( 44 | self, adata: Union[AnnData, AnnCollection]) -> list[Tensor]: 45 | """Helper function to generate target ids from label mappings. 46 | 47 | Args: 48 | adata (Union[AnnData, AnnCollection]): Anndata object containing targets in `obs`. 49 | """ 50 | target_ids = [] 51 | if isinstance(self.target, str): 52 | targets = [self.target] 53 | else: 54 | targets = self.target 55 | 56 | for target in targets: 57 | target_mappings = self.mappings[target]['label2id'] 58 | target_ids.append( 59 | torch.as_tensor( 60 | adata.obs[self.target].astype('category').cat. 61 | rename_categories(target_mappings).astype('int64').values)) 62 | 63 | return target_ids 64 | 65 | def __call__(self, adata): 66 | """Returns a Torch DataLoader object.""" 67 | return AnnLoader(adata, 68 | batch_size=self.batch_size, 69 | collate_fn=lambda batch: self.collate_fn(batch)) 70 | 71 | @classmethod 72 | def get_default_params(cls) -> dict: 73 | """Class method to get default params for model_config.""" 74 | return dict() 75 | 76 | 77 | def build_dataloader(dataloader_config: dict, 78 | adata: Union[AnnData, AnnCollection], target: str, 79 | mappings: dict) -> tuple[DataLoaderBase, dict]: 80 | """Builder object to get DataLoader, updated dataloader_config. 81 | 82 | Args: 83 | dataloader_config (dict): Config to build dataloader. 84 | adata (Union[AnnData, AnnCollection]): Data to load. 85 | target (str): Target column in data.obs. 86 | mappings (dict): Mappings of column labels to ids. 87 | """ 88 | if not dataloader_config.get('name'): 89 | raise ValueError('DataLoader name is required!') 90 | 91 | dataloader_config['params'] = dataloader_config.get('params', 92 | dict(batch_size=1)) 93 | dataloader_config['params']['target'] = target 94 | dataloader_config['params']['mappings'] = mappings 95 | 96 | dataloader_object, dataloader_config = build_object(scalr.nn.dataloader, 97 | dataloader_config) 98 | dataloader_config['params'].pop('target') 99 | dataloader_config['params'].pop('mappings') 100 | 101 | return dataloader_object(adata), dataloader_config 102 | -------------------------------------------------------------------------------- /scalr/nn/dataloader/simple_dataloader.py: -------------------------------------------------------------------------------- 1 | """This file is the implementation of simpledataloader.""" 2 | 3 | from typing import Union 4 | 5 | from anndata import AnnData 6 | from anndata.experimental import AnnCollection 7 | import torch 8 | from torch import Tensor 9 | from torch.nn.functional import pad 10 | 11 | from scalr.nn.dataloader import DataLoaderBase 12 | 13 | 14 | class SimpleDataLoader(DataLoaderBase): 15 | """Class for simple dataloader. 16 | 17 | Simple DataLoader converts all adata values to inputs, and target columns in metadata 18 | to output labels. 19 | 20 | Returns: 21 | PyTorch DataLoader object with (X: Tensor [batch_size, features], y: Tensor [batch_size, ]). 22 | """ 23 | 24 | def __init__(self, 25 | batch_size: int, 26 | target: str, 27 | mappings: dict, 28 | padding: int = None): 29 | """ 30 | Args: 31 | batch_size (int): Number of samples to be loaded in each batch. 32 | target (str): Corresponding metadata name to be treated as training 33 | objective in classification. Must be present as a column_name in `adata.obs`. 34 | mappings (dict): Mapping the target name to respective ids. 35 | padding (int): Padding size in case of #features < model input size. 36 | """ 37 | super().__init__(batch_size, target, mappings) 38 | self.padding = padding 39 | 40 | def collate_fn( 41 | self, 42 | adata_batch: Union[AnnData, AnnCollection], 43 | ) -> tuple[Tensor, Tensor]: 44 | """Given an input anndata of batch_size, the collate function creates inputs and outputs. 45 | 46 | Args: 47 | adata_batch (Union[AnnData, AnnCollection]): Anndata view object with batch_size samples. 48 | 49 | Returns: 50 | Tuple(x, y): Input to model, output from data. 51 | """ 52 | 53 | x = adata_batch.X.float() 54 | # Handle the case when observed #features are less than expected #features by the model. 55 | # Features(0s) are padded after actual features in that case to make it consistent for model training. 56 | if self.padding and x.shape[1] < self.padding: 57 | x = pad(x, (0, self.padding - x.shape[1]), 'constant', 0.0) 58 | y = self.get_targets_ids_from_mappings(adata_batch)[0] 59 | 60 | return x, y 61 | 62 | @classmethod 63 | def get_default_params(cls) -> dict: 64 | """Class method to get default params for model_config.""" 65 | return dict(batch_size=1, target=None, mappings=dict()) 66 | -------------------------------------------------------------------------------- /scalr/nn/dataloader/simple_metadataloader.py: -------------------------------------------------------------------------------- 1 | """This file is the implementation of simple metadataloader.""" 2 | 3 | from typing import Union 4 | 5 | from anndata import AnnData 6 | from anndata.experimental import AnnCollection 7 | import numpy as np 8 | from sklearn.preprocessing import OneHotEncoder 9 | import torch 10 | 11 | from scalr.nn.dataloader import SimpleDataLoader 12 | 13 | 14 | class SimpleMetaDataLoader(SimpleDataLoader): 15 | """Class for simple metadataloader. 16 | 17 | Simple MetaDataLoader converts all adata values to inputs, concat specified metadata columns as onehotencoded vector 18 | to feature data and map target columns in metadata to output labels. 19 | 20 | Returns: 21 | PyTorch DataLoader object with (X: Tensor [batch_size, features], y: Tensor [batch_size, ]). 22 | """ 23 | 24 | def __init__(self, 25 | batch_size: int, 26 | target: str, 27 | mappings: dict, 28 | metadata_col: list[str], 29 | padding: int = None): 30 | """ 31 | Args: 32 | batch_size (int): Number of samples to be loaded in each batch. 33 | target (str): Corresponding metadata name to be treated as training 34 | objective in classification. Must be present as a column_name in `adata.obs`. 35 | mappings (dict): Mapping the target name to respective ids. 36 | metadata_col (list): List of metadata columns to be onehotencoded. 37 | padding (int): Padding size incase of #features < model input size. 38 | """ 39 | super().__init__(batch_size=batch_size, 40 | target=target, 41 | mappings=mappings, 42 | padding=padding) 43 | self.mappings = mappings 44 | self.metadata_col = metadata_col 45 | 46 | # Generating OneHotEncoder object for specified metadata_col. 47 | self.metadata_onehotencoder = {} 48 | for col in self.metadata_col: 49 | ohe = OneHotEncoder(handle_unknown='ignore') 50 | ohe.fit(np.array(sorted(mappings[col]['id2label'])).reshape(-1, 1)) 51 | self.metadata_onehotencoder[col] = ohe 52 | 53 | def collate_fn( 54 | self, 55 | adata_batch: Union[AnnData, AnnCollection], 56 | ): 57 | """Given an input anndata of batch_size, the collate function creates inputs and outputs. 58 | 59 | Args: 60 | adata_batch (Union[AnnData, AnnCollection]): Anndata view object with batch_size samples. 61 | 62 | Returns: 63 | Tuple(x, y): Input to model, output from data. 64 | """ 65 | 66 | # Getting x & y 67 | x, y = super().collate_fn(adata_batch) 68 | 69 | # One hot encoding requested metadata columns and adding to features data. 70 | for col in self.metadata_col: 71 | x = torch.cat( 72 | (x, 73 | torch.as_tensor(self.metadata_onehotencoder[col].transform( 74 | adata_batch.obs[col].values.reshape(-1, 1)).A, 75 | dtype=torch.float32)), 76 | dim=1) 77 | return x, y 78 | 79 | @classmethod 80 | def get_default_params(cls) -> dict: 81 | """Class method to get default params for model_config.""" 82 | return dict(batch_size=1, target=None, mappings=dict()) 83 | -------------------------------------------------------------------------------- /scalr/nn/dataloader/test_simple_dataloader.py: -------------------------------------------------------------------------------- 1 | """This is a test file for simpledataloader.""" 2 | 3 | from scalr.nn.dataloader import build_dataloader 4 | from scalr.utils import generate_dummy_anndata 5 | 6 | 7 | def test_metadataloader(): 8 | """This function tests features shape returned by simpledataloader for the below 2 cases. 9 | 1. #features are consistent with feature_subsetsize. No padding is required. 10 | 2. #features are less than feature_subsetsize. This case needs padding. 11 | """ 12 | 13 | # Generating dummy anndata. 14 | n_samples = 30 15 | n_features = 13 16 | adata = generate_dummy_anndata(n_samples=n_samples, n_features=n_features) 17 | 18 | # Generating mappings for anndata obs columns. 19 | mappings = {} 20 | for column_name in adata.obs.columns: 21 | mappings[column_name] = {} 22 | 23 | id2label = [] 24 | id2label += adata.obs[column_name].astype( 25 | 'category').cat.categories.tolist() 26 | 27 | label2id = {id2label[i]: i for i in range(len(id2label))} 28 | mappings[column_name]['id2label'] = id2label 29 | mappings[column_name]['label2id'] = label2id 30 | 31 | # Test case 1 32 | # Expected features shape after dataloading is (batch_size, 13). 33 | # So no padding is required as adata n_features=13. But we can pass 34 | # `padding=feature_subsetsize` in dataloader_config. 35 | 36 | ## Defining required parameters for simpledataloader. 37 | feature_subsetsize = 13 38 | dataloader_config = { 39 | 'name': 'SimpleDataLoader', 40 | 'params': { 41 | 'batch_size': 10, 42 | 'padding': feature_subsetsize, 43 | } 44 | } 45 | dataloader, _ = build_dataloader(dataloader_config=dataloader_config, 46 | adata=adata, 47 | target='celltype', 48 | mappings=mappings) 49 | 50 | ## Comparing expecting features shape after using metadatloader. 51 | for feature, _ in dataloader: 52 | assert feature.shape[ 53 | 1] == feature_subsetsize, f"There is some issue in simpledataloader."\ 54 | f" Expected #features({n_features}) != Observed #features({feature.shape[1]}). Please check!" 55 | # Breaking, as checking only the first batch is enough. 56 | break 57 | 58 | # Test case 2 59 | # Expected features shape after dataloading is (batch_size, 20). 60 | # So padding is required as adata n_features=13. Hence 7 columns should be padded in dataloader. 61 | 62 | ## Defining required parameters for simpledataloader. 63 | feature_subsetsize = 20 64 | dataloader_config = { 65 | 'name': 'SimpleDataLoader', 66 | 'params': { 67 | 'batch_size': 10, 68 | 'padding': feature_subsetsize, 69 | } 70 | } 71 | dataloader, _ = build_dataloader(dataloader_config=dataloader_config, 72 | adata=adata, 73 | target='celltype', 74 | mappings=mappings) 75 | 76 | ## Comparing expected features shape after using metadatloader. 77 | for feature, _ in dataloader: 78 | assert feature.shape[ 79 | 1] == feature_subsetsize, f"There is some issue in simpledataloader."\ 80 | f" Expected #features({feature_subsetsize}) != Observed #features({feature.shape[1]}). Please check!" 81 | # Breaking, as checking only the first batch is enough. 82 | break 83 | -------------------------------------------------------------------------------- /scalr/nn/dataloader/test_simple_metadataloader.py: -------------------------------------------------------------------------------- 1 | '''This is a test file for simplemetadataloader.''' 2 | 3 | import anndata 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from scalr.nn.dataloader import build_dataloader 8 | from scalr.utils import generate_dummy_anndata 9 | 10 | 11 | def test_metadataloader(): 12 | 13 | # Generating dummy anndata. 14 | adata = generate_dummy_anndata(n_samples=15, n_features=7) 15 | 16 | # Generating mappings for anndata obs columns. 17 | mappings = {} 18 | for column_name in adata.obs.columns: 19 | mappings[column_name] = {} 20 | 21 | id2label = [] 22 | id2label += adata.obs[column_name].astype( 23 | 'category').cat.categories.tolist() 24 | 25 | label2id = {id2label[i]: i for i in range(len(id2label))} 26 | mappings[column_name]['id2label'] = id2label 27 | mappings[column_name]['label2id'] = label2id 28 | 29 | # Defining required parameters for metadataloader. 30 | metadata_col = ['batch', 'env'] 31 | dataloader_config = { 32 | 'name': 'SimpleMetaDataLoader', 33 | 'params': { 34 | 'batch_size': 10, 35 | 'metadata_col': metadata_col 36 | } 37 | } 38 | dataloader, _ = build_dataloader(dataloader_config=dataloader_config, 39 | adata=adata, 40 | target='celltype', 41 | mappings=mappings) 42 | 43 | # Comparing expecting features shape after using metadatloader. 44 | for feature, _ in dataloader: 45 | assert feature.shape[1] == len( 46 | adata.var_names) + adata.obs[metadata_col].nunique().sum() 47 | # Breaking, as checking only first batch is enough. 48 | break 49 | -------------------------------------------------------------------------------- /scalr/nn/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from ._loss import build_loss_fn 2 | from ._loss import CustomLossBase 3 | -------------------------------------------------------------------------------- /scalr/nn/loss/_loss.py: -------------------------------------------------------------------------------- 1 | """This file is a base class for loss functions.""" 2 | 3 | from typing import Union 4 | 5 | from anndata import AnnData 6 | from anndata.experimental import AnnCollection 7 | import torch 8 | from torch import nn 9 | 10 | import scalr 11 | 12 | 13 | class CustomLossBase(nn.Module): 14 | """Base class to implement custom loss functions.""" 15 | 16 | def __init__(self): 17 | super().__init__() 18 | self.criterion = None 19 | 20 | def forward(self, out, preds): 21 | """Returns loss betwen outputs and predictions.""" 22 | return self.criterion(out, preds) 23 | 24 | 25 | def build_loss_fn(loss_config): 26 | """Builder object to get Loss function, updated loss_config.""" 27 | name = loss_config.get('name') 28 | if not name: 29 | raise ValueError('Loss function not provided') 30 | 31 | params = loss_config.get('params', dict()) 32 | 33 | # TODO: Add provision for custom loss object 34 | loss_fn = getattr(torch.nn, name)(**params) 35 | return loss_fn, loss_config 36 | -------------------------------------------------------------------------------- /scalr/nn/model/__init__.py: -------------------------------------------------------------------------------- 1 | from ._model import build_model 2 | from ._model import ModelBase 3 | from .sequential_model import SequentialModel 4 | from .shap_model import CustomShapModel 5 | -------------------------------------------------------------------------------- /scalr/nn/model/_model.py: -------------------------------------------------------------------------------- 1 | """This file is a base class for the model.""" 2 | 3 | from typing import Union 4 | 5 | from anndata import AnnData 6 | from anndata.experimental import AnnCollection 7 | import torch 8 | from torch import nn 9 | from torch import Tensor 10 | from torch.utils.data import DataLoader 11 | 12 | import scalr 13 | from scalr.utils import build_object 14 | 15 | 16 | class ModelBase(nn.Module): 17 | """Class for the model. 18 | 19 | Contains different methods to make a forward() call, load, save weights 20 | and predict the data provided. 21 | """ 22 | 23 | def __init__(self): 24 | super().__init__() 25 | 26 | def forward(self, x: Tensor) -> Tensor: 27 | """A function for forward pass of the model to generate outputs.""" 28 | pass 29 | 30 | def load_weights(self, model_weights_path: str): 31 | """A function to initialize model weights from previous weights.""" 32 | self.load_state_dict(torch.load(model_weights_path)['model_state_dict']) 33 | 34 | def save_weights(self, model_weights_path: str): 35 | """A function to save model weights at the path.""" 36 | torch.save({'model_state_dict': self.state_dict()}, model_weights_path) 37 | 38 | def get_predictions(self, dl: DataLoader, device: str = 'cpu'): 39 | """A function to get predictions from the dataloader.""" 40 | pass 41 | 42 | @classmethod 43 | def get_default_params(cls) -> dict: 44 | """Class method to get default params for model_config.""" 45 | return dict() 46 | 47 | 48 | def build_model(model_config: dict) -> tuple[nn.Module, dict]: 49 | """Builder object to get Model, updated model_config.""" 50 | return build_object(scalr.nn.model, model_config) 51 | -------------------------------------------------------------------------------- /scalr/nn/model/sequential_model.py: -------------------------------------------------------------------------------- 1 | """This file is an implementation of a sequential model.""" 2 | 3 | from typing import Tuple 4 | 5 | import torch 6 | from torch import nn 7 | from torch import Tensor 8 | from torch.utils.data import DataLoader 9 | 10 | from scalr.nn.model import ModelBase 11 | 12 | 13 | class SequentialModel(ModelBase): 14 | """Class for Deep Neural Network model with linear layers.""" 15 | 16 | def __init__(self, 17 | layers: list[int], 18 | dropout: float = 0, 19 | activation: str = 'ReLU', 20 | weights_init_zero: bool = False): 21 | """Initialize required parameters for the linear model. 22 | 23 | Args: 24 | layers (list[int]): List of layers' feature size going from 25 | input_features to output_features. 26 | dropout (float, optional): Dropout after each layer. 27 | Floating point value [0,1). 28 | Defaults to 0. 29 | activation (str, optional): Activation function class after each layer. 30 | Defaults to 'ReLU'. 31 | weights_init_zero (bool, optional): [Bool] to initialize weights of the model to zero. 32 | Defaults to False. 33 | """ 34 | super().__init__() 35 | 36 | try: 37 | activation = getattr(nn, activation)() 38 | except: 39 | raise ValueError( 40 | f'{activation} is not a valid activation function name in torch.nn' 41 | ) 42 | 43 | dropout = nn.Dropout(dropout) 44 | 45 | self.layers = nn.ModuleList() 46 | n = len(layers) 47 | for i in range(n - 2): 48 | self.layers.append(nn.Linear(layers[i], layers[i + 1])) 49 | self.layers.append(activation) 50 | self.layers.append(dropout) 51 | self.out_layer = nn.Linear(layers[n - 2], layers[n - 1]) 52 | 53 | self.weights_init_zero = weights_init_zero 54 | if weights_init_zero: 55 | self.make_weights_zero() 56 | 57 | def make_weights_zero(self): 58 | """A function to initialize layer weights to 0.""" 59 | for layer in self.layers: 60 | torch.nn.init.constant_(layer.weight, 0.0) 61 | torch.nn.init.constant_(self.out_layer.weight, 0.0) 62 | 63 | def forward(self, x: Tensor) -> Tensor: 64 | """Pass input through the network. 65 | 66 | Args: 67 | x: Tensor, shape [batch_size, layers[0]]. 68 | 69 | Returns: 70 | Output dict containing batched layers[-1]-dimensional vectors in ['cls_output'] key. 71 | """ 72 | output = {} 73 | 74 | for i, layer in enumerate(self.layers): 75 | x = layer(x) 76 | output[f'layer{i}_output'] = x 77 | 78 | output['cls_output'] = self.out_layer(x) 79 | return output 80 | 81 | def get_predictions( 82 | self, 83 | dl: DataLoader, 84 | device: str = 'cpu' 85 | ) -> Tuple[list[int], list[int], list[list[int]]]: 86 | """A function to get predictions from a model, from the dataloader. 87 | 88 | Args: 89 | dl (DataLoader): DataLoader object containing samples. 90 | device (str, optional): Device to run the model on. Defaults to 'cpu'. 91 | 92 | Returns: 93 | True labels, Predicted labels, Predicted probabilities of all samples 94 | in the dataloader. 95 | """ 96 | self.eval() 97 | test_labels, pred_labels, pred_probabilities = [], [], [] 98 | 99 | for batch in dl: 100 | with torch.no_grad(): 101 | x, y = batch[0].to(device), batch[1].to(device) 102 | out = self(x)['cls_output'] 103 | 104 | test_labels += y.tolist() 105 | pred_labels += torch.argmax(out, dim=1).tolist() 106 | pred_probabilities += out.tolist() 107 | 108 | return test_labels, pred_labels, pred_probabilities 109 | 110 | @classmethod 111 | def get_default_params(cls): 112 | """Class method to get default params for model_config.""" 113 | return dict(layers=None, 114 | dropout=0, 115 | activation='ReLU', 116 | weights_init_zero=False) 117 | -------------------------------------------------------------------------------- /scalr/nn/model/shap_model.py: -------------------------------------------------------------------------------- 1 | """This file is an implementation of the custom SHAP model.""" 2 | 3 | from torch import nn 4 | 5 | 6 | class CustomShapModel(nn.Module): 7 | """Class for a custom model for SHAP.""" 8 | 9 | def __init__(self, model, key='cls_output'): 10 | """Initialize required parameters for SHAP model. 11 | 12 | Args: 13 | model: Trained model used for SHAP calculation. 14 | key: key from model output dict. 15 | """ 16 | super().__init__() 17 | self.model = model 18 | self.key = key 19 | 20 | def forward(self, x): 21 | """Pass input through the model and return output. 22 | 23 | Args: 24 | x: Tensor. 25 | """ 26 | output = self.model(x) 27 | 28 | if isinstance(output, dict): 29 | output = output[self.key] 30 | 31 | return output 32 | -------------------------------------------------------------------------------- /scalr/nn/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from ._trainer import TrainerBase 2 | from .simple_model_trainer import SimpleModelTrainer 3 | -------------------------------------------------------------------------------- /scalr/nn/trainer/_trainer.py: -------------------------------------------------------------------------------- 1 | """This file is a base class for a model trainer.""" 2 | 3 | from copy import deepcopy 4 | import os 5 | from os import path 6 | from time import time 7 | 8 | import torch 9 | from torch import nn 10 | from torch.nn import Module 11 | from torch.optim import Optimizer 12 | from torch.utils.data import DataLoader 13 | 14 | from scalr.nn.callbacks import CallbackExecutor 15 | from scalr.utils import EventLogger 16 | 17 | 18 | class TrainerBase: 19 | """ Class for a model trainer. It trains and validates a model.""" 20 | 21 | def __init__(self, 22 | model: Module, 23 | opt: Optimizer, 24 | loss_fn: Module, 25 | callbacks: CallbackExecutor, 26 | device: str = 'cpu'): 27 | """Initialize required parameters for a model trainer. 28 | 29 | Args: 30 | model (Module): Model to train. 31 | opt (Optimizer): Optimizer used for learning. 32 | loss_fn (Module): Loss function used for training. 33 | callbacks (CallbackExecutor): Callback executor object to carry out callbacks. 34 | device (str, optional): Device to train the data on (cuda/cpu). Defaults to 'cpu'. 35 | """ 36 | self.event_logger = EventLogger('ModelTrainer') 37 | 38 | self.model = model 39 | self.opt = opt 40 | self.loss_fn = loss_fn 41 | self.callbacks = callbacks 42 | self.device = device 43 | 44 | def train_one_epoch(self, dl: DataLoader) -> tuple[float, float]: 45 | """This function trains the model for one epoch. 46 | 47 | Args: 48 | dl: Training dataloader. 49 | 50 | Returns: 51 | Train Loss, Train Accuracy. 52 | """ 53 | self.model.train() 54 | total_loss = 0 55 | hits = 0 56 | total_samples = 0 57 | for batch in dl: 58 | x, y = [example.to(self.device) for example in batch[:-1] 59 | ], batch[-1].to(self.device) 60 | 61 | out = self.model(*x)['cls_output'] 62 | loss = self.loss_fn(out, y) 63 | 64 | #training 65 | self.opt.zero_grad() 66 | loss.backward() 67 | self.opt.step() 68 | 69 | #logging 70 | total_loss += loss.item() * x[0].size(0) 71 | total_samples += x[0].size(0) 72 | hits += (torch.argmax(out, dim=1) == y).sum().item() 73 | 74 | total_loss /= total_samples 75 | accuracy = hits / total_samples 76 | return total_loss, accuracy 77 | 78 | def validation(self, dl: DataLoader) -> tuple[float, float]: 79 | """This function performs validation of the data. 80 | 81 | Args: 82 | dl: Validation dataloader. 83 | 84 | Returns: 85 | Validation Loss, Validation Accuracy. 86 | """ 87 | self.model.eval() 88 | total_loss = 0 89 | hits = 0 90 | total_samples = 0 91 | for batch in dl: 92 | with torch.no_grad(): 93 | x, y = [example.to(self.device) for example in batch[:-1] 94 | ], batch[-1].to(self.device) 95 | out = self.model(*x)['cls_output'] 96 | loss = self.loss_fn(out, y) 97 | 98 | #logging 99 | hits += (torch.argmax(out, dim=1) == y).sum().item() 100 | total_loss += loss.item() * x[0].size(0) 101 | total_samples += x[0].size(0) 102 | 103 | total_loss /= total_samples 104 | accuracy = hits / total_samples 105 | 106 | return total_loss, accuracy 107 | 108 | def train(self, epochs: int, train_dl: DataLoader, val_dl: DataLoader): 109 | """This function trains the model, and executes callbacks. 110 | 111 | Args: 112 | epochs: Max number of epochs to train model on. 113 | train_dl: Training dataloader. 114 | val_dl: Validation dataloader. 115 | """ 116 | best_val_acc = 0 117 | best_model = deepcopy(self.model) 118 | 119 | for epoch in range(epochs): 120 | ep_start = time() 121 | self.event_logger.info(f'Epoch {epoch+1}:') 122 | train_loss, train_acc = self.train_one_epoch(train_dl) 123 | self.event_logger.info( 124 | f'Training Loss: {train_loss} || Training Accuracy: {train_acc}' 125 | ) 126 | val_loss, val_acc = self.validation(val_dl) 127 | self.event_logger.info( 128 | f'Validation Loss: {val_loss} || Validation Accuracy: {val_acc}' 129 | ) 130 | ep_end = time() 131 | self.event_logger.info(f'Time: {ep_end-ep_start}\n') 132 | 133 | if val_acc > best_val_acc: 134 | best_val_acc = val_acc 135 | best_model = deepcopy(self.model) 136 | 137 | if self.callbacks.execute(model_state_dict=self.model.state_dict(), 138 | opt_state_dict=self.opt.state_dict(), 139 | train_loss=train_loss, 140 | train_acc=train_acc, 141 | val_loss=val_loss, 142 | val_acc=val_acc): 143 | break 144 | 145 | return best_model 146 | -------------------------------------------------------------------------------- /scalr/nn/trainer/simple_model_trainer.py: -------------------------------------------------------------------------------- 1 | """This file is a wrapper for Model trainer base class.""" 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import Module 6 | from torch.optim import Optimizer 7 | from torch.utils.data import DataLoader 8 | 9 | from scalr.nn.callbacks import CallbackExecutor 10 | from scalr.nn.trainer import TrainerBase 11 | 12 | 13 | class SimpleModelTrainer(TrainerBase): 14 | """Class for Simple model trainer. 15 | 16 | It works with dataloaders which contain all input tensors in line 17 | with model input, and the last tensor as target to train the model. 18 | """ 19 | 20 | def __init__(self, *args, **kwargs): 21 | super().__init__(*args, **kwargs) 22 | -------------------------------------------------------------------------------- /scalr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_utils import generate_dummy_anndata 2 | from .data_utils import generate_dummy_dge_anndata 3 | from .file_utils import load_full_data_from_config 4 | from .file_utils import load_test_data_from_config 5 | from .file_utils import load_train_val_data_from_config 6 | from .file_utils import read_data 7 | from .file_utils import write_chunkwise_data 8 | from .file_utils import write_data 9 | from .logger import EventLogger 10 | from .logger import FlowLogger 11 | from .misc_utils import build_object 12 | from .misc_utils import overwrite_default 13 | from .misc_utils import set_seed 14 | -------------------------------------------------------------------------------- /scalr/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | """This file contains functions related to data utility.""" 2 | 3 | from typing import Union 4 | 5 | from anndata import AnnData 6 | from anndata.experimental import AnnCollection 7 | import numpy as np 8 | import pandas as pd 9 | from scipy.sparse import csr_matrix 10 | from sklearn.preprocessing import OneHotEncoder 11 | import torch 12 | 13 | 14 | def get_one_hot_matrix(data: np.array): 15 | """This function returns a one-hot matrix of given labels. 16 | 17 | Args: 18 | data: Categorical data of dim 1D or 2D array. 19 | 20 | Returns: 21 | one-hot matrix. 22 | """ 23 | if data.ndim == 1: 24 | data = data.reshape(-1, 1) 25 | ohe = OneHotEncoder().fit(data) 26 | one_hot_matrix = ohe.transform(data).toarray() 27 | 28 | return one_hot_matrix 29 | 30 | 31 | def get_random_samples( 32 | data: Union[AnnData, AnnCollection], 33 | n_random_samples: int, 34 | ) -> torch.tensor: 35 | """This function returns random N samples from given data. 36 | 37 | Args: 38 | data: AnnData or AnnCollection object. 39 | n_random_samples: number of random samples to extract from the data. 40 | 41 | Returns: 42 | Chosen random samples tensor. 43 | """ 44 | 45 | random_indices = np.random.randint(0, data.shape[0], n_random_samples) 46 | random_background_data = data[random_indices].X 47 | 48 | if not isinstance(random_background_data, np.ndarray): 49 | random_background_data = random_background_data.A 50 | 51 | random_background_data = torch.as_tensor(random_background_data, 52 | dtype=torch.float32) 53 | 54 | return random_background_data 55 | 56 | 57 | def generate_dummy_anndata(n_samples, n_features, target_name='celltype'): 58 | """This function returns anndata object of shape (n_samples, n_features). 59 | 60 | It generates random values for target, batch & env from below mentioned choices. 61 | If you require more columns, you can add them in the below adata.obs without editing 62 | already existing columns. 63 | 64 | Args: 65 | n_samples: Number of samples in anndata. 66 | n_features: Number of features in anndata. 67 | target_name: Any preferred target name. Default is `celltype`. 68 | 69 | Returns: 70 | Anndata object. 71 | """ 72 | 73 | # Setting seed for reproducibility. 74 | np.random.seed(0) 75 | 76 | # Creating anndata object. 77 | adata = AnnData(X=np.random.rand(n_samples, n_features)) 78 | adata.obs = pd.DataFrame.from_dict({ 79 | target_name: np.random.choice(['B', 'C', 'DC', 'T'], size=n_samples), 80 | 'batch': np.random.choice(['batch1', 'batch2'], size=n_samples), 81 | 'env': np.random.choice(['env1', 'env2', 'env3'], size=n_samples) 82 | }) 83 | adata.obs.index = adata.obs.index.astype(str) 84 | 85 | return adata 86 | 87 | 88 | def generate_dummy_dge_anndata(n_donors: int = 5, 89 | cell_type_list: list[str] = [ 90 | 'B_cell', 'T_cell', 'DC' 91 | ], 92 | cell_replicate: int = 2, 93 | n_vars: int = 10) -> AnnData: 94 | """This function returns anndata object for DGE analysis 95 | with shape (n_donors*len(cell_type_list)*cell_replicate, n_vars). 96 | 97 | It generates obs with random donors with a fixed clinical condition (disease_x or normal). 98 | Includes all the cell types in `cell_type_list` with number of `cell_replicate` for each donor. 99 | It generates a csr(Compressed Sparse Row) matrix with random gene expression values. 100 | It generates var with random gene name as `var.index` of length `n_vars`. 101 | 102 | Args: 103 | n_donors: Number of donors or subjects in `anndata.obs`. 104 | cell_type_list: List of different cell types to include. 105 | cell_replicate: Number of cell replicates per cell type. 106 | n_vars: Number of genes to include in `anndata.var`. 107 | 108 | Returns: 109 | Anndata object. 110 | """ 111 | 112 | # Setting seed for reproducibility. 113 | np.random.seed(0) 114 | 115 | donor_list = [f'D_{i}' for i in range(1, n_donors + 1)] 116 | condition_array = np.random.choice(['disease_x', 'normal'], 117 | size=n_donors, 118 | replace=True) 119 | 120 | # Creating obs 121 | obs_data = [] 122 | for donor, condition in zip(donor_list, condition_array): 123 | obs_data.extend([{ 124 | 'donor_id': donor, 125 | 'cell_type': cell_type, 126 | 'disease': condition 127 | } for i in range(cell_replicate) for cell_type in cell_type_list]) 128 | obs = pd.DataFrame(obs_data) 129 | n_obs = obs.shape[0] 130 | obs.index = obs.index.astype(str) 131 | 132 | # Random geneexpression matrix 133 | X = csr_matrix(np.random.rand(n_obs, n_vars)) 134 | 135 | # Creating var 136 | var = pd.DataFrame({ 137 | 'gene_id': [f'gid_{i}' for i in range(1, n_vars + 1)], 138 | 'gene_name': [f'gene_{i}' for i in range(1, n_vars + 1)] 139 | }).set_index('gene_name') 140 | var.index = var.index.astype(str) 141 | 142 | # Creating AnnData object 143 | adata = AnnData(X=X, obs=obs, var=var) 144 | 145 | return adata 146 | -------------------------------------------------------------------------------- /scalr/utils/logger.py: -------------------------------------------------------------------------------- 1 | """This file contains an implementation of the logger in the pipeline.""" 2 | 3 | import logging 4 | 5 | 6 | class FlowLogger(logging.Logger): 7 | """Class for flow logger. 8 | 9 | It logs high-level overview of pipeline execution in the terminal. 10 | """ 11 | level = logging.NOTSET 12 | 13 | def __init__(self, name, level=None): 14 | if level: 15 | FlowLogger.level = level 16 | 17 | if not FlowLogger.level: 18 | FlowLogger.level = logging.INFO 19 | 20 | super().__init__(name, FlowLogger.level) 21 | 22 | formatter = logging.Formatter( 23 | '%(asctime)s - %(name)s - %(levelname)s : %(message)s') 24 | 25 | handler = logging.StreamHandler() 26 | handler.setLevel(FlowLogger.level) 27 | handler.setFormatter(formatter) 28 | self.addHandler(handler) 29 | 30 | 31 | class EventLogger(logging.Logger): 32 | """Class for event logger. It logs detailed file-level logs during pipeline execution. 33 | """ 34 | level = logging.NOTSET 35 | filepath = None 36 | 37 | def __init__(self, name, level=None, filepath=None, stdout=False): 38 | """Initialize required parameters for event logger.""" 39 | if level: 40 | EventLogger.level = level 41 | 42 | if not EventLogger.level: 43 | EventLogger.level = logging.INFO 44 | 45 | super().__init__(name, EventLogger.level) 46 | 47 | if filepath: 48 | EventLogger.filepath = filepath 49 | 50 | if not EventLogger.filepath: 51 | handler = logging.NullHandler() 52 | else: 53 | handler = logging.FileHandler(EventLogger.filepath) 54 | 55 | formatter = logging.Formatter('%(message)s') 56 | 57 | handler.setLevel(EventLogger.level) 58 | handler.setFormatter(formatter) 59 | self.addHandler(handler) 60 | 61 | # If user wants to print logs to stdout 62 | if stdout: 63 | handler = logging.StreamHandler() 64 | handler.setLevel(EventLogger.level) 65 | handler.setFormatter(formatter) 66 | self.addHandler(handler) 67 | 68 | def heading(self, msg, prefix, suffix, count): 69 | """A function to configure setting for heading.""" 70 | self.info(f"\n{prefix*count} {msg} {suffix*count}\n") 71 | 72 | def heading1(self, msg): 73 | """A function to configure setting for heading 1.""" 74 | self.heading(msg, "<", ">", 10) 75 | 76 | def heading2(self, msg): 77 | """A function to configure setting for heading 2.""" 78 | self.heading(msg, "-", "-", 5) 79 | -------------------------------------------------------------------------------- /scalr/utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | """This file contains functions related to miscellaneous utilities.""" 2 | 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | import torch 8 | 9 | 10 | def set_seed(seed: int): 11 | """A function to set seed for reproducibility.""" 12 | os.environ['PYTHONHASHSEED'] = str(seed) 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed_all(seed) 17 | 18 | 19 | def overwrite_default(user_config: dict, default_config: dict) -> dict: 20 | """The function recursively overwrites information from user_config 21 | onto the default_config. 22 | """ 23 | 24 | for key in user_config.keys(): 25 | if key not in default_config.keys() or not isinstance( 26 | user_config[key], dict): 27 | default_config[key] = user_config[key] 28 | else: 29 | default_config[key] = overwrite_default(user_config[key], 30 | default_config[key]) 31 | 32 | return default_config 33 | 34 | 35 | def build_object(module, config: dict): 36 | """A builder function to build an object from its config. 37 | 38 | Args: 39 | module: Module containing the class. 40 | config: Contains the name of the class and params to initialize the object. 41 | 42 | Returns: Object, updated config. 43 | """ 44 | name = config.get('name') 45 | if not name: 46 | raise ValueError('class name not provided!') 47 | 48 | params = config.get('params', dict()) 49 | default_params = getattr(module, name).get_default_params() 50 | params = overwrite_default(params, default_params) 51 | final_config = dict(name=name, params=params) 52 | 53 | return getattr(module, name)(**params), final_config 54 | -------------------------------------------------------------------------------- /scalr/utils/test_file_utils.py: -------------------------------------------------------------------------------- 1 | """This is a test file for file_utils.py""" 2 | 3 | import os 4 | from os import path 5 | import shutil 6 | 7 | import numpy as np 8 | 9 | from scalr.utils import generate_dummy_anndata 10 | from scalr.utils import read_data 11 | from scalr.utils import write_chunkwise_data 12 | from scalr.utils import write_data 13 | 14 | 15 | def test_write_chunkwise_data(): 16 | """This function tests `write_chunkwise()`, `write_data()` & `read_data()` functions 17 | of file_utils.""" 18 | os.makedirs('./tmp', exist_ok=True) 19 | 20 | # Generating dummy anndata. 21 | adata = generate_dummy_anndata(n_samples=25, n_features=5) 22 | 23 | # Path to write full data. 24 | fulldata_path = './tmp/fulldata.h5ad' 25 | write_data(adata, fulldata_path) 26 | 27 | # sample_chunksize to store full data in chunks. 28 | sample_chunksize = 5 29 | 30 | # Path to store chunked data. 31 | dirpath = './tmp/chunked_data/' 32 | 33 | # Writing fulldata in chunks. 34 | full_data = read_data(fulldata_path) 35 | write_chunkwise_data(full_data, 36 | sample_chunksize=sample_chunksize, 37 | dirpath=dirpath) 38 | 39 | # Iterating over stored chunked data to assert shape. 40 | observed_n_chunks = 0 41 | for i in range(len(os.listdir(dirpath))): 42 | if os.path.exists(path.join(dirpath, f'{i}.h5ad')): 43 | chunked_data = read_data(path.join(dirpath, f'{i}.h5ad'), 44 | backed='r') 45 | assert chunked_data.shape == ( 46 | sample_chunksize, len(adata.var_names) 47 | ), f"There is some issue with chunk-{i}. Please check!" 48 | observed_n_chunks += 1 49 | else: 50 | break 51 | 52 | # Checking the number of chunks stored. 53 | expected_n_chunks = np.ceil(adata.shape[0] / sample_chunksize).astype(int) 54 | assert observed_n_chunks == expected_n_chunks, f"There is mismatch of observed_n_chunks - {observed_n_chunks} with expected_n_chunks - {expected_n_chunks}." 55 | 56 | shutil.rmtree('./tmp', ignore_errors=True) 57 | -------------------------------------------------------------------------------- /scalr/utils/test_misc_utils.py: -------------------------------------------------------------------------------- 1 | """This is a test file for misc_utils.py""" 2 | 3 | from copy import deepcopy 4 | 5 | from scalr.utils import overwrite_default 6 | 7 | 8 | def test_overwrite_default(): 9 | """This funciton tests `overwrite_default()` function of misc_utils.""" 10 | 11 | # User config key-values dictionary. 12 | user_config = {'a': 0, 'b': 1, 'd': 3} 13 | 14 | # Default config key-values dictionary. 15 | default_config = {'a': '5', 'b': 7, 'c': 2} 16 | 17 | # Getting updated default config using the overwrite function. 18 | updated_default_params = overwrite_default( 19 | user_config=user_config, default_config=deepcopy(default_config)) 20 | 21 | # Checking whether the (key, values) not available in `user_config` are present in 22 | # `updated_default_params` and existing (key, values) are consistent with 23 | # `user_config` or not. 24 | for key in updated_default_params.keys(): 25 | if key not in user_config: 26 | assert updated_default_params[key] == default_config[key] 27 | elif key in user_config: 28 | assert updated_default_params[key] == user_config[key] 29 | -------------------------------------------------------------------------------- /tutorials/analysis/differential_gene_expression/dge_config.yaml: -------------------------------------------------------------------------------- 1 | 2 | # EXPERIMENT 3 | full_datapath: '/path/to/anndata.h5ad' 4 | #Path to save the results 5 | dirpath: '/path/to/save/the/result' 6 | 7 | dge_type: DgePseudoBulk 8 | # dge_type: DgeLMEM 9 | 10 | psedobulk_params: 11 | celltype_column: 'cell_type' 12 | design_factor: 'disease' 13 | factor_categories: ['COVID-19', 'normal'] 14 | sum_column: 'donor_id' 15 | cell_subsets: ['non-classical monocyte','natural killer cell'] 16 | min_cell_threshold: 1 17 | fold_change: 1.5 18 | p_val: 0.05 19 | save_plot: True 20 | # lmem_params: 21 | # fixed_effect_column: 'disease' 22 | # fixed_effect_factors: ['COVID-19', 'normal'] 23 | # group: 'donor_id' 24 | # celltype_column: 'cell_type' 25 | # cell_subsets: ['non-classical monocyte','natural killer cell'] 26 | # min_cell_threshold: 10 27 | # n_cpu: 6 28 | # gene_batch_size: 1000 29 | # coef_threshold: 0 30 | # p_val: 0.05 31 | # save_plot: True 32 | -------------------------------------------------------------------------------- /tutorials/analysis/differential_gene_expression/dge_lmem_main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing 3 | import os 4 | from os import path 5 | import pickle 6 | import resource 7 | import string 8 | import sys 9 | from typing import Optional, Union, Tuple 10 | import traceback 11 | import warnings 12 | import yaml 13 | 14 | from anndata import AnnData 15 | from anndata import ImplicitModificationWarning 16 | import anndata as ad 17 | from anndata.experimental import AnnCollection 18 | from joblib import Parallel, delayed 19 | import matplotlib.pyplot as plt 20 | import numpy as np 21 | import pandas as pd 22 | from pandas import DataFrame 23 | import scanpy as sc 24 | from scipy.optimize import OptimizeWarning 25 | import statsmodels.api as sm 26 | import statsmodels.formula.api as smf 27 | from statsmodels.stats.multitest import multipletests 28 | from statsmodels.tools.sm_exceptions import HessianInversionWarning, ConvergenceWarning 29 | 30 | from scalr.analysis import DgeLMEM 31 | 32 | 33 | def main(config): 34 | test_data = sc.read_h5ad(config['full_datapath'], backed='r') 35 | dirpath = config['dirpath'] 36 | dge_type = config['dge_type'] 37 | assert (dge_type == 'DgeLMEM') and ('lmem_params' in config), ( 38 | f"Check '{dge_type}' and 'lmem_params' in dge_config file") 39 | 40 | lmem_params = config['lmem_params'] 41 | dge = DgeLMEM(fixed_effect_column=lmem_params['fixed_effect_column'], 42 | fixed_effect_factors=lmem_params['fixed_effect_factors'], 43 | group=lmem_params['group'], 44 | celltype_column=lmem_params.get('celltype_column', None), 45 | cell_subsets=lmem_params.get('cell_subsets', None), 46 | min_cell_threshold=lmem_params.get('min_cell_threshold', 10), 47 | n_cpu=lmem_params.get('n_cpu', 6), 48 | gene_batch_size=lmem_params.get('gene_batch_size', 1000), 49 | coef_threshold=lmem_params.get('coef_threshold', 0), 50 | p_val=lmem_params.get('p_val', 0.05), 51 | y_lim_tuple=lmem_params.get('y_lim_tuple', None), 52 | save_plot=lmem_params.get('save_plot', True), 53 | stdout=True) 54 | 55 | dge.generate_analysis(test_data, dirpath) 56 | 57 | 58 | if __name__ == "__main__": 59 | parser = argparse.ArgumentParser(description='Dge analysis : LMEM method') 60 | parser.add_argument('--config', 61 | '-c', 62 | type=str, 63 | required=True, 64 | help='Path to input dge_config.yaml file') 65 | argument = parser.parse_args() 66 | with open(argument.config, 'r') as config_file: 67 | config = yaml.safe_load(config_file) 68 | main(config) 69 | -------------------------------------------------------------------------------- /tutorials/analysis/differential_gene_expression/dge_pseudobulk_main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from os import path 4 | import sys 5 | from typing import Optional, Union, Tuple 6 | import yaml 7 | 8 | import anndata as ad 9 | from anndata import AnnData 10 | from anndata.experimental import AnnCollection 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import pandas as pd 14 | from pandas import DataFrame 15 | from pydeseq2.dds import DeseqDataSet 16 | from pydeseq2.ds import DeseqStats 17 | import scanpy as sc 18 | 19 | from scalr.analysis import DgePseudoBulk 20 | 21 | 22 | def main(config): 23 | test_data = sc.read_h5ad(config['full_datapath'], backed='r') 24 | dirpath = config['dirpath'] 25 | dge_type = config['dge_type'] 26 | assert (dge_type == 'DgePseudoBulk') and ('psedobulk_params' in config), ( 27 | f"Check '{dge_type}' and 'psedobulk_params' in dge_config file") 28 | 29 | psedobulk_params = config['psedobulk_params'] 30 | dge = DgePseudoBulk(celltype_column=psedobulk_params.get('celltype_column'), 31 | design_factor=psedobulk_params['design_factor'], 32 | factor_categories=psedobulk_params['factor_categories'], 33 | sum_column=psedobulk_params['sum_column'], 34 | cell_subsets=psedobulk_params.get('cell_subsets', None), 35 | min_cell_threshold=psedobulk_params.get( 36 | 'min_cell_threshold', 10), 37 | fold_change=psedobulk_params.get('fold_change', 1.5), 38 | p_val=psedobulk_params.get('p_val', 0.05), 39 | y_lim_tuple=psedobulk_params.get('y_lim_tuple', None), 40 | save_plot=psedobulk_params.get('save_plot', True), 41 | stdout=True) 42 | 43 | dge.generate_analysis(test_data, dirpath) 44 | 45 | 46 | if __name__ == "__main__": 47 | parser = argparse.ArgumentParser( 48 | description='Dge analysis : psedobulk method') 49 | parser.add_argument('--config', 50 | '-c', 51 | type=str, 52 | required=True, 53 | help='Path to input dge_config.yaml file') 54 | argument = parser.parse_args() 55 | with open(argument.config, 'r') as config_file: 56 | config = yaml.safe_load(config_file) 57 | main(config) 58 | -------------------------------------------------------------------------------- /tutorials/analysis/differential_gene_expression/tutorial_config.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infocusp/scaLR/b97553bdc1f02d596d5b7b7ad21c622a304a2793/tutorials/analysis/differential_gene_expression/tutorial_config.png -------------------------------------------------------------------------------- /tutorials/analysis/gene_recall_curve/multi_model_gene_recall_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infocusp/scaLR/b97553bdc1f02d596d5b7b7ad21c622a304a2793/tutorials/analysis/gene_recall_curve/multi_model_gene_recall_comparison.png -------------------------------------------------------------------------------- /tutorials/analysis/gene_recall_curve/reference_genes.csv: -------------------------------------------------------------------------------- 1 | ,B,DC,NK,Megakaryocyte,Mono,T 2 | 0,PLPP5,HSPA5,CST7,CD48,TMEM119,CD3E 3 | 1,FCRLA,RGS1,CD3E,OST4,VCAN,CD40LG 4 | 2,CD3E,LYZ,CTSW,GP6,CTSW,CD101 5 | 3,CD40LG,CD14,CXXC5,CD9,CD34,JCHAIN 6 | 4,IGHA2,MS4A3,HIST1H4C,GP1BA,TNF,RUNX3 7 | 5,JCHAIN,CD44,CD34,CD79A,CCL2,TNFSF8 8 | 6,TNFRSF13C,CD34,CD27,NRGN,ALOX5AP,LTB 9 | 7,TERF2,AIF1,NCAM1,CCL5,OSM,NCAM1 10 | 8,MEF2B,PLD4,MZB1,MYL9,G0S2,IL2RG 11 | 9,CD27,CST3,CYBA,ITGA2B,CCL3,TNFSF14 12 | 10,MZB1,FCGR1A,PTCRA,SPARC,CD1C,SKAP1 13 | 11,VCAM1,HLA-DQB1,CCL3,GNG11,CXCL3,CCR8 14 | 12,B2M,CD4,NKG7,PPBP,TLR2,CD53 15 | 13,XAF1,UBE2C,TBX21,FLI1,THBS1,TRAT1 16 | 14,ITGB1,HLA-DQA1,CD1C,GP9,GATA2,IL2RA 17 | 15,IGHM,SIRPA,IRF4,SELP,FCER1G,KLRF1 18 | 16,SEL1L3,CD209,CPVL,GMFG,TYROBP,TNFRSF18 19 | 17,NFKBIA,MRC1,KLRF1,PF4,FPR1,CCR4 20 | 18,SSR4,MZB1,FOS,TUBB1,FTL,KLRD1 21 | 19,FOXP3,IRF7,FCER1G,CXCR4,OAZ1,IL13 22 | 20,BLNK,CYBA,TYROBP,,C3,IL32 23 | 21,SPIB,CD1E,CD68,,PSAP,SNAP47 24 | 22,CD24,CCL17,LGALS2,,TUBB,TRGV9 25 | 23,AICDA,PTCRA,TNFRSF18,,CLEC12A,CCR2 26 | 24,VPREB3,HLA-DPB1,ATP1B1,,IL32,PRF1 27 | 25,ZCCHC7,HPGD,NRP1,,PRF1,TRAV1-2 28 | 26,HMGB2,CD40,KIR2DL3,,INSIG1,HLA-B 29 | 27,GZMB,CSF1R,KLRD1,,MPO,ID2 30 | 28,IGKC,C1QA,HMGB2,,LYZ,CD52 31 | 29,JUNB,FCER1A,TMIGD2,,CST3,TCF7 32 | 30,TNFRSF17,HLA-DRB1,GZMB,,MPEG1,CCL4 33 | 31,CD52,CD1C,CLEC12A,,CD4,PDCD1 34 | 32,CR2,SIGLEC6,CCR2,,C15orf48,GATA3 35 | 33,CD37,CD74,PRF1,,ZFP36L2,CD2 36 | 34,ADAM28,HLA-DPA1,CLEC4C,,MRC1,CD44 37 | 35,IGHG3,IRF4,ID2,,S100A6,CCR6 38 | 36,ITGA1,LAMP3,CD52,,CCL17,CD4 39 | 37,CD81,CPVL,KLRC2,,PTPRC,CTLA4 40 | 38,CD2,LILRA4,CCL4,,C1QA,CRTAM 41 | 39,CR1,TOP2A,FASLG,,CXCR2,PTPRC 42 | 40,BCL2A1,HLA-DRA,LYZ,,STAB1,CCL5 43 | 41,CTLA4,CLEC10A,CD2,,CXCR3,CXCR3 44 | 42,ITGAX,CD68,CD44,,MSR1,MAGEH1 45 | 43,HLA-DPB1,TYROBP,CD3G,,FBP1,CD19 46 | 44,PTPRC,LGALS2,CCR6,,KCNA5,CFD 47 | 45,CXCR3,UBC,CST3,,CD38,LCK 48 | 46,EIF4EBP1,ATP1B1,CD4,,VIM,VIM 49 | 47,CD19,NRP1,TESC,,LGALS1,ITGAE 50 | 48,CD38,IFI6,MRC1,,IFI44L,LAYN 51 | 49,IFI44L,IRF8,PRSS23,,IL1RN,TRBC2 52 | 50,VIM,CD83,CCL17,,MS4A6A,RPF1 53 | 51,IGHD,CCR7,CASP5,,INHBA,RCAN3 54 | 52,PTPRCAP,NR4A3,PTPRC,,CFP,FCGR3A 55 | 53,CALR,ISG15,HLA-DPB1,,ISG15,GZMH 56 | 54,RALY,IL3RA,C1QA,,EREG,GZMK 57 | 55,ITM2C,GZMB,CCL5,,ALDH2,CXCR6 58 | 56,CDK1,CLEC12A,CXCR3,,FCGR3A,ZBTB16 59 | 57,FCGR2A,CD86,KLRC3,,GZMH,THEMIS 60 | 58,IL4R,CCR2,SAMD3,,CD80,KLRG1 61 | 59,IGHG2,CLEC4C,TOP2A,,CCL20,TNFRSF9 62 | 60,TCF4,CIITA,CD38,,CEBPA,IGHV4-34 63 | 61,HMGB1,TCF4,HOPX,,APOBEC3A,GBP5 64 | 62,BANK1,PPT1,IFI6,,SELL,ENTPD1 65 | 63,SELL,CD1A,TP53TG5,,APOBEC3B,GATA1 66 | 64,LAPTM5,ID2,KLRC1,,LAPTM5,ICOS 67 | 65,TCL1A,CD80,FGFBP2,,AIF1,LAPTM5 68 | 66,PAX5,CD52,ISG15,,FGL2,TNFRSF4 69 | 67,HLA-DQA1,HLA-DQA2,NR4A3,,CXCL17,HAVCR2 70 | 68,IGLL5,REL,CD79B,,TIMP1,CD69 71 | 69,CD69,CD207,IL3RA,,CEBPB,CD79A 72 | 70,CD40,IL5RA,TCF4,,F5,TRGC1 73 | 71,CD22,CCL20,FCGR3A,,HLA-DRA,TRDC 74 | 72,IGHG4,LILRB4,PPT1,,CREG1,IGLV3-19 75 | 73,BACH2,IFITM3,GZMK,,PLAC8,GZMA 76 | 74,LILRA4,THBD,CXCR6,,GNLY,HCST 77 | 75,HLA-DRA,,CD80,,FGD2,SLAMF7 78 | 76,IGLL1,,CCL20,,CD63,GNLY 79 | 77,ARPP21,,DOCK2,,CD33,TRGC2 80 | 78,XBP1,,ZBTB16,,IFI44,CXCR5 81 | 79,CD83,,CD207,,CX3CR1,HHLA2 82 | 80,RPS4Y2,,HLA-DQA2,,FABP4,NR3C1 83 | 81,EOMES,,KLRG1,,S100A9,CX3CR1 84 | 82,MME,,LILRB4,,FCN1,ETS1 85 | 83,IAPP,,CHST12,,SERPING1,RPS15A 86 | 84,CXCR4,,HSPA5,,OLR1,MYC 87 | 85,DERL3,,SELL,,CD93,LYVE1 88 | 86,LAG3,,RGS1,,NUPR1,CXCR4 89 | 87,IL7R,,STMN1,,CD163,TRBC1 90 | 88,CD14,,NCR3,,IL6,CAMK4 91 | 89,HLA-A,,MS4A3,,FCGR1A,CD8A 92 | 90,MKI67,,IL2RB,,VEGFA,IL7R 93 | 91,BLK,,AIF1,,MNDA,LAG3 94 | 92,AHNAK,,PLD4,,HDC,PHLDA1 95 | 93,IGHA1,,L1TD1,,S100A8,TRAC 96 | 94,IGHG1,,SIRPA,,MT-CO3,CD7 97 | 95,CD72,,HLA-DQA1,,CLU,TXNIP 98 | 96,PMAIP1,,CD69,,NLRP3,KLRB1 99 | 97,MX1,,CD40,,STX11,TRABD2A 100 | 98,PPY,,CD1E,,CXCL10,GIMAP7 101 | 99,BCL11A,,HPGD,,SOD2,MKI67 102 | 100,FCER2,,CLIC3,,CLEC10A,RPS12 103 | 101,RALGPS2,,PTGDR,,IL1R2,CD177 104 | 102,FCRL1,,NCR1,,MARCO,IGHA1 105 | 103,CDC20,,KLRK1,,CD300E,TIGIT 106 | 104,BCL11B,,FCER1A,,CSF3R,CD28 107 | 105,CCR7,,XCL2,,S100A12,RPL28 108 | 106,TNFRSF13B,,SIGLEC6,,MS4A7,LEF1 109 | 107,CCDC50,,HLA-DRB1,,IFITM1,IGHG1 110 | 108,AIM2,,TRGC1,,CD14,MT-CO3 111 | 109,MS4A1,,TRDC,,MCEMP1,CD1B 112 | 110,SWAP70,,LAMP3,,,CD247 113 | 111,RAG1,,GZMA,,,CD5 114 | 112,YBX3,,LILRA4,,,CD96 115 | 113,,,CD8B,,,CD3D 116 | 114,,,HLA-DRA,,,CD274 117 | 115,,,GNLY,,,BCL11B 118 | 116,,,TRGC2,,,SOCS1 119 | 117,,,UBC,,,CCR7 120 | 118,,,CHST2,,,BST2 121 | 119,,,CXCR5,,,NELL2 122 | 120,,,CD83,,,IFNG 123 | 121,,,CX3CR1,,,IFITM1 124 | 122,,,EOMES,,,LAT 125 | 123,,,ZFP36,,,RAG1 126 | 124,,,CD86,,,CD70 127 | 125,,,CIITA,,, 128 | 126,,,CD1A,,, 129 | 127,,,XCL1,,, 130 | 128,,,IL5RA,,, 131 | 129,,,CXCR4,,, 132 | 130,,,IL18RAP,,, 133 | 131,,,THBD,,, 134 | 132,,,GSG1,,, 135 | 133,,,NMUR1,,, 136 | 134,,,CD8A,,, 137 | 135,,,LAG3,,, 138 | 136,,,IL7R,,, 139 | 137,,,CD7,,, 140 | 138,,,TRAC,,, 141 | 139,,,KLRB1,,, 142 | 140,,,MKI67,,, 143 | 141,,,CD9,,, 144 | 142,,,FCGR1A,,, 145 | 143,,,HLA-DQB1,,, 146 | 144,,,UBE2C,,, 147 | 145,,,CD160,,, 148 | 146,,,CD209,,, 149 | 147,,,TIGIT,,, 150 | 148,,,IRF7,,, 151 | 149,,,CSF1R,,, 152 | 150,,,CD74,,, 153 | 151,,,HLA-DPA1,,, 154 | 152,,,CD247,,, 155 | 153,,,CD96,,, 156 | 154,,,CLEC10A,,, 157 | 155,,,CD3D,,, 158 | 156,,,IRF8,,, 159 | 157,,,CCR7,,, 160 | 158,,,SPON2,,, 161 | 159,,,IFNG,,, 162 | 160,,,KIR2DL1,,, 163 | 161,,,REL,,, 164 | 162,,,KIR2DL4,,, 165 | 163,,,CD14,,, 166 | 164,,,IFITM3,,, 167 | -------------------------------------------------------------------------------- /tutorials/pipeline/config_celltype.yaml: -------------------------------------------------------------------------------- 1 | # Config file for pipeline run for cell type classification. 2 | 3 | # DEVICE SETUP. 4 | device: 'cuda' 5 | 6 | # EXPERIMENT. 7 | experiment: 8 | dirpath: 'scalr_experiments' 9 | exp_name: 'exp_name' 10 | exp_run: 0 11 | 12 | 13 | # DATA CONFIG. 14 | data: 15 | sample_chunksize: 20000 16 | 17 | train_val_test: 18 | full_datapath: 'data/modified_adata.h5ad' 19 | num_workers: 2 20 | 21 | splitter_config: 22 | name: GroupSplitter 23 | params: 24 | split_ratio: [7, 1, 2.5] 25 | stratify: 'donor_id' 26 | 27 | # split_datapaths: '' 28 | 29 | # preprocess: 30 | # - name: SampleNorm 31 | # params: 32 | # **args 33 | 34 | # - name: StandardScaler 35 | # params: 36 | # **args 37 | 38 | target: cell_type 39 | 40 | 41 | # FEATURE SELECTION. 42 | feature_selection: 43 | 44 | # score_matrix: '/path/to/matrix' 45 | feature_subsetsize: 5000 46 | num_workers: 2 47 | 48 | model: 49 | name: SequentialModel 50 | params: 51 | layers: [5000, 10] 52 | weights_init_zero: True 53 | 54 | model_train_config: 55 | trainer: SimpleModelTrainer 56 | 57 | dataloader: 58 | name: SimpleDataLoader 59 | params: 60 | batch_size: 25000 61 | padding: 5000 62 | 63 | optimizer: 64 | name: SGD 65 | params: 66 | lr: 1.0e-3 67 | weight_decay: 0.1 68 | 69 | loss: 70 | name: CrossEntropyLoss 71 | 72 | epochs: 10 73 | 74 | scoring_config: 75 | name: LinearScorer 76 | 77 | features_selector: 78 | name: AbsMean 79 | params: 80 | k: 5000 81 | 82 | 83 | # FINAL MODEL TRAINING. 84 | final_training: 85 | 86 | model: 87 | name: SequentialModel 88 | params: 89 | layers: [5000, 10] 90 | dropout: 0 91 | weights_init_zero: False 92 | 93 | model_train_config: 94 | resume_from_checkpoint: null 95 | 96 | trainer: SimpleModelTrainer 97 | 98 | dataloader: 99 | name: SimpleDataLoader 100 | params: 101 | batch_size: 15000 102 | 103 | optimizer: 104 | name: Adam 105 | params: 106 | lr: 1.0e-3 107 | weight_decay: 0 108 | 109 | loss: 110 | name: CrossEntropyLoss 111 | 112 | epochs: 100 113 | 114 | callbacks: 115 | - name: TensorboardLogger 116 | - name: EarlyStopping 117 | params: 118 | patience: 3 119 | min_delta: 1.0e-4 120 | - name: ModelCheckpoint 121 | params: 122 | interval: 5 123 | 124 | 125 | # EVALUATION & DOWNSTREAM ANALYSIS. 126 | analysis: 127 | 128 | model_checkpoint: '' 129 | 130 | dataloader: 131 | name: SimpleDataLoader 132 | params: 133 | batch_size: 15000 134 | 135 | gene_analysis: 136 | scoring_config: 137 | name: LinearScorer 138 | 139 | features_selector: 140 | name: ClasswisePromoters 141 | params: 142 | k: 100 143 | test_samples_downstream_analysis: 144 | - name: GeneRecallCurve 145 | params: 146 | reference_genes_path: 'scaLR/tutorials/pipeline/grc_reference_gene.csv' 147 | top_K: 300 148 | plots_per_row: 3 149 | features_selector: 150 | name: ClasswiseAbs 151 | params: {} 152 | - name: Heatmap 153 | params: {} 154 | - name: RocAucCurve 155 | params: {} 156 | -------------------------------------------------------------------------------- /tutorials/pipeline/config_clinical.yaml: -------------------------------------------------------------------------------- 1 | # Config file for pipeline run for clinical condition specific biomarker identification. 2 | 3 | # DEVICE SETUP. 4 | device: 'cuda' 5 | 6 | # EXPERIMENT. 7 | experiment: 8 | dirpath: 'scalr_experiments' 9 | exp_name: 'exp_name' 10 | exp_run: 1 11 | 12 | 13 | # DATA CONFIG. 14 | data: 15 | sample_chunksize: 20000 16 | 17 | train_val_test: 18 | full_datapath: 'data/modified_adata.h5ad' 19 | num_workers: 2 20 | 21 | splitter_config: 22 | name: GroupSplitter 23 | params: 24 | split_ratio: [7, 1, 2.5] 25 | stratify: 'donor_id' 26 | 27 | # split_datapaths: '' 28 | 29 | # preprocess: 30 | # - name: SampleNorm 31 | # params: 32 | # **args 33 | 34 | # - name: StandardScaler 35 | # params: 36 | # **args 37 | 38 | target: disease 39 | 40 | 41 | # FEATURE SELECTION. 42 | feature_selection: 43 | 44 | # score_matrix: '/path/to/matrix' 45 | feature_subsetsize: 5000 46 | num_workers: 2 47 | 48 | model: 49 | name: SequentialModel 50 | params: 51 | layers: [5000, 2] 52 | weights_init_zero: True 53 | 54 | model_train_config: 55 | trainer: SimpleModelTrainer 56 | 57 | dataloader: 58 | name: SimpleDataLoader 59 | params: 60 | batch_size: 25000 61 | padding: 5000 62 | 63 | optimizer: 64 | name: SGD 65 | params: 66 | lr: 1.0e-3 67 | weight_decay: 0.1 68 | 69 | loss: 70 | name: CrossEntropyLoss 71 | 72 | epochs: 10 73 | 74 | scoring_config: 75 | name: LinearScorer 76 | 77 | features_selector: 78 | name: AbsMean 79 | params: 80 | k: 5000 81 | 82 | 83 | # FINAL MODEL TRAINING. 84 | final_training: 85 | 86 | model: 87 | name: SequentialModel 88 | params: 89 | layers: [5000, 2] 90 | dropout: 0 91 | weights_init_zero: False 92 | 93 | model_train_config: 94 | resume_from_checkpoint: null 95 | 96 | trainer: SimpleModelTrainer 97 | 98 | dataloader: 99 | name: SimpleDataLoader 100 | params: 101 | batch_size: 15000 102 | 103 | optimizer: 104 | name: Adam 105 | params: 106 | lr: 1.0e-3 107 | weight_decay: 0 108 | 109 | loss: 110 | name: CrossEntropyLoss 111 | 112 | epochs: 100 113 | 114 | callbacks: 115 | - name: TensorboardLogger 116 | - name: EarlyStopping 117 | params: 118 | patience: 3 119 | min_delta: 1.0e-4 120 | - name: ModelCheckpoint 121 | params: 122 | interval: 5 123 | 124 | 125 | # EVALUATION & DOWNSTREAM ANALYSIS. 126 | analysis: 127 | 128 | model_checkpoint: '' 129 | 130 | dataloader: 131 | name: SimpleDataLoader 132 | params: 133 | batch_size: 15000 134 | 135 | gene_analysis: 136 | scoring_config: 137 | name: LinearScorer 138 | 139 | features_selector: 140 | name: ClasswisePromoters 141 | params: 142 | k: 100 143 | full_samples_downstream_analysis: 144 | - name: Heatmap 145 | params: 146 | top_n_genes: 100 147 | - name: RocAucCurve 148 | params: {} 149 | - name: DgePseudoBulk 150 | params: 151 | celltype_column: 'cell_type' 152 | design_factor: 'disease' 153 | factor_categories: ['COVID-19', 'normal'] 154 | sum_column: 'donor_id' 155 | cell_subsets: ['conventional dendritic cell', 'natural killer cell'] 156 | - name: DgeLMEM 157 | params: 158 | fixed_effect_column: 'disease' 159 | fixed_effect_factors: ['COVID-19', 'normal'] 160 | group: 'donor_id' 161 | celltype_column: 'cell_type' 162 | cell_subsets: ['conventional dendritic cell'] 163 | gene_batch_size: 1000 164 | coef_threshold: 0.1 -------------------------------------------------------------------------------- /tutorials/pipeline/grc_reference_gene.csv: -------------------------------------------------------------------------------- 1 | ,classical monocyte,"CD16-positive, CD56-dim natural killer cell, human",non-classical monocyte,natural killer cell,platelet,"CD16-negative, CD56-bright natural killer cell, human",conventional dendritic cell,plasmacytoid dendritic cell,granulocyte,intermediate monocyte 2 | 0,HLA-DR,FCGR3A,CD14,CD94,NKG7,NCAM1,PPA1,CD19,FCGR3B,CD14 3 | 1,CD9,NCAM1,CD16,NKG7,PPBP,KLRD1,FCER1A,CD38,MPO,CD16 4 | 2,CD93,KLRD1,FCGR3A,NKp80,PF4,KLRK1,HLA-DQA1,CD27,AZU1,APOBEC3A 5 | 3,CD16,KLRK1,TNFRFA1,GNLY,GATA2,NCR1,CD1C,CD20,PRTN3,GBP4 6 | 4,CD11a,KIR2DL1,TNFRFB2,CD3,ARG1,NCR3,HLA-DQB1,MZB1,GATA2,MARCO 7 | 5,CD61,KIR2DL2,CSF1R,Vα24,ANXA1,NKG2A,CLEC9A,ITM2C,CD66b,MARCKSL1 8 | 6,CD14,KIR2DL3,FCN1,CD56,GP9,NKG2C,CDIC,CLEC4C,CD14,TNFRFB2 9 | 7,BASP1,KIR3DL1,CDKN1C,Vβ11,ITGA2B,FCER1G,CD370,LILRA4,CD15,HLA-DR 10 | 8,CXCL8,KIR2DS1,LYZ,KLRB1,ITGB3,PRF1,CD141,GZMB,CAMP,CCR2 11 | 9,S100A9,KIR2DS2,CX3CR1,CD45,GP1BA,GZMB,FLT3,ITM2C,LTF,CX3CR1 12 | 10,IFITM3,KIR3DS1,TREM1,TBX21,GP1BB,GZMA,XCR1,SERPINF1,CD16,TLR2 13 | 11,S100A8,NCR1,ITGAL,CD16,GP6,IFNG,CD11C,IL3RA,ELANE,TLR4 14 | 12,VNN2,NCR3,ITGAM,KLRF1,THBS1,TNF,HLA-DRA,CYTL1,IL8,FCGR3A 15 | 13,IFIT3,NKG2C,S100A8,TYROBP,F2R,LTA,HLA-DRB1,GATA2,CXCR2,S100A8 16 | 14,ALOX5AP,PRF1,S100A9,CD3E,TGFB1,IL2RA,HLA-DPA1,CLEC4C,CCR1,S100A9 17 | 15,S100A12,GZMB,TLR4,NCAM1,SERPINE1,IL2RB,HLA-DPB1,IL3RA,S100A8,TNF 18 | 16,HLA-DQA1,GZMA,TNF,GZMB,VWF,IL2RG,CCR7,IRF8,S100A9,IL1B 19 | 17,IFI6,IFNG,CCL2,CD247,GP5,STAT1,CD80,JCHAIN,LTC4S,IL6 20 | 18,HLA-DRB1,TNF,CCR2,CD69,CD36,STAT4,CD86,MPEG1,PTGS2,CCL2 21 | 19,HLA-DRA,LTA,STAT3,CD7,MPL,TYROBP,IRF8,IL3RA,NCF1,CD36 22 | 20,GPR183,STAT1,NFKB1,CCL5,PDGFB,SH2D1B,IRF4,IL3RA,CTSC,CD86 23 | 21,ISG15,STAT4,IL1B,TRAC,PDGFA,SYK,BATF3,GZMB,TGFBR1,MMP9 24 | 22,CD74,TYROBP,IL6,STMN1,SELP,ZAP70,CSF1R,LILRA4,HLA-DR, 25 | 23,PLBD1,SH2D1B,CXCL10,IL2RB,SPP1,LILRB1,SIRPA,CD123,C3AR1, 26 | 24,QAS3,SYK,MMP9,SELL,F13A1,IL12RB1,CD40,CD45,TLR4, 27 | 25,IFIT2,ZAP70,VEGFA,CX3CR1,ITGB1,IL12RB2,CD83,HLA-DR,, 28 | 26,HLA-DPB1,LILRB1,CD36,TRGC2,BCL2L1,CXCR3,LAMP3,LILRA4,, 29 | 27,IFI44L,IL12RB1,HLA-DR,FOS,ANXA5,CCR7,TLR3,IL3RA,, 30 | 28,ITGA4,IL12RB2,IRF5,ZFP36,,ITGB2,TLR7,TCF4,, 31 | 29,MX1,CD3Z,MYD88,CD3D,,ITGAL,TLR8,CD303,, 32 | 30,CYP1B1,ITGB2,,SPON2,,SELL,CXCR4,CD123,, 33 | 31,HLA-DPA1,ITGAL,,TRGC1,,CX3CR1,IL12B,CD123,, 34 | 32,PADI4,CX3CR1,,FCER1G,,,CCL22,GATA2,, 35 | 33,EPSTI1,,,IL18RAP,,,CSF2RB,CD34,, 36 | 34,HLA-DQB1,,,CD160,,,ZBTB46,LILRA4,, 37 | 35,XAF1,,,ZBTB16,,,,CD2AP,, 38 | 36,CDKN1A,,,CD52,,,,CLEC4C,, 39 | 37,LYZ,,,CXXC5,,,,IL-3Rα,, 40 | 38,FCN1,,,TRDC,,,,IL3RA,, 41 | 39,FCGR3A,,,PRSS23,,,,CD123,, 42 | 40,VCAN,,,CD3G,,,,CLEC4A,, 43 | 41,CD68,,,CHST2,,,,LILRA4,, 44 | 42,CD11b,,,KLRD1,,,,CLEC4C,, 45 | 43,,,,KLRC1,,,,LILRA4,, 46 | 44,,,,FCGR3A,,,,GZMB,, 47 | 45,,,,CD38,,,,JCHAIN,, 48 | 46,,,,CD62L,,,,LILRA4,, 49 | 47,,,,KIR2DL1,,,,TCL1A,, 50 | 48,,,,CD57,,,,IL3RA,, 51 | 49,,,,NKG2C,,,,,, 52 | 50,,,,CD49e,,,,,, 53 | 51,,,,NKp46,,,,,, 54 | 52,,,,KIR2DS1,,,,,, 55 | 53,,,,NKG2D,,,,,, 56 | 54,,,,NCAM1 (CD56),,,,,, 57 | 55,,,,KLRK1,,,,,, 58 | 56,,,,KIR2DL2,,,,,, 59 | 57,,,,KIR2DL3,,,,,, 60 | 58,,,,KIR3DL1,,,,,, 61 | 59,,,,KIR2DS2,,,,,, 62 | 60,,,,KIR3DS1,,,,,, 63 | 61,,,,NCR1,,,,,, 64 | 62,,,,NCR2,,,,,, 65 | 63,,,,NCR3,,,,,, 66 | 64,,,,NKG2A,,,,,, 67 | 65,,,,PRF1,,,,,, 68 | 66,,,,GZMA,,,,,, 69 | 67,,,,IFNG,,,,,, 70 | 68,,,,TNF,,,,,, 71 | 69,,,,LTA,,,,,, 72 | 70,,,,IL2,,,,,, 73 | 71,,,,IL12RB1,,,,,, 74 | 72,,,,IL12RB2,,,,,, 75 | 73,,,,CD3Z,,,,,, 76 | 74,,,,ITGB2,,,,,, 77 | 75,,,,ITGAL,,,,,, 78 | 76,,,,CD226,,,,,, 79 | 77,,,,CD161,,,,,, 80 | 78,,,,LILRB1,,,,,, 81 | 79,,,,SH2D1A,,,,,, 82 | 80,,,,SYK,,,,,, 83 | 81,,,,ZAP70,,,,,, 84 | 82,,,,STAT1,,,,,, 85 | 83,,,,STAT3,,,,,, 86 | 84,,,,STAT5A,,,,,, 87 | 85,,,,STAT5B,,,,,, 88 | 86,,,,IRF1,,,,,, 89 | 87,,,,IRF2,,,,,, 90 | 88,,,,IRF4,,,,,, 91 | 89,,,,BATF3,,,,,, 92 | 90,,,,CXCL10,,,,,, 93 | 91,,,,CCL2,,,,,, 94 | 92,,,,PDCD1 (PD-1),,,,,, 95 | 93,,,,CTLA4,,,,,, 96 | 94,,,,CD40LG,,,,,, 97 | 95,,,,TLR3,,,,,, 98 | 96,,,,TLR7,,,,,, 99 | 97,,,,TLR8,,,,,, 100 | 98,,,,FASLG,,,,,, 101 | 99,,,,XCL1,,,,,, 102 | 100,,,,CCL4,,,,,, 103 | 101,,,,IL15,,,,,, 104 | 102,,,,IL18,,,,,, 105 | 103,,,,IL2RA,,,,,, 106 | 104,,,,IL2RG,,,,,, 107 | 105,,,,LAG3,,,,,, 108 | 106,,,,KLRG1,,,,,, 109 | 107,,,,TGFBR1,,,,,, 110 | 108,,,,TGFBR2,,,,,, 111 | 109,,,,CD44,,,,,, 112 | 110,,,,S100A8,,,,,, 113 | 111,,,,S100A9,,,,,, 114 | 112,,,,MMP9,,,,,, 115 | 113,,,,HLA-A,,,,,, 116 | 114,,,,HLA-B,,,,,, 117 | 115,,,,HLA-C,,,,,, 118 | 116,,,,HLA-E,,,,,, 119 | 117,,,,HLA-F,,,,,, 120 | 118,,,,HLA-G,,,,,, 121 | 119,,,,TAP1,,,,,, 122 | 120,,,,TAP2,,,,,, 123 | 121,,,,BCL2,,,,,, 124 | 122,,,,BCL2L1,,,,,, 125 | 123,,,,MYC,,,,,, 126 | 124,,,,FOXP3,,,,,, 127 | 125,,,,GZMK,,,,,, 128 | 126,,,,GZMM,,,,,, 129 | 127,,,,IL10,,,,,, 130 | 128,,,,TGFB1,,,,,, 131 | 129,,,,TGFB2,,,,,, 132 | 130,,,,IFNA,,,,,, 133 | 131,,,,CCL22,,,,,, 134 | 132,,,,SIRPA,,,,,, 135 | 133,,,,HLA-DR,,,,,, 136 | 134,,,,CCL3,,,,,, 137 | 135,,,,CXCL9,,,,,, 138 | 136,,,,BIRC3,,,,,, 139 | 137,,,,NFKB1,,,,,, 140 | 138,,,,NFKB2,,,,,, 141 | 139,,,,IL1B,,,,,, 142 | 140,,,,CXCL1,,,,,, --------------------------------------------------------------------------------