├── .github └── workflows │ └── code-quality.yaml ├── .gitignore ├── LICENSE.txt ├── README.md ├── notebooks ├── demos │ ├── data_splitter_demo.ipynb │ ├── evaluator_demo.ipynb │ └── generate_prediction_dataframe.ipynb └── neurips2025 │ ├── build_data_scaling_splits.ipynb │ ├── build_imbalance_splits.ipynb │ ├── build_jiang24_frangieh21_splits.ipynb │ ├── cpa_theis_fork_hparams.yaml │ ├── data_curation │ ├── curate_Frangieh21.ipynb │ ├── curate_Jiang24_step1.R │ ├── curate_Jiang24_step2.ipynb │ ├── curate_McFalineFigueroa23_step1.R │ ├── curate_McFalineFigueroa23_step2.ipynb │ ├── curate_Norman19.ipynb │ ├── curate_Srivatsan20.ipynb │ └── readme.txt │ ├── gears │ ├── final_evaluation.ipynb │ ├── gears_helpers.py │ └── preprocess_norman19_gears.ipynb │ └── generate_scgpt_embeddings.ipynb ├── pyproject.toml ├── setup.py └── src └── perturbench ├── analysis ├── __init__.py ├── benchmarks │ ├── __init__.py │ ├── aggregation.py │ ├── evaluation.py │ ├── evaluator.py │ └── metrics.py ├── plotting.py ├── preprocess.py └── utils.py ├── configs ├── __init__.py ├── callbacks │ ├── default.yaml │ ├── early_stopping.yaml │ ├── lr_monitor.yaml │ ├── model_checkpoint.yaml │ ├── model_summary.yaml │ ├── none.yaml │ └── rich_progress_bar.yaml ├── data │ ├── devel.yaml │ ├── evaluation │ │ ├── default.yaml │ │ └── final_test.yaml │ ├── frangieh21.yaml │ ├── jiang24.yaml │ ├── mcfaline23.yaml │ ├── norman19.yaml │ ├── sciplex3.yaml │ ├── splitter │ │ ├── cell_type_transfer_task.yaml │ │ ├── cell_type_treatment_transfer_task.yaml │ │ ├── combination_prediction_task.yaml │ │ ├── mcfaline23_split.yaml │ │ └── saved_split.yaml │ └── transform │ │ └── linear_model_pipeline.yaml ├── experiment │ └── neurips2024 │ │ ├── frangieh21 │ │ ├── biolord_best_params_frangieh21.yaml │ │ ├── cpa_best_params_frangieh21.yaml │ │ ├── cpa_no_adv_best_params_frangieh21.yaml │ │ ├── cpa_scgpt_best_params_frangieh21.yaml │ │ ├── decoder_best_params_frangieh21.yaml │ │ ├── decoder_cov_best_params_frangieh21.yaml │ │ ├── latent_best_params_frangieh21.yaml │ │ ├── latent_scgpt_best_params_frangieh21.yaml │ │ ├── linear_best_params_frangieh21.yaml │ │ ├── sams_best_params_frangieh21.yaml │ │ └── sams_modified_best_params_frangieh21.yaml │ │ ├── jiang24 │ │ ├── cpa_best_params_jiang24.yaml │ │ ├── cpa_no_adv_best_params_jiang24.yaml │ │ ├── decoder_best_params_jiang24.yaml │ │ ├── latent_best_params_jiang24.yaml │ │ ├── linear_best_params_jiang24.yaml │ │ ├── sams_best_params_jiang24.yaml │ │ └── sams_modified_best_params_jiang24.yaml │ │ ├── mcfaline23 │ │ ├── cpa_best_params_mcfaline23_full.yaml │ │ ├── cpa_best_params_mcfaline23_medium.yaml │ │ ├── cpa_best_params_mcfaline23_small.yaml │ │ ├── cpa_no_adv_best_params_mcfaline23_full.yaml │ │ ├── cpa_no_adv_best_params_mcfaline23_medium.yaml │ │ ├── cpa_no_adv_best_params_mcfaline23_small.yaml │ │ ├── decoder_only_best_params_mcfaline23_full.yaml │ │ ├── decoder_only_best_params_mcfaline23_medium.yaml │ │ ├── decoder_only_best_params_mcfaline23_small.yaml │ │ ├── latent_additive_best_params_mcfaline23_full.yaml │ │ ├── latent_additive_best_params_mcfaline23_medium.yaml │ │ ├── latent_additive_best_params_mcfaline23_small.yaml │ │ ├── linear_additive_best_params_mcfaline23_full.yaml │ │ ├── linear_additive_best_params_mcfaline23_medium.yaml │ │ ├── linear_additive_best_params_mcfaline23_small.yaml │ │ ├── sams_best_params_mcfaline23_full.yaml │ │ ├── sams_best_params_mcfaline23_medium.yaml │ │ ├── sams_best_params_mcfaline23_small.yaml │ │ ├── sams_modified_best_params_mcfaline23_full.yaml │ │ ├── sams_modified_best_params_mcfaline23_medium.yaml │ │ └── sams_modified_best_params_mcfaline23_small.yaml │ │ ├── norman19 │ │ ├── biolord_best_params_norman19.yaml │ │ ├── cpa_best_params_norman19.yaml │ │ ├── cpa_no_adv_best_params_norman19.yaml │ │ ├── cpa_scgpt_best_params_norman19.yaml │ │ ├── decoder_best_params_norman19.yaml │ │ ├── latent_best_params_norman19.yaml │ │ ├── latent_scgpt_best_params_norman19.yaml │ │ ├── linear_best_params_norman19.yaml │ │ ├── sams_best_params_norman19.yaml │ │ └── sams_modified_best_params_norman19.yaml │ │ └── sciplex3 │ │ ├── biolord_best_params_sciplex3.yaml │ │ ├── cpa_best_params_sciplex3.yaml │ │ ├── cpa_no_adv_best_params_sciplex3.yaml │ │ ├── cpa_scgpt_best_params_sciplex3.yaml │ │ ├── decoder_best_params_sciplex3.yaml │ │ ├── decoder_cov_best_params_sciplex3.yaml │ │ ├── latent_best_params_sciplex3.yaml │ │ ├── latent_scgpt_best_params_sciplex3.yaml │ │ ├── linear_best_params_sciplex3.yaml │ │ ├── sams_best_params_sciplex3.yaml │ │ └── sams_modified_best_params_sciplex3.yaml ├── hpo │ ├── biolord_hpo.yaml │ ├── cpa_hpo.yaml │ ├── decoder_only_hpo.yaml │ ├── latent_additive_hpo.yaml │ ├── linear_additive_hpo.yaml │ ├── local.yaml │ └── sams_vae_hpo.yaml ├── hydra │ └── default.yaml ├── logger │ ├── csv.yaml │ ├── default.yaml │ └── tensorboard.yaml ├── model │ ├── biolord.yaml │ ├── cpa.yaml │ ├── decoder_only.yaml │ ├── latent_additive.yaml │ ├── linear_additive.yaml │ └── sams_vae.yaml ├── paths │ └── default.yaml ├── predict.yaml ├── train.yaml └── trainer │ ├── cpu.yaml │ ├── default.yaml │ └── gpu.yaml ├── data ├── __init__.py ├── accessors │ ├── base.py │ ├── frangieh21.py │ ├── jiang24.py │ ├── mcfaline23.py │ ├── norman19.py │ └── srivatsan20.py ├── collate.py ├── datasets │ ├── __init__.py │ ├── population.py │ └── singlecell.py ├── datasplitter.py ├── modules.py ├── resources │ ├── __init__.py │ └── devel.h5ad ├── transforms │ ├── __init__.py │ ├── base.py │ ├── encoders.py │ ├── ops.py │ └── pipelines.py ├── types.py └── utils.py └── modelcore ├── __init__.py ├── models ├── __init__.py ├── average.py ├── base.py ├── biolord.py ├── cpa.py ├── decoder_only.py ├── latent_additive.py ├── linear_additive.py └── sams_vae.py ├── nn ├── __init__.py ├── decoders.py ├── mlp.py ├── utils.py └── vae.py ├── predict.py ├── train.py └── utils.py /.github/workflows/code-quality.yaml: -------------------------------------------------------------------------------- 1 | name: code-quality 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | run-linter: 7 | name: run linter 8 | 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - uses: actions/checkout@v4 13 | - uses: chartboost/ruff-action@v1 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | *.sqlite 28 | *.db 29 | *perturbench_data* 30 | *logs/* 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Machine Learning 40 | tensorboard/ 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | __pycache__ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | .pybuilder/ 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | # For a library or package, you might want to ignore these files since the code is 95 | # intended to run in multiple environments; otherwise, check them in: 96 | # .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # poetry 106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 107 | # This is especially recommended for binary packages to ensure reproducibility, and is more 108 | # commonly ignored for libraries. 109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 110 | #poetry.lock 111 | 112 | # pdm 113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 114 | #pdm.lock 115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 116 | # in version control. 117 | # https://pdm.fming.dev/#use-with-ide 118 | .pdm.toml 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # VSCode 171 | .vscode/ 172 | 173 | # Lightning logs 174 | **/lightning_logs/ 175 | 176 | # Data 177 | *.gz 178 | 179 | # Machine Learning 180 | tensorboard/ 181 | [0-9]* 182 | 183 | # Biomart 184 | .pybiomart* 185 | 186 | # Jupyter Notebook 187 | !*.ipynb 188 | !*.py 189 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, Altos Labs 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | Additional License Information 31 | ------------------- 32 | Certain files within this repository are subject to different licensing terms. 33 | Please see the specific files in the perturbench/modelcore/models for their 34 | respective licenses. 35 | -------------------------------------------------------------------------------- /notebooks/demos/generate_prediction_dataframe.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 2024-03-11-Demo: Creating a prediction dataframe" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "A notebook demonstrating how to generate a prediction dataframe on disk for the `predict.py` script" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import scanpy as sc\n", 24 | "import pandas as pd" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "data_cache_dir = '../neurips2024/perturbench_data' ## Change this to your local data directory" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 4, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "data": { 43 | "text/plain": [ 44 | "AnnData object with n_obs × n_vars = 183856 × 9198 backed at '../neurips2024/perturbench_data/srivatsan20_processed.h5ad'\n", 45 | " obs: 'ncounts', 'well', 'plate', 'cell_line', 'replicate', 'time', 'dose_value', 'pathway_level_1', 'pathway_level_2', 'perturbation', 'target', 'pathway', 'dose_unit', 'celltype', 'disease', 'cancer', 'tissue_type', 'organism', 'perturbation_type', 'ngenes', 'percent_mito', 'percent_ribo', 'nperts', 'chembl-ID', 'dataset', 'cell_type', 'treatment', 'condition', 'dose', 'cov_merged', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes'\n", 46 | " var: 'ensembl_id', 'ncounts', 'ncells', 'gene_symbol', 'n_cells', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'highly_variable_nbatches'\n", 47 | " uns: 'hvg', 'log1p', 'rank_genes_groups_cov'\n", 48 | " layers: 'counts'" 49 | ] 50 | }, 51 | "execution_count": 4, 52 | "metadata": {}, 53 | "output_type": "execute_result" 54 | } 55 | ], 56 | "source": [ 57 | "adata = sc.read_h5ad(f'{data_cache_dir}/srivatsan20_processed.h5ad', backed='r')\n", 58 | "adata" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 5, 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "data": { 68 | "text/plain": [ 69 | "['mcf7', 'k562', 'a549']\n", 70 | "Categories (3, object): ['a549', 'k562', 'mcf7']" 71 | ] 72 | }, 73 | "execution_count": 5, 74 | "metadata": {}, 75 | "output_type": "execute_result" 76 | } 77 | ], 78 | "source": [ 79 | "adata.obs.cell_type.unique()" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 6, 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "data": { 89 | "text/plain": [ 90 | "188" 91 | ] 92 | }, 93 | "execution_count": 6, 94 | "metadata": {}, 95 | "output_type": "execute_result" 96 | } 97 | ], 98 | "source": [ 99 | "unique_perturbations = [p for p in adata.obs.perturbation.unique() if p != 'control']\n", 100 | "len(unique_perturbations)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 7, 106 | "metadata": {}, 107 | "outputs": [ 108 | { 109 | "data": { 110 | "text/html": [ 111 | "
\n", 112 | "\n", 125 | "\n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | "
conditioncell_type
0TAK-901k562
1Busulfank562
2BMS-536924k562
3Enzastaurin (LY317615)k562
4BMS-911543k562
\n", 161 | "
" 162 | ], 163 | "text/plain": [ 164 | " condition cell_type\n", 165 | "0 TAK-901 k562\n", 166 | "1 Busulfan k562\n", 167 | "2 BMS-536924 k562\n", 168 | "3 Enzastaurin (LY317615) k562\n", 169 | "4 BMS-911543 k562" 170 | ] 171 | }, 172 | "execution_count": 7, 173 | "metadata": {}, 174 | "output_type": "execute_result" 175 | } 176 | ], 177 | "source": [ 178 | "prediction_df = pd.DataFrame(\n", 179 | " {\n", 180 | " 'condition': unique_perturbations,\n", 181 | " 'cell_type': 'k562',\n", 182 | " }\n", 183 | ")\n", 184 | "prediction_df.head()" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 8, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "prediction_df.to_csv(f'{data_cache_dir}/prediction_dataframe.csv', index=False)" 194 | ] 195 | } 196 | ], 197 | "metadata": { 198 | "kernelspec": { 199 | "display_name": "perturbench-dev", 200 | "language": "python", 201 | "name": "python3" 202 | }, 203 | "language_info": { 204 | "codemirror_mode": { 205 | "name": "ipython", 206 | "version": 3 207 | }, 208 | "file_extension": ".py", 209 | "mimetype": "text/x-python", 210 | "name": "python", 211 | "nbconvert_exporter": "python", 212 | "pygments_lexer": "ipython3", 213 | "version": "3.11.9" 214 | } 215 | }, 216 | "nbformat": 4, 217 | "nbformat_minor": 2 218 | } 219 | -------------------------------------------------------------------------------- /notebooks/neurips2025/cpa_theis_fork_hparams.yaml: -------------------------------------------------------------------------------- 1 | ae_hparams: 2 | autoencoder_depth: 4 3 | autoencoder_width: 256 4 | adversary_depth: 3 5 | adversary_width: 256 6 | use_batch_norm: False 7 | use_layer_norm: True 8 | output_activation: linear 9 | dropout_rate: 0.2 10 | variational: True 11 | seed: 0 12 | 13 | trainer_hparams: 14 | n_epochs_warmup: 10 15 | n_epochs_kl_warmup: 10 16 | max_kl_weight: 0.1 17 | adversary_lr: 0.00008847032648856746 18 | adversary_wd: 9.629190571404551e-07 19 | adversary_steps: 3 20 | autoencoder_lr: 0.00007208558788012054 21 | autoencoder_wd: 1.2280838320404273e-07 22 | dosers_lr: 0.00008835011062896268 23 | dosers_wd: 5.886005123780177e-06 24 | penalty_adversary: 63.44954424334805 25 | reg_adversary: 48.73324753854268 26 | cycle_coeff: 7.19539336141403 27 | step_size_lr: 25 -------------------------------------------------------------------------------- /notebooks/neurips2025/data_curation/curate_Jiang24_step1.R: -------------------------------------------------------------------------------- 1 | if (!requireNamespace("remotes", quietly = TRUE)) { 2 | install.packages("remotes") 3 | } 4 | 5 | remotes::install_version("Matrix", version = "1.6.5", repos = "http://cran.us.r-project.org") 6 | install.packages("Seurat") 7 | remotes::install_github("mojaveazure/seurat-disk") 8 | 9 | library(Seurat) 10 | library(SeuratDisk) 11 | 12 | ## Set your working directory to the current file (only works on Rstudio) 13 | setwd(dirname(rstudioapi::getActiveDocumentContext()$path)) 14 | 15 | ## Data cache directory 16 | data_cache_dir = '../perturbench_data' ## Change this to your local data directory 17 | 18 | ## Define paths to Seurat objects and download from Zenodo 19 | ifng_seurat_path = paste0(data_cache_dir, "/Seurat_object_IFNG_Perturb_seq.rds", sep="") 20 | system2( 21 | "wget", 22 | c( 23 | "https://zenodo.org/records/10520190/files/Seurat_object_IFNG_Perturb_seq.rds?download=1", 24 | "-O", 25 | ifng_seurat_path 26 | ) 27 | ) 28 | 29 | ifnb_seurat_path = paste0(data_cache_dir, "/Seurat_object_IFNB_Perturb_seq.rds", sep="") 30 | system2( 31 | "wget", 32 | c( 33 | "https://zenodo.org/records/10520190/files/Seurat_object_IFNB_Perturb_seq.rds?download=1", 34 | "-O", 35 | ifnb_seurat_path 36 | ) 37 | ) 38 | 39 | ins_seurat_path = paste0(data_cache_dir, "/Seurat_object_INS_Perturb_seq.rds", sep="") 40 | system2( 41 | "wget", 42 | c( 43 | "https://zenodo.org/records/10520190/files/Seurat_object_INS_Perturb_seq.rds?download=1", 44 | "-O", 45 | ins_seurat_path 46 | ) 47 | ) 48 | 49 | tgfb_seurat_path = paste0(data_cache_dir, "/Seurat_object_TGFB_Perturb_seq.rds", sep="") 50 | system2( 51 | "wget", 52 | c( 53 | "https://zenodo.org/records/10520190/files/Seurat_object_TGFB_Perturb_seq.rds?download=1", 54 | "-O", 55 | tgfb_seurat_path 56 | ) 57 | ) 58 | 59 | tnfa_seurat_path = paste0(data_cache_dir, "/Seurat_object_TNFA_Perturb_seq.rds", sep="") 60 | system2( 61 | "wget", 62 | c( 63 | "https://zenodo.org/records/10520190/files/Seurat_object_TNFA_Perturb_seq.rds?download=1", 64 | "-O", 65 | tnfa_seurat_path 66 | ) 67 | ) 68 | 69 | 70 | seurat_files = c( 71 | ifng_seurat_path, 72 | ifnb_seurat_path, 73 | ins_seurat_path, 74 | tgfb_seurat_path, 75 | tnfa_seurat_path 76 | ) 77 | 78 | for (seurat_path in seurat_files) { 79 | print(seurat_path) 80 | obj = readRDS(seurat_path) 81 | print(obj) 82 | 83 | out_h5seurat_path = gsub(".rds", ".h5Seurat", seurat_path) 84 | SaveH5Seurat(obj, filename = out_h5seurat_path) 85 | Convert(out_h5seurat_path, dest = "h5ad") 86 | } 87 | 88 | -------------------------------------------------------------------------------- /notebooks/neurips2025/data_curation/curate_McFalineFigueroa23_step1.R: -------------------------------------------------------------------------------- 1 | ## Install packages 2 | if (!require("BiocManager", quietly = TRUE)) 3 | install.packages("BiocManager") 4 | 5 | 6 | if (!requireNamespace("remotes", quietly = TRUE)) { 7 | install.packages("remotes") 8 | } 9 | 10 | ## Install libcairo and libxt-dev using apt-get 11 | # system2( 12 | # "apt-get", 13 | # c("install", "libcairo2-dev", "libxt-dev") 14 | # ) 15 | 16 | remotes::install_version("Matrix", version = "1.6.5", repos = "http://cran.us.r-project.org") 17 | install.packages("Seurat") 18 | remotes::install_github("mojaveazure/seurat-disk") 19 | BiocManager::install(c('BiocGenerics', 'DelayedArray', 'DelayedMatrixStats', 20 | 'limma', 'lme4', 'S4Vectors', 'SingleCellExperiment', 21 | 'SummarizedExperiment', 'batchelor', 'HDF5Array', 22 | 'terra', 'ggrastr')) 23 | remotes::install_github('cole-trapnell-lab/monocle3') 24 | 25 | ## Load libraries 26 | library(Seurat) 27 | library(SingleCellExperiment) 28 | library(SeuratDisk) 29 | library(Matrix) 30 | 31 | ## Set your working directory to the current file (only works on Rstudio) 32 | setwd(dirname(rstudioapi::getActiveDocumentContext()$path)) 33 | 34 | ## Data cache directory 35 | data_cache_dir = '../perturbench_data' ## Change this to your local data directory 36 | 37 | ## GXE1 38 | gxe1_cds_path = paste0(data_cache_dir, "/GSM7056148_sciPlexGxE_1_preprocessed_cds.rds.gz", sep="") 39 | print(gxe1_cds_path) 40 | 41 | system2( 42 | "wget", 43 | c( 44 | "https://ftp.ncbi.nlm.nih.gov/geo/samples/GSM7056nnn/GSM7056148/suppl/GSM7056148_sciPlexGxE_1_preprocessed_cds.rds.gz", 45 | "-O", 46 | gxe1_cds_path 47 | ) 48 | ) 49 | system2( 50 | "gzip", 51 | c("-d", gxe1_cds_path) 52 | ) 53 | 54 | gxe1 = readRDS(gsub(".gz", "", gxe1_cds_path)) 55 | gxe1_counts = gxe1@assays@data$counts 56 | gxe1_gene_meta = data.frame(gxe1@rowRanges@elementMetadata@listData) 57 | gxe1_cell_meta = data.frame(gxe1@colData) 58 | head(gxe1_cell_meta) 59 | 60 | colnames(gxe1_counts) = rownames(gxe1_cell_meta) 61 | rownames(gxe1_counts) = gxe1_gene_meta$id 62 | gxe1_seurat = CreateSeuratObject(gxe1_counts, meta.data = gxe1_cell_meta) 63 | gxe1_seurat@assays$RNA@meta.features <- gxe1_gene_meta 64 | gxe1_seurat$cell_type = 'A172' 65 | for (col in colnames(gxe1_seurat@meta.data)) { 66 | if (is.factor(gxe1_seurat@meta.data[[col]])) { 67 | print(col) 68 | gxe1_seurat@meta.data[[col]] = as.character(gxe1_seurat@meta.data[[col]]) 69 | } 70 | } 71 | 72 | SaveH5Seurat(gxe1_seurat, filename = paste0(data_cache_dir, "/gxe1.h5Seurat"), overwrite = T) 73 | Convert(paste0(data_cache_dir, "/gxe1.h5Seurat"), dest = "h5ad", overwrite = T) 74 | 75 | ## GXE2 76 | gxe2_cds_path = paste0(data_cache_dir, "/GSM7056149_sciPlexGxE_2_preprocessed_cds.list.RDS.gz") 77 | system2( 78 | "wget", 79 | c( 80 | "https://ftp.ncbi.nlm.nih.gov/geo/samples/GSM7056nnn/GSM7056149/suppl/GSM7056149%5FsciPlexGxE%5F2%5Fpreprocessed%5Fcds.list.RDS.gz", 81 | "-O", 82 | gxe2_cds_path 83 | ) 84 | ) 85 | system2( 86 | "gzip", 87 | c("-d", gxe2_cds_path) 88 | ) 89 | gxe2_list = readRDS(gsub(".gz", "", gxe2_cds_path)) 90 | 91 | base_path = data_cache_dir 92 | gxe2_seurat_list = lapply(1:length(gxe2_list), function(i) { 93 | sce = gxe2_list[[i]] 94 | counts = sce@assays@.xData$data$counts 95 | gene_meta = data.frame(sce@rowRanges@elementMetadata@listData) 96 | cell_meta = data.frame(sce@colData) 97 | head(cell_meta) 98 | 99 | colnames(counts) = rownames(cell_meta) 100 | rownames(counts) = gene_meta$id 101 | seurat_obj = CreateSeuratObject(counts, meta.data = cell_meta) 102 | seurat_obj@assays$RNA@meta.features <- gene_meta 103 | for (col in colnames(seurat_obj@meta.data)) { 104 | if (is.factor(seurat_obj@meta.data[[col]])) { 105 | print(col) 106 | seurat_obj@meta.data[[col]] = as.character(seurat_obj@meta.data[[col]]) 107 | } 108 | } 109 | cl = names(gxe2_list)[[i]] 110 | seurat_obj$cell_type = cl 111 | 112 | out_h5Seurat = paste0(base_path, "/gxe2_", cl, ".h5Seurat") 113 | SaveH5Seurat(seurat_obj, filename = out_h5Seurat, overwrite = T) 114 | Convert(out_h5Seurat, dest = "h5ad", overwrite = T) 115 | 116 | seurat_obj 117 | }) -------------------------------------------------------------------------------- /notebooks/neurips2025/data_curation/curate_Norman19.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 2023-04-17-Curation: Norman19 Combo Screen" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "Add cross validation splits to Norman19" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import scanpy as sc\n", 24 | "import os\n", 25 | "import subprocess as sp\n", 26 | "from perturbench.analysis.preprocess import preprocess\n", 27 | "from scipy.sparse import csr_matrix\n", 28 | "\n", 29 | "%reload_ext autoreload\n", 30 | "%autoreload 2" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "Download from: https://zenodo.org/records/7041849/files/NormanWeissman2019_filtered.h5ad?download=1" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "data_url = 'https://zenodo.org/records/7041849/files/NormanWeissman2019_filtered.h5ad?download=1'\n", 47 | "data_cache_dir = '../perturbench_data' ## Change this to your local data directory\n", 48 | "\n", 49 | "if not os.path.exists(data_cache_dir):\n", 50 | " os.makedirs(data_cache_dir)\n", 51 | "\n", 52 | "tmp_data_dir = f'{data_cache_dir}/norman19_downloaded.h5ad'\n", 53 | "\n", 54 | "if not os.path.exists(tmp_data_dir):\n", 55 | " sp.call(f'wget {data_url} -O {tmp_data_dir}', shell=True)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 3, 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "data": { 65 | "text/plain": [ 66 | "AnnData object with n_obs × n_vars = 111445 × 33694\n", 67 | " obs: 'guide_id', 'read_count', 'UMI_count', 'coverage', 'gemgroup', 'good_coverage', 'number_of_cells', 'tissue_type', 'cell_line', 'cancer', 'disease', 'perturbation_type', 'celltype', 'organism', 'perturbation', 'nperts', 'ngenes', 'ncounts', 'percent_mito', 'percent_ribo'\n", 68 | " var: 'ensemble_id', 'ncounts', 'ncells'" 69 | ] 70 | }, 71 | "execution_count": 3, 72 | "metadata": {}, 73 | "output_type": "execute_result" 74 | } 75 | ], 76 | "source": [ 77 | "adata = sc.read_h5ad(tmp_data_dir)\n", 78 | "adata" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 4, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "adata.obs.rename(columns = {\n", 88 | " 'nCount_RNA': 'ncounts',\n", 89 | " 'nFeature_RNA': 'ngenes',\n", 90 | " 'percent.mt': 'percent_mito',\n", 91 | " 'cell_line': 'cell_type',\n", 92 | "}, inplace=True)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 5, 98 | "metadata": {}, 99 | "outputs": [ 100 | { 101 | "data": { 102 | "text/plain": [ 103 | "['ARID1A', 'BCORL1', 'FOSB', 'SET_KLF1', 'OSR2', ..., 'CEBPB_OSR2', 'PRDM1_CBFA2T3', 'FOSB_CEBPB', 'ZBTB10_DLX2', 'FEV_CBFA2T3']\n", 104 | "Length: 237\n", 105 | "Categories (237, object): ['AHR', 'AHR_FEV', 'AHR_KLF1', 'ARID1A', ..., 'ZC3HAV1_HOXC13', 'ZNF318', 'ZNF318_FOXL2', 'control']" 106 | ] 107 | }, 108 | "execution_count": 5, 109 | "metadata": {}, 110 | "output_type": "execute_result" 111 | } 112 | ], 113 | "source": [ 114 | "adata.obs.perturbation.unique()" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 6, 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "data": { 124 | "text/plain": [ 125 | "perturbation\n", 126 | "control 11855\n", 127 | "KLF1 1960\n", 128 | "BAK1 1457\n", 129 | "CEBPE 1233\n", 130 | "CEBPE+RUNX1T1 1219\n", 131 | " ... \n", 132 | "CEBPB+CEBPA 64\n", 133 | "CBL+UBASH3A 64\n", 134 | "C3orf72+FOXL2 59\n", 135 | "JUN+CEBPB 59\n", 136 | "JUN+CEBPA 54\n", 137 | "Name: count, Length: 237, dtype: int64" 138 | ] 139 | }, 140 | "execution_count": 6, 141 | "metadata": {}, 142 | "output_type": "execute_result" 143 | } 144 | ], 145 | "source": [ 146 | "adata.obs['perturbation'] = adata.obs['perturbation'].str.replace('_', '+')\n", 147 | "adata.obs['perturbation'] = adata.obs['perturbation'].astype('category')\n", 148 | "adata.obs.perturbation.value_counts()" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 7, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "adata.obs['condition'] = adata.obs.perturbation.copy()" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 8, 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "adata.X = csr_matrix(adata.X)" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 9, 172 | "metadata": {}, 173 | "outputs": [ 174 | { 175 | "name": "stdout", 176 | "output_type": "stream", 177 | "text": [ 178 | "Preprocessing ...\n", 179 | "Filtering for highly variable genes or differentially expressed genes ...\n", 180 | "Processed dataset summary:\n", 181 | "View of AnnData object with n_obs × n_vars = 111445 × 5666\n", 182 | " obs: 'guide_id', 'read_count', 'UMI_count', 'coverage', 'gemgroup', 'good_coverage', 'number_of_cells', 'tissue_type', 'cell_type', 'cancer', 'disease', 'perturbation_type', 'celltype', 'organism', 'perturbation', 'nperts', 'ngenes', 'ncounts', 'percent_mito', 'percent_ribo', 'condition', 'cov_merged', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes'\n", 183 | " var: 'ensemble_id', 'ncounts', 'ncells', 'n_cells', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'highly_variable_nbatches'\n", 184 | " uns: 'log1p', 'hvg', 'rank_genes_groups_cov'\n", 185 | " layers: 'counts'\n" 186 | ] 187 | } 188 | ], 189 | "source": [ 190 | "adata = preprocess(\n", 191 | " adata,\n", 192 | " perturbation_key='condition',\n", 193 | " covariate_keys=['cell_type'],\n", 194 | ")" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 12, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "adata = adata.copy()\n", 204 | "output_data_path = f'{data_cache_dir}/norman19_processed.h5ad'\n", 205 | "adata.write_h5ad(output_data_path)" 206 | ] 207 | } 208 | ], 209 | "metadata": { 210 | "kernelspec": { 211 | "display_name": "perturbench-dev", 212 | "language": "python", 213 | "name": "python3" 214 | }, 215 | "language_info": { 216 | "codemirror_mode": { 217 | "name": "ipython", 218 | "version": 3 219 | }, 220 | "file_extension": ".py", 221 | "mimetype": "text/x-python", 222 | "name": "python", 223 | "nbconvert_exporter": "python", 224 | "pygments_lexer": "ipython3", 225 | "version": "3.11.9" 226 | } 227 | }, 228 | "nbformat": 4, 229 | "nbformat_minor": 4 230 | } 231 | -------------------------------------------------------------------------------- /notebooks/neurips2025/data_curation/readme.txt: -------------------------------------------------------------------------------- 1 | URLs to download public data are in the beginning of each curation notebook or R script. 2 | 3 | Some datasets were uploaded as R objects and thus required a 2 step conversion to AnnData h5ad. -------------------------------------------------------------------------------- /notebooks/neurips2025/gears/gears_helpers.py: -------------------------------------------------------------------------------- 1 | import scanpy as sc 2 | import anndata as ad 3 | import pandas as pd 4 | import numpy as np 5 | import torch 6 | import optuna 7 | 8 | from gears import PertData, GEARS 9 | from analysis.benchmarks.evaluation import Evaluation 10 | 11 | 12 | class GEARSHParamsRange: 13 | """ 14 | Hyperparameter ranges for GEARS. 15 | """ 16 | 17 | @staticmethod 18 | def get_distributions(): 19 | return { 20 | "hidden_size": optuna.distributions.IntDistribution(32, 512, step=32), 21 | "num_go_gnn_layers": optuna.distributions.IntDistribution(1, 3), 22 | "num_gene_gnn_layers": optuna.distributions.IntDistribution(1, 3), 23 | "decoder_hidden_size": optuna.distributions.IntDistribution( 24 | 16, 48, step=16 25 | ), 26 | "num_similar_genes_go_graph": optuna.distributions.IntDistribution( 27 | 10, 30, step=5 28 | ), 29 | "num_similar_genes_co_express_graph": optuna.distributions.IntDistribution( 30 | 10, 30, step=5 31 | ), 32 | "coexpress_threshold": optuna.distributions.FloatDistribution( 33 | 0.2, 0.5, step=0.01 34 | ), 35 | "lr": optuna.distributions.FloatDistribution(5e-6, 1e-3, log=True), 36 | "wd": optuna.distributions.FloatDistribution(1e-8, 1e-3, log=True), 37 | } 38 | 39 | 40 | def run_gears( 41 | pert_data_path: str = "/weka/prime-shared/prime-data/gears/", 42 | dataset_name: str = "norman19", 43 | split_dict_path: str = "/weka/prime-shared/prime-data/gears/norman19_gears_split.pkl", 44 | eval_split: str = "val", 45 | batch_size: str = 32, 46 | epochs: int = 10, 47 | lr: float = 5e-4, 48 | wd: float = 1e-4, 49 | hidden_size: str = 128, 50 | num_go_gnn_layers: int = 1, 51 | num_gene_gnn_layers: int = 1, 52 | decoder_hidden_size: int = 16, 53 | num_similar_genes_go_graph: int = 20, 54 | num_similar_genes_co_express_graph: int = 20, 55 | coexpress_threshold: float = 0.4, 56 | device="cuda:0", 57 | seed=0, 58 | ): 59 | """ 60 | Helper function to train and evaluate a GEARS model 61 | """ 62 | pert_data = PertData(pert_data_path) # specific saved folder 63 | pert_data.load( 64 | data_path=pert_data_path + dataset_name 65 | ) # load the processed data, the path is saved folder + dataset_name 66 | pert_data.prepare_split(split="custom", split_dict_path=split_dict_path) 67 | pert_data.get_dataloader(batch_size=batch_size, test_batch_size=batch_size) 68 | 69 | gears_model = GEARS( 70 | pert_data, 71 | device=device, 72 | weight_bias_track=False, 73 | proj_name="pertnet", 74 | exp_name="pertnet", 75 | ) 76 | gears_model.model_initialize( 77 | hidden_size=hidden_size, 78 | num_go_gnn_layers=num_go_gnn_layers, 79 | num_gene_gnn_layers=num_gene_gnn_layers, 80 | decoder_hidden_size=decoder_hidden_size, 81 | num_similar_genes_go_graph=num_similar_genes_go_graph, 82 | num_similar_genes_co_express_graph=num_similar_genes_co_express_graph, 83 | coexpress_threshold=coexpress_threshold, 84 | seed=seed, 85 | ) 86 | gears_model.train(epochs=epochs, lr=lr, weight_decay=wd) 87 | torch.cuda.empty_cache() 88 | 89 | val_perts = [] 90 | for p in pert_data.set2conditions[eval_split]: 91 | newp_list = [] 92 | for gene in p.split("+"): 93 | if gene in gears_model.pert_list: 94 | newp_list.append(gene) 95 | if len(newp_list) > 0: 96 | val_perts.append(newp_list) 97 | 98 | val_avg_pred = gears_model.predict(val_perts) 99 | pred_df = pd.DataFrame(val_avg_pred).T 100 | pred_df.columns = gears_model.adata.var_names.values 101 | torch.cuda.empty_cache() 102 | 103 | ctrl_adata = gears_model.adata[gears_model.adata.obs.condition == "ctrl"] 104 | val_conditions = ["+".join(p) for p in val_perts] + ["ctrl"] 105 | ref_adata = gears_model.adata[gears_model.adata.obs.condition.isin(val_conditions)] 106 | 107 | pred_adata = sc.AnnData(pred_df) 108 | pred_adata.obs["condition"] = [x.replace("_", "+") for x in pred_adata.obs_names] 109 | pred_adata.obs["condition"] = pred_adata.obs["condition"].astype("category") 110 | pred_adata = ad.concat([pred_adata, ctrl_adata]) 111 | 112 | ev = Evaluation( 113 | model_adatas={ 114 | "GEARS": pred_adata, 115 | }, 116 | ref_adata=ref_adata, 117 | pert_col="condition", 118 | ctrl="ctrl", 119 | ) 120 | 121 | aggr_metrics = [ 122 | ("average", "rmse"), 123 | ("logfc", "cosine"), 124 | ] 125 | summary_metrics_dict = {} 126 | for aggr, metric in aggr_metrics: 127 | ev.evaluate(aggr_method=aggr, metric=metric) 128 | ev.evaluate_pairwise(aggr_method=aggr, metric=metric) 129 | ev.evaluate_rank(aggr_method=aggr, metric=metric) 130 | 131 | metric_df = ev.get_eval(aggr_method=aggr, metric=metric) 132 | rank_df = ev.get_rank_eval(aggr_method=aggr, metric=metric) 133 | summary_metrics_dict[f"{metric}_{aggr}"] = np.mean(metric_df["GEARS"]) 134 | summary_metrics_dict[f"{metric}_rank_{aggr}"] = np.mean(rank_df["GEARS"]) 135 | 136 | return pd.Series(summary_metrics_dict) 137 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "perturbench" 7 | description = "PerturBench: Benchmarking Machine Learning Models for Cellular Perturbation Analysis" 8 | readme = "README.md" 9 | requires-python = ">=3.10" 10 | license = {text = "..."} 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | ] 14 | dependencies = [ 15 | 'numpy', 16 | 'pyyaml', 17 | 'pandas', 18 | 'anndata', 19 | 'scanpy', 20 | 'scipy==1.12.0', 21 | 'tqdm', 22 | 'lightning', 23 | "torch", 24 | "torchvision", 25 | "tensorboard", 26 | "hydra-core", 27 | "hydra_colorlog", 28 | 'mlflow-skinny', 29 | 'matplotlib', 30 | 'seaborn', 31 | 'scikit-learn', 32 | 'scikit-misc', 33 | 'adjusttext', 34 | 'pytest', 35 | 'rich', 36 | 'psycopg2-binary', 37 | 'optuna', 38 | 'ray', 39 | 'python-dotenv', 40 | ] 41 | dynamic = ["version"] 42 | 43 | [project.optional-dependencies] 44 | dev = [ 45 | "ruff", 46 | ] 47 | test = [ 48 | "pytest", 49 | "pytest-cov", 50 | ] 51 | cli = [ 52 | "click", 53 | "rich", 54 | ] 55 | 56 | [project.scripts] 57 | train = "perturbench.modelcore.train:main" 58 | predict = "perturbench.modelcore.predict:main" 59 | 60 | [tool.setuptools.packages.find] 61 | where = ["src"] # list of folders that contain the packages (["."] by default) 62 | include = ["perturbench*"] # package names should match these glob patterns (["*"] by default) 63 | exclude = ["tests", "docs", "examples"] # exclude packages matching these glob patterns (empty by default) 64 | # namespaces = false # to disable scanning PEP 420 namespaces (true by default) 65 | 66 | [tool.setuptools.dynamic] 67 | version = {attr = "perturbench.modelcore.VERSION"} 68 | 69 | [tool.pytest.ini_options] 70 | addopts = ["--import-mode=importlib"] 71 | 72 | 73 | [tool.ruff.lint.per-file-ignores] 74 | # Ignore `E402` (import violations) in all `__init__.py` files, and in `path/to/file.py`. 75 | "__init__.py" = ["F401"] 76 | "*.ipynb" = ["E402", "F401"] 77 | # "path/to/file.py" = ["E402"] 78 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | -------------------------------------------------------------------------------- /src/perturbench/analysis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/altoslabs/perturbench/e57c2fbafeb1b22df4923b4a1b3d3c82d2ba57ef/src/perturbench/analysis/__init__.py -------------------------------------------------------------------------------- /src/perturbench/analysis/benchmarks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/altoslabs/perturbench/e57c2fbafeb1b22df4923b4a1b3d3c82d2ba57ef/src/perturbench/analysis/benchmarks/__init__.py -------------------------------------------------------------------------------- /src/perturbench/analysis/benchmarks/metrics.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import r2_score 2 | from sklearn.metrics.pairwise import euclidean_distances, rbf_kernel 3 | from scipy.stats import pearsonr 4 | import numpy as np 5 | from numpy.linalg import norm 6 | import pandas as pd 7 | import tqdm 8 | from ..utils import merge_cols 9 | 10 | 11 | def compute_metric(x, y, metric): 12 | """Compute specified similarity/distance metric between x and y vectors""" 13 | 14 | if metric == "pearson": 15 | score = pearsonr(x, y)[0] 16 | # elif metric == 'spearman': 17 | # score = spearmanr(x, y)[0] 18 | elif metric == "r2_score": 19 | score = r2_score(x, y) 20 | elif metric == "cosine": 21 | score = np.dot(x, y) / (norm(x) * norm(y)) 22 | elif metric == "mse": 23 | score = np.mean(np.square(x - y)) 24 | elif metric == "rmse": 25 | score = np.sqrt(np.mean(np.square(x - y))) 26 | elif metric == "mae": 27 | score = np.mean(np.abs(x - y)) 28 | 29 | return score 30 | 31 | 32 | def compare_perts( 33 | pred, ref, features=None, perts=None, metric="pearson", deg_dict=None 34 | ): 35 | """Compare expression similarities between `pred` and `ref` DataFrames using the specified metric""" 36 | 37 | if perts is None: 38 | perts = list(set(pred.index).intersection(ref.index)) 39 | assert len(perts) > 0 40 | else: 41 | perts = list(perts) 42 | 43 | if features is not None: 44 | pred = pred.loc[:, features] 45 | ref = ref.loc[:, features] 46 | 47 | pred = pred.loc[perts, :] 48 | ref = ref.loc[perts, :] 49 | 50 | pred = pred.replace([np.inf, -np.inf], 0) 51 | ref = ref.replace([np.inf, -np.inf], 0) 52 | 53 | eval_metric = [] 54 | for p in perts: 55 | if deg_dict is not None: 56 | genes = deg_dict[p] 57 | else: 58 | genes = ref.columns 59 | 60 | eval_metric.append( 61 | compute_metric(pred.loc[p, genes], ref.loc[p, genes], metric) 62 | ) 63 | 64 | eval_scores = pd.Series(index=perts, data=eval_metric) 65 | return eval_scores 66 | 67 | 68 | def pairwise_metric_helper( 69 | df, 70 | df2=None, 71 | metric="rmse", 72 | pairwise_deg_dict=None, 73 | verbose=False, 74 | ): 75 | if df2 is None: 76 | df2 = df 77 | 78 | mat = pd.DataFrame(0.0, index=df.index, columns=df2.index) 79 | for p1 in tqdm.tqdm(df.index, disable=not verbose): 80 | for p2 in df2.index: 81 | if pairwise_deg_dict is not None: 82 | pp = frozenset([p1, p2]) 83 | genes_ix = pairwise_deg_dict[pp] 84 | 85 | m = compute_metric( 86 | df.loc[p1, genes_ix], 87 | df2.loc[p2, genes_ix], 88 | metric=metric, 89 | ) 90 | else: 91 | m = compute_metric( 92 | df.loc[p1], 93 | df2.loc[p2], 94 | metric=metric, 95 | ) 96 | mat.at[p1, p2] = m 97 | 98 | return mat 99 | 100 | 101 | def rank_helper(pred_ref_mat, metric_type): 102 | rel_ranks = pd.Series(1.0, index=pred_ref_mat.columns) 103 | for p in pred_ref_mat.columns: 104 | pred_metrics = pred_ref_mat.loc[:, p] 105 | pred_metrics = pred_metrics.sample(frac=1.0) ## Shuffle to avoid ties 106 | if metric_type == "distance": 107 | pred_metrics = pred_metrics.sort_values(ascending=True) 108 | elif metric_type == "similarity": 109 | pred_metrics = pred_metrics.sort_values(ascending=False) 110 | else: 111 | raise ValueError( 112 | "Invalid metric_type, should be either distance or similarity" 113 | ) 114 | 115 | rel_ranks.loc[p] = np.where(pred_metrics.index == p)[0][0] 116 | 117 | rel_ranks = rel_ranks / len(rel_ranks) 118 | return rel_ranks 119 | 120 | 121 | def mmd_energy_distance_helper( 122 | eval, # an Evaluation objective 123 | model_name, 124 | pert_col, 125 | cov_cols, 126 | ctrl, 127 | delim='_', 128 | kernel='energy_distance', 129 | gamma=None, 130 | ): 131 | model_adata = eval.adatas[model_name] 132 | ref_adata = eval.adatas['ref'] 133 | 134 | model_adata.obs[pert_col] = model_adata.obs[pert_col].astype('category') 135 | ref_adata.obs[pert_col] = ref_adata.obs[pert_col].astype('category') 136 | 137 | if len(cov_cols) == 0: 138 | model_adata.obs['_dummy_cov'] = '1' 139 | ref_adata.obs['_dummy_cov'] = '1' 140 | cov_cols = ['_dummy_cov'] 141 | 142 | for col in cov_cols: 143 | assert col in model_adata.obs.columns 144 | assert col in ref_adata.obs.columns 145 | 146 | if kernel == 'energy_distance': 147 | kernel_fns = [lambda x, y: - euclidean_distances(x, y)] 148 | elif kernel == 'rbf_kernel': 149 | if gamma is None: 150 | all_gamma = np.logspace(1, -3, num=5) 151 | elif isinstance(gamma, list): 152 | all_gamma = np.array(gamma) 153 | else: 154 | all_gamma = np.array([gamma]) 155 | kernel_fns = [lambda x, y: rbf_kernel(x, y, gamma=gamma) for gamma in all_gamma] 156 | print('rbf kernels with gammas:', kernel_fns) 157 | else: 158 | raise ValueError('Invalid kernel') 159 | 160 | model_adata_covs = merge_cols(model_adata.obs, cov_cols, delim=delim) 161 | ref_adata_covs = merge_cols(ref_adata.obs, cov_cols, delim=delim) 162 | 163 | ret = {'cov_pert': [], 'model': [], 'metric': []} 164 | for cov in model_adata_covs.cat.categories: 165 | 166 | if len(model_adata[model_adata_covs == cov, :].obs[pert_col].unique()) > 1: # has any perturbations beside control 167 | 168 | model_adata_subset_cov = model_adata[model_adata_covs == cov, :] 169 | ref_adata_subset_cov = ref_adata[ref_adata_covs == cov, :] 170 | model_adata_covs_perts = merge_cols(model_adata_subset_cov.obs, [pert_col], delim=delim) 171 | ref_adata_covs_perts = merge_cols(ref_adata_subset_cov.obs, [pert_col], delim=delim) 172 | 173 | for i, pert in enumerate(model_adata_covs_perts.cat.categories): 174 | if pert == ctrl: 175 | continue 176 | 177 | population_pred = model_adata_subset_cov[model_adata_covs_perts.isin([pert]), :].X 178 | population_truth = ref_adata_subset_cov[ref_adata_covs_perts.isin([pert]), :].X 179 | 180 | all_mmd = [] 181 | for kernel in kernel_fns: 182 | xx = kernel(population_pred, population_pred) 183 | xy = kernel(population_pred, population_truth) 184 | yy = kernel(population_truth, population_truth) 185 | mmd = xx.mean() + yy.mean() - 2 * xy.mean() 186 | all_mmd.append(mmd) 187 | mmd = np.nanmean(all_mmd) 188 | 189 | ret['cov_pert'].append(f'{cov}{delim}{pert}') 190 | ret['model'].append(model_name) 191 | ret['metric'].append(mmd) 192 | 193 | eval.mmd_df = pd.DataFrame.from_dict(ret) 194 | 195 | return eval.mmd_df -------------------------------------------------------------------------------- /src/perturbench/analysis/plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | from adjustText import adjust_text 4 | 5 | 6 | def scatter_labels( 7 | x, 8 | y, 9 | df=None, 10 | hue=None, 11 | labels=None, 12 | label_size=6, 13 | x_title=None, 14 | y_title=None, 15 | axis_title_size=14, 16 | axis_text_size=12, 17 | plot_title=None, 18 | title_size=15, 19 | ax=None, 20 | figsize=None, 21 | hide_legend=True, 22 | size=60, 23 | alpha=0.8, 24 | xlim=None, 25 | ylim=None, 26 | ident_line=False, 27 | quadrants=False, 28 | **kwargs, 29 | ): 30 | """Generate a scatterplot with optional text labels for data points""" 31 | 32 | if ax is None: 33 | fig, ax = plt.subplots(figsize=figsize) 34 | 35 | if df is not None: 36 | sns.scatterplot(ax=ax, x=x, y=y, hue=hue, data=df, marker=".", s=size, **kwargs) 37 | else: 38 | if type(size) in [int, float]: 39 | size = [size] * len(x) 40 | 41 | sns.scatterplot(ax=ax, x=x, y=y, hue=hue, marker=".", s=size, **kwargs) 42 | 43 | ax.set_xlabel(x_title, size=axis_title_size) 44 | ax.set_ylabel(y_title, size=axis_title_size) 45 | ax.tick_params(axis="x", labelsize=axis_text_size) 46 | ax.tick_params(axis="y", labelsize=axis_text_size) 47 | ax.set_title(plot_title, size=title_size) 48 | 49 | if xlim is not None: 50 | ax.set_xlim(xmin=xlim[0], xmax=xlim[1]) 51 | if ylim is not None: 52 | ax.set_ylim(ymin=ylim[0], ymax=ylim[1]) 53 | 54 | xlim = ax.get_xlim() 55 | ylim = ax.get_ylim() 56 | min_pt = max(xlim[0], ylim[0]) 57 | max_pt = min(xlim[1], ylim[1]) 58 | if ident_line: 59 | ax.plot((min_pt, max_pt), (min_pt, max_pt), ls="--", color="red", alpha=0.7) 60 | 61 | if quadrants: 62 | ax.axvline(ls="--", color="red", alpha=0.3) 63 | ax.axhline(ls="--", color="red", alpha=0.3) 64 | 65 | if hide_legend and (ax.get_legend() is not None): 66 | ax.get_legend().remove() 67 | 68 | if labels is not None: 69 | assert df is not None 70 | texts = [] 71 | for lab in labels: 72 | texts.append(ax.text(df.loc[lab, x], df.loc[lab, y], lab, size=label_size)) 73 | 74 | adjust_text(texts, arrowprops=dict(arrowstyle="-", color="black"), ax=ax) 75 | 76 | 77 | ## Base boxplot 78 | def boxplot_jitter( 79 | x, 80 | y, 81 | df, 82 | hue=None, 83 | x_title=None, 84 | y_title=None, 85 | axis_title_size=14, 86 | axis_text_size=12, 87 | jitter_size=10, 88 | alpha=0.8, 89 | figsize=None, 90 | plot_title=None, 91 | title_size=15, 92 | ax=None, 93 | violin=False, 94 | ): 95 | """Generate a boxplot or violin plot with jitter for the data points""" 96 | if ax is None: 97 | _, ax = plt.subplots(figsize=figsize) 98 | 99 | if violin: 100 | sns.violinplot( 101 | ax=ax, x=x, y=y, hue=hue, data=df, palette=sns.color_palette("colorblind") 102 | ) 103 | plt.setp(ax.collections, alpha=alpha) 104 | 105 | else: 106 | sns.boxplot( 107 | ax=ax, 108 | x=x, 109 | y=y, 110 | hue=hue, 111 | data=df, 112 | showfliers=False, 113 | palette=sns.color_palette("colorblind"), 114 | ) 115 | for patch in ax.patches: 116 | r, g, b, a = patch.get_facecolor() 117 | patch.set_facecolor((r, g, b, alpha)) 118 | 119 | sns.stripplot( 120 | ax=ax, 121 | x=x, 122 | y=y, 123 | hue=hue, 124 | data=df, 125 | color="black", 126 | marker=".", 127 | size=jitter_size, 128 | dodge=True, 129 | label=False, 130 | legend=None, 131 | ) 132 | 133 | ax.set_xlabel(x_title, size=axis_title_size) 134 | ax.set_ylabel(y_title, size=axis_title_size) 135 | ax.set_title(plot_title, size=title_size) 136 | ax.tick_params(axis="x", labelrotation=90, labelsize=axis_text_size) 137 | ax.tick_params(axis="y", labelsize=axis_text_size) 138 | 139 | if ax is None: 140 | plt.show() 141 | -------------------------------------------------------------------------------- /src/perturbench/analysis/preprocess.py: -------------------------------------------------------------------------------- 1 | import scanpy as sc 2 | import anndata as ad 3 | from .utils import merge_cols 4 | import pandas as pd 5 | import warnings 6 | 7 | 8 | def differential_expression_by_covariate( 9 | adata, 10 | perturbation_key: str, 11 | perturbation_control_value: str, 12 | covariate_keys: list[str] = [], 13 | n_differential_genes=25, 14 | rankby_abs=True, 15 | key_added="rank_genes_groups_cov", 16 | return_dict=False, 17 | delim="_", 18 | ): 19 | """Finds the top `n_differential_genes` differentially expressed genes for 20 | each perturbation. The differential expression is run separately for 21 | each unique covariate category in the dataset. 22 | 23 | Args: 24 | adata: AnnData dataset 25 | perturbation_key: the key in adata.obs that contains the perturbations 26 | perturbation_combination_delimiter: the delimiter used to separate 27 | perturbations in the perturbation_key 28 | covariate_keys: a list of keys in adata.obs that contain the covariates 29 | perturbation_control_value: the value in adata.obs[perturbation_key] that 30 | corresponds to control cells 31 | n_differential_genes: number of top differentially expressed genes for 32 | each perturbation 33 | rankby_abs: if True, rank genes by absolute values of the score, thus including 34 | top downregulated genes in the top N genes. If False, the ranking will 35 | have only upregulated genes at the top. 36 | key_added: key used when adding the DEG dictionary to adata.uns 37 | return_dict: if True, return the DEG dictionary 38 | 39 | Returns: 40 | Adds the DEG dictionary to adata.uns 41 | 42 | If return_dict is True returns: 43 | gene_dict : dict 44 | Dictionary where groups are stored as keys, and the list of DEGs 45 | are the corresponding values 46 | """ 47 | if "base" not in adata.uns["log1p"]: 48 | adata.uns["log1p"]["base"] = None 49 | 50 | gene_dict = {} 51 | if len(covariate_keys) == 0: 52 | sc.tl.rank_genes_groups( 53 | adata, 54 | groupby=perturbation_key, 55 | reference=perturbation_control_value, 56 | rankby_abs=rankby_abs, 57 | n_genes=n_differential_genes, 58 | ) 59 | 60 | top_de_genes = pd.DataFrame(adata.uns["rank_genes_groups"]["names"]) 61 | for group in top_de_genes: 62 | gene_dict[group] = top_de_genes[group].tolist() 63 | 64 | else: 65 | merged_covariates = merge_cols(adata.obs, covariate_keys, delim=delim) 66 | for unique_covariate in merged_covariates.unique(): 67 | adata_cov = adata[merged_covariates == unique_covariate] 68 | sc.tl.rank_genes_groups( 69 | adata_cov, 70 | groupby=perturbation_key, 71 | reference=perturbation_control_value, 72 | rankby_abs=rankby_abs, 73 | n_genes=n_differential_genes, 74 | ) 75 | 76 | top_de_genes = pd.DataFrame(adata_cov.uns["rank_genes_groups"]["names"]) 77 | for group in top_de_genes: 78 | cov_group = unique_covariate + delim + group 79 | gene_dict[cov_group] = top_de_genes[group].tolist() 80 | 81 | adata.uns[key_added] = gene_dict 82 | 83 | if return_dict: 84 | return gene_dict 85 | 86 | 87 | def preprocess( 88 | adata: ad.AnnData, 89 | perturbation_key: str, 90 | covariate_keys: list[str], 91 | control_value: str = "control", 92 | combination_delimiter: str = "+", 93 | highly_variable: int = 4000, 94 | degs: int = 25, 95 | ): 96 | adata.raw = None 97 | adata.var_names_make_unique() 98 | adata.obs_names_make_unique() 99 | 100 | ## Merge covariate columns 101 | adata.obs["cov_merged"] = merge_cols(adata.obs, covariate_keys) 102 | 103 | ## Preprocess 104 | print("Preprocessing ...") 105 | sc.pp.filter_genes(adata, min_cells=10) 106 | sc.pp.calculate_qc_metrics(adata, inplace=True) 107 | 108 | ## Normalize if needed 109 | adata.layers["counts"] = adata.X.copy() 110 | sc.pp.normalize_total(adata) 111 | sc.pp.log1p(adata) 112 | 113 | ## Pull out perturbed genes 114 | unique_perturbations = set() 115 | for comb in adata.obs[perturbation_key].unique(): 116 | unique_perturbations.update(comb.split(combination_delimiter)) 117 | unique_perturbations = unique_perturbations.intersection(adata.var_names) 118 | 119 | ## Subset to highly variable or differentially expressed genes 120 | if highly_variable > 0: 121 | print( 122 | "Filtering for highly variable genes or differentially expressed genes ..." 123 | ) 124 | sc.pp.highly_variable_genes( 125 | adata, 126 | batch_key="cov_merged", 127 | flavor="seurat_v3", 128 | layer="counts", 129 | n_top_genes=int(highly_variable), 130 | subset=False, 131 | ) 132 | 133 | with warnings.catch_warnings(): 134 | warnings.simplefilter("ignore") 135 | deg_gene_dict = differential_expression_by_covariate( 136 | adata, 137 | perturbation_key, 138 | control_value, 139 | covariate_keys, 140 | n_differential_genes=degs, 141 | rankby_abs=True, 142 | key_added="rank_genes_groups_cov", 143 | return_dict=True, 144 | delim="_", 145 | ) 146 | deg_genes = set() 147 | for genes in deg_gene_dict.values(): 148 | deg_genes.update(genes) 149 | 150 | var_genes = list(adata.var_names[adata.var["highly_variable"]]) 151 | var_genes = list(unique_perturbations.union(var_genes).union(deg_genes)) 152 | adata = adata[:, var_genes] 153 | 154 | print("Processed dataset summary:") 155 | print(adata) 156 | 157 | return adata 158 | -------------------------------------------------------------------------------- /src/perturbench/analysis/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | 4 | 5 | def merge_cols(obs_df, cols, delim="_"): 6 | """Merge columns in DataFrame""" 7 | 8 | covs = obs_df[cols[0]] 9 | if len(cols) > 1: 10 | for i in range(1, len(cols)): 11 | covs = covs.astype(str) + delim + obs_df[cols[i]].astype(str) 12 | covs = covs.astype("category") 13 | return covs 14 | 15 | 16 | def subsample_by_group(adata, obs_key, max_cells=10000): 17 | """ 18 | Downsample anndata on a per group basis 19 | """ 20 | cells_keep = [] 21 | for k in list(adata.obs[obs_key].unique()): 22 | cells = list(adata[adata.obs[obs_key] == k].obs_names) 23 | if len(cells) > max_cells: 24 | cells = random.sample(cells, k=max_cells) 25 | cells_keep.extend(cells) 26 | 27 | print("Input cells: " + str(adata.shape[0])) 28 | print("Sampled cells: " + str(len(cells_keep))) 29 | 30 | adata_downsampled = adata[cells_keep, :] 31 | return adata_downsampled 32 | 33 | 34 | def pert_cluster_filter(adata, pert_key, cluster_key, delim="_", min_cells=20): 35 | """ 36 | Filter anndata to only include perturbations that have at least min_cells in all clusters 37 | """ 38 | 39 | adata.obs["pert_cluster"] = ( 40 | adata.obs[pert_key].astype(str) + delim + adata.obs[cluster_key].astype(str) 41 | ) 42 | pert_cl_counts = adata.obs["pert_cluster"].value_counts() 43 | pert_cl_keep = pert_cl_counts.loc[[x >= min_cells for x in pert_cl_counts]].index 44 | 45 | pert_counts_dict = defaultdict(int) 46 | for x in list(pert_cl_keep): 47 | pert_counts_dict[x.split(delim)[0]] += 1 48 | 49 | perts_keep = [ 50 | x 51 | for x, n in pert_counts_dict.items() 52 | if n == len(adata.obs[cluster_key].unique()) 53 | ] 54 | 55 | print("Original perturbations: " + str(len(adata.obs[pert_key].unique()))) 56 | print("Filtered perturbations: " + str(len(perts_keep))) 57 | 58 | adata_filtered = adata[adata.obs[pert_key].isin(perts_keep), :] 59 | return adata_filtered 60 | 61 | 62 | def get_ensembl_mappings(): 63 | try: 64 | from pybiomart import Dataset 65 | except ImportError: 66 | raise ImportError("Please install the pybiomart package to use this function") 67 | 68 | # Set up connection to server 69 | dataset = Dataset(name="hsapiens_gene_ensembl", host="http://www.ensembl.org") 70 | 71 | id_gene_df = dataset.query(attributes=["ensembl_gene_id", "hgnc_symbol"]) 72 | 73 | ensembl_to_genesymbol = {} 74 | # Store the data in a dict 75 | for gene_id, gene_symbol in zip( 76 | id_gene_df["Gene stable ID"], id_gene_df["HGNC symbol"] 77 | ): 78 | ensembl_to_genesymbol[gene_id] = gene_symbol 79 | 80 | return ensembl_to_genesymbol 81 | -------------------------------------------------------------------------------- /src/perturbench/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/altoslabs/perturbench/e57c2fbafeb1b22df4923b4a1b3d3c82d2ba57ef/src/perturbench/configs/__init__.py -------------------------------------------------------------------------------- /src/perturbench/configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - early_stopping 3 | - lr_monitor 4 | - model_summary 5 | - model_checkpoint 6 | - rich_progress_bar 7 | - _self_ 8 | 9 | model_summary: 10 | max_depth: -1 11 | -------------------------------------------------------------------------------- /src/perturbench/configs/callbacks/early_stopping.yaml: -------------------------------------------------------------------------------- 1 | early_stopping: 2 | _target_: lightning.pytorch.callbacks.EarlyStopping 3 | monitor: val_loss # quantity to monitor 4 | patience: 50 # early stopping patience -------------------------------------------------------------------------------- /src/perturbench/configs/callbacks/lr_monitor.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.LearningRateMonitor.html 2 | 3 | lr_monitor: 4 | _target_: lightning.pytorch.callbacks.LearningRateMonitor 5 | log_momentum: True 6 | logging_interval: epoch 7 | -------------------------------------------------------------------------------- /src/perturbench/configs/callbacks/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 3 | monitor: val_loss # quantity to monitor 4 | dirpath: ${hydra:runtime.output_dir}/checkpoints # directory to save checkpoints 5 | # every_n_train_steps: 80 # save checkpoint every n train steps 6 | # every_n_epochs: 5 # save checkpoint every n epochs -------------------------------------------------------------------------------- /src/perturbench/configs/callbacks/model_summary.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html 2 | 3 | model_summary: 4 | _target_: lightning.pytorch.callbacks.RichModelSummary 5 | max_depth: 1 # the maximum depth of layer nesting that the summary will include 6 | -------------------------------------------------------------------------------- /src/perturbench/configs/callbacks/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/altoslabs/perturbench/e57c2fbafeb1b22df4923b4a1b3d3c82d2ba57ef/src/perturbench/configs/callbacks/none.yaml -------------------------------------------------------------------------------- /src/perturbench/configs/callbacks/rich_progress_bar.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html 2 | 3 | rich_progress_bar: 4 | _target_: lightning.pytorch.callbacks.RichProgressBar 5 | -------------------------------------------------------------------------------- /src/perturbench/configs/data/devel.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - transform: linear_model_pipeline 3 | - collate: linear_model_collate 4 | - splitter: default 5 | - evaluation: default 6 | 7 | _target_: perturbench.data.modules.AnnDataLitModule 8 | datapath: src/perturbench/data/resources/devel.h5ad 9 | perturbation_key: condition 10 | perturbation_combination_delimiter: + 11 | perturbation_control_value: control 12 | covariate_keys: [cell_type] 13 | batch_size: 8 14 | num_workers: 0 15 | batch_sample: True 16 | add_controls: True 17 | use_counts: False -------------------------------------------------------------------------------- /src/perturbench/configs/data/evaluation/default.yaml: -------------------------------------------------------------------------------- 1 | split_value_to_evaluate: val 2 | max_control_cells_per_covariate: 1000 3 | 4 | evaluation_pipelines: 5 | - aggregation: average 6 | metric: rmse 7 | rank: True 8 | 9 | - aggregation: logfc 10 | metric: cosine 11 | rank: True 12 | 13 | save_evaluation: True 14 | save_dir: "${paths.output_dir}/evaluation/" 15 | chunk_size: 20 16 | print_summary: True -------------------------------------------------------------------------------- /src/perturbench/configs/data/evaluation/final_test.yaml: -------------------------------------------------------------------------------- 1 | split_value_to_evaluate: test 2 | max_control_cells_per_covariate: 1000 3 | 4 | evaluation_pipelines: 5 | - aggregation: average 6 | metric: rmse 7 | rank: True 8 | 9 | - aggregation: logfc 10 | metric: cosine 11 | rank: True 12 | 13 | save_evaluation: True 14 | save_dir: "${paths.output_dir}/evaluation/" 15 | chunk_size: 20 16 | print_summary: True -------------------------------------------------------------------------------- /src/perturbench/configs/data/frangieh21.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - transform: linear_model_pipeline 3 | - splitter: saved_split 4 | - evaluation: default 5 | 6 | _target_: perturbench.data.modules.AnnDataLitModule 7 | datapath: ${paths.data_dir}/frangieh21_processed.h5ad 8 | perturbation_key: condition 9 | perturbation_combination_delimiter: + 10 | perturbation_control_value: control 11 | covariate_keys: [treatment] 12 | batch_size: 8000 13 | num_workers: 8 14 | batch_sample: True 15 | add_controls: True 16 | use_counts: False 17 | 18 | splitter: 19 | split_path: ${paths.data_dir}/frangieh21_split.csv -------------------------------------------------------------------------------- /src/perturbench/configs/data/jiang24.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - transform: linear_model_pipeline 3 | - splitter: saved_split 4 | - evaluation: default 5 | 6 | _target_: perturbench.data.modules.AnnDataLitModule 7 | datapath: ${paths.data_dir}/jiang24_processed.h5ad 8 | perturbation_key: condition 9 | perturbation_combination_delimiter: + 10 | perturbation_control_value: control 11 | covariate_keys: [cell_type,treatment] 12 | batch_size: 2000 13 | num_workers: 8 14 | batch_sample: True 15 | add_controls: True 16 | use_counts: False 17 | 18 | splitter: 19 | split_path: ${paths.data_dir}/jiang24_split.csv -------------------------------------------------------------------------------- /src/perturbench/configs/data/mcfaline23.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - transform: linear_model_pipeline 3 | - splitter: mcfaline23_split 4 | - evaluation: default 5 | 6 | _target_: perturbench.data.modules.AnnDataLitModule 7 | datapath: ${paths.data_dir}/mcfaline23_gxe_processed.h5ad 8 | perturbation_key: condition 9 | perturbation_combination_delimiter: + 10 | perturbation_control_value: control 11 | covariate_keys: [cell_type,treatment] 12 | batch_size: 8000 13 | num_workers: 12 14 | num_val_workers: 2 15 | num_test_workers: 0 16 | batch_sample: True 17 | add_controls: True 18 | use_counts: False 19 | embedding_key: null -------------------------------------------------------------------------------- /src/perturbench/configs/data/norman19.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - transform: linear_model_pipeline 3 | - splitter: combination_prediction_task 4 | - evaluation: default 5 | 6 | _target_: perturbench.data.modules.AnnDataLitModule 7 | datapath: ${paths.data_dir}/norman19_processed.h5ad 8 | perturbation_key: condition 9 | perturbation_combination_delimiter: + 10 | perturbation_control_value: control 11 | covariate_keys: [cell_type] 12 | batch_size: 4000 13 | num_workers: 8 14 | batch_sample: True 15 | add_controls: True 16 | use_counts: False 17 | embedding_key: null 18 | 19 | splitter: 20 | max_heldout_fraction_per_covariate: 0.7 -------------------------------------------------------------------------------- /src/perturbench/configs/data/sciplex3.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - transform: linear_model_pipeline 3 | - splitter: cell_type_transfer_task 4 | - evaluation: default 5 | 6 | _target_: perturbench.data.modules.AnnDataLitModule 7 | datapath: ${paths.data_dir}/sciplex3_processed.h5ad 8 | perturbation_key: condition 9 | perturbation_combination_delimiter: + 10 | perturbation_control_value: control 11 | covariate_keys: [cell_type] 12 | batch_size: 4000 13 | num_workers: 8 14 | batch_sample: True 15 | add_controls: True 16 | use_counts: False 17 | embedding_key: null -------------------------------------------------------------------------------- /src/perturbench/configs/data/splitter/cell_type_transfer_task.yaml: -------------------------------------------------------------------------------- 1 | # split_path: data/splits ## Specify this path if you want to use a specific saved path for the splits 2 | task: transfer ## Either `transfer`, `combine`, or `combine_inverse` 3 | covariate_keys: [cell_type] 4 | min_train_covariates: 1 ## Minimum number of covariates to train on per perturbation 5 | max_heldout_covariates: 1 ## Maximum number of covariates to hold out per perturbation 6 | max_heldout_fraction_per_covariate: 0.3 7 | max_heldout_perturbations_per_covariate: 200 8 | train_control_fraction: 0.5 9 | downsample_fraction: 1.0 10 | splitter_seed: 42 11 | save: True ## Whether to save the split to disk 12 | output_path: "${paths.output_dir}/" ## Specify this path if you want to save the split -------------------------------------------------------------------------------- /src/perturbench/configs/data/splitter/cell_type_treatment_transfer_task.yaml: -------------------------------------------------------------------------------- 1 | # split_path: data/splits ## Specify this path if you want to use a specific saved path for the splits 2 | task: transfer ## Either `transfer`, `combine`, or `combine_inverse` 3 | covariate_keys: [cell_type,treatment] 4 | min_train_covariates: 1 ## Minimum number of covariates to train on per perturbation 5 | max_heldout_covariates: 1 ## Maximum number of covariates to hold out per perturbation 6 | max_heldout_fraction_per_covariate: 0.3 7 | max_heldout_perturbations_per_covariate: 200 8 | train_control_fraction: 0.5 9 | downsample_fraction: 1.0 10 | splitter_seed: 42 11 | save: True ## Whether to save the split to disk 12 | output_path: "${paths.output_dir}/" ## Specify this path if you want to save the split -------------------------------------------------------------------------------- /src/perturbench/configs/data/splitter/combination_prediction_task.yaml: -------------------------------------------------------------------------------- 1 | # split_path: data/splits ## Specify this path if you want to use a specific saved path for the splits 2 | task: combine ## Either `transfer`, `combine`, or `combine_inverse` 3 | covariate_keys: [cell_type] 4 | max_heldout_fraction_per_covariate: 0.3 5 | max_heldout_perturbations_per_covariate: 200 6 | train_control_fraction: 0.5 7 | downsample_fraction: 1.0 8 | splitter_seed: 42 9 | save: True ## Whether to save the split to disk 10 | output_path: "${paths.output_dir}/" ## Specify this path if you want to save the split -------------------------------------------------------------------------------- /src/perturbench/configs/data/splitter/mcfaline23_split.yaml: -------------------------------------------------------------------------------- 1 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/full_covariate_split.csv ## Specify this path if you want to use a specific saved path for the splits -------------------------------------------------------------------------------- /src/perturbench/configs/data/splitter/saved_split.yaml: -------------------------------------------------------------------------------- 1 | split_path: null -------------------------------------------------------------------------------- /src/perturbench/configs/data/transform/linear_model_pipeline.yaml: -------------------------------------------------------------------------------- 1 | conf: 2 | _target_: perturbench.data.transforms.pipelines.SingleCellPipeline 3 | _partial_: true 4 | dependencies: 5 | perturbation_uniques: perturbation_uniques 6 | covariate_uniques: covariate_uniques 7 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/frangieh21/biolord_best_params_frangieh21.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: biolord 5 | - override /callbacks: default 6 | - override /data: frangieh21 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | precision: 32 12 | 13 | model: 14 | _target_: anubis.models.BiolordStar 15 | dropout: 0.30000000000000004 16 | encoder_width: 5376 17 | latent_dim: 512 18 | lr: 2.643716468011295e-05 19 | n_genes: null 20 | n_layers: 1 21 | n_perts: null 22 | penalty_weight: 49.426728002719415 23 | softplus_output: true 24 | wd: 3.8596437745886474e-08 25 | 26 | data: 27 | add_controls: false 28 | batch_size: 1000 29 | evaluation: 30 | chunk_size: 10 31 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/frangieh21/cpa_best_params_frangieh21.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: cpa 5 | - override /callbacks: default 6 | - override /data: frangieh21 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | 12 | model: 13 | adv_classifier_hidden_dim: 559 14 | adv_classifier_n_layers: 4 15 | adv_steps: 5 16 | adv_weight: 0.1853767280161523 17 | dropout: 0.1 18 | elementwise_affine: false 19 | hidden_dim: 4352 20 | kl_weight: 0.21855673778824847 21 | lr: 3.680992074544368e-05 22 | n_latent: 128 23 | n_layers_covar_emb: 1 24 | n_layers_encoder: 1 25 | n_layers_pert_emb: 1 26 | n_warmup_epochs: 20 27 | penalty_weight: 7.605167161204701 28 | softplus_output: false 29 | variational: true 30 | wd: 1.9023489520016998e-05 31 | 32 | data: 33 | add_controls: false 34 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/frangieh21/cpa_no_adv_best_params_frangieh21.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: cpa 5 | - override /callbacks: default 6 | - override /data: frangieh21 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | 12 | model: 13 | use_adversary: false 14 | adv_classifier_hidden_dim: 559 15 | adv_classifier_n_layers: 4 16 | adv_steps: 5 17 | adv_weight: 0.1853767280161523 18 | dropout: 0.1 19 | elementwise_affine: false 20 | hidden_dim: 4352 21 | kl_weight: 0.21855673778824847 22 | lr: 3.680992074544368e-05 23 | n_latent: 128 24 | n_layers_covar_emb: 1 25 | n_layers_encoder: 1 26 | n_layers_pert_emb: 1 27 | n_warmup_epochs: 400 28 | penalty_weight: 7.605167161204701 29 | softplus_output: false 30 | variational: true 31 | wd: 1.9023489520016998e-05 32 | 33 | data: 34 | add_controls: false 35 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/frangieh21/cpa_scgpt_best_params_frangieh21.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: cpa 5 | - override /callbacks: default 6 | - override /data: frangieh21 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | 12 | model: 13 | adv_classifier_hidden_dim: 249 14 | adv_classifier_n_layers: 5 15 | adv_steps: 2 16 | adv_weight: 0.4809114972414237 17 | dropout: 0.1 18 | elementwise_affine: false 19 | hidden_dim: 4352 20 | kl_weight: 0.1519255328598612 21 | lr: 0.0006148770073381143 22 | n_latent: 128 23 | n_layers_covar_emb: 1 24 | n_layers_encoder: 1 25 | n_layers_pert_emb: 1 26 | n_warmup_epochs: 10 27 | penalty_weight: 0.19363459943522998 28 | softplus_output: false 29 | variational: true 30 | wd: 7.097486120374588e-08 31 | 32 | data: 33 | add_controls: false 34 | datapath: ${paths.data_dir}/prime-data/frangieh21_processed_scgpt.h5ad 35 | embedding_key: scgpt_embeddings 36 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/frangieh21/decoder_best_params_frangieh21.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: decoder_only 5 | - override /callbacks: default 6 | - override /data: frangieh21 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | 12 | model: 13 | encoder_width: 3328 14 | lr: 0.001489196297126292 15 | n_genes: null 16 | n_layers: 1 17 | n_perts: null 18 | softplus_output: true 19 | use_covariates: true 20 | use_perturbations: true 21 | wd: 4.786646240929937e-06 22 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/frangieh21/decoder_cov_best_params_frangieh21.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: decoder_only 5 | - override /callbacks: default 6 | - override /data: frangieh21 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | 12 | model: 13 | encoder_width: 4352 14 | lr: 0.0003078912524764639 15 | n_genes: null 16 | n_layers: 5 17 | n_perts: null 18 | softplus_output: false 19 | use_covariates: true 20 | use_perturbations: false 21 | wd: 9.633717477684286e-07 22 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/frangieh21/latent_best_params_frangieh21.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: latent_additive 5 | - override /callbacks: default 6 | - override /data: frangieh21 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | 12 | model: 13 | dropout: 0.30000000000000004 14 | encoder_width: 2304 15 | inject_covariates_decoder: true 16 | inject_covariates_encoder: true 17 | latent_dim: 256 18 | lr: 0.00052308363076951 19 | n_genes: null 20 | n_layers: 1 21 | n_perts: null 22 | softplus_output: true 23 | wd: 1.564792695502722e-07 24 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/frangieh21/latent_scgpt_best_params_frangieh21.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: latent_additive 5 | - override /callbacks: default 6 | - override /data: frangieh21 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | 12 | model: 13 | dropout: 0.30000000000000004 14 | encoder_width: 3328 15 | inject_covariates_decoder: true 16 | inject_covariates_encoder: true 17 | latent_dim: 192 18 | lr: 0.0004036126472173799 19 | n_genes: null 20 | n_layers: 1 21 | n_perts: null 22 | softplus_output: true 23 | wd: 2.0909467656368513e-08 24 | 25 | data: 26 | datapath: ${paths.data_dir}/prime-data/frangieh21_processed_scgpt.h5ad 27 | embedding_key: scgpt_embeddings 28 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/frangieh21/linear_best_params_frangieh21.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: linear_additive 5 | - override /callbacks: default 6 | - override /data: frangieh21 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | 12 | model: 13 | inject_covariates: true 14 | lr: 0.003915672085065509 15 | n_genes: null 16 | n_perts: null 17 | wd: 1.4293895399278078e-08 18 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/frangieh21/sams_best_params_frangieh21.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: sams_vae 5 | - override /callbacks: default 6 | - override /data: frangieh21 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | precision: 32 12 | 13 | model: 14 | dropout: 0.1 15 | hidden_dim_cond: 2560 16 | hidden_dim_x: 4096 17 | inject_covariates_decoder: false 18 | inject_covariates_encoder: false 19 | latent_dim: 24 20 | lr: 5.182988883689651e-07 21 | mask_prior_probability: 0.012572644354323684 22 | n_genes: null 23 | n_layers_decoder: 3 24 | n_layers_encoder_e: 5 25 | n_layers_encoder_x: 1 26 | n_perts: null 27 | softplus_output: true 28 | wd: 7.422084694841772e-10 29 | 30 | data: 31 | add_controls: false 32 | batch_size: 256 33 | evaluation: 34 | chunk_size: 10 35 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/frangieh21/sams_modified_best_params_frangieh21.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: sams_vae 5 | - override /callbacks: default 6 | - override /data: frangieh21 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | precision: 32 12 | 13 | model: 14 | disable_e_dist: true 15 | disable_sparsity: true 16 | dropout: 0.0 17 | generative_counterfactual: true 18 | hidden_dim_cond: 329 19 | hidden_dim_x: 2560 20 | inject_covariates_decoder: true 21 | inject_covariates_encoder: true 22 | latent_dim: 128 23 | lr: 5.579821197158771e-06 24 | mask_prior_probability: 0.09647872920959553 25 | n_layers_decoder: 1 26 | n_layers_encoder_e: 4 27 | n_layers_encoder_x: 4 28 | softplus_output: true 29 | wd: 2.0857105211314087e-08 30 | 31 | data: 32 | add_controls: false 33 | batch_size: 256 34 | evaluation: 35 | chunk_size: 10 36 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/jiang24/cpa_best_params_jiang24.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: cpa 5 | - override /callbacks: default 6 | - override /data: jiang24 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | 12 | model: 13 | adv_classifier_hidden_dim: 537 14 | adv_classifier_n_layers: 1 15 | adv_steps: 7 16 | adv_weight: 1.4870029218495044 17 | dropout: 0.30000000000000004 18 | elementwise_affine: false 19 | hidden_dim: 3328 20 | kl_weight: 6.904783806453286 21 | lr: 2.6302273130616072e-05 22 | n_latent: 512 23 | n_layers_covar_emb: 1 24 | n_layers_encoder: 7 25 | n_layers_pert_emb: 2 26 | n_warmup_epochs: 5 27 | penalty_weight: 2.8152900769395943 28 | softplus_output: false 29 | variational: true 30 | wd: 3.693112285894233e-07 31 | 32 | data: 33 | add_controls: false 34 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/jiang24/cpa_no_adv_best_params_jiang24.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: cpa 5 | - override /callbacks: default 6 | - override /data: jiang24 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | 12 | model: 13 | use_adversary: false 14 | adv_classifier_hidden_dim: 537 15 | adv_classifier_n_layers: 1 16 | adv_steps: 7 17 | adv_weight: 1.4870029218495044 18 | dropout: 0.30000000000000004 19 | elementwise_affine: false 20 | hidden_dim: 3328 21 | kl_weight: 6.904783806453286 22 | lr: 2.6302273130616072e-05 23 | n_latent: 512 24 | n_layers_covar_emb: 1 25 | n_layers_encoder: 7 26 | n_layers_pert_emb: 2 27 | n_warmup_epochs: 400 28 | penalty_weight: 2.8152900769395943 29 | softplus_output: false 30 | variational: true 31 | wd: 3.693112285894233e-07 32 | 33 | data: 34 | add_controls: false 35 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/jiang24/decoder_best_params_jiang24.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: decoder_only 5 | - override /callbacks: default 6 | - override /data: jiang24 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | 12 | model: 13 | encoder_width: 5376 14 | lr: 9.426394813968772e-05 15 | n_genes: null 16 | n_layers: 3 17 | n_perts: null 18 | softplus_output: true 19 | use_covariates: true 20 | use_perturbations: true 21 | wd: 8.513975008706282e-08 -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/jiang24/latent_best_params_jiang24.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: latent_additive 5 | - override /callbacks: default 6 | - override /data: jiang24 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | 12 | model: 13 | dropout: 0.6000000000000001 14 | encoder_width: 3328 15 | inject_covariates_decoder: true 16 | inject_covariates_encoder: true 17 | latent_dim: 256 18 | lr: 0.004991266977288785 19 | n_genes: null 20 | n_layers: 3 21 | n_perts: null 22 | softplus_output: true 23 | wd: 8.471259185195355e-07 24 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/jiang24/linear_best_params_jiang24.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: linear_additive 5 | - override /callbacks: default 6 | - override /data: jiang24 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | 12 | model: 13 | inject_covariates: true 14 | lr: 0.0014212823361302294 15 | n_genes: null 16 | n_perts: null 17 | wd: 1.0079009696147276e-08 18 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/jiang24/sams_best_params_jiang24.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: sams_vae 5 | - override /callbacks: default 6 | - override /data: jiang24 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | precision: 32 12 | 13 | model: 14 | dropout: 0.5 15 | generative_counterfactual: true 16 | hidden_dim_cond: 329 17 | hidden_dim_x: 3328 18 | inject_covariates_decoder: true 19 | inject_covariates_encoder: true 20 | latent_dim: 64 21 | lr: 7.594588526179077e-06 22 | mask_prior_probability: 0.20684484580190832 23 | n_layers_decoder: 3 24 | n_layers_encoder_e: 5 25 | n_layers_encoder_x: 3 26 | softplus_output: true 27 | wd: 1.5915315253373636e-08 28 | 29 | data: 30 | add_controls: false 31 | batch_size: 256 32 | evaluation: 33 | chunk_size: 10 34 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/jiang24/sams_modified_best_params_jiang24.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: sams_vae 5 | - override /callbacks: default 6 | - override /data: jiang24 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | precision: 32 12 | 13 | model: 14 | disable_e_dist: true 15 | disable_sparsity: true 16 | dropout: 0.6000000000000001 17 | generative_counterfactual: true 18 | hidden_dim_cond: 329 19 | hidden_dim_x: 1536 20 | inject_covariates_decoder: true 21 | inject_covariates_encoder: true 22 | latent_dim: 32 23 | lr: 0.0003892340882567657 24 | mask_prior_probability: 0.01017246847284454 25 | n_layers_decoder: 3 26 | n_layers_encoder_e: 2 27 | n_layers_encoder_x: 5 28 | softplus_output: true 29 | wd: 9.99382916747775e-07 30 | 31 | data: 32 | add_controls: false 33 | batch_size: 256 34 | evaluation: 35 | chunk_size: 10 36 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/mcfaline23/cpa_best_params_mcfaline23_full.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: cpa 8 | - override /data: mcfaline23 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 245 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | 19 | data: 20 | evaluation: 21 | chunk_size: 5 22 | batch_size: 8000 23 | num_workers: 12 24 | splitter: 25 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/full_covariate_split.csv 26 | 27 | model: 28 | adv_classifier_hidden_dim: 281 29 | adv_classifier_n_layers: 1 30 | adv_steps: 3 31 | adv_weight: 3.3267124584103804 32 | dropout: 0.30000000000000004 33 | elementwise_affine: false 34 | hidden_dim: 4352 35 | kl_weight: 0.4194843985243823 36 | lr: 1.7339648673958316e-05 37 | n_latent: 64 38 | n_layers_covar_emb: 1 39 | n_layers_encoder: 1 40 | n_layers_pert_emb: 2 41 | n_warmup_epochs: 10 42 | penalty_weight: 0.6115268823779908 43 | softplus_output: false 44 | variational: true 45 | wd: 2.346560291976718e-06 46 | 47 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/mcfaline23/cpa_best_params_mcfaline23_medium.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: cpa 8 | - override /data: mcfaline23 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 245 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | 19 | data: 20 | evaluation: 21 | chunk_size: 5 22 | batch_size: 8000 23 | num_workers: 12 24 | splitter: 25 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/medium_covariate_split.csv 26 | 27 | model: 28 | adv_classifier_hidden_dim: 205 29 | adv_classifier_n_layers: 1 30 | adv_steps: 3 31 | adv_weight: 0.9729207769467227 32 | dropout: 0.2 33 | elementwise_affine: false 34 | hidden_dim: 4352 35 | kl_weight: 8.582254554076528 36 | lr: 6.168904762775334e-05 37 | n_latent: 64 38 | n_layers_covar_emb: 1 39 | n_layers_encoder: 7 40 | n_layers_pert_emb: 2 41 | n_warmup_epochs: 15 42 | penalty_weight: 12.269805096188989 43 | softplus_output: false 44 | variational: true 45 | wd: 3.2081170417806746e-07 46 | 47 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/mcfaline23/cpa_best_params_mcfaline23_small.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: cpa 8 | - override /data: mcfaline23 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 245 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | 19 | data: 20 | evaluation: 21 | chunk_size: 5 22 | batch_size: 8000 23 | num_workers: 12 24 | splitter: 25 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/small_covariate_split.csv 26 | 27 | model: 28 | adv_classifier_hidden_dim: 446 29 | adv_classifier_n_layers: 3 30 | adv_steps: 3 31 | adv_weight: 6.657453598080865 32 | dropout: 0.1 33 | elementwise_affine: false 34 | hidden_dim: 2304 35 | kl_weight: 1.1791506221938572 36 | lr: 1.9312841187693903e-05 37 | n_latent: 256 38 | n_layers_covar_emb: 1 39 | n_layers_encoder: 7 40 | n_layers_pert_emb: 3 41 | n_warmup_epochs: 10 42 | penalty_weight: 9.392286981965224 43 | softplus_output: false 44 | variational: true 45 | wd: 1.9552773792043903e-08 46 | 47 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/mcfaline23/cpa_no_adv_best_params_mcfaline23_full.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: cpa 8 | - override /data: mcfaline23 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 245 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | 19 | data: 20 | evaluation: 21 | chunk_size: 5 22 | batch_size: 8000 23 | num_workers: 12 24 | splitter: 25 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/full_covariate_split.csv 26 | 27 | model: 28 | use_adversary: false 29 | adv_classifier_hidden_dim: 281 30 | adv_classifier_n_layers: 1 31 | adv_steps: 3 32 | adv_weight: 3.3267124584103804 33 | dropout: 0.30000000000000004 34 | elementwise_affine: false 35 | hidden_dim: 4352 36 | kl_weight: 0.4194843985243823 37 | lr: 1.7339648673958316e-05 38 | n_latent: 64 39 | n_layers_covar_emb: 1 40 | n_layers_encoder: 1 41 | n_layers_pert_emb: 2 42 | n_warmup_epochs: 500 43 | penalty_weight: 0.6115268823779908 44 | softplus_output: false 45 | variational: true 46 | wd: 2.346560291976718e-06 47 | 48 | 49 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/mcfaline23/cpa_no_adv_best_params_mcfaline23_medium.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: cpa 8 | - override /data: mcfaline23 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 245 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | 19 | data: 20 | evaluation: 21 | chunk_size: 5 22 | batch_size: 8000 23 | num_workers: 12 24 | splitter: 25 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/medium_covariate_split.csv 26 | 27 | model: 28 | use_adversary: false 29 | adv_classifier_hidden_dim: 205 30 | adv_classifier_n_layers: 1 31 | adv_steps: 3 32 | adv_weight: 0.9729207769467227 33 | dropout: 0.2 34 | elementwise_affine: false 35 | hidden_dim: 4352 36 | kl_weight: 8.582254554076528 37 | lr: 6.168904762775334e-05 38 | n_latent: 64 39 | n_layers_covar_emb: 1 40 | n_layers_encoder: 7 41 | n_layers_pert_emb: 2 42 | n_warmup_epochs: 15 43 | penalty_weight: 12.269805096188989 44 | softplus_output: false 45 | variational: true 46 | wd: 3.2081170417806746e-07 47 | 48 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/mcfaline23/cpa_no_adv_best_params_mcfaline23_small.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: cpa 8 | - override /data: mcfaline23 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 245 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | 19 | data: 20 | evaluation: 21 | chunk_size: 5 22 | batch_size: 8000 23 | num_workers: 12 24 | splitter: 25 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/small_covariate_split.csv 26 | 27 | model: 28 | use_adversary: false 29 | adv_classifier_hidden_dim: 446 30 | adv_classifier_n_layers: 3 31 | adv_steps: 3 32 | adv_weight: 6.657453598080865 33 | dropout: 0.1 34 | elementwise_affine: false 35 | hidden_dim: 2304 36 | kl_weight: 1.1791506221938572 37 | lr: 1.9312841187693903e-05 38 | n_latent: 256 39 | n_layers_covar_emb: 1 40 | n_layers_encoder: 7 41 | n_layers_pert_emb: 3 42 | n_warmup_epochs: 10 43 | penalty_weight: 9.392286981965224 44 | softplus_output: false 45 | variational: true 46 | wd: 1.9552773792043903e-08 47 | 48 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/mcfaline23/decoder_only_best_params_mcfaline23_full.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: decoder_only 4 | - override /data: mcfaline23 5 | 6 | seed: 245 7 | 8 | trainer: 9 | min_epochs: 5 10 | max_epochs: 500 11 | 12 | data: 13 | evaluation: 14 | chunk_size: 5 15 | batch_size: 8000 16 | num_workers: 12 17 | splitter: 18 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/full_covariate_split.csv 19 | 20 | model: 21 | encoder_width: 3328 22 | lr: 0.00019009753022442835 23 | n_layers: 1 24 | softplus_output: true 25 | wd: 5.697560765121058e-07 26 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/mcfaline23/decoder_only_best_params_mcfaline23_medium.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: decoder_only 4 | - override /data: mcfaline23 5 | 6 | seed: 245 7 | 8 | trainer: 9 | min_epochs: 5 10 | max_epochs: 500 11 | 12 | data: 13 | evaluation: 14 | chunk_size: 5 15 | batch_size: 8000 16 | num_workers: 8 17 | splitter: 18 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/medium_covariate_split.csv 19 | 20 | model: 21 | encoder_width: 4352 22 | lr: 0.00023414037024915073 23 | n_layers: 1 24 | softplus_output: true 25 | wd: 7.503473517441124e-07 26 | 27 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/mcfaline23/decoder_only_best_params_mcfaline23_small.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: decoder_only 4 | - override /data: mcfaline23 5 | 6 | seed: 245 7 | 8 | trainer: 9 | min_epochs: 5 10 | max_epochs: 500 11 | 12 | data: 13 | evaluation: 14 | chunk_size: 5 15 | batch_size: 8000 16 | num_workers: 8 17 | splitter: 18 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/small_covariate_split.csv 19 | 20 | model: 21 | encoder_width: 4352 22 | lr: 0.0007134752592105323 23 | n_layers: 1 24 | softplus_output: true 25 | wd: 6.892595385769382e-07 26 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/mcfaline23/latent_additive_best_params_mcfaline23_full.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: latent_additive 4 | - override /data: mcfaline23 5 | 6 | seed: 245 7 | 8 | trainer: 9 | min_epochs: 5 10 | max_epochs: 500 11 | 12 | data: 13 | evaluation: 14 | chunk_size: 5 15 | batch_size: 8000 16 | num_workers: 12 17 | splitter: 18 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/full_covariate_split.csv 19 | 20 | model: 21 | dropout: 0.1 22 | encoder_width: 4352 23 | latent_dim: 256 24 | lr: 0.00046952967507921957 25 | n_layers: 1 26 | wd: 3.348258680704949e-08 27 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/mcfaline23/latent_additive_best_params_mcfaline23_medium.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: latent_additive 4 | - override /data: mcfaline23 5 | 6 | seed: 245 7 | 8 | trainer: 9 | min_epochs: 5 10 | max_epochs: 500 11 | 12 | data: 13 | evaluation: 14 | chunk_size: 5 15 | batch_size: 8000 16 | num_workers: 8 17 | splitter: 18 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/medium_covariate_split.csv 19 | 20 | model: 21 | dropout: 0.0 22 | encoder_width: 3328 23 | latent_dim: 256 24 | lr: 2.0752864206129073e-05 25 | n_layers: 1 26 | wd: 9.448743387490416e-08 -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/mcfaline23/latent_additive_best_params_mcfaline23_small.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: latent_additive 4 | - override /data: mcfaline23 5 | 6 | seed: 245 7 | 8 | trainer: 9 | min_epochs: 5 10 | max_epochs: 500 11 | 12 | data: 13 | evaluation: 14 | chunk_size: 5 15 | batch_size: 8000 16 | num_workers: 8 17 | splitter: 18 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/small_covariate_split.csv 19 | 20 | model: 21 | dropout: 0.0 22 | encoder_width: 1280 23 | latent_dim: 512 24 | lr: 0.00021234604809346303 25 | n_layers: 1 26 | wd: 1.4101074283996682e-07 27 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/mcfaline23/linear_additive_best_params_mcfaline23_full.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: linear_additive 4 | - override /data: mcfaline23 5 | 6 | seed: 245 7 | 8 | trainer: 9 | min_epochs: 5 10 | max_epochs: 500 11 | 12 | data: 13 | evaluation: 14 | chunk_size: 5 15 | batch_size: 8000 16 | num_workers: 8 17 | splitter: 18 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/full_covariate_split.csv 19 | 20 | model: 21 | inject_covariates: True 22 | lr: 0.0013589931928117893 23 | wd: 1.0042027312774061e-08 24 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/mcfaline23/linear_additive_best_params_mcfaline23_medium.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: linear_additive 4 | - override /data: mcfaline23 5 | 6 | seed: 245 7 | 8 | trainer: 9 | min_epochs: 5 10 | max_epochs: 500 11 | 12 | data: 13 | evaluation: 14 | chunk_size: 5 15 | batch_size: 8000 16 | num_workers: 8 17 | splitter: 18 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/medium_covariate_split.csv 19 | 20 | model: 21 | inject_covariates: true 22 | lr: 0.0022999550256692877 23 | wd: 1.0685179638547894e-08 24 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/mcfaline23/linear_additive_best_params_mcfaline23_small.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: linear_additive 4 | - override /data: mcfaline23 5 | 6 | seed: 245 7 | 8 | trainer: 9 | min_epochs: 5 10 | max_epochs: 500 11 | 12 | data: 13 | evaluation: 14 | chunk_size: 5 15 | batch_size: 8000 16 | num_workers: 12 17 | splitter: 18 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/small_covariate_split.csv 19 | 20 | model: 21 | inject_covariates: true 22 | lr: 0.004950837663306272 23 | wd: 1.0397598707776114e-08 -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/mcfaline23/sams_best_params_mcfaline23_full.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: sams_vae 5 | - override /callbacks: default 6 | - override /data: mcfaline23 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | precision: 32 12 | 13 | model: 14 | dropout: 0.1 15 | generative_counterfactual: true 16 | hidden_dim_cond: 329 17 | hidden_dim_x: 3328 18 | inject_covariates_decoder: true 19 | inject_covariates_encoder: true 20 | latent_dim: 128 21 | lr: 3.6723171184888935e-06 22 | mask_prior_probability: 0.018435146998978795 23 | n_layers_decoder: 3 24 | n_layers_encoder_e: 4 25 | n_layers_encoder_x: 5 26 | softplus_output: true 27 | wd: 2.569943996675991e-07 28 | 29 | data: 30 | splitter: 31 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/full_covariate_split.csv 32 | add_controls: false 33 | batch_size: 256 34 | evaluation: 35 | chunk_size: 10 36 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/mcfaline23/sams_best_params_mcfaline23_medium.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: sams_vae 5 | - override /callbacks: default 6 | - override /data: mcfaline23 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | precision: 32 12 | 13 | model: 14 | dropout: 0.30000000000000004 15 | generative_counterfactual: true 16 | hidden_dim_cond: 329 17 | hidden_dim_x: 1536 18 | inject_covariates_decoder: true 19 | inject_covariates_encoder: true 20 | latent_dim: 128 21 | lr: 6.4220933723681095e-06 22 | mask_prior_probability: 0.021735343521054572 23 | n_layers_decoder: 2 24 | n_layers_encoder_e: 3 25 | n_layers_encoder_x: 4 26 | softplus_output: true 27 | wd: 3.2516220842555925e-09 28 | 29 | data: 30 | splitter: 31 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/medium_covariate_split.csv 32 | add_controls: false 33 | batch_size: 256 34 | evaluation: 35 | chunk_size: 10 36 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/mcfaline23/sams_best_params_mcfaline23_small.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: sams_vae 5 | - override /callbacks: default 6 | - override /data: mcfaline23 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | precision: 32 12 | 13 | model: 14 | dropout: 0.30000000000000004 15 | generative_counterfactual: true 16 | hidden_dim_cond: 329 17 | hidden_dim_x: 2304 18 | inject_covariates_decoder: true 19 | inject_covariates_encoder: true 20 | latent_dim: 24 21 | lr: 6.950438037779035e-06 22 | mask_prior_probability: 0.03505373699588523 23 | n_layers_decoder: 1 24 | n_layers_encoder_e: 2 25 | n_layers_encoder_x: 3 26 | softplus_output: true 27 | wd: 8.162330576054943e-09 28 | 29 | data: 30 | splitter: 31 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/small_covariate_split.csv 32 | add_controls: false 33 | batch_size: 256 34 | evaluation: 35 | chunk_size: 10 36 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/mcfaline23/sams_modified_best_params_mcfaline23_full.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: sams_vae 5 | - override /callbacks: default 6 | - override /data: mcfaline23 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | precision: 32 12 | 13 | model: 14 | disable_e_dist: true 15 | disable_sparsity: true 16 | dropout: 0.1 17 | generative_counterfactual: true 18 | hidden_dim_cond: 329 19 | hidden_dim_x: 3584 20 | inject_covariates_decoder: true 21 | inject_covariates_encoder: true 22 | latent_dim: 24 23 | lr: 3.952811264448736e-05 24 | mask_prior_probability: 0.07551256281258875 25 | n_layers_decoder: 3 26 | n_layers_encoder_e: 3 27 | n_layers_encoder_x: 1 28 | softplus_output: true 29 | wd: 6.380767802230304e-06 30 | 31 | data: 32 | splitter: 33 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/full_covariate_split.csv 34 | add_controls: false 35 | batch_size: 256 36 | evaluation: 37 | chunk_size: 10 38 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/mcfaline23/sams_modified_best_params_mcfaline23_medium.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: sams_vae 5 | - override /callbacks: default 6 | - override /data: mcfaline23 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | precision: 32 12 | 13 | model: 14 | disable_e_dist: true 15 | disable_sparsity: true 16 | dropout: 0.30000000000000004 17 | generative_counterfactual: true 18 | hidden_dim_cond: 329 19 | hidden_dim_x: 3584 20 | inject_covariates_decoder: true 21 | inject_covariates_encoder: true 22 | latent_dim: 128 23 | lr: 2.97852617667106e-05 24 | mask_prior_probability: 0.012905035283331237 25 | n_layers_decoder: 2 26 | n_layers_encoder_e: 1 27 | n_layers_encoder_x: 1 28 | softplus_output: true 29 | wd: 1.0248794045174245e-05 30 | 31 | data: 32 | splitter: 33 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/medium_covariate_split.csv 34 | add_controls: false 35 | batch_size: 256 36 | evaluation: 37 | chunk_size: 10 38 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/mcfaline23/sams_modified_best_params_mcfaline23_small.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: sams_vae 5 | - override /callbacks: default 6 | - override /data: mcfaline23 7 | 8 | trainer: 9 | max_epochs: 400 10 | min_epochs: 5 11 | precision: 32 12 | 13 | model: 14 | disable_e_dist: true 15 | disable_sparsity: true 16 | dropout: 0.30000000000000004 17 | generative_counterfactual: true 18 | hidden_dim_cond: 329 19 | hidden_dim_x: 3328 20 | inject_covariates_decoder: true 21 | inject_covariates_encoder: true 22 | latent_dim: 32 23 | lr: 3.4324600190407314e-05 24 | mask_prior_probability: 0.05761980928292434 25 | n_layers_decoder: 1 26 | n_layers_encoder_e: 4 27 | n_layers_encoder_x: 1 28 | softplus_output: true 29 | wd: 1.1815419587902247e-08 30 | 31 | data: 32 | splitter: 33 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/small_covariate_split.csv 34 | add_controls: false 35 | batch_size: 256 36 | evaluation: 37 | chunk_size: 10 38 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/norman19/biolord_best_params_norman19.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: biolord 8 | - override /data: norman19 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 245 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | 19 | data: 20 | batch_size: 1000 21 | splitter: 22 | max_heldout_fraction_per_covariate: 0.7 23 | add_controls: False 24 | 25 | model: 26 | dropout: 0.4 27 | encoder_width: 2304 28 | latent_dim: 512 29 | lr: 0.00016701245023478605 30 | n_layers: 1 31 | penalty_weight: 2621.3333751927075 32 | wd: 0.00046077468143989676 -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/norman19/cpa_best_params_norman19.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: cpa 8 | - override /data: norman19 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 245 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | 19 | data: 20 | splitter: 21 | max_heldout_fraction_per_covariate: 0.7 22 | add_controls: False 23 | 24 | model: 25 | adv_classifier_hidden_dim: 776 26 | adv_classifier_n_layers: 2 27 | adv_steps: 2 28 | adv_weight: 3.804798319367216 29 | dropout: 0.2 30 | elementwise_affine: false 31 | hidden_dim: 3328 32 | kl_weight: 19.756374864509844 33 | lr: 6.844044407644798e-05 34 | n_latent: 128 35 | n_layers_covar_emb: 1 36 | n_layers_encoder: 1 37 | n_layers_pert_emb: 2 38 | n_warmup_epochs: 15 39 | penalty_weight: 1.8984251429709187 40 | softplus_output: false 41 | variational: true 42 | wd: 1.6600312433505752e-07 43 | 44 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/norman19/cpa_no_adv_best_params_norman19.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: cpa 8 | - override /data: norman19 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 245 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | 19 | data: 20 | splitter: 21 | max_heldout_fraction_per_covariate: 0.7 22 | add_controls: False 23 | 24 | model: 25 | use_adversary: False 26 | adv_classifier_hidden_dim: 776 27 | adv_classifier_n_layers: 2 28 | adv_steps: 2 29 | adv_weight: 3.804798319367216 30 | dropout: 0.2 31 | elementwise_affine: false 32 | hidden_dim: 3328 33 | kl_weight: 19.756374864509844 34 | lr: 6.844044407644798e-05 35 | n_latent: 128 36 | n_layers_covar_emb: 1 37 | n_layers_encoder: 1 38 | n_layers_pert_emb: 2 39 | n_warmup_epochs: 500 40 | penalty_weight: 1.8984251429709187 41 | softplus_output: false 42 | variational: true 43 | wd: 1.6600312433505752e-07 44 | 45 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/norman19/cpa_scgpt_best_params_norman19.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: cpa 8 | - override /data: norman19 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 245 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 400 18 | 19 | data: 20 | splitter: 21 | max_heldout_fraction_per_covariate: 0.7 22 | datapath: ${paths.data_dir}/norman19_preprocessed_with_embeddings.h5ad 23 | embedding_key: scgpt_embbeddings 24 | add_controls: False 25 | 26 | model: 27 | adv_classifier_hidden_dim: 837 28 | adv_classifier_n_layers: 2 29 | adv_steps: 3 30 | adv_weight: 0.12619384778246517 31 | dropout: 0.6000000000000001 32 | elementwise_affine: false 33 | hidden_dim: 5376 34 | kl_weight: 11.931005035532673 35 | lr: 0.0005134219006570818 36 | n_latent: 256 37 | n_layers_covar_emb: 1 38 | n_layers_encoder: 1 39 | n_layers_pert_emb: 1 40 | n_warmup_epochs: 10 41 | penalty_weight: 0.9962990820647598 42 | softplus_output: false 43 | variational: true 44 | wd: 6.849486237372675e-05 45 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/norman19/decoder_best_params_norman19.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: decoder_only 8 | - override /data: norman19 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 245 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | 19 | data: 20 | splitter: 21 | max_heldout_fraction_per_covariate: 0.7 22 | 23 | model: 24 | use_covariates: False 25 | encoder_width: 3328 26 | lr: 0.00013753391233021738 27 | n_layers: 3 28 | softplus_output: false 29 | wd: 4.417109615721373e-05 30 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/norman19/latent_best_params_norman19.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: latent_additive 8 | - override /data: norman19 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 245 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | 19 | data: 20 | splitter: 21 | max_heldout_fraction_per_covariate: 0.7 22 | 23 | model: 24 | inject_covariates_encoder: False 25 | inject_covariates_decoder: False 26 | dropout: 0.1 27 | encoder_width: 5376 28 | latent_dim: 64 29 | lr: 6.779597503293815e-05 30 | n_layers: 1 31 | wd: 1.0406767176550133e-08 32 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/norman19/latent_scgpt_best_params_norman19.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: latent_additive 8 | - override /data: norman19 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 245 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | 19 | data: 20 | splitter: 21 | max_heldout_fraction_per_covariate: 0.7 22 | datapath: ${paths.data_dir}/norman19_preprocessed_with_embeddings.h5ad 23 | embedding_key: scgpt_embbeddings 24 | 25 | model: 26 | inject_covariates_encoder: False 27 | inject_covariates_decoder: False 28 | dropout: 0.4 29 | encoder_width: 5376 30 | latent_dim: 512 31 | lr: 8.965448576753094e-05 32 | n_layers: 1 33 | wd: 1.9716752225147476e-05 34 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/norman19/linear_best_params_norman19.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: linear_additive 8 | - override /data: norman19 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 245 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | 19 | data: 20 | splitter: 21 | max_heldout_fraction_per_covariate: 0.7 22 | 23 | model: 24 | inject_covariates: false 25 | lr: 0.004716813309487752 26 | wd: 1.7588044643755207e-08 -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/norman19/sams_best_params_norman19.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: sams_vae 8 | - override /data: norman19 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 137 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | precision: 32 19 | 20 | data: 21 | batch_size: 256 22 | splitter: 23 | max_heldout_fraction_per_covariate: 0.7 24 | add_controls: False 25 | evaluation: 26 | chunk_size: 10 27 | 28 | model: 29 | dropout: 0.4 30 | hidden_dim_cond: 1280 31 | hidden_dim_x: 3584 32 | inject_covariates_decoder: false 33 | inject_covariates_encoder: false 34 | latent_dim: 128 35 | lr: 3.0263552583537424e-05 36 | mask_prior_probability: 0.11384202578456981 37 | n_genes: null 38 | n_layers_decoder: 2 39 | n_layers_encoder_e: 3 40 | n_layers_encoder_x: 1 41 | n_perts: null 42 | softplus_output: true 43 | wd: 4.244962253886731e-09 44 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/norman19/sams_modified_best_params_norman19.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: sams_vae 8 | - override /data: norman19 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 137 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | precision: 32 19 | 20 | data: 21 | batch_size: 256 22 | splitter: 23 | max_heldout_fraction_per_covariate: 0.7 24 | add_controls: False 25 | evaluation: 26 | chunk_size: 10 27 | 28 | model: 29 | disable_e_dist: true 30 | disable_sparsity: true 31 | dropout: 0.30000000000000004 32 | generative_counterfactual: true 33 | hidden_dim_cond: 329 34 | hidden_dim_x: 2048 35 | inject_covariates_decoder: false 36 | inject_covariates_encoder: false 37 | latent_dim: 32 38 | lr: 1.5981872317838004e-05 39 | mask_prior_probability: 0.018746135628343086 40 | n_layers_decoder: 3 41 | n_layers_encoder_e: 1 42 | n_layers_encoder_x: 1 43 | softplus_output: true 44 | wd: 6.55647585152116e-05 45 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/sciplex3/biolord_best_params_sciplex3.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: biolord 8 | - override /data: sciplex3 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | 14 | seed: 245 15 | 16 | trainer: 17 | min_epochs: 5 18 | max_epochs: 500 19 | precision: 16 20 | 21 | data: 22 | batch_size: 1000 23 | add_controls: False 24 | 25 | model: 26 | dropout: 0.4 27 | encoder_width: 2304 28 | latent_dim: 512 29 | lr: 0.00016701245023478605 30 | wd: 0.00046077468143989676 31 | n_layers: 1 32 | penalty_weight: 2621.3333751927075 33 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/sciplex3/cpa_best_params_sciplex3.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: cpa 8 | - override /data: sciplex3 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 245 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | 19 | data: 20 | add_controls: False 21 | 22 | model: 23 | adv_classifier_hidden_dim: 383 24 | adv_classifier_n_layers: 2 25 | adv_steps: 3 26 | adv_weight: 0.21007098128462928 27 | dropout: 0.1 28 | elementwise_affine: false 29 | hidden_dim: 1280 30 | kl_weight: 0.2643107960544633 31 | lr: 0.0002688668829765555 32 | n_latent: 512 33 | n_layers_covar_emb: 1 34 | n_layers_encoder: 3 35 | n_layers_pert_emb: 2 36 | n_warmup_epochs: 5 37 | penalty_weight: 10.2126759375119 38 | softplus_output: false 39 | variational: true 40 | wd: 2.992280181677909e-06 -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/sciplex3/cpa_no_adv_best_params_sciplex3.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: cpa 8 | - override /data: sciplex3 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 245 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | 19 | data: 20 | add_controls: False 21 | 22 | model: 23 | use_adversary: False 24 | adv_classifier_hidden_dim: 383 25 | adv_classifier_n_layers: 2 26 | adv_steps: 3 27 | adv_weight: 0.21007098128462928 28 | decoder_distribution: IsotropicGaussian 29 | dropout: 0.1 30 | elementwise_affine: false 31 | hidden_dim: 1280 32 | kl_weight: 0.2643107960544633 33 | lr: 0.0002688668829765555 34 | n_latent: 512 35 | n_layers_covar_emb: 1 36 | n_layers_encoder: 3 37 | n_layers_pert_emb: 2 38 | n_warmup_epochs: 500 39 | penalty_weight: 10.2126759375119 40 | softplus_output: false 41 | variational: true 42 | wd: 2.992280181677909e-06 43 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/sciplex3/cpa_scgpt_best_params_sciplex3.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: cpa 8 | - override /data: sciplex3 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 245 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | 19 | data: 20 | datapath: ${paths.data_dir}/srivatsan20_highest_dose_preprocessed_with_embeddings.h5ad 21 | embedding_key: scgpt_embbeddings 22 | add_controls: False 23 | 24 | model: 25 | adv_classifier_hidden_dim: 681 26 | adv_classifier_n_layers: 1 27 | adv_steps: 2 28 | adv_weight: 1.9003113280995905 29 | dropout: 0.1 30 | elementwise_affine: false 31 | hidden_dim: 5376 32 | kl_weight: 1.095403737862168 33 | lr: 0.0001711130140991506 34 | n_latent: 512 35 | n_layers_covar_emb: 1 36 | n_layers_encoder: 1 37 | n_layers_pert_emb: 4 38 | n_warmup_epochs: 10 39 | penalty_weight: 19.75923566901279 40 | softplus_output: false 41 | variational: true 42 | wd: 0.0002629416918365637 -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/sciplex3/decoder_best_params_sciplex3.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: decoder_only 8 | - override /data: sciplex3 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 245 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | 19 | model: 20 | encoder_width: 4352 21 | lr: 6.798204826057849e-05 22 | n_layers: 5 23 | softplus_output: true 24 | wd: 7.601981738045318e-06 25 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/sciplex3/decoder_cov_best_params_sciplex3.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: decoder_only 8 | - override /data: sciplex3 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 245 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | 19 | model: 20 | use_perturbations: False 21 | encoder_width: 4352 22 | lr: 5.330813318795293e-06 23 | wd: 1.0553034125634522e-08 24 | n_layers: 5 25 | softplus_output: true 26 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/sciplex3/latent_best_params_sciplex3.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: latent_additive 8 | - override /data: sciplex3 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 245 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | 19 | model: 20 | dropout: 0.6000000000000001 21 | encoder_width: 3328 22 | latent_dim: 64 23 | lr: 0.0013195915995113784 24 | n_layers: 1 25 | wd: 9.371223651937301e-07 -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/sciplex3/latent_scgpt_best_params_sciplex3.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: latent_additive 8 | - override /data: sciplex3 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 245 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | 19 | data: 20 | datapath: ${paths.data_dir}/srivatsan20_highest_dose_preprocessed_with_embeddings.h5ad 21 | embedding_key: scgpt_embbeddings 22 | 23 | model: 24 | lr_scheduler_freq: 5 25 | dropout: 0.1 26 | encoder_width: 2304 27 | latent_dim: 128 28 | lr: 1.2051920391885433e-05 29 | n_layers: 3 30 | wd: 2.701962713462281e-07 31 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/sciplex3/linear_best_params_sciplex3.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: linear_additive 8 | - override /data: sciplex3 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 245 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | 19 | model: 20 | lr: 0.0007787610860511324 21 | wd: 1.0632418512521614e-08 22 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/sciplex3/sams_best_params_sciplex3.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: sams_vae 8 | - override /data: sciplex3 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 137 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | precision: 32 19 | 20 | data: 21 | batch_size: 256 22 | add_controls: False 23 | evaluation: 24 | chunk_size: 10 25 | 26 | model: 27 | dropout: 0.30000000000000004 28 | hidden_dim_cond: 2816 29 | hidden_dim_x: 1536 30 | inject_covariates_decoder: false 31 | inject_covariates_encoder: false 32 | latent_dim: 64 33 | lr: 0.00040019008240191795 34 | mask_prior_probability: 0.011486688007756388 35 | n_genes: null 36 | n_layers_decoder: 1 37 | n_layers_encoder_e: 2 38 | n_layers_encoder_x: 4 39 | n_perts: null 40 | softplus_output: true 41 | wd: 1.749495655302928e-06 42 | -------------------------------------------------------------------------------- /src/perturbench/configs/experiment/neurips2024/sciplex3/sams_modified_best_params_sciplex3.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /model: sams_vae 8 | - override /data: sciplex3 9 | 10 | # all parameters below will be merged with parameters from default configurations set above 11 | # this allows you to overwrite only specified parameters 12 | 13 | seed: 137 14 | 15 | trainer: 16 | min_epochs: 5 17 | max_epochs: 500 18 | precision: 32 19 | 20 | data: 21 | batch_size: 256 22 | add_controls: False 23 | evaluation: 24 | chunk_size: 10 25 | 26 | model: 27 | disable_e_dist: true 28 | disable_sparsity: true 29 | dropout: 0.4 30 | generative_counterfactual: true 31 | hidden_dim_cond: 329 32 | hidden_dim_x: 2560 33 | inject_covariates_decoder: true 34 | inject_covariates_encoder: true 35 | latent_dim: 32 36 | lr: 0.000341765289234546 37 | mask_prior_probability: 0.05618183575503891 38 | n_layers_decoder: 3 39 | n_layers_encoder_e: 2 40 | n_layers_encoder_x: 4 41 | softplus_output: true 42 | wd: 6.943657713062555e-08 43 | -------------------------------------------------------------------------------- /src/perturbench/configs/hpo/biolord_hpo.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | #for HPO runs. 4 | defaults: 5 | - override /hydra/sweeper: optuna 6 | - override /hydra/launcher: ray 7 | - override /hydra/sweeper/sampler: tpe 8 | 9 | ## Specify multiple metrics and how they should be added together 10 | metrics_to_optimize: 11 | rmse_average: 1.0 12 | rmse_rank_average: 0.1 13 | 14 | hydra: 15 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 16 | sweeper: 17 | direction: minimize 18 | study_name: ${model._target_}_${oc.env:USER}_${now:%Y-%m-%d}/${now:%H-%M-%S} # if there is an existing study with the same name, it will be resumed 19 | storage: null # optuna database to store HPO results. 20 | n_trials: 4 # number of total trials 21 | n_jobs: 2 # number of parallel jobs 22 | max_failure_rate: 1 # overall HPO job will not fail if less than this ratio of trials fail 23 | params: 24 | model.penalty_weight: tag(log, interval(1e1, 10**5)) 25 | model.n_layers: range(1, 7, step=2) 26 | model.encoder_width: range(256, 5376, step=1024) 27 | model.latent_dim: choice(64, 128, 192, 256, 512) 28 | model.lr: tag(log, interval(5e-6, 5e-3)) 29 | model.wd: tag(log, interval(1e-8, 1e-3)) 30 | model.dropout: range(0.0, 0.8, step=0.1) 31 | launcher: 32 | ray: 33 | init: 34 | ## for local runs 35 | num_gpus: 2 # number of total gpus to use 36 | remote: 37 | num_gpus: 1 # number of gpus per trial 38 | max_calls: 1 -------------------------------------------------------------------------------- /src/perturbench/configs/hpo/cpa_hpo.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | #for HPO runs. 4 | defaults: 5 | - override /hydra/sweeper: optuna 6 | - override /hydra/launcher: ray 7 | - override /hydra/sweeper/sampler: tpe 8 | 9 | ## Specify multiple metrics and how they should be added together 10 | metrics_to_optimize: 11 | rmse_average: 1.0 12 | rmse_rank_average: 0.1 13 | 14 | hydra: 15 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 16 | sweeper: 17 | direction: minimize 18 | study_name: ${model._target_}_${oc.env:USER}_${now:%Y-%m-%d}/${now:%H-%M-%S} # if there is an existing study with the same name, it will be resumed 19 | storage: null # optuna database to store HPO results. 20 | n_trials: 4 # number of total trials 21 | n_jobs: 2 # number of parallel jobs 22 | max_failure_rate: 1 # overall HPO job will not fail if less than this ratio of trials fail 23 | params: 24 | model.n_layers_encoder: range(1, 7, step=2) 25 | model.n_layers_pert_emb: range(1, 5, step=1) 26 | model.adv_classifier_n_layers: range(1, 5, step=1) 27 | model.hidden_dim: range(256, 5376, step=1024) 28 | model.adv_classifier_hidden_dim: tag(log, int(interval(128, 1024))) # log scale 29 | model.adv_steps: choice(2, 3, 5, 7, 10, 20, 30) 30 | model.n_latent: choice(64, 128, 192, 256, 512) 31 | model.lr: tag(log, interval(5e-6, 1e-3)) 32 | model.wd: tag(log, interval(1e-8, 1e-3)) 33 | model.dropout: range(0.0, 0.8, step=0.1) 34 | model.kl_weight: tag(log, interval(0.1, 20)) 35 | model.adv_weight: tag(log, interval(0.1, 20)) 36 | model.penalty_weight: tag(log, interval(0.1, 20)) 37 | launcher: 38 | ray: 39 | init: 40 | ## for local runs 41 | num_gpus: 2 # number of total gpus to use 42 | remote: 43 | num_gpus: 1 # number of gpus per trial 44 | max_calls: 1 -------------------------------------------------------------------------------- /src/perturbench/configs/hpo/decoder_only_hpo.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | #for HPO runs. 4 | defaults: 5 | - override /hydra/sweeper: optuna 6 | - override /hydra/launcher: ray 7 | - override /hydra/sweeper/sampler: tpe 8 | 9 | ## Specify multiple metrics and how they should be added together 10 | metrics_to_optimize: 11 | rmse_average: 1.0 12 | rmse_rank_average: 0.1 13 | 14 | hydra: 15 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 16 | sweeper: 17 | direction: minimize 18 | study_name: ${model._target_}_${oc.env:USER}_${now:%Y-%m-%d}/${now:%H-%M-%S} # if there is an existing study with the same name, it will be resumed 19 | storage: null # optuna database to store HPO results. 20 | n_trials: 4 # number of total trials 21 | n_jobs: 2 # number of parallel jobs 22 | max_failure_rate: 1 # overall HPO job will not fail if less than this ratio of trials fail 23 | params: 24 | model.n_layers: range(1, 7, step=2) 25 | model.encoder_width: range(256, 5376, step=1024) 26 | model.lr: tag(log, interval(5e-6, 5e-3)) 27 | model.wd: tag(log, interval(1e-8, 1e-3)) 28 | model.softplus_output: choice(true, false) 29 | launcher: 30 | ray: 31 | init: 32 | ## for local runs 33 | num_gpus: 2 # number of total gpus to use 34 | remote: 35 | num_gpus: 1 # number of gpus per trial 36 | max_calls: 1 -------------------------------------------------------------------------------- /src/perturbench/configs/hpo/latent_additive_hpo.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | #for HPO runs. 4 | defaults: 5 | - override /hydra/sweeper: optuna 6 | - override /hydra/launcher: ray 7 | - override /hydra/sweeper/sampler: tpe 8 | 9 | ## Specify multiple metrics and how they should be added together 10 | metrics_to_optimize: 11 | rmse_average: 1.0 12 | rmse_rank_average: 0.1 13 | 14 | hydra: 15 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 16 | sweeper: 17 | direction: minimize 18 | study_name: ${model._target_}_${oc.env:USER}_${now:%Y-%m-%d}/${now:%H-%M-%S} # if there is an existing study with the same name, it will be resumed 19 | storage: null # optuna database to store HPO results. 20 | n_trials: 4 # number of total trials 21 | n_jobs: 2 # number of parallel jobs 22 | max_failure_rate: 1 # overall HPO job will not fail if less than this ratio of trials fail 23 | params: 24 | model.n_layers: range(1, 7, step=2) 25 | model.encoder_width: range(256, 5376, step=1024) 26 | model.latent_dim: choice(64, 128, 192, 256, 512) 27 | model.lr: tag(log, interval(5e-6, 5e-3)) 28 | model.wd: tag(log, interval(1e-8, 1e-3)) 29 | model.dropout: range(0.0, 0.8, step=0.1) 30 | launcher: 31 | ray: 32 | init: 33 | ## for local runs 34 | num_gpus: 2 # number of total gpus to use 35 | remote: 36 | num_gpus: 1 # number of gpus per trial 37 | max_calls: 1 -------------------------------------------------------------------------------- /src/perturbench/configs/hpo/linear_additive_hpo.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | #for HPO runs. 4 | defaults: 5 | - override /hydra/sweeper: optuna 6 | - override /hydra/launcher: ray 7 | - override /hydra/sweeper/sampler: tpe 8 | 9 | ## Specify multiple metrics and how they should be added together 10 | metrics_to_optimize: 11 | rmse_average: 1.0 12 | rmse_rank_average: 0.1 13 | 14 | hydra: 15 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 16 | sweeper: 17 | direction: minimize 18 | study_name: ${model._target_}_${oc.env:USER}_${now:%Y-%m-%d}/${now:%H-%M-%S} # if there is an existing study with the same name, it will be resumed 19 | storage: null # optuna database to store HPO results. 20 | n_trials: 4 # number of total trials 21 | n_jobs: 2 # number of parallel jobs 22 | max_failure_rate: 1 # overall HPO job will not fail if less than this ratio of trials fail 23 | params: 24 | model.lr: tag(log, interval(5e-6, 5e-3)) 25 | model.wd: tag(log, interval(1e-8, 1e-3)) 26 | launcher: 27 | ray: 28 | init: 29 | ## for local runs 30 | num_gpus: 2 # number of total gpus to use 31 | remote: 32 | num_gpus: 1 # number of gpus per trial 33 | max_calls: 1 -------------------------------------------------------------------------------- /src/perturbench/configs/hpo/local.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | #for HPO runs. 4 | defaults: 5 | - override /hydra/sweeper: optuna 6 | - override /hydra/launcher: ray 7 | - override /hydra/sweeper/sampler: tpe 8 | 9 | ## Specify multiple metrics and how they should be added together 10 | metrics_to_optimize: 11 | rmse_average: 1.0 12 | rmse_rank_average: 0.1 13 | 14 | hydra: 15 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 16 | sweeper: 17 | direction: minimize 18 | study_name: ${model._target_}_${oc.env:USER}_${now:%Y-%m-%d}/${now:%H-%M-%S} # if there is an existing study with the same name, it will be resumed 19 | storage: null # optuna database to store HPO results. 20 | n_trials: 4 # number of total trials 21 | n_jobs: 2 # number of parallel jobs 22 | max_failure_rate: 1 # overall HPO job will not fail if less than this ratio of trials fail 23 | launcher: 24 | ray: 25 | init: 26 | ## for local runs 27 | num_gpus: 2 # number of total gpus to use 28 | remote: 29 | num_gpus: 1 # number of gpus per trial 30 | max_calls: 1 -------------------------------------------------------------------------------- /src/perturbench/configs/hpo/sams_vae_hpo.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | #for HPO runs. 4 | defaults: 5 | - override /hydra/sweeper: optuna 6 | - override /hydra/launcher: ray 7 | - override /hydra/sweeper/sampler: tpe 8 | 9 | ## Specify multiple metrics and how they should be added together 10 | metrics_to_optimize: 11 | rmse_average: 1.0 12 | rmse_rank_average: 0.1 13 | 14 | hydra: 15 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 16 | sweeper: 17 | direction: minimize 18 | study_name: ${model._target_}_${oc.env:USER}_${now:%Y-%m-%d}/${now:%H-%M-%S} # if there is an existing study with the same name, it will be resumed 19 | storage: null # optuna database to store HPO results. 20 | n_trials: 4 # number of total trials 21 | n_jobs: 2 # number of parallel jobs 22 | max_failure_rate: 1 # overall HPO job will not fail if less than this ratio of trials fail 23 | params: 24 | model.n_layers_encoder_x: range(1, 7, step=2) 25 | model.n_layers_encoder_e: range(1, 7, step=2) 26 | model.n_layers_decoder: range(1, 7, step=2) 27 | model.latent_dim: choice(64, 128, 192, 256, 512) 28 | model.hidden_dim_x: range(256, 5376, step=1024) 29 | model.hidden_dim_cond: range(50, 500, step=50) 30 | model.sparse_additive_mechanism: choice(true, false) 31 | model.mean_field_encoding: choice(true, false) 32 | model.lr: tag(log, interval(5e-6, 5e-3)) 33 | model.wd: tag(log, interval(1e-8, 1e-3)) 34 | model.mask_prior_probability: tag(log, interval(1e-4, 0.99)) 35 | model.dropout: range(0.0, 0.8, step=0.1) 36 | launcher: 37 | ray: 38 | init: 39 | ## for local runs 40 | num_gpus: 2 # number of total gpus to use 41 | remote: 42 | num_gpus: 1 # number of gpus per trial 43 | max_calls: 1 -------------------------------------------------------------------------------- /src/perturbench/configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} 11 | sweep: 12 | dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} 13 | subdir: ${hydra.job.num} 14 | 15 | job_logging: 16 | handlers: 17 | file: 18 | # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 19 | filename: ${hydra.runtime.output_dir}/${task_name}.log 20 | -------------------------------------------------------------------------------- /src/perturbench/configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: lightning.pytorch.loggers.CSVLogger 5 | save_dir: "${paths.output_dir}" 6 | name: "csv/" 7 | prefix: "" -------------------------------------------------------------------------------- /src/perturbench/configs/logger/default.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | - tensorboard 5 | - csv -------------------------------------------------------------------------------- /src/perturbench/configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # TensorBoard logger built in lightning 2 | 3 | tensorboard: 4 | _target_: lightning.pytorch.loggers.TensorBoardLogger 5 | save_dir: "${paths.output_dir}" 6 | name: "tensorboard/" 7 | prefix: "" -------------------------------------------------------------------------------- /src/perturbench/configs/model/biolord.yaml: -------------------------------------------------------------------------------- 1 | _target_: perturbench.modelcore.models.BiolordStar 2 | n_genes: null 3 | n_perts: null 4 | 5 | ## Optimizer params 6 | lr: 6.72106935566634e-05 7 | wd: 7.90200208844498e-09 8 | penalty_weight: 1.0 9 | 10 | ## Model params 11 | n_layers: 3 12 | encoder_width: 3072 13 | latent_dim: 32 14 | softplus_output: True 15 | dropout: 0.1 -------------------------------------------------------------------------------- /src/perturbench/configs/model/cpa.yaml: -------------------------------------------------------------------------------- 1 | _target_: perturbench.modelcore.models.CPA 2 | n_genes: null 3 | n_perts: null 4 | 5 | n_latent: 128 6 | hidden_dim: 256 7 | n_layers_encoder: 3 8 | n_layers_pert_emb: 2 9 | n_layers_covar_emb: 1 10 | 11 | variational: True 12 | 13 | adv_classifier_n_layers: 2 14 | adv_classifier_hidden_dim: 128 15 | 16 | lr: 1e-5 17 | wd: 1e-8 18 | 19 | dropout: 1e-1 20 | kl_weight: 1.0 21 | adv_weight: 1.0 22 | penalty_weight: 10.0 23 | adv_steps: 2 24 | n_warmup_epochs: 5 25 | 26 | softplus_output: False 27 | elementwise_affine: False -------------------------------------------------------------------------------- /src/perturbench/configs/model/decoder_only.yaml: -------------------------------------------------------------------------------- 1 | _target_: perturbench.modelcore.models.DecoderOnly 2 | n_genes: null 3 | n_perts: null 4 | 5 | ## Optimizer params 6 | lr: 6.72106935566634e-05 7 | wd: 7.90200208844498e-09 8 | 9 | ## Model params 10 | n_layers: 3 11 | encoder_width: 3072 12 | softplus_output: True 13 | use_covariates: True 14 | use_perturbations: True -------------------------------------------------------------------------------- /src/perturbench/configs/model/latent_additive.yaml: -------------------------------------------------------------------------------- 1 | _target_: perturbench.modelcore.models.LatentAdditive 2 | n_genes: null 3 | n_perts: null 4 | 5 | ## Optimizer params 6 | lr: 6.72106935566634e-05 7 | wd: 7.90200208844498e-09 8 | 9 | ## Model params 10 | n_layers: 3 11 | encoder_width: 3072 12 | latent_dim: 160 13 | softplus_output: True 14 | inject_covariates_encoder: True 15 | inject_covariates_decoder: True 16 | dropout: 0.1 -------------------------------------------------------------------------------- /src/perturbench/configs/model/linear_additive.yaml: -------------------------------------------------------------------------------- 1 | _target_: perturbench.modelcore.models.LinearAdditive 2 | n_genes: null 3 | n_perts: null 4 | lr: 1e-3 5 | wd: 1e-4 6 | inject_covariates: True -------------------------------------------------------------------------------- /src/perturbench/configs/model/sams_vae.yaml: -------------------------------------------------------------------------------- 1 | _target_: perturbench.modelcore.models.SparseAdditiveVAE 2 | n_genes: null 3 | n_perts: null 4 | 5 | ## Model params 6 | n_layers_encoder_x: 5 7 | n_layers_encoder_e: 3 8 | n_layers_decoder: 4 9 | hidden_dim_x: 2332 10 | hidden_dim_cond: 329 11 | latent_dim: 55 12 | dropout: 0.2 13 | inject_covariates_encoder: False 14 | inject_covariates_decoder: False 15 | mask_prior_probability: 0.01 16 | softplus_output: True 17 | disable_sparsity: false 18 | disable_e_dist: false 19 | generative_counterfactual: false 20 | 21 | ## Optimizer params 22 | lr: 8.184599223121104e-05 23 | wd: 2.4153279969771546e-05 -------------------------------------------------------------------------------- /src/perturbench/configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to data directory 2 | data_dir: ./notebooks/neurips2024/perturbench_data/ 3 | 4 | # path to logging directory 5 | log_dir: ./logs/ 6 | 7 | # path to output directory, created dynamically by hydra 8 | # path generation pattern is specified in `configs/hydra/default.yaml` 9 | # use it to store all files generated during the run, like ckpts and metrics 10 | output_dir: ${hydra:runtime.output_dir} 11 | 12 | # path to working directory 13 | work_dir: ${hydra:runtime.cwd} 14 | -------------------------------------------------------------------------------- /src/perturbench/configs/predict.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default configuration 4 | # order of defaults determines the order in which configs override each other 5 | defaults: 6 | - _self_ 7 | - data: ??? ## Specify data config 8 | - model: ??? ## Specify model config 9 | - trainer: default 10 | - paths: default 11 | - hydra: default 12 | 13 | # experiment configs allow for version control of specific hyperparameters 14 | # e.g. best hyperparameters for given model and datamodule 15 | - experiment: null 16 | 17 | # Provide checkpoint path to pretrained model 18 | ckpt_path: ??? 19 | 20 | # task name, determines output directory path 21 | task_name: "predict" 22 | 23 | # Path to prediction dataframe (generate with notebooks/demos/generate_prediction_dataframe) 24 | prediction_dataframe_path: ??? 25 | 26 | # seed for random number generators in pytorch, numpy and python.random 27 | seed: null 28 | 29 | # Path to save predictions 30 | output_path: "${paths.output_dir}/predictions/" 31 | 32 | # Number of perturbations to generate in memory at once 33 | chunk_size: 50 -------------------------------------------------------------------------------- /src/perturbench/configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default configuration 4 | # order of defaults determines the order in which configs override each other 5 | defaults: 6 | - _self_ 7 | - data: ??? 8 | - model: linear_additive 9 | - callbacks: default 10 | - logger: default 11 | - trainer: default 12 | - paths: default 13 | - hydra: default 14 | 15 | # experiment configs allow for version control of specific hyperparameters 16 | # e.g. best hyperparameters for given model and datamodule 17 | - experiment: null 18 | 19 | # config for hyperparameter optimization 20 | - hpo: null 21 | 22 | # task name, determines output directory path 23 | task_name: "train" 24 | 25 | # set False to skip model training 26 | train: True 27 | 28 | # evaluate on test set, using best model weights achieved during training 29 | # lightning chooses best weights based on the metric specified in checkpoint callback 30 | test: True 31 | 32 | # simply provide checkpoint path to resume training 33 | ckpt_path: null 34 | 35 | # seed for random number generators in pytorch, numpy and python.random 36 | seed: null 37 | -------------------------------------------------------------------------------- /src/perturbench/configs/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: cpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /src/perturbench/configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: lightning.pytorch.trainer.Trainer 2 | 3 | default_root_dir: ${paths.output_dir} 4 | 5 | min_epochs: 1 # prevents early stopping 6 | max_epochs: 10 7 | 8 | accelerator: gpu 9 | devices: 1 10 | 11 | # mixed precision for extra speed-up 12 | precision: 16 13 | 14 | # perform a validation loop every N training epochs 15 | check_val_every_n_epoch: 1 16 | 17 | # set True to to ensure deterministic results 18 | # makes training slower but gives more reproducibility than just setting seeds 19 | deterministic: False 20 | 21 | # log every N steps 22 | log_every_n_steps: 10 -------------------------------------------------------------------------------- /src/perturbench/configs/trainer/gpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: gpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /src/perturbench/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/altoslabs/perturbench/e57c2fbafeb1b22df4923b4a1b3d3c82d2ba57ef/src/perturbench/data/__init__.py -------------------------------------------------------------------------------- /src/perturbench/data/accessors/base.py: -------------------------------------------------------------------------------- 1 | import scanpy as sc 2 | import os 3 | from abc import abstractmethod 4 | import subprocess as sp 5 | 6 | from perturbench.data.datasets import ( 7 | Counterfactual, 8 | CounterfactualWithReference, 9 | SingleCellPerturbation, 10 | SingleCellPerturbationWithControls, 11 | ) 12 | from perturbench.data.transforms.pipelines import SingleCellPipeline 13 | 14 | 15 | def download_scperturb_adata(data_url, data_cache_dir, filename): 16 | """ 17 | Helper function to download and cache anndata files. Returns an in-memory 18 | anndata object as-is with no curation. 19 | """ 20 | if not os.path.exists(data_cache_dir): 21 | os.makedirs(data_cache_dir) 22 | 23 | tmp_data_path = f"{data_cache_dir}/{filename}" 24 | 25 | if not os.path.exists(tmp_data_path): 26 | sp.call(f"wget {data_url} -O {tmp_data_path}", shell=True) 27 | 28 | adata = sc.read_h5ad(tmp_data_path) 29 | return adata 30 | 31 | 32 | def download_file(url: str, output_dir: str, output_filename: str) -> None: 33 | """ 34 | Downloads a file from a URL to the specified output path using wget. 35 | 36 | Args: 37 | url (str): The URL of the file to download 38 | output_path (str): The local path where the file should be saved 39 | """ 40 | if not os.path.exists(output_dir): 41 | os.makedirs(output_dir) 42 | output_path = f"{output_dir}/{output_filename}" 43 | sp.call(f"wget {url} -O {output_path}", shell=True) 44 | 45 | if ".gz" in output_filename: 46 | sp.call(f"gzip -d {output_path}", shell=True) 47 | 48 | return output_path.replace(".gz", "") 49 | 50 | 51 | class Accessor: 52 | data_cache_dir: str 53 | dataset_hf_url: str 54 | dataset_orig_url: str 55 | dataset_name: str 56 | processed_data_path: str 57 | 58 | def __init__( 59 | self, 60 | dataset_hf_url, 61 | dataset_orig_url, 62 | dataset_name, 63 | data_cache_dir="../perturbench_data", 64 | ): 65 | self.dataset_hf_url = dataset_hf_url 66 | self.dataset_orig_url = dataset_orig_url 67 | self.dataset_name = dataset_name 68 | self.data_cache_dir = data_cache_dir 69 | 70 | def get_dataset( 71 | self, 72 | dataset_class=SingleCellPerturbation, 73 | add_default_transforms=True, 74 | **dataset_kwargs, 75 | ): 76 | if dataset_class not in [ 77 | SingleCellPerturbation, 78 | SingleCellPerturbationWithControls, 79 | Counterfactual, 80 | CounterfactualWithReference, 81 | ]: 82 | raise ValueError("Invalid dataset class.") 83 | 84 | ## Instantiate datamodule with Hydra using the config with the datapath changed 85 | adata = self.get_anndata() 86 | 87 | if "perturbation_key" not in dataset_kwargs: 88 | dataset_kwargs["perturbation_key"] = "condition" 89 | if "covariate_keys" not in dataset_kwargs: 90 | dataset_kwargs["covariate_keys"] = ["cell_type"] 91 | if "perturbation_control_value" not in dataset_kwargs: 92 | dataset_kwargs["perturbation_control_value"] = "control" 93 | 94 | dataset, context = dataset_class.from_anndata( 95 | adata=adata, 96 | **dataset_kwargs, 97 | ) 98 | 99 | if add_default_transforms: 100 | dataset.transform = SingleCellPipeline( 101 | perturbation_uniques=context["perturbation_uniques"], 102 | covariate_uniques=context["covariate_uniques"], 103 | ) 104 | 105 | return dataset, context 106 | 107 | @abstractmethod 108 | def get_anndata(self): 109 | pass 110 | -------------------------------------------------------------------------------- /src/perturbench/data/accessors/frangieh21.py: -------------------------------------------------------------------------------- 1 | import scanpy as sc 2 | import os 3 | from scipy.sparse import csr_matrix 4 | 5 | from perturbench.analysis.preprocess import preprocess 6 | from perturbench.data.accessors.base import ( 7 | download_scperturb_adata, 8 | download_file, 9 | Accessor, 10 | ) 11 | 12 | 13 | class Frangieh21(Accessor): 14 | def __init__(self, data_cache_dir="../perturbench_data"): 15 | super().__init__( 16 | data_cache_dir=data_cache_dir, 17 | dataset_hf_url="https://huggingface.co/datasets/altoslabs/perturbench/resolve/main/frangieh19_preprocessed.h5ad.gz", 18 | dataset_orig_url="https://zenodo.org/records/7041849/files/FrangiehIzar2021_RNA.h5ad?download=1", 19 | dataset_name="frangieh21", 20 | ) 21 | 22 | def get_anndata(self): 23 | """ 24 | Downloads, curates, and preprocesses the Frangieh21 dataset from Hugging 25 | Face or the scPerturb database. Saves the preprocessed data to disk and 26 | returns it in-memory. 27 | 28 | Returns: 29 | adata (anndata.AnnData): Anndata object containing the processed data. 30 | 31 | """ 32 | self.processed_data_path = ( 33 | f"{self.data_cache_dir}/{self.dataset_name}_processed.h5ad" 34 | ) 35 | if os.path.exists(self.processed_data_path): 36 | print("Loading processed data from:", self.processed_data_path) 37 | adata = sc.read_h5ad(self.processed_data_path) 38 | 39 | else: 40 | try: 41 | hf_filename = f"{self.dataset_name}_processed.h5ad.gz" 42 | download_file(self.dataset_hf_url, self.data_cache_dir, hf_filename) 43 | adata = sc.read_h5ad(self.processed_data_path) 44 | 45 | except Exception as e: 46 | print(f"Error downloading file from {self.dataset_hf_url}: {e}") 47 | print(f"Downloading file from {self.dataset_orig_url}") 48 | 49 | adata = download_scperturb_adata( 50 | self.dataset_orig_url, 51 | self.data_cache_dir, 52 | filename=f"{self.dataset_name}_downloaded.h5ad", 53 | ) 54 | 55 | ## Format column names 56 | treatment_map = { 57 | "Co-culture": "co-culture", 58 | "Control": "none", 59 | } 60 | adata.obs["treatment"] = [ 61 | treatment_map[x] if x in treatment_map else x 62 | for x in adata.obs.perturbation_2 63 | ] 64 | adata.obs["cell_type"] = "melanocyte" 65 | adata.obs["condition"] = adata.obs.perturbation.copy() 66 | adata.obs["perturbation_type"] = "CRISPRi" 67 | adata.obs["dataset"] = "frangieh21" 68 | 69 | adata.X = csr_matrix(adata.X) 70 | adata = preprocess( 71 | adata, 72 | perturbation_key="condition", 73 | covariate_keys=["treatment"], 74 | ) 75 | 76 | adata = adata.copy() 77 | adata.write_h5ad(self.processed_data_path) 78 | 79 | print("Saved processed data to:", self.processed_data_path) 80 | 81 | return adata 82 | -------------------------------------------------------------------------------- /src/perturbench/data/accessors/jiang24.py: -------------------------------------------------------------------------------- 1 | import scanpy as sc 2 | import os 3 | 4 | from perturbench.data.accessors.base import ( 5 | download_file, 6 | Accessor, 7 | ) 8 | 9 | 10 | class Jiang24(Accessor): 11 | def __init__(self, data_cache_dir="../perturbench_data"): 12 | super().__init__( 13 | data_cache_dir=data_cache_dir, 14 | dataset_hf_url="https://huggingface.co/datasets/altoslabs/perturbench/blob/main/jiang24_preprocessed.h5ad.gz", 15 | dataset_orig_url=None, 16 | dataset_name="jiang24", 17 | ) 18 | 19 | def get_anndata(self): 20 | """ 21 | Downloads, curates, and preprocesses the Jiang24 dataset from Hugging Face. 22 | Saves the preprocessed data to disk and returns it in-memory. 23 | 24 | Returns: 25 | adata (anndata.AnnData): Anndata object containing the processed data. 26 | 27 | """ 28 | self.processed_data_path = ( 29 | f"{self.data_cache_dir}/{self.dataset_name}_processed.h5ad" 30 | ) 31 | if os.path.exists(self.processed_data_path): 32 | print("Loading processed data from:", self.processed_data_path) 33 | adata = sc.read_h5ad(self.processed_data_path) 34 | 35 | else: 36 | try: 37 | hf_filename = f"{self.dataset_name}_processed.h5ad.gz" 38 | download_file(self.dataset_hf_url, self.data_cache_dir, hf_filename) 39 | adata = sc.read_h5ad(self.processed_data_path) 40 | 41 | except Exception as e: 42 | print(f"Error downloading file from {self.dataset_hf_url}: {e}") 43 | raise ValueError( 44 | "Automatic data curation not available for this dataset. \ 45 | Use the notebooks in notebooks/neurips2025/data_curation \ 46 | to download and preprocess the data." 47 | ) 48 | 49 | print("Saved processed data to:", self.processed_data_path) 50 | 51 | return adata 52 | -------------------------------------------------------------------------------- /src/perturbench/data/accessors/mcfaline23.py: -------------------------------------------------------------------------------- 1 | import scanpy as sc 2 | import os 3 | 4 | from perturbench.data.accessors.base import ( 5 | download_file, 6 | Accessor, 7 | ) 8 | 9 | 10 | class McFaline23(Accessor): 11 | def __init__(self, data_cache_dir="../perturbench_data"): 12 | super().__init__( 13 | data_cache_dir=data_cache_dir, 14 | dataset_hf_url="https://huggingface.co/datasets/altoslabs/perturbench/blob/main/mcfaline23_gxe_preprocessed.h5ad.gz", 15 | dataset_orig_url=None, 16 | dataset_name="mcfaline23", 17 | ) 18 | 19 | def get_anndata(self): 20 | """ 21 | Downloads, curates, and preprocesses the McFalineFigueroa23 dataset from 22 | Hugging Face. Saves the preprocessed data to disk and returns it in-memory. 23 | 24 | Returns: 25 | adata (anndata.AnnData): Anndata object containing the processed data. 26 | 27 | """ 28 | self.processed_data_path = ( 29 | f"{self.data_cache_dir}/{self.dataset_name}_processed.h5ad" 30 | ) 31 | if os.path.exists(self.processed_data_path): 32 | print("Loading processed data from:", self.processed_data_path) 33 | adata = sc.read_h5ad(self.processed_data_path) 34 | 35 | else: 36 | try: 37 | hf_filename = f"{self.dataset_name}_processed.h5ad.gz" 38 | download_file(self.dataset_hf_url, self.data_cache_dir, hf_filename) 39 | adata = sc.read_h5ad(self.processed_data_path) 40 | 41 | except Exception as e: 42 | print(f"Error downloading file from {self.dataset_hf_url}: {e}") 43 | raise ValueError( 44 | "Automatic data curation not available for this dataset. \ 45 | Use the notebooks in notebooks/neurips2025/data_curation \ 46 | to download and preprocess the data." 47 | ) 48 | 49 | print("Saved processed data to:", self.processed_data_path) 50 | 51 | return adata 52 | -------------------------------------------------------------------------------- /src/perturbench/data/accessors/norman19.py: -------------------------------------------------------------------------------- 1 | import scanpy as sc 2 | import os 3 | from scipy.sparse import csr_matrix 4 | 5 | from perturbench.analysis.preprocess import preprocess 6 | from perturbench.data.accessors.base import ( 7 | download_scperturb_adata, 8 | download_file, 9 | Accessor, 10 | ) 11 | 12 | 13 | class Norman19(Accessor): 14 | def __init__(self, data_cache_dir="../perturbench_data"): 15 | super().__init__( 16 | data_cache_dir=data_cache_dir, 17 | dataset_hf_url="https://huggingface.co/datasets/altoslabs/perturbench/resolve/main/norman19_preprocessed.h5ad.gz", 18 | dataset_orig_url="https://zenodo.org/records/7041849/files/NormanWeissman2019_filtered.h5ad?download=1", 19 | dataset_name="norman19", 20 | ) 21 | 22 | def get_anndata(self): 23 | """ 24 | Downloads, curates, and preprocesses the norman19 dataset from either 25 | Hugging Face or the scPerturb database. Saves the preprocessed data to 26 | disk and returns it in-memory. 27 | 28 | Returns: 29 | adata (anndata.AnnData): Anndata object containing the processed data. 30 | 31 | """ 32 | self.processed_data_path = ( 33 | f"{self.data_cache_dir}/{self.dataset_name}_processed.h5ad" 34 | ) 35 | if os.path.exists(self.processed_data_path): 36 | print("Loading processed data from:", self.processed_data_path) 37 | adata = sc.read_h5ad(self.processed_data_path) 38 | 39 | else: 40 | try: 41 | hf_filename = f"{self.dataset_name}_processed.h5ad.gz" 42 | download_file(self.dataset_hf_url, self.data_cache_dir, hf_filename) 43 | adata = sc.read_h5ad(self.processed_data_path) 44 | 45 | except Exception as e: 46 | print(f"Error downloading file from {self.dataset_hf_url}: {e}") 47 | print(f"Downloading file from {self.dataset_orig_url}") 48 | 49 | adata = download_scperturb_adata( 50 | self.dataset_orig_url, 51 | self.data_cache_dir, 52 | filename=f"{self.dataset_name}_downloaded.h5ad", 53 | ) 54 | 55 | adata.obs.rename( 56 | columns={ 57 | "nCount_RNA": "ncounts", 58 | "nFeature_RNA": "ngenes", 59 | "percent.mt": "percent_mito", 60 | "cell_line": "cell_type", 61 | }, 62 | inplace=True, 63 | ) 64 | 65 | adata.obs["perturbation"] = adata.obs["perturbation"].str.replace("_", "+") 66 | adata.obs["perturbation"] = adata.obs["perturbation"].astype("category") 67 | adata.obs["condition"] = adata.obs.perturbation.copy() 68 | 69 | adata.X = csr_matrix(adata.X) 70 | 71 | adata = preprocess( 72 | adata, 73 | perturbation_key="condition", 74 | covariate_keys=["cell_type"], 75 | ) 76 | 77 | adata = adata.copy() 78 | adata.write_h5ad(self.processed_data_path) 79 | 80 | print("Saved processed data to:", self.processed_data_path) 81 | 82 | return adata 83 | -------------------------------------------------------------------------------- /src/perturbench/data/accessors/srivatsan20.py: -------------------------------------------------------------------------------- 1 | import scanpy as sc 2 | import os 3 | from scipy.sparse import csr_matrix 4 | 5 | from perturbench.analysis.preprocess import preprocess 6 | from perturbench.analysis.utils import get_ensembl_mappings 7 | from perturbench.data.accessors.base import ( 8 | download_scperturb_adata, 9 | download_file, 10 | Accessor, 11 | ) 12 | 13 | 14 | class Sciplex3(Accessor): 15 | def __init__(self, data_cache_dir="../perturbench_data"): 16 | super().__init__( 17 | data_cache_dir=data_cache_dir, 18 | dataset_hf_url="https://huggingface.co/datasets/altoslabs/perturbench/resolve/main/srivatsan20_highest_dose_preprocessed.h5ad.gz", 19 | dataset_orig_url="https://zenodo.org/records/7041849/files/SrivatsanTrapnell2020_sciplex3.h5ad?download=1", 20 | dataset_name="sciplex3", 21 | ) 22 | 23 | def get_anndata(self): 24 | """ 25 | Downloads, curates, and preprocesses the sciplex3 dataset from the scPerturb 26 | database. Saves the preprocessed data to disk and returns it in-memory. 27 | 28 | Returns: 29 | adata (anndata.AnnData): Anndata object containing the processed data. 30 | 31 | """ 32 | self.processed_data_path = ( 33 | f"{self.data_cache_dir}/{self.dataset_name}_processed.h5ad" 34 | ) 35 | if os.path.exists(self.processed_data_path): 36 | print("Loading processed data from:", self.processed_data_path) 37 | adata = sc.read_h5ad(self.processed_data_path) 38 | 39 | else: 40 | try: 41 | hf_filename = f"{self.dataset_name}_processed.h5ad.gz" 42 | download_file(self.dataset_hf_url, self.data_cache_dir, hf_filename) 43 | adata = sc.read_h5ad(self.processed_data_path) 44 | 45 | except Exception as e: 46 | print(f"Error downloading file from {self.dataset_hf_url}: {e}") 47 | print(f"Downloading file from {self.dataset_orig_url}") 48 | 49 | adata = download_scperturb_adata( 50 | self.dataset_orig_url, 51 | self.data_cache_dir, 52 | filename=f"{self.dataset_name}_downloaded.h5ad", 53 | ) 54 | 55 | unique_genes = ~adata.var.ensembl_id.duplicated() 56 | adata = adata[:, unique_genes] 57 | 58 | ## Map ENSEMBL IDs to gene symbols 59 | adata.var_names = adata.var.ensembl_id.astype(str) 60 | human_ids = [x for x in adata.var_names if "ENSG" in x] 61 | 62 | adata = adata[:, human_ids] 63 | gene_mappings = get_ensembl_mappings() 64 | gene_mappings = { 65 | k: v for k, v in gene_mappings.items() if isinstance(v, str) and v != "" 66 | } 67 | adata = adata[:, [x in gene_mappings for x in adata.var_names]] 68 | adata.var["gene_symbol"] = [gene_mappings[x] for x in adata.var_names] 69 | adata.var_names = adata.var["gene_symbol"] 70 | adata.var_names_make_unique() 71 | adata.var.index.name = None 72 | 73 | ## Format column names 74 | adata.obs.rename( 75 | columns={ 76 | "n_genes": "ngenes", 77 | "n_counts": "ncounts", 78 | }, 79 | inplace=True, 80 | ) 81 | 82 | ## Format cell line names 83 | adata.obs["cell_type"] = adata.obs["cell_line"].copy() 84 | adata = adata[adata.obs.cell_type.isin(["MCF7", "A549", "K562"])] 85 | adata.obs["cell_type"] = [x.lower() for x in adata.obs.cell_type] 86 | 87 | ## Rename some chemicals with the "+" symbol 88 | perturbation_remap = { 89 | "(+)-JQ1": "JQ1", 90 | "ENMD-2076 L-(+)-Tartaric acid": "ENMD-2076", 91 | } 92 | adata.obs["perturbation"] = [ 93 | perturbation_remap.get(x, x) for x in adata.obs.perturbation.astype(str) 94 | ] 95 | adata.obs["condition"] = adata.obs["perturbation"].copy() 96 | 97 | ## Subset to highest dose only 98 | adata = adata[ 99 | (adata.obs.dose_value == 10000) | (adata.obs.condition == "control") 100 | ].copy() 101 | 102 | adata.X = csr_matrix(adata.X) 103 | adata = preprocess( 104 | adata, 105 | perturbation_key="condition", 106 | covariate_keys=["cell_type"], 107 | ) 108 | 109 | adata = adata.copy() 110 | adata.write_h5ad(self.processed_data_path) 111 | 112 | print("Saved processed data to:", self.processed_data_path) 113 | 114 | return adata 115 | -------------------------------------------------------------------------------- /src/perturbench/data/collate.py: -------------------------------------------------------------------------------- 1 | class noop_collate: 2 | """No operation collate function. Returns the batch as is.""" 3 | 4 | def __call__(self, batch: list): 5 | if len(batch) == 1: 6 | return batch[0] 7 | else: 8 | return batch 9 | -------------------------------------------------------------------------------- /src/perturbench/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .singlecell import SingleCellPerturbation, SingleCellPerturbationWithControls 2 | from .population import Counterfactual, CounterfactualWithReference 3 | -------------------------------------------------------------------------------- /src/perturbench/data/resources/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/altoslabs/perturbench/e57c2fbafeb1b22df4923b4a1b3d3c82d2ba57ef/src/perturbench/data/resources/__init__.py -------------------------------------------------------------------------------- /src/perturbench/data/resources/devel.h5ad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/altoslabs/perturbench/e57c2fbafeb1b22df4923b4a1b3d3c82d2ba57ef/src/perturbench/data/resources/devel.h5ad -------------------------------------------------------------------------------- /src/perturbench/data/transforms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/altoslabs/perturbench/e57c2fbafeb1b22df4923b4a1b3d3c82d2ba57ef/src/perturbench/data/transforms/__init__.py -------------------------------------------------------------------------------- /src/perturbench/data/transforms/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from abc import ABC, abstractmethod 3 | from typing import Iterator, TypeVar 4 | 5 | from ..types import Example, Batch 6 | 7 | Datum = TypeVar("Datum", Example, Batch) 8 | 9 | 10 | class Transform(ABC): 11 | """Abstract transform interface.""" 12 | 13 | @abstractmethod 14 | def __call__(self, data: Datum) -> Datum: 15 | pass 16 | 17 | @abstractmethod 18 | def __repr__(self) -> str: 19 | return self.__class__.__name__ + "({!s})" 20 | 21 | 22 | class ExampleTransform(Transform): 23 | """Transforms an example.""" 24 | 25 | @abstractmethod 26 | def __call__(self, example: Example) -> Example: 27 | pass 28 | 29 | def __repr__(self) -> str: 30 | _base = super().__repr__() 31 | return "[example-wise]" + _base 32 | 33 | def batchify(self, collate_fn: callable = list) -> Map: 34 | """Converts an example transform to a batch transform.""" 35 | return Map(self, collate_fn) 36 | 37 | 38 | class Map(Transform): 39 | """Maps a transform to a batch of examples.""" 40 | 41 | def __init__(self, transform: ExampleTransform, collate_fn: callable = list): 42 | self.transform = transform 43 | self.collate_fn = collate_fn 44 | 45 | def __call__(self, batch: Iterator[Example]) -> Batch: 46 | return self.collate_fn(list(map(self.transform, batch))) 47 | 48 | def __repr__(self) -> str: 49 | _base = super().__repr__() 50 | # strip the [example-wise] prefix from the self.transform repr 51 | _instance_repr = repr(self.transform) 52 | if _instance_repr.startswith("[example-wise]"): 53 | _instance_repr = _instance_repr[len("[example-wise]") :] 54 | args_repr = f"{_instance_repr}, collate_fn={self.collate_fn.__name__}" 55 | return "[batch-wise]" + _base.format(args_repr) 56 | 57 | 58 | class Dispatch(dict, Transform): 59 | """Dispatches a transform to an example based on a key field. 60 | 61 | Attributes: 62 | self: A map of key to transform. 63 | """ 64 | 65 | def __call__(self, data: Datum) -> Datum: 66 | """Apply each transform to the field of an example matching its key.""" 67 | result = {} 68 | try: 69 | for key, transform in self.items(): 70 | result[key] = transform(getattr(data, key)) 71 | except KeyError as exc: 72 | raise TypeError( 73 | f"Invalid {key=} in transforms. All keys need to match the " 74 | f"fields of an example." 75 | ) from exc 76 | 77 | return data._replace(**result) 78 | 79 | def __repr__(self) -> str: 80 | _base = Transform.__repr__(self) 81 | transforms_repr = ", ".join( 82 | f"{key}: {repr(transform)}" for key, transform in self.items() 83 | ) 84 | return _base.format(transforms_repr) 85 | 86 | 87 | class Compose(list, Transform): 88 | """Creates a transform from a sequence of transforms.""" 89 | 90 | def __call__(self, data: Datum) -> Datum: 91 | for transform in self: 92 | data = transform(data) 93 | return data 94 | 95 | def __repr__(self) -> str: 96 | transforms_repr = " \u2192 ".join(repr(transform) for transform in self) 97 | return f"[{transforms_repr}]" 98 | -------------------------------------------------------------------------------- /src/perturbench/data/transforms/encoders.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import functools 3 | from typing import Collection, Sequence 4 | 5 | import torch 6 | import numpy as np 7 | from sklearn.preprocessing import OrdinalEncoder, OneHotEncoder, MultiLabelBinarizer 8 | 9 | from .base import Transform 10 | from ..types import ExampleMultiLabel, BatchMultiLabel 11 | 12 | 13 | class OneHotEncode(Transform): 14 | """One-hot encode a categorical variable. 15 | 16 | Attributes: 17 | onehot_encoder: the wrapped encoder instance 18 | """ 19 | 20 | one_hot_encoder: OneHotEncoder 21 | 22 | def __init__(self, categories: Collection[str], **kwargs): 23 | categories = [list(categories)] 24 | self.one_hot_encoder = OneHotEncoder( 25 | categories=categories, 26 | sparse_output=False, 27 | **kwargs, 28 | ) 29 | 30 | def __call__(self, labels: Sequence[str]): 31 | string_array = np.array(labels).reshape(-1, 1) 32 | encoded = self.one_hot_encoder.fit_transform(string_array) 33 | return torch.Tensor(encoded) 34 | 35 | def __repr__(self): 36 | _base = super().__repr__() 37 | categories = ", ".join(self.one_hot_encoder.categories[0]) 38 | return _base.format(categories) 39 | 40 | 41 | class LabelEncode(Transform): 42 | """Label encode categorical variables. 43 | 44 | Attributes: 45 | ordinal_encoder: sklearn.preprocessing.OrdinalEncoder 46 | """ 47 | 48 | ordinal_encoder: OrdinalEncoder 49 | 50 | def __init__(self, values: Sequence[str]): 51 | categories = [np.array(values)] 52 | self.ordinal_encoder = OrdinalEncoder(categories=categories) 53 | 54 | def __call__(self, labels: Sequence[str]): 55 | string_array = np.array(labels).reshape(-1, 1) 56 | return torch.Tensor(self.ordinal_encoder.fit_transform(string_array)) 57 | 58 | def __repr__(self): 59 | _base = super().__repr__() 60 | categories = ", ".join(self.ordinal_encoder.categories[0]) 61 | return _base.format(categories) 62 | 63 | 64 | class MultiLabelEncode(Transform): 65 | """Transforms a sequence of labels into a binary vector. 66 | 67 | Attributes: 68 | label_binarizer: the wrapped binarizer instance 69 | 70 | Raises: 71 | ValueError: if any of the labels are not found in the encoder classes 72 | """ 73 | 74 | label_binarizer: MultiLabelBinarizer 75 | 76 | def __init__(self, classes: Collection[str]): 77 | self.label_binarizer = MultiLabelBinarizer( 78 | classes=list(classes), sparse_output=False 79 | ) 80 | 81 | @functools.cached_property 82 | def classes(self): 83 | return set(self.label_binarizer.classes) 84 | 85 | def __call__(self, labels: ExampleMultiLabel | BatchMultiLabel) -> torch.Tensor: 86 | # If labels is a single example, convert it to a batch 87 | if not labels or isinstance(labels[0], str): 88 | labels = [labels] 89 | self._check_inputs(labels) 90 | encoded = self.label_binarizer.fit_transform(labels) 91 | return torch.from_numpy(encoded) 92 | 93 | def _check_inputs(self, labels: BatchMultiLabel): 94 | unique_labels = set(itertools.chain.from_iterable(labels)) 95 | if not unique_labels <= self.classes: 96 | missing_labels = unique_labels - self.classes 97 | raise ValueError( 98 | f"Labels {missing_labels} not found in the encoder classes {self.classes}" 99 | ) 100 | 101 | def __repr__(self): 102 | _base = super().__repr__() 103 | classes = ", ".join(self.label_binarizer.classes) 104 | return _base.format(classes) 105 | -------------------------------------------------------------------------------- /src/perturbench/data/transforms/ops.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable 2 | 3 | import torch 4 | from scipy.sparse import csr_matrix 5 | 6 | from .base import Transform 7 | 8 | 9 | class ToDense(Transform): 10 | """Convert a sparse matrix/tensor to a dense matrix/tensor.""" 11 | 12 | def __call__(self, value: torch.Tensor | csr_matrix) -> torch.Tensor: 13 | if isinstance(value, torch.Tensor): 14 | return value.to_dense() 15 | elif isinstance(value, csr_matrix): 16 | return torch.Tensor(value.toarray()) 17 | else: 18 | return value 19 | 20 | def __repr__(self): 21 | return "ToDense" 22 | 23 | 24 | class ToFloat(Transform): 25 | """Convert a tensor to float.""" 26 | 27 | def __call__(self, value: torch.Tensor): 28 | return value.float() 29 | 30 | def __repr__(self): 31 | return "ToFloat" 32 | 33 | 34 | class MapApply(Transform): 35 | """Map each transform to an input based on a key. 36 | 37 | Attributes: 38 | transform_map: A map of key to transform. 39 | """ 40 | 41 | transform_map: dict[str, Transform | Callable] 42 | 43 | def __init__( 44 | self, 45 | transforms: dict[str, Transform | Callable], 46 | init_params_map: dict | None = None, 47 | ) -> None: 48 | """Initializes the instance based on passed transforms. 49 | 50 | This classes supports two ways of initializing the transforms. The first 51 | is by passing a map of key to transform. The second is by passing a map of 52 | key to factory callable. The factory callable will be called with the 53 | corresponding init params from the init_params_map. The factory callable 54 | should return a Transform. 55 | 56 | Args: 57 | transforms: A map of key to transform. 58 | init_params_map: A map of key to init params for the transforms. 59 | 60 | Raises: 61 | ValueError: If init_params_map is not None when using a dict of 62 | Transforms. 63 | TypeError: If the transform is not a dict of Transform or a callable. 64 | """ 65 | super().__init__() 66 | self.transform_map = {} 67 | for key, transform in transforms.items(): 68 | # Transforms are dict[str, Transform], directly assign them 69 | if isinstance(transform, Transform): 70 | if init_params_map is not None: 71 | raise ValueError( 72 | "init_params_map should be None when using a dict of " 73 | "Transforms." 74 | ) 75 | self.transform_map[key] = transform 76 | # Transforms are dict[str, factory_callable], call the factory 77 | elif callable(transform): 78 | self.transform_map[key] = transform(init_params_map[key]) 79 | else: 80 | raise TypeError( 81 | f"Invalid type for {key=} in transform. Must be either a " 82 | f"Transform or a callable." 83 | ) 84 | 85 | def __call__(self, value_map: dict[str, Any]) -> dict[str, Any]: 86 | return {key: self.transform_map[key](val) for key, val in value_map.items()} 87 | 88 | def __repr__(self) -> str: 89 | transforms_repr = ", ".join( 90 | f"{key}: {repr(transform)}" for key, transform in self.transform_map.items() 91 | ) 92 | return "{" + transforms_repr + "}" 93 | -------------------------------------------------------------------------------- /src/perturbench/data/transforms/pipelines.py: -------------------------------------------------------------------------------- 1 | from .base import Dispatch, Compose 2 | from .encoders import OneHotEncode, MultiLabelEncode 3 | from .ops import ToDense, ToFloat, MapApply 4 | 5 | 6 | class SingleCellPipeline(Dispatch): 7 | """Single cell transform pipeline.""" 8 | 9 | def __init__( 10 | self, 11 | perturbation_uniques: set[str], 12 | covariate_uniques: dict[str:set], 13 | ) -> None: 14 | # Set up covariates transform 15 | covariate_transform = { 16 | key: Compose([OneHotEncode(uniques), ToFloat()]) 17 | for key, uniques in covariate_uniques.items() 18 | } 19 | # Initialize the pipeline 20 | super().__init__( 21 | perturbations=Compose( 22 | [ 23 | MultiLabelEncode(perturbation_uniques), 24 | ToFloat(), 25 | ] 26 | ), 27 | gene_expression=ToDense(), 28 | covariates=MapApply(covariate_transform), 29 | ) 30 | -------------------------------------------------------------------------------- /src/perturbench/data/types.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import NamedTuple, Sequence, Any 3 | 4 | from scipy.sparse import csr_matrix, csr_array 5 | import numpy as np 6 | 7 | SparseMatrix = csr_matrix 8 | SparseVector = csr_array 9 | 10 | ExampleMultiLabel = Sequence[str] 11 | BatchMultiLabel = Sequence[ExampleMultiLabel] 12 | 13 | 14 | class Example(NamedTuple): 15 | """Single Cell Expression Example.""" 16 | 17 | # A vector of size (num_genes, ) 18 | gene_expression: SparseVector 19 | # A list of perturbations applied to the cell 20 | perturbations: Sequence[str] ## TODO: Should be [] if control 21 | # A map from covariate name to covariate value 22 | covariates: dict[str, str] | None = None 23 | # A map from control condition name to control gene expression of 24 | # shape (num_controls_in_condition, num_genes) 25 | controls: SparseVector | None = None 26 | # A cell id 27 | id: str | None = None 28 | # A list of gene names of length num_genes 29 | gene_names: Sequence[str] | None = None 30 | # Optional foundation model embeddings 31 | embeddings: np.ndarray | None = None 32 | 33 | 34 | class Batch(NamedTuple): 35 | """Single Cell Expression Batch.""" 36 | 37 | gene_expression: SparseMatrix 38 | perturbations: Sequence[list[str]] | SparseMatrix 39 | covariates: dict[str, Sequence[str] | SparseMatrix] | None = None 40 | controls: SparseMatrix | None = None 41 | id: Sequence[str] | None = None 42 | gene_names: Sequence[str] | None = None 43 | embeddings: np.ndarray | None = None 44 | 45 | 46 | class FrozenDictKeyMap(dict): 47 | """A dictionary that uses dictionaries as keys. 48 | 49 | Dictionaries cannot be used directly as keys to another dictionary because 50 | they are mutable. As a result this class first converts the dictionary to a 51 | frozenset of key-value pairs before using it as a key. The underlying data 52 | is stored using the dictionary data structure and this class just modifies 53 | the accessor and mutator methods. 54 | 55 | Example: 56 | >>> d = FrozenDictKeyMap() 57 | >>> d[{"a": 1, "b": 2}] = 1 58 | >>> d[{"a": 1, "b": 2}] = 2 59 | >>> d[{"a": 1, "b": 2}] = 3 60 | >>> d 61 | {frozenset({('a', 1), ('b', 2)}): 3} 62 | 63 | Attributes: see dict class 64 | """ 65 | 66 | def __init__(self, data: Sequence[tuple[dict, Any]] | None = None): 67 | """Initialize the dictionary. 68 | 69 | Args: 70 | data: a sequence of (key, value) pairs to initialize the dictionary 71 | """ 72 | if data is not None: 73 | try: 74 | _data = [(frozenset(key.items()), value) for key, value in data] 75 | except AttributeError as exc: 76 | raise ValueError( 77 | "data must be a sequence of (key, value) pairs where key is a " 78 | "dictionary" 79 | ) from exc 80 | else: 81 | _data = [] 82 | super().__init__(_data) 83 | 84 | def __getitem__(self, key: dict) -> Any: 85 | """Get the value associated with the key. 86 | 87 | Args: 88 | key: a dictionary. 89 | 90 | Returns: 91 | The value associated with the key. 92 | """ 93 | if isinstance(key, frozenset): 94 | key = dict(key) 95 | return super().__getitem__(frozenset(key.items())) 96 | 97 | def __setitem__(self, key: dict, value: Any) -> None: 98 | """Set the value associated with the key. 99 | 100 | Args: 101 | key: a dictionary. 102 | value: the value to set. 103 | """ 104 | if isinstance(key, frozenset): 105 | key = dict(key) 106 | super().__setitem__(frozenset(key.items()), value) 107 | -------------------------------------------------------------------------------- /src/perturbench/modelcore/__init__.py: -------------------------------------------------------------------------------- 1 | VERSION = "0.0.1" 2 | -------------------------------------------------------------------------------- /src/perturbench/modelcore/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .linear_additive import LinearAdditive 2 | from .latent_additive import LatentAdditive 3 | from .decoder_only import DecoderOnly 4 | from .sams_vae import SparseAdditiveVAE 5 | from .base import PerturbationModel 6 | from .biolord import BiolordStar 7 | from .average import Average 8 | from .cpa import CPA 9 | -------------------------------------------------------------------------------- /src/perturbench/modelcore/models/average.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | 4 | Copyright (c) 2024, 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | """ 31 | 32 | import lightning as L 33 | import torch 34 | from torch import optim 35 | from perturbench.data.types import Batch 36 | from .base import PerturbationModel 37 | 38 | 39 | class Average(PerturbationModel): 40 | """ 41 | A perturbation prediction baseline model that returns the average expression of each perturbation in the training data. 42 | """ 43 | 44 | def __init__( 45 | self, 46 | n_genes: int, 47 | n_perts: int, 48 | datamodule: L.LightningDataModule | None = None, 49 | ) -> None: 50 | """ 51 | The constructor for the Average class. 52 | 53 | Args: 54 | n_genes (int): Number of genes in the dataset 55 | n_perts (int): Number of perturbations in the dataset (not including controls) 56 | """ 57 | super(Average, self).__init__(datamodule) 58 | self.save_hyperparameters(ignore=["datamodule"]) 59 | 60 | if n_genes is None: 61 | n_genes = datamodule.num_genes 62 | 63 | if n_perts is None: 64 | n_perts = datamodule.num_perturbations 65 | 66 | self.n_genes = n_genes 67 | self.n_perts = n_perts 68 | self.average_expression = torch.nn.Parameter( 69 | torch.zeros(n_perts, n_genes), requires_grad=False 70 | ) 71 | self.sum_expression = torch.zeros(n_perts, n_genes) 72 | self.num_cells = torch.zeros(n_perts) 73 | self.dummy_nn = torch.nn.Linear(1, 1) 74 | 75 | def configure_optimizers(self): 76 | optimizer = optim.Adam(self.parameters()) 77 | return optimizer 78 | 79 | def backward(self, use_amp, loss, optimizer): 80 | return 81 | 82 | def on_train_start(self): 83 | self.sum_expression = self.sum_expression.to(self.device) 84 | self.num_cells = self.num_cells.to(self.device) 85 | 86 | def training_step(self, batch: Batch, batch_idx: int | list[int]): 87 | # Unpack the batch 88 | observed_perturbed_expression = batch.gene_expression.squeeze() 89 | perturbation = batch.perturbations.squeeze() 90 | self.sum_expression += torch.matmul( 91 | perturbation.t(), observed_perturbed_expression 92 | ) 93 | self.num_cells += perturbation.sum(0) 94 | 95 | def on_train_epoch_end(self): 96 | average_expression = self.sum_expression.t() / self.num_cells 97 | self.average_expression = torch.nn.Parameter( 98 | average_expression.t(), requires_grad=False 99 | ) 100 | 101 | self.sum_expression = torch.zeros(self.n_perts, self.n_genes) 102 | self.num_cells = torch.zeros(self.n_perts) 103 | 104 | def predict(self, batch: Batch): 105 | perturbation = batch.perturbations.squeeze() 106 | perturbation = perturbation.to(self.device) 107 | predicted_perturbed_expression = torch.matmul( 108 | perturbation, self.average_expression 109 | ) 110 | predicted_perturbed_expression = ( 111 | predicted_perturbed_expression.t() / perturbation.sum(1) 112 | ) 113 | return predicted_perturbed_expression.t() 114 | -------------------------------------------------------------------------------- /src/perturbench/modelcore/models/decoder_only.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | 4 | Copyright (c) 2024, 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | """ 31 | 32 | import torch 33 | import torch.nn.functional as F 34 | import lightning as L 35 | import numpy as np 36 | 37 | from ..nn.mlp import MLP 38 | from .base import PerturbationModel 39 | from perturbench.data.types import Batch 40 | 41 | 42 | class DecoderOnly(PerturbationModel): 43 | """ 44 | A latent additive model for predicting perturbation effects 45 | """ 46 | 47 | def __init__( 48 | self, 49 | n_genes: int, 50 | n_perts: int, 51 | n_layers=2, 52 | encoder_width=128, 53 | softplus_output=True, 54 | use_covariates=True, 55 | use_perturbations=True, 56 | lr: float | None = None, 57 | wd: float | None = None, 58 | lr_scheduler_freq: int | None = None, 59 | lr_scheduler_patience: int | None = None, 60 | lr_scheduler_factor: float | None = None, 61 | datamodule: L.LightningDataModule | None = None, 62 | ) -> None: 63 | """ 64 | The constructor for the DecoderOnly class. 65 | 66 | Args: 67 | n_genes (int): Number of genes to use for prediction 68 | n_perts (int): Number of perturbations in the dataset (not including controls) 69 | n_layers (int): Number of layers in the encoder/decoder 70 | lr (float): Learning rate 71 | wd (float): Weight decay 72 | lr_scheduler_freq (int): How often the learning rate scheduler checks val_loss 73 | lr_scheduler_patience (int): Learning rate scheduler patience 74 | lr_scheduler_factor (float): Factor by which to reduce learning rate when learning rate scheduler triggers 75 | softplus_output (bool): Whether to apply a softplus activation to the output of the decoder to enforce non-negativity 76 | """ 77 | 78 | super(DecoderOnly, self).__init__( 79 | datamodule=datamodule, 80 | lr=lr, 81 | wd=wd, 82 | lr_scheduler_freq=lr_scheduler_freq, 83 | lr_scheduler_patience=lr_scheduler_patience, 84 | lr_scheduler_factor=lr_scheduler_factor, 85 | ) 86 | self.save_hyperparameters(ignore=["datamodule"]) 87 | 88 | if not (use_covariates or use_perturbations): 89 | raise ValueError( 90 | "'use_covariates' and 'use_perturbations' can not both be false. Either covariates or perturbations have to be used." 91 | ) 92 | 93 | if n_genes is not None: 94 | self.n_genes = n_genes 95 | if n_perts is not None: 96 | self.n_perts = n_perts 97 | 98 | n_total_covariates = ( 99 | np.sum( 100 | [ 101 | len(unique_covs) 102 | for unique_covs in datamodule.train_context[ 103 | "covariate_uniques" 104 | ].values() 105 | ] 106 | ) 107 | if use_covariates 108 | else 0 109 | ) 110 | 111 | n_perts = self.n_perts if use_perturbations else 0 112 | 113 | decoder_input_dim = n_total_covariates + n_perts 114 | 115 | self.decoder = MLP(decoder_input_dim, encoder_width, self.n_genes, n_layers) 116 | self.softplus_output = softplus_output 117 | self.use_covariates = use_covariates 118 | self.use_perturbations = use_perturbations 119 | 120 | def forward( 121 | self, 122 | control_expression: torch.Tensor, 123 | perturbation: torch.Tensor, 124 | covariates: dict[str, torch.Tensor], 125 | ): 126 | if self.use_covariates and self.use_perturbations: 127 | embedding = torch.cat([cov.squeeze() for cov in covariates.values()], dim=1) 128 | embedding = torch.cat([perturbation, embedding], dim=1) 129 | elif self.use_covariates: 130 | embedding = torch.cat([cov.squeeze() for cov in covariates.values()], dim=1) 131 | elif self.use_perturbations: 132 | embedding = perturbation 133 | 134 | predicted_perturbed_expression = self.decoder(embedding) 135 | 136 | if self.softplus_output: 137 | predicted_perturbed_expression = F.softplus(predicted_perturbed_expression) 138 | return predicted_perturbed_expression 139 | 140 | def training_step(self, batch: Batch, batch_idx: int): 141 | ( 142 | observed_perturbed_expression, 143 | control_expression, 144 | perturbation, 145 | covariates, 146 | _, 147 | ) = self.unpack_batch(batch) 148 | predicted_perturbed_expression = self.forward( 149 | control_expression, perturbation, covariates 150 | ) 151 | loss = F.mse_loss(predicted_perturbed_expression, observed_perturbed_expression) 152 | self.log("train_loss", loss, prog_bar=True, logger=True, batch_size=len(batch)) 153 | return loss 154 | 155 | def validation_step(self, batch: Batch, batch_idx: int): 156 | ( 157 | observed_perturbed_expression, 158 | control_expression, 159 | perturbation, 160 | covariates, 161 | _, 162 | ) = self.unpack_batch(batch) 163 | predicted_perturbed_expression = self.forward( 164 | control_expression, perturbation, covariates 165 | ) 166 | val_loss = F.mse_loss( 167 | predicted_perturbed_expression, observed_perturbed_expression 168 | ) 169 | self.log( 170 | "val_loss", 171 | val_loss, 172 | on_step=True, 173 | prog_bar=True, 174 | logger=True, 175 | batch_size=len(batch), 176 | ) 177 | return val_loss 178 | 179 | def predict(self, batch): 180 | control_expression = batch.gene_expression.squeeze().to(self.device) 181 | perturbation = batch.perturbations.squeeze().to(self.device) 182 | covariates = {k: v.to(self.device) for k, v in batch.covariates.items()} 183 | 184 | predicted_perturbed_expression = self.forward( 185 | control_expression, 186 | perturbation, 187 | covariates, 188 | ) 189 | return predicted_perturbed_expression 190 | -------------------------------------------------------------------------------- /src/perturbench/modelcore/models/linear_additive.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | 4 | Copyright (c) 2024, 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | """ 31 | 32 | import lightning as L 33 | import torch 34 | import torch.nn as nn 35 | import torch.nn.functional as F 36 | import numpy as np 37 | 38 | from perturbench.data.types import Batch 39 | from .base import PerturbationModel 40 | 41 | 42 | class LinearAdditive(PerturbationModel): 43 | """ 44 | A latent additive model for predicting perturbation effects 45 | """ 46 | 47 | def __init__( 48 | self, 49 | n_genes: int, 50 | n_perts: int, 51 | inject_covariates: bool = False, 52 | lr: float | None = None, 53 | wd: float | None = None, 54 | lr_scheduler_freq: int | None = None, 55 | lr_scheduler_interval: str | None = None, 56 | lr_scheduler_patience: int | None = None, 57 | lr_scheduler_factor: int | None = None, 58 | softplus_output: bool = True, 59 | datamodule: L.LightningDataModule | None = None, 60 | ) -> None: 61 | """ 62 | The constructor for the LinearAdditive class. 63 | 64 | Args: 65 | n_genes (int): Number of genes in the dataset 66 | n_perts (int): Number of perturbations in the dataset (not including controls) 67 | lr (float): Learning rate 68 | wd (float): Weight decay 69 | lr_scheduler_freq (int): How often the learning rate scheduler checks val_loss 70 | lr_scheduler_interval (str): Whether to check val_loss every epoch or every step 71 | lr_scheduler_patience (int): Learning rate scheduler patience 72 | lr_scheduler_factor (float): Factor by which to reduce learning rate when learning rate scheduler triggers 73 | inject_covariates: Whether to condition the linear layer on 74 | covariates 75 | softplus_output: Whether to apply a softplus activation to the 76 | output of the decoder to enforce non-negativity 77 | datamodule: The datamodule used to train the model 78 | """ 79 | super(LinearAdditive, self).__init__( 80 | datamodule=datamodule, 81 | lr=lr, 82 | wd=wd, 83 | lr_scheduler_interval=lr_scheduler_interval, 84 | lr_scheduler_freq=lr_scheduler_freq, 85 | lr_scheduler_patience=lr_scheduler_patience, 86 | lr_scheduler_factor=lr_scheduler_factor, 87 | ) 88 | self.save_hyperparameters(ignore=["datamodule"]) 89 | self.softplus_output = softplus_output 90 | 91 | if n_genes is not None: 92 | self.n_genes = n_genes 93 | if n_perts is not None: 94 | self.n_perts = n_perts 95 | 96 | self.inject_covariates = inject_covariates 97 | if inject_covariates: 98 | if datamodule is None or datamodule.train_context is None: 99 | raise ValueError( 100 | "If inject_covariates is True, datamodule must be provided" 101 | ) 102 | n_total_covariates = np.sum( 103 | [ 104 | len(unique_covs) 105 | for unique_covs in datamodule.train_context[ 106 | "covariate_uniques" 107 | ].values() 108 | ] 109 | ) 110 | self.fc_pert = nn.Linear(self.n_perts + n_total_covariates, self.n_genes) 111 | else: 112 | self.fc_pert = nn.Linear(self.n_perts, self.n_genes) 113 | 114 | def forward( 115 | self, 116 | control_expression: torch.Tensor, 117 | perturbation: torch.Tensor, 118 | covariates: dict, 119 | ): 120 | if self.inject_covariates: 121 | merged_covariates = torch.cat( 122 | [cov.squeeze() for cov in covariates.values()], dim=1 123 | ) 124 | perturbation = torch.cat([perturbation, merged_covariates], dim=1) 125 | 126 | predicted_perturbed_expression = control_expression + self.fc_pert(perturbation) 127 | if self.softplus_output: 128 | predicted_perturbed_expression = F.softplus(predicted_perturbed_expression) 129 | return predicted_perturbed_expression 130 | 131 | def training_step(self, batch: Batch, batch_idx: int): 132 | ( 133 | observed_perturbed_expression, 134 | control_expression, 135 | perturbation, 136 | covariates, 137 | _, 138 | ) = self.unpack_batch(batch) 139 | predicted_perturbed_expression = self.forward( 140 | control_expression, perturbation, covariates 141 | ) 142 | loss = F.mse_loss(predicted_perturbed_expression, observed_perturbed_expression) 143 | self.log("train_loss", loss, prog_bar=True, logger=True, batch_size=len(batch)) 144 | return loss 145 | 146 | def validation_step(self, batch: Batch, batch_idx: int): 147 | ( 148 | observed_perturbed_expression, 149 | control_expression, 150 | perturbation, 151 | covariates, 152 | _, 153 | ) = self.unpack_batch(batch) 154 | predicted_perturbed_expression = self.forward( 155 | control_expression, perturbation, covariates 156 | ) 157 | val_loss = F.mse_loss( 158 | predicted_perturbed_expression, observed_perturbed_expression 159 | ) 160 | self.log( 161 | "val_loss", 162 | val_loss, 163 | on_step=True, 164 | prog_bar=True, 165 | logger=True, 166 | batch_size=len(batch), 167 | ) 168 | return val_loss 169 | 170 | def predict(self, batch: Batch): 171 | control_expression = batch.gene_expression.squeeze().to(self.device) 172 | perturbation = batch.perturbations.squeeze().to(self.device) 173 | covariates = {k: v.to(self.device) for k, v in batch.covariates.items()} 174 | predicted_perturbed_expression = self.forward( 175 | control_expression, 176 | perturbation, 177 | covariates, 178 | ) 179 | return predicted_perturbed_expression 180 | -------------------------------------------------------------------------------- /src/perturbench/modelcore/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/altoslabs/perturbench/e57c2fbafeb1b22df4923b4a1b3d3c82d2ba57ef/src/perturbench/modelcore/nn/__init__.py -------------------------------------------------------------------------------- /src/perturbench/modelcore/nn/mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.distributions import RelaxedBernoulli 3 | 4 | 5 | def gumbel_softmax_bernoulli(probs, temperature=0.5): 6 | # Create a RelaxedBernoulli distribution with the specified temperature 7 | relaxed_bernoulli = RelaxedBernoulli(temperature, probs=probs) 8 | 9 | # Sample from the relaxed distribution 10 | soft_sample = relaxed_bernoulli.rsample() 11 | 12 | # Quantize the soft sample to get a hard sample 13 | hard_sample = (soft_sample > 0.5).float() 14 | 15 | # Use straight-through estimator 16 | st_sample = hard_sample - soft_sample.detach() + soft_sample 17 | 18 | return st_sample 19 | 20 | 21 | class MLP(nn.Module): 22 | def __init__( 23 | self, 24 | input_dim: int, 25 | hidden_dim: int, 26 | output_dim: int, 27 | n_layers: int, 28 | dropout: float | None = None, 29 | norm: str | None = "layer", 30 | elementwise_affine: bool = False, 31 | ): 32 | """Class for defining MLP with arbitrary number of layers""" 33 | super(MLP, self).__init__() 34 | 35 | if norm not in ["layer", "batch", None]: 36 | raise ValueError("norm must be one of ['layer', 'batch', None]") 37 | 38 | layers = nn.Sequential() 39 | layers.append(nn.Linear(input_dim, hidden_dim)) 40 | for _ in range(0, n_layers): 41 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 42 | if norm == "layer": 43 | layers.append( 44 | nn.LayerNorm(hidden_dim, elementwise_affine=elementwise_affine) 45 | ) 46 | elif norm == "batch": 47 | layers.append(nn.BatchNorm1d(hidden_dim, momentum=0.01, eps=0.001)) 48 | layers.append(nn.ReLU()) 49 | if dropout is not None: 50 | layers.append(nn.Dropout(dropout)) 51 | 52 | layers.append(nn.Linear(hidden_dim, output_dim)) 53 | self.network = layers 54 | 55 | def forward(self, x): 56 | return self.network(x) 57 | 58 | 59 | class ResMLP(nn.Module): 60 | def __init__( 61 | self, 62 | input_dim: int, 63 | hidden_dim: int, 64 | n_layers: int, 65 | dropout: float | None = None, 66 | norm: str | None = "layer", 67 | elementwise_affine: bool = False, 68 | ): 69 | super(ResMLP, self).__init__() 70 | 71 | layers = nn.Sequential() 72 | layers.append(nn.Linear(input_dim, hidden_dim)) 73 | layers.append(nn.LayerNorm(hidden_dim)) 74 | layers.append(nn.ReLU()) 75 | 76 | for i in range(0, n_layers - 1): 77 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 78 | layers.append(nn.LayerNorm(hidden_dim)) 79 | layers.append(nn.ReLU()) 80 | 81 | layers.append(nn.Linear(hidden_dim, input_dim)) 82 | self.layers = layers 83 | 84 | def forward(self, x): 85 | return x + self.layers(x) 86 | 87 | 88 | class MaskNet(nn.Module): 89 | def __init__( 90 | self, 91 | input_dim: int, 92 | hidden_dim: int, 93 | output_dim: int, 94 | n_layers: int, 95 | dropout: float | None = None, 96 | norm: str | None = "layer", 97 | elementwise_affine: bool = False, 98 | ): 99 | """ 100 | Implements a mask module. Similar to a standard MLP, but the output will be discrete 0's and 1's 101 | and gradients are calculated with the Gumbel-Softmax relaxation. 102 | """ 103 | super(MaskNet, self).__init__() 104 | 105 | self.mlp = MLP( 106 | input_dim=input_dim, 107 | hidden_dim=hidden_dim, 108 | output_dim=output_dim, 109 | n_layers=n_layers, 110 | dropout=dropout, 111 | norm=norm, 112 | elementwise_affine=elementwise_affine, 113 | ) 114 | 115 | def forward(self, x): 116 | m_probs = self.mlp(x).sigmoid() 117 | m = gumbel_softmax_bernoulli(m_probs) 118 | return m 119 | -------------------------------------------------------------------------------- /src/perturbench/modelcore/nn/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions as dist 3 | import torch.nn.functional as F 4 | 5 | 6 | class ZeroInflatedNegativeBinomial(dist.Distribution): 7 | def __init__(self, base_dist, zero_prob_logits): 8 | super(ZeroInflatedNegativeBinomial, self).__init__() 9 | self.base_dist = base_dist 10 | self.zero_prob_logits = zero_prob_logits 11 | self._mean = None 12 | self._variance = None 13 | 14 | def sample(self, sample_shape=torch.Size()): 15 | # note: this is not actually sampling from NB 16 | base_samples = self.base_dist.sample(sample_shape) 17 | zero_mask = torch.bernoulli(torch.sigmoid(self.zero_prob_logits)).bool() 18 | return torch.where(zero_mask, torch.zeros_like(base_samples), base_samples) 19 | 20 | def log_prob(self, value, eps=1e-8): 21 | # the original way but can be numerically unstable 22 | # base_log_prob = self.base_dist.log_prob(value) 23 | # zero_probs = torch.sigmoid(self.zero_prob_logits) 24 | # log_prob_non_zero = torch.log1p(-zero_probs + 1e-8) + base_log_prob 25 | # log_prob = torch.where( 26 | # value == 0, 27 | # torch.log(zero_probs + (1 - zero_probs) * torch.exp(base_log_prob) + 1e-8), 28 | # log_prob_non_zero 29 | # ) 30 | 31 | # Adapted from SCVI's implementation of ZINB log_prob 32 | base_log_prob = self.base_dist.log_prob(value) 33 | softplus_neg_logits = F.softplus(-self.zero_prob_logits) 34 | case_zero = ( 35 | F.softplus(-self.zero_prob_logits + base_log_prob) - softplus_neg_logits 36 | ) 37 | mul_case_zero = torch.mul((value < eps).type(torch.float32), case_zero) 38 | case_non_zero = -self.zero_prob_logits - softplus_neg_logits + base_log_prob 39 | mul_case_non_zero = torch.mul((value > eps).type(torch.float32), case_non_zero) 40 | log_prob = mul_case_zero + mul_case_non_zero 41 | 42 | return log_prob 43 | 44 | @property 45 | def mean(self): 46 | if self._mean is None: 47 | base_mean = self.base_dist.mean 48 | self._mean = (1 - torch.sigmoid(self.zero_prob_logits)) * base_mean 49 | return self._mean 50 | 51 | @property 52 | def variance( 53 | self, 54 | ): # https://docs.pyro.ai/en/dev/_modules/pyro/distributions/zero_inflated.html#ZeroInflatedNegativeBinomial 55 | if self._variance is None: 56 | base_mean = self.base_dist.mean 57 | base_variance = self.base_dist.variance 58 | self._variance = (1 - torch.sigmoid(self.zero_prob_logits)) * ( 59 | base_mean**2 + base_variance 60 | ) - self.mean**2 61 | return self._variance 62 | -------------------------------------------------------------------------------- /src/perturbench/modelcore/predict.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from omegaconf import DictConfig 3 | import lightning as L 4 | import logging 5 | import hydra 6 | import os 7 | 8 | from perturbench.data.datasets import Counterfactual 9 | from perturbench.data.utils import batch_dataloader 10 | from perturbench.data.collate import noop_collate 11 | from .models.base import PerturbationModel 12 | 13 | log = logging.getLogger(__name__) 14 | 15 | 16 | def predict( 17 | cfg: DictConfig, 18 | ): 19 | """Predict counterfactual perturbation effects""" 20 | # Set seed for random number generators in pytorch, numpy and python.random 21 | if cfg.get("seed"): 22 | L.seed_everything(cfg.seed, workers=True) 23 | 24 | log.info("Instantiating datamodule <%s>", cfg.data._target_) 25 | datamodule: L.LightningDataModule = hydra.utils.instantiate(cfg.data) 26 | 27 | log.info("Instantiating model <%s>", cfg.model._target_) 28 | model_class: PerturbationModel = hydra.utils.get_class(cfg.model._target_) 29 | 30 | # Load trained model 31 | if not os.path.exists(cfg.ckpt_path): 32 | raise ValueError(f"Checkpoint path {cfg.ckpt_path} does not exist") 33 | if not os.path.exists(cfg.ckpt_path): 34 | raise ValueError(f"Checkpoint path {cfg.ckpt_path} does not exist") 35 | 36 | trained_model: PerturbationModel = model_class.load_from_checkpoint( 37 | cfg.ckpt_path, 38 | datamodule=datamodule, 39 | ) 40 | 41 | # Load prediction dataframe 42 | if not os.path.exists(cfg.prediction_dataframe_path): 43 | raise ValueError( 44 | f"Prediction dataframe path {cfg.prediction_dataframe_path} does not exist" 45 | ) 46 | pred_df = pd.read_csv(cfg.prediction_dataframe_path) 47 | 48 | if cfg.data.perturbation_key not in pred_df.columns: 49 | raise ValueError( 50 | f"Prediction dataframe must contain column {cfg.data.perturbation_key}" 51 | ) 52 | for covariate_key in cfg.data.covariate_keys: 53 | if covariate_key not in pred_df.columns: 54 | raise ValueError( 55 | f"Prediction dataframe must contain column {covariate_key}" 56 | ) 57 | 58 | # Create inference dataloader 59 | test_adata = datamodule.test_dataset.reference_adata 60 | control_adata = test_adata[ 61 | test_adata.obs[cfg.data.perturbation_key] == cfg.data.perturbation_control_value 62 | ] 63 | del test_adata 64 | 65 | inference_dataset, _ = Counterfactual.from_anndata( 66 | control_adata, 67 | pred_df, 68 | cfg.data.perturbation_key, 69 | perturbation_combination_delimiter=cfg.data.perturbation_combination_delimiter, 70 | covariate_keys=cfg.data.covariate_keys, 71 | perturbation_control_value=cfg.data.perturbation_control_value, 72 | seed=cfg.seed, 73 | max_control_cells_per_covariate=cfg.data.evaluation.max_control_cells_per_covariate, 74 | ) 75 | inference_dataset.transform = trained_model.training_record["transform"] 76 | inference_dataloader = batch_dataloader( 77 | inference_dataset, 78 | batch_size=cfg.chunk_size, 79 | num_workers=cfg.data.num_workers, 80 | shuffle=False, 81 | collate_fn=noop_collate(), 82 | ) 83 | 84 | log.info("Instantiating trainer <%s>", cfg.trainer._target_) 85 | trainer: L.Trainer = hydra.utils.instantiate(cfg.trainer) 86 | 87 | log.info("Generating predictions") 88 | if not os.path.exists(cfg.output_path): 89 | os.makedirs(cfg.output_path) 90 | trained_model.prediction_output_path = cfg.output_path 91 | trainer.predict(model=trained_model, dataloaders=inference_dataloader) 92 | 93 | 94 | @hydra.main(version_base="1.3", config_path="../configs", config_name="predict.yaml") 95 | def main(cfg: DictConfig): 96 | predict(cfg) 97 | 98 | 99 | if __name__ == "__main__": 100 | main() 101 | -------------------------------------------------------------------------------- /src/perturbench/modelcore/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List 3 | import hydra 4 | import lightning as L 5 | from omegaconf import DictConfig 6 | from lightning.pytorch.loggers import Logger 7 | from perturbench.modelcore.utils import multi_instantiate 8 | from perturbench.modelcore.models import PerturbationModel 9 | from hydra.core.hydra_config import HydraConfig 10 | 11 | 12 | log = logging.getLogger(__name__) 13 | 14 | 15 | def train(runtime_context: dict): 16 | cfg = runtime_context["cfg"] 17 | 18 | # Set seed for random number generators in pytorch, numpy and python.random 19 | if cfg.get("seed"): 20 | L.seed_everything(cfg.seed, workers=True) 21 | 22 | log.info("Instantiating datamodule <%s>", cfg.data._target_) 23 | datamodule: L.LightningDataModule = hydra.utils.instantiate( 24 | cfg.data, 25 | seed=cfg.seed, 26 | ) 27 | 28 | log.info("Instantiating model <%s>", cfg.model._target_) 29 | model: PerturbationModel = hydra.utils.instantiate(cfg.model, datamodule=datamodule) 30 | 31 | log.info("Instantiating callbacks...") 32 | callbacks: List[L.Callback] = multi_instantiate(cfg.get("callbacks")) 33 | 34 | log.info("Instantiating loggers...") 35 | loggers: List[Logger] = multi_instantiate(cfg.get("logger")) 36 | 37 | log.info("Instantiating trainer <%s>", cfg.trainer._target_) 38 | trainer: L.Trainer = hydra.utils.instantiate( 39 | cfg.trainer, callbacks=callbacks, logger=loggers 40 | ) 41 | 42 | if cfg.get("train"): 43 | log.info("Starting training!") 44 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) 45 | 46 | train_metrics = trainer.callback_metrics 47 | 48 | summary_metrics_dict = {} 49 | if cfg.get("test"): 50 | log.info("Starting testing!") 51 | if cfg.get("train"): 52 | if ( 53 | trainer.checkpoint_callback is None 54 | or trainer.checkpoint_callback.best_model_path == "" 55 | ): 56 | ckpt_path = None 57 | else: 58 | ckpt_path = "best" 59 | else: 60 | ckpt_path = cfg.get("ckpt_path") 61 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 62 | summary_metrics_dict = model.summary_metrics.to_dict()[ 63 | model.summary_metrics.columns[0] 64 | ] 65 | 66 | test_metrics = trainer.callback_metrics 67 | # merge train and test metrics 68 | metric_dict = {**train_metrics, **test_metrics, **summary_metrics_dict} 69 | 70 | return metric_dict 71 | 72 | 73 | @hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") 74 | def main(cfg: DictConfig) -> float | None: 75 | runtime_context = {"cfg": cfg, "trial_number": HydraConfig.get().job.get("num")} 76 | 77 | ## Train the model 78 | global metric_dict 79 | metric_dict = train(runtime_context) 80 | 81 | ## Combined metric 82 | metrics_use = cfg.get("metrics_to_optimize") 83 | if metrics_use: 84 | combined_metric = sum( 85 | [metric_dict.get(metric) * weight for metric, weight in metrics_use.items()] 86 | ) 87 | return combined_metric 88 | 89 | 90 | if __name__ == "__main__": 91 | main() 92 | -------------------------------------------------------------------------------- /src/perturbench/modelcore/utils.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import logging 3 | import warnings 4 | from typing import Callable, List, TypeVar, Any 5 | import hydra 6 | from omegaconf import DictConfig 7 | from lightning import Callback 8 | from lightning.pytorch.loggers import Logger 9 | 10 | log = logging.getLogger(__name__) 11 | 12 | T = TypeVar("T", Logger, Callback) 13 | 14 | 15 | def multi_instantiate( 16 | cfg: DictConfig, context: dict[str, Any] | None = None 17 | ) -> List[T]: 18 | """Instantiates multiple classes from config. 19 | 20 | Instantiates multiple classes from a config object. The configuration object 21 | has the following structure: dict[submodule_name: str, conf: DictConfig]. The 22 | conf contains the configuration for the submodule and adheres to one of the 23 | following schema: 24 | 25 | 1. dict[_target_: str, ...] where _target_ key specifies the class to be 26 | instantiated and the rest of the keys are the arguments to be passed to 27 | the class constructor (usual hydra config schema). 28 | 29 | 2. dict[conf: DictConfig, dependencies: dict[str, str]] where conf contains 30 | the configuration for the submodule as defined in (1) missing requested 31 | dependencies and dependencies is a mapping from the class arguments to 32 | the keys in the context dictionary. 33 | 34 | 3. dict[conf: Callable, dependencies: DictConfig] where conf is a callable 35 | acts as a factory function for the class and dependencies is a mapping 36 | from the class arguments to the keys in the context dictionary. 37 | 38 | Args: 39 | cfg: A DictConfig object containing configurations. 40 | context: A dictionary containing dependencies that the individual 41 | classes might request from. 42 | 43 | Returns: 44 | A list of instantiated classes. 45 | 46 | Raises: 47 | TypeError: If the config is not a DictConfig. 48 | """ 49 | 50 | instances: List[T] = [] 51 | 52 | if not cfg: 53 | warnings.warn("No configs found! Skipping...") 54 | return instances 55 | 56 | if not isinstance(cfg, DictConfig): 57 | raise TypeError("Config must be a DictConfig!") 58 | 59 | for _, conf in cfg.items(): 60 | target_name = conf.__class__.__name__ 61 | instance_dependencies = {} 62 | # Resolve dependencies if requested 63 | if isinstance(conf, DictConfig) and "dependencies" in conf: 64 | if "conf" not in conf: 65 | raise TypeError( 66 | "Invalid config schema. If dependencies are requested, then " 67 | "the config must contain a 'conf' section specifying the " 68 | "class arguments not included in the dependencies." 69 | ) 70 | # Resolve dependencies if any specified 71 | if conf.dependencies: 72 | if context: 73 | instance_dependencies = { 74 | kwarg: context[key] for kwarg, key in conf.dependencies.items() 75 | } 76 | else: 77 | raise ValueError( 78 | "The config requests dependencies, but none were provided." 79 | ) 80 | conf: DictConfig | Callable = conf.conf 81 | # pylint: disable-next=protected-access 82 | target_name = conf.func.__name__ if callable(conf) else conf._target_ 83 | 84 | log.info("Instantiating an object of type <%s>", target_name) 85 | if isinstance(conf, partial): 86 | instances.append(conf(**instance_dependencies)) 87 | elif isinstance(conf, DictConfig): 88 | if "_target_" in conf: 89 | instances.append( 90 | hydra.utils.instantiate( 91 | conf, 92 | **instance_dependencies, 93 | _recursive_=False, 94 | ) 95 | ) 96 | else: 97 | raise ValueError( 98 | f"Invalid config schema ({conf}). The config must contain " 99 | "a '_target_' key specifying the class to be instantiated." 100 | ) 101 | # Object has already been instantiated 102 | else: 103 | instances.append(conf) 104 | 105 | return instances 106 | 107 | 108 | def instantiate_with_context( 109 | cfg: DictConfig, 110 | context: dict[str, Any] | None = None, 111 | ) -> Any: 112 | if not cfg: 113 | warnings.warn("No configs found! Skipping...") 114 | return None 115 | 116 | if not isinstance(cfg, DictConfig): 117 | raise TypeError("Config must be a DictConfig!") 118 | 119 | if cfg.dependencies: 120 | if context: 121 | dependencies = { 122 | kwarg: context[key] for kwarg, key in cfg.dependencies.items() 123 | } 124 | else: 125 | raise ValueError( 126 | "The config requests dependencies, but none were provided." 127 | ) 128 | else: 129 | dependencies = {} 130 | 131 | return cfg.conf(**dependencies) 132 | --------------------------------------------------------------------------------