├── running.gif ├── static ├── main.js ├── chat.html ├── chat_main.js ├── index.html └── chat_main.css ├── README.md ├── server.py ├── infer.py ├── infer_utils.py └── modeling.py /running.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arijitx/QnA-Bot/HEAD/running.gif -------------------------------------------------------------------------------- /static/main.js: -------------------------------------------------------------------------------- 1 | $(document).ready(function() { 2 | $("#create_bot").click(function(){ 3 | var bot_id = $("#bot_id").val(); 4 | var context = $("#context_ta").val(); 5 | var bot_im_url = $('#bot_im_url').val() 6 | $.post("/create_bot", {"id":bot_id, "context":context, "bot_im_url":bot_im_url},function(data, status){ 7 | alert("Succesfully Created!"); 8 | $('#bot_url').text("Deployed at : "+data); 9 | }); 10 | }); 11 | }); 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # QnA-Bot 2 | QnA bot powered by CoQA + BERT built in Pytorch. This project was done as course project of EE 763 jointly with [Rakesh](https://github.com/RKhobrag) and [Kuber](https://github.com/kuberg1/) 3 | 4 | ## Installation 5 | 6 | pytorch-pretrained-BERT version 0.6.1 7 | torch version 1.0.1.post2 8 | 9 | pip3 install torch 10 | pip3 install flask 11 | pip3 install git+https://github.com/huggingface/pytorch-pretrained-BERT.git 12 | 13 | Download pretrained model from : https://drive.google.com/file/d/15HOJmRizBrgoPPVDHKpvSO2tNf0k-d8f/view?usp=sharing 14 | 15 | python3 server.py 16 | 17 | Google Colab [Demo](https://colab.research.google.com/drive/1Alz7NMENYc1S28EqUDwbe7hf1M8HpYSO#scrollTo=vT40VTBxJMlT) 18 | 19 | ## Running 20 | 21 | ![alt text](https://raw.githubusercontent.com/arijitx/QnA-Bot/master/running.gif) 22 | 23 | 24 | -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, request, render_template 2 | from flask import send_file,session,jsonify 3 | import string 4 | import random 5 | from infer import * 6 | 7 | 8 | def id_generator(size=4, chars=string.ascii_lowercase): 9 | return ''.join(random.choice(chars) for _ in range(size)) 10 | 11 | app = Flask(__name__,template_folder='static') 12 | 13 | @app.route('/static/') 14 | def serve_static(filename): 15 | root_dir = os.path.dirname(os.getcwd()) 16 | return send_from_directory(os.path.join(root_dir, 'static'), filename) 17 | 18 | @app.route('/') 19 | def home(): 20 | return send_file('static/index.html') 21 | 22 | @app.route('/chat/',methods=['GET', 'POST']) 23 | def chat(bot_id): 24 | if request.method == 'POST': 25 | if bot_id in table: 26 | context = table[bot_id]['context'] 27 | question = request.form.get('ques'); 28 | prev_q = request.form.get('prev_q'); 29 | prev_a = request.form.get('prev_a'); 30 | answer = iq.predict(context,question,prev_q,prev_a) 31 | return answer 32 | if request.method == 'GET': 33 | if bot_id not in table: 34 | bot_id = "Oops! Bot not found!" 35 | bot_im = "" 36 | else: 37 | bot_im = table[bot_id]["im_url"] 38 | return render_template('chat.html',bot=bot_id,bot_im=bot_im) 39 | 40 | @app.route('/create_bot',methods=['GET', 'POST']) 41 | def create_bot(): 42 | 43 | bot_id = request.form.get('id') 44 | context = request.form.get('context') 45 | bot_im = request.form.get('bot_im_url') 46 | table[bot_id] = {"context":context,"bot_name":bot_id,"im_url":bot_im} 47 | 48 | return SERVER+'/chat/'+bot_id 49 | 50 | 51 | # SERVER = "10.129.6.41:5000" 52 | table = {} 53 | iq = InferCoQA('model') 54 | print('done loading model ..') 55 | 56 | app.run(host='0.0.0.0', debug=True) -------------------------------------------------------------------------------- /static/chat.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | QnA Chat Bot 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 |
18 |
19 |
20 | 21 | 22 |
23 |
24 |
25 |
26 |
27 | 28 |

{{ bot }}

29 |







30 | 31 | QnA Bot
32 | Made with at IIT Bombay 33 |

34 | 35 | 36 |
37 |
38 | 39 |
40 |
41 |
42 |
43 | 44 | 45 |
46 |
47 | 48 |
49 |
50 |
51 | 52 | {{ bot }} 53 |
54 |
55 |
56 |
57 |
58 | 59 |
60 |
61 | 62 | 63 |
64 |
65 |
66 |
67 |
68 |
69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /static/chat_main.js: -------------------------------------------------------------------------------- 1 | function formatAMPM(date) { 2 | var hours = date.getHours(); 3 | var minutes = date.getMinutes(); 4 | var ampm = hours >= 12 ? 'PM' : 'AM'; 5 | hours = hours % 12; 6 | hours = hours ? hours : 12; // the hour '0' should be '12' 7 | minutes = minutes < 10 ? '0'+minutes : minutes; 8 | var strTime = hours + ':' + minutes + ' ' + ampm; 9 | return strTime; 10 | } 11 | 12 | var prev_q = ""; 13 | var prev_a = ""; 14 | var bot_id = ""; 15 | var bot_im_url = ""; 16 | 17 | $(document).ready(function() { 18 | bot_id = $('#bot_id').text(); 19 | bot_im_url = $('#bot_im_url').text(); 20 | }); 21 | 22 | $(document).keypress(function(e) { 23 | if(e.which == 13) { 24 | var ms = { 25 | username:'arijit', 26 | name: 'Arijit', 27 | avatar: 'https://bootdey.com/img/Content/avatar/avatar2.png', 28 | text: $("#msg").val(), 29 | ago : '' 30 | }; 31 | position = 'right'; 32 | htmldiv = `
33 |
34 | ${ms.name} 35 |
36 |
37 |
38 | ${ms.text} 39 |
40 |
`+formatAMPM(new Date)+`
41 |
`; 42 | 43 | $( "div#chat-messages" ).append(htmldiv); 44 | $("#chat-messages").animate({ scrollTop: $('#chat-messages').prop("scrollHeight")}, 1000); 45 | 46 | $.post("/chat/"+$("#bot_id").text(),{"ques":$("#msg").val(),"prev_q":prev_q,"prev_a":prev_a},function(data, status){ 47 | prev_a = data; 48 | prev_q = $("#msg").val(); 49 | var ms = { 50 | username:'bot', 51 | name: '$("#bot_id").text()', 52 | avatar: bot_im_url, 53 | text: prev_a, 54 | ago : '' 55 | }; 56 | position = 'left'; 57 | htmldiv = `
58 |
59 | ${ms.name} 60 |
61 |
62 |
63 | ${ms.text} 64 |
65 |
`+formatAMPM(new Date)+`
66 |
`; 67 | 68 | $( "div#chat-messages" ).append(htmldiv); 69 | $("#chat-messages").animate({ scrollTop: $('#chat-messages').prop("scrollHeight")}, 1000); 70 | 71 | $("#msg").val(""); 72 | }); 73 | 74 | } 75 | 76 | }); -------------------------------------------------------------------------------- /static/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | QnA Chat Bot 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 |
18 |
19 |
20 |
21 | 22 |









23 | QnA Bot
24 | Made with at IIT Bombay 25 |

26 | 27 | 28 |
29 |
30 |
31 |
32 |
33 | 34 | 35 |
36 |
37 | 38 | 39 |
40 |
41 | 42 | 46 |
47 |
48 | 49 | 50 |
51 |
52 | 53 |
54 |
55 |





56 |

57 |
58 |
59 |
60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import json 4 | import logging 5 | import math 6 | import os 7 | import random 8 | import sys 9 | from io import open 10 | import time 11 | 12 | import numpy as np 13 | import torch 14 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,TensorDataset) 15 | from torch.utils.data.distributed import DistributedSampler 16 | from tqdm import tqdm, trange 17 | 18 | from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE 19 | from modeling import BertForQuestionAnswering, BertConfig 20 | from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear 21 | from pytorch_pretrained_bert.tokenization import (BasicTokenizer,BertTokenizer,whitespace_tokenize) 22 | 23 | if sys.version_info[0] == 2: 24 | import cPickle as pickle 25 | else: 26 | import pickle 27 | 28 | from infer_utils import * 29 | import spacy 30 | nlp = spacy.load('en_core_web_md') 31 | 32 | def is_whitespace(c): 33 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: 34 | return True 35 | return False 36 | 37 | def is_punc(c): 38 | if c in '?,.!()[]-_\'"': 39 | return True 40 | return False 41 | 42 | def punc_sep(s): 43 | tokens = [] 44 | is_prev_white = True 45 | for c in s: 46 | if is_whitespace(c): 47 | is_prev_white = True 48 | else: 49 | if is_punc(c): 50 | tokens.append(c) 51 | is_prev_white = True 52 | else: 53 | if is_prev_white: 54 | is_prev_white = False 55 | tokens.append(c) 56 | else: 57 | tokens[-1]+=c 58 | return ' '.join(tokens) 59 | 60 | def str_to_coqa_example(contenxt, question, prev_ques, prev_answ): 61 | paragraph_text = contenxt 62 | doc_tokens = [] 63 | char_to_word_offset = [] 64 | prev_is_whitespace = True 65 | for c in paragraph_text: 66 | if is_whitespace(c): 67 | prev_is_whitespace = True 68 | else: 69 | if prev_is_whitespace: 70 | doc_tokens.append(c) 71 | prev_is_whitespace = False 72 | else: 73 | doc_tokens[-1] += c 74 | 75 | char_to_word_offset.append(len(doc_tokens) - 1) 76 | 77 | question_text = question 78 | 79 | example = CoQAExample( 80 | qas_id='random', 81 | question_text=question_text, 82 | doc_tokens=doc_tokens, 83 | orig_answer_text="", 84 | start_position=0, 85 | end_position=0, 86 | is_impossible=False, 87 | is_yes= False, 88 | is_no=False, 89 | answer_span="", 90 | prev_ques=prev_ques, 91 | prev_answ=prev_answ) 92 | return example 93 | 94 | class InferCoQA(): 95 | def __init__(self, model_path, lower_case = True): 96 | self.model_path = model_path 97 | self.tokenizer = BertTokenizer.from_pretrained(model_path, do_lower_case=lower_case) 98 | self.model = BertForQuestionAnswering.from_pretrained(model_path) 99 | self.model.cuda() 100 | self.model.eval() 101 | 102 | def predict(self, contenxt, question, prev_ques, prev_answ): 103 | t = time.time() 104 | coqa_example = str_to_coqa_example(contenxt, question, prev_ques, prev_answ) 105 | coqa_features = convert_examples_to_features([coqa_example], self.tokenizer, max_seq_length=512,doc_stride=128, max_query_length=100, is_training=False) 106 | 107 | all_input_ids = torch.tensor([f.input_ids for f in coqa_features], dtype=torch.long) 108 | all_input_mask = torch.tensor([f.input_mask for f in coqa_features], dtype=torch.long) 109 | all_segment_ids = torch.tensor([f.segment_ids for f in coqa_features], dtype=torch.long) 110 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 111 | coqa_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) 112 | 113 | coqa_sampler = SequentialSampler(coqa_data) 114 | coqa_dataloader = DataLoader(coqa_data, sampler=coqa_sampler, batch_size=1) 115 | all_results = [] 116 | for input_ids, input_mask, segment_ids, example_indices in coqa_dataloader: 117 | input_ids = input_ids.cuda() 118 | input_mask = input_mask.cuda() 119 | segment_ids = segment_ids.cuda() 120 | 121 | 122 | with torch.no_grad(): 123 | score = self.model(input_ids, segment_ids, input_mask) 124 | 125 | coqa_feature = coqa_features[example_indices[0].item()] 126 | unique_id = int(coqa_feature.unique_id) 127 | all_results.append(RawResult(unique_id=unique_id,score=score[0].cpu(),length=input_ids.size(1))) 128 | 129 | output_prediction_file = "predictions.json" 130 | output_nbest_file = "nbest_predictions.json" 131 | output_null_log_odds_file = "null_odds.json" 132 | write_predictions([coqa_example], coqa_features, all_results, 133 | 1, 100, 134 | True, output_prediction_file, 135 | output_nbest_file, output_null_log_odds_file, False, 136 | False, 0.0) 137 | os.remove(output_nbest_file) 138 | res = json.loads(open(output_prediction_file).read())['random'] 139 | os.remove(output_prediction_file) 140 | print('inference time :',time.time() - t ) 141 | return res 142 | 143 | # iq = InferCoQA('coqa_ynu_history_1') 144 | # print('done loading model ..') 145 | # context = input("Context : ") 146 | 147 | 148 | # prev_q = "" 149 | # prev_a = "" 150 | # while True: 151 | # q = input("Question : ") 152 | # a = iq.predict(context,q,prev_q,prev_a) 153 | # print("Answer :",a) 154 | # prev_q = q 155 | # prev_a = a 156 | 157 | 158 | -------------------------------------------------------------------------------- /static/chat_main.css: -------------------------------------------------------------------------------- 1 | body{ 2 | margin-top:20px; 3 | background:#eee; 4 | font: 200 14px/20px "Raleway", sans-serif; 5 | } 6 | #chat-head{ 7 | clear: both; 8 | position: relative; 9 | margin: -20px -20px -20px; 10 | padding: 10px; 11 | height: 9%; 12 | background: #f74860; 13 | } 14 | #bot_name{ 15 | color: #fff; 16 | display: inline-block; 17 | font-size: 25px; 18 | text-align: right; 19 | height: 40px; 20 | } 21 | 22 | #chat-messages{ 23 | top: 20px; 24 | } 25 | .row.row-broken { 26 | padding-bottom: 0; 27 | } 28 | .col-inside-lg { 29 | padding: 20px; 30 | } 31 | .chat { 32 | height: calc(100vh - 180px); 33 | } 34 | .decor-default { 35 | background-color: #ffffff; 36 | } 37 | .chat-users h6 { 38 | font-size: 20px; 39 | margin: 0 0 20px; 40 | } 41 | .chat-users .user { 42 | position: relative; 43 | padding: 0 0 0 50px; 44 | display: block; 45 | cursor: pointer; 46 | margin: 0 0 20px; 47 | } 48 | .chat-users .user .avatar { 49 | top: 0; 50 | left: 0; 51 | } 52 | .chat .avatar { 53 | width: 40px; 54 | height: 40px; 55 | position: absolute; 56 | } 57 | .chat .avatar img { 58 | display: block; 59 | border-radius: 20px; 60 | height: 100%; 61 | } 62 | .chat .avatar .status.off { 63 | border: 1px solid #5a5a5a; 64 | background: #ffffff; 65 | } 66 | .chat .avatar .status.online { 67 | background: #4caf50; 68 | } 69 | .chat .avatar .status.busy { 70 | background: #ffc107; 71 | } 72 | .chat .avatar .status.offline { 73 | background: #ed4e6e; 74 | } 75 | .chat-users .user .status { 76 | bottom: 0; 77 | left: 28px; 78 | } 79 | .chat .avatar .status { 80 | width: 10px; 81 | height: 10px; 82 | border-radius: 5px; 83 | position: absolute; 84 | } 85 | .chat-users .user .name { 86 | font-size: 14px; 87 | font-weight: bold; 88 | line-height: 20px; 89 | white-space: nowrap; 90 | overflow: hidden; 91 | text-overflow: ellipsis; 92 | } 93 | .chat-users .user .mood { 94 | font: 200 14px/20px "Raleway", sans-serif; 95 | white-space: nowrap; 96 | overflow: hidden; 97 | text-overflow: ellipsis; 98 | } 99 | 100 | /*****************CHAT BODY *******************/ 101 | .chat-body h6 { 102 | font-size: 20px; 103 | margin: 0 0 20px; 104 | } 105 | .chat-body .answer.left { 106 | padding: 0 0 0 58px; 107 | text-align: left; 108 | float: left; 109 | } 110 | .chat-body .answer { 111 | position: relative; 112 | max-width: 600px; 113 | overflow: hidden; 114 | clear: both; 115 | } 116 | .chat-body .answer.left .avatar { 117 | left: 0; 118 | } 119 | .chat-body .answer .avatar { 120 | top:5px; 121 | } 122 | .chat .avatar { 123 | width: 40px; 124 | height: 40px; 125 | position: absolute; 126 | 127 | } 128 | .chat .avatar img { 129 | display: block; 130 | border-radius: 20px; 131 | height: 40px; 132 | width: 40px; 133 | } 134 | .chat-body .answer .name { 135 | font-size: 14px; 136 | line-height: 36px; 137 | } 138 | .chat-body .answer.left .avatar .status { 139 | right: 4px; 140 | } 141 | .chat-body .answer .avatar .status { 142 | bottom: 0; 143 | } 144 | .chat-body .answer.left .text { 145 | background: #ebebeb; 146 | color: #333333; 147 | border-radius: 8px 8px 8px 0; 148 | 149 | } 150 | .chat-body .answer .text { 151 | padding: 12px; 152 | font-size: 16px; 153 | line-height: 26px; 154 | position: relative; 155 | font-weight: bolder; 156 | } 157 | .chat-body .answer.left .text:before { 158 | left: -30px; 159 | border-right-color: #ebebeb; 160 | border-right-width: 12px; 161 | } 162 | .chat-body .answer .text:before { 163 | content: ''; 164 | display: block; 165 | position: absolute; 166 | bottom: 0; 167 | border: 18px solid transparent; 168 | border-bottom-width: 0; 169 | } 170 | .chat-body .answer.left .time { 171 | padding-left: 12px; 172 | color: #333333; 173 | } 174 | .chat-body .answer .time { 175 | font-size: 16px; 176 | line-height: 36px; 177 | position: relative; 178 | padding-bottom: 1px; 179 | } 180 | /*RIGHT*/ 181 | .chat-body .answer.right { 182 | padding: 0 58px 0 0; 183 | text-align: right; 184 | float: right; 185 | } 186 | 187 | .chat-body .answer.right .avatar { 188 | right: 0; 189 | } 190 | .chat-body .answer.right .avatar .status { 191 | left: 4px; 192 | } 193 | .chat-body .answer.right .text { 194 | background: #7266ba; 195 | color: #ffffff; 196 | border-radius: 8px 8px 0 8px; 197 | } 198 | .chat-body .answer.right .text:before { 199 | right: -30px; 200 | border-left-color: #7266ba; 201 | border-left-width: 12px; 202 | } 203 | .chat-body .answer.right .time { 204 | padding-right: 12px; 205 | color: #333333; 206 | } 207 | 208 | /**************ADD FORM ***************/ 209 | .chat-body .answer-add { 210 | clear: both; 211 | position: relative; 212 | margin: 20px -20px -20px; 213 | padding: 20px; 214 | background: #46be8a; 215 | } 216 | .chat-body .answer-add input { 217 | border: none; 218 | background: none; 219 | display: block; 220 | width: 100%; 221 | font-size: 16px; 222 | line-height: 20px; 223 | 224 | padding: 0; 225 | color: #ffffff; 226 | font-weight: bolder; 227 | } 228 | .chat input { 229 | -webkit-appearance: none; 230 | border-radius: 0; 231 | } 232 | .chat-body .answer-add .answer-btn-1 { 233 | background: url("http://91.234.35.26/iwiki-admin/v1.0.0/admin/img/icon-40.png") 50% 50% no-repeat; 234 | right: 56px; 235 | } 236 | .chat-body .answer-add .answer-btn { 237 | display: block; 238 | cursor: pointer; 239 | width: 36px; 240 | height: 36px; 241 | position: absolute; 242 | top: 50%; 243 | margin-top: -18px; 244 | } 245 | .chat-body .answer-add .answer-btn-2 { 246 | background: url("http://91.234.35.26/iwiki-admin/v1.0.0/admin/img/icon-41.png") 50% 50% no-repeat; 247 | right: 20px; 248 | } 249 | .chat input::-webkit-input-placeholder { 250 | color: #fff; 251 | } 252 | 253 | .chat input:-moz-placeholder { /* Firefox 18- */ 254 | color: #fff; 255 | } 256 | 257 | .chat input::-moz-placeholder { /* Firefox 19+ */ 258 | color: #fff; 259 | } 260 | 261 | .chat input:-ms-input-placeholder { 262 | color: #fff; 263 | } 264 | .chat input { 265 | -webkit-appearance: none; 266 | border-radius: 0; 267 | } 268 | -------------------------------------------------------------------------------- /infer_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import argparse 4 | import collections 5 | import json 6 | import logging 7 | import math 8 | import os 9 | import random 10 | import sys 11 | from io import open 12 | 13 | import numpy as np 14 | import torch 15 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,TensorDataset) 16 | from torch.utils.data.distributed import DistributedSampler 17 | from tqdm import tqdm, trange 18 | 19 | from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE 20 | from modeling import BertForQuestionAnswering, BertConfig 21 | from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear 22 | from pytorch_pretrained_bert.tokenization import (BasicTokenizer,BertTokenizer,whitespace_tokenize) 23 | 24 | if sys.version_info[0] == 2: 25 | import cPickle as pickle 26 | else: 27 | import pickle 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | class CoQAExample(object): 32 | """ 33 | A single training/test example for the CoQA dataset. 34 | For examples without an answer, the start and end position are -1. 35 | """ 36 | def __init__(self, 37 | qas_id, 38 | question_text, 39 | doc_tokens, 40 | orig_answer_text=None, 41 | start_position=None, 42 | end_position=None, 43 | is_impossible=None, 44 | answer_span=None, 45 | is_yes=None, 46 | is_no=None, 47 | prev_ques=None, 48 | prev_answ=None): 49 | self.qas_id = qas_id 50 | self.question_text = question_text 51 | self.doc_tokens = doc_tokens 52 | self.orig_answer_text = orig_answer_text 53 | self.start_position = start_position 54 | self.end_position = end_position 55 | self.is_impossible = is_impossible 56 | self.is_yes = is_yes 57 | self.is_no = is_no 58 | self.answer_span = answer_span 59 | self.prev_ques = prev_ques 60 | self.prev_answ = prev_answ 61 | 62 | def __str__(self): 63 | return self.__repr__() 64 | 65 | def __repr__(self): 66 | s = "" 67 | s += "qas_id: %s" % (self.qas_id) 68 | s += '\n' 69 | s += "question_text: %s" % ( 70 | self.question_text) 71 | s += '\n' 72 | s += "answer_span: "+self.answer_span 73 | s += '\n' 74 | s += "answer : "+self.orig_answer_text 75 | s += '\n' 76 | s += "doc_tokens: [%s]" % (" ".join(self.doc_tokens)) 77 | s += '\n' 78 | if self.start_position: 79 | s += "start_position: %d" % (self.start_position) 80 | s += '\n' 81 | if self.end_position: 82 | s += "end_position: %d" % (self.end_position) 83 | s += '\n' 84 | if self.is_impossible: 85 | s += "is_impossible: %r" % (self.is_impossible) 86 | s += '\n' 87 | return s 88 | 89 | 90 | class InputFeatures(object): 91 | """A single set of features of data.""" 92 | 93 | def __init__(self, 94 | unique_id, 95 | example_index, 96 | doc_span_index, 97 | tokens, 98 | token_to_orig_map, 99 | token_is_max_context, 100 | input_ids, 101 | input_mask, 102 | segment_ids, 103 | start_position=None, 104 | end_position=None, 105 | is_impossible=None, 106 | is_yes=None, 107 | is_no=None): 108 | self.unique_id = unique_id 109 | self.example_index = example_index 110 | self.doc_span_index = doc_span_index 111 | self.tokens = tokens 112 | self.token_to_orig_map = token_to_orig_map 113 | self.token_is_max_context = token_is_max_context 114 | self.input_ids = input_ids 115 | self.input_mask = input_mask 116 | self.segment_ids = segment_ids 117 | self.start_position = start_position 118 | self.end_position = end_position 119 | self.is_impossible = is_impossible 120 | self.is_yes = is_yes 121 | self.is_no = is_no 122 | 123 | 124 | def f1_bow(src,tgt): 125 | src = set(src.split(' ')) 126 | tgt = set(tgt.split(' ')) 127 | p = len(src.intersection(tgt))*1./len(src) 128 | r = len(src.intersection(tgt))*1./len(tgt) 129 | 130 | if p == 0 and r == 0: 131 | return 0 132 | return 2.*p*r/(p+r) 133 | 134 | def find_gt_span(query,tokens,start,end): 135 | max_f1 = 0 136 | ms = -1 137 | me = -1 138 | for s in range(start,end+1): 139 | for e in range(start,end+1): 140 | f1 = f1_bow(query,' '.join(tokens[s:e+1])) 141 | if f1 > max_f1: 142 | ms = s 143 | me = e 144 | max_f1 = f1 145 | return ms,me 146 | 147 | def read_coqa_examples(input_file, is_training, version_2_with_negative): 148 | """Read a SQuAD json file into a list of SquadExample.""" 149 | with open(input_file, "r", encoding='utf-8') as reader: 150 | input_data = json.load(reader)["data"] 151 | 152 | def is_whitespace(c): 153 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: 154 | return True 155 | return False 156 | 157 | def is_punc(c): 158 | if c in '?,.!()[]-_': 159 | return True 160 | return False 161 | 162 | def punc_sep(s): 163 | tokens = [] 164 | is_prev_white = True 165 | for c in s: 166 | if is_whitespace(c): 167 | is_prev_white = True 168 | else: 169 | if is_punc(c): 170 | tokens.append(c) 171 | is_prev_white = True 172 | else: 173 | if is_prev_white: 174 | is_prev_white = False 175 | tokens.append(c) 176 | else: 177 | tokens[-1]+=c 178 | return ' '.join(tokens) 179 | 180 | 181 | examples = [] 182 | for entry in input_data: 183 | paragraph_text = entry['story'] 184 | doc_tokens = [] 185 | char_to_word_offset = [] 186 | prev_is_whitespace = True 187 | for c in paragraph_text: 188 | if is_whitespace(c): 189 | prev_is_whitespace = True 190 | else: 191 | if prev_is_whitespace: 192 | doc_tokens.append(c) 193 | prev_is_whitespace = False 194 | else: 195 | doc_tokens[-1] += c 196 | 197 | char_to_word_offset.append(len(doc_tokens) - 1) 198 | 199 | questions = sorted(entry['questions'], key=lambda k: k['turn_id']) 200 | answers = sorted(entry['answers'], key=lambda k :k['turn_id']) 201 | for i in range(len(questions)): 202 | q = questions[i] 203 | a = answers[i] 204 | question_text = q['input_text'] 205 | orig_answer_text = a['span_text'] 206 | is_impossible = False 207 | is_yes = False 208 | is_no = False 209 | if i!=0: 210 | prev_ques = questions[i-1]['input_text'] 211 | prev_answ = answers[i-1]['input_text'] 212 | else: 213 | prev_ques = "" 214 | prev_answ = "" 215 | 216 | start_position = char_to_word_offset[a['span_start']] 217 | end_position = char_to_word_offset[a['span_end']-1] 218 | 219 | actual_text = " ".join(doc_tokens[start_position:(end_position + 1)]) 220 | # cleaned_answer_text = punc_sep(" ".join(whitespace_tokenize(orig_answer_text))) 221 | answer = a['input_text'] 222 | # start,end = find_gt_span(punc_sep(answer),doc_tokens,start_position,end_position) 223 | start,end = start_position,end_position 224 | if answer.lower() == 'yes' or answer.lower() == 'yes.': 225 | start = -1 226 | end = 0 227 | is_yes = True 228 | if answer.lower() == 'no' or answer.lower() == 'no.': 229 | start = 0 230 | end = -1 231 | is_no = True 232 | if answer.lower() == 'unknown' or answer.lower() == 'unknown.': 233 | start = -1 234 | end = -1 235 | is_impossible = True 236 | example = CoQAExample( 237 | qas_id=entry['id']+'_'+str(i+1), 238 | question_text=question_text, 239 | doc_tokens=doc_tokens, 240 | orig_answer_text=a['input_text'], 241 | start_position=start, 242 | end_position=end, 243 | is_impossible=is_impossible, 244 | is_yes= is_yes, 245 | is_no=is_no, 246 | answer_span=orig_answer_text, 247 | prev_ques=prev_ques, 248 | prev_answ=prev_answ) 249 | examples.append(example) 250 | print(example) 251 | 252 | return examples 253 | 254 | def convert_examples_to_features(examples, tokenizer, max_seq_length, 255 | doc_stride, max_query_length, is_training): 256 | """Loads a data file into a list of `InputBatch`s.""" 257 | 258 | unique_id = 1000000000 259 | HA = True 260 | features = [] 261 | for (example_index, example) in enumerate(examples): 262 | def new_tok(s): 263 | new_s = [] 264 | for t in s.split(' '): 265 | if '[unused' in t: 266 | new_s.append(t) 267 | else: 268 | for tx in tokenizer.tokenize(t): 269 | new_s.append(tx) 270 | return new_s 271 | 272 | query_tokens = new_tok(example.question_text) 273 | prev_q_tokens = new_tok(example.prev_ques) 274 | prev_a_tokens = new_tok(example.prev_answ) 275 | if HA: 276 | if len(query_tokens) > 60: 277 | query_tokens = query_tokens[0:60] 278 | if len(prev_q_tokens) + len(query_tokens) > 90: 279 | prev_q_tokens = prev_q_tokens[:30] 280 | if len(prev_q_tokens) + len(prev_a_tokens) + len(query_tokens) > 100: 281 | prev_a_tokens = prev_a_tokens[:10] 282 | else: 283 | if len(query_tokens) > max_query_length: 284 | query_tokens = query_tokens[:max_query_length] 285 | 286 | tok_to_orig_index = [] 287 | orig_to_tok_index = [] 288 | all_doc_tokens = [] 289 | for (i, token) in enumerate(example.doc_tokens): 290 | orig_to_tok_index.append(len(all_doc_tokens)) 291 | sub_tokens = tokenizer.tokenize(token) 292 | if '[unused' in token: 293 | all_doc_tokens.append(token) 294 | tok_to_orig_index.append(i) 295 | else: 296 | for sub_token in sub_tokens: 297 | tok_to_orig_index.append(i) 298 | all_doc_tokens.append(sub_token) 299 | 300 | 301 | tok_start_position = None 302 | tok_end_position = None 303 | if is_training and example.is_impossible: 304 | tok_start_position = -1 305 | tok_end_position = -1 306 | 307 | if is_training and example.is_yes: 308 | tok_start_position = -1 309 | tok_end_position = 0 310 | 311 | if is_training and example.is_no: 312 | tok_start_position = 0 313 | tok_end_position = -1 314 | 315 | if is_training and not example.is_impossible and not example.is_yes and not example.is_no: 316 | tok_start_position = orig_to_tok_index[example.start_position] 317 | if example.end_position < len(example.doc_tokens) - 1: 318 | tok_end_position = orig_to_tok_index[example.end_position + 1] - 1 319 | else: 320 | tok_end_position = len(all_doc_tokens) - 1 321 | (tok_start_position, tok_end_position) = _improve_answer_span( 322 | all_doc_tokens, tok_start_position, tok_end_position, tokenizer, 323 | example.orig_answer_text) 324 | 325 | # The -3 accounts for [CLS], [SEP] and [SEP] 326 | if HA: 327 | max_tokens_for_doc = max_seq_length - len(query_tokens) - 5 - len(prev_a_tokens) - len(prev_q_tokens) 328 | else: 329 | max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 330 | # We can have documents that are longer than the maximum sequence length. 331 | # To deal with this we do a sliding window approach, where we take chunks 332 | # of the up to our max length with a stride of `doc_stride`. 333 | _DocSpan = collections.namedtuple( # pylint: disable=invalid-name 334 | "DocSpan", ["start", "length"]) 335 | doc_spans = [] 336 | start_offset = 0 337 | while start_offset < len(all_doc_tokens): 338 | length = len(all_doc_tokens) - start_offset 339 | if length > max_tokens_for_doc: 340 | length = max_tokens_for_doc 341 | doc_spans.append(_DocSpan(start=start_offset, length=length)) 342 | if start_offset + length == len(all_doc_tokens): 343 | break 344 | start_offset += min(length, doc_stride) 345 | 346 | for (doc_span_index, doc_span) in enumerate(doc_spans): 347 | tokens = [] 348 | token_to_orig_map = {} 349 | token_is_max_context = {} 350 | segment_ids = [] 351 | tokens.append("[CLS]") 352 | segment_ids.append(0) 353 | for token in query_tokens: 354 | tokens.append(token) 355 | segment_ids.append(0) 356 | if HA: 357 | tokens.append("[unused0]") 358 | segment_ids.append(0) 359 | for token in prev_q_tokens: 360 | tokens.append(token) 361 | segment_ids.append(0) 362 | tokens.append("[unused1]") 363 | segment_ids.append(0) 364 | for token in prev_a_tokens: 365 | tokens.append(token) 366 | segment_ids.append(0) 367 | tokens.append("[SEP]") 368 | segment_ids.append(0) 369 | 370 | for i in range(doc_span.length): 371 | split_token_index = doc_span.start + i 372 | token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] 373 | 374 | is_max_context = _check_is_max_context(doc_spans, doc_span_index, 375 | split_token_index) 376 | token_is_max_context[len(tokens)] = is_max_context 377 | tokens.append(all_doc_tokens[split_token_index]) 378 | segment_ids.append(1) 379 | tokens.append("[SEP]") 380 | segment_ids.append(1) 381 | 382 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 383 | 384 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 385 | # tokens are attended to. 386 | input_mask = [1] * len(input_ids) 387 | 388 | # Zero-pad up to the sequence length. 389 | while len(input_ids) < max_seq_length: 390 | input_ids.append(0) 391 | input_mask.append(0) 392 | segment_ids.append(0) 393 | 394 | assert len(input_ids) == max_seq_length 395 | assert len(input_mask) == max_seq_length 396 | assert len(segment_ids) == max_seq_length 397 | 398 | start_position = None 399 | end_position = None 400 | if is_training and not example.is_impossible: 401 | # For training, if our document chunk does not contain an annotation 402 | # we throw it out, since there is nothing to predict. 403 | doc_start = doc_span.start 404 | doc_end = doc_span.start + doc_span.length - 1 405 | out_of_span = False 406 | if not (tok_start_position >= doc_start and 407 | tok_end_position <= doc_end): 408 | out_of_span = True 409 | if out_of_span: 410 | start_position = 0 411 | end_position = 0 412 | else: 413 | doc_offset = len(query_tokens) + 2 414 | if HA: 415 | doc_offset += 2 + len(prev_q_tokens) + len(prev_a_tokens) 416 | start_position = tok_start_position - doc_start + doc_offset 417 | end_position = tok_end_position - doc_start + doc_offset 418 | 419 | 420 | start_position,end_position = find_gt_span(' '.join(tokenizer.tokenize(example.orig_answer_text)),tokens,start_position,end_position) 421 | 422 | if is_training: 423 | if example.is_impossible: 424 | start_position = -1 425 | end_position = -1 426 | if example.is_yes: 427 | start_position = -1 428 | end_position = 0 429 | if example.is_no: 430 | start_position = 0 431 | end_position = -1 432 | 433 | # print("Tokens") 434 | # print(tokens) 435 | if example_index < 50: 436 | logger.info("*** Example ***") 437 | logger.info("unique_id: %s" % (unique_id)) 438 | logger.info("example_index: %s" % (example_index)) 439 | # logger.info("doc_span_index: %s" % (doc_span_index)) 440 | logger.info("tokens: %s" % " ".join(tokens)) 441 | logger.info("token_to_orig_map: %s" % " ".join([ 442 | "%d:%d" % (x, y) for (x, y) in token_to_orig_map.items()])) 443 | logger.info("token_is_max_context: %s" % " ".join([ 444 | "%d:%s" % (x, y) for (x, y) in token_is_max_context.items() 445 | ])) 446 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 447 | logger.info( 448 | "input_mask: %s" % " ".join([str(x) for x in input_mask])) 449 | logger.info( 450 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 451 | if is_training and example.is_impossible: 452 | logger.info("impossible example") 453 | if is_training and not example.is_impossible: 454 | answer_text = " ".join(tokens[start_position:(end_position + 1)]) 455 | logger.info("start_position: %d" % (start_position)) 456 | logger.info("end_position: %d" % (end_position)) 457 | logger.info( 458 | "answer: %s" % (answer_text)) 459 | 460 | skip = False 461 | if is_training: 462 | skip = True 463 | if start_position == -1 and end_position == -1: 464 | skip = False 465 | if start_position > 0 and end_position > 0: 466 | skip = False 467 | if start_position == -1 and end_position == 0: 468 | skip = False 469 | 470 | if start_position == 0 and end_position == -1: 471 | skip = False 472 | 473 | if not skip: 474 | features.append( 475 | InputFeatures( 476 | unique_id=unique_id, 477 | example_index=example_index, 478 | doc_span_index=doc_span_index, 479 | tokens=tokens, 480 | token_to_orig_map=token_to_orig_map, 481 | token_is_max_context=token_is_max_context, 482 | input_ids=input_ids, 483 | input_mask=input_mask, 484 | segment_ids=segment_ids, 485 | start_position=start_position, 486 | end_position=end_position, 487 | is_impossible=example.is_impossible)) 488 | unique_id += 1 489 | 490 | return features 491 | 492 | 493 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, 494 | orig_answer_text): 495 | """Returns tokenized answer spans that better match the annotated answer.""" 496 | 497 | # The SQuAD annotations are character based. We first project them to 498 | # whitespace-tokenized words. But then after WordPiece tokenization, we can 499 | # often find a "better match". For example: 500 | # 501 | # Question: What year was John Smith born? 502 | # Context: The leader was John Smith (1895-1943). 503 | # Answer: 1895 504 | # 505 | # The original whitespace-tokenized answer will be "(1895-1943).". However 506 | # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match 507 | # the exact answer, 1895. 508 | # 509 | # However, this is not always possible. Consider the following: 510 | # 511 | # Question: What country is the top exporter of electornics? 512 | # Context: The Japanese electronics industry is the lagest in the world. 513 | # Answer: Japan 514 | # 515 | # In this case, the annotator chose "Japan" as a character sub-span of 516 | # the word "Japanese". Since our WordPiece tokenizer does not split 517 | # "Japanese", we just use "Japanese" as the annotation. This is fairly rare 518 | # in SQuAD, but does happen. 519 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) 520 | 521 | for new_start in range(input_start, input_end + 1): 522 | for new_end in range(input_end, new_start - 1, -1): 523 | text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) 524 | if text_span == tok_answer_text: 525 | return (new_start, new_end) 526 | 527 | return (input_start, input_end) 528 | 529 | 530 | def _check_is_max_context(doc_spans, cur_span_index, position): 531 | """Check if this is the 'max context' doc span for the token.""" 532 | 533 | # Because of the sliding window approach taken to scoring documents, a single 534 | # token can appear in multiple documents. E.g. 535 | # Doc: the man went to the store and bought a gallon of milk 536 | # Span A: the man went to the 537 | # Span B: to the store and bought 538 | # Span C: and bought a gallon of 539 | # ... 540 | # 541 | # Now the word 'bought' will have two scores from spans B and C. We only 542 | # want to consider the score with "maximum context", which we define as 543 | # the *minimum* of its left and right context (the *sum* of left and 544 | # right context will always be the same, of course). 545 | # 546 | # In the example the maximum context for 'bought' would be span C since 547 | # it has 1 left context and 3 right context, while span B has 4 left context 548 | # and 0 right context. 549 | best_score = None 550 | best_span_index = None 551 | for (span_index, doc_span) in enumerate(doc_spans): 552 | end = doc_span.start + doc_span.length - 1 553 | if position < doc_span.start: 554 | continue 555 | if position > end: 556 | continue 557 | num_left_context = position - doc_span.start 558 | num_right_context = end - position 559 | score = min(num_left_context, num_right_context) + 0.01 * doc_span.length 560 | if best_score is None or score > best_score: 561 | best_score = score 562 | best_span_index = span_index 563 | 564 | return cur_span_index == best_span_index 565 | 566 | 567 | RawResult = collections.namedtuple("RawResult", 568 | ["unique_id", "score","length"]) 569 | 570 | 571 | def write_predictions(all_examples, all_features, all_results, n_best_size, 572 | max_answer_length, do_lower_case, output_prediction_file, 573 | output_nbest_file, output_null_log_odds_file, verbose_logging, 574 | version_2_with_negative, null_score_diff_threshold): 575 | """Write final predictions to the json file and log-odds of null if needed.""" 576 | logger.info("Writing predictions to: %s" % (output_prediction_file)) 577 | logger.info("Writing nbest to: %s" % (output_nbest_file)) 578 | 579 | example_index_to_features = collections.defaultdict(list) 580 | for feature in all_features: 581 | example_index_to_features[feature.example_index].append(feature) 582 | 583 | unique_id_to_result = {} 584 | for result in all_results: 585 | unique_id_to_result[result.unique_id] = result 586 | 587 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name 588 | "PrelimPrediction", 589 | ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]) 590 | 591 | all_predictions = collections.OrderedDict() 592 | all_nbest_json = collections.OrderedDict() 593 | scores_diff_json = collections.OrderedDict() 594 | 595 | for (example_index, example) in enumerate(all_examples): 596 | features = example_index_to_features[example_index] 597 | 598 | prelim_predictions = [] 599 | # keep track of the minimum score of null start+end of position 0 600 | score_null = 1000000 # large and positive 601 | min_null_feature_index = 0 # the paragraph slice with min null score 602 | null_start_logit = 0 # the start logit at the slice with min null score 603 | null_end_logit = 0 # the end logit at the slice with min null score 604 | for (feature_index, feature) in enumerate(features): 605 | result = unique_id_to_result[feature.unique_id] 606 | # start_indexes = _get_best_indexes(result.start_logits, n_best_size) 607 | # end_indexes = _get_best_indexes(result.end_logits, n_best_size) 608 | _,idx = result.score.max(dim = 0) 609 | # print(result.score.size(),idx,result.length) 610 | if idx < result.length*result.length: 611 | start_indexes = [int(idx / result.length)] 612 | end_indexes = [int(idx % result.length)] 613 | if idx == result.length*result.length: 614 | start_indexes = [-1] 615 | end_indexes = [0] 616 | if idx == result.length*result.length + 1: 617 | start_indexes = [0] 618 | end_indexes = [-1] 619 | if idx == result.length*result.length + 2: 620 | start_indexes = [-1] 621 | end_indexes = [-1] 622 | # print(example_index,idx,start_indexes,end_indexes,result.score[-3:]) 623 | # print(start_indexes, end_indexes) 624 | # if we could have irrelevant answers, get the min score of irrelevant 625 | if version_2_with_negative: 626 | feature_null_score = result.start_logits[0] + result.end_logits[0] 627 | if feature_null_score < score_null: 628 | score_null = feature_null_score 629 | min_null_feature_index = feature_index 630 | null_start_logit = result.start_logits[0] 631 | null_end_logit = result.end_logits[0] 632 | for start_index in start_indexes: 633 | for end_index in end_indexes: 634 | # We could hypothetically create invalid predictions, e.g., predict 635 | # that the start of the span is in the question. We throw out all 636 | # invalid predictions. 637 | if start_index == -1 or end_index == -1: 638 | 639 | prelim_predictions.append( 640 | _PrelimPrediction( 641 | feature_index=feature_index, 642 | start_index=start_index, 643 | end_index=end_index, 644 | start_logit=69, 645 | end_logit=69)) 646 | continue 647 | if start_index >= len(feature.tokens): 648 | continue 649 | if end_index >= len(feature.tokens): 650 | continue 651 | if start_index not in feature.token_to_orig_map: 652 | continue 653 | if end_index not in feature.token_to_orig_map: 654 | continue 655 | if not feature.token_is_max_context.get(start_index, False): 656 | continue 657 | if end_index < start_index: 658 | continue 659 | length = end_index - start_index + 1 660 | if length > max_answer_length: 661 | continue 662 | prelim_predictions.append( 663 | _PrelimPrediction( 664 | feature_index=feature_index, 665 | start_index=start_index, 666 | end_index=end_index, 667 | start_logit=69, 668 | end_logit=69)) 669 | if version_2_with_negative: 670 | prelim_predictions.append( 671 | _PrelimPrediction( 672 | feature_index=min_null_feature_index, 673 | start_index=0, 674 | end_index=0, 675 | start_logit=null_start_logit, 676 | end_logit=null_end_logit)) 677 | prelim_predictions = sorted( 678 | prelim_predictions, 679 | key=lambda x: (x.start_logit + x.end_logit), 680 | reverse=True) 681 | 682 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name 683 | "NbestPrediction", ["text", "start_logit", "end_logit"]) 684 | 685 | seen_predictions = {} 686 | nbest = [] 687 | for pred in prelim_predictions: 688 | if len(nbest) >= n_best_size: 689 | break 690 | feature = features[pred.feature_index] 691 | if pred.start_index > 0: # this is a non-null prediction 692 | tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] 693 | orig_doc_start = feature.token_to_orig_map[pred.start_index] 694 | orig_doc_end = feature.token_to_orig_map[pred.end_index] 695 | orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] 696 | tok_text = " ".join(tok_tokens) 697 | 698 | # De-tokenize WordPieces that have been split off. 699 | tok_text = tok_text.replace(" ##", "") 700 | tok_text = tok_text.replace("##", "") 701 | 702 | # Clean whitespace 703 | tok_text = tok_text.strip() 704 | tok_text = " ".join(tok_text.split()) 705 | orig_text = " ".join(orig_tokens) 706 | 707 | final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging) 708 | if final_text in seen_predictions: 709 | continue 710 | 711 | seen_predictions[final_text] = True 712 | else: 713 | final_text = "" 714 | if pred.start_index == -1 and pred.end_index == -1: 715 | final_text = "unknown" 716 | if pred.start_index == -1 and pred.end_index == 0: 717 | final_text = "yes" 718 | if pred.start_index == 0 and pred.end_index == -1: 719 | final_text = "no" 720 | 721 | seen_predictions[final_text] = True 722 | 723 | nbest.append( 724 | _NbestPrediction( 725 | text=final_text, 726 | start_logit=pred.start_logit, 727 | end_logit=pred.end_logit)) 728 | # if we didn't include the empty option in the n-best, include it 729 | if version_2_with_negative: 730 | if "" not in seen_predictions: 731 | nbest.append( 732 | _NbestPrediction( 733 | text="", 734 | start_logit=null_start_logit, 735 | end_logit=null_end_logit)) 736 | 737 | # In very rare edge cases we could only have single null prediction. 738 | # So we just create a nonce prediction in this case to avoid failure. 739 | if len(nbest)==1: 740 | nbest.insert(0, 741 | _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) 742 | 743 | # In very rare edge cases we could have no valid predictions. So we 744 | # just create a nonce prediction in this case to avoid failure. 745 | if not nbest: 746 | nbest.append( 747 | _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) 748 | 749 | assert len(nbest) >= 1 750 | 751 | total_scores = [] 752 | best_non_null_entry = None 753 | for entry in nbest: 754 | total_scores.append(entry.start_logit + entry.end_logit) 755 | if not best_non_null_entry: 756 | if entry.text: 757 | best_non_null_entry = entry 758 | 759 | probs = _compute_softmax(total_scores) 760 | 761 | nbest_json = [] 762 | for (i, entry) in enumerate(nbest): 763 | output = collections.OrderedDict() 764 | output["text"] = entry.text 765 | output["probability"] = probs[i] 766 | output["start_logit"] = entry.start_logit 767 | output["end_logit"] = entry.end_logit 768 | nbest_json.append(output) 769 | 770 | assert len(nbest_json) >= 1 771 | 772 | if not version_2_with_negative: 773 | all_predictions[example.qas_id] = nbest_json[0]["text"] 774 | else: 775 | # predict "" iff the null score - the score of best non-null > threshold 776 | score_diff = score_null - best_non_null_entry.start_logit - ( 777 | best_non_null_entry.end_logit) 778 | scores_diff_json[example.qas_id] = score_diff 779 | if score_diff > null_score_diff_threshold: 780 | all_predictions[example.qas_id] = "" 781 | else: 782 | all_predictions[example.qas_id] = best_non_null_entry.text 783 | all_nbest_json[example.qas_id] = nbest_json 784 | 785 | with open(output_prediction_file, "w") as writer: 786 | writer.write(json.dumps(all_predictions, indent=4) + "\n") 787 | 788 | with open(output_nbest_file, "w") as writer: 789 | writer.write(json.dumps(all_nbest_json, indent=4) + "\n") 790 | 791 | if version_2_with_negative: 792 | with open(output_null_log_odds_file, "w") as writer: 793 | writer.write(json.dumps(scores_diff_json, indent=4) + "\n") 794 | 795 | def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): 796 | """Project the tokenized prediction back to the original text.""" 797 | 798 | # When we created the data, we kept track of the alignment between original 799 | # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So 800 | # now `orig_text` contains the span of our original text corresponding to the 801 | # span that we predicted. 802 | # 803 | # However, `orig_text` may contain extra characters that we don't want in 804 | # our prediction. 805 | # 806 | # For example, let's say: 807 | # pred_text = steve smith 808 | # orig_text = Steve Smith's 809 | # 810 | # We don't want to return `orig_text` because it contains the extra "'s". 811 | # 812 | # We don't want to return `pred_text` because it's already been normalized 813 | # (the SQuAD eval script also does punctuation stripping/lower casing but 814 | # our tokenizer does additional normalization like stripping accent 815 | # characters). 816 | # 817 | # What we really want to return is "Steve Smith". 818 | # 819 | # Therefore, we have to apply a semi-complicated alignment heuristic between 820 | # `pred_text` and `orig_text` to get a character-to-character alignment. This 821 | # can fail in certain cases in which case we just return `orig_text`. 822 | 823 | def _strip_spaces(text): 824 | ns_chars = [] 825 | ns_to_s_map = collections.OrderedDict() 826 | for (i, c) in enumerate(text): 827 | if c == " ": 828 | continue 829 | ns_to_s_map[len(ns_chars)] = i 830 | ns_chars.append(c) 831 | ns_text = "".join(ns_chars) 832 | return (ns_text, ns_to_s_map) 833 | 834 | # We first tokenize `orig_text`, strip whitespace from the result 835 | # and `pred_text`, and check if they are the same length. If they are 836 | # NOT the same length, the heuristic has failed. If they are the same 837 | # length, we assume the characters are one-to-one aligned. 838 | tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 839 | 840 | tok_text = " ".join(tokenizer.tokenize(orig_text)) 841 | 842 | start_position = tok_text.find(pred_text) 843 | if start_position == -1: 844 | if verbose_logging: 845 | logger.info( 846 | "Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) 847 | return orig_text 848 | end_position = start_position + len(pred_text) - 1 849 | 850 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) 851 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) 852 | 853 | if len(orig_ns_text) != len(tok_ns_text): 854 | if verbose_logging: 855 | logger.info("Length not equal after stripping spaces: '%s' vs '%s'", 856 | orig_ns_text, tok_ns_text) 857 | return orig_text 858 | 859 | # We then project the characters in `pred_text` back to `orig_text` using 860 | # the character-to-character alignment. 861 | tok_s_to_ns_map = {} 862 | for (i, tok_index) in tok_ns_to_s_map.items(): 863 | tok_s_to_ns_map[tok_index] = i 864 | 865 | orig_start_position = None 866 | if start_position in tok_s_to_ns_map: 867 | ns_start_position = tok_s_to_ns_map[start_position] 868 | if ns_start_position in orig_ns_to_s_map: 869 | orig_start_position = orig_ns_to_s_map[ns_start_position] 870 | 871 | if orig_start_position is None: 872 | if verbose_logging: 873 | logger.info("Couldn't map start position") 874 | return orig_text 875 | 876 | orig_end_position = None 877 | if end_position in tok_s_to_ns_map: 878 | ns_end_position = tok_s_to_ns_map[end_position] 879 | if ns_end_position in orig_ns_to_s_map: 880 | orig_end_position = orig_ns_to_s_map[ns_end_position] 881 | 882 | if orig_end_position is None: 883 | if verbose_logging: 884 | logger.info("Couldn't map end position") 885 | return orig_text 886 | 887 | output_text = orig_text[orig_start_position:(orig_end_position + 1)] 888 | return output_text 889 | 890 | 891 | def _get_best_indexes(logits, n_best_size): 892 | """Get the n-best logits from a list.""" 893 | index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) 894 | 895 | best_indexes = [] 896 | for i in range(len(index_and_score)): 897 | if i >= n_best_size: 898 | break 899 | best_indexes.append(index_and_score[i][0]) 900 | return best_indexes 901 | 902 | 903 | def _compute_softmax(scores): 904 | """Compute softmax probability over raw logits.""" 905 | if not scores: 906 | return [] 907 | 908 | max_score = None 909 | for score in scores: 910 | if max_score is None or score > max_score: 911 | max_score = score 912 | 913 | exp_scores = [] 914 | total_sum = 0.0 915 | for score in scores: 916 | x = math.exp(score - max_score) 917 | exp_scores.append(x) 918 | total_sum += x 919 | 920 | probs = [] 921 | for score in exp_scores: 922 | probs.append(score / total_sum) 923 | return probs -------------------------------------------------------------------------------- /modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model.""" 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import copy 21 | import json 22 | import logging 23 | import math 24 | import os 25 | import shutil 26 | import tarfile 27 | import tempfile 28 | import sys 29 | from io import open 30 | 31 | import torch 32 | from torch import nn 33 | from torch.nn import CrossEntropyLoss 34 | 35 | from pytorch_pretrained_bert.file_utils import cached_path 36 | WEIGHTS_NAME = 'pytorch_model.bin' 37 | CONFIG_NAME = 'config.json' 38 | 39 | logger = logging.getLogger(__name__) 40 | 41 | PRETRAINED_MODEL_ARCHIVE_MAP = { 42 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", 43 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", 44 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", 45 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", 46 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", 47 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", 48 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", 49 | } 50 | BERT_CONFIG_NAME = 'bert_config.json' 51 | TF_WEIGHTS_NAME = 'model.ckpt' 52 | 53 | def gen_upper_triangle(score_s, score_e, max_len, use_cuda=False): 54 | batch_size = score_s.shape[0] 55 | context_len = score_s.shape[1] 56 | # batch x context_len x context_len 57 | expand_score = score_s.unsqueeze(2).expand([batch_size, context_len, context_len]) +\ 58 | score_e.unsqueeze(1).expand([batch_size, context_len, context_len]) 59 | score_mask = torch.ones(context_len) 60 | if use_cuda: 61 | score_mask = score_mask.cuda() 62 | score_mask = torch.ger(score_mask, score_mask).triu().tril(max_len - 1) 63 | empty_mask = score_mask.eq(0).unsqueeze(0).expand_as(expand_score) 64 | expand_score.data.masked_fill_(empty_mask.data, -float('inf')) 65 | return expand_score.contiguous().view(batch_size, -1) 66 | 67 | def load_tf_weights_in_bert(model, tf_checkpoint_path): 68 | """ Load tf checkpoints in a pytorch model 69 | """ 70 | try: 71 | import re 72 | import numpy as np 73 | import tensorflow as tf 74 | except ImportError: 75 | print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " 76 | "https://www.tensorflow.org/install/ for installation instructions.") 77 | raise 78 | tf_path = os.path.abspath(tf_checkpoint_path) 79 | print("Converting TensorFlow checkpoint from {}".format(tf_path)) 80 | # Load weights from TF model 81 | init_vars = tf.train.list_variables(tf_path) 82 | names = [] 83 | arrays = [] 84 | for name, shape in init_vars: 85 | print("Loading TF weight {} with shape {}".format(name, shape)) 86 | array = tf.train.load_variable(tf_path, name) 87 | names.append(name) 88 | arrays.append(array) 89 | 90 | for name, array in zip(names, arrays): 91 | name = name.split('/') 92 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v 93 | # which are not required for using pretrained model 94 | if any(n in ["adam_v", "adam_m", "global_step"] for n in name): 95 | print("Skipping {}".format("/".join(name))) 96 | continue 97 | pointer = model 98 | for m_name in name: 99 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name): 100 | l = re.split(r'_(\d+)', m_name) 101 | else: 102 | l = [m_name] 103 | if l[0] == 'kernel' or l[0] == 'gamma': 104 | pointer = getattr(pointer, 'weight') 105 | elif l[0] == 'output_bias' or l[0] == 'beta': 106 | pointer = getattr(pointer, 'bias') 107 | elif l[0] == 'output_weights': 108 | pointer = getattr(pointer, 'weight') 109 | elif l[0] == 'squad': 110 | pointer = getattr(pointer, 'classifier') 111 | else: 112 | try: 113 | pointer = getattr(pointer, l[0]) 114 | except AttributeError: 115 | print("Skipping {}".format("/".join(name))) 116 | continue 117 | if len(l) >= 2: 118 | num = int(l[1]) 119 | pointer = pointer[num] 120 | if m_name[-11:] == '_embeddings': 121 | pointer = getattr(pointer, 'weight') 122 | elif m_name == 'kernel': 123 | array = np.transpose(array) 124 | try: 125 | assert pointer.shape == array.shape 126 | except AssertionError as e: 127 | e.args += (pointer.shape, array.shape) 128 | raise 129 | print("Initialize PyTorch weight {}".format(name)) 130 | pointer.data = torch.from_numpy(array) 131 | return model 132 | 133 | 134 | def gelu(x): 135 | """Implementation of the gelu activation function. 136 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 137 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 138 | Also see https://arxiv.org/abs/1606.08415 139 | """ 140 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 141 | 142 | 143 | def swish(x): 144 | return x * torch.sigmoid(x) 145 | 146 | 147 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 148 | 149 | 150 | class BertConfig(object): 151 | """Configuration class to store the configuration of a `BertModel`. 152 | """ 153 | def __init__(self, 154 | vocab_size_or_config_json_file, 155 | hidden_size=768, 156 | num_hidden_layers=12, 157 | num_attention_heads=12, 158 | intermediate_size=3072, 159 | hidden_act="gelu", 160 | hidden_dropout_prob=0.1, 161 | attention_probs_dropout_prob=0.1, 162 | max_position_embeddings=512, 163 | type_vocab_size=2, 164 | initializer_range=0.02): 165 | """Constructs BertConfig. 166 | 167 | Args: 168 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. 169 | hidden_size: Size of the encoder layers and the pooler layer. 170 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 171 | num_attention_heads: Number of attention heads for each attention layer in 172 | the Transformer encoder. 173 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 174 | layer in the Transformer encoder. 175 | hidden_act: The non-linear activation function (function or string) in the 176 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 177 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 178 | layers in the embeddings, encoder, and pooler. 179 | attention_probs_dropout_prob: The dropout ratio for the attention 180 | probabilities. 181 | max_position_embeddings: The maximum sequence length that this model might 182 | ever be used with. Typically set this to something large just in case 183 | (e.g., 512 or 1024 or 2048). 184 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 185 | `BertModel`. 186 | initializer_range: The sttdev of the truncated_normal_initializer for 187 | initializing all weight matrices. 188 | """ 189 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 190 | and isinstance(vocab_size_or_config_json_file, unicode)): 191 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 192 | json_config = json.loads(reader.read()) 193 | for key, value in json_config.items(): 194 | self.__dict__[key] = value 195 | elif isinstance(vocab_size_or_config_json_file, int): 196 | self.vocab_size = vocab_size_or_config_json_file 197 | self.hidden_size = hidden_size 198 | self.num_hidden_layers = num_hidden_layers 199 | self.num_attention_heads = num_attention_heads 200 | self.hidden_act = hidden_act 201 | self.intermediate_size = intermediate_size 202 | self.hidden_dropout_prob = hidden_dropout_prob 203 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 204 | self.max_position_embeddings = max_position_embeddings 205 | self.type_vocab_size = type_vocab_size 206 | self.initializer_range = initializer_range 207 | else: 208 | raise ValueError("First argument must be either a vocabulary size (int)" 209 | "or the path to a pretrained model config file (str)") 210 | 211 | @classmethod 212 | def from_dict(cls, json_object): 213 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 214 | config = BertConfig(vocab_size_or_config_json_file=-1) 215 | for key, value in json_object.items(): 216 | config.__dict__[key] = value 217 | return config 218 | 219 | @classmethod 220 | def from_json_file(cls, json_file): 221 | """Constructs a `BertConfig` from a json file of parameters.""" 222 | with open(json_file, "r", encoding='utf-8') as reader: 223 | text = reader.read() 224 | return cls.from_dict(json.loads(text)) 225 | 226 | def __repr__(self): 227 | return str(self.to_json_string()) 228 | 229 | def to_dict(self): 230 | """Serializes this instance to a Python dictionary.""" 231 | output = copy.deepcopy(self.__dict__) 232 | return output 233 | 234 | def to_json_string(self): 235 | """Serializes this instance to a JSON string.""" 236 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 237 | 238 | def to_json_file(self, json_file_path): 239 | """ Save this instance to a json file.""" 240 | with open(json_file_path, "w", encoding='utf-8') as writer: 241 | writer.write(self.to_json_string()) 242 | 243 | try: 244 | from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm 245 | except ImportError: 246 | logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .") 247 | class BertLayerNorm(nn.Module): 248 | def __init__(self, hidden_size, eps=1e-12): 249 | """Construct a layernorm module in the TF style (epsilon inside the square root). 250 | """ 251 | super(BertLayerNorm, self).__init__() 252 | self.weight = nn.Parameter(torch.ones(hidden_size)) 253 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 254 | self.variance_epsilon = eps 255 | 256 | def forward(self, x): 257 | u = x.mean(-1, keepdim=True) 258 | s = (x - u).pow(2).mean(-1, keepdim=True) 259 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 260 | return self.weight * x + self.bias 261 | 262 | class BertEmbeddings(nn.Module): 263 | """Construct the embeddings from word, position and token_type embeddings. 264 | """ 265 | def __init__(self, config): 266 | super(BertEmbeddings, self).__init__() 267 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) 268 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 269 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 270 | 271 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 272 | # any TensorFlow checkpoint file 273 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 274 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 275 | 276 | def forward(self, input_ids, token_type_ids=None): 277 | seq_length = input_ids.size(1) 278 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 279 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 280 | if token_type_ids is None: 281 | token_type_ids = torch.zeros_like(input_ids) 282 | 283 | words_embeddings = self.word_embeddings(input_ids) 284 | position_embeddings = self.position_embeddings(position_ids) 285 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 286 | 287 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 288 | embeddings = self.LayerNorm(embeddings) 289 | embeddings = self.dropout(embeddings) 290 | return embeddings 291 | 292 | 293 | class BertSelfAttention(nn.Module): 294 | def __init__(self, config): 295 | super(BertSelfAttention, self).__init__() 296 | if config.hidden_size % config.num_attention_heads != 0: 297 | raise ValueError( 298 | "The hidden size (%d) is not a multiple of the number of attention " 299 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 300 | self.num_attention_heads = config.num_attention_heads 301 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 302 | self.all_head_size = self.num_attention_heads * self.attention_head_size 303 | 304 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 305 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 306 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 307 | 308 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 309 | 310 | def transpose_for_scores(self, x): 311 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 312 | x = x.view(*new_x_shape) 313 | return x.permute(0, 2, 1, 3) 314 | 315 | def forward(self, hidden_states, attention_mask): 316 | mixed_query_layer = self.query(hidden_states) 317 | mixed_key_layer = self.key(hidden_states) 318 | mixed_value_layer = self.value(hidden_states) 319 | 320 | query_layer = self.transpose_for_scores(mixed_query_layer) 321 | key_layer = self.transpose_for_scores(mixed_key_layer) 322 | value_layer = self.transpose_for_scores(mixed_value_layer) 323 | 324 | # Take the dot product between "query" and "key" to get the raw attention scores. 325 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 326 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 327 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 328 | attention_scores = attention_scores + attention_mask 329 | 330 | # Normalize the attention scores to probabilities. 331 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 332 | 333 | # This is actually dropping out entire tokens to attend to, which might 334 | # seem a bit unusual, but is taken from the original Transformer paper. 335 | attention_probs = self.dropout(attention_probs) 336 | 337 | context_layer = torch.matmul(attention_probs, value_layer) 338 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 339 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 340 | context_layer = context_layer.view(*new_context_layer_shape) 341 | return context_layer 342 | 343 | 344 | class BertSelfOutput(nn.Module): 345 | def __init__(self, config): 346 | super(BertSelfOutput, self).__init__() 347 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 348 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 349 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 350 | 351 | def forward(self, hidden_states, input_tensor): 352 | hidden_states = self.dense(hidden_states) 353 | hidden_states = self.dropout(hidden_states) 354 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 355 | return hidden_states 356 | 357 | 358 | class BertAttention(nn.Module): 359 | def __init__(self, config): 360 | super(BertAttention, self).__init__() 361 | self.self = BertSelfAttention(config) 362 | self.output = BertSelfOutput(config) 363 | 364 | def forward(self, input_tensor, attention_mask): 365 | self_output = self.self(input_tensor, attention_mask) 366 | attention_output = self.output(self_output, input_tensor) 367 | return attention_output 368 | 369 | 370 | class BertIntermediate(nn.Module): 371 | def __init__(self, config): 372 | super(BertIntermediate, self).__init__() 373 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 374 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): 375 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 376 | else: 377 | self.intermediate_act_fn = config.hidden_act 378 | 379 | def forward(self, hidden_states): 380 | hidden_states = self.dense(hidden_states) 381 | hidden_states = self.intermediate_act_fn(hidden_states) 382 | return hidden_states 383 | 384 | 385 | class BertOutput(nn.Module): 386 | def __init__(self, config): 387 | super(BertOutput, self).__init__() 388 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 389 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 390 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 391 | 392 | def forward(self, hidden_states, input_tensor): 393 | hidden_states = self.dense(hidden_states) 394 | hidden_states = self.dropout(hidden_states) 395 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 396 | return hidden_states 397 | 398 | 399 | class BertLayer(nn.Module): 400 | def __init__(self, config): 401 | super(BertLayer, self).__init__() 402 | self.attention = BertAttention(config) 403 | self.intermediate = BertIntermediate(config) 404 | self.output = BertOutput(config) 405 | 406 | def forward(self, hidden_states, attention_mask): 407 | attention_output = self.attention(hidden_states, attention_mask) 408 | intermediate_output = self.intermediate(attention_output) 409 | layer_output = self.output(intermediate_output, attention_output) 410 | return layer_output 411 | 412 | 413 | class BertEncoder(nn.Module): 414 | def __init__(self, config): 415 | super(BertEncoder, self).__init__() 416 | layer = BertLayer(config) 417 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) 418 | 419 | def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): 420 | all_encoder_layers = [] 421 | for layer_module in self.layer: 422 | hidden_states = layer_module(hidden_states, attention_mask) 423 | if output_all_encoded_layers: 424 | all_encoder_layers.append(hidden_states) 425 | if not output_all_encoded_layers: 426 | all_encoder_layers.append(hidden_states) 427 | return all_encoder_layers 428 | 429 | 430 | class BertPooler(nn.Module): 431 | def __init__(self, config): 432 | super(BertPooler, self).__init__() 433 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 434 | self.activation = nn.Tanh() 435 | 436 | def forward(self, hidden_states): 437 | # We "pool" the model by simply taking the hidden state corresponding 438 | # to the first token. 439 | first_token_tensor = hidden_states[:, 0] 440 | pooled_output = self.dense(first_token_tensor) 441 | pooled_output = self.activation(pooled_output) 442 | return pooled_output 443 | 444 | 445 | class BertPredictionHeadTransform(nn.Module): 446 | def __init__(self, config): 447 | super(BertPredictionHeadTransform, self).__init__() 448 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 449 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): 450 | self.transform_act_fn = ACT2FN[config.hidden_act] 451 | else: 452 | self.transform_act_fn = config.hidden_act 453 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 454 | 455 | def forward(self, hidden_states): 456 | hidden_states = self.dense(hidden_states) 457 | hidden_states = self.transform_act_fn(hidden_states) 458 | hidden_states = self.LayerNorm(hidden_states) 459 | return hidden_states 460 | 461 | 462 | class BertLMPredictionHead(nn.Module): 463 | def __init__(self, config, bert_model_embedding_weights): 464 | super(BertLMPredictionHead, self).__init__() 465 | self.transform = BertPredictionHeadTransform(config) 466 | 467 | # The output weights are the same as the input embeddings, but there is 468 | # an output-only bias for each token. 469 | self.decoder = nn.Linear(bert_model_embedding_weights.size(1), 470 | bert_model_embedding_weights.size(0), 471 | bias=False) 472 | self.decoder.weight = bert_model_embedding_weights 473 | self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) 474 | 475 | def forward(self, hidden_states): 476 | hidden_states = self.transform(hidden_states) 477 | hidden_states = self.decoder(hidden_states) + self.bias 478 | return hidden_states 479 | 480 | 481 | class BertOnlyMLMHead(nn.Module): 482 | def __init__(self, config, bert_model_embedding_weights): 483 | super(BertOnlyMLMHead, self).__init__() 484 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) 485 | 486 | def forward(self, sequence_output): 487 | prediction_scores = self.predictions(sequence_output) 488 | return prediction_scores 489 | 490 | 491 | class BertOnlyNSPHead(nn.Module): 492 | def __init__(self, config): 493 | super(BertOnlyNSPHead, self).__init__() 494 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 495 | 496 | def forward(self, pooled_output): 497 | seq_relationship_score = self.seq_relationship(pooled_output) 498 | return seq_relationship_score 499 | 500 | 501 | class BertPreTrainingHeads(nn.Module): 502 | def __init__(self, config, bert_model_embedding_weights): 503 | super(BertPreTrainingHeads, self).__init__() 504 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) 505 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 506 | 507 | def forward(self, sequence_output, pooled_output): 508 | prediction_scores = self.predictions(sequence_output) 509 | seq_relationship_score = self.seq_relationship(pooled_output) 510 | return prediction_scores, seq_relationship_score 511 | 512 | 513 | class BertPreTrainedModel(nn.Module): 514 | """ An abstract class to handle weights initialization and 515 | a simple interface for dowloading and loading pretrained models. 516 | """ 517 | def __init__(self, config, *inputs, **kwargs): 518 | super(BertPreTrainedModel, self).__init__() 519 | if not isinstance(config, BertConfig): 520 | raise ValueError( 521 | "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " 522 | "To create a model from a Google pretrained model use " 523 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 524 | self.__class__.__name__, self.__class__.__name__ 525 | )) 526 | self.config = config 527 | 528 | def init_bert_weights(self, module): 529 | """ Initialize the weights. 530 | """ 531 | if isinstance(module, (nn.Linear, nn.Embedding)): 532 | # Slightly different from the TF version which uses truncated_normal for initialization 533 | # cf https://github.com/pytorch/pytorch/pull/5617 534 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 535 | elif isinstance(module, BertLayerNorm): 536 | module.bias.data.zero_() 537 | module.weight.data.fill_(1.0) 538 | if isinstance(module, nn.Linear) and module.bias is not None: 539 | module.bias.data.zero_() 540 | 541 | @classmethod 542 | def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None, 543 | from_tf=False, *inputs, **kwargs): 544 | """ 545 | Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. 546 | Download and cache the pre-trained model file if needed. 547 | 548 | Params: 549 | pretrained_model_name_or_path: either: 550 | - a str with the name of a pre-trained model to load selected in the list of: 551 | . `bert-base-uncased` 552 | . `bert-large-uncased` 553 | . `bert-base-cased` 554 | . `bert-large-cased` 555 | . `bert-base-multilingual-uncased` 556 | . `bert-base-multilingual-cased` 557 | . `bert-base-chinese` 558 | - a path or url to a pretrained model archive containing: 559 | . `bert_config.json` a configuration file for the model 560 | . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance 561 | - a path or url to a pretrained model archive containing: 562 | . `bert_config.json` a configuration file for the model 563 | . `model.chkpt` a TensorFlow checkpoint 564 | from_tf: should we load the weights from a locally saved TensorFlow checkpoint 565 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 566 | state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models 567 | *inputs, **kwargs: additional input for the specific Bert class 568 | (ex: num_labels for BertForSequenceClassification) 569 | """ 570 | if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: 571 | archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] 572 | else: 573 | archive_file = pretrained_model_name_or_path 574 | # redirect to the cache, if necessary 575 | try: 576 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) 577 | except EnvironmentError: 578 | logger.error( 579 | "Model name '{}' was not found in model name list ({}). " 580 | "We assumed '{}' was a path or url but couldn't find any file " 581 | "associated to this path or url.".format( 582 | pretrained_model_name_or_path, 583 | ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), 584 | archive_file)) 585 | return None 586 | if resolved_archive_file == archive_file: 587 | logger.info("loading archive file {}".format(archive_file)) 588 | else: 589 | logger.info("loading archive file {} from cache at {}".format( 590 | archive_file, resolved_archive_file)) 591 | tempdir = None 592 | if os.path.isdir(resolved_archive_file) or from_tf: 593 | serialization_dir = resolved_archive_file 594 | else: 595 | # Extract archive to temp dir 596 | tempdir = tempfile.mkdtemp() 597 | logger.info("extracting archive file {} to temp dir {}".format( 598 | resolved_archive_file, tempdir)) 599 | with tarfile.open(resolved_archive_file, 'r:gz') as archive: 600 | archive.extractall(tempdir) 601 | serialization_dir = tempdir 602 | # Load config 603 | config_file = os.path.join(serialization_dir, CONFIG_NAME) 604 | if not os.path.exists(config_file): 605 | # Backward compatibility with old naming format 606 | config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME) 607 | config = BertConfig.from_json_file(config_file) 608 | logger.info("Model config {}".format(config)) 609 | # Instantiate model. 610 | model = cls(config, *inputs, **kwargs) 611 | if state_dict is None and not from_tf: 612 | weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) 613 | state_dict = torch.load(weights_path, map_location='cpu') 614 | if tempdir: 615 | # Clean up temp dir 616 | shutil.rmtree(tempdir) 617 | if from_tf: 618 | # Directly load from a TensorFlow checkpoint 619 | weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) 620 | return load_tf_weights_in_bert(model, weights_path) 621 | # Load from a PyTorch state_dict 622 | old_keys = [] 623 | new_keys = [] 624 | for key in state_dict.keys(): 625 | new_key = None 626 | if 'gamma' in key: 627 | new_key = key.replace('gamma', 'weight') 628 | if 'beta' in key: 629 | new_key = key.replace('beta', 'bias') 630 | if new_key: 631 | old_keys.append(key) 632 | new_keys.append(new_key) 633 | for old_key, new_key in zip(old_keys, new_keys): 634 | state_dict[new_key] = state_dict.pop(old_key) 635 | 636 | missing_keys = [] 637 | unexpected_keys = [] 638 | error_msgs = [] 639 | # copy state_dict so _load_from_state_dict can modify it 640 | metadata = getattr(state_dict, '_metadata', None) 641 | state_dict = state_dict.copy() 642 | if metadata is not None: 643 | state_dict._metadata = metadata 644 | 645 | def load(module, prefix=''): 646 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 647 | module._load_from_state_dict( 648 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 649 | for name, child in module._modules.items(): 650 | if child is not None: 651 | load(child, prefix + name + '.') 652 | start_prefix = '' 653 | if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()): 654 | start_prefix = 'bert.' 655 | load(model, prefix=start_prefix) 656 | if len(missing_keys) > 0: 657 | logger.info("Weights of {} not initialized from pretrained model: {}".format( 658 | model.__class__.__name__, missing_keys)) 659 | if len(unexpected_keys) > 0: 660 | logger.info("Weights from pretrained model not used in {}: {}".format( 661 | model.__class__.__name__, unexpected_keys)) 662 | if len(error_msgs) > 0: 663 | raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 664 | model.__class__.__name__, "\n\t".join(error_msgs))) 665 | return model 666 | 667 | 668 | class BertModel(BertPreTrainedModel): 669 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 670 | 671 | Params: 672 | config: a BertConfig class instance with the configuration to build a new model 673 | 674 | Inputs: 675 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 676 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 677 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 678 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 679 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 680 | a `sentence B` token (see BERT paper for more details). 681 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 682 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 683 | input sequence length in the current batch. It's the mask that we typically use for attention when 684 | a batch has varying length sentences. 685 | `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. 686 | 687 | Outputs: Tuple of (encoded_layers, pooled_output) 688 | `encoded_layers`: controled by `output_all_encoded_layers` argument: 689 | - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end 690 | of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each 691 | encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], 692 | - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding 693 | to the last attention block of shape [batch_size, sequence_length, hidden_size], 694 | `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a 695 | classifier pretrained on top of the hidden state associated to the first character of the 696 | input (`CLS`) to train on the Next-Sentence task (see BERT's paper). 697 | 698 | Example usage: 699 | ```python 700 | # Already been converted into WordPiece token ids 701 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 702 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 703 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 704 | 705 | config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 706 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 707 | 708 | model = modeling.BertModel(config=config) 709 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 710 | ``` 711 | """ 712 | def __init__(self, config): 713 | super(BertModel, self).__init__(config) 714 | self.embeddings = BertEmbeddings(config) 715 | self.encoder = BertEncoder(config) 716 | self.pooler = BertPooler(config) 717 | self.apply(self.init_bert_weights) 718 | 719 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): 720 | if attention_mask is None: 721 | attention_mask = torch.ones_like(input_ids) 722 | if token_type_ids is None: 723 | token_type_ids = torch.zeros_like(input_ids) 724 | 725 | # We create a 3D attention mask from a 2D tensor mask. 726 | # Sizes are [batch_size, 1, 1, to_seq_length] 727 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 728 | # this attention mask is more simple than the triangular masking of causal attention 729 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 730 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 731 | 732 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 733 | # masked positions, this operation will create a tensor which is 0.0 for 734 | # positions we want to attend and -10000.0 for masked positions. 735 | # Since we are adding it to the raw scores before the softmax, this is 736 | # effectively the same as removing these entirely. 737 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 738 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 739 | 740 | embedding_output = self.embeddings(input_ids, token_type_ids) 741 | encoded_layers = self.encoder(embedding_output, 742 | extended_attention_mask, 743 | output_all_encoded_layers=output_all_encoded_layers) 744 | sequence_output = encoded_layers[-1] 745 | pooled_output = self.pooler(sequence_output) 746 | if not output_all_encoded_layers: 747 | encoded_layers = encoded_layers[-1] 748 | return encoded_layers, pooled_output 749 | 750 | 751 | class BertForPreTraining(BertPreTrainedModel): 752 | """BERT model with pre-training heads. 753 | This module comprises the BERT model followed by the two pre-training heads: 754 | - the masked language modeling head, and 755 | - the next sentence classification head. 756 | 757 | Params: 758 | config: a BertConfig class instance with the configuration to build a new model. 759 | 760 | Inputs: 761 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 762 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 763 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 764 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 765 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 766 | a `sentence B` token (see BERT paper for more details). 767 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 768 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 769 | input sequence length in the current batch. It's the mask that we typically use for attention when 770 | a batch has varying length sentences. 771 | `masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 772 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 773 | is only computed for the labels set in [0, ..., vocab_size] 774 | `next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size] 775 | with indices selected in [0, 1]. 776 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence. 777 | 778 | Outputs: 779 | if `masked_lm_labels` and `next_sentence_label` are not `None`: 780 | Outputs the total_loss which is the sum of the masked language modeling loss and the next 781 | sentence classification loss. 782 | if `masked_lm_labels` or `next_sentence_label` is `None`: 783 | Outputs a tuple comprising 784 | - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and 785 | - the next sentence classification logits of shape [batch_size, 2]. 786 | 787 | Example usage: 788 | ```python 789 | # Already been converted into WordPiece token ids 790 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 791 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 792 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 793 | 794 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 795 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 796 | 797 | model = BertForPreTraining(config) 798 | masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) 799 | ``` 800 | """ 801 | def __init__(self, config): 802 | super(BertForPreTraining, self).__init__(config) 803 | self.bert = BertModel(config) 804 | self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) 805 | self.apply(self.init_bert_weights) 806 | 807 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None): 808 | sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, 809 | output_all_encoded_layers=False) 810 | prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) 811 | 812 | if masked_lm_labels is not None and next_sentence_label is not None: 813 | loss_fct = CrossEntropyLoss(ignore_index=-1) 814 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 815 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 816 | total_loss = masked_lm_loss + next_sentence_loss 817 | return total_loss 818 | else: 819 | return prediction_scores, seq_relationship_score 820 | 821 | 822 | class BertForMaskedLM(BertPreTrainedModel): 823 | """BERT model with the masked language modeling head. 824 | This module comprises the BERT model followed by the masked language modeling head. 825 | 826 | Params: 827 | config: a BertConfig class instance with the configuration to build a new model. 828 | 829 | Inputs: 830 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 831 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 832 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 833 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 834 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 835 | a `sentence B` token (see BERT paper for more details). 836 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 837 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 838 | input sequence length in the current batch. It's the mask that we typically use for attention when 839 | a batch has varying length sentences. 840 | `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 841 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 842 | is only computed for the labels set in [0, ..., vocab_size] 843 | 844 | Outputs: 845 | if `masked_lm_labels` is not `None`: 846 | Outputs the masked language modeling loss. 847 | if `masked_lm_labels` is `None`: 848 | Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size]. 849 | 850 | Example usage: 851 | ```python 852 | # Already been converted into WordPiece token ids 853 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 854 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 855 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 856 | 857 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 858 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 859 | 860 | model = BertForMaskedLM(config) 861 | masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) 862 | ``` 863 | """ 864 | def __init__(self, config): 865 | super(BertForMaskedLM, self).__init__(config) 866 | self.bert = BertModel(config) 867 | self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) 868 | self.apply(self.init_bert_weights) 869 | 870 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None): 871 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, 872 | output_all_encoded_layers=False) 873 | prediction_scores = self.cls(sequence_output) 874 | 875 | if masked_lm_labels is not None: 876 | loss_fct = CrossEntropyLoss(ignore_index=-1) 877 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 878 | return masked_lm_loss 879 | else: 880 | return prediction_scores 881 | 882 | 883 | class BertForNextSentencePrediction(BertPreTrainedModel): 884 | """BERT model with next sentence prediction head. 885 | This module comprises the BERT model followed by the next sentence classification head. 886 | 887 | Params: 888 | config: a BertConfig class instance with the configuration to build a new model. 889 | 890 | Inputs: 891 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 892 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 893 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 894 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 895 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 896 | a `sentence B` token (see BERT paper for more details). 897 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 898 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 899 | input sequence length in the current batch. It's the mask that we typically use for attention when 900 | a batch has varying length sentences. 901 | `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] 902 | with indices selected in [0, 1]. 903 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence. 904 | 905 | Outputs: 906 | if `next_sentence_label` is not `None`: 907 | Outputs the total_loss which is the sum of the masked language modeling loss and the next 908 | sentence classification loss. 909 | if `next_sentence_label` is `None`: 910 | Outputs the next sentence classification logits of shape [batch_size, 2]. 911 | 912 | Example usage: 913 | ```python 914 | # Already been converted into WordPiece token ids 915 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 916 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 917 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 918 | 919 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 920 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 921 | 922 | model = BertForNextSentencePrediction(config) 923 | seq_relationship_logits = model(input_ids, token_type_ids, input_mask) 924 | ``` 925 | """ 926 | def __init__(self, config): 927 | super(BertForNextSentencePrediction, self).__init__(config) 928 | self.bert = BertModel(config) 929 | self.cls = BertOnlyNSPHead(config) 930 | self.apply(self.init_bert_weights) 931 | 932 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None): 933 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, 934 | output_all_encoded_layers=False) 935 | seq_relationship_score = self.cls( pooled_output) 936 | 937 | if next_sentence_label is not None: 938 | loss_fct = CrossEntropyLoss(ignore_index=-1) 939 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 940 | return next_sentence_loss 941 | else: 942 | return seq_relationship_score 943 | 944 | 945 | class BertForSequenceClassification(BertPreTrainedModel): 946 | """BERT model for classification. 947 | This module is composed of the BERT model with a linear layer on top of 948 | the pooled output. 949 | 950 | Params: 951 | `config`: a BertConfig class instance with the configuration to build a new model. 952 | `num_labels`: the number of classes for the classifier. Default = 2. 953 | 954 | Inputs: 955 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 956 | with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts 957 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 958 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 959 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 960 | a `sentence B` token (see BERT paper for more details). 961 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 962 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 963 | input sequence length in the current batch. It's the mask that we typically use for attention when 964 | a batch has varying length sentences. 965 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 966 | with indices selected in [0, ..., num_labels]. 967 | 968 | Outputs: 969 | if `labels` is not `None`: 970 | Outputs the CrossEntropy classification loss of the output with the labels. 971 | if `labels` is `None`: 972 | Outputs the classification logits of shape [batch_size, num_labels]. 973 | 974 | Example usage: 975 | ```python 976 | # Already been converted into WordPiece token ids 977 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 978 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 979 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 980 | 981 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 982 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 983 | 984 | num_labels = 2 985 | 986 | model = BertForSequenceClassification(config, num_labels) 987 | logits = model(input_ids, token_type_ids, input_mask) 988 | ``` 989 | """ 990 | def __init__(self, config, num_labels): 991 | super(BertForSequenceClassification, self).__init__(config) 992 | self.num_labels = num_labels 993 | self.bert = BertModel(config) 994 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 995 | self.classifier = nn.Linear(config.hidden_size, num_labels) 996 | self.apply(self.init_bert_weights) 997 | 998 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 999 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 1000 | pooled_output = self.dropout(pooled_output) 1001 | logits = self.classifier(pooled_output) 1002 | 1003 | if labels is not None: 1004 | loss_fct = CrossEntropyLoss() 1005 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1006 | return loss 1007 | else: 1008 | return logits 1009 | 1010 | 1011 | class BertForMultipleChoice(BertPreTrainedModel): 1012 | """BERT model for multiple choice tasks. 1013 | This module is composed of the BERT model with a linear layer on top of 1014 | the pooled output. 1015 | 1016 | Params: 1017 | `config`: a BertConfig class instance with the configuration to build a new model. 1018 | `num_choices`: the number of classes for the classifier. Default = 2. 1019 | 1020 | Inputs: 1021 | `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] 1022 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1023 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1024 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] 1025 | with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` 1026 | and type 1 corresponds to a `sentence B` token (see BERT paper for more details). 1027 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices 1028 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1029 | input sequence length in the current batch. It's the mask that we typically use for attention when 1030 | a batch has varying length sentences. 1031 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 1032 | with indices selected in [0, ..., num_choices]. 1033 | 1034 | Outputs: 1035 | if `labels` is not `None`: 1036 | Outputs the CrossEntropy classification loss of the output with the labels. 1037 | if `labels` is `None`: 1038 | Outputs the classification logits of shape [batch_size, num_labels]. 1039 | 1040 | Example usage: 1041 | ```python 1042 | # Already been converted into WordPiece token ids 1043 | input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) 1044 | input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) 1045 | token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) 1046 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1047 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1048 | 1049 | num_choices = 2 1050 | 1051 | model = BertForMultipleChoice(config, num_choices) 1052 | logits = model(input_ids, token_type_ids, input_mask) 1053 | ``` 1054 | """ 1055 | def __init__(self, config, num_choices): 1056 | super(BertForMultipleChoice, self).__init__(config) 1057 | self.num_choices = num_choices 1058 | self.bert = BertModel(config) 1059 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1060 | self.classifier = nn.Linear(config.hidden_size, 1) 1061 | self.apply(self.init_bert_weights) 1062 | 1063 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 1064 | flat_input_ids = input_ids.view(-1, input_ids.size(-1)) 1065 | flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) 1066 | flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) 1067 | _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) 1068 | pooled_output = self.dropout(pooled_output) 1069 | logits = self.classifier(pooled_output) 1070 | reshaped_logits = logits.view(-1, self.num_choices) 1071 | 1072 | if labels is not None: 1073 | loss_fct = CrossEntropyLoss() 1074 | loss = loss_fct(reshaped_logits, labels) 1075 | return loss 1076 | else: 1077 | return reshaped_logits 1078 | 1079 | 1080 | class BertForTokenClassification(BertPreTrainedModel): 1081 | """BERT model for token-level classification. 1082 | This module is composed of the BERT model with a linear layer on top of 1083 | the full hidden state of the last layer. 1084 | 1085 | Params: 1086 | `config`: a BertConfig class instance with the configuration to build a new model. 1087 | `num_labels`: the number of classes for the classifier. Default = 2. 1088 | 1089 | Inputs: 1090 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1091 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1092 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1093 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1094 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1095 | a `sentence B` token (see BERT paper for more details). 1096 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1097 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1098 | input sequence length in the current batch. It's the mask that we typically use for attention when 1099 | a batch has varying length sentences. 1100 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length] 1101 | with indices selected in [0, ..., num_labels]. 1102 | 1103 | Outputs: 1104 | if `labels` is not `None`: 1105 | Outputs the CrossEntropy classification loss of the output with the labels. 1106 | if `labels` is `None`: 1107 | Outputs the classification logits of shape [batch_size, sequence_length, num_labels]. 1108 | 1109 | Example usage: 1110 | ```python 1111 | # Already been converted into WordPiece token ids 1112 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1113 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1114 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1115 | 1116 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1117 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1118 | 1119 | num_labels = 2 1120 | 1121 | model = BertForTokenClassification(config, num_labels) 1122 | logits = model(input_ids, token_type_ids, input_mask) 1123 | ``` 1124 | """ 1125 | def __init__(self, config, num_labels): 1126 | super(BertForTokenClassification, self).__init__(config) 1127 | self.num_labels = num_labels 1128 | self.bert = BertModel(config) 1129 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1130 | self.classifier = nn.Linear(config.hidden_size, num_labels) 1131 | self.apply(self.init_bert_weights) 1132 | 1133 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 1134 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 1135 | sequence_output = self.dropout(sequence_output) 1136 | logits = self.classifier(sequence_output) 1137 | 1138 | if labels is not None: 1139 | loss_fct = CrossEntropyLoss() 1140 | # Only keep active parts of the loss 1141 | if attention_mask is not None: 1142 | active_loss = attention_mask.view(-1) == 1 1143 | active_logits = logits.view(-1, self.num_labels)[active_loss] 1144 | active_labels = labels.view(-1)[active_loss] 1145 | loss = loss_fct(active_logits, active_labels) 1146 | else: 1147 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1148 | return loss 1149 | else: 1150 | return logits 1151 | 1152 | 1153 | class BertForQuestionAnswering(BertPreTrainedModel): 1154 | """BERT model for Question Answering (span extraction). 1155 | This module is composed of the BERT model with a linear layer on top of 1156 | the sequence output that computes start_logits and end_logits 1157 | 1158 | Params: 1159 | `config`: a BertConfig class instance with the configuration to build a new model. 1160 | 1161 | Inputs: 1162 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1163 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1164 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1165 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1166 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1167 | a `sentence B` token (see BERT paper for more details). 1168 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1169 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1170 | input sequence length in the current batch. It's the mask that we typically use for attention when 1171 | a batch has varying length sentences. 1172 | `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. 1173 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken 1174 | into account for computing the loss. 1175 | `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. 1176 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken 1177 | into account for computing the loss. 1178 | 1179 | Outputs: 1180 | if `start_positions` and `end_positions` are not `None`: 1181 | Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. 1182 | if `start_positions` or `end_positions` is `None`: 1183 | Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end 1184 | position tokens of shape [batch_size, sequence_length]. 1185 | 1186 | Example usage: 1187 | ```python 1188 | # Already been converted into WordPiece token ids 1189 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1190 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1191 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1192 | 1193 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1194 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1195 | 1196 | model = BertForQuestionAnswering(config) 1197 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 1198 | ``` 1199 | """ 1200 | def __init__(self, config): 1201 | super(BertForQuestionAnswering, self).__init__(config) 1202 | self.bert = BertModel(config) 1203 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version 1204 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 1205 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 1206 | self.yes_output = nn.Linear(config.hidden_size,1) 1207 | self.no_output = nn.Linear(config.hidden_size,1) 1208 | self.unknow_ouput = nn.Linear(config.hidden_size,1) 1209 | 1210 | self.apply(self.init_bert_weights) 1211 | 1212 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): 1213 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 1214 | logits = self.qa_outputs(sequence_output) 1215 | start_logits, end_logits = logits.split(1, dim=-1) 1216 | start_logits = start_logits.squeeze(-1) 1217 | end_logits = end_logits.squeeze(-1) 1218 | 1219 | yes_logits = self.yes_output(sequence_output[:,0,:]) 1220 | no_logits = self.no_output(sequence_output[:,0,:]) 1221 | unknown_logits = self.unknow_ouput(sequence_output[:,0,:]) 1222 | 1223 | score = gen_upper_triangle(start_logits, end_logits, start_logits.size(1), use_cuda=True) 1224 | ex_score = torch.cat([score,yes_logits,no_logits,unknown_logits], dim = 1) 1225 | 1226 | rejected = 0 1227 | if start_positions is not None and end_positions is not None: 1228 | # If we are on multi-GPU, split add a dimension 1229 | if len(start_positions.size()) > 1: 1230 | start_positions = start_positions.squeeze(-1) 1231 | if len(end_positions.size()) > 1: 1232 | end_positions = end_positions.squeeze(-1) 1233 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 1234 | ignored_index = start_logits.size(1) 1235 | start_positions.clamp_(-1, ignored_index) 1236 | end_positions.clamp_(-1, ignored_index) 1237 | 1238 | targets = [] 1239 | span_idx = int(logits.size(1)*logits.size(1)) 1240 | # print(start_positions,end_positions) 1241 | for i in range(start_positions.size(0)): 1242 | if start_positions[i] == -1 and end_positions[i] == -1: 1243 | targets.append(span_idx + 2) 1244 | 1245 | if start_positions[i] == -1 and end_positions[i] == 0: 1246 | # code for yes 1247 | targets.append(span_idx) 1248 | 1249 | if start_positions[i] == 0 and end_positions[i] == -1: 1250 | # code for no 1251 | targets.append(span_idx+1) 1252 | 1253 | if start_positions[i] != -1 and end_positions[i] != -1: 1254 | if start_positions[i] == ignored_index or end_positions[i] == ignored_index: 1255 | targets.append(ignored_index*ignored_index+3) 1256 | rejected += 1 1257 | else: 1258 | targets.append(start_positions[i]*logits.size(1) + end_positions[i]) 1259 | 1260 | targets = torch.LongTensor(targets).cuda() 1261 | # print(targets) 1262 | # targets = torch.LongTensor(targets) 1263 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index*ignored_index+3) 1264 | 1265 | total_loss = loss_fct(ex_score,targets) 1266 | 1267 | return total_loss,rejected 1268 | else: 1269 | return ex_score 1270 | --------------------------------------------------------------------------------