├── .github
└── workflows
│ └── code-quality.yaml
├── .gitignore
├── LICENSE.txt
├── README.md
├── notebooks
├── demos
│ ├── data_splitter_demo.ipynb
│ ├── evaluator_demo.ipynb
│ └── generate_prediction_dataframe.ipynb
└── neurips2025
│ ├── build_data_scaling_splits.ipynb
│ ├── build_imbalance_splits.ipynb
│ ├── build_jiang24_frangieh21_splits.ipynb
│ ├── cpa_theis_fork_hparams.yaml
│ ├── data_curation
│ ├── curate_Frangieh21.ipynb
│ ├── curate_Jiang24_step1.R
│ ├── curate_Jiang24_step2.ipynb
│ ├── curate_McFalineFigueroa23_step1.R
│ ├── curate_McFalineFigueroa23_step2.ipynb
│ ├── curate_Norman19.ipynb
│ ├── curate_Srivatsan20.ipynb
│ └── readme.txt
│ ├── gears
│ ├── final_evaluation.ipynb
│ ├── gears_helpers.py
│ └── preprocess_norman19_gears.ipynb
│ └── generate_scgpt_embeddings.ipynb
├── pyproject.toml
├── setup.py
└── src
└── perturbench
├── analysis
├── __init__.py
├── benchmarks
│ ├── __init__.py
│ ├── aggregation.py
│ ├── evaluation.py
│ ├── evaluator.py
│ └── metrics.py
├── plotting.py
├── preprocess.py
└── utils.py
├── configs
├── __init__.py
├── callbacks
│ ├── default.yaml
│ ├── early_stopping.yaml
│ ├── lr_monitor.yaml
│ ├── model_checkpoint.yaml
│ ├── model_summary.yaml
│ ├── none.yaml
│ └── rich_progress_bar.yaml
├── data
│ ├── devel.yaml
│ ├── evaluation
│ │ ├── default.yaml
│ │ └── final_test.yaml
│ ├── frangieh21.yaml
│ ├── jiang24.yaml
│ ├── mcfaline23.yaml
│ ├── norman19.yaml
│ ├── sciplex3.yaml
│ ├── splitter
│ │ ├── cell_type_transfer_task.yaml
│ │ ├── cell_type_treatment_transfer_task.yaml
│ │ ├── combination_prediction_task.yaml
│ │ ├── mcfaline23_split.yaml
│ │ └── saved_split.yaml
│ └── transform
│ │ └── linear_model_pipeline.yaml
├── experiment
│ └── neurips2024
│ │ ├── frangieh21
│ │ ├── biolord_best_params_frangieh21.yaml
│ │ ├── cpa_best_params_frangieh21.yaml
│ │ ├── cpa_no_adv_best_params_frangieh21.yaml
│ │ ├── cpa_scgpt_best_params_frangieh21.yaml
│ │ ├── decoder_best_params_frangieh21.yaml
│ │ ├── decoder_cov_best_params_frangieh21.yaml
│ │ ├── latent_best_params_frangieh21.yaml
│ │ ├── latent_scgpt_best_params_frangieh21.yaml
│ │ ├── linear_best_params_frangieh21.yaml
│ │ ├── sams_best_params_frangieh21.yaml
│ │ └── sams_modified_best_params_frangieh21.yaml
│ │ ├── jiang24
│ │ ├── cpa_best_params_jiang24.yaml
│ │ ├── cpa_no_adv_best_params_jiang24.yaml
│ │ ├── decoder_best_params_jiang24.yaml
│ │ ├── latent_best_params_jiang24.yaml
│ │ ├── linear_best_params_jiang24.yaml
│ │ ├── sams_best_params_jiang24.yaml
│ │ └── sams_modified_best_params_jiang24.yaml
│ │ ├── mcfaline23
│ │ ├── cpa_best_params_mcfaline23_full.yaml
│ │ ├── cpa_best_params_mcfaline23_medium.yaml
│ │ ├── cpa_best_params_mcfaline23_small.yaml
│ │ ├── cpa_no_adv_best_params_mcfaline23_full.yaml
│ │ ├── cpa_no_adv_best_params_mcfaline23_medium.yaml
│ │ ├── cpa_no_adv_best_params_mcfaline23_small.yaml
│ │ ├── decoder_only_best_params_mcfaline23_full.yaml
│ │ ├── decoder_only_best_params_mcfaline23_medium.yaml
│ │ ├── decoder_only_best_params_mcfaline23_small.yaml
│ │ ├── latent_additive_best_params_mcfaline23_full.yaml
│ │ ├── latent_additive_best_params_mcfaline23_medium.yaml
│ │ ├── latent_additive_best_params_mcfaline23_small.yaml
│ │ ├── linear_additive_best_params_mcfaline23_full.yaml
│ │ ├── linear_additive_best_params_mcfaline23_medium.yaml
│ │ ├── linear_additive_best_params_mcfaline23_small.yaml
│ │ ├── sams_best_params_mcfaline23_full.yaml
│ │ ├── sams_best_params_mcfaline23_medium.yaml
│ │ ├── sams_best_params_mcfaline23_small.yaml
│ │ ├── sams_modified_best_params_mcfaline23_full.yaml
│ │ ├── sams_modified_best_params_mcfaline23_medium.yaml
│ │ └── sams_modified_best_params_mcfaline23_small.yaml
│ │ ├── norman19
│ │ ├── biolord_best_params_norman19.yaml
│ │ ├── cpa_best_params_norman19.yaml
│ │ ├── cpa_no_adv_best_params_norman19.yaml
│ │ ├── cpa_scgpt_best_params_norman19.yaml
│ │ ├── decoder_best_params_norman19.yaml
│ │ ├── latent_best_params_norman19.yaml
│ │ ├── latent_scgpt_best_params_norman19.yaml
│ │ ├── linear_best_params_norman19.yaml
│ │ ├── sams_best_params_norman19.yaml
│ │ └── sams_modified_best_params_norman19.yaml
│ │ └── sciplex3
│ │ ├── biolord_best_params_sciplex3.yaml
│ │ ├── cpa_best_params_sciplex3.yaml
│ │ ├── cpa_no_adv_best_params_sciplex3.yaml
│ │ ├── cpa_scgpt_best_params_sciplex3.yaml
│ │ ├── decoder_best_params_sciplex3.yaml
│ │ ├── decoder_cov_best_params_sciplex3.yaml
│ │ ├── latent_best_params_sciplex3.yaml
│ │ ├── latent_scgpt_best_params_sciplex3.yaml
│ │ ├── linear_best_params_sciplex3.yaml
│ │ ├── sams_best_params_sciplex3.yaml
│ │ └── sams_modified_best_params_sciplex3.yaml
├── hpo
│ ├── biolord_hpo.yaml
│ ├── cpa_hpo.yaml
│ ├── decoder_only_hpo.yaml
│ ├── latent_additive_hpo.yaml
│ ├── linear_additive_hpo.yaml
│ ├── local.yaml
│ └── sams_vae_hpo.yaml
├── hydra
│ └── default.yaml
├── logger
│ ├── csv.yaml
│ ├── default.yaml
│ └── tensorboard.yaml
├── model
│ ├── biolord.yaml
│ ├── cpa.yaml
│ ├── decoder_only.yaml
│ ├── latent_additive.yaml
│ ├── linear_additive.yaml
│ └── sams_vae.yaml
├── paths
│ └── default.yaml
├── predict.yaml
├── train.yaml
└── trainer
│ ├── cpu.yaml
│ ├── default.yaml
│ └── gpu.yaml
├── data
├── __init__.py
├── accessors
│ ├── base.py
│ ├── frangieh21.py
│ ├── jiang24.py
│ ├── mcfaline23.py
│ ├── norman19.py
│ └── srivatsan20.py
├── collate.py
├── datasets
│ ├── __init__.py
│ ├── population.py
│ └── singlecell.py
├── datasplitter.py
├── modules.py
├── resources
│ ├── __init__.py
│ └── devel.h5ad
├── transforms
│ ├── __init__.py
│ ├── base.py
│ ├── encoders.py
│ ├── ops.py
│ └── pipelines.py
├── types.py
└── utils.py
└── modelcore
├── __init__.py
├── models
├── __init__.py
├── average.py
├── base.py
├── biolord.py
├── cpa.py
├── decoder_only.py
├── latent_additive.py
├── linear_additive.py
└── sams_vae.py
├── nn
├── __init__.py
├── decoders.py
├── mlp.py
├── utils.py
└── vae.py
├── predict.py
├── train.py
└── utils.py
/.github/workflows/code-quality.yaml:
--------------------------------------------------------------------------------
1 | name: code-quality
2 |
3 | on: [push, pull_request]
4 |
5 | jobs:
6 | run-linter:
7 | name: run linter
8 |
9 | runs-on: ubuntu-latest
10 |
11 | steps:
12 | - uses: actions/checkout@v4
13 | - uses: chartboost/ruff-action@v1
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | *.sqlite
28 | *.db
29 | *perturbench_data*
30 | *logs/*
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Machine Learning
40 | tensorboard/
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | *.py,cover
57 | .hypothesis/
58 | .pytest_cache/
59 | cover/
60 | __pycache__
61 |
62 | # Translations
63 | *.mo
64 | *.pot
65 |
66 | # Django stuff:
67 | *.log
68 | local_settings.py
69 | db.sqlite3
70 | db.sqlite3-journal
71 |
72 | # Flask stuff:
73 | instance/
74 | .webassets-cache
75 |
76 | # Scrapy stuff:
77 | .scrapy
78 |
79 | # Sphinx documentation
80 | docs/_build/
81 |
82 | # PyBuilder
83 | .pybuilder/
84 | target/
85 |
86 | # Jupyter Notebook
87 | .ipynb_checkpoints
88 |
89 | # IPython
90 | profile_default/
91 | ipython_config.py
92 |
93 | # pyenv
94 | # For a library or package, you might want to ignore these files since the code is
95 | # intended to run in multiple environments; otherwise, check them in:
96 | # .python-version
97 |
98 | # pipenv
99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
102 | # install all needed dependencies.
103 | #Pipfile.lock
104 |
105 | # poetry
106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
107 | # This is especially recommended for binary packages to ensure reproducibility, and is more
108 | # commonly ignored for libraries.
109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
110 | #poetry.lock
111 |
112 | # pdm
113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
114 | #pdm.lock
115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
116 | # in version control.
117 | # https://pdm.fming.dev/#use-with-ide
118 | .pdm.toml
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # VSCode
171 | .vscode/
172 |
173 | # Lightning logs
174 | **/lightning_logs/
175 |
176 | # Data
177 | *.gz
178 |
179 | # Machine Learning
180 | tensorboard/
181 | [0-9]*
182 |
183 | # Biomart
184 | .pybiomart*
185 |
186 | # Jupyter Notebook
187 | !*.ipynb
188 | !*.py
189 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2024, Altos Labs
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
30 | Additional License Information
31 | -------------------
32 | Certain files within this repository are subject to different licensing terms.
33 | Please see the specific files in the perturbench/modelcore/models for their
34 | respective licenses.
35 |
--------------------------------------------------------------------------------
/notebooks/demos/generate_prediction_dataframe.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 2024-03-11-Demo: Creating a prediction dataframe"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "A notebook demonstrating how to generate a prediction dataframe on disk for the `predict.py` script"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 1,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "import scanpy as sc\n",
24 | "import pandas as pd"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 2,
30 | "metadata": {},
31 | "outputs": [],
32 | "source": [
33 | "data_cache_dir = '../neurips2024/perturbench_data' ## Change this to your local data directory"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 4,
39 | "metadata": {},
40 | "outputs": [
41 | {
42 | "data": {
43 | "text/plain": [
44 | "AnnData object with n_obs × n_vars = 183856 × 9198 backed at '../neurips2024/perturbench_data/srivatsan20_processed.h5ad'\n",
45 | " obs: 'ncounts', 'well', 'plate', 'cell_line', 'replicate', 'time', 'dose_value', 'pathway_level_1', 'pathway_level_2', 'perturbation', 'target', 'pathway', 'dose_unit', 'celltype', 'disease', 'cancer', 'tissue_type', 'organism', 'perturbation_type', 'ngenes', 'percent_mito', 'percent_ribo', 'nperts', 'chembl-ID', 'dataset', 'cell_type', 'treatment', 'condition', 'dose', 'cov_merged', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes'\n",
46 | " var: 'ensembl_id', 'ncounts', 'ncells', 'gene_symbol', 'n_cells', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'highly_variable_nbatches'\n",
47 | " uns: 'hvg', 'log1p', 'rank_genes_groups_cov'\n",
48 | " layers: 'counts'"
49 | ]
50 | },
51 | "execution_count": 4,
52 | "metadata": {},
53 | "output_type": "execute_result"
54 | }
55 | ],
56 | "source": [
57 | "adata = sc.read_h5ad(f'{data_cache_dir}/srivatsan20_processed.h5ad', backed='r')\n",
58 | "adata"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": 5,
64 | "metadata": {},
65 | "outputs": [
66 | {
67 | "data": {
68 | "text/plain": [
69 | "['mcf7', 'k562', 'a549']\n",
70 | "Categories (3, object): ['a549', 'k562', 'mcf7']"
71 | ]
72 | },
73 | "execution_count": 5,
74 | "metadata": {},
75 | "output_type": "execute_result"
76 | }
77 | ],
78 | "source": [
79 | "adata.obs.cell_type.unique()"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": 6,
85 | "metadata": {},
86 | "outputs": [
87 | {
88 | "data": {
89 | "text/plain": [
90 | "188"
91 | ]
92 | },
93 | "execution_count": 6,
94 | "metadata": {},
95 | "output_type": "execute_result"
96 | }
97 | ],
98 | "source": [
99 | "unique_perturbations = [p for p in adata.obs.perturbation.unique() if p != 'control']\n",
100 | "len(unique_perturbations)"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": 7,
106 | "metadata": {},
107 | "outputs": [
108 | {
109 | "data": {
110 | "text/html": [
111 | "
\n",
112 | "\n",
125 | "
\n",
126 | " \n",
127 | " \n",
128 | " | \n",
129 | " condition | \n",
130 | " cell_type | \n",
131 | "
\n",
132 | " \n",
133 | " \n",
134 | " \n",
135 | " 0 | \n",
136 | " TAK-901 | \n",
137 | " k562 | \n",
138 | "
\n",
139 | " \n",
140 | " 1 | \n",
141 | " Busulfan | \n",
142 | " k562 | \n",
143 | "
\n",
144 | " \n",
145 | " 2 | \n",
146 | " BMS-536924 | \n",
147 | " k562 | \n",
148 | "
\n",
149 | " \n",
150 | " 3 | \n",
151 | " Enzastaurin (LY317615) | \n",
152 | " k562 | \n",
153 | "
\n",
154 | " \n",
155 | " 4 | \n",
156 | " BMS-911543 | \n",
157 | " k562 | \n",
158 | "
\n",
159 | " \n",
160 | "
\n",
161 | "
"
162 | ],
163 | "text/plain": [
164 | " condition cell_type\n",
165 | "0 TAK-901 k562\n",
166 | "1 Busulfan k562\n",
167 | "2 BMS-536924 k562\n",
168 | "3 Enzastaurin (LY317615) k562\n",
169 | "4 BMS-911543 k562"
170 | ]
171 | },
172 | "execution_count": 7,
173 | "metadata": {},
174 | "output_type": "execute_result"
175 | }
176 | ],
177 | "source": [
178 | "prediction_df = pd.DataFrame(\n",
179 | " {\n",
180 | " 'condition': unique_perturbations,\n",
181 | " 'cell_type': 'k562',\n",
182 | " }\n",
183 | ")\n",
184 | "prediction_df.head()"
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "execution_count": 8,
190 | "metadata": {},
191 | "outputs": [],
192 | "source": [
193 | "prediction_df.to_csv(f'{data_cache_dir}/prediction_dataframe.csv', index=False)"
194 | ]
195 | }
196 | ],
197 | "metadata": {
198 | "kernelspec": {
199 | "display_name": "perturbench-dev",
200 | "language": "python",
201 | "name": "python3"
202 | },
203 | "language_info": {
204 | "codemirror_mode": {
205 | "name": "ipython",
206 | "version": 3
207 | },
208 | "file_extension": ".py",
209 | "mimetype": "text/x-python",
210 | "name": "python",
211 | "nbconvert_exporter": "python",
212 | "pygments_lexer": "ipython3",
213 | "version": "3.11.9"
214 | }
215 | },
216 | "nbformat": 4,
217 | "nbformat_minor": 2
218 | }
219 |
--------------------------------------------------------------------------------
/notebooks/neurips2025/cpa_theis_fork_hparams.yaml:
--------------------------------------------------------------------------------
1 | ae_hparams:
2 | autoencoder_depth: 4
3 | autoencoder_width: 256
4 | adversary_depth: 3
5 | adversary_width: 256
6 | use_batch_norm: False
7 | use_layer_norm: True
8 | output_activation: linear
9 | dropout_rate: 0.2
10 | variational: True
11 | seed: 0
12 |
13 | trainer_hparams:
14 | n_epochs_warmup: 10
15 | n_epochs_kl_warmup: 10
16 | max_kl_weight: 0.1
17 | adversary_lr: 0.00008847032648856746
18 | adversary_wd: 9.629190571404551e-07
19 | adversary_steps: 3
20 | autoencoder_lr: 0.00007208558788012054
21 | autoencoder_wd: 1.2280838320404273e-07
22 | dosers_lr: 0.00008835011062896268
23 | dosers_wd: 5.886005123780177e-06
24 | penalty_adversary: 63.44954424334805
25 | reg_adversary: 48.73324753854268
26 | cycle_coeff: 7.19539336141403
27 | step_size_lr: 25
--------------------------------------------------------------------------------
/notebooks/neurips2025/data_curation/curate_Jiang24_step1.R:
--------------------------------------------------------------------------------
1 | if (!requireNamespace("remotes", quietly = TRUE)) {
2 | install.packages("remotes")
3 | }
4 |
5 | remotes::install_version("Matrix", version = "1.6.5", repos = "http://cran.us.r-project.org")
6 | install.packages("Seurat")
7 | remotes::install_github("mojaveazure/seurat-disk")
8 |
9 | library(Seurat)
10 | library(SeuratDisk)
11 |
12 | ## Set your working directory to the current file (only works on Rstudio)
13 | setwd(dirname(rstudioapi::getActiveDocumentContext()$path))
14 |
15 | ## Data cache directory
16 | data_cache_dir = '../perturbench_data' ## Change this to your local data directory
17 |
18 | ## Define paths to Seurat objects and download from Zenodo
19 | ifng_seurat_path = paste0(data_cache_dir, "/Seurat_object_IFNG_Perturb_seq.rds", sep="")
20 | system2(
21 | "wget",
22 | c(
23 | "https://zenodo.org/records/10520190/files/Seurat_object_IFNG_Perturb_seq.rds?download=1",
24 | "-O",
25 | ifng_seurat_path
26 | )
27 | )
28 |
29 | ifnb_seurat_path = paste0(data_cache_dir, "/Seurat_object_IFNB_Perturb_seq.rds", sep="")
30 | system2(
31 | "wget",
32 | c(
33 | "https://zenodo.org/records/10520190/files/Seurat_object_IFNB_Perturb_seq.rds?download=1",
34 | "-O",
35 | ifnb_seurat_path
36 | )
37 | )
38 |
39 | ins_seurat_path = paste0(data_cache_dir, "/Seurat_object_INS_Perturb_seq.rds", sep="")
40 | system2(
41 | "wget",
42 | c(
43 | "https://zenodo.org/records/10520190/files/Seurat_object_INS_Perturb_seq.rds?download=1",
44 | "-O",
45 | ins_seurat_path
46 | )
47 | )
48 |
49 | tgfb_seurat_path = paste0(data_cache_dir, "/Seurat_object_TGFB_Perturb_seq.rds", sep="")
50 | system2(
51 | "wget",
52 | c(
53 | "https://zenodo.org/records/10520190/files/Seurat_object_TGFB_Perturb_seq.rds?download=1",
54 | "-O",
55 | tgfb_seurat_path
56 | )
57 | )
58 |
59 | tnfa_seurat_path = paste0(data_cache_dir, "/Seurat_object_TNFA_Perturb_seq.rds", sep="")
60 | system2(
61 | "wget",
62 | c(
63 | "https://zenodo.org/records/10520190/files/Seurat_object_TNFA_Perturb_seq.rds?download=1",
64 | "-O",
65 | tnfa_seurat_path
66 | )
67 | )
68 |
69 |
70 | seurat_files = c(
71 | ifng_seurat_path,
72 | ifnb_seurat_path,
73 | ins_seurat_path,
74 | tgfb_seurat_path,
75 | tnfa_seurat_path
76 | )
77 |
78 | for (seurat_path in seurat_files) {
79 | print(seurat_path)
80 | obj = readRDS(seurat_path)
81 | print(obj)
82 |
83 | out_h5seurat_path = gsub(".rds", ".h5Seurat", seurat_path)
84 | SaveH5Seurat(obj, filename = out_h5seurat_path)
85 | Convert(out_h5seurat_path, dest = "h5ad")
86 | }
87 |
88 |
--------------------------------------------------------------------------------
/notebooks/neurips2025/data_curation/curate_McFalineFigueroa23_step1.R:
--------------------------------------------------------------------------------
1 | ## Install packages
2 | if (!require("BiocManager", quietly = TRUE))
3 | install.packages("BiocManager")
4 |
5 |
6 | if (!requireNamespace("remotes", quietly = TRUE)) {
7 | install.packages("remotes")
8 | }
9 |
10 | ## Install libcairo and libxt-dev using apt-get
11 | # system2(
12 | # "apt-get",
13 | # c("install", "libcairo2-dev", "libxt-dev")
14 | # )
15 |
16 | remotes::install_version("Matrix", version = "1.6.5", repos = "http://cran.us.r-project.org")
17 | install.packages("Seurat")
18 | remotes::install_github("mojaveazure/seurat-disk")
19 | BiocManager::install(c('BiocGenerics', 'DelayedArray', 'DelayedMatrixStats',
20 | 'limma', 'lme4', 'S4Vectors', 'SingleCellExperiment',
21 | 'SummarizedExperiment', 'batchelor', 'HDF5Array',
22 | 'terra', 'ggrastr'))
23 | remotes::install_github('cole-trapnell-lab/monocle3')
24 |
25 | ## Load libraries
26 | library(Seurat)
27 | library(SingleCellExperiment)
28 | library(SeuratDisk)
29 | library(Matrix)
30 |
31 | ## Set your working directory to the current file (only works on Rstudio)
32 | setwd(dirname(rstudioapi::getActiveDocumentContext()$path))
33 |
34 | ## Data cache directory
35 | data_cache_dir = '../perturbench_data' ## Change this to your local data directory
36 |
37 | ## GXE1
38 | gxe1_cds_path = paste0(data_cache_dir, "/GSM7056148_sciPlexGxE_1_preprocessed_cds.rds.gz", sep="")
39 | print(gxe1_cds_path)
40 |
41 | system2(
42 | "wget",
43 | c(
44 | "https://ftp.ncbi.nlm.nih.gov/geo/samples/GSM7056nnn/GSM7056148/suppl/GSM7056148_sciPlexGxE_1_preprocessed_cds.rds.gz",
45 | "-O",
46 | gxe1_cds_path
47 | )
48 | )
49 | system2(
50 | "gzip",
51 | c("-d", gxe1_cds_path)
52 | )
53 |
54 | gxe1 = readRDS(gsub(".gz", "", gxe1_cds_path))
55 | gxe1_counts = gxe1@assays@data$counts
56 | gxe1_gene_meta = data.frame(gxe1@rowRanges@elementMetadata@listData)
57 | gxe1_cell_meta = data.frame(gxe1@colData)
58 | head(gxe1_cell_meta)
59 |
60 | colnames(gxe1_counts) = rownames(gxe1_cell_meta)
61 | rownames(gxe1_counts) = gxe1_gene_meta$id
62 | gxe1_seurat = CreateSeuratObject(gxe1_counts, meta.data = gxe1_cell_meta)
63 | gxe1_seurat@assays$RNA@meta.features <- gxe1_gene_meta
64 | gxe1_seurat$cell_type = 'A172'
65 | for (col in colnames(gxe1_seurat@meta.data)) {
66 | if (is.factor(gxe1_seurat@meta.data[[col]])) {
67 | print(col)
68 | gxe1_seurat@meta.data[[col]] = as.character(gxe1_seurat@meta.data[[col]])
69 | }
70 | }
71 |
72 | SaveH5Seurat(gxe1_seurat, filename = paste0(data_cache_dir, "/gxe1.h5Seurat"), overwrite = T)
73 | Convert(paste0(data_cache_dir, "/gxe1.h5Seurat"), dest = "h5ad", overwrite = T)
74 |
75 | ## GXE2
76 | gxe2_cds_path = paste0(data_cache_dir, "/GSM7056149_sciPlexGxE_2_preprocessed_cds.list.RDS.gz")
77 | system2(
78 | "wget",
79 | c(
80 | "https://ftp.ncbi.nlm.nih.gov/geo/samples/GSM7056nnn/GSM7056149/suppl/GSM7056149%5FsciPlexGxE%5F2%5Fpreprocessed%5Fcds.list.RDS.gz",
81 | "-O",
82 | gxe2_cds_path
83 | )
84 | )
85 | system2(
86 | "gzip",
87 | c("-d", gxe2_cds_path)
88 | )
89 | gxe2_list = readRDS(gsub(".gz", "", gxe2_cds_path))
90 |
91 | base_path = data_cache_dir
92 | gxe2_seurat_list = lapply(1:length(gxe2_list), function(i) {
93 | sce = gxe2_list[[i]]
94 | counts = sce@assays@.xData$data$counts
95 | gene_meta = data.frame(sce@rowRanges@elementMetadata@listData)
96 | cell_meta = data.frame(sce@colData)
97 | head(cell_meta)
98 |
99 | colnames(counts) = rownames(cell_meta)
100 | rownames(counts) = gene_meta$id
101 | seurat_obj = CreateSeuratObject(counts, meta.data = cell_meta)
102 | seurat_obj@assays$RNA@meta.features <- gene_meta
103 | for (col in colnames(seurat_obj@meta.data)) {
104 | if (is.factor(seurat_obj@meta.data[[col]])) {
105 | print(col)
106 | seurat_obj@meta.data[[col]] = as.character(seurat_obj@meta.data[[col]])
107 | }
108 | }
109 | cl = names(gxe2_list)[[i]]
110 | seurat_obj$cell_type = cl
111 |
112 | out_h5Seurat = paste0(base_path, "/gxe2_", cl, ".h5Seurat")
113 | SaveH5Seurat(seurat_obj, filename = out_h5Seurat, overwrite = T)
114 | Convert(out_h5Seurat, dest = "h5ad", overwrite = T)
115 |
116 | seurat_obj
117 | })
--------------------------------------------------------------------------------
/notebooks/neurips2025/data_curation/curate_Norman19.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 2023-04-17-Curation: Norman19 Combo Screen"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "Add cross validation splits to Norman19"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 1,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "import scanpy as sc\n",
24 | "import os\n",
25 | "import subprocess as sp\n",
26 | "from perturbench.analysis.preprocess import preprocess\n",
27 | "from scipy.sparse import csr_matrix\n",
28 | "\n",
29 | "%reload_ext autoreload\n",
30 | "%autoreload 2"
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "metadata": {},
36 | "source": [
37 | "Download from: https://zenodo.org/records/7041849/files/NormanWeissman2019_filtered.h5ad?download=1"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 2,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "data_url = 'https://zenodo.org/records/7041849/files/NormanWeissman2019_filtered.h5ad?download=1'\n",
47 | "data_cache_dir = '../perturbench_data' ## Change this to your local data directory\n",
48 | "\n",
49 | "if not os.path.exists(data_cache_dir):\n",
50 | " os.makedirs(data_cache_dir)\n",
51 | "\n",
52 | "tmp_data_dir = f'{data_cache_dir}/norman19_downloaded.h5ad'\n",
53 | "\n",
54 | "if not os.path.exists(tmp_data_dir):\n",
55 | " sp.call(f'wget {data_url} -O {tmp_data_dir}', shell=True)"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": 3,
61 | "metadata": {},
62 | "outputs": [
63 | {
64 | "data": {
65 | "text/plain": [
66 | "AnnData object with n_obs × n_vars = 111445 × 33694\n",
67 | " obs: 'guide_id', 'read_count', 'UMI_count', 'coverage', 'gemgroup', 'good_coverage', 'number_of_cells', 'tissue_type', 'cell_line', 'cancer', 'disease', 'perturbation_type', 'celltype', 'organism', 'perturbation', 'nperts', 'ngenes', 'ncounts', 'percent_mito', 'percent_ribo'\n",
68 | " var: 'ensemble_id', 'ncounts', 'ncells'"
69 | ]
70 | },
71 | "execution_count": 3,
72 | "metadata": {},
73 | "output_type": "execute_result"
74 | }
75 | ],
76 | "source": [
77 | "adata = sc.read_h5ad(tmp_data_dir)\n",
78 | "adata"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": 4,
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "adata.obs.rename(columns = {\n",
88 | " 'nCount_RNA': 'ncounts',\n",
89 | " 'nFeature_RNA': 'ngenes',\n",
90 | " 'percent.mt': 'percent_mito',\n",
91 | " 'cell_line': 'cell_type',\n",
92 | "}, inplace=True)"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 5,
98 | "metadata": {},
99 | "outputs": [
100 | {
101 | "data": {
102 | "text/plain": [
103 | "['ARID1A', 'BCORL1', 'FOSB', 'SET_KLF1', 'OSR2', ..., 'CEBPB_OSR2', 'PRDM1_CBFA2T3', 'FOSB_CEBPB', 'ZBTB10_DLX2', 'FEV_CBFA2T3']\n",
104 | "Length: 237\n",
105 | "Categories (237, object): ['AHR', 'AHR_FEV', 'AHR_KLF1', 'ARID1A', ..., 'ZC3HAV1_HOXC13', 'ZNF318', 'ZNF318_FOXL2', 'control']"
106 | ]
107 | },
108 | "execution_count": 5,
109 | "metadata": {},
110 | "output_type": "execute_result"
111 | }
112 | ],
113 | "source": [
114 | "adata.obs.perturbation.unique()"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": 6,
120 | "metadata": {},
121 | "outputs": [
122 | {
123 | "data": {
124 | "text/plain": [
125 | "perturbation\n",
126 | "control 11855\n",
127 | "KLF1 1960\n",
128 | "BAK1 1457\n",
129 | "CEBPE 1233\n",
130 | "CEBPE+RUNX1T1 1219\n",
131 | " ... \n",
132 | "CEBPB+CEBPA 64\n",
133 | "CBL+UBASH3A 64\n",
134 | "C3orf72+FOXL2 59\n",
135 | "JUN+CEBPB 59\n",
136 | "JUN+CEBPA 54\n",
137 | "Name: count, Length: 237, dtype: int64"
138 | ]
139 | },
140 | "execution_count": 6,
141 | "metadata": {},
142 | "output_type": "execute_result"
143 | }
144 | ],
145 | "source": [
146 | "adata.obs['perturbation'] = adata.obs['perturbation'].str.replace('_', '+')\n",
147 | "adata.obs['perturbation'] = adata.obs['perturbation'].astype('category')\n",
148 | "adata.obs.perturbation.value_counts()"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": 7,
154 | "metadata": {},
155 | "outputs": [],
156 | "source": [
157 | "adata.obs['condition'] = adata.obs.perturbation.copy()"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": 8,
163 | "metadata": {},
164 | "outputs": [],
165 | "source": [
166 | "adata.X = csr_matrix(adata.X)"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": 9,
172 | "metadata": {},
173 | "outputs": [
174 | {
175 | "name": "stdout",
176 | "output_type": "stream",
177 | "text": [
178 | "Preprocessing ...\n",
179 | "Filtering for highly variable genes or differentially expressed genes ...\n",
180 | "Processed dataset summary:\n",
181 | "View of AnnData object with n_obs × n_vars = 111445 × 5666\n",
182 | " obs: 'guide_id', 'read_count', 'UMI_count', 'coverage', 'gemgroup', 'good_coverage', 'number_of_cells', 'tissue_type', 'cell_type', 'cancer', 'disease', 'perturbation_type', 'celltype', 'organism', 'perturbation', 'nperts', 'ngenes', 'ncounts', 'percent_mito', 'percent_ribo', 'condition', 'cov_merged', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes'\n",
183 | " var: 'ensemble_id', 'ncounts', 'ncells', 'n_cells', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'highly_variable_nbatches'\n",
184 | " uns: 'log1p', 'hvg', 'rank_genes_groups_cov'\n",
185 | " layers: 'counts'\n"
186 | ]
187 | }
188 | ],
189 | "source": [
190 | "adata = preprocess(\n",
191 | " adata,\n",
192 | " perturbation_key='condition',\n",
193 | " covariate_keys=['cell_type'],\n",
194 | ")"
195 | ]
196 | },
197 | {
198 | "cell_type": "code",
199 | "execution_count": 12,
200 | "metadata": {},
201 | "outputs": [],
202 | "source": [
203 | "adata = adata.copy()\n",
204 | "output_data_path = f'{data_cache_dir}/norman19_processed.h5ad'\n",
205 | "adata.write_h5ad(output_data_path)"
206 | ]
207 | }
208 | ],
209 | "metadata": {
210 | "kernelspec": {
211 | "display_name": "perturbench-dev",
212 | "language": "python",
213 | "name": "python3"
214 | },
215 | "language_info": {
216 | "codemirror_mode": {
217 | "name": "ipython",
218 | "version": 3
219 | },
220 | "file_extension": ".py",
221 | "mimetype": "text/x-python",
222 | "name": "python",
223 | "nbconvert_exporter": "python",
224 | "pygments_lexer": "ipython3",
225 | "version": "3.11.9"
226 | }
227 | },
228 | "nbformat": 4,
229 | "nbformat_minor": 4
230 | }
231 |
--------------------------------------------------------------------------------
/notebooks/neurips2025/data_curation/readme.txt:
--------------------------------------------------------------------------------
1 | URLs to download public data are in the beginning of each curation notebook or R script.
2 |
3 | Some datasets were uploaded as R objects and thus required a 2 step conversion to AnnData h5ad.
--------------------------------------------------------------------------------
/notebooks/neurips2025/gears/gears_helpers.py:
--------------------------------------------------------------------------------
1 | import scanpy as sc
2 | import anndata as ad
3 | import pandas as pd
4 | import numpy as np
5 | import torch
6 | import optuna
7 |
8 | from gears import PertData, GEARS
9 | from analysis.benchmarks.evaluation import Evaluation
10 |
11 |
12 | class GEARSHParamsRange:
13 | """
14 | Hyperparameter ranges for GEARS.
15 | """
16 |
17 | @staticmethod
18 | def get_distributions():
19 | return {
20 | "hidden_size": optuna.distributions.IntDistribution(32, 512, step=32),
21 | "num_go_gnn_layers": optuna.distributions.IntDistribution(1, 3),
22 | "num_gene_gnn_layers": optuna.distributions.IntDistribution(1, 3),
23 | "decoder_hidden_size": optuna.distributions.IntDistribution(
24 | 16, 48, step=16
25 | ),
26 | "num_similar_genes_go_graph": optuna.distributions.IntDistribution(
27 | 10, 30, step=5
28 | ),
29 | "num_similar_genes_co_express_graph": optuna.distributions.IntDistribution(
30 | 10, 30, step=5
31 | ),
32 | "coexpress_threshold": optuna.distributions.FloatDistribution(
33 | 0.2, 0.5, step=0.01
34 | ),
35 | "lr": optuna.distributions.FloatDistribution(5e-6, 1e-3, log=True),
36 | "wd": optuna.distributions.FloatDistribution(1e-8, 1e-3, log=True),
37 | }
38 |
39 |
40 | def run_gears(
41 | pert_data_path: str = "/weka/prime-shared/prime-data/gears/",
42 | dataset_name: str = "norman19",
43 | split_dict_path: str = "/weka/prime-shared/prime-data/gears/norman19_gears_split.pkl",
44 | eval_split: str = "val",
45 | batch_size: str = 32,
46 | epochs: int = 10,
47 | lr: float = 5e-4,
48 | wd: float = 1e-4,
49 | hidden_size: str = 128,
50 | num_go_gnn_layers: int = 1,
51 | num_gene_gnn_layers: int = 1,
52 | decoder_hidden_size: int = 16,
53 | num_similar_genes_go_graph: int = 20,
54 | num_similar_genes_co_express_graph: int = 20,
55 | coexpress_threshold: float = 0.4,
56 | device="cuda:0",
57 | seed=0,
58 | ):
59 | """
60 | Helper function to train and evaluate a GEARS model
61 | """
62 | pert_data = PertData(pert_data_path) # specific saved folder
63 | pert_data.load(
64 | data_path=pert_data_path + dataset_name
65 | ) # load the processed data, the path is saved folder + dataset_name
66 | pert_data.prepare_split(split="custom", split_dict_path=split_dict_path)
67 | pert_data.get_dataloader(batch_size=batch_size, test_batch_size=batch_size)
68 |
69 | gears_model = GEARS(
70 | pert_data,
71 | device=device,
72 | weight_bias_track=False,
73 | proj_name="pertnet",
74 | exp_name="pertnet",
75 | )
76 | gears_model.model_initialize(
77 | hidden_size=hidden_size,
78 | num_go_gnn_layers=num_go_gnn_layers,
79 | num_gene_gnn_layers=num_gene_gnn_layers,
80 | decoder_hidden_size=decoder_hidden_size,
81 | num_similar_genes_go_graph=num_similar_genes_go_graph,
82 | num_similar_genes_co_express_graph=num_similar_genes_co_express_graph,
83 | coexpress_threshold=coexpress_threshold,
84 | seed=seed,
85 | )
86 | gears_model.train(epochs=epochs, lr=lr, weight_decay=wd)
87 | torch.cuda.empty_cache()
88 |
89 | val_perts = []
90 | for p in pert_data.set2conditions[eval_split]:
91 | newp_list = []
92 | for gene in p.split("+"):
93 | if gene in gears_model.pert_list:
94 | newp_list.append(gene)
95 | if len(newp_list) > 0:
96 | val_perts.append(newp_list)
97 |
98 | val_avg_pred = gears_model.predict(val_perts)
99 | pred_df = pd.DataFrame(val_avg_pred).T
100 | pred_df.columns = gears_model.adata.var_names.values
101 | torch.cuda.empty_cache()
102 |
103 | ctrl_adata = gears_model.adata[gears_model.adata.obs.condition == "ctrl"]
104 | val_conditions = ["+".join(p) for p in val_perts] + ["ctrl"]
105 | ref_adata = gears_model.adata[gears_model.adata.obs.condition.isin(val_conditions)]
106 |
107 | pred_adata = sc.AnnData(pred_df)
108 | pred_adata.obs["condition"] = [x.replace("_", "+") for x in pred_adata.obs_names]
109 | pred_adata.obs["condition"] = pred_adata.obs["condition"].astype("category")
110 | pred_adata = ad.concat([pred_adata, ctrl_adata])
111 |
112 | ev = Evaluation(
113 | model_adatas={
114 | "GEARS": pred_adata,
115 | },
116 | ref_adata=ref_adata,
117 | pert_col="condition",
118 | ctrl="ctrl",
119 | )
120 |
121 | aggr_metrics = [
122 | ("average", "rmse"),
123 | ("logfc", "cosine"),
124 | ]
125 | summary_metrics_dict = {}
126 | for aggr, metric in aggr_metrics:
127 | ev.evaluate(aggr_method=aggr, metric=metric)
128 | ev.evaluate_pairwise(aggr_method=aggr, metric=metric)
129 | ev.evaluate_rank(aggr_method=aggr, metric=metric)
130 |
131 | metric_df = ev.get_eval(aggr_method=aggr, metric=metric)
132 | rank_df = ev.get_rank_eval(aggr_method=aggr, metric=metric)
133 | summary_metrics_dict[f"{metric}_{aggr}"] = np.mean(metric_df["GEARS"])
134 | summary_metrics_dict[f"{metric}_rank_{aggr}"] = np.mean(rank_df["GEARS"])
135 |
136 | return pd.Series(summary_metrics_dict)
137 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "setuptools-scm"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "perturbench"
7 | description = "PerturBench: Benchmarking Machine Learning Models for Cellular Perturbation Analysis"
8 | readme = "README.md"
9 | requires-python = ">=3.10"
10 | license = {text = "..."}
11 | classifiers = [
12 | "Programming Language :: Python :: 3",
13 | ]
14 | dependencies = [
15 | 'numpy',
16 | 'pyyaml',
17 | 'pandas',
18 | 'anndata',
19 | 'scanpy',
20 | 'scipy==1.12.0',
21 | 'tqdm',
22 | 'lightning',
23 | "torch",
24 | "torchvision",
25 | "tensorboard",
26 | "hydra-core",
27 | "hydra_colorlog",
28 | 'mlflow-skinny',
29 | 'matplotlib',
30 | 'seaborn',
31 | 'scikit-learn',
32 | 'scikit-misc',
33 | 'adjusttext',
34 | 'pytest',
35 | 'rich',
36 | 'psycopg2-binary',
37 | 'optuna',
38 | 'ray',
39 | 'python-dotenv',
40 | ]
41 | dynamic = ["version"]
42 |
43 | [project.optional-dependencies]
44 | dev = [
45 | "ruff",
46 | ]
47 | test = [
48 | "pytest",
49 | "pytest-cov",
50 | ]
51 | cli = [
52 | "click",
53 | "rich",
54 | ]
55 |
56 | [project.scripts]
57 | train = "perturbench.modelcore.train:main"
58 | predict = "perturbench.modelcore.predict:main"
59 |
60 | [tool.setuptools.packages.find]
61 | where = ["src"] # list of folders that contain the packages (["."] by default)
62 | include = ["perturbench*"] # package names should match these glob patterns (["*"] by default)
63 | exclude = ["tests", "docs", "examples"] # exclude packages matching these glob patterns (empty by default)
64 | # namespaces = false # to disable scanning PEP 420 namespaces (true by default)
65 |
66 | [tool.setuptools.dynamic]
67 | version = {attr = "perturbench.modelcore.VERSION"}
68 |
69 | [tool.pytest.ini_options]
70 | addopts = ["--import-mode=importlib"]
71 |
72 |
73 | [tool.ruff.lint.per-file-ignores]
74 | # Ignore `E402` (import violations) in all `__init__.py` files, and in `path/to/file.py`.
75 | "__init__.py" = ["F401"]
76 | "*.ipynb" = ["E402", "F401"]
77 | # "path/to/file.py" = ["E402"]
78 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup()
4 |
--------------------------------------------------------------------------------
/src/perturbench/analysis/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/altoslabs/perturbench/e57c2fbafeb1b22df4923b4a1b3d3c82d2ba57ef/src/perturbench/analysis/__init__.py
--------------------------------------------------------------------------------
/src/perturbench/analysis/benchmarks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/altoslabs/perturbench/e57c2fbafeb1b22df4923b4a1b3d3c82d2ba57ef/src/perturbench/analysis/benchmarks/__init__.py
--------------------------------------------------------------------------------
/src/perturbench/analysis/benchmarks/metrics.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics import r2_score
2 | from sklearn.metrics.pairwise import euclidean_distances, rbf_kernel
3 | from scipy.stats import pearsonr
4 | import numpy as np
5 | from numpy.linalg import norm
6 | import pandas as pd
7 | import tqdm
8 | from ..utils import merge_cols
9 |
10 |
11 | def compute_metric(x, y, metric):
12 | """Compute specified similarity/distance metric between x and y vectors"""
13 |
14 | if metric == "pearson":
15 | score = pearsonr(x, y)[0]
16 | # elif metric == 'spearman':
17 | # score = spearmanr(x, y)[0]
18 | elif metric == "r2_score":
19 | score = r2_score(x, y)
20 | elif metric == "cosine":
21 | score = np.dot(x, y) / (norm(x) * norm(y))
22 | elif metric == "mse":
23 | score = np.mean(np.square(x - y))
24 | elif metric == "rmse":
25 | score = np.sqrt(np.mean(np.square(x - y)))
26 | elif metric == "mae":
27 | score = np.mean(np.abs(x - y))
28 |
29 | return score
30 |
31 |
32 | def compare_perts(
33 | pred, ref, features=None, perts=None, metric="pearson", deg_dict=None
34 | ):
35 | """Compare expression similarities between `pred` and `ref` DataFrames using the specified metric"""
36 |
37 | if perts is None:
38 | perts = list(set(pred.index).intersection(ref.index))
39 | assert len(perts) > 0
40 | else:
41 | perts = list(perts)
42 |
43 | if features is not None:
44 | pred = pred.loc[:, features]
45 | ref = ref.loc[:, features]
46 |
47 | pred = pred.loc[perts, :]
48 | ref = ref.loc[perts, :]
49 |
50 | pred = pred.replace([np.inf, -np.inf], 0)
51 | ref = ref.replace([np.inf, -np.inf], 0)
52 |
53 | eval_metric = []
54 | for p in perts:
55 | if deg_dict is not None:
56 | genes = deg_dict[p]
57 | else:
58 | genes = ref.columns
59 |
60 | eval_metric.append(
61 | compute_metric(pred.loc[p, genes], ref.loc[p, genes], metric)
62 | )
63 |
64 | eval_scores = pd.Series(index=perts, data=eval_metric)
65 | return eval_scores
66 |
67 |
68 | def pairwise_metric_helper(
69 | df,
70 | df2=None,
71 | metric="rmse",
72 | pairwise_deg_dict=None,
73 | verbose=False,
74 | ):
75 | if df2 is None:
76 | df2 = df
77 |
78 | mat = pd.DataFrame(0.0, index=df.index, columns=df2.index)
79 | for p1 in tqdm.tqdm(df.index, disable=not verbose):
80 | for p2 in df2.index:
81 | if pairwise_deg_dict is not None:
82 | pp = frozenset([p1, p2])
83 | genes_ix = pairwise_deg_dict[pp]
84 |
85 | m = compute_metric(
86 | df.loc[p1, genes_ix],
87 | df2.loc[p2, genes_ix],
88 | metric=metric,
89 | )
90 | else:
91 | m = compute_metric(
92 | df.loc[p1],
93 | df2.loc[p2],
94 | metric=metric,
95 | )
96 | mat.at[p1, p2] = m
97 |
98 | return mat
99 |
100 |
101 | def rank_helper(pred_ref_mat, metric_type):
102 | rel_ranks = pd.Series(1.0, index=pred_ref_mat.columns)
103 | for p in pred_ref_mat.columns:
104 | pred_metrics = pred_ref_mat.loc[:, p]
105 | pred_metrics = pred_metrics.sample(frac=1.0) ## Shuffle to avoid ties
106 | if metric_type == "distance":
107 | pred_metrics = pred_metrics.sort_values(ascending=True)
108 | elif metric_type == "similarity":
109 | pred_metrics = pred_metrics.sort_values(ascending=False)
110 | else:
111 | raise ValueError(
112 | "Invalid metric_type, should be either distance or similarity"
113 | )
114 |
115 | rel_ranks.loc[p] = np.where(pred_metrics.index == p)[0][0]
116 |
117 | rel_ranks = rel_ranks / len(rel_ranks)
118 | return rel_ranks
119 |
120 |
121 | def mmd_energy_distance_helper(
122 | eval, # an Evaluation objective
123 | model_name,
124 | pert_col,
125 | cov_cols,
126 | ctrl,
127 | delim='_',
128 | kernel='energy_distance',
129 | gamma=None,
130 | ):
131 | model_adata = eval.adatas[model_name]
132 | ref_adata = eval.adatas['ref']
133 |
134 | model_adata.obs[pert_col] = model_adata.obs[pert_col].astype('category')
135 | ref_adata.obs[pert_col] = ref_adata.obs[pert_col].astype('category')
136 |
137 | if len(cov_cols) == 0:
138 | model_adata.obs['_dummy_cov'] = '1'
139 | ref_adata.obs['_dummy_cov'] = '1'
140 | cov_cols = ['_dummy_cov']
141 |
142 | for col in cov_cols:
143 | assert col in model_adata.obs.columns
144 | assert col in ref_adata.obs.columns
145 |
146 | if kernel == 'energy_distance':
147 | kernel_fns = [lambda x, y: - euclidean_distances(x, y)]
148 | elif kernel == 'rbf_kernel':
149 | if gamma is None:
150 | all_gamma = np.logspace(1, -3, num=5)
151 | elif isinstance(gamma, list):
152 | all_gamma = np.array(gamma)
153 | else:
154 | all_gamma = np.array([gamma])
155 | kernel_fns = [lambda x, y: rbf_kernel(x, y, gamma=gamma) for gamma in all_gamma]
156 | print('rbf kernels with gammas:', kernel_fns)
157 | else:
158 | raise ValueError('Invalid kernel')
159 |
160 | model_adata_covs = merge_cols(model_adata.obs, cov_cols, delim=delim)
161 | ref_adata_covs = merge_cols(ref_adata.obs, cov_cols, delim=delim)
162 |
163 | ret = {'cov_pert': [], 'model': [], 'metric': []}
164 | for cov in model_adata_covs.cat.categories:
165 |
166 | if len(model_adata[model_adata_covs == cov, :].obs[pert_col].unique()) > 1: # has any perturbations beside control
167 |
168 | model_adata_subset_cov = model_adata[model_adata_covs == cov, :]
169 | ref_adata_subset_cov = ref_adata[ref_adata_covs == cov, :]
170 | model_adata_covs_perts = merge_cols(model_adata_subset_cov.obs, [pert_col], delim=delim)
171 | ref_adata_covs_perts = merge_cols(ref_adata_subset_cov.obs, [pert_col], delim=delim)
172 |
173 | for i, pert in enumerate(model_adata_covs_perts.cat.categories):
174 | if pert == ctrl:
175 | continue
176 |
177 | population_pred = model_adata_subset_cov[model_adata_covs_perts.isin([pert]), :].X
178 | population_truth = ref_adata_subset_cov[ref_adata_covs_perts.isin([pert]), :].X
179 |
180 | all_mmd = []
181 | for kernel in kernel_fns:
182 | xx = kernel(population_pred, population_pred)
183 | xy = kernel(population_pred, population_truth)
184 | yy = kernel(population_truth, population_truth)
185 | mmd = xx.mean() + yy.mean() - 2 * xy.mean()
186 | all_mmd.append(mmd)
187 | mmd = np.nanmean(all_mmd)
188 |
189 | ret['cov_pert'].append(f'{cov}{delim}{pert}')
190 | ret['model'].append(model_name)
191 | ret['metric'].append(mmd)
192 |
193 | eval.mmd_df = pd.DataFrame.from_dict(ret)
194 |
195 | return eval.mmd_df
--------------------------------------------------------------------------------
/src/perturbench/analysis/plotting.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import seaborn as sns
3 | from adjustText import adjust_text
4 |
5 |
6 | def scatter_labels(
7 | x,
8 | y,
9 | df=None,
10 | hue=None,
11 | labels=None,
12 | label_size=6,
13 | x_title=None,
14 | y_title=None,
15 | axis_title_size=14,
16 | axis_text_size=12,
17 | plot_title=None,
18 | title_size=15,
19 | ax=None,
20 | figsize=None,
21 | hide_legend=True,
22 | size=60,
23 | alpha=0.8,
24 | xlim=None,
25 | ylim=None,
26 | ident_line=False,
27 | quadrants=False,
28 | **kwargs,
29 | ):
30 | """Generate a scatterplot with optional text labels for data points"""
31 |
32 | if ax is None:
33 | fig, ax = plt.subplots(figsize=figsize)
34 |
35 | if df is not None:
36 | sns.scatterplot(ax=ax, x=x, y=y, hue=hue, data=df, marker=".", s=size, **kwargs)
37 | else:
38 | if type(size) in [int, float]:
39 | size = [size] * len(x)
40 |
41 | sns.scatterplot(ax=ax, x=x, y=y, hue=hue, marker=".", s=size, **kwargs)
42 |
43 | ax.set_xlabel(x_title, size=axis_title_size)
44 | ax.set_ylabel(y_title, size=axis_title_size)
45 | ax.tick_params(axis="x", labelsize=axis_text_size)
46 | ax.tick_params(axis="y", labelsize=axis_text_size)
47 | ax.set_title(plot_title, size=title_size)
48 |
49 | if xlim is not None:
50 | ax.set_xlim(xmin=xlim[0], xmax=xlim[1])
51 | if ylim is not None:
52 | ax.set_ylim(ymin=ylim[0], ymax=ylim[1])
53 |
54 | xlim = ax.get_xlim()
55 | ylim = ax.get_ylim()
56 | min_pt = max(xlim[0], ylim[0])
57 | max_pt = min(xlim[1], ylim[1])
58 | if ident_line:
59 | ax.plot((min_pt, max_pt), (min_pt, max_pt), ls="--", color="red", alpha=0.7)
60 |
61 | if quadrants:
62 | ax.axvline(ls="--", color="red", alpha=0.3)
63 | ax.axhline(ls="--", color="red", alpha=0.3)
64 |
65 | if hide_legend and (ax.get_legend() is not None):
66 | ax.get_legend().remove()
67 |
68 | if labels is not None:
69 | assert df is not None
70 | texts = []
71 | for lab in labels:
72 | texts.append(ax.text(df.loc[lab, x], df.loc[lab, y], lab, size=label_size))
73 |
74 | adjust_text(texts, arrowprops=dict(arrowstyle="-", color="black"), ax=ax)
75 |
76 |
77 | ## Base boxplot
78 | def boxplot_jitter(
79 | x,
80 | y,
81 | df,
82 | hue=None,
83 | x_title=None,
84 | y_title=None,
85 | axis_title_size=14,
86 | axis_text_size=12,
87 | jitter_size=10,
88 | alpha=0.8,
89 | figsize=None,
90 | plot_title=None,
91 | title_size=15,
92 | ax=None,
93 | violin=False,
94 | ):
95 | """Generate a boxplot or violin plot with jitter for the data points"""
96 | if ax is None:
97 | _, ax = plt.subplots(figsize=figsize)
98 |
99 | if violin:
100 | sns.violinplot(
101 | ax=ax, x=x, y=y, hue=hue, data=df, palette=sns.color_palette("colorblind")
102 | )
103 | plt.setp(ax.collections, alpha=alpha)
104 |
105 | else:
106 | sns.boxplot(
107 | ax=ax,
108 | x=x,
109 | y=y,
110 | hue=hue,
111 | data=df,
112 | showfliers=False,
113 | palette=sns.color_palette("colorblind"),
114 | )
115 | for patch in ax.patches:
116 | r, g, b, a = patch.get_facecolor()
117 | patch.set_facecolor((r, g, b, alpha))
118 |
119 | sns.stripplot(
120 | ax=ax,
121 | x=x,
122 | y=y,
123 | hue=hue,
124 | data=df,
125 | color="black",
126 | marker=".",
127 | size=jitter_size,
128 | dodge=True,
129 | label=False,
130 | legend=None,
131 | )
132 |
133 | ax.set_xlabel(x_title, size=axis_title_size)
134 | ax.set_ylabel(y_title, size=axis_title_size)
135 | ax.set_title(plot_title, size=title_size)
136 | ax.tick_params(axis="x", labelrotation=90, labelsize=axis_text_size)
137 | ax.tick_params(axis="y", labelsize=axis_text_size)
138 |
139 | if ax is None:
140 | plt.show()
141 |
--------------------------------------------------------------------------------
/src/perturbench/analysis/preprocess.py:
--------------------------------------------------------------------------------
1 | import scanpy as sc
2 | import anndata as ad
3 | from .utils import merge_cols
4 | import pandas as pd
5 | import warnings
6 |
7 |
8 | def differential_expression_by_covariate(
9 | adata,
10 | perturbation_key: str,
11 | perturbation_control_value: str,
12 | covariate_keys: list[str] = [],
13 | n_differential_genes=25,
14 | rankby_abs=True,
15 | key_added="rank_genes_groups_cov",
16 | return_dict=False,
17 | delim="_",
18 | ):
19 | """Finds the top `n_differential_genes` differentially expressed genes for
20 | each perturbation. The differential expression is run separately for
21 | each unique covariate category in the dataset.
22 |
23 | Args:
24 | adata: AnnData dataset
25 | perturbation_key: the key in adata.obs that contains the perturbations
26 | perturbation_combination_delimiter: the delimiter used to separate
27 | perturbations in the perturbation_key
28 | covariate_keys: a list of keys in adata.obs that contain the covariates
29 | perturbation_control_value: the value in adata.obs[perturbation_key] that
30 | corresponds to control cells
31 | n_differential_genes: number of top differentially expressed genes for
32 | each perturbation
33 | rankby_abs: if True, rank genes by absolute values of the score, thus including
34 | top downregulated genes in the top N genes. If False, the ranking will
35 | have only upregulated genes at the top.
36 | key_added: key used when adding the DEG dictionary to adata.uns
37 | return_dict: if True, return the DEG dictionary
38 |
39 | Returns:
40 | Adds the DEG dictionary to adata.uns
41 |
42 | If return_dict is True returns:
43 | gene_dict : dict
44 | Dictionary where groups are stored as keys, and the list of DEGs
45 | are the corresponding values
46 | """
47 | if "base" not in adata.uns["log1p"]:
48 | adata.uns["log1p"]["base"] = None
49 |
50 | gene_dict = {}
51 | if len(covariate_keys) == 0:
52 | sc.tl.rank_genes_groups(
53 | adata,
54 | groupby=perturbation_key,
55 | reference=perturbation_control_value,
56 | rankby_abs=rankby_abs,
57 | n_genes=n_differential_genes,
58 | )
59 |
60 | top_de_genes = pd.DataFrame(adata.uns["rank_genes_groups"]["names"])
61 | for group in top_de_genes:
62 | gene_dict[group] = top_de_genes[group].tolist()
63 |
64 | else:
65 | merged_covariates = merge_cols(adata.obs, covariate_keys, delim=delim)
66 | for unique_covariate in merged_covariates.unique():
67 | adata_cov = adata[merged_covariates == unique_covariate]
68 | sc.tl.rank_genes_groups(
69 | adata_cov,
70 | groupby=perturbation_key,
71 | reference=perturbation_control_value,
72 | rankby_abs=rankby_abs,
73 | n_genes=n_differential_genes,
74 | )
75 |
76 | top_de_genes = pd.DataFrame(adata_cov.uns["rank_genes_groups"]["names"])
77 | for group in top_de_genes:
78 | cov_group = unique_covariate + delim + group
79 | gene_dict[cov_group] = top_de_genes[group].tolist()
80 |
81 | adata.uns[key_added] = gene_dict
82 |
83 | if return_dict:
84 | return gene_dict
85 |
86 |
87 | def preprocess(
88 | adata: ad.AnnData,
89 | perturbation_key: str,
90 | covariate_keys: list[str],
91 | control_value: str = "control",
92 | combination_delimiter: str = "+",
93 | highly_variable: int = 4000,
94 | degs: int = 25,
95 | ):
96 | adata.raw = None
97 | adata.var_names_make_unique()
98 | adata.obs_names_make_unique()
99 |
100 | ## Merge covariate columns
101 | adata.obs["cov_merged"] = merge_cols(adata.obs, covariate_keys)
102 |
103 | ## Preprocess
104 | print("Preprocessing ...")
105 | sc.pp.filter_genes(adata, min_cells=10)
106 | sc.pp.calculate_qc_metrics(adata, inplace=True)
107 |
108 | ## Normalize if needed
109 | adata.layers["counts"] = adata.X.copy()
110 | sc.pp.normalize_total(adata)
111 | sc.pp.log1p(adata)
112 |
113 | ## Pull out perturbed genes
114 | unique_perturbations = set()
115 | for comb in adata.obs[perturbation_key].unique():
116 | unique_perturbations.update(comb.split(combination_delimiter))
117 | unique_perturbations = unique_perturbations.intersection(adata.var_names)
118 |
119 | ## Subset to highly variable or differentially expressed genes
120 | if highly_variable > 0:
121 | print(
122 | "Filtering for highly variable genes or differentially expressed genes ..."
123 | )
124 | sc.pp.highly_variable_genes(
125 | adata,
126 | batch_key="cov_merged",
127 | flavor="seurat_v3",
128 | layer="counts",
129 | n_top_genes=int(highly_variable),
130 | subset=False,
131 | )
132 |
133 | with warnings.catch_warnings():
134 | warnings.simplefilter("ignore")
135 | deg_gene_dict = differential_expression_by_covariate(
136 | adata,
137 | perturbation_key,
138 | control_value,
139 | covariate_keys,
140 | n_differential_genes=degs,
141 | rankby_abs=True,
142 | key_added="rank_genes_groups_cov",
143 | return_dict=True,
144 | delim="_",
145 | )
146 | deg_genes = set()
147 | for genes in deg_gene_dict.values():
148 | deg_genes.update(genes)
149 |
150 | var_genes = list(adata.var_names[adata.var["highly_variable"]])
151 | var_genes = list(unique_perturbations.union(var_genes).union(deg_genes))
152 | adata = adata[:, var_genes]
153 |
154 | print("Processed dataset summary:")
155 | print(adata)
156 |
157 | return adata
158 |
--------------------------------------------------------------------------------
/src/perturbench/analysis/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | from collections import defaultdict
3 |
4 |
5 | def merge_cols(obs_df, cols, delim="_"):
6 | """Merge columns in DataFrame"""
7 |
8 | covs = obs_df[cols[0]]
9 | if len(cols) > 1:
10 | for i in range(1, len(cols)):
11 | covs = covs.astype(str) + delim + obs_df[cols[i]].astype(str)
12 | covs = covs.astype("category")
13 | return covs
14 |
15 |
16 | def subsample_by_group(adata, obs_key, max_cells=10000):
17 | """
18 | Downsample anndata on a per group basis
19 | """
20 | cells_keep = []
21 | for k in list(adata.obs[obs_key].unique()):
22 | cells = list(adata[adata.obs[obs_key] == k].obs_names)
23 | if len(cells) > max_cells:
24 | cells = random.sample(cells, k=max_cells)
25 | cells_keep.extend(cells)
26 |
27 | print("Input cells: " + str(adata.shape[0]))
28 | print("Sampled cells: " + str(len(cells_keep)))
29 |
30 | adata_downsampled = adata[cells_keep, :]
31 | return adata_downsampled
32 |
33 |
34 | def pert_cluster_filter(adata, pert_key, cluster_key, delim="_", min_cells=20):
35 | """
36 | Filter anndata to only include perturbations that have at least min_cells in all clusters
37 | """
38 |
39 | adata.obs["pert_cluster"] = (
40 | adata.obs[pert_key].astype(str) + delim + adata.obs[cluster_key].astype(str)
41 | )
42 | pert_cl_counts = adata.obs["pert_cluster"].value_counts()
43 | pert_cl_keep = pert_cl_counts.loc[[x >= min_cells for x in pert_cl_counts]].index
44 |
45 | pert_counts_dict = defaultdict(int)
46 | for x in list(pert_cl_keep):
47 | pert_counts_dict[x.split(delim)[0]] += 1
48 |
49 | perts_keep = [
50 | x
51 | for x, n in pert_counts_dict.items()
52 | if n == len(adata.obs[cluster_key].unique())
53 | ]
54 |
55 | print("Original perturbations: " + str(len(adata.obs[pert_key].unique())))
56 | print("Filtered perturbations: " + str(len(perts_keep)))
57 |
58 | adata_filtered = adata[adata.obs[pert_key].isin(perts_keep), :]
59 | return adata_filtered
60 |
61 |
62 | def get_ensembl_mappings():
63 | try:
64 | from pybiomart import Dataset
65 | except ImportError:
66 | raise ImportError("Please install the pybiomart package to use this function")
67 |
68 | # Set up connection to server
69 | dataset = Dataset(name="hsapiens_gene_ensembl", host="http://www.ensembl.org")
70 |
71 | id_gene_df = dataset.query(attributes=["ensembl_gene_id", "hgnc_symbol"])
72 |
73 | ensembl_to_genesymbol = {}
74 | # Store the data in a dict
75 | for gene_id, gene_symbol in zip(
76 | id_gene_df["Gene stable ID"], id_gene_df["HGNC symbol"]
77 | ):
78 | ensembl_to_genesymbol[gene_id] = gene_symbol
79 |
80 | return ensembl_to_genesymbol
81 |
--------------------------------------------------------------------------------
/src/perturbench/configs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/altoslabs/perturbench/e57c2fbafeb1b22df4923b4a1b3d3c82d2ba57ef/src/perturbench/configs/__init__.py
--------------------------------------------------------------------------------
/src/perturbench/configs/callbacks/default.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - early_stopping
3 | - lr_monitor
4 | - model_summary
5 | - model_checkpoint
6 | - rich_progress_bar
7 | - _self_
8 |
9 | model_summary:
10 | max_depth: -1
11 |
--------------------------------------------------------------------------------
/src/perturbench/configs/callbacks/early_stopping.yaml:
--------------------------------------------------------------------------------
1 | early_stopping:
2 | _target_: lightning.pytorch.callbacks.EarlyStopping
3 | monitor: val_loss # quantity to monitor
4 | patience: 50 # early stopping patience
--------------------------------------------------------------------------------
/src/perturbench/configs/callbacks/lr_monitor.yaml:
--------------------------------------------------------------------------------
1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.LearningRateMonitor.html
2 |
3 | lr_monitor:
4 | _target_: lightning.pytorch.callbacks.LearningRateMonitor
5 | log_momentum: True
6 | logging_interval: epoch
7 |
--------------------------------------------------------------------------------
/src/perturbench/configs/callbacks/model_checkpoint.yaml:
--------------------------------------------------------------------------------
1 | model_checkpoint:
2 | _target_: lightning.pytorch.callbacks.ModelCheckpoint
3 | monitor: val_loss # quantity to monitor
4 | dirpath: ${hydra:runtime.output_dir}/checkpoints # directory to save checkpoints
5 | # every_n_train_steps: 80 # save checkpoint every n train steps
6 | # every_n_epochs: 5 # save checkpoint every n epochs
--------------------------------------------------------------------------------
/src/perturbench/configs/callbacks/model_summary.yaml:
--------------------------------------------------------------------------------
1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
2 |
3 | model_summary:
4 | _target_: lightning.pytorch.callbacks.RichModelSummary
5 | max_depth: 1 # the maximum depth of layer nesting that the summary will include
6 |
--------------------------------------------------------------------------------
/src/perturbench/configs/callbacks/none.yaml:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/altoslabs/perturbench/e57c2fbafeb1b22df4923b4a1b3d3c82d2ba57ef/src/perturbench/configs/callbacks/none.yaml
--------------------------------------------------------------------------------
/src/perturbench/configs/callbacks/rich_progress_bar.yaml:
--------------------------------------------------------------------------------
1 | # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
2 |
3 | rich_progress_bar:
4 | _target_: lightning.pytorch.callbacks.RichProgressBar
5 |
--------------------------------------------------------------------------------
/src/perturbench/configs/data/devel.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - transform: linear_model_pipeline
3 | - collate: linear_model_collate
4 | - splitter: default
5 | - evaluation: default
6 |
7 | _target_: perturbench.data.modules.AnnDataLitModule
8 | datapath: src/perturbench/data/resources/devel.h5ad
9 | perturbation_key: condition
10 | perturbation_combination_delimiter: +
11 | perturbation_control_value: control
12 | covariate_keys: [cell_type]
13 | batch_size: 8
14 | num_workers: 0
15 | batch_sample: True
16 | add_controls: True
17 | use_counts: False
--------------------------------------------------------------------------------
/src/perturbench/configs/data/evaluation/default.yaml:
--------------------------------------------------------------------------------
1 | split_value_to_evaluate: val
2 | max_control_cells_per_covariate: 1000
3 |
4 | evaluation_pipelines:
5 | - aggregation: average
6 | metric: rmse
7 | rank: True
8 |
9 | - aggregation: logfc
10 | metric: cosine
11 | rank: True
12 |
13 | save_evaluation: True
14 | save_dir: "${paths.output_dir}/evaluation/"
15 | chunk_size: 20
16 | print_summary: True
--------------------------------------------------------------------------------
/src/perturbench/configs/data/evaluation/final_test.yaml:
--------------------------------------------------------------------------------
1 | split_value_to_evaluate: test
2 | max_control_cells_per_covariate: 1000
3 |
4 | evaluation_pipelines:
5 | - aggregation: average
6 | metric: rmse
7 | rank: True
8 |
9 | - aggregation: logfc
10 | metric: cosine
11 | rank: True
12 |
13 | save_evaluation: True
14 | save_dir: "${paths.output_dir}/evaluation/"
15 | chunk_size: 20
16 | print_summary: True
--------------------------------------------------------------------------------
/src/perturbench/configs/data/frangieh21.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - transform: linear_model_pipeline
3 | - splitter: saved_split
4 | - evaluation: default
5 |
6 | _target_: perturbench.data.modules.AnnDataLitModule
7 | datapath: ${paths.data_dir}/frangieh21_processed.h5ad
8 | perturbation_key: condition
9 | perturbation_combination_delimiter: +
10 | perturbation_control_value: control
11 | covariate_keys: [treatment]
12 | batch_size: 8000
13 | num_workers: 8
14 | batch_sample: True
15 | add_controls: True
16 | use_counts: False
17 |
18 | splitter:
19 | split_path: ${paths.data_dir}/frangieh21_split.csv
--------------------------------------------------------------------------------
/src/perturbench/configs/data/jiang24.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - transform: linear_model_pipeline
3 | - splitter: saved_split
4 | - evaluation: default
5 |
6 | _target_: perturbench.data.modules.AnnDataLitModule
7 | datapath: ${paths.data_dir}/jiang24_processed.h5ad
8 | perturbation_key: condition
9 | perturbation_combination_delimiter: +
10 | perturbation_control_value: control
11 | covariate_keys: [cell_type,treatment]
12 | batch_size: 2000
13 | num_workers: 8
14 | batch_sample: True
15 | add_controls: True
16 | use_counts: False
17 |
18 | splitter:
19 | split_path: ${paths.data_dir}/jiang24_split.csv
--------------------------------------------------------------------------------
/src/perturbench/configs/data/mcfaline23.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - transform: linear_model_pipeline
3 | - splitter: mcfaline23_split
4 | - evaluation: default
5 |
6 | _target_: perturbench.data.modules.AnnDataLitModule
7 | datapath: ${paths.data_dir}/mcfaline23_gxe_processed.h5ad
8 | perturbation_key: condition
9 | perturbation_combination_delimiter: +
10 | perturbation_control_value: control
11 | covariate_keys: [cell_type,treatment]
12 | batch_size: 8000
13 | num_workers: 12
14 | num_val_workers: 2
15 | num_test_workers: 0
16 | batch_sample: True
17 | add_controls: True
18 | use_counts: False
19 | embedding_key: null
--------------------------------------------------------------------------------
/src/perturbench/configs/data/norman19.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - transform: linear_model_pipeline
3 | - splitter: combination_prediction_task
4 | - evaluation: default
5 |
6 | _target_: perturbench.data.modules.AnnDataLitModule
7 | datapath: ${paths.data_dir}/norman19_processed.h5ad
8 | perturbation_key: condition
9 | perturbation_combination_delimiter: +
10 | perturbation_control_value: control
11 | covariate_keys: [cell_type]
12 | batch_size: 4000
13 | num_workers: 8
14 | batch_sample: True
15 | add_controls: True
16 | use_counts: False
17 | embedding_key: null
18 |
19 | splitter:
20 | max_heldout_fraction_per_covariate: 0.7
--------------------------------------------------------------------------------
/src/perturbench/configs/data/sciplex3.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - transform: linear_model_pipeline
3 | - splitter: cell_type_transfer_task
4 | - evaluation: default
5 |
6 | _target_: perturbench.data.modules.AnnDataLitModule
7 | datapath: ${paths.data_dir}/sciplex3_processed.h5ad
8 | perturbation_key: condition
9 | perturbation_combination_delimiter: +
10 | perturbation_control_value: control
11 | covariate_keys: [cell_type]
12 | batch_size: 4000
13 | num_workers: 8
14 | batch_sample: True
15 | add_controls: True
16 | use_counts: False
17 | embedding_key: null
--------------------------------------------------------------------------------
/src/perturbench/configs/data/splitter/cell_type_transfer_task.yaml:
--------------------------------------------------------------------------------
1 | # split_path: data/splits ## Specify this path if you want to use a specific saved path for the splits
2 | task: transfer ## Either `transfer`, `combine`, or `combine_inverse`
3 | covariate_keys: [cell_type]
4 | min_train_covariates: 1 ## Minimum number of covariates to train on per perturbation
5 | max_heldout_covariates: 1 ## Maximum number of covariates to hold out per perturbation
6 | max_heldout_fraction_per_covariate: 0.3
7 | max_heldout_perturbations_per_covariate: 200
8 | train_control_fraction: 0.5
9 | downsample_fraction: 1.0
10 | splitter_seed: 42
11 | save: True ## Whether to save the split to disk
12 | output_path: "${paths.output_dir}/" ## Specify this path if you want to save the split
--------------------------------------------------------------------------------
/src/perturbench/configs/data/splitter/cell_type_treatment_transfer_task.yaml:
--------------------------------------------------------------------------------
1 | # split_path: data/splits ## Specify this path if you want to use a specific saved path for the splits
2 | task: transfer ## Either `transfer`, `combine`, or `combine_inverse`
3 | covariate_keys: [cell_type,treatment]
4 | min_train_covariates: 1 ## Minimum number of covariates to train on per perturbation
5 | max_heldout_covariates: 1 ## Maximum number of covariates to hold out per perturbation
6 | max_heldout_fraction_per_covariate: 0.3
7 | max_heldout_perturbations_per_covariate: 200
8 | train_control_fraction: 0.5
9 | downsample_fraction: 1.0
10 | splitter_seed: 42
11 | save: True ## Whether to save the split to disk
12 | output_path: "${paths.output_dir}/" ## Specify this path if you want to save the split
--------------------------------------------------------------------------------
/src/perturbench/configs/data/splitter/combination_prediction_task.yaml:
--------------------------------------------------------------------------------
1 | # split_path: data/splits ## Specify this path if you want to use a specific saved path for the splits
2 | task: combine ## Either `transfer`, `combine`, or `combine_inverse`
3 | covariate_keys: [cell_type]
4 | max_heldout_fraction_per_covariate: 0.3
5 | max_heldout_perturbations_per_covariate: 200
6 | train_control_fraction: 0.5
7 | downsample_fraction: 1.0
8 | splitter_seed: 42
9 | save: True ## Whether to save the split to disk
10 | output_path: "${paths.output_dir}/" ## Specify this path if you want to save the split
--------------------------------------------------------------------------------
/src/perturbench/configs/data/splitter/mcfaline23_split.yaml:
--------------------------------------------------------------------------------
1 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/full_covariate_split.csv ## Specify this path if you want to use a specific saved path for the splits
--------------------------------------------------------------------------------
/src/perturbench/configs/data/splitter/saved_split.yaml:
--------------------------------------------------------------------------------
1 | split_path: null
--------------------------------------------------------------------------------
/src/perturbench/configs/data/transform/linear_model_pipeline.yaml:
--------------------------------------------------------------------------------
1 | conf:
2 | _target_: perturbench.data.transforms.pipelines.SingleCellPipeline
3 | _partial_: true
4 | dependencies:
5 | perturbation_uniques: perturbation_uniques
6 | covariate_uniques: covariate_uniques
7 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/frangieh21/biolord_best_params_frangieh21.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: biolord
5 | - override /callbacks: default
6 | - override /data: frangieh21
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 | precision: 32
12 |
13 | model:
14 | _target_: anubis.models.BiolordStar
15 | dropout: 0.30000000000000004
16 | encoder_width: 5376
17 | latent_dim: 512
18 | lr: 2.643716468011295e-05
19 | n_genes: null
20 | n_layers: 1
21 | n_perts: null
22 | penalty_weight: 49.426728002719415
23 | softplus_output: true
24 | wd: 3.8596437745886474e-08
25 |
26 | data:
27 | add_controls: false
28 | batch_size: 1000
29 | evaluation:
30 | chunk_size: 10
31 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/frangieh21/cpa_best_params_frangieh21.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: cpa
5 | - override /callbacks: default
6 | - override /data: frangieh21
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 |
12 | model:
13 | adv_classifier_hidden_dim: 559
14 | adv_classifier_n_layers: 4
15 | adv_steps: 5
16 | adv_weight: 0.1853767280161523
17 | dropout: 0.1
18 | elementwise_affine: false
19 | hidden_dim: 4352
20 | kl_weight: 0.21855673778824847
21 | lr: 3.680992074544368e-05
22 | n_latent: 128
23 | n_layers_covar_emb: 1
24 | n_layers_encoder: 1
25 | n_layers_pert_emb: 1
26 | n_warmup_epochs: 20
27 | penalty_weight: 7.605167161204701
28 | softplus_output: false
29 | variational: true
30 | wd: 1.9023489520016998e-05
31 |
32 | data:
33 | add_controls: false
34 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/frangieh21/cpa_no_adv_best_params_frangieh21.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: cpa
5 | - override /callbacks: default
6 | - override /data: frangieh21
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 |
12 | model:
13 | use_adversary: false
14 | adv_classifier_hidden_dim: 559
15 | adv_classifier_n_layers: 4
16 | adv_steps: 5
17 | adv_weight: 0.1853767280161523
18 | dropout: 0.1
19 | elementwise_affine: false
20 | hidden_dim: 4352
21 | kl_weight: 0.21855673778824847
22 | lr: 3.680992074544368e-05
23 | n_latent: 128
24 | n_layers_covar_emb: 1
25 | n_layers_encoder: 1
26 | n_layers_pert_emb: 1
27 | n_warmup_epochs: 400
28 | penalty_weight: 7.605167161204701
29 | softplus_output: false
30 | variational: true
31 | wd: 1.9023489520016998e-05
32 |
33 | data:
34 | add_controls: false
35 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/frangieh21/cpa_scgpt_best_params_frangieh21.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: cpa
5 | - override /callbacks: default
6 | - override /data: frangieh21
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 |
12 | model:
13 | adv_classifier_hidden_dim: 249
14 | adv_classifier_n_layers: 5
15 | adv_steps: 2
16 | adv_weight: 0.4809114972414237
17 | dropout: 0.1
18 | elementwise_affine: false
19 | hidden_dim: 4352
20 | kl_weight: 0.1519255328598612
21 | lr: 0.0006148770073381143
22 | n_latent: 128
23 | n_layers_covar_emb: 1
24 | n_layers_encoder: 1
25 | n_layers_pert_emb: 1
26 | n_warmup_epochs: 10
27 | penalty_weight: 0.19363459943522998
28 | softplus_output: false
29 | variational: true
30 | wd: 7.097486120374588e-08
31 |
32 | data:
33 | add_controls: false
34 | datapath: ${paths.data_dir}/prime-data/frangieh21_processed_scgpt.h5ad
35 | embedding_key: scgpt_embeddings
36 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/frangieh21/decoder_best_params_frangieh21.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: decoder_only
5 | - override /callbacks: default
6 | - override /data: frangieh21
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 |
12 | model:
13 | encoder_width: 3328
14 | lr: 0.001489196297126292
15 | n_genes: null
16 | n_layers: 1
17 | n_perts: null
18 | softplus_output: true
19 | use_covariates: true
20 | use_perturbations: true
21 | wd: 4.786646240929937e-06
22 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/frangieh21/decoder_cov_best_params_frangieh21.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: decoder_only
5 | - override /callbacks: default
6 | - override /data: frangieh21
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 |
12 | model:
13 | encoder_width: 4352
14 | lr: 0.0003078912524764639
15 | n_genes: null
16 | n_layers: 5
17 | n_perts: null
18 | softplus_output: false
19 | use_covariates: true
20 | use_perturbations: false
21 | wd: 9.633717477684286e-07
22 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/frangieh21/latent_best_params_frangieh21.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: latent_additive
5 | - override /callbacks: default
6 | - override /data: frangieh21
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 |
12 | model:
13 | dropout: 0.30000000000000004
14 | encoder_width: 2304
15 | inject_covariates_decoder: true
16 | inject_covariates_encoder: true
17 | latent_dim: 256
18 | lr: 0.00052308363076951
19 | n_genes: null
20 | n_layers: 1
21 | n_perts: null
22 | softplus_output: true
23 | wd: 1.564792695502722e-07
24 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/frangieh21/latent_scgpt_best_params_frangieh21.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: latent_additive
5 | - override /callbacks: default
6 | - override /data: frangieh21
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 |
12 | model:
13 | dropout: 0.30000000000000004
14 | encoder_width: 3328
15 | inject_covariates_decoder: true
16 | inject_covariates_encoder: true
17 | latent_dim: 192
18 | lr: 0.0004036126472173799
19 | n_genes: null
20 | n_layers: 1
21 | n_perts: null
22 | softplus_output: true
23 | wd: 2.0909467656368513e-08
24 |
25 | data:
26 | datapath: ${paths.data_dir}/prime-data/frangieh21_processed_scgpt.h5ad
27 | embedding_key: scgpt_embeddings
28 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/frangieh21/linear_best_params_frangieh21.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: linear_additive
5 | - override /callbacks: default
6 | - override /data: frangieh21
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 |
12 | model:
13 | inject_covariates: true
14 | lr: 0.003915672085065509
15 | n_genes: null
16 | n_perts: null
17 | wd: 1.4293895399278078e-08
18 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/frangieh21/sams_best_params_frangieh21.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: sams_vae
5 | - override /callbacks: default
6 | - override /data: frangieh21
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 | precision: 32
12 |
13 | model:
14 | dropout: 0.1
15 | hidden_dim_cond: 2560
16 | hidden_dim_x: 4096
17 | inject_covariates_decoder: false
18 | inject_covariates_encoder: false
19 | latent_dim: 24
20 | lr: 5.182988883689651e-07
21 | mask_prior_probability: 0.012572644354323684
22 | n_genes: null
23 | n_layers_decoder: 3
24 | n_layers_encoder_e: 5
25 | n_layers_encoder_x: 1
26 | n_perts: null
27 | softplus_output: true
28 | wd: 7.422084694841772e-10
29 |
30 | data:
31 | add_controls: false
32 | batch_size: 256
33 | evaluation:
34 | chunk_size: 10
35 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/frangieh21/sams_modified_best_params_frangieh21.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: sams_vae
5 | - override /callbacks: default
6 | - override /data: frangieh21
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 | precision: 32
12 |
13 | model:
14 | disable_e_dist: true
15 | disable_sparsity: true
16 | dropout: 0.0
17 | generative_counterfactual: true
18 | hidden_dim_cond: 329
19 | hidden_dim_x: 2560
20 | inject_covariates_decoder: true
21 | inject_covariates_encoder: true
22 | latent_dim: 128
23 | lr: 5.579821197158771e-06
24 | mask_prior_probability: 0.09647872920959553
25 | n_layers_decoder: 1
26 | n_layers_encoder_e: 4
27 | n_layers_encoder_x: 4
28 | softplus_output: true
29 | wd: 2.0857105211314087e-08
30 |
31 | data:
32 | add_controls: false
33 | batch_size: 256
34 | evaluation:
35 | chunk_size: 10
36 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/jiang24/cpa_best_params_jiang24.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: cpa
5 | - override /callbacks: default
6 | - override /data: jiang24
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 |
12 | model:
13 | adv_classifier_hidden_dim: 537
14 | adv_classifier_n_layers: 1
15 | adv_steps: 7
16 | adv_weight: 1.4870029218495044
17 | dropout: 0.30000000000000004
18 | elementwise_affine: false
19 | hidden_dim: 3328
20 | kl_weight: 6.904783806453286
21 | lr: 2.6302273130616072e-05
22 | n_latent: 512
23 | n_layers_covar_emb: 1
24 | n_layers_encoder: 7
25 | n_layers_pert_emb: 2
26 | n_warmup_epochs: 5
27 | penalty_weight: 2.8152900769395943
28 | softplus_output: false
29 | variational: true
30 | wd: 3.693112285894233e-07
31 |
32 | data:
33 | add_controls: false
34 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/jiang24/cpa_no_adv_best_params_jiang24.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: cpa
5 | - override /callbacks: default
6 | - override /data: jiang24
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 |
12 | model:
13 | use_adversary: false
14 | adv_classifier_hidden_dim: 537
15 | adv_classifier_n_layers: 1
16 | adv_steps: 7
17 | adv_weight: 1.4870029218495044
18 | dropout: 0.30000000000000004
19 | elementwise_affine: false
20 | hidden_dim: 3328
21 | kl_weight: 6.904783806453286
22 | lr: 2.6302273130616072e-05
23 | n_latent: 512
24 | n_layers_covar_emb: 1
25 | n_layers_encoder: 7
26 | n_layers_pert_emb: 2
27 | n_warmup_epochs: 400
28 | penalty_weight: 2.8152900769395943
29 | softplus_output: false
30 | variational: true
31 | wd: 3.693112285894233e-07
32 |
33 | data:
34 | add_controls: false
35 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/jiang24/decoder_best_params_jiang24.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: decoder_only
5 | - override /callbacks: default
6 | - override /data: jiang24
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 |
12 | model:
13 | encoder_width: 5376
14 | lr: 9.426394813968772e-05
15 | n_genes: null
16 | n_layers: 3
17 | n_perts: null
18 | softplus_output: true
19 | use_covariates: true
20 | use_perturbations: true
21 | wd: 8.513975008706282e-08
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/jiang24/latent_best_params_jiang24.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: latent_additive
5 | - override /callbacks: default
6 | - override /data: jiang24
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 |
12 | model:
13 | dropout: 0.6000000000000001
14 | encoder_width: 3328
15 | inject_covariates_decoder: true
16 | inject_covariates_encoder: true
17 | latent_dim: 256
18 | lr: 0.004991266977288785
19 | n_genes: null
20 | n_layers: 3
21 | n_perts: null
22 | softplus_output: true
23 | wd: 8.471259185195355e-07
24 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/jiang24/linear_best_params_jiang24.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: linear_additive
5 | - override /callbacks: default
6 | - override /data: jiang24
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 |
12 | model:
13 | inject_covariates: true
14 | lr: 0.0014212823361302294
15 | n_genes: null
16 | n_perts: null
17 | wd: 1.0079009696147276e-08
18 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/jiang24/sams_best_params_jiang24.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: sams_vae
5 | - override /callbacks: default
6 | - override /data: jiang24
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 | precision: 32
12 |
13 | model:
14 | dropout: 0.5
15 | generative_counterfactual: true
16 | hidden_dim_cond: 329
17 | hidden_dim_x: 3328
18 | inject_covariates_decoder: true
19 | inject_covariates_encoder: true
20 | latent_dim: 64
21 | lr: 7.594588526179077e-06
22 | mask_prior_probability: 0.20684484580190832
23 | n_layers_decoder: 3
24 | n_layers_encoder_e: 5
25 | n_layers_encoder_x: 3
26 | softplus_output: true
27 | wd: 1.5915315253373636e-08
28 |
29 | data:
30 | add_controls: false
31 | batch_size: 256
32 | evaluation:
33 | chunk_size: 10
34 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/jiang24/sams_modified_best_params_jiang24.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: sams_vae
5 | - override /callbacks: default
6 | - override /data: jiang24
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 | precision: 32
12 |
13 | model:
14 | disable_e_dist: true
15 | disable_sparsity: true
16 | dropout: 0.6000000000000001
17 | generative_counterfactual: true
18 | hidden_dim_cond: 329
19 | hidden_dim_x: 1536
20 | inject_covariates_decoder: true
21 | inject_covariates_encoder: true
22 | latent_dim: 32
23 | lr: 0.0003892340882567657
24 | mask_prior_probability: 0.01017246847284454
25 | n_layers_decoder: 3
26 | n_layers_encoder_e: 2
27 | n_layers_encoder_x: 5
28 | softplus_output: true
29 | wd: 9.99382916747775e-07
30 |
31 | data:
32 | add_controls: false
33 | batch_size: 256
34 | evaluation:
35 | chunk_size: 10
36 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/mcfaline23/cpa_best_params_mcfaline23_full.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: cpa
8 | - override /data: mcfaline23
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 245
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 |
19 | data:
20 | evaluation:
21 | chunk_size: 5
22 | batch_size: 8000
23 | num_workers: 12
24 | splitter:
25 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/full_covariate_split.csv
26 |
27 | model:
28 | adv_classifier_hidden_dim: 281
29 | adv_classifier_n_layers: 1
30 | adv_steps: 3
31 | adv_weight: 3.3267124584103804
32 | dropout: 0.30000000000000004
33 | elementwise_affine: false
34 | hidden_dim: 4352
35 | kl_weight: 0.4194843985243823
36 | lr: 1.7339648673958316e-05
37 | n_latent: 64
38 | n_layers_covar_emb: 1
39 | n_layers_encoder: 1
40 | n_layers_pert_emb: 2
41 | n_warmup_epochs: 10
42 | penalty_weight: 0.6115268823779908
43 | softplus_output: false
44 | variational: true
45 | wd: 2.346560291976718e-06
46 |
47 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/mcfaline23/cpa_best_params_mcfaline23_medium.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: cpa
8 | - override /data: mcfaline23
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 245
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 |
19 | data:
20 | evaluation:
21 | chunk_size: 5
22 | batch_size: 8000
23 | num_workers: 12
24 | splitter:
25 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/medium_covariate_split.csv
26 |
27 | model:
28 | adv_classifier_hidden_dim: 205
29 | adv_classifier_n_layers: 1
30 | adv_steps: 3
31 | adv_weight: 0.9729207769467227
32 | dropout: 0.2
33 | elementwise_affine: false
34 | hidden_dim: 4352
35 | kl_weight: 8.582254554076528
36 | lr: 6.168904762775334e-05
37 | n_latent: 64
38 | n_layers_covar_emb: 1
39 | n_layers_encoder: 7
40 | n_layers_pert_emb: 2
41 | n_warmup_epochs: 15
42 | penalty_weight: 12.269805096188989
43 | softplus_output: false
44 | variational: true
45 | wd: 3.2081170417806746e-07
46 |
47 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/mcfaline23/cpa_best_params_mcfaline23_small.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: cpa
8 | - override /data: mcfaline23
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 245
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 |
19 | data:
20 | evaluation:
21 | chunk_size: 5
22 | batch_size: 8000
23 | num_workers: 12
24 | splitter:
25 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/small_covariate_split.csv
26 |
27 | model:
28 | adv_classifier_hidden_dim: 446
29 | adv_classifier_n_layers: 3
30 | adv_steps: 3
31 | adv_weight: 6.657453598080865
32 | dropout: 0.1
33 | elementwise_affine: false
34 | hidden_dim: 2304
35 | kl_weight: 1.1791506221938572
36 | lr: 1.9312841187693903e-05
37 | n_latent: 256
38 | n_layers_covar_emb: 1
39 | n_layers_encoder: 7
40 | n_layers_pert_emb: 3
41 | n_warmup_epochs: 10
42 | penalty_weight: 9.392286981965224
43 | softplus_output: false
44 | variational: true
45 | wd: 1.9552773792043903e-08
46 |
47 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/mcfaline23/cpa_no_adv_best_params_mcfaline23_full.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: cpa
8 | - override /data: mcfaline23
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 245
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 |
19 | data:
20 | evaluation:
21 | chunk_size: 5
22 | batch_size: 8000
23 | num_workers: 12
24 | splitter:
25 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/full_covariate_split.csv
26 |
27 | model:
28 | use_adversary: false
29 | adv_classifier_hidden_dim: 281
30 | adv_classifier_n_layers: 1
31 | adv_steps: 3
32 | adv_weight: 3.3267124584103804
33 | dropout: 0.30000000000000004
34 | elementwise_affine: false
35 | hidden_dim: 4352
36 | kl_weight: 0.4194843985243823
37 | lr: 1.7339648673958316e-05
38 | n_latent: 64
39 | n_layers_covar_emb: 1
40 | n_layers_encoder: 1
41 | n_layers_pert_emb: 2
42 | n_warmup_epochs: 500
43 | penalty_weight: 0.6115268823779908
44 | softplus_output: false
45 | variational: true
46 | wd: 2.346560291976718e-06
47 |
48 |
49 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/mcfaline23/cpa_no_adv_best_params_mcfaline23_medium.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: cpa
8 | - override /data: mcfaline23
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 245
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 |
19 | data:
20 | evaluation:
21 | chunk_size: 5
22 | batch_size: 8000
23 | num_workers: 12
24 | splitter:
25 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/medium_covariate_split.csv
26 |
27 | model:
28 | use_adversary: false
29 | adv_classifier_hidden_dim: 205
30 | adv_classifier_n_layers: 1
31 | adv_steps: 3
32 | adv_weight: 0.9729207769467227
33 | dropout: 0.2
34 | elementwise_affine: false
35 | hidden_dim: 4352
36 | kl_weight: 8.582254554076528
37 | lr: 6.168904762775334e-05
38 | n_latent: 64
39 | n_layers_covar_emb: 1
40 | n_layers_encoder: 7
41 | n_layers_pert_emb: 2
42 | n_warmup_epochs: 15
43 | penalty_weight: 12.269805096188989
44 | softplus_output: false
45 | variational: true
46 | wd: 3.2081170417806746e-07
47 |
48 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/mcfaline23/cpa_no_adv_best_params_mcfaline23_small.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: cpa
8 | - override /data: mcfaline23
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 245
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 |
19 | data:
20 | evaluation:
21 | chunk_size: 5
22 | batch_size: 8000
23 | num_workers: 12
24 | splitter:
25 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/small_covariate_split.csv
26 |
27 | model:
28 | use_adversary: false
29 | adv_classifier_hidden_dim: 446
30 | adv_classifier_n_layers: 3
31 | adv_steps: 3
32 | adv_weight: 6.657453598080865
33 | dropout: 0.1
34 | elementwise_affine: false
35 | hidden_dim: 2304
36 | kl_weight: 1.1791506221938572
37 | lr: 1.9312841187693903e-05
38 | n_latent: 256
39 | n_layers_covar_emb: 1
40 | n_layers_encoder: 7
41 | n_layers_pert_emb: 3
42 | n_warmup_epochs: 10
43 | penalty_weight: 9.392286981965224
44 | softplus_output: false
45 | variational: true
46 | wd: 1.9552773792043903e-08
47 |
48 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/mcfaline23/decoder_only_best_params_mcfaline23_full.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: decoder_only
4 | - override /data: mcfaline23
5 |
6 | seed: 245
7 |
8 | trainer:
9 | min_epochs: 5
10 | max_epochs: 500
11 |
12 | data:
13 | evaluation:
14 | chunk_size: 5
15 | batch_size: 8000
16 | num_workers: 12
17 | splitter:
18 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/full_covariate_split.csv
19 |
20 | model:
21 | encoder_width: 3328
22 | lr: 0.00019009753022442835
23 | n_layers: 1
24 | softplus_output: true
25 | wd: 5.697560765121058e-07
26 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/mcfaline23/decoder_only_best_params_mcfaline23_medium.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: decoder_only
4 | - override /data: mcfaline23
5 |
6 | seed: 245
7 |
8 | trainer:
9 | min_epochs: 5
10 | max_epochs: 500
11 |
12 | data:
13 | evaluation:
14 | chunk_size: 5
15 | batch_size: 8000
16 | num_workers: 8
17 | splitter:
18 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/medium_covariate_split.csv
19 |
20 | model:
21 | encoder_width: 4352
22 | lr: 0.00023414037024915073
23 | n_layers: 1
24 | softplus_output: true
25 | wd: 7.503473517441124e-07
26 |
27 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/mcfaline23/decoder_only_best_params_mcfaline23_small.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: decoder_only
4 | - override /data: mcfaline23
5 |
6 | seed: 245
7 |
8 | trainer:
9 | min_epochs: 5
10 | max_epochs: 500
11 |
12 | data:
13 | evaluation:
14 | chunk_size: 5
15 | batch_size: 8000
16 | num_workers: 8
17 | splitter:
18 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/small_covariate_split.csv
19 |
20 | model:
21 | encoder_width: 4352
22 | lr: 0.0007134752592105323
23 | n_layers: 1
24 | softplus_output: true
25 | wd: 6.892595385769382e-07
26 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/mcfaline23/latent_additive_best_params_mcfaline23_full.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: latent_additive
4 | - override /data: mcfaline23
5 |
6 | seed: 245
7 |
8 | trainer:
9 | min_epochs: 5
10 | max_epochs: 500
11 |
12 | data:
13 | evaluation:
14 | chunk_size: 5
15 | batch_size: 8000
16 | num_workers: 12
17 | splitter:
18 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/full_covariate_split.csv
19 |
20 | model:
21 | dropout: 0.1
22 | encoder_width: 4352
23 | latent_dim: 256
24 | lr: 0.00046952967507921957
25 | n_layers: 1
26 | wd: 3.348258680704949e-08
27 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/mcfaline23/latent_additive_best_params_mcfaline23_medium.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: latent_additive
4 | - override /data: mcfaline23
5 |
6 | seed: 245
7 |
8 | trainer:
9 | min_epochs: 5
10 | max_epochs: 500
11 |
12 | data:
13 | evaluation:
14 | chunk_size: 5
15 | batch_size: 8000
16 | num_workers: 8
17 | splitter:
18 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/medium_covariate_split.csv
19 |
20 | model:
21 | dropout: 0.0
22 | encoder_width: 3328
23 | latent_dim: 256
24 | lr: 2.0752864206129073e-05
25 | n_layers: 1
26 | wd: 9.448743387490416e-08
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/mcfaline23/latent_additive_best_params_mcfaline23_small.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: latent_additive
4 | - override /data: mcfaline23
5 |
6 | seed: 245
7 |
8 | trainer:
9 | min_epochs: 5
10 | max_epochs: 500
11 |
12 | data:
13 | evaluation:
14 | chunk_size: 5
15 | batch_size: 8000
16 | num_workers: 8
17 | splitter:
18 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/small_covariate_split.csv
19 |
20 | model:
21 | dropout: 0.0
22 | encoder_width: 1280
23 | latent_dim: 512
24 | lr: 0.00021234604809346303
25 | n_layers: 1
26 | wd: 1.4101074283996682e-07
27 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/mcfaline23/linear_additive_best_params_mcfaline23_full.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: linear_additive
4 | - override /data: mcfaline23
5 |
6 | seed: 245
7 |
8 | trainer:
9 | min_epochs: 5
10 | max_epochs: 500
11 |
12 | data:
13 | evaluation:
14 | chunk_size: 5
15 | batch_size: 8000
16 | num_workers: 8
17 | splitter:
18 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/full_covariate_split.csv
19 |
20 | model:
21 | inject_covariates: True
22 | lr: 0.0013589931928117893
23 | wd: 1.0042027312774061e-08
24 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/mcfaline23/linear_additive_best_params_mcfaline23_medium.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: linear_additive
4 | - override /data: mcfaline23
5 |
6 | seed: 245
7 |
8 | trainer:
9 | min_epochs: 5
10 | max_epochs: 500
11 |
12 | data:
13 | evaluation:
14 | chunk_size: 5
15 | batch_size: 8000
16 | num_workers: 8
17 | splitter:
18 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/medium_covariate_split.csv
19 |
20 | model:
21 | inject_covariates: true
22 | lr: 0.0022999550256692877
23 | wd: 1.0685179638547894e-08
24 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/mcfaline23/linear_additive_best_params_mcfaline23_small.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: linear_additive
4 | - override /data: mcfaline23
5 |
6 | seed: 245
7 |
8 | trainer:
9 | min_epochs: 5
10 | max_epochs: 500
11 |
12 | data:
13 | evaluation:
14 | chunk_size: 5
15 | batch_size: 8000
16 | num_workers: 12
17 | splitter:
18 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/small_covariate_split.csv
19 |
20 | model:
21 | inject_covariates: true
22 | lr: 0.004950837663306272
23 | wd: 1.0397598707776114e-08
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/mcfaline23/sams_best_params_mcfaline23_full.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: sams_vae
5 | - override /callbacks: default
6 | - override /data: mcfaline23
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 | precision: 32
12 |
13 | model:
14 | dropout: 0.1
15 | generative_counterfactual: true
16 | hidden_dim_cond: 329
17 | hidden_dim_x: 3328
18 | inject_covariates_decoder: true
19 | inject_covariates_encoder: true
20 | latent_dim: 128
21 | lr: 3.6723171184888935e-06
22 | mask_prior_probability: 0.018435146998978795
23 | n_layers_decoder: 3
24 | n_layers_encoder_e: 4
25 | n_layers_encoder_x: 5
26 | softplus_output: true
27 | wd: 2.569943996675991e-07
28 |
29 | data:
30 | splitter:
31 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/full_covariate_split.csv
32 | add_controls: false
33 | batch_size: 256
34 | evaluation:
35 | chunk_size: 10
36 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/mcfaline23/sams_best_params_mcfaline23_medium.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: sams_vae
5 | - override /callbacks: default
6 | - override /data: mcfaline23
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 | precision: 32
12 |
13 | model:
14 | dropout: 0.30000000000000004
15 | generative_counterfactual: true
16 | hidden_dim_cond: 329
17 | hidden_dim_x: 1536
18 | inject_covariates_decoder: true
19 | inject_covariates_encoder: true
20 | latent_dim: 128
21 | lr: 6.4220933723681095e-06
22 | mask_prior_probability: 0.021735343521054572
23 | n_layers_decoder: 2
24 | n_layers_encoder_e: 3
25 | n_layers_encoder_x: 4
26 | softplus_output: true
27 | wd: 3.2516220842555925e-09
28 |
29 | data:
30 | splitter:
31 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/medium_covariate_split.csv
32 | add_controls: false
33 | batch_size: 256
34 | evaluation:
35 | chunk_size: 10
36 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/mcfaline23/sams_best_params_mcfaline23_small.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: sams_vae
5 | - override /callbacks: default
6 | - override /data: mcfaline23
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 | precision: 32
12 |
13 | model:
14 | dropout: 0.30000000000000004
15 | generative_counterfactual: true
16 | hidden_dim_cond: 329
17 | hidden_dim_x: 2304
18 | inject_covariates_decoder: true
19 | inject_covariates_encoder: true
20 | latent_dim: 24
21 | lr: 6.950438037779035e-06
22 | mask_prior_probability: 0.03505373699588523
23 | n_layers_decoder: 1
24 | n_layers_encoder_e: 2
25 | n_layers_encoder_x: 3
26 | softplus_output: true
27 | wd: 8.162330576054943e-09
28 |
29 | data:
30 | splitter:
31 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/small_covariate_split.csv
32 | add_controls: false
33 | batch_size: 256
34 | evaluation:
35 | chunk_size: 10
36 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/mcfaline23/sams_modified_best_params_mcfaline23_full.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: sams_vae
5 | - override /callbacks: default
6 | - override /data: mcfaline23
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 | precision: 32
12 |
13 | model:
14 | disable_e_dist: true
15 | disable_sparsity: true
16 | dropout: 0.1
17 | generative_counterfactual: true
18 | hidden_dim_cond: 329
19 | hidden_dim_x: 3584
20 | inject_covariates_decoder: true
21 | inject_covariates_encoder: true
22 | latent_dim: 24
23 | lr: 3.952811264448736e-05
24 | mask_prior_probability: 0.07551256281258875
25 | n_layers_decoder: 3
26 | n_layers_encoder_e: 3
27 | n_layers_encoder_x: 1
28 | softplus_output: true
29 | wd: 6.380767802230304e-06
30 |
31 | data:
32 | splitter:
33 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/full_covariate_split.csv
34 | add_controls: false
35 | batch_size: 256
36 | evaluation:
37 | chunk_size: 10
38 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/mcfaline23/sams_modified_best_params_mcfaline23_medium.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: sams_vae
5 | - override /callbacks: default
6 | - override /data: mcfaline23
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 | precision: 32
12 |
13 | model:
14 | disable_e_dist: true
15 | disable_sparsity: true
16 | dropout: 0.30000000000000004
17 | generative_counterfactual: true
18 | hidden_dim_cond: 329
19 | hidden_dim_x: 3584
20 | inject_covariates_decoder: true
21 | inject_covariates_encoder: true
22 | latent_dim: 128
23 | lr: 2.97852617667106e-05
24 | mask_prior_probability: 0.012905035283331237
25 | n_layers_decoder: 2
26 | n_layers_encoder_e: 1
27 | n_layers_encoder_x: 1
28 | softplus_output: true
29 | wd: 1.0248794045174245e-05
30 |
31 | data:
32 | splitter:
33 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/medium_covariate_split.csv
34 | add_controls: false
35 | batch_size: 256
36 | evaluation:
37 | chunk_size: 10
38 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/mcfaline23/sams_modified_best_params_mcfaline23_small.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: sams_vae
5 | - override /callbacks: default
6 | - override /data: mcfaline23
7 |
8 | trainer:
9 | max_epochs: 400
10 | min_epochs: 5
11 | precision: 32
12 |
13 | model:
14 | disable_e_dist: true
15 | disable_sparsity: true
16 | dropout: 0.30000000000000004
17 | generative_counterfactual: true
18 | hidden_dim_cond: 329
19 | hidden_dim_x: 3328
20 | inject_covariates_decoder: true
21 | inject_covariates_encoder: true
22 | latent_dim: 32
23 | lr: 3.4324600190407314e-05
24 | mask_prior_probability: 0.05761980928292434
25 | n_layers_decoder: 1
26 | n_layers_encoder_e: 4
27 | n_layers_encoder_x: 1
28 | softplus_output: true
29 | wd: 1.1815419587902247e-08
30 |
31 | data:
32 | splitter:
33 | split_path: ${paths.data_dir}/mcfaline23_gxe_splits/small_covariate_split.csv
34 | add_controls: false
35 | batch_size: 256
36 | evaluation:
37 | chunk_size: 10
38 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/norman19/biolord_best_params_norman19.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: biolord
8 | - override /data: norman19
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 245
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 |
19 | data:
20 | batch_size: 1000
21 | splitter:
22 | max_heldout_fraction_per_covariate: 0.7
23 | add_controls: False
24 |
25 | model:
26 | dropout: 0.4
27 | encoder_width: 2304
28 | latent_dim: 512
29 | lr: 0.00016701245023478605
30 | n_layers: 1
31 | penalty_weight: 2621.3333751927075
32 | wd: 0.00046077468143989676
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/norman19/cpa_best_params_norman19.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: cpa
8 | - override /data: norman19
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 245
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 |
19 | data:
20 | splitter:
21 | max_heldout_fraction_per_covariate: 0.7
22 | add_controls: False
23 |
24 | model:
25 | adv_classifier_hidden_dim: 776
26 | adv_classifier_n_layers: 2
27 | adv_steps: 2
28 | adv_weight: 3.804798319367216
29 | dropout: 0.2
30 | elementwise_affine: false
31 | hidden_dim: 3328
32 | kl_weight: 19.756374864509844
33 | lr: 6.844044407644798e-05
34 | n_latent: 128
35 | n_layers_covar_emb: 1
36 | n_layers_encoder: 1
37 | n_layers_pert_emb: 2
38 | n_warmup_epochs: 15
39 | penalty_weight: 1.8984251429709187
40 | softplus_output: false
41 | variational: true
42 | wd: 1.6600312433505752e-07
43 |
44 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/norman19/cpa_no_adv_best_params_norman19.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: cpa
8 | - override /data: norman19
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 245
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 |
19 | data:
20 | splitter:
21 | max_heldout_fraction_per_covariate: 0.7
22 | add_controls: False
23 |
24 | model:
25 | use_adversary: False
26 | adv_classifier_hidden_dim: 776
27 | adv_classifier_n_layers: 2
28 | adv_steps: 2
29 | adv_weight: 3.804798319367216
30 | dropout: 0.2
31 | elementwise_affine: false
32 | hidden_dim: 3328
33 | kl_weight: 19.756374864509844
34 | lr: 6.844044407644798e-05
35 | n_latent: 128
36 | n_layers_covar_emb: 1
37 | n_layers_encoder: 1
38 | n_layers_pert_emb: 2
39 | n_warmup_epochs: 500
40 | penalty_weight: 1.8984251429709187
41 | softplus_output: false
42 | variational: true
43 | wd: 1.6600312433505752e-07
44 |
45 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/norman19/cpa_scgpt_best_params_norman19.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: cpa
8 | - override /data: norman19
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 245
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 400
18 |
19 | data:
20 | splitter:
21 | max_heldout_fraction_per_covariate: 0.7
22 | datapath: ${paths.data_dir}/norman19_preprocessed_with_embeddings.h5ad
23 | embedding_key: scgpt_embbeddings
24 | add_controls: False
25 |
26 | model:
27 | adv_classifier_hidden_dim: 837
28 | adv_classifier_n_layers: 2
29 | adv_steps: 3
30 | adv_weight: 0.12619384778246517
31 | dropout: 0.6000000000000001
32 | elementwise_affine: false
33 | hidden_dim: 5376
34 | kl_weight: 11.931005035532673
35 | lr: 0.0005134219006570818
36 | n_latent: 256
37 | n_layers_covar_emb: 1
38 | n_layers_encoder: 1
39 | n_layers_pert_emb: 1
40 | n_warmup_epochs: 10
41 | penalty_weight: 0.9962990820647598
42 | softplus_output: false
43 | variational: true
44 | wd: 6.849486237372675e-05
45 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/norman19/decoder_best_params_norman19.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: decoder_only
8 | - override /data: norman19
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 245
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 |
19 | data:
20 | splitter:
21 | max_heldout_fraction_per_covariate: 0.7
22 |
23 | model:
24 | use_covariates: False
25 | encoder_width: 3328
26 | lr: 0.00013753391233021738
27 | n_layers: 3
28 | softplus_output: false
29 | wd: 4.417109615721373e-05
30 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/norman19/latent_best_params_norman19.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: latent_additive
8 | - override /data: norman19
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 245
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 |
19 | data:
20 | splitter:
21 | max_heldout_fraction_per_covariate: 0.7
22 |
23 | model:
24 | inject_covariates_encoder: False
25 | inject_covariates_decoder: False
26 | dropout: 0.1
27 | encoder_width: 5376
28 | latent_dim: 64
29 | lr: 6.779597503293815e-05
30 | n_layers: 1
31 | wd: 1.0406767176550133e-08
32 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/norman19/latent_scgpt_best_params_norman19.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: latent_additive
8 | - override /data: norman19
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 245
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 |
19 | data:
20 | splitter:
21 | max_heldout_fraction_per_covariate: 0.7
22 | datapath: ${paths.data_dir}/norman19_preprocessed_with_embeddings.h5ad
23 | embedding_key: scgpt_embbeddings
24 |
25 | model:
26 | inject_covariates_encoder: False
27 | inject_covariates_decoder: False
28 | dropout: 0.4
29 | encoder_width: 5376
30 | latent_dim: 512
31 | lr: 8.965448576753094e-05
32 | n_layers: 1
33 | wd: 1.9716752225147476e-05
34 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/norman19/linear_best_params_norman19.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: linear_additive
8 | - override /data: norman19
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 245
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 |
19 | data:
20 | splitter:
21 | max_heldout_fraction_per_covariate: 0.7
22 |
23 | model:
24 | inject_covariates: false
25 | lr: 0.004716813309487752
26 | wd: 1.7588044643755207e-08
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/norman19/sams_best_params_norman19.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: sams_vae
8 | - override /data: norman19
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 137
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 | precision: 32
19 |
20 | data:
21 | batch_size: 256
22 | splitter:
23 | max_heldout_fraction_per_covariate: 0.7
24 | add_controls: False
25 | evaluation:
26 | chunk_size: 10
27 |
28 | model:
29 | dropout: 0.4
30 | hidden_dim_cond: 1280
31 | hidden_dim_x: 3584
32 | inject_covariates_decoder: false
33 | inject_covariates_encoder: false
34 | latent_dim: 128
35 | lr: 3.0263552583537424e-05
36 | mask_prior_probability: 0.11384202578456981
37 | n_genes: null
38 | n_layers_decoder: 2
39 | n_layers_encoder_e: 3
40 | n_layers_encoder_x: 1
41 | n_perts: null
42 | softplus_output: true
43 | wd: 4.244962253886731e-09
44 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/norman19/sams_modified_best_params_norman19.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: sams_vae
8 | - override /data: norman19
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 137
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 | precision: 32
19 |
20 | data:
21 | batch_size: 256
22 | splitter:
23 | max_heldout_fraction_per_covariate: 0.7
24 | add_controls: False
25 | evaluation:
26 | chunk_size: 10
27 |
28 | model:
29 | disable_e_dist: true
30 | disable_sparsity: true
31 | dropout: 0.30000000000000004
32 | generative_counterfactual: true
33 | hidden_dim_cond: 329
34 | hidden_dim_x: 2048
35 | inject_covariates_decoder: false
36 | inject_covariates_encoder: false
37 | latent_dim: 32
38 | lr: 1.5981872317838004e-05
39 | mask_prior_probability: 0.018746135628343086
40 | n_layers_decoder: 3
41 | n_layers_encoder_e: 1
42 | n_layers_encoder_x: 1
43 | softplus_output: true
44 | wd: 6.55647585152116e-05
45 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/sciplex3/biolord_best_params_sciplex3.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: biolord
8 | - override /data: sciplex3
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 |
14 | seed: 245
15 |
16 | trainer:
17 | min_epochs: 5
18 | max_epochs: 500
19 | precision: 16
20 |
21 | data:
22 | batch_size: 1000
23 | add_controls: False
24 |
25 | model:
26 | dropout: 0.4
27 | encoder_width: 2304
28 | latent_dim: 512
29 | lr: 0.00016701245023478605
30 | wd: 0.00046077468143989676
31 | n_layers: 1
32 | penalty_weight: 2621.3333751927075
33 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/sciplex3/cpa_best_params_sciplex3.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: cpa
8 | - override /data: sciplex3
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 245
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 |
19 | data:
20 | add_controls: False
21 |
22 | model:
23 | adv_classifier_hidden_dim: 383
24 | adv_classifier_n_layers: 2
25 | adv_steps: 3
26 | adv_weight: 0.21007098128462928
27 | dropout: 0.1
28 | elementwise_affine: false
29 | hidden_dim: 1280
30 | kl_weight: 0.2643107960544633
31 | lr: 0.0002688668829765555
32 | n_latent: 512
33 | n_layers_covar_emb: 1
34 | n_layers_encoder: 3
35 | n_layers_pert_emb: 2
36 | n_warmup_epochs: 5
37 | penalty_weight: 10.2126759375119
38 | softplus_output: false
39 | variational: true
40 | wd: 2.992280181677909e-06
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/sciplex3/cpa_no_adv_best_params_sciplex3.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: cpa
8 | - override /data: sciplex3
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 245
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 |
19 | data:
20 | add_controls: False
21 |
22 | model:
23 | use_adversary: False
24 | adv_classifier_hidden_dim: 383
25 | adv_classifier_n_layers: 2
26 | adv_steps: 3
27 | adv_weight: 0.21007098128462928
28 | decoder_distribution: IsotropicGaussian
29 | dropout: 0.1
30 | elementwise_affine: false
31 | hidden_dim: 1280
32 | kl_weight: 0.2643107960544633
33 | lr: 0.0002688668829765555
34 | n_latent: 512
35 | n_layers_covar_emb: 1
36 | n_layers_encoder: 3
37 | n_layers_pert_emb: 2
38 | n_warmup_epochs: 500
39 | penalty_weight: 10.2126759375119
40 | softplus_output: false
41 | variational: true
42 | wd: 2.992280181677909e-06
43 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/sciplex3/cpa_scgpt_best_params_sciplex3.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: cpa
8 | - override /data: sciplex3
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 245
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 |
19 | data:
20 | datapath: ${paths.data_dir}/srivatsan20_highest_dose_preprocessed_with_embeddings.h5ad
21 | embedding_key: scgpt_embbeddings
22 | add_controls: False
23 |
24 | model:
25 | adv_classifier_hidden_dim: 681
26 | adv_classifier_n_layers: 1
27 | adv_steps: 2
28 | adv_weight: 1.9003113280995905
29 | dropout: 0.1
30 | elementwise_affine: false
31 | hidden_dim: 5376
32 | kl_weight: 1.095403737862168
33 | lr: 0.0001711130140991506
34 | n_latent: 512
35 | n_layers_covar_emb: 1
36 | n_layers_encoder: 1
37 | n_layers_pert_emb: 4
38 | n_warmup_epochs: 10
39 | penalty_weight: 19.75923566901279
40 | softplus_output: false
41 | variational: true
42 | wd: 0.0002629416918365637
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/sciplex3/decoder_best_params_sciplex3.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: decoder_only
8 | - override /data: sciplex3
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 245
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 |
19 | model:
20 | encoder_width: 4352
21 | lr: 6.798204826057849e-05
22 | n_layers: 5
23 | softplus_output: true
24 | wd: 7.601981738045318e-06
25 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/sciplex3/decoder_cov_best_params_sciplex3.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: decoder_only
8 | - override /data: sciplex3
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 245
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 |
19 | model:
20 | use_perturbations: False
21 | encoder_width: 4352
22 | lr: 5.330813318795293e-06
23 | wd: 1.0553034125634522e-08
24 | n_layers: 5
25 | softplus_output: true
26 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/sciplex3/latent_best_params_sciplex3.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: latent_additive
8 | - override /data: sciplex3
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 245
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 |
19 | model:
20 | dropout: 0.6000000000000001
21 | encoder_width: 3328
22 | latent_dim: 64
23 | lr: 0.0013195915995113784
24 | n_layers: 1
25 | wd: 9.371223651937301e-07
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/sciplex3/latent_scgpt_best_params_sciplex3.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: latent_additive
8 | - override /data: sciplex3
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 245
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 |
19 | data:
20 | datapath: ${paths.data_dir}/srivatsan20_highest_dose_preprocessed_with_embeddings.h5ad
21 | embedding_key: scgpt_embbeddings
22 |
23 | model:
24 | lr_scheduler_freq: 5
25 | dropout: 0.1
26 | encoder_width: 2304
27 | latent_dim: 128
28 | lr: 1.2051920391885433e-05
29 | n_layers: 3
30 | wd: 2.701962713462281e-07
31 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/sciplex3/linear_best_params_sciplex3.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: linear_additive
8 | - override /data: sciplex3
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 245
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 |
19 | model:
20 | lr: 0.0007787610860511324
21 | wd: 1.0632418512521614e-08
22 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/sciplex3/sams_best_params_sciplex3.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: sams_vae
8 | - override /data: sciplex3
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 137
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 | precision: 32
19 |
20 | data:
21 | batch_size: 256
22 | add_controls: False
23 | evaluation:
24 | chunk_size: 10
25 |
26 | model:
27 | dropout: 0.30000000000000004
28 | hidden_dim_cond: 2816
29 | hidden_dim_x: 1536
30 | inject_covariates_decoder: false
31 | inject_covariates_encoder: false
32 | latent_dim: 64
33 | lr: 0.00040019008240191795
34 | mask_prior_probability: 0.011486688007756388
35 | n_genes: null
36 | n_layers_decoder: 1
37 | n_layers_encoder_e: 2
38 | n_layers_encoder_x: 4
39 | n_perts: null
40 | softplus_output: true
41 | wd: 1.749495655302928e-06
42 |
--------------------------------------------------------------------------------
/src/perturbench/configs/experiment/neurips2024/sciplex3/sams_modified_best_params_sciplex3.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /model: sams_vae
8 | - override /data: sciplex3
9 |
10 | # all parameters below will be merged with parameters from default configurations set above
11 | # this allows you to overwrite only specified parameters
12 |
13 | seed: 137
14 |
15 | trainer:
16 | min_epochs: 5
17 | max_epochs: 500
18 | precision: 32
19 |
20 | data:
21 | batch_size: 256
22 | add_controls: False
23 | evaluation:
24 | chunk_size: 10
25 |
26 | model:
27 | disable_e_dist: true
28 | disable_sparsity: true
29 | dropout: 0.4
30 | generative_counterfactual: true
31 | hidden_dim_cond: 329
32 | hidden_dim_x: 2560
33 | inject_covariates_decoder: true
34 | inject_covariates_encoder: true
35 | latent_dim: 32
36 | lr: 0.000341765289234546
37 | mask_prior_probability: 0.05618183575503891
38 | n_layers_decoder: 3
39 | n_layers_encoder_e: 2
40 | n_layers_encoder_x: 4
41 | softplus_output: true
42 | wd: 6.943657713062555e-08
43 |
--------------------------------------------------------------------------------
/src/perturbench/configs/hpo/biolord_hpo.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | #for HPO runs.
4 | defaults:
5 | - override /hydra/sweeper: optuna
6 | - override /hydra/launcher: ray
7 | - override /hydra/sweeper/sampler: tpe
8 |
9 | ## Specify multiple metrics and how they should be added together
10 | metrics_to_optimize:
11 | rmse_average: 1.0
12 | rmse_rank_average: 0.1
13 |
14 | hydra:
15 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
16 | sweeper:
17 | direction: minimize
18 | study_name: ${model._target_}_${oc.env:USER}_${now:%Y-%m-%d}/${now:%H-%M-%S} # if there is an existing study with the same name, it will be resumed
19 | storage: null # optuna database to store HPO results.
20 | n_trials: 4 # number of total trials
21 | n_jobs: 2 # number of parallel jobs
22 | max_failure_rate: 1 # overall HPO job will not fail if less than this ratio of trials fail
23 | params:
24 | model.penalty_weight: tag(log, interval(1e1, 10**5))
25 | model.n_layers: range(1, 7, step=2)
26 | model.encoder_width: range(256, 5376, step=1024)
27 | model.latent_dim: choice(64, 128, 192, 256, 512)
28 | model.lr: tag(log, interval(5e-6, 5e-3))
29 | model.wd: tag(log, interval(1e-8, 1e-3))
30 | model.dropout: range(0.0, 0.8, step=0.1)
31 | launcher:
32 | ray:
33 | init:
34 | ## for local runs
35 | num_gpus: 2 # number of total gpus to use
36 | remote:
37 | num_gpus: 1 # number of gpus per trial
38 | max_calls: 1
--------------------------------------------------------------------------------
/src/perturbench/configs/hpo/cpa_hpo.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | #for HPO runs.
4 | defaults:
5 | - override /hydra/sweeper: optuna
6 | - override /hydra/launcher: ray
7 | - override /hydra/sweeper/sampler: tpe
8 |
9 | ## Specify multiple metrics and how they should be added together
10 | metrics_to_optimize:
11 | rmse_average: 1.0
12 | rmse_rank_average: 0.1
13 |
14 | hydra:
15 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
16 | sweeper:
17 | direction: minimize
18 | study_name: ${model._target_}_${oc.env:USER}_${now:%Y-%m-%d}/${now:%H-%M-%S} # if there is an existing study with the same name, it will be resumed
19 | storage: null # optuna database to store HPO results.
20 | n_trials: 4 # number of total trials
21 | n_jobs: 2 # number of parallel jobs
22 | max_failure_rate: 1 # overall HPO job will not fail if less than this ratio of trials fail
23 | params:
24 | model.n_layers_encoder: range(1, 7, step=2)
25 | model.n_layers_pert_emb: range(1, 5, step=1)
26 | model.adv_classifier_n_layers: range(1, 5, step=1)
27 | model.hidden_dim: range(256, 5376, step=1024)
28 | model.adv_classifier_hidden_dim: tag(log, int(interval(128, 1024))) # log scale
29 | model.adv_steps: choice(2, 3, 5, 7, 10, 20, 30)
30 | model.n_latent: choice(64, 128, 192, 256, 512)
31 | model.lr: tag(log, interval(5e-6, 1e-3))
32 | model.wd: tag(log, interval(1e-8, 1e-3))
33 | model.dropout: range(0.0, 0.8, step=0.1)
34 | model.kl_weight: tag(log, interval(0.1, 20))
35 | model.adv_weight: tag(log, interval(0.1, 20))
36 | model.penalty_weight: tag(log, interval(0.1, 20))
37 | launcher:
38 | ray:
39 | init:
40 | ## for local runs
41 | num_gpus: 2 # number of total gpus to use
42 | remote:
43 | num_gpus: 1 # number of gpus per trial
44 | max_calls: 1
--------------------------------------------------------------------------------
/src/perturbench/configs/hpo/decoder_only_hpo.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | #for HPO runs.
4 | defaults:
5 | - override /hydra/sweeper: optuna
6 | - override /hydra/launcher: ray
7 | - override /hydra/sweeper/sampler: tpe
8 |
9 | ## Specify multiple metrics and how they should be added together
10 | metrics_to_optimize:
11 | rmse_average: 1.0
12 | rmse_rank_average: 0.1
13 |
14 | hydra:
15 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
16 | sweeper:
17 | direction: minimize
18 | study_name: ${model._target_}_${oc.env:USER}_${now:%Y-%m-%d}/${now:%H-%M-%S} # if there is an existing study with the same name, it will be resumed
19 | storage: null # optuna database to store HPO results.
20 | n_trials: 4 # number of total trials
21 | n_jobs: 2 # number of parallel jobs
22 | max_failure_rate: 1 # overall HPO job will not fail if less than this ratio of trials fail
23 | params:
24 | model.n_layers: range(1, 7, step=2)
25 | model.encoder_width: range(256, 5376, step=1024)
26 | model.lr: tag(log, interval(5e-6, 5e-3))
27 | model.wd: tag(log, interval(1e-8, 1e-3))
28 | model.softplus_output: choice(true, false)
29 | launcher:
30 | ray:
31 | init:
32 | ## for local runs
33 | num_gpus: 2 # number of total gpus to use
34 | remote:
35 | num_gpus: 1 # number of gpus per trial
36 | max_calls: 1
--------------------------------------------------------------------------------
/src/perturbench/configs/hpo/latent_additive_hpo.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | #for HPO runs.
4 | defaults:
5 | - override /hydra/sweeper: optuna
6 | - override /hydra/launcher: ray
7 | - override /hydra/sweeper/sampler: tpe
8 |
9 | ## Specify multiple metrics and how they should be added together
10 | metrics_to_optimize:
11 | rmse_average: 1.0
12 | rmse_rank_average: 0.1
13 |
14 | hydra:
15 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
16 | sweeper:
17 | direction: minimize
18 | study_name: ${model._target_}_${oc.env:USER}_${now:%Y-%m-%d}/${now:%H-%M-%S} # if there is an existing study with the same name, it will be resumed
19 | storage: null # optuna database to store HPO results.
20 | n_trials: 4 # number of total trials
21 | n_jobs: 2 # number of parallel jobs
22 | max_failure_rate: 1 # overall HPO job will not fail if less than this ratio of trials fail
23 | params:
24 | model.n_layers: range(1, 7, step=2)
25 | model.encoder_width: range(256, 5376, step=1024)
26 | model.latent_dim: choice(64, 128, 192, 256, 512)
27 | model.lr: tag(log, interval(5e-6, 5e-3))
28 | model.wd: tag(log, interval(1e-8, 1e-3))
29 | model.dropout: range(0.0, 0.8, step=0.1)
30 | launcher:
31 | ray:
32 | init:
33 | ## for local runs
34 | num_gpus: 2 # number of total gpus to use
35 | remote:
36 | num_gpus: 1 # number of gpus per trial
37 | max_calls: 1
--------------------------------------------------------------------------------
/src/perturbench/configs/hpo/linear_additive_hpo.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | #for HPO runs.
4 | defaults:
5 | - override /hydra/sweeper: optuna
6 | - override /hydra/launcher: ray
7 | - override /hydra/sweeper/sampler: tpe
8 |
9 | ## Specify multiple metrics and how they should be added together
10 | metrics_to_optimize:
11 | rmse_average: 1.0
12 | rmse_rank_average: 0.1
13 |
14 | hydra:
15 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
16 | sweeper:
17 | direction: minimize
18 | study_name: ${model._target_}_${oc.env:USER}_${now:%Y-%m-%d}/${now:%H-%M-%S} # if there is an existing study with the same name, it will be resumed
19 | storage: null # optuna database to store HPO results.
20 | n_trials: 4 # number of total trials
21 | n_jobs: 2 # number of parallel jobs
22 | max_failure_rate: 1 # overall HPO job will not fail if less than this ratio of trials fail
23 | params:
24 | model.lr: tag(log, interval(5e-6, 5e-3))
25 | model.wd: tag(log, interval(1e-8, 1e-3))
26 | launcher:
27 | ray:
28 | init:
29 | ## for local runs
30 | num_gpus: 2 # number of total gpus to use
31 | remote:
32 | num_gpus: 1 # number of gpus per trial
33 | max_calls: 1
--------------------------------------------------------------------------------
/src/perturbench/configs/hpo/local.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | #for HPO runs.
4 | defaults:
5 | - override /hydra/sweeper: optuna
6 | - override /hydra/launcher: ray
7 | - override /hydra/sweeper/sampler: tpe
8 |
9 | ## Specify multiple metrics and how they should be added together
10 | metrics_to_optimize:
11 | rmse_average: 1.0
12 | rmse_rank_average: 0.1
13 |
14 | hydra:
15 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
16 | sweeper:
17 | direction: minimize
18 | study_name: ${model._target_}_${oc.env:USER}_${now:%Y-%m-%d}/${now:%H-%M-%S} # if there is an existing study with the same name, it will be resumed
19 | storage: null # optuna database to store HPO results.
20 | n_trials: 4 # number of total trials
21 | n_jobs: 2 # number of parallel jobs
22 | max_failure_rate: 1 # overall HPO job will not fail if less than this ratio of trials fail
23 | launcher:
24 | ray:
25 | init:
26 | ## for local runs
27 | num_gpus: 2 # number of total gpus to use
28 | remote:
29 | num_gpus: 1 # number of gpus per trial
30 | max_calls: 1
--------------------------------------------------------------------------------
/src/perturbench/configs/hpo/sams_vae_hpo.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | #for HPO runs.
4 | defaults:
5 | - override /hydra/sweeper: optuna
6 | - override /hydra/launcher: ray
7 | - override /hydra/sweeper/sampler: tpe
8 |
9 | ## Specify multiple metrics and how they should be added together
10 | metrics_to_optimize:
11 | rmse_average: 1.0
12 | rmse_rank_average: 0.1
13 |
14 | hydra:
15 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
16 | sweeper:
17 | direction: minimize
18 | study_name: ${model._target_}_${oc.env:USER}_${now:%Y-%m-%d}/${now:%H-%M-%S} # if there is an existing study with the same name, it will be resumed
19 | storage: null # optuna database to store HPO results.
20 | n_trials: 4 # number of total trials
21 | n_jobs: 2 # number of parallel jobs
22 | max_failure_rate: 1 # overall HPO job will not fail if less than this ratio of trials fail
23 | params:
24 | model.n_layers_encoder_x: range(1, 7, step=2)
25 | model.n_layers_encoder_e: range(1, 7, step=2)
26 | model.n_layers_decoder: range(1, 7, step=2)
27 | model.latent_dim: choice(64, 128, 192, 256, 512)
28 | model.hidden_dim_x: range(256, 5376, step=1024)
29 | model.hidden_dim_cond: range(50, 500, step=50)
30 | model.sparse_additive_mechanism: choice(true, false)
31 | model.mean_field_encoding: choice(true, false)
32 | model.lr: tag(log, interval(5e-6, 5e-3))
33 | model.wd: tag(log, interval(1e-8, 1e-3))
34 | model.mask_prior_probability: tag(log, interval(1e-4, 0.99))
35 | model.dropout: range(0.0, 0.8, step=0.1)
36 | launcher:
37 | ray:
38 | init:
39 | ## for local runs
40 | num_gpus: 2 # number of total gpus to use
41 | remote:
42 | num_gpus: 1 # number of gpus per trial
43 | max_calls: 1
--------------------------------------------------------------------------------
/src/perturbench/configs/hydra/default.yaml:
--------------------------------------------------------------------------------
1 | # https://hydra.cc/docs/configure_hydra/intro/
2 |
3 | # enable color logging
4 | defaults:
5 | - override hydra_logging: colorlog
6 | - override job_logging: colorlog
7 |
8 | # output directory, generated dynamically on each run
9 | run:
10 | dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
11 | sweep:
12 | dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
13 | subdir: ${hydra.job.num}
14 |
15 | job_logging:
16 | handlers:
17 | file:
18 | # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
19 | filename: ${hydra.runtime.output_dir}/${task_name}.log
20 |
--------------------------------------------------------------------------------
/src/perturbench/configs/logger/csv.yaml:
--------------------------------------------------------------------------------
1 | # csv logger built in lightning
2 |
3 | csv:
4 | _target_: lightning.pytorch.loggers.CSVLogger
5 | save_dir: "${paths.output_dir}"
6 | name: "csv/"
7 | prefix: ""
--------------------------------------------------------------------------------
/src/perturbench/configs/logger/default.yaml:
--------------------------------------------------------------------------------
1 | # train with many loggers at once
2 |
3 | defaults:
4 | - tensorboard
5 | - csv
--------------------------------------------------------------------------------
/src/perturbench/configs/logger/tensorboard.yaml:
--------------------------------------------------------------------------------
1 | # TensorBoard logger built in lightning
2 |
3 | tensorboard:
4 | _target_: lightning.pytorch.loggers.TensorBoardLogger
5 | save_dir: "${paths.output_dir}"
6 | name: "tensorboard/"
7 | prefix: ""
--------------------------------------------------------------------------------
/src/perturbench/configs/model/biolord.yaml:
--------------------------------------------------------------------------------
1 | _target_: perturbench.modelcore.models.BiolordStar
2 | n_genes: null
3 | n_perts: null
4 |
5 | ## Optimizer params
6 | lr: 6.72106935566634e-05
7 | wd: 7.90200208844498e-09
8 | penalty_weight: 1.0
9 |
10 | ## Model params
11 | n_layers: 3
12 | encoder_width: 3072
13 | latent_dim: 32
14 | softplus_output: True
15 | dropout: 0.1
--------------------------------------------------------------------------------
/src/perturbench/configs/model/cpa.yaml:
--------------------------------------------------------------------------------
1 | _target_: perturbench.modelcore.models.CPA
2 | n_genes: null
3 | n_perts: null
4 |
5 | n_latent: 128
6 | hidden_dim: 256
7 | n_layers_encoder: 3
8 | n_layers_pert_emb: 2
9 | n_layers_covar_emb: 1
10 |
11 | variational: True
12 |
13 | adv_classifier_n_layers: 2
14 | adv_classifier_hidden_dim: 128
15 |
16 | lr: 1e-5
17 | wd: 1e-8
18 |
19 | dropout: 1e-1
20 | kl_weight: 1.0
21 | adv_weight: 1.0
22 | penalty_weight: 10.0
23 | adv_steps: 2
24 | n_warmup_epochs: 5
25 |
26 | softplus_output: False
27 | elementwise_affine: False
--------------------------------------------------------------------------------
/src/perturbench/configs/model/decoder_only.yaml:
--------------------------------------------------------------------------------
1 | _target_: perturbench.modelcore.models.DecoderOnly
2 | n_genes: null
3 | n_perts: null
4 |
5 | ## Optimizer params
6 | lr: 6.72106935566634e-05
7 | wd: 7.90200208844498e-09
8 |
9 | ## Model params
10 | n_layers: 3
11 | encoder_width: 3072
12 | softplus_output: True
13 | use_covariates: True
14 | use_perturbations: True
--------------------------------------------------------------------------------
/src/perturbench/configs/model/latent_additive.yaml:
--------------------------------------------------------------------------------
1 | _target_: perturbench.modelcore.models.LatentAdditive
2 | n_genes: null
3 | n_perts: null
4 |
5 | ## Optimizer params
6 | lr: 6.72106935566634e-05
7 | wd: 7.90200208844498e-09
8 |
9 | ## Model params
10 | n_layers: 3
11 | encoder_width: 3072
12 | latent_dim: 160
13 | softplus_output: True
14 | inject_covariates_encoder: True
15 | inject_covariates_decoder: True
16 | dropout: 0.1
--------------------------------------------------------------------------------
/src/perturbench/configs/model/linear_additive.yaml:
--------------------------------------------------------------------------------
1 | _target_: perturbench.modelcore.models.LinearAdditive
2 | n_genes: null
3 | n_perts: null
4 | lr: 1e-3
5 | wd: 1e-4
6 | inject_covariates: True
--------------------------------------------------------------------------------
/src/perturbench/configs/model/sams_vae.yaml:
--------------------------------------------------------------------------------
1 | _target_: perturbench.modelcore.models.SparseAdditiveVAE
2 | n_genes: null
3 | n_perts: null
4 |
5 | ## Model params
6 | n_layers_encoder_x: 5
7 | n_layers_encoder_e: 3
8 | n_layers_decoder: 4
9 | hidden_dim_x: 2332
10 | hidden_dim_cond: 329
11 | latent_dim: 55
12 | dropout: 0.2
13 | inject_covariates_encoder: False
14 | inject_covariates_decoder: False
15 | mask_prior_probability: 0.01
16 | softplus_output: True
17 | disable_sparsity: false
18 | disable_e_dist: false
19 | generative_counterfactual: false
20 |
21 | ## Optimizer params
22 | lr: 8.184599223121104e-05
23 | wd: 2.4153279969771546e-05
--------------------------------------------------------------------------------
/src/perturbench/configs/paths/default.yaml:
--------------------------------------------------------------------------------
1 | # path to data directory
2 | data_dir: ./notebooks/neurips2024/perturbench_data/
3 |
4 | # path to logging directory
5 | log_dir: ./logs/
6 |
7 | # path to output directory, created dynamically by hydra
8 | # path generation pattern is specified in `configs/hydra/default.yaml`
9 | # use it to store all files generated during the run, like ckpts and metrics
10 | output_dir: ${hydra:runtime.output_dir}
11 |
12 | # path to working directory
13 | work_dir: ${hydra:runtime.cwd}
14 |
--------------------------------------------------------------------------------
/src/perturbench/configs/predict.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # specify here default configuration
4 | # order of defaults determines the order in which configs override each other
5 | defaults:
6 | - _self_
7 | - data: ??? ## Specify data config
8 | - model: ??? ## Specify model config
9 | - trainer: default
10 | - paths: default
11 | - hydra: default
12 |
13 | # experiment configs allow for version control of specific hyperparameters
14 | # e.g. best hyperparameters for given model and datamodule
15 | - experiment: null
16 |
17 | # Provide checkpoint path to pretrained model
18 | ckpt_path: ???
19 |
20 | # task name, determines output directory path
21 | task_name: "predict"
22 |
23 | # Path to prediction dataframe (generate with notebooks/demos/generate_prediction_dataframe)
24 | prediction_dataframe_path: ???
25 |
26 | # seed for random number generators in pytorch, numpy and python.random
27 | seed: null
28 |
29 | # Path to save predictions
30 | output_path: "${paths.output_dir}/predictions/"
31 |
32 | # Number of perturbations to generate in memory at once
33 | chunk_size: 50
--------------------------------------------------------------------------------
/src/perturbench/configs/train.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # specify here default configuration
4 | # order of defaults determines the order in which configs override each other
5 | defaults:
6 | - _self_
7 | - data: ???
8 | - model: linear_additive
9 | - callbacks: default
10 | - logger: default
11 | - trainer: default
12 | - paths: default
13 | - hydra: default
14 |
15 | # experiment configs allow for version control of specific hyperparameters
16 | # e.g. best hyperparameters for given model and datamodule
17 | - experiment: null
18 |
19 | # config for hyperparameter optimization
20 | - hpo: null
21 |
22 | # task name, determines output directory path
23 | task_name: "train"
24 |
25 | # set False to skip model training
26 | train: True
27 |
28 | # evaluate on test set, using best model weights achieved during training
29 | # lightning chooses best weights based on the metric specified in checkpoint callback
30 | test: True
31 |
32 | # simply provide checkpoint path to resume training
33 | ckpt_path: null
34 |
35 | # seed for random number generators in pytorch, numpy and python.random
36 | seed: null
37 |
--------------------------------------------------------------------------------
/src/perturbench/configs/trainer/cpu.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default
3 |
4 | accelerator: cpu
5 | devices: 1
6 |
--------------------------------------------------------------------------------
/src/perturbench/configs/trainer/default.yaml:
--------------------------------------------------------------------------------
1 | _target_: lightning.pytorch.trainer.Trainer
2 |
3 | default_root_dir: ${paths.output_dir}
4 |
5 | min_epochs: 1 # prevents early stopping
6 | max_epochs: 10
7 |
8 | accelerator: gpu
9 | devices: 1
10 |
11 | # mixed precision for extra speed-up
12 | precision: 16
13 |
14 | # perform a validation loop every N training epochs
15 | check_val_every_n_epoch: 1
16 |
17 | # set True to to ensure deterministic results
18 | # makes training slower but gives more reproducibility than just setting seeds
19 | deterministic: False
20 |
21 | # log every N steps
22 | log_every_n_steps: 10
--------------------------------------------------------------------------------
/src/perturbench/configs/trainer/gpu.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default
3 |
4 | accelerator: gpu
5 | devices: 1
6 |
--------------------------------------------------------------------------------
/src/perturbench/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/altoslabs/perturbench/e57c2fbafeb1b22df4923b4a1b3d3c82d2ba57ef/src/perturbench/data/__init__.py
--------------------------------------------------------------------------------
/src/perturbench/data/accessors/base.py:
--------------------------------------------------------------------------------
1 | import scanpy as sc
2 | import os
3 | from abc import abstractmethod
4 | import subprocess as sp
5 |
6 | from perturbench.data.datasets import (
7 | Counterfactual,
8 | CounterfactualWithReference,
9 | SingleCellPerturbation,
10 | SingleCellPerturbationWithControls,
11 | )
12 | from perturbench.data.transforms.pipelines import SingleCellPipeline
13 |
14 |
15 | def download_scperturb_adata(data_url, data_cache_dir, filename):
16 | """
17 | Helper function to download and cache anndata files. Returns an in-memory
18 | anndata object as-is with no curation.
19 | """
20 | if not os.path.exists(data_cache_dir):
21 | os.makedirs(data_cache_dir)
22 |
23 | tmp_data_path = f"{data_cache_dir}/{filename}"
24 |
25 | if not os.path.exists(tmp_data_path):
26 | sp.call(f"wget {data_url} -O {tmp_data_path}", shell=True)
27 |
28 | adata = sc.read_h5ad(tmp_data_path)
29 | return adata
30 |
31 |
32 | def download_file(url: str, output_dir: str, output_filename: str) -> None:
33 | """
34 | Downloads a file from a URL to the specified output path using wget.
35 |
36 | Args:
37 | url (str): The URL of the file to download
38 | output_path (str): The local path where the file should be saved
39 | """
40 | if not os.path.exists(output_dir):
41 | os.makedirs(output_dir)
42 | output_path = f"{output_dir}/{output_filename}"
43 | sp.call(f"wget {url} -O {output_path}", shell=True)
44 |
45 | if ".gz" in output_filename:
46 | sp.call(f"gzip -d {output_path}", shell=True)
47 |
48 | return output_path.replace(".gz", "")
49 |
50 |
51 | class Accessor:
52 | data_cache_dir: str
53 | dataset_hf_url: str
54 | dataset_orig_url: str
55 | dataset_name: str
56 | processed_data_path: str
57 |
58 | def __init__(
59 | self,
60 | dataset_hf_url,
61 | dataset_orig_url,
62 | dataset_name,
63 | data_cache_dir="../perturbench_data",
64 | ):
65 | self.dataset_hf_url = dataset_hf_url
66 | self.dataset_orig_url = dataset_orig_url
67 | self.dataset_name = dataset_name
68 | self.data_cache_dir = data_cache_dir
69 |
70 | def get_dataset(
71 | self,
72 | dataset_class=SingleCellPerturbation,
73 | add_default_transforms=True,
74 | **dataset_kwargs,
75 | ):
76 | if dataset_class not in [
77 | SingleCellPerturbation,
78 | SingleCellPerturbationWithControls,
79 | Counterfactual,
80 | CounterfactualWithReference,
81 | ]:
82 | raise ValueError("Invalid dataset class.")
83 |
84 | ## Instantiate datamodule with Hydra using the config with the datapath changed
85 | adata = self.get_anndata()
86 |
87 | if "perturbation_key" not in dataset_kwargs:
88 | dataset_kwargs["perturbation_key"] = "condition"
89 | if "covariate_keys" not in dataset_kwargs:
90 | dataset_kwargs["covariate_keys"] = ["cell_type"]
91 | if "perturbation_control_value" not in dataset_kwargs:
92 | dataset_kwargs["perturbation_control_value"] = "control"
93 |
94 | dataset, context = dataset_class.from_anndata(
95 | adata=adata,
96 | **dataset_kwargs,
97 | )
98 |
99 | if add_default_transforms:
100 | dataset.transform = SingleCellPipeline(
101 | perturbation_uniques=context["perturbation_uniques"],
102 | covariate_uniques=context["covariate_uniques"],
103 | )
104 |
105 | return dataset, context
106 |
107 | @abstractmethod
108 | def get_anndata(self):
109 | pass
110 |
--------------------------------------------------------------------------------
/src/perturbench/data/accessors/frangieh21.py:
--------------------------------------------------------------------------------
1 | import scanpy as sc
2 | import os
3 | from scipy.sparse import csr_matrix
4 |
5 | from perturbench.analysis.preprocess import preprocess
6 | from perturbench.data.accessors.base import (
7 | download_scperturb_adata,
8 | download_file,
9 | Accessor,
10 | )
11 |
12 |
13 | class Frangieh21(Accessor):
14 | def __init__(self, data_cache_dir="../perturbench_data"):
15 | super().__init__(
16 | data_cache_dir=data_cache_dir,
17 | dataset_hf_url="https://huggingface.co/datasets/altoslabs/perturbench/resolve/main/frangieh19_preprocessed.h5ad.gz",
18 | dataset_orig_url="https://zenodo.org/records/7041849/files/FrangiehIzar2021_RNA.h5ad?download=1",
19 | dataset_name="frangieh21",
20 | )
21 |
22 | def get_anndata(self):
23 | """
24 | Downloads, curates, and preprocesses the Frangieh21 dataset from Hugging
25 | Face or the scPerturb database. Saves the preprocessed data to disk and
26 | returns it in-memory.
27 |
28 | Returns:
29 | adata (anndata.AnnData): Anndata object containing the processed data.
30 |
31 | """
32 | self.processed_data_path = (
33 | f"{self.data_cache_dir}/{self.dataset_name}_processed.h5ad"
34 | )
35 | if os.path.exists(self.processed_data_path):
36 | print("Loading processed data from:", self.processed_data_path)
37 | adata = sc.read_h5ad(self.processed_data_path)
38 |
39 | else:
40 | try:
41 | hf_filename = f"{self.dataset_name}_processed.h5ad.gz"
42 | download_file(self.dataset_hf_url, self.data_cache_dir, hf_filename)
43 | adata = sc.read_h5ad(self.processed_data_path)
44 |
45 | except Exception as e:
46 | print(f"Error downloading file from {self.dataset_hf_url}: {e}")
47 | print(f"Downloading file from {self.dataset_orig_url}")
48 |
49 | adata = download_scperturb_adata(
50 | self.dataset_orig_url,
51 | self.data_cache_dir,
52 | filename=f"{self.dataset_name}_downloaded.h5ad",
53 | )
54 |
55 | ## Format column names
56 | treatment_map = {
57 | "Co-culture": "co-culture",
58 | "Control": "none",
59 | }
60 | adata.obs["treatment"] = [
61 | treatment_map[x] if x in treatment_map else x
62 | for x in adata.obs.perturbation_2
63 | ]
64 | adata.obs["cell_type"] = "melanocyte"
65 | adata.obs["condition"] = adata.obs.perturbation.copy()
66 | adata.obs["perturbation_type"] = "CRISPRi"
67 | adata.obs["dataset"] = "frangieh21"
68 |
69 | adata.X = csr_matrix(adata.X)
70 | adata = preprocess(
71 | adata,
72 | perturbation_key="condition",
73 | covariate_keys=["treatment"],
74 | )
75 |
76 | adata = adata.copy()
77 | adata.write_h5ad(self.processed_data_path)
78 |
79 | print("Saved processed data to:", self.processed_data_path)
80 |
81 | return adata
82 |
--------------------------------------------------------------------------------
/src/perturbench/data/accessors/jiang24.py:
--------------------------------------------------------------------------------
1 | import scanpy as sc
2 | import os
3 |
4 | from perturbench.data.accessors.base import (
5 | download_file,
6 | Accessor,
7 | )
8 |
9 |
10 | class Jiang24(Accessor):
11 | def __init__(self, data_cache_dir="../perturbench_data"):
12 | super().__init__(
13 | data_cache_dir=data_cache_dir,
14 | dataset_hf_url="https://huggingface.co/datasets/altoslabs/perturbench/blob/main/jiang24_preprocessed.h5ad.gz",
15 | dataset_orig_url=None,
16 | dataset_name="jiang24",
17 | )
18 |
19 | def get_anndata(self):
20 | """
21 | Downloads, curates, and preprocesses the Jiang24 dataset from Hugging Face.
22 | Saves the preprocessed data to disk and returns it in-memory.
23 |
24 | Returns:
25 | adata (anndata.AnnData): Anndata object containing the processed data.
26 |
27 | """
28 | self.processed_data_path = (
29 | f"{self.data_cache_dir}/{self.dataset_name}_processed.h5ad"
30 | )
31 | if os.path.exists(self.processed_data_path):
32 | print("Loading processed data from:", self.processed_data_path)
33 | adata = sc.read_h5ad(self.processed_data_path)
34 |
35 | else:
36 | try:
37 | hf_filename = f"{self.dataset_name}_processed.h5ad.gz"
38 | download_file(self.dataset_hf_url, self.data_cache_dir, hf_filename)
39 | adata = sc.read_h5ad(self.processed_data_path)
40 |
41 | except Exception as e:
42 | print(f"Error downloading file from {self.dataset_hf_url}: {e}")
43 | raise ValueError(
44 | "Automatic data curation not available for this dataset. \
45 | Use the notebooks in notebooks/neurips2025/data_curation \
46 | to download and preprocess the data."
47 | )
48 |
49 | print("Saved processed data to:", self.processed_data_path)
50 |
51 | return adata
52 |
--------------------------------------------------------------------------------
/src/perturbench/data/accessors/mcfaline23.py:
--------------------------------------------------------------------------------
1 | import scanpy as sc
2 | import os
3 |
4 | from perturbench.data.accessors.base import (
5 | download_file,
6 | Accessor,
7 | )
8 |
9 |
10 | class McFaline23(Accessor):
11 | def __init__(self, data_cache_dir="../perturbench_data"):
12 | super().__init__(
13 | data_cache_dir=data_cache_dir,
14 | dataset_hf_url="https://huggingface.co/datasets/altoslabs/perturbench/blob/main/mcfaline23_gxe_preprocessed.h5ad.gz",
15 | dataset_orig_url=None,
16 | dataset_name="mcfaline23",
17 | )
18 |
19 | def get_anndata(self):
20 | """
21 | Downloads, curates, and preprocesses the McFalineFigueroa23 dataset from
22 | Hugging Face. Saves the preprocessed data to disk and returns it in-memory.
23 |
24 | Returns:
25 | adata (anndata.AnnData): Anndata object containing the processed data.
26 |
27 | """
28 | self.processed_data_path = (
29 | f"{self.data_cache_dir}/{self.dataset_name}_processed.h5ad"
30 | )
31 | if os.path.exists(self.processed_data_path):
32 | print("Loading processed data from:", self.processed_data_path)
33 | adata = sc.read_h5ad(self.processed_data_path)
34 |
35 | else:
36 | try:
37 | hf_filename = f"{self.dataset_name}_processed.h5ad.gz"
38 | download_file(self.dataset_hf_url, self.data_cache_dir, hf_filename)
39 | adata = sc.read_h5ad(self.processed_data_path)
40 |
41 | except Exception as e:
42 | print(f"Error downloading file from {self.dataset_hf_url}: {e}")
43 | raise ValueError(
44 | "Automatic data curation not available for this dataset. \
45 | Use the notebooks in notebooks/neurips2025/data_curation \
46 | to download and preprocess the data."
47 | )
48 |
49 | print("Saved processed data to:", self.processed_data_path)
50 |
51 | return adata
52 |
--------------------------------------------------------------------------------
/src/perturbench/data/accessors/norman19.py:
--------------------------------------------------------------------------------
1 | import scanpy as sc
2 | import os
3 | from scipy.sparse import csr_matrix
4 |
5 | from perturbench.analysis.preprocess import preprocess
6 | from perturbench.data.accessors.base import (
7 | download_scperturb_adata,
8 | download_file,
9 | Accessor,
10 | )
11 |
12 |
13 | class Norman19(Accessor):
14 | def __init__(self, data_cache_dir="../perturbench_data"):
15 | super().__init__(
16 | data_cache_dir=data_cache_dir,
17 | dataset_hf_url="https://huggingface.co/datasets/altoslabs/perturbench/resolve/main/norman19_preprocessed.h5ad.gz",
18 | dataset_orig_url="https://zenodo.org/records/7041849/files/NormanWeissman2019_filtered.h5ad?download=1",
19 | dataset_name="norman19",
20 | )
21 |
22 | def get_anndata(self):
23 | """
24 | Downloads, curates, and preprocesses the norman19 dataset from either
25 | Hugging Face or the scPerturb database. Saves the preprocessed data to
26 | disk and returns it in-memory.
27 |
28 | Returns:
29 | adata (anndata.AnnData): Anndata object containing the processed data.
30 |
31 | """
32 | self.processed_data_path = (
33 | f"{self.data_cache_dir}/{self.dataset_name}_processed.h5ad"
34 | )
35 | if os.path.exists(self.processed_data_path):
36 | print("Loading processed data from:", self.processed_data_path)
37 | adata = sc.read_h5ad(self.processed_data_path)
38 |
39 | else:
40 | try:
41 | hf_filename = f"{self.dataset_name}_processed.h5ad.gz"
42 | download_file(self.dataset_hf_url, self.data_cache_dir, hf_filename)
43 | adata = sc.read_h5ad(self.processed_data_path)
44 |
45 | except Exception as e:
46 | print(f"Error downloading file from {self.dataset_hf_url}: {e}")
47 | print(f"Downloading file from {self.dataset_orig_url}")
48 |
49 | adata = download_scperturb_adata(
50 | self.dataset_orig_url,
51 | self.data_cache_dir,
52 | filename=f"{self.dataset_name}_downloaded.h5ad",
53 | )
54 |
55 | adata.obs.rename(
56 | columns={
57 | "nCount_RNA": "ncounts",
58 | "nFeature_RNA": "ngenes",
59 | "percent.mt": "percent_mito",
60 | "cell_line": "cell_type",
61 | },
62 | inplace=True,
63 | )
64 |
65 | adata.obs["perturbation"] = adata.obs["perturbation"].str.replace("_", "+")
66 | adata.obs["perturbation"] = adata.obs["perturbation"].astype("category")
67 | adata.obs["condition"] = adata.obs.perturbation.copy()
68 |
69 | adata.X = csr_matrix(adata.X)
70 |
71 | adata = preprocess(
72 | adata,
73 | perturbation_key="condition",
74 | covariate_keys=["cell_type"],
75 | )
76 |
77 | adata = adata.copy()
78 | adata.write_h5ad(self.processed_data_path)
79 |
80 | print("Saved processed data to:", self.processed_data_path)
81 |
82 | return adata
83 |
--------------------------------------------------------------------------------
/src/perturbench/data/accessors/srivatsan20.py:
--------------------------------------------------------------------------------
1 | import scanpy as sc
2 | import os
3 | from scipy.sparse import csr_matrix
4 |
5 | from perturbench.analysis.preprocess import preprocess
6 | from perturbench.analysis.utils import get_ensembl_mappings
7 | from perturbench.data.accessors.base import (
8 | download_scperturb_adata,
9 | download_file,
10 | Accessor,
11 | )
12 |
13 |
14 | class Sciplex3(Accessor):
15 | def __init__(self, data_cache_dir="../perturbench_data"):
16 | super().__init__(
17 | data_cache_dir=data_cache_dir,
18 | dataset_hf_url="https://huggingface.co/datasets/altoslabs/perturbench/resolve/main/srivatsan20_highest_dose_preprocessed.h5ad.gz",
19 | dataset_orig_url="https://zenodo.org/records/7041849/files/SrivatsanTrapnell2020_sciplex3.h5ad?download=1",
20 | dataset_name="sciplex3",
21 | )
22 |
23 | def get_anndata(self):
24 | """
25 | Downloads, curates, and preprocesses the sciplex3 dataset from the scPerturb
26 | database. Saves the preprocessed data to disk and returns it in-memory.
27 |
28 | Returns:
29 | adata (anndata.AnnData): Anndata object containing the processed data.
30 |
31 | """
32 | self.processed_data_path = (
33 | f"{self.data_cache_dir}/{self.dataset_name}_processed.h5ad"
34 | )
35 | if os.path.exists(self.processed_data_path):
36 | print("Loading processed data from:", self.processed_data_path)
37 | adata = sc.read_h5ad(self.processed_data_path)
38 |
39 | else:
40 | try:
41 | hf_filename = f"{self.dataset_name}_processed.h5ad.gz"
42 | download_file(self.dataset_hf_url, self.data_cache_dir, hf_filename)
43 | adata = sc.read_h5ad(self.processed_data_path)
44 |
45 | except Exception as e:
46 | print(f"Error downloading file from {self.dataset_hf_url}: {e}")
47 | print(f"Downloading file from {self.dataset_orig_url}")
48 |
49 | adata = download_scperturb_adata(
50 | self.dataset_orig_url,
51 | self.data_cache_dir,
52 | filename=f"{self.dataset_name}_downloaded.h5ad",
53 | )
54 |
55 | unique_genes = ~adata.var.ensembl_id.duplicated()
56 | adata = adata[:, unique_genes]
57 |
58 | ## Map ENSEMBL IDs to gene symbols
59 | adata.var_names = adata.var.ensembl_id.astype(str)
60 | human_ids = [x for x in adata.var_names if "ENSG" in x]
61 |
62 | adata = adata[:, human_ids]
63 | gene_mappings = get_ensembl_mappings()
64 | gene_mappings = {
65 | k: v for k, v in gene_mappings.items() if isinstance(v, str) and v != ""
66 | }
67 | adata = adata[:, [x in gene_mappings for x in adata.var_names]]
68 | adata.var["gene_symbol"] = [gene_mappings[x] for x in adata.var_names]
69 | adata.var_names = adata.var["gene_symbol"]
70 | adata.var_names_make_unique()
71 | adata.var.index.name = None
72 |
73 | ## Format column names
74 | adata.obs.rename(
75 | columns={
76 | "n_genes": "ngenes",
77 | "n_counts": "ncounts",
78 | },
79 | inplace=True,
80 | )
81 |
82 | ## Format cell line names
83 | adata.obs["cell_type"] = adata.obs["cell_line"].copy()
84 | adata = adata[adata.obs.cell_type.isin(["MCF7", "A549", "K562"])]
85 | adata.obs["cell_type"] = [x.lower() for x in adata.obs.cell_type]
86 |
87 | ## Rename some chemicals with the "+" symbol
88 | perturbation_remap = {
89 | "(+)-JQ1": "JQ1",
90 | "ENMD-2076 L-(+)-Tartaric acid": "ENMD-2076",
91 | }
92 | adata.obs["perturbation"] = [
93 | perturbation_remap.get(x, x) for x in adata.obs.perturbation.astype(str)
94 | ]
95 | adata.obs["condition"] = adata.obs["perturbation"].copy()
96 |
97 | ## Subset to highest dose only
98 | adata = adata[
99 | (adata.obs.dose_value == 10000) | (adata.obs.condition == "control")
100 | ].copy()
101 |
102 | adata.X = csr_matrix(adata.X)
103 | adata = preprocess(
104 | adata,
105 | perturbation_key="condition",
106 | covariate_keys=["cell_type"],
107 | )
108 |
109 | adata = adata.copy()
110 | adata.write_h5ad(self.processed_data_path)
111 |
112 | print("Saved processed data to:", self.processed_data_path)
113 |
114 | return adata
115 |
--------------------------------------------------------------------------------
/src/perturbench/data/collate.py:
--------------------------------------------------------------------------------
1 | class noop_collate:
2 | """No operation collate function. Returns the batch as is."""
3 |
4 | def __call__(self, batch: list):
5 | if len(batch) == 1:
6 | return batch[0]
7 | else:
8 | return batch
9 |
--------------------------------------------------------------------------------
/src/perturbench/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .singlecell import SingleCellPerturbation, SingleCellPerturbationWithControls
2 | from .population import Counterfactual, CounterfactualWithReference
3 |
--------------------------------------------------------------------------------
/src/perturbench/data/resources/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/altoslabs/perturbench/e57c2fbafeb1b22df4923b4a1b3d3c82d2ba57ef/src/perturbench/data/resources/__init__.py
--------------------------------------------------------------------------------
/src/perturbench/data/resources/devel.h5ad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/altoslabs/perturbench/e57c2fbafeb1b22df4923b4a1b3d3c82d2ba57ef/src/perturbench/data/resources/devel.h5ad
--------------------------------------------------------------------------------
/src/perturbench/data/transforms/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/altoslabs/perturbench/e57c2fbafeb1b22df4923b4a1b3d3c82d2ba57ef/src/perturbench/data/transforms/__init__.py
--------------------------------------------------------------------------------
/src/perturbench/data/transforms/base.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from abc import ABC, abstractmethod
3 | from typing import Iterator, TypeVar
4 |
5 | from ..types import Example, Batch
6 |
7 | Datum = TypeVar("Datum", Example, Batch)
8 |
9 |
10 | class Transform(ABC):
11 | """Abstract transform interface."""
12 |
13 | @abstractmethod
14 | def __call__(self, data: Datum) -> Datum:
15 | pass
16 |
17 | @abstractmethod
18 | def __repr__(self) -> str:
19 | return self.__class__.__name__ + "({!s})"
20 |
21 |
22 | class ExampleTransform(Transform):
23 | """Transforms an example."""
24 |
25 | @abstractmethod
26 | def __call__(self, example: Example) -> Example:
27 | pass
28 |
29 | def __repr__(self) -> str:
30 | _base = super().__repr__()
31 | return "[example-wise]" + _base
32 |
33 | def batchify(self, collate_fn: callable = list) -> Map:
34 | """Converts an example transform to a batch transform."""
35 | return Map(self, collate_fn)
36 |
37 |
38 | class Map(Transform):
39 | """Maps a transform to a batch of examples."""
40 |
41 | def __init__(self, transform: ExampleTransform, collate_fn: callable = list):
42 | self.transform = transform
43 | self.collate_fn = collate_fn
44 |
45 | def __call__(self, batch: Iterator[Example]) -> Batch:
46 | return self.collate_fn(list(map(self.transform, batch)))
47 |
48 | def __repr__(self) -> str:
49 | _base = super().__repr__()
50 | # strip the [example-wise] prefix from the self.transform repr
51 | _instance_repr = repr(self.transform)
52 | if _instance_repr.startswith("[example-wise]"):
53 | _instance_repr = _instance_repr[len("[example-wise]") :]
54 | args_repr = f"{_instance_repr}, collate_fn={self.collate_fn.__name__}"
55 | return "[batch-wise]" + _base.format(args_repr)
56 |
57 |
58 | class Dispatch(dict, Transform):
59 | """Dispatches a transform to an example based on a key field.
60 |
61 | Attributes:
62 | self: A map of key to transform.
63 | """
64 |
65 | def __call__(self, data: Datum) -> Datum:
66 | """Apply each transform to the field of an example matching its key."""
67 | result = {}
68 | try:
69 | for key, transform in self.items():
70 | result[key] = transform(getattr(data, key))
71 | except KeyError as exc:
72 | raise TypeError(
73 | f"Invalid {key=} in transforms. All keys need to match the "
74 | f"fields of an example."
75 | ) from exc
76 |
77 | return data._replace(**result)
78 |
79 | def __repr__(self) -> str:
80 | _base = Transform.__repr__(self)
81 | transforms_repr = ", ".join(
82 | f"{key}: {repr(transform)}" for key, transform in self.items()
83 | )
84 | return _base.format(transforms_repr)
85 |
86 |
87 | class Compose(list, Transform):
88 | """Creates a transform from a sequence of transforms."""
89 |
90 | def __call__(self, data: Datum) -> Datum:
91 | for transform in self:
92 | data = transform(data)
93 | return data
94 |
95 | def __repr__(self) -> str:
96 | transforms_repr = " \u2192 ".join(repr(transform) for transform in self)
97 | return f"[{transforms_repr}]"
98 |
--------------------------------------------------------------------------------
/src/perturbench/data/transforms/encoders.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import functools
3 | from typing import Collection, Sequence
4 |
5 | import torch
6 | import numpy as np
7 | from sklearn.preprocessing import OrdinalEncoder, OneHotEncoder, MultiLabelBinarizer
8 |
9 | from .base import Transform
10 | from ..types import ExampleMultiLabel, BatchMultiLabel
11 |
12 |
13 | class OneHotEncode(Transform):
14 | """One-hot encode a categorical variable.
15 |
16 | Attributes:
17 | onehot_encoder: the wrapped encoder instance
18 | """
19 |
20 | one_hot_encoder: OneHotEncoder
21 |
22 | def __init__(self, categories: Collection[str], **kwargs):
23 | categories = [list(categories)]
24 | self.one_hot_encoder = OneHotEncoder(
25 | categories=categories,
26 | sparse_output=False,
27 | **kwargs,
28 | )
29 |
30 | def __call__(self, labels: Sequence[str]):
31 | string_array = np.array(labels).reshape(-1, 1)
32 | encoded = self.one_hot_encoder.fit_transform(string_array)
33 | return torch.Tensor(encoded)
34 |
35 | def __repr__(self):
36 | _base = super().__repr__()
37 | categories = ", ".join(self.one_hot_encoder.categories[0])
38 | return _base.format(categories)
39 |
40 |
41 | class LabelEncode(Transform):
42 | """Label encode categorical variables.
43 |
44 | Attributes:
45 | ordinal_encoder: sklearn.preprocessing.OrdinalEncoder
46 | """
47 |
48 | ordinal_encoder: OrdinalEncoder
49 |
50 | def __init__(self, values: Sequence[str]):
51 | categories = [np.array(values)]
52 | self.ordinal_encoder = OrdinalEncoder(categories=categories)
53 |
54 | def __call__(self, labels: Sequence[str]):
55 | string_array = np.array(labels).reshape(-1, 1)
56 | return torch.Tensor(self.ordinal_encoder.fit_transform(string_array))
57 |
58 | def __repr__(self):
59 | _base = super().__repr__()
60 | categories = ", ".join(self.ordinal_encoder.categories[0])
61 | return _base.format(categories)
62 |
63 |
64 | class MultiLabelEncode(Transform):
65 | """Transforms a sequence of labels into a binary vector.
66 |
67 | Attributes:
68 | label_binarizer: the wrapped binarizer instance
69 |
70 | Raises:
71 | ValueError: if any of the labels are not found in the encoder classes
72 | """
73 |
74 | label_binarizer: MultiLabelBinarizer
75 |
76 | def __init__(self, classes: Collection[str]):
77 | self.label_binarizer = MultiLabelBinarizer(
78 | classes=list(classes), sparse_output=False
79 | )
80 |
81 | @functools.cached_property
82 | def classes(self):
83 | return set(self.label_binarizer.classes)
84 |
85 | def __call__(self, labels: ExampleMultiLabel | BatchMultiLabel) -> torch.Tensor:
86 | # If labels is a single example, convert it to a batch
87 | if not labels or isinstance(labels[0], str):
88 | labels = [labels]
89 | self._check_inputs(labels)
90 | encoded = self.label_binarizer.fit_transform(labels)
91 | return torch.from_numpy(encoded)
92 |
93 | def _check_inputs(self, labels: BatchMultiLabel):
94 | unique_labels = set(itertools.chain.from_iterable(labels))
95 | if not unique_labels <= self.classes:
96 | missing_labels = unique_labels - self.classes
97 | raise ValueError(
98 | f"Labels {missing_labels} not found in the encoder classes {self.classes}"
99 | )
100 |
101 | def __repr__(self):
102 | _base = super().__repr__()
103 | classes = ", ".join(self.label_binarizer.classes)
104 | return _base.format(classes)
105 |
--------------------------------------------------------------------------------
/src/perturbench/data/transforms/ops.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable
2 |
3 | import torch
4 | from scipy.sparse import csr_matrix
5 |
6 | from .base import Transform
7 |
8 |
9 | class ToDense(Transform):
10 | """Convert a sparse matrix/tensor to a dense matrix/tensor."""
11 |
12 | def __call__(self, value: torch.Tensor | csr_matrix) -> torch.Tensor:
13 | if isinstance(value, torch.Tensor):
14 | return value.to_dense()
15 | elif isinstance(value, csr_matrix):
16 | return torch.Tensor(value.toarray())
17 | else:
18 | return value
19 |
20 | def __repr__(self):
21 | return "ToDense"
22 |
23 |
24 | class ToFloat(Transform):
25 | """Convert a tensor to float."""
26 |
27 | def __call__(self, value: torch.Tensor):
28 | return value.float()
29 |
30 | def __repr__(self):
31 | return "ToFloat"
32 |
33 |
34 | class MapApply(Transform):
35 | """Map each transform to an input based on a key.
36 |
37 | Attributes:
38 | transform_map: A map of key to transform.
39 | """
40 |
41 | transform_map: dict[str, Transform | Callable]
42 |
43 | def __init__(
44 | self,
45 | transforms: dict[str, Transform | Callable],
46 | init_params_map: dict | None = None,
47 | ) -> None:
48 | """Initializes the instance based on passed transforms.
49 |
50 | This classes supports two ways of initializing the transforms. The first
51 | is by passing a map of key to transform. The second is by passing a map of
52 | key to factory callable. The factory callable will be called with the
53 | corresponding init params from the init_params_map. The factory callable
54 | should return a Transform.
55 |
56 | Args:
57 | transforms: A map of key to transform.
58 | init_params_map: A map of key to init params for the transforms.
59 |
60 | Raises:
61 | ValueError: If init_params_map is not None when using a dict of
62 | Transforms.
63 | TypeError: If the transform is not a dict of Transform or a callable.
64 | """
65 | super().__init__()
66 | self.transform_map = {}
67 | for key, transform in transforms.items():
68 | # Transforms are dict[str, Transform], directly assign them
69 | if isinstance(transform, Transform):
70 | if init_params_map is not None:
71 | raise ValueError(
72 | "init_params_map should be None when using a dict of "
73 | "Transforms."
74 | )
75 | self.transform_map[key] = transform
76 | # Transforms are dict[str, factory_callable], call the factory
77 | elif callable(transform):
78 | self.transform_map[key] = transform(init_params_map[key])
79 | else:
80 | raise TypeError(
81 | f"Invalid type for {key=} in transform. Must be either a "
82 | f"Transform or a callable."
83 | )
84 |
85 | def __call__(self, value_map: dict[str, Any]) -> dict[str, Any]:
86 | return {key: self.transform_map[key](val) for key, val in value_map.items()}
87 |
88 | def __repr__(self) -> str:
89 | transforms_repr = ", ".join(
90 | f"{key}: {repr(transform)}" for key, transform in self.transform_map.items()
91 | )
92 | return "{" + transforms_repr + "}"
93 |
--------------------------------------------------------------------------------
/src/perturbench/data/transforms/pipelines.py:
--------------------------------------------------------------------------------
1 | from .base import Dispatch, Compose
2 | from .encoders import OneHotEncode, MultiLabelEncode
3 | from .ops import ToDense, ToFloat, MapApply
4 |
5 |
6 | class SingleCellPipeline(Dispatch):
7 | """Single cell transform pipeline."""
8 |
9 | def __init__(
10 | self,
11 | perturbation_uniques: set[str],
12 | covariate_uniques: dict[str:set],
13 | ) -> None:
14 | # Set up covariates transform
15 | covariate_transform = {
16 | key: Compose([OneHotEncode(uniques), ToFloat()])
17 | for key, uniques in covariate_uniques.items()
18 | }
19 | # Initialize the pipeline
20 | super().__init__(
21 | perturbations=Compose(
22 | [
23 | MultiLabelEncode(perturbation_uniques),
24 | ToFloat(),
25 | ]
26 | ),
27 | gene_expression=ToDense(),
28 | covariates=MapApply(covariate_transform),
29 | )
30 |
--------------------------------------------------------------------------------
/src/perturbench/data/types.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import NamedTuple, Sequence, Any
3 |
4 | from scipy.sparse import csr_matrix, csr_array
5 | import numpy as np
6 |
7 | SparseMatrix = csr_matrix
8 | SparseVector = csr_array
9 |
10 | ExampleMultiLabel = Sequence[str]
11 | BatchMultiLabel = Sequence[ExampleMultiLabel]
12 |
13 |
14 | class Example(NamedTuple):
15 | """Single Cell Expression Example."""
16 |
17 | # A vector of size (num_genes, )
18 | gene_expression: SparseVector
19 | # A list of perturbations applied to the cell
20 | perturbations: Sequence[str] ## TODO: Should be [] if control
21 | # A map from covariate name to covariate value
22 | covariates: dict[str, str] | None = None
23 | # A map from control condition name to control gene expression of
24 | # shape (num_controls_in_condition, num_genes)
25 | controls: SparseVector | None = None
26 | # A cell id
27 | id: str | None = None
28 | # A list of gene names of length num_genes
29 | gene_names: Sequence[str] | None = None
30 | # Optional foundation model embeddings
31 | embeddings: np.ndarray | None = None
32 |
33 |
34 | class Batch(NamedTuple):
35 | """Single Cell Expression Batch."""
36 |
37 | gene_expression: SparseMatrix
38 | perturbations: Sequence[list[str]] | SparseMatrix
39 | covariates: dict[str, Sequence[str] | SparseMatrix] | None = None
40 | controls: SparseMatrix | None = None
41 | id: Sequence[str] | None = None
42 | gene_names: Sequence[str] | None = None
43 | embeddings: np.ndarray | None = None
44 |
45 |
46 | class FrozenDictKeyMap(dict):
47 | """A dictionary that uses dictionaries as keys.
48 |
49 | Dictionaries cannot be used directly as keys to another dictionary because
50 | they are mutable. As a result this class first converts the dictionary to a
51 | frozenset of key-value pairs before using it as a key. The underlying data
52 | is stored using the dictionary data structure and this class just modifies
53 | the accessor and mutator methods.
54 |
55 | Example:
56 | >>> d = FrozenDictKeyMap()
57 | >>> d[{"a": 1, "b": 2}] = 1
58 | >>> d[{"a": 1, "b": 2}] = 2
59 | >>> d[{"a": 1, "b": 2}] = 3
60 | >>> d
61 | {frozenset({('a', 1), ('b', 2)}): 3}
62 |
63 | Attributes: see dict class
64 | """
65 |
66 | def __init__(self, data: Sequence[tuple[dict, Any]] | None = None):
67 | """Initialize the dictionary.
68 |
69 | Args:
70 | data: a sequence of (key, value) pairs to initialize the dictionary
71 | """
72 | if data is not None:
73 | try:
74 | _data = [(frozenset(key.items()), value) for key, value in data]
75 | except AttributeError as exc:
76 | raise ValueError(
77 | "data must be a sequence of (key, value) pairs where key is a "
78 | "dictionary"
79 | ) from exc
80 | else:
81 | _data = []
82 | super().__init__(_data)
83 |
84 | def __getitem__(self, key: dict) -> Any:
85 | """Get the value associated with the key.
86 |
87 | Args:
88 | key: a dictionary.
89 |
90 | Returns:
91 | The value associated with the key.
92 | """
93 | if isinstance(key, frozenset):
94 | key = dict(key)
95 | return super().__getitem__(frozenset(key.items()))
96 |
97 | def __setitem__(self, key: dict, value: Any) -> None:
98 | """Set the value associated with the key.
99 |
100 | Args:
101 | key: a dictionary.
102 | value: the value to set.
103 | """
104 | if isinstance(key, frozenset):
105 | key = dict(key)
106 | super().__setitem__(frozenset(key.items()), value)
107 |
--------------------------------------------------------------------------------
/src/perturbench/modelcore/__init__.py:
--------------------------------------------------------------------------------
1 | VERSION = "0.0.1"
2 |
--------------------------------------------------------------------------------
/src/perturbench/modelcore/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .linear_additive import LinearAdditive
2 | from .latent_additive import LatentAdditive
3 | from .decoder_only import DecoderOnly
4 | from .sams_vae import SparseAdditiveVAE
5 | from .base import PerturbationModel
6 | from .biolord import BiolordStar
7 | from .average import Average
8 | from .cpa import CPA
9 |
--------------------------------------------------------------------------------
/src/perturbench/modelcore/models/average.py:
--------------------------------------------------------------------------------
1 | """
2 | BSD 3-Clause License
3 |
4 | Copyright (c) 2024,
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 | """
31 |
32 | import lightning as L
33 | import torch
34 | from torch import optim
35 | from perturbench.data.types import Batch
36 | from .base import PerturbationModel
37 |
38 |
39 | class Average(PerturbationModel):
40 | """
41 | A perturbation prediction baseline model that returns the average expression of each perturbation in the training data.
42 | """
43 |
44 | def __init__(
45 | self,
46 | n_genes: int,
47 | n_perts: int,
48 | datamodule: L.LightningDataModule | None = None,
49 | ) -> None:
50 | """
51 | The constructor for the Average class.
52 |
53 | Args:
54 | n_genes (int): Number of genes in the dataset
55 | n_perts (int): Number of perturbations in the dataset (not including controls)
56 | """
57 | super(Average, self).__init__(datamodule)
58 | self.save_hyperparameters(ignore=["datamodule"])
59 |
60 | if n_genes is None:
61 | n_genes = datamodule.num_genes
62 |
63 | if n_perts is None:
64 | n_perts = datamodule.num_perturbations
65 |
66 | self.n_genes = n_genes
67 | self.n_perts = n_perts
68 | self.average_expression = torch.nn.Parameter(
69 | torch.zeros(n_perts, n_genes), requires_grad=False
70 | )
71 | self.sum_expression = torch.zeros(n_perts, n_genes)
72 | self.num_cells = torch.zeros(n_perts)
73 | self.dummy_nn = torch.nn.Linear(1, 1)
74 |
75 | def configure_optimizers(self):
76 | optimizer = optim.Adam(self.parameters())
77 | return optimizer
78 |
79 | def backward(self, use_amp, loss, optimizer):
80 | return
81 |
82 | def on_train_start(self):
83 | self.sum_expression = self.sum_expression.to(self.device)
84 | self.num_cells = self.num_cells.to(self.device)
85 |
86 | def training_step(self, batch: Batch, batch_idx: int | list[int]):
87 | # Unpack the batch
88 | observed_perturbed_expression = batch.gene_expression.squeeze()
89 | perturbation = batch.perturbations.squeeze()
90 | self.sum_expression += torch.matmul(
91 | perturbation.t(), observed_perturbed_expression
92 | )
93 | self.num_cells += perturbation.sum(0)
94 |
95 | def on_train_epoch_end(self):
96 | average_expression = self.sum_expression.t() / self.num_cells
97 | self.average_expression = torch.nn.Parameter(
98 | average_expression.t(), requires_grad=False
99 | )
100 |
101 | self.sum_expression = torch.zeros(self.n_perts, self.n_genes)
102 | self.num_cells = torch.zeros(self.n_perts)
103 |
104 | def predict(self, batch: Batch):
105 | perturbation = batch.perturbations.squeeze()
106 | perturbation = perturbation.to(self.device)
107 | predicted_perturbed_expression = torch.matmul(
108 | perturbation, self.average_expression
109 | )
110 | predicted_perturbed_expression = (
111 | predicted_perturbed_expression.t() / perturbation.sum(1)
112 | )
113 | return predicted_perturbed_expression.t()
114 |
--------------------------------------------------------------------------------
/src/perturbench/modelcore/models/decoder_only.py:
--------------------------------------------------------------------------------
1 | """
2 | BSD 3-Clause License
3 |
4 | Copyright (c) 2024,
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 | """
31 |
32 | import torch
33 | import torch.nn.functional as F
34 | import lightning as L
35 | import numpy as np
36 |
37 | from ..nn.mlp import MLP
38 | from .base import PerturbationModel
39 | from perturbench.data.types import Batch
40 |
41 |
42 | class DecoderOnly(PerturbationModel):
43 | """
44 | A latent additive model for predicting perturbation effects
45 | """
46 |
47 | def __init__(
48 | self,
49 | n_genes: int,
50 | n_perts: int,
51 | n_layers=2,
52 | encoder_width=128,
53 | softplus_output=True,
54 | use_covariates=True,
55 | use_perturbations=True,
56 | lr: float | None = None,
57 | wd: float | None = None,
58 | lr_scheduler_freq: int | None = None,
59 | lr_scheduler_patience: int | None = None,
60 | lr_scheduler_factor: float | None = None,
61 | datamodule: L.LightningDataModule | None = None,
62 | ) -> None:
63 | """
64 | The constructor for the DecoderOnly class.
65 |
66 | Args:
67 | n_genes (int): Number of genes to use for prediction
68 | n_perts (int): Number of perturbations in the dataset (not including controls)
69 | n_layers (int): Number of layers in the encoder/decoder
70 | lr (float): Learning rate
71 | wd (float): Weight decay
72 | lr_scheduler_freq (int): How often the learning rate scheduler checks val_loss
73 | lr_scheduler_patience (int): Learning rate scheduler patience
74 | lr_scheduler_factor (float): Factor by which to reduce learning rate when learning rate scheduler triggers
75 | softplus_output (bool): Whether to apply a softplus activation to the output of the decoder to enforce non-negativity
76 | """
77 |
78 | super(DecoderOnly, self).__init__(
79 | datamodule=datamodule,
80 | lr=lr,
81 | wd=wd,
82 | lr_scheduler_freq=lr_scheduler_freq,
83 | lr_scheduler_patience=lr_scheduler_patience,
84 | lr_scheduler_factor=lr_scheduler_factor,
85 | )
86 | self.save_hyperparameters(ignore=["datamodule"])
87 |
88 | if not (use_covariates or use_perturbations):
89 | raise ValueError(
90 | "'use_covariates' and 'use_perturbations' can not both be false. Either covariates or perturbations have to be used."
91 | )
92 |
93 | if n_genes is not None:
94 | self.n_genes = n_genes
95 | if n_perts is not None:
96 | self.n_perts = n_perts
97 |
98 | n_total_covariates = (
99 | np.sum(
100 | [
101 | len(unique_covs)
102 | for unique_covs in datamodule.train_context[
103 | "covariate_uniques"
104 | ].values()
105 | ]
106 | )
107 | if use_covariates
108 | else 0
109 | )
110 |
111 | n_perts = self.n_perts if use_perturbations else 0
112 |
113 | decoder_input_dim = n_total_covariates + n_perts
114 |
115 | self.decoder = MLP(decoder_input_dim, encoder_width, self.n_genes, n_layers)
116 | self.softplus_output = softplus_output
117 | self.use_covariates = use_covariates
118 | self.use_perturbations = use_perturbations
119 |
120 | def forward(
121 | self,
122 | control_expression: torch.Tensor,
123 | perturbation: torch.Tensor,
124 | covariates: dict[str, torch.Tensor],
125 | ):
126 | if self.use_covariates and self.use_perturbations:
127 | embedding = torch.cat([cov.squeeze() for cov in covariates.values()], dim=1)
128 | embedding = torch.cat([perturbation, embedding], dim=1)
129 | elif self.use_covariates:
130 | embedding = torch.cat([cov.squeeze() for cov in covariates.values()], dim=1)
131 | elif self.use_perturbations:
132 | embedding = perturbation
133 |
134 | predicted_perturbed_expression = self.decoder(embedding)
135 |
136 | if self.softplus_output:
137 | predicted_perturbed_expression = F.softplus(predicted_perturbed_expression)
138 | return predicted_perturbed_expression
139 |
140 | def training_step(self, batch: Batch, batch_idx: int):
141 | (
142 | observed_perturbed_expression,
143 | control_expression,
144 | perturbation,
145 | covariates,
146 | _,
147 | ) = self.unpack_batch(batch)
148 | predicted_perturbed_expression = self.forward(
149 | control_expression, perturbation, covariates
150 | )
151 | loss = F.mse_loss(predicted_perturbed_expression, observed_perturbed_expression)
152 | self.log("train_loss", loss, prog_bar=True, logger=True, batch_size=len(batch))
153 | return loss
154 |
155 | def validation_step(self, batch: Batch, batch_idx: int):
156 | (
157 | observed_perturbed_expression,
158 | control_expression,
159 | perturbation,
160 | covariates,
161 | _,
162 | ) = self.unpack_batch(batch)
163 | predicted_perturbed_expression = self.forward(
164 | control_expression, perturbation, covariates
165 | )
166 | val_loss = F.mse_loss(
167 | predicted_perturbed_expression, observed_perturbed_expression
168 | )
169 | self.log(
170 | "val_loss",
171 | val_loss,
172 | on_step=True,
173 | prog_bar=True,
174 | logger=True,
175 | batch_size=len(batch),
176 | )
177 | return val_loss
178 |
179 | def predict(self, batch):
180 | control_expression = batch.gene_expression.squeeze().to(self.device)
181 | perturbation = batch.perturbations.squeeze().to(self.device)
182 | covariates = {k: v.to(self.device) for k, v in batch.covariates.items()}
183 |
184 | predicted_perturbed_expression = self.forward(
185 | control_expression,
186 | perturbation,
187 | covariates,
188 | )
189 | return predicted_perturbed_expression
190 |
--------------------------------------------------------------------------------
/src/perturbench/modelcore/models/linear_additive.py:
--------------------------------------------------------------------------------
1 | """
2 | BSD 3-Clause License
3 |
4 | Copyright (c) 2024,
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 | """
31 |
32 | import lightning as L
33 | import torch
34 | import torch.nn as nn
35 | import torch.nn.functional as F
36 | import numpy as np
37 |
38 | from perturbench.data.types import Batch
39 | from .base import PerturbationModel
40 |
41 |
42 | class LinearAdditive(PerturbationModel):
43 | """
44 | A latent additive model for predicting perturbation effects
45 | """
46 |
47 | def __init__(
48 | self,
49 | n_genes: int,
50 | n_perts: int,
51 | inject_covariates: bool = False,
52 | lr: float | None = None,
53 | wd: float | None = None,
54 | lr_scheduler_freq: int | None = None,
55 | lr_scheduler_interval: str | None = None,
56 | lr_scheduler_patience: int | None = None,
57 | lr_scheduler_factor: int | None = None,
58 | softplus_output: bool = True,
59 | datamodule: L.LightningDataModule | None = None,
60 | ) -> None:
61 | """
62 | The constructor for the LinearAdditive class.
63 |
64 | Args:
65 | n_genes (int): Number of genes in the dataset
66 | n_perts (int): Number of perturbations in the dataset (not including controls)
67 | lr (float): Learning rate
68 | wd (float): Weight decay
69 | lr_scheduler_freq (int): How often the learning rate scheduler checks val_loss
70 | lr_scheduler_interval (str): Whether to check val_loss every epoch or every step
71 | lr_scheduler_patience (int): Learning rate scheduler patience
72 | lr_scheduler_factor (float): Factor by which to reduce learning rate when learning rate scheduler triggers
73 | inject_covariates: Whether to condition the linear layer on
74 | covariates
75 | softplus_output: Whether to apply a softplus activation to the
76 | output of the decoder to enforce non-negativity
77 | datamodule: The datamodule used to train the model
78 | """
79 | super(LinearAdditive, self).__init__(
80 | datamodule=datamodule,
81 | lr=lr,
82 | wd=wd,
83 | lr_scheduler_interval=lr_scheduler_interval,
84 | lr_scheduler_freq=lr_scheduler_freq,
85 | lr_scheduler_patience=lr_scheduler_patience,
86 | lr_scheduler_factor=lr_scheduler_factor,
87 | )
88 | self.save_hyperparameters(ignore=["datamodule"])
89 | self.softplus_output = softplus_output
90 |
91 | if n_genes is not None:
92 | self.n_genes = n_genes
93 | if n_perts is not None:
94 | self.n_perts = n_perts
95 |
96 | self.inject_covariates = inject_covariates
97 | if inject_covariates:
98 | if datamodule is None or datamodule.train_context is None:
99 | raise ValueError(
100 | "If inject_covariates is True, datamodule must be provided"
101 | )
102 | n_total_covariates = np.sum(
103 | [
104 | len(unique_covs)
105 | for unique_covs in datamodule.train_context[
106 | "covariate_uniques"
107 | ].values()
108 | ]
109 | )
110 | self.fc_pert = nn.Linear(self.n_perts + n_total_covariates, self.n_genes)
111 | else:
112 | self.fc_pert = nn.Linear(self.n_perts, self.n_genes)
113 |
114 | def forward(
115 | self,
116 | control_expression: torch.Tensor,
117 | perturbation: torch.Tensor,
118 | covariates: dict,
119 | ):
120 | if self.inject_covariates:
121 | merged_covariates = torch.cat(
122 | [cov.squeeze() for cov in covariates.values()], dim=1
123 | )
124 | perturbation = torch.cat([perturbation, merged_covariates], dim=1)
125 |
126 | predicted_perturbed_expression = control_expression + self.fc_pert(perturbation)
127 | if self.softplus_output:
128 | predicted_perturbed_expression = F.softplus(predicted_perturbed_expression)
129 | return predicted_perturbed_expression
130 |
131 | def training_step(self, batch: Batch, batch_idx: int):
132 | (
133 | observed_perturbed_expression,
134 | control_expression,
135 | perturbation,
136 | covariates,
137 | _,
138 | ) = self.unpack_batch(batch)
139 | predicted_perturbed_expression = self.forward(
140 | control_expression, perturbation, covariates
141 | )
142 | loss = F.mse_loss(predicted_perturbed_expression, observed_perturbed_expression)
143 | self.log("train_loss", loss, prog_bar=True, logger=True, batch_size=len(batch))
144 | return loss
145 |
146 | def validation_step(self, batch: Batch, batch_idx: int):
147 | (
148 | observed_perturbed_expression,
149 | control_expression,
150 | perturbation,
151 | covariates,
152 | _,
153 | ) = self.unpack_batch(batch)
154 | predicted_perturbed_expression = self.forward(
155 | control_expression, perturbation, covariates
156 | )
157 | val_loss = F.mse_loss(
158 | predicted_perturbed_expression, observed_perturbed_expression
159 | )
160 | self.log(
161 | "val_loss",
162 | val_loss,
163 | on_step=True,
164 | prog_bar=True,
165 | logger=True,
166 | batch_size=len(batch),
167 | )
168 | return val_loss
169 |
170 | def predict(self, batch: Batch):
171 | control_expression = batch.gene_expression.squeeze().to(self.device)
172 | perturbation = batch.perturbations.squeeze().to(self.device)
173 | covariates = {k: v.to(self.device) for k, v in batch.covariates.items()}
174 | predicted_perturbed_expression = self.forward(
175 | control_expression,
176 | perturbation,
177 | covariates,
178 | )
179 | return predicted_perturbed_expression
180 |
--------------------------------------------------------------------------------
/src/perturbench/modelcore/nn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/altoslabs/perturbench/e57c2fbafeb1b22df4923b4a1b3d3c82d2ba57ef/src/perturbench/modelcore/nn/__init__.py
--------------------------------------------------------------------------------
/src/perturbench/modelcore/nn/mlp.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.distributions import RelaxedBernoulli
3 |
4 |
5 | def gumbel_softmax_bernoulli(probs, temperature=0.5):
6 | # Create a RelaxedBernoulli distribution with the specified temperature
7 | relaxed_bernoulli = RelaxedBernoulli(temperature, probs=probs)
8 |
9 | # Sample from the relaxed distribution
10 | soft_sample = relaxed_bernoulli.rsample()
11 |
12 | # Quantize the soft sample to get a hard sample
13 | hard_sample = (soft_sample > 0.5).float()
14 |
15 | # Use straight-through estimator
16 | st_sample = hard_sample - soft_sample.detach() + soft_sample
17 |
18 | return st_sample
19 |
20 |
21 | class MLP(nn.Module):
22 | def __init__(
23 | self,
24 | input_dim: int,
25 | hidden_dim: int,
26 | output_dim: int,
27 | n_layers: int,
28 | dropout: float | None = None,
29 | norm: str | None = "layer",
30 | elementwise_affine: bool = False,
31 | ):
32 | """Class for defining MLP with arbitrary number of layers"""
33 | super(MLP, self).__init__()
34 |
35 | if norm not in ["layer", "batch", None]:
36 | raise ValueError("norm must be one of ['layer', 'batch', None]")
37 |
38 | layers = nn.Sequential()
39 | layers.append(nn.Linear(input_dim, hidden_dim))
40 | for _ in range(0, n_layers):
41 | layers.append(nn.Linear(hidden_dim, hidden_dim))
42 | if norm == "layer":
43 | layers.append(
44 | nn.LayerNorm(hidden_dim, elementwise_affine=elementwise_affine)
45 | )
46 | elif norm == "batch":
47 | layers.append(nn.BatchNorm1d(hidden_dim, momentum=0.01, eps=0.001))
48 | layers.append(nn.ReLU())
49 | if dropout is not None:
50 | layers.append(nn.Dropout(dropout))
51 |
52 | layers.append(nn.Linear(hidden_dim, output_dim))
53 | self.network = layers
54 |
55 | def forward(self, x):
56 | return self.network(x)
57 |
58 |
59 | class ResMLP(nn.Module):
60 | def __init__(
61 | self,
62 | input_dim: int,
63 | hidden_dim: int,
64 | n_layers: int,
65 | dropout: float | None = None,
66 | norm: str | None = "layer",
67 | elementwise_affine: bool = False,
68 | ):
69 | super(ResMLP, self).__init__()
70 |
71 | layers = nn.Sequential()
72 | layers.append(nn.Linear(input_dim, hidden_dim))
73 | layers.append(nn.LayerNorm(hidden_dim))
74 | layers.append(nn.ReLU())
75 |
76 | for i in range(0, n_layers - 1):
77 | layers.append(nn.Linear(hidden_dim, hidden_dim))
78 | layers.append(nn.LayerNorm(hidden_dim))
79 | layers.append(nn.ReLU())
80 |
81 | layers.append(nn.Linear(hidden_dim, input_dim))
82 | self.layers = layers
83 |
84 | def forward(self, x):
85 | return x + self.layers(x)
86 |
87 |
88 | class MaskNet(nn.Module):
89 | def __init__(
90 | self,
91 | input_dim: int,
92 | hidden_dim: int,
93 | output_dim: int,
94 | n_layers: int,
95 | dropout: float | None = None,
96 | norm: str | None = "layer",
97 | elementwise_affine: bool = False,
98 | ):
99 | """
100 | Implements a mask module. Similar to a standard MLP, but the output will be discrete 0's and 1's
101 | and gradients are calculated with the Gumbel-Softmax relaxation.
102 | """
103 | super(MaskNet, self).__init__()
104 |
105 | self.mlp = MLP(
106 | input_dim=input_dim,
107 | hidden_dim=hidden_dim,
108 | output_dim=output_dim,
109 | n_layers=n_layers,
110 | dropout=dropout,
111 | norm=norm,
112 | elementwise_affine=elementwise_affine,
113 | )
114 |
115 | def forward(self, x):
116 | m_probs = self.mlp(x).sigmoid()
117 | m = gumbel_softmax_bernoulli(m_probs)
118 | return m
119 |
--------------------------------------------------------------------------------
/src/perturbench/modelcore/nn/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributions as dist
3 | import torch.nn.functional as F
4 |
5 |
6 | class ZeroInflatedNegativeBinomial(dist.Distribution):
7 | def __init__(self, base_dist, zero_prob_logits):
8 | super(ZeroInflatedNegativeBinomial, self).__init__()
9 | self.base_dist = base_dist
10 | self.zero_prob_logits = zero_prob_logits
11 | self._mean = None
12 | self._variance = None
13 |
14 | def sample(self, sample_shape=torch.Size()):
15 | # note: this is not actually sampling from NB
16 | base_samples = self.base_dist.sample(sample_shape)
17 | zero_mask = torch.bernoulli(torch.sigmoid(self.zero_prob_logits)).bool()
18 | return torch.where(zero_mask, torch.zeros_like(base_samples), base_samples)
19 |
20 | def log_prob(self, value, eps=1e-8):
21 | # the original way but can be numerically unstable
22 | # base_log_prob = self.base_dist.log_prob(value)
23 | # zero_probs = torch.sigmoid(self.zero_prob_logits)
24 | # log_prob_non_zero = torch.log1p(-zero_probs + 1e-8) + base_log_prob
25 | # log_prob = torch.where(
26 | # value == 0,
27 | # torch.log(zero_probs + (1 - zero_probs) * torch.exp(base_log_prob) + 1e-8),
28 | # log_prob_non_zero
29 | # )
30 |
31 | # Adapted from SCVI's implementation of ZINB log_prob
32 | base_log_prob = self.base_dist.log_prob(value)
33 | softplus_neg_logits = F.softplus(-self.zero_prob_logits)
34 | case_zero = (
35 | F.softplus(-self.zero_prob_logits + base_log_prob) - softplus_neg_logits
36 | )
37 | mul_case_zero = torch.mul((value < eps).type(torch.float32), case_zero)
38 | case_non_zero = -self.zero_prob_logits - softplus_neg_logits + base_log_prob
39 | mul_case_non_zero = torch.mul((value > eps).type(torch.float32), case_non_zero)
40 | log_prob = mul_case_zero + mul_case_non_zero
41 |
42 | return log_prob
43 |
44 | @property
45 | def mean(self):
46 | if self._mean is None:
47 | base_mean = self.base_dist.mean
48 | self._mean = (1 - torch.sigmoid(self.zero_prob_logits)) * base_mean
49 | return self._mean
50 |
51 | @property
52 | def variance(
53 | self,
54 | ): # https://docs.pyro.ai/en/dev/_modules/pyro/distributions/zero_inflated.html#ZeroInflatedNegativeBinomial
55 | if self._variance is None:
56 | base_mean = self.base_dist.mean
57 | base_variance = self.base_dist.variance
58 | self._variance = (1 - torch.sigmoid(self.zero_prob_logits)) * (
59 | base_mean**2 + base_variance
60 | ) - self.mean**2
61 | return self._variance
62 |
--------------------------------------------------------------------------------
/src/perturbench/modelcore/predict.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from omegaconf import DictConfig
3 | import lightning as L
4 | import logging
5 | import hydra
6 | import os
7 |
8 | from perturbench.data.datasets import Counterfactual
9 | from perturbench.data.utils import batch_dataloader
10 | from perturbench.data.collate import noop_collate
11 | from .models.base import PerturbationModel
12 |
13 | log = logging.getLogger(__name__)
14 |
15 |
16 | def predict(
17 | cfg: DictConfig,
18 | ):
19 | """Predict counterfactual perturbation effects"""
20 | # Set seed for random number generators in pytorch, numpy and python.random
21 | if cfg.get("seed"):
22 | L.seed_everything(cfg.seed, workers=True)
23 |
24 | log.info("Instantiating datamodule <%s>", cfg.data._target_)
25 | datamodule: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
26 |
27 | log.info("Instantiating model <%s>", cfg.model._target_)
28 | model_class: PerturbationModel = hydra.utils.get_class(cfg.model._target_)
29 |
30 | # Load trained model
31 | if not os.path.exists(cfg.ckpt_path):
32 | raise ValueError(f"Checkpoint path {cfg.ckpt_path} does not exist")
33 | if not os.path.exists(cfg.ckpt_path):
34 | raise ValueError(f"Checkpoint path {cfg.ckpt_path} does not exist")
35 |
36 | trained_model: PerturbationModel = model_class.load_from_checkpoint(
37 | cfg.ckpt_path,
38 | datamodule=datamodule,
39 | )
40 |
41 | # Load prediction dataframe
42 | if not os.path.exists(cfg.prediction_dataframe_path):
43 | raise ValueError(
44 | f"Prediction dataframe path {cfg.prediction_dataframe_path} does not exist"
45 | )
46 | pred_df = pd.read_csv(cfg.prediction_dataframe_path)
47 |
48 | if cfg.data.perturbation_key not in pred_df.columns:
49 | raise ValueError(
50 | f"Prediction dataframe must contain column {cfg.data.perturbation_key}"
51 | )
52 | for covariate_key in cfg.data.covariate_keys:
53 | if covariate_key not in pred_df.columns:
54 | raise ValueError(
55 | f"Prediction dataframe must contain column {covariate_key}"
56 | )
57 |
58 | # Create inference dataloader
59 | test_adata = datamodule.test_dataset.reference_adata
60 | control_adata = test_adata[
61 | test_adata.obs[cfg.data.perturbation_key] == cfg.data.perturbation_control_value
62 | ]
63 | del test_adata
64 |
65 | inference_dataset, _ = Counterfactual.from_anndata(
66 | control_adata,
67 | pred_df,
68 | cfg.data.perturbation_key,
69 | perturbation_combination_delimiter=cfg.data.perturbation_combination_delimiter,
70 | covariate_keys=cfg.data.covariate_keys,
71 | perturbation_control_value=cfg.data.perturbation_control_value,
72 | seed=cfg.seed,
73 | max_control_cells_per_covariate=cfg.data.evaluation.max_control_cells_per_covariate,
74 | )
75 | inference_dataset.transform = trained_model.training_record["transform"]
76 | inference_dataloader = batch_dataloader(
77 | inference_dataset,
78 | batch_size=cfg.chunk_size,
79 | num_workers=cfg.data.num_workers,
80 | shuffle=False,
81 | collate_fn=noop_collate(),
82 | )
83 |
84 | log.info("Instantiating trainer <%s>", cfg.trainer._target_)
85 | trainer: L.Trainer = hydra.utils.instantiate(cfg.trainer)
86 |
87 | log.info("Generating predictions")
88 | if not os.path.exists(cfg.output_path):
89 | os.makedirs(cfg.output_path)
90 | trained_model.prediction_output_path = cfg.output_path
91 | trainer.predict(model=trained_model, dataloaders=inference_dataloader)
92 |
93 |
94 | @hydra.main(version_base="1.3", config_path="../configs", config_name="predict.yaml")
95 | def main(cfg: DictConfig):
96 | predict(cfg)
97 |
98 |
99 | if __name__ == "__main__":
100 | main()
101 |
--------------------------------------------------------------------------------
/src/perturbench/modelcore/train.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import List
3 | import hydra
4 | import lightning as L
5 | from omegaconf import DictConfig
6 | from lightning.pytorch.loggers import Logger
7 | from perturbench.modelcore.utils import multi_instantiate
8 | from perturbench.modelcore.models import PerturbationModel
9 | from hydra.core.hydra_config import HydraConfig
10 |
11 |
12 | log = logging.getLogger(__name__)
13 |
14 |
15 | def train(runtime_context: dict):
16 | cfg = runtime_context["cfg"]
17 |
18 | # Set seed for random number generators in pytorch, numpy and python.random
19 | if cfg.get("seed"):
20 | L.seed_everything(cfg.seed, workers=True)
21 |
22 | log.info("Instantiating datamodule <%s>", cfg.data._target_)
23 | datamodule: L.LightningDataModule = hydra.utils.instantiate(
24 | cfg.data,
25 | seed=cfg.seed,
26 | )
27 |
28 | log.info("Instantiating model <%s>", cfg.model._target_)
29 | model: PerturbationModel = hydra.utils.instantiate(cfg.model, datamodule=datamodule)
30 |
31 | log.info("Instantiating callbacks...")
32 | callbacks: List[L.Callback] = multi_instantiate(cfg.get("callbacks"))
33 |
34 | log.info("Instantiating loggers...")
35 | loggers: List[Logger] = multi_instantiate(cfg.get("logger"))
36 |
37 | log.info("Instantiating trainer <%s>", cfg.trainer._target_)
38 | trainer: L.Trainer = hydra.utils.instantiate(
39 | cfg.trainer, callbacks=callbacks, logger=loggers
40 | )
41 |
42 | if cfg.get("train"):
43 | log.info("Starting training!")
44 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
45 |
46 | train_metrics = trainer.callback_metrics
47 |
48 | summary_metrics_dict = {}
49 | if cfg.get("test"):
50 | log.info("Starting testing!")
51 | if cfg.get("train"):
52 | if (
53 | trainer.checkpoint_callback is None
54 | or trainer.checkpoint_callback.best_model_path == ""
55 | ):
56 | ckpt_path = None
57 | else:
58 | ckpt_path = "best"
59 | else:
60 | ckpt_path = cfg.get("ckpt_path")
61 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
62 | summary_metrics_dict = model.summary_metrics.to_dict()[
63 | model.summary_metrics.columns[0]
64 | ]
65 |
66 | test_metrics = trainer.callback_metrics
67 | # merge train and test metrics
68 | metric_dict = {**train_metrics, **test_metrics, **summary_metrics_dict}
69 |
70 | return metric_dict
71 |
72 |
73 | @hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
74 | def main(cfg: DictConfig) -> float | None:
75 | runtime_context = {"cfg": cfg, "trial_number": HydraConfig.get().job.get("num")}
76 |
77 | ## Train the model
78 | global metric_dict
79 | metric_dict = train(runtime_context)
80 |
81 | ## Combined metric
82 | metrics_use = cfg.get("metrics_to_optimize")
83 | if metrics_use:
84 | combined_metric = sum(
85 | [metric_dict.get(metric) * weight for metric, weight in metrics_use.items()]
86 | )
87 | return combined_metric
88 |
89 |
90 | if __name__ == "__main__":
91 | main()
92 |
--------------------------------------------------------------------------------
/src/perturbench/modelcore/utils.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import logging
3 | import warnings
4 | from typing import Callable, List, TypeVar, Any
5 | import hydra
6 | from omegaconf import DictConfig
7 | from lightning import Callback
8 | from lightning.pytorch.loggers import Logger
9 |
10 | log = logging.getLogger(__name__)
11 |
12 | T = TypeVar("T", Logger, Callback)
13 |
14 |
15 | def multi_instantiate(
16 | cfg: DictConfig, context: dict[str, Any] | None = None
17 | ) -> List[T]:
18 | """Instantiates multiple classes from config.
19 |
20 | Instantiates multiple classes from a config object. The configuration object
21 | has the following structure: dict[submodule_name: str, conf: DictConfig]. The
22 | conf contains the configuration for the submodule and adheres to one of the
23 | following schema:
24 |
25 | 1. dict[_target_: str, ...] where _target_ key specifies the class to be
26 | instantiated and the rest of the keys are the arguments to be passed to
27 | the class constructor (usual hydra config schema).
28 |
29 | 2. dict[conf: DictConfig, dependencies: dict[str, str]] where conf contains
30 | the configuration for the submodule as defined in (1) missing requested
31 | dependencies and dependencies is a mapping from the class arguments to
32 | the keys in the context dictionary.
33 |
34 | 3. dict[conf: Callable, dependencies: DictConfig] where conf is a callable
35 | acts as a factory function for the class and dependencies is a mapping
36 | from the class arguments to the keys in the context dictionary.
37 |
38 | Args:
39 | cfg: A DictConfig object containing configurations.
40 | context: A dictionary containing dependencies that the individual
41 | classes might request from.
42 |
43 | Returns:
44 | A list of instantiated classes.
45 |
46 | Raises:
47 | TypeError: If the config is not a DictConfig.
48 | """
49 |
50 | instances: List[T] = []
51 |
52 | if not cfg:
53 | warnings.warn("No configs found! Skipping...")
54 | return instances
55 |
56 | if not isinstance(cfg, DictConfig):
57 | raise TypeError("Config must be a DictConfig!")
58 |
59 | for _, conf in cfg.items():
60 | target_name = conf.__class__.__name__
61 | instance_dependencies = {}
62 | # Resolve dependencies if requested
63 | if isinstance(conf, DictConfig) and "dependencies" in conf:
64 | if "conf" not in conf:
65 | raise TypeError(
66 | "Invalid config schema. If dependencies are requested, then "
67 | "the config must contain a 'conf' section specifying the "
68 | "class arguments not included in the dependencies."
69 | )
70 | # Resolve dependencies if any specified
71 | if conf.dependencies:
72 | if context:
73 | instance_dependencies = {
74 | kwarg: context[key] for kwarg, key in conf.dependencies.items()
75 | }
76 | else:
77 | raise ValueError(
78 | "The config requests dependencies, but none were provided."
79 | )
80 | conf: DictConfig | Callable = conf.conf
81 | # pylint: disable-next=protected-access
82 | target_name = conf.func.__name__ if callable(conf) else conf._target_
83 |
84 | log.info("Instantiating an object of type <%s>", target_name)
85 | if isinstance(conf, partial):
86 | instances.append(conf(**instance_dependencies))
87 | elif isinstance(conf, DictConfig):
88 | if "_target_" in conf:
89 | instances.append(
90 | hydra.utils.instantiate(
91 | conf,
92 | **instance_dependencies,
93 | _recursive_=False,
94 | )
95 | )
96 | else:
97 | raise ValueError(
98 | f"Invalid config schema ({conf}). The config must contain "
99 | "a '_target_' key specifying the class to be instantiated."
100 | )
101 | # Object has already been instantiated
102 | else:
103 | instances.append(conf)
104 |
105 | return instances
106 |
107 |
108 | def instantiate_with_context(
109 | cfg: DictConfig,
110 | context: dict[str, Any] | None = None,
111 | ) -> Any:
112 | if not cfg:
113 | warnings.warn("No configs found! Skipping...")
114 | return None
115 |
116 | if not isinstance(cfg, DictConfig):
117 | raise TypeError("Config must be a DictConfig!")
118 |
119 | if cfg.dependencies:
120 | if context:
121 | dependencies = {
122 | kwarg: context[key] for kwarg, key in cfg.dependencies.items()
123 | }
124 | else:
125 | raise ValueError(
126 | "The config requests dependencies, but none were provided."
127 | )
128 | else:
129 | dependencies = {}
130 |
131 | return cfg.conf(**dependencies)
132 |
--------------------------------------------------------------------------------