├── paccmann_predictor ├── utils │ ├── __init__.py │ ├── hyperparams.py │ ├── utils.py │ ├── loss_functions.py │ ├── layers.py │ └── interpret.py ├── __init__.py └── models │ ├── __init__.py │ ├── dense.py │ ├── knn.py │ ├── paccmann.py │ ├── paccmann_v2.py │ └── bimodal_mca.py ├── assets └── paccmann.png ├── examples ├── affinity │ ├── requirements.txt │ ├── conda.yml │ ├── affinity.json │ ├── regression.json │ ├── predict_affinity.py │ ├── train_affinity_regression.py │ └── train_affinity.py └── IC50 │ ├── conda.yml │ ├── example_params.json │ ├── paccmann_v2_params.json │ ├── test_paccmann.py │ └── train_paccmann.py ├── .github └── workflows │ ├── push_pypi.yml │ └── build.yml ├── LICENSE ├── setup.py ├── .gitignore └── README.md /paccmann_predictor/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/paccmann.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaccMann/paccmann_predictor/HEAD/assets/paccmann.png -------------------------------------------------------------------------------- /paccmann_predictor/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialization for `paccmann.models` submodule.""" 2 | __version__ = "1.0.1" 3 | -------------------------------------------------------------------------------- /examples/affinity/requirements.txt: -------------------------------------------------------------------------------- 1 | pytoda @ git+https://github.com/PaccMann/paccmann_datasets@0.2.4 2 | numpy>=1.14.3 3 | scipy>=1.3.1 4 | torch>=1.3.0 5 | -------------------------------------------------------------------------------- /examples/IC50/conda.yml: -------------------------------------------------------------------------------- 1 | name: paccmann_predictor 2 | channels: 3 | - https://conda.anaconda.org/rdkit 4 | dependencies: 5 | - rdkit=2019.03.1 6 | - python>=3.6,<3.8 7 | - pip>=19.1 8 | - pip: 9 | - pytoda==1.0.0 10 | - numpy>=1.14.3 11 | - scipy>=1.3.1 12 | - torch>=1.7.1 13 | - tqdm 14 | - pandas 15 | 16 | -------------------------------------------------------------------------------- /examples/affinity/conda.yml: -------------------------------------------------------------------------------- 1 | name: paccmann_predictor 2 | channels: 3 | - https://conda.anaconda.org/rdkit 4 | dependencies: 5 | - rdkit=2019.03.1 6 | - python>=3.7,<3.8 7 | - pip>=19.1 8 | - pip: 9 | - pytoda @ git+https://github.com/PaccMann/paccmann_datasets@0.2.4 10 | - numpy>=1.14.3 11 | - scipy>=1.3.1 12 | - torch>=1.3.0 13 | -------------------------------------------------------------------------------- /paccmann_predictor/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .bimodal_mca import BimodalMCA 2 | from .dense import Dense 3 | from .paccmann import MCA 4 | from .paccmann_v2 import PaccMannV2 5 | from .knn import knn # noqa 6 | 7 | # More models could follow 8 | MODEL_FACTORY = { 9 | 'mca': MCA, 10 | 'dense': Dense, 11 | 'bimodal_mca': BimodalMCA, 12 | 'paccmann_v2': PaccMannV2, 13 | 'knn': knn 14 | } 15 | -------------------------------------------------------------------------------- /examples/affinity/affinity.json: -------------------------------------------------------------------------------- 1 | { 2 | "augment_smiles": false, 3 | "smiles_canonical": false, 4 | "ligand_start_stop_token": true, 5 | "receptor_start_stop_token": true, 6 | "ligand_padding_length": 1024, 7 | "receptor_padding_length": 8192, 8 | "dense_hidden_sizes": [ 9 | 20 10 | ], 11 | "activation_fn": "relu", 12 | "dropout": 0.3, 13 | "batch_norm": true, 14 | "batch_size": 512, 15 | "lr": 0.001, 16 | "epochs": 200, 17 | "save_model": 25 18 | } -------------------------------------------------------------------------------- /examples/affinity/regression.json: -------------------------------------------------------------------------------- 1 | { 2 | "augment_smiles": false, 3 | "smiles_canonical": false, 4 | "ligand_start_stop_token": true, 5 | "receptor_start_stop_token": true, 6 | "ligand_padding_length": 512, 7 | "receptor_padding_length": 2048, 8 | "loss_fn": "mse", 9 | "dense_hidden_sizes": [ 10 | 512 11 | ], 12 | "activation_fn": "relu", 13 | "dropout": 0.3, 14 | "batch_norm": true, 15 | "batch_size": 64, 16 | "lr": 0.0001, 17 | "epochs": 200, 18 | "save_model": 25 19 | } -------------------------------------------------------------------------------- /.github/workflows/push_pypi.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-20.04 10 | steps: 11 | - uses: actions/checkout@v2 12 | - uses: actions/setup-python@v2 13 | - name: Install dependencies 14 | run: | 15 | python -m pip install --upgrade pip 16 | pip install setuptools wheel twine 17 | - name: Build and publish 18 | env: 19 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 20 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 21 | run: | 22 | python setup.py sdist bdist_wheel 23 | twine upload --skip-existing dist/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2019 Ali Oskooei, Jannis Born, Matteo Manica, Joris Cadow 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /examples/IC50/example_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "drug_sensitivity_min_max": true, 3 | "augment_smiles": true, 4 | "smiles_start_stop_token": true, 5 | "number_of_genes": 2128, 6 | "smiles_padding_length": 560, 7 | "stacked_dense_hidden_sizes": [ 8 | 1024, 9 | 512 10 | ], 11 | "activation_fn": "relu", 12 | "dropout": 0.5, 13 | "batch_norm": true, 14 | "filters": [ 15 | 64, 16 | 64, 17 | 64 18 | ], 19 | "multiheads": [ 20 | 4, 21 | 4, 22 | 4, 23 | 4 24 | ], 25 | "smiles_embedding_size": 16, 26 | "kernel_sizes": [ 27 | [ 28 | 3, 29 | 16 30 | ], 31 | [ 32 | 5, 33 | 16 34 | ], 35 | [ 36 | 11, 37 | 16 38 | ] 39 | ], 40 | "smiles_attention_size": 64, 41 | "embed_scale_grad": false, 42 | "final_activation": true, 43 | "gene_to_dense": false, 44 | "batch_size": 2048, 45 | "lr": 0.01, 46 | "optimizer": "adam", 47 | "loss_fn": "mse", 48 | "epochs": 200, 49 | "save_model": 25 50 | } -------------------------------------------------------------------------------- /paccmann_predictor/utils/hyperparams.py: -------------------------------------------------------------------------------- 1 | """Customizable model hyperparameter """ 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | 5 | from .loss_functions import ( 6 | mse_cc_loss, 7 | correlation_coefficient_loss, 8 | ) 9 | 10 | # LSTM(10, 20, 2) -> input has 10 features, 20 hidden size and 2 layers. 11 | # NOTE: Make sure to set batch_first=True. Optionally set bidirectional=True 12 | RNN_CELL_FACTORY = {'lstm': nn.LSTM, 'gru': nn.GRU} 13 | 14 | LOSS_FN_FACTORY = { 15 | 'mse': nn.MSELoss(), 16 | 'l1': nn.L1Loss(), 17 | 'mse_and_pearson': mse_cc_loss, 18 | 'pearson': correlation_coefficient_loss, 19 | 'binary_cross_entropy': nn.BCELoss() 20 | } 21 | 22 | ACTIVATION_FN_FACTORY = { 23 | 'relu': nn.ReLU(), 24 | 'sigmoid': nn.Sigmoid(), 25 | 'selu': nn.SELU(), 26 | 'tanh': nn.Tanh(), 27 | 'lrelu': nn.LeakyReLU(), 28 | 'elu': nn.ELU() 29 | } 30 | OPTIMIZER_FACTORY = { 31 | 'adam': optim.Adam, 32 | 'adadelta': optim.Adadelta, 33 | 'adagrad': optim.Adagrad, 34 | 'gd': optim.SGD, 35 | 'sparseadam': optim.SparseAdam, 36 | 'adamax': optim.Adamax, 37 | 'asgd': optim.ASGD, 38 | 'lbfgs': optim.LBFGS, 39 | 'rmsprop': optim.RMSprop, 40 | 'rprop': optim.Rprop 41 | } 42 | -------------------------------------------------------------------------------- /examples/IC50/paccmann_v2_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "drug_sensitivity_min_max": true, 3 | "augment_smiles": true, 4 | "smiles_start_stop_token": true, 5 | "number_of_genes": 2128, 6 | "smiles_padding_length": 512, 7 | "stacked_dense_hidden_sizes": [ 8 | 1024, 9 | 512 10 | ], 11 | "activation_fn": "relu", 12 | "dropout": 0.5, 13 | "batch_norm": true, 14 | "filters": [ 15 | 64, 16 | 64, 17 | 64 18 | ], 19 | "molecule_heads": [ 20 | 4, 21 | 4, 22 | 4, 23 | 4 24 | ], 25 | "gene_heads": [ 26 | 2, 27 | 2, 28 | 2, 29 | 2 30 | ], 31 | "smiles_embedding_size": 16, 32 | "kernel_sizes": [ 33 | [ 34 | 3, 35 | 16 36 | ], 37 | [ 38 | 5, 39 | 16 40 | ], 41 | [ 42 | 11, 43 | 16 44 | ] 45 | ], 46 | "smiles_attention_size": 64, 47 | "gene_attention_size": 1, 48 | "embed_scale_grad": false, 49 | "final_activation": true, 50 | "batch_size": 256, 51 | "lr": 0.01, 52 | "optimizer": "adam", 53 | "loss_fn": "mse", 54 | "epochs": 10, 55 | "save_model": 25, 56 | "dataset_device": "cpu" 57 | } -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Install package.""" 2 | import codecs 3 | import os 4 | 5 | from setuptools import find_packages, setup 6 | 7 | 8 | def read(rel_path): 9 | here = os.path.abspath(os.path.dirname(__file__)) 10 | with codecs.open(os.path.join(here, rel_path), 'r') as fp: 11 | return fp.read() 12 | 13 | 14 | def get_version(rel_path): 15 | for line in read(rel_path).splitlines(): 16 | if line.startswith('__version__'): 17 | delim = '"' if '"' in line else "'" 18 | return line.split(delim)[1] 19 | else: 20 | raise RuntimeError('Unable to find version string.') 21 | 22 | 23 | setup( 24 | name='paccmann_predictor', 25 | version=get_version('paccmann_predictor/__init__.py'), 26 | description=('PyTorch implementation of PaccMann'), 27 | long_description=open('README.md').read(), 28 | long_description_content_type='text/markdown', 29 | url='https://github.com/PaccMann/paccmann_predictor', 30 | author='Ali Oskooei, Jannis Born, Matteo Manica, Joris Cadow', 31 | author_email=( 32 | 'ali.oskooei@gmail.com, jab@zurich.ibm.com, ' 33 | 'drugilsberg@gmail.com, joriscadow@gmail.com' 34 | ), 35 | install_requires=[ 36 | 'numpy', 37 | 'scipy', 38 | 'torch>=1.0.0', 39 | 'pandas', 40 | 'tqdm', 41 | 'rdkit', 42 | 'pytoda>=1.1.5', 43 | ], 44 | packages=find_packages('.'), 45 | zip_safe=False, 46 | ) 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # secrets detector 2 | .secrets.baseline 3 | 4 | # mac files 5 | .DS_Store 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | env/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # dotenv 89 | .env 90 | 91 | # virtualenv 92 | .venv 93 | venv/ 94 | ENV/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | # data files 110 | .pdf 111 | .csv 112 | 113 | # swap files 114 | *swp 115 | 116 | # shell scripts 117 | *.sh 118 | 119 | # trained models 120 | /models 121 | 122 | # IDE 123 | .vscode/ -------------------------------------------------------------------------------- /paccmann_predictor/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def get_device(): 6 | return torch.device('cuda' if cuda() else 'cpu') 7 | 8 | 9 | def cuda(): 10 | return torch.cuda.is_available() 11 | 12 | 13 | def to_np(x): 14 | return x.data.cpu().numpy() 15 | 16 | 17 | def attention_list_to_matrix(coding_tuple, dim=2): 18 | """[summary] 19 | 20 | Args: 21 | coding_tuple (list((torch.Tensor, torch.Tensor))): iterable of 22 | (outputs, att_weights) tuples coming from the attention function 23 | dim (int, optional): The dimension along which expansion takes place to 24 | concatenate the attention weights. Defaults to 2. 25 | 26 | Returns: 27 | (torch.Tensor, torch.Tensor): raw_coeff, coeff 28 | 29 | raw_coeff: with the attention weights of all multiheads and 30 | convolutional kernel sizes concatenated along the given dimension, 31 | by default the last dimension. 32 | coeff: where the dimension is collapsed by averaging. 33 | """ 34 | raw_coeff = torch.cat( 35 | [torch.unsqueeze(tpl[1], 2) for tpl in coding_tuple], dim=dim 36 | ) 37 | return raw_coeff, torch.mean(raw_coeff, dim=dim) 38 | 39 | 40 | def get_log_molar(y, ic50_max=None, ic50_min=None): 41 | """ 42 | Converts PaccMann predictions from [0,1] to log(micromolar) range. 43 | """ 44 | return y * (ic50_max - ic50_min) + ic50_min 45 | 46 | 47 | class Squeeze(nn.Module): 48 | """Squeeze wrapper for nn.Sequential.""" 49 | 50 | def forward(self, data): 51 | return torch.squeeze(data) 52 | 53 | 54 | class Unsqueeze(nn.Module): 55 | """Unsqueeze wrapper for nn.Sequential.""" 56 | 57 | def __init__(self, dim): 58 | super(Unsqueeze, self).__init__() 59 | self.dim = dim 60 | 61 | def forward(self, data): 62 | return torch.unsqueeze(data, self.dim) 63 | 64 | 65 | class Temperature(nn.Module): 66 | """Temperature wrapper for nn.Sequential.""" 67 | 68 | def __init__(self, temperature): 69 | super(Temperature, self).__init__() 70 | self.temperature = temperature 71 | 72 | def forward(self, data): 73 | return data / self.temperature 74 | -------------------------------------------------------------------------------- /paccmann_predictor/utils/loss_functions.py: -------------------------------------------------------------------------------- 1 | """Loss function definitions for paccmann.""" 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def pearsonr(x, y): 7 | """Compute Pearson correlation. 8 | 9 | Args: 10 | x (torch.Tensor): 1D vector 11 | y (torch.Tensor): 1D vector of the same size as y. 12 | 13 | Raises: 14 | TypeError: not torch.Tensors. 15 | ValueError: not same shape or at least length 2. 16 | 17 | Returns: 18 | Pearson correlation coefficient. 19 | """ 20 | if not isinstance(x, torch.Tensor) or not isinstance(y, torch.Tensor): 21 | raise TypeError('Function expects torch Tensors.') 22 | 23 | if len(x.shape) > 1 or len(y.shape) > 1: 24 | raise ValueError(' x and y must be 1D Tensors.') 25 | 26 | if len(x) != len(y): 27 | raise ValueError('x and y must have the same length.') 28 | 29 | if len(x) < 2: 30 | raise ValueError('x and y must have length at least 2.') 31 | 32 | # If an input is constant, the correlation coefficient is not defined. 33 | if bool((x == x[0]).all()) or bool((y == y[0]).all()): 34 | raise ValueError('Constant input, r is not defined.') 35 | 36 | mx = x - torch.mean(x) 37 | my = y - torch.mean(y) 38 | cost = ( 39 | torch.sum(mx * my) / 40 | (torch.sqrt(torch.sum(mx**2)) * torch.sqrt(torch.sum(my**2))) 41 | ) 42 | return torch.clamp(cost, min=-1.0, max=1.0) 43 | 44 | 45 | def correlation_coefficient_loss(labels, predictions): 46 | """Compute loss based on Pearson correlation. 47 | 48 | Args: 49 | labels (torch.Tensor): reference values 50 | predictions (torch.Tensor): predicted values 51 | 52 | Returns: 53 | torch.Tensor: A loss that when minimized forces high squared correlation coefficient: 54 | \$1 - r(labels, predictions)^2\$ # noqa 55 | """ 56 | return 1 - pearsonr(labels, predictions)**2 57 | 58 | 59 | def mse_cc_loss(labels, predictions): 60 | """Compute loss based on MSE and Pearson correlation. 61 | 62 | The main assumption is that MSE lies in [0,1] range, i.e.: range is 63 | comparable with Pearson correlation-based loss. 64 | 65 | Args: 66 | labels (torch.Tensor): reference values 67 | predictions (torch.Tensor): predicted values 68 | 69 | Returns: 70 | torch.Tensor: A loss that computes the following: 71 | \$mse(labels, predictions) + 1 - r(labels, predictions)^2\$ # noqa 72 | """ 73 | mse_loss_fn = nn.MSELoss() 74 | mse_loss = mse_loss_fn(predictions, labels) 75 | cc_loss = correlation_coefficient_loss(labels, predictions) 76 | return mse_loss + cc_loss 77 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: build 3 | on: [push] 4 | jobs: 5 | conda-tests: 6 | name: Test with conda (${{ matrix.os }}) 7 | runs-on: ${{ matrix.os }} 8 | continue-on-error: ${{ matrix.experimental }} 9 | strategy: 10 | fail-fast: false 11 | matrix: 12 | include: 13 | - os: ubuntu-latest 14 | pip_cache_path: ~/.cache/pip 15 | experimental: false 16 | defaults: 17 | run: 18 | shell: bash -l {0} # For conda 19 | env: 20 | # Increase this value to reset cache if conda.yml and requirements.txt 21 | # have not changed 22 | CACHE_NUMBER: 0 23 | steps: 24 | - uses: actions/checkout@v4 25 | - name: Checkout and setup python 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: '3.10' 29 | architecture: 'x64' 30 | 31 | - name: Cache conda 32 | uses: actions/cache@v4 33 | with: 34 | path: ~/conda_pkgs_dir 35 | key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{ 36 | hashFiles('conda.yml') }} 37 | 38 | - name: Cache pip 39 | uses: actions/cache@v4 40 | with: 41 | path: ${{ matrix.pip_cache_path }} 42 | key: ${{ runner.os }}-pip--${{ env.CACHE_NUMBER }}-${{ 43 | hashFiles('requirements.txt') }} 44 | 45 | - name: IC50 - Conda environment setup 46 | uses: conda-incubator/setup-miniconda@v2 47 | with: 48 | activate-environment: paccmann_predictor 49 | environment-file: examples/IC50/conda.yml 50 | auto-activate-base: false 51 | use-only-tar-bz2: true # This needs to be set for proper caching 52 | auto-update-conda: true # Required for windows for `use-only-tar-bz2` 53 | 54 | - name: IC50 - Install dependencies and run tests 55 | run: | 56 | python3 -m pip install --upgrade pip 57 | pip3 install --no-deps . 58 | python3 -c "import paccmann_predictor" 59 | python3 examples/IC50/train_paccmann.py -h 60 | - name: Affinity - Conda environment setup 61 | uses: conda-incubator/setup-miniconda@v2 62 | with: 63 | activate-environment: paccmann_predictor 64 | environment-file: examples/affinity/conda.yml 65 | auto-activate-base: false 66 | use-only-tar-bz2: true # This needs to be set for proper caching 67 | auto-update-conda: true # Required for windows for `use-only-tar-bz2` 68 | 69 | - name: Install dependencies and test code 70 | run: | 71 | pip3 install -e . 72 | python3 -c "import paccmann_predictor.models" 73 | 74 | - name: Affinity - Install dependencies and run tests 75 | run: | 76 | python3 -m pip install --upgrade pip 77 | pip3 install --no-cache-dir -r examples/affinity/requirements.txt 78 | pip3 install --no-deps . 79 | python3 -c "import paccmann_predictor" 80 | python3 examples/affinity/train_affinity.py -h 81 | - name: Send Slack notification 82 | uses: 8398a7/action-slack@v2 83 | if: always() 84 | with: 85 | status: ${{ job.status }} 86 | text: "CI Build ${{ matrix.os }}" 87 | author_name: ${{ github.actor }} 88 | env: 89 | SLACK_WEBHOOK_URL: ${{ secrets.SLACK_HOOK_URL }} 90 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} -------------------------------------------------------------------------------- /paccmann_predictor/models/dense.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ..utils.layers import dense_layer 4 | from ..utils.hyperparams import ACTIVATION_FN_FACTORY, LOSS_FN_FACTORY 5 | from ..utils.utils import get_device 6 | 7 | 8 | class Dense(nn.Module): 9 | """ This is a Dense model for validation """ 10 | 11 | def __init__(self, params, *args, **kwargs): 12 | """Constructor. 13 | 14 | Args: 15 | params (dict): A dictionary containing the parameter to built the 16 | dense Decoder. 17 | TODO params should become actual arguments (use **params). 18 | 19 | Items in params: 20 | dense_sizes (list[int]): Number of neurons in the hidden layers. 21 | number_of_genes (int, optional): Number of -omics features of cell. 22 | Defaults to 2128. 23 | num_drug_features (int, optional): Number of features for molecule. 24 | Defaults to 512. 25 | activation_fn (string, optional): Activation function used in all 26 | layers for specification in ACTIVATION_FN_FACTORY. 27 | Defaults to 'relu'. 28 | batch_norm (bool, optional): Whether batch normalization is 29 | applied. Defaults to True. 30 | dropout (float, optional): Dropout probability in all 31 | except parametric layer. Defaults to 0.0. 32 | *args, **kwargs: positional and keyword arguments are ignored. 33 | 34 | Example params: 35 | ``` 36 | { 37 | "dense_sizes": [2048, 1024, 512, 1], 38 | "dropout" : 0.1, 39 | "activation_fn": 'relu', 40 | } 41 | ``` 42 | """ 43 | 44 | super(Dense, self).__init__(*args, **kwargs) 45 | 46 | # Model Parameter 47 | self.device = get_device() 48 | self.params = params 49 | self.loss_fn = LOSS_FN_FACTORY[params.get('loss_fn', 'mse')] 50 | self.number_of_genes = params.get('number_of_genes', 2128) 51 | self.num_drug_features = params.get('num_drug_features', 512) 52 | self.hidden_sizes = params.get( 53 | 'stacked_dense_hidden_sizes', [ 54 | self.number_of_genes + self.num_drug_features, 1024, 512, 256, 55 | 64 56 | ] 57 | ) 58 | 59 | self.dropout = params.get('dropout', 0.0) 60 | self.act_fn = ACTIVATION_FN_FACTORY[ 61 | params.get('activation_fn', 'relu')] 62 | 63 | self.dense_layers = [ 64 | dense_layer( 65 | self.hidden_sizes[ind], self.hidden_sizes[ind + 1], 66 | self.act_fn, self.dropout 67 | ).to(self.device) for ind in range(len(self.hidden_sizes) - 1) 68 | ] 69 | 70 | self.final_dense = nn.Linear(self.hidden_sizes[-1], 1) 71 | 72 | def forward(self, fps, gep): 73 | """Forward pass through the dense model. 74 | 75 | Args: 76 | fps (torch.Tensor) of type int and shape `[bs, 512 (bits). 77 | gep (torch.Tensor): of shape `[bs, num_genes]`. 78 | 79 | Returns: 80 | (torch.Tensor, torch.Tensor): predictions, prediction_dict 81 | 82 | predictions is IC50 drug sensitivity prediction of shape `[bs, 1]`. 83 | prediction_dict includes the prediction and attention weights. 84 | """ 85 | 86 | inputs = torch.cat([fps.float(), gep], dim=1) 87 | 88 | for dl in self.dense_layers: 89 | inputs = dl(inputs) 90 | 91 | predictions = self.final_dense(inputs) 92 | prediction_dict = {'IC50': predictions} 93 | return predictions, prediction_dict 94 | 95 | def loss(self, yhat, y): 96 | return self.loss_fn(yhat, y) 97 | 98 | def load(self, path, *args, **kwargs): 99 | """Load model from path.""" 100 | weights = torch.load(path, *args, **kwargs) 101 | self.load_state_dict(weights) 102 | 103 | def save(self, path, *args, **kwargs): 104 | """Save model to path.""" 105 | torch.save(self.state_dict(), path, *args, **kwargs) 106 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 2 | [![Build Status](https://github.com/PaccMann/paccmann_predictor/actions/workflows/build.yml/badge.svg)](https://github.com/PaccMann/paccmann_predictor/actions/workflows/build.yml) 3 | 4 | # paccmann_predictor 5 | 6 | Drug interaction prediction with PaccMann. 7 | 8 | `paccmann_predictor` is a package for drug interaction prediction, with examples of 9 | anticancer drug sensitivity prediction and drug target affinity prediction. Please see our papers: 10 | 11 | - [_Toward explainable anticancer compound sensitivity prediction via multimodal attention-based convolutional encoders_](https://doi.org/10.1021/acs.molpharmaceut.9b00520) (*Molecular Pharmaceutics*, 2019). This is the original paper on IC50 prediction using drug properties and tissue-specific cell line information (gene expression profiles). While the original code was written in `tensorflow` and is available [here](https://github.com/drugilsberg/paccmann), this is the `pytorch` implementation of the best PaccMann architecture (multiscale convolutional encoder). 12 | 13 | 14 | **PaccMann for affinity prediction:** 15 | - [Data-driven molecular design for discovery and synthesis of novel ligands: a case study on SARS-CoV-2](https://iopscience.iop.org/article/10.1088/2632-2153/abe808) (_Machine Learning: Science and Technology_, 2021). In there, we propose a slightly modified version to predict drug-target binding affinities based on protein sequences and SMILES 16 | 17 | ![Graphical abstract](https://github.com/PaccMann/paccmann_predictor/blob/master/assets/paccmann.png "Graphical abstract") 18 | 19 | ## Installation 20 | The library itself has few dependencies (see [setup.py](setup.py)) with loose requirements. 21 | First, set up the environment as follows: 22 | ```sh 23 | conda env create -f examples/IC50/conda.yml 24 | conda activate paccmann_predictor 25 | pip install -e . 26 | ``` 27 | 28 | 29 | ## Evaluate pretrained drug sensitivty model on your own data 30 | First, please consider using our public [PaccMann webservice](https://ibm.biz/paccmann-aas) as described in the [NAR paper](https://academic.oup.com/nar/article/48/W1/W502/5836770). 31 | 32 | To use our pretrained model, please download the model from: https://ibm.biz/paccmann-data (just download `models/single_pytorch_model`). 33 | For example, assuming that you: 34 | 1. Set up your conda environment as described above; 35 | 2. Downloaded the model linked above in a directory called `single_pytorch_model` and 36 | 3. Downloaded the data from https://ibm.box.com/v/paccmann-pytoda-data in folders `data` and `splitted_data`; 37 | then, the following command should work: 38 | ```console 39 | (paccmann_predictor) $ python examples/IC50/test_paccmann.py \ 40 | splitted_data/gdsc_cell_line_ic50_test_fraction_0.1_id_997_seed_42.csv \ 41 | data/gene_expression/gdsc-rnaseq_gene-expression.csv \ 42 | data/smiles/gdsc.smi \ 43 | data/2128_genes.pkl \ 44 | single_pytorch_model/smiles_language \ 45 | single_pytorch_model/weights/best_mse_paccmann_v2.pt \ 46 | results \ 47 | single_pytorch_model/model_params.json 48 | ``` 49 | *NOTE*: If you bring your own data, please make sure to provide the omic data for the 2128 genes specified in `data/2128_genes.pkl`. Your omic data (here it is `data/gene_expression/gdsc-rnaseq_gene-expression.csv`) can contain more columns and it does not need to follow the order of the pickled gene list. But please dont change this pickle file. Also note that this is PaccMannV2 which is slightly improved compared to the paper version (context attention on both modalities). 50 | 51 | ## Finetuning on your own data 52 | You can also **finetune** our pretrained model on your data instead of training a model from scratch. For that, please follow the instruction below for training on scratch and just set: 53 | - `model_path` --> directory where the `single_pytorch_model` is stored 54 | - `training_name` --> this should be `single_pytorch_model` 55 | - `params_filepath` --> `base_path/single_pytorch_model/model_params.json` 56 | 57 | 58 | ## Training a model from scratch 59 | To run the example training script we provide environment files under `examples/IC50/`. 60 | In the `examples` directory is a training script [train_paccmann.py](./examples/IC50/train_paccmann.py) that makes use 61 | of `paccmann_predictor`. 62 | 63 | ```console 64 | (paccmann_predictor) $ python examples/IC50/train_paccmann.py -h 65 | usage: train_paccmann.py [-h] 66 | train_sensitivity_filepath test_sensitivity_filepath 67 | gep_filepath smi_filepath gene_filepath 68 | smiles_language_filepath model_path params_filepath 69 | training_name 70 | 71 | positional arguments: 72 | train_sensitivity_filepath 73 | Path to the drug sensitivity (IC50) data. 74 | test_sensitivity_filepath 75 | Path to the drug sensitivity (IC50) data. 76 | gep_filepath Path to the gene expression profile data. 77 | smi_filepath Path to the SMILES data. 78 | gene_filepath Path to a pickle object containing list of genes. 79 | smiles_language_filepath 80 | Path to a pickle object a SMILES language object. 81 | model_path Directory where the model will be stored. 82 | params_filepath Path to the parameter file. 83 | training_name Name for the training. 84 | 85 | optional arguments: 86 | -h, --help show this help message and exit 87 | ``` 88 | 89 | `params_filepath` could point to [examples/IC50/example_params.json](examples/IC50/example_params.json), examples for other files can be downloaded from [here](https://ibm.box.com/v/paccmann-pytoda-data). 90 | 91 | ## References 92 | 93 | If you use `paccmann_predictor` in your projects, please cite the following: 94 | 95 | ```bib 96 | @article{manica2019paccmann, 97 | title={Toward explainable anticancer compound sensitivity prediction via multimodal attention-based convolutional encoders}, 98 | author={Manica, Matteo and Oskooei, Ali and Born, Jannis and Subramanian, Vigneshwari and S{\'a}ez-Rodr{\'\i}guez, Julio and Mart{\'\i}nez, Mar{\'\i}a Rodr{\'\i}guez}, 99 | journal={Molecular pharmaceutics}, 100 | volume={16}, 101 | number={12}, 102 | pages={4797--4806}, 103 | year={2019}, 104 | publisher={ACS Publications}, 105 | doi = {10.1021/acs.molpharmaceut.9b00520}, 106 | note = {PMID: 31618586} 107 | } 108 | 109 | @article{born2021datadriven, 110 | author = {Born, Jannis and Manica, Matteo and Cadow, Joris and Markert, Greta and Mill, Nil Adell and Filipavicius, Modestas and Janakarajan, Nikita and Cardinale, Antonio and Laino, Teodoro and {Rodr{\'{i}}guez Mart{\'{i}}nez}, Mar{\'{i}}a}, 111 | doi = {10.1088/2632-2153/abe808}, 112 | issn = {2632-2153}, 113 | journal = {Machine Learning: Science and Technology}, 114 | number = {2}, 115 | pages = {025024}, 116 | title = {{Data-driven molecular design for discovery and synthesis of novel ligands: a case study on SARS-CoV-2}}, 117 | url = {https://iopscience.iop.org/article/10.1088/2632-2153/abe808}, 118 | volume = {2}, 119 | year = {2021} 120 | } 121 | ``` 122 | -------------------------------------------------------------------------------- /paccmann_predictor/models/knn.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from rdkit import Chem, DataStructs 6 | from rdkit.Chem import AllChem 7 | from tqdm import tqdm 8 | from time import time 9 | from scipy.stats import pearsonr 10 | 11 | 12 | def knn( 13 | train_df: pd.DataFrame, 14 | test_df: pd.DataFrame, 15 | drug_df: pd.DataFrame, 16 | cell_df: pd.DataFrame, 17 | k: int = 1, 18 | return_knn_labels: bool = False, 19 | verbose: bool = False, 20 | result_path: str = None, 21 | chirality: bool = False, 22 | radius: int = 2, 23 | ): 24 | """Baseline model for CPI prediction. Applies KNN classification using as 25 | similarity the Euclidean disance between cell representations (e.g. RNA-Seq) 26 | and FP similarity of the drugs. 27 | Predictions conceptually correspond to the predict_proba method of 28 | sklearn.neighbors.KNeighborsClassifier. 29 | 30 | Args: 31 | train_df (pd.DataFrame): DF with training samples in rows. Columns are: 32 | 'drug', 'cell_line' and 'label'. 33 | test_df (pd.DataFrame): DF with testing samples in rows. Columns are: 34 | 'drug', 'cell_line' and 'label'. 35 | drug_df (pd.DataFrame): DF with drug name as identifier and SMILES column. 36 | cell_df (pd.DataFrame): DF with cell line name as identifier and omic values 37 | as columns. 38 | k (int, optional): Hyperparameter for KNN classification. Defaults to 1. 39 | return_knn_labels (bool, optional): If set, the labels of the K nearest 40 | neighbors are also returned. 41 | """ 42 | assert isinstance(train_df, pd.DataFrame) 43 | assert isinstance(test_df, pd.DataFrame) 44 | assert isinstance(drug_df, pd.DataFrame) 45 | assert isinstance(cell_df, pd.DataFrame) 46 | 47 | # Compute FPs of training data: 48 | drug_fp_dict = dict( 49 | zip( 50 | drug_df.index, 51 | [ 52 | AllChem.GetMorganFingerprintAsBitVect( 53 | Chem.MolFromSmiles(smi), radius, useChirality=chirality 54 | ) 55 | for smi in drug_df["SMILES"].values 56 | ], 57 | ) 58 | ) 59 | # Compute pairwise distances 60 | print(f"Computing pairwise distances of {len(cell_df)} expression profiles") 61 | cell_dist_dict = {} 62 | max_cell_dist = 0 63 | for cell_a_name, cell_a in tqdm(cell_df.T.items()): 64 | cell_dist_dict[cell_a_name] = {} 65 | for cell_b_name, cell_b in cell_df.T.items(): 66 | d = np.linalg.norm(cell_a - cell_b) 67 | cell_dist_dict[cell_a_name][cell_b_name] = d 68 | if d > max_cell_dist: 69 | max_cell_dist = d 70 | 71 | # Will store computed distances to avoid re-computation 72 | tani_dict = {} 73 | 74 | predictions, drugs, cells, labels = [], [], [], [] 75 | knn_labels_sample, knn_labels_full = [], [] 76 | flipper = lambda x: x * -1 + 1 77 | t = time() 78 | train_labels = np.array(train_df["label"]) 79 | for idx_loc, test_sample in tqdm(test_df.iterrows()): 80 | 81 | idx = test_df.index.get_loc(idx_loc) 82 | 83 | if verbose and idx % 10 == 0: 84 | print(f"Idx {idx}/{len(test_df)}") 85 | 86 | cell_name = test_sample["cell_line"] 87 | drug_name = test_sample["drug"] 88 | label = test_sample["label"] 89 | fp = drug_fp_dict[drug_name] 90 | 91 | new_mol = False 92 | if drug_name not in tani_dict.keys(): 93 | tani_dict[drug_name] = {} 94 | new_mol = True 95 | 96 | if new_mol: 97 | 98 | def get_mol_dist(train_drug): 99 | if train_drug in tani_dict[drug_name].keys(): 100 | return tani_dict[drug_name][train_drug] 101 | else: 102 | tani_dict[drug_name][train_drug] = flipper( 103 | DataStructs.FingerprintSimilarity(fp, drug_fp_dict[train_drug]) 104 | ) 105 | return tani_dict[drug_name][train_drug] 106 | 107 | else: 108 | 109 | def get_mol_dist(train_drug): 110 | return tani_dict[drug_name][train_drug] 111 | 112 | get_cell_dist = lambda x: cell_dist_dict[cell_name][x] 113 | 114 | # new_cell = False 115 | # if cell_name not in omic_dict.keys(): 116 | # omic_dict[cell_name] = {} 117 | # new_cell = True 118 | 119 | # if new_cell: 120 | 121 | # def get_cell_dist(train_cell_name): 122 | 123 | # if train_cell_name in omic_dict[cell_name].keys(): 124 | # return omic_dict[cell_name][train_cell_name] 125 | # else: 126 | # omic_dict[cell_name][train_cell_name] = np.linalg.norm( 127 | # cell_profile - cell_df_dict[train_cell_name] 128 | # ) 129 | # return omic_dict[cell_name][train_cell_name] 130 | 131 | # else: 132 | 133 | # def get_cell_dist(train_cell_name): 134 | # return omic_dict[cell_name][train_cell_name] 135 | 136 | mol_dists, cell_dists = np.zeros((len(train_df),)), np.zeros((len(train_df),)) 137 | 138 | # print(f"Rest took {time()-t}") 139 | # t = time() 140 | mol_dists = np.array(list(map(get_mol_dist, train_df["drug"].values))) 141 | # print(f" Mol dists took {time()-t}") 142 | # t = time() 143 | cell_dists = np.array(list(map(get_cell_dist, train_df["cell_line"].values))) 144 | # print(f"Cell dists took {time()-t}") 145 | # t = time() 146 | 147 | # Normalize cell distances 148 | cell_dists_sample = cell_dists / np.max(cell_dists) 149 | cell_dists_full = cell_dists / max_cell_dist 150 | 151 | knns_sample = np.argsort(mol_dists + cell_dists_sample)[:k] 152 | knns_full = np.argsort(mol_dists + cell_dists_full)[:k] 153 | 154 | _knn_labels_sample = train_labels[knns_sample] 155 | _knn_labels_full = train_labels[knns_full] 156 | 157 | knn_labels_sample.append(_knn_labels_sample) 158 | knn_labels_full.append(_knn_labels_full) 159 | 160 | predictions.append( 161 | (np.mean(_knn_labels_sample) + np.mean(_knn_labels_full)) / 2 162 | ) 163 | 164 | drugs.append(drug_name) 165 | cells.append(cell_name) 166 | labels.append(label) 167 | 168 | if result_path is not None and idx % 100 == 0 and idx > 0: 169 | for x, y in zip( 170 | [knn_labels_sample, knn_labels_full], ["sample_norm", "full_norm"] 171 | ): 172 | df = pd.DataFrame(x) 173 | df.insert(0, "label", labels) 174 | df.insert(0, "cell", cells) 175 | df.insert(0, "drug", drugs) 176 | df.to_csv(os.path.join(result_path, f"knn_{y}_{idx}.csv")) 177 | p = pearsonr(labels, np.array(x).mean(axis=1)) 178 | print(f"Running pearson ({y}): {round(p[0], 4)}") 179 | 180 | if result_path is not None: 181 | for x, y in zip( 182 | [knn_labels_sample, knn_labels_full], ["sample_norm", "full_norm"] 183 | ): 184 | df = pd.DataFrame(x) 185 | df.insert(0, "label", labels) 186 | df.insert(0, "cell", cells) 187 | df.insert(0, "drug", drugs) 188 | df.to_csv(os.path.join(result_path, f"knn_{y}_{idx}.csv")) 189 | p = pearsonr(labels, np.array(x).mean(axis=1)) 190 | print(f"===Final pearson ({y}): {round(p[0], 4)}===") 191 | 192 | return (predictions, knn_labels_sample) if return_knn_labels else predictions 193 | -------------------------------------------------------------------------------- /examples/IC50/test_paccmann.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Test PaccMann predictor.""" 3 | import argparse 4 | import json 5 | import logging 6 | import os 7 | import pickle 8 | import sys 9 | from copy import deepcopy 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import torch 14 | from tqdm import tqdm 15 | from paccmann_predictor.models import MODEL_FACTORY 16 | from paccmann_predictor.utils.utils import get_device 17 | from pytoda.datasets import DrugSensitivityDataset 18 | from pytoda.smiles.smiles_language import SMILESTokenizer 19 | from scipy.stats import pearsonr 20 | 21 | # setup logging 22 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 23 | 24 | # yapf: disable 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument( 27 | 'test_sensitivity_filepath', type=str, 28 | help='Path to the drug sensitivity (IC50) data.' 29 | ) 30 | parser.add_argument( 31 | 'gep_filepath', type=str, 32 | help='Path to the gene expression profile data.' 33 | ) 34 | parser.add_argument( 35 | 'smi_filepath', type=str, 36 | help='Path to the SMILES data.' 37 | ) 38 | parser.add_argument( 39 | 'gene_filepath', type=str, 40 | help='Path to a pickle object containing list of genes.' 41 | ) 42 | parser.add_argument( 43 | 'smiles_language_filepath', type=str, 44 | help='Path to a folder with SMILES language .json files.' 45 | ) 46 | parser.add_argument( 47 | 'model_filepath', type=str, 48 | help='Path to the stored model.' 49 | ) 50 | parser.add_argument( 51 | 'predictions_filepath', type=str, 52 | help='Path to the predictions.' 53 | ) 54 | parser.add_argument( 55 | 'params_filepath', type=str, 56 | help='Path to the parameter file.' 57 | ) 58 | # yapf: enable 59 | 60 | 61 | def main( 62 | test_sensitivity_filepath, gep_filepath, 63 | smi_filepath, gene_filepath, smiles_language_filepath, model_filepath, predictions_filepath, 64 | params_filepath 65 | ): 66 | 67 | logger = logging.getLogger('test') 68 | # Process parameter file: 69 | params = {} 70 | with open(params_filepath) as fp: 71 | params.update(json.load(fp)) 72 | 73 | 74 | # Prepare the dataset 75 | logger.info("Start data preprocessing...") 76 | 77 | # Load SMILES language 78 | smiles_language = SMILESTokenizer.from_pretrained(smiles_language_filepath) 79 | smiles_language.set_encoding_transforms( 80 | add_start_and_stop=params.get('add_start_and_stop', True), 81 | padding=params.get('padding', True), 82 | padding_length=params.get('smiles_padding_length', None) 83 | ) 84 | test_smiles_language = deepcopy(smiles_language) 85 | smiles_language.set_smiles_transforms( 86 | augment=params.get('augment_smiles', False), 87 | canonical=params.get('smiles_canonical', False), 88 | kekulize=params.get('smiles_kekulize', False), 89 | all_bonds_explicit=params.get('smiles_bonds_explicit', False), 90 | all_hs_explicit=params.get('smiles_all_hs_explicit', False), 91 | remove_bonddir=params.get('smiles_remove_bonddir', False), 92 | remove_chirality=params.get('smiles_remove_chirality', False), 93 | selfies=params.get('selfies', False), 94 | sanitize=params.get('selfies', False) 95 | ) 96 | test_smiles_language.set_smiles_transforms( 97 | augment=False, 98 | canonical=params.get('test_smiles_canonical', False), 99 | kekulize=params.get('smiles_kekulize', False), 100 | all_bonds_explicit=params.get('smiles_bonds_explicit', False), 101 | all_hs_explicit=params.get('smiles_all_hs_explicit', False), 102 | remove_bonddir=params.get('smiles_remove_bonddir', False), 103 | remove_chirality=params.get('smiles_remove_chirality', False), 104 | selfies=params.get('selfies', False), 105 | sanitize=params.get('selfies', False) 106 | ) 107 | 108 | # Load the gene list 109 | with open(gene_filepath, 'rb') as f: 110 | gene_list = pickle.load(f) 111 | 112 | # Assemble test dataset 113 | test_dataset = DrugSensitivityDataset( 114 | drug_sensitivity_filepath=test_sensitivity_filepath, 115 | smi_filepath=smi_filepath, 116 | gene_expression_filepath=gep_filepath, 117 | smiles_language=test_smiles_language, 118 | gene_list=gene_list, 119 | drug_sensitivity_min_max=params.get('drug_sensitivity_min_max', True), 120 | gene_expression_standardize=params.get( 121 | 'gene_expression_standardize', True 122 | ), 123 | gene_expression_min_max=params.get('gene_expression_min_max', False), 124 | gene_expression_processing_parameters=params.get( 125 | 'gene_expression_processing_parameters', {} 126 | ), 127 | device=torch.device(params.get('dataset_device', 'cpu')), 128 | iterate_dataset=False 129 | ) 130 | test_loader = torch.utils.data.DataLoader( 131 | dataset=test_dataset, 132 | batch_size=params['batch_size'], 133 | shuffle=False, 134 | drop_last=False, 135 | num_workers=params.get('num_workers', 0) 136 | ) 137 | logger.info( 138 | f'Test dataset has {len(test_dataset)} samples with {len(test_loader)} batches' 139 | ) 140 | 141 | device = get_device() 142 | logger.info( 143 | f'Device for data loader is {test_dataset.device} and for ' 144 | f'model is {device}' 145 | ) 146 | 147 | model_name = params.get('model_fn', 'paccmann') 148 | model = MODEL_FACTORY[model_name](params).to(device) 149 | model._associate_language(smiles_language) 150 | try: 151 | logger.info(f'Attempting to restore model from {model_filepath}...') 152 | model.load(model_filepath, map_location=device) 153 | except Exception: 154 | raise ValueError(f'Error in restoring model from {model_filepath}!') 155 | 156 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 157 | params.update({'number_of_parameters': num_params}) 158 | logger.info(f'Number of parameters {num_params}') 159 | 160 | # Start testing 161 | logger.info('Testing about to start... \n') 162 | model.eval() 163 | 164 | with torch.no_grad(): 165 | test_loss = 0 166 | predictions = [] 167 | # gene_attentions = [] 168 | # epistemic_confs = [] 169 | # aleatoric_confs = [] 170 | labels = [] 171 | for ind, (smiles, gep, y) in tqdm(enumerate(test_loader), total=len(test_loader)): 172 | y_hat, pred_dict = model( 173 | torch.squeeze(smiles.to(device)), gep.to(device), confidence = False 174 | ) 175 | predictions.extend(list(y_hat.detach().cpu().squeeze().numpy())) 176 | # gene_attentions.append(pred_dict['gene_attention']) 177 | # epistemic_confs.append(pred_dict['epistemic_confidence']) 178 | # aleatoric_confs.append(pred_dict['aleatoric_confidence']) 179 | labels.extend(list(y.detach().cpu().squeeze().numpy())) 180 | loss = model.loss(y_hat, y.to(device)) 181 | test_loss += loss.item() 182 | 183 | #gene_attentions = np.array([a.cpu().numpy() for atts in gene_attentions for a in atts]) 184 | #epistemic_confs = np.array([c.cpu().numpy() for conf in epistemic_confs for c in conf]).ravel() 185 | #aleatoric_confs = np.array([c.cpu().numpy() for conf in aleatoric_confs for c in conf]).ravel() 186 | predictions = np.array(predictions) 187 | labels = np.array(labels) 188 | 189 | pearson = pearsonr(predictions, labels)[0] 190 | rmse = np.sqrt(np.mean((predictions - labels)**2)) 191 | loss = test_loss / len(test_loader) 192 | logger.info( 193 | f"\t**RESULT**\t loss:{loss:.5f}, Pearson: {pearson:.3f}, RMSE: {rmse:.3f}" 194 | ) 195 | 196 | df = test_dataset.drug_sensitivity_df 197 | df['prediction'] = predictions 198 | df.to_csv(predictions_filepath+'.csv') 199 | 200 | #np.save(predictions_filepath+'_gene_attention.npy', gene_attentions) 201 | #np.save(predictions_filepath+'_epistemic_confidence.npy', epistemic_confs) 202 | #np.save(predictions_filepath+'_aleatoric_confidence.npy', aleatoric_confs) 203 | 204 | if __name__ == '__main__': 205 | # parse arguments 206 | args = parser.parse_args() 207 | # run the testing 208 | main( 209 | args.test_sensitivity_filepath, 210 | args.gep_filepath, args.smi_filepath, args.gene_filepath, 211 | args.smiles_language_filepath, args.model_filepath, args.predictions_filepath, args.params_filepath 212 | ) 213 | -------------------------------------------------------------------------------- /examples/affinity/predict_affinity.py: -------------------------------------------------------------------------------- 1 | """Predict Affinity for a list of proteins and SMILES.""" 2 | import argparse 3 | import json 4 | import logging 5 | import os 6 | import sys 7 | 8 | import pandas as pd 9 | import torch 10 | from paccmann_predictor.models import MODEL_FACTORY 11 | from paccmann_predictor.utils.utils import get_device 12 | from pytoda.files import read_smi 13 | from pytoda.proteins.protein_language import ProteinLanguage 14 | from pytoda.smiles.smiles_language import SMILESTokenizer 15 | from pytoda.transforms import LeftPadding, ToTensor 16 | from pytoda.datasets import SMILESTokenizerDataset 17 | from paccmann_predictor.utils.interpret import ( 18 | monte_carlo_dropout, 19 | test_time_augmentation, 20 | ) 21 | from pytoda.smiles.transforms import AugmentTensor 22 | 23 | # setup logging 24 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 25 | 26 | # yapf: disable 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument( 29 | 'model_path', type=str, 30 | help='Path to the trained model' 31 | ) 32 | parser.add_argument( 33 | 'protein_filepath', type=str, 34 | help='Path to a .smi file with protein sequences.' 35 | ) 36 | parser.add_argument( 37 | 'smi_filepath', type=str, 38 | help='Path to a .smi file with SMILES sequences.' 39 | ) 40 | parser.add_argument( 41 | 'output_folder', type=str, 42 | help='Directory where the output .csv will be stored.' 43 | ) 44 | parser.add_argument( 45 | '-m', '--model_id', type=str, 46 | help='ID for model factory', default='bimodal_mca' 47 | ) 48 | parser.add_argument( 49 | '-s', '--smiles_language_filepath', type=str, default='.', 50 | help='Path to a SMILES language object.' 51 | ) 52 | parser.add_argument( 53 | '-p', '--protein_language_filepath', type=str, default='.', 54 | help='Path to a pickle of a Protein language object.' 55 | ) 56 | parser.add_argument( 57 | '-l', '--label_filepath', type=str, default=None, required=False, 58 | help='Optional path to a file with labels' 59 | ) 60 | parser.add_argument( 61 | '-c', '--confidence', action='store_true', 62 | help='Whether or not confidence predictions should be performed' 63 | ) 64 | # yapf: enable 65 | 66 | 67 | def main( 68 | model_path, 69 | protein_filepath, 70 | smi_filepath, 71 | output_folder, 72 | model_id, 73 | smiles_language_filepath, 74 | protein_language_filepath, 75 | label_filepath, 76 | confidence, 77 | ): 78 | 79 | logger = logging.getLogger('affinity_prediction') 80 | 81 | # Process parameter file: 82 | params = {} 83 | with open(os.path.join(model_path, 'model_params.json'), 'r') as fp: 84 | params.update(json.load(fp)) 85 | 86 | # Create model directory 87 | os.makedirs(output_folder, exist_ok=True) 88 | 89 | device = get_device() 90 | weights_path = os.path.join(model_path, 'weights', 'best_ROC-AUC_bimodal_mca.pt') 91 | 92 | if label_filepath is not None: 93 | label_df = pd.read_csv(label_filepath, index_col=0) 94 | 95 | if smiles_language_filepath == '.': 96 | smiles_language_filepath = os.path.join(model_path, 'smiles_language.json') 97 | if protein_language_filepath == '.': 98 | protein_language_filepath = os.path.join(model_path, 'protein_language.pkl') 99 | # Load languages 100 | protein_language = ProteinLanguage.load(protein_language_filepath) 101 | smiles_language = SMILESTokenizer( 102 | vocab_file=smiles_language_filepath, 103 | padding=params.get('smiles_padding', True), 104 | padding_length=params.get('smiles_padding_length', None), 105 | add_start_and_stop=params.get('smiles_add_start_stop', True), 106 | augment=False, 107 | canonical=params.get('smiles_test_canonical', False), 108 | kekulize=params.get('smiles_kekulize', False), 109 | all_bonds_explicit=params.get('smiles_bonds_explicit', False), 110 | all_hs_explicit=params.get('smiles_all_hs_explicit', False), 111 | remove_bonddir=params.get('smiles_remove_bonddir', False), 112 | remove_chirality=params.get('smiles_remove_chirality', False), 113 | selfies=params.get('selfies', False), 114 | ) 115 | augment = AugmentTensor(smiles_language) 116 | 117 | model = MODEL_FACTORY[model_id](params).to(device) 118 | 119 | if os.path.isfile(weights_path): 120 | try: 121 | model.load(weights_path, map_location=device) 122 | except Exception: 123 | raise ValueError(f'Error in model restoring from {weights_path}') 124 | else: 125 | logger.info(f'Did not find weights at {weights_path}, name weights "best.pt".') 126 | model.eval() 127 | 128 | # Transforms 129 | to_tensor = ToTensor() 130 | pad_seq = LeftPadding(model.protein_padding_length, protein_language.padding_index) 131 | 132 | # Read data 133 | sequences = read_smi(protein_filepath, names=['Sequence', 'Name']) 134 | ligands = read_smi(smi_filepath) 135 | 136 | smiles_data = SMILESTokenizerDataset( 137 | smi_filepath, smiles_language=smiles_language, iterate_dataset=False 138 | ) 139 | smiles_loader = torch.utils.data.DataLoader( 140 | smiles_data, batch_size=256, drop_last=False, num_workers=0, shuffle=False 141 | ) 142 | 143 | for idx, (sequence_id, row) in enumerate(sequences.iterrows()): 144 | logger.info(f'Target {idx+1}/{len(sequences)}: {sequence_id}') 145 | 146 | proteins = to_tensor( 147 | pad_seq(protein_language.sequence_to_token_indexes(row['Sequence'])) 148 | ).unsqueeze(0) 149 | 150 | target_preds = [] 151 | epi_confs, epi_preds, ale_confs, ale_preds = [], [], [], [] 152 | for sidx, smiles_batch in enumerate(smiles_loader): 153 | protein_batch = proteins.repeat(len(smiles_batch), 1) 154 | preds, pred_dict = model(smiles_batch, protein_batch) 155 | target_preds.extend(preds.detach().squeeze().tolist()) 156 | 157 | # Get confidences 158 | if confidence: 159 | 160 | ale_conf, ale_pred = test_time_augmentation( 161 | model, 162 | regime='tensors', 163 | tensors=(smiles_batch, protein_batch), 164 | augmenter=augment, 165 | tensors_to_augment=0, 166 | ) 167 | epi_conf, epi_pred = monte_carlo_dropout( 168 | model, regime='tensors', tensors=(smiles_batch, protein_batch) 169 | ) 170 | epi_confs.extend(epi_conf.detach().squeeze().tolist()) 171 | epi_preds.extend(epi_pred.detach().squeeze().tolist()) 172 | ale_confs.extend(ale_conf.detach().squeeze().tolist()) 173 | ale_preds.extend(ale_pred.detach().squeeze().tolist()) 174 | 175 | save_name = ( 176 | sequence_id.strip() 177 | .replace(' ', '_') 178 | .replace('\\', '_') 179 | .replace('/', '_') 180 | .replace('=', '_') 181 | ) 182 | df = pd.DataFrame({'SMILES': ligands['SMILES'], 'affinity': target_preds}) 183 | if confidence: 184 | df['epistemic_confidence'] = epi_confs 185 | df['aleatoric_confidence'] = ale_confs 186 | df['epistemic_affinity'] = epi_preds 187 | df['aleatoric_affinity'] = ale_preds 188 | 189 | # Retrieve labels 190 | if label_filepath is not None: 191 | labels, ligand_names = [], [] 192 | for smiles in ligands['SMILES']: 193 | try: 194 | selected_row = label_df[ 195 | ( 196 | label_df['ligand_name'] 197 | == ligands[ligands['SMILES'] == smiles].index[0] 198 | ) 199 | & (label_df['sequence_id'] == row['Name']) 200 | ] 201 | labels.append(selected_row['label'].values[0]) 202 | ligand_names.append(selected_row['ligand_name'].values[0]) 203 | except IndexError: 204 | labels.append(-1) 205 | ligand_names.append(' ') 206 | df['ligand_name'] = ligand_names 207 | df['labels'] = labels 208 | 209 | df.to_csv(os.path.join(output_folder, f'{save_name}.csv'), index=False) 210 | 211 | # Free memory 212 | del preds, pred_dict 213 | 214 | logger.info('Done, shutting down.') 215 | 216 | 217 | if __name__ == '__main__': 218 | # parse arguments 219 | args = parser.parse_args() 220 | # run the predictions 221 | main( 222 | args.model_path, 223 | args.protein_filepath, 224 | args.smi_filepath, 225 | args.output_folder, 226 | args.model_id, 227 | args.smiles_language_filepath, 228 | args.protein_language_filepath, 229 | args.label_filepath, 230 | args.confidence, 231 | ) 232 | -------------------------------------------------------------------------------- /paccmann_predictor/utils/layers.py: -------------------------------------------------------------------------------- 1 | """Custom layers implementation.""" 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .utils import Squeeze, get_device, Temperature, Unsqueeze 8 | 9 | DEVICE = get_device() 10 | 11 | 12 | def dense_layer( 13 | input_size, hidden_size, act_fn=nn.ReLU(), batch_norm=False, dropout=0.0 14 | ): 15 | return nn.Sequential( 16 | OrderedDict( 17 | [ 18 | ('projection', nn.Linear(input_size, hidden_size)), 19 | ( 20 | 'batch_norm', 21 | nn.BatchNorm1d(hidden_size) 22 | if batch_norm else nn.Identity(), 23 | ), 24 | ('act_fn', act_fn), 25 | ('dropout', nn.Dropout(p=dropout)), 26 | ] 27 | ) 28 | ) 29 | 30 | 31 | def dense_attention_layer( 32 | number_of_features: int, 33 | temperature: float = 1.0, 34 | dropout=0.0 35 | ) -> nn.Sequential: 36 | """Attention mechanism layer for dense inputs. 37 | 38 | Args: 39 | number_of_features (int): Size to allocate weight matrix. 40 | temperature (float): Softmax temperature parameter (0, inf). Lower 41 | temperature (< 1) result in a more descriminative/spiky softmax, 42 | higher temperature (> 1) results in a smoother attention. 43 | Returns: 44 | callable: a function that can be called with inputs. 45 | """ 46 | return nn.Sequential( 47 | OrderedDict( 48 | [ 49 | ('dense', nn.Linear(number_of_features, number_of_features)), 50 | ('dropout', nn.Dropout(p=dropout)), 51 | ('temperature', Temperature(temperature)), 52 | ('softmax', nn.Softmax(dim=-1)), 53 | ] 54 | ) 55 | ) 56 | 57 | 58 | def convolutional_layer( 59 | num_kernel, 60 | kernel_size, 61 | act_fn=nn.ReLU(), 62 | batch_norm=False, 63 | dropout=0.0, 64 | input_channels=1, 65 | ): 66 | """Convolutional layer. 67 | 68 | Args: 69 | num_kernel (int): Number of convolution kernels. 70 | kernel_size (tuple[int, int]): Size of the convolution kernels. 71 | act_fn (callable): Functional of the nonlinear activation. 72 | batch_norm (bool): whether batch normalization is applied. 73 | dropout (float): Probability for each input value to be 0. 74 | input_channels (int): Number of input channels (defaults to 1). 75 | 76 | Returns: 77 | callable: a function that can be called with inputs. 78 | """ 79 | return nn.Sequential( 80 | OrderedDict( 81 | [ 82 | ( 83 | 'convolve', 84 | torch.nn.Conv2d( 85 | input_channels, # channel_in 86 | num_kernel, # channel_out 87 | kernel_size, # kernel_size 88 | padding=[kernel_size[0] // 2, 89 | 0], # pad for valid conv. 90 | ), 91 | ), 92 | ('squeeze', Squeeze()), 93 | ('act_fn', act_fn), 94 | ('dropout', nn.Dropout(p=dropout)), 95 | ( 96 | 'batch_norm', 97 | nn.BatchNorm1d(num_kernel) 98 | if batch_norm else nn.Identity(), 99 | ), 100 | ] 101 | ) 102 | ) 103 | 104 | 105 | class ContextAttentionLayer(nn.Module): 106 | """ 107 | Implements context attention as in the PaccMann paper (Figure 2C) in 108 | Molecular Pharmaceutics. 109 | With the additional option of having a hidden size in the context. 110 | NOTE: 111 | In tensorflow, weights were initialized from N(0,0.1). Instead, pytorch 112 | uses U(-stddev, stddev) where stddev=1./math.sqrt(weight.size(1)). 113 | """ 114 | 115 | def __init__( 116 | self, 117 | reference_hidden_size: int, 118 | reference_sequence_length: int, 119 | context_hidden_size: int, 120 | context_sequence_length: int = 1, 121 | attention_size: int = 16, 122 | individual_nonlinearity: type = nn.Sequential(), 123 | temperature: float = 1.0, 124 | ): 125 | """Constructor 126 | Arguments: 127 | reference_hidden_size (int): Hidden size of the reference input 128 | over which the attention will be computed (H). 129 | reference_sequence_length (int): Sequence length of the reference 130 | (T). 131 | context_hidden_size (int): This is either simply the amount of 132 | features used as context (G) or, if the context is a sequence 133 | itself, the hidden size of each time point. 134 | context_sequence_length (int): Hidden size in the context, useful 135 | if context is also textual data, i.e. coming from nn.Embedding. 136 | Defaults to 1. 137 | attention_size (int): Hyperparameter of the attention layer, 138 | defaults to 16. 139 | individual_nonlinearities (type): This is an optional 140 | nonlinearity applied to each projection. Defaults to 141 | nn.Sequential(), i.e. no nonlinearity. Otherwise it expects a 142 | torch.nn activation function, e.g. nn.ReLU(). 143 | temperature (float): Temperature parameter to smooth or sharpen the 144 | softmax. Defaults to 1. Temperature > 1 flattens the 145 | distribution, temperature below 1 makes it spikier. 146 | """ 147 | super().__init__() 148 | 149 | self.reference_sequence_length = reference_sequence_length 150 | self.reference_hidden_size = reference_hidden_size 151 | self.context_sequence_length = context_sequence_length 152 | self.context_hidden_size = context_hidden_size 153 | self.attention_size = attention_size 154 | self.individual_nonlinearity = individual_nonlinearity 155 | self.temperature = temperature 156 | 157 | # Project the reference into the attention space 158 | self.reference_projection = nn.Sequential( 159 | OrderedDict( 160 | [ 161 | ( 162 | 'projection', 163 | nn.Linear(reference_hidden_size, attention_size), 164 | ), 165 | ('act_fn', individual_nonlinearity), 166 | ] 167 | ) 168 | ) # yapf: disable 169 | 170 | # Project the context into the attention space 171 | self.context_projection = nn.Sequential( 172 | OrderedDict( 173 | [ 174 | ( 175 | 'projection', 176 | nn.Linear(context_hidden_size, attention_size), 177 | ), 178 | ('act_fn', individual_nonlinearity), 179 | ] 180 | ) 181 | ) # yapf: disable 182 | 183 | # Optionally reduce the hidden size in context 184 | if context_sequence_length > 1: 185 | self.context_hidden_projection = nn.Sequential( 186 | OrderedDict( 187 | [ 188 | ( 189 | 'projection', 190 | nn.Linear( 191 | context_sequence_length, 192 | reference_sequence_length, 193 | ), 194 | ), 195 | ('act_fn', individual_nonlinearity), 196 | ] 197 | ) 198 | ) # yapf: disable 199 | else: 200 | self.context_hidden_projection = nn.Sequential() 201 | 202 | self.alpha_projection = nn.Sequential( 203 | OrderedDict( 204 | [ 205 | ('projection', nn.Linear(attention_size, 1, bias=False)), 206 | ('squeeze', Squeeze()), 207 | ('temperature', Temperature(self.temperature)), 208 | ('softmax', nn.Softmax(dim=1)), 209 | ] 210 | ) 211 | ) 212 | 213 | def forward( 214 | self, 215 | reference: torch.Tensor, 216 | context: torch.Tensor, 217 | average_seq: bool = True 218 | ): 219 | """ 220 | Forward pass through a context attention layer 221 | Arguments: 222 | reference (torch.Tensor): This is the reference input on which 223 | attention is computed. Shape: bs x ref_seq_length x ref_hidden_size 224 | context (torch.Tensor): This is the context used for attention. 225 | Shape: bs x context_seq_length x context_hidden_size 226 | average_seq (bool): Whether the filtered attention is averaged over the 227 | sequence length. 228 | NOTE: This is recommended to be True, however if the ref_hidden_size 229 | is 1, this can be used to prevent collapsing to a single float. 230 | Defaults to True. 231 | Returns: 232 | (output, attention_weights): A tuple of two Tensors, first one 233 | containing the reference filtered by attention (shape: 234 | bs x ref_hidden_size) and the second one the 235 | attention weights (bs x ref_seq_length). 236 | NOTE: If average_seq is False, the output is: bs x ref_seq_length 237 | """ 238 | assert len(reference.shape) == 3, 'Reference tensor needs to be 3D' 239 | assert len(context.shape) == 3, 'Context tensor needs to be 3D' 240 | 241 | reference_attention = self.reference_projection(reference) 242 | context_attention = self.context_hidden_projection( 243 | self.context_projection(context).permute(0, 2, 1) 244 | ).permute(0, 2, 1) 245 | alphas = self.alpha_projection( 246 | torch.tanh(reference_attention + context_attention) 247 | ) 248 | 249 | output = reference * torch.unsqueeze(alphas, -1) 250 | output = torch.sum(output, 1) if average_seq else torch.squeeze(output) 251 | 252 | return output, alphas 253 | 254 | 255 | def gene_projection(num_genes, attention_size, ind_nonlin=nn.Sequential()): 256 | return nn.Sequential( 257 | OrderedDict( 258 | [ 259 | ('projection', nn.Linear(num_genes, attention_size)), 260 | ('act_fn', ind_nonlin), 261 | ('expand', Unsqueeze(1)), 262 | ] 263 | ) 264 | ).to(DEVICE) 265 | 266 | 267 | def smiles_projection( 268 | smiles_hidden_size, attention_size, ind_nonlin=nn.Sequential() 269 | ): 270 | return nn.Sequential( 271 | OrderedDict( 272 | [ 273 | ('projection', nn.Linear(smiles_hidden_size, attention_size)), 274 | ('act_fn', ind_nonlin), 275 | ] 276 | ) 277 | ).to(DEVICE) 278 | 279 | 280 | def alpha_projection(attention_size): 281 | return nn.Sequential( 282 | OrderedDict( 283 | [ 284 | ('projection', nn.Linear(attention_size, 1, bias=False)), 285 | ('squeeze', Squeeze()), 286 | ('softmax', nn.Softmax(dim=1)), 287 | ] 288 | ) 289 | ).to(DEVICE) 290 | -------------------------------------------------------------------------------- /paccmann_predictor/utils/interpret.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | 6 | from .utils import get_device 7 | 8 | # We use standard deviation to measure uncertainity since entropy is not 9 | # defined for continuous variables and differential entropy is not ideal. 10 | # In case all predictions are identical, std is 0. If 50% are 0 and 50% are 11 | # one, it is maximal, i.e. 0.5. 12 | MAX_STD = 0.5 13 | MIN_STD = 0.0 14 | 15 | DEVICE = get_device() 16 | 17 | 18 | def map_to_device(inputs: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]: 19 | return tuple(x.to(DEVICE) for x in inputs) 20 | 21 | 22 | def monte_carlo_dropout( 23 | model, regime="loader", loader=None, tensors=None, repetitions=20 24 | ): 25 | """ 26 | Attempts to approximate epistemic uncertainity through MC dropout. 27 | Performs Monte Carlo dropout for a given model and returns a list of 28 | sample-wise confidence estimates. 29 | This method can be used in two regimes, either by passing a dataloader 30 | or by passing a tensor with the raw input to the model. 31 | 32 | NOTE: The method only works for binary classification tasks (possibly 33 | multi-task like in Tox21). It does *not* work for a multi-class 34 | classification like MNIST. 35 | 36 | 37 | Arguments: 38 | model (torch.nn.Module): The torch network to be investigated. 39 | NOTE: Model is assumed to return either a single tensor of 40 | predictions or a n-tupel with the first part being a tensor 41 | of predictions. They need to be [0, 1] where 0 and 1 represent 42 | two classes. 43 | regime (str): from {'loader', 'tensors'}. If 'loader' is used the 44 | the loader argument needs to be fed. If 'tensors' is used all 45 | necessary input tensors need to be fed in the right shape 46 | loader (torch.utils.data.DataLoader): The dataset to be tested 47 | The loader is expected to return a tuple with the last item 48 | being the labels and all others the model inputs. 49 | Is only used if 'regime'=='loader' 50 | tensors (torch.Tensor, tuple): The input tensor(s) for the model 51 | Can either be a single tensor or a tuple of tensors (in the 52 | right order) 53 | repetitions (int): Amount of forward passes for each sample 54 | 55 | Returns: 56 | confidences (torch.Tensor) - shape: loader.dataset x num_tasks 57 | Contains the inverse normalized standard deviation of the MC 58 | dropout estimates. 59 | predictions (torch.Tensor) - shape: loader.dataset x num_tasks 60 | Contains the averaged predictions across all MC dropout estimates. 61 | """ 62 | 63 | if regime != "loader" and regime != "tensors": 64 | raise ValueError("Choose regime from {'loader', 'tensors'}") 65 | 66 | # Activate dropout layers while keeping other rest in eval mode. 67 | def enable_dropout(m): 68 | if type(m) == nn.Dropout: 69 | m.train() 70 | 71 | model.eval() 72 | model.apply(enable_dropout) 73 | 74 | if regime == "loader": 75 | 76 | # Error handling 77 | if not isinstance(loader.sampler, torch.utils.data.sampler.SequentialSampler): 78 | raise AttributeError( 79 | "Data loader does not use sequential sampling. Consider set" 80 | "ting shuffle=False when instantiating the data loader." 81 | ) 82 | 83 | # Run over all batches in the loader 84 | 85 | def call_fn(): 86 | preds = [] 87 | for inputs in loader: 88 | # inputs is a tuple with the last element being the labels 89 | # outs can be a n-tuple returned by the model 90 | outs = model(*map_to_device(inputs[:-1])) 91 | preds.append( 92 | outs[0].detach().cpu() 93 | if isinstance(outs, tuple) 94 | else outs.detach().cpu() 95 | ) 96 | 97 | return torch.cat(preds) 98 | 99 | elif regime == "tensors": 100 | 101 | if not isinstance(tensors, tuple) and not isinstance(tensors, torch.Tensor): 102 | raise ValueError("Tensor needs to either tuple or torch.Tensor") 103 | 104 | inputs = tensors if isinstance(tensors, tuple) else (tensors,) 105 | 106 | def call_fn(): 107 | outs = model(*map_to_device(inputs)) 108 | return outs[0] if isinstance(outs, tuple) else outs 109 | 110 | with torch.no_grad(): 111 | predictions = [torch.unsqueeze(call_fn(), -1) for _ in range(repetitions)] 112 | predictions = torch.cat(predictions, dim=-1) 113 | 114 | # Scale confidences to [0, 1] 115 | confidences = -1 * ((predictions.std(dim=-1) - MIN_STD) / (MAX_STD - MIN_STD)) + 1 116 | 117 | model.eval() 118 | 119 | return confidences, torch.mean(predictions, -1) 120 | 121 | 122 | def test_time_augmentation( 123 | model, 124 | regime="loader", 125 | loader=None, 126 | tensors=None, 127 | repetitions=20, 128 | augmenter=None, 129 | tensors_to_augment=None, 130 | ): 131 | """ 132 | Attempts to measure aleatoric uncertainity through augmentation during test 133 | time. It returns a list of sample-wise confidence estimates. 134 | This method can be used in two regimes, either by passing a dataloader 135 | or by passing a tensor with the raw input to the model. 136 | 137 | NOTE: The method only works for binary classification tasks (possibly 138 | multi-task like in Tox21). So each output of the model should be [0, 1] 139 | where 0 represent two classes. It does *not* work for a multi-class 140 | classification like MNIST. 141 | 142 | Arguments: 143 | model (torch.nn.Module): The torch network to be investigated. 144 | NOTE: Model is assumed to return either a single tensor of 145 | predictions or a n-tupel with the first part being a tensor 146 | of predictions. They need to be [0, 1] where 0 and 1 represent 147 | two classes. 148 | regime (str): from {'loader', 'tensors'}: If 'loader' is used the 149 | the loader argument needs to be fed. If 'tensors' is used all 150 | necessary input tensors need to be fed in the right shape 151 | loader (torch.utils.data.DataLoader): The dataset to be tested 152 | The loader is expected to return a tuple with the last item 153 | being the labels and all others the model inputs. The loader should 154 | natively perform data augmentation. 155 | Is only used if 'regime'=='loader'. 156 | tensors (torch.Tensor, tuple): The input tensor(s) for the model 157 | Can either be a single tensor or a tuple of tensors (in the 158 | right order) 159 | repetitions (int): Amount of forward passes for each sample 160 | augmenter (transform object, list): This can either be function that 161 | performs the augmentation, e.g. an object of type 162 | pytoda.smiles.AugmentTensor (if `tensors` represents a SMILES 163 | tensor). Alternatively, it can also be a list of augmenters with 164 | the same length like tensors_to_augment. 165 | Only used if regime=='tensors'. 166 | tensors_to_augment (Union[int, list]): This can either be an integer 167 | pointing to the tensor to be augmented. E.g. tensors_to_augment = 0 168 | augments the first tensor in tensors. Can also be a list of the 169 | same length as augmenter (if several augmentations should be 170 | performed on several tensors simultaneously). 171 | Only used if regime=='tensors'. 172 | 173 | Returns: 174 | confidences (torch.Tensor) - shape: loader.dataset x num_tasks 175 | Contains the inverse normalized standard deviation of the MC 176 | dropout estimates. 177 | predictions (torch.Tensor) - shape: loader.dataset x num_tasks 178 | Contains the averaged predictions across estimates. 179 | """ 180 | 181 | if regime != "loader" and regime != "tensors": 182 | raise ValueError("Choose regime from {'loader', 'tensors'}") 183 | 184 | model.eval() 185 | 186 | if regime == "loader": 187 | 188 | # Error handling 189 | if not isinstance(loader.sampler, torch.utils.data.sampler.SequentialSampler): 190 | raise AttributeError( 191 | "Data loader does not use sequential sampling. Consider set" 192 | "ting shuffle=False when instantiating the data loader." 193 | ) 194 | 195 | # Run over all batches in the loader 196 | 197 | def call_fn(): 198 | preds = [] 199 | for inputs in loader: 200 | # inputs is a tuple with the last element being the labels 201 | # outs can be a n-tuple returned by the model 202 | outs = model(*map_to_device(inputs[:-1])) 203 | preds.append(outs[0] if isinstance(outs, tuple) else outs) 204 | 205 | return torch.cat(preds) 206 | 207 | elif regime == "tensors": 208 | 209 | if not isinstance(tensors, tuple) and not isinstance(tensors, torch.Tensor): 210 | raise ValueError("Tensor needs to either tuple or torch.Tensor") 211 | if not isinstance(tensors_to_augment, list) and not isinstance( 212 | tensors_to_augment, int 213 | ): 214 | raise ValueError("tensors_to_augment needs to be list or int") 215 | 216 | # Convert input to common formats (tuples and lists) 217 | tensors_to_augment = ( 218 | [tensors_to_augment] 219 | if isinstance(tensors_to_augment, int) 220 | else tensors_to_augment 221 | ) 222 | inputs = tensors if isinstance(tensors, tuple) else (tensors,) 223 | aug_fns = augmenter if isinstance(augmenter, tuple) else (augmenter,) 224 | 225 | # Error handling 226 | if not len(aug_fns) == len(tensors_to_augment): 227 | raise ValueError( 228 | "Provide one augmenter for each tensor you want to augment." 229 | ) 230 | if max(tensors_to_augment) > len(inputs): 231 | raise ValueError( 232 | "tensors_to_augment should be indexes to the tensors used for " 233 | f"augmentation. {max(tensors_to_augment)} is larger than " 234 | f"length of inputs ({len(inputs)})." 235 | ) 236 | 237 | def call_fn(): 238 | # Perform augmentation on all designated functions 239 | augmented_inputs = [ 240 | ( 241 | aug_fns[tensors_to_augment[tensors_to_augment == ind]](tensor) 242 | if ind in tensors_to_augment 243 | else tensor 244 | ) 245 | for ind, tensor in enumerate(tensors) 246 | ] 247 | outs = model(*map_to_device(augmented_inputs)) 248 | return outs[0] if isinstance(outs, tuple) else outs 249 | 250 | with torch.no_grad(): 251 | predictions = [torch.unsqueeze(call_fn(), -1) for _ in range(repetitions)] 252 | predictions = torch.cat(predictions, dim=-1) 253 | 254 | # Scale confidences to [0, 1] 255 | confidences = -1 * ((predictions.std(dim=-1) - MIN_STD) / (MAX_STD - MIN_STD)) + 1 256 | 257 | return torch.clamp(confidences, min=0), torch.mean(predictions, -1) 258 | -------------------------------------------------------------------------------- /examples/IC50/train_paccmann.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Train PaccMann predictor.""" 3 | import argparse 4 | import json 5 | import logging 6 | import os 7 | import pickle 8 | import sys 9 | from copy import deepcopy 10 | from time import time 11 | 12 | import numpy as np 13 | import torch 14 | from paccmann_predictor.models import MODEL_FACTORY 15 | from paccmann_predictor.utils.hyperparams import OPTIMIZER_FACTORY 16 | from paccmann_predictor.utils.loss_functions import pearsonr 17 | from paccmann_predictor.utils.utils import get_device 18 | from pytoda.datasets import DrugSensitivityDataset 19 | from pytoda.smiles.smiles_language import SMILESTokenizer 20 | 21 | # setup logging 22 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 23 | 24 | # yapf: disable 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument( 27 | 'train_sensitivity_filepath', type=str, 28 | help='Path to the drug sensitivity (IC50) data.' 29 | ) 30 | parser.add_argument( 31 | 'test_sensitivity_filepath', type=str, 32 | help='Path to the drug sensitivity (IC50) data.' 33 | ) 34 | parser.add_argument( 35 | 'gep_filepath', type=str, 36 | help='Path to the gene expression profile data.' 37 | ) 38 | parser.add_argument( 39 | 'smi_filepath', type=str, 40 | help='Path to the SMILES data.' 41 | ) 42 | parser.add_argument( 43 | 'gene_filepath', type=str, 44 | help='Path to a pickle object containing list of genes.' 45 | ) 46 | parser.add_argument( 47 | 'smiles_language_filepath', type=str, 48 | help='Path to a folder with SMILES language .json files.' 49 | ) 50 | parser.add_argument( 51 | 'model_path', type=str, 52 | help='Directory where the model will be stored.' 53 | ) 54 | parser.add_argument( 55 | 'params_filepath', type=str, 56 | help='Path to the parameter file.' 57 | ) 58 | parser.add_argument( 59 | 'training_name', type=str, 60 | help='Name for the training.' 61 | ) 62 | # yapf: enable 63 | 64 | 65 | def main( 66 | train_sensitivity_filepath, 67 | test_sensitivity_filepath, 68 | gep_filepath, 69 | smi_filepath, 70 | gene_filepath, 71 | smiles_language_filepath, 72 | model_path, 73 | params_filepath, 74 | training_name, 75 | ): 76 | 77 | logger = logging.getLogger(f"{training_name}") 78 | # Process parameter file: 79 | params = {} 80 | with open(params_filepath) as fp: 81 | params.update(json.load(fp)) 82 | 83 | # Create model directory and dump files 84 | model_dir = os.path.join(model_path, training_name) 85 | os.makedirs(os.path.join(model_dir, "weights"), exist_ok=True) 86 | os.makedirs(os.path.join(model_dir, "results"), exist_ok=True) 87 | with open(os.path.join(model_dir, "model_params.json"), "w") as fp: 88 | json.dump(params, fp, indent=4) 89 | 90 | # Prepare the dataset 91 | logger.info("Start data preprocessing...") 92 | 93 | # Load SMILES language 94 | smiles_language = SMILESTokenizer.from_pretrained(smiles_language_filepath) 95 | smiles_language.set_encoding_transforms( 96 | add_start_and_stop=params.get("add_start_and_stop", True), 97 | padding=params.get("padding", True), 98 | padding_length=params.get("smiles_padding_length", None), 99 | ) 100 | test_smiles_language = deepcopy(smiles_language) 101 | smiles_language.set_smiles_transforms( 102 | augment=params.get("augment_smiles", False), 103 | canonical=params.get("smiles_canonical", False), 104 | kekulize=params.get("smiles_kekulize", False), 105 | all_bonds_explicit=params.get("smiles_bonds_explicit", False), 106 | all_hs_explicit=params.get("smiles_all_hs_explicit", False), 107 | remove_bonddir=params.get("smiles_remove_bonddir", False), 108 | remove_chirality=params.get("smiles_remove_chirality", False), 109 | selfies=params.get("selfies", False), 110 | sanitize=params.get("selfies", False), 111 | ) 112 | test_smiles_language.set_smiles_transforms( 113 | augment=False, 114 | canonical=params.get("test_smiles_canonical", True), 115 | kekulize=params.get("smiles_kekulize", False), 116 | all_bonds_explicit=params.get("smiles_bonds_explicit", False), 117 | all_hs_explicit=params.get("smiles_all_hs_explicit", False), 118 | remove_bonddir=params.get("smiles_remove_bonddir", False), 119 | remove_chirality=params.get("smiles_remove_chirality", False), 120 | selfies=params.get("selfies", False), 121 | sanitize=params.get("selfies", False), 122 | ) 123 | 124 | # Load the gene list 125 | with open(gene_filepath, "rb") as f: 126 | gene_list = pickle.load(f) 127 | 128 | # Assemble datasets 129 | train_dataset = DrugSensitivityDataset( 130 | drug_sensitivity_filepath=train_sensitivity_filepath, 131 | smi_filepath=smi_filepath, 132 | gene_expression_filepath=gep_filepath, 133 | smiles_language=smiles_language, 134 | gene_list=gene_list, 135 | drug_sensitivity_min_max=params.get("drug_sensitivity_min_max", True), 136 | drug_sensitivity_processing_parameters=params.get( 137 | "drug_sensitivity_processing_parameters", {} 138 | ), 139 | gene_expression_standardize=params.get("gene_expression_standardize", True), 140 | gene_expression_min_max=params.get("gene_expression_min_max", False), 141 | gene_expression_processing_parameters=params.get( 142 | "gene_expression_processing_parameters", {} 143 | ), 144 | device=torch.device(params.get("dataset_device", "cpu")), 145 | iterate_dataset=False, 146 | ) 147 | train_loader = torch.utils.data.DataLoader( 148 | dataset=train_dataset, 149 | batch_size=params["batch_size"], 150 | shuffle=True, 151 | drop_last=True, 152 | num_workers=params.get("num_workers", 0), 153 | ) 154 | 155 | test_dataset = DrugSensitivityDataset( 156 | drug_sensitivity_filepath=test_sensitivity_filepath, 157 | smi_filepath=smi_filepath, 158 | gene_expression_filepath=gep_filepath, 159 | smiles_language=smiles_language, 160 | gene_list=gene_list, 161 | drug_sensitivity_min_max=params.get("drug_sensitivity_min_max", True), 162 | drug_sensitivity_processing_parameters=params.get( 163 | "drug_sensitivity_processing_parameters", 164 | train_dataset.drug_sensitivity_processing_parameters, 165 | ), 166 | gene_expression_standardize=params.get("gene_expression_standardize", True), 167 | gene_expression_min_max=params.get("gene_expression_min_max", False), 168 | gene_expression_processing_parameters=params.get( 169 | "gene_expression_processing_parameters", 170 | train_dataset.gene_expression_dataset.processing, 171 | ), 172 | device=torch.device(params.get("dataset_device", "cpu")), 173 | iterate_dataset=False, 174 | ) 175 | test_loader = torch.utils.data.DataLoader( 176 | dataset=test_dataset, 177 | batch_size=params["batch_size"], 178 | shuffle=True, 179 | drop_last=True, 180 | num_workers=params.get("num_workers", 0), 181 | ) 182 | logger.info( 183 | f"Training dataset has {len(train_dataset)} samples, test set has " 184 | f"{len(test_dataset)}." 185 | ) 186 | 187 | device = get_device() 188 | logger.info( 189 | f"Device for data loader is {train_dataset.device} and for " 190 | f"model is {device}" 191 | ) 192 | save_top_model = os.path.join(model_dir, "weights/{}_{}_{}.pt") 193 | params.update( 194 | { # yapf: disable 195 | "number_of_genes": len(gene_list), 196 | "smiles_vocabulary_size": smiles_language.number_of_tokens, 197 | "drug_sensitivity_processing_parameters": train_dataset.drug_sensitivity_processing_parameters, 198 | "gene_expression_processing_parameters": train_dataset.gene_expression_dataset.processing, 199 | } 200 | ) 201 | model_name = params.get("model_fn", "paccmann_v2") 202 | model = MODEL_FACTORY[model_name](params).to(device) 203 | model._associate_language(smiles_language) 204 | 205 | if os.path.isfile(os.path.join(model_dir, "weights", f"best_mse_{model_name}.pt")): 206 | logger.info("Found existing model, restoring now...") 207 | model.load(os.path.join(model_dir, "weights", f"best_mse_{model_name}.pt")) 208 | 209 | with open(os.path.join(model_dir, "results", "mse.json"), "r") as f: 210 | info = json.load(f) 211 | 212 | min_rmse = info["best_rmse"] 213 | max_pearson = info["best_pearson"] 214 | min_loss = info["test_loss"] 215 | 216 | else: 217 | min_loss, min_rmse, max_pearson = 100, 1000, 0 218 | 219 | # Define optimizer 220 | optimizer = OPTIMIZER_FACTORY[params.get("optimizer", "Adam")]( 221 | model.parameters(), lr=params.get("lr", 0.01) 222 | ) 223 | 224 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 225 | params.update({"number_of_parameters": num_params}) 226 | logger.info(f"Number of parameters {num_params}") 227 | logger.info(model) 228 | 229 | # Overwrite params.json file with updated parameters. 230 | with open(os.path.join(model_dir, "model_params.json"), "w") as fp: 231 | json.dump(params, fp) 232 | 233 | # Start training 234 | logger.info("Training about to start...\n") 235 | t = time() 236 | 237 | model.save(save_top_model.format("epoch", "0", model_name)) 238 | 239 | for epoch in range(params["epochs"]): 240 | 241 | model.train() 242 | logger.info(params_filepath.split("/")[-1]) 243 | logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==") 244 | train_loss = 0 245 | 246 | for ind, (smiles, gep, y) in enumerate(train_loader): 247 | y_hat, pred_dict = model(torch.squeeze(smiles.to(device)), gep.to(device)) 248 | loss = model.loss(y_hat, y.to(device)) 249 | optimizer.zero_grad() 250 | loss.backward() 251 | # Apply gradient clipping 252 | # torch.nn.utils.clip_grad_norm_(model.parameters(),1e-6) 253 | optimizer.step() 254 | train_loss += loss.item() 255 | 256 | logger.info( 257 | "\t **** TRAINING **** " 258 | f"Epoch [{epoch + 1}/{params['epochs']}], " 259 | f"loss: {train_loss / len(train_loader):.5f}. " 260 | f"This took {time() - t:.1f} secs." 261 | ) 262 | t = time() 263 | 264 | # Measure validation performance 265 | model.eval() 266 | with torch.no_grad(): 267 | test_loss = 0 268 | predictions = [] 269 | labels = [] 270 | for ind, (smiles, gep, y) in enumerate(test_loader): 271 | y_hat, pred_dict = model( 272 | torch.squeeze(smiles.to(device)), gep.to(device) 273 | ) 274 | predictions.append(y_hat) 275 | labels.append(y) 276 | loss = model.loss(y_hat, y.to(device)) 277 | test_loss += loss.item() 278 | 279 | predictions = np.array([p.cpu() for preds in predictions for p in preds]) 280 | labels = np.array([l.cpu() for label in labels for l in label]) 281 | test_pearson_a = pearsonr(torch.Tensor(predictions), torch.Tensor(labels)) 282 | test_rmse_a = np.sqrt(np.mean((predictions - labels) ** 2)) 283 | test_loss_a = test_loss / len(test_loader) 284 | logger.info( 285 | f"\t **** TESTING **** Epoch [{epoch + 1}/{params['epochs']}], " 286 | f"loss: {test_loss_a:.5f}, " 287 | f"Pearson: {test_pearson_a:.3f}, " 288 | f"RMSE: {test_rmse_a:.3f}" 289 | ) 290 | 291 | def save(path, metric, typ, val=None): 292 | model.save(path.format(typ, metric, model_name)) 293 | with open(os.path.join(model_dir, "results", metric + ".json"), "w") as f: 294 | json.dump(info, f) 295 | np.save( 296 | os.path.join(model_dir, "results", metric + "_preds.npy"), 297 | np.vstack([predictions, labels]), 298 | ) 299 | if typ == "best": 300 | logger.info( 301 | f'\t New best performance in "{metric}"' 302 | f" with value : {val:.7f} in epoch: {epoch}" 303 | ) 304 | 305 | def update_info(): 306 | return { 307 | "best_rmse": str(min_rmse), 308 | "best_pearson": str(float(max_pearson)), 309 | "test_loss": str(min_loss), 310 | "predictions": [float(p) for p in predictions], 311 | } 312 | 313 | if test_loss_a < min_loss: 314 | min_rmse = test_rmse_a 315 | min_loss = test_loss_a 316 | min_loss_pearson = test_pearson_a 317 | info = update_info() 318 | save(save_top_model, "mse", "best", min_loss) 319 | ep_loss = epoch 320 | if test_pearson_a > max_pearson: 321 | max_pearson = test_pearson_a 322 | max_pearson_loss = test_loss_a 323 | info = update_info() 324 | save(save_top_model, "pearson", "best", max_pearson) 325 | ep_pearson = epoch 326 | if (epoch + 1) % params.get("save_model", 100) == 0: 327 | save(save_top_model, "epoch", str(epoch)) 328 | logger.info( 329 | "Overall best performances are: \n \t" 330 | f"Loss = {min_loss:.4f} in epoch {ep_loss} " 331 | f"\t (Pearson was {min_loss_pearson:4f}) \n \t" 332 | f"Pearson = {max_pearson:.4f} in epoch {ep_pearson} " 333 | f"\t (Loss was {max_pearson_loss:2f})" 334 | ) 335 | save(save_top_model, "training", "done") 336 | logger.info("Done with training, models saved, shutting down.") 337 | 338 | 339 | if __name__ == "__main__": 340 | # parse arguments 341 | args = parser.parse_args() 342 | # run the training 343 | main( 344 | args.train_sensitivity_filepath, 345 | args.test_sensitivity_filepath, 346 | args.gep_filepath, 347 | args.smi_filepath, 348 | args.gene_filepath, 349 | args.smiles_language_filepath, 350 | args.model_path, 351 | args.params_filepath, 352 | args.training_name, 353 | ) 354 | -------------------------------------------------------------------------------- /examples/affinity/train_affinity_regression.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Train Affinity predictor model.""" 3 | import argparse 4 | import json 5 | import logging 6 | import os 7 | import sys 8 | from copy import deepcopy 9 | from time import time 10 | 11 | import numpy as np 12 | import torch 13 | from paccmann_predictor.models import MODEL_FACTORY 14 | from paccmann_predictor.utils.hyperparams import OPTIMIZER_FACTORY 15 | from paccmann_predictor.utils.utils import get_device 16 | from pytoda.datasets import DrugAffinityDataset 17 | from pytoda.smiles.smiles_language import SMILESTokenizer 18 | from scipy.stats import pearsonr 19 | 20 | # setup logging 21 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 22 | 23 | # yapf: disable 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument( 26 | 'train_affinity_filepath', type=str, 27 | help='Path to the drug affinity data.' 28 | ) 29 | parser.add_argument( 30 | 'test_affinity_filepath', type=str, 31 | help='Path to the drug affinity data.' 32 | ) 33 | parser.add_argument( 34 | 'protein_filepath', type=str, 35 | help='Path to the protein profile data.' 36 | ) 37 | parser.add_argument( 38 | 'smi_filepath', type=str, 39 | help='Path to the SMILES data.' 40 | ) 41 | parser.add_argument( 42 | 'smiles_language_filepath', type=str, 43 | help='Path to a json for a SMILES language object.' 44 | ) 45 | parser.add_argument( 46 | 'model_path', type=str, 47 | help='Directory where the model will be stored.' 48 | ) 49 | parser.add_argument( 50 | 'params_filepath', type=str, 51 | help='Path to the parameter file.' 52 | ) 53 | parser.add_argument( 54 | 'training_name', type=str, 55 | help='Name for the training.' 56 | ) 57 | # yapf: enable 58 | 59 | 60 | def main( 61 | train_affinity_filepath, 62 | test_affinity_filepath, 63 | protein_filepath, 64 | smi_filepath, 65 | smiles_language_filepath, 66 | model_path, 67 | params_filepath, 68 | training_name, 69 | ): 70 | 71 | logger = logging.getLogger(f'{training_name}') 72 | # Process parameter file: 73 | params = {} 74 | with open(params_filepath) as fp: 75 | params.update(json.load(fp)) 76 | 77 | # Create model directory and dump files 78 | model_dir = os.path.join(model_path, training_name) 79 | os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True) 80 | os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True) 81 | with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: 82 | json.dump(params, fp, indent=4) 83 | 84 | # Prepare the dataset 85 | logger.info("Start data preprocessing...") 86 | device = get_device() 87 | 88 | # Load languages 89 | smiles_language = SMILESTokenizer.from_pretrained(smiles_language_filepath) 90 | # Set transform 91 | test_smiles_language = deepcopy(smiles_language) 92 | smiles_language.set_smiles_transforms( 93 | augment=params.get('augment_smiles', False), 94 | canonical=params.get('smiles_canonical', False), 95 | kekulize=params.get('smiles_kekulize', False), 96 | all_bonds_explicit=params.get('smiles_bonds_explicit', False), 97 | all_hs_explicit=params.get('smiles_all_hs_explicit', False), 98 | remove_bonddir=params.get('smiles_remove_bonddir', False), 99 | remove_chirality=params.get('smiles_remove_chirality', False), 100 | selfies=params.get('selfies', False), 101 | sanitize=params.get('selfies', False) 102 | ) 103 | smiles_language.set_encoding_transforms( 104 | padding=params.get('smiles_padding', True), 105 | padding_length=params.get('smiles_padding_length', None), 106 | add_start_and_stop=params.get('smiles_add_start_stop', True) 107 | ) 108 | test_smiles_language.set_smiles_transforms( 109 | augment=False, 110 | canonical=params.get('test_smiles_canonical', False), 111 | kekulize=params.get('smiles_kekulize', False), 112 | all_bonds_explicit=params.get('smiles_bonds_explicit', False), 113 | all_hs_explicit=params.get('smiles_all_hs_explicit', False), 114 | remove_bonddir=params.get('smiles_remove_bonddir', False), 115 | remove_chirality=params.get('smiles_remove_chirality', False), 116 | selfies=params.get('selfies', False), 117 | sanitize=params.get('selfies', False) 118 | ) 119 | test_smiles_language.set_encoding_transforms( 120 | padding=params.get('smiles_padding', True), 121 | padding_length=params.get('smiles_padding_length', None), 122 | add_start_and_stop=params.get('smiles_add_start_stop', True) 123 | ) 124 | 125 | # Assemble datasets 126 | train_dataset = DrugAffinityDataset( 127 | drug_affinity_filepath=train_affinity_filepath, 128 | column_names=['ligand_name', 'sequence_id', 'affinity'], 129 | smi_filepath=smi_filepath, 130 | protein_filepath=protein_filepath, 131 | smiles_language=smiles_language, 132 | protein_amino_acid_dict=params.get('protein_amino_acid_dict', 'iupac'), 133 | protein_padding=params.get('protein_padding', True), 134 | protein_padding_length=params.get('protein_padding_length', None), 135 | protein_add_start_and_stop=params.get('protein_add_start_stop', True), 136 | protein_augment_by_revert=params.get('protein_augment', False), 137 | device=device, 138 | drug_affinity_dtype=torch.float, 139 | backend='eager', 140 | iterate_dataset=params.get('iterate_dataset', False) 141 | ) 142 | train_loader = torch.utils.data.DataLoader( 143 | dataset=train_dataset, 144 | batch_size=params['batch_size'], 145 | shuffle=True, 146 | drop_last=True, 147 | num_workers=params.get('num_workers', 0), 148 | ) 149 | 150 | test_dataset = DrugAffinityDataset( 151 | drug_affinity_filepath=test_affinity_filepath, 152 | column_names=['ligand_name', 'sequence_id', 'affinity'], 153 | smi_filepath=smi_filepath, 154 | protein_filepath=protein_filepath, 155 | smiles_language=smiles_language, 156 | smiles_padding=params.get('smiles_padding', True), 157 | smiles_padding_length=params.get('smiles_padding_length', None), 158 | smiles_add_start_and_stop=params.get('smiles_add_start_stop', True), 159 | protein_amino_acid_dict=params.get('protein_amino_acid_dict', 'iupac'), 160 | protein_padding=params.get('protein_padding', True), 161 | protein_padding_length=params.get('protein_padding_length', None), 162 | protein_add_start_and_stop=params.get('protein_add_start_stop', True), 163 | protein_augment_by_revert=False, 164 | device=device, 165 | drug_affinity_dtype=torch.float, 166 | backend='eager', 167 | iterate_dataset=params.get('iterate_dataset', False) 168 | ) 169 | test_loader = torch.utils.data.DataLoader( 170 | dataset=test_dataset, 171 | batch_size=params['batch_size'], 172 | shuffle=True, 173 | drop_last=False, 174 | num_workers=params.get('num_workers', 0), 175 | ) 176 | logger.info( 177 | f'Training dataset has {len(train_dataset)} samples, test set has ' 178 | f'{len(test_dataset)}.' 179 | ) 180 | 181 | logger.info( 182 | f'Device for data loader is {train_dataset.device} and for ' 183 | f'model is {device}' 184 | ) 185 | save_top_model = os.path.join(model_dir, 'weights/{}_{}_{}.pt') 186 | protein_language = train_dataset.protein_sequence_dataset.protein_language 187 | params.update( 188 | { 189 | 'smiles_vocabulary_size': smiles_language.number_of_tokens, 190 | 'protein_vocabulary_size': protein_language.number_of_tokens, 191 | } 192 | ) 193 | smiles_language.save_pretrained(model_dir) 194 | protein_language.save(os.path.join(model_dir, 'protein_language.pkl')) 195 | 196 | model_fn = params.get('model_fn', 'bimodal_mca') 197 | model = MODEL_FACTORY[model_fn](params).to(device) 198 | model._associate_language(smiles_language) 199 | model._associate_language(protein_language) 200 | 201 | if os.path.isfile(os.path.join(model_dir, 'weights', 'best_mca.pt')): 202 | logger.info('Found existing model, restoring now...') 203 | try: 204 | model.load(os.path.join(model_dir, 'weights', 'best_mca.pt')) 205 | 206 | with open( 207 | os.path.join(model_dir, 'results', 'mse.json'), 'r' 208 | ) as f: 209 | info = json.load(f) 210 | 211 | max_pearson = info['best_pearson'] 212 | min_loss = info['test_loss'] 213 | min_rmse = info['best_rmse'] 214 | 215 | except Exception: 216 | min_loss, max_pearson, min_rmse = 10000, -1, 10000 217 | else: 218 | min_loss, max_pearson, min_rmse = 10000, -1, 10000 219 | 220 | # Define optimizer 221 | optimizer = OPTIMIZER_FACTORY[ 222 | params.get('optimizer', 223 | 'adam')](model.parameters(), lr=params.get('lr', 0.001)) 224 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 225 | params.update({'number_of_parameters': num_params}) 226 | logger.info(f'Number of parameters: {num_params}') 227 | logger.info(f'Model: {model}') 228 | 229 | # Overwrite params.json file with updated parameters. 230 | with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: 231 | json.dump(params, fp) 232 | 233 | # Start training 234 | logger.info('Training about to start...\n') 235 | t = time() 236 | 237 | logger.info(train_dataset.smiles_dataset.smiles_language.transform_smiles) 238 | logger.info( 239 | train_dataset.smiles_dataset.smiles_language.transform_encoding 240 | ) 241 | logger.info(test_dataset.smiles_dataset.smiles_language.transform_smiles) 242 | logger.info(test_dataset.smiles_dataset.smiles_language.transform_encoding) 243 | logger.info(train_dataset.protein_sequence_dataset.language_transforms) 244 | logger.info(test_dataset.protein_sequence_dataset.language_transforms) 245 | 246 | for epoch in range(params['epochs']): 247 | 248 | model.train() 249 | logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==") 250 | train_loss = 0 251 | 252 | for ind, (smiles, proteins, y) in enumerate(train_loader): 253 | if ind % 1000 == 0: 254 | logger.info(f'Batch {ind}/{len(train_loader)}') 255 | y_hat, pred_dict = model(smiles, proteins) 256 | loss = model.loss(y_hat, y.to(device)) 257 | optimizer.zero_grad() 258 | loss.backward() 259 | # Apply gradient clipping 260 | # torch.nn.utils.clip_grad_norm_(model.parameters(),1e-6) 261 | optimizer.step() 262 | train_loss += loss.item() 263 | 264 | logger.info( 265 | "\t **** TRAINING **** " 266 | f"Epoch [{epoch + 1}/{params['epochs']}], " 267 | f"loss: {train_loss / len(train_loader):.5f}. " 268 | f"This took {time() - t:.1f} secs." 269 | ) 270 | t = time() 271 | 272 | # Measure validation performance 273 | model.eval() 274 | with torch.no_grad(): 275 | test_loss = 0 276 | predictions = [] 277 | labels = [] 278 | for ind, (smiles, proteins, y) in enumerate(test_loader): 279 | y_hat, pred_dict = model( 280 | smiles.to(device), proteins.to(device) 281 | ) 282 | predictions.append(y_hat) 283 | labels.append(y.clone()) 284 | loss = model.loss(y_hat, y.to(device)) 285 | test_loss += loss.item() 286 | 287 | predictions = torch.cat(predictions, dim=0).flatten().cpu().numpy() 288 | labels = torch.cat(labels, dim=0).flatten().cpu().numpy() 289 | 290 | test_loss = test_loss / len(test_loader) 291 | test_pearson = pearsonr(predictions, labels)[0] 292 | test_rmse = np.sqrt(np.mean((predictions - labels)**2)) 293 | logger.info( 294 | f"\t **** TESTING **** Epoch [{epoch + 1}/{params['epochs']}], " 295 | f"loss: {test_loss:.5f}, " 296 | f"Pearson: {test_pearson:.3f}, " 297 | f"RMSE: {test_rmse:.3f}" 298 | ) 299 | 300 | def save(path, metric, typ, val=None): 301 | model.save(path.format(typ, metric, model_fn)) 302 | info = { 303 | 'best_pearson': str(max_pearson), 304 | 'best_rmse': str(min_rmse), 305 | 'test_rmse': str(test_rmse), 306 | 'test_pearson': str(test_pearson), 307 | 'test_loss': str(min_loss), 308 | } 309 | with open( 310 | os.path.join(model_dir, 'results', metric + '.json'), 'w' 311 | ) as f: 312 | json.dump(info, f) 313 | np.save( 314 | os.path.join(model_dir, 'results', metric + '_preds.npy'), 315 | np.vstack([predictions, labels]), 316 | ) 317 | if typ == 'best': 318 | logger.info( 319 | f'\t New best performance in "{metric}"' 320 | f' with value : {val:.7f} in epoch: {epoch}' 321 | ) 322 | 323 | if test_pearson > max_pearson: 324 | max_pearson = test_pearson 325 | max_pearson_loss = test_loss 326 | save(save_top_model, 'pearson', 'best', max_pearson) 327 | ep_pearson = epoch 328 | if test_loss < min_loss: 329 | min_loss = test_loss 330 | min_rmse = test_rmse 331 | min_loss_pearson = test_pearson 332 | save(save_top_model, 'mse', 'best', min_loss) 333 | ep_loss = epoch 334 | if (epoch + 1) % params.get('save_model', 100) == 0: 335 | save(save_top_model, 'epoch', str(epoch)) 336 | 337 | logger.info( 338 | 'Overall best performances are: \n \t' 339 | f'Loss = {min_loss:.4f} in epoch {ep_loss} ' 340 | f'\t (Pearson was {min_loss_pearson:4f}) \n \t' 341 | f'Pearson = {max_pearson:.4f} in epoch {ep_pearson} ' 342 | f'\t (Loss was {max_pearson_loss:4f})' 343 | ) 344 | save(save_top_model, 'training', 'done') 345 | logger.info('Done with training, models saved, shutting down.') 346 | 347 | 348 | if __name__ == '__main__': 349 | # parse arguments 350 | args = parser.parse_args() 351 | # run the training 352 | main( 353 | args.train_affinity_filepath, args.test_affinity_filepath, 354 | args.protein_filepath, args.smi_filepath, 355 | args.smiles_language_filepath, args.model_path, args.params_filepath, 356 | args.training_name 357 | ) 358 | -------------------------------------------------------------------------------- /paccmann_predictor/models/paccmann.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | from collections import OrderedDict 4 | 5 | import pytoda 6 | import torch 7 | import torch.nn as nn 8 | from pytoda.smiles.transforms import AugmentTensor 9 | 10 | from ..utils.hyperparams import ACTIVATION_FN_FACTORY, LOSS_FN_FACTORY 11 | from ..utils.interpret import monte_carlo_dropout, test_time_augmentation 12 | from ..utils.layers import ( 13 | ContextAttentionLayer, convolutional_layer, dense_attention_layer, 14 | dense_layer 15 | ) 16 | from ..utils.utils import get_device, get_log_molar 17 | 18 | # setup logging 19 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class MCA(nn.Module): 24 | """Multiscale Convolutional Attentive Encoder. 25 | 26 | This is the MCA model as presented in the authors publication in 27 | Molecular Pharmaceutics: 28 | https://pubs.acs.org/doi/10.1021/acs.molpharmaceut.9b00520. 29 | """ 30 | 31 | def __init__(self, params, *args, **kwargs): 32 | """Constructor. 33 | 34 | Args: 35 | params (dict): A dictionary containing the parameter to built the 36 | dense encoder. 37 | TODO params should become actual arguments (use **params). 38 | 39 | Items in params: 40 | smiles_embedding_size (int): dimension of tokens' embedding. 41 | smiles_vocabulary_size (int): size of the tokens vocabulary. 42 | 43 | activation_fn (string, optional): Activation function used in all 44 | layers for specification in ACTIVATION_FN_FACTORY. 45 | Defaults to 'relu'. 46 | batch_norm (bool, optional): Whether batch normalization is 47 | applied. Defaults to True. 48 | dropout (float, optional): Dropout probability in all 49 | except parametric layer. Defaults to 0.5. 50 | filters (list[int], optional): Numbers of filters to learn per 51 | convolutional layer. Defaults to [64, 64, 64]. 52 | kernel_sizes (list[list[int]], optional): Sizes of kernels per 53 | convolutional layer. Defaults to [ 54 | [3, params['smiles_embedding_size']], 55 | [5, params['smiles_embedding_size']], 56 | [11, params['smiles_embedding_size']] 57 | ] 58 | NOTE: The kernel sizes should match the dimensionality of the 59 | smiles_embedding_size, so if the latter is 8, the images are 60 | t x 8, then treat the 8 embedding dimensions like channels 61 | in an RGB image. 62 | multiheads (list[int], optional): Amount of attentive multiheads 63 | per SMILES embedding. Should have len(filters)+1. 64 | Defaults to [4, 4, 4, 4]. 65 | stacked_dense_hidden_sizes (list[int], optional): Sizes of the 66 | hidden dense layers. Defaults to [1024, 512]. 67 | smiles_attention_size (int, optional): size of the attentive layer 68 | for the smiles sequence. Defaults to 64. 69 | temperature (float, optional): Softmax temperature parameter for 70 | gene attention (0, inf). 71 | 72 | Example params: 73 | ``` 74 | { 75 | "smiles_attention_size": 8, 76 | "smiles_vocabulary_size": 28, 77 | "smiles_embedding_size": 8, 78 | "filters": [128, 128], 79 | "kernel_sizes": [[3, 8], [5, 8]], 80 | "multiheads":[32, 32, 32] 81 | "stacked_dense_hidden_sizes": [512, 64, 16] 82 | } 83 | ``` 84 | """ 85 | 86 | super(MCA, self).__init__(*args, **kwargs) 87 | 88 | # Model Parameter 89 | self.device = get_device() 90 | self.params = params 91 | self.loss_fn = LOSS_FN_FACTORY[params.get('loss_fn', 'mse')] 92 | self.min_max_scaling = True if params.get( 93 | 'drug_sensitivity_processing_parameters', {} 94 | ) != {} else False 95 | if self.min_max_scaling: 96 | self.IC50_max = params[ 97 | 'drug_sensitivity_processing_parameters' 98 | ]['parameters']['max'] # yapf: disable 99 | self.IC50_min = params[ 100 | 'drug_sensitivity_processing_parameters' 101 | ]['parameters']['min'] # yapf: disable 102 | 103 | # Model inputs 104 | self.number_of_genes = params.get('number_of_genes', 2128) 105 | self.smiles_attention_size = params.get('smiles_attention_size', 64) 106 | 107 | # Model architecture (hyperparameter) 108 | self.multiheads = params.get('multiheads', [4, 4, 4, 4]) 109 | self.filters = params.get('filters', [64, 64, 64]) 110 | self.hidden_sizes = ( 111 | [ 112 | self.multiheads[0] * params['smiles_embedding_size'] + sum( 113 | [h * f for h, f in zip(self.multiheads[1:], self.filters)] 114 | ) 115 | ] + params.get('stacked_dense_hidden_sizes', [1024, 512]) 116 | ) 117 | 118 | if params.get('gene_to_dense', False): # Optional skip connection 119 | self.hidden_sizes[0] += self.number_of_genes 120 | self.dropout = params.get('dropout', 0.5) 121 | self.temperature = params.get('temperature', 1.) 122 | self.act_fn = ACTIVATION_FN_FACTORY[ 123 | params.get('activation_fn', 'relu')] 124 | self.kernel_sizes = params.get( 125 | 'kernel_sizes', [ 126 | [3, params['smiles_embedding_size']], 127 | [5, params['smiles_embedding_size']], 128 | [11, params['smiles_embedding_size']] 129 | ] 130 | ) 131 | if len(self.filters) != len(self.kernel_sizes): 132 | raise ValueError( 133 | 'Length of filter and kernel size lists do not match.' 134 | ) 135 | if len(self.filters) + 1 != len(self.multiheads): 136 | raise ValueError( 137 | 'Length of filter and multihead lists do not match' 138 | ) 139 | 140 | # Build the model 141 | self.smiles_embedding = nn.Embedding( 142 | self.params['smiles_vocabulary_size'], 143 | self.params['smiles_embedding_size'], 144 | scale_grad_by_freq=params.get('embed_scale_grad', False) 145 | ) 146 | self.gene_attention_layer = dense_attention_layer( 147 | self.number_of_genes, 148 | temperature=self.temperature, 149 | dropout=self.dropout 150 | ).to(self.device) 151 | 152 | self.convolutional_layers = nn.Sequential( 153 | OrderedDict( 154 | [ 155 | ( 156 | f'convolutional_{index}', 157 | convolutional_layer( 158 | num_kernel, 159 | kernel_size, 160 | act_fn=self.act_fn, 161 | batch_norm=params.get('batch_norm', False), 162 | dropout=self.dropout 163 | ).to(self.device) 164 | ) for index, (num_kernel, kernel_size) in 165 | enumerate(zip(self.filters, self.kernel_sizes)) 166 | ] 167 | ) 168 | ) 169 | 170 | smiles_hidden_sizes = [params['smiles_embedding_size']] + self.filters 171 | 172 | self.context_attention_layers = nn.Sequential(OrderedDict([ 173 | ( 174 | f'context_attention_{layer}_head_{head}', 175 | ContextAttentionLayer( 176 | smiles_hidden_sizes[layer], 177 | 42, # Can be anything since context is only 1D (omic) 178 | self.number_of_genes, 179 | attention_size=self.smiles_attention_size, 180 | individual_nonlinearity=params.get( 181 | 'context_nonlinearity', nn.Sequential() 182 | ) 183 | ) 184 | ) for layer in range(len(self.multiheads)) 185 | for head in range(self.multiheads[layer]) 186 | ])) # yapf: disable 187 | 188 | # Only applied if params['batch_norm'] = True 189 | self.batch_norm = nn.BatchNorm1d(self.hidden_sizes[0]) 190 | self.dense_layers = nn.Sequential( 191 | OrderedDict( 192 | [ 193 | ( 194 | 'dense_{}'.format(ind), 195 | dense_layer( 196 | self.hidden_sizes[ind], 197 | self.hidden_sizes[ind + 1], 198 | act_fn=self.act_fn, 199 | dropout=self.dropout, 200 | batch_norm=params.get('batch_norm', True) 201 | ).to(self.device) 202 | ) for ind in range(len(self.hidden_sizes) - 1) 203 | ] 204 | ) 205 | ) 206 | 207 | self.final_dense = ( 208 | nn.Linear(self.hidden_sizes[-1], 1) 209 | if not params.get('final_activation', False) else nn.Sequential( 210 | OrderedDict( 211 | [ 212 | ('projection', nn.Linear(self.hidden_sizes[-1], 1)), 213 | ('sigmoidal', ACTIVATION_FN_FACTORY['sigmoid']) 214 | ] 215 | ) 216 | ) 217 | ) 218 | 219 | def forward(self, smiles, gep, confidence=False): 220 | """Forward pass through the MCA. 221 | 222 | Args: 223 | smiles (torch.Tensor): of type int and shape `[bs, seq_length]`. 224 | gep (torch.Tensor): of shape `[bs, num_genes]`. 225 | confidence (bool, optional) whether the confidence estimates are 226 | performed. 227 | 228 | Returns: 229 | (torch.Tensor, torch.Tensor): predictions, prediction_dict 230 | 231 | predictions is IC50 drug sensitivity prediction of shape `[bs, 1]`. 232 | prediction_dict includes the prediction and attention weights. 233 | """ 234 | embedded_smiles = self.smiles_embedding(smiles.to(dtype=torch.int64)) 235 | 236 | # Gene attention weights 237 | gene_alphas = self.gene_attention_layer(gep) 238 | 239 | # Filter the gene expression with the weights. 240 | encoded_genes = gene_alphas * gep 241 | 242 | # NOTE: SMILES Convolutions. Unsqueeze has shape bs x 1 x T x H. 243 | encoded_smiles = [embedded_smiles] + [ 244 | self.convolutional_layers[ind] 245 | (torch.unsqueeze(embedded_smiles, 1)).permute(0, 2, 1) 246 | for ind in range(len(self.convolutional_layers)) 247 | ] 248 | 249 | # NOTE: SMILES Attention mechanism 250 | encodings, smiles_alphas = [], [] 251 | context = torch.unsqueeze(encoded_genes, 1) 252 | for layer in range(len(self.multiheads)): 253 | for head in range(self.multiheads[layer]): 254 | ind = self.multiheads[0] * layer + head 255 | e, a = self.context_attention_layers[ind]( 256 | encoded_smiles[layer], context 257 | ) 258 | encodings.append(e) 259 | smiles_alphas.append(a) 260 | 261 | encodings = torch.cat(encodings, dim=1) 262 | if self.params.get('gene_to_dense', False): 263 | encodings = torch.cat([encodings, gep], dim=1) 264 | 265 | # Apply batch normalization if specified 266 | inputs = self.batch_norm(encodings) if self.params.get( 267 | 'batch_norm', False 268 | ) else encodings 269 | # NOTE: stacking dense layers as a bottleneck 270 | for dl in self.dense_layers: 271 | inputs = dl(inputs) 272 | 273 | predictions = self.final_dense(inputs) 274 | 275 | prediction_dict = {} 276 | 277 | if not self.training: 278 | # The below is to ease postprocessing 279 | smiles_attention_weights = torch.mean( 280 | torch.cat( 281 | [torch.unsqueeze(p, -1) for p in smiles_alphas], dim=-1 282 | ), 283 | dim=-1 284 | ) 285 | prediction_dict.update({ 286 | 'gene_attention': gene_alphas, 287 | 'smiles_attention': smiles_attention_weights, 288 | 'IC50': predictions, 289 | 'log_micromolar_IC50': 290 | get_log_molar( 291 | predictions, 292 | ic50_max=self.IC50_max, 293 | ic50_min=self.IC50_min 294 | ) if self.min_max_scaling else predictions 295 | }) # yapf: disable 296 | 297 | if confidence: 298 | augmenter = AugmentTensor(self.smiles_language) 299 | epi_conf, epi_pred = monte_carlo_dropout( 300 | self, 301 | regime='tensors', 302 | tensors=(smiles, gep), 303 | repetitions=5 304 | ) 305 | ale_conf, ale_pred = test_time_augmentation( 306 | self, 307 | regime='tensors', 308 | tensors=(smiles, gep), 309 | repetitions=5, 310 | augmenter=augmenter, 311 | tensors_to_augment=0 312 | ) 313 | 314 | prediction_dict.update({ 315 | 'epistemic_confidence': epi_conf, 316 | 'epistemic_predictions': epi_pred, 317 | 'aleatoric_confidence': ale_conf, 318 | 'aleatoric_predictions': ale_pred 319 | }) # yapf: disable 320 | 321 | elif confidence: 322 | logger.info('Using confidence in training mode is not supported.') 323 | 324 | return predictions, prediction_dict 325 | 326 | def loss(self, yhat, y): 327 | return self.loss_fn(yhat, y) 328 | 329 | def _associate_language(self, smiles_language): 330 | """ 331 | Bind a SMILES language object to the model. Is only used inside the 332 | confidence estimation. 333 | 334 | Arguments: 335 | smiles_language {[pytoda.smiles.smiles_language.SMILESLanguage]} 336 | -- [A SMILES language object] 337 | 338 | Raises: 339 | TypeError: 340 | """ 341 | if not isinstance( 342 | smiles_language, pytoda.smiles.smiles_language.SMILESLanguage 343 | ): 344 | raise TypeError( 345 | 'Please insert a smiles language (object of type ' 346 | 'pytoda.smiles.smiles_language.SMILESLanguage). Given was ' 347 | f'{type(smiles_language)}' 348 | ) 349 | self.smiles_language = smiles_language 350 | 351 | def load(self, path, *args, **kwargs): 352 | """Load model from path.""" 353 | weights = torch.load(path, *args, **kwargs) 354 | self.load_state_dict(weights) 355 | 356 | def save(self, path, *args, **kwargs): 357 | """Save model to path.""" 358 | torch.save(self.state_dict(), path, *args, **kwargs) 359 | -------------------------------------------------------------------------------- /examples/affinity/train_affinity.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Train Affinity predictor model.""" 3 | import argparse 4 | import json 5 | import logging 6 | import os 7 | import sys 8 | from time import time 9 | 10 | import numpy as np 11 | import torch 12 | from sklearn.metrics import ( 13 | auc, 14 | average_precision_score, 15 | precision_recall_curve, 16 | roc_curve, 17 | ) 18 | from paccmann_predictor.models import MODEL_FACTORY 19 | from paccmann_predictor.utils.hyperparams import OPTIMIZER_FACTORY 20 | from paccmann_predictor.utils.utils import get_device 21 | from pytoda.datasets import DrugAffinityDataset 22 | from pytoda.proteins import ProteinLanguage, ProteinFeatureLanguage 23 | from pytoda.smiles import metadata 24 | from pytoda.smiles.smiles_language import SMILESLanguage, SMILESTokenizer 25 | 26 | # setup logging 27 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 28 | 29 | # yapf: disable 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument( 32 | 'train_affinity_filepath', type=str, 33 | help='Path to the drug affinity data.' 34 | ) 35 | parser.add_argument( 36 | 'test_affinity_filepath', type=str, 37 | help='Path to the drug affinity data.' 38 | ) 39 | parser.add_argument( 40 | 'receptor_filepath', type=str, 41 | help='Path to the protein profile data. Receptors must be encoded as amino \ 42 | acids' 43 | ) 44 | parser.add_argument( 45 | 'ligand_filepath', type=str, 46 | help='Path to the ligand data. Ligands must be encoded as SMILES' 47 | ) 48 | parser.add_argument( 49 | 'model_path', type=str, 50 | help='Directory where the model will be stored.' 51 | ) 52 | parser.add_argument( 53 | 'params_filepath', type=str, 54 | help='Path to the parameter file.' 55 | ) 56 | parser.add_argument( 57 | 'training_name', type=str, 58 | help='Name for the training.' 59 | ) 60 | parser.add_argument( 61 | '-smiles_language_filepath', type=str, default='', required=False, 62 | help='Path to smiles language (contains token_count.json, vocab.json and \ 63 | tokenizer_config.json files). If not specified, language from pytoda \ 64 | metadata is loaded.' 65 | ) 66 | # yapf: enable 67 | 68 | 69 | def main( 70 | train_affinity_filepath, test_affinity_filepath, receptor_filepath, 71 | ligand_filepath, model_path, params_filepath, training_name, 72 | smiles_language_filepath 73 | ): 74 | 75 | logger = logging.getLogger(f'{training_name}') 76 | # Process parameter file: 77 | params = {} 78 | with open(params_filepath) as fp: 79 | params.update(json.load(fp)) 80 | 81 | # Create model directory and dump files 82 | model_dir = os.path.join(model_path, training_name) 83 | os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True) 84 | os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True) 85 | with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: 86 | json.dump(params, fp, indent=4) 87 | 88 | # Prepare the dataset 89 | logger.info("Start data preprocessing...") 90 | device = get_device() 91 | 92 | # Load languages 93 | if smiles_language_filepath == '': 94 | smiles_language_filepath = os.path.join( 95 | os.sep, 96 | *metadata.__file__.split(os.sep)[:-1], 'smiles_language' 97 | ) 98 | smiles_language = SMILESTokenizer.from_pretrained(smiles_language_filepath) 99 | smiles_language.set_encoding_transforms( 100 | randomize=None, 101 | add_start_and_stop=params.get('ligand_start_stop_token', True), 102 | padding=params.get('ligand_padding', True), 103 | padding_length=params.get('ligand_padding_length', True), 104 | device=device, 105 | ) 106 | smiles_language.set_smiles_transforms( 107 | augment=params.get('augment_smiles', False), 108 | canonical=params.get('smiles_canonical', False), 109 | kekulize=params.get('smiles_kekulize', False), 110 | all_bonds_explicit=params.get('smiles_bonds_explicit', False), 111 | all_hs_explicit=params.get('smiles_all_hs_explicit', False), 112 | remove_bonddir=params.get('smiles_remove_bonddir', False), 113 | remove_chirality=params.get('smiles_remove_chirality', False), 114 | selfies=params.get('selfies', False), 115 | sanitize=params.get('sanitize', False) 116 | ) 117 | 118 | if params.get('receptor_embedding', 'learned') == 'predefined': 119 | protein_language = ProteinFeatureLanguage( 120 | features=params.get('predefined_embedding', 'blosum') 121 | ) 122 | else: 123 | protein_language = ProteinLanguage() 124 | 125 | if params.get('ligand_embedding', 'learned') == 'one_hot': 126 | logger.warning( 127 | 'ligand_embedding_size parameter in param file is ignored in ' 128 | 'one_hot embedding setting, ligand_vocabulary_size used instead.' 129 | ) 130 | if params.get('receptor_embedding', 'learned') == 'one_hot': 131 | logger.warning( 132 | 'receptor_embedding_size parameter in param file is ignored in ' 133 | 'one_hot embedding setting, receptor_vocabulary_size used instead.' 134 | ) 135 | 136 | # Assemble datasets 137 | train_dataset = DrugAffinityDataset( 138 | drug_affinity_filepath=train_affinity_filepath, 139 | smi_filepath=ligand_filepath, 140 | protein_filepath=receptor_filepath, 141 | protein_language=protein_language, 142 | smiles_language=smiles_language, 143 | smiles_padding=params.get('ligand_padding', True), 144 | smiles_padding_length=params.get('ligand_padding_length', None), 145 | smiles_add_start_and_stop=params.get('ligand_add_start_stop', True), 146 | smiles_augment=params.get('augment_smiles', False), 147 | smiles_canonical=params.get('smiles_canonical', False), 148 | smiles_kekulize=params.get('smiles_kekulize', False), 149 | smiles_all_bonds_explicit=params.get('smiles_bonds_explicit', False), 150 | smiles_all_hs_explicit=params.get('smiles_all_hs_explicit', False), 151 | smiles_remove_bonddir=params.get('smiles_remove_bonddir', False), 152 | smiles_remove_chirality=params.get('smiles_remove_chirality', False), 153 | smiles_selfies=params.get('selfies', False), 154 | protein_amino_acid_dict=params.get('protein_amino_acid_dict', 'iupac'), 155 | protein_padding=params.get('receptor_padding', True), 156 | protein_padding_length=params.get('receptor_padding_length', None), 157 | protein_add_start_and_stop=params.get('receptor_add_start_stop', True), 158 | protein_augment_by_revert=params.get('protein_augment', False), 159 | device=device, 160 | drug_affinity_dtype=torch.float, 161 | backend='eager', 162 | iterate_dataset=params.get('iterate_dataset', False), 163 | ) 164 | train_loader = torch.utils.data.DataLoader( 165 | dataset=train_dataset, 166 | batch_size=params['batch_size'], 167 | shuffle=True, 168 | drop_last=True, 169 | num_workers=params.get('num_workers', 0), 170 | ) 171 | 172 | test_dataset = DrugAffinityDataset( 173 | drug_affinity_filepath=test_affinity_filepath, 174 | smi_filepath=ligand_filepath, 175 | protein_filepath=receptor_filepath, 176 | protein_language=protein_language, 177 | smiles_language=smiles_language, 178 | smiles_padding=params.get('ligand_padding', True), 179 | smiles_padding_length=params.get('ligand_padding_length', None), 180 | smiles_add_start_and_stop=params.get('ligand_add_start_stop', True), 181 | smiles_augment=False, 182 | smiles_canonical=params.get('smiles_test_canonical', False), 183 | smiles_kekulize=params.get('smiles_kekulize', False), 184 | smiles_all_bonds_explicit=params.get('smiles_bonds_explicit', False), 185 | smiles_all_hs_explicit=params.get('smiles_all_hs_explicit', False), 186 | smiles_remove_bonddir=params.get('smiles_remove_bonddir', False), 187 | smiles_remove_chirality=params.get('smiles_remove_chirality', False), 188 | smiles_selfies=params.get('selfies', False), 189 | protein_amino_acid_dict=params.get('protein_amino_acid_dict', 'iupac'), 190 | protein_padding=params.get('receptor_padding', True), 191 | protein_padding_length=params.get('receptor_padding_length', None), 192 | protein_add_start_and_stop=params.get('receptor_add_start_stop', True), 193 | protein_augment_by_revert=False, 194 | device=device, 195 | drug_affinity_dtype=torch.float, 196 | backend='eager', 197 | iterate_dataset=params.get('iterate_dataset', False), 198 | ) 199 | test_loader = torch.utils.data.DataLoader( 200 | dataset=test_dataset, 201 | batch_size=params['batch_size'], 202 | shuffle=True, 203 | drop_last=True, 204 | num_workers=params.get('num_workers', 0), 205 | ) 206 | logger.info( 207 | f'Training dataset has {len(train_dataset)} samples, test set has ' 208 | f'{len(test_dataset)}.' 209 | ) 210 | 211 | logger.info( 212 | f'Device for data loader is {train_dataset.device} and for ' 213 | f'model is {device}' 214 | ) 215 | save_top_model = os.path.join(model_dir, 'weights/{}_{}_{}.pt') 216 | params.update( 217 | { 218 | 'ligand_vocabulary_size': 219 | ( 220 | train_dataset.smiles_dataset.smiles_language. 221 | number_of_tokens 222 | ), 223 | 'receptor_vocabulary_size': protein_language.number_of_tokens, 224 | } 225 | ) 226 | logger.info( 227 | f'Receptor vocabulary size is {protein_language.number_of_tokens} and ' 228 | f'ligand vocabulary size is {train_dataset.smiles_dataset.smiles_language.number_of_tokens}' 229 | ) 230 | model_fn = params.get('model_fn', 'bimodal_mca') 231 | model = MODEL_FACTORY[model_fn](params).to(device) 232 | model._associate_language(smiles_language) 233 | model._associate_language(protein_language) 234 | 235 | if os.path.isfile(os.path.join(model_dir, 'weights', 'best_mca.pt')): 236 | logger.info('Found existing model, restoring now...') 237 | try: 238 | model.load(os.path.join(model_dir, 'weights', 'best_mca.pt')) 239 | 240 | with open( 241 | os.path.join(model_dir, 'results', 'mse.json'), 'r' 242 | ) as f: 243 | info = json.load(f) 244 | 245 | max_roc_auc = info['best_roc_auc'] 246 | min_loss = info['test_loss'] 247 | 248 | except Exception: 249 | min_loss, max_roc_auc = 100, 0 250 | else: 251 | min_loss, max_roc_auc = 100, 0 252 | 253 | # Define optimizer 254 | optimizer = OPTIMIZER_FACTORY[ 255 | params.get('optimizer', 256 | 'adam')](model.parameters(), lr=params.get('lr', 0.001)) 257 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 258 | params.update({'number_of_parameters': num_params}) 259 | logger.info(f'Number of parameters: {num_params}') 260 | logger.info(f'Model: {model}') 261 | 262 | # Overwrite params.json file with updated parameters. 263 | with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: 264 | json.dump(params, fp) 265 | 266 | # Start training 267 | logger.info('Training about to start...\n') 268 | t = time() 269 | 270 | model.save(save_top_model.format('epoch', '0', model_fn)) 271 | 272 | for epoch in range(params['epochs']): 273 | 274 | model.train() 275 | logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==") 276 | train_loss = 0 277 | 278 | for ind, (ligands, receptors, y) in enumerate(train_loader): 279 | if ind % 100 == 0: 280 | logger.info(f'Batch {ind}/{len(train_loader)}') 281 | y_hat, pred_dict = model(ligands, receptors) 282 | loss = model.loss(y_hat, y.to(device)) 283 | optimizer.zero_grad() 284 | loss.backward() 285 | # Apply gradient clipping 286 | # torch.nn.utils.clip_grad_norm_(model.parameters(),1e-6) 287 | optimizer.step() 288 | train_loss += loss.item() 289 | 290 | logger.info( 291 | "\t **** TRAINING **** " 292 | f"Epoch [{epoch + 1}/{params['epochs']}], " 293 | f"loss: {train_loss / len(train_loader):.5f}. " 294 | f"This took {time() - t:.1f} secs." 295 | ) 296 | t = time() 297 | 298 | # Measure validation performance 299 | model.eval() 300 | with torch.no_grad(): 301 | test_loss = 0 302 | predictions = [] 303 | labels = [] 304 | for ind, (ligands, receptors, y) in enumerate(test_loader): 305 | y_hat, pred_dict = model( 306 | ligands.to(device), receptors.to(device) 307 | ) 308 | predictions.append(y_hat) 309 | labels.append(y.clone()) 310 | loss = model.loss(y_hat, y.to(device)) 311 | test_loss += loss.item() 312 | 313 | predictions = torch.cat(predictions, dim=0).flatten().cpu().numpy() 314 | labels = torch.cat(labels, dim=0).flatten().cpu().numpy() 315 | 316 | test_loss = test_loss / len(test_loader) 317 | fpr, tpr, _ = roc_curve(labels, predictions) 318 | test_roc_auc = auc(fpr, tpr) 319 | 320 | # calculations for visualization plot 321 | precision, recall, _ = precision_recall_curve(labels, predictions) 322 | avg_precision = average_precision_score(labels, predictions) 323 | 324 | test_loss = test_loss / len(test_loader) 325 | logger.info( 326 | f"\t **** TESTING **** Epoch [{epoch + 1}/{params['epochs']}], " 327 | f"loss: {test_loss:.5f}, ROC-AUC: {test_roc_auc:.3f}, " 328 | f"Average precision: {avg_precision:.3f}." 329 | ) 330 | 331 | def save(path, metric, typ, val=None): 332 | model.save(path.format(typ, metric, model_fn)) 333 | info = { 334 | 'best_roc_auc': str(max_roc_auc), 335 | 'test_loss': str(min_loss), 336 | } 337 | with open( 338 | os.path.join(model_dir, 'results', metric + '.json'), 'w' 339 | ) as f: 340 | json.dump(info, f) 341 | np.save( 342 | os.path.join(model_dir, 'results', metric + '_preds.npy'), 343 | np.vstack([predictions, labels]), 344 | ) 345 | if typ == 'best': 346 | logger.info( 347 | f'\t New best performance in "{metric}"' 348 | f' with value : {val:.7f} in epoch: {epoch}' 349 | ) 350 | 351 | if test_roc_auc > max_roc_auc: 352 | max_roc_auc = test_roc_auc 353 | save(save_top_model, 'ROC-AUC', 'best', max_roc_auc) 354 | ep_roc = epoch 355 | roc_auc_loss = test_loss 356 | 357 | if test_loss < min_loss: 358 | min_loss = test_loss 359 | save(save_top_model, 'loss', 'best', min_loss) 360 | ep_loss = epoch 361 | loss_roc_auc = test_roc_auc 362 | if (epoch + 1) % params.get('save_model', 100) == 0: 363 | save(save_top_model, 'epoch', str(epoch)) 364 | logger.info( 365 | 'Overall best performances are: \n \t' 366 | f'Loss = {min_loss:.4f} in epoch {ep_loss} ' 367 | f'\t (ROC-AUC was {loss_roc_auc:4f}) \n \t' 368 | f'ROC-AUC = {max_roc_auc:.4f} in epoch {ep_roc} ' 369 | f'\t (Loss was {roc_auc_loss:4f})' 370 | ) 371 | save(save_top_model, 'training', 'done') 372 | logger.info('Done with training, models saved, shutting down.') 373 | 374 | 375 | if __name__ == '__main__': 376 | # parse arguments 377 | args = parser.parse_args() 378 | # run the training 379 | main( 380 | args.train_affinity_filepath, args.test_affinity_filepath, 381 | args.receptor_filepath, args.ligand_filepath, args.model_path, 382 | args.params_filepath, args.training_name, args.smiles_language_filepath 383 | ) 384 | -------------------------------------------------------------------------------- /paccmann_predictor/models/paccmann_v2.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | from collections import OrderedDict 4 | 5 | import pytoda 6 | import torch 7 | import torch.nn as nn 8 | from pytoda.smiles.transforms import AugmentTensor 9 | 10 | from ..utils.hyperparams import ACTIVATION_FN_FACTORY, LOSS_FN_FACTORY 11 | from ..utils.interpret import monte_carlo_dropout, test_time_augmentation 12 | from ..utils.layers import ( 13 | ContextAttentionLayer, convolutional_layer, dense_layer 14 | ) 15 | from ..utils.utils import get_device, get_log_molar 16 | 17 | # setup logging 18 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class PaccMannV2(nn.Module): 23 | """Based on the MCA model in Molecular Pharmaceutics: 24 | https://pubs.acs.org/doi/10.1021/acs.molpharmaceut.9b00520. 25 | Updates: 26 | - Context instead of self attention on omic data 27 | """ 28 | 29 | def __init__(self, params, *args, **kwargs): 30 | """Constructor. 31 | 32 | Args: 33 | params (dict): A dictionary containing the parameter to built the 34 | dense encoder. 35 | TODO params should become actual arguments (use **params). 36 | 37 | Items in params: 38 | smiles_padding_length (int): Padding length for SMILES. 39 | smiles_embedding_size (int): dimension of tokens' embedding. 40 | smiles_vocabulary_size (int): size of the tokens vocabulary. 41 | activation_fn (string, optional): Activation function used in all 42 | layers for specification in ACTIVATION_FN_FACTORY. 43 | Defaults to 'relu'. 44 | batch_norm (bool, optional): Whether batch normalization is 45 | applied. Defaults to True. 46 | dropout (float, optional): Dropout probability in all 47 | except parametric layer. Defaults to 0.5. 48 | filters (list[int], optional): Numbers of filters to learn per 49 | SMILES convolutional layer. Defaults to [64, 64, 64]. 50 | kernel_sizes (list[list[int]], optional): Sizes of kernels per 51 | SMILES convolutional layer. Defaults to [ 52 | [3, params['smiles_embedding_size']], 53 | [5, params['smiles_embedding_size']], 54 | [11, params['smiles_embedding_size']] 55 | ] 56 | NOTE: The kernel sizes should match the dimensionality of the 57 | smiles_embedding_size, so if the latter is 8, the images are 58 | t x 8, then treat the 8 embedding dimensions like channels 59 | in an RGB image. 60 | molecule_heads (list[int], optional): Amount of attentive molecule_heads 61 | per SMILES embedding. Should have len(filters)+1. 62 | Defaults to [4, 4, 4, 4]. 63 | stacked_dense_hidden_sizes (list[int], optional): Sizes of the 64 | hidden dense layers. Defaults to [1024, 512]. 65 | smiles_attention_size (int, optional): size of the attentive layer 66 | for the smiles sequence. Defaults to 64. 67 | """ 68 | 69 | super(PaccMannV2, self).__init__(*args, **kwargs) 70 | 71 | # Model Parameter 72 | self.device = get_device() 73 | self.params = params 74 | self.loss_fn = LOSS_FN_FACTORY[params.get('loss_fn', 'mse')] 75 | self.min_max_scaling = True if params.get( 76 | 'drug_sensitivity_processing_parameters', {} 77 | ) != {} else False 78 | if self.min_max_scaling: 79 | self.IC50_max = params[ 80 | 'drug_sensitivity_processing_parameters' 81 | ]['parameters']['max'] # yapf: disable 82 | self.IC50_min = params[ 83 | 'drug_sensitivity_processing_parameters' 84 | ]['parameters']['min'] # yapf: disable 85 | 86 | # Model inputs 87 | self.smiles_padding_length = params['smiles_padding_length'] 88 | self.number_of_genes = params.get('number_of_genes', 2128) 89 | self.smiles_attention_size = params.get('smiles_attention_size', 64) 90 | self.gene_attention_size = params.get('gene_attention_size', 1) 91 | self.molecule_temperature = params.get('molecule_temperature', 1.) 92 | self.gene_temperature = params.get('gene_temperature', 1.) 93 | 94 | # Model architecture (hyperparameter) 95 | self.molecule_heads = params.get('molecule_heads', [4, 4, 4, 4]) 96 | self.gene_heads = params.get('gene_heads', [2, 2, 2, 2]) 97 | if len(self.gene_heads) != len(self.molecule_heads): 98 | raise ValueError('Length of gene and molecule_heads do not match.') 99 | 100 | self.filters = params.get('filters', [64, 64, 64]) 101 | 102 | self.hidden_sizes = ( 103 | [ 104 | self.molecule_heads[0] * params['smiles_embedding_size'] + sum( 105 | [ 106 | h * f 107 | for h, f in zip(self.molecule_heads[1:], self.filters) 108 | ] 109 | ) + sum(self.gene_heads) * self.number_of_genes 110 | ] + params.get('stacked_dense_hidden_sizes', [1024, 512]) 111 | ) 112 | 113 | self.dropout = params.get('dropout', 0.5) 114 | self.temperature = params.get('temperature', 1.) 115 | self.act_fn = ACTIVATION_FN_FACTORY[ 116 | params.get('activation_fn', 'relu')] 117 | self.kernel_sizes = params.get( 118 | 'kernel_sizes', [ 119 | [3, params['smiles_embedding_size']], 120 | [5, params['smiles_embedding_size']], 121 | [11, params['smiles_embedding_size']] 122 | ] 123 | ) 124 | if len(self.filters) != len(self.kernel_sizes): 125 | raise ValueError( 126 | 'Length of filter and kernel size lists do not match.' 127 | ) 128 | if len(self.filters) + 1 != len(self.molecule_heads): 129 | raise ValueError( 130 | 'Length of filter and multihead lists do not match' 131 | ) 132 | 133 | # Build the model 134 | self.smiles_embedding = nn.Embedding( 135 | self.params['smiles_vocabulary_size'], 136 | self.params['smiles_embedding_size'], 137 | scale_grad_by_freq=params.get('embed_scale_grad', False) 138 | ) 139 | 140 | self.convolutional_layers = nn.Sequential( 141 | OrderedDict( 142 | [ 143 | ( 144 | f'convolutional_{index}', 145 | convolutional_layer( 146 | num_kernel, 147 | kernel_size, 148 | act_fn=self.act_fn, 149 | batch_norm=params.get('batch_norm', False), 150 | dropout=self.dropout 151 | ).to(self.device) 152 | ) for index, (num_kernel, kernel_size) in 153 | enumerate(zip(self.filters, self.kernel_sizes)) 154 | ] 155 | ) 156 | ) 157 | 158 | smiles_hidden_sizes = [params['smiles_embedding_size']] + self.filters 159 | 160 | self.molecule_attention_layers = nn.Sequential(OrderedDict([ 161 | ( 162 | f'molecule_attention_{layer}_head_{head}', 163 | ContextAttentionLayer( 164 | reference_hidden_size=smiles_hidden_sizes[layer], 165 | reference_sequence_length=self.smiles_padding_length, 166 | context_hidden_size=1, 167 | context_sequence_length=self.number_of_genes, 168 | attention_size=self.smiles_attention_size, 169 | individual_nonlinearity=params.get( 170 | 'context_nonlinearity', nn.Sequential() 171 | ), 172 | temperature=self.molecule_temperature 173 | ) 174 | ) for layer in range(len(self.molecule_heads)) 175 | for head in range(self.molecule_heads[layer]) 176 | ])) # yapf: disable 177 | 178 | # Gene attention stream 179 | self.gene_attention_layers = nn.Sequential(OrderedDict([ 180 | ( 181 | f'gene_attention_{layer}_head_{head}', 182 | ContextAttentionLayer( 183 | reference_hidden_size=1, 184 | reference_sequence_length=self.number_of_genes, 185 | context_hidden_size=smiles_hidden_sizes[layer], 186 | context_sequence_length=self.smiles_padding_length, 187 | attention_size=self.gene_attention_size, 188 | individual_nonlinearity=params.get( 189 | 'context_nonlinearity', nn.Sequential() 190 | ), 191 | temperature=self.gene_temperature 192 | ) 193 | ) for layer in range(len(self.molecule_heads)) 194 | for head in range(self.gene_heads[layer]) 195 | ])) # yapf: disable 196 | 197 | # Only applied if params['batch_norm'] = True 198 | self.batch_norm = nn.BatchNorm1d(self.hidden_sizes[0]) 199 | self.dense_layers = nn.Sequential( 200 | OrderedDict( 201 | [ 202 | ( 203 | 'dense_{}'.format(ind), 204 | dense_layer( 205 | self.hidden_sizes[ind], 206 | self.hidden_sizes[ind + 1], 207 | act_fn=self.act_fn, 208 | dropout=self.dropout, 209 | batch_norm=params.get('batch_norm', True) 210 | ).to(self.device) 211 | ) for ind in range(len(self.hidden_sizes) - 1) 212 | ] 213 | ) 214 | ) 215 | 216 | self.final_dense = ( 217 | nn.Linear(self.hidden_sizes[-1], 1) 218 | if not params.get('final_activation', False) else nn.Sequential( 219 | OrderedDict( 220 | [ 221 | ('projection', nn.Linear(self.hidden_sizes[-1], 1)), 222 | ('sigmoidal', ACTIVATION_FN_FACTORY['sigmoid']) 223 | ] 224 | ) 225 | ) 226 | ) 227 | 228 | def forward(self, smiles, gep, confidence=False): 229 | """Forward pass through the PaccMannV2. 230 | 231 | Args: 232 | smiles (torch.Tensor): of type int and shape: [bs, smiles_padding_length] 233 | gep (torch.Tensor): of shape `[bs, number_of_genes]`. 234 | confidence (bool, optional) whether the confidence estimates are 235 | performed. 236 | 237 | Returns: 238 | (torch.Tensor, dict): predictions, prediction_dict 239 | predictions is IC50 drug sensitivity prediction of shape `[bs, 1]`. 240 | prediction_dict includes the prediction and attention weights. 241 | """ 242 | 243 | gep = torch.unsqueeze(gep, dim=-1) 244 | embedded_smiles = self.smiles_embedding(smiles.to(dtype=torch.int64)) 245 | 246 | # SMILES Convolutions. Unsqueeze has shape bs x 1 x T x H. 247 | encoded_smiles = [embedded_smiles] + [ 248 | self.convolutional_layers[ind] 249 | (torch.unsqueeze(embedded_smiles, 1)).permute(0, 2, 1) 250 | for ind in range(len(self.convolutional_layers)) 251 | ] 252 | 253 | # Molecule context attention 254 | encodings, smiles_alphas, gene_alphas = [], [], [] 255 | for layer in range(len(self.molecule_heads)): 256 | for head in range(self.molecule_heads[layer]): 257 | 258 | ind = self.molecule_heads[0] * layer + head 259 | e, a = self.molecule_attention_layers[ind]( 260 | encoded_smiles[layer], gep 261 | ) 262 | encodings.append(e) 263 | smiles_alphas.append(a) 264 | 265 | # Gene context attention 266 | for layer in range(len(self.gene_heads)): 267 | for head in range(self.gene_heads[layer]): 268 | ind = self.gene_heads[0] * layer + head 269 | 270 | e, a = self.gene_attention_layers[ind]( 271 | gep, encoded_smiles[layer], average_seq=False 272 | ) 273 | encodings.append(e) 274 | gene_alphas.append(a) 275 | 276 | encodings = torch.cat(encodings, dim=1) 277 | 278 | # Apply batch normalization if specified 279 | inputs = self.batch_norm(encodings) if self.params.get( 280 | 'batch_norm', False 281 | ) else encodings 282 | # NOTE: stacking dense layers as a bottleneck 283 | for dl in self.dense_layers: 284 | inputs = dl(inputs) 285 | 286 | predictions = self.final_dense(inputs) 287 | prediction_dict = {} 288 | 289 | if not self.training: 290 | # The below is to ease postprocessing 291 | smiles_attention = torch.cat( 292 | [torch.unsqueeze(p, -1) for p in smiles_alphas], dim=-1 293 | ) 294 | gene_attention = torch.cat( 295 | [torch.unsqueeze(p, -1) for p in gene_alphas], dim=-1 296 | ) 297 | prediction_dict.update({ 298 | 'gene_attention': gene_attention, 299 | 'smiles_attention': smiles_attention, 300 | 'IC50': predictions, 301 | 'log_micromolar_IC50': 302 | get_log_molar( 303 | predictions, 304 | ic50_max=self.IC50_max, 305 | ic50_min=self.IC50_min 306 | ) if self.min_max_scaling else predictions 307 | }) # yapf: disable 308 | 309 | if confidence: 310 | augmenter = AugmentTensor(self.smiles_language) 311 | epi_conf, epi_pred = monte_carlo_dropout( 312 | self, 313 | regime='tensors', 314 | tensors=(smiles, gep), 315 | repetitions=5 316 | ) 317 | ale_conf, ale_pred = test_time_augmentation( 318 | self, 319 | regime='tensors', 320 | tensors=(smiles, gep), 321 | repetitions=5, 322 | augmenter=augmenter, 323 | tensors_to_augment=0 324 | ) 325 | 326 | prediction_dict.update({ 327 | 'epistemic_confidence': epi_conf, 328 | 'epistemic_predictions': epi_pred, 329 | 'aleatoric_confidence': ale_conf, 330 | 'aleatoric_predictions': ale_pred 331 | }) # yapf: disable 332 | 333 | elif confidence: 334 | logger.info('Using confidence in training mode is not supported.') 335 | 336 | return predictions, prediction_dict 337 | 338 | def loss(self, yhat, y): 339 | return self.loss_fn(yhat, y) 340 | 341 | def _associate_language(self, smiles_language): 342 | """ 343 | Bind a SMILES language object to the model. Is only used inside the 344 | confidence estimation. 345 | 346 | Arguments: 347 | smiles_language {[pytoda.smiles.smiles_language.SMILESLanguage]} 348 | -- [A SMILES language object] 349 | 350 | Raises: 351 | TypeError: 352 | """ 353 | if not isinstance( 354 | smiles_language, pytoda.smiles.smiles_language.SMILESLanguage 355 | ): 356 | raise TypeError( 357 | 'Please insert a smiles language (object of type ' 358 | 'pytoda.smiles.smiles_language.SMILESLanguage). Given was ' 359 | f'{type(smiles_language)}' 360 | ) 361 | self.smiles_language = smiles_language 362 | 363 | def load(self, path, *args, **kwargs): 364 | """Load model from path.""" 365 | weights = torch.load(path, *args, **kwargs) 366 | self.load_state_dict(weights) 367 | 368 | def save(self, path, *args, **kwargs): 369 | """Save model to path.""" 370 | torch.save(self.state_dict(), path, *args, **kwargs) 371 | -------------------------------------------------------------------------------- /paccmann_predictor/models/bimodal_mca.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import pickle 3 | 4 | import pytoda 5 | import torch 6 | import torch.nn as nn 7 | from pytoda.smiles.transforms import AugmentTensor 8 | 9 | from ..utils.hyperparams import ACTIVATION_FN_FACTORY, LOSS_FN_FACTORY 10 | from ..utils.interpret import monte_carlo_dropout, test_time_augmentation 11 | from ..utils.layers import ( 12 | ContextAttentionLayer, 13 | convolutional_layer, 14 | dense_layer, 15 | ) 16 | from ..utils.utils import get_device 17 | 18 | 19 | class BimodalMCA(nn.Module): 20 | """Bimodal Multiscale Convolutional Attentive Encoder. 21 | 22 | This is based on the MCA model as presented in the publication in 23 | Molecular Pharmaceutics: 24 | https://pubs.acs.org/doi/10.1021/acs.molpharmaceut.9b00520. 25 | """ 26 | 27 | def __init__(self, params, *args, **kwargs): 28 | """Constructor. 29 | 30 | Args: 31 | params (dict): A dictionary containing the parameter to built the 32 | dense encoder. 33 | TODO params should become actual arguments (use **params). 34 | 35 | Required items in params: 36 | ligand_padding_length (int): dimension of tokens' embedding. 37 | ligand_vocabulary_size (int): size of the tokens vocabulary. 38 | receptor_padding_length (int): dimension of tokens' embedding. 39 | receptor_vocabulary_size (int): size of the tokens vocabulary. 40 | Optional items in params: 41 | activation_fn (str): Activation function used in all ayers for 42 | specification in ACTIVATION_FN_FACTORY. Defaults to 'relu'. 43 | batch_norm (bool): Whether batch normalization is applied. Defaults 44 | to True. 45 | dropout (float): Dropout probability in all except context 46 | attention layer. Defaults to 0.5. 47 | ligand_embedding (str): Way to numberically embed ligand sequence. 48 | Options: 'predefined' (sequence is already embedded using 49 | predefined token representations like BLOSUM matrix), 50 | 'one-hot', 'pretrained' (loads embedding from ligand_embedding 51 | path) or 'learned (model learns an embedding from data). 52 | Defaults to 'learned'. 53 | ligand_embedding_path (str): Path where pretrained embedding 54 | weights are stored. Needed if ligand_embedding is 'pretrained'. 55 | receptor_embedding (str): Way to numberically embed receptor sequence. 56 | Options: 'predefined' (sequence is already embedded using 57 | predefined token representations like BLOSUM matrix), 58 | 'one-hot', 'pretrained' (loads embedding from receptor_embedding 59 | path) or 'learned (model learns an embedding from data). 60 | Defaults to 'learned'. 61 | receptor_embedding_path (str): Path where pretrained embedding 62 | weights are stored. Needed if receptor_embedding is 'pretrained'. 63 | ligand_embedding_size (int): Embedding dimensionality, default: 32 64 | receptor_embedding_size (int): Embedding dimensionality, default: 8 65 | ligand_filters (list[int]): Numbers of filters to learn per 66 | convolutional layer. Defaults to [32, 32, 32]. 67 | receptor_filters (list[int]): Numbers of filters to learn per 68 | convolutional layer. Defaults to [32, 32, 32]. 69 | ligand_kernel_sizes (list[list[int]]): Sizes of kernels per 70 | convolutional layer. Defaults to [ 71 | [3, params['ligand_embedding_size']], 72 | [5, params['ligand_embedding_size']], 73 | [11, params['ligand_embedding_size']] 74 | ] 75 | receptor_kernel_sizes (list[list[int]]): Sizes of kernels per 76 | convolutional layer. Defaults to [ 77 | [3, params['receptor_embedding_size']], 78 | [11, params['receptor_embedding_size']], 79 | [25, params['receptor_embedding_size']] 80 | ] 81 | NOTE: The kernel sizes should match the dimensionality of the 82 | ligand_embedding_size, so if the latter is 8, the images are 83 | t x 8, then treat the 8 embedding dimensions like channels 84 | in an RGB image. 85 | ligand_attention_size (int): size of the attentive layer for the 86 | ligand sequence. Defaults to 16. 87 | receptor_attention_size (int): size of the attentive layer for the 88 | receptor sequence. Defaults to 16. 89 | dense_hidden_sizes (list[int]): Sizes of the hidden dense layers. 90 | Defaults to [20]. 91 | final_activation: (bool): Whether a (sigmoid) activation function 92 | is used in the final layer. Defaults to False. 93 | """ 94 | 95 | super(BimodalMCA, self).__init__(*args, **kwargs) 96 | 97 | # Model Parameter 98 | self.device = get_device() 99 | self.params = params 100 | self.ligand_padding_length = params['ligand_padding_length'] 101 | self.receptor_padding_length = params['receptor_padding_length'] 102 | 103 | self.loss_fn = LOSS_FN_FACTORY[ 104 | params.get('loss_fn', 'binary_cross_entropy') 105 | ] # yapf: disable 106 | self.ligand_embedding_type = params.get('ligand_embedding', 'learned') 107 | self.receptor_embedding_type = params.get( 108 | 'receptor_embedding', 'learned' 109 | ) 110 | 111 | # Hyperparameter 112 | self.act_fn = ACTIVATION_FN_FACTORY[ 113 | params.get('activation_fn', 'relu') 114 | ] # yapf: disable 115 | self.dropout = params.get('dropout', 0.5) 116 | self.use_batch_norm = params.get('batch_norm', True) 117 | self.temperature = params.get('temperature', 1.0) 118 | self.ligand_filters = params.get('ligand_filters', [32, 32, 32]) 119 | self.receptor_filters = params.get('receptor_filters', [32, 32, 32]) 120 | 121 | # set embedding_size to vocabulary_size if one_hot encoding is chosen 122 | if params.get('ligand_embedding', 'learned') == 'one_hot': 123 | self.ligand_embedding_size = params.get( 124 | 'ligand_vocabulary_size', 32 125 | ) 126 | else: 127 | self.ligand_embedding_size = params.get( 128 | 'ligand_embedding_size', 32 129 | ) 130 | if params.get('receptor_embedding', 'learned') == 'one_hot': 131 | self.receptor_embedding_size = params.get( 132 | 'receptor_vocabulary_size', 35 133 | ) 134 | else: 135 | self.receptor_embedding_size = params.get( 136 | 'receptor_embedding_size', 35 137 | ) 138 | 139 | self.ligand_kernel_sizes = params.get( 140 | 'ligand_kernel_sizes', 141 | [ 142 | [3, self.ligand_embedding_size], 143 | [5, self.ligand_embedding_size], 144 | [11, self.ligand_embedding_size], 145 | ], 146 | ) 147 | self.receptor_kernel_sizes = params.get( 148 | 'receptor_kernel_sizes', 149 | [ 150 | [3, self.receptor_embedding_size], 151 | [11, self.receptor_embedding_size], 152 | [25, self.receptor_embedding_size], 153 | ], 154 | ) 155 | 156 | self.ligand_attention_size = params.get('ligand_attention_size', 16) 157 | self.receptor_attention_size = params.get( 158 | 'receptor_attention_size', 16 159 | ) 160 | 161 | self.ligand_hidden_sizes = [ 162 | self.ligand_embedding_size 163 | ] + self.ligand_filters 164 | self.receptor_hidden_sizes = [ 165 | self.receptor_embedding_size 166 | ] + self.receptor_filters 167 | self.hidden_sizes = [ 168 | self.ligand_embedding_size + sum(self.ligand_filters) + 169 | self.receptor_embedding_size + sum(self.receptor_filters) 170 | ] + params.get('dense_hidden_sizes', [20]) 171 | if self.use_batch_norm: 172 | self.batch_norm = nn.BatchNorm1d(self.hidden_sizes[0]) 173 | 174 | # Sanity checking of model sizes 175 | if len(self.ligand_filters) != len(self.ligand_kernel_sizes): 176 | raise ValueError( 177 | 'Length of ligand filter and kernel size lists do not match.' 178 | ) 179 | if len(self.receptor_filters) != len(self.receptor_kernel_sizes): 180 | raise ValueError( 181 | 'Length of receptor filter and kernel size lists do not match.' 182 | ) 183 | if len(self.ligand_filters) != len(self.receptor_filters): 184 | raise ValueError( 185 | 'Length of ligand_filters and receptor_filters array must match' 186 | f', found ligand_filters: {len(self.ligand_filters)} and ' 187 | f'receptor_filters: {len(self.receptor_filters)}.' 188 | ) 189 | """ Construct model """ 190 | # Embeddings 191 | if params.get('ligand_embedding', 'learned') == 'pretrained': 192 | # Load the pretrained embeddings 193 | try: 194 | with open(params['ligand_embedding_path'], 'rb') as f: 195 | embeddings = pickle.load(f) 196 | except KeyError: 197 | raise KeyError('Path for ligand embeddings missing in params.') 198 | 199 | # Plug into layer 200 | self.ligand_embedding = nn.Embedding( 201 | embeddings.shape[0], embeddings.shape[1] 202 | ) 203 | self.ligand_embedding.load_state_dict( 204 | {'weight': torch.Tensor(embeddings)} 205 | ) 206 | if params.get('fix_ligand_embeddings', True): 207 | self.ligand_embedding.weight.requires_grad = False 208 | 209 | elif params.get('ligand_embedding', 'learned') == 'one_hot': 210 | self.ligand_embedding = nn.Embedding( 211 | self.params['ligand_vocabulary_size'], 212 | self.params['ligand_vocabulary_size'], 213 | ) 214 | # Plug in one hot-vectors and freeze weights 215 | self.ligand_embedding.load_state_dict( 216 | { 217 | 'weight': 218 | torch.nn.functional.one_hot( 219 | torch.arange( 220 | self.params['ligand_vocabulary_size'] 221 | ) 222 | ) 223 | } 224 | ) 225 | self.ligand_embedding.weight.requires_grad = False 226 | 227 | elif params.get('ligand_embedding', 'learned') == 'learned': 228 | self.ligand_embedding = nn.Embedding( 229 | self.params['ligand_vocabulary_size'], 230 | self.ligand_embedding_size, 231 | scale_grad_by_freq=params.get('embed_scale_grad', False) 232 | ) 233 | else: 234 | assert params.get( 235 | 'ligand_embedding', 'learned' 236 | ) == 'predefined', 'Choose either pretrained, one_hot, predefined \ 237 | or learned as ligand_embedding. Defaults to learned' 238 | 239 | if params.get('receptor_embedding', 'learned') == 'pretrained': 240 | # Load the pretrained embeddings 241 | try: 242 | with open(params['receptor_embedding_path'], 'rb') as f: 243 | embeddings = pickle.load(f) 244 | except KeyError: 245 | raise KeyError( 246 | 'Path for receptor embeddings missing in params.' 247 | ) 248 | 249 | # Plug into layer 250 | self.receptor_embedding = nn.Embedding( 251 | embeddings.shape[0], embeddings.shape[1] 252 | ) 253 | self.receptor_embedding.load_state_dict( 254 | {'weight': torch.Tensor(embeddings)} 255 | ) 256 | if params.get('fix_receptor_embeddings', True): 257 | self.receptor_embedding.weight.requires_grad = False 258 | 259 | elif params.get('receptor_embedding', 'learned') == 'one_hot': 260 | self.receptor_embedding = nn.Embedding( 261 | self.params['receptor_vocabulary_size'], 262 | self.params['receptor_vocabulary_size'], 263 | ) 264 | # Plug in one hot-vectors and freeze weights 265 | self.receptor_embedding.load_state_dict( 266 | { 267 | 'weight': 268 | torch.nn.functional.one_hot( 269 | torch.arange( 270 | self.params['receptor_vocabulary_size'] 271 | ) 272 | ) 273 | } 274 | ) 275 | self.receptor_embedding.weight.requires_grad = False 276 | 277 | elif params.get('receptor_embedding', 'learned') == 'learned': 278 | self.receptor_embedding = nn.Embedding( 279 | self.params['receptor_vocabulary_size'], 280 | self.receptor_embedding_size, 281 | scale_grad_by_freq=params.get('embed_scale_grad', False), 282 | ) 283 | else: 284 | assert params.get( 285 | 'receptor_embedding', 'learned' 286 | ) == 'predefined', 'Choose either pretrained, one_hot, predefined \ 287 | or learned as ligand_embedding. Defaults to learned' 288 | 289 | # Convolutions 290 | # TODO: Use nn.ModuleDict instead of the nn.Seq/OrderedDict 291 | self.ligand_convolutional_layers = nn.Sequential( 292 | OrderedDict( 293 | [ 294 | ( 295 | f'ligand_convolutional_{index}', 296 | convolutional_layer( 297 | num_kernel, 298 | kernel_size, 299 | act_fn=self.act_fn, 300 | dropout=self.dropout, 301 | batch_norm=self.use_batch_norm, 302 | ).to(self.device), 303 | ) 304 | for index, (num_kernel, kernel_size) in enumerate( 305 | zip(self.ligand_filters, self.ligand_kernel_sizes) 306 | ) 307 | ] 308 | ) 309 | ) # yapf: disable 310 | 311 | self.receptor_convolutional_layers = nn.Sequential( 312 | OrderedDict( 313 | [ 314 | ( 315 | f'receptor_convolutional_{index}', 316 | convolutional_layer( 317 | num_kernel, 318 | kernel_size, 319 | act_fn=self.act_fn, 320 | dropout=self.dropout, 321 | batch_norm=self.use_batch_norm, 322 | ).to(self.device), 323 | ) 324 | for index, (num_kernel, kernel_size) in enumerate( 325 | zip(self.receptor_filters, self.receptor_kernel_sizes) 326 | ) 327 | ] 328 | ) 329 | ) # yapf: disable 330 | 331 | # Context attention 332 | self.context_attention_ligand_layers = nn.Sequential( 333 | OrderedDict( 334 | [ 335 | ( 336 | f'context_attention_ligand_{layer}', 337 | ContextAttentionLayer( 338 | self.ligand_hidden_sizes[layer], 339 | self.params['ligand_padding_length'], 340 | self.receptor_hidden_sizes[layer], 341 | context_sequence_length=( 342 | self.receptor_padding_length 343 | ), 344 | attention_size=self.ligand_attention_size, 345 | individual_nonlinearity=params.get( 346 | 'context_nonlinearity', nn.Sequential() 347 | ), 348 | temperature=self.temperature, 349 | ), 350 | ) for layer in range(len(self.ligand_filters) + 1) 351 | ] 352 | ) 353 | ) 354 | 355 | self.context_attention_receptor_layers = nn.Sequential( 356 | OrderedDict( 357 | [ 358 | ( 359 | f'context_attention_receptor_{layer}', 360 | ContextAttentionLayer( 361 | self.receptor_hidden_sizes[layer], 362 | self.params['receptor_padding_length'], 363 | self.ligand_hidden_sizes[layer], 364 | context_sequence_length=self.ligand_padding_length, 365 | attention_size=self.receptor_attention_size, 366 | individual_nonlinearity=params.get( 367 | 'context_nonlinearity', nn.Sequential() 368 | ), 369 | temperature=self.temperature, 370 | ), 371 | ) for layer in range(len(self.receptor_filters) + 1) 372 | ] 373 | ) 374 | ) 375 | 376 | self.dense_layers = nn.Sequential( 377 | OrderedDict( 378 | [ 379 | ( 380 | f'dense_{ind}', 381 | dense_layer( 382 | self.hidden_sizes[ind], 383 | self.hidden_sizes[ind + 1], 384 | act_fn=self.act_fn, 385 | dropout=self.dropout, 386 | batch_norm=self.use_batch_norm, 387 | ).to(self.device), 388 | ) for ind in range(len(self.hidden_sizes) - 1) 389 | ] 390 | ) 391 | ) 392 | 393 | self.final_dense = nn.Linear(self.hidden_sizes[-1], 1) 394 | if params.get('final_activation', True): 395 | self.final_dense = nn.Sequential( 396 | self.final_dense, ACTIVATION_FN_FACTORY['sigmoid'] 397 | ) 398 | 399 | def forward(self, ligand, receptors, confidence=False): 400 | """Forward pass through the biomodal MCA. 401 | 402 | Args: 403 | ligand (torch.Tensor): of type int and shape 404 | `[bs, ligand_padding_length]`. 405 | receptors (torch.Tensor): of type int and shape 406 | `[bs, receptor_padding_length]`. 407 | confidence (bool, optional) whether the confidence estimates are 408 | performed. 409 | 410 | Returns: 411 | (torch.Tensor, torch.Tensor): predictions, prediction_dict 412 | 413 | predictions is IC50 drug sensitivity prediction of shape `[bs, 1]`. 414 | prediction_dict includes the prediction and attention weights. 415 | """ 416 | # Embedding 417 | if self.ligand_embedding_type == 'predefined': 418 | embedded_ligand = ligand.to(torch.float) 419 | else: 420 | embedded_ligand = self.ligand_embedding(ligand.to(torch.int64)) 421 | if self.receptor_embedding_type == 'predefined': 422 | embedded_receptor = receptors.to(torch.float) 423 | else: 424 | embedded_receptor = self.receptor_embedding( 425 | receptors.to(torch.int64) 426 | ) 427 | 428 | # Convolutions 429 | encoded_ligand = [embedded_ligand] + [ 430 | layer(torch.unsqueeze(embedded_ligand, 1)).permute(0, 2, 1) 431 | for layer in self.ligand_convolutional_layers 432 | ] 433 | encoded_receptor = [embedded_receptor] + [ 434 | layer(torch.unsqueeze(embedded_receptor, 1)).permute(0, 2, 1) 435 | for layer in self.receptor_convolutional_layers 436 | ] 437 | 438 | # Context attention on ligand 439 | ligand_encodings, ligand_alphas = zip( 440 | *[ 441 | layer(reference, context) for layer, reference, context in zip( 442 | self.context_attention_ligand_layers, 443 | encoded_ligand, 444 | encoded_receptor, 445 | ) 446 | ] 447 | ) 448 | 449 | # Context attention on receptor 450 | receptor_encodings, receptor_alphas = zip( 451 | *[ 452 | layer(reference, context) for layer, reference, context in zip( 453 | self.context_attention_receptor_layers, 454 | encoded_receptor, 455 | encoded_ligand, 456 | ) 457 | ] 458 | ) 459 | 460 | # Concatenate all encodings 461 | encodings = torch.cat( 462 | [ 463 | torch.cat(ligand_encodings, dim=1), 464 | torch.cat(receptor_encodings, dim=1), 465 | ], 466 | dim=1, 467 | ) 468 | 469 | # Apply batch normalization if specified 470 | out = self.batch_norm(encodings) if self.use_batch_norm else encodings 471 | 472 | # Stack dense layers 473 | for dl in self.dense_layers: 474 | out = dl(out) 475 | predictions = self.final_dense(out) 476 | 477 | prediction_dict = {} 478 | if not self.training: 479 | # The below is to ease postprocessing 480 | ligand_attention_weights = torch.mean( 481 | torch.cat( 482 | [torch.unsqueeze(p, -1) for p in ligand_alphas], dim=-1 483 | ), 484 | dim=-1, 485 | ) 486 | receptor_attention_weights = torch.mean( 487 | torch.cat( 488 | [torch.unsqueeze(p, -1) for p in receptor_alphas], dim=-1 489 | ), 490 | dim=-1, 491 | ) 492 | prediction_dict.update( 493 | { 494 | 'ligand_attention': ligand_attention_weights, 495 | 'receptor_attention': receptor_attention_weights, 496 | } 497 | ) # yapf: disable 498 | 499 | if confidence: 500 | augmenter = AugmentTensor(self.smiles_language) 501 | epistemic_conf = monte_carlo_dropout( 502 | self, 503 | regime='tensors', 504 | tensors=(ligand, receptors), 505 | repetitions=5, 506 | ) 507 | aleatoric_conf = test_time_augmentation( 508 | self, 509 | regime='tensors', 510 | tensors=(ligand, receptors), 511 | repetitions=5, 512 | augmenter=augmenter, 513 | tensors_to_augment=0, 514 | ) 515 | 516 | prediction_dict.update( 517 | { 518 | 'epistemic_confidence': epistemic_conf, 519 | 'aleatoric_confidence': aleatoric_conf, 520 | } 521 | ) # yapf: disable 522 | 523 | return predictions, prediction_dict 524 | 525 | def loss(self, yhat, y): 526 | return self.loss_fn(yhat, y) 527 | 528 | def _associate_language(self, language): 529 | """ 530 | Bind a SMILES or Protein language object to the model. 531 | Is only used inside the confidence estimation. 532 | 533 | Arguments: 534 | language {Union[ 535 | pytoda.smiles.smiles_language.SMILESLanguage, 536 | pytoda.proteins.protein_langauge.ProteinLanguage 537 | ]} -- [A SMILES or Protein language object] 538 | 539 | Raises: 540 | TypeError: 541 | """ 542 | if isinstance(language, pytoda.smiles.smiles_language.SMILESLanguage): 543 | self.smiles_language = language 544 | 545 | elif isinstance( 546 | language, pytoda.proteins.protein_language.ProteinLanguage 547 | ): 548 | self.protein_language = language 549 | else: 550 | raise TypeError( 551 | 'Please insert a smiles language (object of type ' 552 | 'pytoda.smiles.smiles_language.SMILESLanguage or ' 553 | 'pytoda.proteins.protein_language.ProteinLanguage). Given was ' 554 | f'{type(language)}' 555 | ) 556 | 557 | def load(self, path, *args, **kwargs): 558 | """Load model from path.""" 559 | weights = torch.load(path, *args, **kwargs) 560 | self.load_state_dict(weights) 561 | 562 | def save(self, path, *args, **kwargs): 563 | """Save model to path.""" 564 | torch.save(self.state_dict(), path, *args, **kwargs) 565 | --------------------------------------------------------------------------------