├── .gitignore ├── LICENSE ├── README.md ├── nmt ├── __init__.py ├── attention_model.py ├── gnmt_model.py ├── inference.py ├── input_config.py ├── ion_beam_search_decoder.py ├── model.py ├── model_helper.py ├── post_process.py ├── scripts │ ├── __init__.py │ ├── bleu.py │ └── rouge.py ├── train.py ├── utils │ ├── __init__.py │ ├── evaluation_utils.py │ ├── file_utils.py │ ├── iterator_utils.py │ ├── mask_utils.py │ ├── misc_utils.py │ ├── nmt_utils.py │ ├── peaks_utils.py │ ├── standard_hparams_utils.py │ └── vocab_utils.py └── vocab │ ├── mass.txt │ ├── vocab.txt │ ├── vocab_m.txt │ └── vocab_nomod.txt ├── requirements.txt ├── run.py ├── utils_data ├── clean_msms_data.py └── convert_mgf_to_csv.py └── utils_masking ├── SMSNet_final_database_search.py ├── append_decoded_peptides.py └── create_denovo_report.py /.gitignore: -------------------------------------------------------------------------------- 1 | data* 2 | log* 3 | *.zip 4 | 5 | 6 | # Tensorflow 7 | .ipynb_checkpoints 8 | node_modules 9 | /.bazelrc 10 | /.tf_configure.bazelrc 11 | /bazel-* 12 | /bazel_pip 13 | /tools/python_bin_path.sh 14 | /tensorflow/tools/git/gen 15 | /pip_test 16 | /_python_build 17 | *.pyc 18 | __pycache__ 19 | *.swp 20 | .vscode/ 21 | cmake_build/ 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SMSNet 2 | 3 | ## Sequence-Mask-Search framework and SMSNet model for de novo peptide sequencing. 4 | 5 | The pre-trained model can be downloaded from [FigShare](https://figshare.com/articles/SMSNet_s_trained_models/8259122). 6 | 7 | SMSNet's predicted amino acid sequences for a public HLA peptidome dataset (MassIVE accession MSV000080527) and phosphoproteome dataset (PRIDE accession PXD009227) can be found on [FigShare](https://figshare.com/articles/SMSNet_s_predictions_for_HLA_peptidome_and_human_phosphoproteome_datasets/8259134). 8 | 9 | The preprint can be found on [bioRxiv](http://biorxiv.org/cgi/content/short/667527v1). 10 | 11 | ## Dependencies 12 | This project uses Python 3.5.2, with the following lib dependencies: 13 | * [Tensorflow 1.4](https://www.tensorflow.org/) (compatible up to 1.11) 14 | * [Keras 2.2.4](https://keras.io/) 15 | 16 | A list of all python packages can be found in ```requirement.txt``` 17 | 18 | 19 | ## Instructions 20 | ### Decode 21 | ``` 22 | python run.py --model_dir --inference_input_file --rescore 23 | ``` 24 | `````` is the directory of the model (can be downloaded from the link above). 25 | 26 | `````` is the path to the input file. 27 | 28 | Using --rescore flag will generate another probability file with suffix “_rescore” in the same directory. The output will be in “_output/”. 29 | 30 | Other options can be found in "run.py". 31 | 32 | Model parameters (including possible amino acids) can be found in "nmt/input_config.py". 33 | 34 | ### Note 35 | * In order to generate the report file, the ```TITLE``` lines in .mgf file must end with "scan=". 36 | * To switch between m-mod and p-mod model, the following changes are needed (default: p-mod): 37 | * ```tgt_vocab_size``` (24 for m_mod / 27 for p-mod) and ```tgt_vocab_file``` in ```run.py``` line 61-62. 38 | * Comment/uncomment the possible vocab in ```inverse_vocab``` in ```nmt/input_config.py``` accordingly ('s', 't', 'y' at line 65, 67, 71). 39 | * Select the corresponding ```AA_weight_table``` in function ```create_aa_tables()``` in ```nmt/input_config.py``` (by comment/uncomment line 169-174 or 176-180). 40 | 41 | 42 | ### Outputs 43 | * For each input file, three output files will be generated in the output directory: ``````, ```_prob``` and, ```_rescore```. They are the output sequences, probabilities of each amino acid, and probabilities of each amino acid after rescoring, respectively. 44 | * The report summarizing the outputs in .tsv format will be in the same parent directory as the input. 45 | 46 | ## Example 47 | Decoding “test_decode/test_file.mgf” with a model in "model/m_mod/". 48 | ``` 49 | 50 | python run.py --model_dir model/m_mod/ --inference_input_file test_decode/test_file.mgf --rescore 51 | ``` 52 | The report will be in test_decode_output. 53 | 54 | The model will also produce three files: 55 | ``` 56 | test_decode_output/test_file (sequence) 57 | test_decode_output/test_file_prob (score) 58 | test_decode_output/test_file_rescore (score after rescoring with post-processing model) 59 | ``` 60 | 61 | ## Database search 62 | For database searching, change the file name in ```utils_masking/SMSNet_final_database_search.py```, then run ```python utils_masking/SMSNet_final_database_search.py``` 63 | 64 | 65 | ## Acknowledgement 66 | This code is based on TensorFlow Neural Machine Translation [GNMT](https://github.com/tensorflow/nmt). 67 | -------------------------------------------------------------------------------- /nmt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmb-chula/SMSNet/facfaf441d0ef286d062f69530f0a298aba78edc/nmt/__init__.py -------------------------------------------------------------------------------- /nmt/attention_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Attention-based sequence-to-sequence model with dynamic RNN support.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from . import model 23 | from . import model_helper 24 | 25 | __all__ = ["AttentionModel"] 26 | 27 | 28 | class AttentionModel(model.Model): 29 | """Sequence-to-sequence dynamic model with attention. 30 | 31 | This class implements a multi-layer recurrent neural network as encoder, 32 | and an attention-based decoder. This is the same as the model described in 33 | (Luong et al., EMNLP'2015) paper: https://arxiv.org/pdf/1508.04025v5.pdf. 34 | This class also allows to use GRU cells in addition to LSTM cells with 35 | support for dropout. 36 | """ 37 | 38 | def __init__(self, 39 | hparams, 40 | mode, 41 | iterator, 42 | source_vocab_table, 43 | target_vocab_table, 44 | reverse_target_vocab_table=None, 45 | scope=None, 46 | extra_args=None): 47 | # Set attention_mechanism_fn 48 | if extra_args and extra_args.attention_mechanism_fn: 49 | self.attention_mechanism_fn = extra_args.attention_mechanism_fn 50 | else: 51 | self.attention_mechanism_fn = create_attention_mechanism 52 | 53 | super(AttentionModel, self).__init__( 54 | hparams=hparams, 55 | mode=mode, 56 | iterator=iterator, 57 | source_vocab_table=source_vocab_table, 58 | target_vocab_table=target_vocab_table, 59 | reverse_target_vocab_table=reverse_target_vocab_table, 60 | scope=scope, 61 | extra_args=extra_args) 62 | 63 | if self.mode == tf.contrib.learn.ModeKeys.INFER: 64 | self.infer_summary = self._get_infer_summary(hparams) 65 | 66 | def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, 67 | source_sequence_length): 68 | """Build a RNN cell with attention mechanism that can be used by decoder.""" 69 | attention_option = hparams.attention 70 | attention_architecture = hparams.attention_architecture 71 | 72 | if attention_architecture != "standard": 73 | raise ValueError( 74 | "Unknown attention architecture %s" % attention_architecture) 75 | 76 | num_units = hparams.num_units 77 | num_layers = self.num_decoder_layers 78 | num_residual_layers = self.num_decoder_residual_layers 79 | beam_width = hparams.beam_width 80 | 81 | dtype = tf.float32 82 | 83 | # Ensure memory is batch-major 84 | if self.time_major: 85 | memory = tf.transpose(encoder_outputs, [1, 0, 2]) 86 | else: 87 | memory = encoder_outputs 88 | 89 | if self.mode == tf.contrib.learn.ModeKeys.INFER and beam_width > 0: 90 | memory = tf.contrib.seq2seq.tile_batch( 91 | memory, multiplier=beam_width) 92 | source_sequence_length = tf.contrib.seq2seq.tile_batch( 93 | source_sequence_length, multiplier=beam_width) 94 | encoder_state = tf.contrib.seq2seq.tile_batch( 95 | encoder_state, multiplier=beam_width) 96 | batch_size = self.batch_size * beam_width 97 | else: 98 | batch_size = self.batch_size 99 | 100 | attention_mechanism = self.attention_mechanism_fn( 101 | attention_option, num_units, memory, source_sequence_length, self.mode) 102 | 103 | cell = model_helper.create_rnn_cell( 104 | unit_type=hparams.unit_type, 105 | num_units=num_units, 106 | num_layers=num_layers, 107 | num_residual_layers=num_residual_layers, 108 | forget_bias=hparams.forget_bias, 109 | dropout=hparams.dropout, 110 | num_gpus=self.num_gpus, 111 | mode=self.mode, 112 | single_cell_fn=self.single_cell_fn) 113 | 114 | # Only generate alignment in greedy INFER mode. 115 | alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and 116 | beam_width == 0) 117 | cell = tf.contrib.seq2seq.AttentionWrapper( 118 | cell, 119 | attention_mechanism, 120 | attention_layer_size=num_units, 121 | alignment_history=alignment_history, 122 | output_attention=hparams.output_attention, 123 | name="attention") 124 | 125 | # TODO(thangluong): do we need num_layers, num_gpus? 126 | cell = tf.contrib.rnn.DeviceWrapper(cell, 127 | model_helper.get_device_str( 128 | num_layers - 1, self.num_gpus)) 129 | 130 | if hparams.pass_hidden_state: 131 | decoder_initial_state = cell.zero_state(batch_size, dtype).clone( 132 | cell_state=encoder_state) 133 | else: 134 | decoder_initial_state = cell.zero_state(batch_size, dtype) 135 | 136 | return cell, decoder_initial_state 137 | 138 | def _get_infer_summary(self, hparams): 139 | if hparams.beam_width > 0: 140 | return tf.no_op() 141 | return _create_attention_images_summary(self.final_context_state) 142 | 143 | 144 | def create_attention_mechanism(attention_option, num_units, memory, 145 | source_sequence_length, mode): 146 | """Create attention mechanism based on the attention_option.""" 147 | del mode # unused 148 | 149 | # Mechanism 150 | if attention_option == "luong": 151 | attention_mechanism = tf.contrib.seq2seq.LuongAttention( 152 | num_units, memory, memory_sequence_length=source_sequence_length) 153 | elif attention_option == "scaled_luong": 154 | attention_mechanism = tf.contrib.seq2seq.LuongAttention( 155 | num_units, 156 | memory, 157 | memory_sequence_length=source_sequence_length, 158 | scale=True) 159 | elif attention_option == "bahdanau": 160 | attention_mechanism = tf.contrib.seq2seq.BahdanauAttention( 161 | num_units, memory, memory_sequence_length=source_sequence_length) 162 | elif attention_option == "normed_bahdanau": 163 | attention_mechanism = tf.contrib.seq2seq.BahdanauAttention( 164 | num_units, 165 | memory, 166 | memory_sequence_length=source_sequence_length, 167 | normalize=True) 168 | else: 169 | raise ValueError("Unknown attention option %s" % attention_option) 170 | 171 | return attention_mechanism 172 | 173 | 174 | def _create_attention_images_summary(final_context_state): 175 | """create attention image and attention summary.""" 176 | attention_images = (final_context_state.alignment_history.stack()) 177 | # Reshape to (batch, src_seq_len, tgt_seq_len,1) 178 | attention_images = tf.expand_dims( 179 | tf.transpose(attention_images, [1, 2, 0]), -1) 180 | # Scale to range [0, 255] 181 | attention_images *= 255 182 | attention_summary = tf.summary.image("attention_images", attention_images) 183 | return attention_summary 184 | -------------------------------------------------------------------------------- /nmt/gnmt_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """GNMT attention sequence-to-sequence model with dynamic RNN support.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | # TODO(rzhao): Use tf.contrib.framework.nest once 1.3 is out. 24 | from tensorflow.python.util import nest 25 | 26 | from . import attention_model 27 | from . import model_helper 28 | from .utils import misc_utils as utils 29 | 30 | __all__ = ["GNMTModel"] 31 | 32 | 33 | class GNMTModel(attention_model.AttentionModel): 34 | """Sequence-to-sequence dynamic model with GNMT attention architecture. 35 | """ 36 | 37 | def __init__(self, 38 | hparams, 39 | mode, 40 | iterator, 41 | source_vocab_table, 42 | target_vocab_table, 43 | reverse_target_vocab_table=None, 44 | scope=None, 45 | extra_args=None): 46 | super(GNMTModel, self).__init__( 47 | hparams=hparams, 48 | mode=mode, 49 | iterator=iterator, 50 | source_vocab_table=source_vocab_table, 51 | target_vocab_table=target_vocab_table, 52 | reverse_target_vocab_table=reverse_target_vocab_table, 53 | scope=scope, 54 | extra_args=extra_args) 55 | 56 | def _build_encoder(self, hparams): 57 | """Build a GNMT encoder.""" 58 | if hparams.encoder_type == "uni" or hparams.encoder_type == "bi": 59 | return super(GNMTModel, self)._build_encoder(hparams) 60 | 61 | if hparams.encoder_type != "gnmt": 62 | raise ValueError("Unknown encoder_type %s" % hparams.encoder_type) 63 | 64 | # Build GNMT encoder. 65 | num_bi_layers = 1 66 | num_uni_layers = self.num_encoder_layers - num_bi_layers 67 | utils.print_out(" num_bi_layers = %d" % num_bi_layers) 68 | utils.print_out(" num_uni_layers = %d" % num_uni_layers) 69 | 70 | iterator = self.iterator 71 | source = iterator.source 72 | if self.time_major: 73 | source = tf.transpose(source) 74 | 75 | with tf.variable_scope("encoder") as scope: 76 | dtype = scope.dtype 77 | 78 | # Look up embedding, emp_inp: [max_time, batch_size, num_units] 79 | # when time_major = True 80 | encoder_emb_inp = tf.nn.embedding_lookup(self.embedding_encoder, 81 | source) 82 | 83 | # Execute _build_bidirectional_rnn from Model class 84 | bi_encoder_outputs, bi_encoder_state = self._build_bidirectional_rnn( 85 | inputs=encoder_emb_inp, 86 | sequence_length=iterator.source_sequence_length, 87 | dtype=dtype, 88 | hparams=hparams, 89 | num_bi_layers=num_bi_layers, 90 | num_bi_residual_layers=0, # no residual connection 91 | ) 92 | 93 | uni_cell = model_helper.create_rnn_cell( 94 | unit_type=hparams.unit_type, 95 | num_units=hparams.num_units, 96 | num_layers=num_uni_layers, 97 | num_residual_layers=self.num_encoder_residual_layers, 98 | forget_bias=hparams.forget_bias, 99 | dropout=hparams.dropout, 100 | num_gpus=self.num_gpus, 101 | base_gpu=1, 102 | mode=self.mode, 103 | single_cell_fn=self.single_cell_fn) 104 | 105 | # encoder_outputs: size [max_time, batch_size, num_units] 106 | # when time_major = True 107 | encoder_outputs, encoder_state = tf.nn.dynamic_rnn( 108 | uni_cell, 109 | bi_encoder_outputs, 110 | dtype=dtype, 111 | sequence_length=iterator.source_sequence_length, 112 | time_major=self.time_major) 113 | 114 | # Pass all encoder state except the first bi-directional layer's state to 115 | # decoder. 116 | encoder_state = (bi_encoder_state[1],) + ( 117 | (encoder_state,) if num_uni_layers == 1 else encoder_state) 118 | 119 | return encoder_outputs, encoder_state 120 | 121 | def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, 122 | source_sequence_length): 123 | """Build a RNN cell with GNMT attention architecture.""" 124 | # Standard attention 125 | if hparams.attention_architecture == "standard": 126 | return super(GNMTModel, self)._build_decoder_cell( 127 | hparams, encoder_outputs, encoder_state, source_sequence_length) 128 | 129 | # GNMT attention 130 | attention_option = hparams.attention 131 | attention_architecture = hparams.attention_architecture 132 | num_units = hparams.num_units 133 | beam_width = hparams.beam_width 134 | 135 | dtype = tf.float32 136 | 137 | if self.time_major: 138 | memory = tf.transpose(encoder_outputs, [1, 0, 2]) 139 | else: 140 | memory = encoder_outputs 141 | 142 | if self.mode == tf.contrib.learn.ModeKeys.INFER and beam_width > 0: 143 | memory = tf.contrib.seq2seq.tile_batch( 144 | memory, multiplier=beam_width) 145 | source_sequence_length = tf.contrib.seq2seq.tile_batch( 146 | source_sequence_length, multiplier=beam_width) 147 | encoder_state = tf.contrib.seq2seq.tile_batch( 148 | encoder_state, multiplier=beam_width) 149 | batch_size = self.batch_size * beam_width 150 | else: 151 | batch_size = self.batch_size 152 | 153 | attention_mechanism = self.attention_mechanism_fn( 154 | attention_option, num_units, memory, source_sequence_length, self.mode) 155 | 156 | cell_list = model_helper._cell_list( # pylint: disable=protected-access 157 | unit_type=hparams.unit_type, 158 | num_units=num_units, 159 | num_layers=self.num_decoder_layers, 160 | num_residual_layers=self.num_decoder_residual_layers, 161 | forget_bias=hparams.forget_bias, 162 | dropout=hparams.dropout, 163 | num_gpus=self.num_gpus, 164 | mode=self.mode, 165 | single_cell_fn=self.single_cell_fn, 166 | residual_fn=gnmt_residual_fn 167 | ) 168 | 169 | # Only wrap the bottom layer with the attention mechanism. 170 | attention_cell = cell_list.pop(0) 171 | 172 | # Only generate alignment in greedy INFER mode. 173 | alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and 174 | beam_width == 0) 175 | attention_cell = tf.contrib.seq2seq.AttentionWrapper( 176 | attention_cell, 177 | attention_mechanism, 178 | attention_layer_size=None, # don't use attention layer. 179 | output_attention=False, 180 | alignment_history=alignment_history, 181 | name="attention") 182 | 183 | if attention_architecture == "gnmt": 184 | cell = GNMTAttentionMultiCell( 185 | attention_cell, cell_list) 186 | elif attention_architecture == "gnmt_v2": 187 | cell = GNMTAttentionMultiCell( 188 | attention_cell, cell_list, use_new_attention=True) 189 | else: 190 | raise ValueError( 191 | "Unknown attention_architecture %s" % attention_architecture) 192 | 193 | if hparams.pass_hidden_state: 194 | decoder_initial_state = tuple( 195 | zs.clone(cell_state=es) 196 | if isinstance(zs, tf.contrib.seq2seq.AttentionWrapperState) else es 197 | for zs, es in zip( 198 | cell.zero_state(batch_size, dtype), encoder_state)) 199 | else: 200 | decoder_initial_state = cell.zero_state(batch_size, dtype) 201 | 202 | return cell, decoder_initial_state 203 | 204 | def _get_infer_summary(self, hparams): 205 | # Standard attention 206 | if hparams.attention_architecture == "standard": 207 | return super(GNMTModel, self)._get_infer_summary(hparams) 208 | 209 | # GNMT attention 210 | if hparams.beam_width > 0: 211 | return tf.no_op() 212 | return attention_model._create_attention_images_summary( 213 | self.final_context_state[0]) 214 | 215 | 216 | class GNMTAttentionMultiCell(tf.nn.rnn_cell.MultiRNNCell): 217 | """A MultiCell with GNMT attention style.""" 218 | 219 | def __init__(self, attention_cell, cells, use_new_attention=False): 220 | """Creates a GNMTAttentionMultiCell. 221 | 222 | Args: 223 | attention_cell: An instance of AttentionWrapper. 224 | cells: A list of RNNCell wrapped with AttentionInputWrapper. 225 | use_new_attention: Whether to use the attention generated from current 226 | step bottom layer's output. Default is False. 227 | """ 228 | cells = [attention_cell] + cells 229 | self.use_new_attention = use_new_attention 230 | super(GNMTAttentionMultiCell, self).__init__(cells, state_is_tuple=True) 231 | 232 | def __call__(self, inputs, state, scope=None): 233 | """Run the cell with bottom layer's attention copied to all upper layers.""" 234 | if not nest.is_sequence(state): 235 | raise ValueError( 236 | "Expected state to be a tuple of length %d, but received: %s" 237 | % (len(self.state_size), state)) 238 | 239 | with tf.variable_scope(scope or "multi_rnn_cell"): 240 | new_states = [] 241 | 242 | with tf.variable_scope("cell_0_attention"): 243 | attention_cell = self._cells[0] 244 | attention_state = state[0] 245 | cur_inp, new_attention_state = attention_cell(inputs, attention_state) 246 | new_states.append(new_attention_state) 247 | 248 | for i in range(1, len(self._cells)): 249 | with tf.variable_scope("cell_%d" % i): 250 | 251 | cell = self._cells[i] 252 | cur_state = state[i] 253 | 254 | if self.use_new_attention: 255 | cur_inp = tf.concat([cur_inp, new_attention_state.attention], -1) 256 | else: 257 | cur_inp = tf.concat([cur_inp, attention_state.attention], -1) 258 | 259 | cur_inp, new_state = cell(cur_inp, cur_state) 260 | new_states.append(new_state) 261 | 262 | return cur_inp, tuple(new_states) 263 | 264 | 265 | def gnmt_residual_fn(inputs, outputs): 266 | """Residual function that handles different inputs and outputs inner dims. 267 | 268 | Args: 269 | inputs: cell inputs, this is actual inputs concatenated with the attention 270 | vector. 271 | outputs: cell outputs 272 | 273 | Returns: 274 | outputs + actual inputs 275 | """ 276 | def split_input(inp, out): 277 | out_dim = out.get_shape().as_list()[-1] 278 | inp_dim = inp.get_shape().as_list()[-1] 279 | return tf.split(inp, [out_dim, inp_dim - out_dim], axis=-1) 280 | actual_inputs, _ = nest.map_structure(split_input, inputs, outputs) 281 | def assert_shape_match(inp, out): 282 | inp.get_shape().assert_is_compatible_with(out.get_shape()) 283 | nest.assert_same_structure(actual_inputs, outputs) 284 | nest.map_structure(assert_shape_match, actual_inputs, outputs) 285 | return nest.map_structure(lambda inp, out: inp + out, actual_inputs, outputs) 286 | -------------------------------------------------------------------------------- /nmt/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """To perform inference on test set given a trained model.""" 17 | from __future__ import print_function 18 | 19 | import codecs 20 | import time 21 | 22 | import tensorflow as tf 23 | 24 | from . import attention_model 25 | from . import gnmt_model 26 | from . import model as nmt_model 27 | from . import model_helper 28 | from .utils import misc_utils as utils 29 | from .utils import nmt_utils 30 | 31 | __all__ = ["load_data", "inference", 32 | "single_worker_inference", "multi_worker_inference"] 33 | 34 | 35 | def _decode_inference_indices(model, sess, output_infer, 36 | output_infer_summary_prefix, 37 | inference_indices, 38 | tgt_eos, 39 | subword_option): 40 | """Decoding only a specific set of sentences.""" 41 | utils.print_out(" decoding to output %s , num sents %d." % 42 | (output_infer, len(inference_indices))) 43 | start_time = time.time() 44 | with codecs.getwriter("utf-8")( 45 | tf.gfile.GFile(output_infer, mode="wb")) as trans_f: 46 | trans_f.write("") # Write empty string to ensure file is created. 47 | for decode_id in inference_indices: 48 | nmt_outputs, infer_summary = model.decode(sess) 49 | 50 | # get text translation 51 | assert nmt_outputs.shape[0] == 1 52 | translation = nmt_utils.get_translation( 53 | nmt_outputs, 54 | sent_id=0, 55 | tgt_eos=tgt_eos, 56 | subword_option=subword_option) 57 | 58 | if infer_summary is not None: # Attention models 59 | image_file = output_infer_summary_prefix + str(decode_id) + ".png" 60 | utils.print_out(" save attention image to %s*" % image_file) 61 | image_summ = tf.Summary() 62 | image_summ.ParseFromString(infer_summary) 63 | with tf.gfile.GFile(image_file, mode="w") as img_f: 64 | img_f.write(image_summ.value[0].image.encoded_image_string) 65 | 66 | trans_f.write("%s\n" % translation) 67 | utils.print_out(translation + b"\n") 68 | utils.print_time(" done", start_time) 69 | 70 | 71 | def load_data(inference_input_file, hparams=None): 72 | """Load inference data.""" 73 | with codecs.getreader("utf-8")( 74 | tf.gfile.GFile(inference_input_file, mode="rb")) as f: 75 | inference_data = f.read().splitlines() 76 | 77 | if hparams and hparams.inference_indices: 78 | inference_data = [inference_data[i] for i in hparams.inference_indices] 79 | 80 | return inference_data 81 | 82 | 83 | def inference(ckpt, 84 | inference_input_file, 85 | inference_output_file, 86 | hparams, 87 | num_workers=1, 88 | jobid=0, 89 | scope=None): 90 | """Perform translation.""" 91 | if hparams.inference_indices: 92 | assert num_workers == 1 93 | 94 | if not hparams.attention: 95 | model_creator = nmt_model.Model 96 | elif hparams.attention_architecture == "standard": 97 | model_creator = attention_model.AttentionModel 98 | elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]: 99 | model_creator = gnmt_model.GNMTModel 100 | else: 101 | raise ValueError("Unknown model architecture") 102 | infer_model = model_helper.create_infer_model(model_creator, hparams, scope) 103 | 104 | if num_workers == 1: 105 | single_worker_inference( 106 | infer_model, 107 | ckpt, 108 | inference_input_file, 109 | inference_output_file, 110 | hparams) 111 | else: 112 | multi_worker_inference( 113 | infer_model, 114 | ckpt, 115 | inference_input_file, 116 | inference_output_file, 117 | hparams, 118 | num_workers=num_workers, 119 | jobid=jobid) 120 | 121 | 122 | def single_worker_inference(infer_model, 123 | ckpt, 124 | inference_input_file, 125 | inference_output_file, 126 | hparams): 127 | """Inference with a single worker.""" 128 | output_infer = inference_output_file 129 | 130 | # Read data 131 | infer_data = load_data(inference_input_file, hparams) 132 | 133 | with tf.Session( 134 | graph=infer_model.graph, config=utils.get_config_proto()) as sess: 135 | loaded_infer_model = model_helper.load_model( 136 | infer_model.model, ckpt, sess, "infer") 137 | sess.run( 138 | infer_model.iterator.initializer, 139 | feed_dict={ 140 | infer_model.src_placeholder: infer_data, 141 | infer_model.batch_size_placeholder: hparams.infer_batch_size 142 | }) 143 | # Decode 144 | utils.print_out("# Start decoding") 145 | if hparams.inference_indices: 146 | _decode_inference_indices( 147 | loaded_infer_model, 148 | sess, 149 | output_infer=output_infer, 150 | output_infer_summary_prefix=output_infer, 151 | inference_indices=hparams.inference_indices, 152 | tgt_eos=hparams.eos, 153 | subword_option=hparams.subword_option) 154 | else: 155 | nmt_utils.decode_and_evaluate( 156 | "infer", 157 | loaded_infer_model, 158 | sess, 159 | output_infer, 160 | ref_file=None, 161 | metrics=hparams.metrics, 162 | subword_option=hparams.subword_option, 163 | beam_width=hparams.beam_width, 164 | tgt_eos=hparams.eos, 165 | num_translations_per_input=hparams.num_translations_per_input) 166 | 167 | 168 | def multi_worker_inference(infer_model, 169 | ckpt, 170 | inference_input_file, 171 | inference_output_file, 172 | hparams, 173 | num_workers, 174 | jobid): 175 | """Inference using multiple workers.""" 176 | assert num_workers > 1 177 | 178 | final_output_infer = inference_output_file 179 | output_infer = "%s_%d" % (inference_output_file, jobid) 180 | output_infer_done = "%s_done_%d" % (inference_output_file, jobid) 181 | 182 | # Read data 183 | infer_data = load_data(inference_input_file, hparams) 184 | 185 | # Split data to multiple workers 186 | total_load = len(infer_data) 187 | load_per_worker = int((total_load - 1) / num_workers) + 1 188 | start_position = jobid * load_per_worker 189 | end_position = min(start_position + load_per_worker, total_load) 190 | infer_data = infer_data[start_position:end_position] 191 | 192 | with tf.Session( 193 | graph=infer_model.graph, config=utils.get_config_proto()) as sess: 194 | loaded_infer_model = model_helper.load_model( 195 | infer_model.model, ckpt, sess, "infer") 196 | sess.run(infer_model.iterator.initializer, 197 | { 198 | infer_model.src_placeholder: infer_data, 199 | infer_model.batch_size_placeholder: hparams.infer_batch_size 200 | }) 201 | # Decode 202 | utils.print_out("# Start decoding") 203 | nmt_utils.decode_and_evaluate( 204 | "infer", 205 | loaded_infer_model, 206 | sess, 207 | output_infer, 208 | ref_file=None, 209 | metrics=hparams.metrics, 210 | subword_option=hparams.subword_option, 211 | beam_width=hparams.beam_width, 212 | tgt_eos=hparams.eos, 213 | num_translations_per_input=hparams.num_translations_per_input) 214 | 215 | # Change file name to indicate the file writing is completed. 216 | tf.gfile.Rename(output_infer, output_infer_done, overwrite=True) 217 | 218 | # Job 0 is responsible for the clean up. 219 | if jobid != 0: return 220 | 221 | # Now write all translations 222 | with codecs.getwriter("utf-8")( 223 | tf.gfile.GFile(final_output_infer, mode="wb")) as final_f: 224 | for worker_id in range(num_workers): 225 | worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id) 226 | while not tf.gfile.Exists(worker_infer_done): 227 | utils.print_out(" waitting job %d to complete." % worker_id) 228 | time.sleep(10) 229 | 230 | with codecs.getreader("utf-8")( 231 | tf.gfile.GFile(worker_infer_done, mode="rb")) as f: 232 | for translation in f: 233 | final_f.write("%s" % translation) 234 | 235 | for worker_id in range(num_workers): 236 | worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id) 237 | tf.gfile.Remove(worker_infer_done) 238 | -------------------------------------------------------------------------------- /nmt/input_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Korrawe Karunratanakul 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | 19 | from tensorflow.python.ops import lookup_ops 20 | 21 | # ============================================================================== 22 | # MODEL HYPERPARAMETERS 23 | # ============================================================================== 24 | 25 | bin_step = 0.01 # 0.01 26 | inv_bin_step = 1.0/bin_step 27 | max_posi = 5000 28 | max_spec_length = int(max_posi * inv_bin_step) 29 | max_spec_length_cnn = max_spec_length//10 30 | 31 | aa_input_window_size = 0.2 #1.0 32 | # full_aa_window = int(aa_input_window_size * inv_bin_step) 33 | half_aa_window = int(aa_input_window_size * inv_bin_step/2) 34 | full_aa_window = half_aa_window * 2 35 | 36 | # knapsack parameters 37 | max_dp_sz = 1500 38 | dp_resolution = 0.0005 39 | inv_dp_resolution = 1.0/dp_resolution 40 | max_dp_array_size = int(round(max_dp_sz * inv_dp_resolution)) 41 | 42 | # ============================================================================== 43 | # GLOBAL VARIABLES for VOCABULARY 44 | # ============================================================================== 45 | 46 | 47 | # Special vocabulary symbols - we always put them at the start. 48 | _PAD = "_PAD" 49 | _GO = "_GO" 50 | _EOS = "_EOS" 51 | _START_VOCAB = [_PAD, _GO, _EOS] 52 | 53 | PAD_ID = 0 54 | GO_ID = 1 55 | EOS_ID = 2 56 | 57 | vocab_reverse = ['', 58 | '', 59 | '', 60 | 'A', 61 | 'C', 62 | # 'Cmod', 63 | 'D', 64 | 'E', 65 | 'F', 66 | 'G', 67 | 'H', 68 | 'I', 69 | 'K', 70 | 'L', 71 | 'M', 72 | 'm', # 'Mmod', 73 | 'N', 74 | # 'n', # 'Nmod', 75 | 'P', 76 | 'Q', 77 | # 'q', # 'Qmod', 78 | 'R', 79 | 'S', 80 | 's', 81 | 'T', 82 | 't', 83 | 'V', 84 | 'W', 85 | 'Y', 86 | 'y', 87 | ] 88 | 89 | # vocab_reverse = _START_VOCAB + vocab_reverse 90 | # print("vocab_reverse ", vocab_reverse) # 91 | vocab_reverse_np = np.array(vocab_reverse) 92 | # vocab = dict([(x, y) for (y, x) in enumerate(vocab_reverse)]) 93 | # print("vocab ", vocab) 94 | 95 | vocab_size_with_eos = len(vocab_reverse) 96 | vocab_size = len(vocab_reverse) - 3 97 | # print("vocab_size ", vocab_size) # 98 | 99 | 100 | # ============================================================================== 101 | # GLOBAL VARIABLES for THEORETICAL MASS 102 | # ============================================================================== 103 | 104 | 105 | mass_H = 1.007825 106 | mass_H2O = 18.0106 107 | mass_NH3 = 17.0265 108 | mass_N_terminus = 1.007825 109 | mass_C_terminus = 17.00274 110 | mass_CO = 27.9949 111 | 112 | mass_AA = {# '_PAD': 0.0, 113 | # '_GO': mass_N_terminus-mass_H, 114 | # '': mass_C_terminus+mass_H, 115 | '': 0.0, 116 | '': 0.0, 117 | '': 0.0, 118 | 'A': 71.03711, # 0 119 | # 'C': 103.00919, # 4 120 | 'C': 160.03065, # C(+57.02) 121 | #~ 'Cmod': 161.01919, # C(+58.01) # orbi 122 | 'R': 156.10111, # 1 123 | 'N': 114.04293, # 2 124 | # 'n': 115.02695, # N mod, N(+.98) 125 | 'D': 115.02694, # 3 126 | 127 | 'E': 129.04259, # 5 128 | 'Q': 128.05858, # 6 129 | # 'q': 129.0426, # Q mod, Q(+.98) 130 | 'G': 57.02146, # 7 131 | 'H': 137.05891, # 8 132 | 'I': 113.08406, # 9 133 | 'L': 113.08406, # 10 134 | 'K': 128.09496, # 11 135 | 'M': 131.04049, # 12 136 | 'm': 147.0354, # M mod, M(+15.99) 137 | 'F': 147.06841, # 13 138 | 'P': 97.05276, # 14 139 | 'S': 87.03203, # 15 140 | 's': 166.99836, # S mod, S(ph), S + 79.96633 141 | 'T': 101.04768, # 16 142 | 't': 181.01401, # T mod, T(ph), T + 79.96633 143 | 'W': 186.07931, # 17 144 | 'Y': 163.06333, # 18 145 | 'y': 243.02966, # Y mod, y(ph), Y + 79.96633 146 | 'V': 99.06841, # 19 147 | } 148 | 149 | vocab_ID = np.array(list(range(vocab_size_with_eos)), dtype=np.int32) 150 | # print(vocab_ID) # 151 | mass_ID = [mass_AA[vocab_reverse[x]] for x in range(vocab_size_with_eos)] 152 | mass_ID_np_with_eos = np.array(mass_ID, dtype=np.float32) 153 | # print(mass_ID_np_with_eos) # 154 | mass_ID_np = mass_ID_np_with_eos[3:] 155 | # print(mass_ID_np) # 156 | 157 | mass_AA_min = mass_AA["G"] # 57.02146 158 | 159 | def create_aa_tables(): 160 | """Creates amino acid tables.""" 161 | # mass_ID_tf = tf.convert_to_tensor(mass_ID_np_with_eos, dtype=tf.float32) 162 | # vocab_tf = tf.convert_to_tensor(vocab_ID, dtype=tf.int64)#int32) 163 | 164 | # AA_weight_table = lookup_ops.HashTable( 165 | # lookup_ops.KeyValueTensorInitializer(vocab_tf, mass_ID_tf), -1) 166 | with tf.device("/gpu:0"): 167 | # AA_weight_table = lookup_ops.index_to_string_table_from_file( 168 | # "nmt/vocab/mass.txt", default_value="0.0") 169 | AA_weight_table = tf.constant([0.0, 0.0, 0.0, 71.03711, 160.03065, 170 | 115.02694, 129.04259, 147.06841, 57.02146, 137.05891, 171 | 113.08406, 128.09496, 113.08406, 131.04049, 147.0354, 172 | 114.04293, 97.05276, 128.05858, 156.10111, 87.03203, 173 | 166.99836, 101.04768, 181.01401, 99.06841, 186.07931, 174 | 163.06333, 243.02966]) # M mod + ph mod 175 | 176 | # AA_weight_table = tf.constant([0.0, 0.0, 0.0, 71.03711, 160.03065, 177 | # 115.02694, 129.04259, 147.06841, 57.02146, 137.05891, 178 | # 113.08406, 128.09496, 113.08406, 131.04049, 147.0354, 179 | # 114.04293, 97.05276, 128.05858, 156.10111, 87.03203, 180 | # 101.04768, 99.06841, 186.07931, 163.06333]) # M mod 181 | 182 | # AA_weight_table = tf.constant([0.0, 0.0, 0.0, 71.03711, 160.03065, 183 | # 115.02694, 129.04259, 147.06841, 57.02146, 137.05891, 184 | # 113.08406, 128.09496, 113.08406, 131.04049, 185 | # 114.04293, 97.05276, 128.05858, 156.10111, 87.03203, 186 | # 101.04768, 99.06841, 186.07931, 163.06333]) # No mod 187 | return AA_weight_table 188 | -------------------------------------------------------------------------------- /nmt/model_helper.py: -------------------------------------------------------------------------------- 1 | """Utility functions for building models.""" 2 | from __future__ import print_function 3 | 4 | import collections 5 | import six 6 | import os 7 | import time 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | from tensorflow.python.ops import lookup_ops 13 | 14 | from .utils import iterator_utils 15 | from .utils import misc_utils as utils 16 | from .utils import vocab_utils 17 | 18 | from . import input_config 19 | 20 | 21 | __all__ = [ 22 | "get_initializer", "get_device_str", "create_train_model", 23 | "create_eval_model", "create_infer_model", 24 | "create_emb_for_encoder_and_decoder", "create_rnn_cell", "gradient_clip", 25 | "create_or_load_model", "load_model", "avg_checkpoints", 26 | "compute_perplexity" 27 | ] 28 | 29 | # If a vocab size is greater than this value, put the embedding on cpu instead 30 | VOCAB_SIZE_THRESHOLD_CPU = 50000 31 | 32 | 33 | def get_initializer(init_op, seed=None, init_weight=None): 34 | """Create an initializer. init_weight is only for uniform.""" 35 | if init_op == "uniform": 36 | assert init_weight 37 | return tf.random_uniform_initializer( 38 | -init_weight, init_weight, seed=seed) 39 | elif init_op == "glorot_normal": 40 | return tf.keras.initializers.glorot_normal( 41 | seed=seed) 42 | elif init_op == "glorot_uniform": 43 | return tf.keras.initializers.glorot_uniform( 44 | seed=seed) 45 | else: 46 | raise ValueError("Unknown init_op %s" % init_op) 47 | 48 | 49 | def get_device_str(device_id, num_gpus): 50 | """Return a device string for multi-GPU setup.""" 51 | if num_gpus == 0: 52 | return "/cpu:0" 53 | device_str_output = "/gpu:%d" % (device_id % num_gpus) 54 | return device_str_output 55 | 56 | 57 | class ExtraArgs(collections.namedtuple( 58 | "ExtraArgs", ("single_cell_fn", "model_device_fn", 59 | "attention_mechanism_fn"))): 60 | pass 61 | 62 | 63 | class TrainModel( 64 | collections.namedtuple("TrainModel", ("graph", "model", "iterator", 65 | "skip_count_placeholder"))): 66 | pass 67 | 68 | 69 | def create_train_model( 70 | model_creator, hparams, scope=None, num_workers=1, jobid=0, 71 | extra_args=None): 72 | """Create train graph, model, and iterator.""" 73 | src_file = [] 74 | for file_name in hparams.src: 75 | src_file.append("%s%s" % (file_name, hparams.src_suffix )) 76 | tgt_vocab_file = hparams.tgt_vocab_file 77 | 78 | graph = tf.Graph() 79 | 80 | with graph.as_default(), tf.container(scope or "train"): 81 | tgt_vocab_table = vocab_utils.create_vocab_tables(tgt_vocab_file) 82 | aa_weight_table = input_config.create_aa_tables() 83 | 84 | skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) 85 | 86 | iterator = iterator_utils.get_iterator( 87 | src_file, 88 | tgt_vocab_table, 89 | batch_size=hparams.batch_size, 90 | sos=hparams.sos, 91 | eos=hparams.eos, 92 | aa_weight_table=aa_weight_table, 93 | random_seed=hparams.random_seed, 94 | num_buckets=hparams.num_buckets, 95 | src_max_len=hparams.src_max_len, 96 | tgt_max_len=hparams.tgt_max_len, 97 | skip_count=skip_count_placeholder, 98 | num_shards=num_workers, 99 | shard_index=jobid) 100 | 101 | model_device_fn = None 102 | if extra_args: model_device_fn = extra_args.model_device_fn 103 | with tf.device(model_device_fn): 104 | model = model_creator( 105 | hparams, 106 | iterator=iterator, 107 | mode=tf.contrib.learn.ModeKeys.TRAIN, 108 | aa_weight_table=aa_weight_table, 109 | target_vocab_table=tgt_vocab_table, 110 | scope=scope, 111 | extra_args=extra_args) 112 | 113 | return TrainModel( 114 | graph=graph, 115 | model=model, 116 | iterator=iterator, 117 | skip_count_placeholder=skip_count_placeholder) 118 | 119 | 120 | class EvalModel( 121 | collections.namedtuple("EvalModel", 122 | ("graph", "model", "src_file_placeholder", 123 | "iterator"))): 124 | pass 125 | 126 | 127 | def create_eval_model(model_creator, hparams, scope=None, extra_args=None): 128 | """Create train graph, model, src/tgt file holders, and iterator.""" 129 | tgt_vocab_file = hparams.tgt_vocab_file 130 | graph = tf.Graph() 131 | 132 | with graph.as_default(), tf.container(scope or "eval"): 133 | tgt_vocab_table = vocab_utils.create_vocab_tables(tgt_vocab_file) 134 | aa_weight_table = input_config.create_aa_tables() 135 | src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) ### 136 | iterator = iterator_utils.get_iterator( 137 | [src_file_placeholder], 138 | tgt_vocab_table, 139 | batch_size=hparams.batch_size, 140 | sos=hparams.sos, 141 | eos=hparams.eos, 142 | aa_weight_table=aa_weight_table, 143 | random_seed=hparams.random_seed, 144 | num_buckets=hparams.num_buckets, 145 | src_max_len=hparams.src_max_len, 146 | tgt_max_len=hparams.tgt_max_len_infer) 147 | model = model_creator( 148 | hparams, 149 | iterator=iterator, 150 | mode=tf.contrib.learn.ModeKeys.EVAL, 151 | aa_weight_table=aa_weight_table, 152 | target_vocab_table=tgt_vocab_table, 153 | scope=scope, 154 | extra_args=extra_args) 155 | return EvalModel( 156 | graph=graph, 157 | model=model, 158 | src_file_placeholder=src_file_placeholder, 159 | iterator=iterator) 160 | 161 | 162 | class InferModel( 163 | collections.namedtuple("InferModel", 164 | ("graph", "model", "src_placeholder", 165 | "batch_size_placeholder", "iterator"))): 166 | pass 167 | 168 | 169 | def create_infer_model(model_creator, hparams, scope=None, extra_args=None): 170 | """Create inference model.""" 171 | graph = tf.Graph() 172 | tgt_vocab_file = hparams.tgt_vocab_file 173 | 174 | with graph.as_default(), tf.container(scope or "infer"): 175 | tgt_vocab_table = vocab_utils.create_vocab_tables(tgt_vocab_file) 176 | aa_weight_table = input_config.create_aa_tables() 177 | reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( 178 | tgt_vocab_file, default_value=vocab_utils.UNK) 179 | 180 | src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) 181 | batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64) 182 | 183 | src_dataset = tf.data.Dataset.from_tensor_slices( 184 | src_placeholder) 185 | iterator = iterator_utils.get_infer_iterator( 186 | src_dataset, 187 | tgt_vocab_table, 188 | batch_size=batch_size_placeholder, 189 | aa_weight_table=aa_weight_table, 190 | ) 191 | model = model_creator( 192 | hparams, 193 | iterator=iterator, 194 | mode=tf.contrib.learn.ModeKeys.INFER, 195 | aa_weight_table=aa_weight_table, 196 | target_vocab_table=tgt_vocab_table, 197 | reverse_target_vocab_table=reverse_tgt_vocab_table, 198 | scope=scope, 199 | extra_args=extra_args) 200 | return InferModel( 201 | graph=graph, 202 | model=model, 203 | src_placeholder=src_placeholder, 204 | batch_size_placeholder=batch_size_placeholder, 205 | iterator=iterator) 206 | 207 | 208 | def _get_embed_device(vocab_size): 209 | """Decide on which device to place an embed matrix given its vocab size.""" 210 | if vocab_size > VOCAB_SIZE_THRESHOLD_CPU: 211 | return "/cpu:0" 212 | else: 213 | return "/gpu:0" 214 | 215 | 216 | def _create_pretrained_emb_from_txt( 217 | vocab_file, embed_file, num_trainable_tokens=3, dtype=tf.float32, 218 | scope=None): 219 | """Load pretrain embeding from embed_file, and return an embedding matrix. 220 | 221 | Args: 222 | embed_file: Path to a Glove formated embedding txt file. 223 | num_trainable_tokens: Make the first n tokens in the vocab file as trainable 224 | variables. Default is 3, which is "", "" and "". 225 | """ 226 | vocab, _ = vocab_utils.load_vocab(vocab_file) 227 | trainable_tokens = vocab[:num_trainable_tokens] 228 | 229 | utils.print_out("# Using pretrained embedding: %s." % embed_file) 230 | utils.print_out(" with trainable tokens: ") 231 | 232 | emb_dict, emb_size = vocab_utils.load_embed_txt(embed_file) 233 | for token in trainable_tokens: 234 | utils.print_out(" %s" % token) 235 | if token not in emb_dict: 236 | emb_dict[token] = [0.0] * emb_size 237 | 238 | emb_mat = np.array( 239 | [emb_dict[token] for token in vocab], dtype=dtype.as_numpy_dtype()) 240 | emb_mat = tf.constant(emb_mat) 241 | emb_mat_const = tf.slice(emb_mat, [num_trainable_tokens, 0], [-1, -1]) 242 | with tf.variable_scope(scope or "pretrain_embeddings", dtype=dtype) as scope: 243 | with tf.device(_get_embed_device(num_trainable_tokens)): 244 | emb_mat_var = tf.get_variable( 245 | "emb_mat_var", [num_trainable_tokens, emb_size]) 246 | return tf.concat([emb_mat_var, emb_mat_const], 0) 247 | 248 | 249 | def _create_or_load_embed(embed_name, vocab_file, embed_file, 250 | vocab_size, embed_size, dtype): 251 | """Create a new or load an existing embedding matrix.""" 252 | if vocab_file and embed_file: 253 | embedding = _create_pretrained_emb_from_txt(vocab_file, embed_file) 254 | else: 255 | with tf.device(_get_embed_device(vocab_size)): 256 | embedding = tf.get_variable( 257 | embed_name, [vocab_size, embed_size], dtype) 258 | return embedding 259 | 260 | 261 | def create_emb_for_decoder(#share_vocab, 262 | #src_vocab_size, 263 | tgt_vocab_size, 264 | #src_embed_size, 265 | tgt_embed_size, 266 | dtype=tf.float32, 267 | num_partitions=0, 268 | #src_vocab_file=None, 269 | tgt_vocab_file=None, 270 | #src_embed_file=None, 271 | tgt_embed_file=None, 272 | scope=None): 273 | """Create embedding matrix for both encoder and decoder. 274 | 275 | Args: 276 | share_vocab: A boolean. Whether to share embedding matrix for both 277 | encoder and decoder. 278 | src_vocab_size: An integer. The source vocab size. 279 | tgt_vocab_size: An integer. The target vocab size. 280 | src_embed_size: An integer. The embedding dimension for the encoder's 281 | embedding. 282 | tgt_embed_size: An integer. The embedding dimension for the decoder's 283 | embedding. 284 | dtype: dtype of the embedding matrix. Default to float32. 285 | num_partitions: number of partitions used for the embedding vars. 286 | scope: VariableScope for the created subgraph. Default to "embedding". 287 | 288 | Returns: 289 | embedding_encoder: Encoder's embedding matrix. 290 | embedding_decoder: Decoder's embedding matrix. 291 | 292 | Raises: 293 | ValueError: if use share_vocab but source and target have different vocab 294 | size. 295 | """ 296 | 297 | if num_partitions <= 1: 298 | partitioner = None 299 | else: 300 | # Note: num_partitions > 1 is required for distributed training due to 301 | # embedding_lookup tries to colocate single partition-ed embedding variable 302 | # with lookup ops. This may cause embedding variables being placed on worker 303 | # jobs. 304 | partitioner = tf.fixed_size_partitioner(num_partitions) 305 | 306 | if tgt_embed_file and partitioner: 307 | raise ValueError( 308 | "Can't set num_partitions > 1 when using pretrained embedding") 309 | 310 | with tf.variable_scope("decoder", partitioner=partitioner): 311 | embedding_decoder = _create_or_load_embed( 312 | "embedding_decoder", tgt_vocab_file, tgt_embed_file, 313 | tgt_vocab_size, tgt_embed_size, dtype) 314 | 315 | return embedding_decoder 316 | 317 | 318 | def _single_cell(unit_type, num_units, forget_bias, dropout, mode, 319 | residual_connection=False, device_str=None, residual_fn=None): 320 | """Create an instance of a single RNN cell.""" 321 | # dropout (= 1 - keep_prob) is set to 0 during eval and infer 322 | dropout = dropout if mode == tf.contrib.learn.ModeKeys.TRAIN else 0.0 323 | 324 | # Cell Type 325 | if unit_type == "lstm": 326 | utils.print_out(" LSTM, forget_bias=%g" % forget_bias, new_line=False) 327 | single_cell = tf.contrib.rnn.BasicLSTMCell( 328 | num_units, 329 | forget_bias=forget_bias) 330 | elif unit_type == "gru": 331 | utils.print_out(" GRU", new_line=False) 332 | single_cell = tf.contrib.rnn.GRUCell(num_units) 333 | elif unit_type == "layer_norm_lstm": 334 | utils.print_out(" Layer Normalized LSTM, forget_bias=%g" % forget_bias, 335 | new_line=False) 336 | single_cell = tf.contrib.rnn.LayerNormBasicLSTMCell( 337 | num_units, 338 | forget_bias=forget_bias, 339 | layer_norm=True) 340 | elif unit_type == "nas": 341 | utils.print_out(" NASCell", new_line=False) 342 | single_cell = tf.contrib.rnn.NASCell(num_units) 343 | else: 344 | raise ValueError("Unknown unit type %s!" % unit_type) 345 | 346 | # Dropout (= 1 - keep_prob) 347 | if dropout > 0.0: 348 | single_cell = tf.contrib.rnn.DropoutWrapper( 349 | cell=single_cell, input_keep_prob=(1.0 - dropout)) 350 | utils.print_out(" %s, dropout=%g " %(type(single_cell).__name__, dropout), 351 | new_line=False) 352 | 353 | # Residual 354 | if residual_connection: 355 | single_cell = tf.contrib.rnn.ResidualWrapper( 356 | single_cell, residual_fn=residual_fn) 357 | utils.print_out(" %s" % type(single_cell).__name__, new_line=False) 358 | 359 | # Device Wrapper 360 | if device_str: 361 | single_cell = tf.contrib.rnn.DeviceWrapper(single_cell, device_str) 362 | utils.print_out(" %s, device=%s" % 363 | (type(single_cell).__name__, device_str), new_line=False) 364 | 365 | return single_cell 366 | 367 | 368 | def _cell_list(unit_type, num_units, num_layers, num_residual_layers, 369 | forget_bias, dropout, mode, num_gpus, base_gpu=0, 370 | single_cell_fn=None, residual_fn=None): 371 | """Create a list of RNN cells.""" 372 | if not single_cell_fn: 373 | single_cell_fn = _single_cell 374 | 375 | # Multi-GPU 376 | cell_list = [] 377 | for i in range(num_layers): 378 | utils.print_out(" cell %d" % i, new_line=False) 379 | single_cell = single_cell_fn( 380 | unit_type=unit_type, 381 | num_units=num_units, 382 | forget_bias=forget_bias, 383 | dropout=dropout, 384 | mode=mode, 385 | residual_connection=(i >= num_layers - num_residual_layers), 386 | device_str=get_device_str(i + base_gpu, num_gpus), 387 | residual_fn=residual_fn 388 | ) 389 | utils.print_out("") 390 | cell_list.append(single_cell) 391 | 392 | return cell_list 393 | 394 | 395 | def create_rnn_cell(unit_type, num_units, num_layers, num_residual_layers, 396 | forget_bias, dropout, mode, num_gpus, base_gpu=0, 397 | single_cell_fn=None): 398 | """Create multi-layer RNN cell. 399 | 400 | Args: 401 | unit_type: string representing the unit type, i.e. "lstm". 402 | num_units: the depth of each unit. 403 | num_layers: number of cells. 404 | num_residual_layers: Number of residual layers from top to bottom. For 405 | example, if `num_layers=4` and `num_residual_layers=2`, the last 2 RNN 406 | cells in the returned list will be wrapped with `ResidualWrapper`. 407 | forget_bias: the initial forget bias of the RNNCell(s). 408 | dropout: floating point value between 0.0 and 1.0: 409 | the probability of dropout. this is ignored if `mode != TRAIN`. 410 | mode: either tf.contrib.learn.TRAIN/EVAL/INFER 411 | num_gpus: The number of gpus to use when performing round-robin 412 | placement of layers. 413 | base_gpu: The gpu device id to use for the first RNN cell in the 414 | returned list. The i-th RNN cell will use `(base_gpu + i) % num_gpus` 415 | as its device id. 416 | single_cell_fn: allow for adding customized cell. 417 | When not specified, we default to model_helper._single_cell 418 | Returns: 419 | An `RNNCell` instance. 420 | """ 421 | cell_list = _cell_list(unit_type=unit_type, 422 | num_units=num_units, 423 | num_layers=num_layers, 424 | num_residual_layers=num_residual_layers, 425 | forget_bias=forget_bias, 426 | dropout=dropout, 427 | mode=mode, 428 | num_gpus=num_gpus, 429 | base_gpu=base_gpu, 430 | single_cell_fn=single_cell_fn) 431 | 432 | if len(cell_list) == 1: # Single layer. 433 | return cell_list[0] 434 | else: # Multi layers 435 | return tf.contrib.rnn.MultiRNNCell(cell_list) 436 | 437 | 438 | def gradient_clip(gradients, max_gradient_norm): 439 | """Clipping gradients of a model.""" 440 | clipped_gradients, gradient_norm = tf.clip_by_global_norm( 441 | gradients, max_gradient_norm) 442 | gradient_norm_summary = [tf.summary.scalar("grad_norm", gradient_norm)] 443 | gradient_norm_summary.append( 444 | tf.summary.scalar("clipped_gradient", tf.global_norm(clipped_gradients))) 445 | 446 | return clipped_gradients, gradient_norm_summary, gradient_norm 447 | 448 | 449 | def load_model(model, ckpt, session, name): 450 | start_time = time.time() 451 | model.saver.restore(session, ckpt) 452 | session.run(tf.tables_initializer()) 453 | utils.print_out( 454 | " loaded %s model parameters from %s, time %.2fs" % 455 | (name, ckpt, time.time() - start_time)) 456 | return model 457 | 458 | 459 | def avg_checkpoints(model_dir, num_last_checkpoints, global_step, 460 | global_step_name): 461 | """Average the last N checkpoints in the model_dir.""" 462 | checkpoint_state = tf.train.get_checkpoint_state(model_dir) 463 | if not checkpoint_state: 464 | utils.print_out("# No checkpoint file found in directory: %s" % model_dir) 465 | return None 466 | 467 | # Checkpoints are ordered from oldest to newest. 468 | checkpoints = ( 469 | checkpoint_state.all_model_checkpoint_paths[-num_last_checkpoints:]) 470 | 471 | if len(checkpoints) < num_last_checkpoints: 472 | utils.print_out( 473 | "# Skipping averaging checkpoints because not enough checkpoints is " 474 | "avaliable." 475 | ) 476 | return None 477 | 478 | avg_model_dir = os.path.join(model_dir, "avg_checkpoints") 479 | if not tf.gfile.Exists(avg_model_dir): 480 | utils.print_out( 481 | "# Creating new directory %s for saving averaged checkpoints." % 482 | avg_model_dir) 483 | tf.gfile.MakeDirs(avg_model_dir) 484 | 485 | utils.print_out("# Reading and averaging variables in checkpoints:") 486 | var_list = tf.contrib.framework.list_variables(checkpoints[0]) 487 | var_values, var_dtypes = {}, {} 488 | for (name, shape) in var_list: 489 | if name != global_step_name: 490 | var_values[name] = np.zeros(shape) 491 | 492 | for checkpoint in checkpoints: 493 | utils.print_out(" %s" % checkpoint) 494 | reader = tf.contrib.framework.load_checkpoint(checkpoint) 495 | for name in var_values: 496 | tensor = reader.get_tensor(name) 497 | var_dtypes[name] = tensor.dtype 498 | var_values[name] += tensor 499 | 500 | for name in var_values: 501 | var_values[name] /= len(checkpoints) 502 | 503 | # Build a graph with same variables in the checkpoints, and save the averaged 504 | # variables into the avg_model_dir. 505 | with tf.Graph().as_default(): 506 | tf_vars = [ 507 | tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[name]) 508 | for v in var_values 509 | ] 510 | 511 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] 512 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] 513 | global_step_var = tf.Variable( 514 | global_step, name=global_step_name, trainable=False) 515 | saver = tf.train.Saver(tf.all_variables()) 516 | 517 | with tf.Session() as sess: 518 | sess.run(tf.initialize_all_variables()) 519 | for p, assign_op, (name, value) in zip(placeholders, assign_ops, 520 | six.iteritems(var_values)): 521 | sess.run(assign_op, {p: value}) 522 | 523 | # Use the built saver to save the averaged checkpoint. Only keep 1 524 | # checkpoint and the best checkpoint will be moved to avg_best_metric_dir. 525 | saver.save( 526 | sess, 527 | os.path.join(avg_model_dir, "translate.ckpt")) 528 | 529 | return avg_model_dir 530 | 531 | 532 | def create_or_load_model(model, model_dir, session, name): 533 | """Create translation model and initialize or load parameters in session.""" 534 | latest_ckpt = tf.train.latest_checkpoint(model_dir) 535 | if latest_ckpt: 536 | model = load_model(model, latest_ckpt, session, name) 537 | else: 538 | start_time = time.time() 539 | session.run(tf.global_variables_initializer()) 540 | session.run(tf.tables_initializer()) 541 | compute_dp_array(model.suffix_dp_table).run(session=session) 542 | utils.print_out(" created %s model with fresh parameters, time %.2fs" % 543 | (name, time.time() - start_time)) 544 | 545 | global_step = model.global_step.eval(session=session) 546 | return model, global_step 547 | 548 | 549 | def _refine_dp_array(original_dp_array, max_dp_array_size, inv_dp_resolution, error_window): 550 | utils.print_out(" Refining DP array to allow %f Da error window" % (error_window)) 551 | 552 | refined_dp_array = np.zeros(max_dp_array_size, dtype=bool) 553 | # error_window = 0.2 554 | error_window_size = int(round(error_window * inv_dp_resolution)) 555 | 556 | for i in range(max_dp_array_size): 557 | start = max(0, i - error_window_size) 558 | end = min(max_dp_array_size, i + error_window_size) 559 | if np.sum(original_dp_array[start:end]) > 0: 560 | refined_dp_array[i] = True 561 | 562 | return refined_dp_array 563 | 564 | def compute_dp_array(dp_array_tensor): 565 | start_time = time.time() 566 | utils.print_out(" Computing DP array of size %f with resolution %f" % (input_config.max_dp_sz, input_config.dp_resolution)) 567 | max_dp_array_size = input_config.max_dp_array_size 568 | inv_dp_resolution = input_config.inv_dp_resolution 569 | 570 | dp = np.zeros(max_dp_array_size, dtype=bool) 571 | dp[0] = True 572 | 573 | for mass in input_config.mass_ID_np: 574 | mass = int(round(mass * inv_dp_resolution)) 575 | # print(mass) 576 | for i in range(max_dp_array_size): 577 | if dp[i] and i + mass < max_dp_array_size: 578 | dp[i + mass] = True 579 | 580 | # Refine with 0.2 Da possible error 581 | error_window = 0.2 582 | dp_table_value = _refine_dp_array(dp, max_dp_array_size, inv_dp_resolution, error_window) 583 | 584 | utils.print_out(" created suffix DP array , time %.2fs" % (time.time() - start_time)) 585 | # dp_array_tensor.assign(dp_table_value).op.run() 586 | return dp_array_tensor.assign(dp_table_value).op 587 | 588 | 589 | def compute_perplexity(model, sess, name): 590 | """Compute perplexity of the output of the model. 591 | 592 | Args: 593 | model: model for compute perplexity. 594 | sess: tensorflow session to use. 595 | name: name of the batch. 596 | 597 | Returns: 598 | The perplexity of the eval outputs. 599 | Total loss of the eval outputs. 600 | """ 601 | total_loss = 0 602 | total_predict_count = 0 603 | start_time = time.time() 604 | 605 | while True: 606 | try: 607 | loss, predict_count, batch_size = model.eval(sess) 608 | total_loss += loss * batch_size 609 | total_predict_count += predict_count 610 | except tf.errors.OutOfRangeError: 611 | break 612 | 613 | perplexity = utils.safe_exp(total_loss / total_predict_count) 614 | utils.print_time(" eval %s: perplexity %.2f" % (name, perplexity), 615 | start_time) 616 | return perplexity -------------------------------------------------------------------------------- /nmt/post_process.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Korrawe Karunratanakul 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import argparse 17 | import os 18 | import numpy as np 19 | 20 | from nmt import input_config 21 | from nmt.utils import file_utils 22 | from sklearn.model_selection import train_test_split 23 | from keras.models import Sequential, Model 24 | from keras.layers import Embedding, Reshape, Activation, Input, Dense, Reshape, Dropout, Flatten 25 | from keras.optimizers import Adam 26 | from keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau 27 | 28 | 29 | def cal_mass(line): 30 | if not line: return [] 31 | # print(line) 32 | AAs = line.split(" ") 33 | aa_mass = [] 34 | for aa in AAs: 35 | if aa in input_config.mass_AA and not (aa == '' or aa == '' or aa == ''): 36 | aa_mass.append(input_config.mass_AA[aa]) 37 | else: 38 | return [] 39 | 40 | return aa_mass 41 | 42 | 43 | def compare_mass_without_mask(nmt_mass, ref_mass, label='nmt'): 44 | nmt_len = len(nmt_mass) 45 | ref_len = len(ref_mass) 46 | if abs(sum(nmt_mass) - sum(ref_mass)) > 0.05: 47 | if label == 'nmt': 48 | return [0] * nmt_len 49 | else: 50 | return [0] * ref_len 51 | 52 | pred_label = [] 53 | 54 | i, j = 0, 0 55 | sum_nmt = 0.0 56 | sum_ref = 0.0 57 | while i < nmt_len and j < ref_len: 58 | if abs(sum_nmt - sum_ref) < 0.03 and abs(nmt_mass[i] - ref_mass[j]) < 0.0001: 59 | pred_label.append(1) 60 | sum_nmt += nmt_mass[i] 61 | sum_ref += ref_mass[j] 62 | i += 1 63 | j += 1 64 | else: 65 | if sum_nmt < sum_ref: 66 | sum_nmt += nmt_mass[i] 67 | if label == 'nmt': 68 | pred_label.append(0) 69 | i += 1 70 | else: 71 | sum_ref += ref_mass[j] 72 | if label == 'ref': 73 | pred_label.append(0) 74 | j += 1 75 | 76 | if label == 'nmt': 77 | while i < nmt_len: 78 | pred_label.append(0) 79 | i += 1 80 | 81 | if label == 'ref': 82 | while j < ref_len: 83 | pred_label.append(0) 84 | j += 1 85 | 86 | return pred_label 87 | 88 | 89 | def prepare_label(file_data): 90 | labels = [] 91 | for idx in range(file_data["length"]): 92 | ref_mass = cal_mass(file_data["ref_seqs"][idx]) 93 | nmt_mass = cal_mass(file_data["nmt_seqs"][idx]) 94 | if not ref_mass: 95 | labels.append([]) 96 | continue 97 | result = compare_mass_without_mask(nmt_mass, ref_mass, label='nmt') 98 | labels.append(result) 99 | return labels 100 | 101 | 102 | def prepare_feature(file_data, log_dir, labels_list=None): 103 | features = [] 104 | labels = [] 105 | positions = [] 106 | for idx in range(file_data["length"]): 107 | probs = file_data["probs"][idx] 108 | if (labels_list and len(labels_list[idx]) <= 0 ) or probs[0] == '': 109 | continue 110 | 111 | probs = [np.exp(float(prob)) for prob in probs] 112 | 113 | # seq features 114 | seq_feat = [] 115 | # num of aa in seq 116 | seq_feat.append(len(probs)) 117 | # num of probs higher than 0.8 and 0.9 118 | seq_feat.append(sum([1 if prob > 0.7 else 0 for prob in probs])) 119 | seq_feat.append(sum([1 if prob > 0.8 else 0 for prob in probs])) 120 | seq_feat.append(sum([1 if prob > 0.9 else 0 for prob in probs])) 121 | # geometric mean of seq 122 | seq_feat.append(np.exp(np.sum(np.log(probs))/len(probs))) 123 | 124 | for posi in range(len(probs)): 125 | feat = seq_feat.copy() 126 | feat.append(posi/len(probs)) 127 | # aa features 128 | aa_feat = [probs[posi+i] if (posi+i >= 0 and posi+i < len(probs)) else 1.0 for i in range(-2,2)] 129 | feat.extend(aa_feat) 130 | 131 | features.append(feat) 132 | if labels_list: 133 | labels.append(labels_list[idx][posi]) 134 | else: 135 | positions.append((idx, posi)) 136 | 137 | features = np.array(features) 138 | labels = np.array(labels) 139 | 140 | if labels_list: 141 | X_train, X_val, y_train, y_val = train_test_split(features, labels, test_size=0.1, random_state=42) 142 | feature_mean = np.mean(X_train, axis=0) 143 | feature_std = np.std(X_train, axis=0) 144 | np.savetxt(os.path.join(log_dir, 'feature_mean.txt'), feature_mean, fmt='%f') 145 | np.savetxt(os.path.join(log_dir,'feature_std.txt'), feature_std, fmt='%f') 146 | 147 | X_train = (X_train - feature_mean) / feature_std 148 | X_val = (X_val - feature_mean) / feature_std 149 | 150 | data_dict = {"train_x": X_train, "train_y": y_train, 151 | "val_x": X_val, "val_y": y_val} 152 | 153 | else: 154 | feature_mean = np.loadtxt(os.path.join(log_dir, 'feature_mean.txt'), dtype=float) 155 | feature_std = np.loadtxt(os.path.join(log_dir, 'feature_std.txt'), dtype=float) 156 | features = (features - feature_mean) / feature_std 157 | data_dict = {"features": features, "positions": positions} 158 | 159 | return data_dict 160 | 161 | 162 | def prepare_data_infer(file_data, log_dir): 163 | data_dict = prepare_feature(file_data, log_dir) 164 | return data_dict["features"], data_dict["positions"] 165 | 166 | 167 | def prepare_data(file_data, log_dir): 168 | labels_list = prepare_label(file_data) 169 | data_dict = prepare_feature(file_data, log_dir, labels_list) 170 | 171 | return data_dict["train_x"], data_dict["val_x"], data_dict["train_y"], data_dict["val_y"] 172 | 173 | 174 | def get_model(input_shape): 175 | input1 = Input(shape=(input_shape,)) 176 | x = Dense(64, activation='relu')(input1) 177 | x = Dense(64, activation='relu')(x) 178 | out = Dense(1, activation='sigmoid')(x) 179 | 180 | model = Model(inputs=input1, outputs=out) 181 | model.compile(optimizer=Adam(), 182 | loss='binary_crossentropy', 183 | metrics=['acc']) 184 | return model 185 | 186 | 187 | def print_predicted_prob(test_probs, test_position_list, output_filename): 188 | with open(output_filename, 'w') as output_file: 189 | current_row = 0 190 | test_probs = np.squeeze(test_probs) 191 | out_str = "" 192 | for i, prob in enumerate(test_probs): 193 | current_out_row = test_position_list[i][0] 194 | while(current_row < current_out_row): 195 | current_row += 1 196 | output_file.write(out_str.strip() + "\n") 197 | out_str = "" 198 | 199 | out_str += str(np.log(prob)) + " " 200 | output_file.write(out_str.strip() + "\n") 201 | 202 | 203 | def train(output_filename, prob_filename, tgt_filename, spectrum_filename, log_dir="log_post_process"): 204 | if not log_dir: 205 | log_dir="post_process" 206 | data_content = file_utils.read_output_file(output_filename, prob_filename, 207 | tgt_filename, spectrum_filename) 208 | X_train, X_val, y_train, y_val = prepare_data(data_content, log_dir=log_dir) 209 | input_shape = X_train.shape[1] 210 | model = get_model(input_shape) 211 | model.summary() 212 | 213 | callbacks_list = [ 214 | ModelCheckpoint( 215 | log_dir + "/post_processing_model_weight.h5", 216 | monitor = "val_loss", 217 | mode = 'min', 218 | verbose = 1, 219 | save_best_only = True, 220 | save_weights_only = True, 221 | ), 222 | ] 223 | model.fit(X_train, y_train, epochs=30, batch_size=256, verbose=1, validation_data=(X_val, y_val), callbacks=callbacks_list) 224 | 225 | 226 | 227 | def rescore(output_filename, prob_filename, spectrum_filename, 228 | log_dir="post_process", output_path=None): 229 | if not log_dir: 230 | log_dir="post_process" 231 | if not output_path: 232 | output_path = os.path.join(log_dir, "rescore_prob") 233 | test_data_content = file_utils.read_output_file(output_filename, prob_filename, 234 | None, spectrum_filename) 235 | test_x, test_positions = prepare_data_infer(test_data_content, log_dir=log_dir) 236 | 237 | input_shape = test_x.shape[1] 238 | model = get_model(input_shape) 239 | model.load_weights(log_dir + "/post_processing_model_weight.h5") 240 | test_probs = model.predict(test_x,verbose=1) 241 | 242 | print_predicted_prob(test_probs, test_positions, output_path) 243 | 244 | 245 | if __name__ == "__main__": 246 | parser = argparse.ArgumentParser() 247 | parser.register("type", "bool", lambda v: v.lower() == "true") 248 | 249 | parser.add_argument("--train", type="bool", nargs="?", const=True, 250 | default=False, 251 | help="Train new model.") 252 | parser.add_argument("--rescore", type="bool", nargs="?", const=True, 253 | default=False, 254 | help="Rescore with previously trained model.") 255 | parser.add_argument("--output_file", type=str, default=None, 256 | help="Predition file from main model.") 257 | parser.add_argument("--prob_file", type=str, default=None, 258 | help="Prob file from main model.") 259 | parser.add_argument("--tgt_file", type=str, default=None, 260 | help="Target file for training.") 261 | parser.add_argument("--spectrum_file", type=str, default=None, 262 | help="Source spectrum.") 263 | parser.add_argument("--logdir", type=str, default=None, 264 | help="Directory to save or load model.") 265 | parser.add_argument("--output", type=str, default=None, 266 | help="Output file path.") 267 | 268 | args = parser.parse_args() 269 | print(args) 270 | if args.train: 271 | print(" Training...") 272 | train(args.output_file, args.prob_file, args.tgt_file, args.spectrum_file, args.logdir) 273 | 274 | elif args.rescore: 275 | print(" Rescoring..") 276 | rescore(args.output_file, args.prob_file, args.spectrum_file, args.logdir, args.output) 277 | print(" Done") 278 | -------------------------------------------------------------------------------- /nmt/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmb-chula/SMSNet/facfaf441d0ef286d062f69530f0a298aba78edc/nmt/scripts/__init__.py -------------------------------------------------------------------------------- /nmt/scripts/bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Python implementation of BLEU and smooth-BLEU. 17 | 18 | This module provides a Python implementation of BLEU and smooth-BLEU. 19 | Smooth BLEU is computed following the method outlined in the paper: 20 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic 21 | evaluation metrics for machine translation. COLING 2004. 22 | """ 23 | 24 | import collections 25 | import math 26 | 27 | 28 | def _get_ngrams(segment, max_order): 29 | """Extracts all n-grams upto a given maximum order from an input segment. 30 | 31 | Args: 32 | segment: text segment from which n-grams will be extracted. 33 | max_order: maximum length in tokens of the n-grams returned by this 34 | methods. 35 | 36 | Returns: 37 | The Counter containing all n-grams upto max_order in segment 38 | with a count of how many times each n-gram occurred. 39 | """ 40 | ngram_counts = collections.Counter() 41 | for order in range(1, max_order + 1): 42 | for i in range(0, len(segment) - order + 1): 43 | ngram = tuple(segment[i:i+order]) 44 | ngram_counts[ngram] += 1 45 | return ngram_counts 46 | 47 | 48 | def compute_bleu(reference_corpus, translation_corpus, max_order=4, 49 | smooth=False): 50 | """Computes BLEU score of translated segments against one or more references. 51 | 52 | Args: 53 | reference_corpus: list of lists of references for each translation. Each 54 | reference should be tokenized into a list of tokens. 55 | translation_corpus: list of translations to score. Each translation 56 | should be tokenized into a list of tokens. 57 | max_order: Maximum n-gram order to use when computing BLEU score. 58 | smooth: Whether or not to apply Lin et al. 2004 smoothing. 59 | 60 | Returns: 61 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram 62 | precisions and brevity penalty. 63 | """ 64 | matches_by_order = [0] * max_order 65 | possible_matches_by_order = [0] * max_order 66 | reference_length = 0 67 | translation_length = 0 68 | for (references, translation) in zip(reference_corpus, 69 | translation_corpus): 70 | reference_length += min(len(r) for r in references) 71 | translation_length += len(translation) 72 | 73 | merged_ref_ngram_counts = collections.Counter() 74 | for reference in references: 75 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order) 76 | translation_ngram_counts = _get_ngrams(translation, max_order) 77 | overlap = translation_ngram_counts & merged_ref_ngram_counts 78 | for ngram in overlap: 79 | matches_by_order[len(ngram)-1] += overlap[ngram] 80 | for order in range(1, max_order+1): 81 | possible_matches = len(translation) - order + 1 82 | if possible_matches > 0: 83 | possible_matches_by_order[order-1] += possible_matches 84 | 85 | precisions = [0] * max_order 86 | for i in range(0, max_order): 87 | if smooth: 88 | precisions[i] = ((matches_by_order[i] + 1.) / 89 | (possible_matches_by_order[i] + 1.)) 90 | else: 91 | if possible_matches_by_order[i] > 0: 92 | precisions[i] = (float(matches_by_order[i]) / 93 | possible_matches_by_order[i]) 94 | else: 95 | precisions[i] = 0.0 96 | 97 | if min(precisions) > 0: 98 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) 99 | geo_mean = math.exp(p_log_sum) 100 | else: 101 | geo_mean = 0 102 | 103 | ratio = float(translation_length) / reference_length 104 | 105 | if ratio > 1.0: 106 | bp = 1. 107 | else: 108 | bp = math.exp(1 - 1. / ratio) 109 | 110 | bleu = geo_mean * bp 111 | 112 | return (bleu, precisions, bp, ratio, translation_length, reference_length) 113 | -------------------------------------------------------------------------------- /nmt/scripts/rouge.py: -------------------------------------------------------------------------------- 1 | """ROUGE metric implementation. 2 | 3 | Copy from tf_seq2seq/seq2seq/metrics/rouge.py. 4 | This is a modified and slightly extended verison of 5 | https://github.com/miso-belica/sumy/blob/dev/sumy/evaluation/rouge.py. 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | import itertools 14 | import numpy as np 15 | 16 | #pylint: disable=C0103 17 | 18 | 19 | def _get_ngrams(n, text): 20 | """Calcualtes n-grams. 21 | 22 | Args: 23 | n: which n-grams to calculate 24 | text: An array of tokens 25 | 26 | Returns: 27 | A set of n-grams 28 | """ 29 | ngram_set = set() 30 | text_length = len(text) 31 | max_index_ngram_start = text_length - n 32 | for i in range(max_index_ngram_start + 1): 33 | ngram_set.add(tuple(text[i:i + n])) 34 | return ngram_set 35 | 36 | 37 | def _split_into_words(sentences): 38 | """Splits multiple sentences into words and flattens the result""" 39 | return list(itertools.chain(*[_.split(" ") for _ in sentences])) 40 | 41 | 42 | def _get_word_ngrams(n, sentences): 43 | """Calculates word n-grams for multiple sentences. 44 | """ 45 | assert len(sentences) > 0 46 | assert n > 0 47 | 48 | words = _split_into_words(sentences) 49 | return _get_ngrams(n, words) 50 | 51 | 52 | def _len_lcs(x, y): 53 | """ 54 | Returns the length of the Longest Common Subsequence between sequences x 55 | and y. 56 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 57 | 58 | Args: 59 | x: sequence of words 60 | y: sequence of words 61 | 62 | Returns 63 | integer: Length of LCS between x and y 64 | """ 65 | table = _lcs(x, y) 66 | n, m = len(x), len(y) 67 | return table[n, m] 68 | 69 | 70 | def _lcs(x, y): 71 | """ 72 | Computes the length of the longest common subsequence (lcs) between two 73 | strings. The implementation below uses a DP programming algorithm and runs 74 | in O(nm) time where n = len(x) and m = len(y). 75 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 76 | 77 | Args: 78 | x: collection of words 79 | y: collection of words 80 | 81 | Returns: 82 | Table of dictionary of coord and len lcs 83 | """ 84 | n, m = len(x), len(y) 85 | table = dict() 86 | for i in range(n + 1): 87 | for j in range(m + 1): 88 | if i == 0 or j == 0: 89 | table[i, j] = 0 90 | elif x[i - 1] == y[j - 1]: 91 | table[i, j] = table[i - 1, j - 1] + 1 92 | else: 93 | table[i, j] = max(table[i - 1, j], table[i, j - 1]) 94 | return table 95 | 96 | 97 | def _recon_lcs(x, y): 98 | """ 99 | Returns the Longest Subsequence between x and y. 100 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 101 | 102 | Args: 103 | x: sequence of words 104 | y: sequence of words 105 | 106 | Returns: 107 | sequence: LCS of x and y 108 | """ 109 | i, j = len(x), len(y) 110 | table = _lcs(x, y) 111 | 112 | def _recon(i, j): 113 | """private recon calculation""" 114 | if i == 0 or j == 0: 115 | return [] 116 | elif x[i - 1] == y[j - 1]: 117 | return _recon(i - 1, j - 1) + [(x[i - 1], i)] 118 | elif table[i - 1, j] > table[i, j - 1]: 119 | return _recon(i - 1, j) 120 | else: 121 | return _recon(i, j - 1) 122 | 123 | recon_tuple = tuple(map(lambda x: x[0], _recon(i, j))) 124 | return recon_tuple 125 | 126 | 127 | def rouge_n(evaluated_sentences, reference_sentences, n=2): 128 | """ 129 | Computes ROUGE-N of two text collections of sentences. 130 | Sourece: http://research.microsoft.com/en-us/um/people/cyl/download/ 131 | papers/rouge-working-note-v1.3.1.pdf 132 | 133 | Args: 134 | evaluated_sentences: The sentences that have been picked by the summarizer 135 | reference_sentences: The sentences from the referene set 136 | n: Size of ngram. Defaults to 2. 137 | 138 | Returns: 139 | A tuple (f1, precision, recall) for ROUGE-N 140 | 141 | Raises: 142 | ValueError: raises exception if a param has len <= 0 143 | """ 144 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 145 | raise ValueError("Collections must contain at least 1 sentence.") 146 | 147 | evaluated_ngrams = _get_word_ngrams(n, evaluated_sentences) 148 | reference_ngrams = _get_word_ngrams(n, reference_sentences) 149 | reference_count = len(reference_ngrams) 150 | evaluated_count = len(evaluated_ngrams) 151 | 152 | # Gets the overlapping ngrams between evaluated and reference 153 | overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams) 154 | overlapping_count = len(overlapping_ngrams) 155 | 156 | # Handle edge case. This isn't mathematically correct, but it's good enough 157 | if evaluated_count == 0: 158 | precision = 0.0 159 | else: 160 | precision = overlapping_count / evaluated_count 161 | 162 | if reference_count == 0: 163 | recall = 0.0 164 | else: 165 | recall = overlapping_count / reference_count 166 | 167 | f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8)) 168 | 169 | # return overlapping_count / reference_count 170 | return f1_score, precision, recall 171 | 172 | 173 | def _f_p_r_lcs(llcs, m, n): 174 | """ 175 | Computes the LCS-based F-measure score 176 | Source: http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 177 | rouge-working-note-v1.3.1.pdf 178 | 179 | Args: 180 | llcs: Length of LCS 181 | m: number of words in reference summary 182 | n: number of words in candidate summary 183 | 184 | Returns: 185 | Float. LCS-based F-measure score 186 | """ 187 | r_lcs = llcs / m 188 | p_lcs = llcs / n 189 | beta = p_lcs / (r_lcs + 1e-12) 190 | num = (1 + (beta**2)) * r_lcs * p_lcs 191 | denom = r_lcs + ((beta**2) * p_lcs) 192 | f_lcs = num / (denom + 1e-12) 193 | return f_lcs, p_lcs, r_lcs 194 | 195 | 196 | def rouge_l_sentence_level(evaluated_sentences, reference_sentences): 197 | """ 198 | Computes ROUGE-L (sentence level) of two text collections of sentences. 199 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 200 | rouge-working-note-v1.3.1.pdf 201 | 202 | Calculated according to: 203 | R_lcs = LCS(X,Y)/m 204 | P_lcs = LCS(X,Y)/n 205 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 206 | 207 | where: 208 | X = reference summary 209 | Y = Candidate summary 210 | m = length of reference summary 211 | n = length of candidate summary 212 | 213 | Args: 214 | evaluated_sentences: The sentences that have been picked by the summarizer 215 | reference_sentences: The sentences from the referene set 216 | 217 | Returns: 218 | A float: F_lcs 219 | 220 | Raises: 221 | ValueError: raises exception if a param has len <= 0 222 | """ 223 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 224 | raise ValueError("Collections must contain at least 1 sentence.") 225 | reference_words = _split_into_words(reference_sentences) 226 | evaluated_words = _split_into_words(evaluated_sentences) 227 | m = len(reference_words) 228 | n = len(evaluated_words) 229 | lcs = _len_lcs(evaluated_words, reference_words) 230 | return _f_p_r_lcs(lcs, m, n) 231 | 232 | 233 | def _union_lcs(evaluated_sentences, reference_sentence): 234 | """ 235 | Returns LCS_u(r_i, C) which is the LCS score of the union longest common 236 | subsequence between reference sentence ri and candidate summary C. For example 237 | if r_i= w1 w2 w3 w4 w5, and C contains two sentences: c1 = w1 w2 w6 w7 w8 and 238 | c2 = w1 w3 w8 w9 w5, then the longest common subsequence of r_i and c1 is 239 | "w1 w2" and the longest common subsequence of r_i and c2 is "w1 w3 w5". The 240 | union longest common subsequence of r_i, c1, and c2 is "w1 w2 w3 w5" and 241 | LCS_u(r_i, C) = 4/5. 242 | 243 | Args: 244 | evaluated_sentences: The sentences that have been picked by the summarizer 245 | reference_sentence: One of the sentences in the reference summaries 246 | 247 | Returns: 248 | float: LCS_u(r_i, C) 249 | 250 | ValueError: 251 | Raises exception if a param has len <= 0 252 | """ 253 | if len(evaluated_sentences) <= 0: 254 | raise ValueError("Collections must contain at least 1 sentence.") 255 | 256 | lcs_union = set() 257 | reference_words = _split_into_words([reference_sentence]) 258 | combined_lcs_length = 0 259 | for eval_s in evaluated_sentences: 260 | evaluated_words = _split_into_words([eval_s]) 261 | lcs = set(_recon_lcs(reference_words, evaluated_words)) 262 | combined_lcs_length += len(lcs) 263 | lcs_union = lcs_union.union(lcs) 264 | 265 | union_lcs_count = len(lcs_union) 266 | union_lcs_value = union_lcs_count / combined_lcs_length 267 | return union_lcs_value 268 | 269 | 270 | def rouge_l_summary_level(evaluated_sentences, reference_sentences): 271 | """ 272 | Computes ROUGE-L (summary level) of two text collections of sentences. 273 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 274 | rouge-working-note-v1.3.1.pdf 275 | 276 | Calculated according to: 277 | R_lcs = SUM(1, u)[LCS(r_i,C)]/m 278 | P_lcs = SUM(1, u)[LCS(r_i,C)]/n 279 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 280 | 281 | where: 282 | SUM(i,u) = SUM from i through u 283 | u = number of sentences in reference summary 284 | C = Candidate summary made up of v sentences 285 | m = number of words in reference summary 286 | n = number of words in candidate summary 287 | 288 | Args: 289 | evaluated_sentences: The sentences that have been picked by the summarizer 290 | reference_sentence: One of the sentences in the reference summaries 291 | 292 | Returns: 293 | A float: F_lcs 294 | 295 | Raises: 296 | ValueError: raises exception if a param has len <= 0 297 | """ 298 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 299 | raise ValueError("Collections must contain at least 1 sentence.") 300 | 301 | # total number of words in reference sentences 302 | m = len(_split_into_words(reference_sentences)) 303 | 304 | # total number of words in evaluated sentences 305 | n = len(_split_into_words(evaluated_sentences)) 306 | 307 | union_lcs_sum_across_all_references = 0 308 | for ref_s in reference_sentences: 309 | union_lcs_sum_across_all_references += _union_lcs(evaluated_sentences, 310 | ref_s) 311 | return _f_p_r_lcs(union_lcs_sum_across_all_references, m, n) 312 | 313 | 314 | def rouge(hypotheses, references): 315 | """Calculates average rouge scores for a list of hypotheses and 316 | references""" 317 | 318 | # Filter out hyps that are of 0 length 319 | # hyps_and_refs = zip(hypotheses, references) 320 | # hyps_and_refs = [_ for _ in hyps_and_refs if len(_[0]) > 0] 321 | # hypotheses, references = zip(*hyps_and_refs) 322 | 323 | # Calculate ROUGE-1 F1, precision, recall scores 324 | rouge_1 = [ 325 | rouge_n([hyp], [ref], 1) for hyp, ref in zip(hypotheses, references) 326 | ] 327 | rouge_1_f, rouge_1_p, rouge_1_r = map(np.mean, zip(*rouge_1)) 328 | 329 | # Calculate ROUGE-2 F1, precision, recall scores 330 | rouge_2 = [ 331 | rouge_n([hyp], [ref], 2) for hyp, ref in zip(hypotheses, references) 332 | ] 333 | rouge_2_f, rouge_2_p, rouge_2_r = map(np.mean, zip(*rouge_2)) 334 | 335 | # Calculate ROUGE-L F1, precision, recall scores 336 | rouge_l = [ 337 | rouge_l_sentence_level([hyp], [ref]) 338 | for hyp, ref in zip(hypotheses, references) 339 | ] 340 | rouge_l_f, rouge_l_p, rouge_l_r = map(np.mean, zip(*rouge_l)) 341 | 342 | return { 343 | "rouge_1/f_score": rouge_1_f, 344 | "rouge_1/r_score": rouge_1_r, 345 | "rouge_1/p_score": rouge_1_p, 346 | "rouge_2/f_score": rouge_2_f, 347 | "rouge_2/r_score": rouge_2_r, 348 | "rouge_2/p_score": rouge_2_p, 349 | "rouge_l/f_score": rouge_l_f, 350 | "rouge_l/r_score": rouge_l_r, 351 | "rouge_l/p_score": rouge_l_p, 352 | } 353 | -------------------------------------------------------------------------------- /nmt/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """For training NMT models.""" 16 | from __future__ import print_function 17 | 18 | import math 19 | import os 20 | import random 21 | import time 22 | 23 | import tensorflow as tf 24 | 25 | from . import attention_model 26 | from . import gnmt_model 27 | from . import inference 28 | from . import model as nmt_model 29 | from . import model_helper 30 | from .utils import misc_utils as utils 31 | from .utils import nmt_utils 32 | 33 | utils.check_tensorflow_version() 34 | 35 | __all__ = [ 36 | "run_sample_decode", "run_internal_eval", "run_external_eval", 37 | "run_avg_external_eval", "run_full_eval", "init_stats", "update_stats", 38 | "print_step_info", "process_stats", "train" 39 | ] 40 | 41 | 42 | def run_sample_decode(infer_model, infer_sess, model_dir, hparams, 43 | summary_writer, src_data, tgt_data): 44 | """Sample decode a random sentence from src_data.""" 45 | with infer_model.graph.as_default(): 46 | loaded_infer_model, global_step = model_helper.create_or_load_model( 47 | infer_model.model, model_dir, infer_sess, "infer") 48 | 49 | _sample_decode(loaded_infer_model, global_step, infer_sess, hparams, 50 | infer_model.iterator, src_data, tgt_data, 51 | infer_model.src_placeholder, 52 | infer_model.batch_size_placeholder, summary_writer) 53 | 54 | 55 | def run_internal_eval( 56 | eval_model, eval_sess, model_dir, hparams, summary_writer, 57 | use_test_set=True): 58 | """Compute internal evaluation (perplexity) for both dev / test.""" 59 | with eval_model.graph.as_default(): 60 | loaded_eval_model, global_step = model_helper.create_or_load_model( 61 | eval_model.model, model_dir, eval_sess, "eval") 62 | 63 | dev_src_file = "%s%s" % (hparams.dev, hparams.src_suffix) 64 | dev_eval_iterator_feed_dict = { 65 | eval_model.src_file_placeholder: dev_src_file, 66 | } 67 | 68 | dev_ppl = _internal_eval(loaded_eval_model, global_step, eval_sess, 69 | eval_model.iterator, dev_eval_iterator_feed_dict, 70 | summary_writer, "dev") 71 | test_ppl = None 72 | if use_test_set and hparams.test: 73 | test_src_file = "%s%s" % (hparams.test, hparams.src_suffix) 74 | test_eval_iterator_feed_dict = { 75 | eval_model.src_file_placeholder: test_src_file, 76 | } 77 | test_ppl = _internal_eval(loaded_eval_model, global_step, eval_sess, 78 | eval_model.iterator, test_eval_iterator_feed_dict, 79 | summary_writer, "test") 80 | return dev_ppl, test_ppl 81 | 82 | 83 | def run_external_eval(infer_model, infer_sess, model_dir, hparams, 84 | summary_writer, save_best_dev=True, use_test_set=True, 85 | avg_ckpts=False): 86 | """Compute external evaluation (bleu, rouge, etc.) for both dev / test.""" 87 | with infer_model.graph.as_default(): 88 | loaded_infer_model, global_step = model_helper.create_or_load_model( 89 | infer_model.model, model_dir, infer_sess, "infer") 90 | 91 | dev_src_file = "%s%s" % (hparams.dev, hparams.src_suffix) 92 | dev_tgt_file = "%s%s" % (hparams.dev, hparams.tgt) 93 | dev_infer_iterator_feed_dict = { 94 | infer_model.src_placeholder: inference.load_data(dev_src_file), 95 | infer_model.batch_size_placeholder: hparams.infer_batch_size, 96 | } 97 | dev_scores = _external_eval( 98 | loaded_infer_model, 99 | global_step, 100 | infer_sess, 101 | hparams, 102 | infer_model.iterator, 103 | dev_infer_iterator_feed_dict, 104 | dev_tgt_file, 105 | "dev", 106 | summary_writer, 107 | save_on_best=save_best_dev, 108 | avg_ckpts=avg_ckpts) 109 | 110 | test_scores = None 111 | if use_test_set and hparams.test: 112 | test_src_file = "%s%s" % (hparams.test, hparams.src_suffix) 113 | test_tgt_file = "%s%s" % (hparams.test, hparams.tgt) 114 | test_infer_iterator_feed_dict = { 115 | infer_model.src_placeholder: inference.load_data(test_src_file), 116 | infer_model.batch_size_placeholder: hparams.infer_batch_size, 117 | } 118 | test_scores = _external_eval( 119 | loaded_infer_model, 120 | global_step, 121 | infer_sess, 122 | hparams, 123 | infer_model.iterator, 124 | test_infer_iterator_feed_dict, 125 | test_tgt_file, 126 | "test", 127 | summary_writer, 128 | save_on_best=False, 129 | avg_ckpts=avg_ckpts) 130 | return dev_scores, test_scores, global_step 131 | 132 | 133 | def run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, 134 | summary_writer, global_step): 135 | """Creates an averaged checkpoint and run external eval with it.""" 136 | avg_dev_scores, avg_test_scores = None, None 137 | if hparams.avg_ckpts: 138 | # Convert VariableName:0 to VariableName. 139 | global_step_name = infer_model.model.global_step.name.split(":")[0] 140 | avg_model_dir = model_helper.avg_checkpoints( 141 | model_dir, hparams.num_keep_ckpts, global_step, global_step_name) 142 | 143 | if avg_model_dir: 144 | avg_dev_scores, avg_test_scores, _ = run_external_eval( 145 | infer_model, 146 | infer_sess, 147 | avg_model_dir, 148 | hparams, 149 | summary_writer, 150 | avg_ckpts=True) 151 | 152 | return avg_dev_scores, avg_test_scores 153 | 154 | 155 | def run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess, 156 | hparams, summary_writer, sample_src_data, sample_tgt_data, 157 | avg_ckpts=False): 158 | """Wrapper for running sample_decode, internal_eval and external_eval.""" 159 | run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, 160 | sample_src_data, sample_tgt_data) 161 | dev_ppl, test_ppl = run_internal_eval( 162 | eval_model, eval_sess, model_dir, hparams, summary_writer) 163 | dev_scores, test_scores, global_step = run_external_eval( 164 | infer_model, infer_sess, model_dir, hparams, summary_writer) 165 | 166 | metrics = { 167 | "dev_ppl": dev_ppl, 168 | "test_ppl": test_ppl, 169 | "dev_scores": dev_scores, 170 | "test_scores": test_scores, 171 | } 172 | 173 | avg_dev_scores, avg_test_scores = None, None 174 | if avg_ckpts: 175 | avg_dev_scores, avg_test_scores = run_avg_external_eval( 176 | infer_model, infer_sess, model_dir, hparams, summary_writer, 177 | global_step) 178 | metrics["avg_dev_scores"] = avg_dev_scores 179 | metrics["avg_test_scores"] = avg_test_scores 180 | 181 | result_summary = _format_results("dev", dev_ppl, dev_scores, hparams.metrics) 182 | if avg_dev_scores: 183 | result_summary += ", " + _format_results("avg_dev", None, avg_dev_scores, 184 | hparams.metrics) 185 | if hparams.test: 186 | result_summary += ", " + _format_results("test", test_ppl, test_scores, 187 | hparams.metrics) 188 | if avg_test_scores: 189 | result_summary += ", " + _format_results("avg_test", None, 190 | avg_test_scores, hparams.metrics) 191 | 192 | return result_summary, global_step, metrics 193 | 194 | 195 | def init_stats(): 196 | """Initialize statistics that we want to accumulate.""" 197 | return {"step_time": 0.0, "loss": 0.0, "predict_count": 0.0, 198 | "total_count": 0.0, "grad_norm": 0.0} 199 | 200 | 201 | def update_stats(stats, start_time, step_result): 202 | """Update stats: write summary and accumulate statistics.""" 203 | (_, step_loss, step_predict_count, step_summary, global_step, 204 | step_word_count, batch_size, grad_norm, learning_rate) = step_result 205 | 206 | # Update statistics 207 | stats["step_time"] += (time.time() - start_time) 208 | stats["loss"] += (step_loss * batch_size) 209 | stats["predict_count"] += step_predict_count 210 | stats["total_count"] += float(step_word_count) 211 | stats["grad_norm"] += grad_norm 212 | 213 | return global_step, learning_rate, step_summary 214 | 215 | 216 | def print_step_info(prefix, global_step, info, result_summary, log_f): 217 | """Print all info at the current global step.""" 218 | utils.print_out( 219 | "%sstep %d lr %g step-time %.2fs wps %.2fK ppl %.2f gN %.2f %s, %s" % 220 | (prefix, global_step, info["learning_rate"], info["avg_step_time"], 221 | info["speed"], info["train_ppl"], info["avg_grad_norm"], result_summary, 222 | time.ctime()), 223 | log_f) 224 | 225 | 226 | def process_stats(stats, info, global_step, steps_per_stats, log_f): 227 | """Update info and check for overflow.""" 228 | # Update info 229 | info["avg_step_time"] = stats["step_time"] / steps_per_stats 230 | info["avg_grad_norm"] = stats["grad_norm"] / steps_per_stats 231 | info["train_ppl"] = utils.safe_exp(stats["loss"] / stats["predict_count"]) 232 | info["speed"] = stats["total_count"] / (1000 * stats["step_time"]) 233 | 234 | # Check for overflow 235 | is_overflow = False 236 | train_ppl = info["train_ppl"] 237 | if math.isnan(train_ppl) or math.isinf(train_ppl) or train_ppl > 1e20: 238 | utils.print_out(" step %d overflow, stop early" % global_step, 239 | log_f) 240 | is_overflow = True 241 | 242 | return is_overflow 243 | 244 | 245 | def before_train(loaded_train_model, train_model, train_sess, global_step, 246 | hparams, log_f): 247 | """Misc tasks to do before training.""" 248 | stats = init_stats() 249 | info = {"train_ppl": 0.0, "speed": 0.0, "avg_step_time": 0.0, 250 | "avg_grad_norm": 0.0, 251 | "learning_rate": loaded_train_model.learning_rate.eval( 252 | session=train_sess)} 253 | start_train_time = time.time() 254 | utils.print_out("# Start step %d, lr %g, %s" % 255 | (global_step, info["learning_rate"], time.ctime()), log_f) 256 | 257 | # Initialize all of the iterators 258 | skip_count = hparams.batch_size * hparams.epoch_step 259 | utils.print_out("# Init train iterator, skipping %d elements" % skip_count) 260 | train_sess.run( 261 | train_model.iterator.initializer, 262 | feed_dict={train_model.skip_count_placeholder: skip_count}) 263 | 264 | return stats, info, start_train_time 265 | 266 | 267 | def train(hparams, scope=None, target_session=""): 268 | """Train a translation model.""" 269 | log_device_placement = hparams.log_device_placement 270 | out_dir = hparams.out_dir 271 | num_train_steps = hparams.num_train_steps 272 | steps_per_stats = hparams.steps_per_stats 273 | steps_per_external_eval = hparams.steps_per_external_eval 274 | steps_per_eval = 10 * steps_per_stats 275 | avg_ckpts = hparams.avg_ckpts 276 | 277 | if not steps_per_external_eval: 278 | steps_per_external_eval = 5 * steps_per_eval 279 | 280 | if not hparams.attention: 281 | model_creator = nmt_model.Model 282 | else: # Attention 283 | if (hparams.encoder_type == "gnmt" or 284 | hparams.attention_architecture in ["gnmt", "gnmt_v2"]): 285 | model_creator = gnmt_model.GNMTModel 286 | elif hparams.attention_architecture == "standard": 287 | model_creator = attention_model.AttentionModel 288 | else: 289 | raise ValueError("Unknown attention architecture %s" % 290 | hparams.attention_architecture) 291 | 292 | train_model = model_helper.create_train_model(model_creator, hparams, scope) 293 | eval_model = model_helper.create_eval_model(model_creator, hparams, scope) 294 | infer_model = model_helper.create_infer_model(model_creator, hparams, scope) 295 | 296 | # Preload data for sample decoding. 297 | dev_src_file = "%s%s" % (hparams.dev, hparams.src_suffix) 298 | dev_tgt_file = "%s%s" % (hparams.dev, hparams.tgt) 299 | sample_src_data = inference.load_data(dev_src_file) 300 | sample_tgt_data = inference.load_data(dev_tgt_file) 301 | 302 | summary_name = "train_log" 303 | model_dir = hparams.out_dir 304 | 305 | # Log and output files 306 | log_file = os.path.join(out_dir, "log_%d" % time.time()) 307 | log_f = tf.gfile.GFile(log_file, mode="a") 308 | utils.print_out("# log_file=%s" % log_file, log_f) 309 | 310 | # TensorFlow model 311 | config_proto = utils.get_config_proto( 312 | log_device_placement=log_device_placement, 313 | num_intra_threads=hparams.num_intra_threads, 314 | num_inter_threads=hparams.num_inter_threads) 315 | train_sess = tf.Session( 316 | target=target_session, config=config_proto, graph=train_model.graph) 317 | eval_sess = tf.Session( 318 | target=target_session, config=config_proto, graph=eval_model.graph) 319 | infer_sess = tf.Session( 320 | target=target_session, config=config_proto, graph=infer_model.graph) 321 | 322 | with train_model.graph.as_default(): 323 | loaded_train_model, global_step = model_helper.create_or_load_model( 324 | train_model.model, model_dir, train_sess, "train") 325 | 326 | # Summary writer 327 | summary_writer = tf.summary.FileWriter( 328 | os.path.join(out_dir, summary_name), train_model.graph) 329 | 330 | # First evaluation 331 | run_full_eval( 332 | model_dir, infer_model, infer_sess, 333 | eval_model, eval_sess, hparams, 334 | summary_writer, sample_src_data, 335 | sample_tgt_data, avg_ckpts) 336 | 337 | # For stoping inference run after the training was already finished 338 | if hparams.train_fin: 339 | raise SystemExit 340 | 341 | last_stats_step = global_step 342 | last_eval_step = global_step 343 | last_external_eval_step = global_step 344 | 345 | # This is the training loop. 346 | stats, info, start_train_time = before_train( 347 | loaded_train_model, train_model, train_sess, global_step, hparams, log_f) 348 | while global_step < num_train_steps: 349 | ### Run a step ### 350 | start_time = time.time() 351 | try: 352 | step_result = loaded_train_model.train(train_sess) 353 | hparams.epoch_step += 1 354 | except tf.errors.OutOfRangeError: 355 | # Finished going through the training dataset. Go to next epoch. 356 | hparams.epoch_step = 0 357 | utils.print_out( 358 | "# Finished an epoch, step %d. Perform external evaluation" % 359 | global_step) 360 | run_sample_decode(infer_model, infer_sess, model_dir, hparams, 361 | summary_writer, sample_src_data, sample_tgt_data) 362 | run_external_eval(infer_model, infer_sess, model_dir, hparams, 363 | summary_writer) 364 | if avg_ckpts: 365 | run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, 366 | summary_writer, global_step) 367 | 368 | train_sess.run( 369 | train_model.iterator.initializer, 370 | feed_dict={train_model.skip_count_placeholder: 0}) 371 | continue 372 | 373 | # Process step_result, accumulate stats, and write summary 374 | global_step, info["learning_rate"], step_summary = update_stats( 375 | stats, start_time, step_result) 376 | summary_writer.add_summary(step_summary, global_step) 377 | 378 | # Once in a while, we print statistics. 379 | if global_step - last_stats_step >= steps_per_stats: 380 | last_stats_step = global_step 381 | is_overflow = process_stats( 382 | stats, info, global_step, steps_per_stats, log_f) 383 | print_step_info(" ", global_step, info, _get_best_results(hparams), 384 | log_f) 385 | if is_overflow: 386 | break 387 | 388 | # Reset statistics 389 | stats = init_stats() 390 | 391 | if global_step - last_eval_step >= steps_per_eval: 392 | last_eval_step = global_step 393 | utils.print_out("# Save eval, global step %d" % global_step) 394 | utils.add_summary(summary_writer, global_step, "train_ppl", 395 | info["train_ppl"]) 396 | 397 | # Save checkpoint 398 | loaded_train_model.saver.save( 399 | train_sess, 400 | os.path.join(out_dir, "translate.ckpt"), 401 | global_step=global_step) 402 | 403 | # Evaluate on dev/test 404 | run_sample_decode(infer_model, infer_sess, 405 | model_dir, hparams, summary_writer, sample_src_data, 406 | sample_tgt_data) 407 | run_internal_eval( 408 | eval_model, eval_sess, model_dir, hparams, summary_writer) 409 | 410 | if global_step - last_external_eval_step >= steps_per_external_eval: 411 | last_external_eval_step = global_step 412 | 413 | # Save checkpoint 414 | loaded_train_model.saver.save( 415 | train_sess, 416 | os.path.join(out_dir, "translate.ckpt"), 417 | global_step=global_step) 418 | run_sample_decode(infer_model, infer_sess, 419 | model_dir, hparams, summary_writer, sample_src_data, 420 | sample_tgt_data) 421 | run_external_eval( 422 | infer_model, infer_sess, model_dir, 423 | hparams, summary_writer) 424 | 425 | if avg_ckpts: 426 | run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, 427 | summary_writer, global_step) 428 | 429 | # Done training 430 | loaded_train_model.saver.save( 431 | train_sess, 432 | os.path.join(out_dir, "translate.ckpt"), 433 | global_step=global_step) 434 | 435 | (result_summary, _, final_eval_metrics) = ( 436 | run_full_eval( 437 | model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, 438 | summary_writer, sample_src_data, sample_tgt_data, avg_ckpts)) 439 | print_step_info("# Final, ", global_step, info, result_summary, log_f) 440 | utils.print_time("# Done training!", start_train_time) 441 | 442 | summary_writer.close() 443 | 444 | utils.print_out("# Start evaluating saved best models.") 445 | for metric in hparams.metrics: 446 | best_model_dir = getattr(hparams, "best_" + metric + "_dir") 447 | summary_writer = tf.summary.FileWriter( 448 | os.path.join(best_model_dir, summary_name), infer_model.graph) 449 | result_summary, best_global_step, _ = run_full_eval( 450 | best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, 451 | summary_writer, sample_src_data, sample_tgt_data) 452 | print_step_info("# Best %s, " % metric, best_global_step, info, 453 | result_summary, log_f) 454 | summary_writer.close() 455 | 456 | if avg_ckpts: 457 | best_model_dir = getattr(hparams, "avg_best_" + metric + "_dir") 458 | summary_writer = tf.summary.FileWriter( 459 | os.path.join(best_model_dir, summary_name), infer_model.graph) 460 | result_summary, best_global_step, _ = run_full_eval( 461 | best_model_dir, infer_model, infer_sess, eval_model, eval_sess, 462 | hparams, summary_writer, sample_src_data, sample_tgt_data) 463 | print_step_info("# Averaged Best %s, " % metric, best_global_step, info, 464 | result_summary, log_f) 465 | summary_writer.close() 466 | 467 | return final_eval_metrics, global_step 468 | 469 | 470 | def _format_results(name, ppl, scores, metrics): 471 | """Format results.""" 472 | result_str = "" 473 | if ppl: 474 | result_str = "%s ppl %.2f" % (name, ppl) 475 | if scores: 476 | for metric in metrics: 477 | if result_str: 478 | result_str += ", %s %s %.1f" % (name, metric, scores[metric]) 479 | else: 480 | result_str = "%s %s %.1f" % (name, metric, scores[metric]) 481 | return result_str 482 | 483 | 484 | def _get_best_results(hparams): 485 | """Summary of the current best results.""" 486 | tokens = [] 487 | for metric in hparams.metrics: 488 | tokens.append("%s %.2f" % (metric, getattr(hparams, "best_" + metric))) 489 | return ", ".join(tokens) 490 | 491 | 492 | def _internal_eval(model, global_step, sess, iterator, iterator_feed_dict, 493 | summary_writer, label): 494 | """Computing perplexity.""" 495 | sess.run(iterator.initializer, feed_dict=iterator_feed_dict) 496 | ppl = model_helper.compute_perplexity(model, sess, label) 497 | 498 | utils.add_summary(summary_writer, global_step, "%s_ppl" % label, ppl) 499 | return ppl 500 | 501 | 502 | def _sample_decode(model, global_step, sess, hparams, iterator, src_data, 503 | tgt_data, iterator_src_placeholder, 504 | iterator_batch_size_placeholder, summary_writer): 505 | """Pick a sentence and decode.""" 506 | """Pick 2 sentences and decode.""" 507 | for i in range(2): 508 | decode_id = random.randint(0, len(src_data) - 1) 509 | utils.print_out(" # %d" % decode_id) 510 | 511 | iterator_feed_dict = { 512 | iterator_src_placeholder: [src_data[decode_id]], 513 | iterator_batch_size_placeholder: 1, 514 | } 515 | 516 | sess.run(iterator.initializer, feed_dict=iterator_feed_dict) 517 | 518 | nmt_outputs, steps_probs, attention_summary = model.decode(sess, 1) 519 | if hparams.beam_width > 0: 520 | # get the top translation. 521 | nmt_outputs = nmt_outputs[0] 522 | steps_probs = steps_probs[0] 523 | 524 | translation = nmt_utils.get_translation( 525 | nmt_outputs, 526 | sent_id=0, 527 | tgt_eos=hparams.eos, 528 | subword_option=hparams.subword_option) 529 | 530 | trans_step_probs = nmt_utils.get_log_probs( 531 | steps_probs, 532 | sent_id=0, 533 | output_len=int((len(translation)+1)/2) 534 | ) 535 | utils.print_out(" ref: %s" % tgt_data[decode_id]) 536 | utils.print_out(b" nmt: " + translation) 537 | utils.print_out(b" prob: " + trans_step_probs) 538 | 539 | utils.print_out("--------") 540 | # Summary 541 | if attention_summary is not None: 542 | summary_writer.add_summary(attention_summary, global_step) 543 | 544 | 545 | def _external_eval(model, global_step, sess, hparams, iterator, 546 | iterator_feed_dict, tgt_file, label, summary_writer, 547 | save_on_best, avg_ckpts=False): 548 | """External evaluation such as BLEU and ROUGE scores.""" 549 | out_dir = hparams.out_dir 550 | decode = global_step > 0 551 | 552 | if avg_ckpts: 553 | label = "avg_" + label 554 | 555 | if decode: 556 | utils.print_out("# External evaluation, global step %d" % global_step) 557 | 558 | sess.run(iterator.initializer, feed_dict=iterator_feed_dict) 559 | 560 | output = os.path.join(out_dir, "output_%s" % label) 561 | scores = nmt_utils.decode_and_evaluate( 562 | label, 563 | model, 564 | sess, 565 | output, 566 | ref_file=tgt_file, 567 | metrics=hparams.metrics, 568 | subword_option=hparams.subword_option, 569 | beam_width=hparams.beam_width, 570 | tgt_eos=hparams.eos, 571 | decode=decode) 572 | # Save on best metrics 573 | if decode: 574 | for metric in hparams.metrics: 575 | if avg_ckpts: 576 | best_metric_label = "avg_best_" + metric 577 | else: 578 | best_metric_label = "best_" + metric 579 | 580 | utils.add_summary(summary_writer, global_step, "%s_%s" % (label, metric), 581 | scores[metric]) 582 | # metric: larger is better 583 | if save_on_best and scores[metric] > getattr(hparams, best_metric_label): 584 | setattr(hparams, best_metric_label, scores[metric]) 585 | model.saver.save( 586 | sess, 587 | os.path.join( 588 | getattr(hparams, best_metric_label + "_dir"), "translate.ckpt"), 589 | global_step=model.global_step) 590 | utils.save_hparams(out_dir, hparams) 591 | return scores 592 | -------------------------------------------------------------------------------- /nmt/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmb-chula/SMSNet/facfaf441d0ef286d062f69530f0a298aba78edc/nmt/utils/__init__.py -------------------------------------------------------------------------------- /nmt/utils/evaluation_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # Copyright 2019 Korrawe Karunratanakul 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Utility for evaluating various tasks, e.g., translation & summarization.""" 18 | import codecs 19 | import os 20 | import re 21 | import subprocess 22 | 23 | import tensorflow as tf 24 | 25 | from ..scripts import bleu 26 | from ..scripts import rouge 27 | from .. import input_config 28 | 29 | 30 | __all__ = ["evaluate"] 31 | 32 | 33 | def evaluate(ref_file, trans_file, metric, subword_option=None): 34 | """Pick a metric and evaluate depending on task.""" 35 | # BLEU scores for translation task 36 | if metric.lower() == "bleu": 37 | evaluation_score = _bleu(ref_file, trans_file, 38 | subword_option=subword_option) 39 | # ROUGE scores for summarization tasks 40 | elif metric.lower() == "rouge": 41 | evaluation_score = _rouge(ref_file, trans_file, 42 | subword_option=subword_option) 43 | elif metric.lower() == "accuracy": 44 | evaluation_score = _accuracy(ref_file, trans_file) 45 | elif metric.lower() == "word_accuracy": 46 | evaluation_score = _word_accuracy(ref_file, trans_file) 47 | elif metric.lower() == "amino_acid_accuracy": 48 | evaluation_score = _amino_acid_accuracy(ref_file, trans_file) 49 | else: 50 | raise ValueError("Unknown metric %s" % metric) 51 | 52 | return evaluation_score 53 | 54 | 55 | def _clean(sentence, subword_option): 56 | """Clean and handle BPE or SPM outputs.""" 57 | sentence = sentence.strip() 58 | 59 | # BPE 60 | if subword_option == "bpe": 61 | sentence = re.sub("@@ ", "", sentence) 62 | 63 | # SPM 64 | elif subword_option == "spm": 65 | sentence = u"".join(sentence.split()).replace(u"\u2581", u" ").lstrip() 66 | 67 | return sentence 68 | 69 | 70 | # Follow //transconsole/localization/machine_translation/metrics/bleu_calc.py 71 | def _bleu(ref_file, trans_file, subword_option=None): 72 | """Compute BLEU scores and handling BPE.""" 73 | max_order = 4 74 | smooth = False 75 | 76 | ref_files = [ref_file] 77 | reference_text = [] 78 | for reference_filename in ref_files: 79 | with codecs.getreader("utf-8")( 80 | tf.gfile.GFile(reference_filename, "rb")) as fh: 81 | reference_text.append(fh.readlines()) 82 | 83 | per_segment_references = [] 84 | for references in zip(*reference_text): 85 | reference_list = [] 86 | for reference in references: 87 | reference = _clean(reference, subword_option) 88 | reference_list.append(reference.split(" ")) 89 | per_segment_references.append(reference_list) 90 | 91 | translations = [] 92 | with codecs.getreader("utf-8")(tf.gfile.GFile(trans_file, "rb")) as fh: 93 | for line in fh: 94 | line = _clean(line, subword_option=None) 95 | translations.append(line.split(" ")) 96 | 97 | # bleu_score, precisions, bp, ratio, translation_length, reference_length 98 | bleu_score, _, _, _, _, _ = bleu.compute_bleu( 99 | per_segment_references, translations, max_order, smooth) 100 | return 100 * bleu_score 101 | 102 | 103 | def _rouge(ref_file, summarization_file, subword_option=None): 104 | """Compute ROUGE scores and handling BPE.""" 105 | 106 | references = [] 107 | with codecs.getreader("utf-8")(tf.gfile.GFile(ref_file, "rb")) as fh: 108 | for line in fh: 109 | references.append(_clean(line, subword_option)) 110 | 111 | hypotheses = [] 112 | with codecs.getreader("utf-8")( 113 | tf.gfile.GFile(summarization_file, "rb")) as fh: 114 | for line in fh: 115 | hypotheses.append(_clean(line, subword_option=None)) 116 | 117 | rouge_score_map = rouge.rouge(hypotheses, references) 118 | return 100 * rouge_score_map["rouge_l/f_score"] 119 | 120 | 121 | def _accuracy(label_file, pred_file): 122 | """Compute accuracy, each line contains a label.""" 123 | 124 | with codecs.getreader("utf-8")(tf.gfile.GFile(label_file, "rb")) as label_fh: 125 | with codecs.getreader("utf-8")(tf.gfile.GFile(pred_file, "rb")) as pred_fh: 126 | count = 0.0 127 | match = 0.0 128 | for label in label_fh: 129 | label = label.strip() 130 | label = label.replace('L', 'I') ##### <<<<< change L in true label to I 131 | pred = pred_fh.readline().strip() 132 | if label == pred: 133 | match += 1 134 | count += 1 135 | return 100 * match / count 136 | 137 | 138 | def _word_accuracy(label_file, pred_file): 139 | """Compute accuracy on per word basis.""" 140 | 141 | with codecs.getreader("utf-8")(tf.gfile.GFile(label_file, "rb")) as label_fh: 142 | with codecs.getreader("utf-8")(tf.gfile.GFile(pred_file, "rb")) as pred_fh: 143 | total_acc, total_count = 0., 0. 144 | for sentence in label_fh: 145 | labels = sentence.strip().split(" ") 146 | preds = pred_fh.readline().strip().split(" ") 147 | match = 0.0 148 | for pos in range(min(len(labels), len(preds))): 149 | label = labels[pos] 150 | pred = preds[pos] 151 | if label == pred: 152 | match += 1 153 | total_acc += 100 * match / max(len(labels), len(preds)) 154 | total_count += 1 155 | return total_acc / total_count 156 | 157 | 158 | def _amino_acid_accuracy(label_file, pred_file): 159 | """Compute amino acid accuracy.""" 160 | 161 | with codecs.getreader("utf-8")(tf.gfile.GFile(label_file, "rb")) as label_fh: 162 | with codecs.getreader("utf-8")(tf.gfile.GFile(pred_file, "rb")) as pred_fh: 163 | total_match, total_count = 0., 0. 164 | for sentence in label_fh: 165 | labels = sentence.strip().split(" ") 166 | labels = [input_config.mass_AA[aa] for aa in labels] 167 | # print(labels) 168 | preds = pred_fh.readline().strip().split(" ") 169 | if not preds[0]: 170 | preds = [''] 171 | preds = [input_config.mass_AA[aa] for aa in preds] 172 | # print(preds) 173 | pred_len, label_len = len(preds), len(labels) 174 | pos_pred, pos_label = 0, 0 175 | sum_pred, sum_label = 0., 0. 176 | match = 0.0 177 | while pos_pred < pred_len and pos_label < label_len: 178 | if abs(sum_pred - sum_label) < 0.03 and abs(preds[pos_pred] - labels[pos_label]) < 0.0001: 179 | match += 1 180 | sum_pred += preds[pos_pred] 181 | sum_label += labels[pos_label] 182 | pos_pred += 1 183 | pos_label += 1 184 | else: 185 | if sum_pred < sum_label: 186 | sum_pred += preds[pos_pred] 187 | pos_pred += 1 188 | else: 189 | sum_label += labels[pos_label] 190 | pos_label += 1 191 | total_match += match 192 | total_count += len(labels) 193 | return 100 * total_match / total_count 194 | 195 | 196 | def _moses_bleu(multi_bleu_script, tgt_test, trans_file, subword_option=None): 197 | """Compute BLEU scores using Moses multi-bleu.perl script.""" 198 | 199 | # TODO(thangluong): perform rewrite using python 200 | # BPE 201 | if subword_option == "bpe": 202 | debpe_tgt_test = tgt_test + ".debpe" 203 | if not os.path.exists(debpe_tgt_test): 204 | # TODO(thangluong): not use shell=True, can be a security hazard 205 | subprocess.call("cp %s %s" % (tgt_test, debpe_tgt_test), shell=True) 206 | subprocess.call("sed s/@@ //g %s" % (debpe_tgt_test), 207 | shell=True) 208 | tgt_test = debpe_tgt_test 209 | elif subword_option == "spm": 210 | despm_tgt_test = tgt_test + ".despm" 211 | if not os.path.exists(despm_tgt_test): 212 | subprocess.call("cp %s %s" % (tgt_test, despm_tgt_test)) 213 | subprocess.call("sed s/ //g %s" % (despm_tgt_test)) 214 | subprocess.call(u"sed s/^\u2581/g %s" % (despm_tgt_test)) 215 | subprocess.call(u"sed s/\u2581/ /g %s" % (despm_tgt_test)) 216 | tgt_test = despm_tgt_test 217 | cmd = "%s %s < %s" % (multi_bleu_script, tgt_test, trans_file) 218 | 219 | # subprocess 220 | # TODO(thangluong): not use shell=True, can be a security hazard 221 | bleu_output = subprocess.check_output(cmd, shell=True) 222 | 223 | # extract BLEU score 224 | m = re.search("BLEU = (.+?),", bleu_output) 225 | bleu_score = float(m.group(1)) 226 | 227 | return bleu_score 228 | -------------------------------------------------------------------------------- /nmt/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from nmt import input_config 4 | 5 | 6 | def read_empirical_mass(file_data, source_file_name): 7 | with open(source_file_name, 'r') as source_file: 8 | file_data["ref_empirical_mass"] = [] 9 | for row in source_file: 10 | row = row.strip() 11 | data = row.split('|') 12 | 13 | charge = int(float(data[1])) 14 | mass_p_charge = float(data[2]) 15 | peptide_mass = charge * mass_p_charge + (2.0 - charge) * input_config.mass_H 16 | sum_mass = peptide_mass - 2*input_config.mass_H - input_config.mass_H2O 17 | file_data["ref_empirical_mass"].append(sum_mass) 18 | 19 | 20 | def read_output_file(input_filename, prob_filename, ref_filename=None, mass_spec_filename=None): 21 | file_entry = {"ref_seqs": [], 22 | "nmt_seqs": [], 23 | "ref_weights": [], 24 | "nmt_weights": [], 25 | "probs": [], 26 | "length": 0} 27 | if ref_filename: 28 | with open(ref_filename, "r") as ref_file: 29 | for line in ref_file: 30 | ref_seq = line.strip() 31 | file_entry["ref_seqs"].append(ref_seq) 32 | 33 | with open(input_filename, "r") as input_file: 34 | for line in input_file: 35 | nmt_seq = line.strip() 36 | file_entry["nmt_seqs"].append(nmt_seq) 37 | 38 | with open(prob_filename, "r") as prob_file: 39 | for line in prob_file: 40 | probs = line.strip() 41 | probs = probs.split(" ") 42 | # if probs[0] != '': 43 | # probs = [np.exp(float(prob)) for prob in probs] 44 | # else: 45 | # probs = [] 46 | file_entry["probs"].append(probs) 47 | file_entry["length"] += 1 48 | 49 | read_empirical_mass(file_entry, mass_spec_filename) 50 | return file_entry 51 | 52 | 53 | def read_compare_file(input_filename): 54 | file_entry = {"ref_seqs": [], 55 | "nmt_seqs": [], 56 | "ref_weights": [], 57 | "nmt_weights": [], 58 | "probs": [], 59 | "length": 0} 60 | 61 | with open(input_filename, "r") as input_file: 62 | while True: 63 | ref_seq = input_file.readline() 64 | if not ref_seq: break 65 | 66 | nmt_seq = input_file.readline() 67 | nmt_seq = nmt_seq.strip() 68 | ref_seq = ref_seq.strip() 69 | 70 | total_weight_ref = input_file.readline().strip() 71 | total_weight_nmt = input_file.readline().strip() 72 | 73 | probs = input_file.readline().strip() 74 | probs = probs.split(" ") 75 | tmp = input_file.readline() 76 | 77 | file_entry["ref_seqs"].append(ref_seq) 78 | file_entry["nmt_seqs"].append(nmt_seq) 79 | file_entry["ref_weights"].append(float(total_weight_ref)) 80 | file_entry["nmt_weights"].append(float(total_weight_nmt)) 81 | file_entry["probs"].append(probs) 82 | file_entry["length"] += 1 83 | return file_entry 84 | 85 | 86 | def read_deepnovo_file(input_filename, ref_filename, mass_spec_filename): 87 | file_entry = {"ref_seqs": [], 88 | "nmt_seqs": [], 89 | "ref_weights": [], 90 | "nmt_weights": [], 91 | "probs": [], 92 | "length": 0} 93 | with open(ref_filename, "r") as ref_file: 94 | last = 0 95 | for line in ref_file: 96 | ref_seq = line.strip() 97 | file_entry["ref_seqs"].append(ref_seq) 98 | 99 | # scan predicted_sequence predicted_score predicted_position_score 100 | # 0 Y,E,E,I,Q,I,T,Q,R -0.45 -1.24,-0.01,-0.00,-0.01,-0.05,-0.01,-2.70,-0.05 101 | with open(input_filename, "r") as input_file: 102 | header = input_file.readline() 103 | last = 0 104 | for line in input_file: 105 | line = line.strip().split('\t') 106 | scan_num = int(line[0]) 107 | 108 | while scan_num > last: 109 | file_entry["nmt_seqs"].append('') 110 | file_entry["probs"].append('') 111 | file_entry["length"] += 1 112 | last += 1 113 | 114 | deepnovo_seq = line[1] 115 | deepnovo_seq = deepnovo_seq.replace('Cmod', 'C') 116 | deepnovo_seq = deepnovo_seq.replace('Mmod', 'm') 117 | deepnovo_seq = deepnovo_seq.replace('Qmod', 'q') 118 | deepnovo_seq = deepnovo_seq.replace('Nmod', 'n') 119 | deepnovo_seq = deepnovo_seq.replace(',', ' ') 120 | # print(deepnovo_seq) 121 | # print(file_entry["ref_seqs"][last]) 122 | # print('--------') 123 | 124 | if len(line[1]) > 0: 125 | probs = line[3].split(",") 126 | else: 127 | probs = [] 128 | probs.append('0') 129 | 130 | file_entry["nmt_seqs"].append(deepnovo_seq) 131 | file_entry["probs"].append(probs) 132 | file_entry["length"] += 1 133 | last += 1 134 | 135 | read_empirical_mass(file_entry, mass_spec_filename) 136 | return file_entry -------------------------------------------------------------------------------- /nmt/utils/mask_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmb-chula/SMSNet/facfaf441d0ef286d062f69530f0a298aba78edc/nmt/utils/mask_utils.py -------------------------------------------------------------------------------- /nmt/utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Generally useful utility functions.""" 17 | from __future__ import print_function 18 | 19 | import codecs 20 | import collections 21 | import json 22 | import math 23 | import os 24 | import sys 25 | import time 26 | 27 | import numpy as np 28 | import tensorflow as tf 29 | 30 | 31 | def check_tensorflow_version(): 32 | min_tf_version = "1.4.0-dev20171024" 33 | #if tf.__version__ < min_tf_version: 34 | # raise EnvironmentError("Tensorflow version must >= %s" % min_tf_version) 35 | 36 | 37 | def safe_exp(value): 38 | """Exponentiation with catching of overflow error.""" 39 | try: 40 | ans = math.exp(value) 41 | except OverflowError: 42 | ans = float("inf") 43 | return ans 44 | 45 | 46 | def print_time(s, start_time): 47 | """Take a start time, print elapsed duration, and return a new time.""" 48 | print("%s, time %ds, %s." % (s, (time.time() - start_time), time.ctime())) 49 | sys.stdout.flush() 50 | return time.time() 51 | 52 | 53 | def print_out(s, f=None, new_line=True): 54 | """Similar to print but with support to flush and output to a file.""" 55 | if isinstance(s, bytes): 56 | s = s.decode("utf-8") 57 | 58 | if f: 59 | f.write(s.encode("utf-8")) 60 | if new_line: 61 | f.write(b"\n") 62 | 63 | # stdout 64 | out_s = s.encode("utf-8") 65 | if not isinstance(out_s, str): 66 | out_s = out_s.decode("utf-8") 67 | print(out_s, end="", file=sys.stdout) 68 | 69 | if new_line: 70 | sys.stdout.write("\n") 71 | sys.stdout.flush() 72 | 73 | 74 | def print_hparams(hparams, skip_patterns=None, header=None): 75 | """Print hparams, can skip keys based on pattern.""" 76 | if header: print_out("%s" % header) 77 | values = hparams.values() 78 | for key in sorted(values.keys()): 79 | if not skip_patterns or all( 80 | [skip_pattern not in key for skip_pattern in skip_patterns]): 81 | print_out(" %s=%s" % (key, str(values[key]))) 82 | 83 | 84 | def load_hparams(model_dir): 85 | """Load hparams from an existing model directory.""" 86 | hparams_file = os.path.join(model_dir, "hparams") 87 | if tf.gfile.Exists(hparams_file): 88 | print_out("# Loading hparams from %s" % hparams_file) 89 | with codecs.getreader("utf-8")(tf.gfile.GFile(hparams_file, "rb")) as f: 90 | try: 91 | hparams_values = json.load(f) 92 | hparams = tf.contrib.training.HParams(**hparams_values) 93 | except ValueError: 94 | print_out(" can't load hparams file") 95 | return None 96 | return hparams 97 | else: 98 | return None 99 | 100 | 101 | def maybe_parse_standard_hparams(hparams, hparams_path): 102 | """Override hparams values with existing standard hparams config.""" 103 | if not hparams_path: 104 | return hparams 105 | 106 | if tf.gfile.Exists(hparams_path): 107 | print_out("# Loading standard hparams from %s" % hparams_path) 108 | with tf.gfile.GFile(hparams_path, "r") as f: 109 | hparams.parse_json(f.read()) 110 | 111 | return hparams 112 | 113 | 114 | def save_hparams(out_dir, hparams): 115 | """Save hparams.""" 116 | hparams_file = os.path.join(out_dir, "hparams") 117 | print_out(" saving hparams to %s" % hparams_file) 118 | with codecs.getwriter("utf-8")(tf.gfile.GFile(hparams_file, "wb")) as f: 119 | f.write(hparams.to_json()) 120 | 121 | 122 | def debug_tensor(s, msg=None, summarize=10): 123 | """Print the shape and value of a tensor at test time. Return a new tensor.""" 124 | if not msg: 125 | msg = s.name 126 | return tf.Print(s, [tf.shape(s), s], msg + " ", summarize=summarize) 127 | 128 | 129 | def add_summary(summary_writer, global_step, tag, value): 130 | """Add a new summary to the current summary_writer. 131 | Useful to log things that are not part of the training graph, e.g., tag=BLEU. 132 | """ 133 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 134 | summary_writer.add_summary(summary, global_step) 135 | 136 | 137 | def get_config_proto(log_device_placement=False, allow_soft_placement=True, 138 | num_intra_threads=0, num_inter_threads=0): 139 | # GPU options: 140 | # https://www.tensorflow.org/versions/r0.10/how_tos/using_gpu/index.html 141 | config_proto = tf.ConfigProto( 142 | log_device_placement=log_device_placement, 143 | allow_soft_placement=allow_soft_placement) 144 | config_proto.gpu_options.allow_growth = True 145 | 146 | # CPU threads options 147 | if num_intra_threads: 148 | config_proto.intra_op_parallelism_threads = num_intra_threads 149 | if num_inter_threads: 150 | config_proto.inter_op_parallelism_threads = num_inter_threads 151 | 152 | return config_proto 153 | 154 | 155 | def format_text(words): 156 | """Convert a sequence words into sentence.""" 157 | if (not hasattr(words, "__len__") and # for numpy array 158 | not isinstance(words, collections.Iterable)): 159 | words = [words] 160 | return b" ".join(words) 161 | 162 | 163 | def format_bpe_text(symbols, delimiter=b"@@"): 164 | """Convert a sequence of bpe words into sentence.""" 165 | words = [] 166 | word = b"" 167 | if isinstance(symbols, str): 168 | symbols = symbols.encode() 169 | delimiter_len = len(delimiter) 170 | for symbol in symbols: 171 | if len(symbol) >= delimiter_len and symbol[-delimiter_len:] == delimiter: 172 | word += symbol[:-delimiter_len] 173 | else: # end of a word 174 | word += symbol 175 | words.append(word) 176 | word = b"" 177 | return b" ".join(words) 178 | 179 | 180 | def format_spm_text(symbols): 181 | """Decode a text in SPM (https://github.com/google/sentencepiece) format.""" 182 | return u"".join(format_text(symbols).decode("utf-8").split()).replace( 183 | u"\u2581", u" ").strip().encode("utf-8") 184 | -------------------------------------------------------------------------------- /nmt/utils/nmt_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility functions specifically for NMT.""" 17 | from __future__ import print_function 18 | 19 | import codecs 20 | import time 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | from ..utils import evaluation_utils 25 | from ..utils import misc_utils as utils 26 | 27 | __all__ = ["decode_and_evaluate", "get_translation"] 28 | 29 | 30 | def decode_and_evaluate(name, 31 | model, 32 | sess, 33 | # summary_writer, 34 | trans_file, 35 | ref_file, 36 | metrics, 37 | subword_option, 38 | beam_width, 39 | tgt_eos, 40 | num_translations_per_input=1, 41 | decode=True): 42 | """Decode a test set and compute a score according to the evaluation task.""" 43 | # Decode 44 | if decode: 45 | utils.print_out(" decoding to output %s." % trans_file) 46 | 47 | start_time = time.time() 48 | num_sentences = 0 49 | with codecs.getwriter("utf-8")( 50 | tf.gfile.GFile(trans_file, mode="wb")) as trans_f, codecs.getwriter("utf-8")( 51 | tf.gfile.GFile(trans_file + "_prob", mode="wb")) as probs_f: 52 | trans_f.write("") # Write empty string to ensure file is created. 53 | probs_f.write("") 54 | 55 | num_translations_per_input = max( 56 | min(num_translations_per_input, beam_width), 1) 57 | 58 | step = 0 59 | while True: 60 | try: 61 | nmt_outputs, step_probs, _ = model.decode(sess, step) #, summary_writer, step) 62 | step = step + 1 63 | if beam_width == 0: 64 | nmt_outputs = np.expand_dims(nmt_outputs, 0) 65 | step_probs = np.expand_dims(step_probs, 0) 66 | 67 | batch_size = nmt_outputs.shape[1] 68 | num_sentences += batch_size 69 | 70 | for sent_id in range(batch_size): 71 | for beam_id in range(num_translations_per_input): 72 | translation = get_translation( 73 | nmt_outputs[beam_id], 74 | sent_id, 75 | tgt_eos=tgt_eos, 76 | subword_option=subword_option) 77 | trans_f.write((translation + b"\n").decode("utf-8")) 78 | 79 | trans_step_probs = get_log_probs( 80 | step_probs[beam_id], 81 | sent_id, 82 | output_len=int((len(translation)+1)/2)) 83 | probs_f.write((trans_step_probs + b"\n").decode("utf-8")) 84 | except tf.errors.OutOfRangeError: 85 | utils.print_time( 86 | " done, num sentences %d, num translations per input %d" % 87 | (num_sentences, num_translations_per_input), start_time) 88 | break 89 | 90 | # Evaluation 91 | evaluation_scores = {} 92 | if ref_file and tf.gfile.Exists(trans_file): 93 | for metric in metrics: 94 | score = evaluation_utils.evaluate( 95 | ref_file, 96 | trans_file, 97 | metric, 98 | subword_option=subword_option) 99 | evaluation_scores[metric] = score 100 | utils.print_out(" %s %s: %.1f" % (metric, name, score)) 101 | 102 | return evaluation_scores 103 | 104 | 105 | def get_translation(nmt_outputs, sent_id, tgt_eos, subword_option): 106 | """Given batch decoding outputs, select a sentence and turn to text.""" 107 | if tgt_eos: tgt_eos = tgt_eos.encode("utf-8") 108 | # Select a sentence 109 | output = nmt_outputs[sent_id, :].tolist() 110 | 111 | # If there is an eos symbol in outputs, cut them at that point. 112 | if tgt_eos and tgt_eos in output: 113 | output = output[:output.index(tgt_eos)] 114 | 115 | if subword_option == "bpe": # BPE 116 | translation = utils.format_bpe_text(output) 117 | elif subword_option == "spm": # SPM 118 | translation = utils.format_spm_text(output) 119 | else: 120 | translation = utils.format_text(output) 121 | 122 | return translation 123 | 124 | 125 | def get_log_probs(step_log_probs, sent_id, output_len): 126 | """Given batch decoding outputs, select a sentence and return 127 | log prob for each character.""" 128 | # Select a sentence 129 | output = step_log_probs[sent_id, :].tolist() 130 | # Slecct only first 'output_len' elements 131 | output = output[:output_len] 132 | 133 | # trans_step_probs = " ".join(map(str, output)) 134 | trans_step_probs = " ".join("%.6f" % x for x in output) 135 | 136 | return trans_step_probs.encode("utf-8") 137 | -------------------------------------------------------------------------------- /nmt/utils/peaks_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Korrawe Karunratanakul 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import collections 17 | import multiprocessing 18 | import numpy as np 19 | import sys 20 | import time 21 | 22 | # sys.path.append('..') 23 | from nmt import input_config 24 | 25 | def parse_row(input_row): 26 | '''Parse input string to data entry''' 27 | # header = ["sequence", "charge", "mass/charge", "score1", "score2", "spectrum"] 28 | input_row = input_row.strip() 29 | data = input_row.split('|') 30 | 31 | charge = int(float(data[1])) 32 | mass_p_charge = float(data[2]) 33 | peptide_mass = charge * mass_p_charge + (2.0 - charge) * input_config.mass_H 34 | 35 | spec = data[-1].split(',') 36 | 37 | positions = spec[::2] 38 | values = spec[1::2] 39 | seq = data[0] 40 | 41 | positions = [float(posi) for posi in positions] 42 | values = [float(val) for val in values] 43 | 44 | return seq, charge, mass_p_charge, peptide_mass, positions, values 45 | 46 | 47 | def find_nearest_peak(array, value): 48 | idx = (np.abs(array - value)).argmin() 49 | return array[idx] 50 | 51 | 52 | def has_peak_within_threshold(relevant_peaks, value, threshold=0.1): 53 | return (np.abs(relevant_peaks - value)).min() < threshold 54 | 55 | 56 | def get_sorted_relevant_peak_masses(sequence, peptide_mass, charge): 57 | '''Compute relevant peak masses from the input sequence 58 | Input: 59 | sequence: A list of strings. Each entry contain a string representing an amino acid 60 | peptide_mass: Float. total mass of the peptide 61 | charge: Int. 62 | 63 | Output: 64 | A list of relevant masses. For each amino acid, 8 locations are considered: b,b(2+),b-H2O,b-NH3, 65 | y,y(2+),y-H2O,y-NH3. 66 | ''' 67 | amino_acid_masses = [] 68 | 69 | for amino_acid in sequence: 70 | assert amino_acid in input_config.mass_AA 71 | amino_acid_masses.append(input_config.mass_AA[amino_acid]) 72 | 73 | amino_acid_masses = np.array(amino_acid_masses) 74 | cum_sum = np.cumsum(amino_acid_masses) 75 | 76 | b_ion = cum_sum + input_config.mass_N_terminus 77 | y_ion = peptide_mass - b_ion 78 | 79 | b_H2O = b_ion - input_config.mass_H2O 80 | b_NH3 = b_ion - input_config.mass_NH3 81 | b_plus2_charge1 = (b_ion + input_config.mass_H) / 2 82 | 83 | y_H2O = y_ion - input_config.mass_H2O 84 | y_NH3 = y_ion - input_config.mass_NH3 85 | y_plus2_charge1 = (y_ion + input_config.mass_H) / 2 86 | 87 | relevant_masses = np.concatenate((b_ion, b_H2O, b_NH3, b_plus2_charge1, 88 | y_ion, y_H2O, y_NH3, y_plus2_charge1), axis=0) 89 | sorted_relevant_masses = np.sort(relevant_masses) 90 | return sorted_relevant_masses 91 | 92 | 93 | def get_relevant_peak_mass_dict(sequence, peptide_mass, charge): 94 | '''Compute relevant peak masses from the input sequence 95 | Input: 96 | sequence: A list of strings. Each entry contain a string representing an amino acid 97 | peptide_mass: Float. total mass of the peptide 98 | charge: Int. 99 | 100 | Output: 101 | A dictionary of lists of relevant masses. For each amino acid, 8 locations are considered: 102 | b,b(2+),b-H2O,b-NH3,y,y(2+),y-H2O,y-NH3. 103 | ''' 104 | amino_acid_masses = [] 105 | 106 | for amino_acid in sequence: 107 | assert amino_acid in input_config.mass_AA 108 | amino_acid_masses.append(input_config.mass_AA[amino_acid]) 109 | 110 | amino_acid_masses = np.array(amino_acid_masses) 111 | cum_sum = np.cumsum(amino_acid_masses) 112 | 113 | relevant_peak_dict = {} 114 | b_ion = cum_sum + input_config.mass_N_terminus 115 | y_ion = peptide_mass - b_ion 116 | relevant_peak_dict['b_ion'] = b_ion 117 | relevant_peak_dict['y_ion'] = y_ion 118 | 119 | relevant_peak_dict['b_H2O'] = b_ion - input_config.mass_H2O 120 | relevant_peak_dict['b_NH3'] = b_ion - input_config.mass_NH3 121 | relevant_peak_dict['b_plus2_charge1'] = (b_ion + input_config.mass_H) / 2 122 | 123 | relevant_peak_dict['y_H2O'] = y_ion - input_config.mass_H2O 124 | relevant_peak_dict['y_NH3'] = y_ion - input_config.mass_NH3 125 | relevant_peak_dict['y_plus2_charge1'] = (y_ion + input_config.mass_H) / 2 126 | return relevant_peak_dict 127 | 128 | 129 | def compute_evidence_correlation(source_file_name, result_file_name): 130 | '''Compute correlation between amino acid eviden and correctness of prediction 131 | Args: 132 | source_file_name: data file 133 | result_file_name: result file, same number of line as data file. 134 | each line consisting of 1 or 0 indicating the correctness of the predictions 135 | ''' 136 | with open(source_file_name, 'r') as source_file, open(result_file_name, 'r') as result_file: 137 | count = 0 138 | amino_acid_count, amino_acid_with_evidence_count = 0., 0. 139 | has_evidence_correct_count, has_evidence_incorrect_count = 0., 0. 140 | no_evidence_correct_count, no_evidence_incorrect_count = 0., 0. 141 | for row in source_file: 142 | result = result_file.readline().strip() 143 | result = result.split(',') 144 | 145 | (seq, charge, mass_p_charge, peptide_mass, positions, values) = parse_row(row) 146 | # print(seq, charge, mass_p_charge, peptide_mass, positions, values) 147 | assert len(seq) == len(result) 148 | 149 | relevant_peak_dict = get_relevant_peak_mass_dict(seq, peptide_mass, charge) 150 | # considered_ion = ['b_ion','y_ion'] 151 | considered_ion = ['b_ion','y_ion','b_plus2_charge1','y_plus2_charge1'] 152 | # considered_ion = ['b_ion','y_ion','b_H2O','y_H2O','b_NH3','y_NH3','b_plus2_charge1','y_plus2_charge1'] 153 | 154 | # positions must be sorted 155 | positions = np.array([float(posi) for posi in positions]) 156 | amino_acid_count += len(seq) 157 | has_evidence_in_previous_step = True 158 | for pos in range(len(seq)): 159 | has_evidence = False 160 | for ion_type in considered_ion: 161 | has_evidence |= has_peak_within_threshold(positions, relevant_peak_dict[ion_type][pos]) 162 | evidence_for_next_step = has_evidence 163 | has_evidence = has_evidence_in_previous_step and has_evidence 164 | 165 | if result[pos] == '1' and has_evidence: 166 | has_evidence_correct_count += 1 167 | elif result[pos] == '1' and not has_evidence: 168 | no_evidence_correct_count += 1 169 | elif result[pos] == '0' and has_evidence: 170 | has_evidence_incorrect_count += 1 171 | elif result[pos] == '0' and not has_evidence: 172 | no_evidence_incorrect_count += 1 173 | else: 174 | print("Something is worng here") 175 | 176 | if has_evidence: 177 | amino_acid_with_evidence_count += 1 178 | has_evidence_in_previous_step = evidence_for_next_step 179 | count += 1 180 | if count % 5000 == 0: print(count) # break 181 | 182 | print("Amino acid with evidence: %.2f %%" % (100 * amino_acid_with_evidence_count/amino_acid_count)) 183 | print("Amino acid with evidence: %.0f, Total amino acid: %0.f" % (amino_acid_with_evidence_count,amino_acid_count)) 184 | print("Has evidence and correct: %.0f (%.2f %%), Has evidence but incorrect: %.0f (%.2f %%)" % 185 | (has_evidence_correct_count, 100 * has_evidence_correct_count/amino_acid_count, 186 | has_evidence_incorrect_count, 100 * has_evidence_incorrect_count/amino_acid_count)) 187 | print("No evidence but correct: %.0f (%.2f %%), No evidence and incorrect: %.0f (%.2f %%)" % 188 | (no_evidence_correct_count, 100 * no_evidence_correct_count/amino_acid_count, 189 | no_evidence_incorrect_count, 100 * no_evidence_incorrect_count/amino_acid_count)) 190 | print("Of all incorrect amino acids, %.2f %% has no evidence" % 191 | (100 * no_evidence_incorrect_count / (no_evidence_incorrect_count + has_evidence_incorrect_count))) 192 | return 193 | -------------------------------------------------------------------------------- /nmt/utils/standard_hparams_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """standard hparams utils.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | 25 | def create_standard_hparams(): 26 | return tf.contrib.training.HParams( 27 | # Data 28 | src="", 29 | tgt="", 30 | train_prefix="", 31 | dev_prefix="", 32 | test_prefix="", 33 | vocab_prefix="", 34 | embed_prefix="", 35 | out_dir="", 36 | 37 | # Networks 38 | num_units=512, 39 | num_layers=2, 40 | num_encoder_layers=2, 41 | num_decoder_layers=2, 42 | dropout=0.2, 43 | unit_type="lstm", 44 | encoder_type="bi", 45 | residual=False, 46 | time_major=True, 47 | num_embeddings_partitions=0, 48 | 49 | # Attention mechanisms 50 | attention="scaled_luong", 51 | attention_architecture="standard", 52 | output_attention=True, 53 | pass_hidden_state=True, 54 | 55 | # Train 56 | optimizer="sgd", 57 | batch_size=128, 58 | init_op="uniform", 59 | init_weight=0.1, 60 | max_gradient_norm=5.0, 61 | learning_rate=1.0, 62 | warmup_steps=0, 63 | warmup_scheme="t2t", 64 | decay_scheme="luong234", 65 | colocate_gradients_with_ops=True, 66 | num_train_steps=12000, 67 | 68 | # Data constraints 69 | num_buckets=5, 70 | max_train=0, 71 | src_max_len=50, 72 | tgt_max_len=50, 73 | src_max_len_infer=0, 74 | tgt_max_len_infer=0, 75 | 76 | # Data format 77 | sos="", 78 | eos="", 79 | subword_option="", 80 | check_special_token=True, 81 | 82 | # Misc 83 | forget_bias=1.0, 84 | num_gpus=1, 85 | epoch_step=0, # record where we were within an epoch. 86 | steps_per_stats=100, 87 | steps_per_external_eval=0, 88 | share_vocab=False, 89 | metrics=["bleu"], 90 | log_device_placement=False, 91 | random_seed=None, 92 | # only enable beam search during inference when beam_width > 0. 93 | beam_width=0, 94 | length_penalty_weight=0.0, 95 | override_loaded_hparams=True, 96 | num_keep_ckpts=5, 97 | avg_ckpts=False, 98 | 99 | # For inference 100 | inference_indices=None, 101 | infer_batch_size=32, 102 | sampling_temperature=0.0, 103 | num_translations_per_input=1, 104 | ) 105 | -------------------------------------------------------------------------------- /nmt/utils/vocab_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility to handle vocabularies.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import codecs 23 | import os 24 | import tensorflow as tf 25 | 26 | from tensorflow.python.ops import lookup_ops 27 | 28 | from ..utils import misc_utils as utils 29 | 30 | 31 | UNK = "" 32 | SOS = "" 33 | EOS = "" 34 | UNK_ID = 0 35 | 36 | 37 | def load_vocab(vocab_file): 38 | vocab = [] 39 | with codecs.getreader("utf-8")(tf.gfile.GFile(vocab_file, "rb")) as f: 40 | vocab_size = 0 41 | for word in f: 42 | vocab_size += 1 43 | vocab.append(word.strip()) 44 | return vocab, vocab_size 45 | 46 | 47 | def check_vocab(vocab_file, out_dir, check_special_token=True, sos=None, 48 | eos=None, unk=None): 49 | """Check if vocab_file doesn't exist, create from corpus_file.""" 50 | if tf.gfile.Exists(vocab_file): 51 | utils.print_out("# Vocab file %s exists" % vocab_file) 52 | vocab, vocab_size = load_vocab(vocab_file) 53 | if check_special_token: 54 | # Verify if the vocab starts with unk, sos, eos 55 | # If not, prepend those tokens & generate a new vocab file 56 | if not unk: unk = UNK 57 | if not sos: sos = SOS 58 | if not eos: eos = EOS 59 | assert len(vocab) >= 3 60 | if vocab[0] != unk or vocab[1] != sos or vocab[2] != eos: 61 | utils.print_out("The first 3 vocab words [%s, %s, %s]" 62 | " are not [%s, %s, %s]" % 63 | (vocab[0], vocab[1], vocab[2], unk, sos, eos)) 64 | vocab = [unk, sos, eos] + vocab 65 | vocab_size += 3 66 | new_vocab_file = os.path.join(out_dir, os.path.basename(vocab_file)) 67 | with codecs.getwriter("utf-8")( 68 | tf.gfile.GFile(new_vocab_file, "wb")) as f: 69 | for word in vocab: 70 | f.write("%s\n" % word) 71 | vocab_file = new_vocab_file 72 | else: 73 | raise ValueError("vocab_file '%s' does not exist." % vocab_file) 74 | 75 | vocab_size = len(vocab) 76 | return vocab_size, vocab_file 77 | 78 | 79 | def create_vocab_tables(tgt_vocab_file): 80 | """Creates vocab tables for tgt_vocab_file.""" 81 | tgt_vocab_table = lookup_ops.index_table_from_file( 82 | tgt_vocab_file, default_value=UNK_ID) 83 | 84 | return tgt_vocab_table 85 | 86 | 87 | def load_embed_txt(embed_file): 88 | """Load embed_file into a python dictionary. 89 | 90 | Note: the embed_file should be a Glove formated txt file. Assuming 91 | embed_size=5, for example: 92 | 93 | the -0.071549 0.093459 0.023738 -0.090339 0.056123 94 | to 0.57346 0.5417 -0.23477 -0.3624 0.4037 95 | and 0.20327 0.47348 0.050877 0.002103 0.060547 96 | 97 | Args: 98 | embed_file: file path to the embedding file. 99 | Returns: 100 | a dictionary that maps word to vector, and the size of embedding dimensions. 101 | """ 102 | emb_dict = dict() 103 | emb_size = None 104 | with codecs.getreader("utf-8")(tf.gfile.GFile(embed_file, 'rb')) as f: 105 | for line in f: 106 | tokens = line.strip().split(" ") 107 | word = tokens[0] 108 | vec = list(map(float, tokens[1:])) 109 | emb_dict[word] = vec 110 | if emb_size: 111 | assert emb_size == len(vec), "All embedding size should be same." 112 | else: 113 | emb_size = len(vec) 114 | return emb_dict, emb_size 115 | -------------------------------------------------------------------------------- /nmt/vocab/mass.txt: -------------------------------------------------------------------------------- 1 | 0.0 2 | 0.0 3 | 0.0 4 | 71.03711 5 | 160.03065 6 | 115.02694 7 | 129.04259 8 | 147.06841 9 | 57.02146 10 | 137.05891 11 | 113.08406 12 | 128.09496 13 | 113.08406 14 | 131.04049 15 | 147.0354 16 | 114.04293 17 | 97.05276 18 | 128.05858 19 | 156.10111 20 | 87.03203 21 | 166.99836 22 | 101.04768 23 | 181.01401 24 | 99.06841 25 | 186.07931 26 | 163.06333 27 | 243.02966 -------------------------------------------------------------------------------- /nmt/vocab/vocab.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | A 5 | C 6 | D 7 | E 8 | F 9 | G 10 | H 11 | I 12 | K 13 | L 14 | M 15 | m 16 | N 17 | P 18 | Q 19 | R 20 | S 21 | s 22 | T 23 | t 24 | V 25 | W 26 | Y 27 | y -------------------------------------------------------------------------------- /nmt/vocab/vocab_m.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | A 5 | C 6 | D 7 | E 8 | F 9 | G 10 | H 11 | I 12 | K 13 | L 14 | M 15 | m 16 | N 17 | P 18 | Q 19 | R 20 | S 21 | T 22 | V 23 | W 24 | Y -------------------------------------------------------------------------------- /nmt/vocab/vocab_nomod.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | A 5 | C 6 | D 7 | E 8 | F 9 | G 10 | H 11 | I 12 | K 13 | L 14 | M 15 | N 16 | P 17 | Q 18 | R 19 | S 20 | T 21 | V 22 | W 23 | Y -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.5.0 2 | astor==0.7.1 3 | backcall==0.1.0 4 | bleach==3.0.0 5 | cycler==0.10.0 6 | decorator==4.3.0 7 | defusedxml==0.5.0 8 | entrypoints==0.2.3 9 | gast==0.2.0 10 | grpcio==1.15.0 11 | h5py==2.8.0 12 | ipykernel==5.0.0 13 | ipython==7.0.1 14 | ipython-genutils==0.2.0 15 | ipywidgets==7.4.2 16 | jedi==0.13.1 17 | Jinja2==2.10 18 | jsonschema==2.6.0 19 | jupyter==1.0.0 20 | jupyter-client==5.2.3 21 | jupyter-console==6.0.0 22 | jupyter-core==4.4.0 23 | jupyterlab==0.35.0 24 | jupyterlab-server==0.2.0 25 | Keras==2.2.4 26 | Keras-Applications==1.0.6 27 | Keras-Preprocessing==1.0.5 28 | kiwisolver==1.0.1 29 | Markdown==3.0.1 30 | MarkupSafe==1.0 31 | matplotlib==3.0.2 32 | mistune==0.8.3 33 | nbconvert==5.4.0 34 | nbformat==4.4.0 35 | notebook==5.7.0 36 | numpy==1.15.2 37 | pandas==0.24.1 38 | pandocfilters==1.4.2 39 | parso==0.3.1 40 | pexpect==4.6.0 41 | pickleshare==0.7.5 42 | prometheus-client==0.4.0 43 | prompt-toolkit==2.0.5 44 | protobuf==3.6.1 45 | ptyprocess==0.6.0 46 | Pygments==2.2.0 47 | pyparsing==2.3.1 48 | python-dateutil==2.7.3 49 | pytz==2018.9 50 | PyYAML==3.13 51 | pyzmq==17.1.2 52 | qtconsole==4.4.1 53 | scikit-learn==0.20.2 54 | scipy==1.2.1 55 | seaborn==0.9.0 56 | Send2Trash==1.5.0 57 | simplegeneric==0.8.1 58 | six==1.11.0 59 | sklearn==0.0 60 | tensorboard==1.11.0 61 | tensorflow-gpu==1.11.0 62 | termcolor==1.1.0 63 | terminado==0.8.1 64 | testpath==0.4.2 65 | tornado==5.1.1 66 | traitlets==4.3.2 67 | wcwidth==0.1.7 68 | webencodings==0.5.1 69 | Werkzeug==0.14.1 70 | widgetsnbextension==3.4.2 71 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Korrawe Karunratanakul 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from __future__ import print_function 17 | 18 | import argparse 19 | import os 20 | import random 21 | import sys 22 | 23 | # import matplotlib.image as mpimg 24 | import numpy as np 25 | import tensorflow as tf 26 | 27 | from nmt import inference 28 | from nmt import train 29 | from nmt.utils import evaluation_utils 30 | from nmt.utils import misc_utils as utils 31 | from nmt.utils import vocab_utils 32 | from nmt.utils import iterator_utils 33 | from utils_data import convert_mgf_to_csv as mgf2csv 34 | from utils_masking import create_denovo_report as report_utils 35 | 36 | utils.check_tensorflow_version() 37 | 38 | FLAGS = None 39 | 40 | from nmt import model as nmt_model 41 | from nmt import model_helper 42 | from nmt import post_process 43 | 44 | input_files = [] 45 | # input_files = ['data_deepnovo/peaks_db'] 46 | for i in range(22): #453,1): # 490,1): # 418,1): 47 | input_files.append("data_best_compat/train_" + str(i)) 48 | # print(input_files) 49 | 50 | hparams = tf.contrib.training.HParams( 51 | # Data 52 | src=input_files, # "nmt/processed_data/evidence_16", 53 | # src="nmt/processed_data/evidence_16", 54 | tgt='_tgt.csv', 55 | #train_prefix=flags.train_prefix, 56 | #dev_prefix=flags.dev_prefix, 57 | #test_prefix=flags.test_prefix, 58 | #vocab_prefix=flags.vocab_prefix, 59 | #embed_prefix=flags.embed_prefix, 60 | out_dir="log_ablation/log_best_700k...",# encoder_to_decoder_lookahead", # "log_nomod_full", #"test_one", 61 | tgt_vocab_file="nmt/vocab/vocab.txt",# "nmt/vocab/vocab_m.txt", 62 | tgt_vocab_size=27, # 27, 63 | tgt_embed_file="", 64 | 65 | # dev="data_mod_m/val_no_dup", # "data_nomod_full/val_no_dup", 66 | # test="data_mod_m/test_no_dup", # "data_nomod_full/test_no_dup", 67 | dev="data_best_compat/val_no_dup", # "data_mod_m/val_no_dup_1000", # "data_nomod_full/val_no_dup", 68 | test="data_best_compat/test_no_dup", # "data_mod_m/val_no_dup_relevant", #"data_mod_m/test_no_dup_1000", # "data_nomod_full/test_no_dup", 69 | src_suffix=".csv", 70 | train_fin=True, ### 71 | 72 | # Networks 73 | num_units=512, 74 | num_layers=2, # Compatible 75 | #num_encoder_layers=1, #(flags.num_encoder_layers or flags.num_layers), 76 | num_decoder_layers=2 ,#(flags.num_decoder_layers or flags.num_layers), 77 | dropout=0.1, 78 | unit_type="layer_norm_lstm", # "lstm", 79 | # encoder_type=None, #flags.encoder_type, 80 | residual=True, #flags.residual, 81 | num_decoder_residual_layers = 1, ###### 82 | time_major=True, 83 | num_embeddings_partitions=0, 84 | 85 | # Attention mechanisms 86 | attention="", 87 | #attention_architecture=flags.attention_architecture, 88 | #output_attention=flags.output_attention, 89 | #pass_hidden_state=flags.pass_hidden_state, 90 | 91 | # Train 92 | optimizer="sgd", 93 | num_train_steps=700000, 94 | batch_size=32, 95 | init_op="uniform", 96 | init_weight=0.1, 97 | max_gradient_norm=5.0, 98 | learning_rate=0.01, 99 | warmup_steps=0, #flags.warmup_steps, 100 | warmup_scheme="t2t", #flags.warmup_scheme, 101 | decay_scheme="luong234", #flags.decay_scheme, 102 | colocate_gradients_with_ops=True, 103 | 104 | # Data constraints 105 | num_buckets=2, 106 | max_train=0, 107 | src_max_len=None, 108 | tgt_max_len=None, 109 | 110 | # Inference 111 | src_max_len_infer=None, #flags.src_max_len_infer,+ 112 | tgt_max_len_infer=50, #flags.tgt_max_len_infer, 113 | infer_batch_size=8, 114 | beam_width=20, ############################################### 115 | length_penalty_weight=1.0, 116 | sampling_temperature=0.0, 117 | num_translations_per_input=1, #flags.num_translations_per_input, 118 | 119 | # Vocab 120 | sos='', #flags.sos if flags.sos else vocab_utils.SOS, 121 | eos='', #flags.eos if flags.eos else vocab_utils.EOS, 122 | subword_option=None, #flags.subword_option, 123 | check_special_token=None, #flags.check_special_token, 124 | embed_size=32, # 5, # 64 125 | 126 | # Misc 127 | forget_bias=1.0, 128 | num_gpus=1, 129 | epoch_step=0, # record where we were within an epoch. 130 | steps_per_stats=200, 131 | steps_per_external_eval=20000, #None, 132 | share_vocab=None, #flags.share_vocab, 133 | metrics=["bleu","accuracy","amino_acid_accuracy"], # rouge 134 | log_device_placement=False, 135 | random_seed=48, #flags.random_seed, 136 | override_loaded_hparams=False, 137 | num_keep_ckpts=5, 138 | avg_ckpts=False, 139 | num_intra_threads=None,# flags.num_intra_threads, 140 | num_inter_threads=None 141 | ) 142 | 143 | # Evaluation 144 | for metric in hparams.metrics: 145 | hparams.add_hparam("best_" + metric, 0) # larger is better 146 | best_metric_dir = os.path.join(hparams.out_dir, "best_" + metric) 147 | hparams.add_hparam("best_" + metric + "_dir", best_metric_dir) 148 | tf.gfile.MakeDirs(best_metric_dir) 149 | 150 | if hparams.avg_ckpts: 151 | hparams.add_hparam("avg_best_" + metric, 0) # larger is better 152 | best_metric_dir = os.path.join(hparams.out_dir, "avg_best_" + metric) 153 | hparams.add_hparam("avg_best_" + metric + "_dir", best_metric_dir) 154 | tf.gfile.MakeDirs(best_metric_dir) 155 | 156 | 157 | if __name__ == "__main__": 158 | parser = argparse.ArgumentParser() 159 | parser.register("type", "bool", lambda v: v.lower() == "true") 160 | parser.add_argument("--inference_input_file", type=str, default=None, 161 | help="Set to the text to decode.") 162 | # parser.add_argument("--inference_output_file", type=str, default=None, 163 | # help="Output file to store decoding results.") 164 | parser.add_argument("--ckpt", type=str, default="", 165 | help="Checkpoint file to load a model for inference.") 166 | parser.add_argument("--model_dir", type=str, default="", 167 | help="Directory to load a model for inference.") 168 | 169 | parser.add_argument("--rescore", type="bool", nargs="?", const=True, 170 | default=False, 171 | help="Rescore with previously trained model.") 172 | parser.add_argument("--rescore_logdir", type=str, default=None, 173 | help="Directory to save or load model for rescoring.") 174 | args = parser.parse_args() 175 | print(args) 176 | if args.inference_input_file: 177 | infer_input_file = args.inference_input_file 178 | # Inference 179 | hparams.inference_indices = None 180 | print(infer_input_file) 181 | 182 | # trans_file = args.inference_output_file 183 | source_filename = os.path.basename(infer_input_file)[:-4] # no ".mgf" 184 | input_dir = os.path.dirname(infer_input_file) 185 | 186 | trans_dir = input_dir + '_output/' 187 | trans_file = os.path.join(trans_dir, source_filename) 188 | # print(trans_path, trans_file) 189 | if not os.path.exists(trans_dir): 190 | os.mkdir(trans_dir) 191 | 192 | # convert to csv format if nessesary to speed-up inference 193 | if infer_input_file[-3:] == 'mgf': 194 | mgf2csv.main([trans_dir, infer_input_file]) 195 | infer_input_file = os.path.join(trans_dir, source_filename + '.csv') 196 | del_temp_file = True 197 | else: 198 | del_temp_file = False 199 | 200 | 201 | # check model path 202 | ckpt = args.ckpt 203 | if not ckpt: 204 | model_dir = hparams.out_dir 205 | if args.model_dir: 206 | model_dir = args.model_dir 207 | ckpt = tf.train.latest_checkpoint(model_dir) 208 | 209 | # decode 210 | inference.inference(ckpt, infer_input_file, trans_file, hparams) 211 | 212 | if args.rescore: 213 | if not args.rescore_logdir: 214 | rescore_dir = os.path.join(model_dir, "post_process") 215 | post_process.rescore(trans_file, trans_file + "_prob", infer_input_file, 216 | rescore_dir, trans_dir + source_filename + "_rescore") 217 | print("Done") 218 | 219 | # Create report if m_mod(21 AAs + 3 tokens) or p_mod (24 AAs + 3 tokens) 220 | if hparams.tgt_vocab_size == 24: 221 | report_utils.main(trans_dir, input_dir, 'm-mod') 222 | elif hparams.tgt_vocab_size == 27: 223 | report_utils.main(trans_dir, input_dir, 'p-mod') 224 | 225 | if del_temp_file: 226 | os.remove(infer_input_file) 227 | 228 | else: 229 | print('training') 230 | train.train(hparams) 231 | -------------------------------------------------------------------------------- /utils_data/clean_msms_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Korrawe Karunratanakul 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | 18 | ### CONSTANTS 19 | aa_list = 'GASPVTCLINDQKEMHFRYWmnq' 20 | aa_mass_list = [57.02146,71.03711,87.03203, 21 | 97.05276,99.06841,101.04768, 22 | 103.00919 + 57.02146,113.08406,113.08406, ## C are C mod 23 | 114.04293,115.02694,128.05858, 24 | 128.09496,129.04259,131.04049, 25 | 137.05891,147.06841,156.10111, 26 | 163.06333,186.07931,131.04049 + 15.99491, ## m = M(ox) 27 | 114.04293 + 0.98402,128.05858 + 0.98402] ## n = N(de), q = Q(de) 28 | 29 | aa_mass = {} 30 | for i in range(len(aa_list)): 31 | aa_mass[aa_list[i]] = aa_mass_list[i] 32 | 33 | proton = 1.007276 34 | water = 18.010565 35 | ammonia = 17.026549 36 | aion = 27.994915 37 | mass_tol = 20 * 0.000001 # ppm 38 | 39 | ############################################ 40 | ### FUNCTION FOR COMPUTING EVIDENCE 41 | ### MAJOR = +1, +2 42 | ### MINOR = -H2O, -NH3, -CO 43 | def find_evidence(seq, spectrum, tol): 44 | evidence_major = [0] * len(seq) 45 | evidence_minor = [0] * len(seq) 46 | 47 | ladder = seq_to_ladder(seq) 48 | ion_major = [0] * len(ladder) 49 | ion_minor = [0] * len(ladder) 50 | 51 | for pos in range(len(ladder)): 52 | found_major, found_minor = find_evidence_helper(ladder, pos, spectrum, tol) 53 | 54 | if found_major: 55 | ion_major[pos] = 1 56 | ion_minor[pos] = 1 57 | elif found_minor: 58 | ion_minor[pos] = 1 59 | 60 | for i in range(len(seq)): 61 | if ion_major[i] == 1 and ion_major[i + 1] == 1: 62 | evidence_major[i] = 1 63 | if ion_minor[i] == 1 and ion_minor[i + 1] == 1: 64 | evidence_minor[i] = 1 65 | 66 | return evidence_major, evidence_minor 67 | 68 | def find_evidence_helper(ladder, pos, spectrum, tol): # look for evidence for a ladder position in a spectrum 69 | if pos == 0 or pos == len(ladder) - 1: 70 | return True, True # always return True for 0 and total mass 71 | 72 | found_major = False 73 | found_minor = False 74 | 75 | major_masses, minor_masses = get_fragment_ion_masses(ladder[pos], ladder[-1]) 76 | 77 | for m in major_masses: 78 | matched_loc, matched_error = mass_matching_init(m, spectrum[:, 0], tol) 79 | 80 | if len(matched_loc) > 0: 81 | found_major = True 82 | break 83 | 84 | if not found_major: 85 | for m in minor_masses: 86 | matched_loc, matched_error = mass_matching_init(m, spectrum[:, 0], tol) 87 | 88 | if len(matched_loc) > 0: 89 | found_minor = True 90 | break 91 | 92 | return found_major, found_minor 93 | 94 | def seq_to_ladder(seq): # convert amino acid sequence to mass ladder 95 | ladder = [0] * (len(seq) + 1) 96 | ladder[1] = aa_mass[seq[0]] 97 | 98 | for i in range(2, len(ladder)): 99 | ladder[i] = ladder[i - 1] + aa_mass[seq[i - 1]] 100 | 101 | return ladder 102 | 103 | def get_fragment_ion_masses(aa_mass, total_mass): # compute b- and y-ion masses from sum of aa mass 104 | major_masses = sorted([aa_mass + proton, aa_mass / 2.0 + proton, \ 105 | total_mass - aa_mass + water + proton, (total_mass - aa_mass + water) / 2.0 + proton]) 106 | minor_masses = sorted([major_masses[0] - water, major_masses[0] - ammonia, major_masses[0] - aion, \ 107 | major_masses[2] - water, major_masses[2] - ammonia]) 108 | return major_masses, minor_masses 109 | 110 | def get_mass_error(observed, expected): 111 | if expected == 0: 112 | return observed - expected 113 | else: 114 | return (observed - expected) * 1000000.0 / expected 115 | 116 | def mass_matching_init(target, mass_list, tol): 117 | return mass_matching(target, mass_list, tol, 0, len(mass_list)) 118 | 119 | def mass_matching(target, mass_list, tol, start_loc, end_loc): 120 | if target * (1.0 + tol) < mass_list[start_loc] or target * (1.0 - tol) > mass_list[end_loc - 1]: 121 | return [], [] 122 | 123 | if end_loc - start_loc == 1: 124 | return [start_loc], [get_mass_error(mass_list[start_loc], target)] 125 | 126 | mid_loc = int((start_loc + end_loc) / 2) 127 | 128 | if target * (1.0 + tol) < mass_list[mid_loc]: 129 | return mass_matching(target, mass_list, tol, start_loc, mid_loc) 130 | 131 | if target * (1.0 - tol) > mass_list[mid_loc]: 132 | if mid_loc == len(mass_list) - 1: 133 | return False 134 | return mass_matching(target, mass_list, tol, mid_loc + 1, end_loc) 135 | 136 | hit_loc = [mid_loc] 137 | hit_error = [get_mass_error(mass_list[mid_loc], target)] 138 | next_loc = mid_loc + 1 139 | 140 | while next_loc < end_loc and mass_list[next_loc] <= target * (1.0 + tol): 141 | hit_loc.append(next_loc) 142 | hit_error.append(get_mass_error(mass_list[next_loc], target)) 143 | next_loc += 1 144 | 145 | next_loc = mid_loc - 1 146 | 147 | while next_loc >= start_loc and mass_list[next_loc] >= target * (1.0 - tol): 148 | hit_loc.append(next_loc) 149 | hit_error.append(get_mass_error(mass_list[next_loc], target)) 150 | next_loc -= 1 151 | 152 | return hit_loc, hit_error 153 | 154 | ############################################ 155 | ### MAIN 156 | input_files = [] 157 | for i in range(22): 158 | input_files.append("train_" + str(i)) 159 | # print(input_files) 160 | input_files = ["test_no_dup", "val_no_dup"] 161 | 162 | # data_path = '../data_best_compat/val_no_dup' 163 | total_seq_count = 0 164 | keep_count = 0 165 | 166 | for input_file in input_files: 167 | print(input_file) 168 | data_path = '../data_best_compat/' + input_file 169 | output_path = '../data_best_compat_clean/' + input_file 170 | tgt_output_path = '../data_best_compat_clean/' + input_file 171 | 172 | with open(data_path + '.csv', 'rt') as fin, open(output_path + '.csv', 'w') as fout, open(output_path + '_tgt.csv', 'w') as tgt_out: 173 | line = fin.readline() 174 | 175 | i = 0 176 | for line in fin: 177 | content = line.rstrip('\n').split('|') 178 | seq = content[0] 179 | heder = content[:5] 180 | # print(seq) 181 | if 'U' in seq: 182 | continue 183 | spectrum = np.reshape([float(x) for x in content[-1].split(',')], (-1, 2)) 184 | evidence_major, evidence_minor = find_evidence(seq, spectrum, mass_tol) 185 | # print('--') 186 | ### DO WHATEVER YOU WANT WITH EVIDENCE DATA HERE 187 | if np.sum(evidence_minor) / len(evidence_minor) >= 0.4: 188 | keep_count += 1 189 | # print(np.sum(evidence_minor), len(evidence_minor)) 190 | # print(' '.join(seq)) 191 | fout.write(line) 192 | tgt_out.write(' '.join(seq) + '\n') 193 | 194 | 195 | total_seq_count += 1 196 | # i += 1 197 | # if i > 5: break 198 | 199 | print('keep: %d, total: %d' % (keep_count, total_seq_count)) 200 | -------------------------------------------------------------------------------- /utils_data/convert_mgf_to_csv.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Korrawe Karunratanakul 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import os 17 | import sys 18 | 19 | # file_names = [] # 'evidence_16.txt' 20 | # for i in range(0, 28, 1): # range(0,212,1): 21 | # strr = "train_" + str(i) # + '.csv' 22 | # file_names.append(strr) 23 | 24 | # file_names = [ 25 | # '/data/deepnovo/high_all.mgf', 26 | # ] 27 | # # file_names = ['test_MGF/DS13HipH_RP_CE27_EqMass_1_formatted.mgf'] 28 | # # file_names = ['wu_peptidome/' + f for f in file_names] 29 | # file_names = ['test_no_dup_1000.mgf'] 30 | # print(file_names) 31 | 32 | 33 | 34 | def main(args): 35 | # args = [output_dir, file1, file2, file3, ...] 36 | out_dir = args[0] 37 | # header = ["sequence", "charge", "mass/charge", "score1", "score2", "spectrum"] 38 | i = 0 39 | m_mod_count = 0 40 | n_mod_count = 0 41 | q_mod_count = 0 42 | for file_name in args[1:]: 43 | print(file_name) 44 | out_filename = os.path.basename(file_name)[:-4] + '.csv' 45 | out_filename = os.path.join(out_dir, out_filename) 46 | # file_name = file_name 47 | with open(file_name, 'r') as input_file, open(out_filename, 'w') as output_file: 48 | 49 | for row in input_file: 50 | row = row.strip() 51 | if row == 'BEGIN IONS': 52 | # print('begin') 53 | mass_p_charge = '' 54 | seq = '' 55 | charge = '' 56 | spectrum = [] 57 | 58 | elif row == 'END IONS': 59 | if seq == '': 60 | seq = 'SEQ' 61 | spectrum_str = ','.join(spectrum) 62 | output_str = '|'.join([seq, charge, mass_p_charge, spectrum_str]) 63 | # print(output_str) 64 | output_file.write(output_str + '\n') 65 | i = i + 1 66 | # if i % 1000 == 0: print(i) 67 | 68 | elif row.startswith('TITLE'): 69 | # print('title') 70 | pass 71 | elif row.startswith('PEPMASS'): 72 | mass_p_charge = row.split('=')[1] 73 | # for wu_peptidome data, the pepmass contains a pair of space-seperated mass and sum of abundance 74 | mass_p_charge = mass_p_charge.split(' ')[0] 75 | # print(mass_p_charge) 76 | elif row.startswith('CHARGE'): 77 | charge = row.split('=')[1][0] 78 | # print(charge) 79 | elif row.startswith('SCANS'): 80 | pass 81 | elif row.startswith('RTINSECONDS'): 82 | pass 83 | elif row.startswith('SEQ'): 84 | seq = row.split('=')[1] 85 | 86 | # Edit modification 87 | # C mod +57 -> normal C 88 | seq = seq.replace('C(+57.02)', 'C') 89 | # M mod -> m 90 | seq = seq.replace('M(+15.99)', 'm') 91 | # Q mod -> q 92 | seq = seq.replace('Q(+.98)', 'q') 93 | # N mod -> n 94 | seq = seq.replace('N(+.98)', 'n') 95 | 96 | if '(' in seq: 97 | print(seq) 98 | break 99 | if 'm' in seq: m_mod_count += 1 100 | if 'q' in seq: q_mod_count += 1 101 | if 'n' in seq: n_mod_count += 1 102 | 103 | else: 104 | spectrum += row.split(' ') 105 | # i = i + 1 106 | # if i >= 2: 107 | # break 108 | 109 | 110 | print('total seq:', i) 111 | print('M mod:', m_mod_count, 100.*m_mod_count/i) 112 | print('N mod:', n_mod_count, 100.*n_mod_count/i) 113 | print('Q mod:', q_mod_count, 100.*q_mod_count/i) 114 | # BEGIN IONS 115 | # TITLE=DS13HipH_RP_CE27_EqMass_1.7.7.2 File:"DS13HipH_RP_CE27_EqMass_1.raw", NativeID:"controllerType=0 controllerNumber=1 scan=7" 116 | # PEPMASS=469.424499511719 117 | # CHARGE=2+ 118 | # SCANS=7 119 | # RTINSECONDS=2.44445604 120 | # SEQ= 121 | # 65.09106445 7391.3022460938 122 | 123 | if __name__ == "__main__": 124 | main(sys.argv[1:]) 125 | -------------------------------------------------------------------------------- /utils_masking/SMSNet_final_database_search.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Sira Sriswasdi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import os, re, gzip, math, time 18 | from multiprocessing.pool import ThreadPool, Pool 19 | 20 | num_thread = 10 ## number of CPU thread to use 21 | denovo_mass_tol = 50 ## mass error for filtering de novo results, in ppm 22 | pep_mass_tol = 20 ## mass error for comparing candidate peptide mass to theoretical mass, in ppm 23 | 24 | input_path = '/data/users/kppt/denovo/' ## location where de novo report file is located 25 | db_path = '/data/users/kppt/protein_db/' ## location where protein database is located 26 | 27 | db_name = 'human_all_isoform' ## name of protein database (.fasta format) 28 | 29 | exp_name = 'PXD009227_EGF_stimulate' ## name of de novo report file: [exp_name]_[mode_tag]_[fdr_tag].tsv) 30 | mode_tag = 'p-mod' 31 | fdr_tag = 'fdr5' 32 | 33 | if mode_tag == 'p-mod': 34 | mod_aas = 'MSTY' 35 | elif mode_tag == 'm-mod': 36 | mod_aas = 'M' 37 | 38 | mod_aas_lower = mod_aas.lower() 39 | 40 | mass_tag_cutoff = 3 41 | pep_mass_tol_abs = pep_mass_tol / 1000000 42 | mass_tag_re = re.compile('(\([^\)]*\))') 43 | 44 | aa_list = 'GASPVTCLINDQKEMHFRYWmstyUXZB' 45 | aa_mass_list = [57.02146,71.03711,87.03203, 46 | 97.05276,99.06841,101.04768, 47 | 103.00919 + 57.02146,113.08406,113.08406, ## C are C mod 48 | 114.04293,115.02694,128.05858, 49 | 128.09496,129.04259,131.04049, 50 | 137.05891,147.06841,156.10111, 51 | 163.06333,186.07931,131.04049 + 15.99491, ## m = M(ox) 52 | 87.03203 + 79.96633,101.04768 + 79.96633,163.06333 + 79.96633, ## s = S(ph), t = T(ph), y = Y(ph) 53 | 99999, 99999, 99999, 99999] ## very large masses for ambiguous amino acids 54 | 55 | aa_mass = {} 56 | for i in range(len(aa_list)): 57 | aa_mass[aa_list[i]] = aa_mass_list[i] 58 | 59 | aa_mod_delta_mass = {} 60 | for i in range(len(aa_list)): 61 | if aa_list[i].lower() in aa_list: 62 | aa_mod_delta_mass[aa_list[i].upper()] = aa_mass[aa_list[i].lower()] - aa_mass[aa_list[i].upper()] 63 | 64 | proton = 1.007276 65 | water = 18.010565 66 | 67 | min_aa_mass = min(aa_mass_list) 68 | max_aa_mass = max(aa_mass_list) 69 | 70 | all_results = [] 71 | unique_peptides = set() 72 | 73 | with open(os.path.join(input_path, exp_name + '_' + mode_tag + '_' + fdr_tag + '.tsv'), 'rt') as fin: 74 | header = fin.readline().rstrip('\n').split('\t') ## header 75 | mass_tol_col = header.index('MassError(ppm)') + 1 76 | peptide_col = header.index('Prediction') + 1 77 | 78 | current_id = 0 79 | 80 | for line in fin.readlines(): 81 | content = [current_id] 82 | content.extend(line.rstrip('\n').split('\t')) 83 | 84 | if abs(float(content[mass_tol_col])) <= denovo_mass_tol: ## filter by mass tolerance 85 | all_results.append(content) 86 | unique_peptides.add(content[peptide_col]) 87 | 88 | current_id += 1 89 | 90 | all_proteins = {} 91 | protein_info = {} 92 | 93 | with gzip.open(os.path.join(db_path, db_name + '.fasta.gz'), 'rt') as fin: 94 | line = fin.readline() 95 | 96 | while line: 97 | if line.startswith('>'): 98 | temp = line.split('|') 99 | uniprot_id = temp[1] 100 | primary_uniprot_id = uniprot_id.split('-')[0] 101 | 102 | if not 'OS=' in temp[2]: 103 | protein_name = temp[2] 104 | species_name = 'Unknown' 105 | else: 106 | temp = temp[2].split(' OS=') 107 | protein_name = temp[0] 108 | 109 | temp = temp[1].split('=')[0].split() 110 | species_name = ' '.join(temp[:-1]) 111 | 112 | line = fin.readline() 113 | seq = '' 114 | 115 | while line and not line.startswith('>'): 116 | seq += line.strip() 117 | line = fin.readline() 118 | 119 | #if not 'U' in seq and not 'X' in seq and not 'Z' in seq and not 'B' in seq: 120 | if not species_name in all_proteins: 121 | all_proteins[species_name] = {} 122 | 123 | if not primary_uniprot_id in all_proteins[species_name]: 124 | all_proteins[species_name][primary_uniprot_id] = [[uniprot_id, seq, seq.replace('L', 'I')]] 125 | else: 126 | all_proteins[species_name][primary_uniprot_id].append([uniprot_id, seq, seq.replace('L', 'I')]) 127 | 128 | if not uniprot_id in protein_info: 129 | protein_info[uniprot_id] = [species_name, primary_uniprot_id, protein_name] 130 | else: 131 | line = fin.readline() 132 | 133 | def seq_to_mass(sequence): 134 | return [aa_mass[x] for x in sequence] 135 | 136 | ################################### 137 | all_proteins_in_masses = {} 138 | 139 | for sp in all_proteins: 140 | all_proteins_in_masses[sp] = {} 141 | 142 | for pri_id in all_proteins[sp]: 143 | all_proteins_in_masses[sp][pri_id] = [] 144 | 145 | for entry in all_proteins[sp][pri_id]: 146 | all_proteins_in_masses[sp][pri_id].append(seq_to_mass(entry[1])) 147 | 148 | def get_signature(peptide): ## the regular expression will ignore the first mass-tag 149 | content = re.split(mass_tag_re, peptide) 150 | 151 | if content[0] == '': 152 | content = content[1:] 153 | 154 | sig = [] 155 | sig_case = [] 156 | sig_rev = [] 157 | sig_rev_case = [] 158 | seed = ['', -1] 159 | seed_case = ['', -1] 160 | tag = [] 161 | sig_flag = [] 162 | 163 | for i in range(len(content)): 164 | if content[i].startswith('('): 165 | mass = float(content[i][1:-1]) 166 | max_length = math.ceil(mass / min_aa_mass) 167 | min_length = math.ceil(mass / max_aa_mass) 168 | 169 | temp = [mass * (1 - pep_mass_tol_abs), mass * (1 + pep_mass_tol_abs)] 170 | sig.append(temp) 171 | sig_case.append(temp) 172 | sig_rev.append(temp) 173 | sig_rev_case.append(temp) 174 | sig_flag.append(False) 175 | elif not content[i] == '': 176 | sig.append(content[i].upper()) 177 | sig_case.append(content[i]) 178 | sig_rev.append(content[i][::-1].upper()) 179 | sig_rev_case.append(content[i][::-1]) 180 | sig_flag.append(True) 181 | 182 | if len(content[i]) > len(seed_case[0]): 183 | seed_case[0] = content[i] 184 | seed_case[1] = i 185 | 186 | if len(content[i]) >= mass_tag_cutoff: 187 | tag.append(content[i].upper()) 188 | 189 | seed[0] = seed_case[0].upper() 190 | seed[1] = seed_case[1] 191 | 192 | max_prefix_len = 0 193 | 194 | for i in range(seed[1]): 195 | if isinstance(sig[i], str): 196 | max_prefix_len += len(sig[i]) 197 | else: 198 | max_prefix_len += math.ceil(sig[i][1] / min_aa_mass) 199 | 200 | max_suffix_len = 0 201 | 202 | for i in range(seed[1] + 1, len(sig)): 203 | if isinstance(sig[i], str): 204 | max_suffix_len += len(sig[i]) 205 | else: 206 | max_suffix_len += math.ceil(sig[i][1] / min_aa_mass) 207 | 208 | #print(sig, sig_rev, tag, seed, max_prefix_len, max_suffix_len) 209 | return sig, sig_case, sig_rev, sig_rev_case, sig_flag, tag, seed, seed_case, max_prefix_len, max_suffix_len 210 | 211 | ## use lower case from pep but L from prot 212 | def merge_seq_info(pep, prot): 213 | template = list(prot) 214 | 215 | for i in range(len(pep)): 216 | if pep[i] in mod_aas_lower: 217 | template[i] = pep[i] 218 | 219 | return ''.join(template) 220 | 221 | ## compare candidate protein section against peptide signatures (mass or seq) 222 | ## pep_sig_flag is True for string, False for mass tag 223 | def search_hybrid(pep_sig, pep_sig_case, pep_sig_flag, prot_seq, prot_seq_noIL, prot_mass, current_prefix, mass_offset, current_pep_pos, current_prot_pos): 224 | # print('comparing:', pep_sig, current_pep_pos, 'and', prot_seq, current_prot_pos, current_prefix, mass_offset) 225 | if current_pep_pos == len(pep_sig): ## matched until the end 226 | return [current_prefix] 227 | elif current_prot_pos < len(prot_seq): ## there are some protein section left 228 | if pep_sig_flag[current_pep_pos]: ## string matching 229 | if prot_seq_noIL[current_prot_pos:].startswith(pep_sig[current_pep_pos]): ## matched 230 | return search_hybrid(pep_sig, pep_sig_case, pep_sig_flag, prot_seq, prot_seq_noIL, prot_mass, \ 231 | current_prefix + merge_seq_info(pep_sig_case[current_pep_pos], prot_seq[current_prot_pos:(current_prot_pos + len(pep_sig[current_pep_pos]))]), \ 232 | 0, current_pep_pos + 1, current_prot_pos + len(pep_sig[current_pep_pos])) 233 | else: ## mismatched 234 | return None 235 | else: ## mass matching 236 | current_mass = mass_offset 237 | current_index = current_prot_pos 238 | mod_flag = prot_seq[current_index] in mod_aas 239 | 240 | while current_mass < pep_sig[current_pep_pos][0] and current_index < len(prot_seq) - 1 and not mod_flag: ## keep adding more mass 241 | current_mass += prot_mass[current_index] 242 | current_index += 1 243 | mod_flag = prot_seq[current_index] in mod_aas 244 | 245 | if current_mass >= pep_sig[current_pep_pos][0]: ## exceeded the lower bound of mass 246 | if current_mass <= pep_sig[current_pep_pos][1]: ## the right amount of mass was achieved 247 | return search_hybrid(pep_sig, pep_sig_case, pep_sig_flag, prot_seq, prot_seq_noIL, prot_mass, \ 248 | current_prefix + prot_seq[current_prot_pos:current_index], 0, \ 249 | current_pep_pos + 1, current_index) 250 | 251 | elif current_index == len(prot_seq) - 1: ## arrived at the end of protein section, but the mass is still too low 252 | current_mass += prot_mass[current_index] 253 | 254 | if current_mass >= pep_sig[current_pep_pos][0] and current_mass <= pep_sig[current_pep_pos][1]: ## the right amount of mass was achieved 255 | return search_hybrid(pep_sig, pep_sig_case, pep_sig_flag, '', '', [], \ 256 | current_prefix + prot_seq[current_prot_pos:(current_index + 1)], 0, \ 257 | current_pep_pos + 1, current_index + 1) 258 | 259 | if mod_flag: ## the next amino acid can be modified 260 | current_mass += aa_mod_delta_mass[prot_seq[current_index]] ## try adding delta mass 261 | 262 | if current_mass >= pep_sig[current_pep_pos][0] and current_mass <= pep_sig[current_pep_pos][1]: ## the right amount of mass was achieved 263 | return search_hybrid(pep_sig, pep_sig_case, pep_sig_flag, '', '', [], \ 264 | current_prefix + prot_seq[current_prot_pos:(current_index + 1)], 0, \ 265 | current_pep_pos + 1, current_index + 1) 266 | 267 | else: ## must have reached a modifiable position that is not at the end of protein section, mass is also still too low 268 | current_mass += prot_mass[current_index] 269 | 270 | future_nomod = search_hybrid(pep_sig, pep_sig_case, pep_sig_flag, prot_seq, prot_seq_noIL, prot_mass, \ 271 | current_prefix + prot_seq[current_prot_pos:(current_index + 1)], current_mass, \ 272 | current_pep_pos, current_index + 1) 273 | 274 | future_mod = search_hybrid(pep_sig, pep_sig_case, pep_sig_flag, prot_seq, prot_seq_noIL, prot_mass, \ 275 | current_prefix + prot_seq[current_prot_pos:current_index] + prot_seq[current_index].lower(), \ 276 | current_mass + aa_mod_delta_mass[prot_seq[current_index]], \ 277 | current_pep_pos, current_index + 1) 278 | 279 | if future_nomod is None: 280 | return future_mod 281 | elif future_mod is None: 282 | return future_nomod 283 | else: 284 | future_nomod.extend(future_mod) 285 | return future_nomod 286 | 287 | return None ## return None for any situation not caught above 288 | 289 | def search_main(peptide): 290 | return search(peptide, all_proteins, all_proteins_in_masses) 291 | 292 | def search(peptide, proteins, masses): 293 | hits = [] 294 | pep_sig, pep_sig_case, pep_sig_rev, pep_sig_rev_case, pep_sig_flag, pep_tag, pep_seed, pep_seed_case, pep_max_prefix_len, pep_max_suffix_len = get_signature(peptide) 295 | 296 | for sp in proteins: 297 | for pri_id in proteins[sp]: 298 | for i in range(len(proteins[sp][pri_id])): 299 | prot_info = proteins[sp][pri_id][i] 300 | prot_mass = masses[sp][pri_id][i] 301 | 302 | matched_tag_flag = True 303 | 304 | for tag in pep_tag: 305 | if not tag in prot_info[2]: ## compare against no-IL version 306 | matched_tag_flag = False 307 | break 308 | 309 | if matched_tag_flag: ## all tags can be found 310 | start = prot_info[2].find(pep_seed[0], 0) 311 | 312 | while start > -1: ## continue while 'seed' can be found 313 | updated_seed = merge_seq_info(pep_seed_case[0], prot_info[1][start:(start + len(pep_seed[0]))]) 314 | 315 | if pep_seed[1] < len(pep_sig): 316 | L = len(pep_seed[0]) 317 | forward_hit = search_hybrid(pep_sig[(pep_seed[1] + 1):], pep_sig_case[(pep_seed[1] + 1):], \ 318 | pep_sig_flag[(pep_seed[1] + 1):], \ 319 | prot_info[1][(start + L):(start + L + pep_max_suffix_len)], \ 320 | prot_info[2][(start + L):(start + L + pep_max_suffix_len)], \ 321 | prot_mass[(start + L):(start + L + pep_max_suffix_len)], '', 0, 0, 0) 322 | else: 323 | forward_hit = [''] 324 | 325 | if not forward_hit is None: 326 | if pep_seed[1] > 0: 327 | L = start - pep_max_prefix_len - 1 328 | reverse_hit = search_hybrid(pep_sig_rev[(pep_seed[1] - 1)::-1], pep_sig_rev_case[(pep_seed[1] - 1)::-1], 329 | pep_sig_flag[(pep_seed[1] - 1)::-1], \ 330 | prot_info[1][(start - 1):L:-1], prot_info[2][(start - 1):L:-1], \ 331 | prot_mass[(start - 1):L:-1], '', 0, 0, 0) 332 | else: 333 | reverse_hit = [''] 334 | 335 | if not reverse_hit is None: ## success 336 | for prefix_rev in reverse_hit: 337 | prefix = prefix_rev[::-1] 338 | actual_start = start + 1 - len(prefix) 339 | 340 | for suffix in forward_hit: 341 | hits.append([sp, pri_id, prot_info[0], str(actual_start), prefix + updated_seed + suffix]) 342 | 343 | start = prot_info[2].find(pep_seed[0], start + 1) 344 | 345 | return [peptide, hits] 346 | 347 | pool = Pool(processes = num_thread) 348 | #begin = time.time() 349 | map_results = pool.map(search_main, unique_peptides) 350 | 351 | pool.close() 352 | pool.join() 353 | 354 | #print(time.time() - begin) 355 | 356 | with open(os.path.join(input_path, exp_name + '_' + mode_tag + '_' + fdr_tag + '_against_' + db_name + '.tsv'), 'w') as fout: 357 | for res in map_results: 358 | for entry in res[1]: 359 | fout.write(res[0] + '\t' + '\t'.join(entry) + '\n') 360 | -------------------------------------------------------------------------------- /utils_masking/append_decoded_peptides.py: -------------------------------------------------------------------------------- 1 | fname = '../MS002 - Peptide de novo sequencing/peptidome validation/Wu_peptidome_p-mod_fdr5' 2 | suffix = '_vs_human_all_isoform_v4' 3 | peptide_loc = 9 ## column containing peptide sequence in the report 4 | 5 | aa_list = 'GASPVTCLINDQKEMHFRYWmsty' 6 | aa_mass_list = [57.02146,71.03711,87.03203, 7 | 97.05276,99.06841,101.04768, 8 | 103.00919 + 57.02146,113.08406,113.08406, ## C are C mod 9 | 114.04293,115.02694,128.05858, 10 | 128.09496,129.04259,131.04049, 11 | 137.05891,147.06841,156.10111, 12 | 163.06333,186.07931,131.04049 + 15.99491, ## m = M(ox) 13 | 87.03203 + 79.96633,101.04768 + 79.96633,163.06333 + 79.96633] ## s = S(ph), t = T(ph), y = Y(ph) 14 | 15 | proton = 1.007276 16 | 17 | aa_mass = {} 18 | 19 | for i in range(len(aa_list)): 20 | aa_mass[aa_list[i]] = aa_mass_list[i] 21 | 22 | trivial_map = {} 23 | 24 | for i in range(len(aa_list)): 25 | trivial_map['(' + str(aa_mass_list[i]) + ')'] = aa_list[i] 26 | 27 | decode_map = {} 28 | 29 | with open(fname + suffix + '.tsv', 'rt') as fin: 30 | for line in fin.readlines(): 31 | content = line.rstrip('\n').split('\t') 32 | 33 | if not content[0] in decode_map: 34 | decode_map[content[0]] = {} 35 | 36 | if not content[5] in decode_map[content[0]]: 37 | decode_map[content[0]][content[5]] = [] 38 | 39 | decode_map[content[0]][content[5]].append(content[3]) 40 | 41 | mod_seq_map = {} 42 | base_seq_map = {} 43 | 44 | for peptide in decode_map: 45 | candidates = set([x.upper() for x in decode_map[peptide]]) 46 | 47 | if len(candidates) == 1: 48 | base_seq_map[peptide] = ', '.join(candidates) 49 | mod_seq_map[peptide] =', '.join(decode_map[peptide].keys()) 50 | else: 51 | base_seq_map[peptide] = 'Ambiguous' 52 | mod_seq_map[peptide] = 'Ambiguous' 53 | 54 | with open(fname + '.tsv', 'rt') as fin, open(fname + '_extended.tsv', 'w') as fout: 55 | header = fin.readline().strip() 56 | 57 | fout.write(header + '\tBaseSeq\tModSeq\n') # ProteinMapDetails\n') 58 | 59 | for line in fin.readlines(): 60 | fout.write(line.strip()) 61 | content = line.strip().split('\t')[peptide_loc] 62 | 63 | if not content in decode_map: 64 | for tag in trivial_map: 65 | content = content.replace(tag, trivial_map[tag]) 66 | 67 | if not '(' in content: ## no mass tag 68 | if not 'I' in content: ## no I/L ambiguity 69 | fout.write('\t' + content + '\t' + content + '\tUnMapped\n') 70 | else: 71 | fout.write('\tAmbiguous\tAmbiguous\tUnMapped\n') 72 | else: 73 | fout.write('\tAmbiguous\tAmbiguous\tUnMapped\n') 74 | else: 75 | fout.write('\t' + base_seq_map[content] + '\t' + mod_seq_map[content] + '\n') # + str(decode_map[content]) + '\n') 76 | -------------------------------------------------------------------------------- /utils_masking/create_denovo_report.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Sira Sriswasdi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | ### FDR THRESHOLD 17 | fdr_target = 5 ## choices are 5 or 10, right now 18 | min_predicted_aa = 4 19 | min_predicted_frac = 0 20 | model = 'm-mod' 21 | 22 | output_folder = 'CUSB_antibody_May2019_output' 23 | mgf_folder = 'CUSB_antibody_May2019' 24 | 25 | ########################################################## 26 | ### LIBRARY 27 | import os, math 28 | 29 | ########################################################## 30 | ### INTERNAL PARAMETERS AND CONSTANTS 31 | fdr_threshold_map = {'m-mod': {5: 0.83, 10: 0.56}, 32 | 'p-mod': {5: 0.81, 10: 0.57}} 33 | 34 | aa_list = 'GASPVTCLINDQKEMHFRYWmsty' 35 | aa_mass_list = [57.02146,71.03711,87.03203, 36 | 97.05276,99.06841,101.04768, 37 | 103.00919 + 57.02146,113.08406,113.08406, ## C are C mod 38 | 114.04293,115.02694,128.05858, 39 | 128.09496,129.04259,131.04049, 40 | 137.05891,147.06841,156.10111, 41 | 163.06333,186.07931,131.04049 + 15.99491, ## m = M(ox) 42 | 87.03203 + 79.96633,101.04768 + 79.96633,163.06333 + 79.96633] ## s = S(ph), t = T(ph), y = Y(ph) 43 | 44 | aa_mass = {} 45 | for i in range(len(aa_list)): 46 | aa_mass[aa_list[i]] = aa_mass_list[i] 47 | 48 | proton = 1.007276 49 | water = 18.010565 50 | 51 | ########################################################## 52 | ### MAIN SCRIPT 53 | def main(output_folder, mgf_folder, model): 54 | rescore_threshold = fdr_threshold_map[model][fdr_target] 55 | predictions = {} 56 | 57 | spectra_count = 0 58 | predicted_seq_count = 0 59 | predicted_full_seq_count = 0 60 | predicted_aa_count = 0 61 | predicted_mask_count = 0 62 | 63 | for f in os.listdir(output_folder): 64 | if f.endswith('_rescore'): 65 | fname = f[:-8] 66 | predictions[fname] = {} 67 | 68 | with open(os.path.join(output_folder, fname), 'rt') as seq_in, open(os.path.join(output_folder, f), 'rt') as score_in: 69 | seq_line = seq_in.readline() 70 | score_line = score_in.readline() 71 | current_id = 0 ## use line number to map between MGF and OUTPUT 72 | 73 | while seq_line and score_line: 74 | seq_list = seq_line.strip().split() 75 | score_list = score_line.strip().split() 76 | 77 | if not '' in seq_list and not '' in seq_list and len(seq_list) == len(score_list): ## valid prediction 78 | score_list = [math.exp(float(x)) for x in score_list] 79 | mask_list = ['Y'] * len(seq_list) 80 | mask_count = 0 81 | 82 | for i in range(len(score_list)): 83 | if score_list[i] < rescore_threshold: 84 | mask_list[i] = 'N' 85 | mask_count += 1 86 | 87 | temp = str(score_list[i]).split('.') 88 | score_list[i] = temp[0] + '.' + temp[1][:2] 89 | 90 | ## count as predicted only if all thresholds are satisfied 91 | if len(seq_list) - mask_count >= min_predicted_aa and (len(seq_list) - mask_count) / len(seq_list) >= min_predicted_frac: 92 | predicted_seq_count += 1 93 | predicted_aa_count += len(seq_list) 94 | predicted_mask_count += mask_count 95 | 96 | predicted_seq = '' 97 | total_mass = 0 98 | unknown_mass = 0 99 | 100 | for i in range(len(score_list)): 101 | total_mass += aa_mass[seq_list[i]] 102 | 103 | if mask_list[i] == 'Y': 104 | if unknown_mass > 0: ## preceded by masked positions 105 | temp = str(unknown_mass).split('.') 106 | predicted_seq += '(' + temp[0] + '.' + temp[1][:min(5, len(temp[1]))] + ')' 107 | unknown_mass = 0 108 | 109 | predicted_seq += seq_list[i] 110 | else: 111 | unknown_mass += aa_mass[seq_list[i]] 112 | 113 | if unknown_mass > 0: ## ends with unknown mass 114 | temp = str(unknown_mass).split('.') 115 | predicted_seq += '(' + temp[0] + '.' + temp[1][:min(5, len(temp[1]))] + ')' 116 | 117 | if mask_count == 0: 118 | predicted_full_seq_count += 1 119 | 120 | theoretical_mhp = str(total_mass + proton + water) 121 | predictions[fname][current_id] = [predicted_seq, ';'.join([str(x) for x in score_list]), theoretical_mhp] 122 | 123 | current_id += 1 ## update ID 124 | seq_line = seq_in.readline() 125 | score_line = score_in.readline() 126 | 127 | spectra_count += current_id 128 | 129 | print('total spectra: ', spectra_count) 130 | print('predicted sequences: ', predicted_seq_count) 131 | print('predicted full sequences: ', predicted_full_seq_count) 132 | print('amino acids in predicted sequences:', predicted_aa_count) 133 | print('amino acids after masking: ', predicted_aa_count - predicted_mask_count) 134 | print('masks: ', predicted_mask_count) 135 | 136 | for f in os.listdir(mgf_folder): 137 | if f.endswith('.mgf'): 138 | fname = f[:-4] 139 | 140 | if fname in predictions: 141 | with open(os.path.join(mgf_folder, f), 'rt') as fin: 142 | current_id = 0 143 | line = fin.readline() 144 | 145 | while line: 146 | if line.startswith('BEGIN'): 147 | if current_id in predictions[fname]: 148 | scan_num = 'UNK' 149 | charge = 'UNK' 150 | ret_time = 'UNK' 151 | precursor_mass = 'UNK' 152 | precursor_mhp = 'UNK' 153 | precursor_int = 'UNK' 154 | mass_error = 'UNK' 155 | line = fin.readline() 156 | 157 | while not line.startswith('END'): 158 | if line.startswith('TITLE'): 159 | scan_num = line.strip().split('scan=')[1].replace('"', '') 160 | 161 | elif line.startswith('RTINSECONDS'): 162 | ret_time = str(float(line.strip().split('=')[1]) / 60.0) 163 | elif line.startswith('PEPMASS'): 164 | content = line.strip().split('=')[1].split(' ') 165 | 166 | if len(content) == 2: 167 | precursor_mass = content[0] 168 | precursor_int = content[1] 169 | elif len(content) == 1: 170 | precursor_mass = content[0] 171 | 172 | elif line.startswith('CHARGE'): 173 | charge = line.strip().split('=')[1][:-1] 174 | 175 | try: 176 | precursor_mhp = float(precursor_mass) * float(charge) - (float(charge) - 1) * proton 177 | mass_error = str((precursor_mhp - float(predictions[fname][current_id][2])) * 1000000.0 / float(predictions[fname][current_id][2])) 178 | precursor_mhp = str(precursor_mhp) 179 | except: 180 | pass 181 | 182 | line = fin.readline() 183 | 184 | predictions[fname][current_id].extend([fname, scan_num, charge, ret_time, precursor_mass, precursor_mhp, mass_error, precursor_int]) 185 | 186 | current_id += 1 187 | 188 | line = fin.readline() 189 | 190 | with open('_'.join([mgf_folder, model, 'fdr' + str(fdr_target)]) + '.tsv', 'w') as fout: 191 | fout.write('\t'.join(['MS File', 'ScanNum', 'Charge', 'RT(min)', 'ObservedM/Z', 'ObservedM+H', 'TheoreticalM+H', 'MassError(ppm)', 'ObservedInt', 'Prediction', 'Scores']) + '\n') 192 | 193 | for fname in predictions: 194 | for spectrum in predictions[fname]: 195 | if not len(predictions[fname][spectrum]) == 11: 196 | print(fname, spectrum, predictions[fname][spectrum]) 197 | else: 198 | fout.write('\t'.join([predictions[fname][spectrum][i] for i in [3, 4, 5, 6, 7, 8, 2, 9, 10, 0, 1]]) + '\n') 199 | 200 | 201 | if __name__ == "__main__": 202 | main(sys.argv[1:]) 203 | --------------------------------------------------------------------------------