├── __init__.py ├── chemprop ├── models │ ├── __init__.py │ ├── concrete_dropout.py │ └── mpn.py ├── data │ ├── __init__.py │ ├── scaler.py │ ├── scaffold.py │ ├── data.py │ └── utils.py ├── __init__.py ├── features │ ├── __init__.py │ ├── utils.py │ └── features_generators.py ├── train │ ├── __init__.py │ ├── active_learning.py │ ├── active_learning_ofd.py │ ├── active_learning_scf.py │ ├── active_learning_scf_ofd20.py │ ├── active_learning_scf_ofd50.py │ ├── active_learning_scf_mixofd.py │ ├── cross_validate.py │ ├── cross_validate_multimodel.py │ ├── train.py │ ├── train_multimodel.py │ ├── evaluate_multimodel.py │ ├── predict.py │ ├── make_predictions_atomicUnc_multiMol.py │ ├── evaluate.py │ ├── make_predictions_atomic_unc_onemol.py │ ├── make_predictions_atomic_unc.py │ ├── make_predictions.py │ └── run_training.py ├── atom_plot │ ├── utils.py │ └── molecule_drawer.py ├── random_forest.py └── nn_utils.py ├── images ├── TOC.jpeg ├── TOC2.jpeg ├── image.jpeg ├── image2.jpeg └── draw_predicted_molecule_images.png ├── utils ├── makedir.py ├── scaffold_check.py ├── sort_unc-r0.py ├── plot_2unit_2layer.py ├── uncertainty_metrics.py └── hyperparameter_optimization.py ├── .gitignore ├── predict_atomicunc.py ├── predict.py ├── train.py ├── active_learning ├── train_multimodel.py ├── train_atl.py ├── train_atl_3.py ├── train_atl_ofd.py ├── train_atl_ofd_5.py └── train_atl_scf.py ├── draw_predicted_molecules.py ├── LICENSE └── README.md /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chemprop/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import build_model, MoleculeModel 2 | -------------------------------------------------------------------------------- /images/TOC.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chuiyang/atom-based_uncertainty_model/HEAD/images/TOC.jpeg -------------------------------------------------------------------------------- /images/TOC2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chuiyang/atom-based_uncertainty_model/HEAD/images/TOC2.jpeg -------------------------------------------------------------------------------- /images/image.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chuiyang/atom-based_uncertainty_model/HEAD/images/image.jpeg -------------------------------------------------------------------------------- /images/image2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chuiyang/atom-based_uncertainty_model/HEAD/images/image2.jpeg -------------------------------------------------------------------------------- /utils/makedir.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | print(sys.argv[1]) 4 | os.makedirs(f'{sys.argv[1]}', exist_ok=True) 5 | -------------------------------------------------------------------------------- /images/draw_predicted_molecule_images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chuiyang/atom-based_uncertainty_model/HEAD/images/draw_predicted_molecule_images.png -------------------------------------------------------------------------------- /chemprop/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import MoleculeDatapoint, MoleculeDataset 2 | from .scaffold import scaffold_to_smiles 3 | from .scaler import StandardScaler 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.pyc 3 | *.ipynb 4 | *.png 5 | *.svg 6 | *.yml 7 | 8 | !/images 9 | !*.py 10 | !.gitignore 11 | 12 | /data 13 | /model 14 | /saved 15 | /dev_files -------------------------------------------------------------------------------- /chemprop/__init__.py: -------------------------------------------------------------------------------- 1 | import chemprop.data 2 | import chemprop.features 3 | import chemprop.models 4 | import chemprop.train 5 | 6 | import chemprop.nn_utils 7 | import chemprop.parsing 8 | import chemprop.utils 9 | -------------------------------------------------------------------------------- /chemprop/features/__init__.py: -------------------------------------------------------------------------------- 1 | from .features_generators import get_available_features_generators, get_features_generator 2 | from .featurization import atom_features, bond_features, BatchMolGraph, get_atom_fdim, get_bond_fdim, mol2graph, clear_cache 3 | from .utils import load_features, save_features 4 | -------------------------------------------------------------------------------- /predict_atomicunc.py: -------------------------------------------------------------------------------- 1 | """Loads a trained model checkpoint and makes predictions on a dataset.""" 2 | 3 | from chemprop.parsing import parse_draw_molecules_args 4 | from chemprop.train import make_predictions_atomic_unc_onemol 5 | 6 | if __name__ == '__main__': 7 | args = parse_draw_molecules_args() 8 | make_predictions_atomic_unc_onemol(args) 9 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | """Loads a trained model checkpoint and makes predictions on a dataset.""" 2 | 3 | from chemprop.parsing import parse_predict_args 4 | from chemprop.train import make_predictions 5 | from pprint import pformat 6 | 7 | if __name__ == '__main__': 8 | args = parse_predict_args() 9 | print(pformat(vars(args))) 10 | make_predictions(args) 11 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Trains a model on a dataset.""" 2 | 3 | from chemprop.parsing import parse_train_args 4 | from chemprop.train import cross_validate 5 | from chemprop.utils import create_logger 6 | 7 | 8 | if __name__ == '__main__': 9 | args = parse_train_args() 10 | logger = create_logger(name='train', save_dir=args.save_dir, quiet=args.quiet) 11 | cross_validate(args, logger) 12 | -------------------------------------------------------------------------------- /active_learning/train_multimodel.py: -------------------------------------------------------------------------------- 1 | """Trains a model on a dataset.""" 2 | 3 | from chemprop.parsing import parse_train_args 4 | from chemprop.utils import create_logger 5 | from chemprop.train import cross_validate_multimodel 6 | 7 | if __name__ == '__main__': 8 | args = parse_train_args() 9 | logger = create_logger(name='train', save_dir=args.save_dir, quiet=args.quiet) 10 | cross_validate_multimodel(args, logger) 11 | -------------------------------------------------------------------------------- /chemprop/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .cross_validate import cross_validate 2 | from .evaluate import evaluate, evaluate_predictions 3 | from .make_predictions import make_predictions 4 | from .predict import predict 5 | from .run_training import run_training 6 | from .train import train 7 | from .active_learning import active_learning 8 | from .run_training_atl import run_training_atl 9 | from .make_predictions_atomic_unc import make_predictions_atomic_unc 10 | from .make_predictions_atomic_unc_onemol import make_predictions_atomic_unc_onemol 11 | from .make_predictions_atomicUnc_multiMol import make_predictions_atomicUnc_multiMol 12 | from .train_multimodel import train_multimodel 13 | from .cross_validate_multimodel import cross_validate_multimodel 14 | from .evaluate_multimodel import evaluate_predictions_multimodel, evaluate_multimodel -------------------------------------------------------------------------------- /draw_predicted_molecules.py: -------------------------------------------------------------------------------- 1 | """Loads a trained model checkpoint and makes predictions on a dataset.""" 2 | import os 3 | import logging 4 | 5 | from chemprop.parsing import parse_draw_molecules_args 6 | from chemprop.train import make_predictions_atomicUnc_multiMol 7 | 8 | 9 | def setup_logger(name, log_file, level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'): 10 | logger = logging.getLogger(name) 11 | logger.setLevel(level) 12 | 13 | file_handler = logging.FileHandler(log_file, 'w') 14 | file_handler.setLevel(level) 15 | 16 | formatter = logging.Formatter(format) 17 | file_handler.setFormatter(formatter) 18 | 19 | logger.addHandler(file_handler) 20 | return logger 21 | 22 | if __name__ == '__main__': 23 | args = parse_draw_molecules_args() 24 | 25 | log_file_path = os.path.join(args.draw_mols_dir, 'draw_molecules.log') 26 | logger = setup_logger('', log_file_path) 27 | make_predictions_atomicUnc_multiMol(args, smiles=None, logger=logger) 28 | 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Chu-I Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /chemprop/train/active_learning.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from logging import Logger 3 | import os 4 | import csv 5 | from typing import Tuple 6 | 7 | import numpy as np 8 | 9 | from .run_training_atl import run_training_atl 10 | from chemprop.data.utils import get_task_names 11 | from chemprop.utils import makedirs 12 | 13 | 14 | def active_learning(args: Namespace, logger: Logger = None, active_iter: int = 0, logger_all: Logger = None) -> Tuple[float, float]: 15 | """k-fold cross validation""" 16 | info = logger.info if logger is not None else print 17 | info_all = logger_all.info if logger_all is not None else print 18 | 19 | # Initialize relevant variables 20 | save_dir = args.save_dir 21 | 22 | # only one fold 23 | args.active_dir = os.path.join(save_dir, f'active_iter{active_iter}') 24 | makedirs(args.active_dir) 25 | 26 | # test logger_all 27 | model_scores, model_rmse, model_mae = run_training_atl(args, logger, active_iter, logger_all) 28 | 29 | info_all(f'active learning iter: {active_iter}, test {args.metric}: {model_scores:.6f}, test rmse: {model_rmse:.6f}, test_mae: {model_mae:.6f}\n') 30 | info(f'active learning iter: {active_iter}, test {args.metric}: {model_scores:.6f}, test rmse: {model_rmse:.6f}, test mae: {model_mae:.6f}\n') 31 | 32 | active_log = open(f'{save_dir}/active_log.csv', 'a') 33 | writer = csv.writer(active_log) 34 | writer.writerow([active_iter, round(model_scores, 5), round(model_rmse, 5), round(model_mae, 5)]) 35 | active_log.close() 36 | 37 | -------------------------------------------------------------------------------- /chemprop/train/active_learning_ofd.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from logging import Logger 3 | import os 4 | import csv 5 | from typing import Tuple 6 | 7 | import numpy as np 8 | 9 | from .run_training_atl_ofd import run_training_atl_ofd 10 | from chemprop.data.utils import get_task_names 11 | from chemprop.utils import makedirs 12 | 13 | 14 | def active_learning_ofd(args: Namespace, logger: Logger = None, active_iter: int = 0, logger_all: Logger = None) -> Tuple[float, float]: 15 | """k-fold cross validation""" 16 | info = logger.info if logger is not None else print 17 | info_all = logger_all.info if logger_all is not None else print 18 | 19 | # Initialize relevant variables 20 | save_dir = args.save_dir 21 | 22 | # only one fold 23 | args.active_dir = os.path.join(save_dir, f'active_iter{active_iter}') 24 | makedirs(args.active_dir) 25 | 26 | # test logger_all 27 | model_scores, model_rmse, model_mae = run_training_atl_ofd(args, logger, active_iter, logger_all) 28 | 29 | info_all(f'active learning iter: {active_iter}, test {args.metric}: {model_scores:.6f}, test rmse: {model_rmse:.6f}, test_mae: {model_mae:.6f}\n') 30 | info(f'active learning iter: {active_iter}, test {args.metric}: {model_scores:.6f}, test rmse: {model_rmse:.6f}, test mae: {model_mae:.6f}\n') 31 | 32 | active_log = open(f'{save_dir}/active_log.csv', 'a') 33 | writer = csv.writer(active_log) 34 | writer.writerow([active_iter, round(model_scores, 5), round(model_rmse, 5), round(model_mae, 5)]) 35 | active_log.close() 36 | 37 | -------------------------------------------------------------------------------- /chemprop/train/active_learning_scf.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from logging import Logger 3 | import os 4 | import csv 5 | from typing import Tuple 6 | 7 | import numpy as np 8 | 9 | from .run_training_atl_scf import run_training_atl_scf 10 | from chemprop.data.utils import get_task_names 11 | from chemprop.utils import makedirs 12 | 13 | 14 | def active_learning_scf(args: Namespace, logger: Logger = None, active_iter: int = 0, logger_all: Logger = None) -> Tuple[float, float]: 15 | """k-fold cross validation""" 16 | info = logger.info if logger is not None else print 17 | info_all = logger_all.info if logger_all is not None else print 18 | 19 | # Initialize relevant variables 20 | save_dir = args.save_dir 21 | 22 | # only one fold 23 | args.active_dir = os.path.join(save_dir, f'active_iter{active_iter}') 24 | makedirs(args.active_dir) 25 | 26 | # test logger_all 27 | model_scores, model_rmse, model_mae = run_training_atl_scf(args, logger, active_iter, logger_all) 28 | 29 | info_all(f'active learning iter: {active_iter}, test {args.metric}: {model_scores:.6f}, test rmse: {model_rmse:.6f}, test_mae: {model_mae:.6f}\n') 30 | info(f'active learning iter: {active_iter}, test {args.metric}: {model_scores:.6f}, test rmse: {model_rmse:.6f}, test mae: {model_mae:.6f}\n') 31 | 32 | active_log = open(f'{save_dir}/active_log.csv', 'a') 33 | writer = csv.writer(active_log) 34 | writer.writerow([active_iter, round(model_scores, 5), round(model_rmse, 5), round(model_mae, 5)]) 35 | active_log.close() 36 | 37 | -------------------------------------------------------------------------------- /chemprop/train/active_learning_scf_ofd20.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from logging import Logger 3 | import os 4 | import csv 5 | from typing import Tuple 6 | 7 | import numpy as np 8 | 9 | from .run_training_atl_scf_ofd20 import run_training_atl_scf_ofd20 10 | from chemprop.data.utils import get_task_names 11 | from chemprop.utils import makedirs 12 | 13 | 14 | def active_learning_scf_ofd20(args: Namespace, logger: Logger = None, active_iter: int = 0, logger_all: Logger = None) -> Tuple[float, float]: 15 | """k-fold cross validation""" 16 | info = logger.info if logger is not None else print 17 | info_all = logger_all.info if logger_all is not None else print 18 | 19 | # Initialize relevant variables 20 | save_dir = args.save_dir 21 | 22 | # only one fold 23 | args.active_dir = os.path.join(save_dir, f'active_iter{active_iter}') 24 | makedirs(args.active_dir) 25 | 26 | # test logger_all 27 | model_scores, model_rmse, model_mae = run_training_atl_scf_ofd20(args, logger, active_iter, logger_all) 28 | 29 | info_all(f'active learning iter: {active_iter}, test {args.metric}: {model_scores:.6f}, test rmse: {model_rmse:.6f}, test_mae: {model_mae:.6f}\n') 30 | info(f'active learning iter: {active_iter}, test {args.metric}: {model_scores:.6f}, test rmse: {model_rmse:.6f}, test mae: {model_mae:.6f}\n') 31 | 32 | active_log = open(f'{save_dir}/active_log.csv', 'a') 33 | writer = csv.writer(active_log) 34 | writer.writerow([active_iter, round(model_scores, 5), round(model_rmse, 5), round(model_mae, 5)]) 35 | active_log.close() 36 | 37 | -------------------------------------------------------------------------------- /chemprop/train/active_learning_scf_ofd50.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from logging import Logger 3 | import os 4 | import csv 5 | from typing import Tuple 6 | 7 | import numpy as np 8 | 9 | from .run_training_atl_scf_ofd50 import run_training_atl_scf_ofd50 10 | from chemprop.data.utils import get_task_names 11 | from chemprop.utils import makedirs 12 | 13 | 14 | def active_learning_scf_ofd50(args: Namespace, logger: Logger = None, active_iter: int = 0, logger_all: Logger = None) -> Tuple[float, float]: 15 | """k-fold cross validation""" 16 | info = logger.info if logger is not None else print 17 | info_all = logger_all.info if logger_all is not None else print 18 | 19 | # Initialize relevant variables 20 | save_dir = args.save_dir 21 | 22 | # only one fold 23 | args.active_dir = os.path.join(save_dir, f'active_iter{active_iter}') 24 | makedirs(args.active_dir) 25 | 26 | # test logger_all 27 | model_scores, model_rmse, model_mae = run_training_atl_scf_ofd50(args, logger, active_iter, logger_all) 28 | 29 | info_all(f'active learning iter: {active_iter}, test {args.metric}: {model_scores:.6f}, test rmse: {model_rmse:.6f}, test_mae: {model_mae:.6f}\n') 30 | info(f'active learning iter: {active_iter}, test {args.metric}: {model_scores:.6f}, test rmse: {model_rmse:.6f}, test mae: {model_mae:.6f}\n') 31 | 32 | active_log = open(f'{save_dir}/active_log.csv', 'a') 33 | writer = csv.writer(active_log) 34 | writer.writerow([active_iter, round(model_scores, 5), round(model_rmse, 5), round(model_mae, 5)]) 35 | active_log.close() 36 | 37 | -------------------------------------------------------------------------------- /chemprop/train/active_learning_scf_mixofd.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from logging import Logger 3 | import os 4 | import csv 5 | from typing import Tuple 6 | 7 | import numpy as np 8 | 9 | from .run_training_atl_scf_mixofd import run_training_atl_scf_mixofd 10 | from chemprop.data.utils import get_task_names 11 | from chemprop.utils import makedirs 12 | 13 | 14 | def active_learning_scf_mixofd(args: Namespace, logger: Logger = None, active_iter: int = 0, logger_all: Logger = None) -> Tuple[float, float]: 15 | """k-fold cross validation""" 16 | info = logger.info if logger is not None else print 17 | info_all = logger_all.info if logger_all is not None else print 18 | 19 | # Initialize relevant variables 20 | save_dir = args.save_dir 21 | 22 | # only one fold 23 | args.active_dir = os.path.join(save_dir, f'active_iter{active_iter}') 24 | makedirs(args.active_dir) 25 | 26 | # test logger_all 27 | model_scores, model_rmse, model_mae = run_training_atl_scf_mixofd(args, logger, active_iter, logger_all) 28 | 29 | info_all(f'active learning iter: {active_iter}, test {args.metric}: {model_scores:.6f}, test rmse: {model_rmse:.6f}, test_mae: {model_mae:.6f}\n') 30 | info(f'active learning iter: {active_iter}, test {args.metric}: {model_scores:.6f}, test rmse: {model_rmse:.6f}, test mae: {model_mae:.6f}\n') 31 | 32 | active_log = open(f'{save_dir}/active_log.csv', 'a') 33 | writer = csv.writer(active_log) 34 | writer.writerow([active_iter, round(model_scores, 5), round(model_rmse, 5), round(model_mae, 5)]) 35 | active_log.close() 36 | 37 | -------------------------------------------------------------------------------- /utils/scaffold_check.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from rdkit import Chem 4 | from rdkit.Chem.Scaffolds import MurckoScaffold 5 | 6 | 7 | train_data = pd.read_csv('./saved_models/qm9_130k_scaf/qm9_130k_pear_scale_2l_21e_scaf_lr/fold_0/train_smiles.csv').values[:, 0] 8 | 9 | scaf_set = set() 10 | for train_smiles in train_data: 11 | 12 | scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=Chem.MolFromSmiles(train_smiles)) 13 | scaf_set.add(scaffold) 14 | 15 | print(f'len(scaf_set): {len(scaf_set)}') 16 | 17 | 18 | 19 | output = [] 20 | 21 | 22 | ccsd = pd.read_csv('./data/new_heavy_atom/new_binary/heavy_atom_9.csv').values 23 | print(f'ccsd.shape: {ccsd.shape}') 24 | 25 | for ccsd_smiles, ccsd_hf in zip(ccsd[:, 0], ccsd[:, 1]): 26 | scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=Chem.MolFromSmiles(ccsd_smiles)) 27 | if scaffold not in scaf_set: 28 | output.append([ccsd_smiles, ccsd_hf]) 29 | 30 | print(f'len(output): {len(output)}') 31 | 32 | 33 | ccsd = pd.read_csv('./data/new_heavy_atom/new_binary/heavy_atom_10.csv').values 34 | print(f'ccsd.shape: {ccsd.shape}') 35 | for ccsd_smiles, ccsd_hf in zip(ccsd[:, 0], ccsd[:, 1]): 36 | scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=Chem.MolFromSmiles(ccsd_smiles)) 37 | if scaffold not in scaf_set: 38 | output.append([ccsd_smiles, ccsd_hf]) 39 | 40 | output = np.array(output) 41 | 42 | print(f'output.shape: {output.shape}') 43 | 44 | 45 | pd.DataFrame(output, columns=['smiles', 'Hf']).to_csv('./saved_models/qm9_130k_scaf/qm9_130k_pear_scale_2l_21e_scaf_lr/fold_0/ccsd_hf_scaffoldNotInTrain.csv', index=False) -------------------------------------------------------------------------------- /utils/sort_unc-r0.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from sklearn.model_selection import train_test_split 4 | import os 5 | 6 | load_remain_data = pd.read_csv('./saved_models/qm9_130k_rbf_unscale_atl_rand_bs50/active_iter0/saved_data/remain_pred.csv').values 7 | np.random.seed(10) 8 | np.random.shuffle(load_remain_data) 9 | train_data = pd.read_csv('./saved_models/qm9_130k_rbf_unscale_atl_rand_bs50/active_iter0/saved_data/train_full.csv').values 10 | val_data = pd.read_csv('./saved_models/qm9_130k_rbf_unscale_atl_rand_bs50/active_iter0/saved_data/val_full.csv').values 11 | 12 | 13 | k_samples = load_remain_data.shape[0] // 6 14 | print(f'amount of remain data at first round: {load_remain_data.shape[0]}') 15 | print(f'arg.k_samples: {k_samples}') 16 | kremain_data = load_remain_data[:k_samples, :2] 17 | load_remain_data = load_remain_data[k_samples:, :2] 18 | concat_train_val_kremain = np.vstack((train_data, val_data, kremain_data)) 19 | 20 | 21 | np.random.shuffle(concat_train_val_kremain) 22 | 23 | new_amount_of_train_data = int(concat_train_val_kremain.shape[0]*0.9) 24 | 25 | train_data = concat_train_val_kremain[:new_amount_of_train_data, :] 26 | val_data = concat_train_val_kremain[new_amount_of_train_data:, :] 27 | 28 | print(f'train_data: {train_data.shape}, val_data: {val_data.shape}, remain: {load_remain_data.shape}') 29 | 30 | 31 | for dataset, name in [(train_data, 'train'), (val_data, 'val'), (load_remain_data, 'remain')]: 32 | pd.DataFrame(dataset, columns=['smiles', 'true']).to_csv(os.path.join('./saved_models/qm9_130k_rbf_unscale_atl_rand_bs50/active_iter1/saved_data', name + '_full.csv'), index=False) 33 | -------------------------------------------------------------------------------- /active_learning/train_atl.py: -------------------------------------------------------------------------------- 1 | """Trains a model on a dataset.""" 2 | 3 | from chemprop.parsing import parse_train_args 4 | from chemprop.train import active_learning 5 | from chemprop.utils import create_logger, create_logger_atl_all 6 | 7 | import matplotlib.pyplot as plt 8 | import pandas as pd 9 | 10 | if __name__ == '__main__': 11 | args = parse_train_args() 12 | logger_all = create_logger_atl_all(name='train_atl', save_dir=args.save_dir, quiet=True) 13 | for active_iter in range(0, 7): 14 | logger = create_logger(name=f'train_atl_{active_iter}', save_dir=args.save_dir, quiet=args.quiet, active_iter=active_iter) 15 | active_learning(args, logger, active_iter=active_iter, logger_all=logger_all) 16 | 17 | # plot results 18 | active_log = pd.read_csv(f'{args.save_dir}/active_log.csv', header=None).values 19 | fig, ax1 = plt.subplots() 20 | color = 'tab:red' 21 | ax1.plot(active_log[:, 0], active_log[:, 1], color=color, marker='o') 22 | ax1.set_xlabel('Iterations') 23 | ax1.set_ylabel('Heterosedastic loss of testing data', color=color) 24 | ax1.tick_params(axis='y', labelcolor=color) 25 | color = 'tab:blue' 26 | ax2 = ax1.twinx() 27 | ax2.plot(active_log[:, 0], active_log[:, 2], color=color, marker='o', label='rmse') 28 | ax2.plot(active_log[:, 0], active_log[:, 3], color='darkcyan', marker='o', label='mae') 29 | ax2.set_ylabel('RMSE/MAE of testing data', color=color) 30 | ax2.tick_params(axis='y', labelcolor=color) 31 | fig.suptitle('Active Learning Log') 32 | plt.legend() 33 | plt.tight_layout() 34 | plt.savefig(f'{args.save_dir}/active_log.png', dpi=300) 35 | -------------------------------------------------------------------------------- /active_learning/train_atl_3.py: -------------------------------------------------------------------------------- 1 | """Trains a model on a dataset.""" 2 | 3 | from chemprop.parsing import parse_train_args 4 | from chemprop.train import active_learning 5 | from chemprop.utils import create_logger, create_logger_atl_all 6 | 7 | import matplotlib.pyplot as plt 8 | import pandas as pd 9 | 10 | if __name__ == '__main__': 11 | args = parse_train_args() 12 | logger_all = create_logger_atl_all(name='train_atl', save_dir=args.save_dir, quiet=True) 13 | for active_iter in range(3, 7): 14 | logger = create_logger(name=f'train_atl_{active_iter}', save_dir=args.save_dir, quiet=args.quiet, active_iter=active_iter) 15 | active_learning(args, logger, active_iter=active_iter, logger_all=logger_all) 16 | 17 | # plot results 18 | active_log = pd.read_csv(f'{args.save_dir}/active_log.csv', header=None).values 19 | fig, ax1 = plt.subplots() 20 | color = 'tab:red' 21 | ax1.plot(active_log[:, 0], active_log[:, 1], color=color, marker='o') 22 | ax1.set_xlabel('Iterations') 23 | ax1.set_ylabel('Heterosedastic loss of testing data', color=color) 24 | ax1.tick_params(axis='y', labelcolor=color) 25 | color = 'tab:blue' 26 | ax2 = ax1.twinx() 27 | ax2.plot(active_log[:, 0], active_log[:, 2], color=color, marker='o', label='rmse') 28 | ax2.plot(active_log[:, 0], active_log[:, 3], color='darkcyan', marker='o', label='mae') 29 | ax2.set_ylabel('RMSE/MAE of testing data', color=color) 30 | ax2.tick_params(axis='y', labelcolor=color) 31 | fig.suptitle('Active Learning Log') 32 | plt.legend() 33 | plt.tight_layout() 34 | plt.savefig(f'{args.save_dir}/active_log.png', dpi=300) 35 | -------------------------------------------------------------------------------- /active_learning/train_atl_ofd.py: -------------------------------------------------------------------------------- 1 | """Trains a model on a dataset.""" 2 | 3 | from chemprop.parsing import parse_train_args 4 | from chemprop.train import active_learning_ofd 5 | from chemprop.utils import create_logger, create_logger_atl_all 6 | 7 | import matplotlib.pyplot as plt 8 | import pandas as pd 9 | 10 | if __name__ == '__main__': 11 | args = parse_train_args() 12 | logger_all = create_logger_atl_all(name='train_atl', save_dir=args.save_dir, quiet=True) 13 | for active_iter in range(0, 7): 14 | logger = create_logger(name=f'train_atl_{active_iter}', save_dir=args.save_dir, quiet=args.quiet, active_iter=active_iter) 15 | active_learning_ofd(args, logger, active_iter=active_iter, logger_all=logger_all) 16 | 17 | # plot results 18 | active_log = pd.read_csv(f'{args.save_dir}/active_log.csv', header=None).values 19 | fig, ax1 = plt.subplots() 20 | color = 'tab:red' 21 | ax1.plot(active_log[:, 0], active_log[:, 1], color=color, marker='o') 22 | ax1.set_xlabel('Iterations') 23 | ax1.set_ylabel('Heterosedastic loss of testing data', color=color) 24 | ax1.tick_params(axis='y', labelcolor=color) 25 | color = 'tab:blue' 26 | ax2 = ax1.twinx() 27 | ax2.plot(active_log[:, 0], active_log[:, 2], color=color, marker='o', label='rmse') 28 | ax2.plot(active_log[:, 0], active_log[:, 3], color='darkcyan', marker='o', label='mae') 29 | ax2.set_ylabel('RMSE/MAE of testing data', color=color) 30 | ax2.tick_params(axis='y', labelcolor=color) 31 | fig.suptitle('Active Learning Log') 32 | plt.legend() 33 | plt.tight_layout() 34 | plt.savefig(f'{args.save_dir}/active_log.png', dpi=300) 35 | -------------------------------------------------------------------------------- /active_learning/train_atl_ofd_5.py: -------------------------------------------------------------------------------- 1 | """Trains a model on a dataset.""" 2 | 3 | from chemprop.parsing import parse_train_args 4 | from chemprop.train import active_learning_ofd 5 | from chemprop.utils import create_logger, create_logger_atl_all 6 | 7 | import matplotlib.pyplot as plt 8 | import pandas as pd 9 | 10 | if __name__ == '__main__': 11 | args = parse_train_args() 12 | logger_all = create_logger_atl_all(name='train_atl', save_dir=args.save_dir, quiet=True) 13 | for active_iter in range(5, 7): 14 | logger = create_logger(name=f'train_atl_{active_iter}', save_dir=args.save_dir, quiet=args.quiet, active_iter=active_iter) 15 | active_learning_ofd(args, logger, active_iter=active_iter, logger_all=logger_all) 16 | 17 | # plot results 18 | active_log = pd.read_csv(f'{args.save_dir}/active_log.csv', header=None).values 19 | fig, ax1 = plt.subplots() 20 | color = 'tab:red' 21 | ax1.plot(active_log[:, 0], active_log[:, 1], color=color, marker='o') 22 | ax1.set_xlabel('Iterations') 23 | ax1.set_ylabel('Heterosedastic loss of testing data', color=color) 24 | ax1.tick_params(axis='y', labelcolor=color) 25 | color = 'tab:blue' 26 | ax2 = ax1.twinx() 27 | ax2.plot(active_log[:, 0], active_log[:, 2], color=color, marker='o', label='rmse') 28 | ax2.plot(active_log[:, 0], active_log[:, 3], color='darkcyan', marker='o', label='mae') 29 | ax2.set_ylabel('RMSE/MAE of testing data', color=color) 30 | ax2.tick_params(axis='y', labelcolor=color) 31 | fig.suptitle('Active Learning Log') 32 | plt.legend() 33 | plt.tight_layout() 34 | plt.savefig(f'{args.save_dir}/active_log.png', dpi=300) 35 | -------------------------------------------------------------------------------- /utils/plot_2unit_2layer.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import os 5 | 6 | scaling = ['scale', 'unscale'] 7 | pairing = ['rbf', 'cos'] 8 | output_layer = ['2unit', '2layer'] 9 | 10 | mol_filename = [f'qm9_130k_mol_{scaling_i}' for scaling_i in scaling] 11 | atom_filename = [f'qm9_130k_{pairing_i}_{scaling_i}_{output_layer_i}' for output_layer_i in output_layer for scaling_i in scaling for pairing_i in pairing] 12 | 13 | mol_note = [f'm_{scaling_i}' for scaling_i in scaling] 14 | atom_note = [f'a_{pairing_i}_{scaling_i}_{output_layer_i}' for output_layer_i in output_layer for scaling_i in scaling for pairing_i in pairing] 15 | 16 | filename_list = mol_filename + atom_filename 17 | notes = mol_note + atom_note 18 | mean_rmse = [] 19 | mean_mae = [] 20 | 21 | plt.figure(figsize=(17, 11)) 22 | 23 | for note, filename in zip(notes, filename_list): 24 | fold_log = pd.read_csv(os.path.join('./saved_models', filename, 'fold_log.csv'), header=None).values 25 | fold = fold_log[:, 0] 26 | loss = fold_log[:, 1] 27 | rmse = fold_log[:, 2] 28 | mae = fold_log[:, 3] 29 | mean_rmse.append(np.round(np.mean(rmse), 3)) 30 | mean_mae.append(np.round(np.mean(mae), 3)) 31 | print(f'{note}, {np.round(np.mean(loss), 3)}, {np.round(np.mean(rmse), 3)}, {np.round(np.mean(mae), 3)}') 32 | 33 | 34 | ax1 = plt.subplot(2, 1, 1) 35 | ax1.bar(np.arange(len(mean_rmse)), mean_rmse) 36 | ax1.set_xticks(np.arange(len(mean_rmse))) 37 | ax1.set_xticklabels(notes) 38 | ax1.set_title('rmse') 39 | ax2 = plt.subplot(2, 1, 2) 40 | ax2.bar(np.arange(len(mean_mae)), mean_mae) 41 | ax2.set_xticks(np.arange(len(mean_mae))) 42 | ax2.set_xticklabels(notes) 43 | ax2.set_title('mae') 44 | plt.tight_layout() 45 | plt.savefig('./plot_2unit_2layer.png', dpi=800) 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /active_learning/train_atl_scf.py: -------------------------------------------------------------------------------- 1 | """Trains a model on a dataset.""" 2 | 3 | from chemprop.parsing import parse_train_args 4 | from chemprop.train import active_learning_scf 5 | from chemprop.utils import create_logger, create_logger_atl_all 6 | 7 | import matplotlib.pyplot as plt 8 | import pandas as pd 9 | import csv 10 | 11 | if __name__ == '__main__': 12 | args = parse_train_args() 13 | logger_all = create_logger_atl_all(name='train_atl', save_dir=args.save_dir, quiet=True) 14 | 15 | active_log = open(f'{args.save_dir}/active_log.csv', 'w') 16 | writer = csv.writer(active_log) 17 | writer.writerow([args.active_uncertainty,'','','']) 18 | active_log.close() 19 | 20 | for active_iter in range(0, 7): 21 | logger = create_logger(name=f'train_atl_{active_iter}', save_dir=args.save_dir, quiet=args.quiet, active_iter=active_iter) 22 | active_learning_scf(args, logger, active_iter=active_iter, logger_all=logger_all) 23 | 24 | # plot results 25 | active_log = pd.read_csv(f'{args.save_dir}/active_log.csv', header=None).values[1:, :] 26 | print(f'active log shape: {active_log.shape}') 27 | fig, ax1 = plt.subplots() 28 | color = 'tab:red' 29 | ax1.plot(active_log[:, 0], active_log[:, 1], color=color, marker='o') 30 | ax1.set_xlabel('Iterations') 31 | ax1.set_ylabel('Heterosedastic loss of testing data', color=color) 32 | ax1.tick_params(axis='y', labelcolor=color) 33 | color = 'tab:blue' 34 | ax2 = ax1.twinx() 35 | ax2.plot(active_log[:, 0], active_log[:, 2], color=color, marker='o', label='rmse') 36 | ax2.plot(active_log[:, 0], active_log[:, 3], color='darkcyan', marker='o', label='mae') 37 | ax2.set_ylabel('RMSE/MAE of testing data', color=color) 38 | ax2.tick_params(axis='y', labelcolor=color) 39 | fig.suptitle('Active Learning Log') 40 | plt.legend() 41 | plt.tight_layout() 42 | plt.savefig(f'{args.save_dir}/active_log.png', dpi=300) 43 | -------------------------------------------------------------------------------- /chemprop/features/utils.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import pickle 4 | from typing import List 5 | 6 | import numpy as np 7 | 8 | 9 | def save_features(path: str, features: List[np.ndarray]): 10 | """ 11 | Saves features to a compressed .npz file with array name "features". 12 | 13 | :param path: Path to a .npz file where the features will be saved. 14 | :param features: A list of 1D numpy arrays containing the features for molecules. 15 | """ 16 | np.savez_compressed(path, features=features) 17 | 18 | 19 | def load_features(path: str) -> np.ndarray: 20 | """ 21 | Loads features saved in a variety of formats. 22 | 23 | Supported formats: 24 | - .npz compressed (assumes features are saved with name "features") 25 | - .npz (assumes features are saved with name "features") 26 | - .npy 27 | - .csv/.txt (assumes comma-separated features with a header and with one line per molecule) 28 | - .pkl/.pckl/.pickle containing a sparse numpy array (TODO: remove this option once we are no longer dependent on it) 29 | 30 | All formats assume that the SMILES strings loaded elsewhere in the code are in the same 31 | order as the features loaded here. 32 | 33 | :param path: Path to a file containing features. 34 | :return: A 2D numpy array of size (num_molecules, features_size) containing the features. 35 | """ 36 | extension = os.path.splitext(path)[1] 37 | 38 | if extension == '.npz': 39 | features = np.load(path)['features'] 40 | elif extension == '.npy': 41 | features = np.load(path) 42 | elif extension in ['.csv', '.txt']: 43 | with open(path) as f: 44 | reader = csv.reader(f) 45 | next(reader) # skip header 46 | features = np.array([[float(value) for value in row] for row in reader]) 47 | elif extension in ['.pkl', '.pckl', '.pickle']: 48 | with open(path, 'rb') as f: 49 | features = np.array([np.squeeze(np.array(feat.todense())) for feat in pickle.load(f)]) 50 | else: 51 | raise ValueError(f'Features path extension {extension} not supported.') 52 | 53 | return features 54 | -------------------------------------------------------------------------------- /chemprop/train/cross_validate.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from logging import Logger 3 | import os 4 | from typing import Tuple 5 | 6 | import numpy as np 7 | import csv 8 | 9 | from .run_training import run_training 10 | from chemprop.data.utils import get_task_names 11 | from chemprop.utils import makedirs 12 | 13 | 14 | def cross_validate(args: Namespace, logger: Logger = None) -> Tuple[float, float]: 15 | """k-fold cross validation""" 16 | info = logger.info if logger is not None else print 17 | 18 | # Initialize relevant variables 19 | init_seed = args.seed 20 | save_dir = args.save_dir 21 | fold_log_dir = args.save_dir 22 | task_names = get_task_names(args.data_path) 23 | # Run training on different random seeds for each fold 24 | all_scores = [] 25 | for fold_num in range(args.num_folds): 26 | info(f'Fold {fold_num}') 27 | args.seed = init_seed + fold_num 28 | args.save_dir = os.path.join(save_dir, f'fold_{fold_num}') 29 | makedirs(args.save_dir) 30 | model_scores, model_rmse, model_mae = run_training(args, logger) 31 | all_scores.append(model_scores) 32 | 33 | fold_log = open(f'{fold_log_dir}/fold_log.csv', 'a') 34 | writer = csv.writer(fold_log) 35 | writer.writerow([fold_num, round(model_scores, 5), round(model_rmse, 5), round(model_mae, 5)]) 36 | fold_log.close() 37 | 38 | 39 | all_scores = np.array(all_scores) 40 | # Report results 41 | info(f'{args.num_folds}-fold cross validation') 42 | 43 | # Report scores for each fold 44 | for fold_num, scores in enumerate(all_scores): 45 | info(f'Seed {init_seed + fold_num} ==> test {args.metric} = {np.nanmean(scores):.6f}') 46 | 47 | if args.show_individual_scores: 48 | for task_name, score in zip(task_names, scores): 49 | info(f'Seed {init_seed + fold_num} ==> test {task_name} {args.metric} = {score:.6f}') 50 | 51 | # Report scores across models 52 | avg_scores = np.nanmean(all_scores, axis=0) # average score for each model across tasks #axis=1 53 | mean_score, std_score = np.nanmean(avg_scores), np.nanstd(avg_scores) 54 | info(f'Overall test {args.metric} = {mean_score:.6f} +/- {std_score:.6f}') 55 | 56 | if args.show_individual_scores: 57 | for task_num, task_name in enumerate(task_names): 58 | info(f'Overall test {task_name} {args.metric} = ' 59 | f'{np.nanmean(all_scores[:, task_num]):.6f} +/- {np.nanstd(all_scores[:, task_num]):.6f}') 60 | 61 | return mean_score, std_score 62 | 63 | -------------------------------------------------------------------------------- /chemprop/train/cross_validate_multimodel.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from logging import Logger 3 | import os 4 | from typing import Tuple 5 | 6 | import numpy as np 7 | import csv 8 | 9 | from .run_training_multimodel import run_training_multimodel 10 | from chemprop.data.utils import get_task_names 11 | from chemprop.utils import makedirs 12 | 13 | 14 | def cross_validate_multimodel(args: Namespace, logger: Logger = None) -> Tuple[float, float]: 15 | """k-fold cross validation""" 16 | info = logger.info if logger is not None else print 17 | 18 | # Initialize relevant variables 19 | init_seed = args.seed 20 | save_dir = args.save_dir 21 | fold_log_dir = args.save_dir 22 | task_names = get_task_names(args.data_path) 23 | # Run training on different random seeds for each fold 24 | all_scores = [] 25 | for fold_num in range(args.num_folds): 26 | info(f'Fold {fold_num}') 27 | args.seed = init_seed + fold_num 28 | args.save_dir = os.path.join(save_dir, f'fold_{fold_num}') 29 | makedirs(args.save_dir) 30 | model_scores, model_rmse, model_mae = run_training_multimodel(args, logger) 31 | all_scores.append(model_scores) 32 | 33 | fold_log = open(f'{fold_log_dir}/fold_log.csv', 'a') 34 | writer = csv.writer(fold_log) 35 | writer.writerow([fold_num, round(model_scores, 5), round(model_rmse, 5), round(model_mae, 5)]) 36 | fold_log.close() 37 | 38 | 39 | all_scores = np.array(all_scores) 40 | # Report results 41 | info(f'{args.num_folds}-fold cross validation') 42 | 43 | # Report scores for each fold 44 | for fold_num, scores in enumerate(all_scores): 45 | info(f'Seed {init_seed + fold_num} ==> test {args.metric} = {np.nanmean(scores):.6f}') 46 | 47 | if args.show_individual_scores: 48 | for task_name, score in zip(task_names, scores): 49 | info(f'Seed {init_seed + fold_num} ==> test {task_name} {args.metric} = {score:.6f}') 50 | 51 | # Report scores across models 52 | avg_scores = np.nanmean(all_scores, axis=0) # average score for each model across tasks #axis=1 53 | mean_score, std_score = np.nanmean(avg_scores), np.nanstd(avg_scores) 54 | info(f'Overall test {args.metric} = {mean_score:.6f} +/- {std_score:.6f}') 55 | 56 | if args.show_individual_scores: 57 | for task_num, task_name in enumerate(task_names): 58 | info(f'Overall test {task_name} {args.metric} = ' 59 | f'{np.nanmean(all_scores[:, task_num]):.6f} +/- {np.nanstd(all_scores[:, task_num]):.6f}') 60 | 61 | return mean_score, std_score 62 | 63 | -------------------------------------------------------------------------------- /chemprop/models/concrete_dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | # Partially based on https://github.com/yaringal/ConcreteDropout 6 | 7 | class ConcreteDropout(nn.Module): 8 | def __init__(self, layer, reg_acc, weight_regularizer=1e-6, 9 | dropout_regularizer=1e-5, init_min=0.1, init_max=0.1, depth=1): 10 | super(ConcreteDropout, self).__init__() 11 | 12 | 13 | self.weight_regularizer = weight_regularizer 14 | self.dropout_regularizer = dropout_regularizer 15 | self.layer = layer 16 | 17 | self.reg_acc = reg_acc 18 | self.reg_acc.notify_loss(depth) 19 | 20 | init_min = np.log(init_min) - np.log(1. - init_min) 21 | init_max = np.log(init_max) - np.log(1. - init_max) 22 | 23 | self.p_logit = nn.Parameter(torch.empty(1).uniform_(init_min, init_max)) 24 | 25 | 26 | def forward(self, x): 27 | p = torch.sigmoid(self.p_logit) 28 | 29 | out = self.layer(self._concrete_dropout(x, p)) 30 | 31 | if self.training: 32 | sum_of_square = 0 33 | for param in self.layer.parameters(): 34 | sum_of_square += torch.sum(torch.pow(param, 2)) 35 | 36 | #weights_regularizer = self.weight_regularizer * sum_of_square / (1 - p) 37 | weights_regularizer = self.weight_regularizer * sum_of_square * (1 - p) 38 | 39 | dropout_regularizer = p * torch.log(p) 40 | dropout_regularizer += (1. - p) * torch.log(1. - p) 41 | 42 | input_dimensionality = x[0].numel() # Number of elements of first item in batchs 43 | dropout_regularizer *= self.dropout_regularizer * input_dimensionality 44 | 45 | regularization = weights_regularizer + dropout_regularizer 46 | 47 | self.reg_acc.add_loss(regularization) 48 | 49 | input_dimensionality = x[0].numel() # Number of elements of first item in batch 50 | 51 | return out 52 | 53 | def _concrete_dropout(self, x, p): 54 | 55 | eps = 1e-7 56 | temp = 0.1 57 | 58 | unif_noise = torch.rand_like(x) 59 | 60 | drop_prob = (torch.log(p + eps) 61 | - torch.log(1 - p + eps) 62 | + torch.log(unif_noise + eps) 63 | - torch.log(1 - unif_noise + eps)) 64 | 65 | drop_prob = torch.sigmoid(drop_prob / temp) 66 | random_tensor = 1 - drop_prob 67 | retain_prob = 1 - p 68 | 69 | x = torch.mul(x, random_tensor) 70 | x /= retain_prob 71 | 72 | return x 73 | 74 | 75 | class RegularizationAccumulator: 76 | def __init__(self): 77 | self.i = 0 78 | self.size = 0 79 | 80 | def notify_loss(self, depth): 81 | self.size += depth 82 | 83 | def initialize(self, cuda): 84 | self.arr = torch.empty(self.size) 85 | if cuda: 86 | self.arr = self.arr.cuda() 87 | 88 | def add_loss(self, loss): 89 | self.arr[self.i] = loss 90 | self.i += 1 91 | 92 | def get_sum(self): 93 | sum = torch.sum(self.arr) 94 | 95 | # reset index and computational graph 96 | self.i = 0 97 | self.arr = self.arr.detach() 98 | 99 | return sum 100 | -------------------------------------------------------------------------------- /chemprop/data/scaler.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | 3 | import numpy as np 4 | 5 | 6 | class StandardScaler: 7 | """A StandardScaler normalizes a dataset. 8 | 9 | When fit on a dataset, the StandardScaler learns the mean and standard deviation across the 0th axis. 10 | When transforming a dataset, the StandardScaler subtracts the means and divides by the standard deviations. 11 | """ 12 | 13 | def __init__(self, means: np.ndarray = None, stds: np.ndarray = None, replace_nan_token: Any = None): 14 | """ 15 | Initialize StandardScaler, optionally with means and standard deviations precomputed. 16 | 17 | :param means: An optional 1D numpy array of precomputed means. 18 | :param stds: An optional 1D numpy array of precomputed standard deviations. 19 | :param replace_nan_token: The token to use in place of nans. 20 | """ 21 | self.means = means 22 | self.stds = stds 23 | self.replace_nan_token = replace_nan_token 24 | 25 | def fit(self, X: List[List[float]]) -> 'StandardScaler': 26 | """ 27 | Learns means and standard deviations across the 0th axis. 28 | 29 | :param X: A list of lists of floats. 30 | :return: The fitted StandardScaler. 31 | """ 32 | X = np.array(X).astype(float) 33 | # self.means = np.nanmean(X, axis=0) 34 | self.stds = np.nanstd(X, axis=0) 35 | # self.means = np.where(np.isnan(self.means), np.zeros(self.means.shape), self.means) 36 | self.stds = np.where(np.isnan(self.stds), np.ones(self.stds.shape), self.stds) 37 | self.stds = np.where(self.stds == 0, np.ones(self.stds.shape), self.stds) 38 | return self 39 | 40 | def transform(self, X: List[List[float]]): 41 | """ 42 | Transforms the data by subtracting the means and dividing by the standard deviations. 43 | 44 | :param X: A list of lists of floats. 45 | :return: The transformed data. 46 | """ 47 | X = np.array(X).astype(float) 48 | transformed_with_nan = X / self.stds # - self.means 49 | transformed_with_none = np.where(np.isnan(transformed_with_nan), self.replace_nan_token, transformed_with_nan) 50 | 51 | return transformed_with_none 52 | 53 | def inverse_transform(self, X: List[List[float]]): 54 | """ 55 | Performs the inverse transformation by multiplying by the standard deviations and adding the means. 56 | 57 | :param X: A list of lists of floats. 58 | :return: The inverse transformed data. 59 | """ 60 | X = np.array(X).astype(float) 61 | transformed_with_nan = X * self.stds # + self.means 62 | transformed_with_none = np.where(np.isnan(transformed_with_nan), self.replace_nan_token, transformed_with_nan) 63 | 64 | return transformed_with_none 65 | 66 | def inverse_transform_variance(self, X: List[List[float]]): 67 | """ 68 | Performs the inverse transformation by the squares of the standard deviations. 69 | 70 | :param X: A list of lists of floats. 71 | :return: The inverse transformed data. 72 | """ 73 | X = np.array(X).astype(float) 74 | transformed_with_nan = X * (self.stds**2) 75 | transformed_with_none = np.where(np.isnan(transformed_with_nan), self.replace_nan_token, transformed_with_nan) 76 | 77 | return transformed_with_none 78 | -------------------------------------------------------------------------------- /chemprop/atom_plot/utils.py: -------------------------------------------------------------------------------- 1 | from rdkit.Chem import AllChem 2 | from rdkit import Chem 3 | 4 | def mol_with_atom_index(mol): 5 | for atom in mol.GetAtoms(): 6 | atom.SetAtomMapNum(atom.GetIdx()+1) 7 | return mol 8 | 9 | def highlight_substructure(mol, mol_unc, atomic_unc): 10 | hl_atoms = [] 11 | hl_bonds = [] 12 | avg_unc = mol_unc / mol.GetNumAtoms() 13 | print(f'mol_unc: {mol_unc:.3f}, mol.GetNumAtoms(): {mol.GetNumAtoms()}, avg_unc: {avg_unc:.3f}') 14 | 15 | for a, a_unc in zip(mol.GetAtoms(), atomic_unc): 16 | print(f'atom index: {a.GetIdx()}') 17 | if a_unc > avg_unc: 18 | hl_atoms.append(a.GetIdx()) 19 | for b in mol.GetBonds(): 20 | b1, b2 = b.GetBeginAtomIdx(), b.GetEndAtomIdx() 21 | print(f'bonds index: {b.GetIdx()}') 22 | if (b1 in hl_atoms) and (b2 in hl_atoms): 23 | hl_bonds.append(b.GetIdx()) 24 | return hl_atoms, hl_bonds 25 | 26 | 27 | def unsave_atomUnc_large(mol, atomic_unc): 28 | for a, a_unc in zip(mol.GetAtoms(), atomic_unc): 29 | if (a.GetSymbol() == 'N') and (a_unc == max(atomic_unc)): 30 | return False 31 | return True 32 | 33 | 34 | def mol_with_atom_index(mol, atomic_unc=None): 35 | 36 | for atom in mol.GetAtoms(): 37 | atom.SetAtomMapNum(atom.GetIdx()+1) 38 | if atomic_unc is not None: 39 | for atom, a_unc in zip(mol.GetAtoms(), atomic_unc): 40 | atom.SetProp('atomNote', f'{a_unc:.1f}') 41 | # atom.SetProp('_displayLabel', '') 42 | # for bond in mol.GetBonds(): 43 | # bond.SetProp('displayLabel', '') 44 | # bond.SetProp('displayLabelW', '') 45 | return mol 46 | 47 | 48 | def highlight_substructure(mol, mol_unc, atomic_unc): 49 | hl_atoms = [] 50 | hl_bonds = [] 51 | avg_unc = mol_unc / mol.GetNumAtoms() 52 | save = True 53 | for a, a_unc in zip(mol.GetAtoms(), atomic_unc): 54 | if a_unc > avg_unc: 55 | hl_atoms.append(a.GetIdx()) 56 | # if (a.GetSymbol() == 'O'): ## if highlight O then do not save 57 | # save = False 58 | # elif (a.GetSymbol() == 'N') and (a_unc < 1): 59 | # save = False 60 | # elif a.GetSymbol() == 'N': ## if N is not highlight then do not save 61 | # save = False 62 | # if (a_unc == max(atomic_unc)) and (a.GetSymbol() != 'N'): ## if max atomic unc is not N 63 | # save = False 64 | for b in mol.GetBonds(): 65 | b1, b2 = b.GetBeginAtomIdx(), b.GetEndAtomIdx() 66 | if (b1 in hl_atoms) and (b2 in hl_atoms): 67 | hl_bonds.append(b.GetIdx()) 68 | return hl_atoms, hl_bonds, save 69 | 70 | 71 | def titlePos(mol): 72 | min_x, min_y = 0, 0 73 | AllChem.EmbedMolecule(mol) 74 | mh_conf = mol.GetConformer() 75 | for atom in mol.GetAtoms(): 76 | pos = mh_conf.GetAtomPosition(atom.GetIdx()) 77 | min_x = pos.x if min_x > pos.x else min_x 78 | min_y = pos.y if min_y > pos.y else min_y 79 | print(f'atom.GetIdx(): {atom.GetIdx()}, {pos.x}, {pos.y}') 80 | return min_x, min_y 81 | 82 | 83 | def has_atom(smile): 84 | atomSymbol = 'N' 85 | for atom in Chem.MolFromSmiles(smile).GetAtoms(): 86 | if atom.GetSymbol() == atomSymbol: 87 | return True 88 | return False 89 | 90 | def atomsize(smile): 91 | mol = Chem.MolFromSmiles(smile) 92 | if mol.GetNumHeavyAtoms() < 9: 93 | return True 94 | else: 95 | return False -------------------------------------------------------------------------------- /utils/uncertainty_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Chu-I Yang 3 | Description: Functions for calculating the x- and y-values of uncertainty quantification evaluation plots. 4 | 5 | (1) Error-based calibration curve 6 | (2) Confidence-based calibration curve 7 | """ 8 | 9 | import scipy 10 | import pandas as pd 11 | import numpy as np 12 | from numpy.typing import NDArray 13 | from typing import List, Tuple 14 | from sklearn.metrics import mean_squared_error 15 | 16 | 17 | 18 | Array1D = NDArray[np.float64] 19 | 20 | def cal_error_based_calibration_metrics(true_arr: Array1D, 21 | pred_arr: Array1D, 22 | unc_arr: Array1D, 23 | n_bins:int = 100) -> Tuple[List, List]: 24 | """ 25 | The error-based calibration curve examines the consistency between 26 | the expected error (measured by mean squared error, MSE) and the predicted uncertainty 27 | under the assumption that the estimator is unbiased. 28 | 29 | The error-based calibration curve is a parity plot between the root mean square error (RMSE) and the root mean uncertainty (RMU) 30 | 31 | - Parameters 32 | true_arr : 1D numpy array of true values in the dataset 33 | pred_arr : 1D numpy array of predicted values in the dataset 34 | unc_arr : 1D numpy array of the predicted uncertainty in the dataset (can be either epistemic, aleatoric or total uncertainty) 35 | n_bins : the number of bins for the dataset (default 100 bins, will calculate 100 points on the plot) 36 | 37 | - Returns 38 | rmu_bins : x-values of the error-based calibration curve 39 | rmse_bins : y-values of the error-based calibration curve 40 | """ 41 | absolute_error = np.abs(true_arr - pred_arr) 42 | 43 | sorted_matrix = np.vstack((true_arr, pred_arr, unc_arr, absolute_error)).T 44 | # sort data by uncertainty (predicted variance) 45 | sorted_matrix = sorted_matrix[np.argsort(sorted_matrix[:, 2].astype(float))] 46 | 47 | rmu_bins = [] 48 | rmse_bins = [] 49 | bin_size = len(true_arr) / n_bins 50 | 51 | for bin_i in range(n_bins): 52 | start = int(bin_i * bin_size) 53 | end = int((bin_i+1) * bin_size) 54 | rmse_bins.append(mean_squared_error(sorted_matrix[start:end, 0], sorted_matrix[start:end, 1], squared=False)) 55 | rmu_bins.append(np.sqrt(np.mean(sorted_matrix[start:end, 2]))) 56 | 57 | return rmu_bins, rmse_bins 58 | 59 | def cal_confidence_based_calibration_metrics(true_arr: Array1D, 60 | pred_arr: Array1D, 61 | unc_arr: Array1D, 62 | n_bins:int = 10) -> Tuple[List, List]: 63 | """ 64 | The confidence-based calibration curve examines the fraction of data that actually falls in each confidence level. 65 | 66 | - Parameters 67 | true_arr : 1D numpy array of true values in the dataset 68 | pred_arr : 1D numpy array of predicted values in the dataset 69 | unc_arr : 1D numpy array of the predicted uncertainty in the dataset (can be either epistemic, aleatoric or total uncertainty) 70 | n_bins : the number of bins for the dataset (default 10 bins) 71 | 72 | - Returns 73 | confidence_level : x-values of the error-based calibration curve 74 | fractions : y-values of the error-based calibration curve 75 | """ 76 | data_size = len(true_arr) 77 | confidence_level = np.linspace(0, 1, n_bins, endpoint=False) 78 | 79 | # the fraction of data that true value falls in the confidence interval 80 | fractions = [] 81 | for conf in confidence_level: 82 | count = 0 83 | for mean, var, true in zip(pred_arr, unc_arr, true_arr): 84 | lower_bound, upper_bound = scipy.stats.norm.interval(conf, loc=mean, scale=var**0.5) 85 | if lower_bound < true < upper_bound: 86 | count += 1 87 | fractions.append(count/data_size) 88 | 89 | return confidence_level, fractions 90 | 91 | 92 | 93 | 94 | 95 | 96 | -------------------------------------------------------------------------------- /utils/hyperparameter_optimization.py: -------------------------------------------------------------------------------- 1 | """Optimizes hyperparameters using Bayesian optimization.""" 2 | 3 | from argparse import ArgumentParser, Namespace 4 | from copy import deepcopy 5 | import json 6 | from typing import Dict, Union 7 | import os 8 | 9 | from hyperopt import fmin, hp, tpe 10 | import numpy as np 11 | 12 | from chemprop.models import build_model 13 | from chemprop.nn_utils import param_count 14 | from chemprop.parsing import add_train_args, modify_train_args 15 | from chemprop.train import cross_validate 16 | from chemprop.utils import create_logger, makedirs 17 | 18 | SPACE = { 19 | 'depth': hp.quniform('depth', low=2, high=6, q=1), 20 | 'hidden_size': hp.quniform('hidden_size',low=100,high=2000,q=100) 21 | } 22 | INT_KEYS = ['depth','hidden_size'] 23 | 24 | '''SPACE = { 25 | 'hidden_size': hp.quniform('hidden_size', low=300, high=2400, q=100), 26 | 'depth': hp.quniform('depth', low=2, high=6, q=1), 27 | 'dropout': hp.quniform('dropout', low=0.0, high=0.4, q=0.05), 28 | } 29 | INT_KEYS = ['hidden_size', 'depth']''' 30 | 31 | 32 | def grid_search(args: Namespace): 33 | # Create loggers 34 | logger = create_logger(name='hyperparameter_optimization', save_dir=args.log_dir, quiet=True) 35 | train_logger = create_logger(name='train', save_dir=args.save_dir, quiet=args.quiet) 36 | 37 | # Run grid search 38 | results = [] 39 | 40 | # Define hyperparameter optimization 41 | def objective(hyperparams: Dict[str, Union[int, float]]) -> float: 42 | # Convert hyperparams from float to int when necessary 43 | for key in INT_KEYS: 44 | hyperparams[key] = int(hyperparams[key]) 45 | 46 | # Update args with hyperparams 47 | hyper_args = deepcopy(args) 48 | if args.save_dir is not None: 49 | folder_name = '_'.join(f'{key}_{value}' for key, value in hyperparams.items()) 50 | hyper_args.save_dir = os.path.join(hyper_args.save_dir, folder_name) 51 | for key, value in hyperparams.items(): 52 | setattr(hyper_args, key, value) 53 | 54 | # Record hyperparameters 55 | logger.info(hyperparams) 56 | 57 | # Cross validate 58 | mean_score, std_score = cross_validate(hyper_args, train_logger) 59 | 60 | # Record results 61 | temp_model = build_model(hyper_args) 62 | num_params = param_count(temp_model) 63 | logger.info(f'num params: {num_params:,}') 64 | logger.info(f'{mean_score} +/- {std_score} {hyper_args.metric}') 65 | 66 | results.append({ 67 | 'mean_score': mean_score, 68 | 'std_score': std_score, 69 | 'hyperparams': hyperparams, 70 | 'num_params': num_params 71 | }) 72 | 73 | # Deal with nan 74 | if np.isnan(mean_score): 75 | if hyper_args.dataset_type == 'classification': 76 | mean_score = 0 77 | else: 78 | raise ValueError('Can\'t handle nan score for non-classification dataset.') 79 | 80 | return (1 if hyper_args.minimize_score else -1) * mean_score 81 | 82 | fmin(objective, SPACE, algo=tpe.suggest, max_evals=args.num_iters) 83 | 84 | # Report best result 85 | results = [result for result in results if not np.isnan(result['mean_score'])] 86 | best_result = min(results, key=lambda result: (1 if args.minimize_score else -1) * result['mean_score']) 87 | logger.info('best') 88 | logger.info(best_result['hyperparams']) 89 | logger.info(f'num params: {best_result["num_params"]:,}') 90 | logger.info(f'{best_result["mean_score"]} +/- {best_result["std_score"]} {args.metric}') 91 | 92 | # Save best hyperparameter settings as JSON config file 93 | makedirs(args.config_save_path, isfile=True) 94 | 95 | with open(args.config_save_path, 'w') as f: 96 | json.dump(best_result['hyperparams'], f, indent=4, sort_keys=True) 97 | 98 | 99 | if __name__ == '__main__': 100 | parser = ArgumentParser() 101 | add_train_args(parser) 102 | parser.add_argument('--num_iters', type=int, default=20, 103 | help='Number of hyperparameter choices to try') 104 | parser.add_argument('--config_save_path', type=str, required=True, 105 | help='Path to .json file where best hyperparameter settings will be written') 106 | parser.add_argument('--log_dir', type=str, 107 | help='(Optional) Path to a directory where all results of the hyperparameter optimization will be written') 108 | args = parser.parse_args() 109 | modify_train_args(args) 110 | 111 | grid_search(args) 112 | -------------------------------------------------------------------------------- /chemprop/train/train.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import logging 3 | from typing import Callable, List, Union 4 | 5 | from tensorboardX import SummaryWriter 6 | import torch 7 | import torch.nn as nn 8 | from torch.optim import Optimizer 9 | from torch.optim.lr_scheduler import _LRScheduler 10 | import tqdm 11 | 12 | from chemprop.data import MoleculeDataset 13 | from chemprop.nn_utils import compute_gnorm, compute_pnorm, NoamLR 14 | 15 | 16 | def train(model: nn.Module, 17 | data: Union[MoleculeDataset, List[MoleculeDataset]], 18 | loss_func: Callable, 19 | optimizer: Optimizer, 20 | scheduler: _LRScheduler, 21 | args: Namespace, 22 | n_iter: int = 0, 23 | logger: logging.Logger = None, 24 | writer: SummaryWriter = None) -> int: 25 | """ 26 | Trains a model for an epoch. 27 | 28 | :param model: Model. 29 | :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe). 30 | :param loss_func: Loss function. 31 | :param optimizer: An Optimizer. 32 | :param scheduler: A learning rate scheduler. 33 | :param args: Arguments. 34 | :param n_iter: The number of iterations (training examples) trained on so far. 35 | :param logger: A logger for printing intermediate results. 36 | :param writer: A tensorboardX SummaryWriter. 37 | :return: The total number of iterations (training examples) trained on so far. 38 | """ 39 | debug = logger.debug if logger is not None else print 40 | # info = logger.info if logger is not None else print 41 | 42 | model.train() 43 | 44 | data.shuffle() 45 | 46 | loss_sum, iter_count = 0, 0 47 | 48 | num_iters = len(data) // args.batch_size * args.batch_size # don't use the last batch if it's small, for stability 49 | 50 | iter_size = args.batch_size 51 | 52 | for i in range(0, num_iters, iter_size): 53 | # print(f'iter: {i}') # , args.batch_size: {args.batch_size}, {len(data)} 54 | # Prepare batch 55 | if i + args.batch_size > len(data): 56 | break 57 | mol_batch = MoleculeDataset(data[i:i + args.batch_size]) 58 | smiles_batch, features_batch, target_batch = mol_batch.smiles(), mol_batch.features(), mol_batch.targets() 59 | batch = smiles_batch 60 | mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch]) 61 | targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch]) 62 | 63 | if next(model.parameters()).is_cuda: 64 | mask, targets = mask.cuda(), targets.cuda() 65 | 66 | class_weights = torch.ones(targets.shape) 67 | 68 | if args.cuda: 69 | class_weights = class_weights.cuda() 70 | 71 | # Run model 72 | model.zero_grad() 73 | 74 | if not args.aleatoric: 75 | preds = model(batch, features_batch) 76 | if args.dataset_type == 'multiclass': 77 | targets = targets.long() 78 | loss = torch.cat([loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) for target_index in range(preds.size(1))], dim=1) * class_weights * mask 79 | else: 80 | loss = loss_func(preds, targets) * class_weights * mask 81 | else: 82 | means, logvars, _, _ = model(batch, features_batch) 83 | loss = loss_func(targets, means, logvars) * class_weights * mask 84 | 85 | loss = loss.sum() / mask.sum() 86 | # print(f'iter_count: {iter_count}, loss: {loss}, batch: {batch[:3]}') 87 | 88 | 89 | if args.epistemic == 'mc_dropout': 90 | reg_loss = args.reg_acc.get_sum() 91 | loss += reg_loss 92 | 93 | loss_sum += loss.item() 94 | iter_count += len(mol_batch) 95 | 96 | loss.backward() 97 | optimizer.step() 98 | 99 | if isinstance(scheduler, NoamLR): # class 100 | scheduler.step() 101 | 102 | n_iter += len(mol_batch) 103 | 104 | # Log and/or add to tensorboard 105 | if (n_iter // args.batch_size) % args.log_frequency == 0: 106 | lrs = scheduler.get_lr() 107 | pnorm = compute_pnorm(model) 108 | gnorm = compute_gnorm(model) 109 | loss_avg = loss_sum / iter_count 110 | loss_sum, iter_count = 0, 0 111 | 112 | lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs)) 113 | debug(f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}') 114 | # info(f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}') 115 | 116 | if writer is not None: 117 | writer.add_scalar('train_loss', loss_avg, n_iter) 118 | writer.add_scalar('param_norm', pnorm, n_iter) 119 | writer.add_scalar('gradient_norm', gnorm, n_iter) 120 | for i, lr in enumerate(lrs): 121 | writer.add_scalar(f'learning_rate_{i}', lr, n_iter) 122 | return n_iter 123 | -------------------------------------------------------------------------------- /chemprop/train/train_multimodel.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import logging 3 | from typing import Callable, List, Union 4 | 5 | from tensorboardX import SummaryWriter 6 | import torch 7 | import torch.nn as nn 8 | from torch.optim import Optimizer 9 | from torch.optim.lr_scheduler import _LRScheduler 10 | import tqdm 11 | 12 | from chemprop.data import MoleculeDataset 13 | from chemprop.nn_utils import compute_gnorm, compute_pnorm, NoamLR, InverseLR 14 | 15 | 16 | def train_multimodel(models_dict: dict, 17 | data: Union[MoleculeDataset, List[MoleculeDataset]], 18 | loss_func: Callable, 19 | optimizer: Optimizer, 20 | scheduler: _LRScheduler, 21 | args: Namespace, 22 | n_iter: int = 0, 23 | logger: logging.Logger = None, 24 | writer: SummaryWriter = None) -> int: 25 | """ 26 | Trains a model for an epoch. 27 | 28 | :param model: Model. 29 | :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe). 30 | :param loss_func: Loss function. 31 | :param optimizer: An Optimizer. 32 | :param scheduler: A learning rate scheduler. 33 | :param args: Arguments. 34 | :param n_iter: The number of iterations (training examples) trained on so far. 35 | :param logger: A logger for printing intermediate results. 36 | :param writer: A tensorboardX SummaryWriter. 37 | :return: The total number of iterations (training examples) trained on so far. 38 | """ 39 | debug = logger.debug if logger is not None else print 40 | 41 | # model.train() 42 | 43 | data.shuffle() 44 | 45 | loss_sum, iter_count = 0, 0 46 | 47 | num_iters = len(data) // args.batch_size * args.batch_size # don't use the last batch if it's small, for stability 48 | 49 | iter_size = args.batch_size 50 | 51 | for i in range(0, num_iters, iter_size): 52 | # Prepare batch 53 | if i + args.batch_size > len(data): 54 | break 55 | mol_batch = MoleculeDataset(data[i:i + args.batch_size]) 56 | smiles_batch, features_batch, target_batch = mol_batch.smiles(), mol_batch.features(), mol_batch.targets() 57 | batch = smiles_batch 58 | mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch]) 59 | targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch]) 60 | means_multi = torch.zeros(targets.shape[0], 1, args.ensemble_size) 61 | logvars_multi = torch.zeros(targets.shape[0], 1, args.ensemble_size) 62 | if args.cuda: 63 | mask, targets, means_multi, logvars_multi = mask.cuda(), targets.cuda(), means_multi.cuda(), logvars_multi.cuda() 64 | 65 | class_weights = torch.ones(targets.shape) 66 | 67 | if args.cuda: 68 | class_weights = class_weights.cuda() 69 | 70 | # Run model 71 | for index in range(len(models_dict.keys())): 72 | model = models_dict[f'model_{index}'] 73 | model.train() 74 | # model.zero_grad() 75 | 76 | if not args.aleatoric: 77 | preds = model(batch, features_batch) 78 | if args.dataset_type == 'multiclass': 79 | targets = targets.long() 80 | loss = torch.cat([loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) for target_index in range(preds.size(1))], dim=1) * class_weights * mask 81 | else: 82 | loss = loss_func(preds, targets) * class_weights * mask 83 | else: 84 | means, logvars, _, _ = model(batch, features_batch) 85 | means_multi[:, :, index] = means 86 | logvars_multi[:, :, index] = logvars 87 | 88 | means = torch.mean(means_multi, 2) 89 | logvars = torch.mean(logvars_multi, 2) 90 | 91 | loss = loss_func(targets, means, logvars) * class_weights * mask 92 | 93 | loss = loss.sum() / mask.sum() 94 | 95 | if args.epistemic == 'mc_dropout': 96 | reg_loss = args.reg_acc.get_sum() 97 | loss += reg_loss 98 | 99 | loss_sum += loss.item() 100 | iter_count += len(mol_batch) 101 | 102 | loss.backward() 103 | optimizer.step() 104 | optimizer.zero_grad() 105 | 106 | if isinstance(scheduler, NoamLR): # class 107 | scheduler.step() 108 | 109 | n_iter += len(mol_batch) 110 | 111 | # Log and/or add to tensorboard 112 | if (n_iter // args.batch_size) % args.log_frequency == 0: 113 | lrs = scheduler.get_lr() 114 | pnorm = compute_pnorm(model) 115 | gnorm = compute_gnorm(model) 116 | loss_avg = loss_sum / iter_count 117 | loss_sum, iter_count = 0, 0 118 | 119 | lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs)) 120 | debug(f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, lr = {lrs_str}') 121 | 122 | if writer is not None: 123 | writer.add_scalar('train_loss', loss_avg, n_iter) 124 | writer.add_scalar('param_norm', pnorm, n_iter) 125 | writer.add_scalar('gradient_norm', gnorm, n_iter) 126 | for i, lr in enumerate(lrs): 127 | writer.add_scalar(f'learning_rate_{i}', lr, n_iter) 128 | return n_iter 129 | -------------------------------------------------------------------------------- /chemprop/atom_plot/molecule_drawer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from matplotlib import cm 3 | from matplotlib.colors import LinearSegmentedColormap 4 | from rdkit import Chem, Geometry 5 | from rdkit.Chem.Draw import rdMolDraw2D 6 | from rdkit.Chem import Draw, rdDepictor 7 | 8 | """ 9 | Use rdkit to draw molecule. Higher uncertainty results in red marks. 10 | """ 11 | 12 | class MoleculeDrawer(): 13 | @staticmethod 14 | def _get_similarity_map_from_weights(mol, weights, colorMap=None, sigma=None, 15 | contourLines=10, draw2d=None, unc_type=None, **kwargs): 16 | """ 17 | Copied from Chem.Draw.SimilarityMaps GetSimilarityMapFromWeights 18 | 19 | Generates the similarity map for a molecule given the atomic weights. 20 | 21 | Parameters: 22 | mol -- the molecule of interest 23 | colorMap -- the matplotlib color map scheme, default is custom PiWG color map 24 | sigma -- the sigma for the Gaussians 25 | contourLines -- if integer number N: N contour lines are drawn 26 | if list(numbers): contour lines at these numbers are drawn 27 | alpha -- the alpha blending value for the contour lines 28 | unc_type -- if it's 'pred', do not draw color on atoms 29 | kwargs -- additional arguments for drawing 30 | """ 31 | if mol.GetNumAtoms() < 2: 32 | raise ValueError("too few atoms") 33 | 34 | if draw2d is not None: 35 | mol = rdMolDraw2D.PrepareMolForDrawing(mol, addChiralHs=False) 36 | if not mol.GetNumConformers(): 37 | rdDepictor.Compute2DCoords(mol) 38 | if sigma is None: 39 | if mol.GetNumBonds() > 0: 40 | bond = mol.GetBondWithIdx(0) 41 | idx1 = bond.GetBeginAtomIdx() 42 | idx2 = bond.GetEndAtomIdx() 43 | sigma = 0.3 * (mol.GetConformer().GetAtomPosition(idx1) - 44 | mol.GetConformer().GetAtomPosition(idx2)).Length() 45 | else: 46 | sigma = 0.3 * (mol.GetConformer().GetAtomPosition(0) - 47 | mol.GetConformer().GetAtomPosition(1)).Length() 48 | sigma = round(sigma, 2) 49 | 50 | sigmas = [sigma] * mol.GetNumAtoms() 51 | locs = [] 52 | max_atom_pos_y = 0 53 | for i in range(mol.GetNumAtoms()): 54 | p = mol.GetConformer().GetAtomPosition(i) 55 | locs.append(Geometry.Point2D(p.x, p.y)) 56 | max_atom_pos_y = p.y if p.y > max_atom_pos_y else max_atom_pos_y 57 | 58 | if unc_type == 'pred': 59 | draw2d.DrawMolecule(mol) 60 | return None, max_atom_pos_y 61 | 62 | draw2d.ClearDrawing() 63 | ps = Draw.ContourParams() 64 | ps.fillGrid = True 65 | ps.gridResolution = 0.03 66 | ps.extraGridPadding = 0.5 67 | 68 | if colorMap is not None: 69 | if cm is not None and isinstance(colorMap, type(cm.Blues)): 70 | # it's a matplotlib colormap: 71 | clrs = [tuple(x) for x in colorMap([0, 0.5, 1])] 72 | elif type(colorMap) == str: 73 | if cm is None: 74 | raise ValueError("cannot provide named colormaps unless matplotlib is installed") 75 | clrs = [tuple(x) for x in cm.get_cmap(colorMap)([0, 0.5, 1])] 76 | else: 77 | clrs = [colorMap[0], colorMap[1], colorMap[2]] 78 | ps.setColourMap(clrs) 79 | 80 | Draw.ContourAndDrawGaussians(draw2d, locs, weights, sigmas, nContours=contourLines, params=ps) 81 | draw2d.drawOptions().clearBackground = False 82 | draw2d.DrawMolecule(mol) 83 | return draw2d, max_atom_pos_y 84 | 85 | @staticmethod 86 | def draw_molecule_with_atom_notes(smiles: str, mol_note: float, atom_notes: List, unc_type: str, svg: bool=True): 87 | draw_opts = rdMolDraw2D.MolDrawOptions() 88 | draw_opts.addAtomIndices = False # We don't want to show default atom indices 89 | draw_opts.atomNoteFontSize = 16 # Set the font size for atom notes 90 | 91 | colors = [(1, 0.2, 0.2), (1, 1, 1), (1, 0.2, 0.2)] # pink 92 | cmap = LinearSegmentedColormap.from_list('self_define', colors, N=100) 93 | mol = Chem.MolFromSmiles(smiles) 94 | 95 | atom_notes = list(atom_notes[:mol.GetNumHeavyAtoms()]) 96 | 97 | if svg: 98 | drawer = rdMolDraw2D.MolDraw2DSVG(520, 550) # Set the size of the drawing 99 | else: 100 | drawer = rdMolDraw2D.MolDraw2DCairo(520, 550) # Specify the desired image size 101 | 102 | drawer.SetDrawOptions(draw_opts) 103 | for atom, note in zip(mol.GetAtoms(), atom_notes): 104 | atom.SetProp('atomNote', str(note)) 105 | atom.SetProp('atomLabel', atom.GetSymbol()) # forces all atoms, including carbons, to be labeled with their element symbols 106 | 107 | _, max_atom_pos_y = MoleculeDrawer._get_similarity_map_from_weights(mol, atom_notes, colorMap=cmap, contourLines=2, draw2d=drawer, alpha=3, sigma=0.25, unc_type=unc_type) #0.34 108 | max_atom_pos_y = max_atom_pos_y*1.7 109 | 110 | if unc_type == 'pred': 111 | drawer.DrawString(f'Prediction: {mol_note:.2f}', Geometry.Point2D(0, max_atom_pos_y)) 112 | 113 | elif unc_type == 'ale': 114 | drawer.DrawString(f'Aleatoric: {mol_note:.2f}', Geometry.Point2D(0, max_atom_pos_y)) 115 | elif unc_type == 'epi': 116 | drawer.DrawString(f'Epistemic: {mol_note:.2f}', Geometry.Point2D(0, max_atom_pos_y)) 117 | else: 118 | drawer.DrawString(f'Total: {mol_note:.2f}', Geometry.Point2D(0, max_atom_pos_y)) 119 | 120 | drawer.FinishDrawing() 121 | svg = drawer.GetDrawingText() 122 | return svg 123 | -------------------------------------------------------------------------------- /chemprop/train/evaluate_multimodel.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Callable, List 3 | 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | from .predict import predict 8 | from chemprop.data import MoleculeDataset, StandardScaler 9 | from chemprop.utils import get_metric_func 10 | 11 | def evaluate_predictions_multimodel(preds: List[List[float]], 12 | targets: List[List[float]], 13 | ales: List[List[float]], 14 | num_tasks: int, 15 | metric_func: Callable, 16 | dataset_type: str, 17 | logger: logging.Logger = None) -> List[List[float]]: 18 | """ 19 | Evaluates predictions using a metric function and filtering out invalid targets. 20 | 21 | :param preds: A list of lists of shape (data_size, num_tasks) with model predictions. 22 | :param targets: A list of lists of shape (data_size, num_tasks) with targets. 23 | :param num_tasks: Number of tasks. 24 | :param metric_func: Metric function which takes in a list of targets and a list of predictions. 25 | :param dataset_type: Dataset type. 26 | :param logger: Logger. 27 | :return: A list with the score for each task based on `metric_func`. 28 | """ 29 | info = logger.info if logger is not None else print 30 | rmse_function = get_metric_func(metric='rmse') 31 | mae_function = get_metric_func(metric='mae') 32 | 33 | 34 | if len(preds) == 0: 35 | return [float('nan')] * num_tasks 36 | 37 | # Filter out empty targets 38 | # valid_preds and valid_targets have shape (num_tasks, data_size) 39 | valid_preds = [[] for _ in range(num_tasks)] 40 | valid_targets = [[] for _ in range(num_tasks)] 41 | valid_ales = [[] for _ in range(num_tasks)] 42 | for i in range(num_tasks): 43 | for j in range(len(preds)): 44 | if targets[j][i] is not None: # Skip those without targets 45 | valid_preds[i].append(preds[j][i]) 46 | valid_targets[i].append(targets[j][i]) 47 | valid_ales[i].append(ales[j][i]) 48 | 49 | # Compute metric 50 | results = [] 51 | rmses = [] 52 | maes = [] 53 | for i in range(num_tasks): 54 | # # Skip if all targets or preds are identical, otherwise we'll crash during classification 55 | if dataset_type == 'classification': 56 | nan = False 57 | if all(target == 0 for target in valid_targets[i]) or all(target == 1 for target in valid_targets[i]): 58 | nan = True 59 | info('Warning: Found a task with targets all 0s or all 1s') 60 | if all(pred == 0 for pred in valid_preds[i]) or all(pred == 1 for pred in valid_preds[i]): 61 | nan = True 62 | info('Warning: Found a task with predictions all 0s or all 1s') 63 | 64 | if nan: 65 | results.append(float('nan')) 66 | continue 67 | 68 | if len(valid_targets[i]) == 0: 69 | continue 70 | 71 | if dataset_type == 'multiclass': 72 | results.append(metric_func(valid_targets[i], valid_preds[i], labels=list(range(len(valid_preds[i][0]))))) 73 | else: 74 | results.append(metric_func(valid_targets[i], valid_preds[i], valid_ales[i])) 75 | rmses.append(rmse_function(valid_targets[i], valid_preds[i])) 76 | maes.append(mae_function(valid_targets[i], valid_preds[i])) 77 | 78 | return results, rmses, maes 79 | 80 | 81 | def evaluate_multimodel(models_dict: nn.Module, 82 | data: MoleculeDataset, 83 | num_tasks: int, 84 | metric_func: Callable, 85 | batch_size: int, 86 | dataset_type: str, 87 | sampling_size: int, 88 | fp_method: str, 89 | scaler: StandardScaler = None, 90 | logger: logging.Logger = None) -> List[float]: 91 | """ 92 | Evaluates an ensemble of models on a dataset. 93 | 94 | :param model: A model. 95 | :param data: A MoleculeDataset. 96 | :param num_tasks: Number of tasks. 97 | :param metric_func: Metric function which takes in a list of targets and a list of predictions. 98 | :param batch_size: Batch size. 99 | :param dataset_type: Dataset type. 100 | :param scaler: A StandardScaler object fit on the training targets. 101 | :param logger: Logger. 102 | :return: A list with the score for each task based on `metric_func`. 103 | """ 104 | preds_multimodel = [] 105 | ales_multimodel = [] 106 | for i in range(len(models_dict.keys())): 107 | model = models_dict[f'model_{i}'] 108 | preds, ales, _, _, _ = predict( 109 | model=model, 110 | data=data, 111 | batch_size=batch_size, 112 | scaler=scaler, 113 | sampling_size=sampling_size, 114 | fp_method=fp_method 115 | ) 116 | preds_multimodel.append(preds) 117 | ales_multimodel.append(ales) 118 | 119 | preds = np.mean(np.array(preds_multimodel), 0).tolist() 120 | ales = np.mean(np.array(ales_multimodel), 0).tolist() 121 | 122 | targets = data.targets() 123 | 124 | assert len(targets) == len(preds) 125 | 126 | results, rmse, mae = evaluate_predictions_multimodel( 127 | preds=preds, 128 | targets=targets, 129 | ########## heteroscedastic loss ######### 130 | ales=ales, 131 | ########## heteroscedastic loss ######### 132 | num_tasks=num_tasks, 133 | metric_func=metric_func, 134 | dataset_type=dataset_type, 135 | logger=logger 136 | ) 137 | 138 | return results, rmse, mae 139 | -------------------------------------------------------------------------------- /chemprop/features/features_generators.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Union 2 | 3 | import numpy as np 4 | from rdkit import Chem, DataStructs 5 | from rdkit.Chem import AllChem 6 | 7 | 8 | Molecule = Union[str, Chem.Mol] 9 | FeaturesGenerator = Callable[[Molecule], np.ndarray] 10 | 11 | 12 | FEATURES_GENERATOR_REGISTRY = {} 13 | 14 | 15 | def register_features_generator(features_generator_name: str) -> Callable[[FeaturesGenerator], FeaturesGenerator]: 16 | """ 17 | Registers a features generator. 18 | 19 | :param features_generator_name: The name to call the FeaturesGenerator. 20 | :return: A decorator which will add a FeaturesGenerator to the registry using the specified name. 21 | """ 22 | def decorator(features_generator: FeaturesGenerator) -> FeaturesGenerator: 23 | FEATURES_GENERATOR_REGISTRY[features_generator_name] = features_generator 24 | return features_generator 25 | 26 | return decorator 27 | 28 | 29 | def get_features_generator(features_generator_name: str) -> FeaturesGenerator: 30 | """ 31 | Gets a registered FeaturesGenerator by name. 32 | 33 | :param features_generator_name: The name of the FeaturesGenerator. 34 | :return: The desired FeaturesGenerator. 35 | """ 36 | if features_generator_name not in FEATURES_GENERATOR_REGISTRY: 37 | raise ValueError(f'Features generator "{features_generator_name}" could not be found. ' 38 | f'If this generator relies on rdkit features, you may need to install descriptastorus.') 39 | 40 | return FEATURES_GENERATOR_REGISTRY[features_generator_name] 41 | 42 | 43 | def get_available_features_generators() -> List[str]: 44 | """Returns the names of available features generators.""" 45 | return list(FEATURES_GENERATOR_REGISTRY.keys()) 46 | 47 | 48 | MORGAN_RADIUS = 2 49 | MORGAN_NUM_BITS = 2048 50 | 51 | 52 | @register_features_generator('morgan') 53 | def morgan_binary_features_generator(mol: Molecule, 54 | radius: int = MORGAN_RADIUS, 55 | num_bits: int = MORGAN_NUM_BITS) -> np.ndarray: 56 | """ 57 | Generates a binary Morgan fingerprint for a molecule. 58 | 59 | :param mol: A molecule (i.e. either a SMILES string or an RDKit molecule). 60 | :param radius: Morgan fingerprint radius. 61 | :param num_bits: Number of bits in Morgan fingerprint. 62 | :return: A 1-D numpy array containing the binary Morgan fingerprint. 63 | """ 64 | mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol 65 | features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=num_bits) 66 | features = np.zeros((1,)) 67 | DataStructs.ConvertToNumpyArray(features_vec, features) 68 | 69 | return features 70 | 71 | 72 | @register_features_generator('morgan_count') 73 | def morgan_counts_features_generator(mol: Molecule, 74 | radius: int = MORGAN_RADIUS, 75 | num_bits: int = MORGAN_NUM_BITS) -> np.ndarray: 76 | """ 77 | Generates a counts-based Morgan fingerprint for a molecule. 78 | 79 | :param mol: A molecule (i.e. either a SMILES string or an RDKit molecule). 80 | :param radius: Morgan fingerprint radius. 81 | :param num_bits: Number of bits in Morgan fingerprint. 82 | :return: A 1D numpy array containing the counts-based Morgan fingerprint. 83 | """ 84 | mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol 85 | features_vec = AllChem.GetHashedMorganFingerprint(mol, radius, nBits=num_bits) 86 | features = np.zeros((1,)) 87 | DataStructs.ConvertToNumpyArray(features_vec, features) 88 | 89 | return features 90 | 91 | 92 | try: 93 | from descriptastorus.descriptors import rdDescriptors, rdNormalizedDescriptors 94 | 95 | @register_features_generator('rdkit_2d') 96 | def rdkit_2d_features_generator(mol: Molecule) -> np.ndarray: 97 | """ 98 | Generates RDKit 2D features for a molecule. 99 | 100 | :param mol: A molecule (i.e. either a SMILES string or an RDKit molecule). 101 | :return: A 1D numpy array containing the RDKit 2D features. 102 | """ 103 | smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol 104 | generator = rdDescriptors.RDKit2D() 105 | features = generator.process(smiles)[1:] 106 | 107 | return features 108 | 109 | @register_features_generator('rdkit_2d_normalized') 110 | def rdkit_2d_features_generator(mol: Molecule) -> np.ndarray: 111 | """ 112 | Generates RDKit 2D normalized features for a molecule. 113 | 114 | :param mol: A molecule (i.e. either a SMILES string or an RDKit molecule). 115 | :return: A 1D numpy array containing the RDKit 2D normalized features. 116 | """ 117 | smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol 118 | generator = rdNormalizedDescriptors.RDKit2DNormalized() 119 | features = generator.process(smiles)[1:] 120 | 121 | return features 122 | except ImportError: 123 | pass 124 | 125 | 126 | """ 127 | Custom features generator template. 128 | 129 | Note: The name you use to register the features generator is the name 130 | you will specify on the command line when using the --features_generator flag. 131 | Ex. python train.py ... --features_generator custom ... 132 | 133 | @register_features_generator('custom') 134 | def custom_features_generator(mol: Molecule) -> np.ndarray: 135 | # If you want to use the SMILES string 136 | smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol 137 | 138 | # If you want to use the RDKit molecule 139 | mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol 140 | 141 | # Replace this with code which generates features from the molecule 142 | features = np.array([0, 0, 1]) 143 | 144 | return features 145 | """ 146 | -------------------------------------------------------------------------------- /chemprop/train/predict.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | from tqdm import trange 6 | import numpy as np 7 | 8 | from chemprop.data import MoleculeDataset, StandardScaler 9 | 10 | 11 | def predict(model: nn.Module, 12 | data: MoleculeDataset, 13 | batch_size: int, 14 | sampling_size: int, 15 | fp_method: str, 16 | scaler: StandardScaler = None, 17 | atomic_unc: bool = False) -> Tuple[Union[List[List[float]], None], ...]: 18 | """ 19 | Makes predictions on a dataset using an ensemble of models. 20 | 21 | :param model: A model. 22 | :param data: A MoleculeDataset. 23 | :param batch_size: Batch size. 24 | :param scaler: A StandardScaler object fit on the training targets. 25 | :return: A list of lists of predictions. The outer list is examples 26 | while the inner list is tasks. 27 | """ 28 | model.eval() 29 | 30 | preds = [] 31 | ale_unc = [] 32 | epi_unc = [] 33 | 34 | atomic_pred = [] 35 | atomic_ales = [] 36 | 37 | aleatoric = model.aleatoric 38 | # if MC-Dropout 39 | mc_dropout = model.mc_dropout 40 | 41 | num_iters, iter_step = len(data), batch_size 42 | 43 | for i in range(0, num_iters, iter_step): 44 | # Prepare batch 45 | mol_batch = MoleculeDataset(data[i:i + batch_size]) 46 | smiles_batch, features_batch = mol_batch.smiles(), mol_batch.features() 47 | 48 | # Run model 49 | batch = smiles_batch 50 | if not aleatoric and not mc_dropout: 51 | with torch.no_grad(): 52 | batch_preds = model(batch, features_batch) 53 | batch_preds = batch_preds.data.cpu().numpy() 54 | 55 | # Inverse scale if regression 56 | if scaler is not None: 57 | batch_preds = scaler.inverse_transform(batch_preds) 58 | 59 | # Collect vectors 60 | batch_preds = batch_preds.tolist() 61 | preds.extend(batch_preds) 62 | 63 | elif aleatoric and not mc_dropout: 64 | with torch.no_grad(): 65 | batch_preds, batch_var, batch_atomic_pred, batch_atomic_ales = model(batch, features_batch) 66 | if fp_method == 'molecular': 67 | batch_var = torch.exp(batch_var) # log_var in molecular fp_method 68 | batch_preds = batch_preds.data.cpu().numpy() 69 | batch_ale_unc = batch_var.data.cpu().numpy() 70 | 71 | # Inverse scale if regression 72 | if scaler is not None: 73 | batch_preds = scaler.inverse_transform(batch_preds) 74 | batch_ale_unc = scaler.inverse_transform_variance(batch_ale_unc) 75 | # Collect vectors 76 | batch_preds = batch_preds.tolist() 77 | batch_ale_unc = batch_ale_unc.tolist() 78 | preds.extend(batch_preds) 79 | ale_unc.extend(batch_ale_unc) 80 | 81 | if atomic_unc: 82 | batch_atomic_pred = batch_atomic_pred.data.cpu().numpy() 83 | batch_atomic_ales = batch_atomic_ales.data.cpu().numpy() 84 | if scaler is not None: 85 | batch_atomic_pred = scaler.inverse_transform(batch_atomic_pred) 86 | batch_atomic_ales = scaler.inverse_transform_variance(batch_atomic_ales) 87 | atomic_pred.extend(batch_atomic_pred) # bs x max_atom_size 88 | atomic_ales.extend(batch_atomic_ales) # bs x max_atom_size 89 | 90 | 91 | 92 | elif not aleatoric and mc_dropout: 93 | with torch.no_grad(): 94 | P_mean = [] 95 | 96 | for ss in range(sampling_size): 97 | batch_preds = model(batch, features_batch) 98 | P_mean.append(batch_preds) 99 | 100 | batch_preds = torch.mean(torch.stack(P_mean), 0) 101 | batch_epi_unc = torch.var(torch.stack(P_mean), 0) 102 | 103 | batch_preds = batch_preds.data.cpu().numpy() 104 | batch_epi_unc = batch_epi_unc.data.cpu().numpy() 105 | 106 | elif aleatoric and mc_dropout: 107 | with torch.no_grad(): 108 | P_mean = [] 109 | P_var = [] 110 | for ss in range(sampling_size): 111 | batch_preds, batch_var, batch_atomic_pred, batch_atomic_ales = model(batch, features_batch) 112 | if fp_method == 'molecular': 113 | batch_var = torch.exp(batch_var) # log_var in molecular fp_method 114 | P_mean.append(batch_preds) 115 | P_var.append(batch_var) 116 | 117 | batch_preds = torch.mean(torch.stack(P_mean), 0) 118 | batch_ale_unc = torch.mean(torch.stack(P_var), 0) 119 | batch_epi_unc = torch.var(torch.stack(P_mean), 0) 120 | 121 | batch_preds = batch_preds.data.cpu().numpy() 122 | batch_ale_unc = batch_ale_unc.data.cpu().numpy() 123 | batch_epi_unc = batch_epi_unc.data.cpu().numpy() 124 | 125 | # Inverse scale if regression 126 | if scaler is not None: 127 | batch_preds = scaler.inverse_transform(batch_preds) 128 | batch_ale_unc = scaler.inverse_transform_variance(batch_ale_unc) 129 | batch_epi_unc = scaler.inverse_transform_variance(batch_epi_unc) 130 | 131 | # Collect vectors 132 | batch_preds = batch_preds.tolist() 133 | batch_ale_unc = batch_ale_unc.tolist() 134 | batch_epi_unc = batch_epi_unc.tolist() 135 | 136 | preds.extend(batch_preds) 137 | ale_unc.extend(batch_ale_unc) 138 | epi_unc.extend(batch_epi_unc) 139 | 140 | if atomic_unc: 141 | atomic_pred = np.r_[atomic_pred] 142 | atomic_ales = np.r_[atomic_ales] 143 | print(f'predict.py | atomic_pred.shape: {atomic_pred.shape}, atomic_ales.shape: {atomic_ales.shape}') 144 | return preds, ale_unc, None, atomic_pred, atomic_ales 145 | 146 | if not aleatoric and not mc_dropout: 147 | return preds, None, None, None, None 148 | elif aleatoric and not mc_dropout: 149 | return preds, ale_unc, None, None, None 150 | elif not aleatoric and mc_dropout: 151 | return preds, None, epi_unc, None, None 152 | elif aleatoric and mc_dropout: 153 | return preds, ale_unc, epi_unc, None, None -------------------------------------------------------------------------------- /chemprop/random_forest.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from logging import Logger 3 | from pprint import pformat 4 | from typing import Callable, List, Tuple 5 | 6 | import numpy as np 7 | from sklearn.ensemble import RandomForestRegressor 8 | from sklearn.ensemble import RandomForestClassifier 9 | from tqdm import trange, tqdm 10 | 11 | from chemprop.data import MoleculeDataset 12 | from chemprop.data.utils import get_data, split_data 13 | from chemprop.features import get_features_generator 14 | from chemprop.train.evaluate import evaluate_predictions 15 | from chemprop.utils import get_metric_func 16 | 17 | 18 | def single_task_random_forest(train_data: MoleculeDataset, 19 | test_data: MoleculeDataset, 20 | metric_func: Callable, 21 | args: Namespace, 22 | logger: Logger = None) -> List[float]: 23 | scores = [] 24 | num_tasks = train_data.num_tasks() 25 | for task_num in trange(num_tasks): 26 | # Only get features and targets for molecules where target is not None 27 | train_features, train_targets = zip(*[(features, targets[task_num]) 28 | for features, targets in zip(train_data.features(), train_data.targets()) 29 | if targets[task_num] is not None]) 30 | test_features, test_targets = zip(*[(features, targets[task_num]) 31 | for features, targets in zip(test_data.features(), test_data.targets()) 32 | if targets[task_num] is not None]) 33 | 34 | if args.dataset_type == 'regression': 35 | model = RandomForestRegressor(n_estimators=args.num_trees, n_jobs=-1) 36 | elif args.dataset_type == 'classification': 37 | model = RandomForestClassifier(class_weight=args.class_weight, n_estimators=args.num_trees, n_jobs=-1) 38 | else: 39 | raise ValueError(f'dataset_type "{args.dataset_type}" not supported.') 40 | 41 | model.fit(train_features, train_targets) 42 | 43 | test_preds = model.predict(test_features) 44 | 45 | test_preds = [[pred] for pred in test_preds] 46 | test_targets = [[target] for target in test_targets] 47 | 48 | score = evaluate_predictions( 49 | preds=test_preds, 50 | targets=test_targets, 51 | num_tasks=1, 52 | metric_func=metric_func, 53 | dataset_type=args.dataset_type, 54 | logger=logger 55 | ) 56 | scores.append(score[0]) 57 | 58 | return scores 59 | 60 | 61 | def multi_task_random_forest(train_data: MoleculeDataset, 62 | test_data: MoleculeDataset, 63 | metric_func: Callable, 64 | args: Namespace, 65 | logger: Logger = None) -> List[float]: 66 | num_tasks = train_data.num_tasks() 67 | 68 | if args.dataset_type == 'regression': 69 | model = RandomForestRegressor(n_estimators=args.num_trees, n_jobs=-1) 70 | elif args.dataset_type == 'classification': 71 | model = RandomForestClassifier(n_estimators=args.num_trees, n_jobs=-1) 72 | else: 73 | raise ValueError(f'dataset_type "{args.dataset_type}" not supported.') 74 | 75 | train_targets = train_data.targets() 76 | if train_data.num_tasks() == 1: 77 | train_targets = [targets[0] for targets in train_targets] 78 | 79 | model.fit(train_data.features(), train_targets) 80 | 81 | test_preds = model.predict(test_data.features()) 82 | if num_tasks == 1: 83 | test_preds = [[pred] for pred in test_preds] 84 | 85 | scores = evaluate_predictions( 86 | preds=test_preds, 87 | targets=test_data.targets(), 88 | num_tasks=num_tasks, 89 | metric_func=metric_func, 90 | dataset_type=args.dataset_type, 91 | logger=logger 92 | ) 93 | 94 | return scores 95 | 96 | 97 | def run_random_forest(args: Namespace, logger: Logger = None) -> List[float]: 98 | if logger is not None: 99 | debug, info = logger.debug, logger.info 100 | else: 101 | debug = info = print 102 | 103 | debug(pformat(vars(args))) 104 | 105 | metric_func = get_metric_func(args.metric) 106 | 107 | debug('Loading data') 108 | data = get_data(path=args.data_path) 109 | 110 | debug(f'Splitting data with seed {args.seed}') 111 | # Need to have val set so that train and test sets are the same as when doing MPN 112 | train_data, _, test_data = split_data(data=data, split_type=args.split_type, seed=args.seed, args=args) 113 | 114 | debug(f'Total size = {len(data):,} | train size = {len(train_data):,} | test size = {len(test_data):,}') 115 | 116 | debug('Computing morgan fingerprints') 117 | morgan_fingerprint = get_features_generator('morgan') 118 | for dataset in [train_data, test_data]: 119 | for datapoint in tqdm(dataset, total=len(dataset)): 120 | datapoint.set_features(morgan_fingerprint(mol=datapoint.smiles, radius=args.radius, num_bits=args.num_bits)) 121 | 122 | debug('Training') 123 | if args.single_task: 124 | scores = single_task_random_forest(train_data, test_data, metric_func, args, logger) 125 | else: 126 | scores = multi_task_random_forest(train_data, test_data, metric_func, args, logger) 127 | 128 | info(f'Test {args.metric} = {np.nanmean(scores)}') 129 | 130 | return scores 131 | 132 | 133 | def cross_validate_random_forest(args: Namespace, logger: Logger = None) -> Tuple[float, float]: 134 | info = logger.info if logger is not None else print 135 | init_seed = args.seed 136 | 137 | # Run training on different random seeds for each fold 138 | all_scores = [] 139 | for fold_num in range(args.num_folds): 140 | info(f'Fold {fold_num}') 141 | args.seed = init_seed + fold_num 142 | model_scores = run_random_forest(args, logger) 143 | all_scores.append(model_scores) 144 | all_scores = np.array(all_scores) 145 | 146 | # Report scores for each fold 147 | for fold_num, scores in enumerate(all_scores): 148 | info(f'Seed {init_seed + fold_num} ==> test {args.metric} = {np.nanmean(scores):.6f}') 149 | 150 | # Report scores across folds 151 | avg_scores = np.nanmean(all_scores, axis=1) # average score for each model across tasks 152 | mean_score, std_score = np.nanmean(avg_scores), np.nanstd(avg_scores) 153 | info(f'Overall test {args.metric} = {mean_score:.6f} +/- {std_score:.6f}') 154 | 155 | return mean_score, std_score 156 | -------------------------------------------------------------------------------- /chemprop/data/scaffold.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import logging 3 | import random 4 | from typing import Dict, List, Set, Tuple, Union 5 | 6 | from rdkit import Chem 7 | from rdkit.Chem.Scaffolds import MurckoScaffold 8 | from tqdm import tqdm 9 | import numpy as np 10 | 11 | from .data import MoleculeDataset 12 | 13 | 14 | def generate_scaffold(mol: Union[str, Chem.Mol], include_chirality: bool = False) -> str: 15 | """ 16 | Compute the Bemis-Murcko scaffold for a SMILES string. 17 | 18 | :param mol: A smiles string or an RDKit molecule. 19 | :param include_chirality: Whether to include chirality. 20 | :return: 21 | """ 22 | mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol 23 | scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol, includeChirality=include_chirality) 24 | 25 | return scaffold 26 | 27 | 28 | def scaffold_to_smiles(mols: Union[List[str], List[Chem.Mol]], 29 | use_indices: bool = False) -> Dict[str, Union[Set[str], Set[int]]]: 30 | """ 31 | Computes scaffold for each smiles string and returns a mapping from scaffolds to sets of smiles. 32 | 33 | :param mols: A list of smiles strings or RDKit molecules. 34 | :param use_indices: Whether to map to the smiles' index in all_smiles rather than mapping 35 | to the smiles string itself. This is necessary if there are duplicate smiles. 36 | :return: A dictionary mapping each unique scaffold to all smiles (or smiles indices) which have that scaffold. 37 | """ 38 | scaffolds = defaultdict(set) 39 | for i, mol in enumerate(mols): 40 | scaffold = generate_scaffold(mol) 41 | if use_indices: 42 | scaffolds[scaffold].add(i) 43 | else: 44 | scaffolds[scaffold].add(mol) 45 | 46 | return scaffolds 47 | 48 | 49 | def scaffold_split(data: MoleculeDataset, 50 | sizes: Tuple[float, float, float] = (0.8, 0.1, 0.1), 51 | balanced: bool = False, 52 | seed: int = 0, 53 | logger: logging.Logger = None) -> Tuple[MoleculeDataset, 54 | MoleculeDataset, 55 | MoleculeDataset]: 56 | """ 57 | Split a dataset by scaffold so that no molecules sharing a scaffold are in the same split. 58 | 59 | :param data: A MoleculeDataset. 60 | :param sizes: A length-3 tuple with the proportions of data in the 61 | train, validation, and test sets. 62 | :param balanced: Try to balance sizes of scaffolds in each set, rather than just putting smallest in test set. 63 | :param seed: Seed for shuffling when doing balanced splitting. 64 | :param logger: A logger. 65 | :return: A tuple containing the train, validation, and test splits of the data. 66 | """ 67 | assert sum(sizes) == 1 68 | 69 | # Split 70 | train_size, val_size, test_size = sizes[0] * len(data), sizes[1] * len(data), sizes[2] * len(data) 71 | train, val, test = [], [], [] 72 | train_scaffold_count, val_scaffold_count, test_scaffold_count = 0, 0, 0 73 | 74 | # Map from scaffold to index in the data 75 | scaffold_to_indices = scaffold_to_smiles(data.mols(), use_indices=True) 76 | 77 | if balanced: # Put stuff that's bigger than half the val/test size into train, rest just order randomly 78 | index_sets = list(scaffold_to_indices.values()) 79 | big_index_sets = [] 80 | small_index_sets = [] 81 | for index_set in index_sets: 82 | if len(index_set) > val_size / 2 or len(index_set) > test_size / 2: 83 | big_index_sets.append(index_set) 84 | else: 85 | small_index_sets.append(index_set) 86 | random.seed(seed) 87 | random.shuffle(big_index_sets) 88 | random.shuffle(small_index_sets) 89 | if logger is not None: 90 | logger.debug(f'len of big_index_sets: {len(big_index_sets)} len of small_index_sets: {len(small_index_sets)}') 91 | index_sets = big_index_sets + small_index_sets 92 | else: # Sort from largest to smallest scaffold sets 93 | index_sets = sorted(list(scaffold_to_indices.values()), 94 | key=lambda index_set: len(index_set), 95 | reverse=True) 96 | 97 | for index_set in index_sets: 98 | if len(train) + len(index_set) <= train_size: 99 | train += index_set 100 | train_scaffold_count += 1 101 | elif len(val) + len(index_set) <= val_size: 102 | val += index_set 103 | val_scaffold_count += 1 104 | else: 105 | test += index_set 106 | test_scaffold_count += 1 107 | 108 | if logger is not None: 109 | logger.debug(f'Total scaffolds = {len(scaffold_to_indices):,} | ' 110 | f'train scaffolds = {train_scaffold_count:,} | ' 111 | f'val scaffolds = {val_scaffold_count:,} | ' 112 | f'test scaffolds = {test_scaffold_count:,}') 113 | 114 | log_scaffold_stats(data, index_sets, logger=logger) 115 | 116 | # Map from indices to data 117 | train = [data[i] for i in train] 118 | val = [data[i] for i in val] 119 | test = [data[i] for i in test] 120 | 121 | return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test) 122 | 123 | 124 | def log_scaffold_stats(data: MoleculeDataset, 125 | index_sets: List[Set[int]], 126 | num_scaffolds: int = 10, 127 | num_labels: int = 20, 128 | logger: logging.Logger = None) -> List[Tuple[List[float], List[int]]]: 129 | """ 130 | Logs and returns statistics about counts and average target values in molecular scaffolds. 131 | 132 | :param data: A MoleculeDataset. 133 | :param index_sets: A list of sets of indices representing splits of the data. 134 | :param num_scaffolds: The number of scaffolds about which to display statistics. 135 | :param num_labels: The number of labels about which to display statistics. 136 | :param logger: A Logger. 137 | :return: A list of tuples where each tuple contains a list of average target values 138 | across the first num_labels labels and a list of the number of non-zero values for 139 | the first num_scaffolds scaffolds, sorted in decreasing order of scaffold frequency. 140 | """ 141 | # print some statistics about scaffolds 142 | target_avgs = [] 143 | counts = [] 144 | for index_set in index_sets: 145 | data_set = [data[i] for i in index_set] 146 | targets = [d.targets for d in data_set] 147 | targets = np.array(targets, dtype=np.float) 148 | target_avgs.append(np.nanmean(targets, axis=0)) 149 | counts.append(np.count_nonzero(~np.isnan(targets), axis=0)) 150 | stats = [(target_avgs[i][:num_labels], counts[i][:num_labels]) for i in range(min(num_scaffolds, len(target_avgs)))] 151 | 152 | if logger is not None: 153 | logger.debug('Label averages per scaffold, in decreasing order of scaffold frequency,' 154 | f'capped at {num_scaffolds} scaffolds and {num_labels} labels: {stats}') 155 | 156 | return stats 157 | -------------------------------------------------------------------------------- /chemprop/train/make_predictions_atomicUnc_multiMol.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from typing import List 4 | import logging 5 | import numpy as np 6 | import torch 7 | from tqdm import tqdm 8 | from pprint import pformat 9 | from argparse import Namespace 10 | 11 | from .predict import predict 12 | from chemprop.data import MoleculeDataset 13 | from chemprop.data.utils import get_data, get_data_from_smiles 14 | from chemprop.utils import load_args, load_checkpoint, load_scalers 15 | from chemprop.atom_plot.molecule_drawer import MoleculeDrawer 16 | 17 | 18 | def draw_and_save_molecule(i, smiles_i, mol_unc_i, atomic_unc_i, unc_t, args, svg=False, logger=None): 19 | smiles = smiles_i 20 | mol_unc = float(mol_unc_i) 21 | atom_uncs = [round(a, 2) for a in atomic_unc_i.astype(float)] 22 | try: 23 | pic_data = MoleculeDrawer.draw_molecule_with_atom_notes(smiles=smiles, mol_note=mol_unc, atom_notes=atom_uncs, unc_type=unc_t, svg=svg) 24 | except: 25 | if logger: 26 | logger.error(f'Cannot draw molecule {i}: {smiles_i}') 27 | else: 28 | print(f'[Error] Cannot draw molecule {i}: {smiles_i}') 29 | return False 30 | if svg: 31 | with open(os.path.join(args.unc_type_png_path, f'{i}_{unc_t}.svg'), 'w') as f: 32 | f.write(pic_data) 33 | else: 34 | with open(os.path.join(args.unc_type_png_path, f'{i}_{unc_t}.png'), 'wb') as f: 35 | f.write(pic_data) 36 | return True 37 | 38 | def make_predictions_atomicUnc_multiMol(args: Namespace, smiles: List[str] = None, logger: logging.Logger = None) -> None: 39 | """ 40 | Makes predictions. If smiles is provided, makes predictions on smiles. Otherwise makes predictions on args.test_data. 41 | 42 | :param args: Arguments. 43 | :param smiles: Smiles to make predictions on. 44 | :return: None. 45 | """ 46 | high_resolution = args.high_resolution 47 | if args.gpu is not None: 48 | torch.cuda.set_device(args.gpu) 49 | 50 | logger.info('Loading training args') 51 | scaler, features_scaler = load_scalers(args.checkpoint_paths[0]) 52 | train_args = load_args(args.checkpoint_paths[0]) 53 | 54 | # Update args with training arguments 55 | for key, value in vars(train_args).items(): 56 | if not hasattr(args, key): 57 | setattr(args, key, value) 58 | 59 | args.atomic_unc = True 60 | logger.info(pformat(vars(args))) 61 | logger.info('Loading data') 62 | if smiles is not None: 63 | test_data = get_data_from_smiles(smiles=smiles, skip_invalid_smiles=False) 64 | else: 65 | if args.write_true_val: 66 | test_data, true_vals = get_data(path=args.test_path, args=args, use_compound_names=args.use_compound_names, skip_invalid_smiles=False) 67 | else: 68 | test_data = get_data(path=args.test_path, args=args, use_compound_names=args.use_compound_names, skip_invalid_smiles=False) 69 | 70 | logger.info('Validating SMILES') 71 | valid_indices = [i for i in range(len(test_data)) if test_data[i].mol is not None] 72 | full_data = test_data 73 | test_data = MoleculeDataset([test_data[i] for i in valid_indices]) 74 | 75 | # Edge case if empty list of smiles is provided 76 | if len(test_data) == 0: 77 | return [None] * len(full_data) 78 | 79 | if args.use_compound_names: 80 | compound_names = test_data.compound_names() 81 | logger.info(f'Test size = {len(test_data):,}') 82 | 83 | # Normalize features 84 | if train_args.features_scaling: 85 | test_data.normalize_features(features_scaler) 86 | 87 | # max atom size check 88 | args.max_atom_size = 0 89 | logger.info(f'Checking testing data max HeavyAtom size') 90 | for test_mol in test_data.mols(): 91 | if test_mol.GetNumHeavyAtoms() > args.max_atom_size: 92 | args.max_atom_size = args.pred_max_atom_size = test_mol.GetNumHeavyAtoms() 93 | logger.info(f'Max heavy atom size = {args.max_atom_size}') 94 | if args.covariance_matrix_pred: 95 | args.scaler_stds = scaler.stds 96 | logger.info(f'covariance matrix pred: scaling factor = {args.scaler_stds}') 97 | # Predict with each model individually and sum predictions 98 | all_preds = np.zeros((len(test_data), len(args.checkpoint_paths))) 99 | all_ale_uncs = np.zeros((len(test_data), len(args.checkpoint_paths))) 100 | all_atomic_preds = np.zeros((len(test_data), args.max_atom_size, len(args.checkpoint_paths))) 101 | all_atomic_ales = np.zeros((len(test_data), args.max_atom_size, len(args.checkpoint_paths))) 102 | 103 | logger.info(f'Predicting with an ensemble of {len(args.checkpoint_paths)} models') 104 | for index, checkpoint_path in enumerate(tqdm(args.checkpoint_paths, total=len(args.checkpoint_paths), disable=True)): 105 | # Load model 106 | model = load_checkpoint(checkpoint_path, current_args=args, cuda=args.cuda, logger=logger) 107 | model_preds, ale_uncs, _, atomic_preds, atomic_uncs = predict( 108 | model=model, 109 | data=test_data, 110 | batch_size=args.batch_size, 111 | scaler=scaler, 112 | sampling_size=args.sampling_size, 113 | fp_method=args.fp_method, 114 | atomic_unc=args.atomic_unc 115 | ) 116 | all_preds[:, index] = np.array(model_preds).squeeze() # (num_mols, 1) -> (num_mols) 117 | all_ale_uncs[:, index] = np.array(ale_uncs).squeeze() 118 | all_atomic_preds[:, :, index] = atomic_preds # num_mols x atom_preds x models 119 | all_atomic_ales[:, :, index] = atomic_uncs # num_mols x atom_preds x models 120 | 121 | # Ensemble predictions 122 | assert args.estimate_variance is not None 123 | avg_preds = (np.sum(all_preds, axis=1) / len(args.checkpoint_paths))[:, np.newaxis] 124 | avg_ale_uncs = (np.sum(all_ale_uncs, axis=1) / len(args.checkpoint_paths))[:, np.newaxis] 125 | avg_epi_uncs = np.var(all_preds, axis=1)[:, np.newaxis] 126 | avg_total_uncs = np.array(avg_ale_uncs) + np.array(avg_epi_uncs) 127 | avg_test_atomic_preds = np.mean(all_atomic_preds, axis=2) 128 | avg_test_atomic_ales = np.sum(all_atomic_ales, axis=2) / len(args.checkpoint_paths) 129 | avg_test_atomic_epis = np.var(all_atomic_preds, axis=2) 130 | avg_test_atomic_total = avg_test_atomic_ales + avg_test_atomic_epis # mol x max_atom_size 131 | # avg_test_atomic_max_ales = np.max(avg_test_atomic_ales, axis=1)[:, np.newaxis].tolist() # take max ale_unc of atoms in a molecule 132 | # avg_test_atomic_max_epis = np.max(avg_test_atomic_epis, axis=1)[:, np.newaxis].tolist() 133 | # avg_test_atomic_max_total = np.max(avg_test_atomic_total, axis=1)[:, np.newaxis].tolist() 134 | # del avg_test_atomic_ales, avg_test_atomic_epis, avg_test_atomic_total 135 | 136 | # Save predictions 137 | assert len(test_data) == len(avg_preds) 138 | assert len(test_data) == len(avg_ale_uncs) 139 | assert len(test_data) == len(avg_epi_uncs) 140 | test_smiles = full_data.smiles() 141 | 142 | 143 | args.png_path = os.path.join(args.draw_mols_dir) 144 | unc_type = ['epi', 'ale', 'pred'] 145 | 146 | logger.info(f'make image directory: {args.png_path}') 147 | os.makedirs(args.png_path, exist_ok=True) 148 | 149 | for unc_t in unc_type: 150 | args.unc_type_png_path = os.path.join(args.png_path, unc_t) 151 | os.makedirs(args.unc_type_png_path) if not os.path.isdir(args.unc_type_png_path) else None 152 | if unc_t == 'ale': 153 | mol_unc = avg_ale_uncs 154 | atomic_unc = avg_test_atomic_ales 155 | elif unc_t == 'epi': 156 | mol_unc = avg_epi_uncs 157 | atomic_unc = avg_test_atomic_epis 158 | elif unc_t == 'total': 159 | mol_unc = avg_total_uncs 160 | atomic_unc = avg_test_atomic_total 161 | elif unc_t == 'pred': 162 | mol_unc = avg_preds 163 | atomic_unc = avg_test_atomic_preds 164 | for i, (smiles_i, mol_unc_i, atomic_unc_i) in enumerate(zip(test_smiles, mol_unc, atomic_unc)): 165 | draw_and_save_molecule(i, smiles_i, mol_unc_i, atomic_unc_i, unc_t, args, svg=high_resolution, logger=logger) 166 | return avg_preds 167 | -------------------------------------------------------------------------------- /chemprop/train/evaluate.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Callable, List 3 | 4 | import torch.nn as nn 5 | 6 | from .predict import predict 7 | from chemprop.data import MoleculeDataset, StandardScaler 8 | from chemprop.utils import get_metric_func 9 | 10 | def evaluate_predictions(preds: List[List[float]], 11 | targets: List[List[float]], 12 | ales: List[List[float]], 13 | num_tasks: int, 14 | metric_func: Callable, 15 | dataset_type: str, 16 | logger: logging.Logger = None) -> List[List[float]]: 17 | """ 18 | Evaluates predictions using a metric function and filtering out invalid targets. 19 | 20 | :param preds: A list of lists of shape (data_size, num_tasks) with model predictions. 21 | :param targets: A list of lists of shape (data_size, num_tasks) with targets. 22 | :param num_tasks: Number of tasks. 23 | :param metric_func: Metric function which takes in a list of targets and a list of predictions. 24 | :param dataset_type: Dataset type. 25 | :param logger: Logger. 26 | :return: A list with the score for each task based on `metric_func`. 27 | """ 28 | if ales is None: 29 | results, rmses, maes = evaluate_predictions_rmse(preds=preds, targets=targets, num_tasks=num_tasks, metric_func=metric_func, dataset_type=dataset_type, logger=logger) 30 | return results, rmses, maes 31 | 32 | 33 | info = logger.info if logger is not None else print 34 | rmse_function = get_metric_func(metric='rmse') 35 | mae_function = get_metric_func(metric='mae') 36 | 37 | 38 | if len(preds) == 0: 39 | return [float('nan')] * num_tasks 40 | 41 | # Filter out empty targets 42 | # valid_preds and valid_targets have shape (num_tasks, data_size) 43 | valid_preds = [[] for _ in range(num_tasks)] 44 | valid_targets = [[] for _ in range(num_tasks)] 45 | valid_ales = [[] for _ in range(num_tasks)] 46 | for i in range(num_tasks): 47 | for j in range(len(preds)): 48 | if targets[j][i] is not None: # Skip those without targets 49 | valid_preds[i].append(preds[j][i]) 50 | valid_targets[i].append(targets[j][i]) 51 | valid_ales[i].append(ales[j][i]) 52 | 53 | # Compute metric 54 | results = [] 55 | rmses = [] 56 | maes = [] 57 | for i in range(num_tasks): 58 | # # Skip if all targets or preds are identical, otherwise we'll crash during classification 59 | if dataset_type == 'classification': 60 | nan = False 61 | if all(target == 0 for target in valid_targets[i]) or all(target == 1 for target in valid_targets[i]): 62 | nan = True 63 | info('Warning: Found a task with targets all 0s or all 1s') 64 | if all(pred == 0 for pred in valid_preds[i]) or all(pred == 1 for pred in valid_preds[i]): 65 | nan = True 66 | info('Warning: Found a task with predictions all 0s or all 1s') 67 | 68 | if nan: 69 | results.append(float('nan')) 70 | continue 71 | 72 | if len(valid_targets[i]) == 0: 73 | continue 74 | 75 | if dataset_type == 'multiclass': 76 | results.append(metric_func(valid_targets[i], valid_preds[i], labels=list(range(len(valid_preds[i][0]))))) 77 | else: 78 | results.append(metric_func(valid_targets[i], valid_preds[i], valid_ales[i])) 79 | rmses.append(rmse_function(valid_targets[i], valid_preds[i])) 80 | maes.append(mae_function(valid_targets[i], valid_preds[i])) 81 | 82 | return results, rmses, maes 83 | 84 | def evaluate_predictions_rmse(preds: List[List[float]], 85 | targets: List[List[float]], 86 | num_tasks: int, 87 | metric_func: Callable, 88 | dataset_type: str, 89 | logger: logging.Logger = None) -> List[List[float]]: 90 | """ 91 | Evaluates predictions using a metric function and filtering out invalid targets. 92 | 93 | :param preds: A list of lists of shape (data_size, num_tasks) with model predictions. 94 | :param targets: A list of lists of shape (data_size, num_tasks) with targets. 95 | :param num_tasks: Number of tasks. 96 | :param metric_func: Metric function which takes in a list of targets and a list of predictions. 97 | :param dataset_type: Dataset type. 98 | :param logger: Logger. 99 | :return: A list with the score for each task based on `metric_func`. 100 | """ 101 | info = logger.info if logger is not None else print 102 | rmse_function = get_metric_func(metric='rmse') 103 | mae_function = get_metric_func(metric='mae') 104 | 105 | 106 | if len(preds) == 0: 107 | return [float('nan')] * num_tasks 108 | 109 | # Filter out empty targets 110 | # valid_preds and valid_targets have shape (num_tasks, data_size) 111 | valid_preds = [[] for _ in range(num_tasks)] 112 | valid_targets = [[] for _ in range(num_tasks)] 113 | for i in range(num_tasks): 114 | for j in range(len(preds)): 115 | if targets[j][i] is not None: # Skip those without targets 116 | valid_preds[i].append(preds[j][i]) 117 | valid_targets[i].append(targets[j][i]) 118 | 119 | # Compute metric 120 | results = [] 121 | rmses = [] 122 | maes = [] 123 | for i in range(num_tasks): 124 | # # Skip if all targets or preds are identical, otherwise we'll crash during classification 125 | if dataset_type == 'classification': 126 | raise ValueError('classification not support (evaluate.py)') 127 | 128 | if len(valid_targets[i]) == 0: 129 | continue 130 | 131 | if dataset_type == 'multiclass': 132 | results.append(metric_func(valid_targets[i], valid_preds[i], labels=list(range(len(valid_preds[i][0]))))) 133 | else: 134 | results.append(metric_func(valid_targets[i], valid_preds[i])) 135 | rmses.append(rmse_function(valid_targets[i], valid_preds[i])) 136 | maes.append(mae_function(valid_targets[i], valid_preds[i])) 137 | 138 | return results, rmses, maes 139 | 140 | 141 | def evaluate(model: nn.Module, 142 | data: MoleculeDataset, 143 | num_tasks: int, 144 | metric_func: Callable, 145 | batch_size: int, 146 | dataset_type: str, 147 | sampling_size: int, 148 | fp_method: str, 149 | scaler: StandardScaler = None, 150 | logger: logging.Logger = None) -> List[float]: 151 | """ 152 | Evaluates an ensemble of models on a dataset. 153 | 154 | :param model: A model. 155 | :param data: A MoleculeDataset. 156 | :param num_tasks: Number of tasks. 157 | :param metric_func: Metric function which takes in a list of targets and a list of predictions. 158 | :param batch_size: Batch size. 159 | :param dataset_type: Dataset type. 160 | :param scaler: A StandardScaler object fit on the training targets. 161 | :param logger: Logger. 162 | :return: A list with the score for each task based on `metric_func`. 163 | """ 164 | preds, ales, _, _, _ = predict( 165 | model=model, 166 | data=data, 167 | batch_size=batch_size, 168 | scaler=scaler, 169 | sampling_size=sampling_size, 170 | fp_method=fp_method 171 | ) 172 | 173 | targets = data.targets() 174 | 175 | if ales is None: 176 | results, rmse, mae = evaluate_predictions_rmse( 177 | preds=preds, 178 | targets=targets, 179 | num_tasks=num_tasks, 180 | metric_func=metric_func, 181 | dataset_type=dataset_type, 182 | logger=logger 183 | ) 184 | 185 | else: 186 | results, rmse, mae = evaluate_predictions( 187 | preds=preds, 188 | targets=targets, 189 | ales=ales, # heteroscedastic loss 190 | num_tasks=num_tasks, 191 | metric_func=metric_func, 192 | dataset_type=dataset_type, 193 | logger=logger 194 | ) 195 | 196 | return results, rmse, mae 197 | -------------------------------------------------------------------------------- /chemprop/train/make_predictions_atomic_unc_onemol.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import csv 3 | from typing import List, Optional 4 | 5 | import numpy as np 6 | import torch 7 | from tqdm import tqdm 8 | from pprint import pformat 9 | 10 | from .predict import predict 11 | from .evaluate import evaluate_predictions 12 | from chemprop.data import MoleculeDataset 13 | from chemprop.data.utils import get_data, get_data_from_smiles 14 | from chemprop.utils import load_args, load_checkpoint, load_scalers, get_metric_func 15 | 16 | from rdkit import Chem 17 | 18 | def make_predictions_atomic_unc_onemol(args: Namespace, smiles: List[str] = None) -> List[Optional[List[float]]]: 19 | """ 20 | Makes predictions. If smiles is provided, makes predictions on smiles. Otherwise makes predictions on args.test_data. 21 | 22 | :param args: Arguments. 23 | :param smiles: Smiles to make predictions on. 24 | :return: A list of lists of target predictions. 25 | """ 26 | if args.gpu is not None: 27 | torch.cuda.set_device(args.gpu) 28 | 29 | print('Loading training args') 30 | scaler, features_scaler = load_scalers(args.checkpoint_paths[0]) 31 | train_args = load_args(args.checkpoint_paths[0]) 32 | 33 | # Update args with training arguments 34 | for key, value in vars(train_args).items(): 35 | if not hasattr(args, key): 36 | setattr(args, key, value) 37 | 38 | args.atomic_unc = True 39 | print(pformat(vars(args))) 40 | print('Loading data') 41 | if smiles is not None: 42 | test_data = get_data_from_smiles(smiles=smiles, skip_invalid_smiles=False) 43 | else: 44 | if args.write_true_val: 45 | test_data, true_vals = get_data(path=args.test_path, args=args, use_compound_names=args.use_compound_names, skip_invalid_smiles=False) 46 | else: 47 | test_data = get_data(path=args.test_path, args=args, use_compound_names=args.use_compound_names, skip_invalid_smiles=False) 48 | 49 | print('Validating SMILES') 50 | valid_indices = [i for i in range(len(test_data)) if test_data[i].mol is not None] 51 | full_data = test_data 52 | test_data = MoleculeDataset([test_data[i] for i in valid_indices]) 53 | 54 | # Edge case if empty list of smiles is provided 55 | if len(test_data) == 0: 56 | return [None] * len(full_data) 57 | 58 | if args.use_compound_names: 59 | compound_names = test_data.compound_names() 60 | print(f'Test size = {len(test_data):,}') 61 | 62 | # Normalize features 63 | if train_args.features_scaling: 64 | test_data.normalize_features(features_scaler) 65 | # max atom size check 66 | args.max_atom_size = args.pred_max_atom_size = test_data.mols()[0].GetNumHeavyAtoms() 67 | print(f'args.max_atom_size = {args.max_atom_size}') 68 | if args.covariance_matrix_pred: 69 | args.scaler_stds = scaler.stds 70 | print(f'covariance matrix pred: scaling factor = {args.scaler_stds}') 71 | # Predict with each model individually and sum predictions 72 | all_preds = np.zeros((len(test_data), len(args.checkpoint_paths))) 73 | all_ale_uncs = np.zeros((len(test_data), len(args.checkpoint_paths))) 74 | all_atomic_preds = np.zeros((len(test_data), args.max_atom_size, len(args.checkpoint_paths))) 75 | all_atomic_ales = np.zeros((len(test_data), args.max_atom_size, len(args.checkpoint_paths))) 76 | 77 | print(f'Predicting with an ensemble of {len(args.checkpoint_paths)} models') 78 | for index, checkpoint_path in enumerate(tqdm(args.checkpoint_paths, total=len(args.checkpoint_paths), disable=True)): 79 | # Load model 80 | model = load_checkpoint(checkpoint_path, current_args=args, cuda=args.cuda) 81 | model_preds, ale_uncs, _, atomic_preds, atomic_uncs = predict( 82 | model=model, 83 | data=test_data, 84 | batch_size=args.batch_size, 85 | scaler=scaler, 86 | sampling_size=args.sampling_size, 87 | fp_method=args.fp_method, 88 | atomic_unc=args.atomic_unc 89 | ) 90 | all_preds[:, index] = np.array(model_preds).squeeze() # (num_mols, 1) -> (num_mols) 91 | all_ale_uncs[:, index] = np.array(ale_uncs).squeeze() 92 | all_atomic_preds[:, :, index] = atomic_preds # num_mols x atom_preds x models 93 | all_atomic_ales[:, :, index] = atomic_uncs 94 | 95 | # Ensemble predictions 96 | assert args.estimate_variance is not None 97 | avg_preds = (np.sum(all_preds, axis=1) / len(args.checkpoint_paths))[:, np.newaxis].tolist() 98 | avg_ale_uncs = (np.sum(all_ale_uncs, axis=1) / len(args.checkpoint_paths))[:, np.newaxis].tolist() 99 | avg_epi_uncs = np.var(all_preds, axis=1)[:, np.newaxis].tolist() 100 | avg_test_atomic_preds = np.mean(all_atomic_preds, axis=2) 101 | avg_test_atomic_ales = np.sum(all_atomic_ales, axis=2) / len(args.checkpoint_paths) 102 | avg_test_atomic_epis = np.var(all_atomic_preds, axis=2) 103 | avg_test_atomic_total = avg_test_atomic_ales + avg_test_atomic_epis # mol x max_atom_size 104 | # avg_test_atomic_max_ales = np.max(avg_test_atomic_ales, axis=1)[:, np.newaxis].tolist() # take max ale_unc of atoms in a molecule 105 | # avg_test_atomic_max_epis = np.max(avg_test_atomic_epis, axis=1)[:, np.newaxis].tolist() 106 | # avg_test_atomic_max_total = np.max(avg_test_atomic_total, axis=1)[:, np.newaxis].tolist() 107 | # del avg_test_atomic_ales, avg_test_atomic_epis, avg_test_atomic_total 108 | 109 | # Save predictions 110 | assert len(test_data) == len(avg_preds) 111 | assert len(test_data) == len(avg_ale_uncs) 112 | assert len(test_data) == len(avg_epi_uncs) 113 | 114 | print(f'Saving predictions to {args.preds_path}') 115 | 116 | # Put Nones for invalid smiles 117 | full_preds = [None] * len(full_data) 118 | full_ale_uncs = [None] * len(full_data) 119 | full_epi_uncs = [None] * len(full_data) 120 | full_atomic_max_ale = [None] * len(full_data) 121 | full_atomic_max_epi = [None] * len(full_data) 122 | full_atomic_max_total = [None] * len(full_data) 123 | 124 | for i, si in enumerate(valid_indices): 125 | full_preds[si] = avg_preds[i] 126 | full_ale_uncs[si] = avg_ale_uncs[i] 127 | full_epi_uncs[si] = avg_epi_uncs[i] 128 | # full_atomic_max_ale[si] = avg_test_atomic_max_ales[i] 129 | # full_atomic_max_epi[si] = avg_test_atomic_max_epis[i] 130 | # full_atomic_max_total[si] = avg_test_atomic_max_total[i] 131 | 132 | avg_preds = full_preds 133 | avg_ale_uncs = full_ale_uncs 134 | avg_epi_uncs = full_epi_uncs 135 | avg_total_uncs = np.array(avg_ale_uncs) + np.array(avg_epi_uncs) 136 | avg_atomic_max_ale = full_atomic_max_ale 137 | avg_atomic_max_epi = full_atomic_max_epi 138 | avg_atomic_max_total = full_atomic_max_total 139 | 140 | test_smiles = full_data.smiles() 141 | print(f'writing predictions number: {len(avg_preds)}') 142 | # Write predictions 143 | print(f'writing file in {args.preds_path}') 144 | with open(args.preds_path, 'w') as f: 145 | writer = csv.writer(f) 146 | 147 | for i in range(len(avg_preds)): 148 | smiles_row = [] 149 | mol = Chem.MolFromSmiles(test_smiles[i]) 150 | count = 0 151 | for atom in mol.GetAtoms(): 152 | print(atom.GetSmarts()) 153 | smiles_row.append(atom.GetSmarts()) 154 | count += 1 155 | print(f'atom number: {count}') 156 | print(smiles_row) 157 | writer.writerow(['smiles', test_smiles[i]]+ smiles_row) 158 | writer.writerow(['preds', avg_preds[i][0]] + list(avg_test_atomic_preds[i])) 159 | writer.writerow(['ale', avg_ale_uncs[i][0]] + list(avg_test_atomic_ales[i])) 160 | writer.writerow(['epi', avg_epi_uncs[i][0]] + list(avg_test_atomic_epis[i])) 161 | writer.writerow(['total', avg_total_uncs[i][0]] + list(avg_test_atomic_total[i])) 162 | 163 | print([test_smiles[i]] + list(test_smiles[i])) 164 | print([avg_preds[i]] + list(avg_test_atomic_preds[i])) 165 | print([avg_ale_uncs[i]] + list(avg_test_atomic_ales[i])) 166 | print([avg_epi_uncs[i]] + list(avg_test_atomic_epis[i])) 167 | print([avg_total_uncs[i]] + list(avg_test_atomic_total[i])) 168 | print(np.sum(avg_test_atomic_preds[i])) 169 | 170 | 171 | return avg_preds 172 | -------------------------------------------------------------------------------- /chemprop/train/make_predictions_atomic_unc.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import csv 3 | from typing import List, Optional 4 | 5 | import numpy as np 6 | import torch 7 | from tqdm import tqdm 8 | 9 | from .predict import predict 10 | from .evaluate import evaluate_predictions 11 | from chemprop.data import MoleculeDataset 12 | from chemprop.data.utils import get_data, get_data_from_smiles 13 | from chemprop.utils import load_args, load_checkpoint, load_scalers, get_metric_func 14 | 15 | 16 | def make_predictions_atomic_unc(args: Namespace, smiles: List[str] = None) -> List[Optional[List[float]]]: 17 | """ 18 | Makes predictions. If smiles is provided, makes predictions on smiles. Otherwise makes predictions on args.test_data. 19 | 20 | :param args: Arguments. 21 | :param smiles: Smiles to make predictions on. 22 | :return: A list of lists of target predictions. 23 | """ 24 | if args.gpu is not None: 25 | torch.cuda.set_device(args.gpu) 26 | 27 | print('Loading training args') 28 | scaler, features_scaler = load_scalers(args.checkpoint_paths[0]) 29 | train_args = load_args(args.checkpoint_paths[0]) 30 | 31 | # Update args with training arguments 32 | for key, value in vars(train_args).items(): 33 | if not hasattr(args, key): 34 | setattr(args, key, value) 35 | 36 | print('Loading data') 37 | if smiles is not None: 38 | test_data = get_data_from_smiles(smiles=smiles, skip_invalid_smiles=False) 39 | else: 40 | if args.write_true_val: 41 | test_data, true_vals = get_data(path=args.test_path, args=args, use_compound_names=args.use_compound_names, skip_invalid_smiles=False) 42 | else: 43 | test_data = get_data(path=args.test_path, args=args, use_compound_names=args.use_compound_names, skip_invalid_smiles=False) 44 | 45 | print('Validating SMILES') 46 | valid_indices = [i for i in range(len(test_data)) if test_data[i].mol is not None] 47 | full_data = test_data 48 | test_data = MoleculeDataset([test_data[i] for i in valid_indices]) 49 | 50 | # Edge case if empty list of smiles is provided 51 | if len(test_data) == 0: 52 | return [None] * len(full_data) 53 | 54 | if args.use_compound_names: 55 | compound_names = test_data.compound_names() 56 | print(f'Test size = {len(test_data):,}') 57 | 58 | # Normalize features 59 | if train_args.features_scaling: 60 | test_data.normalize_features(features_scaler) 61 | # max atom size check 62 | args.max_atom_size = args.pred_max_atom_size = test_data.mols()[0].GetNumHeavyAtoms() 63 | print(f'args.max_atom_size = {args.max_atom_size}') 64 | # Predict with each model individually and sum predictions 65 | all_preds = np.zeros((len(test_data), len(args.checkpoint_paths))) 66 | all_ale_uncs = np.zeros((len(test_data), len(args.checkpoint_paths))) 67 | all_atomic_preds = np.zeros((len(test_data), args.max_atom_size, len(args.checkpoint_paths))) 68 | all_atomic_ales = np.zeros((len(test_data), args.max_atom_size, len(args.checkpoint_paths))) 69 | 70 | print(f'Predicting with an ensemble of {len(args.checkpoint_paths)} models') 71 | for index, checkpoint_path in enumerate(tqdm(args.checkpoint_paths, total=len(args.checkpoint_paths), disable=True)): 72 | # Load model 73 | model = load_checkpoint(checkpoint_path, current_args=args, cuda=args.cuda) 74 | model_preds, ale_uncs, _, atomic_preds, atomic_uncs = predict( 75 | model=model, 76 | data=test_data, 77 | batch_size=args.batch_size, 78 | scaler=scaler, 79 | sampling_size=args.sampling_size, 80 | fp_method=args.fp_method, 81 | atomic_unc=True 82 | ) 83 | all_preds[:, index] = np.array(model_preds).squeeze() # (num_mols, 1) -> (num_mols) 84 | all_ale_uncs[:, index] = np.array(ale_uncs).squeeze() 85 | all_atomic_preds[:, :, index] = atomic_preds # num_mols x atom_preds x models 86 | all_atomic_ales[:, :, index] = atomic_uncs 87 | 88 | # Ensemble predictions 89 | assert args.estimate_variance is not None 90 | avg_preds = (np.sum(all_preds, axis=1) / len(args.checkpoint_paths))[:, np.newaxis].tolist() 91 | avg_ale_uncs = (np.sum(all_ale_uncs, axis=1) / len(args.checkpoint_paths))[:, np.newaxis].tolist() 92 | avg_epi_uncs = np.var(all_preds, axis=1)[:, np.newaxis].tolist() 93 | avg_test_atomic_ales = np.sum(all_atomic_ales, axis=2) / len(args.checkpoint_paths) 94 | avg_test_atomic_epis = np.var(all_atomic_preds, axis=2) 95 | avg_test_atomic_total = avg_test_atomic_ales + avg_test_atomic_epis # mol x max_atom_size 96 | avg_test_atomic_max_ales = np.max(avg_test_atomic_ales, axis=1)[:, np.newaxis].tolist() # take max ale_unc of atoms in a molecule 97 | avg_test_atomic_max_epis = np.max(avg_test_atomic_epis, axis=1)[:, np.newaxis].tolist() 98 | avg_test_atomic_max_total = np.max(avg_test_atomic_total, axis=1)[:, np.newaxis].tolist() 99 | del avg_test_atomic_ales, avg_test_atomic_epis, avg_test_atomic_total 100 | # Save predictions 101 | assert len(test_data) == len(avg_preds) 102 | assert len(test_data) == len(avg_ale_uncs) 103 | assert len(test_data) == len(avg_epi_uncs) 104 | 105 | print(f'Saving predictions to {args.preds_path}') 106 | 107 | # Put Nones for invalid smiles 108 | full_preds = [None] * len(full_data) 109 | full_ale_uncs = [None] * len(full_data) 110 | full_epi_uncs = [None] * len(full_data) 111 | full_atomic_max_ale = [None] * len(full_data) 112 | full_atomic_max_epi = [None] * len(full_data) 113 | full_atomic_max_total = [None] * len(full_data) 114 | 115 | for i, si in enumerate(valid_indices): 116 | full_preds[si] = avg_preds[i] 117 | full_ale_uncs[si] = avg_ale_uncs[i] 118 | full_epi_uncs[si] = avg_epi_uncs[i] 119 | full_atomic_max_ale[si] = avg_test_atomic_max_ales[i] 120 | full_atomic_max_epi[si] = avg_test_atomic_max_epis[i] 121 | full_atomic_max_total[si] = avg_test_atomic_max_total[i] 122 | 123 | avg_preds = full_preds 124 | avg_ale_uncs = full_ale_uncs 125 | avg_epi_uncs = full_epi_uncs 126 | avg_total_uncs = np.array(avg_ale_uncs) + np.array(avg_epi_uncs) 127 | avg_atomic_max_ale = full_atomic_max_ale 128 | avg_atomic_max_epi = full_atomic_max_epi 129 | avg_atomic_max_total = full_atomic_max_total 130 | 131 | test_smiles = full_data.smiles() 132 | ### For mixed model ### 133 | 134 | # Write predictions 135 | with open(args.preds_path, 'w') as f: 136 | writer = csv.writer(f) 137 | 138 | header = [] 139 | 140 | if args.use_compound_names: 141 | header.append('compound_names') 142 | 143 | header.append('smiles') 144 | 145 | if args.dataset_type == 'multiclass': 146 | for name in args.task_names: 147 | for i in range(args.multiclass_num_classes): 148 | header.append(name + '_class' + str(i)) 149 | else: 150 | if args.write_true_val: 151 | header.append('true_'+args.task_names[0]) 152 | header.append('preds_'+args.task_names[0]) 153 | header.extend([tn + "_ale_unc" for tn in args.task_names]) 154 | header.extend([tn + "_epi_unc" for tn in args.task_names]) 155 | header.extend([tn + "_total_unc" for tn in args.task_names]) 156 | header.extend([tn + "_max_atom_ale_unc" for tn in args.task_names]) 157 | header.extend([tn + "_max_atom_epi_unc" for tn in args.task_names]) 158 | header.extend([tn + "_max_atom_total_unc" for tn in args.task_names]) 159 | 160 | writer.writerow(header) 161 | 162 | for i in range(len(avg_preds)): 163 | row = [] 164 | 165 | if args.use_compound_names: 166 | row.append(compound_names[i]) 167 | 168 | row.append(test_smiles[i]) 169 | 170 | if args.write_true_val: 171 | row.append(true_vals[i]) 172 | 173 | if avg_preds[i] is not None: 174 | if args.dataset_type == 'multiclass': 175 | for task_probs in avg_preds[i]: 176 | row.extend(task_probs) 177 | else: 178 | row.extend(avg_preds[i]) 179 | row.extend(avg_ale_uncs[i]) 180 | row.extend(avg_epi_uncs[i]) 181 | row.extend(avg_total_uncs[i]) 182 | row.extend(avg_atomic_max_ale[i]) 183 | row.extend(avg_atomic_max_epi[i]) 184 | row.extend(avg_atomic_max_total[i]) 185 | 186 | else: 187 | if args.dataset_type == 'multiclass': 188 | row.extend([''] * args.num_tasks * args.multiclass_num_classes) 189 | else: 190 | # Both the prediction, the aleatoric uncertainty and the epistemic uncertainty are None 191 | row.extend([''] * 3 * args.num_tasks) 192 | 193 | writer.writerow(row) 194 | 195 | return avg_preds 196 | -------------------------------------------------------------------------------- /chemprop/train/make_predictions.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import csv 3 | from typing import List, Optional 4 | 5 | import numpy as np 6 | import torch 7 | from tqdm import tqdm 8 | from pprint import pformat 9 | 10 | from .predict import predict 11 | from .evaluate import evaluate_predictions 12 | from chemprop.data import MoleculeDataset 13 | from chemprop.data.utils import get_data, get_data_from_smiles 14 | from chemprop.utils import load_args, load_checkpoint, load_scalers, get_metric_func 15 | 16 | 17 | def make_predictions(args: Namespace, smiles: List[str] = None) -> List[Optional[List[float]]]: 18 | """ 19 | Makes predictions. If smiles is provided, makes predictions on smiles. Otherwise makes predictions on args.test_data. 20 | 21 | :param args: Arguments. 22 | :param smiles: Smiles to make predictions on. 23 | :return: A list of lists of target predictions. 24 | """ 25 | if args.gpu is not None: 26 | torch.cuda.set_device(args.gpu) 27 | 28 | print('Loading training args') 29 | scaler, features_scaler = load_scalers(args.checkpoint_paths[0]) 30 | for i in range(len(args.checkpoint_paths)): 31 | scaler_i, feature_scaler_i = load_scalers(args.checkpoint_paths[i]) 32 | train_args = load_args(args.checkpoint_paths[0]) 33 | 34 | # Update args with training arguments 35 | for key, value in vars(train_args).items(): 36 | if not hasattr(args, key): 37 | setattr(args, key, value) 38 | 39 | print('Loading data') 40 | if smiles is not None: 41 | test_data = get_data_from_smiles(smiles=smiles, skip_invalid_smiles=False) 42 | else: 43 | if args.write_true_val: 44 | test_data, true_vals = get_data(path=args.test_path, args=args, use_compound_names=args.use_compound_names, skip_invalid_smiles=False) 45 | else: 46 | test_data = get_data(path=args.test_path, args=args, use_compound_names=args.use_compound_names, skip_invalid_smiles=False) 47 | 48 | print('Validating SMILES') 49 | valid_indices = [i for i in range(len(test_data)) if test_data[i].mol is not None] 50 | full_data = test_data 51 | test_data = MoleculeDataset([test_data[i] for i in valid_indices]) 52 | 53 | # Edge case if empty list of smiles is provided 54 | if len(test_data) == 0: 55 | return [None] * len(full_data) 56 | 57 | if args.use_compound_names: 58 | compound_names = test_data.compound_names() 59 | print(f'Test size = {len(test_data):,}') 60 | 61 | # Normalize features 62 | if train_args.features_scaling: 63 | test_data.normalize_features(features_scaler) 64 | # max atom size check 65 | if hasattr(args, 'pred_max_atom_size'): 66 | print(f'predict max heavy atom size: {args.pred_max_atom_size}') 67 | args.max_atom_size = args.pred_max_atom_size 68 | else: 69 | args.pred_max_atom_size = args.max_atom_size 70 | print(f'args.pred_max_atom_size is {args.pred_max_atom_size}') 71 | # Predict with each model individually and sum predictions 72 | if args.dataset_type == 'multiclass': 73 | sum_preds = np.zeros((len(test_data), args.num_tasks, args.multiclass_num_classes)) 74 | sum_ale_uncs = np.zeros((len(test_data), args.num_tasks, args.multiclass_num_classes)) 75 | sum_epi_uncs = np.zeros((len(test_data), args.num_tasks, args.multiclass_num_classes)) 76 | else: 77 | sum_preds = np.zeros((len(test_data), args.num_tasks)) 78 | sum_ale_uncs = np.zeros((len(test_data), args.num_tasks)) 79 | sum_epi_uncs = np.zeros((len(test_data), args.num_tasks)) 80 | 81 | # Partial results for variance robust calculation. 82 | all_preds = np.zeros((len(test_data), args.num_tasks, len(args.checkpoint_paths))) 83 | 84 | print(pformat(vars(args))) 85 | 86 | print(f'Predicting with an ensemble of {len(args.checkpoint_paths)} models') 87 | for index, checkpoint_path in enumerate(tqdm(args.checkpoint_paths, total=len(args.checkpoint_paths), disable=True)): 88 | # Load model 89 | model = load_checkpoint(checkpoint_path, current_args=args, cuda=args.cuda) 90 | model_preds, ale_uncs, epi_uncs, _, _ = predict( 91 | model=model, 92 | data=test_data, 93 | batch_size=args.batch_size, 94 | scaler=scaler, 95 | sampling_size=args.sampling_size, 96 | fp_method=args.fp_method 97 | ) 98 | sum_preds += np.array(model_preds) 99 | 100 | ### For mixed model ### 101 | if ale_uncs is not None: 102 | sum_ale_uncs += np.array(ale_uncs) 103 | if epi_uncs is not None: 104 | sum_epi_uncs += np.array(epi_uncs) 105 | if args.estimate_variance: 106 | all_preds[:, :, index] = model_preds 107 | ### For mixed model ### 108 | 109 | # Ensemble predictions 110 | ### For mixed model ### 111 | if args.estimate_variance: # not mc_dropout 112 | # Use ensemble variance to estimate uncertainty. This overwrites existing uncertainty estimates. 113 | # preds <- mean(preds), ale_uncs <- mean(ale_uncs), epi_uncs <- var(preds) 114 | avg_preds = sum_preds / len(args.checkpoint_paths) 115 | avg_preds = avg_preds.tolist() 116 | 117 | avg_ale_uncs = sum_ale_uncs / len(args.checkpoint_paths) 118 | avg_ale_uncs = avg_ale_uncs.tolist() 119 | 120 | avg_epi_uncs = np.var(all_preds, axis=2) 121 | avg_epi_uncs = avg_epi_uncs.tolist() 122 | else: # mc_dropout 123 | # Use another method to estimate uncertainty. 124 | # preds <- mean(preds), ale_uncs <- mean(ale_uncs), epi_uncs <- mean(epi_uncs) 125 | avg_preds = sum_preds / len(args.checkpoint_paths) 126 | avg_preds = avg_preds.tolist() 127 | 128 | avg_ale_uncs = sum_ale_uncs / len(args.checkpoint_paths) 129 | avg_ale_uncs = avg_ale_uncs.tolist() 130 | 131 | avg_epi_uncs = sum_epi_uncs / len(args.checkpoint_paths) 132 | avg_epi_uncs = avg_epi_uncs.tolist() 133 | ### For mixed model ### 134 | 135 | # Save predictions 136 | assert len(test_data) == len(avg_preds) 137 | assert len(test_data) == len(avg_ale_uncs) 138 | assert len(test_data) == len(avg_epi_uncs) 139 | 140 | print(f'Saving predictions to {args.preds_path}') 141 | 142 | # Put Nones for invalid smiles 143 | full_preds = [None] * len(full_data) 144 | full_ale_uncs = [None] * len(full_data) 145 | full_epi_uncs = [None] * len(full_data) 146 | 147 | for i, si in enumerate(valid_indices): 148 | full_preds[si] = avg_preds[i] 149 | full_ale_uncs[si] = avg_ale_uncs[i] 150 | full_epi_uncs[si] = avg_epi_uncs[i] 151 | 152 | avg_preds = full_preds 153 | avg_ale_uncs = full_ale_uncs 154 | avg_epi_uncs = full_epi_uncs 155 | avg_total_uncs = np.array(avg_ale_uncs) + np.array(avg_epi_uncs) 156 | 157 | test_smiles = full_data.smiles() 158 | ### For mixed model ### 159 | 160 | # Write predictions 161 | with open(args.preds_path, 'w') as f: 162 | writer = csv.writer(f) 163 | 164 | header = [] 165 | 166 | if args.use_compound_names: 167 | header.append('compound_names') 168 | 169 | header.append('smiles') 170 | 171 | if args.dataset_type == 'multiclass': 172 | for name in args.task_names: 173 | for i in range(args.multiclass_num_classes): 174 | header.append(name + '_class' + str(i)) 175 | else: 176 | if args.write_true_val: 177 | header.append('true_'+args.task_names[0]) 178 | header.append('pred_'+args.task_names[0]) 179 | header.extend([tn + "_ale_unc" for tn in args.task_names]) 180 | header.extend([tn + "_epi_unc" for tn in args.task_names]) 181 | header.extend([tn + "_total_unc" for tn in args.task_names]) 182 | 183 | writer.writerow(header) 184 | 185 | for i in range(len(avg_preds)): 186 | row = [] 187 | 188 | if args.use_compound_names: 189 | row.append(compound_names[i]) 190 | 191 | row.append(test_smiles[i]) 192 | 193 | if args.write_true_val: 194 | row.append(true_vals[i]) 195 | 196 | if avg_preds[i] is not None: 197 | if args.dataset_type == 'multiclass': 198 | for task_probs in avg_preds[i]: 199 | row.extend(task_probs) 200 | else: 201 | row.extend(avg_preds[i]) 202 | row.extend(avg_ale_uncs[i]) 203 | row.extend(avg_epi_uncs[i]) 204 | row.extend(avg_total_uncs[i]) 205 | 206 | else: 207 | if args.dataset_type == 'multiclass': 208 | row.extend([''] * args.num_tasks * args.multiclass_num_classes) 209 | else: 210 | # Both the prediction, the aleatoric uncertainty and the epistemic uncertainty are None 211 | row.extend([''] * 3 * args.num_tasks) 212 | 213 | writer.writerow(row) 214 | 215 | return avg_preds 216 | -------------------------------------------------------------------------------- /chemprop/data/data.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import random 3 | from typing import Callable, List, Union 4 | 5 | import numpy as np 6 | from torch.utils.data.dataset import Dataset 7 | from rdkit import Chem 8 | 9 | from .scaler import StandardScaler 10 | from chemprop.features import get_features_generator 11 | 12 | 13 | class MoleculeDatapoint: 14 | """A MoleculeDatapoint contains a single molecule and its associated features and targets.""" 15 | 16 | def __init__(self, 17 | line: List[str], 18 | args: Namespace = None, 19 | features: np.ndarray = None, 20 | use_compound_names: bool = False): 21 | """ 22 | Initializes a MoleculeDatapoint, which contains a single molecule. 23 | 24 | :param line: A list of strings generated by separating a line in a data CSV file by comma. 25 | :param args: Arguments. 26 | :param features: A numpy array containing additional features (ex. Morgan fingerprint). 27 | :param use_compound_names: Whether the data CSV includes the compound name on each line. 28 | """ 29 | if args is not None: 30 | self.features_generator = args.features_generator 31 | self.args = args 32 | else: 33 | self.features_generator = self.args = None 34 | 35 | if features is not None and self.features_generator is not None: 36 | raise ValueError('Currently cannot provide both loaded features and a features generator.') 37 | 38 | self.features = features 39 | 40 | if use_compound_names: 41 | self.compound_name = line[0] # str 42 | line = line[1:] 43 | else: 44 | self.compound_name = None 45 | 46 | if line[0].startswith('InChI'): 47 | self.smiles = line[0] # inchi 48 | self.mol = Chem.MolFromInchi(self.smiles) 49 | else: 50 | self.smiles = line[0] # str 51 | self.mol = Chem.MolFromSmiles(self.smiles) 52 | 53 | 54 | # Generate additional features if given a generator 55 | if self.features_generator is not None: 56 | self.features = [] 57 | 58 | for fg in self.features_generator: 59 | features_generator = get_features_generator(fg) 60 | if self.mol is not None and self.mol.GetNumHeavyAtoms() > 0: 61 | self.features.extend(features_generator(self.mol)) 62 | 63 | self.features = np.array(self.features) 64 | 65 | # Fix nans in features 66 | if self.features is not None: 67 | replace_token = 0 68 | self.features = np.where(np.isnan(self.features), replace_token, self.features) 69 | 70 | # Create targets 71 | self.targets = [float(x) if x != '' else None for x in line[1:]] 72 | 73 | def set_features(self, features: np.ndarray): 74 | """ 75 | Sets the features of the molecule. 76 | 77 | :param features: A 1-D numpy array of features for the molecule. 78 | """ 79 | self.features = features 80 | 81 | def num_tasks(self) -> int: 82 | """ 83 | Returns the number of prediction tasks. 84 | 85 | :return: The number of tasks. 86 | """ 87 | return len(self.targets) 88 | 89 | def set_targets(self, targets: List[float]): 90 | """ 91 | Sets the targets of a molecule. 92 | 93 | :param targets: A list of floats containing the targets. 94 | """ 95 | self.targets = targets 96 | 97 | 98 | class MoleculeDataset(Dataset): 99 | """A MoleculeDataset contains a list of molecules and their associated features and targets.""" 100 | 101 | def __init__(self, data: List[MoleculeDatapoint]): 102 | """ 103 | Initializes a MoleculeDataset, which contains a list of MoleculeDatapoints (i.e. a list of molecules). 104 | 105 | :param data: A list of MoleculeDatapoints. 106 | """ 107 | self.data = data 108 | self.args = self.data[0].args if len(self.data) > 0 else None 109 | self.scaler = None 110 | 111 | def compound_names(self) -> List[str]: 112 | """ 113 | Returns the compound names associated with the molecule (if they exist). 114 | 115 | :return: A list of compound names or None if the dataset does not contain compound names. 116 | """ 117 | if len(self.data) == 0 or self.data[0].compound_name is None: 118 | return None 119 | 120 | return [d.compound_name for d in self.data] 121 | 122 | def smiles(self) -> List[str]: 123 | """ 124 | Returns the smiles strings associated with the molecules. 125 | 126 | :return: A list of smiles strings. 127 | """ 128 | return [d.smiles for d in self.data] 129 | 130 | def mols(self) -> List[Chem.Mol]: 131 | """ 132 | Returns the RDKit molecules associated with the molecules. 133 | 134 | :return: A list of RDKit Mols. 135 | """ 136 | return [d.mol for d in self.data] 137 | 138 | def features(self) -> List[np.ndarray]: 139 | """ 140 | Returns the features associated with each molecule (if they exist). 141 | 142 | :return: A list of 1D numpy arrays containing the features for each molecule or None if there are no features. 143 | """ 144 | if len(self.data) == 0 or self.data[0].features is None: 145 | return None 146 | 147 | return [d.features for d in self.data] 148 | 149 | def targets(self) -> List[List[float]]: 150 | """ 151 | Returns the targets associated with each molecule. 152 | 153 | :return: A list of lists of floats containing the targets. 154 | """ 155 | return [d.targets for d in self.data] 156 | 157 | def num_tasks(self) -> int: 158 | """ 159 | Returns the number of prediction tasks. 160 | 161 | :return: The number of tasks. 162 | """ 163 | return self.data[0].num_tasks() if len(self.data) > 0 else None 164 | 165 | def features_size(self) -> int: 166 | """ 167 | Returns the size of the features array associated with each molecule. 168 | 169 | :return: The size of the features. 170 | """ 171 | return len(self.data[0].features) if len(self.data) > 0 and self.data[0].features is not None else None 172 | 173 | def shuffle(self, seed: int = None): 174 | """ 175 | Shuffles the dataset. 176 | 177 | :param seed: Optional random seed. 178 | """ 179 | if seed is not None: 180 | random.seed(seed) 181 | random.shuffle(self.data) 182 | 183 | def normalize_features(self, scaler: StandardScaler = None, replace_nan_token: int = 0) -> StandardScaler: 184 | """ 185 | Normalizes the features of the dataset using a StandardScaler (subtract mean, divide by standard deviation). 186 | 187 | If a scaler is provided, uses that scaler to perform the normalization. Otherwise fits a scaler to the 188 | features in the dataset and then performs the normalization. 189 | 190 | :param scaler: A fitted StandardScaler. Used if provided. Otherwise a StandardScaler is fit on 191 | this dataset and is then used. 192 | :param replace_nan_token: What to replace nans with. 193 | :return: A fitted StandardScaler. If a scaler is provided, this is the same scaler. Otherwise, this is 194 | a scaler fit on this dataset. 195 | """ 196 | if len(self.data) == 0 or self.data[0].features is None: 197 | return None 198 | 199 | if scaler is not None: 200 | self.scaler = scaler 201 | 202 | elif self.scaler is None: 203 | features = np.vstack([d.features for d in self.data]) 204 | self.scaler = StandardScaler(replace_nan_token=replace_nan_token) 205 | self.scaler.fit(features) 206 | 207 | for d in self.data: 208 | d.set_features(self.scaler.transform(d.features.reshape(1, -1))[0]) 209 | 210 | return self.scaler 211 | 212 | def set_targets(self, targets: List[List[float]]): 213 | """ 214 | Sets the targets for each molecule in the dataset. Assumes the targets are aligned with the datapoints. 215 | 216 | :param targets: A list of lists of floats containing targets for each molecule. This must be the 217 | same length as the underlying dataset. 218 | """ 219 | assert len(self.data) == len(targets) 220 | for i in range(len(self.data)): 221 | self.data[i].set_targets(targets[i]) 222 | 223 | def sort(self, key: Callable): 224 | """ 225 | Sorts the dataset using the provided key. 226 | 227 | :param key: A function on a MoleculeDatapoint to determine the sorting order. 228 | """ 229 | self.data.sort(key=key) 230 | 231 | def __len__(self) -> int: 232 | """ 233 | Returns the length of the dataset (i.e. the number of molecules). 234 | 235 | :return: The length of the dataset. 236 | """ 237 | return len(self.data) 238 | 239 | def __getitem__(self, item) -> Union[MoleculeDatapoint, List[MoleculeDatapoint]]: 240 | """ 241 | Gets one or more MoleculeDatapoints via an index or slice. 242 | 243 | :param item: An index (int) or a slice object. 244 | :return: A MoleculeDatapoint if an int is provided or a list of MoleculeDatapoints if a slice is provided. 245 | """ 246 | return self.data[item] 247 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Atom-Based Uncertainty Model 2 | 3 | ![image](https://github.com/chuiyang/atom-based_uncertainty_model/blob/main/images/TOC.jpeg) 4 | 5 | The atom-based uncertainty quantification method provides an extra layer of explainability to both aleatoric and epistemic uncertainties, i.e., one can analyze individual atomic uncertainty values to diagnose the chemical component that introduces the uncertainty in the prediction. 6 | 7 | Detailed content is available at Journal of Cheminformatics: 8 | 9 | [Explainable Uncertainty Quantifications for Deep Learning-Based Molecular Property Prediction](https://doi.org/10.1186/s13321-023-00682-3) 10 | 11 | 12 | 13 | The atom-based uncertainty model is modified from the architecture of a [molecule-based uncertainty model](https://github.com/gscalia/chemprop/tree/uncertainty) by Scalia G, Grambow CA, Pernici B et al (2020) 14 | 15 | > **Note**: 16 | Currently only **regression tasks** are supported. 17 | This repository is still under development. (14.02.2024) \ 18 | I'll try to deliver this repo biweekly, hopefully~ :) 19 | 20 | ## Table of Contents 21 | - [Train](#train) 22 | - [Evaluate](#evaluate) 23 | - [Molecular Property Prediction](#1-molecular-property-prediction) 24 | - [Draw Molecular Images](#2-draw-molecular-images) 25 | - [Post-hoc recalibration](#post-hoc-recalibration) 26 | 27 | 28 | ## Train 29 | ### Train **atom-based uncertainty model** by running: 30 | ```bash 31 | python train.py \ 32 | --data_path \ 33 | --save_dir \ 34 | --dataset_type regression \ 35 | --max_atom_size 9 \ 36 | --aleatoric \ 37 | --metric heteroscedastic \ 38 | ``` 39 | [Below is optional for atom-based uncertainty model] 40 | ```bash 41 | --fp_method atomic \ 42 | --corr_similarity_function pearson \ 43 | --epochs 150 \ 44 | --no_features_scaling \ 45 | --seed 20 \ 46 | --y_scaling \ 47 | --batch_size 50 \ 48 | --save_smiles_splits \ 49 | --max_lr 1e-3 \ 50 | --ensemble_size 1 51 | ``` 52 | * `` is the CSV file with columns name at the first row 53 | 54 | e.g. 55 | | smiles | [property name] | 56 | | :---: | :---: | 57 | | c1ccccc1 | -1.31 | 58 | | CCCO | 2.43 | 59 | | ... | ... | 60 | 61 | * `` is the path to save the checkpoints. e.g., 👉 ./result/folder1 62 | * `--max_atom_size` is to specify the largest size of molecule in the training data. 63 | e.g. the maximum number of atoms in a molecule is 9. 64 | * `--fp_method` should be specified as `atomic` for atomic predictive distributions. (default: atomic) 65 | 66 | ### Train **molecule-based uncertainty model** by running: 67 | ```bash 68 | python train.py \ 69 | --data_path \ 70 | --save_dir \ 71 | --dataset_type regression \ 72 | --fp_method molecular \ 73 | --aleatoric \ 74 | --metric heteroscedastic \ 75 | ``` 76 | [Below is optional for molecule-based uncertainty model] 77 | ```bash 78 | --epochs 150 \ 79 | --no_features_scaling \ 80 | --seed 20 \ 81 | --aggregation sum \ 82 | --y_scaling \ 83 | --batch_size 50 \ 84 | --save_smiles_splits \ 85 | --ensemble_size 30 \ 86 | --max_lr 5e-4 87 | ``` 88 | * `` and `` is the same as the above, and there is no need to specify the largest size of molecule in the training data when training a molecule-based uncertainty model 89 | * `--fp_method` should be specified as `molecular` for only generating molecular predictive distribution. 90 | 91 | ## Evaluate 92 | Currently, you can predict: 93 | 1. **Input**: A CSV file with molecules in SMILES format \ 94 | **Output**: Molecular (property/aleatoric uncertainty/epistemic uncertainty) predictions. (CSV file) 95 | 96 | 2. **Input**: A CSV file with molecules in SMILES format \ 97 | **Output**: PNG/SVG images of molecules with atomic (contribution/aleatoric/epistemic) labeled near the atoms. (a folder with PNG/SVG files) 98 | 99 | ### 1. Molecular Property Prediction 100 | Run: 101 | ```bash 102 | python predict.py \ 103 | --test_path \ 104 | --checkpoint_dir \ 105 | --preds_path \ 106 | --estimate_variance 107 | ``` 108 | * `` is the CSV file path to evaluate. e.g., 👉 ./data/test_data.csv 109 | * `` is the checkpoint directory path where the model is saved. It should be same as the `` when you train the model. e.g., 👉 ./result/folder1 110 | * `` is the CSV file path to save the output file after predicting the ``e.g., 👉 ./data/test_data_pred.csv 111 | * Note: if the maximum heavy atom size in the test_path is larger then training data in checkpoint_dir, add `--pred_max_atom_size ` tag (atom-based uncertainty model only, to be fixed). 112 | 113 | ### 2. Draw Molecular Images 114 | with atomic information \ 115 | Run: 116 | ```bash 117 | python draw_predicted_molecules.py \ 118 | --test_path \ 119 | --checkpoint_dir \ 120 | --draw_mols_dir \ 121 | --high_resolution 122 | ``` 123 | * `` is the CSV file path to evaluate. e.g., 👉 ./data/test_data.csv 124 | * `` is the checkpoint directory path where the model is saved. It should be same as the `` when you train the model. e.g., 👉 ./result/folder1 125 | * `` is the directory path where the images will be saved. 👉 ./molecule/test_data_image_folder 126 | > A folder named `test_data_image_folder` contains `pred`, `ale`, and `epi` folders. \ 127 | Three PNG/SVG images will be generated per molecule, including property prediction, aleatoric uncertainty, and epistemic uncertainty. \ 128 | These PNG/SVG images will be classified into the folders they belong to. (e.g. aleatoric uncertainty with atomic aleatoric uncertainty image is in `ale` folder) 129 | * `--high_resolution` add this tag will generate images with svg format. Else, images with png format. 130 | 131 | ![image](https://github.com/chuiyang/atom-based_uncertainty_model/blob/main/images/draw_predicted_molecule_images.png) 132 | 133 | ## Post-hoc recalibration 134 | 135 | To fine-tune the variance layer in either atom- or molecule-based uncertainty model, run train_multimodel.py and add `--transfer_learning_freeze_GCNN`. 136 | 137 | In the following, the ensemble model before post-hoc calibration is named as "ens_model" and the ensemble model after post-hoc calibration is named as "post-hoc_ens_model". 138 | 139 | ### Train the pos-hoc recalibration on ens_model by running: 140 | ```bash 141 | python train_multimodel.py \ 142 | --data_path \ 143 | --separate_val_path \ 144 | --separate_test_path \ 145 | --checkpoint_dir \ 146 | --save_dir \ 147 | --transfer_learning_freeze_GCNN \ 148 | --fp_method molecular \ 149 | --init_lr 1e-6 \ 150 | --max_lr 1e-5 \ 151 | --final_lr 8e-7 \ 152 | --warmup_epochs 4 \ 153 | --dataset_type regression \ 154 | --epochs 150 \ 155 | --no_features_scaling \ 156 | --seed 20 \ 157 | --aleatoric \ 158 | --metric heteroscedastic \ 159 | --aggregation sum \ 160 | --ensemble_size 30 \ 161 | --y_scaling 162 | ``` 163 | * ``, ``, and `` are the CSV file paths of training/validation/testing data that used in ens_model. 164 | * `` is the path to the saved checkpoint of the ens_model. 165 | * `` is the path to save the checkpoints of post-hoc_ens_model. 166 | * `--transfer_learning_freeze_GCNN` is to freeze the weights that do not belongs in **variance layer**. 167 | 168 | ## Computational Cost 169 | The computational cost depends on **the size of the training set** and **the number of epochs the machine runs**.
We give the user an idea of how long it takes to train the model.
The times shown below are for training an atom-based uncertainty model.
(If you want 5 models to form an ensemble model, 5 times the time needs to be considered if you do not perform parallel processing during training.) 170 | 171 | For **Delaney**, the size of dataset is 1128 molecules. We split train:val:test into 8:1:1, set the 150 epochs, and stop early if there is no improvement in 50 epochs. 172 | | Epochs it ran | Time | 173 | | ------------- | ------------- | 174 | | 60 | 2 mins 33 secs | 175 | | 65 | 2 mins 37 secs | 176 | | 69 | 2 mins 55 secs | 177 | | 88 | 3 mins 44 secs | 178 | | 101 | 6 mins 08 sces | 179 | 180 | For **QM9**, the size of dataset is 134k molecules. We split train:val:test into 8:1:1, set the 150 epochs, and stop early if there is no improvement in 15 epochs. 181 | | Epochs it ran | Time | 182 | | ------------- | ------------- | 183 | | 36 | 141 mins | 184 | | 70 | 258 mins | 185 | | 90 | 330 mins | 186 | | 94 | 352 mins | 187 | | 114 | 411 mins | 188 | | 116 | 424 mins | 189 | | 130 | 472 mins | 190 | 191 | All timings above were performed on 4 cores of a 2.0GHz AMD EPYC Rome 64-core processor machine. 192 | 193 | 194 | For **Delaney**, the size of dataset is 1128 molecules. We split train:val:test into 8:1:1, set the 150 epochs, and stop early if there is no improvement in 50 epochs. 195 | | Epochs it ran | Time | 196 | | ------------- | ------------- | 197 | | 74 | 2 mins 47 secs | 198 | | 81 | 3 mins 07 secs | 199 | | 78 | 3 mins 01 secs | 200 | | 79 | 3 mins 04 secs | 201 | | 104 | 3 mins 57 sces | 202 | 203 | For **QM9**, the size of dataset is 134k molecules. We split train:val:test into 8:1:1, set the 150 epochs, and stop early if there is no improvement in 15 epochs. 204 | | Epochs it ran | Time | 205 | | ------------- | ------------- | 206 | | 31 | 117 mins | 207 | | 36 | 134 mins | 208 | | 61 | 211 mins | 209 | | 85 | 286 mins | 210 | | 108 | 366 mins | 211 | 212 | All timings above were performed on 8 cores of a 2.0GHz AMD EPYC Rome 64-core processor machine. -------------------------------------------------------------------------------- /chemprop/nn_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Union 3 | # https://docs.python.org/zh-tw/3/library/typing.html 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from torch.optim import Optimizer 8 | from torch.optim.lr_scheduler import _LRScheduler 9 | from tqdm import trange 10 | 11 | from chemprop.data import MoleculeDataset 12 | 13 | 14 | def compute_pnorm(model: nn.Module) -> float: 15 | """Computes the norm of the parameters of a model.""" 16 | return math.sqrt(sum([p.norm().item() ** 2 for p in model.parameters()])) 17 | 18 | 19 | def compute_gnorm(model: nn.Module) -> float: 20 | """Computes the norm of the gradients of a model.""" 21 | return math.sqrt(sum([p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None])) 22 | 23 | 24 | def param_count(model: nn.Module) -> int: 25 | """ 26 | Determines number of trainable parameters. 27 | 28 | :param model: An nn.Module. 29 | :return: The number of trainable parameters. 30 | """ 31 | return sum(param.numel() for param in model.parameters() if param.requires_grad) 32 | 33 | 34 | def index_select_ND(source: torch.Tensor, index: torch.Tensor) -> torch.Tensor: 35 | """ 36 | Selects the message features from source corresponding to the atom or bond indices in index. 37 | 38 | :param source: A tensor of shape (num_bonds, hidden_size) containing message features. 39 | :param index: A tensor of shape (num_atoms/num_bonds, max_num_bonds) containing the atom or bond 40 | indices to select from source. 41 | :return: A tensor of shape (num_atoms/num_bonds, max_num_bonds, hidden_size) containing the message 42 | features corresponding to the atoms/bonds specified in index. 43 | """ 44 | index_size = index.size() # (num_atoms/num_bonds, max_num_bonds) 45 | suffix_dim = source.size()[1:] # (hidden_size,) 46 | final_size = index_size + suffix_dim # (num_atoms/num_bonds, max_num_bonds, hidden_size) 47 | 48 | target = source.index_select(dim=0, index=index.view(-1)) # (num_atoms/num_bonds * max_num_bonds, hidden_size) 49 | target = target.view(final_size) # (num_atoms/num_bonds, max_num_bonds, hidden_size) 50 | 51 | #print('index_size:', index_size) 52 | #print('suffix_dim:', suffix_dim) 53 | #print('final_size:', final_size) 54 | #print('source.index_select(dim=0, index=index.view(-1)):', source.index_select(dim=0, index=index.view(-1))) 55 | #print('source.index_select(dim=0, index=index.view(-1)).size():', source.index_select(dim=0, index=index.view(-1)).size()) 56 | #print('target:', target) 57 | #print('target.size():', target.size()) 58 | 59 | return target 60 | 61 | 62 | def get_activation_function(activation: str) -> nn.Module: 63 | """ 64 | Gets an activation function module given the name of the activation. 65 | 66 | :param activation: The name of the activation function. 67 | :return: The activation function module. 68 | """ 69 | if activation == 'ReLU': 70 | return nn.ReLU(inplace=False) 71 | elif activation == 'LeakyReLU': 72 | return nn.LeakyReLU(0.1) 73 | elif activation == 'PReLU': 74 | return nn.PReLU() 75 | elif activation == 'tanh': 76 | return nn.Tanh() 77 | elif activation == 'SELU': 78 | return nn.SELU() 79 | elif activation == 'ELU': 80 | return nn.ELU() 81 | else: 82 | raise ValueError(f'Activation "{activation}" not supported.') 83 | 84 | 85 | def initialize_weights(model: nn.Module): 86 | """ 87 | Initializes the weights of a model in place. 88 | 89 | :param model: An nn.Module. 90 | """ 91 | # torch.manual_seed(11) 92 | # torch.cuda.manual_seed(11) 93 | for param in model.parameters(): 94 | if param.dim() == 1: 95 | nn.init.constant_(param, 0) 96 | else: 97 | nn.init.xavier_normal_(param) 98 | 99 | 100 | def get_cc_dropout_hyper(num_train: int, regularization_scale: int): 101 | """ 102 | Compute the hyperparameters for Concrete Dropout 103 | 104 | :param num_train: Training size. 105 | :param regularization_scale: Regularization scale. 106 | :return: A tuple: (weight_regularizer, dropout_regularizer) 107 | 108 | """ 109 | wd = regularization_scale ** 2 / num_train 110 | dd = 2. / num_train 111 | return wd, dd 112 | 113 | 114 | class Identity(nn.Module): 115 | """Identity PyTorch module.""" 116 | def forward(self, x): 117 | return x 118 | 119 | 120 | def compute_molecule_vectors(model: nn.Module, data: MoleculeDataset, batch_size: int) -> List[np.ndarray]: 121 | """ 122 | Computes the molecule vectors output from the last layer of a MoleculeModel. 123 | 124 | :param model: A MoleculeModel. 125 | :param data: A MoleculeDataset. 126 | :param batch_size: Batch size. 127 | :return: A list of 1D numpy arrays of length hidden_size containing 128 | the molecule vectors generated by the model for each molecule provided. 129 | """ 130 | model.eval() 131 | model.ffn[-1] = Identity() # Replace last linear layer with identity 132 | if hasattr(model, 'sigmoid'): 133 | model.sigmoid = Identity() 134 | 135 | vecs = [] 136 | 137 | num_iters, iter_step = len(data), batch_size 138 | 139 | for i in range(0, num_iters, iter_step): 140 | # Prepare batch 141 | mol_batch = MoleculeDataset(data[i:i + batch_size]) 142 | smiles_batch, features_batch = mol_batch.smiles(), mol_batch.features() 143 | 144 | # Run model 145 | batch = smiles_batch 146 | 147 | with torch.no_grad(): 148 | batch_vecs = model(batch, features_batch) 149 | 150 | # Collect vectors 151 | batch_vecs = batch_vecs.data.cpu().numpy() 152 | vecs.extend(batch_vecs) 153 | 154 | return vecs 155 | 156 | 157 | class NoamLR(_LRScheduler): 158 | """ 159 | Noam learning rate scheduler with piecewise linear increase and exponential decay. 160 | 161 | The learning rate increases linearly from init_lr to max_lr over the course of 162 | the first warmup_steps (where warmup_steps = warmup_epochs * steps_per_epoch). 163 | Then the learning rate decreases exponentially from max_lr to final_lr over the 164 | course of the remaining total_steps - warmup_steps (where total_steps = 165 | total_epochs * steps_per_epoch). This is roughly based on the learning rate 166 | schedule from Attention is All You Need, section 5.3 (https://arxiv.org/abs/1706.03762). 167 | """ 168 | def __init__(self, 169 | optimizer: Optimizer, 170 | warmup_epochs: List[Union[float, int]], 171 | total_epochs: List[int], 172 | steps_per_epoch: int, 173 | init_lr: List[float], 174 | max_lr: List[float], 175 | final_lr: List[float]): 176 | """ 177 | Initializes the learning rate scheduler. 178 | 179 | :param optimizer: A PyTorch optimizer. 180 | :param warmup_epochs: The number of epochs during which to linearly increase the learning rate. 181 | :param total_epochs: The total number of epochs. 182 | :param steps_per_epoch: The number of steps (batches) per epoch. 183 | :param init_lr: The initial learning rate. 184 | :param max_lr: The maximum learning rate (achieved after warmup_epochs). 185 | :param final_lr: The final learning rate (achieved after total_epochs). 186 | """ 187 | 188 | assert len(optimizer.param_groups) == len(warmup_epochs) == len(total_epochs) == len(init_lr) == \ 189 | len(max_lr) == len(final_lr) # origin 190 | 191 | 192 | self.num_lrs = len(optimizer.param_groups) 193 | 194 | self.optimizer = optimizer 195 | self.steps_per_epoch = steps_per_epoch 196 | 197 | self.warmup_epochs = np.array(warmup_epochs) # origin 198 | self.total_epochs = np.array(total_epochs) # origin 199 | self.init_lr = np.array(init_lr) # origin 200 | self.max_lr = np.array(max_lr) # origin 201 | self.final_lr = np.array(final_lr) # origin 202 | 203 | self.current_step = 0 204 | self.lr = init_lr # origin 205 | 206 | self.warmup_steps = (self.warmup_epochs * self.steps_per_epoch).astype(int) 207 | 208 | self.total_steps = self.total_epochs * self.steps_per_epoch 209 | self.linear_increment = (self.max_lr - self.init_lr) / self.warmup_steps 210 | self.exponential_gamma = (self.final_lr / self.max_lr) ** (1 / (self.total_steps - self.warmup_steps)) 211 | 212 | super(NoamLR, self).__init__(optimizer) 213 | 214 | def get_lr(self) -> List[float]: 215 | """Gets a list of the current learning rates.""" 216 | return list(self.lr) 217 | 218 | def step(self, current_step: int = None): 219 | """ 220 | Updates the learning rate by taking a step. 221 | 222 | :param current_step: Optionally specify what step to set the learning rate to. 223 | If None, current_step = self.current_step + 1. 224 | """ 225 | if current_step is not None: 226 | self.current_step = current_step 227 | else: 228 | self.current_step += 1 229 | 230 | for i in range(self.num_lrs): 231 | if self.current_step <= self.warmup_steps[i]: 232 | self.lr[i] = self.init_lr[i] + self.current_step * self.linear_increment[i] 233 | elif self.current_step <= self.total_steps[i]: 234 | self.lr[i] = self.max_lr[i] * (self.exponential_gamma[i] ** (self.current_step - self.warmup_steps[i])) 235 | else: # theoretically this case should never be reached since training should stop at total_steps 236 | self.lr[i] = self.final_lr[i] 237 | 238 | self.optimizer.param_groups[i]['lr'] = self.lr[i] 239 | 240 | 241 | class InverseLR(_LRScheduler): 242 | """ 243 | Inverse learning rate scheduler with piecewise linear increase and exponential decay. 244 | 245 | The learning rate increases linearly from init_lr to max_lr over the course of 246 | the first warmup_steps (where warmup_steps = warmup_epochs * steps_per_epoch). 247 | Then the learning rate decreases exponentially from max_lr to final_lr by the inverse square 248 | root of steps. 249 | 250 | def lr_scheduler(step, stepsInAnEpoch, init_lr=0.001, max_lr=0.01): 251 | warm_up = 5 252 | p = max_lr*((warm_up*stepsInAnEpoch)**0.5) 253 | if step < warm_up*stepsInAnEpoch: # warm_up 254 | lr = (max_lr-init_lr)*(step/warm_up/stepsInAnEpoch) + init_lr 255 | else: # inverse square root of steps 256 | lr = 1/(step**0.5) * p 257 | # print(step, lr) 258 | return lr 259 | """ 260 | def __init__(self, 261 | optimizer: Optimizer, 262 | warmup_epochs: List[Union[float, int]], 263 | total_epochs: List[int], 264 | steps_per_epoch: int, 265 | init_lr: List[float], 266 | max_lr: List[float], 267 | final_lr: List[float]): 268 | """ 269 | Initializes the learning rate scheduler. 270 | 271 | :param optimizer: A PyTorch optimizer. 272 | :param warmup_epochs: The number of epochs during which to linearly increase the learning rate. 273 | :param total_epochs: The total number of epochs. 274 | :param steps_per_epoch: The number of steps (batches) per epoch. 275 | :param init_lr: The initial learning rate. 276 | :param max_lr: The maximum learning rate (achieved after warmup_epochs). 277 | :param final_lr: The final learning rate (achieved after total_epochs). 278 | """ 279 | 280 | assert len(optimizer.param_groups) == len(warmup_epochs) == len(total_epochs) == len(init_lr) == \ 281 | len(max_lr) == len(final_lr) # origin 282 | 283 | 284 | self.num_lrs = len(optimizer.param_groups) 285 | 286 | self.optimizer = optimizer 287 | self.steps_per_epoch = steps_per_epoch 288 | 289 | self.warmup_epochs = np.array(warmup_epochs) # origin 290 | self.total_epochs = np.array(total_epochs) # origin 291 | self.init_lr = np.array(init_lr) # origin 292 | self.max_lr = np.array(max_lr) # origin 293 | self.final_lr = np.array(final_lr) # origin 294 | 295 | self.current_step = 0 296 | self.lr = init_lr # origin 297 | 298 | self.warmup_steps = (self.warmup_epochs * self.steps_per_epoch).astype(int) 299 | 300 | self.total_steps = self.total_epochs * self.steps_per_epoch 301 | self.linear_increment = (self.max_lr - self.init_lr) / self.warmup_steps 302 | self.exponential_gamma = (self.final_lr / self.max_lr) ** (1 / (self.total_steps - self.warmup_steps)) 303 | 304 | super(InverseLR, self).__init__(optimizer) 305 | 306 | def get_lr(self) -> List[float]: 307 | """Gets a list of the current learning rates.""" 308 | return list(self.lr) 309 | 310 | def step(self, current_step: int = None): 311 | """ 312 | Updates the learning rate by taking a step. 313 | 314 | :param current_step: Optionally specify what step to set the learning rate to. 315 | If None, current_step = self.current_step + 1. 316 | """ 317 | if current_step is not None: 318 | self.current_step = current_step 319 | else: 320 | self.current_step += 1 321 | 322 | for i in range(self.num_lrs): 323 | if self.current_step <= self.warmup_steps[i]: 324 | self.lr[i] = self.init_lr[i] + self.current_step * self.linear_increment[i] 325 | elif self.current_step <= self.total_steps[i]: 326 | """ 327 | p = max_lr*((warm_up*stepsInAnEpoch)**0.5) 328 | if step < warm_up*stepsInAnEpoch: # warm_up 329 | lr = (max_lr-init_lr)*(step/warm_up/stepsInAnEpoch) + init_lr 330 | else: # inverse square root of steps 331 | lr = 1/(step**0.5) * p 332 | """ 333 | self.lr[i] = (1/(self.current_step**0.5)) * (self.max_lr[i] * (self.warmup_steps[i]**0.5)) 334 | else: # theoretically this case should never be reached since training should stop at total_steps 335 | self.lr[i] = self.final_lr[i] 336 | 337 | self.optimizer.param_groups[i]['lr'] = self.lr[i] -------------------------------------------------------------------------------- /chemprop/data/utils.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import csv 3 | from logging import Logger 4 | import pickle 5 | import random 6 | from typing import List, Set, Tuple 7 | import os 8 | 9 | from rdkit import Chem 10 | import numpy as np 11 | from tqdm import tqdm 12 | 13 | from .data import MoleculeDatapoint, MoleculeDataset 14 | from .scaffold import log_scaffold_stats, scaffold_split 15 | from chemprop.features import load_features 16 | 17 | 18 | def get_task_names(path: str, use_compound_names: bool = False) -> List[str]: 19 | """ 20 | Gets the task names from a data CSV file. 21 | 22 | :param path: Path to a CSV file. 23 | :param use_compound_names: Whether file has compound names in addition to smiles strings. 24 | :return: A list of task names. 25 | """ 26 | index = 2 if use_compound_names else 1 27 | task_names = get_header(path)[index:] 28 | 29 | return task_names 30 | 31 | 32 | def get_header(path: str) -> List[str]: 33 | """ 34 | Returns the header of a data CSV file. 35 | 36 | :param path: Path to a CSV file. 37 | :return: A list of strings containing the strings in the comma-separated header. 38 | """ 39 | with open(path) as f: 40 | header = next(csv.reader(f)) 41 | 42 | return header 43 | 44 | 45 | def get_num_tasks(path: str) -> int: 46 | """ 47 | Gets the number of tasks in a data CSV file. 48 | 49 | :param path: Path to a CSV file. 50 | :return: The number of tasks. 51 | """ 52 | return len(get_header(path)) - 1 53 | 54 | 55 | def get_smiles(path: str, header: bool = True) -> List[str]: 56 | """ 57 | Returns the smiles strings from a data CSV file (assuming the first line is a header). 58 | 59 | :param path: Path to a CSV file. 60 | :param header: Whether the CSV file contains a header (that will be skipped). 61 | :return: A list of smiles strings. 62 | """ 63 | with open(path) as f: 64 | reader = csv.reader(f) 65 | if header: 66 | next(reader) # Skip header 67 | smiles = [line[0] for line in reader] 68 | 69 | return smiles 70 | 71 | 72 | def filter_invalid_smiles(data: MoleculeDataset) -> MoleculeDataset: 73 | """ 74 | Filters out invalid SMILES. 75 | 76 | :param data: A MoleculeDataset. 77 | :return: A MoleculeDataset with only valid molecules. 78 | """ 79 | return MoleculeDataset([datapoint for datapoint in data 80 | if datapoint.smiles != '' and datapoint.mol is not None 81 | and datapoint.mol.GetNumHeavyAtoms() > 0]) 82 | 83 | 84 | def get_data(path: str, 85 | skip_invalid_smiles: bool = True, 86 | args: Namespace = None, 87 | features_path: List[str] = None, 88 | max_data_size: int = None, 89 | use_compound_names: bool = None, 90 | logger: Logger = None) -> MoleculeDataset: 91 | """ 92 | Gets smiles string and target values (and optionally compound names if provided) from a CSV file. 93 | 94 | :param path: Path to a CSV file. 95 | :param skip_invalid_smiles: Whether to skip and filter out invalid smiles. 96 | :param args: Arguments. 97 | :param features_path: A list of paths to files containing features. If provided, it is used 98 | in place of args.features_path. 99 | :param max_data_size: The maximum number of data points to load. 100 | :param use_compound_names: Whether file has compound names in addition to smiles strings. 101 | :param logger: Logger. 102 | :return: A MoleculeDataset containing smiles strings and target values along 103 | with other info such as additional features and compound names when desired. 104 | """ 105 | debug = logger.debug if logger is not None else print 106 | 107 | if args is not None: 108 | # Prefer explicit function arguments but default to args if not provided 109 | features_path = features_path if features_path is not None else args.features_path 110 | max_data_size = max_data_size if max_data_size is not None else args.max_data_size 111 | use_compound_names = use_compound_names if use_compound_names is not None else args.use_compound_names 112 | else: 113 | use_compound_names = False 114 | 115 | max_data_size = max_data_size or float('inf') 116 | 117 | # Load features 118 | if features_path is not None: 119 | features_data = [] 120 | for feat_path in features_path: 121 | features_data.append(load_features(feat_path)) # each is num_data x num_features 122 | features_data = np.concatenate(features_data, axis=1) 123 | else: 124 | features_data = None 125 | 126 | skip_smiles = set() 127 | 128 | # Load data 129 | with open(path) as f: 130 | reader = csv.reader(f) 131 | next(reader) # skip header 132 | lines = [] 133 | true_vals=[] 134 | for line in reader: 135 | smiles = line[0] 136 | true_val=float(line[1]) 137 | 138 | if smiles in skip_smiles: 139 | continue 140 | 141 | lines.append(line) 142 | true_vals.append(true_val) 143 | 144 | if len(lines) >= max_data_size: 145 | break 146 | data = MoleculeDataset([ 147 | MoleculeDatapoint( 148 | line=line, 149 | args=args, 150 | features=features_data[i] if features_data is not None else None, 151 | use_compound_names=use_compound_names 152 | ) for i, line in enumerate(lines) 153 | ]) 154 | 155 | # Filter out invalid SMILES 156 | if skip_invalid_smiles: 157 | original_data_len = len(data) 158 | data = filter_invalid_smiles(data) 159 | 160 | if len(data) < original_data_len: 161 | debug(f'Warning: {original_data_len - len(data)} SMILES are invalid.') 162 | 163 | if data.data[0].features is not None: 164 | args.features_dim = len(data.data[0].features) 165 | 166 | if args.write_true_val: 167 | return data, true_vals 168 | else: 169 | return data 170 | 171 | 172 | def get_data_from_smiles(smiles: List[str], skip_invalid_smiles: bool = True, logger: Logger = None) -> MoleculeDataset: 173 | """ 174 | Converts SMILES to a MoleculeDataset. 175 | 176 | :param smiles: A list of SMILES strings. 177 | :param skip_invalid_smiles: Whether to skip and filter out invalid smiles. 178 | :param logger: Logger. 179 | :return: A MoleculeDataset with all of the provided SMILES. 180 | """ 181 | debug = logger.debug if logger is not None else print 182 | 183 | data = MoleculeDataset([MoleculeDatapoint([smile]) for smile in smiles]) 184 | 185 | # Filter out invalid SMILES 186 | if skip_invalid_smiles: 187 | original_data_len = len(data) 188 | data = filter_invalid_smiles(data) 189 | 190 | if len(data) < original_data_len: 191 | debug(f'Warning: {original_data_len - len(data)} SMILES are invalid.') 192 | 193 | return data 194 | 195 | 196 | def split_data(data: MoleculeDataset, 197 | split_type: str = 'random', 198 | sizes: Tuple[float, float, float] = (0.8, 0.1, 0.1), 199 | seed: int = 0, 200 | args: Namespace = None, 201 | logger: Logger = None, ) -> Tuple[MoleculeDataset, 202 | MoleculeDataset, 203 | MoleculeDataset]: 204 | """ 205 | Splits data into training, validation, and test splits. 206 | 207 | :param data: A MoleculeDataset. 208 | :param split_type: Split type. 209 | :param sizes: A length-3 tuple with the proportions of data in the 210 | train, validation, and test sets. 211 | :param seed: The random seed to use before shuffling data. 212 | :param args: Namespace of arguments. 213 | :param logger: A logger. 214 | :return: A tuple containing the train, validation, and test splits of the data. 215 | """ 216 | assert len(sizes) == 3 and sum(sizes) == 1 217 | 218 | if args is not None: 219 | folds_file, val_fold_index, test_fold_index = \ 220 | args.folds_file, args.val_fold_index, args.test_fold_index 221 | else: 222 | folds_file = val_fold_index = test_fold_index = None 223 | 224 | if split_type == 'crossval': 225 | index_set = args.crossval_index_sets[args.seed] 226 | data_split = [] 227 | for split in range(3): 228 | split_indices = [] 229 | for index in index_set[split]: 230 | with open(os.path.join(args.crossval_index_dir, f'{index}.pkl'), 'rb') as rf: 231 | split_indices.extend(pickle.load(rf)) 232 | data_split.append([data[i] for i in split_indices]) 233 | train, val, test = tuple(data_split) 234 | return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test) 235 | 236 | elif split_type == 'index_predetermined': 237 | split_indices = args.crossval_index_sets[args.seed] 238 | assert len(split_indices) == 3 239 | data_split = [] 240 | for split in range(3): 241 | data_split.append([data[i] for i in split_indices[split]]) 242 | train, val, test = tuple(data_split) 243 | return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test) 244 | 245 | elif split_type == 'predetermined': 246 | if not val_fold_index: 247 | assert sizes[2] == 0 # test set is created separately so use all of the other data for train and val 248 | assert folds_file is not None 249 | assert test_fold_index is not None 250 | 251 | try: 252 | with open(folds_file, 'rb') as f: 253 | all_fold_indices = pickle.load(f) 254 | except UnicodeDecodeError: 255 | with open(folds_file, 'rb') as f: 256 | all_fold_indices = pickle.load(f, encoding='latin1') # in case we're loading indices from python2 257 | # assert len(data) == sum([len(fold_indices) for fold_indices in all_fold_indices]) 258 | 259 | log_scaffold_stats(data, all_fold_indices, logger=logger) 260 | 261 | folds = [[data[i] for i in fold_indices] for fold_indices in all_fold_indices] 262 | 263 | test = folds[test_fold_index] 264 | if val_fold_index is not None: 265 | val = folds[val_fold_index] 266 | 267 | train_val = [] 268 | for i in range(len(folds)): 269 | if i != test_fold_index and (val_fold_index is None or i != val_fold_index): 270 | train_val.extend(folds[i]) 271 | 272 | if val_fold_index is not None: 273 | train = train_val 274 | else: 275 | random.seed(seed) 276 | random.shuffle(train_val) 277 | train_size = int(sizes[0] * len(train_val)) 278 | train = train_val[:train_size] 279 | val = train_val[train_size:] 280 | 281 | return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test) 282 | 283 | elif split_type == 'scaffold_balanced': 284 | return scaffold_split(data, sizes=sizes, balanced=True, seed=seed, logger=logger) 285 | 286 | elif split_type == 'random': 287 | data.shuffle(seed=seed) 288 | 289 | train_size = int(sizes[0] * len(data)) 290 | train_val_size = int((sizes[0] + sizes[1]) * len(data)) 291 | 292 | train = data[:train_size] 293 | val = data[train_size:train_val_size] 294 | test = data[train_val_size:] 295 | 296 | return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test) 297 | 298 | else: 299 | raise ValueError(f'split_type "{split_type}" not supported.') 300 | 301 | 302 | def get_class_sizes(data: MoleculeDataset) -> List[List[float]]: 303 | """ 304 | Determines the proportions of the different classes in the classification dataset. 305 | 306 | :param data: A classification dataset 307 | :return: A list of lists of class proportions. Each inner list contains the class proportions 308 | for a task. 309 | """ 310 | targets = data.targets() 311 | 312 | # Filter out Nones 313 | valid_targets = [[] for _ in range(data.num_tasks())] 314 | for i in range(len(targets)): 315 | for task_num in range(len(targets[i])): 316 | if targets[i][task_num] is not None: 317 | valid_targets[task_num].append(targets[i][task_num]) 318 | 319 | class_sizes = [] 320 | for task_targets in valid_targets: 321 | # Make sure we're dealing with a binary classification task 322 | assert set(np.unique(task_targets)) <= {0, 1} 323 | 324 | try: 325 | ones = np.count_nonzero(task_targets) / len(task_targets) 326 | except ZeroDivisionError: 327 | ones = float('nan') 328 | print('Warning: class has no targets') 329 | class_sizes.append([1 - ones, ones]) 330 | 331 | return class_sizes 332 | 333 | 334 | def validate_data(data_path: str) -> Set[str]: 335 | """ 336 | Validates a data CSV file, returning a set of errors. 337 | 338 | :param data_path: Path to a data CSV file. 339 | :return: A set of error messages. 340 | """ 341 | errors = set() 342 | 343 | header = get_header(data_path) 344 | 345 | with open(data_path) as f: 346 | reader = csv.reader(f) 347 | next(reader) # Skip header 348 | 349 | smiles, targets = [], [] 350 | for line in reader: 351 | smiles.append(line[0]) 352 | targets.append(line[1:]) 353 | 354 | # Validate header 355 | if len(header) == 0: 356 | errors.add('Empty header') 357 | elif len(header) < 2: 358 | errors.add('Header must include task names.') 359 | 360 | mol = Chem.MolFromSmiles(header[0]) 361 | if mol is not None: 362 | errors.add('First row is a SMILES string instead of a header.') 363 | 364 | # Validate smiles 365 | for smile in smiles: 366 | mol = Chem.MolFromSmiles(smile) 367 | if mol is None: 368 | errors.add('Data includes an invalid SMILES.') 369 | 370 | # Validate targets 371 | num_tasks_set = set(len(mol_targets) for mol_targets in targets) 372 | if len(num_tasks_set) != 1: 373 | errors.add('Inconsistent number of tasks for each molecule.') 374 | 375 | if len(num_tasks_set) == 1: 376 | num_tasks = num_tasks_set.pop() 377 | if num_tasks != len(header) - 1: 378 | errors.add('Number of tasks for each molecule doesn\'t match number of tasks in header.') 379 | 380 | unique_targets = set(np.unique([target for mol_targets in targets for target in mol_targets])) 381 | 382 | if unique_targets <= {''}: 383 | errors.add('All targets are missing.') 384 | 385 | for target in unique_targets - {''}: 386 | try: 387 | float(target) 388 | except ValueError: 389 | errors.add('Found a target which is not a number.') 390 | 391 | return errors 392 | -------------------------------------------------------------------------------- /chemprop/models/mpn.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from typing import List, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | from chemprop.features import BatchMolGraph, get_atom_fdim, get_bond_fdim, mol2graph 9 | from chemprop.nn_utils import index_select_ND, get_activation_function, get_cc_dropout_hyper 10 | from chemprop.models.concrete_dropout import ConcreteDropout 11 | 12 | import time 13 | 14 | torch.set_printoptions(edgeitems=7) 15 | 16 | 17 | class MPNEncoder(nn.Module): # for atomic_vecs_d2, atomic_vecs_final, mol_vecs 18 | """A message passing neural network for encoding a molecule.""" 19 | 20 | def __init__(self, args: Namespace, atom_fdim: int, bond_fdim: int): 21 | """Initializes the MPNEncoder. 22 | 23 | :param args: Arguments. 24 | :param atom_fdim: Atom features dimension. 25 | :param bond_fdim: Bond features dimension. 26 | """ 27 | super(MPNEncoder, self).__init__() 28 | self.atom_fdim = atom_fdim 29 | self.bond_fdim = bond_fdim 30 | self.hidden_size = args.hidden_size 31 | self.bias = args.bias 32 | self.depth = args.depth 33 | self.dropout = args.dropout 34 | self.layers_per_message = 1 35 | self.undirected = args.undirected 36 | self.atom_messages = args.atom_messages 37 | self.use_input_features = args.use_input_features 38 | self.max_atom_size = args.max_atom_size 39 | self.epistemic = args.epistemic 40 | self.mc_dropout = self.epistemic == 'mc_dropout' 41 | self.aggregation = args.aggregation 42 | self.aggregation_norm = args.aggregation_norm 43 | self.fp_method = args.fp_method 44 | self.corr_similarity_function = args.corr_similarity_function 45 | self.args = args 46 | 47 | 48 | # Dropout 49 | self.dropout_layer = nn.Dropout(p=self.dropout) 50 | 51 | # Activation 52 | self.act_func = get_activation_function(args.activation) 53 | 54 | # Cached zeros 55 | self.cached_zero_vector = nn.Parameter(torch.zeros(self.hidden_size), requires_grad=False) 56 | 57 | # Concrete Dropout for Bayesian NN 58 | wd, dd = get_cc_dropout_hyper(args.train_data_size, args.regularization_scale) 59 | 60 | # Input 61 | input_dim = self.atom_fdim if self.atom_messages else self.bond_fdim # self.bond_fdim=145 62 | 63 | # cosine similarity 64 | self.cosine_similarity = nn.CosineSimilarity(dim=1, eps=1e-8) 65 | 66 | if self.mc_dropout: 67 | self.W_i = ConcreteDropout(layer=nn.Linear(input_dim, self.hidden_size, bias=self.bias), reg_acc=args.reg_acc, weight_regularizer=wd, dropout_regularizer=dd) 68 | else: 69 | self.W_i = nn.Linear(input_dim, self.hidden_size, bias=self.bias) 70 | 71 | if self.atom_messages: 72 | w_h_input_size = self.hidden_size + self.bond_fdim 73 | else: # hidden_size 74 | w_h_input_size = self.hidden_size 75 | 76 | # Shared weight matrix across depths (default) 77 | if self.mc_dropout: 78 | self.W_h = ConcreteDropout(layer=nn.Linear(w_h_input_size, self.hidden_size, bias=self.bias), reg_acc=args.reg_acc, weight_regularizer=wd, dropout_regularizer=dd, depth=self.depth - 1) 79 | self.W_o = ConcreteDropout(layer=nn.Linear(self.atom_fdim + self.hidden_size, self.hidden_size), reg_acc=args.reg_acc, weight_regularizer=wd, dropout_regularizer=dd) 80 | else: 81 | self.W_h = nn.Linear(w_h_input_size, self.hidden_size, bias=self.bias) 82 | self.W_o = nn.Linear(self.atom_fdim + self.hidden_size, self.hidden_size) 83 | 84 | # zero-padding tensor 85 | def padding(self, mol_vector): # chui: set max atom number per molecule to 10 86 | num_atoms_index = mol_vector.shape[0] 87 | num_features_index = mol_vector.shape[1] 88 | padding_tensor = torch.zeros((self.max_atom_size, num_features_index)) # default 10x300 89 | padding_tensor[:num_atoms_index, :] = mol_vector 90 | return padding_tensor 91 | 92 | def get_cov_index(self, atom_num): 93 | b = [] 94 | [b.extend(range(i * self.max_atom_size, i * self.max_atom_size + atom_num)) for i in range(atom_num)] 95 | return b 96 | 97 | def get_sign(self, val): 98 | val[val >= 0] = 1 99 | val[val < 0] = -1 100 | return val 101 | 102 | def cov_func_padding(self, mol_vector): # check input Kxy but only output variance 103 | num_atoms = mol_vector.size(0) 104 | first = mol_vector.repeat(mol_vector.size(0), 1) 105 | second = mol_vector.unsqueeze(1).repeat(1, mol_vector.size(0), 1).view(-1, mol_vector.size(1)) 106 | output_tensor = torch.cat((first, second), dim=1) 107 | mol_dim = mol_vector.size(1) 108 | 109 | # inner product 110 | if self.corr_similarity_function == 'cos': 111 | # val = torch.sum(output_tensor[:, :mol_dim]*output_tensor[:, mol_dim:], axis=1) # inner product 112 | # cos = val / torch.sqrt(torch.sum(torch.pow(output_tensor[:, :mol_dim], 2), axis=1)) / torch.sqrt(torch.sum(torch.pow(output_tensor[:, mol_dim:], 2), axis=1)) 113 | cos = self.cosine_similarity(output_tensor[:, :mol_dim], output_tensor[:, mol_dim:]) 114 | absolute_tensor = cos.view(-1, 1) 115 | 116 | # RBF kernel 117 | elif self.corr_similarity_function == 'rbf': 118 | absolute_tensor = torch.exp(-torch.sum(output_tensor[:, :mol_dim]-output_tensor[:, mol_dim:], axis=1)**2/300).view(-1, 1) 119 | val = torch.sum(output_tensor[:, :mol_dim]*output_tensor[:, mol_dim:], axis=1) # inner product 120 | sign = self.get_sign(val).view(-1, 1) 121 | absolute_tensor = absolute_tensor*sign 122 | 123 | elif self.corr_similarity_function == 'pearson': 124 | pearson = self.cosine_similarity(output_tensor[:, :mol_dim] - output_tensor[:, :mol_dim].mean(dim=1, keepdim=True), output_tensor[:, mol_dim:] - output_tensor[:, mol_dim:].mean(dim=1, keepdim=True)) 125 | absolute_tensor = pearson.view(-1, 1) 126 | 127 | else: 128 | raise ValueError(f'atomic fingerprint similarity function {self.corr_similarity_function} is not supported.') 129 | 130 | padding_cov_tensor = torch.zeros((self.max_atom_size*self.max_atom_size, absolute_tensor.size(1))) 131 | place_index = self.get_cov_index(num_atoms) 132 | if self.args.cuda: 133 | padding_cov_tensor = padding_cov_tensor.cuda() 134 | padding_cov_tensor[place_index, :] = absolute_tensor[:, :] 135 | 136 | return padding_cov_tensor 137 | 138 | def forward(self, 139 | mol_graph: BatchMolGraph, 140 | features_batch: List[np.ndarray] = None) -> torch.FloatTensor: 141 | """ 142 | Encodes a batch of molecular graphs. 143 | 144 | :param mol_graph: A BatchMolGraph representing a batch of molecular graphs. 145 | :param features_batch: A list of ndarrays containing additional features. 146 | :return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule. 147 | """ 148 | if self.use_input_features: 149 | features_batch = torch.from_numpy(np.stack(features_batch)).float() 150 | 151 | if self.args.cuda: 152 | features_batch = features_batch.cuda() 153 | 154 | if self.features_only: 155 | return features_batch 156 | 157 | f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, conv_bt = mol_graph.get_components() # wei fix 158 | 159 | if self.atom_messages: 160 | a2a = mol_graph.get_a2a() 161 | 162 | if self.args.cuda or next(self.parameters()).is_cuda: 163 | f_atoms, f_bonds, a2b, b2a, b2revb, conv_bt = f_atoms.cuda(), f_bonds.cuda(), a2b.cuda(), b2a.cuda(), b2revb.cuda(), conv_bt.cuda() # wei fix 164 | 165 | if self.atom_messages: 166 | a2a = a2a.cuda() 167 | 168 | # Input 169 | if self.atom_messages: # false 170 | input = self.W_i(f_atoms) # num_atoms x hidden_size 171 | else: 172 | input = self.W_i(f_bonds) # num_bonds x hidden_size 173 | message = self.act_func(input) # num_bonds x hidden_size 174 | 175 | # Message passing 176 | for depth in range(self.depth - 1): 177 | if self.undirected: 178 | message = (message + message[b2revb]) / 2 179 | 180 | if self.atom_messages: # False 181 | nei_a_message = index_select_ND(message, a2a) # num_atoms x max_num_bonds x hidden 182 | nei_f_bonds = index_select_ND(f_bonds, a2b) # num_atoms x max_num_bonds x bond_fdim 183 | nei_message = torch.cat((nei_a_message, nei_f_bonds), dim=2) # num_atoms x max_num_bonds x hidden + bond_fdim 184 | message = nei_message.sum(dim=1) # num_atoms x hidden + bond_fdim 185 | else: 186 | # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1) 187 | # message a_message = sum(nei_a_message) rev_message 188 | nei_a_message = index_select_ND(message, a2b) # num_atoms x max_num_bonds x hidden 189 | a_message = nei_a_message.sum(dim=1) # num_atoms x hidden 190 | rev_message = message[b2revb] # num_bonds x hidden 191 | message = a_message[b2a] - rev_message # num_bonds x hidden 192 | 193 | message = self.W_h(message) 194 | message = self.act_func(input + message) # num_bonds x hidden_size 195 | message = self.dropout_layer(message) # num_bonds x hidden 196 | 197 | a2x = a2a if self.atom_messages else a2b 198 | nei_a_message = index_select_ND(message, a2x) # num_atoms x max_num_bonds x hidden 199 | a_message = nei_a_message.sum(dim=1) # num_atoms x hidden 200 | a_input = torch.cat([f_atoms, a_message], dim=1) # num_atoms x (atom_fdim + hidden) 201 | ############ without relu ############## 202 | atom_hiddens = self.W_o(a_input) # num_atoms x hidden # norelu!!, self.act_func(self.W_o(a_input)) 203 | ############ without relu ############## 204 | atom_hiddens = self.dropout_layer(atom_hiddens) # num_atoms x hidden 205 | 206 | # Readout 207 | mol_vecs = [] 208 | cov_vecs = [] 209 | asize_vecs = [] 210 | for i, (a_start, a_size) in enumerate(a_scope): 211 | if a_size == 0: 212 | mol_vecs.append(self.cached_zero_vector) 213 | else: 214 | cur_hiddens = atom_hiddens.narrow(0, a_start, a_size) 215 | mol_vec = cur_hiddens # (num_atoms, hidden_size) 216 | 217 | if self.fp_method != 'atomic': 218 | if self.aggregation == 'mean': 219 | mol_vec_molar = mol_vec.sum(dim=0) / a_size 220 | elif self.aggregation == 'sum': 221 | mol_vec_molar = mol_vec.sum(dim=0) 222 | elif self.aggregation == 'norm': 223 | mol_vec_molar = mol_vec.sum(dim=0) / self.aggregation_norm 224 | 225 | # mol_vec_molar is molecular fingerprint 226 | if self.fp_method == 'molecular': 227 | mol_vecs.append(mol_vec_molar) 228 | 229 | else: # atomic or hybrid method 230 | if self.fp_method == 'atomic': 231 | pass 232 | elif self.fp_method == 'hybrid_dim0': 233 | mol_vec = torch.cat((mol_vec, mol_vec_molar.view(1, -1)), dim=0) 234 | assert mol_vec.shape[0] == (a_size + 1) 235 | else: # 'hybrid_dim1' 236 | assert self.fp_method == 'hybrid_dim1' 237 | mol_vec_molar = mol_vec_molar.repeat(a_size, 1) 238 | assert mol_vec_molar.shape == mol_vec.shape 239 | mol_vec = torch.cat((mol_vec, mol_vec_molar), dim=1) 240 | # padding 241 | new_mol_vec = self.padding(mol_vec) 242 | new_cov_vec = self.cov_func_padding(mol_vec) 243 | 244 | mol_vecs.append(new_mol_vec) 245 | cov_vecs.append(new_cov_vec) 246 | if self.args.intensive_property: 247 | asize_vecs.append(torch.tensor(a_size)) 248 | 249 | mol_vecs = torch.stack(mol_vecs, dim=0) # (num_molecules, num_atoms, hidden_size) 250 | if self.args.cuda: 251 | mol_vecs = mol_vecs.cuda() 252 | if self.fp_method in ['atomic', 'hybrid_dim0', 'hybrid_dim1']: 253 | cov_vecs = torch.stack(cov_vecs, dim=0) 254 | if self.args.cuda: 255 | cov_vecs = cov_vecs.cuda() 256 | 257 | if self.args.intensive_property: 258 | asize_vecs = torch.stack(asize_vecs, dim=0).unsqueeze(1) 259 | if self.args.cuda: 260 | asize_vecs = asize_vecs.cuda() 261 | 262 | 263 | # if self.use_input_features: 264 | # features_batch = features_batch.to(mol_vecs) 265 | # if len(features_batch.shape) == 1: 266 | # features_batch = features_batch.view([1,features_batch.shape[0]]) 267 | # mol_vecs = torch.cat([mol_vecs, features_batch], dim=1) # (num_molecules, num_atoms, hidden_size) 268 | 269 | # num_molecules x num_atoms x hidden , num_molecules x num_atoms^2 x 1 , num_molecules x 1 270 | return mol_vecs, cov_vecs, asize_vecs 271 | 272 | 273 | class MPN(nn.Module): 274 | """A message passing neural network for encoding a molecule.""" 275 | 276 | def __init__(self, 277 | args: Namespace, 278 | atom_fdim: int = None, 279 | bond_fdim: int = None, 280 | graph_input: bool = False): 281 | """ 282 | Initializes the MPN. 283 | 284 | :param args: Arguments. 285 | :param atom_fdim: Atom features dimension. 286 | :param bond_fdim: Bond features dimension. 287 | :param graph_input: If true, expects BatchMolGraph as input. Otherwise expects a list of smiles strings as input. 288 | """ 289 | super(MPN, self).__init__() # equals to nn.Module.__init__() 290 | self.features_only = args.features_only 291 | self.args = args 292 | self.atom_fdim = atom_fdim or get_atom_fdim(args) 293 | self.bond_fdim = bond_fdim or get_bond_fdim(args) + (not args.atom_messages) * self.atom_fdim # self.bond_fdim=145 where bond_fdim=17, atom_fdim=128 294 | self.graph_input = graph_input 295 | self.encoder = MPNEncoder(self.args, self.atom_fdim, self.bond_fdim) 296 | 297 | if self.features_only: 298 | return 299 | 300 | def forward(self, 301 | batch: Union[List[str], BatchMolGraph], 302 | features_batch: List[np.ndarray] = None) -> torch.FloatTensor: 303 | """ 304 | Encodes a batch of molecular SMILES strings. 305 | 306 | :param batch: A list of SMILES strings or a BatchMolGraph (if self.graph_input is True). 307 | :param features_batch: A list of ndarrays containing additional features. 308 | :return: A PyTorch tensor of shape (num_molecules, num_atoms, hidden_size) containing the encoding of each molecule. 309 | """ 310 | if not self.graph_input: # if features only, batch won't even be used 311 | batch = mol2graph(batch, self.args) 312 | 313 | output = self.encoder.forward(batch, features_batch) 314 | return output 315 | 316 | 317 | 318 | -------------------------------------------------------------------------------- /chemprop/train/run_training.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import csv 3 | from logging import Logger 4 | import os 5 | from pprint import pformat 6 | from typing import List 7 | 8 | import numpy as np 9 | from tensorboardX import SummaryWriter 10 | import torch 11 | from tqdm import trange 12 | import pickle 13 | from torch.optim.lr_scheduler import ExponentialLR 14 | import random 15 | 16 | from .evaluate import evaluate, evaluate_predictions 17 | from .predict import predict 18 | from .train import train 19 | from chemprop.data import StandardScaler, MoleculeDataset 20 | from chemprop.data.utils import get_class_sizes, get_data, get_task_names, split_data 21 | from chemprop.models import build_model 22 | from chemprop.nn_utils import param_count 23 | from chemprop.utils import build_optimizer, build_lr_scheduler, get_loss_func, get_metric_func, load_checkpoint,\ 24 | makedirs, save_checkpoint, transfer_learning_check 25 | import pandas as pd 26 | 27 | 28 | def run_training(args: Namespace, logger: Logger = None) -> List[float]: 29 | """ 30 | Trains a model and returns test scores on the model checkpoint with the highest validation score. 31 | 32 | :param args: Arguments. 33 | :param logger: Logger. 34 | :return: A list of ensemble scores for each task. 35 | """ 36 | if logger is not None: 37 | debug, info = logger.debug, logger.info 38 | else: 39 | debug = info = print 40 | 41 | # Set GPU 42 | if args.gpu is not None: 43 | torch.cuda.set_device(args.gpu) 44 | 45 | # Print args 46 | debug(pformat(vars(args))) 47 | 48 | # Get data 49 | debug('Loading data') 50 | args.task_names = get_task_names(args.data_path) 51 | data = get_data(path=args.data_path, args=args, logger=logger) 52 | args.num_tasks = data.num_tasks() 53 | args.features_size = data.features_size() 54 | debug(f'Number of tasks = {args.num_tasks}') 55 | 56 | # Split data 57 | debug(f'Splitting data with seed {args.seed}') 58 | if args.separate_test_path: 59 | test_data = get_data(path=args.separate_test_path, args=args, features_path=args.separate_test_features_path, logger=logger) 60 | if args.separate_val_path: 61 | val_data = get_data(path=args.separate_val_path, args=args, features_path=args.separate_val_features_path, logger=logger) 62 | 63 | if args.separate_val_path and args.separate_test_path: 64 | train_data = data 65 | elif args.separate_val_path: 66 | train_data, _, test_data = split_data(data=data, split_type=args.split_type, sizes=(0.8, 0.2, 0.0), seed=args.seed, args=args, logger=logger) 67 | elif args.separate_test_path: 68 | train_data, val_data, _ = split_data(data=data, split_type=args.split_type, sizes=(0.8, 0.2, 0.0), seed=args.seed, args=args, logger=logger) 69 | else: 70 | train_data, val_data, test_data = split_data(data=data, split_type=args.split_type, sizes=args.split_sizes, seed=args.seed, args=args, logger=logger) 71 | 72 | if args.dataset_type == 'classification': 73 | raise ValueError('Classification is not supported.') 74 | 75 | if args.save_smiles_splits: 76 | with open(args.data_path, 'r') as f: 77 | reader = csv.reader(f) 78 | header = next(reader) 79 | 80 | all_split_indices = [] 81 | datasets = [train_data, val_data] if args.separate_test_path else [train_data, val_data, test_data] 82 | names = ['train', 'val'] if args.separate_test_path else ['train', 'val', 'test'] 83 | for dataset, name in zip(datasets, names): 84 | with open(os.path.join(args.save_dir, name + '_smiles.csv'), 'w') as f: 85 | writer = csv.writer(f) 86 | writer.writerow(['smiles']) 87 | for smiles in dataset.smiles(): 88 | writer.writerow([smiles]) 89 | with open(os.path.join(args.save_dir, name + '_full.csv'), 'w') as f: 90 | writer = csv.writer(f) 91 | writer.writerow(header) 92 | dataset_targets = dataset.targets() 93 | for i, smiles in enumerate(dataset.smiles()): 94 | writer.writerow([smiles] + dataset_targets[i]) 95 | split_indices = [] 96 | for i, smiles in enumerate(dataset.smiles()): 97 | split_indices.append([smiles] + dataset_targets[i]) 98 | split_indices = sorted(split_indices) 99 | all_split_indices.append(split_indices) 100 | with open(os.path.join(args.save_dir, 'split_indices.pckl'), 'wb') as f: 101 | pickle.dump(all_split_indices, f) 102 | 103 | if args.features_scaling: 104 | features_scaler = train_data.normalize_features(replace_nan_token=0) 105 | val_data.normalize_features(features_scaler) 106 | test_data.normalize_features(features_scaler) 107 | else: 108 | features_scaler = None 109 | 110 | args.train_data_size = len(train_data) 111 | 112 | debug(f'Total size = {len(data):,} | ' 113 | f'train size = {len(train_data):,} | val size = {len(val_data):,} | test size = {len(test_data):,}') 114 | 115 | # Initialize scaler and scale training targets by subtracting mean and dividing standard deviation (regression only) 116 | if args.dataset_type == 'regression': 117 | debug('Fitting scaler') 118 | train_smiles, train_targets = train_data.smiles(), train_data.targets() 119 | 120 | if args.y_scaling: # devide with training data standard deviation 121 | debug(f'{args.fp_method} scale y value') 122 | scaler = StandardScaler().fit(train_targets) 123 | scaled_targets = scaler.transform(train_targets).tolist() 124 | 125 | else: # without scaling 126 | debug(f'{args.fp_method} unscale y value') 127 | scaler = None 128 | scaled_targets = train_targets 129 | 130 | train_data.set_targets(scaled_targets) 131 | else: 132 | scaler = None 133 | 134 | # Get loss and metric functions 135 | loss_func = get_loss_func(args) 136 | metric_func = get_metric_func(metric=args.metric) 137 | 138 | # Set up test set evaluation 139 | test_smiles, test_targets = test_data.smiles(), test_data.targets() 140 | sum_test_preds = np.zeros((len(test_smiles), args.num_tasks)) 141 | sum_test_ales = np.zeros((len(test_smiles), args.num_tasks)) 142 | all_test_preds = np.zeros((len(test_data), args.num_tasks, args.ensemble_size)) 143 | 144 | 145 | # Train ensemble of models 146 | for model_idx in range(args.ensemble_size): 147 | # Tensorboard writer 148 | save_dir = os.path.join(args.save_dir, f'model_{model_idx}') 149 | makedirs(save_dir) 150 | try: 151 | writer = SummaryWriter(log_dir=save_dir) 152 | except: 153 | writer = SummaryWriter(logdir=save_dir) 154 | # Load/build model 155 | if args.checkpoint_paths is not None: 156 | debug(f'Loading model {model_idx} from {args.checkpoint_paths[model_idx]}') 157 | model = load_checkpoint(args.checkpoint_paths[model_idx], current_args=args, logger=logger) 158 | if args.transfer_learning_freeze_GCNN: 159 | model = transfer_learning_check(model, args.transfer_learning_freeze_GCNN, logger=logger) # transfer learning, freeze GCNN layer. 160 | else: 161 | debug(f'Building model {model_idx}') 162 | model = build_model(args) 163 | debug(model) 164 | debug(f'Number of parameters = {param_count(model):,}') 165 | 166 | #for param in model.parameters(): 167 | # print(param.requires_grad) 168 | 169 | if args.cuda: 170 | debug('Moving model to cuda') 171 | model = model.cuda() 172 | 173 | # Ensure that model is saved in correct location for evaluation if 0 epochs 174 | save_checkpoint(os.path.join(save_dir, 'model.pt'), model, scaler, features_scaler, args) 175 | 176 | # Optimizers 177 | optimizer = build_optimizer(model, args) 178 | 179 | # Learning rate schedulers 180 | debug(f'train_data_size: {args.train_data_size}, args.batch_size: {args.batch_size}') 181 | scheduler = build_lr_scheduler(optimizer, args) 182 | 183 | # Run training 184 | best_score = float('inf') if args.minimize_score else -float('inf') 185 | best_epoch, n_iter = 0, 0 186 | early_stopping_step = 0 187 | early_stopping = args.early_stopping 188 | for epoch in range(args.epochs): 189 | debug(f'Epoch {epoch}') 190 | 191 | n_iter = train( 192 | model=model, 193 | data=train_data, 194 | loss_func=loss_func, 195 | optimizer=optimizer, 196 | scheduler=scheduler, 197 | args=args, 198 | n_iter=n_iter, 199 | logger=logger, 200 | writer=writer 201 | ) 202 | if isinstance(scheduler, ExponentialLR): 203 | scheduler.step() 204 | val_scores, val_rmses, val_maes = evaluate( 205 | model=model, 206 | data=val_data, 207 | num_tasks=args.num_tasks, 208 | metric_func=metric_func, 209 | batch_size=args.batch_size, 210 | dataset_type=args.dataset_type, 211 | scaler=scaler, 212 | logger=logger, 213 | sampling_size=args.sampling_size, 214 | fp_method=args.fp_method 215 | ) 216 | 217 | # Average validation score 218 | avg_val_score = np.nanmean(val_scores) 219 | avg_val_rmse = np.nanmean(val_rmses) 220 | avg_val_mae = np.nanmean(val_maes) 221 | debug(f'Validation {args.metric} = {avg_val_score:.6f} | Validation rmse = {avg_val_rmse:.6f} | Validation mae = {avg_val_mae:.6f}') 222 | writer.add_scalar(f'validation_{args.metric}', avg_val_score, n_iter) 223 | writer.add_scalar('validation_rmse', avg_val_rmse, n_iter) 224 | writer.add_scalar('validation_mae', avg_val_mae, n_iter) 225 | if args.show_individual_scores: 226 | # Individual validation scores 227 | for task_name, val_score in zip(args.task_names, val_scores): 228 | debug(f'Validation {args.metric} = {avg_val_score:.6f} | Validation rmse = {avg_val_rmse:.6f} | Validation mae = {avg_val_mae:.6f}') 229 | writer.add_scalar(f'validation_{task_name}_{args.metric}', val_score, n_iter) 230 | 231 | # Save model checkpoint if improved validation score 232 | if args.early_stopping_metric == 'heteroscedastic': 233 | save_val_score = avg_val_score 234 | elif args.early_stopping_metric == 'rmse': 235 | save_val_score = avg_val_rmse 236 | elif args.early_stopping_metric == 'mae': 237 | save_val_score = avg_val_mae 238 | else: 239 | raise ValueError(f'args.early_stopping_metric: {args.early_stopping_metric} not supported.') 240 | if args.minimize_score and save_val_score < best_score or \ 241 | not args.minimize_score and save_val_score > best_score: 242 | early_stopping_step = 0 243 | best_score, best_epoch = save_val_score, epoch 244 | save_checkpoint(os.path.join(save_dir, 'model.pt'), model, scaler, features_scaler, args) 245 | else: 246 | early_stopping_step += 1 247 | 248 | # break if early stopping happens 249 | if early_stopping_step == early_stopping: 250 | debug(f'STOPPING CONDITION IS MET!! epoch:{epoch}') 251 | break 252 | 253 | # Evaluate on test set using model with best validation score 254 | info(f'Model {model_idx} best validation {args.metric} = {best_score:.6f} on epoch {best_epoch}') 255 | model = load_checkpoint(os.path.join(save_dir, 'model.pt'), cuda=args.cuda, logger=logger) 256 | 257 | 258 | test_preds, test_ales, _, _, _ = predict( 259 | model=model, 260 | data=test_data, 261 | batch_size=args.batch_size, 262 | scaler=scaler, 263 | sampling_size=args.sampling_size, 264 | fp_method=args.fp_method 265 | ) 266 | test_scores, test_rmse, test_mae = evaluate_predictions( 267 | preds=test_preds, 268 | targets=test_targets, 269 | ales=test_ales, 270 | num_tasks=args.num_tasks, 271 | metric_func=metric_func, 272 | dataset_type=args.dataset_type, 273 | logger=logger 274 | ) 275 | if len(test_preds) != 0: 276 | sum_test_preds += np.array(test_preds) 277 | all_test_preds[:, :, model_idx] = test_preds 278 | if args.aleatoric: 279 | sum_test_ales += np.array(test_ales) 280 | 281 | # Average test score 282 | avg_test_score = np.nanmean(test_scores) 283 | avg_test_rmse = np.nanmean(test_rmse) 284 | avg_test_mae = np.nanmean(test_mae) 285 | info(f'Model {model_idx} test {args.metric} = {avg_test_score:.6f} | test rmse = {avg_test_rmse:.6f} | test mae = {avg_test_mae:.6f}') 286 | writer.add_scalar(f'test_{args.metric}', avg_test_score, 0) 287 | writer.add_scalar(f'test_rmse', avg_test_rmse, 0) 288 | writer.add_scalar(f'test_mae', avg_test_mae, 0) 289 | if args.show_individual_scores: 290 | # Individual test scores 291 | for task_name, test_score in zip(args.task_names, test_scores): 292 | info(f'Model {model_idx} test {args.metric} = {avg_test_score:.6f} | test rmse = {avg_test_rmse:.6f} | test mae = {avg_test_mae:.6f}') 293 | writer.add_scalar(f'test_{task_name}_{args.metric}', test_score, n_iter) 294 | 295 | # Evaluate ensemble on test set 296 | avg_test_preds = (sum_test_preds / args.ensemble_size).tolist() 297 | avg_test_ales = (sum_test_ales / args.ensemble_size).tolist() 298 | avg_epi_uncs = np.var(all_test_preds, axis=2).tolist() 299 | 300 | avg_test_ales = avg_test_ales if args.aleatoric else None 301 | 302 | ensemble_scores, ensemble_rmse, ensemble_mae = evaluate_predictions( 303 | preds=avg_test_preds, 304 | targets=test_targets, 305 | ales=avg_test_ales, 306 | num_tasks=args.num_tasks, 307 | metric_func=metric_func, 308 | dataset_type=args.dataset_type, 309 | logger=logger 310 | ) 311 | 312 | # save testing results in current saved_model folder 313 | test_output = np.hstack((np.array(test_data.smiles())[:, np.newaxis], np.array(test_data.targets()), np.array(avg_test_preds), np.array(avg_test_ales), np.array(avg_epi_uncs), np.array(avg_test_ales) + np.array(avg_epi_uncs))) 314 | assert test_output.shape == (test_data.__len__(), 6) 315 | test_name = args.save_dir.split('/')[-2] 316 | test_name = '_'.join(test_name.split('_')[:2]) + '_test_' + '_'.join(test_name.split('_')[2:]) # qm9_130k_test_rbf_unscale_150_cano_stop_heter_cos_cpu 317 | pd.DataFrame(test_output, columns=['smiles', f'true_{args.task_names[0]}', f'pred_{args.task_names[0]}', f'{args.task_names[0]}_ale_unc', f'{args.task_names[0]}_epi_unc', f'{args.task_names[0]}_total_unc']).to_csv(os.path.join(args.save_dir, f'{test_name}.csv'), index=False) 318 | 319 | 320 | # Average ensemble score 321 | avg_ensemble_test_score = np.nanmean(ensemble_scores) 322 | avg_ensemble_test_rmse = np.nanmean(ensemble_rmse) 323 | avg_ensemble_test_mae = np.nanmean(ensemble_mae) 324 | info(f'Ensemble test {args.metric} = {avg_ensemble_test_score:.6f} | Ensemble test rmse = {avg_ensemble_test_rmse:.6f} | Ensemble test mae = {avg_ensemble_test_mae:.6f}') 325 | writer.add_scalar(f'ensemble_test_{args.metric}', avg_ensemble_test_score, 0) 326 | 327 | # Individual ensemble scores 328 | if args.show_individual_scores: # false 329 | for task_name, ensemble_score in zip(args.task_names, ensemble_scores): 330 | info(f'Ensemble test {task_name} {args.metric} = {ensemble_score:.6f}') 331 | return avg_ensemble_test_score, avg_ensemble_test_rmse, avg_ensemble_test_mae 332 | --------------------------------------------------------------------------------