├── README.md ├── deepdta_retrain.py ├── examples ├── bindingDB_processed.csv ├── cleaned_mpro.csv ├── deepdta_original_split-prk12-ldk8.pt ├── ligand_dict-prk12-ldk8.json ├── protein_dict-prk12-ldk8.json ├── test-result-prk12-ldk8.txt └── training-prk12-ldk8.log ├── model.py ├── model_result.ipynb ├── requirements.txt └── trainer.py /README.md: -------------------------------------------------------------------------------- 1 | # DeepDTA-Pytorch 2 | Pytorch Implementation of the original DeepDTA paper. [Original GitHub Repo](https://github.com/hkmztrk/DeepDTA/) 3 | 4 | Requirements (most of them come with Anaconda except torch, pytorch-cuda, and tqdm) 5 | ``` 6 | python==3.8.16 7 | numpy==1.24.1 8 | pandas==1.5.2 9 | matplotlib==3.5.3 10 | scipy==1.8.1 11 | torch==2.1.0 12 | pytorch-cuda==11.7 13 | tqdm==4.65.0 14 | ``` 15 | 16 | The data format should be in the form of a csv file with four columns: proteins, ligands, affinity, split, where proteins store all the sequence information, ligands store the isomeric smile strings of the molecular binders, and affinity was either the Kd/Ki value or the bidning affinity in kcal/mol (this needs to be consistent for all data). The final split column will have three possible values that indicate the train-val-test splitting: 'train', 'val', and 'test'. See an example in `examples/cleaned_mpro.csv` 17 | 18 | To run the code, go to deepdta_retrain.py to do the appropriate modification of fp and then run python deepdta_retrain.py 19 | 20 | For analysis, there's a separate jupyter notebook files for some preliminary scatter plots and using the trained model to analyze a held-out set of data. Make sure to change the name of ligand_dict and protein_dict and the model you want to use to your choices. This part of analysis is mainly for choosing the best hyperparameters of protein and ligand kernel size. In the first fp you can use `examples/cleaned_mpro.csv`. Then, in the held out data csv, you can use `examples/bindingDB_processed.csv`. 21 | 22 | Also, feel free to check out our paper that tests this implementation on a better PDBBind Splitting here. [![arXiv](https://img.shields.io/badge/arXiv-2308.09639v2-B31B1B)](https://arxiv.org/abs/2308.09639v2) 23 | 24 | ## Citation 25 | ```bibtex 26 | @article{lppdbbind, 27 | title = {Leak {Proof} {PDBBind}: {A} {Reorganized} {Dataset} of {Protein}-{Ligand} {Complexes} for {More} {Generalizable} {Binding} {Affinity} {Prediction}}, 28 | journal = {ArXiv}, 29 | author = {Li, Jie and Guan, Xingyi and Zhang, Oufan and Sun, Kunyang and Wang, Yingze and Bagni, Dorian and Head-Gordon, Teresa}, 30 | month = may, 31 | year = {2024}, 32 | pmid = {37645037}, 33 | pmcid = {PMC10462179}, 34 | pages = {arXiv:2308.09639v2}, 35 | } 36 | 37 | @article{10.1093/bioinformatics/bty593, 38 | author = {Öztürk, Hakime and Özgür, Arzucan and Ozkirimli, Elif}, 39 | title = "{DeepDTA: deep drug–target binding affinity prediction}", 40 | journal = {Bioinformatics}, 41 | volume = {34}, 42 | number = {17}, 43 | pages = {i821-i829}, 44 | year = {2018}, 45 | month = {09}, 46 | issn = {1367-4803}, 47 | doi = {10.1093/bioinformatics/bty593}, 48 | url = {https://doi.org/10.1093/bioinformatics/bty593}, 49 | eprint = {https://academic.oup.com/bioinformatics/article-pdf/34/17/i821/25702584/bty593.pdf}, 50 | } 51 | ``` 52 | -------------------------------------------------------------------------------- /deepdta_retrain.py: -------------------------------------------------------------------------------- 1 | from model import DeepDTA 2 | from trainer import Trainer 3 | import pandas as pd 4 | 5 | # this CSV file has 4 columns, protein, ligands, affinity, split. 6 | 7 | fp = 'path_to_your_data.csv' 8 | 9 | df = pd.read_csv(fp) 10 | train_idx = df[df['split'] == 'train'].index.values 11 | val_idx = df[df['split'] == 'val'].index.values 12 | test_idx = df[df['split'] == 'test'].index.values 13 | 14 | model = DeepDTA 15 | channel = 32 16 | protein_kernel = [8, 12] 17 | ligand_kernel = [4, 8] 18 | 19 | for prk in protein_kernel: 20 | for ldk in ligand_kernel: 21 | # epoch 50 is enough for convergence in this case, but may need more for other datasets 22 | trainer = Trainer(model, channel, prk, ldk, df, train_idx, val_idx, test_idx, "training-prk{}-ldk{}.log".format(prk, ldk)) 23 | trainer.train(num_epochs=30, batch_size=256, lr=0.001, save_path='deepdta_retrain-prk{}-ldk{}.pt'.format(prk, ldk)) 24 | -------------------------------------------------------------------------------- /examples/deepdta_original_split-prk12-ldk8.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KSUN63/DeepDTA-Pytorch/78c1ffb8e44a8b3ed03ce58e3f4e3e805cc03c31/examples/deepdta_original_split-prk12-ldk8.pt -------------------------------------------------------------------------------- /examples/ligand_dict-prk12-ldk8.json: -------------------------------------------------------------------------------- 1 | {"~": 0, "I": 1, "4": 2, "+": 3, "5": 4, "s": 5, "l": 6, "#": 7, "/": 8, "\\": 9, "o": 10, "P": 11, "B": 12, "n": 13, "(": 14, "@": 15, "H": 16, "F": 17, "6": 18, "1": 19, "O": 20, "S": 21, "r": 22, "c": 23, "2": 24, "3": 25, ")": 26, "]": 27, "dummy": 28, "[": 29, "C": 30, "=": 31, "-": 32, "N": 33} -------------------------------------------------------------------------------- /examples/protein_dict-prk12-ldk8.json: -------------------------------------------------------------------------------- 1 | {"I": 0, "D": 1, "V": 2, "W": 3, "X": 4, "M": 5, "K": 6, "P": 7, "R": 8, "A": 9, "G": 10, "H": 11, "F": 12, "Q": 13, ":": 14, "L": 15, "T": 16, "Y": 17, "S": 18, "E": 19, "dummy": 20, "C": 21, "N": 22} -------------------------------------------------------------------------------- /examples/test-result-prk12-ldk8.txt: -------------------------------------------------------------------------------- 1 | 3.597082376480102539e+00 2 | 2.890498638153076172e+00 3 | 6.756830692291259766e+00 4 | 5.875916004180908203e+00 5 | 6.186888217926025391e+00 6 | 6.615565776824951172e+00 7 | 5.445238590240478516e+00 8 | 5.148606300354003906e+00 9 | -------------------------------------------------------------------------------- /examples/training-prk12-ldk8.log: -------------------------------------------------------------------------------- 1 | 2023-05-21 23:28:30,884 - INFO - Epoch: 1 - Training Loss: 16.421245 2 | 2023-05-21 23:28:31,326 - INFO - Best Model So Far in Epoch: 1 3 | 2023-05-21 23:28:31,327 - INFO - Epoch: 1 - Validation Loss: 3.224245 4 | 2023-05-21 23:28:37,866 - INFO - Epoch: 2 - Training Loss: 3.216737 5 | 2023-05-21 23:28:38,300 - INFO - Best Model So Far in Epoch: 2 6 | 2023-05-21 23:28:38,302 - INFO - Epoch: 2 - Validation Loss: 2.930337 7 | 2023-05-21 23:28:44,835 - INFO - Epoch: 3 - Training Loss: 2.896854 8 | 2023-05-21 23:28:45,274 - INFO - Best Model So Far in Epoch: 3 9 | 2023-05-21 23:28:45,275 - INFO - Epoch: 3 - Validation Loss: 2.805947 10 | 2023-05-21 23:28:51,790 - INFO - Epoch: 4 - Training Loss: 2.612218 11 | 2023-05-21 23:28:52,226 - INFO - Best Model So Far in Epoch: 4 12 | 2023-05-21 23:28:52,227 - INFO - Epoch: 4 - Validation Loss: 2.508415 13 | 2023-05-21 23:28:58,734 - INFO - Epoch: 5 - Training Loss: 2.413236 14 | 2023-05-21 23:28:59,170 - INFO - Best Model So Far in Epoch: 5 15 | 2023-05-21 23:28:59,171 - INFO - Epoch: 5 - Validation Loss: 2.416417 16 | 2023-05-21 23:29:05,689 - INFO - Epoch: 6 - Training Loss: 2.174455 17 | 2023-05-21 23:29:06,126 - INFO - Best Model So Far in Epoch: 6 18 | 2023-05-21 23:29:06,128 - INFO - Epoch: 6 - Validation Loss: 2.187035 19 | 2023-05-21 23:29:12,646 - INFO - Epoch: 7 - Training Loss: 2.053800 20 | 2023-05-21 23:29:13,084 - INFO - Epoch: 7 - Validation Loss: 2.305705 21 | 2023-05-21 23:29:19,628 - INFO - Epoch: 8 - Training Loss: 1.878061 22 | 2023-05-21 23:29:20,066 - INFO - Best Model So Far in Epoch: 8 23 | 2023-05-21 23:29:20,068 - INFO - Epoch: 8 - Validation Loss: 1.904182 24 | 2023-05-21 23:29:26,587 - INFO - Epoch: 9 - Training Loss: 1.603704 25 | 2023-05-21 23:29:27,024 - INFO - Best Model So Far in Epoch: 9 26 | 2023-05-21 23:29:27,025 - INFO - Epoch: 9 - Validation Loss: 1.774848 27 | 2023-05-21 23:29:33,570 - INFO - Epoch: 10 - Training Loss: 1.513370 28 | 2023-05-21 23:29:33,996 - INFO - Best Model So Far in Epoch: 10 29 | 2023-05-21 23:29:33,996 - INFO - Epoch: 10 - Validation Loss: 1.726570 30 | 2023-05-21 23:29:40,518 - INFO - Epoch: 11 - Training Loss: 1.388718 31 | 2023-05-21 23:29:40,956 - INFO - Best Model So Far in Epoch: 11 32 | 2023-05-21 23:29:40,956 - INFO - Epoch: 11 - Validation Loss: 1.621160 33 | 2023-05-21 23:29:47,473 - INFO - Epoch: 12 - Training Loss: 1.301370 34 | 2023-05-21 23:29:47,910 - INFO - Epoch: 12 - Validation Loss: 2.047675 35 | 2023-05-21 23:29:54,418 - INFO - Epoch: 13 - Training Loss: 1.239951 36 | 2023-05-21 23:29:54,860 - INFO - Epoch: 13 - Validation Loss: 1.667548 37 | 2023-05-21 23:30:01,397 - INFO - Epoch: 14 - Training Loss: 1.139910 38 | 2023-05-21 23:30:01,830 - INFO - Epoch: 14 - Validation Loss: 1.636454 39 | 2023-05-21 23:30:08,362 - INFO - Epoch: 15 - Training Loss: 1.280690 40 | 2023-05-21 23:30:08,797 - INFO - Best Model So Far in Epoch: 15 41 | 2023-05-21 23:30:08,798 - INFO - Epoch: 15 - Validation Loss: 1.564113 42 | 2023-05-21 23:30:15,317 - INFO - Epoch: 16 - Training Loss: 1.072031 43 | 2023-05-21 23:30:15,750 - INFO - Epoch: 16 - Validation Loss: 1.602668 44 | 2023-05-21 23:30:22,276 - INFO - Epoch: 17 - Training Loss: 1.070402 45 | 2023-05-21 23:30:22,712 - INFO - Epoch: 17 - Validation Loss: 1.593379 46 | 2023-05-21 23:30:29,232 - INFO - Epoch: 18 - Training Loss: 0.925829 47 | 2023-05-21 23:30:29,668 - INFO - Epoch: 18 - Validation Loss: 1.564833 48 | 2023-05-21 23:30:36,207 - INFO - Epoch: 19 - Training Loss: 0.888907 49 | 2023-05-21 23:30:36,640 - INFO - Best Model So Far in Epoch: 19 50 | 2023-05-21 23:30:36,642 - INFO - Epoch: 19 - Validation Loss: 1.519529 51 | 2023-05-21 23:30:43,176 - INFO - Epoch: 20 - Training Loss: 0.894562 52 | 2023-05-21 23:30:43,618 - INFO - Epoch: 20 - Validation Loss: 1.560987 53 | 2023-05-21 23:30:50,144 - INFO - Epoch: 21 - Training Loss: 0.925372 54 | 2023-05-21 23:30:50,579 - INFO - Epoch: 21 - Validation Loss: 1.693252 55 | 2023-05-21 23:30:57,187 - INFO - Epoch: 22 - Training Loss: 0.880360 56 | 2023-05-21 23:30:57,618 - INFO - Epoch: 22 - Validation Loss: 1.708040 57 | 2023-05-21 23:31:04,148 - INFO - Epoch: 23 - Training Loss: 0.949208 58 | 2023-05-21 23:31:04,587 - INFO - Epoch: 23 - Validation Loss: 1.560593 59 | 2023-05-21 23:31:11,111 - INFO - Epoch: 24 - Training Loss: 0.759356 60 | 2023-05-21 23:31:11,553 - INFO - Epoch: 24 - Validation Loss: 1.531301 61 | 2023-05-21 23:31:18,086 - INFO - Epoch: 25 - Training Loss: 0.677824 62 | 2023-05-21 23:31:18,521 - INFO - Best Model So Far in Epoch: 25 63 | 2023-05-21 23:31:18,522 - INFO - Epoch: 25 - Validation Loss: 1.493793 64 | 2023-05-21 23:31:25,040 - INFO - Epoch: 26 - Training Loss: 0.659823 65 | 2023-05-21 23:31:25,472 - INFO - Epoch: 26 - Validation Loss: 1.763883 66 | 2023-05-21 23:31:32,032 - INFO - Epoch: 27 - Training Loss: 0.664094 67 | 2023-05-21 23:31:32,465 - INFO - Epoch: 27 - Validation Loss: 1.579167 68 | 2023-05-21 23:31:39,000 - INFO - Epoch: 28 - Training Loss: 0.742393 69 | 2023-05-21 23:31:39,442 - INFO - Epoch: 28 - Validation Loss: 1.547086 70 | 2023-05-21 23:31:45,959 - INFO - Epoch: 29 - Training Loss: 0.702457 71 | 2023-05-21 23:31:46,399 - INFO - Epoch: 29 - Validation Loss: 2.110060 72 | 2023-05-21 23:31:52,930 - INFO - Epoch: 30 - Training Loss: 0.952292 73 | 2023-05-21 23:31:53,376 - INFO - Best Model So Far in Epoch: 30 74 | 2023-05-21 23:31:53,377 - INFO - Epoch: 30 - Validation Loss: 1.478125 75 | 2023-05-21 23:31:53,432 - INFO - Best Model Loaded from Epoch: 30 76 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Replication of the model architecture used in the paper DeepDTA in pytorch 2 | # Author: @ksun63 3 | # Date: 2023-04-14 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | class Conv1d(nn.Module): 9 | """ 10 | Three 1d convolutional layer with relu activation stacked on top of each other 11 | with a final global maxpooling layer 12 | """ 13 | def __init__(self, vocab_size, channel, kernel_size, stride=1, padding=0): 14 | super(Conv1d, self).__init__() 15 | self.embedding = nn.Embedding(vocab_size, embedding_dim=128) 16 | self.conv1 = nn.Conv1d(128, channel, kernel_size, stride, padding) 17 | self.conv2 = nn.Conv1d(channel, channel*2, kernel_size, stride, padding) 18 | self.conv3 = nn.Conv1d(channel*2, channel*3, kernel_size, stride, padding) 19 | self.relu = nn.ReLU() 20 | self.globalmaxpool = nn.AdaptiveMaxPool1d(1) 21 | 22 | def forward(self, x): 23 | x = self.embedding(x) 24 | x = x.permute(0, 2, 1) 25 | x = self.conv1(x) 26 | x = self.conv2(x) 27 | x = self.conv3(x) 28 | x = self.relu(x) 29 | x = self.globalmaxpool(x) 30 | x = x.squeeze(-1) 31 | return x 32 | 33 | class DeepDTA(nn.Module): 34 | """DeepDTA model architecture, Y-shaped net that does 1d convolution on 35 | both the ligand and the protein representation and then concatenates the 36 | result into a final predictor of binding affinity""" 37 | 38 | def __init__(self, pro_vocab_size, lig_vocab_size, channel, protein_kernel_size, ligand_kernel_size): 39 | super(DeepDTA, self).__init__() 40 | self.ligand_conv = Conv1d(lig_vocab_size, channel, ligand_kernel_size) 41 | self.protein_conv = Conv1d(pro_vocab_size, channel, protein_kernel_size) 42 | self.fc1 = nn.Linear(channel*6, 1024) 43 | self.fc2 = nn.Linear(1024, 1024) 44 | self.fc3 = nn.Linear(1024, 512) 45 | self.fc4 = nn.Linear(512, 1) 46 | self.dropout = nn.Dropout(0.1) 47 | self.relu = nn.ReLU() 48 | 49 | def forward(self, protein, ligand): 50 | x1 = self.ligand_conv(ligand) 51 | x2 = self.protein_conv(protein) 52 | x = torch.cat((x1, x2), dim=1) 53 | x = self.fc1(x) 54 | x = self.relu(x) 55 | x = self.dropout(x) 56 | x = self.fc2(x) 57 | x = self.relu(x) 58 | x = self.dropout(x) 59 | x = self.fc3(x) 60 | x = self.relu(x) 61 | x = self.dropout(x) 62 | x = self.fc4(x) 63 | return x.squeeze() 64 | -------------------------------------------------------------------------------- /model_result.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "from scipy.stats import pearsonr, spearmanr\n", 11 | "import pandas as pd\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "import glob" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "fp = 'path_to_your_data.csv'" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "df = pd.read_csv(fp)\n", 32 | "test_set_values = df[df['split'] == 'test']['affinity'].values" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "model_results = glob.glob('test-result*.txt')\n", 42 | "fig, ax = plt.subplots(len(model_results), 1, figsize=(5, 20))\n", 43 | "i = 0\n", 44 | "\n", 45 | "for model_result in model_results:\n", 46 | " model_param = model_result.split('-')[2], model_result.split('-')[3].split('.')[0]\n", 47 | " print(model_param)\n", 48 | " result = np.loadtxt(model_result)\n", 49 | " print(len(result), test_set_values.shape)\n", 50 | " ax[i].scatter(test_set_values, result)\n", 51 | " ax[i].set_title('Protein kernel: {}, Ligand kernel: {}'.format(model_param[0], model_param[1]))\n", 52 | " ax[i].set_xlabel('True affinity')\n", 53 | " ax[i].set_ylabel('Predicted affinity')\n", 54 | " ax[i].plot([0,12], [0,12], 'k--', lw=4)\n", 55 | " i += 1\n", 56 | " # calculate the correlation coefficient\n", 57 | " print(\"Pearson correlation coefficient: {}\".format(pearsonr(test_set_values, result)[0]))\n", 58 | " print(\"Spearman correlation coefficient: {}\".format(spearmanr(test_set_values, result)[0]))\n", 59 | " print(\"-\" * 20)\n", 60 | "\n", 61 | "plt.show()" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "#load a held-out test set\n", 71 | "\n", 72 | "from rdkit import Chem\n", 73 | "\n", 74 | "other_fp = 'file_to_other_data.csv'\n", 75 | "df2 = pd.read_csv(other_fp).dropna()" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "def convert_to_isomeric(smiles):\n", 85 | " \"\"\"\n", 86 | " convert a smile string to an isomeric smile string\n", 87 | " \"\"\"\n", 88 | " m = Chem.MolFromSmiles(smiles)\n", 89 | " return Chem.MolToSmiles(m, isomericSmiles=True)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "# convert the smiles to isomeric smiles\n", 99 | "df2['ligands'] = df2['smiles'].apply(convert_to_isomeric)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "# load the best DeepDTA model to predict for these test sets\n", 109 | "from model import DeepDTA\n", 110 | "import torch, json\n", 111 | "from torchsummary import summary\n", 112 | "\n", 113 | "\n", 114 | "# convert the smiles to one-hot encoding; CHANGE TO YOUR OWN PATH OF YOUR BEST MODEL\n", 115 | "ligand_dict = json.load(open('ligand_dict.json'))\n", 116 | "protein_dict = json.load(open('protein_dict.json'))\n", 117 | "smilelen, seqlen = 200, 2000\n", 118 | "\n", 119 | "# load model\n", 120 | "model = DeepDTA(len(protein_dict)+1, len(ligand_dict)+1, 32, 8, 8)\n", 121 | "model.load_state_dict(torch.load('deepdta_retrain.pt'))\n", 122 | "model.eval()\n", 123 | "\n", 124 | "df2_result = []\n", 125 | "for i in range(len(df2)):\n", 126 | " ligand = df2.iloc[i]['ligands']\n", 127 | " protein = df2.iloc[i]['proteins']\n", 128 | " protein = [protein_dict[x] for x in protein] + [protein_dict['dummy']] * (seqlen - len(protein))\n", 129 | " ligand = [ligand_dict[x] for x in ligand] + [ligand_dict['dummy']] * (smilelen - len(ligand))\n", 130 | " ligand = torch.tensor(ligand).unsqueeze(0)\n", 131 | " protein = torch.tensor(protein).unsqueeze(0)\n", 132 | " with torch.no_grad():\n", 133 | " result = model(protein, ligand)\n", 134 | " df2_result.append(result.item())\n", 135 | "\n", 136 | "df2_result = np.array(df2_result)\n", 137 | "ground_truth = df2['affinity'].values" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "plt.scatter(ground_truth, df2_result)\n", 147 | "plt.plot([4,8], [4,8], 'k--', lw=4)\n", 148 | "print(\"Pearson correlation coefficient: {}\".format(pearsonr(ground_truth, df2_result)[0]))\n", 149 | "print(\"Spearman correlation coefficient: {}\".format(spearmanr(ground_truth, df2_result)[0]))" 150 | ] 151 | } 152 | ], 153 | "metadata": { 154 | "kernelspec": { 155 | "display_name": "ml_dev", 156 | "language": "python", 157 | "name": "ml_dev" 158 | }, 159 | "language_info": { 160 | "codemirror_mode": { 161 | "name": "ipython", 162 | "version": 3 163 | }, 164 | "file_extension": ".py", 165 | "mimetype": "text/x-python", 166 | "name": "python", 167 | "nbconvert_exporter": "python", 168 | "pygments_lexer": "ipython3", 169 | "version": "3.8.16" 170 | }, 171 | "orig_nbformat": 4 172 | }, 173 | "nbformat": 4, 174 | "nbformat_minor": 2 175 | } 176 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python==3.8.16 2 | numpy==1.24.1 3 | pandas==1.5.2 4 | matplotlib==3.5.3 5 | scipy==1.8.1 6 | torch==2.1.0 7 | pytorch-cuda==11.7 8 | tqdm==4.65.0 9 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | # data processing and training of the DeepDTA paper in pytorch code with your own data 2 | # Author: @ksun63 3 | # Date: 2023-04-14 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import json 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import Dataset, Subset, DataLoader 11 | from torch.utils.tensorboard import SummaryWriter 12 | from tqdm import tqdm 13 | from copy import deepcopy 14 | import logging 15 | import matplotlib.pyplot as plt 16 | 17 | 18 | class Dataset(Dataset): 19 | """ 20 | Here, the input dataset should be a pandas dataframe with the following columns: 21 | protein, ligands, affinity, where proteins are the protein seqeunces, ligands are the 22 | isomeric SMILES representation of the ligand, and affinity is the binding affinity 23 | """ 24 | def __init__(self, df, seqlen=2000, smilen=200): 25 | """ 26 | df: pandas dataframe with the columns proteins, ligands, affinity 27 | seqlen: max length of the protein sequence 28 | smilen: max length of the ligand SMILES representation 29 | """ 30 | self.proteins = df['proteins'].values 31 | self.ligands = df['ligands'].values 32 | self.affinity = df['affinity'].values 33 | self.smilelen = smilen 34 | self.seqlen = seqlen 35 | self.protein_vocab = set() 36 | self.ligand_vocab = set() 37 | for lig in self.ligands: 38 | for i in lig: 39 | self.ligand_vocab.update(i) 40 | for pr in self.proteins: 41 | for i in pr: 42 | self.protein_vocab.update(i) 43 | 44 | # having a dummy token to pad the sequences to the max length 45 | self.protein_vocab.update(['dummy']) 46 | self.ligand_vocab.update(['dummy']) 47 | self.protein_dict = {x: i for i, x in enumerate(self.protein_vocab)} 48 | self.ligand_dict = {x: i for i, x in enumerate(self.ligand_vocab)} 49 | 50 | 51 | def __len__(self): 52 | """ 53 | Returns the length of the dataset 54 | """ 55 | return len(self.proteins) 56 | 57 | def __getitem__(self, idx): 58 | """ 59 | Get the protein, ligand, and affinity of the idx-th sample 60 | 61 | param idx: index of the sample 62 | """ 63 | pr = self.proteins[idx] 64 | lig = self.ligands[idx] 65 | target = self.affinity[idx] 66 | protein = [self.protein_dict[x] for x in pr] + [self.protein_dict['dummy']] * (self.seqlen - len(pr)) 67 | ligand = [self.ligand_dict[x] for x in lig] + [self.ligand_dict['dummy']] * (self.smilelen - len(lig)) 68 | 69 | return torch.tensor(protein), torch.tensor(ligand), torch.tensor(target, dtype=torch.float) 70 | 71 | def collate_fn(batch): 72 | """ 73 | Collate function for the DataLoader 74 | """ 75 | proteins, ligands, targets = zip(*batch) 76 | proteins = torch.stack(proteins, dim=0) 77 | ligands = torch.stack(ligands, dim=0) 78 | targets = torch.stack(targets, dim=0) 79 | return proteins, ligands, targets 80 | 81 | 82 | class Trainer: 83 | """ 84 | Trainer class of the DeepDTA model 85 | """ 86 | def __init__(self, model, channel, protein_kernel, ligand_kernel, df, train_idx, val_idx, test_idx, 87 | log_file, smilen=200, seqlen=2000): 88 | """ 89 | model: DeepDTA model defined in model.py 90 | df: pandas dataframe with the columns protein, ligands, affinity 91 | train_idx: indices of the training set 92 | val_idx: indices of the validation set 93 | smilen: max length of the ligand SMILES representation 94 | seqlen: max length of the protein sequence 95 | log_file: file to save the training logs 96 | """ 97 | self.dataset = Dataset(df, smilen=smilen, seqlen=seqlen) 98 | self.protein_vocab = len(self.dataset.protein_vocab) + 1 99 | self.ligand_vocab = len(self.dataset.ligand_vocab) + 1 100 | self.train_dataset = Subset(self.dataset, train_idx) 101 | self.val_dataset = Subset(self.dataset, val_idx) 102 | self.test_dataset = Subset(self.dataset, test_idx) 103 | self.protein_kernel = protein_kernel 104 | self.ligand_kernel = ligand_kernel 105 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 106 | self.model = model(self.protein_vocab, self.ligand_vocab, channel, protein_kernel, ligand_kernel).to(self.device) 107 | self.log_file = log_file 108 | 109 | self.logger = logging.getLogger(__name__) 110 | self.logger.setLevel(logging.DEBUG) 111 | file_handler = logging.FileHandler(self.log_file) 112 | file_handler.setLevel(logging.DEBUG) 113 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 114 | file_handler.setFormatter(formatter) 115 | self.logger.addHandler(file_handler) 116 | 117 | def train(self, lr, num_epochs, batch_size, save_path): 118 | """ 119 | Train the model 120 | """ 121 | optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-5) 122 | criterion = nn.MSELoss() 123 | 124 | writer = SummaryWriter() 125 | 126 | train_loader = DataLoader(self.train_dataset, batch_size=batch_size, drop_last = False, shuffle=True, collate_fn=collate_fn) 127 | val_loader = DataLoader(self.val_dataset, batch_size=batch_size, drop_last = False, collate_fn=collate_fn) 128 | test_loader = DataLoader(self.test_dataset, batch_size=batch_size, drop_last = False, collate_fn=collate_fn) 129 | 130 | # save the encoding dictionaries into json files 131 | with open('protein_dict-prk{}-ldk{}.json'.format(self.protein_kernel, self.ligand_kernel), 'w') as f: 132 | json.dump(self.dataset.protein_dict, f) 133 | with open('ligand_dict-prk{}-ldk{}.json'.format(self.protein_kernel, self.ligand_kernel), 'w') as f: 134 | json.dump(self.dataset.ligand_dict, f) 135 | 136 | 137 | best_weights = self.model.state_dict() 138 | best_val_loss = np.inf 139 | best_epoch = 0 140 | 141 | for epoch in range(num_epochs): 142 | self.model.train() 143 | train_loss = 0.0 144 | 145 | with tqdm(total=len(train_loader)) as pbar: 146 | for protein, ligand, target in train_loader: 147 | protein, ligand, target = protein.to(self.device), ligand.to(self.device), target.to(self.device) 148 | 149 | optimizer.zero_grad() 150 | output = self.model(protein, ligand) 151 | loss = criterion(output, target) 152 | loss.backward() 153 | optimizer.step() 154 | 155 | train_loss += loss.item() 156 | 157 | pbar.update(1) 158 | 159 | train_loss /= len(train_loader) 160 | self.logger.info('Epoch: {} - Training Loss: {:.6f}'.format(epoch+1, train_loss)) 161 | writer.add_scalar('train_loss', train_loss, epoch) 162 | 163 | # switch to evaluation mode 164 | self.model.eval() 165 | val_loss = 0.0 166 | 167 | with torch.no_grad(): 168 | for protein, ligand, target in val_loader: 169 | protein, ligand, target = protein.to(self.device), ligand.to(self.device), target.to(self.device) 170 | 171 | output = self.model(protein, ligand) 172 | loss = criterion(output, target) 173 | val_loss += loss.item() 174 | 175 | val_loss /= len(val_loader) 176 | if val_loss < best_val_loss: 177 | best_val_loss = val_loss 178 | best_weights = deepcopy(self.model.state_dict()) 179 | best_epoch = epoch 180 | self.logger.info('Best Model So Far in Epoch: {}'.format(epoch+1)) 181 | self.logger.info('Epoch: {} - Validation Loss: {:.6f}'.format(epoch+1, val_loss)) 182 | writer.add_scalar('val_loss', val_loss, epoch) 183 | 184 | self.model.load_state_dict(best_weights) 185 | test_result = [] 186 | with torch.no_grad(): 187 | for protein, ligand, target in test_loader: 188 | protein, ligand, target = protein.to(self.device), ligand.to(self.device), target.to(self.device) 189 | 190 | output = self.model(protein, ligand) 191 | test_result.append(output.cpu().numpy()) 192 | test_result = np.concatenate(test_result) 193 | np.savetxt('test-result-prk{}-ldk{}.txt'.format(self.protein_kernel, self.ligand_kernel), test_result) 194 | 195 | self.logger.info('Best Model Loaded from Epoch: {}'.format(best_epoch+1)) 196 | torch.save(self.model.state_dict(), save_path) 197 | self.logger.handlers[0].close() 198 | self.logger.removeHandler(self.logger.handlers[0]) 199 | writer.close() 200 | 201 | --------------------------------------------------------------------------------