├── data.rar ├── model.png ├── data ├── kiba_tokenizer.pkl ├── bindingdb_tokenizer.pkl └── Data_files.txt ├── DEMO ├── required_files_for_demo │ ├── __pycache__ │ │ ├── model_aff.cpython-38.pyc │ │ ├── model_gen.cpython-38.pyc │ │ └── demo_utils.cpython-38.pyc │ ├── demo_utils.py │ ├── model_aff.py │ └── model_gen.py ├── DEMO_Affinity.py ├── DEMO_Generation.py └── affinity_generation.py ├── Data_files.txt ├── environment.yml ├── generation_eveluation.py ├── generate.py ├── test.py ├── create_data.py ├── FetterGrad.py ├── training.py ├── README.md ├── utils.py └── model.py /data.rar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSUBioGroup/DeepDTAGen/HEAD/data.rar -------------------------------------------------------------------------------- /model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSUBioGroup/DeepDTAGen/HEAD/model.png -------------------------------------------------------------------------------- /data/kiba_tokenizer.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSUBioGroup/DeepDTAGen/HEAD/data/kiba_tokenizer.pkl -------------------------------------------------------------------------------- /data/bindingdb_tokenizer.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSUBioGroup/DeepDTAGen/HEAD/data/bindingdb_tokenizer.pkl -------------------------------------------------------------------------------- /DEMO/required_files_for_demo/__pycache__/model_aff.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSUBioGroup/DeepDTAGen/HEAD/DEMO/required_files_for_demo/__pycache__/model_aff.cpython-38.pyc -------------------------------------------------------------------------------- /DEMO/required_files_for_demo/__pycache__/model_gen.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSUBioGroup/DeepDTAGen/HEAD/DEMO/required_files_for_demo/__pycache__/model_gen.cpython-38.pyc -------------------------------------------------------------------------------- /DEMO/required_files_for_demo/__pycache__/demo_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSUBioGroup/DeepDTAGen/HEAD/DEMO/required_files_for_demo/__pycache__/demo_utils.cpython-38.pyc -------------------------------------------------------------------------------- /Data_files.txt: -------------------------------------------------------------------------------- 1 | If you'd like to retrain our model, please download the data file from the following link: 2 | **Data File:** [Download Here](https://drive.google.com/file/d/1vZlbluFt9lmYVUADBZaP-N4mnuY-ETeJ/view?usp=drive_link) 3 | 4 | Alternatively, if you want to test our pretrained models, you can download them from the link below: 5 | **Pretrained Models:** [Download Here](https://drive.google.com/file/d/1p7cnV2CVgI2gD34EsnzwdfDlTEOPKd7y/view?usp=drive_link) 6 | -------------------------------------------------------------------------------- /data/Data_files.txt: -------------------------------------------------------------------------------- 1 | If you'd like to retrain our model, please download the data file from the following link: 2 | **Data File:** [Download Here](https://drive.google.com/file/d/1vZlbluFt9lmYVUADBZaP-N4mnuY-ETeJ/view?usp=drive_link) 3 | 4 | Alternatively, if you want to test our pretrained models, you can download them from the link below: 5 | **Pretrained Models:** [Download Here](https://drive.google.com/file/d/1p7cnV2CVgI2gD34EsnzwdfDlTEOPKd7y/view?usp=drive_link) 6 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: DeepDTAGen 2 | 3 | channels: 4 | - pytorch 5 | - conda-forge 6 | - defaults 7 | 8 | dependencies: 9 | - numpy=1.23.5=py38h7042d01_0 10 | - pandas=1.5.2=py38h8f669ce_0 11 | - pip=22.3.1=pyhd8ed1ab_0 12 | - python=3.8.15=h4a9ceb5_0_cpython 13 | - rdkit=2022.09.1=py38h5acb366_1 14 | - tqdm=4.64.1=pyhd8ed1ab_0 15 | - requests 16 | - scikit-learn 17 | - scipy 18 | - torchaudio 19 | - torchvision 20 | - pip: 21 | - --extra-index-url https://download.pytorch.org/whl/cu102 22 | - einops==0.6.0 23 | - fairseq==0.10.2 24 | - torch==1.12.1+cu102 25 | -------------------------------------------------------------------------------- /generation_eveluation.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from rdkit import Chem 3 | import argparse 4 | import os 5 | 6 | 7 | def is_valid_smiles(smiles: str) -> bool: 8 | """Check if a SMILES string is chemically valid.""" 9 | try: 10 | return Chem.MolFromSmiles(smiles) is not None 11 | except: 12 | return False 13 | 14 | 15 | def evaluate_smiles(smiles_list, reference_set=None): 16 | """Evaluate validity, uniqueness, and novelty of SMILES.""" 17 | valid = [s for s in smiles_list if is_valid_smiles(s)] 18 | unique = set(valid) 19 | novel = [s for s in unique if s not in reference_set] if reference_set else list(unique) 20 | 21 | return { 22 | "validity_ratio": len(valid) / len(smiles_list) if smiles_list else 0, 23 | "uniqueness_ratio": len(unique) / len(valid) if valid else 0, 24 | "novelty_ratio": len(novel) / len(unique) if unique else 0 25 | } 26 | 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--dataset', type=str, default='bindingdb', help='Dataset prefix (default: "bindingdb")') 31 | args = parser.parse_args() 32 | 33 | file_path = f"{args.dataset}_generated.csv" 34 | if not os.path.exists(file_path): 35 | print(f"Error: File '{file_path}' not found.") 36 | return 37 | 38 | df = pd.read_csv(file_path) 39 | if 'Generated_SMILES' not in df.columns: 40 | print("Error: Column 'Generated_SMILES' not found in the dataset.") 41 | return 42 | 43 | generated = df['Generated_SMILES'].dropna().tolist() 44 | reference_set = set(df['target_smiles'].dropna()) if 'target_smiles' in df.columns else None 45 | 46 | results = evaluate_smiles(generated, reference_set) 47 | 48 | print(f"Validity Ratio : {results['validity_ratio']:.2f}") 49 | print(f"Uniqueness Ratio : {results['uniqueness_ratio']:.2f}") 50 | print(f"Novelty Ratio : {results['novelty_ratio']:.2f}") 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /DEMO/DEMO_Affinity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pickle 4 | from tqdm import tqdm 5 | import pandas as pd 6 | from torch.utils.data import DataLoader 7 | from required_files_for_demo.demo_utils import * 8 | 9 | from required_files_for_demo.model_aff import DeepDTAGen 10 | 11 | def demo(): 12 | dataset_name = 'bindingdb' 13 | 14 | # Setup device 15 | # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | device = torch.device('cpu') 17 | 18 | # Paths 19 | model_path = f'./models/deepdtagen_model_{dataset_name}.pth' 20 | tokenizer_path = f'./data/{dataset_name}_tokenizer.pkl' 21 | 22 | smiles = "O=C(c1nc(NS(=O)(=O)c2cc(Br)cc(Cl)c2O)cn1C1CCCC1)N1CCC(C2CCCN2)CC1" 23 | protein_sequence = "MATEEKKPETEAARAQPTPSSSATQSKPTPVKPNYALKFTLAGHTKAVSSVKFSPNGEWLASSSADKLIKIWGAYDGKFEKTISGHKLGISDVAWSSDSNLLVSASDDKTLKIWDVSSGKCLKTLKGHSNYVFCCNFNPQSNLIVSGSFDESVRIWDVKTGKCLKTLPAHSDPVSAVHFNRDGSLIVSSSYDGLCRIWDTASGQCLKTLIDDDNPPVSFVKFSPNGKYILAATLDNTLKLWDYSKGKCLKTYTGHKNEKYCIFANFSVTGGKWIVSGSEDNLVYIWNLQTKEIVQKLQGHTDVVISTACHPTENIIASAALENDKTIKLWKSDC" 24 | 25 | # Load tokenizer 26 | with open(tokenizer_path, 'rb') as f: 27 | tokenizer = pickle.load(f) 28 | 29 | # Load model 30 | model = DeepDTAGen(tokenizer) 31 | model.load_state_dict(torch.load(model_path, map_location=device)) 32 | model.to(device) 33 | 34 | # Load test data 35 | processed_data = f'./data/processed/{smiles}.pt' 36 | if not os.path.isfile(processed_data): 37 | test_data = process_latent_a(smiles, protein_sequence) 38 | else: 39 | test_data = torch.load(processed_data) 40 | print(test_data) 41 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False, collate_fn=collate) 42 | 43 | # Evaluate the model 44 | model.eval() 45 | with torch.no_grad(): 46 | for data in tqdm(test_loader, desc='Testing'): 47 | predictions = model(data.to(device)) 48 | print("Predicted Affinity :", predictions) 49 | 50 | if __name__ == "__main__": 51 | demo() -------------------------------------------------------------------------------- /DEMO/DEMO_Generation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pickle 4 | from tqdm import tqdm 5 | import pandas as pd 6 | from torch.utils.data import DataLoader 7 | from required_files_for_demo.demo_utils import * 8 | 9 | from required_files_for_demo.model_gen import DeepDTAGen 10 | 11 | def demo(): 12 | dataset_name = 'bindingdb' 13 | 14 | # Setup device 15 | # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | device = torch.device('cpu') 17 | 18 | # Paths 19 | model_path = f'./models/deepdtagen_model_{dataset_name}.pth' 20 | tokenizer_path = f'./data/{dataset_name}_tokenizer.pkl' 21 | 22 | smiles = "O=C(c1nc(NS(=O)(=O)c2cc(Br)cc(Cl)c2O)cn1C1CCCC1)N1CCC(C2CCCN2)CC1" 23 | protein_sequence = "MATEEKKPETEAARAQPTPSSSATQSKPTPVKPNYALKFTLAGHTKAVSSVKFSPNGEWLASSSADKLIKIWGAYDGKFEKTISGHKLGISDVAWSSDSNLLVSASDDKTLKIWDVSSGKCLKTLKGHSNYVFCCNFNPQSNLIVSGSFDESVRIWDVKTGKCLKTLPAHSDPVSAVHFNRDGSLIVSSSYDGLCRIWDTASGQCLKTLIDDDNPPVSFVKFSPNGKYILAATLDNTLKLWDYSKGKCLKTYTGHKNEKYCIFANFSVTGGKWIVSGSEDNLVYIWNLQTKEIVQKLQGHTDVVISTACHPTENIIASAALENDKTIKLWKSDC" 24 | conditional_affinity = 6.0 25 | # Load tokenizer 26 | with open(tokenizer_path, 'rb') as f: 27 | tokenizer = pickle.load(f) 28 | 29 | # Load model 30 | model = DeepDTAGen(tokenizer) 31 | model.load_state_dict(torch.load(model_path, map_location=device)) 32 | model.to(device) 33 | 34 | # Load test data 35 | processed_data = f'./data/processed/{smiles}.pt' 36 | if not os.path.isfile(processed_data): 37 | test_data = process_latent(smiles, protein_sequence, conditional_affinity) 38 | else: 39 | test_data = torch.load(processed_data) 40 | print(test_data) 41 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False, collate_fn=collate) 42 | 43 | # Evaluate the model 44 | model.eval() 45 | with torch.no_grad(): 46 | for data in tqdm(test_loader, desc='Testing'): 47 | res = tokenizer.get_text(model.generate(data.to(device))) 48 | print("Generated Drug :", res) 49 | 50 | if __name__ == "__main__": 51 | demo() 52 | -------------------------------------------------------------------------------- /DEMO/affinity_generation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pickle 4 | from tqdm import tqdm 5 | import pandas as pd 6 | from torch.utils.data import DataLoader 7 | from required_files_for_demo.demo_utils import * 8 | 9 | from required_files_for_demo.model_aff import DeepDTAGen 10 | 11 | def demo(): 12 | dataset_name = 'bindingdb' 13 | 14 | # Setup device 15 | # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | device = torch.device('cpu') 17 | 18 | # Paths 19 | model_path = f'./models/deepdtagen_model_{dataset_name}.pth' 20 | tokenizer_path = f'./data/{dataset_name}_tokenizer.pkl' 21 | 22 | smiles = "O=C(c1nc(NS(=O)(=O)c2cc(Br)cc(Cl)c2O)cn1C1CCCC1)N1CCC(C2CCCN2)CC1" 23 | protein_sequence = "MATEEKKPETEAARAQPTPSSSATQSKPTPVKPNYALKFTLAGHTKAVSSVKFSPNGEWLASSSADKLIKIWGAYDGKFEKTISGHKLGISDVAWSSDSNLLVSASDDKTLKIWDVSSGKCLKTLKGHSNYVFCCNFNPQSNLIVSGSFDESVRIWDVKTGKCLKTLPAHSDPVSAVHFNRDGSLIVSSSYDGLCRIWDTASGQCLKTLIDDDNPPVSFVKFSPNGKYILAATLDNTLKLWDYSKGKCLKTYTGHKNEKYCIFANFSVTGGKWIVSGSEDNLVYIWNLQTKEIVQKLQGHTDVVISTACHPTENIIASAALENDKTIKLWKSDC" 24 | conditional_affinity = 6.0 25 | # Load tokenizer 26 | with open(tokenizer_path, 'rb') as f: 27 | tokenizer = pickle.load(f) 28 | 29 | # Load model 30 | model = DeepDTAGen(tokenizer) 31 | model.load_state_dict(torch.load(model_path, map_location=device)) 32 | model.to(device) 33 | 34 | # Load test data 35 | processed_data = f'./data/processed/{smiles}.pt' 36 | if not os.path.isfile(processed_data): 37 | test_data = process_latent(smiles, protein_sequence, conditional_affinity) 38 | else: 39 | test_data = torch.load(processed_data) 40 | # print(test_data) 41 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False, collate_fn=collate) 42 | 43 | # Evaluate the model 44 | model.eval() 45 | with torch.no_grad(): 46 | for data in tqdm(test_loader, desc='Testing'): 47 | predictions = model(data.to(device)) 48 | res = tokenizer.get_text(model.generate(data.to(device))) 49 | print("Generated Drug :", res) 50 | print("Predicted Affinity:", predictions.item()) 51 | 52 | if __name__ == "__main__": 53 | demo() 54 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | from pathlib import Path 4 | import pandas as pd 5 | from rdkit import RDLogger, Chem 6 | from torch.utils.data import DataLoader 7 | from tqdm.auto import tqdm 8 | from model import DeepDTAGen 9 | from utils import * 10 | import torch 11 | 12 | RDLogger.DisableLog('rdApp.*') 13 | 14 | 15 | def load_model(model_path, tokenizer_path): 16 | with open(tokenizer_path, 'rb') as f: 17 | tokenizer = pickle.load(f) 18 | 19 | model = DeepDTAGen(tokenizer) 20 | states = torch.load(model_path, map_location='cpu') 21 | print(model.load_state_dict(states, strict=False)) 22 | 23 | return model, tokenizer 24 | 25 | 26 | def format_smiles(smiles): 27 | mol = Chem.MolFromSmiles(smiles) 28 | if mol is None: 29 | return None 30 | return Chem.MolToSmiles(mol, isomericSmiles=True) 31 | 32 | 33 | if __name__ == '__main__': 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('dataset', type=str, choices=['kiba', 'bindingdb'], help='the dataset name (kiba or bindingdb)') 36 | parser.add_argument('--device', type=str, default='cpu', choices=['cpu', 'cuda'], help='device to use (cpu or cuda)') 37 | args = parser.parse_args() 38 | 39 | dataset = args.dataset 40 | device = args.device 41 | 42 | config = { 43 | 'input_path': f'data/{dataset}_test.csv', 44 | 'model_path': f'models/deepdtagen_model_{dataset}.pth', 45 | 'tokenizer_path': f'data/{dataset}_tokenizer.pkl', 46 | 'n_mol': 40000, 47 | 'filter': True, 48 | 'batch_size': 1, 49 | 'seed': -1 50 | } 51 | 52 | # Load the input CSV 53 | input_df = pd.read_csv(config['input_path']) 54 | input_df['Generated_SMILES'] = None 55 | 56 | # Load dataset and model 57 | test_data = TestbedDataset(root="data", dataset=f"{dataset}_test") 58 | test_loader = DataLoader(test_data, batch_size=1, shuffle=False) 59 | model, tokenizer = load_model(config['model_path'], config['tokenizer_path']) 60 | 61 | model.eval() 62 | model.to(device) 63 | 64 | # Generate SMILES 65 | for i, data in enumerate(tqdm(test_loader)): 66 | data.to(device) 67 | generated = tokenizer.get_text(model.generate(data)) 68 | generated = generated[:config['n_mol']] 69 | 70 | if config['filter']: 71 | generated = [format_smiles(smi) for smi in generated] 72 | generated = [smi for smi in generated if smi] 73 | 74 | input_df.loc[i, 'Generated_SMILES'] = generated[0] if generated else None 75 | 76 | output_path = Path(f"{dataset}_generated.csv") 77 | input_df.to_csv(output_path, index=False) 78 | print(f'Generation complete. Output saved to {output_path}') 79 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pickle 4 | from tqdm import tqdm 5 | from torch.utils.data import DataLoader 6 | import argparse 7 | 8 | from utils import * 9 | from model import DeepDTAGen 10 | 11 | def main(dataset_name): 12 | # Setup device 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | 15 | # Paths 16 | model_path = f'models/deepdtagen_model_{dataset_name}.pth' 17 | tokenizer_path = f'data/{dataset_name}_tokenizer.pkl' 18 | test_batch_size = 128 19 | 20 | # Threshold values based on the dataset 21 | if dataset_name == 'kiba': 22 | thresholds = [10.0, 10.50, 11.0, 11.50, 12.0, 12.50] 23 | aupr_threshold = 12.1 24 | elif dataset_name in ['davis', 'bindingdb']: 25 | thresholds = [5.0, 5.50, 6.0, 6.50, 7.0, 7.50, 8.0, 8.50] 26 | aupr_threshold = 7.0 27 | else: 28 | raise ValueError(f"Unknown dataset: {dataset_name}") 29 | 30 | # Load tokenizer 31 | with open(tokenizer_path, 'rb') as f: 32 | tokenizer = pickle.load(f) 33 | 34 | # Load model 35 | model = DeepDTAGen(tokenizer) 36 | model.load_state_dict(torch.load(model_path, map_location=device)) 37 | model.to(device) 38 | 39 | # Load test data 40 | test_data = TestbedDataset(root='data', dataset=f'{dataset_name}_test') 41 | test_loader = DataLoader(test_data, batch_size=test_batch_size, shuffle=False) 42 | 43 | # Evaluate the model 44 | model.eval() 45 | total_predict = torch.Tensor().to(device) 46 | total_true = torch.Tensor().to(device) 47 | 48 | with torch.no_grad(): 49 | for data in tqdm(test_loader, desc='Testing'): 50 | predictions, _, lm_loss, kl_loss = model(data.to(device)) 51 | total_true = torch.cat((total_true, data.y.view(-1, 1)), dim=0) 52 | total_predict = torch.cat((total_predict, predictions), dim=0) 53 | 54 | # Convert to numpy arrays 55 | ground_truth = total_true.cpu().numpy().flatten() 56 | predicted = total_predict.cpu().numpy().flatten() 57 | 58 | # Calculate metrics 59 | mse_loss = mse(ground_truth, predicted) 60 | concordance_index = get_cindex(ground_truth, predicted) 61 | rm2_value = get_rm2(ground_truth, predicted) 62 | rms_error = rmse(ground_truth, predicted) 63 | pearson_corr = pearson(ground_truth, predicted) 64 | spearman_corr = spearman(ground_truth, predicted) 65 | aupr_value = get_aupr(predicted, ground_truth, aupr_threshold) 66 | 67 | # Calculate AUC values for each threshold 68 | auc_values = [ 69 | get_auc((predictions.cpu() > threshold).int(), data.y.view(-1, 1).float().cpu()) 70 | for threshold in thresholds 71 | ] 72 | 73 | # Print the results 74 | print(f'MSE: {mse_loss:.4f}, CI: {concordance_index:.4f}, RM2: {rm2_value:.4f}') 75 | print(f'RMS Error: {rms_error}') 76 | print(f'PPC: {pearson_corr:.4f}, Spearman: {spearman_corr:.4f}') 77 | print(f'AUC Values: {auc_values}') 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser(description='Evaluate DeepDTAGen on a dataset.') 81 | parser.add_argument('--dataset', type=str, required=True, help='Name of the dataset (e.g., kiba, davis, bindingdb)') 82 | args = parser.parse_args() 83 | 84 | main(args.dataset) 85 | -------------------------------------------------------------------------------- /create_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import os 4 | import json,pickle 5 | from collections import OrderedDict 6 | from rdkit import Chem 7 | from rdkit.Chem import MolFromSmiles 8 | import networkx as nx 9 | from utils import * 10 | import re 11 | from typing import List 12 | 13 | def one_of_k_encoding(x, allowable_set): 14 | if x not in allowable_set: 15 | x = allowable_set[-1] 16 | return [x == s for s in allowable_set] 17 | 18 | def one_of_k_encoding_unk(x, allowable_set): 19 | if x not in allowable_set: 20 | x = allowable_set[-1] 21 | return [x == s for s in allowable_set] + [x not in allowable_set] 22 | 23 | def atom_features(atom): 24 | return np.array(one_of_k_encoding_unk(atom.GetSymbol(),['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na','Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb','Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H','Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr','Cr', 'Pt', 'Hg', 'Pb', 'Unknown']) + #Atom symbol 25 | one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + #Number of adjacent atoms 26 | one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + # Number of adjacent hydrogens 27 | one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + #Implicit valence 28 | one_of_k_encoding_unk(atom.GetFormalCharge(), [-1, -2, 1, 2, 0]) + #Formal charge 29 | one_of_k_encoding_unk(atom.GetHybridization(), [Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2]) + #Hybridization 30 | [atom.GetIsAromatic()] + #Aromaticity 31 | [atom.IsInRing()] #In ring 32 | ) 33 | 34 | def bond_features(bond): 35 | bt = bond.GetBondType() 36 | bond_feats = [0, 0, 0, 0, bond.GetBondTypeAsDouble()] 37 | if bt == Chem.rdchem.BondType.SINGLE: 38 | bond_feats = [1, 0, 0, 0, bond.GetBondTypeAsDouble()] 39 | elif bt == Chem.rdchem.BondType.DOUBLE: 40 | bond_feats = [0, 1, 0, 0, bond.GetBondTypeAsDouble()] 41 | elif bt == Chem.rdchem.BondType.TRIPLE: 42 | bond_feats = [0, 0, 1, 0, bond.GetBondTypeAsDouble()] 43 | elif bt == Chem.rdchem.BondType.AROMATIC: 44 | bond_feats = [0, 0, 0, 1, bond.GetBondTypeAsDouble()] 45 | return np.array(bond_feats) 46 | 47 | def smile_to_graph(smile): 48 | mol = Chem.MolFromSmiles(smile) 49 | 50 | c_size = mol.GetNumAtoms() 51 | 52 | features = [] 53 | for atom in mol.GetAtoms(): 54 | feature = atom_features(atom) 55 | features.append(feature / sum(feature)) 56 | 57 | edges = [] 58 | for bond in mol.GetBonds(): 59 | edge_feats = bond_features(bond) 60 | edges.append((bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), {'edge_feats': edge_feats})) 61 | 62 | g = nx.Graph() 63 | g.add_edges_from(edges) 64 | g = g.to_directed() 65 | edge_index = [] 66 | edge_feats = [] 67 | for e1, e2, feats in g.edges(data=True): 68 | edge_index.append([e1, e2]) 69 | edge_feats.append(feats['edge_feats']) 70 | 71 | return c_size, features, edge_index, edge_feats 72 | 73 | def smile_parse(smiles, tokenizer: Tokenizer): 74 | tokenizer = Tokenizer(Tokenizer.gen_vocabs(smiles)) 75 | smi = tokenizer.parse(smiles) 76 | return smi 77 | 78 | def seq_cat(prot): 79 | x = np.zeros(max_seq_len) 80 | for i, ch in enumerate(prot[:max_seq_len]): 81 | x[i] = seq_dict[ch] 82 | return x 83 | 84 | seq_voc = "ABCDEFGHIKLMNOPQRSTUVWXYZ" 85 | seq_dict = {v:(i+1) for i,v in enumerate(seq_voc)} 86 | seq_dict_len = len(seq_dict) 87 | max_seq_len = 1000 88 | 89 | compound_iso_smiles = [] 90 | for dt_name in ['kiba','davis', 'bindingdb']: 91 | opts = ['train','test'] 92 | for opt in opts: 93 | df = pd.read_csv('data/' + dt_name + '_' + opt + '.csv') 94 | compound_iso_smiles += list( df['compound_iso_smiles'] ) 95 | compound_iso_smiles = set(compound_iso_smiles) 96 | smile_graph = {} 97 | for smile in compound_iso_smiles: 98 | g = smile_to_graph(smile) 99 | smile_graph[smile] = g 100 | dir = 'data' 101 | datasets = ['davis', 'kiba', 'bindingdb'] 102 | # convert to PyTorch data format 103 | for dataset in datasets: 104 | processed_data_file_train = 'data/processed/' + dataset + '_train.pt' 105 | processed_data_file_test = 'data/processed/' + dataset + '_test.pt' 106 | tokenizer_file = f'{dir}/{dataset}_tokenizer.pkl' 107 | if ((not os.path.isfile(processed_data_file_train)) or (not os.path.isfile(processed_data_file_test))): 108 | df_train = pd.read_csv('data/' + dataset + '_train.csv') 109 | df_test = pd.read_csv('data/' + dataset + '_test.csv') 110 | 111 | all_smiles = set(df_train['compound_iso_smiles']).union(set(df_test['compound_iso_smiles'])) 112 | tokenizer = Tokenizer(Tokenizer.gen_vocabs(all_smiles)) 113 | 114 | with open(tokenizer_file, 'wb') as file: 115 | pickle.dump(tokenizer, file) 116 | # Process train set 117 | train_drugs, train_MTS, train_prots, train_Y = list(df_train['compound_iso_smiles']), list(df_train['target_smiles']), list(df_train['target_sequence']), list(df_train['affinity']) 118 | XT = [seq_cat(t) for t in train_prots] 119 | train_drugs, train_MTS, train_prots, train_Y = np.asarray(train_drugs), np.asarray(train_MTS), np.asarray(XT), np.asarray(train_Y) 120 | train_XD = [torch.LongTensor(tokenizer.parse(smile)) for smile in train_MTS] 121 | 122 | # Process test set 123 | test_drugs, test_MTS, test_prots, test_Y = list(df_test['compound_iso_smiles']), list(df_test['target_smiles']), list(df_test['target_sequence']), list(df_test['affinity']) 124 | XT = [seq_cat(t) for t in test_prots] 125 | test_drugs, test_MTS, test_prots, test_Y = np.asarray(test_drugs), np.asarray(test_MTS), np.asarray(XT), np.asarray(test_Y) 126 | test_XD = [torch.LongTensor(tokenizer.parse(smile)) for smile in test_MTS] 127 | 128 | print('preparing ', dataset + '_train.pt in pytorch format!') 129 | train_data = TestbedDataset(root='data', dataset=dataset+'_train', xd=train_drugs, xdt=train_XD, xt=train_prots, y=train_Y,smile_graph=smile_graph) 130 | print('preparing ', dataset + '_test.pt in pytorch format!') 131 | test_data = TestbedDataset(root='data', dataset=dataset+'_test', xd=test_drugs, xdt=test_XD, xt=test_prots, y=test_Y,smile_graph=smile_graph) 132 | print(processed_data_file_train, ' and ', processed_data_file_test, ' have been created') 133 | else: 134 | print(processed_data_file_train, ' and ', processed_data_file_test, ' are already created') 135 | -------------------------------------------------------------------------------- /FetterGrad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import pdb 6 | import numpy as np 7 | import copy 8 | import random 9 | 10 | 11 | class FetterGrad(): 12 | def __init__(self, optimizer, reduction='mean'): 13 | self._optim, self._reduction = optimizer, reduction 14 | return 15 | 16 | @property 17 | def optimizer(self): 18 | return self._optim 19 | 20 | def zero_grad(self): 21 | ''' 22 | clear the gradient of the parameters 23 | ''' 24 | 25 | return self._optim.zero_grad(set_to_none=True) 26 | 27 | def step(self): 28 | ''' 29 | update the parameters with the gradient 30 | ''' 31 | 32 | return self._optim.step() 33 | 34 | def ft_backward(self, objectives): 35 | ''' 36 | calculate the gradient of the parameters 37 | 38 | input: 39 | - objectives: a list of objectives 40 | ''' 41 | 42 | grads, shapes, has_grads = self._pack_grad(objectives) 43 | fetter_grad = self._project_conflicting(grads, has_grads) 44 | fetter_grad = self._unflatten_grad(fetter_grad, shapes[0]) 45 | self._set_grad(fetter_grad) 46 | return 47 | 48 | def _project_conflicting(self, grads, has_grads, shapes=None): 49 | shared = torch.stack(has_grads).prod(0).bool() 50 | fetter_grad, num_task = copy.deepcopy(grads), len(grads) 51 | for g_i in fetter_grad: 52 | random.shuffle(grads) 53 | for g_j in grads: 54 | g_i_g_j = torch.dist(g_i, g_j, p=2) # Euclidean distance 55 | g_i_g_j = 1 / (1 + g_i_g_j) 56 | if g_i_g_j < 0.5: 57 | g_i += (g_i_g_j) * g_j 58 | merged_grad = torch.zeros_like(grads[0]).to(grads[0].device) 59 | if self._reduction: 60 | merged_grad[shared] = torch.stack([g[shared] 61 | for g in fetter_grad]).mean(dim=0) 62 | elif self._reduction == 'sum': 63 | merged_grad[shared] = torch.stack([g[shared] 64 | for g in fetter_grad]).sum(dim=0) 65 | else: exit('invalid reduction method') 66 | 67 | merged_grad[~shared] = torch.stack([g[~shared] 68 | for g in fetter_grad]).sum(dim=0) 69 | return merged_grad 70 | 71 | def _set_grad(self, grads): 72 | ''' 73 | set the modified gradients to the network 74 | ''' 75 | 76 | idx = 0 77 | for group in self._optim.param_groups: 78 | for p in group['params']: 79 | # if p.grad is None: continue 80 | p.grad = grads[idx] 81 | idx += 1 82 | return 83 | 84 | def _pack_grad(self, objectives): 85 | ''' 86 | pack the gradient of the parameters of the network for each objective 87 | 88 | output: 89 | - grad: a list of the gradient of the parameters 90 | - shape: a list of the shape of the parameters 91 | - has_grad: a list of mask represent whether the parameter has gradient 92 | ''' 93 | 94 | grads, shapes, has_grads = [], [], [] 95 | for obj in objectives: 96 | self._optim.zero_grad(set_to_none=True) 97 | obj.backward(retain_graph=True) 98 | grad, shape, has_grad = self._retrieve_grad() 99 | grads.append(self._flatten_grad(grad, shape)) 100 | has_grads.append(self._flatten_grad(has_grad, shape)) 101 | shapes.append(shape) 102 | return grads, shapes, has_grads 103 | 104 | def _unflatten_grad(self, grads, shapes): 105 | unflatten_grad, idx = [], 0 106 | for shape in shapes: 107 | length = np.prod(shape) 108 | unflatten_grad.append(grads[idx:idx + length].view(shape).clone()) 109 | idx += length 110 | return unflatten_grad 111 | 112 | def _flatten_grad(self, grads, shapes): 113 | flatten_grad = torch.cat([g.flatten() for g in grads]) 114 | return flatten_grad 115 | 116 | def _retrieve_grad(self): 117 | ''' 118 | get the gradient of the parameters of the network with specific 119 | objective 120 | 121 | output: 122 | - grad: a list of the gradient of the parameters 123 | - shape: a list of the shape of the parameters 124 | - has_grad: a list of mask represent whether the parameter has gradient 125 | ''' 126 | 127 | grad, shape, has_grad = [], [], [] 128 | for group in self._optim.param_groups: 129 | for p in group['params']: 130 | # if p.grad is None: continue 131 | # tackle the multi-head scenario 132 | if p.grad is None: 133 | shape.append(p.shape) 134 | grad.append(torch.zeros_like(p).to(p.device)) 135 | has_grad.append(torch.zeros_like(p).to(p.device)) 136 | continue 137 | shape.append(p.grad.shape) 138 | grad.append(p.grad.clone()) 139 | has_grad.append(torch.ones_like(p).to(p.device)) 140 | return grad, shape, has_grad 141 | 142 | 143 | class TestNet(nn.Module): 144 | def __init__(self): 145 | super().__init__() 146 | self._linear = nn.Linear(3, 4) 147 | 148 | def forward(self, x): 149 | return self._linear(x) 150 | 151 | 152 | class MultiHeadTestNet(nn.Module): 153 | def __init__(self): 154 | super().__init__() 155 | self._linear = nn.Linear(3, 2) 156 | self._head1 = nn.Linear(2, 4) 157 | self._head2 = nn.Linear(2, 4) 158 | 159 | def forward(self, x): 160 | feat = self._linear(x) 161 | return self._head1(feat), self._head2(feat) 162 | 163 | 164 | if __name__ == '__main__': 165 | 166 | # fully shared network test 167 | torch.manual_seed(4) 168 | x, y = torch.randn(2, 3), torch.randn(2, 4) 169 | net = TestNet() 170 | y_pred = net(x) 171 | pc_adam = FetterGrad(optim.Adam(net.parameters())) 172 | pc_adam.zero_grad() 173 | loss1_fn, loss2_fn = nn.L1Loss(), nn.MSELoss() 174 | loss1, loss2 = loss1_fn(y_pred, y), loss2_fn(y_pred, y) 175 | 176 | pc_adam.ft_backward([loss1, loss2]) 177 | for p in net.parameters(): 178 | print(p.grad) 179 | 180 | print('-' * 80) 181 | # seperated shared network test 182 | 183 | torch.manual_seed(4) 184 | x, y = torch.randn(2, 3), torch.randn(2, 4) 185 | net = MultiHeadTestNet() 186 | y_pred_1, y_pred_2 = net(x) 187 | pc_adam = FetterGrad(optim.Adam(net.parameters())) 188 | pc_adam.zero_grad() 189 | loss1_fn, loss2_fn = nn.MSELoss(), nn.MSELoss() 190 | loss1, loss2 = loss1_fn(y_pred_1, y), loss2_fn(y_pred_2, y) 191 | 192 | pc_adam.ft_backward([loss1, loss2]) 193 | for p in net.parameters(): 194 | print(p.grad) -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | 6 | from utils import * 7 | from model import DeepDTAGen 8 | from FetterGrad import FetterGrad 9 | 10 | from tqdm import tqdm 11 | import sys, os 12 | import time 13 | import pickle 14 | import random 15 | 16 | seed = 4221 17 | np.random.seed(seed) 18 | random.seed(seed) 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | 22 | if torch.cuda.is_available(): 23 | generator = torch.Generator('cuda').manual_seed(seed) 24 | else: 25 | generator = torch.Generator().manual_seed(seed) 26 | 27 | 28 | """Train the DeepDTAGen model using the specified data and hyperparameters.""" 29 | 30 | def train(model, device, train_loader, optimizer, mse_f, epoch, train_data, FLAGS): 31 | model.train() 32 | 33 | with tqdm(train_loader, desc=f"Epoch {epoch + 1}") as t: 34 | for i, data in enumerate(t): 35 | optimizer.zero_grad() 36 | batch = data.batch.to(device) 37 | Pridection, new_drug, lm_loss, kl_loss = model(data.to(device)) 38 | 39 | mse_loss = mse_f(Pridection, data.y.view(-1, 1).float().to(device)) 40 | 41 | train_ci = get_cindex(Pridection.cpu().detach().numpy(), data.y.view(-1, 1).float().cpu().detach().numpy()) 42 | 43 | loss = kl_loss * 0.001 + mse_loss + lm_loss 44 | # loss.backward() 45 | # optimizer.step() 46 | 47 | losses = [loss, mse_loss] 48 | optimizer.ft_backward(losses) 49 | optimizer.step() 50 | t.set_postfix(MSE=mse_loss.item(), Train_cindex=train_ci, KL=kl_loss.item(), LM=lm_loss.item()) 51 | msg = f"Epoch {epoch+1}, total loss={loss.item()}, MSE={mse_loss.item()}, KL_loss={kl_loss.item()}, LM={lm_loss.item()}" 52 | logging(msg, FLAGS) 53 | return model 54 | 55 | def test(model, device, test_loader, dataset, FLAGS): 56 | """Test the DeepDTAGen model on the specified data and report the results.""" 57 | print('Testing on {} samples...'.format(len(test_loader.dataset))) 58 | model.eval() 59 | total_true = torch.Tensor() 60 | total_predict = torch.Tensor() 61 | total_loss = 0 62 | 63 | if dataset == "kiba": 64 | thresholds = [10.0, 10.50, 11.0, 11.50, 12.0, 12.50] 65 | else: 66 | thresholds = [5.0, 5.50, 6.0, 6.50, 7.0, 7.50, 8.0, 8.50] 67 | 68 | with torch.no_grad(): 69 | for i, data in enumerate(tqdm(test_loader)): 70 | 71 | Pridection, new_drug, lm_loss, kl_loss = model(data.to(device)) 72 | 73 | total_true = torch.cat((total_true, data.y.view(-1, 1).cpu()), 0) 74 | total_predict = torch.cat((total_predict, Pridection.cpu()), 0) 75 | G = total_true.numpy().flatten() 76 | P = total_predict.numpy().flatten() 77 | mse_loss = mse(G, P) 78 | test_ci = get_cindex(G, P) 79 | rm2 = get_rm2(G, P) 80 | auc_values = [] 81 | for t in thresholds: 82 | auc = get_aupr(np.int32(G > t), P) 83 | auc_values.append(auc) 84 | loss = lm_loss + kl_loss 85 | total_loss += loss.item() * data.num_graphs 86 | return total_loss, mse_loss, test_ci, rm2, auc_values, G, P 87 | 88 | def experiment(FLAGS, dataset, device): 89 | logging('Starting program', FLAGS) 90 | 91 | # Hyperparameters 92 | BATCH_SIZE = 32 93 | LR = 0.0002 94 | NUM_EPOCHS = 500 95 | 96 | # Print hyperparameters 97 | print(f"Dataset: {dataset}") 98 | print(f"Device: {device}") 99 | print(f"Batch size: {BATCH_SIZE}") 100 | print(f"Learning rate: {LR}") 101 | print(f"Epochs: {NUM_EPOCHS}") 102 | 103 | # Log hyperparameters 104 | msg = f"Dataset {dataset}, Device {device}, batch size {BATCH_SIZE}, learning rate {LR}, epochs {NUM_EPOCHS}" 105 | logging(msg, FLAGS) 106 | 107 | # Load tokenizer 108 | with open(f'data/{dataset}_tokenizer.pkl', 'rb') as f: 109 | tokenizer = pickle.load(f) 110 | 111 | # Load processed data 112 | processed_data_file_train = f"data/processed/{dataset}_train.pt" 113 | processed_data_file_test = f"data/processed/{dataset}_test.pt" 114 | if not (os.path.isfile(processed_data_file_train) and os.path.isfile(processed_data_file_test)): 115 | print("Please run create_data.py to prepare data in PyTorch format!") 116 | else: 117 | train_data = TestbedDataset(root="data", dataset=f"{dataset}_train") 118 | test_data = TestbedDataset(root="data", dataset=f"{dataset}_test") 119 | 120 | # Prepare PyTorch mini-batches 121 | train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True) 122 | test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False) 123 | 124 | # Initialize model, optimizer, and loss function 125 | model = DeepDTAGen(tokenizer).to(device) 126 | optimizer = FetterGrad(optim.Adam(model.parameters(), lr=LR)) 127 | mse_f = nn.MSELoss() 128 | 129 | # Train model 130 | best_mse = float('inf') 131 | for epoch in range(NUM_EPOCHS): 132 | model = train(model, device, train_loader, optimizer, mse_f, epoch, train_data, FLAGS) 133 | 134 | if (epoch + 1) % 20 == 0: 135 | # Test model 136 | total_loss, mse_loss, test_ci, rm2, auc_values, G, P = test(model, device, test_loader, dataset, FLAGS) 137 | filename = f"saved_models/deepdtagen_model_{dataset}.pth" 138 | if mse_loss < best_mse: 139 | best_mse = mse_loss 140 | torch.save(model.state_dict(), filename) 141 | print('model saved') 142 | 143 | print(f"MSE: {mse_loss.item():.4f}") 144 | print(f"CI: {test_ci:.4f}") 145 | print(f"RM2: {rm2:.4f}") 146 | print(f"AUCs: {', '.join([f'{auc:.4f}' for auc in auc_values])}") 147 | 148 | # Save estimated and true labels 149 | folder_path = "Affinities/" 150 | np.savetxt(folder_path + f"estimated_labels_{dataset}.txt", P) 151 | np.savetxt(folder_path + f"true_labels_{dataset}.txt", G) 152 | 153 | logging('Program finished', FLAGS) 154 | 155 | if __name__ == "__main__": 156 | 157 | datasets = ['davis', 'kiba', 'bindingdb'] 158 | dataset_idx = int(sys.argv[1]) if len(sys.argv) > 1 else 2 159 | dataset = datasets[dataset_idx] 160 | 161 | default_device = "cuda:0" if torch.cuda.is_available() else "cpu" 162 | device = torch.device("cuda:" + str(int(sys.argv[2])) if len(sys.argv) > 2 and torch.cuda.is_available() else default_device) 163 | 164 | FLAGS = lambda: None 165 | FLAGS.log_dir = 'logs' 166 | FLAGS.dataset_name = f'dataset_{dataset}_{int(time.time())}' 167 | 168 | os.makedirs(FLAGS.log_dir, exist_ok=True) 169 | os.makedirs('Affinities', exist_ok=True) 170 | os.makedirs('saved_models', exist_ok=True) 171 | 172 | experiment(FLAGS, dataset, device) 173 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepDTAGen 2 | ## 💡 Description 3 | This is the implementation of DeepDTAGen: Multitask deep learning framework for Predicting Drug-Target Affinity and Generating Target-Specific Drugs. 4 | 5 | ## 📋 Table of Contents 6 | 1. [💡 Description](#description) 7 | 2. [🔍 Dataset](#dataset) 8 | 3. [🧠 Model Architecture](#model-architecture) 9 | 4. [🛠️ Preprocessing](#Preprocessing) 10 | 5. [📊 System Requirements](#System-requirements) 11 | 6. [⚙️ Installation and Requirements](#installation) 12 | 7. [📁 Source codes](#sourcecode) 13 | 8. [🖥️ Demo](#demo) 14 | 9. [🤖🎛️ Training](#training) 15 | 10. [📧 Contact](#contact) 16 | 11. [🙏 Acknowledgments](#acknowledgments) 17 | 18 | 19 | ## 🔍 Datasets: 20 | ### Data: 21 | The data is available in CSV format within the 'data.rar' file. Each file is named according to its respective dataset and whether it is for training or testing. 22 | ## 🧠 Model Architecture 23 | The DeepDTAGen architecture consists of the following components: 24 | 25 | 1. 💊⚛️ **Graph-Encoder module**: The Graph-Encoder module, denoted as q(ZDrug|X,A), is designed to process graph data represented as node feature vectors X and adjacency matrix A. The input data is organized in mini-batches of size 26 | [batch_size, Drug_features], where each drug is characterized by its feature vector.The goal of the Drug Encoder is to transform this high-dimensional input into a lower-dimensional representation. Typically, the Drug Encoder employs a multivariate Gaussian distribution to map the input data points to a continuous range of possible values between 0 and 1. This results in novel features that are derived from the original drug features, providing a new representation of each drug. Further the condition vector C added. However, when dealing with affinity prediction, it is necessary to keep the actual representation of the input drug to make accurate predictions. Thus, we utilized the Drug Encoder to yield a pair of outputs as follows 27 | 28 | (I): For the affinity prediction task, we use the features obtained prior to the mean and log variance operation (PMVO). These features are more appropriate for predicting drug affinity, as they retain the original characteristics of the input drug without being altered by the AMVO process. 29 | 30 | (II): For novel drug generation, we utilize the feature obtained after performing the mean and log variance operation (AMVO). 31 | 32 | 2. 🔄 **Gated-CNN Module for Target-Proteins**: The Gated Convolutional Neural Network (GCNN) block is specifically designed to extract the features of target sequences. The GCNN takes the protein sequences in the form of the embedding matrix, where each amino acid is represented by 128 feature vectors and extracts the features as output. 33 | 3. 💊 **Transformer-Decoder Module**: The Transformer-Decoder p(DrugSMILES|ZDrug) uses latent space (AMVO) and Modified Target SMILES (MST) and generates novel drug SMILES in an autoregressive manner ((More details are available in the main article section 1.3)). 34 | 35 | 36 | 4. 🎯 **Prediction (Fully-Connected Module)**: The prediction block utilizes the extracted features from the Drug Encoder (PMVO) and GCNN for target proteins and predicts the affinity between the given drug and the target. 37 | 38 | ![Model](model.jpg) 39 | 40 | ##🛠️ Preprocessing 41 | + Drugs: The SMILES string representation are converted to the chemical structure using the RDKit library. We then use NetworkX to further convert it to graph representation. 42 | + Proteins: The protein sequence convert it into a numerical representation using label encoding. Further some more steps preprocessing steps were applied (more detail are provided in the main text). 43 | 44 | 45 | ## System requirements 46 | + Operating System: Ubuntu 16.04.7 LTS 47 | + CPU: Intel(R) Xeon(R) Silver 4114 CPU @ 2.20GHz 48 | + GPU: GeForce RTX 2080 Ti 49 | + CUDA: 10.2 50 | 51 | 52 | ## ⚙️ Installation and Requirements 53 | You'll need to run the following commands in order to run the codes 54 | ```sh 55 | conda env create -f environment.yml 56 | ``` 57 | it will download all the required libraries 58 | 59 | Or install Manually... 60 | ```sh 61 | conda create -n DeepDTAGen python=3.8 62 | conda activate DeepDTAGen 63 | + python 3.8.11 64 | + conda install -y -c conda-forge rdkit 65 | + conda install pytorch torchvision cudatoolkit -c pytorch 66 | ``` 67 | ```sh 68 | pip install torch-cluster==1.6.0+pt112cu102 69 | ``` 70 | ```sh 71 | pip install torch-scatter==2.1.0+pt112cu102 72 | ``` 73 | ```sh 74 | pip install torch-sparse==0.6.16+pt112cu102 75 | ``` 76 | ```sh 77 | pip install torch-spline-conv==1.2.1+pt112cu102 78 | ``` 79 | ```sh 80 | pip install torch-geometric==2.2.0 81 | ``` 82 | ```sh 83 | pip pip install fairseq==0.10.2 84 | ``` 85 | ```sh 86 | pip install einops==0.6.0 87 | ``` 88 | + The whole installation maximum takes about 30 minutes. 89 | 90 | ## 📁 Source codes: 91 | The whole implementation of DeepDTAGen is based on PyTorch. 92 | 93 | + create_data.py: This script generates data in PyTorch format. 94 | + utils.py: Within this module, there's a variety of useful functions and classes employed by other scripts within the codebase. One notable class is TestbedDataset, which is specifically utilized by create_data.py to generate data in PyTorch format. Additionally, there's the tokenizer class responsible for preparing data for the transformer decoder. 95 | + training.py: This module will train the DeepDTAGen model. 96 | + models.py: This module receives graph data as input for drugs while sequencing data for protein with corresponding actual labels (Affinity values). 97 | + FetterGrads.py: This script FetterGrad.py is the implementation of our proposed algorithm Fetter Gradients. 98 | + test.py: The script test.py is utilized to assess the performance of our saved models. 99 | + generata.py: The generate.py script is employed to create drugs based on a given condition using latent space and random noise. 100 | 101 | ## Demo 102 | We have provided a DEMO directory, having two files "DEMO_Affinity.py" and "DEMO_Generation.py". "DEMO_Affinity.py" can be used to demonstrate affinity prediction, allowing users to test our model using a sample input. While "DEMO_Generation.py", can be used for drug generation, providing a test case for evaluating our model's performance in generating drugs. 103 | + DEMO_Affinity.py for affinity prediction 104 | + DEMO_Generation.py for drug generation. 105 | Running these files takes approximately 1 to 2 seconds. 106 | Expected results for the given input in the DEMO_Affinity.py is (predicted affinity between the given inputs: 6.255425453186035) 107 | Expected result for the given input in the DEMO_Generation.py is (generated drug: O=C(c1cc(C(F)(F)F)ccc1F)N(C1CCN(C(=O)c2ccc(Br)cc2)CC1)C(=O)N1CCCC1 based on the given input) 108 | 109 | ## 🤖🎛️ Training 110 | The DeepDTAGen is trained using PyTorch and PyTorch Geometric libraries, with the support of NVIDIA GeForce RTX 2080 Ti GPU for the back-end hardware. 111 | 112 | i.Create Data 113 | ```sh 114 | conda activate DeepDTAGen 115 | python create_data.py 116 | ``` 117 | The create_data.py script generates four PyTorch-formatted data files from: kiba_train.csv, kiba_test.csv, davis_train.csv, davis_test.csv, bindingdb_train.csv, and bindingdb_test.csv and store it data/processed/, consisting of kiba_train.pt, kiba_test.pt, davis_train.pt, davis_test.pt, bindingdb_train.pt, and bindingdb_test.pt. 118 | 119 | ii. Train the model 120 | ```sh 121 | conda activate DeepDTAGen 122 | python training.py 123 | ``` 124 | ## 💊 Molecule Generation 125 | To generate molecules using the trained model, simply run the following script 126 | ```sh 127 | python generate.py 128 | ``` 129 | ## 📊 Model Evaluation 130 | To evaluate the performance of the predictive model, run the following command 131 | ```sh 132 | python test.py 133 | ``` 134 | ## 🎯 Generative Model Evaluation 135 | To evaluate the generative performance of the model, run 136 | ```sh 137 | python generation_evaluation.py 138 | ``` 139 | 140 | ## 📧 Contact 141 | Have a question? or suggestion Feel free to reach out to me!. 142 | 143 | **📨 Email:** [Connect with me](pirmasoomshah@gmail.com) 144 | **🌐 Google Site:** [Pir Masoom Shah](https://sites.google.com/view/pirmasoomshah/home?authuser=0) 145 | 146 | ## 📜 Reference 147 | paper reference 148 | 149 | 152 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from math import sqrt 4 | from scipy import stats 5 | from torch_geometric.data import InMemoryDataset 6 | from torch_geometric.loader import DataLoader 7 | from sklearn.metrics import auc,precision_recall_curve 8 | from torch_geometric import data as DATA 9 | import torch 10 | from torch.nn.utils.rnn import pad_sequence 11 | import numpy as np 12 | import torch 13 | from torch.utils.data import Dataset 14 | from tqdm.auto import tqdm 15 | import re 16 | from typing import List 17 | 18 | 19 | class Tokenizer: 20 | NUM_RESERVED_TOKENS = 32 21 | SPECIAL_TOKENS = ('', '', '', '', '', '') 22 | SPECIAL_TOKENS += tuple([f'' for i in range(len(SPECIAL_TOKENS), 32)]) # saved for future use 23 | 24 | PATTEN = re.compile(r'\[[^\]]+\]' 25 | # only some B|C|N|O|P|S|F|Cl|Br|I atoms can omit square brackets 26 | r'|B[r]?|C[l]?|N|O|P|S|F|I' 27 | r'|[bcnops]' 28 | r'|@@|@' 29 | r'|%\d{2}' 30 | r'|.') 31 | 32 | ATOM_PATTEN = re.compile(r'\[[^\]]+\]' 33 | r'|B[r]?|C[l]?|N|O|P|S|F|I' 34 | r'|[bcnops]') 35 | 36 | @staticmethod 37 | def gen_vocabs(smiles_list): 38 | smiles_set = set(smiles_list) 39 | vocabs = set() 40 | 41 | for a in tqdm(smiles_set): 42 | vocabs.update(re.findall(Tokenizer.PATTEN, a)) 43 | 44 | return vocabs 45 | 46 | def __init__(self, vocabs): 47 | special_tokens = list(Tokenizer.SPECIAL_TOKENS) 48 | vocabs = special_tokens + sorted(set(vocabs) - set(special_tokens), key=lambda x: (len(x), x)) 49 | self.vocabs = vocabs 50 | self.i2s = {i: s for i, s in enumerate(vocabs)} 51 | self.s2i = {s: i for i, s in self.i2s.items()} 52 | 53 | def __len__(self): 54 | return len(self.vocabs) 55 | 56 | def parse(self, smiles, return_atom_idx=False): 57 | l = [] 58 | if return_atom_idx: 59 | atom_idx=[] 60 | for i, s in enumerate(('', *re.findall(Tokenizer.PATTEN, smiles), '')): 61 | if s not in self.s2i: 62 | a = 3 # 3 for !!!!!! 63 | else: 64 | a = self.s2i[s] 65 | l.append(a) 66 | 67 | if return_atom_idx and re.fullmatch(Tokenizer.ATOM_PATTEN, s) is not None: 68 | atom_idx.append(i) 69 | if return_atom_idx: 70 | return l, atom_idx 71 | return l 72 | 73 | def get_text(self, predictions): 74 | if isinstance(predictions, torch.Tensor): 75 | predictions = predictions.tolist() 76 | 77 | smiles = [] 78 | for p in predictions: 79 | s = [] 80 | for i in p: 81 | c = self.i2s[i] 82 | if c == '': 83 | break 84 | s.append(c) 85 | smiles.append(''.join(s)) 86 | 87 | return smiles 88 | 89 | # tokenizer = Tokenizer(Tokenizer.gen_vocabs(smiles)) 90 | 91 | #pre_filter is a method that is applied to each individual sample of the dataset before it is added to the list of processed data. The purpose of pre_filter is to remove samples from the dataset that do not meet certain criteria. For example, you might use pre_filter to remove samples that have missing data or that do not meet some quality threshold. 92 | #pre_transform, on the other hand, is a method that is applied to each individual sample after it has been processed by collate, but before it is returned by the data loader. The purpose of pre_transform is to transform the individual samples in some way. For example, you might use pre_transform to normalize the features of the graph, to add noise to the graph, or to perform data augmentation. 93 | class TestbedDataset(InMemoryDataset): 94 | def __init__(self, root='/tmp', dataset='davis', 95 | xd=None, xdt=None, xt=None, y=None, transform=None, 96 | pre_transform=None,smile_graph=None): 97 | 98 | #root is required for save preprocessed data, default is '/tmp' 99 | super(TestbedDataset, self).__init__(root, transform, pre_transform) 100 | # benchmark dataset, default = 'davis' 101 | self.dataset = dataset 102 | self.pad_token = Tokenizer.SPECIAL_TOKENS.index('') 103 | if os.path.isfile(self.processed_paths[0]): 104 | print('Pre-processed data found: {}, loading ...'.format(self.processed_paths[0])) 105 | self.data, self.slices = torch.load(self.processed_paths[0]) 106 | else: 107 | print('Pre-processed data {} not found, doing pre-processing...'.format(self.processed_paths[0])) 108 | self.process(xd, xdt, xt, y,smile_graph) 109 | self.data, self.slices = torch.load(self.processed_paths[0]) 110 | 111 | @property 112 | def raw_file_names(self): 113 | pass 114 | #return ['some_file_1', 'some_file_2', ...] 115 | 116 | @property 117 | def processed_file_names(self): 118 | return [self.dataset + '.pt'] 119 | 120 | def download(self): 121 | # Download to `self.raw_dir`. 122 | pass 123 | 124 | def _download(self): 125 | pass 126 | 127 | def _process(self): 128 | if not os.path.exists(self.processed_dir): 129 | os.makedirs(self.processed_dir) 130 | 131 | def process(self, xd, xdt, xt, y, smile_graph): 132 | assert (len(xd) == len(xt) and len(xt) == len(y) == len(xdt)), "The three lists must be the same length!" 133 | 134 | smi = pad_sequence(xdt, batch_first=True, padding_value=self.pad_token) 135 | data_list = [] 136 | data_len = len(xd) 137 | for i in range(data_len): 138 | print('Preparing data in Pytorch Format: {}/{}'.format(i + 1, data_len)) 139 | smiles = xd[i] 140 | target = xt[i] 141 | labels = y[i] 142 | tok_smi = smi[i] 143 | # print(tok_smi) 144 | tok_smi = tok_smi.tolist() 145 | c_size, features, edge_index, edge_feats = smile_graph[smiles] 146 | GCNData = DATA.Data(x=torch.Tensor(features), 147 | edge_index=torch.LongTensor(edge_index).transpose(1, 0), 148 | edge_attr=torch.Tensor(edge_feats), 149 | y=torch.FloatTensor([labels])) 150 | GCNData.target = torch.LongTensor([target]) 151 | GCNData.target_seq = torch.LongTensor([tok_smi]) 152 | GCNData.__setitem__('c_size', torch.LongTensor([c_size])) 153 | data_list.append(GCNData) 154 | 155 | if self.pre_filter is not None: 156 | data_list = [data for data in data_list if self.pre_filter(data)] 157 | 158 | if self.pre_transform is not None: 159 | data_list = [self.pre_transform(data) for data in data_list] 160 | print('Data preparation Done!. Saving to file.') 161 | data, slices = self.collate(data_list) 162 | torch.save((data, slices), self.processed_paths[0]) 163 | 164 | def logging(msg, FLAGS): 165 | fpath = os.path.join(FLAGS.log_dir, f"log_{FLAGS.dataset_name}.txt") 166 | with open(fpath, "a") as fw: 167 | fw.write("%s\n" % msg) 168 | 169 | def save_best_model(mse_loss, model, best_mse, model_path): 170 | if mse_loss < best_mse: 171 | best_mse = mse_loss 172 | torch.save(model.state_dict(), model_path) 173 | print("Best model saved!") 174 | 175 | 176 | def rmse(y,f): 177 | rmse = sqrt(((y - f)**2).mean(axis=0)) 178 | return rmse 179 | def mse(y,f): 180 | mse = ((y - f)**2).mean(axis=0) 181 | return mse 182 | def pearson(y,f): 183 | rp = np.corrcoef(y, f)[0,1] 184 | return rp 185 | def spearman(y,f): 186 | rs = stats.spearmanr(y, f)[0] 187 | return rs 188 | def get_cindex(Y, P): 189 | P = P[:,np.newaxis] - P 190 | P = np.float32(P==0) * 0.5 + np.float32(P>0) 191 | 192 | Y = Y[:,np.newaxis] - Y 193 | Y = np.tril(np.float32(Y>0), 0) 194 | 195 | P_sum = np.sum(P*Y) 196 | Y_sum = np.sum(Y) 197 | 198 | 199 | if Y_sum==0: 200 | return 0 201 | else: 202 | return P_sum/Y_sum 203 | 204 | def r_squared_error(y_obs,y_pred): 205 | y_obs = np.array(y_obs) 206 | y_pred = np.array(y_pred) 207 | y_obs_mean = [np.mean(y_obs) for y in y_obs] 208 | y_pred_mean = [np.mean(y_pred) for y in y_pred] 209 | 210 | mult = sum((y_pred - y_pred_mean) * (y_obs - y_obs_mean)) 211 | mult = mult * mult 212 | 213 | y_obs_sq = sum((y_obs - y_obs_mean)*(y_obs - y_obs_mean)) 214 | y_pred_sq = sum((y_pred - y_pred_mean) * (y_pred - y_pred_mean) ) 215 | 216 | return mult / float(y_obs_sq * y_pred_sq) 217 | 218 | 219 | def get_k(y_obs,y_pred): 220 | y_obs = np.array(y_obs) 221 | y_pred = np.array(y_pred) 222 | 223 | return sum(y_obs*y_pred) / float(sum(y_pred*y_pred)) 224 | 225 | def squared_error_zero(y_obs,y_pred): 226 | k = get_k(y_obs,y_pred) 227 | 228 | y_obs = np.array(y_obs) 229 | y_pred = np.array(y_pred) 230 | y_obs_mean = [np.mean(y_obs) for y in y_obs] 231 | upp = sum((y_obs - (k*y_pred)) * (y_obs - (k* y_pred))) 232 | down= sum((y_obs - y_obs_mean)*(y_obs - y_obs_mean)) 233 | 234 | return 1 - (upp / float(down)) 235 | 236 | def get_rm2(ys_orig,ys_line): 237 | r2 = r_squared_error(ys_orig, ys_line) 238 | r02 = squared_error_zero(ys_orig, ys_line) 239 | return r2 * (1 - np.sqrt(np.absolute((r2*r2)-(r02*r02)))) 240 | 241 | def get_auc(y_true,y_pred): 242 | precision, recall, thresholds = precision_recall_curve(y_true,y_pred) 243 | roc_aupr = auc(recall,precision) 244 | return roc_aupr 245 | 246 | def get_aupr(predictions, true_labels, threshold): 247 | """Calculate Area Under Precision-Recall Curve (AUPR) at a given threshold.""" 248 | binary_pred = (predictions > threshold).astype(int) 249 | binary_true = (true_labels > threshold).astype(int) 250 | 251 | # Calculate AUPR 252 | return average_precision_score(binary_true, binary_pred) 253 | -------------------------------------------------------------------------------- /DEMO/required_files_for_demo/demo_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pickle 4 | from tqdm import tqdm 5 | import pandas as pd 6 | import argparse 7 | from rdkit import Chem 8 | import sys 9 | import os 10 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))) 11 | from utils import * 12 | import rdkit 13 | import networkx as nx 14 | 15 | import os 16 | from torch_geometric.data import InMemoryDataset 17 | from torch_geometric import data as DATA 18 | import torch 19 | import numpy as np 20 | import torch 21 | from torch_geometric.data import InMemoryDataset, Batch 22 | import re 23 | from datetime import datetime 24 | 25 | 26 | def format_smiles(smiles): 27 | mol = Chem.MolFromSmiles(smiles) 28 | if mol is None: 29 | return None 30 | smiles = rdkit.Chem.MolToSmiles(mol, isomericSmiles=True) 31 | 32 | return smiles 33 | 34 | def one_of_k_encoding(x, allowable_set): 35 | if x not in allowable_set: 36 | x = allowable_set[-1] 37 | return [x == s for s in allowable_set] 38 | 39 | def one_of_k_encoding_unk(x, allowable_set): 40 | if x not in allowable_set: 41 | x = allowable_set[-1] 42 | return [x == s for s in allowable_set] + [x not in allowable_set] 43 | 44 | def atom_features(atom): 45 | return np.array(one_of_k_encoding_unk(atom.GetSymbol(),['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na','Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb','Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H','Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr','Cr', 'Pt', 'Hg', 'Pb', 'Unknown']) + 46 | one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + 47 | one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + 48 | one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + 49 | one_of_k_encoding_unk(atom.GetFormalCharge(), [-1, -2, 1, 2, 0]) + 50 | one_of_k_encoding_unk(atom.GetHybridization(), [Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2]) + 51 | [atom.GetIsAromatic()] + 52 | [atom.IsInRing()]) 53 | 54 | def bond_features(bond): 55 | bt = bond.GetBondType() 56 | bond_feats = [0, 0, 0, 0, bond.GetBondTypeAsDouble()] 57 | if bt == Chem.rdchem.BondType.SINGLE: 58 | bond_feats = [1, 0, 0, 0, bond.GetBondTypeAsDouble()] 59 | elif bt == Chem.rdchem.BondType.DOUBLE: 60 | bond_feats = [0, 1, 0, 0, bond.GetBondTypeAsDouble()] 61 | elif bt == Chem.rdchem.BondType.TRIPLE: 62 | bond_feats = [0, 0, 1, 0, bond.GetBondTypeAsDouble()] 63 | elif bt == Chem.rdchem.BondType.AROMATIC: 64 | bond_feats = [0, 0, 0, 1, bond.GetBondTypeAsDouble()] 65 | return np.array(bond_feats) 66 | 67 | def smile_to_graph(smile): 68 | mol = Chem.MolFromSmiles(smile) 69 | 70 | c_size = mol.GetNumAtoms() 71 | 72 | features = [] 73 | for atom in mol.GetAtoms(): 74 | feature = atom_features(atom) 75 | features.append(feature / sum(feature)) 76 | 77 | edges = [] 78 | for bond in mol.GetBonds(): 79 | edge_feats = bond_features(bond) 80 | edges.append((bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), {'edge_feats': edge_feats})) 81 | 82 | g = nx.Graph() 83 | g.add_edges_from(edges) 84 | g = g.to_directed() 85 | edge_index = [] 86 | edge_feats = [] 87 | for e1, e2, feats in g.edges(data=True): 88 | edge_index.append([e1, e2]) 89 | edge_feats.append(feats['edge_feats']) 90 | 91 | return c_size, features, edge_index, edge_feats 92 | 93 | def seq_cat(prot): 94 | seq_voc = "ABCDEFGHIKLMNOPQRSTUVWXYZ" 95 | seq_dict = {v:(i+1) for i,v in enumerate(seq_voc)} 96 | max_seq_len = 1000 97 | x = np.zeros(max_seq_len) 98 | for i, ch in enumerate(prot[:max_seq_len]): 99 | x[i] = seq_dict[ch] 100 | return x 101 | 102 | def process_latent_a(smile, protein_seq): 103 | tokenizer_file = f'data/{smile}_tokenizer.pkl' 104 | tokenizer = Tokenizer(Tokenizer.gen_vocabs(smile)) 105 | 106 | smile_graph = {} 107 | g = smile_to_graph(smile) 108 | smile_graph[smile] = g 109 | 110 | with open(tokenizer_file, 'wb') as file: 111 | pickle.dump(tokenizer, file) 112 | XT = seq_cat(protein_seq) 113 | name = datetime.now().strftime("%Y%m%d%H%M%S") 114 | data = TestbedDataset(root='data', dataset=name, xd=np.asarray([smile]), xt=np.asarray([XT]), smile_graph=smile_graph) 115 | 116 | return data 117 | 118 | def process_latent(smile, protein_seq, affinity): 119 | tokenizer_file = f'data/{smile}_tokenizer.pkl' 120 | tokenizer = Tokenizer(Tokenizer.gen_vocabs(smile)) 121 | 122 | smile_graph = {} 123 | g = smile_to_graph(smile) 124 | smile_graph[smile] = g 125 | 126 | with open(tokenizer_file, 'wb') as file: 127 | pickle.dump(tokenizer, file) 128 | XT = seq_cat(protein_seq) 129 | Y = float(affinity) 130 | Y = np.asarray([Y]) 131 | name = datetime.now().strftime("%Y%m%d%H%M%S") 132 | data = TestbedDataset2(root='data', dataset=name, xd=np.asarray([smile]), xt=np.asarray([XT]), y=Y, smile_graph=smile_graph) 133 | 134 | return data 135 | 136 | 137 | 138 | class Tokenizer: 139 | NUM_RESERVED_TOKENS = 32 140 | SPECIAL_TOKENS = ('', '', '', '', '', '') 141 | SPECIAL_TOKENS += tuple([f'' for i in range(len(SPECIAL_TOKENS), 32)]) # saved for future use 142 | 143 | PATTEN = re.compile(r'\[[^\]]+\]' 144 | # only some B|C|N|O|P|S|F|Cl|Br|I atoms can omit square brackets 145 | r'|B[r]?|C[l]?|N|O|P|S|F|I' 146 | r'|[bcnops]' 147 | r'|@@|@' 148 | r'|%\d{2}' 149 | r'|.') 150 | 151 | ATOM_PATTEN = re.compile(r'\[[^\]]+\]' 152 | r'|B[r]?|C[l]?|N|O|P|S|F|I' 153 | r'|[bcnops]') 154 | 155 | @staticmethod 156 | def gen_vocabs(smiles_list): 157 | smiles_set = set(smiles_list) 158 | vocabs = set() 159 | 160 | for a in tqdm(smiles_set): 161 | vocabs.update(re.findall(Tokenizer.PATTEN, a)) 162 | 163 | return vocabs 164 | 165 | def __init__(self, vocabs): 166 | special_tokens = list(Tokenizer.SPECIAL_TOKENS) 167 | vocabs = special_tokens + sorted(set(vocabs) - set(special_tokens), key=lambda x: (len(x), x)) 168 | self.vocabs = vocabs 169 | self.i2s = {i: s for i, s in enumerate(vocabs)} 170 | self.s2i = {s: i for i, s in self.i2s.items()} 171 | 172 | def __len__(self): 173 | return len(self.vocabs) 174 | 175 | def parse(self, smiles, return_atom_idx=False): 176 | l = [] 177 | if return_atom_idx: 178 | atom_idx=[] 179 | for i, s in enumerate(('', *re.findall(Tokenizer.PATTEN, smiles), '')): 180 | if s not in self.s2i: 181 | a = 3 # 3 for !!!!!! 182 | else: 183 | a = self.s2i[s] 184 | l.append(a) 185 | 186 | if return_atom_idx and re.fullmatch(Tokenizer.ATOM_PATTEN, s) is not None: 187 | atom_idx.append(i) 188 | if return_atom_idx: 189 | return l, atom_idx 190 | return l 191 | 192 | def get_text(self, predictions): 193 | if isinstance(predictions, torch.Tensor): 194 | predictions = predictions.tolist() 195 | 196 | smiles = [] 197 | for p in predictions: 198 | s = [] 199 | for i in p: 200 | c = self.i2s[i] 201 | if c == '': 202 | break 203 | s.append(c) 204 | smiles.append(''.join(s)) 205 | 206 | return smiles 207 | 208 | class TestbedDataset(InMemoryDataset): 209 | def __init__(self, root='/tmp', dataset='davis', 210 | xd=None, xt=None, transform=None, 211 | pre_transform=None,smile_graph=None): 212 | 213 | super(TestbedDataset, self).__init__(root, transform, pre_transform) 214 | self.dataset = dataset 215 | self.data_l = [] 216 | self.pad_token = Tokenizer.SPECIAL_TOKENS.index('') 217 | if os.path.isfile(self.processed_paths[0]): 218 | print('Pre-processed data found: {}, loading ...'.format(self.processed_paths[0])) 219 | self.data, self.slices = torch.load(self.processed_paths[0]) 220 | else: 221 | print('Pre-processed data {} not found, doing pre-processing...'.format(self.processed_paths[0])) 222 | self.process(xd, xt, smile_graph) 223 | self.data, self.slices = torch.load(self.processed_paths[0]) 224 | 225 | @property 226 | def raw_file_names(self): 227 | pass 228 | #return ['some_file_1', 'some_file_2', ...] 229 | 230 | @property 231 | def processed_file_names(self): 232 | return [self.dataset + '.pt'] 233 | 234 | def download(self): 235 | # Download to `self.raw_dir`. 236 | pass 237 | 238 | def _download(self): 239 | pass 240 | 241 | def _process(self): 242 | if not os.path.exists(self.processed_dir): 243 | os.makedirs(self.processed_dir) 244 | 245 | def process(self, xd, xt, smile_graph): 246 | assert (len(xd) == len(xt) and len(xt)), "The three lists must be the same length!" 247 | 248 | data_list = [] 249 | data_len = len(xd) 250 | for i in range(data_len): 251 | print('Preparing data in Pytorch Format: {}/{}'.format(i + 1, data_len)) 252 | smiles = xd[i] 253 | target = xt[i] 254 | c_size, features, edge_index, edge_feats = smile_graph[smiles] 255 | GCNData = DATA.Data(x=torch.Tensor(np.array(features)), 256 | edge_index=torch.LongTensor(edge_index).transpose(1, 0), 257 | edge_attr=torch.Tensor(np.array(edge_feats))) 258 | GCNData.target = torch.from_numpy(target).long() 259 | GCNData.__setitem__('c_size', torch.LongTensor([c_size])) 260 | data_list.append(GCNData) 261 | 262 | if self.pre_filter is not None: 263 | data_list = [data for data in data_list if self.pre_filter(data)] 264 | 265 | if self.pre_transform is not None: 266 | data_list = [self.pre_transform(data) for data in data_list] 267 | print('Data preparation Done!. Saving to file.') 268 | data, slices = self.collate(data_list) 269 | torch.save((data, slices), self.processed_paths[0]) 270 | 271 | self.data_l = data_list 272 | 273 | 274 | def __len__(self): 275 | return len(self.data_l) 276 | 277 | def __getitem__(self, idx): 278 | return self.data_l[idx] 279 | 280 | class TestbedDataset2(InMemoryDataset): 281 | def __init__(self, root='/tmp', dataset='davis', 282 | xd=None, xt=None, y=None, transform=None, 283 | pre_transform=None,smile_graph=None): 284 | 285 | super(TestbedDataset2, self).__init__(root, transform, pre_transform) 286 | self.dataset = dataset 287 | self.data_l = [] 288 | self.pad_token = Tokenizer.SPECIAL_TOKENS.index('') 289 | if os.path.isfile(self.processed_paths[0]): 290 | print('Pre-processed data found: {}, loading ...'.format(self.processed_paths[0])) 291 | self.data, self.slices = torch.load(self.processed_paths[0]) 292 | else: 293 | print('Pre-processed data {} not found, doing pre-processing...'.format(self.processed_paths[0])) 294 | self.process(xd, xt, y, smile_graph) 295 | self.data, self.slices = torch.load(self.processed_paths[0]) 296 | 297 | @property 298 | def raw_file_names(self): 299 | pass 300 | #return ['some_file_1', 'some_file_2', ...] 301 | 302 | @property 303 | def processed_file_names(self): 304 | return [self.dataset + '.pt'] 305 | 306 | def download(self): 307 | # Download to `self.raw_dir`. 308 | pass 309 | 310 | def _download(self): 311 | pass 312 | 313 | def _process(self): 314 | if not os.path.exists(self.processed_dir): 315 | os.makedirs(self.processed_dir) 316 | 317 | def process(self, xd, xt, y, smile_graph): 318 | assert (len(xd) == len(xt) and len(xt)), "The three lists must be the same length!" 319 | 320 | data_list = [] 321 | data_len = len(xd) 322 | for i in range(data_len): 323 | print('Preparing data in Pytorch Format: {}/{}'.format(i + 1, data_len)) 324 | smiles = xd[i] 325 | target = xt[i] 326 | labels = y[i] 327 | c_size, features, edge_index, edge_feats = smile_graph[smiles] 328 | GCNData = DATA.Data(x=torch.Tensor(np.array(features)), 329 | edge_index=torch.LongTensor(edge_index).transpose(1, 0), 330 | edge_attr=torch.Tensor(np.array(edge_feats)), 331 | y=torch.FloatTensor([labels])) 332 | GCNData.target = torch.from_numpy(target).long() 333 | GCNData.__setitem__('c_size', torch.LongTensor([c_size])) 334 | data_list.append(GCNData) 335 | 336 | if self.pre_filter is not None: 337 | data_list = [data for data in data_list if self.pre_filter(data)] 338 | 339 | if self.pre_transform is not None: 340 | data_list = [self.pre_transform(data) for data in data_list] 341 | print('Data preparation Done!. Saving to file.') 342 | data, slices = self.collate(data_list) 343 | torch.save((data, slices), self.processed_paths[0]) 344 | 345 | self.data_l = data_list 346 | 347 | 348 | def __len__(self): 349 | return len(self.data_l) 350 | 351 | def __getitem__(self, idx): 352 | return self.data_l[idx] 353 | 354 | #prepare the protein and drug pairs 355 | def collate(data_list): 356 | batchA = Batch.from_data_list([data for data in data_list]) 357 | return batchA -------------------------------------------------------------------------------- /DEMO/required_files_for_demo/model_aff.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import GCNConv, global_max_pool as gmp 5 | from typing import Optional, Dict 6 | import math 7 | from fairseq.models import FairseqIncrementalDecoder 8 | from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer 9 | from torch.nn.utils.rnn import pad_sequence 10 | from utils import Tokenizer 11 | from einops.layers.torch import Rearrange 12 | 13 | class PositionalEncoding(nn.Module): 14 | 15 | def __init__(self, d_model, dropout=0.1, max_len=5000): 16 | super().__init__() 17 | self.dropout = nn.Dropout(p=dropout) 18 | 19 | pe = torch.zeros(max_len, d_model) 20 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 21 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 22 | pe[:, 0::2] = torch.sin(position * div_term) 23 | pe[:, 1::2] = torch.cos(position * div_term) 24 | pe = pe.unsqueeze(0).transpose(0, 1) 25 | self.register_buffer('pe', pe) 26 | 27 | def forward(self, x): 28 | r"""Inputs of forward function 29 | Args: 30 | x: the sequence fed to the positional encoder model (required). 31 | Shape: 32 | x: [sequence length, batch size, embed dim] 33 | output: [sequence length, batch size, embed dim] 34 | Examples: 35 | >>> output = pos_encoder(x) 36 | """ 37 | 38 | x = x + self.pe[:x.size(0), :] 39 | return self.dropout(x) 40 | 41 | class Namespace: 42 | def __init__(self, argvs): 43 | for k, v in argvs.items(): 44 | setattr(self, k, v) 45 | 46 | 47 | class TransformerEncoder(nn.Module): 48 | def __init__(self, dim, ff_dim, num_head, num_layer): 49 | super().__init__() 50 | 51 | self.layer = nn.ModuleList([ 52 | TransformerEncoderLayer(Namespace({ 53 | 'encoder_embed_dim': dim, 54 | 'encoder_attention_heads': num_head, 55 | 'attention_dropout': 0.1, 56 | 'dropout': 0.1, 57 | 'encoder_normalize_before': True, 58 | 'encoder_ffn_embed_dim': ff_dim, 59 | })) for i in range(num_layer) 60 | ]) 61 | 62 | self.layer_norm = nn.LayerNorm(dim) 63 | 64 | def forward(self, x, encoder_padding_mask=None): 65 | # print('my TransformerDecode forward()') 66 | for layer in self.layer: 67 | x = layer(x, encoder_padding_mask) 68 | x = self.layer_norm(x) 69 | return x # T x B x C 70 | 71 | class Encoder(torch.nn.Module): 72 | def __init__(self, Drug_Features, dropout, Final_dim): 73 | super(Encoder, self).__init__() 74 | self.hidden_dim = 376 75 | self.GraphConv1 = GCNConv(Drug_Features, Drug_Features * 2) 76 | self.GraphConv2 = GCNConv(Drug_Features * 2, Drug_Features * 3) 77 | self.GraphConv3 = GCNConv(Drug_Features * 3, Drug_Features * 4) 78 | self.cond = nn.Linear(96 * 107, self.hidden_dim) 79 | self.cond2 = nn.Linear(451, self.hidden_dim) 80 | 81 | self.mean = nn.Sequential( 82 | nn.Linear(self.hidden_dim, self.hidden_dim), 83 | nn.ReLU(), 84 | nn.Linear(self.hidden_dim, self.hidden_dim)) 85 | self.var = nn.Sequential( 86 | nn.Linear(self.hidden_dim, self.hidden_dim), 87 | nn.ReLU(), 88 | nn.Linear(self.hidden_dim, self.hidden_dim)) 89 | 90 | self.Drug_FCs = nn.Sequential( 91 | nn.Linear(Drug_Features * 4, 1024), 92 | nn.ReLU(), 93 | nn.Dropout(dropout), 94 | nn.Linear(1024, Final_dim) 95 | ) 96 | self.Relu_activation = nn.ReLU() 97 | self.pp_seg_encoding = nn.Parameter(torch.randn(376)) 98 | 99 | def reparameterize(self, z_mean, logvar, batch, con): 100 | # Compute the KL divergence loss 101 | z_log_var = -torch.abs(logvar) 102 | kl_loss = -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp()) / 64 103 | 104 | # Reparameterization trick: sample from N(0, 1) 105 | epsilon = torch.randn_like(z_mean).to(z_mean.device) 106 | z_ = z_mean + torch.exp(z_log_var / 2) * epsilon 107 | 108 | # Reshape and apply conditioning 109 | con = con.view(-1, 96 * 107) 110 | con_embedding = self.cond(con) 111 | 112 | # Add conditioning and auxiliary factor 113 | z_ = z_ + con_embedding 114 | 115 | return z_, kl_loss 116 | 117 | 118 | def process_p(self, node_features, num_nodes, batch_size): 119 | # Convert graph features to sequence 120 | d_node_features = pad_sequence(torch.split(node_features, num_nodes.tolist()), batch_first=False, padding_value=-999) 121 | # Initialize padded sequence with -999 122 | padded_sequence = d_node_features.new_ones((d_node_features.shape[0], 123 | d_node_features.shape[1], 124 | d_node_features.shape[2])) * -999 125 | # Fill padded sequence with node features 126 | padded_sequence[:d_node_features.shape[0], :, :] = d_node_features 127 | d_node_features = padded_sequence 128 | # Create mask for padding positions 129 | padding_mask = (d_node_features[:, :, 0].T == -999).bool() 130 | padded_sequence_with_encoding = d_node_features + self.pp_seg_encoding 131 | return padded_sequence_with_encoding, padding_mask 132 | 133 | def forward(self, data, con): 134 | x, edge_index, batch, num_nodes, affinity = data.x, data.edge_index, data.batch, data.c_size, data.y 135 | GCNConv = self.GraphConv1(x, edge_index) 136 | GCNConv = self.Relu_activation(GCNConv) 137 | GCNConv = self.GraphConv2(GCNConv, edge_index) 138 | GCNConv = self.Relu_activation(GCNConv) 139 | PMVO = self.GraphConv3(GCNConv, edge_index) 140 | x = self.Relu_activation(PMVO) 141 | d_sequence, Mask = self.process_p(x, num_nodes, batch) 142 | mu = self.mean(d_sequence) 143 | logvar = self.var(d_sequence) 144 | AMVO, kl_loss = self.reparameterize(mu, logvar, batch, con) 145 | x2 = gmp(x, batch) 146 | PMVO = self.Drug_FCs(x2) 147 | return d_sequence, AMVO, Mask, PMVO, kl_loss 148 | 149 | 150 | class Decoder(nn.Module): 151 | def __init__(self, dim, ff_dim, num_head, num_layer): 152 | super().__init__() 153 | 154 | self.layer = nn.ModuleList([ 155 | TransformerDecoderLayer(Namespace({ 156 | 'decoder_embed_dim': dim, 157 | 'decoder_attention_heads': num_head, 158 | 'attention_dropout': 0.1, 159 | 'dropout': 0.1, 160 | 'decoder_normalize_before': True, 161 | 'decoder_ffn_embed_dim': ff_dim, 162 | })) for i in range(num_layer) 163 | ]) 164 | self.layer_norm = nn.LayerNorm(dim) 165 | 166 | def forward(self, x, mem, x_mask=None, x_padding_mask=None, mem_padding_mask=None): 167 | for layer in self.layer: 168 | x = layer(x, mem, 169 | self_attn_mask=x_mask, self_attn_padding_mask=x_padding_mask, 170 | encoder_padding_mask=mem_padding_mask)[0] 171 | x = self.layer_norm(x) 172 | return x 173 | 174 | @torch.jit.export 175 | def forward_one(self, 176 | x: torch.Tensor, 177 | mem: torch.Tensor, 178 | incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]], 179 | mem_padding_mask: torch.BoolTensor = None, 180 | ) -> torch.Tensor: 181 | x = x[-1:] 182 | for layer in self.layer: 183 | x = layer(x, mem, incremental_state=incremental_state, encoder_padding_mask=mem_padding_mask)[0] 184 | x = self.layer_norm(x) 185 | return x 186 | 187 | 188 | class GatedCNN(nn.Module): 189 | def __init__(self, Protein_Features, Num_Filters, Embed_dim, Final_dim, K_size): 190 | super(GatedCNN, self).__init__() 191 | self.Protein_Embed = nn.Embedding(Protein_Features + 1, Embed_dim) 192 | self.Protein_Conv1 = nn.Conv1d(in_channels=1000, out_channels=Num_Filters, kernel_size=K_size) 193 | self.Protein_Gate1 = nn.Conv1d(in_channels=1000, out_channels=Num_Filters, kernel_size=K_size) 194 | self.Protein_Conv2 = nn.Conv1d(in_channels=Num_Filters, out_channels=Num_Filters * 2, kernel_size=K_size) 195 | self.Protein_Gate2 = nn.Conv1d(in_channels=Num_Filters, out_channels=Num_Filters * 2, kernel_size=K_size) 196 | self.Protein_Conv3 = nn.Conv1d(in_channels=Num_Filters * 2, out_channels=Num_Filters * 3, kernel_size=K_size) 197 | self.Protein_Gate3 = nn.Conv1d(in_channels=Num_Filters * 2, out_channels=Num_Filters * 3, kernel_size=K_size) 198 | self.relu = nn.ReLU() 199 | self.Protein_FC = nn.Linear(96 * 107, Final_dim) 200 | 201 | def forward(self, data): 202 | target = data.target 203 | Embed = self.Protein_Embed(target) 204 | conv1 = self.Protein_Conv1(Embed) 205 | gate1 = torch.sigmoid(self.Protein_Gate1(Embed)) 206 | GCNN1_Output = conv1 * gate1 207 | GCNN1_Output = self.relu(GCNN1_Output) 208 | #GATED CNN 2ND LAYER 209 | conv2 = self.Protein_Conv2(GCNN1_Output) 210 | gate2 = torch.sigmoid(self.Protein_Gate2(GCNN1_Output)) 211 | GCNN2_Output = conv2 * gate2 212 | GCNN2_Output = self.relu(GCNN2_Output) 213 | #GATED CNN 3RD LAYER 214 | conv3 = self.Protein_Conv3(GCNN2_Output) 215 | gate3 = torch.sigmoid(self.Protein_Gate3(GCNN2_Output)) 216 | GCNN3_Output = conv3 * gate3 217 | GCNN3_Output = self.relu(GCNN3_Output) 218 | #FLAT TENSOR 219 | xt = GCNN3_Output.view(-1, 96 * 107) 220 | #PROTEIN FULLY CONNECTED LAYER 221 | xt = self.Protein_FC(xt) 222 | return xt, GCNN3_Output 223 | 224 | class FC(torch.nn.Module): 225 | def __init__(self, output_dim, n_output, dropout): 226 | super(FC, self).__init__() 227 | self.FC_layers = nn.Sequential( 228 | nn.Linear(output_dim * 2, 1024), 229 | nn.ReLU(), 230 | nn.Dropout(dropout), 231 | nn.Linear(1024, 512), 232 | nn.ReLU(), 233 | nn.Dropout(dropout), 234 | nn.Linear(512, 256), 235 | nn.ReLU(), 236 | nn.Dropout(dropout), 237 | nn.Linear(256, n_output) 238 | ) 239 | 240 | def forward(self, Drug_Features, Protein_Features): 241 | Combined = torch.cat((Drug_Features, Protein_Features), 1) 242 | Pridection = self.FC_layers(Combined) 243 | return Pridection 244 | 245 | # MAin CLass 246 | class DeepDTAGen(torch.nn.Module): 247 | def __init__(self, tokenizer): 248 | super(DeepDTAGen, self).__init__() 249 | self.hidden_dim = 376 250 | self.max_len = 128 251 | self.node_feature = 94 252 | self.output_dim = 128 253 | self.ff_dim = 1024 254 | self.heads = 8 255 | self.layers = 8 256 | self.encoder_dropout = 0.2 257 | self.dropout = 0.3 258 | self.protein_f = 25 259 | self.filters = 32 260 | self.kernel = 8 261 | # Encoder, Decoder, and related components 262 | self.encoder = Encoder(Drug_Features=self.node_feature, dropout=self.encoder_dropout, Final_dim=self.output_dim) 263 | self.decoder = Decoder(dim=self.hidden_dim, ff_dim=self.ff_dim, num_head=self.heads, num_layer=self.layers) 264 | self.dencoder = TransformerEncoder(dim=self.hidden_dim, ff_dim=self.ff_dim, num_head=self.heads, num_layer=self.layers) 265 | self.pos_encoding = PositionalEncoding(self.hidden_dim, max_len=138) 266 | 267 | # CNN for processing protein features 268 | self.cnn = GatedCNN(Protein_Features=self.protein_f, Num_Filters=self.filters, 269 | Embed_dim=self.output_dim, Final_dim=self.output_dim, K_size=self.kernel) 270 | 271 | # Fully connected layer 272 | self.fc = FC(output_dim=self.output_dim, n_output=1, dropout=self.dropout) 273 | 274 | # Learnable parameter for segment encoding 275 | self.zz_seg_encoding = nn.Parameter(torch.randn(self.hidden_dim)) 276 | 277 | # Word prediction layers 278 | vocab_size = len(tokenizer) 279 | self.word_pred = nn.Sequential( 280 | nn.Linear(self.hidden_dim, self.hidden_dim), 281 | nn.PReLU(), 282 | nn.LayerNorm(self.hidden_dim), 283 | nn.Linear(self.hidden_dim, vocab_size) 284 | ) 285 | torch.nn.init.zeros_(self.word_pred[3].bias) 286 | 287 | # Other properties 288 | self.vocab_size = vocab_size 289 | self.sos_value = tokenizer.s2i[''] 290 | self.eos_value = tokenizer.s2i[''] 291 | self.pad_value = tokenizer.s2i[''] 292 | self.word_embed = nn.Embedding(vocab_size, self.hidden_dim) 293 | self.unk_index = Tokenizer.SPECIAL_TOKENS.index('') 294 | 295 | # Expand and Fusion layers 296 | self.expand = nn.Sequential( 297 | nn.Linear(self.hidden_dim, self.hidden_dim), 298 | nn.ReLU(), 299 | nn.LayerNorm(self.hidden_dim), 300 | nn.Linear(self.hidden_dim, self.hidden_dim), 301 | Rearrange('batch_size h -> 1 batch_size h') 302 | ) 303 | 304 | def expand_then_fusing(self, z, pp_mask, vvs): 305 | # Expand and fuse latent variables 306 | zz = z 307 | # zz = self.expand(z) 308 | zzs = zz + self.zz_seg_encoding 309 | 310 | # Create full mask for decoder 311 | full_mask = zz.new_zeros(zz.shape[1], zz.shape[0]) 312 | full_mask = torch.cat((pp_mask, full_mask), dim=1) # batch seq_plus 313 | 314 | # Concatenate latent variables and segment encoding 315 | zzz = torch.cat((vvs, zzs), dim=0) # seq_plus batch feat 316 | 317 | # Encode the concatenated sequence 318 | zzz = self.dencoder(zzz, full_mask) 319 | 320 | return zzz, full_mask 321 | 322 | def sample(self, batch_size, device): 323 | z = torch.randn(1, self.hidden_dim).to(device) 324 | return z 325 | 326 | def forward(self, data): 327 | # Process protein features through CNN 328 | Protein_vector, con = self.cnn(data) 329 | 330 | # Encode the input graph 331 | vss, AMVO, mask, PMVO, kl_loss = self.encoder(data, con) 332 | 333 | # Expand and fuse latent variables 334 | zzz, encoder_mask = self.expand_then_fusing(AMVO, mask, vss) 335 | 336 | Pridection = self.fc(PMVO, Protein_vector) 337 | 338 | return Pridection 339 | 340 | 341 | def _generate(self, zzz, encoder_mask, random_sample, return_score=False): 342 | # Determine the batch size and device 343 | batch_size = zzz.shape[1] 344 | device = zzz.device 345 | 346 | # Initialize token tensor for sequence generation 347 | token = torch.full((batch_size, self.max_len), self.pad_value, dtype=torch.long, device=device) 348 | token[:, 0] = self.sos_value 349 | 350 | # Initialize positional encoding for text 351 | text_pos = self.pos_encoding.pe 352 | 353 | # Initialize text embedding for the first token 354 | text_embed = self.word_embed(token[:, 0]) 355 | text_embed = text_embed + text_pos[0] 356 | text_embed = text_embed.unsqueeze(0) 357 | 358 | # Initialize incremental state for transformer decoder 359 | incremental_state = torch.jit.annotate( 360 | Dict[str, Dict[str, Optional[torch.Tensor]]], 361 | torch.jit.annotate(Dict[str, Dict[str, Optional[torch.Tensor]]], {}), 362 | ) 363 | 364 | # Initialize scores if return_score is True 365 | if return_score: 366 | scores = [] 367 | 368 | # Initialize flag for finished sequences 369 | finished = torch.zeros(batch_size, dtype=torch.bool, device=device) 370 | 371 | # Loop for sequence generation 372 | for t in range(1, self.max_len): 373 | # Decode one token at a time 374 | one = self.decoder.forward_one(text_embed, zzz, incremental_state, mem_padding_mask=encoder_mask) 375 | one = one.squeeze(0) 376 | l = self.word_pred(one) # Get predicted scores for tokens 377 | 378 | # Append scores if return_score is True 379 | if return_score: 380 | scores.append(l) 381 | 382 | # Sample the next token either randomly or by choosing the one with maximum probability 383 | if random_sample: 384 | k = torch.multinomial(torch.softmax(l, 1), 1).squeeze(1) 385 | else: 386 | k = torch.argmax(l, -1) # Predict token with maximum probability 387 | 388 | # Update token tensor with the predicted token 389 | token[:, t] = k 390 | 391 | # Check if sequences are finished 392 | finished |= k == self.eos_value 393 | if finished.all(): 394 | break 395 | 396 | # Update text embedding for the next token 397 | text_embed = self.word_embed(k) 398 | text_embed = text_embed + text_pos[t] # Add positional encoding 399 | text_embed = text_embed.unsqueeze(0) 400 | 401 | # Extract the predicted sequence 402 | predict = token[:, 1:] 403 | 404 | # Return predicted sequence along with scores if return_score is True 405 | if return_score: 406 | return predict, torch.stack(scores, dim=1) 407 | return predict 408 | 409 | def generate(self, data, random_sample=False, return_z=False): 410 | # use protein condition from GatedCNN 411 | _, con = self.cnn(data) 412 | 413 | # Encode the input graph 414 | vss, AMVO, mask, PMVO, kl_loss = self.encoder(data, con) 415 | 416 | # Sample latent variables 417 | z = self.sample(data.batch, device=vss.device) 418 | 419 | # z = z + con1 + con2 420 | 421 | # zzz, encoder_mask = self.expand_then_fusing(z, mask, vss) 422 | 423 | # Expand and fuse latent variables 424 | zzz, encoder_mask = self.expand_then_fusing(AMVO, mask, vss) 425 | 426 | # Generate sequence based on latent variables 427 | predict = self._generate(zzz, encoder_mask, random_sample=random_sample, return_score=False) 428 | 429 | # Return predicted sequence along with latent variables if specified 430 | if return_z: 431 | return predict, z.detach().cpu().numpy() 432 | return predict 433 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import GCNConv, global_max_pool as gmp 5 | from typing import Optional, Dict 6 | import math 7 | from fairseq.models import FairseqIncrementalDecoder 8 | from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer 9 | from torch.nn.utils.rnn import pad_sequence 10 | from utils import Tokenizer 11 | from einops.layers.torch import Rearrange 12 | 13 | class PositionalEncoding(nn.Module): 14 | 15 | def __init__(self, d_model, dropout=0.1, max_len=5000): 16 | super().__init__() 17 | self.dropout = nn.Dropout(p=dropout) 18 | 19 | pe = torch.zeros(max_len, d_model) 20 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 21 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 22 | pe[:, 0::2] = torch.sin(position * div_term) 23 | pe[:, 1::2] = torch.cos(position * div_term) 24 | pe = pe.unsqueeze(0).transpose(0, 1) 25 | self.register_buffer('pe', pe) 26 | 27 | def forward(self, x): 28 | r"""Inputs of forward function 29 | Args: 30 | x: the sequence fed to the positional encoder model (required). 31 | Shape: 32 | x: [sequence length, batch size, embed dim] 33 | output: [sequence length, batch size, embed dim] 34 | Examples: 35 | >>> output = pos_encoder(x) 36 | """ 37 | 38 | x = x + self.pe[:x.size(0), :] 39 | return self.dropout(x) 40 | 41 | class Namespace: 42 | def __init__(self, argvs): 43 | for k, v in argvs.items(): 44 | setattr(self, k, v) 45 | 46 | 47 | class TransformerEncoder(nn.Module): 48 | def __init__(self, dim, ff_dim, num_head, num_layer): 49 | super().__init__() 50 | 51 | self.layer = nn.ModuleList([ 52 | TransformerEncoderLayer(Namespace({ 53 | 'encoder_embed_dim': dim, 54 | 'encoder_attention_heads': num_head, 55 | 'attention_dropout': 0.1, 56 | 'dropout': 0.1, 57 | 'encoder_normalize_before': True, 58 | 'encoder_ffn_embed_dim': ff_dim, 59 | })) for i in range(num_layer) 60 | ]) 61 | 62 | self.layer_norm = nn.LayerNorm(dim) 63 | 64 | def forward(self, x, encoder_padding_mask=None): 65 | # print('my TransformerDecode forward()') 66 | for layer in self.layer: 67 | x = layer(x, encoder_padding_mask) 68 | x = self.layer_norm(x) 69 | return x # T x B x C 70 | 71 | class Encoder(torch.nn.Module): 72 | def __init__(self, Drug_Features, dropout, Final_dim): 73 | super(Encoder, self).__init__() 74 | self.hidden_dim = 376 75 | self.GraphConv1 = GCNConv(Drug_Features, Drug_Features * 2) 76 | self.GraphConv2 = GCNConv(Drug_Features * 2, Drug_Features * 3) 77 | self.GraphConv3 = GCNConv(Drug_Features * 3, Drug_Features * 4) 78 | self.cond = nn.Linear(96 * 107, self.hidden_dim) 79 | self.cond2 = nn.Linear(451, self.hidden_dim) 80 | 81 | self.mean = nn.Sequential( 82 | nn.Linear(self.hidden_dim, self.hidden_dim), 83 | nn.ReLU(), 84 | nn.Linear(self.hidden_dim, self.hidden_dim)) 85 | self.var = nn.Sequential( 86 | nn.Linear(self.hidden_dim, self.hidden_dim), 87 | nn.ReLU(), 88 | nn.Linear(self.hidden_dim, self.hidden_dim)) 89 | 90 | self.Drug_FCs = nn.Sequential( 91 | nn.Linear(Drug_Features * 4, 1024), 92 | nn.ReLU(), 93 | nn.Dropout(dropout), 94 | nn.Linear(1024, Final_dim) 95 | ) 96 | self.Relu_activation = nn.ReLU() 97 | self.pp_seg_encoding = nn.Parameter(torch.randn(376)) 98 | 99 | def reparameterize(self, z_mean, logvar, batch, con, a): 100 | # Compute the KL divergence loss 101 | z_log_var = -torch.abs(logvar) 102 | kl_loss = -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp()) / 64 103 | 104 | # Reparameterization trick: sample from N(0, 1) 105 | epsilon = torch.randn_like(z_mean).to(z_mean.device) 106 | z_ = z_mean + torch.exp(z_log_var / 2) * epsilon 107 | 108 | # Reshape and apply conditioning 109 | con = con.view(-1, 96 * 107) 110 | con_embedding = self.cond(con) 111 | 112 | # Add conditioning and auxiliary factor 113 | z_ = z_ + con_embedding + a 114 | 115 | return z_, kl_loss 116 | 117 | 118 | def process_p(self, node_features, num_nodes, batch_size): 119 | # Convert graph features to sequence 120 | d_node_features = pad_sequence(torch.split(node_features, num_nodes.tolist()), batch_first=False, padding_value=-999) 121 | # Initialize padded sequence with -999 122 | padded_sequence = d_node_features.new_ones((d_node_features.shape[0], 123 | d_node_features.shape[1], 124 | d_node_features.shape[2])) * -999 125 | # Fill padded sequence with node features 126 | padded_sequence[:d_node_features.shape[0], :, :] = d_node_features 127 | d_node_features = padded_sequence 128 | # Create mask for padding positions 129 | padding_mask = (d_node_features[:, :, 0].T == -999).bool() 130 | padded_sequence_with_encoding = d_node_features + self.pp_seg_encoding 131 | return padded_sequence_with_encoding, padding_mask 132 | 133 | def forward(self, data, con): 134 | x, edge_index, batch, num_nodes, affinity = data.x, data.edge_index, data.batch, data.c_size, data.y 135 | a = affinity.view(-1, 1) 136 | GCNConv = self.GraphConv1(x, edge_index) 137 | GCNConv = self.Relu_activation(GCNConv) 138 | GCNConv = self.GraphConv2(GCNConv, edge_index) 139 | GCNConv = self.Relu_activation(GCNConv) 140 | PMVO = self.GraphConv3(GCNConv, edge_index) 141 | x = self.Relu_activation(PMVO) 142 | d_sequence, Mask = self.process_p(x, num_nodes, batch) 143 | mu = self.mean(d_sequence) 144 | logvar = self.var(d_sequence) 145 | AMVO, kl_loss = self.reparameterize(mu, logvar, batch, con, a) 146 | x2 = gmp(x, batch) 147 | PMVO = self.Drug_FCs(x2) 148 | return d_sequence, AMVO, Mask, PMVO, kl_loss 149 | 150 | 151 | class Decoder(nn.Module): 152 | def __init__(self, dim, ff_dim, num_head, num_layer): 153 | super().__init__() 154 | 155 | self.layer = nn.ModuleList([ 156 | TransformerDecoderLayer(Namespace({ 157 | 'decoder_embed_dim': dim, 158 | 'decoder_attention_heads': num_head, 159 | 'attention_dropout': 0.1, 160 | 'dropout': 0.1, 161 | 'decoder_normalize_before': True, 162 | 'decoder_ffn_embed_dim': ff_dim, 163 | })) for i in range(num_layer) 164 | ]) 165 | self.layer_norm = nn.LayerNorm(dim) 166 | 167 | def forward(self, x, mem, x_mask=None, x_padding_mask=None, mem_padding_mask=None): 168 | for layer in self.layer: 169 | x = layer(x, mem, 170 | self_attn_mask=x_mask, self_attn_padding_mask=x_padding_mask, 171 | encoder_padding_mask=mem_padding_mask)[0] 172 | x = self.layer_norm(x) 173 | return x 174 | 175 | @torch.jit.export 176 | def forward_one(self, 177 | x: torch.Tensor, 178 | mem: torch.Tensor, 179 | incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]], 180 | mem_padding_mask: torch.BoolTensor = None, 181 | ) -> torch.Tensor: 182 | x = x[-1:] 183 | for layer in self.layer: 184 | x = layer(x, mem, incremental_state=incremental_state, encoder_padding_mask=mem_padding_mask)[0] 185 | x = self.layer_norm(x) 186 | return x 187 | 188 | 189 | class GatedCNN(nn.Module): 190 | def __init__(self, Protein_Features, Num_Filters, Embed_dim, Final_dim, K_size): 191 | super(GatedCNN, self).__init__() 192 | self.Protein_Embed = nn.Embedding(Protein_Features + 1, Embed_dim) 193 | self.Protein_Conv1 = nn.Conv1d(in_channels=1000, out_channels=Num_Filters, kernel_size=K_size) 194 | self.Protein_Gate1 = nn.Conv1d(in_channels=1000, out_channels=Num_Filters, kernel_size=K_size) 195 | self.Protein_Conv2 = nn.Conv1d(in_channels=Num_Filters, out_channels=Num_Filters * 2, kernel_size=K_size) 196 | self.Protein_Gate2 = nn.Conv1d(in_channels=Num_Filters, out_channels=Num_Filters * 2, kernel_size=K_size) 197 | self.Protein_Conv3 = nn.Conv1d(in_channels=Num_Filters * 2, out_channels=Num_Filters * 3, kernel_size=K_size) 198 | self.Protein_Gate3 = nn.Conv1d(in_channels=Num_Filters * 2, out_channels=Num_Filters * 3, kernel_size=K_size) 199 | self.relu = nn.ReLU() 200 | self.Protein_FC = nn.Linear(96 * 107, Final_dim) 201 | 202 | def forward(self, data): 203 | target = data.target 204 | Embed = self.Protein_Embed(target) 205 | conv1 = self.Protein_Conv1(Embed) 206 | gate1 = torch.sigmoid(self.Protein_Gate1(Embed)) 207 | GCNN1_Output = conv1 * gate1 208 | GCNN1_Output = self.relu(GCNN1_Output) 209 | #GATED CNN 2ND LAYER 210 | conv2 = self.Protein_Conv2(GCNN1_Output) 211 | gate2 = torch.sigmoid(self.Protein_Gate2(GCNN1_Output)) 212 | GCNN2_Output = conv2 * gate2 213 | GCNN2_Output = self.relu(GCNN2_Output) 214 | #GATED CNN 3RD LAYER 215 | conv3 = self.Protein_Conv3(GCNN2_Output) 216 | gate3 = torch.sigmoid(self.Protein_Gate3(GCNN2_Output)) 217 | GCNN3_Output = conv3 * gate3 218 | GCNN3_Output = self.relu(GCNN3_Output) 219 | #FLAT TENSOR 220 | xt = GCNN3_Output.view(-1, 96 * 107) 221 | #PROTEIN FULLY CONNECTED LAYER 222 | xt = self.Protein_FC(xt) 223 | return xt, GCNN3_Output 224 | 225 | class FC(torch.nn.Module): 226 | def __init__(self, output_dim, n_output, dropout): 227 | super(FC, self).__init__() 228 | self.FC_layers = nn.Sequential( 229 | nn.Linear(output_dim * 2, 1024), 230 | nn.ReLU(), 231 | nn.Dropout(dropout), 232 | nn.Linear(1024, 512), 233 | nn.ReLU(), 234 | nn.Dropout(dropout), 235 | nn.Linear(512, 256), 236 | nn.ReLU(), 237 | nn.Dropout(dropout), 238 | nn.Linear(256, n_output) 239 | ) 240 | 241 | def forward(self, Drug_Features, Protein_Features): 242 | Combined = torch.cat((Drug_Features, Protein_Features), 1) 243 | Pridection = self.FC_layers(Combined) 244 | return Pridection 245 | 246 | # MAin CLass 247 | class DeepDTAGen(torch.nn.Module): 248 | def __init__(self, tokenizer): 249 | super(DeepDTAGen, self).__init__() 250 | self.hidden_dim = 376 251 | self.max_len = 128 252 | self.node_feature = 94 253 | self.output_dim = 128 254 | self.ff_dim = 1024 255 | self.heads = 8 256 | self.layers = 8 257 | self.encoder_dropout = 0.2 258 | self.dropout = 0.3 259 | self.protein_f = 25 260 | self.filters = 32 261 | self.kernel = 8 262 | # Encoder, Decoder, and related components 263 | self.encoder = Encoder(Drug_Features=self.node_feature, dropout=self.encoder_dropout, Final_dim=self.output_dim) 264 | self.decoder = Decoder(dim=self.hidden_dim, ff_dim=self.ff_dim, num_head=self.heads, num_layer=self.layers) 265 | self.dencoder = TransformerEncoder(dim=self.hidden_dim, ff_dim=self.ff_dim, num_head=self.heads, num_layer=self.layers) 266 | self.pos_encoding = PositionalEncoding(self.hidden_dim, max_len=138) 267 | 268 | # CNN for processing protein features 269 | self.cnn = GatedCNN(Protein_Features=self.protein_f, Num_Filters=self.filters, 270 | Embed_dim=self.output_dim, Final_dim=self.output_dim, K_size=self.kernel) 271 | 272 | # Fully connected layer 273 | self.fc = FC(output_dim=self.output_dim, n_output=1, dropout=self.dropout) 274 | 275 | # Learnable parameter for segment encoding 276 | self.zz_seg_encoding = nn.Parameter(torch.randn(self.hidden_dim)) 277 | 278 | # Word prediction layers 279 | vocab_size = len(tokenizer) 280 | self.word_pred = nn.Sequential( 281 | nn.Linear(self.hidden_dim, self.hidden_dim), 282 | nn.PReLU(), 283 | nn.LayerNorm(self.hidden_dim), 284 | nn.Linear(self.hidden_dim, vocab_size) 285 | ) 286 | torch.nn.init.zeros_(self.word_pred[3].bias) 287 | 288 | # Other properties 289 | self.vocab_size = vocab_size 290 | self.sos_value = tokenizer.s2i[''] 291 | self.eos_value = tokenizer.s2i[''] 292 | self.pad_value = tokenizer.s2i[''] 293 | self.word_embed = nn.Embedding(vocab_size, self.hidden_dim) 294 | self.unk_index = Tokenizer.SPECIAL_TOKENS.index('') 295 | 296 | # Expand and Fusion layers 297 | self.expand = nn.Sequential( 298 | nn.Linear(self.hidden_dim, self.hidden_dim), 299 | nn.ReLU(), 300 | nn.LayerNorm(self.hidden_dim), 301 | nn.Linear(self.hidden_dim, self.hidden_dim), 302 | Rearrange('batch_size h -> 1 batch_size h') 303 | ) 304 | 305 | def expand_then_fusing(self, z, pp_mask, vvs): 306 | # Expand and fuse latent variables 307 | zz = z 308 | # zz = self.expand(z) 309 | zzs = zz + self.zz_seg_encoding 310 | 311 | # Create full mask for decoder 312 | full_mask = zz.new_zeros(zz.shape[1], zz.shape[0]) 313 | full_mask = torch.cat((pp_mask, full_mask), dim=1) # batch seq_plus 314 | 315 | # Concatenate latent variables and segment encoding 316 | zzz = torch.cat((vvs, zzs), dim=0) # seq_plus batch feat 317 | 318 | # Encode the concatenated sequence 319 | zzz = self.dencoder(zzz, full_mask) 320 | 321 | return zzz, full_mask 322 | 323 | def sample(self, batch_size, device): 324 | z = torch.randn(1, self.hidden_dim).to(device) 325 | return z 326 | 327 | def forward(self, data): 328 | # Process protein features through CNN 329 | Protein_vector, con = self.cnn(data) 330 | 331 | # Encode the input graph 332 | vss, AMVO, mask, PMVO, kl_loss = self.encoder(data, con) 333 | 334 | # Expand and fuse latent variables 335 | zzz, encoder_mask = self.expand_then_fusing(AMVO, mask, vss) 336 | 337 | # Prepare target sequence for decoding 338 | targets = data.target_seq 339 | _, target_length = targets.shape 340 | target_mask = torch.triu(torch.ones(target_length, target_length, dtype=torch.bool), diagonal=1).to(targets.device) 341 | target_embed = self.word_embed(targets) 342 | target_embed = self.pos_encoding(target_embed.permute(1, 0, 2).contiguous()) 343 | 344 | # Decode the target sequence 345 | output = self.decoder(target_embed, zzz, x_mask=target_mask, mem_padding_mask=encoder_mask).permute(1, 0, 2).contiguous() 346 | prediction_scores = self.word_pred(output) # batch_size, sequence_length, class 347 | 348 | # Compute loss and predictions 349 | shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() 350 | targets = targets[:, 1:].contiguous() 351 | batch_size, sequence_length, vocab_size = shifted_prediction_scores.size() 352 | shifted_prediction_scores = shifted_prediction_scores.view(-1, vocab_size) 353 | targets = targets.view(-1) 354 | 355 | Pridection = self.fc(PMVO, Protein_vector) 356 | lm_loss = F.cross_entropy(shifted_prediction_scores, targets, ignore_index=self.pad_value) 357 | 358 | return Pridection, prediction_scores, lm_loss, kl_loss 359 | 360 | 361 | def _generate(self, zzz, encoder_mask, random_sample, return_score=False): 362 | # Determine the batch size and device 363 | batch_size = zzz.shape[1] 364 | device = zzz.device 365 | 366 | # Initialize token tensor for sequence generation 367 | token = torch.full((batch_size, self.max_len), self.pad_value, dtype=torch.long, device=device) 368 | token[:, 0] = self.sos_value 369 | 370 | # Initialize positional encoding for text 371 | text_pos = self.pos_encoding.pe 372 | 373 | # Initialize text embedding for the first token 374 | text_embed = self.word_embed(token[:, 0]) 375 | text_embed = text_embed + text_pos[0] 376 | text_embed = text_embed.unsqueeze(0) 377 | 378 | # Initialize incremental state for transformer decoder 379 | incremental_state = torch.jit.annotate( 380 | Dict[str, Dict[str, Optional[torch.Tensor]]], 381 | torch.jit.annotate(Dict[str, Dict[str, Optional[torch.Tensor]]], {}), 382 | ) 383 | 384 | # Initialize scores if return_score is True 385 | if return_score: 386 | scores = [] 387 | 388 | # Initialize flag for finished sequences 389 | finished = torch.zeros(batch_size, dtype=torch.bool, device=device) 390 | 391 | # Loop for sequence generation 392 | for t in range(1, self.max_len): 393 | # Decode one token at a time 394 | one = self.decoder.forward_one(text_embed, zzz, incremental_state, mem_padding_mask=encoder_mask) 395 | one = one.squeeze(0) 396 | l = self.word_pred(one) # Get predicted scores for tokens 397 | 398 | # Append scores if return_score is True 399 | if return_score: 400 | scores.append(l) 401 | 402 | # Sample the next token either randomly or by choosing the one with maximum probability 403 | if random_sample: 404 | k = torch.multinomial(torch.softmax(l, 1), 1).squeeze(1) 405 | else: 406 | k = torch.argmax(l, -1) # Predict token with maximum probability 407 | 408 | # Update token tensor with the predicted token 409 | token[:, t] = k 410 | 411 | # Check if sequences are finished 412 | finished |= k == self.eos_value 413 | if finished.all(): 414 | break 415 | 416 | # Update text embedding for the next token 417 | text_embed = self.word_embed(k) 418 | text_embed = text_embed + text_pos[t] # Add positional encoding 419 | text_embed = text_embed.unsqueeze(0) 420 | 421 | # Extract the predicted sequence 422 | predict = token[:, 1:] 423 | 424 | # Return predicted sequence along with scores if return_score is True 425 | if return_score: 426 | return predict, torch.stack(scores, dim=1) 427 | return predict 428 | 429 | def generate(self, data, random_sample=False, return_z=False): 430 | # use protein condition from GatedCNN 431 | _, con = self.cnn(data) 432 | 433 | # Encode the input graph 434 | vss, AMVO, mask, PMVO, kl_loss = self.encoder(data, con) 435 | 436 | # Sample latent variables 437 | z = self.sample(data.batch, device=vss.device) 438 | 439 | # z = z + con1 + con2 440 | 441 | # zzz, encoder_mask = self.expand_then_fusing(z, mask, vss) 442 | 443 | # Expand and fuse latent variables 444 | zzz, encoder_mask = self.expand_then_fusing(AMVO, mask, vss) 445 | 446 | # Generate sequence based on latent variables 447 | predict = self._generate(zzz, encoder_mask, random_sample=random_sample, return_score=False) 448 | 449 | # Return predicted sequence along with latent variables if specified 450 | if return_z: 451 | return predict, z.detach().cpu().numpy() 452 | return predict 453 | -------------------------------------------------------------------------------- /DEMO/required_files_for_demo/model_gen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import GCNConv, global_max_pool as gmp 5 | from typing import Optional, Dict 6 | import math 7 | from fairseq.models import FairseqIncrementalDecoder 8 | from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer 9 | from torch.nn.utils.rnn import pad_sequence 10 | from utils import Tokenizer 11 | from einops.layers.torch import Rearrange 12 | 13 | class PositionalEncoding(nn.Module): 14 | 15 | def __init__(self, d_model, dropout=0.1, max_len=5000): 16 | super().__init__() 17 | self.dropout = nn.Dropout(p=dropout) 18 | 19 | pe = torch.zeros(max_len, d_model) 20 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 21 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 22 | pe[:, 0::2] = torch.sin(position * div_term) 23 | pe[:, 1::2] = torch.cos(position * div_term) 24 | pe = pe.unsqueeze(0).transpose(0, 1) 25 | self.register_buffer('pe', pe) 26 | 27 | def forward(self, x): 28 | r"""Inputs of forward function 29 | Args: 30 | x: the sequence fed to the positional encoder model (required). 31 | Shape: 32 | x: [sequence length, batch size, embed dim] 33 | output: [sequence length, batch size, embed dim] 34 | Examples: 35 | >>> output = pos_encoder(x) 36 | """ 37 | 38 | x = x + self.pe[:x.size(0), :] 39 | return self.dropout(x) 40 | 41 | class Namespace: 42 | def __init__(self, argvs): 43 | for k, v in argvs.items(): 44 | setattr(self, k, v) 45 | 46 | 47 | class TransformerEncoder(nn.Module): 48 | def __init__(self, dim, ff_dim, num_head, num_layer): 49 | super().__init__() 50 | 51 | self.layer = nn.ModuleList([ 52 | TransformerEncoderLayer(Namespace({ 53 | 'encoder_embed_dim': dim, 54 | 'encoder_attention_heads': num_head, 55 | 'attention_dropout': 0.1, 56 | 'dropout': 0.1, 57 | 'encoder_normalize_before': True, 58 | 'encoder_ffn_embed_dim': ff_dim, 59 | })) for i in range(num_layer) 60 | ]) 61 | 62 | self.layer_norm = nn.LayerNorm(dim) 63 | 64 | def forward(self, x, encoder_padding_mask=None): 65 | # print('my TransformerDecode forward()') 66 | for layer in self.layer: 67 | x = layer(x, encoder_padding_mask) 68 | x = self.layer_norm(x) 69 | return x # T x B x C 70 | 71 | class Encoder(torch.nn.Module): 72 | def __init__(self, Drug_Features, dropout, Final_dim): 73 | super(Encoder, self).__init__() 74 | self.hidden_dim = 376 75 | self.GraphConv1 = GCNConv(Drug_Features, Drug_Features * 2) 76 | self.GraphConv2 = GCNConv(Drug_Features * 2, Drug_Features * 3) 77 | self.GraphConv3 = GCNConv(Drug_Features * 3, Drug_Features * 4) 78 | self.cond = nn.Linear(96 * 107, self.hidden_dim) 79 | self.cond2 = nn.Linear(451, self.hidden_dim) 80 | 81 | self.mean = nn.Sequential( 82 | nn.Linear(self.hidden_dim, self.hidden_dim), 83 | nn.ReLU(), 84 | nn.Linear(self.hidden_dim, self.hidden_dim)) 85 | self.var = nn.Sequential( 86 | nn.Linear(self.hidden_dim, self.hidden_dim), 87 | nn.ReLU(), 88 | nn.Linear(self.hidden_dim, self.hidden_dim)) 89 | 90 | self.Drug_FCs = nn.Sequential( 91 | nn.Linear(Drug_Features * 4, 1024), 92 | nn.ReLU(), 93 | nn.Dropout(dropout), 94 | nn.Linear(1024, Final_dim) 95 | ) 96 | self.Relu_activation = nn.ReLU() 97 | self.pp_seg_encoding = nn.Parameter(torch.randn(376)) 98 | 99 | def reparameterize(self, z_mean, logvar, batch, con, a): 100 | # Compute the KL divergence loss 101 | z_log_var = -torch.abs(logvar) 102 | kl_loss = -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp()) / 64 103 | 104 | # Reparameterization trick: sample from N(0, 1) 105 | epsilon = torch.randn_like(z_mean).to(z_mean.device) 106 | z_ = z_mean + torch.exp(z_log_var / 2) * epsilon 107 | 108 | # Reshape and apply conditioning 109 | con = con.view(-1, 96 * 107) 110 | con_embedding = self.cond(con) 111 | 112 | # Add conditioning and auxiliary factor 113 | z_ = z_ + con_embedding + a 114 | 115 | return z_, kl_loss 116 | 117 | 118 | def process_p(self, node_features, num_nodes, batch_size): 119 | # Convert graph features to sequence 120 | d_node_features = pad_sequence(torch.split(node_features, num_nodes.tolist()), batch_first=False, padding_value=-999) 121 | # Initialize padded sequence with -999 122 | padded_sequence = d_node_features.new_ones((d_node_features.shape[0], 123 | d_node_features.shape[1], 124 | d_node_features.shape[2])) * -999 125 | # Fill padded sequence with node features 126 | padded_sequence[:d_node_features.shape[0], :, :] = d_node_features 127 | d_node_features = padded_sequence 128 | # Create mask for padding positions 129 | padding_mask = (d_node_features[:, :, 0].T == -999).bool() 130 | padded_sequence_with_encoding = d_node_features + self.pp_seg_encoding 131 | return padded_sequence_with_encoding, padding_mask 132 | 133 | def forward(self, data, con): 134 | x, edge_index, batch, num_nodes, affinity = data.x, data.edge_index, data.batch, data.c_size, data.y 135 | a = affinity.view(-1, 1) 136 | GCNConv = self.GraphConv1(x, edge_index) 137 | GCNConv = self.Relu_activation(GCNConv) 138 | GCNConv = self.GraphConv2(GCNConv, edge_index) 139 | GCNConv = self.Relu_activation(GCNConv) 140 | PMVO = self.GraphConv3(GCNConv, edge_index) 141 | x = self.Relu_activation(PMVO) 142 | d_sequence, Mask = self.process_p(x, num_nodes, batch) 143 | mu = self.mean(d_sequence) 144 | logvar = self.var(d_sequence) 145 | AMVO, kl_loss = self.reparameterize(mu, logvar, batch, con, a) 146 | x2 = gmp(x, batch) 147 | PMVO = self.Drug_FCs(x2) 148 | return d_sequence, AMVO, Mask, PMVO, kl_loss 149 | 150 | 151 | class Decoder(nn.Module): 152 | def __init__(self, dim, ff_dim, num_head, num_layer): 153 | super().__init__() 154 | 155 | self.layer = nn.ModuleList([ 156 | TransformerDecoderLayer(Namespace({ 157 | 'decoder_embed_dim': dim, 158 | 'decoder_attention_heads': num_head, 159 | 'attention_dropout': 0.1, 160 | 'dropout': 0.1, 161 | 'decoder_normalize_before': True, 162 | 'decoder_ffn_embed_dim': ff_dim, 163 | })) for i in range(num_layer) 164 | ]) 165 | self.layer_norm = nn.LayerNorm(dim) 166 | 167 | def forward(self, x, mem, x_mask=None, x_padding_mask=None, mem_padding_mask=None): 168 | for layer in self.layer: 169 | x = layer(x, mem, 170 | self_attn_mask=x_mask, self_attn_padding_mask=x_padding_mask, 171 | encoder_padding_mask=mem_padding_mask)[0] 172 | x = self.layer_norm(x) 173 | return x 174 | 175 | @torch.jit.export 176 | def forward_one(self, 177 | x: torch.Tensor, 178 | mem: torch.Tensor, 179 | incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]], 180 | mem_padding_mask: torch.BoolTensor = None, 181 | ) -> torch.Tensor: 182 | x = x[-1:] 183 | for layer in self.layer: 184 | x = layer(x, mem, incremental_state=incremental_state, encoder_padding_mask=mem_padding_mask)[0] 185 | x = self.layer_norm(x) 186 | return x 187 | 188 | 189 | class GatedCNN(nn.Module): 190 | def __init__(self, Protein_Features, Num_Filters, Embed_dim, Final_dim, K_size): 191 | super(GatedCNN, self).__init__() 192 | self.Protein_Embed = nn.Embedding(Protein_Features + 1, Embed_dim) 193 | self.Protein_Conv1 = nn.Conv1d(in_channels=1000, out_channels=Num_Filters, kernel_size=K_size) 194 | self.Protein_Gate1 = nn.Conv1d(in_channels=1000, out_channels=Num_Filters, kernel_size=K_size) 195 | self.Protein_Conv2 = nn.Conv1d(in_channels=Num_Filters, out_channels=Num_Filters * 2, kernel_size=K_size) 196 | self.Protein_Gate2 = nn.Conv1d(in_channels=Num_Filters, out_channels=Num_Filters * 2, kernel_size=K_size) 197 | self.Protein_Conv3 = nn.Conv1d(in_channels=Num_Filters * 2, out_channels=Num_Filters * 3, kernel_size=K_size) 198 | self.Protein_Gate3 = nn.Conv1d(in_channels=Num_Filters * 2, out_channels=Num_Filters * 3, kernel_size=K_size) 199 | self.relu = nn.ReLU() 200 | self.Protein_FC = nn.Linear(96 * 107, Final_dim) 201 | 202 | def forward(self, data): 203 | target = data.target 204 | Embed = self.Protein_Embed(target) 205 | conv1 = self.Protein_Conv1(Embed) 206 | gate1 = torch.sigmoid(self.Protein_Gate1(Embed)) 207 | GCNN1_Output = conv1 * gate1 208 | GCNN1_Output = self.relu(GCNN1_Output) 209 | #GATED CNN 2ND LAYER 210 | conv2 = self.Protein_Conv2(GCNN1_Output) 211 | gate2 = torch.sigmoid(self.Protein_Gate2(GCNN1_Output)) 212 | GCNN2_Output = conv2 * gate2 213 | GCNN2_Output = self.relu(GCNN2_Output) 214 | #GATED CNN 3RD LAYER 215 | conv3 = self.Protein_Conv3(GCNN2_Output) 216 | gate3 = torch.sigmoid(self.Protein_Gate3(GCNN2_Output)) 217 | GCNN3_Output = conv3 * gate3 218 | GCNN3_Output = self.relu(GCNN3_Output) 219 | #FLAT TENSOR 220 | xt = GCNN3_Output.view(-1, 96 * 107) 221 | #PROTEIN FULLY CONNECTED LAYER 222 | xt = self.Protein_FC(xt) 223 | return xt, GCNN3_Output 224 | 225 | class FC(torch.nn.Module): 226 | def __init__(self, output_dim, n_output, dropout): 227 | super(FC, self).__init__() 228 | self.FC_layers = nn.Sequential( 229 | nn.Linear(output_dim * 2, 1024), 230 | nn.ReLU(), 231 | nn.Dropout(dropout), 232 | nn.Linear(1024, 512), 233 | nn.ReLU(), 234 | nn.Dropout(dropout), 235 | nn.Linear(512, 256), 236 | nn.ReLU(), 237 | nn.Dropout(dropout), 238 | nn.Linear(256, n_output) 239 | ) 240 | 241 | def forward(self, Drug_Features, Protein_Features): 242 | Combined = torch.cat((Drug_Features, Protein_Features), 1) 243 | Pridection = self.FC_layers(Combined) 244 | return Pridection 245 | 246 | # MAin CLass 247 | class DeepDTAGen(torch.nn.Module): 248 | def __init__(self, tokenizer): 249 | super(DeepDTAGen, self).__init__() 250 | self.hidden_dim = 376 251 | self.max_len = 128 252 | self.node_feature = 94 253 | self.output_dim = 128 254 | self.ff_dim = 1024 255 | self.heads = 8 256 | self.layers = 8 257 | self.encoder_dropout = 0.2 258 | self.dropout = 0.3 259 | self.protein_f = 25 260 | self.filters = 32 261 | self.kernel = 8 262 | # Encoder, Decoder, and related components 263 | self.encoder = Encoder(Drug_Features=self.node_feature, dropout=self.encoder_dropout, Final_dim=self.output_dim) 264 | self.decoder = Decoder(dim=self.hidden_dim, ff_dim=self.ff_dim, num_head=self.heads, num_layer=self.layers) 265 | self.dencoder = TransformerEncoder(dim=self.hidden_dim, ff_dim=self.ff_dim, num_head=self.heads, num_layer=self.layers) 266 | self.pos_encoding = PositionalEncoding(self.hidden_dim, max_len=138) 267 | 268 | # CNN for processing protein features 269 | self.cnn = GatedCNN(Protein_Features=self.protein_f, Num_Filters=self.filters, 270 | Embed_dim=self.output_dim, Final_dim=self.output_dim, K_size=self.kernel) 271 | 272 | # Fully connected layer 273 | self.fc = FC(output_dim=self.output_dim, n_output=1, dropout=self.dropout) 274 | 275 | # Learnable parameter for segment encoding 276 | self.zz_seg_encoding = nn.Parameter(torch.randn(self.hidden_dim)) 277 | 278 | # Word prediction layers 279 | vocab_size = len(tokenizer) 280 | self.word_pred = nn.Sequential( 281 | nn.Linear(self.hidden_dim, self.hidden_dim), 282 | nn.PReLU(), 283 | nn.LayerNorm(self.hidden_dim), 284 | nn.Linear(self.hidden_dim, vocab_size) 285 | ) 286 | torch.nn.init.zeros_(self.word_pred[3].bias) 287 | 288 | # Other properties 289 | self.vocab_size = vocab_size 290 | self.sos_value = tokenizer.s2i[''] 291 | self.eos_value = tokenizer.s2i[''] 292 | self.pad_value = tokenizer.s2i[''] 293 | self.word_embed = nn.Embedding(vocab_size, self.hidden_dim) 294 | self.unk_index = Tokenizer.SPECIAL_TOKENS.index('') 295 | 296 | # Expand and Fusion layers 297 | self.expand = nn.Sequential( 298 | nn.Linear(self.hidden_dim, self.hidden_dim), 299 | nn.ReLU(), 300 | nn.LayerNorm(self.hidden_dim), 301 | nn.Linear(self.hidden_dim, self.hidden_dim), 302 | Rearrange('batch_size h -> 1 batch_size h') 303 | ) 304 | 305 | def expand_then_fusing(self, z, pp_mask, vvs): 306 | # Expand and fuse latent variables 307 | zz = z 308 | # zz = self.expand(z) 309 | zzs = zz + self.zz_seg_encoding 310 | 311 | # Create full mask for decoder 312 | full_mask = zz.new_zeros(zz.shape[1], zz.shape[0]) 313 | full_mask = torch.cat((pp_mask, full_mask), dim=1) # batch seq_plus 314 | 315 | # Concatenate latent variables and segment encoding 316 | zzz = torch.cat((vvs, zzs), dim=0) # seq_plus batch feat 317 | 318 | # Encode the concatenated sequence 319 | zzz = self.dencoder(zzz, full_mask) 320 | 321 | return zzz, full_mask 322 | 323 | def sample(self, batch_size, device): 324 | z = torch.randn(1, self.hidden_dim).to(device) 325 | return z 326 | 327 | def forward(self, data): 328 | # Process protein features through CNN 329 | Protein_vector, con = self.cnn(data) 330 | 331 | # Encode the input graph 332 | vss, AMVO, mask, PMVO, kl_loss = self.encoder(data, con) 333 | 334 | # Expand and fuse latent variables 335 | zzz, encoder_mask = self.expand_then_fusing(AMVO, mask, vss) 336 | 337 | # Prepare target sequence for decoding 338 | targets = data.target_seq 339 | _, target_length = targets.shape 340 | target_mask = torch.triu(torch.ones(target_length, target_length, dtype=torch.bool), diagonal=1).to(targets.device) 341 | target_embed = self.word_embed(targets) 342 | target_embed = self.pos_encoding(target_embed.permute(1, 0, 2).contiguous()) 343 | 344 | # Decode the target sequence 345 | output = self.decoder(target_embed, zzz, x_mask=target_mask, mem_padding_mask=encoder_mask).permute(1, 0, 2).contiguous() 346 | prediction_scores = self.word_pred(output) # batch_size, sequence_length, class 347 | 348 | # Compute loss and predictions 349 | shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() 350 | targets = targets[:, 1:].contiguous() 351 | batch_size, sequence_length, vocab_size = shifted_prediction_scores.size() 352 | shifted_prediction_scores = shifted_prediction_scores.view(-1, vocab_size) 353 | targets = targets.view(-1) 354 | 355 | Pridection = self.fc(PMVO, Protein_vector) 356 | lm_loss = F.cross_entropy(shifted_prediction_scores, targets, ignore_index=self.pad_value) 357 | 358 | return Pridection, prediction_scores, lm_loss, kl_loss 359 | 360 | 361 | def _generate(self, zzz, encoder_mask, random_sample, return_score=False): 362 | # Determine the batch size and device 363 | batch_size = zzz.shape[1] 364 | device = zzz.device 365 | 366 | # Initialize token tensor for sequence generation 367 | token = torch.full((batch_size, self.max_len), self.pad_value, dtype=torch.long, device=device) 368 | token[:, 0] = self.sos_value 369 | 370 | # Initialize positional encoding for text 371 | text_pos = self.pos_encoding.pe 372 | 373 | # Initialize text embedding for the first token 374 | text_embed = self.word_embed(token[:, 0]) 375 | text_embed = text_embed + text_pos[0] 376 | text_embed = text_embed.unsqueeze(0) 377 | 378 | # Initialize incremental state for transformer decoder 379 | incremental_state = torch.jit.annotate( 380 | Dict[str, Dict[str, Optional[torch.Tensor]]], 381 | torch.jit.annotate(Dict[str, Dict[str, Optional[torch.Tensor]]], {}), 382 | ) 383 | 384 | # Initialize scores if return_score is True 385 | if return_score: 386 | scores = [] 387 | 388 | # Initialize flag for finished sequences 389 | finished = torch.zeros(batch_size, dtype=torch.bool, device=device) 390 | 391 | # Loop for sequence generation 392 | for t in range(1, self.max_len): 393 | # Decode one token at a time 394 | one = self.decoder.forward_one(text_embed, zzz, incremental_state, mem_padding_mask=encoder_mask) 395 | one = one.squeeze(0) 396 | l = self.word_pred(one) # Get predicted scores for tokens 397 | 398 | # Append scores if return_score is True 399 | if return_score: 400 | scores.append(l) 401 | 402 | # Sample the next token either randomly or by choosing the one with maximum probability 403 | if random_sample: 404 | k = torch.multinomial(torch.softmax(l, 1), 1).squeeze(1) 405 | else: 406 | k = torch.argmax(l, -1) # Predict token with maximum probability 407 | 408 | # Update token tensor with the predicted token 409 | token[:, t] = k 410 | 411 | # Check if sequences are finished 412 | finished |= k == self.eos_value 413 | if finished.all(): 414 | break 415 | 416 | # Update text embedding for the next token 417 | text_embed = self.word_embed(k) 418 | text_embed = text_embed + text_pos[t] # Add positional encoding 419 | text_embed = text_embed.unsqueeze(0) 420 | 421 | # Extract the predicted sequence 422 | predict = token[:, 1:] 423 | 424 | # Return predicted sequence along with scores if return_score is True 425 | if return_score: 426 | return predict, torch.stack(scores, dim=1) 427 | return predict 428 | 429 | def generate(self, data, random_sample=False, return_z=False): 430 | # use protein condition from GatedCNN 431 | _, con = self.cnn(data) 432 | 433 | # Encode the input graph 434 | vss, AMVO, mask, PMVO, kl_loss = self.encoder(data, con) 435 | 436 | # Sample latent variables 437 | z = self.sample(data.batch, device=vss.device) 438 | 439 | # z = z + con1 + con2 440 | 441 | # zzz, encoder_mask = self.expand_then_fusing(z, mask, vss) 442 | 443 | # Expand and fuse latent variables 444 | zzz, encoder_mask = self.expand_then_fusing(AMVO, mask, vss) 445 | 446 | # Generate sequence based on latent variables 447 | predict = self._generate(zzz, encoder_mask, random_sample=random_sample, return_score=False) 448 | 449 | # Return predicted sequence along with latent variables if specified 450 | if return_z: 451 | return predict, z.detach().cpu().numpy() 452 | return predict 453 | 454 | 455 | # def generate(self, data, random_sample=False, return_z=False): 456 | # # use protein condition from GatedCNN 457 | # _, con = self.cnn(data) 458 | 459 | # # Encode the input graph 460 | # vss, AMVO, mask, PMVO, kl_loss = self.encoder(data, con) 461 | 462 | # # Sample latent variables 463 | # # z = self.sample(data.batch, device=vss.device) 464 | 465 | # z = z + con1 + con2 466 | 467 | # zzz, encoder_mask = self.expand_then_fusing(z, mask, vss) 468 | 469 | # # Expand and fuse latent variables 470 | # # zzz, encoder_mask = self.expand_then_fusing(AMVO, mask, vss) 471 | 472 | # # Generate sequence based on latent variables 473 | # predict = self._generate(zzz, encoder_mask, random_sample=random_sample, return_score=False) 474 | 475 | # # Return predicted sequence along with latent variables if specified 476 | # if return_z: 477 | # return predict, z.detach().cpu().numpy() 478 | # return predict 479 | --------------------------------------------------------------------------------