├── .gitignore ├── utils ├── __init__.py ├── ddi_config.py ├── weighted_trainer.py ├── genmol_data.py ├── lipo_datasets.py ├── sampl_datasets.py ├── hiv_datasets.py ├── tox21_datasets.py ├── delaney_datasets.py ├── cep_datasets.py ├── bbbp_datasets.py ├── malaria_datasets.py ├── clintox_datasets.py ├── mol.py ├── sider_datasets.py ├── gen_utils.py ├── bace_datasets.py ├── sascorer.py ├── torchvocab.py ├── gen_dataset.py ├── molnet_loader.py ├── features.py ├── molnet_dataloader.py ├── ddi_dataset.py └── dti_asset.py ├── iupac_regex ├── config.json ├── merges.txt └── vocab.json ├── smiles_tokenizer ├── config.json └── merges.txt ├── cliff_sim.png ├── vocab └── Merge_vocab.pkl ├── download_pubchem ├── remove_bracket.py ├── remove_nums.py ├── remove_bracs_nums.py ├── download.sh ├── extract_info.py └── txt2csv.py ├── calculate_fg.py ├── canonicalize.py ├── README.md ├── extract_feat_labels.py ├── parsing.py ├── cliff_pair.csv ├── gcn.py └── pe_2d └── config.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /iupac_regex/config.json: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /iupac_regex/merges.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /smiles_tokenizer/config.json: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /smiles_tokenizer/merges.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cliff_sim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fengshikun/UniMAP/HEAD/cliff_sim.png -------------------------------------------------------------------------------- /vocab/Merge_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fengshikun/UniMAP/HEAD/vocab/Merge_vocab.pkl -------------------------------------------------------------------------------- /utils/ddi_config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | 4 | DDI_CONFIG = edict() 5 | DDI_CONFIG.data_dir = '/sharefs/sharefs-test_data/deepddi/data/DrugBank5.0_Approved_drugs' 6 | DDI_CONFIG.label_file = '/sharefs/sharefs-test_data/deepddi/data/DrugBank_known_ddi.txt' 7 | DDI_CONFIG.train_ratio = 0.6 -------------------------------------------------------------------------------- /download_pubchem/remove_bracket.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | if __name__=='__main__': 4 | print('Reading...') 5 | data = pd.read_csv('/sharefs//chem_data/pubchem/data_1m/iupacs.csv') 6 | print('Replacing...') 7 | data['no_bracs'] = data['Preferred'].apply(lambda x:x.replace('(',' ').replace(')',' ').replace('[',' ').replace(']',' ').strip()) 8 | data.drop(['Preferred'],axis=1,inplace=True) 9 | print('Writing...') 10 | data.to_csv('/sharefs//chem_data/pubchem/data_1m/iupacs_no_bracs.csv',index=False) -------------------------------------------------------------------------------- /download_pubchem/remove_nums.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import re 3 | 4 | if __name__=='__main__': 5 | print('Reading...') 6 | data = pd.read_csv('/sharefs//chem_data/pubchem/data_1m/iupacs.csv') 7 | print('Replacing...') 8 | #data['no_bracs'] = data['Preferred'].apply(lambda x:x.replace('(',' ').replace(')',' ').replace('[',' ').replace(']',' ').strip()) 9 | 10 | pattern = r'[0-9]' 11 | data['no_nums'] = data['Preferred'].apply(lambda x: re.sub(pattern,'',x)).apply(lambda x:x.replace(',','')).apply(lambda x:x.replace('-',' ')) 12 | 13 | data.drop(['Preferred'],axis=1,inplace=True) 14 | # data.drop(['no_bracs'],axis=1,inplace=True) 15 | 16 | print('Writing...') 17 | data.to_csv('/sharefs//chem_data/pubchem/data_1m/iupacs_no_nums.csv',index=False) -------------------------------------------------------------------------------- /download_pubchem/remove_bracs_nums.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import re 3 | 4 | if __name__=='__main__': 5 | print('Reading...') 6 | data = pd.read_csv('/sharefs//chem_data/pubchem/data_1m/iupacs.csv') 7 | print('Replacing...') 8 | data['no_bracs'] = data['Preferred'].apply(lambda x:x.replace('(',' ').replace(')',' ').replace('[',' ').replace(']',' ').strip()) 9 | 10 | pattern = r'[0-9]' 11 | data['no_bracs_nums'] = data['no_bracs'].apply(lambda x: re.sub(pattern,'',x)).apply(lambda x:x.replace(',','')).apply(lambda x:x.replace('-',' ')) 12 | 13 | data.drop(['Preferred'],axis=1,inplace=True) 14 | data.drop(['no_bracs'],axis=1,inplace=True) 15 | 16 | print('Writing...') 17 | data.to_csv('/sharefs//chem_data/pubchem/data_1m/iupacs_no_bracs_nums.csv',index=False) -------------------------------------------------------------------------------- /calculate_fg.py: -------------------------------------------------------------------------------- 1 | 2 | from rdkit import DataStructs, Chem 3 | from rdkit.Chem import AllChem 4 | import numpy as np 5 | 6 | # csv_file = '/sharefs/sharefs-test_data/PUBCHEM/iso_processed_txt/ALL_CSV/smiles.csv.can' 7 | csv_file = '/home/AI4Science/fengsk/pubchem/iso_processed_txt/smiles.csv.can' 8 | 9 | 10 | import multiprocess 11 | from tqdm import tqdm 12 | pool = multiprocess.Pool(32) 13 | 14 | import io 15 | def lines(): 16 | with io.open(csv_file, 'r', encoding='utf8', newline='\n') as srcf: 17 | for line in srcf: 18 | yield line.strip() 19 | 20 | total = len(io.open(csv_file, 'r', encoding='utf8', newline='\n').readlines()) 21 | # import pdb; pdb.set_trace() 22 | 23 | def get_bitvec(smiles): 24 | try: 25 | mol = Chem.MolFromSmiles(smiles) 26 | if mol is None: 27 | return np.full((2048), -1, dtype=np.int8) 28 | mol1 = Chem.AddHs(mol) 29 | fps1 = AllChem.GetMorganFingerprintAsBitVect(mol1, 2, nBits=2048, useChirality=False) 30 | fp_array = np.zeros((0, ), dtype=np.int8) 31 | DataStructs.ConvertToNumpyArray(fps1, fp_array) 32 | return fp_array 33 | except: 34 | return np.full((2048), -1, dtype=np.int8) 35 | 36 | 37 | ALL_FP = np.zeros((total, 2048), dtype=np.int8) 38 | 39 | cnt = 0 40 | for fp in tqdm(pool.imap(get_bitvec, lines(), chunksize=100), total=total): 41 | ALL_FP[cnt] = fp 42 | cnt += 1 43 | 44 | # np.save('/sharefs/sharefs-test_data/PUBCHEM/iso_processed_txt/ALL_CSV/all_fp.npy', ALL_FP) 45 | np.save('/home/AI4Science/fengsk/pubchem/iso_processed_txt/all_fp.npy', ALL_FP) 46 | pool.close() 47 | print('Finished') -------------------------------------------------------------------------------- /utils/weighted_trainer.py: -------------------------------------------------------------------------------- 1 | from transformers import Trainer 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def get_sampler_weight(train_dataset, val_dataset): 7 | train_labels = train_dataset.labels 8 | val_labels = val_dataset.labels 9 | all_labels = np.concatenate([train_labels, val_labels], axis=0) 10 | sample_num, task_num = all_labels.shape 11 | # label one ratio for each task 12 | label_one_ratio = all_labels.sum(axis=0) / sample_num 13 | label_zero_ratio = 1 - label_one_ratio 14 | label_one_ratio = label_one_ratio.reshape(-1, 1) 15 | label_zero_ratio = label_zero_ratio.reshape(-1, 1) 16 | label_ratio = np.concatenate([label_zero_ratio, label_one_ratio], axis=1) 17 | # assume we have binary label of training dataset and val dataset; use label as index 18 | all_label_int = all_labels.astype(np.int) 19 | 20 | all_label_ratio = np.zeros_like(all_labels) 21 | for i in range(task_num): 22 | all_label_ratio[:, i] = label_ratio[i][all_label_int[:, i]] 23 | 24 | # split train label ratio 25 | train_num = train_labels.shape[0] 26 | train_label_ratio = all_label_ratio[:train_num] 27 | 28 | # change the ratio to the weight 29 | train_label_weight = (1.0 / train_label_ratio).sum(axis=1) 30 | 31 | # normalise 32 | normalise_alpha = train_num / train_label_weight.sum() 33 | 34 | train_label_weight_norm = normalise_alpha * train_label_weight 35 | 36 | return train_label_weight_norm 37 | 38 | class WeightedTrainer(Trainer): 39 | def set_weight(self, weight): 40 | self.weight = weight 41 | self.sample_len = weight.shape[0] 42 | def _get_train_sampler(self): 43 | generator = torch.Generator() 44 | generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) 45 | sampler = torch.utils.data.WeightedRandomSampler(self.weight, self.sample_len, replacement=True, generator=generator) 46 | return sampler 47 | -------------------------------------------------------------------------------- /canonicalize.py: -------------------------------------------------------------------------------- 1 | import re 2 | import io 3 | import argparse 4 | from tqdm import tqdm 5 | import multiprocessing 6 | from rdkit import Chem 7 | 8 | 9 | def rm_map_number(smiles): 10 | t = re.sub(':\d*', '', smiles) 11 | return t 12 | 13 | 14 | def canonicalize(smiles): 15 | try: 16 | smiles, keep_atommap = smiles 17 | if not keep_atommap: 18 | # smiles = rm_map_number(smiles) 19 | pass 20 | mol = Chem.MolFromSmiles(smiles) 21 | if mol is None: 22 | return None 23 | else: 24 | smiles = Chem.MolToSmiles(mol) 25 | if Chem.MolFromSmiles(smiles) is not None: 26 | return smiles 27 | # return Chem.MolToSmiles(mol) 28 | else: 29 | return None 30 | except: 31 | return None 32 | 33 | 34 | def main(args): 35 | input_fn = args.fn 36 | 37 | def lines(): 38 | with io.open(input_fn, 'r', encoding='utf8', newline='\n') as srcf: 39 | for line in srcf: 40 | yield line.strip(), args.keep_atommapnum 41 | 42 | results = [] 43 | total = len(io.open(input_fn, 'r', encoding='utf8', newline='\n').readlines()) 44 | 45 | pool = multiprocessing.Pool(args.workers) 46 | for res in tqdm(pool.imap(canonicalize, lines(), chunksize=100000), total=total): 47 | if res is not None: 48 | results.append('{}\n'.format(res)) 49 | 50 | if args.output_fn is None: 51 | output_fn = '{}.can'.format(input_fn) 52 | else: 53 | output_fn = args.output_fn 54 | io.open(output_fn, 'w', encoding='utf8', newline='\n').writelines(results) 55 | print('{}/{}'.format(len(results), total)) 56 | 57 | 58 | if __name__ == "__main__": 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument('fn', type=str) 61 | parser.add_argument('--workers', type=int, default=1) 62 | parser.add_argument('--output-fn', type=str, default=None) 63 | parser.add_argument('--keep-atommapnum', action='store_true', default=False) 64 | args = parser.parse_args() 65 | main(args) -------------------------------------------------------------------------------- /download_pubchem/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # rlaunch --private-machine=group --charged-group=health --cpu=16 --gpu=0 --memory=100000 -- zsh download.sh 3 | # there must be enough memory to support multiprocessing. 4 | # 0413: 5 cols: Preferred Canonical Formula Mass Log P 5 | MIN=5 6 | MAX=5 # 1555 7 | 8 | PREFIX="ftp://ftp.ncbi.nlm.nih.gov/pubchem/Compound/CURRENT-Full/XML/" 9 | # fill this in 10 | DOWNLOAD_DIR="/sharefs/chem_data/pubchem/" 11 | EXTRACTION_FILE='names_properties_8_cols.txt' #'iupacs_properties_10.txt' 12 | FINAL_OUTPUT_DIR='data_1m_8cols/' 13 | mkdir -p ${DOWNLOAD_DIR}${FINAL_OUTPUT_DIR} 14 | 15 | prev_num="0000" 16 | for i in $(seq $MIN 5 $MAX); do 17 | 18 | num=$(printf "%04d" $i) 19 | echo $num 20 | fn="Compound_${prev_num}00001_${num}00000.xml" 21 | prev_num=$num 22 | echo "getting" $fn 23 | if ! [[ -f $DOWNLOAD_DIR$fn ]]; then 24 | orig_dir=$(pwd) 25 | cd $DOWNLOAD_DIR 26 | wget "${PREFIX}${fn}.gz" 27 | wget "${PREFIX}${fn}.gz.md5" 28 | if md5sum -c ${fn}.gz.md5; then 29 | echo md5 passed 30 | # rm ${fn}.gz.md5 31 | # pigz does multithreaded unzipping. If you don't have pigz, 32 | # you can use gunzip by uncommenting the line below 33 | gunzip $fn 34 | # pigz -d -p 8 $fn 35 | else 36 | echo md5 failed 37 | fi 38 | cd $orig_dir 39 | fi 40 | echo "extracting" 41 | # python extract_info.py $DOWNLOAD_DIR$fn "" Preferred 11 34 -26 Traditional 11 34 -26 "Canonical<" 11 34 -26 Mass 12 34 -26 Formula 11 34 -26 "Log P" 11 34 -26 >> ${DOWNLOAD_DIR}iupacs_properties.txt 42 | # rm $DOWNLOAD_DIR$fn 43 | # echo ${DOWNLOAD_DIR}${EXTRACTION_FILE} 44 | echo ${DOWNLOAD_DIR}${FINAL_OUTPUT_DIR}${EXTRACTION_FILE} 45 | echo $DOWNLOAD_DIR$fn 46 | #python extract_info.py $DOWNLOAD_DIR$fn "" Preferred 11 34 -26 "Canonical<" 11 34 -26 Formula 11 34 -26 >> ${DOWNLOAD_DIR}${EXTRACTION_FILE} #iupacs_properties.txt 47 | python extract_info.py $DOWNLOAD_DIR$fn "" Preferred 11 34 -26 'CAS-like Style' 11 34 -26 Systematic 11 34 -26 Traditional 11 34 -26 "Canonical<" 11 34 -26 Formula 11 34 -26 Mass 12 34 -26 "Log P" 11 34 -26 >> ${DOWNLOAD_DIR}${FINAL_OUTPUT_DIR}${EXTRACTION_FILE} 48 | 49 | wc -l ${DOWNLOAD_DIR}${FINAL_OUTPUT_DIR}${EXTRACTION_FILE} 50 | 51 | done 52 | 53 | python txt2csv.py \ 54 | --input_dir=${DOWNLOAD_DIR}${FINAL_OUTPUT_DIR}${EXTRACTION_FILE} \ 55 | --output_dir=${DOWNLOAD_DIR}${FINAL_OUTPUT_DIR} 56 | -------------------------------------------------------------------------------- /download_pubchem/extract_info.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | from multiprocessing import Pool 4 | import numpy as np 5 | import itertools 6 | 7 | # look through fn for all the provided search terms (keys), 8 | # and extract values as directed by the offset, start & end cols 9 | 10 | # example usage to get IUPAC: 11 | # python extract_mass_formula.py Compounds.xml Systematic 11 34 -26 12 | 13 | # Systematic 11 34 -26 14 | # Mass 12 34 -26 15 | # Formula 11 34 -26 16 | # Log P 11 34 -26 17 | 18 | # chemicals are separated by the 2nd arg 19 | 20 | LINES_PER_PROC = 10000 21 | 22 | fn = sys.argv[1] 23 | 24 | assert len(sys.argv) > 3, "need to provide search terms, etc." 25 | assert len(sys.argv[3:]) % 4 == 0, "each search term needs offset & cols" 26 | 27 | chemical_separator = sys.argv[2] 28 | 29 | search_terms = [] 30 | line_offsets = [] 31 | start_cols = [] 32 | end_cols = [] 33 | 34 | for i in range(3, len(sys.argv), 4): 35 | search_terms.append(sys.argv[i]) 36 | line_offsets.append(int(sys.argv[i+1])) 37 | start_cols.append(int(sys.argv[i+2])) 38 | end_cols.append(int(sys.argv[i+3])) 39 | 40 | lines = [] 41 | 42 | def find_relevant(start_line): 43 | relevant_lines = [] 44 | max_length = len(lines) 45 | for i in range(LINES_PER_PROC): 46 | if start_line + i >= max_length: 47 | return relevant_lines 48 | line = lines[start_line + i] 49 | if chemical_separator in line: 50 | relevant_lines.append(start_line + i) 51 | for search_term in search_terms: 52 | if search_term in line: 53 | relevant_lines.append(start_line + i) 54 | return relevant_lines 55 | 56 | with open(fn, "r") as xml_file: 57 | # first line is headers 58 | found_values = copy.deepcopy(search_terms) 59 | # print('found_values:',found_values) 60 | 61 | lines = xml_file.readlines() 62 | 63 | p = Pool(32) 64 | relevant_lines = p.map(find_relevant, 65 | range(0, len(lines), LINES_PER_PROC)) 66 | relevant_lines = itertools.chain.from_iterable(relevant_lines) 67 | relevant_lines = np.array(list(relevant_lines)) 68 | # print(relevant_lines[0]) 69 | 70 | for idx,i in enumerate(relevant_lines): 71 | line = lines[i] 72 | if chemical_separator in line: 73 | # new chemical -- reset search term lines & found_values 74 | # print("|".join(found_values)) 75 | if idx>0: 76 | print("\t".join(found_values)) 77 | found_values = ["" for _ in search_terms] 78 | continue 79 | 80 | for j, search_term in enumerate(search_terms): 81 | if search_term in line: 82 | # found the jth search term on line i 83 | found = i + line_offsets[j] 84 | found_values[j] = lines[found][start_cols[j]:end_cols[j]] 85 | 86 | 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UniMAP code 2 | 3 | This is the offical implementation of paper "UniMAP: Universal SMILES-Graph Representation Learning" 4 | 5 | 6 | 7 | 8 | The code is based on huggingface transformers(4.9.2), [chemBerta](https://github.com/seyonechithrananda/bert-loves-chemistry) ,and [C5T5](https://github.com/dhroth/c5t5). 9 | 10 | 11 | 12 | ### Data downloading and extraction 13 | 14 | 15 | #### Download pubchem 16 | The scrpits in ./download_pubchem are for downlaoding data from [pubchem](https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/CURRENT-Full/XML/) and further extration of certain fields (like iupac name, smiles, formula). 17 | 18 | #### Generate ECFP fingerprints and function group labels 19 | 20 | ```python -u extract_feat_labels.py --smiles_file pubchme_smiles_file.lst``` 21 | 22 | 23 | ### Pre-training: 24 | 25 | ``` 26 | python -u -m torch.distributed.launch --nproc_per_node 8 --master_port 1233 train_multimodal_uni.py --run_name unimap_pretrain --dataset_path pubchme_smiles_file.lst --logging_steps 5 --tokenizer_path iupac_regex/ --smiles_tokenizer_path smiles_tokenizer/ --num_train_epochs 40 --output_dir unimap_pretraining --per_device_train_batch_size 64 --atom_vocab_file vocab/Merge_vocab.pkl --atom_vocab_size 10535 --function_group_file fg_labels.npy --finger_print_file ecfp_regression.npy --mlm_probability 0.2 --smiles_only --get_frag --pooler_type avg --fp16 --gnn_number_layer 3 --graph_max_seq_size 128 --mlm_group_probability 0.6 --gnn_dropout 0 --check_frag 27 | ``` 28 | 29 | We also provide the pre-trained model at link: [https://drive.google.com/drive/folders/11_L4dq_n9P8cEWp6sEtb0d-MCgATDD-O?usp=sharing](https://drive.google.com/drive/folders/11_L4dq_n9P8cEWp6sEtb0d-MCgATDD-O?usp=sharing) 30 | 31 | 32 | 33 | 34 | ### Fine-tuning(MoleculeNet and DTA): 35 | 36 | 37 | 38 | 39 | test_gnn_smiles_uni_avg_srun.sh 40 | ``` 41 | cd finetune; 42 | declare -a test_sets=("bbbp" "clintox" "kiba" "davis") 43 | 44 | for test_set in "${test_sets[@]}"; do 45 | CUDA_VISIBLE_DEVICES=0 python -u finetune_mm_split.py --datasets $test_set --pretrained_model_name_or_path $1 --tokenizer_path iupac_regex/ --smiles_tokenizer_path smiles_tokenizer/ --lang_only --graph_uni --output_dir $1/graph_uni --pooler_type avg --split scaffold --n_seeds 3 --split scaffold --n_seeds 1 --n_trials 20 --graph_max_seq_size 128 --gnn_number_layer 3 --per_device_train_batch_size 64 --num_train_epochs_max 100 --number_seed 3 46 | done 47 | ``` 48 | 49 | bash test_gnn_smiles_uni_avg_srun.sh ./train_uni_smile 50 | 51 | 52 | 53 | ### Fine-tuning(DDI) 54 | 55 | test_ddi_srun.sh: 56 | ``` 57 | cd finetune; 58 | python finetune_mm_split_ddi.py --datasets deepddi --split random --pretrained_model_name_or_path $1 --tokenizer_path ../iupac_regex/ --smiles_tokenizer_path ../smiles_tokenizer --output_dir $1/graph_uni --lang_only --graph_uni --per_device_train_batch_size 32 --pooler_type avg --n_seeds 3 --n_trials 20 --gnn_number_layer 3 59 | 60 | ``` 61 | 62 | bash test_ddi_srun.sh ./train_uni_smile 63 | -------------------------------------------------------------------------------- /utils/genmol_data.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from torch import nn 4 | from torch import optim 5 | import torch.nn.functional as F 6 | import subprocess 7 | import pickle 8 | import selfies as sf 9 | import csv 10 | from tqdm import tqdm 11 | import os 12 | from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split 13 | from rdkit.Chem.Crippen import MolLogP 14 | from rdkit.Chem import MolFromSmiles, QED 15 | import time 16 | import math 17 | 18 | 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | 21 | delta_g_to_kd = lambda x: math.exp(x / (0.00198720425864083 * 298.15)) 22 | 23 | 24 | class MolDataModule(pl.LightningDataModule): 25 | def __init__(self, batch_size, file): 26 | super(MolDataModule, self).__init__() 27 | self.batch_size = batch_size 28 | self.dataset = Dataset(file) 29 | self.train_data, self.test_data = random_split(self.dataset, [int(round(len(self.dataset) * 0.8)), int(round(len(self.dataset) * 0.2))]) 30 | 31 | def train_dataloader(self): 32 | return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=16, pin_memory=True) 33 | 34 | def val_dataloader(self): 35 | return DataLoader(self.test_data, batch_size=self.batch_size, drop_last=True, num_workers=16, pin_memory=True) 36 | 37 | 38 | class PropDataModule(pl.LightningDataModule): 39 | def __init__(self, x, y, batch_size): 40 | super(PropDataModule, self).__init__() 41 | self.batch_size = batch_size 42 | self.dataset = TensorDataset(x, y) 43 | self.train_data, self.test_data = random_split(self.dataset, [int(round(len(self.dataset) * 0.9)), int(round(len(self.dataset) * 0.1))]) 44 | 45 | def train_dataloader(self): 46 | return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, drop_last=True) 47 | 48 | def val_dataloader(self): 49 | return DataLoader(self.test_data, batch_size=self.batch_size, drop_last=True) 50 | 51 | 52 | 53 | class Dataset(Dataset): 54 | def __init__(self, file): 55 | selfies = [sf.encoder(line.split()[0]) for line in open(file, 'r')] 56 | self.alphabet = set() 57 | for s in selfies: 58 | self.alphabet.update(sf.split_selfies(s)) 59 | self.alphabet = ['[nop]'] + list(sorted(self.alphabet)) 60 | self.max_len = max(len(list(sf.split_selfies(s))) for s in selfies) 61 | self.symbol_to_idx = {s: i for i, s in enumerate(self.alphabet)} 62 | self.idx_to_symbol = {i: s for i, s in enumerate(self.alphabet)} 63 | self.encodings = [[self.symbol_to_idx[symbol] for symbol in sf.split_selfies(s)] for s in selfies] 64 | 65 | def __len__(self): 66 | return len(self.encodings) 67 | 68 | def __getitem__(self, i): 69 | return torch.tensor(self.encodings[i] + [self.symbol_to_idx['[nop]'] for i in range(self.max_len - len(self.encodings[i]))]) -------------------------------------------------------------------------------- /utils/lipo_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lipophilicity dataset loader. 3 | """ 4 | import os 5 | import deepchem as dc 6 | # from deepchem.molnet.load_function.molnet_loader import TransformerGenerator, _MolnetLoader 7 | from utils.molnet_loader import TransformerGenerator, _MolnetLoader 8 | from deepchem.data import Dataset 9 | from typing import List, Optional, Tuple, Union 10 | from utils.data_loader import CSVLoader 11 | 12 | LIPO_URL = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/Lipophilicity.csv" 13 | LIPO_TASKS = ['exp'] 14 | 15 | LIPO_IUPAC_FILE = "../test_data/Lipophilicity_iupac.csv" 16 | 17 | 18 | 19 | class _LipoLoader(_MolnetLoader): 20 | 21 | def create_dataset(self) -> Dataset: 22 | # dataset_file = os.path.join(self.data_dir, "Lipophilicity.csv") 23 | dataset_file = LIPO_IUPAC_FILE 24 | if not os.path.exists(dataset_file): 25 | dc.utils.data_utils.download_url(url=LIPO_URL, dest_dir=self.data_dir) 26 | loader = CSVLoader( 27 | tasks=self.tasks, feature_field="smiles", featurizer=self.featurizer, id_field="iupac1") 28 | return loader.create_dataset(dataset_file, shard_size=8192) 29 | 30 | 31 | def load_lipo( 32 | featurizer: Union[dc.feat.Featurizer, str] = 'ECFP', 33 | splitter: Union[dc.splits.Splitter, str, None] = 'scaffold', 34 | # transformers: List[Union[TransformerGenerator, str]] = ['normalization'], 35 | transformers: List[Union[TransformerGenerator, str]] = [], 36 | reload: bool = False, 37 | data_dir: Optional[str] = None, 38 | save_dir: Optional[str] = None, 39 | **kwargs 40 | ) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]: 41 | """Load Lipophilicity dataset 42 | 43 | Lipophilicity is an important feature of drug molecules that affects both 44 | membrane permeability and solubility. The lipophilicity dataset, curated 45 | from ChEMBL database, provides experimental results of octanol/water 46 | distribution coefficient (logD at pH 7.4) of 4200 compounds. 47 | 48 | Scaffold splitting is recommended for this dataset. 49 | 50 | The raw data csv file contains columns below: 51 | 52 | - "smiles" - SMILES representation of the molecular structure 53 | - "exp" - Measured octanol/water distribution coefficient (logD) of the 54 | compound, used as label 55 | 56 | Parameters 57 | ---------- 58 | featurizer: Featurizer or str 59 | the featurizer to use for processing the data. Alternatively you can pass 60 | one of the names from dc.molnet.featurizers as a shortcut. 61 | splitter: Splitter or str 62 | the splitter to use for splitting the data into training, validation, and 63 | test sets. Alternatively you can pass one of the names from 64 | dc.molnet.splitters as a shortcut. If this is None, all the data 65 | will be included in a single dataset. 66 | transformers: list of TransformerGenerators or strings 67 | the Transformers to apply to the data. Each one is specified by a 68 | TransformerGenerator or, as a shortcut, one of the names from 69 | dc.molnet.transformers. 70 | reload: bool 71 | if True, the first call for a particular featurizer and splitter will cache 72 | the datasets to disk, and subsequent calls will reload the cached datasets. 73 | data_dir: str 74 | a directory to save the raw data in 75 | save_dir: str 76 | a directory to save the dataset in 77 | 78 | References 79 | ---------- 80 | .. [1] Hersey, A. ChEMBL Deposited Data Set - AZ dataset; 2015. 81 | https://doi.org/10.6019/chembl3301361 82 | """ 83 | loader = _LipoLoader(featurizer, splitter, transformers, LIPO_TASKS, data_dir, 84 | save_dir, **kwargs) 85 | return loader.load_dataset('lipo', reload) 86 | -------------------------------------------------------------------------------- /utils/sampl_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | SAMPL dataset loader. 3 | """ 4 | import os 5 | import deepchem as dc 6 | from utils.molnet_loader import TransformerGenerator, _MolnetLoader 7 | from deepchem.data import Dataset 8 | from typing import List, Optional, Tuple, Union 9 | 10 | from utils.data_loader import CSVLoader 11 | 12 | SAMPL_URL = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/SAMPL.csv" 13 | SAMPL_TASKS = ['expt'] 14 | 15 | 16 | class _SAMPLLoader(_MolnetLoader): 17 | 18 | def create_dataset(self) -> Dataset: 19 | dataset_file = os.path.join(self.data_dir, "SAMPL.csv") 20 | if not os.path.exists(dataset_file): 21 | dc.utils.data_utils.download_url(url=SAMPL_URL, dest_dir=self.data_dir) 22 | loader = CSVLoader( 23 | tasks=self.tasks, feature_field="smiles", featurizer=self.featurizer, id_field="iupac") 24 | return loader.create_dataset(dataset_file, shard_size=8192) 25 | 26 | 27 | def load_sampl_iupac( 28 | featurizer: Union[dc.feat.Featurizer, str] = 'ECFP', 29 | splitter: Union[dc.splits.Splitter, str, None] = 'scaffold', 30 | # transformers: List[Union[TransformerGenerator, str]] = ['normalization'], 31 | transformers: List[Union[TransformerGenerator, str]] = [], 32 | reload: bool = True, 33 | data_dir: Optional[str] = None, 34 | save_dir: Optional[str] = None, 35 | **kwargs 36 | ) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]: 37 | """Load SAMPL(FreeSolv) dataset 38 | 39 | The Free Solvation Database, FreeSolv(SAMPL), provides experimental and 40 | calculated hydration free energy of small molecules in water. The calculated 41 | values are derived from alchemical free energy calculations using molecular 42 | dynamics simulations. The experimental values are included in the benchmark 43 | collection. 44 | 45 | Random splitting is recommended for this dataset. 46 | 47 | The raw data csv file contains columns below: 48 | 49 | - "iupac" - IUPAC name of the compound 50 | - "smiles" - SMILES representation of the molecular structure 51 | - "expt" - Measured solvation energy (unit: kcal/mol) of the compound, 52 | used as label 53 | - "calc" - Calculated solvation energy (unit: kcal/mol) of the compound 54 | 55 | Parameters 56 | ---------- 57 | featurizer: Featurizer or str 58 | the featurizer to use for processing the data. Alternatively you can pass 59 | one of the names from dc.molnet.featurizers as a shortcut. 60 | splitter: Splitter or str 61 | the splitter to use for splitting the data into training, validation, and 62 | test sets. Alternatively you can pass one of the names from 63 | dc.molnet.splitters as a shortcut. If this is None, all the data 64 | will be included in a single dataset. 65 | transformers: list of TransformerGenerators or strings 66 | the Transformers to apply to the data. Each one is specified by a 67 | TransformerGenerator or, as a shortcut, one of the names from 68 | dc.molnet.transformers. 69 | reload: bool 70 | if True, the first call for a particular featurizer and splitter will cache 71 | the datasets to disk, and subsequent calls will reload the cached datasets. 72 | data_dir: str 73 | a directory to save the raw data in 74 | save_dir: str 75 | a directory to save the dataset in 76 | 77 | References 78 | ---------- 79 | .. [1] Mobley, David L., and J. Peter Guthrie. "FreeSolv: a database of 80 | experimental and calculated hydration free energies, with input files." 81 | Journal of computer-aided molecular design 28.7 (2014): 711-720. 82 | """ 83 | loader = _SAMPLLoader(featurizer, splitter, transformers, SAMPL_TASKS, 84 | data_dir, save_dir, **kwargs) 85 | return loader.load_dataset('sampl', reload) 86 | -------------------------------------------------------------------------------- /download_pubchem/txt2csv.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import argparse 4 | from pathlib import Path 5 | 6 | ''' 7 | DOWNLOAD_DIR="/sharefs//chem_data/pubchem/" 8 | EXTRACTION_FILE='names_properties_1m.txt' #'iupacs_properties_10.txt' 9 | FINAL_OUTPUT_DIR='data_1m' 10 | 11 | python txt2csv.py \ 12 | --input_dir=/sharefs//chem_data/pubchem/names_properties_1m.txt \ 13 | --output_dir=/sharefs//chem_data/pubchem/data_1m 14 | 15 | rlaunch --private-machine=group --charged-group=health --cpu=16 --gpu=0 --memory=100000 \ 16 | -- python txt2csv.py \ 17 | --input_dir=/sharefs//chem_data/pubchem/data_1m_5cols/names_properties.txt \ 18 | --output_dir=/sharefs//chem_data/pubchem/data_1m_5cols 19 | ''' 20 | 21 | if __name__=='__main__': 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--input_dir", required=True, type=str) 25 | parser.add_argument("--output_dir", required=True, type=str) 26 | args = parser.parse_args() 27 | input_dir = args.input_dir 28 | output_dir = args.output_dir 29 | 30 | # path = '/sharefs//chem_data/pubchem/data/' 31 | print('Reading...') 32 | df = pd.read_csv(input_dir,delimiter='\t',header=None) 33 | # df.columns = ['Preferred','Canonical','Formula'] # Canonical< 34 | df.columns = ['Preferred','CAS','Systematic','Traditional','Canonical','Formula','Mass','LogP'] # add column names 35 | print(df.head()) 36 | print('Original df.shape', df.shape) # 898483, 4 37 | # print(df.info) 38 | 39 | print('Dropping NAs in IUPACs and SMILES') 40 | df.dropna(subset=['Preferred','Canonical'],inplace=True) # not drpopping nas in task labels like formula, mass. log p 41 | print('Ater dropna: df.shape', df.shape) # 896718, 4 42 | 43 | Path(output_dir).mkdir(parents=True, exist_ok=True) 44 | print('writing...') 45 | iupac = df[['Preferred','CAS','Systematic','Traditional']] 46 | smiles = df['Canonical'] 47 | 48 | 49 | iupac.to_csv(Path(args.output_dir) / 'iupacs.csv',index=False) # m:ode='a', index=False, header=False # not using appending mode 50 | print('Finished writing iupac!') 51 | print(iupac.shape) 52 | print(iupac.head()) 53 | 54 | smiles.to_csv(Path(args.output_dir) / 'smiles.csv',index=False) # mode='a', index=False, header=False 55 | print('Finished writing smiles!') 56 | print(smiles.shape) 57 | print(smiles.head()) 58 | 59 | ''' 60 | ### discretize the continuous mass and log p labels 61 | groups = 20 62 | quantile_bins = [(1/groups)*x for x in range(groups+1)] 63 | labels = [i for i in range(groups)] 64 | 65 | print(df['LogP'].values) 66 | print(df['LogP'].values.shape) 67 | 68 | mass_quantiles = np.nanquantile(df['Mass'].values,q=quantile_bins) 69 | logp_quantiles = np.nanquantile(df['LogP'].values,q=quantile_bins) 70 | print('mass_quantiles',mass_quantiles) 71 | print('logp_quantiles',logp_quantiles) 72 | # groups = 10,mass_quantiles [1.0e+00 1.9e+02 2.3e+02 2.6e+02 2.8e+02 3.0e+02 3.2e+02 3.4e+02 3.7e+02 4.6e+02 1.8e+04] 73 | # groups = 10,logp_quantiles [-70.2 0.6 1.5 2.1 2.5 3. 3.4 3.8 4.4 5.2 78. ] 74 | 75 | # df['Mass_label'] = pd.cut(x=df['Mass'],bins=mass_quantiles)#,labels=labels) 76 | # df['LogP_label'] = pd.cut(x=df['LogP'],bins=logp_quantiles)#,labels=labels) 77 | df['Mass_label'] = pd.cut(x=df['Mass'],bins=groups,precision=1)# retbins=True)#,labels=labels) #mass竟然有负数... 78 | df['LogP_label'] = pd.cut(x=df['LogP'],bins=groups,precision=1)# retbins=True)#,labels=labels) 79 | ''' 80 | 81 | df.to_csv(Path(args.output_dir) / 'names_properties_8cols.csv',index=False) # names_properties_5cols.csv 82 | print('Finished writing names_properties!') 83 | print(df.shape) 84 | print(df.head()) 85 | 86 | -------------------------------------------------------------------------------- /utils/hiv_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | hiv dataset loader. 3 | """ 4 | import os 5 | import deepchem as dc 6 | # from deepchem.molnet.load_function.molnet_loader import TransformerGenerator, _MolnetLoader 7 | from utils.molnet_loader import TransformerGenerator, _MolnetLoader 8 | from deepchem.data import Dataset 9 | from typing import List, Optional, Tuple, Union 10 | from utils.data_loader import CSVLoader 11 | 12 | HIV_URL = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/HIV.csv" 13 | HIV_TASKS = ["HIV_active"] 14 | 15 | HIV_IUPAC_FILE = "../test_data/HIV_iupac.csv" 16 | 17 | 18 | 19 | class _HIVLoader(_MolnetLoader): 20 | 21 | def create_dataset(self) -> Dataset: 22 | # dataset_file = os.path.join(self.data_dir, "HIV.csv") 23 | dataset_file = HIV_IUPAC_FILE 24 | if not os.path.exists(dataset_file): 25 | dc.utils.data_utils.download_url(url=HIV_URL, dest_dir=self.data_dir) 26 | loader = CSVLoader( 27 | tasks=self.tasks, feature_field="smiles", featurizer=self.featurizer, id_field="iupac1") 28 | return loader.create_dataset(dataset_file, shard_size=8192) 29 | 30 | 31 | def load_hiv( 32 | featurizer: Union[dc.feat.Featurizer, str] = 'ECFP', 33 | splitter: Union[dc.splits.Splitter, str, None] = 'scaffold', 34 | transformers: List[Union[TransformerGenerator, str]] = ['balancing'], 35 | reload: bool = False, 36 | data_dir: Optional[str] = None, 37 | save_dir: Optional[str] = None, 38 | **kwargs 39 | ) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]: 40 | """Load HIV dataset 41 | 42 | The HIV dataset was introduced by the Drug Therapeutics 43 | Program (DTP) AIDS Antiviral Screen, which tested the ability 44 | to inhibit HIV replication for over 40,000 compounds. 45 | Screening results were evaluated and placed into three 46 | categories: confirmed inactive (CI),confirmed active (CA) and 47 | confirmed moderately active (CM). We further combine the 48 | latter two labels, making it a classification task between 49 | inactive (CI) and active (CA and CM). 50 | 51 | Scaffold splitting is recommended for this dataset. 52 | 53 | The raw data csv file contains columns below: 54 | 55 | - "smiles": SMILES representation of the molecular structure 56 | - "activity": Three-class labels for screening results: CI/CM/CA 57 | - "HIV_active": Binary labels for screening results: 1 (CA/CM) and 0 (CI) 58 | 59 | Parameters 60 | ---------- 61 | featurizer: Featurizer or str 62 | the featurizer to use for processing the data. Alternatively you can pass 63 | one of the names from dc.molnet.featurizers as a shortcut. 64 | splitter: Splitter or str 65 | the splitter to use for splitting the data into training, validation, and 66 | test sets. Alternatively you can pass one of the names from 67 | dc.molnet.splitters as a shortcut. If this is None, all the data 68 | will be included in a single dataset. 69 | transformers: list of TransformerGenerators or strings 70 | the Transformers to apply to the data. Each one is specified by a 71 | TransformerGenerator or, as a shortcut, one of the names from 72 | dc.molnet.transformers. 73 | reload: bool 74 | if True, the first call for a particular featurizer and splitter will cache 75 | the datasets to disk, and subsequent calls will reload the cached datasets. 76 | data_dir: str 77 | a directory to save the raw data in 78 | save_dir: str 79 | a directory to save the dataset in 80 | 81 | References 82 | ---------- 83 | .. [1] AIDS Antiviral Screen Data. 84 | https://wiki.nci.nih.gov/display/NCIDTPdata/AIDS+Antiviral+Screen+Data 85 | """ 86 | loader = _HIVLoader(featurizer, splitter, transformers, HIV_TASKS, data_dir, 87 | save_dir, **kwargs) 88 | return loader.load_dataset('hiv', reload) 89 | -------------------------------------------------------------------------------- /utils/tox21_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tox21 dataset loader. 3 | """ 4 | import os 5 | import deepchem as dc 6 | # from deepchem.molnet.load_function.molnet_loader import TransformerGenerator, _MolnetLoader 7 | from utils.molnet_loader import TransformerGenerator, _MolnetLoader 8 | from deepchem.data import Dataset 9 | from typing import List, Optional, Tuple, Union 10 | from utils.data_loader import CSVLoader 11 | 12 | TOX21_URL = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz" 13 | TOX21_TASKS = [ 14 | 'NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD', 15 | 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53' 16 | ] 17 | 18 | TOX21_IUPAC_FILE = "../test_data/tox21_iupac.csv" 19 | 20 | 21 | class _Tox21Loader(_MolnetLoader): 22 | 23 | def create_dataset(self) -> Dataset: 24 | # dataset_file = os.path.join(self.data_dir, "tox21.csv.gz") 25 | dataset_file = TOX21_IUPAC_FILE 26 | if not os.path.exists(dataset_file): 27 | dc.utils.data_utils.download_url(url=TOX21_URL, dest_dir=self.data_dir) 28 | loader = CSVLoader( 29 | tasks=self.tasks, feature_field="smiles", featurizer=self.featurizer, id_field="iupac1") 30 | return loader.create_dataset(dataset_file, shard_size=8192) 31 | 32 | 33 | def load_tox21( 34 | featurizer: Union[dc.feat.Featurizer, str] = 'ECFP', 35 | splitter: Union[dc.splits.Splitter, str, None] = 'scaffold', 36 | transformers: List[Union[TransformerGenerator, str]] = ['balancing'], 37 | reload: bool = False, 38 | data_dir: Optional[str] = None, 39 | save_dir: Optional[str] = None, 40 | **kwargs 41 | ) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]: 42 | """Load Tox21 dataset 43 | 44 | The "Toxicology in the 21st Century" (Tox21) initiative created a public 45 | database measuring toxicity of compounds, which has been used in the 2014 46 | Tox21 Data Challenge. This dataset contains qualitative toxicity measurements 47 | for 8k compounds on 12 different targets, including nuclear receptors and 48 | stress response pathways. 49 | 50 | Random splitting is recommended for this dataset. 51 | 52 | The raw data csv file contains columns below: 53 | 54 | - "smiles" - SMILES representation of the molecular structure 55 | - "NR-XXX" - Nuclear receptor signaling bioassays results 56 | - "SR-XXX" - Stress response bioassays results 57 | 58 | please refer to https://tripod.nih.gov/tox21/challenge/data.jsp for details. 59 | 60 | Parameters 61 | ---------- 62 | featurizer: Featurizer or str 63 | the featurizer to use for processing the data. Alternatively you can pass 64 | one of the names from dc.molnet.featurizers as a shortcut. 65 | splitter: Splitter or str 66 | the splitter to use for splitting the data into training, validation, and 67 | test sets. Alternatively you can pass one of the names from 68 | dc.molnet.splitters as a shortcut. If this is None, all the data 69 | will be included in a single dataset. 70 | transformers: list of TransformerGenerators or strings 71 | the Transformers to apply to the data. Each one is specified by a 72 | TransformerGenerator or, as a shortcut, one of the names from 73 | dc.molnet.transformers. 74 | reload: bool 75 | if True, the first call for a particular featurizer and splitter will cache 76 | the datasets to disk, and subsequent calls will reload the cached datasets. 77 | data_dir: str 78 | a directory to save the raw data in 79 | save_dir: str 80 | a directory to save the dataset in 81 | 82 | References 83 | ---------- 84 | .. [1] Tox21 Challenge. https://tripod.nih.gov/tox21/challenge/ 85 | """ 86 | loader = _Tox21Loader(featurizer, splitter, transformers, TOX21_TASKS, 87 | data_dir, save_dir, **kwargs) 88 | return loader.load_dataset('tox21', reload) 89 | -------------------------------------------------------------------------------- /utils/delaney_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Delaney dataset loader. 3 | """ 4 | import os 5 | import deepchem as dc 6 | # from deepchem.molnet.load_function.molnet_loader import TransformerGenerator, _MolnetLoader 7 | from utils.molnet_loader import TransformerGenerator, _MolnetLoader 8 | from deepchem.data import Dataset 9 | from typing import List, Optional, Tuple, Union 10 | from utils.data_loader import CSVLoader 11 | 12 | DELANEY_URL = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv" 13 | DELANEY_TASKS = ['measured log solubility in mols per litre'] 14 | 15 | DELANEY_IUPAC_FILE = "../test_data/esol_iupac.csv" 16 | 17 | class _DelaneyLoader(_MolnetLoader): 18 | 19 | def create_dataset(self) -> Dataset: 20 | # dataset_file = os.path.join(self.data_dir, "delaney-processed.csv") 21 | dataset_file = DELANEY_IUPAC_FILE 22 | if not os.path.exists(dataset_file): 23 | dc.utils.data_utils.download_url(url=DELANEY_URL, dest_dir=self.data_dir) 24 | loader = CSVLoader( 25 | tasks=self.tasks, feature_field="smiles", featurizer=self.featurizer, id_field="iupac1") 26 | return loader.create_dataset(dataset_file, shard_size=8192) 27 | 28 | 29 | def load_delaney( 30 | featurizer: Union[dc.feat.Featurizer, str] = 'ECFP', 31 | splitter: Union[dc.splits.Splitter, str, None] = 'scaffold', 32 | # transformers: List[Union[TransformerGenerator, str]] = ['normalization'], 33 | transformers: List[Union[TransformerGenerator, str]] = [], 34 | reload: bool = False, 35 | data_dir: Optional[str] = None, 36 | save_dir: Optional[str] = None, 37 | **kwargs 38 | ) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]: 39 | """Load Delaney dataset 40 | 41 | The Delaney (ESOL) dataset a regression dataset containing structures and 42 | water solubility data for 1128 compounds. The dataset is widely used to 43 | validate machine learning models on estimating solubility directly from 44 | molecular structures (as encoded in SMILES strings). 45 | 46 | Scaffold splitting is recommended for this dataset. 47 | 48 | The raw data csv file contains columns below: 49 | 50 | - "Compound ID" - Name of the compound 51 | - "smiles" - SMILES representation of the molecular structure 52 | - "measured log solubility in mols per litre" - Log-scale water solubility 53 | of the compound, used as label 54 | 55 | Parameters 56 | ---------- 57 | featurizer: Featurizer or str 58 | the featurizer to use for processing the data. Alternatively you can pass 59 | one of the names from dc.molnet.featurizers as a shortcut. 60 | splitter: Splitter or str 61 | the splitter to use for splitting the data into training, validation, and 62 | test sets. Alternatively you can pass one of the names from 63 | dc.molnet.splitters as a shortcut. If this is None, all the data 64 | will be included in a single dataset. 65 | transformers: list of TransformerGenerators or strings 66 | the Transformers to apply to the data. Each one is specified by a 67 | TransformerGenerator or, as a shortcut, one of the names from 68 | dc.molnet.transformers. 69 | reload: bool 70 | if True, the first call for a particular featurizer and splitter will cache 71 | the datasets to disk, and subsequent calls will reload the cached datasets. 72 | data_dir: str 73 | a directory to save the raw data in 74 | save_dir: str 75 | a directory to save the dataset in 76 | 77 | References 78 | ---------- 79 | .. [1] Delaney, John S. "ESOL: estimating aqueous solubility directly from 80 | molecular structure." Journal of chemical information and computer 81 | sciences 44.3 (2004): 1000-1005. 82 | """ 83 | loader = _DelaneyLoader(featurizer, splitter, transformers, DELANEY_TASKS, 84 | data_dir, save_dir, **kwargs) 85 | return loader.load_dataset('delaney', reload) 86 | -------------------------------------------------------------------------------- /utils/cep_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Delaney dataset loader. 3 | """ 4 | import os 5 | import deepchem as dc 6 | from deepchem.molnet.load_function.molnet_loader import TransformerGenerator, _MolnetLoader 7 | # from utils.molnet_loader import TransformerGenerator, _MolnetLoader 8 | from deepchem.data import Dataset 9 | from typing import List, Optional, Tuple, Union 10 | from utils.data_loader import CSVLoader 11 | 12 | # CEP_URL = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv" 13 | CEP_TASKS = ['PCE'] 14 | 15 | CEP_FILE = "../test_data/cep.csv" 16 | 17 | class _CepLoader(_MolnetLoader): 18 | 19 | def create_dataset(self) -> Dataset: 20 | # dataset_file = os.path.join(self.data_dir, "delaney-processed.csv") 21 | dataset_file = CEP_FILE 22 | if not os.path.exists(dataset_file): 23 | dc.utils.data_utils.download_url(url=CEP_FILE, dest_dir=self.data_dir) 24 | # loader = CSVLoader( 25 | # tasks=self.tasks, feature_field="smiles", featurizer=self.featurizer, id_field="iupac1") 26 | loader = dc.data.CSVLoader( 27 | tasks=self.tasks, feature_field="smiles", featurizer=self.featurizer) 28 | return loader.create_dataset(dataset_file, shard_size=8192) 29 | 30 | 31 | def load_cep( 32 | featurizer: Union[dc.feat.Featurizer, str] = 'ECFP', 33 | splitter: Union[dc.splits.Splitter, str, None] = 'scaffold', 34 | # transformers: List[Union[TransformerGenerator, str]] = ['normalization'], 35 | transformers: List[Union[TransformerGenerator, str]] = [], 36 | reload: bool = False, 37 | data_dir: Optional[str] = None, 38 | save_dir: Optional[str] = None, 39 | **kwargs 40 | ) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]: 41 | """Load Delaney dataset 42 | 43 | The Delaney (ESOL) dataset a regression dataset containing structures and 44 | water solubility data for 1128 compounds. The dataset is widely used to 45 | validate machine learning models on estimating solubility directly from 46 | molecular structures (as encoded in SMILES strings). 47 | 48 | Scaffold splitting is recommended for this dataset. 49 | 50 | The raw data csv file contains columns below: 51 | 52 | - "Compound ID" - Name of the compound 53 | - "smiles" - SMILES representation of the molecular structure 54 | - "measured log solubility in mols per litre" - Log-scale water solubility 55 | of the compound, used as label 56 | 57 | Parameters 58 | ---------- 59 | featurizer: Featurizer or str 60 | the featurizer to use for processing the data. Alternatively you can pass 61 | one of the names from dc.molnet.featurizers as a shortcut. 62 | splitter: Splitter or str 63 | the splitter to use for splitting the data into training, validation, and 64 | test sets. Alternatively you can pass one of the names from 65 | dc.molnet.splitters as a shortcut. If this is None, all the data 66 | will be included in a single dataset. 67 | transformers: list of TransformerGenerators or strings 68 | the Transformers to apply to the data. Each one is specified by a 69 | TransformerGenerator or, as a shortcut, one of the names from 70 | dc.molnet.transformers. 71 | reload: bool 72 | if True, the first call for a particular featurizer and splitter will cache 73 | the datasets to disk, and subsequent calls will reload the cached datasets. 74 | data_dir: str 75 | a directory to save the raw data in 76 | save_dir: str 77 | a directory to save the dataset in 78 | 79 | References 80 | ---------- 81 | .. [1] Delaney, John S. "ESOL: estimating aqueous solubility directly from 82 | molecular structure." Journal of chemical information and computer 83 | sciences 44.3 (2004): 1000-1005. 84 | """ 85 | loader = _CepLoader(featurizer, splitter, transformers, CEP_TASKS, 86 | data_dir, save_dir, **kwargs) 87 | return loader.load_dataset('delaney', reload) 88 | -------------------------------------------------------------------------------- /utils/bbbp_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Blood-Brain Barrier Penetration dataset loader. 3 | """ 4 | import os 5 | import deepchem as dc 6 | # from deepchem.molnet.load_function.molnet_loader import TransformerGenerator, _MolnetLoader 7 | from utils.molnet_loader import TransformerGenerator, _MolnetLoader 8 | from deepchem.data import Dataset 9 | from typing import List, Optional, Tuple, Union 10 | from utils.data_loader import CSVLoader 11 | 12 | BBBP_URL = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv" 13 | BBBP_TASKS = ["p_np"] 14 | BBBP_IUPAC_FILE = "../test_data/BBBP_iupac.csv" 15 | 16 | 17 | 18 | class _BBBPLoader(_MolnetLoader): 19 | 20 | def create_dataset(self) -> Dataset: 21 | # dataset_file = os.path.join(self.data_dir, "BBBP.csv") 22 | dataset_file = BBBP_IUPAC_FILE 23 | if not os.path.exists(dataset_file): 24 | dc.utils.data_utils.download_url(url=BBBP_URL, dest_dir=self.data_dir) 25 | loader = CSVLoader( 26 | tasks=self.tasks, feature_field="smiles", featurizer=self.featurizer, id_field="iupac1") 27 | return loader.create_dataset(dataset_file, shard_size=8192) 28 | 29 | 30 | def load_bbbp( 31 | featurizer: Union[dc.feat.Featurizer, str] = 'ECFP', 32 | splitter: Union[dc.splits.Splitter, str, None] = 'scaffold', 33 | transformers: List[Union[TransformerGenerator, str]] = ['balancing'], 34 | reload: bool = False, 35 | data_dir: Optional[str] = None, 36 | save_dir: Optional[str] = None, 37 | **kwargs 38 | ) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]: 39 | """Load BBBP dataset 40 | 41 | The blood-brain barrier penetration (BBBP) dataset is designed for the 42 | modeling and prediction of barrier permeability. As a membrane separating 43 | circulating blood and brain extracellular fluid, the blood-brain barrier 44 | blocks most drugs, hormones and neurotransmitters. Thus penetration of the 45 | barrier forms a long-standing issue in development of drugs targeting 46 | central nervous system. 47 | 48 | This dataset includes binary labels for over 2000 compounds on their 49 | permeability properties. 50 | 51 | Scaffold splitting is recommended for this dataset. 52 | 53 | The raw data csv file contains columns below: 54 | 55 | - "name" - Name of the compound 56 | - "smiles" - SMILES representation of the molecular structure 57 | - "p_np" - Binary labels for penetration/non-penetration 58 | 59 | Parameters 60 | ---------- 61 | featurizer: Featurizer or str 62 | the featurizer to use for processing the data. Alternatively you can pass 63 | one of the names from dc.molnet.featurizers as a shortcut. 64 | splitter: Splitter or str 65 | the splitter to use for splitting the data into training, validation, and 66 | test sets. Alternatively you can pass one of the names from 67 | dc.molnet.splitters as a shortcut. If this is None, all the data 68 | will be included in a single dataset. 69 | transformers: list of TransformerGenerators or strings 70 | the Transformers to apply to the data. Each one is specified by a 71 | TransformerGenerator or, as a shortcut, one of the names from 72 | dc.molnet.transformers. 73 | reload: bool 74 | if True, the first call for a particular featurizer and splitter will cache 75 | the datasets to disk, and subsequent calls will reload the cached datasets. 76 | data_dir: str 77 | a directory to save the raw data in 78 | save_dir: str 79 | a directory to save the dataset in 80 | 81 | References 82 | ---------- 83 | .. [1] Martins, Ines Filipa, et al. "A Bayesian approach to in silico 84 | blood-brain barrier penetration modeling." Journal of chemical 85 | information and modeling 52.6 (2012): 1686-1697. 86 | """ 87 | loader = _BBBPLoader(featurizer, splitter, transformers, BBBP_TASKS, data_dir, 88 | save_dir, **kwargs) 89 | return loader.load_dataset('bbbp', reload) 90 | -------------------------------------------------------------------------------- /utils/malaria_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Delaney dataset loader. 3 | """ 4 | import os 5 | import deepchem as dc 6 | from deepchem.molnet.load_function.molnet_loader import TransformerGenerator, _MolnetLoader 7 | # from utils.molnet_loader import TransformerGenerator, _MolnetLoader 8 | from deepchem.data import Dataset 9 | from typing import List, Optional, Tuple, Union 10 | from utils.data_loader import CSVLoader 11 | 12 | # MALARIA_URL = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv" 13 | MALARIA_TASKS = ['activity'] 14 | 15 | MALARIA_FILE = "../test_data/malaria.csv" 16 | 17 | class _MalariaLoader(_MolnetLoader): 18 | 19 | def create_dataset(self) -> Dataset: 20 | # dataset_file = os.path.join(self.data_dir, "delaney-processed.csv") 21 | dataset_file = MALARIA_FILE 22 | if not os.path.exists(dataset_file): 23 | dc.utils.data_utils.download_url(url=MALARIA_FILE, dest_dir=self.data_dir) 24 | # loader = CSVLoader( 25 | # tasks=self.tasks, feature_field="smiles", featurizer=self.featurizer, id_field="iupac1") 26 | loader = dc.data.CSVLoader( 27 | tasks=self.tasks, feature_field="smiles", featurizer=self.featurizer) 28 | return loader.create_dataset(dataset_file, shard_size=8192) 29 | 30 | 31 | def load_malaria( 32 | featurizer: Union[dc.feat.Featurizer, str] = 'ECFP', 33 | splitter: Union[dc.splits.Splitter, str, None] = 'scaffold', 34 | # transformers: List[Union[TransformerGenerator, str]] = ['normalization'], 35 | transformers: List[Union[TransformerGenerator, str]] = [], 36 | reload: bool = False, 37 | data_dir: Optional[str] = None, 38 | save_dir: Optional[str] = None, 39 | **kwargs 40 | ) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]: 41 | """Load Delaney dataset 42 | 43 | The Delaney (ESOL) dataset a regression dataset containing structures and 44 | water solubility data for 1128 compounds. The dataset is widely used to 45 | validate machine learning models on estimating solubility directly from 46 | molecular structures (as encoded in SMILES strings). 47 | 48 | Scaffold splitting is recommended for this dataset. 49 | 50 | The raw data csv file contains columns below: 51 | 52 | - "Compound ID" - Name of the compound 53 | - "smiles" - SMILES representation of the molecular structure 54 | - "measured log solubility in mols per litre" - Log-scale water solubility 55 | of the compound, used as label 56 | 57 | Parameters 58 | ---------- 59 | featurizer: Featurizer or str 60 | the featurizer to use for processing the data. Alternatively you can pass 61 | one of the names from dc.molnet.featurizers as a shortcut. 62 | splitter: Splitter or str 63 | the splitter to use for splitting the data into training, validation, and 64 | test sets. Alternatively you can pass one of the names from 65 | dc.molnet.splitters as a shortcut. If this is None, all the data 66 | will be included in a single dataset. 67 | transformers: list of TransformerGenerators or strings 68 | the Transformers to apply to the data. Each one is specified by a 69 | TransformerGenerator or, as a shortcut, one of the names from 70 | dc.molnet.transformers. 71 | reload: bool 72 | if True, the first call for a particular featurizer and splitter will cache 73 | the datasets to disk, and subsequent calls will reload the cached datasets. 74 | data_dir: str 75 | a directory to save the raw data in 76 | save_dir: str 77 | a directory to save the dataset in 78 | 79 | References 80 | ---------- 81 | .. [1] Delaney, John S. "ESOL: estimating aqueous solubility directly from 82 | molecular structure." Journal of chemical information and computer 83 | sciences 44.3 (2004): 1000-1005. 84 | """ 85 | loader = _MalariaLoader(featurizer, splitter, transformers, MALARIA_TASKS, 86 | data_dir, save_dir, **kwargs) 87 | return loader.load_dataset('delaney', reload) 88 | -------------------------------------------------------------------------------- /utils/clintox_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Clinical Toxicity (clintox) dataset loader. 3 | @author Caleb Geniesse 4 | """ 5 | import os 6 | import deepchem as dc 7 | # from deepchem.molnet.load_function.molnet_loader import TransformerGenerator, _MolnetLoader 8 | from utils.molnet_loader import TransformerGenerator, _MolnetLoader 9 | from deepchem.data import Dataset 10 | from typing import List, Optional, Tuple, Union 11 | 12 | CLINTOX_URL = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/clintox.csv.gz" 13 | CLINTOX_TASKS = ['FDA_APPROVED', 'CT_TOX'] 14 | from utils.data_loader import CSVLoader 15 | 16 | 17 | CLINTOX_IUPAC_FILE = "../test_data/clintox_iupac.csv" 18 | 19 | class _ClintoxLoader(_MolnetLoader): 20 | 21 | def create_dataset(self) -> Dataset: 22 | # dataset_file = os.path.join(self.data_dir, "clintox.csv.gz") 23 | dataset_file = CLINTOX_IUPAC_FILE 24 | if not os.path.exists(dataset_file): 25 | dc.utils.data_utils.download_url(url=CLINTOX_URL, dest_dir=self.data_dir) 26 | loader = CSVLoader( 27 | tasks=self.tasks, feature_field="smiles", featurizer=self.featurizer, id_field="iupac1") 28 | return loader.create_dataset(dataset_file, shard_size=8192) 29 | 30 | 31 | def load_clintox( 32 | featurizer: Union[dc.feat.Featurizer, str] = 'ECFP', 33 | splitter: Union[dc.splits.Splitter, str, None] = 'scaffold', 34 | transformers: List[Union[TransformerGenerator, str]] = ['balancing'], 35 | reload: bool = True, 36 | data_dir: Optional[str] = None, 37 | save_dir: Optional[str] = None, 38 | **kwargs 39 | ) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]: 40 | """Load ClinTox dataset 41 | 42 | The ClinTox dataset compares drugs approved by the FDA and 43 | drugs that have failed clinical trials for toxicity reasons. 44 | The dataset includes two classification tasks for 1491 drug 45 | compounds with known chemical structures: 46 | 47 | #. clinical trial toxicity (or absence of toxicity) 48 | #. FDA approval status. 49 | 50 | List of FDA-approved drugs are compiled from the SWEETLEAD 51 | database, and list of drugs that failed clinical trials for 52 | toxicity reasons are compiled from the Aggregate Analysis of 53 | ClinicalTrials.gov(AACT) database. 54 | 55 | Random splitting is recommended for this dataset. 56 | 57 | The raw data csv file contains columns below: 58 | 59 | - "smiles" - SMILES representation of the molecular structure 60 | - "FDA_APPROVED" - FDA approval status 61 | - "CT_TOX" - Clinical trial results 62 | 63 | Parameters 64 | ---------- 65 | featurizer: Featurizer or str 66 | the featurizer to use for processing the data. Alternatively you can pass 67 | one of the names from dc.molnet.featurizers as a shortcut. 68 | splitter: Splitter or str 69 | the splitter to use for splitting the data into training, validation, and 70 | test sets. Alternatively you can pass one of the names from 71 | dc.molnet.splitters as a shortcut. If this is None, all the data 72 | will be included in a single dataset. 73 | transformers: list of TransformerGenerators or strings 74 | the Transformers to apply to the data. Each one is specified by a 75 | TransformerGenerator or, as a shortcut, one of the names from 76 | dc.molnet.transformers. 77 | reload: bool 78 | if True, the first call for a particular featurizer and splitter will cache 79 | the datasets to disk, and subsequent calls will reload the cached datasets. 80 | data_dir: str 81 | a directory to save the raw data in 82 | save_dir: str 83 | a directory to save the dataset in 84 | 85 | References 86 | ---------- 87 | .. [1] Gayvert, Kaitlyn M., Neel S. Madhukar, and Olivier Elemento. 88 | "A data-driven approach to predicting successes and failures of clinical 89 | trials." 90 | Cell chemical biology 23.10 (2016): 1294-1301. 91 | .. [2] Artemov, Artem V., et al. "Integrated deep learned transcriptomic and 92 | structure-based predictor of clinical trials outcomes." bioRxiv (2016): 93 | 095653. 94 | .. [3] Novick, Paul A., et al. "SWEETLEAD: an in silico database of approved 95 | drugs, regulated chemicals, and herbal isolates for computer-aided drug 96 | discovery." PloS one 8.11 (2013): e79568. 97 | .. [4] Aggregate Analysis of ClincalTrials.gov (AACT) Database. 98 | https://www.ctti-clinicaltrials.org/aact-database 99 | """ 100 | loader = _ClintoxLoader(featurizer, splitter, transformers, CLINTOX_TASKS, 101 | data_dir, save_dir, **kwargs) 102 | return loader.load_dataset('clintox', reload) 103 | -------------------------------------------------------------------------------- /utils/mol.py: -------------------------------------------------------------------------------- 1 | from utils.features import atom_to_feature_vector, bond_to_feature_vector 2 | from rdkit import Chem 3 | import numpy as np 4 | from torch_geometric.data import Data 5 | from collections import Counter 6 | import torch 7 | 8 | from rdkit import RDLogger 9 | RDLogger.DisableLog('rdApp.*') 10 | 11 | 12 | BOND_FEATURES = ['BondType', 'Stereo', 'BondDir'] 13 | def get_bond_feature_name(bond): 14 | """ 15 | Return the string format of bond features. 16 | Bond features are surrounded with () 17 | 18 | """ 19 | ret = [] 20 | for bond_feature in BOND_FEATURES: 21 | fea = eval(f"bond.Get{bond_feature}")() 22 | ret.append(str(fea)) 23 | 24 | return '(' + '-'.join(ret) + ')' 25 | 26 | def atom_to_vocab(mol, atom): 27 | """ 28 | Convert atom to vocabulary. The convention is based on atom type and bond type. 29 | :param mol: the molecular. 30 | :param atom: the target atom. 31 | :return: the generated atom vocabulary with its contexts. 32 | """ 33 | nei = Counter() 34 | for a in atom.GetNeighbors(): 35 | bond = mol.GetBondBetweenAtoms(atom.GetIdx(), a.GetIdx()) 36 | nei[str(a.GetSymbol()) + "-" + str(bond.GetBondType())] += 1 37 | keys = nei.keys() 38 | keys = list(keys) 39 | keys.sort() 40 | output = atom.GetSymbol() 41 | for k in keys: 42 | output = "%s_%s%d" % (output, k, nei[k]) 43 | 44 | # The generated atom_vocab is too long? 45 | return output 46 | 47 | def bond_to_vocab(mol, bond): 48 | """ 49 | Convert bond to vocabulary. The convention is based on atom type and bond type. 50 | Considering one-hop neighbor atoms 51 | :param mol: the molecular. 52 | :param atom: the target atom. 53 | :return: the generated bond vocabulary with its contexts. 54 | """ 55 | nei = Counter() 56 | two_neighbors = (bond.GetBeginAtom(), bond.GetEndAtom()) 57 | two_indices = [a.GetIdx() for a in two_neighbors] 58 | for nei_atom in two_neighbors: 59 | for a in nei_atom.GetNeighbors(): 60 | a_idx = a.GetIdx() 61 | if a_idx in two_indices: 62 | continue 63 | tmp_bond = mol.GetBondBetweenAtoms(nei_atom.GetIdx(), a_idx) 64 | nei[str(nei_atom.GetSymbol()) + '-' + get_bond_feature_name(tmp_bond)] += 1 65 | keys = list(nei.keys()) 66 | keys.sort() 67 | output = get_bond_feature_name(bond) 68 | for k in keys: 69 | output = "%s_%s%d" % (output, k, nei[k]) 70 | return output 71 | 72 | 73 | def smiles2graph(smiles_string, atom_vocab=None): 74 | """ 75 | Converts SMILES string to graph Data object 76 | :input: SMILES string (str) 77 | :return: graph object 78 | """ 79 | 80 | mol = Chem.MolFromSmiles(smiles_string) 81 | 82 | # atoms 83 | mlabes = [] 84 | atom_features_list = [] 85 | for atom in mol.GetAtoms(): 86 | atom_features_list.append(atom_to_feature_vector(atom)) 87 | if atom_vocab is not None: 88 | mlabes.append(atom_vocab.stoi.get(atom_to_vocab(mol, atom), atom_vocab.other_index)) 89 | 90 | 91 | x = np.array(atom_features_list, dtype=np.int64) 92 | 93 | # bonds 94 | num_bond_features = 3 # bond type, bond stereo, is_conjugated 95 | if len(mol.GetBonds()) > 0: # mol has bonds 96 | edges_list = [] 97 | edge_features_list = [] 98 | for bond in mol.GetBonds(): 99 | i = bond.GetBeginAtomIdx() 100 | j = bond.GetEndAtomIdx() 101 | 102 | edge_feature = bond_to_feature_vector(bond) 103 | 104 | # add edges in both directions 105 | edges_list.append((i, j)) 106 | edge_features_list.append(edge_feature) 107 | edges_list.append((j, i)) 108 | edge_features_list.append(edge_feature) 109 | 110 | edge_index = np.array(edges_list, dtype=np.int64) 111 | 112 | # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features] 113 | edge_attr = np.array(edge_features_list, dtype=np.int64) 114 | 115 | else: # mol has no bonds 116 | edge_index = np.empty((0, 2), dtype=np.int64) 117 | edge_attr = np.empty((0, num_bond_features), dtype=np.int64) 118 | 119 | # graph = dict() 120 | # graph['edge_index'] = edge_index 121 | # graph['edge_attr'] = edge_attr 122 | # graph['node_attr'] = x 123 | # graph['num_nodes'] = len(x) 124 | 125 | graph = Data(x=torch.tensor(x), edge_index=torch.tensor(edge_index.T), edge_attr=torch.tensor(edge_attr)) 126 | 127 | 128 | return graph, mlabes 129 | 130 | 131 | if __name__ == '__main__': 132 | graph = smiles2graph('O1C=C[C@H]([C@H]1O2)c3c2cc(OC)c4c3OC(=O)C5=C4CCC(=O)5') 133 | print(graph) 134 | -------------------------------------------------------------------------------- /utils/sider_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | SIDER dataset loader. 3 | """ 4 | import os 5 | import deepchem as dc 6 | # from deepchem.molnet.load_function.molnet_loader import TransformerGenerator, _MolnetLoader 7 | from utils.molnet_loader import TransformerGenerator, _MolnetLoader 8 | from deepchem.data import Dataset 9 | from typing import List, Optional, Tuple, Union 10 | from utils.data_loader import CSVLoader 11 | 12 | SIDER_URL = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/sider.csv.gz" 13 | SIDER_TASKS = [ 14 | 'Hepatobiliary disorders', 'Metabolism and nutrition disorders', 15 | 'Product issues', 'Eye disorders', 'Investigations', 16 | 'Musculoskeletal and connective tissue disorders', 17 | 'Gastrointestinal disorders', 'Social circumstances', 18 | 'Immune system disorders', 'Reproductive system and breast disorders', 19 | 'Neoplasms benign, malignant and unspecified (incl cysts and polyps)', 20 | 'General disorders and administration site conditions', 21 | 'Endocrine disorders', 'Surgical and medical procedures', 22 | 'Vascular disorders', 'Blood and lymphatic system disorders', 23 | 'Skin and subcutaneous tissue disorders', 24 | 'Congenital, familial and genetic disorders', 'Infections and infestations', 25 | 'Respiratory, thoracic and mediastinal disorders', 'Psychiatric disorders', 26 | 'Renal and urinary disorders', 27 | 'Pregnancy, puerperium and perinatal conditions', 28 | 'Ear and labyrinth disorders', 'Cardiac disorders', 29 | 'Nervous system disorders', 'Injury, poisoning and procedural complications' 30 | ] 31 | 32 | SIDER_IUPAC_FILE = "../test_data/sider_iupac.csv" 33 | 34 | class _SiderLoader(_MolnetLoader): 35 | 36 | def create_dataset(self) -> Dataset: 37 | # dataset_file = os.path.join(self.data_dir, "sider.csv.gz") 38 | dataset_file = SIDER_IUPAC_FILE 39 | if not os.path.exists(dataset_file): 40 | dc.utils.data_utils.download_url(url=SIDER_URL, dest_dir=self.data_dir) 41 | loader = CSVLoader( 42 | tasks=self.tasks, feature_field="smiles", featurizer=self.featurizer, id_field="iupac1") 43 | return loader.create_dataset(dataset_file, shard_size=8192) 44 | 45 | 46 | def load_sider( 47 | featurizer: Union[dc.feat.Featurizer, str] = 'ECFP', 48 | splitter: Union[dc.splits.Splitter, str, None] = 'scaffold', 49 | transformers: List[Union[TransformerGenerator, str]] = ['balancing'], 50 | reload: bool = False, 51 | data_dir: Optional[str] = None, 52 | save_dir: Optional[str] = None, 53 | **kwargs 54 | ) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]: 55 | """Load SIDER dataset 56 | 57 | The Side Effect Resource (SIDER) is a database of marketed 58 | drugs and adverse drug reactions (ADR). The version of the 59 | SIDER dataset in DeepChem has grouped drug side effects into 60 | 27 system organ classes following MedDRA classifications 61 | measured for 1427 approved drugs. 62 | 63 | Random splitting is recommended for this dataset. 64 | 65 | The raw data csv file contains columns below: 66 | 67 | - "smiles": SMILES representation of the molecular structure 68 | - "Hepatobiliary disorders" ~ "Injury, poisoning and procedural 69 | complications": Recorded side effects for the drug. Please refer 70 | to http://sideeffects.embl.de/se/?page=98 for details on ADRs. 71 | 72 | Parameters 73 | ---------- 74 | featurizer: Featurizer or str 75 | the featurizer to use for processing the data. Alternatively you can pass 76 | one of the names from dc.molnet.featurizers as a shortcut. 77 | splitter: Splitter or str 78 | the splitter to use for splitting the data into training, validation, and 79 | test sets. Alternatively you can pass one of the names from 80 | dc.molnet.splitters as a shortcut. If this is None, all the data 81 | will be included in a single dataset. 82 | transformers: list of TransformerGenerators or strings 83 | the Transformers to apply to the data. Each one is specified by a 84 | TransformerGenerator or, as a shortcut, one of the names from 85 | dc.molnet.transformers. 86 | reload: bool 87 | if True, the first call for a particular featurizer and splitter will cache 88 | the datasets to disk, and subsequent calls will reload the cached datasets. 89 | data_dir: str 90 | a directory to save the raw data in 91 | save_dir: str 92 | a directory to save the dataset in 93 | 94 | References 95 | ---------- 96 | .. [1] Kuhn, Michael, et al. "The SIDER database of drugs and side effects." 97 | Nucleic acids research 44.D1 (2015): D1075-D1079. 98 | .. [2] Altae-Tran, Han, et al. "Low data drug discovery with one-shot 99 | learning." ACS central science 3.4 (2017): 283-293. 100 | .. [3] Medical Dictionary for Regulatory Activities. http://www.meddra.org/ 101 | """ 102 | loader = _SiderLoader(featurizer, splitter, transformers, SIDER_TASKS, 103 | data_dir, save_dir, **kwargs) 104 | return loader.load_dataset('sider', reload) 105 | -------------------------------------------------------------------------------- /utils/gen_utils.py: -------------------------------------------------------------------------------- 1 | from rdkit.Chem import MolFromSmiles, QED 2 | from utils.sascorer import calculateScore 3 | from rdkit.Chem.Crippen import MolLogP 4 | import pytorch_lightning as pl 5 | import torch 6 | from torch import nn 7 | from torch import optim 8 | import torch.nn.functional as F 9 | from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split 10 | import selfies as sf 11 | from tqdm import tqdm 12 | from rdkit import Chem, DataStructs 13 | from rdkit.Chem import AllChem 14 | 15 | 16 | def calculate_tanimoto(smiles, smiles2): 17 | mol = Chem.MolFromSmiles(smiles) 18 | mol1 = Chem.AddHs(mol) 19 | fps1 = AllChem.GetMorganFingerprintAsBitVect(mol1, 2, nBits=2048, useChirality=False) 20 | 21 | mol2 = Chem.MolFromSmiles(smiles2) 22 | mol2 = Chem.AddHs(mol2) 23 | fps2 = AllChem.GetMorganFingerprintAsBitVect(mol2, 2, nBits=2048, useChirality=False) 24 | fp_sim = DataStructs.TanimotoSimilarity(fps1, fps2) 25 | 26 | return fp_sim 27 | 28 | def one_hot_to_smiles(prob, idx_to_symbol): 29 | return sf.decoder(one_hot_to_selfies(prob, idx_to_symbol)) 30 | 31 | def one_hot_to_selfies(prob, idx_to_symbol): 32 | # todo ignore the pad_idx: 108, may out of idx_to_symbol's range 33 | return ''.join([idx_to_symbol[idx.item()] for idx in prob.argmax(1)]).replace(' ', '') 34 | 35 | def one_hots_to_penalized_logp(probs_lst, idx_to_symbol, return_smi=False): 36 | logps = [] 37 | smi_lst = [] 38 | for i, prob in enumerate(probs_lst): 39 | smile = one_hot_to_smiles(prob, idx_to_symbol) 40 | smi_lst.append(smile) 41 | mol = MolFromSmiles(smile) 42 | penalized_logp = MolLogP(mol) - calculateScore(mol) 43 | for ring in mol.GetRingInfo().AtomRings(): 44 | if len(ring) > 6: 45 | penalized_logp -= 1 46 | logps.append(penalized_logp) 47 | if return_smi: 48 | return logps, smi_lst 49 | return logps 50 | 51 | def smiles_to_penalized_logp(smi_lst): 52 | logps = [] 53 | for smile in tqdm(smi_lst): 54 | mol = MolFromSmiles(smile) 55 | penalized_logp = MolLogP(mol) - calculateScore(mol) 56 | for ring in mol.GetRingInfo().AtomRings(): 57 | if len(ring) > 6: 58 | penalized_logp -= 1 59 | logps.append(penalized_logp) 60 | return logps 61 | 62 | 63 | def smiles_to_indices(smiles, symbol_to_idx): 64 | encoding = [symbol_to_idx[symbol] for symbol in sf.split_selfies(sf.encoder(smiles))] 65 | return torch.tensor(encoding + [symbol_to_idx['[nop]']]) 66 | 67 | 68 | def smiles_to_one_hot(smiles, symbol_to_idx): 69 | idx_smi = smiles_to_indices(smiles, symbol_to_idx) 70 | out = torch.zeros((idx_smi.size(0), len(symbol_to_idx))) 71 | for i, index in enumerate(idx_smi): 72 | out[i][index] = 1 73 | return out.flatten() 74 | 75 | 76 | def generate_training_mols(num_mols, prop_func, device, generator): 77 | with torch.no_grad(): 78 | z = torch.randn((num_mols, 1024), device=device) 79 | x = generator(z) 80 | y = torch.tensor(prop_func(x), device=device).unsqueeze(1).float() 81 | return x, y 82 | 83 | class PropertyPredictor(pl.LightningModule): 84 | def __init__(self, in_dim, learning_rate=0.001): 85 | super(PropertyPredictor, self).__init__() 86 | self.learning_rate = learning_rate 87 | self.fc = nn.Sequential(nn.Linear(in_dim, 1000), 88 | nn.ReLU(), 89 | nn.Linear(1000, 1000), 90 | nn.ReLU(), 91 | nn.Linear(1000, 1)) 92 | 93 | def forward(self, x): 94 | return self.fc(x) 95 | 96 | def configure_optimizers(self): 97 | return optim.Adam(self.parameters(), lr=self.learning_rate) 98 | 99 | def loss_function(self, pred, real): 100 | return F.mse_loss(pred, real) 101 | 102 | def training_step(self, batch, batch_idx): 103 | x, y = batch 104 | out = self(x) 105 | loss = self.loss_function(out, y) 106 | self.log('train_loss', loss) 107 | return loss 108 | 109 | def validation_step(self, batch, batch_idx): 110 | x, y = batch 111 | out = self(x) 112 | loss = self.loss_function(out, y) 113 | self.log('val_loss', loss) 114 | return loss 115 | 116 | class PropDataModule(pl.LightningDataModule): 117 | def __init__(self, x, y, batch_size): 118 | super(PropDataModule, self).__init__() 119 | self.batch_size = batch_size 120 | self.dataset = TensorDataset(x, y) 121 | self.train_data, self.test_data = random_split(self.dataset, [int(round(len(self.dataset) * 0.9)), int(round(len(self.dataset) * 0.1))]) 122 | 123 | def train_dataloader(self): 124 | return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, drop_last=True) 125 | 126 | def val_dataloader(self): 127 | return DataLoader(self.test_data, batch_size=self.batch_size, drop_last=True) 128 | -------------------------------------------------------------------------------- /utils/bace_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | bace dataset loader. 3 | """ 4 | import os 5 | import deepchem as dc 6 | # from deepchem.molnet.load_function.molnet_loader import TransformerGenerator, _MolnetLoader 7 | from utils.molnet_loader import TransformerGenerator, _MolnetLoader 8 | from deepchem.data import Dataset 9 | from typing import List, Optional, Tuple, Union 10 | from utils.data_loader import CSVLoader 11 | 12 | BACE_URL = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/bace.csv" 13 | BACE_REGRESSION_TASKS = ["pIC50"] 14 | BACE_CLASSIFICATION_TASKS = ["Class"] 15 | IUPAC_FILE = '../test_data/bace_iupac.csv' 16 | 17 | 18 | 19 | class _BaceLoader(_MolnetLoader): 20 | 21 | def create_dataset(self) -> Dataset: 22 | # dataset_file = os.path.join(self.data_dir, "bace.csv") 23 | dataset_file = IUPAC_FILE 24 | if not os.path.exists(dataset_file): 25 | dc.utils.data_utils.download_url(url=BACE_URL, dest_dir=self.data_dir) 26 | loader = CSVLoader( 27 | tasks=self.tasks, feature_field="mol", featurizer=self.featurizer, id_field='iupac1') 28 | return loader.create_dataset(dataset_file, shard_size=8192) 29 | 30 | 31 | def load_bace_regression( 32 | featurizer: Union[dc.feat.Featurizer, str] = 'ECFP', 33 | splitter: Union[dc.splits.Splitter, str, None] = 'scaffold', 34 | transformers: List[Union[TransformerGenerator, str]] = ['normalization'], 35 | reload: bool = False, 36 | data_dir: Optional[str] = None, 37 | save_dir: Optional[str] = None, 38 | **kwargs 39 | ) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]: 40 | """ Load BACE dataset, regression labels 41 | 42 | The BACE dataset provides quantitative IC50 and qualitative (binary label) 43 | binding results for a set of inhibitors of human beta-secretase 1 (BACE-1). 44 | 45 | All data are experimental values reported in scientific literature over the 46 | past decade, some with detailed crystal structures available. A collection 47 | of 1522 compounds is provided, along with the regression labels of IC50. 48 | 49 | Scaffold splitting is recommended for this dataset. 50 | 51 | The raw data csv file contains columns below: 52 | 53 | - "mol" - SMILES representation of the molecular structure 54 | - "pIC50" - Negative log of the IC50 binding affinity 55 | - "class" - Binary labels for inhibitor 56 | 57 | Parameters 58 | ---------- 59 | featurizer: Featurizer or str 60 | the featurizer to use for processing the data. Alternatively you can pass 61 | one of the names from dc.molnet.featurizers as a shortcut. 62 | splitter: Splitter or str 63 | the splitter to use for splitting the data into training, validation, and 64 | test sets. Alternatively you can pass one of the names from 65 | dc.molnet.splitters as a shortcut. If this is None, all the data 66 | will be included in a single dataset. 67 | transformers: list of TransformerGenerators or strings 68 | the Transformers to apply to the data. Each one is specified by a 69 | TransformerGenerator or, as a shortcut, one of the names from 70 | dc.molnet.transformers. 71 | reload: bool 72 | if True, the first call for a particular featurizer and splitter will cache 73 | the datasets to disk, and subsequent calls will reload the cached datasets. 74 | data_dir: str 75 | a directory to save the raw data in 76 | save_dir: str 77 | a directory to save the dataset in 78 | 79 | References 80 | ---------- 81 | .. [1] Subramanian, Govindan, et al. "Computational modeling of β-secretase 1 82 | (BACE-1) inhibitors using ligand based approaches." Journal of chemical 83 | information and modeling 56.10 (2016): 1936-1949. 84 | """ 85 | loader = _BaceLoader(featurizer, splitter, transformers, 86 | BACE_REGRESSION_TASKS, data_dir, save_dir, **kwargs) 87 | return loader.load_dataset('bace_r', reload) 88 | 89 | 90 | def load_bace_classification( 91 | featurizer: Union[dc.feat.Featurizer, str] = 'ECFP', 92 | splitter: Union[dc.splits.Splitter, str, None] = 'scaffold', 93 | transformers: List[Union[TransformerGenerator, str]] = ['balancing'], 94 | reload: bool = False, 95 | data_dir: Optional[str] = None, 96 | save_dir: Optional[str] = None, 97 | **kwargs 98 | ) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]: 99 | """ Load BACE dataset, classification labels 100 | 101 | BACE dataset with classification labels ("class"). 102 | 103 | Parameters 104 | ---------- 105 | featurizer: Featurizer or str 106 | the featurizer to use for processing the data. Alternatively you can pass 107 | one of the names from dc.molnet.featurizers as a shortcut. 108 | splitter: Splitter or str 109 | the splitter to use for splitting the data into training, validation, and 110 | test sets. Alternatively you can pass one of the names from 111 | dc.molnet.splitters as a shortcut. If this is None, all the data 112 | will be included in a single dataset. 113 | transformers: list of TransformerGenerators or strings 114 | the Transformers to apply to the data. Each one is specified by a 115 | TransformerGenerator or, as a shortcut, one of the names from 116 | dc.molnet.transformers. 117 | reload: bool 118 | if True, the first call for a particular featurizer and splitter will cache 119 | the datasets to disk, and subsequent calls will reload the cached datasets. 120 | data_dir: str 121 | a directory to save the raw data in 122 | save_dir: str 123 | a directory to save the dataset in 124 | """ 125 | loader = _BaceLoader(featurizer, splitter, transformers, 126 | BACE_CLASSIFICATION_TASKS, data_dir, save_dir, **kwargs) 127 | return loader.load_dataset('bace_c', reload) 128 | -------------------------------------------------------------------------------- /extract_feat_labels.py: -------------------------------------------------------------------------------- 1 | from descriptastorus.descriptors import rdDescriptors, rdNormalizedDescriptors 2 | from rdkit import DataStructs, Chem 3 | from rdkit.Chem import AllChem 4 | import multiprocessing 5 | from tqdm import tqdm 6 | 7 | import multiprocessing 8 | from tqdm import tqdm 9 | 10 | import numpy as np 11 | from rdkit import Chem 12 | from descriptastorus.descriptors import rdDescriptors 13 | from typing import Callable, Union 14 | 15 | 16 | from rdkit import RDLogger 17 | RDLogger.DisableLog('rdApp.*') 18 | 19 | Molecule = Union[str, Chem.Mol] 20 | import argparse 21 | 22 | 23 | RDKIT_PROPS = ['fr_Al_COO', 'fr_Al_OH', 'fr_Al_OH_noTert', 'fr_ArN', 24 | 'fr_Ar_COO', 'fr_Ar_N', 'fr_Ar_NH', 'fr_Ar_OH', 'fr_COO', 'fr_COO2', 25 | 'fr_C_O', 'fr_C_O_noCOO', 'fr_C_S', 'fr_HOCCN', 'fr_Imine', 'fr_NH0', 26 | 'fr_NH1', 'fr_NH2', 'fr_N_O', 'fr_Ndealkylation1', 'fr_Ndealkylation2', 27 | 'fr_Nhpyrrole', 'fr_SH', 'fr_aldehyde', 'fr_alkyl_carbamate', 'fr_alkyl_halide', 28 | 'fr_allylic_oxid', 'fr_amide', 'fr_amidine', 'fr_aniline', 'fr_aryl_methyl', 29 | 'fr_azide', 'fr_azo', 'fr_barbitur', 'fr_benzene', 'fr_benzodiazepine', 30 | 'fr_bicyclic', 'fr_diazo', 'fr_dihydropyridine', 'fr_epoxide', 'fr_ester', 31 | 'fr_ether', 'fr_furan', 'fr_guanido', 'fr_halogen', 'fr_hdrzine', 'fr_hdrzone', 32 | 'fr_imidazole', 'fr_imide', 'fr_isocyan', 'fr_isothiocyan', 'fr_ketone', 33 | 'fr_ketone_Topliss', 'fr_lactam', 'fr_lactone', 'fr_methoxy', 'fr_morpholine', 34 | 'fr_nitrile', 'fr_nitro', 'fr_nitro_arom', 'fr_nitro_arom_nonortho', 35 | 'fr_nitroso', 'fr_oxazole', 'fr_oxime', 'fr_para_hydroxylation', 'fr_phenol', 36 | 'fr_phenol_noOrthoHbond', 'fr_phos_acid', 'fr_phos_ester', 'fr_piperdine', 37 | 'fr_piperzine', 'fr_priamide', 'fr_prisulfonamd', 'fr_pyridine', 'fr_quatN', 38 | 'fr_sulfide', 'fr_sulfonamd', 'fr_sulfone', 'fr_term_acetylene', 'fr_tetrazole', 39 | 'fr_thiazole', 'fr_thiocyan', 'fr_thiophene', 'fr_unbrch_alkane', 'fr_urea'] 40 | 41 | 42 | def rdkit_functional_group_label_features_generator(mol: Molecule) -> np.ndarray: 43 | """ 44 | Generates functional group label for a molecule using RDKit. 45 | 46 | :param mol: A molecule (i.e. either a SMILES string or an RDKit molecule). 47 | :return: A 1D numpy array containing the RDKit 2D features. 48 | """ 49 | smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol 50 | generator = rdDescriptors.RDKit2D(RDKIT_PROPS) 51 | features = generator.process(smiles)[1:] 52 | features = np.array(features) 53 | features[features != 0] = 1 54 | return features 55 | 56 | def rdkit_2d_features_normalized_generator(mol: Molecule) -> np.ndarray: 57 | """ 58 | Generates RDKit 2D normalized features for a molecule. 59 | 60 | :param mol: A molecule (i.e. either a SMILES string or an RDKit molecule). 61 | :return: A 1D numpy array containing the RDKit 2D normalized features. 62 | """ 63 | try: 64 | smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol 65 | generator = rdNormalizedDescriptors.RDKit2DNormalized() 66 | features = generator.process(smiles)[1:] 67 | except: 68 | # import pdb;pdb.set_trace() 69 | features = np.zeros(200) 70 | return features 71 | 72 | def rdkit_2d_features_generator(mol: Molecule) -> np.ndarray: 73 | """ 74 | Generates RDKit 2D features for a molecule. 75 | 76 | :param mol: A molecule (i.e. either a SMILES string or an RDKit molecule). 77 | :return: A 1D numpy array containing the RDKit 2D features. 78 | """ 79 | smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol 80 | generator = rdDescriptors.RDKit2D() 81 | features = generator.process(smiles)[1:] 82 | 83 | return features 84 | 85 | 86 | 87 | 88 | def process_fg(line): 89 | smiles = line.strip() 90 | res = rdkit_functional_group_label_features_generator(smiles) 91 | return res 92 | 93 | 94 | def get_bitvec(smiles): 95 | try: 96 | mol = Chem.MolFromSmiles(smiles) 97 | if mol is None: 98 | return np.full((2048), -1, dtype=np.int8) 99 | mol1 = Chem.AddHs(mol) 100 | fps1 = AllChem.GetMorganFingerprintAsBitVect(mol1, 2, nBits=2048, useChirality=False) 101 | fp_array = np.zeros((0, ), dtype=np.int8) 102 | DataStructs.ConvertToNumpyArray(fps1, fp_array) 103 | return fp_array 104 | except: 105 | return np.full((2048), -1, dtype=np.int8) 106 | 107 | import io 108 | def lines(): 109 | with io.open(data_file, 'r', encoding='utf8', newline='\n') as srcf: 110 | for line in srcf: 111 | yield line.strip() 112 | 113 | if __name__ == "__main__": 114 | 115 | parser = argparse.ArgumentParser(description='extract efcp fingerprints and function group labels') 116 | parser.add_argument("--smiles_file", type=str, default="pubchme_smiles_file.lst") 117 | 118 | args = parser.parse_args() 119 | 120 | 121 | data_file = args.smiles_file 122 | 123 | total = len(io.open(data_file, 'r', encoding='utf8', newline='\n').readlines()) 124 | 125 | 126 | 127 | pool = multiprocessing.Pool(64) 128 | 129 | rdkit_res = np.zeros((total, 85)) 130 | cnt = 0 131 | for res in tqdm(pool.imap(process_fg, lines(), chunksize=100000), total=total): 132 | rdkit_res[cnt] = res 133 | cnt += 1 134 | 135 | np.save("fg_labels.npy", rdkit_res) 136 | 137 | ecfp_feat_res = np.zeros((total, 2048)) 138 | cnt = 0 139 | for res in tqdm(pool.imap(get_bitvec, lines(), chunksize=100), total=total): 140 | ecfp_feat_res[cnt] = res 141 | cnt += 1 142 | 143 | np.save("ecfp_regression.npy", ecfp_feat_res) 144 | 145 | 146 | 147 | 148 | 149 | -------------------------------------------------------------------------------- /utils/sascorer.py: -------------------------------------------------------------------------------- 1 | # NOT OUR CODE: 2 | # 3 | # calculation of synthetic accessibility score as described in: 4 | # 5 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions 6 | # Peter Ertl and Ansgar Schuffenhauer 7 | # Journal of Cheminformatics 1:8 (2009) 8 | # http://www.jcheminf.com/content/1/1/8 9 | # 10 | # several small modifications to the original paper are included 11 | # particularly slightly different formula for marocyclic penalty 12 | # and taking into account also molecule symmetry (fingerprint density) 13 | # 14 | # for a set of 10k diverse molecules the agreement between the original method 15 | # as implemented in PipelinePilot and this implementation is r2 = 0.97 16 | # 17 | # peter ertl & greg landrum, september 2013 18 | # 19 | 20 | 21 | from rdkit import Chem 22 | from rdkit.Chem import rdMolDescriptors 23 | import pickle 24 | 25 | import math 26 | from collections import defaultdict 27 | 28 | import os.path as op 29 | 30 | _fscores = None 31 | 32 | 33 | def readFragmentScores(name='fpscores'): 34 | import gzip 35 | global _fscores 36 | # generate the full path filename: 37 | if name == "fpscores": 38 | name = op.join(op.dirname(__file__), name) 39 | data = pickle.load(gzip.open('%s.pkl.gz' % name)) 40 | outDict = {} 41 | for i in data: 42 | for j in range(1, len(i)): 43 | outDict[i[j]] = float(i[0]) 44 | _fscores = outDict 45 | 46 | 47 | def numBridgeheadsAndSpiro(mol, ri=None): 48 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) 49 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) 50 | return nBridgehead, nSpiro 51 | 52 | 53 | def calculateScore(m): 54 | if _fscores is None: 55 | readFragmentScores() 56 | 57 | # fragment score 58 | fp = rdMolDescriptors.GetMorganFingerprint(m, 59 | 2) # <- 2 is the *radius* of the circular fingerprint 60 | fps = fp.GetNonzeroElements() 61 | score1 = 0. 62 | nf = 0 63 | for bitId, v in fps.items(): 64 | nf += v 65 | sfp = bitId 66 | score1 += _fscores.get(sfp, -4) * v 67 | score1 /= nf 68 | 69 | # features score 70 | nAtoms = m.GetNumAtoms() 71 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) 72 | ri = m.GetRingInfo() 73 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) 74 | nMacrocycles = 0 75 | for x in ri.AtomRings(): 76 | if len(x) > 8: 77 | nMacrocycles += 1 78 | 79 | sizePenalty = nAtoms**1.005 - nAtoms 80 | stereoPenalty = math.log10(nChiralCenters + 1) 81 | spiroPenalty = math.log10(nSpiro + 1) 82 | bridgePenalty = math.log10(nBridgeheads + 1) 83 | macrocyclePenalty = 0. 84 | # --------------------------------------- 85 | # This differs from the paper, which defines: 86 | # macrocyclePenalty = math.log10(nMacrocycles+1) 87 | # This form generates better results when 2 or more macrocycles are present 88 | if nMacrocycles > 0: 89 | macrocyclePenalty = math.log10(2) 90 | 91 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty 92 | 93 | # correction for the fingerprint density 94 | # not in the original publication, added in version 1.1 95 | # to make highly symmetrical molecules easier to synthetise 96 | score3 = 0. 97 | if nAtoms > len(fps): 98 | score3 = math.log(float(nAtoms) / len(fps)) * .5 99 | 100 | sascore = score1 + score2 + score3 101 | 102 | # need to transform "raw" value into scale between 1 and 10 103 | min = -4.0 104 | max = 2.5 105 | sascore = 11. - (sascore - min + 1) / (max - min) * 9. 106 | # smooth the 10-end 107 | if sascore > 8.: 108 | sascore = 8. + math.log(sascore + 1. - 9.) 109 | if sascore > 10.: 110 | sascore = 10.0 111 | elif sascore < 1.: 112 | sascore = 1.0 113 | 114 | return sascore 115 | 116 | 117 | def processMols(mols): 118 | print('smiles\tName\tsa_score') 119 | for i, m in enumerate(mols): 120 | if m is None: 121 | continue 122 | 123 | s = calculateScore(m) 124 | 125 | smiles = Chem.MolToSmiles(m) 126 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s) 127 | 128 | 129 | if __name__ == '__main__': 130 | import sys 131 | import time 132 | 133 | t1 = time.time() 134 | readFragmentScores("fpscores") 135 | t2 = time.time() 136 | 137 | suppl = Chem.SmilesMolSupplier(sys.argv[1]) 138 | t3 = time.time() 139 | processMols(suppl) 140 | t4 = time.time() 141 | 142 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)), 143 | file=sys.stderr) 144 | 145 | # 146 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc. 147 | # All rights reserved. 148 | # 149 | # Redistribution and use in source and binary forms, with or without 150 | # modification, are permitted provided that the following conditions are 151 | # met: 152 | # 153 | # * Redistributions of source code must retain the above copyright 154 | # notice, this list of conditions and the following disclaimer. 155 | # * Redistributions in binary form must reproduce the above 156 | # copyright notice, this list of conditions and the following 157 | # disclaimer in the documentation and/or other materials provided 158 | # with the distribution. 159 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 160 | # nor the names of its contributors may be used to endorse or promote 161 | # products derived from this software without specific prior written permission. 162 | # 163 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 164 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 165 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 166 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 167 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 168 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 169 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 170 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 171 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 172 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 173 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 174 | # -------------------------------------------------------------------------------- /utils/torchvocab.py: -------------------------------------------------------------------------------- 1 | """ 2 | The contextual property. 3 | """ 4 | import pickle 5 | from collections import Counter 6 | from multiprocessing import Pool 7 | 8 | import tqdm 9 | from rdkit import Chem 10 | 11 | from utils.mol import atom_to_vocab 12 | from utils.mol import bond_to_vocab 13 | 14 | 15 | class TorchVocab(object): 16 | """ 17 | Defines the vocabulary for atoms/bonds in molecular. 18 | """ 19 | 20 | def __init__(self, counter, max_size=None, min_freq=1, specials=('', ''), vocab_type='atom'): 21 | """ 22 | 23 | :param counter: 24 | :param max_size: 25 | :param min_freq: 26 | :param specials: 27 | :param vocab_type: 'atom': atom atom_vocab; 'bond': bond atom_vocab. 28 | """ 29 | self.freqs = counter 30 | counter = counter.copy() 31 | min_freq = max(min_freq, 1) 32 | if vocab_type in ('atom', 'bond'): 33 | self.vocab_type = vocab_type 34 | else: 35 | raise ValueError('Wrong input for vocab_type!') 36 | self.itos = list(specials) 37 | 38 | max_size = None if max_size is None else max_size + len(self.itos) 39 | # sort by frequency, then alphabetically 40 | words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) 41 | words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) 42 | 43 | for word, freq in words_and_frequencies: 44 | if freq < min_freq or len(self.itos) == max_size: 45 | break 46 | self.itos.append(word) 47 | # stoi is simply a reverse dict for itos 48 | self.stoi = {tok: i for i, tok in enumerate(self.itos)} 49 | self.other_index = 1 50 | self.pad_index = 0 51 | 52 | def __eq__(self, other): 53 | if self.freqs != other.freqs: 54 | return False 55 | if self.stoi != other.stoi: 56 | return False 57 | if self.itos != other.itos: 58 | return False 59 | # if self.vectors != other.vectors: 60 | # return False 61 | return True 62 | 63 | def __len__(self): 64 | return len(self.itos) 65 | 66 | def vocab_rerank(self): 67 | self.stoi = {word: i for i, word in enumerate(self.itos)} 68 | 69 | def extend(self, v, sort=False): 70 | words = sorted(v.itos) if sort else v.itos 71 | for w in words: 72 | if w not in self.stoi: 73 | self.itos.append(w) 74 | self.stoi[w] = len(self.itos) - 1 75 | self.freqs[w] = 0 76 | self.freqs[w] += v.freqs[w] 77 | 78 | def mol_to_seq(self, mol, with_len=False): 79 | mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol 80 | if self.vocab_type == 'atom': 81 | seq = [self.stoi.get(atom_to_vocab(mol, atom), self.other_index) for i, atom in enumerate(mol.GetAtoms())] 82 | else: 83 | seq = [self.stoi.get(bond_to_vocab(mol, bond), self.other_index) for i, bond in enumerate(mol.GetBonds())] 84 | return (seq, len(seq)) if with_len else seq 85 | 86 | @staticmethod 87 | def load_vocab(vocab_path: str) -> 'Vocab': 88 | with open(vocab_path, "rb") as f: 89 | return pickle.load(f) 90 | 91 | def save_vocab(self, vocab_path): 92 | with open(vocab_path, "wb") as f: 93 | pickle.dump(self, f) 94 | 95 | 96 | class MolVocab(TorchVocab): 97 | def __init__(self, smiles, max_size=None, min_freq=1, vocab_type='atom'): 98 | if vocab_type in ('atom', 'bond'): 99 | self.vocab_type = vocab_type 100 | else: 101 | raise ValueError('Wrong input for vocab_type!') 102 | 103 | print("Building %s vocab from smiles: %d" % (self.vocab_type, len(smiles))) 104 | counter = Counter() 105 | 106 | for smi in tqdm.tqdm(smiles): 107 | mol = Chem.MolFromSmiles(smi) 108 | if self.vocab_type == 'atom': 109 | for _, atom in enumerate(mol.GetAtoms()): 110 | v = atom_to_vocab(mol, atom) 111 | counter[v] += 1 112 | else: 113 | for _, bond in enumerate(mol.GetBonds()): 114 | v = bond_to_vocab(mol, bond) 115 | counter[v] += 1 116 | super().__init__(counter, max_size=max_size, min_freq=min_freq, vocab_type=vocab_type) 117 | 118 | def __init__(self, file_path, max_size=None, min_freq=1, num_workers=1, total_lines=None, vocab_type='atom'): 119 | if vocab_type in ('atom', 'bond'): 120 | self.vocab_type = vocab_type 121 | else: 122 | raise ValueError('Wrong input for vocab_type!') 123 | print("Building %s vocab from file: %s" % (self.vocab_type, file_path)) 124 | 125 | from rdkit import RDLogger 126 | lg = RDLogger.logger() 127 | lg.setLevel(RDLogger.CRITICAL) 128 | 129 | if total_lines is None: 130 | def file_len(fname): 131 | f_len = 0 132 | with open(fname) as f: 133 | for f_len, _ in enumerate(f): 134 | pass 135 | return f_len + 1 136 | 137 | total_lines = file_len(file_path) 138 | 139 | counter = Counter() 140 | pbar = tqdm.tqdm(total=total_lines) 141 | pool = Pool(num_workers) 142 | res = [] 143 | batch = 50000 144 | callback = lambda a: pbar.update(batch) 145 | for i in range(int(total_lines / batch + 1)): 146 | start = int(batch * i) 147 | end = min(total_lines, batch * (i + 1)) 148 | # print("Start: %d, End: %d"%(start, end)) 149 | res.append(pool.apply_async(MolVocab.read_smiles_from_file, 150 | args=(file_path, start, end, vocab_type,), 151 | callback=callback)) 152 | # read_smiles_from_file(lock, file_path, start, end) 153 | pool.close() 154 | pool.join() 155 | for r in res: 156 | sub_counter = r.get() 157 | for k in sub_counter: 158 | if k not in counter: 159 | counter[k] = 0 160 | counter[k] += sub_counter[k] 161 | # print(counter) 162 | super().__init__(counter, max_size=max_size, min_freq=min_freq, vocab_type=vocab_type) 163 | 164 | @staticmethod 165 | def read_smiles_from_file(file_path, start, end, vocab_type): 166 | # print("start") 167 | smiles = open(file_path, "r") 168 | smiles.readline() 169 | sub_counter = Counter() 170 | for i, smi in enumerate(smiles): 171 | if i < start: 172 | continue 173 | if i >= end: 174 | break 175 | 176 | 177 | if len(smi.split()) > 1: # contains iupac 178 | smi = smi.split()[0] 179 | mol = Chem.MolFromSmiles(smi) 180 | if vocab_type == 'atom': 181 | for atom in mol.GetAtoms(): 182 | v = atom_to_vocab(mol, atom) 183 | sub_counter[v] += 1 184 | else: 185 | for bond in mol.GetBonds(): 186 | v = bond_to_vocab(mol, bond) 187 | sub_counter[v] += 1 188 | # print("end") 189 | return sub_counter 190 | 191 | @staticmethod 192 | def load_vocab(vocab_path: str) -> 'MolVocab': 193 | with open(vocab_path, "rb") as f: 194 | return pickle.load(f) 195 | -------------------------------------------------------------------------------- /parsing.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import re 3 | import argparse 4 | from pathlib import Path 5 | import pyparsing 6 | import json 7 | from tqdm import tqdm 8 | import sys 9 | import random 10 | 11 | ''' 12 | rlaunch --private-machine=group --charged-group=health --cpu=8 --gpu=0 --memory=50000 \ 13 | -- python parsing.py --dataset=/sharefs/ylx/chem_data/pubchem/data_1m/100.csv \ 14 | --output_dir=/sharefs/ylx/chem_data/pubchem/data_1m/processed/ 15 | 16 | nohup rlaunch --private-machine=group --charged-group=health --cpu=8 --gpu=0 --memory=50000 \ 17 | -- python parsing.py --dataset=/sharefs/ylx/chem_data/pubchem/data_1m_5cols/iupacs.csv \ 18 | --output_dir=/sharefs/ylx/chem_data/pubchem/data_1m_5cols/processed/ \ 19 | > parsing.log & 20 | ''' 21 | 22 | def print_list(lst, level=0): 23 | # print(' ' * (level - 1) + '+---' * (level > 0) + lst[0]) 24 | for l in lst: 25 | if type(l) is list: 26 | print_list(l, level + 1) 27 | else: 28 | print(' ' * level + '+---' + l) 29 | 30 | 31 | 32 | def get_dict_sub(lst, pos_list, parsed_iupac_dic, level=0): 33 | # lst = ['5', ['1H', 'isopropylbutyl'], '4', 'propylundecane'] 34 | # print(get_dict(lst)) 35 | ## {'isopropylbutyl': {0: [5], 1: [1]}, 'propylundecane': {0: [4]}} 36 | for l in lst: 37 | if type(l) is list: 38 | get_dict_sub(l, pos_list, parsed_iupac_dic, level + 1) 39 | else: 40 | if type(l) is str: 41 | temp = re.findall(r'\d+', l) # '1H'-> 1 42 | num_list = list(map(int, temp)) 43 | # if l.isdigit(): 44 | if len(num_list) > 0: 45 | # assert len(res)==1 46 | if len(pos_list) > 0 and max(pos_list.keys())>=level: 47 | for key in list(pos_list): 48 | if key>=level: 49 | pos_list.pop(key) 50 | 51 | if level not in list(pos_list): 52 | pos_list[level] = [] 53 | 54 | # pos_list[level].append(int(l)) 55 | pos_list[level]+=num_list 56 | 57 | else: 58 | # dic[l] = pos_list.copy() # not enough 59 | parsed_iupac_dic[l] = copy.deepcopy(pos_list) 60 | if len(pos_list)>0: 61 | pos_list.pop(max(pos_list.keys()),None) 62 | # print(dic) # stepwise check 63 | 64 | return parsed_iupac_dic 65 | 66 | 67 | def get_dict(lst): 68 | pos_list = {} 69 | parsed_iupac_dic={} 70 | res = get_dict_sub(lst, pos_list, parsed_iupac_dic) 71 | return res 72 | 73 | def get_val_dic_list(idx_value_list):#,iupac_list): 74 | indexes = [] 75 | final_dic_list = [] 76 | #assert len(idx_value_list) == len(iupac_list) 77 | for idx, dic in idx_value_list: 78 | indexes.append(idx) 79 | #dic['index'] = idx 80 | #dic['iupac'] = iupac_list[idx] 81 | final_dic_list.append(dic) 82 | return final_dic_list, indexes 83 | 84 | def get_train_dic_list(parsed_iupac_list,val_indexes): # iupac_list 85 | indexes = [] 86 | final_dic_list = [] 87 | #assert len(idx_value_list) == len(iupac_list) 88 | for idx, dic in enumerate(parsed_iupac_list): 89 | if idx not in val_indexes: 90 | indexes.append(idx) 91 | #dic['index'] = idx 92 | #dic['iupac'] = iupac_list[idx] 93 | final_dic_list.append(dic) 94 | return final_dic_list, indexes 95 | 96 | if __name__=='__main__': 97 | 98 | ''' 99 | #lst = ['a', ['b', 'c', ['d', 'i'], 'e'], 'f', ['g', 'h', ['j', 'k', 'l', 'm']]] 100 | #lst = ['5', ['1,2', 'isopropylbutyl'], '4', 'propylundecane'] 101 | lst = ['5', ['1', 'isopropylbutyl'], ['2','3', 'dimethane','4','xxx'], '4', 'propylundecane'] 102 | pos_list = {} # the variables must be claimed out of the recursion func 103 | parsed_iupac_dic={} 104 | print(get_dict(lst)) 105 | ''' 106 | 107 | sys.setrecursionlimit(10000) 108 | random.seed(42) 109 | 110 | parser = argparse.ArgumentParser() 111 | parser.add_argument("--dataset", required=True, type=str) 112 | parser.add_argument("--output_dir", required=True, type=str) 113 | #parser.add_argument("--output_file", required=True, type=str) 114 | parser.add_argument("--val_percent", required=False, default='0.1' ,type=float) 115 | 116 | args = parser.parse_args() 117 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 118 | 119 | processed_name_list = [] 120 | iupac_list = [] 121 | 122 | thecontent = pyparsing.Word(pyparsing.alphanums) #| '-' # | '+' #pyparsing.alphanums 123 | parens = pyparsing.nestedExpr( '(', ')', content=thecontent) 124 | 125 | with open(args.dataset,'r') as f: 126 | # with open(args.output_dir/args.output_file,'w'): 127 | #names = f.readlines()[1:] 128 | names = f.read().splitlines()[1:] 129 | for i,line in enumerate(tqdm(names)): 130 | #print(line) 131 | line_ = line.replace('"','')\ 132 | .replace('[','(').replace(']',')')\ 133 | .replace(',','sep').replace('.','sep')\ 134 | .replace('-',' ').replace('+',' ').replace(';',' ').replace('&',' ').replace('?',' ') 135 | #clean = re.sub(r"[,.;+-@#?!&$]+", " ", line_) 136 | nested_line_ = '('+line_+')' 137 | name_list = parens.parseString(nested_line_).asList()[0] 138 | #print(name_list) 139 | # parsed_iupac_dic = {} 140 | # pos_list = {} 141 | instance_dic = {} 142 | name_dic = get_dict(name_list) 143 | instance_dic['index'] = i 144 | instance_dic['iupac'] = line.replace('"','') 145 | instance_dic['parsed_iupac'] = name_dic 146 | 147 | #print(name_dict) 148 | processed_name_list.append(instance_dic) 149 | #iupac_list.append(line.replace('"','')) # keep original iupac name 150 | f.close() 151 | 152 | with open(Path(args.output_dir) / 'full.jsonl','w') as w: 153 | # json.dump(processed_name_list,w) 154 | for row in processed_name_list: 155 | print(json.dumps(row), file=w) 156 | w.close() 157 | 158 | if args.val_percent > 0: 159 | val_size = int(len(processed_name_list)*args.val_percent) 160 | print('Sampling validation set,val_percent:',args.val_percent) 161 | val_idx_value = random.sample(list(enumerate(processed_name_list)), val_size) # [(id,value),...] 162 | 163 | print('Getting val&train dic list...') 164 | val_final_dic_list,val_indexes = get_val_dic_list(val_idx_value)#,iupac_list) 165 | train_final_dic_list,train_indexes = get_train_dic_list(processed_name_list,val_indexes) #iupac_list 166 | 167 | assert len(set(val_indexes) & set(train_indexes)) == 0 168 | 169 | ''' 170 | val_indexes = [] 171 | val_values = [] 172 | for idx, val in val_idx_value: 173 | val_indexes.append(idx) 174 | val['index'] = idx 175 | val['iupac'] = iupac_list[idx] 176 | val_values.append(val) 177 | print('Taking the left as training set...') 178 | 179 | train_indexes = [] 180 | train_values = [] 181 | for i, e in enumerate(processed_name_list): 182 | if i not in val_indexes: 183 | train_indexes.append(i) 184 | train_values.append(e) 185 | train_values = [processed_name_list[i] for i, e in enumerate(processed_name_list) if i not in val_indexes] 186 | ''' 187 | 188 | print('Val/Train size: ',len(val_final_dic_list),len(train_final_dic_list)) # Val/Train size: 89671 807047 189 | 190 | print('writing files...') 191 | with open(Path(args.output_dir) / 'val.jsonl','w') as w: # val 192 | for row in val_final_dic_list: 193 | print(json.dumps(row), file=w) 194 | w.close() 195 | 196 | with open(Path(args.output_dir) / 'train.jsonl','w') as w: # train 197 | for row in train_final_dic_list: 198 | print(json.dumps(row), file=w) 199 | w.close() -------------------------------------------------------------------------------- /utils/gen_dataset.py: -------------------------------------------------------------------------------- 1 | import selfies as sf 2 | from torch.utils.data import Dataset 3 | import os 4 | import pickle 5 | import torch 6 | import sys 7 | sys.path.insert(0,'..') 8 | from utils.mol import smiles2graph 9 | from typing import List, Dict 10 | from torch_geometric.data import Data, Batch 11 | from iupac_token import IUPACTokenizer, SmilesIUPACTokenizer, SmilesTokenizer 12 | from torch.nn import functional as F 13 | # data_folder: 14 | import pytorch_lightning as pl 15 | from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split 16 | 17 | class GenSelfiesDataset(Dataset): 18 | def __init__(self, data_folder, tokenizer, sim_file='zinc250k.smi', smi_self_file='train_self.smi', vocab_file='vocab_lst.pkl', pkl_file='dm_info.pkl', max_seq_length=128): 19 | super().__init__() 20 | self.data_folder = data_folder 21 | self.tokenizer = tokenizer 22 | self.max_seq_length = max_seq_length 23 | self.sim_file = os.path.join(self.data_folder, sim_file) 24 | self.smi_self_file = os.path.join(self.data_folder, smi_self_file) 25 | self.vocab_file = os.path.join(self.data_folder, vocab_file) 26 | 27 | self.smile_lst = [] 28 | self.selfies = [] 29 | self.alphabet = set() 30 | with open(self.sim_file, 'r') as sr: 31 | for line in sr: 32 | self.smile_lst.append(line.strip()) 33 | 34 | if os.path.exists(self.vocab_file): 35 | with open(self.vocab_file, "rb") as fp: 36 | self.alphabet = pickle.load(fp) 37 | with open(self.smi_self_file, 'r') as fr: 38 | for line in fr: 39 | self.selfies.append(line.strip()) 40 | elif pkl_file is not None: 41 | pkl_file_path = os.path.join(self.data_folder, pkl_file) 42 | with open(pkl_file_path, "rb") as fp: 43 | dataset_info = pickle.load(fp) 44 | self.alphabet = dataset_info['alphabet'] 45 | self.max_len = dataset_info['max_len'] + 1 46 | self.symbol_to_idx = dataset_info['symbol_to_idx'] 47 | self.idx_to_symbol = dataset_info['idx_to_symbol'] 48 | self.encodings = dataset_info['encodings'] 49 | self.pad_idx = len(self.symbol_to_idx) 50 | else: 51 | self.generate_file() 52 | 53 | self.length = len(self.smile_lst) 54 | 55 | if pkl_file is None: 56 | self.max_len = max(len(list(sf.split_selfies(s))) for s in self.selfies) 57 | self.symbol_to_idx = {s: i for i, s in enumerate(self.alphabet)} 58 | self.idx_to_symbol = {i: s for i, s in enumerate(self.alphabet)} 59 | self.encodings = [[self.symbol_to_idx[symbol] for symbol in sf.split_selfies(s)] for s in self.selfies] 60 | 61 | 62 | 63 | def generate_file(self): 64 | assert os.path.exists(self.sim_file) 65 | 66 | for smi in self.smile_lst: 67 | self.selfies.append(sf.encoder(smi)) 68 | for s in self.selfies: 69 | self.alphabet.update(sf.split_selfies(s)) 70 | self.alphabet = ['[nop]'] + list(sorted(self.alphabet)) 71 | 72 | # save self.selfies self.alphabet 73 | with open(self.vocab_file, "wb") as fp: 74 | pickle.dump(self.alphabet, fp) 75 | 76 | # save self.selfiles 77 | with open(self.smi_self_file, 'w') as fw: 78 | for line in self.selfies: 79 | fw.write(f"{line}\n") 80 | 81 | pass 82 | 83 | def __len__(self): 84 | return self.length 85 | 86 | def _getitem_smi(self, smi): 87 | item = {} 88 | inputs = self.tokenizer( 89 | smi, 90 | add_special_tokens=True, 91 | max_length=self.max_seq_length, 92 | padding="max_length", 93 | truncation=True, 94 | ) 95 | item['input_ids'] = inputs['input_ids'] 96 | item['attention_mask'] = inputs['attention_mask'] 97 | graph, _ = smiles2graph(smi) 98 | item['graph'] = graph 99 | return item 100 | 101 | def __getitem__(self, i): 102 | item = {} 103 | item['target_encoding'] = torch.tensor(self.encodings[i] + [self.symbol_to_idx['[nop]']] + [self.pad_idx for i in range(self.max_len - len(self.encodings[i]))]) 104 | smi = self.smile_lst[i] 105 | inputs = self.tokenizer( 106 | smi, 107 | add_special_tokens=True, 108 | max_length=self.max_seq_length, 109 | padding="max_length", 110 | truncation=True, 111 | ) 112 | item['input_ids'] = inputs['input_ids'] 113 | item['attention_mask'] = inputs['attention_mask'] 114 | graph, _ = smiles2graph(smi) 115 | item['graph'] = graph 116 | return item 117 | 118 | def one_hot_to_selfies(self, hot): 119 | return ''.join([self.idx_to_symbol[idx.item()] for idx in hot.view((self.max_len, -1)).argmax(1)]).replace(' ', '') 120 | 121 | def one_hot_to_selfies_multi(self, hot): 122 | hot = hot.view(self.max_len, -1) 123 | hot = F.softmax(hot, dim=-1) 124 | hot_idx = [torch.multinomial(hot[i], num_samples=1)[0] for i in range(self.max_len)] 125 | return ''.join([self.idx_to_symbol[idx.item()] for idx in hot_idx]).replace(' ', '') 126 | 127 | def one_hot_to_smiles(self, hot): 128 | # return sf.decoder(self.one_hot_to_selfies(hot)) 129 | return sf.decoder(self.one_hot_to_selfies_multi(hot)) 130 | 131 | def default_data_collator(features: List[Dict]) -> Dict[str, torch.Tensor]: 132 | first = features[0] 133 | batch = {} 134 | 135 | graph_lst = [] 136 | for k, v in first.items(): 137 | if v is not None and not isinstance(v, str): 138 | if isinstance(v, torch.Tensor): 139 | batch[k] = torch.stack([f[k] for f in features]) 140 | elif isinstance(v, Data): 141 | graph_lst = [f[k] for f in features] 142 | else: 143 | batch[k] = torch.tensor([f[k] for f in features]) 144 | 145 | if len(graph_lst): 146 | batch['graph'] = Batch.from_data_list(graph_lst) 147 | 148 | 149 | return batch 150 | 151 | 152 | class Dataset(Dataset): 153 | def __init__(self, file): 154 | selfies = [sf.encoder(line.split()[0]) for line in open(file, 'r')] 155 | self.alphabet = set() 156 | for s in selfies: 157 | self.alphabet.update(sf.split_selfies(s)) 158 | self.alphabet = ['[nop]'] + list(sorted(self.alphabet)) 159 | self.max_len = max(len(list(sf.split_selfies(s))) for s in selfies) 160 | self.symbol_to_idx = {s: i for i, s in enumerate(self.alphabet)} 161 | self.idx_to_symbol = {i: s for i, s in enumerate(self.alphabet)} 162 | self.encodings = [[self.symbol_to_idx[symbol] for symbol in sf.split_selfies(s)] for s in selfies] 163 | 164 | def __len__(self): 165 | return len(self.encodings) 166 | 167 | def __getitem__(self, i): 168 | return torch.tensor(self.encodings[i] + [self.symbol_to_idx['[nop]'] for i in range(self.max_len - len(self.encodings[i]))]) 169 | 170 | 171 | class MolDataModule(pl.LightningDataModule): 172 | def __init__(self, batch_size, file): 173 | super(MolDataModule, self).__init__() 174 | self.batch_size = batch_size 175 | self.dataset = Dataset(file) 176 | self.train_data, self.test_data = random_split(self.dataset, [int(round(len(self.dataset) * 0.8)), int(round(len(self.dataset) * 0.2))]) 177 | 178 | def train_dataloader(self): 179 | return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=16, pin_memory=True) 180 | 181 | def val_dataloader(self): 182 | return DataLoader(self.test_data, batch_size=self.batch_size, drop_last=True, num_workers=16, pin_memory=True) 183 | 184 | if __name__ == "__main__": 185 | smiles_tokenizer = SmilesTokenizer.from_pretrained('/home/fengshikun/iupac-pretrain_3/smiles_tokenizer', max_len=128) 186 | gen_self = GenSelfiesDataset(data_folder='/sharefs/sharefs-test_data/LIMO', tokenizer=smiles_tokenizer, sim_file='test.smi') -------------------------------------------------------------------------------- /utils/molnet_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common code for loading MoleculeNet datasets. 3 | """ 4 | import os 5 | import logging 6 | import deepchem as dc 7 | from deepchem.data import Dataset, DiskDataset 8 | from typing import List, Optional, Tuple, Type, Union 9 | from utils.splitters import * 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class TransformerGenerator(object): 15 | """Create Transformers for Datasets. 16 | 17 | When loading molnet datasets, you cannot directly pass in Transformers 18 | to use because many Transformers require the Dataset they will be applied to 19 | as a constructor argument. Instead you pass in TransformerGenerator objects 20 | which can create the Transformers once the Dataset is loaded. 21 | """ 22 | 23 | def __init__(self, transformer_class: Type[dc.trans.Transformer], **kwargs): 24 | """Construct an object for creating Transformers. 25 | 26 | Parameters 27 | ---------- 28 | transformer_class: Type[Transformer] 29 | the class of Transformer to create 30 | kwargs: 31 | any additional arguments are passed to the Transformer's constructor 32 | """ 33 | self.transformer_class = transformer_class 34 | self.kwargs = kwargs 35 | 36 | def create_transformer(self, dataset: Dataset) -> dc.trans.Transformer: 37 | """Construct a Transformer for a Dataset.""" 38 | return self.transformer_class(dataset=dataset, **self.kwargs) 39 | 40 | def get_directory_name(self) -> str: 41 | """Get a name for directories on disk describing this Transformer.""" 42 | name = self.transformer_class.__name__ 43 | for key, value in self.kwargs.items(): 44 | if isinstance(value, list): 45 | continue 46 | name += '_' + key + '_' + str(value) 47 | return name 48 | 49 | 50 | featurizers = { 51 | 'ecfp': dc.feat.CircularFingerprint(size=1024), 52 | 'graphconv': dc.feat.ConvMolFeaturizer(), 53 | 'raw': dc.feat.RawFeaturizer(), 54 | 'onehot': dc.feat.OneHotFeaturizer(), 55 | 'smiles2img': dc.feat.SmilesToImage(img_size=80, img_spec='std'), 56 | 'weave': dc.feat.WeaveFeaturizer(), 57 | } 58 | 59 | splitters = { 60 | 'index': IndexSplitter(), 61 | 'random': RandomSplitter(), 62 | 'scaffold': ScaffoldSplitter(), 63 | 'scaffold_balance': ScaffoldBalanceSplitter(), 64 | 'butina': ButinaSplitter(), 65 | 'fingerprint': FingerprintSplitter(), 66 | 'task': dc.splits.TaskSplitter(), 67 | 'stratified': RandomStratifiedSplitter() 68 | } 69 | 70 | transformers = { 71 | 'balancing': 72 | TransformerGenerator(dc.trans.BalancingTransformer), 73 | 'normalization': 74 | TransformerGenerator(dc.trans.NormalizationTransformer, transform_y=True), 75 | 'minmax': 76 | TransformerGenerator(dc.trans.MinMaxTransformer, transform_y=True), 77 | 'clipping': 78 | TransformerGenerator(dc.trans.ClippingTransformer, transform_y=True), 79 | 'log': 80 | TransformerGenerator(dc.trans.LogTransformer, transform_y=True) 81 | } 82 | 83 | 84 | class _MolnetLoader(object): 85 | """The class provides common functionality used by many molnet loader functions. 86 | It is an abstract class. Subclasses implement loading of particular datasets. 87 | """ 88 | 89 | def __init__(self, featurizer: Union[dc.feat.Featurizer, str], 90 | splitter: Union[dc.splits.Splitter, str, None], 91 | transformer_generators: List[Union[TransformerGenerator, str]], 92 | tasks: List[str], data_dir: Optional[str], 93 | save_dir: Optional[str], seed=0, **kwargs): 94 | """Construct an object for loading a dataset. 95 | 96 | Parameters 97 | ---------- 98 | featurizer: Featurizer or str 99 | the featurizer to use for processing the data. Alternatively you can pass 100 | one of the names from dc.molnet.featurizers as a shortcut. 101 | splitter: Splitter or str 102 | the splitter to use for splitting the data into training, validation, and 103 | test sets. Alternatively you can pass one of the names from 104 | dc.molnet.splitters as a shortcut. If this is None, all the data 105 | will be included in a single dataset. 106 | transformer_generators: list of TransformerGenerators or strings 107 | the Transformers to apply to the data. Each one is specified by a 108 | TransformerGenerator or, as a shortcut, one of the names from 109 | dc.molnet.transformers. 110 | tasks: List[str] 111 | the names of the tasks in the dataset 112 | data_dir: str 113 | a directory to save the raw data in 114 | save_dir: str 115 | a directory to save the dataset in 116 | """ 117 | if 'split' in kwargs: 118 | splitter = kwargs['split'] 119 | logger.warning("'split' is deprecated. Use 'splitter' instead.") 120 | if isinstance(featurizer, str): 121 | featurizer = featurizers[featurizer.lower()] 122 | if isinstance(splitter, str): 123 | splitter = splitters[splitter.lower()] 124 | if data_dir is None: 125 | data_dir = dc.utils.data_utils.get_data_dir() 126 | if save_dir is None: 127 | save_dir = dc.utils.data_utils.get_data_dir() 128 | self.featurizer = featurizer 129 | self.splitter = splitter 130 | self.transformers = [ 131 | transformers[t.lower()] if isinstance(t, str) else t 132 | for t in transformer_generators 133 | ] 134 | self.tasks = list(tasks) 135 | self.data_dir = data_dir 136 | self.save_dir = save_dir 137 | self.args = kwargs 138 | self.seed = seed 139 | 140 | def load_dataset( 141 | self, name: str, reload: bool 142 | ) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]: 143 | """Load the dataset. 144 | 145 | Parameters 146 | ---------- 147 | name: str 148 | the name of the dataset, used to identify the directory on disk 149 | reload: bool 150 | if True, the first call for a particular featurizer and splitter will cache 151 | the datasets to disk, and subsequent calls will reload the cached datasets. 152 | """ 153 | # Build the path to the dataset on disk. 154 | 155 | featurizer_name = str(self.featurizer) 156 | splitter_name = 'None' if self.splitter is None else str(self.splitter) 157 | save_folder = os.path.join(self.save_dir, name + "-featurized", 158 | featurizer_name, splitter_name) 159 | if len(self.transformers) > 0: 160 | transformer_name = '_'.join( 161 | t.get_directory_name() for t in self.transformers) 162 | save_folder = os.path.join(save_folder, transformer_name) 163 | 164 | # Try to reload cached datasets. 165 | 166 | if reload: 167 | if self.splitter is None: 168 | if os.path.exists(save_folder): 169 | transformers = dc.utils.data_utils.load_transformers(save_folder) 170 | return self.tasks, (DiskDataset(save_folder),), transformers 171 | else: 172 | loaded, all_dataset, transformers = dc.utils.data_utils.load_dataset_from_disk( 173 | save_folder) 174 | if all_dataset is not None: 175 | return self.tasks, all_dataset, transformers 176 | 177 | # Create the dataset 178 | 179 | logger.info("About to featurize %s dataset." % name) 180 | dataset = self.create_dataset() 181 | 182 | # Split and transform the dataset. 183 | 184 | if self.splitter is None: 185 | transformer_dataset: Dataset = dataset 186 | else: 187 | logger.info("About to split dataset with {} splitter.".format( 188 | self.splitter.__class__.__name__)) 189 | train, valid, test = self.splitter.train_valid_test_split(dataset, seed=self.seed) 190 | transformer_dataset = train 191 | transformers = [ 192 | t.create_transformer(transformer_dataset) for t in self.transformers 193 | ] 194 | logger.info("About to transform data.") 195 | if self.splitter is None: 196 | for transformer in transformers: 197 | dataset = transformer.transform(dataset) 198 | if reload and isinstance(dataset, DiskDataset): 199 | dataset.move(save_folder) 200 | dc.utils.data_utils.save_transformers(save_folder, transformers) 201 | return self.tasks, (dataset,), transformers 202 | 203 | for transformer in transformers: 204 | train = transformer.transform(train) 205 | valid = transformer.transform(valid) 206 | test = transformer.transform(test) 207 | if reload and isinstance(train, DiskDataset) and isinstance( 208 | valid, DiskDataset) and isinstance(test, DiskDataset): 209 | dc.utils.data_utils.save_dataset_to_disk(save_folder, train, valid, test, 210 | transformers) 211 | return self.tasks, (train, valid, test), transformers 212 | 213 | def create_dataset(self) -> Dataset: 214 | """Subclasses must implement this to load the dataset.""" 215 | raise NotImplementedError() 216 | -------------------------------------------------------------------------------- /utils/features.py: -------------------------------------------------------------------------------- 1 | # allowable multiple choice node and edge features 2 | allowable_features = { 3 | 'possible_atomic_num_list': list(range(1, 119)) + ['[MASK]', 'misc'], 4 | 'possible_chirality_list': [ 5 | 'CHI_UNSPECIFIED', 6 | 'CHI_TETRAHEDRAL_CW', 7 | 'CHI_TETRAHEDRAL_CCW', 8 | 'CHI_OTHER' 9 | ], 10 | 'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], 11 | 'possible_formal_charge_list': [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'], 12 | 'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], 13 | 'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'], 14 | 'possible_hybridization_list': [ 15 | 'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'misc' 16 | ], 17 | 'possible_is_aromatic_list': [False, True], 18 | 'possible_is_in_ring_list': [False, True], 19 | 'possible_bond_type_list': [ 20 | 'SINGLE', 21 | 'DOUBLE', 22 | 'TRIPLE', 23 | 'AROMATIC', 24 | '[MASK]', 25 | '[SELF]', 26 | 'misc' 27 | ], 28 | 'possible_bond_stereo_list': [ 29 | 'STEREONONE', 30 | 'STEREOZ', 31 | 'STEREOE', 32 | 'STEREOCIS', 33 | 'STEREOTRANS', 34 | 'STEREOANY', 35 | '[MASK]', 36 | ], 37 | 'possible_is_conjugated_list': [False, True], 38 | } 39 | 40 | 41 | def safe_index(l, e): 42 | """ 43 | Return index of element e in list l. If e is not present, return the last index 44 | """ 45 | try: 46 | return l.index(e) 47 | except: 48 | return len(l) - 1 49 | 50 | 51 | def get_self_loops_typeid(): 52 | return allowable_features['possible_bond_type_list'].index('[SELF]') 53 | 54 | 55 | def get_mask_atom_typeid(): 56 | return allowable_features["possible_atomic_num_list"].index('[MASK]') 57 | 58 | 59 | def get_mask_edge_typeid(): 60 | return allowable_features['possible_bond_type_list'].index('[MASK]') 61 | 62 | 63 | def get_mask_atom_feature(): 64 | atom_feature = [ 65 | safe_index(allowable_features['possible_atomic_num_list'], '[MASK]'), 66 | allowable_features['possible_chirality_list'].index('CHI_UNSPECIFIED'), 67 | safe_index(allowable_features['possible_degree_list'], 'misc'), 68 | safe_index(allowable_features['possible_formal_charge_list'], 'misc'), 69 | safe_index(allowable_features['possible_numH_list'], 'misc'), 70 | safe_index(allowable_features['possible_number_radical_e_list'], 'misc'), 71 | safe_index(allowable_features['possible_hybridization_list'], 'misc'), 72 | allowable_features['possible_is_aromatic_list'].index(False), 73 | allowable_features['possible_is_in_ring_list'].index(False), 74 | ] 75 | return atom_feature 76 | 77 | 78 | # # miscellaneous case 79 | # i = safe_index(allowable_features['possible_atomic_num_list'], 'asdf') 80 | # assert allowable_features['possible_atomic_num_list'][i] == 'misc' 81 | # # normal case 82 | # i = safe_index(allowable_features['possible_atomic_num_list'], 2) 83 | # assert allowable_features['possible_atomic_num_list'][i] == 2 84 | 85 | def atom_to_feature_vector(atom): 86 | """ 87 | Converts rdkit atom object to feature list of indices 88 | :param mol: rdkit atom object 89 | :return: list 90 | """ 91 | atom_feature = [ 92 | safe_index(allowable_features['possible_atomic_num_list'], atom.GetAtomicNum()), 93 | allowable_features['possible_chirality_list'].index(str(atom.GetChiralTag())), 94 | safe_index(allowable_features['possible_degree_list'], atom.GetTotalDegree()), 95 | safe_index(allowable_features['possible_formal_charge_list'], atom.GetFormalCharge()), 96 | safe_index(allowable_features['possible_numH_list'], atom.GetTotalNumHs()), 97 | safe_index(allowable_features['possible_number_radical_e_list'], atom.GetNumRadicalElectrons()), 98 | safe_index(allowable_features['possible_hybridization_list'], str(atom.GetHybridization())), 99 | allowable_features['possible_is_aromatic_list'].index(atom.GetIsAromatic()), 100 | allowable_features['possible_is_in_ring_list'].index(atom.IsInRing()), 101 | ] 102 | return atom_feature 103 | 104 | 105 | # from rdkit import Chem 106 | # mol = Chem.MolFromSmiles('Cl[C@H](/C=C/C)Br') 107 | # atom = mol.GetAtomWithIdx(1) # chiral carbon 108 | # atom_feature = atom_to_feature_vector(atom) 109 | # assert atom_feature == [5, 2, 4, 5, 1, 0, 2, 0, 0] 110 | 111 | 112 | def get_atom_feature_dims(): 113 | return list(map(len, [ 114 | allowable_features['possible_atomic_num_list'], 115 | allowable_features['possible_chirality_list'], 116 | allowable_features['possible_degree_list'], 117 | allowable_features['possible_formal_charge_list'], 118 | allowable_features['possible_numH_list'], 119 | allowable_features['possible_number_radical_e_list'], 120 | allowable_features['possible_hybridization_list'], 121 | allowable_features['possible_is_aromatic_list'], 122 | allowable_features['possible_is_in_ring_list'] 123 | ])) 124 | 125 | 126 | def bond_to_feature_vector(bond): 127 | """ 128 | Converts rdkit bond object to feature list of indices 129 | :param mol: rdkit bond object 130 | :return: list 131 | """ 132 | bond_feature = [ 133 | safe_index(allowable_features['possible_bond_type_list'], str(bond.GetBondType())), 134 | allowable_features['possible_bond_stereo_list'].index(str(bond.GetStereo())), 135 | allowable_features['possible_is_conjugated_list'].index(bond.GetIsConjugated()), 136 | ] 137 | return bond_feature 138 | 139 | 140 | def get_bond_mask_feature(): 141 | bond_feature = [ 142 | safe_index(allowable_features['possible_bond_type_list'], '[MASK]'), 143 | safe_index(allowable_features['possible_bond_stereo_list'], '[MASK]'), 144 | allowable_features['possible_is_conjugated_list'].index(False), 145 | ] 146 | return bond_feature 147 | 148 | # uses same molecule as atom_to_feature_vector test 149 | # bond = mol.GetBondWithIdx(2) # double bond with stereochem 150 | # bond_feature = bond_to_feature_vector(bond) 151 | # assert bond_feature == [1, 2, 0] 152 | 153 | def get_bond_feature_dims(): 154 | return list(map(len, [ 155 | allowable_features['possible_bond_type_list'], 156 | allowable_features['possible_bond_stereo_list'], 157 | allowable_features['possible_is_conjugated_list'] 158 | ])) 159 | 160 | 161 | def atom_feature_vector_to_dict(atom_feature): 162 | [atomic_num_idx, 163 | chirality_idx, 164 | degree_idx, 165 | formal_charge_idx, 166 | num_h_idx, 167 | number_radical_e_idx, 168 | hybridization_idx, 169 | is_aromatic_idx, 170 | is_in_ring_idx] = atom_feature 171 | 172 | feature_dict = { 173 | 'atomic_num': allowable_features['possible_atomic_num_list'][atomic_num_idx], 174 | 'chirality': allowable_features['possible_chirality_list'][chirality_idx], 175 | 'degree': allowable_features['possible_degree_list'][degree_idx], 176 | 'formal_charge': allowable_features['possible_formal_charge_list'][formal_charge_idx], 177 | 'num_h': allowable_features['possible_numH_list'][num_h_idx], 178 | 'num_rad_e': allowable_features['possible_number_radical_e_list'][number_radical_e_idx], 179 | 'hybridization': allowable_features['possible_hybridization_list'][hybridization_idx], 180 | 'is_aromatic': allowable_features['possible_is_aromatic_list'][is_aromatic_idx], 181 | 'is_in_ring': allowable_features['possible_is_in_ring_list'][is_in_ring_idx] 182 | } 183 | 184 | return feature_dict 185 | 186 | 187 | # # uses same atom_feature as atom_to_feature_vector test 188 | # atom_feature_dict = atom_feature_vector_to_dict(atom_feature) 189 | # assert atom_feature_dict['atomic_num'] == 6 190 | # assert atom_feature_dict['chirality'] == 'CHI_TETRAHEDRAL_CCW' 191 | # assert atom_feature_dict['degree'] == 4 192 | # assert atom_feature_dict['formal_charge'] == 0 193 | # assert atom_feature_dict['num_h'] == 1 194 | # assert atom_feature_dict['num_rad_e'] == 0 195 | # assert atom_feature_dict['hybridization'] == 'SP3' 196 | # assert atom_feature_dict['is_aromatic'] == False 197 | # assert atom_feature_dict['is_in_ring'] == False 198 | 199 | def bond_feature_vector_to_dict(bond_feature): 200 | [bond_type_idx, 201 | bond_stereo_idx, 202 | is_conjugated_idx] = bond_feature 203 | 204 | feature_dict = { 205 | 'bond_type': allowable_features['possible_bond_type_list'][bond_type_idx], 206 | 'bond_stereo': allowable_features['possible_bond_stereo_list'][bond_stereo_idx], 207 | 'is_conjugated': allowable_features['possible_is_conjugated_list'][is_conjugated_idx] 208 | } 209 | 210 | return feature_dict 211 | # # uses same bond as bond_to_feature_vector test 212 | # bond_feature_dict = bond_feature_vector_to_dict(bond_feature) 213 | # assert bond_feature_dict['bond_type'] == 'DOUBLE' 214 | # assert bond_feature_dict['bond_stereo'] == 'STEREOE' 215 | # assert bond_feature_dict['is_conjugated'] == False 216 | -------------------------------------------------------------------------------- /iupac_regex/vocab.json: -------------------------------------------------------------------------------- 1 | {"": 0, "": 1, "": 2, "": 3, "": 4, ";": 5, ".": 6, ">": 7, "": 8, " ": 9, "(": 10, "(1+)": 11, "(1-)": 12, "(2+)": 13, "(2-)": 14, "(3+)": 15, "(3-)": 16, "(4+)": 17, "(5+)": 18, "(6+)": 19, "(7+)": 20, "(8+)": 21, ")": 22, ",": 23, "-": 24, "0": 25, "1": 26, "10": 27, "11": 28, "12": 29, "13": 30, "14": 31, "15": 32, "16": 33, "17": 34, "18": 35, "19": 36, "2": 37, "20": 38, "21": 39, "22": 40, "23": 41, "24": 42, "25": 43, "26": 44, "27": 45, "28": 46, "29": 47, "3": 48, "30": 49, "31": 50, "32": 51, "33": 52, "34": 53, "35": 54, "36": 55, "37": 56, "38": 57, "39": 58, "4": 59, "40": 60, "41": 61, "42": 62, "43": 63, "44": 64, "45": 65, "46": 66, "47": 67, "48": 68, "49": 69, "5": 70, "50": 71, "51": 72, "52": 73, "53": 74, "54": 75, "55": 76, "56": 77, "57": 78, "58": 79, "59": 80, "6": 81, "60": 82, "61": 83, "62": 84, "63": 85, "64": 86, "65": 87, "66": 88, "67": 89, "68": 90, "69": 91, "7": 92, "71": 93, "72": 94, "73": 95, "74": 96, "75": 97, "76": 98, "77": 99, "78": 100, "79": 101, "8": 102, "81": 103, "82": 104, "83": 105, "84": 106, "85": 107, "86": 108, "87": 109, "88": 110, "89": 111, "9": 112, "91": 113, "92": 114, "93": 115, "94": 116, "95": 117, "96": 118, "97": 119, "98": 120, "99": 121, "C": 122, "E": 123, "H": 124, "N": 125, "O": 126, "R": 127, "S": 128, "Z": 129, "[": 130, "]": 131, "^": 132, "a": 133, "aceanthrylen": 134, "acenaphthylen": 135, "acephenanthrylen": 136, "acet": 137, "acetaldehyde": 138, "aceto": 139, "acetyl": 140, "acid": 141, "acridin": 142, "acrido": 143, "actinium": 144, "adamant": 145, "al": 146, "alumane": 147, "alumanyl": 148, "aluminum": 149, "americium": 150, "amide": 151, "amido": 152, "amine": 153, "amino": 154, "an": 155, "ane": 156, "aniline": 157, "anilino": 158, "annulen": 159, "ano": 160, "anthracen": 161, "antimony": 162, "argon": 163, "ars": 164, "arsanyl": 165, "arsenic": 166, "arsindol": 167, "arsonic": 168, "arsono": 169, "astatine": 170, "ate": 171, "az": 172, "aza": 173, "azanida": 174, "azanide": 175, "azanidyl": 176, "azanium": 177, "azanyl": 178, "azido": 179, "azonia": 180, "azonio": 181, "azulen": 182, "b": 183, "barium": 184, "benz": 185, "benzaldehyde": 186, "benzhydryl": 187, "benzo": 188, "benzyl": 189, "berkelium": 190, "beryllium": 191, "bi": 192, "bis": 193, "bismuth": 194, "bor": 195, "bora": 196, "boran": 197, "borane": 198, "boranuida": 199, "boranuide": 200, "boranyl": 201, "borate": 202, "borepin": 203, "borinan": 204, "borinic": 205, "borinin": 206, "borino": 207, "boriran": 208, "borol": 209, "borolan": 210, "boron": 211, "borono": 212, "bromanyl": 213, "bromide": 214, "bromine": 215, "bromo": 216, "but": 217, "buta": 218, "butanoyl": 219, "butyl": 220, "c": 221, "cadmium": 222, "calcium": 223, "californium": 224, "carbaldehyde": 225, "carbamate": 226, "carbamic": 227, "carbamimid": 228, "carbamimidate": 229, "carbamimidothioate": 230, "carbamimidoyl": 231, "carbamo": 232, "carbamoyl": 233, "carbanid": 234, "carbazol": 235, "carbo": 236, "carbodithioate": 237, "carbohydrazide": 238, "carbohydrazonate": 239, "carbon": 240, "carbonate": 241, "carbonimidoyl": 242, "carbonitrile": 243, "carbono": 244, "carbonochloridate": 245, "carbonyl": 246, "carbothialdehyde": 247, "carbothioate": 248, "carbothioic": 249, "carbothioyl": 250, "carboxamide": 251, "carboximidamide": 252, "carboximidate": 253, "carboximidothioate": 254, "carboximidoyl": 255, "carboxy": 256, "carboxylate": 257, "carboxylato": 258, "carboxylic": 259, "cerium": 260, "cesium": 261, "chloranyl": 262, "chloride": 263, "chloridic": 264, "chlorido": 265, "chloridoyl": 266, "chlorine": 267, "chloro": 268, "chloroform": 269, "chromen": 270, "chromium": 271, "chrysen": 272, "cinnolin": 273, "cobalt": 274, "cont": 275, "conta": 276, "copper": 277, "coronen": 278, "corrin": 279, "cos": 280, "cosa": 281, "cuban": 282, "cumen": 283, "curium": 284, "cyanamide": 285, "cyanate": 286, "cyanato": 287, "cyanic": 288, "cyanide": 289, "cyano": 290, "cyclo": 291, "d": 292, "dec": 293, "deca": 294, "decanoyl": 295, "deuterio": 296, "di": 297, "do": 298, "dysprosium": 299, "e": 300, "ec": 301, "einsteinium": 302, "en": 303, "ene": 304, "eno": 305, "ep": 306, "erbium": 307, "et": 308, "eth": 309, "ethyl": 310, "europium": 311, "f": 312, "fermium": 313, "fluoranthen": 314, "fluoranyl": 315, "fluoren": 316, "fluoride": 317, "fluorine": 318, "fluoro": 319, "formaldehyde": 320, "formamide": 321, "formamido": 322, "formate": 323, "formic": 324, "formonitril": 325, "formyl": 326, "fulleren": 327, "furan": 328, "furo": 329, "g": 330, "gadolinium": 331, "gallane": 332, "gallanyl": 333, "gallium": 334, "germ": 335, "germa": 336, "germane": 337, "germanium": 338, "germyl": 339, "gold": 340, "guanidin": 341, "h": 342, "hafnium": 343, "hecta": 344, "helium": 345, "hen": 346, "heni": 347, "hept": 348, "hepta": 349, "heptalen": 350, "heptanoyl": 351, "heptyl": 352, "hex": 353, "hexa": 354, "hexacen": 355, "hexanoyl": 356, "hexyl": 357, "holmium": 358, "hydrate": 359, "hydrazide": 360, "hydrazin": 361, "hydrazine": 362, "hydrazonate": 363, "hydrazono": 364, "hydride": 365, "hydro": 366, "hydrogen": 367, "hydroxide": 368, "hydroxy": 369, "hydroxyl": 370, "hypobromite": 371, "hypochlorite": 372, "hypochlorous": 373, "hypofluorite": 374, "hypoiodite": 375, "hypoiodous": 376, "i": 377, "ic": 378, "icos": 379, "icosa": 380, "id": 381, "idene": 382, "idin": 383, "idine": 384, "imid": 385, "imidazo": 386, "imidazol": 387, "imidazolidin": 388, "imido": 389, "imine": 390, "imino": 391, "in": 392, "indacen": 393, "indazol": 394, "inden": 395, "indium": 396, "indol": 397, "indolizin": 398, "ino": 399, "iodane": 400, "iodanyl": 401, "iodide": 402, "iodine": 403, "iodo": 404, "ir": 405, "iridium": 406, "iron": 407, "iso": 408, "ium": 409, "j": 410, "k": 411, "kis": 412, "krypton": 413, "l": 414, "lanthanum": 415, "lawrencium": 416, "lead": 417, "lithium": 418, "lutetium": 419, "magnesium": 420, "manganese": 421, "mendelevium": 422, "mercury": 423, "meth": 424, "methyl": 425, "methylidene": 426, "molybdenum": 427, "mono": 428, "morpholin": 429, "n": 430, "naphthalen": 431, "naphtho": 432, "naphthyridin": 433, "neodymium": 434, "neon": 435, "neptunium": 436, "nickel": 437, "niobium": 438, "nitramide": 439, "nitramido": 440, "nitrate": 441, "nitrile": 442, "nitrite": 443, "nitro": 444, "nitrogen": 445, "nitroso": 446, "nitrous": 447, "nobelium": 448, "non": 449, "nona": 450, "nonanoyl": 451, "o": 452, "oate": 453, "oc": 454, "oct": 455, "octa": 456, "octanoyl": 457, "octyl": 458, "oic": 459, "ol": 460, "olo": 461, "on": 462, "one": 463, "osmium": 464, "ovalen": 465, "ox": 466, "oxa": 467, "oxalo": 468, "oxamide": 469, "oxanthren": 470, "oxide": 471, "oxido": 472, "oxino": 473, "oxo": 474, "oxy": 475, "oxygen": 476, "palladium": 477, "pent": 478, "penta": 479, "pentacen": 480, "pentalen": 481, "pentamin": 482, "pentanoyl": 483, "pentyl": 484, "perbromate": 485, "perchlorate": 486, "perchloric": 487, "perimidin": 488, "periodate": 489, "peroxo": 490, "peroxy": 491, "perylen": 492, "phen": 493, "phenacyl": 494, "phenanthren": 495, "phenanthridin": 496, "phenanthro": 497, "phenanthrolin": 498, "pheno": 499, "phenyl": 500, "phosph": 501, "phospha": 502, "phosphane": 503, "phosphanide": 504, "phosphanium": 505, "phosphanyl": 506, "phosphate": 507, "phosphinite": 508, "phosphinolin": 509, "phosphinous": 510, "phosphite": 511, "phosphonamidic": 512, "phosphonato": 513, "phosphonia": 514, "phosphono": 515, "phosphonous": 516, "phosphoroso": 517, "phosphorus": 518, "phosphoryl": 519, "phthalaldehyde": 520, "phthalate": 521, "phthalazin": 522, "phthalic": 523, "picen": 524, "piperazin": 525, "piperidin": 526, "platinum": 527, "pleiaden": 528, "plumbyl": 529, "plutonium": 530, "porphyrin": 531, "potassium": 532, "praseodymium": 533, "promethium": 534, "prop": 535, "propa": 536, "propanoyl": 537, "propyl": 538, "protactinium": 539, "protide": 540, "protio": 541, "pteridin": 542, "purin": 543, "pyran": 544, "pyranthren": 545, "pyrazin": 546, "pyrazol": 547, "pyren": 548, "pyridazin": 549, "pyridin": 550, "pyrido": 551, "pyrimidin": 552, "pyrimido": 553, "pyrrol": 554, "pyrrolidin": 555, "pyrrolizin": 556, "quinazolin": 557, "quinolin": 558, "quinolizin": 559, "quinoxalin": 560, "radon": 561, "rhenium": 562, "rhodium": 563, "rubidium": 564, "ruthenium": 565, "samarium": 566, "scandium": 567, "selanyl": 568, "selen": 569, "selenium": 570, "seleno": 571, "sil": 572, "sila": 573, "silane": 574, "silanide": 575, "silanyl": 576, "silicate": 577, "silicon": 578, "silole": 579, "silver": 580, "silyl": 581, "silylo": 582, "silyloxy": 583, "sodium": 584, "spiro": 585, "stannane": 586, "stannyl": 587, "stiba": 588, "stiboryl": 589, "sulfamate": 590, "sulfamic": 591, "sulfamoyl": 592, "sulfane": 593, "sulfanium": 594, "sulfanyl": 595, "sulfanylidene": 596, "sulfate": 597, "sulfide": 598, "sulfido": 599, "sulfinamide": 600, "sulfinamoyl": 601, "sulfinate": 602, "sulfinato": 603, "sulfinic": 604, "sulfinimidoyl": 605, "sulfino": 606, "sulfinyl": 607, "sulfite": 608, "sulfo": 609, "sulfonamide": 610, "sulfonamido": 611, "sulfonate": 612, "sulfonato": 613, "sulfonic": 614, "sulfonimidoyl": 615, "sulfonio": 616, "sulfono": 617, "sulfonyl": 618, "sulfur": 619, "tantalum": 620, "technetium": 621, "tellanyl": 622, "tellur": 623, "tellurium": 624, "telluro": 625, "terbium": 626, "terephthalaldehyde": 627, "terephthalate": 628, "terephthalic": 629, "tert-butyl": 630, "tetr": 631, "tetra": 632, "tetracen": 633, "tetramin": 634, "thallium": 635, "thi": 636, "thia": 637, "thial": 638, "thian": 639, "thianthren": 640, "thio": 641, "thiocyanate": 642, "thiol": 643, "thiolo": 644, "thione": 645, "thiophen": 646, "thiourea": 647, "thorium": 648, "thulium": 649, "tin": 650, "titanium": 651, "tri": 652, "tria": 653, "tris": 654, "tritio": 655, "trityl": 656, "tungsten": 657, "un": 658, "uranium": 659, "urea": 660, "vanadium": 661, "xanthen": 662, "xenon": 663, "yl": 664, "yn": 665, "yne": 666, "yohimban": 667, "ytterbium": 668, "yttrium": 669, "zinc": 670, "zirconium": 671} -------------------------------------------------------------------------------- /cliff_pair.csv: -------------------------------------------------------------------------------- 1 | cliff_pair_a,cliff_pair_b 2 | COc1ncc(-c2ccc3ncc4c(c3n2)n(C2CCN(C(=O)[C@H](C)O)CC2)c(=O)n4C)cn1,COc1ncc(-c2ccc3ncc4c(c3n2)n(C2CCN(C(=O)CN(C)C)CC2)c(=O)n4C)cn1 3 | CCc1cccc(N2CCN(CCCCNC(=O)c3cc4ccccc4[nH]3)CC2)c1Cl,CCc1cccc(N2CCN(CC(O)CCNC(=O)c3cc4ccccc4[nH]3)CC2)c1Cl 4 | CC[C@@H]1[C@@H]2C[C@H](O)CC[C@]2(C)[C@H]2CC[C@@]3(C)[C@@H](CC[C@@H]3[C@H](C)CCC(C)=O)[C@@H]2[C@@H]1O,CC[C@@H]1[C@@H]2C[C@H](O)CC[C@]2(C)[C@H]2CC[C@@]3(C)[C@@H](CC[C@@H]3[C@H](C)CCC(=O)NS(C)(=O)=O)[C@@H]2[C@@H]1O 5 | Fc1ccc(C(OCCN2CCN(CCCc3ccccc3)CC2)c2ccc(F)cc2)cc1,C=C(Cc1ccccc1)CN1CCN(CCOC(c2ccc(F)cc2)c2ccc(F)cc2)CC1 6 | O=C(CN1CCN(CCOC(c2ccccc2)c2ccccc2)CC1)c1ccccc1,C=C(CCN1CCN(CCOC(c2ccccc2)c2ccccc2)CC1)c1ccccc1 7 | CC[C@@H]1[C@@H]2C[C@H](O)CC[C@]2(C)[C@H]2CC[C@@]3(C)[C@@H](CC[C@@H]3[C@H](C)CCO)[C@@H]2[C@@H]1O,CC[C@@H]1[C@@H]2C[C@H](O)CC[C@]2(C)[C@H]2CC[C@@]3(C)[C@@H](CC[C@@H]3[C@H](C)CCNC(C)=O)[C@@H]2[C@@H]1O 8 | CC(C)c1onc(-c2c(Cl)cccc2Cl)c1COc1ccc(-c2nc3ccc(C(=O)O)cc3n2C2CCCC2)cc1,CC(=O)N1CCC(n2c(-c3ccc(OCc4c(-c5c(Cl)cccc5Cl)noc4C(C)C)cc3)nc3ccc(C(=O)O)cc32)CC1 9 | C=CC[C@@H]1C2C[C@H](O)CC[C@]2(C)[C@H]2CC[C@]3(C)C([C@H](C)CCC(=O)O)CC[C@H]3[C@@H]2[C@@H]1O,C[C@H](CCC(=O)O)C1CC[C@H]2[C@H]3[C@H](CC[C@]12C)[C@@]1(C)CC[C@@H](O)CC1[C@@H](CCO)[C@H]3O 10 | CC(C)c1onc(-c2c(Cl)cccc2Cl)c1COc1ccc(-c2nc3ccc(C(=O)O)cc3n2C2CCC2)cc1,CC(=O)N1CCC(n2c(-c3ccc(OCc4c(-c5c(Cl)cccc5Cl)noc4C(C)C)cc3)nc3ccc(C(=O)O)cc32)CC1 11 | OCc1ccc(C2Nc3ccccc3-c3ccnc4[nH]cc2c34)cc1,Nc1ccc(C2Nc3ccccc3-c3ccnc4[nH]cc2c34)cc1 12 | CC[C@@H]1[C@@H]2C[C@H](O)CC[C@]2(C)[C@H]2CC[C@@]3(C)[C@@H](CC[C@@H]3[C@H](C)CCNC(=O)C(=O)O)[C@@H]2[C@@H]1O,CC[C@@H]1[C@@H]2C[C@H](O)CC[C@]2(C)[C@H]2CC[C@@]3(C)[C@@H](CC[C@@H]3[C@H](C)CCNC(=O)OC)[C@@H]2[C@@H]1O 13 | Oc1ccc(C2Nc3ccccc3-c3ccnc4[nH]cc2c34)cc1,O=C(O)c1ccc(C2Nc3ccccc3-c3ccnc4[nH]cc2c34)cc1 14 | CC(C)c1onc(-c2c(Cl)cccc2Cl)c1COc1ccc(-c2nc3ccc(C(=O)O)cc3n2C2CCCCC2)cc1,CC(=O)N1CCC(n2c(-c3ccc(OCc4c(-c5c(Cl)cccc5Cl)noc4C(C)C)cc3)nc3ccc(C(=O)O)cc32)CC1 15 | CCc1cc(Cl)c(OC)c(N2CCN(CCCCNC(=O)c3cc4ccccc4[nH]3)CC2)c1,CCc1cc(Cl)c(OC)c(N2CCN(CC(O)CCNC(=O)c3cc4ccccc4[nH]3)CC2)c1 16 | CC[C@@H]1[C@@H]2C[C@H](O)CC[C@]2(C)[C@H]2CC[C@@]3(C)[C@@H](CC[C@@H]3[C@H](C)CCNC(N)=O)[C@@H]2[C@@H]1O,CC[C@@H]1[C@@H]2C[C@H](O)CC[C@]2(C)[C@H]2CC[C@@]3(C)[C@@H](CC[C@@H]3[C@H](C)CCNC(=O)C(=O)O)[C@@H]2[C@@H]1O 17 | CC(=O)c1ccc(CNC(=O)c2ccc(C(C)(C)C)cc2)cc1,CC(C)(C)c1ccc(C(=O)NCc2ccc(C(N)=O)cc2)cc1 18 | Cc1ccc(F)cc1-c1ccc2cc(NC(=O)C3CCCC3)ncc2c1,Cc1ccc(F)cc1-c1ccc2cc(NC(=O)C3CCN(C)CC3)ncc2c1 19 | Fc1ccc(C2Nc3ccccc3-c3ccnc4[nH]cc2c34)cc1,OCc1ccc(C2Nc3ccccc3-c3ccnc4[nH]cc2c34)cc1 20 | O=C(NC1CCN(CCCCc2ccccc2)CC1)c1ccc2ccccc2c1,O=C(NC1CCN(Cc2ccccc2)CC1)c1ccc2ccccc2c1 21 | COc1ccccc1N1CCN(CCCCNC(=O)c2ccc(Br)cc2)CC1,COc1ccccc1N1CCN(CCCCNC(=O)c2ccc(CCF)cc2)CC1 22 | CC(C)OC(=O)C1=CN(C(=O)c2ccc(OCCCN3CCOCC3)cc2)CC(C)(C)c2c1[nH]c1ccccc21,CC(C)OC(=O)C1=CN(C(=O)c2ccc(OCCN3CCCCC3)cc2)CC(C)(C)c2c1[nH]c1ccccc21 23 | COc1ncc(-c2ccc3ncc4c(c3n2)n(C2CCN(C(=O)[C@H](C)O)CC2)c(=O)n4C)cn1,COc1ncc(-c2ccc3ncc4c(c3n2)n(C2CCN(C)CC2)c(=O)n4C)cn1 24 | Cc1ccc(CO)cc1-c1ccc2cc(NC(=O)C3CC3)ncc2c1,Cc1ccc(Cl)cc1-c1ccc2cc(NC(=O)C3CC3)ncc2c1 25 | Cc1ccc(O)cc1-c1ccc2cc(NC(=O)C3CC3)ncc2c1,Cc1ccc(C(C)O)cc1-c1ccc2cc(NC(=O)C3CC3)ncc2c1 26 | COc1cccc(C(=O)CCCCN2CCN(c3ccccc3OC)CC2)c1,COc1cccc(C(=O)NCCCN2CCN(c3ccccc3OC)CC2)c1 27 | CN(C)c1ccc(-c2ccc3ncc4c(c3n2)n(C2CCN(C(N)=O)CC2)c(=O)n4C)cn1,CN1CCC(n2c(=O)n(C)c3cnc4ccc(-c5ccc(N(C)C)nc5)nc4c32)CC1 28 | CCCc1cc(-c2nc(C)c(C(C)=O)s2)ccc1OCCCOc1ccc2c(c1)CC[C@H]2CC(=O)O,CCCc1cc(-c2nc(CO)c(C(=O)O)s2)ccc1OCCCOc1ccc2c(c1)CC[C@H]2CC(=O)O 29 | Cc1ccc([C@H]2C[C@@H]3CC[C@H]([C@H]2c2ncc(-c4ccc(Cl)cc4)s2)N3C)cc1,CN1[C@H]2CC[C@@H]1[C@@H](c1ncc(-c3ccc([N+](=O)[O-])cc3)s1)[C@@H](c1ccc(Cl)cc1)C2 30 | Cc1ccc(CO)cc1-c1ccc2cc(NC(=O)C3CC3)ncc2c1,Cc1ccc(C)c(-c2ccc3cc(NC(=O)C4CC4)ncc3c2)c1 31 | Oc1ccc(C2Nc3ccccc3-c3ccnc4[nH]cc2c34)cc1,COc1ccc(C2Nc3ccccc3-c3ccnc4[nH]cc2c34)cc1 32 | COc1ccc(-c2ccc3ncc4c(c3n2)n(C2CCOCC2)c(=O)n4C)cn1,COc1ccc(-c2ccc3ncc4c(c3n2)n(C2CCOCC2)c(=O)n4CCO)cn1 33 | COc1cccc(C(=O)CCCCN2CCN(c3ccccc3OC)CC2)c1,COc1cccc(C(=O)NCCCCN2CCN(c3ccccc3OC)CC2)c1 34 | Cc1cc(N)ncc1-c1ccc2cc(NC(=O)C3CC3)ncc2c1,Cc1cc(CO)ncc1-c1ccc2cc(NC(=O)C3CC3)ncc2c1 35 | COc1cccc(C(=O)NCCCCN2CCN(c3ccc(Cl)cc3)CC2)c1,COc1cccc(C(=O)CCCCCN2CCN(c3ccc(Cl)cc3)CC2)c1 36 | CCO[C@@H](Cc1cccc(/C(C)=N/OCc2ccc(O)cc2)c1)C(=O)O,CCO[C@@H](Cc1cccc(/C(C)=N/OCc2ccc(I)cc2)c1)C(=O)O 37 | Cc1nc(-c2cc(Cl)cnc2Nc2cccc3[nH]ncc23)c2nc[nH]c2n1,Cc1nc(-c2cc(CO)cnc2Nc2cccc3[nH]ncc23)c2nc[nH]c2n1 38 | Fc1ccc(C2Nc3ccccc3-c3ccnc4[nH]cc2c34)cc1,COc1ccc(C2Nc3ccccc3-c3ccnc4[nH]cc2c34)cc1 39 | COc1cccc(C(=O)CCCCN2CCN(c3ccc(Cl)cc3)CC2)c1,COc1cccc(C(=O)NCCCCCN2CCN(c3ccc(Cl)cc3)CC2)c1 40 | Clc1ccc(N2CCN(Cc3cnn(-c4ccccc4)c3)CC2)cc1,Clc1ccc(-n2cc(COCCN3CCN(c4ccccc4)CC3)cn2)cc1 41 | Oc1ccc(CN2CCO[C@H](CCc3ccccc3)C2)cc1,Fc1ccc(CN2CCOC(CCc3ccccc3)C2)cc1 42 | CC(=O)c1cc(NC(=O)NCCC[C@H]2C[C@H](Cc3ccc(F)cc3)CCN2C(=N)N)cc(C(C)=O)c1,CC(=O)c1cc(NC(=O)NCCC[C@H]2C[C@H](Cc3ccc(F)cc3)CCN2CC(N)=O)cc(C(C)=O)c1 43 | Cc1ccncc1-c1ccc2cc(NC(=O)C3CC3)ncc2c1,Cc1ccncc1-c1ccc2cc(NC(=O)C3CCN(C)CC3)ncc2c1 44 | CCc1ccccc1-c1ccc(O[C@@H](Cc2ccccc2)C(=O)O)cc1,COc1ccccc1-c1ccc(O[C@@H](Cc2ccccc2)C(=O)O)cc1 45 | O=C(Nc1cc2ccc(-c3cnccc3Cl)cc2cn1)C1CC1,COc1ccncc1-c1ccc2cc(NC(=O)C3CC3)ncc2c1 46 | O=Cc1cnn2ccc(OCCCCN3CCN(c4cccc(Cl)c4Cl)CC3)cc12,O/N=C/c1cnn2ccc(OCCCCN3CCN(c4cccc(Cl)c4Cl)CC3)cc12 47 | Cc1ccc(O)cc1-c1ccc2cc(NC(=O)C3CC3)ncc2c1,COc1ccc(C)c(-c2ccc3cc(NC(=O)C4CC4)ncc3c2)c1 48 | COc1ccccc1N1CCN(CCCCNC(=O)c2ccc(Br)cc2)CC1,COc1ccccc1N1CCN(CCCCNC(=O)c2ccc(N(C)C)cc2)CC1 49 | CNC(=O)c1cc2cc(C3(Cc4ccccc4)CCNC3)ccc2[nH]1,NC(=O)c1cc2cc(C3(Cc4ccccc4)CCNC3)ccc2[nH]1 50 | CC(C)c1onc(-c2c(Cl)cccc2Cl)c1COc1ccc(COc2cccc(C#N)c2)cc1,CC(C)c1onc(-c2c(Cl)cccc2Cl)c1COc1ccc(COc2cccc(C(=O)O)c2)cc1 51 | CCCc1c(OCCCSc2ccc(CC(=O)O)cc2Cl)ccc(/C(CC)=N/O)c1O,CCCc1c(OCCCSc2ccc(CC(=O)O)cc2Cl)ccc(C(C)=O)c1O 52 | CC(=O)[C@@]12CNCC[C@]1(c1ccc(Cl)c(Cl)c1)C2,NC(=O)[C@@]12CNCC[C@]1(c1ccc(Cl)c(Cl)c1)C2 53 | Cc1cc(C(=O)O)ccc1NC(=O)[C@H](C1CCCCC1)n1c(-c2ccc(Cl)cc2)nc2cc(F)c(F)cc21,O=C(O)c1ccc(NC(=O)C(C2CCCCC2)n2c(-c3ccc(Cl)cc3)nc3cc(F)c(F)cc32)nc1 54 | Clc1ccc(-c2cc(C3CCN(Cc4ccccc4)CC3)[nH]n2)cc1,Clc1ccc(-c2cc(C3CCN(CCCc4ccccc4)CC3)[nH]n2)cc1 55 | Cc1cc(CNC(=O)c2cc(-c3ccc(-c4ccoc4)cc3)n(C)n2)ccc1OC(C)(C)C(=O)O,Cc1cc(CNC(=O)c2cc(-c3ccc(-c4ccncc4)cc3)n(C)n2)ccc1OC(C)(C)C(=O)O 56 | CC(=O)c1cc(NC(=O)NCCC[C@H]2C[C@H](Cc3ccc(F)cc3)CCN2C(=N)N)cc(C(C)=O)c1,CCC(=O)N1CC[C@@H](Cc2ccc(F)cc2)C[C@@H]1CCCNC(=O)Nc1cc(C(C)=O)cc(C(C)=O)c1 57 | CC(=O)Nc1nc2ccc(-c3ccnc(Nc4ccccc4)n3)cc2s1,CC(=O)Nc1nc2ccc(-c3ccnc(NCc4ccccc4)n3)cc2s1 58 | CC(C)(C)c1ccc(C(=O)NCc2ccc(C(=O)O)cc2)cc1,CC(C)(C)c1ccc(C(=O)NCc2ccc(C(N)=O)cc2)cc1 59 | O=C(N[C@H]1CC[C@H](O)CC1)[C@H](C1CCCCC1)n1c(-c2ccc(Cl)cc2)nc2cc(F)c(F)cc21,O=C(O)C[C@H]1CC[C@H](NC(=O)[C@H](C2CCCCC2)n2c(-c3ccc(Cl)cc3)nc3cc(F)c(F)cc32)CC1 60 | CCC(Cc1ccc(OC)c(C(=O)NCc2ccc(OCCc3ccccc3)cc2)c1)C(=O)O,CCC(Cc1ccc(OC)c(C(=O)NCc2ccc(Oc3ccccc3)cc2)c1)C(=O)O 61 | Cc1ccc(NC(=O)c2ccccc2NC(=O)c2ccc(C(C)(C)C)cc2)cc1C(=O)O,CC(C)(C)c1ccc(C(=O)Nc2ccccc2C(=O)Nc2ccc(Br)c(C(=O)O)c2)cc1 62 | CCCCOc1ccc(CC(CCC)C(=O)O)cc1CNC(=O)c1ccc(C(F)(F)F)cc1F,CCCC(Cc1ccc(OC)c(CNC(=O)c2ccc(C(F)(F)F)cc2F)c1)C(=O)O 63 | Clc1ccc(OCCNCCCOc2ccc(Br)cc2)cc1,Fc1ccc(OCCCNCCOc2ccc(Cl)cc2)cc1 64 | CNC(=O)N1CCC(n2c(=O)n(C)c3cnc4ccc(-c5ccc(N(C)C)nc5)nc4c32)CC1,CN1CCC(n2c(=O)n(C)c3cnc4ccc(-c5ccc(N(C)C)nc5)nc4c32)CC1 65 | COc1cc(/C=C/C(=O)c2ccc3c(c2)sc(=O)n3CCOc2ccc(CC(OCC(F)(F)F)C(=O)O)cc2)cc(OC)c1,COc1cc(/C=C/c2ccc3c(c2)sc(=O)n3CCOc2ccc(CC(OCC(F)(F)F)C(=O)O)cc2)cc(OC)c1 66 | COC(C)(C)[C@H](O)[C@@H](O)C[C@@H](C)C1=C2C[C@H](O)[C@H]3[C@@]4(C)CCC(=O)C(C)(C)[C@@H]4CC[C@]3(C)[C@@]2(C)CC1,CCCCOC(C)(C)[C@H](O)[C@@H](O)C[C@@H](C)C1=C2C[C@H](O)[C@H]3[C@@]4(C)CCC(=O)C(C)(C)[C@@H]4CC[C@]3(C)[C@@]2(C)CC1 67 | Fc1ccc(C(OC2CC3CCC(C2)N3CCOCc2ccccc2)c2ccc(F)cc2)cc1,Fc1ccc(C(OC2CC3CCC(C2)N3CCCOc2ccccc2)c2ccc(F)cc2)cc1 68 | Cc1ccncc1-c1ccc2cc(NC(=O)C3CC3)ncc2c1,O=C(Nc1cc2ccc(-c3ccncc3Cl)cc2cn1)C1CC1 69 | COc1ccc2c(c1)CCCN(C)CCCc1ccccc1C2,COc1ccc2c(c1)CCN(C)CCc1ccccc1C2 70 | Oc1ccc(C2Nc3ccccc3-c3ccnc4[nH]cc2c34)cc1,Nc1ccc(C2Nc3ccccc3-c3ccnc4[nH]cc2c34)cc1 71 | C[C@H](CCC(=O)O)C1CC[C@H]2[C@H]3[C@H](CC[C@]12C)[C@@]1(C)CC[C@@H](O)CC1[C@@H](C)[C@H]3O,CO[C@@H]1C2C[C@H](O)CC[C@]2(C)[C@H]2CC[C@]3(C)C([C@H](C)CCC(=O)O)CC[C@H]3[C@@H]2[C@@H]1O 72 | Cc1nc(-c2ccc(C(F)(F)F)cc2)sc1C(=O)Nc1ccc(OC(C)(C)C(=O)O)cc1,Cc1nc(-c2ccc(C(F)(F)F)cc2)sc1C(=O)NCCc1ccc(OC(C)(C)C(=O)O)cc1 73 | COc1ccccc1N1CCN(CCCCNC(=O)c2ccc(F)cc2)CC1,COc1ccccc1N1CCN(CCCCNC(=O)c2ccc(N(C)C)cc2)CC1 74 | O=C(O)c1ccc2nc(N3CCC(OCc4c(-c5ccccc5C(F)(F)F)noc4C4CC4)CC3)sc2c1,O=C(O)c1ccc2nc(N3CCCC(OCc4c(-c5ccccc5OC(F)(F)F)noc4C4CC4)CC3)sc2c1 75 | C=C(Cc1ccccc1)CN1CCN(CCOC(c2ccccc2)c2ccccc2)CC1,O=C(Cc1ccccc1)CN1CCN(CCOC(c2ccccc2)c2ccccc2)CC1 76 | N#Cc1ccncc1-c1ccc2cc(NC(=O)C3CC3)ncc2c1,Cc1ccncc1-c1ccc2cc(NC(=O)C3CC3)ncc2c1 77 | Fc1ccc(C2Nc3ccccc3-c3ccnc4[nH]cc2c34)cc1,Oc1ccc(C2Nc3ccccc3-c3ccnc4[nH]cc2c34)cc1 78 | Cc1ccncc1-c1ccc2cc(NC(=O)C3CC3)ncc2c1,COc1ccncc1-c1ccc2cc(NC(=O)C3CC3)ncc2c1 79 | C[C@H](CCC(=O)O)C1CC[C@H]2[C@H]3[C@H](CC[C@]12C)[C@@]1(C)CC[C@@H](O)CC1[C@@H](CCO)[C@H]3O,C#CC[C@@H]1C2C[C@H](O)CC[C@]2(C)[C@H]2CC[C@]3(C)C([C@H](C)CCC(=O)O)CC[C@H]3[C@@H]2[C@@H]1O 80 | Cc1nc(-c2ccc(N)cc2)sc1C(=O)NCc1ccc(OC(C)(C)C(=O)O)cc1,Cc1nc(-c2ccc(C(F)(F)F)cc2)sc1C(=O)NCc1ccc(OC(C)(C)C(=O)O)cc1 81 | CCO[C@@H](Cc1cccc(/C(C)=N/OCc2ccc(OC)cc2)c1)C(=O)O,CCO[C@@H](Cc1cccc(/C(C)=N/OCc2ccc(O)cc2)c1)C(=O)O 82 | Cc1cnccc1-c1ccc2cc(NC(=O)C3CC3)ncc2c1,O=C(Nc1cc2ccc(-c3cnccc3Cl)cc2cn1)C1CC1 83 | COc1ccccc1N1CCN(CCCCNC(=O)c2ccc(F)cc2)CC1,COc1ccccc1N1CCN(CCCCNC(=O)c2ccc(CCF)cc2)CC1 84 | CN(C)c1ccc(-c2ccc3ncc4c(c3n2)n(C2CCOCC2)c(=O)n4C)cn1,CN(C)c1ccc(-c2ccc3ncc4c(c3n2)n(C2CCOCC2)c(=O)n4CCO)cn1 85 | Cc1nc(-c2ccc(N)cc2)sc1C(=O)NCc1ccc(OC(C)(C)C(=O)O)cc1,Cc1nc(-c2ccc(F)cc2)sc1C(=O)NCc1ccc(OC(C)(C)C(=O)O)cc1 86 | CCO[C@@H](Cc1cccc(/C(C)=N/OCc2ccc(O)cc2)c1)C(=O)O,CCO[C@@H](Cc1cccc(/C(C)=N/OCc2ccc(Br)cc2)c1)C(=O)O 87 | CC(=O)c1cccc(NC(=O)NCCC[C@H]2C[C@H](Cc3ccc(F)cc3)CCN2CCO)c1,CC(=O)c1cccc(NC(=O)NCCC[C@H]2C[C@H](Cc3ccc(F)cc3)CCN2C)c1 88 | CCCc1c(OCCCCOc2cccc(C(=O)O)c2)ccc2c(-c3ccccc3)noc12,CCCc1c(OCCCCOc2cccc(OC(C)(C)C(=O)O)c2)ccc2c(-c3ccccc3)noc12 89 | Cc1ccc(OC(=O)N(CC(=O)O)Cc2cccc(OCc3nc(-c4ccc(Cl)cc4)oc3C)c2)cc1,Cc1ccc(OC(=O)N(CC(=O)O)Cc2cccc(OCc3nc(-c4ccc(Cl)cc4)sc3C)c2)cc1 90 | CCCN(CCCCNC(=O)c1ccc(-c2ccccc2)cc1)C1CCn2ncc(Cl)c2C1,CCCN(CCCCNC(=O)c1ccc(-c2ccccc2)cc1)C1CCn2ncc(C=O)c2C1 91 | N#Cc1c(N2CCOCC2)sc(C(=O)O)c1-c1ccc(Cl)cc1,N#Cc1c(N2CCOCC2)sc(C(N)=O)c1-c1ccc(Cl)cc1 92 | Cc1ccc(-c2ccc3ncc4c(c3n2)n(C2CCOCC2)c(=O)n4C)cn1,Cc1ccc(-c2ccc3ncc4c(c3n2)n(C2CCOCC2)c(=O)n4CCO)cn1 93 | Cc1ccccc1OCCNCCCOc1ccccc1,Clc1ccccc1OCCNCCCOc1ccccc1 94 | Cc1ccc(O)cc1-c1ccc2cc(NC(=O)C3CC3)ncc2c1,Cc1ccc(Cl)cc1-c1ccc2cc(NC(=O)C3CC3)ncc2c1 95 | Cc1c(Cl)cccc1-c1ccc2cc(NC(=O)C3CC3)ncc2c1,Cc1c(O)cccc1-c1ccc2cc(NC(=O)C3CC3)ncc2c1 96 | COc1cccc2c1CCN2C(=O)CN1CCN(Cc2ccc(Cl)cc2)CC1,Cc1cccc2c1CCN2C(=O)CN1CCN(Cc2ccc(Cl)cc2)CC1 97 | CC(C)n1nc(-c2ccc3oc(N)nc3c2)c2c(F)nc(N)nc21,CCc1nc(N)nc2c1c(-c1ccc3oc(N)nc3c1)nn2C(C)C 98 | Cc1ccc(O)cc1-c1ccc2cc(NC(=O)C3CC3)ncc2c1,Cc1ccc(C)c(-c2ccc3cc(NC(=O)C4CC4)ncc3c2)c1 99 | CN(Cc1cccc(F)c1)[C@H]1C2C3CC4C5C3CC2C5C41,CN(CCc1cccc(F)c1)[C@H]1C2C3CC4C5C3CC2C5C41 100 | Fc1cc(N2CCNCC2Cc2ccccc2)cc2cc[nH]c12,COc1cc(N2CCNCC2Cc2ccccc2)cc2cc[nH]c12 101 | C=C(Cc1ccccc1)CN1CCN(CCOC(c2ccc(F)cc2)c2ccc(F)cc2)CC1,O=C(Cc1ccccc1)CN1CCN(CCOC(c2ccc(F)cc2)c2ccc(F)cc2)CC1 102 | -------------------------------------------------------------------------------- /utils/molnet_dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | import numpy as np 4 | 5 | import pandas as pd 6 | from utils.sampl_datasets import load_sampl_iupac 7 | from deepchem.molnet import load_clearance, \ 8 | load_qm7, load_qm8, load_qm9, load_muv 9 | 10 | # load_delaney, load_hiv, load_lipo, load_bbbp, load_tox21 11 | from utils.clintox_datasets import load_clintox 12 | from utils.sider_datasets import load_sider 13 | 14 | from utils.delaney_datasets import load_delaney # esol 15 | from utils.hiv_datasets import load_hiv 16 | from utils.lipo_datasets import load_lipo 17 | from utils.bbbp_datasets import load_bbbp 18 | 19 | from utils.tox21_datasets import load_tox21 20 | from utils.toxcast_datasets import load_toxcast 21 | from utils.bace_datasets import load_bace_classification, load_bace_regression 22 | from utils.malaria_datasets import load_malaria 23 | from utils.cep_datasets import load_cep 24 | 25 | 26 | from rdkit import Chem 27 | 28 | MOLNET_DIRECTORY = { 29 | "bace_classification": { 30 | "dataset_type": "classification", 31 | "load_fn": load_bace_classification, 32 | "split": "scaffold", 33 | }, 34 | "bace_regression": { 35 | "dataset_type": "regression", 36 | "load_fn": load_bace_regression, 37 | "split": "scaffold", 38 | }, 39 | "bbbp": { 40 | "dataset_type": "classification", 41 | "load_fn": load_bbbp, 42 | "split": "scaffold", 43 | }, 44 | "clearance": { 45 | "dataset_type": "regression", 46 | "load_fn": load_clearance, 47 | "split": "scaffold", 48 | }, 49 | "clintox": { 50 | "dataset_type": "classification", 51 | "load_fn": load_clintox, 52 | "split": "scaffold", 53 | # "tasks_wanted": ["CT_TOX"], 54 | }, 55 | "delaney": { 56 | "dataset_type": "regression", 57 | "load_fn": load_delaney, 58 | "split": "scaffold", 59 | }, 60 | "hiv": { 61 | "dataset_type": "classification", 62 | "load_fn": load_hiv, 63 | "split": "scaffold", 64 | }, 65 | # pcba is very large and breaks the dataloader 66 | # "pcba": { 67 | # "dataset_type": "classification", 68 | # "load_fn": load_pcba, 69 | # "split": "scaffold", 70 | # }, 71 | "lipo": { 72 | "dataset_type": "regression", 73 | "load_fn": load_lipo, 74 | "split": "scaffold", 75 | }, 76 | "qm7": { 77 | "dataset_type": "regression", 78 | "load_fn": load_qm7, 79 | "split": "random", 80 | }, 81 | "qm8": { 82 | "dataset_type": "regression", 83 | "load_fn": load_qm8, 84 | "split": "random", 85 | }, 86 | "qm9": { 87 | "dataset_type": "regression", 88 | "load_fn": load_qm9, 89 | "split": "random", 90 | }, 91 | "sider": { 92 | "dataset_type": "classification", 93 | "load_fn": load_sider, 94 | "split": "scaffold", 95 | }, 96 | "tox21": { 97 | "dataset_type": "classification", 98 | "load_fn": load_tox21, 99 | "split": "scaffold", 100 | # "tasks_wanted": ["SR-p53"], 101 | }, 102 | "toxcast": { 103 | "dataset_type": "classification", 104 | "load_fn": load_toxcast, 105 | "split": "scaffold", 106 | # "tasks_wanted": ["SR-p53"], 107 | }, 108 | "muv": { 109 | "dataset_type": "classification", 110 | "load_fn": load_muv, 111 | "split": "scaffold", 112 | # "tasks_wanted": ["SR-p53"], 113 | }, 114 | "sampl": { 115 | "dataset_type": "regression", 116 | "load_fn": load_sampl_iupac, 117 | "split": "scaffold", 118 | }, 119 | "malaria": { 120 | "dataset_type": "regression", 121 | "load_fn": load_malaria, 122 | "split": "scaffold", 123 | }, 124 | "cep": { 125 | "dataset_type": "regression", 126 | "load_fn": load_cep, 127 | "split": "scaffold", 128 | } 129 | 130 | } 131 | 132 | 133 | def get_dataset_info(name: str): 134 | return MOLNET_DIRECTORY[name] 135 | 136 | 137 | def load_molnet_dataset( 138 | name: str, 139 | split: str = None, 140 | tasks_wanted: List = None, 141 | df_format: str = "chemberta", 142 | seed: int = 0, 143 | ): 144 | """Loads a MolNet dataset into a DataFrame ready for either chemberta or chemprop. 145 | 146 | Args: 147 | name: Name of MolNet dataset (e.g., "bbbp", "tox21"). 148 | split: Split name. Defaults to the split specified in MOLNET_DIRECTORY. 149 | tasks_wanted: List of tasks from dataset. Defaults to `tasks_wanted` in MOLNET_DIRECTORY, if specified, or else all available tasks. 150 | df_format: `chemberta` or `chemprop` 151 | 152 | Returns: 153 | tasks_wanted, (train_df, valid_df, test_df), transformers 154 | 155 | """ 156 | load_fn = MOLNET_DIRECTORY[name]["load_fn"] 157 | tasks, splits, transformers = load_fn( 158 | featurizer="Raw", splitter=split or MOLNET_DIRECTORY[name]["split"], seed=seed, 159 | ) 160 | 161 | # Default to all available tasks 162 | if tasks_wanted is None: 163 | tasks_wanted = MOLNET_DIRECTORY[name].get("tasks_wanted", tasks) 164 | # tasks_wanted = ['Nervous system disorders'] 165 | print(f"Using tasks {tasks_wanted} from available tasks for {name}: {tasks}") 166 | 167 | # tasks_wanted = tasks.copy() 168 | # # for sp in splits: 169 | # sp = splits[2] # test set 170 | # for ti, task in enumerate(tasks): 171 | # all_valid_label = sp.y[:, ti][sp.w[:, ti]!=0] 172 | # if (np.unique(all_valid_label)).size <= 1: 173 | # tasks_wanted.remove(task) 174 | 175 | 176 | 177 | return ( 178 | tasks_wanted, 179 | [ 180 | make_dataframe( 181 | s, 182 | MOLNET_DIRECTORY[name]["dataset_type"], 183 | tasks, 184 | tasks_wanted, 185 | df_format, 186 | ) 187 | for s in splits 188 | ], 189 | transformers, 190 | ) 191 | 192 | 193 | def write_molnet_dataset_for_chemprop( 194 | name: str, split: str = None, tasks_wanted: List = None, data_dir: str = None 195 | ): 196 | """Writes a MolNet dataset to separate train, val, test CSVs ready for chemprop. 197 | 198 | Args: 199 | name: Name of MolNet dataset (e.g., "bbbp", "tox21"). 200 | split: Split name. Defaults to the split specified in MOLNET_DIRECTORY. 201 | tasks_wanted: List of tasks from dataset. Defaults to all available tasks. 202 | data_dir: Location to write CSV files. Defaults to /tmp/molnet/{name}/. 203 | 204 | Returns: 205 | tasks_wanted, (train_df, valid_df, test_df), transformers, out_paths 206 | 207 | """ 208 | if data_dir is None: 209 | data_dir = os.path.join("/tmp/molnet/", name) 210 | os.makedirs(data_dir, exist_ok=True) 211 | 212 | tasks, dataframes, transformers = load_molnet_dataset( 213 | name, split=split, tasks_wanted=tasks_wanted, df_format="chemprop" 214 | ) 215 | 216 | out_paths = [] 217 | for split_name, df in zip(["train", "val", "test"], dataframes): 218 | path = os.path.join(data_dir, f"{split_name}.csv") 219 | out_paths.append(path) 220 | df.to_csv(path, index=False) 221 | 222 | return tasks, dataframes, transformers, out_paths 223 | 224 | 225 | 226 | def to_dataframe(data_dict) -> pd.DataFrame: 227 | """Construct a pandas DataFrame containing the data from this Dataset. 228 | 229 | Returns 230 | ------- 231 | pd.DataFrame 232 | Pandas dataframe. If there is only a single feature per datapoint, 233 | will have column "X" else will have columns "X1,X2,..." for 234 | features. If there is only a single label per datapoint, will 235 | have column "y" else will have columns "y1,y2,..." for labels. If 236 | there is only a single weight per datapoint will have column "w" 237 | else will have columns "w1,w2,...". Will have column "ids" for 238 | identifiers. 239 | """ 240 | X = data_dict["X"] 241 | y = data_dict["y"] 242 | w = data_dict["w"] 243 | if len(X.shape) == 1 or X.shape[1] == 1: 244 | columns = ['X'] 245 | else: 246 | columns = [f'X{i+1}' for i in range(X.shape[1])] 247 | X_df = pd.DataFrame(X, columns=columns) 248 | if len(y.shape) == 1 or y.shape[1] == 1: 249 | columns = ['y'] 250 | else: 251 | columns = [f'y{i+1}' for i in range(y.shape[1])] 252 | y_df = pd.DataFrame(y, columns=columns) 253 | if len(w.shape) == 1 or w.shape[1] == 1: 254 | columns = ['w'] 255 | else: 256 | columns = [f'w{i+1}' for i in range(w.shape[1])] 257 | w_df = pd.DataFrame(w, columns=columns) 258 | 259 | ids = data_dict["smile_ids"] 260 | smiles_ids_df = pd.DataFrame(ids, columns=['smile_ids']) 261 | if 'iupac_ids' in data_dict: 262 | iupac_ids = data_dict["iupac_ids"] 263 | iupac_ids_df = pd.DataFrame(iupac_ids, columns=['iupac_ids']) 264 | return pd.concat([X_df, y_df, w_df, smiles_ids_df, iupac_ids_df], axis=1, sort=False) 265 | else: 266 | return pd.concat([X_df, y_df, w_df, smiles_ids_df], axis=1, sort=False) 267 | 268 | 269 | 270 | 271 | 272 | 273 | def make_dataframe( 274 | dataset, dataset_type, tasks, tasks_wanted, df_format: str = "chemberta" 275 | ): 276 | iupac = False 277 | if len(dataset.ids.shape) == 2 and dataset.ids.shape[1] == 2: # contains iupac 278 | iupac = True 279 | data_dict = {} 280 | data_dict["iupac_ids"] = dataset.ids[:,1] 281 | data_dict["smile_ids"] = dataset.ids[:,0] 282 | data_dict["y"] = dataset.y 283 | data_dict["w"] = dataset.w 284 | data_dict["X"] = dataset.X 285 | # no need for X; rdkit Chem rdchem Mol object 286 | # df = pd.DataFrame(data_dict) 287 | df = to_dataframe(data_dict) 288 | elif 'MUV' in tasks[0]: # muv dataset 289 | iupac = True 290 | data_dict = {} 291 | data_dict["smile_ids"] = dataset.ids 292 | data_dict["y"] = dataset.y 293 | data_dict["w"] = dataset.w 294 | data_dict["X"] = dataset.X 295 | df = to_dataframe(data_dict) 296 | else: 297 | df = dataset.to_dataframe() 298 | 299 | if len(tasks) == 1: 300 | mapper = {"y": tasks[0]} 301 | else: 302 | 303 | tasks_wanted_index_dict = {} 304 | for task in tasks_wanted: 305 | tasks_wanted_index_dict[task] = tasks.index(task) 306 | mapper = {f"y{y_i+1}": task for task, y_i in tasks_wanted_index_dict.items()} 307 | df.rename(mapper, axis="columns", inplace=True) 308 | 309 | 310 | # Canonicalize SMILES 311 | # smiles_list = [Chem.MolToSmiles(s, isomericSmiles=True) for s in df["X"]] 312 | smiles_list = [Chem.MolToSmiles(s) for s in df["X"]] 313 | # smiles_list = [Chem.MolToSmiles(s, isomericSmiles=False, canonical=True) for s in df["X"]] 314 | 315 | # smiles_list_2 = [Chem.MolToSmiles(s, isomericSmiles=False, canonical=True) for s in df["X"]] 316 | # df['smile_ids'] = smiles_list_2 317 | 318 | # df['smile_ids'] = smiles_list 319 | # Convert labels to integer for classification 320 | labels = df[tasks_wanted] 321 | if dataset_type == "classification": 322 | labels = labels.astype(int) 323 | 324 | elif dataset_type == "regression": 325 | labels = labels.astype(float) 326 | 327 | if iupac: 328 | return df 329 | 330 | if df_format == "chemberta": 331 | if len(tasks_wanted) == 1: 332 | labels = labels.values.flatten() 333 | else: 334 | # Convert labels to list for simpletransformers multi-label 335 | labels = labels.values.tolist() 336 | return pd.DataFrame({"text": smiles_list, "labels": labels}) 337 | elif df_format == "chemprop": 338 | df_out = pd.DataFrame({"smiles": smiles_list}) 339 | for task in tasks_wanted: 340 | df_out[task] = labels[task] 341 | return df_out 342 | else: 343 | raise ValueError(df_format) -------------------------------------------------------------------------------- /utils/ddi_dataset.py: -------------------------------------------------------------------------------- 1 | from cProfile import label 2 | import os 3 | import torch 4 | from rdkit.Chem.PandasTools import LoadSDF 5 | from pe_2d.utils_pe_seq import InputExample, convert_examples_seq_to_features 6 | import numpy as np 7 | from utils.mol import smiles2graph 8 | 9 | 10 | 11 | def convert_to_smiles_seq_examples(smiles_ids): 12 | input_examples = [] 13 | for smiles_id in smiles_ids: 14 | input_examples.append(InputExample( 15 | seq=smiles_id, 16 | )) 17 | return input_examples 18 | 19 | 20 | class FinetuneDDIDataset(torch.utils.data.Dataset): 21 | def __init__(self, data_folder, pair_ids, tokenizer, smiles_dict, labels, data_suffix=".sdf", get_labels=True): 22 | # collect smiles 23 | self.smiles_lst = [] 24 | self.pair_labels = [] 25 | for sdf_ids in pair_ids: 26 | smiles_pair = [] 27 | for sdf_id in sdf_ids: 28 | smiles_pair.append(smiles_dict[sdf_id]) 29 | # smiles_pair.append(smiles) 30 | self.smiles_lst.append(smiles_pair) 31 | 32 | 33 | 34 | # self.pair_labels = None 35 | # self.labels = None 36 | # if labels is not None: 37 | 38 | self.get_labels = get_labels 39 | 40 | self.pair_labels = labels 41 | assert len(self.pair_labels) == len(self.smiles_lst) 42 | self.labels = np.array(self.pair_labels) 43 | 44 | 45 | np_smiles_lst = np.array(self.smiles_lst) 46 | left_smiles_lst = np_smiles_lst[:, 0].tolist() 47 | right_smiles_lst = np_smiles_lst[:, 1].tolist() 48 | 49 | 50 | 51 | # tokenizer 52 | left_input_examples = convert_to_smiles_seq_examples(left_smiles_lst) 53 | right_input_examples = convert_to_smiles_seq_examples(right_smiles_lst) 54 | self.left_encodings = convert_examples_seq_to_features(left_input_examples, max_seq_length=128,tokenizer=tokenizer) 55 | self.right_encodings = convert_examples_seq_to_features(right_input_examples, max_seq_length=128,tokenizer=tokenizer) 56 | 57 | 58 | def __len__(self): 59 | return len(self.smiles_lst) 60 | 61 | 62 | def __getitem__(self, idx): 63 | item = {} 64 | 65 | # no need to concat 66 | 67 | item['left_input_ids']=self.left_encodings[idx].input_ids 68 | item['left_attention_mask']=self.left_encodings[idx].attention_mask 69 | 70 | item['right_input_ids'] = self.right_encodings[idx].input_ids 71 | item['right_attention_mask'] = self.right_encodings[idx].attention_mask 72 | 73 | smiles_pair = self.smiles_lst[idx] 74 | 75 | left_graph, _ = smiles2graph(smiles_pair[0]) 76 | right_graph, _ = smiles2graph(smiles_pair[1]) 77 | 78 | 79 | # graph 80 | item['left_graph'] = left_graph 81 | item['right_graph'] = right_graph 82 | 83 | 84 | # label 85 | # if self.pair_labels is not None: 86 | if self.get_labels: 87 | item["labels"] = torch.tensor(self.pair_labels[idx], dtype=torch.float) 88 | 89 | return item 90 | 91 | # def __init__(self, df, tokenizer, include_labels=True, use_struct_pos=True, tasks_wanted=None, iupac_only=False, lang_only=False, gnn_only=False, iupac_smiles_concat=False, graph_uni=False, use_rdkit_feature=False): 92 | # #df = df[(df['iupac_ids'] != 'Request error') & (df['iupac_ids'] != '')] # assume df has iupac field 93 | # df = df[(df['iupac_ids'] != 'Request error') & (df['iupac_ids'] != '') & (pd.isna(df['iupac_ids']) == False)] # assume df has iupac field 94 | # self.iupac_only = iupac_only 95 | # self.lang_only = lang_only 96 | # self.gnn_only = gnn_only 97 | 98 | # self.use_struct_pos = use_struct_pos 99 | 100 | # # concat smiles and iupac as one sequence 101 | # self.iupac_smiles_concat = iupac_smiles_concat 102 | # self.graph_uni = graph_uni 103 | # # self.iupac_ids = False 104 | # # if 'iupac_ids' in df.keys(): 105 | 106 | # # filter 107 | 108 | # self.fp_features = None 109 | # if use_rdkit_feature: 110 | # self.fp_features = [] 111 | # pool = multiprocessing.Pool(24) 112 | # smiles_lst = df["smile_ids"].tolist() 113 | # total = len(smiles_lst) 114 | 115 | # for res in tqdm(pool.imap(rdkit_2d_features_normalized_generator, smiles_lst, chunksize=10), total=total): 116 | # replace_token = 0 117 | # fp_feature = np.where(np.isnan(res), replace_token, res) 118 | # self.fp_features.append(np.float32(fp_feature)) 119 | 120 | 121 | # # for smile in df["smile_ids"].tolist(): 122 | # # fp_feature = rdkit_2d_features_normalized_generator(smile) 123 | # # replace_token = 0 124 | # # fp_feature = np.where(np.isnan(fp_feature), replace_token, fp_feature) 125 | # # self.fp_features.append(np.float32(fp_feature)) 126 | 127 | 128 | # if self.lang_only: 129 | # self.smiles_lst = df["smile_ids"].tolist() # for uni 130 | # if self.iupac_smiles_concat: 131 | # input_examples = convert_to_iupac_seq_examples(df["iupac_ids"].tolist()) 132 | # self.iupac_features, self.iupac_length = convert_examples_seq_to_features_wlen(input_examples, max_seq_length=128,tokenizer=tokenizer[0]) 133 | # input_examples = convert_to_smiles_seq_examples(df["smile_ids"].tolist()) 134 | # self.smiles_features, self.smiles_length = convert_examples_seq_to_features_wlen(input_examples, max_seq_length=128,tokenizer=tokenizer[1]) 135 | 136 | # self.iupac_start_emb = tokenizer[1].vocab_size # smiles token size of the begining 137 | 138 | # self.labels = df.iloc[:, 2].values 139 | 140 | # elif self.iupac_only: 141 | # if use_struct_pos: 142 | # input_examples = convert_to_input_examples(df["iupac_ids"].tolist()) 143 | # self.features = convert_examples_with_strucpos_to_features(input_examples, max_seq_length=128, 144 | # max_pos_depth=16, tokenizer=tokenizer[0]) 145 | # else: 146 | # input_examples = convert_to_iupac_seq_examples(df["iupac_ids"].tolist()) 147 | # self.features = convert_examples_seq_to_features(input_examples, max_seq_length=128,tokenizer=tokenizer[0]) 148 | # # self.encodings = tokenizer(df["smiles"].tolist(), truncation=True, padding=True) 149 | # self.iupac_ids = True 150 | # self.labels = df.iloc[:, 2].values 151 | # else: 152 | # input_examples = convert_to_smiles_seq_examples(df["smile_ids"].tolist()) 153 | # self.encodings = convert_examples_seq_to_features(input_examples, max_seq_length=128,tokenizer=tokenizer[1]) 154 | # self.labels = df.iloc[:, 1].values 155 | # elif self.gnn_only: 156 | # # save smiles_ids 157 | # self.smiles_lst = df["smile_ids"].tolist() 158 | # else: 159 | # raise NotImplementedError 160 | 161 | 162 | # if tasks_wanted is not None: 163 | # if len(tasks_wanted) == 1: 164 | # self.labels = df[tasks_wanted[0]].values 165 | # else: 166 | # labels = [] 167 | # for task in tasks_wanted: 168 | # labels.append(df[task].values.reshape(-1, 1)) 169 | # self.labels = np.concatenate(labels, axis=1) 170 | 171 | # task_weight_cols = list(df.columns[1:-2]) # X ,..., smiles_ids, iupac_ids 172 | # self.task_weights = None 173 | # if 'w' in task_weight_cols or 'w1' in task_weight_cols: 174 | # # import pdb; pdb.set_trace() 175 | # task_weights = [] 176 | # for task in tasks_wanted: 177 | # task_idx = task_weight_cols.index(task) 178 | # task_weight_col_name = 'w' + str(task_idx + 1) 179 | # if len(task_weight_cols) == 2: 180 | # task_weight_col_name = task_weight_cols[1] 181 | # task_weights.append(df[task_weight_col_name].tolist()) 182 | 183 | # task_weights = np.array(task_weights, dtype=np.float32) 184 | 185 | # self.task_weights = task_weights.T 186 | # self.include_labels = include_labels 187 | 188 | 189 | # def _concat_lang(self, smiles_inputs, iupac_inputs, smile_len, iupac_len, concat_max_length=256): 190 | # return_dict = {} 191 | # # import pdb;pdb.set_trace() 192 | # return_dict['input_ids'] = torch.full((1, concat_max_length), 1)[0] # padding: 1 193 | # return_dict['attention_mask'] = torch.full((1, concat_max_length), 0)[0] # attention mask, default: 0 194 | 195 | 196 | # return_dict['input_ids'][:smile_len] = torch.tensor(smiles_inputs.input_ids[:smile_len]) 197 | 198 | # iupac_input_ids = torch.tensor(iupac_inputs.input_ids[1:iupac_len]) # earse the iupac cls token 199 | # iupac_input_ids[:-1] += self.iupac_start_emb # except the sep token 200 | # return_dict['input_ids'][smile_len: smile_len + iupac_len - 1] = iupac_input_ids 201 | 202 | # return_dict['attention_mask'][:smile_len] = torch.tensor(smiles_inputs.attention_mask[:smile_len]) 203 | # return_dict['attention_mask'][smile_len: smile_len + iupac_len - 1] = torch.tensor(iupac_inputs.attention_mask[1:iupac_len]) # erase the iupac cls token 204 | # return return_dict 205 | 206 | 207 | # def __getitem__(self, idx): 208 | # # get smiles and transfer to graph 209 | # if self.gnn_only: 210 | # item = {} 211 | # smiles = self.smiles_lst[idx] 212 | # graph, _ = smiles2graph(smiles) 213 | # item['graph'] = graph 214 | # elif self.lang_only: 215 | # if self.iupac_smiles_concat: 216 | # item = {} 217 | # # concat smiles and iupac features 218 | # item = self._concat_lang(self.smiles_features[idx], self.iupac_features[idx], \ 219 | # self.smiles_length[idx], self.iupac_length[idx], FLAGS.max_concat_len) 220 | # elif self.iupac_only: 221 | # if self.use_struct_pos: 222 | # item = {} 223 | # item['input_ids']=self.features[idx].input_ids 224 | # item['attention_mask']=self.features[idx].attention_mask 225 | # item['strucpos_ids']=self.features[idx].strucpos_ids 226 | # else: 227 | # #item = {key: torch.tensor(val[idx]) for key, val in self.features.items()} 228 | # item = {} 229 | # item['input_ids']=self.features[idx].input_ids 230 | # item['attention_mask']=self.features[idx].attention_mask 231 | # else: 232 | # item = {} 233 | # item['input_ids']=self.encodings[idx].input_ids 234 | # item['attention_mask']=self.encodings[idx].attention_mask 235 | 236 | # if self.graph_uni: 237 | # smiles = self.smiles_lst[idx] 238 | # graph, _ = smiles2graph(smiles) 239 | # item['graph'] = graph # add graph 240 | # else: 241 | # raise NotImplementedError 242 | 243 | # if self.include_labels and self.labels is not None: 244 | # item["labels"] = torch.tensor(self.labels[idx], dtype=torch.float) 245 | # if self.task_weights is not None: 246 | # item['weight'] = torch.tensor(self.task_weights[idx], dtype=torch.float) 247 | 248 | # if self.fp_features is not None: 249 | # item['fp_feature'] = self.fp_features[idx] 250 | 251 | # return item 252 | 253 | # def __len__(self): 254 | # if self.gnn_only: 255 | # return len(self.smiles_lst) 256 | # if self.iupac_smiles_concat: 257 | # return len(self.iupac_features) 258 | 259 | # if self.iupac_only: 260 | # return len(self.features)#["input_ids"]) 261 | # return len(self.encodings)#["input_ids"]) 262 | -------------------------------------------------------------------------------- /gcn.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.nn.modules.sparse import Embedding 5 | from torch_geometric.nn import MessagePassing 6 | from torch_scatter import scatter 7 | from torch import nn, Tensor 8 | from utils.features import get_atom_feature_dims, get_bond_feature_dims 9 | # from fairseq import utils 10 | from torch_geometric.nn import global_max_pool, global_mean_pool, global_sort_pool 11 | 12 | from torch_geometric.utils import add_self_loops 13 | 14 | 15 | class GINConv(MessagePassing): 16 | """ 17 | Extension of GIN aggregation to incorporate edge information by concatenation. 18 | 19 | Args: 20 | emb_dim (int): dimensionality of embeddings for nodes and edges. 21 | embed_input (bool): whether to embed input or not. 22 | 23 | 24 | See https://arxiv.org/abs/1810.00826 25 | """ 26 | def __init__(self, emb_dim, out_dim, num_bond_type, num_bond_direction, aggr = "add", **kwargs): 27 | kwargs.setdefault('aggr', aggr) 28 | self.aggr = aggr 29 | super(GINConv, self).__init__(**kwargs) 30 | #multi-layer perceptron 31 | self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, out_dim)) 32 | self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim) 33 | self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim) 34 | 35 | torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data) 36 | torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data) 37 | 38 | def forward(self, x, edge_index, edge_attr): 39 | #add self loops in the edge space 40 | edge_index, _ = add_self_loops(edge_index, num_nodes = x.size(0)) 41 | 42 | #add features corresponding to self-loop edges. 43 | self_loop_attr = torch.zeros(x.size(0), 2) 44 | self_loop_attr[:,0] = 4 #bond type for self-loop edge 45 | self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype) 46 | edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0) 47 | 48 | edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1]) 49 | 50 | # return self.propagate(self.aggr, edge_index, x=x, edge_attr=edge_embeddings) 51 | return self.propagate(edge_index, x=x, edge_attr=edge_embeddings) 52 | 53 | def message(self, x_j, edge_attr): 54 | return x_j + edge_attr 55 | 56 | def update(self, aggr_out): 57 | return self.mlp(aggr_out) 58 | 59 | 60 | 61 | class CustomMessagePassing(MessagePassing): 62 | def __init__(self, aggr: Optional[str] = "maxminmean", embed_dim: Optional[int] = None): 63 | if aggr in ['maxminmean']: 64 | super().__init__(aggr=None) 65 | self.aggr = aggr 66 | assert embed_dim is not None 67 | self.aggrmlp = nn.Linear(3 * embed_dim, embed_dim) 68 | else: 69 | super().__init__(aggr=aggr) 70 | 71 | def aggregate(self, inputs: Tensor, index: Tensor, ptr: Optional[Tensor], 72 | dim_size: Optional[int]) -> Tensor: 73 | if self.aggr in ['maxminmean']: 74 | inputs_fp32 = inputs.float() 75 | input_max = scatter(inputs_fp32, 76 | index, 77 | dim=self.node_dim, 78 | dim_size=dim_size, 79 | reduce='max') 80 | input_min = scatter(inputs_fp32, 81 | index, 82 | dim=self.node_dim, 83 | dim_size=dim_size, 84 | reduce='min') 85 | input_mean = scatter(inputs_fp32, 86 | index, 87 | dim=self.node_dim, 88 | dim_size=dim_size, 89 | reduce='mean') 90 | aggr_out = torch.cat([input_max, input_min, input_mean], dim=-1).type_as(inputs) 91 | aggr_out = self.aggrmlp(aggr_out) 92 | return aggr_out 93 | else: 94 | return super().aggregate(inputs, index, ptr, dim_size) 95 | 96 | 97 | class MulOnehotEncoder(nn.Module): 98 | def __init__(self, embed_dim, get_feature_dims: Callable): 99 | super().__init__() 100 | self.atom_embedding_list = nn.ModuleList() 101 | 102 | for dim in get_feature_dims(): 103 | emb = nn.Embedding(dim, embed_dim) 104 | nn.init.xavier_uniform_(emb.weight.data) 105 | self.atom_embedding_list.append(emb) 106 | 107 | def forward(self, x): 108 | x_embedding = 0 109 | for i in range(x.shape[1]): 110 | x_embedding = x_embedding + self.atom_embedding_list[i](x[:, i]) 111 | return x_embedding 112 | 113 | 114 | 115 | class ResidualGINLayer(CustomMessagePassing): 116 | def __init__(self, 117 | in_dim, 118 | emb_dim, 119 | aggr='add', 120 | encode_edge=False, 121 | bond_encoder=False, 122 | edge_feat_dim=None): 123 | super().__init__(aggr, embed_dim=in_dim) 124 | # self.mlp = nn.Linear(in_dim, emb_dim) 125 | self.mlp = torch.nn.Sequential(torch.nn.Linear(in_dim, 2*in_dim), torch.nn.GELU(), torch.nn.Linear(2*in_dim, emb_dim)) 126 | self.encode_edge = encode_edge 127 | 128 | if encode_edge: 129 | if bond_encoder: 130 | self.edge_encoder = MulOnehotEncoder(in_dim, get_bond_feature_dims) 131 | else: 132 | self.edge_encoder = nn.Linear(edge_feat_dim, in_dim) 133 | 134 | def forward(self, x, edge_index, edge_attr=None): 135 | if self.encode_edge and edge_attr is not None: 136 | edge_emb = self.edge_encoder(edge_attr) 137 | else: 138 | edge_emb = None 139 | m = self.propagate(edge_index, x=x, edge_attr=edge_emb) 140 | h = x + m 141 | out = self.mlp(h) 142 | return out 143 | 144 | def message(self, x_j, edge_attr=None): 145 | if edge_attr is not None: 146 | msg = x_j + edge_attr 147 | else: 148 | msg = x_j 149 | 150 | return msg 151 | 152 | 153 | class ResidualConvLayer(CustomMessagePassing): 154 | def __init__(self, 155 | in_dim, 156 | emb_dim, 157 | aggr, 158 | encode_edge=False, 159 | bond_encoder=False, 160 | edge_feat_dim=None): 161 | super().__init__(aggr, embed_dim=in_dim) 162 | self.mlp = nn.Linear(in_dim, emb_dim) 163 | self.encode_edge = encode_edge 164 | 165 | if encode_edge: 166 | if bond_encoder: 167 | self.edge_encoder = MulOnehotEncoder(in_dim, get_bond_feature_dims) 168 | else: 169 | self.edge_encoder = nn.Linear(edge_feat_dim, in_dim) 170 | 171 | def forward(self, x, edge_index, edge_attr=None): 172 | if self.encode_edge and edge_attr is not None: 173 | edge_emb = self.edge_encoder(edge_attr) 174 | else: 175 | edge_emb = None 176 | m = self.propagate(edge_index, x=x, edge_attr=edge_emb) 177 | h = x + m 178 | out = self.mlp(h) 179 | return out 180 | 181 | def message(self, x_j, edge_attr=None): 182 | if edge_attr is not None: 183 | msg = x_j + edge_attr 184 | else: 185 | msg = x_j 186 | 187 | return msg 188 | 189 | 190 | def get_norm_layer(norm, fea_dim): 191 | norm = norm.lower() 192 | 193 | if norm == 'layer': 194 | return nn.LayerNorm(fea_dim) 195 | elif norm == "batch": 196 | return nn.BatchNorm1d(fea_dim) 197 | else: 198 | raise NotImplementedError() 199 | 200 | 201 | class AtomHead(nn.Module): 202 | def __init__(self, emb_dim, output_dim, activation_fn, weight=None, norm=None): 203 | super().__init__() 204 | self.dense = nn.Linear(emb_dim, emb_dim) 205 | # self.activation_fn = utils.get_activation_fn(activation_fn) 206 | if activation_fn == 'gelu': 207 | self.activation_fn = nn.GELU() 208 | else: 209 | self.activation_fn = nn.ReLU() 210 | self.norm = get_norm_layer(norm, emb_dim) 211 | 212 | if weight is None: 213 | weight = nn.Linear(emb_dim, output_dim, bias=False).weight 214 | self.weight = weight 215 | self.bias = nn.Parameter(torch.zeros(output_dim)) 216 | 217 | def forward(self, node_features, cls_features, masked_atom=None): 218 | if cls_features is not None: 219 | node_features = torch.cat((node_features, cls_features), 1) 220 | 221 | if masked_atom is not None: 222 | node_features = node_features[masked_atom, :] 223 | 224 | x = self.dense(node_features) 225 | x = self.activation_fn(x) 226 | x = self.norm(x) 227 | x = F.linear(x, self.weight) + self.bias 228 | return x 229 | 230 | 231 | class DeeperGCN(nn.Module): 232 | def __init__(self, args): 233 | super().__init__() 234 | self.num_layers = args['gnn_number_layer'] 235 | self.dropout = args['gnn_dropout'] 236 | self.conv_encode_edge = args['conv_encode_edge'] 237 | self.embed_dim = args['gnn_embed_dim'] 238 | self.aggr = args['gnn_aggr'] 239 | self.norm = args['gnn_norm'] 240 | self.act = args['gnn_act'] 241 | print("++++++++++use norm type: {}++++++++++++++++++++".format(self.norm)) 242 | 243 | self.gcns = nn.ModuleList() 244 | self.norms = nn.ModuleList() 245 | # self.activation_fn = utils.get_activation_fn(getattr(args, 'gnn_activation_fn', 'relu')) 246 | # self.activation_fn = nn.ReLU() 247 | if self.act == 'relu': 248 | self.activation_fn = nn.ReLU() 249 | else: 250 | self.activation_fn = nn.GELU() 251 | 252 | for layer in range(self.num_layers): 253 | self.gcns.append( 254 | ResidualConvLayer( 255 | self.embed_dim, 256 | self.embed_dim, 257 | self.aggr, 258 | encode_edge=self.conv_encode_edge, 259 | bond_encoder=True, 260 | )) 261 | self.norms.append(get_norm_layer(self.norm, self.embed_dim)) 262 | 263 | self.atom_encoder = MulOnehotEncoder(self.embed_dim, get_atom_feature_dims) 264 | if not self.conv_encode_edge: 265 | self.bond_encoder = MulOnehotEncoder(self.embed_dim, get_bond_feature_dims) 266 | 267 | self.graph_pred_linear = nn.Identity() 268 | self.output_features = 2 * self.embed_dim 269 | # self.atom_head = AtomHead(self.embed_dim, 270 | # get_atom_feature_dims()[0], 271 | # getattr(args, 'gnn_activation_fn', 'relu'), 272 | # norm=self.norm, 273 | # weight=self.atom_encoder.atom_embedding_list[0].weight) 274 | 275 | def forward(self, graph, masked_tokens=None, features_only=False): 276 | x = graph.x 277 | edge_index = graph.edge_index 278 | edge_attr = graph.edge_attr 279 | batch = graph.batch 280 | 281 | h = self.atom_encoder(x) 282 | 283 | if self.conv_encode_edge: 284 | edge_emb = edge_attr 285 | else: 286 | edge_emb = self.bond_encoder(edge_attr) 287 | 288 | h = self.gcns[0](h, edge_index, edge_emb) 289 | 290 | for layer in range(1, self.num_layers): 291 | residual = h 292 | h = self.norms[layer](h) 293 | h = self.activation_fn(h) 294 | h = F.dropout(h, p=self.dropout, training=self.training) 295 | h = self.gcns[layer](h, edge_index, edge_emb) 296 | h = h + residual 297 | h = self.norms[0](h) 298 | h = self.activation_fn(h) 299 | node_fea = F.dropout(h, p=self.dropout, training=self.training) 300 | 301 | graph_fea = self.pool(node_fea, batch) 302 | 303 | # if not features_only: 304 | # atom_pred = self.atom_head(node_fea, masked_tokens) 305 | # else: 306 | # atom_pred = None 307 | return graph_fea, node_fea 308 | # return (graph_fea, node_fea), atom_pred 309 | 310 | def pool(self, h, batch): 311 | h_fp32 = h.float() 312 | h_max = global_max_pool(h_fp32, batch) 313 | h_mean = global_mean_pool(h_fp32, batch) 314 | h = torch.cat([h_max, h_mean], dim=-1).type_as(h) 315 | h = self.graph_pred_linear(h) 316 | return h 317 | -------------------------------------------------------------------------------- /pe_2d/config.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ BERT model configuration """ 17 | from collections import OrderedDict 18 | from typing import Mapping 19 | 20 | from transformers.configuration_utils import PretrainedConfig 21 | from transformers.utils import logging 22 | 23 | 24 | logger = logging.get_logger(__name__) 25 | 26 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 27 | "bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/config.json", 28 | "bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/config.json", 29 | "bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/config.json", 30 | "bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/config.json", 31 | "bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/config.json", 32 | "bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/config.json", 33 | "bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/config.json", 34 | "bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/config.json", 35 | "bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/config.json", 36 | "bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/config.json", 37 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/config.json", 38 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/config.json", 39 | "bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/config.json", 40 | "bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/config.json", 41 | "bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/config.json", 42 | "cl-tohoku/bert-base-japanese": "https://huggingface.co/cl-tohoku/bert-base-japanese/resolve/main/config.json", 43 | "cl-tohoku/bert-base-japanese-whole-word-masking": "https://huggingface.co/cl-tohoku/bert-base-japanese-whole-word-masking/resolve/main/config.json", 44 | "cl-tohoku/bert-base-japanese-char": "https://huggingface.co/cl-tohoku/bert-base-japanese-char/resolve/main/config.json", 45 | "cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://huggingface.co/cl-tohoku/bert-base-japanese-char-whole-word-masking/resolve/main/config.json", 46 | "TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/config.json", 47 | "TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/config.json", 48 | "wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/config.json", 49 | # See all BERT models at https://huggingface.co/models?filter=bert 50 | } 51 | 52 | 53 | class BertConfig(PretrainedConfig): 54 | r""" 55 | This is the configuration class to store the configuration of a :class:`~transformers.BertModel` or a 56 | :class:`~transformers.TFBertModel`. It is used to instantiate a BERT model according to the specified arguments, 57 | defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration 58 | to that of the BERT `bert-base-uncased `__ architecture. 59 | 60 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model 61 | outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. 62 | 63 | 64 | Args: 65 | vocab_size (:obj:`int`, `optional`, defaults to 30522): 66 | Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the 67 | :obj:`inputs_ids` passed when calling :class:`~transformers.BertModel` or 68 | :class:`~transformers.TFBertModel`. 69 | hidden_size (:obj:`int`, `optional`, defaults to 768): 70 | Dimensionality of the encoder layers and the pooler layer. 71 | num_hidden_layers (:obj:`int`, `optional`, defaults to 12): 72 | Number of hidden layers in the Transformer encoder. 73 | num_attention_heads (:obj:`int`, `optional`, defaults to 12): 74 | Number of attention heads for each attention layer in the Transformer encoder. 75 | intermediate_size (:obj:`int`, `optional`, defaults to 3072): 76 | Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. 77 | hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`): 78 | The non-linear activation function (function or string) in the encoder and pooler. If string, 79 | :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported. 80 | hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): 81 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 82 | attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): 83 | The dropout ratio for the attention probabilities. 84 | max_position_embeddings (:obj:`int`, `optional`, defaults to 512): 85 | The maximum sequence length that this model might ever be used with. Typically set this to something large 86 | just in case (e.g., 512 or 1024 or 2048). 87 | type_vocab_size (:obj:`int`, `optional`, defaults to 2): 88 | The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BertModel` or 89 | :class:`~transformers.TFBertModel`. 90 | initializer_range (:obj:`float`, `optional`, defaults to 0.02): 91 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 92 | layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): 93 | The epsilon used by the layer normalization layers. 94 | gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): 95 | If True, use gradient checkpointing to save memory at the expense of slower backward pass. 96 | position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`): 97 | Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`, 98 | :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on 99 | :obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.) 100 | `__. For more information on :obj:`"relative_key_query"`, please refer to 101 | `Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.) 102 | `__. 103 | use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): 104 | Whether or not the model should return the last key/values attentions (not used by all models). Only 105 | relevant if ``config.is_decoder=True``. 106 | 107 | Examples:: 108 | 109 | >>> from transformers import BertModel, BertConfig 110 | 111 | >>> # Initializing a BERT bert-base-uncased style configuration 112 | >>> configuration = BertConfig() 113 | 114 | >>> # Initializing a model from the bert-base-uncased style configuration 115 | >>> model = BertModel(configuration) 116 | 117 | >>> # Accessing the model configuration 118 | >>> configuration = model.config 119 | """ 120 | model_type = "bert" 121 | 122 | def __init__( 123 | self, 124 | vocab_size=30522, 125 | hidden_size=768, 126 | num_hidden_layers=12, 127 | num_attention_heads=12, 128 | intermediate_size=3072, 129 | hidden_act="gelu", 130 | hidden_dropout_prob=0.1, 131 | attention_probs_dropout_prob=0.1, 132 | max_position_embeddings=512, 133 | type_vocab_size=2, 134 | initializer_range=0.02, 135 | layer_norm_eps=1e-12, 136 | pad_token_id=0, 137 | gradient_checkpointing=False, 138 | position_embedding_type="absolute", 139 | use_cache=True, 140 | pos_lambda1=1.0, 141 | strucpos_lambda2=1.0, 142 | strucpos_func='linear', 143 | num_labels_for_fg=0, 144 | num_labels_for_formula=0, 145 | problem_type_for_formula='regression', 146 | formula_lambda=0.1, 147 | **kwargs 148 | ): 149 | super().__init__(pad_token_id=pad_token_id, **kwargs) 150 | 151 | self.vocab_size = vocab_size 152 | self.hidden_size = hidden_size 153 | self.num_hidden_layers = num_hidden_layers 154 | self.num_attention_heads = num_attention_heads 155 | self.hidden_act = hidden_act 156 | self.intermediate_size = intermediate_size 157 | self.hidden_dropout_prob = hidden_dropout_prob 158 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 159 | self.max_position_embeddings = max_position_embeddings 160 | self.type_vocab_size = type_vocab_size 161 | self.initializer_range = initializer_range 162 | self.layer_norm_eps = layer_norm_eps 163 | self.gradient_checkpointing = gradient_checkpointing 164 | self.position_embedding_type = position_embedding_type 165 | self.use_cache = use_cache 166 | self.pos_lambda1 = pos_lambda1 167 | self.strucpos_lambda2 = strucpos_lambda2 168 | self.strucpos_func = strucpos_func 169 | self.num_labels_for_fg = num_labels_for_fg 170 | self.num_labels_for_formula = num_labels_for_formula 171 | self.problem_type_for_formula = problem_type_for_formula 172 | self.formula_lambda = formula_lambda 173 | 174 | 175 | 176 | ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { 177 | "roberta-base": "https://huggingface.co/roberta-base/resolve/main/config.json", 178 | "roberta-large": "https://huggingface.co/roberta-large/resolve/main/config.json", 179 | "roberta-large-mnli": "https://huggingface.co/roberta-large-mnli/resolve/main/config.json", 180 | "distilroberta-base": "https://huggingface.co/distilroberta-base/resolve/main/config.json", 181 | "roberta-base-openai-detector": "https://huggingface.co/roberta-base-openai-detector/resolve/main/config.json", 182 | "roberta-large-openai-detector": "https://huggingface.co/roberta-large-openai-detector/resolve/main/config.json", 183 | } 184 | 185 | 186 | class RobertaConfig(BertConfig): 187 | r""" 188 | This is the configuration class to store the configuration of a :class:`~transformers.RobertaModel` or a 189 | :class:`~transformers.TFRobertaModel`. It is used to instantiate a RoBERTa model according to the specified 190 | arguments, defining the model architecture. 191 | 192 | 193 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model 194 | outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. 195 | 196 | The :class:`~transformers.RobertaConfig` class directly inherits :class:`~transformers.BertConfig`. It reuses the 197 | same defaults. Please check the parent class for more information. 198 | 199 | Examples:: 200 | 201 | >>> from transformers import RobertaConfig, RobertaModel 202 | 203 | >>> # Initializing a RoBERTa configuration 204 | >>> configuration = RobertaConfig() 205 | 206 | >>> # Initializing a model from the configuration 207 | >>> model = RobertaModel(configuration) 208 | 209 | >>> # Accessing the model configuration 210 | >>> configuration = model.config 211 | """ 212 | model_type = "roberta" 213 | 214 | def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs): 215 | """Constructs RobertaConfig.""" 216 | super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) -------------------------------------------------------------------------------- /utils/dti_asset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch import nn 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | from rdkit.Chem import AllChem 7 | from torch_geometric.data import InMemoryDataset 8 | from pe_2d.utils_pe_seq import InputExample, convert_examples_seq_to_features 9 | from utils.mol import smiles2graph 10 | from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel 11 | from utils.multilingual_regression import RobertaFeatureHead, RobertaHead 12 | from gcn import DeeperGCN 13 | from models import Pooler 14 | from utils.raw_text_dataset import collate_tokens 15 | from torch.nn import SmoothL1Loss 16 | # from .molecule_datasets import mol_to_graph_data_obj_simple 17 | 18 | seq_voc = "ABCDEFGHIKLMNOPQRSTUVWXYZ" 19 | seq_dict = {v:(i+1) for i,v in enumerate(seq_voc)} 20 | seq_dict_len = len(seq_dict) 21 | max_seq_len = 1000 22 | 23 | 24 | def seq_cat(prot): 25 | x = np.zeros(max_seq_len) 26 | for i, ch in enumerate(prot[:max_seq_len]): 27 | x[i] = seq_dict[ch] 28 | return x 29 | 30 | 31 | def convert_to_smiles_seq_examples(smiles_ids): 32 | input_examples = [] 33 | for smiles_id in smiles_ids: 34 | input_examples.append(InputExample( 35 | seq=smiles_id, 36 | )) 37 | return input_examples 38 | 39 | 40 | class MoleculeProteinDataset(InMemoryDataset): 41 | def __init__(self, root, dataset, smiles_tokenizer, mode, include_labels=False): 42 | super(InMemoryDataset, self).__init__() 43 | self.root = root 44 | self.dataset = dataset 45 | datapath = os.path.join(self.root, self.dataset, '{}.csv'.format(mode)) 46 | print('datapath\t', datapath) 47 | 48 | self.smiles_tokenizer = smiles_tokenizer 49 | 50 | self.process_molecule() 51 | self.process_protein() 52 | 53 | df = pd.read_csv(datapath) 54 | self.molecule_index_list = df['smiles_id'].tolist() 55 | self.protein_index_list = df['target_id'].tolist() 56 | self.label_list = df['affinity'].tolist() 57 | self.label_list = torch.FloatTensor(self.label_list) 58 | self.labels = self.label_list 59 | self.include_labels = include_labels 60 | 61 | return 62 | 63 | def process_molecule(self): 64 | input_path = os.path.join(self.root, self.dataset, 'smiles.csv') 65 | input_df = pd.read_csv(input_path, sep=',') 66 | self.smiles_list = input_df['smiles'] 67 | input_examples = convert_to_smiles_seq_examples(self.smiles_list) 68 | self.encodings = convert_examples_seq_to_features(input_examples, max_seq_length=128,tokenizer=self.smiles_tokenizer) 69 | 70 | 71 | # def process_molecule(self): 72 | # input_path = os.path.join(self.root, self.dataset, 'smiles.csv') 73 | # input_df = pd.read_csv(input_path, sep=',') 74 | # smiles_list = input_df['smiles'] 75 | 76 | # rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] 77 | # preprocessed_rdkit_mol_objs_list = [m if m != None else None for m in rdkit_mol_objs_list] 78 | # preprocessed_smiles_list = [AllChem.MolToSmiles(m) if m != None else None for m in preprocessed_rdkit_mol_objs_list] 79 | # assert len(smiles_list) == len(preprocessed_rdkit_mol_objs_list) 80 | # assert len(smiles_list) == len(preprocessed_smiles_list) 81 | 82 | # smiles_list, rdkit_mol_objs = preprocessed_smiles_list, preprocessed_rdkit_mol_objs_list 83 | 84 | # data_list = [] 85 | # for i in range(len(smiles_list)): 86 | # rdkit_mol = rdkit_mol_objs[i] 87 | # if rdkit_mol != None: 88 | # data = mol_to_graph_data_obj_simple(rdkit_mol) 89 | # data.id = torch.tensor([i]) 90 | # data_list.append(data) 91 | 92 | # self.molecule_list = data_list 93 | # return 94 | 95 | def process_protein(self): 96 | datapath = os.path.join(self.root, self.dataset, 'protein.csv') 97 | 98 | input_df = pd.read_csv(datapath, sep=',') 99 | protein_list = input_df['protein'].tolist() 100 | 101 | self.protein_list = [seq_cat(t) for t in protein_list] 102 | self.protein_list = torch.LongTensor(self.protein_list) 103 | return 104 | 105 | def __getitem__(self, idx): 106 | item = {} 107 | # molecule = self.molecule_list[self.molecule_index_list[idx]] 108 | item['input_ids']=self.encodings[self.molecule_index_list[idx]].input_ids 109 | item['attention_mask']=self.encodings[self.molecule_index_list[idx]].attention_mask 110 | smiles = self.smiles_list[self.molecule_index_list[idx]] 111 | graph, _ = smiles2graph(smiles) 112 | item['graph'] = graph # add graph 113 | protein = self.protein_list[self.protein_index_list[idx]] 114 | item['protein_encoding'] = protein 115 | if self.include_labels: 116 | label = self.label_list[idx] 117 | item['label'] = label 118 | return item 119 | 120 | def __len__(self): 121 | return len(self.label_list) 122 | 123 | 124 | class ProteinModel(nn.Module): 125 | def __init__(self, emb_dim=128, num_features=25, output_dim=128, n_filters=32, kernel_size=8): 126 | super(ProteinModel, self).__init__() 127 | self.n_filters = n_filters 128 | self.kernel_size = kernel_size 129 | self.intermediate_dim = emb_dim - kernel_size + 1 130 | 131 | self.embedding = nn.Embedding(num_features+1, emb_dim) 132 | self.n_filters = n_filters 133 | self.conv1 = nn.Conv1d(in_channels=1000, out_channels=n_filters, kernel_size=kernel_size) 134 | self.fc = nn.Linear(n_filters*self.intermediate_dim, output_dim) 135 | 136 | def forward(self, x): 137 | x = self.embedding(x) 138 | x = self.conv1(x) 139 | x = x.view(-1, self.n_filters*self.intermediate_dim) 140 | x = self.fc(x) 141 | return x 142 | 143 | class MoleculeProteinModel(nn.Module): 144 | def __init__(self, molecule_model, protein_model, molecule_emb_dim, protein_emb_dim, output_dim=1, dropout=0.2): 145 | super(MoleculeProteinModel, self).__init__() 146 | self.fc1 = nn.Linear(molecule_emb_dim+protein_emb_dim, 1024) 147 | self.fc2 = nn.Linear(1024, 512) 148 | self.out = nn.Linear(512, output_dim) 149 | self.molecule_model = molecule_model 150 | self.protein_model = protein_model 151 | self.pool = global_mean_pool 152 | self.relu = nn.ReLU() 153 | self.dropout = nn.Dropout(dropout) 154 | 155 | def forward(self, molecule, protein): 156 | molecule_node_representation = self.molecule_model(molecule) 157 | molecule_representation = self.pool(molecule_node_representation, molecule.batch) 158 | protein_representation = self.protein_model(protein) 159 | 160 | x = torch.cat([molecule_representation, protein_representation], dim=1) 161 | 162 | x = self.fc1(x) 163 | x = self.relu(x) 164 | x = self.dropout(x) 165 | x = self.fc2(x) 166 | x = self.relu(x) 167 | x = self.dropout(x) 168 | x = self.out(x) 169 | 170 | return x 171 | 172 | 173 | class MultilingualModalUNIDTI(RobertaPreTrainedModel): 174 | _keys_to_ignore_on_load_missing = ["position_ids"] 175 | 176 | def __init__(self, config, gcn_config, is_regression=False, use_label_weight=False, use_rdkit_feature=False): 177 | super().__init__(config) 178 | self.num_labels = config.num_labels 179 | self.num_tasks = config.num_tasks # sider have 27 binary tasks, maybe multi head is useful for multi label classification 180 | 181 | self.register_buffer("norm_mean", torch.tensor(config.norm_mean)) 182 | # Replace any 0 stddev norms with 1 183 | self.register_buffer( 184 | "norm_std", 185 | torch.tensor( 186 | [label_std if label_std != 0 else 1 for label_std in config.norm_std] 187 | ), 188 | ) 189 | 190 | if self.num_tasks > 1: 191 | assert self.num_labels == 2 # binary multi label classification 192 | 193 | # iupac and smiles has same 194 | from multimodal.modeling_roberta import RobertaModel 195 | self.lang_roberta = RobertaModel(config, add_pooling_layer=True) 196 | # self.smiles_roberta = RobertaModel(smiles_config, add_pooling_layer=True) 197 | 198 | self.lang_pooler = Pooler(config.pooler_type) 199 | self.gnn = DeeperGCN(gcn_config) 200 | 201 | self.gcn_config = gcn_config 202 | self.config = config 203 | 204 | # transfer from gcn embeddings to lang shape 205 | self.gcn_embedding = nn.Linear(gcn_config['gnn_embed_dim'], config.hidden_size, bias=True) 206 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 207 | self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12) 208 | 209 | 210 | self.use_rdkit_feature = use_rdkit_feature 211 | self.use_label_weight = use_label_weight 212 | 213 | 214 | if self.use_rdkit_feature: 215 | self.head = RobertaFeatureHead(config, regression=is_regression) 216 | else: 217 | self.head = RobertaHead(config, regression=is_regression) 218 | 219 | # self.head = RobertaFeatureHead(config, regression=is_regression) 220 | 221 | 222 | self.is_regression = is_regression 223 | if is_regression: 224 | self.register_buffer("norm_mean", torch.tensor(config.norm_mean)) 225 | # Replace any 0 stddev norms with 1 226 | self.register_buffer( 227 | "norm_std", 228 | torch.tensor( 229 | [label_std if label_std != 0 else 1 for label_std in config.norm_std] 230 | ), 231 | ) 232 | 233 | self.init_weights() 234 | 235 | self.task_weight = None 236 | 237 | self.protein = ProteinModel(emb_dim=config.hidden_size, output_dim=config.hidden_size) 238 | self.fc1 = nn.Linear(config.hidden_size + config.hidden_size, 1024) 239 | self.fc2 = nn.Linear(1024, 512) 240 | self.out = nn.Linear(512, 1) 241 | # self.molecule_model = molecule_model 242 | # self.protein_model = protein_model 243 | # self.pool = global_mean_pool 244 | self.relu = nn.ReLU() 245 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 246 | 247 | 248 | def set_task_weight(self, task_weight): 249 | self.task_weight = task_weight 250 | 251 | def forward( 252 | self, 253 | input_ids=None, 254 | attention_mask=None, 255 | token_type_ids=None, 256 | position_ids=None, 257 | 258 | graph=None, 259 | # strucpos_ids=None, 260 | head_mask=None, 261 | inputs_embeds=None, 262 | labels=None, 263 | weight=None, 264 | output_attentions=None, 265 | output_hidden_states=None, 266 | fp_feature = None, 267 | return_dict=None, 268 | protein_encoding=None, 269 | ): 270 | """ 271 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 272 | Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., 273 | config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), 274 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). 275 | """ 276 | return_dict = ( 277 | return_dict if return_dict is not None else self.config.use_return_dict 278 | ) 279 | # graph_inputs = Batch.from_data_list(graph) 280 | graph.to(self.device) 281 | gcn_output = self.gnn(graph) 282 | # concat graph atome embeddings and langua embeddings 283 | gcn_embedding_output = self.gcn_embedding(gcn_output[1]) 284 | gcn_embedding_output = self.LayerNorm(gcn_embedding_output) 285 | gcn_embedding_output = self.dropout(gcn_embedding_output) 286 | 287 | 288 | # pad the gcn_embedding same shape with pos_coord_matrix_pad 289 | gcn_embedding_lst = [] 290 | batch_size = input_ids.shape[0] 291 | batch_idx = graph.batch 292 | 293 | graph_attention_mask = [] 294 | for bs in range(batch_size): 295 | gcn_embedding_lst.append(gcn_embedding_output[batch_idx == bs]) 296 | atom_num = (batch_idx == bs).sum().item() 297 | graph_attention_mask.append(torch.tensor([1 for _ in range(atom_num)]).to(self.device)) 298 | 299 | graph_attention_mask = collate_tokens(graph_attention_mask, pad_idx=0, pad_to_multiple=8) 300 | graph_attention_mask = graph_attention_mask.to(torch.bool) 301 | 302 | lang_gcn_outputs, lang_gcn_attention_mask = self.lang_roberta( 303 | input_ids, 304 | attention_mask=attention_mask, 305 | # token_type_ids=lingua['token_type_ids'], 306 | position_ids=None, 307 | head_mask=None, 308 | inputs_embeds=None, 309 | output_attentions=None, 310 | output_hidden_states=True if self.config.pooler_type in ['avg_top2', 'avg_first_last'] else False, 311 | return_dict=True, 312 | 313 | graph_input = gcn_embedding_lst, 314 | graph_batch = graph.batch, 315 | # graph_max_seq_size = self.gcn_config['graph_max_seq_size'], 316 | gnn_mask_labels = None, 317 | graph_attention_mask = graph_attention_mask, 318 | ) 319 | lang_gcn_pooler_output = self.lang_pooler(lang_gcn_attention_mask, lang_gcn_outputs) 320 | 321 | protein_representation = self.protein(protein_encoding) # get protein representations 322 | 323 | x = torch.cat([lang_gcn_pooler_output, protein_representation], dim=1) 324 | 325 | x = self.fc1(x) 326 | x = self.relu(x) 327 | x = self.dropout(x) 328 | x = self.fc2(x) 329 | x = self.relu(x) 330 | x = self.dropout(x) 331 | x = self.out(x) 332 | 333 | loss_fct = SmoothL1Loss() 334 | if labels is None: 335 | return self.unnormalize_logits(x).float() 336 | normalized_labels = self.normalize_logits(labels).float() 337 | loss = loss_fct(x.view(-1), normalized_labels) 338 | 339 | return [loss] 340 | 341 | 342 | def normalize_logits(self, tensor): 343 | return (tensor - self.norm_mean) / self.norm_std 344 | 345 | def unnormalize_logits(self, tensor): 346 | return (tensor * self.norm_std) + self.norm_mean --------------------------------------------------------------------------------