├── src ├── data │ ├── __init__.py │ ├── protac_dataloader.py │ └── protac_dataset.py ├── models │ ├── __init__.py │ ├── xgboost │ │ └── __init__.py │ ├── poi_encoder │ │ ├── __init__.py │ │ └── poi_count_vectorizer.py │ ├── cell_type_encoder │ │ ├── __init__.py │ │ └── cell_type_ordinal_encoder.py │ ├── e3_ligase_encoder │ │ ├── __init__.py │ │ └── e3_ordinal_encoder.py │ ├── smiles_encoder │ │ ├── gnn │ │ │ ├── __init__.py │ │ │ └── torch_geom_architectures.py │ │ ├── mlp │ │ │ ├── __init__.py │ │ │ └── rdkit_fp_model.py │ │ └── transformer │ │ │ ├── __init__.py │ │ │ └── pretrained_transformer.py │ └── wrapper_model.py ├── hyperparameter_tuning │ ├── __init__.py │ ├── cli.py │ └── optuna_utils.py ├── utils │ ├── __init__.py │ └── fingerprints.py ├── README.md └── tune.py ├── LICENSE.md ├── __init__.py ├── environment.yml ├── models ├── poi_encoder.joblib ├── cell_type_encoder.joblib └── e3_ligase_encoder.joblib ├── AUTHORS.md ├── CONTRIBUTING.md ├── main.py ├── config_optuna.yml ├── .gitignore ├── config.yaml ├── config_default.yml ├── README.md └── notebooks ├── extra_features_encoders.ipynb └── complex_encoding.ipynb /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/xgboost/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/hyperparameter_tuning/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/poi_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/cell_type_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/e3_ligase_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/smiles_encoder/gnn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/smiles_encoder/mlp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/smiles_encoder/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright 2023 Chalmers University of Technology 2 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.getcwd()) -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.getcwd()) -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AstraZeneca/Machine-Learning-for-Predicting-Targeted-Protein-Degradation/HEAD/environment.yml -------------------------------------------------------------------------------- /models/poi_encoder.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AstraZeneca/Machine-Learning-for-Predicting-Targeted-Protein-Degradation/HEAD/models/poi_encoder.joblib -------------------------------------------------------------------------------- /AUTHORS.md: -------------------------------------------------------------------------------- 1 | # Authors 2 | 3 | * Stefano Ribes 4 | * Eva Nittinger 5 | * Rocío Mercado 6 | * Christian Tyrchan 7 | 8 | # Maintainer 9 | 10 | * Stefano Ribes 11 | -------------------------------------------------------------------------------- /models/cell_type_encoder.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AstraZeneca/Machine-Learning-for-Predicting-Targeted-Protein-Degradation/HEAD/models/cell_type_encoder.joblib -------------------------------------------------------------------------------- /models/e3_ligase_encoder.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AstraZeneca/Machine-Learning-for-Predicting-Targeted-Protein-Degradation/HEAD/models/e3_ligase_encoder.joblib -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Thank you for your interest in contributing to this project! Please note that we are not accepting contributions to this repository. Further development will be done based on the code under the AILab-Bio GitHub page, which can be found at https://github.com/ailab-bio. 4 | 5 | If you have any questions or concerns, please feel free to contact us at: ribes dot stefano at gmail dot com. -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Any, Dict, Set, Optional, Union 3 | 4 | import pytorch_lightning as pl 5 | from pytorch_lightning.cli import LightningCLI, ArgsType, LightningArgumentParser 6 | 7 | from lightning.pytorch.utilities.types import ( 8 | _EVALUATE_OUTPUT, 9 | _PREDICT_OUTPUT, 10 | EVAL_DATALOADERS, 11 | LRSchedulerConfig, 12 | TRAIN_DATALOADERS, 13 | ) 14 | 15 | from jsonargparse import ( 16 | ActionConfigFile, 17 | ArgumentParser, 18 | class_from_function, 19 | ) 20 | 21 | # simple demo classes for your convenience 22 | from data.protac_dataloader import PROTACDataModule 23 | from models.wrapper_model import WrapperModel, ProtacModel # noqa: F401 24 | from models.smiles_encoder.mlp.rdkit_fp_model import RDKitFingerprintEncoder # noqa: F401 25 | 26 | 27 | def cli_main(): 28 | cli = LightningCLI( 29 | ProtacModel, 30 | PROTACDataModule, 31 | # subclass_mode_model=True, 32 | seed_everything_default=42, 33 | parser_kwargs={'parser_mode': 'omegaconf'}, 34 | save_config_kwargs={'overwrite': True}, 35 | ) 36 | 37 | 38 | if __name__ == '__main__': 39 | cli_main() -------------------------------------------------------------------------------- /src/README.md: -------------------------------------------------------------------------------- 1 | # Tips and Notes 2 | 3 | ## TODOs 4 | 5 | * The directory tree might be simplified, as it is intended to be extended in the future. 6 | 7 | ## CLI 8 | 9 | * For any sub-module to be used in the CLI, it must have a docstring! 10 | * To disable all logging when validating and testing, append the following to the command: `--trainer.logger false --trainer.enable_checkpointing false` 11 | * The official Optuna pruning callback for pytorch ligthning has some issues. If one desires to add it, use instead the following as `class_path`: `hyperparameter_tuning.optuna_utils.CustomPyTorchLightningPruningCallback` 12 | 13 | > Parsers make a best effort to determine the correct names and types that the parser should accept. However, there can be cases not yet supported or cases for which it would be impossible to support. To somewhat overcome these limitations, there is a special key `dict_kwargs` that can be used to provide arguments that will not be validated during parsing, but will be used for class instantiation. 14 | > Multiple config files can be provided, and they will be parsed sequentially. 15 | > `$ python main.py fit --trainer trainer.yaml --model model.yaml --data data.yaml [...]` -------------------------------------------------------------------------------- /src/tune.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Any, Dict, Set, Optional, Union, Type, List, Tuple, Callable 3 | 4 | from data.protac_dataloader import PROTACDataModule 5 | from models.wrapper_model import WrapperModel, ProtacModel # noqa: F401 6 | 7 | from hyperparameter_tuning.cli import TuneLightningCLI 8 | from hyperparameter_tuning.optuna_utils import objective 9 | 10 | from pytorch_lightning.cli import ArgsType 11 | 12 | def cli_main(): 13 | """Main function for hyperparameter tuning.""" 14 | cli = TuneLightningCLI(ProtacModel, 15 | PROTACDataModule, 16 | seed_everything_default=42, 17 | parser_kwargs={'parser_mode': 'omegaconf'}, 18 | run=False) 19 | # Create study and launch optimization 20 | study = cli.create_study() 21 | optimize_args = cli.config.optuna.get('optimize', {}) 22 | study.optimize(lambda trial: objective(trial, cli), **optimize_args) 23 | # Save best config to file 24 | best_config = study.best_trial.user_attrs['config'] 25 | print(f'Best config: {best_config}') 26 | filename = best_config.get('best-config-filename', 'best_config.yaml') 27 | cli.parser.save(best_config, filename, format='yaml', overwrite=True) 28 | 29 | 30 | if __name__ == '__main__': 31 | cli_main() -------------------------------------------------------------------------------- /config_optuna.yml: -------------------------------------------------------------------------------- 1 | trainer: 2 | callbacks: 3 | - class_path: EarlyStopping 4 | init_args: 5 | monitor: val_loss 6 | mode: min 7 | patience: 5 8 | check_finite: true 9 | - class_path: EarlyStopping 10 | init_args: 11 | monitor: val_acc 12 | mode: max 13 | patience: 5 14 | check_finite: true 15 | enable_checkpointing: false 16 | max_epochs: 15 17 | optuna: 18 | optimize: 19 | n_trials: 10 20 | metric: val_acc 21 | study: 22 | direction: maximize 23 | pruner: 24 | class_path: optuna.pruners.HyperbandPruner 25 | init_args: 26 | min_resource: 2 27 | max_resource: ${trainer.max_epochs} # Should be equal to the number of training epochs 28 | reduction_factor: 3 29 | sampler: 30 | class_path: optuna.samplers.TPESampler 31 | init_args: 32 | seed: 42 33 | hparams: 34 | # model.smiles_encoder.init_args.dropout: 35 | # function: suggest_float 36 | # kwargs: 37 | # low: 0.1 38 | # high: 0.4 39 | # model.smiles_encoder.init_args.hidden_channels: 40 | # function: suggest_categorical 41 | # kwargs: 42 | # choices: 43 | # - [1024, 512, 256] 44 | # - [512, 256, 256] 45 | # - [128, 128, 128, 256] 46 | # - [256, 256, 256] 47 | # model.head.init_args.hidden_channels: 48 | # function: suggest_categorical 49 | # kwargs: 50 | # choices: 51 | # - [1] 52 | # - [64, 32, 1] 53 | # - [128, 64, 32, 1] 54 | # - [512, 128, 1] 55 | data.batch_size: 56 | function: suggest_int 57 | kwargs: 58 | low: 32 59 | high: 256 60 | step: 32 61 | data.protac_dataset_args.morgan_atomic_radius: 62 | function: suggest_int 63 | kwargs: 64 | low: 3 65 | high: 8 66 | optimizer.init_args.lr: 67 | function: suggest_float 68 | kwargs: 69 | low: 1e-5 70 | high: 1e-3 71 | log: true -------------------------------------------------------------------------------- /src/models/smiles_encoder/transformer/pretrained_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoConfig, AutoModelForSequenceClassification 3 | import pytorch_lightning as pl 4 | 5 | def mean_pooling(model_output, attention_mask): 6 | # First element of model_output contains all token embeddings 7 | token_embeddings = model_output['last_hidden_state'] 8 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 9 | sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) 10 | sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) 11 | return sum_embeddings / sum_mask 12 | 13 | 14 | class TransformerSubModel(pl.LightningModule): 15 | 16 | def __init__(self, checkpoint_path: str = 'seyonec/ChemBERTa-zinc-base-v1'): 17 | super().__init__() 18 | # Save the arguments passed to init 19 | self.save_hyperparameters() 20 | self.__dict__.update(locals()) # Add arguments as attributes 21 | # ChemBERT for SMILES 22 | self.config = AutoConfig.from_pretrained(checkpoint_path, 23 | output_hidden_states=True, 24 | num_labels=1) 25 | self.chembert = AutoModelForSequenceClassification.from_pretrained( 26 | checkpoint_path, 27 | config=self.config 28 | ).roberta 29 | 30 | def forward(self, x_in): 31 | # Run ChemBert over the tokenized SMILES 32 | input_ids = x_in['smiles_tokenized']['input_ids'].squeeze(dim=1) 33 | attention_mask = x_in['smiles_tokenized']['attention_mask'].squeeze(dim=1) 34 | smiles_embedding = self.chembert(input_ids, attention_mask) 35 | # NOTE: Due to multi-head attention, the output of the Transformer is a 36 | # sequence of hidden states, one for each input token. The following 37 | # takes the mean of all token embeddings to get a single embedding. 38 | smiles_embedding = mean_pooling(smiles_embedding, attention_mask) 39 | return smiles_embedding 40 | 41 | def get_embedding_size(self): 42 | return self.config.to_dict()['hidden_size'] -------------------------------------------------------------------------------- /src/utils/fingerprints.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rdkit import Chem 3 | from rdkit.Chem import AllChem, DataStructs, MACCSkeys 4 | 5 | from typing import Mapping, Literal, Callable, List, ClassVar, Any, Tuple, Type 6 | 7 | def get_fingerprint(smiles: str, 8 | n_bits: int = 1024, 9 | fp_type: Literal['morgan', 'maccs', 'path'] = 'morgan', 10 | min_path: int = 1, 11 | max_path: int = 2, 12 | atomic_radius: int = 2) -> np.ndarray: 13 | """Returns molecular fingerprint of a given molecule SMILES. 14 | 15 | Args: 16 | smiles (str): SMILES string to convert. 17 | n_bits (int, optional): Number of bits of the generated fingerprint. Defaults to 1024. 18 | fp_type (Literal['morgan', 'maccs', 'path'], optional): Fingerprint type to generate. Defaults to 'morgan'. 19 | min_path (int, optional): Minimum path lenght for path-based fingerprints. Defaults to 1. 20 | max_path (int, optional): Maximum path lenght for path-based fingerprints. Defaults to 2. 21 | atomic_radius (int, optional): Atomic radius for MORGAN fingerprints. Defaults to 2. 22 | 23 | Raises: 24 | ValueError: When wrong fingerprint type is requested. 25 | 26 | Returns: 27 | np.ndarray: The generated fingerprint. 28 | """ 29 | mol = Chem.MolFromSmiles(smiles) 30 | if fp_type == 'morgan': 31 | fingerprint = AllChem.GetMorganFingerprintAsBitVect(mol, atomic_radius, 32 | nBits=n_bits) 33 | elif fp_type == 'maccs': 34 | fingerprint = MACCSkeys.GenMACCSKeys(mol) 35 | elif fp_type == 'path': 36 | fingerprint = Chem.rdmolops.RDKFingerprint(mol, fpSize=n_bits, 37 | minPath=min_path, 38 | maxPath=max_path) 39 | else: 40 | raise ValueError(f'Wrong type of fingerprint requested. Received "{fp_type}", expected one in: [morgan|maccs|path]') 41 | array = np.zeros((0,), dtype=np.int8) 42 | DataStructs.ConvertToNumpyArray(fingerprint, array) 43 | return array 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | checkpoints/ 132 | .vscode 133 | 134 | # TODO: The figures directory should be included later on... 135 | figures 136 | 137 | data/protac/ 138 | data/*.pt 139 | data/**/*.jpg 140 | data/figures 141 | data/figures/*.png 142 | data/figures/*.pdf 143 | data/figures/*.gif 144 | data/figures/*.html 145 | lightning_logs 146 | **/*.csv 147 | **/*.backup 148 | **/*.pkl 149 | **/*.sdf 150 | logs/ 151 | 152 | **/**/SSL_* -------------------------------------------------------------------------------- /src/models/e3_ligase_encoder/e3_ordinal_encoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Mapping, Literal, Callable, List, ClassVar, Any, Tuple, Type 3 | 4 | from sklearn.preprocessing import OrdinalEncoder 5 | 6 | import pytorch_lightning as pl 7 | from pytorch_lightning import LightningModule 8 | 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | 13 | import joblib 14 | import os 15 | 16 | 17 | class E3LigaseEncoder(pl.LightningModule): 18 | 19 | def __init__(self, 20 | normalize_output: bool = False, 21 | e3_ligase_encoder_filepath: str | None = 'e3_ligase_encoder.joblib', 22 | run_ordinal_encoder: bool = False, 23 | e3_ligase_train_data: np.ndarray | None = None, 24 | use_linear_layer: bool = False, 25 | embedding_size: int = 1, 26 | ): 27 | """Encode cell type as an embedding vector. 28 | 29 | Args: 30 | normalize_output (bool, optional): Whether to normalize the output vector. Defaults to False. 31 | e3_ligase_encoder_filepath (str | None, optional): Path to get the ordinal encoder from. Defaults to 'e3_ligase_encoder.joblib'. 32 | run_ordinal_encoder (bool, optional): Whether to run the ordinal encoder. If False, the encoder will handle the raw string class. Defaults to False. 33 | e3_ligase_train_data (np.ndarray | None, optional): Training data to fit the ordinal encoder. Defaults to None. 34 | use_linear_layer (bool, optional): Whether to use a linear layer to encode the cell type. If False, the embedding size will be overwritten to 1. Defaults to False. 35 | Raises: 36 | ValueError: e3_ligase_train_data must be passed if e3_ligase_encoder_filepath does not exist' 37 | """ 38 | super().__init__() 39 | # Set our init args as class attributes 40 | self.__dict__.update(locals()) # Add arguments as attributes 41 | self.save_hyperparameters(ignore='e3_ligase_train_data') 42 | # Load the pre-trained ordinal encoder if it exists, otherwise train it 43 | if run_ordinal_encoder: 44 | if os.path.exists(e3_ligase_encoder_filepath): 45 | self.e3_ligase_encoder = joblib.load(e3_ligase_encoder_filepath) 46 | else: 47 | self.e3_ligase_encoder = OrdinalEncoder(handle_unknown='use_encoded_value', 48 | unknown_value=-1, 49 | encoded_missing_value=-1) 50 | if e3_ligase_train_data is not None: 51 | self.fit_ordinal_encoder(e3_ligase_train_data) 52 | else: 53 | raise ValueError('e3_ligase_train_data must be passed if e3_ligase_encoder_filepath does not exist') 54 | if use_linear_layer: 55 | self.lin_layer = nn.Linear(1, self.embedding_size) 56 | else: 57 | self.embedding_size = 1 58 | 59 | def forward(self, x_in): 60 | if self.run_ordinal_encoder: 61 | cell_emb = self.e3_ligase_encoder.transform(x_in['e3_ligase'].numpy()) 62 | if self.normalize_output: 63 | cell_emb /= len(self.e3_ligase_encoder.categories_) 64 | cell_emb = torch.tensor(cell_emb, dtype=torch.float32) 65 | else: 66 | cell_emb = x_in['e3_ligase'] 67 | if self.use_linear_layer: 68 | return self.lin_layer(cell_emb) 69 | else: 70 | return cell_emb 71 | 72 | def get_embedding_size(self): 73 | return self.embedding_size 74 | 75 | def fit_ordinal_encoder(self, e3_ligase_train_data: np.ndarray): 76 | self.e3_ligase_encoder.fit(e3_ligase_train_data.reshape(-1, 1)) 77 | joblib.dump(self.e3_ligase_encoder, self.e3_ligase_encoder_filepath) -------------------------------------------------------------------------------- /src/models/cell_type_encoder/cell_type_ordinal_encoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Mapping, Literal, Callable, List, ClassVar, Any, Tuple, Type 3 | 4 | from sklearn.preprocessing import OrdinalEncoder 5 | 6 | import pytorch_lightning as pl 7 | from pytorch_lightning import LightningModule 8 | 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | 13 | import joblib 14 | import os 15 | 16 | 17 | class CellTypeEncoder(pl.LightningModule): 18 | 19 | def __init__(self, 20 | normalize_output: bool = False, 21 | cell_type_encoder_filepath: str | None = 'cell_type_encoder.joblib', 22 | run_ordinal_encoder: bool = False, 23 | cell_type_train_data: np.ndarray | None = None, 24 | use_linear_layer: bool = False, 25 | embedding_size: int = 1, 26 | ): 27 | """Encode cell type as an embedding vector. 28 | 29 | Args: 30 | normalize_output (bool, optional): Whether to normalize the output vector. Defaults to False. 31 | cell_type_encoder_filepath (str | None, optional): Path to get the ordinal encoder from. Defaults to 'cell_type_encoder.joblib'. 32 | run_ordinal_encoder (bool, optional): Whether to run the ordinal encoder. If False, the encoder will handle the raw string class. Defaults to False. 33 | cell_type_train_data (np.ndarray | None, optional): Training data to fit the ordinal encoder. Defaults to None. 34 | use_linear_layer (bool, optional): Whether to use a linear layer to encode the cell type. If False, the embedding size will be overwritten to 1. Defaults to False. 35 | Raises: 36 | ValueError: cell_type_train_data must be passed if cell_type_encoder_filepath does not exist' 37 | """ 38 | super().__init__() 39 | # Set our init args as class attributes 40 | self.__dict__.update(locals()) # Add arguments as attributes 41 | self.save_hyperparameters(ignore='cell_type_train_data') 42 | # Load the pre-trained ordinal encoder if it exists, otherwise train it 43 | if run_ordinal_encoder: 44 | if os.path.exists(cell_type_encoder_filepath): 45 | self.cell_type_encoder = joblib.load(cell_type_encoder_filepath) 46 | else: 47 | self.cell_type_encoder = OrdinalEncoder(handle_unknown='use_encoded_value', 48 | unknown_value=-1, 49 | encoded_missing_value=-1) 50 | if cell_type_train_data is not None: 51 | self.fit_ordinal_encoder(cell_type_train_data) 52 | else: 53 | raise ValueError('cell_type_train_data must be passed if cell_type_encoder_filepath does not exist') 54 | if use_linear_layer: 55 | self.lin_layer = nn.Linear(1, self.embedding_size) 56 | else: 57 | self.embedding_size = 1 58 | 59 | def forward(self, x_in): 60 | if self.run_ordinal_encoder: 61 | cell_emb = self.cell_type_encoder.transform(x_in['cell_type'].numpy()) 62 | if self.normalize_output: 63 | cell_emb /= len(self.cell_type_encoder.categories_) 64 | cell_emb = torch.tensor(cell_emb, dtype=torch.float32) 65 | else: 66 | cell_emb = x_in['cell_type'] 67 | if self.use_linear_layer: 68 | return self.lin_layer(cell_emb) 69 | else: 70 | return cell_emb 71 | 72 | def get_embedding_size(self): 73 | return self.embedding_size 74 | 75 | def fit_ordinal_encoder(self, cell_type_train_data: np.ndarray): 76 | self.cell_type_encoder.fit(cell_type_train_data.reshape(-1, 1)) 77 | joblib.dump(self.cell_type_encoder, self.cell_type_encoder_filepath) -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==2.0.2 2 | seed_everything: 42 3 | trainer: 4 | accelerator: gpu 5 | strategy: auto 6 | devices: auto 7 | num_nodes: 1 8 | precision: '16-mixed' 9 | logger: 10 | # class_path: CSVLogger 11 | # init_args: 12 | # save_dir: ./logs 13 | # name: log 14 | # version: null 15 | # prefix: null 16 | class_path: TensorBoardLogger 17 | init_args: 18 | save_dir: ./logs 19 | name: log 20 | version: null 21 | prefix: null 22 | callbacks: 23 | class_path: ModelCheckpoint 24 | init_args: 25 | monitor: val_acc 26 | mode: max 27 | save_top_k: 1 28 | save_weights_only: false 29 | dirpath: null 30 | filename: null 31 | class_path: EarlyStopping 32 | init_args: 33 | monitor: val_loss 34 | mode: min 35 | patience: 5 36 | check_finite: true 37 | class_path: EarlyStopping 38 | init_args: 39 | monitor: val_acc 40 | mode: max 41 | patience: 5 42 | check_finite: true 43 | enable_checkpointing: null 44 | max_epochs: 15 45 | fast_dev_run: false 46 | min_epochs: null 47 | max_steps: -1 48 | min_steps: null 49 | max_time: null 50 | limit_train_batches: null 51 | limit_val_batches: null 52 | limit_test_batches: null 53 | limit_predict_batches: null 54 | overfit_batches: 0.0 55 | val_check_interval: null 56 | check_val_every_n_epoch: 1 57 | num_sanity_val_steps: null 58 | log_every_n_steps: 8 59 | enable_progress_bar: null 60 | enable_model_summary: null 61 | accumulate_grad_batches: 1 62 | gradient_clip_val: 1.0 63 | gradient_clip_algorithm: norm 64 | deterministic: null 65 | benchmark: null 66 | inference_mode: true 67 | use_distributed_sampler: true 68 | profiler: null 69 | detect_anomaly: false 70 | barebones: false 71 | plugins: null 72 | sync_batchnorm: false 73 | reload_dataloaders_every_n_epochs: 0 74 | default_root_dir: null 75 | model: 76 | smiles_encoder: 77 | class_path: src.models.smiles_encoder.mlp.rdkit_fp_model.RDKitFingerprintEncoder 78 | init_args: 79 | dropout: 0.5 80 | # NOTE: The embedding size will be the last element of the list below. 81 | hidden_channels: [1024, 512, 256] 82 | poi_seq_encoder: 83 | class_path: src.models.poi_encoder.poi_count_vectorizer.POISequenceEncoder 84 | init_args: 85 | input_size: 403 86 | use_linear_layer: false 87 | e3_ligase_encoder: 88 | class_path: src.models.e3_ligase_encoder.e3_ordinal_encoder.E3LigaseEncoder 89 | init_args: 90 | embedding_size: 1 91 | use_linear_layer: false 92 | cell_type_encoder: 93 | class_path: src.models.cell_type_encoder.cell_type_ordinal_encoder.CellTypeEncoder 94 | init_args: 95 | embedding_size: 1 96 | use_linear_layer: false 97 | head: 98 | class_path: torchvision.ops.MLP 99 | init_args: 100 | # NOTE: The argument `in_channels` must be the sum of all the above 101 | # embedding sizes! 102 | in_channels: 661 103 | hidden_channels: [512, 256, 1] 104 | norm_layer: torch.nn.BatchNorm1d 105 | inplace: false 106 | dropout: 0.4 107 | # task: predict_active_inactive 108 | # freeze_smiles_encoder: false 109 | # learning_rate: 0.001 110 | optimizer: 111 | class_path: torch.optim.AdamW 112 | init_args: 113 | lr: 0.001 114 | data: 115 | train_df_path: ./data/train/train_bin_upsampled.csv 116 | val_df_path: ./data/val/val_bin.csv 117 | test_df_path: ./data/test/test_bin.csv 118 | predict_df_path: ./data/test/test_bin.csv 119 | batch_size: 256 120 | protac_dataset_args: 121 | precompute_fingerprints: false 122 | use_morgan_fp: true 123 | morgan_bits: 1024 124 | # smiles_tokenizer: 'seyonec/ChemBERTa-zinc-base-v1' 125 | poi_seq_enc: models/poi_encoder.joblib 126 | e3_ligase_enc: models/e3_ligase_encoder.joblib 127 | cell_type_enc: models/cell_type_encoder.joblib 128 | ckpt_path: null 129 | -------------------------------------------------------------------------------- /src/models/poi_encoder/poi_count_vectorizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Mapping, Literal, Callable, List, ClassVar, Any, Tuple, Type 3 | 4 | from sklearn.preprocessing import OrdinalEncoder 5 | from sklearn.feature_extraction.text import CountVectorizer 6 | 7 | import pytorch_lightning as pl 8 | from pytorch_lightning import LightningModule 9 | 10 | import numpy as np 11 | import torch 12 | from torch import nn 13 | 14 | import joblib 15 | import os 16 | 17 | 18 | class POISequenceEncoder(pl.LightningModule): 19 | 20 | def __init__(self, 21 | ngram_min_range: int = 2, 22 | ngram_max_range: int = 2, 23 | normalize_output: bool = False, 24 | poi_seq_encoder_filepath: str | None = 'poi_seq_encoder.joblib', 25 | run_count_vectorizer: bool = False, 26 | poi_seq_train_data: np.ndarray | None = None, 27 | input_size: int = 403, 28 | use_linear_layer: bool = False, 29 | embedding_size: int = 32, 30 | ): 31 | """Encode POI sequence as an embedding vector. 32 | 33 | Args: 34 | ngram_min_range (int, optional): Minimum ngram range for the count vectorizer. Defaults to 2. 35 | ngram_max_range (int, optional): Maximum ngram range for the count vectorizer. Defaults to 2. 36 | normalize_output (bool, optional): Whether to normalize the output vector. Defaults to False. 37 | poi_seq_encoder_filepath (str | None, optional): Path to get the count vectorizer from. Defaults to 'poi_seq_encoder.joblib'. 38 | run_count_vectorizer (bool, optional): Whether to run the count vectorizer. If False, the encoder will handle the raw string class. Defaults to False. 39 | poi_seq_train_data (np.ndarray | None, optional): Training data to fit the count vectorizer. Defaults to None. 40 | input_size (int, optional): Input size of the linear layer. Defaults to 403. 41 | use_linear_layer (bool, optional): Whether to use a linear layer. Defaults to False. 42 | embedding_size (int, optional): Embedding size of the linear layer. Defaults to 32. 43 | """ 44 | super().__init__() 45 | # Set our init args as class attributes 46 | self.__dict__.update(locals()) # Add arguments as attributes 47 | self.save_hyperparameters(ignore='poi_seq_train_data') 48 | # Load the pre-trained ordinal encoder if it exists, otherwise train it 49 | if os.path.exists(poi_seq_encoder_filepath): 50 | self.poi_seq_encoder = joblib.load(poi_seq_encoder_filepath) 51 | input_size = self.poi_seq_encoder.get_feature_names_out().shape[-1] 52 | elif run_count_vectorizer: 53 | ngram_range = (ngram_min_range, ngram_max_range) 54 | self.poi_seq_encoder = CountVectorizer(analyzer='char', 55 | ngram_range=ngram_range) 56 | if poi_seq_train_data is not None: 57 | self.fit_count_vectorizer(poi_seq_train_data) 58 | else: 59 | raise ValueError('poi_seq_train_data must be passed if poi_seq_encoder_filepath does not exist') 60 | input_size = self.poi_seq_encoder.get_feature_names_out().shape[-1] 61 | if use_linear_layer: 62 | self.lin_layer = nn.Linear(input_size, self.embedding_size) 63 | else: 64 | self.embedding_size = input_size 65 | 66 | def forward(self, x_in): 67 | if self.run_count_vectorizer: 68 | poi_emb = self.poi_seq_encoder.transform(x_in['poi_seq'].tolist()) 69 | if self.normalize_output: 70 | poi_emb /= len(self.poi_seq_encoder.categories_) 71 | poi_emb = torch.tensor(poi_emb, dtype=torch.float32) 72 | else: 73 | poi_emb = x_in['poi_seq'] 74 | if self.use_linear_layer: 75 | return self.lin_layer(poi_emb) 76 | else: 77 | return poi_emb 78 | 79 | def get_embedding_size(self): 80 | return self.embedding_size 81 | 82 | def fit_count_vectorizer(self, poi_seq_train_data: List[str]): 83 | self.poi_seq_encoder.fit(poi_seq_train_data) 84 | joblib.dump(self.poi_seq_encoder, self.poi_seq_encoder_filepath) -------------------------------------------------------------------------------- /config_default.yml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==2.0.2 2 | seed_everything: 42 3 | trainer: 4 | accelerator: gpu 5 | strategy: auto 6 | devices: auto 7 | num_nodes: 1 8 | precision: '16-mixed' 9 | logger: 10 | # - class_path: CSVLogger 11 | # init_args: 12 | # save_dir: ./logs 13 | # name: log 14 | # version: null 15 | # prefix: null 16 | - class_path: TensorBoardLogger 17 | init_args: 18 | save_dir: ./logs 19 | name: log 20 | version: null 21 | prefix: null 22 | callbacks: 23 | # - class_path: ModelCheckpoint 24 | # init_args: 25 | # monitor: val_acc 26 | # mode: max 27 | # save_top_k: 1 28 | # save_weights_only: false 29 | # dirpath: null 30 | # filename: null 31 | # - class_path: EarlyStopping 32 | # init_args: 33 | # monitor: val_acc 34 | # mode: max 35 | # patience: 10 36 | # check_finite: true 37 | - class_path: EarlyStopping 38 | init_args: 39 | monitor: val_loss 40 | mode: min 41 | patience: 10 42 | check_finite: true 43 | enable_checkpointing: null 44 | fast_dev_run: false 45 | max_epochs: 20 46 | min_epochs: null 47 | max_steps: -1 48 | min_steps: null 49 | max_time: null 50 | limit_train_batches: null 51 | limit_val_batches: null 52 | limit_test_batches: null 53 | limit_predict_batches: null 54 | overfit_batches: 0.0 55 | val_check_interval: null 56 | check_val_every_n_epoch: 1 57 | num_sanity_val_steps: null 58 | log_every_n_steps: 8 59 | enable_progress_bar: null 60 | enable_model_summary: null 61 | accumulate_grad_batches: 1 62 | gradient_clip_val: 1.0 63 | gradient_clip_algorithm: norm 64 | deterministic: null 65 | benchmark: null 66 | inference_mode: true 67 | use_distributed_sampler: true 68 | profiler: null 69 | detect_anomaly: false 70 | barebones: false 71 | plugins: null 72 | sync_batchnorm: false 73 | reload_dataloaders_every_n_epochs: 0 74 | default_root_dir: null 75 | model: 76 | # smiles_encoder: 77 | # class_path: models.smiles_encoder.mlp.rdkit_fp_model.RDKitFingerprintEncoder 78 | # init_args: 79 | # fp_bits: 4096 80 | # # NOTE: The embedding size will be the last element of the list below. 81 | # hidden_channels: [512, 128, 64] 82 | # dropout: 0.2 83 | smiles_encoder: 84 | class_path: models.smiles_encoder.gnn.torch_geom_architectures.GnnSubModel 85 | init_args: 86 | model_type: 'gat' 87 | hidden_channels: 128 88 | num_layers: 16 89 | out_channels: 64 90 | dropout: 0.1 91 | poi_seq_encoder: 92 | class_path: models.poi_encoder.poi_count_vectorizer.POISequenceEncoder 93 | init_args: 94 | input_size: 403 95 | embedding_size: 64 # 403 96 | use_linear_layer: true 97 | e3_ligase_encoder: 98 | class_path: models.e3_ligase_encoder.e3_ordinal_encoder.E3LigaseEncoder 99 | init_args: 100 | embedding_size: 64 # 1 101 | use_linear_layer: true 102 | cell_type_encoder: 103 | class_path: models.cell_type_encoder.cell_type_ordinal_encoder.CellTypeEncoder 104 | init_args: 105 | embedding_size: 64 # 1 106 | use_linear_layer: true 107 | head: 108 | class_path: torch.nn.Linear 109 | init_args: 110 | in_features: 64 111 | out_features: 1 112 | # class_path: torchvision.ops.MLP 113 | # init_args: 114 | # # NOTE: The argument `in_channels` must be the sum of all the above 115 | # # embedding sizes! 116 | # in_channels: 256 # 661 # ${{ model.poi_seq_encoder.init_args.embedding_size + model.poi_seq_encoder.init_args.input_size }} 117 | # hidden_channels: [32, 1] 118 | # norm_layer: torch.nn.BatchNorm1d 119 | # inplace: false 120 | # dropout: 0.2 121 | join_branches: 'sum' 122 | # task: predict_active_inactive 123 | # freeze_smiles_encoder: false 124 | # learning_rate: 0.001 125 | optimizer: 126 | class_path: torch.optim.AdamW 127 | init_args: 128 | lr: 1e-4 129 | data: 130 | train_df_path: ./data/train/train_bin_upsampled.csv 131 | val_df_path: ./data/val/val_bin.csv 132 | test_df_path: ./data/test/test_bin.csv 133 | predict_df_path: ./data/test/test_bin.csv 134 | batch_size: 128 135 | protac_dataset_args: 136 | include_smiles_as_graphs: true 137 | 138 | # use_morgan_fp: true 139 | # morgan_bits: ${model.smiles_encoder.init_args.fp_bits} 140 | # morgan_atomic_radius: 6 141 | # use_maccs_fp: true 142 | 143 | # smiles_tokenizer: 'seyonec/ChemBERTa-zinc-base-v1' 144 | poi_seq_enc: models/poi_encoder.joblib 145 | e3_ligase_enc: models/e3_ligase_encoder.joblib 146 | cell_type_enc: models/cell_type_encoder.joblib 147 | # ckpt_path: null 148 | -------------------------------------------------------------------------------- /src/models/smiles_encoder/mlp/rdkit_fp_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Mapping, Literal, Callable, List, ClassVar, Any, Tuple, Type 3 | 4 | import pytorch_lightning as pl 5 | from pytorch_lightning import LightningModule 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | from torch.utils.data import Dataset, DataLoader, random_split 11 | 12 | from torchvision.ops import MLP 13 | 14 | import torch_geometric 15 | import torch_geometric.nn as geom_nn 16 | import torch_geometric.data as geom_data 17 | from torch_geometric.utils.smiles import from_smiles 18 | 19 | from torchmetrics import (Accuracy, 20 | AUROC, 21 | ROC, 22 | Precision, 23 | Recall, 24 | F1Score, 25 | MeanAbsoluteError, 26 | MeanSquaredError) 27 | from torchmetrics.functional import (mean_absolute_error, 28 | mean_squared_error, 29 | mean_squared_log_error, 30 | pearson_corrcoef, 31 | r2_score) 32 | from torchmetrics.functional.classification import (binary_accuracy, 33 | binary_auroc, 34 | binary_precision, 35 | binary_recall, 36 | binary_f1_score) 37 | from torchmetrics import MetricCollection 38 | 39 | MACCS_BITWIDTH = 167 40 | 41 | 42 | class RDKitFingerprintEncoder(pl.LightningModule): 43 | 44 | def __init__(self, 45 | fp_type: Literal['morgan_fp', 'maccs_fp', 'path_fp'] = 'morgan_fp', 46 | fp_bits: int = 1024, 47 | hidden_channels: List[int] = [128, 128], 48 | norm_layer: Type[nn.Module] = nn.BatchNorm1d, 49 | dropout: float = 0.5): 50 | """SMILES encoder using RDKit fingerprints. 51 | 52 | Args: 53 | fp_type (Literal['morgan_fp', 'maccs_fp', 'path_fp'], optional): Type of fingerprint to use. Defaults to 'morgan_fp'. 54 | fp_bits (int, optional): Number of bits in the fingerprint. Defaults to 1024. 55 | hidden_channels (List[int], optional): Number of hidden channels in the MLP. Defaults to [128, 128]. 56 | norm_layer (Type[nn.Module], optional): Normalization layer to use. Defaults to nn.BatchNorm1d. 57 | dropout (float, optional): Dropout to use. Defaults to 0.5. 58 | """ 59 | super().__init__() 60 | # Set our init args as class attributes 61 | self.__dict__.update(locals()) # Add arguments as attributes 62 | self.save_hyperparameters() 63 | self.fp_bits = MACCS_BITWIDTH if fp_type == 'maccs_fp' else fp_bits 64 | # Define PyTorch model 65 | self.fp_encoder = MLP(in_channels=self.fp_bits, 66 | hidden_channels=hidden_channels, 67 | norm_layer=norm_layer, 68 | inplace=False, 69 | dropout=dropout) 70 | self.maccs_encoder = MLP(in_channels=MACCS_BITWIDTH, 71 | hidden_channels=hidden_channels, 72 | norm_layer=norm_layer, 73 | inplace=False, 74 | dropout=dropout) 75 | 76 | def forward(self, x_in): 77 | # return self.fp_encoder(x_in[self.fp_type]) 78 | morgan_emb = self.fp_encoder(x_in[self.fp_type]) 79 | maccs_emb = self.maccs_encoder(x_in['maccs_fp']) 80 | return morgan_emb + maccs_emb 81 | 82 | def get_embedding_size(self): 83 | return self.hidden_channels[-1] 84 | 85 | 86 | class FingerprintSubModel(pl.LightningModule): 87 | 88 | def __init__(self, 89 | fp_type: Literal['morgan_fp', 'maccs_fp', 'path_fp'] = 'morgan_fp', 90 | fp_bits: int = 1024, 91 | hidden_channels: List[int] = [128, 128], 92 | norm_layer: object = nn.BatchNorm1d, 93 | dropout: float = 0.5): 94 | super().__init__() 95 | # Set our init args as class attributes 96 | self.__dict__.update(locals()) # Add arguments as attributes 97 | self.save_hyperparameters() 98 | self.fp_bits = MACCS_BITWIDTH if fp_type == 'maccs_fp' else fp_bits 99 | # Define PyTorch model 100 | self.fp_encoder = MLP(in_channels=self.fp_bits, 101 | hidden_channels=hidden_channels, 102 | norm_layer=norm_layer, 103 | inplace=False, 104 | dropout=dropout) 105 | 106 | def forward(self, x_in): 107 | return self.fp_encoder(x_in[self.fp_type]) 108 | 109 | def get_smiles_embedding_size(self): 110 | return self.hidden_channels[-1] -------------------------------------------------------------------------------- /src/data/protac_dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Mapping, Literal, Callable, List, ClassVar, Any, Tuple, Type 3 | 4 | from data.protac_dataset import ProtacDataset 5 | 6 | import pytorch_lightning as pl 7 | from torch.utils.data import random_split, DataLoader 8 | 9 | from torch_geometric.data import Data, Batch 10 | from torch.utils.data._utils.collate import collate 11 | from torch.utils.data._utils.collate import default_collate_fn_map 12 | 13 | from torchvision import transforms 14 | 15 | import pandas as pd 16 | 17 | 18 | def graph_collate(batch, *, collate_fn_map=None): 19 | # Handle graph data separately: graph representation and computation can be 20 | # greatly optimized due to their sparse nature. In fact, multiple graphs in 21 | # a batch can be seen as a 'big' graph of unconnected sub-graphs. Hence, 22 | # their adjecency matrices can be combined together to form a single one. 23 | return Batch.from_data_list(batch) 24 | 25 | 26 | def custom_collate(batch): 27 | collate_map = default_collate_fn_map.copy() 28 | collate_map.update({Data: graph_collate}) 29 | return collate(batch, collate_fn_map=collate_map) 30 | 31 | 32 | class PROTACDataModule(pl.LightningDataModule): 33 | 34 | def __init__(self, 35 | train_df_path: str = './data/train/train_bin_upsampled.csv', 36 | val_df_path: str = './data/val/val_bin.csv', 37 | test_df_path: str = './data/test/test_bin.csv', 38 | predict_df_path: str = './data/test/test_bin.csv', 39 | protac_dataset_args: Mapping[str, Any] = {}, 40 | batch_size: int = 32): 41 | """Wrapper DataModule for PROTAC datasets. 42 | 43 | Args: 44 | train_df_path (str, optional): Path to train dataset CSV. Defaults to './data/train/train_bin.csv'. 45 | val_df_path (str, optional): Path to validation dataset CSV. Defaults to './data/val/val_bin.csv'. 46 | test_df_path (str, optional): Path to test dataset CSV. Defaults to './data/test/test_bin.csv'. 47 | predict_df_path (str, optional): Path to prediction dataset CSV. Defaults to './data/test/test_bin.csv'. 48 | protac_dataset_args (Mapping[str, Any], optional): Arguments to pass to ProtacDataset. Defaults to {}. 49 | batch_size (int, optional): Batch size. Defaults to 32. 50 | """ 51 | super().__init__() 52 | self.__dict__.update(locals()) # Add arguments as attributes 53 | self.save_hyperparameters() 54 | 55 | def prepare_data(self): 56 | # Download and clean PROTAC-DB and PROTAC-Pedia 57 | # TODO: Wrap the notebook code for data cleaning into a function 58 | pass 59 | 60 | def setup(self, stage: str = Literal['fit', 'validate', 'test', 'predict']): 61 | cols_to_keep = [ 62 | 'Smiles', 63 | 'Smiles_nostereo', 64 | 'DC50', 65 | 'pDC50', 66 | 'Dmax', 67 | 'poi_gene_id', 68 | 'poi_seq', 69 | 'cell_type', 70 | 'e3_ligase', 71 | 'active', 72 | ] 73 | # Assign train/val datasets for use in dataloaders 74 | if stage == 'fit' or stage == 'validate': 75 | train_df = pd.read_csv(self.train_df_path).reset_index(drop=True) 76 | train_df = train_df[cols_to_keep] 77 | self.train_ds = ProtacDataset(train_df, **self.protac_dataset_args) 78 | val_df = pd.read_csv(self.val_df_path).reset_index(drop=True) 79 | val_df = val_df[cols_to_keep] 80 | self.val_ds = ProtacDataset(val_df, **self.protac_dataset_args) 81 | # Assign test dataset for use in dataloader(s) 82 | if stage == 'test' or stage == 'predict': 83 | test_df = pd.read_csv(self.test_df_path).reset_index(drop=True) 84 | test_df = test_df[cols_to_keep] 85 | self.test_ds = ProtacDataset(test_df, **self.protac_dataset_args) 86 | predict_df = pd.read_csv(self.predict_df_path).reset_index(drop=True) 87 | predict_df = predict_df[cols_to_keep] 88 | self.predict_ds = ProtacDataset(predict_df, **self.protac_dataset_args) 89 | 90 | def train_dataset(self): 91 | return self.train_ds 92 | 93 | def val_dataset(self): 94 | return self.val_ds 95 | 96 | def test_dataset(self): 97 | return self.test_ds 98 | 99 | def predict_dataset(self): 100 | return self.predict_ds 101 | 102 | 103 | def train_dataloader(self): 104 | return DataLoader(self.train_dataset(), batch_size=self.batch_size, 105 | shuffle=True, collate_fn=custom_collate, 106 | drop_last=True) 107 | 108 | def val_dataloader(self): 109 | return DataLoader(self.val_dataset(), batch_size=self.batch_size, 110 | shuffle=False, collate_fn=custom_collate) 111 | 112 | def test_dataloader(self): 113 | return DataLoader(self.test_dataset(), batch_size=self.batch_size, 114 | shuffle=False, collate_fn=custom_collate) 115 | 116 | def test_dataloader(self): 117 | return DataLoader(self.predict_dataset(), batch_size=self.batch_size, 118 | shuffle=False, collate_fn=custom_collate) -------------------------------------------------------------------------------- /src/models/smiles_encoder/gnn/torch_geom_architectures.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Mapping, Literal, Callable, List, ClassVar, Any, Tuple, Type 3 | 4 | import torch 5 | import pytorch_lightning as pl 6 | import torch_geometric.nn as geom_nn 7 | from torch_geometric.nn.models import GIN, GAT, GCN, AttentiveFP 8 | 9 | class GnnSubModel(pl.LightningModule): 10 | 11 | def __init__(self, 12 | num_node_features: int = 9, 13 | node_edge_dim: int = 3, 14 | model_type: Literal['gin', 'gat', 'gcn', 'attentivefp'] = 'gin', 15 | hidden_channels: int = 32, 16 | num_layers: int = 8, 17 | out_channels: int = 8, 18 | dropout: float = 0.1, 19 | act: Literal['relu', 'elu'] = 'relu', 20 | jk: Literal['max', 'last', 'cat', 'lstm'] = 'max', 21 | norm: Literal['batch', 'layer'] = 'batch', 22 | num_timesteps: int = 16): 23 | """Initialize a GNN submodel for encoding SMILES strings into a fixed-length vector representation. 24 | 25 | Args: 26 | num_node_features (int, optional): Number of node features. Defaults to 9. See `from_smiles` [implementation](https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/utils/smiles.html#from_smiles). 27 | node_edge_dim (int, optional): Number of edge features. Defaults to 3. See `from_smiles` [implementation](https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/utils/smiles.html#from_smiles). 28 | model_type (Literal['gin', 'gat', 'gcn', 'attentivefp'], optional): Type of GNN to use. Defaults to 'gin'. 29 | hidden_channels (int, optional): Number of hidden channels. Defaults to 32. 30 | num_layers (int, optional): Number of GNN layers. Defaults to 8. 31 | out_channels (int, optional): Number of output channels. Defaults to 8. 32 | dropout (float, optional): Dropout probability. Defaults to 0.1. 33 | act (Literal['relu', 'elu'], optional): Activation function. Defaults to 'relu'. 34 | jk (Literal['max', 'last', 'cat', 'lstm'], optional): JK aggregation type. Defaults to 'max'. 35 | norm (Literal['batch', 'layer'], optional): Normalization type. Defaults to 'batch'. 36 | num_timesteps (int, optional): Number of timesteps for AttentiveFP. Defaults to 16. 37 | """ 38 | super().__init__() 39 | # Set our init args as class attributes 40 | self.__dict__.update(locals()) # Add arguments as attributes 41 | self.save_hyperparameters() 42 | self.smiles_embedding_size = out_channels 43 | if model_type == 'gin': 44 | self.smiles_embedding_size = hidden_channels 45 | self.gnn = GIN(in_channels=num_node_features, 46 | hidden_channels=hidden_channels, 47 | num_layers=num_layers, 48 | dropout=dropout, 49 | act=act, 50 | norm=norm, 51 | jk=jk) 52 | elif model_type == 'gat': 53 | self.gnn = GAT(in_channels=num_node_features, 54 | hidden_channels=hidden_channels, 55 | num_layers=num_layers, 56 | out_channels=out_channels, 57 | dropout=dropout, 58 | act=act, 59 | norm=norm, 60 | jk=jk) 61 | elif model_type == 'gcn': 62 | self.gnn = GCN(in_channels=num_node_features, 63 | hidden_channels=hidden_channels, 64 | num_layers=num_layers, 65 | out_channels=out_channels, 66 | dropout=dropout, 67 | act=act, 68 | norm=norm, 69 | jk=jk) 70 | elif model_type == 'attentivefp': 71 | self.gnn = AttentiveFP(in_channels=num_node_features, 72 | hidden_channels=hidden_channels, 73 | out_channels=out_channels, 74 | edge_dim=node_edge_dim, 75 | num_layers=num_layers, 76 | num_timesteps=num_timesteps, 77 | dropout=dropout) 78 | else: 79 | raise ValueError(f'Unknown model type: {model_type}. Available: gin, gat, gcn, attentivefp') 80 | 81 | 82 | def forward(self, batch): 83 | if self.model_type == 'gin': 84 | x = self.gnn(batch['smiles_graph'].x, 85 | batch['smiles_graph'].edge_index) 86 | smiles_emb = geom_nn.global_add_pool(x, batch['smiles_graph'].batch) 87 | elif self.model_type == 'gat': 88 | x = self.gnn(x=batch['smiles_graph'].x.to(torch.float), 89 | edge_index=batch['smiles_graph'].edge_index, 90 | edge_attr=batch['smiles_graph'].edge_attr) 91 | smiles_emb = geom_nn.global_add_pool(x, batch['smiles_graph'].batch) 92 | elif self.model_type == 'gcn': 93 | x = self.gnn(x=batch['smiles_graph'].x.to(torch.float), 94 | edge_index=batch['smiles_graph'].edge_index, 95 | edge_attr=batch['smiles_graph'].edge_attr) 96 | smiles_emb = geom_nn.global_add_pool(x, batch['smiles_graph'].batch) 97 | elif self.model_type == 'attentivefp': 98 | smiles_emb = self.gnn(batch['smiles_graph'].x.to(torch.float), 99 | batch['smiles_graph'].edge_index, 100 | batch['smiles_graph'].edge_attr, 101 | batch['smiles_graph'].batch) 102 | return smiles_emb 103 | 104 | def get_embedding_size(self): 105 | return self.smiles_embedding_size -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Maturity level-0](https://img.shields.io/badge/Maturity%20Level-ML--0-red) 2 | 3 | # ML for Predicting Targeted Protein Degradation 4 | 5 | This repository contains the code developed within the master thesis project: _"Machine Learning for Predicting Targetd Protain Degradation"_. A brief overview of the project and the thesis report can be read at this [repository](https://github.com/ribesstefano/ml-for-protacs). 6 | 7 | ## Code Overview 8 | 9 | The `ProtacModel` class, defined in the `src/models/wrapper_model.py` file, is a subclass of the PyTorch Lightning `LightningModule` class. It is a wrapper class that makes predictions on PROTAC data. 10 | The `ProtacModel` class takes in several encoders and a head module, which are used to encode the input data and make predictions. The forward method defines the forward pass of the model, which takes in the input data and passes it through the encoders and head module to produce the output. 11 | 12 | The model roughly follows the following architecture: 13 | 14 | ![image](https://github.com/ribesstefano/ml-for-protacs/assets/17163014/5adedfd7-9e5e-419b-bc8f-5334c9a41c4f) 15 | 16 | The `ProtacModel` class is used in the `src/main.py` for training and testing the model through a `LightningCLI` module. The `LightningCLI` module is a command-line interface that allows to train and test the model using a YAML configuration file. The configuration file contains the hyperparameters and dataset arguments that are used to define and automatically instantiate the model and the dataset before performing training and testing. 17 | 18 | The `src/models/wrapper_model.py` file contains instead the `WrapperModel` class, which is a wrapper around the PyTorch Lightning `ProtacModel` class. The purpose of this class is to provide a simple and flexible way to train and test various models using the same interface. 19 | The `WrapperModel` class can be used as a stand-alone class to train and test the model without the `LightningCLI` module. This is useful when one wants to train and test the model using a Jupyter Notebook or a Python script. 20 | 21 | ### Encoders 22 | 23 | In the `src/models` folder different encoders are defined. The encoders are used to encode the diverse input data about PROTACs before passing it to the head module. The encoders are all defined as subclasses of the `nn.Module` class and are used in the `ProtacModel` class to encode the input data. 24 | 25 | ### PROTACDataset and PROTACDataLoader 26 | 27 | The `PROTACDataset` class represents a dataset used for training and testing the PROTAC models. It takes in a Pandas DataFrame containing the PROTAC data and various arguments to preprocess the data. 28 | 29 | The `PROTACDataModule` class is defined in the `src/data/protac_dataloader.py` file and is a subclass of the PyTorch Lightning `LightningDataModule` class. It provides a convenient and customizable way to load and preprocess the PROTAC data and create PyTorch dataloaders for training, validation, and testing. 30 | 31 | ## Data Curation 32 | 33 | The data curation process is detailed and carried on in the notebooks `notebooks/protac_db_data_curation.ipynb` and `notebooks/protac_pedia_data_curation.ipynb`. They are based on the [PROTAC-DB](http://cadd.zju.edu.cn/protacdb/about) and [PROTAC-Pedia](https://protacpedia.weizmann.ac.il/ptcb/main) datasets. The directory `data` already contains curated versions of the aforementioned datasets. In order to perform data curation, one shall download the raw datasets from the respective sources and run the respective notebooks. 34 | 35 | ## Quick Start 36 | 37 | 1. Install the required dependencies by running `conda env create -f environment.yml` in your terminal. 38 | 39 | 2. To train an MLP model, run the following command: 40 | 41 | ```bash 42 | python main.py fit \ 43 | --trainer="{'max_epochs': 10, 'accelerator': 'gpu', 'precision': 16}" \ 44 | --model.smiles_encoder=src.models.smiles_encoder.mlp.rdkit_fp_model.RDKitFingerprintEncoder \ 45 | --data.protac_dataset_args="{'use_morgan_fp': True, 'morgan_bits': 1024, 'precompute_fingerprints': True, 'poi_vectorizer': 'models/poi_encoder.joblib', 'e3_ligase_enc': 'models/e3_ligase_encoder.joblib', 'cell_type_enc': 'models/cell_type_encoder.joblib'}" 46 | ``` 47 | This will train an MLP model with the specified hyperparameters and dataset arguments. 48 | 49 | 3. To test the MLP model, run the following command: 50 | 51 | ```bash 52 | python main.py test --ckpt_path=.\lightning_logs\version_7\checkpoints\epoch=9-step=930.ckpt -c .\lightning_logs\version_7\config.yaml 53 | ``` 54 | This will test the MLP model using the specified checkpoint path and configuration file. 55 | 56 | 4. To train a GNN model, for instance, run the following command instead: 57 | 58 | ```bash 59 | python main.py fit --trainer="{'max_epochs': 10, 'accelerator': 'gpu', 'precision': 16}" \ 60 | --model.smiles_encoder=src.models.smiles_encoder.gnn.torch_geom_architectures.GnnSubModel \ 61 | --model.use_smiles_only=False \ 62 | --data.protac_dataset_args="{'include_smiles_as_graphs': True, 'precompute_smiles_as_graphs': True, 'poi_vectorizer': 'models/poi_encoder.joblib', 'e3_ligase_enc': 'models/e3_ligase_encoder.joblib', 'cell_type_enc': 'models/cell_type_encoder.joblib'}" \ 63 | --model.cell_type_encoder=src.models.cell_type_encoder.cell_type_ordinal_encoder.CellTypeEncoder \ 64 | --model.e3_ligase_encoder=src.models.e3_encoder.e3_ordinal_encoder.E3LigaseEncoder \ 65 | --model.poi_seq_encoder=src.models.poi_encoder.poi_count_vectorizer.POISequenceEncoder \ 66 | --model.poi_seq_encoder_args=" {'poi_seq_encoder_filepath': 'models/poi_encoder.joblib'}" 67 | ``` 68 | 69 | This will train a GNN model with the specified hyperparameters and dataset arguments. 70 | 71 | That's it! You can modify the hyperparameters and dataset arguments as needed to train and test different models. Additionally, there are some TODOs in the code that you can work on to improve the functionality of the code. 72 | 73 | ## Hyperparameters Search with Optuna 74 | 75 | The `src/tune.py` file provides a convenient way to perform hyperparameter tuning using the Optuna library. It defines several functions that create PyTorch Lightning objects with the specified hyperparameters, and an objective function that is optimized by Optuna. By running it with the appropriate configuration files, one can perform hyperparameter tuning on the `ProtacModel` and obtain the best configuration as a YAML file. 76 | 77 | The `src/config_optuna.yml` file contains an example configuration for the Optuna hyperparameter tuning. It defines the search space for the hyperparameters and the number of trials to run. The `src/config_default.yml` file contains the default configuration for the model and dataset arguments for the `ProtacModel` and `ProtacDataset`. You can modify the configuration files as needed to perform hyperparameter tuning on your models. 78 | 79 | The `objective` function in `src/hyperparameter_tuning/optuna_utils.py` is the objective function that is optimized by Optuna during hyperparameter tuning. It returns the validation loss of the model trained with the suggested hyperparameters. Note that the CLI script will only change the arguments to be passed to `nn.Module` classes in the model before instantiating them. For more advanced hyperparameters configurations, one shall write a custom objective function. Please refer to the file `notebooks/machine_learning.ipynb` for some tailored examples. 80 | 81 | Here's a brief overview of what the `objective` function does: 82 | 83 | * Ovefits a trial model with the suggested hyperparameters on a minibatch. 84 | * If the training accuracy on the minibatch is less than 0.95, returns a score of 0.0. 85 | * Otherwise, it fits a trial model with the suggested hyperparameters. 86 | * Finally, returns the validation loss of the model. 87 | 88 | Example of usage: 89 | 90 | ```bash 91 | python .\src\tune.py --config .\config_default.yml --config .\config_optuna.yml 92 | ``` -------------------------------------------------------------------------------- /src/hyperparameter_tuning/cli.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Any, Dict, Set, Optional, Union, Type, Callable 3 | from os import PathLike 4 | 5 | from data.protac_dataloader import PROTACDataModule 6 | from models.wrapper_model import WrapperModel, ProtacModel # noqa: F401 7 | 8 | 9 | from optuna.study import Study, create_study 10 | from optuna.pruners import BasePruner 11 | 12 | from pytorch_lightning.cli import ( 13 | LightningCLI, 14 | LightningArgumentParser, 15 | ArgsType, 16 | SaveConfigCallback 17 | ) 18 | 19 | from pytorch_lightning import LightningDataModule, LightningModule, Trainer 20 | 21 | import importlib 22 | from jsonargparse import ( 23 | ActionConfigFile, 24 | ArgumentParser, 25 | class_from_function, 26 | Namespace, 27 | register_unresolvable_import_paths, 28 | set_config_read_mode, 29 | CLI, 30 | lazy_instance, 31 | ) 32 | 33 | 34 | def getclass(class_string): 35 | """Given a string, return the class type it represents.""" 36 | module_name, class_name = class_string.rsplit('.', 1) 37 | module = importlib.import_module(module_name) 38 | return getattr(module, class_name) 39 | 40 | 41 | def create_instance(class_string, args_dict, instanciate_nested: bool = False, lazy_init: bool = False): 42 | if isinstance(args_dict, Namespace) or isinstance(args_dict, dict) or isinstance(args_dict, list): 43 | tmp = args_dict 44 | if isinstance(args_dict, Namespace): 45 | tmp = args_dict.as_dict().items() 46 | if isinstance(args_dict, dict): 47 | tmp = args_dict.items() 48 | for arg_key, arg_value in tmp: 49 | # print('-' * 80) 50 | # print(f'arg_key: {arg_key}, arg_value: {arg_value}') 51 | # print('-' * 80) 52 | if isinstance(arg_value, dict): 53 | if 'class_path' in arg_value.keys(): 54 | # TODO: Do not create nested instances... current limitation 55 | if instanciate_nested: 56 | if 'init_args' in arg_value.keys(): 57 | args_dict[arg_key] = create_instance(arg_value['class_path'], arg_value['init_args']) 58 | else: 59 | args_dict[arg_key] = getclass(arg_value['class_path']) 60 | else: 61 | args_dict[arg_key] = getclass(arg_value['class_path']) 62 | class_ = getclass(class_string) 63 | if lazy_init: 64 | return lazy_instance(class_, **args_dict) 65 | else: 66 | return class_(**args_dict) 67 | 68 | 69 | class TuneLightningCLI(LightningCLI): 70 | 71 | def __init__( 72 | self, 73 | model_class: Optional[Union[Type[LightningModule], Callable[..., LightningModule]]] = None, 74 | datamodule_class: Optional[Union[Type[LightningDataModule], Callable[..., LightningDataModule]]] = None, 75 | save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback, 76 | save_config_kwargs: Optional[Dict[str, Any]] = None, 77 | trainer_class: Union[Type[Trainer], Callable[..., Trainer]] = Trainer, 78 | trainer_defaults: Optional[Dict[str, Any]] = None, 79 | seed_everything_default: Union[bool, int] = True, 80 | parser_kwargs: Optional[Union[Dict[str, Any], Dict[str, Dict[str, Any]]]] = None, 81 | subclass_mode_model: bool = False, 82 | subclass_mode_data: bool = False, 83 | args: ArgsType = None, 84 | run: bool = True, 85 | auto_configure_optimizers: bool = True, 86 | ) -> None: 87 | """Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which 88 | are called / instantiated using a parsed configuration file and / or command line args. 89 | 90 | Parsing of configuration from environment variables can be enabled by setting ``parser_kwargs={"default_env": 91 | True}``. A full configuration yaml would be parsed from ``PL_CONFIG`` if set. Individual settings are so parsed 92 | from variables named for example ``PL_TRAINER__MAX_EPOCHS``. 93 | 94 | For more info, read :ref:`the CLI docs `. 95 | 96 | Args: 97 | model_class: An optional :class:`~lightning.pytorch.core.module.LightningModule` class to train on or a 98 | callable which returns a :class:`~lightning.pytorch.core.module.LightningModule` instance when 99 | called. If ``None``, you can pass a registered model with ``--model=MyModel``. 100 | datamodule_class: An optional :class:`~lightning.pytorch.core.datamodule.LightningDataModule` class or a 101 | callable which returns a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` instance when 102 | called. If ``None``, you can pass a registered datamodule with ``--data=MyDataModule``. 103 | save_config_callback: A callback class to save the config. 104 | save_config_kwargs: Parameters that will be used to instantiate the save_config_callback. 105 | trainer_class: An optional subclass of the :class:`~lightning.pytorch.trainer.trainer.Trainer` class or a 106 | callable which returns a :class:`~lightning.pytorch.trainer.trainer.Trainer` instance when called. 107 | trainer_defaults: Set to override Trainer defaults or add persistent callbacks. The callbacks added through 108 | this argument will not be configurable from a configuration file and will always be present for 109 | this particular CLI. Alternatively, configurable callbacks can be added as explained in 110 | :ref:`the CLI docs `. 111 | seed_everything_default: Number for the :func:`~lightning.fabric.utilities.seed.seed_everything` 112 | seed value. Set to True to automatically choose a seed value. 113 | Setting it to False will avoid calling ``seed_everything``. 114 | parser_kwargs: Additional arguments to instantiate each ``LightningArgumentParser``. 115 | subclass_mode_model: Whether model can be any `subclass 116 | `_ 117 | of the given class. 118 | subclass_mode_data: Whether datamodule can be any `subclass 119 | `_ 120 | of the given class. 121 | args: Arguments to parse. If ``None`` the arguments are taken from ``sys.argv``. Command line style 122 | arguments can be given in a ``list``. Alternatively, structured config options can be given in a 123 | ``dict`` or ``jsonargparse.Namespace``. 124 | run: Whether subcommands should be added to run a :class:`~lightning.pytorch.trainer.trainer.Trainer` 125 | method. If set to ``False``, the trainer and model classes will be instantiated only. 126 | """ 127 | self.save_config_callback = save_config_callback 128 | self.save_config_kwargs = save_config_kwargs or {} 129 | self.trainer_class = trainer_class 130 | self.trainer_defaults = trainer_defaults or {} 131 | self.seed_everything_default = seed_everything_default 132 | self.parser_kwargs = parser_kwargs or {} # type: ignore[var-annotated] # github.com/python/mypy/issues/6463 133 | self.auto_configure_optimizers = auto_configure_optimizers 134 | 135 | self.model_class = model_class 136 | # used to differentiate between the original value and the processed value 137 | self._model_class = model_class or LightningModule 138 | self.subclass_mode_model = (model_class is None) or subclass_mode_model 139 | 140 | self.datamodule_class = datamodule_class 141 | # used to differentiate between the original value and the processed value 142 | self._datamodule_class = datamodule_class or LightningDataModule 143 | self.subclass_mode_data = (datamodule_class is None) or subclass_mode_data 144 | 145 | main_kwargs, subparser_kwargs = self._setup_parser_kwargs(self.parser_kwargs) 146 | self.setup_parser(run, main_kwargs, subparser_kwargs) 147 | self.parse_arguments(self.parser, args) 148 | 149 | self.subcommand = self.config["subcommand"] if run else None 150 | 151 | self._set_seed() 152 | 153 | # self.before_instantiate_classes() 154 | # self.instantiate_classes() 155 | # if self.subcommand is not None: 156 | # self._run_subcommand(self.subcommand) 157 | 158 | 159 | def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: 160 | """Implement to add extra arguments to the parser or link arguments. 161 | 162 | Args: 163 | parser: The parser object to which arguments can be added 164 | """ 165 | parser.add_argument('--optuna', type=dict, help='Optuna hyperparameters', required=True) 166 | parser.add_argument('--best-config-filename', type=PathLike, help='Best config file path', default='best_config.yaml') 167 | # parser.add_subclass_arguments(BasePruner, '--optuna.study.pruner', instantiate=False) 168 | 169 | 170 | def create_study(self) -> Study: 171 | config = {} 172 | if 'study' in self.config.optuna.keys(): 173 | for k, v in self.config.optuna['study'].items(): 174 | if isinstance(v, dict): 175 | if 'class_path' in v.keys(): 176 | config[k] = create_instance(v['class_path'], v['init_args']) 177 | else: 178 | config[k] = v 179 | return create_study(**config) -------------------------------------------------------------------------------- /src/hyperparameter_tuning/optuna_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Any, Dict, Set, Optional, Union, Type, Callable 3 | 4 | import optuna 5 | import copy 6 | import warnings 7 | import logging 8 | import importlib 9 | from data.protac_dataloader import PROTACDataModule 10 | from models.wrapper_model import WrapperModel, ProtacModel # noqa: F401 11 | 12 | from hyperparameter_tuning.cli import TuneLightningCLI 13 | from packaging import version 14 | from optuna.storages._cached_storage import _CachedStorage 15 | from optuna.storages._rdb.storage import RDBStorage 16 | from pytorch_lightning import LightningModule, Trainer 17 | from pytorch_lightning.callbacks import Callback 18 | 19 | from jsonargparse import ( 20 | ActionConfigFile, 21 | ArgumentParser, 22 | class_from_function, 23 | Namespace, 24 | register_unresolvable_import_paths, 25 | set_config_read_mode, 26 | CLI, 27 | lazy_instance, 28 | ) 29 | 30 | 31 | def suggest_optuna_param(trial: optuna.Trial, 32 | name: str, 33 | config: Dict[str, Any]) -> Any: 34 | """Given a config, suggest a value for the hyperparameter. 35 | 36 | Args: 37 | trial (optuna.Trial): The current Optuna trial 38 | name (str): The name of the hyperparameter 39 | config (Dict[str, Any]): The configuration of the hyperparameter 40 | 41 | Returns: 42 | Any: Suggested hyperparameter value 43 | """ 44 | return getattr(trial, config['function'])(name, **config['kwargs']) 45 | 46 | 47 | def objective(trial: optuna.Trial, cli: Optional[TuneLightningCLI] = None, config: Optional[Dict[str, Any]] = None) -> float: 48 | """Optuna objective function. Creates a new TuneLightningCLI object with the suggested hyperparameters and then fits the newly instantiated model. 49 | 50 | Args: 51 | trial (optuna.Trial): Optuna trial 52 | cli (Optional[TuneLightningCLI], optional): Custom LightningCLI object. Defaults to None. 53 | config (Optional[Dict[str, Any]], optional): Extra additional configurations as parsed arguments to the custom CLI. Defaults to None. 54 | 55 | Returns: 56 | float: Value of the objective function 57 | """ 58 | if cli is None and config is None: 59 | raise ValueError('Either cli or config must be provided') 60 | elif cli is None and config is not None: 61 | cli_trial = TuneLightningCLI(ProtacModel, 62 | PROTACDataModule, 63 | seed_everything_default=42, 64 | parser_kwargs={'parser_mode': 'omegaconf'}, 65 | args=config, 66 | run=False) 67 | else: 68 | cli_trial = copy.deepcopy(cli) 69 | if config is not None: 70 | print('WARNING. Both cli and config are provided. CLI "optuna" config will be updated.') 71 | cli_trial.config['optuna'].update(config) 72 | # Change CLI config with Optuna suggested values 73 | for hparam, cfg in cli_trial.config['optuna']['hparams'].items(): 74 | cli_trial.config[hparam] = suggest_optuna_param(trial, hparam, cfg) 75 | # Change logging name if TensorBoardLogger is used 76 | if 'trainer.logger.TensorBoardLogger' in cli_trial.config: 77 | name = cli_trial.config['trainer.logger.TensorBoardLogger.init_args.name'] 78 | name = f'{name}_{trial.number}' 79 | cli_trial.config['trainer.logger.TensorBoardLogger.init_args.name'] = name 80 | # Store current configuration in the trial attributes 81 | trial.set_user_attr('config', cli_trial.config) 82 | # TODO: Change Trainer config `overfit_batches` if not there already, then 83 | # fit the model and check its performance. If train accuracy not close to 84 | # 100%, then return a bad score. Else, remove the `overfit_batches` config 85 | # and fit the model. 86 | # 87 | # Turn on `overfit_batches` (with 5% of the data) and fit the model 88 | cli_trial.config['trainer.overfit_batches'] = 1 89 | # cli_trial.config['trainer.limit_val_batches'] = 1.0 90 | cli_trial.instantiate_classes() 91 | cli_trial.trainer.fit(cli_trial.model, cli_trial.datamodule) 92 | # Check if train accuracy is close to 100% 93 | if cli_trial.trainer.callback_metrics['train_acc'] < 0.95: 94 | logging.warning(f'WARNING. Train accuracy is {cli_trial.trainer.callback_metrics["train_acc"]}. Returning bad score.') 95 | return 0.0 96 | # Turn off `overfit_batches` and fit the model 97 | cli_trial.config['trainer.overfit_batches'] = 0.0 98 | # NOTE: Instantiating again the classes would hopefully overwrite the model 99 | # and the dataloader, without wasting memory... 100 | # TODO: Add Optuna callback for pruning to Trainer callbacks 101 | # Instantiate classes and fit model 102 | cli_trial.instantiate_classes() 103 | cli_trial.trainer.fit(cli_trial.model, cli_trial.datamodule) 104 | # Obtain the metric value 105 | return cli_trial.trainer.callback_metrics[cli_trial.config.optuna['metric']] 106 | 107 | 108 | # Define key names of `Trial.system_attrs`. 109 | _PRUNED_KEY = "ddp_pl:pruned" 110 | _EPOCH_KEY = "ddp_pl:epoch" 111 | 112 | class CustomPyTorchLightningPruningCallback(Callback): 113 | """PyTorch Lightning callback to prune unpromising trials. 114 | 115 | See `the example `__ 117 | if you want to add a pruning callback which observes accuracy. 118 | 119 | Args: 120 | trial: 121 | A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the 122 | objective function. 123 | monitor: 124 | An evaluation metric for pruning, e.g., ``val_loss`` or 125 | ``val_acc``. The metrics are obtained from the returned dictionaries from e.g. 126 | ``pytorch_lightning.LightningModule.training_step`` or 127 | ``pytorch_lightning.LightningModule.validation_epoch_end`` and the names thus depend on 128 | how this dictionary is formatted. 129 | 130 | .. note:: 131 | For the distributed data parallel training, the version of PyTorchLightning needs to be 132 | higher than or equal to v1.5.0. In addition, :class:`~optuna.study.Study` should be 133 | instantiated with RDB storage. 134 | """ 135 | 136 | def __init__(self, trial: optuna.trial.Trial, monitor: str) -> None: 137 | super().__init__() 138 | 139 | self._trial = trial 140 | self.monitor = monitor 141 | self.is_ddp_backend = False 142 | 143 | def on_init_start(self, trainer: Trainer) -> None: 144 | self.is_ddp_backend = ( 145 | trainer._accelerator_connector.distributed_backend is not None # type: ignore 146 | ) 147 | if self.is_ddp_backend: 148 | if version.parse(pl.__version__) < version.parse("1.5.0"): # type: ignore 149 | raise ValueError("PyTorch Lightning>=1.5.0 is required in DDP.") 150 | if not ( 151 | isinstance(self._trial.study._storage, _CachedStorage) 152 | and isinstance(self._trial.study._storage._backend, RDBStorage) 153 | ): 154 | raise ValueError( 155 | "optuna.integration.PyTorchLightningPruningCallback" 156 | " supports only optuna.storages.RDBStorage in DDP." 157 | ) 158 | 159 | def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None: 160 | 161 | # When the trainer calls `on_validation_end` for sanity check, 162 | # do not call `trial.report` to avoid calling `trial.report` multiple times 163 | # at epoch 0. The related page is 164 | # https://github.com/PyTorchLightning/pytorch-lightning/issues/1391. 165 | if trainer.sanity_checking: 166 | return 167 | 168 | epoch = pl_module.current_epoch 169 | 170 | current_score = trainer.callback_metrics.get(self.monitor) 171 | if current_score is None: 172 | message = ( 173 | "The metric '{}' is not in the evaluation logs for pruning. " 174 | "Please make sure you set the correct metric name.".format(self.monitor) 175 | ) 176 | warnings.warn(message) 177 | return 178 | 179 | should_stop = False 180 | if trainer.is_global_zero: 181 | self._trial.report(current_score.item(), step=epoch) 182 | should_stop = self._trial.should_prune() 183 | # TODO: The following line breaks the current version of Pytorch 184 | # Lightning. But I suspect it's necessary in a distributed training 185 | # environment... so it shouldn't matter for us... 186 | # should_stop = trainer.training_type_plugin.broadcast(should_stop) 187 | trainer.should_stop = should_stop 188 | if not should_stop: 189 | return 190 | 191 | if not self.is_ddp_backend: 192 | message = "Trial was pruned at epoch {}.".format(epoch) 193 | raise optuna.TrialPruned(message) 194 | else: 195 | # Stop every DDP process if global rank 0 process decides to stop. 196 | trainer.should_stop = True 197 | if trainer.is_global_zero: 198 | self._trial.storage.set_trial_system_attr(self._trial._trial_id, _PRUNED_KEY, True) 199 | self._trial.storage.set_trial_system_attr(self._trial._trial_id, _EPOCH_KEY, epoch) 200 | 201 | def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: 202 | if not self.is_ddp_backend: 203 | return 204 | 205 | # Because on_validation_end is executed in spawned processes, 206 | # _trial.report is necessary to update the memory in main process, not to update the RDB. 207 | _trial_id = self._trial._trial_id 208 | _study = self._trial.study 209 | _trial = _study._storage._backend.get_trial(_trial_id) # type: ignore 210 | _trial_system_attrs = _study._storage.get_trial_system_attrs(_trial_id) 211 | is_pruned = _trial_system_attrs.get(_PRUNED_KEY) 212 | epoch = _trial_system_attrs.get(_EPOCH_KEY) 213 | intermediate_values = _trial.intermediate_values 214 | for step, value in intermediate_values.items(): 215 | self._trial.report(value, step=step) 216 | 217 | if is_pruned: 218 | message = "Trial was pruned at epoch {}.".format(epoch) 219 | raise optuna.TrialPruned(message) -------------------------------------------------------------------------------- /src/data/protac_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import (Dict, Optional, Literal, Callable, Any, Tuple, Type) 3 | 4 | from utils.fingerprints import get_fingerprint 5 | 6 | import numpy as np 7 | 8 | import torch 9 | from torch.utils.data import Dataset 10 | 11 | import torch_geometric 12 | import torch_geometric.nn as geom_nn 13 | import torch_geometric.data as geom_data 14 | from torch_geometric.utils.smiles import from_smiles 15 | 16 | from torchvision.transforms import Compose 17 | 18 | from sklearn.feature_extraction.text import CountVectorizer 19 | from sklearn.preprocessing import OrdinalEncoder, OneHotEncoder 20 | 21 | import pandas as pd 22 | import joblib 23 | 24 | from transformers import ( 25 | AutoTokenizer, 26 | AutoModelForMaskedLM, 27 | TrainingArguments, 28 | Trainer, 29 | DataCollatorForLanguageModeling, 30 | RobertaTokenizerFast, 31 | RobertaForMaskedLM, 32 | ) 33 | 34 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 35 | 36 | class ProtacDataset(Dataset): 37 | 38 | def __init__(self, 39 | dataframe, 40 | task: Literal['predict_active_inactive', 'predict_pDC50_and_Dmax'] = 'predict_active_inactive', 41 | scale_concentration: bool = False, 42 | include_smiles_as_str: bool = False, 43 | include_smiles_as_graphs: bool = False, 44 | smiles_tokenizer: Optional[str | Callable] = None, 45 | smiles_tokenizer_type: Type = AutoTokenizer, 46 | smiles_tokenizer_args: Dict = {}, 47 | ngram_range: Tuple[int, int] = (2, 2), 48 | precompute_smiles_as_graphs: bool = False, 49 | precompute_fingerprints: bool = False, 50 | use_for_ssl: bool = False, 51 | use_morgan_fp: bool = False, 52 | morgan_bits: int = 1024, 53 | morgan_atomic_radius: int = 2, 54 | use_maccs_fp: bool = False, 55 | use_path_fp: bool = False, 56 | path_bits: int = 1024, 57 | fp_min_path: int = 1, 58 | fp_max_path: int = 16, 59 | include_poi_seq: bool = True, 60 | include_poi_gene: bool = False, 61 | include_e3_ligase: bool = True, 62 | include_cell_type: bool = True, 63 | tokenize_poi_seq: bool = False, 64 | poi_tokenizer: Optional[Callable | str] = None, 65 | poi_seq_enc: Optional[Callable | str] = None, 66 | preencode_poi_seq: bool = False, 67 | poi_gene_enc: Optional[Callable | str] = None, 68 | e3_ligase_enc: Optional[Callable | str] = None, 69 | cell_type_enc: Optional[Callable | str] = None, 70 | use_default_poi_seq_enc: bool = False, 71 | use_default_e3_ligase_enc: bool = False, 72 | use_default_cell_type_enc: bool = False, 73 | normalize_poi_seq_enc: bool = False, 74 | normalize_e3_ligase_enc: bool = False, 75 | normalize_cell_type_enc: bool = False, 76 | normalize_poi_gene_enc: bool = False, 77 | transform: Optional[Callable] = None): 78 | """Pytorch Dataset for PROTAC data. Each element will consist of a dictionary of different processed features. 79 | When processed by a DataLoader, the dictionary structure will remain, but each value will be converted to a batch of tensors. 80 | 81 | Args: 82 | dataframe (pd.DataFrame): Dataframe containing the PROTAC data 83 | task (Literal['predict_active_inactive', 'predict_pDC50_and_Dmax'], optional): Task to perform. Defaults to 'predict_active_inactive'. 84 | scale_concentration (bool, optional): Whether to scale the concentration values. Defaults to False. 85 | 86 | """ 87 | self.__dict__.update(locals()) # Add arguments as attributes 88 | self.hparams = {k: v for k, v in locals().items() if k != 'dataframe' and k != 'self'} # Store hyperparameters 89 | self.maccs_bits = 167 # Hardcoded, see RDKit documentation 90 | self.dataset_len = len(self.dataframe) 91 | # Handle SMILES information 92 | self.smiles = self.dataframe['Smiles_nostereo'] 93 | # if include_selfies: 94 | # self.selfies = [sf.encoder(s) for s in self.smiles] 95 | if precompute_fingerprints: 96 | if self.use_morgan_fp: 97 | self.morgan_fp = np.array([get_fingerprint(s, n_bits=self.morgan_bits, fp_type='morgan', atomic_radius=morgan_atomic_radius).astype(np.float32) for s in self.smiles]) 98 | if self.use_maccs_fp: 99 | self.maccs_fp = np.array([get_fingerprint(s, fp_type='maccs').astype(np.float32) for s in self.smiles]) 100 | if self.use_path_fp: 101 | self.path_fp = np.array([get_fingerprint(s, n_bits=self.path_bits, fp_type='path', min_path=self.fp_min_path, max_path=self.fp_max_path).astype(np.float32) for s in self.smiles]) 102 | if include_smiles_as_graphs or precompute_smiles_as_graphs: 103 | # NOTE: self.graph_smiles is a list of PytorchGeometric Data objects 104 | self.graph_smiles = [from_smiles(s) for s in self.smiles] 105 | if smiles_tokenizer is not None: 106 | if isinstance(smiles_tokenizer, str): 107 | self.smiles_tokenizer = smiles_tokenizer_type.from_pretrained(smiles_tokenizer, **smiles_tokenizer_args) 108 | else: 109 | self.smiles_tokenizer = smiles_tokenizer 110 | # NOTE: Do NOT return tensors when doing SSL, i.e., MLM, as reported 111 | # in this conversation: https://discuss.huggingface.co/t/extra-dimension-with-datacollatorfor-languagemodeling-into-bertformaskedlm/6400/6 112 | if use_for_ssl: 113 | self.smiles_tokenized = [ 114 | self.smiles_tokenizer(s, padding='max_length', truncation=True) for s in self.smiles 115 | ] 116 | assert len(self.smiles_tokenized) == len(self.smiles), ( 117 | f'ERROR. Len tokenized {len(self.smiles_tokenized)} /= len SMILES {len(self.smiles)}' 118 | ) 119 | else: 120 | self.smiles_tokenized = [ 121 | self.smiles_tokenizer(s, padding='max_length', truncation=True, return_tensors='pt') for s in self.smiles 122 | ] 123 | # Handle the POI sequence 124 | if include_poi_seq: 125 | self.poi_seq = self.dataframe['poi_seq'].to_list() 126 | if poi_seq_enc is not None: 127 | if isinstance(poi_seq_enc, str): 128 | self.poi_seq_enc = joblib.load(poi_seq_enc) 129 | else: 130 | self.poi_seq_enc = poi_seq_enc 131 | else: 132 | self.poi_seq_enc = CountVectorizer(analyzer='char', 133 | ngram_range=ngram_range) 134 | self.poi_seq_enc.fit(self.poi_seq) 135 | if preencode_poi_seq: 136 | self.poi_seq = self.poi_seq_enc.transform(self.poi_seq) 137 | self.poi_seq = self.poi_seq.toarray().astype(np.float32) 138 | # Tokenize the POI sequence (for example for BERT-based models) 139 | if tokenize_poi_seq: 140 | if poi_tokenizer is None: 141 | self.poi_seq = self.dataframe['poi_seq'] 142 | else: 143 | self.poi_seq = [self.poi_tokenizer(seq, padding='max_length', truncation=True, return_tensors='pt') for seq in self.dataframe['poi_seq']] 144 | # Handle the POI gene 145 | if include_poi_gene: 146 | self.gene = self.dataframe['poi_gene_id'].to_numpy().reshape(-1, 1) 147 | if poi_gene_enc is not None: 148 | self.gene = poi_gene_enc.transform(self.gene) 149 | self.gene = self.gene.astype(np.float32) 150 | if normalize_poi_gene_enc: 151 | self.gene /= len(poi_gene_enc.categories_) 152 | else: 153 | self.poi_gene_enc = OrdinalEncoder( 154 | handle_unknown='use_encoded_value', 155 | unknown_value=-1, 156 | encoded_missing_value=-1 157 | ) 158 | tmp = self.poi_gene_enc.fit_transform(self.gene) 159 | self.gene = tmp.astype(np.float32).flatten() 160 | # Handle the E3 ligase 161 | if include_e3_ligase: 162 | self.e3_ligase = self.dataframe['e3_ligase'].to_numpy().reshape(-1, 1) 163 | if e3_ligase_enc is not None: 164 | if isinstance(e3_ligase_enc, str): 165 | self.e3_ligase_enc = joblib.load(e3_ligase_enc) 166 | else: 167 | self.e3_ligase_enc = e3_ligase_enc 168 | self.e3_ligase = self.e3_ligase_enc.transform(self.e3_ligase) 169 | self.e3_ligase = self.e3_ligase.astype(np.float32) 170 | if normalize_e3_ligase_enc: 171 | self.e3_ligase /= len(self.e3_ligase_enc.categories_) 172 | elif use_default_e3_ligase_enc: 173 | self.e3_ligase_enc = OrdinalEncoder( 174 | handle_unknown='use_encoded_value', 175 | unknown_value=-1, 176 | encoded_missing_value=-1 177 | ) 178 | tmp = self.e3_ligase_enc.fit_transform(self.e3_ligase) 179 | self.e3_ligase = tmp.astype(np.float32) 180 | # Handle cell type 181 | if include_cell_type: 182 | self.cell_type = self.dataframe['cell_type'].to_numpy().reshape(-1, 1) 183 | if cell_type_enc is not None: 184 | if isinstance(cell_type_enc, str): 185 | self.cell_type_enc = joblib.load(cell_type_enc) 186 | else: 187 | self.cell_type_enc = cell_type_enc 188 | self.cell_type = self.cell_type_enc.transform(self.cell_type) 189 | self.cell_type = self.cell_type.astype(np.float32) 190 | if normalize_cell_type_enc: 191 | self.cell_type /= len(self.cell_type_enc.categories_) 192 | elif use_default_cell_type_enc: 193 | self.cell_type_enc = OrdinalEncoder( 194 | handle_unknown='use_encoded_value', 195 | unknown_value=-1, 196 | encoded_missing_value=-1 197 | ) 198 | tmp = self.cell_type_enc.fit_transform(self.cell_type) 199 | self.cell_type = tmp.astype(np.float32).flatten() 200 | # Handle PROTAC activity information 201 | if not use_for_ssl: 202 | if task == 'predict_active_inactive': 203 | num_nan = len(self.dataframe[self.dataframe['active'].isna()]) 204 | if num_nan > 0: 205 | print('-' * 80) 206 | print(f'Number of NaNs in active column: {num_nan}') 207 | print('-' * 80) 208 | raise ValueError('NaNs found in active column') 209 | # self.dataframe = dataframe.dropna(subset=['active']) 210 | self.dataframe['active'] = self.dataframe['active'].replace({True: 1, False: 0}) 211 | # Get the concentration and degradation values 212 | self.active = self.dataframe['active'].to_numpy().astype(np.float32).reshape(-1, 1) 213 | else: 214 | # TODO: Scaling the concentrations and degradations??? 215 | if scale_concentration: 216 | self.pDC50 = (self.dataframe['pDC50'] * 0.1).astype(np.float32) 217 | else: 218 | self.pDC50 = self.dataframe['pDC50'].astype(np.float32) 219 | self.Dmax = (self.dataframe['Dmax']).astype(np.float32) 220 | 221 | @staticmethod 222 | def load(pt_file): 223 | # TODO: Work in progress 224 | return torch.load(pt_file) 225 | 226 | def __len__(self): 227 | return len(self.smiles) 228 | 229 | def __getitem__(self, idx): 230 | smiles = self.smiles.iloc[idx] 231 | if self.use_for_ssl: 232 | elem = {} 233 | if self.smiles_tokenizer: 234 | smiles_tokenized = self.smiles_tokenized[idx] 235 | elem['input_ids'] = smiles_tokenized['input_ids'] 236 | elem['attention_mask'] = smiles_tokenized['attention_mask'] 237 | elem['labels'] = smiles_tokenized['input_ids'].copy() 238 | else: 239 | elem['smiles'] = smiles 240 | return elem 241 | elem = {} 242 | if self.include_poi_seq: 243 | elem['poi_seq'] = self.poi_seq[idx] 244 | if self.poi_seq_enc is not None and not self.preencode_poi_seq: 245 | poi_seq = self.poi_seq_enc.transform([self.poi_seq[idx]]) 246 | poi_seq = poi_seq.toarray().flatten().astype(np.float32) 247 | elem['poi_seq'] = poi_seq 248 | elif self.tokenize_poi_seq is not None: 249 | poi_seq = self.tokenize_poi_seq(self.poi_seq[idx]) 250 | elem['poi_seq'] = poi_seq 251 | if self.include_poi_gene: 252 | elem['poi_gene_id'] = self.gene[idx] 253 | if self.include_e3_ligase: 254 | elem['e3_ligase'] = self.e3_ligase[idx] 255 | if self.include_cell_type: 256 | elem['cell_type'] = self.cell_type[idx] 257 | if self.task == 'predict_active_inactive': 258 | elem['labels'] = self.active[idx] 259 | elif self.task == 'predict_pDC50_and_Dmax': 260 | Dmax = self.Dmax.iloc[idx] 261 | pDC50 = self.pDC50.iloc[idx] 262 | elem['labels'] = np.array([Dmax, pDC50]) 263 | else: 264 | raise ValueError(f'Task "{self.task}" not recognized. Available: "predict_active_inactive" \| "predict_pDC50_and_Dmax"') 265 | if self.include_smiles_as_graphs or self.precompute_smiles_as_graphs: 266 | if self.precompute_smiles_as_graphs: 267 | elem['smiles_graph'] = self.graph_smiles[idx] 268 | else: 269 | elem['smiles_graph'] = from_smiles(smiles) 270 | if self.smiles_tokenizer: 271 | elem['smiles_tokenized'] = self.smiles_tokenized[idx] 272 | if self.include_smiles_as_str: 273 | elem['smiles'] = smiles 274 | if self.use_morgan_fp: 275 | if self.precompute_fingerprints: 276 | fp = self.morgan_fp[idx].copy() 277 | else: 278 | fp = get_fingerprint(smiles, n_bits=self.morgan_bits).astype(np.float32) 279 | elem['morgan_fp'] = fp 280 | if self.use_maccs_fp: 281 | if self.precompute_fingerprints: 282 | fp = self.maccs_fp[idx].copy() 283 | else: 284 | fp = get_fingerprint(smiles, fp_type='maccs').astype(np.float32) 285 | elem['maccs_fp'] = fp 286 | if self.use_path_fp: 287 | if self.precompute_fingerprints: 288 | fp = self.path_fp[idx].copy() 289 | else: 290 | fp = get_fingerprint(smiles, n_bits=self.path_bits, 291 | fp_type='path', 292 | min_path=self.fp_min_path, 293 | max_path=self.fp_max_path).astype(np.float32) 294 | elem['path_fp'] = fp 295 | if self.transform is not None: 296 | elem = self.transform(elem) 297 | return elem 298 | 299 | def get_fingerprint(self, fp_type: Literal['morgan_fp', 'maccs_fp', 'path_fp'] = 'morgan_fp'): 300 | # TODO: Add the proper checks if fingerprints are used 301 | if self.precompute_fingerprints: 302 | if fp_type == 'morgan_fp': 303 | return self.morgan_fp 304 | elif fp_type == 'maccs_fp': 305 | return self.maccs_fp 306 | elif fp_type == 'path_fp': 307 | return self.path_fp 308 | else: 309 | raise ValueError(f'Fingerprint type "{fp_type}" not recognized. Available: "morgan_fp" \| "maccs_fp" \| "path_fp"') 310 | else: 311 | smiles = self.smiles 312 | if fp_type == 'morgan_fp': 313 | return np.array([get_fingerprint(s, n_bits=self.morgan_bits).astype(np.float32) for s in smiles]) 314 | elif fp_type == 'maccs_fp': 315 | return np.array([get_fingerprint(s, fp_type='maccs_fp').astype(np.float32) for s in smiles]) 316 | elif fp_type == 'path_fp': 317 | return np.array([get_fingerprint(s, n_bits=self.path_bits, fp_type='path_fp', min_path=self.fp_min_path, max_path=self.fp_max_path).astype(np.float32) for s in smiles]) 318 | else: 319 | raise ValueError(f'Fingerprint type "{fp_type}" not recognized. Available: "morgan_fp" \| "maccs_fp" \| "path_fp"') 320 | 321 | def get_poi_seq_emb_size(self): 322 | if self.include_poi_seq: 323 | return len(self.poi_seq_enc.get_feature_names_out()) 324 | else: 325 | return 0 326 | 327 | def __str__(self) -> str: 328 | return f'ProtacDataset for {self.task} task with {len(self)} samples.' -------------------------------------------------------------------------------- /notebooks/extra_features_encoders.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from IPython.display import display_html\n", 10 | "\n", 11 | "import logging\n", 12 | "import warnings\n", 13 | "import re\n", 14 | "import os\n", 15 | "import numpy as np\n", 16 | "import pandas as pd\n", 17 | "import pickle\n", 18 | "import sklearn.metrics\n", 19 | "import pickle\n", 20 | "import requests\n", 21 | "import collections\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "import seaborn as sns\n", 24 | "from rdkit import Chem\n", 25 | "from rdkit.Chem import AllChem, DataStructs, MACCSkeys\n", 26 | "from datetime import date\n", 27 | "from typing import Literal, Optional, Union, List, Dict, Tuple, Any, Callable\n", 28 | "from collections import defaultdict\n", 29 | "\n", 30 | "from IPython.display import display_html\n", 31 | "# from IPython.core.interactiveshell import InteractiveShell\n", 32 | "# InteractiveShell.ast_node_interactivity = 'all'\n", 33 | "\n", 34 | "import collections\n", 35 | "import itertools\n", 36 | "import re\n", 37 | "import gc\n", 38 | "import math\n", 39 | "import numpy as np\n", 40 | "import pandas as pd\n", 41 | "import pickle\n", 42 | "import requests as r\n", 43 | "import matplotlib.pyplot as plt\n", 44 | "import seaborn as sns\n", 45 | "import shutil\n", 46 | "import random\n", 47 | "import copy\n", 48 | "import os\n", 49 | "\n", 50 | "import typing\n", 51 | "from typing import Mapping, Literal, Callable, List, ClassVar, Any, Tuple, Type\n", 52 | "\n", 53 | "from uuid import uuid4\n", 54 | "from rdkit import Chem\n", 55 | "from rdkit.Chem import AllChem, DataStructs, MACCSkeys, Draw\n", 56 | "from rdkit.Chem.Draw import IPythonConsole\n", 57 | "from datetime import date\n", 58 | "from scipy.sparse import csr_matrix, vstack\n", 59 | "from tqdm import tqdm\n", 60 | "\n", 61 | "import sklearn\n", 62 | "from sklearn.feature_extraction.text import CountVectorizer\n", 63 | "from sklearn.model_selection import train_test_split, GroupShuffleSplit\n", 64 | "from sklearn import preprocessing\n", 65 | "from sklearn import metrics\n", 66 | "from sklearn.metrics import classification_report, f1_score, roc_auc_score\n", 67 | "from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay\n", 68 | "from sklearn.utils import resample, class_weight\n", 69 | "\n", 70 | "import joblib\n", 71 | "\n", 72 | "pd.set_option('display.max_columns', 1000, 'display.width', 2000, 'display.max_colwidth', 100)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 2, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "data_dir = os.path.join(os.getcwd(), '..', 'data')\n", 82 | "src_dir = os.path.join(os.getcwd(), '..', 'src')\n", 83 | "fig_dir = os.path.join(data_dir, 'figures')\n", 84 | "checkpoint_dir = os.path.join(os.getcwd(), '..', 'checkpoints')\n", 85 | "dirs_to_make = [\n", 86 | " data_dir,\n", 87 | " os.path.join(data_dir, 'raw'),\n", 88 | " os.path.join(data_dir, 'processed'),\n", 89 | " os.path.join(data_dir, 'train'),\n", 90 | " os.path.join(data_dir, 'val'),\n", 91 | " os.path.join(data_dir, 'test'),\n", 92 | " src_dir,\n", 93 | " fig_dir,\n", 94 | " checkpoint_dir,\n", 95 | "]\n", 96 | "for d in dirs_to_make:\n", 97 | " if not os.path.exists(d):\n", 98 | " os.makedirs(d)" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 7, 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "data": { 108 | "text/html": [ 109 | "
\n", 110 | "\n", 123 | "\n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | "
SmilesSmiles_nostereoDC50pDC50Dmaxpoi_gene_idpoi_seqcell_typee3_ligaseactive
0Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O)[C@@H](NC(=O)CCCCCCCCCCNC(=O)c2cc3c(cc2CS(C)(=O)=...C(CCCCCNC(=O)c1cc2c(-c3cn(C)c(=O)c4c3c(c[nH]4)CN2c2ncc(F)cc2F)cc1CS(=O)(C)=O)CCCCC(=O)NC(C(N1CC(...3.400000e-098.4685210.980000BRD4MSAESGPGTRLRNLPVMGDGLETSQMSTTQAQAQPQPANAASTNPPPPETSNPNKPKRQTNQLQYLLRVVLKTLWKHQFAWPFQQPVDAVKLNLPD...PC3-S1VHLTrue
1CN[C@@H](C)C(=O)N[C@H](C(=O)N1C[C@@H](NC(=O)COCCOCCOCCN2CCN(c3ccc(Nc4ncc5c(C)c(C(C)=O)c(=O)n(C6C...C(=O)(C(C(C)(C)C)NC(=O)C(NC)C)N1CC(NC(COCCOCCOCCN2CCN(c3ccc(Nc4ncc5c(C)c(C(=O)C)c(=O)n(C6CCCC6)c...1.600000e-087.7958800.871440CDK6MEKDGLCRADQQYECVAEIGEGAYGKVFKARDLKNGGRFVALKRVRVQTGEEGMPLSTIREVAVLRHLETFEHPNVVRLFDVCTVSRTDRETKLTL...JURKATIAPTrue
2O=C(CCCCCCC(=O)N/N=C/c1ccc(OCCCC#Cc2cccc3c2CN(C2CCC(=O)NC2=O)C3=O)cc1)NOC(Oc1ccc(C=NNC(CCCCCCC(NO)=O)=O)cc1)CCC#Cc1cccc2c1CN(C1CCC(=O)NC1=O)C2=O1.940000e-098.7121980.896614HDAC6MTSTGQDSTTTRQRRSRQNPQSPPQDSSVTSKRNIKKGAVPRSIPNLAEVKKKGKMKKLGQAMEEDLIVGLQGMDLNLEAEALAGTGLVLDEQLNE...MM1SCRBNTrue
3Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O)[C@@H](NC(=O)CCc2cc(N)cc(CCOCCOCC#Cc3ccc(C4=N[C@@...CC(C)(C(C(N1C(C(=O)NCc2ccc(-c3c(C)ncs3)cc2)CC(O)C1)=O)NC(CCc1cc(N)cc(CCOCCOCC#Cc2ccc(C3=NC(Cc4nc...8.700000e-087.0604811.000000BRD4MSAESGPGTRLRNLPVMGDGLETSQMSTTQAQAQPQPANAASTNPPPPETSNPNKPKRQTNQLQYLLRVVLKTLWKHQFAWPFQQPVDAVKLNLPD...PC3-S1VHLTrue
4Nc1ncnc2c1c(-c1ccc(Oc3ccccc3)cc1)nn2[C@@H]1CCCN(C(=O)/C=C/COCCOCCOCCOCCOCCOc2cccc3c2C(=O)N(C2CCC...C(Oc1cccc2c1C(=O)N(C1CCC(=O)NC1=O)C2=O)COCCOCCOCCOCCOCC=CC(N1CC(n2c3c(c(N)ncn3)c(-c3ccc(Oc4ccccc...8.600000e-098.0655020.910000BTKMAAVILESIFLKRSQQKKKTSPLNFKKRLFLLTVHKLSYYEYDFERGRRGSKKGSIDVEKITCVETVVPEKNPPPERQIPRRGEESSEMEQISIIE...MINOCRBNTrue
\n", 207 | "
" 208 | ], 209 | "text/plain": [ 210 | " Smiles Smiles_nostereo DC50 pDC50 Dmax poi_gene_id poi_seq cell_type e3_ligase active\n", 211 | "0 Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O)[C@@H](NC(=O)CCCCCCCCCCNC(=O)c2cc3c(cc2CS(C)(=O)=... C(CCCCCNC(=O)c1cc2c(-c3cn(C)c(=O)c4c3c(c[nH]4)CN2c2ncc(F)cc2F)cc1CS(=O)(C)=O)CCCCC(=O)NC(C(N1CC(... 3.400000e-09 8.468521 0.980000 BRD4 MSAESGPGTRLRNLPVMGDGLETSQMSTTQAQAQPQPANAASTNPPPPETSNPNKPKRQTNQLQYLLRVVLKTLWKHQFAWPFQQPVDAVKLNLPD... PC3-S1 VHL True\n", 212 | "1 CN[C@@H](C)C(=O)N[C@H](C(=O)N1C[C@@H](NC(=O)COCCOCCOCCN2CCN(c3ccc(Nc4ncc5c(C)c(C(C)=O)c(=O)n(C6C... C(=O)(C(C(C)(C)C)NC(=O)C(NC)C)N1CC(NC(COCCOCCOCCN2CCN(c3ccc(Nc4ncc5c(C)c(C(=O)C)c(=O)n(C6CCCC6)c... 1.600000e-08 7.795880 0.871440 CDK6 MEKDGLCRADQQYECVAEIGEGAYGKVFKARDLKNGGRFVALKRVRVQTGEEGMPLSTIREVAVLRHLETFEHPNVVRLFDVCTVSRTDRETKLTL... JURKAT IAP True\n", 213 | "2 O=C(CCCCCCC(=O)N/N=C/c1ccc(OCCCC#Cc2cccc3c2CN(C2CCC(=O)NC2=O)C3=O)cc1)NO C(Oc1ccc(C=NNC(CCCCCCC(NO)=O)=O)cc1)CCC#Cc1cccc2c1CN(C1CCC(=O)NC1=O)C2=O 1.940000e-09 8.712198 0.896614 HDAC6 MTSTGQDSTTTRQRRSRQNPQSPPQDSSVTSKRNIKKGAVPRSIPNLAEVKKKGKMKKLGQAMEEDLIVGLQGMDLNLEAEALAGTGLVLDEQLNE... MM1S CRBN True\n", 214 | "3 Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O)[C@@H](NC(=O)CCc2cc(N)cc(CCOCCOCC#Cc3ccc(C4=N[C@@... CC(C)(C(C(N1C(C(=O)NCc2ccc(-c3c(C)ncs3)cc2)CC(O)C1)=O)NC(CCc1cc(N)cc(CCOCCOCC#Cc2ccc(C3=NC(Cc4nc... 8.700000e-08 7.060481 1.000000 BRD4 MSAESGPGTRLRNLPVMGDGLETSQMSTTQAQAQPQPANAASTNPPPPETSNPNKPKRQTNQLQYLLRVVLKTLWKHQFAWPFQQPVDAVKLNLPD... PC3-S1 VHL True\n", 215 | "4 Nc1ncnc2c1c(-c1ccc(Oc3ccccc3)cc1)nn2[C@@H]1CCCN(C(=O)/C=C/COCCOCCOCCOCCOCCOc2cccc3c2C(=O)N(C2CCC... C(Oc1cccc2c1C(=O)N(C1CCC(=O)NC1=O)C2=O)COCCOCCOCCOCCOCC=CC(N1CC(n2c3c(c(N)ncn3)c(-c3ccc(Oc4ccccc... 8.600000e-09 8.065502 0.910000 BTK MAAVILESIFLKRSQQKKKTSPLNFKKRLFLLTVHKLSYYEYDFERGRRGSKKGSIDVEKITCVETVVPEKNPPPERQIPRRGEESSEMEQISIIE... MINO CRBN True" 216 | ] 217 | }, 218 | "execution_count": 7, 219 | "metadata": {}, 220 | "output_type": "execute_result" 221 | } 222 | ], 223 | "source": [ 224 | "ssl_df_path = os.path.join(data_dir, 'processed', 'protac_db_ssl.csv')\n", 225 | "train_df_path = os.path.join(data_dir, 'train', 'train_bin_upsampled.csv')\n", 226 | "val_df_path = os.path.join(data_dir, 'val', 'val_bin.csv')\n", 227 | "test_df_path = os.path.join(data_dir, 'test', 'test_bin.csv')\n", 228 | "cols_to_keep = [\n", 229 | " 'Smiles',\n", 230 | " 'Smiles_nostereo',\n", 231 | " 'DC50',\n", 232 | " 'pDC50',\n", 233 | " 'Dmax',\n", 234 | " 'poi_gene_id',\n", 235 | " 'poi_seq',\n", 236 | " 'cell_type',\n", 237 | " 'e3_ligase',\n", 238 | " 'active',\n", 239 | "]\n", 240 | "# Assign train/val datasets for use in dataloaders\n", 241 | "train_df = pd.read_csv(train_df_path).reset_index(drop=True)\n", 242 | "val_df = pd.read_csv(val_df_path).reset_index(drop=True)\n", 243 | "test_df = pd.read_csv(test_df_path).reset_index(drop=True)\n", 244 | "ssl_df = pd.read_csv(ssl_df_path).reset_index(drop=True)\n", 245 | "train_df = train_df[cols_to_keep]\n", 246 | "val_df = val_df[cols_to_keep]\n", 247 | "test_df = test_df[cols_to_keep]\n", 248 | "ssl_df = ssl_df[cols_to_keep]\n", 249 | "\n", 250 | "protac_df = pd.concat([train_df, val_df, ssl_df]).reset_index(drop=True)\n", 251 | "protac_df = protac_df[cols_to_keep]\n", 252 | "protac_df.head()" 253 | ] 254 | }, 255 | { 256 | "attachments": {}, 257 | "cell_type": "markdown", 258 | "metadata": {}, 259 | "source": [ 260 | "## POI Sequence Encoding" 261 | ] 262 | }, 263 | { 264 | "attachments": {}, 265 | "cell_type": "markdown", 266 | "metadata": {}, 267 | "source": [ 268 | "#### POI Sequence to $N_{grams}$\n", 269 | "\n", 270 | "Count-vectorize the POI amino acid sequence.\n", 271 | "\n", 272 | "(Not ideal and very simple, but it's a start)" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 11, 278 | "metadata": {}, 279 | "outputs": [ 280 | { 281 | "name": "stdout", 282 | "output_type": "stream", 283 | "text": [ 284 | "Training POI vectorizer...\n", 285 | "POI embedding size: 403\n", 286 | "Done!\n" 287 | ] 288 | } 289 | ], 290 | "source": [ 291 | "ngram_min_range = 2 # Orginal: 3\n", 292 | "ngram_max_range = 2 # Orginal: 3\n", 293 | "# poi_vectorizer = CountVectorizer(analyzer='char', ngram_range=(ngram_min_range, ngram_max_range))\n", 294 | "# X = poi_vectorizer.fit_transform(protac_df['poi_seq'].tolist())\n", 295 | "# rec_n_grams_df = pd.DataFrame(X.toarray(), columns=list(s.replace(' ', '') for s in poi_vectorizer.get_feature_names_out()))\n", 296 | "# print(f'POI embedding size: {rec_n_grams_df.shape[-1]}')\n", 297 | "\n", 298 | "protac_df['poi_seq'] = protac_df['poi_seq'].fillna('')\n", 299 | "\n", 300 | "# Load the pre-trained countvectorizer if it exists, otherwise train it\n", 301 | "poi_encoder_filepath = os.path.join(checkpoint_dir, 'poi_encoder.joblib')\n", 302 | "if os.path.exists(poi_encoder_filepath):\n", 303 | " print('Loading pre-trained POI vectorizer...')\n", 304 | " poi_encoder = joblib.load(poi_encoder_filepath)\n", 305 | "else:\n", 306 | " print('Training POI vectorizer...')\n", 307 | " poi_encoder = CountVectorizer(analyzer='char', ngram_range=(ngram_min_range, ngram_max_range))\n", 308 | " X = poi_encoder.fit_transform(protac_df['poi_seq'].tolist())\n", 309 | " rec_n_grams_df = pd.DataFrame(X.toarray(), columns=list(s.replace(' ', '') for s in poi_encoder.get_feature_names_out()))\n", 310 | " print(f'POI embedding size: {rec_n_grams_df.shape[-1]}')\n", 311 | " joblib.dump(poi_encoder, poi_encoder_filepath)\n", 312 | "print('Done!')" 313 | ] 314 | }, 315 | { 316 | "attachments": {}, 317 | "cell_type": "markdown", 318 | "metadata": {}, 319 | "source": [ 320 | "#### POI Gene Ordinal Encoding" 321 | ] 322 | }, 323 | { 324 | "attachments": {}, 325 | "cell_type": "markdown", 326 | "metadata": {}, 327 | "source": [ 328 | "Add the \"Unknown\" class to the POI genes.\n", 329 | "\n", 330 | "Since genes ultimately encode proteins, we can use the gene ID as a categorical feature to include information about the POIs.\n", 331 | "\n", 332 | "(The information loss is considerable, since the gene ID is not that informative compared to the entire amino acid sequence)" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": null, 338 | "metadata": {}, 339 | "outputs": [], 340 | "source": [ 341 | "poi_gene_enc = preprocessing.OrdinalEncoder(handle_unknown='use_encoded_value',\n", 342 | " unknown_value=-1,\n", 343 | " encoded_missing_value=-1)\n", 344 | "poi_gene_id = protac_df['poi_gene_id'].to_numpy().reshape(-1, 1)\n", 345 | "poi_gene_enc.fit(poi_gene_id)" 346 | ] 347 | }, 348 | { 349 | "attachments": {}, 350 | "cell_type": "markdown", 351 | "metadata": {}, 352 | "source": [ 353 | "## E3 Ligase and Cell Type Ordinal Encoding\n", 354 | "\n", 355 | "Notice that the \"other E3\" have been dropped during the previous steps, leading to only 5 possibilities left." 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": 13, 361 | "metadata": {}, 362 | "outputs": [ 363 | { 364 | "name": "stdout", 365 | "output_type": "stream", 366 | "text": [ 367 | "Training E3 encoder...\n", 368 | "Done!\n" 369 | ] 370 | } 371 | ], 372 | "source": [ 373 | "# Load the pre-trained ordinal encoder if it exists, otherwise train it\n", 374 | "e3_encoder_filepath = os.path.join(checkpoint_dir, 'e3_ligase_encoder.joblib')\n", 375 | "if os.path.exists(e3_encoder_filepath) and False:\n", 376 | " print('Loading pre-trained POI vectorizer...')\n", 377 | " e3_encoder = joblib.load(e3_encoder_filepath)\n", 378 | "else:\n", 379 | " print('Training E3 encoder...')\n", 380 | " e3_encoder = preprocessing.OrdinalEncoder(handle_unknown='use_encoded_value',\n", 381 | " unknown_value=-1,\n", 382 | " encoded_missing_value=-1)\n", 383 | " e3_ligase = protac_df['e3_ligase'].to_numpy().reshape(-1, 1)\n", 384 | " e3_encoder.fit(e3_ligase)\n", 385 | " joblib.dump(e3_encoder, e3_encoder_filepath)\n", 386 | "print('Done!')" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 12, 392 | "metadata": {}, 393 | "outputs": [ 394 | { 395 | "name": "stdout", 396 | "output_type": "stream", 397 | "text": [ 398 | "Training E3 encoder...\n", 399 | "Done!\n" 400 | ] 401 | } 402 | ], 403 | "source": [ 404 | "# Load the pre-trained ordinal encoder if it exists, otherwise train it\n", 405 | "cell_encoder_filepath = os.path.join(checkpoint_dir, 'cell_type_encoder.joblib')\n", 406 | "if os.path.exists(cell_encoder_filepath):\n", 407 | " print('Loading pre-trained POI vectorizer...')\n", 408 | " cell_encoder = joblib.load(cell_encoder_filepath)\n", 409 | "else:\n", 410 | " print('Training E3 encoder...')\n", 411 | " cell_encoder = preprocessing.OrdinalEncoder(handle_unknown='use_encoded_value',\n", 412 | " unknown_value=-1,\n", 413 | " encoded_missing_value=-1)\n", 414 | " cell_type = protac_df['cell_type'].to_numpy().reshape(-1, 1)\n", 415 | " cell_encoder.fit(cell_type)\n", 416 | " joblib.dump(cell_encoder, cell_encoder_filepath)\n", 417 | "print('Done!')" 418 | ] 419 | } 420 | ], 421 | "metadata": { 422 | "kernelspec": { 423 | "display_name": "Python 3", 424 | "language": "python", 425 | "name": "python3" 426 | }, 427 | "language_info": { 428 | "codemirror_mode": { 429 | "name": "ipython", 430 | "version": 3 431 | }, 432 | "file_extension": ".py", 433 | "mimetype": "text/x-python", 434 | "name": "python", 435 | "nbconvert_exporter": "python", 436 | "pygments_lexer": "ipython3", 437 | "version": "3.11.0" 438 | }, 439 | "orig_nbformat": 4 440 | }, 441 | "nbformat": 4, 442 | "nbformat_minor": 2 443 | } 444 | -------------------------------------------------------------------------------- /src/models/wrapper_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional, Mapping, Literal, Callable, List, ClassVar, Any, Tuple, Type, Dict 3 | 4 | from data.protac_dataset import ProtacDataset 5 | from data.protac_dataloader import custom_collate 6 | 7 | import pytorch_lightning as pl 8 | from pytorch_lightning import LightningModule 9 | 10 | import torch 11 | from torch import nn 12 | from torch.nn import functional as F 13 | from torch.utils.data import Dataset, DataLoader 14 | 15 | from torchvision.ops import MLP 16 | 17 | import torch_geometric 18 | import torch_geometric.nn as geom_nn 19 | import torch_geometric.data as geom_data 20 | from torch_geometric.utils.smiles import from_smiles 21 | 22 | from torchmetrics import (Accuracy, 23 | AUROC, 24 | ROC, 25 | Precision, 26 | Recall, 27 | F1Score, 28 | MeanAbsoluteError, 29 | MeanSquaredError) 30 | from torchmetrics.functional import (mean_absolute_error, 31 | mean_squared_error, 32 | mean_squared_log_error, 33 | pearson_corrcoef, 34 | r2_score) 35 | from torchmetrics.functional.classification import (binary_accuracy, 36 | binary_auroc, 37 | binary_precision, 38 | binary_recall, 39 | binary_f1_score) 40 | from torchvision.ops import MLP 41 | 42 | from torchmetrics import MetricCollection 43 | 44 | 45 | class ProtacModel(pl.LightningModule): 46 | 47 | def __init__(self, 48 | smiles_encoder: nn.Module | nn.Sequential, 49 | poi_seq_encoder: Optional[nn.Module | nn.Sequential] = None, 50 | e3_ligase_encoder: Optional[nn.Module | nn.Sequential] = None, 51 | cell_type_encoder: Optional[nn.Module | nn.Sequential] = None, 52 | head: Optional[nn.Module | nn.Sequential] = None, 53 | join_branches: Literal['cat', 'sum'] = 'cat',): 54 | """Wrapper class to make prediction on PROTAC data. 55 | 56 | Args: 57 | smiles_encoder (nn.Module): SMILES encoder class or instance. 58 | poi_seq_encoder (Optional[nn.Module], optional): POI sequence encoder class or instance. Defaults to None. 59 | e3_ligase_encoder (Optional[nn.Module], optional): E3 ligase encoder class or instance. Defaults to None. 60 | cell_type_encoder (Optional[nn.Module], optional): Cell type encoder class or instance. Defaults to None. 61 | head (Optional[nn.Module], optional): Head class or instance. Defaults to None. 62 | join_branches (Literal['cat', 'sum'], optional): How to join the branches embeddings. Defaults to 'cat'. 63 | """ 64 | super().__init__() 65 | # Set our init args as class attributes 66 | # self.__dict__.update(locals()) # Add arguments as attributes 67 | # # Save the arguments passed to init 68 | # ignore_args = [] 69 | # self.save_hyperparameters() 70 | self.join_branches = join_branches 71 | # Define SMILES encoder and head input size 72 | self.smiles_encoder = smiles_encoder 73 | head_input_size = self.smiles_encoder.get_embedding_size() 74 | # Define or load POI sequence, cell type and E3 ligase encoders 75 | # POI sequence encoder 76 | self.poi_seq_encoder = None 77 | if poi_seq_encoder is not None: 78 | self.poi_seq_encoder = poi_seq_encoder 79 | head_input_size += self.poi_seq_encoder.get_embedding_size() 80 | # E3 ligase encoder 81 | self.e3_ligase_encoder = None 82 | if e3_ligase_encoder is not None: 83 | self.e3_ligase_encoder = e3_ligase_encoder 84 | head_input_size += self.e3_ligase_encoder.get_embedding_size() 85 | # Cell type encoder 86 | self.cell_type_encoder = None 87 | if cell_type_encoder is not None: 88 | self.cell_type_encoder = cell_type_encoder 89 | head_input_size += self.cell_type_encoder.get_embedding_size() 90 | 91 | # self.extra_feat_enc = MLP(in_channels=head_input_size - self.smiles_encoder.get_embedding_size(), 92 | # hidden_channels=[1], 93 | # norm_layer=nn.BatchNorm1d, 94 | # inplace=False, 95 | # dropout=0.3) 96 | 97 | 98 | # Define head module 99 | if head is None: 100 | head_args = { 101 | 'hidden_channels': [1], 102 | 'norm_layer': nn.BatchNorm1d, 103 | 'inplace': False, 104 | 'dropout': 0.3, 105 | } 106 | self.head = MLP(in_channels=head_input_size, **head_args) 107 | else: 108 | self.head = head 109 | # Define loss function 110 | self.bin_loss = nn.BCEWithLogitsLoss() 111 | # Metrics, a separate metrics collection is defined for each stage 112 | # NOTE: According to the PyTorch Lightning docs, "similar" metrics, 113 | # i.e., requiring the same computation, should be optimized w/in a 114 | # metrics collection. 115 | stages = ['train_metrics', 'val_metrics', 'test_metrics'] 116 | self.metrics = nn.ModuleDict({s: MetricCollection({ 117 | 'acc': Accuracy(task='binary'), 118 | 'roc_auc': AUROC(task='binary'), 119 | 'precision': Precision(task='binary'), 120 | 'recall': Recall(task='binary'), 121 | 'f1_score': F1Score(task='binary'), 122 | 'opt_score': Accuracy(task='binary') + F1Score(task='binary'), 123 | 'hp_metric': Accuracy(task='binary'), 124 | }, prefix=s.replace('metrics', '')) for s in stages}) 125 | 126 | # def configure_optimizers(self): 127 | # optimizer = torch.optim.Adam(self.parameters()) 128 | # return optimizer 129 | 130 | def forward(self, x_in): 131 | mol_emb = self.smiles_encoder(x_in) 132 | if self.poi_seq_encoder is not None: 133 | poi_seq_emb = self.poi_seq_encoder(x_in) 134 | if self.join_branches == 'cat': 135 | mol_emb = torch.cat((mol_emb, poi_seq_emb), dim=-1) 136 | elif self.join_branches == 'sum': 137 | mol_emb = mol_emb + poi_seq_emb 138 | if self.e3_ligase_encoder is not None: 139 | e3_ligase_emb = self.e3_ligase_encoder(x_in) 140 | if self.join_branches == 'cat': 141 | mol_emb = torch.cat((mol_emb, e3_ligase_emb), dim=-1) 142 | elif self.join_branches == 'sum': 143 | mol_emb = mol_emb + e3_ligase_emb 144 | if self.cell_type_encoder is not None: 145 | cell_type_emb = self.cell_type_encoder(x_in) 146 | if self.join_branches == 'cat': 147 | mol_emb = torch.cat((mol_emb, cell_type_emb), dim=-1) 148 | elif self.join_branches == 'sum': 149 | mol_emb = mol_emb + cell_type_emb 150 | return self.head(mol_emb) 151 | 152 | def step(self, batch, stage='train'): 153 | y = batch['labels'] 154 | preds = self.forward(batch) 155 | loss = self.bin_loss(preds, y) 156 | self.metrics[f'{stage}_metrics'].update(preds, y) 157 | self.log(f'{stage}_loss', loss, on_epoch=True, prog_bar=True) 158 | self.log_dict(self.metrics[f'{stage}_metrics'], on_epoch=True) 159 | return loss 160 | 161 | def training_step(self, batch, batch_idx): 162 | return self.step(batch, stage='train') 163 | 164 | def validation_step(self, batch, batch_idx): 165 | return self.step(batch, stage='val') 166 | 167 | def test_step(self, batch, batch_idx): 168 | return self.step(batch, stage='test') 169 | 170 | 171 | class WrapperModel(pl.LightningModule): 172 | 173 | def __init__(self, 174 | smiles_encoder: Type[nn.Module] | nn.Module, 175 | smiles_encoder_args: Optional[Dict] = None, 176 | task: Literal['predict_active_inactive', 'predict_pDC50_and_Dmax'] = 'predict_active_inactive', 177 | freeze_smiles_encoder: bool = False, 178 | cell_type_encoder: Optional[Type[nn.Module] | nn.Module] = None, 179 | cell_type_encoder_args: Optional[Dict] = None, 180 | e3_ligase_encoder: Optional[Type[nn.Module] | nn.Module] = None, 181 | e3_ligase_encoder_args: Optional[Dict] = None, 182 | poi_seq_encoder: Optional[Type[nn.Module] | nn.Module] = None, 183 | poi_seq_encoder_args: Optional[Dict] = None, 184 | head_class: Optional[Type[nn.Module] | nn.Module] = None, 185 | head: Optional[nn.Module] = None, 186 | head_args: Optional[Type[nn.Module] | nn.Module] = None, 187 | train_dataset: ProtacDataset = None, 188 | val_dataset: ProtacDataset = None, 189 | test_dataset: ProtacDataset = None, 190 | batch_size: int = 32, 191 | learning_rate: float = 1e-3, 192 | regr_loss: Type[nn.Module] | nn.Module | Callable = nn.HuberLoss, 193 | regr_loss_args: Optional[Dict] = None): 194 | """Wrapper class to make prediction on PROTAC data. 195 | 196 | Args: 197 | smiles_encoder (Type[nn.Module] | nn.Module): SMILES encoder class or instance. 198 | smiles_encoder_args (Optional[Dict], optional): Arguments to pass to the SMILES encoder class. Defaults to None. 199 | task (Literal['predict_active_inactive', 'predict_pDC50_and_Dmax'], optional): Task to perform. Defaults to 'predict_active_inactive'. 200 | freeze_smiles_encoder (bool, optional): Whether to freeze the SMILES encoder. Defaults to False. 201 | cell_type_encoder (Optional[Type[nn.Module] | nn.Module], optional): Cell type encoder class or instance. Defaults to None. 202 | cell_type_encoder_args (Optional[Dict], optional): Arguments to pass to the cell type encoder class. Defaults to None. 203 | e3_ligase_encoder (Optional[Type[nn.Module] | nn.Module], optional): E3 ligase encoder class or instance. Defaults to None. 204 | e3_ligase_encoder_args (Optional[Dict], optional): Arguments to pass to the E3 ligase encoder class. Defaults to None. 205 | poi_seq_encoder (Optional[Type[nn.Module] | nn.Module], optional): POI sequence encoder class or instance. Defaults to None. 206 | poi_seq_encoder_args (Optional[Dict], optional): Arguments to pass to the POI sequence encoder class. Defaults to None. 207 | head_class (Optional[Type[nn.Module] | nn.Module], optional): Head class or instance. Defaults to None. 208 | head_args (Optional[Dict[List, bool, int, float, Type[nn.Module]]], optional): Arguments to pass to the head class. Defaults to None. 209 | train_dataset (ProtacDataset, optional): Training dataset. Defaults to None. 210 | val_dataset (ProtacDataset, optional): Validation dataset. Defaults to None. 211 | test_dataset (ProtacDataset, optional): Test dataset. Defaults to None. 212 | batch_size (int, optional): Batch size. Defaults to 32. 213 | learning_rate (float, optional): Learning rate. Defaults to 1e-3. 214 | regr_loss (Type[nn.Module] | nn.Module | Callable, optional): Regression loss function. Defaults to nn.HuberLoss. 215 | regr_loss_args (Optional[Dict], optional): Arguments to pass to the regression loss function. Defaults to None. 216 | """ 217 | super().__init__() 218 | # Set our init args as class attributes 219 | self.__dict__.update(locals()) # Add arguments as attributes 220 | # Save the arguments passed to init 221 | ignore_args_as_hyperparams = [ 222 | 'train_dataset', 223 | 'test_dataset', 224 | 'val_dataset', 225 | ] 226 | self.save_hyperparameters(ignore=ignore_args_as_hyperparams) 227 | # Define or load SMILES encoder sub-model 228 | smiles_encoder_args = smiles_encoder_args or {} 229 | self.smiles_encoder = smiles_encoder(**smiles_encoder_args) 230 | if freeze_smiles_encoder: 231 | self.smiles_encoder.freeze() 232 | head_input_size = self.smiles_encoder.get_embedding_size() 233 | # Define or load POI sequence, cell type and E3 ligase encoders 234 | # POI sequence encoder 235 | if poi_seq_encoder_args is not None: 236 | if poi_seq_encoder is None: 237 | raise ValueError('`poi_seq_encoder` must be provided if `poi_seq_encoder_args` is not None.') 238 | self.poi_seq_encoder = poi_seq_encoder(**poi_seq_encoder_args) 239 | elif poi_seq_encoder is not None: 240 | self.poi_seq_encoder = poi_seq_encoder 241 | if self.poi_seq_encoder is not None: 242 | head_input_size += self.poi_seq_encoder.get_embedding_size() 243 | # E3 ligase encoder 244 | if e3_ligase_encoder_args is not None: 245 | if e3_ligase_encoder is None: 246 | raise ValueError('`e3_ligase_encoder` must be provided if `e3_ligase_encoder_args` is not None.') 247 | self.e3_ligase_encoder = e3_ligase_encoder(**e3_ligase_encoder_args) 248 | elif e3_ligase_encoder is not None: 249 | self.e3_ligase_encoder = e3_ligase_encoder 250 | if self.e3_ligase_encoder is not None: 251 | head_input_size += self.e3_ligase_encoder.get_embedding_size() 252 | # Cell type encoder 253 | if cell_type_encoder_args is not None: 254 | if cell_type_encoder is None: 255 | raise ValueError('`cell_type_encoder` must be provided if `cell_type_encoder_args` is not None.') 256 | self.cell_type_encoder = cell_type_encoder(**cell_type_encoder_args) 257 | elif cell_type_encoder is not None: 258 | self.cell_type_encoder = cell_type_encoder 259 | if self.cell_type_encoder is not None: 260 | head_input_size += self.cell_type_encoder.get_embedding_size() 261 | # Define head module 262 | if head_args is None: 263 | num_outputs = 2 if task == 'predict_pDC50_and_Dmax' else 1 264 | head_args = { 265 | 'hidden_channels': [num_outputs], 266 | 'norm_layer': nn.BatchNorm1d, 267 | 'inplace': False, 268 | 'dropout': 0.3, 269 | } 270 | if head_class is None: 271 | self.head = MLP(in_channels=head_input_size, **head_args) 272 | elif head_args is not None: 273 | self.head = head_class(**head_args) 274 | else: 275 | self.head = head_class 276 | # Define PROTAC model 277 | self.model = ProtacModel(smiles_encoder=self.smiles_encoder, 278 | poi_seq_encoder=self.poi_seq_encoder, 279 | e3_ligase_encoder=self.e3_ligase_encoder, 280 | cell_type_encoder=self.cell_type_encoder, 281 | head=self.head) 282 | # Define losses 283 | if task == 'predict_pDC50_and_Dmax': 284 | if regr_loss_args is None: 285 | regr_loss_args = {'reduction': 'mean'} 286 | self.regr_loss = regr_loss(**regr_loss_args) 287 | else: 288 | self.bin_loss = nn.BCEWithLogitsLoss() 289 | # Metrics, a separate metrics collection is defined for each stage 290 | # NOTE: According to the PyTorch Lightning docs, "similar" metrics, 291 | # i.e., requiring the same computation, should be optimized w/in a 292 | # metrics collection. 293 | stages = ['train_metrics', 'val_metrics', 'test_metrics'] 294 | self.metrics = nn.ModuleDict({s: MetricCollection({ 295 | 'acc': Accuracy(task='binary'), 296 | 'roc_auc': AUROC(task='binary'), 297 | 'precision': Precision(task='binary'), 298 | 'recall': Recall(task='binary'), 299 | 'f1_score': F1Score(task='binary'), 300 | 'opt_score': Accuracy(task='binary') + F1Score(task='binary'), 301 | 'hp_metric': Accuracy(task='binary'), 302 | }, prefix=s.replace('metrics', '')) for s in stages}) 303 | # Misc settings 304 | self.missing_dataset_error = \ 305 | '''Class variable `{0}` is None. If the model was loaded from a checkpoint, the dataset must be set manually: 306 | 307 | model = {1}.load_from_checkpoint('checkpoint.ckpt') 308 | model.{0} = my_{0} 309 | ''' 310 | 311 | def forward(self, x_in): 312 | return self.model(x_in) 313 | 314 | def step(self, batch, stage='train'): 315 | y = batch['labels'] 316 | preds = self.forward(batch) 317 | if self.task == 'predict_active_inactive': 318 | loss = self.bin_loss(preds, y) 319 | self.metrics[f'{stage}_metrics'].update(preds, y) 320 | self.log(f'{stage}_loss', loss, on_epoch=True, prog_bar=True) 321 | self.log_dict(self.metrics[f'{stage}_metrics'], on_epoch=True) 322 | else: 323 | loss = self.regr_loss(preds, y) 324 | self.log(f'{stage}_loss', loss, on_epoch=True, prog_bar=True) 325 | if stage == 'val': 326 | self.log('hp_metric', loss) 327 | return loss 328 | 329 | def training_step(self, batch, batch_idx): 330 | return self.step(batch, stage='train') 331 | 332 | def validation_step(self, batch, batch_idx): 333 | return self.step(batch, stage='val') 334 | 335 | def test_step(self, batch, batch_idx): 336 | return self.step(batch, stage='test') 337 | 338 | def configure_optimizers(self): 339 | optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) 340 | return optimizer 341 | 342 | def load_smiles_encoder(self, checkpoint_path): 343 | ckpt = torch.load(checkpoint_path, map_location=self.device) 344 | self.smiles_encoder.load_state_dict(ckpt, strict=False) 345 | 346 | # def prepare_data(self): 347 | # train_ds = os.path.join(data_dir, 'protac', f'train_dataset_fp{self.fp_bits}.pt') 348 | # test_ds = os.path.join(data_dir, 'protac', f'test_dataset_fp{self.fp_bits}.pt') 349 | # self.train_dataset = torch.load(train_ds) 350 | # self.train_dataset = torch.load(train_ds) 351 | # self.test_dataset = torch.load(test_ds) 352 | 353 | def train_dataloader(self): 354 | if self.train_dataset is None: 355 | format = 'train_dataset', self.__class__.__name__ 356 | raise ValueError(self.missing_dataset_error.format(*format)) 357 | return DataLoader(self.train_dataset, batch_size=self.batch_size, 358 | shuffle=True, collate_fn=custom_collate, 359 | drop_last=True) 360 | 361 | def val_dataloader(self): 362 | if self.val_dataset is None: 363 | format = 'val_dataset', self.__class__.__name__ 364 | raise ValueError(self.missing_dataset_error.format(*format)) 365 | return DataLoader(self.val_dataset, batch_size=self.batch_size, 366 | shuffle=False, collate_fn=custom_collate) 367 | 368 | def test_dataloader(self): 369 | if self.test_dataset is None: 370 | format = 'test_dataset', self.__class__.__name__ 371 | raise ValueError(self.missing_dataset_error.format(*format)) 372 | return DataLoader(self.test_dataset, batch_size=self.batch_size, 373 | shuffle=False, collate_fn=custom_collate) 374 | -------------------------------------------------------------------------------- /notebooks/complex_encoding.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# PROTAC Complex Encoding\n", 9 | "\n", 10 | "Collection of ideas and unfinished work. To be ignored for now..." 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": {}, 16 | "source": [ 17 | "## Features Encoding\n", 18 | "\n", 19 | "EVOformer from AlphaFold may be used for predicting molecule structures. The problem with general molecules is that they do not have a structure as proteins. Proteins are polymers, and each individual amino acid has its own shape, so this information is somehow leveraged by AlphaFold in its predictions." 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "hyperparameters = {\n", 29 | " 'use_morgan_fp': ('categorical', [True, False]),\n", 30 | " 'use_maccs_fp': ('categorical', [True, False]),\n", 31 | " 'use_path_fp': ('categorical', [True, False]),\n", 32 | " 'pathfp_min_path': (int, 1, 32),\n", 33 | " 'pathfp_max_path': (int, 1, 64),\n", 34 | " 'morgan_bitwidth': (int, 1024, 2048),\n", 35 | " 'pathfp_bitwidth': (int, 1024, 2048),\n", 36 | " 'morgan_encoder_hidden_sz': (int, 256, 2048),\n", 37 | " 'maccs_encoder_hidden_sz': (int, 256, 2048),\n", 38 | " 'pathfp_encoder_hidden_sz': (int, 256, 2048),\n", 39 | " 'learning_rate': (float, 1e-5, 1e-3),\n", 40 | " 'gnn_layer_type': ('categorical', ['GraphConv', 'GCN', 'GAT']),\n", 41 | "}" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "### Encoding SELFIES via ChemGPT (DEPRECATED)\n", 49 | "\n", 50 | "> Christian: SELFIES are not useful, we have already investigated and studied them.\n", 51 | "\n", 52 | "Let's start by installing and importing the required dependencies." 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "!pip install selfies sentencepiece transformers datasets wandb -qqq" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification\n", 71 | "from datasets import Dataset, DatasetDict, load_dataset, load_from_disk\n", 72 | "import selfies as sf\n", 73 | "import wandb" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "Each entry in the dataset will consist of a pair of the SELFIES encoding and the degradation percentage. In order to get the SELFIES encoding, each SMILES entry (without stereochemistry information) is converted." 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "entries = [(sf.encoder(x['Smiles_nostereo']), x['degradation']) for x in train_upsampled.to_dict(orient='records')]\n", 90 | "df = pd.DataFrame(entries, columns=['text', 'labels'])\n", 91 | "train_dataset = Dataset.from_pandas(df, preserve_index=False)\n", 92 | "\n", 93 | "entries = [(sf.encoder(x['Smiles_nostereo']), x['degradation']) for x in test.to_dict(orient='records')]\n", 94 | "df = pd.DataFrame(entries, columns=['text', 'labels'])\n", 95 | "test_dataset = Dataset.from_pandas(df, preserve_index=False)\n", 96 | "\n", 97 | "dataset = DatasetDict({\n", 98 | " 'train': train_dataset,\n", 99 | " 'test': test_dataset,\n", 100 | "})\n", 101 | "# dataset" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "We can now import the tokenizer and tokenize the SELFIES strings in the dataset." 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "tokenizer = AutoTokenizer.from_pretrained('ncfrey/ChemGPT-4.7M')\n", 118 | "tokenizer.add_special_tokens({'pad_token': '[PAD]'})" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "def tokenize_function(entries):\n", 128 | " return tokenizer(entries['text'], padding='max_length', truncation=True, max_length=256, return_tensors='pt')\n", 129 | "\n", 130 | "tokenized_datasets = dataset.map(tokenize_function, batched=True)\n", 131 | "tokenized_datasets = tokenized_datasets.remove_columns(['text'])" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "tokenized_datasets" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": {}, 146 | "source": [ 147 | "Next, we can download the pretrained ChemGPT model." 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "%%capture\n", 157 | "# model = AutoModelForCausalLM.from_pretrained('ncfrey/ChemGPT-4.7M', num_labels=1) # Original\n", 158 | "model = AutoModelForSequenceClassification.from_pretrained('ncfrey/ChemGPT-4.7M', num_labels=1)" 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "metadata": {}, 164 | "source": [ 165 | "Freeze all un-initialized layers in order to avoid \"catastrophic forgetting\"." 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "uninit_layers =[\n", 175 | " 'score.weight',\n", 176 | " 'transformer.h.1.attn.attention.bias',\n", 177 | " 'transformer.h.3.attn.attention.bias',\n", 178 | " 'transformer.h.5.attn.attention.bias',\n", 179 | " 'transformer.h.7.attn.attention.bias',\n", 180 | " 'transformer.h.9.attn.attention.bias',\n", 181 | " 'transformer.h.11.attn.attention.bias',\n", 182 | " 'transformer.h.13.attn.attention.bias',\n", 183 | " 'transformer.h.15.attn.attention.bias',\n", 184 | " 'transformer.h.17.attn.attention.bias',\n", 185 | " 'transformer.h.19.attn.attention.bias',\n", 186 | " 'transformer.h.21.attn.attention.bias',\n", 187 | " 'transformer.h.23.attn.attention.bias',\n", 188 | "]\n", 189 | "\n", 190 | "for name, param in model.named_parameters():\n", 191 | " if name in uninit_layers:\n", 192 | " param.requires_grad = True\n", 193 | " print(name, param.requires_grad)\n", 194 | " else:\n", 195 | " param.requires_grad = False" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "from datasets import load_metric\n", 205 | "from sklearn.metrics import mean_squared_error\n", 206 | "\n", 207 | "def compute_metrics(eval_pred):\n", 208 | " predictions, labels = eval_pred\n", 209 | " rmse = mean_squared_error(labels, predictions, squared=False)\n", 210 | " return {'rmse': rmse}" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "%%wandb\n", 220 | "from transformers import TrainingArguments, Trainer\n", 221 | "\n", 222 | "model.config.pad_token_id = tokenizer.pad_token_id\n", 223 | "\n", 224 | "training_args = TrainingArguments(\n", 225 | " # label_names='degradation,\n", 226 | " report_to='wandb',\n", 227 | " output_dir='test_trainer',\n", 228 | " logging_strategy='epoch',\n", 229 | " evaluation_strategy='epoch',\n", 230 | " per_device_train_batch_size=32,\n", 231 | " per_device_eval_batch_size=32,\n", 232 | " learning_rate=5e-5,\n", 233 | " num_train_epochs=15,\n", 234 | " save_total_limit=2,\n", 235 | " save_strategy='no')\n", 236 | "\n", 237 | "trainer = Trainer(\n", 238 | " model=model,\n", 239 | " args=training_args,\n", 240 | " train_dataset=tokenized_datasets['train'],\n", 241 | " eval_dataset=tokenized_datasets['test'],\n", 242 | " compute_metrics=compute_metrics\n", 243 | ")\n", 244 | "trainer.train()" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "trainer.state.log_history" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "predictions = trainer.predict(tokenized_datasets['test'])\n", 263 | "g = plt.plot(predictions.predictions, label='predictions')\n", 264 | "g = plt.plot(predictions.label_ids, label='label_ids')\n", 265 | "g = plt.legend()\n", 266 | "g = plt.grid(alpha=0.8)\n", 267 | "g = plt.xlabel('Test ID')\n", 268 | "g = plt.ylabel('Degradation (%)')\n", 269 | "plt.show()" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": null, 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "model = AutoModelForCausalLM.from_pretrained('ncfrey/ChemGPT-4.7M', num_labels=1) # Original\n", 279 | "model.eval()" 280 | ] 281 | }, 282 | { 283 | "cell_type": "markdown", 284 | "metadata": {}, 285 | "source": [ 286 | "From this [post](https://github.com/huggingface/transformers/issues/7540):\n", 287 | "\n", 288 | "> BERT (the base model without any heads on top) outputs 2 things: `last_hidden_state` and `pooler_output`.\n", 289 | "> \n", 290 | "> * `last_hidden_state` contains the hidden representations for each token in each sequence of the batch. So the size is `(batch_size, seq_len, hidden_size)`.\n", 291 | "> * `pooler_output` contains a \"representation\" of each sequence in the batch, and is of size `(batch_size, hidden_size)`. What it basically does is take the hidden representation of the [CLS] token of each sequence in the batch (which is a vector of size `hidden_size`), and then run that through the [BertPooler](https://github.com/huggingface/transformers/blob/de4d7b004a24e4bb087eb46d742ea7939bc74644/src/transformers/modeling_bert.py#L498) nn.Module. This consists of a linear layer followed by a Tanh activation function. The weights of this linear layer are already pretrained on the next sentence prediction task (note that BERT is pretrained on 2 tasks: masked language modeling and next sentence prediction). I assume that the authors of the Transformers library have taken the weights from the original TF implementation, and initialized the layer with them. In theory, they would come from [BertForPretraining](https://github.com/huggingface/transformers/blob/de4d7b004a24e4bb087eb46d742ea7939bc74644/src/transformers/modeling_bert.py#L862) - which is BERT with the 2 pretraining heads on top." 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": null, 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [ 300 | "import torch\n", 301 | "import selfies as sf\n", 302 | "\n", 303 | "smi = protac_db_df.iloc[42]['Smiles_nostereo']\n", 304 | "sf = sf.encoder(smi)\n", 305 | "print(f'smi: {smi}')\n", 306 | "print(f'sf: {sf}')\n", 307 | "\n", 308 | "inputs = tokenizer(sf, return_tensors='pt')\n", 309 | "outputs_transformer = model.transformer(**inputs, output_hidden_states=True)\n", 310 | "outputs = model(**inputs, output_hidden_states=True)\n", 311 | "print(f'Model Transformer output keys: {outputs_transformer.keys()}')\n", 312 | "print(f'Model Tranformer+Head output keys: {outputs.keys()}')" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": null, 318 | "metadata": {}, 319 | "outputs": [], 320 | "source": [ 321 | "outputs['logits'].size()" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": null, 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [ 330 | "outputs['hidden_states'] # A tuple" 331 | ] 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "metadata": {}, 336 | "source": [ 337 | "### Encoding POI Sequence (TODO)" 338 | ] 339 | }, 340 | { 341 | "cell_type": "markdown", 342 | "metadata": {}, 343 | "source": [ 344 | "#### Encoding POI Sequence via Protein Embeddings" 345 | ] 346 | }, 347 | { 348 | "cell_type": "markdown", 349 | "metadata": {}, 350 | "source": [ 351 | "Following this [implementation](https://huggingface.co/Rostlab/prot_bert). There are more models available at this [repository](https://github.com/agemagician/ProtTrans)." 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": null, 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [ 360 | "rem_rare_amino_acids = lambda seq: re.sub(r'[UZOB]', 'X', seq)\n", 361 | "poi_seq = input_df['poi_seq'].apply(rem_rare_amino_acids)\n", 362 | "print('POIs:', poi_seq.to_list())" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": null, 368 | "metadata": {}, 369 | "outputs": [], 370 | "source": [ 371 | "n = 128\n", 372 | "tmp = [' [SEP] '.join([seq[i:i+n] for i in range(0, len(seq), n)]) for seq in poi_seq]\n", 373 | "tmp[17]" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": null, 379 | "metadata": {}, 380 | "outputs": [], 381 | "source": [ 382 | "from transformers import T5Tokenizer, T5EncoderModel\n", 383 | "\n", 384 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 385 | "\n", 386 | "# Load the tokenizer\n", 387 | "tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False) #.to(device)" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": null, 393 | "metadata": {}, 394 | "outputs": [], 395 | "source": [ 396 | "# Load the model\n", 397 | "model = T5EncoderModel.from_pretrained(\"Rostlab/prot_t5_xl_half_uniref50-enc\").to(device)\n", 398 | "\n", 399 | "# only GPUs support half-precision currently; if you want to run on CPU use full-precision (not recommended, much slower)\n", 400 | "model = model.full() if device == 'cpu' else model.half()" 401 | ] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": null, 406 | "metadata": {}, 407 | "outputs": [], 408 | "source": [ 409 | "# prepare your protein sequences as a list\n", 410 | "sequence_examples = ['PRTEINO', 'SEQWENCE [SEP] SEQWENCE']\n", 411 | "\n", 412 | "# replace all rare/ambiguous amino acids by X and introduce white-space between all amino acids\n", 413 | "sequence_examples = [' '.join(list(re.sub(r'[UZOB]', 'X', sequence))) for sequence in sequence_examples]\n", 414 | "\n", 415 | "# tokenize sequences and pad up to the longest sequence in the batch\n", 416 | "ids = tokenizer.batch_encode_plus(sequence_examples, add_special_tokens=True, padding='longest')\n", 417 | "\n", 418 | "input_ids = torch.tensor(ids['input_ids']).to(device)\n", 419 | "attention_mask = torch.tensor(ids['attention_mask']).to(device)" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": null, 425 | "metadata": {}, 426 | "outputs": [], 427 | "source": [ 428 | "input_ids.size()" 429 | ] 430 | }, 431 | { 432 | "cell_type": "code", 433 | "execution_count": null, 434 | "metadata": {}, 435 | "outputs": [], 436 | "source": [ 437 | "# generate embeddings\n", 438 | "with torch.no_grad():\n", 439 | " embedding_rpr = model(input_ids=input_ids, attention_mask=attention_mask)\n", 440 | "\n", 441 | "# extract residue embeddings for the first ([0,:]) sequence in the batch and remove padded & special tokens ([0,:7]) \n", 442 | "emb_0 = embedding_repr.last_hidden_state[0, :7] # shape (7 x 1024)\n", 443 | "# same for the second ([1,:]) sequence but taking into account different sequence lengths ([1,:8])\n", 444 | "emb_1 = embedding_repr.last_hidden_state[1, :8] # shape (8 x 1024)\n", 445 | "\n", 446 | "# if you want to derive a single representation (per-protein embedding) for the whole protein\n", 447 | "emb_0_per_protein = emb_0.mean(dim=0) # shape (1024)\n", 448 | "\n", 449 | "print(emb_0)\n", 450 | "print(emb_1)\n", 451 | "print(emb_0_per_protein)" 452 | ] 453 | }, 454 | { 455 | "cell_type": "markdown", 456 | "metadata": {}, 457 | "source": [ 458 | "#### Encoding POI Sequence via ProtBERT" 459 | ] 460 | }, 461 | { 462 | "cell_type": "markdown", 463 | "metadata": {}, 464 | "source": [ 465 | "Following this [implementation](https://huggingface.co/Rostlab/prot_bert). There are more models available at this [repository](https://github.com/agemagician/ProtTrans)." 466 | ] 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": null, 471 | "metadata": {}, 472 | "outputs": [], 473 | "source": [ 474 | "from transformers import BertModel, BertTokenizer\n", 475 | "import re\n", 476 | "\n", 477 | "poi_tokenizer = BertTokenizer.from_pretrained('Rostlab/prot_bert', do_lower_case=False)" 478 | ] 479 | }, 480 | { 481 | "cell_type": "markdown", 482 | "metadata": {}, 483 | "source": [ 484 | "Get BERT [output](https://huggingface.co/docs/transformers/main_classes/output#transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions)." 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": null, 490 | "metadata": {}, 491 | "outputs": [], 492 | "source": [ 493 | "sequence_Example = 'AETCZAO'\n", 494 | "sequence_Example = re.sub(r'[UZOB]', 'X', sequence_Example)\n", 495 | "poi_tokenizer(sequence_Example, return_tensors='pt')" 496 | ] 497 | }, 498 | { 499 | "cell_type": "code", 500 | "execution_count": null, 501 | "metadata": {}, 502 | "outputs": [], 503 | "source": [ 504 | "rem_rare_amino_acids = lambda seq: re.sub(r'[UZOB]', 'X', seq)\n", 505 | "poi_seq = input_df['poi_seq'].apply(rem_rare_amino_acids)\n", 506 | "print('POIs:', poi_seq.to_list())" 507 | ] 508 | }, 509 | { 510 | "cell_type": "markdown", 511 | "metadata": {}, 512 | "source": [ 513 | "**TODO: The sequence max sequence length is at the moment requiring too much RAM to handle it. I truncate it as a temporary workaround.**" 514 | ] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "execution_count": null, 519 | "metadata": {}, 520 | "outputs": [], 521 | "source": [ 522 | "import math\n", 523 | "\n", 524 | "def nearest_pow2(x):\n", 525 | " return 1 << (x - 1).bit_length()\n", 526 | "\n", 527 | "longest_seq = max([len(seq) for seq in poi_seq])\n", 528 | "seq_max_len = nearest_pow2(longest_seq)\n", 529 | "# seq_max_len = 128\n", 530 | "poi_tokenizer.max_length = seq_max_len\n", 531 | "seq_max_len" 532 | ] 533 | }, 534 | { 535 | "cell_type": "code", 536 | "execution_count": null, 537 | "metadata": {}, 538 | "outputs": [], 539 | "source": [ 540 | "BINARY_CLASSIFICATION = False\n", 541 | "\n", 542 | "del_aminoacids = lambda seq: re.sub(r'[UZOB]', 'X', ' '.join(seq))\n", 543 | "train_upsampled_with_poi = train_upsampled.copy()\n", 544 | "train_upsampled_with_poi['poi_seq'] = train_upsampled_with_poi['poi_seq'].apply(del_aminoacids)\n", 545 | "test_with_poi = test.copy()\n", 546 | "test_with_poi['poi_seq'] = test_with_poi['poi_seq'].apply(del_aminoacids)\n", 547 | "\n", 548 | "train_dataset = ProtacDataset(train_upsampled_with_poi,\n", 549 | " poi_tokenizer=poi_tokenizer,\n", 550 | " binary_classification=BINARY_CLASSIFICATION)\n", 551 | "test_dataset = ProtacDataset(test,\n", 552 | " poi_tokenizer=poi_tokenizer,\n", 553 | " binary_classification=BINARY_CLASSIFICATION)" 554 | ] 555 | }, 556 | { 557 | "cell_type": "code", 558 | "execution_count": null, 559 | "metadata": {}, 560 | "outputs": [], 561 | "source": [ 562 | "class POIEncoder(pl.LightningModule):\n", 563 | "\n", 564 | " def __init__(self,\n", 565 | " hidden_size:int=64,\n", 566 | " n_layers:int=3,\n", 567 | " batch_size:int=64,\n", 568 | " learning_rate:float=1e-3):\n", 569 | " super().__init__()\n", 570 | " # Save the arguments passed to init\n", 571 | " self.save_hyperparameters()\n", 572 | " self.__dict__.update(locals()) # Add arguments as attributes\n", 573 | " # Define PyTorch models\n", 574 | " hidden_channels = [hidden_size] * n_layers\n", 575 | " self.extra_features_encoder = MLP(in_channels=3,\n", 576 | " hidden_channels=hidden_channels,\n", 577 | " norm_layer=nn.BatchNorm1d,\n", 578 | " inplace=False,\n", 579 | " dropout=0.5)\n", 580 | " self.poi_encoder = BertModel.from_pretrained('Rostlab/prot_bert')\n", 581 | " self.head = nn.Linear(hidden_size + poi_embedding_size, 1)\n", 582 | " # Define loss metrics\n", 583 | " self.val_mse = MeanSquaredError()\n", 584 | " self.test_mse = MeanSquaredError()\n", 585 | "\n", 586 | " def forward(self, x_in):\n", 587 | " # Ecode \"extra\" features\n", 588 | " concentrations = x_in['concentrations']\n", 589 | " e3_ligase = x_in['e3_ligase']\n", 590 | " cell_type = x_in['cell_type']\n", 591 | " x = torch.cat((concentrations, e3_ligase, cell_type), dim=-1)\n", 592 | " extra_features_embedding = self.extra_features_encoder(x)\n", 593 | " # Ecode POI sequence\n", 594 | " input_ids = x_in['poi_seq']['input_ids'].squeeze(dim=1)\n", 595 | " token_type_ids = x_in['poi_seq']['token_type_ids'].squeeze(dim=1)\n", 596 | " attention_mask = x_in['poi_seq']['attention_mask'].squeeze(dim=1)\n", 597 | " poi_embedding = self.poi_encoder(input_ids, token_type_ids,\n", 598 | " attention_mask)['pooler_output']\n", 599 | " # Run linear head\n", 600 | " x = torch.cat((extra_features_embedding, poi_embedding), dim=-1)\n", 601 | " return self.head(x)\n", 602 | "\n", 603 | " def step(self, batch, phase='train'):\n", 604 | " y = batch['labels']\n", 605 | " preds = self.forward(batch)\n", 606 | " loss = F.mse_loss(preds, y)\n", 607 | " self.log(f'{phase}_loss', loss, prog_bar=True)\n", 608 | " return loss\n", 609 | "\n", 610 | " def training_step(self, batch, batch_idx):\n", 611 | " return self.step(batch, phase='train')\n", 612 | "\n", 613 | " def validation_step(self, batch, batch_idx):\n", 614 | " return self.step(batch, phase='val')\n", 615 | "\n", 616 | " def test_step(self, batch, batch_idx):\n", 617 | " return self.step(batch, phase='test')\n", 618 | "\n", 619 | " def configure_optimizers(self):\n", 620 | " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", 621 | " return optimizer\n", 622 | "\n", 623 | " ####################\n", 624 | " # DATA RELATED HOOKS\n", 625 | " ####################\n", 626 | "\n", 627 | " # def prepare_data(self):\n", 628 | " # # download\n", 629 | " # MNIST(self.data_dir, train=True, download=True)\n", 630 | " # MNIST(self.data_dir, train=False, download=True)\n", 631 | "\n", 632 | " def train_dataloader(self):\n", 633 | " return DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=custom_collate)\n", 634 | "\n", 635 | " def val_dataloader(self):\n", 636 | " return DataLoader(test_dataset, batch_size=self.batch_size, collate_fn=custom_collate)\n", 637 | "\n", 638 | " def test_dataloader(self):\n", 639 | " return DataLoader(test_dataset, batch_size=self.batch_size, collate_fn=custom_collate)" 640 | ] 641 | }, 642 | { 643 | "cell_type": "code", 644 | "execution_count": null, 645 | "metadata": {}, 646 | "outputs": [], 647 | "source": [ 648 | "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n", 649 | "\n", 650 | "model = POIEncoder(hidden_size=32,\n", 651 | " n_layers=3,\n", 652 | " batch_size=4,\n", 653 | " learning_rate=1e-6)\n", 654 | "\n", 655 | "callbacks = [\n", 656 | " TQDMProgressBar(refresh_rate=20),\n", 657 | " EarlyStopping(monitor='val_loss', mode='min'),\n", 658 | " ModelCheckpoint(save_weights_only=True, mode='min', monitor='val_loss'),\n", 659 | "]\n", 660 | "\n", 661 | "trainer = pl.Trainer(max_epochs=5,\n", 662 | " gradient_clip_val=0.5,\n", 663 | " gradient_clip_algorithm='norm',\n", 664 | " accelerator='auto',\n", 665 | " devices=1 if torch.cuda.is_available() else None,\n", 666 | " log_every_n_steps=8,\n", 667 | " callbacks=callbacks,\n", 668 | " logger=CSVLogger(save_dir='logs/'))\n", 669 | "trainer.fit(model)" 670 | ] 671 | }, 672 | { 673 | "cell_type": "code", 674 | "execution_count": null, 675 | "metadata": {}, 676 | "outputs": [], 677 | "source": [ 678 | "trainer.test(ckpt_path='best')" 679 | ] 680 | }, 681 | { 682 | "cell_type": "code", 683 | "execution_count": null, 684 | "metadata": {}, 685 | "outputs": [], 686 | "source": [ 687 | "metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')\n", 688 | "del metrics['step']\n", 689 | "metrics.set_index('epoch', inplace=True)\n", 690 | "display(metrics.dropna(axis=1, how='all').head())\n", 691 | "g = sns.relplot(data=metrics, kind='line')\n", 692 | "g = plt.grid(alpha=0.7)\n", 693 | "plt.show()" 694 | ] 695 | }, 696 | { 697 | "cell_type": "code", 698 | "execution_count": null, 699 | "metadata": {}, 700 | "outputs": [], 701 | "source": [ 702 | "predictions = []\n", 703 | "y = []\n", 704 | "# Make predictions\n", 705 | "with torch.no_grad():\n", 706 | " _ = model.eval()\n", 707 | " for batch in model.test_dataloader():\n", 708 | " predictions.extend(model(batch).detach().tolist())\n", 709 | " y.extend(batch['labels'].detach().tolist())\n", 710 | "predictions = np.array(predictions).flatten()\n", 711 | "y = np.array(y).flatten()\n", 712 | "sorted_idx = np.argsort(y)\n", 713 | "# Plot predicitons (sorted)\n", 714 | "g = plt.plot(predictions[sorted_idx], label='Predicted degradation (%)')\n", 715 | "g = plt.plot(y[sorted_idx], label='Reference degradation (%)')\n", 716 | "g = plt.legend()\n", 717 | "g = plt.grid(alpha=0.8)\n", 718 | "g = plt.xlabel('Test ID (sorted by degradation perc.)')\n", 719 | "g = plt.ylabel('Degradation (%)')\n", 720 | "plt.show()" 721 | ] 722 | }, 723 | { 724 | "cell_type": "markdown", 725 | "metadata": {}, 726 | "source": [ 727 | "### LightGBM\n", 728 | "\n", 729 | "> LightGBM can use categorical features as input directly. It doesn’t need to convert to one-hot encoding, and is much faster than one-hot encoding (about 8x speed-up).\n", 730 | ">\n", 731 | "> Note: You should convert your categorical features to int type before you construct Dataset." 732 | ] 733 | }, 734 | { 735 | "cell_type": "code", 736 | "execution_count": null, 737 | "metadata": {}, 738 | "outputs": [], 739 | "source": [ 740 | "%%capture\n", 741 | "!pip install optuna" 742 | ] 743 | }, 744 | { 745 | "cell_type": "code", 746 | "execution_count": null, 747 | "metadata": {}, 748 | "outputs": [], 749 | "source": [ 750 | "import os, sys\n", 751 | "\n", 752 | "src_dir = os.path.join('/content/drive/', 'MyDrive', 'Colab Notebooks', 'thesis', 'src')\n", 753 | "sys.path.append(src_dir)" 754 | ] 755 | }, 756 | { 757 | "cell_type": "code", 758 | "execution_count": null, 759 | "metadata": {}, 760 | "outputs": [], 761 | "source": [ 762 | "import optuna\n", 763 | "import lightgbm as lgb\n", 764 | "from binary_label_metrics import BinaryLabelMetrics\n", 765 | "\n", 766 | "optuna.logging.set_verbosity(optuna.logging.WARN) #INFO, WARN" 767 | ] 768 | }, 769 | { 770 | "cell_type": "code", 771 | "execution_count": null, 772 | "metadata": {}, 773 | "outputs": [], 774 | "source": [ 775 | "prm = {\n", 776 | " # Number of Optuna trials\n", 777 | " 'NTRIALS': 20,\n", 778 | " # Number of boosted trees to be created\n", 779 | " 'NBOOST': 50,\n", 780 | " # Number of classes in response variable\n", 781 | " 'NCLASS': 2,\n", 782 | " # Morgan fingerpring bit length\n", 783 | " 'FP_BITS': 1024,\n", 784 | "}\n", 785 | "\n", 786 | "offs = max(map(len, prm.keys()))\n", 787 | "print('Parameters:')\n", 788 | "for k, v in prm.items():\n", 789 | " print(f'\\t{k:>{offs}}: {v}')" 790 | ] 791 | }, 792 | { 793 | "cell_type": "code", 794 | "execution_count": null, 795 | "metadata": {}, 796 | "outputs": [], 797 | "source": [ 798 | "class LightGBMObjective(object):\n", 799 | " def __init__(self, X_train, y_train):\n", 800 | " self.best_booster = None\n", 801 | " self._booster = None\n", 802 | " self.X = X_train\n", 803 | " self.y = y_train\n", 804 | " self.dtrain = lgb.Dataset(self.X, self.y)\n", 805 | " self.nclass = 2\n", 806 | " self.prm_lgb = { \n", 807 | " 'objective': 'multiclass' if self.nclass > 2 else 'binary',\n", 808 | " 'metric': None, \n", 809 | " 'verbosity': -1,\n", 810 | " 'boosting_type': 'gbdt',\n", 811 | " 'force_row_wise': True,\n", 812 | " 'min_gain_to_split': .5,\n", 813 | " }\n", 814 | " \n", 815 | " def __call__(self, trial): \n", 816 | " if self.nclass > 2:\n", 817 | " def f1_eval(preds, dataset):\n", 818 | " y = preds.reshape(-1, self.nclass).argmax(axis=1)\n", 819 | " f_score = f1_score(dataset.get_label(), y, average='micro')\n", 820 | " return 'f1_score', f_score, True\n", 821 | " else:\n", 822 | " def f1_eval(preds, dataset):\n", 823 | " pred1 = np.zeros(dataset.get_label().shape[0], dtype=int)\n", 824 | " pred1[-dataset.get_label().sum().astype(int):] = 1\n", 825 | " f_score = f1_score(dataset.get_label()[preds.argsort()], pred1,\n", 826 | " average='micro')\n", 827 | " return 'f1_score', f_score, True\n", 828 | "\n", 829 | " trial_prm = {\n", 830 | " 'learning_rate': trial.suggest_float('learning_rate', .01, .3, log=True),\n", 831 | " 'lambda_l1': trial.suggest_float('lambda_l1', 1E-3, 1., log=True),\n", 832 | " 'lambda_l2': trial.suggest_float('lambda_l2', .5, 3., log=True),\n", 833 | " 'num_leaves': trial.suggest_int('num_leaves', 8, 32),\n", 834 | " 'min_data_in_leaf': trial.suggest_int('min_data_in_leaf', 50, 100),\n", 835 | " 'feature_fraction': trial.suggest_float('feature_fraction', .3, .6),\n", 836 | " 'bagging_fraction': trial.suggest_float('bagging_fraction', .4, 1.),\n", 837 | " 'bagging_freq': trial.suggest_int('bagging_freq', 2, 6),\n", 838 | " 'boosting_type': trial.suggest_categorical('boosting_type', ['gbdt', 'rf'])\n", 839 | " }\n", 840 | " prm_lgb = dict(self.prm_lgb)\n", 841 | " prm_lgb.update(trial_prm)\n", 842 | " eval_hist = lgb.cv(prm_lgb, self.dtrain, nfold=5, seed=12345,\n", 843 | " num_boost_round=prm['NBOOST'], feval=f1_eval,\n", 844 | " callbacks=[lgb.early_stopping(20, verbose=False)])\n", 845 | " return eval_hist['f1_score-mean'][-1]\n", 846 | "\n", 847 | " def callback(self, study, trial):\n", 848 | " if study.best_trial == trial:\n", 849 | " print(f'{study.best_trial.number} ({study.best_trial.values[0]:.3f}) -> ', end=' ', flush=True)\n", 850 | " self.best_booster = self._booster\n", 851 | " return\n", 852 | " if trial.number % 20 == 0:\n", 853 | " print(f'{trial.number}', end=' ', flush=True)" 854 | ] 855 | }, 856 | { 857 | "cell_type": "markdown", 858 | "metadata": {}, 859 | "source": [ 860 | "Only consider the features that the model can process, _e.g._, get rid of the SMILES, which are strings. Also, remove the degradation percentage, since we are trying to predict the \"binarized\" version of it.\n", 861 | "\n", 862 | "TODO: From investigating the feature importance, we see that the _concentration_ is actually the most important one. Because of that, it is removed in the following experiments.\n", 863 | "\n", 864 | "TODO: If we remove the _concentration_, however, we might have many different entries with the _same data_ in the remaining columns." 865 | ] 866 | }, 867 | { 868 | "cell_type": "code", 869 | "execution_count": null, 870 | "metadata": {}, 871 | "outputs": [], 872 | "source": [ 873 | "removed_features = [\n", 874 | " 'degradation', # NOTE: Must be removed, it's the \"regression version\" of y\n", 875 | " # 'concentration',\n", 876 | " 'Smiles',\n", 877 | " 'Smiles_nostereo',\n", 878 | " 'poi_seq',\n", 879 | "]" 880 | ] 881 | }, 882 | { 883 | "cell_type": "markdown", 884 | "metadata": {}, 885 | "source": [ 886 | "Instantiate Optuna objective and start optimization, _i.e._, training." 887 | ] 888 | }, 889 | { 890 | "cell_type": "code", 891 | "execution_count": null, 892 | "metadata": {}, 893 | "outputs": [], 894 | "source": [ 895 | "X = X_train_upsampled.drop(removed_features, axis=1)\n", 896 | "objective = LightGBMObjective(X, y_train_upsampled)\n", 897 | "print(f\"Number of trials: {prm['NTRIALS']}\")\n", 898 | "print(f\"Trial ID (F1 score): \", end='')\n", 899 | "study = optuna.create_study(pruner=optuna.pruners.MedianPruner(n_warmup_steps=10), \n", 900 | " sampler=optuna.samplers.TPESampler(seed=1234),\n", 901 | " direction='maximize')\n", 902 | "study.optimize(objective, n_trials=prm['NTRIALS'], callbacks=[objective.callback])" 903 | ] 904 | }, 905 | { 906 | "cell_type": "code", 907 | "execution_count": null, 908 | "metadata": {}, 909 | "outputs": [], 910 | "source": [ 911 | "attributes = ('params', 'user_attrs', 'value', 'duration')\n", 912 | "trials_df = study.trials_dataframe(attrs=attributes)\n", 913 | "for y in ['params', 'user_attrs']:\n", 914 | " trials_df.columns = [x[1 + len(y):] if x.startswith(y) else x for x in trials_df.columns]\n", 915 | "trials_df['duration'] = trials_df['duration'].apply(lambda x: x.total_seconds())\n", 916 | "\n", 917 | "with pd.option_context('display.max_rows', 6, 'display.float_format', '{:.4f}'.format):\n", 918 | " display_html(trials_df.sort_values('value', ascending=False))" 919 | ] 920 | }, 921 | { 922 | "cell_type": "code", 923 | "execution_count": null, 924 | "metadata": {}, 925 | "outputs": [], 926 | "source": [ 927 | "print(f\"Training time: {trials_df['duration'].sum() / 60:.1f}min\")" 928 | ] 929 | }, 930 | { 931 | "cell_type": "code", 932 | "execution_count": null, 933 | "metadata": {}, 934 | "outputs": [], 935 | "source": [ 936 | "trials_df.groupby('boosting_type').agg(meanv=('value', 'mean'), sdv=('value', 'std'))" 937 | ] 938 | }, 939 | { 940 | "cell_type": "markdown", 941 | "metadata": {}, 942 | "source": [ 943 | "Identify best model and re-train it." 944 | ] 945 | }, 946 | { 947 | "cell_type": "code", 948 | "execution_count": null, 949 | "metadata": {}, 950 | "outputs": [], 951 | "source": [ 952 | "best_idx = np.argmax(trials_df['value'].values)\n", 953 | "lgb_prm = study.trials[best_idx].params\n", 954 | "lgb_prm.update({\n", 955 | " 'objective': 'multiclass',\n", 956 | " 'metric': None,\n", 957 | " 'num_class': 2,\n", 958 | " 'force_row_wise': True, \n", 959 | " 'verbosity': -1,\n", 960 | " 'min_gain_to_split': .5,\n", 961 | "})\n", 962 | "\n", 963 | "def f1_eval(preds, dtrain):\n", 964 | " preds = preds.reshape(prm['NCLASS'], -1).T.argmax(axis=1)\n", 965 | " f_score = f1_score(dtrain.get_label(), preds, average='micro')\n", 966 | " return 'f1_score', f_score, True\n", 967 | "\n", 968 | "# Get dataset\n", 969 | "# Balance classes via class weighting (unnecessacy, we are already upsampling)\n", 970 | "# wt = class_weight.compute_sample_weight(class_weight='balanced', y=y)\n", 971 | "# dtrain = lgb.Dataset(X, y, weight=wt)\n", 972 | "X = X_train_upsampled.drop(removed_features, axis=1)\n", 973 | "dtrain = lgb.Dataset(X, y_train_upsampled)\n", 974 | "model = lgb.train(lgb_prm, dtrain, feval=f1_eval, num_boost_round=prm['NBOOST'])\n", 975 | "model" 976 | ] 977 | }, 978 | { 979 | "cell_type": "markdown", 980 | "metadata": {}, 981 | "source": [ 982 | "#### Evaluation" 983 | ] 984 | }, 985 | { 986 | "cell_type": "code", 987 | "execution_count": null, 988 | "metadata": {}, 989 | "outputs": [], 990 | "source": [ 991 | "X = X_test.drop(removed_features, axis=1)\n", 992 | "y_hat = np.array([val[1] for val in model.predict(X)])\n", 993 | "scores_df = pd.DataFrame({'label': list(y_test), 'score': list(y_hat)})\n", 994 | "\n", 995 | "blm = BinaryLabelMetrics()\n", 996 | "blm.add_model('binary_gbm', scores_df)" 997 | ] 998 | }, 999 | { 1000 | "cell_type": "code", 1001 | "execution_count": null, 1002 | "metadata": {}, 1003 | "outputs": [], 1004 | "source": [ 1005 | "blm.plot_roc(params={'legloc': 4})" 1006 | ] 1007 | }, 1008 | { 1009 | "cell_type": "code", 1010 | "execution_count": null, 1011 | "metadata": {}, 1012 | "outputs": [], 1013 | "source": [ 1014 | "# blm.plot(chart_types=[2, 5], params={'legloc': 2, 'chart_thresh': 0.5})" 1015 | ] 1016 | }, 1017 | { 1018 | "cell_type": "code", 1019 | "execution_count": null, 1020 | "metadata": {}, 1021 | "outputs": [], 1022 | "source": [ 1023 | "X = X_test.drop(removed_features, axis=1)\n", 1024 | "y_pred = np.array([0 if f1 >= 50 else 1 for _, f1 in model.predict(X)])\n", 1025 | "conf_mat = confusion_matrix(y_test, y_pred)\n", 1026 | "disp = ConfusionMatrixDisplay(confusion_matrix=conf_mat,\n", 1027 | " display_labels=['inactive', 'active'])\n", 1028 | "g = disp.plot(cmap=plt.cm.Blues)\n", 1029 | "g = plt.title(f'Confusion Matrix')\n", 1030 | "# plt.savefig(os.path.join(fig_dir, f'conf_mat_{model_type}.pdf'))\n", 1031 | "plt.show()" 1032 | ] 1033 | }, 1034 | { 1035 | "cell_type": "markdown", 1036 | "metadata": {}, 1037 | "source": [ 1038 | "[Plot feature importance](https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.plot_importance.html#lightgbm.plot_importance)." 1039 | ] 1040 | }, 1041 | { 1042 | "cell_type": "code", 1043 | "execution_count": null, 1044 | "metadata": {}, 1045 | "outputs": [], 1046 | "source": [ 1047 | "g = lgb.plot_importance(model, max_num_features=20, figsize=(9, 7))" 1048 | ] 1049 | }, 1050 | { 1051 | "cell_type": "code", 1052 | "execution_count": null, 1053 | "metadata": {}, 1054 | "outputs": [], 1055 | "source": [ 1056 | "for c in train.columns:\n", 1057 | " if 'dss' in c:\n", 1058 | " print(c)" 1059 | ] 1060 | }, 1061 | { 1062 | "cell_type": "markdown", 1063 | "metadata": {}, 1064 | "source": [ 1065 | "## Metrics on Hold-Out Test Set" 1066 | ] 1067 | }, 1068 | { 1069 | "cell_type": "code", 1070 | "execution_count": null, 1071 | "metadata": {}, 1072 | "outputs": [], 1073 | "source": [ 1074 | "y_hat = np.array([val[1] for val in model.predict(X_test)])\n", 1075 | "print(y_hat.shape, y_test.shape)\n", 1076 | "scores_df = pd.DataFrame({'label': list(y_test), 'score': list(y_hat)})" 1077 | ] 1078 | }, 1079 | { 1080 | "cell_type": "code", 1081 | "execution_count": null, 1082 | "metadata": {}, 1083 | "outputs": [], 1084 | "source": [ 1085 | "blm = BinaryLabelMetrics()\n", 1086 | "blm.add_model('binary_gbm', scores_df)" 1087 | ] 1088 | }, 1089 | { 1090 | "cell_type": "code", 1091 | "execution_count": null, 1092 | "metadata": {}, 1093 | "outputs": [], 1094 | "source": [ 1095 | "blm.plot()" 1096 | ] 1097 | }, 1098 | { 1099 | "cell_type": "code", 1100 | "execution_count": null, 1101 | "metadata": {}, 1102 | "outputs": [], 1103 | "source": [ 1104 | "blm.plot_roc()" 1105 | ] 1106 | }, 1107 | { 1108 | "cell_type": "code", 1109 | "execution_count": null, 1110 | "metadata": {}, 1111 | "outputs": [], 1112 | "source": [ 1113 | "import math\n", 1114 | "\n", 1115 | "l = blm._f1[0]\n", 1116 | "l = [x for x in l if math.isnan(x) == False]\n", 1117 | "original = ['original', blm._auc, blm._prrec, max(l)]" 1118 | ] 1119 | }, 1120 | { 1121 | "cell_type": "markdown", 1122 | "metadata": {}, 1123 | "source": [ 1124 | "## Feature Importances" 1125 | ] 1126 | }, 1127 | { 1128 | "cell_type": "code", 1129 | "execution_count": null, 1130 | "metadata": {}, 1131 | "outputs": [], 1132 | "source": [ 1133 | "def shuffle(var, df):\n", 1134 | " if var == 'cellType':\n", 1135 | " cols = test.columns.tolist()\n", 1136 | " ct_cols = [c for c in cols if 'ct' in c]\n", 1137 | " new_df = df.copy(deep=True)\n", 1138 | " for col in ct_cols:\n", 1139 | " new_df[col] = new_df[col].sample(frac=1).values\n", 1140 | " return new_df\n", 1141 | " elif var == 'e3':\n", 1142 | " cols = test.columns.tolist()\n", 1143 | " ct_cols = [c for c in cols if 'e3' in c]\n", 1144 | " new_df = df.copy(deep=True)\n", 1145 | " for col in ct_cols:\n", 1146 | " new_df[col] = new_df[col].sample(frac=1).values\n", 1147 | " return new_df\n", 1148 | " elif var == 'ligand':\n", 1149 | " cols = test.columns.tolist()\n", 1150 | " ct_cols = [c for c in cols if 'sm' in c]\n", 1151 | " new_df = df.copy(deep=True)\n", 1152 | " for col in ct_cols:\n", 1153 | " new_df[col] = new_df[col].sample(frac=1).values\n", 1154 | " return new_df\n", 1155 | " elif var == 'receptor':\n", 1156 | " cols = test.columns.tolist()\n", 1157 | " ct_cols = [c for c in cols if (len(c) == 2 or len(c) == 3) and 'sm' not in c]\n", 1158 | " new_df = df.copy(deep=True)\n", 1159 | " for col in ct_cols:\n", 1160 | " new_df[col] = new_df[col].sample(frac=1).values\n", 1161 | " return new_df" 1162 | ] 1163 | }, 1164 | { 1165 | "cell_type": "code", 1166 | "execution_count": null, 1167 | "metadata": {}, 1168 | "outputs": [], 1169 | "source": [ 1170 | "import math\n", 1171 | "\n", 1172 | "metrics = [['shuffled_var', 'auc', 'pr_rec', 'f1']]\n", 1173 | "for var in ['cellType', 'e3', 'ligand', 'receptor']:\n", 1174 | " new_df = shuffle(var, test)\n", 1175 | " X_test = new_df.drop(['resp_categorical', 'resp', 'Smiles'], axis=1).values\n", 1176 | " y_hat = np.array([val[1] for val in model.predict(X_test)])\n", 1177 | " print(y_hat.shape, y_test.shape, y_hat)\n", 1178 | " scores_df = pd.DataFrame({'label': list(y_test), 'score': list(y_hat)})\n", 1179 | " blm = BinaryLabelMetrics()\n", 1180 | " blm.add_model('binary_gbm', scores_df)\n", 1181 | " lst = blm._f1[0]\n", 1182 | " newlist = [x for x in lst if math.isnan(x) == False]\n", 1183 | " metrics.append([var, blm._auc, blm._prrec, max(newlist)])\n", 1184 | "metrics.append(original)" 1185 | ] 1186 | }, 1187 | { 1188 | "cell_type": "code", 1189 | "execution_count": null, 1190 | "metadata": {}, 1191 | "outputs": [], 1192 | "source": [ 1193 | "metrics" 1194 | ] 1195 | }, 1196 | { 1197 | "cell_type": "code", 1198 | "execution_count": null, 1199 | "metadata": {}, 1200 | "outputs": [], 1201 | "source": [] 1202 | }, 1203 | { 1204 | "attachments": {}, 1205 | "cell_type": "markdown", 1206 | "metadata": {}, 1207 | "source": [ 1208 | "## POI Encoding" 1209 | ] 1210 | }, 1211 | { 1212 | "cell_type": "code", 1213 | "execution_count": null, 1214 | "metadata": {}, 1215 | "outputs": [], 1216 | "source": [] 1217 | } 1218 | ], 1219 | "metadata": { 1220 | "language_info": { 1221 | "name": "python" 1222 | }, 1223 | "orig_nbformat": 4 1224 | }, 1225 | "nbformat": 4, 1226 | "nbformat_minor": 2 1227 | } 1228 | --------------------------------------------------------------------------------