├── LICENSE ├── MANIFEST.in ├── README.md ├── docker └── Dockerfile ├── evaluate-v2.0.py ├── gpt2_squad.py ├── gpt2sqa ├── .DS_Store ├── __init__.py ├── __main__.py ├── file_utils.py ├── gpt2 │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── gpt2config.cpython-37.pyc │ │ ├── gpt2lmhead.cpython-37.pyc │ │ ├── gpt2model.cpython-37.pyc │ │ ├── gpt2pretrained.cpython-37.pyc │ │ ├── gptdoubleheads.cpython-37.pyc │ │ ├── layer_norm.cpython-37.pyc │ │ ├── modules.cpython-37.pyc │ │ └── utils.cpython-37.pyc │ ├── gpt2config.py │ ├── gpt2lmhead.py │ ├── gpt2model.py │ ├── gpt2pretrained.py │ ├── gptdoubleheads.py │ ├── layer_norm.py │ ├── modules.py │ └── utils.py ├── modeling_gpt2.py ├── optimization.py ├── squad │ ├── .DS_Store │ ├── __init__.py │ ├── squad_example.py │ └── utils.py └── tokenization.py ├── requirements.txt └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GPT2sQA 2 | 3 | This repo includes an experiment of fine-tuning GPT-2 117M for Question Answering (QA). It also runs the model on Stanford Question Answering Dataset 2.0 (SQuAD). It uses Huggingface Inc.'s PyTorch implementation of GPT-2 and adapts from their fine-tuning of BERT for QA. 4 | 5 | SQuAD data can be downloaded from: https://github.com/rajpurkar/SQuAD-explorer/tree/master/dataset 6 | 7 | 8 | To train and validate the model: 9 | 10 | ``` 11 | python gpt2_squad.py --output_dir=output/ --train_file=data/train-v2.0.json --do_train --train_batch_size=32 --predict_file=data/dev-v2.0.json --do_predict 12 | 13 | ``` 14 | 15 | To evaluate: 16 | 17 | ``` 18 | 19 | python evaluate-v2.0.py data/dev-v2.0.json output/predictions.json 20 | 21 | ``` 22 | 23 | 24 | Different fine-tuning experiments will be uploaded soon for GPT-2 345M on datasets that exclusively target commonsense reasoning in an attempt to bring insight to reasoning abilities of GPT-2. Such an insight could potentially improve our ability to improve Natural Language Understanding through language models in semi-supervised settings. 25 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:latest 2 | 3 | RUN git clone https://github.com/NVIDIA/apex.git && cd apex && python setup.py install --cuda_ext --cpp_ext 4 | 5 | RUN pip install pytorch-pretrained-bert 6 | 7 | WORKDIR /workspace -------------------------------------------------------------------------------- /evaluate-v2.0.py: -------------------------------------------------------------------------------- 1 | """Official evaluation script for SQuAD version 2.0. 2 | 3 | In addition to basic functionality, we also compute additional statistics and 4 | plot precision-recall curves if an additional na_prob.json file is provided. 5 | This file is expected to map question ID's to the model's predicted probability 6 | that a question is unanswerable. 7 | """ 8 | import argparse 9 | import collections 10 | import json 11 | import numpy as np 12 | import os 13 | import re 14 | import string 15 | import sys 16 | 17 | OPTS = None 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser('Official evaluation script for SQuAD version 2.0.') 21 | parser.add_argument('data_file', metavar='data.json', help='Input data JSON file.') 22 | parser.add_argument('pred_file', metavar='pred.json', help='Model predictions.') 23 | parser.add_argument('--out-file', '-o', metavar='eval.json', 24 | help='Write accuracy metrics to file (default is stdout).') 25 | parser.add_argument('--na-prob-file', '-n', metavar='na_prob.json', 26 | help='Model estimates of probability of no answer.') 27 | parser.add_argument('--na-prob-thresh', '-t', type=float, default=1.0, 28 | help='Predict "" if no-answer probability exceeds this (default = 1.0).') 29 | parser.add_argument('--out-image-dir', '-p', metavar='out_images', default=None, 30 | help='Save precision-recall curves to directory.') 31 | parser.add_argument('--verbose', '-v', action='store_true') 32 | if len(sys.argv) == 1: 33 | parser.print_help() 34 | sys.exit(1) 35 | return parser.parse_args() 36 | 37 | def make_qid_to_has_ans(dataset): 38 | qid_to_has_ans = {} 39 | for article in dataset: 40 | for p in article['paragraphs']: 41 | for qa in p['qas']: 42 | qid_to_has_ans[qa['id']] = bool(qa['answers']) 43 | return qid_to_has_ans 44 | 45 | def normalize_answer(s): 46 | """Lower text and remove punctuation, articles and extra whitespace.""" 47 | def remove_articles(text): 48 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 49 | return re.sub(regex, ' ', text) 50 | def white_space_fix(text): 51 | return ' '.join(text.split()) 52 | def remove_punc(text): 53 | exclude = set(string.punctuation) 54 | return ''.join(ch for ch in text if ch not in exclude) 55 | def lower(text): 56 | return text.lower() 57 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 58 | 59 | def get_tokens(s): 60 | if not s: return [] 61 | return normalize_answer(s).split() 62 | 63 | def compute_exact(a_gold, a_pred): 64 | return int(normalize_answer(a_gold) == normalize_answer(a_pred)) 65 | 66 | def compute_f1(a_gold, a_pred): 67 | gold_toks = get_tokens(a_gold) 68 | pred_toks = get_tokens(a_pred) 69 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 70 | num_same = sum(common.values()) 71 | if len(gold_toks) == 0 or len(pred_toks) == 0: 72 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 73 | return int(gold_toks == pred_toks) 74 | if num_same == 0: 75 | return 0 76 | precision = 1.0 * num_same / len(pred_toks) 77 | recall = 1.0 * num_same / len(gold_toks) 78 | f1 = (2 * precision * recall) / (precision + recall) 79 | return f1 80 | 81 | def get_raw_scores(dataset, preds): 82 | exact_scores = {} 83 | f1_scores = {} 84 | for article in dataset: 85 | for p in article['paragraphs']: 86 | for qa in p['qas']: 87 | qid = qa['id'] 88 | gold_answers = [a['text'] for a in qa['answers'] 89 | if normalize_answer(a['text'])] 90 | if not gold_answers: 91 | # For unanswerable questions, only correct answer is empty string 92 | gold_answers = [''] 93 | if qid not in preds: 94 | print('Missing prediction for %s' % qid) 95 | continue 96 | a_pred = preds[qid] 97 | # Take max over all gold answers 98 | exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers) 99 | f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers) 100 | return exact_scores, f1_scores 101 | 102 | def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): 103 | new_scores = {} 104 | for qid, s in scores.items(): 105 | pred_na = na_probs[qid] > na_prob_thresh 106 | if pred_na: 107 | new_scores[qid] = float(not qid_to_has_ans[qid]) 108 | else: 109 | new_scores[qid] = s 110 | return new_scores 111 | 112 | def make_eval_dict(exact_scores, f1_scores, qid_list=None): 113 | if not qid_list: 114 | total = len(exact_scores) 115 | return collections.OrderedDict([ 116 | ('exact', 100.0 * sum(exact_scores.values()) / total), 117 | ('f1', 100.0 * sum(f1_scores.values()) / total), 118 | ('total', total), 119 | ]) 120 | else: 121 | total = len(qid_list) 122 | return collections.OrderedDict([ 123 | ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total), 124 | ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total), 125 | ('total', total), 126 | ]) 127 | 128 | def merge_eval(main_eval, new_eval, prefix): 129 | for k in new_eval: 130 | main_eval['%s_%s' % (prefix, k)] = new_eval[k] 131 | 132 | def plot_pr_curve(precisions, recalls, out_image, title): 133 | plt.step(recalls, precisions, color='b', alpha=0.2, where='post') 134 | plt.fill_between(recalls, precisions, step='post', alpha=0.2, color='b') 135 | plt.xlabel('Recall') 136 | plt.ylabel('Precision') 137 | plt.xlim([0.0, 1.05]) 138 | plt.ylim([0.0, 1.05]) 139 | plt.title(title) 140 | plt.savefig(out_image) 141 | plt.clf() 142 | 143 | def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, 144 | out_image=None, title=None): 145 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 146 | true_pos = 0.0 147 | cur_p = 1.0 148 | cur_r = 0.0 149 | precisions = [1.0] 150 | recalls = [0.0] 151 | avg_prec = 0.0 152 | for i, qid in enumerate(qid_list): 153 | if qid_to_has_ans[qid]: 154 | true_pos += scores[qid] 155 | cur_p = true_pos / float(i+1) 156 | cur_r = true_pos / float(num_true_pos) 157 | if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]: 158 | # i.e., if we can put a threshold after this point 159 | avg_prec += cur_p * (cur_r - recalls[-1]) 160 | precisions.append(cur_p) 161 | recalls.append(cur_r) 162 | if out_image: 163 | plot_pr_curve(precisions, recalls, out_image, title) 164 | return {'ap': 100.0 * avg_prec} 165 | 166 | def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, 167 | qid_to_has_ans, out_image_dir): 168 | if out_image_dir and not os.path.exists(out_image_dir): 169 | os.makedirs(out_image_dir) 170 | num_true_pos = sum(1 for v in qid_to_has_ans.values() if v) 171 | if num_true_pos == 0: 172 | return 173 | pr_exact = make_precision_recall_eval( 174 | exact_raw, na_probs, num_true_pos, qid_to_has_ans, 175 | out_image=os.path.join(out_image_dir, 'pr_exact.png'), 176 | title='Precision-Recall curve for Exact Match score') 177 | pr_f1 = make_precision_recall_eval( 178 | f1_raw, na_probs, num_true_pos, qid_to_has_ans, 179 | out_image=os.path.join(out_image_dir, 'pr_f1.png'), 180 | title='Precision-Recall curve for F1 score') 181 | oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()} 182 | pr_oracle = make_precision_recall_eval( 183 | oracle_scores, na_probs, num_true_pos, qid_to_has_ans, 184 | out_image=os.path.join(out_image_dir, 'pr_oracle.png'), 185 | title='Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)') 186 | merge_eval(main_eval, pr_exact, 'pr_exact') 187 | merge_eval(main_eval, pr_f1, 'pr_f1') 188 | merge_eval(main_eval, pr_oracle, 'pr_oracle') 189 | 190 | def histogram_na_prob(na_probs, qid_list, image_dir, name): 191 | if not qid_list: 192 | return 193 | x = [na_probs[k] for k in qid_list] 194 | weights = np.ones_like(x) / float(len(x)) 195 | plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0)) 196 | plt.xlabel('Model probability of no-answer') 197 | plt.ylabel('Proportion of dataset') 198 | plt.title('Histogram of no-answer probability: %s' % name) 199 | plt.savefig(os.path.join(image_dir, 'na_prob_hist_%s.png' % name)) 200 | plt.clf() 201 | 202 | def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): 203 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) 204 | cur_score = num_no_ans 205 | best_score = cur_score 206 | best_thresh = 0.0 207 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 208 | for i, qid in enumerate(qid_list): 209 | if qid not in scores: continue 210 | if qid_to_has_ans[qid]: 211 | diff = scores[qid] 212 | else: 213 | if preds[qid]: 214 | diff = -1 215 | else: 216 | diff = 0 217 | cur_score += diff 218 | if cur_score > best_score: 219 | best_score = cur_score 220 | best_thresh = na_probs[qid] 221 | return 100.0 * best_score / len(scores), best_thresh 222 | 223 | def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): 224 | best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) 225 | best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) 226 | main_eval['best_exact'] = best_exact 227 | main_eval['best_exact_thresh'] = exact_thresh 228 | main_eval['best_f1'] = best_f1 229 | main_eval['best_f1_thresh'] = f1_thresh 230 | 231 | def main(): 232 | with open(OPTS.data_file) as f: 233 | dataset_json = json.load(f) 234 | dataset = dataset_json['data'] 235 | with open(OPTS.pred_file) as f: 236 | preds = json.load(f) 237 | if OPTS.na_prob_file: 238 | with open(OPTS.na_prob_file) as f: 239 | na_probs = json.load(f) 240 | else: 241 | na_probs = {k: 0.0 for k in preds} 242 | qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False 243 | has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] 244 | no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] 245 | exact_raw, f1_raw = get_raw_scores(dataset, preds) 246 | exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, 247 | OPTS.na_prob_thresh) 248 | f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, 249 | OPTS.na_prob_thresh) 250 | out_eval = make_eval_dict(exact_thresh, f1_thresh) 251 | if has_ans_qids: 252 | has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids) 253 | merge_eval(out_eval, has_ans_eval, 'HasAns') 254 | if no_ans_qids: 255 | no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) 256 | merge_eval(out_eval, no_ans_eval, 'NoAns') 257 | if OPTS.na_prob_file: 258 | find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans) 259 | if OPTS.na_prob_file and OPTS.out_image_dir: 260 | run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, 261 | qid_to_has_ans, OPTS.out_image_dir) 262 | histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, 'hasAns') 263 | histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, 'noAns') 264 | if OPTS.out_file: 265 | with open(OPTS.out_file, 'w') as f: 266 | json.dump(out_eval, f) 267 | else: 268 | print(json.dumps(out_eval, indent=2)) 269 | 270 | if __name__ == '__main__': 271 | OPTS = parse_args() 272 | if OPTS.out_image_dir: 273 | import matplotlib 274 | matplotlib.use('Agg') 275 | import matplotlib.pyplot as plt 276 | main() 277 | -------------------------------------------------------------------------------- /gpt2_squad.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Run GPT2 small on SQuAD.""" 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import argparse 21 | import collections 22 | import json 23 | import logging 24 | import math 25 | import os 26 | import random 27 | import sys 28 | from io import open 29 | import pickle 30 | 31 | 32 | import numpy as np 33 | import torch 34 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 35 | TensorDataset) 36 | from torch.utils.data.distributed import DistributedSampler 37 | from tqdm import tqdm, trange 38 | 39 | from gpt2sqa.file_utils import PYTORCH_PRETRAINED_GPT2_CACHE, WEIGHTS_NAME, CONFIG_NAME 40 | from gpt2sqa.modeling_gpt2 import GPT2ModelForQuestionAnswering 41 | from gpt2sqa.optimization import GPT2Adam, WarmupLinearSchedule 42 | from gpt2sqa.tokenization import GPT2Tokenizer 43 | from gpt2sqa.squad.squad_example import InputFeatures 44 | from gpt2sqa.squad.utils import convert_examples_to_features, read_squad_examples, get_final_text, write_predictions, _check_is_max_context, _get_best_indexes, _compute_softmax, RawResult 45 | logger = logging.getLogger(__name__) 46 | 47 | 48 | def main(): 49 | parser = argparse.ArgumentParser() 50 | 51 | # Required parameters 52 | parser.add_argument("--output_dir", default=None, type=str, required=True, 53 | help="The output directory where the model checkpoints and predictions will be written.") 54 | 55 | # Other parameters 56 | parser.add_argument("--train_file", default=None, type=str, help="SQuAD json for training. E.g., train-v1.1.json") 57 | parser.add_argument("--predict_file", default=None, type=str, 58 | help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") 59 | parser.add_argument("--max_seq_length", default=1000, type=int, 60 | help="The maximum total input sequence length after WordPiece tokenization. Sequences " 61 | "longer than this will be truncated, and sequences shorter than this will be padded.") 62 | parser.add_argument("--doc_stride", default=128, type=int, 63 | help="When splitting up a long document into chunks, how much stride to take between chunks.") 64 | parser.add_argument("--max_query_length", default=64, type=int, 65 | help="The maximum number of tokens for the question. Questions longer than this will " 66 | "be truncated to this length.") 67 | parser.add_argument("--do_train", action='store_true', help="Whether to run training.") 68 | parser.add_argument("--do_predict", action='store_true', help="Whether to run eval on the dev set.") 69 | parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") 70 | parser.add_argument("--predict_batch_size", default=8, type=int, help="Total batch size for predictions.") 71 | parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") 72 | parser.add_argument("--num_train_epochs", default=3.0, type=float, 73 | help="Total number of training epochs to perform.") 74 | parser.add_argument("--warmup_proportion", default=0.1, type=float, 75 | help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% " 76 | "of training.") 77 | parser.add_argument("--n_best_size", default=20, type=int, 78 | help="The total number of n-best predictions to generate in the nbest_predictions.json " 79 | "output file.") 80 | parser.add_argument("--max_answer_length", default=30, type=int, 81 | help="The maximum length of an answer that can be generated. This is needed because the start " 82 | "and end predictions are not conditioned on one another.") 83 | parser.add_argument("--verbose_logging", action='store_true', 84 | help="If true, all of the warnings related to data processing will be printed. " 85 | "A number of warnings are expected for a normal SQuAD evaluation.") 86 | parser.add_argument("--no_cuda", 87 | action='store_true', 88 | help="Whether not to use CUDA when available") 89 | parser.add_argument('--seed', 90 | type=int, 91 | default=42, 92 | help="random seed for initialization") 93 | parser.add_argument('--gradient_accumulation_steps', 94 | type=int, 95 | default=1, 96 | help="Number of updates steps to accumulate before performing a backward/update pass.") 97 | parser.add_argument("--do_lower_case", 98 | action='store_true', 99 | help="Whether to lower case the input text. True for uncased models, False for cased models.") 100 | parser.add_argument("--local_rank", 101 | type=int, 102 | default=-1, 103 | help="local_rank for distributed training on gpus") 104 | parser.add_argument('--loss_scale', 105 | type=float, default=0, 106 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 107 | "0 (default value): dynamic loss scaling.\n" 108 | "Positive power of 2: static loss scaling value.\n") 109 | parser.add_argument('--null_score_diff_threshold', 110 | type=float, default=0.0, 111 | help="If null_score - best_non_null is greater than the threshold predict null.") 112 | args = parser.parse_args() 113 | print(args) 114 | 115 | if args.local_rank == -1 or args.no_cuda: 116 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 117 | n_gpu = torch.cuda.device_count() 118 | else: 119 | torch.cuda.set_device(args.local_rank) 120 | device = torch.device("cuda", args.local_rank) 121 | n_gpu = 1 122 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 123 | torch.distributed.init_process_group(backend='nccl') 124 | 125 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 126 | datefmt='%m/%d/%Y %H:%M:%S', 127 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 128 | 129 | logger.info("device: {} n_gpu: {}, distributed training: {}".format( 130 | device, n_gpu, bool(args.local_rank != -1))) 131 | 132 | if args.gradient_accumulation_steps < 1: 133 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 134 | args.gradient_accumulation_steps)) 135 | 136 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 137 | 138 | random.seed(args.seed) 139 | np.random.seed(args.seed) 140 | torch.manual_seed(args.seed) 141 | if n_gpu > 0: 142 | torch.cuda.manual_seed_all(args.seed) 143 | 144 | if not args.do_train and not args.do_predict: 145 | raise ValueError("At least one of `do_train` or `do_predict` must be True.") 146 | 147 | if args.do_train: 148 | if not args.train_file: 149 | raise ValueError( 150 | "If `do_train` is True, then `train_file` must be specified.") 151 | if args.do_predict: 152 | if not args.predict_file: 153 | raise ValueError( 154 | "If `do_predict` is True, then `predict_file` must be specified.") 155 | 156 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: 157 | raise ValueError("Output directory () already exists and is not empty.") 158 | if not os.path.exists(args.output_dir): 159 | os.makedirs(args.output_dir) 160 | 161 | tokenizer = GPT2Tokenizer.from_pretrained() 162 | 163 | train_examples = None 164 | num_train_optimization_steps = None 165 | if args.do_train: 166 | train_examples = read_squad_examples(input_file=args.train_file, is_training=True) 167 | num_train_optimization_steps = int(len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs 168 | if args.local_rank != -1: 169 | num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() 170 | 171 | # Prepare model 172 | model = GPT2ModelForQuestionAnswering.from_pretrained(cache_dir=os.path.join(str(PYTORCH_PRETRAINED_GPT2_CACHE), 'distributed_{}'.format(args.local_rank))) 173 | 174 | model.to(device) 175 | if args.local_rank != -1: 176 | try: 177 | from apex.parallel import DistributedDataParallel as DDP 178 | except ImportError: 179 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 180 | 181 | model = DDP(model) 182 | elif n_gpu > 1: 183 | model = torch.nn.DataParallel(model) 184 | 185 | # Prepare optimizer 186 | param_optimizer = list(model.named_parameters()) 187 | 188 | # hack to remove pooler, which is not used 189 | # thus it produce None grad that break apex 190 | param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]] 191 | 192 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 193 | optimizer_grouped_parameters = [ 194 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 195 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 196 | ] 197 | 198 | if num_train_optimization_steps is None: 199 | num_train_optimization_steps = 1 200 | 201 | optimizer = GPT2Adam(optimizer_grouped_parameters, 202 | lr=args.learning_rate, 203 | warmup=args.warmup_proportion, 204 | t_total=num_train_optimization_steps) 205 | 206 | global_step = 0 207 | if args.do_train: 208 | train_features = convert_examples_to_features( 209 | examples=train_examples, 210 | tokenizer=tokenizer, 211 | max_seq_length=args.max_seq_length, 212 | doc_stride=args.doc_stride, 213 | max_query_length=args.max_query_length, 214 | is_training=True) 215 | logger.info("***** Running training *****") 216 | logger.info(" Num orig examples = %d", len(train_examples)) 217 | logger.info(" Num split examples = %d", len(train_features)) 218 | logger.info(" Batch size = %d", args.train_batch_size) 219 | logger.info(" Num steps = %d", num_train_optimization_steps) 220 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) 221 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) 222 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) 223 | all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long) 224 | all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long) 225 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, 226 | all_start_positions, all_end_positions) 227 | if args.local_rank == -1: 228 | train_sampler = RandomSampler(train_data) 229 | else: 230 | train_sampler = DistributedSampler(train_data) 231 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) 232 | 233 | model.train() 234 | total_loss = 0 235 | pbar = tqdm(train_dataloader, disable=args.local_rank not in [-1, 0]) 236 | for _ in trange(int(args.num_train_epochs), desc="Epoch"): 237 | for step, batch in enumerate(pbar): 238 | if n_gpu == 1: 239 | batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self 240 | input_ids, input_mask, segment_ids, start_positions, end_positions = batch 241 | loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions) 242 | if n_gpu > 1: 243 | loss = loss.mean() # mean() to average on multi-gpu. 244 | if args.gradient_accumulation_steps > 1: 245 | loss = loss / args.gradient_accumulation_steps 246 | total_loss += loss.item() 247 | 248 | loss.backward() 249 | pbar.update(1) 250 | if step % 10 == 0: 251 | pbar.set_description(desc=f'loss:{np.mean(total_loss)}') 252 | total_loss = 0 253 | if (step + 1) % args.gradient_accumulation_steps == 0: 254 | optimizer.step() 255 | optimizer.zero_grad() 256 | global_step += 1 257 | 258 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 259 | # Save a trained model, configuration and tokenizer 260 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 261 | 262 | # If we save using the predefined names, we can load using `from_pretrained` 263 | output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) 264 | output_config_file = os.path.join(args.output_dir, CONFIG_NAME) 265 | 266 | torch.save(model_to_save.state_dict(), output_model_file) 267 | model_to_save.config.to_json_file(output_config_file) 268 | tokenizer.save_vocabulary(args.output_dir) 269 | 270 | # Load a trained model and vocabulary that you have fine-tuned 271 | model = GPT2ModelForQuestionAnswering.from_pretrained(args.output_dir) 272 | tokenizer = GPT2Tokenizer.from_pretrained(args.output_dir) 273 | else: 274 | model = GPT2ModelForQuestionAnswering.from_pretrained() 275 | 276 | model.to(device) 277 | 278 | if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 279 | eval_examples = read_squad_examples( 280 | input_file=args.predict_file, is_training=False, ) 281 | eval_features = convert_examples_to_features( 282 | examples=eval_examples, 283 | tokenizer=tokenizer, 284 | max_seq_length=args.max_seq_length, 285 | doc_stride=args.doc_stride, 286 | max_query_length=args.max_query_length, 287 | is_training=False) 288 | 289 | logger.info("***** Running predictions *****") 290 | logger.info(" Num orig examples = %d", len(eval_examples)) 291 | logger.info(" Num split examples = %d", len(eval_features)) 292 | logger.info(" Batch size = %d", args.predict_batch_size) 293 | 294 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 295 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 296 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 297 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 298 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) 299 | # Run prediction for full data 300 | eval_sampler = SequentialSampler(eval_data) 301 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size) 302 | 303 | model.eval() 304 | all_results = [] 305 | logger.info("Start evaluating") 306 | for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating", disable=args.local_rank not in [-1, 0]): 307 | if len(all_results) % 1000 == 0: 308 | logger.info("Processing example: %d" % (len(all_results))) 309 | input_ids = input_ids.to(device) 310 | input_mask = input_mask.to(device) 311 | segment_ids = segment_ids.to(device) 312 | with torch.no_grad(): 313 | batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask) 314 | for i, example_index in enumerate(example_indices): 315 | start_logits = batch_start_logits[i].detach().cpu().tolist() 316 | end_logits = batch_end_logits[i].detach().cpu().tolist() 317 | eval_feature = eval_features[example_index.item()] 318 | unique_id = int(eval_feature.unique_id) 319 | all_results.append(RawResult(unique_id=unique_id, 320 | start_logits=start_logits, 321 | end_logits=end_logits)) 322 | output_prediction_file = os.path.join(args.output_dir, "predictions.json") 323 | output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json") 324 | output_null_log_odds_file = os.path.join(args.output_dir, "null_odds.json") 325 | write_predictions(eval_examples, eval_features, all_results, 326 | args.n_best_size, args.max_answer_length, 327 | args.do_lower_case, output_prediction_file, 328 | output_nbest_file, output_null_log_odds_file, args.verbose_logging, 329 | True, args.null_score_diff_threshold) 330 | 331 | 332 | if __name__ == "__main__": 333 | main() 334 | -------------------------------------------------------------------------------- /gpt2sqa/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ftarlaci/GPT2sQA/41cd86ef5c2051ad3fda224ac912d97d07f73f61/gpt2sqa/.DS_Store -------------------------------------------------------------------------------- /gpt2sqa/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.6.2" 2 | 3 | from gpt2sqa.gpt2.utils import load_tf_weights_in_gpt2 4 | from gpt2sqa.gpt2.gpt2config import GPT2Config 5 | from gpt2sqa.gpt2.gpt2model import GPT2Model 6 | from gpt2sqa.gpt2.gptdoubleheads import GPT2DoubleHeadsModel 7 | from gpt2sqa.gpt2.gpt2lmhead import GPT2LMHead 8 | 9 | 10 | from .file_utils import PYTORCH_PRETRAINED_GPT2_CACHE, cached_path, WEIGHTS_NAME, CONFIG_NAME 11 | -------------------------------------------------------------------------------- /gpt2sqa/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [ 5 | "convert_tf_checkpoint_to_pytorch", 6 | "convert_openai_checkpoint", 7 | "convert_transfo_xl_checkpoint", 8 | "convert_gpt2_checkpoint", 9 | ]: 10 | print( 11 | "Should be used as one of: \n" 12 | ">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n" 13 | ">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n" 14 | ">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n" 15 | ">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`") 16 | else: 17 | if sys.argv[1] == "convert_tf_checkpoint_to_pytorch": 18 | try: 19 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 20 | except ImportError: 21 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 22 | "In that case, it requires TensorFlow to be installed. Please see " 23 | "https://www.tensorflow.org/install/ for installation instructions.") 24 | raise 25 | 26 | if len(sys.argv) != 5: 27 | # pylint: disable=line-too-long 28 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 29 | else: 30 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 31 | TF_CONFIG = sys.argv.pop() 32 | TF_CHECKPOINT = sys.argv.pop() 33 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 34 | elif sys.argv[1] == "convert_openai_checkpoint": 35 | from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch 36 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] 37 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 38 | if len(sys.argv) == 5: 39 | OPENAI_GPT_CONFIG = sys.argv[4] 40 | else: 41 | OPENAI_GPT_CONFIG = "" 42 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, 43 | OPENAI_GPT_CONFIG, 44 | PYTORCH_DUMP_OUTPUT) 45 | elif sys.argv[1] == "convert_transfo_xl_checkpoint": 46 | try: 47 | from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch 48 | except ImportError: 49 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 50 | "In that case, it requires TensorFlow to be installed. Please see " 51 | "https://www.tensorflow.org/install/ for installation instructions.") 52 | raise 53 | 54 | if 'ckpt' in sys.argv[2].lower(): 55 | TF_CHECKPOINT = sys.argv[2] 56 | TF_DATASET_FILE = "" 57 | else: 58 | TF_DATASET_FILE = sys.argv[2] 59 | TF_CHECKPOINT = "" 60 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 61 | if len(sys.argv) == 5: 62 | TF_CONFIG = sys.argv[4] 63 | else: 64 | TF_CONFIG = "" 65 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) 66 | else: 67 | try: 68 | from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch 69 | except ImportError: 70 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 71 | "In that case, it requires TensorFlow to be installed. Please see " 72 | "https://www.tensorflow.org/install/ for installation instructions.") 73 | raise 74 | 75 | TF_CHECKPOINT = sys.argv[2] 76 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 77 | if len(sys.argv) == 5: 78 | TF_CONFIG = sys.argv[4] 79 | else: 80 | TF_CONFIG = "" 81 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /gpt2sqa/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import sys 9 | import json 10 | import logging 11 | import os 12 | import shutil 13 | import tempfile 14 | import fnmatch 15 | from functools import wraps 16 | from hashlib import sha256 17 | import sys 18 | from io import open 19 | 20 | import boto3 21 | import requests 22 | from botocore.exceptions import ClientError 23 | from tqdm import tqdm 24 | 25 | try: 26 | from urllib.parse import urlparse 27 | except ImportError: 28 | from urlparse import urlparse 29 | 30 | try: 31 | from pathlib import Path 32 | PYTORCH_PRETRAINED_GPT2_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 33 | Path.home() / '.pytorch_pretrained_bert')) 34 | except (AttributeError, ImportError): 35 | PYTORCH_PRETRAINED_GPT2_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 36 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) 37 | 38 | CONFIG_NAME = "config.json" 39 | WEIGHTS_NAME = "pytorch_model.bin" 40 | 41 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 42 | 43 | 44 | def url_to_filename(url, etag=None): 45 | """ 46 | Convert `url` into a hashed filename in a repeatable way. 47 | If `etag` is specified, append its hash to the url's, delimited 48 | by a period. 49 | """ 50 | url_bytes = url.encode('utf-8') 51 | url_hash = sha256(url_bytes) 52 | filename = url_hash.hexdigest() 53 | 54 | if etag: 55 | etag_bytes = etag.encode('utf-8') 56 | etag_hash = sha256(etag_bytes) 57 | filename += '.' + etag_hash.hexdigest() 58 | 59 | return filename 60 | 61 | 62 | def filename_to_url(filename, cache_dir=None): 63 | """ 64 | Return the url and etag (which may be ``None``) stored for `filename`. 65 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 66 | """ 67 | if cache_dir is None: 68 | cache_dir = PYTORCH_PRETRAINED_GPT2_CACHE 69 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 70 | cache_dir = str(cache_dir) 71 | 72 | cache_path = os.path.join(cache_dir, filename) 73 | if not os.path.exists(cache_path): 74 | raise EnvironmentError("file {} not found".format(cache_path)) 75 | 76 | meta_path = cache_path + '.json' 77 | if not os.path.exists(meta_path): 78 | raise EnvironmentError("file {} not found".format(meta_path)) 79 | 80 | with open(meta_path, encoding="utf-8") as meta_file: 81 | metadata = json.load(meta_file) 82 | url = metadata['url'] 83 | etag = metadata['etag'] 84 | 85 | return url, etag 86 | 87 | 88 | def cached_path(url_or_filename, cache_dir=None): 89 | """ 90 | Given something that might be a URL (or might be a local path), 91 | determine which. If it's a URL, download the file and cache it, and 92 | return the path to the cached file. If it's already a local path, 93 | make sure the file exists and then return the path. 94 | """ 95 | if cache_dir is None: 96 | cache_dir = PYTORCH_PRETRAINED_GPT2_CACHE 97 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 98 | url_or_filename = str(url_or_filename) 99 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 100 | cache_dir = str(cache_dir) 101 | 102 | parsed = urlparse(url_or_filename) 103 | 104 | if parsed.scheme in ('http', 'https', 's3'): 105 | # URL, so get it from the cache (downloading if necessary) 106 | return get_from_cache(url_or_filename, cache_dir) 107 | elif os.path.exists(url_or_filename): 108 | # File, and it exists. 109 | return url_or_filename 110 | elif parsed.scheme == '': 111 | # File, but it doesn't exist. 112 | raise EnvironmentError("file {} not found".format(url_or_filename)) 113 | else: 114 | # Something unknown 115 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 116 | 117 | 118 | def split_s3_path(url): 119 | """Split a full s3 path into the bucket name and path.""" 120 | parsed = urlparse(url) 121 | if not parsed.netloc or not parsed.path: 122 | raise ValueError("bad s3 path {}".format(url)) 123 | bucket_name = parsed.netloc 124 | s3_path = parsed.path 125 | # Remove '/' at beginning of path. 126 | if s3_path.startswith("/"): 127 | s3_path = s3_path[1:] 128 | return bucket_name, s3_path 129 | 130 | 131 | def s3_request(func): 132 | """ 133 | Wrapper function for s3 requests in order to create more helpful error 134 | messages. 135 | """ 136 | 137 | @wraps(func) 138 | def wrapper(url, *args, **kwargs): 139 | try: 140 | return func(url, *args, **kwargs) 141 | except ClientError as exc: 142 | if int(exc.response["Error"]["Code"]) == 404: 143 | raise EnvironmentError("file {} not found".format(url)) 144 | else: 145 | raise 146 | 147 | return wrapper 148 | 149 | 150 | @s3_request 151 | def s3_etag(url): 152 | """Check ETag on S3 object.""" 153 | s3_resource = boto3.resource("s3") 154 | bucket_name, s3_path = split_s3_path(url) 155 | s3_object = s3_resource.Object(bucket_name, s3_path) 156 | return s3_object.e_tag 157 | 158 | 159 | @s3_request 160 | def s3_get(url, temp_file): 161 | """Pull a file directly from S3.""" 162 | s3_resource = boto3.resource("s3") 163 | bucket_name, s3_path = split_s3_path(url) 164 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 165 | 166 | 167 | def http_get(url, temp_file): 168 | req = requests.get(url, stream=True) 169 | content_length = req.headers.get('Content-Length') 170 | total = int(content_length) if content_length is not None else None 171 | progress = tqdm(unit="B", total=total) 172 | for chunk in req.iter_content(chunk_size=1024): 173 | if chunk: # filter out keep-alive new chunks 174 | progress.update(len(chunk)) 175 | temp_file.write(chunk) 176 | progress.close() 177 | 178 | 179 | def get_from_cache(url, cache_dir=None): 180 | """ 181 | Given a URL, look for the corresponding dataset in the local cache. 182 | If it's not there, download it. Then return the path to the cached file. 183 | """ 184 | if cache_dir is None: 185 | cache_dir = PYTORCH_PRETRAINED_GPT2_CACHE 186 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 187 | cache_dir = str(cache_dir) 188 | 189 | if not os.path.exists(cache_dir): 190 | os.makedirs(cache_dir) 191 | 192 | # Get eTag to add to filename, if it exists. 193 | if url.startswith("s3://"): 194 | etag = s3_etag(url) 195 | else: 196 | try: 197 | response = requests.head(url, allow_redirects=True) 198 | if response.status_code != 200: 199 | etag = None 200 | else: 201 | etag = response.headers.get("ETag") 202 | except EnvironmentError: 203 | etag = None 204 | 205 | if sys.version_info[0] == 2 and etag is not None: 206 | etag = etag.decode('utf-8') 207 | filename = url_to_filename(url, etag) 208 | 209 | # get cache path to put the file 210 | cache_path = os.path.join(cache_dir, filename) 211 | 212 | # If we don't have a connection (etag is None) and can't identify the file 213 | # try to get the last downloaded one 214 | if not os.path.exists(cache_path) and etag is None: 215 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') 216 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) 217 | if matching_files: 218 | cache_path = os.path.join(cache_dir, matching_files[-1]) 219 | 220 | if not os.path.exists(cache_path): 221 | # Download to temporary file, then copy to cache dir once finished. 222 | # Otherwise you get corrupt cache entries if the download gets interrupted. 223 | with tempfile.NamedTemporaryFile() as temp_file: 224 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 225 | 226 | # GET file object 227 | if url.startswith("s3://"): 228 | s3_get(url, temp_file) 229 | else: 230 | http_get(url, temp_file) 231 | 232 | # we are copying the file before closing it, so flush to avoid truncation 233 | temp_file.flush() 234 | # shutil.copyfileobj() starts at the current position, so go to the start 235 | temp_file.seek(0) 236 | 237 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 238 | with open(cache_path, 'wb') as cache_file: 239 | shutil.copyfileobj(temp_file, cache_file) 240 | 241 | logger.info("creating metadata file for %s", cache_path) 242 | meta = {'url': url, 'etag': etag} 243 | meta_path = cache_path + '.json' 244 | with open(meta_path, 'w') as meta_file: 245 | output_string = json.dumps(meta) 246 | if sys.version_info[0] == 2 and isinstance(output_string, str): 247 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2 248 | meta_file.write(output_string) 249 | 250 | logger.info("removing temp file %s", temp_file.name) 251 | 252 | return cache_path 253 | 254 | 255 | def read_set_from_file(filename): 256 | ''' 257 | Extract a de-duped collection (set) of text from a file. 258 | Expected file format is one item per line. 259 | ''' 260 | collection = set() 261 | with open(filename, 'r', encoding='utf-8') as file_: 262 | for line in file_: 263 | collection.add(line.rstrip()) 264 | return collection 265 | 266 | 267 | def get_file_extension(path, dot=True, lower=True): 268 | ext = os.path.splitext(path)[1] 269 | ext = ext if dot else ext[1:] 270 | return ext.lower() if lower else ext 271 | -------------------------------------------------------------------------------- /gpt2sqa/gpt2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ftarlaci/GPT2sQA/41cd86ef5c2051ad3fda224ac912d97d07f73f61/gpt2sqa/gpt2/__init__.py -------------------------------------------------------------------------------- /gpt2sqa/gpt2/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ftarlaci/GPT2sQA/41cd86ef5c2051ad3fda224ac912d97d07f73f61/gpt2sqa/gpt2/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /gpt2sqa/gpt2/__pycache__/gpt2config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ftarlaci/GPT2sQA/41cd86ef5c2051ad3fda224ac912d97d07f73f61/gpt2sqa/gpt2/__pycache__/gpt2config.cpython-37.pyc -------------------------------------------------------------------------------- /gpt2sqa/gpt2/__pycache__/gpt2lmhead.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ftarlaci/GPT2sQA/41cd86ef5c2051ad3fda224ac912d97d07f73f61/gpt2sqa/gpt2/__pycache__/gpt2lmhead.cpython-37.pyc -------------------------------------------------------------------------------- /gpt2sqa/gpt2/__pycache__/gpt2model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ftarlaci/GPT2sQA/41cd86ef5c2051ad3fda224ac912d97d07f73f61/gpt2sqa/gpt2/__pycache__/gpt2model.cpython-37.pyc -------------------------------------------------------------------------------- /gpt2sqa/gpt2/__pycache__/gpt2pretrained.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ftarlaci/GPT2sQA/41cd86ef5c2051ad3fda224ac912d97d07f73f61/gpt2sqa/gpt2/__pycache__/gpt2pretrained.cpython-37.pyc -------------------------------------------------------------------------------- /gpt2sqa/gpt2/__pycache__/gptdoubleheads.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ftarlaci/GPT2sQA/41cd86ef5c2051ad3fda224ac912d97d07f73f61/gpt2sqa/gpt2/__pycache__/gptdoubleheads.cpython-37.pyc -------------------------------------------------------------------------------- /gpt2sqa/gpt2/__pycache__/layer_norm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ftarlaci/GPT2sQA/41cd86ef5c2051ad3fda224ac912d97d07f73f61/gpt2sqa/gpt2/__pycache__/layer_norm.cpython-37.pyc -------------------------------------------------------------------------------- /gpt2sqa/gpt2/__pycache__/modules.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ftarlaci/GPT2sQA/41cd86ef5c2051ad3fda224ac912d97d07f73f61/gpt2sqa/gpt2/__pycache__/modules.cpython-37.pyc -------------------------------------------------------------------------------- /gpt2sqa/gpt2/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ftarlaci/GPT2sQA/41cd86ef5c2051ad3fda224ac912d97d07f73f61/gpt2sqa/gpt2/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /gpt2sqa/gpt2/gpt2config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import copy 3 | import sys 4 | 5 | 6 | class GPT2Config(object): 7 | """Configuration class to store the configuration of a `GPT2Model`. 8 | """ 9 | 10 | def __init__( 11 | self, 12 | vocab_size_or_config_json_file=50257, 13 | n_positions=1024, 14 | n_ctx=1024, 15 | n_embd=768, 16 | n_layer=12, 17 | n_head=12, 18 | layer_norm_epsilon=1e-5, 19 | initializer_range=0.02, 20 | ): 21 | """Constructs GPT2Config. 22 | 23 | Args: 24 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file. 25 | n_positions: Number of positional embeddings. 26 | n_ctx: Size of the causal mask (usually same as n_positions). 27 | n_embd: Dimensionality of the embeddings and hidden states. 28 | n_layer: Number of hidden layers in the Transformer encoder. 29 | n_head: Number of attention heads for each attention layer in 30 | the Transformer encoder. 31 | layer_norm_epsilon: epsilon to use in the layer norm layers 32 | initializer_range: The sttdev of the truncated_normal_initializer for 33 | initializing all weight matrices. 34 | """ 35 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 36 | and isinstance(vocab_size_or_config_json_file, unicode)): 37 | with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: 38 | json_config = json.loads(reader.read()) 39 | for key, value in json_config.items(): 40 | self.__dict__[key] = value 41 | elif isinstance(vocab_size_or_config_json_file, int): 42 | self.vocab_size = vocab_size_or_config_json_file 43 | self.n_ctx = n_ctx 44 | self.n_positions = n_positions 45 | self.n_embd = n_embd 46 | self.n_layer = n_layer 47 | self.n_head = n_head 48 | self.layer_norm_epsilon = layer_norm_epsilon 49 | self.initializer_range = initializer_range 50 | else: 51 | raise ValueError( 52 | "First argument must be either a vocabulary size (int)" 53 | "or the path to a pretrained model config file (str)" 54 | ) 55 | 56 | @classmethod 57 | def from_dict(cls, json_object): 58 | """Constructs a `GPT2Config` from a Python dictionary of parameters.""" 59 | config = GPT2Config(vocab_size_or_config_json_file=-1) 60 | for key, value in json_object.items(): 61 | config.__dict__[key] = value 62 | return config 63 | 64 | @classmethod 65 | def from_json_file(cls, json_file): 66 | """Constructs a `GPT2Config` from a json file of parameters.""" 67 | with open(json_file, "r", encoding="utf-8") as reader: 68 | text = reader.read() 69 | return cls.from_dict(json.loads(text)) 70 | 71 | def __repr__(self): 72 | return str(self.to_json_string()) 73 | 74 | def to_dict(self): 75 | """Serializes this instance to a Python dictionary.""" 76 | output = copy.deepcopy(self.__dict__) 77 | return output 78 | 79 | def to_json_string(self): 80 | """Serializes this instance to a JSON string.""" 81 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 82 | 83 | def to_json_file(self, json_file_path): 84 | """ Save this instance to a json file.""" 85 | with open(json_file_path, "w", encoding='utf-8') as writer: 86 | writer.write(self.to_json_string()) 87 | -------------------------------------------------------------------------------- /gpt2sqa/gpt2/gpt2lmhead.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class GPT2LMHead(nn.Module): 5 | """ Language Model Head for the transformer """ 6 | 7 | def __init__(self, model_embeddings_weights, config): 8 | super(GPT2LMHead, self).__init__() 9 | self.n_embd = config.n_embd 10 | self.set_embeddings_weights(model_embeddings_weights) 11 | 12 | def set_embeddings_weights(self, model_embeddings_weights): 13 | embed_shape = model_embeddings_weights.shape 14 | self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) 15 | self.decoder.weight = model_embeddings_weights # Tied weights 16 | 17 | def forward(self, hidden_state): 18 | # Truncated Language modeling logits (we remove the last token) 19 | # h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd) 20 | lm_logits = self.decoder(hidden_state) 21 | return lm_logits 22 | 23 | 24 | class GPT2MultipleChoiceHead(nn.Module): 25 | """ Classifier Head for the transformer """ 26 | 27 | def __init__(self, config): 28 | super(GPT2MultipleChoiceHead, self).__init__() 29 | self.n_embd = config.n_embd 30 | self.linear = nn.Linear(config.n_embd, 1) 31 | 32 | nn.init.normal_(self.linear.weight, std=0.02) 33 | nn.init.normal_(self.linear.bias, 0) 34 | 35 | def forward(self, hidden_states, mc_token_ids): 36 | # Classification logits 37 | # hidden_state (bsz, num_choices, seq_length, hidden_size) 38 | # mc_token_ids (bsz, num_choices) 39 | mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1)) 40 | # (bsz, num_choices, 1, hidden_size) 41 | multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2) 42 | # (bsz, num_choices, hidden_size) 43 | multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1) 44 | # (bsz, num_choices) 45 | return multiple_choice_logits 46 | -------------------------------------------------------------------------------- /gpt2sqa/gpt2/gpt2model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import CrossEntropyLoss 6 | 7 | from gpt2_question_answering.gpt2.layer_norm import LayerNorm 8 | from gpt2_question_answering.gpt2.modules import * 9 | from gpt2_question_answering.gpt2.gpt2lmhead import GPT2LMHead 10 | from gpt2_question_answering.gpt2.gpt2pretrained import GPT2PreTrainedModel 11 | 12 | 13 | class GPT2Model(GPT2PreTrainedModel): 14 | """OpenAI GPT-2 model ("Language Models are Unsupervised Multitask Learners"). 15 | 16 | Params: 17 | config: a GPT2Config class instance with the configuration to build a new model 18 | 19 | Inputs: 20 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length] 21 | were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[ 22 | `position_ids`: an optional torch.LongTensor with the same shape as input_ids 23 | with the position indices (selected in the range [0, config.n_positions - 1[. 24 | `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids 25 | You can use it to add a third type of embedding to each input token in the sequence 26 | (the previous two being the word and position embeddings). 27 | The input, position and token_type embeddings are summed inside the Transformer before the first 28 | self-attention block. 29 | `past`: an optional list of torch.LongTensor that contains pre-computed hidden-states 30 | (key and values in the attention blocks) to speed up sequential decoding 31 | (this is the presents output of the model, cf. below). 32 | 33 | Outputs a tuple consisting of: 34 | `hidden_states`: the encoded-hidden-states at the top of the model 35 | as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size] 36 | (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids) 37 | `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as 38 | torch.FloatTensors. They can be reused to speed up sequential decoding. 39 | 40 | Example usage: 41 | ```python 42 | # Already been converted into BPE token ids 43 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 44 | 45 | config = modeling_gpt2.GPT2Config() 46 | 47 | model = modeling_gpt2.GPT2Model(config) 48 | hidden_states, presents = model(input_ids) 49 | ``` 50 | """ 51 | 52 | def __init__(self, config): 53 | super(GPT2Model, self).__init__(config) 54 | self.wte = nn.Embedding(config.vocab_size, config.n_embd) 55 | self.wpe = nn.Embedding(config.n_positions, config.n_embd) 56 | block = Block(config.n_ctx, config, scale=True) 57 | self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) 58 | self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) 59 | 60 | self.apply(self.init_weights) 61 | 62 | def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None): 63 | if past is None: 64 | past_length = 0 65 | past = [None] * len(self.h) 66 | else: 67 | past_length = past[0][0].size(-2) 68 | if position_ids is None: 69 | position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device) 70 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 71 | 72 | input_shape = input_ids.size() 73 | input_ids = input_ids.view(-1, input_ids.size(-1)) 74 | position_ids = position_ids.view(-1, position_ids.size(-1)) 75 | 76 | inputs_embeds = self.wte(input_ids) 77 | position_embeds = self.wpe(position_ids) 78 | if token_type_ids is not None: 79 | token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) 80 | token_type_embeds = self.wte(token_type_ids) 81 | else: 82 | token_type_embeds = 0 83 | hidden_states = inputs_embeds + position_embeds + token_type_embeds 84 | presents = [] 85 | for block, layer_past in zip(self.h, past): 86 | hidden_states, present = block(hidden_states, layer_past) 87 | presents.append(present) 88 | hidden_states = self.ln_f(hidden_states) 89 | output_shape = input_shape + (hidden_states.size(-1),) 90 | return hidden_states.view(*output_shape), presents 91 | 92 | 93 | class GPT2LMHeadModel(GPT2PreTrainedModel): 94 | """OpenAI GPT-2 model with a Language Modeling head ("Language Models are Unsupervised Multitask Learners"). 95 | 96 | Params: 97 | config: a GPT2Config class instance with the configuration to build a new model 98 | 99 | Inputs: 100 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length] 101 | were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[ 102 | `position_ids`: an optional torch.LongTensor with the same shape as input_ids 103 | with the position indices (selected in the range [0, config.n_positions - 1[. 104 | `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids 105 | You can use it to add a third type of embedding to each input token in the sequence 106 | (the previous two being the word and position embeddings). 107 | The input, position and token_type embeddings are summed inside the Transformer before the first 108 | self-attention block. 109 | `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 110 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 111 | is only computed for the labels set in [0, ..., vocab_size] 112 | `past`: an optional list of torch.LongTensor that contains pre-computed hidden-states 113 | (key and values in the attention blocks) to speed up sequential decoding 114 | (this is the presents output of the model, cf. below). 115 | 116 | Outputs: 117 | if `lm_labels` is not `None`: 118 | Outputs the language modeling loss. 119 | else a tuple: 120 | `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, sequence_length, config.vocab_size] 121 | (or more generally [d_1, ..., d_n, config.vocab_size] were d_1 ... d_n are the dimension of input_ids) 122 | `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as 123 | torch.FloatTensors. They can be reused to speed up sequential decoding. 124 | 125 | Example usage: 126 | ```python 127 | # Already been converted into BPE token ids 128 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 129 | 130 | config = modeling_gpt2.GPT2Config() 131 | 132 | model = modeling_gpt2.GPT2LMHeadModel(config) 133 | lm_logits, presents = model(input_ids) 134 | ``` 135 | """ 136 | 137 | def __init__(self, config): 138 | super(GPT2LMHeadModel, self).__init__(config) 139 | self.transformer = GPT2Model(config) 140 | self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) 141 | self.apply(self.init_weights) 142 | 143 | def set_tied(self): 144 | """ Make sure we are sharing the embeddings 145 | """ 146 | self.lm_head.set_embeddings_weights(self.transformer.wte.weight) 147 | 148 | def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None): 149 | hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past) 150 | lm_logits = self.lm_head(hidden_states) 151 | if lm_labels is not None: 152 | # Shift so that tokens < n predict n 153 | shift_logits = lm_logits[:, :-1].contiguous() 154 | shift_labels = lm_labels[:, 1:].contiguous() 155 | 156 | # Flatten the tokens 157 | loss_fct = CrossEntropyLoss(ignore_index=-1) 158 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), 159 | shift_labels.view(-1)) 160 | return loss 161 | return lm_logits, presents 162 | -------------------------------------------------------------------------------- /gpt2sqa/gpt2/gpt2pretrained.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import os 5 | from torch import nn 6 | 7 | from gpt2_question_answering.gpt2.gpt2config import GPT2Config 8 | from gpt2_question_answering.file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME 9 | from gpt2_question_answering.gpt2.layer_norm import LayerNorm 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"} 16 | PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"} 17 | 18 | 19 | class GPT2PreTrainedModel(nn.Module): 20 | """ An abstract class to handle weights initialization and 21 | a simple interface for dowloading and loading pretrained models. 22 | """ 23 | 24 | def __init__(self, config, *inputs, **kwargs): 25 | super(GPT2PreTrainedModel, self).__init__() 26 | if not isinstance(config, GPT2Config): 27 | raise ValueError( 28 | "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. " 29 | "To create a model from a pretrained model use " 30 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 31 | self.__class__.__name__, self.__class__.__name__ 32 | ) 33 | ) 34 | self.config = config 35 | 36 | def set_tied(self): 37 | pass 38 | 39 | def init_weights(self, module): 40 | """ Initialize the weights. 41 | """ 42 | if isinstance(module, (nn.Linear, nn.Embedding)): 43 | # Slightly different from the TF version which uses truncated_normal for initialization 44 | # cf https://github.com/pytorch/pytorch/pull/5617 45 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 46 | elif isinstance(module, LayerNorm): 47 | module.bias.data.zero_() 48 | module.weight.data.fill_(1.0) 49 | if isinstance(module, nn.Linear) and module.bias is not None: 50 | module.bias.data.zero_() 51 | 52 | @classmethod 53 | def from_pretrained( 54 | cls, pretrained_model_name_or_path='gpt2', state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs 55 | ): 56 | """ 57 | Instantiate a GPT2PreTrainedModel from a pre-trained model file or a pytorch state dict. 58 | Download and cache the pre-trained model file if needed. 59 | 60 | Params: 61 | pretrained_model_name_or_path: either: 62 | - a str with the name of a pre-trained model to load selected in the list of: 63 | . `gpt2` 64 | - a path or url to a pretrained model archive containing: 65 | . `gpt2_config.json` a configuration file for the model 66 | . `pytorch_model.bin` a PyTorch dump of a GPT2Model instance 67 | - a path or url to a pretrained model archive containing: 68 | . `gpt2_config.json` a configuration file for the model 69 | . a TensorFlow checkpoint with trained weights 70 | from_tf: should we load the weights from a locally saved TensorFlow checkpoint 71 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 72 | state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models 73 | *inputs, **kwargs: additional input for the specific GPT class 74 | """ 75 | if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: 76 | archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] 77 | config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path] 78 | else: 79 | archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) 80 | config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) 81 | # redirect to the cache, if necessary 82 | print(archive_file) 83 | try: 84 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) 85 | resolved_config_file = cached_path(config_file, cache_dir=cache_dir) 86 | except EnvironmentError: 87 | logger.error( 88 | "Model name '{}' was not found in model name list ({}). " 89 | "We assumed '{}' was a path or url but couldn't find files {} and {} " 90 | "at this path or url.".format( 91 | pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path, 92 | archive_file, config_file 93 | ) 94 | ) 95 | return None 96 | if resolved_archive_file == archive_file and resolved_config_file == config_file: 97 | logger.info("loading weights file {}".format(archive_file)) 98 | logger.info("loading configuration file {}".format(config_file)) 99 | else: 100 | logger.info("loading weights file {} from cache at {}".format( 101 | archive_file, resolved_archive_file)) 102 | logger.info("loading configuration file {} from cache at {}".format( 103 | config_file, resolved_config_file)) 104 | # Load config 105 | config = GPT2Config.from_json_file(resolved_config_file) 106 | logger.info("Model config {}".format(config)) 107 | # Instantiate model. 108 | model = cls(config, *inputs, **kwargs) 109 | if state_dict is None and not from_tf: 110 | state_dict = torch.load(resolved_archive_file, map_location='cpu') 111 | if from_tf: 112 | # Directly load from a TensorFlow checkpoint (stored as NumPy array) 113 | return load_tf_weights_in_gpt2(model, resolved_archive_file) 114 | old_keys = [] 115 | new_keys = [] 116 | for key in state_dict.keys(): 117 | new_key = None 118 | if key.endswith(".g"): 119 | new_key = key[:-2] + ".weight" 120 | elif key.endswith(".b"): 121 | new_key = key[:-2] + ".bias" 122 | elif key.endswith(".w"): 123 | new_key = key[:-2] + ".weight" 124 | if new_key: 125 | old_keys.append(key) 126 | new_keys.append(new_key) 127 | for old_key, new_key in zip(old_keys, new_keys): 128 | state_dict[new_key] = state_dict.pop(old_key) 129 | 130 | missing_keys = [] 131 | unexpected_keys = [] 132 | error_msgs = [] 133 | # copy state_dict so _load_from_state_dict can modify it 134 | metadata = getattr(state_dict, "_metadata", None) 135 | state_dict = state_dict.copy() 136 | if metadata is not None: 137 | state_dict._metadata = metadata 138 | 139 | def load(module, prefix=""): 140 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 141 | module._load_from_state_dict( 142 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs 143 | ) 144 | for name, child in module._modules.items(): 145 | if child is not None: 146 | load(child, prefix + name + ".") 147 | 148 | start_model = model 149 | if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()): 150 | start_model = model.transformer 151 | load(start_model, prefix="") 152 | 153 | if len(missing_keys) > 0: 154 | logger.info( 155 | "Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys) 156 | ) 157 | if len(unexpected_keys) > 0: 158 | logger.info( 159 | "Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys) 160 | ) 161 | if len(error_msgs) > 0: 162 | raise RuntimeError( 163 | "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)) 164 | ) 165 | 166 | # Make sure we are still sharing the output and input embeddings after loading weights 167 | model.set_tied() 168 | return model 169 | -------------------------------------------------------------------------------- /gpt2sqa/gpt2/gptdoubleheads.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import CrossEntropyLoss 5 | 6 | from gpt2sqa.gpt2.gpt2lmhead import GPT2LMHead, GPT2MultipleChoiceHead 7 | from gpt2sqa.gpt2.gpt2pretrained import GPT2PreTrainedModel 8 | from gpt2sqa.gpt2.gpt2model import GPT2Model 9 | 10 | 11 | class GPT2DoubleHeadsModel(GPT2PreTrainedModel): 12 | """OpenAI GPT-2 model with a Language Modeling and a Multiple Choice head ("Language Models are Unsupervised Multitask Learners"). 13 | 14 | Params: 15 | config: a GPT2Config class instance with the configuration to build a new model 16 | 17 | Inputs: 18 | `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with the BPE token 19 | indices selected in the range [0, config.vocab_size[ 20 | `mc_token_ids`: a torch.LongTensor of shape [batch_size, num_choices] with the index of the token from 21 | which we should take the hidden state to feed the multiple choice classifier (usually last token of the sequence) 22 | `position_ids`: an optional torch.LongTensor with the same shape as input_ids 23 | with the position indices (selected in the range [0, config.n_positions - 1[. 24 | `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids 25 | You can use it to add a third type of embedding to each input token in the sequence 26 | (the previous two being the word and position embeddings). 27 | The input, position and token_type embeddings are summed inside the Transformer before the first 28 | self-attention block. 29 | `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, num_choices, sequence_length] 30 | with indices selected in [-1, 0, ..., config.vocab_size]. All labels set to -1 are ignored (masked), the loss 31 | is only computed for the labels set in [0, ..., config.vocab_size] 32 | `multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size] 33 | with indices selected in [0, ..., num_choices]. 34 | `past`: an optional list of torch.LongTensor that contains pre-computed hidden-states 35 | (key and values in the attention blocks) to speed up sequential decoding 36 | (this is the presents output of the model, cf. below). 37 | 38 | Outputs: 39 | if `lm_labels` and `multiple_choice_labels` are not `None`: 40 | Outputs a tuple of losses with the language modeling loss and the multiple choice loss. 41 | else: a tuple with 42 | `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, num_choices, sequence_length, config.vocab_size] 43 | `multiple_choice_logits`: the multiple choice logits as a torch.FloatTensor of size [batch_size, num_choices] 44 | `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as 45 | torch.FloatTensors. They can be reused to speed up sequential decoding. 46 | 47 | Example usage: 48 | ```python 49 | # Already been converted into BPE token ids 50 | input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]]]) # (bsz, number of choice, seq length) 51 | mc_token_ids = torch.LongTensor([[2], [1]]) # (bsz, number of choice) 52 | 53 | config = modeling_gpt2.GPT2Config() 54 | 55 | model = modeling_gpt2.GPT2LMHeadModel(config) 56 | lm_logits, multiple_choice_logits, presents = model(input_ids, mc_token_ids) 57 | ``` 58 | """ 59 | 60 | def __init__(self, config): 61 | super(GPT2DoubleHeadsModel, self).__init__(config) 62 | self.transformer = GPT2Model(config) 63 | self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) 64 | self.multiple_choice_head = GPT2MultipleChoiceHead(config) 65 | self.apply(self.init_weights) 66 | 67 | def set_tied(self): 68 | """ Make sure we are sharing the embeddings 69 | """ 70 | self.lm_head.set_embeddings_weights(self.transformer.wte.weight) 71 | 72 | def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None): 73 | hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past) 74 | lm_logits = self.lm_head(hidden_states) 75 | mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) 76 | losses = [] 77 | if lm_labels is not None: 78 | shift_logits = lm_logits[:, :-1].contiguous() 79 | shift_labels = lm_labels[:, 1:].contiguous() 80 | loss_fct = CrossEntropyLoss(ignore_index=-1) 81 | losses.append(loss_fct(shift_logits.view(-1, 82 | shift_logits.size(-1)), shift_labels.view(-1))) 83 | if mc_labels is not None: 84 | loss_fct = CrossEntropyLoss() 85 | losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))) 86 | if losses: 87 | return losses 88 | return lm_logits, mc_logits, presents 89 | -------------------------------------------------------------------------------- /gpt2sqa/gpt2/layer_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LayerNorm(nn.Module): 6 | def __init__(self, hidden_size, eps=1e-12): 7 | """Construct a layernorm module in the TF style (epsilon inside the square root). 8 | """ 9 | super(LayerNorm, self).__init__() 10 | self.weight = nn.Parameter(torch.ones(hidden_size)) 11 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 12 | self.variance_epsilon = eps 13 | 14 | def forward(self, x): 15 | u = x.mean(-1, keepdim=True) 16 | s = (x - u).pow(2).mean(-1, keepdim=True) 17 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 18 | return self.weight * x + self.bias 19 | -------------------------------------------------------------------------------- /gpt2sqa/gpt2/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn.parameter import Parameter 6 | 7 | from gpt2sqa.gpt2.layer_norm import LayerNorm 8 | from gpt2sqa.gpt2.utils import gelu 9 | 10 | 11 | class Conv1D(nn.Module): 12 | def __init__(self, nf, nx): 13 | super(Conv1D, self).__init__() 14 | self.nf = nf 15 | w = torch.empty(nx, nf) 16 | nn.init.normal_(w, std=0.02) 17 | self.weight = Parameter(w) 18 | self.bias = Parameter(torch.zeros(nf)) 19 | 20 | def forward(self, x): 21 | size_out = x.size()[:-1] + (self.nf,) 22 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) 23 | x = x.view(*size_out) 24 | return x 25 | 26 | 27 | class Attention(nn.Module): 28 | def __init__(self, nx, n_ctx, config, scale=False): 29 | super(Attention, self).__init__() 30 | n_state = nx # in Attention: n_state=768 (nx=n_embd) 31 | # [switch nx => n_state from Block to Attention to keep identical to TF implem] 32 | assert n_state % config.n_head == 0 33 | self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) 34 | self.n_head = config.n_head 35 | self.split_size = n_state 36 | self.scale = scale 37 | self.c_attn = Conv1D(n_state * 3, nx) 38 | self.c_proj = Conv1D(n_state, nx) 39 | 40 | def _attn(self, q, k, v): 41 | w = torch.matmul(q, k) 42 | if self.scale: 43 | w = w / math.sqrt(v.size(-1)) 44 | nd, ns = w.size(-2), w.size(-1) 45 | b = self.bias[:, :, ns - nd:ns, :ns] 46 | w = w * b - 1e4 * (1 - b) 47 | 48 | w = nn.Softmax(dim=-1)(w) 49 | return torch.matmul(w, v) 50 | 51 | def merge_heads(self, x): 52 | x = x.permute(0, 2, 1, 3).contiguous() 53 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) 54 | return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states 55 | 56 | def split_heads(self, x, k=False): 57 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) 58 | x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states 59 | if k: 60 | return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) 61 | else: 62 | return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) 63 | 64 | def forward(self, x, layer_past=None): 65 | x = self.c_attn(x) 66 | query, key, value = x.split(self.split_size, dim=2) 67 | query = self.split_heads(query) 68 | key = self.split_heads(key, k=True) 69 | value = self.split_heads(value) 70 | if layer_past is not None: 71 | past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below 72 | key = torch.cat((past_key, key), dim=-1) 73 | value = torch.cat((past_value, value), dim=-2) 74 | present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking 75 | a = self._attn(query, key, value) 76 | a = self.merge_heads(a) 77 | a = self.c_proj(a) 78 | return a, present 79 | 80 | 81 | class MLP(nn.Module): 82 | def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) 83 | super(MLP, self).__init__() 84 | nx = config.n_embd 85 | self.c_fc = Conv1D(n_state, nx) 86 | self.c_proj = Conv1D(nx, n_state) 87 | self.act = gelu 88 | 89 | def forward(self, x): 90 | h = self.act(self.c_fc(x)) 91 | h2 = self.c_proj(h) 92 | return h2 93 | 94 | 95 | class Block(nn.Module): 96 | def __init__(self, n_ctx, config, scale=False): 97 | super(Block, self).__init__() 98 | nx = config.n_embd 99 | self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) 100 | self.attn = Attention(nx, n_ctx, config, scale) 101 | self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) 102 | self.mlp = MLP(4 * nx, config) 103 | 104 | def forward(self, x, layer_past=None): 105 | a, present = self.attn(self.ln_1(x), layer_past=layer_past) 106 | x = x + a 107 | m = self.mlp(self.ln_2(x)) 108 | x = x + m 109 | return x, present 110 | -------------------------------------------------------------------------------- /gpt2sqa/gpt2/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import math 4 | 5 | 6 | def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path): 7 | """ Load tf checkpoints in a pytorch model 8 | """ 9 | try: 10 | import re 11 | import numpy as np 12 | import tensorflow as tf 13 | except ImportError: 14 | print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " 15 | "https://www.tensorflow.org/install/ for installation instructions.") 16 | raise 17 | tf_path = os.path.abspath(gpt2_checkpoint_path) 18 | print("Converting TensorFlow checkpoint from {}".format(tf_path)) 19 | # Load weights from TF model 20 | init_vars = tf.train.list_variables(tf_path) 21 | names = [] 22 | arrays = [] 23 | for name, shape in init_vars: 24 | print("Loading TF weight {} with shape {}".format(name, shape)) 25 | array = tf.train.load_variable(tf_path, name) 26 | names.append(name) 27 | arrays.append(array.squeeze()) 28 | 29 | for name, array in zip(names, arrays): 30 | name = name[6:] # skip "model/" 31 | name = name.split('/') 32 | pointer = model 33 | for m_name in name: 34 | if re.fullmatch(r'[A-Za-z]+\d+', m_name): 35 | l = re.split(r'(\d+)', m_name) 36 | else: 37 | l = [m_name] 38 | if l[0] == 'w' or l[0] == 'g': 39 | pointer = getattr(pointer, 'weight') 40 | elif l[0] == 'b': 41 | pointer = getattr(pointer, 'bias') 42 | elif l[0] == 'wpe' or l[0] == 'wte': 43 | pointer = getattr(pointer, l[0]) 44 | pointer = getattr(pointer, 'weight') 45 | else: 46 | pointer = getattr(pointer, l[0]) 47 | if len(l) >= 2: 48 | num = int(l[1]) 49 | pointer = pointer[num] 50 | try: 51 | assert pointer.shape == array.shape 52 | except AssertionError as e: 53 | e.args += (pointer.shape, array.shape) 54 | raise 55 | print("Initialize PyTorch weight {}".format(name)) 56 | pointer.data = torch.from_numpy(array) 57 | return model 58 | 59 | 60 | def gelu(x): 61 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 62 | -------------------------------------------------------------------------------- /gpt2sqa/modeling_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """PyTorch OpenAI GPT-2 small model.""" 18 | 19 | from __future__ import absolute_import, division, print_function, unicode_literals 20 | 21 | import logging 22 | 23 | import torch 24 | import torch.nn as nn 25 | from torch.nn import CrossEntropyLoss 26 | from torch.nn.parameter import Parameter 27 | 28 | from gpt2sqa.gpt2.gpt2pretrained import GPT2PreTrainedModel 29 | from gpt2sqa.gpt2.gpt2model import GPT2Model 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | class GPT2ModelForQuestionAnswering(GPT2PreTrainedModel): 35 | """A linear layer on top of pre-trained GPT-2 output that computes start_logits and end_logits 36 | 37 | Params: 38 | `config`: a BertConfig class instance with the configuration to build a new model. 39 | 40 | Inputs: 41 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 42 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 43 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 44 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 45 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 46 | a `sentence B` token (see BERT paper for more details). 47 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 48 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 49 | input sequence length in the current batch. It's the mask that we typically use for attention when 50 | a batch has varying length sentences. 51 | `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. 52 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken 53 | into account for computing the loss. 54 | `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. 55 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken 56 | into account for computing the loss. 57 | 58 | Outputs: 59 | if `start_positions` and `end_positions` are not `None`: 60 | Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. 61 | if `start_positions` or `end_positions` is `None`: 62 | Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end 63 | position tokens of shape [batch_size, sequence_length]. 64 | """ 65 | 66 | def __init__(self, config): 67 | super(GPT2ModelForQuestionAnswering, self).__init__(config) 68 | self.gpt2 = GPT2Model(config) 69 | print(config) 70 | self.qa_outputs = nn.Linear(config.n_embd, 2) 71 | self.apply(self.init_weights) 72 | 73 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): 74 | sequence_output, _ = self.gpt2(input_ids, None, token_type_ids) 75 | logits = self.qa_outputs(sequence_output) 76 | start_logits, end_logits = logits.split(1, dim=2) 77 | start_logits = start_logits.squeeze(-1) 78 | end_logits = end_logits.squeeze(-1) 79 | 80 | if start_positions is not None and end_positions is not None: 81 | if len(start_positions.size()) > 1: 82 | start_positions = start_positions.squeeze(-1) 83 | if len(end_positions.size()) > 1: 84 | end_positions = end_positions.squeeze(-1) 85 | ignored_index = start_logits.size(1) 86 | start_positions.clamp_(0, ignored_index) 87 | end_positions.clamp_(0, ignored_index) 88 | 89 | loss_function = CrossEntropyLoss(ignore_index=ignored_index) 90 | start_loss = loss_function(start_logits, start_positions) 91 | end_loss = loss_function(end_logits, end_positions) 92 | total_loss = (start_loss + end_loss) / 2 93 | return total_loss 94 | else: 95 | return start_logits, end_logits 96 | -------------------------------------------------------------------------------- /gpt2sqa/optimization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """PyTorch optimization for GPT2 model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | import abc 24 | import sys 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | if sys.version_info >= (3, 4): 30 | ABC = abc.ABC 31 | else: 32 | ABC = abc.ABCMeta('ABC', (), {}) 33 | 34 | 35 | class _LRSchedule(ABC): 36 | """ Parent of all LRSchedules here. """ 37 | warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense 38 | 39 | def __init__(self, warmup=0.002, t_total=-1, **kw): 40 | """ 41 | :param warmup: what fraction of t_total steps will be used for linear warmup 42 | :param t_total: how many training steps (updates) are planned 43 | :param kw: 44 | """ 45 | super(_LRSchedule, self).__init__(**kw) 46 | if t_total < 0: 47 | logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) 48 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 49 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 50 | warmup = max(warmup, 0.) 51 | self.warmup, self.t_total = float(warmup), float(t_total) 52 | self.warned_for_t_total_at_progress = -1 53 | 54 | def get_lr(self, step, nowarn=False): 55 | """ 56 | :param step: which of t_total steps we're on 57 | :param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps 58 | :return: learning rate multiplier for current update 59 | """ 60 | if self.t_total < 0: 61 | return 1. 62 | progress = float(step) / self.t_total 63 | ret = self.get_lr_(progress) 64 | # warning for exceeding t_total (only active with warmup_linear 65 | if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress: 66 | logger.warning( 67 | "Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly." 68 | .format(ret, self.__class__.__name__)) 69 | self.warned_for_t_total_at_progress = progress 70 | # end warning 71 | return ret 72 | 73 | @abc.abstractmethod 74 | def get_lr_(self, progress): 75 | """ 76 | :param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress 77 | :return: learning rate multiplier for current update 78 | """ 79 | return 1. 80 | 81 | 82 | class ConstantLR(_LRSchedule): 83 | def get_lr_(self, progress): 84 | return 1. 85 | 86 | 87 | class WarmupCosineSchedule(_LRSchedule): 88 | """ 89 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 90 | Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve. 91 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 92 | """ 93 | warn_t_total = True 94 | 95 | def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw): 96 | """ 97 | :param warmup: see LRSchedule 98 | :param t_total: see LRSchedule 99 | :param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1. 100 | :param kw: 101 | """ 102 | super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw) 103 | self.cycles = cycles 104 | 105 | def get_lr_(self, progress): 106 | if progress < self.warmup: 107 | return progress / self.warmup 108 | else: 109 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 110 | return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress)) 111 | 112 | 113 | class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule): 114 | """ 115 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 116 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying 117 | learning rate (with hard restarts). 118 | """ 119 | 120 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): 121 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) 122 | assert(cycles >= 1.) 123 | 124 | def get_lr_(self, progress): 125 | if progress < self.warmup: 126 | return progress / self.warmup 127 | else: 128 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 129 | ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1))) 130 | return ret 131 | 132 | 133 | class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule): 134 | """ 135 | All training progress is divided in `cycles` (default=1.) parts of equal length. 136 | Every part follows a schedule with the first `warmup` fraction of the training steps linearly increasing from 0. to 1., 137 | followed by a learning rate decreasing from 1. to 0. following a cosine curve. 138 | """ 139 | 140 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): 141 | assert(warmup * cycles < 1.) 142 | warmup = warmup * cycles if warmup >= 0 else warmup 143 | super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) 144 | 145 | def get_lr_(self, progress): 146 | progress = progress * self.cycles % 1. 147 | if progress < self.warmup: 148 | return progress / self.warmup 149 | else: 150 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 151 | ret = 0.5 * (1. + math.cos(math.pi * progress)) 152 | return ret 153 | 154 | 155 | class WarmupConstantSchedule(_LRSchedule): 156 | """ 157 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 158 | Keeps learning rate equal to 1. after warmup. 159 | """ 160 | 161 | def get_lr_(self, progress): 162 | if progress < self.warmup: 163 | return progress / self.warmup 164 | return 1. 165 | 166 | 167 | class WarmupLinearSchedule(_LRSchedule): 168 | """ 169 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 170 | Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps. 171 | """ 172 | warn_t_total = True 173 | 174 | def get_lr_(self, progress): 175 | if progress < self.warmup: 176 | return progress / self.warmup 177 | return max((progress - 1.) / (self.warmup - 1.), 0.) 178 | 179 | 180 | SCHEDULES = { 181 | None: ConstantLR, 182 | "none": ConstantLR, 183 | "warmup_cosine": WarmupCosineSchedule, 184 | "warmup_constant": WarmupConstantSchedule, 185 | "warmup_linear": WarmupLinearSchedule 186 | } 187 | 188 | 189 | class GPT2Adam(Optimizer): 190 | """Implements BERT version of Adam algorithm with weight decay fix. 191 | Params: 192 | lr: learning rate 193 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 194 | t_total: total number of training steps for the learning 195 | rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1 196 | schedule: schedule to use for the warmup (see above). 197 | Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object (see below). 198 | If `None` or `'none'`, learning rate is always kept constant. 199 | Default : `'warmup_linear'` 200 | b1: Adams b1. Default: 0.9 201 | b2: Adams b2. Default: 0.999 202 | e: Adams epsilon. Default: 1e-6 203 | weight_decay: Weight decay. Default: 0.01 204 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 205 | """ 206 | 207 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 208 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs): 209 | if lr is not required and lr < 0.0: 210 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 211 | if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: 212 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 213 | if not 0.0 <= b1 < 1.0: 214 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 215 | if not 0.0 <= b2 < 1.0: 216 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 217 | if not e >= 0.0: 218 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 219 | # initialize schedule object 220 | if not isinstance(schedule, _LRSchedule): 221 | schedule_type = SCHEDULES[schedule] 222 | schedule = schedule_type(warmup=warmup, t_total=t_total) 223 | else: 224 | if warmup != -1 or t_total != -1: 225 | logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. " 226 | "Please specify custom warmup and t_total in _LRSchedule object.") 227 | defaults = dict(lr=lr, schedule=schedule, 228 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 229 | max_grad_norm=max_grad_norm) 230 | super(GPT2Adam, self).__init__(params, defaults) 231 | 232 | def get_lr(self): 233 | lr = [] 234 | for group in self.param_groups: 235 | for p in group['params']: 236 | state = self.state[p] 237 | if len(state) == 0: 238 | return [0] 239 | lr_scheduled = group['lr'] 240 | lr_scheduled *= group['schedule'].get_lr(state['step']) 241 | lr.append(lr_scheduled) 242 | return lr 243 | 244 | def step(self, closure=None): 245 | """Performs a single optimization step. 246 | 247 | Arguments: 248 | closure (callable, optional): A closure that reevaluates the model 249 | and returns the loss. 250 | """ 251 | loss = None 252 | if closure is not None: 253 | loss = closure() 254 | 255 | for group in self.param_groups: 256 | for p in group['params']: 257 | if p.grad is None: 258 | continue 259 | grad = p.grad.data 260 | if grad.is_sparse: 261 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 262 | 263 | state = self.state[p] 264 | 265 | # State initialization 266 | if len(state) == 0: 267 | state['step'] = 0 268 | # Exponential moving average of gradient values 269 | state['next_m'] = torch.zeros_like(p.data) 270 | # Exponential moving average of squared gradient values 271 | state['next_v'] = torch.zeros_like(p.data) 272 | 273 | next_m, next_v = state['next_m'], state['next_v'] 274 | beta1, beta2 = group['b1'], group['b2'] 275 | 276 | # Add grad clipping 277 | if group['max_grad_norm'] > 0: 278 | clip_grad_norm_(p, group['max_grad_norm']) 279 | 280 | # Decay the first and second moment running average coefficient 281 | # In-place operations to update the averages at the same time 282 | next_m.mul_(beta1).add_(1 - beta1, grad) 283 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 284 | update = next_m / (next_v.sqrt() + group['e']) 285 | 286 | # Just adding the square of the weights to the loss function is *not* 287 | # the correct way of using L2 regularization/weight decay with Adam, 288 | # since that will interact with the m and v parameters in strange ways. 289 | # 290 | # Instead we want to decay the weights in a manner that doesn't interact 291 | # with the m/v parameters. This is equivalent to adding the square 292 | # of the weights to the loss with plain (non-momentum) SGD. 293 | if group['weight_decay'] > 0.0: 294 | update += group['weight_decay'] * p.data 295 | 296 | lr_scheduled = group['lr'] 297 | lr_scheduled *= group['schedule'].get_lr(state['step']) 298 | 299 | update_with_lr = lr_scheduled * update 300 | p.data.add_(-update_with_lr) 301 | 302 | state['step'] += 1 303 | 304 | return loss 305 | -------------------------------------------------------------------------------- /gpt2sqa/squad/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ftarlaci/GPT2sQA/41cd86ef5c2051ad3fda224ac912d97d07f73f61/gpt2sqa/squad/.DS_Store -------------------------------------------------------------------------------- /gpt2sqa/squad/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ftarlaci/GPT2sQA/41cd86ef5c2051ad3fda224ac912d97d07f73f61/gpt2sqa/squad/__init__.py -------------------------------------------------------------------------------- /gpt2sqa/squad/squad_example.py: -------------------------------------------------------------------------------- 1 | class SquadExample(object): 2 | """ 3 | A single training/test example for the Squad dataset. 4 | For examples without an answer, the start and end position are -1. 5 | """ 6 | 7 | def __init__(self, 8 | qas_id, 9 | question_text, 10 | doc_tokens, 11 | orig_answer_text=None, 12 | start_position=None, 13 | end_position=None, 14 | is_impossible=None): 15 | self.qas_id = qas_id 16 | self.question_text = question_text 17 | self.doc_tokens = doc_tokens 18 | self.orig_answer_text = orig_answer_text 19 | self.start_position = start_position 20 | self.end_position = end_position 21 | self.is_impossible = is_impossible 22 | 23 | def __str__(self): 24 | return self.__repr__() 25 | 26 | def __repr__(self): 27 | s = "" 28 | s += "qas_id: %s" % (self.qas_id) 29 | s += ", question_text: %s" % ( 30 | self.question_text) 31 | s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) 32 | if self.start_position: 33 | s += ", start_position: %d" % (self.start_position) 34 | if self.end_position: 35 | s += ", end_position: %d" % (self.end_position) 36 | if self.is_impossible: 37 | s += ", is_impossible: %r" % (self.is_impossible) 38 | return s 39 | 40 | class InputFeatures(object): 41 | """A single set of features of data.""" 42 | 43 | def __init__(self, 44 | unique_id, 45 | example_index, 46 | doc_span_index, 47 | tokens, 48 | token_to_orig_map, 49 | token_is_max_context, 50 | input_ids, 51 | input_mask, 52 | segment_ids, 53 | start_position=None, 54 | end_position=None, 55 | is_impossible=None): 56 | self.unique_id = unique_id 57 | self.example_index = example_index 58 | self.doc_span_index = doc_span_index 59 | self.tokens = tokens 60 | self.token_to_orig_map = token_to_orig_map 61 | self.token_is_max_context = token_is_max_context 62 | self.input_ids = input_ids 63 | self.input_mask = input_mask 64 | self.segment_ids = segment_ids 65 | self.start_position = start_position 66 | self.end_position = end_position 67 | self.is_impossible = is_impossible 68 | 69 | 70 | -------------------------------------------------------------------------------- /gpt2sqa/squad/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import json 3 | import collections 4 | import math 5 | 6 | from gpt2sqa.squad.squad_example import SquadExample, InputFeatures 7 | from gpt2sqa.tokenization import (whitespace_tokenize, BasicTokenizer) 8 | 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def read_squad_examples(input_file, is_training, version_2_with_negative=True): 14 | """Read a SQuAD json file into a list of SquadExample.""" 15 | with open(input_file, "r", encoding='utf-8') as reader: 16 | input_data = json.load(reader)["data"] 17 | 18 | def is_whitespace(c): 19 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: 20 | return True 21 | return False 22 | 23 | examples = [] 24 | for entry in input_data: 25 | for paragraph in entry["paragraphs"]: 26 | paragraph_text = paragraph["context"] 27 | doc_tokens = [] 28 | char_to_word_offset = [] 29 | prev_is_whitespace = True 30 | for c in paragraph_text: 31 | if is_whitespace(c): 32 | prev_is_whitespace = True 33 | else: 34 | if prev_is_whitespace: 35 | doc_tokens.append(c) 36 | else: 37 | doc_tokens[-1] += c 38 | prev_is_whitespace = False 39 | char_to_word_offset.append(len(doc_tokens) - 1) 40 | 41 | for qa in paragraph["qas"]: 42 | qas_id = qa["id"] 43 | question_text = qa["question"] 44 | start_position = None 45 | end_position = None 46 | orig_answer_text = None 47 | is_impossible = False 48 | if is_training: 49 | if version_2_with_negative: 50 | is_impossible = qa["is_impossible"] 51 | if (len(qa["answers"]) != 1) and (not is_impossible): 52 | raise ValueError( 53 | "For training, each question should have exactly 1 answer.") 54 | if not is_impossible: 55 | answer = qa["answers"][0] 56 | orig_answer_text = answer["text"] 57 | answer_offset = answer["answer_start"] 58 | answer_length = len(orig_answer_text) 59 | start_position = char_to_word_offset[answer_offset] 60 | end_position = char_to_word_offset[answer_offset + answer_length - 1] 61 | # Only add answers where the text can be exactly recovered from the 62 | # document. If this CAN'T happen it's likely due to weird Unicode 63 | # stuff so we will just skip the example. 64 | # 65 | # Note that this means for training mode, every example is NOT 66 | # guaranteed to be preserved. 67 | actual_text = " ".join(doc_tokens[start_position:(end_position + 1)]) 68 | cleaned_answer_text = " ".join( 69 | whitespace_tokenize(orig_answer_text)) 70 | if actual_text.find(cleaned_answer_text) == -1: 71 | logger.warning("Could not find answer: '%s' vs. '%s'", 72 | actual_text, cleaned_answer_text) 73 | continue 74 | else: 75 | start_position = -1 76 | end_position = -1 77 | orig_answer_text = "" 78 | 79 | example = SquadExample( 80 | qas_id=qas_id, 81 | question_text=question_text, 82 | doc_tokens=doc_tokens, 83 | orig_answer_text=orig_answer_text, 84 | start_position=start_position, 85 | end_position=end_position, 86 | is_impossible=is_impossible) 87 | examples.append(example) 88 | return examples 89 | 90 | 91 | def convert_examples_to_features(examples, tokenizer, max_seq_length, 92 | doc_stride, max_query_length, is_training): 93 | """Loads a data file into a list of `InputBatch`s.""" 94 | 95 | unique_id = 1000000000 96 | 97 | features = [] 98 | 99 | for (example_index, example) in enumerate(examples): 100 | query_tokens = tokenizer.tokenize(example.question_text) 101 | 102 | if len(query_tokens) > max_query_length: 103 | query_tokens = query_tokens[0:max_query_length] 104 | 105 | tok_to_orig_index = [] 106 | orig_to_tok_index = [] 107 | all_doc_tokens = [] 108 | for (i, token) in enumerate(example.doc_tokens): 109 | orig_to_tok_index.append(len(all_doc_tokens)) 110 | sub_tokens = tokenizer.tokenize(token) 111 | for sub_token in sub_tokens: 112 | tok_to_orig_index.append(i) 113 | all_doc_tokens.append(sub_token) 114 | 115 | tok_start_position = None 116 | tok_end_position = None 117 | if is_training and example.is_impossible: 118 | tok_start_position = -1 119 | tok_end_position = -1 120 | if is_training and not example.is_impossible: 121 | tok_start_position = orig_to_tok_index[example.start_position] 122 | if example.end_position < len(example.doc_tokens) - 1: 123 | tok_end_position = orig_to_tok_index[example.end_position + 1] - 1 124 | else: 125 | tok_end_position = len(all_doc_tokens) - 1 126 | (tok_start_position, tok_end_position) = _improve_answer_span( 127 | all_doc_tokens, tok_start_position, tok_end_position, tokenizer, 128 | example.orig_answer_text) 129 | 130 | # The -3 accounts for [CLS], [SEP] and [SEP] 131 | max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 132 | 133 | # We can have documents that are longer than the maximum sequence length. 134 | # To deal with this we do a sliding window approach, where we take chunks 135 | # of the up to our max length with a stride of `doc_stride`. 136 | _DocSpan = collections.namedtuple( # pylint: disable=invalid-name 137 | "DocSpan", ["start", "length"]) 138 | doc_spans = [] 139 | start_offset = 0 140 | while start_offset < len(all_doc_tokens): 141 | length = len(all_doc_tokens) - start_offset 142 | if length > max_tokens_for_doc: 143 | total_missed+=1 144 | length = max_tokens_for_doc 145 | doc_spans.append(_DocSpan(start=start_offset, length=length)) 146 | if start_offset + length == len(all_doc_tokens): 147 | break 148 | start_offset += min(length, doc_stride) 149 | for (doc_span_index, doc_span) in enumerate(doc_spans): 150 | tokens = [] 151 | token_to_orig_map = {} 152 | token_is_max_context = {} 153 | segment_ids = [] 154 | tokens.append("[CLS]") 155 | segment_ids.append(0) 156 | for token in query_tokens: 157 | tokens.append(token) 158 | segment_ids.append(0) 159 | tokens.append("[SEP]") 160 | segment_ids.append(0) 161 | 162 | for i in range(doc_span.length): 163 | split_token_index = doc_span.start + i 164 | token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] 165 | 166 | is_max_context = _check_is_max_context(doc_spans, doc_span_index, 167 | split_token_index) 168 | token_is_max_context[len(tokens)] = is_max_context 169 | tokens.append(all_doc_tokens[split_token_index]) 170 | segment_ids.append(1) 171 | tokens.append("[SEP]") 172 | segment_ids.append(1) 173 | 174 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 175 | 176 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 177 | # tokens are attended to. 178 | input_mask = [1] * len(input_ids) 179 | 180 | # Zero-pad up to the sequence length. 181 | while len(input_ids) < max_seq_length: 182 | input_ids.append(0) 183 | input_mask.append(0) 184 | segment_ids.append(0) 185 | 186 | assert len(input_ids) == max_seq_length 187 | assert len(input_mask) == max_seq_length 188 | assert len(segment_ids) == max_seq_length 189 | 190 | start_position = None 191 | end_position = None 192 | if is_training and not example.is_impossible: 193 | # For training, if our document chunk does not contain an annotation 194 | # we throw it out, since there is nothing to predict. 195 | doc_start = doc_span.start 196 | doc_end = doc_span.start + doc_span.length - 1 197 | out_of_span = False 198 | if not (tok_start_position >= doc_start and 199 | tok_end_position <= doc_end): 200 | out_of_span = True 201 | if out_of_span: 202 | total_missed+=1 203 | start_position = 0 204 | end_position = 0 205 | else: 206 | doc_offset = len(query_tokens) + 2 207 | start_position = tok_start_position - doc_start + doc_offset 208 | end_position = tok_end_position - doc_start + doc_offset 209 | if is_training and example.is_impossible: 210 | start_position = 0 211 | end_position = 0 212 | if example_index < 20: 213 | logger.info("*** Example ***") 214 | logger.info("unique_id: %s" % (unique_id)) 215 | logger.info("example_index: %s" % (example_index)) 216 | logger.info("doc_span_index: %s" % (doc_span_index)) 217 | logger.info("tokens: %s" % " ".join(tokens)) 218 | logger.info("token_to_orig_map: %s" % " ".join([ 219 | "%d:%d" % (x, y) for (x, y) in token_to_orig_map.items()])) 220 | logger.info("token_is_max_context: %s" % " ".join([ 221 | "%d:%s" % (x, y) for (x, y) in token_is_max_context.items() 222 | ])) 223 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 224 | logger.info( 225 | "input_mask: %s" % " ".join([str(x) for x in input_mask])) 226 | logger.info( 227 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 228 | if is_training and example.is_impossible: 229 | logger.info("impossible example") 230 | if is_training and not example.is_impossible: 231 | answer_text = " ".join(tokens[start_position:(end_position + 1)]) 232 | logger.info("start_position: %d" % (start_position)) 233 | logger.info("end_position: %d" % (end_position)) 234 | logger.info( 235 | "answer: %s" % (answer_text)) 236 | 237 | features.append( 238 | InputFeatures( 239 | unique_id=unique_id, 240 | example_index=example_index, 241 | doc_span_index=doc_span_index, 242 | tokens=tokens, 243 | token_to_orig_map=token_to_orig_map, 244 | token_is_max_context=token_is_max_context, 245 | input_ids=input_ids, 246 | input_mask=input_mask, 247 | segment_ids=segment_ids, 248 | start_position=start_position, 249 | end_position=end_position, 250 | is_impossible=example.is_impossible)) 251 | unique_id += 1 252 | print(f'total_missed:{total_missed}') 253 | return features 254 | 255 | 256 | 257 | 258 | 259 | 260 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, 261 | orig_answer_text): 262 | """Returns tokenized answer spans that better match the annotated answer.""" 263 | 264 | # The SQuAD annotations are character based. We first project them to 265 | # whitespace-tokenized words. But then after WordPiece tokenization, we can 266 | # often find a "better match". For example: 267 | # 268 | # Question: What year was John Smith born? 269 | # Context: The leader was John Smith (1895-1943). 270 | # Answer: 1895 271 | # 272 | # The original whitespace-tokenized answer will be "(1895-1943).". However 273 | # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match 274 | # the exact answer, 1895. 275 | # 276 | # However, this is not always possible. Consider the following: 277 | # 278 | # Question: What country is the top exporter of electornics? 279 | # Context: The Japanese electronics industry is the lagest in the world. 280 | # Answer: Japan 281 | # 282 | # In this case, the annotator chose "Japan" as a character sub-span of 283 | # the word "Japanese". Since our WordPiece tokenizer does not split 284 | # "Japanese", we just use "Japanese" as the annotation. This is fairly rare 285 | # in SQuAD, but does happen. 286 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) 287 | 288 | for new_start in range(input_start, input_end + 1): 289 | for new_end in range(input_end, new_start - 1, -1): 290 | text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) 291 | if text_span == tok_answer_text: 292 | return (new_start, new_end) 293 | 294 | return (input_start, input_end) 295 | 296 | 297 | def _check_is_max_context(doc_spans, cur_span_index, position): 298 | """Check if this is the 'max context' doc span for the token.""" 299 | 300 | # Because of the sliding window approach taken to scoring documents, a single 301 | # token can appear in multiple documents. E.g. 302 | # Doc: the man went to the store and bought a gallon of milk 303 | # Span A: the man went to the 304 | # Span B: to the store and bought 305 | # Span C: and bought a gallon of 306 | # ... 307 | # 308 | # Now the word 'bought' will have two scores from spans B and C. We only 309 | # want to consider the score with "maximum context", which we define as 310 | # the *minimum* of its left and right context (the *sum* of left and 311 | # right context will always be the same, of course). 312 | # 313 | # In the example the maximum context for 'bought' would be span C since 314 | # it has 1 left context and 3 right context, while span B has 4 left context 315 | # and 0 right context. 316 | best_score = None 317 | best_span_index = None 318 | for (span_index, doc_span) in enumerate(doc_spans): 319 | end = doc_span.start + doc_span.length - 1 320 | if position < doc_span.start: 321 | continue 322 | if position > end: 323 | continue 324 | num_left_context = position - doc_span.start 325 | num_right_context = end - position 326 | score = min(num_left_context, num_right_context) + 0.01 * doc_span.length 327 | if best_score is None or score > best_score: 328 | best_score = score 329 | best_span_index = span_index 330 | 331 | return cur_span_index == best_span_index 332 | 333 | 334 | RawResult = collections.namedtuple("RawResult", 335 | ["unique_id", "start_logits", "end_logits"]) 336 | 337 | 338 | def write_predictions(all_examples, all_features, all_results, n_best_size, 339 | max_answer_length, do_lower_case, output_prediction_file, 340 | output_nbest_file, output_null_log_odds_file, verbose_logging, 341 | version_2_with_negative, null_score_diff_threshold): 342 | """Write final predictions to the json file and log-odds of null if needed.""" 343 | logger.info("Writing predictions to: %s" % (output_prediction_file)) 344 | logger.info("Writing nbest to: %s" % (output_nbest_file)) 345 | 346 | example_index_to_features = collections.defaultdict(list) 347 | for feature in all_features: 348 | example_index_to_features[feature.example_index].append(feature) 349 | 350 | unique_id_to_result = {} 351 | for result in all_results: 352 | unique_id_to_result[result.unique_id] = result 353 | 354 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name 355 | "PrelimPrediction", 356 | ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]) 357 | 358 | all_predictions = collections.OrderedDict() 359 | all_nbest_json = collections.OrderedDict() 360 | scores_diff_json = collections.OrderedDict() 361 | 362 | for (example_index, example) in enumerate(all_examples): 363 | features = example_index_to_features[example_index] 364 | 365 | prelim_predictions = [] 366 | # keep track of the minimum score of null start+end of position 0 367 | score_null = 1000000 # large and positive 368 | min_null_feature_index = 0 # the paragraph slice with min null score 369 | null_start_logit = 0 # the start logit at the slice with min null score 370 | null_end_logit = 0 # the end logit at the slice with min null score 371 | for (feature_index, feature) in enumerate(features): 372 | result = unique_id_to_result[feature.unique_id] 373 | start_indexes = _get_best_indexes(result.start_logits, n_best_size) 374 | end_indexes = _get_best_indexes(result.end_logits, n_best_size) 375 | # if we could have irrelevant answers, get the min score of irrelevant 376 | feature_null_score = result.start_logits[0] + result.end_logits[0] 377 | if feature_null_score < score_null: 378 | score_null = feature_null_score 379 | min_null_feature_index = feature_index 380 | null_start_logit = result.start_logits[0] 381 | null_end_logit = result.end_logits[0] 382 | for start_index in start_indexes: 383 | for end_index in end_indexes: 384 | # We could hypothetically create invalid predictions, e.g., predict 385 | # that the start of the span is in the question. We throw out all 386 | # invalid predictions. 387 | if start_index >= len(feature.tokens): 388 | continue 389 | if end_index >= len(feature.tokens): 390 | continue 391 | if start_index not in feature.token_to_orig_map: 392 | continue 393 | if end_index not in feature.token_to_orig_map: 394 | continue 395 | if not feature.token_is_max_context.get(start_index, False): 396 | continue 397 | if end_index < start_index: 398 | continue 399 | length = end_index - start_index + 1 400 | if length > max_answer_length: 401 | continue 402 | prelim_predictions.append( 403 | _PrelimPrediction( 404 | feature_index=feature_index, 405 | start_index=start_index, 406 | end_index=end_index, 407 | start_logit=result.start_logits[start_index], 408 | end_logit=result.end_logits[end_index])) 409 | prelim_predictions.append( 410 | _PrelimPrediction( 411 | feature_index=min_null_feature_index, 412 | start_index=0, 413 | end_index=0, 414 | start_logit=null_start_logit, 415 | end_logit=null_end_logit)) 416 | prelim_predictions = sorted( 417 | prelim_predictions, 418 | key=lambda x: (x.start_logit + x.end_logit), 419 | reverse=True) 420 | 421 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name 422 | "NbestPrediction", ["text", "start_logit", "end_logit"]) 423 | 424 | seen_predictions = {} 425 | nbest = [] 426 | for pred in prelim_predictions: 427 | if len(nbest) >= n_best_size: 428 | break 429 | feature = features[pred.feature_index] 430 | if pred.start_index > 0: # this is a non-null prediction 431 | tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] 432 | orig_doc_start = feature.token_to_orig_map[pred.start_index] 433 | orig_doc_end = feature.token_to_orig_map[pred.end_index] 434 | orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] 435 | tok_text = " ".join(tok_tokens) 436 | 437 | # De-tokenize WordPieces that have been split off. 438 | tok_text = tok_text.replace(" ##", "") 439 | tok_text = tok_text.replace("##", "") 440 | 441 | # Clean whitespace 442 | tok_text = tok_text.strip() 443 | tok_text = " ".join(tok_text.split()) 444 | orig_text = " ".join(orig_tokens) 445 | 446 | final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging) 447 | if final_text in seen_predictions: 448 | continue 449 | 450 | seen_predictions[final_text] = True 451 | else: 452 | final_text = "" 453 | seen_predictions[final_text] = True 454 | 455 | nbest.append( 456 | _NbestPrediction( 457 | text=final_text, 458 | start_logit=pred.start_logit, 459 | end_logit=pred.end_logit)) 460 | # if we didn't include the empty option in the n-best, include it 461 | if "" not in seen_predictions: 462 | nbest.append( 463 | _NbestPrediction( 464 | text="", 465 | start_logit=null_start_logit, 466 | end_logit=null_end_logit)) 467 | 468 | # In very rare edge cases we could only have single null prediction. 469 | # So we just create a nonce prediction in this case to avoid failure. 470 | if len(nbest)==1: 471 | nbest.insert(0, 472 | _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) 473 | 474 | # In very rare edge cases we could have no valid predictions. So we 475 | # just create a nonce prediction in this case to avoid failure. 476 | if not nbest: 477 | nbest.append( 478 | _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) 479 | 480 | assert len(nbest) >= 1 481 | 482 | total_scores = [] 483 | best_non_null_entry = None 484 | for entry in nbest: 485 | total_scores.append(entry.start_logit + entry.end_logit) 486 | if not best_non_null_entry: 487 | if entry.text: 488 | best_non_null_entry = entry 489 | 490 | probs = _compute_softmax(total_scores) 491 | 492 | nbest_json = [] 493 | for (i, entry) in enumerate(nbest): 494 | output = collections.OrderedDict() 495 | output["text"] = entry.text 496 | output["probability"] = probs[i] 497 | output["start_logit"] = entry.start_logit 498 | output["end_logit"] = entry.end_logit 499 | nbest_json.append(output) 500 | 501 | assert len(nbest_json) >= 1 502 | 503 | all_predictions[example.qas_id] = nbest_json[0]["text"] 504 | 505 | with open(output_prediction_file, "w") as writer: 506 | writer.write(json.dumps(all_predictions, indent=4) + "\n") 507 | 508 | with open(output_nbest_file, "w") as writer: 509 | writer.write(json.dumps(all_nbest_json, indent=4) + "\n") 510 | 511 | if version_2_with_negative: 512 | with open(output_null_log_odds_file, "w") as writer: 513 | writer.write(json.dumps(scores_diff_json, indent=4) + "\n") 514 | 515 | 516 | def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): 517 | """Project the tokenized prediction back to the original text.""" 518 | 519 | # When we created the data, we kept track of the alignment between original 520 | # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So 521 | # now `orig_text` contains the span of our original text corresponding to the 522 | # span that we predicted. 523 | # 524 | # However, `orig_text` may contain extra characters that we don't want in 525 | # our prediction. 526 | # 527 | # For example, let's say: 528 | # pred_text = steve smith 529 | # orig_text = Steve Smith's 530 | # 531 | # We don't want to return `orig_text` because it contains the extra "'s". 532 | # 533 | # We don't want to return `pred_text` because it's already been normalized 534 | # (the SQuAD eval script also does punctuation stripping/lower casing but 535 | # our tokenizer does additional normalization like stripping accent 536 | # characters). 537 | # 538 | # What we really want to return is "Steve Smith". 539 | # 540 | # Therefore, we have to apply a semi-complicated alignment heuristic between 541 | # `pred_text` and `orig_text` to get a character-to-character alignment. This 542 | # can fail in certain cases in which case we just return `orig_text`. 543 | 544 | def _strip_spaces(text): 545 | ns_chars = [] 546 | ns_to_s_map = collections.OrderedDict() 547 | for (i, c) in enumerate(text): 548 | if c == " ": 549 | continue 550 | ns_to_s_map[len(ns_chars)] = i 551 | ns_chars.append(c) 552 | ns_text = "".join(ns_chars) 553 | return (ns_text, ns_to_s_map) 554 | 555 | # We first tokenize `orig_text`, strip whitespace from the result 556 | # and `pred_text`, and check if they are the same length. If they are 557 | # NOT the same length, the heuristic has failed. If they are the same 558 | # length, we assume the characters are one-to-one aligned. 559 | tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 560 | 561 | tok_text = " ".join(tokenizer.tokenize(orig_text)) 562 | 563 | start_position = tok_text.find(pred_text) 564 | if start_position == -1: 565 | if verbose_logging: 566 | logger.info( 567 | "Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) 568 | return orig_text 569 | end_position = start_position + len(pred_text) - 1 570 | 571 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) 572 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) 573 | 574 | if len(orig_ns_text) != len(tok_ns_text): 575 | if verbose_logging: 576 | logger.info("Length not equal after stripping spaces: '%s' vs '%s'", 577 | orig_ns_text, tok_ns_text) 578 | return orig_text 579 | 580 | # We then project the characters in `pred_text` back to `orig_text` using 581 | # the character-to-character alignment. 582 | tok_s_to_ns_map = {} 583 | for (i, tok_index) in tok_ns_to_s_map.items(): 584 | tok_s_to_ns_map[tok_index] = i 585 | 586 | orig_start_position = None 587 | if start_position in tok_s_to_ns_map: 588 | ns_start_position = tok_s_to_ns_map[start_position] 589 | if ns_start_position in orig_ns_to_s_map: 590 | orig_start_position = orig_ns_to_s_map[ns_start_position] 591 | 592 | if orig_start_position is None: 593 | if verbose_logging: 594 | logger.info("Couldn't map start position") 595 | return orig_text 596 | 597 | orig_end_position = None 598 | if end_position in tok_s_to_ns_map: 599 | ns_end_position = tok_s_to_ns_map[end_position] 600 | if ns_end_position in orig_ns_to_s_map: 601 | orig_end_position = orig_ns_to_s_map[ns_end_position] 602 | 603 | if orig_end_position is None: 604 | if verbose_logging: 605 | logger.info("Couldn't map end position") 606 | return orig_text 607 | 608 | output_text = orig_text[orig_start_position:(orig_end_position + 1)] 609 | return output_text 610 | 611 | 612 | def _get_best_indexes(logits, n_best_size): 613 | """Get the n-best logits from a list.""" 614 | index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) 615 | 616 | best_indexes = [] 617 | for i in range(len(index_and_score)): 618 | if i >= n_best_size: 619 | break 620 | best_indexes.append(index_and_score[i][0]) 621 | return best_indexes 622 | 623 | 624 | def _compute_softmax(scores): 625 | """Compute softmax probability over raw logits.""" 626 | if not scores: 627 | return [] 628 | 629 | max_score = None 630 | for score in scores: 631 | if max_score is None or score > max_score: 632 | max_score = score 633 | 634 | exp_scores = [] 635 | total_sum = 0.0 636 | for score in scores: 637 | x = math.exp(score - max_score) 638 | exp_scores.append(x) 639 | total_sum += x 640 | 641 | probs = [] 642 | for score in exp_scores: 643 | probs.append(score / total_sum) 644 | return probs 645 | 646 | -------------------------------------------------------------------------------- /gpt2sqa/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | from __future__ import (absolute_import, division, print_function, 20 | unicode_literals) 21 | import collections 22 | import regex as re 23 | import json 24 | import sys 25 | import logging 26 | import os 27 | import unicodedata 28 | from io import open 29 | 30 | from .file_utils import cached_path 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | def whitespace_tokenize(text): 36 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 37 | text = text.strip() 38 | if not text: 39 | return [] 40 | tokens = text.split() 41 | return tokens 42 | 43 | 44 | try: 45 | from functools import lru_cache 46 | except ImportError: 47 | # Just a dummy decorator to get the checks to run on python2 48 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. 49 | def lru_cache(): 50 | return lambda func: func 51 | 52 | 53 | logger = logging.getLogger(__name__) 54 | 55 | PRETRAINED_VOCAB_ARCHIVE_MAP = {'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",} 56 | PRETRAINED_MERGES_ARCHIVE_MAP = {'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",} 57 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {'gpt2': 1024,} 58 | VOCAB_NAME = 'vocab.json' 59 | MERGES_NAME = 'merges.txt' 60 | SPECIAL_TOKENS_NAME = 'special_tokens.txt' 61 | 62 | 63 | @lru_cache() 64 | def get_pairs(word): 65 | """Return set of symbol pairs in a word. 66 | 67 | Word is represented as tuple of symbols (symbols being variable-length strings). 68 | """ 69 | pairs = set() 70 | prev_char = word[0] 71 | for char in word[1:]: 72 | pairs.add((prev_char, char)) 73 | prev_char = char 74 | return pairs 75 | 76 | 77 | class GPT2Tokenizer(object): 78 | """ 79 | GPT-2 BPE tokenizer. Peculiarities: 80 | - Byte-level BPE 81 | """ 82 | @classmethod 83 | def from_pretrained(cls, pretrained_model_name_or_path='gpt2', cache_dir=None, *inputs, **kwargs): 84 | """ 85 | Download and cache the pre-trained model file if needed. 86 | """ 87 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 88 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 89 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] 90 | special_tokens_file = None 91 | else: 92 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) 93 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) 94 | special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) 95 | if not os.path.exists(special_tokens_file): 96 | special_tokens_file = None 97 | else: 98 | logger.info("loading special tokens file {}".format(special_tokens_file)) 99 | # redirect to the cache, if necessary 100 | try: 101 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 102 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) 103 | except EnvironmentError: 104 | logger.error( 105 | "Model name '{}' was not found in model name list ({}). " 106 | "We assumed '{}' was a path or url but couldn't find files {} and {} " 107 | "at this path or url.".format( 108 | pretrained_model_name_or_path, 109 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 110 | pretrained_model_name_or_path, 111 | vocab_file, merges_file)) 112 | return None 113 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: 114 | logger.info("loading vocabulary file {}".format(vocab_file)) 115 | logger.info("loading merges file {}".format(merges_file)) 116 | else: 117 | logger.info("loading vocabulary file {} from cache at {}".format( 118 | vocab_file, resolved_vocab_file)) 119 | logger.info("loading merges file {} from cache at {}".format( 120 | merges_file, resolved_merges_file)) 121 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 122 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 123 | # than the number of positional embeddings 124 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 125 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 126 | # Instantiate tokenizer. 127 | if special_tokens_file and 'special_tokens' not in kwargs: 128 | special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] 129 | else: 130 | special_tokens = kwargs.pop('special_tokens', []) 131 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) 132 | return tokenizer 133 | 134 | def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None): 135 | self.max_len = max_len if max_len is not None else int(1e12) 136 | self.encoder = json.load(open(vocab_file)) 137 | self.decoder = {v: k for k, v in self.encoder.items()} 138 | self.errors = errors # how to handle errors in decoding 139 | bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 140 | bpe_merges = [tuple(merge.split()) for merge in bpe_data] 141 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 142 | self.cache = {} 143 | 144 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 145 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 146 | 147 | self.special_tokens = {} 148 | self.special_tokens_decoder = {} 149 | self.set_special_tokens(special_tokens) 150 | 151 | def __len__(self): 152 | return len(self.encoder) + len(self.special_tokens) 153 | 154 | def set_special_tokens(self, special_tokens): 155 | """ Add a list of additional tokens to the encoder. 156 | The additional tokens are indexed starting from the last index of the 157 | current vocabulary in the order of the `special_tokens` list. 158 | """ 159 | if not special_tokens: 160 | self.special_tokens = {} 161 | self.special_tokens_decoder = {} 162 | return 163 | self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) 164 | self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()} 165 | logger.info("Special tokens {}".format(self.special_tokens)) 166 | 167 | def bpe(self, token): 168 | if token in self.cache: 169 | return self.cache[token] 170 | word = tuple(token) 171 | pairs = get_pairs(word) 172 | 173 | if not pairs: 174 | return token 175 | 176 | while True: 177 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 178 | if bigram not in self.bpe_ranks: 179 | break 180 | first, second = bigram 181 | new_word = [] 182 | i = 0 183 | while i < len(word): 184 | try: 185 | j = word.index(first, i) 186 | new_word.extend(word[i:j]) 187 | i = j 188 | except: 189 | new_word.extend(word[i:]) 190 | break 191 | 192 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 193 | new_word.append(first + second) 194 | i += 2 195 | else: 196 | new_word.append(word[i]) 197 | i += 1 198 | new_word = tuple(new_word) 199 | word = new_word 200 | if len(word) == 1: 201 | break 202 | else: 203 | pairs = get_pairs(word) 204 | word = ' '.join(word) 205 | self.cache[token] = word 206 | return word 207 | 208 | def tokenize(self, text): 209 | """ Tokenize a string. """ 210 | bpe_tokens = [] 211 | for token in re.findall(self.pat, text): 212 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) 213 | return bpe_tokens 214 | 215 | def convert_tokens_to_ids(self, tokens): 216 | """ Converts a sequence of tokens into ids using the vocab. """ 217 | ids = [] 218 | if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): 219 | if tokens in self.special_tokens: 220 | return self.special_tokens[tokens] 221 | else: 222 | return self.encoder.get(tokens, 0) 223 | for token in tokens: 224 | if token in self.special_tokens: 225 | ids.append(self.special_tokens[token]) 226 | else: 227 | ids.append(self.encoder.get(token, 0)) 228 | if len(ids) > self.max_len: 229 | logger.warning( 230 | "Token indices sequence length is longer than the specified maximum " 231 | " sequence length for this OpenAI GPT model ({} > {}). Running this" 232 | " sequence through the model will result in indexing errors".format(len(ids), self.max_len) 233 | ) 234 | return ids 235 | 236 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False): 237 | """Converts a sequence of ids in BPE tokens using the vocab.""" 238 | tokens = [] 239 | for i in ids: 240 | if i in self.special_tokens_decoder: 241 | if not skip_special_tokens: 242 | tokens.append(self.special_tokens_decoder[i]) 243 | else: 244 | tokens.append(self.decoder[i]) 245 | return tokens 246 | 247 | def encode(self, text): 248 | return self.convert_tokens_to_ids(self.tokenize(text)) 249 | 250 | def decode(self, tokens): 251 | text = ''.join([self.decoder[token] for token in tokens]) 252 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 253 | return text 254 | 255 | def save_vocabulary(self, vocab_path): 256 | """Save the tokenizer vocabulary and merge files to a directory.""" 257 | if not os.path.isdir(vocab_path): 258 | logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) 259 | return 260 | vocab_file = os.path.join(vocab_path, VOCAB_NAME) 261 | merge_file = os.path.join(vocab_path, MERGES_NAME) 262 | special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) 263 | 264 | with open(vocab_file, 'w', encoding='utf-8') as f: 265 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 266 | 267 | index = 0 268 | with open(merge_file, "w", encoding="utf-8") as writer: 269 | writer.write(u'#version: 0.2\n') 270 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 271 | if index != token_index: 272 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." 273 | " Please check that the tokenizer is not corrupted!".format(merge_file)) 274 | index = token_index 275 | writer.write(' '.join(bpe_tokens) + u'\n') 276 | index += 1 277 | 278 | index = len(self.encoder) 279 | with open(special_tokens_file, 'w', encoding='utf-8') as writer: 280 | for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): 281 | if index != token_index: 282 | logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." 283 | " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) 284 | index = token_index 285 | writer.write(token + u'\n') 286 | index += 1 287 | 288 | return vocab_file, merge_file, special_tokens_file 289 | 290 | 291 | class BasicTokenizer(object): 292 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 293 | 294 | def __init__(self, 295 | do_lower_case=True, 296 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 297 | """Constructs a BasicTokenizer. 298 | Args: 299 | do_lower_case: Whether to lower case the input. 300 | """ 301 | self.do_lower_case = do_lower_case 302 | self.never_split = never_split 303 | 304 | def tokenize(self, text): 305 | """Tokenizes a piece of text.""" 306 | text = self._clean_text(text) 307 | text = self._tokenize_chinese_chars(text) 308 | orig_tokens = whitespace_tokenize(text) 309 | split_tokens = [] 310 | for token in orig_tokens: 311 | if self.do_lower_case and token not in self.never_split: 312 | token = token.lower() 313 | token = self._run_strip_accents(token) 314 | split_tokens.extend(self._run_split_on_punc(token)) 315 | 316 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 317 | return output_tokens 318 | 319 | def _run_strip_accents(self, text): 320 | """Strips accents from a piece of text.""" 321 | text = unicodedata.normalize("NFD", text) 322 | output = [] 323 | for char in text: 324 | cat = unicodedata.category(char) 325 | if cat == "Mn": 326 | continue 327 | output.append(char) 328 | return "".join(output) 329 | 330 | def _run_split_on_punc(self, text): 331 | """Splits punctuation on a piece of text.""" 332 | if text in self.never_split: 333 | return [text] 334 | chars = list(text) 335 | i = 0 336 | start_new_word = True 337 | output = [] 338 | while i < len(chars): 339 | char = chars[i] 340 | if _is_punctuation(char): 341 | output.append([char]) 342 | start_new_word = True 343 | else: 344 | if start_new_word: 345 | output.append([]) 346 | start_new_word = False 347 | output[-1].append(char) 348 | i += 1 349 | 350 | return ["".join(x) for x in output] 351 | 352 | def _tokenize_chinese_chars(self, text): 353 | """Adds whitespace around any CJK character.""" 354 | output = [] 355 | for char in text: 356 | cp = ord(char) 357 | if self._is_chinese_char(cp): 358 | output.append(" ") 359 | output.append(char) 360 | output.append(" ") 361 | else: 362 | output.append(char) 363 | return "".join(output) 364 | 365 | def _is_chinese_char(self, cp): 366 | """Checks whether CP is the codepoint of a CJK character.""" 367 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 368 | (cp >= 0x3400 and cp <= 0x4DBF) or # 369 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 370 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 371 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 372 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 373 | (cp >= 0xF900 and cp <= 0xFAFF) or # 374 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 375 | return True 376 | 377 | return False 378 | 379 | def _clean_text(self, text): 380 | """Performs invalid character removal and whitespace cleanup on text.""" 381 | output = [] 382 | for char in text: 383 | cp = ord(char) 384 | if cp == 0 or cp == 0xfffd or _is_control(char): 385 | continue 386 | if _is_whitespace(char): 387 | output.append(" ") 388 | else: 389 | output.append(char) 390 | return "".join(output) 391 | 392 | 393 | def _is_whitespace(char): 394 | """Checks whether `chars` is a whitespace character.""" 395 | # \t, \n, and \r are technically contorl characters but we treat them 396 | # as whitespace since they are generally considered as such. 397 | if char == " " or char == "\t" or char == "\n" or char == "\r": 398 | return True 399 | cat = unicodedata.category(char) 400 | if cat == "Zs": 401 | return True 402 | return False 403 | 404 | 405 | def _is_control(char): 406 | """Checks whether `chars` is a control character.""" 407 | # These are technically control characters but we count them as whitespace 408 | # characters. 409 | if char == "\t" or char == "\n" or char == "\r": 410 | return False 411 | cat = unicodedata.category(char) 412 | if cat.startswith("C"): 413 | return True 414 | return False 415 | 416 | 417 | def _is_punctuation(char): 418 | """Checks whether `chars` is a punctuation character.""" 419 | cp = ord(char) 420 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 421 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 422 | return True 423 | cat = unicodedata.category(char) 424 | if cat.startswith("P"): 425 | return True 426 | return False 427 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # PyTorch 2 | torch>=0.4.1 3 | # progress bars in model download and training scripts 4 | tqdm 5 | # Accessing files from S3 directly. 6 | boto3 7 | # Used for downloading models over HTTP 8 | requests 9 | # For OpenAI GPT 10 | regex -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py 3 | 4 | To create the package for pypi. 5 | 6 | 1. Change the version in __init__.py and setup.py. 7 | 8 | 2. Commit these changes with the message: "Release: VERSION" 9 | 10 | 3. Add a tag in git to mark the release: "git tag VERSION -m'Adds tag VERSION for pypi' " 11 | Push the tag to git: git push --tags origin master 12 | 13 | 4. Build both the sources and the wheel. Do not change anything in setup.py between 14 | creating the wheel and the source distribution (obviously). 15 | 16 | For the wheel, run: "python setup.py bdist_wheel" in the top level allennlp directory. 17 | (this will build a wheel for the python version you use to build it - make sure you use python 3.x). 18 | 19 | For the sources, run: "python setup.py sdist" 20 | You should now have a /dist directory with both .whl and .tar.gz source versions of allennlp. 21 | 22 | 5. Check that everything looks correct by uploading the package to the pypi test server: 23 | 24 | twine upload dist/* -r pypitest 25 | (pypi suggest using twine as other methods upload files via plaintext.) 26 | 27 | Check that you can install it in a virtualenv by running: 28 | pip install -i https://testpypi.python.org/pypi allennlp 29 | 30 | 6. Upload the final version to actual pypi: 31 | twine upload dist/* -r pypi 32 | 33 | 7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory. 34 | 35 | """ 36 | from io import open 37 | from setuptools import find_packages, setup 38 | 39 | setup( 40 | name="pytorch_pretrained_bert", 41 | version="0.6.2", 42 | author="Thomas Wolf, Victor Sanh, Tim Rault, Google AI Language Team Authors, Open AI team Authors", 43 | author_email="thomas@huggingface.co", 44 | description="PyTorch version of Google AI BERT model with script to load Google pre-trained models", 45 | long_description=open("README.md", "r", encoding='utf-8').read(), 46 | long_description_content_type="text/markdown", 47 | keywords='BERT NLP deep learning google', 48 | license='Apache', 49 | url="https://github.com/huggingface/pytorch-pretrained-BERT", 50 | packages=find_packages(exclude=["*.tests", "*.tests.*", 51 | "tests.*", "tests"]), 52 | install_requires=['torch>=0.4.1', 53 | 'numpy', 54 | 'boto3', 55 | 'requests', 56 | 'tqdm', 57 | 'regex'], 58 | entry_points={ 59 | 'console_scripts': [ 60 | "pytorch_pretrained_bert=pytorch_pretrained_bert.__main__:main", 61 | ] 62 | }, 63 | # python_requires='>=3.5.0', 64 | tests_require=['pytest'], 65 | classifiers=[ 66 | 'Intended Audience :: Science/Research', 67 | 'License :: OSI Approved :: Apache Software License', 68 | 'Programming Language :: Python :: 3', 69 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 70 | ], 71 | ) 72 | --------------------------------------------------------------------------------