├── .gitignore ├── .pre-commit-config.yaml ├── CITATION.cff ├── Dockerfile ├── LICENSE ├── README.md ├── chemCPA ├── __init__.py ├── data │ ├── __init__.py │ ├── data.py │ ├── dataset │ │ ├── __init__.py │ │ ├── compute_degs.py │ │ ├── dataset.py │ │ ├── drug_names_to_once_canon_smiles.py │ │ └── subdataset.py │ └── perturbation_data_module.py ├── embedding.py ├── helper.py ├── lightning_module.py ├── model.py ├── paths.py ├── profiling.py ├── train.py ├── train_hydra.py └── train_hydra_tmux.sh ├── config ├── README.md ├── dataset │ ├── biolord.yaml │ ├── biolord_split_30.yaml │ ├── combinatorial.yaml │ ├── default.yaml │ ├── lincs.yaml │ ├── lincs_2000_genes.yaml │ ├── sciplex.yaml │ └── sciplex_2000_genes.yaml ├── finetune.yaml ├── finetune_2000_genes.yaml ├── finetune_combinatorial.yaml ├── hydra │ └── default.yaml ├── lincs.yaml ├── lincs_2000_genes.yaml ├── main.yaml ├── model │ ├── .pretrain_combinatorial.yaml.swp │ ├── additional_params │ │ └── default.yaml │ ├── combinatorial_rdkit.yaml │ ├── default.yaml │ ├── embedding │ │ ├── biolord_split_30.yaml │ │ ├── combinatorial_rdkit.yaml │ │ ├── default.yaml │ │ ├── lincs.yaml │ │ ├── sciplex_lincs_genes.yaml │ │ └── sciplex_middle.yaml │ ├── finetune.yaml │ ├── finetune_2000_genes.yaml │ ├── finetune_combinatorial.yaml │ ├── finetune_grover.yaml │ ├── hparams │ │ └── default.yaml │ ├── lincs.yaml │ ├── pretrain_combinatorial.yaml │ └── sciplex.yaml ├── pretrain_combinatorial.yaml ├── sciplex.yaml ├── training │ └── default.yaml └── wandb │ └── default.yaml ├── docker_entrypoint.sh ├── docs └── chemCPA.png ├── download_training_output.sh ├── embeddings ├── chemvae │ ├── README.md │ ├── generate_embeddings.ipynb │ ├── generate_embeddings.py │ ├── train.py │ └── train_chemvae.sh ├── dgl │ ├── embedding_pretrained_gnn.ipynb │ └── embedding_pretrained_gnn.py ├── grover │ ├── README.md │ ├── data │ │ ├── embeddings │ │ │ └── .gitkeep │ │ └── model │ │ │ └── .gitkeep │ ├── environment.yml │ ├── generate_embeddings.ipynb │ ├── generate_embeddings.py │ ├── grover │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── dist_sampler.py │ │ │ ├── groverdataset.py │ │ │ ├── moldataset.py │ │ │ ├── molfeaturegenerator.py │ │ │ ├── molgraph.py │ │ │ ├── scaler.py │ │ │ ├── task_labels.py │ │ │ └── torchvocab.py │ │ ├── model │ │ │ ├── layers.py │ │ │ └── models.py │ │ └── util │ │ │ ├── metrics.py │ │ │ ├── multi_gpu_wrapper.py │ │ │ ├── nn_utils.py │ │ │ ├── parsing.py │ │ │ ├── scheduler.py │ │ │ └── utils.py │ ├── main.py │ ├── requirements.txt │ ├── scripts │ │ ├── __init__.py │ │ ├── build_vocab.py │ │ ├── save_features.py │ │ └── split_data.py │ └── task │ │ ├── __init__.py │ │ ├── cross_validate.py │ │ ├── fingerprint.py │ │ ├── grovertrainer.py │ │ ├── predict.py │ │ ├── pretrain.py │ │ ├── run_evaluation.py │ │ └── train.py ├── jtvae │ ├── README.md │ ├── analyze_smiles.ipynb │ ├── analyze_smiles.py │ ├── environment.yml │ ├── generate_embeddings.ipynb │ ├── generate_embeddings.py │ ├── jtvae_train_all.yaml │ ├── jtvae_vaetrain_all.yaml │ ├── pretrain.py │ ├── reconstruct.py │ ├── seml_train.py │ ├── utils.py │ └── vaetrain.py ├── lincs_drugs_smiles.csv ├── lincs_trapnell.smiles ├── rdkit │ ├── __init__.py │ ├── embedding_rdkit.ipynb │ └── embedding_rdkit.py ├── seq2seq │ ├── README.md │ ├── environment.yml │ ├── generate_embeddings.ipynb │ ├── generate_embeddings.py │ ├── slurm_train.sh │ └── train_model.py ├── trapnell_drugs_smiles.csv ├── zinc_smiles_test.txt └── zinc_smiles_train.csv ├── environment.yaml ├── environment.yml ├── experiments ├── README.md ├── baseline_comparison │ ├── baseline_experiment.ipynb │ ├── baseline_experiment.py │ ├── baseline_experiment.yaml │ ├── baseline_experiment_high_dose.ipynb │ ├── baseline_experiment_high_dose.py │ ├── baseline_experiment_highest_dose.yaml │ └── results.md ├── dom_experiments │ ├── analyse_single_config_biolord.ipynb │ ├── analyse_single_config_biolord.py │ ├── analyse_single_config_sciplex.ipynb │ ├── analyse_single_config_sciplex.py │ ├── analyze_biolord_runs.ipynb │ ├── analyze_biolord_runs.py │ ├── analyze_sciplex_runs.ipynb │ ├── analyze_sciplex_runs.py │ ├── combine_adata_biolord.ipynb │ ├── combine_adata_biolord.py │ ├── compute_embedding_rdkit.ipynb │ ├── compute_embedding_rdkit.py │ ├── config_biolord.yaml │ ├── config_sciplex.yaml │ └── utils.py ├── finetuning_num_genes │ ├── README.md │ ├── analyze_sciplex_finetune_num_genes.ipynb │ ├── analyze_sciplex_finetune_num_genes.py │ └── config_sciplex_finetune_num_genes.yaml ├── fold_change │ ├── fold_experiment.ipynb │ ├── fold_experiment.py │ ├── fold_experiment.yaml │ ├── fold_experiment_highest_dose.ipynb │ ├── fold_experiment_highest_dose.py │ └── fold_experiment_highest_dose.yaml ├── lincs_rdkit_hparam │ ├── README.md │ ├── analyze_lincs_all_embeddings_hparam.ipynb │ ├── analyze_lincs_all_embeddings_hparam.py │ ├── analyze_lincs_rdkit_hparam.ipynb │ ├── analyze_lincs_rdkit_hparam.py │ ├── config_lincs_all_embbeddings_hparam_sweep.yaml │ └── config_lincs_rdkit_hparam_sweep.yaml └── sciplex_hparam │ ├── README.md │ ├── analyze_sciplex_finetuning_hparam.ipynb │ ├── analyze_sciplex_finetuning_hparam.py │ ├── analyze_sciplex_rdkit_hparam.ipynb │ ├── analyze_sciplex_rdkit_hparam.py │ ├── config_sciplex_finetuning_hparam.yaml │ └── config_sciplex_rdkit_hparam.yaml ├── load_lightning.ipynb ├── load_lightning.py ├── manual_run.yaml ├── manual_seml_sweep.py ├── notebooks ├── Additional │ ├── analyse_baseline.ipynb │ ├── analyse_baseline.py │ ├── analysis_results.md │ ├── experiment_analysis.ipynb │ ├── experiment_analysis.py │ ├── sciplex_cpa.ipynb │ ├── sciplex_cpa.py │ ├── sciplex_scgen.ipynb │ ├── sciplex_scgen.py │ └── utils.py ├── README.md ├── chemCPA_Figure_2_grover.ipynb ├── chemCPA_Figure_2_grover.py ├── chemCPA_Figure_2_jtvae.ipynb ├── chemCPA_Figure_2_jtvae.py ├── chemCPA_Figure_2_rdkit.ipynb ├── chemCPA_Figure_2_rdkit.py ├── chemCPA_Figure_3.ipynb ├── chemCPA_Figure_3.py ├── chemCPA_Figure_4_grover.ipynb ├── chemCPA_Figure_4_grover.py ├── chemCPA_Figure_4_jtvae.ipynb ├── chemCPA_Figure_4_jtvae.py ├── chemCPA_Figure_4_rdkit.ipynb ├── chemCPA_Figure_4_rdkit.py ├── chemCPA_Table_2.ipynb ├── chemCPA_Table_2.py ├── chemCPA_Table_3.ipynb ├── chemCPA_Table_3.py ├── chemCPA_Table_4.ipynb ├── chemCPA_Table_4.py ├── finetuning_num_genes.json └── utils.py ├── preprocessing ├── 1_lincs.ipynb ├── 1_lincs.py ├── 2_lincs_SMILES.ipynb ├── 2_lincs_SMILES.py ├── 3_lincs_sciplex_comb.py ├── 3_lincs_sciplex_gene_matching.ipynb ├── 3_lincs_sciplex_gene_matching.py ├── 4_sciplex_SMILES.ipynb ├── 4_sciplex_SMILES.py ├── 5_sciplex_ood_splits.ipynb ├── 5_sciplex_ood_splits.py ├── 6_baseline_sciplex_dataset.ipynb ├── 6_baseline_sciplex_dataset.py ├── 7_compute_embeddings.ipynb ├── 7_compute_embeddings.py ├── README.md ├── analysis_smiles_lincs_trapnell.ipynb ├── analysis_smiles_lincs_trapnell.py ├── convert_notebooks.sh ├── drug_dict.json ├── notebook_utils.py ├── run_notebooks.py └── supress_output.py ├── project_folder ├── pyproject.toml ├── raw_data ├── __init__.py ├── datasets.py ├── download_data.py └── download_utils.py ├── setup.py ├── test_config.yaml ├── test_config_biolord.yaml └── tests ├── test_dataset.py ├── test_dosers.py └── test_embedding.py /.gitignore: -------------------------------------------------------------------------------- 1 | #Individual 2 | notebooks/trapnell_2.ipynb 3 | .idea/ 4 | .vscode/ 5 | kang_preprocessed* 6 | model_seed* 7 | sweeps/logs/* 8 | datasets/* 9 | pretrained_models/* 10 | git-lfs-folders/* 11 | cpa_binaries.tar 12 | notebooks/*.pth 13 | embeddings/*/data/* 14 | outputs/* 15 | lightning_logs/* 16 | wandb/* 17 | # embeddings 18 | *.parquet 19 | 20 | # Byte-compiled / optimized / DLL files 21 | __pycache__/ 22 | *.py[cod] 23 | *$py.class 24 | 25 | # C extensions 26 | *.so 27 | 28 | # Distribution / packaging 29 | .Python 30 | build/ 31 | develop-eggs/ 32 | dist/ 33 | downloads/ 34 | eggs/ 35 | .eggs/ 36 | lib/ 37 | lib64/ 38 | parts/ 39 | sdist/ 40 | var/ 41 | wheels/ 42 | pip-wheel-metadata/ 43 | share/python-wheels/ 44 | *.egg-info/ 45 | .installed.cfg 46 | *.egg 47 | MANIFEST 48 | 49 | # PyInstaller 50 | # Usually these files are written by a python script from a template 51 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 52 | *.manifest 53 | *.spec 54 | 55 | # Installer logs 56 | pip-log.txt 57 | pip-delete-this-directory.txt 58 | 59 | # Unit test / coverage reports 60 | htmlcov/ 61 | .tox/ 62 | .nox/ 63 | .coverage 64 | .coverage.* 65 | .cache 66 | nosetests.xml 67 | coverage.xml 68 | *.cover 69 | *.py,cover 70 | .hypothesis/ 71 | .pytest_cache/ 72 | 73 | # Translations 74 | *.mo 75 | *.pot 76 | 77 | # Django stuff: 78 | *.log 79 | local_settings.py 80 | db.sqlite3 81 | db.sqlite3-journal 82 | 83 | # Flask stuff: 84 | instance/ 85 | .webassets-cache 86 | 87 | # Scrapy stuff: 88 | .scrapy 89 | 90 | # Sphinx documentation 91 | docs/_build/ 92 | 93 | # PyBuilder 94 | target/ 95 | 96 | # Jupyter Notebook 97 | .ipynb_checkpoints 98 | 99 | # IPython 100 | profile_default/ 101 | ipython_config.py 102 | 103 | # pyenv 104 | .python-version 105 | 106 | # pipenv 107 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 108 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 109 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 110 | # install all needed dependencies. 111 | #Pipfile.lock 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # macos 151 | .DS_Store 152 | 153 | # VSCode 154 | settings.json 155 | scripts/*.yaml 156 | embeddings_dir 157 | training_output 158 | training_output.zip 159 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 22.10.0 4 | hooks: 5 | - id: black 6 | language_version: python3 7 | args: ["--line-length", "120"] 8 | 9 | - repo: https://github.com/pycqa/isort 10 | rev: 5.11.5 11 | hooks: 12 | - id: isort 13 | args: ["--profile", "black", "--filter-files"] 14 | verbose: true 15 | 16 | - repo: https://github.com/mwouts/jupytext 17 | rev: v1.14.1 18 | hooks: 19 | - id: jupytext 20 | args: 21 | - --from=ipynb 22 | - --to=py:percent 23 | - --pipe 24 | - black --line-length 120 - 25 | - --pipe 26 | - isort --profile black --filter-files - 27 | - --opt=notebook_metadata_filter=-kernelspec 28 | additional_dependencies: 29 | - black==22.10.0 # Matches hook 30 | - isort==5.11.5 31 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | # This CITATION.cff file was generated with cffinit. 2 | # Visit https://bit.ly/cffinit to generate yours today! 3 | 4 | cff-version: 1.2.0 5 | title: >- 6 | Predicting Cellular Responses to Novel Drug 7 | Perturbations at a Single-Cell Resolution 8 | message: >- 9 | If you use this software, please cite it using the 10 | metadata from this file. 11 | type: software 12 | authors: 13 | - given-names: Leon 14 | family-names: Hetzel 15 | - given-names: Simon 16 | family-names: Boehm 17 | - given-names: Niki 18 | family-names: Kilbertus 19 | - given-names: Stephan 20 | family-names: Günnemann 21 | - given-names: Mohammad 22 | family-names: Lotfollahi 23 | - given-names: Fabian 24 | name-particle: J 25 | family-names: Theis 26 | identifiers: 27 | - type: url 28 | value: 'https://neurips.cc/virtual/2022/poster/53227' 29 | repository-code: 'https://github.com/theislab/chemCPA' 30 | abstract: >+ 31 | Single-cell transcriptomics enabled the study of 32 | cellular heterogeneity in response to perturbations 33 | at the resolution of individual cells. However, 34 | scaling high-throughput screens (HTSs) to measure 35 | cellular responses for many drugs remains a 36 | challenge due to technical limitations and, more 37 | importantly, the cost of such multiplexed 38 | experiments. Thus, transferring information from 39 | routinely performed bulk RNA HTS is required to 40 | enrich single-cell data meaningfully.We introduce 41 | chemCPA, a new encoder-decoder architecture to 42 | study the perturbational effects of unseen drugs. 43 | We combine the model with an architecture surgery 44 | for transfer learning and demonstrate how training 45 | on existing bulk RNA HTS datasets can improve 46 | generalisation performance. Better generalisation 47 | reduces the need for extensive and costly screens 48 | at single-cell resolution. We envision that our 49 | proposed method will facilitate more efficient 50 | experiment designs through its ability to generate 51 | in-silico hypotheses, ultimately accelerating drug 52 | discovery. 53 | 54 | keywords: 55 | - transfer learning 56 | - disentanglement 57 | - perturbation 58 | - single cell 59 | - genomics 60 | - Drug Discovery 61 | - unsupervised 62 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Predicting Cellular Responses to Novel Drug Perturbations at a Single-Cell Resolution 2 | 3 | Code accompanying the [NeurIPS 2022 paper](https://neurips.cc/virtual/2022/poster/53227) ([PDF](https://openreview.net/pdf?id=vRrFVHxFiXJ)). 4 | 5 | ![architecture of CCPA](docs/chemCPA.png) 6 | 7 | Our talk on chemCPA at the M2D2 reading club is available [here](https://m2d2.io/talks/m2d2/predicting-single-cell-perturbation-responses-for-unseen-drugs/). 8 | A [previous version](https://arxiv.org/abs/2204.13545) of this work was a spotlight paper at ICLR MLDD 2022. 9 | Code for this previous version can be found under the `v1.0` git tag. 10 | 11 | ## Codebase overview 12 | 13 | - `chemCPA/`: contains the code for the model, the data, and the training loop. 14 | - `embeddings`: There is one folder for each molecular embedding model we benchmarked. Each contains an `environment.yml` with dependencies. We generated the embeddings using the provided notebooks and saved them to disk, to load them during the main training loop. 15 | - `experiments`: Each folder contains a `README.md` with the experiment description, a `.yaml` file with the seml configuration, and a notebook to analyze the results. 16 | - `notebooks`: Example analysis notebooks. 17 | - `preprocessing`: Notebooks for processing the data. For each dataset there is one notebook that loads the raw data. 18 | - `tests`: A few very basic tests. 19 | 20 | All experiments where run through [seml](https://github.com/TUM-DAML/seml). 21 | The entry function is `ExperimentWrapper.__init__` in `chemCPA/seml_sweep_icb.py`. 22 | For convenience, we provide a script to run experiments manually for debugging purposes at `chemCPA/manual_seml_sweep.py`. 23 | The script expects a `manual_run.yaml` file containing the experiment configuration. 24 | 25 | All notebooks also exist as Python scripts (converted through [jupytext](https://github.com/mwouts/jupytext)) to make them easier to review. 26 | 27 | ## Getting started 28 | 29 | #### Environment 30 | The easiest way to get started is to use a docker image we provide 31 | ``` 32 | docker run -it -p 8888:8888 --platform=linux/amd64 registry.hf.space/b1ro-chemcpa:latest 33 | ``` 34 | this image contains the source code and all dependencies to run the experiments. 35 | By default it runs a jupyter server on port 8888. 36 | 37 | Alternatively you may clone this repository and setup your own environment by running: 38 | 39 | ```python 40 | conda env create -f environment.yml 41 | python setup.py install -e . 42 | ``` 43 | 44 | 45 | 46 | #### Datasets 47 | The datasets are not included in the docker image, but get automatically downloaded when you run the notebooks that require them. The datasets may alternatively be downloaded manually using the python tool in the `raw_data/dataset.py` folder. Usage is: 48 | ``` 49 | python raw_data/dataset.py --list 50 | python raw_data/dataset.py --dataset 51 | ``` 52 | 53 | or you may use the following links: 54 | - [weight checkpoints](https://f003.backblazeb2.com/file/chemCPA-models/chemCPA_models.zip) 55 | - [hyperparameter configuration](https://f003.backblazeb2.com/file/chemCPA-models/finetuning_num_genes.json) 56 | - [raw datasets](https://dl.fbaipublicfiles.com/dlp/cpa_binaries.tar) 57 | - [processed datasets](https://f003.backblazeb2.com/file/chemCPA-datasets/) 58 | - [embeddings](https://drive.google.com/drive/folders/1KzkhYptcW3uT3j4GQpDdAC1DXEuXe49J?usp=share_link) 59 | 60 | Some of the notebooks use a *drugbank_all.csv* file, which can be downloaded from [here](https://go.drugbank.com/) (registration needed). 61 | 62 | #### Data preparation 63 | To train the models, first the raw data needs to be processed. 64 | This can be done by running the notebooks inside the `preprocessing/` folder in a sequential order. 65 | Alternatively, you may run 66 | 67 | ``` 68 | python preprocessing/run_notebooks.py 69 | ``` 70 | A description of the preprocessing steps is given in the `preprocessing/README.md` file and in the headers 71 | of individual notebooks. Section 4 of the paper is also highly relevant. 72 | 73 | #### Training the models 74 | Run 75 | ``` 76 | python chemCPA/train_hydra.py 77 | ``` 78 | 79 | ## Citation 80 | 81 | You can cite our work as: 82 | 83 | ``` 84 | @inproceedings{hetzel2022predicting, 85 | title={Predicting Cellular Responses to Novel Drug Perturbations at a Single-Cell Resolution}, 86 | author={Hetzel, Leon and Böhm, Simon and Kilbertus, Niki and Günnemann, Stephan and Lotfollahi, Mohammad and Theis, Fabian J}, 87 | booktitle={NeurIPS 2022}, 88 | year={2022} 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /chemCPA/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/chemCPA/43e830eb0958c54e4aa64442c17ec0fed19b3f15/chemCPA/__init__.py -------------------------------------------------------------------------------- /chemCPA/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset, SubDataset, drug_names_to_once_canon_smiles 2 | from .perturbation_data_module import PerturbationDataModule 3 | from .data import load_data, load_dataset_splits 4 | 5 | __all__ = [ 6 | "load_data", 7 | "load_dataset_splits", 8 | "Dataset", 9 | "SubDataset", 10 | "PerturbationDataModule", 11 | "drug_names_to_once_canon_smiles" 12 | ] 13 | -------------------------------------------------------------------------------- /chemCPA/data/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset, drug_names_to_once_canon_smiles 2 | from .subdataset import SubDataset 3 | 4 | __all__ = ['Dataset', 'SubDataset', 'drug_names_to_once_canon_smiles'] 5 | -------------------------------------------------------------------------------- /chemCPA/data/dataset/compute_degs.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch 4 | import pandas as pd 5 | from tqdm import tqdm 6 | 7 | def compute_degs(drugs_names, covariate_names, dose_names, de_genes, var_names): 8 | """ 9 | Compute differential gene expression (DEG) tensor for given drug-covariate-dose combinations. 10 | 11 | This vectorized version avoids explicit loops over combinations by using NumPy and PyTorch operations. 12 | 13 | Args: 14 | drugs_names (list): List of drug names. 15 | covariate_names (list): List of covariate names. 16 | dose_names (list): List of dose names (if dose-specific). 17 | de_genes (dict): Dictionary mapping drug-covariate-dose keys to lists of differentially expressed genes. 18 | var_names (list): List of all gene names to be considered. 19 | 20 | Returns: 21 | torch.Tensor: A binary tensor of shape (number of combinations, number of genes), 22 | where 1 indicates the gene is differentially expressed for that combination, 23 | and 0 indicates it is not. 24 | """ 25 | start_time = time.time() 26 | dose_specific = len(list(de_genes.keys())[0].split("_")) == 3 27 | 28 | gene_to_index = {gene: i for i, gene in enumerate(var_names)} 29 | var_names_set = set(var_names) 30 | 31 | covariate_names = np.array(covariate_names, dtype=str) 32 | drugs_names = np.array(drugs_names, dtype=str) 33 | if dose_specific: 34 | dose_names = np.array(dose_names, dtype=str) 35 | 36 | if dose_specific: 37 | keys = np.char.add(np.char.add(np.char.add(covariate_names, '_'), drugs_names), '_') 38 | keys = np.char.add(keys, dose_names) 39 | else: 40 | keys = np.char.add(np.char.add(covariate_names, '_'), drugs_names) 41 | 42 | N = len(keys) 43 | print(f"Number of combinations: {N}") 44 | 45 | key_to_index = {key: i for i, key in enumerate(keys)} 46 | 47 | control_drugs = {'control', 'DMSO', 'Vehicle'} 48 | is_control = np.isin(drugs_names, list(control_drugs)) 49 | 50 | degs = torch.zeros((N, len(var_names)), dtype=torch.float32) 51 | 52 | row_indices = [] 53 | col_indices = [] 54 | 55 | # Decode byte strings in de_genes keys and values 56 | de_genes_decoded = {} 57 | for key, genes in de_genes.items(): 58 | # Decode the key if it's a byte string 59 | if isinstance(key, bytes): 60 | key = key.decode('utf-8') 61 | # Decode each gene in the list if it's a byte string 62 | genes = [gene.decode('utf-8') if isinstance(gene, bytes) else gene for gene in genes] 63 | de_genes_decoded[key] = genes 64 | 65 | de_genes = de_genes_decoded 66 | 67 | for key, genes in tqdm(de_genes.items(), desc="Processing DEGs"): 68 | # Decode key if necessary 69 | if isinstance(key, bytes): 70 | key = key.decode('utf-8') 71 | # Decode genes if they are byte strings 72 | genes = [gene.decode('utf-8') if isinstance(gene, bytes) else gene for gene in genes] 73 | idx = key_to_index.get(key) 74 | if idx is not None: 75 | if not is_control[idx]: 76 | valid_genes = var_names_set.intersection(genes) 77 | indices = [gene_to_index[gene] for gene in valid_genes] 78 | if indices: 79 | row_indices.extend([idx] * len(indices)) 80 | col_indices.extend(indices) 81 | 82 | if row_indices: 83 | degs[row_indices, col_indices] = 1.0 84 | 85 | print(f"DEGs tensor shape: {degs.shape}") 86 | print(f"Time taken: {time.time() - start_time:.2f} seconds") 87 | return degs 88 | 89 | -------------------------------------------------------------------------------- /chemCPA/data/dataset/drug_names_to_once_canon_smiles.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from chemCPA.helper import canonicalize_smiles 3 | 4 | def drug_names_to_once_canon_smiles( 5 | drug_names: List[str], 6 | obs: dict, 7 | perturbation_key: str, 8 | smiles_key: str 9 | ): 10 | """ 11 | For each row in obs, split combination drug names (on '+') 12 | and combo SMILES (on '..'), then canonicalize each sub-part 13 | and store them in a dictionary keyed by the *individual* drug name. 14 | 15 | That way, if drug_names includes both single and sub-drug names, 16 | we have an entry for each sub-drug. 17 | """ 18 | drug_names_array = obs[perturbation_key] 19 | smiles_array = obs[smiles_key] 20 | 21 | # Build a set of (full_combo_drug_name, full_combo_smiles) 22 | unique_pairs = set(zip(drug_names_array, smiles_array)) 23 | name_to_smiles_map = {} 24 | 25 | for combo_name, combo_smiles in unique_pairs: 26 | # If this row doesn't have valid strings, skip 27 | if not isinstance(combo_name, str) or not isinstance(combo_smiles, str): 28 | continue 29 | 30 | # Split the drug name on '+' 31 | sub_drugs = combo_name.split('+') 32 | # Split the SMILES on '..' 33 | sub_smiles = combo_smiles.split('..') 34 | 35 | # If lengths don't match, handle or skip 36 | if len(sub_drugs) != len(sub_smiles): 37 | # Example: skip this row or raise an error 38 | continue 39 | 40 | # Canonicalize each sub-smiles, store in map keyed by each sub-drug 41 | for drug, raw_smi in zip(sub_drugs, sub_smiles): 42 | drug = drug.strip() 43 | smi = raw_smi.strip() 44 | try: 45 | # Canonicalize each sub-smiles 46 | canon = canonicalize_smiles(smi) 47 | except Exception as e: 48 | # Optionally handle parsing errors 49 | canon = None 50 | name_to_smiles_map[drug] = canon 51 | 52 | # Now build the output list: for each (sub-)drug_name in the requested list 53 | # return the canonical SMILES if present, else None or raise an error 54 | result = [] 55 | for name in drug_names: 56 | name = name.strip() 57 | if name in name_to_smiles_map: 58 | result.append(name_to_smiles_map[name]) 59 | else: 60 | # Decide how to handle unknown sub-drugs 61 | result.append(None) 62 | 63 | return result 64 | 65 | 66 | -------------------------------------------------------------------------------- /chemCPA/data/dataset/subdataset.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import numpy as np 3 | 4 | indx = lambda a, i: a[i] if a is not None else None 5 | 6 | class SubDataset: 7 | """ 8 | Subsets a `Dataset` or another `SubDataset` by selecting the examples given by `indices`. 9 | """ 10 | 11 | def __init__(self, dataset, indices): 12 | # Store a reference to the parent dataset (could be Dataset or SubDataset) 13 | self.dataset = dataset 14 | self.use_drugs_idx = dataset.use_drugs_idx 15 | 16 | # Map indices to the original dataset if necessary 17 | if isinstance(dataset, SubDataset): 18 | # Map original indices to keep track for future subsetting 19 | self.original_indices = [dataset.original_indices[i] for i in indices] 20 | else: 21 | self.original_indices = indices 22 | 23 | # Access data using indices directly from the parent dataset 24 | self.genes = dataset.genes[indices] 25 | 26 | if self.use_drugs_idx: 27 | self.drugs_idx = indx(dataset.drugs_idx, indices) 28 | self.dosages = indx(dataset.dosages, indices) 29 | else: 30 | self.drugs = indx(dataset.drugs, indices) 31 | 32 | # Retrieve dosage values from the parent dataset 33 | self.dose_values = [dataset.dose_values[i] for i in indices] 34 | 35 | # Proceed with other attributes 36 | self.degs = indx(dataset.degs, indices) 37 | self.covariates = [indx(cov, indices) for cov in dataset.covariates] 38 | 39 | self.drugs_names = indx(dataset.drugs_names, indices) 40 | self.pert_categories = indx(dataset.pert_categories, indices) 41 | self.covariate_keys = dataset.covariate_keys 42 | self.covariate_names = {} 43 | for cov in self.covariate_keys: 44 | self.covariate_names[cov] = indx(dataset.covariate_names[cov], indices) 45 | 46 | self.var_names = dataset.var_names 47 | self.de_genes = dataset.de_genes 48 | self.ctrl_name = dataset.ctrl_name 49 | 50 | self.num_covariates = dataset.num_covariates 51 | self.num_genes = dataset.num_genes 52 | self.num_drugs = dataset.num_drugs 53 | 54 | # **Added missing attributes** 55 | self.perturbation_key = dataset.perturbation_key 56 | self.dose_key = dataset.dose_key 57 | self.smiles_key = dataset.smiles_key 58 | self.degs_key = dataset.degs_key 59 | self.pert_category = dataset.pert_category 60 | self.split_key = dataset.split_key 61 | 62 | def subset(self, dosage_range=None, dosage_filter=None): 63 | """ 64 | Creates a new SubDataset by filtering the current SubDataset based on dosage criteria. 65 | 66 | Parameters: 67 | dosage_range (tuple): A tuple (min_dose, max_dose) to filter samples where all dosages fall within this range. 68 | dosage_filter (callable): A function that takes a list of dosages and returns True if the sample should be included. 69 | 70 | Returns: 71 | SubDataset: A new SubDataset instance with the filtered data. 72 | """ 73 | idx = list(range(len(self.genes))) 74 | 75 | # Filter based on dosage_range if provided 76 | if dosage_range is not None: 77 | min_dose, max_dose = dosage_range 78 | idx = [ 79 | i for i in idx 80 | if all(min_dose <= dose <= max_dose for dose in self.dose_values[i]) 81 | ] 82 | 83 | # Apply a custom dosage_filter function if provided 84 | if dosage_filter is not None: 85 | idx = [i for i in idx if dosage_filter(self.dose_values[i])] 86 | 87 | return SubDataset(self, idx) 88 | 89 | def __getitem__(self, i): 90 | if self.use_drugs_idx: 91 | return ( 92 | self.genes[i], 93 | indx(self.drugs_idx, i), 94 | indx(self.dosages, i), 95 | indx(self.degs, i), 96 | *[indx(cov, i) for cov in self.covariates], 97 | ) 98 | else: 99 | return ( 100 | self.genes[i], 101 | indx(self.drugs, i), 102 | indx(self.degs, i), 103 | *[indx(cov, i) for cov in self.covariates], 104 | ) 105 | 106 | def __len__(self): 107 | return len(self.genes) 108 | -------------------------------------------------------------------------------- /chemCPA/data/perturbation_data_module.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | from torch.utils.data import DataLoader 3 | import torch 4 | 5 | def custom_collate(batch): 6 | transposed = zip(*batch) 7 | concat_batch = [] 8 | for samples in transposed: 9 | if samples[0] is None: 10 | concat_batch.append(None) 11 | else: 12 | # Move to CUDA here so that prefetching in the DataLoader yields ready-to-process CUDA tensors 13 | concat_batch.append(torch.stack(samples, 0).to("cuda")) 14 | return concat_batch 15 | 16 | class PerturbationDataModule(L.LightningDataModule): 17 | def __init__(self, datasplits, train_bs=32, val_bs=32, test_bs=32): 18 | super().__init__() 19 | self.datasplits = datasplits 20 | self.train_bs = train_bs 21 | self.val_bs = val_bs 22 | self.test_bs = test_bs 23 | 24 | def setup(self, stage: str): 25 | # Assign datasets for use in dataloaders 26 | if stage == "fit": 27 | self.train_dataset = self.datasplits["training"] 28 | self.train_control_dataset = self.datasplits["training_control"] 29 | self.train_treated_dataset = self.datasplits["training_treated"] 30 | self.test_dataset = self.datasplits["test"] 31 | self.test_control_dataset = self.datasplits["test_control"] 32 | self.test_treated_dataset = self.datasplits["test_treated"] 33 | self.ood_control_dataset = self.datasplits["test_control"] 34 | self.ood_treated_dataset = self.datasplits["ood"] 35 | 36 | if stage == "validate" or stage == "test": 37 | self.test_dataset = self.datasplits["test"] 38 | self.test_control_dataset = self.datasplits["test_control"] 39 | self.test_treated_dataset = self.datasplits["test_treated"] 40 | 41 | if stage == "predict": 42 | self.ood_control_dataset = self.datasplits["test_control"] 43 | self.ood_treated_dataset = self.datasplits["ood"] 44 | 45 | def train_dataloader(self): 46 | return DataLoader(self.train_dataset, batch_size=self.train_bs, shuffle=True, collate_fn=custom_collate) 47 | 48 | def val_dataloader(self): 49 | return { 50 | "test": DataLoader(self.test_dataset, batch_size=self.val_bs), 51 | "test_control": DataLoader(self.test_control_dataset, batch_size=self.val_bs), 52 | "test_treated": DataLoader(self.test_treated_dataset, batch_size=self.val_bs), 53 | } 54 | 55 | def test_dataloader(self): 56 | return DataLoader(self.test_dataset, batch_size=self.test_bs) 57 | 58 | def predict_dataloader(self): 59 | return DataLoader(self.ood_dataset, batch_size=self.test_bs) -------------------------------------------------------------------------------- /chemCPA/helper.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Optional 3 | 4 | # import dgl 5 | import pandas as pd 6 | import scanpy as sc 7 | 8 | # from dgllife.utils import ( 9 | # AttentiveFPAtomFeaturizer, 10 | # AttentiveFPBondFeaturizer, 11 | # CanonicalAtomFeaturizer, 12 | # CanonicalBondFeaturizer, 13 | # PretrainAtomFeaturizer, 14 | # PretrainBondFeaturizer, 15 | # smiles_to_bigraph, 16 | # ) 17 | from rdkit import Chem 18 | 19 | 20 | def rank_genes_groups_by_cov( 21 | adata, 22 | groupby, 23 | control_group, 24 | covariate, 25 | pool_doses=False, 26 | n_genes=50, 27 | rankby_abs=True, 28 | key_added="rank_genes_groups_cov", 29 | return_dict=False, 30 | ): 31 | 32 | """ 33 | Function that generates a list of differentially expressed genes computed 34 | separately for each covariate category, and using the respective control 35 | cells as reference. 36 | 37 | Usage example: 38 | 39 | rank_genes_groups_by_cov( 40 | adata, 41 | groupby='cov_product_dose', 42 | covariate_key='cell_type', 43 | control_group='Vehicle_0' 44 | ) 45 | 46 | Parameters 47 | ---------- 48 | adata : AnnData 49 | AnnData dataset 50 | groupby : str 51 | Obs column that defines the groups, should be 52 | cartesian product of covariate_perturbation_cont_var, 53 | it is important that this format is followed. 54 | control_group : str 55 | String that defines the control group in the groupby obs 56 | covariate : str 57 | Obs column that defines the main covariate by which we 58 | want to separate DEG computation (eg. cell type, species, etc.) 59 | n_genes : int (default: 50) 60 | Number of DEGs to include in the lists 61 | rankby_abs : bool (default: True) 62 | If True, rank genes by absolute values of the score, thus including 63 | top downregulated genes in the top N genes. If False, the ranking will 64 | have only upregulated genes at the top. 65 | key_added : str (default: 'rank_genes_groups_cov') 66 | Key used when adding the dictionary to adata.uns 67 | return_dict : str (default: False) 68 | Signals whether to return the dictionary or not 69 | 70 | Returns 71 | ------- 72 | Adds the DEG dictionary to adata.uns 73 | 74 | If return_dict is True returns: 75 | gene_dict : dict 76 | Dictionary where groups are stored as keys, and the list of DEGs 77 | are the corresponding values 78 | 79 | """ 80 | 81 | gene_dict = {} 82 | cov_categories = adata.obs[covariate].unique() 83 | for cov_cat in cov_categories: 84 | print(cov_cat) 85 | # name of the control group in the groupby obs column 86 | control_group_cov = "_".join([cov_cat, control_group]) 87 | 88 | # subset adata to cells belonging to a covariate category 89 | adata_cov = adata[adata.obs[covariate] == cov_cat] 90 | 91 | # compute DEGs 92 | sc.tl.rank_genes_groups( 93 | adata_cov, 94 | groupby=groupby, 95 | reference=control_group_cov, 96 | rankby_abs=rankby_abs, 97 | n_genes=n_genes, 98 | ) 99 | 100 | # add entries to dictionary of gene sets 101 | de_genes = pd.DataFrame(adata_cov.uns["rank_genes_groups"]["names"]) 102 | for group in de_genes: 103 | gene_dict[group] = de_genes[group].tolist() 104 | 105 | adata.uns[key_added] = gene_dict 106 | 107 | if return_dict: 108 | return gene_dict 109 | 110 | 111 | def canonicalize_smiles(smiles: Optional[str]): 112 | if smiles: 113 | mol = Chem.MolFromSmiles(smiles) 114 | if mol is None: 115 | return None 116 | return Chem.MolToSmiles( 117 | mol, 118 | isomericSmiles=True, # Keep stereochemistry information 119 | canonical=True, # Ensure canonical atom ordering 120 | doRandom=False, # Don't introduce randomness 121 | allBondsExplicit=False, # Don't make all bonds explicit 122 | allHsExplicit=False, # Don't make all hydrogens explicit 123 | kekuleSmiles=False # Don't kekulize the molecule 124 | ) 125 | else: 126 | return None 127 | -------------------------------------------------------------------------------- /chemCPA/paths.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | ROOT = Path(__file__).parent.resolve().parent 4 | 5 | PROJECT_DIR = ROOT / "project_folder" 6 | DATA_DIR = PROJECT_DIR / "datasets" 7 | EMBEDDING_DIR = PROJECT_DIR / "embeddings" 8 | CHECKPOINT_DIR = PROJECT_DIR / "checkpoints" 9 | FIGURE_DIR = PROJECT_DIR / "figures" 10 | WB_DIR = PROJECT_DIR / "wandb" 11 | -------------------------------------------------------------------------------- /chemCPA/profiling.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import signal 4 | import subprocess 5 | from pathlib import Path 6 | 7 | import sacred 8 | 9 | 10 | class Profiler: 11 | outpath: Path 12 | _process: subprocess.Popen 13 | 14 | def __init__(self, seed: str, save_dir: str): 15 | """ 16 | Creates a new profiler without start it yet. 17 | @param seed: random string used for generating unique filepath. 18 | @param save_dir: directory to save the file to. 19 | """ 20 | assert Path(save_dir).is_dir(), f"{save_dir} is not a directory!" 21 | self.outpath = Path(save_dir) / f"profile_{seed}.speedscope" 22 | assert shutil.which("py-spy"), "py-spy not found, please install it first." 23 | 24 | def start(self): 25 | """Start recording the current Python process""" 26 | # starts py-spy in a new subprocess 27 | self._process = subprocess.Popen( 28 | [ 29 | shutil.which("py-spy"), 30 | "record", 31 | "--pid", 32 | str(os.getpid()), # tells py-spy to profile the current Python process 33 | "--rate", 34 | "3", # three samples per second should be fine-grained enough and the outfile won't get too large 35 | "--format", 36 | "speedscope", # look at profiles via https://speedscope.app 37 | "--output", 38 | str( 39 | self.outpath 40 | ), # file to save results at (once profiling has finished) 41 | ] 42 | ) 43 | 44 | def stop(self, experiment: sacred.Experiment): 45 | """ 46 | Stop recording and save the results to a file and to MongoDB 47 | @param experiment: The seml / sacred experiment. 48 | """ 49 | # First, send same signal as CTRL+C would. Py-spy should quit and save the results. 50 | self._process.send_signal(signal.SIGINT) 51 | try: 52 | # if the profiler didn't exit after 10s, kill it 53 | self._process.wait(timeout=10) 54 | except subprocess.TimeoutExpired: 55 | # sends SIGKILL. py-spy will quit, but will not save a profile. 56 | self._process.kill() 57 | print("killed py-spy due to timeout.") 58 | # collect the zombie process 59 | self._process.wait(timeout=2) 60 | 61 | # upload the profiling results to mongoDB as a binary 62 | if self.outpath.is_file(): 63 | experiment.add_artifact( 64 | str(self.outpath), 65 | name="py_spy_profile", 66 | content_type="application/octet-stream", 67 | ) 68 | -------------------------------------------------------------------------------- /chemCPA/train_hydra_tmux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Name of the tmux session 4 | SESSION_NAME="chemcpa_training" 5 | 6 | # Path to your Python script 7 | SCRIPT_PATH="chemCPA/train_hydra.py" 8 | 9 | # Check if the session already exists 10 | if tmux has-session -t $SESSION_NAME 2>/dev/null; then 11 | echo "Session $SESSION_NAME already exists. Attaching to it." 12 | tmux attach-session -t $SESSION_NAME 13 | else 14 | # Create a new session 15 | tmux new-session -d -s $SESSION_NAME 16 | 17 | # Rename the first window 18 | tmux rename-window -t $SESSION_NAME:0 'training' 19 | 20 | # Send the command to run the Python script 21 | tmux send-keys -t $SESSION_NAME:0 "python $SCRIPT_PATH" C-m 22 | 23 | # Attach to the session 24 | tmux attach-session -t $SESSION_NAME 25 | fi -------------------------------------------------------------------------------- /config/README.md: -------------------------------------------------------------------------------- 1 | # This folder contains the hydra config files for the project. 2 | 3 | The config files are organized as follows: 4 | main.yaml/lincs.yaml/sciplex.yaml/sciplex_finetune.yaml - the root config files 5 | dataset folder - specifies path to dataset and names of keys into it 6 | training folder - configuration of validation/checkpoint/logging behavior 7 | model folder - configuration of model architecture including the embeddings 8 | 9 | For convenience, we provide one main config file for each of the 10 | main experiments in the paper. 11 | 12 | 13 | To train on sciplex use the sciplex.yaml root config file. 14 | This config will train the model on sciplex_complete_v2.h5ad which is created in the fifth preprocessing notebook. 15 | 16 | To (pre)train on LINCS use the lincs.yaml root config file. 17 | This config will train the model on lincs_full_smiles_sciplex_genes.h5ad as created in the third preprocessing notebook. 18 | 19 | To finetune a LINCS model on SciPlex3 use the finetune.yaml root config file. 20 | This utilizes the sciplex_complete_lincs_genes_v2.h5ad dataset created in the fifth preprocessing notebook. 21 | Note, that you need to specify a model hash inside the model folder (config/model/finetune.yaml), pretrained_model_hashed.model field. 22 | This is the name of the folder with the checkpoint inside the training_output folder, which contains the model checkpoints. 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /config/dataset/biolord.yaml: -------------------------------------------------------------------------------- 1 | perturbation_key: drug # stores name of the drug 2 | pert_category: condition # stores celltype_drugname_drugdose 3 | dose_key: dose # stores drug dose as a float 4 | covariate_keys: cell_type # necessary field for cell types. Fill it with a dummy variable if no celltypes present. 5 | smiles_key: smiles 6 | use_drugs_idx: True # If false, will use One-hot encoding instead 7 | split_key: split 8 | dataset_path: project_folder/datasets/adata_biolord_split_30.h5ad 9 | degs_key: rank_genes_groups_cov_all 10 | -------------------------------------------------------------------------------- /config/dataset/biolord_split_30.yaml: -------------------------------------------------------------------------------- 1 | perturbation_key: drug # stores name of the drug 2 | pert_category: condition # stores celltype_drugname_drugdose 3 | dose_key: dose # stores drug dose as a float 4 | covariate_keys: cell_type # necessary field for cell types. Fill it with a dummy variable if no celltypes present. 5 | smiles_key: smiles 6 | use_drugs_idx: True # If false, will use One-hot encoding instead 7 | split_key: split 8 | dataset_path: project_folder/datasets/adata_biolord_split_30.h5ad 9 | degs_key: rank_genes_groups_cov_all 10 | -------------------------------------------------------------------------------- /config/dataset/combinatorial.yaml: -------------------------------------------------------------------------------- 1 | perturbation_key: condition 2 | pert_category: cov_drug_dose 3 | dose_key: dose_value 4 | covariate_keys: cell_type 5 | smiles_key: smiles_rdkit 6 | use_drugs_idx: True 7 | split_key: split 8 | dataset_path: project_folder/datasets/combo_sciplex_prep_hvg_filtered.h5ad 9 | degs_key: rank_genes_groups_cov 10 | -------------------------------------------------------------------------------- /config/dataset/default.yaml: -------------------------------------------------------------------------------- 1 | perturbation_key: condition # stores name of the drug 2 | pert_category: cov_drug_dose_name # stores celltype_drugname_drugdose 3 | dose_key: dose # stores drug dose as a float 4 | covariate_keys: cell_type # necessary field for cell types. Fill it with a dummy variable if no celltypes present. 5 | smiles_key: SMILES 6 | use_drugs_idx: True # If false, will use One-hot encoding instead 7 | split_key: split_cellcycle_ood 8 | dataset_path: project_folder/datasets/sciplex_complete_v2.h5ad 9 | # dataset_path: project_folder/datasets/adata_biolord_split_30_subset.h5ad # full path to the anndata dataset 10 | degs_key: all_DEGs # `uns` column name denoting the DEGs for each perturbation 11 | -------------------------------------------------------------------------------- /config/dataset/lincs.yaml: -------------------------------------------------------------------------------- 1 | perturbation_key: condition # stores name of the drug 2 | pert_category: cov_drug_dose_name # stores celltype_drugname_drugdose 3 | dose_key: dose_val # stores drug dose as a float 4 | covariate_keys: cell_type # necessary field for cell types. Fill it with a dummy variable if no celltypes present. 5 | smiles_key: canonical_smiles 6 | use_drugs_idx: True # If false, will use One-hot encoding instead 7 | split_key: split 8 | dataset_path: project_folder/datasets/lincs_full_smiles.h5ad # full path to the anndata dataset 9 | # dataset_path: project_folder/datasets/adata_biolord_split_30_subset.h5ad # full path to the anndata dataset 10 | degs_key: rank_genes_groups_cov 11 | -------------------------------------------------------------------------------- /config/dataset/lincs_2000_genes.yaml: -------------------------------------------------------------------------------- 1 | perturbation_key: condition # stores name of the drug 2 | pert_category: cov_drug_dose_name # stores celltype_drugname_drugdose 3 | dose_key: dose_val # stores drug dose as a float 4 | covariate_keys: cell_type # necessary field for cell types. Fill it with a dummy variable if no celltypes present. 5 | smiles_key: canonical_smiles 6 | use_drugs_idx: True # If false, will use One-hot encoding instead 7 | split_key: split 8 | dataset_path: project_folder/datasets/lincs_full_smiles.h5ad # full path to the anndata dataset 9 | # dataset_path: project_folder/datasets/adata_biolord_split_30_subset.h5ad # full path to the anndata dataset 10 | degs_key: rank_genes_groups_cov 11 | -------------------------------------------------------------------------------- /config/dataset/sciplex.yaml: -------------------------------------------------------------------------------- 1 | perturbation_key: condition # stores name of the drug 2 | pert_category: cov_drug_dose_name # stores celltype_drugname_drugdose 3 | dose_key: dose_val # stores drug dose as a float 4 | covariate_keys: cell_type # necessary field for cell types. Fill it with a dummy variable if no celltypes present. 5 | smiles_key: SMILES 6 | use_drugs_idx: True # If false, will use One-hot encoding instead 7 | split_key: split_random 8 | dataset_path: project_folder/datasets/sciplex_complete_lincs_genes_v2.h5ad # full path to the anndata dataset 9 | # dataset_path: project_folder/datasets/adata_biolord_split_30_subset.h5ad # full path to the anndata dataset 10 | degs_key: all_DEGs 11 | -------------------------------------------------------------------------------- /config/dataset/sciplex_2000_genes.yaml: -------------------------------------------------------------------------------- 1 | perturbation_key: condition # stores name of the drug 2 | pert_category: cov_drug_dose_name # stores celltype_drugname_drugdose 3 | dose_key: dose_val # stores drug dose as a float 4 | covariate_keys: cell_type # necessary field for cell types. Fill it with a dummy variable if no celltypes present. 5 | smiles_key: SMILES 6 | use_drugs_idx: True # If false, will use One-hot encoding instead 7 | split_key: split_random 8 | dataset_path: project_folder/datasets/sciplex_complete_v2.h5ad # full path to the anndata dataset 9 | # dataset_path: project_folder/datasets/adata_biolord_split_30_subset.h5ad # full path to the anndata dataset 10 | degs_key: all_DEGs 11 | -------------------------------------------------------------------------------- /config/finetune.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - model: finetune 4 | - dataset: sciplex 5 | - training: default 6 | - wandb: default 7 | - hydra: default 8 | 9 | profiling.run_profiler: False 10 | profiling.outdir: "./" 11 | -------------------------------------------------------------------------------- /config/finetune_2000_genes.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - model: finetune_2000_genes 4 | - dataset: sciplex_2000_genes 5 | - training: default 6 | - wandb: default 7 | - hydra: default 8 | 9 | profiling.run_profiler: False 10 | profiling.outdir: "./" 11 | -------------------------------------------------------------------------------- /config/finetune_combinatorial.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - model: finetune_combinatorial 4 | - dataset: combinatorial 5 | - training: default 6 | - wandb: default 7 | - hydra: default 8 | 9 | profiling.run_profiler: False 10 | profiling.outdir: "./" 11 | -------------------------------------------------------------------------------- /config/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | output_subdir: null 2 | -------------------------------------------------------------------------------- /config/lincs.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - model: lincs 4 | - dataset: lincs 5 | - training: default 6 | - wandb: default 7 | - hydra: default 8 | 9 | profiling.run_profiler: False 10 | profiling.outdir: "./" 11 | -------------------------------------------------------------------------------- /config/lincs_2000_genes.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - model: lincs 4 | - dataset: lincs_2000_genes 5 | - training: default 6 | - wandb: default 7 | - hydra: default 8 | 9 | profiling.run_profiler: False 10 | profiling.outdir: "./" 11 | -------------------------------------------------------------------------------- /config/main.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - model: finetune 4 | - dataset: sciplex_lincs_genes 5 | - training: default 6 | - wandb: default 7 | - hydra: default 8 | 9 | profiling.run_profiler: False 10 | profiling.outdir: "./" 11 | -------------------------------------------------------------------------------- /config/model/.pretrain_combinatorial.yaml.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/chemCPA/43e830eb0958c54e4aa64442c17ec0fed19b3f15/config/model/.pretrain_combinatorial.yaml.swp -------------------------------------------------------------------------------- /config/model/additional_params/default.yaml: -------------------------------------------------------------------------------- 1 | patience: 50 # patience for early stopping. Effective epochs: patience * checkpoint_freq. 2 | decoder_activation: ReLU # last layer of the decoder 'linear' or 'ReLU' 3 | doser_type: amortized # non-linearity for doser function 4 | seed: 1337 5 | -------------------------------------------------------------------------------- /config/model/combinatorial_rdkit.yaml: -------------------------------------------------------------------------------- 1 | model: rdkit 2 | datapath: project_folder/embeddings/rdkit/combo_sciplex_prep_hvg_filtered_rdkit2D_embedding.parquet 3 | -------------------------------------------------------------------------------- /config/model/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hparams: default 3 | - embedding: sciplex_lincs_genes 4 | - additional_params: default 5 | 6 | enable_cpa_mode: False 7 | load_pretrained: True 8 | append_ae_layer: False 9 | pretrained_model_path: /path/to/pretrained/models 10 | pretrained_model_hashes: 11 | rdkit: 4f061dbfc7af05cf84f06a724b0c8563 12 | grover: c30016a7469feb78a8ee9ebb18ed9b1f 13 | jtvae: 915345a522c29fa709b995d6149083b9 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /config/model/embedding/biolord_split_30.yaml: -------------------------------------------------------------------------------- 1 | datapath: project_folder/embeddings/rdkit/data/embeddings/rdkit2D_embedding_biolord.parquet 2 | model: rdkit -------------------------------------------------------------------------------- /config/model/embedding/combinatorial_rdkit.yaml: -------------------------------------------------------------------------------- 1 | model: rdkit 2 | datapath: project_folder/embeddings/rdkit/combo_sciplex_prep_hvg_filtered_rdkit2D_embedding.parquet 3 | -------------------------------------------------------------------------------- /config/model/embedding/default.yaml: -------------------------------------------------------------------------------- 1 | datapath: project_folder/embeddings/rdkit/data/embeddings/rdkit2D_embedding_biolord.parquet 2 | model: rdkit -------------------------------------------------------------------------------- /config/model/embedding/lincs.yaml: -------------------------------------------------------------------------------- 1 | datapath: project_folder/embeddings/rdkit/lincs_full_smiles_rdkit2D_embedding.parquet 2 | model: rdkit -------------------------------------------------------------------------------- /config/model/embedding/sciplex_lincs_genes.yaml: -------------------------------------------------------------------------------- 1 | datapath: project_folder/embeddings/rdkit/sciplex_complete_lincs_genes_v2_rdkit2D_embedding.parquet 2 | model: rdkit 3 | -------------------------------------------------------------------------------- /config/model/embedding/sciplex_middle.yaml: -------------------------------------------------------------------------------- 1 | datapath: project_folder/embeddings/rdkit/data/embeddings/sciplex_complete_middle_subset_v2_rdkit_embeddings.parquet 2 | model: rdkit -------------------------------------------------------------------------------- /config/model/finetune.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hparams: default 3 | - embedding: sciplex_lincs_genes 4 | - additional_params: default 5 | 6 | enable_cpa_mode: False 7 | load_pretrained: True 8 | append_ae_layer: False 9 | pretrained_model_path: ${training.save_dir} 10 | pretrained_model_hashes: 11 | model: omon82fh 12 | -------------------------------------------------------------------------------- /config/model/finetune_2000_genes.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hparams: default 3 | - embedding: sciplex_lincs_genes 4 | - additional_params: default 5 | 6 | enable_cpa_mode: False 7 | load_pretrained: True 8 | append_ae_layer: True 9 | pretrained_model_path: ${training.save_dir} 10 | pretrained_model_hashes: 11 | model: omon82fh 12 | -------------------------------------------------------------------------------- /config/model/finetune_combinatorial.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hparams: default 3 | - embedding: combinatorial_rdkit 4 | - additional_params: default 5 | 6 | enable_cpa_mode: False 7 | load_pretrained: True 8 | append_ae_layer: False 9 | pretrained_model_path: ${training.save_dir} 10 | pretrained_model_hashes: 11 | model: szhlrfyl 12 | -------------------------------------------------------------------------------- /config/model/finetune_grover.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - finetune 3 | - embedding: grover -------------------------------------------------------------------------------- /config/model/hparams/default.yaml: -------------------------------------------------------------------------------- 1 | adversary_depth: 3 2 | adversary_lr: 0.0011926173789223548 3 | adversary_steps: 3 4 | adversary_wd: 0.000009846738873614555 5 | adversary_width: 128 6 | autoencoder_depth: 4 7 | autoencoder_lr: 0.0015751320499779737 8 | autoencoder_wd: 6.251373574521742e-7 9 | autoencoder_width: 256 10 | batch_size: 256 11 | dim: 32 12 | dosers_depth: 3 13 | dosers_lr: 0.0015751320499779737 14 | dosers_wd: 6.251373574521742e-7 15 | dosers_width: 64 16 | dropout: 0.262378 17 | embedding_encoder_depth: 4 18 | embedding_encoder_width: 128 19 | penalty_adversary: 0.4550475813202185 20 | reg_adversary: 9.100951626404369 21 | reg_adversary_cov: 16.165583124257587 22 | reg_multi_task: 0 23 | step_size_lr: 6 24 | -------------------------------------------------------------------------------- /config/model/lincs.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hparams: default 3 | - embedding: lincs 4 | - additional_params: default 5 | 6 | enable_cpa_mode: False 7 | load_pretrained: False 8 | append_ae_layer: False 9 | #pretrained_model_path: /path/to/pretrained/models 10 | #pretrained_model_hashes: 11 | # rdkit: 4f061dbfc7af05cf84f06a724b0c8563 12 | # grover: c30016a7469feb78a8ee9ebb18ed9b1f 13 | # jtvae: 915345a522c29fa709b995d6149083b9 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /config/model/pretrain_combinatorial.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hparams: default 3 | - embedding: combinatorial_rdkit 4 | - additional_params: default 5 | 6 | enable_cpa_mode: False 7 | load_pretrained: False 8 | append_ae_layer: False 9 | -------------------------------------------------------------------------------- /config/model/sciplex.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hparams: default 3 | - embedding: sciplex_lincs_genes 4 | - additional_params: default 5 | 6 | enable_cpa_mode: False 7 | load_pretrained: False 8 | append_ae_layer: False -------------------------------------------------------------------------------- /config/pretrain_combinatorial.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - model: pretrain_combinatorial 4 | - dataset: combinatorial 5 | - training: default 6 | - wandb: default 7 | - hydra: default 8 | 9 | profiling.run_profiler: False 10 | profiling.outdir: "./" 11 | -------------------------------------------------------------------------------- /config/sciplex.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - model: sciplex 4 | - dataset: sciplex 5 | - training: default 6 | - wandb: default 7 | - hydra: default 8 | 9 | profiling.run_profiler: False 10 | profiling.outdir: "./" 11 | -------------------------------------------------------------------------------- /config/training/default.yaml: -------------------------------------------------------------------------------- 1 | checkpoint_freq: 50 # checkpoint frequency to run evaluate, and maybe save checkpoint 2 | num_epochs: 201 # maximum epochs for training. One epoch updates either autoencoder, or adversary, depending on adversary_steps. 3 | max_minutes: 00:20:00:00 # maximum computation time, DD:hh:MM:SS 4 | full_eval_during_train: False 5 | run_eval_disentangle: True # whether to calc the disentanglement loss when running the full eval 6 | run_eval_r2: True 7 | run_eval_r2_sc: False 8 | run_eval_logfold: False 9 | save_checkpoints: True # checkpoints tend to be ~250MB large for LINCS. 10 | save_dir: ./training_output -------------------------------------------------------------------------------- /config/wandb/default.yaml: -------------------------------------------------------------------------------- 1 | entity: biroscak 2 | project: chemCPA -------------------------------------------------------------------------------- /docker_entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | JUPYTER_PASSWORD="chemCPA" 3 | 4 | # Source conda 5 | source /home/user/conda/etc/profile.d/conda.sh 6 | 7 | # Activate conda environment 8 | conda activate chemCPA 9 | 10 | # Set LD_LIBRARY_PATH 11 | export LD_LIBRARY_PATH=/home/user/conda/envs/chemCPA/lib:$LD_LIBRARY_PATH 12 | 13 | # Create jupyter config directory if it doesn't exist 14 | mkdir -p ~/.jupyter 15 | 16 | # Set up the password 17 | python -c "from jupyter_server.auth import passwd; print(passwd('$JUPYTER_PASSWORD'))" > ~/.jupyter/jupyter_server_password.txt 18 | HASHED_PASSWORD=$(cat ~/.jupyter/jupyter_server_password.txt) 19 | 20 | # Create jupyter server config 21 | cat > ~/.jupyter/jupyter_server_config.py << EOF 22 | c.ServerApp.password = '$HASHED_PASSWORD' 23 | c.ServerApp.password_required = True 24 | EOF 25 | 26 | chmod 1777 /tmp 27 | chmod 755 /home/user 28 | 29 | sudo apt update; 30 | DEBIAN_FRONTEND=noninteractive sudo apt-get install openssh-server -y; 31 | mkdir -p ~/.ssh; 32 | cd $_; 33 | chmod 700 ~/.ssh; 34 | echo "$PUBLIC_KEY" >> authorized_keys; 35 | chmod 700 authorized_keys; 36 | service ssh start; 37 | 38 | echo "Starting Jupyter Lab with password authentication" 39 | 40 | NOTEBOOK_DIR="/home/user/chemCPA" 41 | 42 | jupyter-lab \ 43 | --ip 0.0.0.0 \ 44 | --port 8888 \ 45 | --no-browser \ 46 | --allow-root \ 47 | --notebook-dir=$NOTEBOOK_DIR 48 | -------------------------------------------------------------------------------- /docs/chemCPA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/chemCPA/43e830eb0958c54e4aa64442c17ec0fed19b3f15/docs/chemCPA.png -------------------------------------------------------------------------------- /download_training_output.sh: -------------------------------------------------------------------------------- 1 | gdown "https://drive.google.com/uc?id=1wmcVdSxXMWnlccJnfDrrUVlY3Fr87Ru0" -O training_output.zip 2 | unzip training_output.zip -d training_output 3 | rm training_output.zip 4 | -------------------------------------------------------------------------------- /embeddings/chemvae/README.md: -------------------------------------------------------------------------------- 1 | # Training the Chemical VAE 2 | 3 | This is the chemical VAE as presented in [this paper](https://pubs.acs.org/doi/abs/10.1021/acscentsci.7b00572). 4 | The only difference is that we're not jointly training a mol property predictor. 5 | 6 | The files need to contain a single SMILES per row, no trailing comma. 7 | Header needs to be `SMILES`. 8 | 9 | We run with the standard hyperparameters, except for the max KL weight (β), which we set to `1.0` to get 10 | a more disentangled latent space. 11 | 12 | Dimension of embedding: 128 13 | 14 | ```bash 15 | # header 16 | echo "SMILES" > ../lincs_trapnell_zinc.csv 17 | # add train SMILES to file 18 | tail -n +2 ../lincs_trapnell.smiles >> ../lincs_trapnell_zinc.csv # skip the first line (`smiles`) 19 | cat ../zinc_smiles_train.csv >> ../lincs_trapnell_zincs.csv 20 | # start the training 21 | ./train_chemvae.sh 22 | ``` -------------------------------------------------------------------------------- /embeddings/chemvae/generate_embeddings.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # notebook_metadata_filter: -kernelspec 5 | # text_representation: 6 | # extension: .py 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.14.1 10 | # --- 11 | 12 | # %% 13 | from pathlib import Path 14 | 15 | # %% 16 | import moses 17 | import numpy as np 18 | import pandas as pd 19 | import torch 20 | from moses.vae import VAE 21 | from rdkit import Chem, RDLogger 22 | from tqdm import tqdm 23 | 24 | RDLogger.logger().setLevel(RDLogger.CRITICAL) 25 | RDLogger.DisableLog("rdApp.*") 26 | 27 | # %% 28 | config_fpath = Path( 29 | "/storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/embeddings/chemvae/config.txt" 30 | ) 31 | state_dict_fpath = Path( 32 | "/storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/embeddings/chemvae/vae_checkpoint_final.pt" 33 | ) 34 | 35 | # %% pycharm={"name": "#%%\n"} 36 | config = torch.load(config_fpath) 37 | vocab = torch.load(config.vocab_save) 38 | state = torch.load(state_dict_fpath) 39 | 40 | # %% 41 | model = VAE(vocab, config) 42 | model.load_state_dict(state) 43 | model.to("cuda") 44 | model.eval() 45 | 46 | # %% 47 | all_smiles = list( 48 | pd.read_csv( 49 | "/storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/embeddings/lincs_trapnell.smiles" 50 | )["smiles"].values 51 | ) 52 | 53 | # %% 54 | embeddings = [] 55 | for s in tqdm(all_smiles): 56 | with torch.no_grad(): 57 | tensors = [model.string2tensor(s)] 58 | emb, _ = model.forward_encoder(tensors) 59 | embeddings.append(emb.cpu().numpy()) 60 | 61 | # %% 62 | emb = np.concatenate(embeddings, axis=0) 63 | final_df = pd.DataFrame( 64 | emb, index=all_smiles, columns=[f"latent_{i+1}" for i in range(emb.shape[1])] 65 | ) 66 | final_df.to_parquet("chemvae.parquet") 67 | final_df 68 | 69 | 70 | # %% [markdown] 71 | # ## Bit of testing 72 | # 73 | # Testing sampled SMILES for validitiy 74 | 75 | # %% 76 | def smiles_is_syntatically_valid(smiles): 77 | return Chem.MolFromSmiles(smiles, sanitize=False) is not None 78 | 79 | 80 | def smiles_is_semantically_valid(smiles): 81 | valid = True 82 | try: 83 | Chem.SanitizeMol(Chem.MolFromSmiles(smiles, sanitize=False)) 84 | except: 85 | valid = False 86 | return valid 87 | 88 | 89 | # %% 90 | samples = model.sample(1000) 91 | 92 | # %% 93 | syn_valid = sum(smiles_is_syntatically_valid(s) for s in samples) / len(samples) 94 | sem_valid = sum(smiles_is_syntatically_valid(s) for s in samples) / len(samples) 95 | print(f"TOTAL: {len(samples)} SYN: {syn_valid} SEM: {sem_valid}") 96 | 97 | # %% 98 | samples 99 | 100 | # %% 101 | -------------------------------------------------------------------------------- /embeddings/chemvae/train.py: -------------------------------------------------------------------------------- 1 | # this is a dump of the train.py file at moses/vae/train.py. I didn't want to clone the whole repo 2 | import argparse 3 | import os 4 | import sys 5 | 6 | import rdkit 7 | import torch 8 | from moses.dataset import get_dataset 9 | from moses.models_storage import ModelsStorage 10 | from moses.script_utils import add_train_args, read_smiles_csv, set_seed 11 | 12 | lg = rdkit.RDLogger.logger() 13 | lg.setLevel(rdkit.RDLogger.CRITICAL) 14 | 15 | MODELS = ModelsStorage() 16 | 17 | 18 | def get_parser(): 19 | parser = argparse.ArgumentParser() 20 | subparsers = parser.add_subparsers( 21 | title="Models trainer script", description="available models" 22 | ) 23 | for model in MODELS.get_model_names(): 24 | add_train_args( 25 | MODELS.get_model_train_parser(model)(subparsers.add_parser(model)) 26 | ) 27 | return parser 28 | 29 | 30 | def main(model, config): 31 | set_seed(config.seed) 32 | device = torch.device(config.device) 33 | 34 | if config.config_save is not None: 35 | torch.save(config, config.config_save) 36 | 37 | # For CUDNN to work properly 38 | if device.type.startswith("cuda"): 39 | torch.cuda.set_device(device.index or 0) 40 | if config.train_load is None: 41 | train_data = get_dataset("train") 42 | else: 43 | train_data = read_smiles_csv(config.train_load) 44 | if config.val_load is None: 45 | val_data = get_dataset("test") 46 | else: 47 | val_data = read_smiles_csv(config.val_load) 48 | trainer = MODELS.get_model_trainer(model)(config) 49 | 50 | if config.vocab_load is not None: 51 | assert os.path.exists(config.vocab_load), "vocab_load path does not exist!" 52 | vocab = torch.load(config.vocab_load) 53 | else: 54 | vocab = trainer.get_vocabulary(train_data) 55 | 56 | if config.vocab_save is not None: 57 | torch.save(vocab, config.vocab_save) 58 | 59 | model = MODELS.get_model_class(model)(vocab, config).to(device) 60 | trainer.fit(model, train_data, val_data) 61 | 62 | model = model.to("cpu") 63 | torch.save(model.state_dict(), config.model_save) 64 | 65 | 66 | if __name__ == "__main__": 67 | parser = get_parser() 68 | config = parser.parse_args() 69 | model = sys.argv[1] 70 | main(model, config) 71 | -------------------------------------------------------------------------------- /embeddings/chemvae/train_chemvae.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python train.py vae --train_load ../lincs_trapnell_zinc.csv \ 4 | --val_load ../zinc_smiles_test.txt \ 5 | --config_save /storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/embeddings/chemvae/config.txt \ 6 | --model_save /storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/embeddings/chemvae/model.pt \ 7 | --vocab_save /storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/embeddings/chemvae/vocab.txt \ 8 | --device cuda:0 \ 9 | --kl_w_end 1.0 -------------------------------------------------------------------------------- /embeddings/dgl/embedding_pretrained_gnn.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # notebook_metadata_filter: -kernelspec 5 | # text_representation: 6 | # extension: .py 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.14.1 10 | # --- 11 | 12 | # %% [markdown] 13 | # ## General imports 14 | 15 | # %% 16 | import sys 17 | 18 | sys.path.insert( 19 | 0, "/" 20 | ) # this depends on the notebook depth and must be adapted per notebook 21 | # %% 22 | import numpy as np 23 | from compert.paths import DATA_DIR, EMBEDDING_DIR 24 | from dgllife.utils import ( 25 | CanonicalAtomFeaturizer, 26 | CanonicalBondFeaturizer, 27 | smiles_to_bigraph, 28 | ) 29 | 30 | # %% [markdown] 31 | # ## Load Smiles list 32 | 33 | # %% 34 | dataset_name = "lincs_trapnell" 35 | 36 | # %% 37 | import pandas as pd 38 | 39 | smiles_df = pd.read_csv(EMBEDDING_DIR / f"{dataset_name}.smiles") 40 | smiles_list = smiles_df["smiles"].values 41 | 42 | # %% 43 | print(f"Number of smiles strings: {len(smiles_list)}") 44 | 45 | # %% [markdown] 46 | # ## Featurizer functions 47 | 48 | # %% 49 | node_feats = CanonicalAtomFeaturizer(atom_data_field="h") 50 | edge_feats = CanonicalBondFeaturizer(bond_data_field="h", self_loop=True) 51 | 52 | # %% [markdown] 53 | # ## Create graphs from smiles and featurizers 54 | 55 | # %% 56 | mol_graphs = [] 57 | 58 | for smiles in smiles_list: 59 | mol_graphs.append( 60 | smiles_to_bigraph( 61 | smiles=smiles, 62 | add_self_loop=True, 63 | node_featurizer=node_feats, 64 | edge_featurizer=edge_feats, 65 | ) 66 | ) 67 | 68 | # %% 69 | print(f"Number of molecular graphs: {len(mol_graphs)}") 70 | 71 | # %% [markdown] 72 | # ## Batch graphs 73 | 74 | # %% 75 | import dgl 76 | 77 | mol_batch = dgl.batch(mol_graphs) 78 | 79 | # %% 80 | mol_batch 81 | 82 | # %% [markdown] 83 | # ## Load pretrained model 84 | 85 | # %% [markdown] 86 | # Choose a model form [here](https://lifesci.dgl.ai/api/model.pretrain.html) 87 | 88 | # %% 89 | model_name = "GCN_canonical_PCBA" 90 | # model_name = 'MPNN_canonical_PCBA' 91 | # model_name = 'AttentiveFP_canonical_PCBA' 92 | # model_name = 'Weave_canonical_PCBA' 93 | # model_name = 'GCN_Tox21' 94 | 95 | # %% 96 | from dgllife.model import load_pretrained 97 | 98 | model = load_pretrained(model_name) 99 | 100 | verbose = True 101 | if verbose: 102 | print(model) 103 | 104 | # %% [markdown] 105 | # ## Predict with pretrained model 106 | 107 | # %% [markdown] 108 | # ### Take readout, just before prediction 109 | 110 | # %% 111 | model.eval() 112 | # no edge features 113 | prediction = model(mol_batch, mol_batch.ndata["h"]) 114 | # # with edge features 115 | # prediction = model(mol_batch, mol_batch.ndata['h'], mol_batch.edata['h']) 116 | print(f"Prediction has shape: {prediction.shape}") 117 | prediction 118 | 119 | # %% [markdown] 120 | # ## Save 121 | 122 | # %% 123 | import pandas as pd 124 | 125 | df = pd.DataFrame( 126 | data=prediction.detach().numpy(), 127 | index=smiles_list, 128 | columns=[f"latent_{i+1}" for i in range(prediction.size()[1])], 129 | ) 130 | 131 | # %% 132 | import os 133 | 134 | fname = f"{model_name}_embedding_{dataset_name}.parquet" 135 | 136 | directory = EMBEDDING_DIR / "dgl" / "data" / "embeddings" 137 | if not directory.exists(): 138 | os.makedirs(directory) 139 | print(f"Created folder: {directory}") 140 | 141 | df.to_parquet(directory / fname) 142 | 143 | # %% [markdown] 144 | # Check that it worked 145 | 146 | # %% 147 | df = pd.read_parquet(directory / fname) 148 | df 149 | 150 | # %% 151 | df.std() 152 | 153 | # %% [markdown] 154 | # ## Drawing molecules 155 | 156 | # %% 157 | from IPython.display import SVG 158 | from rdkit import Chem 159 | from rdkit.Chem import Draw 160 | 161 | # %% 162 | mols = [Chem.MolFromSmiles(s) for s in smiles_list[:14]] 163 | Draw.MolsToGridImage(mols, molsPerRow=7, subImgSize=(180, 150)) 164 | 165 | # %% 166 | -------------------------------------------------------------------------------- /embeddings/grover/README.md: -------------------------------------------------------------------------------- 1 | 2 | GROVER 3 | === 4 | This is a dump of the GROVER repository (+ a small bugfix [commit](https://github.com/tencent-ailab/grover/issues/8#issuecomment-908864507)). 5 | [link](https://github.com/tencent-ailab/grover) to the original repository. 6 | 7 | To setup the environment: `mamba env create` (or `conda env create` if you have a lot of time). 8 | 9 | ## Pretained Model Download 10 | Download the pretrained models from Tencent and put them into the `data/model` directory. 11 | - [GROVERbase](https://ai.tencent.com/ailab/ml/ml-data/grover-models/pretrain/grover_base.tar.gz) 12 | - [GROVERlarge](https://ai.tencent.com/ailab/ml/ml-data/grover-models/pretrain/grover_large.tar.gz) 13 | 14 | ## Generating the embeddings 15 | In `generate_embeddings.ipynb` there is code that extracts all SMILES from LINCs & trapnell, generates the 16 | grover embeddings from them and saves them to a DataFrame. 17 | 18 | -------------------------------------------------------------------------------- /embeddings/grover/data/embeddings/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/chemCPA/43e830eb0958c54e4aa64442c17ec0fed19b3f15/embeddings/grover/data/embeddings/.gitkeep -------------------------------------------------------------------------------- /embeddings/grover/data/model/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/chemCPA/43e830eb0958c54e4aa64442c17ec0fed19b3f15/embeddings/grover/data/model/.gitkeep -------------------------------------------------------------------------------- /embeddings/grover/environment.yml: -------------------------------------------------------------------------------- 1 | name: grover 2 | channels: 3 | - pytorch 4 | - rdkit 5 | - conda-forge 6 | - rmg 7 | dependencies: 8 | - boost=1.68.0=py36h8619c78_1001 9 | - boost-cpp=1.68.0=h11c811c_1000 10 | - descriptastorus=2.2.0=py_0 11 | - numpy=1.16.4=py36h7e9f1db_0 12 | - numpy-base=1.16.4=py36hde5b4d6_0 13 | - pandas=0.25.0=py36hb3f55d8_0 14 | - python=3.6.8=h0371630_0 15 | - pytorch=1.1.0=py3.6_cuda9.0.176_cudnn7.5.1_0 16 | - tensorboard=1.13.1=py36_0 17 | - torchvision=0.3.0=py36_cu9.0.176_1 18 | - rdkit=2019.03.4.0=py36hc20afe1_1 19 | - readline=7.0=h7b6447c_5 20 | - scikit-learn=0.21.2=py36hcdab131_1 21 | - scipy=1.3.0=py36h921218d_1 22 | - tqdm=4.32.1=py_0 23 | - typing=3.6.4=py36_0 24 | - pyarrow 25 | - jupyter 26 | - scanpy 27 | -------------------------------------------------------------------------------- /embeddings/grover/generate_embeddings.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # notebook_metadata_filter: -kernelspec 5 | # text_representation: 6 | # extension: .py 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.14.1 10 | # --- 11 | 12 | # %% [markdown] 13 | # # GROVER 14 | # Generate GROVER fingerprints for SMILES-drugs coming from LINCS + SciPlex3. 15 | # 16 | # Steps: 17 | # 1. Load `lincs_trapnell.smiles` as the list of SMILES to be encoded 18 | # 2. Generate fingerprints using GROVER 19 | # 3. Save SMILES -> fingerprint mapping as a pandas df. 20 | # 21 | # from pathlib import Path 22 | # 23 | # import numpy as np 24 | # import pandas as pd 25 | 26 | # %% 27 | import rdkit 28 | 29 | # %% 30 | import scanpy as sc 31 | from rdkit import Chem 32 | 33 | rdkit.__version__ 34 | 35 | # %% 36 | # SET 37 | datasets_fpath = Path("/home/icb/simon.boehm/Masters_thesis/MT_code/datasets") 38 | all_smiles_fpath = Path.cwd().parent / "lincs_trapnell.smiles" 39 | 40 | # %% [markdown] 41 | # ## Step 1: Generate fingerprints 42 | # 43 | # - TODO: Right now we generate `rdkit_2d_normalized` features. Are these the correct ones? 44 | # - TODO: There are pretrained & finetuned models also available, maybe that's useful for us: 45 | # - SIDER: Drug side effect prediction task 46 | # - ClinTox: Drug toxicity prediction task 47 | # - ChEMBL log P prediction task 48 | 49 | # %% language="bash" 50 | # set -euox pipefail 51 | # 52 | # # move csv of all smiles to be encoded into current workdir 53 | # cp ../lincs_trapnell.smiles data/embeddings/lincs_trapnell.csv 54 | # file="data/embeddings/lincs_trapnell.csv" 55 | # 56 | # # First we generate the feature embedding for the SMILES, which is an extra input 57 | # # into GROVER 58 | # echo "FILE: $file" 59 | # features=$(echo $file | sed 's:.csv:.npz:') 60 | # if [[ ! -f $features ]]; then 61 | # echo "Generating features: $features" 62 | # python scripts/save_features.py --data_path "$file" \ 63 | # --save_path "$features" \ 64 | # --features_generator rdkit_2d_normalized \ 65 | # --restart 66 | # fi; 67 | # 68 | # # Second we input SMILES + Features into grover and get the fingerprint out 69 | # # 'both' means we get a concatenated fingerprint of combined atoms + bonds features 70 | # outfile=$(echo $file | sed 's:.csv:_grover_base_both.npz:') 71 | # echo "EMB: $outfile" 72 | # if [[ ! -f $outfile ]]; then 73 | # echo "Generating embedding: $outfile" 74 | # python main.py fingerprint --data_path "$file" \ 75 | # --features_path "$features" \ 76 | # --checkpoint_path data/model/grover_base.pt \ 77 | # --fingerprint_source both \ 78 | # --output "$outfile" 79 | # fi; 80 | 81 | # %% 82 | lincs_trapnell_base = np.load("data/embeddings/lincs_trapnell_grover_base_both.npz") 83 | print("Shape of GROVER_base embedding:", lincs_trapnell_base["fps"].shape) 84 | 85 | 86 | # %% [markdown] 87 | # ## Step 2: Generate DataFrame with SMILES -> Embedding mapping 88 | 89 | # %% 90 | def flatten(x: np.ndarray): 91 | assert len(x.shape) == 2 and x.shape[0] == 1 92 | return x[0] 93 | 94 | 95 | embeddings_fpath = Path("data/embeddings") 96 | smiles_file = embeddings_fpath / "lincs_trapnell.csv" 97 | emb_file = embeddings_fpath / "lincs_trapnell_grover_base_both.npz" 98 | 99 | # read list of smiles 100 | smiles_df = pd.read_csv(smiles_file) 101 | # read generated embedding (.npz has only one key, 'fps') 102 | emb = np.load(emb_file)["fps"] 103 | assert len(smiles_df) == emb.shape[0] 104 | 105 | # generate a DataFrame with SMILES and Embedding in each row 106 | final_df = pd.DataFrame( 107 | emb, 108 | index=smiles_df["smiles"].values, 109 | columns=[f"latent_{i+1}" for i in range(emb.shape[1])], 110 | ) 111 | # remove duplicates indices (=SMILES) (This is probably useless) 112 | final_df = final_df[~final_df.index.duplicated(keep="first")] 113 | final_df.to_parquet(embeddings_fpath / "grover_base.parquet") 114 | 115 | # %% 116 | df = pd.read_parquet("data/embeddings/grover_base.parquet") 117 | 118 | # %% 119 | df 120 | 121 | # %% [markdown] 122 | # ## Step 3: Check 123 | # Make extra sure the index of the generated dataframe is correct by loading our list of canonical SMILES again 124 | 125 | # %% 126 | all_smiles_fpath = Path.cwd().parent / "lincs_trapnell.smiles" 127 | all_smiles = pd.read_csv(all_smiles_fpath)["smiles"].values 128 | assert sorted(list(df.index)) == sorted(list(all_smiles)) 129 | -------------------------------------------------------------------------------- /embeddings/grover/grover/data/__init__.py: -------------------------------------------------------------------------------- 1 | from grover.data.moldataset import MoleculeDatapoint, MoleculeDataset 2 | from grover.data.molfeaturegenerator import ( 3 | get_available_features_generators, 4 | get_features_generator, 5 | ) 6 | from grover.data.molgraph import ( 7 | BatchMolGraph, 8 | MolCollator, 9 | MolGraph, 10 | get_atom_fdim, 11 | get_bond_fdim, 12 | mol2graph, 13 | ) 14 | from grover.data.scaler import StandardScaler 15 | 16 | # from .utils import load_features, save_features 17 | -------------------------------------------------------------------------------- /embeddings/grover/grover/data/dist_sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | The re-implemented distributed sampler for the distributed training of GROVER. 3 | """ 4 | import math 5 | import time 6 | 7 | import torch 8 | import torch.distributed as dist 9 | from torch.utils.data.sampler import Sampler 10 | 11 | 12 | class DistributedSampler(Sampler): 13 | """Sampler that restricts data loading to a subset of the dataset. 14 | 15 | It is especially useful in conjunction with 16 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 17 | process can pass a DistributedSampler instance as a DataLoader sampler, 18 | and load a subset of the original dataset that is exclusive to it. 19 | 20 | .. note:: 21 | Dataset is assumed to be of constant size. 22 | 23 | Arguments: 24 | dataset: Dataset used for sampling. 25 | num_replicas (optional): Number of processes participating in 26 | distributed training. 27 | rank (optional): Rank of the current process within num_replicas. 28 | """ 29 | 30 | def __init__( 31 | self, dataset, num_replicas=None, rank=None, shuffle=True, sample_per_file=None 32 | ): 33 | if num_replicas is None: 34 | if not dist.is_available(): 35 | raise RuntimeError("Requires distributed package to be available") 36 | num_replicas = dist.get_world_size() 37 | if rank is None: 38 | if not dist.is_available(): 39 | raise RuntimeError("Requires distributed package to be available") 40 | rank = dist.get_rank() 41 | self.dataset = dataset 42 | self.num_replicas = num_replicas 43 | self.rank = rank 44 | self.epoch = 0 45 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 46 | self.total_size = self.num_samples * self.num_replicas 47 | self.sample_per_file = sample_per_file 48 | self.shuffle = shuffle 49 | 50 | def get_indices(self): 51 | 52 | indices = list(range(len(self.dataset))) 53 | 54 | if self.sample_per_file is not None: 55 | indices = self.sub_indices_of_rank(indices) 56 | else: 57 | # add extra samples to make it evenly divisible 58 | indices += indices[: (self.total_size - len(indices))] 59 | assert len(indices) == self.total_size 60 | # subsample 61 | s = self.rank * self.num_samples 62 | e = min((self.rank + 1) * self.num_samples, len(indices)) 63 | 64 | # indices = indices[self.rank:self.total_size:self.num_replicas] 65 | indices = indices[s:e] 66 | 67 | if self.shuffle: 68 | g = torch.Generator() 69 | # the seed need to be considered. 70 | g.manual_seed((self.epoch + 1) * (self.rank + 1) * time.time()) 71 | idx = torch.randperm(len(indices), generator=g).tolist() 72 | indices = [indices[i] for i in idx] 73 | 74 | # disable this since sub_indices_of_rank. 75 | # assert len(indices) == self.num_samples 76 | 77 | return indices 78 | 79 | def sub_indices_of_rank(self, indices): 80 | 81 | # fix generator for each epoch 82 | g = torch.Generator() 83 | # All data should be loaded in each epoch. 84 | g.manual_seed((self.epoch + 1) * 2 + 3) 85 | 86 | # the fake file indices to cache 87 | f_indices = list( 88 | range(int(math.ceil(len(indices) * 1.0 / self.sample_per_file))) 89 | ) 90 | idx = torch.randperm(len(f_indices), generator=g).tolist() 91 | f_indices = [f_indices[i] for i in idx] 92 | 93 | file_per_rank = int(math.ceil(len(f_indices) * 1.0 / self.num_replicas)) 94 | # add extra fake file to make it evenly divisible 95 | f_indices += f_indices[: (file_per_rank * self.num_replicas - len(f_indices))] 96 | 97 | # divide index by rank 98 | rank_s = self.rank * file_per_rank 99 | rank_e = min((self.rank + 1) * file_per_rank, len(f_indices)) 100 | 101 | # get file index for this rank 102 | f_indices = f_indices[rank_s:rank_e] 103 | # print("f_indices") 104 | # print(f_indices) 105 | res_indices = [] 106 | for fi in f_indices: 107 | # get real indices for this rank 108 | si = fi * self.sample_per_file 109 | ei = min((fi + 1) * self.sample_per_file, len(indices)) 110 | cur_idx = [indices[i] for i in range(si, ei)] 111 | res_indices += cur_idx 112 | 113 | self.num_samples = len(res_indices) 114 | return res_indices 115 | 116 | def __iter__(self): 117 | return iter(self.get_indices()) 118 | 119 | def __len__(self): 120 | return self.num_samples 121 | 122 | def set_epoch(self, epoch): 123 | self.epoch = epoch 124 | 125 | 126 | if __name__ == "__main__": 127 | # dataset = [1] * 9 128 | # ds = DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=True) 129 | # print(ds.get_indices()) 130 | # ds = DistributedSampler(dataset, num_replicas=2, rank=1, shuffle=True) 131 | # print(ds.get_indices()) 132 | 133 | dataset = [1] * 190001 134 | res = [] 135 | ds = DistributedSampler( 136 | dataset, num_replicas=2, rank=0, shuffle=True, sample_per_file=777 137 | ) 138 | res.extend(ds.get_indices()) 139 | print(len(ds.get_indices())) 140 | ds = DistributedSampler( 141 | dataset, num_replicas=2, rank=1, shuffle=True, sample_per_file=777 142 | ) 143 | res.extend(ds.get_indices()) 144 | print(len(ds.get_indices())) 145 | print(len(set(res))) 146 | print("hello") 147 | -------------------------------------------------------------------------------- /embeddings/grover/grover/data/scaler.py: -------------------------------------------------------------------------------- 1 | """ 2 | The scaler for the regression task. 3 | This implementation is adapted from 4 | https://github.com/chemprop/chemprop/blob/master/chemprop/data/scaler.py 5 | """ 6 | from typing import Any, List 7 | 8 | import numpy as np 9 | 10 | 11 | class StandardScaler: 12 | """A StandardScaler normalizes a dataset. 13 | 14 | When fit on a dataset, the StandardScaler learns the mean and standard deviation across the 0th axis. 15 | When transforming a dataset, the StandardScaler subtracts the means and divides by the standard deviations. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | means: np.ndarray = None, 21 | stds: np.ndarray = None, 22 | replace_nan_token: Any = None, 23 | ): 24 | """ 25 | Initialize StandardScaler, optionally with means and standard deviations precomputed. 26 | 27 | :param means: An optional 1D numpy array of precomputed means. 28 | :param stds: An optional 1D numpy array of precomputed standard deviations. 29 | :param replace_nan_token: The token to use in place of nans. 30 | """ 31 | self.means = means 32 | self.stds = stds 33 | self.replace_nan_token = replace_nan_token 34 | 35 | def fit(self, X: List[List[float]]) -> "StandardScaler": 36 | """ 37 | Learns means and standard deviations across the 0th axis. 38 | 39 | :param X: A list of lists of floats. 40 | :return: The fitted StandardScaler. 41 | """ 42 | X = np.array(X).astype(float) 43 | self.means = np.nanmean(X, axis=0) 44 | self.stds = np.nanstd(X, axis=0) 45 | self.means = np.where( 46 | np.isnan(self.means), np.zeros(self.means.shape), self.means 47 | ) 48 | self.stds = np.where(np.isnan(self.stds), np.ones(self.stds.shape), self.stds) 49 | self.stds = np.where(self.stds == 0, np.ones(self.stds.shape), self.stds) 50 | 51 | return self 52 | 53 | def transform(self, X: List[List[float]]): 54 | """ 55 | Transforms the data by subtracting the means and dividing by the standard deviations. 56 | 57 | :param X: A list of lists of floats. 58 | :return: The transformed data. 59 | """ 60 | X = np.array(X).astype(float) 61 | transformed_with_nan = (X - self.means) / self.stds 62 | transformed_with_none = np.where( 63 | np.isnan(transformed_with_nan), self.replace_nan_token, transformed_with_nan 64 | ) 65 | 66 | return transformed_with_none 67 | 68 | def inverse_transform(self, X: List[List[float]]): 69 | """ 70 | Performs the inverse transformation by multiplying by the standard deviations and adding the means. 71 | 72 | :param X: A list of lists of floats. 73 | :return: The inverse transformed data. 74 | """ 75 | if isinstance(X, np.ndarray) or isinstance(X, list): 76 | X = np.array(X).astype(float) 77 | transformed_with_nan = X * self.stds + self.means 78 | transformed_with_none = np.where( 79 | np.isnan(transformed_with_nan), 80 | self.replace_nan_token, 81 | transformed_with_nan, 82 | ) 83 | return transformed_with_none 84 | -------------------------------------------------------------------------------- /embeddings/grover/grover/data/task_labels.py: -------------------------------------------------------------------------------- 1 | """ 2 | The label generator for the pretraining. 3 | """ 4 | from collections import Counter 5 | from typing import Callable, Union 6 | 7 | import numpy as np 8 | from descriptastorus.descriptors import rdDescriptors 9 | from grover.data.molfeaturegenerator import register_features_generator 10 | from rdkit import Chem 11 | 12 | Molecule = Union[str, Chem.Mol] 13 | FeaturesGenerator = Callable[[Molecule], np.ndarray] 14 | 15 | # The functional group descriptors in RDkit. 16 | RDKIT_PROPS = [ 17 | "fr_Al_COO", 18 | "fr_Al_OH", 19 | "fr_Al_OH_noTert", 20 | "fr_ArN", 21 | "fr_Ar_COO", 22 | "fr_Ar_N", 23 | "fr_Ar_NH", 24 | "fr_Ar_OH", 25 | "fr_COO", 26 | "fr_COO2", 27 | "fr_C_O", 28 | "fr_C_O_noCOO", 29 | "fr_C_S", 30 | "fr_HOCCN", 31 | "fr_Imine", 32 | "fr_NH0", 33 | "fr_NH1", 34 | "fr_NH2", 35 | "fr_N_O", 36 | "fr_Ndealkylation1", 37 | "fr_Ndealkylation2", 38 | "fr_Nhpyrrole", 39 | "fr_SH", 40 | "fr_aldehyde", 41 | "fr_alkyl_carbamate", 42 | "fr_alkyl_halide", 43 | "fr_allylic_oxid", 44 | "fr_amide", 45 | "fr_amidine", 46 | "fr_aniline", 47 | "fr_aryl_methyl", 48 | "fr_azide", 49 | "fr_azo", 50 | "fr_barbitur", 51 | "fr_benzene", 52 | "fr_benzodiazepine", 53 | "fr_bicyclic", 54 | "fr_diazo", 55 | "fr_dihydropyridine", 56 | "fr_epoxide", 57 | "fr_ester", 58 | "fr_ether", 59 | "fr_furan", 60 | "fr_guanido", 61 | "fr_halogen", 62 | "fr_hdrzine", 63 | "fr_hdrzone", 64 | "fr_imidazole", 65 | "fr_imide", 66 | "fr_isocyan", 67 | "fr_isothiocyan", 68 | "fr_ketone", 69 | "fr_ketone_Topliss", 70 | "fr_lactam", 71 | "fr_lactone", 72 | "fr_methoxy", 73 | "fr_morpholine", 74 | "fr_nitrile", 75 | "fr_nitro", 76 | "fr_nitro_arom", 77 | "fr_nitro_arom_nonortho", 78 | "fr_nitroso", 79 | "fr_oxazole", 80 | "fr_oxime", 81 | "fr_para_hydroxylation", 82 | "fr_phenol", 83 | "fr_phenol_noOrthoHbond", 84 | "fr_phos_acid", 85 | "fr_phos_ester", 86 | "fr_piperdine", 87 | "fr_piperzine", 88 | "fr_priamide", 89 | "fr_prisulfonamd", 90 | "fr_pyridine", 91 | "fr_quatN", 92 | "fr_sulfide", 93 | "fr_sulfonamd", 94 | "fr_sulfone", 95 | "fr_term_acetylene", 96 | "fr_tetrazole", 97 | "fr_thiazole", 98 | "fr_thiocyan", 99 | "fr_thiophene", 100 | "fr_unbrch_alkane", 101 | "fr_urea", 102 | ] 103 | 104 | BOND_FEATURES = ["BondType", "Stereo", "BondDir"] 105 | 106 | 107 | # BOND_FEATURES = ['BondType', 'Stereo'] 108 | # BOND_FEATURES = ['Stereo'] 109 | 110 | 111 | @register_features_generator("fgtasklabel") 112 | def rdkit_functional_group_label_features_generator(mol: Molecule) -> np.ndarray: 113 | """ 114 | Generates functional group label for a molecule using RDKit. 115 | 116 | :param mol: A molecule (i.e. either a SMILES string or an RDKit molecule). 117 | :return: A 1D numpy array containing the RDKit 2D features. 118 | """ 119 | smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol 120 | generator = rdDescriptors.RDKit2D(RDKIT_PROPS) 121 | features = generator.process(smiles)[1:] 122 | features = np.array(features) 123 | features[features != 0] = 1 124 | return features 125 | 126 | 127 | def atom_to_vocab(mol, atom): 128 | """ 129 | Convert atom to vocabulary. The convention is based on atom type and bond type. 130 | :param mol: the molecular. 131 | :param atom: the target atom. 132 | :return: the generated atom vocabulary with its contexts. 133 | """ 134 | nei = Counter() 135 | for a in atom.GetNeighbors(): 136 | bond = mol.GetBondBetweenAtoms(atom.GetIdx(), a.GetIdx()) 137 | nei[str(a.GetSymbol()) + "-" + str(bond.GetBondType())] += 1 138 | keys = nei.keys() 139 | keys = list(keys) 140 | keys.sort() 141 | output = atom.GetSymbol() 142 | for k in keys: 143 | output = "%s_%s%d" % (output, k, nei[k]) 144 | 145 | # The generated atom_vocab is too long? 146 | return output 147 | 148 | 149 | def bond_to_vocab(mol, bond): 150 | """ 151 | Convert bond to vocabulary. The convention is based on atom type and bond type. 152 | Considering one-hop neighbor atoms 153 | :param mol: the molecular. 154 | :param atom: the target atom. 155 | :return: the generated bond vocabulary with its contexts. 156 | """ 157 | nei = Counter() 158 | two_neighbors = (bond.GetBeginAtom(), bond.GetEndAtom()) 159 | two_indices = [a.GetIdx() for a in two_neighbors] 160 | for nei_atom in two_neighbors: 161 | for a in nei_atom.GetNeighbors(): 162 | a_idx = a.GetIdx() 163 | if a_idx in two_indices: 164 | continue 165 | tmp_bond = mol.GetBondBetweenAtoms(nei_atom.GetIdx(), a_idx) 166 | nei[str(nei_atom.GetSymbol()) + "-" + get_bond_feature_name(tmp_bond)] += 1 167 | keys = list(nei.keys()) 168 | keys.sort() 169 | output = get_bond_feature_name(bond) 170 | for k in keys: 171 | output = "%s_%s%d" % (output, k, nei[k]) 172 | return output 173 | 174 | 175 | def get_bond_feature_name(bond): 176 | """ 177 | Return the string format of bond features. 178 | Bond features are surrounded with () 179 | 180 | """ 181 | ret = [] 182 | for bond_feature in BOND_FEATURES: 183 | fea = eval(f"bond.Get{bond_feature}")() 184 | ret.append(str(fea)) 185 | 186 | return "(" + "-".join(ret) + ")" 187 | -------------------------------------------------------------------------------- /embeddings/grover/grover/util/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | The evaluation metrics. 3 | """ 4 | import math 5 | from typing import Callable, List, Union 6 | 7 | from sklearn.metrics import ( 8 | accuracy_score, 9 | auc, 10 | confusion_matrix, 11 | mean_absolute_error, 12 | mean_squared_error, 13 | precision_recall_curve, 14 | r2_score, 15 | recall_score, 16 | roc_auc_score, 17 | ) 18 | 19 | 20 | def accuracy(targets: List[int], preds: List[float], threshold: float = 0.5) -> float: 21 | """ 22 | Computes the accuracy of a binary prediction task using a given threshold for generating hard predictions. 23 | 24 | :param targets: A list of binary targets. 25 | :param preds: A list of prediction probabilities. 26 | :param threshold: The threshold above which a prediction is a 1 and below which (inclusive) a prediction is a 0 27 | :return: The computed accuracy. 28 | """ 29 | hard_preds = [1 if p > threshold else 0 for p in preds] 30 | return accuracy_score(targets, hard_preds) 31 | 32 | 33 | def recall(targets: List[int], preds: List[float], threshold: float = 0.5) -> float: 34 | """ 35 | Computes the recall of a binary prediction task using a given threshold for generating hard predictions. 36 | 37 | :param targets: A list of binary targets. 38 | :param preds: A list of prediction probabilities. 39 | :param threshold: The threshold above which a prediction is a 1 and below which (inclusive) a prediction is a 0 40 | :return: The computed recall. 41 | """ 42 | hard_preds = [1 if p > threshold else 0 for p in preds] 43 | return recall_score(targets, hard_preds) 44 | 45 | 46 | def sensitivity( 47 | targets: List[int], preds: List[float], threshold: float = 0.5 48 | ) -> float: 49 | """ 50 | Computes the sensitivity of a binary prediction task using a given threshold for generating hard predictions. 51 | 52 | :param targets: A list of binary targets. 53 | :param preds: A list of prediction probabilities. 54 | :param threshold: The threshold above which a prediction is a 1 and below which (inclusive) a prediction is a 0 55 | :return: The computed sensitivity. 56 | """ 57 | return recall(targets, preds, threshold) 58 | 59 | 60 | def specificity( 61 | targets: List[int], preds: List[float], threshold: float = 0.5 62 | ) -> float: 63 | """ 64 | Computes the specificity of a binary prediction task using a given threshold for generating hard predictions. 65 | 66 | :param targets: A list of binary targets. 67 | :param preds: A list of prediction probabilities. 68 | :param threshold: The threshold above which a prediction is a 1 and below which (inclusive) a prediction is a 0 69 | :return: The computed specificity. 70 | """ 71 | hard_preds = [1 if p > threshold else 0 for p in preds] 72 | tn, fp, _, _ = confusion_matrix(targets, hard_preds).ravel() 73 | return tn / float(tn + fp) 74 | 75 | 76 | def rmse(targets: List[float], preds: List[float]) -> float: 77 | """ 78 | Computes the root mean squared error. 79 | 80 | :param targets: A list of targets. 81 | :param preds: A list of predictions. 82 | :return: The computed rmse. 83 | """ 84 | return math.sqrt(mean_squared_error(targets, preds)) 85 | 86 | 87 | def get_metric_func( 88 | metric: str, 89 | ) -> Callable[[Union[List[int], List[float]], List[float]], float]: 90 | """ 91 | Gets the metric function corresponding to a given metric name. 92 | 93 | :param metric: Metric name. 94 | :return: A metric function which takes as arguments a list of targets and a list of predictions and returns. 95 | """ 96 | # Note: If you want to add a new metric, please also update the parser argument --metric in parsing.py. 97 | if metric == "auc": 98 | return roc_auc_score 99 | 100 | if metric == "prc-auc": 101 | return prc_auc 102 | 103 | if metric == "rmse": 104 | return rmse 105 | 106 | if metric == "mae": 107 | return mean_absolute_error 108 | 109 | if metric == "r2": 110 | return r2_score 111 | 112 | if metric == "accuracy": 113 | return accuracy 114 | 115 | if metric == "recall": 116 | return recall 117 | 118 | if metric == "sensitivity": 119 | return sensitivity 120 | 121 | if metric == "specificity": 122 | return specificity 123 | 124 | raise ValueError(f'Metric "{metric}" not supported.') 125 | 126 | 127 | def prc_auc(targets: List[int], preds: List[float]) -> float: 128 | """ 129 | Computes the area under the precision-recall curve. 130 | 131 | :param targets: A list of binary targets. 132 | :param preds: A list of prediction probabilities. 133 | :return: The computed prc-auc. 134 | """ 135 | precision, recall, _ = precision_recall_curve(targets, preds) 136 | return auc(recall, precision) 137 | -------------------------------------------------------------------------------- /embeddings/grover/grover/util/multi_gpu_wrapper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Wrapper for multi-GPU training. 3 | """ 4 | # use Hovorod for multi-GPU pytorch training 5 | try: 6 | import horovod.torch as mgw 7 | import torch 8 | 9 | print("using Horovod for multi-GPU training") 10 | except ImportError: 11 | print("[WARNING] Horovod cannot be imported; multi-GPU training is unsupported") 12 | pass 13 | 14 | 15 | class MultiGpuWrapper(object): 16 | """Wrapper for multi-GPU training.""" 17 | 18 | def __init__(self): 19 | """Constructor function.""" 20 | pass 21 | 22 | @classmethod 23 | def init(cls, *args): 24 | """Initialization.""" 25 | 26 | try: 27 | return mgw.init(*args) 28 | except NameError: 29 | raise NameError("module not imported") 30 | 31 | @classmethod 32 | def size(cls, *args): 33 | """Get the number of workers at all nodes.""" 34 | 35 | try: 36 | return mgw.size(*args) 37 | except NameError: 38 | raise NameError("module not imported") 39 | 40 | @classmethod 41 | def rank(cls, *args): 42 | """Get the rank of current worker at all nodes.""" 43 | 44 | try: 45 | return mgw.rank(*args) 46 | except NameError: 47 | raise NameError("module not imported") 48 | 49 | @classmethod 50 | def local_size(cls, *args): 51 | """Get the number of workers at the current node.""" 52 | 53 | try: 54 | return mgw.local_size(*args) 55 | except NameError: 56 | raise NameError("module not imported") 57 | 58 | @classmethod 59 | def local_rank(cls, *args): 60 | """Get the rank of current worker at the current node.""" 61 | 62 | try: 63 | return mgw.local_rank(*args) 64 | except NameError: 65 | raise NameError("module not imported") 66 | 67 | @classmethod 68 | def DistributedOptimizer(cls, *args, **kwargs): 69 | """Get a distributed optimizer from the base optimizer.""" 70 | 71 | try: 72 | return mgw.DistributedOptimizer(*args, **kwargs) 73 | except NameError: 74 | raise NameError("module not imported") 75 | 76 | @classmethod 77 | def broadcast_parameters(cls, *args, **kwargs): 78 | """Get a operation to broadcast all the parameters.""" 79 | 80 | try: 81 | return mgw.broadcast_parameters(*args, **kwargs) 82 | except NameError: 83 | raise NameError("module not imported") 84 | 85 | @classmethod 86 | def broadcast_optimizer_state(cls, *args, **kwargs): 87 | """Get a operation to broadcast all the optimizer state.""" 88 | 89 | try: 90 | return mgw.broadcast_optimizer_state(*args, **kwargs) 91 | except NameError: 92 | raise NameError("module not imported") 93 | 94 | @classmethod 95 | def broadcast(cls, *args, **kwargs): 96 | """Get a operation to broadcast all the optimizer state.""" 97 | 98 | try: 99 | return mgw.broadcast(*args, **kwargs) 100 | except NameError: 101 | raise NameError("module not imported") 102 | 103 | @classmethod 104 | def barrier(cls): 105 | """Add a barrier to synchronize different processes""" 106 | 107 | try: 108 | return mgw.allreduce(torch.tensor(0), name="barrier") 109 | except NameError: 110 | raise NameError("module not imported") 111 | -------------------------------------------------------------------------------- /embeddings/grover/grover/util/nn_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | The utility function for model construction. 3 | This implementation is adapted from 4 | https://github.com/chemprop/chemprop/blob/master/chemprop/nn_utils.py 5 | """ 6 | import torch 7 | from torch import nn as nn 8 | 9 | 10 | def param_count(model: nn.Module) -> int: 11 | """ 12 | Determines number of trainable parameters. 13 | :param model: An nn.Module. 14 | :return: The number of trainable parameters. 15 | """ 16 | return sum(param.numel() for param in model.parameters() if param.requires_grad) 17 | 18 | 19 | def index_select_nd(source: torch.Tensor, index: torch.Tensor) -> torch.Tensor: 20 | """ 21 | Selects the message features from source corresponding to the atom or bond indices in index. 22 | 23 | :param source: A tensor of shape (num_bonds, hidden_size) containing message features. 24 | :param index: A tensor of shape (num_atoms/num_bonds, max_num_bonds) containing the atom or bond 25 | indices to select from source. 26 | :return: A tensor of shape (num_atoms/num_bonds, max_num_bonds, hidden_size) containing the message 27 | features corresponding to the atoms/bonds specified in index. 28 | """ 29 | index_size = index.size() # (num_atoms/num_bonds, max_num_bonds) 30 | suffix_dim = source.size()[1:] # (hidden_size,) 31 | final_size = ( 32 | index_size + suffix_dim 33 | ) # (num_atoms/num_bonds, max_num_bonds, hidden_size) 34 | 35 | target = source.index_select( 36 | dim=0, index=index.view(-1) 37 | ) # (num_atoms/num_bonds * max_num_bonds, hidden_size) 38 | target = target.view( 39 | final_size 40 | ) # (num_atoms/num_bonds, max_num_bonds, hidden_size) 41 | 42 | return target 43 | 44 | 45 | def get_activation_function(activation: str) -> nn.Module: 46 | """ 47 | Gets an activation function module given the name of the activation. 48 | 49 | :param activation: The name of the activation function. 50 | :return: The activation function module. 51 | """ 52 | if activation == "ReLU": 53 | return nn.ReLU() 54 | elif activation == "LeakyReLU": 55 | return nn.LeakyReLU(0.1) 56 | elif activation == "PReLU": 57 | return nn.PReLU() 58 | elif activation == "tanh": 59 | return nn.Tanh() 60 | elif activation == "SELU": 61 | return nn.SELU() 62 | elif activation == "ELU": 63 | return nn.ELU() 64 | elif activation == "Linear": 65 | return lambda x: x 66 | else: 67 | raise ValueError(f'Activation "{activation}" not supported.') 68 | 69 | 70 | def initialize_weights(model: nn.Module, distinct_init=False, model_idx=0): 71 | """ 72 | Initializes the weights of a model in place. 73 | 74 | :param model: An nn.Module. 75 | """ 76 | init_fns = [ 77 | nn.init.kaiming_normal_, 78 | nn.init.kaiming_uniform_, 79 | nn.init.xavier_normal_, 80 | nn.init.xavier_uniform_, 81 | ] 82 | for param in model.parameters(): 83 | if param.dim() == 1: 84 | nn.init.constant_(param, 0) 85 | else: 86 | if distinct_init: 87 | init_fn = init_fns[model_idx % 4] 88 | if "kaiming" in init_fn.__name__: 89 | init_fn(param, nonlinearity="relu") 90 | else: 91 | init_fn(param) 92 | else: 93 | nn.init.xavier_normal_(param) 94 | 95 | 96 | def select_neighbor_and_aggregate(feature, index): 97 | """ 98 | The basic operation in message passing. 99 | Caution: the index_selec_ND would cause the reproducibility issue when performing the training on CUDA. 100 | See: https://pytorch.org/docs/stable/notes/randomness.html 101 | :param feature: the candidate feature for aggregate. (n_nodes, hidden) 102 | :param index: the selected index (neighbor indexes). 103 | :return: 104 | """ 105 | neighbor = index_select_nd(feature, index) 106 | return neighbor.sum(dim=1) 107 | -------------------------------------------------------------------------------- /embeddings/grover/grover/util/scheduler.py: -------------------------------------------------------------------------------- 1 | """ 2 | The learning rate scheduler. 3 | This implementation is adapted from 4 | https://github.com/chemprop/chemprop/blob/master/chemprop/nn_utils.py 5 | """ 6 | from typing import List, Union 7 | 8 | import numpy as np 9 | from torch.optim.lr_scheduler import _LRScheduler 10 | 11 | 12 | class NoamLR(_LRScheduler): 13 | """ 14 | Noam learning rate scheduler with piecewise linear increase and exponential decay. 15 | 16 | The learning rate increases linearly from init_lr to max_lr over the course of 17 | the first warmup_steps (where warmup_steps = warmup_epochs * steps_per_epoch). 18 | Then the learning rate decreases exponentially from max_lr to final_lr over the 19 | course of the remaining total_steps - warmup_steps (where total_steps = 20 | total_epochs * steps_per_epoch). This is roughly based on the learning rate 21 | schedule from SelfAttention is All You Need, section 5.3 (https://arxiv.org/abs/1706.03762). 22 | """ 23 | 24 | def __init__( 25 | self, 26 | optimizer, 27 | warmup_epochs: List[Union[float, int]], 28 | total_epochs: List[int], 29 | steps_per_epoch: int, 30 | init_lr: List[float], 31 | max_lr: List[float], 32 | final_lr: List[float], 33 | fine_tune_coff: float = 1.0, 34 | fine_tune_param_idx: int = 0, 35 | ): 36 | """ 37 | Initializes the learning rate scheduler. 38 | 39 | 40 | :param optimizer: A PyTorch optimizer. 41 | :param warmup_epochs: The number of epochs during which to linearly increase the learning rate. 42 | :param total_epochs: The total number of epochs. 43 | :param steps_per_epoch: The number of steps (batches) per epoch. 44 | :param init_lr: The initial learning rate. 45 | :param max_lr: The maximum learning rate (achieved after warmup_epochs). 46 | :param final_lr: The final learning rate (achieved after total_epochs). 47 | :param fine_tune_coff: The fine tune coefficient for the target param group. The true learning rate for the 48 | target param group would be lr*fine_tune_coff. 49 | :param fine_tune_param_idx: The index of target param group. Default is index 0. 50 | """ 51 | 52 | # assert len(optimizer.param_groups) == len(warmup_epochs) == len(total_epochs) == len(init_lr) == \ 53 | # len(max_lr) == len(final_lr) 54 | 55 | self.num_lrs = len(optimizer.param_groups) 56 | 57 | self.optimizer = optimizer 58 | self.warmup_epochs = np.array([warmup_epochs] * self.num_lrs) 59 | self.total_epochs = np.array([total_epochs] * self.num_lrs) 60 | self.steps_per_epoch = steps_per_epoch 61 | self.init_lr = np.array([init_lr] * self.num_lrs) 62 | self.max_lr = np.array([max_lr] * self.num_lrs) 63 | self.final_lr = np.array([final_lr] * self.num_lrs) 64 | self.lr_coff = np.array([1] * self.num_lrs) 65 | self.fine_tune_param_idx = fine_tune_param_idx 66 | self.lr_coff[self.fine_tune_param_idx] = fine_tune_coff 67 | 68 | self.current_step = 0 69 | self.lr = [init_lr] * self.num_lrs 70 | self.warmup_steps = (self.warmup_epochs * self.steps_per_epoch).astype(int) 71 | self.total_steps = self.total_epochs * self.steps_per_epoch 72 | self.linear_increment = (self.max_lr - self.init_lr) / self.warmup_steps 73 | 74 | self.exponential_gamma = (self.final_lr / self.max_lr) ** ( 75 | 1 / (self.total_steps - self.warmup_steps) 76 | ) 77 | super(NoamLR, self).__init__(optimizer) 78 | 79 | def get_lr(self) -> List[float]: 80 | """Gets a list of the current learning rates.""" 81 | return list(self.lr) 82 | 83 | def step(self, current_step: int = None): 84 | """ 85 | Updates the learning rate by taking a step. 86 | 87 | :param current_step: Optionally specify what step to set the learning rate to. 88 | If None, current_step = self.current_step + 1. 89 | """ 90 | if current_step is not None: 91 | self.current_step = current_step 92 | else: 93 | self.current_step += 1 94 | for i in range(self.num_lrs): 95 | if self.current_step <= self.warmup_steps[i]: 96 | self.lr[i] = ( 97 | self.init_lr[i] + self.current_step * self.linear_increment[i] 98 | ) 99 | elif self.current_step <= self.total_steps[i]: 100 | self.lr[i] = self.max_lr[i] * ( 101 | self.exponential_gamma[i] 102 | ** (self.current_step - self.warmup_steps[i]) 103 | ) 104 | else: # theoretically this case should never be reached since training should stop at total_steps 105 | self.lr[i] = self.final_lr[i] 106 | self.lr[i] *= self.lr_coff[i] 107 | self.optimizer.param_groups[i]["lr"] = self.lr[i] 108 | -------------------------------------------------------------------------------- /embeddings/grover/main.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from grover.data.torchvocab import MolVocab 6 | from grover.util.parsing import get_newest_train_args, parse_args 7 | from grover.util.utils import create_logger 8 | from rdkit import RDLogger 9 | from task.cross_validate import cross_validate 10 | from task.fingerprint import generate_fingerprints 11 | from task.predict import make_predictions, write_prediction 12 | from task.pretrain import pretrain_model 13 | 14 | 15 | def setup(seed): 16 | # frozen random seed 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | np.random.seed(seed) 20 | random.seed(seed) 21 | torch.backends.cudnn.deterministic = True 22 | 23 | 24 | if __name__ == "__main__": 25 | # setup random seed 26 | setup(seed=42) 27 | # Avoid the pylint warning. 28 | a = MolVocab 29 | # supress rdkit logger 30 | lg = RDLogger.logger() 31 | lg.setLevel(RDLogger.CRITICAL) 32 | 33 | # Initialize MolVocab 34 | mol_vocab = MolVocab 35 | 36 | args = parse_args() 37 | if args.parser_name == "finetune": 38 | logger = create_logger(name="train", save_dir=args.save_dir, quiet=False) 39 | cross_validate(args, logger) 40 | elif args.parser_name == "pretrain": 41 | logger = create_logger(name="pretrain", save_dir=args.save_dir) 42 | pretrain_model(args, logger) 43 | elif args.parser_name == "eval": 44 | logger = create_logger(name="eval", save_dir=args.save_dir, quiet=False) 45 | cross_validate(args, logger) 46 | elif args.parser_name == "fingerprint": 47 | train_args = get_newest_train_args() 48 | logger = create_logger(name="fingerprint", save_dir=None, quiet=False) 49 | feas = generate_fingerprints(args, logger) 50 | np.savez_compressed(args.output_path, fps=feas) 51 | elif args.parser_name == "predict": 52 | train_args = get_newest_train_args() 53 | avg_preds, test_smiles = make_predictions(args, train_args) 54 | write_prediction(avg_preds, test_smiles, args) 55 | -------------------------------------------------------------------------------- /embeddings/grover/requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | boost=1.68.0=py36h8619c78_1001 5 | boost-cpp=1.68.0=h11c811c_1000 6 | descriptastorus=2.2.0=py_0 7 | numpy=1.16.4=py36h7e9f1db_0 8 | numpy-base=1.16.4=py36hde5b4d6_0 9 | pandas=0.25.0=py36hb3f55d8_0 10 | python=3.6.8=h0371630_0 11 | pytorch=1.1.0=py3.6_cuda9.0.176_cudnn7.5.1_0 12 | tensorboard=1.13.1=py36_0 13 | torchvision=0.3.0=py36_cu9.0.176_1 14 | rdkit=2019.03.4.0=py36hc20afe1_1 15 | readline=7.0=h7b6447c_5 16 | scikit-learn=0.21.2=py36hcdab131_1 17 | scipy=1.3.0=py36h921218d_1 18 | tqdm=4.32.1=py_0 19 | typing=3.6.4=py36_0 -------------------------------------------------------------------------------- /embeddings/grover/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/chemCPA/43e830eb0958c54e4aa64442c17ec0fed19b3f15/embeddings/grover/scripts/__init__.py -------------------------------------------------------------------------------- /embeddings/grover/scripts/build_vocab.py: -------------------------------------------------------------------------------- 1 | """ 2 | The vocabulary building scripts. 3 | """ 4 | import os 5 | 6 | from grover.data.torchvocab import MolVocab 7 | 8 | 9 | def build(): 10 | import argparse 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument( 14 | "--data_path", 15 | default="../../dataset/grover_new_dataset/druglike_merged_refine2.csv", 16 | type=str, 17 | ) 18 | parser.add_argument( 19 | "--vocab_save_folder", default="../../dataset/grover_new_dataset", type=str 20 | ) 21 | parser.add_argument( 22 | "--dataset_name", 23 | type=str, 24 | default=None, 25 | help="Will be the first part of the vocab file name. If it is None," 26 | "the vocab files will be: atom_vocab.pkl and bond_vocab.pkl", 27 | ) 28 | parser.add_argument("--vocab_max_size", type=int, default=None) 29 | parser.add_argument("--vocab_min_freq", type=int, default=1) 30 | args = parser.parse_args() 31 | 32 | # fin = open(args.data_path, 'r') 33 | # lines = fin.readlines() 34 | 35 | for vocab_type in ["atom", "bond"]: 36 | vocab_file = f"{vocab_type}_vocab.pkl" 37 | if args.dataset_name is not None: 38 | vocab_file = args.dataset_name + "_" + vocab_file 39 | vocab_save_path = os.path.join(args.vocab_save_folder, vocab_file) 40 | 41 | os.makedirs(os.path.dirname(vocab_save_path), exist_ok=True) 42 | vocab = MolVocab( 43 | file_path=args.data_path, 44 | max_size=args.vocab_max_size, 45 | min_freq=args.vocab_min_freq, 46 | num_workers=100, 47 | vocab_type=vocab_type, 48 | ) 49 | print(f"{vocab_type} vocab size", len(vocab)) 50 | vocab.save_vocab(vocab_save_path) 51 | 52 | 53 | if __name__ == "__main__": 54 | build() 55 | -------------------------------------------------------------------------------- /embeddings/grover/scripts/save_features.py: -------------------------------------------------------------------------------- 1 | """ 2 | Computes and saves molecular features for a dataset. 3 | """ 4 | import os 5 | import shutil 6 | import sys 7 | from argparse import ArgumentParser, Namespace 8 | from multiprocessing import Pool 9 | from typing import List, Tuple 10 | 11 | from tqdm import tqdm 12 | 13 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 14 | 15 | from grover.data.molfeaturegenerator import ( 16 | get_available_features_generators, 17 | get_features_generator, 18 | ) 19 | from grover.data.task_labels import rdkit_functional_group_label_features_generator 20 | from grover.util.utils import get_data, load_features, makedirs, save_features 21 | 22 | 23 | def load_temp(temp_dir: str) -> Tuple[List[List[float]], int]: 24 | """ 25 | Loads all features saved as .npz files in load_dir. 26 | 27 | Assumes temporary files are named in order 0.npz, 1.npz, ... 28 | 29 | :param temp_dir: Directory in which temporary .npz files containing features are stored. 30 | :return: A tuple with a list of molecule features, where each molecule's features is a list of floats, 31 | and the number of temporary files. 32 | """ 33 | features = [] 34 | temp_num = 0 35 | temp_path = os.path.join(temp_dir, f"{temp_num}.npz") 36 | 37 | while os.path.exists(temp_path): 38 | features.extend(load_features(temp_path)) 39 | temp_num += 1 40 | temp_path = os.path.join(temp_dir, f"{temp_num}.npz") 41 | 42 | return features, temp_num 43 | 44 | 45 | def generate_and_save_features(args: Namespace): 46 | """ 47 | Computes and saves features for a dataset of molecules as a 2D array in a .npz file. 48 | 49 | :param args: Arguments. 50 | """ 51 | # Create directory for save_path 52 | makedirs(args.save_path, isfile=True) 53 | 54 | # Get data and features function 55 | data = get_data(path=args.data_path, max_data_size=None) 56 | features_generator = get_features_generator(args.features_generator) 57 | temp_save_dir = args.save_path + "_temp" 58 | 59 | # Load partially complete data 60 | if args.restart: 61 | if os.path.exists(args.save_path): 62 | os.remove(args.save_path) 63 | if os.path.exists(temp_save_dir): 64 | shutil.rmtree(temp_save_dir) 65 | else: 66 | if os.path.exists(args.save_path): 67 | raise ValueError( 68 | f'"{args.save_path}" already exists and args.restart is False.' 69 | ) 70 | 71 | if os.path.exists(temp_save_dir): 72 | features, temp_num = load_temp(temp_save_dir) 73 | 74 | if not os.path.exists(temp_save_dir): 75 | makedirs(temp_save_dir) 76 | features, temp_num = [], 0 77 | 78 | # Build features map function 79 | data = data[ 80 | len(features) : 81 | ] # restrict to data for which features have not been computed yet 82 | mols = (d.smiles for d in data) 83 | 84 | if args.sequential: 85 | features_map = map(features_generator, mols) 86 | else: 87 | features_map = Pool(30).imap(features_generator, mols) 88 | 89 | # Get features 90 | temp_features = [] 91 | for i, feats in tqdm(enumerate(features_map), total=len(data)): 92 | temp_features.append(feats) 93 | 94 | # Save temporary features every save_frequency 95 | if (i > 0 and (i + 1) % args.save_frequency == 0) or i == len(data) - 1: 96 | save_features(os.path.join(temp_save_dir, f"{temp_num}.npz"), temp_features) 97 | features.extend(temp_features) 98 | temp_features = [] 99 | temp_num += 1 100 | 101 | try: 102 | # Save all features 103 | save_features(args.save_path, features) 104 | 105 | # Remove temporary features 106 | shutil.rmtree(temp_save_dir) 107 | except OverflowError: 108 | print( 109 | "Features array is too large to save as a single file. Instead keeping features as a directory of files." 110 | ) 111 | 112 | 113 | if __name__ == "__main__": 114 | 115 | parser = ArgumentParser() 116 | parser.add_argument("--data_path", type=str, required=True, help="Path to data CSV") 117 | parser.add_argument( 118 | "--features_generator", 119 | type=str, 120 | required=True, 121 | choices=get_available_features_generators(), 122 | help="Type of features to generate", 123 | ) 124 | parser.add_argument( 125 | "--save_path", 126 | type=str, 127 | default=None, 128 | help="Path to .npz file where features will be saved as a compressed numpy archive", 129 | ) 130 | parser.add_argument( 131 | "--save_frequency", 132 | type=int, 133 | default=10000, 134 | help="Frequency with which to save the features", 135 | ) 136 | parser.add_argument( 137 | "--restart", 138 | action="store_true", 139 | default=False, 140 | help="Whether to not load partially complete featurization and instead start from scratch", 141 | ) 142 | parser.add_argument( 143 | "--max_data_size", type=int, help="Maximum number of data points to load" 144 | ) 145 | parser.add_argument( 146 | "--sequential", 147 | action="store_true", 148 | default=False, 149 | help="Whether to task sequentially rather than in parallel", 150 | ) 151 | args = parser.parse_args() 152 | if args.save_path is None: 153 | args.save_path = args.data_path.split("csv")[0] + "npz" 154 | generate_and_save_features(args) 155 | -------------------------------------------------------------------------------- /embeddings/grover/scripts/split_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | The data splitting script for pretraining. 3 | """ 4 | import csv 5 | import os 6 | import shutil 7 | from argparse import ArgumentParser 8 | 9 | import grover.util.utils as fea_utils 10 | import numpy as np 11 | 12 | parser = ArgumentParser() 13 | parser.add_argument( 14 | "--data_path", default="../drug_data/grover_data/delaneyfreesolvlipo.csv" 15 | ) 16 | parser.add_argument( 17 | "--features_path", 18 | default="../drug_data/grover_data/delaneyfreesolvlipo_molbert.npz", 19 | ) 20 | parser.add_argument("--sample_per_file", type=int, default=1000) 21 | parser.add_argument( 22 | "--output_path", default="../drug_data/grover_data/delaneyfreesolvlipo" 23 | ) 24 | 25 | 26 | def load_smiles(data_path): 27 | with open(data_path) as f: 28 | reader = csv.reader(f) 29 | header = next(reader) 30 | res = [] 31 | for line in reader: 32 | res.append(line) 33 | return res, header 34 | 35 | 36 | def load_features(data_path): 37 | fea = fea_utils.load_features(data_path) 38 | return fea 39 | 40 | 41 | def save_smiles(data_path, index, data, header): 42 | fn = os.path.join(data_path, str(index) + ".csv") 43 | with open(fn, "w") as f: 44 | fw = csv.writer(f) 45 | fw.writerow(header) 46 | for d in data: 47 | fw.writerow(d) 48 | 49 | 50 | def save_features(data_path, index, data): 51 | fn = os.path.join(data_path, str(index) + ".npz") 52 | np.savez_compressed(fn, features=data) 53 | 54 | 55 | def run(): 56 | args = parser.parse_args() 57 | res, header = load_smiles(data_path=args.data_path) 58 | fea = load_features(data_path=args.features_path) 59 | assert len(res) == fea.shape[0] 60 | 61 | n_graphs = len(res) 62 | perm = np.random.permutation(n_graphs) 63 | 64 | nfold = int(n_graphs / args.sample_per_file + 1) 65 | print("Number of files: %d" % nfold) 66 | if os.path.exists(args.output_path): 67 | shutil.rmtree(args.output_path) 68 | os.makedirs(args.output_path, exist_ok=True) 69 | graph_path = os.path.join(args.output_path, "graph") 70 | fea_path = os.path.join(args.output_path, "feature") 71 | os.makedirs(graph_path, exist_ok=True) 72 | os.makedirs(fea_path, exist_ok=True) 73 | 74 | for i in range(nfold): 75 | sidx = i * args.sample_per_file 76 | eidx = min((i + 1) * args.sample_per_file, n_graphs) 77 | indexes = perm[sidx:eidx] 78 | sres = [res[j] for j in indexes] 79 | sfea = fea[indexes] 80 | save_smiles(graph_path, i, sres, header) 81 | save_features(fea_path, i, sfea) 82 | 83 | summary_path = os.path.join(args.output_path, "summary.txt") 84 | summary_fout = open(summary_path, "w") 85 | summary_fout.write("n_files:%d\n" % nfold) 86 | summary_fout.write("n_samples:%d\n" % n_graphs) 87 | summary_fout.write("sample_per_file:%d\n" % args.sample_per_file) 88 | summary_fout.close() 89 | 90 | 91 | if __name__ == "__main__": 92 | run() 93 | -------------------------------------------------------------------------------- /embeddings/grover/task/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/chemCPA/43e830eb0958c54e4aa64442c17ec0fed19b3f15/embeddings/grover/task/__init__.py -------------------------------------------------------------------------------- /embeddings/grover/task/cross_validate.py: -------------------------------------------------------------------------------- 1 | """ 2 | The cross validation function for finetuning. 3 | This implementation is adapted from 4 | https://github.com/chemprop/chemprop/blob/master/chemprop/train/cross_validate.py 5 | """ 6 | import os 7 | import time 8 | from argparse import Namespace 9 | from logging import Logger 10 | from typing import Tuple 11 | 12 | import numpy as np 13 | from grover.util.utils import get_task_names, makedirs 14 | from task.run_evaluation import run_evaluation 15 | from task.train import run_training 16 | 17 | 18 | def cross_validate(args: Namespace, logger: Logger = None) -> Tuple[float, float]: 19 | """ 20 | k-fold cross validation. 21 | 22 | :return: A tuple of mean_score and std_score. 23 | """ 24 | info = logger.info if logger is not None else print 25 | 26 | # Initialize relevant variables 27 | init_seed = args.seed 28 | save_dir = args.save_dir 29 | task_names = get_task_names(args.data_path) 30 | 31 | # Run training with different random seeds for each fold 32 | all_scores = [] 33 | time_start = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) 34 | for fold_num in range(args.num_folds): 35 | info(f"Fold {fold_num}") 36 | args.seed = init_seed + fold_num 37 | args.save_dir = os.path.join(save_dir, f"fold_{fold_num}") 38 | makedirs(args.save_dir) 39 | if args.parser_name == "finetune": 40 | model_scores = run_training(args, time_start, logger) 41 | else: 42 | model_scores = run_evaluation(args, logger) 43 | all_scores.append(model_scores) 44 | all_scores = np.array(all_scores) 45 | 46 | # Report scores for each fold 47 | info(f"{args.num_folds}-fold cross validation") 48 | 49 | for fold_num, scores in enumerate(all_scores): 50 | info( 51 | f"Seed {init_seed + fold_num} ==> test {args.metric} = {np.nanmean(scores):.6f}" 52 | ) 53 | 54 | if args.show_individual_scores: 55 | for task_name, score in zip(task_names, scores): 56 | info( 57 | f"Seed {init_seed + fold_num} ==> test {task_name} {args.metric} = {score:.6f}" 58 | ) 59 | 60 | # Report scores across models 61 | avg_scores = np.nanmean( 62 | all_scores, axis=1 63 | ) # average score for each model across tasks 64 | mean_score, std_score = np.nanmean(avg_scores), np.nanstd(avg_scores) 65 | info(f"overall_{args.split_type}_test_{args.metric}={mean_score:.6f}") 66 | info(f"std={std_score:.6f}") 67 | 68 | if args.show_individual_scores: 69 | for task_num, task_name in enumerate(task_names): 70 | info( 71 | f"Overall test {task_name} {args.metric} = " 72 | f"{np.nanmean(all_scores[:, task_num]):.6f} +/- {np.nanstd(all_scores[:, task_num]):.6f}" 73 | ) 74 | 75 | return mean_score, std_score 76 | -------------------------------------------------------------------------------- /embeddings/grover/task/fingerprint.py: -------------------------------------------------------------------------------- 1 | """ 2 | The fingerprint generation function. 3 | """ 4 | from argparse import Namespace 5 | from logging import Logger 6 | from typing import List 7 | 8 | import torch 9 | import torch.nn as nn 10 | from grover.data import MolCollator, MoleculeDataset 11 | from grover.util.utils import create_logger, get_data, load_checkpoint 12 | from torch.utils.data import DataLoader 13 | 14 | 15 | def do_generate( 16 | model: nn.Module, 17 | data: MoleculeDataset, 18 | args: Namespace, 19 | ) -> List[List[float]]: 20 | """ 21 | Do the fingerprint generation on a dataset using the pre-trained models. 22 | 23 | :param model: A model. 24 | :param data: A MoleculeDataset. 25 | :param args: A StandardScaler object fit on the training targets. 26 | :return: A list of fingerprints. 27 | """ 28 | model.eval() 29 | args.bond_drop_rate = 0 30 | preds = [] 31 | 32 | mol_collator = MolCollator(args=args, shared_dict={}) 33 | 34 | num_workers = 4 35 | mol_loader = DataLoader( 36 | data, 37 | batch_size=32, 38 | shuffle=False, 39 | num_workers=num_workers, 40 | collate_fn=mol_collator, 41 | ) 42 | for item in mol_loader: 43 | _, batch, features_batch, _, _ = item 44 | with torch.no_grad(): 45 | batch_preds = model(batch, features_batch) 46 | preds.extend(batch_preds.data.cpu().numpy()) 47 | return preds 48 | 49 | 50 | def generate_fingerprints(args: Namespace, logger: Logger = None) -> List[List[float]]: 51 | """ 52 | Generate the fingerprints. 53 | 54 | :param logger: 55 | :param args: Arguments. 56 | :return: A list of lists of target fingerprints. 57 | """ 58 | 59 | checkpoint_path = args.checkpoint_paths[0] 60 | if logger is None: 61 | logger = create_logger("fingerprints", quiet=False) 62 | print("Loading data") 63 | test_data = get_data( 64 | path=args.data_path, 65 | args=args, 66 | use_compound_names=False, 67 | max_data_size=float("inf"), 68 | skip_invalid_smiles=False, 69 | ) 70 | test_data = MoleculeDataset(test_data) 71 | 72 | logger.info(f"Total size = {len(test_data):,}") 73 | logger.info(f"Generating...") 74 | # Load model 75 | model = load_checkpoint( 76 | checkpoint_path, cuda=args.cuda, current_args=args, logger=logger 77 | ) 78 | model_preds = do_generate(model=model, data=test_data, args=args) 79 | 80 | return model_preds 81 | -------------------------------------------------------------------------------- /embeddings/jtvae/README.md: -------------------------------------------------------------------------------- 1 | # Notes on JTVAE 2 | 3 | GPU runs out of memory if too many workers are used. 4 | Currently the training code in the main repository is broken, a fix is at 5 | https://github.com/siboehm/dgl-lifesci/tree/jtvae. 6 | 7 | - `lincs_trapnell.smiles`: 17870 SMILES 8 | - `~/.dgl/jtvae/train.txt` (ZINC): 220011 SMILES -------------------------------------------------------------------------------- /embeddings/jtvae/analyze_smiles.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # notebook_metadata_filter: -kernelspec 5 | # text_representation: 6 | # extension: .py 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.14.1 10 | # --- 11 | 12 | # %% 13 | from pathlib import Path 14 | 15 | import matplotlib 16 | import seaborn as sn 17 | 18 | matplotlib.style.use("fivethirtyeight") 19 | matplotlib.style.use("seaborn-talk") 20 | matplotlib.rcParams["font.family"] = "monospace" 21 | matplotlib.pyplot.rcParams["savefig.facecolor"] = "white" 22 | sn.set_context("poster") 23 | 24 | # %% pycharm={"name": "#%%\n"} 25 | zinc_dgl = Path.home() / ".dgl" / "jtvae" / "train.txt" 26 | lincs_trapnell = Path.cwd().parent / "lincs_trapnell.smiles" 27 | outfile = Path.cwd().parent / "lincs_trapnell.smiles.short" 28 | assert zinc_dgl.exists() and lincs_trapnell.exists() 29 | 30 | # %% pycharm={"name": "#%%\n"} 31 | for p in [zinc_dgl, lincs_trapnell]: 32 | with open(p) as f: 33 | max_length = 0 34 | for smile in f: 35 | if len(smile.strip()) > max_length: 36 | max_length = len(smile.strip()) 37 | print(f"Max length of {p} is {max_length}") 38 | 39 | # %% pycharm={"name": "#%%\n"} 40 | with open(lincs_trapnell) as f: 41 | count = 0 42 | for smile in f: 43 | smile = smile.strip() 44 | if len(smile) >= 200: 45 | count += 1 46 | print(f"There are {count} SMILES >= 200") 47 | 48 | # %% pycharm={"name": "#%%\n"} 49 | with open(lincs_trapnell) as f: 50 | h = [] 51 | for smile in f: 52 | h.append(len(smile.strip())) 53 | 54 | # %% pycharm={"name": "#%%\n"} 55 | ax = sn.histplot(h) 56 | ax.set_title("SMILES-length in LINCS") 57 | 58 | # %% [markdown] 59 | # ## Generate a new smiles list 60 | # We generate a new list of SMILES that are pruned to length <= 200 61 | 62 | # %% pycharm={"name": "#%%\n"} 63 | with open(outfile, "w") as outfile, open(lincs_trapnell) as infile: 64 | for line in infile: 65 | line = line.strip() 66 | if len(line) < 200: 67 | outfile.write(line + "\n") 68 | 69 | # %% pycharm={"name": "#%%\n"} 70 | with open(Path.cwd().parent / "lincs_trapnell.smiles.mini", "w") as outfile, open( 71 | lincs_trapnell 72 | ) as infile: 73 | for line in infile: 74 | line = line.strip() 75 | if len(line) <= 120: 76 | outfile.write(line + "\n") 77 | -------------------------------------------------------------------------------- /embeddings/jtvae/environment.yml: -------------------------------------------------------------------------------- 1 | name: jtvae_dgl 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - pytorch::pytorch 6 | - cudatoolkit=10.2 7 | - python 8 | - rdkit=2018.09.3 9 | - jupyter 10 | - pyarrow 11 | - tqdm 12 | - dglteam::dgl-cuda10.2 13 | - pip 14 | - seml 15 | -------------------------------------------------------------------------------- /embeddings/jtvae/generate_embeddings.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # notebook_metadata_filter: -kernelspec 5 | # text_representation: 6 | # extension: .py 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.14.1 10 | # --- 11 | 12 | # %% [markdown] 13 | # # JTVAE embedding 14 | # This is a molecule embedding using the JunctionTree VAE, as implemented in DGLLifeSci. 15 | # 16 | # It's pretrained on LINCS + Trapnell + half of ZINC (~220K molecules total). 17 | # LINCS contains a `Cl.[Li]` molecule which fails during encoding, so it just gets a dummy encoding. 18 | 19 | # %% 20 | import pickle 21 | 22 | import pandas as pd 23 | import rdkit 24 | import torch 25 | from dgllife.data import JTVAECollator, JTVAEDataset 26 | from dgllife.model import load_pretrained 27 | from tqdm import tqdm 28 | 29 | print(rdkit.__version__) 30 | print(torch.__version__) 31 | assert torch.cuda.is_available() 32 | 33 | # %% pycharm={"name": "#%%\n"} 34 | from dgllife.model import JTNNVAE 35 | 36 | from_pretrained = False 37 | if from_pretrained: 38 | model = load_pretrained("JTVAE_ZINC_no_kl") 39 | else: 40 | trainfile = "data/train_077a9bedefe77f2a34187eb57be2d416.txt" 41 | modelfile = "data/model-vaetrain-final.pt" 42 | vocabfile = "data/vocab-final.pkl" 43 | 44 | with open(vocabfile, "rb") as f: 45 | vocab = pickle.load(f) 46 | 47 | model = JTNNVAE(vocab=vocab, hidden_size=450, latent_size=56, depth=3) 48 | model.load_state_dict(torch.load(modelfile, map_location="cpu")) 49 | 50 | 51 | # %% pycharm={"name": "#%%\n"} 52 | model = model.to("cuda") 53 | 54 | # %% pycharm={"name": "#%%\n"} 55 | smiles = pd.read_csv("../lincs_trapnell.smiles") 56 | # need to remove the header, before passing it to JTVAE 57 | smiles.to_csv("jtvae_dataset.smiles", index=False, header=None) 58 | 59 | # %% pycharm={"name": "#%%\n"} 60 | dataset = JTVAEDataset("jtvae_dataset.smiles", vocab=model.vocab, training=False) 61 | collator = JTVAECollator(training=False) 62 | dataloader = torch.utils.data.DataLoader( 63 | dataset, batch_size=1, shuffle=False, collate_fn=collator, drop_last=True 64 | ) 65 | 66 | # %% [markdown] 67 | # ## Reconstruction demo 68 | # Reconstruct a couple of molecules to check reconstruction performance (it's not good). 69 | 70 | # %% pycharm={"name": "#%%\n"} 71 | acc = 0.0 72 | device = "cuda" 73 | for it, (tree, tree_graph, mol_graph) in enumerate(dataloader): 74 | if it > 10: 75 | break 76 | tot = it + 1 77 | smiles = tree.smiles 78 | tree_graph = tree_graph.to(device) 79 | mol_graph = mol_graph.to(device) 80 | dec_smiles = model.reconstruct(tree_graph, mol_graph) 81 | print(dec_smiles) 82 | print(smiles) 83 | print() 84 | if dec_smiles == smiles: 85 | acc += 1 86 | print("Final acc: {:.4f}".format(acc / tot)) 87 | 88 | # %% [markdown] 89 | # ## Generate embeddings for all LINCS + Trapnell molecules 90 | 91 | # %% pycharm={"is_executing": true, "name": "#%%\n"} 92 | get_data = lambda idx: collator([dataset[idx]]) 93 | errors = [] 94 | smiles = [] 95 | latents = [] 96 | for i in tqdm(range(len(dataset))): 97 | try: 98 | _, batch_tree_graphs, batch_mol_graphs = get_data(i) 99 | batch_tree_graphs = batch_tree_graphs.to("cuda") 100 | batch_mol_graphs = batch_mol_graphs.to("cuda") 101 | with torch.no_grad(): 102 | _, tree_vec, mol_vec = model.encode(batch_tree_graphs, batch_mol_graphs) 103 | latent = torch.cat([model.T_mean(tree_vec), model.G_mean(mol_vec)], dim=1) 104 | latents.append(latent) 105 | smiles.append(dataset.data[i]) 106 | except Exception as e: 107 | errors.append((dataset.data[i], e)) 108 | 109 | # %% pycharm={"is_executing": true, "name": "#%%\n"} 110 | # There should only be one error, a Cl.[Li] molecule. 111 | errors 112 | 113 | # %% pycharm={"is_executing": true, "name": "#%%\n"} 114 | # Add a dummy embedding for the Cl.[Li] molecule 115 | dummy_emb = torch.mean(torch.concat(latents), dim=0).unsqueeze(dim=0) 116 | assert dummy_emb.shape == latents[0].shape 117 | smiles.append(errors[0][0]) 118 | latents.append(dummy_emb) 119 | assert len(latents) == len(smiles) 120 | 121 | # %% pycharm={"is_executing": true, "name": "#%%\n"} 122 | np_latents = [latent.squeeze().cpu().detach().numpy() for latent in latents] 123 | final_df = pd.DataFrame( 124 | np_latents, 125 | index=smiles, 126 | columns=[f"latent_{i + 1}" for i in range(np_latents[0].shape[0])], 127 | ) 128 | final_df.to_parquet("data/jtvae_dgl.parquet") 129 | 130 | # %% pycharm={"is_executing": true, "name": "#%%\n"} 131 | final_df 132 | 133 | # %% pycharm={"is_executing": true, "name": "#%%\n"} 134 | smiles = pd.read_csv("../lincs_trapnell.smiles") 135 | smiles2 = final_df.index 136 | 137 | # %% pycharm={"is_executing": true, "name": "#%%\n"} 138 | set(list(smiles["smiles"])) == set(list(smiles2)) 139 | -------------------------------------------------------------------------------- /embeddings/jtvae/jtvae_train_all.yaml: -------------------------------------------------------------------------------- 1 | # Config for hyperparameter-tuning CPA on SciPlex 3 with pretrained (and frozen) HierVAE drug embeddings. 2 | seml: 3 | executable: seml_train.py 4 | name: jtvae_all 5 | output_dir: sweeps/logs 6 | conda_environment: jtvae_dgl 7 | project_root_dir: . 8 | 9 | slurm: 10 | max_simultaneous_jobs: 4 11 | experiments_per_job: 1 12 | sbatch_options_template: GPU 13 | sbatch_options: 14 | gres: gpu:1 # num GPUs 15 | mem: 32G # memory 16 | cpus-per-task: 6 # num cores 17 | time: 2-00:00 # max time, D-HH:MM 18 | ###### BEGIN PARAMETER CONFIGURATION ###### 19 | 20 | fixed: 21 | # Configured to take ~2 days to train. 22 | training.save_path: "pre_model_all" 23 | training.batch_size: 40 24 | training.hidden_size: 450 25 | training.latent_size: 56 26 | training.depth: 3 27 | training.lr: 0.001 28 | training.gamma: 0.9 29 | training.max_epoch: 2 30 | training.print_iter: 20 31 | training.save_iter: 200 32 | training.incl_zinc: True 33 | training.subsample_zinc_percent: 0.4 34 | training.training_path: "../lincs_trapnell.smiles" 35 | training.num_workers: 3 36 | -------------------------------------------------------------------------------- /embeddings/jtvae/jtvae_vaetrain_all.yaml: -------------------------------------------------------------------------------- 1 | # Config for hyperparameter-tuning CPA on SciPlex 3 with pretrained (and frozen) HierVAE drug embeddings. 2 | seml: 3 | executable: seml_train.py 4 | name: jtvae_vae_all 5 | output_dir: sweeps/logs 6 | conda_environment: jtvae_dgl 7 | project_root_dir: . 8 | 9 | slurm: 10 | max_simultaneous_jobs: 4 11 | experiments_per_job: 1 12 | sbatch_options_template: GPU 13 | sbatch_options: 14 | gres: gpu:1 # num GPUs 15 | mem: 32G # memory 16 | cpus-per-task: 6 # num cores 17 | time: 2-00:00 # max time, D-HH:MM 18 | ###### BEGIN PARAMETER CONFIGURATION ###### 19 | 20 | fixed: 21 | # Configured to take ~2 days to train. 22 | training.save_path: "vae_model_all" 23 | training.model_path: "pre_model_all/model.epoch-0" 24 | training.vocab_path: "pre_model_all/vocab_1f1775f24668d31640df46ce45fe3577.pkl" 25 | training.batch_size: 40 26 | training.hidden_size: 450 27 | training.latent_size: 56 28 | training.depth: 3 29 | training.lr: 0.007 30 | training.beta: 0.001 31 | training.gamma: 0.9 32 | training.max_epoch: 2 33 | training.print_iter: 20 34 | training.save_iter: 200 35 | training.incl_zinc: True 36 | training.subsample_zinc_percent: 0.4 37 | training.training_path: "../lincs_trapnell.smiles" 38 | training.num_workers: 3 39 | 40 | training.pretrain_only: False 41 | -------------------------------------------------------------------------------- /embeddings/jtvae/reconstruct.py: -------------------------------------------------------------------------------- 1 | import rdkit 2 | import torch 3 | from dgllife.data import JTVAEZINC, JTVAECollator, JTVAEDataset 4 | from dgllife.model import JTNNVAE, load_pretrained 5 | from dgllife.utils import JTVAEVocab 6 | from torch.utils.data import DataLoader 7 | 8 | 9 | def main(args): 10 | lg = rdkit.RDLogger.logger() 11 | lg.setLevel(rdkit.RDLogger.CRITICAL) 12 | 13 | if args.use_cpu or not torch.cuda.is_available(): 14 | device = torch.device("cpu") 15 | else: 16 | device = torch.device("cuda:0") 17 | 18 | vocab = JTVAEVocab(file_path=args.train_path) 19 | if args.test_path is None: 20 | dataset = JTVAEZINC("test", vocab) 21 | else: 22 | dataset = JTVAEDataset(args.test_path, vocab, training=False) 23 | dataloader = DataLoader( 24 | dataset, batch_size=1, collate_fn=JTVAECollator(training=False) 25 | ) 26 | 27 | if args.model_path is None: 28 | model = load_pretrained("JTVAE_ZINC_no_kl") 29 | else: 30 | model = JTNNVAE(vocab, args.hidden_size, args.latent_size, args.depth) 31 | model.load_state_dict(torch.load(args.model_path, map_location="cpu")) 32 | model = model.to(device) 33 | 34 | acc = 0.0 35 | for it, (tree, tree_graph, mol_graph) in enumerate(dataloader): 36 | tot = it + 1 37 | smiles = tree.smiles 38 | tree_graph = tree_graph.to(device) 39 | mol_graph = mol_graph.to(device) 40 | dec_smiles = model.reconstruct(tree_graph, mol_graph) 41 | if dec_smiles == smiles: 42 | acc += 1 43 | if tot % args.print_iter == 0: 44 | print( 45 | "Iter {:d}/{:d} | Acc {:.4f}".format( 46 | tot // args.print_iter, 47 | len(dataloader) // args.print_iter, 48 | acc / tot, 49 | ) 50 | ) 51 | print("Final acc: {:.4f}".format(acc / tot)) 52 | 53 | 54 | if __name__ == "__main__": 55 | from argparse import ArgumentParser 56 | 57 | parser = ArgumentParser() 58 | parser.add_argument( 59 | "-tr", 60 | "--train-path", 61 | type=str, 62 | help="Path to the training molecules, with one SMILES string a line", 63 | ) 64 | parser.add_argument( 65 | "-te", 66 | "--test-path", 67 | type=str, 68 | help="Path to the test molecules, with one SMILES string a line", 69 | ) 70 | parser.add_argument( 71 | "-m", "--model-path", type=str, help="Path to pre-trained model checkpoint" 72 | ) 73 | parser.add_argument( 74 | "-w", "--hidden-size", type=int, default=450, help="Hidden size" 75 | ) 76 | parser.add_argument("-l", "--latent-size", type=int, default=56, help="Latent size") 77 | parser.add_argument( 78 | "-d", "--depth", type=int, default=3, help="Number of GNN layers" 79 | ) 80 | parser.add_argument( 81 | "-pi", 82 | "--print-iter", 83 | type=int, 84 | default=20, 85 | help="Frequency for printing evaluation metrics", 86 | ) 87 | parser.add_argument( 88 | "-cpu", 89 | "--use-cpu", 90 | action="store_true", 91 | help="By default, the script uses GPU whenever available. " 92 | "This flag enforces the use of CPU.", 93 | ) 94 | args = parser.parse_args() 95 | 96 | main(args) 97 | -------------------------------------------------------------------------------- /embeddings/jtvae/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | # 6 | # pylint: disable= no-member, arguments-differ, invalid-name 7 | # 8 | # Utils for JTVAE 9 | import datetime 10 | import errno 11 | import os 12 | 13 | 14 | def get_timestamp(): 15 | return datetime.datetime.now().strftime("%d-%b-%Y-%H:%M:%S") 16 | 17 | 18 | def mkdir_p(path): 19 | """Create a folder for the given path. 20 | 21 | Parameters 22 | ---------- 23 | path: str 24 | Folder to create 25 | """ 26 | try: 27 | os.makedirs(path) 28 | print("Created directory {}".format(path)) 29 | except OSError as exc: 30 | if exc.errno == errno.EEXIST and os.path.isdir(path): 31 | print("Directory {} already exists.".format(path)) 32 | else: 33 | raise 34 | -------------------------------------------------------------------------------- /embeddings/rdkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/chemCPA/43e830eb0958c54e4aa64442c17ec0fed19b3f15/embeddings/rdkit/__init__.py -------------------------------------------------------------------------------- /embeddings/seq2seq/README.md: -------------------------------------------------------------------------------- 1 | # Seq2Seq 2 | 3 | This is a baseline model, a simple LSTM-based Seq2Seq model without any bells-or-whistles (eg no teacher-forcing). 4 | Code largely taken from [DeepChem examples](https://github.com/deepchem/deepchem/blob/master/examples/tutorials/Learning_Unsupervised_Embeddings_for_Molecules.ipynb). 5 | 6 | This model isn't great and shouldn't be used for any real production runs. -------------------------------------------------------------------------------- /embeddings/seq2seq/environment.yml: -------------------------------------------------------------------------------- 1 | name: cpa_seq2seq 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - rdkit 6 | - deepchem==2.5.0 7 | - jupyter 8 | - pandas 9 | - pip 10 | - pyarrow 11 | - pip: 12 | - tensorflow-gpu 13 | -------------------------------------------------------------------------------- /embeddings/seq2seq/generate_embeddings.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # notebook_metadata_filter: -kernelspec 5 | # text_representation: 6 | # extension: .py 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.14.1 10 | # --- 11 | 12 | # %% pycharm={"name": "#%%\n"} 13 | import deepchem as dc 14 | import pandas as pd 15 | import tensorflow as tf 16 | from rdkit import Chem 17 | from train_model import MAX_LENGTH, TOKENS 18 | 19 | print("Num GPUs Available: ", len(tf.config.list_physical_devices("GPU"))) 20 | 21 | # %% [markdown] 22 | # ## Load the most recent checkpoint 23 | # I stored all checkpoints in `embeddings/seq2seq/data` 24 | 25 | # %% pycharm={"name": "#%%\n"} 26 | model = dc.models.SeqToSeq( 27 | TOKENS, 28 | TOKENS, 29 | MAX_LENGTH, 30 | encoder_layers=2, 31 | decoder_layers=2, 32 | embedding_dimension=256, 33 | batch_size=100, 34 | model_dir="data", 35 | ) 36 | 37 | # %% pycharm={"name": "#%%\n"} 38 | model.get_checkpoints() 39 | 40 | # %% pycharm={"name": "#%%\n"} 41 | # loads the newest checkpoint 42 | model.restore() 43 | 44 | # %% [markdown] pycharm={"name": "#%% md\n"} 45 | # ## Load all SMILES 46 | # and predict their embedding 47 | 48 | # %% pycharm={"name": "#%%\n"} 49 | canonicalize = lambda smile: Chem.MolToSmiles(Chem.MolFromSmiles(smile)) 50 | all_smiles = list(pd.read_csv("../lincs_trapnell.smiles")["smiles"].values) 51 | 52 | # %% pycharm={"name": "#%%\n"} 53 | # quick check on subset of all embeddings 54 | pred = model.predict_from_sequences(all_smiles[0:15]) 55 | for s_pred, s_real in zip(pred, all_smiles[0:15]): 56 | s_pred = "".join(s_pred) 57 | print(f"{s_pred == s_real}\n-- {s_real}\n-- {s_pred}") 58 | 59 | # %% pycharm={"name": "#%%\n"} 60 | # actually predict all embeddings 61 | emb = model.predict_embeddings(all_smiles) 62 | 63 | # %% [markdown] pycharm={"name": "#%% md\n"} 64 | # ## Store the resulting embedding 65 | 66 | # %% pycharm={"name": "#%%\n"} 67 | final_df = pd.DataFrame( 68 | emb, index=all_smiles, columns=[f"latent_{i+1}" for i in range(emb.shape[1])] 69 | ) 70 | 71 | # %% pycharm={"name": "#%%\n"} 72 | final_df.to_parquet("data/seq2seq.parquet") 73 | final_df 74 | 75 | # %% pycharm={"name": "#%%\n"} 76 | assert sorted(pd.read_csv("../lincs_trapnell.smiles")["smiles"].values) == sorted( 77 | final_df.index.values 78 | ) 79 | -------------------------------------------------------------------------------- /embeddings/seq2seq/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env zsh 2 | 3 | #SBATCH -o slurm_output.txt 4 | #SBATCH -e slurm_error.txt 5 | #SBATCH -J seq2seq 6 | #SBATCH --partition gpu_p 7 | #SBATCH --cpus-per-task 6 8 | #SBATCH --mem=16G 9 | #SBATCH --exclude=supergpu05 10 | #SBATCH --gres=gpu:1 11 | #SBATCH --gres=mps:40 12 | #SBATCH --qos=gpu 13 | #SBATCH --time 05:00:00 14 | #SBATCH --nice=10000 15 | 16 | echo "Started running $(date)" 17 | /home/icb/simon.boehm/miniconda3/envs/cpa_seq2seq/bin/python3 train_model.py 18 | echo "Ending $(date)" 19 | -------------------------------------------------------------------------------- /embeddings/seq2seq/train_model.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import deepchem as dc 4 | import numpy as np 5 | import pandas as pd 6 | import tensorflow as tf 7 | from deepchem.models.optimizers import ExponentialDecay 8 | from rdkit import Chem 9 | 10 | # I generated these on 15.11.2021. If we update the smiles or add new drug 11 | # then these should be stored anew. 12 | TOKENS = [ 13 | "#", 14 | "(", 15 | ")", 16 | "*", 17 | "+", 18 | "-", 19 | ".", 20 | "/", 21 | "1", 22 | "2", 23 | "3", 24 | "4", 25 | "5", 26 | "6", 27 | "7", 28 | "8", 29 | "=", 30 | "@", 31 | "A", 32 | "B", 33 | "C", 34 | "F", 35 | "H", 36 | "I", 37 | "L", 38 | "M", 39 | "N", 40 | "O", 41 | "P", 42 | "S", 43 | "[", 44 | "\\", 45 | "]", 46 | "a", 47 | "c", 48 | "d", 49 | "e", 50 | "g", 51 | "i", 52 | "l", 53 | "n", 54 | "o", 55 | "r", 56 | "s", 57 | "t", 58 | "u", 59 | ] 60 | MAX_LENGTH = 461 61 | 62 | 63 | def load_train_val(datasets_fpath="../../datasets"): 64 | datasets_fpath = Path(datasets_fpath) 65 | 66 | # read in all relevant smiles 67 | train_smiles = [] 68 | for f in ["all_smiles_lincs_trapnell.csv", "train_smiles_muv.csv"]: 69 | x = pd.read_csv(datasets_fpath / f, header=None) 70 | train_smiles += list(x[0]) 71 | 72 | val_smiles = list( 73 | pd.read_csv(datasets_fpath / "validation_smiles_muv.csv", header=None)[0] 74 | ) 75 | 76 | # get canoncialized train / val split 77 | canonicalize = lambda smile: Chem.MolToSmiles(Chem.MolFromSmiles(smile)) 78 | train_smiles = np.array([canonicalize(smile) for smile in list(train_smiles)]) 79 | val_smiles = np.array([canonicalize(smile) for smile in list(val_smiles)]) 80 | return train_smiles, val_smiles 81 | 82 | 83 | def get_model( 84 | train_smiles, model_dir="data/big_256", encoder_layers=4, decoder_layers=4 85 | ): 86 | tokens = set() 87 | for s in train_smiles: 88 | tokens = tokens.union(set(c for c in s)) 89 | tokens = sorted(list(tokens)) 90 | 91 | max_length = max(len(s) for s in train_smiles) 92 | batch_size = 100 93 | batches_per_epoch = int(len(train_smiles) / batch_size) 94 | model = dc.models.SeqToSeq( 95 | TOKENS, 96 | TOKENS, 97 | MAX_LENGTH, 98 | encoder_layers=encoder_layers, 99 | decoder_layers=decoder_layers, 100 | embedding_dimension=256, 101 | model_dir=model_dir, 102 | batch_size=batch_size, 103 | learning_rate=ExponentialDecay(0.001, 0.9, batches_per_epoch), 104 | ) 105 | return model 106 | 107 | 108 | def train_model(model, train_smiles): 109 | def generate_sequences(epochs): 110 | for i in range(epochs): 111 | print("Epoch:", i) 112 | for s in train_smiles: 113 | yield (s, s) 114 | 115 | # there are ~92K molecules, batchsize is 100 -> ~920 train steps per epoch 116 | model.fit_sequences( 117 | generate_sequences(200), 118 | ) 119 | 120 | 121 | if __name__ == "__main__": 122 | print(Path().cwd()) 123 | # make sure GPU is available 124 | assert len(tf.config.list_physical_devices("GPU")) > 0 125 | train_smiles, val_smiles = load_train_val() 126 | model = get_model(train_smiles) 127 | train_model(model, train_smiles) 128 | 129 | # load the most recent checkpoint 130 | model.restore() 131 | pred = model.predict_from_sequences(val_smiles) 132 | n_restored_smiles = 0 133 | for s_pred, s_real in zip(pred, val_smiles): 134 | s_pred = "".join(s_pred) 135 | if s_pred == s_real: 136 | n_restored_smiles += 1 137 | print(f"Acc: {n_restored_smiles / len(val_smiles)}") 138 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: chemCPA 2 | channels: 3 | - pytorch 4 | - defaults 5 | - conda-forge 6 | dependencies: 7 | - python=3.9 8 | - adjusttext 9 | - matplotlib 10 | - numpy 11 | - pandas 12 | - scanpy 13 | - scipy 14 | - seaborn 15 | - scikit-learn 16 | - submitit 17 | - jupyter 18 | - jupytext 19 | - pytorch::pytorch=1.12.1 20 | - cudatoolkit=11.3 21 | - lightning 22 | - torchmetrics 23 | - h5py 24 | - dglteam::dgl-cuda11.3 25 | - deepchem 26 | - rdkit=2021.09.2 27 | - pytest 28 | - pre-commit 29 | - seml 30 | - py-spy 31 | - fastparquet 32 | - wandb 33 | - hydra-core 34 | # UMAP 35 | - umap-learn 36 | - datashader 37 | - bokeh 38 | - holoviews 39 | - colorcet 40 | - scikit-image 41 | - pip 42 | - pip: 43 | - pyarrow 44 | - tensorflow 45 | - dgllife 46 | - scgen 47 | - wandb -------------------------------------------------------------------------------- /experiments/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Predicting ood drug perturbations via transfer learning: Experiment setup 3 | This file explains the experiment setup for the chemical compositional perturbational autoencoder (CCPA). The idea of this project is to investigate how well counterfactual predictions can be made for unseen drugs at a single cell level. 4 | 5 | To this end, we extended the original [CPA model](https://github.com/facebookresearch/CPA) to include drug embedding neural networks, that can map SMILES strings to a chemically meanigful embedding. Since these neural networks are not limited to a set of drugs used in a real experiment, it is possible to investigate perturbations for any drug. We hypothesise that such an ood perturbational prediction can only be meaningful when the model has seen a sufficient number of different drugs. However, scRNA-seq perturbation screens are limited in the number of drugs they can investiate. The sciplex dataset, for example, contains only 188 drugs. To alleviate this shortcoming, we aim to enrich the drug latent space with a transfer learning approach. We investigate how and, more importantly, if a blukSeq perturbation screen, LINCS, which include more than 17k different drugs, can be used to improve the ood generalisation of the CCPA model for single-cell resolved perturbation predictions. 6 | 7 | We split the experiments in multiple parts: 8 | 9 | ## 1. EXP: `lincs_rdkit_hparam` 10 | This experiment is used to identify model configurations that have sufficient performance on the L1000 datasets. The resulting model configurations are then transferred for the finetuning experiments `sciplex_hparam`, `fintuning_num_genes`. 11 | 12 | This experiment runs in 2 steps: 13 | 1. We run a (large) hyperparameter sweep using chemCPA with RDKit embeddings. We use the results to pick good hyperparameters for the autoencoder and the adversarial predictors. See `config_lincs_rdkit_hparam_sweep.yaml`. 14 | 2. We run a (small) hyperparameter sweep using chemCPA with all other embeddings. We sweep just over the embedding-related hparams (drug encoder, drug doser), while fixing the AE & ADV related hparams as selected through (1). See `config_lincs_all_embeddings_hparam_sweep.yaml`. 15 | 16 | 17 | ## 2. EXP: `sciplex_hparam` 18 | This experiment is run to determine suitable optimisation hparams for the adversary when fine-tuning on the sciplex dataset. These hparams are meant to be shared when evaluating transfer performace for different drug embedding models. 19 | 20 | Similar to `lincs_rdkit_hparam`, we subset to the `grover_base`, `jtvae`, and `rdkit` embedding to be considerate wrt to compute resources. 21 | 22 | Setup: 23 | Importantly, we sweep over a split that set some drugs as ood. In this setting the original CPA model is not applicable anymore. The drugs were chose according to the results from the original [sciplex publication](https://www.science.org/doi/full/10.1126/science.aax6234), cf. Fig.S6 in the supplements of the publication. 24 | 25 | ## 3. EXP: `finetuning_num_genes` 26 | * In this experiment we test how the pretraining on lincs with a smaller set of genes helps to increase the performance on a larger gene set for sciplex. We use the ood split `split_ood_finetuning` in `lincs_sciplex_gene_matching.ipynb`. 27 | 28 | **Why is this interesting?** 29 | * This is biologically relevant as different single-cell datasets have different important gene sets that explain their variation. 30 | 31 | **Experiment steps**: 32 | 33 | 1. Pretrain on LINCS (~900 genes), finetune on Trapnell (same ~900 genes) - `'config_sciplex_finetune_lincs_genes.yaml'` 34 | 2. Pretrain on LINCS (~900 genes), finetune on Trapnell (2000 genes) 35 | 3. Train from Scratch on Trapnell (900 genes) - `'config_sciplex_finetune_lincs_genes.yaml'` 36 | 4. Train from Scratch on Trapnell (2000 genes) 37 | 38 | Compare performances between chemCPA with pretraining (1. and 2.) and chemCPA without pretraining (3. and 4.) for each of the two settings. 39 | -------------------------------------------------------------------------------- /experiments/dom_experiments/combine_adata_biolord.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # notebook_metadata_filter: -kernelspec 5 | # text_representation: 6 | # extension: .py 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.14.1 10 | # --- 11 | 12 | # %% 13 | # If not done before: combine adatas! 14 | 15 | from pathlib import Path 16 | 17 | import pandas as pd 18 | import scanpy as sc 19 | 20 | from chemCPA.paths import PROJECT_DIR 21 | 22 | adatas = [ 23 | sc.read(PROJECT_DIR / "datasets" / f"adata_{split}_biolord_split_30.h5ad") for split in ["train", "test", "ood"] 24 | ] 25 | # adatas = [sc.read(Path("project_folder") / "datasets" / f"adata_{split}_biolord_split_30.h5ad") for split in ["train", "test"]] 26 | 27 | for adata in adatas: 28 | df = adata.obs 29 | for col in df.select_dtypes(["category"]).columns: 30 | _type = type(df[col].cat.categories[0]) 31 | print(f"{col}: {_type}") 32 | try: 33 | df[col] = df[col].astype(_type) 34 | except: 35 | print(col) 36 | df[col] = df[col].astype(str) 37 | 38 | 39 | adata = sc.concat(adatas) 40 | 41 | 42 | key_check = ( 43 | ~pd.Series(adatas[0].uns["rank_genes_groups_cov_all"].keys()).isin( 44 | list(adatas[1].uns["rank_genes_groups_cov_all"].keys()) 45 | ) 46 | ).sum() 47 | print(f"Key check: {key_check} should be 0.") 48 | 49 | adata.uns["rank_genes_groups_cov_all"] = adatas[0].uns["rank_genes_groups_cov_all"] 50 | 51 | # Add proper DMSO 52 | adata.obs["smiles"].replace("nan", "CS(C)=O", inplace=True) 53 | 54 | # %% 55 | sc.write(Path("project_folder") / "datasets" / "adata_biolord_split_30.h5ad", adata) 56 | -------------------------------------------------------------------------------- /experiments/dom_experiments/compute_embedding_rdkit.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # notebook_metadata_filter: -kernelspec 5 | # text_representation: 6 | # extension: .py 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.14.1 10 | # --- 11 | 12 | # %% [markdown] 13 | # **Requirements** 14 | # * According to this [paper](https://arxiv.org/pdf/1904.01561.pdf), features are computed with [descriptastorus](https://github.com/bp-kelley/descriptastorus) package 15 | # * Install via: `pip install git+https://github.com/bp-kelley/descriptastorus` 16 | 17 | # %% [markdown] 18 | # ## General imports 19 | 20 | # %% 21 | import sys 22 | 23 | # this depends on the notebook depth and must be adapted per notebook 24 | sys.path.insert(0, "/") 25 | # %% 26 | import numpy as np 27 | 28 | # %% 29 | import scanpy as sc 30 | from joblib import Parallel, delayed 31 | from tqdm.notebook import tqdm 32 | 33 | from chemCPA.helper import canonicalize_smiles 34 | from chemCPA.paths import DATA_DIR, EMBEDDING_DIR 35 | 36 | # %% [markdown] 37 | # ## Load Smiles list 38 | 39 | 40 | # %% 41 | adata = sc.read(DATA_DIR / "adata_biolord_split_30.h5ad") 42 | 43 | # %% 44 | smiles_list = adata.obs["smiles"].unique() 45 | # exclude nan from smiles_list 46 | smiles_list = [canonicalize_smiles(s) for s in smiles_list if s != "nan"] 47 | 48 | # %% 49 | print(f"Number of smiles strings: {len(smiles_list)}") 50 | 51 | # %% 52 | from descriptastorus.descriptors.DescriptorGenerator import MakeGenerator 53 | 54 | generator = MakeGenerator(("RDKit2D",)) 55 | for name, numpy_type in generator.GetColumns(): 56 | print(f"{name}({numpy_type.__name__})") 57 | 58 | # %% 59 | n_jobs = 16 60 | data = Parallel(n_jobs=n_jobs)( 61 | delayed(generator.process)(smiles) for smiles in tqdm(smiles_list, position=0, leave=True) 62 | ) 63 | 64 | # %% 65 | data = [d[1:] for d in data] 66 | 67 | # %% 68 | embedding = np.array(data) 69 | embedding.shape 70 | 71 | # %% [markdown] 72 | # ## Check `nans` and `infs` 73 | 74 | # %% [markdown] 75 | # Check for `nans` 76 | 77 | # %% 78 | drug_idx, feature_idx = np.where(np.isnan(embedding)) 79 | print(f"drug_idx:\n {drug_idx}") 80 | print(f"feature_idx:\n {feature_idx}") 81 | 82 | # %% [markdown] 83 | # Check for `infs` and add to idx lists 84 | 85 | # %% 86 | drug_idx_infs, feature_idx_infs = np.where(np.isinf(embedding)) 87 | 88 | drug_idx = np.concatenate((drug_idx, drug_idx_infs)) 89 | feature_idx = np.concatenate((feature_idx, feature_idx_infs)) 90 | 91 | # %% [markdown] 92 | # Features that have these invalid values: 93 | 94 | # %% tags=[] 95 | np.array(generator.GetColumns())[np.unique(feature_idx)] 96 | 97 | # %% [markdown] 98 | # Set values to `0` 99 | 100 | # %% 101 | embedding[drug_idx, feature_idx] 102 | 103 | # %% 104 | embedding[drug_idx, feature_idx] = 0 105 | 106 | # %% [markdown] 107 | # ## Save 108 | 109 | # %% 110 | import pandas as pd 111 | 112 | df = pd.DataFrame(data=embedding, index=smiles_list, columns=[f"latent_{i}" for i in range(embedding.shape[1])]) 113 | 114 | # Drop first feature from generator (RDKit2D_calculated) 115 | df.drop(columns=["latent_0"], inplace=True) 116 | 117 | # Drop columns with 0 standard deviation 118 | threshold = 0.01 119 | columns = [f"latent_{idx+1}" for idx in np.where(df.std() <= threshold)[0]] 120 | print(f"Deleting columns with std<={threshold}: {columns}") 121 | df.drop(columns=[f"latent_{idx+1}" for idx in np.where(df.std() <= 0.01)[0]], inplace=True) 122 | 123 | # %% [markdown] 124 | # Check that correct columns were deleted: 125 | 126 | # %% 127 | np.where(df.std() <= threshold) 128 | 129 | # %% [markdown] 130 | # ### Normalise dataframe 131 | 132 | # %% 133 | normalized_df = (df - df.mean()) / df.std() 134 | 135 | # %% 136 | normalized_df 137 | 138 | # %% [markdown] 139 | # Check destination folder 140 | 141 | # %% 142 | model_name = "rdkit2D" 143 | dataset_name = "biolord" 144 | fname = f"{model_name}_embedding_{dataset_name}.parquet" 145 | 146 | directory = EMBEDDING_DIR / "rdkit" / "data" / "embeddings" 147 | directory.mkdir(parents=True, exist_ok=True) 148 | 149 | # %% [markdown] 150 | # Save normalised version 151 | 152 | # %% 153 | normalized_df.to_parquet(directory / fname) 154 | 155 | # %% [markdown] 156 | # Check that it worked 157 | 158 | # %% 159 | df = pd.read_parquet(directory / fname) 160 | df 161 | 162 | # %% 163 | directory / fname 164 | 165 | # %% 166 | -------------------------------------------------------------------------------- /experiments/finetuning_num_genes/README.md: -------------------------------------------------------------------------------- 1 | **Summary** 2 | * In this experiment we test how the pretraining on lincs with a smaller set of genes helps to increase the performance on a larger gene set for sciplex. We use the ood split `split_ood_finetuning` in `lincs_sciplex_gene_matching.ipynb`. 3 | 4 | **Why is this interesting?** 5 | * This is biologically relevant as different single-cell datasets have different important gene sets that explain their variation. 6 | 7 | **Experiment steps**: 8 | 9 | 1. Pretrain on LINCS (~900 genes), finetune on Trapnell (same ~900 genes) - `'config_sciplex_finetune_lincs_genes.yaml'` 10 | 2. Pretrain on LINCS (~900 genes), finetune on Trapnell (2000 genes) 11 | 3. Train from Scratch on Trapnell (900 genes) - `'config_sciplex_finetune_lincs_genes.yaml'` 12 | 4. Train from Scratch on Trapnell (2000 genes) 13 | 14 | Compare performances between chemCPA with pretraining (1. and 2.) and chemCPA without pretraining (3. and 4.) for each of the two settings. -------------------------------------------------------------------------------- /experiments/lincs_rdkit_hparam/README.md: -------------------------------------------------------------------------------- 1 | This experiment runs in 2 steps: 2 | 1. We run a (large) hyperparameter sweep using chemCPA with RDKit embeddings. We use the results to pick good hyperparameters for the autoencoder and the adversarial predictors. See `config_lincs_rdkit_hparam_sweep.yaml`. 3 | 2. We run a (small) hyperparameter sweep using chemCPA with all other embeddings. We sweep just over the embedding-related hparams (drug encoder, drug doser), while fixing the AE & ADV related hparams as selected through (1). See `config_lincs_all_embeddings_hparam_sweep.yaml`. 4 | -------------------------------------------------------------------------------- /experiments/lincs_rdkit_hparam/config_lincs_all_embbeddings_hparam_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Config for hyperparameter-tuning chemCPA on L1000, using all embeddings, while keeping AE hparams fixed 2 | # (part 2 of lincs_rdkit_hparam) 3 | seml: 4 | executable: chemCPA/seml_sweep_icb.py 5 | name: lincs_all_emb_hparam 6 | output_dir: sweeps/logs 7 | conda_environment: chemical_CPA 8 | project_root_dir: ../.. 9 | 10 | slurm: 11 | max_simultaneous_jobs: 20 12 | experiments_per_job: 1 13 | sbatch_options_template: GPU 14 | sbatch_options: 15 | gres: gpu:1 # num GPUs 16 | mem: 32G # memory 17 | cpus-per-task: 6 # num cores 18 | # speeds is roughly 3 epochs / minute 19 | time: 1-00:01 # max time, D-HH:MM 20 | ###### BEGIN PARAMETER CONFIGURATION ###### 21 | 22 | fixed: 23 | profiling.run_profiler: False 24 | profiling.outdir: "./" 25 | 26 | training.checkpoint_freq: 25 # checkpoint frequency to run evaluate, and maybe save checkpoint 27 | training.num_epochs: 1500 # maximum epochs for training. One epoch updates either autoencoder, or adversary, depending on adversary_steps. 28 | training.max_minutes: 1200 # maximum computation time 29 | training.full_eval_during_train: False 30 | training.run_eval_disentangle: True # whether to calc the disentanglement loss when running the full eval 31 | training.save_checkpoints: True # checkpoints tend to be ~250MB large for LINCS. 32 | training.save_dir: /storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/sweeps/checkpoints 33 | 34 | dataset.dataset_type: lincs 35 | dataset.data_params.dataset_path: /storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/datasets/lincs_full_smiles_sciplex_genes.h5ad # full path to the anndata dataset 36 | dataset.data_params.perturbation_key: pert_id # stores name of the drug 37 | dataset.data_params.pert_category: cov_drug_dose_name # stores celltype_drugname_drugdose 38 | dataset.data_params.dose_key: pert_dose # stores drug dose as a float 39 | dataset.data_params.covariate_keys: cell_type # necessary field for cell types. Fill it with a dummy variable if no celltypes present. 40 | dataset.data_params.smiles_key: canonical_smiles 41 | dataset.data_params.degs_key: rank_genes_groups_cov # `uns` column name denoting the DEGs for each perturbation 42 | dataset.data_params.split_key: random_split # necessary field for train, test, ood splits. 43 | dataset.data_params.use_drugs_idx: True # If false, will use One-hot encoding instead 44 | 45 | model.load_pretrained: False 46 | model.pretrained_model_path: null 47 | model.pretrained_model_hashes: null 48 | model.additional_params.patience: 4 # patience for early stopping. Effective epochs: patience * checkpoint_freq. 49 | model.additional_params.decoder_activation: linear # last layer of the decoder 50 | model.additional_params.doser_type: amortized # non-linearity for doser function 51 | model.embedding.directory: null # null will load the path from paths.py 52 | 53 | model.additional_params.seed: 1337 54 | 55 | # these were picked using the first part of the experiment 56 | model.hparams.dim: 32 57 | model.hparams.dropout: 0.262378 58 | model.hparams.autoencoder_width: 256 59 | model.hparams.autoencoder_depth: 4 60 | model.hparams.autoencoder_lr: 0.001121 61 | model.hparams.autoencoder_wd: 3.752056e-7 62 | model.hparams.adversary_width: 128 63 | model.hparams.adversary_depth: 3 64 | model.hparams.adversary_lr: 0.000806 65 | model.hparams.adversary_wd: 0.000004 66 | model.hparams.adversary_steps: 2 # every X steps, update the adversary INSTEAD OF the autoencoder. 67 | model.hparams.reg_adversary: 24.082073 68 | model.hparams.penalty_adversary: 3.347776 69 | model.hparams.batch_size: 128 70 | 71 | grid: 72 | model.embedding.model: 73 | type: choice 74 | options: 75 | - vanilla 76 | - grover_base 77 | - GCN 78 | - weave 79 | - MPNN 80 | - rdkit 81 | - jtvae 82 | - seq2seq 83 | 84 | random: 85 | samples: 18 86 | seed: 42 87 | model.hparams.dosers_width: 88 | type: choice 89 | options: 90 | - 64 91 | - 128 92 | - 256 93 | - 512 94 | model.hparams.dosers_depth: 95 | type: choice 96 | options: 97 | - 1 98 | - 2 99 | - 3 100 | model.hparams.dosers_lr: 101 | type: loguniform 102 | min: 1e-4 103 | max: 1e-2 104 | model.hparams.dosers_wd: 105 | type: loguniform 106 | min: 1e-8 107 | max: 1e-5 108 | model.hparams.step_size_lr: # this applies to all optimizers (AE, ADV, DRUG) 109 | type: choice 110 | options: 111 | - 50 112 | - 100 113 | - 200 114 | model.hparams.embedding_encoder_width: 115 | type: choice 116 | options: 117 | - 128 118 | - 256 119 | - 512 120 | model.hparams.embedding_encoder_depth: 121 | type: choice 122 | options: 123 | - 2 124 | - 3 125 | - 4 126 | -------------------------------------------------------------------------------- /experiments/lincs_rdkit_hparam/config_lincs_rdkit_hparam_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Config for hyperparameter-tuning chemCPA on L1000, using rdkit embeddings (part 1 of lincs_rdkit_hparam) 2 | seml: 3 | executable: chemCPA/seml_sweep_icb.py 4 | name: lincs_rdkit_hparam 5 | output_dir: sweeps/logs 6 | conda_environment: chemical_CPA 7 | project_root_dir: ../.. 8 | 9 | slurm: 10 | max_simultaneous_jobs: 10 11 | experiments_per_job: 1 12 | sbatch_options_template: GPU 13 | sbatch_options: 14 | gres: gpu:1 # num GPUs 15 | mem: 32G # memory 16 | cpus-per-task: 6 # num cores 17 | # speeds is roughly 3 epochs / minute 18 | time: 1-00:01 # max time, D-HH:MM 19 | ###### BEGIN PARAMETER CONFIGURATION ###### 20 | 21 | fixed: 22 | profiling.run_profiler: False 23 | profiling.outdir: "./" 24 | 25 | training.checkpoint_freq: 25 # checkpoint frequency to run evaluate, and maybe save checkpoint 26 | training.num_epochs: 1500 # maximum epochs for training. One epoch updates either autoencoder, or adversary, depending on adversary_steps. 27 | training.max_minutes: 1200 # maximum computation time 28 | training.full_eval_during_train: False 29 | training.run_eval_disentangle: True # whether to calc the disentanglement loss when running the full eval 30 | training.save_checkpoints: True # checkpoints tend to be ~250MB large for LINCS. 31 | training.save_dir: /storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/sweeps/checkpoints 32 | 33 | dataset.dataset_type: lincs 34 | dataset.data_params.dataset_path: /storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/datasets/lincs_full_smiles_sciplex_genes.h5ad # full path to the anndata dataset 35 | dataset.data_params.perturbation_key: pert_id # stores name of the drug 36 | dataset.data_params.pert_category: cov_drug_dose_name # stores celltype_drugname_drugdose 37 | dataset.data_params.dose_key: pert_dose # stores drug dose as a float 38 | dataset.data_params.covariate_keys: cell_type # necessary field for cell types. Fill it with a dummy variable if no celltypes present. 39 | dataset.data_params.smiles_key: canonical_smiles 40 | dataset.data_params.degs_key: rank_genes_groups_cov # `uns` column name denoting the DEGs for each perturbation 41 | dataset.data_params.split_key: random_split # necessary field for train, test, ood splits. 42 | dataset.data_params.use_drugs_idx: True # If false, will use One-hot encoding instead 43 | 44 | model.load_pretrained: False 45 | model.pretrained_model_path: null 46 | model.pretrained_model_hashes: null 47 | model.additional_params.patience: 3 # patience for early stopping. Effective epochs: patience * checkpoint_freq. 48 | model.additional_params.decoder_activation: linear # last layer of the decoder 49 | model.additional_params.doser_type: amortized # non-linearity for doser function 50 | model.embedding.directory: null # null will load the path from paths.py 51 | 52 | model.embedding.model: rdkit 53 | model.additional_params.seed: 1337 54 | 55 | random: 56 | samples: 25 57 | seed: 42 58 | model.hparams.dim: 59 | type: choice 60 | options: 61 | - 64 62 | - 32 63 | model.hparams.dropout: 64 | type: uniform 65 | min: 0.0 66 | max: 0.5 67 | model.hparams.dosers_width: 68 | type: choice 69 | options: 70 | - 64 71 | - 128 72 | - 256 73 | - 512 74 | model.hparams.dosers_depth: 75 | type: choice 76 | options: 77 | - 1 78 | - 2 79 | - 3 80 | model.hparams.dosers_lr: 81 | type: loguniform 82 | min: 1e-4 83 | max: 1e-2 84 | model.hparams.dosers_wd: 85 | type: loguniform 86 | min: 1e-8 87 | max: 1e-5 88 | model.hparams.autoencoder_width: 89 | type: choice 90 | options: 91 | - 128 92 | - 256 93 | - 512 94 | model.hparams.autoencoder_depth: 95 | type: choice 96 | options: 97 | - 3 98 | - 4 99 | - 5 100 | model.hparams.autoencoder_lr: 101 | type: loguniform 102 | min: 1e-4 103 | max: 1e-2 104 | model.hparams.autoencoder_wd: 105 | type: loguniform 106 | min: 1e-8 107 | max: 1e-5 108 | model.hparams.adversary_width: 109 | type: choice 110 | options: 111 | - 64 112 | - 128 113 | - 256 114 | model.hparams.adversary_depth: 115 | type: choice 116 | options: 117 | - 2 118 | - 3 119 | - 4 120 | model.hparams.adversary_lr: 121 | type: loguniform 122 | min: 5e-5 123 | max: 1e-2 124 | model.hparams.adversary_wd: 125 | type: loguniform 126 | min: 1e-8 127 | max: 1e-3 128 | model.hparams.adversary_steps: # every X steps, update the adversary INSTEAD OF the autoencoder. 129 | type: choice 130 | options: 131 | - 2 132 | - 3 133 | model.hparams.reg_adversary: 134 | type: loguniform 135 | min: 5 136 | max: 100 137 | model.hparams.penalty_adversary: 138 | type: loguniform 139 | min: 1 140 | max: 10 141 | model.hparams.batch_size: 142 | type: choice 143 | options: 144 | - 32 145 | - 64 146 | - 128 147 | model.hparams.step_size_lr: 148 | type: choice 149 | options: 150 | - 50 151 | - 100 152 | - 200 153 | model.hparams.embedding_encoder_width: 154 | type: choice 155 | options: 156 | - 128 157 | - 256 158 | - 512 159 | model.hparams.embedding_encoder_depth: 160 | type: choice 161 | options: 162 | - 2 163 | - 3 164 | - 4 165 | -------------------------------------------------------------------------------- /experiments/sciplex_hparam/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## EXP: `sciplex_hparam` 3 | This experiment is run to determine suitable optimisation hparams for the adversary when fine-tuning on the sciplex dataset. These hparams are meant to be shared when evaluating transfer performace for different drug embedding models. 4 | 5 | Similar to `lincs_rdkit_hparam`, we subset to the `grover_base`, `jtvae`, and `rdkit` embedding to be considerate wrt to compute resources. 6 | 7 | Setup: 8 | Importantly, we sweep over a split that set some drugs as ood. In this setting the original CPA model is not applicable anymore. The drugs were chose according to the results from the original [sciplex publication](https://www.science.org/doi/full/10.1126/science.aax6234), cf. Fig.S6 in the supplements of the publication. 9 | -------------------------------------------------------------------------------- /manual_seml_sweep.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from pprint import pprint 3 | 4 | from seml.config import generate_configs, read_config 5 | 6 | from chemCPA.experiments_run import ExperimentWrapper 7 | 8 | if __name__ == "__main__": 9 | exp = ExperimentWrapper(init_all=False) 10 | 11 | # this is how seml loads the config file internally 12 | config = "test_config_biolord.yaml" 13 | assert Path(config).exists(), "config file not found" 14 | seml_config, slurm_config, experiment_config = read_config(config) 15 | # we take the first config generated 16 | configs = generate_configs(experiment_config) 17 | if len(configs) > 1: 18 | print("Careful, more than one config generated from the yaml file") 19 | args = configs[0] 20 | pprint(args) 21 | 22 | exp.seed = 1337 23 | # loads the dataset splits 24 | exp.init_dataset(**args["dataset"]) 25 | 26 | exp.init_drug_embedding(embedding=args["model"]["embedding"]) 27 | exp.init_model( 28 | hparams=args["model"]["hparams"], 29 | additional_params=args["model"]["additional_params"], 30 | load_pretrained=args["model"]["load_pretrained"], 31 | append_ae_layer=args["model"]["append_ae_layer"], 32 | enable_cpa_mode=args["model"]["enable_cpa_mode"], 33 | pretrained_model_path=args["model"]["pretrained_model_path"], 34 | pretrained_model_hashes=args["model"]["pretrained_model_hashes"], 35 | ) 36 | # setup the torch DataLoader 37 | exp.update_datasets() 38 | 39 | exp.train(**args["training"]) 40 | -------------------------------------------------------------------------------- /notebooks/Additional/analysis_results.md: -------------------------------------------------------------------------------- 1 | Dose: 0.1 muMol 2 | 3 | | model | ('Median all genes', 'R2') | ('Median all genes', 'delta') | ('Median DEGs', 'R2') | ('Median DEGs', 'delta') | ('Mean all genes', 'R2') | ('Mean all genes', 'delta') | ('Mean DEGs', 'R2') | ('Mean DEGs', 'delta') | 4 | |:-------------------|-----------------------------:|--------------------------------:|------------------------:|---------------------------:|---------------------------:|------------------------------:|----------------------:|-------------------------:| 5 | | Baseline | 82.5 | 0 | 62.2 | 0 | 69.3 | 0 | 51 | 0 | 6 | | scGen | 76.7 | -0.9 | 68.1 | 3.4 | 72.7 | 3.4 | 59.1 | 8.2 | 7 | | CPA | 85.6 | 2.3 | 66.8 | 3.3 | 71.7 | 2.5 | 54.1 | 3.1 | 8 | | chemCPA | 85.7 | 2.2 | 66.1 | 4.8 | 73.8 | 4.5 | 60.5 | 9.5 | 9 | | chemCPA pretrained | 85.3 | 2 | 75.5 | 6.4 | 76.9 | 7.6 | 68.3 | 17.3 | 10 | 11 | Dose: 1.0 muMol 12 | 13 | | model | ('Median all genes', 'R2') | ('Median all genes', 'delta') | ('Median DEGs', 'R2') | ('Median DEGs', 'delta') | ('Mean all genes', 'R2') | ('Mean all genes', 'delta') | ('Mean DEGs', 'R2') | ('Mean DEGs', 'delta') | 14 | |:-------------------|-----------------------------:|--------------------------------:|------------------------:|---------------------------:|---------------------------:|------------------------------:|----------------------:|-------------------------:| 15 | | Baseline | 48.5 | 0 | 12.2 | 0 | 50.3 | 0 | 28.7 | 0 | 16 | | scGen | 66.2 | 5.2 | 48.9 | 13.9 | 62.2 | 11.9 | 47.2 | 18.5 | 17 | | CPA | 52.5 | 3.4 | 25.7 | 4 | 53.9 | 3.6 | 34.2 | 5.5 | 18 | | chemCPA | 76.7 | 5.5 | 64.4 | 10.9 | 71.2 | 20.9 | 58.4 | 29.7 | 19 | | chemCPA pretrained | 82.4 | 6.6 | 78.6 | 15.5 | 76.5 | 26.2 | 67.8 | 39.1 | 20 | -------------------------------------------------------------------------------- /notebooks/Additional/experiment_analysis.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # notebook_metadata_filter: -kernelspec 5 | # text_representation: 6 | # extension: .py 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.14.1 10 | # --- 11 | 12 | # %% 13 | import matplotlib 14 | import matplotlib.pyplot as plt 15 | import pandas as pd 16 | import seaborn as sns 17 | 18 | # %% 19 | # %load_ext lab_black 20 | # %load_ext autoreload 21 | # %autoreload 2 22 | 23 | # %% 24 | BLACK = False 25 | 26 | if BLACK: 27 | plt.style.use("dark_background") 28 | else: 29 | matplotlib.style.use("fivethirtyeight") 30 | matplotlib.style.use("seaborn-talk") 31 | matplotlib.pyplot.rcParams["savefig.facecolor"] = "white" 32 | sns.set_style("whitegrid") 33 | 34 | matplotlib.rcParams["font.family"] = "monospace" 35 | matplotlib.rcParams["figure.dpi"] = 120 36 | sns.set_context("poster") 37 | 38 | # %% 39 | df = pd.concat( 40 | [ 41 | pd.read_parquet("cpa_predictions.parquet"), 42 | pd.read_parquet("scgen_predictions.parquet"), 43 | ] 44 | ) 45 | # df = pd.concat( 46 | # [ 47 | # pd.read_parquet("cpa_predictions_high_dose.parquet"), 48 | # pd.read_parquet("scgen_predictions_high_dose.parquet"), 49 | # ] 50 | # ) 51 | 52 | # %% 53 | df 54 | 55 | # %% 56 | fig, ax = plt.subplots(figsize=(20, 9)) 57 | 58 | sns.boxplot( 59 | data=df[df["genes"] == "all"], 60 | x="condition", 61 | y="R2", 62 | hue="model", 63 | palette="tab10", 64 | ) 65 | ax.set_xticklabels(ax.get_xticklabels(), rotation=75, ha="right") 66 | ax.set_xlabel("") 67 | ax.set_ylabel("$E[r^2]$ on all genes") 68 | ax.legend( 69 | title="Model type", 70 | # fontsize=18, 71 | # title_fontsize=24, 72 | loc="lower left", 73 | bbox_to_anchor=(1, 0.2), 74 | ) 75 | ax.grid(".", color="darkgrey") 76 | plt.tight_layout() 77 | 78 | # %% 79 | fig, ax = plt.subplots(figsize=(20, 9)) 80 | 81 | sns.boxplot( 82 | data=df[df["genes"] == "degs"], 83 | x="condition", 84 | y="R2", 85 | hue="model", 86 | palette="tab10", 87 | ) 88 | ax.set_xticklabels(ax.get_xticklabels(), rotation=75, ha="right") 89 | ax.set_xlabel("") 90 | ax.set_ylabel("$E[r^2]$ on all genes") 91 | ax.legend( 92 | title="Model type", 93 | # fontsize=18, 94 | # title_fontsize=24, 95 | loc="lower left", 96 | bbox_to_anchor=(1, 0.2), 97 | ) 98 | ax.grid(".", color="darkgrey") 99 | plt.tight_layout() 100 | 101 | # %% 102 | df.groupby(["model", "genes"]).std() 103 | 104 | # %% 105 | df.groupby(["model", "genes"]).mean() 106 | 107 | # %% 108 | DELTA = False 109 | 110 | if DELTA: 111 | df["delta"] = 0 112 | 113 | for cond, _df in df.groupby(["cell_type", "condition", "genes"]): 114 | df.loc[ 115 | df[["cell_type", "condition", "genes"]].isin(cond).prod(1).astype(bool), 116 | "delta", 117 | ] = ( 118 | _df["R2"].values - _df.loc[_df["model"] == "baseline", "R2"].values[0] 119 | ) 120 | 121 | # %% 122 | df 123 | 124 | # %% 125 | df1 = df[df.genes == "all"].groupby(["model"]).mean().round(2) 126 | df2 = df[df.genes == "degs"].groupby(["model"]).mean().round(2) 127 | df3 = df[df.genes == "all"].groupby(["model"]).median().round(2) 128 | df4 = df[df.genes == "degs"].groupby(["model"]).median().round(2) 129 | 130 | # %% 131 | result_df = ( 132 | pd.concat( 133 | [df1, df2, df3, df4], 134 | axis=1, 135 | keys=["Mean all genes", "Mean DEGs", "Median all genes", "Median DEGs"], 136 | ) 137 | .reindex(["baseline", "scGen", "cpa", "chemCPA", "chemCPA_pretrained"]) 138 | .rename( 139 | index={ 140 | "baseline": "Baseline", 141 | "cpa": "CPA", 142 | "chemCPA_pretrained": "chemCPA pretrained", 143 | } 144 | ) 145 | ) 146 | 147 | result_df 148 | 149 | # %% 150 | print(result_df.to_markdown()) 151 | 152 | # %% 153 | print(result_df.to_latex()) 154 | 155 | # %% 156 | -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | # Predicting single-cell perturbation responses for unseen drugs - Notebooks 2 | 3 | These notebooks are meant to showcase how to analyse a trained chemCPA model. They also reproduce the results from the paper. 4 | 5 | To load the model configs please use the provided `.json` file and define your `load_config` function similar to this: 6 | 7 | ```python 8 | import json 9 | from tqdm.auto import tqdm 10 | from chemCPA.paths import PROJECT_DIR 11 | 12 | def load_config(seml_collection, model_hash): 13 | file_path = PROJECT_DIR / f"{seml_collection}.json" # Provide path to json 14 | 15 | with open(file_path) as f: 16 | file_data = json.load(f) 17 | 18 | for _config in tqdm(file_data): 19 | if _config["config_hash"] == model_hash: 20 | # print(config) 21 | config = _config["config"] 22 | config["config_hash"] = _config["config_hash"] 23 | return config 24 | ``` 25 | 26 | Make sure that the dataset paths are set correctly. Here is how to manually change this in the config: 27 | 28 | ```python 29 | from chemCPA.paths import DATA_DIR 30 | 31 | config["dataset"]["data_params"]["dataset_path"] = DATA_DIR / config["dataset"]["data_params"]["dataset_path"].split('/')[-1] 32 | ``` 33 | 34 | Similarly, the `CHECKPOINT_DIR` should align with the folder where you have stored the trained chemCPA models, this is used in the `utils.py`: 35 | 36 | ```python 37 | from chemCPA.paths import CHECKPOINT_DIR 38 | ``` 39 | -------------------------------------------------------------------------------- /preprocessing/5_sciplex_ood_splits.ipynb: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a609c2a16860cc8e65c93dc893ea6d0ecc08dc922f9aead1870fe8085b593a32 3 | size 44371165 4 | -------------------------------------------------------------------------------- /preprocessing/6_baseline_sciplex_dataset.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # text_representation: 5 | # extension: .py 6 | # format_name: light 7 | # format_version: '1.5' 8 | # jupytext_version: 1.16.1 9 | # kernelspec: 10 | # display_name: Python 3.7.12 ('chemical_CPA') 11 | # language: python 12 | # name: python3 13 | # --- 14 | 15 | # # 6 BASELINE SCIPLEX DATASET 16 | 17 | # **Requires** 18 | # sciplex_complete_middle_subset_lincs_genes.h5ad 19 | # 20 | # **Outputs** 21 | # adata_baseline_high_dose.h5ad 22 | # 23 | # 24 | 25 | # + 26 | import pandas as pd 27 | import scanpy as sc 28 | 29 | from chemCPA.paths import DATA_DIR 30 | 31 | pd.set_option('display.max_columns', 200) 32 | # - 33 | 34 | list(DATA_DIR.iterdir()) 35 | 36 | adata_sciplex = sc.read(DATA_DIR/ "sciplex_complete_middle_subset_lincs_genes.h5ad") 37 | 38 | adata_sciplex.obs.columns 39 | 40 | adata_sciplex.obs.loc[adata_sciplex.obs.split_ood_multi_task == 'ood', 'condition'].unique() 41 | 42 | # + 43 | # Subset to second largest dose 44 | 45 | print(adata_sciplex.obs.dose.unique()) 46 | adata_sciplex = adata_sciplex[adata_sciplex.obs.dose.isin([0., 1e4])].copy() 47 | 48 | # + 49 | # Add new splits for dose=1000 and cell_type (A549, MCF7, K562) being unseen for ood drugs 50 | 51 | for cell_type in adata_sciplex.obs.cell_type.unique(): 52 | print(cell_type) 53 | adata_sciplex.obs[f'split_baseline_{cell_type}'] = adata_sciplex.obs['split_ood_multi_task'] 54 | sub_df = adata_sciplex.obs.loc[(adata_sciplex.obs[f'split_baseline_{cell_type}'] == 'ood') * (adata_sciplex.obs.cell_type != cell_type)] 55 | 56 | train_test = sub_df.index 57 | test = sub_df.sample(frac=0.5).index 58 | 59 | adata_sciplex.obs.loc[train_test,f'split_baseline_{cell_type}'] = 'train' 60 | adata_sciplex.obs.loc[test,f'split_baseline_{cell_type}'] = 'test' 61 | # - 62 | 63 | adata_sciplex.obs['split_baseline_A549'].value_counts() 64 | 65 | pd.crosstab(adata_sciplex.obs['split_ood_multi_task'], adata_sciplex.obs['condition']) 66 | 67 | # + 68 | # Quick check that everything is correct 69 | 70 | cell_type = 'K562' 71 | 72 | # pd.crosstab(adata_sciplex.obs[f'split_baseline_{cell_type}'], adata_sciplex.obs['condition']) 73 | pd.crosstab(adata_sciplex.obs[f'split_baseline_{cell_type}'], adata_sciplex.obs['cell_type']) 74 | 75 | # + 76 | # write adata 77 | 78 | adata_sciplex.write(DATA_DIR/'adata_baseline_high_dose.h5ad', compression="gzip") 79 | # - 80 | 81 | 82 | -------------------------------------------------------------------------------- /preprocessing/7_compute_embeddings.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Computes embeddings for the dataset" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "# RDKIT" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "# Computes embeddings for the dataset\n", 24 | "\n", 25 | "# RDKIT\n", 26 | "from chemCPA.paths import DATA_DIR, PROJECT_DIR, ROOT, EMBEDDING_DIR\n", 27 | "import sys\n", 28 | "import os\n", 29 | "from tqdm.auto import tqdm\n", 30 | "\n", 31 | "# Add the parent directory of embeddings to Python path\n", 32 | "sys.path.append(str(ROOT))\n", 33 | "\n", 34 | "import embeddings.rdkit.embedding_rdkit as embedding_rdkit\n", 35 | "\n", 36 | "# Define the datasets to process with their corresponding SMILES keys\n", 37 | "datasets = [\n", 38 | " ('lincs_smiles.h5ad', 'SMILES'),\n", 39 | " ('lincs_full_smiles.h5ad', 'smiles'), # Changed SMILES key to lowercase\n", 40 | " ('sciplex_complete.h5ad', 'SMILES'),\n", 41 | " ('adata_MCF7.h5ad', 'SMILES'),\n", 42 | " ('adata_MCF7_lincs_genes.h5ad', 'SMILES'),\n", 43 | " ('adata_K562.h5ad', 'SMILES'),\n", 44 | " ('adata_K562_lincs_genes.h5ad', 'SMILES'),\n", 45 | " ('adata_A549.h5ad', 'SMILES'),\n", 46 | " ('adata_A549_lincs_genes.h5ad', 'SMILES'),\n", 47 | " ('sciplex_complete_subset_lincs_genes_v2.h5ad', 'SMILES'),\n", 48 | " ('sciplex_complete_middle_subset_v2.h5ad', 'SMILES'),\n", 49 | " ('sciplex_complete_middle_subset_lincs_genes_v2.h5ad', 'SMILES'),\n", 50 | " ('sciplex_complete_v2.h5ad', 'SMILES'),\n", 51 | " ('sciplex_complete_lincs_genes_v2.h5ad', 'SMILES')\n", 52 | "]\n", 53 | "\n", 54 | "# Process each dataset\n", 55 | "for dataset, smiles_key in tqdm(datasets, desc=\"Computing RDKit embeddings\"):\n", 56 | " h5ad_path = os.path.join(DATA_DIR, dataset)\n", 57 | " base_name = os.path.splitext(dataset)[0]\n", 58 | " output_filename = f\"{base_name}_rdkit2D_embedding.parquet\"\n", 59 | " output_path = os.path.join(EMBEDDING_DIR, 'rdkit', output_filename)\n", 60 | " \n", 61 | " # Create the output directory if it doesn't exist\n", 62 | " os.makedirs(os.path.dirname(output_path), exist_ok=True)\n", 63 | " \n", 64 | " try:\n", 65 | " embedding_rdkit.compute_rdkit_embeddings(h5ad_path, output_path=output_path, smiles_key=smiles_key)\n", 66 | " except Exception as e:\n", 67 | " tqdm.write(f\"Error processing {dataset}: {str(e)}\")" 68 | ] 69 | } 70 | ], 71 | "metadata": { 72 | "language_info": { 73 | "name": "python" 74 | } 75 | }, 76 | "nbformat": 4, 77 | "nbformat_minor": 2 78 | } 79 | -------------------------------------------------------------------------------- /preprocessing/7_compute_embeddings.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # text_representation: 5 | # extension: .py 6 | # format_name: light 7 | # format_version: '1.5' 8 | # jupytext_version: 1.16.1 9 | # --- 10 | 11 | # # Computes embeddings for the dataset and prints their dimensions 12 | 13 | from chemCPA.paths import DATA_DIR, PROJECT_DIR, ROOT, EMBEDDING_DIR 14 | import sys 15 | import os 16 | from tqdm.auto import tqdm 17 | import pandas as pd 18 | 19 | # Add the parent directory of embeddings to Python path 20 | sys.path.append(str(ROOT)) 21 | 22 | import embeddings.rdkit.embedding_rdkit as embedding_rdkit 23 | 24 | # Define the datasets to process with their corresponding SMILES keys 25 | datasets = [ 26 | ('lincs_smiles.h5ad', 'SMILES'), 27 | ('lincs_full_smiles.h5ad', 'canonical_smiles'), # Changed SMILES key to lowercase 28 | #('sciplex_complete.h5ad', 'SMILES'), 29 | #('adata_MCF7.h5ad', 'SMILES'), 30 | #('adata_MCF7_lincs_genes.h5ad', 'SMILES'), 31 | #('adata_K562.h5ad', 'SMILES'), 32 | #('adata_K562_lincs_genes.h5ad', 'SMILES'), 33 | #('adata_A549.h5ad', 'SMILES'), 34 | #('adata_A549_lincs_genes.h5ad', 'SMILES'), 35 | #('sciplex_complete_subset_lincs_genes_v2.h5ad', 'SMILES'), 36 | #('sciplex_complete_middle_subset_v2.h5ad', 'SMILES'), 37 | #('sciplex_complete_middle_subset_lincs_genes_v2.h5ad', 'SMILES'), 38 | ('sciplex_complete_v2.h5ad', 'SMILES'), 39 | ('sciplex_complete_lincs_genes_v2.h5ad', 'SMILES') 40 | #('combo_sciplex_prep_hvg_filtered.h5ad', 'smiles_rdkit') 41 | ] 42 | 43 | # Define desired embedding dimension 44 | FIXED_EMBEDDING_DIM = 200 # or whatever dimension you want 45 | 46 | # Define whether to skip variance filtering to keep dimensions consistent 47 | SKIP_VARIANCE_FILTER = False # Set this to True to keep all dimensions 48 | 49 | print("\nComputing and analyzing embeddings:") 50 | print(f"Using fixed embedding dimension: {FIXED_EMBEDDING_DIM}") 51 | print(f"Skip variance filtering: {SKIP_VARIANCE_FILTER}") 52 | print("-" * 50) 53 | 54 | # Process each dataset 55 | for dataset, smiles_key in tqdm(datasets, desc="Computing RDKit embeddings"): 56 | h5ad_path = os.path.join(DATA_DIR, dataset) 57 | base_name = os.path.splitext(dataset)[0] 58 | output_filename = f"{base_name}_rdkit2D_embedding.parquet" 59 | output_path = os.path.join(EMBEDDING_DIR, 'rdkit', output_filename) 60 | 61 | # Create the output directory if it doesn't exist 62 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 63 | 64 | try: 65 | # Compute embeddings without variance filtering 66 | embedding_rdkit.compute_rdkit_embeddings( 67 | h5ad_path, 68 | output_path=output_path, 69 | smiles_key=smiles_key, 70 | skip_variance_filter=SKIP_VARIANCE_FILTER 71 | ) 72 | 73 | # Read and analyze the generated embeddings 74 | embeddings_df = pd.read_parquet(output_path) 75 | 76 | print(f"\nEmbedding analysis for {dataset}:") 77 | print(f"Shape: {embeddings_df.shape}") 78 | print(f"Number of features: {embeddings_df.shape[1]}") 79 | print(f"Memory usage: {embeddings_df.memory_usage().sum() / 1024**2:.2f} MB") 80 | print(f"File location: {output_path}") 81 | print("-" * 50) 82 | 83 | except Exception as e: 84 | tqdm.write(f"Error processing {dataset}: {str(e)}") 85 | -------------------------------------------------------------------------------- /preprocessing/README.md: -------------------------------------------------------------------------------- 1 | # Preprocessing 2 | 3 | This folder contains preprocessing notebooks that convert the raw data to datasets that may 4 | be used for training. 5 | 6 | ## Description 7 | Briefly: 8 | 1. The first notebook cleans up the LINCS dataset, computes DEGS, and splits. 9 | 2. The second notebook adds the SMILES information to the LINCS dataset. 10 | 3. The third notebook finds matching genes between LINCS and SciPlex-3 datasets, and creates datasets with only subsets of genes that match in some way. 11 | 4. The fourth notebook adds the SMILES information to the SciPlex-3 dataset 12 | 5. The fifth notebook creates various sub-datasets with varying observations and train/test/ood splits 13 | 6. The sixth notebook computes a baseline dataset 14 | 7. The seventh notebook computes embedding files for all datasets 15 | 16 | For more details read the notebooks. 17 | 18 | ### Clarifcation on the avaialable datasets 19 | 20 | The `lincs_full.h5ad` as a combination of the available L1000 datasets from phase 1 and phase 2, available here: 21 | - L1000 Connectivity Map Phase I: [GSE70138](https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE70138) 22 | - L1000 Connectivity Map Phase II: [GSE92742](https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE92742) 23 | 24 | We combined the data which is only available in the `.gctx` together with its metadata, cf. `<...>_inst_info.txt`, and make it available as scanpy compatible `.h5ad` object. 25 | 26 | Note that not all perturbations correspond to small molecules which is why we subsetted the data to only contain perturbation types `trt_cp` and `ctl_vehicle`, resulting in a total of 1034271 observations. 27 | 28 | The provided data is normalised. 29 | 30 | For the training on the LINCS data, we ignored the treatment time, `adata_lincs_full.obs["pert_time"]`. 31 | 32 | #### Preprocess data 33 | The data preprocessing should run thorugh with the provided files. For the matching of genes between LINCS and the SciPlex-3 data in `3_lincs_sciplex_gene_matching.ipynb`, we provide a [`symbols_dict`](https://drive.google.com/file/d/16V5nyj3xKlsUk_cJtRYtkpFiglGzP9Xl/view?usp=sharing) which replaces the matching via `sfaira`. Note that you have to execute `4_sciplex_SMILES.ipynb` for both gene sets. The same notebook also contains multiple check against the [`trapnell_final_V7.h5ad`](https://drive.google.com/file/d/1_JUg631r_QfZhKl9NZXXzVefgCMPXE_9/view?usp=share_link) file to make sure that the SMILES are correctly matched. You could ignore these and implement your own solution. For this, we provide `drug_dict.json` file. 34 | 35 | #### Files produced 36 | 37 | -------------------------------------------------------------------------------- /preprocessing/analysis_smiles_lincs_trapnell.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # text_representation: 5 | # extension: .py 6 | # format_name: light 7 | # format_version: '1.5' 8 | # jupytext_version: 1.16.1 9 | # kernelspec: 10 | # display_name: Python 3 11 | # language: python 12 | # name: python3 13 | # --- 14 | 15 | # + jupyter={"outputs_hidden": true} pycharm={"name": "#%%\n"} 16 | import matplotlib 17 | import matplotlib.pyplot as plt 18 | import pandas as pd 19 | import seaborn as sn 20 | from rdkit import Chem, DataStructs 21 | from rdkit.Chem import Draw 22 | from rdkit.Chem.Draw import IPythonConsole 23 | 24 | matplotlib.style.use("fivethirtyeight") 25 | matplotlib.style.use("seaborn-talk") 26 | matplotlib.rcParams['font.family'] = "monospace" 27 | matplotlib.rcParams['figure.dpi'] = 200 28 | matplotlib.pyplot.rcParams['savefig.facecolor'] = 'white' 29 | sn.set_context("poster") 30 | IPythonConsole.ipython_useSVG = False 31 | 32 | 33 | # + jupyter={"outputs_hidden": false} pycharm={"name": "#%%\n"} 34 | trapnell_df = pd.read_csv("../embeddings/trapnell_drugs_smiles.csv", names=["drug", "smiles", "pathway"]) 35 | trapnell_df["smiles"] = trapnell_df.smiles.str.strip() 36 | lincs_df = pd.read_csv("../embeddings/lincs_drugs_smiles.csv", names=["drug", "smiles"]) 37 | lincs_df["smiles"] = lincs_df.smiles.str.strip() 38 | 39 | 40 | # + jupyter={"outputs_hidden": false} pycharm={"name": "#%%\n"} 41 | def tanimoto_score(input_smiles, target_smiles): 42 | input_fp = Chem.RDKFingerprint(Chem.MolFromSmiles(input_smiles)) 43 | target_fp = Chem.RDKFingerprint(Chem.MolFromSmiles(target_smiles)) 44 | return DataStructs.TanimotoSimilarity(input_fp, target_fp) 45 | 46 | 47 | # - 48 | 49 | # ## Checking 3 hold out drugs 50 | # Looking for the most similar drugs in LINCS to our 3 hold out drug in Trapnell 51 | 52 | # + jupyter={"outputs_hidden": false} pycharm={"name": "#%%\n"} 53 | loo_drugs = trapnell_df[trapnell_df.drug.isin(["Quisinostat", "Flavopiridol", "BMS-754807"])] 54 | loo_drugs 55 | 56 | # + jupyter={"outputs_hidden": false} pycharm={"name": "#%%\n"} 57 | smiles_orig = [] 58 | smiles_lincs = [] 59 | for i, (drug, smiles, pathway) in loo_drugs.iterrows(): 60 | tanimoto_sim_col = f"tanimoto_sim_{drug}" 61 | lincs_df[tanimoto_sim_col] = lincs_df.smiles.apply(lambda lincs_smiles: tanimoto_score(lincs_smiles, smiles)) 62 | most_similar = lincs_df.sort_values(tanimoto_sim_col, ascending=False).head(1) 63 | smiles_orig.append(smiles) 64 | smiles_lincs.append(most_similar["smiles"].item()) 65 | print(drug, any(lincs_df.smiles.isin([smiles])), most_similar[tanimoto_sim_col].item(), most_similar["drug"].item()) 66 | print(lincs_df.sort_values(tanimoto_sim_col, ascending=False).head(5)[["drug", tanimoto_sim_col]]) 67 | 68 | # + jupyter={"outputs_hidden": false} pycharm={"name": "#%%\n"} 69 | for orig, lincs in zip(smiles_orig, smiles_lincs): 70 | im = Draw.MolsToGridImage([Chem.MolFromSmiles(orig), Chem.MolFromSmiles(lincs)], subImgSize=(600, 400), 71 | legends=[orig, lincs]) 72 | plt.tight_layout() 73 | display(im) 74 | -------------------------------------------------------------------------------- /preprocessing/convert_notebooks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # helper script to convert all jupyter notebooks in the preprocessing folder to python scripts without having to open up jupyter lab 3 | 4 | # Get the directory of the script (preprocessing folder) 5 | NOTEBOOK_DIR="$(dirname "$0")" 6 | 7 | echo "Starting conversion of Jupyter notebooks to Python scripts in directory: $NOTEBOOK_DIR" 8 | 9 | # Convert all .ipynb files in the directory to .py files 10 | for notebook in "$NOTEBOOK_DIR"/*.ipynb; do 11 | # Extract the filename without the path 12 | filename=$(basename -- "$notebook") 13 | 14 | echo "Converting $filename to Python script..." 15 | 16 | # Run the conversion 17 | jupytext --to py "$notebook" 18 | 19 | # Check if the conversion was successful 20 | if [ $? -eq 0 ]; then 21 | echo "Successfully converted $filename to ${filename%.ipynb}.py" 22 | else 23 | echo "Failed to convert $filename" 24 | fi 25 | done 26 | 27 | echo "Conversion process completed." 28 | 29 | -------------------------------------------------------------------------------- /preprocessing/drug_dict.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/chemCPA/43e830eb0958c54e4aa64442c17ec0fed19b3f15/preprocessing/drug_dict.json -------------------------------------------------------------------------------- /preprocessing/notebook_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | from IPython import get_ipython 4 | 5 | # Define a context manager to suppress output 6 | class suppress_output: 7 | def __enter__(self): 8 | self._stdout = os.dup(1) 9 | self._stderr = os.dup(2) 10 | self._null = os.open(os.devnull, os.O_RDWR) 11 | os.dup2(self._null, 1) 12 | os.dup2(self._null, 2) 13 | return self 14 | 15 | def __exit__(self, *args): 16 | # First restore the original file descriptors 17 | os.dup2(self._stdout, 1) 18 | os.dup2(self._stderr, 2) 19 | # Then close all our saved descriptors 20 | os.close(self._stdout) 21 | os.close(self._stderr) 22 | os.close(self._null) 23 | 24 | def is_notebook() -> bool: 25 | """ 26 | Returns True if running in a Jupyter notebook, False otherwise. 27 | """ 28 | try: 29 | # Get the shell object from IPython 30 | shell = get_ipython().__class__.__name__ 31 | # Check if we're in a notebook-like environment 32 | if shell == 'ZMQInteractiveShell': # Jupyter notebook or qtconsole 33 | return True 34 | elif shell == 'TerminalInteractiveShell': # Terminal IPython 35 | return False 36 | else: 37 | return False 38 | except NameError: # If get_ipython is not defined (standard Python interpreter) 39 | return False 40 | -------------------------------------------------------------------------------- /preprocessing/run_notebooks.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import logging 4 | import subprocess 5 | from tqdm import tqdm 6 | import sys 7 | 8 | # Set up logging 9 | logging.basicConfig( 10 | level=logging.INFO, 11 | format='%(asctime)s - %(levelname)s - %(message)s' 12 | ) 13 | logger = logging.getLogger(__name__) 14 | 15 | def run_python_script(script_path, env_vars=None): 16 | """Execute a single Python script with optional environment variables.""" 17 | try: 18 | logger.info(f"Running script: {script_path}") 19 | 20 | # Create a copy of the current environment 21 | env = os.environ.copy() 22 | # Update with any additional environment variables 23 | if env_vars: 24 | env.update(env_vars) 25 | 26 | # Execute the Python script using subprocess with real-time output 27 | process = subprocess.Popen( 28 | ['python', str(script_path)], 29 | stdout=subprocess.PIPE, 30 | stderr=subprocess.PIPE, 31 | text=True, 32 | bufsize=1, 33 | universal_newlines=True, 34 | env=env # Pass the modified environment 35 | ) 36 | 37 | # Print output in real-time 38 | while True: 39 | output = process.stdout.readline() 40 | error = process.stderr.readline() 41 | 42 | if output: 43 | print(output.strip()) 44 | if error: 45 | print(error.strip(), file=sys.stderr) 46 | 47 | # Check if process has finished 48 | if output == '' and error == '' and process.poll() is not None: 49 | break 50 | 51 | return_code = process.poll() 52 | 53 | if return_code == 0: 54 | logger.info(f"Successfully executed: {script_path}") 55 | return True 56 | else: 57 | logger.error(f"Error executing {script_path} (return code: {return_code})") 58 | return False 59 | 60 | except Exception as e: 61 | logger.error(f"Error executing {script_path}: {str(e)}") 62 | return False 63 | 64 | def main(): 65 | # Get preprocessing directory 66 | preprocessing_dir = Path("preprocessing") 67 | if not preprocessing_dir.exists(): 68 | preprocessing_dir = Path(__file__).parent 69 | 70 | # Get all python files that start with a digit 71 | python_files = sorted([f for f in preprocessing_dir.glob("[0-9]*.py")]) 72 | 73 | logger.info(f"Found {len(python_files)} Python scripts to execute") 74 | 75 | # Execute scripts in order 76 | results = [] 77 | for script_path in tqdm(python_files, desc="Executing scripts"): 78 | # Special handling for the SMILES script 79 | if "4_sciplex_SMILES" in script_path.name: 80 | # Run with LINCS_GENES=True 81 | success_true = run_python_script(script_path, {"LINCS_GENES": "true"}) 82 | results.append((f"{script_path.name} (LINCS_GENES=True)", success_true)) 83 | 84 | # Run with LINCS_GENES=False 85 | success_false = run_python_script(script_path, {"LINCS_GENES": "false"}) 86 | results.append((f"{script_path.name} (LINCS_GENES=False)", success_false)) 87 | else: 88 | # Run other scripts normally 89 | success = run_python_script(script_path) 90 | results.append((script_path.name, success)) 91 | 92 | # Print summary 93 | print("\nExecution Summary:") 94 | print("-----------------") 95 | for name, success in results: 96 | status = "✓ Success" if success else "✗ Failed" 97 | print(f"{name}: {status}") 98 | 99 | if __name__ == "__main__": 100 | main() 101 | -------------------------------------------------------------------------------- /preprocessing/supress_output.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | 4 | # Define a context manager to suppress output 5 | class suppress_output: 6 | def __enter__(self): 7 | self._stdout = os.dup(1) 8 | self._stderr = os.dup(2) 9 | self._null = os.open(os.devnull, os.O_RDWR) 10 | os.dup2(self._null, 1) 11 | os.dup2(self._null, 2) 12 | return self 13 | 14 | def __exit__(self, *args): 15 | # First restore the original file descriptors 16 | os.dup2(self._stdout, 1) 17 | os.dup2(self._stderr, 2) 18 | # Then close all our saved descriptors 19 | os.close(self._stdout) 20 | os.close(self._stderr) 21 | os.close(self._null) 22 | -------------------------------------------------------------------------------- /project_folder: -------------------------------------------------------------------------------- 1 | /nfs/homedirs/hetzell/hdd/project_chemCPA -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | exclude = ''' 3 | /( 4 | \.eggs 5 | | \.git 6 | | \.venv 7 | | build 8 | | dist 9 | )/ 10 | ''' 11 | -------------------------------------------------------------------------------- /raw_data/__init__.py: -------------------------------------------------------------------------------- 1 | # Empty file to make the directory a Python package -------------------------------------------------------------------------------- /raw_data/download_data.py: -------------------------------------------------------------------------------- 1 | import gdown 2 | import os 3 | import urllib.request 4 | import gzip 5 | import tarfile 6 | from tqdm import tqdm 7 | import argparse 8 | import shutil 9 | 10 | 11 | def download_file(url, output, use_gdown=True): 12 | if not os.path.exists(output): 13 | print(f"Downloading file from {url}") 14 | if use_gdown: 15 | gdown.download(url, output, quiet=False) 16 | else: 17 | download_file_with_progress(url, output) 18 | print(f"File downloaded as {output}") 19 | else: 20 | print(f"File {output} already exists. Skipping download.") 21 | 22 | 23 | def download_file_with_progress(url, output): 24 | response = urllib.request.urlopen(url) 25 | total_size = int(response.info().get('Content-Length', -1)) 26 | block_size = 8192 # 8 KB 27 | 28 | with tqdm(total=total_size, unit='iB', unit_scale=True, desc=output) as progress_bar: 29 | with open(output, 'wb') as file: 30 | while True: 31 | buffer = response.read(block_size) 32 | if not buffer: 33 | break 34 | size = file.write(buffer) 35 | progress_bar.update(size) 36 | 37 | 38 | def download_and_extract_gzip(url, output, max_retries=3): 39 | # Check if final output file already exists 40 | if os.path.exists(output): 41 | print(f"File {output} already exists. Skipping download and extraction.") 42 | return 43 | 44 | gzip_file = output + '.gz' 45 | 46 | for attempt in range(max_retries): 47 | try: 48 | if not os.path.exists(gzip_file) or attempt > 0: 49 | download_file(url, gzip_file, use_gdown=False) 50 | 51 | print(f"Extracting {gzip_file}") 52 | with gzip.open(gzip_file, 'rb') as f_in: 53 | with open(output, 'wb') as f_out: 54 | with tqdm(unit='B', unit_scale=True, desc="Extracting", total=os.path.getsize(gzip_file)) as pbar: 55 | while True: 56 | chunk = f_in.read(8192) 57 | if not chunk: 58 | break 59 | f_out.write(chunk) 60 | pbar.update(len(chunk)) 61 | 62 | os.remove(gzip_file) 63 | print(f"Extracted file saved as {output}") 64 | return # Successful extraction, exit the function 65 | except (EOFError, gzip.BadGzipFile) as e: 66 | print(f"Error during extraction (attempt {attempt + 1}/{max_retries}): {str(e)}") 67 | if os.path.exists(gzip_file): 68 | os.remove(gzip_file) 69 | if os.path.exists(output): 70 | os.remove(output) 71 | if attempt == max_retries - 1: 72 | print(f"Failed to download and extract {url} after {max_retries} attempts.") 73 | raise 74 | 75 | 76 | def extract_tar(tar_file, output_dir): 77 | print(f"Extracting {tar_file} to {output_dir}") 78 | with tarfile.open(tar_file, 'r') as tar: 79 | members = tar.getmembers() 80 | for member in tqdm(members, desc="Extracting"): 81 | tar.extract(member, output_dir) 82 | os.remove(tar_file) 83 | print(f"Extracted files to {output_dir}") 84 | 85 | 86 | def download_files(force_redownload=False): 87 | project_folder = "project_folder" 88 | 89 | # List of files to download: (url, relative_path, is_gzip, is_tar) 90 | files_to_download = [ 91 | ("https://drive.google.com/uc?export=download&id=18QkyADzuM8b7lMxRg94jufHaKRPkzEFw", 92 | "datasets/adata_biolord_split_30.h5ad", False, False), 93 | ("https://drive.google.com/uc?export=download&id=1oV2o5dVEVE3OwBVZzuuJTXuaamZJeFL9", 94 | "embeddings/rdkit/data/embeddings/rdkit2D_embedding_biolord.parquet", False, False), 95 | ("https://f003.backblazeb2.com/file/chemCPA-datasets/lincs_full.h5ad.gz", 96 | "datasets/lincs_full.h5ad", True, False), 97 | ("https://dl.fbaipublicfiles.com/dlp/cpa_binaries.tar", 98 | "binaries/cpa_binaries.tar", False, True), 99 | ] 100 | 101 | for url, relative_path, is_gzip, is_tar in files_to_download: 102 | # Create the full path 103 | full_path = os.path.join(project_folder, relative_path) 104 | 105 | # Ensure the directory exists 106 | os.makedirs(os.path.dirname(full_path), exist_ok=True) 107 | 108 | # If force_redownload is True, remove the existing file 109 | if force_redownload and os.path.exists(full_path): 110 | print(f"Removing existing file: {full_path}") 111 | os.remove(full_path) 112 | 113 | # Download and extract based on file type 114 | if is_gzip: 115 | download_and_extract_gzip(url, full_path) 116 | elif is_tar: 117 | # Download tar file if it doesn't exist 118 | download_file(url, full_path, use_gdown=False) 119 | # Extract if the tar file exists 120 | if os.path.exists(full_path): 121 | extract_tar(full_path, os.path.dirname(full_path)) 122 | else: 123 | download_file(url, full_path) 124 | 125 | -------------------------------------------------------------------------------- /raw_data/download_utils.py: -------------------------------------------------------------------------------- 1 | # dataset_downloader.py 2 | 3 | import gdown 4 | import os 5 | import urllib.request 6 | import gzip 7 | import tarfile 8 | from tqdm import tqdm 9 | import shutil 10 | import requests 11 | 12 | 13 | def get_file_size(url, use_gdown=False): 14 | """Get file size in bytes.""" 15 | try: 16 | if use_gdown: 17 | # For Google Drive, we need to use a different approach 18 | response = requests.head(url, allow_redirects=True) 19 | size = None # Google Drive doesn't provide size in headers directly 20 | else: 21 | response = requests.head(url, allow_redirects=True) 22 | size = int(response.headers.get('content-length', 0)) 23 | return size 24 | except: 25 | return None 26 | 27 | 28 | def format_size(size_in_bytes): 29 | """Convert size in bytes to human readable format.""" 30 | if size_in_bytes is None: 31 | return "unknown size" 32 | 33 | for unit in ['B', 'KB', 'MB', 'GB']: 34 | if size_in_bytes < 1024: 35 | return f"{size_in_bytes:.1f} {unit}" 36 | size_in_bytes /= 1024 37 | return f"{size_in_bytes:.1f} TB" 38 | 39 | 40 | def confirm_download(url, path, use_gdown=False): 41 | """Ask user for confirmation before downloading.""" 42 | size = get_file_size(url, use_gdown) 43 | formatted_size = format_size(size) 44 | 45 | print(f"\n📦 Preparing to download:") 46 | print(f" • File: {os.path.basename(path)}") 47 | print(f" • Size: {formatted_size}") 48 | print(f" • Destination: {path}") 49 | 50 | while True: 51 | response = input("\n⚡ Continue with download? [y/n]: ").lower() 52 | if response in ['y', 'yes']: 53 | return True 54 | elif response in ['n', 'no']: 55 | print("❌ Download cancelled") 56 | return False 57 | print("Please answer 'y' or 'n'") 58 | 59 | 60 | def download_file(url, path, use_gdown=False): 61 | """Download a file from a URL to a specified path.""" 62 | if os.path.exists(path): 63 | print(f"✨ File already exists at {path}") 64 | return 65 | 66 | if not confirm_download(url, path, use_gdown): 67 | exit(0) 68 | 69 | print(f"🚀 Starting download...") 70 | 71 | if use_gdown: 72 | gdown.download(url, path, quiet=False) 73 | else: 74 | response = requests.get(url, stream=True) 75 | total_size = int(response.headers.get('content-length', 0)) 76 | 77 | with open(path, 'wb') as f: 78 | if total_size == 0: 79 | f.write(response.content) 80 | else: 81 | downloaded = 0 82 | for data in response.iter_content(chunk_size=4096): 83 | downloaded += len(data) 84 | f.write(data) 85 | done = int(50 * downloaded / total_size) 86 | downloaded_formatted = format_size(downloaded) 87 | total_formatted = format_size(total_size) 88 | print(f"\r💫 Progress: [{'=' * done}{' ' * (50-done)}] {downloaded_formatted}/{total_formatted}", end='') 89 | print("\n✅ Download complete!") 90 | 91 | 92 | def download_and_extract_gzip(url, output_path): 93 | """Download a gzip file and extract it.""" 94 | if os.path.exists(output_path): 95 | print(f"✨ File already exists at {output_path}") 96 | return 97 | 98 | # Download to temporary gzip file 99 | temp_gz_path = output_path + '.gz' 100 | download_file(url, temp_gz_path) 101 | 102 | print(f"📂 Extracting {temp_gz_path}...") 103 | with gzip.open(temp_gz_path, 'rb') as f_in: 104 | with open(output_path, 'wb') as f_out: 105 | shutil.copyfileobj(f_in, f_out) 106 | 107 | # Clean up 108 | os.remove(temp_gz_path) 109 | print("✨ Extraction complete!") 110 | 111 | 112 | def extract_tar(tar_path, output_dir): 113 | """Extract a tar file to the specified directory.""" 114 | print(f"📂 Extracting {tar_path}...") 115 | with tarfile.open(tar_path) as tar: 116 | tar.extractall(path=output_dir) 117 | print("✨ Extraction complete!") 118 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | setup( 4 | name="chemCPA", 5 | version="1.0.0", 6 | description="", 7 | url="http://github.com/theislab/chemCPA", 8 | author="See README.md", 9 | author_email="See paper", 10 | license="MIT", 11 | packages=["chemCPA"], 12 | ) 13 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import torch.testing 3 | 4 | from chemCPA.data.data import Dataset 5 | 6 | 7 | def test_dataset_idx_ohe(): 8 | kwargs = { 9 | "perturbation_key": "condition", 10 | "pert_category": "cov_drug_dose_name", 11 | "dose_key": "dose", 12 | "covariate_keys": "cell_type", 13 | "smiles_key": "SMILES", 14 | "split_key": "split", 15 | } 16 | d_idx = Dataset( 17 | fname="datasets/trapnell_cpa_subset.h5ad", 18 | **kwargs, 19 | use_drugs_idx=True, 20 | ) 21 | 22 | d_ohe = Dataset( 23 | fname="datasets/trapnell_cpa_subset.h5ad", 24 | **kwargs, 25 | use_drugs_idx=False, 26 | ) 27 | 28 | numpy.testing.assert_equal( 29 | d_ohe.encoder_drug.categories_[0], d_idx.drugs_names_unique_sorted 30 | ) 31 | 32 | for i in range(len(d_idx)): 33 | genes_idx, idx, dosage, cov_idx = d_idx[i] 34 | genes_ohe, drug, cov_ohe = d_ohe[i] 35 | torch.testing.assert_close(genes_idx, genes_ohe) 36 | # make sure the OHE and the index representation encode the same info 37 | torch.testing.assert_close(drug[idx], dosage) 38 | torch.testing.assert_close(cov_idx, cov_ohe) 39 | -------------------------------------------------------------------------------- /tests/test_dosers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch.nn 3 | import torch.nn.functional as F 4 | 5 | from chemCPA.model import ComPert, GeneralizedSigmoid 6 | 7 | 8 | @pytest.mark.parametrize("nonlin", ["sigm", "logsigm", None]) 9 | def test_sigm_ohe_idx(nonlin): 10 | # test to make sure the generalized sigmoid doser has the same outputs 11 | # with indices and OHE 12 | sigm = GeneralizedSigmoid(10, "cpu", nonlin=nonlin) 13 | 14 | beta = torch.nn.Parameter( 15 | torch.tensor( 16 | [[x / 10 for x in range(0, 10)]], dtype=torch.float32, device="cpu" 17 | ) 18 | ) 19 | assert sigm.beta.shape == beta.shape 20 | sigm.beta = beta 21 | bias = torch.nn.Parameter( 22 | torch.tensor([[x / 5 for x in range(-5, 5)]], dtype=torch.float32, device="cpu") 23 | ) 24 | assert sigm.bias.shape == bias.shape 25 | sigm.bias = bias 26 | 27 | dosages = torch.tensor([1.0, 1.0, 1.0, 1.0], device="cpu", dtype=torch.float32) 28 | x = torch.tensor([0, 2, 9, 2], dtype=torch.long) 29 | ohe = F.one_hot(x, num_classes=10) 30 | ohe_scaled = torch.einsum("a,ab->ab", [dosages, ohe]) 31 | ohe_s = sigm(ohe_scaled) 32 | idx_s = sigm(dosages, idx=x) 33 | assert ohe_s[0][0] == idx_s[0] 34 | assert ohe_s[1][2] == idx_s[1] 35 | assert ohe_s[2][9] == idx_s[2] 36 | assert ohe_s[3][2] == idx_s[3] 37 | 38 | 39 | @pytest.mark.parametrize("doser_type", ["logsigm", "sigm", "mlp", None]) 40 | def test_drug_embedding(doser_type): 41 | drug_emb = torch.nn.Embedding.from_pretrained( 42 | torch.tensor(list(range(10 * 10)), dtype=torch.float32, device="cpu").view( 43 | 10, 10 44 | ) 45 | ) 46 | model = ComPert( 47 | num_genes=50, 48 | num_drugs=10, 49 | num_covariates=[1], 50 | doser_type=doser_type, 51 | device="cpu", 52 | drug_embeddings=drug_emb, 53 | use_drugs_idx=False, 54 | ) 55 | idx = torch.tensor([0, 2, 9, 2], dtype=torch.long, device="cpu") 56 | dosages = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=torch.float32, device="cpu") 57 | ohe = F.one_hot(idx, num_classes=10).to(dtype=torch.float32, device="cpu") 58 | ohe_scaled = torch.einsum("a,ab->ab", [dosages, ohe]) 59 | 60 | emb_ohe = model.compute_drug_embeddings_(drugs=ohe_scaled) 61 | model.use_drugs_idx = True 62 | emb_idx = model.compute_drug_embeddings_(drugs_idx=idx, dosages=dosages) 63 | 64 | torch.testing.assert_close(emb_ohe, emb_idx) 65 | -------------------------------------------------------------------------------- /tests/test_embedding.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch.testing 3 | 4 | from chemCPA.data.data import Dataset 5 | from chemCPA.embedding import get_chemical_representation 6 | from chemCPA.model import ComPert 7 | 8 | 9 | def test_embedding_idx_roundtrip(): 10 | # test to make sure that the same drug embeddings are computed for all drugs 11 | # in trapnell_subset, independent of whether we use indices or one-hot-encodings 12 | kwargs = { 13 | "perturbation_key": "condition", 14 | "pert_category": "cov_drug_dose_name", 15 | "dose_key": "dose", 16 | "covariate_keys": "cell_type", 17 | "smiles_key": "SMILES", 18 | "split_key": "split", 19 | } 20 | 21 | # load the embedding of DSMO 22 | control_emb = torch.tensor( 23 | pd.read_parquet("embeddings/grover/data/embeddings/grover_base.parquet") 24 | .loc["CS(C)=O"] 25 | .values 26 | ) 27 | 28 | for use_drugs_idx in [True, False]: 29 | dataset = Dataset( 30 | fname="datasets/trapnell_cpa_subset.h5ad", 31 | **kwargs, 32 | use_drugs_idx=use_drugs_idx 33 | ) 34 | embedding = get_chemical_representation( 35 | data_dir="embeddings/", 36 | smiles=dataset.canon_smiles_unique_sorted, 37 | embedding_model="grover_base", 38 | ) 39 | device = embedding.weight.device 40 | 41 | # make sure "control" is correctly encoded as the all zero vector 42 | control = torch.tensor( 43 | list(dataset.drugs_names_unique_sorted).index("control"), 44 | device=device, 45 | ) 46 | torch.testing.assert_close(embedding(control), control_emb.to(device)) 47 | 48 | model = ComPert( 49 | dataset.num_genes, 50 | dataset.num_drugs, 51 | dataset.num_covariates, 52 | device=device, 53 | doser_type="sigm", 54 | drug_embeddings=embedding, 55 | use_drugs_idx=use_drugs_idx, 56 | ) 57 | if use_drugs_idx: 58 | genes, idx, dosages, covariates = dataset[:] 59 | idx_emb = model.compute_drug_embeddings_(drugs_idx=idx, dosages=dosages) 60 | else: 61 | genes, drugs, covariates = dataset[:] 62 | ohe_emb = model.compute_drug_embeddings_(drugs=drugs) 63 | 64 | # assert both model return the same embedding for the drugs in the dataset 65 | torch.testing.assert_close(idx_emb, ohe_emb) 66 | --------------------------------------------------------------------------------