├── LICENSE ├── README.md ├── bert ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── __init__.py ├── create_pretraining_data.py ├── extract_features.py ├── modeling.py ├── modeling_test.py ├── multilingual.md ├── optimization.py ├── optimization_test.py ├── predicting_movie_reviews_with_bert_on_tf_hub.ipynb ├── requirements.txt ├── run_classifier.py ├── run_classifier_with_tfhub.py ├── run_pretraining.py ├── run_squad.py ├── sample_text.txt ├── tokenization.py └── tokenization_test.py ├── calculate_model_score.py ├── calculating_model_score ├── calculate_atis_intent.py ├── calculate_atis_slot.py ├── calculate_model_score.py ├── calculate_snips_intent_and_slot.py ├── calculate_snips_intent_and_slot_new.py ├── calculate_snips_slot.py ├── calculate_snpis_intent.py ├── log.txt ├── sklearn_metrics_function.py └── tf_metrics.py ├── data ├── CoNLL2003_NER │ ├── conll03_raw_data_to_stand_file.py │ ├── dev.txt │ ├── test.txt │ ├── test │ │ ├── seq.in │ │ └── seq.out │ ├── train.txt │ ├── train │ │ ├── seq.in │ │ └── seq.out │ └── valid │ │ ├── seq.in │ │ └── seq.out ├── atis_Intent_Detection_and_Slot_Filling │ ├── test │ │ ├── label │ │ ├── seq.in │ │ └── seq.out │ ├── train │ │ ├── check_train_raw_data.py │ │ ├── label │ │ ├── seq.in │ │ └── seq.out │ └── valid │ │ ├── label │ │ ├── seq.in │ │ └── seq.out └── snips_Intent_Detection_and_Slot_Filling │ ├── test │ ├── label │ ├── seq.in │ └── seq.out │ ├── train │ ├── label │ ├── seq.in │ └── seq.out │ └── valid │ ├── label │ ├── seq.in │ └── seq.out ├── output_model_prediction ├── atis_join_task_LSTM_epoch30_ckpt4198 │ ├── .ipynb_checkpoints │ │ └── model_score_log-checkpoint.txt │ ├── intent_label2id.pkl │ ├── intent_prediction_test_results.txt │ ├── model_score_log.txt │ ├── predict.tf_record │ ├── slot_filling_test_results.txt │ └── slot_label2id.pkl ├── atis_join_task_epoch30_ckpt4198 │ ├── intent_label2id.pkl │ ├── intent_prediction_test_results.txt │ ├── model_score_log.txt │ ├── predict.tf_record │ ├── slot_filling_test_results.txt │ └── slot_label2id.pkl ├── conll2003ner_epoch3_test653ckpt │ ├── label2id.pkl │ ├── predict.tf_record │ ├── slot_filling_test_results.txt │ └── token_test.txt ├── model_score_summarization.txt ├── score_summarization.py └── snips_join_task_epoch10_test4088ckpt │ ├── intent_label2id.pkl │ ├── intent_prediction_test_results.txt │ ├── model_score_log.txt │ ├── predict.tf_record │ ├── slot_filling_test_results.txt │ └── slot_label2id.pkl ├── predefined_task_usage.md ├── pretrained_model └── uncased_L-12_H-768_A-12 │ ├── bert_config.json │ └── vocab.txt ├── requirements.txt ├── run_sequence_labeling.py ├── run_sequence_labeling_and_text_classification.py ├── run_slot_intent_join_task_LSTM.py ├── run_text_classification.py └── store_fine_tuned_model └── download_url.md /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Template Code: BERT-for-Sequence-Labeling-and-Text-Classification 2 | BERT is used for sequence annotation and text categorization template code to facilitate BERT for more tasks. The code has been tested on snips (intention recognition and slot filling task), ATIS (intention recognition and slot filling task) and conll-2003 (named entity recognition task) datasets. Welcome to use this BERT template to solve more NLP tasks, and then share your results and code here. 3 | 4 | 这是使用BERT进行序列标注和文本分类的模板代码,方便大家将BERT用于更多任务。该代码已经在SNIPS(意图识别和槽填充任务)、ATIS(意图识别和槽填充任务)和conll-2003(命名实体识别任务)数据集上进行了实验。欢迎使用这个BERT模板解决更多NLP任务,然后在这里分享你的结果和代码。 5 | 6 | ![](https://yuanxiaosc.github.io/2019/03/18/%E6%A7%BD%E5%A1%AB%E5%85%85%E5%92%8C%E6%84%8F%E5%9B%BE%E8%AF%86%E5%88%AB%E4%BB%BB%E5%8A%A1%E7%9A%84%E5%9F%BA%E6%9C%AC%E6%A6%82%E5%BF%B5/1.png) 7 | 8 | ## Task and Dataset 9 | I have downloaded the data for you. Welcome to add new data set. 10 | 11 | |task name|dataset name|data source| 12 | |-|-|-| 13 | |CoNLL-2003 named entity recognition|conll2003ner|https://www.clips.uantwerpen.be/conll2003/ner/ | 14 | |Atis Joint Slot Filling and Intent Prediction|atis|https://github.com/MiuLab/SlotGated-SLU/tree/master/data/atis | 15 | |Snips Joint Slot Filling and Intent Prediction|snips|https://github.com/MiuLab/SlotGated-SLU/tree/master/data/snips | 16 | 17 | 18 | ## Environment Requirements 19 | Use `pip install -r requirements.txt` to install dependencies quickly. 20 | + python 3.6+ 21 | + Tensorflow 1.12.0+ 22 | + sklearn 23 | 24 | ## Template Code Usage Method 25 | 26 | ### Using pre training and fine-tuning model directly 27 | > For example: Atis Joint Slot Filling and Intent Prediction 28 | 29 | 1. Download model weight [atis_join_task_LSTM_epoch30_simple.zip](https://pan.baidu.com/s/1SZkQXP8NrOtZKVEMfDE4bw) and unzip then to file `store_fine_tuned_model`, https://pan.baidu.com/s/1SZkQXP8NrOtZKVEMfDE4bw; 30 | 2. Run Code! You can change task_name and output_dir. 31 | ```bash 32 | python run_slot_intent_join_task_LSTM.py \ 33 | --task_name=Atis \ 34 | --do_predict=true \ 35 | --data_dir=data/atis_Intent_Detection_and_Slot_Filling \ 36 | --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \ 37 | --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \ 38 | --init_checkpoint=store_fine_tuned_model/atis_join_task_LSTM_epoch30_simple/model.ckpt-4198 \ 39 | --max_seq_length=128 \ 40 | --output_dir=./output_model_predict/atis_join_task_LSTM_epoch30_simple_ckpt4198 41 | ``` 42 | 43 | You can find the file of model prediction and the score of model prediction in `output_dir` (You can find the content of model socres later). 44 | 45 | 46 | ### Quick start(model train and predict) 47 | > See [predefined_task_usage.md](predefined_task_usage.md) for more predefined task usage codes. 48 | 49 | 1. Move google's [BERT code](https://github.com/google-research/bert) to file `bert` (I've prepared a copy for you.); 50 | 2. Download google's [BERT pretrained model](https://github.com/google-research/bert) and unzip then to file `pretrained_model`, https://github.com/google-research/bert; 51 | 3. Run Code! You can change task_name and output_dir. 52 | 53 | **model training** 54 | ``` 55 | python run_sequence_labeling_and_text_classification.py \ 56 | --task_name=snips \ 57 | --do_train=true \ 58 | --do_eval=true \ 59 | --data_dir=data/snips_Intent_Detection_and_Slot_Filling \ 60 | --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \ 61 | --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \ 62 | --init_checkpoint=pretrained_model/uncased_L-12_H-768_A-12/bert_model.ckpt \ 63 | --num_train_epochs=3.0 \ 64 | --output_dir=./store_fine_tuned_model/snips_join_task_epoch3/ 65 | ``` 66 | 67 | Then you can find the fine tuned model in the `output_dir=./store_fine_tuned_model/snips_join_task_epoch3/` folder. 68 | 69 | 70 | **model prediction** 71 | ``` 72 | python run_sequence_labeling_and_text_classification.py \ 73 | --task_name=Snips \ 74 | --do_predict=true \ 75 | --data_dir=data/snips_Intent_Detection_and_Slot_Filling \ 76 | --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \ 77 | --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \ 78 | --init_checkpoint=output_model/snips_join_task_epoch3/model.ckpt-1000 \ 79 | --max_seq_length=128 \ 80 | --output_dir=./output_model_prediction/snips_join_task_epoch3_ckpt1000 81 | ``` 82 | 83 | Then you can find the predicted output of the model and the output test results (accuracy, recall, F1 value, etc.) in the `output_dir=./output_model_prediction/snips_join_task_epoch3_ckpt1000` folder. 84 | 85 | 86 | ## File Structure 87 | 88 | |name|function| 89 | |-|-| 90 | | bert |store google's [BERT code](https://github.com/google-research/bert)||| 91 | | data |store task raw data set| 92 | |output_model_prediction|store model predict| 93 | |store_fine_tuned_model| store finet tuned model| 94 | |calculating_model_score|| 95 | |pretrained_model |store [BERT pretrained model](https://github.com/google-research/bert)| 96 | |run_sequence_labeling.py |for Sequence Labeling Task| 97 | |run_text_classification.py| for Text Classification Task| 98 | |run_sequence_labeling_and_text_classification.py| for join task | 99 | |calculate_model_score.py |for evaluation model | 100 | 101 | 102 | ## Model Socres 103 | 104 | **The following model scores are model scores without careful adjustment of model parameters, that is to say, the scores can continue to improve!** 105 | 106 | ### CoNLL-2003 named entity recognition 107 | eval_f = 0.926 108 | eval_precision = 0.925 109 | eval_recall = 0.928 110 | 111 | ### Atis Joint Slot Filling and Intent Prediction 112 | Intent Prediction 113 | Correct rate: 0.976 114 | Accuracy: 0.976 115 | Recall rate: 0.976 116 | F1-score: 0.976 117 | 118 | Slot Filling19 119 | Correct rate: 0.955 120 | Accuracy: 0.955 121 | Recall rate: 0.955 122 | F1-score: 0.955 123 | 124 | ## How to add a new task 125 | 126 | Just write a small piece of code according to the existing template! 127 | 128 | ### Data 129 | For example, If you have a new classification task [QQP](https://data.quora.com/First-Quora-Dataset-Release-Question-Pairs). 130 | 131 | Before running this example you must download the [GLUE data](https://gluebenchmark.com/tasks) by running [this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e). 132 | 133 | ### Code 134 | Now, write code! 135 | 136 | ``` 137 | class QqpProcessor(DataProcessor): 138 | """Processor for the QQP data set.""" 139 | 140 | def get_train_examples(self, data_dir): 141 | """See base class.""" 142 | return self._create_examples( 143 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 144 | 145 | def get_dev_examples(self, data_dir): 146 | """See base class.""" 147 | return self._create_examples( 148 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 149 | 150 | def get_test_examples(self, data_dir): 151 | """See base class.""" 152 | return self._create_examples( 153 | self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 154 | 155 | def get_labels(self): 156 | """See base class.""" 157 | return ["0", "1"] 158 | 159 | def _create_examples(self, lines, set_type): 160 | """Creates examples for the training and dev sets.""" 161 | examples = [] 162 | for (i, line) in enumerate(lines): 163 | if i == 0 or len(line)!=6: 164 | continue 165 | guid = "%s-%s" % (set_type, i) 166 | text_a = tokenization.convert_to_unicode(line[3]) 167 | text_b = tokenization.convert_to_unicode(line[4]) 168 | if set_type == "test": 169 | label = "1" 170 | else: 171 | label = tokenization.convert_to_unicode(line[5]) 172 | examples.append( 173 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 174 | return examples 175 | ``` 176 | 177 | Registration task 178 | 179 | ``` 180 | def main(_): 181 | tf.logging.set_verbosity(tf.logging.INFO) 182 | processors = { 183 | "qqp": QqpProcessor, 184 | } 185 | ``` 186 | 187 | ### Run 188 | ``` 189 | python run_text_classification.py \ 190 | --task_name=qqp \ 191 | --do_train=true \ 192 | --do_eval=true \ 193 | --data_dir=data/snips_Intent_Detection_and_Slot_Filling \ 194 | --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \ 195 | --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \ 196 | --init_checkpoint=pretrained_model/uncased_L-12_H-768_A-12/bert_model.ckpt \ 197 | --max_seq_length=128 \ 198 | --train_batch_size=32 \ 199 | --learning_rate=2e-5 \ 200 | --num_train_epochs=3.0 \ 201 | --output_dir=./output/qqp_Intent_Detection/ 202 | ``` 203 | -------------------------------------------------------------------------------- /bert/.gitignore: -------------------------------------------------------------------------------- 1 | # Initially taken from Github's Python gitignore file 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | .dmypy.json 113 | dmypy.json 114 | 115 | # Pyre type checker 116 | .pyre/ 117 | -------------------------------------------------------------------------------- /bert/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | BERT needs to maintain permanent compatibility with the pre-trained model files, 4 | so we do not plan to make any major changes to this library (other than what was 5 | promised in the README). However, we can accept small patches related to 6 | re-factoring and documentation. To submit contributes, there are just a few 7 | small guidelines you need to follow. 8 | 9 | ## Contributor License Agreement 10 | 11 | Contributions to this project must be accompanied by a Contributor License 12 | Agreement. You (or your employer) retain the copyright to your contribution; 13 | this simply gives us permission to use and redistribute your contributions as 14 | part of the project. Head over to to see 15 | your current agreements on file or to sign a new one. 16 | 17 | You generally only need to submit a CLA once, so if you've already submitted one 18 | (even if it was for a different project), you probably don't need to do it 19 | again. 20 | 21 | ## Code reviews 22 | 23 | All submissions, including submissions by project members, require review. We 24 | use GitHub pull requests for this purpose. Consult 25 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 26 | information on using pull requests. 27 | 28 | ## Community Guidelines 29 | 30 | This project follows 31 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 32 | -------------------------------------------------------------------------------- /bert/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /bert/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /bert/modeling_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import collections 20 | import json 21 | import random 22 | import re 23 | 24 | import modeling 25 | import six 26 | import tensorflow as tf 27 | 28 | 29 | class BertModelTest(tf.test.TestCase): 30 | 31 | class BertModelTester(object): 32 | 33 | def __init__(self, 34 | parent, 35 | batch_size=13, 36 | seq_length=7, 37 | is_training=True, 38 | use_input_mask=True, 39 | use_token_type_ids=True, 40 | vocab_size=99, 41 | hidden_size=32, 42 | num_hidden_layers=5, 43 | num_attention_heads=4, 44 | intermediate_size=37, 45 | hidden_act="gelu", 46 | hidden_dropout_prob=0.1, 47 | attention_probs_dropout_prob=0.1, 48 | max_position_embeddings=512, 49 | type_vocab_size=16, 50 | initializer_range=0.02, 51 | scope=None): 52 | self.parent = parent 53 | self.batch_size = batch_size 54 | self.seq_length = seq_length 55 | self.is_training = is_training 56 | self.use_input_mask = use_input_mask 57 | self.use_token_type_ids = use_token_type_ids 58 | self.vocab_size = vocab_size 59 | self.hidden_size = hidden_size 60 | self.num_hidden_layers = num_hidden_layers 61 | self.num_attention_heads = num_attention_heads 62 | self.intermediate_size = intermediate_size 63 | self.hidden_act = hidden_act 64 | self.hidden_dropout_prob = hidden_dropout_prob 65 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 66 | self.max_position_embeddings = max_position_embeddings 67 | self.type_vocab_size = type_vocab_size 68 | self.initializer_range = initializer_range 69 | self.scope = scope 70 | 71 | def create_model(self): 72 | input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], 73 | self.vocab_size) 74 | 75 | input_mask = None 76 | if self.use_input_mask: 77 | input_mask = BertModelTest.ids_tensor( 78 | [self.batch_size, self.seq_length], vocab_size=2) 79 | 80 | token_type_ids = None 81 | if self.use_token_type_ids: 82 | token_type_ids = BertModelTest.ids_tensor( 83 | [self.batch_size, self.seq_length], self.type_vocab_size) 84 | 85 | config = modeling.BertConfig( 86 | vocab_size=self.vocab_size, 87 | hidden_size=self.hidden_size, 88 | num_hidden_layers=self.num_hidden_layers, 89 | num_attention_heads=self.num_attention_heads, 90 | intermediate_size=self.intermediate_size, 91 | hidden_act=self.hidden_act, 92 | hidden_dropout_prob=self.hidden_dropout_prob, 93 | attention_probs_dropout_prob=self.attention_probs_dropout_prob, 94 | max_position_embeddings=self.max_position_embeddings, 95 | type_vocab_size=self.type_vocab_size, 96 | initializer_range=self.initializer_range) 97 | 98 | model = modeling.BertModel( 99 | config=config, 100 | is_training=self.is_training, 101 | input_ids=input_ids, 102 | input_mask=input_mask, 103 | token_type_ids=token_type_ids, 104 | scope=self.scope) 105 | 106 | outputs = { 107 | "embedding_output": model.get_embedding_output(), 108 | "sequence_output": model.get_sequence_output(), 109 | "pooled_output": model.get_pooled_output(), 110 | "all_encoder_layers": model.get_all_encoder_layers(), 111 | } 112 | return outputs 113 | 114 | def check_output(self, result): 115 | self.parent.assertAllEqual( 116 | result["embedding_output"].shape, 117 | [self.batch_size, self.seq_length, self.hidden_size]) 118 | 119 | self.parent.assertAllEqual( 120 | result["sequence_output"].shape, 121 | [self.batch_size, self.seq_length, self.hidden_size]) 122 | 123 | self.parent.assertAllEqual(result["pooled_output"].shape, 124 | [self.batch_size, self.hidden_size]) 125 | 126 | def test_default(self): 127 | self.run_tester(BertModelTest.BertModelTester(self)) 128 | 129 | def test_config_to_json_string(self): 130 | config = modeling.BertConfig(vocab_size=99, hidden_size=37) 131 | obj = json.loads(config.to_json_string()) 132 | self.assertEqual(obj["vocab_size"], 99) 133 | self.assertEqual(obj["hidden_size"], 37) 134 | 135 | def run_tester(self, tester): 136 | with self.test_session() as sess: 137 | ops = tester.create_model() 138 | init_op = tf.group(tf.global_variables_initializer(), 139 | tf.local_variables_initializer()) 140 | sess.run(init_op) 141 | output_result = sess.run(ops) 142 | tester.check_output(output_result) 143 | 144 | self.assert_all_tensors_reachable(sess, [init_op, ops]) 145 | 146 | @classmethod 147 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None): 148 | """Creates a random int32 tensor of the shape within the vocab size.""" 149 | if rng is None: 150 | rng = random.Random() 151 | 152 | total_dims = 1 153 | for dim in shape: 154 | total_dims *= dim 155 | 156 | values = [] 157 | for _ in range(total_dims): 158 | values.append(rng.randint(0, vocab_size - 1)) 159 | 160 | return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name) 161 | 162 | def assert_all_tensors_reachable(self, sess, outputs): 163 | """Checks that all the tensors in the graph are reachable from outputs.""" 164 | graph = sess.graph 165 | 166 | ignore_strings = [ 167 | "^.*/assert_less_equal/.*$", 168 | "^.*/dilation_rate$", 169 | "^.*/Tensordot/concat$", 170 | "^.*/Tensordot/concat/axis$", 171 | "^testing/.*$", 172 | ] 173 | 174 | ignore_regexes = [re.compile(x) for x in ignore_strings] 175 | 176 | unreachable = self.get_unreachable_ops(graph, outputs) 177 | filtered_unreachable = [] 178 | for x in unreachable: 179 | do_ignore = False 180 | for r in ignore_regexes: 181 | m = r.match(x.name) 182 | if m is not None: 183 | do_ignore = True 184 | if do_ignore: 185 | continue 186 | filtered_unreachable.append(x) 187 | unreachable = filtered_unreachable 188 | 189 | self.assertEqual( 190 | len(unreachable), 0, "The following ops are unreachable: %s" % 191 | (" ".join([x.name for x in unreachable]))) 192 | 193 | @classmethod 194 | def get_unreachable_ops(cls, graph, outputs): 195 | """Finds all of the tensors in graph that are unreachable from outputs.""" 196 | outputs = cls.flatten_recursive(outputs) 197 | output_to_op = collections.defaultdict(list) 198 | op_to_all = collections.defaultdict(list) 199 | assign_out_to_in = collections.defaultdict(list) 200 | 201 | for op in graph.get_operations(): 202 | for x in op.inputs: 203 | op_to_all[op.name].append(x.name) 204 | for y in op.outputs: 205 | output_to_op[y.name].append(op.name) 206 | op_to_all[op.name].append(y.name) 207 | if str(op.type) == "Assign": 208 | for y in op.outputs: 209 | for x in op.inputs: 210 | assign_out_to_in[y.name].append(x.name) 211 | 212 | assign_groups = collections.defaultdict(list) 213 | for out_name in assign_out_to_in.keys(): 214 | name_group = assign_out_to_in[out_name] 215 | for n1 in name_group: 216 | assign_groups[n1].append(out_name) 217 | for n2 in name_group: 218 | if n1 != n2: 219 | assign_groups[n1].append(n2) 220 | 221 | seen_tensors = {} 222 | stack = [x.name for x in outputs] 223 | while stack: 224 | name = stack.pop() 225 | if name in seen_tensors: 226 | continue 227 | seen_tensors[name] = True 228 | 229 | if name in output_to_op: 230 | for op_name in output_to_op[name]: 231 | if op_name in op_to_all: 232 | for input_name in op_to_all[op_name]: 233 | if input_name not in stack: 234 | stack.append(input_name) 235 | 236 | expanded_names = [] 237 | if name in assign_groups: 238 | for assign_name in assign_groups[name]: 239 | expanded_names.append(assign_name) 240 | 241 | for expanded_name in expanded_names: 242 | if expanded_name not in stack: 243 | stack.append(expanded_name) 244 | 245 | unreachable_ops = [] 246 | for op in graph.get_operations(): 247 | is_unreachable = False 248 | all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs] 249 | for name in all_names: 250 | if name not in seen_tensors: 251 | is_unreachable = True 252 | if is_unreachable: 253 | unreachable_ops.append(op) 254 | return unreachable_ops 255 | 256 | @classmethod 257 | def flatten_recursive(cls, item): 258 | """Flattens (potentially nested) a tuple/dictionary/list to a list.""" 259 | output = [] 260 | if isinstance(item, list): 261 | output.extend(item) 262 | elif isinstance(item, tuple): 263 | output.extend(list(item)) 264 | elif isinstance(item, dict): 265 | for (_, v) in six.iteritems(item): 266 | output.append(v) 267 | else: 268 | return [item] 269 | 270 | flat_output = [] 271 | for x in output: 272 | flat_output.extend(cls.flatten_recursive(x)) 273 | return flat_output 274 | 275 | 276 | if __name__ == "__main__": 277 | tf.test.main() 278 | -------------------------------------------------------------------------------- /bert/multilingual.md: -------------------------------------------------------------------------------- 1 | ## Models 2 | 3 | There are two multilingual models currently available. We do not plan to release 4 | more single-language models, but we may release `BERT-Large` versions of these 5 | two in the future: 6 | 7 | * **[`BERT-Base, Multilingual Cased (New, recommended)`](https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip)**: 8 | 104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters 9 | * **[`BERT-Base, Multilingual Uncased (Orig, not recommended)`](https://storage.googleapis.com/bert_models/2018_11_03/multilingual_L-12_H-768_A-12.zip)**: 10 | 102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters 11 | * **[`BERT-Base, Chinese`](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip)**: 12 | Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M 13 | parameters 14 | 15 | **The `Multilingual Cased (New)` model also fixes normalization issues in many 16 | languages, so it is recommended in languages with non-Latin alphabets (and is 17 | often better for most languages with Latin alphabets). When using this model, 18 | make sure to pass `--do_lower_case=false` to `run_pretraining.py` and other 19 | scripts.** 20 | 21 | See the [list of languages](#list-of-languages) that the Multilingual model 22 | supports. The Multilingual model does include Chinese (and English), but if your 23 | fine-tuning data is Chinese-only, then the Chinese model will likely produce 24 | better results. 25 | 26 | ## Results 27 | 28 | To evaluate these systems, we use the 29 | [XNLI dataset](https://github.com/facebookresearch/XNLI) dataset, which is a 30 | version of [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) where the 31 | dev and test sets have been translated (by humans) into 15 languages. Note that 32 | the training set was *machine* translated (we used the translations provided by 33 | XNLI, not Google NMT). For clarity, we only report on 6 languages below: 34 | 35 | 36 | 37 | | System | English | Chinese | Spanish | German | Arabic | Urdu | 38 | | --------------------------------- | -------- | -------- | -------- | -------- | -------- | -------- | 39 | | XNLI Baseline - Translate Train | 73.7 | 67.0 | 68.8 | 66.5 | 65.8 | 56.6 | 40 | | XNLI Baseline - Translate Test | 73.7 | 68.3 | 70.7 | 68.7 | 66.8 | 59.3 | 41 | | BERT - Translate Train Cased | **81.9** | **76.6** | **77.8** | **75.9** | **70.7** | 61.6 | 42 | | BERT - Translate Train Uncased | 81.4 | 74.2 | 77.3 | 75.2 | 70.5 | 61.7 | 43 | | BERT - Translate Test Uncased | 81.4 | 70.1 | 74.9 | 74.4 | 70.4 | **62.1** | 44 | | BERT - Zero Shot Uncased | 81.4 | 63.8 | 74.3 | 70.5 | 62.1 | 58.3 | 45 | 46 | 47 | 48 | The first two rows are baselines from the XNLI paper and the last three rows are 49 | our results with BERT. 50 | 51 | **Translate Train** means that the MultiNLI training set was machine translated 52 | from English into the foreign language. So training and evaluation were both 53 | done in the foreign language. Unfortunately, training was done on 54 | machine-translated data, so it is impossible to quantify how much of the lower 55 | accuracy (compared to English) is due to the quality of the machine translation 56 | vs. the quality of the pre-trained model. 57 | 58 | **Translate Test** means that the XNLI test set was machine translated from the 59 | foreign language into English. So training and evaluation were both done on 60 | English. However, test evaluation was done on machine-translated English, so the 61 | accuracy depends on the quality of the machine translation system. 62 | 63 | **Zero Shot** means that the Multilingual BERT system was fine-tuned on English 64 | MultiNLI, and then evaluated on the foreign language XNLI test. In this case, 65 | machine translation was not involved at all in either the pre-training or 66 | fine-tuning. 67 | 68 | Note that the English result is worse than the 84.2 MultiNLI baseline because 69 | this training used Multilingual BERT rather than English-only BERT. This implies 70 | that for high-resource languages, the Multilingual model is somewhat worse than 71 | a single-language model. However, it is not feasible for us to train and 72 | maintain dozens of single-language model. Therefore, if your goal is to maximize 73 | performance with a language other than English or Chinese, you might find it 74 | beneficial to run pre-training for additional steps starting from our 75 | Multilingual model on data from your language of interest. 76 | 77 | Here is a comparison of training Chinese models with the Multilingual 78 | `BERT-Base` and Chinese-only `BERT-Base`: 79 | 80 | System | Chinese 81 | ----------------------- | ------- 82 | XNLI Baseline | 67.0 83 | BERT Multilingual Model | 74.2 84 | BERT Chinese-only Model | 77.2 85 | 86 | Similar to English, the single-language model does 3% better than the 87 | Multilingual model. 88 | 89 | ## Fine-tuning Example 90 | 91 | The multilingual model does **not** require any special consideration or API 92 | changes. We did update the implementation of `BasicTokenizer` in 93 | `tokenization.py` to support Chinese character tokenization, so please update if 94 | you forked it. However, we did not change the tokenization API. 95 | 96 | To test the new models, we did modify `run_classifier.py` to add support for the 97 | [XNLI dataset](https://github.com/facebookresearch/XNLI). This is a 15-language 98 | version of MultiNLI where the dev/test sets have been human-translated, and the 99 | training set has been machine-translated. 100 | 101 | To run the fine-tuning code, please download the 102 | [XNLI dev/test set](https://s3.amazonaws.com/xnli/XNLI-1.0.zip) and the 103 | [XNLI machine-translated training set](https://s3.amazonaws.com/xnli/XNLI-MT-1.0.zip) 104 | and then unpack both .zip files into some directory `$XNLI_DIR`. 105 | 106 | To run fine-tuning on XNLI. The language is hard-coded into `run_classifier.py` 107 | (Chinese by default), so please modify `XnliProcessor` if you want to run on 108 | another language. 109 | 110 | This is a large dataset, so this will training will take a few hours on a GPU 111 | (or about 30 minutes on a Cloud TPU). To run an experiment quickly for 112 | debugging, just set `num_train_epochs` to a small value like `0.1`. 113 | 114 | ```shell 115 | export BERT_BASE_DIR=/path/to/bert/chinese_L-12_H-768_A-12 # or multilingual_L-12_H-768_A-12 116 | export XNLI_DIR=/path/to/xnli 117 | 118 | python run_classifier.py \ 119 | --task_name=XNLI \ 120 | --do_train=true \ 121 | --do_eval=true \ 122 | --data_dir=$XNLI_DIR \ 123 | --vocab_file=$BERT_BASE_DIR/vocab.txt \ 124 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \ 125 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ 126 | --max_seq_length=128 \ 127 | --train_batch_size=32 \ 128 | --learning_rate=5e-5 \ 129 | --num_train_epochs=2.0 \ 130 | --output_dir=/tmp/xnli_output/ 131 | ``` 132 | 133 | With the Chinese-only model, the results should look something like this: 134 | 135 | ``` 136 | ***** Eval results ***** 137 | eval_accuracy = 0.774116 138 | eval_loss = 0.83554 139 | global_step = 24543 140 | loss = 0.74603 141 | ``` 142 | 143 | ## Details 144 | 145 | ### Data Source and Sampling 146 | 147 | The languages chosen were the 148 | [top 100 languages with the largest Wikipedias](https://meta.wikimedia.org/wiki/List_of_Wikipedias). 149 | The entire Wikipedia dump for each language (excluding user and talk pages) was 150 | taken as the training data for each language 151 | 152 | However, the size of the Wikipedia for a given language varies greatly, and 153 | therefore low-resource languages may be "under-represented" in terms of the 154 | neural network model (under the assumption that languages are "competing" for 155 | limited model capacity to some extent). 156 | 157 | However, the size of a Wikipedia also correlates with the number of speakers of 158 | a language, and we also don't want to overfit the model by performing thousands 159 | of epochs over a tiny Wikipedia for a particular language. 160 | 161 | To balance these two factors, we performed exponentially smoothed weighting of 162 | the data during pre-training data creation (and WordPiece vocab creation). In 163 | other words, let's say that the probability of a language is *P(L)*, e.g., 164 | *P(English) = 0.21* means that after concatenating all of the Wikipedias 165 | together, 21% of our data is English. We exponentiate each probability by some 166 | factor *S* and then re-normalize, and sample from that distribution. In our case 167 | we use *S=0.7*. So, high-resource languages like English will be under-sampled, 168 | and low-resource languages like Icelandic will be over-sampled. E.g., in the 169 | original distribution English would be sampled 1000x more than Icelandic, but 170 | after smoothing it's only sampled 100x more. 171 | 172 | ### Tokenization 173 | 174 | For tokenization, we use a 110k shared WordPiece vocabulary. The word counts are 175 | weighted the same way as the data, so low-resource languages are upweighted by 176 | some factor. We intentionally do *not* use any marker to denote the input 177 | language (so that zero-shot training can work). 178 | 179 | Because Chinese (and Japanese Kanji and Korean Hanja) does not have whitespace 180 | characters, we add spaces around every character in the 181 | [CJK Unicode range](https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_\(Unicode_block\)) 182 | before applying WordPiece. This means that Chinese is effectively 183 | character-tokenized. Note that the CJK Unicode block only includes 184 | Chinese-origin characters and does *not* include Hangul Korean or 185 | Katakana/Hiragana Japanese, which are tokenized with whitespace+WordPiece like 186 | all other languages. 187 | 188 | For all other languages, we apply the 189 | [same recipe as English](https://github.com/google-research/bert#tokenization): 190 | (a) lower casing+accent removal, (b) punctuation splitting, (c) whitespace 191 | tokenization. We understand that accent markers have substantial meaning in some 192 | languages, but felt that the benefits of reducing the effective vocabulary make 193 | up for this. Generally the strong contextual models of BERT should make up for 194 | any ambiguity introduced by stripping accent markers. 195 | 196 | ### List of Languages 197 | 198 | The multilingual model supports the following languages. These languages were 199 | chosen because they are the top 100 languages with the largest Wikipedias: 200 | 201 | * Afrikaans 202 | * Albanian 203 | * Arabic 204 | * Aragonese 205 | * Armenian 206 | * Asturian 207 | * Azerbaijani 208 | * Bashkir 209 | * Basque 210 | * Bavarian 211 | * Belarusian 212 | * Bengali 213 | * Bishnupriya Manipuri 214 | * Bosnian 215 | * Breton 216 | * Bulgarian 217 | * Burmese 218 | * Catalan 219 | * Cebuano 220 | * Chechen 221 | * Chinese (Simplified) 222 | * Chinese (Traditional) 223 | * Chuvash 224 | * Croatian 225 | * Czech 226 | * Danish 227 | * Dutch 228 | * English 229 | * Estonian 230 | * Finnish 231 | * French 232 | * Galician 233 | * Georgian 234 | * German 235 | * Greek 236 | * Gujarati 237 | * Haitian 238 | * Hebrew 239 | * Hindi 240 | * Hungarian 241 | * Icelandic 242 | * Ido 243 | * Indonesian 244 | * Irish 245 | * Italian 246 | * Japanese 247 | * Javanese 248 | * Kannada 249 | * Kazakh 250 | * Kirghiz 251 | * Korean 252 | * Latin 253 | * Latvian 254 | * Lithuanian 255 | * Lombard 256 | * Low Saxon 257 | * Luxembourgish 258 | * Macedonian 259 | * Malagasy 260 | * Malay 261 | * Malayalam 262 | * Marathi 263 | * Minangkabau 264 | * Nepali 265 | * Newar 266 | * Norwegian (Bokmal) 267 | * Norwegian (Nynorsk) 268 | * Occitan 269 | * Persian (Farsi) 270 | * Piedmontese 271 | * Polish 272 | * Portuguese 273 | * Punjabi 274 | * Romanian 275 | * Russian 276 | * Scots 277 | * Serbian 278 | * Serbo-Croatian 279 | * Sicilian 280 | * Slovak 281 | * Slovenian 282 | * South Azerbaijani 283 | * Spanish 284 | * Sundanese 285 | * Swahili 286 | * Swedish 287 | * Tagalog 288 | * Tajik 289 | * Tamil 290 | * Tatar 291 | * Telugu 292 | * Turkish 293 | * Ukrainian 294 | * Urdu 295 | * Uzbek 296 | * Vietnamese 297 | * Volapük 298 | * Waray-Waray 299 | * Welsh 300 | * West Frisian 301 | * Western Punjabi 302 | * Yoruba 303 | 304 | The **Multilingual Cased (New)** release contains additionally **Thai** and 305 | **Mongolian**, which were not included in the original release. 306 | -------------------------------------------------------------------------------- /bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | # Normally the global step update is done inside of `apply_gradients`. 80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 81 | # a different optimizer, you should probably take this line out. 82 | new_global_step = global_step + 1 83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 84 | return train_op 85 | 86 | 87 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 88 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 89 | 90 | def __init__(self, 91 | learning_rate, 92 | weight_decay_rate=0.0, 93 | beta_1=0.9, 94 | beta_2=0.999, 95 | epsilon=1e-6, 96 | exclude_from_weight_decay=None, 97 | name="AdamWeightDecayOptimizer"): 98 | """Constructs a AdamWeightDecayOptimizer.""" 99 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 100 | 101 | self.learning_rate = learning_rate 102 | self.weight_decay_rate = weight_decay_rate 103 | self.beta_1 = beta_1 104 | self.beta_2 = beta_2 105 | self.epsilon = epsilon 106 | self.exclude_from_weight_decay = exclude_from_weight_decay 107 | 108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 109 | """See base class.""" 110 | assignments = [] 111 | for (grad, param) in grads_and_vars: 112 | if grad is None or param is None: 113 | continue 114 | 115 | param_name = self._get_variable_name(param.name) 116 | 117 | m = tf.get_variable( 118 | name=param_name + "/adam_m", 119 | shape=param.shape.as_list(), 120 | dtype=tf.float32, 121 | trainable=False, 122 | initializer=tf.zeros_initializer()) 123 | v = tf.get_variable( 124 | name=param_name + "/adam_v", 125 | shape=param.shape.as_list(), 126 | dtype=tf.float32, 127 | trainable=False, 128 | initializer=tf.zeros_initializer()) 129 | 130 | # Standard Adam update. 131 | next_m = ( 132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 133 | next_v = ( 134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 135 | tf.square(grad))) 136 | 137 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 138 | 139 | # Just adding the square of the weights to the loss function is *not* 140 | # the correct way of using L2 regularization/weight decay with Adam, 141 | # since that will interact with the m and v parameters in strange ways. 142 | # 143 | # Instead we want ot decay the weights in a manner that doesn't interact 144 | # with the m/v parameters. This is equivalent to adding the square 145 | # of the weights to the loss with plain (non-momentum) SGD. 146 | if self._do_use_weight_decay(param_name): 147 | update += self.weight_decay_rate * param 148 | 149 | update_with_lr = self.learning_rate * update 150 | 151 | next_param = param - update_with_lr 152 | 153 | assignments.extend( 154 | [param.assign(next_param), 155 | m.assign(next_m), 156 | v.assign(next_v)]) 157 | return tf.group(*assignments, name=name) 158 | 159 | def _do_use_weight_decay(self, param_name): 160 | """Whether to use L2 weight decay for `param_name`.""" 161 | if not self.weight_decay_rate: 162 | return False 163 | if self.exclude_from_weight_decay: 164 | for r in self.exclude_from_weight_decay: 165 | if re.search(r, param_name) is not None: 166 | return False 167 | return True 168 | 169 | def _get_variable_name(self, param_name): 170 | """Get the variable name from the tensor name.""" 171 | m = re.match("^(.*):\\d+$", param_name) 172 | if m is not None: 173 | param_name = m.group(1) 174 | return param_name 175 | -------------------------------------------------------------------------------- /bert/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import optimization 20 | import tensorflow as tf 21 | 22 | 23 | class OptimizationTest(tf.test.TestCase): 24 | 25 | def test_adam(self): 26 | with self.test_session() as sess: 27 | w = tf.get_variable( 28 | "w", 29 | shape=[3], 30 | initializer=tf.constant_initializer([0.1, -0.2, -0.1])) 31 | x = tf.constant([0.4, 0.2, -0.5]) 32 | loss = tf.reduce_mean(tf.square(x - w)) 33 | tvars = tf.trainable_variables() 34 | grads = tf.gradients(loss, tvars) 35 | global_step = tf.train.get_or_create_global_step() 36 | optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2) 37 | train_op = optimizer.apply_gradients(zip(grads, tvars), global_step) 38 | init_op = tf.group(tf.global_variables_initializer(), 39 | tf.local_variables_initializer()) 40 | sess.run(init_op) 41 | for _ in range(100): 42 | sess.run(train_op) 43 | w_np = sess.run(w) 44 | self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2) 45 | 46 | 47 | if __name__ == "__main__": 48 | tf.test.main() 49 | -------------------------------------------------------------------------------- /bert/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow >= 1.11.0 # CPU Version of TensorFlow. 2 | # tensorflow-gpu >= 1.11.0 # GPU version of TensorFlow. 3 | -------------------------------------------------------------------------------- /bert/run_classifier_with_tfhub.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """BERT finetuning runner with TF-Hub.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import optimization 23 | import run_classifier 24 | import tokenization 25 | import tensorflow as tf 26 | import tensorflow_hub as hub 27 | 28 | flags = tf.flags 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | flags.DEFINE_string( 33 | "bert_hub_module_handle", None, 34 | "Handle for the BERT TF-Hub module.") 35 | 36 | 37 | def create_model(is_training, input_ids, input_mask, segment_ids, labels, 38 | num_labels): 39 | """Creates a classification model.""" 40 | tags = set() 41 | if is_training: 42 | tags.add("train") 43 | bert_module = hub.Module( 44 | FLAGS.bert_hub_module_handle, 45 | tags=tags, 46 | trainable=True) 47 | bert_inputs = dict( 48 | input_ids=input_ids, 49 | input_mask=input_mask, 50 | segment_ids=segment_ids) 51 | bert_outputs = bert_module( 52 | inputs=bert_inputs, 53 | signature="tokens", 54 | as_dict=True) 55 | 56 | # In the demo, we are doing a simple classification task on the entire 57 | # segment. 58 | # 59 | # If you want to use the token-level output, use 60 | # bert_outputs["sequence_output"] instead. 61 | output_layer = bert_outputs["pooled_output"] 62 | 63 | hidden_size = output_layer.shape[-1].value 64 | 65 | output_weights = tf.get_variable( 66 | "output_weights", [num_labels, hidden_size], 67 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 68 | 69 | output_bias = tf.get_variable( 70 | "output_bias", [num_labels], initializer=tf.zeros_initializer()) 71 | 72 | with tf.variable_scope("loss"): 73 | if is_training: 74 | # I.e., 0.1 dropout 75 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 76 | 77 | logits = tf.matmul(output_layer, output_weights, transpose_b=True) 78 | logits = tf.nn.bias_add(logits, output_bias) 79 | log_probs = tf.nn.log_softmax(logits, axis=-1) 80 | 81 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 82 | 83 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 84 | loss = tf.reduce_mean(per_example_loss) 85 | 86 | return (loss, per_example_loss, logits) 87 | 88 | 89 | def model_fn_builder(num_labels, learning_rate, num_train_steps, 90 | num_warmup_steps, use_tpu): 91 | """Returns `model_fn` closure for TPUEstimator.""" 92 | 93 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 94 | """The `model_fn` for TPUEstimator.""" 95 | 96 | tf.logging.info("*** Features ***") 97 | for name in sorted(features.keys()): 98 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 99 | 100 | input_ids = features["input_ids"] 101 | input_mask = features["input_mask"] 102 | segment_ids = features["segment_ids"] 103 | label_ids = features["label_ids"] 104 | 105 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 106 | 107 | (total_loss, per_example_loss, logits) = create_model( 108 | is_training, input_ids, input_mask, segment_ids, label_ids, num_labels) 109 | 110 | output_spec = None 111 | if mode == tf.estimator.ModeKeys.TRAIN: 112 | train_op = optimization.create_optimizer( 113 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 114 | 115 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 116 | mode=mode, 117 | loss=total_loss, 118 | train_op=train_op) 119 | elif mode == tf.estimator.ModeKeys.EVAL: 120 | 121 | def metric_fn(per_example_loss, label_ids, logits): 122 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 123 | accuracy = tf.metrics.accuracy(label_ids, predictions) 124 | loss = tf.metrics.mean(per_example_loss) 125 | return { 126 | "eval_accuracy": accuracy, 127 | "eval_loss": loss, 128 | } 129 | 130 | eval_metrics = (metric_fn, [per_example_loss, label_ids, logits]) 131 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 132 | mode=mode, 133 | loss=total_loss, 134 | eval_metrics=eval_metrics) 135 | else: 136 | raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode)) 137 | 138 | return output_spec 139 | 140 | return model_fn 141 | 142 | 143 | def create_tokenizer_from_hub_module(): 144 | """Get the vocab file and casing info from the Hub module.""" 145 | with tf.Graph().as_default(): 146 | bert_module = hub.Module(FLAGS.bert_hub_module_handle) 147 | tokenization_info = bert_module(signature="tokenization_info", as_dict=True) 148 | with tf.Session() as sess: 149 | vocab_file, do_lower_case = sess.run([tokenization_info["vocab_file"], 150 | tokenization_info["do_lower_case"]]) 151 | return tokenization.FullTokenizer( 152 | vocab_file=vocab_file, do_lower_case=do_lower_case) 153 | 154 | 155 | def main(_): 156 | tf.logging.set_verbosity(tf.logging.INFO) 157 | 158 | processors = { 159 | "cola": run_classifier.ColaProcessor, 160 | "mnli": run_classifier.MnliProcessor, 161 | "mrpc": run_classifier.MrpcProcessor, 162 | } 163 | 164 | if not FLAGS.do_train and not FLAGS.do_eval: 165 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 166 | 167 | tf.gfile.MakeDirs(FLAGS.output_dir) 168 | 169 | task_name = FLAGS.task_name.lower() 170 | 171 | if task_name not in processors: 172 | raise ValueError("Task not found: %s" % (task_name)) 173 | 174 | processor = processors[task_name]() 175 | 176 | label_list = processor.get_labels() 177 | 178 | tokenizer = create_tokenizer_from_hub_module() 179 | 180 | tpu_cluster_resolver = None 181 | if FLAGS.use_tpu and FLAGS.tpu_name: 182 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 183 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 184 | 185 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 186 | run_config = tf.contrib.tpu.RunConfig( 187 | cluster=tpu_cluster_resolver, 188 | master=FLAGS.master, 189 | model_dir=FLAGS.output_dir, 190 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 191 | tpu_config=tf.contrib.tpu.TPUConfig( 192 | iterations_per_loop=FLAGS.iterations_per_loop, 193 | num_shards=FLAGS.num_tpu_cores, 194 | per_host_input_for_training=is_per_host)) 195 | 196 | train_examples = None 197 | num_train_steps = None 198 | num_warmup_steps = None 199 | if FLAGS.do_train: 200 | train_examples = processor.get_train_examples(FLAGS.data_dir) 201 | num_train_steps = int( 202 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) 203 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 204 | 205 | model_fn = model_fn_builder( 206 | num_labels=len(label_list), 207 | learning_rate=FLAGS.learning_rate, 208 | num_train_steps=num_train_steps, 209 | num_warmup_steps=num_warmup_steps, 210 | use_tpu=FLAGS.use_tpu) 211 | 212 | # If TPU is not available, this will fall back to normal Estimator on CPU 213 | # or GPU. 214 | estimator = tf.contrib.tpu.TPUEstimator( 215 | use_tpu=FLAGS.use_tpu, 216 | model_fn=model_fn, 217 | config=run_config, 218 | train_batch_size=FLAGS.train_batch_size, 219 | eval_batch_size=FLAGS.eval_batch_size) 220 | 221 | if FLAGS.do_train: 222 | train_features = run_classifier.convert_examples_to_features( 223 | train_examples, label_list, FLAGS.max_seq_length, tokenizer) 224 | tf.logging.info("***** Running training *****") 225 | tf.logging.info(" Num examples = %d", len(train_examples)) 226 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 227 | tf.logging.info(" Num steps = %d", num_train_steps) 228 | train_input_fn = run_classifier.input_fn_builder( 229 | features=train_features, 230 | seq_length=FLAGS.max_seq_length, 231 | is_training=True, 232 | drop_remainder=True) 233 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 234 | 235 | if FLAGS.do_eval: 236 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 237 | eval_features = run_classifier.convert_examples_to_features( 238 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer) 239 | 240 | tf.logging.info("***** Running evaluation *****") 241 | tf.logging.info(" Num examples = %d", len(eval_examples)) 242 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 243 | 244 | # This tells the estimator to run through the entire set. 245 | eval_steps = None 246 | # However, if running eval on the TPU, you will need to specify the 247 | # number of steps. 248 | if FLAGS.use_tpu: 249 | # Eval will be slightly WRONG on the TPU because it will truncate 250 | # the last batch. 251 | eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size) 252 | 253 | eval_drop_remainder = True if FLAGS.use_tpu else False 254 | eval_input_fn = run_classifier.input_fn_builder( 255 | features=eval_features, 256 | seq_length=FLAGS.max_seq_length, 257 | is_training=False, 258 | drop_remainder=eval_drop_remainder) 259 | 260 | result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) 261 | 262 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 263 | with tf.gfile.GFile(output_eval_file, "w") as writer: 264 | tf.logging.info("***** Eval results *****") 265 | for key in sorted(result.keys()): 266 | tf.logging.info(" %s = %s", key, str(result[key])) 267 | writer.write("%s = %s\n" % (key, str(result[key]))) 268 | 269 | 270 | if __name__ == "__main__": 271 | flags.mark_flag_as_required("data_dir") 272 | flags.mark_flag_as_required("task_name") 273 | flags.mark_flag_as_required("bert_hub_module_handle") 274 | flags.mark_flag_as_required("output_dir") 275 | tf.app.run() 276 | -------------------------------------------------------------------------------- /bert/sample_text.txt: -------------------------------------------------------------------------------- 1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত 2 | Text should be one-sentence-per-line, with empty lines between documents. 3 | This sample text is public domain and was randomly selected from Project Guttenberg. 4 | 5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors. 6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity. 7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them. 8 | "Cass" Beard had risen early that morning, but not with a view to discovery. 9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets. 10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency. 11 | This was nearly opposite. 12 | Mr. Cassius crossed the highway, and stopped suddenly. 13 | Something glittered in the nearest red pool before him. 14 | Gold, surely! 15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring. 16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass." 17 | Like most of his fellow gold-seekers, Cass was superstitious. 18 | 19 | The fountain of classic wisdom, Hypatia herself. 20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge. 21 | From my youth I felt in me a soul above the matter-entangled herd. 22 | She revealed to me the glorious fact, that I am a spark of Divinity itself. 23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's. 24 | There is a philosophic pleasure in opening one's treasures to the modest young. 25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street. 26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide; 27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind. 28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now. 29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert; 30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts. 31 | At last they reached the quay at the opposite end of the street; 32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers. 33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him. 34 | -------------------------------------------------------------------------------- /bert/tokenization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import tempfile 21 | import tokenization 22 | import six 23 | import tensorflow as tf 24 | 25 | 26 | class TokenizationTest(tf.test.TestCase): 27 | 28 | def test_full_tokenizer(self): 29 | vocab_tokens = [ 30 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 31 | "##ing", "," 32 | ] 33 | with tempfile.NamedTemporaryFile(delete=False) as vocab_writer: 34 | if six.PY2: 35 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 36 | else: 37 | vocab_writer.write("".join( 38 | [x + "\n" for x in vocab_tokens]).encode("utf-8")) 39 | 40 | vocab_file = vocab_writer.name 41 | 42 | tokenizer = tokenization.FullTokenizer(vocab_file) 43 | os.unlink(vocab_file) 44 | 45 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 46 | self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 47 | 48 | self.assertAllEqual( 49 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 50 | 51 | def test_chinese(self): 52 | tokenizer = tokenization.BasicTokenizer() 53 | 54 | self.assertAllEqual( 55 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 56 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 57 | 58 | def test_basic_tokenizer_lower(self): 59 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True) 60 | 61 | self.assertAllEqual( 62 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 63 | ["hello", "!", "how", "are", "you", "?"]) 64 | self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 65 | 66 | def test_basic_tokenizer_no_lower(self): 67 | tokenizer = tokenization.BasicTokenizer(do_lower_case=False) 68 | 69 | self.assertAllEqual( 70 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 71 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 72 | 73 | def test_wordpiece_tokenizer(self): 74 | vocab_tokens = [ 75 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 76 | "##ing" 77 | ] 78 | 79 | vocab = {} 80 | for (i, token) in enumerate(vocab_tokens): 81 | vocab[token] = i 82 | tokenizer = tokenization.WordpieceTokenizer(vocab=vocab) 83 | 84 | self.assertAllEqual(tokenizer.tokenize(""), []) 85 | 86 | self.assertAllEqual( 87 | tokenizer.tokenize("unwanted running"), 88 | ["un", "##want", "##ed", "runn", "##ing"]) 89 | 90 | self.assertAllEqual( 91 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 92 | 93 | def test_convert_tokens_to_ids(self): 94 | vocab_tokens = [ 95 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 96 | "##ing" 97 | ] 98 | 99 | vocab = {} 100 | for (i, token) in enumerate(vocab_tokens): 101 | vocab[token] = i 102 | 103 | self.assertAllEqual( 104 | tokenization.convert_tokens_to_ids( 105 | vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9]) 106 | 107 | def test_is_whitespace(self): 108 | self.assertTrue(tokenization._is_whitespace(u" ")) 109 | self.assertTrue(tokenization._is_whitespace(u"\t")) 110 | self.assertTrue(tokenization._is_whitespace(u"\r")) 111 | self.assertTrue(tokenization._is_whitespace(u"\n")) 112 | self.assertTrue(tokenization._is_whitespace(u"\u00A0")) 113 | 114 | self.assertFalse(tokenization._is_whitespace(u"A")) 115 | self.assertFalse(tokenization._is_whitespace(u"-")) 116 | 117 | def test_is_control(self): 118 | self.assertTrue(tokenization._is_control(u"\u0005")) 119 | 120 | self.assertFalse(tokenization._is_control(u"A")) 121 | self.assertFalse(tokenization._is_control(u" ")) 122 | self.assertFalse(tokenization._is_control(u"\t")) 123 | self.assertFalse(tokenization._is_control(u"\r")) 124 | 125 | def test_is_punctuation(self): 126 | self.assertTrue(tokenization._is_punctuation(u"-")) 127 | self.assertTrue(tokenization._is_punctuation(u"$")) 128 | self.assertTrue(tokenization._is_punctuation(u"`")) 129 | self.assertTrue(tokenization._is_punctuation(u".")) 130 | 131 | self.assertFalse(tokenization._is_punctuation(u"A")) 132 | self.assertFalse(tokenization._is_punctuation(u" ")) 133 | 134 | 135 | if __name__ == "__main__": 136 | tf.test.main() 137 | -------------------------------------------------------------------------------- /calculating_model_score/calculate_atis_intent.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from sklearn_metrics_function import show_metrics 4 | 5 | ATIS_intent_label = ['atis_abbreviation', 'atis_aircraft', 'atis_aircraft#atis_flight#atis_flight_no', 6 | 'atis_airfare', 'atis_airfare#atis_flight', 'atis_airfare#atis_flight_time', 7 | 'atis_airline', 'atis_airline#atis_flight_no', 'atis_airport', 'atis_capacity', 8 | 'atis_cheapest', 'atis_city', 'atis_day_name', 'atis_distance', 'atis_flight', 9 | 'atis_flight#atis_airfare', 'atis_flight#atis_airline', 'atis_flight_no', 10 | 'atis_flight_no#atis_airline', 'atis_flight_time', 'atis_ground_fare', 11 | 'atis_ground_service', 'atis_ground_service#atis_ground_fare', 'atis_meal', 12 | 'atis_quantity', 'atis_restriction'] 13 | 14 | with open(os.path.join("ATIS_Intent", "label")) as label_f: 15 | label_list = [label.replace("\n", "") for label in label_f.readlines()] 16 | #print(len(label_list), label_list) 17 | 18 | predit_label_value = np.fromfile(os.path.join("ATIS_Intent", "test_results.tsv"), sep="\t") 19 | predit_label_value = predit_label_value.reshape(-1, len(ATIS_intent_label)) 20 | predit_label_value = np.argmax(predit_label_value, axis=1) 21 | predit_label = [ATIS_intent_label[label_index] for label_index in predit_label_value] 22 | 23 | #print(len(predit_label), predit_label) 24 | 25 | 26 | show_metrics(y_test=label_list, y_predict=predit_label, labels=ATIS_intent_label) -------------------------------------------------------------------------------- /calculating_model_score/calculate_atis_slot.py: -------------------------------------------------------------------------------- 1 | import os 2 | from sklearn_metrics_function import show_metrics,delete_both_sides_is_O_word 3 | 4 | ATIS_slot_label = ['[Padding]', '[##WordPiece]', '[CLS]', '[SEP]', 'B-aircraft_code', 'B-airline_code', 'B-airline_name', 'B-airport_code', 'B-airport_name', 'B-arrive_date.date_relative', 'B-arrive_date.day_name', 'B-arrive_date.day_number', 'B-arrive_date.month_name', 'B-arrive_date.today_relative', 'B-arrive_time.end_time', 'B-arrive_time.period_mod', 'B-arrive_time.period_of_day', 'B-arrive_time.start_time', 'B-arrive_time.time', 'B-arrive_time.time_relative', 'B-booking_class', 'B-city_name', 'B-class_type', 'B-compartment', 'B-connect', 'B-cost_relative', 'B-day_name', 'B-day_number', 'B-days_code', 'B-depart_date.date_relative', 'B-depart_date.day_name', 'B-depart_date.day_number', 'B-depart_date.month_name', 'B-depart_date.today_relative', 'B-depart_date.year', 'B-depart_time.end_time', 'B-depart_time.period_mod', 'B-depart_time.period_of_day', 'B-depart_time.start_time', 'B-depart_time.time', 'B-depart_time.time_relative', 'B-economy', 'B-fare_amount', 'B-fare_basis_code', 'B-flight', 'B-flight_days', 'B-flight_mod', 'B-flight_number', 'B-flight_stop', 'B-flight_time', 'B-fromloc.airport_code', 'B-fromloc.airport_name', 'B-fromloc.city_name', 'B-fromloc.state_code', 'B-fromloc.state_name', 'B-meal', 'B-meal_code', 'B-meal_description', 'B-mod', 'B-month_name', 'B-or', 'B-period_of_day', 'B-restriction_code', 'B-return_date.date_relative', 'B-return_date.day_name', 'B-return_date.day_number', 'B-return_date.month_name', 'B-return_date.today_relative', 'B-return_time.period_mod', 'B-return_time.period_of_day', 'B-round_trip', 'B-state_code', 'B-state_name', 'B-stoploc.airport_code', 'B-stoploc.airport_name', 'B-stoploc.city_name', 'B-stoploc.state_code', 'B-time', 'B-time_relative', 'B-today_relative', 'B-toloc.airport_code', 'B-toloc.airport_name', 'B-toloc.city_name', 'B-toloc.country_name', 'B-toloc.state_code', 'B-toloc.state_name', 'B-transport_type', 'I-airline_name', 'I-airport_name', 'I-arrive_date.day_number', 'I-arrive_time.end_time', 'I-arrive_time.period_of_day', 'I-arrive_time.start_time', 'I-arrive_time.time', 'I-arrive_time.time_relative', 'I-city_name', 'I-class_type', 'I-cost_relative', 'I-depart_date.day_number', 'I-depart_date.today_relative', 'I-depart_time.end_time', 'I-depart_time.period_of_day', 'I-depart_time.start_time', 'I-depart_time.time', 'I-depart_time.time_relative', 'I-economy', 'I-fare_amount', 'I-fare_basis_code', 'I-flight_mod', 'I-flight_number', 'I-flight_stop', 'I-flight_time', 'I-fromloc.airport_name', 'I-fromloc.city_name', 'I-fromloc.state_name', 'I-meal_code', 'I-meal_description', 'I-restriction_code', 'I-return_date.date_relative', 'I-return_date.day_number', 'I-return_date.today_relative', 'I-round_trip', 'I-state_name', 'I-stoploc.city_name', 'I-time', 'I-today_relative', 'I-toloc.airport_name', 'I-toloc.city_name', 'I-toloc.state_name', 'I-transport_type', 'O'] 5 | ATIS_slot_effective_label = ['B-aircraft_code', 'B-airline_code', 'B-airline_name', 'B-airport_code', 'B-airport_name', 'B-arrive_date.date_relative', 'B-arrive_date.day_name', 'B-arrive_date.day_number', 'B-arrive_date.month_name', 'B-arrive_date.today_relative', 'B-arrive_time.end_time', 'B-arrive_time.period_mod', 'B-arrive_time.period_of_day', 'B-arrive_time.start_time', 'B-arrive_time.time', 'B-arrive_time.time_relative', 'B-booking_class', 'B-city_name', 'B-class_type', 'B-compartment', 'B-connect', 'B-cost_relative', 'B-day_name', 'B-day_number', 'B-days_code', 'B-depart_date.date_relative', 'B-depart_date.day_name', 'B-depart_date.day_number', 'B-depart_date.month_name', 'B-depart_date.today_relative', 'B-depart_date.year', 'B-depart_time.end_time', 'B-depart_time.period_mod', 'B-depart_time.period_of_day', 'B-depart_time.start_time', 'B-depart_time.time', 'B-depart_time.time_relative', 'B-economy', 'B-fare_amount', 'B-fare_basis_code', 'B-flight', 'B-flight_days', 'B-flight_mod', 'B-flight_number', 'B-flight_stop', 'B-flight_time', 'B-fromloc.airport_code', 'B-fromloc.airport_name', 'B-fromloc.city_name', 'B-fromloc.state_code', 'B-fromloc.state_name', 'B-meal', 'B-meal_code', 'B-meal_description', 'B-mod', 'B-month_name', 'B-or', 'B-period_of_day', 'B-restriction_code', 'B-return_date.date_relative', 'B-return_date.day_name', 'B-return_date.day_number', 'B-return_date.month_name', 'B-return_date.today_relative', 'B-return_time.period_mod', 'B-return_time.period_of_day', 'B-round_trip', 'B-state_code', 'B-state_name', 'B-stoploc.airport_code', 'B-stoploc.airport_name', 'B-stoploc.city_name', 'B-stoploc.state_code', 'B-time', 'B-time_relative', 'B-today_relative', 'B-toloc.airport_code', 'B-toloc.airport_name', 'B-toloc.city_name', 'B-toloc.country_name', 'B-toloc.state_code', 'B-toloc.state_name', 'B-transport_type', 'I-airline_name', 'I-airport_name', 'I-arrive_date.day_number', 'I-arrive_time.end_time', 'I-arrive_time.period_of_day', 'I-arrive_time.start_time', 'I-arrive_time.time', 'I-arrive_time.time_relative', 'I-city_name', 'I-class_type', 'I-cost_relative', 'I-depart_date.day_number', 'I-depart_date.today_relative', 'I-depart_time.end_time', 'I-depart_time.period_of_day', 'I-depart_time.start_time', 'I-depart_time.time', 'I-depart_time.time_relative', 'I-economy', 'I-fare_amount', 'I-fare_basis_code', 'I-flight_mod', 'I-flight_number', 'I-flight_stop', 'I-flight_time', 'I-fromloc.airport_name', 'I-fromloc.city_name', 'I-fromloc.state_name', 'I-meal_code', 'I-meal_description', 'I-restriction_code', 'I-return_date.date_relative', 'I-return_date.day_number', 'I-return_date.today_relative', 'I-round_trip', 'I-state_name', 'I-stoploc.city_name', 'I-time', 'I-today_relative', 'I-toloc.airport_name', 'I-toloc.city_name', 'I-toloc.state_name', 'I-transport_type', 'O'] 6 | ATIS_slot_effective_label2 = ['B-aircraft_code', 'B-airline_code', 'B-airline_name', 'B-airport_code', 'B-airport_name', 'B-arrive_date.date_relative', 'B-arrive_date.day_name', 'B-arrive_date.day_number', 'B-arrive_date.month_name', 'B-arrive_date.today_relative', 'B-arrive_time.end_time', 'B-arrive_time.period_mod', 'B-arrive_time.period_of_day', 'B-arrive_time.start_time', 'B-arrive_time.time', 'B-arrive_time.time_relative', 'B-booking_class', 'B-city_name', 'B-class_type', 'B-compartment', 'B-connect', 'B-cost_relative', 'B-day_name', 'B-day_number', 'B-days_code', 'B-depart_date.date_relative', 'B-depart_date.day_name', 'B-depart_date.day_number', 'B-depart_date.month_name', 'B-depart_date.today_relative', 'B-depart_date.year', 'B-depart_time.end_time', 'B-depart_time.period_mod', 'B-depart_time.period_of_day', 'B-depart_time.start_time', 'B-depart_time.time', 'B-depart_time.time_relative', 'B-economy', 'B-fare_amount', 'B-fare_basis_code', 'B-flight', 'B-flight_days', 'B-flight_mod', 'B-flight_number', 'B-flight_stop', 'B-flight_time', 'B-fromloc.airport_code', 'B-fromloc.airport_name', 'B-fromloc.city_name', 'B-fromloc.state_code', 'B-fromloc.state_name', 'B-meal', 'B-meal_code', 'B-meal_description', 'B-mod', 'B-month_name', 'B-or', 'B-period_of_day', 'B-restriction_code', 'B-return_date.date_relative', 'B-return_date.day_name', 'B-return_date.day_number', 'B-return_date.month_name', 'B-return_date.today_relative', 'B-return_time.period_mod', 'B-return_time.period_of_day', 'B-round_trip', 'B-state_code', 'B-state_name', 'B-stoploc.airport_code', 'B-stoploc.airport_name', 'B-stoploc.city_name', 'B-stoploc.state_code', 'B-time', 'B-time_relative', 'B-today_relative', 'B-toloc.airport_code', 'B-toloc.airport_name', 'B-toloc.city_name', 'B-toloc.country_name', 'B-toloc.state_code', 'B-toloc.state_name', 'B-transport_type', 'I-airline_name', 'I-airport_name', 'I-arrive_date.day_number', 'I-arrive_time.end_time', 'I-arrive_time.period_of_day', 'I-arrive_time.start_time', 'I-arrive_time.time', 'I-arrive_time.time_relative', 'I-city_name', 'I-class_type', 'I-cost_relative', 'I-depart_date.day_number', 'I-depart_date.today_relative', 'I-depart_time.end_time', 'I-depart_time.period_of_day', 'I-depart_time.start_time', 'I-depart_time.time', 'I-depart_time.time_relative', 'I-economy', 'I-fare_amount', 'I-fare_basis_code', 'I-flight_mod', 'I-flight_number', 'I-flight_stop', 'I-flight_time', 'I-fromloc.airport_name', 'I-fromloc.city_name', 'I-fromloc.state_name', 'I-meal_code', 'I-meal_description', 'I-restriction_code', 'I-return_date.date_relative', 'I-return_date.day_number', 'I-return_date.today_relative', 'I-round_trip', 'I-state_name', 'I-stoploc.city_name', 'I-time', 'I-today_relative', 'I-toloc.airport_name', 'I-toloc.city_name', 'I-toloc.state_name', 'I-transport_type'] 7 | 8 | 9 | 10 | with open(os.path.join("ATIS_solt", "seq.out")) as label_f: 11 | label_list = [label.replace("\n", "") for label in label_f.readlines()] 12 | label_list = [seq.split() for seq in label_list] 13 | #print(len(label_list), label_list) 14 | 15 | with open(os.path.join("ATIS_solt", "label_test.txt")) as predict_f: 16 | predict_list = [predict_label.replace("\n", "") for predict_label in predict_f.readlines()] 17 | #print(len(predict_list), predict_list) 18 | predict_sentence_list = [] 19 | for word in predict_list: 20 | if "[CLS]" == word: 21 | a_sentence = [] 22 | a_sentence.append(word) 23 | if "[SEP]" == word: 24 | predict_sentence_list.append(a_sentence) 25 | #print(len(predict_sentence_list), predict_sentence_list) 26 | 27 | y_test_list = [] 28 | clean_y_predict_list = [] 29 | assert len(label_list)==len(predict_sentence_list) 30 | for y_test, y_predict in zip(label_list, predict_sentence_list): 31 | y_predict.remove('[CLS]') 32 | y_predict.remove('[SEP]') 33 | while '[Padding]' in y_predict: 34 | y_predict.remove('[Padding]') 35 | while '[##WordPiece]' in y_predict: 36 | y_predict.remove('[##WordPiece]') 37 | if len(y_predict)!=len(y_test): 38 | print(y_predict) 39 | print(y_test) 40 | print("~"*100) 41 | y_test_list.extend(y_test) 42 | clean_y_predict_list.extend(y_predict) 43 | 44 | assert len(y_test_list)==len(clean_y_predict_list) 45 | 46 | 47 | 48 | y_test_list, clean_y_predict_list = delete_both_sides_is_O_word(y_test_list, clean_y_predict_list) 49 | 50 | 51 | show_metrics(y_test=y_test_list, y_predict=clean_y_predict_list, labels=ATIS_slot_effective_label2) 52 | 53 | -------------------------------------------------------------------------------- /calculating_model_score/calculate_model_score.py: -------------------------------------------------------------------------------- 1 | from sklearn import metrics 2 | import os 3 | import sys 4 | import time 5 | #import numpy 6 | #numpy.set_printoptions(threshold=numpy.nan) 7 | 8 | 9 | class Logger(object): 10 | """store log to txt file""" 11 | def __init__(self, filename="Default.log"): 12 | self.terminal = sys.stdout 13 | self.log = open(filename, "a") 14 | 15 | def write(self, message): 16 | self.terminal.write(message) 17 | self.log.write(message) 18 | 19 | def flush(self): 20 | pass 21 | 22 | class Sequence_Labeling_and_Text_Classification_Calculate(object): 23 | 24 | def get_slot_labels(self): 25 | """for Sequence_Labeling labels""" 26 | raise NotImplementedError() 27 | 28 | def get_intent_labels(self): 29 | """for Text_Classification labels""" 30 | raise NotImplementedError() 31 | 32 | 33 | @classmethod 34 | def show_metrics(cls, y_test_list, y_predict_list, label_list): 35 | print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) ) 36 | print('准确率:', metrics.accuracy_score(y_test_list, y_predict_list)) # 预测准确率输出 37 | 38 | print('宏平均精确率:', metrics.precision_score(y_test_list, y_predict_list, average='macro')) # 预测宏平均精确率输出 39 | print('微平均精确率:', metrics.precision_score(y_test_list, y_predict_list, average='micro')) # 预测微平均精确率输出 40 | print('加权平均精确率:', metrics.precision_score(y_test_list, y_predict_list, average='weighted')) # 预测加权平均精确率输出 41 | 42 | print('宏平均召回率:', metrics.recall_score(y_test_list, y_predict_list, average='macro')) # 预测宏平均召回率输出 43 | print('微平均召回率:', metrics.recall_score(y_test_list, y_predict_list, average='micro')) # 预测微平均召回率输出 44 | print('加权平均召回率:', metrics.recall_score(y_test_list, y_predict_list, average='micro')) # 预测加权平均召回率输出 45 | 46 | print('宏平均F1-score:', metrics.f1_score(y_test_list, y_predict_list, labels=label_list, average='macro')) # 预测宏平均f1-score输出 47 | print('微平均F1-score:', metrics.f1_score(y_test_list, y_predict_list, labels=label_list, average='micro')) # 预测微平均f1-score输出 48 | print('加权平均F1-score:', 49 | metrics.f1_score(y_test_list, y_predict_list, labels=label_list, average='weighted')) # 预测加权平均f1-score输出 50 | a_confusion_matrix = metrics.confusion_matrix(y_test_list, y_predict_list) 51 | 52 | print('混淆矩阵输出:\n', a_confusion_matrix) # 混淆矩阵输出 53 | print('分类报告:\n', metrics.classification_report(y_test_list, y_predict_list)) # 分类报告输出 54 | print("\n") 55 | 56 | @classmethod 57 | def delete_both_sides_is_O_word(cls, y_test_list, clean_y_predict_list): 58 | new_y_test_list, new_clean_y_predict_list = [], [] 59 | for test, pred in zip(y_test_list, clean_y_predict_list): 60 | if test == "O" and pred == "O": 61 | continue 62 | new_y_test_list.append(test) 63 | new_clean_y_predict_list.append(pred) 64 | assert len(new_y_test_list) == len(new_clean_y_predict_list) 65 | return new_y_test_list, new_clean_y_predict_list 66 | 67 | class Snips_Slot_Filling_and_Intent_Detection_Calculate(Sequence_Labeling_and_Text_Classification_Calculate): 68 | 69 | def __init__(self, path_to_label_file=None, path_to_predict_label_file=None, log_out_file=None): 70 | if path_to_label_file is None and path_to_predict_label_file is None: 71 | raise Exception("At least have `path_to_label_file") 72 | self.path_to_label_file = path_to_label_file 73 | if path_to_predict_label_file is not None: 74 | self.path_to_predict_label_file = path_to_predict_label_file 75 | else: 76 | self.path_to_predict_label_file = path_to_label_file 77 | if log_out_file is None: 78 | self.log_out_file = os.getcwd() 79 | else: 80 | if not os.path.exists(log_out_file): 81 | os.makedirs(log_out_file) 82 | self.log_out_file = log_out_file 83 | 84 | def get_intent_label_list(self, path_to_intent_label_file): 85 | with open(path_to_intent_label_file) as label_f: 86 | intent_label_list = [label.replace("\n", "") for label in label_f.readlines()] 87 | return intent_label_list 88 | 89 | def get_predict_intent_label_list(self, path_to_predict_intent_label_file): 90 | with open(path_to_predict_intent_label_file) as intent_f: 91 | predict_intent_label_list = [label.replace("\n", "") for label in intent_f.readlines()] 92 | return predict_intent_label_list 93 | 94 | def _get_slot_sententce_list(self, path_to_slot_sentence_file): 95 | with open(path_to_slot_sentence_file) as slot_f: 96 | slot_sententce_list = [sententce.split() for sententce in slot_f.readlines()] 97 | return slot_sententce_list 98 | 99 | def _get_predict_slot_sentence_list(self, path_to_slot_filling_test_results_file): 100 | with open(path_to_slot_filling_test_results_file) as slot_predict_f: 101 | predict_slot_sentence_list = [predict_label.split() for predict_label in slot_predict_f.readlines()] 102 | return predict_slot_sentence_list 103 | 104 | def producte_slot_list(self): 105 | """input seq.out and slot_filling_test_results.txt file 106 | output slot_test_list, clean_predict_slot_list 107 | """ 108 | path_to_slot_sentence_file = os.path.join(self.path_to_label_file, "seq.out") 109 | slot_sententce_list = self._get_predict_slot_sentence_list(path_to_slot_sentence_file) 110 | path_to_slot_filling_test_results_file = os.path.join(self.path_to_predict_label_file, "slot_filling_test_results.txt") 111 | predict_slot_sentence_list = self._get_predict_slot_sentence_list(path_to_slot_filling_test_results_file) 112 | slot_test_list = [] 113 | clean_predict_slot_list = [] 114 | seqence_length_dont_match_index = 0 115 | for y_test, y_predict in zip(slot_sententce_list, predict_slot_sentence_list): 116 | y_predict.remove('[CLS]') 117 | y_predict.remove('[SEP]') 118 | while '[Padding]' in y_predict: 119 | y_predict.remove('[Padding]') 120 | while '[##WordPiece]' in y_predict: 121 | y_predict.remove('[##WordPiece]') 122 | if len(y_predict) > len(y_test): 123 | #print(seqence_length_dont_match_index) 124 | #print(y_predict) 125 | #print(y_test) 126 | #print("~" * 100) 127 | seqence_length_dont_match_index += 1 128 | y_predict = y_predict[0:len(y_test)] 129 | elif len(y_predict) < len(y_test): 130 | #print(seqence_length_dont_match_index) 131 | #print(y_predict) 132 | #print(y_test) 133 | #print("~" * 100) 134 | y_predict = y_predict + ["O"] * (len(y_test) - len(y_predict)) 135 | seqence_length_dont_match_index += 1 136 | assert len(y_predict) == len(y_test) 137 | slot_test_list.extend(y_test) 138 | clean_predict_slot_list.extend(y_predict) 139 | #print("seqence_length_dont_match numbers", seqence_length_dont_match_index) 140 | return slot_test_list, clean_predict_slot_list 141 | 142 | def get_slot_model_labels(self): 143 | """contain ['[Padding]', '[##WordPiece]', '[CLS]', '[SEP]', 'O'] + Task labels""" 144 | return ['[Padding]', '[##WordPiece]', '[CLS]', '[SEP]', 'B-album', 'B-artist', 'B-best_rating', 'B-city', 145 | 'B-condition_description', 'B-condition_temperature', 'B-country', 'B-cuisine', 'B-current_location', 146 | 'B-entity_name', 'B-facility', 'B-genre', 'B-geographic_poi', 'B-location_name', 'B-movie_name', 147 | 'B-movie_type', 'B-music_item', 'B-object_location_type', 'B-object_name', 148 | 'B-object_part_of_series_type', 'B-object_select', 'B-object_type', 'B-party_size_description', 149 | 'B-party_size_number', 'B-playlist', 'B-playlist_owner', 'B-poi', 'B-rating_unit', 'B-rating_value', 150 | 'B-restaurant_name', 'B-restaurant_type', 'B-served_dish', 'B-service', 'B-sort', 'B-spatial_relation', 151 | 'B-state', 'B-timeRange', 'B-track', 'B-year', 'I-album', 'I-artist', 'I-city', 'I-country', 152 | 'I-cuisine', 'I-current_location', 'I-entity_name', 'I-facility', 'I-genre', 'I-geographic_poi', 153 | 'I-location_name', 'I-movie_name', 'I-movie_type', 'I-music_item', 'I-object_location_type', 154 | 'I-object_name', 'I-object_part_of_series_type', 'I-object_select', 'I-object_type', 155 | 'I-party_size_description', 'I-playlist', 'I-playlist_owner', 'I-poi', 'I-restaurant_name', 156 | 'I-restaurant_type', 'I-served_dish', 'I-service', 'I-sort', 'I-spatial_relation', 'I-state', 157 | 'I-timeRange', 'I-track', 'O'] 158 | 159 | def get_slot_labels(self): 160 | """only contain Task labels""" 161 | return ['B-album', 'B-artist', 'B-best_rating', 'B-city', 162 | 'B-condition_description', 'B-condition_temperature', 'B-country', 'B-cuisine', 'B-current_location', 163 | 'B-entity_name', 'B-facility', 'B-genre', 'B-geographic_poi', 'B-location_name', 'B-movie_name', 164 | 'B-movie_type', 'B-music_item', 'B-object_location_type', 'B-object_name', 165 | 'B-object_part_of_series_type', 'B-object_select', 'B-object_type', 'B-party_size_description', 166 | 'B-party_size_number', 'B-playlist', 'B-playlist_owner', 'B-poi', 'B-rating_unit', 'B-rating_value', 167 | 'B-restaurant_name', 'B-restaurant_type', 'B-served_dish', 'B-service', 'B-sort', 'B-spatial_relation', 168 | 'B-state', 'B-timeRange', 'B-track', 'B-year', 'I-album', 'I-artist', 'I-city', 'I-country', 169 | 'I-cuisine', 'I-current_location', 'I-entity_name', 'I-facility', 'I-genre', 'I-geographic_poi', 170 | 'I-location_name', 'I-movie_name', 'I-movie_type', 'I-music_item', 'I-object_location_type', 171 | 'I-object_name', 'I-object_part_of_series_type', 'I-object_select', 'I-object_type', 172 | 'I-party_size_description', 'I-playlist', 'I-playlist_owner', 'I-poi', 'I-restaurant_name', 173 | 'I-restaurant_type', 'I-served_dish', 'I-service', 'I-sort', 'I-spatial_relation', 'I-state', 174 | 'I-timeRange', 'I-track'] 175 | 176 | def get_intent_labels(self): 177 | return ['AddToPlaylist', 'BookRestaurant', 'GetWeather', 'PlayMusic', 178 | 'RateBook', 'SearchCreativeWork', 'SearchScreeningEvent'] 179 | 180 | 181 | def show_intent_prediction_report(self): 182 | sys.stdout = Logger(os.path.join(self.log_out_file, "log.txt")) 183 | path_to_intent_label_lfile = os.path.join(self.path_to_label_file, "label") 184 | intent_label_list = self.get_intent_label_list(path_to_intent_label_lfile) 185 | path_to_predict_intent_label_file = os.path.join(self.path_to_predict_label_file, "intent_prediction_test_results.txt") 186 | predict_intent_label_list = self.get_predict_intent_label_list(path_to_predict_intent_label_file) 187 | labels = self.get_intent_labels() 188 | print("---show_intent_prediction_report---") 189 | self.show_metrics(intent_label_list, predict_intent_label_list, labels) 190 | print("--"*30) 191 | 192 | def show_slot_filling_report(self): 193 | sys.stdout = Logger(os.path.join(self.log_out_file, "log.txt")) 194 | slot_test_list, clean_predict_slot_list = self.producte_slot_list() 195 | slot_test_list, clean_predict_slot_list = self.delete_both_sides_is_O_word(slot_test_list, clean_predict_slot_list) 196 | labels = self.get_slot_labels() 197 | print("---show_slot_filling_report---") 198 | self.show_metrics(slot_test_list, clean_predict_slot_list, labels) 199 | print("--"*30) 200 | 201 | 202 | if __name__=='__main__': 203 | path_to_label_file = "/home/b418/jupyter_workspace/yuanxiao/" \ 204 | "BERT-for-Sequence-Labeling-and-Text-Classification/" \ 205 | "data/snips_Intent_Detection_and_Slot_Filling/test/" 206 | 207 | path_to_predict_label_file = "snips_join_task_epoch10_test4088ckpt" 208 | log_out_file = "snips_join_task_epoch10_test4088ckpt" 209 | intent_slot_reports = Snips_Slot_Filling_and_Intent_Detection_Calculate( 210 | path_to_label_file, path_to_predict_label_file, log_out_file) 211 | 212 | intent_slot_reports.show_intent_prediction_report() 213 | intent_slot_reports.show_slot_filling_report() 214 | 215 | -------------------------------------------------------------------------------- /calculating_model_score/calculate_snips_intent_and_slot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from sklearn_metrics_function import show_metrics,delete_both_sides_is_O_word 4 | 5 | print("-"*100) 6 | print("Slot Intent Task Report") 7 | print("-"*100) 8 | SNIPS_intent_label = ['AddToPlaylist', 'BookRestaurant', 'GetWeather', 'PlayMusic', 9 | 'RateBook', 'SearchCreativeWork', 'SearchScreeningEvent'] 10 | 11 | with open(os.path.join("SNIPS_Intent_and_Slot", "label")) as label_f: 12 | label_list = [label.replace("\n", "") for label in label_f.readlines()] 13 | #print(len(label_list), label_list) 14 | 15 | predit_label_value = np.fromfile(os.path.join("SNIPS_Intent_and_Slot", "intent_prediction_test_results.tsv"), sep="\t") 16 | predit_label_value = predit_label_value.reshape(-1, len(SNIPS_intent_label)) 17 | predit_label_value = np.argmax(predit_label_value, axis=1) 18 | predit_label = [SNIPS_intent_label[label_index] for label_index in predit_label_value] 19 | 20 | #print(len(predit_label), predit_label) 21 | 22 | 23 | show_metrics(y_test=label_list, y_predict=predit_label, labels=SNIPS_intent_label) 24 | 25 | print("-"*100) 26 | print("Slot Filling Task Report") 27 | print("-"*100) 28 | 29 | SNIPS_slot_label = ['[Padding]', '[##WordPiece]', '[CLS]', '[SEP]', 'B-album', 'B-artist', 'B-best_rating', 'B-city', 'B-condition_description', 'B-condition_temperature', 'B-country', 'B-cuisine', 'B-current_location', 'B-entity_name', 'B-facility', 'B-genre', 'B-geographic_poi', 'B-location_name', 'B-movie_name', 'B-movie_type', 'B-music_item', 'B-object_location_type', 'B-object_name', 'B-object_part_of_series_type', 'B-object_select', 'B-object_type', 'B-party_size_description', 'B-party_size_number', 'B-playlist', 'B-playlist_owner', 'B-poi', 'B-rating_unit', 'B-rating_value', 'B-restaurant_name', 'B-restaurant_type', 'B-served_dish', 'B-service', 'B-sort', 'B-spatial_relation', 'B-state', 'B-timeRange', 'B-track', 'B-year', 'I-album', 'I-artist', 'I-city', 'I-country', 'I-cuisine', 'I-current_location', 'I-entity_name', 'I-facility', 'I-genre', 'I-geographic_poi', 'I-location_name', 'I-movie_name', 'I-movie_type', 'I-music_item', 'I-object_location_type', 'I-object_name', 'I-object_part_of_series_type', 'I-object_select', 'I-object_type', 'I-party_size_description', 'I-playlist', 'I-playlist_owner', 'I-poi', 'I-restaurant_name', 'I-restaurant_type', 'I-served_dish', 'I-service', 'I-sort', 'I-spatial_relation', 'I-state', 'I-timeRange', 'I-track', 'O'] 30 | SNIPS_slot_effective_label = ['B-album', 'B-artist', 'B-best_rating', 'B-city', 'B-condition_description', 'B-condition_temperature', 'B-country', 'B-cuisine', 'B-current_location', 'B-entity_name', 'B-facility', 'B-genre', 'B-geographic_poi', 'B-location_name', 'B-movie_name', 'B-movie_type', 'B-music_item', 'B-object_location_type', 'B-object_name', 'B-object_part_of_series_type', 'B-object_select', 'B-object_type', 'B-party_size_description', 'B-party_size_number', 'B-playlist', 'B-playlist_owner', 'B-poi', 'B-rating_unit', 'B-rating_value', 'B-restaurant_name', 'B-restaurant_type', 'B-served_dish', 'B-service', 'B-sort', 'B-spatial_relation', 'B-state', 'B-timeRange', 'B-track', 'B-year', 'I-album', 'I-artist', 'I-city', 'I-country', 'I-cuisine', 'I-current_location', 'I-entity_name', 'I-facility', 'I-genre', 'I-geographic_poi', 'I-location_name', 'I-movie_name', 'I-movie_type', 'I-music_item', 'I-object_location_type', 'I-object_name', 'I-object_part_of_series_type', 'I-object_select', 'I-object_type', 'I-party_size_description', 'I-playlist', 'I-playlist_owner', 'I-poi', 'I-restaurant_name', 'I-restaurant_type', 'I-served_dish', 'I-service', 'I-sort', 'I-spatial_relation', 'I-state', 'I-timeRange', 'I-track', 'O'] 31 | SNIPS_slot_effective_label2 = ['B-album', 'B-artist', 'B-best_rating', 'B-city', 'B-condition_description', 'B-condition_temperature', 'B-country', 'B-cuisine', 'B-current_location', 'B-entity_name', 'B-facility', 'B-genre', 'B-geographic_poi', 'B-location_name', 'B-movie_name', 'B-movie_type', 'B-music_item', 'B-object_location_type', 'B-object_name', 'B-object_part_of_series_type', 'B-object_select', 'B-object_type', 'B-party_size_description', 'B-party_size_number', 'B-playlist', 'B-playlist_owner', 'B-poi', 'B-rating_unit', 'B-rating_value', 'B-restaurant_name', 'B-restaurant_type', 'B-served_dish', 'B-service', 'B-sort', 'B-spatial_relation', 'B-state', 'B-timeRange', 'B-track', 'B-year', 'I-album', 'I-artist', 'I-city', 'I-country', 'I-cuisine', 'I-current_location', 'I-entity_name', 'I-facility', 'I-genre', 'I-geographic_poi', 'I-location_name', 'I-movie_name', 'I-movie_type', 'I-music_item', 'I-object_location_type', 'I-object_name', 'I-object_part_of_series_type', 'I-object_select', 'I-object_type', 'I-party_size_description', 'I-playlist', 'I-playlist_owner', 'I-poi', 'I-restaurant_name', 'I-restaurant_type', 'I-served_dish', 'I-service', 'I-sort', 'I-spatial_relation', 'I-state', 'I-timeRange', 'I-track'] 32 | 33 | with open(os.path.join("SNIPS_Intent_and_Slot", "seq.out")) as label_f: 34 | label_list = [label.replace("\n", "") for label in label_f.readlines()] 35 | label_list = [seq.split() for seq in label_list] 36 | #print(len(label_list), label_list) 37 | 38 | 39 | with open(os.path.join("SNIPS_Intent_and_Slot", "slot_filling_test_results.txt")) as predict_f: 40 | predict_list = [predict_label.replace("\n", "") for predict_label in predict_f.readlines()] 41 | #print(len(predict_list), predict_list) 42 | predict_sentence_list = [] 43 | for word in predict_list: 44 | if "[CLS]" == word: 45 | a_sentence = [] 46 | a_sentence.append(word) 47 | if "[SEP]" == word: 48 | predict_sentence_list.append(a_sentence) 49 | #print(len(predict_sentence_list), predict_sentence_list) 50 | 51 | y_test_list = [] 52 | clean_y_predict_list = [] 53 | assert len(label_list)==len(predict_sentence_list) 54 | 55 | seqence_length_dont_match_index = 0 56 | for y_test, y_predict in zip(label_list, predict_sentence_list): 57 | y_predict.remove('[CLS]') 58 | y_predict.remove('[SEP]') 59 | while '[Padding]' in y_predict: 60 | y_predict.remove('[Padding]') 61 | while '[##WordPiece]' in y_predict: 62 | y_predict.remove('[##WordPiece]') 63 | if len(y_predict) > len(y_test): 64 | print(seqence_length_dont_match_index) 65 | print(y_predict) 66 | print(y_test) 67 | print("~"*100) 68 | y_predict = y_predict[0:len(y_test)] 69 | elif len(y_predict) < len(y_test): 70 | print(seqence_length_dont_match_index) 71 | print(y_predict) 72 | print(y_test) 73 | print("~"*100) 74 | y_predict = y_predict + ["O"] * (len(y_test)-len(y_predict)) 75 | assert len(y_predict)==len(y_test) 76 | y_test_list.extend(y_test) 77 | clean_y_predict_list.extend(y_predict) 78 | seqence_length_dont_match_index +=1 79 | 80 | assert len(y_test_list)==len(clean_y_predict_list) 81 | 82 | y_test_list, clean_y_predict_list = delete_both_sides_is_O_word(y_test_list, clean_y_predict_list) 83 | 84 | show_metrics(y_test=y_test_list, y_predict=clean_y_predict_list, labels=SNIPS_slot_effective_label) -------------------------------------------------------------------------------- /calculating_model_score/calculate_snips_intent_and_slot_new.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from sklearn_metrics_function import show_metrics,delete_both_sides_is_O_word 4 | 5 | print("-"*100) 6 | print("Slot Intent Task Report") 7 | print("-"*100) 8 | SNIPS_intent_label = ['AddToPlaylist', 'BookRestaurant', 'GetWeather', 'PlayMusic', 9 | 'RateBook', 'SearchCreativeWork', 'SearchScreeningEvent'] 10 | 11 | with open(os.path.join("join_task", "label")) as label_f: 12 | intent_label_list = [label.replace("\n", "") for label in label_f.readlines()] 13 | 14 | with open(os.path.join("join_task", "intent_prediction_test_results.txt")) as intent_f: 15 | predict_intent_label_list = [label.replace("\n", "") for label in intent_f.readlines()] 16 | 17 | assert len(intent_label_list)==len(predict_intent_label_list) 18 | 19 | show_metrics(y_test=intent_label_list, y_predict=predict_intent_label_list, labels=SNIPS_intent_label) 20 | 21 | print("-"*100) 22 | print("Slot Filling Task Report") 23 | print("-"*100) 24 | 25 | SNIPS_slot_label = ['[Padding]', '[##WordPiece]', '[CLS]', '[SEP]', 'B-album', 'B-artist', 'B-best_rating', 'B-city', 'B-condition_description', 'B-condition_temperature', 'B-country', 'B-cuisine', 'B-current_location', 'B-entity_name', 'B-facility', 'B-genre', 'B-geographic_poi', 'B-location_name', 'B-movie_name', 'B-movie_type', 'B-music_item', 'B-object_location_type', 'B-object_name', 'B-object_part_of_series_type', 'B-object_select', 'B-object_type', 'B-party_size_description', 'B-party_size_number', 'B-playlist', 'B-playlist_owner', 'B-poi', 'B-rating_unit', 'B-rating_value', 'B-restaurant_name', 'B-restaurant_type', 'B-served_dish', 'B-service', 'B-sort', 'B-spatial_relation', 'B-state', 'B-timeRange', 'B-track', 'B-year', 'I-album', 'I-artist', 'I-city', 'I-country', 'I-cuisine', 'I-current_location', 'I-entity_name', 'I-facility', 'I-genre', 'I-geographic_poi', 'I-location_name', 'I-movie_name', 'I-movie_type', 'I-music_item', 'I-object_location_type', 'I-object_name', 'I-object_part_of_series_type', 'I-object_select', 'I-object_type', 'I-party_size_description', 'I-playlist', 'I-playlist_owner', 'I-poi', 'I-restaurant_name', 'I-restaurant_type', 'I-served_dish', 'I-service', 'I-sort', 'I-spatial_relation', 'I-state', 'I-timeRange', 'I-track', 'O'] 26 | SNIPS_slot_effective_label = ['B-album', 'B-artist', 'B-best_rating', 'B-city', 'B-condition_description', 'B-condition_temperature', 'B-country', 'B-cuisine', 'B-current_location', 'B-entity_name', 'B-facility', 'B-genre', 'B-geographic_poi', 'B-location_name', 'B-movie_name', 'B-movie_type', 'B-music_item', 'B-object_location_type', 'B-object_name', 'B-object_part_of_series_type', 'B-object_select', 'B-object_type', 'B-party_size_description', 'B-party_size_number', 'B-playlist', 'B-playlist_owner', 'B-poi', 'B-rating_unit', 'B-rating_value', 'B-restaurant_name', 'B-restaurant_type', 'B-served_dish', 'B-service', 'B-sort', 'B-spatial_relation', 'B-state', 'B-timeRange', 'B-track', 'B-year', 'I-album', 'I-artist', 'I-city', 'I-country', 'I-cuisine', 'I-current_location', 'I-entity_name', 'I-facility', 'I-genre', 'I-geographic_poi', 'I-location_name', 'I-movie_name', 'I-movie_type', 'I-music_item', 'I-object_location_type', 'I-object_name', 'I-object_part_of_series_type', 'I-object_select', 'I-object_type', 'I-party_size_description', 'I-playlist', 'I-playlist_owner', 'I-poi', 'I-restaurant_name', 'I-restaurant_type', 'I-served_dish', 'I-service', 'I-sort', 'I-spatial_relation', 'I-state', 'I-timeRange', 'I-track', 'O'] 27 | SNIPS_slot_effective_label2 = ['B-album', 'B-artist', 'B-best_rating', 'B-city', 'B-condition_description', 'B-condition_temperature', 'B-country', 'B-cuisine', 'B-current_location', 'B-entity_name', 'B-facility', 'B-genre', 'B-geographic_poi', 'B-location_name', 'B-movie_name', 'B-movie_type', 'B-music_item', 'B-object_location_type', 'B-object_name', 'B-object_part_of_series_type', 'B-object_select', 'B-object_type', 'B-party_size_description', 'B-party_size_number', 'B-playlist', 'B-playlist_owner', 'B-poi', 'B-rating_unit', 'B-rating_value', 'B-restaurant_name', 'B-restaurant_type', 'B-served_dish', 'B-service', 'B-sort', 'B-spatial_relation', 'B-state', 'B-timeRange', 'B-track', 'B-year', 'I-album', 'I-artist', 'I-city', 'I-country', 'I-cuisine', 'I-current_location', 'I-entity_name', 'I-facility', 'I-genre', 'I-geographic_poi', 'I-location_name', 'I-movie_name', 'I-movie_type', 'I-music_item', 'I-object_location_type', 'I-object_name', 'I-object_part_of_series_type', 'I-object_select', 'I-object_type', 'I-party_size_description', 'I-playlist', 'I-playlist_owner', 'I-poi', 'I-restaurant_name', 'I-restaurant_type', 'I-served_dish', 'I-service', 'I-sort', 'I-spatial_relation', 'I-state', 'I-timeRange', 'I-track'] 28 | 29 | with open(os.path.join("join_task", "seq_out")) as slot_f: 30 | slot_label_list = [label.split() for label in slot_f.readlines()] 31 | 32 | with open(os.path.join("join_task", "slot_filling_test_results.txt")) as slot_predict_f: 33 | predict_slot_sentence_list = [predict_label.split() for predict_label in slot_predict_f.readlines()] 34 | 35 | assert len(slot_label_list)==len(predict_slot_sentence_list) 36 | 37 | slot_test_list = [] 38 | clean_predict_slot_list = [] 39 | 40 | seqence_length_dont_match_index = 0 41 | for y_test, y_predict in zip(slot_label_list, predict_slot_sentence_list): 42 | y_predict.remove('[CLS]') 43 | y_predict.remove('[SEP]') 44 | while '[Padding]' in y_predict: 45 | y_predict.remove('[Padding]') 46 | while '[##WordPiece]' in y_predict: 47 | y_predict.remove('[##WordPiece]') 48 | if len(y_predict) > len(y_test): 49 | print(seqence_length_dont_match_index) 50 | print(y_predict) 51 | print(y_test) 52 | print("~"*100) 53 | y_predict = y_predict[0:len(y_test)] 54 | elif len(y_predict) < len(y_test): 55 | print(seqence_length_dont_match_index) 56 | print(y_predict) 57 | print(y_test) 58 | print("~"*100) 59 | y_predict = y_predict + ["O"] * (len(y_test)-len(y_predict)) 60 | assert len(y_predict)==len(y_test) 61 | slot_test_list.extend(y_test) 62 | clean_predict_slot_list.extend(y_predict) 63 | seqence_length_dont_match_index +=1 64 | 65 | assert len(slot_test_list) == len(clean_predict_slot_list) 66 | 67 | slot_test_list, clean_predict_slot_list = delete_both_sides_is_O_word(slot_test_list, clean_predict_slot_list) 68 | 69 | show_metrics(y_test=slot_test_list, y_predict=clean_predict_slot_list, labels=SNIPS_slot_effective_label) -------------------------------------------------------------------------------- /calculating_model_score/calculate_snips_slot.py: -------------------------------------------------------------------------------- 1 | import os 2 | from sklearn_metrics_function import show_metrics,delete_both_sides_is_O_word 3 | 4 | SNIPS_slot_label = ['[Padding]', '[##WordPiece]', '[CLS]', '[SEP]', 'B-album', 'B-artist', 'B-best_rating', 'B-city', 'B-condition_description', 'B-condition_temperature', 'B-country', 'B-cuisine', 'B-current_location', 'B-entity_name', 'B-facility', 'B-genre', 'B-geographic_poi', 'B-location_name', 'B-movie_name', 'B-movie_type', 'B-music_item', 'B-object_location_type', 'B-object_name', 'B-object_part_of_series_type', 'B-object_select', 'B-object_type', 'B-party_size_description', 'B-party_size_number', 'B-playlist', 'B-playlist_owner', 'B-poi', 'B-rating_unit', 'B-rating_value', 'B-restaurant_name', 'B-restaurant_type', 'B-served_dish', 'B-service', 'B-sort', 'B-spatial_relation', 'B-state', 'B-timeRange', 'B-track', 'B-year', 'I-album', 'I-artist', 'I-city', 'I-country', 'I-cuisine', 'I-current_location', 'I-entity_name', 'I-facility', 'I-genre', 'I-geographic_poi', 'I-location_name', 'I-movie_name', 'I-movie_type', 'I-music_item', 'I-object_location_type', 'I-object_name', 'I-object_part_of_series_type', 'I-object_select', 'I-object_type', 'I-party_size_description', 'I-playlist', 'I-playlist_owner', 'I-poi', 'I-restaurant_name', 'I-restaurant_type', 'I-served_dish', 'I-service', 'I-sort', 'I-spatial_relation', 'I-state', 'I-timeRange', 'I-track', 'O'] 5 | SNIPS_slot_effective_label = ['B-album', 'B-artist', 'B-best_rating', 'B-city', 'B-condition_description', 'B-condition_temperature', 'B-country', 'B-cuisine', 'B-current_location', 'B-entity_name', 'B-facility', 'B-genre', 'B-geographic_poi', 'B-location_name', 'B-movie_name', 'B-movie_type', 'B-music_item', 'B-object_location_type', 'B-object_name', 'B-object_part_of_series_type', 'B-object_select', 'B-object_type', 'B-party_size_description', 'B-party_size_number', 'B-playlist', 'B-playlist_owner', 'B-poi', 'B-rating_unit', 'B-rating_value', 'B-restaurant_name', 'B-restaurant_type', 'B-served_dish', 'B-service', 'B-sort', 'B-spatial_relation', 'B-state', 'B-timeRange', 'B-track', 'B-year', 'I-album', 'I-artist', 'I-city', 'I-country', 'I-cuisine', 'I-current_location', 'I-entity_name', 'I-facility', 'I-genre', 'I-geographic_poi', 'I-location_name', 'I-movie_name', 'I-movie_type', 'I-music_item', 'I-object_location_type', 'I-object_name', 'I-object_part_of_series_type', 'I-object_select', 'I-object_type', 'I-party_size_description', 'I-playlist', 'I-playlist_owner', 'I-poi', 'I-restaurant_name', 'I-restaurant_type', 'I-served_dish', 'I-service', 'I-sort', 'I-spatial_relation', 'I-state', 'I-timeRange', 'I-track', 'O'] 6 | SNIPS_slot_effective_label2 = ['B-album', 'B-artist', 'B-best_rating', 'B-city', 'B-condition_description', 'B-condition_temperature', 'B-country', 'B-cuisine', 'B-current_location', 'B-entity_name', 'B-facility', 'B-genre', 'B-geographic_poi', 'B-location_name', 'B-movie_name', 'B-movie_type', 'B-music_item', 'B-object_location_type', 'B-object_name', 'B-object_part_of_series_type', 'B-object_select', 'B-object_type', 'B-party_size_description', 'B-party_size_number', 'B-playlist', 'B-playlist_owner', 'B-poi', 'B-rating_unit', 'B-rating_value', 'B-restaurant_name', 'B-restaurant_type', 'B-served_dish', 'B-service', 'B-sort', 'B-spatial_relation', 'B-state', 'B-timeRange', 'B-track', 'B-year', 'I-album', 'I-artist', 'I-city', 'I-country', 'I-cuisine', 'I-current_location', 'I-entity_name', 'I-facility', 'I-genre', 'I-geographic_poi', 'I-location_name', 'I-movie_name', 'I-movie_type', 'I-music_item', 'I-object_location_type', 'I-object_name', 'I-object_part_of_series_type', 'I-object_select', 'I-object_type', 'I-party_size_description', 'I-playlist', 'I-playlist_owner', 'I-poi', 'I-restaurant_name', 'I-restaurant_type', 'I-served_dish', 'I-service', 'I-sort', 'I-spatial_relation', 'I-state', 'I-timeRange', 'I-track'] 7 | 8 | 9 | 10 | with open(os.path.join("SNIPS_slot", "seq.out")) as label_f: 11 | label_list = [label.replace("\n", "") for label in label_f.readlines()] 12 | label_list = [seq.split() for seq in label_list] 13 | #print(len(label_list), label_list) 14 | 15 | with open(os.path.join("SNIPS_slot", "label_test.txt")) as predict_f: 16 | predict_list = [predict_label.replace("\n", "") for predict_label in predict_f.readlines()] 17 | #print(len(predict_list), predict_list) 18 | predict_sentence_list = [] 19 | for word in predict_list: 20 | if "[CLS]" == word: 21 | a_sentence = [] 22 | a_sentence.append(word) 23 | if "[SEP]" == word: 24 | predict_sentence_list.append(a_sentence) 25 | #print(len(predict_sentence_list), predict_sentence_list) 26 | 27 | y_test_list = [] 28 | clean_y_predict_list = [] 29 | assert len(label_list)==len(predict_sentence_list) 30 | for y_test, y_predict in zip(label_list, predict_sentence_list): 31 | y_predict.remove('[CLS]') 32 | y_predict.remove('[SEP]') 33 | while '[Padding]' in y_predict: 34 | y_predict.remove('[Padding]') 35 | while '[##WordPiece]' in y_predict: 36 | y_predict.remove('[##WordPiece]') 37 | if len(y_predict)!=len(y_test): 38 | print(y_predict) 39 | print(y_test) 40 | print("~"*100) 41 | y_test_list.extend(y_test) 42 | clean_y_predict_list.extend(y_predict) 43 | 44 | assert len(y_test_list)==len(clean_y_predict_list) 45 | 46 | y_test_list, clean_y_predict_list = delete_both_sides_is_O_word(y_test_list, clean_y_predict_list) 47 | 48 | show_metrics(y_test=y_test_list, y_predict=clean_y_predict_list, labels=SNIPS_slot_effective_label) 49 | 50 | -------------------------------------------------------------------------------- /calculating_model_score/calculate_snpis_intent.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from sklearn_metrics_function import show_metrics 4 | 5 | SNIPS_intent_label = ['AddToPlaylist', 'BookRestaurant', 'GetWeather', 'PlayMusic', 6 | 'RateBook', 'SearchCreativeWork', 'SearchScreeningEvent'] 7 | 8 | with open(os.path.join("SNIPS_Intent", "label")) as label_f: 9 | label_list = [label.replace("\n", "") for label in label_f.readlines()] 10 | #print(len(label_list), label_list) 11 | 12 | predit_label_value = np.fromfile(os.path.join("SNIPS_Intent", "test_results.tsv"), sep="\t") 13 | predit_label_value = predit_label_value.reshape(-1, len(SNIPS_intent_label)) 14 | predit_label_value = np.argmax(predit_label_value, axis=1) 15 | predit_label = [SNIPS_intent_label[label_index] for label_index in predit_label_value] 16 | 17 | #print(len(predit_label), predit_label) 18 | 19 | 20 | show_metrics(y_test=label_list, y_predict=predit_label, labels=SNIPS_intent_label) -------------------------------------------------------------------------------- /calculating_model_score/sklearn_metrics_function.py: -------------------------------------------------------------------------------- 1 | from sklearn import metrics 2 | 3 | y_test = [1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4] 4 | y_predict = [1, 1, 1, 0, 0, 2, 2, 3, 3, 3, 4, 3, 4, 3] 5 | 6 | def show_metrics(y_test, y_predict, labels): 7 | print('准确率:', metrics.accuracy_score(y_test, y_predict)) # 预测准确率输出 8 | 9 | print('宏平均精确率:', metrics.precision_score(y_test, y_predict, average='macro')) # 预测宏平均精确率输出 10 | print('微平均精确率:', metrics.precision_score(y_test, y_predict, average='micro')) # 预测微平均精确率输出 11 | print('加权平均精确率:', metrics.precision_score(y_test, y_predict, average='weighted')) # 预测加权平均精确率输出 12 | 13 | print('宏平均召回率:', metrics.recall_score(y_test, y_predict, average='macro')) # 预测宏平均召回率输出 14 | print('微平均召回率:', metrics.recall_score(y_test, y_predict, average='micro')) # 预测微平均召回率输出 15 | print('加权平均召回率:', metrics.recall_score(y_test, y_predict, average='micro')) # 预测加权平均召回率输出 16 | 17 | print('宏平均F1-score:', metrics.f1_score(y_test, y_predict, labels=labels, average='macro')) # 预测宏平均f1-score输出 18 | print('微平均F1-score:', metrics.f1_score(y_test, y_predict, labels=labels, average='micro')) # 预测微平均f1-score输出 19 | print('加权平均F1-score:', metrics.f1_score(y_test, y_predict, labels=labels, average='weighted')) # 预测加权平均f1-score输出 20 | 21 | print('混淆矩阵输出:\n', metrics.confusion_matrix(y_test, y_predict)) # 混淆矩阵输出 22 | print('分类报告:\n', metrics.classification_report(y_test, y_predict)) # 分类报告输出 23 | 24 | #show_metrics(y_test=y_test, y_predict=y_predict, labels=[1,2,3,4]) 25 | def delete_both_sides_is_O_word(y_test_list, clean_y_predict_list): 26 | new_y_test_list, new_clean_y_predict_list = [], [] 27 | for test, pred in zip(y_test_list, clean_y_predict_list): 28 | if test=="O" and pred=="O": 29 | continue 30 | new_y_test_list.append(test) 31 | new_clean_y_predict_list.append(pred) 32 | assert len(new_y_test_list)==len(new_clean_y_predict_list) 33 | return new_y_test_list, new_clean_y_predict_list -------------------------------------------------------------------------------- /calculating_model_score/tf_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multiclass 3 | from: 4 | https://github.com/guillaumegenthial/tf_metrics/blob/master/tf_metrics/__init__.py 5 | 6 | """ 7 | 8 | __author__ = "Guillaume Genthial" 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | from tensorflow.python.ops.metrics_impl import _streaming_confusion_matrix 13 | 14 | 15 | def precision(labels, predictions, num_classes, pos_indices=None, 16 | weights=None, average='micro'): 17 | """Multi-class precision metric for Tensorflow 18 | Parameters 19 | ---------- 20 | labels : Tensor of tf.int32 or tf.int64 21 | The true labels 22 | predictions : Tensor of tf.int32 or tf.int64 23 | The predictions, same shape as labels 24 | num_classes : int 25 | The number of classes 26 | pos_indices : list of int, optional 27 | The indices of the positive classes, default is all 28 | weights : Tensor of tf.int32, optional 29 | Mask, must be of compatible shape with labels 30 | average : str, optional 31 | 'micro': counts the total number of true positives, false 32 | positives, and false negatives for the classes in 33 | `pos_indices` and infer the metric from it. 34 | 'macro': will compute the metric separately for each class in 35 | `pos_indices` and average. Will not account for class 36 | imbalance. 37 | 'weighted': will compute the metric separately for each class in 38 | `pos_indices` and perform a weighted average by the total 39 | number of true labels for each class. 40 | Returns 41 | ------- 42 | tuple of (scalar float Tensor, update_op) 43 | """ 44 | cm, op = _streaming_confusion_matrix( 45 | labels, predictions, num_classes, weights) 46 | pr, _, _ = metrics_from_confusion_matrix( 47 | cm, pos_indices, average=average) 48 | op, _, _ = metrics_from_confusion_matrix( 49 | op, pos_indices, average=average) 50 | return (pr, op) 51 | 52 | 53 | def recall(labels, predictions, num_classes, pos_indices=None, weights=None, 54 | average='micro'): 55 | """Multi-class recall metric for Tensorflow 56 | Parameters 57 | ---------- 58 | labels : Tensor of tf.int32 or tf.int64 59 | The true labels 60 | predictions : Tensor of tf.int32 or tf.int64 61 | The predictions, same shape as labels 62 | num_classes : int 63 | The number of classes 64 | pos_indices : list of int, optional 65 | The indices of the positive classes, default is all 66 | weights : Tensor of tf.int32, optional 67 | Mask, must be of compatible shape with labels 68 | average : str, optional 69 | 'micro': counts the total number of true positives, false 70 | positives, and false negatives for the classes in 71 | `pos_indices` and infer the metric from it. 72 | 'macro': will compute the metric separately for each class in 73 | `pos_indices` and average. Will not account for class 74 | imbalance. 75 | 'weighted': will compute the metric separately for each class in 76 | `pos_indices` and perform a weighted average by the total 77 | number of true labels for each class. 78 | Returns 79 | ------- 80 | tuple of (scalar float Tensor, update_op) 81 | """ 82 | cm, op = _streaming_confusion_matrix( 83 | labels, predictions, num_classes, weights) 84 | _, re, _ = metrics_from_confusion_matrix( 85 | cm, pos_indices, average=average) 86 | _, op, _ = metrics_from_confusion_matrix( 87 | op, pos_indices, average=average) 88 | return (re, op) 89 | 90 | 91 | def f1(labels, predictions, num_classes, pos_indices=None, weights=None, 92 | average='micro'): 93 | return fbeta(labels, predictions, num_classes, pos_indices, weights, 94 | average) 95 | 96 | 97 | def fbeta(labels, predictions, num_classes, pos_indices=None, weights=None, 98 | average='micro', beta=1): 99 | """Multi-class fbeta metric for Tensorflow 100 | Parameters 101 | ---------- 102 | labels : Tensor of tf.int32 or tf.int64 103 | The true labels 104 | predictions : Tensor of tf.int32 or tf.int64 105 | The predictions, same shape as labels 106 | num_classes : int 107 | The number of classes 108 | pos_indices : list of int, optional 109 | The indices of the positive classes, default is all 110 | weights : Tensor of tf.int32, optional 111 | Mask, must be of compatible shape with labels 112 | average : str, optional 113 | 'micro': counts the total number of true positives, false 114 | positives, and false negatives for the classes in 115 | `pos_indices` and infer the metric from it. 116 | 'macro': will compute the metric separately for each class in 117 | `pos_indices` and average. Will not account for class 118 | imbalance. 119 | 'weighted': will compute the metric separately for each class in 120 | `pos_indices` and perform a weighted average by the total 121 | number of true labels for each class. 122 | beta : int, optional 123 | Weight of precision in harmonic mean 124 | Returns 125 | ------- 126 | tuple of (scalar float Tensor, update_op) 127 | """ 128 | cm, op = _streaming_confusion_matrix( 129 | labels, predictions, num_classes, weights) 130 | _, _, fbeta = metrics_from_confusion_matrix( 131 | cm, pos_indices, average=average, beta=beta) 132 | _, _, op = metrics_from_confusion_matrix( 133 | op, pos_indices, average=average, beta=beta) 134 | return (fbeta, op) 135 | 136 | 137 | def safe_div(numerator, denominator): 138 | """Safe division, return 0 if denominator is 0""" 139 | numerator, denominator = tf.to_float(numerator), tf.to_float(denominator) 140 | zeros = tf.zeros_like(numerator, dtype=numerator.dtype) 141 | denominator_is_zero = tf.equal(denominator, zeros) 142 | return tf.where(denominator_is_zero, zeros, numerator / denominator) 143 | 144 | 145 | def pr_re_fbeta(cm, pos_indices, beta=1): 146 | """Uses a confusion matrix to compute precision, recall and fbeta""" 147 | num_classes = cm.shape[0] 148 | neg_indices = [i for i in range(num_classes) if i not in pos_indices] 149 | cm_mask = np.ones([num_classes, num_classes]) 150 | cm_mask[neg_indices, neg_indices] = 0 151 | diag_sum = tf.reduce_sum(tf.diag_part(cm * cm_mask)) 152 | 153 | cm_mask = np.ones([num_classes, num_classes]) 154 | cm_mask[:, neg_indices] = 0 155 | tot_pred = tf.reduce_sum(cm * cm_mask) 156 | 157 | cm_mask = np.ones([num_classes, num_classes]) 158 | cm_mask[neg_indices, :] = 0 159 | tot_gold = tf.reduce_sum(cm * cm_mask) 160 | 161 | pr = safe_div(diag_sum, tot_pred) 162 | re = safe_div(diag_sum, tot_gold) 163 | fbeta = safe_div((1. + beta**2) * pr * re, beta**2 * pr + re) 164 | 165 | return pr, re, fbeta 166 | 167 | 168 | def metrics_from_confusion_matrix(cm, pos_indices=None, average='micro', 169 | beta=1): 170 | """Precision, Recall and F1 from the confusion matrix 171 | Parameters 172 | ---------- 173 | cm : tf.Tensor of type tf.int32, of shape (num_classes, num_classes) 174 | The streaming confusion matrix. 175 | pos_indices : list of int, optional 176 | The indices of the positive classes 177 | beta : int, optional 178 | Weight of precision in harmonic mean 179 | average : str, optional 180 | 'micro', 'macro' or 'weighted' 181 | """ 182 | num_classes = cm.shape[0] 183 | if pos_indices is None: 184 | pos_indices = [i for i in range(num_classes)] 185 | 186 | if average == 'micro': 187 | return pr_re_fbeta(cm, pos_indices, beta) 188 | elif average in {'macro', 'weighted'}: 189 | precisions, recalls, fbetas, n_golds = [], [], [], [] 190 | for idx in pos_indices: 191 | pr, re, fbeta = pr_re_fbeta(cm, [idx], beta) 192 | precisions.append(pr) 193 | recalls.append(re) 194 | fbetas.append(fbeta) 195 | cm_mask = np.zeros([num_classes, num_classes]) 196 | cm_mask[idx, :] = 1 197 | n_golds.append(tf.to_float(tf.reduce_sum(cm * cm_mask))) 198 | 199 | if average == 'macro': 200 | pr = tf.reduce_mean(precisions) 201 | re = tf.reduce_mean(recalls) 202 | fbeta = tf.reduce_mean(fbetas) 203 | return pr, re, fbeta 204 | if average == 'weighted': 205 | n_gold = tf.reduce_sum(n_golds) 206 | pr_sum = sum(p * n for p, n in zip(precisions, n_golds)) 207 | pr = safe_div(pr_sum, n_gold) 208 | re_sum = sum(r * n for r, n in zip(recalls, n_golds)) 209 | re = safe_div(re_sum, n_gold) 210 | fbeta_sum = sum(f * n for f, n in zip(fbetas, n_golds)) 211 | fbeta = safe_div(fbeta_sum, n_gold) 212 | return pr, re, fbeta 213 | 214 | else: 215 | raise NotImplementedError() -------------------------------------------------------------------------------- /data/CoNLL2003_NER/conll03_raw_data_to_stand_file.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def get_examples(input_file): 4 | """Reads a BIO data.""" 5 | with open(input_file) as f: 6 | lines = [] 7 | words = [] 8 | labels = [] 9 | for line in f: 10 | contends = line.strip() 11 | word = line.strip().split(' ')[0] 12 | label = line.strip().split(' ')[-1] 13 | if contends.startswith("-DOCSTART-"): 14 | words.append('') 15 | continue 16 | if len(contends) == 0 and words[-1] == '.': 17 | l = ' '.join([label for label in labels if len(label) > 0]) 18 | w = ' '.join([word for word in words if len(word) > 0]) 19 | lines.append([l, w]) 20 | words = [] 21 | labels = [] 22 | continue 23 | words.append(word) 24 | labels.append(label) 25 | seq_in = [line[1] for line in lines] 26 | seq_out = [line[0] for line in lines] 27 | return seq_in, seq_out 28 | 29 | def conll03_raw_data_to_stand(path_to_raw_file=None): 30 | os.makedirs("train") 31 | os.makedirs("valid") 32 | os.makedirs("test") 33 | for file_type in ["train", "dev", "test"]: 34 | raw_file = file_type+".txt" 35 | seq_in, seq_out = get_examples(raw_file) 36 | if file_type=="dev": 37 | file_type="valid" 38 | with open(os.path.join(file_type, "seq.in"), "w") as seq_in_f: 39 | with open(os.path.join(file_type, "seq.out"), "w") as seq_out_f: 40 | for seq in seq_in: 41 | seq_in_f.write(seq + "\n") 42 | for seq in seq_out: 43 | seq_out_f.write(seq+"\n") 44 | 45 | 46 | 47 | conll03_raw_data_to_stand() 48 | -------------------------------------------------------------------------------- /data/atis_Intent_Detection_and_Slot_Filling/train/check_train_raw_data.py: -------------------------------------------------------------------------------- 1 | label_data = open("label", encoding='utf-8').readlines() 2 | label_data = [x.strip() for x in label_data] 3 | print(len(label_data)) 4 | label_kinds = set(label_data) 5 | print(label_kinds) -------------------------------------------------------------------------------- /data/atis_Intent_Detection_and_Slot_Filling/valid/label: -------------------------------------------------------------------------------- 1 | atis_flight 2 | atis_flight 3 | atis_flight 4 | atis_flight 5 | atis_flight 6 | atis_flight 7 | atis_flight 8 | atis_flight 9 | atis_flight 10 | atis_flight 11 | atis_flight 12 | atis_airfare 13 | atis_flight 14 | atis_airfare 15 | atis_flight 16 | atis_flight 17 | atis_flight 18 | atis_flight 19 | atis_restriction 20 | atis_ground_service 21 | atis_abbreviation 22 | atis_flight 23 | atis_flight 24 | atis_flight 25 | atis_flight 26 | atis_flight 27 | atis_flight 28 | atis_flight 29 | atis_flight 30 | atis_flight 31 | atis_flight 32 | atis_aircraft 33 | atis_flight 34 | atis_flight 35 | atis_flight 36 | atis_flight 37 | atis_airfare 38 | atis_flight 39 | atis_flight 40 | atis_flight 41 | atis_flight 42 | atis_airline 43 | atis_flight 44 | atis_flight 45 | atis_flight 46 | atis_flight 47 | atis_flight 48 | atis_flight 49 | atis_flight 50 | atis_quantity 51 | atis_flight_time 52 | atis_flight 53 | atis_flight 54 | atis_flight 55 | atis_flight 56 | atis_ground_service 57 | atis_flight 58 | atis_flight 59 | atis_flight 60 | atis_flight_time 61 | atis_flight 62 | atis_flight_time 63 | atis_distance 64 | atis_aircraft 65 | atis_flight 66 | atis_flight#atis_airfare 67 | atis_flight 68 | atis_flight 69 | atis_airfare 70 | atis_flight 71 | atis_airfare 72 | atis_flight 73 | atis_flight 74 | atis_flight 75 | atis_flight 76 | atis_quantity 77 | atis_flight_time 78 | atis_ground_service 79 | atis_flight 80 | atis_flight 81 | atis_flight 82 | atis_flight 83 | atis_flight 84 | atis_ground_service 85 | atis_flight 86 | atis_flight 87 | atis_flight 88 | atis_ground_fare 89 | atis_flight 90 | atis_flight_time 91 | atis_flight 92 | atis_capacity 93 | atis_flight 94 | atis_flight 95 | atis_flight_time 96 | atis_flight 97 | atis_flight 98 | atis_airfare 99 | atis_airfare 100 | atis_flight 101 | atis_flight 102 | atis_flight 103 | atis_flight 104 | atis_flight 105 | atis_flight 106 | atis_flight 107 | atis_flight 108 | atis_flight 109 | atis_flight 110 | atis_flight 111 | atis_flight 112 | atis_flight 113 | atis_flight 114 | atis_flight 115 | atis_flight 116 | atis_flight 117 | atis_flight 118 | atis_flight 119 | atis_airfare 120 | atis_flight 121 | atis_flight 122 | atis_flight 123 | atis_flight 124 | atis_flight 125 | atis_flight 126 | atis_flight 127 | atis_aircraft 128 | atis_abbreviation 129 | atis_airfare 130 | atis_flight 131 | atis_flight 132 | atis_flight 133 | atis_flight 134 | atis_flight 135 | atis_airline 136 | atis_flight 137 | atis_abbreviation 138 | atis_flight 139 | atis_airfare 140 | atis_airfare 141 | atis_airfare 142 | atis_flight 143 | atis_airfare 144 | atis_airfare 145 | atis_flight 146 | atis_flight 147 | atis_airfare 148 | atis_flight 149 | atis_flight 150 | atis_airline 151 | atis_airfare 152 | atis_flight 153 | atis_flight 154 | atis_flight 155 | atis_flight 156 | atis_flight 157 | atis_flight 158 | atis_flight 159 | atis_flight 160 | atis_flight 161 | atis_flight 162 | atis_airfare 163 | atis_airline 164 | atis_quantity 165 | atis_flight 166 | atis_flight 167 | atis_flight 168 | atis_aircraft 169 | atis_flight 170 | atis_flight 171 | atis_city 172 | atis_quantity 173 | atis_flight 174 | atis_flight 175 | atis_flight 176 | atis_flight 177 | atis_flight 178 | atis_flight 179 | atis_flight 180 | atis_airfare 181 | atis_airfare 182 | atis_flight 183 | atis_flight 184 | atis_flight 185 | atis_flight 186 | atis_flight 187 | atis_flight 188 | atis_flight 189 | atis_airline 190 | atis_airfare 191 | atis_flight 192 | atis_flight 193 | atis_flight 194 | atis_quantity 195 | atis_flight 196 | atis_flight 197 | atis_flight 198 | atis_flight 199 | atis_flight 200 | atis_flight 201 | atis_flight 202 | atis_flight 203 | atis_flight 204 | atis_flight 205 | atis_flight 206 | atis_flight 207 | atis_flight 208 | atis_flight 209 | atis_flight 210 | atis_airfare 211 | atis_aircraft 212 | atis_flight 213 | atis_flight 214 | atis_ground_service 215 | atis_airport 216 | atis_flight 217 | atis_flight 218 | atis_airfare 219 | atis_flight 220 | atis_flight 221 | atis_flight 222 | atis_abbreviation 223 | atis_flight 224 | atis_flight_time 225 | atis_airline 226 | atis_quantity 227 | atis_flight 228 | atis_flight 229 | atis_flight 230 | atis_flight 231 | atis_flight 232 | atis_flight 233 | atis_flight 234 | atis_flight 235 | atis_flight 236 | atis_flight 237 | atis_airfare 238 | atis_flight 239 | atis_ground_service 240 | atis_flight 241 | atis_flight 242 | atis_flight 243 | atis_flight 244 | atis_ground_service 245 | atis_flight 246 | atis_flight 247 | atis_flight 248 | atis_flight 249 | atis_abbreviation 250 | atis_flight 251 | atis_flight_time 252 | atis_flight 253 | atis_flight 254 | atis_abbreviation 255 | atis_aircraft 256 | atis_flight 257 | atis_flight 258 | atis_flight 259 | atis_flight 260 | atis_airfare 261 | atis_airline 262 | atis_flight 263 | atis_flight 264 | atis_aircraft 265 | atis_flight 266 | atis_ground_service 267 | atis_flight 268 | atis_flight 269 | atis_flight 270 | atis_flight 271 | atis_flight 272 | atis_flight_time 273 | atis_flight 274 | atis_flight 275 | atis_ground_service 276 | atis_ground_service 277 | atis_airfare 278 | atis_distance 279 | atis_flight 280 | atis_flight 281 | atis_ground_service 282 | atis_airfare 283 | atis_ground_service 284 | atis_flight 285 | atis_flight 286 | atis_flight 287 | atis_flight 288 | atis_flight 289 | atis_airline 290 | atis_flight 291 | atis_flight 292 | atis_flight 293 | atis_ground_service 294 | atis_abbreviation 295 | atis_flight 296 | atis_flight 297 | atis_flight 298 | atis_flight 299 | atis_flight 300 | atis_aircraft 301 | atis_flight 302 | atis_flight 303 | atis_flight 304 | atis_flight 305 | atis_flight 306 | atis_flight 307 | atis_airfare 308 | atis_flight 309 | atis_flight 310 | atis_flight 311 | atis_flight 312 | atis_flight 313 | atis_flight 314 | atis_flight 315 | atis_flight 316 | atis_flight 317 | atis_flight 318 | atis_flight 319 | atis_flight 320 | atis_flight 321 | atis_flight 322 | atis_flight 323 | atis_abbreviation 324 | atis_flight 325 | atis_airfare 326 | atis_flight 327 | atis_abbreviation 328 | atis_flight 329 | atis_abbreviation 330 | atis_flight 331 | atis_flight 332 | atis_flight 333 | atis_flight 334 | atis_quantity 335 | atis_flight 336 | atis_airfare 337 | atis_airfare 338 | atis_flight 339 | atis_flight 340 | atis_flight 341 | atis_flight 342 | atis_abbreviation 343 | atis_flight 344 | atis_flight 345 | atis_flight 346 | atis_flight 347 | atis_flight 348 | atis_flight 349 | atis_flight 350 | atis_flight 351 | atis_airfare 352 | atis_flight 353 | atis_flight 354 | atis_flight 355 | atis_flight 356 | atis_flight 357 | atis_flight 358 | atis_flight 359 | atis_flight 360 | atis_airline 361 | atis_flight 362 | atis_flight 363 | atis_ground_service 364 | atis_flight 365 | atis_flight 366 | atis_flight 367 | atis_flight 368 | atis_flight 369 | atis_flight 370 | atis_flight 371 | atis_flight 372 | atis_flight 373 | atis_flight 374 | atis_flight 375 | atis_flight 376 | atis_airfare 377 | atis_flight 378 | atis_flight 379 | atis_ground_service 380 | atis_airline 381 | atis_flight 382 | atis_ground_service 383 | atis_flight 384 | atis_aircraft 385 | atis_flight 386 | atis_abbreviation 387 | atis_flight 388 | atis_flight 389 | atis_ground_service 390 | atis_flight 391 | atis_airfare 392 | atis_flight 393 | atis_abbreviation 394 | atis_airport 395 | atis_flight 396 | atis_flight 397 | atis_ground_service 398 | atis_flight 399 | atis_abbreviation 400 | atis_flight 401 | atis_ground_service 402 | atis_flight 403 | atis_airline 404 | atis_flight 405 | atis_airline 406 | atis_quantity 407 | atis_flight 408 | atis_flight 409 | atis_flight 410 | atis_flight 411 | atis_abbreviation 412 | atis_flight 413 | atis_airline 414 | atis_airfare 415 | atis_quantity 416 | atis_flight 417 | atis_flight 418 | atis_airfare#atis_flight_time 419 | atis_airline 420 | atis_ground_service 421 | atis_distance 422 | atis_flight 423 | atis_flight 424 | atis_ground_service 425 | atis_flight 426 | atis_flight 427 | atis_flight 428 | atis_flight 429 | atis_flight 430 | atis_flight 431 | atis_abbreviation 432 | atis_flight 433 | atis_flight 434 | atis_flight 435 | atis_flight 436 | atis_flight 437 | atis_flight 438 | atis_flight 439 | atis_flight 440 | atis_flight 441 | atis_flight 442 | atis_airfare 443 | atis_flight 444 | atis_airline 445 | atis_flight 446 | atis_airfare 447 | atis_flight 448 | atis_flight 449 | atis_airfare 450 | atis_flight 451 | atis_airport 452 | atis_flight 453 | atis_flight 454 | atis_flight 455 | atis_flight#atis_airfare 456 | atis_airline 457 | atis_flight 458 | atis_ground_service 459 | atis_flight 460 | atis_flight 461 | atis_flight 462 | atis_flight 463 | atis_flight 464 | atis_flight 465 | atis_flight 466 | atis_flight 467 | atis_ground_service 468 | atis_ground_service 469 | atis_flight 470 | atis_abbreviation 471 | atis_airline 472 | atis_ground_fare 473 | atis_flight 474 | atis_flight 475 | atis_flight 476 | atis_flight 477 | atis_flight 478 | atis_airline 479 | atis_flight 480 | atis_flight 481 | atis_flight 482 | atis_aircraft 483 | atis_flight 484 | atis_ground_fare 485 | atis_aircraft 486 | atis_flight 487 | atis_flight 488 | atis_flight 489 | atis_flight 490 | atis_flight 491 | atis_airfare 492 | atis_quantity 493 | atis_flight 494 | atis_flight 495 | atis_flight 496 | atis_flight 497 | atis_flight 498 | atis_flight 499 | atis_ground_service 500 | atis_flight 501 | -------------------------------------------------------------------------------- /data/snips_Intent_Detection_and_Slot_Filling/valid/label: -------------------------------------------------------------------------------- 1 | AddToPlaylist 2 | AddToPlaylist 3 | AddToPlaylist 4 | AddToPlaylist 5 | AddToPlaylist 6 | AddToPlaylist 7 | AddToPlaylist 8 | AddToPlaylist 9 | AddToPlaylist 10 | AddToPlaylist 11 | AddToPlaylist 12 | AddToPlaylist 13 | AddToPlaylist 14 | AddToPlaylist 15 | AddToPlaylist 16 | AddToPlaylist 17 | AddToPlaylist 18 | AddToPlaylist 19 | AddToPlaylist 20 | AddToPlaylist 21 | AddToPlaylist 22 | AddToPlaylist 23 | AddToPlaylist 24 | AddToPlaylist 25 | AddToPlaylist 26 | AddToPlaylist 27 | AddToPlaylist 28 | AddToPlaylist 29 | AddToPlaylist 30 | AddToPlaylist 31 | AddToPlaylist 32 | AddToPlaylist 33 | AddToPlaylist 34 | AddToPlaylist 35 | AddToPlaylist 36 | AddToPlaylist 37 | AddToPlaylist 38 | AddToPlaylist 39 | AddToPlaylist 40 | AddToPlaylist 41 | AddToPlaylist 42 | AddToPlaylist 43 | AddToPlaylist 44 | AddToPlaylist 45 | AddToPlaylist 46 | AddToPlaylist 47 | AddToPlaylist 48 | AddToPlaylist 49 | AddToPlaylist 50 | AddToPlaylist 51 | AddToPlaylist 52 | AddToPlaylist 53 | AddToPlaylist 54 | AddToPlaylist 55 | AddToPlaylist 56 | AddToPlaylist 57 | AddToPlaylist 58 | AddToPlaylist 59 | AddToPlaylist 60 | AddToPlaylist 61 | AddToPlaylist 62 | AddToPlaylist 63 | AddToPlaylist 64 | AddToPlaylist 65 | AddToPlaylist 66 | AddToPlaylist 67 | AddToPlaylist 68 | AddToPlaylist 69 | AddToPlaylist 70 | AddToPlaylist 71 | AddToPlaylist 72 | AddToPlaylist 73 | AddToPlaylist 74 | AddToPlaylist 75 | AddToPlaylist 76 | AddToPlaylist 77 | AddToPlaylist 78 | AddToPlaylist 79 | AddToPlaylist 80 | AddToPlaylist 81 | AddToPlaylist 82 | AddToPlaylist 83 | AddToPlaylist 84 | AddToPlaylist 85 | AddToPlaylist 86 | AddToPlaylist 87 | AddToPlaylist 88 | AddToPlaylist 89 | AddToPlaylist 90 | AddToPlaylist 91 | AddToPlaylist 92 | AddToPlaylist 93 | AddToPlaylist 94 | AddToPlaylist 95 | AddToPlaylist 96 | AddToPlaylist 97 | AddToPlaylist 98 | AddToPlaylist 99 | AddToPlaylist 100 | AddToPlaylist 101 | BookRestaurant 102 | BookRestaurant 103 | BookRestaurant 104 | BookRestaurant 105 | BookRestaurant 106 | BookRestaurant 107 | BookRestaurant 108 | BookRestaurant 109 | BookRestaurant 110 | BookRestaurant 111 | BookRestaurant 112 | BookRestaurant 113 | BookRestaurant 114 | BookRestaurant 115 | BookRestaurant 116 | BookRestaurant 117 | BookRestaurant 118 | BookRestaurant 119 | BookRestaurant 120 | BookRestaurant 121 | BookRestaurant 122 | BookRestaurant 123 | BookRestaurant 124 | BookRestaurant 125 | BookRestaurant 126 | BookRestaurant 127 | BookRestaurant 128 | BookRestaurant 129 | BookRestaurant 130 | BookRestaurant 131 | BookRestaurant 132 | BookRestaurant 133 | BookRestaurant 134 | BookRestaurant 135 | BookRestaurant 136 | BookRestaurant 137 | BookRestaurant 138 | BookRestaurant 139 | BookRestaurant 140 | BookRestaurant 141 | BookRestaurant 142 | BookRestaurant 143 | BookRestaurant 144 | BookRestaurant 145 | BookRestaurant 146 | BookRestaurant 147 | BookRestaurant 148 | BookRestaurant 149 | BookRestaurant 150 | BookRestaurant 151 | BookRestaurant 152 | BookRestaurant 153 | BookRestaurant 154 | BookRestaurant 155 | BookRestaurant 156 | BookRestaurant 157 | BookRestaurant 158 | BookRestaurant 159 | BookRestaurant 160 | BookRestaurant 161 | BookRestaurant 162 | BookRestaurant 163 | BookRestaurant 164 | BookRestaurant 165 | BookRestaurant 166 | BookRestaurant 167 | BookRestaurant 168 | BookRestaurant 169 | BookRestaurant 170 | BookRestaurant 171 | BookRestaurant 172 | BookRestaurant 173 | BookRestaurant 174 | BookRestaurant 175 | BookRestaurant 176 | BookRestaurant 177 | BookRestaurant 178 | BookRestaurant 179 | BookRestaurant 180 | BookRestaurant 181 | BookRestaurant 182 | BookRestaurant 183 | BookRestaurant 184 | BookRestaurant 185 | BookRestaurant 186 | BookRestaurant 187 | BookRestaurant 188 | BookRestaurant 189 | BookRestaurant 190 | BookRestaurant 191 | BookRestaurant 192 | BookRestaurant 193 | BookRestaurant 194 | BookRestaurant 195 | BookRestaurant 196 | BookRestaurant 197 | BookRestaurant 198 | BookRestaurant 199 | BookRestaurant 200 | BookRestaurant 201 | GetWeather 202 | GetWeather 203 | GetWeather 204 | GetWeather 205 | GetWeather 206 | GetWeather 207 | GetWeather 208 | GetWeather 209 | GetWeather 210 | GetWeather 211 | GetWeather 212 | GetWeather 213 | GetWeather 214 | GetWeather 215 | GetWeather 216 | GetWeather 217 | GetWeather 218 | GetWeather 219 | GetWeather 220 | GetWeather 221 | GetWeather 222 | GetWeather 223 | GetWeather 224 | GetWeather 225 | GetWeather 226 | GetWeather 227 | GetWeather 228 | GetWeather 229 | GetWeather 230 | GetWeather 231 | GetWeather 232 | GetWeather 233 | GetWeather 234 | GetWeather 235 | GetWeather 236 | GetWeather 237 | GetWeather 238 | GetWeather 239 | GetWeather 240 | GetWeather 241 | GetWeather 242 | GetWeather 243 | GetWeather 244 | GetWeather 245 | GetWeather 246 | GetWeather 247 | GetWeather 248 | GetWeather 249 | GetWeather 250 | GetWeather 251 | GetWeather 252 | GetWeather 253 | GetWeather 254 | GetWeather 255 | GetWeather 256 | GetWeather 257 | GetWeather 258 | GetWeather 259 | GetWeather 260 | GetWeather 261 | GetWeather 262 | GetWeather 263 | GetWeather 264 | GetWeather 265 | GetWeather 266 | GetWeather 267 | GetWeather 268 | GetWeather 269 | GetWeather 270 | GetWeather 271 | GetWeather 272 | GetWeather 273 | GetWeather 274 | GetWeather 275 | GetWeather 276 | GetWeather 277 | GetWeather 278 | GetWeather 279 | GetWeather 280 | GetWeather 281 | GetWeather 282 | GetWeather 283 | GetWeather 284 | GetWeather 285 | GetWeather 286 | GetWeather 287 | GetWeather 288 | GetWeather 289 | GetWeather 290 | GetWeather 291 | GetWeather 292 | GetWeather 293 | GetWeather 294 | GetWeather 295 | GetWeather 296 | GetWeather 297 | GetWeather 298 | GetWeather 299 | GetWeather 300 | GetWeather 301 | PlayMusic 302 | PlayMusic 303 | PlayMusic 304 | PlayMusic 305 | PlayMusic 306 | PlayMusic 307 | PlayMusic 308 | PlayMusic 309 | PlayMusic 310 | PlayMusic 311 | PlayMusic 312 | PlayMusic 313 | PlayMusic 314 | PlayMusic 315 | PlayMusic 316 | PlayMusic 317 | PlayMusic 318 | PlayMusic 319 | PlayMusic 320 | PlayMusic 321 | PlayMusic 322 | PlayMusic 323 | PlayMusic 324 | PlayMusic 325 | PlayMusic 326 | PlayMusic 327 | PlayMusic 328 | PlayMusic 329 | PlayMusic 330 | PlayMusic 331 | PlayMusic 332 | PlayMusic 333 | PlayMusic 334 | PlayMusic 335 | PlayMusic 336 | PlayMusic 337 | PlayMusic 338 | PlayMusic 339 | PlayMusic 340 | PlayMusic 341 | PlayMusic 342 | PlayMusic 343 | PlayMusic 344 | PlayMusic 345 | PlayMusic 346 | PlayMusic 347 | PlayMusic 348 | PlayMusic 349 | PlayMusic 350 | PlayMusic 351 | PlayMusic 352 | PlayMusic 353 | PlayMusic 354 | PlayMusic 355 | PlayMusic 356 | PlayMusic 357 | PlayMusic 358 | PlayMusic 359 | PlayMusic 360 | PlayMusic 361 | PlayMusic 362 | PlayMusic 363 | PlayMusic 364 | PlayMusic 365 | PlayMusic 366 | PlayMusic 367 | PlayMusic 368 | PlayMusic 369 | PlayMusic 370 | PlayMusic 371 | PlayMusic 372 | PlayMusic 373 | PlayMusic 374 | PlayMusic 375 | PlayMusic 376 | PlayMusic 377 | PlayMusic 378 | PlayMusic 379 | PlayMusic 380 | PlayMusic 381 | PlayMusic 382 | PlayMusic 383 | PlayMusic 384 | PlayMusic 385 | PlayMusic 386 | PlayMusic 387 | PlayMusic 388 | PlayMusic 389 | PlayMusic 390 | PlayMusic 391 | PlayMusic 392 | PlayMusic 393 | PlayMusic 394 | PlayMusic 395 | PlayMusic 396 | PlayMusic 397 | PlayMusic 398 | PlayMusic 399 | PlayMusic 400 | PlayMusic 401 | RateBook 402 | RateBook 403 | RateBook 404 | RateBook 405 | RateBook 406 | RateBook 407 | RateBook 408 | RateBook 409 | RateBook 410 | RateBook 411 | RateBook 412 | RateBook 413 | RateBook 414 | RateBook 415 | RateBook 416 | RateBook 417 | RateBook 418 | RateBook 419 | RateBook 420 | RateBook 421 | RateBook 422 | RateBook 423 | RateBook 424 | RateBook 425 | RateBook 426 | RateBook 427 | RateBook 428 | RateBook 429 | RateBook 430 | RateBook 431 | RateBook 432 | RateBook 433 | RateBook 434 | RateBook 435 | RateBook 436 | RateBook 437 | RateBook 438 | RateBook 439 | RateBook 440 | RateBook 441 | RateBook 442 | RateBook 443 | RateBook 444 | RateBook 445 | RateBook 446 | RateBook 447 | RateBook 448 | RateBook 449 | RateBook 450 | RateBook 451 | RateBook 452 | RateBook 453 | RateBook 454 | RateBook 455 | RateBook 456 | RateBook 457 | RateBook 458 | RateBook 459 | RateBook 460 | RateBook 461 | RateBook 462 | RateBook 463 | RateBook 464 | RateBook 465 | RateBook 466 | RateBook 467 | RateBook 468 | RateBook 469 | RateBook 470 | RateBook 471 | RateBook 472 | RateBook 473 | RateBook 474 | RateBook 475 | RateBook 476 | RateBook 477 | RateBook 478 | RateBook 479 | RateBook 480 | RateBook 481 | RateBook 482 | RateBook 483 | RateBook 484 | RateBook 485 | RateBook 486 | RateBook 487 | RateBook 488 | RateBook 489 | RateBook 490 | RateBook 491 | RateBook 492 | RateBook 493 | RateBook 494 | RateBook 495 | RateBook 496 | RateBook 497 | RateBook 498 | RateBook 499 | RateBook 500 | RateBook 501 | SearchCreativeWork 502 | SearchCreativeWork 503 | SearchCreativeWork 504 | SearchCreativeWork 505 | SearchCreativeWork 506 | SearchCreativeWork 507 | SearchCreativeWork 508 | SearchCreativeWork 509 | SearchCreativeWork 510 | SearchCreativeWork 511 | SearchCreativeWork 512 | SearchCreativeWork 513 | SearchCreativeWork 514 | SearchCreativeWork 515 | SearchCreativeWork 516 | SearchCreativeWork 517 | SearchCreativeWork 518 | SearchCreativeWork 519 | SearchCreativeWork 520 | SearchCreativeWork 521 | SearchCreativeWork 522 | SearchCreativeWork 523 | SearchCreativeWork 524 | SearchCreativeWork 525 | SearchCreativeWork 526 | SearchCreativeWork 527 | SearchCreativeWork 528 | SearchCreativeWork 529 | SearchCreativeWork 530 | SearchCreativeWork 531 | SearchCreativeWork 532 | SearchCreativeWork 533 | SearchCreativeWork 534 | SearchCreativeWork 535 | SearchCreativeWork 536 | SearchCreativeWork 537 | SearchCreativeWork 538 | SearchCreativeWork 539 | SearchCreativeWork 540 | SearchCreativeWork 541 | SearchCreativeWork 542 | SearchCreativeWork 543 | SearchCreativeWork 544 | SearchCreativeWork 545 | SearchCreativeWork 546 | SearchCreativeWork 547 | SearchCreativeWork 548 | SearchCreativeWork 549 | SearchCreativeWork 550 | SearchCreativeWork 551 | SearchCreativeWork 552 | SearchCreativeWork 553 | SearchCreativeWork 554 | SearchCreativeWork 555 | SearchCreativeWork 556 | SearchCreativeWork 557 | SearchCreativeWork 558 | SearchCreativeWork 559 | SearchCreativeWork 560 | SearchCreativeWork 561 | SearchCreativeWork 562 | SearchCreativeWork 563 | SearchCreativeWork 564 | SearchCreativeWork 565 | SearchCreativeWork 566 | SearchCreativeWork 567 | SearchCreativeWork 568 | SearchCreativeWork 569 | SearchCreativeWork 570 | SearchCreativeWork 571 | SearchCreativeWork 572 | SearchCreativeWork 573 | SearchCreativeWork 574 | SearchCreativeWork 575 | SearchCreativeWork 576 | SearchCreativeWork 577 | SearchCreativeWork 578 | SearchCreativeWork 579 | SearchCreativeWork 580 | SearchCreativeWork 581 | SearchCreativeWork 582 | SearchCreativeWork 583 | SearchCreativeWork 584 | SearchCreativeWork 585 | SearchCreativeWork 586 | SearchCreativeWork 587 | SearchCreativeWork 588 | SearchCreativeWork 589 | SearchCreativeWork 590 | SearchCreativeWork 591 | SearchCreativeWork 592 | SearchCreativeWork 593 | SearchCreativeWork 594 | SearchCreativeWork 595 | SearchCreativeWork 596 | SearchCreativeWork 597 | SearchCreativeWork 598 | SearchCreativeWork 599 | SearchCreativeWork 600 | SearchCreativeWork 601 | SearchScreeningEvent 602 | SearchScreeningEvent 603 | SearchScreeningEvent 604 | SearchScreeningEvent 605 | SearchScreeningEvent 606 | SearchScreeningEvent 607 | SearchScreeningEvent 608 | SearchScreeningEvent 609 | SearchScreeningEvent 610 | SearchScreeningEvent 611 | SearchScreeningEvent 612 | SearchScreeningEvent 613 | SearchScreeningEvent 614 | SearchScreeningEvent 615 | SearchScreeningEvent 616 | SearchScreeningEvent 617 | SearchScreeningEvent 618 | SearchScreeningEvent 619 | SearchScreeningEvent 620 | SearchScreeningEvent 621 | SearchScreeningEvent 622 | SearchScreeningEvent 623 | SearchScreeningEvent 624 | SearchScreeningEvent 625 | SearchScreeningEvent 626 | SearchScreeningEvent 627 | SearchScreeningEvent 628 | SearchScreeningEvent 629 | SearchScreeningEvent 630 | SearchScreeningEvent 631 | SearchScreeningEvent 632 | SearchScreeningEvent 633 | SearchScreeningEvent 634 | SearchScreeningEvent 635 | SearchScreeningEvent 636 | SearchScreeningEvent 637 | SearchScreeningEvent 638 | SearchScreeningEvent 639 | SearchScreeningEvent 640 | SearchScreeningEvent 641 | SearchScreeningEvent 642 | SearchScreeningEvent 643 | SearchScreeningEvent 644 | SearchScreeningEvent 645 | SearchScreeningEvent 646 | SearchScreeningEvent 647 | SearchScreeningEvent 648 | SearchScreeningEvent 649 | SearchScreeningEvent 650 | SearchScreeningEvent 651 | SearchScreeningEvent 652 | SearchScreeningEvent 653 | SearchScreeningEvent 654 | SearchScreeningEvent 655 | SearchScreeningEvent 656 | SearchScreeningEvent 657 | SearchScreeningEvent 658 | SearchScreeningEvent 659 | SearchScreeningEvent 660 | SearchScreeningEvent 661 | SearchScreeningEvent 662 | SearchScreeningEvent 663 | SearchScreeningEvent 664 | SearchScreeningEvent 665 | SearchScreeningEvent 666 | SearchScreeningEvent 667 | SearchScreeningEvent 668 | SearchScreeningEvent 669 | SearchScreeningEvent 670 | SearchScreeningEvent 671 | SearchScreeningEvent 672 | SearchScreeningEvent 673 | SearchScreeningEvent 674 | SearchScreeningEvent 675 | SearchScreeningEvent 676 | SearchScreeningEvent 677 | SearchScreeningEvent 678 | SearchScreeningEvent 679 | SearchScreeningEvent 680 | SearchScreeningEvent 681 | SearchScreeningEvent 682 | SearchScreeningEvent 683 | SearchScreeningEvent 684 | SearchScreeningEvent 685 | SearchScreeningEvent 686 | SearchScreeningEvent 687 | SearchScreeningEvent 688 | SearchScreeningEvent 689 | SearchScreeningEvent 690 | SearchScreeningEvent 691 | SearchScreeningEvent 692 | SearchScreeningEvent 693 | SearchScreeningEvent 694 | SearchScreeningEvent 695 | SearchScreeningEvent 696 | SearchScreeningEvent 697 | SearchScreeningEvent 698 | SearchScreeningEvent 699 | SearchScreeningEvent 700 | SearchScreeningEvent 701 | -------------------------------------------------------------------------------- /output_model_prediction/atis_join_task_LSTM_epoch30_ckpt4198/.ipynb_checkpoints/model_score_log-checkpoint.txt: -------------------------------------------------------------------------------- 1 | 时间: 2019-12-03 15:18:04 2 | 准确率: 0.9764837625979843 3 | 宏平均精确率: 0.6948830139834943 4 | 微平均精确率: 0.9764837625979843 5 | 加权平均精确率: 0.9751264947939539 6 | 宏平均召回率: 0.713644670245032 7 | 微平均召回率: 0.9764837625979843 8 | 加权平均召回率: 0.9764837625979843 9 | 宏平均F1-score: 0.5568218399380129 10 | 微平均F1-score: 0.9764837625979843 11 | 加权平均F1-score: 0.9741401155452318 12 | 13 | 混淆矩阵输出: 14 | [[ 33 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 15 | 0 0 0] 16 | [ 0 8 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 17 | 0 0 1] 18 | [ 0 0 48 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 19 | 0 0 0] 20 | [ 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 21 | 0 0 0] 22 | [ 0 0 0 0 38 0 0 0 0 0 0 0 0 0 0 0 0 0 23 | 0 0 0] 24 | [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 25 | 0 0 0] 26 | [ 0 0 0 0 0 0 18 0 0 0 0 0 0 0 0 0 0 0 27 | 0 0 0] 28 | [ 0 0 0 0 0 0 0 21 0 0 0 0 0 0 0 0 0 0 29 | 0 0 0] 30 | [ 0 0 0 0 0 0 0 0 6 0 0 0 0 0 0 0 0 0 31 | 0 0 0] 32 | [ 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 33 | 0 0 0] 34 | [ 0 0 0 0 0 0 0 0 0 0 10 0 0 0 0 0 0 0 35 | 0 0 0] 36 | [ 0 0 0 0 0 0 1 0 0 0 0 626 1 0 0 0 0 0 37 | 0 0 4] 38 | [ 0 0 0 0 0 0 0 0 0 0 0 7 5 0 0 0 0 0 39 | 0 0 0] 40 | [ 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 41 | 0 0 0] 42 | [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 8 0 0 0 43 | 0 0 0] 44 | [ 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 45 | 0 0 0] 46 | [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 47 | 0 0 0] 48 | [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 6 49 | 1 0 0] 50 | [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 51 | 36 0 0] 52 | [ 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 53 | 0 5 0] 54 | [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 55 | 0 0 3]] 56 | 分类报告: 57 | precision recall f1-score support 58 | 59 | atis_abbreviation 1.00 1.00 1.00 33 60 | atis_aircraft 1.00 0.89 0.94 9 61 | atis_airfare 0.98 1.00 0.99 48 62 | atis_airfare#atis_flight 0.00 0.00 0.00 1 63 | atis_airline 1.00 1.00 1.00 38 64 | atis_airline#atis_flight_no 0.00 0.00 0.00 0 65 | atis_airport 0.95 1.00 0.97 18 66 | atis_capacity 1.00 1.00 1.00 21 67 | atis_city 1.00 1.00 1.00 6 68 | atis_day_name 0.00 0.00 0.00 2 69 | atis_distance 1.00 1.00 1.00 10 70 | atis_flight 0.98 0.99 0.99 632 71 | atis_flight#atis_airfare 0.83 0.42 0.56 12 72 | atis_flight#atis_airline 0.00 0.00 0.00 1 73 | atis_flight_no 1.00 1.00 1.00 8 74 | atis_flight_no#atis_airline 0.00 0.00 0.00 1 75 | atis_flight_time 0.50 1.00 0.67 1 76 | atis_ground_fare 1.00 0.86 0.92 7 77 | atis_ground_service 0.97 1.00 0.99 36 78 | atis_meal 1.00 0.83 0.91 6 79 | atis_quantity 0.38 1.00 0.55 3 80 | 81 | accuracy 0.98 893 82 | macro avg 0.69 0.71 0.69 893 83 | weighted avg 0.98 0.98 0.97 893 84 | 85 | 86 | 87 | 时间: 2019-12-03 15:18:04 88 | 准确率: 0.9551752241238793 89 | 宏平均精确率: 0.8168164504259258 90 | 微平均精确率: 0.9551752241238793 91 | 加权平均精确率: 0.9533279446565879 92 | 宏平均召回率: 0.8119494819792376 93 | 微平均召回率: 0.9551752241238793 94 | 加权平均召回率: 0.9551752241238793 95 | 宏平均F1-score: 0.0 96 | 微平均F1-score: 0.0 97 | 加权平均F1-score: 0 98 | 99 | 混淆矩阵输出: 100 | [[ 31 1 0 ... 0 0 0] 101 | [ 0 32 0 ... 0 0 0] 102 | [ 0 0 101 ... 0 0 0] 103 | ... 104 | [ 0 0 0 ... 1 0 0] 105 | [ 0 0 0 ... 0 0 1] 106 | [ 0 4 0 ... 0 0 0]] 107 | 分类报告: 108 | precision recall f1-score support 109 | 110 | B-aircraft_code 1.00 0.94 0.97 33 111 | B-airline_code 0.86 0.94 0.90 34 112 | B-airline_name 1.00 1.00 1.00 101 113 | B-airport_code 0.80 0.44 0.57 9 114 | B-airport_name 0.82 0.43 0.56 21 115 | B-arrive_date.date_relative 0.67 1.00 0.80 2 116 | B-arrive_date.day_name 0.79 1.00 0.88 11 117 | B-arrive_date.day_number 0.71 0.83 0.77 6 118 | B-arrive_date.month_name 0.71 0.83 0.77 6 119 | B-arrive_time.end_time 1.00 1.00 1.00 8 120 | B-arrive_time.period_of_day 0.75 1.00 0.86 6 121 | B-arrive_time.start_time 0.89 1.00 0.94 8 122 | B-arrive_time.time 0.94 0.97 0.96 34 123 | B-arrive_time.time_relative 0.94 0.94 0.94 31 124 | B-booking_class 0.00 0.00 0.00 1 125 | B-city_name 0.85 0.58 0.69 57 126 | B-class_type 0.96 1.00 0.98 24 127 | B-compartment 0.00 0.00 0.00 1 128 | B-connect 1.00 1.00 1.00 6 129 | B-cost_relative 1.00 0.97 0.99 37 130 | B-day_name 1.00 0.50 0.67 2 131 | B-days_code 1.00 1.00 1.00 1 132 | B-depart_date.date_relative 1.00 1.00 1.00 17 133 | B-depart_date.day_name 1.00 0.99 0.99 212 134 | B-depart_date.day_number 0.98 0.96 0.97 55 135 | B-depart_date.month_name 0.98 0.96 0.97 56 136 | B-depart_date.today_relative 1.00 0.89 0.94 9 137 | B-depart_date.year 1.00 1.00 1.00 3 138 | B-depart_time.end_time 1.00 1.00 1.00 3 139 | B-depart_time.period_mod 1.00 1.00 1.00 5 140 | B-depart_time.period_of_day 1.00 0.91 0.95 130 141 | B-depart_time.start_time 1.00 0.67 0.80 3 142 | B-depart_time.time 0.88 1.00 0.93 57 143 | B-depart_time.time_relative 0.97 0.98 0.98 65 144 | B-economy 1.00 1.00 1.00 6 145 | B-fare_amount 1.00 1.00 1.00 2 146 | B-fare_basis_code 0.85 1.00 0.92 17 147 | B-flight 0.00 0.00 0.00 1 148 | B-flight_days 1.00 1.00 1.00 10 149 | B-flight_mod 0.83 1.00 0.91 24 150 | B-flight_number 0.85 1.00 0.92 11 151 | B-flight_stop 1.00 1.00 1.00 21 152 | B-flight_time 0.50 1.00 0.67 1 153 | B-fromloc.airport_code 0.50 1.00 0.67 5 154 | B-fromloc.airport_name 0.48 1.00 0.65 12 155 | B-fromloc.city_name 0.99 1.00 0.99 704 156 | B-fromloc.state_code 1.00 1.00 1.00 23 157 | B-fromloc.state_name 0.94 1.00 0.97 17 158 | B-meal 0.94 1.00 0.97 16 159 | B-meal_code 1.00 1.00 1.00 1 160 | B-meal_description 1.00 1.00 1.00 10 161 | B-mod 1.00 0.50 0.67 2 162 | B-or 0.38 1.00 0.55 3 163 | B-period_of_day 1.00 0.75 0.86 4 164 | B-restriction_code 1.00 1.00 1.00 4 165 | B-return_date.date_relative 0.50 0.33 0.40 3 166 | B-return_date.day_name 1.00 0.50 0.67 2 167 | B-round_trip 1.00 0.97 0.99 73 168 | B-state_code 1.00 1.00 1.00 1 169 | B-state_name 0.00 0.00 0.00 9 170 | B-stoploc.airport_code 0.00 0.00 0.00 1 171 | B-stoploc.city_name 1.00 1.00 1.00 20 172 | B-toloc.airport_code 1.00 0.75 0.86 4 173 | B-toloc.airport_name 1.00 1.00 1.00 3 174 | B-toloc.city_name 0.97 0.99 0.98 716 175 | B-toloc.country_name 1.00 1.00 1.00 1 176 | B-toloc.state_code 1.00 1.00 1.00 18 177 | B-toloc.state_name 0.90 1.00 0.95 28 178 | B-transport_type 1.00 1.00 1.00 10 179 | I-airline_name 1.00 1.00 1.00 65 180 | I-airport_name 0.86 0.41 0.56 29 181 | I-arrive_date.day_number 0.00 0.00 0.00 0 182 | I-arrive_time.end_time 0.89 1.00 0.94 8 183 | I-arrive_time.start_time 1.00 1.00 1.00 1 184 | I-arrive_time.time 1.00 0.97 0.99 35 185 | I-arrive_time.time_relative 1.00 1.00 1.00 4 186 | I-city_name 0.88 0.47 0.61 30 187 | I-class_type 1.00 1.00 1.00 17 188 | I-cost_relative 1.00 0.67 0.80 3 189 | I-depart_date.day_number 1.00 0.93 0.97 15 190 | I-depart_time.end_time 1.00 0.67 0.80 3 191 | I-depart_time.period_of_day 1.00 1.00 1.00 1 192 | I-depart_time.start_time 1.00 1.00 1.00 1 193 | I-depart_time.time 0.93 1.00 0.96 52 194 | I-depart_time.time_relative 0.00 0.00 0.00 1 195 | I-fare_amount 1.00 1.00 1.00 2 196 | I-flight_mod 0.50 0.17 0.25 6 197 | I-flight_number 0.00 0.00 0.00 1 198 | I-flight_time 1.00 1.00 1.00 1 199 | I-fromloc.airport_name 0.45 1.00 0.62 15 200 | I-fromloc.city_name 0.98 1.00 0.99 177 201 | I-fromloc.state_name 1.00 1.00 1.00 1 202 | I-restriction_code 1.00 1.00 1.00 3 203 | I-return_date.date_relative 0.75 1.00 0.86 3 204 | I-round_trip 1.00 1.00 1.00 71 205 | I-state_name 0.00 0.00 0.00 1 206 | I-stoploc.city_name 1.00 1.00 1.00 10 207 | I-toloc.airport_name 1.00 1.00 1.00 3 208 | I-toloc.city_name 0.96 0.99 0.97 265 209 | I-toloc.state_name 1.00 1.00 1.00 1 210 | I-transport_type 0.00 0.00 0.00 1 211 | O 0.00 0.00 0.00 18 212 | 213 | accuracy 0.96 3681 214 | macro avg 0.82 0.81 0.80 3681 215 | weighted avg 0.95 0.96 0.95 3681 216 | 217 | 218 | 219 | -------------------------------------------------------------------------------- /output_model_prediction/atis_join_task_LSTM_epoch30_ckpt4198/intent_label2id.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanxiaosc/BERT-for-Sequence-Labeling-and-Text-Classification/2a6d2f9c732a362458030643e131540e7d1cdcca/output_model_prediction/atis_join_task_LSTM_epoch30_ckpt4198/intent_label2id.pkl -------------------------------------------------------------------------------- /output_model_prediction/atis_join_task_LSTM_epoch30_ckpt4198/model_score_log.txt: -------------------------------------------------------------------------------- 1 | 时间: 2019-12-03 15:18:04 2 | 准确率: 0.9764837625979843 3 | 宏平均精确率: 0.6948830139834943 4 | 微平均精确率: 0.9764837625979843 5 | 加权平均精确率: 0.9751264947939539 6 | 宏平均召回率: 0.713644670245032 7 | 微平均召回率: 0.9764837625979843 8 | 加权平均召回率: 0.9764837625979843 9 | 宏平均F1-score: 0.5568218399380129 10 | 微平均F1-score: 0.9764837625979843 11 | 加权平均F1-score: 0.9741401155452318 12 | 13 | 混淆矩阵输出: 14 | [[ 33 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 15 | 0 0 0] 16 | [ 0 8 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 17 | 0 0 1] 18 | [ 0 0 48 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 19 | 0 0 0] 20 | [ 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 21 | 0 0 0] 22 | [ 0 0 0 0 38 0 0 0 0 0 0 0 0 0 0 0 0 0 23 | 0 0 0] 24 | [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 25 | 0 0 0] 26 | [ 0 0 0 0 0 0 18 0 0 0 0 0 0 0 0 0 0 0 27 | 0 0 0] 28 | [ 0 0 0 0 0 0 0 21 0 0 0 0 0 0 0 0 0 0 29 | 0 0 0] 30 | [ 0 0 0 0 0 0 0 0 6 0 0 0 0 0 0 0 0 0 31 | 0 0 0] 32 | [ 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 33 | 0 0 0] 34 | [ 0 0 0 0 0 0 0 0 0 0 10 0 0 0 0 0 0 0 35 | 0 0 0] 36 | [ 0 0 0 0 0 0 1 0 0 0 0 626 1 0 0 0 0 0 37 | 0 0 4] 38 | [ 0 0 0 0 0 0 0 0 0 0 0 7 5 0 0 0 0 0 39 | 0 0 0] 40 | [ 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 41 | 0 0 0] 42 | [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 8 0 0 0 43 | 0 0 0] 44 | [ 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 45 | 0 0 0] 46 | [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 47 | 0 0 0] 48 | [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 6 49 | 1 0 0] 50 | [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 51 | 36 0 0] 52 | [ 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 53 | 0 5 0] 54 | [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 55 | 0 0 3]] 56 | 分类报告: 57 | precision recall f1-score support 58 | 59 | atis_abbreviation 1.00 1.00 1.00 33 60 | atis_aircraft 1.00 0.89 0.94 9 61 | atis_airfare 0.98 1.00 0.99 48 62 | atis_airfare#atis_flight 0.00 0.00 0.00 1 63 | atis_airline 1.00 1.00 1.00 38 64 | atis_airline#atis_flight_no 0.00 0.00 0.00 0 65 | atis_airport 0.95 1.00 0.97 18 66 | atis_capacity 1.00 1.00 1.00 21 67 | atis_city 1.00 1.00 1.00 6 68 | atis_day_name 0.00 0.00 0.00 2 69 | atis_distance 1.00 1.00 1.00 10 70 | atis_flight 0.98 0.99 0.99 632 71 | atis_flight#atis_airfare 0.83 0.42 0.56 12 72 | atis_flight#atis_airline 0.00 0.00 0.00 1 73 | atis_flight_no 1.00 1.00 1.00 8 74 | atis_flight_no#atis_airline 0.00 0.00 0.00 1 75 | atis_flight_time 0.50 1.00 0.67 1 76 | atis_ground_fare 1.00 0.86 0.92 7 77 | atis_ground_service 0.97 1.00 0.99 36 78 | atis_meal 1.00 0.83 0.91 6 79 | atis_quantity 0.38 1.00 0.55 3 80 | 81 | accuracy 0.98 893 82 | macro avg 0.69 0.71 0.69 893 83 | weighted avg 0.98 0.98 0.97 893 84 | 85 | 86 | 87 | 时间: 2019-12-03 15:18:04 88 | 准确率: 0.9551752241238793 89 | 宏平均精确率: 0.8168164504259258 90 | 微平均精确率: 0.9551752241238793 91 | 加权平均精确率: 0.9533279446565879 92 | 宏平均召回率: 0.8119494819792376 93 | 微平均召回率: 0.9551752241238793 94 | 加权平均召回率: 0.9551752241238793 95 | 宏平均F1-score: 0.0 96 | 微平均F1-score: 0.0 97 | 加权平均F1-score: 0 98 | 99 | 混淆矩阵输出: 100 | [[ 31 1 0 ... 0 0 0] 101 | [ 0 32 0 ... 0 0 0] 102 | [ 0 0 101 ... 0 0 0] 103 | ... 104 | [ 0 0 0 ... 1 0 0] 105 | [ 0 0 0 ... 0 0 1] 106 | [ 0 4 0 ... 0 0 0]] 107 | 分类报告: 108 | precision recall f1-score support 109 | 110 | B-aircraft_code 1.00 0.94 0.97 33 111 | B-airline_code 0.86 0.94 0.90 34 112 | B-airline_name 1.00 1.00 1.00 101 113 | B-airport_code 0.80 0.44 0.57 9 114 | B-airport_name 0.82 0.43 0.56 21 115 | B-arrive_date.date_relative 0.67 1.00 0.80 2 116 | B-arrive_date.day_name 0.79 1.00 0.88 11 117 | B-arrive_date.day_number 0.71 0.83 0.77 6 118 | B-arrive_date.month_name 0.71 0.83 0.77 6 119 | B-arrive_time.end_time 1.00 1.00 1.00 8 120 | B-arrive_time.period_of_day 0.75 1.00 0.86 6 121 | B-arrive_time.start_time 0.89 1.00 0.94 8 122 | B-arrive_time.time 0.94 0.97 0.96 34 123 | B-arrive_time.time_relative 0.94 0.94 0.94 31 124 | B-booking_class 0.00 0.00 0.00 1 125 | B-city_name 0.85 0.58 0.69 57 126 | B-class_type 0.96 1.00 0.98 24 127 | B-compartment 0.00 0.00 0.00 1 128 | B-connect 1.00 1.00 1.00 6 129 | B-cost_relative 1.00 0.97 0.99 37 130 | B-day_name 1.00 0.50 0.67 2 131 | B-days_code 1.00 1.00 1.00 1 132 | B-depart_date.date_relative 1.00 1.00 1.00 17 133 | B-depart_date.day_name 1.00 0.99 0.99 212 134 | B-depart_date.day_number 0.98 0.96 0.97 55 135 | B-depart_date.month_name 0.98 0.96 0.97 56 136 | B-depart_date.today_relative 1.00 0.89 0.94 9 137 | B-depart_date.year 1.00 1.00 1.00 3 138 | B-depart_time.end_time 1.00 1.00 1.00 3 139 | B-depart_time.period_mod 1.00 1.00 1.00 5 140 | B-depart_time.period_of_day 1.00 0.91 0.95 130 141 | B-depart_time.start_time 1.00 0.67 0.80 3 142 | B-depart_time.time 0.88 1.00 0.93 57 143 | B-depart_time.time_relative 0.97 0.98 0.98 65 144 | B-economy 1.00 1.00 1.00 6 145 | B-fare_amount 1.00 1.00 1.00 2 146 | B-fare_basis_code 0.85 1.00 0.92 17 147 | B-flight 0.00 0.00 0.00 1 148 | B-flight_days 1.00 1.00 1.00 10 149 | B-flight_mod 0.83 1.00 0.91 24 150 | B-flight_number 0.85 1.00 0.92 11 151 | B-flight_stop 1.00 1.00 1.00 21 152 | B-flight_time 0.50 1.00 0.67 1 153 | B-fromloc.airport_code 0.50 1.00 0.67 5 154 | B-fromloc.airport_name 0.48 1.00 0.65 12 155 | B-fromloc.city_name 0.99 1.00 0.99 704 156 | B-fromloc.state_code 1.00 1.00 1.00 23 157 | B-fromloc.state_name 0.94 1.00 0.97 17 158 | B-meal 0.94 1.00 0.97 16 159 | B-meal_code 1.00 1.00 1.00 1 160 | B-meal_description 1.00 1.00 1.00 10 161 | B-mod 1.00 0.50 0.67 2 162 | B-or 0.38 1.00 0.55 3 163 | B-period_of_day 1.00 0.75 0.86 4 164 | B-restriction_code 1.00 1.00 1.00 4 165 | B-return_date.date_relative 0.50 0.33 0.40 3 166 | B-return_date.day_name 1.00 0.50 0.67 2 167 | B-round_trip 1.00 0.97 0.99 73 168 | B-state_code 1.00 1.00 1.00 1 169 | B-state_name 0.00 0.00 0.00 9 170 | B-stoploc.airport_code 0.00 0.00 0.00 1 171 | B-stoploc.city_name 1.00 1.00 1.00 20 172 | B-toloc.airport_code 1.00 0.75 0.86 4 173 | B-toloc.airport_name 1.00 1.00 1.00 3 174 | B-toloc.city_name 0.97 0.99 0.98 716 175 | B-toloc.country_name 1.00 1.00 1.00 1 176 | B-toloc.state_code 1.00 1.00 1.00 18 177 | B-toloc.state_name 0.90 1.00 0.95 28 178 | B-transport_type 1.00 1.00 1.00 10 179 | I-airline_name 1.00 1.00 1.00 65 180 | I-airport_name 0.86 0.41 0.56 29 181 | I-arrive_date.day_number 0.00 0.00 0.00 0 182 | I-arrive_time.end_time 0.89 1.00 0.94 8 183 | I-arrive_time.start_time 1.00 1.00 1.00 1 184 | I-arrive_time.time 1.00 0.97 0.99 35 185 | I-arrive_time.time_relative 1.00 1.00 1.00 4 186 | I-city_name 0.88 0.47 0.61 30 187 | I-class_type 1.00 1.00 1.00 17 188 | I-cost_relative 1.00 0.67 0.80 3 189 | I-depart_date.day_number 1.00 0.93 0.97 15 190 | I-depart_time.end_time 1.00 0.67 0.80 3 191 | I-depart_time.period_of_day 1.00 1.00 1.00 1 192 | I-depart_time.start_time 1.00 1.00 1.00 1 193 | I-depart_time.time 0.93 1.00 0.96 52 194 | I-depart_time.time_relative 0.00 0.00 0.00 1 195 | I-fare_amount 1.00 1.00 1.00 2 196 | I-flight_mod 0.50 0.17 0.25 6 197 | I-flight_number 0.00 0.00 0.00 1 198 | I-flight_time 1.00 1.00 1.00 1 199 | I-fromloc.airport_name 0.45 1.00 0.62 15 200 | I-fromloc.city_name 0.98 1.00 0.99 177 201 | I-fromloc.state_name 1.00 1.00 1.00 1 202 | I-restriction_code 1.00 1.00 1.00 3 203 | I-return_date.date_relative 0.75 1.00 0.86 3 204 | I-round_trip 1.00 1.00 1.00 71 205 | I-state_name 0.00 0.00 0.00 1 206 | I-stoploc.city_name 1.00 1.00 1.00 10 207 | I-toloc.airport_name 1.00 1.00 1.00 3 208 | I-toloc.city_name 0.96 0.99 0.97 265 209 | I-toloc.state_name 1.00 1.00 1.00 1 210 | I-transport_type 0.00 0.00 0.00 1 211 | O 0.00 0.00 0.00 18 212 | 213 | accuracy 0.96 3681 214 | macro avg 0.82 0.81 0.80 3681 215 | weighted avg 0.95 0.96 0.95 3681 216 | 217 | 218 | 219 | -------------------------------------------------------------------------------- /output_model_prediction/atis_join_task_LSTM_epoch30_ckpt4198/predict.tf_record: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanxiaosc/BERT-for-Sequence-Labeling-and-Text-Classification/2a6d2f9c732a362458030643e131540e7d1cdcca/output_model_prediction/atis_join_task_LSTM_epoch30_ckpt4198/predict.tf_record -------------------------------------------------------------------------------- /output_model_prediction/atis_join_task_LSTM_epoch30_ckpt4198/slot_label2id.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanxiaosc/BERT-for-Sequence-Labeling-and-Text-Classification/2a6d2f9c732a362458030643e131540e7d1cdcca/output_model_prediction/atis_join_task_LSTM_epoch30_ckpt4198/slot_label2id.pkl -------------------------------------------------------------------------------- /output_model_prediction/atis_join_task_epoch30_ckpt4198/intent_label2id.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanxiaosc/BERT-for-Sequence-Labeling-and-Text-Classification/2a6d2f9c732a362458030643e131540e7d1cdcca/output_model_prediction/atis_join_task_epoch30_ckpt4198/intent_label2id.pkl -------------------------------------------------------------------------------- /output_model_prediction/atis_join_task_epoch30_ckpt4198/model_score_log.txt: -------------------------------------------------------------------------------- 1 | 时间: 2019-12-03 14:05:28 2 | 准确率: 0.9776035834266518 3 | 宏平均精确率: 0.7271219400703015 4 | 微平均精确率: 0.9776035834266518 5 | 加权平均精确率: 0.9743475364586045 6 | 宏平均召回率: 0.7548824593128391 7 | 微平均召回率: 0.9776035834266518 8 | 加权平均召回率: 0.9776035834266518 9 | 宏平均F1-score: 0.5590581937752302 10 | 微平均F1-score: 0.9776035834266518 11 | 加权平均F1-score: 0.9743986419753211 12 | 13 | 混淆矩阵输出: 14 | [[ 33 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 15 | 0 0] 16 | [ 0 9 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 17 | 0 0] 18 | [ 0 0 48 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 19 | 0 0] 20 | [ 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 21 | 0 0] 22 | [ 0 0 0 0 38 0 0 0 0 0 0 0 0 0 0 0 0 0 23 | 0 0] 24 | [ 0 0 0 0 0 18 0 0 0 0 0 0 0 0 0 0 0 0 25 | 0 0] 26 | [ 0 0 0 0 0 0 21 0 0 0 0 0 0 0 0 0 0 0 27 | 0 0] 28 | [ 0 0 0 0 0 0 0 6 0 0 0 0 0 0 0 0 0 0 29 | 0 0] 30 | [ 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 31 | 0 0] 32 | [ 0 0 0 0 0 0 0 0 0 10 0 0 0 0 0 0 0 0 33 | 0 0] 34 | [ 0 0 0 0 0 1 0 0 0 0 626 1 0 0 0 0 0 0 35 | 0 4] 36 | [ 0 0 0 0 0 0 0 0 0 0 7 5 0 0 0 0 0 0 37 | 0 0] 38 | [ 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 39 | 0 0] 40 | [ 0 0 0 0 0 0 0 0 0 0 0 0 0 8 0 0 0 0 41 | 0 0] 42 | [ 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 43 | 0 0] 44 | [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 45 | 0 0] 46 | [ 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 6 0 47 | 0 0] 48 | [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 36 49 | 0 0] 50 | [ 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 51 | 5 0] 52 | [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 53 | 0 3]] 54 | 分类报告: 55 | precision recall f1-score support 56 | 57 | atis_abbreviation 1.00 1.00 1.00 33 58 | atis_aircraft 1.00 1.00 1.00 9 59 | atis_airfare 0.96 1.00 0.98 48 60 | atis_airfare#atis_flight 0.00 0.00 0.00 1 61 | atis_airline 1.00 1.00 1.00 38 62 | atis_airport 0.95 1.00 0.97 18 63 | atis_capacity 1.00 1.00 1.00 21 64 | atis_city 1.00 1.00 1.00 6 65 | atis_day_name 0.00 0.00 0.00 2 66 | atis_distance 1.00 1.00 1.00 10 67 | atis_flight 0.98 0.99 0.99 632 68 | atis_flight#atis_airfare 0.83 0.42 0.56 12 69 | atis_flight#atis_airline 0.00 0.00 0.00 1 70 | atis_flight_no 0.89 1.00 0.94 8 71 | atis_flight_no#atis_airline 0.00 0.00 0.00 1 72 | atis_flight_time 0.50 1.00 0.67 1 73 | atis_ground_fare 1.00 0.86 0.92 7 74 | atis_ground_service 1.00 1.00 1.00 36 75 | atis_meal 1.00 0.83 0.91 6 76 | atis_quantity 0.43 1.00 0.60 3 77 | 78 | accuracy 0.98 893 79 | macro avg 0.73 0.75 0.73 893 80 | weighted avg 0.97 0.98 0.97 893 81 | 82 | 83 | 84 | 时间: 2019-12-03 14:05:28 85 | 准确率: 0.9559183673469388 86 | 宏平均精确率: 0.8178759021037678 87 | 微平均精确率: 0.9559183673469388 88 | 加权平均精确率: 0.9602153329387616 89 | 宏平均召回率: 0.7879331267814973 90 | 微平均召回率: 0.9559183673469388 91 | 加权平均召回率: 0.9559183673469388 92 | 宏平均F1-score: 0.0 93 | 微平均F1-score: 0.0 94 | 加权平均F1-score: 0 95 | 96 | 混淆矩阵输出: 97 | [[ 32 0 0 ... 0 0 0] 98 | [ 0 32 0 ... 0 0 0] 99 | [ 0 0 101 ... 0 0 0] 100 | ... 101 | [ 0 0 0 ... 1 0 0] 102 | [ 0 0 0 ... 0 0 1] 103 | [ 0 3 0 ... 0 0 0]] 104 | 分类报告: 105 | precision recall f1-score support 106 | 107 | B-aircraft_code 1.00 0.97 0.98 33 108 | B-airline_code 0.86 0.94 0.90 34 109 | B-airline_name 1.00 1.00 1.00 101 110 | B-airport_code 1.00 0.56 0.71 9 111 | B-airport_name 0.80 0.38 0.52 21 112 | B-arrive_date.date_relative 1.00 1.00 1.00 2 113 | B-arrive_date.day_name 0.79 1.00 0.88 11 114 | B-arrive_date.day_number 0.71 0.83 0.77 6 115 | B-arrive_date.month_name 0.71 0.83 0.77 6 116 | B-arrive_time.end_time 0.89 1.00 0.94 8 117 | B-arrive_time.period_of_day 0.75 1.00 0.86 6 118 | B-arrive_time.start_time 0.89 1.00 0.94 8 119 | B-arrive_time.time 0.94 0.97 0.96 34 120 | B-arrive_time.time_relative 0.97 0.97 0.97 31 121 | B-booking_class 0.00 0.00 0.00 1 122 | B-city_name 0.89 0.58 0.70 57 123 | B-class_type 0.96 1.00 0.98 24 124 | B-compartment 0.00 0.00 0.00 1 125 | B-connect 1.00 1.00 1.00 6 126 | B-cost_relative 1.00 0.97 0.99 37 127 | B-day_name 1.00 0.50 0.67 2 128 | B-days_code 1.00 1.00 1.00 1 129 | B-depart_date.date_relative 0.94 1.00 0.97 17 130 | B-depart_date.day_name 1.00 0.99 0.99 212 131 | B-depart_date.day_number 0.98 0.96 0.97 55 132 | B-depart_date.month_name 0.98 0.96 0.97 56 133 | B-depart_date.today_relative 1.00 0.89 0.94 9 134 | B-depart_date.year 1.00 1.00 1.00 3 135 | B-depart_time.end_time 1.00 0.67 0.80 3 136 | B-depart_time.period_mod 0.83 1.00 0.91 5 137 | B-depart_time.period_of_day 1.00 0.92 0.96 130 138 | B-depart_time.start_time 1.00 0.67 0.80 3 139 | B-depart_time.time 0.88 1.00 0.93 57 140 | B-depart_time.time_relative 0.98 0.98 0.98 65 141 | B-economy 1.00 1.00 1.00 6 142 | B-fare_amount 1.00 1.00 1.00 2 143 | B-fare_basis_code 0.85 1.00 0.92 17 144 | B-flight 0.00 0.00 0.00 1 145 | B-flight_days 1.00 1.00 1.00 10 146 | B-flight_mod 1.00 1.00 1.00 24 147 | B-flight_number 0.85 1.00 0.92 11 148 | B-flight_stop 1.00 1.00 1.00 21 149 | B-flight_time 0.50 1.00 0.67 1 150 | B-fromloc.airport_code 0.56 1.00 0.71 5 151 | B-fromloc.airport_name 0.46 1.00 0.63 12 152 | B-fromloc.city_name 0.99 1.00 0.99 704 153 | B-fromloc.state_code 1.00 1.00 1.00 23 154 | B-fromloc.state_name 0.94 1.00 0.97 17 155 | B-meal 0.94 1.00 0.97 16 156 | B-meal_code 0.00 0.00 0.00 1 157 | B-meal_description 1.00 1.00 1.00 10 158 | B-mod 1.00 0.50 0.67 2 159 | B-or 0.50 1.00 0.67 3 160 | B-period_of_day 1.00 0.50 0.67 4 161 | B-restriction_code 1.00 1.00 1.00 4 162 | B-return_date.date_relative 0.50 0.33 0.40 3 163 | B-return_date.day_name 1.00 0.50 0.67 2 164 | B-round_trip 1.00 0.97 0.99 73 165 | B-state_code 1.00 1.00 1.00 1 166 | B-state_name 1.00 0.22 0.36 9 167 | B-stoploc.airport_code 0.00 0.00 0.00 1 168 | B-stoploc.city_name 0.91 1.00 0.95 20 169 | B-toloc.airport_code 1.00 0.75 0.86 4 170 | B-toloc.airport_name 1.00 1.00 1.00 3 171 | B-toloc.city_name 0.97 0.99 0.98 716 172 | B-toloc.country_name 1.00 1.00 1.00 1 173 | B-toloc.state_code 1.00 1.00 1.00 18 174 | B-toloc.state_name 0.93 0.93 0.93 28 175 | B-transport_type 1.00 1.00 1.00 10 176 | I-airline_name 1.00 1.00 1.00 65 177 | I-airport_name 0.85 0.38 0.52 29 178 | I-arrive_date.day_number 0.00 0.00 0.00 0 179 | I-arrive_time.end_time 0.89 1.00 0.94 8 180 | I-arrive_time.start_time 1.00 1.00 1.00 1 181 | I-arrive_time.time 1.00 0.97 0.99 35 182 | I-arrive_time.time_relative 1.00 1.00 1.00 4 183 | I-city_name 1.00 0.47 0.64 30 184 | I-class_type 1.00 1.00 1.00 17 185 | I-cost_relative 1.00 0.67 0.80 3 186 | I-depart_date.day_number 1.00 0.93 0.97 15 187 | I-depart_date.today_relative 0.00 0.00 0.00 0 188 | I-depart_time.end_time 1.00 0.67 0.80 3 189 | I-depart_time.period_of_day 1.00 1.00 1.00 1 190 | I-depart_time.start_time 1.00 1.00 1.00 1 191 | I-depart_time.time 0.96 1.00 0.98 52 192 | I-depart_time.time_relative 0.00 0.00 0.00 1 193 | I-fare_amount 1.00 1.00 1.00 2 194 | I-flight_mod 1.00 0.17 0.29 6 195 | I-flight_number 0.00 0.00 0.00 1 196 | I-flight_time 1.00 1.00 1.00 1 197 | I-fromloc.airport_name 0.44 1.00 0.61 15 198 | I-fromloc.city_name 0.98 1.00 0.99 177 199 | I-fromloc.state_name 1.00 1.00 1.00 1 200 | I-restriction_code 1.00 1.00 1.00 3 201 | I-return_date.date_relative 0.67 0.67 0.67 3 202 | I-round_trip 1.00 1.00 1.00 71 203 | I-state_name 0.00 0.00 0.00 1 204 | I-stoploc.city_name 0.83 1.00 0.91 10 205 | I-toloc.airport_name 1.00 1.00 1.00 3 206 | I-toloc.city_name 0.96 0.99 0.97 265 207 | I-toloc.state_name 1.00 1.00 1.00 1 208 | I-transport_type 0.00 0.00 0.00 1 209 | O 0.00 0.00 0.00 12 210 | 211 | accuracy 0.96 3675 212 | macro avg 0.82 0.79 0.78 3675 213 | weighted avg 0.96 0.96 0.95 3675 214 | 215 | 216 | 217 | -------------------------------------------------------------------------------- /output_model_prediction/atis_join_task_epoch30_ckpt4198/predict.tf_record: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanxiaosc/BERT-for-Sequence-Labeling-and-Text-Classification/2a6d2f9c732a362458030643e131540e7d1cdcca/output_model_prediction/atis_join_task_epoch30_ckpt4198/predict.tf_record -------------------------------------------------------------------------------- /output_model_prediction/atis_join_task_epoch30_ckpt4198/slot_label2id.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanxiaosc/BERT-for-Sequence-Labeling-and-Text-Classification/2a6d2f9c732a362458030643e131540e7d1cdcca/output_model_prediction/atis_join_task_epoch30_ckpt4198/slot_label2id.pkl -------------------------------------------------------------------------------- /output_model_prediction/conll2003ner_epoch3_test653ckpt/label2id.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanxiaosc/BERT-for-Sequence-Labeling-and-Text-Classification/2a6d2f9c732a362458030643e131540e7d1cdcca/output_model_prediction/conll2003ner_epoch3_test653ckpt/label2id.pkl -------------------------------------------------------------------------------- /output_model_prediction/conll2003ner_epoch3_test653ckpt/predict.tf_record: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanxiaosc/BERT-for-Sequence-Labeling-and-Text-Classification/2a6d2f9c732a362458030643e131540e7d1cdcca/output_model_prediction/conll2003ner_epoch3_test653ckpt/predict.tf_record -------------------------------------------------------------------------------- /output_model_prediction/score_summarization.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | file_names = os.listdir() 4 | print(file_names) 5 | score_summarization_f = open("model_score_summarization.txt", "w") 6 | 7 | for file_name in file_names: 8 | if file_name[-4:] == "ckpt": 9 | log_file_path = os.path.join(file_name, "model_score_log.txt") 10 | if os.path.exists(log_file_path): 11 | score_summarization_f.write("*" * 100 + "\n") 12 | score_summarization_f.write("*" * 28 + file_name + "*" * 28 + "\n") 13 | score_summarization_f.write("*" * 100 + "\n") 14 | for line in open(log_file_path): 15 | score_summarization_f.write(line) 16 | 17 | score_summarization_f.close() 18 | -------------------------------------------------------------------------------- /output_model_prediction/snips_join_task_epoch10_test4088ckpt/intent_label2id.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanxiaosc/BERT-for-Sequence-Labeling-and-Text-Classification/2a6d2f9c732a362458030643e131540e7d1cdcca/output_model_prediction/snips_join_task_epoch10_test4088ckpt/intent_label2id.pkl -------------------------------------------------------------------------------- /output_model_prediction/snips_join_task_epoch10_test4088ckpt/predict.tf_record: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanxiaosc/BERT-for-Sequence-Labeling-and-Text-Classification/2a6d2f9c732a362458030643e131540e7d1cdcca/output_model_prediction/snips_join_task_epoch10_test4088ckpt/predict.tf_record -------------------------------------------------------------------------------- /output_model_prediction/snips_join_task_epoch10_test4088ckpt/slot_label2id.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanxiaosc/BERT-for-Sequence-Labeling-and-Text-Classification/2a6d2f9c732a362458030643e131540e7d1cdcca/output_model_prediction/snips_join_task_epoch10_test4088ckpt/slot_label2id.pkl -------------------------------------------------------------------------------- /predefined_task_usage.md: -------------------------------------------------------------------------------- 1 | ## BERT-for-Sequence-Labeling-and-Text-Classification 2 | 3 | ## BERT information 4 | 5 | Take uncased_L-12_H-768_A-12 as an example, which contains the following three files: 6 | + uncased_L-12_H-768_A-12/vocab.txt 7 | + uncased_L-12_H-768_A-12/bert_config.json 8 | + uncased_L-12_H-768_A-12/bert_model.ckpt 9 | 10 | ## Sequence-Labeling-task 序列标注任务 11 | 12 | > Examples of model training usage 13 | 14 | ### ATIS 15 | python run_sequence_labeling.py \ 16 | --task_name="atis" \ 17 | --do_train=True \ 18 | --do_eval=True \ 19 | --do_predict=True \ 20 | --data_dir=data/atis_Intent_Detection_and_Slot_Filling \ 21 | --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \ 22 | --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \ 23 | --init_checkpoint=pretrained_model/uncased_L-12_H-768_A-12/bert_model.ckpt \ 24 | --max_seq_length=128 \ 25 | --train_batch_size=32 \ 26 | --learning_rate=2e-5 \ 27 | --num_train_epochs=3.0 \ 28 | --output_dir=./output_model/atis_Slot_Filling_epoch3/ 29 | ### SNIPS 30 | python run_sequence_labeling.py \ 31 | --task_name="snips" \ 32 | --do_train=True \ 33 | --do_eval=True \ 34 | --do_predict=True \ 35 | --data_dir=data/snips_Intent_Detection_and_Slot_Filling \ 36 | --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \ 37 | --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \ 38 | --init_checkpoint=pretrained_model/uncased_L-12_H-768_A-12/bert_model.ckpt \ 39 | --max_seq_length=128 \ 40 | --train_batch_size=32 \ 41 | --learning_rate=2e-5 \ 42 | --num_train_epochs=3.0 \ 43 | --output_dir=./output_model/snips_Slot_Filling_epochs3/ 44 | ### CoNLL2003NER 45 | python run_sequence_labeling.py \ 46 | --task_name="conll2003ner" \ 47 | --do_train=True \ 48 | --do_eval=True \ 49 | --do_predict=True \ 50 | --data_dir=data/CoNLL2003_NER \ 51 | --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \ 52 | --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \ 53 | --init_checkpoint=pretrained_model/uncased_L-12_H-768_A-12/bert_model.ckpt \ 54 | --max_seq_length=128 \ 55 | --train_batch_size=32 \ 56 | --learning_rate=2e-5 \ 57 | --num_train_epochs=3.0 \ 58 | --output_dir=./output_model/conll2003ner_epoch3/ 59 | ## Sequence labeling task prediction 序列标注任务预测 60 | python run_sequence_labeling.py \ 61 | --task_name="conll2003ner" \ 62 | --do_predict=True \ 63 | --data_dir=data/CoNLL2003_NER \ 64 | --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \ 65 | --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \ 66 | --init_checkpoint=output_model/conll2003ner_epoch3/model.ckpt-653 \ 67 | --output_dir=./output_predict/conll2003ner_epoch3_ckpt653/ 68 | ## Text-Classification Train 文本分类任务训练 69 | 70 | ### ATIS Train 71 | python run_text_classification.py \ 72 | --task_name=atis \ 73 | --do_train=true \ 74 | --do_eval=true \ 75 | --data_dir=data/atis_Intent_Detection_and_Slot_Filling \ 76 | --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \ 77 | --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \ 78 | --init_checkpoint=pretrained_model/uncased_L-12_H-768_A-12/bert_model.ckpt \ 79 | --max_seq_length=128 \ 80 | --train_batch_size=32 \ 81 | --learning_rate=2e-5 \ 82 | --num_train_epochs=3.0 \ 83 | --output_dir=./output_model/atis_Intent_Detection_epochs3/ 84 | ### ATIS Make Predicte 85 | python run_text_classification.py \ 86 | --task_name=atis \ 87 | --do_predict=true \ 88 | --data_dir=data/atis_Intent_Detection_and_Slot_Filling \ 89 | --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \ 90 | --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \ 91 | --init_checkpoint=output_model/atis_Intent_Detection_epochs3/model.ckpt-419 \ 92 | --max_seq_length=128 \ 93 | --output_dir=./output_predict/atis_Intent_Detection_epoch3_ckpt419 94 | ### SNIPS Make Predicte 95 | python run_text_classification.py \ 96 | --task_name=Snips \ 97 | --do_predict=true \ 98 | --data_dir=data/snips_Intent_Detection_and_Slot_Filling \ 99 | --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \ 100 | --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \ 101 | --init_checkpoint=output_model/snips_Intent_Detection_epochs3/model.ckpt-1226 \ 102 | --max_seq_length=128 \ 103 | --output_dir=./output_predict/snips_Intent_Detection_epoch3_ckpt1226/ 104 | ## Joint task training 联合任务训练 105 | 106 | ### SNIPS Train 107 | python run_sequence_labeling_and_text_classification.py \ 108 | --task_name=snips \ 109 | --do_train=true \ 110 | --do_eval=true \ 111 | --data_dir=data/snips_Intent_Detection_and_Slot_Filling \ 112 | --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \ 113 | --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \ 114 | --init_checkpoint=pretrained_model/uncased_L-12_H-768_A-12/bert_model.ckpt \ 115 | --num_train_epochs=3.0 \ 116 | --output_dir=./output_model/snips_join_task_epoch3/ 117 | ### ATIS Train 118 | python run_sequence_labeling_and_text_classification.py \ 119 | --task_name=Atis \ 120 | --do_train=true \ 121 | --do_eval=true \ 122 | --data_dir=data/atis_Intent_Detection_and_Slot_Filling \ 123 | --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \ 124 | --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \ 125 | --init_checkpoint=pretrained_model/uncased_L-12_H-768_A-12/bert_model.ckpt \ 126 | --num_train_epochs=3.0 \ 127 | --output_dir=./output_model/atis_join_task_epoch3/ 128 | ### ATIS Next Train 129 | python run_sequence_labeling_and_text_classification.py \ 130 | --task_name=Atis \ 131 | --do_train=true \ 132 | --do_eval=true \ 133 | --data_dir=data/atis_Intent_Detection_and_Slot_Filling \ 134 | --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \ 135 | --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \ 136 | --init_checkpoint=output_model/atis_join_task_epoch3/model.ckpt-1399 \ 137 | --num_train_epochs=3.0 \ 138 | --output_dir=./output_model/atis_join_task_epoch6/ 139 | ## Joint Mission predict 联合任务预测 140 | 141 | ### SNIPS Make Predicte 142 | python run_sequence_labeling_and_text_classification.py \ 143 | --task_name=Snips \ 144 | --do_predict=true \ 145 | --data_dir=data/snips_Intent_Detection_and_Slot_Filling \ 146 | --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \ 147 | --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \ 148 | --init_checkpoint=output_model/snips_join_task_epoch3/model.ckpt-1000 \ 149 | --max_seq_length=128 \ 150 | --output_dir=./output_predict/snips_join_task_epoch3_ckpt1000 151 | ### ATIS Make Predicte 152 | python run_sequence_labeling_and_text_classification.py \ 153 | --task_name=Atis \ 154 | --do_predict=true \ 155 | --data_dir=data/atis_Intent_Detection_and_Slot_Filling \ 156 | --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \ 157 | --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \ 158 | --init_checkpoint=output_model/atis_join_task_epoch3/model.ckpt-1000 \ 159 | --max_seq_length=128 \ 160 | --output_dir=./output_predict/atis_join_task_epoch30_ckpt1000 -------------------------------------------------------------------------------- /pretrained_model/uncased_L-12_H-768_A-12/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 12, 11 | "type_vocab_size": 2, 12 | "vocab_size": 30522 13 | } 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.7.1 2 | alabaster==0.7.12 3 | allennlp==0.9.0 4 | anaconda-client==1.6.14 5 | anaconda-navigator==1.8.7 6 | anaconda-project==0.8.2 7 | aniso8601==7.0.0 8 | argcomplete==1.9.4 9 | asn1crypto==0.22.0 10 | astor==0.7.1 11 | astroid==2.0.4 12 | astropy==3.0.5 13 | atomicwrites==1.2.1 14 | attrdict==2.0.0 15 | attrs==18.2.0 16 | Automat==0.7.0 17 | aws-xray-sdk==2.2.0 18 | awscli==1.16.39 19 | Babel==2.6.0 20 | backcall==0.1.0 21 | backports.shutil-get-terminal-size==1.0.0 22 | beautifulsoup4==4.6.3 23 | bert==2.2.0 24 | bert-serving-client==1.6.9 25 | bert-serving-server==1.6.9 26 | bert-tensorflow==1.0.1 27 | biscuits==0.1.1 28 | bitarray==0.8.3 29 | bkcharts==0.2 30 | blaze==0.11.3 31 | bleach==1.5.0 32 | blis==0.2.4 33 | bokeh==0.13.0 34 | boto==2.48.0 35 | boto3==1.9.29 36 | botocore==1.12.29 37 | Bottleneck==1.2.1 38 | bx-python==0.8.2 39 | bz2file==0.98 40 | catboost==0.10.4.1 41 | category-encoders==1.3.0 42 | certifi==2019.9.11 43 | certificates==1.0.4 44 | cffi==1.11.5 45 | chardet==3.0.4 46 | chat==1.0.7.dev136 47 | Click==7.0 48 | cloudpickle==0.6.1 49 | clyent==1.2.2 50 | colorama==0.3.9 51 | conda==4.7.12 52 | conda-build==3.10.5 53 | conda-package-handling==1.6.0 54 | conda-verify==2.0.0 55 | conllu==1.3.1 56 | constantly==15.1.0 57 | contextlib2==0.5.5 58 | cookies==2.2.1 59 | cryptography==2.3.1 60 | cssselect==1.0.3 61 | cycler==0.10.0 62 | cymem==2.0.2 63 | cysignals==1.9.0 64 | Cython==0.29 65 | cytoolz==0.9.0.1 66 | dask==0.19.4 67 | dataclasses==0.7 68 | datashape==0.5.4 69 | de-core-news-sm==2.0.0 70 | decorator==4.3.0 71 | defusedxml==0.5.0 72 | dill==0.2.8.2 73 | distributed==1.23.3 74 | Django==2.1.7 75 | docker==3.5.1 76 | docker-pycreds==0.3.0 77 | docopt==0.6.2 78 | docutils==0.14 79 | e==1.4.5 80 | EasyProcess==0.2.7 81 | ecdsa==0.13 82 | editdistance==0.5.2 83 | en-core-web-lg==2.0.0 84 | en-core-web-sm==2.0.0 85 | entrypoints==0.2.3 86 | enum34==1.1.6 87 | erlastic==2.0.0 88 | et-xmlfile==1.0.1 89 | fast-bert==1.4.4 90 | fastai==1.0.59 91 | fastcache==1.0.2 92 | fastprogress==0.1.21 93 | featuretools==0.3.1 94 | filelock==3.0.4 95 | flaky==3.4.0 96 | Flask==1.1.1 97 | Flask-Cors==3.0.8 98 | Flask-RESTful==0.3.6 99 | Flask-Uploads==0.2.1 100 | ftfy==5.5.1 101 | future==0.16.0 102 | gast==0.2.0 103 | gensim==3.7.2 104 | gevent==1.4.0 105 | gitdb2==2.0.5 106 | GitPython==2.1.11 107 | glob2==0.6 108 | gmpy2==2.0.8 109 | google-pasta==0.1.7 110 | GPUtil==1.4.0 111 | graphviz==0.10 112 | greenlet==0.4.15 113 | grpcio==1.15.0 114 | h5py==2.8.0 115 | heapdict==1.0.0 116 | html5lib==0.9999999 117 | humanize==0.5.1 118 | hyperlink==18.0.0 119 | hyperopt==0.2.1 120 | icc-rt==2019.0 121 | idna==2.7 122 | imageio==2.3.0 123 | imagesize==1.1.0 124 | importlib-metadata==0.6 125 | incremental==17.5.0 126 | intel-openmp==2019.0 127 | ipaddress==1.0.22 128 | ipykernel==5.1.0 129 | ipython==7.0.1 130 | ipython-genutils==0.2.0 131 | ipywidgets==7.4.2 132 | isort==4.3.4 133 | ItsDangerous==1.0.0 134 | jdcal==1.4 135 | jedi==0.13.1 136 | jeepney==0.4 137 | jieba==0.39 138 | Jinja2==2.10.3 139 | jmespath==0.9.3 140 | joblib==0.13.2 141 | json5==0.8.5 142 | jsondiff==1.1.2 143 | jsonnet==0.11.2 144 | jsonpickle==1.0 145 | jsonschema==2.6.0 146 | jupyter==1.0.0 147 | jupyter-client==5.2.3 148 | jupyter-console==6.0.0 149 | jupyter-core==4.4.0 150 | jupyterlab==0.35.2 151 | jupyterlab-launcher==0.13.1 152 | jupyterlab-server==0.2.0 153 | Keras==2.2.4 154 | Keras-Applications==1.0.6 155 | Keras-Preprocessing==1.0.5 156 | keyring==15.1.0 157 | kitchen==1.2.5 158 | kiwisolver==1.0.1 159 | lazy-object-proxy==1.3.1 160 | lightgbm==2.2.1 161 | llvmlite==0.23.1 162 | locket==0.2.0 163 | lxml==4.2.5 164 | Markdown==3.0.1 165 | MarkupSafe==1.0 166 | matchzoo-py==1.0 167 | matplotlib==3.0.0 168 | mccabe==0.6.1 169 | mistune==0.8.4 170 | mkl==2019.0 171 | mkl-fft==1.0.6 172 | mkl-random==1.0.1.1 173 | mock==2.0.0 174 | more-itertools==4.3.0 175 | moto==1.3.6 176 | mpmath==1.0.0 177 | msgpack==0.5.6 178 | msgpack-numpy==0.4.3.2 179 | msgpack-python==0.5.6 180 | multipledispatch==0.6.0 181 | murmurhash==1.0.1 182 | navigator-updater==0.2.1 183 | nbconvert==5.4.0 184 | nbformat==4.4.0 185 | neo4j-driver==1.6.2 186 | neobolt==1.7.4 187 | neotime==1.7.4 188 | neptune-cli==2.8.16 189 | networkx==2.2 190 | nltk==3.4.1 191 | nose==1.3.7 192 | notebook==5.7.0 193 | numba==0.38.0 194 | numexpr==2.6.8 195 | numpy==1.16.3 196 | numpydoc==0.8.0 197 | nvidia-ml-py3==7.352.0 198 | oauthlib==2.1.0 199 | odo==0.5.1 200 | olefile==0.45.1 201 | OpenNMT-tf==1.10.1 202 | openpyxl==2.5.9 203 | overrides==1.9 204 | packaging==18.0 205 | pandas==0.24.1 206 | pandocfilters==1.4.2 207 | parsel==1.5.1 208 | parsimonious==0.8.1 209 | parso==0.3.1 210 | partd==0.3.9 211 | path.py==11.5.0 212 | pathlib2==2.3.0 213 | patsy==0.5.0 214 | pbr==3.1.1 215 | pep8==1.7.1 216 | pexpect==4.6.0 217 | pickleshare==0.7.5 218 | Pillow==5.3.0 219 | pinyin==0.4.0 220 | pkginfo==1.4.2 221 | plac==0.9.6 222 | pluggy==0.8.0 223 | ply==3.11 224 | preshed==2.0.1 225 | prometheus-client==0.4.2 226 | prompt-toolkit==2.0.4 227 | protobuf==3.6.1 228 | psutil==5.4.7 229 | ptyprocess==0.5.2 230 | py==1.7.0 231 | py2neo==3.1.2 232 | pyaml==17.12.1 233 | pyasn1==0.4.4 234 | pyasn1-modules==0.2.4 235 | pyBigWig==0.3.11 236 | pycodestyle==2.4.0 237 | pycosat==0.6.3 238 | pycparser==2.19 239 | pycrypto==2.6.1 240 | pycryptodome==3.6.6 241 | pycurl==7.43.0.1 242 | PyDispatcher==2.0.5 243 | pydot==1.2.4 244 | pydot-ng==2.0.0 245 | pyenchant==2.0.0 246 | pyfasttext==0.4.5 247 | pyflakes==2.0.0 248 | Pygments==2.3.1 249 | PyHamcrest==1.9.0 250 | pyhocon==0.3.51 251 | PyJWT==1.6.4 252 | pykwalify==1.5.2 253 | pylint==2.1.1 254 | pymongo==3.7.2 255 | Pympler==0.6 256 | PyMySQL==0.9.3 257 | pyodbc==4.0.24 258 | pyonmttok==1.10.1 259 | pyOpenSSL==17.0.0 260 | pyparsing==2.2.2 261 | pypinyin==0.22.0 262 | PyQt5==5.9.2 263 | PyQt5-sip==4.19.13 264 | pysam==0.15.2 265 | PySocks==1.6.6 266 | pytest==3.9.2 267 | pytest-arraydiff==0.2 268 | pytest-astropy==0.4.0 269 | pytest-doctestplus==0.1.3 270 | pytest-openfiles==0.3.0 271 | pytest-remotedata==0.3.0 272 | python-dateutil==2.7.3 273 | python-jose==2.0.2 274 | python-Levenshtein==0.12.0 275 | pytorch-lamb==1.0.0 276 | pytorch-pretrained-bert==0.6.2 277 | pytorch-transformers==1.1.0 278 | pytz==2018.5 279 | PyVirtualDisplay==0.2.4 280 | PyWavelets==1.0.1 281 | PyYAML==3.12 282 | pyzmq==17.1.2 283 | QtAwesome==0.5.1 284 | qtconsole==4.3.1 285 | QtPy==1.5.2 286 | queuelib==1.5.0 287 | raven==6.9.0 288 | recordtype==1.3 289 | regex==2018.1.10 290 | requests==2.22.0 291 | requests-oauthlib==1.0.0 292 | responses==0.10.1 293 | rope==0.11.0 294 | rouge==0.3.1 295 | rsa==3.4.2 296 | RSeQC==3.0.0 297 | s3fs==0.1.6 298 | s3transfer==0.1.13 299 | sacremoses==0.0.35 300 | scapy==2.4.0 301 | scikit-image==0.14.1 302 | scikit-learn==0.20.3 303 | scipy==1.1.0 304 | Scrapy==1.6.0 305 | seaborn==0.9.0 306 | SecretStorage==3.1.0 307 | selenium==3.141.0 308 | Send2Trash==1.5.0 309 | sentencepiece==0.1.83 310 | service-identity==18.1.0 311 | simplegeneric==0.8.1 312 | singledispatch==3.4.0.3 313 | sip==4.19.8 314 | six==1.10.0 315 | sklearn==0.0 316 | smart-open==1.7.1 317 | smmap2==2.0.5 318 | snowballstemmer==1.2.1 319 | sortedcollections==1.0.1 320 | sortedcontainers==2.0.5 321 | spacy==2.1.9 322 | Sphinx==1.8.1 323 | sphinxcontrib-websupport==1.1.0 324 | spyder==3.2.8 325 | spyder-kernels==1.1.0 326 | SQLAlchemy==1.2.12 327 | sqlparse==0.2.4 328 | srsly==0.2.0 329 | statsmodels==0.9.0 330 | steppy==0.1.15 331 | steppy-toolkit==0.1.13 332 | sympy==1.1.1 333 | tables==3.4.4 334 | tb-nightly==1.14.0a20190603 335 | tbb==2019.0 336 | tbb4py==2019.0 337 | tblib==1.3.2 338 | tensorboard==1.8.0 339 | tensorboardX==1.4 340 | tensorflow==1.8.0 341 | tensorflow-hub==0.1.1 342 | tensorflow-probability==0.4.0 343 | termcolor==1.1.0 344 | terminado==0.8.1 345 | terminaltables==2.1.0 346 | testpath==0.4.2 347 | textblob==0.15.1 348 | tf-estimator-nightly==1.14.0.dev2019060501 349 | thinc==7.0.8 350 | thulac==0.2.0 351 | toolz==0.9.0 352 | torch==1.2.0 353 | torchtext==0.4.0 354 | torchvision==0.3.0 355 | tornado==5.1.1 356 | tqdm==4.31.1 357 | traitlets==4.3.2 358 | transformers==2.1.1 359 | Twisted==18.9.0 360 | typed-ast==1.1.0 361 | typing==3.6.6 362 | ujson==1.35 363 | unicodecsv==0.14.1 364 | Unidecode==1.1.1 365 | urllib3==1.22 366 | voluptuous==0.11.5 367 | w3lib==1.20.0 368 | wasabi==0.4.0 369 | wcwidth==0.1.7 370 | webencodings==0.5.1 371 | websocket-client==0.53.0 372 | Werkzeug==0.16.0 373 | wget==3.2 374 | widgetsnbextension==3.4.2 375 | word2number==1.1 376 | wrapt==1.10.11 377 | xgboost==0.80 378 | xlrd==1.0.0 379 | XlsxWriter==1.1.2 380 | xlwt==1.2.0 381 | xmltodict==0.11.0 382 | zhon==1.1.5 383 | zict==0.1.3 384 | zope.interface==4.6.0 385 | -------------------------------------------------------------------------------- /store_fine_tuned_model/download_url.md: -------------------------------------------------------------------------------- 1 | download url https://pan.baidu.com/s/1SZkQXP8NrOtZKVEMfDE4bw 2 | --------------------------------------------------------------------------------