├── testdata ├── test_14_record.gz.info ├── test_14_record.inchikey.txt ├── test_14_record.gz ├── test_14.spectra_library.npy ├── test_dataset_config_file.json ├── README └── test_2_mend.sdf ├── .gitignore ├── neims_toc.jpeg ├── __init__.py ├── test_utils.py ├── CONTRIBUTING.md ├── feature_map_constants.py ├── Model_Retrain_Quickstart.md ├── make_spectra_prediction.py ├── gather_similarites.py ├── README.md ├── mass_spec_constants.py ├── examples └── pentachlorobenzene.sdf ├── feature_utils_test.py ├── spectra_predictor_test.py ├── similarity.py ├── util_test.py ├── make_predictions.py ├── make_predictions_from_tfrecord.py ├── train_test_split_utils.py ├── molecule_estimator_test.py ├── util.py ├── spectra_predictor.py ├── dataset_setup_constants.py ├── make_train_test_split_test.py ├── LICENSE ├── feature_utils.py ├── molecule_estimator.py ├── plot_spectra_utils.py ├── library_matching_test.py └── make_train_test_split.py /testdata/test_14_record.gz.info: -------------------------------------------------------------------------------- 1 | 12 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | .ipynb_checkpoints/* 3 | __pycache__/* 4 | -------------------------------------------------------------------------------- /testdata/test_14_record.inchikey.txt: -------------------------------------------------------------------------------- 1 | YXHKONLOYHBTNS-UHFFFAOYSA-N 2 | -------------------------------------------------------------------------------- /neims_toc.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brain-research/deep-molecular-massspec/HEAD/neims_toc.jpeg -------------------------------------------------------------------------------- /testdata/test_14_record.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brain-research/deep-molecular-massspec/HEAD/testdata/test_14_record.gz -------------------------------------------------------------------------------- /testdata/test_14.spectra_library.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brain-research/deep-molecular-massspec/HEAD/testdata/test_14.spectra_library.npy -------------------------------------------------------------------------------- /testdata/test_dataset_config_file.json: -------------------------------------------------------------------------------- 1 | {"LIBRARY_MATCHING_PREDICTED": ["test_14_record.gz"], "SPECTRUM_PREDICTION_TEST": ["test_14_record.gz"], "mainlib_train_spectra_library_file": "test_14.spectra_library.npy", "SPECTRUM_PREDICTION_TRAIN": ["test_14_record.gz"], "LIBRARY_MATCHING_QUERY": ["test_14_record.gz"], "LIBRARY_MATCHING_OBSERVED": ["test_14_record.gz"]} 2 | -------------------------------------------------------------------------------- /testdata/README: -------------------------------------------------------------------------------- 1 | test_2_mend.sdf and test_14_mend.sdf contain molblocks from mainlib.sdf, an SDF 2 | file from NIST. The actual mass spectra data has been removed from these SDF 3 | files, and replaced with fake data. 4 | 5 | test_14_record.gz is a TFRecord with features 'molecule weight', 'atom weights', 6 | 'circular_fp_1024', 'dense_mass_spec' 7 | 8 | NB : Since 2 molecules will not pass the filters in get_sdf_to_mol, there will 9 | only be 12 molecules into the TFRecord file. -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /test_utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for tests.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | import os 7 | from absl import flags 8 | 9 | 10 | def test_dir(relative_path): 11 | """Gets the path to a testdata file in genomics at relative path. 12 | 13 | Args: 14 | relative_path: a directory path relative to base directory of this module. 15 | Returns: 16 | The absolute path to a testdata file. 17 | """ 18 | 19 | return os.path.join(flags.FLAGS.test_srcdir, 20 | os.path.split(os.path.abspath(__file__))[0], 21 | relative_path) 22 | 23 | 24 | def encode(value, is_py3_flag): 25 | """A helper function for wrapping strings as bytes for py3.""" 26 | if is_py3_flag: 27 | return value.encode('utf-8') 28 | else: 29 | return value 30 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /feature_map_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Module containing the feature names used in the TFRecords of this repo.""" 16 | 17 | # Overall Tasks 18 | LIBRARY_MATCHING = 'library_matching' 19 | SPECTRUM_PREDICTION = 'spectrum_prediction' 20 | 21 | # Feature names 22 | ATOM_WEIGHTS = 'atom_weights' 23 | MOLECULE_WEIGHT = 'molecule_weight' 24 | CIRCULAR_FP_BASENAME = 'circular_fp' 25 | COUNTING_CIRCULAR_FP_BASENAME = 'counting_circular_fp' 26 | INDEX_TO_GROUND_TRUTH_ARRAY = 'index_in_library' 27 | 28 | FP_TYPE_LIST = [ 29 | CIRCULAR_FP_BASENAME, COUNTING_CIRCULAR_FP_BASENAME 30 | ] 31 | 32 | DENSE_MASS_SPEC = 'dense_mass_spec' 33 | INCHIKEY = 'inchikey' 34 | NAME = 'name' 35 | SMILES = 'smiles' 36 | MOLECULAR_FORMULA = 'molecular_formula' 37 | ADJACENCY_MATRIX = 'adjacency_matrix' 38 | ATOM_IDS = 'atom_id' 39 | SMILES_TOKEN_LIST_LENGTH = 'smiles_sequence_len' 40 | -------------------------------------------------------------------------------- /Model_Retrain_Quickstart.md: -------------------------------------------------------------------------------- 1 | ## Quickstart 2 | 3 | The following is a tutorial for training a new model. 4 | 5 | Setup a directory for predictions: 6 | 7 | ``` 8 | $ TARGET_PATH_NAME=/tmp/massspec_predictions 9 | ``` 10 | 11 | To convert an sdf file into a TFRecord and make the training and test splits: 12 | 13 | ``` 14 | $ python make_train_test_split.py --main_sdf_name=testdata/test_14_mend.sdf 15 | --replicates_sdf_name=testdata/test_2_mend.sdf \\ \ 16 | --output_master_dir=$TARGET_PATH_NAME/spectra_tf_records 17 | ``` 18 | 19 | To train a model: 20 | 21 | ``` 22 | python molecule_estimator.py 23 | --dataset_config_file=$TARGET_PATH_NAME/spectra_tf_records/query_replicates_val_predicted_replicates_val.json 24 | --train_steps=1000 \\ \ 25 | --model_dir=$TARGET_PATH_NAME/models/output --hparams=make_spectra_plots=True 26 | --alsologtostderr 27 | ``` 28 | 29 | The aggregate training results will be logged to stdout. The final library 30 | matching results can also be found in 31 | $TARGET_PATH_NAME/models/output/predictions. These results report the query 32 | inchikey, the matched inchikey, and the rank asigned to the true spectra. 33 | 34 | It is also possible to view these results in tensorboard: \ 35 | tensorboard --logdir=path/to/log-directory 36 | 37 | To predict spectra for new data given a model, run: 38 | 39 | ``` 40 | python make_predictions_from_tfrecord.py \ 41 | --input_file=testdata/test_14_record.gz \ 42 | --output_file=$TARGET_PATH_NAME/models/output_predictions \ 43 | --model_checkpoint_path=$TARGET_PATH_NAME/models/output/ \ 44 | --hparams=eval_batch_size=16 \ 45 | --alsologtostderr 46 | ``` 47 | -------------------------------------------------------------------------------- /make_spectra_prediction.py: -------------------------------------------------------------------------------- 1 | r"""Makes spectra prediction using model and writes predictions to SDF. 2 | 3 | Make predictions using our trained model. Example of how to run: 4 | 5 | # Save weights to a models directory 6 | $ MODEL_WEIGHTS_DIR=/tmp/neims_model 7 | $ cd $MODEL_WEIGHTS_DIR 8 | $ wget https://storage.googleapis.com/deep-molecular-massspec/massspec_weights/massspec_weights.zip # pylint: disable=line-too-long 9 | $ unzip massspec_weights.zip 10 | 11 | $ python make_spectra_prediction.py \ 12 | --input_file=examples/pentachlorobenzene.sdf \ 13 | --output_file=/tmp/neims_model/annotated.sdf \ 14 | --weights_dir=$MODEL_WEIGHTS_DIR/massspec_weights 15 | """ 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl import app 22 | from absl import flags 23 | from absl import logging 24 | 25 | import spectra_predictor 26 | 27 | FLAGS = flags.FLAGS 28 | flags.DEFINE_string('input_file', 'input.sdf', 29 | 'Name of input file for predictions.') 30 | flags.DEFINE_string('weights_dir', 31 | '/usr/local/massspec_weights', 32 | 'Name of directory that stores model weights.') 33 | flags.DEFINE_string('output_file', 'annotated.sdf', 34 | 'Name of output file for predictions.') 35 | 36 | 37 | def main(_): 38 | logging.info('Loading weights from %s', FLAGS.weights_dir) 39 | predictor = spectra_predictor.NeimsSpectraPredictor( 40 | model_checkpoint_dir=FLAGS.weights_dir) 41 | 42 | logging.info('Loading molecules from %s', FLAGS.input_file) 43 | mols_from_file = spectra_predictor.get_mol_list_from_sdf( 44 | FLAGS.input_file) 45 | fingerprints, mol_weights = predictor.get_inputs_for_model_from_mol_list( 46 | mols_from_file) 47 | 48 | logging.info('Making predictions ...') 49 | spectra_predictions = predictor.make_spectra_prediction( 50 | fingerprints, mol_weights) 51 | 52 | logging.info('Updating molecules in place with predictions.') 53 | spectra_predictor.update_mols_with_spectra(mols_from_file, 54 | spectra_predictions) 55 | 56 | logging.info('Writing predictions to %s', FLAGS.output_file) 57 | with open(FLAGS.output_file, 'w') as f: 58 | spectra_predictor.write_rdkit_mols_to_sdf(mols_from_file, f) 59 | 60 | 61 | if __name__ == '__main__': 62 | app.run(main) 63 | -------------------------------------------------------------------------------- /gather_similarites.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | import parse_sdf_utils 6 | import train_test_split_utils 7 | import feature_utils 8 | import mass_spec_constants as ms_constants 9 | import similarity as similarity_lib 10 | 11 | 12 | def make_spectra_array(mol_list): 13 | """Grab spectra pertaining to same molecule in one np.array. 14 | Args: 15 | mol_list: list of rdkit.Mol objects. Each Mol should contain 16 | information about the spectra, as stored in NIST. 17 | Output: 18 | np.array of spectra of shape (number of spectra, max spectra length) 19 | """ 20 | mass_spec_spectra = np.zeros( ( len(mol_list), ms_constants.MAX_PEAK_LOC)) 21 | for idx, mol in enumerate(mol_list): 22 | spectra_str = mol.GetProp(ms_constants.SDF_TAG_MASS_SPEC_PEAKS) 23 | spectral_locs, spectral_intensities = feature_utils.parse_peaks(spectra_str) 24 | dense_mass_spec = feature_utils.make_dense_mass_spectra( 25 | spectral_locs, spectral_intensities, ms_constants.MAX_PEAK_LOC) 26 | 27 | mass_spec_spectra[idx, :] = dense_mass_spec 28 | 29 | return mass_spec_spectra 30 | 31 | 32 | def get_similarities(raw_spectra_array): 33 | """Preprocess spectra and then calculate similarity between spectra. 34 | Args: 35 | raw_spectra_array: np.array containing unprocessed spectra 36 | Output: 37 | np.array of shape (len(raw_spectra_array), len(raw_spectra_array)) 38 | reflects distances between spectra. 39 | """ 40 | spec_array_var = tf.constant(raw_spectra_array) 41 | 42 | # Adjusting intensity to match default in molecule_predictors 43 | intensity_adjusted_spectra = tf.pow(spec_array_var, 0.5) 44 | 45 | hparams = tf.contrib.training.HParams( 46 | mass_power=1., 47 | ) 48 | 49 | cos_similarity = similarity_lib.GeneralizedCosineSimilarityProvider(hparams) 50 | norm_spectra = cos_similarity._normalize_rows(intensity_adjusted_spectra) 51 | similarity = cos_similarity.compute_similarity(norm_spectra, norm_spectra) 52 | 53 | with tf.Session() as sess: 54 | sess.run(tf.global_variables_initializer()) 55 | dist = sess.run(similarity) 56 | 57 | return dist 58 | 59 | 60 | def main(): 61 | mol_list = parse_sdf_utils.get_sdf_to_mol('/mnt/storage/NIST_zipped/NIST17/replib_mend.sdf') 62 | inchikey_dict = train_test_split_utils.make_inchikey_dict(mol_list) 63 | 64 | spectra_for_one_mol = make_spectra_array(inchikey_dict['PDACHFOTOFNHBT-UHFFFAOYSA-N']) 65 | distance_matrix = get_similarities(spectra_for_one_mol) 66 | print('distance for spectra in PDACHFOTOFNHBT-UHFFFAOYSA-N', distance_matrix) 67 | 68 | if __name__ == '__main__': 69 | main() 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep learning for Electron Ionization mass spectrometry for organic molecules 2 | ![TOC](https://github.com/brain-research/deep-molecular-massspec/blob/master/neims_toc.jpeg?raw=true) 3 | 4 | This repository accompanies 5 | 6 | [Rapid Prediction of Electron–Ionization Mass Spectrometry Using Neural Networks](https://pubs.acs.org/doi/10.1021/acscentsci.9b00085)\ 7 | [Jennifer N. Wei](https://ai.google/research/people/JenniferNWei), 8 | [David Belanger](https://davidbelanger.github.io/), [Ryan P. Adams](https://www.cs.princeton.edu/~rpa/), 9 | and [D. Sculley](https://www.eecs.tufts.edu/~dsculley/)\ 10 | ACS Central Science 2019 5 (4), 700-708\ 11 | DOI: 10.1021/acscentsci.9b00085 12 | 13 | 14 | ## Introduction 15 | 16 | We predict the mass spectrometry spectra of molecules using deep learning 17 | techniques applied to various molecule representations. The performance behavior 18 | is evaluated with a custom-made library matching task. In this task we identify 19 | molecules by matching its spectra to a library of labeled spectra. As a 20 | baseline, this library contains all of the molecules in the NIST main library, 21 | which mimics the behavior currently used by experimental chemists. To test our 22 | predictions, we replace portions of the library with spectra predictions from 23 | our model. This task is described in more detail below. 24 | 25 | ## Required Packages: 26 | 27 | It is recommended to use [Anaconda](https://www.anaconda.com/distribution/) with a Python 3.6 environment to install these packages. 28 | - [RDKit](https://www.rdkit.org/docs/Install.html) 29 | - [Tensorflow](https://www.tensorflow.org/install) 30 | 31 | Most of the packages required here can be installed with conda install tensorflow=1.13.2 rdkit matplotlib and pip install absl-py. 32 | 33 | ## Quickstart Guide for Making Model Predictions 34 | 35 | 1. Create a directory and download the weights for the model. 36 | 37 | ``` 38 | $ MODEL_WEIGHTS_DIR=/home/path/to/model 39 | $ mkdir $MODEL_WEIGHTS_DIR 40 | $ pushd $MODEL_WEIGHTS_DIR 41 | $ curl -o https://storage.googleapis.com/deep-molecular-massspec/massspec_weights/massspec_weights.zip 42 | $ unzip massspec_weights.zip 43 | $ popd 44 | ``` 45 | 46 | 2. Run the model prediction on the example molecule 47 | 48 | ``` 49 | $ python make_spectra_prediction.py \ 50 | --input_file=examples/pentachlorobenzene.sdf \ 51 | --output_file=/tmp/annotated.sdf \ 52 | --weights_dir=$MODEL_WEIGHTS_DIR/massspec_weights 53 | ``` 54 | 55 | ## Training splits for benchmarking purposes 56 | The molecules used for the training, validation, and test sets can be found under the 57 | directory *training_splits*. The molecules are provided in 58 | inchikey and smiles format. 59 | 60 | 61 | ## To cite this work: 62 | 63 | @article{doi:10.1021/acscentsci.9b00085,\ 64 | author = {Wei, Jennifer N. and Belanger, David and Adams, Ryan P. and Sculley, D.},\ 65 | title = {Rapid Prediction of Electron–Ionization Mass Spectrometry Using Neural Networks},\ 66 | journal = {ACS Central Science},\ 67 | volume = {5},\ 68 | number = {4},\ 69 | pages = {700-708},\ 70 | year = {2019},\ 71 | doi = {10.1021/acscentsci.9b00085},\ 72 | URL = {https://doi.org/10.1021/acscentsci.9b00085},\ 73 | } 74 | -------------------------------------------------------------------------------- /mass_spec_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Module containing all commonly used variables in this repo.""" 16 | 17 | from collections import namedtuple 18 | 19 | 20 | class CircularFingerprintKey( 21 | namedtuple('CircularFingerprintKey', ['fp_type', 'fp_len', 'radius'])): 22 | """Helper function for labeling fingerprint keys.""" 23 | 24 | def __str__(self): 25 | return self.fp_type + '_' + str(self.fp_len) + '_' + str(self.radius) 26 | 27 | 28 | # Constants for SDF tags found in NIST sdf files: 29 | SDF_TAG_MASS_SPEC_PEAKS = 'MASS SPECTRAL PEAKS' 30 | SDF_TAG_INCHIKEY = 'INCHIKEY' 31 | SDF_TAG_NAME = 'NAME' 32 | SDF_TAG_MOLECULE_MASS = 'EXACT MASS' 33 | SDF_TAG_MASS_SPEC_PEAKS = 'MASS SPECTRAL PEAKS' 34 | 35 | # Constants for fields in TFRecords 36 | MAX_MZ_WEIGHT_RATIO = 3.0 37 | MAX_PEAK_LOC = 1000 38 | MAX_ATOMS = 100 39 | MAX_ATOM_ID = 100 40 | MAX_TOKEN_LIST_LENGTH = 230 41 | 42 | CIRCULAR_FP_RADII_LIST = [2, 4, 6] 43 | NUM_CIRCULAR_FP_BITS_LIST = [1024, 2048, 4096] 44 | ADD_HS_TO_MOLECULES = False 45 | 46 | TWO_LETTER_TOKEN_NAMES = [ 47 | 'Al', 'Ce', 'Co', 'Ge', 'Gd', 'Cs', 'Th', 'Cd', 'As', 'Na', 'Nb', 'Li', 48 | 'Ni', 'Se', 'Sc', 'Sb', 'Sn', 'Hf', 'Hg', 'Si', 'Be', 'Cl', 'Rb', 'Fe', 49 | 'Bi', 'Br', 'Ag', 'Ru', 'Zn', 'Te', 'Mo', 'Pt', 'Mn', 'Os', 'Tl', 'In', 50 | 'Cu', 'Mg', 'Ti', 'Pb', 'Re', 'Pd', 'Ir', 'Rh', 'Zr', 'Cr', '@@', 'se', 51 | 'si', 'te' 52 | ] 53 | 54 | METAL_ATOM_SYMBOLS = [ 55 | 'As', 'Cr', 'Cs', 'Cu', 'Be', 'Ag', 'Co', 'Al', 'Cd', 'Ce', 'Si', 'Sn', 56 | 'Os', 'Sb', 'Sc', 'In', 'Se', 'Ni', 'Th', 'Hg', 'Hf', 'Li', 'Nb', 'U', 'Y', 57 | 'V', 'W', 'Tl', 'Na', 'Fe', 'K', 'Zr', 'B', 'Pb', 'Pd', 'Rh', 'Re', 'Gd', 58 | 'Ge', 'Ir', 'Rb', 'Ti', 'Pt', 'Mn', 'Mg', 'Ru', 'Bi', 'Zn', 'Te', 'Mo' 59 | ] 60 | 61 | SMILES_TOKEN_NAMES = [ 62 | '#', '%', '(', ')', '+', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', 63 | '7', '8', '9', '=', '@', '@@', 'Ag', 'Al', 'As', 'B', 'Be', 'Bi', 'Br', 'C', 64 | 'Cd', 'Ce', 'Cl', 'Co', 'Cr', 'Cs', 'Cu', 'F', 'Fe', 'Gd', 'Ge', 'H', 'Hf', 65 | 'Hg', 'I', 'In', 'Ir', 'K', 'Li', 'Mg', 'Mn', 'Mo', 'N', 'Na', 'Nb', 'Ni', 66 | 'O', 'Os', 'P', 'Pb', 'Pd', 'Pt', 'Rb', 'Re', 'Rh', 'Ru', 'S', 'Sb', 'Sc', 67 | 'Se', 'Si', 'Sn', 'Te', 'Th', 'Ti', 'Tl', 'U', 'V', 'W', 'Y', 'Zn', 'Zr', 68 | '[', '\\', ']', 'c', 'n', 'o', 'p', 's' 69 | ] 70 | 71 | SMILES_TOKEN_NAME_TO_INDEX = { 72 | name: idx for idx, name in enumerate(SMILES_TOKEN_NAMES) 73 | } 74 | 75 | # Add 3 elements which also have lowercase representations in SMILES string. 76 | # We want these to have the same index as the upper-lower case version. 77 | SMILES_TOKEN_NAME_TO_INDEX['se'] = SMILES_TOKEN_NAME_TO_INDEX['Se'] 78 | SMILES_TOKEN_NAME_TO_INDEX['si'] = SMILES_TOKEN_NAME_TO_INDEX['Si'] 79 | SMILES_TOKEN_NAME_TO_INDEX['te'] = SMILES_TOKEN_NAME_TO_INDEX['Te'] 80 | 81 | # Bond order master list: 82 | BOND_ORDER_TO_INTS_DICT = {1.0: 1, 2.0: 2, 3.0: 3, 1.5: 4} 83 | 84 | TRUE_SPECTRA_SCALING_FACTOR = 0.1 85 | -------------------------------------------------------------------------------- /examples/pentachlorobenzene.sdf: -------------------------------------------------------------------------------- 1 | 11855 2 | -OEChem-08201916472D 3 | 4 | 12 12 0 0 0 0 0 0 0999 V2000 5 | 3.7320 1.5000 0.0000 Cl 0 0 0 0 0 0 0 0 0 0 0 0 6 | 5.4641 0.5000 0.0000 Cl 0 0 0 0 0 0 0 0 0 0 0 0 7 | 2.0000 0.5000 0.0000 Cl 0 0 0 0 0 0 0 0 0 0 0 0 8 | 2.0000 -1.5000 0.0000 Cl 0 0 0 0 0 0 0 0 0 0 0 0 9 | 5.4641 -1.5000 0.0000 Cl 0 0 0 0 0 0 0 0 0 0 0 0 10 | 3.7320 0.5000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0 11 | 4.5981 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0 12 | 2.8660 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0 13 | 2.8660 -1.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0 14 | 4.5981 -1.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0 15 | 3.7320 -1.5000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0 16 | 3.7320 -2.1200 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0 17 | 1 6 1 0 0 0 0 18 | 2 7 1 0 0 0 0 19 | 3 8 1 0 0 0 0 20 | 4 9 1 0 0 0 0 21 | 5 10 1 0 0 0 0 22 | 6 7 2 0 0 0 0 23 | 6 8 1 0 0 0 0 24 | 7 10 1 0 0 0 0 25 | 8 9 2 0 0 0 0 26 | 9 11 1 0 0 0 0 27 | 10 11 2 0 0 0 0 28 | 11 12 1 0 0 0 0 29 | M END 30 | > 31 | 11855 32 | 33 | > 34 | 1 35 | 36 | > 37 | 125 38 | 39 | > 40 | 0 41 | 42 | > 43 | 0 44 | 45 | > 46 | 0 47 | 48 | > 49 | AAADcQBgAAAHAAAAAAAAAAAAAAAAAAAAAAAwAAAAAAAAAAABAAAAGAIAAAAACAKAECAwAIAAAACAACBCAAACAAAgBQAAikAAAogIICKBEhCAIAAggAAIiAcAAAAAAAAQAACAAAQAACAAAQAACAAAAAAAAA== 50 | 51 | > 52 | 1,2,3,4,5-pentachlorobenzene 53 | 54 | > 55 | 1,2,3,4,5-pentachlorobenzene 56 | 57 | > 58 | 1,2,3,4,5-pentachlorobenzene 59 | 60 | > 61 | 1,2,3,4,5-pentachlorobenzene 62 | 63 | > 64 | 1,2,3,4,5-pentakis(chloranyl)benzene 65 | 66 | > 67 | 1,2,3,4,5-pentachlorobenzene 68 | 69 | > 70 | InChI=1S/C6HCl5/c7-2-1-3(8)5(10)6(11)4(2)9/h1H 71 | 72 | > 73 | CEOCDNVZRAIOQZ-UHFFFAOYSA-N 74 | 75 | > 76 | 5.2 77 | 78 | > 79 | 249.849138 80 | 81 | > 82 | C6HCl5 83 | 84 | > 85 | 250.3 86 | 87 | > 88 | C1=C(C(=C(C(=C1Cl)Cl)Cl)Cl)Cl 89 | 90 | > 91 | C1=C(C(=C(C(=C1Cl)Cl)Cl)Cl)Cl 92 | 93 | > 94 | 0 95 | 96 | > 97 | 247.852089 98 | 99 | > 100 | 0 101 | 102 | > 103 | 11 104 | 105 | > 106 | 0 107 | 108 | > 109 | 0 110 | 111 | > 112 | 0 113 | 114 | > 115 | 0 116 | 117 | > 118 | 0 119 | 120 | > 121 | 1 122 | 123 | > 124 | -1 125 | 126 | > 127 | 1 128 | 5 129 | 255 130 | 131 | > 132 | 10 11 8 133 | 6 7 8 134 | 6 8 8 135 | 7 10 8 136 | 8 9 8 137 | 9 11 8 138 | 139 | $$$$ 140 | -------------------------------------------------------------------------------- /feature_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for feature_utils.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | 22 | import feature_utils 23 | import mass_spec_constants as ms_constants 24 | import numpy as np 25 | from rdkit import Chem 26 | import tensorflow as tf 27 | 28 | 29 | class FeatureUtilsTest(tf.test.TestCase): 30 | 31 | def setUp(self): 32 | test_mol_smis = [ 33 | 'CCCC', 'CC(=CC(=O)C)O.CC(=CC(=O)C)O.[Cu]', 'CCCCl', 34 | ('C[C@H](CCCC(C)C)[C@H]1CC[C@@H]2[C@@]1' 35 | '(CC[C@H]3[C@H]2CC=C4[C@@]3(CC[C@@H](C4)O)C)C') 36 | ] 37 | self.test_mols = [ 38 | Chem.MolFromSmiles(mol_str) for mol_str in test_mol_smis 39 | ] 40 | 41 | def _validate_smiles_string_tokenization(self, smiles_string, 42 | expected_token_list): 43 | token_list = feature_utils.tokenize_smiles(np.array([smiles_string])) 44 | self.assertAllEqual(token_list, expected_token_list) 45 | 46 | def test_tokenize_smiles_string(self): 47 | self._validate_smiles_string_tokenization('CCCC', [28, 28, 28, 28]) 48 | self._validate_smiles_string_tokenization('ClCCCC', [31, 28, 28, 28, 28]) 49 | self._validate_smiles_string_tokenization('CCClCC', [28, 28, 31, 28, 28]) 50 | self._validate_smiles_string_tokenization('CCCCCl', [28, 28, 28, 28, 31]) 51 | self._validate_smiles_string_tokenization('ClC(CC)CCl', 52 | [31, 28, 2, 28, 28, 3, 28, 31]) 53 | self._validate_smiles_string_tokenization( 54 | 'ClC(CCCl)CCl', [31, 28, 2, 28, 28, 31, 3, 28, 31]) 55 | self._validate_smiles_string_tokenization('BrCCCCCl', 56 | [27, 28, 28, 28, 28, 31]) 57 | self._validate_smiles_string_tokenization('ClCCCCBr', 58 | [31, 28, 28, 28, 28, 27]) 59 | self._validate_smiles_string_tokenization('[Te][te]', 60 | [81, 71, 83, 81, 71, 83]) 61 | 62 | def test_check_mol_only_has_atoms(self): 63 | result = [ 64 | feature_utils.check_mol_only_has_atoms(mol, ['C']) 65 | for mol in self.test_mols 66 | ] 67 | self.assertAllEqual(result, [True, False, False, False]) 68 | 69 | def test_check_mol_does_not_have_atoms(self): 70 | result = [ 71 | feature_utils.check_mol_does_not_have_atoms( 72 | mol, ms_constants.METAL_ATOM_SYMBOLS) for mol in self.test_mols 73 | ] 74 | self.assertAllEqual(result, [True, False, True, True]) 75 | 76 | def test_make_filter_by_substructure(self): 77 | filter_fn = feature_utils.make_filter_by_substructure('steroid') 78 | result = [filter_fn(mol) for mol in self.test_mols] 79 | self.assertAllEqual(result, [False, False, False, True]) 80 | 81 | def test_convert_spectrum_array_to_string(self): 82 | spectra_array = np.zeros((2, 1000)) 83 | spectra_array[0, 3] = 100 84 | spectra_array[1, 39] = 100 85 | spectra_array[1, 21] = 60 86 | 87 | expected_spectra_strings = ['3 100', '21 60\n39 100'] 88 | result_spectra_strings = [] 89 | for idx in range(np.shape(spectra_array)[0]): 90 | result_spectra_strings.append( 91 | feature_utils.convert_spectrum_array_to_string(spectra_array[idx, :])) 92 | 93 | self.assertAllEqual(expected_spectra_strings, result_spectra_strings) 94 | 95 | 96 | if __name__ == '__main__': 97 | tf.test.main() 98 | -------------------------------------------------------------------------------- /testdata/test_2_mend.sdf: -------------------------------------------------------------------------------- 1 | METHANE, DIAZO- 2 | -MSLIB32-03140511252D 1 1.0 0.0 57 3 | CAS rn = 334883, Library ID = 4 4 | 3 2 0 0 0 0 0 0 0 0 0 5 | 20.0000 15.0000 0.0000 C 0 0 0 0 0 0 0 0 0 6 | 0.0000 15.0000 0.0000 N 0 0 0 0 5 0 0 0 0 7 | -20.0000 15.0000 0.0000 N 0 0 0 0 3 0 0 0 0 8 | 1 2 2 0 0 0 0 9 | 2 3 3 0 0 0 0 10 | M END 11 | > 12 | Methane, diazo- 13 | 14 | > 15 | Azimethylene 16 | Diazomethane 17 | Acomethylene 18 | Diazirine 19 | Diazonium methylide 20 | 21 | > 22 | YXHKONLOYHBTNS-UHFFFAOYSA-N 23 | 24 | > 25 | CH2N2 26 | 27 | > 28 | 42 29 | 30 | > 31 | 42.0217981 32 | 33 | > 34 | 22 110 35 | 23 220 36 | 24 999 37 | 25 25 38 | 26 12 39 | 27 58 40 | 28 179 41 | 30 22 42 | 31 110 43 | 32 425 44 | 45 | $$$$ 46 | (4-(4-CHLORPHENYL)-3-MORPHOLINO-PYRROL-2-YL)-BUTENEDIOIC ACID, DIMETHYL ESTER 47 | -MSLIB32-03140511252D 1 1.0 0.0 286647 48 | CAS rn = 164221378, Library ID = 6 49 | 28 30 0 0 0 0 0 0 0 0 0 50 | -29.0000 41.0000 0.0000 C 0 0 0 0 0 0 0 0 0 51 | -49.0000 39.0000 0.0000 C 0 0 0 0 0 0 0 0 0 52 | 9.0000 6.0000 0.0000 C 0 0 0 0 0 0 0 0 0 53 | -8.0000 -6.0000 0.0000 N 0 0 0 0 3 0 0 0 0 54 | 38.0000 73.0000 0.0000 O 0 0 0 0 2 0 0 0 0 55 | 46.0000 55.0000 0.0000 C 0 0 0 0 0 0 0 0 0 56 | -21.0000 59.0000 0.0000 C 0 0 0 0 0 0 0 0 0 57 | -18.0000 25.0000 0.0000 C 0 0 0 0 0 0 0 0 0 58 | -61.0000 55.0000 0.0000 C 0 0 0 0 0 0 0 0 0 59 | -33.0000 75.0000 0.0000 C 0 0 0 0 0 0 0 0 0 60 | 2.0000 25.0000 0.0000 C 0 0 0 0 0 0 0 0 0 61 | -24.0000 6.0000 0.0000 C 0 0 0 0 0 0 0 0 0 62 | -53.0000 73.0000 0.0000 C 0 0 0 0 0 0 0 0 0 63 | -65.0000 89.0000 0.0000 Cl 0 0 0 0 1 0 0 0 0 64 | 28.0000 -0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 65 | 14.0000 41.0000 0.0000 N 0 0 0 0 3 0 0 0 0 66 | 32.0000 -20.0000 0.0000 C 0 0 0 0 0 0 0 0 0 67 | 43.0000 13.0000 0.0000 C 0 0 0 0 0 0 0 0 0 68 | 34.0000 39.0000 0.0000 C 0 0 0 0 0 0 0 0 0 69 | 6.0000 59.0000 0.0000 C 0 0 0 0 0 0 0 0 0 70 | 18.0000 75.0000 0.0000 C 0 0 0 0 0 0 0 0 0 71 | 17.0000 -33.0000 0.0000 C 0 0 0 0 0 0 0 0 0 72 | 21.0000 -53.0000 0.0000 O 0 0 0 0 2 0 0 0 0 73 | -2.0000 -27.0000 0.0000 O 0 0 0 0 2 0 0 0 0 74 | 62.0000 7.0000 0.0000 O 0 0 0 0 2 0 0 0 0 75 | 38.0000 33.0000 0.0000 O 0 0 0 0 2 0 0 0 0 76 | 66.0000 -13.0000 0.0000 C 0 0 0 0 0 0 0 0 0 77 | 40.0000 -59.0000 0.0000 C 0 0 0 0 0 0 0 0 0 78 | 1 2 2 0 0 0 0 79 | 1 7 1 0 0 0 0 80 | 1 8 1 0 0 0 0 81 | 2 9 1 0 0 0 0 82 | 3 4 1 0 0 0 0 83 | 3 15 1 0 0 0 0 84 | 3 11 2 0 0 0 0 85 | 4 12 1 0 0 0 0 86 | 5 6 1 0 0 0 0 87 | 5 21 1 0 0 0 0 88 | 6 19 1 0 0 0 0 89 | 7 10 2 0 0 0 0 90 | 8 11 1 0 0 0 0 91 | 8 12 2 0 0 0 0 92 | 9 13 2 0 0 0 0 93 | 10 13 1 0 0 0 0 94 | 11 16 1 0 0 0 0 95 | 13 14 1 0 0 0 0 96 | 15 17 2 0 0 0 0 97 | 15 18 1 0 0 0 0 98 | 16 19 1 0 0 0 0 99 | 16 20 1 0 0 0 0 100 | 17 22 1 0 0 0 0 101 | 18 25 1 0 0 0 0 102 | 18 26 2 0 0 0 0 103 | 20 21 1 0 0 0 0 104 | 22 23 1 0 0 0 0 105 | 22 24 2 0 0 0 0 106 | 23 28 1 0 0 0 0 107 | 25 27 1 0 0 0 0 108 | M END 109 | > 110 | (4-(4-Chlorphenyl)-3-morpholino-pyrrol-2-yl)-butenedioic acid, dimethyl ester 111 | 112 | > 113 | Dimethyl (2E)-2-[4-(4-chlorophenyl)-3-(4-morpholinyl)-1H-pyrrol-2-yl]-2-butenedioate # 114 | 115 | > 116 | PNYUDNYAXSEACV-RVDMUPIBSA-N 117 | 118 | > 119 | C20H21ClN2O5 120 | 121 | > 122 | 404 123 | 124 | > 125 | 404.1139 126 | 127 | > 128 | 32 12 129 | 33 7 130 | 34 28 131 | 35 999 132 | 36 57 133 | 37 302 134 | 38 975 135 | 39 8 136 | 40 53 137 | 41 176 138 | 42 99 139 | 43 122 140 | 44 117 141 | 45 155 142 | 46 9 143 | 47 7 144 | 49 6 145 | 50 28 146 | 51 59 147 | 148 | $$$$ 149 | -------------------------------------------------------------------------------- /spectra_predictor_test.py: -------------------------------------------------------------------------------- 1 | """Tests for .spectra_predictor.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | import os 7 | import tempfile 8 | from absl.testing import absltest 9 | 10 | import feature_utils 11 | import mass_spec_constants as ms_constants 12 | import spectra_predictor 13 | import test_utils 14 | 15 | import numpy as np 16 | import tensorflow as tf 17 | 18 | 19 | class DummySpectraPredictor(spectra_predictor.SpectraPredictor): 20 | """A test class that returns the mol weight input as the spectra prediction.""" 21 | 22 | def _setup_prediction_op(self): 23 | fingerprint_input_op = tf.placeholder(tf.float32, (None, 4096)) 24 | mol_weight_input_op = tf.placeholder(tf.float32, (None, 1)) 25 | 26 | feature_dict = { 27 | self.fingerprint_input_key: fingerprint_input_op, 28 | self.molecular_weight_key: mol_weight_input_op 29 | } 30 | 31 | predict_op = tf.multiply(fingerprint_input_op, mol_weight_input_op) 32 | return feature_dict, predict_op 33 | 34 | 35 | class SpectraPredictorTest(tf.test.TestCase): 36 | 37 | def setUp(self): 38 | super(SpectraPredictorTest, self).setUp() 39 | self.np_fingerprint_input = np.ones((2, 4096)) 40 | self.np_mol_weight_input = np.reshape(np.array([18., 16.]), (2, 1)) 41 | self.test_data_directory = test_utils.test_dir("testdata/") 42 | self.temp_dir = tempfile.mkdtemp(dir=absltest.get_default_test_tmpdir()) 43 | self.test_file_short = os.path.join(self.test_data_directory, 44 | "test_2_mend.sdf") 45 | 46 | def tearDown(self): 47 | tf.reset_default_graph() 48 | tf.io.gfile.rmtree(self.temp_dir) 49 | super(SpectraPredictorTest, self).tearDown() 50 | 51 | def test_make_dummy_spectra_prediction(self): 52 | """Tests the dummy predictor.""" 53 | predictor = DummySpectraPredictor() 54 | 55 | spectra_predictions = predictor.make_spectra_prediction( 56 | self.np_fingerprint_input, self.np_mol_weight_input) 57 | expected_value = np.multiply( 58 | self.np_fingerprint_input, self.np_mol_weight_input) 59 | expected_value = ( 60 | expected_value / np.max(expected_value, axis=1, keepdims=True) * 61 | spectra_predictor.SCALE_FACTOR_FOR_LARGEST_INTENSITY) 62 | self.assertAllEqual(spectra_predictions, expected_value) 63 | 64 | def test_make_neims_spectra_prediction_without_weights(self): 65 | """Tests loading the parameters for the neims model without weights.""" 66 | predictor = spectra_predictor.NeimsSpectraPredictor(model_checkpoint_dir="") 67 | 68 | spectra_predictions = predictor.make_spectra_prediction( 69 | self.np_fingerprint_input, self.np_mol_weight_input) 70 | 71 | self.assertEqual( 72 | np.shape(spectra_predictions), 73 | (np.shape(self.np_fingerprint_input)[0], ms_constants.MAX_PEAK_LOC)) 74 | 75 | def test_load_fingerprints_from_sdf(self): 76 | """Tests loading fingerprints from an sdf file.""" 77 | predictor = spectra_predictor.NeimsSpectraPredictor(model_checkpoint_dir="") 78 | 79 | mols_from_file = spectra_predictor.get_mol_list_from_sdf( 80 | self.test_file_short) 81 | fingerprints_from_file = predictor.get_fingerprints_from_mol_list( 82 | mols_from_file) 83 | 84 | self.assertEqual(np.shape(fingerprints_from_file), (2, 4096)) 85 | 86 | def test_write_spectra_to_sdf(self): 87 | """Tests predicting and writing spectra to file.""" 88 | predictor = spectra_predictor.NeimsSpectraPredictor(model_checkpoint_dir="") 89 | 90 | mols_from_file = spectra_predictor.get_mol_list_from_sdf( 91 | self.test_file_short) 92 | fingerprints, mol_weights = predictor.get_inputs_for_model_from_mol_list( 93 | mols_from_file) 94 | 95 | spectra_predictions = predictor.make_spectra_prediction( 96 | fingerprints, mol_weights) 97 | spectra_predictor.update_mols_with_spectra(mols_from_file, 98 | spectra_predictions) 99 | 100 | _, fpath = tempfile.mkstemp(dir=self.temp_dir) 101 | spectra_predictor.write_rdkit_mols_to_sdf(mols_from_file, fpath) 102 | 103 | # Test contents of newly written file: 104 | new_mol_list = spectra_predictor.get_mol_list_from_sdf(fpath) 105 | 106 | for idx, mol in enumerate(new_mol_list): 107 | peak_string_from_file = mol.GetProp( 108 | spectra_predictor.PREDICTED_SPECTRA_PROP_NAME) 109 | peak_locs, peak_intensities = feature_utils.parse_peaks( 110 | peak_string_from_file) 111 | dense_mass_spectra = feature_utils.make_dense_mass_spectra( 112 | peak_locs, peak_intensities, ms_constants.MAX_PEAK_LOC) 113 | self.assertSequenceAlmostEqual( 114 | dense_mass_spectra, spectra_predictions[idx, :].astype(np.int32), 115 | delta=1.) 116 | 117 | 118 | if __name__ == "__main__": 119 | tf.test.main() 120 | -------------------------------------------------------------------------------- /similarity.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helper functions for similarity computation.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import abc 21 | 22 | import numpy as np 23 | import tensorflow as tf 24 | 25 | # When computing cosine similarity, the denominator is constrained to be no 26 | # smaller than this. 27 | EPSILON = 1e-6 28 | 29 | 30 | class SimilarityProvider(object): 31 | """Abstract class of helpers for similarity-based library matching.""" 32 | __metaclass__ = abc.ABCMeta 33 | 34 | def __init__(self, hparams=None): 35 | self.hparams = hparams 36 | 37 | @abc.abstractmethod 38 | def preprocess_library(self, library): 39 | """Perform normalization of [num_library_elements, feature_dim] library.""" 40 | 41 | @abc.abstractmethod 42 | def undo_library_preprocessing(self, library): 43 | """Undo the effect of preprocess_library(), up to a scaling constant.""" 44 | 45 | @abc.abstractmethod 46 | def preprocess_queries(self, queries): 47 | """Perform normalization of [num_query_elements, feature_dim] queries.""" 48 | 49 | @abc.abstractmethod 50 | def compute_similarity(self, library, queries): 51 | """Compute [num_library_elements, num_query_elements] similarities.""" 52 | 53 | @abc.abstractmethod 54 | def make_training_loss(self, true_tensor, predicted_tensor): 55 | """Create training loss that is consistent with the similarity.""" 56 | 57 | 58 | class CosineSimilarityProvider(SimilarityProvider): 59 | """Cosine similarity.""" 60 | 61 | def _normalize_rows(self, tensor): 62 | return tf.nn.l2_normalize(tensor, axis=1) 63 | 64 | def preprocess_library(self, library): 65 | return self._normalize_rows(library) 66 | 67 | def undo_library_preprocessing(self, library): 68 | return library 69 | 70 | def preprocess_queries(self, queries): 71 | return self._normalize_rows(queries) 72 | 73 | def compute_similarity(self, library, queries): 74 | similarities = tf.matmul(library, queries, transpose_b=True) 75 | return tf.transpose(similarities) 76 | 77 | def make_training_loss(self, true_tensor, predicted_tensor): 78 | return tf.reduce_mean( 79 | tf.losses.mean_squared_error(true_tensor, predicted_tensor)) 80 | 81 | 82 | class GeneralizedCosineSimilarityProvider(CosineSimilarityProvider): 83 | """Custom cosine similarity that is popular for massspec matching.""" 84 | 85 | def _make_weights(self, tensor): 86 | num_bins = tensor.shape[1].value 87 | weights = np.power(np.arange(1, num_bins + 1), 88 | self.hparams.mass_power)[np.newaxis, :] 89 | return weights / np.sum(weights) 90 | 91 | def _normalize_rows(self, tensor): 92 | if self.hparams.mass_power != 0: 93 | tensor *= self._make_weights(tensor) 94 | 95 | return super(GeneralizedCosineSimilarityProvider, 96 | self)._normalize_rows(tensor) 97 | 98 | def undo_library_preprocessing(self, library): 99 | return library / self._make_weights(library) 100 | 101 | def compute_similarity(self, library, queries): 102 | similarities = tf.matmul(library, queries, transpose_b=True) 103 | return tf.transpose(similarities) 104 | 105 | def make_training_loss(self, true_tensor, predicted_tensor): 106 | if self.hparams.mass_power != 0: 107 | weights = self._make_weights(true_tensor) 108 | weighted_squared_error = weights * tf.square(true_tensor - 109 | predicted_tensor) 110 | return tf.reduce_mean(weighted_squared_error) 111 | else: 112 | return tf.reduce_mean( 113 | tf.losses.mean_squared_error(true_tensor, predicted_tensor)) 114 | 115 | 116 | def max_margin_ranking_loss(predictions, target_indices, library, 117 | similarity_provider, margin): 118 | """Max-margin ranking loss. 119 | 120 | loss = (1/batch_size) * sum_i w_i sum_j max(0, 121 | similarities[i, j] 122 | - similarities[i, ti] + margin), 123 | where similarities = similarity_provider.compute_similarity(library, 124 | predictions) 125 | and ti = target_indices[i]. Here, w_i is a weight placed on each element of 126 | the batch. Without w_i, our loss would be the standard Crammer-Singer 127 | multiclass svm. Instead, we set w_i so that the total constribution to the 128 | parameter gradient from each batch element is equal. Therefore, we set w_i 129 | equal to 1 / (the number of margin violations for element i). 130 | 131 | Args: 132 | predictions: [batch_size, prediction_dim] float Tensor 133 | target_indices: [batch_size] int Tensor 134 | library: [num_library_elements, prediction_dim] constant Tensor 135 | similarity_provider: a SimilarityProvider instance 136 | margin: float 137 | Returns: 138 | loss 139 | """ 140 | library = similarity_provider.preprocess_library(library) 141 | predictions = similarity_provider.preprocess_queries(predictions) 142 | similarities = similarity_provider.compute_similarity(library, predictions) 143 | 144 | batch_size = tf.shape(predictions)[0] 145 | 146 | target_indices = tf.squeeze(target_indices, axis=1) 147 | row_indices = tf.range(0, batch_size) 148 | indices = tf.stack([row_indices, tf.cast(target_indices, tf.int32)], axis=1) 149 | ground_truth_similarities = tf.gather_nd(similarities, indices) 150 | 151 | margin_violations = tf.nn.relu(-ground_truth_similarities[..., tf.newaxis] + 152 | similarities + margin) 153 | 154 | margin_violators = tf.cast(margin_violations > 0, tf.int32) 155 | margin_violators_per_batch_element = tf.to_float( 156 | tf.reduce_sum(margin_violators, axis=1, keep_dims=True)) 157 | margin_violators_per_batch_element = tf.maximum( 158 | margin_violators_per_batch_element, 1.) 159 | margin_violators_per_batch_element = tf.stop_gradient( 160 | margin_violators_per_batch_element) 161 | tf.summary.scalar('num_margin_violations', 162 | tf.reduce_mean(margin_violators_per_batch_element)) 163 | weighted_margin_violations = ( 164 | margin_violations / margin_violators_per_batch_element) 165 | return tf.reduce_sum(weighted_margin_violations) / tf.maximum( 166 | tf.to_float(batch_size), 1.) 167 | -------------------------------------------------------------------------------- /util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for utils.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | import tempfile 20 | 21 | from absl.testing import absltest 22 | from absl.testing import parameterized 23 | 24 | import util 25 | import numpy as np 26 | import tensorflow as tf 27 | 28 | 29 | class UtilTest(tf.test.TestCase, parameterized.TestCase): 30 | 31 | def setUp(self): 32 | self.temp_dir = tempfile.mkdtemp(dir=absltest.get_default_test_tmpdir()) 33 | 34 | def tearDown(self): 35 | tf.gfile.DeleteRecursively(self.temp_dir) 36 | 37 | def _make_model(self, batch_size, num_batches, variable_initializer_value): 38 | np_inputs = np.arange(batch_size * num_batches) 39 | np_inputs = np.float32(np_inputs) 40 | inputs = tf.data.Dataset.from_tensor_slices(np_inputs) 41 | inputs = inputs.batch(batch_size).make_one_shot_iterator().get_next() 42 | scale = tf.get_variable( 43 | name='scale', dtype=tf.float32, initializer=variable_initializer_value, 44 | trainable=True) 45 | output = inputs * scale 46 | return output 47 | 48 | def test_run_graph_and_process_results(self): 49 | 50 | batch_size = 3 51 | num_batches = 5 52 | 53 | # Make a graph that contains a Variable and save it to checkpoint. 54 | with tf.Graph().as_default(): 55 | _ = self._make_model( 56 | batch_size=batch_size, num_batches=num_batches, 57 | variable_initializer_value=2.0) 58 | saver = tf.train.Saver(tf.trainable_variables()) 59 | with tf.Session() as sess: 60 | sess.run(tf.global_variables_initializer()) 61 | saver.save(sess, self.temp_dir + '/model-') 62 | 63 | # Make another copy of the graph, and process data using this one. 64 | with tf.Graph().as_default(): 65 | # We intentionally make this graph have a different value for its Variable 66 | # than the graph above. When we restore from checkpoint, we will grab the 67 | # value from the first graph. This helps test that the Variables are 68 | # being properly restored from checkpoint. 69 | ops_to_fetch = self._make_model( 70 | batch_size=batch_size, num_batches=num_batches, 71 | variable_initializer_value=3.0 72 | ) 73 | 74 | results = [] 75 | def process_fetched_values_fn(np_array): 76 | results.append(np_array) 77 | 78 | model_checkpoint_path = self.temp_dir 79 | 80 | util.run_graph_and_process_results(ops_to_fetch, model_checkpoint_path, 81 | process_fetched_values_fn) 82 | 83 | results = np.concatenate(results, axis=0) 84 | expected_results = np.arange(num_batches * batch_size) * 2.0 85 | 86 | self.assertAllEqual(results, expected_results) 87 | 88 | @parameterized.parameters((7), (10), (65)) 89 | def test_map_predictor(self, sub_batch_size): 90 | input_op = { 91 | 'a': tf.random_normal(shape=(50, 5)), 92 | 'b': tf.random_normal(shape=(50, 5)) 93 | } 94 | 95 | def predictor_fn(data): 96 | return data['a'] + data['b'] 97 | 98 | mapped_prediction = util.map_predictor( 99 | input_op, predictor_fn, sub_batch_size=sub_batch_size) 100 | unmapped_prediction = predictor_fn(input_op) 101 | difference = tf.reduce_mean( 102 | tf.squared_difference(mapped_prediction, unmapped_prediction)) 103 | with tf.Session() as sess: 104 | self.assertLess( 105 | sess.run(difference), 1e-6, 106 | 'The output of _map_predictor does not match a direct ' 107 | 'application of predictor_fn.') 108 | 109 | def test_value_op_with_initializer(self): 110 | """Test correctness of library_matching.value_op_with_initializer.""" 111 | 112 | base_value_op = tf.get_variable('value', initializer=0.) 113 | 114 | def make_value_op(): 115 | return base_value_op 116 | 117 | def make_init_op(value): 118 | # This is a simple assignment that could have been achieved by changing 119 | # the initializer above. However, in other use cases of 120 | # value_op_with_initializer, the contructed value requires 121 | # data-dependent computation that can't be done via an initializer. 122 | return value.assign(tf.ones_like(value)) 123 | 124 | value_op = util.value_op_with_initializer(make_value_op, make_init_op) 125 | 126 | # Check that the value of the Variable generated by make_value_op() 127 | # is the value constructed by make_init_op, not the value given 128 | # the initializer given to the Variable's constructor. 129 | with tf.Session() as sess: 130 | sess.run(tf.global_variables_initializer()) 131 | sess.run(tf.local_variables_initializer()) 132 | self.assertAllEqual(sess.run(base_value_op), 0.0) 133 | self.assertAllEqual(sess.run(value_op), 1.0) 134 | self.assertAllEqual(sess.run(base_value_op), 1.0) 135 | 136 | def test_scatter_by_anchor_indices(self): 137 | 138 | def _validate(anchor_indices, data, index_shift, expected_output): 139 | with tf.Graph().as_default(): 140 | output = util.scatter_by_anchor_indices(anchor_indices, data, 141 | index_shift) 142 | with tf.Session() as sess: 143 | actual_output = sess.run(output) 144 | self.assertAllClose( 145 | np.array(expected_output, dtype=np.float32), actual_output) 146 | 147 | data = [[1, 2, 3], [4, 5, 6]] 148 | 149 | anchor_indices = [1, 1] 150 | index_shift = 0 151 | expected_output = [[2, 1, 0], [5, 4, 0]] 152 | _validate(anchor_indices, data, index_shift, expected_output) 153 | 154 | anchor_indices = [2, 2] 155 | index_shift = 0 156 | expected_output = [[3, 2, 1], [6, 5, 4]] 157 | _validate(anchor_indices, data, index_shift, expected_output) 158 | 159 | anchor_indices = [1, 1] 160 | index_shift = 1 161 | expected_output = [[3, 2, 1], [6, 5, 4]] 162 | _validate(anchor_indices, data, index_shift, expected_output) 163 | 164 | anchor_indices = [0, 1] 165 | index_shift = 1 166 | expected_output = [[2, 1, 0], [6, 5, 4]] 167 | _validate(anchor_indices, data, index_shift, expected_output) 168 | 169 | data = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]] 170 | anchor_indices = [1, 2, 3] 171 | index_shift = 0 172 | expected_output = [[2, 1, 0, 0], [7, 6, 5, 0], [12, 11, 10, 9]] 173 | _validate(anchor_indices, data, index_shift, expected_output) 174 | 175 | data = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]] 176 | anchor_indices = [0, 1, 2] 177 | index_shift = 1 178 | expected_output = [[2, 1, 0, 0], [7, 6, 5, 0], [12, 11, 10, 9]] 179 | _validate(anchor_indices, data, index_shift, expected_output) 180 | 181 | if __name__ == '__main__': 182 | tf.test.main() 183 | -------------------------------------------------------------------------------- /make_predictions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | r"""Run massspec model on data and write out predictions. 15 | 16 | Example usage: 17 | blaze-bin/third_party/py/deep_molecular_massspec/make_predictions \ 18 | --alsologtostderr --input_file=testdata/test_14_record.gz \ 19 | --output_file=/tmp/models/output_predictions \ 20 | --model_checkpoint_path=/tmp/models/output/ \ 21 | --hparams=eval_batch_size=16 22 | 23 | This saves a numpy archive to FLAGS.output_file that contains a dictionary 24 | where the keys are inchikeys and values are 1D np arrays for spectra. 25 | 26 | You should load this dict downstream using: 27 | data_dict = np.load(data_file).item() 28 | 29 | (Note that .item() is necessary because np.load returns a 0-D array, 30 | where the first element is the desired dictionary.) 31 | """ 32 | 33 | from __future__ import print_function 34 | import json 35 | import os 36 | import tempfile 37 | 38 | 39 | import dataset_setup_constants as ds_constants 40 | import feature_map_constants as fmap_constants 41 | 42 | # Note that many FLAGS are inherited from molecule_estimator 43 | import molecule_estimator 44 | 45 | import molecule_predictors 46 | import plot_spectra_utils 47 | import util 48 | 49 | import numpy as np 50 | import tensorflow as tf 51 | import matplotlib.pyplot as plt 52 | 53 | FLAGS = tf.app.flags.FLAGS 54 | 55 | tf.flags.DEFINE_string( 56 | 'input_file', None, 'Input TFRecord file or a ' 57 | 'globble file pattern for TFRecord files') 58 | tf.flags.DEFINE_string( 59 | 'model_checkpoint_path', None, 60 | 'Path to model checkpoint. If a directory, the most ' 61 | 'recent model checkpoint in this directory will be used. If a file, it ' 62 | 'should be of the form /.../name-of-the-file.ckpt-10000') 63 | tf.flags.DEFINE_bool( 64 | 'save_spectra_plots', True, 65 | 'Make plots of true and predicted spectra for each query molecule.' 66 | ) 67 | tf.flags.DEFINE_string('output_file', None, 68 | 'Location where outputs will be written.') 69 | 70 | 71 | def _make_features_and_labels_from_tfrecord(input_file_pattern, hparams, 72 | features_to_load): 73 | """Construct features and labels Tensors to be consumed by model_fn.""" 74 | 75 | def _make_tmp_dataset_config_file(input_filenames): 76 | """Construct a temporary config file that points to input_filename.""" 77 | 78 | _, tmp_file = tempfile.mkstemp() 79 | dataset_config = { 80 | ds_constants.SPECTRUM_PREDICTION_TRAIN_KEY: input_filenames 81 | } 82 | 83 | with tf.gfile.Open(tmp_file, 'w') as f: 84 | json.dump(dataset_config, f) 85 | 86 | return tmp_file 87 | 88 | input_files = tf.gfile.Glob(input_file_pattern) 89 | if not input_files: 90 | raise ValueError('No files found matching %s' % input_file_pattern) 91 | 92 | data_dir, _ = os.path.split(input_files[0]) 93 | data_basenames = [os.path.split(filename)[1] for filename in input_files] 94 | dataset_config_file = _make_tmp_dataset_config_file(data_basenames) 95 | 96 | mode = tf.estimator.ModeKeys.PREDICT 97 | input_fn = molecule_estimator.make_input_fn( 98 | dataset_config_file=dataset_config_file, 99 | hparams=hparams, 100 | mode=mode, 101 | features_to_load=features_to_load, 102 | data_dir=data_dir, 103 | load_library_matching_data=False) 104 | tf.gfile.Remove(dataset_config_file) 105 | return input_fn() 106 | 107 | 108 | def _make_features_labels_and_estimator(model_type, hparam_string, input_file): 109 | """Construct input ops and EstimatorSpec for massspec model.""" 110 | 111 | prediction_helper = molecule_predictors.get_prediction_helper(model_type) 112 | hparams = prediction_helper.get_default_hparams() 113 | hparams.parse(hparam_string) 114 | 115 | model_fn = molecule_estimator.make_model_fn( 116 | prediction_helper, dataset_config_file=None, model_dir=None) 117 | 118 | features_to_load = prediction_helper.features_to_load(hparams) 119 | features, labels = _make_features_and_labels_from_tfrecord( 120 | input_file, hparams, features_to_load) 121 | 122 | estimator_spec = model_fn( 123 | features, labels, hparams, mode=tf.estimator.ModeKeys.PREDICT) 124 | 125 | return features, labels, estimator_spec 126 | 127 | 128 | def _save_plot_figure(key, prediction, true_spectrum, results_dir): 129 | """A helper function that makes and saves plots of true and predicted spectra.""" 130 | spectra_plot_file_name = plot_spectra_utils.name_plot_file( 131 | plot_spectra_utils.PlotModeKeys.PREDICTED_SPECTRUM, key, file_type='png') 132 | 133 | # Rescale the true/predicted spectra 134 | true_spectrum = true_spectrum / true_spectrum.max() * plot_spectra_utils.MAX_VALUE_OF_TRUE_SPECTRA 135 | prediction = prediction / prediction.max() * plot_spectra_utils.MAX_VALUE_OF_TRUE_SPECTRA 136 | 137 | plot_spectra_utils.plot_true_and_predicted_spectra( 138 | true_spectrum, prediction, 139 | output_filename=os.path.join(results_dir,spectra_plot_file_name), 140 | rescale_mz_axis=True 141 | ) 142 | 143 | 144 | def main(_): 145 | 146 | features, labels, estimator_spec = _make_features_labels_and_estimator( 147 | FLAGS.model_type, FLAGS.hparams, FLAGS.input_file) 148 | del labels # Unused 149 | 150 | pred_op = estimator_spec.predictions 151 | inchikey_op = features[fmap_constants.SPECTRUM_PREDICTION][ 152 | fmap_constants.INCHIKEY] 153 | ops_to_fetch = [inchikey_op, pred_op] 154 | if FLAGS.save_spectra_plots: 155 | true_spectra_op = features[fmap_constants.SPECTRUM_PREDICTION][fmap_constants.DENSE_MASS_SPEC] 156 | ops_to_fetch.append(true_spectra_op) 157 | 158 | results = {} 159 | results_dir = os.path.dirname(FLAGS.output_file) 160 | tf.gfile.MakeDirs(results_dir) 161 | 162 | def process_fetched_values_fn(fetched_values): 163 | if FLAGS.save_spectra_plots: 164 | keys, predictions, true_spectra = fetched_values 165 | for key, prediction, true_spectrum in zip(keys, predictions, true_spectra): 166 | # Dereference the singleton np string array to get the actual string. 167 | key = key[0] 168 | results[key] = prediction 169 | _save_plot_figure(key, prediction, true_spectrum, results_dir) 170 | else: 171 | keys, predictions = fetched_values 172 | for key, prediction in zip(keys, predictions): 173 | # Dereference the singleton np string array to get the actual string. 174 | key = key[0] 175 | results[key] = prediction 176 | 177 | util.run_graph_and_process_results(ops_to_fetch, FLAGS.model_checkpoint_path, 178 | process_fetched_values_fn) 179 | 180 | np.save(FLAGS.output_file, results) 181 | 182 | 183 | if __name__ == '__main__': 184 | for flag in ['input_file', 'model_checkpoint_path', 'output_file']: 185 | tf.app.flags.mark_flag_as_required(flag) 186 | 187 | tf.app.run(main) 188 | -------------------------------------------------------------------------------- /make_predictions_from_tfrecord.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | r"""Run massspec model on data and write out predictions. 15 | 16 | Example usage: 17 | blaze-bin/third_party/py/deep_molecular_massspec/make_predictions \ 18 | --alsologtostderr --input_file=testdata/test_14_record.gz \ 19 | --output_file=/tmp/models/output_predictions \ 20 | --model_checkpoint_path=/tmp/models/output/ \ 21 | --hparams=eval_batch_size=16 22 | 23 | This saves a numpy archive to FLAGS.output_file that contains a dictionary 24 | where the keys are inchikeys and values are 1D np arrays for spectra. 25 | 26 | You should load this dict downstream using: 27 | data_dict = np.load(data_file).item() 28 | 29 | (Note that .item() is necessary because np.load returns a 0-D array, 30 | where the first element is the desired dictionary.) 31 | """ 32 | 33 | from __future__ import print_function 34 | import json 35 | import os 36 | import tempfile 37 | 38 | 39 | from absl import flags 40 | 41 | import dataset_setup_constants as ds_constants 42 | import feature_map_constants as fmap_constants 43 | 44 | # Note that many FLAGS are inherited from molecule_estimator 45 | import molecule_estimator 46 | 47 | import molecule_predictors 48 | import plot_spectra_utils 49 | import util 50 | 51 | import numpy as np 52 | import tensorflow as tf 53 | 54 | FLAGS = flags.FLAGS 55 | 56 | flags.DEFINE_string( 57 | 'input_file', None, 'Input TFRecord file or a ' 58 | 'globble file pattern for TFRecord files') 59 | flags.DEFINE_string( 60 | 'model_checkpoint_path', None, 61 | 'Path to model checkpoint. If a directory, the most ' 62 | 'recent model checkpoint in this directory will be used. If a file, it ' 63 | 'should be of the form /.../name-of-the-file.ckpt-10000') 64 | flags.DEFINE_bool( 65 | 'save_spectra_plots', True, 66 | 'Make plots of true and predicted spectra for each query molecule.') 67 | flags.DEFINE_string('output_file', None, 68 | 'Location where outputs will be written.') 69 | 70 | 71 | def _make_features_and_labels_from_tfrecord(input_file_pattern, hparams, 72 | features_to_load): 73 | """Construct features and labels Tensors to be consumed by model_fn.""" 74 | 75 | def _make_tmp_dataset_config_file(input_filenames): 76 | """Construct a temporary config file that points to input_filename.""" 77 | 78 | _, tmp_file = tempfile.mkstemp() 79 | dataset_config = { 80 | ds_constants.SPECTRUM_PREDICTION_TRAIN_KEY: input_filenames 81 | } 82 | 83 | # TODO(b/135189673): Replace this with tf.gfile.Open once the type 84 | # issue is fixed. 85 | with open(tmp_file, 'w') as f: 86 | json.dump(dataset_config, f) 87 | 88 | return tmp_file 89 | 90 | input_files = tf.gfile.Glob(input_file_pattern) 91 | if not input_files: 92 | raise ValueError('No files found matching %s' % input_file_pattern) 93 | 94 | data_dir, _ = os.path.split(input_files[0]) 95 | data_basenames = [os.path.split(filename)[1] for filename in input_files] 96 | dataset_config_file = _make_tmp_dataset_config_file(data_basenames) 97 | 98 | mode = tf.estimator.ModeKeys.PREDICT 99 | input_fn = molecule_estimator.make_input_fn( 100 | dataset_config_file=dataset_config_file, 101 | hparams=hparams, 102 | mode=mode, 103 | features_to_load=features_to_load, 104 | data_dir=data_dir, 105 | load_library_matching_data=False) 106 | tf.gfile.Remove(dataset_config_file) 107 | return input_fn() 108 | 109 | 110 | def _make_features_labels_and_estimator(model_type, hparam_string, input_file): 111 | """Construct input ops and EstimatorSpec for massspec model.""" 112 | 113 | prediction_helper = molecule_predictors.get_prediction_helper(model_type) 114 | hparams = prediction_helper.get_default_hparams() 115 | hparams.parse(hparam_string) 116 | 117 | model_fn = molecule_estimator.make_model_fn( 118 | prediction_helper, dataset_config_file=None, model_dir=None) 119 | 120 | features_to_load = prediction_helper.features_to_load(hparams) 121 | features, labels = _make_features_and_labels_from_tfrecord( 122 | input_file, hparams, features_to_load) 123 | 124 | estimator_spec = model_fn( 125 | features, labels, hparams, mode=tf.estimator.ModeKeys.PREDICT) 126 | 127 | return features, labels, estimator_spec 128 | 129 | 130 | def _save_plot_figure(key, prediction, true_spectrum, results_dir): 131 | """A helper function that makes and saves plots of true and predicted spectra.""" 132 | spectra_plot_file_name = plot_spectra_utils.name_plot_file( 133 | plot_spectra_utils.PlotModeKeys.PREDICTED_SPECTRUM, key, file_type='png') 134 | 135 | # Rescale the true/predicted spectra 136 | true_spectrum = ( 137 | true_spectrum / true_spectrum.max() * 138 | plot_spectra_utils.MAX_VALUE_OF_TRUE_SPECTRA) 139 | prediction = ( 140 | prediction / prediction.max() * 141 | plot_spectra_utils.MAX_VALUE_OF_TRUE_SPECTRA) 142 | 143 | plot_spectra_utils.plot_true_and_predicted_spectra( 144 | true_spectrum, 145 | prediction, 146 | output_filename=os.path.join(results_dir, spectra_plot_file_name), 147 | rescale_mz_axis=True) 148 | 149 | 150 | def main(_): 151 | 152 | features, labels, estimator_spec = _make_features_labels_and_estimator( 153 | FLAGS.model_type, FLAGS.hparams, FLAGS.input_file) 154 | del labels # Unused 155 | 156 | pred_op = estimator_spec.predictions 157 | inchikey_op = features[fmap_constants.SPECTRUM_PREDICTION][ 158 | fmap_constants.INCHIKEY] 159 | ops_to_fetch = [inchikey_op, pred_op] 160 | if FLAGS.save_spectra_plots: 161 | true_spectra_op = features[fmap_constants.SPECTRUM_PREDICTION][ 162 | fmap_constants.DENSE_MASS_SPEC] 163 | ops_to_fetch.append(true_spectra_op) 164 | 165 | results = {} 166 | results_dir = os.path.dirname(FLAGS.output_file) 167 | tf.gfile.MakeDirs(results_dir) 168 | 169 | def process_fetched_values_fn(fetched_values): 170 | """Processes output values from estimator.""" 171 | if FLAGS.save_spectra_plots: 172 | keys, predictions, true_spectra = fetched_values 173 | for key, prediction, true_spectrum in zip(keys, predictions, 174 | true_spectra): 175 | # Dereference the singleton np string array to get the actual string. 176 | key = key[0] 177 | results[key] = prediction 178 | # Spectra plots are saved to file. 179 | _save_plot_figure(key, prediction, true_spectrum, results_dir) 180 | else: 181 | keys, predictions = fetched_values 182 | for key, prediction in zip(keys, predictions): 183 | # Dereference the singleton np string array to get the actual string. 184 | key = key[0] 185 | results[key] = prediction 186 | 187 | util.run_graph_and_process_results(ops_to_fetch, FLAGS.model_checkpoint_path, 188 | process_fetched_values_fn) 189 | 190 | tf.gfile.MakeDirs(os.path.dirname(FLAGS.output_file)) 191 | np.save(FLAGS.output_file, results) 192 | 193 | 194 | if __name__ == '__main__': 195 | for flag in ['input_file', 'model_checkpoint_path', 'output_file']: 196 | flags.mark_flag_as_required(flag) 197 | 198 | tf.app.run(main) 199 | -------------------------------------------------------------------------------- /train_test_split_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for making train test split for mass spectra datsets. 16 | 17 | Contains TrainValFractions namedtuple for passing 3-tuple of train, validation, 18 | and test fractions to use of a datasets. Also contains TrainValTestInchikeys 19 | namedtuple for 3-tuple of lists of inchikeys to put into train, validation, and 20 | test splits. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | from collections import namedtuple 27 | import random 28 | 29 | import feature_utils 30 | import mass_spec_constants as ms_constants 31 | import numpy as np 32 | 33 | # Helper class for storing train, validation, and test fractions. 34 | TrainValTestFractions = namedtuple('TrainValTestFractions', 35 | ['train', 'validation', 'test']) 36 | 37 | # Helper class for storing train, validation, and test inchikeys after split. 38 | TrainValTestInchikeys = namedtuple('TrainValTestInchikeys', 39 | ['train', 'validation', 'test']) 40 | 41 | 42 | def assert_all_lists_mutally_exclusive(list_of_lists): 43 | """Check if any lists within a list of lists contain identical items.""" 44 | for idx, list1 in enumerate((list_of_lists)): 45 | for list2 in list_of_lists[idx + 1:]: 46 | if any(elem in list2 for elem in list1): 47 | raise ValueError( 48 | 'found matching items between two lists: \n {}\n {}'.format( 49 | ', '.join(list1), 50 | ', '.join(list2), 51 | )) 52 | 53 | 54 | def make_inchikey_dict(mol_list): 55 | """Converts rdkit.Mol list into dict of lists of Mols keyed by inchikey.""" 56 | inchikey_dict = {} 57 | for mol in mol_list: 58 | inchikey = mol.GetProp(ms_constants.SDF_TAG_INCHIKEY) 59 | if inchikey not in inchikey_dict: 60 | inchikey_dict[inchikey] = [mol] 61 | else: 62 | inchikey_dict[inchikey].append(mol) 63 | return inchikey_dict 64 | 65 | 66 | def get_random_inchikeys(inchikey_list, train_val_test_split_fractions): 67 | """Splits a given inchikey list of into 3 lists for train/val/test sets.""" 68 | random.shuffle(inchikey_list) 69 | 70 | train_num = int(train_val_test_split_fractions.train * len(inchikey_list)) 71 | val_num = int(train_val_test_split_fractions.validation * len(inchikey_list)) 72 | 73 | return TrainValTestInchikeys(inchikey_list[:train_num], 74 | inchikey_list[train_num:train_num + val_num], 75 | inchikey_list[train_num + val_num:]) 76 | 77 | 78 | def get_inchikeys_by_family(inchikey_list, 79 | inchikey_dict, 80 | train_val_test_split_fractions, 81 | family_name='steroid', 82 | exclude_from_train=True): 83 | """Creates train/val/test split based on presence of steroids. 84 | 85 | Filters molecules according to whether they have the substructure specified 86 | by family_name. All molecules passing the filter will be placed in 87 | validation/test datasets or into the train set according to exclude from 88 | train. The molecules assigned to the validation/test split according to the 89 | relative ratio between the validation/test fractions. 90 | 91 | If the validation and tests fractions are both equal to 0.0, these values 92 | will be over written to 0.5 and 0.5. 93 | 94 | Args: 95 | inchikey_list: List of inchikeys to partition into train/val/test sets 96 | inchikey_dict: dict of inchikeys, [rdkit.Mol objects]. 97 | Must contain inchikey_list in its keys. 98 | train_val_test_split_fractions: a TrainValTestFractions tuple 99 | family_name: str, a key in feature_utils.FAMILY_DICT 100 | exclude_from_train: indicates whether to include/exclude steroid molecules 101 | from training set. If excluded from training set, test and validation 102 | sets will be comprised only of these molecules. 103 | Returns: 104 | TrainValTestInchikeys object 105 | """ 106 | _, val_fraction, test_fraction = train_val_test_split_fractions 107 | if val_fraction == 0.0 and test_fraction == 0.0: 108 | val_fraction = 0.5 109 | test_fraction = 0.5 110 | 111 | substructure_filter_fn = feature_utils.make_filter_by_substructure( 112 | family_name) 113 | family_inchikeys = [] 114 | nonfamily_inchikeys = [] 115 | 116 | for ikey in inchikey_list: 117 | if substructure_filter_fn(inchikey_dict[ikey][0]): 118 | family_inchikeys.append(ikey) 119 | else: 120 | nonfamily_inchikeys.append(ikey) 121 | 122 | if exclude_from_train: 123 | val_test_inchikeys, train_inchikeys = (family_inchikeys, 124 | nonfamily_inchikeys) 125 | else: 126 | train_inchikeys, val_test_inchikeys = (family_inchikeys, 127 | nonfamily_inchikeys) 128 | 129 | random.shuffle(val_test_inchikeys) 130 | val_num = int( 131 | val_fraction / (val_fraction + test_fraction) * len(val_test_inchikeys)) 132 | return TrainValTestInchikeys(train_inchikeys, val_test_inchikeys[:val_num], 133 | val_test_inchikeys[val_num:]) 134 | 135 | 136 | def make_train_val_test_split_inchikey_lists(train_inchikey_list, 137 | train_inchikey_dict, 138 | train_val_test_split_fractions, 139 | holdout_inchikey_list=None, 140 | splitting_type='random'): 141 | """Given inchikey lists, returns lists to use for train/val/test sets. 142 | 143 | If holdout_inchikey_list is given, the inchikeys in this list will be excluded 144 | from the returned train/validation/test lists. 145 | 146 | Args: 147 | train_inchikey_list : List of inchikeys to use for train/val/test sets 148 | train_inchikey_dict : Main dict keyed by inchikeys, values are lists of 149 | rdkit.Mol. Note that train_inchikey_dict.keys() != train_inchikey_list 150 | train_inchikey_dict will have many more keys than are in the list. 151 | train_val_test_split_fractions : a TrainValTestFractions tuple 152 | holdout_inchikey_list : List of inchikeys to exclude from train/val/test 153 | sets. 154 | splitting_type : method of splitting molecules into train/val/test sets. 155 | Returns: 156 | A TrainValTestInchikeys namedtuple 157 | Raises: 158 | ValueError : if not train_val_test_split_sizes XOR 159 | train_val_test_split_fractions 160 | or if specify a splitting_type that isn't implemented yet. 161 | """ 162 | if not np.isclose([sum(train_val_test_split_fractions)], [1.0]): 163 | raise ValueError('Must specify train_val_test_split that sums to 1.0') 164 | 165 | if holdout_inchikey_list: 166 | # filter out those inchikeys that are in the holdout set. 167 | train_inchikey_list = [ 168 | ikey for ikey in train_inchikey_list 169 | if ikey not in holdout_inchikey_list 170 | ] 171 | 172 | if splitting_type == 'random': 173 | return get_random_inchikeys(train_inchikey_list, 174 | train_val_test_split_fractions) 175 | else: 176 | # Assume that splitting_type is the name of a structure family. 177 | # get_inchikeys_by_family will throw an error if this is not supported. 178 | return get_inchikeys_by_family( 179 | train_inchikey_list, 180 | train_inchikey_dict, 181 | train_val_test_split_fractions, 182 | family_name=splitting_type, 183 | exclude_from_train=True) 184 | 185 | 186 | def make_mol_list_from_inchikey_dict(inchikey_dict, inchikey_list): 187 | """Return a list of rdkit.Mols given a list of inchikeys. 188 | 189 | Args: 190 | inchikey_dict : a dict of lists of rdkit.Mol objects keyed by inchikey 191 | inchikey_list : List of inchikeys of molecules we want in a list. 192 | Returns: 193 | A list of rdkit.Mols corresponding to inchikeys in inchikey_list. 194 | """ 195 | mol_list = [] 196 | for inchikey in inchikey_list: 197 | mol_list.extend(inchikey_dict[inchikey]) 198 | 199 | return mol_list 200 | -------------------------------------------------------------------------------- /molecule_estimator_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for molecule_estimator.""" 16 | 17 | from __future__ import print_function 18 | import json 19 | import os 20 | import tempfile 21 | 22 | 23 | from absl.testing import absltest 24 | from absl.testing import parameterized 25 | 26 | import dataset_setup_constants as ds_constants 27 | import mass_spec_constants as ms_constants 28 | import molecule_estimator 29 | import molecule_predictors 30 | import parse_sdf_utils 31 | import plot_spectra_utils 32 | import test_utils 33 | import numpy as np 34 | import tensorflow as tf 35 | 36 | 37 | class MoleculeEstimatorTest(tf.test.TestCase, parameterized.TestCase): 38 | 39 | def setUp(self): 40 | """Sets up a dataset json for regular, baseline, and all_predicted cases.""" 41 | super(MoleculeEstimatorTest, self).setUp() 42 | self.test_data_directory = test_utils.test_dir('testdata/') 43 | record_file = os.path.join(self.test_data_directory, 'test_14_record.gz') 44 | 45 | self.num_eval_examples = parse_sdf_utils.parse_info_file(record_file)[ 46 | 'num_examples'] 47 | self.temp_dir = tempfile.mkdtemp(dir=absltest.get_default_test_tmpdir()) 48 | self.default_dataset_config_file = os.path.join(self.temp_dir, 49 | 'dataset_config.json') 50 | self.baseline_dataset_config_file = os.path.join( 51 | self.temp_dir, 'baseline_dataset_config.json') 52 | self.all_predicted_dataset_config_file = os.path.join( 53 | self.temp_dir, 'all_predicted_dataset_config.json') 54 | 55 | dataset_names = [ 56 | ds_constants.SPECTRUM_PREDICTION_TRAIN_KEY, 57 | ds_constants.SPECTRUM_PREDICTION_TEST_KEY, 58 | ds_constants.LIBRARY_MATCHING_OBSERVED_KEY, 59 | ds_constants.LIBRARY_MATCHING_PREDICTED_KEY, 60 | ds_constants.LIBRARY_MATCHING_QUERY_KEY 61 | ] 62 | 63 | default_dataset_config = {key: [record_file] for key in dataset_names} 64 | default_dataset_config[ 65 | ds_constants.TRAINING_SPECTRA_ARRAY_KEY] = os.path.join( 66 | self.test_data_directory, 'test_14.spectra_library.npy') 67 | with tf.gfile.Open(self.default_dataset_config_file, 'w') as f: 68 | json.dump(default_dataset_config, f) 69 | 70 | # Test estimator behavior when predicted set is empty 71 | baseline_dataset_config = dict( 72 | [(key, [record_file]) 73 | if key != ds_constants.LIBRARY_MATCHING_PREDICTED_KEY else (key, []) 74 | for key in dataset_names]) 75 | baseline_dataset_config[ 76 | ds_constants.TRAINING_SPECTRA_ARRAY_KEY] = os.path.join( 77 | self.test_data_directory, 'test_14.spectra_library.npy') 78 | with tf.gfile.Open(self.baseline_dataset_config_file, 'w') as f: 79 | json.dump(baseline_dataset_config, f) 80 | 81 | # Test estimator behavior when observed set is empty 82 | all_predicted_dataset_config = dict( 83 | [(key, [record_file]) 84 | if key != ds_constants.LIBRARY_MATCHING_OBSERVED_KEY else (key, []) 85 | for key in dataset_names]) 86 | all_predicted_dataset_config[ 87 | ds_constants.TRAINING_SPECTRA_ARRAY_KEY] = os.path.join( 88 | self.test_data_directory, 'test_14.spectra_library.npy') 89 | with tf.gfile.Open(self.all_predicted_dataset_config_file, 'w') as f: 90 | json.dump(all_predicted_dataset_config, f) 91 | 92 | def tearDown(self): 93 | tf.gfile.DeleteRecursively(self.temp_dir) 94 | super(MoleculeEstimatorTest, self).tearDown() 95 | 96 | def _get_loss_history(self, checkpoint_dir): 97 | """Get list of train losses from events file.""" 98 | losses = [] 99 | for event_file in tf.gfile.Glob( 100 | os.path.join(checkpoint_dir, 'events.out.tfevents.*')): 101 | for event in tf.train.summary_iterator(event_file): 102 | for v in event.summary.value: 103 | if v.tag == 'loss': 104 | losses.append(v.simple_value) 105 | return losses 106 | 107 | def _run_estimator(self, prediction_helper, get_hparams, dataset_config_file): 108 | """Helper function for running molecule_estimator.""" 109 | checkpoint_dir = self.temp_dir 110 | config = tf.contrib.learn.RunConfig( 111 | model_dir=checkpoint_dir, save_summary_steps=1) 112 | (estimator, train_spec, 113 | eval_spec) = molecule_estimator.make_estimator_and_inputs( 114 | config, 115 | get_hparams(), 116 | prediction_helper, 117 | dataset_config_file, 118 | train_steps=10, 119 | model_dir=self.temp_dir) 120 | 121 | tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) 122 | loss_history = self._get_loss_history(checkpoint_dir) 123 | 124 | init_loss = loss_history[0] 125 | loss = loss_history[-1] 126 | if not np.isfinite(loss): 127 | raise ValueError('Final loss is not finite: %f' % loss) 128 | 129 | tf.logging.info('initial loss : {} final loss : {}'.format(init_loss, loss)) 130 | 131 | self.assertNotEqual(loss, init_loss, 132 | ('Loss did not change after brief testing: ' 133 | 'init = %f, final = %f.') % (init_loss, loss)) 134 | 135 | @parameterized.parameters( 136 | ('linear', 0, 'generalized_mse', False), 137 | ('mlp', 1, 'generalized_mse', True, True), 138 | ('linear', 0, 'cross_entropy', True), 139 | ('mlp', 2, 'generalized_mse', False), 140 | ('linear', 0, 'max_margin', True), 141 | ('smiles_rnn', 0, 'generalized_mse', True, True)) 142 | def test_run_estimator(self, model_type, num_hidden_layers, loss_type, 143 | do_library_matching, bidirectional_prediction=False): 144 | """Integration test for molecule_estimator.""" 145 | prediction_helper = molecule_predictors.get_prediction_helper(model_type) 146 | 147 | def get_hparams(): 148 | hparams = prediction_helper.get_default_hparams() 149 | hparams.set_hparam('loss', loss_type) 150 | hparams.set_hparam('do_library_matching', do_library_matching) 151 | hparams.set_hparam('bidirectional_prediction', bidirectional_prediction) 152 | 153 | # To test batching and padding in library matching, set the 154 | # eval_batch_size such that it does not divide the number of examples 155 | # in the test set. 156 | eval_batch_size = np.int32(np.floor(self.num_eval_examples / 2) - 1) 157 | assert eval_batch_size > 0, ('The evaluation data is not big enough to ' 158 | 'support using multiple batches, where the ' 159 | 'batch size does not divide the total ' 160 | 'number of examples.') 161 | hparams.set_hparam('eval_batch_size', eval_batch_size) 162 | 163 | if model_type == 'mlp': 164 | hparams.set_hparam('num_hidden_layers', num_hidden_layers) 165 | return hparams 166 | 167 | self._run_estimator(prediction_helper, get_hparams, 168 | self.default_dataset_config_file) 169 | 170 | def test_run_estimator_on_baseline(self): 171 | prediction_helper = molecule_predictors.get_prediction_helper('baseline') 172 | self._run_estimator(prediction_helper, 173 | prediction_helper.get_default_hparams, 174 | self.baseline_dataset_config_file) 175 | 176 | def test_run_estimator_on_all_predicted(self): 177 | prediction_helper = molecule_predictors.get_prediction_helper('mlp') 178 | self._run_estimator(prediction_helper, 179 | prediction_helper.get_default_hparams, 180 | self.all_predicted_dataset_config_file) 181 | 182 | def test_plot_true_and_predicted_spectra(self): 183 | """Test if plot is successfully generated given two spectra.""" 184 | max_mass_spec_peak_loc = ms_constants.MAX_PEAK_LOC 185 | true_spectra = np.zeros(max_mass_spec_peak_loc) 186 | predicted_spectra = np.zeros(max_mass_spec_peak_loc) 187 | true_spectra[3:6] = 60 188 | predicted_spectra[300] = 999 189 | true_spectra[200] = 780 190 | 191 | test_figure_path_name = os.path.join(self.temp_dir, 'test.png') 192 | generated_plot = plot_spectra_utils.plot_true_and_predicted_spectra( 193 | true_spectra, predicted_spectra, output_filename=test_figure_path_name) 194 | 195 | self.assertEqual( 196 | np.shape(generated_plot), 197 | plot_spectra_utils.SPECTRA_PLOT_DIMENSIONS_RGB) 198 | self.assertTrue(os.path.exists(test_figure_path_name)) 199 | 200 | 201 | if __name__ == '__main__': 202 | tf.test.main() 203 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Some general-purpose helper functions.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import numpy as np 21 | import tensorflow as tf 22 | 23 | 24 | def _get_ckpt_from_path(path): 25 | ckpt = tf.train.latest_checkpoint(path) 26 | if ckpt is None: 27 | raise ValueError('No checkpoint found in %s' % path) 28 | tf.logging.info('Reading from checkpoint %s', ckpt) 29 | return ckpt 30 | 31 | 32 | def run_graph_and_process_results(ops_to_fetch, 33 | model_checkpoint_path, 34 | process_fetched_values_fn, 35 | feed_dict=None, 36 | logging_frequency=10): 37 | """Run a graph repeatedly and use the fetched values. 38 | 39 | Args: 40 | ops_to_fetch: a single Tensor or nested structure of Tensors. The graph will 41 | be run, and the below callables will be called, until a 42 | tf.errors.OutOfRangeError is caught. This is thrown when a tf.data.Dataset 43 | runs out of data. 44 | model_checkpoint_path: Path to model checkpoint. If a directory, the most 45 | recent model checkpoint in this directory will be used. 46 | process_fetched_values_fn: A callable, potentially with side-effects, that 47 | takes as input the output of sess.run(ops_to_fetch). 48 | feed_dict: a feed_dict to be included in sess.run calls. 49 | logging_frequency: after this many batches have been processed, a logging 50 | message will be printed. 51 | """ 52 | ckpt = _get_ckpt_from_path(model_checkpoint_path) 53 | saver = tf.train.Saver() 54 | 55 | with tf.Session() as sess: 56 | saver.restore(sess, ckpt) 57 | counter = 0 58 | while True: 59 | try: 60 | fetched_values = sess.run(ops_to_fetch, feed_dict=feed_dict) 61 | process_fetched_values_fn(fetched_values) 62 | counter += 1 63 | if counter % logging_frequency == 0: 64 | tf.logging.info('Total examples processed so far: %d', counter) 65 | except tf.errors.OutOfRangeError: 66 | tf.logging.info('Finished processing data. Processed %d batches', 67 | counter) 68 | break 69 | 70 | 71 | def map_predictor(input_op, predictor_fn, sub_batch_size): 72 | """Wrapper for tf.map_fn to do batched computation within each map step.""" 73 | 74 | num_elements = tf.contrib.framework.nest.flatten(input_op)[0].shape[0].value 75 | 76 | # Only chop the batch dim into sub-batches if the input data is big. 77 | if num_elements < sub_batch_size: 78 | return predictor_fn(input_op) 79 | 80 | pad_amount = -num_elements % sub_batch_size 81 | 82 | def reshape(tensor): 83 | """Reshape into batches of sub-batches.""" 84 | pad_shape = tensor.shape.as_list() 85 | pad_shape[0] = pad_amount 86 | padding = tf.zeros(shape=pad_shape, dtype=tensor.dtype) 87 | tensor = tf.concat([tensor, padding], axis=0) 88 | if tensor.shape[0].value % sub_batch_size != 0: 89 | raise ValueError('Incorrent padding size: %d does not ' 90 | 'divide %d' % (sub_batch_size, tensor.shape[0].value)) 91 | 92 | shape = tensor.shape.as_list() 93 | output_shape = [-1, sub_batch_size] + shape[1:] 94 | return tf.reshape(tensor, shape=output_shape) 95 | 96 | reshaped_inputs = tf.contrib.framework.nest.map_structure(reshape, input_op) 97 | 98 | mapped_prediction = tf.map_fn( 99 | predictor_fn, 100 | reshaped_inputs, 101 | parallel_iterations=1, 102 | back_prop=False, 103 | name=None, 104 | dtype=tf.float32) 105 | 106 | output_shape = [-1] + mapped_prediction.shape.as_list()[2:] 107 | reshaped_output = tf.reshape(mapped_prediction, shape=output_shape) 108 | 109 | # If padding was required for the input data, strip off the output of the 110 | # predictor on this padding. 111 | if pad_amount > 0: 112 | reshaped_output = reshaped_output[0:(-pad_amount), ...] 113 | 114 | return reshaped_output 115 | 116 | 117 | def get_static_shape_without_adding_ops(inputs, fn): 118 | """Get the shape of fn(inputs) without adding ops to the default graph. 119 | 120 | Operationally equivalent to fn(inputs).shape.as_list(), except that no 121 | ops are added to the default graph. 122 | 123 | In order to get the shape of fn(inputs) without adding ops to the graph 124 | we make a new graph, make placeholders with the right shape, construct 125 | fn(placeholders) in that graph, get the shape, and then delete the graph. 126 | 127 | Note that using this function may have unintended consequences if fn() has 128 | side effects. 129 | 130 | Args: 131 | inputs: a (nested) structure where the leaf elements are either Tensors or 132 | None. 133 | fn: a function that can be applied to inputs and returns a single Tensor. 134 | Returns: 135 | a python list containing the static shape of fn(inputs). 136 | 137 | """ 138 | g = tf.Graph() 139 | with g.as_default(): 140 | def make_placeholder(tensor): 141 | if tensor is None: 142 | return None 143 | else: 144 | return tf.placeholder(shape=tensor.shape, dtype=tensor.dtype) 145 | 146 | placeholders = tf.contrib.framework.nest.map_structure(make_placeholder, 147 | inputs) 148 | output_shape = fn(placeholders).shape.as_list() 149 | 150 | del g 151 | return output_shape 152 | 153 | 154 | def value_op_with_initializer(value_op_fn, init_op_fn): 155 | """Make value_op that gets set by idempotent init_op on first invocation.""" 156 | 157 | init_has_been_run = tf.get_local_variable( 158 | 'has_been_run', 159 | initializer=np.zeros(shape=(), dtype=np.bool), 160 | dtype=tf.bool) 161 | 162 | value_op = value_op_fn() 163 | 164 | def run_init_and_toggle(): 165 | init_op = init_op_fn(value_op) 166 | 167 | with tf.control_dependencies([init_op]): 168 | assign_op = init_has_been_run.assign(True) 169 | 170 | with tf.control_dependencies([assign_op]): 171 | return tf.identity(value_op) 172 | 173 | return tf.cond(init_has_been_run, lambda: value_op, run_init_and_toggle) 174 | 175 | 176 | def scatter_by_anchor_indices(anchor_indices, data, index_shift): 177 | """Shift data such that it is indexed relative to anchor_indices. 178 | 179 | For each row of the data array, we flip it horizontally and then shift it 180 | so that the output at (anchor_index + index_shift) is the leftmost column 181 | of the input. Namely: 182 | 183 | output[i][j] = data[i][anchor_indices[i] - j + index_shift] 184 | 185 | Args: 186 | anchor_indices: [batch_size] int Tensor or np array 187 | data: [batch_size, num_columns]: float Tensor or np array 188 | index_shift: int 189 | Returns: 190 | [batch_size, num_columns] Tensor 191 | """ 192 | anchor_indices = tf.convert_to_tensor(anchor_indices) 193 | data = tf.convert_to_tensor(data) 194 | 195 | num_data_columns = data.shape[-1].value 196 | indices = np.arange(num_data_columns)[np.newaxis, ...] 197 | shifted_indices = anchor_indices[..., tf.newaxis] - indices + index_shift 198 | valid_indices = shifted_indices >= 0 199 | 200 | batch_size = tf.shape(data)[0] 201 | 202 | batch_indices = tf.tile( 203 | tf.range(batch_size)[..., tf.newaxis], [1, num_data_columns]) 204 | shifted_indices += batch_indices * num_data_columns 205 | 206 | shifted_indices = tf.reshape(shifted_indices, [-1]) 207 | num_elements = tf.shape(data)[0] * tf.shape(data)[1] 208 | row_indices = tf.range(num_elements) 209 | stacked_indices = tf.stack([row_indices, shifted_indices], axis=1) 210 | 211 | lower_batch_boundaries = tf.reshape(batch_indices * num_data_columns, [-1]) 212 | upper_batch_boundaries = tf.reshape(((batch_indices + 1) * num_data_columns), 213 | [-1]) 214 | valid_indices = tf.logical_and(shifted_indices >= lower_batch_boundaries, 215 | shifted_indices < upper_batch_boundaries) 216 | stacked_indices = tf.boolean_mask( 217 | stacked_indices, 218 | valid_indices, 219 | ) 220 | 221 | dense_shape = tf.cast(tf.tile(num_elements[..., tf.newaxis], [2]), tf.int64) 222 | 223 | scattering_matrix = tf.SparseTensor( 224 | indices=tf.cast(stacked_indices, tf.int64), 225 | values=tf.ones_like(stacked_indices[:, 0], dtype=data.dtype), 226 | dense_shape=dense_shape) 227 | 228 | flattened_data = tf.reshape(data, [-1])[..., tf.newaxis] 229 | flattened_output = tf.sparse_tensor_dense_matmul( 230 | scattering_matrix, 231 | flattened_data, 232 | adjoint_a=False, 233 | adjoint_b=False, 234 | name=None) 235 | 236 | return tf.reshape( 237 | tf.transpose(flattened_output, [0, 1]), [-1, num_data_columns]) 238 | -------------------------------------------------------------------------------- /spectra_predictor.py: -------------------------------------------------------------------------------- 1 | """Helpers for generating spectra prediction from trained models.""" 2 | 3 | import abc 4 | 5 | import feature_map_constants as fmap_constants 6 | import feature_utils 7 | import mass_spec_constants as ms_constants 8 | import molecule_predictors 9 | 10 | import numpy as np 11 | from rdkit import Chem 12 | from rdkit.Chem import AllChem 13 | import six 14 | import tensorflow as tf 15 | 16 | _DEFAULT_HPARAMS = { 17 | "radius": 2, 18 | "mass_power": 1.0, 19 | "gate_bidirectional_predictions": True, 20 | "include_atom_mass": True, 21 | "init_bias": "default", 22 | "reverse_prediction": True, 23 | "max_mass_spec_peak_loc": 1000, 24 | "num_hidden_units": 2000, 25 | "use_counting_fp": True, 26 | "max_atoms": 100, 27 | "intensity_power": 0.5, 28 | "max_prediction_above_molecule_mass": 5, 29 | "fp_length": 4096, 30 | "bidirectional_prediction": True, 31 | "resnet_bottleneck_factor": 0.5, 32 | "max_atom_type": 100, 33 | "hidden_layer_activation": "relu", 34 | "init_weights": "default", 35 | "num_hidden_layers": 7 36 | } 37 | _DEFAULT_HPARAMS_STR = ",".join( 38 | "{}={}".format(k, v) for k, v in six.iteritems(_DEFAULT_HPARAMS)) 39 | PREDICTED_SPECTRA_PROP_NAME = "PREDICTED SPECTRUM" 40 | 41 | # Predictions from the model are normalized by default. 42 | # This factor is used to rescale the predictions so the highest intensity has 43 | # this value. 44 | SCALE_FACTOR_FOR_LARGEST_INTENSITY = 999. 45 | 46 | 47 | def fingerprints_to_use(hparams): 48 | """Given tf.HParams, return a ms_constants.CircularFingerprintKey.""" 49 | if hparams.use_counting_fp: 50 | key = fmap_constants.COUNTING_CIRCULAR_FP_BASENAME 51 | else: 52 | key = fmap_constants.CIRCULAR_FP_BASENAME 53 | 54 | return ms_constants.CircularFingerprintKey(key, hparams.fp_length, 55 | hparams.radius) 56 | 57 | 58 | def get_mol_weights_from_mol_list(mol_list): 59 | """Given a list of rdkit.Mols, return weights for each mol.""" 60 | return np.array([Chem.rdMolDescriptors.CalcExactMolWt(m) for m in mol_list]) 61 | 62 | 63 | def get_mol_list_from_sdf(sdf_fname): 64 | """Reads a sdf file and returns a list of molecules. 65 | 66 | Note: rdkit's Chem.SDMolSupplier only accepts filenames as inputs. As such 67 | this code only supports local filesystem name environments. 68 | 69 | Args: 70 | sdf_fname: Path to sdf file. 71 | 72 | Returns: 73 | List of rdkit.Mol objects. 74 | 75 | Raises: 76 | ValueError if a molblock in the SDF cannot be parsed. 77 | """ 78 | suppl = Chem.SDMolSupplier(sdf_fname) 79 | mols = [] 80 | 81 | for idx, mol in enumerate(suppl): 82 | if mol is not None: 83 | mols.append(mol) 84 | else: 85 | fail_sdf_block = suppl.GetItemText(idx) 86 | raise ValueError("Unable to parse the following mol block %s" % 87 | fail_sdf_block) 88 | return mols 89 | 90 | 91 | def update_mols_with_spectra(mol_list, spectra_array): 92 | """Writes a predicted spectrum for each RDKit.mol object. 93 | 94 | Args: 95 | mol_list: List of rdkit.Mol objects. 96 | spectra_array: np.array of spectra. 97 | 98 | Returns: 99 | Updated list of rdkit.Mol objects where each molecule contains a predicted 100 | spectrum. 101 | """ 102 | if len(mol_list) != np.shape(spectra_array)[0]: 103 | raise ValueError("Number of mols in mol list %d is not equal to number of " 104 | "spectra found %d." % 105 | (len(mol_list), np.shape(spectra_array)[0])) 106 | for mol, spectrum in zip(mol_list, spectra_array): 107 | spec_array_text = feature_utils.convert_spectrum_array_to_string(spectrum) 108 | mol.SetProp(PREDICTED_SPECTRA_PROP_NAME, spec_array_text) 109 | return mol_list 110 | 111 | 112 | def write_rdkit_mols_to_sdf(mol_list, out_sdf_name): 113 | """Writes a series of rdkit.Mol to SDF. 114 | 115 | Args: 116 | mol_list: List of rdkit.Mol objects. 117 | out_sdf_name: Output file path for molecules. 118 | """ 119 | writer = AllChem.SDWriter(out_sdf_name) 120 | 121 | for mol in mol_list: 122 | writer.write(mol) 123 | writer.close() 124 | 125 | 126 | class SpectraPredictor(object): 127 | """Helper for generating a computational graph for making predictions.""" 128 | __metaclass__ = abc.ABCMeta 129 | 130 | def __init__(self, hparams_str=""): 131 | """Sets up graph, session, and input and output ops for prediction. 132 | 133 | Args: 134 | hparams_str (str): String containing hyperparameter settings. 135 | """ 136 | 137 | self._prediction_helper = molecule_predictors.get_prediction_helper("mlp") 138 | self._hparams = self._prediction_helper.get_default_hparams() 139 | self._hparams.parse(hparams_str) 140 | self._fingerprint_key = fingerprints_to_use(self._hparams) 141 | self.fingerprint_input_key = str(self._fingerprint_key) 142 | self.molecular_weight_key = fmap_constants.MOLECULE_WEIGHT 143 | 144 | self._graph = tf.Graph() 145 | self._sess = tf.Session(graph=self._graph) 146 | with self._graph.as_default(): 147 | (self._placeholder_dict, self._predict_op) = self._setup_prediction_op() 148 | assert set(self._placeholder_dict) == set( 149 | [self.fingerprint_input_key, self.molecular_weight_key]) 150 | 151 | @abc.abstractmethod 152 | def _setup_prediction_op(self): 153 | """Sets up prediction operation. 154 | 155 | Returns: 156 | placeholder_dict: Dict with self.fingerprint_input_key and 157 | self.molecular_weight_key as keys and values which are tf.placeholder 158 | for predicted spectra. 159 | predict_op: tf.Tensor for predicted spectra. 160 | """ 161 | 162 | def make_spectra_prediction(self, fingerprint_array, molecule_weight_array): 163 | """Makes spectra prediction. 164 | 165 | Args: 166 | fingerprint_array (np.array): Contains molcule fingerprints. 167 | molecule_weight_array (np.array): Contains molecular weights. Should have 168 | same batch dimension as fingerprint_array. 169 | 170 | Returns: 171 | np.array of predictions. 172 | """ 173 | molecule_weight_array = np.reshape(molecule_weight_array, (-1, 1)) 174 | with self._graph.as_default(): 175 | prediction = self._sess.run( 176 | self._predict_op, 177 | feed_dict={ 178 | self._placeholder_dict[self.fingerprint_input_key]: 179 | fingerprint_array, 180 | self._placeholder_dict[self.molecular_weight_key]: 181 | molecule_weight_array 182 | }) 183 | 184 | prediction = prediction / np.max( 185 | prediction, axis=1, keepdims=True) * SCALE_FACTOR_FOR_LARGEST_INTENSITY 186 | return prediction 187 | 188 | def get_fingerprints_from_mol_list(self, mol_list): 189 | """Converts a list of rdkit.Mol objects into circular fingerprints. 190 | 191 | Args: 192 | mol_list: List of rdkit.Mol objects. 193 | 194 | Returns: 195 | np.array of fingerprints for prediction. 196 | """ 197 | 198 | fingerprints = [ 199 | feature_utils.make_circular_fingerprint(mol, self._fingerprint_key) 200 | for mol in mol_list 201 | ] 202 | 203 | return np.array(fingerprints) 204 | 205 | def get_inputs_for_model_from_mol_list(self, mol_list): 206 | """Grabs fingerprints and molecular weights for the prediction model.""" 207 | fingerprints = self.get_fingerprints_from_mol_list(mol_list) 208 | weights = get_mol_weights_from_mol_list(mol_list) 209 | return fingerprints, weights 210 | 211 | 212 | class NeimsSpectraPredictor(SpectraPredictor): 213 | """Helper for making spectra predictions using the trained NEIMS model.""" 214 | 215 | def __init__(self, model_checkpoint_dir, hparams_str=_DEFAULT_HPARAMS_STR): 216 | """Initializes the predictor with the weights and hyperparameters. 217 | 218 | Args: 219 | model_checkpoint_dir (str): Path to checkpoint weights. 220 | hparams_str (str): String that contains hyperparameters for model. 221 | """ 222 | super(NeimsSpectraPredictor, self).__init__(hparams_str) 223 | self.restore_from_checkpoint(model_checkpoint_dir) 224 | 225 | def _setup_prediction_op(self): 226 | """Sets up prediction operation and inputs for model.""" 227 | fp_length = self._hparams.fp_length 228 | 229 | fingerprint_input_op = tf.placeholder(tf.float32, (None, fp_length)) 230 | mol_weight_input_op = tf.placeholder(tf.float32, (None, 1)) 231 | 232 | feature_dict = { 233 | self.fingerprint_input_key: fingerprint_input_op, 234 | self.molecular_weight_key: mol_weight_input_op 235 | } 236 | 237 | predict_op, _ = self._prediction_helper.make_prediction_ops( 238 | feature_dict, 239 | self._hparams, 240 | mode=tf.estimator.ModeKeys.PREDICT, 241 | reuse=False) 242 | 243 | return feature_dict, predict_op 244 | 245 | def restore_from_checkpoint(self, model_checkpoint_dir): 246 | """Restores model parameters from checkpoint directory. 247 | 248 | Args: 249 | model_checkpoint_dir (str): filepath directory to weights. If empty, model 250 | will be initialized with random weights. 251 | """ 252 | with self._graph.as_default(): 253 | if model_checkpoint_dir: 254 | saver = tf.train.Saver() 255 | saver.restore(self._sess, 256 | tf.train.latest_checkpoint(model_checkpoint_dir)) 257 | else: 258 | tf.logging.warn("No model checkpoint directory given," 259 | " reinitializing model.") 260 | self._sess.run(tf.global_variables_initializer()) 261 | -------------------------------------------------------------------------------- /dataset_setup_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Constants for creating dataset for mass spectrometry experiment.""" 16 | 17 | from typing import NamedTuple 18 | 19 | # List of components of the main NIST library and the replicates library. 20 | # The main NIST library is divided into a train/validation/test set 21 | # The replicates library is divided into validation/test sets. 22 | MAINLIB_TRAIN_BASENAME = 'mainlib_train' 23 | MAINLIB_VALIDATION_BASENAME = 'mainlib_validation' 24 | MAINLIB_TEST_BASENAME = 'mainlib_test' 25 | REPLICATES_TRAIN_BASENAME = 'replicates_train' 26 | REPLICATES_VALIDATION_BASENAME = 'replicates_validation' 27 | REPLICATES_TEST_BASENAME = 'replicates_test' 28 | 29 | # Key names of main datasets required for each experiment. 30 | SPECTRUM_PREDICTION_TRAIN_KEY = 'SPECTRUM_PREDICTION_TRAIN' 31 | SPECTRUM_PREDICTION_TEST_KEY = 'SPECTRUM_PREDICTION_TEST' 32 | LIBRARY_MATCHING_OBSERVED_KEY = 'LIBRARY_MATCHING_OBSERVED' 33 | LIBRARY_MATCHING_PREDICTED_KEY = 'LIBRARY_MATCHING_PREDICTED' 34 | LIBRARY_MATCHING_QUERY_KEY = 'LIBRARY_MATCHING_QUERY' 35 | 36 | TRAINING_SPECTRA_ARRAY_KEY = MAINLIB_TRAIN_BASENAME + '_spectra_library_file' 37 | 38 | 39 | class ExperimentSetup( 40 | NamedTuple('ExperimentSetup', 41 | [('json_name', str), ('data_to_get_from_mainlib', list), 42 | ('data_to_get_from_replicates', list), 43 | ('experiment_setup_dataset_dict', dict)])): 44 | """Stores information related to the various experiment setups. 45 | 46 | Attributes: 47 | json_nane : name of the json file to store locations of the datasets 48 | data_to_get_from_mainlib: List of dataset keys to grab from the mainlib 49 | library. 50 | data_to_get_from_replicates: List of dataset keys to grab from the 51 | replicates library. 52 | experiment_setup_dataset_dict: Dict which matches the experiment keys to 53 | lists of the component datasets, matching the basename keys above. 54 | """ 55 | 56 | def __new__(cls, json_name, data_to_get_from_mainlib, 57 | data_to_get_from_replicates, experiment_setup_dataset_dict): 58 | assert (experiment_setup_dataset_dict[LIBRARY_MATCHING_QUERY_KEY] == 59 | experiment_setup_dataset_dict[SPECTRUM_PREDICTION_TEST_KEY]), ( 60 | 'In json {}, library query list did not match' 61 | ' spectrum prediction list.'.format(json_name)) 62 | assert (experiment_setup_dataset_dict[SPECTRUM_PREDICTION_TRAIN_KEY] == 63 | [MAINLIB_TRAIN_BASENAME]), ( 64 | 'In json {}, spectra prediction dataset is not mainlib_train,' 65 | ' which is currently not supported.'.format(json_name)) 66 | return super(ExperimentSetup, cls).__new__( 67 | cls, json_name, data_to_get_from_mainlib, data_to_get_from_replicates, 68 | experiment_setup_dataset_dict) 69 | 70 | 71 | # Experiment setups: 72 | QUERY_MAINLIB_VAL_PRED_MAINLIB_VAL = ExperimentSetup( 73 | 'query_mainlib_val_predicted_mainlib_val.json', [ 74 | SPECTRUM_PREDICTION_TRAIN_KEY, 75 | SPECTRUM_PREDICTION_TEST_KEY, 76 | LIBRARY_MATCHING_QUERY_KEY, 77 | LIBRARY_MATCHING_OBSERVED_KEY, 78 | LIBRARY_MATCHING_PREDICTED_KEY, 79 | ], [], { 80 | SPECTRUM_PREDICTION_TRAIN_KEY: [MAINLIB_TRAIN_BASENAME], 81 | SPECTRUM_PREDICTION_TEST_KEY: [MAINLIB_VALIDATION_BASENAME], 82 | LIBRARY_MATCHING_OBSERVED_KEY: [ 83 | MAINLIB_TRAIN_BASENAME, MAINLIB_TEST_BASENAME, 84 | REPLICATES_TRAIN_BASENAME, REPLICATES_VALIDATION_BASENAME, 85 | REPLICATES_TEST_BASENAME 86 | ], 87 | LIBRARY_MATCHING_PREDICTED_KEY: [MAINLIB_VALIDATION_BASENAME], 88 | LIBRARY_MATCHING_QUERY_KEY: [MAINLIB_VALIDATION_BASENAME], 89 | }) 90 | 91 | QUERY_REPLICATES_VAL_PRED_REPLICATES_VAL = ExperimentSetup( 92 | 'query_replicates_val_predicted_replicates_val.json', [ 93 | SPECTRUM_PREDICTION_TRAIN_KEY, 94 | SPECTRUM_PREDICTION_TEST_KEY, 95 | LIBRARY_MATCHING_OBSERVED_KEY, 96 | LIBRARY_MATCHING_PREDICTED_KEY, 97 | ], [LIBRARY_MATCHING_QUERY_KEY], { 98 | SPECTRUM_PREDICTION_TRAIN_KEY: [MAINLIB_TRAIN_BASENAME], 99 | SPECTRUM_PREDICTION_TEST_KEY: [REPLICATES_VALIDATION_BASENAME], 100 | LIBRARY_MATCHING_OBSERVED_KEY: [ 101 | MAINLIB_TRAIN_BASENAME, MAINLIB_TEST_BASENAME, 102 | MAINLIB_VALIDATION_BASENAME, REPLICATES_TRAIN_BASENAME, 103 | REPLICATES_TEST_BASENAME 104 | ], 105 | LIBRARY_MATCHING_PREDICTED_KEY: [REPLICATES_VALIDATION_BASENAME], 106 | LIBRARY_MATCHING_QUERY_KEY: [REPLICATES_VALIDATION_BASENAME], 107 | }) 108 | 109 | QUERY_REPLICATES_TEST_PRED_REPLICATES_TEST = ExperimentSetup( 110 | 'query_replicates_test_predicted_replicates_test.json', [ 111 | SPECTRUM_PREDICTION_TRAIN_KEY, 112 | SPECTRUM_PREDICTION_TEST_KEY, 113 | LIBRARY_MATCHING_OBSERVED_KEY, 114 | LIBRARY_MATCHING_PREDICTED_KEY, 115 | ], [LIBRARY_MATCHING_QUERY_KEY], { 116 | SPECTRUM_PREDICTION_TRAIN_KEY: [MAINLIB_TRAIN_BASENAME], 117 | SPECTRUM_PREDICTION_TEST_KEY: [REPLICATES_TEST_BASENAME], 118 | LIBRARY_MATCHING_OBSERVED_KEY: [ 119 | MAINLIB_TRAIN_BASENAME, MAINLIB_TEST_BASENAME, 120 | REPLICATES_TRAIN_BASENAME, MAINLIB_VALIDATION_BASENAME, 121 | REPLICATES_VALIDATION_BASENAME 122 | ], 123 | LIBRARY_MATCHING_PREDICTED_KEY: [REPLICATES_TEST_BASENAME], 124 | LIBRARY_MATCHING_QUERY_KEY: [REPLICATES_TEST_BASENAME], 125 | }) 126 | 127 | QUERY_REPLICATES_VAL_PRED_NONE = ExperimentSetup( 128 | 'query_replicates_val_predicted_none.json', [ 129 | SPECTRUM_PREDICTION_TRAIN_KEY, 130 | SPECTRUM_PREDICTION_TEST_KEY, 131 | LIBRARY_MATCHING_OBSERVED_KEY, 132 | LIBRARY_MATCHING_PREDICTED_KEY, 133 | ], [LIBRARY_MATCHING_QUERY_KEY], { 134 | SPECTRUM_PREDICTION_TRAIN_KEY: [MAINLIB_TRAIN_BASENAME], 135 | SPECTRUM_PREDICTION_TEST_KEY: [REPLICATES_VALIDATION_BASENAME], 136 | LIBRARY_MATCHING_OBSERVED_KEY: [ 137 | MAINLIB_TRAIN_BASENAME, MAINLIB_TEST_BASENAME, 138 | MAINLIB_VALIDATION_BASENAME, REPLICATES_TRAIN_BASENAME, 139 | REPLICATES_TEST_BASENAME, REPLICATES_VALIDATION_BASENAME 140 | ], 141 | LIBRARY_MATCHING_PREDICTED_KEY: [], 142 | LIBRARY_MATCHING_QUERY_KEY: [REPLICATES_VALIDATION_BASENAME], 143 | }) 144 | 145 | QUERY_REPLICATES_TEST_PRED_NONE = ExperimentSetup( 146 | 'query_replicates_test_predicted_none.json', [ 147 | SPECTRUM_PREDICTION_TRAIN_KEY, 148 | SPECTRUM_PREDICTION_TEST_KEY, 149 | LIBRARY_MATCHING_OBSERVED_KEY, 150 | LIBRARY_MATCHING_PREDICTED_KEY, 151 | ], [LIBRARY_MATCHING_QUERY_KEY], { 152 | SPECTRUM_PREDICTION_TRAIN_KEY: [MAINLIB_TRAIN_BASENAME], 153 | SPECTRUM_PREDICTION_TEST_KEY: [REPLICATES_TEST_BASENAME], 154 | LIBRARY_MATCHING_OBSERVED_KEY: [ 155 | MAINLIB_TRAIN_BASENAME, MAINLIB_TEST_BASENAME, 156 | MAINLIB_VALIDATION_BASENAME, REPLICATES_TRAIN_BASENAME, 157 | REPLICATES_VALIDATION_BASENAME, REPLICATES_TEST_BASENAME 158 | ], 159 | LIBRARY_MATCHING_PREDICTED_KEY: [], 160 | LIBRARY_MATCHING_QUERY_KEY: [REPLICATES_TEST_BASENAME], 161 | }) 162 | 163 | QUERY_REPLICATES_ALL_PRED_NONE = ExperimentSetup( 164 | 'query_replicates_all_predicted_none.json', [ 165 | SPECTRUM_PREDICTION_TRAIN_KEY, 166 | SPECTRUM_PREDICTION_TEST_KEY, 167 | LIBRARY_MATCHING_OBSERVED_KEY, 168 | LIBRARY_MATCHING_PREDICTED_KEY, 169 | ], [LIBRARY_MATCHING_QUERY_KEY], { 170 | SPECTRUM_PREDICTION_TRAIN_KEY: [MAINLIB_TRAIN_BASENAME], 171 | SPECTRUM_PREDICTION_TEST_KEY: [ 172 | REPLICATES_VALIDATION_BASENAME, REPLICATES_TEST_BASENAME 173 | ], 174 | LIBRARY_MATCHING_OBSERVED_KEY: [ 175 | MAINLIB_TRAIN_BASENAME, MAINLIB_TEST_BASENAME, 176 | MAINLIB_VALIDATION_BASENAME, REPLICATES_TRAIN_BASENAME, 177 | REPLICATES_VALIDATION_BASENAME, REPLICATES_TEST_BASENAME 178 | ], 179 | LIBRARY_MATCHING_PREDICTED_KEY: [], 180 | LIBRARY_MATCHING_QUERY_KEY: [ 181 | REPLICATES_VALIDATION_BASENAME, REPLICATES_TEST_BASENAME 182 | ], 183 | }) 184 | 185 | # An overfitting setup for sanity checks 186 | QUERY_MAINLIB_TRAIN_PRED_MAINLIB_TRAIN = ExperimentSetup( 187 | 'query_mainlib_train_predicted_mainlib_train.json', [ 188 | SPECTRUM_PREDICTION_TRAIN_KEY, SPECTRUM_PREDICTION_TEST_KEY, 189 | LIBRARY_MATCHING_OBSERVED_KEY, LIBRARY_MATCHING_PREDICTED_KEY, 190 | LIBRARY_MATCHING_QUERY_KEY 191 | ], [], { 192 | SPECTRUM_PREDICTION_TRAIN_KEY: [MAINLIB_TRAIN_BASENAME], 193 | SPECTRUM_PREDICTION_TEST_KEY: [MAINLIB_TRAIN_BASENAME], 194 | LIBRARY_MATCHING_OBSERVED_KEY: [ 195 | MAINLIB_VALIDATION_BASENAME, 196 | MAINLIB_TEST_BASENAME, 197 | REPLICATES_TRAIN_BASENAME, 198 | REPLICATES_VALIDATION_BASENAME, 199 | REPLICATES_TEST_BASENAME, 200 | ], 201 | LIBRARY_MATCHING_PREDICTED_KEY: [MAINLIB_TRAIN_BASENAME], 202 | LIBRARY_MATCHING_QUERY_KEY: [MAINLIB_TRAIN_BASENAME], 203 | }) 204 | 205 | EXPERIMENT_SETUPS_LIST = [ 206 | QUERY_MAINLIB_VAL_PRED_MAINLIB_VAL, 207 | QUERY_REPLICATES_VAL_PRED_REPLICATES_VAL, 208 | QUERY_REPLICATES_TEST_PRED_REPLICATES_TEST, 209 | QUERY_REPLICATES_VAL_PRED_NONE, 210 | QUERY_REPLICATES_TEST_PRED_NONE, 211 | QUERY_MAINLIB_TRAIN_PRED_MAINLIB_TRAIN, 212 | ] 213 | -------------------------------------------------------------------------------- /make_train_test_split_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for make_train_test_split.py and train_test_split_utils.py.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import json 21 | import os 22 | import tempfile 23 | 24 | from absl.testing import absltest 25 | from absl.testing import parameterized 26 | 27 | import dataset_setup_constants as ds_constants 28 | import feature_map_constants as fmap_constants 29 | import make_train_test_split 30 | import mass_spec_constants as ms_constants 31 | import parse_sdf_utils 32 | import test_utils 33 | import train_test_split_utils 34 | import six 35 | import tensorflow as tf 36 | 37 | 38 | class MakeTrainTestSplitTest(tf.test.TestCase, parameterized.TestCase): 39 | 40 | def setUp(self): 41 | super(MakeTrainTestSplitTest, self).setUp() 42 | test_data_directory = test_utils.test_dir('testdata/') 43 | self.temp_dir = tempfile.mkdtemp(dir=absltest.get_default_test_tmpdir()) 44 | test_sdf_file_large = os.path.join(test_data_directory, 'test_14_mend.sdf') 45 | test_sdf_file_small = os.path.join(test_data_directory, 'test_2_mend.sdf') 46 | 47 | max_atoms = ms_constants.MAX_ATOMS 48 | self.mol_list_large = parse_sdf_utils.get_sdf_to_mol( 49 | test_sdf_file_large, max_atoms=max_atoms) 50 | self.mol_list_small = parse_sdf_utils.get_sdf_to_mol( 51 | test_sdf_file_small, max_atoms=max_atoms) 52 | self.inchikey_dict_large = train_test_split_utils.make_inchikey_dict( 53 | self.mol_list_large) 54 | self.inchikey_dict_small = train_test_split_utils.make_inchikey_dict( 55 | self.mol_list_small) 56 | self.inchikey_list_large = list(self.inchikey_dict_large.keys()) 57 | self.inchikey_list_small = list(self.inchikey_dict_small.keys()) 58 | 59 | def tearDown(self): 60 | tf.gfile.DeleteRecursively(self.temp_dir) 61 | super(MakeTrainTestSplitTest, self).tearDown() 62 | 63 | def encode(self, value): 64 | """Wrapper function for encoding strings in python 3.""" 65 | return test_utils.encode(value, six.PY3) 66 | 67 | def test_all_lists_mutually_exclusive(self): 68 | list1 = ['1', '2', '3'] 69 | list2 = ['2', '3', '4'] 70 | try: 71 | train_test_split_utils.assert_all_lists_mutally_exclusive([list1, list2]) 72 | raise ValueError('Sets with overlapping elements should have failed.') 73 | except ValueError: 74 | pass 75 | 76 | def test_make_inchikey_dict(self): 77 | self.assertLen(self.inchikey_dict_large, 11) 78 | self.assertLen(self.inchikey_dict_small, 2) 79 | 80 | def test_make_mol_list_from_inchikey_dict(self): 81 | mol_list = train_test_split_utils.make_mol_list_from_inchikey_dict( 82 | self.inchikey_dict_large, self.inchikey_list_large) 83 | self.assertCountEqual(mol_list, self.mol_list_large) 84 | 85 | def test_make_train_val_test_split_mol_lists(self): 86 | main_train_test_split = train_test_split_utils.TrainValTestFractions( 87 | 0.5, 0.25, 0.25) 88 | 89 | inchikey_list_of_lists = ( 90 | train_test_split_utils.make_train_val_test_split_inchikey_lists( 91 | self.inchikey_list_large, self.inchikey_dict_large, 92 | main_train_test_split)) 93 | 94 | expected_lengths_of_inchikey_lists = [5, 2, 4] 95 | 96 | for expected_length, inchikey_list in zip( 97 | expected_lengths_of_inchikey_lists, inchikey_list_of_lists): 98 | self.assertLen(inchikey_list, expected_length) 99 | 100 | train_test_split_utils.assert_all_lists_mutally_exclusive( 101 | inchikey_list_of_lists) 102 | 103 | trunc_inchikey_list_large = self.inchikey_list_large[:6] 104 | inchikey_list_of_lists = [ 105 | (train_test_split_utils.make_train_val_test_split_inchikey_lists( 106 | trunc_inchikey_list_large, self.inchikey_dict_large, 107 | main_train_test_split)) 108 | ] 109 | 110 | expected_lengths_of_inchikey_lists = [3, 1, 2] 111 | for expected_length, inchikey_list in zip( 112 | expected_lengths_of_inchikey_lists, inchikey_list_of_lists): 113 | self.assertLen(inchikey_list, expected_length) 114 | 115 | train_test_split_utils.assert_all_lists_mutally_exclusive( 116 | inchikey_list_of_lists) 117 | 118 | def test_make_train_val_test_split_mol_lists_holdout(self): 119 | main_train_test_split = train_test_split_utils.TrainValTestFractions( 120 | 0.5, 0.25, 0.25) 121 | holdout_inchikey_list_of_lists = ( 122 | train_test_split_utils.make_train_val_test_split_inchikey_lists( 123 | self.inchikey_list_large, 124 | self.inchikey_dict_large, 125 | main_train_test_split, 126 | holdout_inchikey_list=self.inchikey_list_small)) 127 | 128 | expected_lengths_of_inchikey_lists = [4, 2, 3] 129 | for expected_length, inchikey_list in zip( 130 | expected_lengths_of_inchikey_lists, holdout_inchikey_list_of_lists): 131 | self.assertLen(inchikey_list, expected_length) 132 | 133 | train_test_split_utils.assert_all_lists_mutally_exclusive( 134 | holdout_inchikey_list_of_lists) 135 | 136 | def test_make_train_val_test_split_mol_lists_family(self): 137 | train_test_split = train_test_split_utils.TrainValTestFractions( 138 | 0.5, 0.25, 0.25) 139 | train_inchikeys, val_inchikeys, test_inchikeys = ( 140 | train_test_split_utils.make_train_val_test_split_inchikey_lists( 141 | self.inchikey_list_large, 142 | self.inchikey_dict_large, 143 | train_test_split, 144 | holdout_inchikey_list=self.inchikey_list_small, 145 | splitting_type='diazo')) 146 | 147 | self.assertCountEqual(train_inchikeys, [ 148 | 'UFHFLCQGNIYNRP-UHFFFAOYSA-N', 'CCGKOQOJPYTBIH-UHFFFAOYSA-N', 149 | 'ASTNYHRQIBTGNO-UHFFFAOYSA-N', 'UFHFLCQGNIYNRP-VVKOMZTBSA-N', 150 | 'PVVBOXUQVSZBMK-UHFFFAOYSA-N' 151 | ]) 152 | 153 | self.assertCountEqual(val_inchikeys + test_inchikeys, [ 154 | 'OWKPLCCVKXABQF-UHFFFAOYSA-N', 'COVPJOWITGLAKX-UHFFFAOYSA-N', 155 | 'GKVDXUXIAHWQIK-UHFFFAOYSA-N', 'UCIXUAPVXAZYDQ-VMPITWQZSA-N' 156 | ]) 157 | 158 | replicate_train_inchikeys, _, replicate_test_inchikeys = ( 159 | train_test_split_utils.make_train_val_test_split_inchikey_lists( 160 | self.inchikey_list_small, 161 | self.inchikey_dict_small, 162 | train_test_split, 163 | splitting_type='diazo')) 164 | 165 | self.assertEqual(replicate_train_inchikeys[0], 166 | 'PNYUDNYAXSEACV-RVDMUPIBSA-N') 167 | self.assertEqual(replicate_test_inchikeys[0], 'YXHKONLOYHBTNS-UHFFFAOYSA-N') 168 | 169 | @parameterized.parameters('random', 'diazo') 170 | def test_make_train_test_split(self, splitting_type): 171 | """An integration test on a small dataset.""" 172 | 173 | fpath = self.temp_dir 174 | 175 | # Create component datasets from two library files. 176 | main_train_val_test_fractions = ( 177 | train_test_split_utils.TrainValTestFractions(0.5, 0.25, 0.25)) 178 | replicates_val_test_fractions = ( 179 | train_test_split_utils.TrainValTestFractions(0.0, 0.5, 0.5)) 180 | 181 | (mainlib_inchikey_dict, replicates_inchikey_dict, 182 | component_inchikey_dict) = ( 183 | make_train_test_split.make_mainlib_replicates_train_test_split( 184 | self.mol_list_large, self.mol_list_small, splitting_type, 185 | main_train_val_test_fractions, replicates_val_test_fractions)) 186 | 187 | make_train_test_split.write_mainlib_split_datasets( 188 | component_inchikey_dict, mainlib_inchikey_dict, fpath, 189 | ms_constants.MAX_ATOMS, ms_constants.MAX_PEAK_LOC) 190 | 191 | make_train_test_split.write_replicates_split_datasets( 192 | component_inchikey_dict, replicates_inchikey_dict, fpath, 193 | ms_constants.MAX_ATOMS, ms_constants.MAX_PEAK_LOC) 194 | 195 | for experiment_setup in ds_constants.EXPERIMENT_SETUPS_LIST: 196 | # Create experiment json files 197 | tf.logging.info('Writing experiment setup for %s', 198 | experiment_setup.json_name) 199 | make_train_test_split.check_experiment_setup( 200 | experiment_setup.experiment_setup_dataset_dict, 201 | component_inchikey_dict) 202 | 203 | make_train_test_split.write_json_for_experiment(experiment_setup, fpath) 204 | 205 | # Check that physical files for library matching contain all inchikeys 206 | dict_from_json = json.load( 207 | tf.gfile.Open(os.path.join(fpath, experiment_setup.json_name))) 208 | 209 | tf.logging.info(dict_from_json) 210 | library_files = ( 211 | dict_from_json[ds_constants.LIBRARY_MATCHING_OBSERVED_KEY] + 212 | dict_from_json[ds_constants.LIBRARY_MATCHING_PREDICTED_KEY]) 213 | library_files = [os.path.join(fpath, fname) for fname in library_files] 214 | 215 | hparams = tf.contrib.training.HParams( 216 | max_atoms=ms_constants.MAX_ATOMS, 217 | max_mass_spec_peak_loc=ms_constants.MAX_PEAK_LOC, 218 | intensity_power=1.0, 219 | batch_size=5) 220 | 221 | parse_sdf_utils.validate_spectra_array_contents( 222 | os.path.join( 223 | fpath, 224 | dict_from_json[ds_constants.SPECTRUM_PREDICTION_TRAIN_KEY][0]), 225 | hparams, 226 | os.path.join(fpath, 227 | dict_from_json[ds_constants.TRAINING_SPECTRA_ARRAY_KEY])) 228 | 229 | dataset = parse_sdf_utils.get_dataset_from_record( 230 | library_files, 231 | hparams, 232 | mode=tf.estimator.ModeKeys.EVAL, 233 | all_data_in_one_batch=True) 234 | 235 | feature_names = [fmap_constants.INCHIKEY] 236 | label_names = [fmap_constants.ATOM_WEIGHTS] 237 | 238 | features, labels = parse_sdf_utils.make_features_and_labels( 239 | dataset, feature_names, label_names, mode=tf.estimator.ModeKeys.EVAL) 240 | 241 | with tf.Session() as sess: 242 | feature_values, _ = sess.run([features, labels]) 243 | 244 | inchikeys_from_file = [ 245 | ikey[0] for ikey in feature_values[fmap_constants.INCHIKEY].tolist() 246 | ] 247 | 248 | length_from_info_file = sum([ 249 | parse_sdf_utils.parse_info_file(library_fname)['num_examples'] 250 | for library_fname in library_files 251 | ]) 252 | # Check that info file has the correct length for the file. 253 | self.assertLen(inchikeys_from_file, length_from_info_file) 254 | # Check that the TF.Record contains all of the inchikeys in our list. 255 | inchikey_list_large = [ 256 | self.encode(ikey) for ikey in self.inchikey_list_large 257 | ] 258 | self.assertSetEqual(set(inchikeys_from_file), set(inchikey_list_large)) 259 | 260 | 261 | if __name__ == '__main__': 262 | tf.test.main() 263 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2018 Google LLC 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /feature_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions for getting molecule features from NIST sdf data. 16 | 17 | This module includes functions to help with parsing sdf files and generating 18 | features such as atom weight lists and adjacency matrices. Also contains a 19 | function to parse mass spectra peaks from their string format in the NIST sdf 20 | files. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import feature_map_constants as fmap_constants 28 | import mass_spec_constants as ms_constants 29 | import numpy as np 30 | from rdkit import Chem 31 | from rdkit import DataStructs 32 | from rdkit.Chem import AllChem 33 | 34 | FILTER_DICT = { 35 | 'steroid': 36 | Chem.MolFromSmarts( 37 | '[#6]1~[#6]2~[#6](~[#6]~[#6]~[#6]~1)' 38 | '~[#6]~[#6]~[#6]1~[#6]~2~[#6]~[#6]~[#6]2~[#6]~1~[#6]~[#6]~[#6]~2' 39 | ), 40 | 'diazo': 41 | Chem.MolFromSmarts( 42 | '[#7]~[#7]' 43 | ), 44 | } 45 | 46 | 47 | def get_smiles_string(mol): 48 | """Make canonicalized smiles from rdkit.Mol.""" 49 | return Chem.MolToSmiles(mol, canonical=True, isomericSmiles=True) 50 | 51 | 52 | def get_molecular_formula(mol): 53 | """Makes string of molecular formula from rdkit.Mol.""" 54 | return AllChem.CalcMolFormula(mol) 55 | 56 | 57 | def parse_peaks(pk_str): 58 | r"""Helper function for converting peak string into vector form. 59 | 60 | Args: 61 | pk_str : String from NIST MS data of format 62 | "peak1_loc peak1_int\npeak2_loc peak2_int" 63 | 64 | Returns: 65 | A tuple containing two arrays: (peak_locs, peak_ints) 66 | peak_locs : int list of the location of peaks 67 | peak_intensities : float list of the intensity of the peaks 68 | """ 69 | all_peaks = pk_str.split('\n') 70 | 71 | peak_locs = [] 72 | peak_intensities = [] 73 | 74 | for peak in all_peaks: 75 | loc, intensity = peak.split() 76 | peak_locs.append(int(loc)) 77 | peak_intensities.append(float(intensity)) 78 | 79 | return peak_locs, peak_intensities 80 | 81 | 82 | def convert_spectrum_array_to_string(spectrum_array): 83 | """Write a spectrum array to string. 84 | 85 | Args: 86 | spectrum_array : np.array of shape (1000) 87 | 88 | Returns: 89 | string representing the peaks of the spectra. 90 | """ 91 | mz_peak_locations = np.nonzero(spectrum_array)[0].tolist() 92 | mass_peak_strings = [ 93 | '%d %d' % (p, spectrum_array[p]) for p in mz_peak_locations 94 | ] 95 | return '\n'.join(mass_peak_strings) 96 | 97 | 98 | def get_largest_mass_spec_peak_loc(mol): 99 | """Returns largest ms peak location from an rdkit.Mol object.""" 100 | return parse_peaks(mol.GetProp(ms_constants.SDF_TAG_MASS_SPEC_PEAKS))[0][-1] 101 | 102 | 103 | def make_dense_mass_spectra(peak_locs, peak_intensities, max_peak_loc): 104 | """Make a dense np.array of the mass spectra. 105 | 106 | Args: 107 | peak_locs : int list of the location of peaks 108 | peak_intensities : float list of the intensity of the peaks 109 | max_peak_loc : maximum number of peaks bins 110 | 111 | Returns: 112 | np.array of the mass spectra data as a dense vector. 113 | """ 114 | dense_spectrum = np.zeros(max_peak_loc) 115 | dense_spectrum[peak_locs] = peak_intensities 116 | 117 | return dense_spectrum 118 | 119 | 120 | def get_padded_atom_weights(mol, max_atoms): 121 | """Make a padded list of atoms of length max_atoms given rdkit.Mol object. 122 | 123 | Note: Returns atoms in the same order as the input rdkit.Mol. 124 | If you want the atoms in canonical order, you should canonicalize 125 | the molecule first. 126 | 127 | Args: 128 | mol : a rdkit.Mol object 129 | max_atoms : maximum number of atoms to consider 130 | Returns: 131 | np array of atoms by atomic mass of shape (max_atoms) 132 | Raises: 133 | ValueError : If rdkit.Mol object had more atoms than max_atoms. 134 | """ 135 | 136 | if max_atoms < len(mol.GetAtoms()): 137 | raise ValueError( 138 | 'molecule contains {} atoms, more than max_atoms {}'.format( 139 | len(mol.GetAtoms()), max_atoms)) 140 | 141 | atom_list = np.array([at.GetMass() for at in mol.GetAtoms()]) 142 | atom_list = np.pad(atom_list, ((0, max_atoms - len(atom_list))), 'constant') 143 | return atom_list 144 | 145 | 146 | def get_padded_atom_ids(mol, max_atoms): 147 | """Make a padded list of atoms of length max_atoms given rdkit.Mol object. 148 | 149 | Args: 150 | mol : a rdkit.Mol object 151 | max_atoms : maximum number of atoms to consider 152 | Returns: 153 | np array of atoms by atomic number of shape (max_atoms) 154 | Note: function returns atoms in the same order as the input rdkit.Mol. 155 | If you want the atoms in canonical order, you should canonicalize 156 | the molecule first. 157 | Raises: 158 | ValueError : rdkit.Mol object is too big, had more atoms than max_atoms. 159 | """ 160 | if max_atoms < len(mol.GetAtoms()): 161 | raise ValueError( 162 | 'molecule contains {} atoms, more than max_atoms {}'.format( 163 | len(mol.GetAtoms()), max_atoms)) 164 | atom_list = np.array([at.GetAtomicNum() for at in mol.GetAtoms()]) 165 | atom_list = atom_list.astype('int32') 166 | atom_list = np.pad(atom_list, ((0, max_atoms - len(atom_list))), 'constant') 167 | 168 | return atom_list 169 | 170 | 171 | def get_padded_adjacency_matrix(mol, max_atoms, add_hs_to_molecule=False): 172 | """Make a matrix with shape (max_atoms, max_atoms) given rdkit.Mol object. 173 | 174 | Args: 175 | mol: a rdkit.Mol object 176 | max_atoms : maximum number of atoms to consider 177 | add_hs_to_molecule : whether or not to add hydrogens to the molecule. 178 | Returns: 179 | np.array of floats of a flattened adjacency matrix of length 180 | (max_atoms * max_atoms) 181 | The values will be the index of the bond order in the alphabet 182 | Raises: 183 | ValueError : rdkit.Mol object is too big, had more atoms than max_atoms. 184 | """ 185 | # Add hydrogens to atoms: 186 | if add_hs_to_molecule: 187 | mol = Chem.rdmolops.AddHs(mol) 188 | 189 | num_atoms_in_mol = len(mol.GetAtoms()) 190 | if max_atoms < num_atoms_in_mol: 191 | raise ValueError( 192 | 'molecule contains {} atoms, more than max_atoms {}'.format( 193 | len(mol.GetAtoms()), max_atoms)) 194 | 195 | adj_matrix = Chem.rdmolops.GetAdjacencyMatrix(mol, useBO=True) 196 | 197 | for i in range(np.shape(adj_matrix)[0]): 198 | for j in range(np.shape(adj_matrix)[1]): 199 | if adj_matrix[i, j] != 0: 200 | adj_matrix[i, j] = ms_constants.BOND_ORDER_TO_INTS_DICT[adj_matrix[i, 201 | j]] 202 | 203 | padded_adjacency_matrix = np.zeros((max_atoms, max_atoms)) 204 | 205 | padded_adjacency_matrix[:num_atoms_in_mol, :num_atoms_in_mol] = adj_matrix 206 | padded_adjacency_matrix = padded_adjacency_matrix.astype('int32') 207 | 208 | return np.reshape(padded_adjacency_matrix, (max_atoms * max_atoms)) 209 | 210 | 211 | def make_circular_fingerprint(mol, circular_fp_key): 212 | """Returns circular fingerprint for a mol given its circular_fp_key. 213 | 214 | Args: 215 | mol : rdkit.Mol 216 | circular_fp_key : A ms_constants.CircularFingerprintKey object 217 | Returns: 218 | np.array of len circular_fp_key.fp_len 219 | """ 220 | # A dictionary to record rdkit functions to base names 221 | fp_methods_dict = { 222 | fmap_constants.CIRCULAR_FP_BASENAME: 223 | AllChem.GetMorganFingerprintAsBitVect, 224 | fmap_constants.COUNTING_CIRCULAR_FP_BASENAME: 225 | AllChem.GetHashedMorganFingerprint 226 | } 227 | 228 | fp = fp_methods_dict[circular_fp_key.fp_type]( 229 | mol, circular_fp_key.radius, nBits=circular_fp_key.fp_len) 230 | fp_arr = np.zeros(1) 231 | DataStructs.ConvertToNumpyArray(fp, fp_arr) 232 | return fp_arr 233 | 234 | 235 | def all_circular_fingerprints_to_dict(mol): 236 | """Creates all circular fingerprints from list of lengths and radii. 237 | 238 | Based on lists of fingerprint lengths and fingerprint radii inside 239 | mass_spec_constants. 240 | 241 | Args: 242 | mol : rdkit.Mol 243 | Returns: 244 | a dict. The keys are CircularFingerprintKey instances and the values are 245 | the corresponding fingerprints 246 | """ 247 | fp_dict = {} 248 | for fp_len in ms_constants.NUM_CIRCULAR_FP_BITS_LIST: 249 | for rad in ms_constants.CIRCULAR_FP_RADII_LIST: 250 | for fp_type in fmap_constants.FP_TYPE_LIST: 251 | circular_fp_key = ms_constants.CircularFingerprintKey( 252 | fp_type, fp_len, rad) 253 | fp_dict[circular_fp_key] = make_circular_fingerprint( 254 | mol, circular_fp_key) 255 | return fp_dict 256 | 257 | 258 | def check_mol_has_non_empty_smiles(mol): 259 | """Checks if smiles string of rdkit.Mol is an empty string.""" 260 | return bool(get_smiles_string(mol)) 261 | 262 | 263 | def check_mol_has_non_empty_mass_spec_peak_tag(mol): 264 | """Checks if mass spec sdf tag is in properties of rdkit.Mol.""" 265 | return ms_constants.SDF_TAG_MASS_SPEC_PEAKS in mol.GetPropNames() 266 | 267 | 268 | def check_mol_only_has_atoms(mol, accept_atom_list): 269 | """Checks if rdkit.Mol only contains atoms from accept_atom_list.""" 270 | atom_symbol_list = [atom.GetSymbol() for atom in mol.GetAtoms()] 271 | return all(atom in accept_atom_list for atom in atom_symbol_list) 272 | 273 | 274 | def check_mol_does_not_have_atoms(mol, exclude_atom_list): 275 | """Checks if rdkit.Mol contains any molecule from exclude_atom_list.""" 276 | atom_symbol_list = [atom.GetSymbol() for atom in mol.GetAtoms()] 277 | return all(atom not in atom_symbol_list for atom in exclude_atom_list) 278 | 279 | 280 | def check_mol_has_substructure(mol, substructure_mol): 281 | """Checks if rdkit.Mol has substructure. 282 | 283 | Args: 284 | mol : rdkit.Mol, representing query 285 | substructure_mol: rdkit.Mol, representing substructure family 286 | Returns: 287 | Boolean, True if substructure found in molecule. 288 | """ 289 | return mol.HasSubstructMatch(substructure_mol) 290 | 291 | 292 | def make_filter_by_substructure(family_name): 293 | """Returns a filter function according to the family_name.""" 294 | if family_name not in FILTER_DICT.keys(): 295 | raise ValueError('%s is not supported for family splitting' % family_name) 296 | return lambda mol: check_mol_has_substructure(mol, FILTER_DICT[family_name]) 297 | 298 | 299 | def tokenize_smiles(smiles_string_arr): 300 | """Creates a list of tokens from a smiles string. 301 | 302 | Two letter atom characters are considered to be a single token. 303 | All two letter tokens observed in this dataset are recorded in 304 | ms_constants.TWO_LETTER_TOKEN_NAMES. 305 | 306 | Args: 307 | smiles_string_arr: np.array of dtype str and shape (1, ) 308 | Returns: 309 | A np.array of ints corresponding with the tokens 310 | """ 311 | 312 | smiles_str = smiles_string_arr[0] 313 | if isinstance(smiles_str, bytes): 314 | smiles_str = smiles_str.decode('utf-8') 315 | 316 | token_list = [] 317 | ptr = 0 318 | 319 | while ptr < len(smiles_str): 320 | if smiles_str[ptr:ptr + 2] in ms_constants.TWO_LETTER_TOKEN_NAMES: 321 | token_list.append( 322 | ms_constants.SMILES_TOKEN_NAME_TO_INDEX[smiles_str[ptr:ptr + 2]]) 323 | ptr += 2 324 | else: 325 | token_list.append( 326 | ms_constants.SMILES_TOKEN_NAME_TO_INDEX[smiles_str[ptr]]) 327 | ptr += 1 328 | 329 | return np.array(token_list, dtype=np.int64) 330 | -------------------------------------------------------------------------------- /molecule_estimator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | r"""Train and evaluate massspec model. 15 | 16 | Example usage: 17 | molecule_estimator.py --train_steps=1000 --model_dir='/tmp/models/output' \ 18 | --dataset_config_file=testdata/test_dataset_config_file.json --alsologtostderr 19 | """ 20 | 21 | from __future__ import print_function 22 | import json 23 | import os 24 | 25 | from absl import flags 26 | import dataset_setup_constants as ds_constants 27 | import feature_map_constants as fmap_constants 28 | import library_matching 29 | import molecule_predictors 30 | import parse_sdf_utils 31 | import six 32 | import tensorflow as tf 33 | 34 | FLAGS = flags.FLAGS 35 | flags.DEFINE_string( 36 | 'dataset_config_file', None, 37 | 'JSON file specifying the various filenames necessary for training and ' 38 | 'evaluation. See make_input_fn() for more details.') 39 | flags.DEFINE_string( 40 | 'hparams', '', 'Hyperparameter values to override the defaults.' 41 | 'Format: params1=value1,params2=value2, ...' 42 | 'Possible parameters: max_atoms, max_mass_spec_peak_loc,' 43 | 'batch_size, epochs, do_linear_regression,' 44 | 'get_mass_spec_features, init_weights, init_bias') 45 | flags.DEFINE_integer('train_steps', None, 46 | 'The number of steps to run training for.') 47 | flags.DEFINE_integer( 48 | 'train_steps_per_iteration', 1000, 49 | 'how frequently to evaluate (only used when schedule ==' 50 | ' continuous_train_and_eval') 51 | 52 | flags.DEFINE_string('model_dir', '', 53 | 'output directory for checkpoints and events files') 54 | flags.DEFINE_string('warm_start_dir', None, 55 | 'directory to warm start model from') 56 | 57 | flags.DEFINE_enum('model_type', 'mlp', 58 | molecule_predictors.MODEL_REGISTRY.keys(), 59 | 'Type of model to use.') 60 | OUTPUT_HPARAMS_CONFIG_FILE_BASE = 'command_line_arguments.txt' 61 | 62 | 63 | def make_input_fn(dataset_config_file, 64 | hparams, 65 | mode, 66 | features_to_load, 67 | load_library_matching_data, 68 | data_dir=None): 69 | """Make input functions returning features and labels. 70 | 71 | In our downstream code, it is advantageous to put both features 72 | and labels together in the same nested structure. However, tf.estimator 73 | requires input_fn() to return features and labels, so here our input_fn 74 | returns dummy labels that will not be used. 75 | 76 | Args: 77 | dataset_config_file: filename of JSON file containing a dict mapping dataset 78 | names to data files. The required keys are: 79 | 'SPECTRUM_PREDICTION_TRAIN': training data 80 | 'SPECTRUM_PREDICTION_TEST': eval data on which we evaluate the same loss 81 | function that is used for training 82 | 'LIBRARY_MATCHING_OBSERVED': data for library matching where we use 83 | ground-truth spectra 84 | 'LIBRARY_MATCHING_PREDICTED': data for library matching where we use 85 | predicted spectra 86 | 'LIBRARY_MATCHING_QUERY': data with observed spectra used for queries in 87 | the library 88 | matching task 89 | 90 | For each data file with name , we read high-level information about 91 | the data from a separate file with the name .info. See 92 | parse_sdf_utils.parse_info_file() for the expected format of that file. 93 | 94 | hparams: hparams required for parsing; includes features such as max_atoms, 95 | max_mass_spec_peak_loc, and batch_size 96 | mode: Set whether training or test mode 97 | features_to_load: list of keys to load from the input data 98 | load_library_matching_data: whether to load library matching data. 99 | data_dir: The directory containing the file names referred to in 100 | dataset_config_file. If None (the default) then this is assumed to be the 101 | directory containing dataset_config_file. 102 | Returns: 103 | A function for creating features and labels from a dataset. 104 | """ 105 | with tf.gfile.Open(dataset_config_file, 'r') as f: 106 | filenames = json.load(f) 107 | 108 | if data_dir is None: 109 | data_dir = os.path.dirname(dataset_config_file) 110 | 111 | def _input_fn(record_fnames, 112 | all_data_in_one_batch, 113 | load_training_spectrum_library=False): 114 | """Reads TFRecord from a list of record file names.""" 115 | if not record_fnames: 116 | return None 117 | 118 | record_fnames = [os.path.join(data_dir, r_name) for r_name in record_fnames] 119 | dataset = parse_sdf_utils.get_dataset_from_record( 120 | record_fnames, 121 | hparams, 122 | mode=mode, 123 | features_to_load=(features_to_load + hparams.label_names), 124 | all_data_in_one_batch=all_data_in_one_batch) 125 | dict_to_return = parse_sdf_utils.make_features_and_labels( 126 | dataset, features_to_load, hparams.label_names, mode=mode)[0] 127 | 128 | if load_training_spectrum_library: 129 | library_file = os.path.join( 130 | '/readahead/128M/', 131 | filenames[ds_constants.TRAINING_SPECTRA_ARRAY_KEY]) 132 | train_library = parse_sdf_utils.load_training_spectra_array(library_file) 133 | train_library = tf.convert_to_tensor(train_library, dtype=tf.float32) 134 | 135 | dict_to_return['SPECTRUM_PREDICTION_LIBRARY'] = train_library 136 | 137 | return dict_to_return 138 | 139 | load_training_spectrum_library = hparams.loss == 'max_margin' 140 | 141 | if load_library_matching_data: 142 | 143 | def _wrapped_input_fn(): 144 | """Construct data for various eval tasks.""" 145 | 146 | data_to_return = { 147 | fmap_constants.SPECTRUM_PREDICTION: 148 | _input_fn( 149 | filenames[ds_constants.SPECTRUM_PREDICTION_TEST_KEY], 150 | all_data_in_one_batch=False, 151 | load_training_spectrum_library=load_training_spectrum_library) 152 | } 153 | 154 | if hparams.do_library_matching: 155 | observed = _input_fn( 156 | filenames[ds_constants.LIBRARY_MATCHING_OBSERVED_KEY], 157 | all_data_in_one_batch=True) 158 | predicted = _input_fn( 159 | filenames[ds_constants.LIBRARY_MATCHING_PREDICTED_KEY], 160 | all_data_in_one_batch=True) 161 | query = _input_fn( 162 | filenames[ds_constants.LIBRARY_MATCHING_QUERY_KEY], 163 | all_data_in_one_batch=False) 164 | data_to_return[fmap_constants. 165 | LIBRARY_MATCHING] = library_matching.LibraryMatchingData( 166 | observed=observed, predicted=predicted, query=query) 167 | 168 | return data_to_return, tf.constant(0) 169 | else: 170 | 171 | def _wrapped_input_fn(): 172 | # See the above comment about why we use dummy labels. 173 | return { 174 | fmap_constants.SPECTRUM_PREDICTION: 175 | _input_fn( 176 | filenames[ds_constants.SPECTRUM_PREDICTION_TRAIN_KEY], 177 | all_data_in_one_batch=False, 178 | load_training_spectrum_library=load_training_spectrum_library) 179 | }, tf.constant(0) 180 | 181 | return _wrapped_input_fn 182 | 183 | 184 | def _log_command_line_string(model_type, model_dir, hparams): 185 | """Log command line args to replicate hparam configuration.""" 186 | 187 | config_string = '--model_type=%s ' % (model_type) 188 | 189 | # Note that the rendered string will not be able to be parsed using 190 | # hparams.parse() if any of the hparam values have commas or '=' signs. 191 | hparams_string = ','.join( 192 | ['%s=%s' % (key, value) for key, value in six.iteritems( 193 | hparams.values())]) 194 | 195 | config_string += ' --hparams=%s\n' % hparams_string 196 | output_file = os.path.join(model_dir, OUTPUT_HPARAMS_CONFIG_FILE_BASE) 197 | tf.gfile.MakeDirs(model_dir) 198 | tf.logging.info('Writing command line config string to %s.' % output_file) 199 | 200 | with tf.gfile.Open(output_file, 'w') as f: 201 | f.write(config_string) 202 | 203 | 204 | def make_model_fn(prediction_helper, dataset_config_file, model_dir): 205 | """Returns a model function for estimator given prediction base class. 206 | 207 | Args: 208 | prediction_helper : Helper class containing prediction, loss, and evaluation 209 | metrics 210 | dataset_config_file: see make_input_fn. 211 | model_dir : directory for writing output files. If model dir is not None, 212 | a file containing all of the necessary command line flags to rehydrate 213 | the model will be written in model_dir. 214 | Returns: 215 | A function that returns a tf.EstimatorSpec 216 | """ 217 | 218 | def _model_fn(features, labels, params, mode=None): 219 | """Returns tf.EstimatorSpec.""" 220 | 221 | # Input labels are ignored. All data are in features. 222 | del labels 223 | 224 | if model_dir is not None: 225 | _log_command_line_string(prediction_helper.model_type, model_dir, params) 226 | 227 | pred_op, pred_op_for_loss = prediction_helper.make_prediction_ops( 228 | features[fmap_constants.SPECTRUM_PREDICTION], params, mode) 229 | loss_op = prediction_helper.make_loss( 230 | pred_op_for_loss, features[fmap_constants.SPECTRUM_PREDICTION], params) 231 | 232 | if mode == tf.estimator.ModeKeys.TRAIN: 233 | train_op = tf.contrib.layers.optimize_loss( 234 | loss=loss_op, 235 | global_step=tf.train.get_global_step(), 236 | clip_gradients=params.gradient_clip, 237 | learning_rate=params.learning_rate, 238 | optimizer='Adam') 239 | eval_op = None 240 | elif mode == tf.estimator.ModeKeys.PREDICT: 241 | train_op = None 242 | eval_op = None 243 | elif mode == tf.estimator.ModeKeys.EVAL: 244 | train_op = None 245 | eval_op = prediction_helper.make_evaluation_metrics( 246 | features, params, dataset_config_file, output_dir=model_dir) 247 | else: 248 | raise ValueError('Invalid mode. Must be ' 249 | 'tf.estimator.ModeKeys.{TRAIN,PREDICT,EVAL}.') 250 | return tf.estimator.EstimatorSpec( 251 | mode=mode, 252 | predictions=pred_op, 253 | loss=loss_op, 254 | train_op=train_op, 255 | eval_metric_ops=eval_op) 256 | 257 | return _model_fn 258 | 259 | 260 | def make_estimator_and_inputs(run_config, 261 | hparams, 262 | prediction_helper, 263 | dataset_config_file, 264 | train_steps, 265 | model_dir, 266 | warm_start_dir=None): 267 | """Make Estimator-compatible Estimator and input_fn for train and eval.""" 268 | 269 | model_fn = make_model_fn(prediction_helper, dataset_config_file, model_dir) 270 | 271 | train_input_fn = make_input_fn( 272 | dataset_config_file=dataset_config_file, 273 | hparams=hparams, 274 | mode=tf.estimator.ModeKeys.TRAIN, 275 | features_to_load=prediction_helper.features_to_load(hparams), 276 | load_library_matching_data=False) 277 | 278 | eval_input_fn = make_input_fn( 279 | dataset_config_file=dataset_config_file, 280 | hparams=hparams, 281 | mode=tf.estimator.ModeKeys.EVAL, 282 | features_to_load=prediction_helper.features_to_load(hparams), 283 | load_library_matching_data=True) 284 | 285 | estimator = tf.estimator.Estimator( 286 | model_fn=model_fn, 287 | params=hparams, 288 | config=run_config, 289 | warm_start_from=warm_start_dir) 290 | 291 | train_spec = tf.estimator.TrainSpec(train_input_fn, max_steps=train_steps) 292 | eval_spec = tf.estimator.EvalSpec(eval_input_fn, steps=None) 293 | 294 | return estimator, train_spec, eval_spec 295 | 296 | 297 | def main(_): 298 | prediction_helper = molecule_predictors.get_prediction_helper( 299 | FLAGS.model_type) 300 | 301 | hparams = prediction_helper.get_default_hparams() 302 | hparams.parse(FLAGS.hparams) 303 | 304 | config = tf.contrib.learn.RunConfig(model_dir=FLAGS.model_dir) 305 | 306 | (estimator, train_spec, eval_spec) = make_estimator_and_inputs( 307 | config, hparams, prediction_helper, FLAGS.dataset_config_file, 308 | FLAGS.train_steps, FLAGS.model_dir, FLAGS.warm_start_dir) 309 | tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) 310 | 311 | 312 | if __name__ == '__main__': 313 | tf.app.run(main) 314 | -------------------------------------------------------------------------------- /plot_spectra_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Functions to add an evaluation metric that generates spectra plots.""" 15 | from __future__ import print_function 16 | 17 | import json 18 | import os 19 | 20 | import dataset_setup_constants as ds_constants 21 | import mass_spec_constants as ms_constants 22 | import matplotlib.pyplot as plt 23 | import numpy as np 24 | from PIL import Image as PilImage 25 | import six 26 | import tensorflow as tf 27 | 28 | IMAGE_SUBDIR_FOR_SPECTRA_PLOTS = 'images' 29 | 30 | SPECTRA_PLOT_BACKGROUND_COLOR = 'white' 31 | SPECTRA_PLOT_FIGURE_SIZE = (10, 10) 32 | SPECTRA_PLOT_GRID_COLOR = 'black' 33 | SPECTRA_PLOT_TRUE_SPECTRA_COLOR = 'blue' 34 | SPECTRA_PLOT_PREDICTED_SPECTRA_COLOR = 'red' 35 | SPECTRA_PLOT_PEAK_LOC_LIMIT = ms_constants.MAX_PEAK_LOC 36 | SPECTRA_PLOT_MZ_MAX_OFFSET = 10 37 | SPECTRA_PLOT_INTENSITY_LIMIT = 1200 38 | SPECTRA_PLOT_DPI = 300 39 | SPECTRA_PLOT_BAR_LINE_WIDTH = 0.8 40 | SPECTRA_PLOT_BAR_GRID_LINE_WIDTH = 0.1 41 | SPECTRA_PLOT_ACTUAL_SPECTRA_LEGEND_TEXT = 'True Mass Spectrum' 42 | SPECTRA_PLOT_PREDICTED_SPECTRA_LEGEND_TEXT = 'Predicted Mass Spectrum' 43 | SPECTRA_PLOT_QUERY_SPECTRA_LEGEND_TEXT = 'Query Mass Spectrum' 44 | SPECTRA_PLOT_LIBRARY_MATCH_SPECTRA_LEGEND_TEXT = 'Library Matched Mass Spectrum' 45 | SPECTRA_PLOT_X_AXIS_LABEL = 'mass/charge ratio' 46 | SPECTRA_PLOT_Y_AXIS_LABEL = 'relative intensity' 47 | SPECTRA_PLOT_PLACE_LEGEND_ABOVE_CHART_KWARGS = {'ncol': 2} 48 | SPECTRA_PLOT_IMAGE_DIR_NAME = 'images' 49 | SPECTRA_PLOT_DIMENSIONS_RGB = (3000, 3000, 3) 50 | FIGURES_TO_SUMMARIZE_PER_BATCH = 2 51 | MAX_VALUE_OF_TRUE_SPECTRA = 999. 52 | 53 | 54 | class PlotModeKeys(object): 55 | """Helper class containing the two supported plotting modes. 56 | 57 | The following keys are defined: 58 | PREDICTED_SPECTRUM : For plotting the spectrum predicted by the algorithm 59 | against the true spectrum. 60 | LIBRARY_MATCHED_SPECTRUM : For plotting the spectrum that was the closest 61 | match to the true spectrum against the true spectrum. 62 | """ 63 | PREDICTED_SPECTRUM = 'predicted_spectrum' 64 | LIBRARY_MATCHED_SPECTRUM = 'library_match_spectrum' 65 | 66 | 67 | def name_plot_file(mode, query_inchikey, matched_inchikey=None, 68 | file_type='png'): 69 | """Generates name for spectra plot files.""" 70 | if mode == PlotModeKeys.PREDICTED_SPECTRUM: 71 | return '{}.{}'.format(query_inchikey, file_type) 72 | elif mode == PlotModeKeys.LIBRARY_MATCHED_SPECTRUM: 73 | return '{}_matched_to_{}.{}'.format(query_inchikey, matched_inchikey, 74 | file_type) 75 | 76 | 77 | def name_metric(mode, inchikey): 78 | return '{}_plot_{}'.format(mode, inchikey) 79 | 80 | 81 | def plot_true_and_predicted_spectra( 82 | true_dense_spectra, 83 | generated_dense_spectra, 84 | plot_mode_key=PlotModeKeys.PREDICTED_SPECTRUM, 85 | output_filename='', 86 | rescale_mz_axis=False): 87 | """Generates a plot comparing a true and predicted mass spec spectra. 88 | 89 | If output_filename given, saves a png file of the spectra, with the 90 | true spectrum above and predicted spectrum below. 91 | 92 | Args: 93 | true_dense_spectra : np.array representing the true mass spectra 94 | generated_dense_spectra : np.array representing the predicted mass spectra 95 | plot_mode_key: a PlotModeKeys instance 96 | output_filename : str path for saving generated image. 97 | rescale_mz_axis: Setting to rescale m/z axis according to highest m/z peak 98 | location. 99 | 100 | Returns: 101 | np.array of the bits of the generated matplotlib plot. 102 | """ 103 | 104 | if not rescale_mz_axis: 105 | x_array = np.arange(SPECTRA_PLOT_PEAK_LOC_LIMIT) 106 | bar_width = SPECTRA_PLOT_BAR_LINE_WIDTH 107 | mz_max = SPECTRA_PLOT_PEAK_LOC_LIMIT 108 | else: 109 | mz_max = max( 110 | max(np.nonzero(true_dense_spectra)[0]), 111 | max(np.nonzero(generated_dense_spectra)[0])) 112 | if mz_max + SPECTRA_PLOT_MZ_MAX_OFFSET < ms_constants.MAX_PEAK_LOC: 113 | mz_max += SPECTRA_PLOT_MZ_MAX_OFFSET 114 | else: 115 | mz_max = ms_constants.MAX_PEAK_LOC 116 | x_array = np.arange(mz_max) 117 | true_dense_spectra = true_dense_spectra[:mz_max] 118 | generated_dense_spectra = generated_dense_spectra[:mz_max] 119 | bar_width = SPECTRA_PLOT_BAR_LINE_WIDTH * mz_max / ms_constants.MAX_PEAK_LOC 120 | 121 | figure = plt.figure(figsize=SPECTRA_PLOT_FIGURE_SIZE, dpi=300) 122 | 123 | # Adding extra subplot so both plots have common x-axis and y-axis labels 124 | ax_main = figure.add_subplot(111, frameon=False) 125 | ax_main.tick_params( 126 | labelcolor='none', top='off', bottom='off', left='off', right='off') 127 | 128 | ax_main.set_xlabel(SPECTRA_PLOT_X_AXIS_LABEL) 129 | ax_main.set_ylabel(SPECTRA_PLOT_Y_AXIS_LABEL) 130 | 131 | if six.PY2: 132 | ax_top = figure.add_subplot(211, axisbg=SPECTRA_PLOT_BACKGROUND_COLOR) 133 | else: 134 | ax_top = figure.add_subplot(211, facecolor=SPECTRA_PLOT_BACKGROUND_COLOR) 135 | 136 | bar_top = ax_top.bar( 137 | x_array, 138 | true_dense_spectra, 139 | bar_width, 140 | color=SPECTRA_PLOT_TRUE_SPECTRA_COLOR, 141 | edgecolor=SPECTRA_PLOT_TRUE_SPECTRA_COLOR, 142 | ) 143 | 144 | ax_top.set_ylim((0, SPECTRA_PLOT_INTENSITY_LIMIT)) 145 | plt.setp(ax_top.get_xticklabels(), visible=False) 146 | ax_top.grid( 147 | color=SPECTRA_PLOT_GRID_COLOR, linewidth=SPECTRA_PLOT_BAR_GRID_LINE_WIDTH) 148 | 149 | if six.PY2: 150 | ax_bottom = figure.add_subplot(212, axisbg=SPECTRA_PLOT_BACKGROUND_COLOR) 151 | else: 152 | ax_bottom = figure.add_subplot(212, facecolor=SPECTRA_PLOT_BACKGROUND_COLOR) 153 | figure.subplots_adjust(hspace=0.0) 154 | 155 | bar_bottom = ax_bottom.bar( 156 | x_array, 157 | generated_dense_spectra, 158 | bar_width, 159 | color=SPECTRA_PLOT_PREDICTED_SPECTRA_COLOR, 160 | edgecolor=SPECTRA_PLOT_PREDICTED_SPECTRA_COLOR, 161 | ) 162 | 163 | # Invert the direction of y-axis ticks for bottom graph. 164 | ax_bottom.set_ylim((SPECTRA_PLOT_INTENSITY_LIMIT, 0)) 165 | 166 | ax_bottom.set_xlim(0, mz_max) 167 | # Remove overlapping 0's from middle of y-axis 168 | yticks_bottom = ax_bottom.yaxis.get_major_ticks() 169 | yticks_bottom[0].label1.set_visible(False) 170 | 171 | ax_bottom.grid( 172 | color=SPECTRA_PLOT_GRID_COLOR, linewidth=SPECTRA_PLOT_BAR_GRID_LINE_WIDTH) 173 | 174 | for ax in [ax_top, ax_bottom]: 175 | ax.minorticks_on() 176 | ax.tick_params(axis='y', which='minor', left='off') 177 | ax.tick_params(axis='y', which='minor', right='off') 178 | 179 | ax_top.tick_params(axis='x', which='minor', top='off') 180 | 181 | if plot_mode_key == PlotModeKeys.PREDICTED_SPECTRUM: 182 | ax_top.legend((bar_top, bar_bottom), 183 | (SPECTRA_PLOT_ACTUAL_SPECTRA_LEGEND_TEXT, 184 | SPECTRA_PLOT_PREDICTED_SPECTRA_LEGEND_TEXT), 185 | **SPECTRA_PLOT_PLACE_LEGEND_ABOVE_CHART_KWARGS) 186 | elif plot_mode_key == PlotModeKeys.LIBRARY_MATCHED_SPECTRUM: 187 | ax_top.legend((bar_top, bar_bottom), 188 | (SPECTRA_PLOT_QUERY_SPECTRA_LEGEND_TEXT, 189 | SPECTRA_PLOT_LIBRARY_MATCH_SPECTRA_LEGEND_TEXT), 190 | **SPECTRA_PLOT_PLACE_LEGEND_ABOVE_CHART_KWARGS) 191 | 192 | figure.canvas.draw() 193 | data = np.fromstring(figure.canvas.tostring_rgb(), dtype=np.uint8, sep='') 194 | 195 | try: 196 | data = np.reshape(data, SPECTRA_PLOT_DIMENSIONS_RGB) 197 | except ValueError: 198 | raise ValueError( 199 | 'The shape of the np.array generated from the data does ' 200 | 'not match the values in ' 201 | 'SPECTRA_PLOT_DIMENSIONS_RGB : {}'.format(SPECTRA_PLOT_DIMENSIONS_RGB)) 202 | 203 | if output_filename: 204 | # We can't call plt.savefig(output_filename) because plt does not 205 | # communicate with the filesystem through gfile. In some scenarios, this 206 | # will prevent us from writing out data. Instead, we use PIL to help us 207 | # efficiently save the nparray of the image as a png file. 208 | if not output_filename.endswith('.png') or output_filename.endswith('.eps'): 209 | output_filename += '.png' 210 | 211 | with tf.gfile.GFile(output_filename, 'wb') as out: 212 | image = PilImage.fromarray(data).convert('RGB') 213 | image.save(out, dpi=(SPECTRA_PLOT_DPI, SPECTRA_PLOT_DPI)) 214 | 215 | tf.logging.info('Shape of spectra plot data {} '.format(np.shape(data))) 216 | 217 | plt.close(figure) 218 | 219 | return data 220 | 221 | 222 | def make_plot(inchikey, 223 | plot_mode_key, 224 | update_img_flag, 225 | inchikeys_batch, 226 | true_spectra_batch, 227 | predictions, 228 | image_directory=None, 229 | library_match_inchikeys=None): 230 | """Makes plots comparing the true and predicted spectra in a dataset. 231 | 232 | This function only performs the expensive step of generating the spectrum 233 | plot if the target inchikey is in the current batch. 234 | 235 | Args: 236 | inchikey: Inchikey of query that we want to make plots with. 237 | plot_mode_key: A PlotModeKeys instance. 238 | update_img_flag: Boolean flag for whether to generate a spectra plot 239 | inchikeys_batch: inchikeys from the current batch 240 | true_spectra_batch: np.array of all the true spectra from the current batch. 241 | predictions: np.array of all predicted spectra from the current batch. 242 | image_directory: Location to save image directory, if set. 243 | library_match_inchikeys: np.array of strings, corresponding to inchikeys 244 | that were the best matched from the library inchikey task. 245 | 246 | Returns: 247 | if update_img_flag: np.array 248 | [see return value of plot_true_and_predicted_spectra] 249 | Otherwise, returns a zero np.array of shape SPECTRA_PLOT_DIMENSIONS_RGB. 250 | Also saves a file at image_directory if this value is non-zero. 251 | Raises: 252 | ValueError: library_match_inchikeys needs to be set if given image_directory 253 | and using PlotModeKeys.LIBRARY_MATCHED_SPECTRUM. 254 | """ 255 | if update_img_flag: 256 | flattened_inchikeys_batch = [ikey[0].strip() for ikey in inchikeys_batch] 257 | inchikey_idx = flattened_inchikeys_batch.index(inchikey) 258 | predictions = predictions / np.amax( 259 | predictions, axis=1, keepdims=True) * MAX_VALUE_OF_TRUE_SPECTRA 260 | predicted_spectra_to_plot = predictions[inchikey_idx, :] 261 | true_spectra_to_plot = true_spectra_batch[inchikey_idx, :] 262 | if image_directory: 263 | if plot_mode_key == PlotModeKeys.PREDICTED_SPECTRUM: 264 | img_filename = name_plot_file(plot_mode_key, inchikey) 265 | elif plot_mode_key == PlotModeKeys.LIBRARY_MATCHED_SPECTRUM: 266 | best_library_match_inchikey = library_match_inchikeys[inchikey_idx, :] 267 | img_filename = name_plot_file(plot_mode_key, inchikey, 268 | best_library_match_inchikey[0]) 269 | 270 | img_pathname = os.path.join(image_directory, img_filename) 271 | spectra_arr = plot_true_and_predicted_spectra(true_spectra_to_plot, 272 | predicted_spectra_to_plot, 273 | plot_mode_key, img_pathname) 274 | else: 275 | spectra_arr = plot_true_and_predicted_spectra(true_spectra_to_plot, 276 | predicted_spectra_to_plot, 277 | plot_mode_key) 278 | return spectra_arr 279 | else: 280 | return np.zeros(SPECTRA_PLOT_DIMENSIONS_RGB, dtype=np.uint8) 281 | 282 | 283 | def spectra_plot_summary_op(inchikey_list, 284 | true_spectra, 285 | prediction_batch, 286 | inchikey_to_plot, 287 | plot_mode_key=PlotModeKeys.PREDICTED_SPECTRUM, 288 | library_match_inchikeys=None, 289 | image_directory=None): 290 | """Wrapper for plotting mass spectra for labels and predictions. 291 | 292 | Plots predicted and true spectra for a given inchikey. If image_directory is 293 | set, will save the plots as files in addition to making the image summary. 294 | 295 | Args: 296 | inchikey_list : tf Tensor of inchikey strings for a batch 297 | true_spectra : tf Tensor array with true spectra for a batch 298 | prediction_batch: tf Tensor array of predicted spectra for a single batch. 299 | inchikey_to_plot: string InChI key contained in test set (but perhaps not in 300 | a particular batch). 301 | plot_mode_key: A PlotModeKeys instance. 302 | library_match_inchikeys: tf Tensor of strings corresponding to the inchikeys 303 | top match from the library matching task. 304 | image_directory: string of dir name to save plots 305 | 306 | Returns: 307 | tf.summary.image of the operation, and an update operator indicating if the 308 | summary has been updated or not. 309 | """ 310 | 311 | def _should_update_image(inchikeys_batch): 312 | """Tests whether to indicate if target inchikey is in batch.""" 313 | flattened_inchikeys_batch = [ikey[0].strip() for ikey in inchikeys_batch] 314 | return inchikey_to_plot in flattened_inchikeys_batch 315 | 316 | metric_namescope = 'spectrum_{}_plot_{}'.format(plot_mode_key, 317 | inchikey_to_plot) 318 | spectra_variable_name = 'spectrum_{}_plot_{}'.format(plot_mode_key, 319 | inchikey_to_plot) 320 | with tf.name_scope(metric_namescope): 321 | # Whether the inchikey_to_plot is in the current batch. 322 | update_image_bool = tf.py_func(_should_update_image, [inchikey_list], 323 | tf.bool) 324 | 325 | if plot_mode_key == PlotModeKeys.LIBRARY_MATCHED_SPECTRUM: 326 | spectra_plot = tf.py_func(make_plot, [ 327 | inchikey_to_plot, plot_mode_key, update_image_bool, inchikey_list, 328 | true_spectra, prediction_batch, image_directory, 329 | library_match_inchikeys 330 | ], tf.uint8) 331 | elif plot_mode_key == PlotModeKeys.PREDICTED_SPECTRUM: 332 | spectra_plot = tf.py_func(make_plot, [ 333 | inchikey_to_plot, plot_mode_key, update_image_bool, inchikey_list, 334 | true_spectra, prediction_batch, image_directory 335 | ], tf.uint8) 336 | 337 | # Container for the plot. this value will only be assigned to something 338 | # new if the target inchikey is in the input batch. 339 | spectra_plot_variable = tf.get_local_variable( 340 | spectra_variable_name, 341 | shape=((1,) + SPECTRA_PLOT_DIMENSIONS_RGB), 342 | initializer=tf.constant_initializer(128), 343 | dtype=tf.uint8) 344 | 345 | # A function that add the spectra plot as metric. 346 | def update_function(): 347 | assign_op = spectra_plot_variable.assign(spectra_plot[tf.newaxis, ...]) 348 | with tf.control_dependencies([assign_op]): 349 | return tf.identity(spectra_plot_variable) 350 | 351 | # We only want to update the metric if the inchikey_to_plot 352 | # is in the batch. update_image_bool serves as a flag to tf.cond 353 | # to use the real update function if inchikey_to_plot is in the batch 354 | # and a fake one if not. 355 | final_spectra_plot = tf.cond(update_image_bool, 356 | update_function, lambda: spectra_plot_variable) 357 | 358 | update_op = final_spectra_plot 359 | 360 | return (tf.summary.image( 361 | spectra_variable_name, spectra_plot_variable, 362 | collections=None), update_op) 363 | 364 | 365 | def inchikeys_for_plotting(dataset_config_file, num_inchikeys_to_read, 366 | eval_batch_size): 367 | """Return inchikeys from spectrum prediction data file. 368 | 369 | Selects one inchikey per eval batch for plotting. This will avoid the 370 | threading issue seen at evaluation time. 371 | 372 | Args: 373 | dataset_config_file: dataset configuration file for experiment. Contains 374 | filename of spectrum prediction inchikey text file. 375 | num_inchikeys_to_read: Number of inchikeys to use for plotting 376 | eval_batch_size: Number of inchikeys to skip before appending the next 377 | inchikey from the text file. 378 | 379 | Returns: 380 | list [num_inchikeys_to_read] containing inchikey strings. 381 | """ 382 | dataset_config_file_dir = os.path.split(dataset_config_file)[0] 383 | with tf.gfile.Open(dataset_config_file, 'r') as f: 384 | line = f.read() 385 | filenames = json.loads(line) 386 | test_inchikey_list_name = os.path.splitext(filenames[ 387 | ds_constants.SPECTRUM_PREDICTION_TEST_KEY][0])[0] + '.inchikey.txt' 388 | 389 | inchikey_list_for_plotting = [] 390 | 391 | with tf.gfile.Open( 392 | os.path.join(dataset_config_file_dir, test_inchikey_list_name)) as f: 393 | for line_idx, line in enumerate(f): 394 | if line_idx % eval_batch_size == 0: 395 | inchikey_list_for_plotting.append(line.strip('\n')) 396 | if len(inchikey_list_for_plotting) == num_inchikeys_to_read: 397 | break 398 | 399 | if len(inchikey_list_for_plotting) < num_inchikeys_to_read: 400 | tf.logging.warn('Dataset specified by {} has fewer than' 401 | '{} inchikeys. Returning {} for plotting'.format( 402 | dataset_config_file, 403 | num_inchikeys_to_read * eval_batch_size, 404 | len(inchikey_list_for_plotting))) 405 | return inchikey_list_for_plotting 406 | -------------------------------------------------------------------------------- /library_matching_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for deep_molecular_massspec.library_matching.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | 21 | import feature_map_constants as fmap_constants 22 | import library_matching 23 | import similarity as similarity_lib 24 | import numpy as np 25 | import tensorflow as tf 26 | 27 | PREDICTOR_INPUT_KEY = 'INPUT' 28 | 29 | 30 | class LibraryMatchingTest(tf.test.TestCase): 31 | 32 | def testCosineSimilarityProviderMatching(self): 33 | """Check correctness for querying the library with a library element.""" 34 | 35 | num_examples = 20 36 | num_trials = 10 37 | data_dim = 5 38 | similarity = similarity_lib.CosineSimilarityProvider() 39 | library = np.float32(np.random.normal(size=(num_examples, data_dim))) 40 | library = tf.constant(library) 41 | library = similarity.preprocess_library(library) 42 | query_idx = tf.placeholder(shape=(), dtype=tf.int32) 43 | query = library[query_idx][np.newaxis, ...] 44 | (match_idx_op, match_similarity_op, _, _, 45 | _) = library_matching._max_similarity_match(library, query, similarity) 46 | 47 | # Use queries that are rows of the library. This means that the maximum 48 | # cosine similarity is 1.0 and is achieved by the row index of the query 49 | # in the library. 50 | with tf.Session() as sess: 51 | for _ in range(num_trials): 52 | idx = np.random.randint(0, high=num_examples) 53 | match_idx, match_similarity = sess.run( 54 | [match_idx_op, match_similarity_op], feed_dict={query_idx: idx}) 55 | # Fail if the match_idx != idx, and the similarity of match_idx does 56 | # is not tied with the argmax (which is 1.0 by construction). 57 | if match_idx != idx: 58 | self.assertClose(match_similarity, 1.0) 59 | 60 | def testFindQueryPositions(self): 61 | """Test library_matching._find_query_rank_helper.""" 62 | 63 | library_keys = np.array(['a', 'b', 'c', 'd', 'a', 'b']) 64 | query_keys = np.array(['a', 'b', 'c', 'a', 'b', 'c']) 65 | 66 | similarities = np.array( 67 | [[3., 4., 6., 0., 0.1, 2.], [1., -1., 5., 3, 2., 1.1], 68 | [-5., 0., 2., 0.1, 3., 1.], [0.2, 0.4, 0.6, 0.32, 0.3, 0.9], 69 | [0.2, 0.9, 0.65, 0.18, 0.3, 0.99], [0.8, 0.6, 0.5, 0.4, 0.9, 0.89]]) 70 | 71 | (highest_query_ranks, lowest_query_ranks, avg_query_ranks, 72 | query_similarities) = library_matching._find_query_rank_helper( 73 | similarities, library_keys, query_keys) 74 | 75 | expected_highest_query_ranks = [4, 5, 1, 5, 1, 4] 76 | expected_lowest_query_ranks = [2, 3, 1, 4, 0, 4] 77 | expected_avg_query_ranks = [3, 4, 1, 4.5, 0.5, 4] 78 | expected_query_similarities = [3., 1.1, 2., 0.3, 0.99, 0.5] 79 | 80 | self.assertAllEqual(expected_highest_query_ranks, highest_query_ranks) 81 | self.assertAllEqual(expected_lowest_query_ranks, lowest_query_ranks) 82 | self.assertAllEqual(expected_avg_query_ranks, avg_query_ranks) 83 | 84 | self.assertAllEqual(expected_query_similarities, query_similarities) 85 | 86 | def testInvertPermutation(self): 87 | """Test library_matching._invert_permutation().""" 88 | 89 | batch_size = 5 90 | num_trials = 10 91 | permutation_length = 6 92 | 93 | def _validate_permutation(perm1, perm2): 94 | ordered_indices = np.arange(perm1.shape[0]) 95 | self.assertAllEqual(perm1[perm2], ordered_indices) 96 | self.assertAllEqual(perm2[perm1], ordered_indices) 97 | 98 | for _ in range(num_trials): 99 | perms = np.stack( 100 | [ 101 | np.random.permutation(permutation_length) 102 | for _ in range(batch_size) 103 | ], 104 | axis=0) 105 | 106 | inverse = library_matching._invert_permutation(perms) 107 | 108 | for j in range(batch_size): 109 | _validate_permutation(perms[j], inverse[j]) 110 | 111 | def np_normalize_rows(self, d): 112 | return d / np.maximum( 113 | np.linalg.norm(d, axis=1)[..., np.newaxis], similarity_lib.EPSILON) 114 | 115 | def make_ids(self, num_ids, prefix=''): 116 | if prefix: 117 | prefix += '-' 118 | 119 | return [('%s%d' % (prefix, uid)).encode('utf-8') for uid in range(num_ids)] 120 | 121 | def _random_fingerprint(self, num_elements): 122 | return tf.to_float(tf.random_uniform(shape=(num_elements, 1024)) > 0.5) 123 | 124 | def _package_data(self, ids, spectrum, masses): 125 | 126 | def convert(t): 127 | if t is None: 128 | return t 129 | else: 130 | return tf.convert_to_tensor(t) 131 | 132 | num_elements = len(ids) 133 | fingerprints = self._random_fingerprint(num_elements) 134 | return { 135 | fmap_constants.DENSE_MASS_SPEC: convert(spectrum), 136 | fmap_constants.INCHIKEY: convert(ids), 137 | library_matching.FP_NAME_FOR_JACCARD_SIMILARITY: fingerprints, 138 | fmap_constants.MOLECULE_WEIGHT: convert(masses) 139 | } 140 | 141 | def make_x_data(self, num_examples, x_dim): 142 | return np.float32(np.random.uniform(size=(num_examples, x_dim))) 143 | 144 | def np_library_matching(self, ids_predicted, ids_observed, y_predicted, 145 | y_observed, y_query): 146 | ids_library = np.concatenate([ids_predicted, ids_observed]) 147 | np_library = self.np_normalize_rows( 148 | np.concatenate([y_predicted, y_observed], axis=0)) 149 | np_similarities = np.dot(np_library, np.transpose(y_query)) 150 | np_predictions = np.argmax(np_similarities, axis=0) 151 | np_predicted_ids = [ids_library[i] for i in np_predictions] 152 | return np_predicted_ids 153 | 154 | def perform_matching(self, ids_observed, ids_predicted, ids_query, 155 | masses_observed, masses_predicted, masses_query, 156 | y_observed, y_query, x_predicted, tf_transform, 157 | mass_tolerance): 158 | 159 | query_data = self._package_data( 160 | ids=ids_query, spectrum=y_query, masses=masses_query) 161 | 162 | predicted_data = self._package_data( 163 | ids=ids_predicted, spectrum=None, masses=masses_predicted) 164 | predicted_data[PREDICTOR_INPUT_KEY] = tf.constant(x_predicted) 165 | 166 | observed_data = self._package_data( 167 | ids=ids_observed, spectrum=y_observed, masses=masses_observed) 168 | 169 | library_matching_data = library_matching.LibraryMatchingData( 170 | query=query_data, observed=observed_data, predicted=predicted_data) 171 | 172 | predictor_fn = lambda d: tf_transform(d[PREDICTOR_INPUT_KEY]) 173 | similarity = similarity_lib.CosineSimilarityProvider() 174 | true_data, predicted_data, _, _ = ( 175 | library_matching.library_matching(library_matching_data, predictor_fn, 176 | similarity, mass_tolerance, 10)) 177 | 178 | with tf.Session() as sess: 179 | sess.run(tf.local_variables_initializer()) 180 | return sess.run([predicted_data, true_data]) 181 | 182 | def tf_vs_np_library_matching_test_helper(self, 183 | num_observed, 184 | num_predicted, 185 | query_source, 186 | num_queries=5): 187 | """Helper for asserting TF and NP give same library matching output.""" 188 | 189 | x_dim = 5 190 | 191 | np_transform = lambda d: np.sqrt(d + 4) 192 | tf_transform = lambda t: tf.sqrt(t + 4) 193 | 194 | x_observed = self.make_x_data(num_observed, x_dim) 195 | x_predicted = self.make_x_data(num_predicted, x_dim) 196 | 197 | y_observed = np_transform(x_observed) 198 | y_predicted = np_transform(x_predicted) 199 | 200 | if query_source == 'random': 201 | # Use queries from the same generating process as the observed and 202 | # predicted data. 203 | x_query = self.make_x_data(num_queries, x_dim) 204 | y_query = np_transform(x_query) 205 | elif query_source == 'observed': 206 | # Copy the observed data to use as queries. 207 | y_query = y_observed 208 | elif query_source == 'zero': 209 | # Use the zero vector as the queries. 210 | y_query = np.zeros(shape=(num_queries, x_dim), dtype=np.float32) 211 | else: 212 | raise ValueError('Invalid query_source: %s' % query_source) 213 | 214 | ids_observed = self.make_ids(num_observed) 215 | ids_predicted = self.make_ids(num_predicted) 216 | ids_query = self.make_ids(num_queries) 217 | masses_observed = np.ones([num_observed, 1], dtype=np.float32) 218 | masses_predicted = np.ones([num_predicted, 1], dtype=np.float32) 219 | masses_query = np.ones([num_queries, 1], dtype=np.float32) 220 | 221 | (predicted_data, true_data) = self.perform_matching( 222 | ids_observed, 223 | ids_predicted, 224 | ids_query, 225 | masses_observed, 226 | masses_predicted, 227 | masses_query, 228 | y_observed, 229 | y_query, 230 | x_predicted, 231 | tf_transform, 232 | mass_tolerance=3) 233 | np_predicted_ids = self.np_library_matching( 234 | ids_predicted, ids_observed, y_predicted, y_observed, y_query) 235 | 236 | # Assert correctness of the ids of the library matches found by TF. 237 | self.assertAllEqual(np_predicted_ids, 238 | predicted_data[fmap_constants.INCHIKEY]) 239 | 240 | # Assert correctness of the ground truth ids output extracted by TF. 241 | self.assertAllEqual(true_data[fmap_constants.INCHIKEY], ids_query) 242 | 243 | # Assert that a query spectrum that is in the observed set should be matched 244 | # to the corresponding element in the observed set. 245 | if query_source == 'observed': 246 | self.assertAllEqual(ids_observed, predicted_data[fmap_constants.INCHIKEY]) 247 | 248 | def testLibraryMatchingTFvsNP(self): 249 | return self.tf_vs_np_library_matching_test_helper( 250 | num_observed=10, num_predicted=5, query_source='random') 251 | 252 | def testLibraryMatchingTFvsNPZeroObserved(self): 253 | return self.tf_vs_np_library_matching_test_helper( 254 | num_observed=0, num_predicted=5, query_source='random') 255 | 256 | def testLibraryMatchingTFvsZeroPredicted(self): 257 | return self.tf_vs_np_library_matching_test_helper( 258 | num_observed=10, num_predicted=0, query_source='random') 259 | 260 | def testLibraryMatchingTFvsNPQueriesObserved(self): 261 | return self.tf_vs_np_library_matching_test_helper( 262 | num_observed=10, 263 | num_predicted=5, 264 | query_source='observed', 265 | num_queries=10) 266 | 267 | def testLibraryMatchingTFvsNPZeroQueries(self): 268 | return self.tf_vs_np_library_matching_test_helper( 269 | num_observed=10, num_predicted=5, query_source='zero') 270 | 271 | def testLibraryMatchingHardcoded(self): 272 | """Test library_matching using hardcoded values.""" 273 | 274 | tf_transform = lambda t: t + 2 275 | x_predicted = np.array([[1, 1], [-3, -2]], dtype=np.float32) 276 | 277 | y_observed = np.array([[1, 2], [2, 1], [0, 0]], dtype=np.float32) 278 | y_query = np.array([[2, 5], [2, 1], [-1.5, -1.1]], dtype=np.float32) 279 | 280 | ids_observed = self.make_ids(3, 'obs') 281 | ids_predicted = self.make_ids(2, 'pred') 282 | ids_query = np.array([b'obs-0', b'obs-1', b'pred-1']) 283 | masses_observed = np.ones([3, 1], dtype=np.float32) 284 | masses_predicted = np.ones([2, 1], dtype=np.float32) 285 | masses_query = np.ones([3, 1], dtype=np.float32) 286 | 287 | expected_predicted_ids = ids_query.tolist() 288 | 289 | predicted_data, _ = self.perform_matching( 290 | ids_observed, 291 | ids_predicted, 292 | ids_query, 293 | masses_observed, 294 | masses_predicted, 295 | masses_query, 296 | y_observed, 297 | y_query, 298 | x_predicted, 299 | tf_transform, 300 | mass_tolerance=3) 301 | 302 | self.assertAllEqual(expected_predicted_ids, 303 | predicted_data[fmap_constants.INCHIKEY]) 304 | 305 | def testLibraryMatchingHardcodedMassFiltered(self): 306 | """Test library_matching using hardcoded values when filtering by mass.""" 307 | 308 | tf_transform = lambda t: t + 2 309 | x_predicted = np.array([[1, 1], [-3, -2]], dtype=np.float32) 310 | 311 | y_observed = np.array([[1, 2], [2, 1], [0, 0]], dtype=np.float32) 312 | y_query = np.array([[2, 5], [2, 1], [-1.5, -1.1]], dtype=np.float32) 313 | 314 | ids_observed = self.make_ids(3, 'obs') 315 | ids_predicted = self.make_ids(2, 'pred') 316 | ids_query = np.array([b'pred-0', b'obs-1', b'obs-2']) 317 | masses_observed = np.ones([3, 1], dtype=np.float32) 318 | masses_predicted = 2 * np.ones([2, 1], dtype=np.float32) 319 | masses_query = np.array([3, 1.5, 0], dtype=np.float32)[..., np.newaxis] 320 | 321 | expected_predicted_ids = ids_query.tolist() 322 | 323 | predicted_data, _ = self.perform_matching( 324 | ids_observed, 325 | ids_predicted, 326 | ids_query, 327 | masses_observed, 328 | masses_predicted, 329 | masses_query, 330 | y_observed, 331 | y_query, 332 | x_predicted, 333 | tf_transform, 334 | mass_tolerance=1) 335 | 336 | self.assertAllEqual(expected_predicted_ids, 337 | predicted_data[fmap_constants.INCHIKEY]) 338 | 339 | def testMassFilterRaisesError(self): 340 | """Test case where mass filtering removes everything.""" 341 | 342 | tf_transform = lambda t: t + 2 343 | x_predicted = np.array([[1, 1], [-3, -2]], dtype=np.float32) 344 | 345 | y_observed = np.array([[1, 2], [2, 1], [0, 0]], dtype=np.float32) 346 | y_query = np.array([[2, 5]], dtype=np.float32) 347 | 348 | ids_observed = self.make_ids(3, 'obs') 349 | ids_predicted = self.make_ids(2, 'pred') 350 | ids_query = np.array(['pred-0']) 351 | masses_observed = np.ones([3, 1], dtype=np.float32) 352 | masses_predicted = 2 * np.ones([2, 1], dtype=np.float32) 353 | masses_query = np.array([5], dtype=np.float32)[..., np.newaxis] 354 | 355 | with self.assertRaises(tf.errors.InvalidArgumentError): 356 | self.perform_matching( 357 | ids_observed, 358 | ids_predicted, 359 | ids_query, 360 | masses_observed, 361 | masses_predicted, 362 | masses_query, 363 | y_observed, 364 | y_query, 365 | x_predicted, 366 | tf_transform, 367 | mass_tolerance=1) 368 | 369 | def testLibraryMatchingNoPredictions(self): 370 | """Test library_matching using hardcoded values with no predicted data.""" 371 | 372 | y_observed = np.array([[1, 2], [2, 1], [0, 0]], dtype=np.float32) 373 | y_query = np.array([[2, 5], [-3, 1], [0, 0]], dtype=np.float32) 374 | 375 | ids_observed = self.make_ids(3) 376 | ids_query = self.make_ids(3) 377 | 378 | expected_predictions = [b'0', b'2', b'0'] 379 | 380 | masses_query = np.ones([3, 1], dtype=np.float32) 381 | query_data = self._package_data( 382 | ids=ids_query, spectrum=y_query, masses=masses_query) 383 | masses_observed = np.ones([3, 1], dtype=np.float32) 384 | observed_data = self._package_data( 385 | ids=ids_observed, spectrum=y_observed, masses=masses_observed) 386 | predicted_data = None 387 | library_matching_data = library_matching.LibraryMatchingData( 388 | query=query_data, observed=observed_data, predicted=predicted_data) 389 | 390 | _, predicted_data, _, _ = library_matching.library_matching( 391 | library_matching_data, 392 | predictor_fn=None, 393 | similarity_provider=similarity_lib.CosineSimilarityProvider(), 394 | mass_tolerance=3.0) 395 | 396 | with tf.Session() as sess: 397 | sess.run(tf.local_variables_initializer()) 398 | predictions = sess.run(predicted_data[fmap_constants.INCHIKEY]) 399 | 400 | self.assertAllEqual(expected_predictions, predictions) 401 | 402 | def testLibraryMatchingNoObserved(self): 403 | """Test library_matching using hardcoded values with no observed data.""" 404 | 405 | tf_transform = lambda t: t + 2 406 | x_predicted = np.array([[1, 1], [-3, -2]], dtype=np.float32) 407 | y_query = np.array([[2, 5], [-3, 1], [0, 0]], dtype=np.float32) 408 | 409 | ids_predicted = self.make_ids(2) 410 | ids_query = self.make_ids(3) 411 | 412 | expected_predictions = [b'0', b'1', b'0'] 413 | 414 | masses_query = np.ones([3, 1], dtype=np.float32) 415 | query_data = self._package_data( 416 | ids=ids_query, spectrum=y_query, masses=masses_query) 417 | masses_predicted = np.ones([2, 1], dtype=np.float32) 418 | predicted_data = self._package_data( 419 | ids=ids_predicted, spectrum=None, masses=masses_predicted) 420 | predicted_data[PREDICTOR_INPUT_KEY] = tf.constant(x_predicted) 421 | 422 | observed_data = None 423 | library_matching_data = library_matching.LibraryMatchingData( 424 | query=query_data, observed=observed_data, predicted=predicted_data) 425 | 426 | predictor_fn = lambda d: tf_transform(d[PREDICTOR_INPUT_KEY]) 427 | 428 | _, predicted_data, _, _ = library_matching.library_matching( 429 | library_matching_data, 430 | predictor_fn=predictor_fn, 431 | similarity_provider=similarity_lib.CosineSimilarityProvider(), 432 | mass_tolerance=3.0) 433 | 434 | with tf.Session() as sess: 435 | sess.run(tf.local_variables_initializer()) 436 | predictions = sess.run(predicted_data[fmap_constants.INCHIKEY]) 437 | 438 | self.assertAllEqual(expected_predictions, predictions) 439 | 440 | 441 | if __name__ == '__main__': 442 | tf.test.main() 443 | -------------------------------------------------------------------------------- /make_train_test_split.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Creates datasets from the NIST sdf files and makes experiment setup jsons. 16 | 17 | This module first breaks up the main NIST library dataset into a train/ 18 | validation/test set, and the replicates library into a validation and test set. 19 | As all the molecules in the replicates file are also in the main NIST library, 20 | the mainlib datasets will exclude inchikeys from the replicates library. All the 21 | molecules in both datasets are to be included in one of these datasets, unless 22 | an argument is passed for mainlib_maximum_num_molecules_to_use or 23 | replicates_maximum_num_molecules_to_use. 24 | 25 | The component datasets are saved as TFRecords, by the names defined in 26 | dataset_setup_constants and the library from which the data came 27 | (e.g. mainlib_train_from_mainlib.tfrecord). This will result in 7 TFRecord files 28 | total, one each for the train/validation/test splits from the main library, and 29 | two each for the replicates validation/test splits, one with its data from the 30 | mainlib NIST file, and the other from the replicates file. 31 | 32 | For each experiment setup included in 33 | dataset_setup_constants.EXPERIMENT_SETUPS_LIST, a json file is written. This 34 | json file name the files to be used for each part of the experiment, i.e. 35 | library matching, spectra prediction. 36 | 37 | Note: Reading sdf files from cns currently not supported. 38 | 39 | Example usage: 40 | make_train_test_split.py \ 41 | --main_sdf_name=testdata/test_14_mend.sdf 42 | --replicates_sdf_name=testdata/test_2_mend.sdf \ 43 | --output_master_dir= 44 | 45 | """ 46 | 47 | from __future__ import absolute_import 48 | from __future__ import division 49 | from __future__ import print_function 50 | import json 51 | import os 52 | import random 53 | 54 | from absl import app 55 | from absl import flags 56 | import dataset_setup_constants as ds_constants 57 | import mass_spec_constants as ms_constants 58 | import parse_sdf_utils 59 | import train_test_split_utils 60 | import six 61 | import tensorflow as tf 62 | 63 | FLAGS = flags.FLAGS 64 | flags.DEFINE_string( 65 | 'main_sdf_name', 'testdata/test_14_mend.sdf', 66 | 'specify full path of sdf file to parse, to be used for' 67 | ' training sets, and validation/test sets') 68 | flags.DEFINE_string( 69 | 'replicates_sdf_name', 70 | 'testdata/test_2_mend.sdf', 71 | 'specify full path of a second sdf file to parse, to be' 72 | ' used for the vaildation/test set. Molecules in this sdf' 73 | ' will be excluded from the main train/val/test sets.') 74 | # Note: For family based splitting, all molecules passing the filter will be 75 | # placed in validation/test datasets, and then split according to the relative 76 | # ratio between the validation/test fractions. If these are both equal to 0.0, 77 | # these values will be over written to 0.5 and 0.5. 78 | flags.DEFINE_list( 79 | 'main_train_val_test_fractions', '1.0,0.0,0.0', 80 | 'specify how large to make the train, val, and test sets' 81 | ' as a fraction of the whole dataset.') 82 | flags.DEFINE_integer('mainlib_maximum_num_molecules_to_use', None, 83 | 'specify how many total samples to use for parsing') 84 | flags.DEFINE_integer('replicates_maximum_num_molecules_to_use', None, 85 | 'specify how many total samples to use for parsing') 86 | flags.DEFINE_list( 87 | 'replicates_train_val_test_fractions', '0.0,0.5,0.5', 88 | 'specify fraction of replicates molecules to use for' 89 | ' for the three replicates sample files.') 90 | flags.DEFINE_enum( 91 | 'splitting_type', 'random', ['random', 'steroid', 'diazo'], 92 | 'specify splitting method to use for creating ' 93 | 'training/validation/test sets') 94 | flags.DEFINE_string('output_master_dir', '/tmp/output_dataset_dir', 95 | 'specify directory to save records') 96 | flags.DEFINE_integer('max_atoms', ms_constants.MAX_ATOMS, 97 | 'specify maximum number of atoms to allow') 98 | flags.DEFINE_integer('max_mass_spec_peak_loc', ms_constants.MAX_PEAK_LOC, 99 | 'specify greatest m/z spectrum peak to allow') 100 | 101 | INCHIKEY_FILENAME_END = '.inchikey.txt' 102 | TFRECORD_FILENAME_END = '.tfrecord' 103 | NP_LIBRARY_ARRAY_END = '.spectra_library.npy' 104 | FROM_MAINLIB_FILENAME_MODIFIER = '_from_mainlib' 105 | FROM_REPLICATES_FILENAME_MODIFIER = '_from_replicates' 106 | 107 | 108 | def make_mainlib_replicates_train_test_split( 109 | mainlib_mol_list, 110 | replicates_mol_list, 111 | splitting_type, 112 | mainlib_fractions, 113 | replicates_fractions, 114 | mainlib_maximum_num_molecules_to_use=None, 115 | replicates_maximum_num_molecules_to_use=None, 116 | rseed=42): 117 | """Makes train/validation/test inchikey lists from two lists of rdkit.Mol. 118 | 119 | Args: 120 | mainlib_mol_list : list of molecules from main library 121 | replicates_mol_list : list of molecules from replicates library 122 | splitting_type : type of splitting to use for validation splits. 123 | mainlib_fractions : TrainValTestFractions namedtuple 124 | holding desired fractions for train/val/test split of mainlib 125 | replicates_fractions : TrainValTestFractions namedtuple 126 | holding desired fractions for train/val/test split of replicates. 127 | For the replicates set, the train fraction should be set to 0. 128 | mainlib_maximum_num_molecules_to_use : Largest number of molecules to use 129 | when making datasets from mainlib 130 | replicates_maximum_num_molecules_to_use : Largest number of molecules to use 131 | when making datasets from replicates 132 | rseed : random seed for shuffling 133 | 134 | Returns: 135 | main_inchikey_dict : Dict that is keyed by inchikey, containing a list of 136 | rdkit.Mol objects corresponding to that inchikey from the mainlib 137 | replicates_inchikey_dict : Dict that is keyed by inchikey, containing a list 138 | of rdkit.Mol objects corresponding to that inchikey from the replicates 139 | library 140 | main_replicates_split_inchikey_lists_dict : dict with keys : 141 | 'mainlib_train', 'mainlib_validation', 'mainlib_test', 142 | 'replicates_train', 'replicates_validation', 'replicates_test' 143 | Values are lists of inchikeys corresponding to each dataset. 144 | 145 | """ 146 | random.seed(rseed) 147 | main_inchikey_dict = train_test_split_utils.make_inchikey_dict( 148 | mainlib_mol_list) 149 | main_inchikey_list = main_inchikey_dict.keys() 150 | 151 | if six.PY3: 152 | main_inchikey_list = list(main_inchikey_list) 153 | 154 | if mainlib_maximum_num_molecules_to_use is not None: 155 | main_inchikey_list = random.sample(main_inchikey_list, 156 | mainlib_maximum_num_molecules_to_use) 157 | 158 | replicates_inchikey_dict = train_test_split_utils.make_inchikey_dict( 159 | replicates_mol_list) 160 | replicates_inchikey_list = replicates_inchikey_dict.keys() 161 | 162 | if six.PY3: 163 | replicates_inchikey_list = list(replicates_inchikey_list) 164 | 165 | if replicates_maximum_num_molecules_to_use is not None: 166 | replicates_inchikey_list = random.sample( 167 | replicates_inchikey_list, replicates_maximum_num_molecules_to_use) 168 | 169 | # Make train/val/test splits for main dataset. 170 | main_train_validation_test_inchikeys = ( 171 | train_test_split_utils.make_train_val_test_split_inchikey_lists( 172 | main_inchikey_list, 173 | main_inchikey_dict, 174 | mainlib_fractions, 175 | holdout_inchikey_list=replicates_inchikey_list, 176 | splitting_type=splitting_type)) 177 | 178 | # Make train/val/test splits for replicates dataset. 179 | replicates_validation_test_inchikeys = ( 180 | train_test_split_utils.make_train_val_test_split_inchikey_lists( 181 | replicates_inchikey_list, 182 | replicates_inchikey_dict, 183 | replicates_fractions, 184 | splitting_type=splitting_type)) 185 | 186 | component_inchikey_dict = { 187 | ds_constants.MAINLIB_TRAIN_BASENAME: 188 | main_train_validation_test_inchikeys.train, 189 | ds_constants.MAINLIB_VALIDATION_BASENAME: 190 | main_train_validation_test_inchikeys.validation, 191 | ds_constants.MAINLIB_TEST_BASENAME: 192 | main_train_validation_test_inchikeys.test, 193 | ds_constants.REPLICATES_TRAIN_BASENAME: 194 | replicates_validation_test_inchikeys.train, 195 | ds_constants.REPLICATES_VALIDATION_BASENAME: 196 | replicates_validation_test_inchikeys.validation, 197 | ds_constants.REPLICATES_TEST_BASENAME: 198 | replicates_validation_test_inchikeys.test 199 | } 200 | 201 | train_test_split_utils.assert_all_lists_mutally_exclusive( 202 | list(component_inchikey_dict.values())) 203 | # Test that the set of the 5 component inchikey lists is equal to the set of 204 | # inchikeys in the main library. 205 | all_inchikeys_in_components = [] 206 | for ikey_list in list(component_inchikey_dict.values()): 207 | for ikey in ikey_list: 208 | all_inchikeys_in_components.append(ikey) 209 | 210 | assert set(main_inchikey_list + replicates_inchikey_list) == set( 211 | all_inchikeys_in_components 212 | ), ('The inchikeys in the original inchikey dictionary are not all included' 213 | ' in the train/val/test component libraries') 214 | 215 | return (main_inchikey_dict, replicates_inchikey_dict, component_inchikey_dict) 216 | 217 | 218 | def write_list_of_inchikeys(inchikey_list, base_name, output_dir): 219 | """Write list of inchikeys as a text file.""" 220 | inchikey_list_name = base_name + INCHIKEY_FILENAME_END 221 | 222 | with tf.gfile.Open(os.path.join(output_dir, inchikey_list_name), 223 | 'w') as writer: 224 | for inchikey in inchikey_list: 225 | writer.write('%s\n' % inchikey) 226 | 227 | 228 | def write_all_dataset_files(inchikey_dict, 229 | inchikey_list, 230 | base_name, 231 | output_dir, 232 | max_atoms, 233 | max_mass_spec_peak_loc, 234 | make_library_array=False): 235 | """Helper function for writing all the files associated with a TFRecord. 236 | 237 | Args: 238 | inchikey_dict : Full dictionary keyed by inchikey containing lists of 239 | rdkit.Mol objects 240 | inchikey_list : List of inchikeys to include in dataset 241 | base_name : Base name for the dataset 242 | output_dir : Path for saving all TFRecord files 243 | max_atoms : Maximum number of atoms to include for a given molecule 244 | max_mass_spec_peak_loc : Largest m/z peak to include in a spectra. 245 | make_library_array : Flag for whether to make library array 246 | Returns: 247 | Saves 3 files: 248 | basename.tfrecord : a TFRecord file, 249 | basename.inchikey.txt : a text file with all the inchikeys in the dataset 250 | basename.tfrecord.info: a text file with one line describing 251 | the length of the TFRecord file. 252 | Also saves if make_library_array is set: 253 | basename.npz : see parse_sdf_utils.write_dicts_to_example 254 | """ 255 | record_name = base_name + TFRECORD_FILENAME_END 256 | 257 | mol_list = train_test_split_utils.make_mol_list_from_inchikey_dict( 258 | inchikey_dict, inchikey_list) 259 | 260 | if make_library_array: 261 | library_array_pathname = base_name + NP_LIBRARY_ARRAY_END 262 | parse_sdf_utils.write_dicts_to_example( 263 | mol_list, os.path.join(output_dir, record_name), 264 | max_atoms, max_mass_spec_peak_loc, 265 | os.path.join(output_dir, library_array_pathname)) 266 | else: 267 | parse_sdf_utils.write_dicts_to_example( 268 | mol_list, os.path.join(output_dir, record_name), max_atoms, 269 | max_mass_spec_peak_loc) 270 | write_list_of_inchikeys(inchikey_list, base_name, output_dir) 271 | parse_sdf_utils.write_info_file(mol_list, os.path.join( 272 | output_dir, record_name)) 273 | 274 | 275 | def write_mainlib_split_datasets(component_inchikey_dict, mainlib_inchikey_dict, 276 | output_dir, max_atoms, max_mass_spec_peak_loc): 277 | """Write all train/val/test set TFRecords from main NIST sdf file.""" 278 | for component_kwarg in component_inchikey_dict.keys(): 279 | component_mainlib_filename = ( 280 | component_kwarg + FROM_MAINLIB_FILENAME_MODIFIER) 281 | if component_kwarg == ds_constants.MAINLIB_TRAIN_BASENAME: 282 | write_all_dataset_files( 283 | mainlib_inchikey_dict, 284 | component_inchikey_dict[component_kwarg], 285 | component_mainlib_filename, 286 | output_dir, 287 | max_atoms, 288 | max_mass_spec_peak_loc, 289 | make_library_array=True) 290 | else: 291 | write_all_dataset_files(mainlib_inchikey_dict, 292 | component_inchikey_dict[component_kwarg], 293 | component_mainlib_filename, output_dir, max_atoms, 294 | max_mass_spec_peak_loc) 295 | 296 | 297 | def write_replicates_split_datasets(component_inchikey_dict, 298 | replicates_inchikey_dict, output_dir, 299 | max_atoms, max_mass_spec_peak_loc): 300 | """Write replicates val/test set TFRecords from replicates sdf file.""" 301 | for component_kwarg in [ 302 | ds_constants.REPLICATES_VALIDATION_BASENAME, 303 | ds_constants.REPLICATES_TEST_BASENAME 304 | ]: 305 | component_replicates_filename = ( 306 | component_kwarg + FROM_REPLICATES_FILENAME_MODIFIER) 307 | write_all_dataset_files(replicates_inchikey_dict, 308 | component_inchikey_dict[component_kwarg], 309 | component_replicates_filename, output_dir, 310 | max_atoms, max_mass_spec_peak_loc) 311 | 312 | 313 | def combine_inchikey_sets(dataset_subdivision_list, dataset_split_dict): 314 | """A function to combine lists of inchikeys that are values from a dict. 315 | 316 | Args: 317 | dataset_subdivision_list: List of keys in dataset_split_dict to combine 318 | into one list 319 | dataset_split_dict: dict containing keys in dataset_subdivision_list, with 320 | lists of inchikeys as values. 321 | Returns: 322 | A list of inchikeys. 323 | """ 324 | dataset_inchikey_list = [] 325 | for dataset_subdivision_name in dataset_subdivision_list: 326 | dataset_inchikey_list.extend(dataset_split_dict[dataset_subdivision_name]) 327 | return dataset_inchikey_list 328 | 329 | 330 | def check_experiment_setup(experiment_setup_dict, component_inchikey_dict): 331 | """Validates experiment setup for given lists of inchikeys.""" 332 | 333 | # Check that the union of the library matching observed and library 334 | # matching predicted sets are equal to the set of inchikeys in the 335 | # mainlib_inchikey_dict 336 | all_inchikeys_in_library = ( 337 | combine_inchikey_sets( 338 | experiment_setup_dict[ds_constants.LIBRARY_MATCHING_OBSERVED_KEY], 339 | component_inchikey_dict) + 340 | combine_inchikey_sets( 341 | experiment_setup_dict[ds_constants.LIBRARY_MATCHING_PREDICTED_KEY], 342 | component_inchikey_dict)) 343 | 344 | all_inchikeys_in_use = [] 345 | for kwarg in component_inchikey_dict.keys(): 346 | all_inchikeys_in_use.extend(component_inchikey_dict[kwarg]) 347 | 348 | assert set(all_inchikeys_in_use) == set(all_inchikeys_in_library), ( 349 | 'Inchikeys in library for library matching does not match full dataset.') 350 | 351 | # Check that all inchikeys in query are found in full library of inchikeys. 352 | assert set( 353 | combine_inchikey_sets( 354 | experiment_setup_dict[ds_constants.LIBRARY_MATCHING_QUERY_KEY], 355 | component_inchikey_dict)).issubset(set(all_inchikeys_in_library)), ( 356 | 'Inchikeys in query set for library matching not' 357 | 'found in library.') 358 | 359 | 360 | def write_json_for_experiment(experiment_setup, output_dir): 361 | """Writes json for experiment, recording relevant files for each component. 362 | 363 | Writes a json containing a list of TFRecord file names to read 364 | for each experiment component, i.e. spectrum_prediction, library_matching. 365 | 366 | Args: 367 | experiment_setup: A dataset_setup_constants.ExperimentSetup tuple 368 | output_dir: directory to write json 369 | Returns: 370 | Writes json recording which files to load for each component 371 | of the experiment 372 | Raises: 373 | ValueError: if the experiment component is not specified to be taken from 374 | either the main NIST library or the replicates library. 375 | 376 | """ 377 | experiment_json_dict = {} 378 | for dataset_kwarg in experiment_setup.experiment_setup_dataset_dict: 379 | if dataset_kwarg in experiment_setup.data_to_get_from_mainlib: 380 | experiment_json_dict[dataset_kwarg] = [ 381 | (component_basename + FROM_MAINLIB_FILENAME_MODIFIER + 382 | TFRECORD_FILENAME_END) for component_basename in 383 | experiment_setup.experiment_setup_dataset_dict[dataset_kwarg] 384 | ] 385 | elif dataset_kwarg in experiment_setup.data_to_get_from_replicates: 386 | experiment_json_dict[dataset_kwarg] = [ 387 | (component_basename + FROM_REPLICATES_FILENAME_MODIFIER + 388 | TFRECORD_FILENAME_END) for component_basename in 389 | experiment_setup.experiment_setup_dataset_dict[dataset_kwarg] 390 | ] 391 | else: 392 | raise ValueError('Did not specify origin for {}.'.format(dataset_kwarg)) 393 | 394 | training_spectra_filename = ( 395 | ds_constants.MAINLIB_TRAIN_BASENAME + FROM_MAINLIB_FILENAME_MODIFIER + 396 | NP_LIBRARY_ARRAY_END) 397 | experiment_json_dict[ 398 | ds_constants.TRAINING_SPECTRA_ARRAY_KEY] = training_spectra_filename 399 | 400 | with tf.gfile.Open(os.path.join(output_dir, experiment_setup.json_name), 401 | 'w') as writer: 402 | experiment_json = json.dumps(experiment_json_dict) 403 | writer.write(experiment_json) 404 | 405 | 406 | def main(_): 407 | tf.gfile.MkDir(FLAGS.output_master_dir) 408 | 409 | main_train_val_test_fractions_tuple = tuple( 410 | [float(elem) for elem in FLAGS.main_train_val_test_fractions]) 411 | main_train_val_test_fractions = train_test_split_utils.TrainValTestFractions( 412 | *main_train_val_test_fractions_tuple) 413 | 414 | replicates_train_val_test_fractions_tuple = tuple( 415 | [float(elem) for elem in FLAGS.replicates_train_val_test_fractions]) 416 | replicates_train_val_test_fractions = ( 417 | train_test_split_utils.TrainValTestFractions( 418 | *replicates_train_val_test_fractions_tuple)) 419 | 420 | mainlib_mol_list = parse_sdf_utils.get_sdf_to_mol( 421 | FLAGS.main_sdf_name, max_atoms=FLAGS.max_atoms) 422 | replicates_mol_list = parse_sdf_utils.get_sdf_to_mol( 423 | FLAGS.replicates_sdf_name, max_atoms=FLAGS.max_atoms) 424 | 425 | # Breaks the inchikeys lists into train/validation/test splits. 426 | (mainlib_inchikey_dict, replicates_inchikey_dict, component_inchikey_dict) = ( 427 | make_mainlib_replicates_train_test_split( 428 | mainlib_mol_list, 429 | replicates_mol_list, 430 | FLAGS.splitting_type, 431 | main_train_val_test_fractions, 432 | replicates_train_val_test_fractions, 433 | mainlib_maximum_num_molecules_to_use=FLAGS. 434 | mainlib_maximum_num_molecules_to_use, 435 | replicates_maximum_num_molecules_to_use=FLAGS. 436 | replicates_maximum_num_molecules_to_use)) 437 | 438 | # Writes TFRecords for each component using info from the main library file 439 | write_mainlib_split_datasets(component_inchikey_dict, mainlib_inchikey_dict, 440 | FLAGS.output_master_dir, FLAGS.max_atoms, 441 | FLAGS.max_mass_spec_peak_loc) 442 | 443 | # Writes TFRecords for each component using info from the replicates file 444 | write_replicates_split_datasets( 445 | component_inchikey_dict, replicates_inchikey_dict, 446 | FLAGS.output_master_dir, FLAGS.max_atoms, FLAGS.max_mass_spec_peak_loc) 447 | 448 | for experiment_setup in ds_constants.EXPERIMENT_SETUPS_LIST: 449 | # Check that experiment setup is valid. 450 | check_experiment_setup(experiment_setup.experiment_setup_dataset_dict, 451 | component_inchikey_dict) 452 | 453 | # Write a json for the experiment setups, pointing to local files. 454 | write_json_for_experiment(experiment_setup, FLAGS.output_master_dir) 455 | 456 | 457 | if __name__ == '__main__': 458 | app.run(main) 459 | --------------------------------------------------------------------------------