├── .gitignore ├── .idea ├── misc.xml ├── modules.xml ├── pytorch-pretrained-BERT_annotation.iml ├── vcs.xml └── workspace.xml ├── LICENSE ├── MANIFEST.in ├── README.md ├── README_bert.md ├── docker └── Dockerfile ├── download_data ├── download_v1.1.sh └── download_v2.0.sh ├── examples ├── analysic_lic_data.py ├── analysic_pred_zhidao.py ├── analysic_pred_zhidao1.py ├── analysic_squad_data.py ├── clean_result.py ├── create_submit_file.py ├── evaluate-v1.1.py ├── evaluate-v2.0.py ├── extract_features.py ├── run_classifier.py ├── run_lm_finetuning.py ├── run_squad.py ├── run_squad2.py ├── run_squad_zh.py ├── run_swag.py ├── softmax.py ├── squad_v1.1_arch_sample.json ├── test_BertForMaskedLM.py ├── test_BertModel.py ├── test_squad.py ├── test_tokenization.py └── valid_data.py ├── notebooks ├── Comparing-TF-and-PT-models-MLM-NSP.ipynb ├── Comparing-TF-and-PT-models-SQuAD.ipynb └── Comparing-TF-and-PT-models.ipynb ├── pytorch_pretrained_bert ├── __init__.py ├── __main__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── file_utils.cpython-36.pyc │ ├── modeling.cpython-36.pyc │ ├── optimization.cpython-36.pyc │ └── tokenization.cpython-36.pyc ├── convert_tf_checkpoint_to_pytorch.py ├── file_utils.py ├── modeling.py ├── optimization.py └── tokenization.py ├── requirements.txt ├── samples ├── input.txt └── sample_text.txt ├── setup.py └── tests ├── modeling_test.py ├── optimization_test.py └── tokenization_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | examples/transfo_format_after_extract_dataset.py 2 | examples/transfo_format_after_extract_dataset_splitter.py 3 | examples/transfo_format_after_extract_dataset_splitter.py 4 | examples/transfo_format.py 5 | examples/transfo_format_nospace.py 6 | pytorch_pretrained_bert/modeling_ks.py 7 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/pytorch-pretrained-BERT_annotation.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Pretrained Bert Annotation 2 | 3 | > This BERT annotation repo is for my personal study. 4 | 5 | - The raw README of PyTorch Pretrained Bert is [here](README_bert.md). 6 | - A very nice [PPT](https://nlp.stanford.edu/seminar/details/lkaiser.pdf) to help understanding. 7 | - Synthetic Self-Training [PPT](https://nlp.stanford.edu/seminar/details/jdevlin.pdf?fbclid=IwAR2TBFCJOeZ9cGhxB-z5cJJ17vHN4W25oWsjI8NqJoTEmlYIYEKG7oh4tlY). 8 | 9 | ## Arch 10 | 11 | The BertModel and BertForMaskedLM arch. 12 | 13 | #### BertModel Arch 14 | - BertEmbeddings 15 | - word_embeddings: Embedding(30522, 768) 16 | - position_embeddings: Embedding(512, 768) 17 | - token_type_embeddings: Embedding(2, 768) 18 | - LayerNorm: BertLayerNorm() 19 | - dropout: Dropout(p=0.1) 20 | - BertEncoder 21 | - BertLayer: (12 layers) 22 | - BertAttention 23 | - BertSelfAttention 24 | - query: Linear(in_features=768, out_features=768, bias=True) 25 | - key: Linear(in_features=768, out_features=768, bias=True) 26 | - value: Linear(in_features=768, out_features=768, bias=True) 27 | - dropout: Dropout(p=0.1) 28 | - BertSelfOutput 29 | - dense: Linear(in_features=768, out_features=768, bias=True) 30 | - LayerNorm: BertLayerNorm() 31 | - dropout: Dropout(p=0.1) 32 | - BertIntermediate 33 | - dense: Linear(in_features=768, out_features=3072, bias=True) 34 | - activation: gelu 35 | - BertOutput 36 | - dense: Linear(in_features=3072, out_features=768, bias=True) 37 | - LayerNorm: BertLayerNorm() 38 | - dropout: Dropout(p=0.1) 39 | - BertPooler 40 | - dense: Linear(in_features=768, out_features=768, bias=True) 41 | - activation: Tanh() 42 | 43 | #### BertForMaskedLM Arch 44 | - BertModel 45 | - BertEmbeddings 46 | - word_embeddings: Embedding(30522, 768) 47 | - position_embeddings: Embedding(512, 768) 48 | - token_type_embeddings: Embedding(2, 768) 49 | - LayerNorm: BertLayerNorm() 50 | - dropout: Dropout(p=0.1) 51 | - BertEncoder 52 | - BertLayer: (12 layers) 53 | - BertAttention 54 | - BertSelfAttention 55 | - query: Linear(in_features=768, out_features=768, bias=True) 56 | - key: Linear(in_features=768, out_features=768, bias=True) 57 | - value: Linear(in_features=768, out_features=768, bias=True) 58 | - dropout: Dropout(p=0.1) 59 | - BertSelfOutput 60 | - dense: Linear(in_features=768, out_features=768, bias=True) 61 | - LayerNorm: BertLayerNorm() 62 | - dropout: Dropout(p=0.1) 63 | - BertIntermediate 64 | - dense: Linear(in_features=768, out_features=3072, bias=True) 65 | - activation: gelu 66 | - BertOutput 67 | - dense: Linear(in_features=3072, out_features=768, bias=True) 68 | - LayerNorm: BertLayerNorm() 69 | - dropout: Dropout(p=0.1) 70 | - BertPooler 71 | - dense: Linear(in_features=768, out_features=768, bias=True) 72 | - activation: Tanh() 73 | - BertOnlyMLMHead 74 | - BertLMPredictionHead 75 | - transform: BertPredictionHeadTransform 76 | - dense: Linear(in_features=768, out_features=768, bias=True) 77 | - LayerNorm: BertLayerNorm() 78 | - decoder: Linear(in_features=768, out_features=30522, bias=False) 79 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:latest 2 | 3 | RUN git clone https://github.com/NVIDIA/apex.git && cd apex && python setup.py install --cuda_ext --cpp_ext 4 | 5 | RUN pip install pytorch-pretrained-bert 6 | 7 | WORKDIR /workspace -------------------------------------------------------------------------------- /download_data/download_v1.1.sh: -------------------------------------------------------------------------------- 1 | # Download SQuAD1.1 Data 2 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json 3 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json 4 | wget https://raw.githubusercontent.com/allenai/bi-att-flow/master/squad/evaluate-v1.1.py 5 | -------------------------------------------------------------------------------- /download_data/download_v2.0.sh: -------------------------------------------------------------------------------- 1 | # Download the SQuAD2.0 dataset 2 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json 3 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json 4 | wget https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/ -O evaluate-v2.0.py 5 | wget https://worksheets.codalab.org/rest/bundles/0x8731effab84f41b7b874a070e40f61e2/contents/blob/ -O dev-evaluate-v2.0-in1 6 | -------------------------------------------------------------------------------- /examples/analysic_lic_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | BASE_PATH = "/home/wyb/PycharmProjects/DuReader/data/demo/" 4 | 5 | 6 | with open(BASE_PATH + "trainset/search.train.json", "r", encoding='utf-8') as reader: 7 | source = reader.readlines() 8 | 9 | # source = json.load(reader) 10 | # input_data = source["data"] 11 | # version = source["version"] 12 | 13 | 14 | # print(len(source)) 15 | # print(type(source)) # 16 | 17 | """ 18 | keys: (one documents) 19 | documents 20 | answer_spans 21 | fake_answers 22 | question 23 | segmented_answers 24 | answers 25 | answer_docs 26 | segmented_question 27 | question_type 28 | question_id 29 | fact_or_opinion 30 | match_scores 31 | """ 32 | line_json = json.loads(source[0]) 33 | for i in line_json.keys(): 34 | print(i) 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /examples/analysic_pred_zhidao.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | BASE_PATH = "/DATA/disk1/wangyongbo/lic2019/DuReader/data/extracted/" 4 | 5 | with open(BASE_PATH + "dureader/predictions.json", "r", encoding="utf-8") as f: 6 | data = json.load(f) 7 | 8 | with open(BASE_PATH + "devset/zhidao.dev.json", "r", encoding='utf-8') as f1: 9 | lines = f1.readlines() 10 | 11 | 12 | def get_best_ans(): 13 | """ 14 | To compare the prob of sents, select the best answers. 15 | """ 16 | with open(BASE_PATH + "dureader/nbest_predictions.json", "r", encoding='utf-8') as f2: 17 | data_n = json.load(f2) 18 | 19 | nbest_para = [] 20 | for k, v in data_n.items(): 21 | para_dict = {} 22 | id = k.split("_")[0] 23 | prob = 0 24 | text = "" 25 | for sents in v: 26 | if sents["probability"] > prob: 27 | prob = sents["probability"] 28 | text = sents["text"] 29 | para_dict["id"] = id 30 | para_dict["prob"] = prob 31 | para_dict["text"] = text 32 | 33 | if nbest_para: 34 | for item in nbest_para: 35 | if id == item["id"]: 36 | if prob > item["prob"]: 37 | item["prob"] = prob 38 | item["text"] = text 39 | else: 40 | nbest_para.append(para_dict) 41 | else: 42 | nbest_para.append(para_dict) 43 | 44 | return nbest_para 45 | 46 | 47 | nbest_para = get_best_ans() 48 | print("===============> nbest_para completed!") 49 | 50 | for line in lines: # raw 51 | sample = json.loads(line) 52 | ans_list = [] 53 | for k, v in data.items(): # pred 54 | if str(sample["question_id"]) == (str(k)).split("_")[0]: 55 | ans_list.append(v) 56 | print("------------------------------------------------------") 57 | print("question_id: " + (str(k)).split("_")[0]) 58 | if sample["fake_answers"]: 59 | print("fake_answers: \n" + str(sample["fake_answers"][0])) 60 | 61 | print(" ") 62 | 63 | print("answer: count=" + str(len(sample["answers"]))) 64 | for idx,ans_item in enumerate(sample["answers"]): 65 | print(str(idx) + "==> " + ans_item) 66 | 67 | print(" ") 68 | 69 | print("pred answer: count=" + str(len(ans_list))) 70 | answer_docs_id = -1 71 | if "answer_docs" in sample and sample["answer_docs"] and sample["answer_docs"][0] < len(ans_list): 72 | answer_docs_id = sample["answer_docs"][0] 73 | 74 | for idx, ans_item in enumerate(ans_list): 75 | state1 = "" # flag of 'has fake_answers' 76 | state2 = "" # flag of 'best answers' 77 | for ans in nbest_para: 78 | if ans_item == ans["text"]: 79 | state2 = "(pred BEST answers)" 80 | 81 | if idx == answer_docs_id: 82 | state1 = "(has fake_answers)" 83 | 84 | print(str(idx) + "==>" + state1 + state2 + " " + ans_item) 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /examples/analysic_pred_zhidao1.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | BASE_PATH = "/DATA/disk1/wangyongbo/lic2019/DuReader/data/extracted/" 4 | 5 | 6 | with open(BASE_PATH + "dureader/predictions.json", "r", encoding="utf-8") as f: 7 | data = json.load(f) 8 | 9 | with open(BASE_PATH + "dureader/predictions_filter.json", "r", encoding="utf-8") as f: 10 | data_filter = json.load(f) 11 | 12 | with open(BASE_PATH + "devset/zhidao.dev.json", "r", encoding='utf-8') as f1: 13 | lines = f1.readlines() 14 | 15 | # nbest_para = get_best_ans() 16 | 17 | for line in lines: # raw 18 | sample = json.loads(line) 19 | ans_list = [] 20 | for k, v in data.items(): # pred 21 | if str(sample["question_id"]) == (str(k)).split("_")[0]: 22 | ans_list.append(v[0]) 23 | print("------------------------------------------------------") 24 | print("question_id: " + (str(k)).split("_")[0]) 25 | if sample["fake_answers"]: 26 | print("fake_answers: \n" + str(sample["fake_answers"][0])) 27 | 28 | print(" ") 29 | 30 | print("answer: count=" + str(len(sample["answers"]))) 31 | for idx,ans_item in enumerate(sample["answers"]): 32 | print(str(idx) + "==> " + ans_item) 33 | 34 | print(" ") 35 | 36 | print("pred answer: count=" + str(len(ans_list))) 37 | answer_docs_id = -1 38 | if "answer_docs" in sample and sample["answer_docs"] and sample["answer_docs"][0] < len(ans_list): 39 | answer_docs_id = sample["answer_docs"][0] 40 | 41 | for idx, ans_item in enumerate(ans_list): 42 | state1 = "" # flag of 'has fake_answers' 43 | state2 = "" # flag of 'best answers' 44 | for k1, v1 in data_filter.items(): 45 | if v1 == ans_item: 46 | state2 = "(pred BEST answers)" 47 | 48 | if idx == answer_docs_id: 49 | state1 = "(has fake_answers)" 50 | 51 | print(str(idx) + "==>" + state1 + state2 + " " + ans_item) 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /examples/analysic_squad_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | BASE_PATH = "/home/wyb/data/squad_v2.0/" 4 | 5 | 6 | with open(BASE_PATH + "train-v2.0.json", "r", encoding='utf-8') as reader: 7 | source = json.load(reader) 8 | input_data = source["data"] 9 | version = source["version"] 10 | # 11 | # 12 | # examples = [] 13 | # for entry in input_data: 14 | # """ 15 | # entry format: 16 | # {"title": xxx, "paragraphs": xxxx} 17 | # """ 18 | # for paragraph in entry["paragraphs"]: 19 | # 20 | # paragraph_text = paragraph["context"] 21 | 22 | 23 | paragraph_text = 'Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say).' 24 | doc_tokens = [] 25 | char_to_word_offset = [] 26 | prev_is_whitespace = True 27 | for c in paragraph_text: # by char 28 | if is_whitespace(c): 29 | prev_is_whitespace = True 30 | else: 31 | if prev_is_whitespace: 32 | doc_tokens.append(c) 33 | else: 34 | doc_tokens[-1] += c 35 | prev_is_whitespace = False 36 | char_to_word_offset.append(len(doc_tokens) - 1) 37 | 38 | print(doc_tokens) 39 | print("----------------") 40 | print(char_to_word_offset) 41 | 42 | 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /examples/clean_result.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | 4 | BASE_PATH = "/home/wyb/Downloads/" 5 | 6 | 7 | with open(BASE_PATH + "test_result.json", "r", encoding="utf-8") as f: 8 | res = f.readlines() # list, len=120000 9 | 10 | # text = "This is a \n file \r that \r hello\r!" 11 | 12 | 13 | def clean_sepc_char(text): 14 | replace_p = ["\t", "\n", "\r", "\u3000", "", "/>", "\\x0a", ""\'' 33 | C_pun = u',。!?【】()《》“‘' 34 | table = {ord(f): ord(t) for f, t in zip(E_pun, C_pun)} 35 | 36 | return string.translate(table) 37 | 38 | 39 | # for i in res: 40 | # data = json.loads(i) 41 | # if "&" in data["answers"][0]: 42 | # print(data) 43 | 44 | 45 | # cate_type = set() # {'DESCRIPTION', 'YES_NO', 'ENTITY'} 46 | # for i in res: 47 | # data = json.loads(i) 48 | # if data["question_type"] == "YES_NO": 49 | # print(data) 50 | 51 | 52 | # text = '小箭头。

iiiiiiiiiiiiiiiiiiiii

2.点击小箭头,则就是筛选。' 53 | 54 | 55 | def remove_html(text): 56 | reg = re.compile(r'<[^>]+>', re.S) 57 | text = reg.sub('', text) 58 | 59 | return text 60 | 61 | 62 | """ 63 | { 64 | "question_id": 403770, 65 | "question_type": "YES_NO", 66 | "answers": ["我都是免费几分钟测试可以玩而已。"], 67 | "entity_answers": [[]], 68 | "yesno_answers": [] 69 | } 70 | """ 71 | json_list = [] 72 | for i in res: 73 | item_dict = {} 74 | 75 | data = json.loads(i) 76 | text = data["answers"][0] 77 | text = E_trans_to_C(text) 78 | text = clean_sepc_char(text) 79 | text = remove_html(text) 80 | 81 | item_dict["question_id"] = data["question_id"] 82 | item_dict["question_type"] = data["question_type"] 83 | item_dict["answers"] = [text] 84 | item_dict["entity_answers"] = data["entity_answers"] 85 | item_dict["yesno_answers"] = data["yesno_answers"] 86 | 87 | json_list.append(item_dict) 88 | 89 | # ================ write to file ================ 90 | with open(BASE_PATH + "test_result_rm.json", 'w') as fout: 91 | for pred_answer in json_list: 92 | fout.write(json.dumps(pred_answer, ensure_ascii=False) + '\n') 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /examples/create_submit_file.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | BASE_PATH = "/DATA/disk1/wangyongbo/lic2019/DuReader/official_data/extracted/" 4 | 5 | """ 6 | /DATA/disk1/wangyongbo/lic2019/DuReader/official_data/extracted/test1set 7 | 8 | { 9 | "question_id": 397032, 10 | "question_type": "ENTITY", 11 | "answers": ["浙江绿谷,秀山丽水。"], 12 | "entity_answers": [[]], 13 | "yesno_answers": [] 14 | } 15 | """ 16 | datasets = ["search", "zhidao"] 17 | for dataset in datasets: 18 | with open(BASE_PATH + "results/predictions_first_filter_" + dataset + ".json", "r") as f: 19 | data_n = json.load(f) 20 | 21 | with open(BASE_PATH + "test1set/" + dataset + ".test1.json", "r", encoding="utf-8") as f: 22 | lines = f.readlines() 23 | 24 | res = [] 25 | test_ids = [] 26 | pred_search_ids = [] 27 | for line in lines: 28 | line_json = json.loads(line) 29 | # avoid loss sample in predictions, save all ids to a list. 30 | test_ids.append(int(line_json["question_id"])) 31 | for k, v in data_n: 32 | pred_search_ids.append(int(k)) 33 | if str(line_json["question_id"]) == str(k): 34 | res_line = {} 35 | res_line["question_id"] = int(k) 36 | res_line["question_type"] = line_json["question_type"] 37 | res_line["answers"] = [v] 38 | res_line["entity_answers"] = [[]] 39 | res_line["yesno_answers"] = [] 40 | res.append(res_line) 41 | 42 | if len(res) != 30000: 43 | # fill in loss sample with "" (no answer) 44 | for id in test_ids: 45 | if id not in pred_search_ids: 46 | for line in lines: 47 | line_json = json.loads(line) 48 | if str(line_json["question_id"]) == str(id): 49 | res_line = {} 50 | res_line["question_id"] = int(id) 51 | res_line["question_type"] = line_json["question_type"] 52 | res_line["answers"] = [""] 53 | res_line["entity_answers"] = [[]] 54 | res_line["yesno_answers"] = [] 55 | res.append(res_line) 56 | 57 | with open(BASE_PATH + "results/test_result_" + dataset + ".json", 'w') as fout: 58 | for pred_answer in res: 59 | fout.write(json.dumps(pred_answer, ensure_ascii=False) + '\n') 60 | -------------------------------------------------------------------------------- /examples/evaluate-v1.1.py: -------------------------------------------------------------------------------- 1 | """ Official evaluation script for v1.1 of the SQuAD dataset. """ 2 | from __future__ import print_function 3 | from collections import Counter 4 | import string 5 | import re 6 | import argparse 7 | import json 8 | import sys 9 | """ 10 | Exec: 11 | python evaluate-v1.1.py ./dev-v1.1.json /tmp/debug_squad/predictions.json 12 | 13 | Results: 14 | {"exact_match": 80.49195837275307, "f1": 88.05701702878619} 15 | """ 16 | 17 | 18 | def normalize_answer(s): 19 | """Lower text and remove punctuation, articles and extra whitespace.""" 20 | def remove_articles(text): 21 | return re.sub(r'\b(a|an|the)\b', ' ', text) 22 | 23 | def white_space_fix(text): 24 | return ' '.join(text.split()) 25 | 26 | def remove_punc(text): 27 | exclude = set(string.punctuation) 28 | return ''.join(ch for ch in text if ch not in exclude) 29 | 30 | def lower(text): 31 | return text.lower() 32 | 33 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 34 | 35 | 36 | def f1_score(prediction, ground_truth): 37 | prediction_tokens = normalize_answer(prediction).split() 38 | ground_truth_tokens = normalize_answer(ground_truth).split() 39 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 40 | num_same = sum(common.values()) 41 | if num_same == 0: 42 | return 0 43 | precision = 1.0 * num_same / len(prediction_tokens) 44 | recall = 1.0 * num_same / len(ground_truth_tokens) 45 | f1 = (2 * precision * recall) / (precision + recall) 46 | return f1 47 | 48 | 49 | def exact_match_score(prediction, ground_truth): 50 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 51 | 52 | 53 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 54 | scores_for_ground_truths = [] 55 | for ground_truth in ground_truths: 56 | score = metric_fn(prediction, ground_truth) 57 | scores_for_ground_truths.append(score) 58 | return max(scores_for_ground_truths) 59 | 60 | 61 | def evaluate(dataset, predictions): 62 | f1 = exact_match = total = 0 63 | for article in dataset: 64 | for paragraph in article['paragraphs']: 65 | for qa in paragraph['qas']: 66 | total += 1 67 | if qa['id'] not in predictions: 68 | message = 'Unanswered question ' + qa['id'] + \ 69 | ' will receive score 0.' 70 | print(message, file=sys.stderr) 71 | continue 72 | ground_truths = list(map(lambda x: x['text'], qa['answers'])) 73 | prediction = predictions[qa['id']] 74 | exact_match += metric_max_over_ground_truths( 75 | exact_match_score, prediction, ground_truths) 76 | f1 += metric_max_over_ground_truths( 77 | f1_score, prediction, ground_truths) 78 | 79 | exact_match = 100.0 * exact_match / total 80 | f1 = 100.0 * f1 / total 81 | 82 | return {'exact_match': exact_match, 'f1': f1} 83 | 84 | 85 | if __name__ == '__main__': 86 | expected_version = '1.1' 87 | parser = argparse.ArgumentParser( 88 | description='Evaluation for SQuAD ' + expected_version) 89 | parser.add_argument('dataset_file', help='Dataset file') 90 | parser.add_argument('prediction_file', help='Prediction File') 91 | args = parser.parse_args() 92 | with open(args.dataset_file) as dataset_file: 93 | dataset_json = json.load(dataset_file) 94 | if (dataset_json['version'] != expected_version): 95 | print('Evaluation expects v-' + expected_version + 96 | ', but got dataset with v-' + dataset_json['version'], 97 | file=sys.stderr) 98 | dataset = dataset_json['data'] 99 | with open(args.prediction_file) as prediction_file: 100 | predictions = json.load(prediction_file) 101 | print(json.dumps(evaluate(dataset, predictions))) -------------------------------------------------------------------------------- /examples/evaluate-v2.0.py: -------------------------------------------------------------------------------- 1 | """Official evaluation script for SQuAD version 2.0. 2 | 3 | In addition to basic functionality, we also compute additional statistics and 4 | plot precision-recall curves if an additional na_prob.json file is provided. 5 | This file is expected to map question ID's to the model's predicted probability 6 | that a question is unanswerable. 7 | """ 8 | import argparse 9 | import collections 10 | import json 11 | import numpy as np 12 | import os 13 | import re 14 | import string 15 | import sys 16 | 17 | OPTS = None 18 | 19 | 20 | def parse_args(): 21 | """ 22 | python evaluate-v2.0.py 23 | 24 | bert-base-uncased 25 | 26 | EXEC: 27 | python evaluate-v2.0.py ./dev-v2.0.json /tmp/debug_squad2/predictions.json 28 | RES: 29 | { 30 | "exact": 70.78244757011707, 31 | "f1": 74.11532024041503, 32 | "total": 11873, 33 | "HasAns_exact": 71.72739541160594, 34 | "HasAns_f1": 78.40269858543304, 35 | "HasAns_total": 5928, 36 | "NoAns_exact": 69.84020185029436, 37 | "NoAns_f1": 69.84020185029436, 38 | "NoAns_total": 5945 39 | } 40 | """ 41 | parser = argparse.ArgumentParser('Official evaluation script for SQuAD version 2.0.') 42 | parser.add_argument('data_file', metavar='data.json', help='Input data JSON file.') 43 | parser.add_argument('pred_file', metavar='pred.json', help='Model predictions.') 44 | parser.add_argument('--out-file', '-o', metavar='eval.json', 45 | help='Write accuracy metrics to file (default is stdout).') 46 | parser.add_argument('--na-prob-file', '-n', metavar='na_prob.json', 47 | help='Model estimates of probability of no answer.') 48 | parser.add_argument('--na-prob-thresh', '-t', type=float, default=1.0, 49 | help='Predict "" if no-answer probability exceeds this (default = 1.0).') 50 | parser.add_argument('--out-image-dir', '-p', metavar='out_images', default=None, 51 | help='Save precision-recall curves to directory.') 52 | parser.add_argument('--verbose', '-v', action='store_true') 53 | if len(sys.argv) == 1: 54 | parser.print_help() 55 | sys.exit(1) 56 | return parser.parse_args() 57 | 58 | 59 | def make_qid_to_has_ans(dataset): 60 | qid_to_has_ans = {} 61 | for article in dataset: 62 | for p in article['paragraphs']: 63 | for qa in p['qas']: 64 | qid_to_has_ans[qa['id']] = bool(qa['answers']) 65 | return qid_to_has_ans 66 | 67 | def normalize_answer(s): 68 | """Lower text and remove punctuation, articles and extra whitespace.""" 69 | def remove_articles(text): 70 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 71 | return re.sub(regex, ' ', text) 72 | def white_space_fix(text): 73 | return ' '.join(text.split()) 74 | def remove_punc(text): 75 | exclude = set(string.punctuation) 76 | return ''.join(ch for ch in text if ch not in exclude) 77 | def lower(text): 78 | return text.lower() 79 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 80 | 81 | def get_tokens(s): 82 | if not s: return [] 83 | return normalize_answer(s).split() 84 | 85 | def compute_exact(a_gold, a_pred): 86 | return int(normalize_answer(a_gold) == normalize_answer(a_pred)) 87 | 88 | def compute_f1(a_gold, a_pred): 89 | gold_toks = get_tokens(a_gold) 90 | pred_toks = get_tokens(a_pred) 91 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 92 | num_same = sum(common.values()) 93 | if len(gold_toks) == 0 or len(pred_toks) == 0: 94 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 95 | return int(gold_toks == pred_toks) 96 | if num_same == 0: 97 | return 0 98 | precision = 1.0 * num_same / len(pred_toks) 99 | recall = 1.0 * num_same / len(gold_toks) 100 | f1 = (2 * precision * recall) / (precision + recall) 101 | return f1 102 | 103 | def get_raw_scores(dataset, preds): 104 | exact_scores = {} 105 | f1_scores = {} 106 | for article in dataset: 107 | for p in article['paragraphs']: 108 | for qa in p['qas']: 109 | qid = qa['id'] 110 | gold_answers = [a['text'] for a in qa['answers'] 111 | if normalize_answer(a['text'])] 112 | if not gold_answers: 113 | # For unanswerable questions, only correct answer is empty string 114 | gold_answers = [''] 115 | if qid not in preds: 116 | print('Missing prediction for %s' % qid) 117 | continue 118 | a_pred = preds[qid] 119 | # Take max over all gold answers 120 | exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers) 121 | f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers) 122 | return exact_scores, f1_scores 123 | 124 | def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): 125 | new_scores = {} 126 | for qid, s in scores.items(): 127 | pred_na = na_probs[qid] > na_prob_thresh 128 | if pred_na: 129 | new_scores[qid] = float(not qid_to_has_ans[qid]) 130 | else: 131 | new_scores[qid] = s 132 | return new_scores 133 | 134 | def make_eval_dict(exact_scores, f1_scores, qid_list=None): 135 | if not qid_list: 136 | total = len(exact_scores) 137 | return collections.OrderedDict([ 138 | ('exact', 100.0 * sum(exact_scores.values()) / total), 139 | ('f1', 100.0 * sum(f1_scores.values()) / total), 140 | ('total', total), 141 | ]) 142 | else: 143 | total = len(qid_list) 144 | return collections.OrderedDict([ 145 | ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total), 146 | ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total), 147 | ('total', total), 148 | ]) 149 | 150 | def merge_eval(main_eval, new_eval, prefix): 151 | for k in new_eval: 152 | main_eval['%s_%s' % (prefix, k)] = new_eval[k] 153 | 154 | def plot_pr_curve(precisions, recalls, out_image, title): 155 | plt.step(recalls, precisions, color='b', alpha=0.2, where='post') 156 | plt.fill_between(recalls, precisions, step='post', alpha=0.2, color='b') 157 | plt.xlabel('Recall') 158 | plt.ylabel('Precision') 159 | plt.xlim([0.0, 1.05]) 160 | plt.ylim([0.0, 1.05]) 161 | plt.title(title) 162 | plt.savefig(out_image) 163 | plt.clf() 164 | 165 | def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, 166 | out_image=None, title=None): 167 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 168 | true_pos = 0.0 169 | cur_p = 1.0 170 | cur_r = 0.0 171 | precisions = [1.0] 172 | recalls = [0.0] 173 | avg_prec = 0.0 174 | for i, qid in enumerate(qid_list): 175 | if qid_to_has_ans[qid]: 176 | true_pos += scores[qid] 177 | cur_p = true_pos / float(i+1) 178 | cur_r = true_pos / float(num_true_pos) 179 | if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]: 180 | # i.e., if we can put a threshold after this point 181 | avg_prec += cur_p * (cur_r - recalls[-1]) 182 | precisions.append(cur_p) 183 | recalls.append(cur_r) 184 | if out_image: 185 | plot_pr_curve(precisions, recalls, out_image, title) 186 | return {'ap': 100.0 * avg_prec} 187 | 188 | def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, 189 | qid_to_has_ans, out_image_dir): 190 | if out_image_dir and not os.path.exists(out_image_dir): 191 | os.makedirs(out_image_dir) 192 | num_true_pos = sum(1 for v in qid_to_has_ans.values() if v) 193 | if num_true_pos == 0: 194 | return 195 | pr_exact = make_precision_recall_eval( 196 | exact_raw, na_probs, num_true_pos, qid_to_has_ans, 197 | out_image=os.path.join(out_image_dir, 'pr_exact.png'), 198 | title='Precision-Recall curve for Exact Match score') 199 | pr_f1 = make_precision_recall_eval( 200 | f1_raw, na_probs, num_true_pos, qid_to_has_ans, 201 | out_image=os.path.join(out_image_dir, 'pr_f1.png'), 202 | title='Precision-Recall curve for F1 score') 203 | oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()} 204 | pr_oracle = make_precision_recall_eval( 205 | oracle_scores, na_probs, num_true_pos, qid_to_has_ans, 206 | out_image=os.path.join(out_image_dir, 'pr_oracle.png'), 207 | title='Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)') 208 | merge_eval(main_eval, pr_exact, 'pr_exact') 209 | merge_eval(main_eval, pr_f1, 'pr_f1') 210 | merge_eval(main_eval, pr_oracle, 'pr_oracle') 211 | 212 | def histogram_na_prob(na_probs, qid_list, image_dir, name): 213 | if not qid_list: 214 | return 215 | x = [na_probs[k] for k in qid_list] 216 | weights = np.ones_like(x) / float(len(x)) 217 | plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0)) 218 | plt.xlabel('Model probability of no-answer') 219 | plt.ylabel('Proportion of dataset') 220 | plt.title('Histogram of no-answer probability: %s' % name) 221 | plt.savefig(os.path.join(image_dir, 'na_prob_hist_%s.png' % name)) 222 | plt.clf() 223 | 224 | def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): 225 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) 226 | cur_score = num_no_ans 227 | best_score = cur_score 228 | best_thresh = 0.0 229 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 230 | for i, qid in enumerate(qid_list): 231 | if qid not in scores: continue 232 | if qid_to_has_ans[qid]: 233 | diff = scores[qid] 234 | else: 235 | if preds[qid]: 236 | diff = -1 237 | else: 238 | diff = 0 239 | cur_score += diff 240 | if cur_score > best_score: 241 | best_score = cur_score 242 | best_thresh = na_probs[qid] 243 | return 100.0 * best_score / len(scores), best_thresh 244 | 245 | def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): 246 | best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) 247 | best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) 248 | main_eval['best_exact'] = best_exact 249 | main_eval['best_exact_thresh'] = exact_thresh 250 | main_eval['best_f1'] = best_f1 251 | main_eval['best_f1_thresh'] = f1_thresh 252 | 253 | def main(): 254 | with open(OPTS.data_file) as f: 255 | dataset_json = json.load(f) 256 | dataset = dataset_json['data'] 257 | with open(OPTS.pred_file) as f: 258 | preds = json.load(f) 259 | if OPTS.na_prob_file: 260 | with open(OPTS.na_prob_file) as f: 261 | na_probs = json.load(f) 262 | else: 263 | na_probs = {k: 0.0 for k in preds} 264 | qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False 265 | has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] 266 | no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] 267 | exact_raw, f1_raw = get_raw_scores(dataset, preds) 268 | exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, 269 | OPTS.na_prob_thresh) 270 | f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, 271 | OPTS.na_prob_thresh) 272 | out_eval = make_eval_dict(exact_thresh, f1_thresh) 273 | if has_ans_qids: 274 | has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids) 275 | merge_eval(out_eval, has_ans_eval, 'HasAns') 276 | if no_ans_qids: 277 | no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) 278 | merge_eval(out_eval, no_ans_eval, 'NoAns') 279 | if OPTS.na_prob_file: 280 | find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans) 281 | if OPTS.na_prob_file and OPTS.out_image_dir: 282 | run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, 283 | qid_to_has_ans, OPTS.out_image_dir) 284 | histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, 'hasAns') 285 | histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, 'noAns') 286 | if OPTS.out_file: 287 | with open(OPTS.out_file, 'w') as f: 288 | json.dump(out_eval, f) 289 | else: 290 | print(json.dumps(out_eval, indent=2)) 291 | 292 | if __name__ == '__main__': 293 | OPTS = parse_args() 294 | if OPTS.out_image_dir: 295 | import matplotlib 296 | matplotlib.use('Agg') 297 | import matplotlib.pyplot as plt 298 | main() -------------------------------------------------------------------------------- /examples/extract_features.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Extract pre-computed feature vectors from a PyTorch BERT model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import collections 23 | import logging 24 | import json 25 | import re 26 | 27 | import torch 28 | from torch.utils.data import TensorDataset, DataLoader, SequentialSampler 29 | from torch.utils.data.distributed import DistributedSampler 30 | 31 | from pytorch_pretrained_bert.tokenization import BertTokenizer 32 | from pytorch_pretrained_bert.modeling import BertModel 33 | 34 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 35 | datefmt = '%m/%d/%Y %H:%M:%S', 36 | level = logging.INFO) 37 | logger = logging.getLogger(__name__) 38 | 39 | 40 | class InputExample(object): 41 | 42 | def __init__(self, unique_id, text_a, text_b): 43 | self.unique_id = unique_id 44 | self.text_a = text_a 45 | self.text_b = text_b 46 | 47 | 48 | class InputFeatures(object): 49 | """A single set of features of data.""" 50 | 51 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): 52 | self.unique_id = unique_id 53 | self.tokens = tokens 54 | self.input_ids = input_ids 55 | self.input_mask = input_mask 56 | self.input_type_ids = input_type_ids 57 | 58 | 59 | def convert_examples_to_features(examples, seq_length, tokenizer): 60 | """Loads a data file into a list of `InputBatch`s.""" 61 | 62 | features = [] 63 | for (ex_index, example) in enumerate(examples): 64 | tokens_a = tokenizer.tokenize(example.text_a) 65 | 66 | tokens_b = None 67 | if example.text_b: 68 | tokens_b = tokenizer.tokenize(example.text_b) 69 | 70 | if tokens_b: 71 | # Modifies `tokens_a` and `tokens_b` in place so that the total 72 | # length is less than the specified length. 73 | # Account for [CLS], [SEP], [SEP] with "- 3" 74 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) 75 | else: 76 | # Account for [CLS] and [SEP] with "- 2" 77 | if len(tokens_a) > seq_length - 2: 78 | tokens_a = tokens_a[0:(seq_length - 2)] 79 | 80 | # The convention in BERT is: 81 | # (a) For sequence pairs: 82 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 83 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 84 | # (b) For single sequences: 85 | # tokens: [CLS] the dog is hairy . [SEP] 86 | # type_ids: 0 0 0 0 0 0 0 87 | # 88 | # Where "type_ids" are used to indicate whether this is the first 89 | # sequence or the second sequence. The embedding vectors for `type=0` and 90 | # `type=1` were learned during pre-training and are added to the wordpiece 91 | # embedding vector (and position vector). This is not *strictly* necessary 92 | # since the [SEP] token unambigiously separates the sequences, but it makes 93 | # it easier for the model to learn the concept of sequences. 94 | # 95 | # For classification tasks, the first vector (corresponding to [CLS]) is 96 | # used as as the "sentence vector". Note that this only makes sense because 97 | # the entire model is fine-tuned. 98 | tokens = [] 99 | input_type_ids = [] 100 | tokens.append("[CLS]") 101 | input_type_ids.append(0) 102 | for token in tokens_a: 103 | tokens.append(token) 104 | input_type_ids.append(0) 105 | tokens.append("[SEP]") 106 | input_type_ids.append(0) 107 | 108 | if tokens_b: 109 | for token in tokens_b: 110 | tokens.append(token) 111 | input_type_ids.append(1) 112 | tokens.append("[SEP]") 113 | input_type_ids.append(1) 114 | 115 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 116 | 117 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 118 | # tokens are attended to. 119 | input_mask = [1] * len(input_ids) 120 | 121 | # Zero-pad up to the sequence length. 122 | while len(input_ids) < seq_length: 123 | input_ids.append(0) 124 | input_mask.append(0) 125 | input_type_ids.append(0) 126 | 127 | assert len(input_ids) == seq_length 128 | assert len(input_mask) == seq_length 129 | assert len(input_type_ids) == seq_length 130 | 131 | if ex_index < 5: 132 | logger.info("*** Example ***") 133 | logger.info("unique_id: %s" % (example.unique_id)) 134 | logger.info("tokens: %s" % " ".join([str(x) for x in tokens])) 135 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 136 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 137 | logger.info( 138 | "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])) 139 | 140 | features.append( 141 | InputFeatures( 142 | unique_id=example.unique_id, 143 | tokens=tokens, 144 | input_ids=input_ids, 145 | input_mask=input_mask, 146 | input_type_ids=input_type_ids)) 147 | return features 148 | 149 | 150 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 151 | """Truncates a sequence pair in place to the maximum length.""" 152 | 153 | # This is a simple heuristic which will always truncate the longer sequence 154 | # one token at a time. This makes more sense than truncating an equal percent 155 | # of tokens from each, since if one sequence is very short then each token 156 | # that's truncated likely contains more information than a longer sequence. 157 | while True: 158 | total_length = len(tokens_a) + len(tokens_b) 159 | if total_length <= max_length: 160 | break 161 | if len(tokens_a) > len(tokens_b): 162 | tokens_a.pop() 163 | else: 164 | tokens_b.pop() 165 | 166 | 167 | def read_examples(input_file): 168 | """Read a list of `InputExample`s from an input file.""" 169 | examples = [] 170 | unique_id = 0 171 | with open(input_file, "r", encoding='utf-8') as reader: 172 | while True: 173 | line = reader.readline() 174 | if not line: 175 | break 176 | line = line.strip() 177 | text_a = None 178 | text_b = None 179 | m = re.match(r"^(.*) \|\|\| (.*)$", line) 180 | if m is None: 181 | text_a = line 182 | else: 183 | text_a = m.group(1) 184 | text_b = m.group(2) 185 | examples.append( 186 | InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)) 187 | unique_id += 1 188 | return examples 189 | 190 | 191 | def main(): 192 | parser = argparse.ArgumentParser() 193 | 194 | ## Required parameters 195 | parser.add_argument("--input_file", default=None, type=str, required=True) 196 | parser.add_argument("--output_file", default=None, type=str, required=True) 197 | parser.add_argument("--bert_model", default=None, type=str, required=True, 198 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 199 | "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") 200 | 201 | ## Other parameters 202 | parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") 203 | parser.add_argument("--layers", default="-1,-2,-3,-4", type=str) 204 | parser.add_argument("--max_seq_length", default=128, type=int, 205 | help="The maximum total input sequence length after WordPiece tokenization. Sequences longer " 206 | "than this will be truncated, and sequences shorter than this will be padded.") 207 | parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.") 208 | parser.add_argument("--local_rank", 209 | type=int, 210 | default=-1, 211 | help = "local_rank for distributed training on gpus") 212 | parser.add_argument("--no_cuda", 213 | action='store_true', 214 | help="Whether not to use CUDA when available") 215 | 216 | args = parser.parse_args() 217 | 218 | if args.local_rank == -1 or args.no_cuda: 219 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 220 | n_gpu = torch.cuda.device_count() 221 | else: 222 | device = torch.device("cuda", args.local_rank) 223 | n_gpu = 1 224 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 225 | torch.distributed.init_process_group(backend='nccl') 226 | logger.info("device: {} n_gpu: {} distributed training: {}".format(device, n_gpu, bool(args.local_rank != -1))) 227 | 228 | layer_indexes = [int(x) for x in args.layers.split(",")] 229 | 230 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 231 | 232 | examples = read_examples(args.input_file) 233 | 234 | features = convert_examples_to_features( 235 | examples=examples, seq_length=args.max_seq_length, tokenizer=tokenizer) 236 | 237 | unique_id_to_feature = {} 238 | for feature in features: 239 | unique_id_to_feature[feature.unique_id] = feature 240 | 241 | model = BertModel.from_pretrained(args.bert_model) 242 | model.to(device) 243 | 244 | if args.local_rank != -1: 245 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 246 | output_device=args.local_rank) 247 | elif n_gpu > 1: 248 | model = torch.nn.DataParallel(model) 249 | 250 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 251 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 252 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 253 | 254 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_example_index) 255 | if args.local_rank == -1: 256 | eval_sampler = SequentialSampler(eval_data) 257 | else: 258 | eval_sampler = DistributedSampler(eval_data) 259 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size) 260 | 261 | model.eval() 262 | with open(args.output_file, "w", encoding='utf-8') as writer: 263 | for input_ids, input_mask, example_indices in eval_dataloader: 264 | input_ids = input_ids.to(device) 265 | input_mask = input_mask.to(device) 266 | 267 | all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask) 268 | all_encoder_layers = all_encoder_layers 269 | 270 | for b, example_index in enumerate(example_indices): 271 | feature = features[example_index.item()] 272 | unique_id = int(feature.unique_id) 273 | # feature = unique_id_to_feature[unique_id] 274 | output_json = collections.OrderedDict() 275 | output_json["linex_index"] = unique_id 276 | all_out_features = [] 277 | for (i, token) in enumerate(feature.tokens): 278 | all_layers = [] 279 | for (j, layer_index) in enumerate(layer_indexes): 280 | layer_output = all_encoder_layers[int(layer_index)].detach().cpu().numpy() 281 | layer_output = layer_output[b] 282 | layers = collections.OrderedDict() 283 | layers["index"] = layer_index 284 | layers["values"] = [ 285 | round(x.item(), 6) for x in layer_output[i] 286 | ] 287 | all_layers.append(layers) 288 | out_features = collections.OrderedDict() 289 | out_features["token"] = token 290 | out_features["layers"] = all_layers 291 | all_out_features.append(out_features) 292 | output_json["features"] = all_out_features 293 | writer.write(json.dumps(output_json) + "\n") 294 | 295 | 296 | if __name__ == "__main__": 297 | main() 298 | -------------------------------------------------------------------------------- /examples/run_lm_finetuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """BERT finetuning runner.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import logging 24 | import argparse 25 | from tqdm import tqdm, trange 26 | 27 | import numpy as np 28 | import torch 29 | from torch.utils.data import DataLoader, RandomSampler 30 | from torch.utils.data.distributed import DistributedSampler 31 | 32 | from pytorch_pretrained_bert.tokenization import BertTokenizer 33 | from pytorch_pretrained_bert.modeling import BertForPreTraining 34 | from pytorch_pretrained_bert.optimization import BertAdam 35 | 36 | from torch.utils.data import Dataset 37 | import random 38 | 39 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 40 | datefmt='%m/%d/%Y %H:%M:%S', 41 | level=logging.INFO) 42 | logger = logging.getLogger(__name__) 43 | 44 | 45 | def warmup_linear(x, warmup=0.002): 46 | if x < warmup: 47 | return x/warmup 48 | return 1.0 - x 49 | 50 | 51 | class BERTDataset(Dataset): 52 | def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True): 53 | self.vocab = tokenizer.vocab 54 | self.tokenizer = tokenizer 55 | self.seq_len = seq_len 56 | self.on_memory = on_memory 57 | self.corpus_lines = corpus_lines # number of non-empty lines in input corpus 58 | self.corpus_path = corpus_path 59 | self.encoding = encoding 60 | self.current_doc = 0 # to avoid random sentence from same doc 61 | 62 | # for loading samples directly from file 63 | self.sample_counter = 0 # used to keep track of full epochs on file 64 | self.line_buffer = None # keep second sentence of a pair in memory and use as first sentence in next pair 65 | 66 | # for loading samples in memory 67 | self.current_random_doc = 0 68 | self.num_docs = 0 69 | self.sample_to_doc = [] # map sample index to doc and line 70 | 71 | # load samples into memory 72 | if on_memory: 73 | self.all_docs = [] 74 | doc = [] 75 | self.corpus_lines = 0 76 | with open(corpus_path, "r", encoding=encoding) as f: 77 | for line in tqdm(f, desc="Loading Dataset", total=corpus_lines): 78 | line = line.strip() 79 | if line == "": 80 | self.all_docs.append(doc) 81 | doc = [] 82 | #remove last added sample because there won't be a subsequent line anymore in the doc 83 | self.sample_to_doc.pop() 84 | else: 85 | #store as one sample 86 | sample = {"doc_id": len(self.all_docs), 87 | "line": len(doc)} 88 | self.sample_to_doc.append(sample) 89 | doc.append(line) 90 | self.corpus_lines = self.corpus_lines + 1 91 | 92 | # if last row in file is not empty 93 | if self.all_docs[-1] != doc: 94 | self.all_docs.append(doc) 95 | self.sample_to_doc.pop() 96 | 97 | self.num_docs = len(self.all_docs) 98 | 99 | # load samples later lazily from disk 100 | else: 101 | if self.corpus_lines is None: 102 | with open(corpus_path, "r", encoding=encoding) as f: 103 | self.corpus_lines = 0 104 | for line in tqdm(f, desc="Loading Dataset", total=corpus_lines): 105 | if line.strip() == "": 106 | self.num_docs += 1 107 | else: 108 | self.corpus_lines += 1 109 | 110 | # if doc does not end with empty line 111 | if line.strip() != "": 112 | self.num_docs += 1 113 | 114 | self.file = open(corpus_path, "r", encoding=encoding) 115 | self.random_file = open(corpus_path, "r", encoding=encoding) 116 | 117 | def __len__(self): 118 | # last line of doc won't be used, because there's no "nextSentence". Additionally, we start counting at 0. 119 | return self.corpus_lines - self.num_docs - 1 120 | 121 | def __getitem__(self, item): 122 | cur_id = self.sample_counter 123 | self.sample_counter += 1 124 | if not self.on_memory: 125 | # after one epoch we start again from beginning of file 126 | if cur_id != 0 and (cur_id % len(self) == 0): 127 | self.file.close() 128 | self.file = open(self.corpus_path, "r", encoding=self.encoding) 129 | 130 | t1, t2, is_next_label = self.random_sent(item) 131 | 132 | # tokenize 133 | tokens_a = self.tokenizer.tokenize(t1) 134 | tokens_b = self.tokenizer.tokenize(t2) 135 | 136 | # combine to one sample 137 | cur_example = InputExample(guid=cur_id, tokens_a=tokens_a, tokens_b=tokens_b, is_next=is_next_label) 138 | 139 | # transform sample to features 140 | cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer) 141 | 142 | cur_tensors = (torch.tensor(cur_features.input_ids), 143 | torch.tensor(cur_features.input_mask), 144 | torch.tensor(cur_features.segment_ids), 145 | torch.tensor(cur_features.lm_label_ids), 146 | torch.tensor(cur_features.is_next)) 147 | 148 | return cur_tensors 149 | 150 | def random_sent(self, index): 151 | """ 152 | Get one sample from corpus consisting of two sentences. With prob. 50% these are two subsequent sentences 153 | from one doc. With 50% the second sentence will be a random one from another doc. 154 | :param index: int, index of sample. 155 | :return: (str, str, int), sentence 1, sentence 2, isNextSentence Label 156 | """ 157 | t1, t2 = self.get_corpus_line(index) 158 | if random.random() > 0.5: 159 | label = 0 160 | else: 161 | t2 = self.get_random_line() 162 | label = 1 163 | 164 | assert len(t1) > 0 165 | assert len(t2) > 0 166 | return t1, t2, label 167 | 168 | def get_corpus_line(self, item): 169 | """ 170 | Get one sample from corpus consisting of a pair of two subsequent lines from the same doc. 171 | :param item: int, index of sample. 172 | :return: (str, str), two subsequent sentences from corpus 173 | """ 174 | t1 = "" 175 | t2 = "" 176 | assert item < self.corpus_lines 177 | if self.on_memory: 178 | sample = self.sample_to_doc[item] 179 | t1 = self.all_docs[sample["doc_id"]][sample["line"]] 180 | t2 = self.all_docs[sample["doc_id"]][sample["line"]+1] 181 | # used later to avoid random nextSentence from same doc 182 | self.current_doc = sample["doc_id"] 183 | return t1, t2 184 | else: 185 | if self.line_buffer is None: 186 | # read first non-empty line of file 187 | while t1 == "" : 188 | t1 = self.file.__next__().strip() 189 | t2 = self.file.__next__().strip() 190 | else: 191 | # use t2 from previous iteration as new t1 192 | t1 = self.line_buffer 193 | t2 = self.file.__next__().strip() 194 | # skip empty rows that are used for separating documents and keep track of current doc id 195 | while t2 == "" or t1 == "": 196 | t1 = self.file.__next__().strip() 197 | t2 = self.file.__next__().strip() 198 | self.current_doc = self.current_doc+1 199 | self.line_buffer = t2 200 | 201 | assert t1 != "" 202 | assert t2 != "" 203 | return t1, t2 204 | 205 | def get_random_line(self): 206 | """ 207 | Get random line from another document for nextSentence task. 208 | :return: str, content of one line 209 | """ 210 | # Similar to original tf repo: This outer loop should rarely go for more than one iteration for large 211 | # corpora. However, just to be careful, we try to make sure that 212 | # the random document is not the same as the document we're processing. 213 | for _ in range(10): 214 | if self.on_memory: 215 | rand_doc_idx = random.randint(0, len(self.all_docs)-1) 216 | rand_doc = self.all_docs[rand_doc_idx] 217 | line = rand_doc[random.randrange(len(rand_doc))] 218 | else: 219 | rand_index = random.randint(1, self.corpus_lines if self.corpus_lines < 1000 else 1000) 220 | #pick random line 221 | for _ in range(rand_index): 222 | line = self.get_next_line() 223 | #check if our picked random line is really from another doc like we want it to be 224 | if self.current_random_doc != self.current_doc: 225 | break 226 | return line 227 | 228 | def get_next_line(self): 229 | """ Gets next line of random_file and starts over when reaching end of file""" 230 | try: 231 | line = self.random_file.__next__().strip() 232 | #keep track of which document we are currently looking at to later avoid having the same doc as t1 233 | if line == "": 234 | self.current_random_doc = self.current_random_doc + 1 235 | line = self.random_file.__next__().strip() 236 | except StopIteration: 237 | self.random_file.close() 238 | self.random_file = open(self.corpus_path, "r", encoding=self.encoding) 239 | line = self.random_file.__next__().strip() 240 | return line 241 | 242 | 243 | class InputExample(object): 244 | """A single training/test example for the language model.""" 245 | 246 | def __init__(self, guid, tokens_a, tokens_b=None, is_next=None, lm_labels=None): 247 | """Constructs a InputExample. 248 | 249 | Args: 250 | guid: Unique id for the example. 251 | tokens_a: string. The untokenized text of the first sequence. For single 252 | sequence tasks, only this sequence must be specified. 253 | tokens_b: (Optional) string. The untokenized text of the second sequence. 254 | Only must be specified for sequence pair tasks. 255 | label: (Optional) string. The label of the example. This should be 256 | specified for train and dev examples, but not for test examples. 257 | """ 258 | self.guid = guid 259 | self.tokens_a = tokens_a 260 | self.tokens_b = tokens_b 261 | self.is_next = is_next # nextSentence 262 | self.lm_labels = lm_labels # masked words for language model 263 | 264 | 265 | class InputFeatures(object): 266 | """A single set of features of data.""" 267 | 268 | def __init__(self, input_ids, input_mask, segment_ids, is_next, lm_label_ids): 269 | self.input_ids = input_ids 270 | self.input_mask = input_mask 271 | self.segment_ids = segment_ids 272 | self.is_next = is_next 273 | self.lm_label_ids = lm_label_ids 274 | 275 | 276 | def random_word(tokens, tokenizer): 277 | """ 278 | Masking some random tokens for Language Model task with probabilities as in the original BERT paper. 279 | :param tokens: list of str, tokenized sentence. 280 | :param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here) 281 | :return: (list of str, list of int), masked tokens and related labels for LM prediction 282 | """ 283 | output_label = [] 284 | 285 | for i, token in enumerate(tokens): 286 | prob = random.random() 287 | # mask token with 15% probability 288 | if prob < 0.15: 289 | prob /= 0.15 290 | 291 | # 80% randomly change token to mask token 292 | if prob < 0.8: 293 | tokens[i] = "[MASK]" 294 | 295 | # 10% randomly change token to random token 296 | elif prob < 0.9: 297 | tokens[i] = random.choice(list(tokenizer.vocab.items()))[0] 298 | 299 | # -> rest 10% randomly keep current token 300 | 301 | # append current token to output (we will predict these later) 302 | try: 303 | output_label.append(tokenizer.vocab[token]) 304 | except KeyError: 305 | # For unknown words (should not occur with BPE vocab) 306 | output_label.append(tokenizer.vocab["[UNK]"]) 307 | logger.warning("Cannot find token '{}' in vocab. Using [UNK] insetad".format(token)) 308 | else: 309 | # no masking token (will be ignored by loss function later) 310 | output_label.append(-1) 311 | 312 | return tokens, output_label 313 | 314 | 315 | def convert_example_to_features(example, max_seq_length, tokenizer): 316 | """ 317 | Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with 318 | IDs, LM labels, input_mask, CLS and SEP tokens etc. 319 | :param example: InputExample, containing sentence input as strings and is_next label 320 | :param max_seq_length: int, maximum length of sequence. 321 | :param tokenizer: Tokenizer 322 | :return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training) 323 | """ 324 | tokens_a = example.tokens_a 325 | tokens_b = example.tokens_b 326 | # Modifies `tokens_a` and `tokens_b` in place so that the total 327 | # length is less than the specified length. 328 | # Account for [CLS], [SEP], [SEP] with "- 3" 329 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 330 | 331 | t1_random, t1_label = random_word(tokens_a, tokenizer) 332 | t2_random, t2_label = random_word(tokens_b, tokenizer) 333 | # concatenate lm labels and account for CLS, SEP, SEP 334 | lm_label_ids = ([-1] + t1_label + [-1] + t2_label + [-1]) 335 | 336 | # The convention in BERT is: 337 | # (a) For sequence pairs: 338 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 339 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 340 | # (b) For single sequences: 341 | # tokens: [CLS] the dog is hairy . [SEP] 342 | # type_ids: 0 0 0 0 0 0 0 343 | # 344 | # Where "type_ids" are used to indicate whether this is the first 345 | # sequence or the second sequence. The embedding vectors for `type=0` and 346 | # `type=1` were learned during pre-training and are added to the wordpiece 347 | # embedding vector (and position vector). This is not *strictly* necessary 348 | # since the [SEP] token unambigiously separates the sequences, but it makes 349 | # it easier for the model to learn the concept of sequences. 350 | # 351 | # For classification tasks, the first vector (corresponding to [CLS]) is 352 | # used as as the "sentence vector". Note that this only makes sense because 353 | # the entire model is fine-tuned. 354 | tokens = [] 355 | segment_ids = [] 356 | tokens.append("[CLS]") 357 | segment_ids.append(0) 358 | for token in tokens_a: 359 | tokens.append(token) 360 | segment_ids.append(0) 361 | tokens.append("[SEP]") 362 | segment_ids.append(0) 363 | 364 | assert len(tokens_b) > 0 365 | for token in tokens_b: 366 | tokens.append(token) 367 | segment_ids.append(1) 368 | tokens.append("[SEP]") 369 | segment_ids.append(1) 370 | 371 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 372 | 373 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 374 | # tokens are attended to. 375 | input_mask = [1] * len(input_ids) 376 | 377 | # Zero-pad up to the sequence length. 378 | while len(input_ids) < max_seq_length: 379 | input_ids.append(0) 380 | input_mask.append(0) 381 | segment_ids.append(0) 382 | lm_label_ids.append(-1) 383 | 384 | assert len(input_ids) == max_seq_length 385 | assert len(input_mask) == max_seq_length 386 | assert len(segment_ids) == max_seq_length 387 | assert len(lm_label_ids) == max_seq_length 388 | 389 | if example.guid < 5: 390 | logger.info("*** Example ***") 391 | logger.info("guid: %s" % (example.guid)) 392 | logger.info("tokens: %s" % " ".join( 393 | [str(x) for x in tokens])) 394 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 395 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 396 | logger.info( 397 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 398 | logger.info("LM label: %s " % (lm_label_ids)) 399 | logger.info("Is next sentence label: %s " % (example.is_next)) 400 | 401 | features = InputFeatures(input_ids=input_ids, 402 | input_mask=input_mask, 403 | segment_ids=segment_ids, 404 | lm_label_ids=lm_label_ids, 405 | is_next=example.is_next) 406 | return features 407 | 408 | 409 | def main(): 410 | """ 411 | python run_lm_finetuning.py \ 412 | --bert_model bert-base-uncased \ 413 | --do_lower_case \ 414 | --do_train \ 415 | --train_file ../samples/sample_text.txt \ 416 | --output_dir models \ 417 | --num_train_epochs 5.0 \ 418 | --learning_rate 3e-5 \ 419 | --train_batch_size 32 \ 420 | --max_seq_length 128 \ 421 | """ 422 | parser = argparse.ArgumentParser() 423 | 424 | # Required parameters 425 | parser.add_argument("--train_file", 426 | default=None, 427 | type=str, 428 | required=True, 429 | help="The input train corpus.") 430 | parser.add_argument("--bert_model", default=None, type=str, required=True, 431 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 432 | "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") 433 | parser.add_argument("--output_dir", 434 | default=None, 435 | type=str, 436 | required=True, 437 | help="The output directory where the model checkpoints will be written.") 438 | 439 | # Other parameters 440 | parser.add_argument("--max_seq_length", 441 | default=128, 442 | type=int, 443 | help="The maximum total input sequence length after WordPiece tokenization. \n" 444 | "Sequences longer than this will be truncated, and sequences shorter \n" 445 | "than this will be padded.") 446 | parser.add_argument("--do_train", 447 | action='store_true', 448 | help="Whether to run training.") 449 | parser.add_argument("--train_batch_size", 450 | default=32, 451 | type=int, 452 | help="Total batch size for training.") 453 | parser.add_argument("--eval_batch_size", 454 | default=8, 455 | type=int, 456 | help="Total batch size for eval.") 457 | parser.add_argument("--learning_rate", 458 | default=3e-5, 459 | type=float, 460 | help="The initial learning rate for Adam.") 461 | parser.add_argument("--num_train_epochs", 462 | default=3.0, 463 | type=float, 464 | help="Total number of training epochs to perform.") 465 | parser.add_argument("--warmup_proportion", 466 | default=0.1, 467 | type=float, 468 | help="Proportion of training to perform linear learning rate warmup for. " 469 | "E.g., 0.1 = 10%% of training.") 470 | parser.add_argument("--no_cuda", 471 | action='store_true', 472 | help="Whether not to use CUDA when available") 473 | parser.add_argument("--on_memory", 474 | action='store_true', 475 | help="Whether to load train samples into memory or use disk") 476 | parser.add_argument("--do_lower_case", 477 | action='store_true', 478 | help="Whether to lower case the input text. True for uncased models, False for cased models.") 479 | parser.add_argument("--local_rank", 480 | type=int, 481 | default=-1, 482 | help="local_rank for distributed training on gpus") 483 | parser.add_argument('--seed', 484 | type=int, 485 | default=42, 486 | help="random seed for initialization") 487 | parser.add_argument('--gradient_accumulation_steps', 488 | type=int, 489 | default=1, 490 | help="Number of updates steps to accumualte before performing a backward/update pass.") 491 | parser.add_argument('--fp16', 492 | action='store_true', 493 | help="Whether to use 16-bit float precision instead of 32-bit") 494 | parser.add_argument('--loss_scale', 495 | type = float, default = 0, 496 | help = "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 497 | "0 (default value): dynamic loss scaling.\n" 498 | "Positive power of 2: static loss scaling value.\n") 499 | 500 | args = parser.parse_args() 501 | 502 | if args.local_rank == -1 or args.no_cuda: 503 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 504 | n_gpu = torch.cuda.device_count() 505 | else: 506 | torch.cuda.set_device(args.local_rank) 507 | device = torch.device("cuda", args.local_rank) 508 | n_gpu = 1 509 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 510 | torch.distributed.init_process_group(backend='nccl') 511 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( 512 | device, n_gpu, bool(args.local_rank != -1), args.fp16)) 513 | 514 | if args.gradient_accumulation_steps < 1: 515 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 516 | args.gradient_accumulation_steps)) 517 | 518 | args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) 519 | 520 | random.seed(args.seed) 521 | np.random.seed(args.seed) 522 | torch.manual_seed(args.seed) 523 | if n_gpu > 0: 524 | torch.cuda.manual_seed_all(args.seed) 525 | 526 | if not args.do_train and not args.do_eval: 527 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 528 | 529 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir): 530 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 531 | os.makedirs(args.output_dir, exist_ok=True) 532 | 533 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 534 | 535 | #train_examples = None 536 | num_train_steps = None 537 | if args.do_train: 538 | print("Loading Train Dataset", args.train_file) 539 | """ 540 | train_dataset: 541 | TODO: 542 | """ 543 | train_dataset = BERTDataset(args.train_file, tokenizer, seq_len=args.max_seq_length, 544 | corpus_lines=None, on_memory=args.on_memory) 545 | num_train_steps = int( 546 | len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) 547 | 548 | # Prepare model 549 | model = BertForPreTraining.from_pretrained(args.bert_model) 550 | if args.fp16: 551 | model.half() 552 | model.to(device) 553 | if args.local_rank != -1: 554 | try: 555 | from apex.parallel import DistributedDataParallel as DDP 556 | except ImportError: 557 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 558 | model = DDP(model) 559 | elif n_gpu > 1: 560 | model = torch.nn.DataParallel(model) 561 | 562 | # Prepare optimizer 563 | param_optimizer = list(model.named_parameters()) 564 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 565 | optimizer_grouped_parameters = [ 566 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 567 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 568 | ] 569 | if args.fp16: 570 | try: 571 | from apex.optimizers import FP16_Optimizer 572 | from apex.optimizers import FusedAdam 573 | except ImportError: 574 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 575 | 576 | optimizer = FusedAdam(optimizer_grouped_parameters, 577 | lr=args.learning_rate, 578 | bias_correction=False, 579 | max_grad_norm=1.0) 580 | if args.loss_scale == 0: 581 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) 582 | else: 583 | optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) 584 | 585 | else: 586 | optimizer = BertAdam(optimizer_grouped_parameters, 587 | lr=args.learning_rate, 588 | warmup=args.warmup_proportion, 589 | t_total=num_train_steps) 590 | 591 | global_step = 0 592 | if args.do_train: 593 | logger.info("***** Running training *****") 594 | logger.info(" Num examples = %d", len(train_dataset)) 595 | logger.info(" Batch size = %d", args.train_batch_size) 596 | logger.info(" Num steps = %d", num_train_steps) 597 | 598 | if args.local_rank == -1: 599 | train_sampler = RandomSampler(train_dataset) 600 | else: 601 | #TODO: check if this works with current data generator from disk that relies on file.__next__ 602 | # (it doesn't return item back by index) 603 | train_sampler = DistributedSampler(train_dataset) 604 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 605 | 606 | model.train() 607 | for _ in trange(int(args.num_train_epochs), desc="Epoch"): 608 | tr_loss = 0 609 | nb_tr_examples, nb_tr_steps = 0, 0 610 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): 611 | batch = tuple(t.to(device) for t in batch) 612 | input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch 613 | loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next) 614 | if n_gpu > 1: 615 | loss = loss.mean() # mean() to average on multi-gpu. 616 | if args.gradient_accumulation_steps > 1: 617 | loss = loss / args.gradient_accumulation_steps 618 | if args.fp16: 619 | optimizer.backward(loss) 620 | else: 621 | loss.backward() 622 | tr_loss += loss.item() 623 | nb_tr_examples += input_ids.size(0) 624 | nb_tr_steps += 1 625 | if (step + 1) % args.gradient_accumulation_steps == 0: 626 | # modify learning rate with special warm up BERT uses 627 | lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_steps, args.warmup_proportion) 628 | for param_group in optimizer.param_groups: 629 | param_group['lr'] = lr_this_step 630 | optimizer.step() 631 | optimizer.zero_grad() 632 | global_step += 1 633 | 634 | # Save a trained model 635 | logger.info("** ** * Saving fine - tuned model ** ** * ") 636 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 637 | output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") 638 | if args.do_train: 639 | torch.save(model_to_save.state_dict(), output_model_file) 640 | 641 | 642 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 643 | """Truncates a sequence pair in place to the maximum length.""" 644 | 645 | # This is a simple heuristic which will always truncate the longer sequence 646 | # one token at a time. This makes more sense than truncating an equal percent 647 | # of tokens from each, since if one sequence is very short then each token 648 | # that's truncated likely contains more information than a longer sequence. 649 | while True: 650 | total_length = len(tokens_a) + len(tokens_b) 651 | if total_length <= max_length: 652 | break 653 | if len(tokens_a) > len(tokens_b): 654 | tokens_a.pop() 655 | else: 656 | tokens_b.pop() 657 | 658 | 659 | def accuracy(out, labels): 660 | outputs = np.argmax(out, axis=1) 661 | return np.sum(outputs == labels) 662 | 663 | 664 | if __name__ == "__main__": 665 | main() -------------------------------------------------------------------------------- /examples/run_swag.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """BERT finetuning runner.""" 17 | 18 | import logging 19 | import os 20 | import argparse 21 | import random 22 | from tqdm import tqdm, trange 23 | import csv 24 | 25 | import numpy as np 26 | import torch 27 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 28 | from torch.utils.data.distributed import DistributedSampler 29 | 30 | from pytorch_pretrained_bert.tokenization import BertTokenizer 31 | from pytorch_pretrained_bert.modeling import BertForMultipleChoice 32 | from pytorch_pretrained_bert.optimization import BertAdam 33 | from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE 34 | 35 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 36 | datefmt = '%m/%d/%Y %H:%M:%S', 37 | level = logging.INFO) 38 | logger = logging.getLogger(__name__) 39 | 40 | 41 | class SwagExample(object): 42 | """A single training/test example for the SWAG dataset.""" 43 | def __init__(self, 44 | swag_id, 45 | context_sentence, 46 | start_ending, 47 | ending_0, 48 | ending_1, 49 | ending_2, 50 | ending_3, 51 | label = None): 52 | self.swag_id = swag_id 53 | self.context_sentence = context_sentence 54 | self.start_ending = start_ending 55 | self.endings = [ 56 | ending_0, 57 | ending_1, 58 | ending_2, 59 | ending_3, 60 | ] 61 | self.label = label 62 | 63 | def __str__(self): 64 | return self.__repr__() 65 | 66 | def __repr__(self): 67 | l = [ 68 | f"swag_id: {self.swag_id}", 69 | f"context_sentence: {self.context_sentence}", 70 | f"start_ending: {self.start_ending}", 71 | f"ending_0: {self.endings[0]}", 72 | f"ending_1: {self.endings[1]}", 73 | f"ending_2: {self.endings[2]}", 74 | f"ending_3: {self.endings[3]}", 75 | ] 76 | 77 | if self.label is not None: 78 | l.append(f"label: {self.label}") 79 | 80 | return ", ".join(l) 81 | 82 | 83 | class InputFeatures(object): 84 | def __init__(self, 85 | example_id, 86 | choices_features, 87 | label 88 | 89 | ): 90 | self.example_id = example_id 91 | self.choices_features = [ 92 | { 93 | 'input_ids': input_ids, 94 | 'input_mask': input_mask, 95 | 'segment_ids': segment_ids 96 | } 97 | for _, input_ids, input_mask, segment_ids in choices_features 98 | ] 99 | self.label = label 100 | 101 | 102 | def read_swag_examples(input_file, is_training): 103 | with open(input_file, 'r', encoding='utf-8') as f: 104 | reader = csv.reader(f) 105 | lines = list(reader) 106 | 107 | if is_training and lines[0][-1] != 'label': 108 | raise ValueError( 109 | "For training, the input file must contain a label column." 110 | ) 111 | 112 | examples = [ 113 | SwagExample( 114 | swag_id = line[2], 115 | context_sentence = line[4], 116 | start_ending = line[5], # in the swag dataset, the 117 | # common beginning of each 118 | # choice is stored in "sent2". 119 | ending_0 = line[7], 120 | ending_1 = line[8], 121 | ending_2 = line[9], 122 | ending_3 = line[10], 123 | label = int(line[11]) if is_training else None 124 | ) for line in lines[1:] # we skip the line with the column names 125 | ] 126 | 127 | return examples 128 | 129 | def convert_examples_to_features(examples, tokenizer, max_seq_length, 130 | is_training): 131 | """Loads a data file into a list of `InputBatch`s.""" 132 | 133 | # Swag is a multiple choice task. To perform this task using Bert, 134 | # we will use the formatting proposed in "Improving Language 135 | # Understanding by Generative Pre-Training" and suggested by 136 | # @jacobdevlin-google in this issue 137 | # https://github.com/google-research/bert/issues/38. 138 | # 139 | # Each choice will correspond to a sample on which we run the 140 | # inference. For a given Swag example, we will create the 4 141 | # following inputs: 142 | # - [CLS] context [SEP] choice_1 [SEP] 143 | # - [CLS] context [SEP] choice_2 [SEP] 144 | # - [CLS] context [SEP] choice_3 [SEP] 145 | # - [CLS] context [SEP] choice_4 [SEP] 146 | # The model will output a single value for each input. To get the 147 | # final decision of the model, we will run a softmax over these 4 148 | # outputs. 149 | features = [] 150 | for example_index, example in enumerate(examples): 151 | context_tokens = tokenizer.tokenize(example.context_sentence) 152 | start_ending_tokens = tokenizer.tokenize(example.start_ending) 153 | 154 | choices_features = [] 155 | for ending_index, ending in enumerate(example.endings): 156 | # We create a copy of the context tokens in order to be 157 | # able to shrink it according to ending_tokens 158 | context_tokens_choice = context_tokens[:] 159 | ending_tokens = start_ending_tokens + tokenizer.tokenize(ending) 160 | # Modifies `context_tokens_choice` and `ending_tokens` in 161 | # place so that the total length is less than the 162 | # specified length. Account for [CLS], [SEP], [SEP] with 163 | # "- 3" 164 | _truncate_seq_pair(context_tokens_choice, ending_tokens, max_seq_length - 3) 165 | 166 | tokens = ["[CLS]"] + context_tokens_choice + ["[SEP]"] + ending_tokens + ["[SEP]"] 167 | segment_ids = [0] * (len(context_tokens_choice) + 2) + [1] * (len(ending_tokens) + 1) 168 | 169 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 170 | input_mask = [1] * len(input_ids) 171 | 172 | # Zero-pad up to the sequence length. 173 | padding = [0] * (max_seq_length - len(input_ids)) 174 | input_ids += padding 175 | input_mask += padding 176 | segment_ids += padding 177 | 178 | assert len(input_ids) == max_seq_length 179 | assert len(input_mask) == max_seq_length 180 | assert len(segment_ids) == max_seq_length 181 | 182 | choices_features.append((tokens, input_ids, input_mask, segment_ids)) 183 | 184 | label = example.label 185 | if example_index < 5: 186 | logger.info("*** Example ***") 187 | logger.info(f"swag_id: {example.swag_id}") 188 | for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features): 189 | logger.info(f"choice: {choice_idx}") 190 | logger.info(f"tokens: {' '.join(tokens)}") 191 | logger.info(f"input_ids: {' '.join(map(str, input_ids))}") 192 | logger.info(f"input_mask: {' '.join(map(str, input_mask))}") 193 | logger.info(f"segment_ids: {' '.join(map(str, segment_ids))}") 194 | if is_training: 195 | logger.info(f"label: {label}") 196 | 197 | features.append( 198 | InputFeatures( 199 | example_id = example.swag_id, 200 | choices_features = choices_features, 201 | label = label 202 | ) 203 | ) 204 | 205 | return features 206 | 207 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 208 | """Truncates a sequence pair in place to the maximum length.""" 209 | 210 | # This is a simple heuristic which will always truncate the longer sequence 211 | # one token at a time. This makes more sense than truncating an equal percent 212 | # of tokens from each, since if one sequence is very short then each token 213 | # that's truncated likely contains more information than a longer sequence. 214 | while True: 215 | total_length = len(tokens_a) + len(tokens_b) 216 | if total_length <= max_length: 217 | break 218 | if len(tokens_a) > len(tokens_b): 219 | tokens_a.pop() 220 | else: 221 | tokens_b.pop() 222 | 223 | def accuracy(out, labels): 224 | outputs = np.argmax(out, axis=1) 225 | return np.sum(outputs == labels) 226 | 227 | def select_field(features, field): 228 | return [ 229 | [ 230 | choice[field] 231 | for choice in feature.choices_features 232 | ] 233 | for feature in features 234 | ] 235 | 236 | def warmup_linear(x, warmup=0.002): 237 | if x < warmup: 238 | return x/warmup 239 | return 1.0 - x 240 | 241 | def main(): 242 | parser = argparse.ArgumentParser() 243 | 244 | ## Required parameters 245 | parser.add_argument("--data_dir", 246 | default=None, 247 | type=str, 248 | required=True, 249 | help="The input data dir. Should contain the .csv files (or other data files) for the task.") 250 | parser.add_argument("--bert_model", default=None, type=str, required=True, 251 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 252 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " 253 | "bert-base-multilingual-cased, bert-base-chinese.") 254 | parser.add_argument("--output_dir", 255 | default=None, 256 | type=str, 257 | required=True, 258 | help="The output directory where the model checkpoints will be written.") 259 | 260 | ## Other parameters 261 | parser.add_argument("--max_seq_length", 262 | default=128, 263 | type=int, 264 | help="The maximum total input sequence length after WordPiece tokenization. \n" 265 | "Sequences longer than this will be truncated, and sequences shorter \n" 266 | "than this will be padded.") 267 | parser.add_argument("--do_train", 268 | action='store_true', 269 | help="Whether to run training.") 270 | parser.add_argument("--do_eval", 271 | action='store_true', 272 | help="Whether to run eval on the dev set.") 273 | parser.add_argument("--do_lower_case", 274 | action='store_true', 275 | help="Set this flag if you are using an uncased model.") 276 | parser.add_argument("--train_batch_size", 277 | default=32, 278 | type=int, 279 | help="Total batch size for training.") 280 | parser.add_argument("--eval_batch_size", 281 | default=8, 282 | type=int, 283 | help="Total batch size for eval.") 284 | parser.add_argument("--learning_rate", 285 | default=5e-5, 286 | type=float, 287 | help="The initial learning rate for Adam.") 288 | parser.add_argument("--num_train_epochs", 289 | default=3.0, 290 | type=float, 291 | help="Total number of training epochs to perform.") 292 | parser.add_argument("--warmup_proportion", 293 | default=0.1, 294 | type=float, 295 | help="Proportion of training to perform linear learning rate warmup for. " 296 | "E.g., 0.1 = 10%% of training.") 297 | parser.add_argument("--no_cuda", 298 | action='store_true', 299 | help="Whether not to use CUDA when available") 300 | parser.add_argument("--local_rank", 301 | type=int, 302 | default=-1, 303 | help="local_rank for distributed training on gpus") 304 | parser.add_argument('--seed', 305 | type=int, 306 | default=42, 307 | help="random seed for initialization") 308 | parser.add_argument('--gradient_accumulation_steps', 309 | type=int, 310 | default=1, 311 | help="Number of updates steps to accumulate before performing a backward/update pass.") 312 | parser.add_argument('--fp16', 313 | action='store_true', 314 | help="Whether to use 16-bit float precision instead of 32-bit") 315 | parser.add_argument('--loss_scale', 316 | type=float, default=0, 317 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 318 | "0 (default value): dynamic loss scaling.\n" 319 | "Positive power of 2: static loss scaling value.\n") 320 | 321 | args = parser.parse_args() 322 | 323 | if args.local_rank == -1 or args.no_cuda: 324 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 325 | n_gpu = torch.cuda.device_count() 326 | else: 327 | torch.cuda.set_device(args.local_rank) 328 | device = torch.device("cuda", args.local_rank) 329 | n_gpu = 1 330 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 331 | torch.distributed.init_process_group(backend='nccl') 332 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( 333 | device, n_gpu, bool(args.local_rank != -1), args.fp16)) 334 | 335 | if args.gradient_accumulation_steps < 1: 336 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 337 | args.gradient_accumulation_steps)) 338 | 339 | args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) 340 | 341 | random.seed(args.seed) 342 | np.random.seed(args.seed) 343 | torch.manual_seed(args.seed) 344 | if n_gpu > 0: 345 | torch.cuda.manual_seed_all(args.seed) 346 | 347 | if not args.do_train and not args.do_eval: 348 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 349 | 350 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir): 351 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 352 | os.makedirs(args.output_dir, exist_ok=True) 353 | 354 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 355 | 356 | train_examples = None 357 | num_train_steps = None 358 | if args.do_train: 359 | train_examples = read_swag_examples(os.path.join(args.data_dir, 'train.csv'), is_training = True) 360 | num_train_steps = int( 361 | len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) 362 | 363 | # Prepare model 364 | model = BertForMultipleChoice.from_pretrained(args.bert_model, 365 | cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank), 366 | num_choices=4) 367 | if args.fp16: 368 | model.half() 369 | model.to(device) 370 | if args.local_rank != -1: 371 | try: 372 | from apex.parallel import DistributedDataParallel as DDP 373 | except ImportError: 374 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 375 | 376 | model = DDP(model) 377 | elif n_gpu > 1: 378 | model = torch.nn.DataParallel(model) 379 | 380 | # Prepare optimizer 381 | param_optimizer = list(model.named_parameters()) 382 | 383 | # hack to remove pooler, which is not used 384 | # thus it produce None grad that break apex 385 | param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]] 386 | 387 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 388 | optimizer_grouped_parameters = [ 389 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 390 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 391 | ] 392 | t_total = num_train_steps 393 | if args.local_rank != -1: 394 | t_total = t_total // torch.distributed.get_world_size() 395 | if args.fp16: 396 | try: 397 | from apex.optimizers import FP16_Optimizer 398 | from apex.optimizers import FusedAdam 399 | except ImportError: 400 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 401 | 402 | optimizer = FusedAdam(optimizer_grouped_parameters, 403 | lr=args.learning_rate, 404 | bias_correction=False, 405 | max_grad_norm=1.0) 406 | if args.loss_scale == 0: 407 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) 408 | else: 409 | optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) 410 | else: 411 | optimizer = BertAdam(optimizer_grouped_parameters, 412 | lr=args.learning_rate, 413 | warmup=args.warmup_proportion, 414 | t_total=t_total) 415 | 416 | global_step = 0 417 | if args.do_train: 418 | train_features = convert_examples_to_features( 419 | train_examples, tokenizer, args.max_seq_length, True) 420 | logger.info("***** Running training *****") 421 | logger.info(" Num examples = %d", len(train_examples)) 422 | logger.info(" Batch size = %d", args.train_batch_size) 423 | logger.info(" Num steps = %d", num_train_steps) 424 | all_input_ids = torch.tensor(select_field(train_features, 'input_ids'), dtype=torch.long) 425 | all_input_mask = torch.tensor(select_field(train_features, 'input_mask'), dtype=torch.long) 426 | all_segment_ids = torch.tensor(select_field(train_features, 'segment_ids'), dtype=torch.long) 427 | all_label = torch.tensor([f.label for f in train_features], dtype=torch.long) 428 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label) 429 | if args.local_rank == -1: 430 | train_sampler = RandomSampler(train_data) 431 | else: 432 | train_sampler = DistributedSampler(train_data) 433 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) 434 | 435 | model.train() 436 | for _ in trange(int(args.num_train_epochs), desc="Epoch"): 437 | tr_loss = 0 438 | nb_tr_examples, nb_tr_steps = 0, 0 439 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): 440 | batch = tuple(t.to(device) for t in batch) 441 | input_ids, input_mask, segment_ids, label_ids = batch 442 | loss = model(input_ids, segment_ids, input_mask, label_ids) 443 | if n_gpu > 1: 444 | loss = loss.mean() # mean() to average on multi-gpu. 445 | if args.fp16 and args.loss_scale != 1.0: 446 | # rescale loss for fp16 training 447 | # see https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html 448 | loss = loss * args.loss_scale 449 | if args.gradient_accumulation_steps > 1: 450 | loss = loss / args.gradient_accumulation_steps 451 | tr_loss += loss.item() 452 | nb_tr_examples += input_ids.size(0) 453 | nb_tr_steps += 1 454 | 455 | if args.fp16: 456 | optimizer.backward(loss) 457 | else: 458 | loss.backward() 459 | if (step + 1) % args.gradient_accumulation_steps == 0: 460 | # modify learning rate with special warm up BERT uses 461 | lr_this_step = args.learning_rate * warmup_linear(global_step/t_total, args.warmup_proportion) 462 | for param_group in optimizer.param_groups: 463 | param_group['lr'] = lr_this_step 464 | optimizer.step() 465 | optimizer.zero_grad() 466 | global_step += 1 467 | 468 | # Save a trained model 469 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 470 | output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") 471 | torch.save(model_to_save.state_dict(), output_model_file) 472 | 473 | # Load a trained model that you have fine-tuned 474 | model_state_dict = torch.load(output_model_file) 475 | model = BertForMultipleChoice.from_pretrained(args.bert_model, 476 | state_dict=model_state_dict, 477 | num_choices=4) 478 | model.to(device) 479 | 480 | if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 481 | eval_examples = read_swag_examples(os.path.join(args.data_dir, 'val.csv'), is_training = True) 482 | eval_features = convert_examples_to_features( 483 | eval_examples, tokenizer, args.max_seq_length, True) 484 | logger.info("***** Running evaluation *****") 485 | logger.info(" Num examples = %d", len(eval_examples)) 486 | logger.info(" Batch size = %d", args.eval_batch_size) 487 | all_input_ids = torch.tensor(select_field(eval_features, 'input_ids'), dtype=torch.long) 488 | all_input_mask = torch.tensor(select_field(eval_features, 'input_mask'), dtype=torch.long) 489 | all_segment_ids = torch.tensor(select_field(eval_features, 'segment_ids'), dtype=torch.long) 490 | all_label = torch.tensor([f.label for f in eval_features], dtype=torch.long) 491 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label) 492 | # Run prediction for full data 493 | eval_sampler = SequentialSampler(eval_data) 494 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 495 | 496 | model.eval() 497 | eval_loss, eval_accuracy = 0, 0 498 | nb_eval_steps, nb_eval_examples = 0, 0 499 | for input_ids, input_mask, segment_ids, label_ids in eval_dataloader: 500 | input_ids = input_ids.to(device) 501 | input_mask = input_mask.to(device) 502 | segment_ids = segment_ids.to(device) 503 | label_ids = label_ids.to(device) 504 | 505 | with torch.no_grad(): 506 | tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids) 507 | logits = model(input_ids, segment_ids, input_mask) 508 | 509 | logits = logits.detach().cpu().numpy() 510 | label_ids = label_ids.to('cpu').numpy() 511 | tmp_eval_accuracy = accuracy(logits, label_ids) 512 | 513 | eval_loss += tmp_eval_loss.mean().item() 514 | eval_accuracy += tmp_eval_accuracy 515 | 516 | nb_eval_examples += input_ids.size(0) 517 | nb_eval_steps += 1 518 | 519 | eval_loss = eval_loss / nb_eval_steps 520 | eval_accuracy = eval_accuracy / nb_eval_examples 521 | 522 | result = {'eval_loss': eval_loss, 523 | 'eval_accuracy': eval_accuracy, 524 | 'global_step': global_step, 525 | 'loss': tr_loss/nb_tr_steps} 526 | 527 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt") 528 | with open(output_eval_file, "w") as writer: 529 | logger.info("***** Eval results *****") 530 | for key in sorted(result.keys()): 531 | logger.info(" %s = %s", key, str(result[key])) 532 | writer.write("%s = %s\n" % (key, str(result[key]))) 533 | 534 | 535 | if __name__ == "__main__": 536 | main() 537 | -------------------------------------------------------------------------------- /examples/softmax.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 5 | datefmt='%m/%d/%Y %H:%M:%S', 6 | level=logging.INFO) 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def _compute_softmax(scores): 11 | """Compute softmax probability over raw logits.""" 12 | if not scores: 13 | return [] 14 | 15 | max_score = None 16 | for score in scores: 17 | if max_score is None or score > max_score: 18 | max_score = score 19 | 20 | exp_scores = [] 21 | total_sum = 0.0 22 | for score in scores: 23 | x = math.exp(score - max_score) 24 | exp_scores.append(x) 25 | total_sum += x 26 | 27 | probs = [] 28 | for score in exp_scores: 29 | probs.append(score / total_sum) 30 | return probs 31 | 32 | 33 | probs = _compute_softmax([1, 2, 3]) 34 | logger.info("probs: %s" % probs) 35 | # logger.info("Test: %s", probs) 36 | # logger.info(sum(probs)) 37 | -------------------------------------------------------------------------------- /examples/test_BertForMaskedLM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM 3 | 4 | # Load pre-trained model tokenizer (vocabulary) 5 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 6 | 7 | # Tokenized input 8 | text = "Who was Jim Henson ? Jim Henson was a puppeteer" 9 | tokenized_text = tokenizer.tokenize(text) 10 | 11 | # Mask a token that we will try to predict back with `BertForMaskedLM` 12 | masked_index = 6 13 | tokenized_text[masked_index] = '[MASK]' 14 | assert tokenized_text == ['who', 'was', 'jim', 'henson', '?', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer'] 15 | 16 | # Convert token to vocabulary indices 17 | indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) 18 | # Define sentence A and B indices associated to 1st and 2nd sentences (see paper) 19 | segments_ids = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1] 20 | 21 | # Convert inputs to PyTorch tensors 22 | tokens_tensor = torch.tensor([indexed_tokens]) 23 | segments_tensors = torch.tensor([segments_ids]) 24 | 25 | # ========================= BertForMaskedLM ============================== 26 | # Load pre-trained model (weights) 27 | model = BertForMaskedLM.from_pretrained('bert-base-uncased') 28 | model.eval() 29 | 30 | """ 31 | predictions.size(): 32 | torch.Size([1, 11, 30522]) 33 | 34 | predictions[0, masked_index]: 35 | tensor([-7.8384, -7.8162, -7.8893, ..., -6.9924, -6.1897, -4.5417], 36 | grad_fn=) 37 | 38 | predictions[0, masked_index].size(): 39 | torch.Size([30522]) 40 | """ 41 | # Predict all tokens 42 | predictions = model(tokens_tensor, segments_tensors) 43 | 44 | # confirm we were able to predict 'henson' 45 | predicted_index = torch.argmax(predictions[0, masked_index]).item() 46 | predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0] 47 | assert predicted_token == 'henson' 48 | -------------------------------------------------------------------------------- /examples/test_BertModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM 3 | 4 | """ 5 | Let's see how to use BertModel to get hidden states 6 | """ 7 | 8 | # Load pre-trained model tokenizer (vocabulary) 9 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 10 | 11 | # Tokenized input 12 | text = "Who was Jim Henson ? Jim Henson was a puppeteer" 13 | tokenized_text = tokenizer.tokenize(text) 14 | 15 | # Mask a token that we will try to predict back with `BertForMaskedLM` 16 | masked_index = 6 17 | tokenized_text[masked_index] = '[MASK]' 18 | assert tokenized_text == ['who', 'was', 'jim', 'henson', '?', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer'] 19 | 20 | # Convert token to vocabulary indices 21 | indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) 22 | # Define sentence A and B indices associated to 1st and 2nd sentences (see paper) 23 | segments_ids = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1] 24 | 25 | # Convert inputs to PyTorch tensors 26 | tokens_tensor = torch.tensor([indexed_tokens]) 27 | segments_tensors = torch.tensor([segments_ids]) 28 | 29 | # ========================= BertModel to get hidden states ============================== 30 | # Load pre-trained model (weights) 31 | model = BertModel.from_pretrained('bert-base-uncased') 32 | model.eval() 33 | 34 | # # If you have a GPU, put everything on cuda 35 | # tokens_tensor = tokens_tensor.to('cuda') 36 | # segments_tensors = segments_tensors.to('cuda') 37 | # model.to('cuda') 38 | 39 | # Predict hidden states features for each layer 40 | with torch.no_grad(): 41 | encoded_layers, _ = model(tokens_tensor, segments_tensors) 42 | # We have a hidden states for each of the 12 layers in model bert-base-uncased 43 | assert len(encoded_layers) == 12 44 | 45 | -------------------------------------------------------------------------------- /examples/test_squad.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | base_path = "/home/wyb/data/squad_v1.1/" 4 | train_file_name = "train-v1.1.json" 5 | dev_file_name = "dev-v1.1.json" 6 | input_file = base_path + train_file_name 7 | 8 | with open(input_file, "r", encoding='utf-8') as reader: 9 | input_data = json.load(reader)["data"] 10 | 11 | # dic = {'a': 1, 'b': 2, 'c': 3} 12 | # js = json.dumps(input_file, sort_keys=True, indent=4, separators=(',', ':')) 13 | # print(js) 14 | 15 | print(len(input_data)) 16 | print(input_data[0]) 17 | print(input_data[1]) 18 | print(type(input_data[1])) -------------------------------------------------------------------------------- /examples/test_tokenization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM 3 | 4 | # OPTIONAL: if you want to have more information on what's happening, activate the logger as follows 5 | import logging 6 | logging.basicConfig(level=logging.INFO) 7 | 8 | # Load pre-trained model tokenizer (vocabulary) 9 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 10 | 11 | # Tokenized input 12 | text = "[CLS] Who 这是一个测试 was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]" # test version 13 | # text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]" 14 | 15 | """ 16 | tokenized_text: 17 | ['[CLS]', 'who', '[UNK]', '[UNK]', '一', '[UNK]', '[UNK]', '[UNK]', 'was', 'jim', 'henson', '?', '[SEP]', 18 | 'jim', 'henson', 'was', 'a', 'puppet', '##eer', '[SEP]'] 19 | """ 20 | tokenized_text = tokenizer.tokenize(text) 21 | 22 | # Mask a token that we will try to predict back with `BertForMaskedLM` 23 | masked_index = 8 24 | tokenized_text[masked_index] = '[MASK]' 25 | # assert tokenized_text == \ 26 | # ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a', 27 | # 'puppet', '##eer', '[SEP]'] 28 | 29 | # test version 30 | assert tokenized_text == \ 31 | ['[CLS]', 'who', '[UNK]', '[UNK]', '一', '[UNK]', '[UNK]', '[UNK]', '[MASK]', 'jim', 'henson', '?', 32 | '[SEP]', 'jim', 'henson', 'was', 'a', 'puppet', '##eer', '[SEP]'] 33 | 34 | # Convert token to vocabulary indices 35 | """ 36 | indexed_tokens: 37 | [101, 2040, 100, 100, 1740, 100, 100, 100, 103, 3958, 27227, 1029, 102, 3958, 27227, 2001, 1037, 13997, 11510, 102] 38 | """ 39 | indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) 40 | # Define sentence A and B indices associated to 1st and 2nd sentences (see paper) 41 | # segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] 42 | 43 | # test version 44 | segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] 45 | 46 | """ 47 | tokens_tensor: 48 | tensor([[ 101, 2040, 100, 100, 1740, 100, 100, 100, 103, 3958, 49 | 27227, 1029, 102, 3958, 27227, 2001, 1037, 13997, 11510, 102]]) 50 | 51 | segments_tensors: 52 | tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]]) 53 | """ 54 | # Convert inputs to PyTorch tensors 55 | tokens_tensor = torch.tensor([indexed_tokens]) 56 | segments_tensors = torch.tensor([segments_ids]) 57 | 58 | -------------------------------------------------------------------------------- /examples/valid_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | # BASE_PATH = "/home/wyb/PycharmProjects/DuReader/data/demo/" 4 | BASE_PATH = "/DATA/disk1/wangyongbo/lic2019/DuReader/data/preprocessed/" 5 | 6 | with open(BASE_PATH + "trainset/search.train_bert.json", "r", encoding='utf-8') as reader: 7 | source = json.load(reader) 8 | input_data = source["data"] 9 | 10 | cou_equal = 0 11 | cou_total = 0 12 | for entry in input_data: 13 | for paragraph in entry["paragraphs"]: 14 | paragraph_text = paragraph["context"] 15 | 16 | for qa in paragraph["qas"]: 17 | cou_total += 1 18 | 19 | """ 20 | { 21 | 'text': 'in the late 1990s', 22 | 'answer_start': 269 # by char 23 | } 24 | """ 25 | answer_dict = qa["answers"][0] 26 | answer = answer_dict["text"] 27 | start_position = answer_dict["answer_start"] # by word 28 | end_position = answer_dict["answer_end"] # by word 29 | 30 | if paragraph_text[start_position:(end_position+1)] == answer.strip(): 31 | cou_equal += 1 32 | 33 | print("cou_equal / cou_total = ", cou_equal, " / ", cou_total) 34 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.4.0" 2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer 3 | from .modeling import (BertConfig, BertModel, BertForPreTraining, 4 | BertForMaskedLM, BertForNextSentencePrediction, 5 | BertForSequenceClassification, BertForMultipleChoice, 6 | BertForTokenClassification, BertForQuestionAnswering) 7 | from .optimization import BertAdam 8 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE 9 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | try: 5 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 6 | except ModuleNotFoundError: 7 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 8 | "In that case, it requires TensorFlow to be installed. Please see " 9 | "https://www.tensorflow.org/install/ for installation instructions.") 10 | raise 11 | 12 | if len(sys.argv) != 5: 13 | # pylint: disable=line-too-long 14 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 15 | else: 16 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 17 | TF_CONFIG = sys.argv.pop() 18 | TF_CHECKPOINT = sys.argv.pop() 19 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 20 | 21 | if __name__ == '__main__': 22 | main() 23 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongbowin/pytorch-pretrained-BERT_annotation/dabfc0941fbeaac931c78ce7d55b15f9f51d62a8/pytorch_pretrained_bert/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/file_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongbowin/pytorch-pretrained-BERT_annotation/dabfc0941fbeaac931c78ce7d55b15f9f51d62a8/pytorch_pretrained_bert/__pycache__/file_utils.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/modeling.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongbowin/pytorch-pretrained-BERT_annotation/dabfc0941fbeaac931c78ce7d55b15f9f51d62a8/pytorch_pretrained_bert/__pycache__/modeling.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/optimization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongbowin/pytorch-pretrained-BERT_annotation/dabfc0941fbeaac931c78ce7d55b15f9f51d62a8/pytorch_pretrained_bert/__pycache__/optimization.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/tokenization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongbowin/pytorch-pretrained-BERT_annotation/dabfc0941fbeaac931c78ce7d55b15f9f51d62a8/pytorch_pretrained_bert/__pycache__/tokenization.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import re 23 | import argparse 24 | import tensorflow as tf 25 | import torch 26 | import numpy as np 27 | 28 | from .modeling import BertConfig, BertForPreTraining 29 | 30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 31 | config_path = os.path.abspath(bert_config_file) 32 | tf_path = os.path.abspath(tf_checkpoint_path) 33 | print("Converting TensorFlow checkpoint from {} with config at {}".format(tf_path, config_path)) 34 | # Load weights from TF model 35 | init_vars = tf.train.list_variables(tf_path) 36 | names = [] 37 | arrays = [] 38 | for name, shape in init_vars: 39 | print("Loading TF weight {} with shape {}".format(name, shape)) 40 | array = tf.train.load_variable(tf_path, name) 41 | names.append(name) 42 | arrays.append(array) 43 | 44 | # Initialise PyTorch model 45 | config = BertConfig.from_json_file(bert_config_file) 46 | print("Building PyTorch model from configuration: {}".format(str(config))) 47 | model = BertForPreTraining(config) 48 | 49 | for name, array in zip(names, arrays): 50 | name = name.split('/') 51 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v 52 | # which are not required for using pretrained model 53 | if any(n in ["adam_v", "adam_m", "global_step"] for n in name): 54 | print("Skipping {}".format("/".join(name))) 55 | continue 56 | pointer = model 57 | for m_name in name: 58 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name): 59 | l = re.split(r'_(\d+)', m_name) 60 | else: 61 | l = [m_name] 62 | if l[0] == 'kernel' or l[0] == 'gamma': 63 | pointer = getattr(pointer, 'weight') 64 | elif l[0] == 'output_bias' or l[0] == 'beta': 65 | pointer = getattr(pointer, 'bias') 66 | elif l[0] == 'output_weights': 67 | pointer = getattr(pointer, 'weight') 68 | else: 69 | pointer = getattr(pointer, l[0]) 70 | if len(l) >= 2: 71 | num = int(l[1]) 72 | pointer = pointer[num] 73 | if m_name[-11:] == '_embeddings': 74 | pointer = getattr(pointer, 'weight') 75 | elif m_name == 'kernel': 76 | array = np.transpose(array) 77 | try: 78 | assert pointer.shape == array.shape 79 | except AssertionError as e: 80 | e.args += (pointer.shape, array.shape) 81 | raise 82 | print("Initialize PyTorch weight {}".format(name)) 83 | pointer.data = torch.from_numpy(array) 84 | 85 | # Save pytorch-model 86 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 87 | torch.save(model.state_dict(), pytorch_dump_path) 88 | 89 | 90 | if __name__ == "__main__": 91 | parser = argparse.ArgumentParser() 92 | ## Required parameters 93 | parser.add_argument("--tf_checkpoint_path", 94 | default = None, 95 | type = str, 96 | required = True, 97 | help = "Path the TensorFlow checkpoint path.") 98 | parser.add_argument("--bert_config_file", 99 | default = None, 100 | type = str, 101 | required = True, 102 | help = "The config json file corresponding to the pre-trained BERT model. \n" 103 | "This specifies the model architecture.") 104 | parser.add_argument("--pytorch_dump_path", 105 | default = None, 106 | type = str, 107 | required = True, 108 | help = "Path to the output PyTorch model.") 109 | args = parser.parse_args() 110 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 111 | args.bert_config_file, 112 | args.pytorch_dump_path) 113 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | 7 | import os 8 | import logging 9 | import shutil 10 | import tempfile 11 | import json 12 | from urllib.parse import urlparse 13 | from pathlib import Path 14 | from typing import Optional, Tuple, Union, IO, Callable, Set 15 | from hashlib import sha256 16 | from functools import wraps 17 | 18 | from tqdm import tqdm 19 | 20 | import boto3 21 | from botocore.exceptions import ClientError 22 | import requests 23 | 24 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 25 | 26 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 27 | Path.home() / '.pytorch_pretrained_bert')) 28 | 29 | 30 | def url_to_filename(url: str, etag: str = None) -> str: 31 | """ 32 | Convert `url` into a hashed filename in a repeatable way. 33 | If `etag` is specified, append its hash to the url's, delimited 34 | by a period. 35 | """ 36 | url_bytes = url.encode('utf-8') 37 | url_hash = sha256(url_bytes) 38 | filename = url_hash.hexdigest() 39 | 40 | if etag: 41 | etag_bytes = etag.encode('utf-8') 42 | etag_hash = sha256(etag_bytes) 43 | filename += '.' + etag_hash.hexdigest() 44 | 45 | return filename 46 | 47 | 48 | def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]: 49 | """ 50 | Return the url and etag (which may be ``None``) stored for `filename`. 51 | Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist. 52 | """ 53 | if cache_dir is None: 54 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 55 | if isinstance(cache_dir, Path): 56 | cache_dir = str(cache_dir) 57 | 58 | cache_path = os.path.join(cache_dir, filename) 59 | if not os.path.exists(cache_path): 60 | raise FileNotFoundError("file {} not found".format(cache_path)) 61 | 62 | meta_path = cache_path + '.json' 63 | if not os.path.exists(meta_path): 64 | raise FileNotFoundError("file {} not found".format(meta_path)) 65 | 66 | with open(meta_path) as meta_file: 67 | metadata = json.load(meta_file) 68 | url = metadata['url'] 69 | etag = metadata['etag'] 70 | 71 | return url, etag 72 | 73 | 74 | def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str: 75 | """ 76 | Given something that might be a URL (or might be a local path), 77 | determine which. If it's a URL, download the file and cache it, and 78 | return the path to the cached file. If it's already a local path, 79 | make sure the file exists and then return the path. 80 | """ 81 | if cache_dir is None: 82 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 83 | if isinstance(url_or_filename, Path): 84 | url_or_filename = str(url_or_filename) 85 | if isinstance(cache_dir, Path): 86 | cache_dir = str(cache_dir) 87 | 88 | parsed = urlparse(url_or_filename) 89 | 90 | if parsed.scheme in ('http', 'https', 's3'): 91 | # URL, so get it from the cache (downloading if necessary) 92 | return get_from_cache(url_or_filename, cache_dir) 93 | elif os.path.exists(url_or_filename): 94 | # File, and it exists. 95 | return url_or_filename 96 | elif parsed.scheme == '': 97 | # File, but it doesn't exist. 98 | raise FileNotFoundError("file {} not found".format(url_or_filename)) 99 | else: 100 | # Something unknown 101 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 102 | 103 | 104 | def split_s3_path(url: str) -> Tuple[str, str]: 105 | """Split a full s3 path into the bucket name and path.""" 106 | parsed = urlparse(url) 107 | if not parsed.netloc or not parsed.path: 108 | raise ValueError("bad s3 path {}".format(url)) 109 | bucket_name = parsed.netloc 110 | s3_path = parsed.path 111 | # Remove '/' at beginning of path. 112 | if s3_path.startswith("/"): 113 | s3_path = s3_path[1:] 114 | return bucket_name, s3_path 115 | 116 | 117 | def s3_request(func: Callable): 118 | """ 119 | Wrapper function for s3 requests in order to create more helpful error 120 | messages. 121 | """ 122 | 123 | @wraps(func) 124 | def wrapper(url: str, *args, **kwargs): 125 | try: 126 | return func(url, *args, **kwargs) 127 | except ClientError as exc: 128 | if int(exc.response["Error"]["Code"]) == 404: 129 | raise FileNotFoundError("file {} not found".format(url)) 130 | else: 131 | raise 132 | 133 | return wrapper 134 | 135 | 136 | @s3_request 137 | def s3_etag(url: str) -> Optional[str]: 138 | """Check ETag on S3 object.""" 139 | s3_resource = boto3.resource("s3") 140 | bucket_name, s3_path = split_s3_path(url) 141 | s3_object = s3_resource.Object(bucket_name, s3_path) 142 | return s3_object.e_tag 143 | 144 | 145 | @s3_request 146 | def s3_get(url: str, temp_file: IO) -> None: 147 | """Pull a file directly from S3.""" 148 | s3_resource = boto3.resource("s3") 149 | bucket_name, s3_path = split_s3_path(url) 150 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 151 | 152 | 153 | def http_get(url: str, temp_file: IO) -> None: 154 | req = requests.get(url, stream=True) 155 | content_length = req.headers.get('Content-Length') 156 | total = int(content_length) if content_length is not None else None 157 | progress = tqdm(unit="B", total=total) 158 | for chunk in req.iter_content(chunk_size=1024): 159 | if chunk: # filter out keep-alive new chunks 160 | progress.update(len(chunk)) 161 | temp_file.write(chunk) 162 | progress.close() 163 | 164 | 165 | def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str: 166 | """ 167 | Given a URL, look for the corresponding dataset in the local cache. 168 | If it's not there, download it. Then return the path to the cached file. 169 | """ 170 | if cache_dir is None: 171 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 172 | if isinstance(cache_dir, Path): 173 | cache_dir = str(cache_dir) 174 | 175 | os.makedirs(cache_dir, exist_ok=True) 176 | 177 | # Get eTag to add to filename, if it exists. 178 | if url.startswith("s3://"): 179 | etag = s3_etag(url) 180 | else: 181 | response = requests.head(url, allow_redirects=True) 182 | if response.status_code != 200: 183 | raise IOError("HEAD request failed for url {} with status code {}" 184 | .format(url, response.status_code)) 185 | etag = response.headers.get("ETag") 186 | 187 | filename = url_to_filename(url, etag) 188 | 189 | # get cache path to put the file 190 | cache_path = os.path.join(cache_dir, filename) 191 | 192 | if not os.path.exists(cache_path): 193 | # Download to temporary file, then copy to cache dir once finished. 194 | # Otherwise you get corrupt cache entries if the download gets interrupted. 195 | with tempfile.NamedTemporaryFile() as temp_file: 196 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 197 | 198 | # GET file object 199 | if url.startswith("s3://"): 200 | s3_get(url, temp_file) 201 | else: 202 | http_get(url, temp_file) 203 | 204 | # we are copying the file before closing it, so flush to avoid truncation 205 | temp_file.flush() 206 | # shutil.copyfileobj() starts at the current position, so go to the start 207 | temp_file.seek(0) 208 | 209 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 210 | with open(cache_path, 'wb') as cache_file: 211 | shutil.copyfileobj(temp_file, cache_file) 212 | 213 | logger.info("creating metadata file for %s", cache_path) 214 | meta = {'url': url, 'etag': etag} 215 | meta_path = cache_path + '.json' 216 | with open(meta_path, 'w') as meta_file: 217 | json.dump(meta, meta_file) 218 | 219 | logger.info("removing temp file %s", temp_file.name) 220 | 221 | return cache_path 222 | 223 | 224 | def read_set_from_file(filename: str) -> Set[str]: 225 | ''' 226 | Extract a de-duped collection (set) of text from a file. 227 | Expected file format is one item per line. 228 | ''' 229 | collection = set() 230 | with open(filename, 'r', encoding='utf-8') as file_: 231 | for line in file_: 232 | collection.add(line.rstrip()) 233 | return collection 234 | 235 | 236 | def get_file_extension(path: str, dot=True, lower: bool = True): 237 | ext = os.path.splitext(path)[1] 238 | ext = ext if dot else ext[1:] 239 | return ext.lower() if lower else ext 240 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | 23 | def warmup_cosine(x, warmup=0.002): 24 | if x < warmup: 25 | return x/warmup 26 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 27 | 28 | def warmup_constant(x, warmup=0.002): 29 | if x < warmup: 30 | return x/warmup 31 | return 1.0 32 | 33 | def warmup_linear(x, warmup=0.002): 34 | if x < warmup: 35 | return x/warmup 36 | return 1.0 - x 37 | 38 | SCHEDULES = { 39 | 'warmup_cosine':warmup_cosine, 40 | 'warmup_constant':warmup_constant, 41 | 'warmup_linear':warmup_linear, 42 | } 43 | 44 | 45 | class BertAdam(Optimizer): 46 | """Implements BERT version of Adam algorithm with weight decay fix. 47 | Params: 48 | lr: learning rate 49 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 50 | t_total: total number of training steps for the learning 51 | rate schedule, -1 means constant learning rate. Default: -1 52 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 53 | b1: Adams b1. Default: 0.9 54 | b2: Adams b2. Default: 0.999 55 | e: Adams epsilon. Default: 1e-6 56 | weight_decay: Weight decay. Default: 0.01 57 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 58 | """ 59 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 60 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, 61 | max_grad_norm=1.0): 62 | if lr is not required and lr < 0.0: 63 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 64 | if schedule not in SCHEDULES: 65 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 66 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 67 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 68 | if not 0.0 <= b1 < 1.0: 69 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 70 | if not 0.0 <= b2 < 1.0: 71 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 72 | if not e >= 0.0: 73 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 74 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 75 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 76 | max_grad_norm=max_grad_norm) 77 | super(BertAdam, self).__init__(params, defaults) 78 | 79 | def get_lr(self): 80 | lr = [] 81 | for group in self.param_groups: 82 | for p in group['params']: 83 | state = self.state[p] 84 | if len(state) == 0: 85 | return [0] 86 | if group['t_total'] != -1: 87 | schedule_fct = SCHEDULES[group['schedule']] 88 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 89 | else: 90 | lr_scheduled = group['lr'] 91 | lr.append(lr_scheduled) 92 | return lr 93 | 94 | def step(self, closure=None): 95 | """Performs a single optimization step. 96 | 97 | Arguments: 98 | closure (callable, optional): A closure that reevaluates the model 99 | and returns the loss. 100 | """ 101 | loss = None 102 | if closure is not None: 103 | loss = closure() 104 | 105 | for group in self.param_groups: 106 | for p in group['params']: 107 | if p.grad is None: 108 | continue 109 | grad = p.grad.data 110 | if grad.is_sparse: 111 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 112 | 113 | state = self.state[p] 114 | 115 | # State initialization 116 | if len(state) == 0: 117 | state['step'] = 0 118 | # Exponential moving average of gradient values 119 | state['next_m'] = torch.zeros_like(p.data) 120 | # Exponential moving average of squared gradient values 121 | state['next_v'] = torch.zeros_like(p.data) 122 | 123 | next_m, next_v = state['next_m'], state['next_v'] 124 | beta1, beta2 = group['b1'], group['b2'] 125 | 126 | # Add grad clipping 127 | if group['max_grad_norm'] > 0: 128 | clip_grad_norm_(p, group['max_grad_norm']) 129 | 130 | # Decay the first and second moment running average coefficient 131 | # In-place operations to update the averages at the same time 132 | next_m.mul_(beta1).add_(1 - beta1, grad) 133 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 134 | update = next_m / (next_v.sqrt() + group['e']) 135 | 136 | # Just adding the square of the weights to the loss function is *not* 137 | # the correct way of using L2 regularization/weight decay with Adam, 138 | # since that will interact with the m and v parameters in strange ways. 139 | # 140 | # Instead we want to decay the weights in a manner that doesn't interact 141 | # with the m/v parameters. This is equivalent to adding the square 142 | # of the weights to the loss with plain (non-momentum) SGD. 143 | if group['weight_decay'] > 0.0: 144 | update += group['weight_decay'] * p.data 145 | 146 | if group['t_total'] != -1: 147 | schedule_fct = SCHEDULES[group['schedule']] 148 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 149 | else: 150 | lr_scheduled = group['lr'] 151 | 152 | update_with_lr = lr_scheduled * update 153 | p.data.add_(-update_with_lr) 154 | 155 | state['step'] += 1 156 | 157 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 158 | # No bias correction 159 | # bias_correction1 = 1 - beta1 ** state['step'] 160 | # bias_correction2 = 1 - beta2 ** state['step'] 161 | 162 | return loss 163 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import os 24 | import logging 25 | 26 | from .file_utils import cached_path 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | """Uncased means that the text has been lowercased before WordPiece tokenization. 31 | """ 32 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 33 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 34 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 35 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 36 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 37 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 38 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 39 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 40 | } 41 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 42 | 'bert-base-uncased': 512, 43 | 'bert-large-uncased': 512, 44 | 'bert-base-cased': 512, 45 | 'bert-large-cased': 512, 46 | 'bert-base-multilingual-uncased': 512, 47 | 'bert-base-multilingual-cased': 512, 48 | 'bert-base-chinese': 512, 49 | } 50 | VOCAB_NAME = 'vocab.txt' 51 | 52 | 53 | def load_vocab(vocab_file): 54 | """Loads a vocabulary file into a dictionary.""" 55 | """The size of vocab is 30522 56 | """ 57 | vocab = collections.OrderedDict() 58 | index = 0 59 | with open(vocab_file, "r", encoding="utf-8") as reader: 60 | while True: 61 | token = reader.readline() 62 | if not token: 63 | break 64 | token = token.strip() 65 | vocab[token] = index 66 | index += 1 67 | return vocab 68 | 69 | 70 | def whitespace_tokenize(text): 71 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 72 | text = text.strip() 73 | if not text: 74 | return [] 75 | tokens = text.split() # split by ' ' return a word list. 76 | return tokens 77 | 78 | 79 | class BertTokenizer(object): 80 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 81 | 82 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, 83 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]", "[unused1]")): 84 | if not os.path.isfile(vocab_file): 85 | raise ValueError( 86 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 87 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 88 | self.vocab = load_vocab(vocab_file) 89 | self.ids_to_tokens = collections.OrderedDict( 90 | [(ids, tok) for tok, ids in self.vocab.items()]) 91 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 92 | never_split=never_split) 93 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 94 | self.max_len = max_len if max_len is not None else int(1e12) # max_len=512 95 | 96 | def tokenize(self, text): 97 | split_tokens = [] 98 | for token in self.basic_tokenizer.tokenize(text): 99 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 100 | split_tokens.append(sub_token) 101 | """ 102 | split_tokens: 103 | ['[CLS]', 'who', '[UNK]', '[UNK]', '一', '[UNK]', '[UNK]', '[UNK]', 'was', 'jim', 'henson', '?', '[SEP]', 104 | 'jim', 'henson', 'was', 'a', 'puppet', '##eer', '[SEP]'] 105 | """ 106 | return split_tokens 107 | 108 | def convert_tokens_to_ids(self, tokens): 109 | """Converts a sequence of tokens into ids using the vocab.""" 110 | ids = [] 111 | for token in tokens: 112 | ids.append(self.vocab[token]) 113 | if len(ids) > self.max_len: 114 | raise ValueError( 115 | "Token indices sequence length is longer than the specified maximum " 116 | " sequence length for this BERT model ({} > {}). Running this" 117 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) 118 | ) 119 | return ids 120 | 121 | def convert_ids_to_tokens(self, ids): 122 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 123 | tokens = [] 124 | for i in ids: 125 | tokens.append(self.ids_to_tokens[i]) 126 | return tokens 127 | 128 | @classmethod 129 | def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): 130 | """ 131 | Instantiate a PreTrainedBertModel from a pre-trained model file. 132 | Download and cache the pre-trained model file if needed. 133 | """ 134 | if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP: 135 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name] 136 | else: 137 | vocab_file = pretrained_model_name 138 | if os.path.isdir(vocab_file): 139 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 140 | # redirect to the cache, if necessary 141 | try: 142 | """Save file to "~/.pytorch_pretrained_bert". 143 | """ 144 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 145 | except FileNotFoundError: 146 | logger.error( 147 | "Model name '{}' was not found in model name list ({}). " 148 | "We assumed '{}' was a path or url but couldn't find any file " 149 | "associated to this path or url.".format( 150 | pretrained_model_name, 151 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 152 | vocab_file)) 153 | return None 154 | if resolved_vocab_file == vocab_file: 155 | logger.info("loading vocabulary file {}".format(vocab_file)) 156 | else: 157 | logger.info("loading vocabulary file {} from cache at {}".format( 158 | vocab_file, resolved_vocab_file)) 159 | if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 160 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 161 | # than the number of positional embeddings 162 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name] 163 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 164 | # Instantiate tokenizer. 165 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 166 | return tokenizer 167 | 168 | 169 | class BasicTokenizer(object): 170 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 171 | 172 | def __init__(self, 173 | do_lower_case=True, 174 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]", "[unused1]")): 175 | """Constructs a BasicTokenizer. 176 | 177 | Args: 178 | do_lower_case: Whether to lower case the input. 179 | """ 180 | self.do_lower_case = do_lower_case 181 | self.never_split = never_split 182 | 183 | def tokenize(self, text): 184 | """Tokenizes a piece of text.""" 185 | text = self._clean_text(text) 186 | # This was added on November 1st, 2018 for the multilingual and Chinese 187 | # models. This is also applied to the English models now, but it doesn't 188 | # matter since the English models were not trained on any Chinese data 189 | # and generally don't have any Chinese data in them (there are Chinese 190 | # characters in the vocabulary because Wikipedia does have some Chinese 191 | # words in the English Wikipedia.). 192 | text = self._tokenize_chinese_chars(text) 193 | orig_tokens = whitespace_tokenize(text) 194 | split_tokens = [] 195 | for token in orig_tokens: 196 | if self.do_lower_case and token not in self.never_split: 197 | token = token.lower() 198 | token = self._run_strip_accents(token) 199 | split_tokens.extend(self._run_split_on_punc(token)) 200 | 201 | """ 202 | output_tokens: 203 | ['[CLS]', 'who', '这', '是', '一', '个', '测', '试', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', 'henson', 204 | 'was', 'a', 'puppeteer', '[SEP]'] 205 | """ 206 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 207 | return output_tokens 208 | 209 | def _run_strip_accents(self, text): 210 | """Strips accents from a piece of text.""" 211 | """ 212 | Strips accents mean the following, 213 | input: "Málaga" 214 | output: "Malaga" 215 | 216 | output: 217 | ['M', 'a', 'l', 'a', 'g', 'a'] 218 | "".join(output) 219 | Malaga 220 | """ 221 | text = unicodedata.normalize("NFD", text) 222 | output = [] 223 | for char in text: 224 | cat = unicodedata.category(char) 225 | if cat == "Mn": 226 | continue 227 | output.append(char) 228 | return "".join(output) 229 | 230 | def _run_split_on_punc(self, text): 231 | """Splits punctuation on a piece of text.""" 232 | """ 233 | If text is one of "[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]", return directly without removing "[" or "]" 234 | """ 235 | if text in self.never_split: 236 | return [text] 237 | chars = list(text) 238 | i = 0 239 | start_new_word = True 240 | output = [] 241 | while i < len(chars): 242 | char = chars[i] 243 | if _is_punctuation(char): 244 | output.append([char]) 245 | start_new_word = True 246 | else: 247 | if start_new_word: 248 | output.append([]) 249 | start_new_word = False 250 | output[-1].append(char) 251 | i += 1 252 | 253 | return ["".join(x) for x in output] 254 | 255 | def _tokenize_chinese_chars(self, text): 256 | """Adds whitespace around any CJK character.""" 257 | """ 258 | CJK: 259 | Chinese, Japanese, Korean 260 | 261 | text = "[CLS] Who 这是一个测试 was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]" 262 | 263 | ord(): 264 | Return the Unicode code point for a one-character string. 265 | 266 | cp = 91 267 | cp = 67 268 | cp = 76 269 | ... 270 | 271 | "".join(output): 272 | [CLS] Who 这 是 一 个 测 试 was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP] 273 | 274 | output: 275 | ['[', 'C', 'L', 'S', ']', ' ', 'W', 'h', 'o', ' ', ' ', '这', ' ', ' ', '是', ' ', ' ', '一', ' ', ' ', '个', 276 | ' ', ' ', '测', ' ', ' ', '试', ' ', ' ', 'w', 'a', 's', ' ', 'J', 'i', 'm', ' ', 'H', 'e', 'n', 's', 277 | 'o', 'n', ' ', '?', ' ', '[', 'S', 'E', 'P', ']', ' ', 'J', 'i', 'm', ' ', 'H', 'e', 'n', 's', 'o', 'n', 278 | ' ', 'w', 'a', 's', ' ', 'a', ' ', 'p', 'u', 'p', 'p', 'e', 't', 'e', 'e', 'r', ' ', '[', 'S', 'E', 'P', 279 | ']'] 280 | """ 281 | output = [] 282 | for char in text: 283 | cp = ord(char) # 284 | if self._is_chinese_char(cp): 285 | output.append(" ") 286 | output.append(char) 287 | output.append(" ") 288 | else: 289 | output.append(char) 290 | return "".join(output) 291 | 292 | def _is_chinese_char(self, cp): 293 | """Checks whether CP is the codepoint of a CJK character.""" 294 | # This defines a "chinese character" as anything in the CJK Unicode block: 295 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 296 | # 297 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 298 | # despite its name. The modern Korean Hangul alphabet is a different block, 299 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 300 | # space-separated words, so they are not treated specially and handled 301 | # like the all of the other languages. 302 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 303 | (cp >= 0x3400 and cp <= 0x4DBF) or # 304 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 305 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 306 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 307 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 308 | (cp >= 0xF900 and cp <= 0xFAFF) or # 309 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 310 | return True 311 | 312 | return False 313 | 314 | def _clean_text(self, text): 315 | """Performs invalid character removal and whitespace cleanup on text.""" 316 | output = [] 317 | for char in text: 318 | cp = ord(char) 319 | if cp == 0 or cp == 0xfffd or _is_control(char): 320 | continue 321 | if _is_whitespace(char): 322 | output.append(" ") 323 | else: 324 | output.append(char) 325 | return "".join(output) 326 | 327 | 328 | class WordpieceTokenizer(object): 329 | """Runs WordPiece tokenization.""" 330 | 331 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 332 | self.vocab = vocab 333 | self.unk_token = unk_token 334 | self.max_input_chars_per_word = max_input_chars_per_word 335 | 336 | def tokenize(self, text): 337 | """Tokenizes a piece of text into its word pieces. 338 | 339 | This uses a greedy longest-match-first algorithm to perform tokenization 340 | using the given vocabulary. 341 | 342 | For example: 343 | input = "unaffable" 344 | output = ["un", "##aff", "##able"] 345 | 346 | Args: 347 | text: A single token or whitespace separated tokens. This should have 348 | already been passed through `BasicTokenizer`. 349 | 350 | Returns: 351 | A list of wordpiece tokens. 352 | 353 | Papers: (for wordPiece) 354 | [1] https://arxiv.org/pdf/1508.07909.pdf 355 | [2] https://arxiv.org/pdf/1609.08144.pdf 356 | 357 | whitespace_tokenize(): 358 | split by ' ' return a word list, i.e. ["word1", "word2", ..., "wordn"] 359 | """ 360 | 361 | output_tokens = [] 362 | for token in whitespace_tokenize(text): # text is a word 363 | chars = list(token) 364 | if len(chars) > self.max_input_chars_per_word: 365 | output_tokens.append(self.unk_token) # unknown 366 | continue 367 | 368 | is_bad = False 369 | start = 0 370 | sub_tokens = [] 371 | while start < len(chars): 372 | end = len(chars) 373 | cur_substr = None 374 | while start < end: 375 | substr = "".join(chars[start:end]) 376 | if start > 0: 377 | substr = "##" + substr 378 | if substr in self.vocab: 379 | cur_substr = substr 380 | break 381 | end -= 1 382 | if cur_substr is None: 383 | is_bad = True 384 | break 385 | sub_tokens.append(cur_substr) # if cur_substr in vocab, append it to []. 386 | start = end # when head part was appended, start the latter part of this word. 387 | 388 | if is_bad: 389 | output_tokens.append(self.unk_token) 390 | else: 391 | output_tokens.extend(sub_tokens) 392 | return output_tokens 393 | 394 | 395 | def _is_whitespace(char): 396 | """Checks whether `chars` is a whitespace character.""" 397 | # \t, \n, and \r are technically contorl characters but we treat them 398 | # as whitespace since they are generally considered as such. 399 | if char == " " or char == "\t" or char == "\n" or char == "\r": 400 | return True 401 | cat = unicodedata.category(char) 402 | if cat == "Zs": 403 | return True 404 | return False 405 | 406 | 407 | def _is_control(char): 408 | """Checks whether `chars` is a control character.""" 409 | # These are technically control characters but we count them as whitespace 410 | # characters. 411 | if char == "\t" or char == "\n" or char == "\r": 412 | return False 413 | cat = unicodedata.category(char) 414 | if cat.startswith("C"): 415 | return True 416 | return False 417 | 418 | 419 | def _is_punctuation(char): 420 | """Checks whether `chars` is a punctuation character.""" 421 | cp = ord(char) 422 | # We treat all non-letter/number ASCII as punctuation. 423 | # Characters such as "^", "$", and "`" are not in the Unicode 424 | # Punctuation class but we treat them as punctuation anyways, for 425 | # consistency. 426 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 427 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 428 | return True 429 | cat = unicodedata.category(char) 430 | if cat.startswith("P"): 431 | return True 432 | return False 433 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # PyTorch 2 | torch>=0.4.1 3 | # progress bars in model download and training scripts 4 | tqdm 5 | # Accessing files from S3 directly. 6 | boto3 7 | # Used for downloading models over HTTP 8 | requests -------------------------------------------------------------------------------- /samples/input.txt: -------------------------------------------------------------------------------- 1 | Who was Jim Henson ? ||| Jim Henson was a puppeteer 2 | -------------------------------------------------------------------------------- /samples/sample_text.txt: -------------------------------------------------------------------------------- 1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত 2 | Text should be one-sentence-per-line, with empty lines between documents. 3 | This sample text is public domain and was randomly selected from Project Guttenberg. 4 | 5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors. 6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity. 7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them. 8 | "Cass" Beard had risen early that morning, but not with a view to discovery. 9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets. 10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency. 11 | This was nearly opposite. 12 | Mr. Cassius crossed the highway, and stopped suddenly. 13 | Something glittered in the nearest red pool before him. 14 | Gold, surely! 15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring. 16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass." 17 | Like most of his fellow gold-seekers, Cass was superstitious. 18 | 19 | The fountain of classic wisdom, Hypatia herself. 20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge. 21 | From my youth I felt in me a soul above the matter-entangled herd. 22 | She revealed to me the glorious fact, that I am a spark of Divinity itself. 23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's. 24 | There is a philosophic pleasure in opening one's treasures to the modest young. 25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street. 26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide; 27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind. 28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now. 29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert; 30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts. 31 | At last they reached the quay at the opposite end of the street; 32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers. 33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him. 34 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Exec this file by run 'python setup.py build' to related the this source code path. 3 | """ 4 | 5 | """ 6 | Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py 7 | 8 | To create the package for pypi. 9 | 10 | 1. Change the version in __init__.py and setup.py. 11 | 12 | 2. Commit these changes with the message: "Release: VERSION" 13 | 14 | 3. Add a tag in git to mark the release: "git tag VERSION -m'Adds tag VERSION for pypi' " 15 | Push the tag to git: git push --tags origin master 16 | 17 | 4. Build both the sources and the wheel. Do not change anything in setup.py between 18 | creating the wheel and the source distribution (obviously). 19 | 20 | For the wheel, run: "python setup.py bdist_wheel" in the top level allennlp directory. 21 | (this will build a wheel for the python version you use to build it - make sure you use python 3.x). 22 | 23 | For the sources, run: "python setup.py sdist" 24 | You should now have a /dist directory with both .whl and .tar.gz source versions of allennlp. 25 | 26 | 5. Check that everything looks correct by uploading the package to the pypi test server: 27 | 28 | twine upload dist/* -r pypitest 29 | (pypi suggest using twine as other methods upload files via plaintext.) 30 | 31 | Check that you can install it in a virtualenv by running: 32 | pip install -i https://testpypi.python.org/pypi allennlp 33 | 34 | 6. Upload the final version to actual pypi: 35 | twine upload dist/* -r pypi 36 | 37 | 7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory. 38 | 39 | """ 40 | from setuptools import find_packages, setup 41 | 42 | setup( 43 | name="pytorch_pretrained_bert", 44 | version="0.4.0", 45 | author="Thomas Wolf, Victor Sanh, Tim Rault, Google AI Language Team Authors", 46 | author_email="thomas@huggingface.co", 47 | description="PyTorch version of Google AI BERT model with script to load Google pre-trained models", 48 | long_description=open("README.md", "r", encoding='utf-8').read(), 49 | long_description_content_type="text/markdown", 50 | keywords='BERT NLP deep learning google', 51 | license='Apache', 52 | url="https://github.com/huggingface/pytorch-pretrained-BERT", 53 | packages=find_packages(exclude=["*.tests", "*.tests.*", 54 | "tests.*", "tests"]), 55 | install_requires=['torch>=0.4.1', 56 | 'numpy', 57 | 'boto3', 58 | 'requests', 59 | 'tqdm'], 60 | entry_points={ 61 | 'console_scripts': [ 62 | "pytorch_pretrained_bert=pytorch_pretrained_bert.__main__:main" 63 | ] 64 | }, 65 | python_requires='>=3.5.0', 66 | tests_require=['pytest'], 67 | classifiers=[ 68 | 'Intended Audience :: Science/Research', 69 | 'License :: OSI Approved :: Apache Software License', 70 | 'Programming Language :: Python :: 3', 71 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 72 | ], 73 | ) 74 | -------------------------------------------------------------------------------- /tests/modeling_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import json 21 | import random 22 | 23 | import torch 24 | 25 | from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM, 26 | BertForNextSentencePrediction, BertForPreTraining, 27 | BertForQuestionAnswering, BertForSequenceClassification, 28 | BertForTokenClassification) 29 | 30 | 31 | class BertModelTest(unittest.TestCase): 32 | class BertModelTester(object): 33 | 34 | def __init__(self, 35 | parent, 36 | batch_size=13, 37 | seq_length=7, 38 | is_training=True, 39 | use_input_mask=True, 40 | use_token_type_ids=True, 41 | use_labels=True, 42 | vocab_size=99, 43 | hidden_size=32, 44 | num_hidden_layers=5, 45 | num_attention_heads=4, 46 | intermediate_size=37, 47 | hidden_act="gelu", 48 | hidden_dropout_prob=0.1, 49 | attention_probs_dropout_prob=0.1, 50 | max_position_embeddings=512, 51 | type_vocab_size=16, 52 | type_sequence_label_size=2, 53 | initializer_range=0.02, 54 | num_labels=3, 55 | scope=None): 56 | self.parent = parent 57 | self.batch_size = batch_size 58 | self.seq_length = seq_length 59 | self.is_training = is_training 60 | self.use_input_mask = use_input_mask 61 | self.use_token_type_ids = use_token_type_ids 62 | self.use_labels = use_labels 63 | self.vocab_size = vocab_size 64 | self.hidden_size = hidden_size 65 | self.num_hidden_layers = num_hidden_layers 66 | self.num_attention_heads = num_attention_heads 67 | self.intermediate_size = intermediate_size 68 | self.hidden_act = hidden_act 69 | self.hidden_dropout_prob = hidden_dropout_prob 70 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 71 | self.max_position_embeddings = max_position_embeddings 72 | self.type_vocab_size = type_vocab_size 73 | self.type_sequence_label_size = type_sequence_label_size 74 | self.initializer_range = initializer_range 75 | self.num_labels = num_labels 76 | self.scope = scope 77 | 78 | def prepare_config_and_inputs(self): 79 | input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size) 80 | 81 | input_mask = None 82 | if self.use_input_mask: 83 | input_mask = BertModelTest.ids_tensor([self.batch_size, self.seq_length], vocab_size=2) 84 | 85 | token_type_ids = None 86 | if self.use_token_type_ids: 87 | token_type_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) 88 | 89 | sequence_labels = None 90 | token_labels = None 91 | if self.use_labels: 92 | sequence_labels = BertModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size) 93 | token_labels = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.num_labels) 94 | 95 | config = BertConfig( 96 | vocab_size_or_config_json_file=self.vocab_size, 97 | hidden_size=self.hidden_size, 98 | num_hidden_layers=self.num_hidden_layers, 99 | num_attention_heads=self.num_attention_heads, 100 | intermediate_size=self.intermediate_size, 101 | hidden_act=self.hidden_act, 102 | hidden_dropout_prob=self.hidden_dropout_prob, 103 | attention_probs_dropout_prob=self.attention_probs_dropout_prob, 104 | max_position_embeddings=self.max_position_embeddings, 105 | type_vocab_size=self.type_vocab_size, 106 | initializer_range=self.initializer_range) 107 | 108 | return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels 109 | 110 | def check_loss_output(self, result): 111 | self.parent.assertListEqual( 112 | list(result["loss"].size()), 113 | []) 114 | 115 | def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): 116 | model = BertModel(config=config) 117 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 118 | outputs = { 119 | "sequence_output": all_encoder_layers[-1], 120 | "pooled_output": pooled_output, 121 | "all_encoder_layers": all_encoder_layers, 122 | } 123 | return outputs 124 | 125 | def check_bert_model_output(self, result): 126 | self.parent.assertListEqual( 127 | [size for layer in result["all_encoder_layers"] for size in layer.size()], 128 | [self.batch_size, self.seq_length, self.hidden_size] * self.num_hidden_layers) 129 | self.parent.assertListEqual( 130 | list(result["sequence_output"].size()), 131 | [self.batch_size, self.seq_length, self.hidden_size]) 132 | self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) 133 | 134 | 135 | def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): 136 | model = BertForMaskedLM(config=config) 137 | loss = model(input_ids, token_type_ids, input_mask, token_labels) 138 | prediction_scores = model(input_ids, token_type_ids, input_mask) 139 | outputs = { 140 | "loss": loss, 141 | "prediction_scores": prediction_scores, 142 | } 143 | return outputs 144 | 145 | def check_bert_for_masked_lm_output(self, result): 146 | self.parent.assertListEqual( 147 | list(result["prediction_scores"].size()), 148 | [self.batch_size, self.seq_length, self.vocab_size]) 149 | 150 | def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): 151 | model = BertForNextSentencePrediction(config=config) 152 | loss = model(input_ids, token_type_ids, input_mask, sequence_labels) 153 | seq_relationship_score = model(input_ids, token_type_ids, input_mask) 154 | outputs = { 155 | "loss": loss, 156 | "seq_relationship_score": seq_relationship_score, 157 | } 158 | return outputs 159 | 160 | def check_bert_for_next_sequence_prediction_output(self, result): 161 | self.parent.assertListEqual( 162 | list(result["seq_relationship_score"].size()), 163 | [self.batch_size, 2]) 164 | 165 | 166 | def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): 167 | model = BertForPreTraining(config=config) 168 | loss = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels) 169 | prediction_scores, seq_relationship_score = model(input_ids, token_type_ids, input_mask) 170 | outputs = { 171 | "loss": loss, 172 | "prediction_scores": prediction_scores, 173 | "seq_relationship_score": seq_relationship_score, 174 | } 175 | return outputs 176 | 177 | def check_bert_for_pretraining_output(self, result): 178 | self.parent.assertListEqual( 179 | list(result["prediction_scores"].size()), 180 | [self.batch_size, self.seq_length, self.vocab_size]) 181 | self.parent.assertListEqual( 182 | list(result["seq_relationship_score"].size()), 183 | [self.batch_size, 2]) 184 | 185 | 186 | def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): 187 | model = BertForQuestionAnswering(config=config) 188 | loss = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels) 189 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 190 | outputs = { 191 | "loss": loss, 192 | "start_logits": start_logits, 193 | "end_logits": end_logits, 194 | } 195 | return outputs 196 | 197 | def check_bert_for_question_answering_output(self, result): 198 | self.parent.assertListEqual( 199 | list(result["start_logits"].size()), 200 | [self.batch_size, self.seq_length]) 201 | self.parent.assertListEqual( 202 | list(result["end_logits"].size()), 203 | [self.batch_size, self.seq_length]) 204 | 205 | 206 | def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): 207 | model = BertForSequenceClassification(config=config, num_labels=self.num_labels) 208 | loss = model(input_ids, token_type_ids, input_mask, sequence_labels) 209 | logits = model(input_ids, token_type_ids, input_mask) 210 | outputs = { 211 | "loss": loss, 212 | "logits": logits, 213 | } 214 | return outputs 215 | 216 | def check_bert_for_sequence_classification_output(self, result): 217 | self.parent.assertListEqual( 218 | list(result["logits"].size()), 219 | [self.batch_size, self.num_labels]) 220 | 221 | 222 | def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): 223 | model = BertForTokenClassification(config=config, num_labels=self.num_labels) 224 | loss = model(input_ids, token_type_ids, input_mask, token_labels) 225 | logits = model(input_ids, token_type_ids, input_mask) 226 | outputs = { 227 | "loss": loss, 228 | "logits": logits, 229 | } 230 | return outputs 231 | 232 | def check_bert_for_token_classification_output(self, result): 233 | self.parent.assertListEqual( 234 | list(result["logits"].size()), 235 | [self.batch_size, self.seq_length, self.num_labels]) 236 | 237 | 238 | def test_default(self): 239 | self.run_tester(BertModelTest.BertModelTester(self)) 240 | 241 | def test_config_to_json_string(self): 242 | config = BertConfig(vocab_size_or_config_json_file=99, hidden_size=37) 243 | obj = json.loads(config.to_json_string()) 244 | self.assertEqual(obj["vocab_size"], 99) 245 | self.assertEqual(obj["hidden_size"], 37) 246 | 247 | def run_tester(self, tester): 248 | config_and_inputs = tester.prepare_config_and_inputs() 249 | output_result = tester.create_bert_model(*config_and_inputs) 250 | tester.check_bert_model_output(output_result) 251 | 252 | output_result = tester.create_bert_for_masked_lm(*config_and_inputs) 253 | tester.check_bert_for_masked_lm_output(output_result) 254 | tester.check_loss_output(output_result) 255 | 256 | output_result = tester.create_bert_for_next_sequence_prediction(*config_and_inputs) 257 | tester.check_bert_for_next_sequence_prediction_output(output_result) 258 | tester.check_loss_output(output_result) 259 | 260 | output_result = tester.create_bert_for_pretraining(*config_and_inputs) 261 | tester.check_bert_for_pretraining_output(output_result) 262 | tester.check_loss_output(output_result) 263 | 264 | output_result = tester.create_bert_for_question_answering(*config_and_inputs) 265 | tester.check_bert_for_question_answering_output(output_result) 266 | tester.check_loss_output(output_result) 267 | 268 | output_result = tester.create_bert_for_sequence_classification(*config_and_inputs) 269 | tester.check_bert_for_sequence_classification_output(output_result) 270 | tester.check_loss_output(output_result) 271 | 272 | output_result = tester.create_bert_for_token_classification(*config_and_inputs) 273 | tester.check_bert_for_token_classification_output(output_result) 274 | tester.check_loss_output(output_result) 275 | 276 | @classmethod 277 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None): 278 | """Creates a random int32 tensor of the shape within the vocab size.""" 279 | if rng is None: 280 | rng = random.Random() 281 | 282 | total_dims = 1 283 | for dim in shape: 284 | total_dims *= dim 285 | 286 | values = [] 287 | for _ in range(total_dims): 288 | values.append(rng.randint(0, vocab_size - 1)) 289 | 290 | return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous() 291 | 292 | 293 | if __name__ == "__main__": 294 | unittest.main() 295 | -------------------------------------------------------------------------------- /tests/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | 21 | import torch 22 | 23 | from pytorch_pretrained_bert import BertAdam 24 | 25 | class OptimizationTest(unittest.TestCase): 26 | 27 | def assertListAlmostEqual(self, list1, list2, tol): 28 | self.assertEqual(len(list1), len(list2)) 29 | for a, b in zip(list1, list2): 30 | self.assertAlmostEqual(a, b, delta=tol) 31 | 32 | def test_adam(self): 33 | w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True) 34 | target = torch.tensor([0.4, 0.2, -0.5]) 35 | criterion = torch.nn.MSELoss() 36 | # No warmup, constant schedule, no gradient clipping 37 | optimizer = BertAdam(params=[w], lr=2e-1, 38 | weight_decay=0.0, 39 | max_grad_norm=-1) 40 | for _ in range(100): 41 | loss = criterion(w, target) 42 | loss.backward() 43 | optimizer.step() 44 | w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves. 45 | w.grad.zero_() 46 | self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) 47 | 48 | 49 | if __name__ == "__main__": 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /tests/tokenization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import unittest 21 | 22 | from pytorch_pretrained_bert.tokenization import (BertTokenizer, BasicTokenizer, WordpieceTokenizer, 23 | _is_whitespace, _is_control, _is_punctuation) 24 | 25 | 26 | class TokenizationTest(unittest.TestCase): 27 | 28 | def test_full_tokenizer(self): 29 | vocab_tokens = [ 30 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 31 | "##ing", "," 32 | ] 33 | with open("/tmp/bert_tokenizer_test.txt", "w") as vocab_writer: 34 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 35 | 36 | vocab_file = vocab_writer.name 37 | 38 | tokenizer = BertTokenizer(vocab_file) 39 | os.remove(vocab_file) 40 | 41 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 42 | self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 43 | 44 | self.assertListEqual( 45 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 46 | 47 | def test_full_tokenizer_raises_error_for_long_sequences(self): 48 | vocab_tokens = [ 49 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 50 | "##ing", "," 51 | ] 52 | with open("/tmp/bert_tokenizer_test.txt", "w") as vocab_writer: 53 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 54 | vocab_file = vocab_writer.name 55 | 56 | tokenizer = BertTokenizer(vocab_file, max_len=10) 57 | os.remove(vocab_file) 58 | tokens = tokenizer.tokenize(u"the cat sat on the mat in the summer time") 59 | indices = tokenizer.convert_tokens_to_ids(tokens) 60 | self.assertListEqual(indices, [0 for _ in range(10)]) 61 | 62 | tokens = tokenizer.tokenize(u"the cat sat on the mat in the summer time .") 63 | self.assertRaises(ValueError, tokenizer.convert_tokens_to_ids, tokens) 64 | 65 | def test_chinese(self): 66 | tokenizer = BasicTokenizer() 67 | 68 | self.assertListEqual( 69 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 70 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 71 | 72 | def test_basic_tokenizer_lower(self): 73 | tokenizer = BasicTokenizer(do_lower_case=True) 74 | 75 | self.assertListEqual( 76 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 77 | ["hello", "!", "how", "are", "you", "?"]) 78 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 79 | 80 | def test_basic_tokenizer_no_lower(self): 81 | tokenizer = BasicTokenizer(do_lower_case=False) 82 | 83 | self.assertListEqual( 84 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 85 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 86 | 87 | def test_wordpiece_tokenizer(self): 88 | vocab_tokens = [ 89 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 90 | "##ing" 91 | ] 92 | 93 | vocab = {} 94 | for (i, token) in enumerate(vocab_tokens): 95 | vocab[token] = i 96 | tokenizer = WordpieceTokenizer(vocab=vocab) 97 | 98 | self.assertListEqual(tokenizer.tokenize(""), []) 99 | 100 | self.assertListEqual( 101 | tokenizer.tokenize("unwanted running"), 102 | ["un", "##want", "##ed", "runn", "##ing"]) 103 | 104 | self.assertListEqual( 105 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 106 | 107 | def test_is_whitespace(self): 108 | self.assertTrue(_is_whitespace(u" ")) 109 | self.assertTrue(_is_whitespace(u"\t")) 110 | self.assertTrue(_is_whitespace(u"\r")) 111 | self.assertTrue(_is_whitespace(u"\n")) 112 | self.assertTrue(_is_whitespace(u"\u00A0")) 113 | 114 | self.assertFalse(_is_whitespace(u"A")) 115 | self.assertFalse(_is_whitespace(u"-")) 116 | 117 | def test_is_control(self): 118 | self.assertTrue(_is_control(u"\u0005")) 119 | 120 | self.assertFalse(_is_control(u"A")) 121 | self.assertFalse(_is_control(u" ")) 122 | self.assertFalse(_is_control(u"\t")) 123 | self.assertFalse(_is_control(u"\r")) 124 | 125 | def test_is_punctuation(self): 126 | self.assertTrue(_is_punctuation(u"-")) 127 | self.assertTrue(_is_punctuation(u"$")) 128 | self.assertTrue(_is_punctuation(u"`")) 129 | self.assertTrue(_is_punctuation(u".")) 130 | 131 | self.assertFalse(_is_punctuation(u"A")) 132 | self.assertFalse(_is_punctuation(u" ")) 133 | 134 | 135 | if __name__ == '__main__': 136 | unittest.main() 137 | --------------------------------------------------------------------------------