├── .env.example ├── .gitattributes ├── .github ├── PULL_REQUEST_TEMPLATE.md ├── codecov.yml ├── dependabot.yml ├── release-drafter.yml └── workflows │ ├── code-quality-main.yaml │ ├── code-quality-pr.yaml │ ├── release-drafter.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .project-root ├── Makefile ├── README.md ├── configs ├── __init__.py ├── callbacks │ ├── default.yaml │ ├── early_stopping.yaml │ ├── model_checkpoint.yaml │ ├── model_summary.yaml │ ├── none.yaml │ └── rich_progress_bar.yaml ├── data │ ├── gears.yaml │ └── perturb.yaml ├── debug │ ├── default.yaml │ ├── fdr.yaml │ ├── limit.yaml │ ├── overfit.yaml │ └── profiler.yaml ├── eval.yaml ├── experiment │ ├── mlp_norman_inference.yaml │ ├── mlp_norman_train.yaml │ ├── mlp_replogle_k562_inference.yaml │ └── mlp_replogle_k562_train.yaml ├── extras │ └── default.yaml ├── hparams_search │ ├── optuna_architecture.yaml │ └── optuna_lr.yaml ├── hydra │ ├── default.yaml │ ├── spectra_data_sweep_1.yaml │ ├── spectra_data_sweep_2.yaml │ └── spectra_data_sweep_3.yaml ├── local │ └── .gitkeep ├── logger │ ├── aim.yaml │ ├── comet.yaml │ ├── csv.yaml │ ├── many_loggers.yaml │ ├── mlflow.yaml │ ├── neptune.yaml │ ├── tensorboard.yaml │ └── wandb.yaml ├── model │ ├── gears.yaml │ ├── mean.yaml │ └── mlp.yaml ├── paths │ └── default.yaml ├── train.yaml └── trainer │ ├── cpu.yaml │ ├── ddp.yaml │ ├── ddp_sim.yaml │ ├── default.yaml │ ├── gpu.yaml │ └── mps.yaml ├── data └── .gitkeep ├── environment.yaml ├── figures └── PertEval-scFM.png ├── logs └── .gitkeep ├── notebooks ├── .gitkeep ├── plots │ ├── aggregated_perturbation_results.ipynb │ ├── contextual_alignment.ipynb │ ├── edistance_vs_mse_analysis.ipynb │ ├── expression_analysis_top20_vs_tail_genes.ipynb │ ├── gene_level_deg_boxplot.ipynb │ ├── individual_perturbation_results.ipynb │ ├── supp_plot_experiment2.ipynb │ ├── visualize_sparsity.ipynb │ └── visualize_spectra_norman_1.ipynb └── preprocessing │ ├── diff_exp_refactored.ipynb │ ├── generate_deg_scripts.ipynb │ └── significant_perts_edist.ipynb ├── pyproject.toml ├── requirements.txt ├── scripts └── schedule.sh ├── setup.py ├── src ├── __init__.py ├── data │ ├── __init__.py │ ├── components │ │ ├── __init__.py │ │ └── embeddings.py │ ├── perturb_datamodule.py │ ├── perturb_dataset.py │ └── reproduction │ │ ├── __init__.py │ │ └── gears │ │ ├── __init__.py │ │ └── gears_datamodule.py ├── eval.py ├── models │ ├── __init__.py │ ├── components │ │ ├── __init__.py │ │ ├── losses.py │ │ └── predictors.py │ ├── gears_module.py │ ├── prediction_module.py │ ├── pretrained_ckpts │ │ └── __init__.py │ └── reproduction │ │ ├── __init__.py │ │ ├── gears │ │ ├── __init__.py │ │ └── gears.py │ │ └── gears_spectra │ │ ├── gears_environment.yml │ │ ├── gears_with_spectra_singlegene.py │ │ └── gears_with_spectra_twogene.py ├── train.py └── utils │ ├── __init__.py │ ├── instantiators.py │ ├── logging_utils.py │ ├── pylogger.py │ ├── rich_utils.py │ ├── spectra │ ├── __init__.py │ ├── dataset.py │ ├── get_splits.py │ ├── independent_set_algo.py │ ├── perturb.py │ ├── spectra.py │ └── utils.py │ └── utils.py └── tests ├── __init__.py ├── conftest.py ├── helpers ├── __init__.py ├── package_available.py ├── run_if.py └── run_sh_command.py ├── test_configs.py ├── test_datamodules.py ├── test_eval.py ├── test_sweeps.py └── test_train.py /.env.example: -------------------------------------------------------------------------------- 1 | # example of file for storing private and user specific environment variables, like keys or system paths 2 | # rename it to ".env" (excluded from version control by default) 3 | # .env is loaded by train.py automatically 4 | # hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR} 5 | 6 | MY_VAR="/home/user/my/system/path" 7 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | * linguist-vendored 2 | *.py linguist-vendored=false -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## What does this PR do? 2 | 3 | 9 | 10 | OPTIONAL: 11 | Fixes #\ 12 | 13 | ## Before submitting 14 | 15 | - [ ] Did you make sure **title is self-explanatory** and **the description concisely explains the PR**? 16 | - [ ] Did you make sure your **PR does only one thing**, instead of bundling different changes together? 17 | - [ ] Did you list all the **breaking changes** introduced by this pull request, if they are present? 18 | - [ ] Did you make sure to **keep any local configs out** of the remote repository? 19 | - [ ] Make sure to **use dev branches for development**. Develop and implement on a personal dev branch. One branch per implementation. Branch naming convention: --, e.g. aaron-dev-gears 20 | 21 | 22 | -------------------------------------------------------------------------------- /.github/codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | # measures overall project coverage 4 | project: 5 | default: 6 | threshold: 100% # how much decrease in coverage is needed to not consider success 7 | 8 | # measures PR or single commit coverage 9 | patch: 10 | default: 11 | threshold: 100% # how much decrease in coverage is needed to not consider success 12 | 13 | 14 | # project: off 15 | # patch: off 16 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "pip" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "daily" 12 | ignore: 13 | - dependency-name: "pytorch-lightning" 14 | update-types: ["version-update:semver-patch"] 15 | - dependency-name: "torchmetrics" 16 | update-types: ["version-update:semver-patch"] 17 | -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name-template: "v$RESOLVED_VERSION" 2 | tag-template: "v$RESOLVED_VERSION" 3 | 4 | categories: 5 | - title: "🚀 Features" 6 | labels: 7 | - "feature" 8 | - "enhancement" 9 | - title: "🐛 Bug Fixes" 10 | labels: 11 | - "fix" 12 | - "bugfix" 13 | - "bug" 14 | - title: "🧹 Maintenance" 15 | labels: 16 | - "maintenance" 17 | - "dependencies" 18 | - "refactoring" 19 | - "cosmetic" 20 | - "chore" 21 | - title: "📝️ Documentation" 22 | labels: 23 | - "documentation" 24 | - "docs" 25 | 26 | change-template: "- $TITLE @$AUTHOR (#$NUMBER)" 27 | change-title-escapes: '\<*_&' # You can add # and @ to disable mentions 28 | 29 | version-resolver: 30 | major: 31 | labels: 32 | - "major" 33 | minor: 34 | labels: 35 | - "minor" 36 | patch: 37 | labels: 38 | - "patch" 39 | default: patch 40 | 41 | template: | 42 | ## Changes 43 | 44 | $CHANGES 45 | -------------------------------------------------------------------------------- /.github/workflows/code-quality-main.yaml: -------------------------------------------------------------------------------- 1 | # Same as `code-quality-pr.yaml` but triggered on commit to main branch 2 | # and runs on all files (instead of only the changed ones) 3 | 4 | name: Code Quality Main 5 | 6 | on: 7 | push: 8 | branches: [main] 9 | 10 | jobs: 11 | code-quality: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - name: Checkout 16 | uses: actions/checkout@v2 17 | 18 | - name: Set up Python 19 | uses: actions/setup-python@v2 20 | 21 | - name: Run pre-commits 22 | uses: pre-commit/action@v2.0.3 23 | -------------------------------------------------------------------------------- /.github/workflows/code-quality-pr.yaml: -------------------------------------------------------------------------------- 1 | # This workflow finds which files were changed, prints them, 2 | # and runs `pre-commit` on those files. 3 | 4 | # Inspired by the sktime library: 5 | # https://github.com/alan-turing-institute/sktime/blob/main/.github/workflows/test.yml 6 | 7 | name: Code Quality PR 8 | 9 | on: 10 | pull_request: 11 | branches: [main, "release/*", "dev"] 12 | 13 | jobs: 14 | code-quality: 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - name: Checkout 19 | uses: actions/checkout@v2 20 | 21 | - name: Set up Python 22 | uses: actions/setup-python@v2 23 | 24 | - name: Find modified files 25 | id: file_changes 26 | uses: trilom/file-changes-action@v1.2.4 27 | with: 28 | output: " " 29 | 30 | - name: List modified files 31 | run: echo '${{ steps.file_changes.outputs.files}}' 32 | 33 | - name: Run pre-commits 34 | uses: pre-commit/action@v2.0.3 35 | with: 36 | extra_args: --files ${{ steps.file_changes.outputs.files}} 37 | -------------------------------------------------------------------------------- /.github/workflows/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name: Release Drafter 2 | 3 | on: 4 | push: 5 | # branches to consider in the event; optional, defaults to all 6 | branches: 7 | - main 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | update_release_draft: 14 | permissions: 15 | # write permission is required to create a github release 16 | contents: write 17 | # write permission is required for autolabeler 18 | # otherwise, read permission is required at least 19 | pull-requests: write 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | # Drafts your next Release notes as Pull Requests are merged into "master" 25 | - uses: release-drafter/release-drafter@v5 26 | env: 27 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 28 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main, "release/*", "dev"] 8 | 9 | jobs: 10 | run_tests_ubuntu: 11 | runs-on: ${{ matrix.os }} 12 | 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | os: ["ubuntu-latest"] 17 | python-version: ["3.8", "3.9", "3.10"] 18 | 19 | timeout-minutes: 20 20 | 21 | steps: 22 | - name: Checkout 23 | uses: actions/checkout@v3 24 | 25 | - name: Set up Python ${{ matrix.python-version }} 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | 30 | - name: Install dependencies 31 | run: | 32 | python -m pip install --upgrade pip 33 | pip install -r requirements.txt 34 | pip install pytest 35 | pip install sh 36 | 37 | - name: List dependencies 38 | run: | 39 | python -m pip list 40 | 41 | - name: Run pytest 42 | run: | 43 | pytest -v 44 | 45 | run_tests_macos: 46 | runs-on: ${{ matrix.os }} 47 | 48 | strategy: 49 | fail-fast: false 50 | matrix: 51 | os: ["macos-latest"] 52 | python-version: ["3.8", "3.9", "3.10"] 53 | 54 | timeout-minutes: 20 55 | 56 | steps: 57 | - name: Checkout 58 | uses: actions/checkout@v3 59 | 60 | - name: Set up Python ${{ matrix.python-version }} 61 | uses: actions/setup-python@v3 62 | with: 63 | python-version: ${{ matrix.python-version }} 64 | 65 | - name: Install dependencies 66 | run: | 67 | python -m pip install --upgrade pip 68 | pip install -r requirements.txt 69 | pip install pytest 70 | pip install sh 71 | 72 | - name: List dependencies 73 | run: | 74 | python -m pip list 75 | 76 | - name: Run pytest 77 | run: | 78 | pytest -v 79 | 80 | run_tests_windows: 81 | runs-on: ${{ matrix.os }} 82 | 83 | strategy: 84 | fail-fast: false 85 | matrix: 86 | os: ["windows-latest"] 87 | python-version: ["3.8", "3.9", "3.10"] 88 | 89 | timeout-minutes: 20 90 | 91 | steps: 92 | - name: Checkout 93 | uses: actions/checkout@v3 94 | 95 | - name: Set up Python ${{ matrix.python-version }} 96 | uses: actions/setup-python@v3 97 | with: 98 | python-version: ${{ matrix.python-version }} 99 | 100 | - name: Install dependencies 101 | run: | 102 | python -m pip install --upgrade pip 103 | pip install -r requirements.txt 104 | pip install pytest 105 | 106 | - name: List dependencies 107 | run: | 108 | python -m pip list 109 | 110 | - name: Run pytest 111 | run: | 112 | pytest -v 113 | 114 | # upload code coverage report 115 | code-coverage: 116 | runs-on: ubuntu-latest 117 | 118 | steps: 119 | - name: Checkout 120 | uses: actions/checkout@v2 121 | 122 | - name: Set up Python 3.10 123 | uses: actions/setup-python@v2 124 | with: 125 | python-version: "3.10" 126 | 127 | - name: Install dependencies 128 | run: | 129 | python -m pip install --upgrade pip 130 | pip install -r requirements.txt 131 | pip install pytest 132 | pip install pytest-cov[toml] 133 | pip install sh 134 | 135 | - name: Run tests and collect coverage 136 | run: pytest --cov src # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER 137 | 138 | - name: Upload coverage to Codecov 139 | uses: codecov/codecov-action@v3 140 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.4.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: check-docstring-first 12 | - id: check-yaml 13 | - id: debug-statements 14 | - id: detect-private-key 15 | - id: check-executables-have-shebangs 16 | - id: check-toml 17 | - id: check-case-conflict 18 | - id: check-added-large-files 19 | 20 | # python code formatting 21 | - repo: https://github.com/psf/black 22 | rev: 23.1.0 23 | hooks: 24 | - id: black 25 | args: [--line-length, "99"] 26 | 27 | # python import sorting 28 | - repo: https://github.com/PyCQA/isort 29 | rev: 5.12.0 30 | hooks: 31 | - id: isort 32 | args: ["--profile", "black", "--filter-files"] 33 | 34 | # python upgrading syntax to newer version 35 | - repo: https://github.com/asottile/pyupgrade 36 | rev: v3.3.1 37 | hooks: 38 | - id: pyupgrade 39 | args: [--py38-plus] 40 | 41 | # python docstring formatting 42 | - repo: https://github.com/myint/docformatter 43 | rev: v1.7.4 44 | hooks: 45 | - id: docformatter 46 | args: 47 | [ 48 | --in-place, 49 | --wrap-summaries=99, 50 | --wrap-descriptions=99, 51 | --style=sphinx, 52 | --black, 53 | ] 54 | 55 | # python docstring coverage checking 56 | - repo: https://github.com/econchick/interrogate 57 | rev: 1.5.0 # or master if you're bold 58 | hooks: 59 | - id: interrogate 60 | args: 61 | [ 62 | --verbose, 63 | --fail-under=80, 64 | --ignore-init-module, 65 | --ignore-init-method, 66 | --ignore-module, 67 | --ignore-nested-functions, 68 | -vv, 69 | ] 70 | 71 | # python check (PEP8), programming errors and code complexity 72 | - repo: https://github.com/PyCQA/flake8 73 | rev: 6.0.0 74 | hooks: 75 | - id: flake8 76 | args: 77 | [ 78 | "--extend-ignore", 79 | "E203,E402,E501,F401,F841,RST2,RST301", 80 | "--exclude", 81 | "logs/*,data/*", 82 | ] 83 | additional_dependencies: [flake8-rst-docstrings==0.3.0] 84 | 85 | # python security linter 86 | - repo: https://github.com/PyCQA/bandit 87 | rev: "1.7.5" 88 | hooks: 89 | - id: bandit 90 | args: ["-s", "B101"] 91 | 92 | # yaml formatting 93 | - repo: https://github.com/pre-commit/mirrors-prettier 94 | rev: v3.0.0-alpha.6 95 | hooks: 96 | - id: prettier 97 | types: [yaml] 98 | exclude: "environment.yaml" 99 | 100 | # shell scripts linter 101 | - repo: https://github.com/shellcheck-py/shellcheck-py 102 | rev: v0.9.0.2 103 | hooks: 104 | - id: shellcheck 105 | 106 | # md formatting 107 | - repo: https://github.com/executablebooks/mdformat 108 | rev: 0.7.16 109 | hooks: 110 | - id: mdformat 111 | args: ["--number"] 112 | additional_dependencies: 113 | - mdformat-gfm 114 | - mdformat-tables 115 | - mdformat_frontmatter 116 | # - mdformat-toc 117 | # - mdformat-black 118 | 119 | # word spelling linter 120 | - repo: https://github.com/codespell-project/codespell 121 | rev: v2.2.4 122 | hooks: 123 | - id: codespell 124 | args: 125 | - --skip=logs/**,data/**,*.ipynb 126 | # - --ignore-words-list=abc,def 127 | 128 | # jupyter notebook cell output clearing 129 | - repo: https://github.com/kynan/nbstripout 130 | rev: 0.6.1 131 | hooks: 132 | - id: nbstripout 133 | 134 | # jupyter notebook linting 135 | - repo: https://github.com/nbQA-dev/nbQA 136 | rev: 1.6.3 137 | hooks: 138 | - id: nbqa-black 139 | args: ["--line-length=99"] 140 | - id: nbqa-isort 141 | args: ["--profile=black"] 142 | - id: nbqa-flake8 143 | args: 144 | [ 145 | "--extend-ignore=E203,E402,E501,F401,F841", 146 | "--exclude=logs/*,data/*", 147 | ] 148 | -------------------------------------------------------------------------------- /.project-root: -------------------------------------------------------------------------------- 1 | # this file is required for inferring the project root directory 2 | # do not delete 3 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | help: ## Show help 3 | @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 4 | 5 | clean: ## Clean autogenerated files 6 | rm -rf dist 7 | find . -type f -name "*.DS_Store" -ls -delete 8 | find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf 9 | find . | grep -E ".pytest_cache" | xargs rm -rf 10 | find . | grep -E ".ipynb_checkpoints" | xargs rm -rf 11 | rm -f .coverage 12 | 13 | clean-logs: ## Clean logs 14 | rm -rf logs/** 15 | 16 | format: ## Run pre-commit hooks 17 | pre-commit run -a 18 | 19 | sync: ## Merge changes from main branch to your current branch 20 | git pull 21 | git pull origin main 22 | 23 | test: ## Run not slow tests 24 | pytest -k "not slow" 25 | 26 | test-full: ## Run all tests 27 | pytest 28 | 29 | train: ## Train the model 30 | python src/train.py 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # PertEval: Evaluating Single-Cell Foundation Models for Perturbation Response Prediction 4 | 5 | PyTorch 6 | Lightning 7 | Config: Hydra 8 | Template
9 | 12 |
13 | 14 | PertEval is a comprehensive evaluation framework designed for perturbation response 15 | prediction. 16 | 17 | Key features: 18 | 19 | - **Extensive Model Support**: Evaluate a wide range of single-cell foundation models 20 | using simple probes for perturbation response prediction. 21 | - **Standardized Evaluations**: Consistent benchmarking protocols and metrics for fair 22 | comparisons in transcriptomic perturbation prediction. 23 | - **Flexible Integration**: Easily extend the codebase with custom models and datasets for 24 | perturbation prediction tasks. 25 | - **Modular Design**: Built on top of PyTorch Lightning and Hydra, ensuring code 26 | organization and configurability. 27 | 28 | PertEval-scFM is composed of three mains parts: data pre-processing, model training and 29 | evaluation 30 | 31 | ![PertEval-scFM Graphical Abstract](figures/PertEval-scFM.png) 32 | 33 | ## Installation 34 | 35 | 40 | 41 | To get PertEval up and running, first clone the GitHub repo: 42 | 43 | ```bash 44 | # clone project 45 | git clone https://github.com/aaronwtr/PertEval 46 | cd PertEval 47 | ``` 48 | 49 | Set up a new conda or virtual environment and install the required dependencies: 50 | 51 | ```bash 52 | # Conda 53 | conda create -n perteval python=3.10 54 | conda activate perteval 55 | 56 | # Virtualenv 57 | python3.10 -m venv perteval 58 | ### Windows: 59 | perteval\Scripts\activate 60 | ### MacOS/Linux 61 | source perteval/bin/activate 62 | ``` 63 | 64 | For a Windows install of torch with CUDA support, first run: 65 | 66 | ```bash 67 | # Windows 68 | pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 69 | ``` 70 | then run: 71 | 72 | ```bash 73 | #Windows/MacOS/Linux 74 | pip install -r requirements.txt 75 | ``` 76 | You'll need to create a Weights and Biases account if you don't have one already, and then login by pasting your API key when prompted 77 | ```bash 78 | wandb login 79 | ``` 80 | 81 | alternatively, you can log in via command line before starting the training process. 82 | 83 | ## Single-cell Foundation Models 84 | 85 | We currently host embeddings of the Norman _et al._, 2019 labeled Perturb-seq dataset 86 | for 1 gene (`Norman_1`) and 2 gene perturbations (`Norman_2`) 87 | extracted with the following single-cell foundation models (scFM): 88 | 89 | | **Model name** | **Architecture** | **Pre-training objective** | **# of cells** | **Organism** | **Emb. dim.** | 90 | |-------------------|----------------|--------------------------|------------------|--------------|--------------| 91 | | **scBERT** | Performer | Masked language modeling | ~5 million | human mouse | 200 | 92 | | **Geneformer** | Transformer | Masked language modeling | ~30 million | human | 256 | 93 | | **scGPT** | Transformer | Specialized attention-masking mechanism | ~33 million | human | 512 | 94 | | **UCE** | Transformer | Masked language modeling | ~36 million | 8 species | 1,280 | 95 | | **scFoundation** | Transformer | Read-depth-aware modeling | ~50 million | human | 3,072 | 96 | 97 | ### Embeddings 98 | 99 | The control expression data and scFM embeddings will be automatically 100 | downloaded, stored and preprocessed during the initial training run. The 101 | embeddings will be stored in the `/data/splits/perturb/norman_x/embeddings` directory. 102 | 103 | ## Training and Evaluation 104 | 105 | The main entry point for training and validation of a model is `train.py`, 106 | which will load your data, model, configs and run the training and validation process. 107 | `eval.py` will evaluate a trained model on the test set. You can run training and 108 | testing using the best checkpoints for the run by setting both `train` and `test` to 109 | `True` in `train.yaml`. 110 | 111 | To run a specific experiment, point to the corresponding configuration file from the 112 | [configs/experiment/](configs/experiment/) directory. For example: 113 | 114 | ```bash 115 | python src/train.py experiment=mlp_norman_train.yaml 116 | ``` 117 | 118 | In the config file you will be able to modify the configuration file to suit 119 | your needs, such as batch size, learning rate, or number of epochs. You can also 120 | override these parameters from the command line. For example: 121 | 122 | ```bash 123 | python src/train.py trainer.max_epochs=20 data.batch_size=64 124 | ``` 125 | 126 | For **Norman_1** the input size is **2060**, and for **Norman_2** the input size 127 | is **2064**. The embedding dimension of each scFM can be found in the table above. 128 | The hidden layer is embed. dim. / 2. 129 | 130 | ### Modeling distribution shift with SPECTRA 131 | 132 | We model distribution shift by creating increasingly challenging train-test splits with 133 | SPECTRA, a graph-based method which controls for cross-split overlap between 134 | train-test data. The splits are created during the initial training run and stored in 135 | `/data/splits/perturb/norman_x/norman_x_SPECTRA_splits` directory. The sparsification 136 | probability (_s_) controls the connectivity in the sample-to-sample similarity graph. To 137 | assess distribution shift, you will have to train and test the model on the different 138 | values of _s_ in the config file `split: 0.0` to `split: 0.5`. If you want to investigate the train-test splits, this can be done in 139 | [plots/visualize_spectra_norman_1.ipynb](notebooks/plots/visualize_spectra_norman_1.ipynb) 140 | 141 | ### Evaluation 142 | 143 | If you want to evaluate the test set with specific model weights, run the following 144 | command with the path to the checkpoint file 145 | 146 | ```bash 147 | python src/eval.py ckpt_path="/path/to/ckpt/name.ckpt" 148 | ``` 149 | 150 | ## Evaluating on differentially expressed gene (DEGs) 151 | 152 | **WIP**: we are working on integrating this workflow into the main pipeline. Meanwhile, 153 | you can follow the steps below to evaluate a perturbation on DEGs. 154 | 155 | Step 1) Calculate significant perturbations with 156 | E-test [notebooks/preprocessing/significant_perts_edist.ipynb](notebooks/preprocessing/significant_perts_edist.ipynb) 157 | 158 | Step 2) Calculate differentially expressed genes for all significant 159 | perturbations [notebooks/preprocessing/diff_exp_refactored.ipynb](notebooks/preprocessing/diff_exp_refactored.ipynb) 160 | 161 | Step 3) Prepare the inference config [configs/experiment/mlp_norman_inference.yaml](configs/experiment/mlp_norman_inference.yaml) with the following parameters: 162 | 163 | - Add the path to the .ckpt file 164 | - Add the model you want to use 165 | - Add the perturbation to be inspected 166 | - Set the proper split and replicate corresponding to the perturbation. 167 | - Update the corresponding hidden and embedding dimensions 168 | 169 | Step 4) Run `eval.py` with the inference config file. 170 | 171 | #### Using a Subset of Data 172 | 173 | In some cases, you might want to train or evaluate your model on a smaller subset of your 174 | data, either for debugging purposes or to speed up the training process. PyTorch Lightning 175 | provides options to limit the number of batches used for training, validation, and 176 | testing. For example, to use only 20% of your data for each of these stages, you can run 177 | the following command: 178 | 179 | ```bash 180 | python train.py +trainer.limit_train_batches=0.2 \ 181 | +trainer.limit_val_batches=0.2 +trainer.limit_test_batches=0.2 182 | ``` 183 | 184 | This mode can be useful when you want to quickly test your code or debug issues with a 185 | smaller subset of your data, or when you want to perform a quick sanity check on your 186 | model's performance before running the full training or evaluation process. 187 | 188 | 189 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | # this file is needed here to include configs when building project as a package 2 | -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model_checkpoint 3 | # - early_stopping 4 | - model_summary 5 | # - rich_progress_bar # only enable if running in terminal 6 | - _self_ 7 | 8 | #model_checkpoint: 9 | # dirpath: ${paths.output_dir}/checkpoints 10 | # filename: "best_model_at_epoch_{epoch:03d}" 11 | # monitor: "val/mse" 12 | # mode: "min" 13 | # save_last: True 14 | # auto_insert_metric_name: False 15 | 16 | #early_stopping: 17 | # monitor: "val/acc" 18 | # patience: 100 19 | # mode: "max" 20 | 21 | model_summary: 22 | max_depth: -1 23 | -------------------------------------------------------------------------------- /configs/callbacks/early_stopping.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html 2 | 3 | early_stopping: 4 | _target_: lightning.pytorch.callbacks.EarlyStopping 5 | monitor: val/rmse # quantity to be monitored, must be specified !!! 6 | min_delta: 0.0001 # minimum change in the monitored quantity to qualify as an improvement 7 | patience: 20 # number of checks with no improvement after which training will be stopped 8 | verbose: False # verbosity mode 9 | mode: "min" # "max" means higher metric value is better, can be also "min" 10 | strict: True # whether to crash the training if monitor is not found in the validation metrics 11 | check_finite: True # when set True, stops training when the monitor becomes NaN or infinite 12 | stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold 13 | divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold 14 | check_on_train_epoch_end: False # whether to run early stopping at the end of the training epoch 15 | # log_rank_zero_only: False # this keyword argument isn't available in stable version 16 | -------------------------------------------------------------------------------- /configs/callbacks/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html 2 | 3 | model_checkpoint: 4 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 5 | dirpath: ${paths.log_dir}train/runs/${now:%Y-%m-%d}/${now:%H-%M-%S}/checkpoints # path to save the model file 6 | filename: ${data.fm}_${data.split}_${data.replicate} # checkpoint filename 7 | monitor: "val/mse" # name of the logged metric which determines when model is improving 8 | verbose: False # verbosity mode 9 | save_last: True # additionally always save an exact copy of the last checkpoint to a file last.ckpt 10 | save_top_k: 1 # save k best models (determined by above metric) 11 | mode: "min" # "max" means higher metric value is better, can be also "min" 12 | auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name 13 | save_weights_only: False # if True, then only the model’s weights will be saved 14 | every_n_train_steps: null # number of training steps between checkpoints 15 | train_time_interval: null # checkpoints are monitored at the specified time interval 16 | every_n_epochs: null # number of epochs between checkpoints 17 | save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation -------------------------------------------------------------------------------- /configs/callbacks/model_summary.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html 2 | 3 | model_summary: 4 | _target_: lightning.pytorch.callbacks.RichModelSummary 5 | max_depth: 1 # the maximum depth of layer nesting that the summary will include 6 | -------------------------------------------------------------------------------- /configs/callbacks/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronwtr/PertEval/efbfa51991fbd6faa6039619d754e354be40fc07/configs/callbacks/none.yaml -------------------------------------------------------------------------------- /configs/callbacks/rich_progress_bar.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html 2 | 3 | rich_progress_bar: 4 | _target_: lightning.pytorch.callbacks.RichProgressBar 5 | -------------------------------------------------------------------------------- /configs/data/gears.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.reproduction.gears.gears_datamodule.GEARSDataModule 2 | data_dir: ${paths.data_dir} 3 | batch_size: 32 # Needs to be divisible by the number of devices (e.g., if in a distributed setup) 4 | num_workers: 0 5 | pin_memory: False 6 | -------------------------------------------------------------------------------- /configs/data/perturb.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.perturb_datamodule.PertDataModule 2 | data_dir: ${paths.data_dir} 3 | data_name: null # Has to be specified from experiment config 4 | data_type: null # Has to be specified from experiment config 5 | batch_size: 32 # Needs to be divisible by the number of devices (e.g., if in a distributed setup) 6 | spectra_parameters: 7 | 'number_repeats': 3 8 | 'random_seed': [42, 44, 46] 9 | 'sparsification_step': 0.1 10 | 'force_reconstruct': True 11 | split: 0.00 12 | replicate: 0 13 | eval_type: null # Has to be specified from experiment config 14 | fm: 'raw_expression' 15 | num_workers: 0 16 | pin_memory: False 17 | -------------------------------------------------------------------------------- /configs/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | # overwrite task name so debugging logs are stored in separate folder 7 | task_name: "debug" 8 | 9 | # disable callbacks and loggers during debugging 10 | callbacks: null 11 | logger: null 12 | 13 | extras: 14 | ignore_warnings: False 15 | enforce_tags: False 16 | 17 | # sets level of all command line loggers to 'DEBUG' 18 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 19 | hydra: 20 | job_logging: 21 | root: 22 | level: DEBUG 23 | 24 | # use this to also set hydra loggers to 'DEBUG' 25 | # verbose: True 26 | 27 | trainer: 28 | max_epochs: 1 29 | accelerator: cpu # debuggers don't like gpus 30 | devices: 1 # debuggers don't like multiprocessing 31 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor 32 | 33 | data: 34 | num_workers: 0 # debuggers don't like multiprocessing 35 | pin_memory: False # disable gpu memory pin 36 | -------------------------------------------------------------------------------- /configs/debug/fdr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs 1 train, 1 validation and 1 test step 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | fast_dev_run: true 10 | -------------------------------------------------------------------------------- /configs/debug/limit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # uses only 1% of the training data and 5% of validation/test data 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 3 10 | limit_train_batches: 0.01 11 | limit_val_batches: 0.05 12 | limit_test_batches: 0.05 13 | -------------------------------------------------------------------------------- /configs/debug/overfit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # overfits to 3 batches 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 20 10 | overfit_batches: 3 11 | 12 | # model ckpt and early stopping need to be disabled during overfitting 13 | callbacks: null 14 | -------------------------------------------------------------------------------- /configs/debug/profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs with execution time profiling 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 1 10 | profiler: "simple" 11 | # profiler: "advanced" 12 | # profiler: "pytorch" 13 | -------------------------------------------------------------------------------- /configs/eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - data: perturb 6 | - model: null # must be specified in experiment 7 | - callbacks: default 8 | - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`) 9 | - trainer: cpu 10 | - paths: default 11 | - extras: default 12 | - hydra: default 13 | 14 | # experiment configs allow for version control of specific hyperparameters 15 | # e.g. best hyperparameters for given model and datamodule 16 | - experiment: null 17 | 18 | model_type: null # set in experiment config 19 | 20 | task_name: "test" 21 | 22 | tags: ["dev"] 23 | 24 | # passing checkpoint path is necessary for evaluation 25 | ckpt_path: "test" 26 | -------------------------------------------------------------------------------- /configs/experiment/mlp_norman_inference.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # run with `python src/eval.py experiment=mlp_norman_inference` 4 | 5 | model_type: "mlp" # pick from mlp or lr 6 | 7 | defaults: 8 | - override /model: mlp 9 | - override /logger: wandb 10 | 11 | total_genes: 2060 12 | emb_dim: 2060 13 | hidden_dim: 128 14 | mean_adjusted: false 15 | save_dir: ${paths.data_dir}/${data.data_name}/pert_effects/${data.eval_pert}/pert_effect_pred_${data.fm}.pkl 16 | train_date: "2024-10-12" # date of the training run in the format YYYY-MM-DD 17 | timestamp: "18-10-41" # time of the training run in the format HH-MM-SS 18 | 19 | data: 20 | data_name: "norman_1" 21 | data_type: "geneformer" 22 | deg_eval: false 23 | eval_pert: null 24 | split: 0.0 25 | replicate: 0 26 | fm: "geneformer" 27 | 28 | trainer: 29 | num_sanity_val_steps: 0 30 | inference_mode: true 31 | accelerator: gpu 32 | 33 | ckpt_path: ${paths.log_dir}train/runs/${train_date}/${timestamp}/checkpoints/${callbacks.model_checkpoint.filename}.ckpt 34 | 35 | logger: 36 | wandb: 37 | tags: ["eval", "norman", "${data.eval_pert}", "${data.fm}", "split_${data.split}", "replicate_${data.replicate}"] 38 | group: "${model_type}_${data.data_name}_${data.split}" 39 | project: "perturbench-local" 40 | -------------------------------------------------------------------------------- /configs/experiment/mlp_norman_train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # run with `python src/train.py experiment=mlp_norman_train` 4 | 5 | model_type: "mlp" # pick from mean, mlp or lr 6 | 7 | defaults: 8 | - override /model: mlp 9 | - override /logger: wandb 10 | 11 | total_genes: 2060 # 2060 for norman_1, 2064 for norman_2 12 | emb_dim: 200 # 256/512 for geneformer Norman_1/Norman_2, 200 for scBert, 512 for scGPT, 1280 for UCE, 3072 for scFoundation 13 | hidden_dim: 100 # embed_dim / 2 14 | mean_adjusted: false 15 | save_dir: ${paths.data_dir}/${data.data_name}/pert_effects/${data.eval_pert}/pert_effect_pred_${data.fm}.pkl 16 | 17 | 18 | data: 19 | data_name: "norman_1" 20 | data_type: "scbert" 21 | split: 0.0 22 | deg_eval: false 23 | eval_pert: null 24 | replicate: 0 25 | batch_size: 64 26 | fm: "scbert" 27 | 28 | trainer: 29 | max_epochs: 100 30 | accelerator: gpu 31 | devices: 1 32 | 33 | callbacks: 34 | learning_rate_monitor: 35 | _target_: lightning.pytorch.callbacks.LearningRateMonitor 36 | logging_interval: 'epoch' 37 | 38 | logger: 39 | wandb: 40 | tags: ["${model_type}", "${data.data_name}", "${data.fm}","split_${data.split}", "replicate_${data.replicate}", "hpo"] 41 | group: "${model_type}_${data.data_name}_${data.split}" 42 | project: "perturbench-local" 43 | 44 | model: 45 | optimizer: 46 | _target_: torch.optim.Adam 47 | _partial_: true 48 | lr: 5e-6 49 | weight_decay: 0 50 | 51 | net: 52 | _target_: src.models.components.predictors.MLP 53 | in_dim: ${eval:'${emb_dim}*2'} 54 | # in_dim: ${emb_dim} 55 | 56 | scheduler: 57 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 58 | _partial_: true 59 | mode: 'min' 60 | factor: 0.1 61 | patience: 10 62 | min_lr: 5e-9 63 | 64 | data_name: "${data.data_name}" 65 | -------------------------------------------------------------------------------- /configs/experiment/mlp_replogle_k562_inference.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # run with `python src/eval.py experiment=mlp_norman_inference` 4 | 5 | model_type: "mlp" # pick from mlp or lr 6 | 7 | defaults: 8 | - override /model: mean 9 | - override /logger: wandb 10 | 11 | total_genes: 3601 12 | emb_dim: 3601 13 | hidden_dim: 1024 14 | mean_adjusted: false 15 | save_dir: ${paths.data_dir}/${data.data_name}/pert_effects/${data.eval_pert}/pert_effect_pred_${data.fm}.pkl 16 | train_date: "2024-12-31" # date of the training run in the format YYYY-MM-DD 17 | timestamp: "17-14-21" # time of the training run in the format HH-MM-SS 18 | 19 | data: 20 | data_name: "replogle_k562" 21 | data_type: "scfoundation" 22 | deg_eval: false 23 | eval_pert: null 24 | split: 0.5 25 | replicate: 2 26 | batch_size: 1 27 | fm: "scfoundation" 28 | 29 | trainer: 30 | num_sanity_val_steps: 0 31 | inference_mode: true 32 | accelerator: gpu 33 | devices: 1 34 | precision: 16-mixed 35 | 36 | # ckpt_path: ${paths.log_dir}train/runs/${train_date}/${timestamp}/checkpoints/${callbacks.model_checkpoint.filename}.ckpt 37 | # ckpt_path: ${paths.log_dir}train/runs/${train_date}/${timestamp}/checkpoints/last.ckpt 38 | ckpt_path: null 39 | 40 | logger: 41 | wandb: 42 | tags: ["eval", "${data.data_name}", "${data.fm}", "split_${data.split}", "replicate_${data.replicate}"] 43 | group: "${model_type}_${data.data_name}_${data.split}" 44 | project: "perturbench-local" 45 | -------------------------------------------------------------------------------- /configs/experiment/mlp_replogle_k562_train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # run with `python src/train.py experiment=mlp_norman_train` 4 | 5 | model_type: "mlp" # pick from mlp or lr 6 | 7 | defaults: 8 | - override /model: mlp 9 | - override /logger: wandb 10 | 11 | total_genes: 3601 # 3601 for replogle_k562, 2061 for norman_1, 2064 for norman_2 12 | emb_dim: 3601 # 256/512 for geneformer Norman_1/Norman_2, 200 for scBert, 512 for scGPT, 1280 for UCE, 3072 for scFoundation 13 | hidden_dim: 1024 # embed_dim / 2 14 | mean_adjusted: false 15 | save_dir: ${paths.data_dir}/${data.data_name}/pert_effects/${data.eval_pert}/pert_effect_pred_${data.fm}.pkl 16 | 17 | 18 | data: 19 | data_name: "replogle_k562" 20 | data_type: "raw_expression" 21 | split: 0.2 22 | deg_eval: false 23 | eval_pert: null 24 | replicate: 1 25 | batch_size: 1 26 | fm: "raw_expression" 27 | 28 | trainer: 29 | max_epochs: 100 30 | accelerator: gpu 31 | devices: 1 32 | precision: 32 33 | 34 | callbacks: 35 | learning_rate_monitor: 36 | _target_: lightning.pytorch.callbacks.LearningRateMonitor 37 | logging_interval: 'epoch' 38 | 39 | logger: 40 | wandb: 41 | tags: ["${model_type}", "${data.data_name}", "${data.fm}","split_${data.split}", "replicate_${data.replicate}", "hpo"] 42 | group: "${model_type}_${data.data_name}_${data.split}" 43 | project: "perturbench-local" 44 | 45 | model: 46 | optimizer: 47 | _target_: torch.optim.Adam 48 | _partial_: true 49 | lr: 5e-6 50 | weight_decay: 0 51 | 52 | scheduler: 53 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 54 | _partial_: true 55 | mode: 'min' 56 | factor: 0.1 57 | patience: 10 58 | min_lr: 5e-9 59 | 60 | data_name: "${data.data_name}" -------------------------------------------------------------------------------- /configs/extras/default.yaml: -------------------------------------------------------------------------------- 1 | # disable python warnings if they annoy you 2 | ignore_warnings: False 3 | 4 | # ask user for tags if none are provided in the config 5 | enforce_tags: False 6 | 7 | # pretty print config tree at the start of the run using Rich library 8 | print_config: True 9 | 10 | 11 | distributed_storage: True -------------------------------------------------------------------------------- /configs/hparams_search/optuna_architecture.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python train.py -m hparams_search=mnist_optuna experiment=example 5 | 6 | defaults: 7 | - override /hydra/sweeper: optuna 8 | 9 | # choose metric which will be optimized by Optuna 10 | # make sure this is the correct name of some metric logged in lightning module! 11 | optimized_metric: "val/mse" 12 | 13 | # here we define Optuna hyperparameter search 14 | # it optimizes for value returned from function with @hydra.main decorator 15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper 16 | hydra: 17 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 18 | 19 | sweeper: 20 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 21 | 22 | # storage URL to persist optimization results 23 | # for example, you can use SQLite if you set 'sqlite:///example.db' 24 | storage: null 25 | 26 | # name of the study to persist optimization results 27 | study_name: perturbench_mlp_arch_hpo 28 | 29 | # number of parallel workers 30 | n_jobs: 2 31 | 32 | # 'minimize' or 'maximize' the objective 33 | direction: minimize 34 | 35 | # total number of runs that will be executed 36 | n_trials: 30 37 | 38 | # choose Optuna hyperparameter sampler 39 | # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others 40 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html 41 | sampler: 42 | _target_: optuna.samplers.TPESampler 43 | seed: 1234 44 | n_startup_trials: 5 # number of random sampling runs before optimization starts 45 | 46 | # define hyperparameter search space 47 | params: 48 | model.net.hidden_dim: choice(256, 512, 1024, 2048) 49 | -------------------------------------------------------------------------------- /configs/hparams_search/optuna_lr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python train.py -m hparams_search=mnist_optuna experiment=example 5 | 6 | defaults: 7 | - override /hydra/sweeper: optuna 8 | 9 | # choose metric which will be optimized by Optuna 10 | # make sure this is the correct name of some metric logged in lightning module! 11 | optimized_metric: "val/mse" 12 | 13 | # here we define Optuna hyperparameter search 14 | # it optimizes for value returned from function with @hydra.main decorator 15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper 16 | hydra: 17 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 18 | 19 | sweeper: 20 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 21 | 22 | # storage URL to persist optimization results 23 | # for example, you can use SQLite if you set 'sqlite:///example.db' 24 | storage: null 25 | 26 | # name of the study to persist optimization results 27 | study_name: perturbench_mlp_lr_hpo 28 | 29 | # number of parallel workers 30 | n_jobs: 2 31 | 32 | # 'minimize' or 'maximize' the objective 33 | direction: minimize 34 | 35 | # total number of runs that will be executed 36 | n_trials: 30 37 | 38 | # choose Optuna hyperparameter sampler 39 | # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others 40 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html 41 | sampler: 42 | _target_: optuna.samplers.TPESampler 43 | seed: 1234 44 | n_startup_trials: 5 # number of random sampling runs before optimization starts 45 | 46 | # define hyperparameter search space 47 | params: 48 | model.optimizer.lr: choice(9e-6, 1e-5, 2e-5, 3e-5, 4e-3, 5e-5) 49 | model.scheduler.patience: choice(5, 10, 15, 20, 25) 50 | model.scheduler.factor: choice(0.1, 0.2, 0.3, 0.4, 0.5) 51 | 52 | -------------------------------------------------------------------------------- /configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}_${data.fm}_${data.split}_${data.replicate} 11 | sweep: 12 | dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}_${data.fm}_${data.split}_${data.replicate} 13 | subdir: ${hydra.job.num} 14 | 15 | job_logging: 16 | handlers: 17 | file: 18 | # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 19 | filename: "${hydra.runtime.output_dir}/train_${model_type}_${data.split}_${data.replicate}.log" 20 | #filename: "${hydra.runtime.output_dir}_${task_name}_${model_type}_${data.split}_${data.replicate}/train.log" 21 | -------------------------------------------------------------------------------- /configs/hydra/spectra_data_sweep_1.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # run with `--multirun hydra=spectra_data_sweep` flag 4 | hydra: 5 | sweeper: 6 | params: 7 | data.split: 0 8 | data.replicate: 0 9 | -------------------------------------------------------------------------------- /configs/hydra/spectra_data_sweep_2.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # run with `--multirun hydra=spectra_data_sweep` flag 4 | hydra: 5 | sweeper: 6 | params: 7 | data.split: 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7 8 | data.replicate: 1 -------------------------------------------------------------------------------- /configs/hydra/spectra_data_sweep_3.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # run with `--multirun hydra=spectra_data_sweep` flag 4 | hydra: 5 | sweeper: 6 | params: 7 | data.split: 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7 8 | data.replicate: 2 9 | -------------------------------------------------------------------------------- /configs/local/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronwtr/PertEval/efbfa51991fbd6faa6039619d754e354be40fc07/configs/local/.gitkeep -------------------------------------------------------------------------------- /configs/logger/aim.yaml: -------------------------------------------------------------------------------- 1 | # https://aimstack.io/ 2 | 3 | # example usage in lightning module: 4 | # https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py 5 | 6 | # open the Aim UI with the following command (run in the folder containing the `.aim` folder): 7 | # `aim up` 8 | 9 | aim: 10 | _target_: aim.pytorch_lightning.AimLogger 11 | repo: ${paths.root_dir} # .aim folder will be created here 12 | # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html# 13 | 14 | # aim allows to group runs under experiment name 15 | experiment: null # any string, set to "default" if not specified 16 | 17 | train_metric_prefix: "train/" 18 | val_metric_prefix: "val/" 19 | test_metric_prefix: "test/" 20 | 21 | # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.) 22 | system_tracking_interval: 10 # set to null to disable system metrics tracking 23 | 24 | # enable/disable logging of system params such as installed packages, git info, env vars, etc. 25 | log_system_params: true 26 | 27 | # enable/disable tracking console logs (default value is true) 28 | capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550 29 | -------------------------------------------------------------------------------- /configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: lightning.pytorch.loggers.comet.CometLogger 5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable 6 | save_dir: "${paths.output_dir}" 7 | project_name: "lightning-hydra-template" 8 | rest_api_key: null 9 | # experiment_name: "" 10 | experiment_key: null # set to resume experiment 11 | offline: False 12 | prefix: "" 13 | -------------------------------------------------------------------------------- /configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: lightning.pytorch.loggers.csv_logs.CSVLogger 5 | save_dir: "${paths.output_dir}" 6 | name: "csv/" 7 | prefix: "" 8 | -------------------------------------------------------------------------------- /configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - comet 5 | - csv 6 | # - mlflow 7 | # - neptune 8 | - tensorboard 9 | - wandb 10 | -------------------------------------------------------------------------------- /configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://mlflow.org 2 | 3 | mlflow: 4 | _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger 5 | # experiment_name: "" 6 | # run_name: "" 7 | tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI 8 | tags: null 9 | # save_dir: "./mlruns" 10 | prefix: "" 11 | artifact_location: null 12 | # run_id: "" 13 | -------------------------------------------------------------------------------- /configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: lightning.pytorch.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | project: username/lightning-hydra-template 7 | # name: "" 8 | log_model_checkpoints: True 9 | prefix: "" 10 | -------------------------------------------------------------------------------- /configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "${paths.output_dir}/tensorboard/" 6 | name: null 7 | log_graph: False 8 | default_hp_metric: True 9 | prefix: "" 10 | # version: "" 11 | -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: lightning.pytorch.loggers.wandb.WandbLogger 5 | # name: "" # name of the run (normally generated by wandb) 6 | save_dir: "${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}/checkpoints" 7 | offline: False 8 | id: null # pass correct id to resume experiment! 9 | anonymous: null # enable anonymous logging 10 | project: "perturbench-local" 11 | log_model: False # upload lightning ckpts 12 | prefix: "" # a string to put at the beginning of metric keys 13 | # entity: "" # set to name of your wandb team 14 | group: "" 15 | tags: ${tags} 16 | job_type: "" 17 | -------------------------------------------------------------------------------- /configs/model/gears.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.gears_module.GEARSLitModule 2 | 3 | net: 4 | _target_: src.models.reproduction.gears.gears.GEARSNetwork 5 | hidden_size: 64 6 | num_go_gnn_layers: 1 7 | num_gene_gnn_layers: 1 8 | decoder_hidden_size: 16 9 | num_similar_genes_go_graph: 20 10 | num_similar_genes_co_express_graph: 20 11 | coexpress_threshold: 0.4 12 | uncertainty: false 13 | uncertainty_reg: 1 14 | direction_lambda: 0.1 15 | G_go: 16 | G_go_weight: 17 | G_coexpress: 18 | G_coexpress_weight: 19 | no_perturb: false 20 | pert_emb_lambda: 0.2 21 | num_genes: 5045 22 | num_perts: 9853 23 | 24 | pertmodule: 25 | _target_: src.data.perturb_datamodule.PertDataModule 26 | 27 | optimizer: 28 | _target_: torch.optim.Adam 29 | _partial_: true 30 | lr: 0.001 31 | weight_decay: 0.0005 32 | 33 | scheduler: 34 | _target_: torch.optim.lr_scheduler.StepLR 35 | _partial_: true 36 | step_size: 1 37 | gamma: 0.5 38 | 39 | model_name: gears 40 | 41 | # compile model for faster training with pytorch 2.0 42 | compile: false 43 | -------------------------------------------------------------------------------- /configs/model/mean.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.prediction_module.PredictionModule 2 | 3 | net: 4 | _target_: src.models.components.predictors.MeanExpression 5 | -------------------------------------------------------------------------------- /configs/model/mlp.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.prediction_module.PredictionModule 2 | 3 | net: 4 | _target_: src.models.components.predictors.MLP 5 | in_dim: ${eval:'${emb_dim}*2'} 6 | hidden_dim: ${hidden_dim} 7 | out_dim: ${total_genes} 8 | num_layers: 1 9 | 10 | optimizer: 11 | _target_: torch.optim.Adam 12 | _partial_: true 13 | lr: 1e-3 14 | weight_decay: 0.0 15 | 16 | criterion: 17 | _target_: torch.nn.MSELoss 18 | 19 | # set these in experiment config 20 | scheduler: null 21 | save_dir: ${save_dir} # path to save the predictions 22 | mean_adjusted: ${mean_adjusted} 23 | -------------------------------------------------------------------------------- /configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # you can replace it with "." if you want the root to be the current working directory 4 | root_dir: ${oc.env:PROJECT_ROOT} 5 | 6 | # path to data directory 7 | data_dir: ${paths.root_dir}/data/splits/perturb 8 | 9 | # path to logging directory 10 | log_dir: ${paths.root_dir}/logs/ 11 | 12 | # path to output directory, created dynamically by hydra 13 | # path generation pattern is specified in `configs/hydra/default.yaml` 14 | # use it to store all files generated during the run, like ckpts and metrics 15 | output_dir: ${hydra:runtime.output_dir} 16 | 17 | # path to working directory 18 | work_dir: ${hydra:runtime.cwd} 19 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default configuration 4 | # order of defaults determines the order in which configs override each other 5 | defaults: 6 | - _self_ 7 | - data: perturb 8 | - model: null # must be specified in experiment 9 | - callbacks: default 10 | - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`) 11 | - trainer: gpu 12 | - paths: default 13 | - extras: default 14 | - hydra: default 15 | 16 | # experiment configs allow for version control of specific hyperparameters 17 | # e.g. best hyperparameters for given model and datamodule 18 | - experiment: null 19 | 20 | # config for hyperparameter optimization 21 | - hparams_search: null 22 | 23 | # optional local config for machine/user specific settings 24 | # it's optional since it doesn't need to exist and is excluded from version control 25 | - optional local: default 26 | 27 | # debugging config (enable through command line, e.g. `python train.py debug=default) 28 | - debug: null 29 | 30 | model_type: null # set in experiment config 31 | 32 | # task name, determines output directory path 33 | task_name: "train" 34 | 35 | ##number of features in the input data, and embedding dimension of fm in consideration 36 | ##must be specified in the experiment config 37 | total_genes: null 38 | emb_dim: null 39 | 40 | # tags to help you identify your experiments 41 | # you can overwrite this in experiment configs 42 | # overwrite from command line with `python train.py tags="[first_tag, second_tag]"` 43 | tags: [] 44 | 45 | # set False to skip model training 46 | train: True 47 | 48 | # evaluate on test set, using best model weights achieved during training 49 | # lightning chooses best weights based on the metric specified in checkpoint callback 50 | test: True 51 | 52 | # simply provide checkpoint path to resume training 53 | ckpt_path: null 54 | 55 | # seed for random number generators in pytorch, numpy and python.random 56 | seed: null 57 | 58 | # number of times to repeat training runs with different seeds for uncertainty estimation 59 | repeats: 3 60 | -------------------------------------------------------------------------------- /configs/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: cpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | strategy: ddp 5 | 6 | accelerator: gpu 7 | devices: 4 8 | num_nodes: 1 9 | sync_batchnorm: True 10 | -------------------------------------------------------------------------------- /configs/trainer/ddp_sim.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | # simulate DDP on CPU, useful for debugging 5 | accelerator: cpu 6 | devices: 2 7 | strategy: ddp_spawn 8 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: lightning.pytorch.trainer.Trainer 2 | 3 | default_root_dir: ${paths.output_dir} 4 | 5 | min_epochs: 1 # prevents early stopping 6 | max_epochs: 20 7 | 8 | accelerator: gpu 9 | devices: 1 10 | 11 | num_sanity_val_steps: 0 12 | 13 | # mixed precision for extra speed-up 14 | # precision: 16 15 | 16 | # perform a validation loop every N training epochs 17 | check_val_every_n_epoch: 1 18 | 19 | # set True to ensure deterministic results 20 | # makes training slower but gives more reproducibility than just setting seeds 21 | deterministic: False 22 | -------------------------------------------------------------------------------- /configs/trainer/gpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: gpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/mps.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: mps 5 | devices: 1 6 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronwtr/PertEval/efbfa51991fbd6faa6039619d754e354be40fc07/data/.gitkeep -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | # reasons you might want to use `environment.yaml` instead of `requirements.txt`: 2 | # - pip installs packages in a loop, without ensuring dependencies across all packages 3 | # are fulfilled simultaneously, but conda achieves proper dependency control across 4 | # all packages 5 | # - conda allows for installing packages without requiring certain compilers or 6 | # libraries to be available in the system, since it installs precompiled binaries 7 | 8 | name: myenv 9 | 10 | channels: 11 | - pytorch 12 | - conda-forge 13 | - defaults 14 | 15 | # it is strongly recommended to specify versions of packages installed through conda 16 | # to avoid situation when version-unspecified packages install their latest major 17 | # versions which can sometimes break things 18 | 19 | # current approach below keeps the dependencies in the same major versions across all 20 | # users, but allows for different minor and patch versions of packages where backwards 21 | # compatibility is usually guaranteed 22 | 23 | dependencies: 24 | - python=3.10 25 | - pytorch=2.* 26 | - torchvision=0.* 27 | - lightning=2.* 28 | - torchmetrics=0.* 29 | - hydra-core=1.* 30 | - rich=13.* 31 | - pre-commit=3.* 32 | - pytest=7.* 33 | 34 | # --------- loggers --------- # 35 | # - wandb 36 | # - neptune-client 37 | # - mlflow 38 | # - comet-ml 39 | # - aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550 40 | 41 | - pip>=23 42 | - pip: 43 | - hydra-optuna-sweeper 44 | - hydra-colorlog 45 | - rootutils 46 | -------------------------------------------------------------------------------- /figures/PertEval-scFM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronwtr/PertEval/efbfa51991fbd6faa6039619d754e354be40fc07/figures/PertEval-scFM.png -------------------------------------------------------------------------------- /logs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronwtr/PertEval/efbfa51991fbd6faa6039619d754e354be40fc07/logs/.gitkeep -------------------------------------------------------------------------------- /notebooks/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronwtr/PertEval/efbfa51991fbd6faa6039619d754e354be40fc07/notebooks/.gitkeep -------------------------------------------------------------------------------- /notebooks/plots/expression_analysis_top20_vs_tail_genes.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "id": "initial_id", 6 | "metadata": { 7 | "collapsed": true, 8 | "ExecuteTime": { 9 | "end_time": "2024-09-10T10:44:56.222481Z", 10 | "start_time": "2024-09-10T10:44:56.220510Z" 11 | } 12 | }, 13 | "source": "# Author: A. Wenteler", 14 | "outputs": [], 15 | "execution_count": 1 16 | }, 17 | { 18 | "metadata": { 19 | "ExecuteTime": { 20 | "end_time": "2024-09-25T17:23:36.094106Z", 21 | "start_time": "2024-09-25T17:23:34.826531Z" 22 | } 23 | }, 24 | "cell_type": "code", 25 | "source": [ 26 | "import pandas as pd\n", 27 | "import numpy as np\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "import seaborn as sns\n", 30 | "import anndata as ad \n", 31 | "import pickle as pkl \n", 32 | "import scanpy as sc \n", 33 | "from tqdm import tqdm\n", 34 | "\n", 35 | "from scipy.sparse import csr_matrix" 36 | ], 37 | "id": "f445facac01a68de", 38 | "outputs": [], 39 | "execution_count": 7 40 | }, 41 | { 42 | "metadata": { 43 | "ExecuteTime": { 44 | "end_time": "2024-09-25T17:24:00.671677Z", 45 | "start_time": "2024-09-25T17:23:36.095108Z" 46 | } 47 | }, 48 | "cell_type": "code", 49 | "source": [ 50 | "# Load data\n", 51 | "sc_data_raw = ad.read_h5ad('../../data/norman_2019_raw.h5ad')" 52 | ], 53 | "id": "fac0ea882876ba10", 54 | "outputs": [], 55 | "execution_count": 8 56 | }, 57 | { 58 | "metadata": { 59 | "ExecuteTime": { 60 | "end_time": "2024-09-25T17:24:26.021982Z", 61 | "start_time": "2024-09-25T17:24:26.015705Z" 62 | } 63 | }, 64 | "cell_type": "code", 65 | "source": [ 66 | "def preprocess_adata(adata, min_gene_counts=None, min_cell_counts=None, no_highly_var=2000):\n", 67 | " \"\"\"\n", 68 | " Input is an adata object has a condition column with either \"ctrl\" for negative controls or GENE_SYMBOL for perturbed cells\n", 69 | " \"\"\"\n", 70 | " \n", 71 | " adata = adata.copy()\n", 72 | "\n", 73 | " #filter genes \n", 74 | " if min_gene_counts is not None:\n", 75 | " sc.pp.filter_genes(adata, min_counts=min_gene_counts)\n", 76 | "\n", 77 | " #filter cells\n", 78 | " if min_cell_counts is not None:\n", 79 | " sc.pp.filter_cells(adata, min_counts=min_cell_counts)\n", 80 | "\n", 81 | " #filter only single gene perturbations and controls\n", 82 | " conditions_to_keep = list()\n", 83 | " for cond in list(adata.obs['guide_ids']):\n", 84 | " if \",\" not in cond:\n", 85 | " conditions_to_keep.append(cond)\n", 86 | " adata = adata[adata.obs['guide_ids'].isin(conditions_to_keep), :]\n", 87 | "\n", 88 | " #apply preprocessing transformation\n", 89 | " sc.pp.normalize_total(adata, inplace=True)\n", 90 | " sc.pp.log1p(adata)\n", 91 | " sc.pp.highly_variable_genes(adata, n_top_genes=no_highly_var)\n", 92 | " highly_variable_genes = adata.var_names[adata.var['highly_variable']]\n", 93 | " adata = adata[:, highly_variable_genes]\n", 94 | "\n", 95 | " return adata " 96 | ], 97 | "id": "e0e39f647ba0ce86", 98 | "outputs": [], 99 | "execution_count": 11 100 | }, 101 | { 102 | "metadata": { 103 | "ExecuteTime": { 104 | "end_time": "2024-09-25T17:24:39.308944Z", 105 | "start_time": "2024-09-25T17:24:26.819919Z" 106 | } 107 | }, 108 | "cell_type": "code", 109 | "source": "adata_pp = preprocess_adata(sc_data_raw, min_gene_counts=5, min_cell_counts=None, no_highly_var=2000)", 110 | "id": "bb15abf0a2b7cd93", 111 | "outputs": [ 112 | { 113 | "name": "stderr", 114 | "output_type": "stream", 115 | "text": [ 116 | "/Users/aaronw/Desktop/PhD/Research/QMUL/Research/scBench/venv/lib/python3.10/site-packages/scanpy/preprocessing/_normalization.py:206: UserWarning: Received a view of an AnnData. Making a copy.\n", 117 | " view_to_actual(adata)\n" 118 | ] 119 | } 120 | ], 121 | "execution_count": 12 122 | }, 123 | { 124 | "metadata": {}, 125 | "cell_type": "code", 126 | "source": [ 127 | "# Load differentially expressed genes \n", 128 | "diff_genes = pkl.load(open('../../data/splits/perturb/norman_1/de_test/deg_pert_dict.pkl', 'rb'))" 129 | ], 130 | "id": "fdeb5450b5525041", 131 | "outputs": [], 132 | "execution_count": null 133 | }, 134 | { 135 | "metadata": {}, 136 | "cell_type": "code", 137 | "source": [ 138 | "expr_matrix = adata_pp.X.todense()\n", 139 | "perts = adata_pp.obs['condition'].tolist()\n", 140 | "expr_matrix = pd.DataFrame(expr_matrix, columns=adata_pp.var.gene_symbols)\n", 141 | "expr_matrix['perturbations'] = perts\n", 142 | "adata_obs = adata_pp.obs" 143 | ], 144 | "id": "84f3f818a22c11c8", 145 | "outputs": [], 146 | "execution_count": null 147 | }, 148 | { 149 | "metadata": {}, 150 | "cell_type": "code", 151 | "source": [ 152 | "expr_matrix_ctrl = expr_matrix.loc[expr_matrix['perturbations'] == '']\n", 153 | "expr_matrix_pert = expr_matrix.loc[expr_matrix['perturbations'] != '']" 154 | ], 155 | "id": "fe584e242e124bc2", 156 | "outputs": [], 157 | "execution_count": null 158 | }, 159 | { 160 | "metadata": {}, 161 | "cell_type": "code", 162 | "source": "df = pd.DataFrame(columns=['perturbations', 'gene', 'expression'])", 163 | "id": "135798645b0dd08c", 164 | "outputs": [], 165 | "execution_count": null 166 | }, 167 | { 168 | "metadata": {}, 169 | "cell_type": "code", 170 | "source": [ 171 | "new_rows = []\n", 172 | "\n", 173 | "for pert, de_genes in diff_genes.items():\n", 174 | " for gene in de_genes: \n", 175 | " gene_symbol = adata_pp.var.gene_symbols[adata_pp.var.index == gene].iloc[0]\n", 176 | " expression_values = expr_matrix_pert[expr_matrix_pert['perturbations'] == pert][gene_symbol].tolist()\n", 177 | "\n", 178 | " new_rows.append([pert, gene_symbol, expression_values])\n", 179 | "\n", 180 | "top20_pert_df = pd.DataFrame(new_rows, columns=['perturbations', 'gene', 'expression'])" 181 | ], 182 | "id": "f7470af70a0858d1", 183 | "outputs": [], 184 | "execution_count": null 185 | }, 186 | { 187 | "metadata": {}, 188 | "cell_type": "code", 189 | "source": [ 190 | "# Precompute the gene index to symbol mapping outside the loop\n", 191 | "gene_symbol_mapping = adata_pp.var['gene_symbols'].to_dict()\n", 192 | "\n", 193 | "# Loop over the perturbations and differentially expressed genes\n", 194 | "for pert, de_genes in tqdm(diff_genes.items()):\n", 195 | " all_genes = set(adata_pp.var.index) # Set of all genes\n", 196 | " remaining_genes = all_genes - set(de_genes) # Compute the remaining genes\n", 197 | "\n", 198 | " # Get the expression matrix for the current perturbation\n", 199 | " expression_values_pert = expr_matrix_pert[expr_matrix_pert['perturbations'] == pert]\n", 200 | " temp_rows = []\n", 201 | "\n", 202 | " for gene in remaining_genes:\n", 203 | " gene_symbol = gene_symbol_mapping[gene] \n", 204 | "\n", 205 | " expression_values = expression_values_pert[gene_symbol].tolist()\n", 206 | " temp_rows.append([pert, gene_symbol, expression_values])\n", 207 | "\n", 208 | " new_rows.extend(temp_rows)\n", 209 | " \n", 210 | "nontop20_pert_df = pd.DataFrame(new_rows, columns=['perturbations', 'gene', 'expression'])" 211 | ], 212 | "id": "ddbf9618a1e46b50", 213 | "outputs": [], 214 | "execution_count": null 215 | }, 216 | { 217 | "metadata": {}, 218 | "cell_type": "code", 219 | "source": [ 220 | "ikzf3_top20 = top20_pert_df[top20_pert_df['perturbations'] == 'IKZF3']\n", 221 | "ikzf3_nontop20 = nontop20_pert_df[nontop20_pert_df['perturbations'] == 'IKZF3']\n", 222 | "\n", 223 | "glb1l2_top20 = top20_pert_df[top20_pert_df['perturbations'] == 'GLB1L2']\n", 224 | "glb1l2_nontop20 = nontop20_pert_df[nontop20_pert_df['perturbations'] == 'GLB1L2']\n", 225 | "\n", 226 | "set_top20 = top20_pert_df[top20_pert_df['perturbations'] == 'SET']\n", 227 | "set_nontop20 = nontop20_pert_df[nontop20_pert_df['perturbations'] == 'SET']" 228 | ], 229 | "id": "9218d28bc3334d0f", 230 | "outputs": [], 231 | "execution_count": null 232 | }, 233 | { 234 | "metadata": {}, 235 | "cell_type": "code", 236 | "source": [ 237 | "ikzf3_top20_expr = np.vstack(ikzf3_top20['expression'].values).mean(axis=0)\n", 238 | "ikzf3_nontop20_expr = np.vstack(ikzf3_nontop20['expression'].values).mean(axis=0)\n", 239 | "\n", 240 | "glb1l2_top20_expr = np.vstack(glb1l2_top20['expression'].values).mean(axis=0)\n", 241 | "glb1l2_nontop20_expr = np.vstack(glb1l2_nontop20['expression'].values).mean(axis=0)\n", 242 | "\n", 243 | "set_top20_expr = np.vstack(set_top20['expression'].values).mean(axis=0)\n", 244 | "set_nontop20_expr = np.vstack(set_nontop20['expression'].values).mean(axis=0)" 245 | ], 246 | "id": "e640faa60d9f63e1", 247 | "outputs": [], 248 | "execution_count": null 249 | }, 250 | { 251 | "metadata": {}, 252 | "cell_type": "code", 253 | "source": [ 254 | "pert_comp = {\n", 255 | " \"Gene\": [\"IKZF3\", \"IKZF3\", \"GLB1L2\", \"GLB1L2\", \"SET\", \"SET\"],\n", 256 | " \"Group\": [\"Top 20 DEGs\", \"Tail genes\", \"Top 20 DEGs\", \"Tail genes\", \"Top 20 DEGs\", \"Tail genes\"],\n", 257 | " \"Expression\": [ikzf3_top20_expr, ikzf3_nontop20_expr, glb1l2_top20_expr, glb1l2_nontop20_expr, set_top20_expr, set_nontop20_expr]\n", 258 | "}" 259 | ], 260 | "id": "238253a7511d4131", 261 | "outputs": [], 262 | "execution_count": null 263 | }, 264 | { 265 | "metadata": {}, 266 | "cell_type": "code", 267 | "source": [ 268 | "pert_comp_df = pd.DataFrame(pert_comp)\n", 269 | "pert_comp_df" 270 | ], 271 | "id": "c8b63d628cd85980", 272 | "outputs": [], 273 | "execution_count": null 274 | }, 275 | { 276 | "metadata": {}, 277 | "cell_type": "code", 278 | "source": [ 279 | "expression_data = pert_comp_df.explode('Expression')\n", 280 | "expression_data_tail = expression_data[expression_data['Group'] == 'Tail genes']\n", 281 | "expression_data_top20 = expression_data[expression_data['Group'] == 'Top 20 DEGs']" 282 | ], 283 | "id": "31e2c39eaa6ae89e", 284 | "outputs": [], 285 | "execution_count": null 286 | }, 287 | { 288 | "metadata": {}, 289 | "cell_type": "code", 290 | "source": "avg_expression_data_tail = expression_data_tail.groupby(\"Gene\")[\"Expression\"].mean()", 291 | "id": "6d59b2d4e5024add", 292 | "outputs": [], 293 | "execution_count": null 294 | }, 295 | { 296 | "metadata": {}, 297 | "cell_type": "code", 298 | "source": [ 299 | "# set dpi = 300 \n", 300 | "plt.figure(dpi=300)\n", 301 | "sns.violinplot(x='Gene', y='Expression', hue='Group', data=expression_data_top20)\n", 302 | "plt.axhline(y=0.1606, color='C1', linestyle='--', label='Mean expression of tail genes')\n", 303 | "plt.xlabel('Perturbation')\n", 304 | "\n", 305 | "handles, labels = plt.gca().get_legend_handles_labels()\n", 306 | "unique_labels = dict(zip(labels, handles))\n", 307 | "plt.legend(unique_labels.values(), unique_labels.keys(), loc='upper right')\n", 308 | "plt.savefig('paper_figs/top20_vs_tail_genes.pdf')" 309 | ], 310 | "id": "17d8b32053d226fb", 311 | "outputs": [], 312 | "execution_count": null 313 | }, 314 | { 315 | "metadata": {}, 316 | "cell_type": "code", 317 | "source": [ 318 | "top20_pert_expl = top20_pert_df.explode('expression')\n", 319 | "nontop20_pert_expl = nontop20_pert_df.explode('expression')\n", 320 | "top20_pert_expl" 321 | ], 322 | "id": "cdfacc8faf865db7", 323 | "outputs": [], 324 | "execution_count": null 325 | }, 326 | { 327 | "metadata": {}, 328 | "cell_type": "code", 329 | "source": [ 330 | "# calculate the average and standard deviation of the expression across all perturbations and genes for both top20 and non top 20\n", 331 | "top20_avg = top20_pert_expl['expression'].mean()\n", 332 | "top20_std = top20_pert_expl['expression'].std()\n", 333 | "nontop20_avg = nontop20_pert_expl['expression'].mean()\n", 334 | "nontop20_std = nontop20_pert_expl['expression'].std()\n", 335 | "# find out the minimum of nontop20\n", 336 | "nontop20_min = nontop20_pert_expl['expression'].min()\n", 337 | "print(f\"Top 20 average expression: {top20_avg}, Top 20 standard deviation: {top20_std}\")\n", 338 | "print(f\"Non top 20 average expression: {nontop20_avg}, Non top 20 standard deviation: {nontop20_std}\")\n", 339 | "print(f\"Non top 20 minimum expression: {nontop20_min}\")" 340 | ], 341 | "id": "de8e1f38de2566e4", 342 | "outputs": [], 343 | "execution_count": null 344 | }, 345 | { 346 | "metadata": {}, 347 | "cell_type": "code", 348 | "outputs": [], 349 | "execution_count": null, 350 | "source": "", 351 | "id": "948f0b814ab84393" 352 | } 353 | ], 354 | "metadata": { 355 | "kernelspec": { 356 | "display_name": "Python 3", 357 | "language": "python", 358 | "name": "python3" 359 | }, 360 | "language_info": { 361 | "codemirror_mode": { 362 | "name": "ipython", 363 | "version": 2 364 | }, 365 | "file_extension": ".py", 366 | "mimetype": "text/x-python", 367 | "name": "python", 368 | "nbconvert_exporter": "python", 369 | "pygments_lexer": "ipython2", 370 | "version": "2.7.6" 371 | } 372 | }, 373 | "nbformat": 4, 374 | "nbformat_minor": 5 375 | } 376 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pytest.ini_options] 2 | addopts = [ 3 | "--color=yes", 4 | "--durations=0", 5 | "--strict-markers", 6 | "--doctest-modules", 7 | ] 8 | filterwarnings = [ 9 | "ignore::DeprecationWarning", 10 | "ignore::UserWarning", 11 | ] 12 | log_cli = "True" 13 | markers = [ 14 | "slow: slow tests", 15 | ] 16 | minversion = "6.0" 17 | testpaths = "tests/" 18 | 19 | [tool.coverage.report] 20 | exclude_lines = [ 21 | "pragma: nocover", 22 | "raise NotImplementedError", 23 | "raise NotImplementedError()", 24 | "if __name__ == .__main__.:", 25 | ] 26 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # --------- pytorch --------- # 2 | # temporary fix for cross-platform torch install with cuda compatibility 3 | torch>=2.0.0; sys_platform != "win32" # "darwin" and "linux" 4 | torchvision>=0.15.0; sys_platform != "win32" 5 | lightning>=2.0.0 6 | torchmetrics>=0.11.4 7 | 8 | # --------- hydra --------- # 9 | hydra-core==1.3.2 10 | hydra-colorlog==1.2.0 11 | hydra-optuna-sweeper==1.2.0 12 | 13 | # --------- loggers --------- # 14 | # wandb 15 | # neptune-client 16 | # mlflow 17 | # comet-ml 18 | # aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550 19 | 20 | # --------- others --------- # 21 | rootutils # standardizing the project root setup 22 | pre-commit # hooks for applying linters on commit 23 | rich # beautiful text formatting in terminal 24 | pytest # tests 25 | anndata==0.10.6 26 | networkx==3.2.1 27 | numpy==1.26.4 28 | numba==0.59.1 29 | omegaconf==2.3.0 30 | packaging==24.0 31 | pandas==2.2.2 32 | pertpy==0.7.0 33 | pytest==8.1.1 34 | rich==13.7.1 35 | rootutils==1.0.7 36 | scanpy==1.10.0 37 | scipy==1.13.1 38 | setuptools==68.2.0 39 | tqdm==4.66.2 40 | gdown==5.2.0 41 | wandb===0.18.3 -------------------------------------------------------------------------------- /scripts/schedule.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Schedule execution of many runs 3 | # Run from root folder with: bash scripts/schedule.sh 4 | 5 | python src/train.py trainer.max_epochs=5 logger=csv 6 | 7 | python src/train.py trainer.max_epochs=10 logger=csv 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | setup( 6 | name="src", 7 | version="0.0.1", 8 | description="Describe Your Cool Project", 9 | author="", 10 | author_email="", 11 | url="https://github.com/user/project", 12 | install_requires=["lightning", "hydra-core"], 13 | packages=find_packages(), 14 | # use this to customize global commands available in the terminal after installing the package 15 | entry_points={ 16 | "console_scripts": [ 17 | "train_command = src.train:main", 18 | "eval_command = src.eval:main", 19 | ] 20 | }, 21 | ) 22 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronwtr/PertEval/efbfa51991fbd6faa6039619d754e354be40fc07/src/__init__.py -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronwtr/PertEval/efbfa51991fbd6faa6039619d754e354be40fc07/src/data/__init__.py -------------------------------------------------------------------------------- /src/data/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronwtr/PertEval/efbfa51991fbd6faa6039619d754e354be40fc07/src/data/components/__init__.py -------------------------------------------------------------------------------- /src/data/components/embeddings.py: -------------------------------------------------------------------------------- 1 | embedding_links = { 2 | 'geneformer': { 3 | 'norman_1': { 4 | 'ctrl': "1yyQRcZEhdcsLOeMQKjZp7eSq6uHIYpwc", 5 | 'pert': "1nK-YeenYax84vV1LVLghg2aXF92CvvuA" 6 | }, 7 | 'norman_2': { 8 | 'ctrl': "1DTuCqigOr4lIg8J2qBFnUIHXz3oMQQC1", 9 | 'pert': "1VWiv7VOaqra63HFocNWkWd-u-MxvGdfY" 10 | }, 11 | 'replogle_k562': { 12 | 'ctrl': "16NfFvlZuSXIjyJwmfw22FvITW0XHXKLv", 13 | 'pert': "1kkSUjXZ6wxmPTC-ZUBfGmLwempDjylQO" 14 | } 15 | }, 16 | 'scgpt': { 17 | 'norman_1': { 18 | 'ctrl': '1ECHWwA5idPQspwfS74PMunxtN9Uj0CzR', 19 | 'pert': '1T1Vd779feygiDhW1zmWI67uN8dDQJcnq' 20 | }, 21 | 'norman_2': { 22 | 'ctrl': '1Oy9u-YxyoQGjYLEclKyrSOf8G4LkrcKk', 23 | 'pert': '1v2wH3pr9TcSrTceRfFG_KBDsLg7RfTiR' 24 | }, 25 | 'replogle_k562': { 26 | 'ctrl': '1UnlBL2os1HuKCWa67eOOhkWfv6c66z9i', 27 | 'pert': '1qZJi3A3XVmCsrNcZJqCXUhF3DGF0gW_z' 28 | } 29 | }, 30 | 'scfoundation': { 31 | 'norman_1': { 32 | 'ctrl': '1JVLfShRXjwUgovX78qWDMiL1ZXhVgepQ', 33 | 'pert': '1CeYuSuUP408h33o11L1e82chV9WVwZXY' 34 | }, 35 | 'norman_2': { 36 | 'ctrl': '1VHEV-lgPb2xe362yM3h6_NLhSBkn1jtp', 37 | 'pert': '1KaWNIJe--NPj5u7k00D0CCrTRJuunu98' 38 | }, 39 | 'replogle_k562': { 40 | 'ctrl': '1G-DnnfskJMVcXSxYsyaWLvHhA_2QfWIG', 41 | 'pert': '1rjMIKErTYhNYVW_CzhhYFGMCKUwN5dOb' 42 | } 43 | }, 44 | 'scbert': { 45 | 'norman_1': { 46 | 'ctrl': '15p0kvoImPNfl31qYmTtGuyhB4GsaSbeu', 47 | 'pert': '1S1lMR6UhM5QUik0imv8A7G5f3i4xFj7R' 48 | }, 49 | 'norman_2': { 50 | 'ctrl': '10JD689TmbvRAGzsQ9vwwHyvMsb4kv30f', 51 | 'pert': '1wl5GXnXbCU7ACtO4YF3Ii0kFVsIJhyOr' 52 | }, 53 | 'replogle_k562': { 54 | 'ctrl': '', 55 | 'pert': '' 56 | } 57 | }, 58 | 'uce': { 59 | 'norman_1': { 60 | 'ctrl': '1CbAVdnmzaKF1p-VKKMynbR0PkBnXWCK8', 61 | 'pert': '1fFQB8mgjB63v3OkwyCn599yMKYy9XPDb' 62 | }, 63 | 'norman_2': { 64 | 'ctrl': '1-WYImnrVxuu9RtHOdm1WDnfQD61QMX2U', 65 | 'pert': '1OD7_wkh9LB7fHLdIPBDVD2saQY0hZW4O' 66 | }, 67 | 'replogle_k562': { 68 | 'ctrl': '', 69 | 'pert': '' 70 | } 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /src/data/perturb_datamodule.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pickle as pkl 4 | 5 | from typing import Any, Dict, Optional 6 | from pertpy import data as scpert_data 7 | 8 | from lightning import LightningDataModule 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | 12 | from src.data.perturb_dataset import PerturbData 13 | 14 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 15 | SRC_DIR = os.path.dirname(SCRIPT_DIR) 16 | ROOT_DIR = os.path.dirname(SRC_DIR) 17 | 18 | with open(f'{ROOT_DIR}/cache/data_dir_cache.txt', 'r') as f: 19 | DATA_DIR = f.read().strip() 20 | 21 | 22 | class PertDataModule(LightningDataModule): 23 | """`LightningDataModule` for perturbation data. Based on GEARS PertData class, but adapted for PyTorch Lightning. 24 | 25 | A `LightningDataModule` implements 7 key methods: 26 | 27 | ```python 28 | def prepare_data(self): 29 | # Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP). 30 | # Download data, pre-process, split, save to disk, etc... 31 | 32 | def setup(self, stage): 33 | # Things to do on every process in DDP. 34 | # Data loading, set variables, etc... 35 | 36 | def train_dataloader(self): 37 | # return train dataloader 38 | 39 | def val_dataloader(self): 40 | # return validation dataloader 41 | 42 | def test_dataloader(self): 43 | # return test dataloader 44 | 45 | def predict_dataloader(self): 46 | # return predict dataloader 47 | 48 | def teardown(self, stage): 49 | # Called on every process in DDP. 50 | # Clean up after fit or test. 51 | ``` 52 | 53 | This allows you to share a full dataset without explaining how to download, 54 | split, transform and process the data. 55 | 56 | Read the docs: 57 | https://lightning.ai/docs/pytorch/latest/data/datamodule.html 58 | """ 59 | 60 | def __init__( 61 | self, 62 | data_dir: str = '', 63 | data_name: str = "norman", 64 | split: float = 0.00, 65 | replicate: int = 0, 66 | batch_size: int = 64, 67 | spectra_parameters: Optional[Dict[str, Any]] = None, 68 | deg_eval: Optional[str] = None, 69 | eval_pert: Optional[str] = None, 70 | num_workers: int = 0, 71 | pin_memory: bool = False, 72 | **kwargs: Any, 73 | ) -> None: 74 | """Initialize a `PertDataModule`. 75 | 76 | :param data_dir: The data directory. Defaults to `""`. 77 | :param data_name: The name of the dataset. Defaults to `"norman"`. Can pick from "norman", "replogle_k562", and 78 | "replogle_rpe1". 79 | :param batch_size: The batch size. Defaults to `64`. 80 | :param num_workers: The number of workers. Defaults to `0`. 81 | :param pin_memory: Whether to pin memory. Defaults to `False`. 82 | """ 83 | super().__init__() 84 | self.deg_dict = None 85 | self.num_genes = None 86 | self.num_perts = None 87 | self.pert_data = None 88 | self.pertmodule = None 89 | self.adata = None 90 | self.train_dataset = None 91 | self.val_dataset = None 92 | self.test_dataset = None 93 | self.spectra_parameters = spectra_parameters 94 | self.data_name = data_name 95 | self.deg_eval = deg_eval 96 | self.eval_pert = eval_pert 97 | 98 | self.fm = kwargs.get("fm", None) 99 | 100 | if isinstance(split, float): 101 | self.spectral_parameter = f"{split:.2f}_{str(replicate)}" 102 | elif isinstance(split, str): 103 | self.spectral_parameter = f"{split}_{str(replicate)}" 104 | elif isinstance(split, int): 105 | self.spectral_parameter = f"{split:.2f}_{str(replicate)}" 106 | else: 107 | raise ValueError("Split must be a float, int or a string!") 108 | 109 | # this line allows to access init params with 'self.hparams' attribute 110 | # also ensures init params will be stored in ckpt 111 | self.save_hyperparameters(logger=False) 112 | self.data_path = os.path.join(data_dir, self.data_name) 113 | 114 | if not os.path.exists(self.data_path): 115 | os.makedirs(self.data_path) 116 | 117 | self.data_train: Optional[DataLoader] = None 118 | self.data_val: Optional[DataLoader] = None 119 | self.data_test: Optional[DataLoader] = None 120 | 121 | self.load_scpert_data = { 122 | "norman": "norman_2019_raw", 123 | "replogle_k562": "replogle_2022_k562_essential", 124 | "replogle_rpe1": "replogle_2022_rpe1", 125 | } 126 | 127 | self.batch_size_per_device = batch_size 128 | 129 | # need to call prepare and setup manually to guarantee proper model setup 130 | self.prepare_data() 131 | self.setup() 132 | 133 | def prepare_data(self) -> None: 134 | """Put all downloading and preprocessing logic that only needs to happen on one device here. Lightning ensures 135 | that `self.prepare_data()` is called only within a single process on CPU, so you can safely add your logic 136 | within. In case of multi-node training, the execution of this hook depends upon `self.prepare_data_per_node()`. 137 | 138 | Downloading: 139 | Currently, supports "norman", "replogle_k562, replogle_rpe1" datasets. 140 | 141 | Do not use it to assign state (self.x = y). 142 | """ 143 | if self.data_name in ["norman_1", "norman_2", "replogle_k562", "replogle_rpe1"]: 144 | if "norman" in self.data_name: 145 | data_name = "norman" 146 | else: 147 | data_name = self.data_name 148 | if f"{self.load_scpert_data[data_name]}.h5ad" not in os.listdir("data/"): 149 | scpert_loader = getattr(scpert_data, self.load_scpert_data[data_name]) 150 | scpert_loader() 151 | else: 152 | raise ValueError(f"Data name {self.data_name} not recognized. Choose from: 'norman_1', 'norman_2', " 153 | f"'replogle_k562', or replogle_rpe1") 154 | 155 | def setup(self, stage: Optional[str] = None) -> None: 156 | """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. 157 | 158 | This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and 159 | `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after 160 | `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to 161 | `self.setup()` once the data is prepared and available for use. 162 | 163 | :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``. 164 | """ 165 | # Divide batch size by the number of devices. 166 | if self.trainer is not None: 167 | if self.hparams.batch_size % self.trainer.world_size != 0: 168 | raise RuntimeError( 169 | f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices " 170 | f"({self.trainer.world_size})." 171 | ) 172 | self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size 173 | 174 | # load and split datasets only if not loaded already 175 | if not self.data_train and not self.data_val and not self.data_test: 176 | if 'norman' in self.data_name: 177 | data_name = "norman" 178 | else: 179 | data_name = self.data_name 180 | scpert_loader = getattr(scpert_data, self.load_scpert_data[data_name]) 181 | adata = scpert_loader() 182 | 183 | self.train_dataset = PerturbData(adata, self.data_path, self.spectral_parameter, 184 | self.spectra_parameters, self.fm, stage="train") 185 | 186 | self.val_dataset = PerturbData(adata, self.data_path, self.spectral_parameter, 187 | self.spectra_parameters, self.fm, stage="val") 188 | 189 | if not self.deg_eval: 190 | self.test_dataset = PerturbData(adata, self.data_path, self.spectral_parameter, 191 | self.spectra_parameters, self.fm, stage="test") 192 | else: 193 | deg_dict = pkl.load(open(f"{self.data_path}/de_test/deg_pert_dict.pkl", "rb")) 194 | self.test_dataset = PerturbData(adata, self.data_path, self.spectral_parameter, self.spectra_parameters, 195 | self.fm, perturbation=self.eval_pert, deg_dict=deg_dict, stage="test") 196 | 197 | def train_dataloader(self) -> DataLoader[Any]: 198 | """Create and return the train dataloader. 199 | 200 | :return: The train dataloader. 201 | """ 202 | return DataLoader( 203 | self.train_dataset, 204 | batch_size=self.batch_size_per_device, 205 | num_workers=self.hparams.num_workers, 206 | pin_memory=self.hparams.pin_memory, 207 | shuffle=True, 208 | ) 209 | 210 | def val_dataloader(self) -> DataLoader[Any]: 211 | """Create and return the validation dataloader. 212 | 213 | :return: The validation dataloader. 214 | """ 215 | return DataLoader( 216 | self.val_dataset, 217 | batch_size=self.batch_size_per_device, 218 | num_workers=self.hparams.num_workers, 219 | pin_memory=self.hparams.pin_memory, 220 | shuffle=False, 221 | ) 222 | 223 | def test_dataloader(self) -> DataLoader[Any]: 224 | """Create and return the test dataloader. 225 | 226 | :return: The test dataloader. 227 | """ 228 | return DataLoader( 229 | self.test_dataset, 230 | # batch_size=len(self.test_dataset), 231 | batch_size=self.batch_size_per_device, 232 | num_workers=self.hparams.num_workers, 233 | pin_memory=self.hparams.pin_memory, 234 | shuffle=False, 235 | ) 236 | 237 | def predict_dataloader(self) -> DataLoader[Any]: 238 | """Create and return the predict dataloader. 239 | 240 | :return: The predict dataloader. 241 | """ 242 | return DataLoader( 243 | self.test_dataset, 244 | batch_size=self.batch_size_per_device, 245 | num_workers=self.hparams.num_workers, 246 | pin_memory=self.hparams.pin_memory, 247 | shuffle=False, 248 | ) 249 | 250 | def teardown(self, stage: Optional[str] = None) -> None: 251 | """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, 252 | `trainer.test()`, and `trainer.predict()`. 253 | 254 | :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 255 | Defaults to ``None``. 256 | """ 257 | pass 258 | 259 | def state_dict(self) -> Dict[Any, Any]: 260 | """Called when saving a checkpoint. Implement to generate and save the datamodule state. 261 | 262 | :return: A dictionary containing the datamodule state that you want to save. 263 | """ 264 | return {} 265 | 266 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 267 | """Called when loading a checkpoint. Implement to reload datamodule state given datamodule 268 | `state_dict()`. 269 | 270 | :param state_dict: The datamodule state returned by `self.state_dict()`. 271 | """ 272 | pass 273 | 274 | 275 | if __name__ == "__main__": 276 | _ = PertDataModule() 277 | -------------------------------------------------------------------------------- /src/data/reproduction/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronwtr/PertEval/efbfa51991fbd6faa6039619d754e354be40fc07/src/data/reproduction/__init__.py -------------------------------------------------------------------------------- /src/data/reproduction/gears/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronwtr/PertEval/efbfa51991fbd6faa6039619d754e354be40fc07/src/data/reproduction/gears/__init__.py -------------------------------------------------------------------------------- /src/data/reproduction/gears/gears_datamodule.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from typing import Any, Dict, Optional, Tuple, Union 4 | from gears import PertData, GEARS 5 | from ruamel.yaml import YAML 6 | 7 | import torch 8 | from lightning import LightningDataModule 9 | from torch.utils.data import DataLoader 10 | from torchvision.transforms import transforms 11 | 12 | from src.utils.utils import zip_data_download_wrapper, find_root_dir 13 | 14 | 15 | ROOT_DIR = find_root_dir(os.path.dirname(os.path.abspath(__file__))) 16 | with open(f'{ROOT_DIR}/cache/data_dir_cache.txt', 'r') as f: 17 | DATA_DIR = f.read().strip() 18 | 19 | 20 | class GEARSDataModule(LightningDataModule): 21 | """`LightningDataModule` for perturbation data. Based on GEARS PertData class, but adapted for PyTorch Lightning. 22 | 23 | A `LightningDataModule` implements 7 key methods: 24 | 25 | ```python 26 | def prepare_data(self): 27 | # Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP). 28 | # Download data, pre-process, split, save to disk, etc... 29 | 30 | def setup(self, stage): 31 | # Things to do on every process in DDP. 32 | # Data loading, set variables, etc... 33 | 34 | def train_dataloader(self): 35 | # return train dataloader 36 | 37 | def val_dataloader(self): 38 | # return validation dataloader 39 | 40 | def test_dataloader(self): 41 | # return test dataloader 42 | 43 | def predict_dataloader(self): 44 | # return predict dataloader 45 | 46 | def teardown(self, stage): 47 | # Called on every process in DDP. 48 | # Clean up after fit or test. 49 | ``` 50 | 51 | This allows you to share a full dataset without explaining how to download, 52 | split, transform and process the data. 53 | 54 | Read the docs: 55 | https://lightning.ai/docs/pytorch/latest/data/datamodule.html 56 | """ 57 | 58 | def __init__( 59 | self, 60 | data_dir: str = DATA_DIR, 61 | data_name: str = "norman", 62 | batch_size: int = 64, 63 | num_workers: int = 0, 64 | pin_memory: bool = False, 65 | ) -> None: 66 | """Initialize a `PertDataModule`. 67 | 68 | :param data_dir: The data directory. Defaults to `""`. 69 | :param data_name: The name of the dataset. Defaults to `"norman"`. Can pick from "norman", "adamson", "dixit", 70 | "replogle_k562_essential" and "replogle_rpe1_essential". 71 | :param train_val_test_split: The train, validation and test split. Defaults to `(0.8, 0.05, 0.15)`. 72 | :param batch_size: The batch size. Defaults to `64`. 73 | :param num_workers: The number of workers. Defaults to `0`. 74 | :param pin_memory: Whether to pin memory. Defaults to `False`. 75 | """ 76 | super().__init__() 77 | 78 | self.num_genes = None 79 | self.num_perts = None 80 | self.pert_data = None 81 | self.data_name = data_name 82 | 83 | # this line allows to access init params with 'self.hparams' attribute 84 | # also ensures init params will be stored in ckpt 85 | self.save_hyperparameters(logger=False) 86 | 87 | self.data_path = f"{data_dir}/{self.data_name}" 88 | # if not os.path.exists(self.data_path): 89 | # os.makedirs(self.data_path) 90 | 91 | self.data_train: Optional[DataLoader] = None 92 | self.data_val: Optional[DataLoader] = None 93 | self.data_test: Optional[DataLoader] = None 94 | 95 | self.batch_size_per_device = batch_size 96 | 97 | # need to call prepare and setup manually to guarantee proper model setup 98 | self.prepare_data() 99 | self.setup() 100 | 101 | def prepare_data(self) -> None: 102 | """Put all downloading and preprocessing logic that only needs to happen on one device here. Lightning ensures 103 | that `self.prepare_data()` is called only within a single process on CPU, so you can safely add your logic 104 | within. In case of multi-node training, the execution of this hook depends upon `self.prepare_data_per_node()`. 105 | 106 | Downloading: 107 | Currently, supports "adamson", "norman", "dixit", "replogle_k562_essential" and "replogle_rpe1_essential" 108 | datasets. 109 | 110 | Do not use it to assign state (self.x = y). 111 | """ 112 | # TODO: Add support for downloading from a specified url 113 | print(f"Downloading {self.data_name} data...") 114 | if os.path.exists(self.data_path): 115 | print(f"Found local copy of {self.data_name} data...") 116 | elif self.data_name in ['norman', 'adamson', 'dixit', 'replogle_k562_essential', 'replogle_rpe1_essential']: 117 | ## load from harvard dataverse 118 | if self.data_name == 'norman': 119 | url = 'https://dataverse.harvard.edu/api/access/datafile/6154020' 120 | elif self.data_name == 'adamson': 121 | url = 'https://dataverse.harvard.edu/api/access/datafile/6154417' 122 | elif self.data_name == 'dixit': 123 | url = 'https://dataverse.harvard.edu/api/access/datafile/6154416' 124 | elif self.data_name == 'replogle_k562_essential': 125 | ## Note: This is not the complete dataset and has been filtered 126 | url = 'https://dataverse.harvard.edu/api/access/datafile/7458695' 127 | elif self.data_name == 'replogle_rpe1_essential': 128 | ## Note: This is not the complete dataset and has been filtered 129 | url = 'https://dataverse.harvard.edu/api/access/datafile/7458694' 130 | zip_data_download_wrapper(url, self.data_path) 131 | print(f"Successfully downloaded {self.data_name} data and saved to {self.data_path}") 132 | else: 133 | raise ValueError("data_name should be either 'norman', 'adamson', 'dixit', 'replogle_k562_essential' or " 134 | "'replogle_rpe1_essential'") 135 | PertData(self.data_path) 136 | 137 | def setup(self, stage: Optional[str] = None) -> None: 138 | """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. 139 | 140 | This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and 141 | `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after 142 | `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to 143 | `self.setup()` once the data is prepared and available for use. 144 | 145 | :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``. 146 | """ 147 | # Divide batch size by the number of devices. 148 | if self.trainer is not None: 149 | if self.hparams.batch_size % self.trainer.world_size != 0: 150 | raise RuntimeError( 151 | f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices " 152 | f"({self.trainer.world_size})." 153 | ) 154 | self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size 155 | 156 | # load and split datasets only if not loaded already 157 | if not self.data_train and not self.data_val and not self.data_test: 158 | pert_data = PertData(self.data_path) 159 | pert_data.load(data_path=self.data_path) 160 | pert_data.prepare_split(split='simulation', seed=1) 161 | pert_data.get_dataloader(batch_size=self.batch_size_per_device, test_batch_size=128) 162 | self.pert_data = pert_data 163 | self.gene_list = pert_data.gene_names.values.tolist() 164 | self.pert_list = pert_data.pert_names.tolist() 165 | # calculating num_genes and num_perts for GEARS 166 | self.num_genes = len(self.gene_list) 167 | self.num_perts = len(self.pert_list) 168 | # adding num_genes and num_perts to hydra configs 169 | yaml = YAML() 170 | yaml.preserve_quotes = True 171 | yaml.width = 4096 172 | with open(f'{ROOT_DIR}/configs/model/gears.yaml', 'r') as f: 173 | yaml_data = yaml.load(f) 174 | yaml_data['net']['num_genes'] = self.num_genes 175 | yaml_data['net']['num_perts'] = self.num_perts 176 | with open(f'{ROOT_DIR}/configs/model/gears.yaml', 'w') as f: 177 | yaml.dump(yaml_data, f) 178 | dataloaders = pert_data.dataloader 179 | self.data_train = dataloaders['train_loader'] 180 | self.data_val = dataloaders['val_loader'] 181 | self.data_test = dataloaders['test_loader'] 182 | 183 | def train_dataloader(self) -> DataLoader[Any]: 184 | """Create and return the train dataloader. 185 | 186 | :return: The train dataloader. 187 | """ 188 | return self.data_train 189 | 190 | def val_dataloader(self) -> DataLoader[Any]: 191 | """Create and return the validation dataloader. 192 | 193 | :return: The validation dataloader. 194 | """ 195 | return self.data_val 196 | 197 | def test_dataloader(self) -> DataLoader[Any]: 198 | """Create and return the test dataloader. 199 | 200 | :return: The test dataloader. 201 | """ 202 | return self.data_test 203 | 204 | def teardown(self, stage: Optional[str] = None) -> None: 205 | """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, 206 | `trainer.test()`, and `trainer.predict()`. 207 | 208 | :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 209 | Defaults to ``None``. 210 | """ 211 | pass 212 | 213 | def state_dict(self) -> Dict[Any, Any]: 214 | """Called when saving a checkpoint. Implement to generate and save the datamodule state. 215 | 216 | :return: A dictionary containing the datamodule state that you want to save. 217 | """ 218 | return {} 219 | 220 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 221 | """Called when loading a checkpoint. Implement to reload datamodule state given datamodule 222 | `state_dict()`. 223 | 224 | :param state_dict: The datamodule state returned by `self.state_dict()`. 225 | """ 226 | pass 227 | 228 | def get_pert_data(self): 229 | return self.pert_data 230 | 231 | 232 | if __name__ == "__main__": 233 | _ = GEARSDataModule() 234 | -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Tuple 2 | 3 | import hydra 4 | import rootutils 5 | import torch 6 | from lightning import LightningDataModule, LightningModule, Trainer 7 | from lightning.pytorch.loggers import Logger 8 | from omegaconf import OmegaConf, DictConfig 9 | 10 | OmegaConf.register_new_resolver("eval", eval) 11 | 12 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 13 | # ------------------------------------------------------------------------------------ # 14 | # the setup_root above is equivalent to: 15 | # - adding project root dir to PYTHONPATH 16 | # (so you don't need to force user to install project as a package) 17 | # (necessary before importing any local modules e.g. `from src import utils`) 18 | # - setting up PROJECT_ROOT environment variable 19 | # (which is used as a base for paths in "configs/paths/default.yaml") 20 | # (this way all filepaths are the same no matter where you run the code) 21 | # - loading environment variables from ".env" in root dir 22 | # 23 | # you can remove it if you: 24 | # 1. either install project as a package or move entry files to project root dir 25 | # 2. set `root_dir` to "." in "configs/paths/default.yaml" 26 | # 27 | # more info: https://github.com/ashleve/rootutils 28 | # ------------------------------------------------------------------------------------ # 29 | 30 | from src.utils import ( 31 | RankedLogger, 32 | extras, 33 | instantiate_loggers, 34 | log_hyperparameters, 35 | task_wrapper, 36 | ) 37 | 38 | log = RankedLogger(__name__, rank_zero_only=True) 39 | 40 | 41 | @task_wrapper 42 | def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 43 | """Evaluates given checkpoint on a datamodule testset. 44 | 45 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during 46 | failure. Useful for multiruns, saving info about the crash, etc. 47 | 48 | :param cfg: DictConfig configuration composed by Hydra. 49 | :return: Tuple[dict, dict] with metrics and dict with all instantiated objects. 50 | """ 51 | # assert cfg.ckpt_path 52 | 53 | if torch.cuda.is_available(): 54 | torch.set_float32_matmul_precision('medium') 55 | 56 | log.info(f"Instantiating datamodule <{cfg.data._target_}>") 57 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) 58 | 59 | log.info(f"Instantiating model <{cfg.model._target_}>") 60 | model: LightningModule = hydra.utils.instantiate(cfg.model) 61 | 62 | log.info("Instantiating loggers...") 63 | logger: List[Logger] = instantiate_loggers(cfg.get("logger")) 64 | 65 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") 66 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger) 67 | 68 | object_dict = { 69 | "cfg": cfg, 70 | "datamodule": datamodule, 71 | "model": model, 72 | "logger": logger, 73 | "trainer": trainer, 74 | } 75 | 76 | if logger: 77 | log.info("Logging hyperparameters!") 78 | log_hyperparameters(object_dict) 79 | 80 | log.info("Starting testing!") 81 | trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path) 82 | 83 | # for predictions use trainer.predict(...) 84 | # predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path) 85 | 86 | metric_dict = trainer.callback_metrics 87 | 88 | return metric_dict, object_dict 89 | 90 | 91 | @hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml") 92 | def main(cfg: DictConfig) -> None: 93 | """Main entry point for evaluation. 94 | 95 | :param cfg: DictConfig configuration composed by Hydra. 96 | """ 97 | # apply extra utilities 98 | # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 99 | extras(cfg) 100 | 101 | evaluate(cfg) 102 | 103 | 104 | if __name__ == "__main__": 105 | main() 106 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronwtr/PertEval/efbfa51991fbd6faa6039619d754e354be40fc07/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronwtr/PertEval/efbfa51991fbd6faa6039619d754e354be40fc07/src/models/components/__init__.py -------------------------------------------------------------------------------- /src/models/components/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class WeightedRMSELoss(nn.Module): 6 | def __init__(self): 7 | super(WeightedRMSELoss, self).__init__() 8 | 9 | def forward(self, y_pred, y_true): 10 | abs_pert_effect = torch.abs(y_true) 11 | 12 | weights = nn.functional.softmax(abs_pert_effect, dim=-1) 13 | 14 | return torch.sqrt(torch.mean(weights * (y_true - y_pred) ** 2)) 15 | -------------------------------------------------------------------------------- /src/models/components/predictors.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Tuple 2 | 3 | import os 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class LinearRegressionModel(torch.nn.Module): 10 | def __init__(self, 11 | in_dim: int): 12 | super().__init__() 13 | 14 | self.linear = torch.nn.Linear(in_dim, 1) 15 | 16 | def forward(self, x: torch.Tensor) -> torch.Tensor: 17 | return self.linear(x) 18 | 19 | 20 | class MLP(torch.nn.Module): 21 | def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, num_layers: int, 22 | layer_activation: nn.Module = nn.ReLU()): 23 | super().__init__() 24 | self.layer_activation = layer_activation 25 | 26 | self.layers = nn.ModuleList() 27 | self.layers.append(nn.Linear(in_dim, hidden_dim)) 28 | for _ in range(num_layers - 1): 29 | self.layers.append(nn.Linear(hidden_dim, hidden_dim)) 30 | self.layers.append(nn.Linear(hidden_dim, out_dim)) 31 | 32 | def forward(self, x: torch.Tensor) -> torch.Tensor: 33 | for layer in self.layers[:-1]: 34 | x = self.layer_activation(layer(x)) 35 | x = self.layers[-1](x) 36 | return x 37 | 38 | 39 | class MeanExpression(torch.nn.Module): 40 | def __init__(self): 41 | super().__init__() 42 | 43 | def forward(self, x: torch.Tensor) -> torch.Tensor: 44 | mean_expr = torch.mean(x[:, :x.shape[1] // 2], dim=0) 45 | return mean_expr - x[:, :x.shape[1] // 2] 46 | -------------------------------------------------------------------------------- /src/models/gears_module.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from typing import Any, Dict, Tuple 4 | 5 | import torch 6 | from lightning import LightningModule 7 | from torch_geometric.data import Batch 8 | from torchmetrics import MaxMetric, MeanMetric 9 | from torchmetrics.regression import SpearmanCorrCoef, PearsonCorrCoef, MeanSquaredError 10 | 11 | from .reproduction.gears.gears import GEARSNetwork 12 | from gears.utils import loss_fct 13 | from gears import PertData 14 | 15 | 16 | class GEARSLitModule(LightningModule): 17 | """LightningModule wrapper for GEARS. 18 | 19 | A `LightningModule` implements 8 key methods: 20 | 21 | ```python 22 | def __init__(self): 23 | # Define initialization code here. 24 | 25 | def setup(self, stage): 26 | # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'. 27 | # This hook is called on every process when using DDP. 28 | 29 | def training_step(self, batch, batch_idx): 30 | # The complete training step. 31 | 32 | def validation_step(self, batch, batch_idx): 33 | # The complete validation step. 34 | 35 | def test_step(self, batch, batch_idx): 36 | # The complete test step. 37 | 38 | def predict_step(self, batch, batch_idx): 39 | # The complete predict step. 40 | 41 | def configure_optimizers(self): 42 | # Define and configure optimizers and LR schedulers. 43 | ``` 44 | 45 | Docs: 46 | https://lightning.ai/docs/pytorch/latest/common/lightning_module.html 47 | """ 48 | 49 | def __init__( 50 | self, 51 | 52 | net: GEARSNetwork, 53 | pertmodule: PertData, 54 | optimizer: Any, 55 | scheduler: Any, 56 | model_name: Any, 57 | compile: bool = False 58 | ) -> None: 59 | """Initialize a `GEARSLitModule`. 60 | 61 | :param net: The model to train. 62 | """ 63 | super().__init__() 64 | 65 | # this line allows to access init params with 'self.hparams' attribute 66 | # also ensures init params will be stored in ckpt 67 | self.save_hyperparameters(logger=False) 68 | 69 | self.net = net 70 | 71 | # loss function 72 | self.criterion = loss_fct 73 | 74 | adata = pertmodule.pert_data.adata 75 | pert_full_id2pert = dict(adata.obs[['condition_name', 'condition']].values) 76 | self.dict_filter = {pert_full_id2pert[i]: j for i, j in adata.uns['non_zeros_gene_idx'].items() 77 | if i in pert_full_id2pert} 78 | 79 | self.pert_list = pertmodule.pert_list 80 | 81 | self.ctrl_expression = torch.tensor(np.mean(adata.X[adata.obs.condition == 'ctrl'], axis=0)).reshape(-1, ) 82 | 83 | self.test_results = {} 84 | self.pert_cat = [] 85 | self.test_pred = [] 86 | self.test_truth = [] 87 | self.test_pred_de = [] 88 | self.test_truth_de = [] 89 | 90 | # for averaging loss across batches 91 | self.train_loss = MeanMetric() 92 | self.val_loss = MeanMetric() 93 | 94 | # for tracking best so far validation pearson correlation 95 | self.val_loss_best = MaxMetric() 96 | 97 | self.test_spr = SpearmanCorrCoef() 98 | self.test_prs = PearsonCorrCoef() 99 | self.test_mse = MeanSquaredError() 100 | 101 | self.metric2fct = { 102 | 'mse': self.test_mse, 103 | 'pearson': self.test_prs, 104 | 'spearman': self.test_spr, 105 | } 106 | 107 | self.net.model_initialize(pertmodule) 108 | 109 | self.net.to(self.device) 110 | 111 | def forward(self, x: torch.Tensor, pert_idx: list, batch: Batch) -> torch.Tensor: 112 | """Perform a forward pass through the model `self.net`. 113 | 114 | :param x: Flattened representation of GEARS input graphs. 115 | :param pert_idx: The index of the perturbation. 116 | :param batch: PyG Batch object. 117 | :return: A tensor of gene-level RNA expression. 118 | """ 119 | return self.net(x, pert_idx, batch) 120 | 121 | def on_train_start(self) -> None: 122 | """Lightning hook that is called when training begins.""" 123 | # by default lightning executes validation step sanity checks before training starts, 124 | # so it's worth to make sure validation metrics don't store results from these checks 125 | self.val_loss.reset() 126 | 127 | def model_step( 128 | self, batch: Batch 129 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 130 | """Perform a single model step on a batch of data. 131 | 132 | :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. 133 | 134 | :return: A tuple containing (in order): 135 | - A tensor of losses. 136 | - A tensor of predictions. 137 | - A tensor of target labels. 138 | """ 139 | y = batch.y 140 | 141 | dir_lambda = self.net.direction_lambda 142 | preds = self.forward(batch.x, batch.pert_idx, batch) 143 | 144 | loss = self.criterion(preds, y, perts=batch.pert, ctrl=self.ctrl_expression.to(self.device), 145 | dict_filter=self.dict_filter, direction_lambda=dir_lambda) 146 | return loss, preds, y 147 | 148 | def training_step( 149 | self, batch: Batch, batch_idx: int 150 | ) -> torch.Tensor: 151 | """Perform a single training step on a batch of data from the training set. 152 | 153 | :param batch: A batch of data (a tuple) containing the input tensor of images and target 154 | labels. 155 | :param batch_idx: The index of the current batch. 156 | :return: A tensor of losses between model predictions and targets. 157 | """ 158 | loss, preds, targets = self.model_step(batch) 159 | # update and log metrics 160 | self.train_loss(loss) 161 | self.log("train/loss", self.train_loss, on_step=True, on_epoch=True, prog_bar=True) 162 | 163 | # return loss or backpropagation will fail 164 | return loss 165 | 166 | def on_train_epoch_end(self) -> None: 167 | """Lightning hook that is called when a training epoch ends.""" 168 | pass 169 | 170 | def validation_step(self, batch: Batch, batch_idx: int) -> None: 171 | """Perform a single validation step on a batch of data from the validation set. 172 | 173 | :param batch: A batch of data (a tuple) containing the input tensor of images and target 174 | labels. 175 | :param batch_idx: The index of the current batch. 176 | """ 177 | loss, _, _ = self.model_step(batch) 178 | # update and log metrics 179 | self.val_loss(loss) 180 | self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True) 181 | 182 | def on_validation_epoch_end(self) -> None: 183 | """Lightning hook that is called when a validation epoch ends.""" 184 | pass 185 | 186 | def test_step(self, batch: Batch, batch_idx: int) -> None: 187 | """Perform a single test step on a batch of data from the test set. 188 | 189 | :param batch: A batch of data (a tuple) containing the input tensor of images and target 190 | labels. 191 | :param batch_idx: The index of the current batch. 192 | """ 193 | self.pert_cat.extend(batch.pert) 194 | 195 | _, p, t = self.model_step(batch) 196 | self.test_pred.extend(p) 197 | self.test_truth.extend(t) 198 | 199 | # Differentially expressed genes 200 | for itr, de_idx in enumerate(batch.de_idx): 201 | self.test_pred_de.append(p[itr, de_idx]) 202 | self.test_truth_de.append(t[itr, de_idx]) 203 | 204 | # all genes 205 | self.test_results['pert_cat'] = np.array(self.pert_cat) 206 | pred = torch.stack(self.test_pred) 207 | truth = torch.stack(self.test_truth) 208 | self.test_results['pred'] = pred 209 | self.test_results['truth'] = truth 210 | 211 | pred_de = torch.stack(self.test_pred_de) 212 | truth_de = torch.stack(self.test_truth_de) 213 | self.test_results['pred_de'] = pred_de 214 | self.test_results['truth_de'] = truth_de 215 | 216 | def on_test_epoch_end(self) -> None: 217 | """Lightning hook that is called when a test epoch ends.""" 218 | metrics = {} 219 | metrics_pert = {} 220 | 221 | for m in self.metric2fct.keys(): 222 | metrics[m] = [] 223 | metrics[m + '_de'] = [] 224 | 225 | for pert in np.unique(self.test_results['pert_cat']): 226 | metrics_pert[pert] = {} 227 | p_idx = np.where(self.test_results['pert_cat'] == pert)[0] 228 | 229 | pert_preds = torch.tensor(self.test_results['pred'][p_idx].mean(0)) 230 | pert_truth = torch.tensor(self.test_results['truth'][p_idx].mean(0)) 231 | pert_truth_de = torch.tensor(self.test_results['truth_de'][p_idx].mean(0)) 232 | pert_preds_de = torch.tensor(self.test_results['pred_de'][p_idx].mean(0)) 233 | for m, fct in self.metric2fct.items(): 234 | if m == 'pearson': 235 | val = fct(pert_preds, pert_truth).item() 236 | if np.isnan(val): 237 | val = 0 238 | else: 239 | val = fct(pert_preds, pert_truth).item() 240 | 241 | metrics_pert[pert][m] = val 242 | metrics[m].append(metrics_pert[pert][m]) 243 | 244 | if pert != 'ctrl': 245 | for m, fct in self.metric2fct.items(): 246 | if m == 'pearson': 247 | val = fct(pert_preds_de, pert_truth_de).item() 248 | if np.isnan(val): 249 | val = 0 250 | else: 251 | val = fct(pert_preds_de, pert_truth_de).item() 252 | 253 | metrics_pert[pert][m + '_de'] = val 254 | metrics[m + '_de'].append(metrics_pert[pert][m + '_de']) 255 | 256 | else: 257 | for m, fct in self.metric2fct.items(): 258 | metrics_pert[pert][m + '_de'] = 0 259 | 260 | for m in self.metric2fct.keys(): 261 | stacked_metrics = np.stack(metrics[m]) 262 | metrics[m] = np.mean(stacked_metrics) 263 | 264 | stacked_metrics_de = np.stack(metrics[m + '_de']) 265 | metrics[m + '_de'] = np.mean(stacked_metrics_de) 266 | 267 | metric_names = ['mse', 'pearson', 'spearman'] 268 | 269 | for m in metric_names: 270 | self.log("test/" + m, metrics[m]) 271 | self.log("test_de/" + m, metrics[m + '_de']) 272 | 273 | def setup(self, stage: str) -> None: 274 | """Lightning hook that is called at the beginning of fit (train + validate), validate, 275 | test, or predict. 276 | 277 | This is a good hook when you need to build models dynamically or adjust something about 278 | them. This hook is called on every process when using DDP. 279 | 280 | :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 281 | """ 282 | if self.hparams.compile and stage == "fit": 283 | self.net = torch.compile(self.net) 284 | 285 | def configure_optimizers(self) -> Dict[str, Any]: 286 | """Choose what optimizers and learning-rate schedulers to use in your optimization. 287 | Normally you'd need one. But in the case of GANs or similar you might have multiple. 288 | 289 | Examples: 290 | https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers 291 | 292 | :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. 293 | """ 294 | optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) 295 | if self.hparams.scheduler is not None: 296 | scheduler = self.hparams.scheduler(optimizer=optimizer) 297 | return { 298 | "optimizer": optimizer, 299 | "lr_scheduler": { 300 | "scheduler": scheduler, 301 | "monitor": "val/loss", 302 | "interval": "epoch", 303 | "frequency": 1, 304 | }, 305 | } 306 | return {"optimizer": optimizer} 307 | 308 | 309 | if __name__ == "__main__": 310 | _ = GEARSLitModule(None, None) 311 | -------------------------------------------------------------------------------- /src/models/prediction_module.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Literal, Optional, Dict, Tuple, List 2 | import torch 3 | import torch.nn as nn 4 | from lightning import LightningModule 5 | from torchmetrics import MeanSquaredError, MeanMetric 6 | import pickle as pkl 7 | 8 | 9 | class PredictionModule(LightningModule): 10 | def __init__( 11 | self, 12 | net: torch.nn.Module, 13 | model_type: Literal["mean", "linear_regression", "mlp"] = "mlp", 14 | optimizer: torch.optim.Optimizer = torch.optim.Adam, 15 | criterion: Optional[torch.nn.Module] = nn.MSELoss(), 16 | compile: Optional[bool] = False, 17 | scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, 18 | mean_adjusted: Optional[bool] = False, 19 | data_name: Optional[str] = None, 20 | save_dir: Optional[str] = None 21 | ) -> None: 22 | super().__init__() 23 | 24 | self.save_dir = save_dir 25 | self.save_hyperparameters(logger=False) 26 | self.mean_adjusted = mean_adjusted 27 | self.data_name = data_name 28 | 29 | self.net = net 30 | self.model_type = model_type 31 | 32 | self.criterion = criterion 33 | self.compile = compile 34 | 35 | self.train_loss = MeanMetric() 36 | self.val_loss = MeanMetric() 37 | self.test_loss = MeanMetric() 38 | 39 | self.train_mse = MeanSquaredError() 40 | self.val_mse = MeanSquaredError() 41 | self.test_mse = MeanSquaredError() 42 | self.baseline_mse = MeanSquaredError() 43 | 44 | def forward(self, x: torch.Tensor) -> torch.Tensor: 45 | return self.net(x) 46 | 47 | # noinspection PyTupleAssignmentBalance 48 | def model_step(self, batch: Tuple[torch.Tensor, torch.Tensor, Optional[dict], Optional[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 49 | if len(batch) == 4: 50 | x, y, deg_dict, input_expr = batch 51 | else: 52 | x, y, input_expr = batch 53 | 54 | if x.dtype != torch.float32: 55 | x = x.to(torch.float32) 56 | 57 | if y.dtype != torch.float32: 58 | y = y.to(torch.float32) 59 | 60 | if input_expr.dtype != torch.float32: 61 | input_expr = input_expr.to(torch.float32) 62 | 63 | preds = self.forward(x) 64 | pert_effect = y - input_expr 65 | loss = torch.sqrt(self.criterion(preds, pert_effect)) 66 | 67 | return loss, preds, pert_effect 68 | 69 | def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: 70 | loss, preds, targets = self.model_step(batch) 71 | self.train_loss(loss) 72 | self.train_mse(preds, targets) 73 | self.log("train/mse", self.train_mse, on_step=False, on_epoch=True, prog_bar=True) 74 | return loss 75 | 76 | def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: 77 | loss, preds, targets = self.model_step(batch) 78 | self.val_loss(loss) 79 | self.val_mse(preds, targets) 80 | self.log("val/mse", self.val_mse, on_step=False, on_epoch=True, prog_bar=True) 81 | 82 | # noinspection PyTupleAssignmentBalance 83 | def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor, Optional[list], Optional[torch.Tensor]], 84 | batch_idx: int) -> None: 85 | de_dict = None 86 | if len(batch) == 4: 87 | x, y, _de_dict_or_test_pert, input_expr = batch 88 | if isinstance(_de_dict_or_test_pert, dict): 89 | de_dict = _de_dict_or_test_pert 90 | else: 91 | test_perts = _de_dict_or_test_pert 92 | test_perts_idx = [i for i, pert in enumerate(test_perts) if '+' in pert] 93 | input_expr = input_expr[test_perts_idx, :] 94 | x = x[test_perts_idx, :] 95 | y = y[test_perts_idx, :] 96 | 97 | elif len(batch) == 3: 98 | x, y, _expr_or_de_dict = batch 99 | if isinstance(_expr_or_de_dict, dict): 100 | de_dict = _expr_or_de_dict 101 | input_expr = x[:, :x.shape[1] // 2] 102 | else: 103 | input_expr = _expr_or_de_dict 104 | else: 105 | x, y = batch 106 | input_expr = x 107 | if not de_dict: 108 | loss, preds, targets = self.model_step((x, y, input_expr)) 109 | self.test_mse(preds, targets) 110 | self.log("test/mse", self.test_mse, on_step=False, on_epoch=True, prog_bar=True) 111 | else: 112 | de_idx = de_dict['de_idx'] 113 | loss, preds, targets = self.model_step((x, y, de_idx, input_expr)) 114 | de_idx = torch.tensor([int(idx[0]) for idx in de_idx]) 115 | de_idx = torch.tensor(de_idx) 116 | preds = preds[:, de_idx] 117 | targets = targets[:, de_idx] 118 | self.test_mse(preds, targets) 119 | self.log("test/mse", self.test_mse, on_step=False, on_epoch=True, prog_bar=True) 120 | 121 | def setup(self, stage: str) -> None: 122 | """Lightning hook that is called at the beginning of fit (train + validate), validate, 123 | test, or predict. 124 | 125 | This is a good hook when you need to build models dynamically or adjust something about 126 | them. This hook is called on every process when using DDP. 127 | 128 | :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 129 | """ 130 | if self.hparams.compile and stage == "fit": 131 | self.net = torch.compile(self.net) 132 | 133 | def configure_optimizers(self) -> Dict[str, Any]: 134 | optimizer = self.hparams.optimizer(params=self.parameters()) 135 | if self.hparams.scheduler is not None: 136 | scheduler = self.hparams.scheduler(optimizer=optimizer) 137 | return { 138 | "optimizer": optimizer, 139 | "lr_scheduler": { 140 | "scheduler": scheduler, 141 | "monitor": "val/mse", 142 | "interval": "epoch", 143 | "frequency": 1, 144 | }, 145 | } 146 | return {"optimizer": optimizer} 147 | -------------------------------------------------------------------------------- /src/models/pretrained_ckpts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronwtr/PertEval/efbfa51991fbd6faa6039619d754e354be40fc07/src/models/pretrained_ckpts/__init__.py -------------------------------------------------------------------------------- /src/models/reproduction/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronwtr/PertEval/efbfa51991fbd6faa6039619d754e354be40fc07/src/models/reproduction/__init__.py -------------------------------------------------------------------------------- /src/models/reproduction/gears/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronwtr/PertEval/efbfa51991fbd6faa6039619d754e354be40fc07/src/models/reproduction/gears/__init__.py -------------------------------------------------------------------------------- /src/models/reproduction/gears/gears.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from gears.utils import get_similarity_network, GeneSimNetwork 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torch_geometric.nn import SGConv 8 | 9 | 10 | class GEARSNetwork(torch.nn.Module): 11 | """ 12 | GEARS model 13 | 14 | """ 15 | 16 | def __init__( 17 | self, 18 | hidden_size: int, 19 | num_go_gnn_layers: int, 20 | num_gene_gnn_layers: int, 21 | decoder_hidden_size: int, 22 | num_similar_genes_go_graph: int, 23 | num_similar_genes_co_express_graph: int, 24 | coexpress_threshold: float, 25 | uncertainty: bool, 26 | uncertainty_reg: float, 27 | direction_lambda: float, 28 | G_go: Optional[torch.Tensor] = torch.Tensor([]), 29 | G_go_weight: Optional[torch.Tensor] = torch.Tensor([]), 30 | G_coexpress: Optional[torch.Tensor] = torch.Tensor([]), 31 | G_coexpress_weight: Optional[torch.Tensor] = torch.Tensor([]), 32 | no_perturb: bool = False, 33 | pert_emb_lambda: float = 0.2, 34 | num_genes: int = None, 35 | num_perts: int = None 36 | ): 37 | super(GEARSNetwork, self).__init__() 38 | 39 | self.hidden_size = hidden_size 40 | self.uncertainty = uncertainty 41 | self.num_layers = num_go_gnn_layers 42 | self.indv_out_hidden_size = decoder_hidden_size 43 | self.num_similar_genes_go_graph = num_similar_genes_go_graph 44 | self.num_similar_genes_co_express_graph = num_similar_genes_co_express_graph 45 | self.G_go = G_go 46 | self.G_go_weight = G_go_weight 47 | self.G_coexpress = G_coexpress 48 | self.G_coexpress_weight = G_coexpress_weight 49 | self.coexpress_threshold = coexpress_threshold 50 | self.uncertainty_reg = uncertainty_reg 51 | self.direction_lambda = direction_lambda 52 | self.num_layers_gene_pos = num_gene_gnn_layers 53 | self.no_perturb = no_perturb 54 | self.pert_emb_lambda = pert_emb_lambda 55 | self.num_genes = num_genes 56 | self.num_perts = num_perts 57 | 58 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 59 | 60 | # perturbation positional embedding added only to the perturbed genes 61 | self.pert_w = nn.Linear(1, self.hidden_size) 62 | 63 | # gene/globel perturbation embedding dictionary lookup 64 | self.gene_emb = nn.Embedding(self.num_genes, self.hidden_size, max_norm=True) 65 | self.pert_emb = nn.Embedding(self.num_perts, self.hidden_size, max_norm=True) 66 | 67 | # transformation layer 68 | self.emb_trans = nn.ReLU() 69 | self.pert_base_trans = nn.ReLU() 70 | self.transform = nn.ReLU() 71 | self.emb_trans_v2 = MLP([self.hidden_size, self.hidden_size, self.hidden_size], last_layer_act='ReLU') 72 | self.pert_fuse = MLP([self.hidden_size, self.hidden_size, self.hidden_size], last_layer_act='ReLU') 73 | 74 | # gene co-expression GNN 75 | if self.G_coexpress is not None: 76 | self.G_coexpress = self.G_coexpress.to(self.device) 77 | self.G_coexpress_weight = self.G_coexpress_weight.to(self.device) 78 | else: 79 | self.G_coexpress = self.G_coexpress 80 | self.G_coexpress_weight = self.G_coexpress_weight 81 | 82 | self.emb_pos = nn.Embedding(self.num_genes, self.hidden_size, max_norm=True).to(self.device) 83 | self.layers_emb_pos = torch.nn.ModuleList() 84 | for i in range(1, self.num_layers_gene_pos + 1): 85 | self.layers_emb_pos.append(SGConv(self.hidden_size, self.hidden_size, 1)) 86 | 87 | ### perturbation gene ontology GNN 88 | if self.G_go is not None: 89 | self.G_sim = self.G_go.to(self.device) 90 | self.G_sim_weight = self.G_go_weight.to(self.device) 91 | else: 92 | self.G_sim = self.G_go 93 | self.G_sim_weight = self.G_go_weight 94 | 95 | self.sim_layers = torch.nn.ModuleList() 96 | for i in range(1, self.num_layers + 1): 97 | self.sim_layers.append(SGConv(self.hidden_size, self.hidden_size, 1)) 98 | 99 | # decoder shared MLP 100 | self.recovery_w = MLP([self.hidden_size, self.hidden_size * 2, self.hidden_size], last_layer_act='linear') 101 | 102 | # gene specific decoder 103 | self.indv_w1 = nn.Parameter(torch.rand(self.num_genes, 104 | self.hidden_size, 1)) 105 | self.indv_b1 = nn.Parameter(torch.rand(self.num_genes, 1)) 106 | self.act = nn.ReLU() 107 | nn.init.xavier_normal_(self.indv_w1) 108 | nn.init.xavier_normal_(self.indv_b1) 109 | 110 | # Cross gene MLP 111 | self.cross_gene_state = MLP([self.num_genes, self.hidden_size, 112 | self.hidden_size]) 113 | # final gene specific decoder 114 | self.indv_w2 = nn.Parameter(torch.rand(1, self.num_genes, 115 | self.hidden_size + 1)) 116 | self.indv_b2 = nn.Parameter(torch.rand(1, self.num_genes)) 117 | nn.init.xavier_normal_(self.indv_w2) 118 | nn.init.xavier_normal_(self.indv_b2) 119 | 120 | # batchnorms 121 | self.bn_emb = nn.BatchNorm1d(self.hidden_size) 122 | self.bn_pert_base = nn.BatchNorm1d(self.hidden_size) 123 | self.bn_pert_base_trans = nn.BatchNorm1d(self.hidden_size) 124 | 125 | # uncertainty mode 126 | if self.uncertainty: 127 | self.uncertainty_w = MLP([self.hidden_size, self.hidden_size * 2, self.hidden_size, 1], 128 | last_layer_act='linear') 129 | 130 | def forward(self, x, pert_idx, batch): 131 | """ 132 | Forward pass of the model 133 | """ 134 | # x, pert_idx = data.x, data.pert_idx 135 | if self.no_perturb: 136 | out = x.reshape(-1, 1) 137 | out = torch.split(torch.flatten(out), self.num_genes) 138 | return torch.stack(out) 139 | else: 140 | num_graphs = len(batch.batch.unique()) 141 | 142 | ## get base gene embeddings 143 | emb = self.gene_emb( 144 | torch.LongTensor(list(range(self.num_genes))).repeat(num_graphs, ).to(self.device)) 145 | emb = self.bn_emb(emb) 146 | base_emb = self.emb_trans(emb) 147 | 148 | pos_emb = self.emb_pos( 149 | torch.LongTensor(list(range(self.num_genes))).repeat(num_graphs, ).to(self.device)).to(self.device) 150 | for idx, layer in enumerate(self.layers_emb_pos): 151 | pos_emb = layer(pos_emb, self.G_coexpress.to(self.device), self.G_coexpress_weight.to(self.device)) 152 | if idx < len(self.layers_emb_pos) - 1: 153 | pos_emb = pos_emb.relu() 154 | 155 | base_emb = base_emb + 0.2 * pos_emb 156 | base_emb = self.emb_trans_v2(base_emb) 157 | 158 | ## get perturbation index and embeddings 159 | 160 | pert_index = [] 161 | for idx, i in enumerate(pert_idx): 162 | for j in i: 163 | if j != -1: 164 | pert_index.append([idx, j]) 165 | pert_index = torch.tensor(pert_index).T 166 | 167 | pert_global_emb = self.pert_emb(torch.LongTensor(list(range(self.num_perts))).to(self.device)) 168 | 169 | ## augment global perturbation embedding with GNN 170 | for idx, layer in enumerate(self.sim_layers): 171 | pert_global_emb = layer(pert_global_emb, self.G_go.to(self.device), self.G_go_weight.to(self.device)) 172 | if idx < self.num_layers - 1: 173 | pert_global_emb = pert_global_emb.relu() 174 | 175 | ## add global perturbation embedding to each gene in each cell in the batch 176 | base_emb = base_emb.reshape(num_graphs, self.num_genes, -1) 177 | 178 | if pert_index.shape[0] != 0: 179 | ### in case all samples in the batch are controls, then there is no indexing for pert_index. 180 | pert_track = {} 181 | for i, j in enumerate(pert_index[0]): 182 | if j.item() in pert_track: 183 | pert_track[j.item()] = pert_track[j.item()] + pert_global_emb[pert_index[1][i]] 184 | else: 185 | pert_track[j.item()] = pert_global_emb[pert_index[1][i]] 186 | 187 | if len(list(pert_track.values())) > 0: 188 | if len(list(pert_track.values())) == 1: 189 | # circumvent when batch size = 1 with single perturbation and cannot feed into MLP 190 | emb_total = self.pert_fuse(torch.stack(list(pert_track.values()) * 2)) 191 | else: 192 | emb_total = self.pert_fuse(torch.stack(list(pert_track.values()))) 193 | 194 | for idx, j in enumerate(pert_track.keys()): 195 | base_emb[j] = base_emb[j] + emb_total[idx] 196 | 197 | base_emb = base_emb.reshape(num_graphs * self.num_genes, -1) 198 | base_emb = self.bn_pert_base(base_emb) 199 | 200 | ## apply the first MLP 201 | base_emb = self.transform(base_emb) 202 | out = self.recovery_w(base_emb) 203 | out = out.reshape(num_graphs, self.num_genes, -1) 204 | out = out.unsqueeze(-1) * self.indv_w1 205 | w = torch.sum(out, axis=2) 206 | out = w + self.indv_b1 207 | 208 | # Cross gene 209 | cross_gene_embed = self.cross_gene_state(out.reshape(num_graphs, self.num_genes, -1).squeeze(2)) 210 | cross_gene_embed = cross_gene_embed.repeat(1, self.num_genes) 211 | 212 | cross_gene_embed = cross_gene_embed.reshape([num_graphs, self.num_genes, -1]) 213 | cross_gene_out = torch.cat([out, cross_gene_embed], 2) 214 | 215 | cross_gene_out = cross_gene_out * self.indv_w2 216 | cross_gene_out = torch.sum(cross_gene_out, axis=2) 217 | out = cross_gene_out + self.indv_b2 218 | out = out.reshape(num_graphs * self.num_genes, -1) + x.reshape(-1, 1) 219 | out = torch.split(torch.flatten(out), self.num_genes) 220 | 221 | ## uncertainty head 222 | if self.uncertainty: 223 | out_logvar = self.uncertainty_w(base_emb) 224 | out_logvar = torch.split(torch.flatten(out_logvar), self.num_genes) 225 | return torch.stack(out), torch.stack(out_logvar) 226 | 227 | return torch.stack(out) 228 | 229 | def model_initialize(self, pertmodule) -> None: 230 | """Initialize the model""" 231 | pert_data = pertmodule.pert_data 232 | if self.G_coexpress is None: 233 | ## calculating co expression similarity graph 234 | edge_list = get_similarity_network( 235 | network_type='co-express', 236 | adata=pert_data.adata, 237 | threshold=self.coexpress_threshold, 238 | k=self.num_similar_genes_co_express_graph, 239 | data_path=pert_data.data_path, 240 | data_name='', 241 | split=pert_data.split, 242 | seed=pert_data.seed, 243 | train_gene_set_size=pert_data.train_gene_set_size, 244 | set2conditions=pert_data.set2conditions 245 | ) 246 | 247 | sim_network = GeneSimNetwork(edge_list, pertmodule.gene_list, node_map=pert_data.node_map) 248 | self.G_coexpress = sim_network.edge_index 249 | self.G_coexpress_weight = sim_network.edge_weight 250 | 251 | if self.G_go is None: 252 | ## calculating gene ontology similarity graph 253 | edge_list = get_similarity_network( 254 | network_type='go', 255 | adata=pert_data.adata, 256 | threshold=self.coexpress_threshold, 257 | k=self.num_similar_genes_go_graph, 258 | pert_list=pertmodule.pert_list, 259 | data_path=pert_data.data_path, 260 | data_name='', 261 | split=pert_data.split, 262 | seed=pert_data.seed, 263 | train_gene_set_size=pert_data.train_gene_set_size, 264 | set2conditions=pert_data.set2conditions, 265 | default_pert_graph=pert_data.default_pert_graph 266 | ) 267 | 268 | sim_network = GeneSimNetwork(edge_list, pertmodule.pert_list, node_map=pert_data.node_map_pert) 269 | self.G_go = sim_network.edge_index 270 | self.G_go_weight = sim_network.edge_weight 271 | 272 | 273 | class MLP(torch.nn.Module): 274 | 275 | def __init__(self, sizes, batch_norm=True, last_layer_act="linear"): 276 | """ 277 | Multi-layer perceptron 278 | :param sizes: list of sizes of the layers 279 | :param batch_norm: whether to use batch normalization 280 | :param last_layer_act: activation function of the last layer 281 | 282 | """ 283 | super(MLP, self).__init__() 284 | layers = [] 285 | for s in range(len(sizes) - 1): 286 | layers = layers + [ 287 | torch.nn.Linear(sizes[s], sizes[s + 1]), 288 | torch.nn.BatchNorm1d(sizes[s + 1]) 289 | if batch_norm and s < len(sizes) - 1 else None, 290 | torch.nn.ReLU() 291 | ] 292 | 293 | layers = [l for l in layers if l is not None][:-1] 294 | self.activation = last_layer_act 295 | self.network = torch.nn.Sequential(*layers) 296 | self.relu = torch.nn.ReLU() 297 | 298 | def forward(self, x): 299 | return self.network(x) -------------------------------------------------------------------------------- /src/models/reproduction/gears_spectra/gears_environment.yml: -------------------------------------------------------------------------------- 1 | name: gears_env2 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - _libgcc_mutex=0.1=conda_forge 6 | - _openmp_mutex=4.5=2_gnu 7 | - asttokens=2.4.1=pyhd8ed1ab_0 8 | - bzip2=1.0.8=hd590300_5 9 | - ca-certificates=2024.8.30=hbcca054_0 10 | - comm=0.2.2=pyhd8ed1ab_0 11 | - decorator=5.1.1=pyhd8ed1ab_0 12 | - importlib-metadata=8.5.0=pyha770c72_0 13 | - jupyter_client=8.6.3=pyhd8ed1ab_0 14 | - jupyter_core=5.7.2=pyh31011fe_1 15 | - keyutils=1.6.1=h166bdaf_0 16 | - krb5=1.21.3=h659f571_0 17 | - ld_impl_linux-64=2.40=hf3520f5_1 18 | - libedit=3.1.20191231=he28a2e2_2 19 | - libffi=3.4.2=h7f98852_5 20 | - libgcc-ng=13.2.0=h77fa898_7 21 | - libgomp=13.2.0=h77fa898_7 22 | - libnsl=2.0.1=hd590300_0 23 | - libsodium=1.0.18=h36c2ea0_1 24 | - libsqlite=3.45.3=h2797004_0 25 | - libstdcxx-ng=13.2.0=hc0a3c3a_7 26 | - libuuid=2.38.1=h0b41bf4_0 27 | - libxcrypt=4.4.36=hd590300_1 28 | - libzlib=1.3.1=h4ab18f5_1 29 | - matplotlib-inline=0.1.7=pyhd8ed1ab_0 30 | - ncurses=6.5=h59595ed_0 31 | - nest-asyncio=1.6.0=pyhd8ed1ab_0 32 | - openssl=3.3.1=hb9d3cd8_3 33 | - parso=0.8.4=pyhd8ed1ab_0 34 | - pexpect=4.9.0=pyhd8ed1ab_0 35 | - pickleshare=0.7.5=py_1003 36 | - pip=24.0=pyhd8ed1ab_0 37 | - ptyprocess=0.7.0=pyhd3deb0d_0 38 | - pure_eval=0.2.3=pyhd8ed1ab_0 39 | - pygments=2.18.0=pyhd8ed1ab_0 40 | - python=3.10.14=hd12c33a_0_cpython 41 | - python_abi=3.10=5_cp310 42 | - readline=8.2=h8228510_1 43 | - setuptools=70.0.0=pyhd8ed1ab_0 44 | - six=1.16.0=pyh6c4a22f_0 45 | - stack_data=0.6.2=pyhd8ed1ab_0 46 | - tk=8.6.13=noxft_h4845f30_101 47 | - traitlets=5.14.3=pyhd8ed1ab_0 48 | - typing_extensions=4.12.2=pyha770c72_0 49 | - wcwidth=0.2.13=pyhd8ed1ab_0 50 | - wheel=0.43.0=pyhd8ed1ab_1 51 | - xz=5.2.6=h166bdaf_0 52 | - zeromq=4.3.5=h75354e8_4 53 | - zipp=3.21.0=pyhd8ed1ab_0 54 | - pip: 55 | - aiohttp==3.9.5 56 | - aiosignal==1.3.1 57 | - anndata==0.10.7 58 | - array-api-compat==1.7.1 59 | - async-timeout==4.0.3 60 | - attrs==23.2.0 61 | - cell-gears==0.1.2 62 | - certifi==2024.6.2 63 | - charset-normalizer==3.3.2 64 | - contourpy==1.2.1 65 | - cycler==0.12.1 66 | - dcor==0.6 67 | - debugpy==1.8.1 68 | - exceptiongroup==1.2.1 69 | - executing==2.0.1 70 | - filelock==3.14.0 71 | - fonttools==4.53.0 72 | - frozenlist==1.4.1 73 | - fsspec==2024.6.0 74 | - h5py==3.11.0 75 | - idna==3.7 76 | - ipykernel==6.29.4 77 | - ipython==8.25.0 78 | - jedi==0.19.1 79 | - jinja2==3.1.4 80 | - joblib==1.4.2 81 | - jupyter-client==8.6.2 82 | - kiwisolver==1.4.5 83 | - legacy-api-wrap==1.4 84 | - llvmlite==0.42.0 85 | - markupsafe==2.1.5 86 | - matplotlib==3.9.0 87 | - mpmath==1.3.0 88 | - multidict==6.0.5 89 | - natsort==8.4.0 90 | - networkx==3.3 91 | - numba==0.59.1 92 | - numpy==1.26.4 93 | - nvidia-cublas-cu12==12.1.3.1 94 | - nvidia-cuda-cupti-cu12==12.1.105 95 | - nvidia-cuda-nvrtc-cu12==12.1.105 96 | - nvidia-cuda-runtime-cu12==12.1.105 97 | - nvidia-cudnn-cu12==8.9.2.26 98 | - nvidia-cufft-cu12==11.0.2.54 99 | - nvidia-curand-cu12==10.3.2.106 100 | - nvidia-cusolver-cu12==11.4.5.107 101 | - nvidia-cusparse-cu12==12.1.0.106 102 | - nvidia-nccl-cu12==2.20.5 103 | - nvidia-nvjitlink-cu12==12.5.40 104 | - nvidia-nvtx-cu12==12.1.105 105 | - packaging==24.0 106 | - pandas==2.2.2 107 | - patsy==0.5.6 108 | - pillow==10.3.0 109 | - platformdirs==4.2.2 110 | - prompt-toolkit==3.0.45 111 | - psutil==5.9.8 112 | - pure-eval==0.2.2 113 | - pynndescent==0.5.12 114 | - pyparsing==3.1.2 115 | - python-dateutil==2.9.0.post0 116 | - pytz==2024.1 117 | - pyzmq==26.0.3 118 | - requests==2.32.3 119 | - scanpy==1.10.1 120 | - scikit-learn==1.5.0 121 | - scipy==1.13.1 122 | - seaborn==0.13.2 123 | - session-info==1.0.0 124 | - spectrae==1.0.3 125 | - stack-data==0.6.3 126 | - statsmodels==0.14.2 127 | - stdlib-list==0.10.0 128 | - sympy==1.12.1 129 | - threadpoolctl==3.5.0 130 | - torch==2.3.0 131 | - torch-geometric==2.5.3 132 | - tornado==6.4 133 | - tqdm==4.66.4 134 | - triton==2.3.0 135 | - tzdata==2024.2 136 | - umap-learn==0.5.6 137 | - urllib3==2.2.1 138 | - yarl==1.9.4 139 | prefix: /data/SBCS-BessantLab/martina/gears/gears_env2 140 | -------------------------------------------------------------------------------- /src/models/reproduction/gears_spectra/gears_with_spectra_singlegene.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | print(torch.__version__) 4 | print(torch.version.cuda) 5 | 6 | # Package imports 7 | from gears import PertData, GEARS 8 | from gears.utils import dataverse_download 9 | from zipfile import ZipFile 10 | import tarfile 11 | import numpy as np 12 | import pickle 13 | from spectrae import SpectraDataset 14 | import scanpy as sc 15 | from sklearn.model_selection import train_test_split 16 | import os 17 | from tqdm import tqdm 18 | 19 | class PerturbGraphData(SpectraDataset): 20 | def parse(self, pert_data): 21 | if isinstance(pert_data, PertData): 22 | self.adata = pert_data.adata 23 | else: 24 | self.adata = pert_data 25 | self.control_expression = self.adata[self.adata.obs['condition'] == 'ctrl'].X.toarray().mean(axis=0) 26 | return [p for p in self.adata.obs['condition'].unique() if p != 'ctrl'] 27 | 28 | def get_mean_logfold_change(self, perturbation): 29 | perturbation_expression = self.adata[self.adata.obs['condition'] == perturbation].X.toarray().mean(axis=0) 30 | logfold_change = np.nan_to_num(np.log2(perturbation_expression + 1) - np.log2(self.control_expression + 1)) 31 | return logfold_change 32 | 33 | def sample_to_index(self, sample): 34 | if not hasattr(self, 'index_to_sequence'): 35 | print("Generating index to sequence") 36 | self.index_to_sequence = {} 37 | for i in tqdm(range(len(self))): 38 | x = self.__getitem__(i) 39 | self.index_to_sequence['-'.join(list(x))] = i 40 | 41 | return self.index_to_sequence[sample] 42 | 43 | def __len__(self): 44 | return len(self.samples) 45 | 46 | def __getitem__(self, idx): 47 | perturbation = self.samples[idx] 48 | return self.get_mean_logfold_change(perturbation) 49 | 50 | if __name__ == '__main__': 51 | parser = argparse.ArgumentParser(description='Train GEARS model with custom splits.') 52 | parser.add_argument('--split_folder', type=str, required=True, 53 | help='Path to the split folder containing train.pkl and test.pkl') 54 | parser.add_argument('--gears_path', type=str, default='/data/SBCS-BessantLab/martina/gears', 55 | help='Path to the gears directory') 56 | parser.add_argument('--epochs', type=int, default=20, 57 | help='Number of epochs to train the model') 58 | parser.add_argument('--device', type=str, default='cuda', 59 | help='Device to use for training (e.g., "cuda" or "cpu")') 60 | args = parser.parse_args() 61 | split_folder = args.split_folder 62 | gears_path = args.gears_path.rstrip('/') # Remove trailing slash if any 63 | epochs = args.epochs 64 | device = args.device 65 | 66 | print(split_folder) 67 | 68 | # Ensure necessary directories exist 69 | if not os.path.exists(gears_path): 70 | os.makedirs(gears_path) 71 | 72 | # Download dataloader if not already present 73 | data_file = f'{gears_path}/norman_umi_go.tar.gz' 74 | if not os.path.exists(data_file): 75 | dataverse_download('https://dataverse.harvard.edu/api/access/datafile/6979957', data_file) 76 | with tarfile.open(data_file, 'r:gz') as tar: 77 | tar.extractall(path=gears_path) 78 | 79 | # Download model if not already present 80 | # model_file = os.path.join(gears_path, 'model.zip') 81 | # if not os.path.exists(model_file): 82 | # dataverse_download('https://dataverse.harvard.edu/api/access/datafile/6979956', model_file) 83 | # with ZipFile(model_file, 'r') as zip_ref: 84 | # zip_ref.extractall(path=gears_path) 85 | 86 | # Load custom train test splits 87 | with open(f'{split_folder}/train.pkl', 'rb') as file: 88 | train_splits = pickle.load(file) 89 | 90 | with open(f'{split_folder}/test.pkl', 'rb') as file: 91 | test_splits = pickle.load(file) 92 | 93 | # Load adata 94 | adata = sc.read(f'{gears_path}/Norman_2019_raw.h5ad') 95 | 96 | # Filter genes 97 | nonzero_genes = (adata.X.sum(axis=0) > 5).A1 98 | filtered_adata = adata[:, nonzero_genes] 99 | single_gene_mask = [True if "," not in name else False for name in adata.obs['guide_ids']] 100 | sg_adata = filtered_adata[single_gene_mask, :] 101 | sg_adata.obs['condition'] = sg_adata.obs['guide_ids'].replace('', 'ctrl') 102 | 103 | genes = sg_adata.var['gene_symbols'].to_list() 104 | genes_and_ctrl = genes + ['ctrl'] 105 | 106 | # Remove cells with perts not in the genes 107 | sg_pert_adata = sg_adata[sg_adata.obs['condition'].isin(genes_and_ctrl), :] 108 | 109 | # Create PerturbGraphData 110 | perturb_graph_data = PerturbGraphData(sg_pert_adata, 'norman') 111 | del nonzero_genes, filtered_adata, single_gene_mask, sg_adata, sg_pert_adata, genes, genes_and_ctrl 112 | 113 | # Function to get perturbation names 114 | def get_pert_names(pert_idxs): 115 | return [perturb_graph_data.samples[idx] for idx in pert_idxs] 116 | 117 | our_train_splits = get_pert_names(train_splits) 118 | our_test_splits = get_pert_names(test_splits) 119 | 120 | # Get splits in format needed for GEARS 121 | our_train_perts = [split + '+' + 'ctrl' for split in our_train_splits] 122 | our_test_perts = [split + '+' + 'ctrl' for split in our_test_splits] 123 | 124 | # Split our_train_perts into train and validation sets 125 | train_perts, val_perts = train_test_split(our_train_perts, test_size=0.2, random_state=42) 126 | 127 | # Load pert_data 128 | pert_data_folder = gears_path 129 | pert_data = PertData(pert_data_folder) 130 | data_name = 'norman_umi_go' 131 | pert_data.load(data_path = pert_data_folder + '/' + data_name) 132 | gear_perts = pert_data.adata.obs['condition'].cat.remove_unused_categories().cat.categories.tolist() 133 | 134 | # Filter perts 135 | def filter_perts(pert_list, gear_perts): 136 | filtered_perts = [] 137 | for pert in pert_list: 138 | if pert in gear_perts: 139 | filtered_perts.append(pert) 140 | else: 141 | # Some perts might be 'ctrl+pert' instead of 'pert+ctrl' 142 | pn = pert.split('+')[0] 143 | new_pert_fmt = 'ctrl' + '+' + pn 144 | if new_pert_fmt in gear_perts: 145 | filtered_perts.append(new_pert_fmt) 146 | else: 147 | print(f"Perturbation {pert} not found in gear_perts.") 148 | return filtered_perts 149 | 150 | train_perts = filter_perts(train_perts, gear_perts) 151 | val_perts = filter_perts(val_perts, gear_perts) 152 | test_perts = filter_perts(our_test_perts, gear_perts) 153 | 154 | # Remove problematic perts 155 | problematic_perts = ["IER5L+ctrl", "SLC38A2+ctrl", "RHOXF2+ctrl"] 156 | for pert in problematic_perts: 157 | if pert in train_perts: 158 | train_perts.remove(pert) 159 | if pert in val_perts: 160 | val_perts.remove(pert) 161 | 162 | # Set up set2conditions 163 | set2conditions = { 164 | "train": train_perts, 165 | "val": val_perts, 166 | "test": test_perts 167 | } 168 | 169 | # Ensure that the sets are not empty 170 | if not train_perts: 171 | raise ValueError("Training set is empty after filtering.") 172 | if not val_perts: 173 | raise ValueError("Validation set is empty after filtering.") 174 | if not test_perts: 175 | raise ValueError("Test set is empty after filtering.") 176 | 177 | # Set up pert_data 178 | pert_data.set2conditions = set2conditions 179 | pert_data.split = "custom" 180 | pert_data.subgroup = None 181 | pert_data.seed = 1 182 | pert_data.get_dataloader(batch_size=32, test_batch_size=128) 183 | 184 | # Train the model 185 | gears_model = GEARS(pert_data, device=device) 186 | gears_model.model_initialize(hidden_size=64) 187 | gears_model.train(epochs=epochs) 188 | -------------------------------------------------------------------------------- /src/models/reproduction/gears_spectra/gears_with_spectra_twogene.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | import argparse 5 | import os 6 | import pickle 7 | import numpy as np 8 | from gears import PertData, GEARS 9 | from gears.utils import dataverse_download 10 | from zipfile import ZipFile 11 | import tarfile 12 | from spectrae import SpectraDataset 13 | import scanpy as sc 14 | from sklearn.model_selection import train_test_split 15 | from tqdm import tqdm 16 | 17 | # Define the PerturbGraphData class 18 | class PerturbGraphData(SpectraDataset): 19 | def parse(self, pert_data): 20 | if isinstance(pert_data, PertData): 21 | self.adata = pert_data.adata 22 | else: 23 | self.adata = pert_data 24 | self.control_expression = self.adata[self.adata.obs['condition'] == 'ctrl'].X.toarray().mean(axis=0) 25 | return [p for p in self.adata.obs['condition'].unique() if p != 'ctrl'] 26 | 27 | def get_mean_logfold_change(self, perturbation): 28 | perturbation_expression = self.adata[self.adata.obs['condition'] == perturbation].X.toarray().mean(axis=0) 29 | logfold_change = np.nan_to_num(np.log2(perturbation_expression + 1) - np.log2(self.control_expression + 1)) 30 | return logfold_change 31 | 32 | def sample_to_index(self, sample): 33 | if not hasattr(self, 'index_to_sequence'): 34 | print("Generating index to sequence") 35 | self.index_to_sequence = {} 36 | for i in tqdm(range(len(self))): 37 | x = self.__getitem__(i) 38 | self.index_to_sequence['-'.join(list(x))] = i 39 | 40 | return self.index_to_sequence[sample] 41 | 42 | def __len__(self): 43 | return len(self.samples) 44 | 45 | def __getitem__(self, idx): 46 | perturbation = self.samples[idx] 47 | return self.get_mean_logfold_change(perturbation) 48 | 49 | def main(): 50 | parser = argparse.ArgumentParser(description='Train GEARS model with custom splits.') 51 | parser.add_argument('--gears_path', type=str, default='/data/SBCS-BessantLab/martina/gears', 52 | help='Path to GEARS data directory') 53 | parser.add_argument('--spectra_splits_dir', type=str, required=True, 54 | help='Directory containing the spectra splits') 55 | parser.add_argument('--split_name', type=str, required=True, 56 | help='Name of the split to use (e.g., SP_0.00_0)') 57 | parser.add_argument('--epochs', type=int, default=20, help='Number of epochs to train') 58 | parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training') 59 | parser.add_argument('--test_batch_size', type=int, default=128, help='Batch size for testing') 60 | parser.add_argument('--device', type=str, default='cuda', help='Device to use for training') 61 | parser.add_argument('--output_dir', type=str, default='./output', help='Directory to save outputs') 62 | args = parser.parse_args() 63 | 64 | # Print the arguments 65 | print("Arguments:") 66 | for arg in vars(args): 67 | print(f"{arg}: {getattr(args, arg)}") 68 | 69 | # Ensure output directory exists 70 | os.makedirs(args.output_dir, exist_ok=True) 71 | 72 | gears_path = args.gears_path.rstrip('/') # Remove trailing slash if any 73 | 74 | # Download dataloader if not already present 75 | data_file = f'{gears_path}/norman_umi_go.tar.gz' 76 | if not os.path.exists(data_file): 77 | dataverse_download('https://dataverse.harvard.edu/api/access/datafile/6979957', data_file) 78 | with tarfile.open(data_file, 'r:gz') as tar: 79 | tar.extractall(path=gears_path) 80 | 81 | # Load the train and test splits 82 | spectra_splits_path = os.path.join(args.spectra_splits_dir, args.split_name) 83 | 84 | print(spectra_splits_path) 85 | 86 | 87 | with open(os.path.join(spectra_splits_path, 'train.pkl'), 'rb') as file: 88 | train_splits = pickle.load(file) 89 | 90 | with open(os.path.join(spectra_splits_path, 'test.pkl'), 'rb') as file: 91 | test_splits = pickle.load(file) 92 | 93 | # Load the dataset 94 | adata = sc.read(os.path.join(args.gears_path, 'Norman_2019_raw.h5ad')) 95 | 96 | # Filter genes with sufficient expression 97 | nonzero_genes = (adata.X.sum(axis=0) > 5).A1 98 | filtered_adata = adata[:, nonzero_genes] 99 | 100 | # Initialize lists to track conditions 101 | conditions = [] 102 | 103 | # Process perturbations while preserving order 104 | for guide_id in filtered_adata.obs['guide_ids']: 105 | if "," in guide_id: # Two-gene perturbation 106 | conditions.append(guide_id.replace(',', '+')) 107 | elif guide_id == "": # Empty guide_id, treat as control 108 | conditions.append("ctrl") 109 | else: # Single-gene perturbation 110 | conditions.append(guide_id) 111 | 112 | # Assign the processed conditions back to the AnnData object 113 | filtered_adata.obs['condition'] = conditions 114 | 115 | # Create a mask to keep only single and two-gene perturbations, excluding "ctrl" 116 | perturbation_mask = filtered_adata.obs['condition'] != 'ctrl' 117 | pert_adata = filtered_adata[perturbation_mask, :] 118 | 119 | pert_adata.obs['condition'] = pert_adata.obs['condition'].astype('category') 120 | 121 | # Generate the PerturbGraphData object 122 | perturb_graph_data = PerturbGraphData(pert_adata, 'norman') 123 | 124 | def get_pert_names(pert_idxs): 125 | return [perturb_graph_data.samples[idx] for idx in pert_idxs] 126 | 127 | # Split train_splits into train and validation indices 128 | # train_indices, val_indices = train_test_split(train_splits, test_size=0.2, random_state=42) 129 | 130 | # Get perturbation names for train, val, and test 131 | our_train_splits = get_pert_names(train_splits) 132 | #our_val_splits = get_pert_names(val_indices) 133 | our_test_splits = get_pert_names(test_splits) 134 | 135 | # Function to add '+ctrl' to single-gene perturbations 136 | def add_ctrl_to_single_gene_perts(pert_list): 137 | updated_list = [ 138 | pert + '+ctrl' if '+' not in pert else pert # Add "+ctrl" if no "+" exists in the perturbation 139 | for pert in pert_list 140 | ] 141 | return updated_list 142 | 143 | # Process the perturbations 144 | our_train_perts = add_ctrl_to_single_gene_perts(our_train_splits) 145 | test_perts = add_ctrl_to_single_gene_perts(our_test_splits) 146 | 147 | train_perts, val_perts = train_test_split(our_train_perts, test_size=0.2, random_state=42) 148 | 149 | # Load pert_data 150 | pert_data_folder = args.gears_path 151 | pert_data = PertData(pert_data_folder) 152 | data_name = 'norman_umi_go' 153 | pert_data.load(data_path=os.path.join(pert_data_folder, data_name)) 154 | 155 | # Get the list of perturbations in GEARS data 156 | gear_perts = pert_data.adata.obs['condition'].cat.remove_unused_categories().cat.categories.tolist() 157 | 158 | # Function to adjust perturbations 159 | def adjust_perturbations(pert_list, gear_perts): 160 | adjusted_pert_list = [] 161 | for pert in pert_list: 162 | if pert in gear_perts: 163 | adjusted_pert_list.append(pert) 164 | continue 165 | if '+' not in pert: # Single-gene perturbation 166 | # Add "+ctrl" to single-gene perturbations 167 | pn = pert 168 | new_pert_fmt = 'ctrl' + '+' + pn 169 | if new_pert_fmt in gear_perts: 170 | print(f"Reformatted single-gene perturbation: {pert} -> {new_pert_fmt}") 171 | adjusted_pert_list.append(new_pert_fmt) 172 | else: 173 | print(f"Perturbation {pert} not found in gear_perts after adding '+ctrl'. Skipping.") 174 | else: # Two-gene perturbation 175 | # Switch order of genes for two-gene perturbations 176 | genes = pert.split('+') 177 | switched_pert = '+'.join(genes[::-1]) # Reverse the gene order 178 | if switched_pert in gear_perts: 179 | print(f"Switched two-gene perturbation: {pert} -> {switched_pert}") 180 | adjusted_pert_list.append(switched_pert) 181 | else: 182 | print(f"Perturbation {pert} not found in gear_perts after switching order. Skipping.") 183 | return adjusted_pert_list 184 | 185 | # Adjust the perturbations 186 | train = adjust_perturbations(train_perts, gear_perts) 187 | val = adjust_perturbations(val_perts, gear_perts) 188 | test = adjust_perturbations(test_perts, gear_perts) 189 | 190 | # Create set2conditions 191 | set2conditions = { 192 | "train": train, 193 | "val": val, 194 | "test": test 195 | } 196 | 197 | # Ensure that the sets are not empty 198 | if not train: 199 | raise ValueError("Training set is empty after filtering.") 200 | if not val: 201 | raise ValueError("Validation set is empty after filtering.") 202 | if not test: 203 | raise ValueError("Test set is empty after filtering.") 204 | 205 | # Update pert_data 206 | pert_data.set2conditions = set2conditions 207 | pert_data.split = "custom" 208 | pert_data.subgroup = None 209 | pert_data.seed = 1 210 | 211 | # Get the dataloaders 212 | pert_data.get_dataloader(batch_size=args.batch_size, test_batch_size=args.test_batch_size) 213 | 214 | # Initialize and train the model 215 | gears_model = GEARS(pert_data, device=args.device) 216 | gears_model.model_initialize(hidden_size=64) 217 | gears_model.train(epochs=args.epochs) 218 | 219 | # Save the trained model 220 | model_save_path = os.path.join(args.output_dir, f'gears_model_{args.split_name}.pt') 221 | gears_model.save_model(model_save_path) 222 | print(f"Model saved to {model_save_path}") 223 | 224 | if __name__ == '__main__': 225 | main() 226 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple 2 | 3 | import hydra 4 | import lightning as L 5 | import rootutils 6 | import torch 7 | from lightning import Callback, LightningDataModule, LightningModule, Trainer 8 | from lightning.pytorch.loggers import Logger 9 | from omegaconf import OmegaConf, DictConfig 10 | 11 | OmegaConf.register_new_resolver("eval", eval) 12 | 13 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 14 | # ------------------------------------------------------------------------------------ # 15 | # the setup_root above is equivalent to: 16 | # - adding project root dir to PYTHONPATH 17 | # (so you don't need to force user to install project as a package) 18 | # (necessary before importing any local modules e.g. `from src import utils`) 19 | # - setting up PROJECT_ROOT environment variable 20 | # (which is used as a base for paths in "configs/paths/default.yaml") 21 | # (this way all filepaths are the same no matter where you run the code) 22 | # - loading environment variables from ".env" in root dir 23 | # 24 | # you can remove it if you: 25 | # 1. either install project as a package or move entry files to project root dir 26 | # 2. set `root_dir` to "." in "configs/paths/default.yaml" 27 | # 28 | # more info: https://github.com/ashleve/rootutils 29 | # ------------------------------------------------------------------------------------ # 30 | 31 | from src.utils import ( 32 | RankedLogger, 33 | extras, 34 | get_metric_value, 35 | instantiate_callbacks, 36 | instantiate_loggers, 37 | log_hyperparameters, 38 | task_wrapper, 39 | ) 40 | 41 | log = RankedLogger(__name__, rank_zero_only=True) 42 | 43 | 44 | @task_wrapper 45 | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 46 | """Trains the model. Can additionally evaluate on a testset, using best weights obtained during 47 | training. 48 | 49 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during 50 | failure. Useful for multiruns, saving info about the crash, etc. 51 | 52 | :param cfg: A DictConfig configuration composed by Hydra. 53 | :return: A tuple with metrics and dict with all instantiated objects. 54 | """ 55 | # set seed for random number generators in pytorch, numpy and python.random 56 | if cfg.get("seed"): 57 | L.seed_everything(cfg.seed, workers=True) 58 | 59 | if torch.cuda.is_available(): 60 | torch.set_float32_matmul_precision('medium') 61 | 62 | log.info(f"Instantiating datamodule <{cfg.data._target_}>") 63 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) 64 | 65 | log.info(f"Instantiating model <{cfg.model._target_}>") 66 | model: LightningModule = hydra.utils.instantiate(cfg.model) 67 | 68 | log.info("Instantiating callbacks...") 69 | callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) 70 | 71 | log.info("Instantiating loggers...") 72 | logger: List[Logger] = instantiate_loggers(cfg.get("logger")) 73 | 74 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") 75 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) 76 | 77 | object_dict = { 78 | "cfg": cfg, 79 | "datamodule": datamodule, 80 | "model": model, 81 | "callbacks": callbacks, 82 | "logger": logger, 83 | "trainer": trainer, 84 | } 85 | 86 | if logger: 87 | log.info("Logging hyperparameters!") 88 | log_hyperparameters(object_dict) 89 | 90 | if cfg.get("train"): 91 | log.info("Starting training!") 92 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) 93 | 94 | train_metrics = trainer.callback_metrics 95 | 96 | if cfg.get("test"): 97 | log.info("Starting testing!") 98 | ckpt_path = trainer.checkpoint_callback.best_model_path 99 | if ckpt_path == "": 100 | log.warning("Best ckpt not found! Using current weights for testing...") 101 | ckpt_path = None 102 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 103 | log.info(f"Best ckpt path: {ckpt_path}") 104 | 105 | test_metrics = trainer.callback_metrics 106 | 107 | # merge train and test metrics 108 | metric_dict = {**train_metrics, **test_metrics} 109 | 110 | return metric_dict, object_dict 111 | 112 | 113 | @hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") 114 | def main(cfg: DictConfig) -> Optional[float]: 115 | """Main entry point for training. 116 | 117 | :param cfg: DictConfig configuration composed by Hydra. 118 | :return: Optional[float] with optimized metric value. 119 | """ 120 | # apply extra utilities 121 | # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 122 | extras(cfg) 123 | 124 | # train the model 125 | metric_dict, _ = train(cfg) 126 | 127 | # safely retrieve metric value for hydra-based hyperparameter optimization 128 | metric_value = get_metric_value( 129 | metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") 130 | ) 131 | 132 | # return optimized metric 133 | return metric_value 134 | 135 | 136 | if __name__ == "__main__": 137 | main() 138 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from src.utils.instantiators import instantiate_callbacks, instantiate_loggers 2 | from src.utils.logging_utils import log_hyperparameters 3 | from src.utils.pylogger import RankedLogger 4 | from src.utils.rich_utils import enforce_tags, print_config_tree 5 | from src.utils.utils import extras, get_metric_value, task_wrapper 6 | -------------------------------------------------------------------------------- /src/utils/instantiators.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import hydra 4 | from lightning import Callback 5 | from lightning.pytorch.loggers import Logger 6 | from omegaconf import DictConfig 7 | 8 | from src.utils import pylogger 9 | 10 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 11 | 12 | 13 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: 14 | """Instantiates callbacks from config. 15 | 16 | :param callbacks_cfg: A DictConfig object containing callback configurations. 17 | :return: A list of instantiated callbacks. 18 | """ 19 | callbacks: List[Callback] = [] 20 | 21 | if not callbacks_cfg: 22 | log.warning("No callback configs found! Skipping..") 23 | return callbacks 24 | 25 | if not isinstance(callbacks_cfg, DictConfig): 26 | raise TypeError("Callbacks config must be a DictConfig!") 27 | 28 | for _, cb_conf in callbacks_cfg.items(): 29 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: 30 | log.info(f"Instantiating callback <{cb_conf._target_}>") 31 | callbacks.append(hydra.utils.instantiate(cb_conf)) 32 | 33 | return callbacks 34 | 35 | 36 | def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: 37 | """Instantiates loggers from config. 38 | 39 | :param logger_cfg: A DictConfig object containing logger configurations. 40 | :return: A list of instantiated loggers. 41 | """ 42 | logger: List[Logger] = [] 43 | 44 | if not logger_cfg: 45 | log.warning("No logger configs found! Skipping...") 46 | return logger 47 | 48 | if not isinstance(logger_cfg, DictConfig): 49 | raise TypeError("Logger config must be a DictConfig!") 50 | 51 | for _, lg_conf in logger_cfg.items(): 52 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: 53 | log.info(f"Instantiating logger <{lg_conf._target_}>") 54 | logger.append(hydra.utils.instantiate(lg_conf)) 55 | 56 | return logger 57 | -------------------------------------------------------------------------------- /src/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from lightning_utilities.core.rank_zero import rank_zero_only 4 | from omegaconf import OmegaConf 5 | 6 | from src.utils import pylogger 7 | 8 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 9 | 10 | 11 | @rank_zero_only 12 | def log_hyperparameters(object_dict: Dict[str, Any]) -> None: 13 | """Controls which config parts are saved by Lightning loggers. 14 | 15 | Additionally saves: 16 | - Number of model parameters 17 | 18 | :param object_dict: A dictionary containing the following objects: 19 | - `"cfg"`: A DictConfig object containing the main config. 20 | - `"model"`: The Lightning model. 21 | - `"trainer"`: The Lightning trainer. 22 | """ 23 | hparams = {} 24 | 25 | cfg = OmegaConf.to_container(object_dict["cfg"]) 26 | model = object_dict["model"] 27 | trainer = object_dict["trainer"] 28 | 29 | if not trainer.logger: 30 | log.warning("Logger not found! Skipping hyperparameter logging...") 31 | return 32 | 33 | hparams["model"] = cfg["model"] 34 | 35 | # save number of model parameters 36 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 37 | hparams["model/params/trainable"] = sum( 38 | p.numel() for p in model.parameters() if p.requires_grad 39 | ) 40 | hparams["model/params/non_trainable"] = sum( 41 | p.numel() for p in model.parameters() if not p.requires_grad 42 | ) 43 | 44 | hparams["data"] = cfg["data"] 45 | hparams["trainer"] = cfg["trainer"] 46 | 47 | hparams["callbacks"] = cfg.get("callbacks") 48 | hparams["extras"] = cfg.get("extras") 49 | 50 | hparams["task_name"] = cfg.get("task_name") 51 | hparams["tags"] = cfg.get("tags") 52 | hparams["ckpt_path"] = cfg.get("ckpt_path") 53 | hparams["seed"] = cfg.get("seed") 54 | 55 | # send hparams to all loggers 56 | for logger in trainer.loggers: 57 | logger.log_hyperparams(hparams) 58 | -------------------------------------------------------------------------------- /src/utils/pylogger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Mapping, Optional 3 | 4 | from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only 5 | 6 | 7 | class RankedLogger(logging.LoggerAdapter): 8 | """A multi-GPU-friendly python command line logger.""" 9 | 10 | def __init__( 11 | self, 12 | name: str = __name__, 13 | rank_zero_only: bool = False, 14 | extra: Optional[Mapping[str, object]] = None, 15 | ) -> None: 16 | """Initializes a multi-GPU-friendly python command line logger that logs on all processes 17 | with their rank prefixed in the log message. 18 | 19 | :param name: The name of the logger. Default is ``__name__``. 20 | :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. 21 | :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. 22 | """ 23 | logger = logging.getLogger(name) 24 | super().__init__(logger=logger, extra=extra) 25 | self.rank_zero_only = rank_zero_only 26 | 27 | def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None: 28 | """Delegate a log call to the underlying logger, after prefixing its message with the rank 29 | of the process it's being logged from. If `'rank'` is provided, then the log will only 30 | occur on that rank/process. 31 | 32 | :param level: The level to log at. Look at `logging.__init__.py` for more information. 33 | :param msg: The message to log. 34 | :param rank: The rank to log at. 35 | :param args: Additional args to pass to the underlying logging function. 36 | :param kwargs: Any additional keyword args to pass to the underlying logging function. 37 | """ 38 | if self.isEnabledFor(level): 39 | msg, kwargs = self.process(msg, kwargs) 40 | current_rank = getattr(rank_zero_only, "rank", None) 41 | if current_rank is None: 42 | raise RuntimeError("The `rank_zero_only.rank` needs to be set before use") 43 | msg = rank_prefixed_message(msg, current_rank) 44 | if self.rank_zero_only: 45 | if current_rank == 0: 46 | self.logger.log(level, msg, *args, **kwargs) 47 | else: 48 | if rank is None: 49 | self.logger.log(level, msg, *args, **kwargs) 50 | elif current_rank == rank: 51 | self.logger.log(level, msg, *args, **kwargs) 52 | -------------------------------------------------------------------------------- /src/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import rich 5 | import rich.syntax 6 | import rich.tree 7 | from hydra.core.hydra_config import HydraConfig 8 | from lightning_utilities.core.rank_zero import rank_zero_only 9 | from omegaconf import DictConfig, OmegaConf, open_dict 10 | from rich.prompt import Prompt 11 | 12 | from src.utils import pylogger 13 | 14 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 15 | 16 | 17 | @rank_zero_only 18 | def print_config_tree( 19 | cfg: DictConfig, 20 | print_order: Sequence[str] = ( 21 | "data", 22 | "model", 23 | "callbacks", 24 | "logger", 25 | "trainer", 26 | "paths", 27 | "extras", 28 | ), 29 | resolve: bool = False, 30 | save_to_file: bool = False, 31 | ) -> None: 32 | """Prints the contents of a DictConfig as a tree structure using the Rich library. 33 | 34 | :param cfg: A DictConfig composed by Hydra. 35 | :param print_order: Determines in what order config components are printed. Default is ``("data", "model", 36 | "callbacks", "logger", "trainer", "paths", "extras")``. 37 | :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. 38 | :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. 39 | """ 40 | style = "dim" 41 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 42 | 43 | queue = [] 44 | 45 | # add fields from `print_order` to queue 46 | for field in print_order: 47 | queue.append(field) if field in cfg else log.warning( 48 | f"Field '{field}' not found in config. Skipping '{field}' config printing..." 49 | ) 50 | 51 | # add all the other fields to queue (not specified in `print_order`) 52 | for field in cfg: 53 | if field not in queue: 54 | queue.append(field) 55 | 56 | # generate config tree from queue 57 | for field in queue: 58 | branch = tree.add(field, style=style, guide_style=style) 59 | 60 | config_group = cfg[field] 61 | if isinstance(config_group, DictConfig): 62 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 63 | else: 64 | branch_content = str(config_group) 65 | 66 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 67 | 68 | # print config tree 69 | rich.print(tree) 70 | 71 | # save config tree to file 72 | if save_to_file: 73 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 74 | rich.print(tree, file=file) 75 | 76 | 77 | @rank_zero_only 78 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 79 | """Prompts user to input tags from command line if no tags are provided in config. 80 | 81 | :param cfg: A DictConfig composed by Hydra. 82 | :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. 83 | """ 84 | if not cfg.get("tags"): 85 | if "id" in HydraConfig().cfg.hydra.job: 86 | raise ValueError("Specify tags before launching a multirun!") 87 | 88 | log.warning("No tags provided in config. Prompting user to input tags...") 89 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 90 | tags = [t.strip() for t in tags.split(",") if t != ""] 91 | 92 | with open_dict(cfg): 93 | cfg.tags = tags 94 | 95 | log.info(f"Tags: {cfg.tags}") 96 | 97 | if save_to_file: 98 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: 99 | rich.print(cfg.tags, file=file) 100 | -------------------------------------------------------------------------------- /src/utils/spectra/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronwtr/PertEval/efbfa51991fbd6faa6039619d754e354be40fc07/src/utils/spectra/__init__.py -------------------------------------------------------------------------------- /src/utils/spectra/dataset.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class SpectraDataset(ABC): 5 | 6 | def __init__(self, input_file, name): 7 | self.input_file = input_file 8 | self.name = name 9 | self.samples = self.parse(input_file) 10 | 11 | @abstractmethod 12 | def sample_to_index(self, idx): 13 | """ 14 | Given a sample, return the data idx 15 | """ 16 | pass 17 | 18 | @abstractmethod 19 | def parse(self, input_file): 20 | """ 21 | Given a dataset file, parse the dataset file. 22 | Make sure there are only unique entries! 23 | """ 24 | pass 25 | 26 | @abstractmethod 27 | def __len__(self): 28 | """ 29 | Return the length of the dataset 30 | """ 31 | pass 32 | 33 | @abstractmethod 34 | def __getitem__(self, idx): 35 | """ 36 | Given a dataset idx, return the element at that index 37 | """ 38 | pass -------------------------------------------------------------------------------- /src/utils/spectra/get_splits.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import PurePath 3 | import numpy as np 4 | 5 | from src.utils.spectra.perturb import PerturbGraphData, SPECTRAPerturb 6 | 7 | 8 | def spectra(sghv_pert_data, data_path, spectra_params, spectral_parameter): 9 | data_name = PurePath(data_path).parts[-1] 10 | perturb_graph_data = PerturbGraphData(sghv_pert_data, data_name) 11 | 12 | sc_spectra = SPECTRAPerturb(perturb_graph_data, binary=False) 13 | sc_spectra.pre_calculate_spectra_properties(f"{data_path}/{data_name}") 14 | 15 | sparsification_step = spectra_params['sparsification_step'] 16 | sparsification = ["{:.2f}".format(i) for i in np.arange(0, 1.01, float(sparsification_step))] 17 | spectra_params['number_repeats'] = int(spectra_params['number_repeats']) 18 | spectra_params['spectral_parameters'] = sparsification 19 | spectra_params['data_path'] = data_path + "/" 20 | 21 | if not os.path.exists(f"{data_path}/{data_name}_SPECTRA_splits"): 22 | sc_spectra.generate_spectra_splits(**spectra_params) 23 | elif not os.listdir(f"{data_path}/{data_name}_SPECTRA_splits"): 24 | sc_spectra.generate_spectra_splits(**spectra_params) 25 | else: 26 | print("Splits already exist. Proceeding. . .") 27 | 28 | sp = spectral_parameter.split('_')[0] 29 | rpt = spectral_parameter.split('_')[1] 30 | train, test = sc_spectra.return_split_samples(sp, rpt, 31 | f"{data_path}/{data_name}") 32 | pert_list = perturb_graph_data.samples 33 | 34 | return train, test, pert_list 35 | -------------------------------------------------------------------------------- /src/utils/spectra/independent_set_algo.py: -------------------------------------------------------------------------------- 1 | import random 2 | import networkx as nx 3 | from .utils import is_clique, connected_components, is_integer 4 | from scipy import stats 5 | import numpy as np 6 | 7 | 8 | def run_independent_set(spectral_parameter, input_G, seed=None, 9 | debug=False, distribution=None, binary=True): 10 | total_deleted = 0 11 | independent_set = [] 12 | 13 | if seed is not None: 14 | random.seed(seed) 15 | 16 | G = input_G.copy() 17 | 18 | if binary: 19 | # First check if any connected component of the graph is a clique, if so, add it as one unit to the independent set 20 | components = list(connected_components(G)) 21 | deleted = 0 22 | for i, component in enumerate(components): 23 | subgraph = G.subgraph(component) 24 | if is_clique(subgraph): 25 | print( 26 | f"Component {i} is too densly connected, adding samples as a single unit to independent set and deleting them from the graph") 27 | independent_set.append(list(subgraph.nodes())) 28 | G.remove_nodes_from(subgraph.nodes()) 29 | else: 30 | for node in list(subgraph.nodes()): 31 | if subgraph.degree(node) == len(subgraph.nodes()) - 1: 32 | deleted += 1 33 | G.remove_node(node) 34 | 35 | print(f"Deleted {deleted} nodes from the graph since they were connected to all other nodes") 36 | 37 | iterations = 0 38 | 39 | while not nx.is_empty(G): 40 | chosen_node = random.sample(list(G.nodes()), 1)[0] 41 | 42 | independent_set.append(chosen_node) 43 | neighbors = G.neighbors(chosen_node) 44 | neighbors_to_delete = [] 45 | 46 | for neighbor in neighbors: 47 | if not binary: 48 | if spectral_parameter == 1.0: 49 | neighbors_to_delete.append(neighbor) 50 | else: 51 | edge_weight = G[chosen_node][neighbor]['weight'] 52 | if distribution is None: 53 | raise Exception( 54 | "Distribution must be provided if binary is set to False, must precompute similarities") 55 | if random.random() < spectral_parameter and ( 56 | 1 - spectral_parameter) * 100 < stats.percentileofscore(distribution, edge_weight): 57 | neighbors_to_delete.append(neighbor) 58 | else: 59 | if spectral_parameter == 1.0: 60 | neighbors_to_delete.append(neighbor) 61 | elif spectral_parameter != 0.0: 62 | if random.random() < spectral_parameter: 63 | neighbors_to_delete.append(neighbor) 64 | 65 | if debug: 66 | print(f"Iteration {iterations} Stats") 67 | print( 68 | f"Deleted {len(neighbors_to_delete)} nodes from {G.degree(chosen_node)} neighbors of node {chosen_node}") 69 | total_deleted += len(neighbors_to_delete) 70 | 71 | for neighbor in neighbors_to_delete: 72 | G.remove_node(neighbor) 73 | 74 | if chosen_node not in neighbors_to_delete: 75 | G.remove_node(chosen_node) 76 | 77 | iterations += 1 78 | 79 | for node in list(G.nodes()): 80 | # Append the nodes left to G 81 | independent_set.append(node) 82 | 83 | if debug: 84 | print(f"{len(input_G.nodes())} nodes in the original graph") 85 | print(f"Total deleted {total_deleted}") 86 | print(f"{len(independent_set)} nodes in the independent set") 87 | 88 | return independent_set 89 | -------------------------------------------------------------------------------- /src/utils/spectra/perturb.py: -------------------------------------------------------------------------------- 1 | from src.utils.spectra.spectra import Spectra 2 | from src.utils.spectra.dataset import SpectraDataset 3 | from anndata import AnnData 4 | 5 | import numpy as np 6 | from tqdm import tqdm 7 | #from gears.pertdata import PertData #removed n.b. 8 | 9 | 10 | class PerturbGraphData(SpectraDataset): 11 | def parse(self, pert_data): 12 | if isinstance(pert_data, AnnData): 13 | self.adata = pert_data 14 | else: 15 | self.adata = pert_data.adata 16 | self.control_expression = self.adata[self.adata.obs['condition'] == 'ctrl'].X.toarray().mean(axis=0) 17 | return [p for p in self.adata.obs['condition'].unique() if p != 'ctrl'] 18 | 19 | def get_mean_logfold_change(self, perturbation): 20 | perturbation_expression = self.adata[self.adata.obs['condition'] == perturbation].X.toarray().mean(axis=0) 21 | logfold_change = np.nan_to_num(np.log2(perturbation_expression + 1) - np.log2(self.control_expression + 1)) 22 | return logfold_change 23 | 24 | def sample_to_index(self, sample): 25 | if not hasattr(self, 'index_to_sequence'): 26 | print("Generating index to sequence") 27 | self.index_to_sequence = {} 28 | for i in tqdm(range(len(self))): 29 | x = self.__getitem__(i) 30 | self.index_to_sequence['-'.join(list(x))] = i 31 | 32 | return self.index_to_sequence[sample] 33 | 34 | def __len__(self): 35 | return len(self.samples) 36 | 37 | def __getitem__(self, idx): 38 | perturbation = self.samples[idx] 39 | return self.get_mean_logfold_change(perturbation) 40 | 41 | 42 | class SPECTRAPerturb(Spectra): 43 | def spectra_properties(self, sample_one, sample_two): 44 | return -np.linalg.norm(sample_one - sample_two) 45 | 46 | def cross_split_overlap(self, train, test): 47 | average_similarity = [] 48 | 49 | for i in test: 50 | for j in train: 51 | average_similarity.append(self.spectra_properties(i, j)) 52 | 53 | return np.mean(average_similarity) 54 | -------------------------------------------------------------------------------- /src/utils/spectra/utils.py: -------------------------------------------------------------------------------- 1 | from networkx.algorithms.components import connected_components 2 | 3 | 4 | def is_integer(n): 5 | if isinstance(n, int): 6 | return True 7 | elif isinstance(n, float): 8 | return n.is_integer() 9 | else: 10 | return False 11 | 12 | 13 | def is_clique(G): 14 | return G.size() == (G.order() * (G.order() - 1)) / 2 15 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import os 3 | import requests 4 | 5 | from tqdm import tqdm 6 | from importlib.util import find_spec 7 | from typing import Any, Callable, Dict, Optional, Tuple 8 | from zipfile import ZipFile 9 | 10 | from omegaconf import DictConfig 11 | 12 | from src.utils import pylogger, rich_utils 13 | 14 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 15 | 16 | 17 | def extras(cfg: DictConfig) -> None: 18 | """Applies optional utilities before the task is started. 19 | 20 | Utilities: 21 | - Ignoring python warnings 22 | - Setting tags from command line 23 | - Rich config printing 24 | 25 | :param cfg: A DictConfig object containing the config tree. 26 | """ 27 | # return if no `extras` config 28 | if not cfg.get("extras"): 29 | log.warning("Extras config not found! ") 30 | return 31 | 32 | # disable python warnings 33 | if cfg.extras.get("ignore_warnings"): 34 | log.info("Disabling python warnings! ") 35 | warnings.filterwarnings("ignore") 36 | 37 | # prompt user to input tags from command line if none are provided in the config 38 | if cfg.extras.get("enforce_tags"): 39 | log.info("Enforcing tags! ") 40 | rich_utils.enforce_tags(cfg, save_to_file=True) 41 | 42 | # pretty print config tree using Rich library 43 | if cfg.extras.get("print_config"): 44 | log.info("Printing config tree with Rich! ") 45 | rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) 46 | 47 | if cfg.extras.get('distributed_storage'): 48 | # writing data_dir to cache for easy access 49 | if not os.path.exists(f'{os.getcwd()}/cache'): 50 | os.makedirs(f'{os.getcwd()}/cache') 51 | with open(f'{os.getcwd()}/cache/data_dir_cache.txt', 'w') as f: 52 | f.write(cfg.get("data")['data_dir']) 53 | 54 | 55 | def task_wrapper(task_func: Callable) -> Callable: 56 | """Optional decorator that controls the failure behavior when executing the task function. 57 | 58 | This wrapper can be used to: 59 | - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) 60 | - save the exception to a `.log` file 61 | - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) 62 | - etc. (adjust depending on your needs) 63 | 64 | Example: 65 | ``` 66 | @utils.task_wrapper 67 | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 68 | ... 69 | return metric_dict, object_dict 70 | ``` 71 | 72 | :param task_func: The task function to be wrapped. 73 | 74 | :return: The wrapped task function. 75 | """ 76 | 77 | def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 78 | # execute the task 79 | try: 80 | metric_dict, object_dict = task_func(cfg=cfg) 81 | 82 | # things to do if exception occurs 83 | except Exception as ex: 84 | # save exception to `.log` file 85 | log.exception("") 86 | 87 | # some hyperparameter combinations might be invalid or cause out-of-memory errors 88 | # so when using hparam search plugins like Optuna, you might want to disable 89 | # raising the below exception to avoid multirun failure 90 | raise ex 91 | 92 | # things to always do after either success or exception 93 | finally: 94 | # display output dir path in terminal 95 | log.info(f"Output dir: {cfg.paths.output_dir}") 96 | 97 | # always close wandb run (even if exception occurs so multirun won't fail) 98 | if find_spec("wandb"): # check if wandb is installed 99 | import wandb 100 | 101 | if wandb.run: 102 | log.info("Closing wandb!") 103 | wandb.finish() 104 | 105 | return metric_dict, object_dict 106 | 107 | return wrap 108 | 109 | 110 | def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]: 111 | """Safely retrieves value of the metric logged in LightningModule. 112 | 113 | :param metric_dict: A dict containing metric values. 114 | :param metric_name: If provided, the name of the metric to retrieve. 115 | :return: If a metric name was provided, the value of the metric. 116 | """ 117 | if not metric_name: 118 | log.info("Metric name is None! Skipping metric value retrieval...") 119 | return None 120 | 121 | if metric_name not in metric_dict: 122 | raise Exception( 123 | f"Metric value not found! \n" 124 | "Make sure metric name logged in LightningModule is correct!\n" 125 | "Make sure `optimized_metric` name in `hparams_search` config is correct!" 126 | ) 127 | 128 | metric_value = metric_dict[metric_name].item() 129 | log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") 130 | 131 | return metric_value 132 | 133 | 134 | def dataverse_download(url: str = "", save_path: str = ""): 135 | """Dataverse download helper with progress bar 136 | 137 | :param url: the url of the dataset 138 | :param save_path: the path to save the dataset 139 | """ 140 | if os.path.exists(save_path): 141 | print('Found local copy...') 142 | else: 143 | response = requests.get(url, stream=True) 144 | total_size_in_bytes = int(response.headers.get('content-length', 0)) 145 | block_size = 1024 146 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) 147 | dir_path = save_path.split('/')[:-1] 148 | os.makedirs('/'.join(dir_path), exist_ok=True) 149 | with open(save_path, 'wb') as file: 150 | for data in response.iter_content(block_size): 151 | progress_bar.update(len(data)) 152 | file.write(data) 153 | progress_bar.close() 154 | 155 | 156 | def zip_data_download_wrapper(url: str = "", zip_path: str = ""): 157 | """Wrapper for zip file download 158 | 159 | :param url: The url of the dataset. 160 | :param zip_path: The path where the file is downloaded. 161 | """ 162 | if os.path.exists(zip_path): 163 | print('Found local copy of .zip file...') 164 | else: 165 | dataverse_download(url, zip_path + '.zip') 166 | print('Extracting zip file...') 167 | with ZipFile((zip_path + '.zip'), 'r') as z: 168 | z.extractall(path=os.path.dirname(zip_path)) 169 | os.remove(zip_path + '.zip') 170 | print("Done!") 171 | 172 | 173 | def find_root_dir(current_dir): 174 | while True: 175 | if '.project-root' in os.listdir(current_dir): 176 | return current_dir 177 | else: 178 | current_dir = os.path.dirname(current_dir) 179 | 180 | if current_dir == '/': # if we have reached the root of the filesystem 181 | raise FileNotFoundError("Could not find .project-root file.") 182 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronwtr/PertEval/efbfa51991fbd6faa6039619d754e354be40fc07/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """This file prepares config fixtures for other tests.""" 2 | 3 | from pathlib import Path 4 | 5 | import pytest 6 | import rootutils 7 | from hydra import compose, initialize 8 | from hydra.core.global_hydra import GlobalHydra 9 | from omegaconf import DictConfig, open_dict 10 | 11 | 12 | @pytest.fixture(scope="package") 13 | def cfg_train_global() -> DictConfig: 14 | """A pytest fixture for setting up a default Hydra DictConfig for training. 15 | 16 | :return: A DictConfig object containing a default Hydra configuration for training. 17 | """ 18 | with initialize(version_base="1.3", config_path="../configs"): 19 | cfg = compose(config_name="train.yaml", return_hydra_config=True, overrides=[]) 20 | 21 | # set defaults for all tests 22 | with open_dict(cfg): 23 | cfg.paths.root_dir = str(rootutils.find_root(indicator=".project-root")) 24 | cfg.trainer.max_epochs = 1 25 | cfg.trainer.limit_train_batches = 0.01 26 | cfg.trainer.limit_val_batches = 0.1 27 | cfg.trainer.limit_test_batches = 0.1 28 | cfg.trainer.accelerator = "cpu" 29 | cfg.trainer.devices = 1 30 | cfg.data.num_workers = 0 31 | cfg.data.pin_memory = False 32 | cfg.extras.print_config = False 33 | cfg.extras.enforce_tags = False 34 | cfg.logger = None 35 | 36 | return cfg 37 | 38 | 39 | @pytest.fixture(scope="package") 40 | def cfg_eval_global() -> DictConfig: 41 | """A pytest fixture for setting up a default Hydra DictConfig for evaluation. 42 | 43 | :return: A DictConfig containing a default Hydra configuration for evaluation. 44 | """ 45 | with initialize(version_base="1.3", config_path="../configs"): 46 | cfg = compose(config_name="eval.yaml", return_hydra_config=True, overrides=["ckpt_path=."]) 47 | 48 | # set defaults for all tests 49 | with open_dict(cfg): 50 | cfg.paths.root_dir = str(rootutils.find_root(indicator=".project-root")) 51 | cfg.trainer.max_epochs = 1 52 | cfg.trainer.limit_test_batches = 0.1 53 | cfg.trainer.accelerator = "cpu" 54 | cfg.trainer.devices = 1 55 | cfg.data.num_workers = 0 56 | cfg.data.pin_memory = False 57 | cfg.extras.print_config = False 58 | cfg.extras.enforce_tags = False 59 | cfg.logger = None 60 | 61 | return cfg 62 | 63 | 64 | @pytest.fixture(scope="function") 65 | def cfg_train(cfg_train_global: DictConfig, tmp_path: Path) -> DictConfig: 66 | """A pytest fixture built on top of the `cfg_train_global()` fixture, which accepts a temporary 67 | logging path `tmp_path` for generating a temporary logging path. 68 | 69 | This is called by each test which uses the `cfg_train` arg. Each test generates its own temporary logging path. 70 | 71 | :param cfg_train_global: The input DictConfig object to be modified. 72 | :param tmp_path: The temporary logging path. 73 | 74 | :return: A DictConfig with updated output and log directories corresponding to `tmp_path`. 75 | """ 76 | cfg = cfg_train_global.copy() 77 | 78 | with open_dict(cfg): 79 | cfg.paths.output_dir = str(tmp_path) 80 | cfg.paths.log_dir = str(tmp_path) 81 | 82 | yield cfg 83 | 84 | GlobalHydra.instance().clear() 85 | 86 | 87 | @pytest.fixture(scope="function") 88 | def cfg_eval(cfg_eval_global: DictConfig, tmp_path: Path) -> DictConfig: 89 | """A pytest fixture built on top of the `cfg_eval_global()` fixture, which accepts a temporary 90 | logging path `tmp_path` for generating a temporary logging path. 91 | 92 | This is called by each test which uses the `cfg_eval` arg. Each test generates its own temporary logging path. 93 | 94 | :param cfg_train_global: The input DictConfig object to be modified. 95 | :param tmp_path: The temporary logging path. 96 | 97 | :return: A DictConfig with updated output and log directories corresponding to `tmp_path`. 98 | """ 99 | cfg = cfg_eval_global.copy() 100 | 101 | with open_dict(cfg): 102 | cfg.paths.output_dir = str(tmp_path) 103 | cfg.paths.log_dir = str(tmp_path) 104 | 105 | yield cfg 106 | 107 | GlobalHydra.instance().clear() 108 | -------------------------------------------------------------------------------- /tests/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronwtr/PertEval/efbfa51991fbd6faa6039619d754e354be40fc07/tests/helpers/__init__.py -------------------------------------------------------------------------------- /tests/helpers/package_available.py: -------------------------------------------------------------------------------- 1 | import platform 2 | 3 | import pkg_resources 4 | from lightning.fabric.accelerators import TPUAccelerator 5 | 6 | 7 | def _package_available(package_name: str) -> bool: 8 | """Check if a package is available in your environment. 9 | 10 | :param package_name: The name of the package to be checked. 11 | 12 | :return: `True` if the package is available. `False` otherwise. 13 | """ 14 | try: 15 | return pkg_resources.require(package_name) is not None 16 | except pkg_resources.DistributionNotFound: 17 | return False 18 | 19 | 20 | _TPU_AVAILABLE = TPUAccelerator.is_available() 21 | 22 | _IS_WINDOWS = platform.system() == "Windows" 23 | 24 | _SH_AVAILABLE = not _IS_WINDOWS and _package_available("sh") 25 | 26 | _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _package_available("deepspeed") 27 | _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _package_available("fairscale") 28 | 29 | _WANDB_AVAILABLE = _package_available("wandb") 30 | _NEPTUNE_AVAILABLE = _package_available("neptune") 31 | _COMET_AVAILABLE = _package_available("comet_ml") 32 | _MLFLOW_AVAILABLE = _package_available("mlflow") 33 | -------------------------------------------------------------------------------- /tests/helpers/run_if.py: -------------------------------------------------------------------------------- 1 | """Adapted from: 2 | 3 | https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py 4 | """ 5 | 6 | import sys 7 | from typing import Any, Dict, Optional 8 | 9 | import pytest 10 | import torch 11 | from packaging.version import Version 12 | from pkg_resources import get_distribution 13 | from pytest import MarkDecorator 14 | 15 | from tests.helpers.package_available import ( 16 | _COMET_AVAILABLE, 17 | _DEEPSPEED_AVAILABLE, 18 | _FAIRSCALE_AVAILABLE, 19 | _IS_WINDOWS, 20 | _MLFLOW_AVAILABLE, 21 | _NEPTUNE_AVAILABLE, 22 | _SH_AVAILABLE, 23 | _TPU_AVAILABLE, 24 | _WANDB_AVAILABLE, 25 | ) 26 | 27 | 28 | class RunIf: 29 | """RunIf wrapper for conditional skipping of tests. 30 | 31 | Fully compatible with `@pytest.mark`. 32 | 33 | Example: 34 | 35 | ```python 36 | @RunIf(min_torch="1.8") 37 | @pytest.mark.parametrize("arg1", [1.0, 2.0]) 38 | def test_wrapper(arg1): 39 | assert arg1 > 0 40 | ``` 41 | """ 42 | 43 | def __new__( 44 | cls, 45 | min_gpus: int = 0, 46 | min_torch: Optional[str] = None, 47 | max_torch: Optional[str] = None, 48 | min_python: Optional[str] = None, 49 | skip_windows: bool = False, 50 | sh: bool = False, 51 | tpu: bool = False, 52 | fairscale: bool = False, 53 | deepspeed: bool = False, 54 | wandb: bool = False, 55 | neptune: bool = False, 56 | comet: bool = False, 57 | mlflow: bool = False, 58 | **kwargs: Dict[Any, Any], 59 | ) -> MarkDecorator: 60 | """Creates a new `@RunIf` `MarkDecorator` decorator. 61 | 62 | :param min_gpus: Min number of GPUs required to run test. 63 | :param min_torch: Minimum pytorch version to run test. 64 | :param max_torch: Maximum pytorch version to run test. 65 | :param min_python: Minimum python version required to run test. 66 | :param skip_windows: Skip test for Windows platform. 67 | :param tpu: If TPU is available. 68 | :param sh: If `sh` module is required to run the test. 69 | :param fairscale: If `fairscale` module is required to run the test. 70 | :param deepspeed: If `deepspeed` module is required to run the test. 71 | :param wandb: If `wandb` module is required to run the test. 72 | :param neptune: If `neptune` module is required to run the test. 73 | :param comet: If `comet` module is required to run the test. 74 | :param mlflow: If `mlflow` module is required to run the test. 75 | :param kwargs: Native `pytest.mark.skipif` keyword arguments. 76 | """ 77 | conditions = [] 78 | reasons = [] 79 | 80 | if min_gpus: 81 | conditions.append(torch.cuda.device_count() < min_gpus) 82 | reasons.append(f"GPUs>={min_gpus}") 83 | 84 | if min_torch: 85 | torch_version = get_distribution("torch").version 86 | conditions.append(Version(torch_version) < Version(min_torch)) 87 | reasons.append(f"torch>={min_torch}") 88 | 89 | if max_torch: 90 | torch_version = get_distribution("torch").version 91 | conditions.append(Version(torch_version) >= Version(max_torch)) 92 | reasons.append(f"torch<{max_torch}") 93 | 94 | if min_python: 95 | py_version = ( 96 | f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" 97 | ) 98 | conditions.append(Version(py_version) < Version(min_python)) 99 | reasons.append(f"python>={min_python}") 100 | 101 | if skip_windows: 102 | conditions.append(_IS_WINDOWS) 103 | reasons.append("does not run on Windows") 104 | 105 | if tpu: 106 | conditions.append(not _TPU_AVAILABLE) 107 | reasons.append("TPU") 108 | 109 | if sh: 110 | conditions.append(not _SH_AVAILABLE) 111 | reasons.append("sh") 112 | 113 | if fairscale: 114 | conditions.append(not _FAIRSCALE_AVAILABLE) 115 | reasons.append("fairscale") 116 | 117 | if deepspeed: 118 | conditions.append(not _DEEPSPEED_AVAILABLE) 119 | reasons.append("deepspeed") 120 | 121 | if wandb: 122 | conditions.append(not _WANDB_AVAILABLE) 123 | reasons.append("wandb") 124 | 125 | if neptune: 126 | conditions.append(not _NEPTUNE_AVAILABLE) 127 | reasons.append("neptune") 128 | 129 | if comet: 130 | conditions.append(not _COMET_AVAILABLE) 131 | reasons.append("comet") 132 | 133 | if mlflow: 134 | conditions.append(not _MLFLOW_AVAILABLE) 135 | reasons.append("mlflow") 136 | 137 | reasons = [rs for cond, rs in zip(conditions, reasons) if cond] 138 | return pytest.mark.skipif( 139 | condition=any(conditions), 140 | reason=f"Requires: [{' + '.join(reasons)}]", 141 | **kwargs, 142 | ) 143 | -------------------------------------------------------------------------------- /tests/helpers/run_sh_command.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | 5 | from tests.helpers.package_available import _SH_AVAILABLE 6 | 7 | if _SH_AVAILABLE: 8 | import sh 9 | 10 | 11 | def run_sh_command(command: List[str]) -> None: 12 | """Default method for executing shell commands with `pytest` and `sh` package. 13 | 14 | :param command: A list of shell commands as strings. 15 | """ 16 | msg = None 17 | try: 18 | sh.python(command) 19 | except sh.ErrorReturnCode as e: 20 | msg = e.stderr.decode() 21 | if msg: 22 | pytest.fail(msg=msg) 23 | -------------------------------------------------------------------------------- /tests/test_configs.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from hydra.core.hydra_config import HydraConfig 3 | from omegaconf import DictConfig 4 | 5 | 6 | def test_train_config(cfg_train: DictConfig) -> None: 7 | """Tests the training configuration provided by the `cfg_train` pytest fixture. 8 | 9 | :param cfg_train: A DictConfig containing a valid training configuration. 10 | """ 11 | assert cfg_train 12 | assert cfg_train.data 13 | assert cfg_train.model 14 | assert cfg_train.trainer 15 | 16 | HydraConfig().set_config(cfg_train) 17 | 18 | hydra.utils.instantiate(cfg_train.data) 19 | hydra.utils.instantiate(cfg_train.model) 20 | hydra.utils.instantiate(cfg_train.trainer) 21 | 22 | 23 | def test_eval_config(cfg_eval: DictConfig) -> None: 24 | """Tests the evaluation configuration provided by the `cfg_eval` pytest fixture. 25 | 26 | :param cfg_train: A DictConfig containing a valid evaluation configuration. 27 | """ 28 | assert cfg_eval 29 | assert cfg_eval.data 30 | assert cfg_eval.model 31 | assert cfg_eval.trainer 32 | 33 | HydraConfig().set_config(cfg_eval) 34 | 35 | hydra.utils.instantiate(cfg_eval.data) 36 | hydra.utils.instantiate(cfg_eval.model) 37 | hydra.utils.instantiate(cfg_eval.trainer) 38 | -------------------------------------------------------------------------------- /tests/test_datamodules.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | import torch 5 | 6 | from src.data.mnist_datamodule import MNISTDataModule 7 | 8 | 9 | @pytest.mark.parametrize("batch_size", [32, 128]) 10 | def test_mnist_datamodule(batch_size: int) -> None: 11 | """Tests `MNISTDataModule` to verify that it can be downloaded correctly, that the necessary 12 | attributes were created (e.g., the dataloader objects), and that dtypes and batch sizes 13 | correctly match. 14 | 15 | :param batch_size: Batch size of the data to be loaded by the dataloader. 16 | """ 17 | data_dir = "data/" 18 | 19 | dm = MNISTDataModule(data_dir=data_dir, batch_size=batch_size) 20 | dm.prepare_data() 21 | 22 | assert not dm.data_train and not dm.data_val and not dm.data_test 23 | assert Path(data_dir, "MNIST").exists() 24 | assert Path(data_dir, "MNIST", "raw").exists() 25 | 26 | dm.setup() 27 | assert dm.data_train and dm.data_val and dm.data_test 28 | assert dm.train_dataloader() and dm.val_dataloader() and dm.test_dataloader() 29 | 30 | num_datapoints = len(dm.data_train) + len(dm.data_val) + len(dm.data_test) 31 | assert num_datapoints == 70_000 32 | 33 | batch = next(iter(dm.train_dataloader())) 34 | x, y = batch 35 | assert len(x) == batch_size 36 | assert len(y) == batch_size 37 | assert x.dtype == torch.float32 38 | assert y.dtype == torch.int64 39 | -------------------------------------------------------------------------------- /tests/test_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import pytest 5 | from hydra.core.hydra_config import HydraConfig 6 | from omegaconf import DictConfig, open_dict 7 | 8 | from src.eval import evaluate 9 | from src.train import train 10 | 11 | 12 | @pytest.mark.slow 13 | def test_train_eval(tmp_path: Path, cfg_train: DictConfig, cfg_eval: DictConfig) -> None: 14 | """Tests training and evaluation by training for 1 epoch with `train.py` then evaluating with 15 | `eval.py`. 16 | 17 | :param tmp_path: The temporary logging path. 18 | :param cfg_train: A DictConfig containing a valid training configuration. 19 | :param cfg_eval: A DictConfig containing a valid evaluation configuration. 20 | """ 21 | assert str(tmp_path) == cfg_train.paths.output_dir == cfg_eval.paths.output_dir 22 | 23 | with open_dict(cfg_train): 24 | cfg_train.trainer.max_epochs = 1 25 | cfg_train.test = True 26 | 27 | HydraConfig().set_config(cfg_train) 28 | train_metric_dict, _ = train(cfg_train) 29 | 30 | assert "last.ckpt" in os.listdir(tmp_path / "checkpoints") 31 | 32 | with open_dict(cfg_eval): 33 | cfg_eval.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") 34 | 35 | HydraConfig().set_config(cfg_eval) 36 | test_metric_dict, _ = evaluate(cfg_eval) 37 | 38 | assert test_metric_dict["test/acc"] > 0.0 39 | assert abs(train_metric_dict["test/acc"].item() - test_metric_dict["test/acc"].item()) < 0.001 40 | -------------------------------------------------------------------------------- /tests/test_sweeps.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | 5 | from tests.helpers.run_if import RunIf 6 | from tests.helpers.run_sh_command import run_sh_command 7 | 8 | startfile = "src/train.py" 9 | overrides = ["logger=[]"] 10 | 11 | 12 | @RunIf(sh=True) 13 | @pytest.mark.slow 14 | def test_experiments(tmp_path: Path) -> None: 15 | """Test running all available experiment configs with `fast_dev_run=True.` 16 | 17 | :param tmp_path: The temporary logging path. 18 | """ 19 | command = [ 20 | startfile, 21 | "-m", 22 | "experiment=glob(*)", 23 | "hydra.sweep.dir=" + str(tmp_path), 24 | "++trainer.fast_dev_run=true", 25 | ] + overrides 26 | run_sh_command(command) 27 | 28 | 29 | @RunIf(sh=True) 30 | @pytest.mark.slow 31 | def test_hydra_sweep(tmp_path: Path) -> None: 32 | """Test default hydra sweep. 33 | 34 | :param tmp_path: The temporary logging path. 35 | """ 36 | command = [ 37 | startfile, 38 | "-m", 39 | "hydra.sweep.dir=" + str(tmp_path), 40 | "model.optimizer.lr=0.005,0.01", 41 | "++trainer.fast_dev_run=true", 42 | ] + overrides 43 | 44 | run_sh_command(command) 45 | 46 | 47 | @RunIf(sh=True) 48 | @pytest.mark.slow 49 | def test_hydra_sweep_ddp_sim(tmp_path: Path) -> None: 50 | """Test default hydra sweep with ddp sim. 51 | 52 | :param tmp_path: The temporary logging path. 53 | """ 54 | command = [ 55 | startfile, 56 | "-m", 57 | "hydra.sweep.dir=" + str(tmp_path), 58 | "trainer=ddp_sim", 59 | "trainer.max_epochs=3", 60 | "+trainer.limit_train_batches=0.01", 61 | "+trainer.limit_val_batches=0.1", 62 | "+trainer.limit_test_batches=0.1", 63 | "model.optimizer.lr=0.005,0.01,0.02", 64 | ] + overrides 65 | run_sh_command(command) 66 | 67 | 68 | @RunIf(sh=True) 69 | @pytest.mark.slow 70 | def test_optuna_sweep(tmp_path: Path) -> None: 71 | """Test Optuna hyperparam sweeping. 72 | 73 | :param tmp_path: The temporary logging path. 74 | """ 75 | command = [ 76 | startfile, 77 | "-m", 78 | "hparams_search=mnist_optuna", 79 | "hydra.sweep.dir=" + str(tmp_path), 80 | "hydra.sweeper.n_trials=10", 81 | "hydra.sweeper.sampler.n_startup_trials=5", 82 | "++trainer.fast_dev_run=true", 83 | ] + overrides 84 | run_sh_command(command) 85 | 86 | 87 | @RunIf(wandb=True, sh=True) 88 | @pytest.mark.slow 89 | def test_optuna_sweep_ddp_sim_wandb(tmp_path: Path) -> None: 90 | """Test Optuna sweep with wandb logging and ddp sim. 91 | 92 | :param tmp_path: The temporary logging path. 93 | """ 94 | command = [ 95 | startfile, 96 | "-m", 97 | "hparams_search=mnist_optuna", 98 | "hydra.sweep.dir=" + str(tmp_path), 99 | "hydra.sweeper.n_trials=5", 100 | "trainer=ddp_sim", 101 | "trainer.max_epochs=3", 102 | "+trainer.limit_train_batches=0.01", 103 | "+trainer.limit_val_batches=0.1", 104 | "+trainer.limit_test_batches=0.1", 105 | "logger=wandb", 106 | ] 107 | run_sh_command(command) 108 | -------------------------------------------------------------------------------- /tests/test_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import pytest 5 | from hydra.core.hydra_config import HydraConfig 6 | from omegaconf import DictConfig, open_dict 7 | 8 | from src.train import train 9 | from tests.helpers.run_if import RunIf 10 | 11 | 12 | def test_train_fast_dev_run(cfg_train: DictConfig) -> None: 13 | """Run for 1 train, val and test step. 14 | 15 | :param cfg_train: A DictConfig containing a valid training configuration. 16 | """ 17 | HydraConfig().set_config(cfg_train) 18 | with open_dict(cfg_train): 19 | cfg_train.trainer.fast_dev_run = True 20 | cfg_train.trainer.accelerator = "cpu" 21 | train(cfg_train) 22 | 23 | 24 | @RunIf(min_gpus=1) 25 | def test_train_fast_dev_run_gpu(cfg_train: DictConfig) -> None: 26 | """Run for 1 train, val and test step on GPU. 27 | 28 | :param cfg_train: A DictConfig containing a valid training configuration. 29 | """ 30 | HydraConfig().set_config(cfg_train) 31 | with open_dict(cfg_train): 32 | cfg_train.trainer.fast_dev_run = True 33 | cfg_train.trainer.accelerator = "gpu" 34 | train(cfg_train) 35 | 36 | 37 | @RunIf(min_gpus=1) 38 | @pytest.mark.slow 39 | def test_train_epoch_gpu_amp(cfg_train: DictConfig) -> None: 40 | """Train 1 epoch on GPU with mixed-precision. 41 | 42 | :param cfg_train: A DictConfig containing a valid training configuration. 43 | """ 44 | HydraConfig().set_config(cfg_train) 45 | with open_dict(cfg_train): 46 | cfg_train.trainer.max_epochs = 1 47 | cfg_train.trainer.accelerator = "gpu" 48 | cfg_train.trainer.precision = 16 49 | train(cfg_train) 50 | 51 | 52 | @pytest.mark.slow 53 | def test_train_epoch_double_val_loop(cfg_train: DictConfig) -> None: 54 | """Train 1 epoch with validation loop twice per epoch. 55 | 56 | :param cfg_train: A DictConfig containing a valid training configuration. 57 | """ 58 | HydraConfig().set_config(cfg_train) 59 | with open_dict(cfg_train): 60 | cfg_train.trainer.max_epochs = 1 61 | cfg_train.trainer.val_check_interval = 0.5 62 | train(cfg_train) 63 | 64 | 65 | @pytest.mark.slow 66 | def test_train_ddp_sim(cfg_train: DictConfig) -> None: 67 | """Simulate DDP (Distributed Data Parallel) on 2 CPU processes. 68 | 69 | :param cfg_train: A DictConfig containing a valid training configuration. 70 | """ 71 | HydraConfig().set_config(cfg_train) 72 | with open_dict(cfg_train): 73 | cfg_train.trainer.max_epochs = 2 74 | cfg_train.trainer.accelerator = "cpu" 75 | cfg_train.trainer.devices = 2 76 | cfg_train.trainer.strategy = "ddp_spawn" 77 | train(cfg_train) 78 | 79 | 80 | @pytest.mark.slow 81 | def test_train_resume(tmp_path: Path, cfg_train: DictConfig) -> None: 82 | """Run 1 epoch, finish, and resume for another epoch. 83 | 84 | :param tmp_path: The temporary logging path. 85 | :param cfg_train: A DictConfig containing a valid training configuration. 86 | """ 87 | with open_dict(cfg_train): 88 | cfg_train.trainer.max_epochs = 1 89 | 90 | HydraConfig().set_config(cfg_train) 91 | metric_dict_1, _ = train(cfg_train) 92 | 93 | files = os.listdir(tmp_path / "checkpoints") 94 | assert "last.ckpt" in files 95 | assert "epoch_000.ckpt" in files 96 | 97 | with open_dict(cfg_train): 98 | cfg_train.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") 99 | cfg_train.trainer.max_epochs = 2 100 | 101 | metric_dict_2, _ = train(cfg_train) 102 | 103 | files = os.listdir(tmp_path / "checkpoints") 104 | assert "epoch_001.ckpt" in files 105 | assert "epoch_002.ckpt" not in files 106 | 107 | assert metric_dict_1["train/acc"] < metric_dict_2["train/acc"] 108 | assert metric_dict_1["val/acc"] < metric_dict_2["val/acc"] 109 | --------------------------------------------------------------------------------