├── .gitignore
├── .idea
├── misc.xml
├── modules.xml
├── pytorch-pretrained-BERT_annotation.iml
├── vcs.xml
└── workspace.xml
├── LICENSE
├── MANIFEST.in
├── README.md
├── README_bert.md
├── docker
└── Dockerfile
├── download_data
├── download_v1.1.sh
└── download_v2.0.sh
├── examples
├── analysic_lic_data.py
├── analysic_pred_zhidao.py
├── analysic_pred_zhidao1.py
├── analysic_squad_data.py
├── clean_result.py
├── create_submit_file.py
├── evaluate-v1.1.py
├── evaluate-v2.0.py
├── extract_features.py
├── run_classifier.py
├── run_lm_finetuning.py
├── run_squad.py
├── run_squad2.py
├── run_squad_zh.py
├── run_swag.py
├── softmax.py
├── squad_v1.1_arch_sample.json
├── test_BertForMaskedLM.py
├── test_BertModel.py
├── test_squad.py
├── test_tokenization.py
└── valid_data.py
├── notebooks
├── Comparing-TF-and-PT-models-MLM-NSP.ipynb
├── Comparing-TF-and-PT-models-SQuAD.ipynb
└── Comparing-TF-and-PT-models.ipynb
├── pytorch_pretrained_bert
├── __init__.py
├── __main__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── file_utils.cpython-36.pyc
│ ├── modeling.cpython-36.pyc
│ ├── optimization.cpython-36.pyc
│ └── tokenization.cpython-36.pyc
├── convert_tf_checkpoint_to_pytorch.py
├── file_utils.py
├── modeling.py
├── optimization.py
└── tokenization.py
├── requirements.txt
├── samples
├── input.txt
└── sample_text.txt
├── setup.py
└── tests
├── modeling_test.py
├── optimization_test.py
└── tokenization_test.py
/.gitignore:
--------------------------------------------------------------------------------
1 | examples/transfo_format_after_extract_dataset.py
2 | examples/transfo_format_after_extract_dataset_splitter.py
3 | examples/transfo_format_after_extract_dataset_splitter.py
4 | examples/transfo_format.py
5 | examples/transfo_format_nospace.py
6 | pytorch_pretrained_bert/modeling_ks.py
7 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/pytorch-pretrained-BERT_annotation.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch Pretrained Bert Annotation
2 |
3 | > This BERT annotation repo is for my personal study.
4 |
5 | - The raw README of PyTorch Pretrained Bert is [here](README_bert.md).
6 | - A very nice [PPT](https://nlp.stanford.edu/seminar/details/lkaiser.pdf) to help understanding.
7 | - Synthetic Self-Training [PPT](https://nlp.stanford.edu/seminar/details/jdevlin.pdf?fbclid=IwAR2TBFCJOeZ9cGhxB-z5cJJ17vHN4W25oWsjI8NqJoTEmlYIYEKG7oh4tlY).
8 |
9 | ## Arch
10 |
11 | The BertModel and BertForMaskedLM arch.
12 |
13 | #### BertModel Arch
14 | - BertEmbeddings
15 | - word_embeddings: Embedding(30522, 768)
16 | - position_embeddings: Embedding(512, 768)
17 | - token_type_embeddings: Embedding(2, 768)
18 | - LayerNorm: BertLayerNorm()
19 | - dropout: Dropout(p=0.1)
20 | - BertEncoder
21 | - BertLayer: (12 layers)
22 | - BertAttention
23 | - BertSelfAttention
24 | - query: Linear(in_features=768, out_features=768, bias=True)
25 | - key: Linear(in_features=768, out_features=768, bias=True)
26 | - value: Linear(in_features=768, out_features=768, bias=True)
27 | - dropout: Dropout(p=0.1)
28 | - BertSelfOutput
29 | - dense: Linear(in_features=768, out_features=768, bias=True)
30 | - LayerNorm: BertLayerNorm()
31 | - dropout: Dropout(p=0.1)
32 | - BertIntermediate
33 | - dense: Linear(in_features=768, out_features=3072, bias=True)
34 | - activation: gelu
35 | - BertOutput
36 | - dense: Linear(in_features=3072, out_features=768, bias=True)
37 | - LayerNorm: BertLayerNorm()
38 | - dropout: Dropout(p=0.1)
39 | - BertPooler
40 | - dense: Linear(in_features=768, out_features=768, bias=True)
41 | - activation: Tanh()
42 |
43 | #### BertForMaskedLM Arch
44 | - BertModel
45 | - BertEmbeddings
46 | - word_embeddings: Embedding(30522, 768)
47 | - position_embeddings: Embedding(512, 768)
48 | - token_type_embeddings: Embedding(2, 768)
49 | - LayerNorm: BertLayerNorm()
50 | - dropout: Dropout(p=0.1)
51 | - BertEncoder
52 | - BertLayer: (12 layers)
53 | - BertAttention
54 | - BertSelfAttention
55 | - query: Linear(in_features=768, out_features=768, bias=True)
56 | - key: Linear(in_features=768, out_features=768, bias=True)
57 | - value: Linear(in_features=768, out_features=768, bias=True)
58 | - dropout: Dropout(p=0.1)
59 | - BertSelfOutput
60 | - dense: Linear(in_features=768, out_features=768, bias=True)
61 | - LayerNorm: BertLayerNorm()
62 | - dropout: Dropout(p=0.1)
63 | - BertIntermediate
64 | - dense: Linear(in_features=768, out_features=3072, bias=True)
65 | - activation: gelu
66 | - BertOutput
67 | - dense: Linear(in_features=3072, out_features=768, bias=True)
68 | - LayerNorm: BertLayerNorm()
69 | - dropout: Dropout(p=0.1)
70 | - BertPooler
71 | - dense: Linear(in_features=768, out_features=768, bias=True)
72 | - activation: Tanh()
73 | - BertOnlyMLMHead
74 | - BertLMPredictionHead
75 | - transform: BertPredictionHeadTransform
76 | - dense: Linear(in_features=768, out_features=768, bias=True)
77 | - LayerNorm: BertLayerNorm()
78 | - decoder: Linear(in_features=768, out_features=30522, bias=False)
79 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:latest
2 |
3 | RUN git clone https://github.com/NVIDIA/apex.git && cd apex && python setup.py install --cuda_ext --cpp_ext
4 |
5 | RUN pip install pytorch-pretrained-bert
6 |
7 | WORKDIR /workspace
--------------------------------------------------------------------------------
/download_data/download_v1.1.sh:
--------------------------------------------------------------------------------
1 | # Download SQuAD1.1 Data
2 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json
3 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json
4 | wget https://raw.githubusercontent.com/allenai/bi-att-flow/master/squad/evaluate-v1.1.py
5 |
--------------------------------------------------------------------------------
/download_data/download_v2.0.sh:
--------------------------------------------------------------------------------
1 | # Download the SQuAD2.0 dataset
2 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json
3 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json
4 | wget https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/ -O evaluate-v2.0.py
5 | wget https://worksheets.codalab.org/rest/bundles/0x8731effab84f41b7b874a070e40f61e2/contents/blob/ -O dev-evaluate-v2.0-in1
6 |
--------------------------------------------------------------------------------
/examples/analysic_lic_data.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | BASE_PATH = "/home/wyb/PycharmProjects/DuReader/data/demo/"
4 |
5 |
6 | with open(BASE_PATH + "trainset/search.train.json", "r", encoding='utf-8') as reader:
7 | source = reader.readlines()
8 |
9 | # source = json.load(reader)
10 | # input_data = source["data"]
11 | # version = source["version"]
12 |
13 |
14 | # print(len(source))
15 | # print(type(source)) #
16 |
17 | """
18 | keys: (one documents)
19 | documents
20 | answer_spans
21 | fake_answers
22 | question
23 | segmented_answers
24 | answers
25 | answer_docs
26 | segmented_question
27 | question_type
28 | question_id
29 | fact_or_opinion
30 | match_scores
31 | """
32 | line_json = json.loads(source[0])
33 | for i in line_json.keys():
34 | print(i)
35 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/examples/analysic_pred_zhidao.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | BASE_PATH = "/DATA/disk1/wangyongbo/lic2019/DuReader/data/extracted/"
4 |
5 | with open(BASE_PATH + "dureader/predictions.json", "r", encoding="utf-8") as f:
6 | data = json.load(f)
7 |
8 | with open(BASE_PATH + "devset/zhidao.dev.json", "r", encoding='utf-8') as f1:
9 | lines = f1.readlines()
10 |
11 |
12 | def get_best_ans():
13 | """
14 | To compare the prob of sents, select the best answers.
15 | """
16 | with open(BASE_PATH + "dureader/nbest_predictions.json", "r", encoding='utf-8') as f2:
17 | data_n = json.load(f2)
18 |
19 | nbest_para = []
20 | for k, v in data_n.items():
21 | para_dict = {}
22 | id = k.split("_")[0]
23 | prob = 0
24 | text = ""
25 | for sents in v:
26 | if sents["probability"] > prob:
27 | prob = sents["probability"]
28 | text = sents["text"]
29 | para_dict["id"] = id
30 | para_dict["prob"] = prob
31 | para_dict["text"] = text
32 |
33 | if nbest_para:
34 | for item in nbest_para:
35 | if id == item["id"]:
36 | if prob > item["prob"]:
37 | item["prob"] = prob
38 | item["text"] = text
39 | else:
40 | nbest_para.append(para_dict)
41 | else:
42 | nbest_para.append(para_dict)
43 |
44 | return nbest_para
45 |
46 |
47 | nbest_para = get_best_ans()
48 | print("===============> nbest_para completed!")
49 |
50 | for line in lines: # raw
51 | sample = json.loads(line)
52 | ans_list = []
53 | for k, v in data.items(): # pred
54 | if str(sample["question_id"]) == (str(k)).split("_")[0]:
55 | ans_list.append(v)
56 | print("------------------------------------------------------")
57 | print("question_id: " + (str(k)).split("_")[0])
58 | if sample["fake_answers"]:
59 | print("fake_answers: \n" + str(sample["fake_answers"][0]))
60 |
61 | print(" ")
62 |
63 | print("answer: count=" + str(len(sample["answers"])))
64 | for idx,ans_item in enumerate(sample["answers"]):
65 | print(str(idx) + "==> " + ans_item)
66 |
67 | print(" ")
68 |
69 | print("pred answer: count=" + str(len(ans_list)))
70 | answer_docs_id = -1
71 | if "answer_docs" in sample and sample["answer_docs"] and sample["answer_docs"][0] < len(ans_list):
72 | answer_docs_id = sample["answer_docs"][0]
73 |
74 | for idx, ans_item in enumerate(ans_list):
75 | state1 = "" # flag of 'has fake_answers'
76 | state2 = "" # flag of 'best answers'
77 | for ans in nbest_para:
78 | if ans_item == ans["text"]:
79 | state2 = "(pred BEST answers)"
80 |
81 | if idx == answer_docs_id:
82 | state1 = "(has fake_answers)"
83 |
84 | print(str(idx) + "==>" + state1 + state2 + " " + ans_item)
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
--------------------------------------------------------------------------------
/examples/analysic_pred_zhidao1.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | BASE_PATH = "/DATA/disk1/wangyongbo/lic2019/DuReader/data/extracted/"
4 |
5 |
6 | with open(BASE_PATH + "dureader/predictions.json", "r", encoding="utf-8") as f:
7 | data = json.load(f)
8 |
9 | with open(BASE_PATH + "dureader/predictions_filter.json", "r", encoding="utf-8") as f:
10 | data_filter = json.load(f)
11 |
12 | with open(BASE_PATH + "devset/zhidao.dev.json", "r", encoding='utf-8') as f1:
13 | lines = f1.readlines()
14 |
15 | # nbest_para = get_best_ans()
16 |
17 | for line in lines: # raw
18 | sample = json.loads(line)
19 | ans_list = []
20 | for k, v in data.items(): # pred
21 | if str(sample["question_id"]) == (str(k)).split("_")[0]:
22 | ans_list.append(v[0])
23 | print("------------------------------------------------------")
24 | print("question_id: " + (str(k)).split("_")[0])
25 | if sample["fake_answers"]:
26 | print("fake_answers: \n" + str(sample["fake_answers"][0]))
27 |
28 | print(" ")
29 |
30 | print("answer: count=" + str(len(sample["answers"])))
31 | for idx,ans_item in enumerate(sample["answers"]):
32 | print(str(idx) + "==> " + ans_item)
33 |
34 | print(" ")
35 |
36 | print("pred answer: count=" + str(len(ans_list)))
37 | answer_docs_id = -1
38 | if "answer_docs" in sample and sample["answer_docs"] and sample["answer_docs"][0] < len(ans_list):
39 | answer_docs_id = sample["answer_docs"][0]
40 |
41 | for idx, ans_item in enumerate(ans_list):
42 | state1 = "" # flag of 'has fake_answers'
43 | state2 = "" # flag of 'best answers'
44 | for k1, v1 in data_filter.items():
45 | if v1 == ans_item:
46 | state2 = "(pred BEST answers)"
47 |
48 | if idx == answer_docs_id:
49 | state1 = "(has fake_answers)"
50 |
51 | print(str(idx) + "==>" + state1 + state2 + " " + ans_item)
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
--------------------------------------------------------------------------------
/examples/analysic_squad_data.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | BASE_PATH = "/home/wyb/data/squad_v2.0/"
4 |
5 |
6 | with open(BASE_PATH + "train-v2.0.json", "r", encoding='utf-8') as reader:
7 | source = json.load(reader)
8 | input_data = source["data"]
9 | version = source["version"]
10 | #
11 | #
12 | # examples = []
13 | # for entry in input_data:
14 | # """
15 | # entry format:
16 | # {"title": xxx, "paragraphs": xxxx}
17 | # """
18 | # for paragraph in entry["paragraphs"]:
19 | #
20 | # paragraph_text = paragraph["context"]
21 |
22 |
23 | paragraph_text = 'Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say).'
24 | doc_tokens = []
25 | char_to_word_offset = []
26 | prev_is_whitespace = True
27 | for c in paragraph_text: # by char
28 | if is_whitespace(c):
29 | prev_is_whitespace = True
30 | else:
31 | if prev_is_whitespace:
32 | doc_tokens.append(c)
33 | else:
34 | doc_tokens[-1] += c
35 | prev_is_whitespace = False
36 | char_to_word_offset.append(len(doc_tokens) - 1)
37 |
38 | print(doc_tokens)
39 | print("----------------")
40 | print(char_to_word_offset)
41 |
42 |
43 |
44 |
45 |
46 |
--------------------------------------------------------------------------------
/examples/clean_result.py:
--------------------------------------------------------------------------------
1 | import re
2 | import json
3 |
4 | BASE_PATH = "/home/wyb/Downloads/"
5 |
6 |
7 | with open(BASE_PATH + "test_result.json", "r", encoding="utf-8") as f:
8 | res = f.readlines() # list, len=120000
9 |
10 | # text = "This is a \n file \r that \r hello\r!"
11 |
12 |
13 | def clean_sepc_char(text):
14 | replace_p = ["\t", "\n", "\r", "\u3000", "", "/>", "\\x0a", " "\''
33 | C_pun = u',。!?【】()《》“‘'
34 | table = {ord(f): ord(t) for f, t in zip(E_pun, C_pun)}
35 |
36 | return string.translate(table)
37 |
38 |
39 | # for i in res:
40 | # data = json.loads(i)
41 | # if "&" in data["answers"][0]:
42 | # print(data)
43 |
44 |
45 | # cate_type = set() # {'DESCRIPTION', 'YES_NO', 'ENTITY'}
46 | # for i in res:
47 | # data = json.loads(i)
48 | # if data["question_type"] == "YES_NO":
49 | # print(data)
50 |
51 |
52 | # text = '小箭头。
iiiiiiiiiiiiiiiiiiiii
2.点击小箭头,则就是筛选。'
53 |
54 |
55 | def remove_html(text):
56 | reg = re.compile(r'<[^>]+>', re.S)
57 | text = reg.sub('', text)
58 |
59 | return text
60 |
61 |
62 | """
63 | {
64 | "question_id": 403770,
65 | "question_type": "YES_NO",
66 | "answers": ["我都是免费几分钟测试可以玩而已。"],
67 | "entity_answers": [[]],
68 | "yesno_answers": []
69 | }
70 | """
71 | json_list = []
72 | for i in res:
73 | item_dict = {}
74 |
75 | data = json.loads(i)
76 | text = data["answers"][0]
77 | text = E_trans_to_C(text)
78 | text = clean_sepc_char(text)
79 | text = remove_html(text)
80 |
81 | item_dict["question_id"] = data["question_id"]
82 | item_dict["question_type"] = data["question_type"]
83 | item_dict["answers"] = [text]
84 | item_dict["entity_answers"] = data["entity_answers"]
85 | item_dict["yesno_answers"] = data["yesno_answers"]
86 |
87 | json_list.append(item_dict)
88 |
89 | # ================ write to file ================
90 | with open(BASE_PATH + "test_result_rm.json", 'w') as fout:
91 | for pred_answer in json_list:
92 | fout.write(json.dumps(pred_answer, ensure_ascii=False) + '\n')
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
--------------------------------------------------------------------------------
/examples/create_submit_file.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | BASE_PATH = "/DATA/disk1/wangyongbo/lic2019/DuReader/official_data/extracted/"
4 |
5 | """
6 | /DATA/disk1/wangyongbo/lic2019/DuReader/official_data/extracted/test1set
7 |
8 | {
9 | "question_id": 397032,
10 | "question_type": "ENTITY",
11 | "answers": ["浙江绿谷,秀山丽水。"],
12 | "entity_answers": [[]],
13 | "yesno_answers": []
14 | }
15 | """
16 | datasets = ["search", "zhidao"]
17 | for dataset in datasets:
18 | with open(BASE_PATH + "results/predictions_first_filter_" + dataset + ".json", "r") as f:
19 | data_n = json.load(f)
20 |
21 | with open(BASE_PATH + "test1set/" + dataset + ".test1.json", "r", encoding="utf-8") as f:
22 | lines = f.readlines()
23 |
24 | res = []
25 | test_ids = []
26 | pred_search_ids = []
27 | for line in lines:
28 | line_json = json.loads(line)
29 | # avoid loss sample in predictions, save all ids to a list.
30 | test_ids.append(int(line_json["question_id"]))
31 | for k, v in data_n:
32 | pred_search_ids.append(int(k))
33 | if str(line_json["question_id"]) == str(k):
34 | res_line = {}
35 | res_line["question_id"] = int(k)
36 | res_line["question_type"] = line_json["question_type"]
37 | res_line["answers"] = [v]
38 | res_line["entity_answers"] = [[]]
39 | res_line["yesno_answers"] = []
40 | res.append(res_line)
41 |
42 | if len(res) != 30000:
43 | # fill in loss sample with "" (no answer)
44 | for id in test_ids:
45 | if id not in pred_search_ids:
46 | for line in lines:
47 | line_json = json.loads(line)
48 | if str(line_json["question_id"]) == str(id):
49 | res_line = {}
50 | res_line["question_id"] = int(id)
51 | res_line["question_type"] = line_json["question_type"]
52 | res_line["answers"] = [""]
53 | res_line["entity_answers"] = [[]]
54 | res_line["yesno_answers"] = []
55 | res.append(res_line)
56 |
57 | with open(BASE_PATH + "results/test_result_" + dataset + ".json", 'w') as fout:
58 | for pred_answer in res:
59 | fout.write(json.dumps(pred_answer, ensure_ascii=False) + '\n')
60 |
--------------------------------------------------------------------------------
/examples/evaluate-v1.1.py:
--------------------------------------------------------------------------------
1 | """ Official evaluation script for v1.1 of the SQuAD dataset. """
2 | from __future__ import print_function
3 | from collections import Counter
4 | import string
5 | import re
6 | import argparse
7 | import json
8 | import sys
9 | """
10 | Exec:
11 | python evaluate-v1.1.py ./dev-v1.1.json /tmp/debug_squad/predictions.json
12 |
13 | Results:
14 | {"exact_match": 80.49195837275307, "f1": 88.05701702878619}
15 | """
16 |
17 |
18 | def normalize_answer(s):
19 | """Lower text and remove punctuation, articles and extra whitespace."""
20 | def remove_articles(text):
21 | return re.sub(r'\b(a|an|the)\b', ' ', text)
22 |
23 | def white_space_fix(text):
24 | return ' '.join(text.split())
25 |
26 | def remove_punc(text):
27 | exclude = set(string.punctuation)
28 | return ''.join(ch for ch in text if ch not in exclude)
29 |
30 | def lower(text):
31 | return text.lower()
32 |
33 | return white_space_fix(remove_articles(remove_punc(lower(s))))
34 |
35 |
36 | def f1_score(prediction, ground_truth):
37 | prediction_tokens = normalize_answer(prediction).split()
38 | ground_truth_tokens = normalize_answer(ground_truth).split()
39 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
40 | num_same = sum(common.values())
41 | if num_same == 0:
42 | return 0
43 | precision = 1.0 * num_same / len(prediction_tokens)
44 | recall = 1.0 * num_same / len(ground_truth_tokens)
45 | f1 = (2 * precision * recall) / (precision + recall)
46 | return f1
47 |
48 |
49 | def exact_match_score(prediction, ground_truth):
50 | return (normalize_answer(prediction) == normalize_answer(ground_truth))
51 |
52 |
53 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
54 | scores_for_ground_truths = []
55 | for ground_truth in ground_truths:
56 | score = metric_fn(prediction, ground_truth)
57 | scores_for_ground_truths.append(score)
58 | return max(scores_for_ground_truths)
59 |
60 |
61 | def evaluate(dataset, predictions):
62 | f1 = exact_match = total = 0
63 | for article in dataset:
64 | for paragraph in article['paragraphs']:
65 | for qa in paragraph['qas']:
66 | total += 1
67 | if qa['id'] not in predictions:
68 | message = 'Unanswered question ' + qa['id'] + \
69 | ' will receive score 0.'
70 | print(message, file=sys.stderr)
71 | continue
72 | ground_truths = list(map(lambda x: x['text'], qa['answers']))
73 | prediction = predictions[qa['id']]
74 | exact_match += metric_max_over_ground_truths(
75 | exact_match_score, prediction, ground_truths)
76 | f1 += metric_max_over_ground_truths(
77 | f1_score, prediction, ground_truths)
78 |
79 | exact_match = 100.0 * exact_match / total
80 | f1 = 100.0 * f1 / total
81 |
82 | return {'exact_match': exact_match, 'f1': f1}
83 |
84 |
85 | if __name__ == '__main__':
86 | expected_version = '1.1'
87 | parser = argparse.ArgumentParser(
88 | description='Evaluation for SQuAD ' + expected_version)
89 | parser.add_argument('dataset_file', help='Dataset file')
90 | parser.add_argument('prediction_file', help='Prediction File')
91 | args = parser.parse_args()
92 | with open(args.dataset_file) as dataset_file:
93 | dataset_json = json.load(dataset_file)
94 | if (dataset_json['version'] != expected_version):
95 | print('Evaluation expects v-' + expected_version +
96 | ', but got dataset with v-' + dataset_json['version'],
97 | file=sys.stderr)
98 | dataset = dataset_json['data']
99 | with open(args.prediction_file) as prediction_file:
100 | predictions = json.load(prediction_file)
101 | print(json.dumps(evaluate(dataset, predictions)))
--------------------------------------------------------------------------------
/examples/evaluate-v2.0.py:
--------------------------------------------------------------------------------
1 | """Official evaluation script for SQuAD version 2.0.
2 |
3 | In addition to basic functionality, we also compute additional statistics and
4 | plot precision-recall curves if an additional na_prob.json file is provided.
5 | This file is expected to map question ID's to the model's predicted probability
6 | that a question is unanswerable.
7 | """
8 | import argparse
9 | import collections
10 | import json
11 | import numpy as np
12 | import os
13 | import re
14 | import string
15 | import sys
16 |
17 | OPTS = None
18 |
19 |
20 | def parse_args():
21 | """
22 | python evaluate-v2.0.py
23 |
24 | bert-base-uncased
25 |
26 | EXEC:
27 | python evaluate-v2.0.py ./dev-v2.0.json /tmp/debug_squad2/predictions.json
28 | RES:
29 | {
30 | "exact": 70.78244757011707,
31 | "f1": 74.11532024041503,
32 | "total": 11873,
33 | "HasAns_exact": 71.72739541160594,
34 | "HasAns_f1": 78.40269858543304,
35 | "HasAns_total": 5928,
36 | "NoAns_exact": 69.84020185029436,
37 | "NoAns_f1": 69.84020185029436,
38 | "NoAns_total": 5945
39 | }
40 | """
41 | parser = argparse.ArgumentParser('Official evaluation script for SQuAD version 2.0.')
42 | parser.add_argument('data_file', metavar='data.json', help='Input data JSON file.')
43 | parser.add_argument('pred_file', metavar='pred.json', help='Model predictions.')
44 | parser.add_argument('--out-file', '-o', metavar='eval.json',
45 | help='Write accuracy metrics to file (default is stdout).')
46 | parser.add_argument('--na-prob-file', '-n', metavar='na_prob.json',
47 | help='Model estimates of probability of no answer.')
48 | parser.add_argument('--na-prob-thresh', '-t', type=float, default=1.0,
49 | help='Predict "" if no-answer probability exceeds this (default = 1.0).')
50 | parser.add_argument('--out-image-dir', '-p', metavar='out_images', default=None,
51 | help='Save precision-recall curves to directory.')
52 | parser.add_argument('--verbose', '-v', action='store_true')
53 | if len(sys.argv) == 1:
54 | parser.print_help()
55 | sys.exit(1)
56 | return parser.parse_args()
57 |
58 |
59 | def make_qid_to_has_ans(dataset):
60 | qid_to_has_ans = {}
61 | for article in dataset:
62 | for p in article['paragraphs']:
63 | for qa in p['qas']:
64 | qid_to_has_ans[qa['id']] = bool(qa['answers'])
65 | return qid_to_has_ans
66 |
67 | def normalize_answer(s):
68 | """Lower text and remove punctuation, articles and extra whitespace."""
69 | def remove_articles(text):
70 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
71 | return re.sub(regex, ' ', text)
72 | def white_space_fix(text):
73 | return ' '.join(text.split())
74 | def remove_punc(text):
75 | exclude = set(string.punctuation)
76 | return ''.join(ch for ch in text if ch not in exclude)
77 | def lower(text):
78 | return text.lower()
79 | return white_space_fix(remove_articles(remove_punc(lower(s))))
80 |
81 | def get_tokens(s):
82 | if not s: return []
83 | return normalize_answer(s).split()
84 |
85 | def compute_exact(a_gold, a_pred):
86 | return int(normalize_answer(a_gold) == normalize_answer(a_pred))
87 |
88 | def compute_f1(a_gold, a_pred):
89 | gold_toks = get_tokens(a_gold)
90 | pred_toks = get_tokens(a_pred)
91 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
92 | num_same = sum(common.values())
93 | if len(gold_toks) == 0 or len(pred_toks) == 0:
94 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
95 | return int(gold_toks == pred_toks)
96 | if num_same == 0:
97 | return 0
98 | precision = 1.0 * num_same / len(pred_toks)
99 | recall = 1.0 * num_same / len(gold_toks)
100 | f1 = (2 * precision * recall) / (precision + recall)
101 | return f1
102 |
103 | def get_raw_scores(dataset, preds):
104 | exact_scores = {}
105 | f1_scores = {}
106 | for article in dataset:
107 | for p in article['paragraphs']:
108 | for qa in p['qas']:
109 | qid = qa['id']
110 | gold_answers = [a['text'] for a in qa['answers']
111 | if normalize_answer(a['text'])]
112 | if not gold_answers:
113 | # For unanswerable questions, only correct answer is empty string
114 | gold_answers = ['']
115 | if qid not in preds:
116 | print('Missing prediction for %s' % qid)
117 | continue
118 | a_pred = preds[qid]
119 | # Take max over all gold answers
120 | exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers)
121 | f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers)
122 | return exact_scores, f1_scores
123 |
124 | def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
125 | new_scores = {}
126 | for qid, s in scores.items():
127 | pred_na = na_probs[qid] > na_prob_thresh
128 | if pred_na:
129 | new_scores[qid] = float(not qid_to_has_ans[qid])
130 | else:
131 | new_scores[qid] = s
132 | return new_scores
133 |
134 | def make_eval_dict(exact_scores, f1_scores, qid_list=None):
135 | if not qid_list:
136 | total = len(exact_scores)
137 | return collections.OrderedDict([
138 | ('exact', 100.0 * sum(exact_scores.values()) / total),
139 | ('f1', 100.0 * sum(f1_scores.values()) / total),
140 | ('total', total),
141 | ])
142 | else:
143 | total = len(qid_list)
144 | return collections.OrderedDict([
145 | ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
146 | ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
147 | ('total', total),
148 | ])
149 |
150 | def merge_eval(main_eval, new_eval, prefix):
151 | for k in new_eval:
152 | main_eval['%s_%s' % (prefix, k)] = new_eval[k]
153 |
154 | def plot_pr_curve(precisions, recalls, out_image, title):
155 | plt.step(recalls, precisions, color='b', alpha=0.2, where='post')
156 | plt.fill_between(recalls, precisions, step='post', alpha=0.2, color='b')
157 | plt.xlabel('Recall')
158 | plt.ylabel('Precision')
159 | plt.xlim([0.0, 1.05])
160 | plt.ylim([0.0, 1.05])
161 | plt.title(title)
162 | plt.savefig(out_image)
163 | plt.clf()
164 |
165 | def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans,
166 | out_image=None, title=None):
167 | qid_list = sorted(na_probs, key=lambda k: na_probs[k])
168 | true_pos = 0.0
169 | cur_p = 1.0
170 | cur_r = 0.0
171 | precisions = [1.0]
172 | recalls = [0.0]
173 | avg_prec = 0.0
174 | for i, qid in enumerate(qid_list):
175 | if qid_to_has_ans[qid]:
176 | true_pos += scores[qid]
177 | cur_p = true_pos / float(i+1)
178 | cur_r = true_pos / float(num_true_pos)
179 | if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]:
180 | # i.e., if we can put a threshold after this point
181 | avg_prec += cur_p * (cur_r - recalls[-1])
182 | precisions.append(cur_p)
183 | recalls.append(cur_r)
184 | if out_image:
185 | plot_pr_curve(precisions, recalls, out_image, title)
186 | return {'ap': 100.0 * avg_prec}
187 |
188 | def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs,
189 | qid_to_has_ans, out_image_dir):
190 | if out_image_dir and not os.path.exists(out_image_dir):
191 | os.makedirs(out_image_dir)
192 | num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
193 | if num_true_pos == 0:
194 | return
195 | pr_exact = make_precision_recall_eval(
196 | exact_raw, na_probs, num_true_pos, qid_to_has_ans,
197 | out_image=os.path.join(out_image_dir, 'pr_exact.png'),
198 | title='Precision-Recall curve for Exact Match score')
199 | pr_f1 = make_precision_recall_eval(
200 | f1_raw, na_probs, num_true_pos, qid_to_has_ans,
201 | out_image=os.path.join(out_image_dir, 'pr_f1.png'),
202 | title='Precision-Recall curve for F1 score')
203 | oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
204 | pr_oracle = make_precision_recall_eval(
205 | oracle_scores, na_probs, num_true_pos, qid_to_has_ans,
206 | out_image=os.path.join(out_image_dir, 'pr_oracle.png'),
207 | title='Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)')
208 | merge_eval(main_eval, pr_exact, 'pr_exact')
209 | merge_eval(main_eval, pr_f1, 'pr_f1')
210 | merge_eval(main_eval, pr_oracle, 'pr_oracle')
211 |
212 | def histogram_na_prob(na_probs, qid_list, image_dir, name):
213 | if not qid_list:
214 | return
215 | x = [na_probs[k] for k in qid_list]
216 | weights = np.ones_like(x) / float(len(x))
217 | plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0))
218 | plt.xlabel('Model probability of no-answer')
219 | plt.ylabel('Proportion of dataset')
220 | plt.title('Histogram of no-answer probability: %s' % name)
221 | plt.savefig(os.path.join(image_dir, 'na_prob_hist_%s.png' % name))
222 | plt.clf()
223 |
224 | def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
225 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
226 | cur_score = num_no_ans
227 | best_score = cur_score
228 | best_thresh = 0.0
229 | qid_list = sorted(na_probs, key=lambda k: na_probs[k])
230 | for i, qid in enumerate(qid_list):
231 | if qid not in scores: continue
232 | if qid_to_has_ans[qid]:
233 | diff = scores[qid]
234 | else:
235 | if preds[qid]:
236 | diff = -1
237 | else:
238 | diff = 0
239 | cur_score += diff
240 | if cur_score > best_score:
241 | best_score = cur_score
242 | best_thresh = na_probs[qid]
243 | return 100.0 * best_score / len(scores), best_thresh
244 |
245 | def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
246 | best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
247 | best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
248 | main_eval['best_exact'] = best_exact
249 | main_eval['best_exact_thresh'] = exact_thresh
250 | main_eval['best_f1'] = best_f1
251 | main_eval['best_f1_thresh'] = f1_thresh
252 |
253 | def main():
254 | with open(OPTS.data_file) as f:
255 | dataset_json = json.load(f)
256 | dataset = dataset_json['data']
257 | with open(OPTS.pred_file) as f:
258 | preds = json.load(f)
259 | if OPTS.na_prob_file:
260 | with open(OPTS.na_prob_file) as f:
261 | na_probs = json.load(f)
262 | else:
263 | na_probs = {k: 0.0 for k in preds}
264 | qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False
265 | has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
266 | no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
267 | exact_raw, f1_raw = get_raw_scores(dataset, preds)
268 | exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans,
269 | OPTS.na_prob_thresh)
270 | f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans,
271 | OPTS.na_prob_thresh)
272 | out_eval = make_eval_dict(exact_thresh, f1_thresh)
273 | if has_ans_qids:
274 | has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
275 | merge_eval(out_eval, has_ans_eval, 'HasAns')
276 | if no_ans_qids:
277 | no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
278 | merge_eval(out_eval, no_ans_eval, 'NoAns')
279 | if OPTS.na_prob_file:
280 | find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans)
281 | if OPTS.na_prob_file and OPTS.out_image_dir:
282 | run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs,
283 | qid_to_has_ans, OPTS.out_image_dir)
284 | histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, 'hasAns')
285 | histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, 'noAns')
286 | if OPTS.out_file:
287 | with open(OPTS.out_file, 'w') as f:
288 | json.dump(out_eval, f)
289 | else:
290 | print(json.dumps(out_eval, indent=2))
291 |
292 | if __name__ == '__main__':
293 | OPTS = parse_args()
294 | if OPTS.out_image_dir:
295 | import matplotlib
296 | matplotlib.use('Agg')
297 | import matplotlib.pyplot as plt
298 | main()
--------------------------------------------------------------------------------
/examples/extract_features.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Extract pre-computed feature vectors from a PyTorch BERT model."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import argparse
22 | import collections
23 | import logging
24 | import json
25 | import re
26 |
27 | import torch
28 | from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
29 | from torch.utils.data.distributed import DistributedSampler
30 |
31 | from pytorch_pretrained_bert.tokenization import BertTokenizer
32 | from pytorch_pretrained_bert.modeling import BertModel
33 |
34 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
35 | datefmt = '%m/%d/%Y %H:%M:%S',
36 | level = logging.INFO)
37 | logger = logging.getLogger(__name__)
38 |
39 |
40 | class InputExample(object):
41 |
42 | def __init__(self, unique_id, text_a, text_b):
43 | self.unique_id = unique_id
44 | self.text_a = text_a
45 | self.text_b = text_b
46 |
47 |
48 | class InputFeatures(object):
49 | """A single set of features of data."""
50 |
51 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids):
52 | self.unique_id = unique_id
53 | self.tokens = tokens
54 | self.input_ids = input_ids
55 | self.input_mask = input_mask
56 | self.input_type_ids = input_type_ids
57 |
58 |
59 | def convert_examples_to_features(examples, seq_length, tokenizer):
60 | """Loads a data file into a list of `InputBatch`s."""
61 |
62 | features = []
63 | for (ex_index, example) in enumerate(examples):
64 | tokens_a = tokenizer.tokenize(example.text_a)
65 |
66 | tokens_b = None
67 | if example.text_b:
68 | tokens_b = tokenizer.tokenize(example.text_b)
69 |
70 | if tokens_b:
71 | # Modifies `tokens_a` and `tokens_b` in place so that the total
72 | # length is less than the specified length.
73 | # Account for [CLS], [SEP], [SEP] with "- 3"
74 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3)
75 | else:
76 | # Account for [CLS] and [SEP] with "- 2"
77 | if len(tokens_a) > seq_length - 2:
78 | tokens_a = tokens_a[0:(seq_length - 2)]
79 |
80 | # The convention in BERT is:
81 | # (a) For sequence pairs:
82 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
83 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
84 | # (b) For single sequences:
85 | # tokens: [CLS] the dog is hairy . [SEP]
86 | # type_ids: 0 0 0 0 0 0 0
87 | #
88 | # Where "type_ids" are used to indicate whether this is the first
89 | # sequence or the second sequence. The embedding vectors for `type=0` and
90 | # `type=1` were learned during pre-training and are added to the wordpiece
91 | # embedding vector (and position vector). This is not *strictly* necessary
92 | # since the [SEP] token unambigiously separates the sequences, but it makes
93 | # it easier for the model to learn the concept of sequences.
94 | #
95 | # For classification tasks, the first vector (corresponding to [CLS]) is
96 | # used as as the "sentence vector". Note that this only makes sense because
97 | # the entire model is fine-tuned.
98 | tokens = []
99 | input_type_ids = []
100 | tokens.append("[CLS]")
101 | input_type_ids.append(0)
102 | for token in tokens_a:
103 | tokens.append(token)
104 | input_type_ids.append(0)
105 | tokens.append("[SEP]")
106 | input_type_ids.append(0)
107 |
108 | if tokens_b:
109 | for token in tokens_b:
110 | tokens.append(token)
111 | input_type_ids.append(1)
112 | tokens.append("[SEP]")
113 | input_type_ids.append(1)
114 |
115 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
116 |
117 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
118 | # tokens are attended to.
119 | input_mask = [1] * len(input_ids)
120 |
121 | # Zero-pad up to the sequence length.
122 | while len(input_ids) < seq_length:
123 | input_ids.append(0)
124 | input_mask.append(0)
125 | input_type_ids.append(0)
126 |
127 | assert len(input_ids) == seq_length
128 | assert len(input_mask) == seq_length
129 | assert len(input_type_ids) == seq_length
130 |
131 | if ex_index < 5:
132 | logger.info("*** Example ***")
133 | logger.info("unique_id: %s" % (example.unique_id))
134 | logger.info("tokens: %s" % " ".join([str(x) for x in tokens]))
135 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
136 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
137 | logger.info(
138 | "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids]))
139 |
140 | features.append(
141 | InputFeatures(
142 | unique_id=example.unique_id,
143 | tokens=tokens,
144 | input_ids=input_ids,
145 | input_mask=input_mask,
146 | input_type_ids=input_type_ids))
147 | return features
148 |
149 |
150 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
151 | """Truncates a sequence pair in place to the maximum length."""
152 |
153 | # This is a simple heuristic which will always truncate the longer sequence
154 | # one token at a time. This makes more sense than truncating an equal percent
155 | # of tokens from each, since if one sequence is very short then each token
156 | # that's truncated likely contains more information than a longer sequence.
157 | while True:
158 | total_length = len(tokens_a) + len(tokens_b)
159 | if total_length <= max_length:
160 | break
161 | if len(tokens_a) > len(tokens_b):
162 | tokens_a.pop()
163 | else:
164 | tokens_b.pop()
165 |
166 |
167 | def read_examples(input_file):
168 | """Read a list of `InputExample`s from an input file."""
169 | examples = []
170 | unique_id = 0
171 | with open(input_file, "r", encoding='utf-8') as reader:
172 | while True:
173 | line = reader.readline()
174 | if not line:
175 | break
176 | line = line.strip()
177 | text_a = None
178 | text_b = None
179 | m = re.match(r"^(.*) \|\|\| (.*)$", line)
180 | if m is None:
181 | text_a = line
182 | else:
183 | text_a = m.group(1)
184 | text_b = m.group(2)
185 | examples.append(
186 | InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
187 | unique_id += 1
188 | return examples
189 |
190 |
191 | def main():
192 | parser = argparse.ArgumentParser()
193 |
194 | ## Required parameters
195 | parser.add_argument("--input_file", default=None, type=str, required=True)
196 | parser.add_argument("--output_file", default=None, type=str, required=True)
197 | parser.add_argument("--bert_model", default=None, type=str, required=True,
198 | help="Bert pre-trained model selected in the list: bert-base-uncased, "
199 | "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
200 |
201 | ## Other parameters
202 | parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
203 | parser.add_argument("--layers", default="-1,-2,-3,-4", type=str)
204 | parser.add_argument("--max_seq_length", default=128, type=int,
205 | help="The maximum total input sequence length after WordPiece tokenization. Sequences longer "
206 | "than this will be truncated, and sequences shorter than this will be padded.")
207 | parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.")
208 | parser.add_argument("--local_rank",
209 | type=int,
210 | default=-1,
211 | help = "local_rank for distributed training on gpus")
212 | parser.add_argument("--no_cuda",
213 | action='store_true',
214 | help="Whether not to use CUDA when available")
215 |
216 | args = parser.parse_args()
217 |
218 | if args.local_rank == -1 or args.no_cuda:
219 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
220 | n_gpu = torch.cuda.device_count()
221 | else:
222 | device = torch.device("cuda", args.local_rank)
223 | n_gpu = 1
224 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
225 | torch.distributed.init_process_group(backend='nccl')
226 | logger.info("device: {} n_gpu: {} distributed training: {}".format(device, n_gpu, bool(args.local_rank != -1)))
227 |
228 | layer_indexes = [int(x) for x in args.layers.split(",")]
229 |
230 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
231 |
232 | examples = read_examples(args.input_file)
233 |
234 | features = convert_examples_to_features(
235 | examples=examples, seq_length=args.max_seq_length, tokenizer=tokenizer)
236 |
237 | unique_id_to_feature = {}
238 | for feature in features:
239 | unique_id_to_feature[feature.unique_id] = feature
240 |
241 | model = BertModel.from_pretrained(args.bert_model)
242 | model.to(device)
243 |
244 | if args.local_rank != -1:
245 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
246 | output_device=args.local_rank)
247 | elif n_gpu > 1:
248 | model = torch.nn.DataParallel(model)
249 |
250 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
251 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
252 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
253 |
254 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_example_index)
255 | if args.local_rank == -1:
256 | eval_sampler = SequentialSampler(eval_data)
257 | else:
258 | eval_sampler = DistributedSampler(eval_data)
259 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size)
260 |
261 | model.eval()
262 | with open(args.output_file, "w", encoding='utf-8') as writer:
263 | for input_ids, input_mask, example_indices in eval_dataloader:
264 | input_ids = input_ids.to(device)
265 | input_mask = input_mask.to(device)
266 |
267 | all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask)
268 | all_encoder_layers = all_encoder_layers
269 |
270 | for b, example_index in enumerate(example_indices):
271 | feature = features[example_index.item()]
272 | unique_id = int(feature.unique_id)
273 | # feature = unique_id_to_feature[unique_id]
274 | output_json = collections.OrderedDict()
275 | output_json["linex_index"] = unique_id
276 | all_out_features = []
277 | for (i, token) in enumerate(feature.tokens):
278 | all_layers = []
279 | for (j, layer_index) in enumerate(layer_indexes):
280 | layer_output = all_encoder_layers[int(layer_index)].detach().cpu().numpy()
281 | layer_output = layer_output[b]
282 | layers = collections.OrderedDict()
283 | layers["index"] = layer_index
284 | layers["values"] = [
285 | round(x.item(), 6) for x in layer_output[i]
286 | ]
287 | all_layers.append(layers)
288 | out_features = collections.OrderedDict()
289 | out_features["token"] = token
290 | out_features["layers"] = all_layers
291 | all_out_features.append(out_features)
292 | output_json["features"] = all_out_features
293 | writer.write(json.dumps(output_json) + "\n")
294 |
295 |
296 | if __name__ == "__main__":
297 | main()
298 |
--------------------------------------------------------------------------------
/examples/run_lm_finetuning.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """BERT finetuning runner."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import os
23 | import logging
24 | import argparse
25 | from tqdm import tqdm, trange
26 |
27 | import numpy as np
28 | import torch
29 | from torch.utils.data import DataLoader, RandomSampler
30 | from torch.utils.data.distributed import DistributedSampler
31 |
32 | from pytorch_pretrained_bert.tokenization import BertTokenizer
33 | from pytorch_pretrained_bert.modeling import BertForPreTraining
34 | from pytorch_pretrained_bert.optimization import BertAdam
35 |
36 | from torch.utils.data import Dataset
37 | import random
38 |
39 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
40 | datefmt='%m/%d/%Y %H:%M:%S',
41 | level=logging.INFO)
42 | logger = logging.getLogger(__name__)
43 |
44 |
45 | def warmup_linear(x, warmup=0.002):
46 | if x < warmup:
47 | return x/warmup
48 | return 1.0 - x
49 |
50 |
51 | class BERTDataset(Dataset):
52 | def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
53 | self.vocab = tokenizer.vocab
54 | self.tokenizer = tokenizer
55 | self.seq_len = seq_len
56 | self.on_memory = on_memory
57 | self.corpus_lines = corpus_lines # number of non-empty lines in input corpus
58 | self.corpus_path = corpus_path
59 | self.encoding = encoding
60 | self.current_doc = 0 # to avoid random sentence from same doc
61 |
62 | # for loading samples directly from file
63 | self.sample_counter = 0 # used to keep track of full epochs on file
64 | self.line_buffer = None # keep second sentence of a pair in memory and use as first sentence in next pair
65 |
66 | # for loading samples in memory
67 | self.current_random_doc = 0
68 | self.num_docs = 0
69 | self.sample_to_doc = [] # map sample index to doc and line
70 |
71 | # load samples into memory
72 | if on_memory:
73 | self.all_docs = []
74 | doc = []
75 | self.corpus_lines = 0
76 | with open(corpus_path, "r", encoding=encoding) as f:
77 | for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
78 | line = line.strip()
79 | if line == "":
80 | self.all_docs.append(doc)
81 | doc = []
82 | #remove last added sample because there won't be a subsequent line anymore in the doc
83 | self.sample_to_doc.pop()
84 | else:
85 | #store as one sample
86 | sample = {"doc_id": len(self.all_docs),
87 | "line": len(doc)}
88 | self.sample_to_doc.append(sample)
89 | doc.append(line)
90 | self.corpus_lines = self.corpus_lines + 1
91 |
92 | # if last row in file is not empty
93 | if self.all_docs[-1] != doc:
94 | self.all_docs.append(doc)
95 | self.sample_to_doc.pop()
96 |
97 | self.num_docs = len(self.all_docs)
98 |
99 | # load samples later lazily from disk
100 | else:
101 | if self.corpus_lines is None:
102 | with open(corpus_path, "r", encoding=encoding) as f:
103 | self.corpus_lines = 0
104 | for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
105 | if line.strip() == "":
106 | self.num_docs += 1
107 | else:
108 | self.corpus_lines += 1
109 |
110 | # if doc does not end with empty line
111 | if line.strip() != "":
112 | self.num_docs += 1
113 |
114 | self.file = open(corpus_path, "r", encoding=encoding)
115 | self.random_file = open(corpus_path, "r", encoding=encoding)
116 |
117 | def __len__(self):
118 | # last line of doc won't be used, because there's no "nextSentence". Additionally, we start counting at 0.
119 | return self.corpus_lines - self.num_docs - 1
120 |
121 | def __getitem__(self, item):
122 | cur_id = self.sample_counter
123 | self.sample_counter += 1
124 | if not self.on_memory:
125 | # after one epoch we start again from beginning of file
126 | if cur_id != 0 and (cur_id % len(self) == 0):
127 | self.file.close()
128 | self.file = open(self.corpus_path, "r", encoding=self.encoding)
129 |
130 | t1, t2, is_next_label = self.random_sent(item)
131 |
132 | # tokenize
133 | tokens_a = self.tokenizer.tokenize(t1)
134 | tokens_b = self.tokenizer.tokenize(t2)
135 |
136 | # combine to one sample
137 | cur_example = InputExample(guid=cur_id, tokens_a=tokens_a, tokens_b=tokens_b, is_next=is_next_label)
138 |
139 | # transform sample to features
140 | cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer)
141 |
142 | cur_tensors = (torch.tensor(cur_features.input_ids),
143 | torch.tensor(cur_features.input_mask),
144 | torch.tensor(cur_features.segment_ids),
145 | torch.tensor(cur_features.lm_label_ids),
146 | torch.tensor(cur_features.is_next))
147 |
148 | return cur_tensors
149 |
150 | def random_sent(self, index):
151 | """
152 | Get one sample from corpus consisting of two sentences. With prob. 50% these are two subsequent sentences
153 | from one doc. With 50% the second sentence will be a random one from another doc.
154 | :param index: int, index of sample.
155 | :return: (str, str, int), sentence 1, sentence 2, isNextSentence Label
156 | """
157 | t1, t2 = self.get_corpus_line(index)
158 | if random.random() > 0.5:
159 | label = 0
160 | else:
161 | t2 = self.get_random_line()
162 | label = 1
163 |
164 | assert len(t1) > 0
165 | assert len(t2) > 0
166 | return t1, t2, label
167 |
168 | def get_corpus_line(self, item):
169 | """
170 | Get one sample from corpus consisting of a pair of two subsequent lines from the same doc.
171 | :param item: int, index of sample.
172 | :return: (str, str), two subsequent sentences from corpus
173 | """
174 | t1 = ""
175 | t2 = ""
176 | assert item < self.corpus_lines
177 | if self.on_memory:
178 | sample = self.sample_to_doc[item]
179 | t1 = self.all_docs[sample["doc_id"]][sample["line"]]
180 | t2 = self.all_docs[sample["doc_id"]][sample["line"]+1]
181 | # used later to avoid random nextSentence from same doc
182 | self.current_doc = sample["doc_id"]
183 | return t1, t2
184 | else:
185 | if self.line_buffer is None:
186 | # read first non-empty line of file
187 | while t1 == "" :
188 | t1 = self.file.__next__().strip()
189 | t2 = self.file.__next__().strip()
190 | else:
191 | # use t2 from previous iteration as new t1
192 | t1 = self.line_buffer
193 | t2 = self.file.__next__().strip()
194 | # skip empty rows that are used for separating documents and keep track of current doc id
195 | while t2 == "" or t1 == "":
196 | t1 = self.file.__next__().strip()
197 | t2 = self.file.__next__().strip()
198 | self.current_doc = self.current_doc+1
199 | self.line_buffer = t2
200 |
201 | assert t1 != ""
202 | assert t2 != ""
203 | return t1, t2
204 |
205 | def get_random_line(self):
206 | """
207 | Get random line from another document for nextSentence task.
208 | :return: str, content of one line
209 | """
210 | # Similar to original tf repo: This outer loop should rarely go for more than one iteration for large
211 | # corpora. However, just to be careful, we try to make sure that
212 | # the random document is not the same as the document we're processing.
213 | for _ in range(10):
214 | if self.on_memory:
215 | rand_doc_idx = random.randint(0, len(self.all_docs)-1)
216 | rand_doc = self.all_docs[rand_doc_idx]
217 | line = rand_doc[random.randrange(len(rand_doc))]
218 | else:
219 | rand_index = random.randint(1, self.corpus_lines if self.corpus_lines < 1000 else 1000)
220 | #pick random line
221 | for _ in range(rand_index):
222 | line = self.get_next_line()
223 | #check if our picked random line is really from another doc like we want it to be
224 | if self.current_random_doc != self.current_doc:
225 | break
226 | return line
227 |
228 | def get_next_line(self):
229 | """ Gets next line of random_file and starts over when reaching end of file"""
230 | try:
231 | line = self.random_file.__next__().strip()
232 | #keep track of which document we are currently looking at to later avoid having the same doc as t1
233 | if line == "":
234 | self.current_random_doc = self.current_random_doc + 1
235 | line = self.random_file.__next__().strip()
236 | except StopIteration:
237 | self.random_file.close()
238 | self.random_file = open(self.corpus_path, "r", encoding=self.encoding)
239 | line = self.random_file.__next__().strip()
240 | return line
241 |
242 |
243 | class InputExample(object):
244 | """A single training/test example for the language model."""
245 |
246 | def __init__(self, guid, tokens_a, tokens_b=None, is_next=None, lm_labels=None):
247 | """Constructs a InputExample.
248 |
249 | Args:
250 | guid: Unique id for the example.
251 | tokens_a: string. The untokenized text of the first sequence. For single
252 | sequence tasks, only this sequence must be specified.
253 | tokens_b: (Optional) string. The untokenized text of the second sequence.
254 | Only must be specified for sequence pair tasks.
255 | label: (Optional) string. The label of the example. This should be
256 | specified for train and dev examples, but not for test examples.
257 | """
258 | self.guid = guid
259 | self.tokens_a = tokens_a
260 | self.tokens_b = tokens_b
261 | self.is_next = is_next # nextSentence
262 | self.lm_labels = lm_labels # masked words for language model
263 |
264 |
265 | class InputFeatures(object):
266 | """A single set of features of data."""
267 |
268 | def __init__(self, input_ids, input_mask, segment_ids, is_next, lm_label_ids):
269 | self.input_ids = input_ids
270 | self.input_mask = input_mask
271 | self.segment_ids = segment_ids
272 | self.is_next = is_next
273 | self.lm_label_ids = lm_label_ids
274 |
275 |
276 | def random_word(tokens, tokenizer):
277 | """
278 | Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
279 | :param tokens: list of str, tokenized sentence.
280 | :param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
281 | :return: (list of str, list of int), masked tokens and related labels for LM prediction
282 | """
283 | output_label = []
284 |
285 | for i, token in enumerate(tokens):
286 | prob = random.random()
287 | # mask token with 15% probability
288 | if prob < 0.15:
289 | prob /= 0.15
290 |
291 | # 80% randomly change token to mask token
292 | if prob < 0.8:
293 | tokens[i] = "[MASK]"
294 |
295 | # 10% randomly change token to random token
296 | elif prob < 0.9:
297 | tokens[i] = random.choice(list(tokenizer.vocab.items()))[0]
298 |
299 | # -> rest 10% randomly keep current token
300 |
301 | # append current token to output (we will predict these later)
302 | try:
303 | output_label.append(tokenizer.vocab[token])
304 | except KeyError:
305 | # For unknown words (should not occur with BPE vocab)
306 | output_label.append(tokenizer.vocab["[UNK]"])
307 | logger.warning("Cannot find token '{}' in vocab. Using [UNK] insetad".format(token))
308 | else:
309 | # no masking token (will be ignored by loss function later)
310 | output_label.append(-1)
311 |
312 | return tokens, output_label
313 |
314 |
315 | def convert_example_to_features(example, max_seq_length, tokenizer):
316 | """
317 | Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
318 | IDs, LM labels, input_mask, CLS and SEP tokens etc.
319 | :param example: InputExample, containing sentence input as strings and is_next label
320 | :param max_seq_length: int, maximum length of sequence.
321 | :param tokenizer: Tokenizer
322 | :return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
323 | """
324 | tokens_a = example.tokens_a
325 | tokens_b = example.tokens_b
326 | # Modifies `tokens_a` and `tokens_b` in place so that the total
327 | # length is less than the specified length.
328 | # Account for [CLS], [SEP], [SEP] with "- 3"
329 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
330 |
331 | t1_random, t1_label = random_word(tokens_a, tokenizer)
332 | t2_random, t2_label = random_word(tokens_b, tokenizer)
333 | # concatenate lm labels and account for CLS, SEP, SEP
334 | lm_label_ids = ([-1] + t1_label + [-1] + t2_label + [-1])
335 |
336 | # The convention in BERT is:
337 | # (a) For sequence pairs:
338 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
339 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
340 | # (b) For single sequences:
341 | # tokens: [CLS] the dog is hairy . [SEP]
342 | # type_ids: 0 0 0 0 0 0 0
343 | #
344 | # Where "type_ids" are used to indicate whether this is the first
345 | # sequence or the second sequence. The embedding vectors for `type=0` and
346 | # `type=1` were learned during pre-training and are added to the wordpiece
347 | # embedding vector (and position vector). This is not *strictly* necessary
348 | # since the [SEP] token unambigiously separates the sequences, but it makes
349 | # it easier for the model to learn the concept of sequences.
350 | #
351 | # For classification tasks, the first vector (corresponding to [CLS]) is
352 | # used as as the "sentence vector". Note that this only makes sense because
353 | # the entire model is fine-tuned.
354 | tokens = []
355 | segment_ids = []
356 | tokens.append("[CLS]")
357 | segment_ids.append(0)
358 | for token in tokens_a:
359 | tokens.append(token)
360 | segment_ids.append(0)
361 | tokens.append("[SEP]")
362 | segment_ids.append(0)
363 |
364 | assert len(tokens_b) > 0
365 | for token in tokens_b:
366 | tokens.append(token)
367 | segment_ids.append(1)
368 | tokens.append("[SEP]")
369 | segment_ids.append(1)
370 |
371 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
372 |
373 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
374 | # tokens are attended to.
375 | input_mask = [1] * len(input_ids)
376 |
377 | # Zero-pad up to the sequence length.
378 | while len(input_ids) < max_seq_length:
379 | input_ids.append(0)
380 | input_mask.append(0)
381 | segment_ids.append(0)
382 | lm_label_ids.append(-1)
383 |
384 | assert len(input_ids) == max_seq_length
385 | assert len(input_mask) == max_seq_length
386 | assert len(segment_ids) == max_seq_length
387 | assert len(lm_label_ids) == max_seq_length
388 |
389 | if example.guid < 5:
390 | logger.info("*** Example ***")
391 | logger.info("guid: %s" % (example.guid))
392 | logger.info("tokens: %s" % " ".join(
393 | [str(x) for x in tokens]))
394 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
395 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
396 | logger.info(
397 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
398 | logger.info("LM label: %s " % (lm_label_ids))
399 | logger.info("Is next sentence label: %s " % (example.is_next))
400 |
401 | features = InputFeatures(input_ids=input_ids,
402 | input_mask=input_mask,
403 | segment_ids=segment_ids,
404 | lm_label_ids=lm_label_ids,
405 | is_next=example.is_next)
406 | return features
407 |
408 |
409 | def main():
410 | """
411 | python run_lm_finetuning.py \
412 | --bert_model bert-base-uncased \
413 | --do_lower_case \
414 | --do_train \
415 | --train_file ../samples/sample_text.txt \
416 | --output_dir models \
417 | --num_train_epochs 5.0 \
418 | --learning_rate 3e-5 \
419 | --train_batch_size 32 \
420 | --max_seq_length 128 \
421 | """
422 | parser = argparse.ArgumentParser()
423 |
424 | # Required parameters
425 | parser.add_argument("--train_file",
426 | default=None,
427 | type=str,
428 | required=True,
429 | help="The input train corpus.")
430 | parser.add_argument("--bert_model", default=None, type=str, required=True,
431 | help="Bert pre-trained model selected in the list: bert-base-uncased, "
432 | "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
433 | parser.add_argument("--output_dir",
434 | default=None,
435 | type=str,
436 | required=True,
437 | help="The output directory where the model checkpoints will be written.")
438 |
439 | # Other parameters
440 | parser.add_argument("--max_seq_length",
441 | default=128,
442 | type=int,
443 | help="The maximum total input sequence length after WordPiece tokenization. \n"
444 | "Sequences longer than this will be truncated, and sequences shorter \n"
445 | "than this will be padded.")
446 | parser.add_argument("--do_train",
447 | action='store_true',
448 | help="Whether to run training.")
449 | parser.add_argument("--train_batch_size",
450 | default=32,
451 | type=int,
452 | help="Total batch size for training.")
453 | parser.add_argument("--eval_batch_size",
454 | default=8,
455 | type=int,
456 | help="Total batch size for eval.")
457 | parser.add_argument("--learning_rate",
458 | default=3e-5,
459 | type=float,
460 | help="The initial learning rate for Adam.")
461 | parser.add_argument("--num_train_epochs",
462 | default=3.0,
463 | type=float,
464 | help="Total number of training epochs to perform.")
465 | parser.add_argument("--warmup_proportion",
466 | default=0.1,
467 | type=float,
468 | help="Proportion of training to perform linear learning rate warmup for. "
469 | "E.g., 0.1 = 10%% of training.")
470 | parser.add_argument("--no_cuda",
471 | action='store_true',
472 | help="Whether not to use CUDA when available")
473 | parser.add_argument("--on_memory",
474 | action='store_true',
475 | help="Whether to load train samples into memory or use disk")
476 | parser.add_argument("--do_lower_case",
477 | action='store_true',
478 | help="Whether to lower case the input text. True for uncased models, False for cased models.")
479 | parser.add_argument("--local_rank",
480 | type=int,
481 | default=-1,
482 | help="local_rank for distributed training on gpus")
483 | parser.add_argument('--seed',
484 | type=int,
485 | default=42,
486 | help="random seed for initialization")
487 | parser.add_argument('--gradient_accumulation_steps',
488 | type=int,
489 | default=1,
490 | help="Number of updates steps to accumualte before performing a backward/update pass.")
491 | parser.add_argument('--fp16',
492 | action='store_true',
493 | help="Whether to use 16-bit float precision instead of 32-bit")
494 | parser.add_argument('--loss_scale',
495 | type = float, default = 0,
496 | help = "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
497 | "0 (default value): dynamic loss scaling.\n"
498 | "Positive power of 2: static loss scaling value.\n")
499 |
500 | args = parser.parse_args()
501 |
502 | if args.local_rank == -1 or args.no_cuda:
503 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
504 | n_gpu = torch.cuda.device_count()
505 | else:
506 | torch.cuda.set_device(args.local_rank)
507 | device = torch.device("cuda", args.local_rank)
508 | n_gpu = 1
509 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
510 | torch.distributed.init_process_group(backend='nccl')
511 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
512 | device, n_gpu, bool(args.local_rank != -1), args.fp16))
513 |
514 | if args.gradient_accumulation_steps < 1:
515 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
516 | args.gradient_accumulation_steps))
517 |
518 | args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
519 |
520 | random.seed(args.seed)
521 | np.random.seed(args.seed)
522 | torch.manual_seed(args.seed)
523 | if n_gpu > 0:
524 | torch.cuda.manual_seed_all(args.seed)
525 |
526 | if not args.do_train and not args.do_eval:
527 | raise ValueError("At least one of `do_train` or `do_eval` must be True.")
528 |
529 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
530 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
531 | os.makedirs(args.output_dir, exist_ok=True)
532 |
533 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
534 |
535 | #train_examples = None
536 | num_train_steps = None
537 | if args.do_train:
538 | print("Loading Train Dataset", args.train_file)
539 | """
540 | train_dataset:
541 | TODO:
542 | """
543 | train_dataset = BERTDataset(args.train_file, tokenizer, seq_len=args.max_seq_length,
544 | corpus_lines=None, on_memory=args.on_memory)
545 | num_train_steps = int(
546 | len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
547 |
548 | # Prepare model
549 | model = BertForPreTraining.from_pretrained(args.bert_model)
550 | if args.fp16:
551 | model.half()
552 | model.to(device)
553 | if args.local_rank != -1:
554 | try:
555 | from apex.parallel import DistributedDataParallel as DDP
556 | except ImportError:
557 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
558 | model = DDP(model)
559 | elif n_gpu > 1:
560 | model = torch.nn.DataParallel(model)
561 |
562 | # Prepare optimizer
563 | param_optimizer = list(model.named_parameters())
564 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
565 | optimizer_grouped_parameters = [
566 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
567 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
568 | ]
569 | if args.fp16:
570 | try:
571 | from apex.optimizers import FP16_Optimizer
572 | from apex.optimizers import FusedAdam
573 | except ImportError:
574 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
575 |
576 | optimizer = FusedAdam(optimizer_grouped_parameters,
577 | lr=args.learning_rate,
578 | bias_correction=False,
579 | max_grad_norm=1.0)
580 | if args.loss_scale == 0:
581 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
582 | else:
583 | optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
584 |
585 | else:
586 | optimizer = BertAdam(optimizer_grouped_parameters,
587 | lr=args.learning_rate,
588 | warmup=args.warmup_proportion,
589 | t_total=num_train_steps)
590 |
591 | global_step = 0
592 | if args.do_train:
593 | logger.info("***** Running training *****")
594 | logger.info(" Num examples = %d", len(train_dataset))
595 | logger.info(" Batch size = %d", args.train_batch_size)
596 | logger.info(" Num steps = %d", num_train_steps)
597 |
598 | if args.local_rank == -1:
599 | train_sampler = RandomSampler(train_dataset)
600 | else:
601 | #TODO: check if this works with current data generator from disk that relies on file.__next__
602 | # (it doesn't return item back by index)
603 | train_sampler = DistributedSampler(train_dataset)
604 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
605 |
606 | model.train()
607 | for _ in trange(int(args.num_train_epochs), desc="Epoch"):
608 | tr_loss = 0
609 | nb_tr_examples, nb_tr_steps = 0, 0
610 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
611 | batch = tuple(t.to(device) for t in batch)
612 | input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
613 | loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next)
614 | if n_gpu > 1:
615 | loss = loss.mean() # mean() to average on multi-gpu.
616 | if args.gradient_accumulation_steps > 1:
617 | loss = loss / args.gradient_accumulation_steps
618 | if args.fp16:
619 | optimizer.backward(loss)
620 | else:
621 | loss.backward()
622 | tr_loss += loss.item()
623 | nb_tr_examples += input_ids.size(0)
624 | nb_tr_steps += 1
625 | if (step + 1) % args.gradient_accumulation_steps == 0:
626 | # modify learning rate with special warm up BERT uses
627 | lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_steps, args.warmup_proportion)
628 | for param_group in optimizer.param_groups:
629 | param_group['lr'] = lr_this_step
630 | optimizer.step()
631 | optimizer.zero_grad()
632 | global_step += 1
633 |
634 | # Save a trained model
635 | logger.info("** ** * Saving fine - tuned model ** ** * ")
636 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
637 | output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
638 | if args.do_train:
639 | torch.save(model_to_save.state_dict(), output_model_file)
640 |
641 |
642 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
643 | """Truncates a sequence pair in place to the maximum length."""
644 |
645 | # This is a simple heuristic which will always truncate the longer sequence
646 | # one token at a time. This makes more sense than truncating an equal percent
647 | # of tokens from each, since if one sequence is very short then each token
648 | # that's truncated likely contains more information than a longer sequence.
649 | while True:
650 | total_length = len(tokens_a) + len(tokens_b)
651 | if total_length <= max_length:
652 | break
653 | if len(tokens_a) > len(tokens_b):
654 | tokens_a.pop()
655 | else:
656 | tokens_b.pop()
657 |
658 |
659 | def accuracy(out, labels):
660 | outputs = np.argmax(out, axis=1)
661 | return np.sum(outputs == labels)
662 |
663 |
664 | if __name__ == "__main__":
665 | main()
--------------------------------------------------------------------------------
/examples/run_swag.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """BERT finetuning runner."""
17 |
18 | import logging
19 | import os
20 | import argparse
21 | import random
22 | from tqdm import tqdm, trange
23 | import csv
24 |
25 | import numpy as np
26 | import torch
27 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
28 | from torch.utils.data.distributed import DistributedSampler
29 |
30 | from pytorch_pretrained_bert.tokenization import BertTokenizer
31 | from pytorch_pretrained_bert.modeling import BertForMultipleChoice
32 | from pytorch_pretrained_bert.optimization import BertAdam
33 | from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
34 |
35 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
36 | datefmt = '%m/%d/%Y %H:%M:%S',
37 | level = logging.INFO)
38 | logger = logging.getLogger(__name__)
39 |
40 |
41 | class SwagExample(object):
42 | """A single training/test example for the SWAG dataset."""
43 | def __init__(self,
44 | swag_id,
45 | context_sentence,
46 | start_ending,
47 | ending_0,
48 | ending_1,
49 | ending_2,
50 | ending_3,
51 | label = None):
52 | self.swag_id = swag_id
53 | self.context_sentence = context_sentence
54 | self.start_ending = start_ending
55 | self.endings = [
56 | ending_0,
57 | ending_1,
58 | ending_2,
59 | ending_3,
60 | ]
61 | self.label = label
62 |
63 | def __str__(self):
64 | return self.__repr__()
65 |
66 | def __repr__(self):
67 | l = [
68 | f"swag_id: {self.swag_id}",
69 | f"context_sentence: {self.context_sentence}",
70 | f"start_ending: {self.start_ending}",
71 | f"ending_0: {self.endings[0]}",
72 | f"ending_1: {self.endings[1]}",
73 | f"ending_2: {self.endings[2]}",
74 | f"ending_3: {self.endings[3]}",
75 | ]
76 |
77 | if self.label is not None:
78 | l.append(f"label: {self.label}")
79 |
80 | return ", ".join(l)
81 |
82 |
83 | class InputFeatures(object):
84 | def __init__(self,
85 | example_id,
86 | choices_features,
87 | label
88 |
89 | ):
90 | self.example_id = example_id
91 | self.choices_features = [
92 | {
93 | 'input_ids': input_ids,
94 | 'input_mask': input_mask,
95 | 'segment_ids': segment_ids
96 | }
97 | for _, input_ids, input_mask, segment_ids in choices_features
98 | ]
99 | self.label = label
100 |
101 |
102 | def read_swag_examples(input_file, is_training):
103 | with open(input_file, 'r', encoding='utf-8') as f:
104 | reader = csv.reader(f)
105 | lines = list(reader)
106 |
107 | if is_training and lines[0][-1] != 'label':
108 | raise ValueError(
109 | "For training, the input file must contain a label column."
110 | )
111 |
112 | examples = [
113 | SwagExample(
114 | swag_id = line[2],
115 | context_sentence = line[4],
116 | start_ending = line[5], # in the swag dataset, the
117 | # common beginning of each
118 | # choice is stored in "sent2".
119 | ending_0 = line[7],
120 | ending_1 = line[8],
121 | ending_2 = line[9],
122 | ending_3 = line[10],
123 | label = int(line[11]) if is_training else None
124 | ) for line in lines[1:] # we skip the line with the column names
125 | ]
126 |
127 | return examples
128 |
129 | def convert_examples_to_features(examples, tokenizer, max_seq_length,
130 | is_training):
131 | """Loads a data file into a list of `InputBatch`s."""
132 |
133 | # Swag is a multiple choice task. To perform this task using Bert,
134 | # we will use the formatting proposed in "Improving Language
135 | # Understanding by Generative Pre-Training" and suggested by
136 | # @jacobdevlin-google in this issue
137 | # https://github.com/google-research/bert/issues/38.
138 | #
139 | # Each choice will correspond to a sample on which we run the
140 | # inference. For a given Swag example, we will create the 4
141 | # following inputs:
142 | # - [CLS] context [SEP] choice_1 [SEP]
143 | # - [CLS] context [SEP] choice_2 [SEP]
144 | # - [CLS] context [SEP] choice_3 [SEP]
145 | # - [CLS] context [SEP] choice_4 [SEP]
146 | # The model will output a single value for each input. To get the
147 | # final decision of the model, we will run a softmax over these 4
148 | # outputs.
149 | features = []
150 | for example_index, example in enumerate(examples):
151 | context_tokens = tokenizer.tokenize(example.context_sentence)
152 | start_ending_tokens = tokenizer.tokenize(example.start_ending)
153 |
154 | choices_features = []
155 | for ending_index, ending in enumerate(example.endings):
156 | # We create a copy of the context tokens in order to be
157 | # able to shrink it according to ending_tokens
158 | context_tokens_choice = context_tokens[:]
159 | ending_tokens = start_ending_tokens + tokenizer.tokenize(ending)
160 | # Modifies `context_tokens_choice` and `ending_tokens` in
161 | # place so that the total length is less than the
162 | # specified length. Account for [CLS], [SEP], [SEP] with
163 | # "- 3"
164 | _truncate_seq_pair(context_tokens_choice, ending_tokens, max_seq_length - 3)
165 |
166 | tokens = ["[CLS]"] + context_tokens_choice + ["[SEP]"] + ending_tokens + ["[SEP]"]
167 | segment_ids = [0] * (len(context_tokens_choice) + 2) + [1] * (len(ending_tokens) + 1)
168 |
169 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
170 | input_mask = [1] * len(input_ids)
171 |
172 | # Zero-pad up to the sequence length.
173 | padding = [0] * (max_seq_length - len(input_ids))
174 | input_ids += padding
175 | input_mask += padding
176 | segment_ids += padding
177 |
178 | assert len(input_ids) == max_seq_length
179 | assert len(input_mask) == max_seq_length
180 | assert len(segment_ids) == max_seq_length
181 |
182 | choices_features.append((tokens, input_ids, input_mask, segment_ids))
183 |
184 | label = example.label
185 | if example_index < 5:
186 | logger.info("*** Example ***")
187 | logger.info(f"swag_id: {example.swag_id}")
188 | for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features):
189 | logger.info(f"choice: {choice_idx}")
190 | logger.info(f"tokens: {' '.join(tokens)}")
191 | logger.info(f"input_ids: {' '.join(map(str, input_ids))}")
192 | logger.info(f"input_mask: {' '.join(map(str, input_mask))}")
193 | logger.info(f"segment_ids: {' '.join(map(str, segment_ids))}")
194 | if is_training:
195 | logger.info(f"label: {label}")
196 |
197 | features.append(
198 | InputFeatures(
199 | example_id = example.swag_id,
200 | choices_features = choices_features,
201 | label = label
202 | )
203 | )
204 |
205 | return features
206 |
207 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
208 | """Truncates a sequence pair in place to the maximum length."""
209 |
210 | # This is a simple heuristic which will always truncate the longer sequence
211 | # one token at a time. This makes more sense than truncating an equal percent
212 | # of tokens from each, since if one sequence is very short then each token
213 | # that's truncated likely contains more information than a longer sequence.
214 | while True:
215 | total_length = len(tokens_a) + len(tokens_b)
216 | if total_length <= max_length:
217 | break
218 | if len(tokens_a) > len(tokens_b):
219 | tokens_a.pop()
220 | else:
221 | tokens_b.pop()
222 |
223 | def accuracy(out, labels):
224 | outputs = np.argmax(out, axis=1)
225 | return np.sum(outputs == labels)
226 |
227 | def select_field(features, field):
228 | return [
229 | [
230 | choice[field]
231 | for choice in feature.choices_features
232 | ]
233 | for feature in features
234 | ]
235 |
236 | def warmup_linear(x, warmup=0.002):
237 | if x < warmup:
238 | return x/warmup
239 | return 1.0 - x
240 |
241 | def main():
242 | parser = argparse.ArgumentParser()
243 |
244 | ## Required parameters
245 | parser.add_argument("--data_dir",
246 | default=None,
247 | type=str,
248 | required=True,
249 | help="The input data dir. Should contain the .csv files (or other data files) for the task.")
250 | parser.add_argument("--bert_model", default=None, type=str, required=True,
251 | help="Bert pre-trained model selected in the list: bert-base-uncased, "
252 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
253 | "bert-base-multilingual-cased, bert-base-chinese.")
254 | parser.add_argument("--output_dir",
255 | default=None,
256 | type=str,
257 | required=True,
258 | help="The output directory where the model checkpoints will be written.")
259 |
260 | ## Other parameters
261 | parser.add_argument("--max_seq_length",
262 | default=128,
263 | type=int,
264 | help="The maximum total input sequence length after WordPiece tokenization. \n"
265 | "Sequences longer than this will be truncated, and sequences shorter \n"
266 | "than this will be padded.")
267 | parser.add_argument("--do_train",
268 | action='store_true',
269 | help="Whether to run training.")
270 | parser.add_argument("--do_eval",
271 | action='store_true',
272 | help="Whether to run eval on the dev set.")
273 | parser.add_argument("--do_lower_case",
274 | action='store_true',
275 | help="Set this flag if you are using an uncased model.")
276 | parser.add_argument("--train_batch_size",
277 | default=32,
278 | type=int,
279 | help="Total batch size for training.")
280 | parser.add_argument("--eval_batch_size",
281 | default=8,
282 | type=int,
283 | help="Total batch size for eval.")
284 | parser.add_argument("--learning_rate",
285 | default=5e-5,
286 | type=float,
287 | help="The initial learning rate for Adam.")
288 | parser.add_argument("--num_train_epochs",
289 | default=3.0,
290 | type=float,
291 | help="Total number of training epochs to perform.")
292 | parser.add_argument("--warmup_proportion",
293 | default=0.1,
294 | type=float,
295 | help="Proportion of training to perform linear learning rate warmup for. "
296 | "E.g., 0.1 = 10%% of training.")
297 | parser.add_argument("--no_cuda",
298 | action='store_true',
299 | help="Whether not to use CUDA when available")
300 | parser.add_argument("--local_rank",
301 | type=int,
302 | default=-1,
303 | help="local_rank for distributed training on gpus")
304 | parser.add_argument('--seed',
305 | type=int,
306 | default=42,
307 | help="random seed for initialization")
308 | parser.add_argument('--gradient_accumulation_steps',
309 | type=int,
310 | default=1,
311 | help="Number of updates steps to accumulate before performing a backward/update pass.")
312 | parser.add_argument('--fp16',
313 | action='store_true',
314 | help="Whether to use 16-bit float precision instead of 32-bit")
315 | parser.add_argument('--loss_scale',
316 | type=float, default=0,
317 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
318 | "0 (default value): dynamic loss scaling.\n"
319 | "Positive power of 2: static loss scaling value.\n")
320 |
321 | args = parser.parse_args()
322 |
323 | if args.local_rank == -1 or args.no_cuda:
324 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
325 | n_gpu = torch.cuda.device_count()
326 | else:
327 | torch.cuda.set_device(args.local_rank)
328 | device = torch.device("cuda", args.local_rank)
329 | n_gpu = 1
330 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
331 | torch.distributed.init_process_group(backend='nccl')
332 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
333 | device, n_gpu, bool(args.local_rank != -1), args.fp16))
334 |
335 | if args.gradient_accumulation_steps < 1:
336 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
337 | args.gradient_accumulation_steps))
338 |
339 | args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
340 |
341 | random.seed(args.seed)
342 | np.random.seed(args.seed)
343 | torch.manual_seed(args.seed)
344 | if n_gpu > 0:
345 | torch.cuda.manual_seed_all(args.seed)
346 |
347 | if not args.do_train and not args.do_eval:
348 | raise ValueError("At least one of `do_train` or `do_eval` must be True.")
349 |
350 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
351 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
352 | os.makedirs(args.output_dir, exist_ok=True)
353 |
354 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
355 |
356 | train_examples = None
357 | num_train_steps = None
358 | if args.do_train:
359 | train_examples = read_swag_examples(os.path.join(args.data_dir, 'train.csv'), is_training = True)
360 | num_train_steps = int(
361 | len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
362 |
363 | # Prepare model
364 | model = BertForMultipleChoice.from_pretrained(args.bert_model,
365 | cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank),
366 | num_choices=4)
367 | if args.fp16:
368 | model.half()
369 | model.to(device)
370 | if args.local_rank != -1:
371 | try:
372 | from apex.parallel import DistributedDataParallel as DDP
373 | except ImportError:
374 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
375 |
376 | model = DDP(model)
377 | elif n_gpu > 1:
378 | model = torch.nn.DataParallel(model)
379 |
380 | # Prepare optimizer
381 | param_optimizer = list(model.named_parameters())
382 |
383 | # hack to remove pooler, which is not used
384 | # thus it produce None grad that break apex
385 | param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
386 |
387 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
388 | optimizer_grouped_parameters = [
389 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
390 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
391 | ]
392 | t_total = num_train_steps
393 | if args.local_rank != -1:
394 | t_total = t_total // torch.distributed.get_world_size()
395 | if args.fp16:
396 | try:
397 | from apex.optimizers import FP16_Optimizer
398 | from apex.optimizers import FusedAdam
399 | except ImportError:
400 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
401 |
402 | optimizer = FusedAdam(optimizer_grouped_parameters,
403 | lr=args.learning_rate,
404 | bias_correction=False,
405 | max_grad_norm=1.0)
406 | if args.loss_scale == 0:
407 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
408 | else:
409 | optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
410 | else:
411 | optimizer = BertAdam(optimizer_grouped_parameters,
412 | lr=args.learning_rate,
413 | warmup=args.warmup_proportion,
414 | t_total=t_total)
415 |
416 | global_step = 0
417 | if args.do_train:
418 | train_features = convert_examples_to_features(
419 | train_examples, tokenizer, args.max_seq_length, True)
420 | logger.info("***** Running training *****")
421 | logger.info(" Num examples = %d", len(train_examples))
422 | logger.info(" Batch size = %d", args.train_batch_size)
423 | logger.info(" Num steps = %d", num_train_steps)
424 | all_input_ids = torch.tensor(select_field(train_features, 'input_ids'), dtype=torch.long)
425 | all_input_mask = torch.tensor(select_field(train_features, 'input_mask'), dtype=torch.long)
426 | all_segment_ids = torch.tensor(select_field(train_features, 'segment_ids'), dtype=torch.long)
427 | all_label = torch.tensor([f.label for f in train_features], dtype=torch.long)
428 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
429 | if args.local_rank == -1:
430 | train_sampler = RandomSampler(train_data)
431 | else:
432 | train_sampler = DistributedSampler(train_data)
433 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
434 |
435 | model.train()
436 | for _ in trange(int(args.num_train_epochs), desc="Epoch"):
437 | tr_loss = 0
438 | nb_tr_examples, nb_tr_steps = 0, 0
439 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
440 | batch = tuple(t.to(device) for t in batch)
441 | input_ids, input_mask, segment_ids, label_ids = batch
442 | loss = model(input_ids, segment_ids, input_mask, label_ids)
443 | if n_gpu > 1:
444 | loss = loss.mean() # mean() to average on multi-gpu.
445 | if args.fp16 and args.loss_scale != 1.0:
446 | # rescale loss for fp16 training
447 | # see https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html
448 | loss = loss * args.loss_scale
449 | if args.gradient_accumulation_steps > 1:
450 | loss = loss / args.gradient_accumulation_steps
451 | tr_loss += loss.item()
452 | nb_tr_examples += input_ids.size(0)
453 | nb_tr_steps += 1
454 |
455 | if args.fp16:
456 | optimizer.backward(loss)
457 | else:
458 | loss.backward()
459 | if (step + 1) % args.gradient_accumulation_steps == 0:
460 | # modify learning rate with special warm up BERT uses
461 | lr_this_step = args.learning_rate * warmup_linear(global_step/t_total, args.warmup_proportion)
462 | for param_group in optimizer.param_groups:
463 | param_group['lr'] = lr_this_step
464 | optimizer.step()
465 | optimizer.zero_grad()
466 | global_step += 1
467 |
468 | # Save a trained model
469 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
470 | output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
471 | torch.save(model_to_save.state_dict(), output_model_file)
472 |
473 | # Load a trained model that you have fine-tuned
474 | model_state_dict = torch.load(output_model_file)
475 | model = BertForMultipleChoice.from_pretrained(args.bert_model,
476 | state_dict=model_state_dict,
477 | num_choices=4)
478 | model.to(device)
479 |
480 | if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
481 | eval_examples = read_swag_examples(os.path.join(args.data_dir, 'val.csv'), is_training = True)
482 | eval_features = convert_examples_to_features(
483 | eval_examples, tokenizer, args.max_seq_length, True)
484 | logger.info("***** Running evaluation *****")
485 | logger.info(" Num examples = %d", len(eval_examples))
486 | logger.info(" Batch size = %d", args.eval_batch_size)
487 | all_input_ids = torch.tensor(select_field(eval_features, 'input_ids'), dtype=torch.long)
488 | all_input_mask = torch.tensor(select_field(eval_features, 'input_mask'), dtype=torch.long)
489 | all_segment_ids = torch.tensor(select_field(eval_features, 'segment_ids'), dtype=torch.long)
490 | all_label = torch.tensor([f.label for f in eval_features], dtype=torch.long)
491 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
492 | # Run prediction for full data
493 | eval_sampler = SequentialSampler(eval_data)
494 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
495 |
496 | model.eval()
497 | eval_loss, eval_accuracy = 0, 0
498 | nb_eval_steps, nb_eval_examples = 0, 0
499 | for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
500 | input_ids = input_ids.to(device)
501 | input_mask = input_mask.to(device)
502 | segment_ids = segment_ids.to(device)
503 | label_ids = label_ids.to(device)
504 |
505 | with torch.no_grad():
506 | tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
507 | logits = model(input_ids, segment_ids, input_mask)
508 |
509 | logits = logits.detach().cpu().numpy()
510 | label_ids = label_ids.to('cpu').numpy()
511 | tmp_eval_accuracy = accuracy(logits, label_ids)
512 |
513 | eval_loss += tmp_eval_loss.mean().item()
514 | eval_accuracy += tmp_eval_accuracy
515 |
516 | nb_eval_examples += input_ids.size(0)
517 | nb_eval_steps += 1
518 |
519 | eval_loss = eval_loss / nb_eval_steps
520 | eval_accuracy = eval_accuracy / nb_eval_examples
521 |
522 | result = {'eval_loss': eval_loss,
523 | 'eval_accuracy': eval_accuracy,
524 | 'global_step': global_step,
525 | 'loss': tr_loss/nb_tr_steps}
526 |
527 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
528 | with open(output_eval_file, "w") as writer:
529 | logger.info("***** Eval results *****")
530 | for key in sorted(result.keys()):
531 | logger.info(" %s = %s", key, str(result[key]))
532 | writer.write("%s = %s\n" % (key, str(result[key])))
533 |
534 |
535 | if __name__ == "__main__":
536 | main()
537 |
--------------------------------------------------------------------------------
/examples/softmax.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 |
4 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
5 | datefmt='%m/%d/%Y %H:%M:%S',
6 | level=logging.INFO)
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | def _compute_softmax(scores):
11 | """Compute softmax probability over raw logits."""
12 | if not scores:
13 | return []
14 |
15 | max_score = None
16 | for score in scores:
17 | if max_score is None or score > max_score:
18 | max_score = score
19 |
20 | exp_scores = []
21 | total_sum = 0.0
22 | for score in scores:
23 | x = math.exp(score - max_score)
24 | exp_scores.append(x)
25 | total_sum += x
26 |
27 | probs = []
28 | for score in exp_scores:
29 | probs.append(score / total_sum)
30 | return probs
31 |
32 |
33 | probs = _compute_softmax([1, 2, 3])
34 | logger.info("probs: %s" % probs)
35 | # logger.info("Test: %s", probs)
36 | # logger.info(sum(probs))
37 |
--------------------------------------------------------------------------------
/examples/test_BertForMaskedLM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
3 |
4 | # Load pre-trained model tokenizer (vocabulary)
5 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
6 |
7 | # Tokenized input
8 | text = "Who was Jim Henson ? Jim Henson was a puppeteer"
9 | tokenized_text = tokenizer.tokenize(text)
10 |
11 | # Mask a token that we will try to predict back with `BertForMaskedLM`
12 | masked_index = 6
13 | tokenized_text[masked_index] = '[MASK]'
14 | assert tokenized_text == ['who', 'was', 'jim', 'henson', '?', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer']
15 |
16 | # Convert token to vocabulary indices
17 | indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
18 | # Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
19 | segments_ids = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
20 |
21 | # Convert inputs to PyTorch tensors
22 | tokens_tensor = torch.tensor([indexed_tokens])
23 | segments_tensors = torch.tensor([segments_ids])
24 |
25 | # ========================= BertForMaskedLM ==============================
26 | # Load pre-trained model (weights)
27 | model = BertForMaskedLM.from_pretrained('bert-base-uncased')
28 | model.eval()
29 |
30 | """
31 | predictions.size():
32 | torch.Size([1, 11, 30522])
33 |
34 | predictions[0, masked_index]:
35 | tensor([-7.8384, -7.8162, -7.8893, ..., -6.9924, -6.1897, -4.5417],
36 | grad_fn=)
37 |
38 | predictions[0, masked_index].size():
39 | torch.Size([30522])
40 | """
41 | # Predict all tokens
42 | predictions = model(tokens_tensor, segments_tensors)
43 |
44 | # confirm we were able to predict 'henson'
45 | predicted_index = torch.argmax(predictions[0, masked_index]).item()
46 | predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
47 | assert predicted_token == 'henson'
48 |
--------------------------------------------------------------------------------
/examples/test_BertModel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
3 |
4 | """
5 | Let's see how to use BertModel to get hidden states
6 | """
7 |
8 | # Load pre-trained model tokenizer (vocabulary)
9 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
10 |
11 | # Tokenized input
12 | text = "Who was Jim Henson ? Jim Henson was a puppeteer"
13 | tokenized_text = tokenizer.tokenize(text)
14 |
15 | # Mask a token that we will try to predict back with `BertForMaskedLM`
16 | masked_index = 6
17 | tokenized_text[masked_index] = '[MASK]'
18 | assert tokenized_text == ['who', 'was', 'jim', 'henson', '?', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer']
19 |
20 | # Convert token to vocabulary indices
21 | indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
22 | # Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
23 | segments_ids = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
24 |
25 | # Convert inputs to PyTorch tensors
26 | tokens_tensor = torch.tensor([indexed_tokens])
27 | segments_tensors = torch.tensor([segments_ids])
28 |
29 | # ========================= BertModel to get hidden states ==============================
30 | # Load pre-trained model (weights)
31 | model = BertModel.from_pretrained('bert-base-uncased')
32 | model.eval()
33 |
34 | # # If you have a GPU, put everything on cuda
35 | # tokens_tensor = tokens_tensor.to('cuda')
36 | # segments_tensors = segments_tensors.to('cuda')
37 | # model.to('cuda')
38 |
39 | # Predict hidden states features for each layer
40 | with torch.no_grad():
41 | encoded_layers, _ = model(tokens_tensor, segments_tensors)
42 | # We have a hidden states for each of the 12 layers in model bert-base-uncased
43 | assert len(encoded_layers) == 12
44 |
45 |
--------------------------------------------------------------------------------
/examples/test_squad.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | base_path = "/home/wyb/data/squad_v1.1/"
4 | train_file_name = "train-v1.1.json"
5 | dev_file_name = "dev-v1.1.json"
6 | input_file = base_path + train_file_name
7 |
8 | with open(input_file, "r", encoding='utf-8') as reader:
9 | input_data = json.load(reader)["data"]
10 |
11 | # dic = {'a': 1, 'b': 2, 'c': 3}
12 | # js = json.dumps(input_file, sort_keys=True, indent=4, separators=(',', ':'))
13 | # print(js)
14 |
15 | print(len(input_data))
16 | print(input_data[0])
17 | print(input_data[1])
18 | print(type(input_data[1]))
--------------------------------------------------------------------------------
/examples/test_tokenization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
3 |
4 | # OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
5 | import logging
6 | logging.basicConfig(level=logging.INFO)
7 |
8 | # Load pre-trained model tokenizer (vocabulary)
9 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
10 |
11 | # Tokenized input
12 | text = "[CLS] Who 这是一个测试 was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]" # test version
13 | # text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
14 |
15 | """
16 | tokenized_text:
17 | ['[CLS]', 'who', '[UNK]', '[UNK]', '一', '[UNK]', '[UNK]', '[UNK]', 'was', 'jim', 'henson', '?', '[SEP]',
18 | 'jim', 'henson', 'was', 'a', 'puppet', '##eer', '[SEP]']
19 | """
20 | tokenized_text = tokenizer.tokenize(text)
21 |
22 | # Mask a token that we will try to predict back with `BertForMaskedLM`
23 | masked_index = 8
24 | tokenized_text[masked_index] = '[MASK]'
25 | # assert tokenized_text == \
26 | # ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a',
27 | # 'puppet', '##eer', '[SEP]']
28 |
29 | # test version
30 | assert tokenized_text == \
31 | ['[CLS]', 'who', '[UNK]', '[UNK]', '一', '[UNK]', '[UNK]', '[UNK]', '[MASK]', 'jim', 'henson', '?',
32 | '[SEP]', 'jim', 'henson', 'was', 'a', 'puppet', '##eer', '[SEP]']
33 |
34 | # Convert token to vocabulary indices
35 | """
36 | indexed_tokens:
37 | [101, 2040, 100, 100, 1740, 100, 100, 100, 103, 3958, 27227, 1029, 102, 3958, 27227, 2001, 1037, 13997, 11510, 102]
38 | """
39 | indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
40 | # Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
41 | # segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
42 |
43 | # test version
44 | segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
45 |
46 | """
47 | tokens_tensor:
48 | tensor([[ 101, 2040, 100, 100, 1740, 100, 100, 100, 103, 3958,
49 | 27227, 1029, 102, 3958, 27227, 2001, 1037, 13997, 11510, 102]])
50 |
51 | segments_tensors:
52 | tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]])
53 | """
54 | # Convert inputs to PyTorch tensors
55 | tokens_tensor = torch.tensor([indexed_tokens])
56 | segments_tensors = torch.tensor([segments_ids])
57 |
58 |
--------------------------------------------------------------------------------
/examples/valid_data.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | # BASE_PATH = "/home/wyb/PycharmProjects/DuReader/data/demo/"
4 | BASE_PATH = "/DATA/disk1/wangyongbo/lic2019/DuReader/data/preprocessed/"
5 |
6 | with open(BASE_PATH + "trainset/search.train_bert.json", "r", encoding='utf-8') as reader:
7 | source = json.load(reader)
8 | input_data = source["data"]
9 |
10 | cou_equal = 0
11 | cou_total = 0
12 | for entry in input_data:
13 | for paragraph in entry["paragraphs"]:
14 | paragraph_text = paragraph["context"]
15 |
16 | for qa in paragraph["qas"]:
17 | cou_total += 1
18 |
19 | """
20 | {
21 | 'text': 'in the late 1990s',
22 | 'answer_start': 269 # by char
23 | }
24 | """
25 | answer_dict = qa["answers"][0]
26 | answer = answer_dict["text"]
27 | start_position = answer_dict["answer_start"] # by word
28 | end_position = answer_dict["answer_end"] # by word
29 |
30 | if paragraph_text[start_position:(end_position+1)] == answer.strip():
31 | cou_equal += 1
32 |
33 | print("cou_equal / cou_total = ", cou_equal, " / ", cou_total)
34 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.4.0"
2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
3 | from .modeling import (BertConfig, BertModel, BertForPreTraining,
4 | BertForMaskedLM, BertForNextSentencePrediction,
5 | BertForSequenceClassification, BertForMultipleChoice,
6 | BertForTokenClassification, BertForQuestionAnswering)
7 | from .optimization import BertAdam
8 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE
9 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__main__.py:
--------------------------------------------------------------------------------
1 | # coding: utf8
2 | def main():
3 | import sys
4 | try:
5 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
6 | except ModuleNotFoundError:
7 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
8 | "In that case, it requires TensorFlow to be installed. Please see "
9 | "https://www.tensorflow.org/install/ for installation instructions.")
10 | raise
11 |
12 | if len(sys.argv) != 5:
13 | # pylint: disable=line-too-long
14 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
15 | else:
16 | PYTORCH_DUMP_OUTPUT = sys.argv.pop()
17 | TF_CONFIG = sys.argv.pop()
18 | TF_CHECKPOINT = sys.argv.pop()
19 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
20 |
21 | if __name__ == '__main__':
22 | main()
23 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yongbowin/pytorch-pretrained-BERT_annotation/dabfc0941fbeaac931c78ce7d55b15f9f51d62a8/pytorch_pretrained_bert/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__pycache__/file_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yongbowin/pytorch-pretrained-BERT_annotation/dabfc0941fbeaac931c78ce7d55b15f9f51d62a8/pytorch_pretrained_bert/__pycache__/file_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__pycache__/modeling.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yongbowin/pytorch-pretrained-BERT_annotation/dabfc0941fbeaac931c78ce7d55b15f9f51d62a8/pytorch_pretrained_bert/__pycache__/modeling.cpython-36.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__pycache__/optimization.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yongbowin/pytorch-pretrained-BERT_annotation/dabfc0941fbeaac931c78ce7d55b15f9f51d62a8/pytorch_pretrained_bert/__pycache__/optimization.cpython-36.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__pycache__/tokenization.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yongbowin/pytorch-pretrained-BERT_annotation/dabfc0941fbeaac931c78ce7d55b15f9f51d62a8/pytorch_pretrained_bert/__pycache__/tokenization.cpython-36.pyc
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert BERT checkpoint."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 | import re
23 | import argparse
24 | import tensorflow as tf
25 | import torch
26 | import numpy as np
27 |
28 | from .modeling import BertConfig, BertForPreTraining
29 |
30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
31 | config_path = os.path.abspath(bert_config_file)
32 | tf_path = os.path.abspath(tf_checkpoint_path)
33 | print("Converting TensorFlow checkpoint from {} with config at {}".format(tf_path, config_path))
34 | # Load weights from TF model
35 | init_vars = tf.train.list_variables(tf_path)
36 | names = []
37 | arrays = []
38 | for name, shape in init_vars:
39 | print("Loading TF weight {} with shape {}".format(name, shape))
40 | array = tf.train.load_variable(tf_path, name)
41 | names.append(name)
42 | arrays.append(array)
43 |
44 | # Initialise PyTorch model
45 | config = BertConfig.from_json_file(bert_config_file)
46 | print("Building PyTorch model from configuration: {}".format(str(config)))
47 | model = BertForPreTraining(config)
48 |
49 | for name, array in zip(names, arrays):
50 | name = name.split('/')
51 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
52 | # which are not required for using pretrained model
53 | if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
54 | print("Skipping {}".format("/".join(name)))
55 | continue
56 | pointer = model
57 | for m_name in name:
58 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
59 | l = re.split(r'_(\d+)', m_name)
60 | else:
61 | l = [m_name]
62 | if l[0] == 'kernel' or l[0] == 'gamma':
63 | pointer = getattr(pointer, 'weight')
64 | elif l[0] == 'output_bias' or l[0] == 'beta':
65 | pointer = getattr(pointer, 'bias')
66 | elif l[0] == 'output_weights':
67 | pointer = getattr(pointer, 'weight')
68 | else:
69 | pointer = getattr(pointer, l[0])
70 | if len(l) >= 2:
71 | num = int(l[1])
72 | pointer = pointer[num]
73 | if m_name[-11:] == '_embeddings':
74 | pointer = getattr(pointer, 'weight')
75 | elif m_name == 'kernel':
76 | array = np.transpose(array)
77 | try:
78 | assert pointer.shape == array.shape
79 | except AssertionError as e:
80 | e.args += (pointer.shape, array.shape)
81 | raise
82 | print("Initialize PyTorch weight {}".format(name))
83 | pointer.data = torch.from_numpy(array)
84 |
85 | # Save pytorch-model
86 | print("Save PyTorch model to {}".format(pytorch_dump_path))
87 | torch.save(model.state_dict(), pytorch_dump_path)
88 |
89 |
90 | if __name__ == "__main__":
91 | parser = argparse.ArgumentParser()
92 | ## Required parameters
93 | parser.add_argument("--tf_checkpoint_path",
94 | default = None,
95 | type = str,
96 | required = True,
97 | help = "Path the TensorFlow checkpoint path.")
98 | parser.add_argument("--bert_config_file",
99 | default = None,
100 | type = str,
101 | required = True,
102 | help = "The config json file corresponding to the pre-trained BERT model. \n"
103 | "This specifies the model architecture.")
104 | parser.add_argument("--pytorch_dump_path",
105 | default = None,
106 | type = str,
107 | required = True,
108 | help = "Path to the output PyTorch model.")
109 | args = parser.parse_args()
110 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
111 | args.bert_config_file,
112 | args.pytorch_dump_path)
113 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/file_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for working with the local dataset cache.
3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4 | Copyright by the AllenNLP authors.
5 | """
6 |
7 | import os
8 | import logging
9 | import shutil
10 | import tempfile
11 | import json
12 | from urllib.parse import urlparse
13 | from pathlib import Path
14 | from typing import Optional, Tuple, Union, IO, Callable, Set
15 | from hashlib import sha256
16 | from functools import wraps
17 |
18 | from tqdm import tqdm
19 |
20 | import boto3
21 | from botocore.exceptions import ClientError
22 | import requests
23 |
24 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
25 |
26 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
27 | Path.home() / '.pytorch_pretrained_bert'))
28 |
29 |
30 | def url_to_filename(url: str, etag: str = None) -> str:
31 | """
32 | Convert `url` into a hashed filename in a repeatable way.
33 | If `etag` is specified, append its hash to the url's, delimited
34 | by a period.
35 | """
36 | url_bytes = url.encode('utf-8')
37 | url_hash = sha256(url_bytes)
38 | filename = url_hash.hexdigest()
39 |
40 | if etag:
41 | etag_bytes = etag.encode('utf-8')
42 | etag_hash = sha256(etag_bytes)
43 | filename += '.' + etag_hash.hexdigest()
44 |
45 | return filename
46 |
47 |
48 | def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]:
49 | """
50 | Return the url and etag (which may be ``None``) stored for `filename`.
51 | Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist.
52 | """
53 | if cache_dir is None:
54 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
55 | if isinstance(cache_dir, Path):
56 | cache_dir = str(cache_dir)
57 |
58 | cache_path = os.path.join(cache_dir, filename)
59 | if not os.path.exists(cache_path):
60 | raise FileNotFoundError("file {} not found".format(cache_path))
61 |
62 | meta_path = cache_path + '.json'
63 | if not os.path.exists(meta_path):
64 | raise FileNotFoundError("file {} not found".format(meta_path))
65 |
66 | with open(meta_path) as meta_file:
67 | metadata = json.load(meta_file)
68 | url = metadata['url']
69 | etag = metadata['etag']
70 |
71 | return url, etag
72 |
73 |
74 | def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str:
75 | """
76 | Given something that might be a URL (or might be a local path),
77 | determine which. If it's a URL, download the file and cache it, and
78 | return the path to the cached file. If it's already a local path,
79 | make sure the file exists and then return the path.
80 | """
81 | if cache_dir is None:
82 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
83 | if isinstance(url_or_filename, Path):
84 | url_or_filename = str(url_or_filename)
85 | if isinstance(cache_dir, Path):
86 | cache_dir = str(cache_dir)
87 |
88 | parsed = urlparse(url_or_filename)
89 |
90 | if parsed.scheme in ('http', 'https', 's3'):
91 | # URL, so get it from the cache (downloading if necessary)
92 | return get_from_cache(url_or_filename, cache_dir)
93 | elif os.path.exists(url_or_filename):
94 | # File, and it exists.
95 | return url_or_filename
96 | elif parsed.scheme == '':
97 | # File, but it doesn't exist.
98 | raise FileNotFoundError("file {} not found".format(url_or_filename))
99 | else:
100 | # Something unknown
101 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
102 |
103 |
104 | def split_s3_path(url: str) -> Tuple[str, str]:
105 | """Split a full s3 path into the bucket name and path."""
106 | parsed = urlparse(url)
107 | if not parsed.netloc or not parsed.path:
108 | raise ValueError("bad s3 path {}".format(url))
109 | bucket_name = parsed.netloc
110 | s3_path = parsed.path
111 | # Remove '/' at beginning of path.
112 | if s3_path.startswith("/"):
113 | s3_path = s3_path[1:]
114 | return bucket_name, s3_path
115 |
116 |
117 | def s3_request(func: Callable):
118 | """
119 | Wrapper function for s3 requests in order to create more helpful error
120 | messages.
121 | """
122 |
123 | @wraps(func)
124 | def wrapper(url: str, *args, **kwargs):
125 | try:
126 | return func(url, *args, **kwargs)
127 | except ClientError as exc:
128 | if int(exc.response["Error"]["Code"]) == 404:
129 | raise FileNotFoundError("file {} not found".format(url))
130 | else:
131 | raise
132 |
133 | return wrapper
134 |
135 |
136 | @s3_request
137 | def s3_etag(url: str) -> Optional[str]:
138 | """Check ETag on S3 object."""
139 | s3_resource = boto3.resource("s3")
140 | bucket_name, s3_path = split_s3_path(url)
141 | s3_object = s3_resource.Object(bucket_name, s3_path)
142 | return s3_object.e_tag
143 |
144 |
145 | @s3_request
146 | def s3_get(url: str, temp_file: IO) -> None:
147 | """Pull a file directly from S3."""
148 | s3_resource = boto3.resource("s3")
149 | bucket_name, s3_path = split_s3_path(url)
150 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
151 |
152 |
153 | def http_get(url: str, temp_file: IO) -> None:
154 | req = requests.get(url, stream=True)
155 | content_length = req.headers.get('Content-Length')
156 | total = int(content_length) if content_length is not None else None
157 | progress = tqdm(unit="B", total=total)
158 | for chunk in req.iter_content(chunk_size=1024):
159 | if chunk: # filter out keep-alive new chunks
160 | progress.update(len(chunk))
161 | temp_file.write(chunk)
162 | progress.close()
163 |
164 |
165 | def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
166 | """
167 | Given a URL, look for the corresponding dataset in the local cache.
168 | If it's not there, download it. Then return the path to the cached file.
169 | """
170 | if cache_dir is None:
171 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
172 | if isinstance(cache_dir, Path):
173 | cache_dir = str(cache_dir)
174 |
175 | os.makedirs(cache_dir, exist_ok=True)
176 |
177 | # Get eTag to add to filename, if it exists.
178 | if url.startswith("s3://"):
179 | etag = s3_etag(url)
180 | else:
181 | response = requests.head(url, allow_redirects=True)
182 | if response.status_code != 200:
183 | raise IOError("HEAD request failed for url {} with status code {}"
184 | .format(url, response.status_code))
185 | etag = response.headers.get("ETag")
186 |
187 | filename = url_to_filename(url, etag)
188 |
189 | # get cache path to put the file
190 | cache_path = os.path.join(cache_dir, filename)
191 |
192 | if not os.path.exists(cache_path):
193 | # Download to temporary file, then copy to cache dir once finished.
194 | # Otherwise you get corrupt cache entries if the download gets interrupted.
195 | with tempfile.NamedTemporaryFile() as temp_file:
196 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
197 |
198 | # GET file object
199 | if url.startswith("s3://"):
200 | s3_get(url, temp_file)
201 | else:
202 | http_get(url, temp_file)
203 |
204 | # we are copying the file before closing it, so flush to avoid truncation
205 | temp_file.flush()
206 | # shutil.copyfileobj() starts at the current position, so go to the start
207 | temp_file.seek(0)
208 |
209 | logger.info("copying %s to cache at %s", temp_file.name, cache_path)
210 | with open(cache_path, 'wb') as cache_file:
211 | shutil.copyfileobj(temp_file, cache_file)
212 |
213 | logger.info("creating metadata file for %s", cache_path)
214 | meta = {'url': url, 'etag': etag}
215 | meta_path = cache_path + '.json'
216 | with open(meta_path, 'w') as meta_file:
217 | json.dump(meta, meta_file)
218 |
219 | logger.info("removing temp file %s", temp_file.name)
220 |
221 | return cache_path
222 |
223 |
224 | def read_set_from_file(filename: str) -> Set[str]:
225 | '''
226 | Extract a de-duped collection (set) of text from a file.
227 | Expected file format is one item per line.
228 | '''
229 | collection = set()
230 | with open(filename, 'r', encoding='utf-8') as file_:
231 | for line in file_:
232 | collection.add(line.rstrip())
233 | return collection
234 |
235 |
236 | def get_file_extension(path: str, dot=True, lower: bool = True):
237 | ext = os.path.splitext(path)[1]
238 | ext = ext if dot else ext[1:]
239 | return ext.lower() if lower else ext
240 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch optimization for BERT model."""
16 |
17 | import math
18 | import torch
19 | from torch.optim import Optimizer
20 | from torch.optim.optimizer import required
21 | from torch.nn.utils import clip_grad_norm_
22 |
23 | def warmup_cosine(x, warmup=0.002):
24 | if x < warmup:
25 | return x/warmup
26 | return 0.5 * (1.0 + torch.cos(math.pi * x))
27 |
28 | def warmup_constant(x, warmup=0.002):
29 | if x < warmup:
30 | return x/warmup
31 | return 1.0
32 |
33 | def warmup_linear(x, warmup=0.002):
34 | if x < warmup:
35 | return x/warmup
36 | return 1.0 - x
37 |
38 | SCHEDULES = {
39 | 'warmup_cosine':warmup_cosine,
40 | 'warmup_constant':warmup_constant,
41 | 'warmup_linear':warmup_linear,
42 | }
43 |
44 |
45 | class BertAdam(Optimizer):
46 | """Implements BERT version of Adam algorithm with weight decay fix.
47 | Params:
48 | lr: learning rate
49 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
50 | t_total: total number of training steps for the learning
51 | rate schedule, -1 means constant learning rate. Default: -1
52 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
53 | b1: Adams b1. Default: 0.9
54 | b2: Adams b2. Default: 0.999
55 | e: Adams epsilon. Default: 1e-6
56 | weight_decay: Weight decay. Default: 0.01
57 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
58 | """
59 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
60 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01,
61 | max_grad_norm=1.0):
62 | if lr is not required and lr < 0.0:
63 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
64 | if schedule not in SCHEDULES:
65 | raise ValueError("Invalid schedule parameter: {}".format(schedule))
66 | if not 0.0 <= warmup < 1.0 and not warmup == -1:
67 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
68 | if not 0.0 <= b1 < 1.0:
69 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
70 | if not 0.0 <= b2 < 1.0:
71 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
72 | if not e >= 0.0:
73 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
74 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
75 | b1=b1, b2=b2, e=e, weight_decay=weight_decay,
76 | max_grad_norm=max_grad_norm)
77 | super(BertAdam, self).__init__(params, defaults)
78 |
79 | def get_lr(self):
80 | lr = []
81 | for group in self.param_groups:
82 | for p in group['params']:
83 | state = self.state[p]
84 | if len(state) == 0:
85 | return [0]
86 | if group['t_total'] != -1:
87 | schedule_fct = SCHEDULES[group['schedule']]
88 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
89 | else:
90 | lr_scheduled = group['lr']
91 | lr.append(lr_scheduled)
92 | return lr
93 |
94 | def step(self, closure=None):
95 | """Performs a single optimization step.
96 |
97 | Arguments:
98 | closure (callable, optional): A closure that reevaluates the model
99 | and returns the loss.
100 | """
101 | loss = None
102 | if closure is not None:
103 | loss = closure()
104 |
105 | for group in self.param_groups:
106 | for p in group['params']:
107 | if p.grad is None:
108 | continue
109 | grad = p.grad.data
110 | if grad.is_sparse:
111 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
112 |
113 | state = self.state[p]
114 |
115 | # State initialization
116 | if len(state) == 0:
117 | state['step'] = 0
118 | # Exponential moving average of gradient values
119 | state['next_m'] = torch.zeros_like(p.data)
120 | # Exponential moving average of squared gradient values
121 | state['next_v'] = torch.zeros_like(p.data)
122 |
123 | next_m, next_v = state['next_m'], state['next_v']
124 | beta1, beta2 = group['b1'], group['b2']
125 |
126 | # Add grad clipping
127 | if group['max_grad_norm'] > 0:
128 | clip_grad_norm_(p, group['max_grad_norm'])
129 |
130 | # Decay the first and second moment running average coefficient
131 | # In-place operations to update the averages at the same time
132 | next_m.mul_(beta1).add_(1 - beta1, grad)
133 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
134 | update = next_m / (next_v.sqrt() + group['e'])
135 |
136 | # Just adding the square of the weights to the loss function is *not*
137 | # the correct way of using L2 regularization/weight decay with Adam,
138 | # since that will interact with the m and v parameters in strange ways.
139 | #
140 | # Instead we want to decay the weights in a manner that doesn't interact
141 | # with the m/v parameters. This is equivalent to adding the square
142 | # of the weights to the loss with plain (non-momentum) SGD.
143 | if group['weight_decay'] > 0.0:
144 | update += group['weight_decay'] * p.data
145 |
146 | if group['t_total'] != -1:
147 | schedule_fct = SCHEDULES[group['schedule']]
148 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
149 | else:
150 | lr_scheduled = group['lr']
151 |
152 | update_with_lr = lr_scheduled * update
153 | p.data.add_(-update_with_lr)
154 |
155 | state['step'] += 1
156 |
157 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
158 | # No bias correction
159 | # bias_correction1 = 1 - beta1 ** state['step']
160 | # bias_correction2 = 1 - beta2 ** state['step']
161 |
162 | return loss
163 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import unicodedata
23 | import os
24 | import logging
25 |
26 | from .file_utils import cached_path
27 |
28 | logger = logging.getLogger(__name__)
29 |
30 | """Uncased means that the text has been lowercased before WordPiece tokenization.
31 | """
32 | PRETRAINED_VOCAB_ARCHIVE_MAP = {
33 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
34 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
35 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
36 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
37 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
38 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
39 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
40 | }
41 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
42 | 'bert-base-uncased': 512,
43 | 'bert-large-uncased': 512,
44 | 'bert-base-cased': 512,
45 | 'bert-large-cased': 512,
46 | 'bert-base-multilingual-uncased': 512,
47 | 'bert-base-multilingual-cased': 512,
48 | 'bert-base-chinese': 512,
49 | }
50 | VOCAB_NAME = 'vocab.txt'
51 |
52 |
53 | def load_vocab(vocab_file):
54 | """Loads a vocabulary file into a dictionary."""
55 | """The size of vocab is 30522
56 | """
57 | vocab = collections.OrderedDict()
58 | index = 0
59 | with open(vocab_file, "r", encoding="utf-8") as reader:
60 | while True:
61 | token = reader.readline()
62 | if not token:
63 | break
64 | token = token.strip()
65 | vocab[token] = index
66 | index += 1
67 | return vocab
68 |
69 |
70 | def whitespace_tokenize(text):
71 | """Runs basic whitespace cleaning and splitting on a peice of text."""
72 | text = text.strip()
73 | if not text:
74 | return []
75 | tokens = text.split() # split by ' ' return a word list.
76 | return tokens
77 |
78 |
79 | class BertTokenizer(object):
80 | """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
81 |
82 | def __init__(self, vocab_file, do_lower_case=True, max_len=None,
83 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]", "[unused1]")):
84 | if not os.path.isfile(vocab_file):
85 | raise ValueError(
86 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
87 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
88 | self.vocab = load_vocab(vocab_file)
89 | self.ids_to_tokens = collections.OrderedDict(
90 | [(ids, tok) for tok, ids in self.vocab.items()])
91 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
92 | never_split=never_split)
93 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
94 | self.max_len = max_len if max_len is not None else int(1e12) # max_len=512
95 |
96 | def tokenize(self, text):
97 | split_tokens = []
98 | for token in self.basic_tokenizer.tokenize(text):
99 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
100 | split_tokens.append(sub_token)
101 | """
102 | split_tokens:
103 | ['[CLS]', 'who', '[UNK]', '[UNK]', '一', '[UNK]', '[UNK]', '[UNK]', 'was', 'jim', 'henson', '?', '[SEP]',
104 | 'jim', 'henson', 'was', 'a', 'puppet', '##eer', '[SEP]']
105 | """
106 | return split_tokens
107 |
108 | def convert_tokens_to_ids(self, tokens):
109 | """Converts a sequence of tokens into ids using the vocab."""
110 | ids = []
111 | for token in tokens:
112 | ids.append(self.vocab[token])
113 | if len(ids) > self.max_len:
114 | raise ValueError(
115 | "Token indices sequence length is longer than the specified maximum "
116 | " sequence length for this BERT model ({} > {}). Running this"
117 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
118 | )
119 | return ids
120 |
121 | def convert_ids_to_tokens(self, ids):
122 | """Converts a sequence of ids in wordpiece tokens using the vocab."""
123 | tokens = []
124 | for i in ids:
125 | tokens.append(self.ids_to_tokens[i])
126 | return tokens
127 |
128 | @classmethod
129 | def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
130 | """
131 | Instantiate a PreTrainedBertModel from a pre-trained model file.
132 | Download and cache the pre-trained model file if needed.
133 | """
134 | if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP:
135 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name]
136 | else:
137 | vocab_file = pretrained_model_name
138 | if os.path.isdir(vocab_file):
139 | vocab_file = os.path.join(vocab_file, VOCAB_NAME)
140 | # redirect to the cache, if necessary
141 | try:
142 | """Save file to "~/.pytorch_pretrained_bert".
143 | """
144 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
145 | except FileNotFoundError:
146 | logger.error(
147 | "Model name '{}' was not found in model name list ({}). "
148 | "We assumed '{}' was a path or url but couldn't find any file "
149 | "associated to this path or url.".format(
150 | pretrained_model_name,
151 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
152 | vocab_file))
153 | return None
154 | if resolved_vocab_file == vocab_file:
155 | logger.info("loading vocabulary file {}".format(vocab_file))
156 | else:
157 | logger.info("loading vocabulary file {} from cache at {}".format(
158 | vocab_file, resolved_vocab_file))
159 | if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
160 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
161 | # than the number of positional embeddings
162 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name]
163 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
164 | # Instantiate tokenizer.
165 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
166 | return tokenizer
167 |
168 |
169 | class BasicTokenizer(object):
170 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
171 |
172 | def __init__(self,
173 | do_lower_case=True,
174 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]", "[unused1]")):
175 | """Constructs a BasicTokenizer.
176 |
177 | Args:
178 | do_lower_case: Whether to lower case the input.
179 | """
180 | self.do_lower_case = do_lower_case
181 | self.never_split = never_split
182 |
183 | def tokenize(self, text):
184 | """Tokenizes a piece of text."""
185 | text = self._clean_text(text)
186 | # This was added on November 1st, 2018 for the multilingual and Chinese
187 | # models. This is also applied to the English models now, but it doesn't
188 | # matter since the English models were not trained on any Chinese data
189 | # and generally don't have any Chinese data in them (there are Chinese
190 | # characters in the vocabulary because Wikipedia does have some Chinese
191 | # words in the English Wikipedia.).
192 | text = self._tokenize_chinese_chars(text)
193 | orig_tokens = whitespace_tokenize(text)
194 | split_tokens = []
195 | for token in orig_tokens:
196 | if self.do_lower_case and token not in self.never_split:
197 | token = token.lower()
198 | token = self._run_strip_accents(token)
199 | split_tokens.extend(self._run_split_on_punc(token))
200 |
201 | """
202 | output_tokens:
203 | ['[CLS]', 'who', '这', '是', '一', '个', '测', '试', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', 'henson',
204 | 'was', 'a', 'puppeteer', '[SEP]']
205 | """
206 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
207 | return output_tokens
208 |
209 | def _run_strip_accents(self, text):
210 | """Strips accents from a piece of text."""
211 | """
212 | Strips accents mean the following,
213 | input: "Málaga"
214 | output: "Malaga"
215 |
216 | output:
217 | ['M', 'a', 'l', 'a', 'g', 'a']
218 | "".join(output)
219 | Malaga
220 | """
221 | text = unicodedata.normalize("NFD", text)
222 | output = []
223 | for char in text:
224 | cat = unicodedata.category(char)
225 | if cat == "Mn":
226 | continue
227 | output.append(char)
228 | return "".join(output)
229 |
230 | def _run_split_on_punc(self, text):
231 | """Splits punctuation on a piece of text."""
232 | """
233 | If text is one of "[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]", return directly without removing "[" or "]"
234 | """
235 | if text in self.never_split:
236 | return [text]
237 | chars = list(text)
238 | i = 0
239 | start_new_word = True
240 | output = []
241 | while i < len(chars):
242 | char = chars[i]
243 | if _is_punctuation(char):
244 | output.append([char])
245 | start_new_word = True
246 | else:
247 | if start_new_word:
248 | output.append([])
249 | start_new_word = False
250 | output[-1].append(char)
251 | i += 1
252 |
253 | return ["".join(x) for x in output]
254 |
255 | def _tokenize_chinese_chars(self, text):
256 | """Adds whitespace around any CJK character."""
257 | """
258 | CJK:
259 | Chinese, Japanese, Korean
260 |
261 | text = "[CLS] Who 这是一个测试 was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
262 |
263 | ord():
264 | Return the Unicode code point for a one-character string.
265 |
266 | cp = 91
267 | cp = 67
268 | cp = 76
269 | ...
270 |
271 | "".join(output):
272 | [CLS] Who 这 是 一 个 测 试 was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]
273 |
274 | output:
275 | ['[', 'C', 'L', 'S', ']', ' ', 'W', 'h', 'o', ' ', ' ', '这', ' ', ' ', '是', ' ', ' ', '一', ' ', ' ', '个',
276 | ' ', ' ', '测', ' ', ' ', '试', ' ', ' ', 'w', 'a', 's', ' ', 'J', 'i', 'm', ' ', 'H', 'e', 'n', 's',
277 | 'o', 'n', ' ', '?', ' ', '[', 'S', 'E', 'P', ']', ' ', 'J', 'i', 'm', ' ', 'H', 'e', 'n', 's', 'o', 'n',
278 | ' ', 'w', 'a', 's', ' ', 'a', ' ', 'p', 'u', 'p', 'p', 'e', 't', 'e', 'e', 'r', ' ', '[', 'S', 'E', 'P',
279 | ']']
280 | """
281 | output = []
282 | for char in text:
283 | cp = ord(char) #
284 | if self._is_chinese_char(cp):
285 | output.append(" ")
286 | output.append(char)
287 | output.append(" ")
288 | else:
289 | output.append(char)
290 | return "".join(output)
291 |
292 | def _is_chinese_char(self, cp):
293 | """Checks whether CP is the codepoint of a CJK character."""
294 | # This defines a "chinese character" as anything in the CJK Unicode block:
295 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
296 | #
297 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
298 | # despite its name. The modern Korean Hangul alphabet is a different block,
299 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
300 | # space-separated words, so they are not treated specially and handled
301 | # like the all of the other languages.
302 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
303 | (cp >= 0x3400 and cp <= 0x4DBF) or #
304 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
305 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
306 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
307 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
308 | (cp >= 0xF900 and cp <= 0xFAFF) or #
309 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
310 | return True
311 |
312 | return False
313 |
314 | def _clean_text(self, text):
315 | """Performs invalid character removal and whitespace cleanup on text."""
316 | output = []
317 | for char in text:
318 | cp = ord(char)
319 | if cp == 0 or cp == 0xfffd or _is_control(char):
320 | continue
321 | if _is_whitespace(char):
322 | output.append(" ")
323 | else:
324 | output.append(char)
325 | return "".join(output)
326 |
327 |
328 | class WordpieceTokenizer(object):
329 | """Runs WordPiece tokenization."""
330 |
331 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
332 | self.vocab = vocab
333 | self.unk_token = unk_token
334 | self.max_input_chars_per_word = max_input_chars_per_word
335 |
336 | def tokenize(self, text):
337 | """Tokenizes a piece of text into its word pieces.
338 |
339 | This uses a greedy longest-match-first algorithm to perform tokenization
340 | using the given vocabulary.
341 |
342 | For example:
343 | input = "unaffable"
344 | output = ["un", "##aff", "##able"]
345 |
346 | Args:
347 | text: A single token or whitespace separated tokens. This should have
348 | already been passed through `BasicTokenizer`.
349 |
350 | Returns:
351 | A list of wordpiece tokens.
352 |
353 | Papers: (for wordPiece)
354 | [1] https://arxiv.org/pdf/1508.07909.pdf
355 | [2] https://arxiv.org/pdf/1609.08144.pdf
356 |
357 | whitespace_tokenize():
358 | split by ' ' return a word list, i.e. ["word1", "word2", ..., "wordn"]
359 | """
360 |
361 | output_tokens = []
362 | for token in whitespace_tokenize(text): # text is a word
363 | chars = list(token)
364 | if len(chars) > self.max_input_chars_per_word:
365 | output_tokens.append(self.unk_token) # unknown
366 | continue
367 |
368 | is_bad = False
369 | start = 0
370 | sub_tokens = []
371 | while start < len(chars):
372 | end = len(chars)
373 | cur_substr = None
374 | while start < end:
375 | substr = "".join(chars[start:end])
376 | if start > 0:
377 | substr = "##" + substr
378 | if substr in self.vocab:
379 | cur_substr = substr
380 | break
381 | end -= 1
382 | if cur_substr is None:
383 | is_bad = True
384 | break
385 | sub_tokens.append(cur_substr) # if cur_substr in vocab, append it to [].
386 | start = end # when head part was appended, start the latter part of this word.
387 |
388 | if is_bad:
389 | output_tokens.append(self.unk_token)
390 | else:
391 | output_tokens.extend(sub_tokens)
392 | return output_tokens
393 |
394 |
395 | def _is_whitespace(char):
396 | """Checks whether `chars` is a whitespace character."""
397 | # \t, \n, and \r are technically contorl characters but we treat them
398 | # as whitespace since they are generally considered as such.
399 | if char == " " or char == "\t" or char == "\n" or char == "\r":
400 | return True
401 | cat = unicodedata.category(char)
402 | if cat == "Zs":
403 | return True
404 | return False
405 |
406 |
407 | def _is_control(char):
408 | """Checks whether `chars` is a control character."""
409 | # These are technically control characters but we count them as whitespace
410 | # characters.
411 | if char == "\t" or char == "\n" or char == "\r":
412 | return False
413 | cat = unicodedata.category(char)
414 | if cat.startswith("C"):
415 | return True
416 | return False
417 |
418 |
419 | def _is_punctuation(char):
420 | """Checks whether `chars` is a punctuation character."""
421 | cp = ord(char)
422 | # We treat all non-letter/number ASCII as punctuation.
423 | # Characters such as "^", "$", and "`" are not in the Unicode
424 | # Punctuation class but we treat them as punctuation anyways, for
425 | # consistency.
426 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
427 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
428 | return True
429 | cat = unicodedata.category(char)
430 | if cat.startswith("P"):
431 | return True
432 | return False
433 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # PyTorch
2 | torch>=0.4.1
3 | # progress bars in model download and training scripts
4 | tqdm
5 | # Accessing files from S3 directly.
6 | boto3
7 | # Used for downloading models over HTTP
8 | requests
--------------------------------------------------------------------------------
/samples/input.txt:
--------------------------------------------------------------------------------
1 | Who was Jim Henson ? ||| Jim Henson was a puppeteer
2 |
--------------------------------------------------------------------------------
/samples/sample_text.txt:
--------------------------------------------------------------------------------
1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত
2 | Text should be one-sentence-per-line, with empty lines between documents.
3 | This sample text is public domain and was randomly selected from Project Guttenberg.
4 |
5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors.
6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity.
7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them.
8 | "Cass" Beard had risen early that morning, but not with a view to discovery.
9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets.
10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency.
11 | This was nearly opposite.
12 | Mr. Cassius crossed the highway, and stopped suddenly.
13 | Something glittered in the nearest red pool before him.
14 | Gold, surely!
15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring.
16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass."
17 | Like most of his fellow gold-seekers, Cass was superstitious.
18 |
19 | The fountain of classic wisdom, Hypatia herself.
20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge.
21 | From my youth I felt in me a soul above the matter-entangled herd.
22 | She revealed to me the glorious fact, that I am a spark of Divinity itself.
23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's.
24 | There is a philosophic pleasure in opening one's treasures to the modest young.
25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street.
26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide;
27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind.
28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now.
29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert;
30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts.
31 | At last they reached the quay at the opposite end of the street;
32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers.
33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him.
34 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """
2 | Exec this file by run 'python setup.py build' to related the this source code path.
3 | """
4 |
5 | """
6 | Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py
7 |
8 | To create the package for pypi.
9 |
10 | 1. Change the version in __init__.py and setup.py.
11 |
12 | 2. Commit these changes with the message: "Release: VERSION"
13 |
14 | 3. Add a tag in git to mark the release: "git tag VERSION -m'Adds tag VERSION for pypi' "
15 | Push the tag to git: git push --tags origin master
16 |
17 | 4. Build both the sources and the wheel. Do not change anything in setup.py between
18 | creating the wheel and the source distribution (obviously).
19 |
20 | For the wheel, run: "python setup.py bdist_wheel" in the top level allennlp directory.
21 | (this will build a wheel for the python version you use to build it - make sure you use python 3.x).
22 |
23 | For the sources, run: "python setup.py sdist"
24 | You should now have a /dist directory with both .whl and .tar.gz source versions of allennlp.
25 |
26 | 5. Check that everything looks correct by uploading the package to the pypi test server:
27 |
28 | twine upload dist/* -r pypitest
29 | (pypi suggest using twine as other methods upload files via plaintext.)
30 |
31 | Check that you can install it in a virtualenv by running:
32 | pip install -i https://testpypi.python.org/pypi allennlp
33 |
34 | 6. Upload the final version to actual pypi:
35 | twine upload dist/* -r pypi
36 |
37 | 7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
38 |
39 | """
40 | from setuptools import find_packages, setup
41 |
42 | setup(
43 | name="pytorch_pretrained_bert",
44 | version="0.4.0",
45 | author="Thomas Wolf, Victor Sanh, Tim Rault, Google AI Language Team Authors",
46 | author_email="thomas@huggingface.co",
47 | description="PyTorch version of Google AI BERT model with script to load Google pre-trained models",
48 | long_description=open("README.md", "r", encoding='utf-8').read(),
49 | long_description_content_type="text/markdown",
50 | keywords='BERT NLP deep learning google',
51 | license='Apache',
52 | url="https://github.com/huggingface/pytorch-pretrained-BERT",
53 | packages=find_packages(exclude=["*.tests", "*.tests.*",
54 | "tests.*", "tests"]),
55 | install_requires=['torch>=0.4.1',
56 | 'numpy',
57 | 'boto3',
58 | 'requests',
59 | 'tqdm'],
60 | entry_points={
61 | 'console_scripts': [
62 | "pytorch_pretrained_bert=pytorch_pretrained_bert.__main__:main"
63 | ]
64 | },
65 | python_requires='>=3.5.0',
66 | tests_require=['pytest'],
67 | classifiers=[
68 | 'Intended Audience :: Science/Research',
69 | 'License :: OSI Approved :: Apache Software License',
70 | 'Programming Language :: Python :: 3',
71 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
72 | ],
73 | )
74 |
--------------------------------------------------------------------------------
/tests/modeling_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import unittest
20 | import json
21 | import random
22 |
23 | import torch
24 |
25 | from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM,
26 | BertForNextSentencePrediction, BertForPreTraining,
27 | BertForQuestionAnswering, BertForSequenceClassification,
28 | BertForTokenClassification)
29 |
30 |
31 | class BertModelTest(unittest.TestCase):
32 | class BertModelTester(object):
33 |
34 | def __init__(self,
35 | parent,
36 | batch_size=13,
37 | seq_length=7,
38 | is_training=True,
39 | use_input_mask=True,
40 | use_token_type_ids=True,
41 | use_labels=True,
42 | vocab_size=99,
43 | hidden_size=32,
44 | num_hidden_layers=5,
45 | num_attention_heads=4,
46 | intermediate_size=37,
47 | hidden_act="gelu",
48 | hidden_dropout_prob=0.1,
49 | attention_probs_dropout_prob=0.1,
50 | max_position_embeddings=512,
51 | type_vocab_size=16,
52 | type_sequence_label_size=2,
53 | initializer_range=0.02,
54 | num_labels=3,
55 | scope=None):
56 | self.parent = parent
57 | self.batch_size = batch_size
58 | self.seq_length = seq_length
59 | self.is_training = is_training
60 | self.use_input_mask = use_input_mask
61 | self.use_token_type_ids = use_token_type_ids
62 | self.use_labels = use_labels
63 | self.vocab_size = vocab_size
64 | self.hidden_size = hidden_size
65 | self.num_hidden_layers = num_hidden_layers
66 | self.num_attention_heads = num_attention_heads
67 | self.intermediate_size = intermediate_size
68 | self.hidden_act = hidden_act
69 | self.hidden_dropout_prob = hidden_dropout_prob
70 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
71 | self.max_position_embeddings = max_position_embeddings
72 | self.type_vocab_size = type_vocab_size
73 | self.type_sequence_label_size = type_sequence_label_size
74 | self.initializer_range = initializer_range
75 | self.num_labels = num_labels
76 | self.scope = scope
77 |
78 | def prepare_config_and_inputs(self):
79 | input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
80 |
81 | input_mask = None
82 | if self.use_input_mask:
83 | input_mask = BertModelTest.ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
84 |
85 | token_type_ids = None
86 | if self.use_token_type_ids:
87 | token_type_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
88 |
89 | sequence_labels = None
90 | token_labels = None
91 | if self.use_labels:
92 | sequence_labels = BertModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
93 | token_labels = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.num_labels)
94 |
95 | config = BertConfig(
96 | vocab_size_or_config_json_file=self.vocab_size,
97 | hidden_size=self.hidden_size,
98 | num_hidden_layers=self.num_hidden_layers,
99 | num_attention_heads=self.num_attention_heads,
100 | intermediate_size=self.intermediate_size,
101 | hidden_act=self.hidden_act,
102 | hidden_dropout_prob=self.hidden_dropout_prob,
103 | attention_probs_dropout_prob=self.attention_probs_dropout_prob,
104 | max_position_embeddings=self.max_position_embeddings,
105 | type_vocab_size=self.type_vocab_size,
106 | initializer_range=self.initializer_range)
107 |
108 | return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels
109 |
110 | def check_loss_output(self, result):
111 | self.parent.assertListEqual(
112 | list(result["loss"].size()),
113 | [])
114 |
115 | def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
116 | model = BertModel(config=config)
117 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
118 | outputs = {
119 | "sequence_output": all_encoder_layers[-1],
120 | "pooled_output": pooled_output,
121 | "all_encoder_layers": all_encoder_layers,
122 | }
123 | return outputs
124 |
125 | def check_bert_model_output(self, result):
126 | self.parent.assertListEqual(
127 | [size for layer in result["all_encoder_layers"] for size in layer.size()],
128 | [self.batch_size, self.seq_length, self.hidden_size] * self.num_hidden_layers)
129 | self.parent.assertListEqual(
130 | list(result["sequence_output"].size()),
131 | [self.batch_size, self.seq_length, self.hidden_size])
132 | self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
133 |
134 |
135 | def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
136 | model = BertForMaskedLM(config=config)
137 | loss = model(input_ids, token_type_ids, input_mask, token_labels)
138 | prediction_scores = model(input_ids, token_type_ids, input_mask)
139 | outputs = {
140 | "loss": loss,
141 | "prediction_scores": prediction_scores,
142 | }
143 | return outputs
144 |
145 | def check_bert_for_masked_lm_output(self, result):
146 | self.parent.assertListEqual(
147 | list(result["prediction_scores"].size()),
148 | [self.batch_size, self.seq_length, self.vocab_size])
149 |
150 | def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
151 | model = BertForNextSentencePrediction(config=config)
152 | loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
153 | seq_relationship_score = model(input_ids, token_type_ids, input_mask)
154 | outputs = {
155 | "loss": loss,
156 | "seq_relationship_score": seq_relationship_score,
157 | }
158 | return outputs
159 |
160 | def check_bert_for_next_sequence_prediction_output(self, result):
161 | self.parent.assertListEqual(
162 | list(result["seq_relationship_score"].size()),
163 | [self.batch_size, 2])
164 |
165 |
166 | def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
167 | model = BertForPreTraining(config=config)
168 | loss = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels)
169 | prediction_scores, seq_relationship_score = model(input_ids, token_type_ids, input_mask)
170 | outputs = {
171 | "loss": loss,
172 | "prediction_scores": prediction_scores,
173 | "seq_relationship_score": seq_relationship_score,
174 | }
175 | return outputs
176 |
177 | def check_bert_for_pretraining_output(self, result):
178 | self.parent.assertListEqual(
179 | list(result["prediction_scores"].size()),
180 | [self.batch_size, self.seq_length, self.vocab_size])
181 | self.parent.assertListEqual(
182 | list(result["seq_relationship_score"].size()),
183 | [self.batch_size, 2])
184 |
185 |
186 | def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
187 | model = BertForQuestionAnswering(config=config)
188 | loss = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels)
189 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
190 | outputs = {
191 | "loss": loss,
192 | "start_logits": start_logits,
193 | "end_logits": end_logits,
194 | }
195 | return outputs
196 |
197 | def check_bert_for_question_answering_output(self, result):
198 | self.parent.assertListEqual(
199 | list(result["start_logits"].size()),
200 | [self.batch_size, self.seq_length])
201 | self.parent.assertListEqual(
202 | list(result["end_logits"].size()),
203 | [self.batch_size, self.seq_length])
204 |
205 |
206 | def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
207 | model = BertForSequenceClassification(config=config, num_labels=self.num_labels)
208 | loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
209 | logits = model(input_ids, token_type_ids, input_mask)
210 | outputs = {
211 | "loss": loss,
212 | "logits": logits,
213 | }
214 | return outputs
215 |
216 | def check_bert_for_sequence_classification_output(self, result):
217 | self.parent.assertListEqual(
218 | list(result["logits"].size()),
219 | [self.batch_size, self.num_labels])
220 |
221 |
222 | def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
223 | model = BertForTokenClassification(config=config, num_labels=self.num_labels)
224 | loss = model(input_ids, token_type_ids, input_mask, token_labels)
225 | logits = model(input_ids, token_type_ids, input_mask)
226 | outputs = {
227 | "loss": loss,
228 | "logits": logits,
229 | }
230 | return outputs
231 |
232 | def check_bert_for_token_classification_output(self, result):
233 | self.parent.assertListEqual(
234 | list(result["logits"].size()),
235 | [self.batch_size, self.seq_length, self.num_labels])
236 |
237 |
238 | def test_default(self):
239 | self.run_tester(BertModelTest.BertModelTester(self))
240 |
241 | def test_config_to_json_string(self):
242 | config = BertConfig(vocab_size_or_config_json_file=99, hidden_size=37)
243 | obj = json.loads(config.to_json_string())
244 | self.assertEqual(obj["vocab_size"], 99)
245 | self.assertEqual(obj["hidden_size"], 37)
246 |
247 | def run_tester(self, tester):
248 | config_and_inputs = tester.prepare_config_and_inputs()
249 | output_result = tester.create_bert_model(*config_and_inputs)
250 | tester.check_bert_model_output(output_result)
251 |
252 | output_result = tester.create_bert_for_masked_lm(*config_and_inputs)
253 | tester.check_bert_for_masked_lm_output(output_result)
254 | tester.check_loss_output(output_result)
255 |
256 | output_result = tester.create_bert_for_next_sequence_prediction(*config_and_inputs)
257 | tester.check_bert_for_next_sequence_prediction_output(output_result)
258 | tester.check_loss_output(output_result)
259 |
260 | output_result = tester.create_bert_for_pretraining(*config_and_inputs)
261 | tester.check_bert_for_pretraining_output(output_result)
262 | tester.check_loss_output(output_result)
263 |
264 | output_result = tester.create_bert_for_question_answering(*config_and_inputs)
265 | tester.check_bert_for_question_answering_output(output_result)
266 | tester.check_loss_output(output_result)
267 |
268 | output_result = tester.create_bert_for_sequence_classification(*config_and_inputs)
269 | tester.check_bert_for_sequence_classification_output(output_result)
270 | tester.check_loss_output(output_result)
271 |
272 | output_result = tester.create_bert_for_token_classification(*config_and_inputs)
273 | tester.check_bert_for_token_classification_output(output_result)
274 | tester.check_loss_output(output_result)
275 |
276 | @classmethod
277 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
278 | """Creates a random int32 tensor of the shape within the vocab size."""
279 | if rng is None:
280 | rng = random.Random()
281 |
282 | total_dims = 1
283 | for dim in shape:
284 | total_dims *= dim
285 |
286 | values = []
287 | for _ in range(total_dims):
288 | values.append(rng.randint(0, vocab_size - 1))
289 |
290 | return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
291 |
292 |
293 | if __name__ == "__main__":
294 | unittest.main()
295 |
--------------------------------------------------------------------------------
/tests/optimization_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import unittest
20 |
21 | import torch
22 |
23 | from pytorch_pretrained_bert import BertAdam
24 |
25 | class OptimizationTest(unittest.TestCase):
26 |
27 | def assertListAlmostEqual(self, list1, list2, tol):
28 | self.assertEqual(len(list1), len(list2))
29 | for a, b in zip(list1, list2):
30 | self.assertAlmostEqual(a, b, delta=tol)
31 |
32 | def test_adam(self):
33 | w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
34 | target = torch.tensor([0.4, 0.2, -0.5])
35 | criterion = torch.nn.MSELoss()
36 | # No warmup, constant schedule, no gradient clipping
37 | optimizer = BertAdam(params=[w], lr=2e-1,
38 | weight_decay=0.0,
39 | max_grad_norm=-1)
40 | for _ in range(100):
41 | loss = criterion(w, target)
42 | loss.backward()
43 | optimizer.step()
44 | w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves.
45 | w.grad.zero_()
46 | self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
47 |
48 |
49 | if __name__ == "__main__":
50 | unittest.main()
51 |
--------------------------------------------------------------------------------
/tests/tokenization_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import os
20 | import unittest
21 |
22 | from pytorch_pretrained_bert.tokenization import (BertTokenizer, BasicTokenizer, WordpieceTokenizer,
23 | _is_whitespace, _is_control, _is_punctuation)
24 |
25 |
26 | class TokenizationTest(unittest.TestCase):
27 |
28 | def test_full_tokenizer(self):
29 | vocab_tokens = [
30 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
31 | "##ing", ","
32 | ]
33 | with open("/tmp/bert_tokenizer_test.txt", "w") as vocab_writer:
34 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
35 |
36 | vocab_file = vocab_writer.name
37 |
38 | tokenizer = BertTokenizer(vocab_file)
39 | os.remove(vocab_file)
40 |
41 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
42 | self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
43 |
44 | self.assertListEqual(
45 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
46 |
47 | def test_full_tokenizer_raises_error_for_long_sequences(self):
48 | vocab_tokens = [
49 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
50 | "##ing", ","
51 | ]
52 | with open("/tmp/bert_tokenizer_test.txt", "w") as vocab_writer:
53 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
54 | vocab_file = vocab_writer.name
55 |
56 | tokenizer = BertTokenizer(vocab_file, max_len=10)
57 | os.remove(vocab_file)
58 | tokens = tokenizer.tokenize(u"the cat sat on the mat in the summer time")
59 | indices = tokenizer.convert_tokens_to_ids(tokens)
60 | self.assertListEqual(indices, [0 for _ in range(10)])
61 |
62 | tokens = tokenizer.tokenize(u"the cat sat on the mat in the summer time .")
63 | self.assertRaises(ValueError, tokenizer.convert_tokens_to_ids, tokens)
64 |
65 | def test_chinese(self):
66 | tokenizer = BasicTokenizer()
67 |
68 | self.assertListEqual(
69 | tokenizer.tokenize(u"ah\u535A\u63A8zz"),
70 | [u"ah", u"\u535A", u"\u63A8", u"zz"])
71 |
72 | def test_basic_tokenizer_lower(self):
73 | tokenizer = BasicTokenizer(do_lower_case=True)
74 |
75 | self.assertListEqual(
76 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
77 | ["hello", "!", "how", "are", "you", "?"])
78 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
79 |
80 | def test_basic_tokenizer_no_lower(self):
81 | tokenizer = BasicTokenizer(do_lower_case=False)
82 |
83 | self.assertListEqual(
84 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
85 | ["HeLLo", "!", "how", "Are", "yoU", "?"])
86 |
87 | def test_wordpiece_tokenizer(self):
88 | vocab_tokens = [
89 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
90 | "##ing"
91 | ]
92 |
93 | vocab = {}
94 | for (i, token) in enumerate(vocab_tokens):
95 | vocab[token] = i
96 | tokenizer = WordpieceTokenizer(vocab=vocab)
97 |
98 | self.assertListEqual(tokenizer.tokenize(""), [])
99 |
100 | self.assertListEqual(
101 | tokenizer.tokenize("unwanted running"),
102 | ["un", "##want", "##ed", "runn", "##ing"])
103 |
104 | self.assertListEqual(
105 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
106 |
107 | def test_is_whitespace(self):
108 | self.assertTrue(_is_whitespace(u" "))
109 | self.assertTrue(_is_whitespace(u"\t"))
110 | self.assertTrue(_is_whitespace(u"\r"))
111 | self.assertTrue(_is_whitespace(u"\n"))
112 | self.assertTrue(_is_whitespace(u"\u00A0"))
113 |
114 | self.assertFalse(_is_whitespace(u"A"))
115 | self.assertFalse(_is_whitespace(u"-"))
116 |
117 | def test_is_control(self):
118 | self.assertTrue(_is_control(u"\u0005"))
119 |
120 | self.assertFalse(_is_control(u"A"))
121 | self.assertFalse(_is_control(u" "))
122 | self.assertFalse(_is_control(u"\t"))
123 | self.assertFalse(_is_control(u"\r"))
124 |
125 | def test_is_punctuation(self):
126 | self.assertTrue(_is_punctuation(u"-"))
127 | self.assertTrue(_is_punctuation(u"$"))
128 | self.assertTrue(_is_punctuation(u"`"))
129 | self.assertTrue(_is_punctuation(u"."))
130 |
131 | self.assertFalse(_is_punctuation(u"A"))
132 | self.assertFalse(_is_punctuation(u" "))
133 |
134 |
135 | if __name__ == '__main__':
136 | unittest.main()
137 |
--------------------------------------------------------------------------------