├── testdata
├── test_14_record.gz.info
├── test_14_record.inchikey.txt
├── test_14_record.gz
├── test_14.spectra_library.npy
├── test_dataset_config_file.json
├── README
└── test_2_mend.sdf
├── .gitignore
├── neims_toc.jpeg
├── __init__.py
├── test_utils.py
├── CONTRIBUTING.md
├── feature_map_constants.py
├── Model_Retrain_Quickstart.md
├── make_spectra_prediction.py
├── gather_similarites.py
├── README.md
├── mass_spec_constants.py
├── examples
└── pentachlorobenzene.sdf
├── feature_utils_test.py
├── spectra_predictor_test.py
├── similarity.py
├── util_test.py
├── make_predictions.py
├── make_predictions_from_tfrecord.py
├── train_test_split_utils.py
├── molecule_estimator_test.py
├── util.py
├── spectra_predictor.py
├── dataset_setup_constants.py
├── make_train_test_split_test.py
├── LICENSE
├── feature_utils.py
├── molecule_estimator.py
├── plot_spectra_utils.py
├── library_matching_test.py
└── make_train_test_split.py
/testdata/test_14_record.gz.info:
--------------------------------------------------------------------------------
1 | 12
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/*
2 | .ipynb_checkpoints/*
3 | __pycache__/*
4 |
--------------------------------------------------------------------------------
/testdata/test_14_record.inchikey.txt:
--------------------------------------------------------------------------------
1 | YXHKONLOYHBTNS-UHFFFAOYSA-N
2 |
--------------------------------------------------------------------------------
/neims_toc.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brain-research/deep-molecular-massspec/HEAD/neims_toc.jpeg
--------------------------------------------------------------------------------
/testdata/test_14_record.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brain-research/deep-molecular-massspec/HEAD/testdata/test_14_record.gz
--------------------------------------------------------------------------------
/testdata/test_14.spectra_library.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brain-research/deep-molecular-massspec/HEAD/testdata/test_14.spectra_library.npy
--------------------------------------------------------------------------------
/testdata/test_dataset_config_file.json:
--------------------------------------------------------------------------------
1 | {"LIBRARY_MATCHING_PREDICTED": ["test_14_record.gz"], "SPECTRUM_PREDICTION_TEST": ["test_14_record.gz"], "mainlib_train_spectra_library_file": "test_14.spectra_library.npy", "SPECTRUM_PREDICTION_TRAIN": ["test_14_record.gz"], "LIBRARY_MATCHING_QUERY": ["test_14_record.gz"], "LIBRARY_MATCHING_OBSERVED": ["test_14_record.gz"]}
2 |
--------------------------------------------------------------------------------
/testdata/README:
--------------------------------------------------------------------------------
1 | test_2_mend.sdf and test_14_mend.sdf contain molblocks from mainlib.sdf, an SDF
2 | file from NIST. The actual mass spectra data has been removed from these SDF
3 | files, and replaced with fake data.
4 |
5 | test_14_record.gz is a TFRecord with features 'molecule weight', 'atom weights',
6 | 'circular_fp_1024', 'dense_mass_spec'
7 |
8 | NB : Since 2 molecules will not pass the filters in get_sdf_to_mol, there will
9 | only be 12 molecules into the TFRecord file.
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/test_utils.py:
--------------------------------------------------------------------------------
1 | """Utilities for tests."""
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 | import os
7 | from absl import flags
8 |
9 |
10 | def test_dir(relative_path):
11 | """Gets the path to a testdata file in genomics at relative path.
12 |
13 | Args:
14 | relative_path: a directory path relative to base directory of this module.
15 | Returns:
16 | The absolute path to a testdata file.
17 | """
18 |
19 | return os.path.join(flags.FLAGS.test_srcdir,
20 | os.path.split(os.path.abspath(__file__))[0],
21 | relative_path)
22 |
23 |
24 | def encode(value, is_py3_flag):
25 | """A helper function for wrapping strings as bytes for py3."""
26 | if is_py3_flag:
27 | return value.encode('utf-8')
28 | else:
29 | return value
30 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows [Google's Open Source Community
28 | Guidelines](https://opensource.google.com/conduct/).
29 |
--------------------------------------------------------------------------------
/feature_map_constants.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Module containing the feature names used in the TFRecords of this repo."""
16 |
17 | # Overall Tasks
18 | LIBRARY_MATCHING = 'library_matching'
19 | SPECTRUM_PREDICTION = 'spectrum_prediction'
20 |
21 | # Feature names
22 | ATOM_WEIGHTS = 'atom_weights'
23 | MOLECULE_WEIGHT = 'molecule_weight'
24 | CIRCULAR_FP_BASENAME = 'circular_fp'
25 | COUNTING_CIRCULAR_FP_BASENAME = 'counting_circular_fp'
26 | INDEX_TO_GROUND_TRUTH_ARRAY = 'index_in_library'
27 |
28 | FP_TYPE_LIST = [
29 | CIRCULAR_FP_BASENAME, COUNTING_CIRCULAR_FP_BASENAME
30 | ]
31 |
32 | DENSE_MASS_SPEC = 'dense_mass_spec'
33 | INCHIKEY = 'inchikey'
34 | NAME = 'name'
35 | SMILES = 'smiles'
36 | MOLECULAR_FORMULA = 'molecular_formula'
37 | ADJACENCY_MATRIX = 'adjacency_matrix'
38 | ATOM_IDS = 'atom_id'
39 | SMILES_TOKEN_LIST_LENGTH = 'smiles_sequence_len'
40 |
--------------------------------------------------------------------------------
/Model_Retrain_Quickstart.md:
--------------------------------------------------------------------------------
1 | ## Quickstart
2 |
3 | The following is a tutorial for training a new model.
4 |
5 | Setup a directory for predictions:
6 |
7 | ```
8 | $ TARGET_PATH_NAME=/tmp/massspec_predictions
9 | ```
10 |
11 | To convert an sdf file into a TFRecord and make the training and test splits:
12 |
13 | ```
14 | $ python make_train_test_split.py --main_sdf_name=testdata/test_14_mend.sdf
15 | --replicates_sdf_name=testdata/test_2_mend.sdf \\ \
16 | --output_master_dir=$TARGET_PATH_NAME/spectra_tf_records
17 | ```
18 |
19 | To train a model:
20 |
21 | ```
22 | python molecule_estimator.py
23 | --dataset_config_file=$TARGET_PATH_NAME/spectra_tf_records/query_replicates_val_predicted_replicates_val.json
24 | --train_steps=1000 \\ \
25 | --model_dir=$TARGET_PATH_NAME/models/output --hparams=make_spectra_plots=True
26 | --alsologtostderr
27 | ```
28 |
29 | The aggregate training results will be logged to stdout. The final library
30 | matching results can also be found in
31 | $TARGET_PATH_NAME/models/output/predictions. These results report the query
32 | inchikey, the matched inchikey, and the rank asigned to the true spectra.
33 |
34 | It is also possible to view these results in tensorboard: \
35 | tensorboard --logdir=path/to/log-directory
36 |
37 | To predict spectra for new data given a model, run:
38 |
39 | ```
40 | python make_predictions_from_tfrecord.py \
41 | --input_file=testdata/test_14_record.gz \
42 | --output_file=$TARGET_PATH_NAME/models/output_predictions \
43 | --model_checkpoint_path=$TARGET_PATH_NAME/models/output/ \
44 | --hparams=eval_batch_size=16 \
45 | --alsologtostderr
46 | ```
47 |
--------------------------------------------------------------------------------
/make_spectra_prediction.py:
--------------------------------------------------------------------------------
1 | r"""Makes spectra prediction using model and writes predictions to SDF.
2 |
3 | Make predictions using our trained model. Example of how to run:
4 |
5 | # Save weights to a models directory
6 | $ MODEL_WEIGHTS_DIR=/tmp/neims_model
7 | $ cd $MODEL_WEIGHTS_DIR
8 | $ wget https://storage.googleapis.com/deep-molecular-massspec/massspec_weights/massspec_weights.zip # pylint: disable=line-too-long
9 | $ unzip massspec_weights.zip
10 |
11 | $ python make_spectra_prediction.py \
12 | --input_file=examples/pentachlorobenzene.sdf \
13 | --output_file=/tmp/neims_model/annotated.sdf \
14 | --weights_dir=$MODEL_WEIGHTS_DIR/massspec_weights
15 | """
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from absl import app
22 | from absl import flags
23 | from absl import logging
24 |
25 | import spectra_predictor
26 |
27 | FLAGS = flags.FLAGS
28 | flags.DEFINE_string('input_file', 'input.sdf',
29 | 'Name of input file for predictions.')
30 | flags.DEFINE_string('weights_dir',
31 | '/usr/local/massspec_weights',
32 | 'Name of directory that stores model weights.')
33 | flags.DEFINE_string('output_file', 'annotated.sdf',
34 | 'Name of output file for predictions.')
35 |
36 |
37 | def main(_):
38 | logging.info('Loading weights from %s', FLAGS.weights_dir)
39 | predictor = spectra_predictor.NeimsSpectraPredictor(
40 | model_checkpoint_dir=FLAGS.weights_dir)
41 |
42 | logging.info('Loading molecules from %s', FLAGS.input_file)
43 | mols_from_file = spectra_predictor.get_mol_list_from_sdf(
44 | FLAGS.input_file)
45 | fingerprints, mol_weights = predictor.get_inputs_for_model_from_mol_list(
46 | mols_from_file)
47 |
48 | logging.info('Making predictions ...')
49 | spectra_predictions = predictor.make_spectra_prediction(
50 | fingerprints, mol_weights)
51 |
52 | logging.info('Updating molecules in place with predictions.')
53 | spectra_predictor.update_mols_with_spectra(mols_from_file,
54 | spectra_predictions)
55 |
56 | logging.info('Writing predictions to %s', FLAGS.output_file)
57 | with open(FLAGS.output_file, 'w') as f:
58 | spectra_predictor.write_rdkit_mols_to_sdf(mols_from_file, f)
59 |
60 |
61 | if __name__ == '__main__':
62 | app.run(main)
63 |
--------------------------------------------------------------------------------
/gather_similarites.py:
--------------------------------------------------------------------------------
1 | from rdkit import Chem
2 | import numpy as np
3 | import tensorflow as tf
4 |
5 | import parse_sdf_utils
6 | import train_test_split_utils
7 | import feature_utils
8 | import mass_spec_constants as ms_constants
9 | import similarity as similarity_lib
10 |
11 |
12 | def make_spectra_array(mol_list):
13 | """Grab spectra pertaining to same molecule in one np.array.
14 | Args:
15 | mol_list: list of rdkit.Mol objects. Each Mol should contain
16 | information about the spectra, as stored in NIST.
17 | Output:
18 | np.array of spectra of shape (number of spectra, max spectra length)
19 | """
20 | mass_spec_spectra = np.zeros( ( len(mol_list), ms_constants.MAX_PEAK_LOC))
21 | for idx, mol in enumerate(mol_list):
22 | spectra_str = mol.GetProp(ms_constants.SDF_TAG_MASS_SPEC_PEAKS)
23 | spectral_locs, spectral_intensities = feature_utils.parse_peaks(spectra_str)
24 | dense_mass_spec = feature_utils.make_dense_mass_spectra(
25 | spectral_locs, spectral_intensities, ms_constants.MAX_PEAK_LOC)
26 |
27 | mass_spec_spectra[idx, :] = dense_mass_spec
28 |
29 | return mass_spec_spectra
30 |
31 |
32 | def get_similarities(raw_spectra_array):
33 | """Preprocess spectra and then calculate similarity between spectra.
34 | Args:
35 | raw_spectra_array: np.array containing unprocessed spectra
36 | Output:
37 | np.array of shape (len(raw_spectra_array), len(raw_spectra_array))
38 | reflects distances between spectra.
39 | """
40 | spec_array_var = tf.constant(raw_spectra_array)
41 |
42 | # Adjusting intensity to match default in molecule_predictors
43 | intensity_adjusted_spectra = tf.pow(spec_array_var, 0.5)
44 |
45 | hparams = tf.contrib.training.HParams(
46 | mass_power=1.,
47 | )
48 |
49 | cos_similarity = similarity_lib.GeneralizedCosineSimilarityProvider(hparams)
50 | norm_spectra = cos_similarity._normalize_rows(intensity_adjusted_spectra)
51 | similarity = cos_similarity.compute_similarity(norm_spectra, norm_spectra)
52 |
53 | with tf.Session() as sess:
54 | sess.run(tf.global_variables_initializer())
55 | dist = sess.run(similarity)
56 |
57 | return dist
58 |
59 |
60 | def main():
61 | mol_list = parse_sdf_utils.get_sdf_to_mol('/mnt/storage/NIST_zipped/NIST17/replib_mend.sdf')
62 | inchikey_dict = train_test_split_utils.make_inchikey_dict(mol_list)
63 |
64 | spectra_for_one_mol = make_spectra_array(inchikey_dict['PDACHFOTOFNHBT-UHFFFAOYSA-N'])
65 | distance_matrix = get_similarities(spectra_for_one_mol)
66 | print('distance for spectra in PDACHFOTOFNHBT-UHFFFAOYSA-N', distance_matrix)
67 |
68 | if __name__ == '__main__':
69 | main()
70 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep learning for Electron Ionization mass spectrometry for organic molecules
2 | 
3 |
4 | This repository accompanies
5 |
6 | [Rapid Prediction of Electron–Ionization Mass Spectrometry Using Neural Networks](https://pubs.acs.org/doi/10.1021/acscentsci.9b00085)\
7 | [Jennifer N. Wei](https://ai.google/research/people/JenniferNWei),
8 | [David Belanger](https://davidbelanger.github.io/), [Ryan P. Adams](https://www.cs.princeton.edu/~rpa/),
9 | and [D. Sculley](https://www.eecs.tufts.edu/~dsculley/)\
10 | ACS Central Science 2019 5 (4), 700-708\
11 | DOI: 10.1021/acscentsci.9b00085
12 |
13 |
14 | ## Introduction
15 |
16 | We predict the mass spectrometry spectra of molecules using deep learning
17 | techniques applied to various molecule representations. The performance behavior
18 | is evaluated with a custom-made library matching task. In this task we identify
19 | molecules by matching its spectra to a library of labeled spectra. As a
20 | baseline, this library contains all of the molecules in the NIST main library,
21 | which mimics the behavior currently used by experimental chemists. To test our
22 | predictions, we replace portions of the library with spectra predictions from
23 | our model. This task is described in more detail below.
24 |
25 | ## Required Packages:
26 |
27 | It is recommended to use [Anaconda](https://www.anaconda.com/distribution/) with a Python 3.6 environment to install these packages.
28 | - [RDKit](https://www.rdkit.org/docs/Install.html)
29 | - [Tensorflow](https://www.tensorflow.org/install)
30 |
31 | Most of the packages required here can be installed with conda install tensorflow=1.13.2 rdkit matplotlib and pip install absl-py.
32 |
33 | ## Quickstart Guide for Making Model Predictions
34 |
35 | 1. Create a directory and download the weights for the model.
36 |
37 | ```
38 | $ MODEL_WEIGHTS_DIR=/home/path/to/model
39 | $ mkdir $MODEL_WEIGHTS_DIR
40 | $ pushd $MODEL_WEIGHTS_DIR
41 | $ curl -o https://storage.googleapis.com/deep-molecular-massspec/massspec_weights/massspec_weights.zip
42 | $ unzip massspec_weights.zip
43 | $ popd
44 | ```
45 |
46 | 2. Run the model prediction on the example molecule
47 |
48 | ```
49 | $ python make_spectra_prediction.py \
50 | --input_file=examples/pentachlorobenzene.sdf \
51 | --output_file=/tmp/annotated.sdf \
52 | --weights_dir=$MODEL_WEIGHTS_DIR/massspec_weights
53 | ```
54 |
55 | ## Training splits for benchmarking purposes
56 | The molecules used for the training, validation, and test sets can be found under the
57 | directory *training_splits*. The molecules are provided in
58 | inchikey and smiles format.
59 |
60 |
61 | ## To cite this work:
62 |
63 | @article{doi:10.1021/acscentsci.9b00085,\
64 | author = {Wei, Jennifer N. and Belanger, David and Adams, Ryan P. and Sculley, D.},\
65 | title = {Rapid Prediction of Electron–Ionization Mass Spectrometry Using Neural Networks},\
66 | journal = {ACS Central Science},\
67 | volume = {5},\
68 | number = {4},\
69 | pages = {700-708},\
70 | year = {2019},\
71 | doi = {10.1021/acscentsci.9b00085},\
72 | URL = {https://doi.org/10.1021/acscentsci.9b00085},\
73 | }
74 |
--------------------------------------------------------------------------------
/mass_spec_constants.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Module containing all commonly used variables in this repo."""
16 |
17 | from collections import namedtuple
18 |
19 |
20 | class CircularFingerprintKey(
21 | namedtuple('CircularFingerprintKey', ['fp_type', 'fp_len', 'radius'])):
22 | """Helper function for labeling fingerprint keys."""
23 |
24 | def __str__(self):
25 | return self.fp_type + '_' + str(self.fp_len) + '_' + str(self.radius)
26 |
27 |
28 | # Constants for SDF tags found in NIST sdf files:
29 | SDF_TAG_MASS_SPEC_PEAKS = 'MASS SPECTRAL PEAKS'
30 | SDF_TAG_INCHIKEY = 'INCHIKEY'
31 | SDF_TAG_NAME = 'NAME'
32 | SDF_TAG_MOLECULE_MASS = 'EXACT MASS'
33 | SDF_TAG_MASS_SPEC_PEAKS = 'MASS SPECTRAL PEAKS'
34 |
35 | # Constants for fields in TFRecords
36 | MAX_MZ_WEIGHT_RATIO = 3.0
37 | MAX_PEAK_LOC = 1000
38 | MAX_ATOMS = 100
39 | MAX_ATOM_ID = 100
40 | MAX_TOKEN_LIST_LENGTH = 230
41 |
42 | CIRCULAR_FP_RADII_LIST = [2, 4, 6]
43 | NUM_CIRCULAR_FP_BITS_LIST = [1024, 2048, 4096]
44 | ADD_HS_TO_MOLECULES = False
45 |
46 | TWO_LETTER_TOKEN_NAMES = [
47 | 'Al', 'Ce', 'Co', 'Ge', 'Gd', 'Cs', 'Th', 'Cd', 'As', 'Na', 'Nb', 'Li',
48 | 'Ni', 'Se', 'Sc', 'Sb', 'Sn', 'Hf', 'Hg', 'Si', 'Be', 'Cl', 'Rb', 'Fe',
49 | 'Bi', 'Br', 'Ag', 'Ru', 'Zn', 'Te', 'Mo', 'Pt', 'Mn', 'Os', 'Tl', 'In',
50 | 'Cu', 'Mg', 'Ti', 'Pb', 'Re', 'Pd', 'Ir', 'Rh', 'Zr', 'Cr', '@@', 'se',
51 | 'si', 'te'
52 | ]
53 |
54 | METAL_ATOM_SYMBOLS = [
55 | 'As', 'Cr', 'Cs', 'Cu', 'Be', 'Ag', 'Co', 'Al', 'Cd', 'Ce', 'Si', 'Sn',
56 | 'Os', 'Sb', 'Sc', 'In', 'Se', 'Ni', 'Th', 'Hg', 'Hf', 'Li', 'Nb', 'U', 'Y',
57 | 'V', 'W', 'Tl', 'Na', 'Fe', 'K', 'Zr', 'B', 'Pb', 'Pd', 'Rh', 'Re', 'Gd',
58 | 'Ge', 'Ir', 'Rb', 'Ti', 'Pt', 'Mn', 'Mg', 'Ru', 'Bi', 'Zn', 'Te', 'Mo'
59 | ]
60 |
61 | SMILES_TOKEN_NAMES = [
62 | '#', '%', '(', ')', '+', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6',
63 | '7', '8', '9', '=', '@', '@@', 'Ag', 'Al', 'As', 'B', 'Be', 'Bi', 'Br', 'C',
64 | 'Cd', 'Ce', 'Cl', 'Co', 'Cr', 'Cs', 'Cu', 'F', 'Fe', 'Gd', 'Ge', 'H', 'Hf',
65 | 'Hg', 'I', 'In', 'Ir', 'K', 'Li', 'Mg', 'Mn', 'Mo', 'N', 'Na', 'Nb', 'Ni',
66 | 'O', 'Os', 'P', 'Pb', 'Pd', 'Pt', 'Rb', 'Re', 'Rh', 'Ru', 'S', 'Sb', 'Sc',
67 | 'Se', 'Si', 'Sn', 'Te', 'Th', 'Ti', 'Tl', 'U', 'V', 'W', 'Y', 'Zn', 'Zr',
68 | '[', '\\', ']', 'c', 'n', 'o', 'p', 's'
69 | ]
70 |
71 | SMILES_TOKEN_NAME_TO_INDEX = {
72 | name: idx for idx, name in enumerate(SMILES_TOKEN_NAMES)
73 | }
74 |
75 | # Add 3 elements which also have lowercase representations in SMILES string.
76 | # We want these to have the same index as the upper-lower case version.
77 | SMILES_TOKEN_NAME_TO_INDEX['se'] = SMILES_TOKEN_NAME_TO_INDEX['Se']
78 | SMILES_TOKEN_NAME_TO_INDEX['si'] = SMILES_TOKEN_NAME_TO_INDEX['Si']
79 | SMILES_TOKEN_NAME_TO_INDEX['te'] = SMILES_TOKEN_NAME_TO_INDEX['Te']
80 |
81 | # Bond order master list:
82 | BOND_ORDER_TO_INTS_DICT = {1.0: 1, 2.0: 2, 3.0: 3, 1.5: 4}
83 |
84 | TRUE_SPECTRA_SCALING_FACTOR = 0.1
85 |
--------------------------------------------------------------------------------
/examples/pentachlorobenzene.sdf:
--------------------------------------------------------------------------------
1 | 11855
2 | -OEChem-08201916472D
3 |
4 | 12 12 0 0 0 0 0 0 0999 V2000
5 | 3.7320 1.5000 0.0000 Cl 0 0 0 0 0 0 0 0 0 0 0 0
6 | 5.4641 0.5000 0.0000 Cl 0 0 0 0 0 0 0 0 0 0 0 0
7 | 2.0000 0.5000 0.0000 Cl 0 0 0 0 0 0 0 0 0 0 0 0
8 | 2.0000 -1.5000 0.0000 Cl 0 0 0 0 0 0 0 0 0 0 0 0
9 | 5.4641 -1.5000 0.0000 Cl 0 0 0 0 0 0 0 0 0 0 0 0
10 | 3.7320 0.5000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
11 | 4.5981 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
12 | 2.8660 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
13 | 2.8660 -1.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
14 | 4.5981 -1.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
15 | 3.7320 -1.5000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
16 | 3.7320 -2.1200 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
17 | 1 6 1 0 0 0 0
18 | 2 7 1 0 0 0 0
19 | 3 8 1 0 0 0 0
20 | 4 9 1 0 0 0 0
21 | 5 10 1 0 0 0 0
22 | 6 7 2 0 0 0 0
23 | 6 8 1 0 0 0 0
24 | 7 10 1 0 0 0 0
25 | 8 9 2 0 0 0 0
26 | 9 11 1 0 0 0 0
27 | 10 11 2 0 0 0 0
28 | 11 12 1 0 0 0 0
29 | M END
30 | >
31 | 11855
32 |
33 | >
34 | 1
35 |
36 | >
37 | 125
38 |
39 | >
40 | 0
41 |
42 | >
43 | 0
44 |
45 | >
46 | 0
47 |
48 | >
49 | AAADcQBgAAAHAAAAAAAAAAAAAAAAAAAAAAAwAAAAAAAAAAABAAAAGAIAAAAACAKAECAwAIAAAACAACBCAAACAAAgBQAAikAAAogIICKBEhCAIAAggAAIiAcAAAAAAAAQAACAAAQAACAAAQAACAAAAAAAAA==
50 |
51 | >
52 | 1,2,3,4,5-pentachlorobenzene
53 |
54 | >
55 | 1,2,3,4,5-pentachlorobenzene
56 |
57 | >
58 | 1,2,3,4,5-pentachlorobenzene
59 |
60 | >
61 | 1,2,3,4,5-pentachlorobenzene
62 |
63 | >
64 | 1,2,3,4,5-pentakis(chloranyl)benzene
65 |
66 | >
67 | 1,2,3,4,5-pentachlorobenzene
68 |
69 | >
70 | InChI=1S/C6HCl5/c7-2-1-3(8)5(10)6(11)4(2)9/h1H
71 |
72 | >
73 | CEOCDNVZRAIOQZ-UHFFFAOYSA-N
74 |
75 | >
76 | 5.2
77 |
78 | >
79 | 249.849138
80 |
81 | >
82 | C6HCl5
83 |
84 | >
85 | 250.3
86 |
87 | >
88 | C1=C(C(=C(C(=C1Cl)Cl)Cl)Cl)Cl
89 |
90 | >
91 | C1=C(C(=C(C(=C1Cl)Cl)Cl)Cl)Cl
92 |
93 | >
94 | 0
95 |
96 | >
97 | 247.852089
98 |
99 | >
100 | 0
101 |
102 | >
103 | 11
104 |
105 | >
106 | 0
107 |
108 | >
109 | 0
110 |
111 | >
112 | 0
113 |
114 | >
115 | 0
116 |
117 | >
118 | 0
119 |
120 | >
121 | 1
122 |
123 | >
124 | -1
125 |
126 | >
127 | 1
128 | 5
129 | 255
130 |
131 | >
132 | 10 11 8
133 | 6 7 8
134 | 6 8 8
135 | 7 10 8
136 | 8 9 8
137 | 9 11 8
138 |
139 | $$$$
140 |
--------------------------------------------------------------------------------
/feature_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for feature_utils."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 |
22 | import feature_utils
23 | import mass_spec_constants as ms_constants
24 | import numpy as np
25 | from rdkit import Chem
26 | import tensorflow as tf
27 |
28 |
29 | class FeatureUtilsTest(tf.test.TestCase):
30 |
31 | def setUp(self):
32 | test_mol_smis = [
33 | 'CCCC', 'CC(=CC(=O)C)O.CC(=CC(=O)C)O.[Cu]', 'CCCCl',
34 | ('C[C@H](CCCC(C)C)[C@H]1CC[C@@H]2[C@@]1'
35 | '(CC[C@H]3[C@H]2CC=C4[C@@]3(CC[C@@H](C4)O)C)C')
36 | ]
37 | self.test_mols = [
38 | Chem.MolFromSmiles(mol_str) for mol_str in test_mol_smis
39 | ]
40 |
41 | def _validate_smiles_string_tokenization(self, smiles_string,
42 | expected_token_list):
43 | token_list = feature_utils.tokenize_smiles(np.array([smiles_string]))
44 | self.assertAllEqual(token_list, expected_token_list)
45 |
46 | def test_tokenize_smiles_string(self):
47 | self._validate_smiles_string_tokenization('CCCC', [28, 28, 28, 28])
48 | self._validate_smiles_string_tokenization('ClCCCC', [31, 28, 28, 28, 28])
49 | self._validate_smiles_string_tokenization('CCClCC', [28, 28, 31, 28, 28])
50 | self._validate_smiles_string_tokenization('CCCCCl', [28, 28, 28, 28, 31])
51 | self._validate_smiles_string_tokenization('ClC(CC)CCl',
52 | [31, 28, 2, 28, 28, 3, 28, 31])
53 | self._validate_smiles_string_tokenization(
54 | 'ClC(CCCl)CCl', [31, 28, 2, 28, 28, 31, 3, 28, 31])
55 | self._validate_smiles_string_tokenization('BrCCCCCl',
56 | [27, 28, 28, 28, 28, 31])
57 | self._validate_smiles_string_tokenization('ClCCCCBr',
58 | [31, 28, 28, 28, 28, 27])
59 | self._validate_smiles_string_tokenization('[Te][te]',
60 | [81, 71, 83, 81, 71, 83])
61 |
62 | def test_check_mol_only_has_atoms(self):
63 | result = [
64 | feature_utils.check_mol_only_has_atoms(mol, ['C'])
65 | for mol in self.test_mols
66 | ]
67 | self.assertAllEqual(result, [True, False, False, False])
68 |
69 | def test_check_mol_does_not_have_atoms(self):
70 | result = [
71 | feature_utils.check_mol_does_not_have_atoms(
72 | mol, ms_constants.METAL_ATOM_SYMBOLS) for mol in self.test_mols
73 | ]
74 | self.assertAllEqual(result, [True, False, True, True])
75 |
76 | def test_make_filter_by_substructure(self):
77 | filter_fn = feature_utils.make_filter_by_substructure('steroid')
78 | result = [filter_fn(mol) for mol in self.test_mols]
79 | self.assertAllEqual(result, [False, False, False, True])
80 |
81 | def test_convert_spectrum_array_to_string(self):
82 | spectra_array = np.zeros((2, 1000))
83 | spectra_array[0, 3] = 100
84 | spectra_array[1, 39] = 100
85 | spectra_array[1, 21] = 60
86 |
87 | expected_spectra_strings = ['3 100', '21 60\n39 100']
88 | result_spectra_strings = []
89 | for idx in range(np.shape(spectra_array)[0]):
90 | result_spectra_strings.append(
91 | feature_utils.convert_spectrum_array_to_string(spectra_array[idx, :]))
92 |
93 | self.assertAllEqual(expected_spectra_strings, result_spectra_strings)
94 |
95 |
96 | if __name__ == '__main__':
97 | tf.test.main()
98 |
--------------------------------------------------------------------------------
/testdata/test_2_mend.sdf:
--------------------------------------------------------------------------------
1 | METHANE, DIAZO-
2 | -MSLIB32-03140511252D 1 1.0 0.0 57
3 | CAS rn = 334883, Library ID = 4
4 | 3 2 0 0 0 0 0 0 0 0 0
5 | 20.0000 15.0000 0.0000 C 0 0 0 0 0 0 0 0 0
6 | 0.0000 15.0000 0.0000 N 0 0 0 0 5 0 0 0 0
7 | -20.0000 15.0000 0.0000 N 0 0 0 0 3 0 0 0 0
8 | 1 2 2 0 0 0 0
9 | 2 3 3 0 0 0 0
10 | M END
11 | >
12 | Methane, diazo-
13 |
14 | >
15 | Azimethylene
16 | Diazomethane
17 | Acomethylene
18 | Diazirine
19 | Diazonium methylide
20 |
21 | >
22 | YXHKONLOYHBTNS-UHFFFAOYSA-N
23 |
24 | >
25 | CH2N2
26 |
27 | >
28 | 42
29 |
30 | >
31 | 42.0217981
32 |
33 | >
34 | 22 110
35 | 23 220
36 | 24 999
37 | 25 25
38 | 26 12
39 | 27 58
40 | 28 179
41 | 30 22
42 | 31 110
43 | 32 425
44 |
45 | $$$$
46 | (4-(4-CHLORPHENYL)-3-MORPHOLINO-PYRROL-2-YL)-BUTENEDIOIC ACID, DIMETHYL ESTER
47 | -MSLIB32-03140511252D 1 1.0 0.0 286647
48 | CAS rn = 164221378, Library ID = 6
49 | 28 30 0 0 0 0 0 0 0 0 0
50 | -29.0000 41.0000 0.0000 C 0 0 0 0 0 0 0 0 0
51 | -49.0000 39.0000 0.0000 C 0 0 0 0 0 0 0 0 0
52 | 9.0000 6.0000 0.0000 C 0 0 0 0 0 0 0 0 0
53 | -8.0000 -6.0000 0.0000 N 0 0 0 0 3 0 0 0 0
54 | 38.0000 73.0000 0.0000 O 0 0 0 0 2 0 0 0 0
55 | 46.0000 55.0000 0.0000 C 0 0 0 0 0 0 0 0 0
56 | -21.0000 59.0000 0.0000 C 0 0 0 0 0 0 0 0 0
57 | -18.0000 25.0000 0.0000 C 0 0 0 0 0 0 0 0 0
58 | -61.0000 55.0000 0.0000 C 0 0 0 0 0 0 0 0 0
59 | -33.0000 75.0000 0.0000 C 0 0 0 0 0 0 0 0 0
60 | 2.0000 25.0000 0.0000 C 0 0 0 0 0 0 0 0 0
61 | -24.0000 6.0000 0.0000 C 0 0 0 0 0 0 0 0 0
62 | -53.0000 73.0000 0.0000 C 0 0 0 0 0 0 0 0 0
63 | -65.0000 89.0000 0.0000 Cl 0 0 0 0 1 0 0 0 0
64 | 28.0000 -0.0000 0.0000 C 0 0 0 0 0 0 0 0 0
65 | 14.0000 41.0000 0.0000 N 0 0 0 0 3 0 0 0 0
66 | 32.0000 -20.0000 0.0000 C 0 0 0 0 0 0 0 0 0
67 | 43.0000 13.0000 0.0000 C 0 0 0 0 0 0 0 0 0
68 | 34.0000 39.0000 0.0000 C 0 0 0 0 0 0 0 0 0
69 | 6.0000 59.0000 0.0000 C 0 0 0 0 0 0 0 0 0
70 | 18.0000 75.0000 0.0000 C 0 0 0 0 0 0 0 0 0
71 | 17.0000 -33.0000 0.0000 C 0 0 0 0 0 0 0 0 0
72 | 21.0000 -53.0000 0.0000 O 0 0 0 0 2 0 0 0 0
73 | -2.0000 -27.0000 0.0000 O 0 0 0 0 2 0 0 0 0
74 | 62.0000 7.0000 0.0000 O 0 0 0 0 2 0 0 0 0
75 | 38.0000 33.0000 0.0000 O 0 0 0 0 2 0 0 0 0
76 | 66.0000 -13.0000 0.0000 C 0 0 0 0 0 0 0 0 0
77 | 40.0000 -59.0000 0.0000 C 0 0 0 0 0 0 0 0 0
78 | 1 2 2 0 0 0 0
79 | 1 7 1 0 0 0 0
80 | 1 8 1 0 0 0 0
81 | 2 9 1 0 0 0 0
82 | 3 4 1 0 0 0 0
83 | 3 15 1 0 0 0 0
84 | 3 11 2 0 0 0 0
85 | 4 12 1 0 0 0 0
86 | 5 6 1 0 0 0 0
87 | 5 21 1 0 0 0 0
88 | 6 19 1 0 0 0 0
89 | 7 10 2 0 0 0 0
90 | 8 11 1 0 0 0 0
91 | 8 12 2 0 0 0 0
92 | 9 13 2 0 0 0 0
93 | 10 13 1 0 0 0 0
94 | 11 16 1 0 0 0 0
95 | 13 14 1 0 0 0 0
96 | 15 17 2 0 0 0 0
97 | 15 18 1 0 0 0 0
98 | 16 19 1 0 0 0 0
99 | 16 20 1 0 0 0 0
100 | 17 22 1 0 0 0 0
101 | 18 25 1 0 0 0 0
102 | 18 26 2 0 0 0 0
103 | 20 21 1 0 0 0 0
104 | 22 23 1 0 0 0 0
105 | 22 24 2 0 0 0 0
106 | 23 28 1 0 0 0 0
107 | 25 27 1 0 0 0 0
108 | M END
109 | >
110 | (4-(4-Chlorphenyl)-3-morpholino-pyrrol-2-yl)-butenedioic acid, dimethyl ester
111 |
112 | >
113 | Dimethyl (2E)-2-[4-(4-chlorophenyl)-3-(4-morpholinyl)-1H-pyrrol-2-yl]-2-butenedioate #
114 |
115 | >
116 | PNYUDNYAXSEACV-RVDMUPIBSA-N
117 |
118 | >
119 | C20H21ClN2O5
120 |
121 | >
122 | 404
123 |
124 | >
125 | 404.1139
126 |
127 | >
128 | 32 12
129 | 33 7
130 | 34 28
131 | 35 999
132 | 36 57
133 | 37 302
134 | 38 975
135 | 39 8
136 | 40 53
137 | 41 176
138 | 42 99
139 | 43 122
140 | 44 117
141 | 45 155
142 | 46 9
143 | 47 7
144 | 49 6
145 | 50 28
146 | 51 59
147 |
148 | $$$$
149 |
--------------------------------------------------------------------------------
/spectra_predictor_test.py:
--------------------------------------------------------------------------------
1 | """Tests for .spectra_predictor."""
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 | import os
7 | import tempfile
8 | from absl.testing import absltest
9 |
10 | import feature_utils
11 | import mass_spec_constants as ms_constants
12 | import spectra_predictor
13 | import test_utils
14 |
15 | import numpy as np
16 | import tensorflow as tf
17 |
18 |
19 | class DummySpectraPredictor(spectra_predictor.SpectraPredictor):
20 | """A test class that returns the mol weight input as the spectra prediction."""
21 |
22 | def _setup_prediction_op(self):
23 | fingerprint_input_op = tf.placeholder(tf.float32, (None, 4096))
24 | mol_weight_input_op = tf.placeholder(tf.float32, (None, 1))
25 |
26 | feature_dict = {
27 | self.fingerprint_input_key: fingerprint_input_op,
28 | self.molecular_weight_key: mol_weight_input_op
29 | }
30 |
31 | predict_op = tf.multiply(fingerprint_input_op, mol_weight_input_op)
32 | return feature_dict, predict_op
33 |
34 |
35 | class SpectraPredictorTest(tf.test.TestCase):
36 |
37 | def setUp(self):
38 | super(SpectraPredictorTest, self).setUp()
39 | self.np_fingerprint_input = np.ones((2, 4096))
40 | self.np_mol_weight_input = np.reshape(np.array([18., 16.]), (2, 1))
41 | self.test_data_directory = test_utils.test_dir("testdata/")
42 | self.temp_dir = tempfile.mkdtemp(dir=absltest.get_default_test_tmpdir())
43 | self.test_file_short = os.path.join(self.test_data_directory,
44 | "test_2_mend.sdf")
45 |
46 | def tearDown(self):
47 | tf.reset_default_graph()
48 | tf.io.gfile.rmtree(self.temp_dir)
49 | super(SpectraPredictorTest, self).tearDown()
50 |
51 | def test_make_dummy_spectra_prediction(self):
52 | """Tests the dummy predictor."""
53 | predictor = DummySpectraPredictor()
54 |
55 | spectra_predictions = predictor.make_spectra_prediction(
56 | self.np_fingerprint_input, self.np_mol_weight_input)
57 | expected_value = np.multiply(
58 | self.np_fingerprint_input, self.np_mol_weight_input)
59 | expected_value = (
60 | expected_value / np.max(expected_value, axis=1, keepdims=True) *
61 | spectra_predictor.SCALE_FACTOR_FOR_LARGEST_INTENSITY)
62 | self.assertAllEqual(spectra_predictions, expected_value)
63 |
64 | def test_make_neims_spectra_prediction_without_weights(self):
65 | """Tests loading the parameters for the neims model without weights."""
66 | predictor = spectra_predictor.NeimsSpectraPredictor(model_checkpoint_dir="")
67 |
68 | spectra_predictions = predictor.make_spectra_prediction(
69 | self.np_fingerprint_input, self.np_mol_weight_input)
70 |
71 | self.assertEqual(
72 | np.shape(spectra_predictions),
73 | (np.shape(self.np_fingerprint_input)[0], ms_constants.MAX_PEAK_LOC))
74 |
75 | def test_load_fingerprints_from_sdf(self):
76 | """Tests loading fingerprints from an sdf file."""
77 | predictor = spectra_predictor.NeimsSpectraPredictor(model_checkpoint_dir="")
78 |
79 | mols_from_file = spectra_predictor.get_mol_list_from_sdf(
80 | self.test_file_short)
81 | fingerprints_from_file = predictor.get_fingerprints_from_mol_list(
82 | mols_from_file)
83 |
84 | self.assertEqual(np.shape(fingerprints_from_file), (2, 4096))
85 |
86 | def test_write_spectra_to_sdf(self):
87 | """Tests predicting and writing spectra to file."""
88 | predictor = spectra_predictor.NeimsSpectraPredictor(model_checkpoint_dir="")
89 |
90 | mols_from_file = spectra_predictor.get_mol_list_from_sdf(
91 | self.test_file_short)
92 | fingerprints, mol_weights = predictor.get_inputs_for_model_from_mol_list(
93 | mols_from_file)
94 |
95 | spectra_predictions = predictor.make_spectra_prediction(
96 | fingerprints, mol_weights)
97 | spectra_predictor.update_mols_with_spectra(mols_from_file,
98 | spectra_predictions)
99 |
100 | _, fpath = tempfile.mkstemp(dir=self.temp_dir)
101 | spectra_predictor.write_rdkit_mols_to_sdf(mols_from_file, fpath)
102 |
103 | # Test contents of newly written file:
104 | new_mol_list = spectra_predictor.get_mol_list_from_sdf(fpath)
105 |
106 | for idx, mol in enumerate(new_mol_list):
107 | peak_string_from_file = mol.GetProp(
108 | spectra_predictor.PREDICTED_SPECTRA_PROP_NAME)
109 | peak_locs, peak_intensities = feature_utils.parse_peaks(
110 | peak_string_from_file)
111 | dense_mass_spectra = feature_utils.make_dense_mass_spectra(
112 | peak_locs, peak_intensities, ms_constants.MAX_PEAK_LOC)
113 | self.assertSequenceAlmostEqual(
114 | dense_mass_spectra, spectra_predictions[idx, :].astype(np.int32),
115 | delta=1.)
116 |
117 |
118 | if __name__ == "__main__":
119 | tf.test.main()
120 |
--------------------------------------------------------------------------------
/similarity.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Helper functions for similarity computation."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 | import abc
21 |
22 | import numpy as np
23 | import tensorflow as tf
24 |
25 | # When computing cosine similarity, the denominator is constrained to be no
26 | # smaller than this.
27 | EPSILON = 1e-6
28 |
29 |
30 | class SimilarityProvider(object):
31 | """Abstract class of helpers for similarity-based library matching."""
32 | __metaclass__ = abc.ABCMeta
33 |
34 | def __init__(self, hparams=None):
35 | self.hparams = hparams
36 |
37 | @abc.abstractmethod
38 | def preprocess_library(self, library):
39 | """Perform normalization of [num_library_elements, feature_dim] library."""
40 |
41 | @abc.abstractmethod
42 | def undo_library_preprocessing(self, library):
43 | """Undo the effect of preprocess_library(), up to a scaling constant."""
44 |
45 | @abc.abstractmethod
46 | def preprocess_queries(self, queries):
47 | """Perform normalization of [num_query_elements, feature_dim] queries."""
48 |
49 | @abc.abstractmethod
50 | def compute_similarity(self, library, queries):
51 | """Compute [num_library_elements, num_query_elements] similarities."""
52 |
53 | @abc.abstractmethod
54 | def make_training_loss(self, true_tensor, predicted_tensor):
55 | """Create training loss that is consistent with the similarity."""
56 |
57 |
58 | class CosineSimilarityProvider(SimilarityProvider):
59 | """Cosine similarity."""
60 |
61 | def _normalize_rows(self, tensor):
62 | return tf.nn.l2_normalize(tensor, axis=1)
63 |
64 | def preprocess_library(self, library):
65 | return self._normalize_rows(library)
66 |
67 | def undo_library_preprocessing(self, library):
68 | return library
69 |
70 | def preprocess_queries(self, queries):
71 | return self._normalize_rows(queries)
72 |
73 | def compute_similarity(self, library, queries):
74 | similarities = tf.matmul(library, queries, transpose_b=True)
75 | return tf.transpose(similarities)
76 |
77 | def make_training_loss(self, true_tensor, predicted_tensor):
78 | return tf.reduce_mean(
79 | tf.losses.mean_squared_error(true_tensor, predicted_tensor))
80 |
81 |
82 | class GeneralizedCosineSimilarityProvider(CosineSimilarityProvider):
83 | """Custom cosine similarity that is popular for massspec matching."""
84 |
85 | def _make_weights(self, tensor):
86 | num_bins = tensor.shape[1].value
87 | weights = np.power(np.arange(1, num_bins + 1),
88 | self.hparams.mass_power)[np.newaxis, :]
89 | return weights / np.sum(weights)
90 |
91 | def _normalize_rows(self, tensor):
92 | if self.hparams.mass_power != 0:
93 | tensor *= self._make_weights(tensor)
94 |
95 | return super(GeneralizedCosineSimilarityProvider,
96 | self)._normalize_rows(tensor)
97 |
98 | def undo_library_preprocessing(self, library):
99 | return library / self._make_weights(library)
100 |
101 | def compute_similarity(self, library, queries):
102 | similarities = tf.matmul(library, queries, transpose_b=True)
103 | return tf.transpose(similarities)
104 |
105 | def make_training_loss(self, true_tensor, predicted_tensor):
106 | if self.hparams.mass_power != 0:
107 | weights = self._make_weights(true_tensor)
108 | weighted_squared_error = weights * tf.square(true_tensor -
109 | predicted_tensor)
110 | return tf.reduce_mean(weighted_squared_error)
111 | else:
112 | return tf.reduce_mean(
113 | tf.losses.mean_squared_error(true_tensor, predicted_tensor))
114 |
115 |
116 | def max_margin_ranking_loss(predictions, target_indices, library,
117 | similarity_provider, margin):
118 | """Max-margin ranking loss.
119 |
120 | loss = (1/batch_size) * sum_i w_i sum_j max(0,
121 | similarities[i, j]
122 | - similarities[i, ti] + margin),
123 | where similarities = similarity_provider.compute_similarity(library,
124 | predictions)
125 | and ti = target_indices[i]. Here, w_i is a weight placed on each element of
126 | the batch. Without w_i, our loss would be the standard Crammer-Singer
127 | multiclass svm. Instead, we set w_i so that the total constribution to the
128 | parameter gradient from each batch element is equal. Therefore, we set w_i
129 | equal to 1 / (the number of margin violations for element i).
130 |
131 | Args:
132 | predictions: [batch_size, prediction_dim] float Tensor
133 | target_indices: [batch_size] int Tensor
134 | library: [num_library_elements, prediction_dim] constant Tensor
135 | similarity_provider: a SimilarityProvider instance
136 | margin: float
137 | Returns:
138 | loss
139 | """
140 | library = similarity_provider.preprocess_library(library)
141 | predictions = similarity_provider.preprocess_queries(predictions)
142 | similarities = similarity_provider.compute_similarity(library, predictions)
143 |
144 | batch_size = tf.shape(predictions)[0]
145 |
146 | target_indices = tf.squeeze(target_indices, axis=1)
147 | row_indices = tf.range(0, batch_size)
148 | indices = tf.stack([row_indices, tf.cast(target_indices, tf.int32)], axis=1)
149 | ground_truth_similarities = tf.gather_nd(similarities, indices)
150 |
151 | margin_violations = tf.nn.relu(-ground_truth_similarities[..., tf.newaxis] +
152 | similarities + margin)
153 |
154 | margin_violators = tf.cast(margin_violations > 0, tf.int32)
155 | margin_violators_per_batch_element = tf.to_float(
156 | tf.reduce_sum(margin_violators, axis=1, keep_dims=True))
157 | margin_violators_per_batch_element = tf.maximum(
158 | margin_violators_per_batch_element, 1.)
159 | margin_violators_per_batch_element = tf.stop_gradient(
160 | margin_violators_per_batch_element)
161 | tf.summary.scalar('num_margin_violations',
162 | tf.reduce_mean(margin_violators_per_batch_element))
163 | weighted_margin_violations = (
164 | margin_violations / margin_violators_per_batch_element)
165 | return tf.reduce_sum(weighted_margin_violations) / tf.maximum(
166 | tf.to_float(batch_size), 1.)
167 |
--------------------------------------------------------------------------------
/util_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for utils."""
15 |
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 | import tempfile
20 |
21 | from absl.testing import absltest
22 | from absl.testing import parameterized
23 |
24 | import util
25 | import numpy as np
26 | import tensorflow as tf
27 |
28 |
29 | class UtilTest(tf.test.TestCase, parameterized.TestCase):
30 |
31 | def setUp(self):
32 | self.temp_dir = tempfile.mkdtemp(dir=absltest.get_default_test_tmpdir())
33 |
34 | def tearDown(self):
35 | tf.gfile.DeleteRecursively(self.temp_dir)
36 |
37 | def _make_model(self, batch_size, num_batches, variable_initializer_value):
38 | np_inputs = np.arange(batch_size * num_batches)
39 | np_inputs = np.float32(np_inputs)
40 | inputs = tf.data.Dataset.from_tensor_slices(np_inputs)
41 | inputs = inputs.batch(batch_size).make_one_shot_iterator().get_next()
42 | scale = tf.get_variable(
43 | name='scale', dtype=tf.float32, initializer=variable_initializer_value,
44 | trainable=True)
45 | output = inputs * scale
46 | return output
47 |
48 | def test_run_graph_and_process_results(self):
49 |
50 | batch_size = 3
51 | num_batches = 5
52 |
53 | # Make a graph that contains a Variable and save it to checkpoint.
54 | with tf.Graph().as_default():
55 | _ = self._make_model(
56 | batch_size=batch_size, num_batches=num_batches,
57 | variable_initializer_value=2.0)
58 | saver = tf.train.Saver(tf.trainable_variables())
59 | with tf.Session() as sess:
60 | sess.run(tf.global_variables_initializer())
61 | saver.save(sess, self.temp_dir + '/model-')
62 |
63 | # Make another copy of the graph, and process data using this one.
64 | with tf.Graph().as_default():
65 | # We intentionally make this graph have a different value for its Variable
66 | # than the graph above. When we restore from checkpoint, we will grab the
67 | # value from the first graph. This helps test that the Variables are
68 | # being properly restored from checkpoint.
69 | ops_to_fetch = self._make_model(
70 | batch_size=batch_size, num_batches=num_batches,
71 | variable_initializer_value=3.0
72 | )
73 |
74 | results = []
75 | def process_fetched_values_fn(np_array):
76 | results.append(np_array)
77 |
78 | model_checkpoint_path = self.temp_dir
79 |
80 | util.run_graph_and_process_results(ops_to_fetch, model_checkpoint_path,
81 | process_fetched_values_fn)
82 |
83 | results = np.concatenate(results, axis=0)
84 | expected_results = np.arange(num_batches * batch_size) * 2.0
85 |
86 | self.assertAllEqual(results, expected_results)
87 |
88 | @parameterized.parameters((7), (10), (65))
89 | def test_map_predictor(self, sub_batch_size):
90 | input_op = {
91 | 'a': tf.random_normal(shape=(50, 5)),
92 | 'b': tf.random_normal(shape=(50, 5))
93 | }
94 |
95 | def predictor_fn(data):
96 | return data['a'] + data['b']
97 |
98 | mapped_prediction = util.map_predictor(
99 | input_op, predictor_fn, sub_batch_size=sub_batch_size)
100 | unmapped_prediction = predictor_fn(input_op)
101 | difference = tf.reduce_mean(
102 | tf.squared_difference(mapped_prediction, unmapped_prediction))
103 | with tf.Session() as sess:
104 | self.assertLess(
105 | sess.run(difference), 1e-6,
106 | 'The output of _map_predictor does not match a direct '
107 | 'application of predictor_fn.')
108 |
109 | def test_value_op_with_initializer(self):
110 | """Test correctness of library_matching.value_op_with_initializer."""
111 |
112 | base_value_op = tf.get_variable('value', initializer=0.)
113 |
114 | def make_value_op():
115 | return base_value_op
116 |
117 | def make_init_op(value):
118 | # This is a simple assignment that could have been achieved by changing
119 | # the initializer above. However, in other use cases of
120 | # value_op_with_initializer, the contructed value requires
121 | # data-dependent computation that can't be done via an initializer.
122 | return value.assign(tf.ones_like(value))
123 |
124 | value_op = util.value_op_with_initializer(make_value_op, make_init_op)
125 |
126 | # Check that the value of the Variable generated by make_value_op()
127 | # is the value constructed by make_init_op, not the value given
128 | # the initializer given to the Variable's constructor.
129 | with tf.Session() as sess:
130 | sess.run(tf.global_variables_initializer())
131 | sess.run(tf.local_variables_initializer())
132 | self.assertAllEqual(sess.run(base_value_op), 0.0)
133 | self.assertAllEqual(sess.run(value_op), 1.0)
134 | self.assertAllEqual(sess.run(base_value_op), 1.0)
135 |
136 | def test_scatter_by_anchor_indices(self):
137 |
138 | def _validate(anchor_indices, data, index_shift, expected_output):
139 | with tf.Graph().as_default():
140 | output = util.scatter_by_anchor_indices(anchor_indices, data,
141 | index_shift)
142 | with tf.Session() as sess:
143 | actual_output = sess.run(output)
144 | self.assertAllClose(
145 | np.array(expected_output, dtype=np.float32), actual_output)
146 |
147 | data = [[1, 2, 3], [4, 5, 6]]
148 |
149 | anchor_indices = [1, 1]
150 | index_shift = 0
151 | expected_output = [[2, 1, 0], [5, 4, 0]]
152 | _validate(anchor_indices, data, index_shift, expected_output)
153 |
154 | anchor_indices = [2, 2]
155 | index_shift = 0
156 | expected_output = [[3, 2, 1], [6, 5, 4]]
157 | _validate(anchor_indices, data, index_shift, expected_output)
158 |
159 | anchor_indices = [1, 1]
160 | index_shift = 1
161 | expected_output = [[3, 2, 1], [6, 5, 4]]
162 | _validate(anchor_indices, data, index_shift, expected_output)
163 |
164 | anchor_indices = [0, 1]
165 | index_shift = 1
166 | expected_output = [[2, 1, 0], [6, 5, 4]]
167 | _validate(anchor_indices, data, index_shift, expected_output)
168 |
169 | data = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]
170 | anchor_indices = [1, 2, 3]
171 | index_shift = 0
172 | expected_output = [[2, 1, 0, 0], [7, 6, 5, 0], [12, 11, 10, 9]]
173 | _validate(anchor_indices, data, index_shift, expected_output)
174 |
175 | data = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]
176 | anchor_indices = [0, 1, 2]
177 | index_shift = 1
178 | expected_output = [[2, 1, 0, 0], [7, 6, 5, 0], [12, 11, 10, 9]]
179 | _validate(anchor_indices, data, index_shift, expected_output)
180 |
181 | if __name__ == '__main__':
182 | tf.test.main()
183 |
--------------------------------------------------------------------------------
/make_predictions.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | r"""Run massspec model on data and write out predictions.
15 |
16 | Example usage:
17 | blaze-bin/third_party/py/deep_molecular_massspec/make_predictions \
18 | --alsologtostderr --input_file=testdata/test_14_record.gz \
19 | --output_file=/tmp/models/output_predictions \
20 | --model_checkpoint_path=/tmp/models/output/ \
21 | --hparams=eval_batch_size=16
22 |
23 | This saves a numpy archive to FLAGS.output_file that contains a dictionary
24 | where the keys are inchikeys and values are 1D np arrays for spectra.
25 |
26 | You should load this dict downstream using:
27 | data_dict = np.load(data_file).item()
28 |
29 | (Note that .item() is necessary because np.load returns a 0-D array,
30 | where the first element is the desired dictionary.)
31 | """
32 |
33 | from __future__ import print_function
34 | import json
35 | import os
36 | import tempfile
37 |
38 |
39 | import dataset_setup_constants as ds_constants
40 | import feature_map_constants as fmap_constants
41 |
42 | # Note that many FLAGS are inherited from molecule_estimator
43 | import molecule_estimator
44 |
45 | import molecule_predictors
46 | import plot_spectra_utils
47 | import util
48 |
49 | import numpy as np
50 | import tensorflow as tf
51 | import matplotlib.pyplot as plt
52 |
53 | FLAGS = tf.app.flags.FLAGS
54 |
55 | tf.flags.DEFINE_string(
56 | 'input_file', None, 'Input TFRecord file or a '
57 | 'globble file pattern for TFRecord files')
58 | tf.flags.DEFINE_string(
59 | 'model_checkpoint_path', None,
60 | 'Path to model checkpoint. If a directory, the most '
61 | 'recent model checkpoint in this directory will be used. If a file, it '
62 | 'should be of the form /.../name-of-the-file.ckpt-10000')
63 | tf.flags.DEFINE_bool(
64 | 'save_spectra_plots', True,
65 | 'Make plots of true and predicted spectra for each query molecule.'
66 | )
67 | tf.flags.DEFINE_string('output_file', None,
68 | 'Location where outputs will be written.')
69 |
70 |
71 | def _make_features_and_labels_from_tfrecord(input_file_pattern, hparams,
72 | features_to_load):
73 | """Construct features and labels Tensors to be consumed by model_fn."""
74 |
75 | def _make_tmp_dataset_config_file(input_filenames):
76 | """Construct a temporary config file that points to input_filename."""
77 |
78 | _, tmp_file = tempfile.mkstemp()
79 | dataset_config = {
80 | ds_constants.SPECTRUM_PREDICTION_TRAIN_KEY: input_filenames
81 | }
82 |
83 | with tf.gfile.Open(tmp_file, 'w') as f:
84 | json.dump(dataset_config, f)
85 |
86 | return tmp_file
87 |
88 | input_files = tf.gfile.Glob(input_file_pattern)
89 | if not input_files:
90 | raise ValueError('No files found matching %s' % input_file_pattern)
91 |
92 | data_dir, _ = os.path.split(input_files[0])
93 | data_basenames = [os.path.split(filename)[1] for filename in input_files]
94 | dataset_config_file = _make_tmp_dataset_config_file(data_basenames)
95 |
96 | mode = tf.estimator.ModeKeys.PREDICT
97 | input_fn = molecule_estimator.make_input_fn(
98 | dataset_config_file=dataset_config_file,
99 | hparams=hparams,
100 | mode=mode,
101 | features_to_load=features_to_load,
102 | data_dir=data_dir,
103 | load_library_matching_data=False)
104 | tf.gfile.Remove(dataset_config_file)
105 | return input_fn()
106 |
107 |
108 | def _make_features_labels_and_estimator(model_type, hparam_string, input_file):
109 | """Construct input ops and EstimatorSpec for massspec model."""
110 |
111 | prediction_helper = molecule_predictors.get_prediction_helper(model_type)
112 | hparams = prediction_helper.get_default_hparams()
113 | hparams.parse(hparam_string)
114 |
115 | model_fn = molecule_estimator.make_model_fn(
116 | prediction_helper, dataset_config_file=None, model_dir=None)
117 |
118 | features_to_load = prediction_helper.features_to_load(hparams)
119 | features, labels = _make_features_and_labels_from_tfrecord(
120 | input_file, hparams, features_to_load)
121 |
122 | estimator_spec = model_fn(
123 | features, labels, hparams, mode=tf.estimator.ModeKeys.PREDICT)
124 |
125 | return features, labels, estimator_spec
126 |
127 |
128 | def _save_plot_figure(key, prediction, true_spectrum, results_dir):
129 | """A helper function that makes and saves plots of true and predicted spectra."""
130 | spectra_plot_file_name = plot_spectra_utils.name_plot_file(
131 | plot_spectra_utils.PlotModeKeys.PREDICTED_SPECTRUM, key, file_type='png')
132 |
133 | # Rescale the true/predicted spectra
134 | true_spectrum = true_spectrum / true_spectrum.max() * plot_spectra_utils.MAX_VALUE_OF_TRUE_SPECTRA
135 | prediction = prediction / prediction.max() * plot_spectra_utils.MAX_VALUE_OF_TRUE_SPECTRA
136 |
137 | plot_spectra_utils.plot_true_and_predicted_spectra(
138 | true_spectrum, prediction,
139 | output_filename=os.path.join(results_dir,spectra_plot_file_name),
140 | rescale_mz_axis=True
141 | )
142 |
143 |
144 | def main(_):
145 |
146 | features, labels, estimator_spec = _make_features_labels_and_estimator(
147 | FLAGS.model_type, FLAGS.hparams, FLAGS.input_file)
148 | del labels # Unused
149 |
150 | pred_op = estimator_spec.predictions
151 | inchikey_op = features[fmap_constants.SPECTRUM_PREDICTION][
152 | fmap_constants.INCHIKEY]
153 | ops_to_fetch = [inchikey_op, pred_op]
154 | if FLAGS.save_spectra_plots:
155 | true_spectra_op = features[fmap_constants.SPECTRUM_PREDICTION][fmap_constants.DENSE_MASS_SPEC]
156 | ops_to_fetch.append(true_spectra_op)
157 |
158 | results = {}
159 | results_dir = os.path.dirname(FLAGS.output_file)
160 | tf.gfile.MakeDirs(results_dir)
161 |
162 | def process_fetched_values_fn(fetched_values):
163 | if FLAGS.save_spectra_plots:
164 | keys, predictions, true_spectra = fetched_values
165 | for key, prediction, true_spectrum in zip(keys, predictions, true_spectra):
166 | # Dereference the singleton np string array to get the actual string.
167 | key = key[0]
168 | results[key] = prediction
169 | _save_plot_figure(key, prediction, true_spectrum, results_dir)
170 | else:
171 | keys, predictions = fetched_values
172 | for key, prediction in zip(keys, predictions):
173 | # Dereference the singleton np string array to get the actual string.
174 | key = key[0]
175 | results[key] = prediction
176 |
177 | util.run_graph_and_process_results(ops_to_fetch, FLAGS.model_checkpoint_path,
178 | process_fetched_values_fn)
179 |
180 | np.save(FLAGS.output_file, results)
181 |
182 |
183 | if __name__ == '__main__':
184 | for flag in ['input_file', 'model_checkpoint_path', 'output_file']:
185 | tf.app.flags.mark_flag_as_required(flag)
186 |
187 | tf.app.run(main)
188 |
--------------------------------------------------------------------------------
/make_predictions_from_tfrecord.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | r"""Run massspec model on data and write out predictions.
15 |
16 | Example usage:
17 | blaze-bin/third_party/py/deep_molecular_massspec/make_predictions \
18 | --alsologtostderr --input_file=testdata/test_14_record.gz \
19 | --output_file=/tmp/models/output_predictions \
20 | --model_checkpoint_path=/tmp/models/output/ \
21 | --hparams=eval_batch_size=16
22 |
23 | This saves a numpy archive to FLAGS.output_file that contains a dictionary
24 | where the keys are inchikeys and values are 1D np arrays for spectra.
25 |
26 | You should load this dict downstream using:
27 | data_dict = np.load(data_file).item()
28 |
29 | (Note that .item() is necessary because np.load returns a 0-D array,
30 | where the first element is the desired dictionary.)
31 | """
32 |
33 | from __future__ import print_function
34 | import json
35 | import os
36 | import tempfile
37 |
38 |
39 | from absl import flags
40 |
41 | import dataset_setup_constants as ds_constants
42 | import feature_map_constants as fmap_constants
43 |
44 | # Note that many FLAGS are inherited from molecule_estimator
45 | import molecule_estimator
46 |
47 | import molecule_predictors
48 | import plot_spectra_utils
49 | import util
50 |
51 | import numpy as np
52 | import tensorflow as tf
53 |
54 | FLAGS = flags.FLAGS
55 |
56 | flags.DEFINE_string(
57 | 'input_file', None, 'Input TFRecord file or a '
58 | 'globble file pattern for TFRecord files')
59 | flags.DEFINE_string(
60 | 'model_checkpoint_path', None,
61 | 'Path to model checkpoint. If a directory, the most '
62 | 'recent model checkpoint in this directory will be used. If a file, it '
63 | 'should be of the form /.../name-of-the-file.ckpt-10000')
64 | flags.DEFINE_bool(
65 | 'save_spectra_plots', True,
66 | 'Make plots of true and predicted spectra for each query molecule.')
67 | flags.DEFINE_string('output_file', None,
68 | 'Location where outputs will be written.')
69 |
70 |
71 | def _make_features_and_labels_from_tfrecord(input_file_pattern, hparams,
72 | features_to_load):
73 | """Construct features and labels Tensors to be consumed by model_fn."""
74 |
75 | def _make_tmp_dataset_config_file(input_filenames):
76 | """Construct a temporary config file that points to input_filename."""
77 |
78 | _, tmp_file = tempfile.mkstemp()
79 | dataset_config = {
80 | ds_constants.SPECTRUM_PREDICTION_TRAIN_KEY: input_filenames
81 | }
82 |
83 | # TODO(b/135189673): Replace this with tf.gfile.Open once the type
84 | # issue is fixed.
85 | with open(tmp_file, 'w') as f:
86 | json.dump(dataset_config, f)
87 |
88 | return tmp_file
89 |
90 | input_files = tf.gfile.Glob(input_file_pattern)
91 | if not input_files:
92 | raise ValueError('No files found matching %s' % input_file_pattern)
93 |
94 | data_dir, _ = os.path.split(input_files[0])
95 | data_basenames = [os.path.split(filename)[1] for filename in input_files]
96 | dataset_config_file = _make_tmp_dataset_config_file(data_basenames)
97 |
98 | mode = tf.estimator.ModeKeys.PREDICT
99 | input_fn = molecule_estimator.make_input_fn(
100 | dataset_config_file=dataset_config_file,
101 | hparams=hparams,
102 | mode=mode,
103 | features_to_load=features_to_load,
104 | data_dir=data_dir,
105 | load_library_matching_data=False)
106 | tf.gfile.Remove(dataset_config_file)
107 | return input_fn()
108 |
109 |
110 | def _make_features_labels_and_estimator(model_type, hparam_string, input_file):
111 | """Construct input ops and EstimatorSpec for massspec model."""
112 |
113 | prediction_helper = molecule_predictors.get_prediction_helper(model_type)
114 | hparams = prediction_helper.get_default_hparams()
115 | hparams.parse(hparam_string)
116 |
117 | model_fn = molecule_estimator.make_model_fn(
118 | prediction_helper, dataset_config_file=None, model_dir=None)
119 |
120 | features_to_load = prediction_helper.features_to_load(hparams)
121 | features, labels = _make_features_and_labels_from_tfrecord(
122 | input_file, hparams, features_to_load)
123 |
124 | estimator_spec = model_fn(
125 | features, labels, hparams, mode=tf.estimator.ModeKeys.PREDICT)
126 |
127 | return features, labels, estimator_spec
128 |
129 |
130 | def _save_plot_figure(key, prediction, true_spectrum, results_dir):
131 | """A helper function that makes and saves plots of true and predicted spectra."""
132 | spectra_plot_file_name = plot_spectra_utils.name_plot_file(
133 | plot_spectra_utils.PlotModeKeys.PREDICTED_SPECTRUM, key, file_type='png')
134 |
135 | # Rescale the true/predicted spectra
136 | true_spectrum = (
137 | true_spectrum / true_spectrum.max() *
138 | plot_spectra_utils.MAX_VALUE_OF_TRUE_SPECTRA)
139 | prediction = (
140 | prediction / prediction.max() *
141 | plot_spectra_utils.MAX_VALUE_OF_TRUE_SPECTRA)
142 |
143 | plot_spectra_utils.plot_true_and_predicted_spectra(
144 | true_spectrum,
145 | prediction,
146 | output_filename=os.path.join(results_dir, spectra_plot_file_name),
147 | rescale_mz_axis=True)
148 |
149 |
150 | def main(_):
151 |
152 | features, labels, estimator_spec = _make_features_labels_and_estimator(
153 | FLAGS.model_type, FLAGS.hparams, FLAGS.input_file)
154 | del labels # Unused
155 |
156 | pred_op = estimator_spec.predictions
157 | inchikey_op = features[fmap_constants.SPECTRUM_PREDICTION][
158 | fmap_constants.INCHIKEY]
159 | ops_to_fetch = [inchikey_op, pred_op]
160 | if FLAGS.save_spectra_plots:
161 | true_spectra_op = features[fmap_constants.SPECTRUM_PREDICTION][
162 | fmap_constants.DENSE_MASS_SPEC]
163 | ops_to_fetch.append(true_spectra_op)
164 |
165 | results = {}
166 | results_dir = os.path.dirname(FLAGS.output_file)
167 | tf.gfile.MakeDirs(results_dir)
168 |
169 | def process_fetched_values_fn(fetched_values):
170 | """Processes output values from estimator."""
171 | if FLAGS.save_spectra_plots:
172 | keys, predictions, true_spectra = fetched_values
173 | for key, prediction, true_spectrum in zip(keys, predictions,
174 | true_spectra):
175 | # Dereference the singleton np string array to get the actual string.
176 | key = key[0]
177 | results[key] = prediction
178 | # Spectra plots are saved to file.
179 | _save_plot_figure(key, prediction, true_spectrum, results_dir)
180 | else:
181 | keys, predictions = fetched_values
182 | for key, prediction in zip(keys, predictions):
183 | # Dereference the singleton np string array to get the actual string.
184 | key = key[0]
185 | results[key] = prediction
186 |
187 | util.run_graph_and_process_results(ops_to_fetch, FLAGS.model_checkpoint_path,
188 | process_fetched_values_fn)
189 |
190 | tf.gfile.MakeDirs(os.path.dirname(FLAGS.output_file))
191 | np.save(FLAGS.output_file, results)
192 |
193 |
194 | if __name__ == '__main__':
195 | for flag in ['input_file', 'model_checkpoint_path', 'output_file']:
196 | flags.mark_flag_as_required(flag)
197 |
198 | tf.app.run(main)
199 |
--------------------------------------------------------------------------------
/train_test_split_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utilities for making train test split for mass spectra datsets.
16 |
17 | Contains TrainValFractions namedtuple for passing 3-tuple of train, validation,
18 | and test fractions to use of a datasets. Also contains TrainValTestInchikeys
19 | namedtuple for 3-tuple of lists of inchikeys to put into train, validation, and
20 | test splits.
21 | """
22 |
23 | from __future__ import absolute_import
24 | from __future__ import division
25 | from __future__ import print_function
26 | from collections import namedtuple
27 | import random
28 |
29 | import feature_utils
30 | import mass_spec_constants as ms_constants
31 | import numpy as np
32 |
33 | # Helper class for storing train, validation, and test fractions.
34 | TrainValTestFractions = namedtuple('TrainValTestFractions',
35 | ['train', 'validation', 'test'])
36 |
37 | # Helper class for storing train, validation, and test inchikeys after split.
38 | TrainValTestInchikeys = namedtuple('TrainValTestInchikeys',
39 | ['train', 'validation', 'test'])
40 |
41 |
42 | def assert_all_lists_mutally_exclusive(list_of_lists):
43 | """Check if any lists within a list of lists contain identical items."""
44 | for idx, list1 in enumerate((list_of_lists)):
45 | for list2 in list_of_lists[idx + 1:]:
46 | if any(elem in list2 for elem in list1):
47 | raise ValueError(
48 | 'found matching items between two lists: \n {}\n {}'.format(
49 | ', '.join(list1),
50 | ', '.join(list2),
51 | ))
52 |
53 |
54 | def make_inchikey_dict(mol_list):
55 | """Converts rdkit.Mol list into dict of lists of Mols keyed by inchikey."""
56 | inchikey_dict = {}
57 | for mol in mol_list:
58 | inchikey = mol.GetProp(ms_constants.SDF_TAG_INCHIKEY)
59 | if inchikey not in inchikey_dict:
60 | inchikey_dict[inchikey] = [mol]
61 | else:
62 | inchikey_dict[inchikey].append(mol)
63 | return inchikey_dict
64 |
65 |
66 | def get_random_inchikeys(inchikey_list, train_val_test_split_fractions):
67 | """Splits a given inchikey list of into 3 lists for train/val/test sets."""
68 | random.shuffle(inchikey_list)
69 |
70 | train_num = int(train_val_test_split_fractions.train * len(inchikey_list))
71 | val_num = int(train_val_test_split_fractions.validation * len(inchikey_list))
72 |
73 | return TrainValTestInchikeys(inchikey_list[:train_num],
74 | inchikey_list[train_num:train_num + val_num],
75 | inchikey_list[train_num + val_num:])
76 |
77 |
78 | def get_inchikeys_by_family(inchikey_list,
79 | inchikey_dict,
80 | train_val_test_split_fractions,
81 | family_name='steroid',
82 | exclude_from_train=True):
83 | """Creates train/val/test split based on presence of steroids.
84 |
85 | Filters molecules according to whether they have the substructure specified
86 | by family_name. All molecules passing the filter will be placed in
87 | validation/test datasets or into the train set according to exclude from
88 | train. The molecules assigned to the validation/test split according to the
89 | relative ratio between the validation/test fractions.
90 |
91 | If the validation and tests fractions are both equal to 0.0, these values
92 | will be over written to 0.5 and 0.5.
93 |
94 | Args:
95 | inchikey_list: List of inchikeys to partition into train/val/test sets
96 | inchikey_dict: dict of inchikeys, [rdkit.Mol objects].
97 | Must contain inchikey_list in its keys.
98 | train_val_test_split_fractions: a TrainValTestFractions tuple
99 | family_name: str, a key in feature_utils.FAMILY_DICT
100 | exclude_from_train: indicates whether to include/exclude steroid molecules
101 | from training set. If excluded from training set, test and validation
102 | sets will be comprised only of these molecules.
103 | Returns:
104 | TrainValTestInchikeys object
105 | """
106 | _, val_fraction, test_fraction = train_val_test_split_fractions
107 | if val_fraction == 0.0 and test_fraction == 0.0:
108 | val_fraction = 0.5
109 | test_fraction = 0.5
110 |
111 | substructure_filter_fn = feature_utils.make_filter_by_substructure(
112 | family_name)
113 | family_inchikeys = []
114 | nonfamily_inchikeys = []
115 |
116 | for ikey in inchikey_list:
117 | if substructure_filter_fn(inchikey_dict[ikey][0]):
118 | family_inchikeys.append(ikey)
119 | else:
120 | nonfamily_inchikeys.append(ikey)
121 |
122 | if exclude_from_train:
123 | val_test_inchikeys, train_inchikeys = (family_inchikeys,
124 | nonfamily_inchikeys)
125 | else:
126 | train_inchikeys, val_test_inchikeys = (family_inchikeys,
127 | nonfamily_inchikeys)
128 |
129 | random.shuffle(val_test_inchikeys)
130 | val_num = int(
131 | val_fraction / (val_fraction + test_fraction) * len(val_test_inchikeys))
132 | return TrainValTestInchikeys(train_inchikeys, val_test_inchikeys[:val_num],
133 | val_test_inchikeys[val_num:])
134 |
135 |
136 | def make_train_val_test_split_inchikey_lists(train_inchikey_list,
137 | train_inchikey_dict,
138 | train_val_test_split_fractions,
139 | holdout_inchikey_list=None,
140 | splitting_type='random'):
141 | """Given inchikey lists, returns lists to use for train/val/test sets.
142 |
143 | If holdout_inchikey_list is given, the inchikeys in this list will be excluded
144 | from the returned train/validation/test lists.
145 |
146 | Args:
147 | train_inchikey_list : List of inchikeys to use for train/val/test sets
148 | train_inchikey_dict : Main dict keyed by inchikeys, values are lists of
149 | rdkit.Mol. Note that train_inchikey_dict.keys() != train_inchikey_list
150 | train_inchikey_dict will have many more keys than are in the list.
151 | train_val_test_split_fractions : a TrainValTestFractions tuple
152 | holdout_inchikey_list : List of inchikeys to exclude from train/val/test
153 | sets.
154 | splitting_type : method of splitting molecules into train/val/test sets.
155 | Returns:
156 | A TrainValTestInchikeys namedtuple
157 | Raises:
158 | ValueError : if not train_val_test_split_sizes XOR
159 | train_val_test_split_fractions
160 | or if specify a splitting_type that isn't implemented yet.
161 | """
162 | if not np.isclose([sum(train_val_test_split_fractions)], [1.0]):
163 | raise ValueError('Must specify train_val_test_split that sums to 1.0')
164 |
165 | if holdout_inchikey_list:
166 | # filter out those inchikeys that are in the holdout set.
167 | train_inchikey_list = [
168 | ikey for ikey in train_inchikey_list
169 | if ikey not in holdout_inchikey_list
170 | ]
171 |
172 | if splitting_type == 'random':
173 | return get_random_inchikeys(train_inchikey_list,
174 | train_val_test_split_fractions)
175 | else:
176 | # Assume that splitting_type is the name of a structure family.
177 | # get_inchikeys_by_family will throw an error if this is not supported.
178 | return get_inchikeys_by_family(
179 | train_inchikey_list,
180 | train_inchikey_dict,
181 | train_val_test_split_fractions,
182 | family_name=splitting_type,
183 | exclude_from_train=True)
184 |
185 |
186 | def make_mol_list_from_inchikey_dict(inchikey_dict, inchikey_list):
187 | """Return a list of rdkit.Mols given a list of inchikeys.
188 |
189 | Args:
190 | inchikey_dict : a dict of lists of rdkit.Mol objects keyed by inchikey
191 | inchikey_list : List of inchikeys of molecules we want in a list.
192 | Returns:
193 | A list of rdkit.Mols corresponding to inchikeys in inchikey_list.
194 | """
195 | mol_list = []
196 | for inchikey in inchikey_list:
197 | mol_list.extend(inchikey_dict[inchikey])
198 |
199 | return mol_list
200 |
--------------------------------------------------------------------------------
/molecule_estimator_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for molecule_estimator."""
16 |
17 | from __future__ import print_function
18 | import json
19 | import os
20 | import tempfile
21 |
22 |
23 | from absl.testing import absltest
24 | from absl.testing import parameterized
25 |
26 | import dataset_setup_constants as ds_constants
27 | import mass_spec_constants as ms_constants
28 | import molecule_estimator
29 | import molecule_predictors
30 | import parse_sdf_utils
31 | import plot_spectra_utils
32 | import test_utils
33 | import numpy as np
34 | import tensorflow as tf
35 |
36 |
37 | class MoleculeEstimatorTest(tf.test.TestCase, parameterized.TestCase):
38 |
39 | def setUp(self):
40 | """Sets up a dataset json for regular, baseline, and all_predicted cases."""
41 | super(MoleculeEstimatorTest, self).setUp()
42 | self.test_data_directory = test_utils.test_dir('testdata/')
43 | record_file = os.path.join(self.test_data_directory, 'test_14_record.gz')
44 |
45 | self.num_eval_examples = parse_sdf_utils.parse_info_file(record_file)[
46 | 'num_examples']
47 | self.temp_dir = tempfile.mkdtemp(dir=absltest.get_default_test_tmpdir())
48 | self.default_dataset_config_file = os.path.join(self.temp_dir,
49 | 'dataset_config.json')
50 | self.baseline_dataset_config_file = os.path.join(
51 | self.temp_dir, 'baseline_dataset_config.json')
52 | self.all_predicted_dataset_config_file = os.path.join(
53 | self.temp_dir, 'all_predicted_dataset_config.json')
54 |
55 | dataset_names = [
56 | ds_constants.SPECTRUM_PREDICTION_TRAIN_KEY,
57 | ds_constants.SPECTRUM_PREDICTION_TEST_KEY,
58 | ds_constants.LIBRARY_MATCHING_OBSERVED_KEY,
59 | ds_constants.LIBRARY_MATCHING_PREDICTED_KEY,
60 | ds_constants.LIBRARY_MATCHING_QUERY_KEY
61 | ]
62 |
63 | default_dataset_config = {key: [record_file] for key in dataset_names}
64 | default_dataset_config[
65 | ds_constants.TRAINING_SPECTRA_ARRAY_KEY] = os.path.join(
66 | self.test_data_directory, 'test_14.spectra_library.npy')
67 | with tf.gfile.Open(self.default_dataset_config_file, 'w') as f:
68 | json.dump(default_dataset_config, f)
69 |
70 | # Test estimator behavior when predicted set is empty
71 | baseline_dataset_config = dict(
72 | [(key, [record_file])
73 | if key != ds_constants.LIBRARY_MATCHING_PREDICTED_KEY else (key, [])
74 | for key in dataset_names])
75 | baseline_dataset_config[
76 | ds_constants.TRAINING_SPECTRA_ARRAY_KEY] = os.path.join(
77 | self.test_data_directory, 'test_14.spectra_library.npy')
78 | with tf.gfile.Open(self.baseline_dataset_config_file, 'w') as f:
79 | json.dump(baseline_dataset_config, f)
80 |
81 | # Test estimator behavior when observed set is empty
82 | all_predicted_dataset_config = dict(
83 | [(key, [record_file])
84 | if key != ds_constants.LIBRARY_MATCHING_OBSERVED_KEY else (key, [])
85 | for key in dataset_names])
86 | all_predicted_dataset_config[
87 | ds_constants.TRAINING_SPECTRA_ARRAY_KEY] = os.path.join(
88 | self.test_data_directory, 'test_14.spectra_library.npy')
89 | with tf.gfile.Open(self.all_predicted_dataset_config_file, 'w') as f:
90 | json.dump(all_predicted_dataset_config, f)
91 |
92 | def tearDown(self):
93 | tf.gfile.DeleteRecursively(self.temp_dir)
94 | super(MoleculeEstimatorTest, self).tearDown()
95 |
96 | def _get_loss_history(self, checkpoint_dir):
97 | """Get list of train losses from events file."""
98 | losses = []
99 | for event_file in tf.gfile.Glob(
100 | os.path.join(checkpoint_dir, 'events.out.tfevents.*')):
101 | for event in tf.train.summary_iterator(event_file):
102 | for v in event.summary.value:
103 | if v.tag == 'loss':
104 | losses.append(v.simple_value)
105 | return losses
106 |
107 | def _run_estimator(self, prediction_helper, get_hparams, dataset_config_file):
108 | """Helper function for running molecule_estimator."""
109 | checkpoint_dir = self.temp_dir
110 | config = tf.contrib.learn.RunConfig(
111 | model_dir=checkpoint_dir, save_summary_steps=1)
112 | (estimator, train_spec,
113 | eval_spec) = molecule_estimator.make_estimator_and_inputs(
114 | config,
115 | get_hparams(),
116 | prediction_helper,
117 | dataset_config_file,
118 | train_steps=10,
119 | model_dir=self.temp_dir)
120 |
121 | tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
122 | loss_history = self._get_loss_history(checkpoint_dir)
123 |
124 | init_loss = loss_history[0]
125 | loss = loss_history[-1]
126 | if not np.isfinite(loss):
127 | raise ValueError('Final loss is not finite: %f' % loss)
128 |
129 | tf.logging.info('initial loss : {} final loss : {}'.format(init_loss, loss))
130 |
131 | self.assertNotEqual(loss, init_loss,
132 | ('Loss did not change after brief testing: '
133 | 'init = %f, final = %f.') % (init_loss, loss))
134 |
135 | @parameterized.parameters(
136 | ('linear', 0, 'generalized_mse', False),
137 | ('mlp', 1, 'generalized_mse', True, True),
138 | ('linear', 0, 'cross_entropy', True),
139 | ('mlp', 2, 'generalized_mse', False),
140 | ('linear', 0, 'max_margin', True),
141 | ('smiles_rnn', 0, 'generalized_mse', True, True))
142 | def test_run_estimator(self, model_type, num_hidden_layers, loss_type,
143 | do_library_matching, bidirectional_prediction=False):
144 | """Integration test for molecule_estimator."""
145 | prediction_helper = molecule_predictors.get_prediction_helper(model_type)
146 |
147 | def get_hparams():
148 | hparams = prediction_helper.get_default_hparams()
149 | hparams.set_hparam('loss', loss_type)
150 | hparams.set_hparam('do_library_matching', do_library_matching)
151 | hparams.set_hparam('bidirectional_prediction', bidirectional_prediction)
152 |
153 | # To test batching and padding in library matching, set the
154 | # eval_batch_size such that it does not divide the number of examples
155 | # in the test set.
156 | eval_batch_size = np.int32(np.floor(self.num_eval_examples / 2) - 1)
157 | assert eval_batch_size > 0, ('The evaluation data is not big enough to '
158 | 'support using multiple batches, where the '
159 | 'batch size does not divide the total '
160 | 'number of examples.')
161 | hparams.set_hparam('eval_batch_size', eval_batch_size)
162 |
163 | if model_type == 'mlp':
164 | hparams.set_hparam('num_hidden_layers', num_hidden_layers)
165 | return hparams
166 |
167 | self._run_estimator(prediction_helper, get_hparams,
168 | self.default_dataset_config_file)
169 |
170 | def test_run_estimator_on_baseline(self):
171 | prediction_helper = molecule_predictors.get_prediction_helper('baseline')
172 | self._run_estimator(prediction_helper,
173 | prediction_helper.get_default_hparams,
174 | self.baseline_dataset_config_file)
175 |
176 | def test_run_estimator_on_all_predicted(self):
177 | prediction_helper = molecule_predictors.get_prediction_helper('mlp')
178 | self._run_estimator(prediction_helper,
179 | prediction_helper.get_default_hparams,
180 | self.all_predicted_dataset_config_file)
181 |
182 | def test_plot_true_and_predicted_spectra(self):
183 | """Test if plot is successfully generated given two spectra."""
184 | max_mass_spec_peak_loc = ms_constants.MAX_PEAK_LOC
185 | true_spectra = np.zeros(max_mass_spec_peak_loc)
186 | predicted_spectra = np.zeros(max_mass_spec_peak_loc)
187 | true_spectra[3:6] = 60
188 | predicted_spectra[300] = 999
189 | true_spectra[200] = 780
190 |
191 | test_figure_path_name = os.path.join(self.temp_dir, 'test.png')
192 | generated_plot = plot_spectra_utils.plot_true_and_predicted_spectra(
193 | true_spectra, predicted_spectra, output_filename=test_figure_path_name)
194 |
195 | self.assertEqual(
196 | np.shape(generated_plot),
197 | plot_spectra_utils.SPECTRA_PLOT_DIMENSIONS_RGB)
198 | self.assertTrue(os.path.exists(test_figure_path_name))
199 |
200 |
201 | if __name__ == '__main__':
202 | tf.test.main()
203 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Some general-purpose helper functions."""
15 |
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import numpy as np
21 | import tensorflow as tf
22 |
23 |
24 | def _get_ckpt_from_path(path):
25 | ckpt = tf.train.latest_checkpoint(path)
26 | if ckpt is None:
27 | raise ValueError('No checkpoint found in %s' % path)
28 | tf.logging.info('Reading from checkpoint %s', ckpt)
29 | return ckpt
30 |
31 |
32 | def run_graph_and_process_results(ops_to_fetch,
33 | model_checkpoint_path,
34 | process_fetched_values_fn,
35 | feed_dict=None,
36 | logging_frequency=10):
37 | """Run a graph repeatedly and use the fetched values.
38 |
39 | Args:
40 | ops_to_fetch: a single Tensor or nested structure of Tensors. The graph will
41 | be run, and the below callables will be called, until a
42 | tf.errors.OutOfRangeError is caught. This is thrown when a tf.data.Dataset
43 | runs out of data.
44 | model_checkpoint_path: Path to model checkpoint. If a directory, the most
45 | recent model checkpoint in this directory will be used.
46 | process_fetched_values_fn: A callable, potentially with side-effects, that
47 | takes as input the output of sess.run(ops_to_fetch).
48 | feed_dict: a feed_dict to be included in sess.run calls.
49 | logging_frequency: after this many batches have been processed, a logging
50 | message will be printed.
51 | """
52 | ckpt = _get_ckpt_from_path(model_checkpoint_path)
53 | saver = tf.train.Saver()
54 |
55 | with tf.Session() as sess:
56 | saver.restore(sess, ckpt)
57 | counter = 0
58 | while True:
59 | try:
60 | fetched_values = sess.run(ops_to_fetch, feed_dict=feed_dict)
61 | process_fetched_values_fn(fetched_values)
62 | counter += 1
63 | if counter % logging_frequency == 0:
64 | tf.logging.info('Total examples processed so far: %d', counter)
65 | except tf.errors.OutOfRangeError:
66 | tf.logging.info('Finished processing data. Processed %d batches',
67 | counter)
68 | break
69 |
70 |
71 | def map_predictor(input_op, predictor_fn, sub_batch_size):
72 | """Wrapper for tf.map_fn to do batched computation within each map step."""
73 |
74 | num_elements = tf.contrib.framework.nest.flatten(input_op)[0].shape[0].value
75 |
76 | # Only chop the batch dim into sub-batches if the input data is big.
77 | if num_elements < sub_batch_size:
78 | return predictor_fn(input_op)
79 |
80 | pad_amount = -num_elements % sub_batch_size
81 |
82 | def reshape(tensor):
83 | """Reshape into batches of sub-batches."""
84 | pad_shape = tensor.shape.as_list()
85 | pad_shape[0] = pad_amount
86 | padding = tf.zeros(shape=pad_shape, dtype=tensor.dtype)
87 | tensor = tf.concat([tensor, padding], axis=0)
88 | if tensor.shape[0].value % sub_batch_size != 0:
89 | raise ValueError('Incorrent padding size: %d does not '
90 | 'divide %d' % (sub_batch_size, tensor.shape[0].value))
91 |
92 | shape = tensor.shape.as_list()
93 | output_shape = [-1, sub_batch_size] + shape[1:]
94 | return tf.reshape(tensor, shape=output_shape)
95 |
96 | reshaped_inputs = tf.contrib.framework.nest.map_structure(reshape, input_op)
97 |
98 | mapped_prediction = tf.map_fn(
99 | predictor_fn,
100 | reshaped_inputs,
101 | parallel_iterations=1,
102 | back_prop=False,
103 | name=None,
104 | dtype=tf.float32)
105 |
106 | output_shape = [-1] + mapped_prediction.shape.as_list()[2:]
107 | reshaped_output = tf.reshape(mapped_prediction, shape=output_shape)
108 |
109 | # If padding was required for the input data, strip off the output of the
110 | # predictor on this padding.
111 | if pad_amount > 0:
112 | reshaped_output = reshaped_output[0:(-pad_amount), ...]
113 |
114 | return reshaped_output
115 |
116 |
117 | def get_static_shape_without_adding_ops(inputs, fn):
118 | """Get the shape of fn(inputs) without adding ops to the default graph.
119 |
120 | Operationally equivalent to fn(inputs).shape.as_list(), except that no
121 | ops are added to the default graph.
122 |
123 | In order to get the shape of fn(inputs) without adding ops to the graph
124 | we make a new graph, make placeholders with the right shape, construct
125 | fn(placeholders) in that graph, get the shape, and then delete the graph.
126 |
127 | Note that using this function may have unintended consequences if fn() has
128 | side effects.
129 |
130 | Args:
131 | inputs: a (nested) structure where the leaf elements are either Tensors or
132 | None.
133 | fn: a function that can be applied to inputs and returns a single Tensor.
134 | Returns:
135 | a python list containing the static shape of fn(inputs).
136 |
137 | """
138 | g = tf.Graph()
139 | with g.as_default():
140 | def make_placeholder(tensor):
141 | if tensor is None:
142 | return None
143 | else:
144 | return tf.placeholder(shape=tensor.shape, dtype=tensor.dtype)
145 |
146 | placeholders = tf.contrib.framework.nest.map_structure(make_placeholder,
147 | inputs)
148 | output_shape = fn(placeholders).shape.as_list()
149 |
150 | del g
151 | return output_shape
152 |
153 |
154 | def value_op_with_initializer(value_op_fn, init_op_fn):
155 | """Make value_op that gets set by idempotent init_op on first invocation."""
156 |
157 | init_has_been_run = tf.get_local_variable(
158 | 'has_been_run',
159 | initializer=np.zeros(shape=(), dtype=np.bool),
160 | dtype=tf.bool)
161 |
162 | value_op = value_op_fn()
163 |
164 | def run_init_and_toggle():
165 | init_op = init_op_fn(value_op)
166 |
167 | with tf.control_dependencies([init_op]):
168 | assign_op = init_has_been_run.assign(True)
169 |
170 | with tf.control_dependencies([assign_op]):
171 | return tf.identity(value_op)
172 |
173 | return tf.cond(init_has_been_run, lambda: value_op, run_init_and_toggle)
174 |
175 |
176 | def scatter_by_anchor_indices(anchor_indices, data, index_shift):
177 | """Shift data such that it is indexed relative to anchor_indices.
178 |
179 | For each row of the data array, we flip it horizontally and then shift it
180 | so that the output at (anchor_index + index_shift) is the leftmost column
181 | of the input. Namely:
182 |
183 | output[i][j] = data[i][anchor_indices[i] - j + index_shift]
184 |
185 | Args:
186 | anchor_indices: [batch_size] int Tensor or np array
187 | data: [batch_size, num_columns]: float Tensor or np array
188 | index_shift: int
189 | Returns:
190 | [batch_size, num_columns] Tensor
191 | """
192 | anchor_indices = tf.convert_to_tensor(anchor_indices)
193 | data = tf.convert_to_tensor(data)
194 |
195 | num_data_columns = data.shape[-1].value
196 | indices = np.arange(num_data_columns)[np.newaxis, ...]
197 | shifted_indices = anchor_indices[..., tf.newaxis] - indices + index_shift
198 | valid_indices = shifted_indices >= 0
199 |
200 | batch_size = tf.shape(data)[0]
201 |
202 | batch_indices = tf.tile(
203 | tf.range(batch_size)[..., tf.newaxis], [1, num_data_columns])
204 | shifted_indices += batch_indices * num_data_columns
205 |
206 | shifted_indices = tf.reshape(shifted_indices, [-1])
207 | num_elements = tf.shape(data)[0] * tf.shape(data)[1]
208 | row_indices = tf.range(num_elements)
209 | stacked_indices = tf.stack([row_indices, shifted_indices], axis=1)
210 |
211 | lower_batch_boundaries = tf.reshape(batch_indices * num_data_columns, [-1])
212 | upper_batch_boundaries = tf.reshape(((batch_indices + 1) * num_data_columns),
213 | [-1])
214 | valid_indices = tf.logical_and(shifted_indices >= lower_batch_boundaries,
215 | shifted_indices < upper_batch_boundaries)
216 | stacked_indices = tf.boolean_mask(
217 | stacked_indices,
218 | valid_indices,
219 | )
220 |
221 | dense_shape = tf.cast(tf.tile(num_elements[..., tf.newaxis], [2]), tf.int64)
222 |
223 | scattering_matrix = tf.SparseTensor(
224 | indices=tf.cast(stacked_indices, tf.int64),
225 | values=tf.ones_like(stacked_indices[:, 0], dtype=data.dtype),
226 | dense_shape=dense_shape)
227 |
228 | flattened_data = tf.reshape(data, [-1])[..., tf.newaxis]
229 | flattened_output = tf.sparse_tensor_dense_matmul(
230 | scattering_matrix,
231 | flattened_data,
232 | adjoint_a=False,
233 | adjoint_b=False,
234 | name=None)
235 |
236 | return tf.reshape(
237 | tf.transpose(flattened_output, [0, 1]), [-1, num_data_columns])
238 |
--------------------------------------------------------------------------------
/spectra_predictor.py:
--------------------------------------------------------------------------------
1 | """Helpers for generating spectra prediction from trained models."""
2 |
3 | import abc
4 |
5 | import feature_map_constants as fmap_constants
6 | import feature_utils
7 | import mass_spec_constants as ms_constants
8 | import molecule_predictors
9 |
10 | import numpy as np
11 | from rdkit import Chem
12 | from rdkit.Chem import AllChem
13 | import six
14 | import tensorflow as tf
15 |
16 | _DEFAULT_HPARAMS = {
17 | "radius": 2,
18 | "mass_power": 1.0,
19 | "gate_bidirectional_predictions": True,
20 | "include_atom_mass": True,
21 | "init_bias": "default",
22 | "reverse_prediction": True,
23 | "max_mass_spec_peak_loc": 1000,
24 | "num_hidden_units": 2000,
25 | "use_counting_fp": True,
26 | "max_atoms": 100,
27 | "intensity_power": 0.5,
28 | "max_prediction_above_molecule_mass": 5,
29 | "fp_length": 4096,
30 | "bidirectional_prediction": True,
31 | "resnet_bottleneck_factor": 0.5,
32 | "max_atom_type": 100,
33 | "hidden_layer_activation": "relu",
34 | "init_weights": "default",
35 | "num_hidden_layers": 7
36 | }
37 | _DEFAULT_HPARAMS_STR = ",".join(
38 | "{}={}".format(k, v) for k, v in six.iteritems(_DEFAULT_HPARAMS))
39 | PREDICTED_SPECTRA_PROP_NAME = "PREDICTED SPECTRUM"
40 |
41 | # Predictions from the model are normalized by default.
42 | # This factor is used to rescale the predictions so the highest intensity has
43 | # this value.
44 | SCALE_FACTOR_FOR_LARGEST_INTENSITY = 999.
45 |
46 |
47 | def fingerprints_to_use(hparams):
48 | """Given tf.HParams, return a ms_constants.CircularFingerprintKey."""
49 | if hparams.use_counting_fp:
50 | key = fmap_constants.COUNTING_CIRCULAR_FP_BASENAME
51 | else:
52 | key = fmap_constants.CIRCULAR_FP_BASENAME
53 |
54 | return ms_constants.CircularFingerprintKey(key, hparams.fp_length,
55 | hparams.radius)
56 |
57 |
58 | def get_mol_weights_from_mol_list(mol_list):
59 | """Given a list of rdkit.Mols, return weights for each mol."""
60 | return np.array([Chem.rdMolDescriptors.CalcExactMolWt(m) for m in mol_list])
61 |
62 |
63 | def get_mol_list_from_sdf(sdf_fname):
64 | """Reads a sdf file and returns a list of molecules.
65 |
66 | Note: rdkit's Chem.SDMolSupplier only accepts filenames as inputs. As such
67 | this code only supports local filesystem name environments.
68 |
69 | Args:
70 | sdf_fname: Path to sdf file.
71 |
72 | Returns:
73 | List of rdkit.Mol objects.
74 |
75 | Raises:
76 | ValueError if a molblock in the SDF cannot be parsed.
77 | """
78 | suppl = Chem.SDMolSupplier(sdf_fname)
79 | mols = []
80 |
81 | for idx, mol in enumerate(suppl):
82 | if mol is not None:
83 | mols.append(mol)
84 | else:
85 | fail_sdf_block = suppl.GetItemText(idx)
86 | raise ValueError("Unable to parse the following mol block %s" %
87 | fail_sdf_block)
88 | return mols
89 |
90 |
91 | def update_mols_with_spectra(mol_list, spectra_array):
92 | """Writes a predicted spectrum for each RDKit.mol object.
93 |
94 | Args:
95 | mol_list: List of rdkit.Mol objects.
96 | spectra_array: np.array of spectra.
97 |
98 | Returns:
99 | Updated list of rdkit.Mol objects where each molecule contains a predicted
100 | spectrum.
101 | """
102 | if len(mol_list) != np.shape(spectra_array)[0]:
103 | raise ValueError("Number of mols in mol list %d is not equal to number of "
104 | "spectra found %d." %
105 | (len(mol_list), np.shape(spectra_array)[0]))
106 | for mol, spectrum in zip(mol_list, spectra_array):
107 | spec_array_text = feature_utils.convert_spectrum_array_to_string(spectrum)
108 | mol.SetProp(PREDICTED_SPECTRA_PROP_NAME, spec_array_text)
109 | return mol_list
110 |
111 |
112 | def write_rdkit_mols_to_sdf(mol_list, out_sdf_name):
113 | """Writes a series of rdkit.Mol to SDF.
114 |
115 | Args:
116 | mol_list: List of rdkit.Mol objects.
117 | out_sdf_name: Output file path for molecules.
118 | """
119 | writer = AllChem.SDWriter(out_sdf_name)
120 |
121 | for mol in mol_list:
122 | writer.write(mol)
123 | writer.close()
124 |
125 |
126 | class SpectraPredictor(object):
127 | """Helper for generating a computational graph for making predictions."""
128 | __metaclass__ = abc.ABCMeta
129 |
130 | def __init__(self, hparams_str=""):
131 | """Sets up graph, session, and input and output ops for prediction.
132 |
133 | Args:
134 | hparams_str (str): String containing hyperparameter settings.
135 | """
136 |
137 | self._prediction_helper = molecule_predictors.get_prediction_helper("mlp")
138 | self._hparams = self._prediction_helper.get_default_hparams()
139 | self._hparams.parse(hparams_str)
140 | self._fingerprint_key = fingerprints_to_use(self._hparams)
141 | self.fingerprint_input_key = str(self._fingerprint_key)
142 | self.molecular_weight_key = fmap_constants.MOLECULE_WEIGHT
143 |
144 | self._graph = tf.Graph()
145 | self._sess = tf.Session(graph=self._graph)
146 | with self._graph.as_default():
147 | (self._placeholder_dict, self._predict_op) = self._setup_prediction_op()
148 | assert set(self._placeholder_dict) == set(
149 | [self.fingerprint_input_key, self.molecular_weight_key])
150 |
151 | @abc.abstractmethod
152 | def _setup_prediction_op(self):
153 | """Sets up prediction operation.
154 |
155 | Returns:
156 | placeholder_dict: Dict with self.fingerprint_input_key and
157 | self.molecular_weight_key as keys and values which are tf.placeholder
158 | for predicted spectra.
159 | predict_op: tf.Tensor for predicted spectra.
160 | """
161 |
162 | def make_spectra_prediction(self, fingerprint_array, molecule_weight_array):
163 | """Makes spectra prediction.
164 |
165 | Args:
166 | fingerprint_array (np.array): Contains molcule fingerprints.
167 | molecule_weight_array (np.array): Contains molecular weights. Should have
168 | same batch dimension as fingerprint_array.
169 |
170 | Returns:
171 | np.array of predictions.
172 | """
173 | molecule_weight_array = np.reshape(molecule_weight_array, (-1, 1))
174 | with self._graph.as_default():
175 | prediction = self._sess.run(
176 | self._predict_op,
177 | feed_dict={
178 | self._placeholder_dict[self.fingerprint_input_key]:
179 | fingerprint_array,
180 | self._placeholder_dict[self.molecular_weight_key]:
181 | molecule_weight_array
182 | })
183 |
184 | prediction = prediction / np.max(
185 | prediction, axis=1, keepdims=True) * SCALE_FACTOR_FOR_LARGEST_INTENSITY
186 | return prediction
187 |
188 | def get_fingerprints_from_mol_list(self, mol_list):
189 | """Converts a list of rdkit.Mol objects into circular fingerprints.
190 |
191 | Args:
192 | mol_list: List of rdkit.Mol objects.
193 |
194 | Returns:
195 | np.array of fingerprints for prediction.
196 | """
197 |
198 | fingerprints = [
199 | feature_utils.make_circular_fingerprint(mol, self._fingerprint_key)
200 | for mol in mol_list
201 | ]
202 |
203 | return np.array(fingerprints)
204 |
205 | def get_inputs_for_model_from_mol_list(self, mol_list):
206 | """Grabs fingerprints and molecular weights for the prediction model."""
207 | fingerprints = self.get_fingerprints_from_mol_list(mol_list)
208 | weights = get_mol_weights_from_mol_list(mol_list)
209 | return fingerprints, weights
210 |
211 |
212 | class NeimsSpectraPredictor(SpectraPredictor):
213 | """Helper for making spectra predictions using the trained NEIMS model."""
214 |
215 | def __init__(self, model_checkpoint_dir, hparams_str=_DEFAULT_HPARAMS_STR):
216 | """Initializes the predictor with the weights and hyperparameters.
217 |
218 | Args:
219 | model_checkpoint_dir (str): Path to checkpoint weights.
220 | hparams_str (str): String that contains hyperparameters for model.
221 | """
222 | super(NeimsSpectraPredictor, self).__init__(hparams_str)
223 | self.restore_from_checkpoint(model_checkpoint_dir)
224 |
225 | def _setup_prediction_op(self):
226 | """Sets up prediction operation and inputs for model."""
227 | fp_length = self._hparams.fp_length
228 |
229 | fingerprint_input_op = tf.placeholder(tf.float32, (None, fp_length))
230 | mol_weight_input_op = tf.placeholder(tf.float32, (None, 1))
231 |
232 | feature_dict = {
233 | self.fingerprint_input_key: fingerprint_input_op,
234 | self.molecular_weight_key: mol_weight_input_op
235 | }
236 |
237 | predict_op, _ = self._prediction_helper.make_prediction_ops(
238 | feature_dict,
239 | self._hparams,
240 | mode=tf.estimator.ModeKeys.PREDICT,
241 | reuse=False)
242 |
243 | return feature_dict, predict_op
244 |
245 | def restore_from_checkpoint(self, model_checkpoint_dir):
246 | """Restores model parameters from checkpoint directory.
247 |
248 | Args:
249 | model_checkpoint_dir (str): filepath directory to weights. If empty, model
250 | will be initialized with random weights.
251 | """
252 | with self._graph.as_default():
253 | if model_checkpoint_dir:
254 | saver = tf.train.Saver()
255 | saver.restore(self._sess,
256 | tf.train.latest_checkpoint(model_checkpoint_dir))
257 | else:
258 | tf.logging.warn("No model checkpoint directory given,"
259 | " reinitializing model.")
260 | self._sess.run(tf.global_variables_initializer())
261 |
--------------------------------------------------------------------------------
/dataset_setup_constants.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Constants for creating dataset for mass spectrometry experiment."""
16 |
17 | from typing import NamedTuple
18 |
19 | # List of components of the main NIST library and the replicates library.
20 | # The main NIST library is divided into a train/validation/test set
21 | # The replicates library is divided into validation/test sets.
22 | MAINLIB_TRAIN_BASENAME = 'mainlib_train'
23 | MAINLIB_VALIDATION_BASENAME = 'mainlib_validation'
24 | MAINLIB_TEST_BASENAME = 'mainlib_test'
25 | REPLICATES_TRAIN_BASENAME = 'replicates_train'
26 | REPLICATES_VALIDATION_BASENAME = 'replicates_validation'
27 | REPLICATES_TEST_BASENAME = 'replicates_test'
28 |
29 | # Key names of main datasets required for each experiment.
30 | SPECTRUM_PREDICTION_TRAIN_KEY = 'SPECTRUM_PREDICTION_TRAIN'
31 | SPECTRUM_PREDICTION_TEST_KEY = 'SPECTRUM_PREDICTION_TEST'
32 | LIBRARY_MATCHING_OBSERVED_KEY = 'LIBRARY_MATCHING_OBSERVED'
33 | LIBRARY_MATCHING_PREDICTED_KEY = 'LIBRARY_MATCHING_PREDICTED'
34 | LIBRARY_MATCHING_QUERY_KEY = 'LIBRARY_MATCHING_QUERY'
35 |
36 | TRAINING_SPECTRA_ARRAY_KEY = MAINLIB_TRAIN_BASENAME + '_spectra_library_file'
37 |
38 |
39 | class ExperimentSetup(
40 | NamedTuple('ExperimentSetup',
41 | [('json_name', str), ('data_to_get_from_mainlib', list),
42 | ('data_to_get_from_replicates', list),
43 | ('experiment_setup_dataset_dict', dict)])):
44 | """Stores information related to the various experiment setups.
45 |
46 | Attributes:
47 | json_nane : name of the json file to store locations of the datasets
48 | data_to_get_from_mainlib: List of dataset keys to grab from the mainlib
49 | library.
50 | data_to_get_from_replicates: List of dataset keys to grab from the
51 | replicates library.
52 | experiment_setup_dataset_dict: Dict which matches the experiment keys to
53 | lists of the component datasets, matching the basename keys above.
54 | """
55 |
56 | def __new__(cls, json_name, data_to_get_from_mainlib,
57 | data_to_get_from_replicates, experiment_setup_dataset_dict):
58 | assert (experiment_setup_dataset_dict[LIBRARY_MATCHING_QUERY_KEY] ==
59 | experiment_setup_dataset_dict[SPECTRUM_PREDICTION_TEST_KEY]), (
60 | 'In json {}, library query list did not match'
61 | ' spectrum prediction list.'.format(json_name))
62 | assert (experiment_setup_dataset_dict[SPECTRUM_PREDICTION_TRAIN_KEY] ==
63 | [MAINLIB_TRAIN_BASENAME]), (
64 | 'In json {}, spectra prediction dataset is not mainlib_train,'
65 | ' which is currently not supported.'.format(json_name))
66 | return super(ExperimentSetup, cls).__new__(
67 | cls, json_name, data_to_get_from_mainlib, data_to_get_from_replicates,
68 | experiment_setup_dataset_dict)
69 |
70 |
71 | # Experiment setups:
72 | QUERY_MAINLIB_VAL_PRED_MAINLIB_VAL = ExperimentSetup(
73 | 'query_mainlib_val_predicted_mainlib_val.json', [
74 | SPECTRUM_PREDICTION_TRAIN_KEY,
75 | SPECTRUM_PREDICTION_TEST_KEY,
76 | LIBRARY_MATCHING_QUERY_KEY,
77 | LIBRARY_MATCHING_OBSERVED_KEY,
78 | LIBRARY_MATCHING_PREDICTED_KEY,
79 | ], [], {
80 | SPECTRUM_PREDICTION_TRAIN_KEY: [MAINLIB_TRAIN_BASENAME],
81 | SPECTRUM_PREDICTION_TEST_KEY: [MAINLIB_VALIDATION_BASENAME],
82 | LIBRARY_MATCHING_OBSERVED_KEY: [
83 | MAINLIB_TRAIN_BASENAME, MAINLIB_TEST_BASENAME,
84 | REPLICATES_TRAIN_BASENAME, REPLICATES_VALIDATION_BASENAME,
85 | REPLICATES_TEST_BASENAME
86 | ],
87 | LIBRARY_MATCHING_PREDICTED_KEY: [MAINLIB_VALIDATION_BASENAME],
88 | LIBRARY_MATCHING_QUERY_KEY: [MAINLIB_VALIDATION_BASENAME],
89 | })
90 |
91 | QUERY_REPLICATES_VAL_PRED_REPLICATES_VAL = ExperimentSetup(
92 | 'query_replicates_val_predicted_replicates_val.json', [
93 | SPECTRUM_PREDICTION_TRAIN_KEY,
94 | SPECTRUM_PREDICTION_TEST_KEY,
95 | LIBRARY_MATCHING_OBSERVED_KEY,
96 | LIBRARY_MATCHING_PREDICTED_KEY,
97 | ], [LIBRARY_MATCHING_QUERY_KEY], {
98 | SPECTRUM_PREDICTION_TRAIN_KEY: [MAINLIB_TRAIN_BASENAME],
99 | SPECTRUM_PREDICTION_TEST_KEY: [REPLICATES_VALIDATION_BASENAME],
100 | LIBRARY_MATCHING_OBSERVED_KEY: [
101 | MAINLIB_TRAIN_BASENAME, MAINLIB_TEST_BASENAME,
102 | MAINLIB_VALIDATION_BASENAME, REPLICATES_TRAIN_BASENAME,
103 | REPLICATES_TEST_BASENAME
104 | ],
105 | LIBRARY_MATCHING_PREDICTED_KEY: [REPLICATES_VALIDATION_BASENAME],
106 | LIBRARY_MATCHING_QUERY_KEY: [REPLICATES_VALIDATION_BASENAME],
107 | })
108 |
109 | QUERY_REPLICATES_TEST_PRED_REPLICATES_TEST = ExperimentSetup(
110 | 'query_replicates_test_predicted_replicates_test.json', [
111 | SPECTRUM_PREDICTION_TRAIN_KEY,
112 | SPECTRUM_PREDICTION_TEST_KEY,
113 | LIBRARY_MATCHING_OBSERVED_KEY,
114 | LIBRARY_MATCHING_PREDICTED_KEY,
115 | ], [LIBRARY_MATCHING_QUERY_KEY], {
116 | SPECTRUM_PREDICTION_TRAIN_KEY: [MAINLIB_TRAIN_BASENAME],
117 | SPECTRUM_PREDICTION_TEST_KEY: [REPLICATES_TEST_BASENAME],
118 | LIBRARY_MATCHING_OBSERVED_KEY: [
119 | MAINLIB_TRAIN_BASENAME, MAINLIB_TEST_BASENAME,
120 | REPLICATES_TRAIN_BASENAME, MAINLIB_VALIDATION_BASENAME,
121 | REPLICATES_VALIDATION_BASENAME
122 | ],
123 | LIBRARY_MATCHING_PREDICTED_KEY: [REPLICATES_TEST_BASENAME],
124 | LIBRARY_MATCHING_QUERY_KEY: [REPLICATES_TEST_BASENAME],
125 | })
126 |
127 | QUERY_REPLICATES_VAL_PRED_NONE = ExperimentSetup(
128 | 'query_replicates_val_predicted_none.json', [
129 | SPECTRUM_PREDICTION_TRAIN_KEY,
130 | SPECTRUM_PREDICTION_TEST_KEY,
131 | LIBRARY_MATCHING_OBSERVED_KEY,
132 | LIBRARY_MATCHING_PREDICTED_KEY,
133 | ], [LIBRARY_MATCHING_QUERY_KEY], {
134 | SPECTRUM_PREDICTION_TRAIN_KEY: [MAINLIB_TRAIN_BASENAME],
135 | SPECTRUM_PREDICTION_TEST_KEY: [REPLICATES_VALIDATION_BASENAME],
136 | LIBRARY_MATCHING_OBSERVED_KEY: [
137 | MAINLIB_TRAIN_BASENAME, MAINLIB_TEST_BASENAME,
138 | MAINLIB_VALIDATION_BASENAME, REPLICATES_TRAIN_BASENAME,
139 | REPLICATES_TEST_BASENAME, REPLICATES_VALIDATION_BASENAME
140 | ],
141 | LIBRARY_MATCHING_PREDICTED_KEY: [],
142 | LIBRARY_MATCHING_QUERY_KEY: [REPLICATES_VALIDATION_BASENAME],
143 | })
144 |
145 | QUERY_REPLICATES_TEST_PRED_NONE = ExperimentSetup(
146 | 'query_replicates_test_predicted_none.json', [
147 | SPECTRUM_PREDICTION_TRAIN_KEY,
148 | SPECTRUM_PREDICTION_TEST_KEY,
149 | LIBRARY_MATCHING_OBSERVED_KEY,
150 | LIBRARY_MATCHING_PREDICTED_KEY,
151 | ], [LIBRARY_MATCHING_QUERY_KEY], {
152 | SPECTRUM_PREDICTION_TRAIN_KEY: [MAINLIB_TRAIN_BASENAME],
153 | SPECTRUM_PREDICTION_TEST_KEY: [REPLICATES_TEST_BASENAME],
154 | LIBRARY_MATCHING_OBSERVED_KEY: [
155 | MAINLIB_TRAIN_BASENAME, MAINLIB_TEST_BASENAME,
156 | MAINLIB_VALIDATION_BASENAME, REPLICATES_TRAIN_BASENAME,
157 | REPLICATES_VALIDATION_BASENAME, REPLICATES_TEST_BASENAME
158 | ],
159 | LIBRARY_MATCHING_PREDICTED_KEY: [],
160 | LIBRARY_MATCHING_QUERY_KEY: [REPLICATES_TEST_BASENAME],
161 | })
162 |
163 | QUERY_REPLICATES_ALL_PRED_NONE = ExperimentSetup(
164 | 'query_replicates_all_predicted_none.json', [
165 | SPECTRUM_PREDICTION_TRAIN_KEY,
166 | SPECTRUM_PREDICTION_TEST_KEY,
167 | LIBRARY_MATCHING_OBSERVED_KEY,
168 | LIBRARY_MATCHING_PREDICTED_KEY,
169 | ], [LIBRARY_MATCHING_QUERY_KEY], {
170 | SPECTRUM_PREDICTION_TRAIN_KEY: [MAINLIB_TRAIN_BASENAME],
171 | SPECTRUM_PREDICTION_TEST_KEY: [
172 | REPLICATES_VALIDATION_BASENAME, REPLICATES_TEST_BASENAME
173 | ],
174 | LIBRARY_MATCHING_OBSERVED_KEY: [
175 | MAINLIB_TRAIN_BASENAME, MAINLIB_TEST_BASENAME,
176 | MAINLIB_VALIDATION_BASENAME, REPLICATES_TRAIN_BASENAME,
177 | REPLICATES_VALIDATION_BASENAME, REPLICATES_TEST_BASENAME
178 | ],
179 | LIBRARY_MATCHING_PREDICTED_KEY: [],
180 | LIBRARY_MATCHING_QUERY_KEY: [
181 | REPLICATES_VALIDATION_BASENAME, REPLICATES_TEST_BASENAME
182 | ],
183 | })
184 |
185 | # An overfitting setup for sanity checks
186 | QUERY_MAINLIB_TRAIN_PRED_MAINLIB_TRAIN = ExperimentSetup(
187 | 'query_mainlib_train_predicted_mainlib_train.json', [
188 | SPECTRUM_PREDICTION_TRAIN_KEY, SPECTRUM_PREDICTION_TEST_KEY,
189 | LIBRARY_MATCHING_OBSERVED_KEY, LIBRARY_MATCHING_PREDICTED_KEY,
190 | LIBRARY_MATCHING_QUERY_KEY
191 | ], [], {
192 | SPECTRUM_PREDICTION_TRAIN_KEY: [MAINLIB_TRAIN_BASENAME],
193 | SPECTRUM_PREDICTION_TEST_KEY: [MAINLIB_TRAIN_BASENAME],
194 | LIBRARY_MATCHING_OBSERVED_KEY: [
195 | MAINLIB_VALIDATION_BASENAME,
196 | MAINLIB_TEST_BASENAME,
197 | REPLICATES_TRAIN_BASENAME,
198 | REPLICATES_VALIDATION_BASENAME,
199 | REPLICATES_TEST_BASENAME,
200 | ],
201 | LIBRARY_MATCHING_PREDICTED_KEY: [MAINLIB_TRAIN_BASENAME],
202 | LIBRARY_MATCHING_QUERY_KEY: [MAINLIB_TRAIN_BASENAME],
203 | })
204 |
205 | EXPERIMENT_SETUPS_LIST = [
206 | QUERY_MAINLIB_VAL_PRED_MAINLIB_VAL,
207 | QUERY_REPLICATES_VAL_PRED_REPLICATES_VAL,
208 | QUERY_REPLICATES_TEST_PRED_REPLICATES_TEST,
209 | QUERY_REPLICATES_VAL_PRED_NONE,
210 | QUERY_REPLICATES_TEST_PRED_NONE,
211 | QUERY_MAINLIB_TRAIN_PRED_MAINLIB_TRAIN,
212 | ]
213 |
--------------------------------------------------------------------------------
/make_train_test_split_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for make_train_test_split.py and train_test_split_utils.py."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 | import json
21 | import os
22 | import tempfile
23 |
24 | from absl.testing import absltest
25 | from absl.testing import parameterized
26 |
27 | import dataset_setup_constants as ds_constants
28 | import feature_map_constants as fmap_constants
29 | import make_train_test_split
30 | import mass_spec_constants as ms_constants
31 | import parse_sdf_utils
32 | import test_utils
33 | import train_test_split_utils
34 | import six
35 | import tensorflow as tf
36 |
37 |
38 | class MakeTrainTestSplitTest(tf.test.TestCase, parameterized.TestCase):
39 |
40 | def setUp(self):
41 | super(MakeTrainTestSplitTest, self).setUp()
42 | test_data_directory = test_utils.test_dir('testdata/')
43 | self.temp_dir = tempfile.mkdtemp(dir=absltest.get_default_test_tmpdir())
44 | test_sdf_file_large = os.path.join(test_data_directory, 'test_14_mend.sdf')
45 | test_sdf_file_small = os.path.join(test_data_directory, 'test_2_mend.sdf')
46 |
47 | max_atoms = ms_constants.MAX_ATOMS
48 | self.mol_list_large = parse_sdf_utils.get_sdf_to_mol(
49 | test_sdf_file_large, max_atoms=max_atoms)
50 | self.mol_list_small = parse_sdf_utils.get_sdf_to_mol(
51 | test_sdf_file_small, max_atoms=max_atoms)
52 | self.inchikey_dict_large = train_test_split_utils.make_inchikey_dict(
53 | self.mol_list_large)
54 | self.inchikey_dict_small = train_test_split_utils.make_inchikey_dict(
55 | self.mol_list_small)
56 | self.inchikey_list_large = list(self.inchikey_dict_large.keys())
57 | self.inchikey_list_small = list(self.inchikey_dict_small.keys())
58 |
59 | def tearDown(self):
60 | tf.gfile.DeleteRecursively(self.temp_dir)
61 | super(MakeTrainTestSplitTest, self).tearDown()
62 |
63 | def encode(self, value):
64 | """Wrapper function for encoding strings in python 3."""
65 | return test_utils.encode(value, six.PY3)
66 |
67 | def test_all_lists_mutually_exclusive(self):
68 | list1 = ['1', '2', '3']
69 | list2 = ['2', '3', '4']
70 | try:
71 | train_test_split_utils.assert_all_lists_mutally_exclusive([list1, list2])
72 | raise ValueError('Sets with overlapping elements should have failed.')
73 | except ValueError:
74 | pass
75 |
76 | def test_make_inchikey_dict(self):
77 | self.assertLen(self.inchikey_dict_large, 11)
78 | self.assertLen(self.inchikey_dict_small, 2)
79 |
80 | def test_make_mol_list_from_inchikey_dict(self):
81 | mol_list = train_test_split_utils.make_mol_list_from_inchikey_dict(
82 | self.inchikey_dict_large, self.inchikey_list_large)
83 | self.assertCountEqual(mol_list, self.mol_list_large)
84 |
85 | def test_make_train_val_test_split_mol_lists(self):
86 | main_train_test_split = train_test_split_utils.TrainValTestFractions(
87 | 0.5, 0.25, 0.25)
88 |
89 | inchikey_list_of_lists = (
90 | train_test_split_utils.make_train_val_test_split_inchikey_lists(
91 | self.inchikey_list_large, self.inchikey_dict_large,
92 | main_train_test_split))
93 |
94 | expected_lengths_of_inchikey_lists = [5, 2, 4]
95 |
96 | for expected_length, inchikey_list in zip(
97 | expected_lengths_of_inchikey_lists, inchikey_list_of_lists):
98 | self.assertLen(inchikey_list, expected_length)
99 |
100 | train_test_split_utils.assert_all_lists_mutally_exclusive(
101 | inchikey_list_of_lists)
102 |
103 | trunc_inchikey_list_large = self.inchikey_list_large[:6]
104 | inchikey_list_of_lists = [
105 | (train_test_split_utils.make_train_val_test_split_inchikey_lists(
106 | trunc_inchikey_list_large, self.inchikey_dict_large,
107 | main_train_test_split))
108 | ]
109 |
110 | expected_lengths_of_inchikey_lists = [3, 1, 2]
111 | for expected_length, inchikey_list in zip(
112 | expected_lengths_of_inchikey_lists, inchikey_list_of_lists):
113 | self.assertLen(inchikey_list, expected_length)
114 |
115 | train_test_split_utils.assert_all_lists_mutally_exclusive(
116 | inchikey_list_of_lists)
117 |
118 | def test_make_train_val_test_split_mol_lists_holdout(self):
119 | main_train_test_split = train_test_split_utils.TrainValTestFractions(
120 | 0.5, 0.25, 0.25)
121 | holdout_inchikey_list_of_lists = (
122 | train_test_split_utils.make_train_val_test_split_inchikey_lists(
123 | self.inchikey_list_large,
124 | self.inchikey_dict_large,
125 | main_train_test_split,
126 | holdout_inchikey_list=self.inchikey_list_small))
127 |
128 | expected_lengths_of_inchikey_lists = [4, 2, 3]
129 | for expected_length, inchikey_list in zip(
130 | expected_lengths_of_inchikey_lists, holdout_inchikey_list_of_lists):
131 | self.assertLen(inchikey_list, expected_length)
132 |
133 | train_test_split_utils.assert_all_lists_mutally_exclusive(
134 | holdout_inchikey_list_of_lists)
135 |
136 | def test_make_train_val_test_split_mol_lists_family(self):
137 | train_test_split = train_test_split_utils.TrainValTestFractions(
138 | 0.5, 0.25, 0.25)
139 | train_inchikeys, val_inchikeys, test_inchikeys = (
140 | train_test_split_utils.make_train_val_test_split_inchikey_lists(
141 | self.inchikey_list_large,
142 | self.inchikey_dict_large,
143 | train_test_split,
144 | holdout_inchikey_list=self.inchikey_list_small,
145 | splitting_type='diazo'))
146 |
147 | self.assertCountEqual(train_inchikeys, [
148 | 'UFHFLCQGNIYNRP-UHFFFAOYSA-N', 'CCGKOQOJPYTBIH-UHFFFAOYSA-N',
149 | 'ASTNYHRQIBTGNO-UHFFFAOYSA-N', 'UFHFLCQGNIYNRP-VVKOMZTBSA-N',
150 | 'PVVBOXUQVSZBMK-UHFFFAOYSA-N'
151 | ])
152 |
153 | self.assertCountEqual(val_inchikeys + test_inchikeys, [
154 | 'OWKPLCCVKXABQF-UHFFFAOYSA-N', 'COVPJOWITGLAKX-UHFFFAOYSA-N',
155 | 'GKVDXUXIAHWQIK-UHFFFAOYSA-N', 'UCIXUAPVXAZYDQ-VMPITWQZSA-N'
156 | ])
157 |
158 | replicate_train_inchikeys, _, replicate_test_inchikeys = (
159 | train_test_split_utils.make_train_val_test_split_inchikey_lists(
160 | self.inchikey_list_small,
161 | self.inchikey_dict_small,
162 | train_test_split,
163 | splitting_type='diazo'))
164 |
165 | self.assertEqual(replicate_train_inchikeys[0],
166 | 'PNYUDNYAXSEACV-RVDMUPIBSA-N')
167 | self.assertEqual(replicate_test_inchikeys[0], 'YXHKONLOYHBTNS-UHFFFAOYSA-N')
168 |
169 | @parameterized.parameters('random', 'diazo')
170 | def test_make_train_test_split(self, splitting_type):
171 | """An integration test on a small dataset."""
172 |
173 | fpath = self.temp_dir
174 |
175 | # Create component datasets from two library files.
176 | main_train_val_test_fractions = (
177 | train_test_split_utils.TrainValTestFractions(0.5, 0.25, 0.25))
178 | replicates_val_test_fractions = (
179 | train_test_split_utils.TrainValTestFractions(0.0, 0.5, 0.5))
180 |
181 | (mainlib_inchikey_dict, replicates_inchikey_dict,
182 | component_inchikey_dict) = (
183 | make_train_test_split.make_mainlib_replicates_train_test_split(
184 | self.mol_list_large, self.mol_list_small, splitting_type,
185 | main_train_val_test_fractions, replicates_val_test_fractions))
186 |
187 | make_train_test_split.write_mainlib_split_datasets(
188 | component_inchikey_dict, mainlib_inchikey_dict, fpath,
189 | ms_constants.MAX_ATOMS, ms_constants.MAX_PEAK_LOC)
190 |
191 | make_train_test_split.write_replicates_split_datasets(
192 | component_inchikey_dict, replicates_inchikey_dict, fpath,
193 | ms_constants.MAX_ATOMS, ms_constants.MAX_PEAK_LOC)
194 |
195 | for experiment_setup in ds_constants.EXPERIMENT_SETUPS_LIST:
196 | # Create experiment json files
197 | tf.logging.info('Writing experiment setup for %s',
198 | experiment_setup.json_name)
199 | make_train_test_split.check_experiment_setup(
200 | experiment_setup.experiment_setup_dataset_dict,
201 | component_inchikey_dict)
202 |
203 | make_train_test_split.write_json_for_experiment(experiment_setup, fpath)
204 |
205 | # Check that physical files for library matching contain all inchikeys
206 | dict_from_json = json.load(
207 | tf.gfile.Open(os.path.join(fpath, experiment_setup.json_name)))
208 |
209 | tf.logging.info(dict_from_json)
210 | library_files = (
211 | dict_from_json[ds_constants.LIBRARY_MATCHING_OBSERVED_KEY] +
212 | dict_from_json[ds_constants.LIBRARY_MATCHING_PREDICTED_KEY])
213 | library_files = [os.path.join(fpath, fname) for fname in library_files]
214 |
215 | hparams = tf.contrib.training.HParams(
216 | max_atoms=ms_constants.MAX_ATOMS,
217 | max_mass_spec_peak_loc=ms_constants.MAX_PEAK_LOC,
218 | intensity_power=1.0,
219 | batch_size=5)
220 |
221 | parse_sdf_utils.validate_spectra_array_contents(
222 | os.path.join(
223 | fpath,
224 | dict_from_json[ds_constants.SPECTRUM_PREDICTION_TRAIN_KEY][0]),
225 | hparams,
226 | os.path.join(fpath,
227 | dict_from_json[ds_constants.TRAINING_SPECTRA_ARRAY_KEY]))
228 |
229 | dataset = parse_sdf_utils.get_dataset_from_record(
230 | library_files,
231 | hparams,
232 | mode=tf.estimator.ModeKeys.EVAL,
233 | all_data_in_one_batch=True)
234 |
235 | feature_names = [fmap_constants.INCHIKEY]
236 | label_names = [fmap_constants.ATOM_WEIGHTS]
237 |
238 | features, labels = parse_sdf_utils.make_features_and_labels(
239 | dataset, feature_names, label_names, mode=tf.estimator.ModeKeys.EVAL)
240 |
241 | with tf.Session() as sess:
242 | feature_values, _ = sess.run([features, labels])
243 |
244 | inchikeys_from_file = [
245 | ikey[0] for ikey in feature_values[fmap_constants.INCHIKEY].tolist()
246 | ]
247 |
248 | length_from_info_file = sum([
249 | parse_sdf_utils.parse_info_file(library_fname)['num_examples']
250 | for library_fname in library_files
251 | ])
252 | # Check that info file has the correct length for the file.
253 | self.assertLen(inchikeys_from_file, length_from_info_file)
254 | # Check that the TF.Record contains all of the inchikeys in our list.
255 | inchikey_list_large = [
256 | self.encode(ikey) for ikey in self.inchikey_list_large
257 | ]
258 | self.assertSetEqual(set(inchikeys_from_file), set(inchikey_list_large))
259 |
260 |
261 | if __name__ == '__main__':
262 | tf.test.main()
263 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright 2018 Google LLC
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/feature_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Functions for getting molecule features from NIST sdf data.
16 |
17 | This module includes functions to help with parsing sdf files and generating
18 | features such as atom weight lists and adjacency matrices. Also contains a
19 | function to parse mass spectra peaks from their string format in the NIST sdf
20 | files.
21 | """
22 |
23 | from __future__ import absolute_import
24 | from __future__ import division
25 | from __future__ import print_function
26 |
27 | import feature_map_constants as fmap_constants
28 | import mass_spec_constants as ms_constants
29 | import numpy as np
30 | from rdkit import Chem
31 | from rdkit import DataStructs
32 | from rdkit.Chem import AllChem
33 |
34 | FILTER_DICT = {
35 | 'steroid':
36 | Chem.MolFromSmarts(
37 | '[#6]1~[#6]2~[#6](~[#6]~[#6]~[#6]~1)'
38 | '~[#6]~[#6]~[#6]1~[#6]~2~[#6]~[#6]~[#6]2~[#6]~1~[#6]~[#6]~[#6]~2'
39 | ),
40 | 'diazo':
41 | Chem.MolFromSmarts(
42 | '[#7]~[#7]'
43 | ),
44 | }
45 |
46 |
47 | def get_smiles_string(mol):
48 | """Make canonicalized smiles from rdkit.Mol."""
49 | return Chem.MolToSmiles(mol, canonical=True, isomericSmiles=True)
50 |
51 |
52 | def get_molecular_formula(mol):
53 | """Makes string of molecular formula from rdkit.Mol."""
54 | return AllChem.CalcMolFormula(mol)
55 |
56 |
57 | def parse_peaks(pk_str):
58 | r"""Helper function for converting peak string into vector form.
59 |
60 | Args:
61 | pk_str : String from NIST MS data of format
62 | "peak1_loc peak1_int\npeak2_loc peak2_int"
63 |
64 | Returns:
65 | A tuple containing two arrays: (peak_locs, peak_ints)
66 | peak_locs : int list of the location of peaks
67 | peak_intensities : float list of the intensity of the peaks
68 | """
69 | all_peaks = pk_str.split('\n')
70 |
71 | peak_locs = []
72 | peak_intensities = []
73 |
74 | for peak in all_peaks:
75 | loc, intensity = peak.split()
76 | peak_locs.append(int(loc))
77 | peak_intensities.append(float(intensity))
78 |
79 | return peak_locs, peak_intensities
80 |
81 |
82 | def convert_spectrum_array_to_string(spectrum_array):
83 | """Write a spectrum array to string.
84 |
85 | Args:
86 | spectrum_array : np.array of shape (1000)
87 |
88 | Returns:
89 | string representing the peaks of the spectra.
90 | """
91 | mz_peak_locations = np.nonzero(spectrum_array)[0].tolist()
92 | mass_peak_strings = [
93 | '%d %d' % (p, spectrum_array[p]) for p in mz_peak_locations
94 | ]
95 | return '\n'.join(mass_peak_strings)
96 |
97 |
98 | def get_largest_mass_spec_peak_loc(mol):
99 | """Returns largest ms peak location from an rdkit.Mol object."""
100 | return parse_peaks(mol.GetProp(ms_constants.SDF_TAG_MASS_SPEC_PEAKS))[0][-1]
101 |
102 |
103 | def make_dense_mass_spectra(peak_locs, peak_intensities, max_peak_loc):
104 | """Make a dense np.array of the mass spectra.
105 |
106 | Args:
107 | peak_locs : int list of the location of peaks
108 | peak_intensities : float list of the intensity of the peaks
109 | max_peak_loc : maximum number of peaks bins
110 |
111 | Returns:
112 | np.array of the mass spectra data as a dense vector.
113 | """
114 | dense_spectrum = np.zeros(max_peak_loc)
115 | dense_spectrum[peak_locs] = peak_intensities
116 |
117 | return dense_spectrum
118 |
119 |
120 | def get_padded_atom_weights(mol, max_atoms):
121 | """Make a padded list of atoms of length max_atoms given rdkit.Mol object.
122 |
123 | Note: Returns atoms in the same order as the input rdkit.Mol.
124 | If you want the atoms in canonical order, you should canonicalize
125 | the molecule first.
126 |
127 | Args:
128 | mol : a rdkit.Mol object
129 | max_atoms : maximum number of atoms to consider
130 | Returns:
131 | np array of atoms by atomic mass of shape (max_atoms)
132 | Raises:
133 | ValueError : If rdkit.Mol object had more atoms than max_atoms.
134 | """
135 |
136 | if max_atoms < len(mol.GetAtoms()):
137 | raise ValueError(
138 | 'molecule contains {} atoms, more than max_atoms {}'.format(
139 | len(mol.GetAtoms()), max_atoms))
140 |
141 | atom_list = np.array([at.GetMass() for at in mol.GetAtoms()])
142 | atom_list = np.pad(atom_list, ((0, max_atoms - len(atom_list))), 'constant')
143 | return atom_list
144 |
145 |
146 | def get_padded_atom_ids(mol, max_atoms):
147 | """Make a padded list of atoms of length max_atoms given rdkit.Mol object.
148 |
149 | Args:
150 | mol : a rdkit.Mol object
151 | max_atoms : maximum number of atoms to consider
152 | Returns:
153 | np array of atoms by atomic number of shape (max_atoms)
154 | Note: function returns atoms in the same order as the input rdkit.Mol.
155 | If you want the atoms in canonical order, you should canonicalize
156 | the molecule first.
157 | Raises:
158 | ValueError : rdkit.Mol object is too big, had more atoms than max_atoms.
159 | """
160 | if max_atoms < len(mol.GetAtoms()):
161 | raise ValueError(
162 | 'molecule contains {} atoms, more than max_atoms {}'.format(
163 | len(mol.GetAtoms()), max_atoms))
164 | atom_list = np.array([at.GetAtomicNum() for at in mol.GetAtoms()])
165 | atom_list = atom_list.astype('int32')
166 | atom_list = np.pad(atom_list, ((0, max_atoms - len(atom_list))), 'constant')
167 |
168 | return atom_list
169 |
170 |
171 | def get_padded_adjacency_matrix(mol, max_atoms, add_hs_to_molecule=False):
172 | """Make a matrix with shape (max_atoms, max_atoms) given rdkit.Mol object.
173 |
174 | Args:
175 | mol: a rdkit.Mol object
176 | max_atoms : maximum number of atoms to consider
177 | add_hs_to_molecule : whether or not to add hydrogens to the molecule.
178 | Returns:
179 | np.array of floats of a flattened adjacency matrix of length
180 | (max_atoms * max_atoms)
181 | The values will be the index of the bond order in the alphabet
182 | Raises:
183 | ValueError : rdkit.Mol object is too big, had more atoms than max_atoms.
184 | """
185 | # Add hydrogens to atoms:
186 | if add_hs_to_molecule:
187 | mol = Chem.rdmolops.AddHs(mol)
188 |
189 | num_atoms_in_mol = len(mol.GetAtoms())
190 | if max_atoms < num_atoms_in_mol:
191 | raise ValueError(
192 | 'molecule contains {} atoms, more than max_atoms {}'.format(
193 | len(mol.GetAtoms()), max_atoms))
194 |
195 | adj_matrix = Chem.rdmolops.GetAdjacencyMatrix(mol, useBO=True)
196 |
197 | for i in range(np.shape(adj_matrix)[0]):
198 | for j in range(np.shape(adj_matrix)[1]):
199 | if adj_matrix[i, j] != 0:
200 | adj_matrix[i, j] = ms_constants.BOND_ORDER_TO_INTS_DICT[adj_matrix[i,
201 | j]]
202 |
203 | padded_adjacency_matrix = np.zeros((max_atoms, max_atoms))
204 |
205 | padded_adjacency_matrix[:num_atoms_in_mol, :num_atoms_in_mol] = adj_matrix
206 | padded_adjacency_matrix = padded_adjacency_matrix.astype('int32')
207 |
208 | return np.reshape(padded_adjacency_matrix, (max_atoms * max_atoms))
209 |
210 |
211 | def make_circular_fingerprint(mol, circular_fp_key):
212 | """Returns circular fingerprint for a mol given its circular_fp_key.
213 |
214 | Args:
215 | mol : rdkit.Mol
216 | circular_fp_key : A ms_constants.CircularFingerprintKey object
217 | Returns:
218 | np.array of len circular_fp_key.fp_len
219 | """
220 | # A dictionary to record rdkit functions to base names
221 | fp_methods_dict = {
222 | fmap_constants.CIRCULAR_FP_BASENAME:
223 | AllChem.GetMorganFingerprintAsBitVect,
224 | fmap_constants.COUNTING_CIRCULAR_FP_BASENAME:
225 | AllChem.GetHashedMorganFingerprint
226 | }
227 |
228 | fp = fp_methods_dict[circular_fp_key.fp_type](
229 | mol, circular_fp_key.radius, nBits=circular_fp_key.fp_len)
230 | fp_arr = np.zeros(1)
231 | DataStructs.ConvertToNumpyArray(fp, fp_arr)
232 | return fp_arr
233 |
234 |
235 | def all_circular_fingerprints_to_dict(mol):
236 | """Creates all circular fingerprints from list of lengths and radii.
237 |
238 | Based on lists of fingerprint lengths and fingerprint radii inside
239 | mass_spec_constants.
240 |
241 | Args:
242 | mol : rdkit.Mol
243 | Returns:
244 | a dict. The keys are CircularFingerprintKey instances and the values are
245 | the corresponding fingerprints
246 | """
247 | fp_dict = {}
248 | for fp_len in ms_constants.NUM_CIRCULAR_FP_BITS_LIST:
249 | for rad in ms_constants.CIRCULAR_FP_RADII_LIST:
250 | for fp_type in fmap_constants.FP_TYPE_LIST:
251 | circular_fp_key = ms_constants.CircularFingerprintKey(
252 | fp_type, fp_len, rad)
253 | fp_dict[circular_fp_key] = make_circular_fingerprint(
254 | mol, circular_fp_key)
255 | return fp_dict
256 |
257 |
258 | def check_mol_has_non_empty_smiles(mol):
259 | """Checks if smiles string of rdkit.Mol is an empty string."""
260 | return bool(get_smiles_string(mol))
261 |
262 |
263 | def check_mol_has_non_empty_mass_spec_peak_tag(mol):
264 | """Checks if mass spec sdf tag is in properties of rdkit.Mol."""
265 | return ms_constants.SDF_TAG_MASS_SPEC_PEAKS in mol.GetPropNames()
266 |
267 |
268 | def check_mol_only_has_atoms(mol, accept_atom_list):
269 | """Checks if rdkit.Mol only contains atoms from accept_atom_list."""
270 | atom_symbol_list = [atom.GetSymbol() for atom in mol.GetAtoms()]
271 | return all(atom in accept_atom_list for atom in atom_symbol_list)
272 |
273 |
274 | def check_mol_does_not_have_atoms(mol, exclude_atom_list):
275 | """Checks if rdkit.Mol contains any molecule from exclude_atom_list."""
276 | atom_symbol_list = [atom.GetSymbol() for atom in mol.GetAtoms()]
277 | return all(atom not in atom_symbol_list for atom in exclude_atom_list)
278 |
279 |
280 | def check_mol_has_substructure(mol, substructure_mol):
281 | """Checks if rdkit.Mol has substructure.
282 |
283 | Args:
284 | mol : rdkit.Mol, representing query
285 | substructure_mol: rdkit.Mol, representing substructure family
286 | Returns:
287 | Boolean, True if substructure found in molecule.
288 | """
289 | return mol.HasSubstructMatch(substructure_mol)
290 |
291 |
292 | def make_filter_by_substructure(family_name):
293 | """Returns a filter function according to the family_name."""
294 | if family_name not in FILTER_DICT.keys():
295 | raise ValueError('%s is not supported for family splitting' % family_name)
296 | return lambda mol: check_mol_has_substructure(mol, FILTER_DICT[family_name])
297 |
298 |
299 | def tokenize_smiles(smiles_string_arr):
300 | """Creates a list of tokens from a smiles string.
301 |
302 | Two letter atom characters are considered to be a single token.
303 | All two letter tokens observed in this dataset are recorded in
304 | ms_constants.TWO_LETTER_TOKEN_NAMES.
305 |
306 | Args:
307 | smiles_string_arr: np.array of dtype str and shape (1, )
308 | Returns:
309 | A np.array of ints corresponding with the tokens
310 | """
311 |
312 | smiles_str = smiles_string_arr[0]
313 | if isinstance(smiles_str, bytes):
314 | smiles_str = smiles_str.decode('utf-8')
315 |
316 | token_list = []
317 | ptr = 0
318 |
319 | while ptr < len(smiles_str):
320 | if smiles_str[ptr:ptr + 2] in ms_constants.TWO_LETTER_TOKEN_NAMES:
321 | token_list.append(
322 | ms_constants.SMILES_TOKEN_NAME_TO_INDEX[smiles_str[ptr:ptr + 2]])
323 | ptr += 2
324 | else:
325 | token_list.append(
326 | ms_constants.SMILES_TOKEN_NAME_TO_INDEX[smiles_str[ptr]])
327 | ptr += 1
328 |
329 | return np.array(token_list, dtype=np.int64)
330 |
--------------------------------------------------------------------------------
/molecule_estimator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | r"""Train and evaluate massspec model.
15 |
16 | Example usage:
17 | molecule_estimator.py --train_steps=1000 --model_dir='/tmp/models/output' \
18 | --dataset_config_file=testdata/test_dataset_config_file.json --alsologtostderr
19 | """
20 |
21 | from __future__ import print_function
22 | import json
23 | import os
24 |
25 | from absl import flags
26 | import dataset_setup_constants as ds_constants
27 | import feature_map_constants as fmap_constants
28 | import library_matching
29 | import molecule_predictors
30 | import parse_sdf_utils
31 | import six
32 | import tensorflow as tf
33 |
34 | FLAGS = flags.FLAGS
35 | flags.DEFINE_string(
36 | 'dataset_config_file', None,
37 | 'JSON file specifying the various filenames necessary for training and '
38 | 'evaluation. See make_input_fn() for more details.')
39 | flags.DEFINE_string(
40 | 'hparams', '', 'Hyperparameter values to override the defaults.'
41 | 'Format: params1=value1,params2=value2, ...'
42 | 'Possible parameters: max_atoms, max_mass_spec_peak_loc,'
43 | 'batch_size, epochs, do_linear_regression,'
44 | 'get_mass_spec_features, init_weights, init_bias')
45 | flags.DEFINE_integer('train_steps', None,
46 | 'The number of steps to run training for.')
47 | flags.DEFINE_integer(
48 | 'train_steps_per_iteration', 1000,
49 | 'how frequently to evaluate (only used when schedule =='
50 | ' continuous_train_and_eval')
51 |
52 | flags.DEFINE_string('model_dir', '',
53 | 'output directory for checkpoints and events files')
54 | flags.DEFINE_string('warm_start_dir', None,
55 | 'directory to warm start model from')
56 |
57 | flags.DEFINE_enum('model_type', 'mlp',
58 | molecule_predictors.MODEL_REGISTRY.keys(),
59 | 'Type of model to use.')
60 | OUTPUT_HPARAMS_CONFIG_FILE_BASE = 'command_line_arguments.txt'
61 |
62 |
63 | def make_input_fn(dataset_config_file,
64 | hparams,
65 | mode,
66 | features_to_load,
67 | load_library_matching_data,
68 | data_dir=None):
69 | """Make input functions returning features and labels.
70 |
71 | In our downstream code, it is advantageous to put both features
72 | and labels together in the same nested structure. However, tf.estimator
73 | requires input_fn() to return features and labels, so here our input_fn
74 | returns dummy labels that will not be used.
75 |
76 | Args:
77 | dataset_config_file: filename of JSON file containing a dict mapping dataset
78 | names to data files. The required keys are:
79 | 'SPECTRUM_PREDICTION_TRAIN': training data
80 | 'SPECTRUM_PREDICTION_TEST': eval data on which we evaluate the same loss
81 | function that is used for training
82 | 'LIBRARY_MATCHING_OBSERVED': data for library matching where we use
83 | ground-truth spectra
84 | 'LIBRARY_MATCHING_PREDICTED': data for library matching where we use
85 | predicted spectra
86 | 'LIBRARY_MATCHING_QUERY': data with observed spectra used for queries in
87 | the library
88 | matching task
89 |
90 | For each data file with name , we read high-level information about
91 | the data from a separate file with the name .info. See
92 | parse_sdf_utils.parse_info_file() for the expected format of that file.
93 |
94 | hparams: hparams required for parsing; includes features such as max_atoms,
95 | max_mass_spec_peak_loc, and batch_size
96 | mode: Set whether training or test mode
97 | features_to_load: list of keys to load from the input data
98 | load_library_matching_data: whether to load library matching data.
99 | data_dir: The directory containing the file names referred to in
100 | dataset_config_file. If None (the default) then this is assumed to be the
101 | directory containing dataset_config_file.
102 | Returns:
103 | A function for creating features and labels from a dataset.
104 | """
105 | with tf.gfile.Open(dataset_config_file, 'r') as f:
106 | filenames = json.load(f)
107 |
108 | if data_dir is None:
109 | data_dir = os.path.dirname(dataset_config_file)
110 |
111 | def _input_fn(record_fnames,
112 | all_data_in_one_batch,
113 | load_training_spectrum_library=False):
114 | """Reads TFRecord from a list of record file names."""
115 | if not record_fnames:
116 | return None
117 |
118 | record_fnames = [os.path.join(data_dir, r_name) for r_name in record_fnames]
119 | dataset = parse_sdf_utils.get_dataset_from_record(
120 | record_fnames,
121 | hparams,
122 | mode=mode,
123 | features_to_load=(features_to_load + hparams.label_names),
124 | all_data_in_one_batch=all_data_in_one_batch)
125 | dict_to_return = parse_sdf_utils.make_features_and_labels(
126 | dataset, features_to_load, hparams.label_names, mode=mode)[0]
127 |
128 | if load_training_spectrum_library:
129 | library_file = os.path.join(
130 | '/readahead/128M/',
131 | filenames[ds_constants.TRAINING_SPECTRA_ARRAY_KEY])
132 | train_library = parse_sdf_utils.load_training_spectra_array(library_file)
133 | train_library = tf.convert_to_tensor(train_library, dtype=tf.float32)
134 |
135 | dict_to_return['SPECTRUM_PREDICTION_LIBRARY'] = train_library
136 |
137 | return dict_to_return
138 |
139 | load_training_spectrum_library = hparams.loss == 'max_margin'
140 |
141 | if load_library_matching_data:
142 |
143 | def _wrapped_input_fn():
144 | """Construct data for various eval tasks."""
145 |
146 | data_to_return = {
147 | fmap_constants.SPECTRUM_PREDICTION:
148 | _input_fn(
149 | filenames[ds_constants.SPECTRUM_PREDICTION_TEST_KEY],
150 | all_data_in_one_batch=False,
151 | load_training_spectrum_library=load_training_spectrum_library)
152 | }
153 |
154 | if hparams.do_library_matching:
155 | observed = _input_fn(
156 | filenames[ds_constants.LIBRARY_MATCHING_OBSERVED_KEY],
157 | all_data_in_one_batch=True)
158 | predicted = _input_fn(
159 | filenames[ds_constants.LIBRARY_MATCHING_PREDICTED_KEY],
160 | all_data_in_one_batch=True)
161 | query = _input_fn(
162 | filenames[ds_constants.LIBRARY_MATCHING_QUERY_KEY],
163 | all_data_in_one_batch=False)
164 | data_to_return[fmap_constants.
165 | LIBRARY_MATCHING] = library_matching.LibraryMatchingData(
166 | observed=observed, predicted=predicted, query=query)
167 |
168 | return data_to_return, tf.constant(0)
169 | else:
170 |
171 | def _wrapped_input_fn():
172 | # See the above comment about why we use dummy labels.
173 | return {
174 | fmap_constants.SPECTRUM_PREDICTION:
175 | _input_fn(
176 | filenames[ds_constants.SPECTRUM_PREDICTION_TRAIN_KEY],
177 | all_data_in_one_batch=False,
178 | load_training_spectrum_library=load_training_spectrum_library)
179 | }, tf.constant(0)
180 |
181 | return _wrapped_input_fn
182 |
183 |
184 | def _log_command_line_string(model_type, model_dir, hparams):
185 | """Log command line args to replicate hparam configuration."""
186 |
187 | config_string = '--model_type=%s ' % (model_type)
188 |
189 | # Note that the rendered string will not be able to be parsed using
190 | # hparams.parse() if any of the hparam values have commas or '=' signs.
191 | hparams_string = ','.join(
192 | ['%s=%s' % (key, value) for key, value in six.iteritems(
193 | hparams.values())])
194 |
195 | config_string += ' --hparams=%s\n' % hparams_string
196 | output_file = os.path.join(model_dir, OUTPUT_HPARAMS_CONFIG_FILE_BASE)
197 | tf.gfile.MakeDirs(model_dir)
198 | tf.logging.info('Writing command line config string to %s.' % output_file)
199 |
200 | with tf.gfile.Open(output_file, 'w') as f:
201 | f.write(config_string)
202 |
203 |
204 | def make_model_fn(prediction_helper, dataset_config_file, model_dir):
205 | """Returns a model function for estimator given prediction base class.
206 |
207 | Args:
208 | prediction_helper : Helper class containing prediction, loss, and evaluation
209 | metrics
210 | dataset_config_file: see make_input_fn.
211 | model_dir : directory for writing output files. If model dir is not None,
212 | a file containing all of the necessary command line flags to rehydrate
213 | the model will be written in model_dir.
214 | Returns:
215 | A function that returns a tf.EstimatorSpec
216 | """
217 |
218 | def _model_fn(features, labels, params, mode=None):
219 | """Returns tf.EstimatorSpec."""
220 |
221 | # Input labels are ignored. All data are in features.
222 | del labels
223 |
224 | if model_dir is not None:
225 | _log_command_line_string(prediction_helper.model_type, model_dir, params)
226 |
227 | pred_op, pred_op_for_loss = prediction_helper.make_prediction_ops(
228 | features[fmap_constants.SPECTRUM_PREDICTION], params, mode)
229 | loss_op = prediction_helper.make_loss(
230 | pred_op_for_loss, features[fmap_constants.SPECTRUM_PREDICTION], params)
231 |
232 | if mode == tf.estimator.ModeKeys.TRAIN:
233 | train_op = tf.contrib.layers.optimize_loss(
234 | loss=loss_op,
235 | global_step=tf.train.get_global_step(),
236 | clip_gradients=params.gradient_clip,
237 | learning_rate=params.learning_rate,
238 | optimizer='Adam')
239 | eval_op = None
240 | elif mode == tf.estimator.ModeKeys.PREDICT:
241 | train_op = None
242 | eval_op = None
243 | elif mode == tf.estimator.ModeKeys.EVAL:
244 | train_op = None
245 | eval_op = prediction_helper.make_evaluation_metrics(
246 | features, params, dataset_config_file, output_dir=model_dir)
247 | else:
248 | raise ValueError('Invalid mode. Must be '
249 | 'tf.estimator.ModeKeys.{TRAIN,PREDICT,EVAL}.')
250 | return tf.estimator.EstimatorSpec(
251 | mode=mode,
252 | predictions=pred_op,
253 | loss=loss_op,
254 | train_op=train_op,
255 | eval_metric_ops=eval_op)
256 |
257 | return _model_fn
258 |
259 |
260 | def make_estimator_and_inputs(run_config,
261 | hparams,
262 | prediction_helper,
263 | dataset_config_file,
264 | train_steps,
265 | model_dir,
266 | warm_start_dir=None):
267 | """Make Estimator-compatible Estimator and input_fn for train and eval."""
268 |
269 | model_fn = make_model_fn(prediction_helper, dataset_config_file, model_dir)
270 |
271 | train_input_fn = make_input_fn(
272 | dataset_config_file=dataset_config_file,
273 | hparams=hparams,
274 | mode=tf.estimator.ModeKeys.TRAIN,
275 | features_to_load=prediction_helper.features_to_load(hparams),
276 | load_library_matching_data=False)
277 |
278 | eval_input_fn = make_input_fn(
279 | dataset_config_file=dataset_config_file,
280 | hparams=hparams,
281 | mode=tf.estimator.ModeKeys.EVAL,
282 | features_to_load=prediction_helper.features_to_load(hparams),
283 | load_library_matching_data=True)
284 |
285 | estimator = tf.estimator.Estimator(
286 | model_fn=model_fn,
287 | params=hparams,
288 | config=run_config,
289 | warm_start_from=warm_start_dir)
290 |
291 | train_spec = tf.estimator.TrainSpec(train_input_fn, max_steps=train_steps)
292 | eval_spec = tf.estimator.EvalSpec(eval_input_fn, steps=None)
293 |
294 | return estimator, train_spec, eval_spec
295 |
296 |
297 | def main(_):
298 | prediction_helper = molecule_predictors.get_prediction_helper(
299 | FLAGS.model_type)
300 |
301 | hparams = prediction_helper.get_default_hparams()
302 | hparams.parse(FLAGS.hparams)
303 |
304 | config = tf.contrib.learn.RunConfig(model_dir=FLAGS.model_dir)
305 |
306 | (estimator, train_spec, eval_spec) = make_estimator_and_inputs(
307 | config, hparams, prediction_helper, FLAGS.dataset_config_file,
308 | FLAGS.train_steps, FLAGS.model_dir, FLAGS.warm_start_dir)
309 | tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
310 |
311 |
312 | if __name__ == '__main__':
313 | tf.app.run(main)
314 |
--------------------------------------------------------------------------------
/plot_spectra_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Functions to add an evaluation metric that generates spectra plots."""
15 | from __future__ import print_function
16 |
17 | import json
18 | import os
19 |
20 | import dataset_setup_constants as ds_constants
21 | import mass_spec_constants as ms_constants
22 | import matplotlib.pyplot as plt
23 | import numpy as np
24 | from PIL import Image as PilImage
25 | import six
26 | import tensorflow as tf
27 |
28 | IMAGE_SUBDIR_FOR_SPECTRA_PLOTS = 'images'
29 |
30 | SPECTRA_PLOT_BACKGROUND_COLOR = 'white'
31 | SPECTRA_PLOT_FIGURE_SIZE = (10, 10)
32 | SPECTRA_PLOT_GRID_COLOR = 'black'
33 | SPECTRA_PLOT_TRUE_SPECTRA_COLOR = 'blue'
34 | SPECTRA_PLOT_PREDICTED_SPECTRA_COLOR = 'red'
35 | SPECTRA_PLOT_PEAK_LOC_LIMIT = ms_constants.MAX_PEAK_LOC
36 | SPECTRA_PLOT_MZ_MAX_OFFSET = 10
37 | SPECTRA_PLOT_INTENSITY_LIMIT = 1200
38 | SPECTRA_PLOT_DPI = 300
39 | SPECTRA_PLOT_BAR_LINE_WIDTH = 0.8
40 | SPECTRA_PLOT_BAR_GRID_LINE_WIDTH = 0.1
41 | SPECTRA_PLOT_ACTUAL_SPECTRA_LEGEND_TEXT = 'True Mass Spectrum'
42 | SPECTRA_PLOT_PREDICTED_SPECTRA_LEGEND_TEXT = 'Predicted Mass Spectrum'
43 | SPECTRA_PLOT_QUERY_SPECTRA_LEGEND_TEXT = 'Query Mass Spectrum'
44 | SPECTRA_PLOT_LIBRARY_MATCH_SPECTRA_LEGEND_TEXT = 'Library Matched Mass Spectrum'
45 | SPECTRA_PLOT_X_AXIS_LABEL = 'mass/charge ratio'
46 | SPECTRA_PLOT_Y_AXIS_LABEL = 'relative intensity'
47 | SPECTRA_PLOT_PLACE_LEGEND_ABOVE_CHART_KWARGS = {'ncol': 2}
48 | SPECTRA_PLOT_IMAGE_DIR_NAME = 'images'
49 | SPECTRA_PLOT_DIMENSIONS_RGB = (3000, 3000, 3)
50 | FIGURES_TO_SUMMARIZE_PER_BATCH = 2
51 | MAX_VALUE_OF_TRUE_SPECTRA = 999.
52 |
53 |
54 | class PlotModeKeys(object):
55 | """Helper class containing the two supported plotting modes.
56 |
57 | The following keys are defined:
58 | PREDICTED_SPECTRUM : For plotting the spectrum predicted by the algorithm
59 | against the true spectrum.
60 | LIBRARY_MATCHED_SPECTRUM : For plotting the spectrum that was the closest
61 | match to the true spectrum against the true spectrum.
62 | """
63 | PREDICTED_SPECTRUM = 'predicted_spectrum'
64 | LIBRARY_MATCHED_SPECTRUM = 'library_match_spectrum'
65 |
66 |
67 | def name_plot_file(mode, query_inchikey, matched_inchikey=None,
68 | file_type='png'):
69 | """Generates name for spectra plot files."""
70 | if mode == PlotModeKeys.PREDICTED_SPECTRUM:
71 | return '{}.{}'.format(query_inchikey, file_type)
72 | elif mode == PlotModeKeys.LIBRARY_MATCHED_SPECTRUM:
73 | return '{}_matched_to_{}.{}'.format(query_inchikey, matched_inchikey,
74 | file_type)
75 |
76 |
77 | def name_metric(mode, inchikey):
78 | return '{}_plot_{}'.format(mode, inchikey)
79 |
80 |
81 | def plot_true_and_predicted_spectra(
82 | true_dense_spectra,
83 | generated_dense_spectra,
84 | plot_mode_key=PlotModeKeys.PREDICTED_SPECTRUM,
85 | output_filename='',
86 | rescale_mz_axis=False):
87 | """Generates a plot comparing a true and predicted mass spec spectra.
88 |
89 | If output_filename given, saves a png file of the spectra, with the
90 | true spectrum above and predicted spectrum below.
91 |
92 | Args:
93 | true_dense_spectra : np.array representing the true mass spectra
94 | generated_dense_spectra : np.array representing the predicted mass spectra
95 | plot_mode_key: a PlotModeKeys instance
96 | output_filename : str path for saving generated image.
97 | rescale_mz_axis: Setting to rescale m/z axis according to highest m/z peak
98 | location.
99 |
100 | Returns:
101 | np.array of the bits of the generated matplotlib plot.
102 | """
103 |
104 | if not rescale_mz_axis:
105 | x_array = np.arange(SPECTRA_PLOT_PEAK_LOC_LIMIT)
106 | bar_width = SPECTRA_PLOT_BAR_LINE_WIDTH
107 | mz_max = SPECTRA_PLOT_PEAK_LOC_LIMIT
108 | else:
109 | mz_max = max(
110 | max(np.nonzero(true_dense_spectra)[0]),
111 | max(np.nonzero(generated_dense_spectra)[0]))
112 | if mz_max + SPECTRA_PLOT_MZ_MAX_OFFSET < ms_constants.MAX_PEAK_LOC:
113 | mz_max += SPECTRA_PLOT_MZ_MAX_OFFSET
114 | else:
115 | mz_max = ms_constants.MAX_PEAK_LOC
116 | x_array = np.arange(mz_max)
117 | true_dense_spectra = true_dense_spectra[:mz_max]
118 | generated_dense_spectra = generated_dense_spectra[:mz_max]
119 | bar_width = SPECTRA_PLOT_BAR_LINE_WIDTH * mz_max / ms_constants.MAX_PEAK_LOC
120 |
121 | figure = plt.figure(figsize=SPECTRA_PLOT_FIGURE_SIZE, dpi=300)
122 |
123 | # Adding extra subplot so both plots have common x-axis and y-axis labels
124 | ax_main = figure.add_subplot(111, frameon=False)
125 | ax_main.tick_params(
126 | labelcolor='none', top='off', bottom='off', left='off', right='off')
127 |
128 | ax_main.set_xlabel(SPECTRA_PLOT_X_AXIS_LABEL)
129 | ax_main.set_ylabel(SPECTRA_PLOT_Y_AXIS_LABEL)
130 |
131 | if six.PY2:
132 | ax_top = figure.add_subplot(211, axisbg=SPECTRA_PLOT_BACKGROUND_COLOR)
133 | else:
134 | ax_top = figure.add_subplot(211, facecolor=SPECTRA_PLOT_BACKGROUND_COLOR)
135 |
136 | bar_top = ax_top.bar(
137 | x_array,
138 | true_dense_spectra,
139 | bar_width,
140 | color=SPECTRA_PLOT_TRUE_SPECTRA_COLOR,
141 | edgecolor=SPECTRA_PLOT_TRUE_SPECTRA_COLOR,
142 | )
143 |
144 | ax_top.set_ylim((0, SPECTRA_PLOT_INTENSITY_LIMIT))
145 | plt.setp(ax_top.get_xticklabels(), visible=False)
146 | ax_top.grid(
147 | color=SPECTRA_PLOT_GRID_COLOR, linewidth=SPECTRA_PLOT_BAR_GRID_LINE_WIDTH)
148 |
149 | if six.PY2:
150 | ax_bottom = figure.add_subplot(212, axisbg=SPECTRA_PLOT_BACKGROUND_COLOR)
151 | else:
152 | ax_bottom = figure.add_subplot(212, facecolor=SPECTRA_PLOT_BACKGROUND_COLOR)
153 | figure.subplots_adjust(hspace=0.0)
154 |
155 | bar_bottom = ax_bottom.bar(
156 | x_array,
157 | generated_dense_spectra,
158 | bar_width,
159 | color=SPECTRA_PLOT_PREDICTED_SPECTRA_COLOR,
160 | edgecolor=SPECTRA_PLOT_PREDICTED_SPECTRA_COLOR,
161 | )
162 |
163 | # Invert the direction of y-axis ticks for bottom graph.
164 | ax_bottom.set_ylim((SPECTRA_PLOT_INTENSITY_LIMIT, 0))
165 |
166 | ax_bottom.set_xlim(0, mz_max)
167 | # Remove overlapping 0's from middle of y-axis
168 | yticks_bottom = ax_bottom.yaxis.get_major_ticks()
169 | yticks_bottom[0].label1.set_visible(False)
170 |
171 | ax_bottom.grid(
172 | color=SPECTRA_PLOT_GRID_COLOR, linewidth=SPECTRA_PLOT_BAR_GRID_LINE_WIDTH)
173 |
174 | for ax in [ax_top, ax_bottom]:
175 | ax.minorticks_on()
176 | ax.tick_params(axis='y', which='minor', left='off')
177 | ax.tick_params(axis='y', which='minor', right='off')
178 |
179 | ax_top.tick_params(axis='x', which='minor', top='off')
180 |
181 | if plot_mode_key == PlotModeKeys.PREDICTED_SPECTRUM:
182 | ax_top.legend((bar_top, bar_bottom),
183 | (SPECTRA_PLOT_ACTUAL_SPECTRA_LEGEND_TEXT,
184 | SPECTRA_PLOT_PREDICTED_SPECTRA_LEGEND_TEXT),
185 | **SPECTRA_PLOT_PLACE_LEGEND_ABOVE_CHART_KWARGS)
186 | elif plot_mode_key == PlotModeKeys.LIBRARY_MATCHED_SPECTRUM:
187 | ax_top.legend((bar_top, bar_bottom),
188 | (SPECTRA_PLOT_QUERY_SPECTRA_LEGEND_TEXT,
189 | SPECTRA_PLOT_LIBRARY_MATCH_SPECTRA_LEGEND_TEXT),
190 | **SPECTRA_PLOT_PLACE_LEGEND_ABOVE_CHART_KWARGS)
191 |
192 | figure.canvas.draw()
193 | data = np.fromstring(figure.canvas.tostring_rgb(), dtype=np.uint8, sep='')
194 |
195 | try:
196 | data = np.reshape(data, SPECTRA_PLOT_DIMENSIONS_RGB)
197 | except ValueError:
198 | raise ValueError(
199 | 'The shape of the np.array generated from the data does '
200 | 'not match the values in '
201 | 'SPECTRA_PLOT_DIMENSIONS_RGB : {}'.format(SPECTRA_PLOT_DIMENSIONS_RGB))
202 |
203 | if output_filename:
204 | # We can't call plt.savefig(output_filename) because plt does not
205 | # communicate with the filesystem through gfile. In some scenarios, this
206 | # will prevent us from writing out data. Instead, we use PIL to help us
207 | # efficiently save the nparray of the image as a png file.
208 | if not output_filename.endswith('.png') or output_filename.endswith('.eps'):
209 | output_filename += '.png'
210 |
211 | with tf.gfile.GFile(output_filename, 'wb') as out:
212 | image = PilImage.fromarray(data).convert('RGB')
213 | image.save(out, dpi=(SPECTRA_PLOT_DPI, SPECTRA_PLOT_DPI))
214 |
215 | tf.logging.info('Shape of spectra plot data {} '.format(np.shape(data)))
216 |
217 | plt.close(figure)
218 |
219 | return data
220 |
221 |
222 | def make_plot(inchikey,
223 | plot_mode_key,
224 | update_img_flag,
225 | inchikeys_batch,
226 | true_spectra_batch,
227 | predictions,
228 | image_directory=None,
229 | library_match_inchikeys=None):
230 | """Makes plots comparing the true and predicted spectra in a dataset.
231 |
232 | This function only performs the expensive step of generating the spectrum
233 | plot if the target inchikey is in the current batch.
234 |
235 | Args:
236 | inchikey: Inchikey of query that we want to make plots with.
237 | plot_mode_key: A PlotModeKeys instance.
238 | update_img_flag: Boolean flag for whether to generate a spectra plot
239 | inchikeys_batch: inchikeys from the current batch
240 | true_spectra_batch: np.array of all the true spectra from the current batch.
241 | predictions: np.array of all predicted spectra from the current batch.
242 | image_directory: Location to save image directory, if set.
243 | library_match_inchikeys: np.array of strings, corresponding to inchikeys
244 | that were the best matched from the library inchikey task.
245 |
246 | Returns:
247 | if update_img_flag: np.array
248 | [see return value of plot_true_and_predicted_spectra]
249 | Otherwise, returns a zero np.array of shape SPECTRA_PLOT_DIMENSIONS_RGB.
250 | Also saves a file at image_directory if this value is non-zero.
251 | Raises:
252 | ValueError: library_match_inchikeys needs to be set if given image_directory
253 | and using PlotModeKeys.LIBRARY_MATCHED_SPECTRUM.
254 | """
255 | if update_img_flag:
256 | flattened_inchikeys_batch = [ikey[0].strip() for ikey in inchikeys_batch]
257 | inchikey_idx = flattened_inchikeys_batch.index(inchikey)
258 | predictions = predictions / np.amax(
259 | predictions, axis=1, keepdims=True) * MAX_VALUE_OF_TRUE_SPECTRA
260 | predicted_spectra_to_plot = predictions[inchikey_idx, :]
261 | true_spectra_to_plot = true_spectra_batch[inchikey_idx, :]
262 | if image_directory:
263 | if plot_mode_key == PlotModeKeys.PREDICTED_SPECTRUM:
264 | img_filename = name_plot_file(plot_mode_key, inchikey)
265 | elif plot_mode_key == PlotModeKeys.LIBRARY_MATCHED_SPECTRUM:
266 | best_library_match_inchikey = library_match_inchikeys[inchikey_idx, :]
267 | img_filename = name_plot_file(plot_mode_key, inchikey,
268 | best_library_match_inchikey[0])
269 |
270 | img_pathname = os.path.join(image_directory, img_filename)
271 | spectra_arr = plot_true_and_predicted_spectra(true_spectra_to_plot,
272 | predicted_spectra_to_plot,
273 | plot_mode_key, img_pathname)
274 | else:
275 | spectra_arr = plot_true_and_predicted_spectra(true_spectra_to_plot,
276 | predicted_spectra_to_plot,
277 | plot_mode_key)
278 | return spectra_arr
279 | else:
280 | return np.zeros(SPECTRA_PLOT_DIMENSIONS_RGB, dtype=np.uint8)
281 |
282 |
283 | def spectra_plot_summary_op(inchikey_list,
284 | true_spectra,
285 | prediction_batch,
286 | inchikey_to_plot,
287 | plot_mode_key=PlotModeKeys.PREDICTED_SPECTRUM,
288 | library_match_inchikeys=None,
289 | image_directory=None):
290 | """Wrapper for plotting mass spectra for labels and predictions.
291 |
292 | Plots predicted and true spectra for a given inchikey. If image_directory is
293 | set, will save the plots as files in addition to making the image summary.
294 |
295 | Args:
296 | inchikey_list : tf Tensor of inchikey strings for a batch
297 | true_spectra : tf Tensor array with true spectra for a batch
298 | prediction_batch: tf Tensor array of predicted spectra for a single batch.
299 | inchikey_to_plot: string InChI key contained in test set (but perhaps not in
300 | a particular batch).
301 | plot_mode_key: A PlotModeKeys instance.
302 | library_match_inchikeys: tf Tensor of strings corresponding to the inchikeys
303 | top match from the library matching task.
304 | image_directory: string of dir name to save plots
305 |
306 | Returns:
307 | tf.summary.image of the operation, and an update operator indicating if the
308 | summary has been updated or not.
309 | """
310 |
311 | def _should_update_image(inchikeys_batch):
312 | """Tests whether to indicate if target inchikey is in batch."""
313 | flattened_inchikeys_batch = [ikey[0].strip() for ikey in inchikeys_batch]
314 | return inchikey_to_plot in flattened_inchikeys_batch
315 |
316 | metric_namescope = 'spectrum_{}_plot_{}'.format(plot_mode_key,
317 | inchikey_to_plot)
318 | spectra_variable_name = 'spectrum_{}_plot_{}'.format(plot_mode_key,
319 | inchikey_to_plot)
320 | with tf.name_scope(metric_namescope):
321 | # Whether the inchikey_to_plot is in the current batch.
322 | update_image_bool = tf.py_func(_should_update_image, [inchikey_list],
323 | tf.bool)
324 |
325 | if plot_mode_key == PlotModeKeys.LIBRARY_MATCHED_SPECTRUM:
326 | spectra_plot = tf.py_func(make_plot, [
327 | inchikey_to_plot, plot_mode_key, update_image_bool, inchikey_list,
328 | true_spectra, prediction_batch, image_directory,
329 | library_match_inchikeys
330 | ], tf.uint8)
331 | elif plot_mode_key == PlotModeKeys.PREDICTED_SPECTRUM:
332 | spectra_plot = tf.py_func(make_plot, [
333 | inchikey_to_plot, plot_mode_key, update_image_bool, inchikey_list,
334 | true_spectra, prediction_batch, image_directory
335 | ], tf.uint8)
336 |
337 | # Container for the plot. this value will only be assigned to something
338 | # new if the target inchikey is in the input batch.
339 | spectra_plot_variable = tf.get_local_variable(
340 | spectra_variable_name,
341 | shape=((1,) + SPECTRA_PLOT_DIMENSIONS_RGB),
342 | initializer=tf.constant_initializer(128),
343 | dtype=tf.uint8)
344 |
345 | # A function that add the spectra plot as metric.
346 | def update_function():
347 | assign_op = spectra_plot_variable.assign(spectra_plot[tf.newaxis, ...])
348 | with tf.control_dependencies([assign_op]):
349 | return tf.identity(spectra_plot_variable)
350 |
351 | # We only want to update the metric if the inchikey_to_plot
352 | # is in the batch. update_image_bool serves as a flag to tf.cond
353 | # to use the real update function if inchikey_to_plot is in the batch
354 | # and a fake one if not.
355 | final_spectra_plot = tf.cond(update_image_bool,
356 | update_function, lambda: spectra_plot_variable)
357 |
358 | update_op = final_spectra_plot
359 |
360 | return (tf.summary.image(
361 | spectra_variable_name, spectra_plot_variable,
362 | collections=None), update_op)
363 |
364 |
365 | def inchikeys_for_plotting(dataset_config_file, num_inchikeys_to_read,
366 | eval_batch_size):
367 | """Return inchikeys from spectrum prediction data file.
368 |
369 | Selects one inchikey per eval batch for plotting. This will avoid the
370 | threading issue seen at evaluation time.
371 |
372 | Args:
373 | dataset_config_file: dataset configuration file for experiment. Contains
374 | filename of spectrum prediction inchikey text file.
375 | num_inchikeys_to_read: Number of inchikeys to use for plotting
376 | eval_batch_size: Number of inchikeys to skip before appending the next
377 | inchikey from the text file.
378 |
379 | Returns:
380 | list [num_inchikeys_to_read] containing inchikey strings.
381 | """
382 | dataset_config_file_dir = os.path.split(dataset_config_file)[0]
383 | with tf.gfile.Open(dataset_config_file, 'r') as f:
384 | line = f.read()
385 | filenames = json.loads(line)
386 | test_inchikey_list_name = os.path.splitext(filenames[
387 | ds_constants.SPECTRUM_PREDICTION_TEST_KEY][0])[0] + '.inchikey.txt'
388 |
389 | inchikey_list_for_plotting = []
390 |
391 | with tf.gfile.Open(
392 | os.path.join(dataset_config_file_dir, test_inchikey_list_name)) as f:
393 | for line_idx, line in enumerate(f):
394 | if line_idx % eval_batch_size == 0:
395 | inchikey_list_for_plotting.append(line.strip('\n'))
396 | if len(inchikey_list_for_plotting) == num_inchikeys_to_read:
397 | break
398 |
399 | if len(inchikey_list_for_plotting) < num_inchikeys_to_read:
400 | tf.logging.warn('Dataset specified by {} has fewer than'
401 | '{} inchikeys. Returning {} for plotting'.format(
402 | dataset_config_file,
403 | num_inchikeys_to_read * eval_batch_size,
404 | len(inchikey_list_for_plotting)))
405 | return inchikey_list_for_plotting
406 |
--------------------------------------------------------------------------------
/library_matching_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for deep_molecular_massspec.library_matching."""
15 |
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 |
21 | import feature_map_constants as fmap_constants
22 | import library_matching
23 | import similarity as similarity_lib
24 | import numpy as np
25 | import tensorflow as tf
26 |
27 | PREDICTOR_INPUT_KEY = 'INPUT'
28 |
29 |
30 | class LibraryMatchingTest(tf.test.TestCase):
31 |
32 | def testCosineSimilarityProviderMatching(self):
33 | """Check correctness for querying the library with a library element."""
34 |
35 | num_examples = 20
36 | num_trials = 10
37 | data_dim = 5
38 | similarity = similarity_lib.CosineSimilarityProvider()
39 | library = np.float32(np.random.normal(size=(num_examples, data_dim)))
40 | library = tf.constant(library)
41 | library = similarity.preprocess_library(library)
42 | query_idx = tf.placeholder(shape=(), dtype=tf.int32)
43 | query = library[query_idx][np.newaxis, ...]
44 | (match_idx_op, match_similarity_op, _, _,
45 | _) = library_matching._max_similarity_match(library, query, similarity)
46 |
47 | # Use queries that are rows of the library. This means that the maximum
48 | # cosine similarity is 1.0 and is achieved by the row index of the query
49 | # in the library.
50 | with tf.Session() as sess:
51 | for _ in range(num_trials):
52 | idx = np.random.randint(0, high=num_examples)
53 | match_idx, match_similarity = sess.run(
54 | [match_idx_op, match_similarity_op], feed_dict={query_idx: idx})
55 | # Fail if the match_idx != idx, and the similarity of match_idx does
56 | # is not tied with the argmax (which is 1.0 by construction).
57 | if match_idx != idx:
58 | self.assertClose(match_similarity, 1.0)
59 |
60 | def testFindQueryPositions(self):
61 | """Test library_matching._find_query_rank_helper."""
62 |
63 | library_keys = np.array(['a', 'b', 'c', 'd', 'a', 'b'])
64 | query_keys = np.array(['a', 'b', 'c', 'a', 'b', 'c'])
65 |
66 | similarities = np.array(
67 | [[3., 4., 6., 0., 0.1, 2.], [1., -1., 5., 3, 2., 1.1],
68 | [-5., 0., 2., 0.1, 3., 1.], [0.2, 0.4, 0.6, 0.32, 0.3, 0.9],
69 | [0.2, 0.9, 0.65, 0.18, 0.3, 0.99], [0.8, 0.6, 0.5, 0.4, 0.9, 0.89]])
70 |
71 | (highest_query_ranks, lowest_query_ranks, avg_query_ranks,
72 | query_similarities) = library_matching._find_query_rank_helper(
73 | similarities, library_keys, query_keys)
74 |
75 | expected_highest_query_ranks = [4, 5, 1, 5, 1, 4]
76 | expected_lowest_query_ranks = [2, 3, 1, 4, 0, 4]
77 | expected_avg_query_ranks = [3, 4, 1, 4.5, 0.5, 4]
78 | expected_query_similarities = [3., 1.1, 2., 0.3, 0.99, 0.5]
79 |
80 | self.assertAllEqual(expected_highest_query_ranks, highest_query_ranks)
81 | self.assertAllEqual(expected_lowest_query_ranks, lowest_query_ranks)
82 | self.assertAllEqual(expected_avg_query_ranks, avg_query_ranks)
83 |
84 | self.assertAllEqual(expected_query_similarities, query_similarities)
85 |
86 | def testInvertPermutation(self):
87 | """Test library_matching._invert_permutation()."""
88 |
89 | batch_size = 5
90 | num_trials = 10
91 | permutation_length = 6
92 |
93 | def _validate_permutation(perm1, perm2):
94 | ordered_indices = np.arange(perm1.shape[0])
95 | self.assertAllEqual(perm1[perm2], ordered_indices)
96 | self.assertAllEqual(perm2[perm1], ordered_indices)
97 |
98 | for _ in range(num_trials):
99 | perms = np.stack(
100 | [
101 | np.random.permutation(permutation_length)
102 | for _ in range(batch_size)
103 | ],
104 | axis=0)
105 |
106 | inverse = library_matching._invert_permutation(perms)
107 |
108 | for j in range(batch_size):
109 | _validate_permutation(perms[j], inverse[j])
110 |
111 | def np_normalize_rows(self, d):
112 | return d / np.maximum(
113 | np.linalg.norm(d, axis=1)[..., np.newaxis], similarity_lib.EPSILON)
114 |
115 | def make_ids(self, num_ids, prefix=''):
116 | if prefix:
117 | prefix += '-'
118 |
119 | return [('%s%d' % (prefix, uid)).encode('utf-8') for uid in range(num_ids)]
120 |
121 | def _random_fingerprint(self, num_elements):
122 | return tf.to_float(tf.random_uniform(shape=(num_elements, 1024)) > 0.5)
123 |
124 | def _package_data(self, ids, spectrum, masses):
125 |
126 | def convert(t):
127 | if t is None:
128 | return t
129 | else:
130 | return tf.convert_to_tensor(t)
131 |
132 | num_elements = len(ids)
133 | fingerprints = self._random_fingerprint(num_elements)
134 | return {
135 | fmap_constants.DENSE_MASS_SPEC: convert(spectrum),
136 | fmap_constants.INCHIKEY: convert(ids),
137 | library_matching.FP_NAME_FOR_JACCARD_SIMILARITY: fingerprints,
138 | fmap_constants.MOLECULE_WEIGHT: convert(masses)
139 | }
140 |
141 | def make_x_data(self, num_examples, x_dim):
142 | return np.float32(np.random.uniform(size=(num_examples, x_dim)))
143 |
144 | def np_library_matching(self, ids_predicted, ids_observed, y_predicted,
145 | y_observed, y_query):
146 | ids_library = np.concatenate([ids_predicted, ids_observed])
147 | np_library = self.np_normalize_rows(
148 | np.concatenate([y_predicted, y_observed], axis=0))
149 | np_similarities = np.dot(np_library, np.transpose(y_query))
150 | np_predictions = np.argmax(np_similarities, axis=0)
151 | np_predicted_ids = [ids_library[i] for i in np_predictions]
152 | return np_predicted_ids
153 |
154 | def perform_matching(self, ids_observed, ids_predicted, ids_query,
155 | masses_observed, masses_predicted, masses_query,
156 | y_observed, y_query, x_predicted, tf_transform,
157 | mass_tolerance):
158 |
159 | query_data = self._package_data(
160 | ids=ids_query, spectrum=y_query, masses=masses_query)
161 |
162 | predicted_data = self._package_data(
163 | ids=ids_predicted, spectrum=None, masses=masses_predicted)
164 | predicted_data[PREDICTOR_INPUT_KEY] = tf.constant(x_predicted)
165 |
166 | observed_data = self._package_data(
167 | ids=ids_observed, spectrum=y_observed, masses=masses_observed)
168 |
169 | library_matching_data = library_matching.LibraryMatchingData(
170 | query=query_data, observed=observed_data, predicted=predicted_data)
171 |
172 | predictor_fn = lambda d: tf_transform(d[PREDICTOR_INPUT_KEY])
173 | similarity = similarity_lib.CosineSimilarityProvider()
174 | true_data, predicted_data, _, _ = (
175 | library_matching.library_matching(library_matching_data, predictor_fn,
176 | similarity, mass_tolerance, 10))
177 |
178 | with tf.Session() as sess:
179 | sess.run(tf.local_variables_initializer())
180 | return sess.run([predicted_data, true_data])
181 |
182 | def tf_vs_np_library_matching_test_helper(self,
183 | num_observed,
184 | num_predicted,
185 | query_source,
186 | num_queries=5):
187 | """Helper for asserting TF and NP give same library matching output."""
188 |
189 | x_dim = 5
190 |
191 | np_transform = lambda d: np.sqrt(d + 4)
192 | tf_transform = lambda t: tf.sqrt(t + 4)
193 |
194 | x_observed = self.make_x_data(num_observed, x_dim)
195 | x_predicted = self.make_x_data(num_predicted, x_dim)
196 |
197 | y_observed = np_transform(x_observed)
198 | y_predicted = np_transform(x_predicted)
199 |
200 | if query_source == 'random':
201 | # Use queries from the same generating process as the observed and
202 | # predicted data.
203 | x_query = self.make_x_data(num_queries, x_dim)
204 | y_query = np_transform(x_query)
205 | elif query_source == 'observed':
206 | # Copy the observed data to use as queries.
207 | y_query = y_observed
208 | elif query_source == 'zero':
209 | # Use the zero vector as the queries.
210 | y_query = np.zeros(shape=(num_queries, x_dim), dtype=np.float32)
211 | else:
212 | raise ValueError('Invalid query_source: %s' % query_source)
213 |
214 | ids_observed = self.make_ids(num_observed)
215 | ids_predicted = self.make_ids(num_predicted)
216 | ids_query = self.make_ids(num_queries)
217 | masses_observed = np.ones([num_observed, 1], dtype=np.float32)
218 | masses_predicted = np.ones([num_predicted, 1], dtype=np.float32)
219 | masses_query = np.ones([num_queries, 1], dtype=np.float32)
220 |
221 | (predicted_data, true_data) = self.perform_matching(
222 | ids_observed,
223 | ids_predicted,
224 | ids_query,
225 | masses_observed,
226 | masses_predicted,
227 | masses_query,
228 | y_observed,
229 | y_query,
230 | x_predicted,
231 | tf_transform,
232 | mass_tolerance=3)
233 | np_predicted_ids = self.np_library_matching(
234 | ids_predicted, ids_observed, y_predicted, y_observed, y_query)
235 |
236 | # Assert correctness of the ids of the library matches found by TF.
237 | self.assertAllEqual(np_predicted_ids,
238 | predicted_data[fmap_constants.INCHIKEY])
239 |
240 | # Assert correctness of the ground truth ids output extracted by TF.
241 | self.assertAllEqual(true_data[fmap_constants.INCHIKEY], ids_query)
242 |
243 | # Assert that a query spectrum that is in the observed set should be matched
244 | # to the corresponding element in the observed set.
245 | if query_source == 'observed':
246 | self.assertAllEqual(ids_observed, predicted_data[fmap_constants.INCHIKEY])
247 |
248 | def testLibraryMatchingTFvsNP(self):
249 | return self.tf_vs_np_library_matching_test_helper(
250 | num_observed=10, num_predicted=5, query_source='random')
251 |
252 | def testLibraryMatchingTFvsNPZeroObserved(self):
253 | return self.tf_vs_np_library_matching_test_helper(
254 | num_observed=0, num_predicted=5, query_source='random')
255 |
256 | def testLibraryMatchingTFvsZeroPredicted(self):
257 | return self.tf_vs_np_library_matching_test_helper(
258 | num_observed=10, num_predicted=0, query_source='random')
259 |
260 | def testLibraryMatchingTFvsNPQueriesObserved(self):
261 | return self.tf_vs_np_library_matching_test_helper(
262 | num_observed=10,
263 | num_predicted=5,
264 | query_source='observed',
265 | num_queries=10)
266 |
267 | def testLibraryMatchingTFvsNPZeroQueries(self):
268 | return self.tf_vs_np_library_matching_test_helper(
269 | num_observed=10, num_predicted=5, query_source='zero')
270 |
271 | def testLibraryMatchingHardcoded(self):
272 | """Test library_matching using hardcoded values."""
273 |
274 | tf_transform = lambda t: t + 2
275 | x_predicted = np.array([[1, 1], [-3, -2]], dtype=np.float32)
276 |
277 | y_observed = np.array([[1, 2], [2, 1], [0, 0]], dtype=np.float32)
278 | y_query = np.array([[2, 5], [2, 1], [-1.5, -1.1]], dtype=np.float32)
279 |
280 | ids_observed = self.make_ids(3, 'obs')
281 | ids_predicted = self.make_ids(2, 'pred')
282 | ids_query = np.array([b'obs-0', b'obs-1', b'pred-1'])
283 | masses_observed = np.ones([3, 1], dtype=np.float32)
284 | masses_predicted = np.ones([2, 1], dtype=np.float32)
285 | masses_query = np.ones([3, 1], dtype=np.float32)
286 |
287 | expected_predicted_ids = ids_query.tolist()
288 |
289 | predicted_data, _ = self.perform_matching(
290 | ids_observed,
291 | ids_predicted,
292 | ids_query,
293 | masses_observed,
294 | masses_predicted,
295 | masses_query,
296 | y_observed,
297 | y_query,
298 | x_predicted,
299 | tf_transform,
300 | mass_tolerance=3)
301 |
302 | self.assertAllEqual(expected_predicted_ids,
303 | predicted_data[fmap_constants.INCHIKEY])
304 |
305 | def testLibraryMatchingHardcodedMassFiltered(self):
306 | """Test library_matching using hardcoded values when filtering by mass."""
307 |
308 | tf_transform = lambda t: t + 2
309 | x_predicted = np.array([[1, 1], [-3, -2]], dtype=np.float32)
310 |
311 | y_observed = np.array([[1, 2], [2, 1], [0, 0]], dtype=np.float32)
312 | y_query = np.array([[2, 5], [2, 1], [-1.5, -1.1]], dtype=np.float32)
313 |
314 | ids_observed = self.make_ids(3, 'obs')
315 | ids_predicted = self.make_ids(2, 'pred')
316 | ids_query = np.array([b'pred-0', b'obs-1', b'obs-2'])
317 | masses_observed = np.ones([3, 1], dtype=np.float32)
318 | masses_predicted = 2 * np.ones([2, 1], dtype=np.float32)
319 | masses_query = np.array([3, 1.5, 0], dtype=np.float32)[..., np.newaxis]
320 |
321 | expected_predicted_ids = ids_query.tolist()
322 |
323 | predicted_data, _ = self.perform_matching(
324 | ids_observed,
325 | ids_predicted,
326 | ids_query,
327 | masses_observed,
328 | masses_predicted,
329 | masses_query,
330 | y_observed,
331 | y_query,
332 | x_predicted,
333 | tf_transform,
334 | mass_tolerance=1)
335 |
336 | self.assertAllEqual(expected_predicted_ids,
337 | predicted_data[fmap_constants.INCHIKEY])
338 |
339 | def testMassFilterRaisesError(self):
340 | """Test case where mass filtering removes everything."""
341 |
342 | tf_transform = lambda t: t + 2
343 | x_predicted = np.array([[1, 1], [-3, -2]], dtype=np.float32)
344 |
345 | y_observed = np.array([[1, 2], [2, 1], [0, 0]], dtype=np.float32)
346 | y_query = np.array([[2, 5]], dtype=np.float32)
347 |
348 | ids_observed = self.make_ids(3, 'obs')
349 | ids_predicted = self.make_ids(2, 'pred')
350 | ids_query = np.array(['pred-0'])
351 | masses_observed = np.ones([3, 1], dtype=np.float32)
352 | masses_predicted = 2 * np.ones([2, 1], dtype=np.float32)
353 | masses_query = np.array([5], dtype=np.float32)[..., np.newaxis]
354 |
355 | with self.assertRaises(tf.errors.InvalidArgumentError):
356 | self.perform_matching(
357 | ids_observed,
358 | ids_predicted,
359 | ids_query,
360 | masses_observed,
361 | masses_predicted,
362 | masses_query,
363 | y_observed,
364 | y_query,
365 | x_predicted,
366 | tf_transform,
367 | mass_tolerance=1)
368 |
369 | def testLibraryMatchingNoPredictions(self):
370 | """Test library_matching using hardcoded values with no predicted data."""
371 |
372 | y_observed = np.array([[1, 2], [2, 1], [0, 0]], dtype=np.float32)
373 | y_query = np.array([[2, 5], [-3, 1], [0, 0]], dtype=np.float32)
374 |
375 | ids_observed = self.make_ids(3)
376 | ids_query = self.make_ids(3)
377 |
378 | expected_predictions = [b'0', b'2', b'0']
379 |
380 | masses_query = np.ones([3, 1], dtype=np.float32)
381 | query_data = self._package_data(
382 | ids=ids_query, spectrum=y_query, masses=masses_query)
383 | masses_observed = np.ones([3, 1], dtype=np.float32)
384 | observed_data = self._package_data(
385 | ids=ids_observed, spectrum=y_observed, masses=masses_observed)
386 | predicted_data = None
387 | library_matching_data = library_matching.LibraryMatchingData(
388 | query=query_data, observed=observed_data, predicted=predicted_data)
389 |
390 | _, predicted_data, _, _ = library_matching.library_matching(
391 | library_matching_data,
392 | predictor_fn=None,
393 | similarity_provider=similarity_lib.CosineSimilarityProvider(),
394 | mass_tolerance=3.0)
395 |
396 | with tf.Session() as sess:
397 | sess.run(tf.local_variables_initializer())
398 | predictions = sess.run(predicted_data[fmap_constants.INCHIKEY])
399 |
400 | self.assertAllEqual(expected_predictions, predictions)
401 |
402 | def testLibraryMatchingNoObserved(self):
403 | """Test library_matching using hardcoded values with no observed data."""
404 |
405 | tf_transform = lambda t: t + 2
406 | x_predicted = np.array([[1, 1], [-3, -2]], dtype=np.float32)
407 | y_query = np.array([[2, 5], [-3, 1], [0, 0]], dtype=np.float32)
408 |
409 | ids_predicted = self.make_ids(2)
410 | ids_query = self.make_ids(3)
411 |
412 | expected_predictions = [b'0', b'1', b'0']
413 |
414 | masses_query = np.ones([3, 1], dtype=np.float32)
415 | query_data = self._package_data(
416 | ids=ids_query, spectrum=y_query, masses=masses_query)
417 | masses_predicted = np.ones([2, 1], dtype=np.float32)
418 | predicted_data = self._package_data(
419 | ids=ids_predicted, spectrum=None, masses=masses_predicted)
420 | predicted_data[PREDICTOR_INPUT_KEY] = tf.constant(x_predicted)
421 |
422 | observed_data = None
423 | library_matching_data = library_matching.LibraryMatchingData(
424 | query=query_data, observed=observed_data, predicted=predicted_data)
425 |
426 | predictor_fn = lambda d: tf_transform(d[PREDICTOR_INPUT_KEY])
427 |
428 | _, predicted_data, _, _ = library_matching.library_matching(
429 | library_matching_data,
430 | predictor_fn=predictor_fn,
431 | similarity_provider=similarity_lib.CosineSimilarityProvider(),
432 | mass_tolerance=3.0)
433 |
434 | with tf.Session() as sess:
435 | sess.run(tf.local_variables_initializer())
436 | predictions = sess.run(predicted_data[fmap_constants.INCHIKEY])
437 |
438 | self.assertAllEqual(expected_predictions, predictions)
439 |
440 |
441 | if __name__ == '__main__':
442 | tf.test.main()
443 |
--------------------------------------------------------------------------------
/make_train_test_split.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Creates datasets from the NIST sdf files and makes experiment setup jsons.
16 |
17 | This module first breaks up the main NIST library dataset into a train/
18 | validation/test set, and the replicates library into a validation and test set.
19 | As all the molecules in the replicates file are also in the main NIST library,
20 | the mainlib datasets will exclude inchikeys from the replicates library. All the
21 | molecules in both datasets are to be included in one of these datasets, unless
22 | an argument is passed for mainlib_maximum_num_molecules_to_use or
23 | replicates_maximum_num_molecules_to_use.
24 |
25 | The component datasets are saved as TFRecords, by the names defined in
26 | dataset_setup_constants and the library from which the data came
27 | (e.g. mainlib_train_from_mainlib.tfrecord). This will result in 7 TFRecord files
28 | total, one each for the train/validation/test splits from the main library, and
29 | two each for the replicates validation/test splits, one with its data from the
30 | mainlib NIST file, and the other from the replicates file.
31 |
32 | For each experiment setup included in
33 | dataset_setup_constants.EXPERIMENT_SETUPS_LIST, a json file is written. This
34 | json file name the files to be used for each part of the experiment, i.e.
35 | library matching, spectra prediction.
36 |
37 | Note: Reading sdf files from cns currently not supported.
38 |
39 | Example usage:
40 | make_train_test_split.py \
41 | --main_sdf_name=testdata/test_14_mend.sdf
42 | --replicates_sdf_name=testdata/test_2_mend.sdf \
43 | --output_master_dir=
44 |
45 | """
46 |
47 | from __future__ import absolute_import
48 | from __future__ import division
49 | from __future__ import print_function
50 | import json
51 | import os
52 | import random
53 |
54 | from absl import app
55 | from absl import flags
56 | import dataset_setup_constants as ds_constants
57 | import mass_spec_constants as ms_constants
58 | import parse_sdf_utils
59 | import train_test_split_utils
60 | import six
61 | import tensorflow as tf
62 |
63 | FLAGS = flags.FLAGS
64 | flags.DEFINE_string(
65 | 'main_sdf_name', 'testdata/test_14_mend.sdf',
66 | 'specify full path of sdf file to parse, to be used for'
67 | ' training sets, and validation/test sets')
68 | flags.DEFINE_string(
69 | 'replicates_sdf_name',
70 | 'testdata/test_2_mend.sdf',
71 | 'specify full path of a second sdf file to parse, to be'
72 | ' used for the vaildation/test set. Molecules in this sdf'
73 | ' will be excluded from the main train/val/test sets.')
74 | # Note: For family based splitting, all molecules passing the filter will be
75 | # placed in validation/test datasets, and then split according to the relative
76 | # ratio between the validation/test fractions. If these are both equal to 0.0,
77 | # these values will be over written to 0.5 and 0.5.
78 | flags.DEFINE_list(
79 | 'main_train_val_test_fractions', '1.0,0.0,0.0',
80 | 'specify how large to make the train, val, and test sets'
81 | ' as a fraction of the whole dataset.')
82 | flags.DEFINE_integer('mainlib_maximum_num_molecules_to_use', None,
83 | 'specify how many total samples to use for parsing')
84 | flags.DEFINE_integer('replicates_maximum_num_molecules_to_use', None,
85 | 'specify how many total samples to use for parsing')
86 | flags.DEFINE_list(
87 | 'replicates_train_val_test_fractions', '0.0,0.5,0.5',
88 | 'specify fraction of replicates molecules to use for'
89 | ' for the three replicates sample files.')
90 | flags.DEFINE_enum(
91 | 'splitting_type', 'random', ['random', 'steroid', 'diazo'],
92 | 'specify splitting method to use for creating '
93 | 'training/validation/test sets')
94 | flags.DEFINE_string('output_master_dir', '/tmp/output_dataset_dir',
95 | 'specify directory to save records')
96 | flags.DEFINE_integer('max_atoms', ms_constants.MAX_ATOMS,
97 | 'specify maximum number of atoms to allow')
98 | flags.DEFINE_integer('max_mass_spec_peak_loc', ms_constants.MAX_PEAK_LOC,
99 | 'specify greatest m/z spectrum peak to allow')
100 |
101 | INCHIKEY_FILENAME_END = '.inchikey.txt'
102 | TFRECORD_FILENAME_END = '.tfrecord'
103 | NP_LIBRARY_ARRAY_END = '.spectra_library.npy'
104 | FROM_MAINLIB_FILENAME_MODIFIER = '_from_mainlib'
105 | FROM_REPLICATES_FILENAME_MODIFIER = '_from_replicates'
106 |
107 |
108 | def make_mainlib_replicates_train_test_split(
109 | mainlib_mol_list,
110 | replicates_mol_list,
111 | splitting_type,
112 | mainlib_fractions,
113 | replicates_fractions,
114 | mainlib_maximum_num_molecules_to_use=None,
115 | replicates_maximum_num_molecules_to_use=None,
116 | rseed=42):
117 | """Makes train/validation/test inchikey lists from two lists of rdkit.Mol.
118 |
119 | Args:
120 | mainlib_mol_list : list of molecules from main library
121 | replicates_mol_list : list of molecules from replicates library
122 | splitting_type : type of splitting to use for validation splits.
123 | mainlib_fractions : TrainValTestFractions namedtuple
124 | holding desired fractions for train/val/test split of mainlib
125 | replicates_fractions : TrainValTestFractions namedtuple
126 | holding desired fractions for train/val/test split of replicates.
127 | For the replicates set, the train fraction should be set to 0.
128 | mainlib_maximum_num_molecules_to_use : Largest number of molecules to use
129 | when making datasets from mainlib
130 | replicates_maximum_num_molecules_to_use : Largest number of molecules to use
131 | when making datasets from replicates
132 | rseed : random seed for shuffling
133 |
134 | Returns:
135 | main_inchikey_dict : Dict that is keyed by inchikey, containing a list of
136 | rdkit.Mol objects corresponding to that inchikey from the mainlib
137 | replicates_inchikey_dict : Dict that is keyed by inchikey, containing a list
138 | of rdkit.Mol objects corresponding to that inchikey from the replicates
139 | library
140 | main_replicates_split_inchikey_lists_dict : dict with keys :
141 | 'mainlib_train', 'mainlib_validation', 'mainlib_test',
142 | 'replicates_train', 'replicates_validation', 'replicates_test'
143 | Values are lists of inchikeys corresponding to each dataset.
144 |
145 | """
146 | random.seed(rseed)
147 | main_inchikey_dict = train_test_split_utils.make_inchikey_dict(
148 | mainlib_mol_list)
149 | main_inchikey_list = main_inchikey_dict.keys()
150 |
151 | if six.PY3:
152 | main_inchikey_list = list(main_inchikey_list)
153 |
154 | if mainlib_maximum_num_molecules_to_use is not None:
155 | main_inchikey_list = random.sample(main_inchikey_list,
156 | mainlib_maximum_num_molecules_to_use)
157 |
158 | replicates_inchikey_dict = train_test_split_utils.make_inchikey_dict(
159 | replicates_mol_list)
160 | replicates_inchikey_list = replicates_inchikey_dict.keys()
161 |
162 | if six.PY3:
163 | replicates_inchikey_list = list(replicates_inchikey_list)
164 |
165 | if replicates_maximum_num_molecules_to_use is not None:
166 | replicates_inchikey_list = random.sample(
167 | replicates_inchikey_list, replicates_maximum_num_molecules_to_use)
168 |
169 | # Make train/val/test splits for main dataset.
170 | main_train_validation_test_inchikeys = (
171 | train_test_split_utils.make_train_val_test_split_inchikey_lists(
172 | main_inchikey_list,
173 | main_inchikey_dict,
174 | mainlib_fractions,
175 | holdout_inchikey_list=replicates_inchikey_list,
176 | splitting_type=splitting_type))
177 |
178 | # Make train/val/test splits for replicates dataset.
179 | replicates_validation_test_inchikeys = (
180 | train_test_split_utils.make_train_val_test_split_inchikey_lists(
181 | replicates_inchikey_list,
182 | replicates_inchikey_dict,
183 | replicates_fractions,
184 | splitting_type=splitting_type))
185 |
186 | component_inchikey_dict = {
187 | ds_constants.MAINLIB_TRAIN_BASENAME:
188 | main_train_validation_test_inchikeys.train,
189 | ds_constants.MAINLIB_VALIDATION_BASENAME:
190 | main_train_validation_test_inchikeys.validation,
191 | ds_constants.MAINLIB_TEST_BASENAME:
192 | main_train_validation_test_inchikeys.test,
193 | ds_constants.REPLICATES_TRAIN_BASENAME:
194 | replicates_validation_test_inchikeys.train,
195 | ds_constants.REPLICATES_VALIDATION_BASENAME:
196 | replicates_validation_test_inchikeys.validation,
197 | ds_constants.REPLICATES_TEST_BASENAME:
198 | replicates_validation_test_inchikeys.test
199 | }
200 |
201 | train_test_split_utils.assert_all_lists_mutally_exclusive(
202 | list(component_inchikey_dict.values()))
203 | # Test that the set of the 5 component inchikey lists is equal to the set of
204 | # inchikeys in the main library.
205 | all_inchikeys_in_components = []
206 | for ikey_list in list(component_inchikey_dict.values()):
207 | for ikey in ikey_list:
208 | all_inchikeys_in_components.append(ikey)
209 |
210 | assert set(main_inchikey_list + replicates_inchikey_list) == set(
211 | all_inchikeys_in_components
212 | ), ('The inchikeys in the original inchikey dictionary are not all included'
213 | ' in the train/val/test component libraries')
214 |
215 | return (main_inchikey_dict, replicates_inchikey_dict, component_inchikey_dict)
216 |
217 |
218 | def write_list_of_inchikeys(inchikey_list, base_name, output_dir):
219 | """Write list of inchikeys as a text file."""
220 | inchikey_list_name = base_name + INCHIKEY_FILENAME_END
221 |
222 | with tf.gfile.Open(os.path.join(output_dir, inchikey_list_name),
223 | 'w') as writer:
224 | for inchikey in inchikey_list:
225 | writer.write('%s\n' % inchikey)
226 |
227 |
228 | def write_all_dataset_files(inchikey_dict,
229 | inchikey_list,
230 | base_name,
231 | output_dir,
232 | max_atoms,
233 | max_mass_spec_peak_loc,
234 | make_library_array=False):
235 | """Helper function for writing all the files associated with a TFRecord.
236 |
237 | Args:
238 | inchikey_dict : Full dictionary keyed by inchikey containing lists of
239 | rdkit.Mol objects
240 | inchikey_list : List of inchikeys to include in dataset
241 | base_name : Base name for the dataset
242 | output_dir : Path for saving all TFRecord files
243 | max_atoms : Maximum number of atoms to include for a given molecule
244 | max_mass_spec_peak_loc : Largest m/z peak to include in a spectra.
245 | make_library_array : Flag for whether to make library array
246 | Returns:
247 | Saves 3 files:
248 | basename.tfrecord : a TFRecord file,
249 | basename.inchikey.txt : a text file with all the inchikeys in the dataset
250 | basename.tfrecord.info: a text file with one line describing
251 | the length of the TFRecord file.
252 | Also saves if make_library_array is set:
253 | basename.npz : see parse_sdf_utils.write_dicts_to_example
254 | """
255 | record_name = base_name + TFRECORD_FILENAME_END
256 |
257 | mol_list = train_test_split_utils.make_mol_list_from_inchikey_dict(
258 | inchikey_dict, inchikey_list)
259 |
260 | if make_library_array:
261 | library_array_pathname = base_name + NP_LIBRARY_ARRAY_END
262 | parse_sdf_utils.write_dicts_to_example(
263 | mol_list, os.path.join(output_dir, record_name),
264 | max_atoms, max_mass_spec_peak_loc,
265 | os.path.join(output_dir, library_array_pathname))
266 | else:
267 | parse_sdf_utils.write_dicts_to_example(
268 | mol_list, os.path.join(output_dir, record_name), max_atoms,
269 | max_mass_spec_peak_loc)
270 | write_list_of_inchikeys(inchikey_list, base_name, output_dir)
271 | parse_sdf_utils.write_info_file(mol_list, os.path.join(
272 | output_dir, record_name))
273 |
274 |
275 | def write_mainlib_split_datasets(component_inchikey_dict, mainlib_inchikey_dict,
276 | output_dir, max_atoms, max_mass_spec_peak_loc):
277 | """Write all train/val/test set TFRecords from main NIST sdf file."""
278 | for component_kwarg in component_inchikey_dict.keys():
279 | component_mainlib_filename = (
280 | component_kwarg + FROM_MAINLIB_FILENAME_MODIFIER)
281 | if component_kwarg == ds_constants.MAINLIB_TRAIN_BASENAME:
282 | write_all_dataset_files(
283 | mainlib_inchikey_dict,
284 | component_inchikey_dict[component_kwarg],
285 | component_mainlib_filename,
286 | output_dir,
287 | max_atoms,
288 | max_mass_spec_peak_loc,
289 | make_library_array=True)
290 | else:
291 | write_all_dataset_files(mainlib_inchikey_dict,
292 | component_inchikey_dict[component_kwarg],
293 | component_mainlib_filename, output_dir, max_atoms,
294 | max_mass_spec_peak_loc)
295 |
296 |
297 | def write_replicates_split_datasets(component_inchikey_dict,
298 | replicates_inchikey_dict, output_dir,
299 | max_atoms, max_mass_spec_peak_loc):
300 | """Write replicates val/test set TFRecords from replicates sdf file."""
301 | for component_kwarg in [
302 | ds_constants.REPLICATES_VALIDATION_BASENAME,
303 | ds_constants.REPLICATES_TEST_BASENAME
304 | ]:
305 | component_replicates_filename = (
306 | component_kwarg + FROM_REPLICATES_FILENAME_MODIFIER)
307 | write_all_dataset_files(replicates_inchikey_dict,
308 | component_inchikey_dict[component_kwarg],
309 | component_replicates_filename, output_dir,
310 | max_atoms, max_mass_spec_peak_loc)
311 |
312 |
313 | def combine_inchikey_sets(dataset_subdivision_list, dataset_split_dict):
314 | """A function to combine lists of inchikeys that are values from a dict.
315 |
316 | Args:
317 | dataset_subdivision_list: List of keys in dataset_split_dict to combine
318 | into one list
319 | dataset_split_dict: dict containing keys in dataset_subdivision_list, with
320 | lists of inchikeys as values.
321 | Returns:
322 | A list of inchikeys.
323 | """
324 | dataset_inchikey_list = []
325 | for dataset_subdivision_name in dataset_subdivision_list:
326 | dataset_inchikey_list.extend(dataset_split_dict[dataset_subdivision_name])
327 | return dataset_inchikey_list
328 |
329 |
330 | def check_experiment_setup(experiment_setup_dict, component_inchikey_dict):
331 | """Validates experiment setup for given lists of inchikeys."""
332 |
333 | # Check that the union of the library matching observed and library
334 | # matching predicted sets are equal to the set of inchikeys in the
335 | # mainlib_inchikey_dict
336 | all_inchikeys_in_library = (
337 | combine_inchikey_sets(
338 | experiment_setup_dict[ds_constants.LIBRARY_MATCHING_OBSERVED_KEY],
339 | component_inchikey_dict) +
340 | combine_inchikey_sets(
341 | experiment_setup_dict[ds_constants.LIBRARY_MATCHING_PREDICTED_KEY],
342 | component_inchikey_dict))
343 |
344 | all_inchikeys_in_use = []
345 | for kwarg in component_inchikey_dict.keys():
346 | all_inchikeys_in_use.extend(component_inchikey_dict[kwarg])
347 |
348 | assert set(all_inchikeys_in_use) == set(all_inchikeys_in_library), (
349 | 'Inchikeys in library for library matching does not match full dataset.')
350 |
351 | # Check that all inchikeys in query are found in full library of inchikeys.
352 | assert set(
353 | combine_inchikey_sets(
354 | experiment_setup_dict[ds_constants.LIBRARY_MATCHING_QUERY_KEY],
355 | component_inchikey_dict)).issubset(set(all_inchikeys_in_library)), (
356 | 'Inchikeys in query set for library matching not'
357 | 'found in library.')
358 |
359 |
360 | def write_json_for_experiment(experiment_setup, output_dir):
361 | """Writes json for experiment, recording relevant files for each component.
362 |
363 | Writes a json containing a list of TFRecord file names to read
364 | for each experiment component, i.e. spectrum_prediction, library_matching.
365 |
366 | Args:
367 | experiment_setup: A dataset_setup_constants.ExperimentSetup tuple
368 | output_dir: directory to write json
369 | Returns:
370 | Writes json recording which files to load for each component
371 | of the experiment
372 | Raises:
373 | ValueError: if the experiment component is not specified to be taken from
374 | either the main NIST library or the replicates library.
375 |
376 | """
377 | experiment_json_dict = {}
378 | for dataset_kwarg in experiment_setup.experiment_setup_dataset_dict:
379 | if dataset_kwarg in experiment_setup.data_to_get_from_mainlib:
380 | experiment_json_dict[dataset_kwarg] = [
381 | (component_basename + FROM_MAINLIB_FILENAME_MODIFIER +
382 | TFRECORD_FILENAME_END) for component_basename in
383 | experiment_setup.experiment_setup_dataset_dict[dataset_kwarg]
384 | ]
385 | elif dataset_kwarg in experiment_setup.data_to_get_from_replicates:
386 | experiment_json_dict[dataset_kwarg] = [
387 | (component_basename + FROM_REPLICATES_FILENAME_MODIFIER +
388 | TFRECORD_FILENAME_END) for component_basename in
389 | experiment_setup.experiment_setup_dataset_dict[dataset_kwarg]
390 | ]
391 | else:
392 | raise ValueError('Did not specify origin for {}.'.format(dataset_kwarg))
393 |
394 | training_spectra_filename = (
395 | ds_constants.MAINLIB_TRAIN_BASENAME + FROM_MAINLIB_FILENAME_MODIFIER +
396 | NP_LIBRARY_ARRAY_END)
397 | experiment_json_dict[
398 | ds_constants.TRAINING_SPECTRA_ARRAY_KEY] = training_spectra_filename
399 |
400 | with tf.gfile.Open(os.path.join(output_dir, experiment_setup.json_name),
401 | 'w') as writer:
402 | experiment_json = json.dumps(experiment_json_dict)
403 | writer.write(experiment_json)
404 |
405 |
406 | def main(_):
407 | tf.gfile.MkDir(FLAGS.output_master_dir)
408 |
409 | main_train_val_test_fractions_tuple = tuple(
410 | [float(elem) for elem in FLAGS.main_train_val_test_fractions])
411 | main_train_val_test_fractions = train_test_split_utils.TrainValTestFractions(
412 | *main_train_val_test_fractions_tuple)
413 |
414 | replicates_train_val_test_fractions_tuple = tuple(
415 | [float(elem) for elem in FLAGS.replicates_train_val_test_fractions])
416 | replicates_train_val_test_fractions = (
417 | train_test_split_utils.TrainValTestFractions(
418 | *replicates_train_val_test_fractions_tuple))
419 |
420 | mainlib_mol_list = parse_sdf_utils.get_sdf_to_mol(
421 | FLAGS.main_sdf_name, max_atoms=FLAGS.max_atoms)
422 | replicates_mol_list = parse_sdf_utils.get_sdf_to_mol(
423 | FLAGS.replicates_sdf_name, max_atoms=FLAGS.max_atoms)
424 |
425 | # Breaks the inchikeys lists into train/validation/test splits.
426 | (mainlib_inchikey_dict, replicates_inchikey_dict, component_inchikey_dict) = (
427 | make_mainlib_replicates_train_test_split(
428 | mainlib_mol_list,
429 | replicates_mol_list,
430 | FLAGS.splitting_type,
431 | main_train_val_test_fractions,
432 | replicates_train_val_test_fractions,
433 | mainlib_maximum_num_molecules_to_use=FLAGS.
434 | mainlib_maximum_num_molecules_to_use,
435 | replicates_maximum_num_molecules_to_use=FLAGS.
436 | replicates_maximum_num_molecules_to_use))
437 |
438 | # Writes TFRecords for each component using info from the main library file
439 | write_mainlib_split_datasets(component_inchikey_dict, mainlib_inchikey_dict,
440 | FLAGS.output_master_dir, FLAGS.max_atoms,
441 | FLAGS.max_mass_spec_peak_loc)
442 |
443 | # Writes TFRecords for each component using info from the replicates file
444 | write_replicates_split_datasets(
445 | component_inchikey_dict, replicates_inchikey_dict,
446 | FLAGS.output_master_dir, FLAGS.max_atoms, FLAGS.max_mass_spec_peak_loc)
447 |
448 | for experiment_setup in ds_constants.EXPERIMENT_SETUPS_LIST:
449 | # Check that experiment setup is valid.
450 | check_experiment_setup(experiment_setup.experiment_setup_dataset_dict,
451 | component_inchikey_dict)
452 |
453 | # Write a json for the experiment setups, pointing to local files.
454 | write_json_for_experiment(experiment_setup, FLAGS.output_master_dir)
455 |
456 |
457 | if __name__ == '__main__':
458 | app.run(main)
459 |
--------------------------------------------------------------------------------