├── .gitignore ├── LICENSE.txt ├── README.md ├── __init__.py ├── chemprop ├── __init__.py ├── data │ ├── __init__.py │ ├── data.py │ ├── scaffold.py │ ├── scaler.py │ ├── similarity.py │ ├── unsupervised_cluster.py │ ├── utils.py │ └── vocab.py ├── data_processing │ ├── __init__.py │ ├── avg_dups.py │ ├── plot_distribution.py │ ├── process_zinc.py │ └── resplit.py ├── features │ ├── __init__.py │ ├── async_featurization.py │ ├── descriptors.py │ ├── featurization.py │ ├── functional_groups.py │ ├── kernels.py │ ├── morgan_fingerprint.py │ ├── rdkit_features.py │ ├── rdkit_normalized_features.py │ ├── smarts.txt │ └── utils.py ├── models │ ├── __init__.py │ ├── gan.py │ ├── jtnn.py │ ├── learned_kernel.py │ ├── model.py │ ├── moe.py │ └── mpn.py ├── nn_utils.py ├── parsing.py ├── random_forest.py ├── train │ ├── __init__.py │ ├── cross_validate.py │ ├── evaluate.py │ ├── make_predictions.py │ ├── predict.py │ ├── run_training.py │ └── train.py └── utils.py ├── data ├── all_smiles.csv.gz ├── bace.csv ├── bbbp.csv ├── clintox.csv ├── delaney.csv ├── freesolv.csv ├── hiv.csv.gz ├── lipo.csv ├── muv.csv.gz ├── pcba.csv.gz ├── pdbbind_core.csv ├── pdbbind_full.csv ├── pdbbind_refined.csv ├── qm7.csv.gz ├── qm8.csv.gz ├── qm9.csv.gz ├── sanitize.py ├── sider.csv ├── tox21.csv └── toxcast.csv.gz ├── distributions ├── chembl │ └── chembl_labels_log.png ├── delaney │ └── delaney_logp.png ├── freesolv │ └── freesolv_freesolv.png ├── lipo │ └── lipo_lipo.png ├── qm8 │ ├── qm8_E1-CAM.png │ ├── qm8_E1-CC2.png │ ├── qm8_E1-PBE0.png │ ├── qm8_E2-CAM.png │ ├── qm8_E2-CC2.png │ ├── qm8_E2-PBE0.png │ ├── qm8_f1-CAM.png │ ├── qm8_f1-CC2.png │ ├── qm8_f1-PBE0.png │ ├── qm8_f2-CAM.png │ ├── qm8_f2-CC2.png │ └── qm8_f2-PBE0.png └── qm9 │ ├── qm9_alpha.png │ ├── qm9_cv.png │ ├── qm9_g298.png │ ├── qm9_g298_atom.png │ ├── qm9_gap.png │ ├── qm9_h298.png │ ├── qm9_h298_atom.png │ ├── qm9_homo.png │ ├── qm9_lumo.png │ ├── qm9_mu.png │ ├── qm9_r2.png │ ├── qm9_u0.png │ ├── qm9_u0_atom.png │ ├── qm9_u298.png │ ├── qm9_u298_atom.png │ └── qm9_zpve.png ├── hyperparameter_optimization.py ├── model_comparison.py ├── predict.py ├── random_forest.py ├── requirements.txt ├── scripts ├── __init__.py ├── avg_dups.py ├── filter_by_scaffold.py ├── overlap.py ├── plot_distribution.py ├── resplit_data.py ├── save_features.py ├── similarity.py ├── visualize_encoding_property_space.py ├── viz_attention.py └── vocab.py ├── setup.cfg ├── setup.py ├── static ├── jsme │ ├── .directory │ ├── JME_to_JSME.html │ ├── JME_to_JSME_simple.html │ ├── JSME.html │ ├── JSME_SVG.html │ ├── JSME_atom_highlight_demo.html │ ├── JSME_autoresize.html │ ├── JSME_callback_and_star.html │ ├── JSME_chemical_resolver_demo.html │ ├── JSME_depict.html │ ├── JSME_depict_action.html │ ├── JSME_depict_action_callback.html │ ├── JSME_depict_edit_toggle.html │ ├── JSME_depict_smiles.html │ ├── JSME_dnd_demo.html │ ├── JSME_editor_plus_SVG.html │ ├── JSME_hidden.html │ ├── JSME_minimal.html │ ├── JSME_minimal2.html │ ├── JSME_minimal3.html │ ├── JSME_parent_and_metabolites.html │ ├── JSME_resize.html │ ├── JSME_smiles_atom_highlight.html │ ├── JSME_template.html │ ├── JSME_test.html │ ├── api_javadoc │ │ ├── allclasses-frame.html │ │ ├── allclasses-noframe.html │ │ ├── constant-values.html │ │ ├── deprecated-list.html │ │ ├── export │ │ │ └── client │ │ │ │ ├── JSME.html │ │ │ │ ├── Utils.html │ │ │ │ ├── package-frame.html │ │ │ │ ├── package-summary.html │ │ │ │ └── package-tree.html │ │ ├── help-doc.html │ │ ├── index-all.html │ │ ├── index.html │ │ ├── overview-tree.html │ │ ├── package-list │ │ ├── script.js │ │ └── stylesheet.css │ ├── bootstrap.css │ ├── bootstrap.html │ ├── doc.css │ ├── doc.html │ ├── index.html │ ├── jme_examples │ │ ├── jme_example1.html │ │ ├── jme_example2.html │ │ ├── jme_example3.html │ │ ├── jme_example4.html │ │ └── jme_window.html │ ├── jsme │ │ ├── 222ADBFEC322C2723C6ED2C4FB31B217.cache.js │ │ ├── 293DFEFA807A962F28C09E358B34A434.cache.js │ │ ├── 396F806CD63ABD414BFBB9D57429F05B.cache.png │ │ ├── 40BAF81124143A595056A9CCA0E9DBBA.cache.png │ │ ├── 4841BDE9DC293BA35F7762B4D8EFD236.cache.png │ │ ├── 61B683D3493CAED438D5743A0404863D.cache.js │ │ ├── 6ABB8447ACAB1353A478923AC9C0550B.cache.js │ │ ├── 7A65B607B90DE29D7EA26AA83BF69D4F.cache.js │ │ ├── 8816D61E367E34DBCFA53666849E21D8.cache.js │ │ ├── 8BDB7ED57B756F8D50277056A0D59DA8.cache.js │ │ ├── 96E40B969193BD74B8A621486920E79C.cache.js │ │ ├── A2384E54F71557BAEA414A43D47F17EA.cache.js │ │ ├── A6DBDE07E3A8F66E8959A4F32505E16B.cache.png │ │ ├── C8A71BD2E1367E9BB43A1B9C25871BEE.cache.js │ │ ├── C9EEF554958AACEE6A060F620375E4FA.cache.js │ │ ├── D4DF9EC9DD21B943E35F3D5696D5D2A1.cache.js │ │ ├── D9A64F1634E29088B910B3E0D4621E49.cache.js │ │ ├── DF7764EEC1903CD03C9545B354D8D8E4.cache.png │ │ ├── clear.cache.gif │ │ ├── compilation-mappings.txt │ │ ├── deferredjs │ │ │ ├── 222ADBFEC322C2723C6ED2C4FB31B217 │ │ │ │ ├── 1.cache.js │ │ │ │ ├── 2.cache.js │ │ │ │ ├── 3.cache.js │ │ │ │ ├── 4.cache.js │ │ │ │ ├── 5.cache.js │ │ │ │ ├── 6.cache.js │ │ │ │ ├── 7.cache.js │ │ │ │ ├── 8.cache.js │ │ │ │ └── 9.cache.js │ │ │ ├── 293DFEFA807A962F28C09E358B34A434 │ │ │ │ ├── 1.cache.js │ │ │ │ ├── 2.cache.js │ │ │ │ ├── 3.cache.js │ │ │ │ ├── 4.cache.js │ │ │ │ ├── 5.cache.js │ │ │ │ ├── 6.cache.js │ │ │ │ ├── 7.cache.js │ │ │ │ ├── 8.cache.js │ │ │ │ └── 9.cache.js │ │ │ ├── 61B683D3493CAED438D5743A0404863D │ │ │ │ ├── 1.cache.js │ │ │ │ ├── 2.cache.js │ │ │ │ ├── 3.cache.js │ │ │ │ ├── 4.cache.js │ │ │ │ ├── 5.cache.js │ │ │ │ ├── 6.cache.js │ │ │ │ ├── 7.cache.js │ │ │ │ ├── 8.cache.js │ │ │ │ └── 9.cache.js │ │ │ ├── 6ABB8447ACAB1353A478923AC9C0550B │ │ │ │ ├── 1.cache.js │ │ │ │ ├── 2.cache.js │ │ │ │ ├── 3.cache.js │ │ │ │ ├── 4.cache.js │ │ │ │ ├── 5.cache.js │ │ │ │ ├── 6.cache.js │ │ │ │ ├── 7.cache.js │ │ │ │ ├── 8.cache.js │ │ │ │ └── 9.cache.js │ │ │ ├── 7A65B607B90DE29D7EA26AA83BF69D4F │ │ │ │ ├── 1.cache.js │ │ │ │ ├── 2.cache.js │ │ │ │ ├── 3.cache.js │ │ │ │ ├── 4.cache.js │ │ │ │ ├── 5.cache.js │ │ │ │ ├── 6.cache.js │ │ │ │ ├── 7.cache.js │ │ │ │ ├── 8.cache.js │ │ │ │ └── 9.cache.js │ │ │ ├── 8816D61E367E34DBCFA53666849E21D8 │ │ │ │ ├── 1.cache.js │ │ │ │ ├── 2.cache.js │ │ │ │ ├── 3.cache.js │ │ │ │ ├── 4.cache.js │ │ │ │ ├── 5.cache.js │ │ │ │ ├── 6.cache.js │ │ │ │ ├── 7.cache.js │ │ │ │ ├── 8.cache.js │ │ │ │ └── 9.cache.js │ │ │ ├── 8BDB7ED57B756F8D50277056A0D59DA8 │ │ │ │ ├── 1.cache.js │ │ │ │ ├── 2.cache.js │ │ │ │ ├── 3.cache.js │ │ │ │ ├── 4.cache.js │ │ │ │ ├── 5.cache.js │ │ │ │ ├── 6.cache.js │ │ │ │ ├── 7.cache.js │ │ │ │ ├── 8.cache.js │ │ │ │ └── 9.cache.js │ │ │ ├── A2384E54F71557BAEA414A43D47F17EA │ │ │ │ ├── 1.cache.js │ │ │ │ ├── 2.cache.js │ │ │ │ ├── 3.cache.js │ │ │ │ ├── 4.cache.js │ │ │ │ ├── 5.cache.js │ │ │ │ ├── 6.cache.js │ │ │ │ ├── 7.cache.js │ │ │ │ ├── 8.cache.js │ │ │ │ └── 9.cache.js │ │ │ ├── C8A71BD2E1367E9BB43A1B9C25871BEE │ │ │ │ ├── 1.cache.js │ │ │ │ ├── 2.cache.js │ │ │ │ ├── 3.cache.js │ │ │ │ ├── 4.cache.js │ │ │ │ ├── 5.cache.js │ │ │ │ ├── 6.cache.js │ │ │ │ ├── 7.cache.js │ │ │ │ ├── 8.cache.js │ │ │ │ └── 9.cache.js │ │ │ ├── C9EEF554958AACEE6A060F620375E4FA │ │ │ │ ├── 1.cache.js │ │ │ │ ├── 2.cache.js │ │ │ │ ├── 3.cache.js │ │ │ │ ├── 4.cache.js │ │ │ │ ├── 5.cache.js │ │ │ │ ├── 6.cache.js │ │ │ │ ├── 7.cache.js │ │ │ │ ├── 8.cache.js │ │ │ │ └── 9.cache.js │ │ │ ├── D4DF9EC9DD21B943E35F3D5696D5D2A1 │ │ │ │ ├── 1.cache.js │ │ │ │ ├── 2.cache.js │ │ │ │ ├── 3.cache.js │ │ │ │ ├── 4.cache.js │ │ │ │ ├── 5.cache.js │ │ │ │ ├── 6.cache.js │ │ │ │ ├── 7.cache.js │ │ │ │ ├── 8.cache.js │ │ │ │ └── 9.cache.js │ │ │ └── D9A64F1634E29088B910B3E0D4621E49 │ │ │ │ ├── 1.cache.js │ │ │ │ ├── 2.cache.js │ │ │ │ ├── 3.cache.js │ │ │ │ ├── 4.cache.js │ │ │ │ ├── 5.cache.js │ │ │ │ ├── 6.cache.js │ │ │ │ ├── 7.cache.js │ │ │ │ ├── 8.cache.js │ │ │ │ └── 9.cache.js │ │ ├── gwt │ │ │ └── chrome │ │ │ │ ├── chrome.css │ │ │ │ ├── chrome_rtl.css │ │ │ │ ├── images │ │ │ │ ├── button │ │ │ │ │ ├── menu-button-arrow-disabled.png │ │ │ │ │ ├── menu-button-arrow.png │ │ │ │ │ ├── split-button-arrow-active.png │ │ │ │ │ ├── split-button-arrow-disabled.png │ │ │ │ │ ├── split-button-arrow-focus.png │ │ │ │ │ ├── split-button-arrow-hover.png │ │ │ │ │ └── split-button-arrow.png │ │ │ │ ├── combobox │ │ │ │ │ ├── arrow-down-disabled.png │ │ │ │ │ ├── arrow-down.png │ │ │ │ │ ├── ellipsis-disabled.png │ │ │ │ │ └── ellipsis.png │ │ │ │ ├── corner.png │ │ │ │ ├── corner_ie6.png │ │ │ │ ├── fastree │ │ │ │ │ ├── selectionBar.gif │ │ │ │ │ ├── treeClosed.gif │ │ │ │ │ ├── treeLoading.gif │ │ │ │ │ └── treeOpen.gif │ │ │ │ ├── glasspanel │ │ │ │ │ └── blue_ridge.png │ │ │ │ ├── hborder.png │ │ │ │ ├── hborder_ie6.png │ │ │ │ ├── ie6 │ │ │ │ │ ├── corner_dialog_topleft.png │ │ │ │ │ ├── corner_dialog_topright.png │ │ │ │ │ ├── hborder_blue_shadow.png │ │ │ │ │ ├── hborder_gray_shadow.png │ │ │ │ │ ├── vborder_blue_shadow.png │ │ │ │ │ └── vborder_gray_shadow.png │ │ │ │ ├── scrolltable │ │ │ │ │ └── bg_header_gradient.gif │ │ │ │ ├── splitPanelThumb.png │ │ │ │ ├── valuespinner │ │ │ │ │ └── bg_textbox.png │ │ │ │ ├── vborder.png │ │ │ │ └── vborder_ie6.png │ │ │ │ ├── mosaic.css │ │ │ │ └── mosaic_rtl.css │ │ ├── jsa.css │ │ ├── jsme.devmode.js │ │ └── jsme.nocache.js │ ├── license.txt │ ├── release_notes.txt │ └── test_depict_many_smiles_in_table.html ├── message_passing.png └── stylesheets │ └── style.css ├── templates ├── checkpoints.html ├── data.html ├── home.html ├── layout.html ├── macros.html ├── predict.html └── train.html ├── train.py └── web.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | deepchem-test/jtnn 3 | deepchem-test/data 4 | *.idea 5 | *.DS_Store 6 | *.vscode 7 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Wengong Jin, Kyle Swanson, Kevin Yang, Regina Barzilay, Tommi Jaakkola 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/__init__.py -------------------------------------------------------------------------------- /chemprop/__init__.py: -------------------------------------------------------------------------------- 1 | import chemprop.data 2 | import chemprop.data_processing 3 | import chemprop.features 4 | import chemprop.models 5 | import chemprop.train 6 | 7 | import chemprop.nn_utils 8 | import chemprop.parsing 9 | import chemprop.utils 10 | -------------------------------------------------------------------------------- /chemprop/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import MoleculeDatapoint, MoleculeDataset 2 | from .scaffold import cluster_split 3 | from .scaler import StandardScaler 4 | from .similarity import morgan_similarity, scaffold_similarity 5 | from .unsupervised_cluster import generate_unsupervised_cluster_labels 6 | from .vocab import get_vocab, load_vocab 7 | -------------------------------------------------------------------------------- /chemprop/data/scaler.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | 5 | 6 | class StandardScaler: 7 | def __init__(self, means: np.ndarray = None, stds: np.ndarray = None, replace_nan_token=None): 8 | """Initialize StandardScaler, optionally with means and standard deviations precomputed.""" 9 | self.means = means 10 | self.stds = stds 11 | self.replace_nan_token = replace_nan_token 12 | 13 | def fit(self, X: List[List[float]]) -> 'StandardScaler': 14 | """ 15 | Learns means and standard deviations across the 0-th axis. 16 | 17 | :param X: A list of lists of floats. 18 | :return: The fitted StandardScaler. 19 | """ 20 | X = np.array(X).astype(float) 21 | self.means = np.nanmean(X, axis=0) 22 | self.stds = np.nanstd(X, axis=0) 23 | self.means = np.where(np.isnan(self.means), np.zeros(self.means.shape), self.means) 24 | self.stds = np.where(np.isnan(self.stds), np.ones(self.stds.shape), self.stds) 25 | self.stds = np.where(self.stds == 0, np.ones(self.stds.shape), self.stds) 26 | 27 | return self 28 | 29 | def transform(self, X: List[List[float]]): 30 | """ 31 | Transforms the data by subtracting the means and dividing by the standard deviations. 32 | 33 | :param X: A list of lists of floats. 34 | :return: The transformed data. 35 | """ 36 | X = np.array(X).astype(float) 37 | transformed_with_nan = (X - self.means) / self.stds 38 | transformed_with_none = np.where(np.isnan(transformed_with_nan), self.replace_nan_token, transformed_with_nan) 39 | 40 | return transformed_with_none 41 | 42 | def inverse_transform(self, X: List[List[float]]): 43 | """ 44 | Performs the inverse transformation by multiplying by the standard deviations and adding the means. 45 | 46 | :param X: A list of lists of floats. 47 | :return: The inverse transformed data. 48 | """ 49 | X = np.array(X).astype(float) 50 | transformed_with_nan = X * self.stds + self.means 51 | transformed_with_none = np.where(np.isnan(transformed_with_nan), self.replace_nan_token, transformed_with_nan) 52 | 53 | return transformed_with_none 54 | -------------------------------------------------------------------------------- /chemprop/data/unsupervised_cluster.py: -------------------------------------------------------------------------------- 1 | from logging import Logger 2 | 3 | from sklearn.cluster import MiniBatchKMeans 4 | import torch 5 | from tqdm import trange 6 | 7 | from .data import MoleculeDataset 8 | 9 | 10 | def get_cluster_labels(encodings, n_clusters: int = 10000, seed: int = 0, logger: Logger = None): 11 | n_clusters = int(min(n_clusters, len(encodings)/10)) # so we don't crash if we only picked a small number of encodings 12 | kmeans = MiniBatchKMeans(n_clusters=n_clusters, random_state=seed) 13 | cluster_labels = kmeans.fit_predict(encodings) 14 | return cluster_labels 15 | 16 | 17 | def generate_unsupervised_cluster_labels(model, data, args, logger: Logger = None): 18 | encodings = [] 19 | for i in trange(0, len(data), args.batch_size): 20 | batch = MoleculeDataset(data[i:i + args.batch_size]) 21 | with torch.no_grad(): 22 | encodings.extend([enc for enc in model.encoder(batch.smiles()).cpu().numpy()]) 23 | cluster_labels = get_cluster_labels(encodings, n_clusters=args.unsupervised_n_clusters, logger=logger) 24 | cluster_labels = cluster_labels.reshape(-1, 1).astype(int).tolist() 25 | data.set_targets(cluster_labels) 26 | 27 | -------------------------------------------------------------------------------- /chemprop/data_processing/__init__.py: -------------------------------------------------------------------------------- 1 | from .avg_dups import average_duplicates 2 | from .plot_distribution import plot_distribution 3 | from .resplit import resplit -------------------------------------------------------------------------------- /chemprop/data_processing/avg_dups.py: -------------------------------------------------------------------------------- 1 | """Averages duplicate data points in a dataset.""" 2 | 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | 7 | from chemprop.data.utils import get_data, get_header 8 | 9 | 10 | def average_duplicates(args): 11 | print('Loading data') 12 | header = get_header(args.data_path) 13 | data = get_data(path=args.data_path) 14 | print(f'Data size = {len(data):,}') 15 | 16 | # Map SMILES string to lists of targets 17 | smiles_in_order = [] 18 | smiles_to_targets = defaultdict(list) 19 | for smiles, targets in zip(data.smiles(), data.targets()): 20 | smiles_to_targets[smiles].append(targets) 21 | if len(smiles_to_targets[smiles]) == 1: 22 | smiles_in_order.append(smiles) 23 | 24 | # Find duplicates 25 | duplicate_count = 0 26 | stds = [] 27 | new_data = [] 28 | for smiles in smiles_in_order: 29 | all_targets = smiles_to_targets[smiles] 30 | duplicate_count += len(all_targets) - 1 31 | num_tasks = len(all_targets[0]) 32 | 33 | targets_by_task = [[] for _ in range(num_tasks)] 34 | for task in range(num_tasks): 35 | for targets in all_targets: 36 | if targets[task] is not None: 37 | targets_by_task[task].append(targets[task]) 38 | 39 | stds.append([np.std(task_targets) if len(task_targets) > 0 else 0.0 for task_targets in targets_by_task]) 40 | means = [np.mean(task_targets) if len(task_targets) > 0 else None for task_targets in targets_by_task] 41 | new_data.append((smiles, means)) 42 | 43 | print(f'Number of duplicates = {duplicate_count:,}') 44 | print(f'Duplicate standard deviation per task = {", ".join(f":{std:.4e}" for std in np.mean(stds, axis=0))}') 45 | print(f'New data size = {len(new_data):,}') 46 | 47 | # Save new data 48 | with open(args.save_path, 'w') as f: 49 | f.write(','.join(header) + '\n') 50 | 51 | for smiles, avg_targets in new_data: 52 | f.write(smiles + ',' + ','.join(str(value) if value is not None else '' for value in avg_targets) + '\n') 53 | -------------------------------------------------------------------------------- /chemprop/data_processing/plot_distribution.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | 5 | from chemprop.data.utils import get_data, get_task_names 6 | 7 | 8 | def plot_distribution(data_path: str, save_dir: str, bins: int): 9 | """ 10 | Plots the distribution of values of a dataset. 11 | 12 | :param data_path: Path to data CSV file. 13 | :param save_dir: Directory where plot PNGs will be saved. 14 | :param bins: Number of bins in histogram. 15 | """ 16 | # Get values 17 | task_names = get_task_names(data_path) 18 | data = get_data(path=data_path) 19 | targets = data.targets() 20 | 21 | # Arrange values by task 22 | data_size, num_tasks = len(targets), len(task_names) 23 | values = [[targets[i][j] for i in range(data_size)] for j in range(num_tasks)] 24 | 25 | # Plot distributions for each task 26 | data_name = os.path.basename(data_path).replace('.csv', '') 27 | 28 | for i in range(num_tasks): 29 | plt.clf() 30 | plt.hist(values[i], bins=bins) 31 | 32 | # Save plot 33 | plt.title(f'{data_name} - {task_names[i]}') 34 | plt.xlabel(task_names[i]) 35 | plt.ylabel('Frequency') 36 | plt.savefig(os.path.join(save_dir, f'{data_name}_{task_names[i]}.png')) 37 | -------------------------------------------------------------------------------- /chemprop/data_processing/process_zinc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | from argparse import ArgumentParser 4 | 5 | if __name__ == '__main__': 6 | parser = ArgumentParser() 7 | parser.add_argument('--zinc_dir', type=str, required=True, 8 | help='Path to dir with raw ZINC files') 9 | parser.add_argument('--max_size', type=int, default=0, 10 | help='Max number of smiles (0 = all)') 11 | parser.add_argument('--write_loc', type=str, required=True, 12 | help='Where to write to') 13 | parser.add_argument('--individual_files', action='store_true', default=False, 14 | help='Process files individually') 15 | args = parser.parse_args() 16 | 17 | if args.individual_files: 18 | os.makedirs(args.write_loc, exist_ok=True) 19 | for root, _, names in os.walk(args.zinc_dir): 20 | for name in tqdm(names, total=len(names)): 21 | _, ext = os.path.splitext(name) 22 | if ext == '.txt': 23 | with open(os.path.join(root, name), 'r') as rf, \ 24 | open(os.path.join(args.write_loc, name), 'w') as wf: 25 | wf.write('smiles') 26 | wf.write('\n') 27 | rf.readline() 28 | for line in rf: 29 | smiles = line.strip().split('\t')[0] 30 | wf.write(smiles) 31 | wf.write('\n') 32 | 33 | else: 34 | with open(os.path.join(args.write_loc), 'w') as wf: 35 | wf.write('smiles') 36 | wf.write('\n') 37 | count = 0 38 | for root, _, names in os.walk(args.zinc_dir): 39 | for name in tqdm(names, total=len(names)): 40 | _, ext = os.path.splitext(name) 41 | if ext == '.txt': 42 | with open(os.path.join(root, name), 'r') as rf: 43 | rf.readline() 44 | for line in rf: 45 | if count > args.max_size: 46 | break 47 | smiles = line.strip().split('\t')[0] 48 | wf.write(smiles) 49 | wf.write('\n') 50 | count += 1 -------------------------------------------------------------------------------- /chemprop/data_processing/resplit.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | 4 | def resplit(args: Namespace): 5 | """ 6 | Resplits the train and validation data chronologically. 7 | 8 | Assumes that the data at train_path and val_path are sorted 9 | chronologically within each file but have been split randomly 10 | between the two files. This function puts the first (1 - val_frac) 11 | of both the train and validation data in the new train_save file 12 | and puts the remaining val_frac of both in the new val_save file. 13 | That way, the new validation data comes chronologically after 14 | the new training data. 15 | """ 16 | train_frac = 1 - args.val_frac 17 | 18 | # Get train and validation sizes 19 | with open(args.train_path, 'r') as f: 20 | train_len = sum(1 for _ in f) - 1 21 | with open(args.val_path, 'r') as f: 22 | val_len = sum(1 for _ in f) - 1 23 | 24 | # Resplit data 25 | with open(args.train_path, 'r') as rtf, open(args.val_path, 'r') as rvf, \ 26 | open(args.train_save, 'w') as wtf, open(args.val_save, 'w') as wvf: 27 | header = rtf.readline().strip() 28 | rvf.readline() # skip header 29 | 30 | wtf.write(header + '\n') 31 | wvf.write(header + '\n') 32 | 33 | for i in range(train_len): 34 | line = rtf.readline().strip() 35 | 36 | if i < train_frac * train_len: 37 | wtf.write(line + '\n') 38 | else: 39 | wvf.write(line + '\n') 40 | 41 | for i in range(val_len): 42 | line = rvf.readline().strip() 43 | 44 | if i < train_frac * val_len: 45 | wtf.write(line + '\n') 46 | else: 47 | wvf.write(line + '\n') 48 | -------------------------------------------------------------------------------- /chemprop/features/__init__.py: -------------------------------------------------------------------------------- 1 | from .featurization import atom_features, bond_features, BatchMolGraph, get_atom_fdim, get_bond_fdim, mol2graph 2 | from .functional_groups import FunctionalGroupFeaturizer 3 | from .kernels import get_kernel_func 4 | from .morgan_fingerprint import morgan_fingerprint 5 | from .rdkit_features import rdkit_2d_features 6 | from .utils import load_features, get_features_func 7 | -------------------------------------------------------------------------------- /chemprop/features/async_featurization.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from typing import List, Tuple, Union 3 | 4 | import numpy as np 5 | from rdkit import Chem 6 | from rdkit.Chem import AllChem 7 | 8 | import torch 9 | torch.multiprocessing.set_sharing_strategy('file_system') 10 | from torch.multiprocessing import Queue, Pool 11 | 12 | from chemprop.data import MoleculeDataset 13 | from .featurization import mol2graph 14 | 15 | def mol2graph_helper(pair): 16 | batch, args = pair 17 | return batch, mol2graph(batch.smiles(), args) 18 | 19 | def async_mol2graph(q: Queue, 20 | data: MoleculeDataset, 21 | args: Namespace, 22 | num_iters: int, 23 | iter_size: int, 24 | exit_q: Queue, 25 | last_batch: bool=False): 26 | batches = [] 27 | for i in range(0, num_iters, iter_size): # will only go up to max size of queue, then yield 28 | if not last_batch and i + args.batch_size > len(data): 29 | break 30 | batch = MoleculeDataset(data[i:i + args.batch_size]) 31 | batches.append(batch) 32 | if len(batches) == args.batches_per_queue_group: # many at a time, since synchronization is expensive 33 | with Pool() as pool: 34 | processed_batches = pool.map(mol2graph_helper, [(batch, args) for batch in batches]) 35 | q.put(processed_batches) 36 | batches = [] 37 | if len(batches) > 0: 38 | with Pool() as pool: 39 | processed_batches = pool.map(mol2graph_helper, [(batch, args) for batch in batches]) 40 | q.put(processed_batches) 41 | exit_q.get() # prevent from exiting until main process tells it to; otherwise we apparently can't read the end of the queue and crash -------------------------------------------------------------------------------- /chemprop/features/descriptors.py: -------------------------------------------------------------------------------- 1 | from mordred import Calculator, descriptors 2 | import numpy as np 3 | from rdkit import Chem 4 | 5 | 6 | mordred_calc = Calculator(descriptors, ignore_3D=True) # can't do 3D without sdf or mol file 7 | 8 | 9 | def mordred_features(mol: Chem.Mol) -> np.ndarray: 10 | return np.array([float(f) for f in mordred_calc(mol)]) 11 | -------------------------------------------------------------------------------- /chemprop/features/functional_groups.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import numpy as np 3 | from typing import List, Union 4 | 5 | from rdkit import Chem 6 | 7 | 8 | def get_num_functional_groups(args: Namespace): 9 | with open(args.functional_group_smarts, 'r') as f: 10 | count = len(f.readlines()) 11 | return count 12 | 13 | 14 | class FunctionalGroupFeaturizer: 15 | """ 16 | Class for extracting feature vector of indicators for atoms being parts of functional groups. 17 | """ 18 | def __init__(self, args: Namespace): 19 | self.smarts = [] 20 | with open(args.functional_group_smarts, 'r') as f: 21 | for line in f: 22 | self.smarts.append(Chem.MolFromSmarts(line.strip())) 23 | 24 | def featurize(self, smiles: Union[Chem.Mol, str]) -> List[List[int]]: 25 | """ 26 | Given a molecule in SMILES form, return a feature vector of indicators for each atom, 27 | indicating whether the atom is part of each functional group. 28 | Can also directly accept a Chem molecule. 29 | Searches through the functional groups given in smarts.txt. 30 | 31 | :param smiles: A smiles string representing a molecule. 32 | :return: Numpy array of shape num_atoms x num_features (num functional groups) 33 | """ 34 | if type(smiles) == str: 35 | mol = Chem.MolFromSmiles(smiles) # turns out rdkit knows to match even without adding Hs 36 | else: 37 | mol = smiles 38 | features = np.zeros((mol.GetNumAtoms(), len(self.smarts))) # num atoms (without Hs) x num features 39 | for i, smarts in enumerate(self.smarts): 40 | for group in mol.GetSubstructMatches(smarts): 41 | for idx in group: 42 | features[idx][i] = 1 43 | 44 | return features.tolist() 45 | 46 | 47 | if __name__ == '__main__': 48 | featurizer = FunctionalGroupFeaturizer() 49 | features = featurizer.featurize('C(#N)C(=O)C#N') 50 | -------------------------------------------------------------------------------- /chemprop/features/morgan_fingerprint.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rdkit import Chem, DataStructs 3 | from rdkit.Chem import AllChem 4 | 5 | 6 | def morgan_fingerprint(smiles: str, radius: int = 2, num_bits: int = 2048, use_counts: bool = False) -> np.ndarray: 7 | """ 8 | Generates a morgan fingerprint for a smiles string. 9 | 10 | :param smiles: A smiles string for a molecule. 11 | :param radius: The radius of the fingerprint. 12 | :param num_bits: The number of bits to use in the fingerprint. 13 | :param use_counts: Whether to use counts or just a bit vector for the fingerprint 14 | :return: A 1-D numpy array containing the morgan fingerprint. 15 | """ 16 | if type(smiles) == str: 17 | mol = Chem.MolFromSmiles(smiles) 18 | else: 19 | mol = smiles 20 | if use_counts: 21 | fp_vect = AllChem.GetHashedMorganFingerprint(mol, radius, nBits=num_bits) 22 | else: 23 | fp_vect = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=num_bits) 24 | fp = np.zeros((1,)) 25 | DataStructs.ConvertToNumpyArray(fp_vect, fp) 26 | 27 | return fp 28 | -------------------------------------------------------------------------------- /chemprop/features/rdkit_features.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import numpy as np 3 | 4 | from rdkit import Chem 5 | # from rdkit.Chem.rdPartialCharges import ComputeGasteigerCharges 6 | from rdkit.Chem.GraphDescriptors import BalabanJ, BertzCT, Ipc, Chi0, Chi0n, Chi0v, Chi1, Chi1n, Chi1v, Chi2n, Chi2v, \ 7 | Chi3n, Chi3v, Chi4n, Chi4v, HallKierAlpha, Kappa1, Kappa2, Kappa3 8 | from rdkit.Chem.Crippen import MolLogP, MolMR 9 | from rdkit.Chem.Descriptors import MolWt, ExactMolWt, HeavyAtomMolWt, NumValenceElectrons 10 | from rdkit.Chem.Lipinski import HeavyAtomCount, NHOHCount, NOCount, NumHAcceptors, NumHDonors, NumHeteroatoms, \ 11 | NumRotatableBonds, NumAromaticRings, NumSaturatedRings, NumAliphaticRings, NumAromaticHeterocycles, \ 12 | NumAromaticCarbocycles, NumSaturatedHeterocycles, NumSaturatedCarbocycles, NumAliphaticHeterocycles, \ 13 | NumAliphaticCarbocycles, RingCount, FractionCSP3 14 | from rdkit.Chem.rdMolDescriptors import CalcNumSpiroAtoms, CalcNumBridgeheadAtoms, CalcTPSA, CalcLabuteASA, PEOE_VSA_, \ 15 | SMR_VSA_, SlogP_VSA_, MQNs_, CalcAUTOCORR2D, CalcNumAmideBonds 16 | from rdkit.Chem.Fragments import fr_ketone_Topliss 17 | from rdkit.Chem.EState.EState_VSA import EState_VSA1, EState_VSA2, EState_VSA3, EState_VSA4, EState_VSA5, EState_VSA6, \ 18 | EState_VSA7, EState_VSA8, EState_VSA9, EState_VSA10, EState_VSA11, VSA_EState1, VSA_EState2, VSA_EState3, \ 19 | VSA_EState4, VSA_EState5, VSA_EState6, VSA_EState7, VSA_EState8, VSA_EState9, VSA_EState10 20 | 21 | from chemprop.features.functional_groups import FunctionalGroupFeaturizer 22 | 23 | FEATURE_FUNCTIONS = [ 24 | # ComputeGasteigerCharges, 25 | BalabanJ, BertzCT, Ipc, Chi0, Chi0n, Chi0v, Chi1, Chi1n, Chi1v, Chi2n, Chi2v, Chi3n, Chi3v, Chi4n, Chi4v, 26 | HallKierAlpha, Kappa1, Kappa2, Kappa3, 27 | MolLogP, MolMR, 28 | MolWt, ExactMolWt, HeavyAtomMolWt, NumValenceElectrons, 29 | HeavyAtomCount, NHOHCount, NOCount, NumHAcceptors, NumHDonors, NumHeteroatoms, NumRotatableBonds, NumAromaticRings, 30 | NumSaturatedRings, NumAliphaticRings, NumAromaticHeterocycles, NumAromaticCarbocycles, NumSaturatedHeterocycles, 31 | NumSaturatedCarbocycles, NumAliphaticHeterocycles, NumAliphaticCarbocycles, RingCount, FractionCSP3, 32 | CalcNumSpiroAtoms, CalcNumBridgeheadAtoms, CalcTPSA, CalcLabuteASA, PEOE_VSA_, SMR_VSA_, SlogP_VSA_, MQNs_, 33 | CalcAUTOCORR2D, CalcNumAmideBonds, 34 | fr_ketone_Topliss, 35 | EState_VSA1, EState_VSA2, EState_VSA3, EState_VSA4, EState_VSA5, EState_VSA6, EState_VSA7, EState_VSA8, EState_VSA9, 36 | EState_VSA10, EState_VSA11, VSA_EState1, VSA_EState2, VSA_EState3, VSA_EState4, VSA_EState5, VSA_EState6, 37 | VSA_EState7, VSA_EState8, VSA_EState9, VSA_EState10 38 | ] 39 | 40 | 41 | def rdkit_2d_features(smiles: str, args: Namespace): 42 | if type(smiles) == str: 43 | mol = Chem.MolFromSmiles(smiles) 44 | else: 45 | mol = smiles 46 | features = [] 47 | for f in FEATURE_FUNCTIONS: 48 | try: 49 | feature = f(mol) 50 | except: # very very rarely, something like BalabanJ crashes 51 | dummy_mol = Chem.MolFromSmiles('c1ccc2cc(CC3=NCCN3)ccc2c1') 52 | feature = f(dummy_mol) 53 | if type(feature) == list: 54 | features.extend(feature) 55 | else: 56 | features.append(feature) 57 | 58 | # TODO: these following ones take about half of the computation time, can we make it more efficient somehow? not sure tho 59 | fg_featurizer = FunctionalGroupFeaturizer(args) 60 | fg_features = fg_featurizer.featurize(mol) 61 | features += np.array(fg_features).sum(axis=0).tolist() 62 | features = np.clip(np.nan_to_num(np.array(features)), -1e2, 1e2) 63 | 64 | return features 65 | -------------------------------------------------------------------------------- /chemprop/features/rdkit_normalized_features.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from rdkit import Chem 3 | 4 | try: 5 | from descriptastorus.descriptors import rdNormalizedDescriptors 6 | generator = rdNormalizedDescriptors.RDKit2DNormalized() 7 | 8 | def rdkit_2d_normalized_features(smiles: str): 9 | # the first element is true/false if the mol was properly computed 10 | if type(smiles) == str: 11 | return generator.process(smiles)[1:] 12 | 13 | else: 14 | # this is a bit of a waste, but the desciptastorus API is smiles 15 | # based for normalization purposes 16 | return generator.process(Chem.MolToSmiles(smiles, isomericSmiles=True))[1:] 17 | 18 | 19 | except ImportError: 20 | logging.getLogger(__name__).warning("descriptastorus is not available, normalized descriptors are not available") 21 | rdkit_2d_normalized_features = None 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /chemprop/features/utils.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from functools import partial 3 | import os 4 | import pickle 5 | from typing import Callable, List, Union 6 | import logging 7 | 8 | import numpy as np 9 | from rdkit import Chem 10 | 11 | from .descriptors import mordred_features 12 | from .morgan_fingerprint import morgan_fingerprint 13 | from .rdkit_features import rdkit_2d_features 14 | 15 | 16 | def load_features(path: str) -> List[np.ndarray]: 17 | """ 18 | Loads features saved as a .pckl file or as a directory of .pckl files. 19 | 20 | If path is a directory, assumes features are saved in files named 0.pckl, 1.pckl, ... 21 | 22 | :param path: Path to a .pckl file or a directory of .pckl files named as above. 23 | :return: A list of numpy arrays containing the features. 24 | """ 25 | if os.path.isfile(path): 26 | with open(path, 'rb') as f: 27 | features = pickle.load(f) 28 | features = [np.squeeze(np.array(feat.todense())) for feat in features] 29 | else: 30 | features = [] 31 | features_num = 0 32 | features_path = os.path.join(path, f'{features_num}.pckl') 33 | 34 | while os.path.exists(features_path): 35 | with open(features_path, 'rb') as f: 36 | feats = pickle.load(f) 37 | features.extend([np.squeeze(np.array(feat.todense())) for feat in feats]) 38 | 39 | features_num += 1 40 | features_path = os.path.join(path, f'{features_num}.pckl') 41 | 42 | return features 43 | 44 | 45 | def get_features_func(features_generator: str, 46 | args: Namespace = None) -> Union[Callable[[Chem.Mol], np.ndarray], 47 | partial]: 48 | if features_generator == 'morgan': 49 | return partial(morgan_fingerprint, use_counts=False) 50 | 51 | if features_generator == 'morgan_count': 52 | return partial(morgan_fingerprint, use_counts=True) 53 | 54 | if features_generator == 'rdkit_2d': 55 | assert args is not None 56 | assert hasattr(args, 'functional_group_smarts') # TODO: handle this in a better way 57 | return partial(rdkit_2d_features, args=args) 58 | 59 | if features_generator == "rdkit_2d_normalized": 60 | from .rdkit_normalized_features import rdkit_2d_normalized_features 61 | if rdkit_2d_normalized_features is None: 62 | logging.getLogger(__name__).warning("Please install descriptastorus for normalized descriptors") 63 | raise ValueError(f'feature_generator type "{features_generator}" not installed') 64 | 65 | return rdkit_2d_normalized_features 66 | 67 | if features_generator == 'mordred': 68 | return mordred_features 69 | 70 | raise ValueError(f'features_generator type "{features_generator}" not supported.') 71 | -------------------------------------------------------------------------------- /chemprop/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import build_model 2 | -------------------------------------------------------------------------------- /chemprop/models/learned_kernel.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class LearnedKernel(nn.Module): 8 | def __init__(self, args: Namespace): 9 | super(LearnedKernel, self).__init__() 10 | self.A = nn.Linear(args.ffn_hidden_size, args.ffn_hidden_size) 11 | 12 | def forward(self, encodings: torch.Tensor): 13 | # encodings is num_pairs x 2 x ffn hidden size 14 | return (self.A(encodings[:, 1, :].squeeze(1)) * encodings[:, 0, :].squeeze(1)).sum(dim=1, keepdim=True) 15 | -------------------------------------------------------------------------------- /chemprop/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .cross_validate import cross_validate 2 | from .evaluate import evaluate, evaluate_predictions 3 | from .make_predictions import make_predictions 4 | from .predict import predict 5 | from .run_training import run_training 6 | from .train import train 7 | -------------------------------------------------------------------------------- /chemprop/train/cross_validate.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from logging import Logger 3 | import os 4 | from typing import Tuple 5 | 6 | import numpy as np 7 | import shutil 8 | 9 | from chemprop.data.utils import get_task_names, get_desired_labels 10 | from .run_training import run_training 11 | 12 | 13 | def cross_validate(args: Namespace, logger: Logger = None) -> Tuple[float, float]: 14 | """k-fold cross validation""" 15 | info = logger.info if logger is not None else print 16 | 17 | # Initialize relevant variables 18 | init_seed = args.seed 19 | save_dir = args.save_dir 20 | task_names = get_task_names(args.data_path) 21 | desired_labels = get_desired_labels(args, task_names) 22 | 23 | # Run training on different random seeds for each fold 24 | all_scores = [] 25 | for fold_num in range(args.num_folds): 26 | info(f'Fold {fold_num}') 27 | args.seed = init_seed + fold_num 28 | args.save_dir = os.path.join(save_dir, f'fold_{fold_num}') 29 | os.makedirs(args.save_dir, exist_ok=True) 30 | model_scores = run_training(args, logger) 31 | all_scores.append(model_scores) 32 | all_scores = np.array(all_scores) 33 | 34 | # Report results 35 | info(f'{args.num_folds}-fold cross validation') 36 | 37 | # Report scores for each fold 38 | for fold_num, scores in enumerate(all_scores): 39 | info(f'Seed {init_seed + fold_num} ==> test {args.metric} = {np.nanmean(scores):.6f}') 40 | 41 | if args.show_individual_scores: 42 | for task_name, score in zip(task_names, scores): 43 | if task_name in desired_labels: 44 | info(f'Seed {init_seed + fold_num} ==> test {task_name} {args.metric} = {score:.6f}') 45 | 46 | # Report scores across models 47 | avg_scores = np.nanmean(all_scores, axis=1) # average score for each model across tasks 48 | mean_score, std_score = np.nanmean(avg_scores), np.nanstd(avg_scores) 49 | info(f'Overall test {args.metric} = {mean_score:.6f} +/- {std_score:.6f}') 50 | 51 | if args.show_individual_scores: 52 | for task_num, task_name in enumerate(task_names): 53 | if task_name in desired_labels: 54 | info(f'Overall test {task_name} {args.metric} = ' 55 | f'{np.nanmean(all_scores[:, task_num]):.6f} +/- {np.nanstd(all_scores[:, task_num]):.6f}') 56 | 57 | if args.num_chunks > 1: 58 | shutil.rmtree(args.chunk_temp_dir) 59 | 60 | return mean_score, std_score 61 | -------------------------------------------------------------------------------- /chemprop/train/make_predictions.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import csv 3 | from typing import List, Optional 4 | 5 | import numpy as np 6 | import torch 7 | from tqdm import tqdm 8 | from rdkit import Chem 9 | 10 | from .predict import predict 11 | from chemprop.data import MoleculeDataset 12 | from chemprop.data.utils import get_data, get_data_from_smiles 13 | from chemprop.utils import load_args, load_checkpoint, load_scalers 14 | 15 | 16 | def make_predictions(args: Namespace, smiles: List[str] = None) -> List[Optional[List[float]]]: 17 | """ 18 | Makes predictions. If smiles is provided, makes predictions on smiles. Otherwise makes predictions on args.test_data. 19 | 20 | :param args: Arguments. 21 | :param smiles: Smiles to make predictions on. 22 | :return: A list of lists of target predictions. 23 | """ 24 | if args.gpu is not None: 25 | torch.cuda.set_device(args.gpu) 26 | 27 | print('Loading training args') 28 | scaler, features_scaler = load_scalers(args.checkpoint_paths[0]) 29 | train_args = load_args(args.checkpoint_paths[0]) 30 | 31 | # Update args with training arguments 32 | for key, value in vars(train_args).items(): 33 | if not hasattr(args, key): 34 | setattr(args, key, value) 35 | 36 | print('Loading data') 37 | if smiles is not None: 38 | test_data = get_data_from_smiles(smiles=smiles, skip_invalid_smiles=False) 39 | else: 40 | test_data = get_data(path=args.test_path, args=args, use_compound_names=args.compound_names, skip_invalid_smiles=False) 41 | 42 | print('Validating SMILES') 43 | valid_indices = [i for i in range(len(test_data)) if test_data[i].mol is not None] 44 | full_data = test_data 45 | test_data = MoleculeDataset([test_data[i] for i in valid_indices]) 46 | 47 | # Edge case if empty list of smiles is provided 48 | if len(test_data) == 0: 49 | return [None] * len(full_data) 50 | 51 | test_smiles = test_data.smiles() 52 | 53 | if args.compound_names: 54 | compound_names = test_data.compound_names() 55 | print(f'Test size = {len(test_data):,}') 56 | 57 | # Normalize features 58 | if train_args.features_scaling: 59 | test_data.normalize_features(features_scaler) 60 | 61 | # Predict with each model individually and sum predictions 62 | sum_preds = np.zeros((len(test_data), args.num_tasks)) 63 | print(f'Predicting with an ensemble of {len(args.checkpoint_paths)} models') 64 | for checkpoint_path in tqdm(args.checkpoint_paths, total=len(args.checkpoint_paths)): 65 | # Load model 66 | model = load_checkpoint(checkpoint_path, cuda=args.cuda) 67 | model_preds = predict( 68 | model=model, 69 | data=test_data, 70 | args=args, 71 | scaler=scaler 72 | ) 73 | sum_preds += np.array(model_preds) 74 | 75 | # Ensemble predictions 76 | avg_preds = sum_preds / args.ensemble_size 77 | avg_preds = avg_preds.tolist() 78 | 79 | # Save predictions 80 | assert len(test_data) == len(avg_preds) 81 | print(f'Saving predictions to {args.preds_path}') 82 | 83 | # Put Nones for invalid smiles 84 | full_preds = [None] * len(full_data) 85 | for i, si in enumerate(valid_indices): 86 | full_preds[si] = avg_preds[i] 87 | avg_preds = full_preds 88 | test_smiles = full_data.smiles() 89 | 90 | # Write predictions 91 | with open(args.preds_path, 'w') as f: 92 | writer = csv.writer(f) 93 | 94 | header = [] 95 | if args.write_smiles: 96 | header.append('smiles') 97 | if args.compound_names: 98 | header.append('compound_names') 99 | 100 | header.extend(args.task_names) 101 | writer.writerow(header) 102 | 103 | for i in range(len(avg_preds)): 104 | row = [] 105 | 106 | if args.write_smiles: 107 | row.append(test_smiles[i]) 108 | if args.compound_names: 109 | row.append(compound_names[i]) 110 | 111 | if avg_preds[i] is not None: 112 | row.extend(avg_preds[i]) 113 | else: 114 | row.extend([''] * args.num_tasks) 115 | 116 | writer.writerow(row) 117 | 118 | return avg_preds 119 | -------------------------------------------------------------------------------- /data/all_smiles.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/data/all_smiles.csv.gz -------------------------------------------------------------------------------- /data/hiv.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/data/hiv.csv.gz -------------------------------------------------------------------------------- /data/muv.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/data/muv.csv.gz -------------------------------------------------------------------------------- /data/pcba.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/data/pcba.csv.gz -------------------------------------------------------------------------------- /data/qm7.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/data/qm7.csv.gz -------------------------------------------------------------------------------- /data/qm8.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/data/qm8.csv.gz -------------------------------------------------------------------------------- /data/qm9.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/data/qm9.csv.gz -------------------------------------------------------------------------------- /data/sanitize.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import csv 3 | from rdkit import Chem 4 | 5 | 6 | def sanitize(data_path: str, save_path): 7 | with open(data_path) as f: 8 | reader = csv.reader(f) 9 | header = next(reader) 10 | lines = [line for line in reader if line[0] != '' and Chem.MolFromSmiles(line[0]) is not None] 11 | 12 | with open(save_path) as f: 13 | writer = csv.writer(f) 14 | writer.writerow(header) 15 | for line in lines: 16 | writer.writerow(line) 17 | 18 | 19 | if __name__ == '__main__': 20 | parser = ArgumentParser() 21 | parser.add_argument('--data_path', type=str, required=True, help='Data CSV to sanitize') 22 | parser.add_argument('--save_path', type=str, required=True, help='Path to CSV where sanitized data will be saved') 23 | args = parser.parse_args() 24 | 25 | sanitize(args.data_path, args.save_path) 26 | -------------------------------------------------------------------------------- /data/toxcast.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/data/toxcast.csv.gz -------------------------------------------------------------------------------- /distributions/chembl/chembl_labels_log.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/chembl/chembl_labels_log.png -------------------------------------------------------------------------------- /distributions/delaney/delaney_logp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/delaney/delaney_logp.png -------------------------------------------------------------------------------- /distributions/freesolv/freesolv_freesolv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/freesolv/freesolv_freesolv.png -------------------------------------------------------------------------------- /distributions/lipo/lipo_lipo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/lipo/lipo_lipo.png -------------------------------------------------------------------------------- /distributions/qm8/qm8_E1-CAM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm8/qm8_E1-CAM.png -------------------------------------------------------------------------------- /distributions/qm8/qm8_E1-CC2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm8/qm8_E1-CC2.png -------------------------------------------------------------------------------- /distributions/qm8/qm8_E1-PBE0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm8/qm8_E1-PBE0.png -------------------------------------------------------------------------------- /distributions/qm8/qm8_E2-CAM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm8/qm8_E2-CAM.png -------------------------------------------------------------------------------- /distributions/qm8/qm8_E2-CC2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm8/qm8_E2-CC2.png -------------------------------------------------------------------------------- /distributions/qm8/qm8_E2-PBE0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm8/qm8_E2-PBE0.png -------------------------------------------------------------------------------- /distributions/qm8/qm8_f1-CAM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm8/qm8_f1-CAM.png -------------------------------------------------------------------------------- /distributions/qm8/qm8_f1-CC2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm8/qm8_f1-CC2.png -------------------------------------------------------------------------------- /distributions/qm8/qm8_f1-PBE0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm8/qm8_f1-PBE0.png -------------------------------------------------------------------------------- /distributions/qm8/qm8_f2-CAM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm8/qm8_f2-CAM.png -------------------------------------------------------------------------------- /distributions/qm8/qm8_f2-CC2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm8/qm8_f2-CC2.png -------------------------------------------------------------------------------- /distributions/qm8/qm8_f2-PBE0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm8/qm8_f2-PBE0.png -------------------------------------------------------------------------------- /distributions/qm9/qm9_alpha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm9/qm9_alpha.png -------------------------------------------------------------------------------- /distributions/qm9/qm9_cv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm9/qm9_cv.png -------------------------------------------------------------------------------- /distributions/qm9/qm9_g298.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm9/qm9_g298.png -------------------------------------------------------------------------------- /distributions/qm9/qm9_g298_atom.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm9/qm9_g298_atom.png -------------------------------------------------------------------------------- /distributions/qm9/qm9_gap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm9/qm9_gap.png -------------------------------------------------------------------------------- /distributions/qm9/qm9_h298.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm9/qm9_h298.png -------------------------------------------------------------------------------- /distributions/qm9/qm9_h298_atom.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm9/qm9_h298_atom.png -------------------------------------------------------------------------------- /distributions/qm9/qm9_homo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm9/qm9_homo.png -------------------------------------------------------------------------------- /distributions/qm9/qm9_lumo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm9/qm9_lumo.png -------------------------------------------------------------------------------- /distributions/qm9/qm9_mu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm9/qm9_mu.png -------------------------------------------------------------------------------- /distributions/qm9/qm9_r2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm9/qm9_r2.png -------------------------------------------------------------------------------- /distributions/qm9/qm9_u0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm9/qm9_u0.png -------------------------------------------------------------------------------- /distributions/qm9/qm9_u0_atom.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm9/qm9_u0_atom.png -------------------------------------------------------------------------------- /distributions/qm9/qm9_u298.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm9/qm9_u298.png -------------------------------------------------------------------------------- /distributions/qm9/qm9_u298_atom.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm9/qm9_u298_atom.png -------------------------------------------------------------------------------- /distributions/qm9/qm9_zpve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/distributions/qm9/qm9_zpve.png -------------------------------------------------------------------------------- /hyperparameter_optimization.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, Namespace 2 | from copy import deepcopy 3 | import json 4 | import os 5 | from typing import Dict, Union 6 | 7 | from hyperopt import fmin, hp, tpe 8 | import numpy as np 9 | 10 | from chemprop.models import build_model 11 | from chemprop.nn_utils import param_count 12 | from chemprop.parsing import add_train_args, modify_train_args 13 | from chemprop.train import cross_validate 14 | from model_comparison import create_logger, create_train_logger 15 | 16 | 17 | SPACE = { 18 | 'hidden_size': hp.quniform('hidden_size', low=300, high=2400, q=100), 19 | 'depth': hp.quniform('depth', low=2, high=6, q=1), 20 | 'dropout': hp.quniform('dropout', low=0.0, high=0.4, q=0.05), 21 | 'ffn_num_layers': hp.quniform('ffn_num_layers', low=1, high=3, q=1) 22 | } 23 | INT_KEYS = ['hidden_size', 'depth', 'ffn_num_layers'] 24 | 25 | TRAIN_LOGGER = create_train_logger() 26 | 27 | 28 | def grid_search(args: Namespace): 29 | # Create logger for dataset 30 | logger = create_logger(name='hyperparameter_optimization', save_path=args.log_path) 31 | 32 | # Run grid search 33 | results = [] 34 | 35 | # Define hyperparameter optimization 36 | def objective(hyperparams: Dict[str, Union[int, float]]) -> float: 37 | # Convert hyperparams from float to int when necessary 38 | for key in INT_KEYS: 39 | hyperparams[key] = int(hyperparams[key]) 40 | 41 | # Update args with hyperparams 42 | hyper_args = deepcopy(args) 43 | if args.save_dir is not None: 44 | folder_name = '_'.join([f'{key}_{value}' if key in INT_KEYS else f'{key}_{value}' for key, value in hyperparams.items()]) 45 | hyper_args.save_dir = os.path.join(hyper_args.save_dir, folder_name) 46 | for key, value in hyperparams.items(): 47 | setattr(hyper_args, key, value) 48 | 49 | # Record hyperparameters 50 | logger.info(hyperparams) 51 | 52 | # Cross validate 53 | mean_score, std_score = cross_validate(hyper_args, TRAIN_LOGGER) 54 | 55 | # Record results 56 | temp_model = build_model(hyper_args) 57 | num_params = param_count(temp_model) 58 | logger.info(f'num params: {num_params:,}') 59 | logger.info(f'{mean_score} +/- {std_score} {hyper_args.metric}') 60 | 61 | results.append({ 62 | 'mean_score': mean_score, 63 | 'std_score': std_score, 64 | 'hyperparams': hyperparams, 65 | 'num_params': num_params 66 | }) 67 | 68 | # Deal with nan 69 | if np.isnan(mean_score): 70 | if hyper_args.dataset_type == 'classification': 71 | mean_score = 0 72 | else: 73 | raise ValueError('Can\'t handle nan score for non-classification dataset.') 74 | 75 | return (1 if hyper_args.minimize_score else -1) * mean_score 76 | 77 | fmin(objective, SPACE, algo=tpe.suggest, max_evals=args.num_iters) 78 | 79 | # Report best result 80 | results = [result for result in results if not np.isnan(result['mean_score'])] 81 | best_result = min(results, key=lambda result: (1 if args.minimize_score else -1) * result['mean_score']) 82 | logger.info('best') 83 | logger.info(best_result['hyperparams']) 84 | logger.info(f'num params: {best_result["num_params"]:,}') 85 | logger.info(f'{best_result["mean_score"]} +/- {best_result["std_score"]} {args.metric}') 86 | 87 | # Save best hyperparameter settings as JSON config file 88 | with open(args.config_save_path, 'w') as f: 89 | json.dump(best_result['hyperparams'], f, indent=4, sort_keys=True) 90 | 91 | 92 | if __name__ == '__main__': 93 | parser = ArgumentParser() 94 | add_train_args(parser) 95 | parser.add_argument('--num_iters', type=int, default=20, 96 | help='Number of hyperparameter choices to try') 97 | parser.add_argument('--config_save_path', type=str, required=True, 98 | help='Path to .json file where best hyperparameter settings will be written') 99 | parser.add_argument('--log_path', type=str, 100 | help='(Optional) Path to .log file where all results of the hyperparameter optimization will be written') 101 | args = parser.parse_args() 102 | modify_train_args(args) 103 | 104 | grid_search(args) 105 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from chemprop.parsing import parse_predict_args 2 | from chemprop.train import make_predictions 3 | 4 | if __name__ == '__main__': 5 | args = parse_predict_args() 6 | make_predictions(args) 7 | -------------------------------------------------------------------------------- /random_forest.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import logging 3 | 4 | from chemprop.random_forest import cross_validate_random_forest 5 | from chemprop.utils import set_logger 6 | 7 | logger = logging.getLogger('random_forest') 8 | logger.setLevel(logging.DEBUG) 9 | logger.propagate = False 10 | 11 | 12 | if __name__ == '__main__': 13 | parser = ArgumentParser() 14 | parser.add_argument('--data_path', type=str, required=True, 15 | help='Path to data CSV') 16 | parser.add_argument('--dataset_type', type=str, required=True, 17 | choices=['regression', 'classification'], 18 | help='Dataset type') 19 | parser.add_argument('--metric', type=str, 20 | choices=['auc', 'prc-auc', 'rmse', 'mae'], 21 | help='Metric to use during evaluation.') 22 | parser.add_argument('--split_type', type=str, default='random', 23 | choices=['random', 'scaffold_balanced'], 24 | help='Split type') 25 | parser.add_argument('--class_weight', type=str, 26 | choices=['balanced'], 27 | help='How to weight classes (None means no class balance)') 28 | parser.add_argument('--single_task', action='store_true', default=False, 29 | help='Whether to run each task separately (needed when dataset has null entries)') 30 | parser.add_argument('--num_folds', type=int, default=1, 31 | help='Number of folds of cross validation') 32 | parser.add_argument('--seed', type=int, default=0, 33 | help='Random seed') 34 | parser.add_argument('--radius', type=int, default=2, 35 | help='Morgan fingerprint radius') 36 | parser.add_argument('--num_bits', type=int, default=2048, 37 | help='Number of bits in morgan fingerprint') 38 | parser.add_argument('--num_trees', type=int, default=500, 39 | help='Number of random forest trees') 40 | parser.add_argument('--quiet', action='store_true', default=False, 41 | help='Control verbosity level') 42 | args = parser.parse_args() 43 | 44 | set_logger(logger, quiet=args.quiet) 45 | 46 | if args.metric is None: 47 | if args.dataset_type == 'regression': 48 | args.metric = 'rmse' 49 | elif args.dataset_type == 'classification': 50 | args.metric = 'auc' 51 | else: 52 | raise ValueError(f'Default metric not supported for dataset_type "{args.dataset_type}"') 53 | 54 | cross_validate_random_forest(args, logger) 55 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flask 2 | hyperopt 3 | matplotlib 4 | mordred 5 | networkx 6 | numpy 7 | python-ternary 8 | scikit-learn 9 | scipy 10 | tensorboardX 11 | tqdm 12 | git+https://github.com/bp-kelley/descriptastorus 13 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wengong-jin/chemprop/3ad3577367d8a53f28aade0be41b56b1f25b6125/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/avg_dups.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import sys 3 | sys.path.append('../') 4 | 5 | from chemprop.data_processing import average_duplicates 6 | 7 | if __name__ == '__main__': 8 | parser = ArgumentParser() 9 | parser.add_argument('--data_path', type=str, 10 | help='Path to data CSV file') 11 | parser.add_argument('--save_path', type=str, 12 | help='Path where average data CSV file will be saved') 13 | args = parser.parse_args() 14 | 15 | average_duplicates(args) 16 | -------------------------------------------------------------------------------- /scripts/filter_by_scaffold.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, Namespace 2 | from collections import namedtuple 3 | import sys 4 | sys.path.append('../') 5 | from typing import List 6 | 7 | from tqdm import tqdm 8 | 9 | from chemprop.data.scaffold import generate_scaffold 10 | 11 | 12 | Datapoint = namedtuple('Datapoint', ['smiles', 'line']) 13 | 14 | 15 | def get_header(path: str) -> str: 16 | with open(path) as f: 17 | header = f.readline() 18 | 19 | return header 20 | 21 | 22 | def get_data(path: str) -> List[Datapoint]: 23 | with open(path) as f: 24 | f.readline() # skip header 25 | data = [] 26 | for line in f.readlines(): 27 | smiles = line[:line.index(',')] 28 | data.append(Datapoint(smiles=smiles, line=line)) 29 | 30 | return data 31 | 32 | 33 | def filter_by_scaffold(args: Namespace): 34 | print('Loading data') 35 | header = get_header(args.data_path) 36 | data = get_data(path=args.data_path) 37 | scaffold_data = get_data(path=args.scaffold_data_path) 38 | 39 | print('Generating scaffolds') 40 | smiles_to_scaffold = {d.smiles: generate_scaffold(d.smiles) for d in tqdm(data, total=len(data))} 41 | scaffolds_to_keep = {generate_scaffold(d.smiles) for d in tqdm(scaffold_data, total=len(scaffold_data))} 42 | 43 | print('Filtering data') 44 | filtered_data = [d for d in data if smiles_to_scaffold[d.smiles] in scaffolds_to_keep] 45 | 46 | print(f'Filtered data from {len(data):,} to {len(filtered_data):,} molecules') 47 | 48 | print('Saving data') 49 | with open(args.save_path, 'w') as f: 50 | f.write(header) 51 | for d in filtered_data: 52 | f.write(d.line) 53 | 54 | 55 | if __name__ == '__main__': 56 | parser = ArgumentParser() 57 | parser.add_argument('--data_path', type=str, required=True, 58 | help='Path to dataset which will be limited to those molecules sharing' 59 | 'a scaffold with a molecule in the scaffold_data_path dataset') 60 | parser.add_argument('--scaffold_data_path', type=str, required=True, 61 | help='Path to the dataset whose scaffolds will be used to limit the' 62 | 'molecules in data_path') 63 | parser.add_argument('--save_path', type=str, required=True, 64 | help='Path where the filtered version of data_path will be saved') 65 | args = parser.parse_args() 66 | 67 | filter_by_scaffold(args) 68 | -------------------------------------------------------------------------------- /scripts/overlap.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import csv 3 | import sys 4 | sys.path.append('../') 5 | 6 | from chemprop.data.utils import get_data 7 | 8 | if __name__ == '__main__': 9 | parser = ArgumentParser() 10 | parser.add_argument('--data_path_1', type=str, required=True, 11 | help='Path to first data CSV file') 12 | parser.add_argument('--data_path_2', type=str, required=True, 13 | help='Path to second data CSV file') 14 | parser.add_argument('--compound_names_1', action='store_true', default=False, 15 | help='Whether data_path_1 has compound names in addition to smiles') 16 | parser.add_argument('--compound_names_2', action='store_true', default=False, 17 | help='Whether data_path_2 has compound names in addition to smiles') 18 | parser.add_argument('--save_intersection_path', type=str, default=None, 19 | help='Path to save intersection at; labeled with data_path 1 header') 20 | parser.add_argument('--save_difference_path', type=str, default=None, 21 | help='Path to save molecules in dataset 1 that are not in dataset 2; labeled with data_path 1 header') 22 | args = parser.parse_args() 23 | 24 | data_1 = get_data(path=args.data_path_1, use_compound_names=args.compound_names_1) 25 | data_2 = get_data(path=args.data_path_2, use_compound_names=args.compound_names_2) 26 | 27 | smiles1 = set(data_1.smiles()) 28 | smiles2 = set(data_2.smiles()) 29 | size_1, size_2 = len(smiles1), len(smiles2) 30 | intersection = smiles1.intersection(smiles2) 31 | size_intersect = len(intersection) 32 | print(f'Size of dataset 1: {size_1}') 33 | print(f'Size of dataset 2: {size_2}') 34 | print(f'Size of intersection: {size_intersect}') 35 | print(f'Size of intersection as frac of dataset 1: {size_intersect / size_1}') 36 | print(f'Size of intersection as frac of dataset 2: {size_intersect / size_2}') 37 | 38 | if args.save_intersection_path is not None: 39 | with open(args.data_path_1, 'r') as rf, open(args.save_intersection_path, 'w') as wf: 40 | reader, writer = csv.reader(rf), csv.writer(wf) 41 | header = next(reader) 42 | writer.writerow(header) 43 | for line in reader: 44 | if line[0] in intersection: 45 | writer.writerow(line) 46 | 47 | if args.save_difference_path is not None: 48 | with open(args.data_path_1, 'r') as rf, open(args.save_difference_path, 'w') as wf: 49 | reader, writer = csv.reader(rf), csv.writer(wf) 50 | header = next(reader) 51 | writer.writerow(header) 52 | for line in reader(): 53 | if line[0] not in intersection: 54 | writer.writerow(line) 55 | -------------------------------------------------------------------------------- /scripts/plot_distribution.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | import sys 4 | sys.path.append('../') 5 | 6 | from chemprop.data_processing import plot_distribution 7 | 8 | 9 | if __name__ == '__main__': 10 | parser = ArgumentParser() 11 | parser.add_argument('--data_path', type=str, 12 | help='Path to data CSV file') 13 | parser.add_argument('--save_dir', type=str, 14 | help='Directory where plot PNGs will be saved') 15 | parser.add_argument('--bins', type=int, default=50, 16 | help='Number of bins in histogram.') 17 | args = parser.parse_args() 18 | 19 | os.makedirs(args.save_dir, exist_ok=True) 20 | 21 | plot_distribution(args.data_path, args.save_dir, args.bins) 22 | -------------------------------------------------------------------------------- /scripts/resplit_data.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | import sys 4 | sys.path.append('../') 5 | 6 | from chemprop.data_processing import resplit 7 | 8 | 9 | if __name__ == '__main__': 10 | parser = ArgumentParser() 11 | parser.add_argument('--train_path', type=str, required=True, 12 | help='Path to CSV file containing training data') 13 | parser.add_argument('--val_path', type=str, required=True, 14 | help='Path to CSV file containing val data') 15 | parser.add_argument('--train_save', type=str, required=True, 16 | help='Path to CSV file for new train data') 17 | parser.add_argument('--val_save', type=str, required=True, 18 | help='Path to CSV file for new val data') 19 | parser.add_argument('--val_frac', type=float, default=0.2, 20 | help='frac of data to use for validation') 21 | args = parser.parse_args() 22 | 23 | # Create directory for save_path 24 | for path in [args.train_save, args.val_save]: 25 | save_dir = os.path.dirname(path) 26 | if save_dir != '': 27 | os.makedirs(save_dir, exist_ok=True) 28 | 29 | resplit(args) 30 | -------------------------------------------------------------------------------- /scripts/similarity.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import sys 3 | sys.path.append('../') 4 | 5 | from chemprop.data.utils import get_data 6 | from chemprop.data import morgan_similarity, scaffold_similarity 7 | 8 | 9 | if __name__ == '__main__': 10 | parser = ArgumentParser() 11 | parser.add_argument('--data_path_1', type=str, required=True, 12 | help='Path to first data CSV file') 13 | parser.add_argument('--data_path_2', type=str, required=True, 14 | help='Path to second data CSV file') 15 | parser.add_argument('--compound_names_1', action='store_true', default=False, 16 | help='Whether data_path_1 has compound names in addition to smiles') 17 | parser.add_argument('--compound_names_2', action='store_true', default=False, 18 | help='Whether data_path_2 has compound names in addition to smiles') 19 | parser.add_argument('--similarity_measure', type=str, required=True, choices=['scaffold', 'morgan'], 20 | help='Similarity measure to use to compare the two datasets') 21 | parser.add_argument('--radius', type=int, default=3, 22 | help='Radius of Morgan fingerprint') 23 | parser.add_argument('--sample_rate', type=float, default=1.0, 24 | help='Rate at which to sample pairs of molecules for Morgan similarity (to reduce time)') 25 | args = parser.parse_args() 26 | 27 | data_1 = get_data(path=args.data_path_1, use_compound_names=args.compound_names_1) 28 | data_2 = get_data(path=args.data_path_2, use_compound_names=args.compound_names_2) 29 | 30 | if args.similarity_measure == 'scaffold': 31 | scaffold_similarity(data_1.smiles(), data_2.smiles()) 32 | elif args.similarity_measure == 'morgan': 33 | morgan_similarity(data_1.smiles(), data_2.smiles(), args.radius, args.sample_rate) 34 | else: 35 | raise ValueError(f'Similarity measure "{args.similarity_measure}" not supported.') 36 | -------------------------------------------------------------------------------- /scripts/visualize_encoding_property_space.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, Namespace 2 | import os 3 | import random 4 | import sys 5 | sys.path.append('../') 6 | from typing import List 7 | 8 | import numpy as np 9 | import ternary 10 | from tqdm import trange 11 | 12 | from chemprop.data import MoleculeDataset 13 | from chemprop.data.utils import get_data 14 | from chemprop.utils import load_checkpoint, load_scalers 15 | 16 | 17 | def visualize_encoding_property_space(args: Namespace): 18 | # Load data 19 | data = get_data(path=args.data_path) 20 | 21 | # Sort according to similarity measure 22 | if args.similarity_measure == 'property': 23 | data.sort(key=lambda d: d.targets[args.task_index]) 24 | elif args.similarity_measure == 'random': 25 | data.shuffle(args.seed) 26 | else: 27 | raise ValueError(f'similarity_measure "{args.similarity_measure}" not supported or not implemented yet.') 28 | 29 | # Load model and scalers 30 | model = load_checkpoint(args.checkpoint_path) 31 | scaler, features_scaler = load_scalers(args.checkpoint_path) 32 | data.normalize_features(features_scaler) 33 | 34 | # Random seed 35 | if args.seed is not None: 36 | random.seed(args.seed) 37 | 38 | # Generate visualizations 39 | for i in trange(args.num_examples): 40 | # Get random three molecules with similar properties 41 | index = random.randint(1, len(data) - 2) 42 | molecules = MoleculeDataset(data[index - 1:index + 2]) 43 | molecule_targets = [t[args.task_index] for t in molecules.targets()] 44 | 45 | # Encode three molecules 46 | molecule_encodings = model.encoder(molecules.smiles()) 47 | 48 | # Define interpolation 49 | def predict_property(point: List[int]) -> float: 50 | # Return true value on endpoints of triangle 51 | argmax = np.argmax(point) 52 | if point[argmax] == 1: 53 | return molecule_targets[argmax] 54 | 55 | # Interpolate and predict task value 56 | encoding = sum(point[j] * molecule_encodings[j] for j in range(len(molecule_encodings))) 57 | pred = model.ffn(encoding).data.cpu().numpy() 58 | pred = scaler.inverse_transform(pred) 59 | pred = pred.item() 60 | 61 | return pred 62 | 63 | # Create visualization 64 | scale = 20 65 | fontsize = 6 66 | 67 | figure, tax = ternary.figure(scale=scale) 68 | tax.heatmapf(predict_property, boundary=True, style="hexagonal") 69 | tax.set_title("Property Prediction") 70 | tax.right_axis_label(f'{molecules[0].smiles} ({molecules[0].targets[args.task_index]:.6f}) -->', 71 | fontsize=fontsize) 72 | tax.left_axis_label(f'{molecules[1].smiles} ({molecules[1].targets[args.task_index]:.6f}) -->', 73 | fontsize=fontsize) 74 | tax.bottom_axis_label(f'<-- {molecules[2].smiles} ({molecules[2].targets[args.task_index]:.6f})', 75 | fontsize=fontsize) 76 | 77 | tax.savefig(os.path.join(args.save_dir, f'{i}.png')) 78 | 79 | 80 | if __name__ == '__main__': 81 | parser = ArgumentParser() 82 | parser.add_argument('--data_path', type=str, required=True, 83 | help='Path to regression dataset .csv') 84 | parser.add_argument('--checkpoint_path', type=str, required=True, 85 | help='Path to a model checkpoint .pt file') 86 | parser.add_argument('--similarity_measure', type=str, default='random', 87 | choices=['random', 'random'], 88 | help='Similarity measure to use when choosing the three molecules for each visualization') 89 | parser.add_argument('--task_index', type=int, default=0, 90 | help='Index of the task (property) in the dataset to use') 91 | parser.add_argument('--num_examples', type=int, default=10, 92 | help='Number of visualizations to generate') 93 | parser.add_argument('--save_dir', type=str, required=True, 94 | help='Directory where visualizations will be saved') 95 | parser.add_argument('--seed', type=int, default=None, 96 | help='Random seed for choosing three similar molecules') 97 | args = parser.parse_args() 98 | 99 | os.makedirs(args.save_dir, exist_ok=True) 100 | 101 | visualize_encoding_property_space(args) 102 | -------------------------------------------------------------------------------- /scripts/viz_attention.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, Namespace 2 | import os 3 | import sys 4 | sys.path.append('../') 5 | 6 | import torch 7 | from tqdm import trange 8 | 9 | from chemprop.data.utils import get_data 10 | from chemprop.utils import load_checkpoint 11 | 12 | 13 | def visualize_attention(args: Namespace): 14 | """Visualizes attention weights.""" 15 | print('Loading data') 16 | data = get_data(path=args.data_path) 17 | smiles = data.smiles() 18 | print(f'Data size = {len(smiles):,}') 19 | 20 | print(f'Loading model from "{args.checkpoint_path}"') 21 | model = load_checkpoint(args.checkpoint_path, cuda=args.cuda) 22 | mpn = model[0] 23 | 24 | for i in trange(0, len(smiles), args.batch_size): 25 | smiles_batch = smiles[i:i + args.batch_size] 26 | mpn.viz_attention(smiles_batch, viz_dir=args.viz_dir) 27 | 28 | 29 | if __name__ == '__main__': 30 | parser = ArgumentParser() 31 | parser.add_argument('--data_path', type=str, required=True, 32 | help='Path to data CSV file') 33 | parser.add_argument('--viz_dir', type=str, required=True, 34 | help='Path where attention PNGs will be saved') 35 | parser.add_argument('--checkpoint_path', type=str, required=True, 36 | help='Path to a model checkpoint') 37 | parser.add_argument('--batch_size', type=int, default=50, 38 | help='Batch size') 39 | parser.add_argument('--no_cuda', action='store_true', default=False, 40 | help='Turn off cuda') 41 | args = parser.parse_args() 42 | 43 | # Cuda 44 | args.cuda = not args.no_cuda and torch.cuda.is_available() 45 | del args.no_cuda 46 | 47 | # Create directory for preds path 48 | os.makedirs(args.viz_dir, exist_ok=True) 49 | 50 | visualize_attention(args) 51 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | with open('requirements.txt') as f: 4 | requirements = [line.strip() for line in f if not line.startswith('git+http')] 5 | 6 | setup( 7 | name='chemprop', 8 | version='1.0', 9 | author='Wengong Jin, Kyle Swanson, Kevin Yang', 10 | author_email='wengong@csail.mit.edu, swansonk@mit.edu, yangk@mit.edu', 11 | description='Molecular Property Prediction with Message Passing Neural Networks', 12 | url='https://github.com/wengong-jin/chemprop', 13 | license='MIT', 14 | packages=find_packages(), 15 | install_requires=requirements, 16 | keywords=['chemistry', 'machine learning', 'property prediction', 'message passing neural network'], 17 | ) 18 | -------------------------------------------------------------------------------- /static/jsme/.directory: -------------------------------------------------------------------------------- 1 | [Dolphin] 2 | SortOrder=1 3 | Timestamp=2015,5,31,16,7,38 4 | Version=3 5 | ViewMode=1 6 | -------------------------------------------------------------------------------- /static/jsme/JME_to_JSME.html: -------------------------------------------------------------------------------- 1 | 2 |
3 |53 | |
67 | 68 | 69 | | 70 |
74 | Atom highlighted by mouse over: 75 | | 76 | 77 |84 | Atom colors: 85 | 86 | 87 | | 88 | 89 | 90 |
104 | Note: All atom and molecule indices start at 1. 105 |
106 | 107 |
108 | public void setAtomToHighLight(int molIndex, int atomIndex)
109 | The hightlight is temporary.
110 |
113 | public void setNotifyAtomHighLightChangeJSfunction(String notifyAtomHighLightJSfunction)
114 | Receive a notification when JSME detects a mouse over one atom. The argument is the name of a JavaScript function that receives two arguments: the molecule index and the atom index.
115 | To cancel, set the argument to null.
116 |
120 | public void setAtomBackgroundColors(int molIndex, String atomAndColorCSV)
121 | atomAndColorCSV must be an integer between 0 and 6. 0 means no background color.
122 |
125 | public void resetAtomColors(int molIndex) 126 |
127 | 128 | 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /static/jsme/JSME_autoresize.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 15 | 16 | 17 | 18 |41 | Draw a structure, click on the star icon of the editor and click on one or more atoms 42 |
43 | 44 | 47 |33 | The atom that should act as a connection point (root) of the template should by marked by :1 (for example the oxygen serving as root should have atomic symbol O:1). 34 |
35 |
58 |
59 | See this page for real example. 60 | 61 | 62 | -------------------------------------------------------------------------------- /static/jsme/api_javadoc/allclasses-frame.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 |
5 | 6 |40 | Draw a structure, click on the star icon of the editor and click on one or more atoms 41 |
42 | 43 | 44 |23 | Files contained in the current directory should help you to implement JSME easily:
24 |
25 | doc.html provides basic information about JSME implementation into web page - start with this document
26 |
27 | 28 |
29 | Example pages:
30 |
48 | And some other files:
49 |
54 | If you are using the JSME cite please the following article:
55 | B. Bienfait and P. Ertl, JSME: a free molecule editor in JavaScript, J. Cheminformatics 5:24 (2013)
56 |
57 | Happy molecule editing! 58 | 59 | -------------------------------------------------------------------------------- /static/jsme/jme_examples/jme_example2.html: -------------------------------------------------------------------------------- 1 | 2 |
3 |49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /static/jsme/jme_examples/jme_example4.html: -------------------------------------------------------------------------------- 1 | 2 |
3 | 5 | 6 | 7 | 8 |53 |
62 |
76 |
80 | | 96 | |
JME 105 | Editor courtesy of Peter Ertl, Novartis | 106 |JME help | 110 |
{{checkpoint}} | 18 |19 | 20 | 21 | | 22 |
---|
{{dataset}} | 18 |19 | 20 | 21 | | 22 |
---|
This website can be used to predict molecular properties using a Message Passing Neural Network (MPNN). In order to make predictions, an MPNN first needs to be trained on a dataset containing molecules along with known property values for each molecule. Once the MPNN is trained, it can be used to predict those same properties on any new molecules.
10 | 11 |To train an MPNN, go to the Train page, upload a dataset or select a dataset which has already been uploaded, set the desired parameters, name the model, and then click "Train".
13 | 14 |To make property predictions using, go to the Predict page, select the trained model checkpoint you want to use, upload or paste the molecules you would like to make predictions on, and then click "Predict".
16 | 17 |To upload, view, download, or delete datasets and model checkpoints, go to the Data and Checkpoints pages, respectively.
19 | 20 |If GPUs are available on your machine, you will see a dropdown menu on the Train and Predict pages which will allow you to select a GPU to use. If you select "None" then only CPUs will be used.
24 | 25 |If you wish to train or predict on a remote server, you can use SSH port-forwarding to run training/predicting on the remote server while viewing the website it locally. To do so, follow these instructions:
27 |ssh <remote_user>@<remote_host>
chemprop
directory.chemprop
requirements: source activate <environment_name>
python web.py
ssh -N -L 5000:localhost:5000 <remote_user>@<remote_host>
Overall: {{ mean_score }} {{ metric }}
114 | 115 |{{ task_names[i] }}: {{ task_scores[i] }} {{ metric }}
118 | {% endfor %} 119 | 120 | {% endif %} 121 | 122 | {% endblock %} 123 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from chemprop.parsing import parse_train_args 4 | from chemprop.train import cross_validate 5 | from chemprop.utils import set_logger 6 | 7 | 8 | # Initialize logger 9 | logger = logging.getLogger('train') 10 | logger.setLevel(logging.DEBUG) 11 | logger.propagate = False 12 | 13 | if __name__ == '__main__': 14 | args = parse_train_args() 15 | set_logger(logger, args.save_dir, args.quiet) 16 | cross_validate(args, logger) 17 | --------------------------------------------------------------------------------